From 1397d0000d5c7bc133f78b0145d4810145e3e3ff Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 14:50:36 +0800 Subject: [PATCH 01/28] chore(deps): add faker --- api/poetry.lock | 61 +++++++++++++--------------------------------- api/pyproject.toml | 1 + 2 files changed, 18 insertions(+), 44 deletions(-) diff --git a/api/poetry.lock b/api/poetry.lock index 74c2ef5dc6..fbc8dd2764 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -932,10 +932,6 @@ files = [ {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"}, {file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"}, {file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"}, {file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"}, @@ -948,14 +944,8 @@ files = [ {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"}, {file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"}, {file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"}, - {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"}, - {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"}, {file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"}, @@ -966,24 +956,8 @@ files = [ {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"}, {file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"}, {file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"}, - {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"}, - {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"}, - {file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"}, - {file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"}, {file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"}, @@ -993,10 +967,6 @@ files = [ {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"}, {file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"}, {file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"}, {file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"}, @@ -1008,10 +978,6 @@ files = [ {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"}, {file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"}, {file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"}, {file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"}, @@ -1024,10 +990,6 @@ files = [ {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"}, {file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"}, {file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"}, {file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"}, @@ -1040,10 +1002,6 @@ files = [ {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"}, {file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"}, {file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"}, {file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"}, @@ -2453,6 +2411,21 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "faker" +version = "32.1.0" +description = "Faker is a Python package that generates fake data for you." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"}, + {file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"}, +] + +[package.dependencies] +python-dateutil = ">=2.4" +typing-extensions = "*" + [[package]] name = "fal-client" version = "0.5.6" @@ -11078,4 +11051,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "2ba4b464eebc26598f290fa94713acc44c588f902176e6efa80622911d40f0ac" +content-hash = "9e6481430de11a66c56737e25cdae6788c0ded2e9cc587c1def7fd35bc93a0d2" diff --git a/api/pyproject.toml b/api/pyproject.toml index 0633e9dd90..8d8872aa78 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -266,6 +266,7 @@ weaviate-client = "~3.21.0" optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" +faker = "^32.1.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" From 1fdaea29aa4e688a9a95ecd533376db3fe7130a6 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 16:07:22 +0800 Subject: [PATCH 02/28] refactor(converter): simplify model credentials validation logic --- .../model_config/converter.py | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index a91b9f0f02..cdc82860c6 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager class ModelConfigConverter: @classmethod - def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: + def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -38,27 +38,23 @@ class ModelConfigConverter: ) if model_credentials is None: - if not skip_check: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - else: - model_credentials = {} + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - if not skip_check: - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_config.model, model_type=ModelType.LLM - ) + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_config.model, model_type=ModelType.LLM + ) - if provider_model is None: - model_name = model_config.model - raise ValueError(f"Model {model_name} not exist.") + if provider_model is None: + model_name = model_config.model + raise ValueError(f"Model {model_name} not exist.") - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config completion_params = model_config.parameters @@ -76,7 +72,7 @@ class ModelConfigConverter: model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) - if not skip_check and not model_schema: + if not model_schema: raise ValueError(f"Model {model_name} not exist.") return ModelConfigWithCredentialsEntity( From 4f89214d89859b1d95568309f8a13622ca3ce02f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 16:21:29 +0800 Subject: [PATCH 03/28] refactor: update stop parameter type to use Sequence instead of list --- api/core/model_manager.py | 2 +- .../model_runtime/callbacks/base_callback.py | 9 +++++---- .../__base/large_language_model.py | 20 +++++++++---------- api/core/workflow/nodes/llm/node.py | 2 +- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 059ba6c3d1..3424a7fa78 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -103,7 +103,7 @@ class ModelInstance: prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 6bd9325785..8870b34435 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Sequence from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -31,7 +32,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -60,7 +61,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ): @@ -90,7 +91,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -120,7 +121,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 5b6f96129b..8faeffa872 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import logging import re import time from abc import abstractmethod -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from typing import Optional, Union from pydantic import ConfigDict @@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel): prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -212,7 +212,7 @@ if you are not sure about the structure. ) model_parameters.pop("response_format") - stop = stop or [] + stop = list(stop) if stop is not None else [] stop.extend(["\n```", "```\n"]) block_prompts = block_prompts.replace("{{block}}", code_block) @@ -408,7 +408,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -479,7 +479,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> Union[LLMResult, Generator]: @@ -601,7 +601,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -647,7 +647,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -694,7 +694,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -742,7 +742,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index eb4d1c9d87..7634b90dff 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -204,7 +204,7 @@ class LLMNode(BaseNode[LLMNodeData]): node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, ) -> Generator[NodeEvent, None, None]: db.session.close() From d9fa6f79be56733082f8ddef7fb43a4ae4ad9f3c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 16:22:01 +0800 Subject: [PATCH 04/28] refactor: update jinja2_variables and prompt_config to use Sequence and add validators for None handling --- api/core/workflow/nodes/llm/entities.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index a25d563fe0..19a66087f7 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -39,7 +39,14 @@ class VisionConfig(BaseModel): class PromptConfig(BaseModel): - jinja2_variables: Optional[list[VariableSelector]] = None + jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) + + @field_validator("jinja2_variables", mode="before") + @classmethod + def convert_none_jinja2_variables(cls, v: Any): + if v is None: + return [] + return v class LLMNodeChatModelMessage(ChatModelMessage): @@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - prompt_config: Optional[PromptConfig] = None + prompt_config: PromptConfig = Field(default_factory=PromptConfig) memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) + + @field_validator("prompt_config", mode="before") + @classmethod + def convert_none_prompt_config(cls, v: Any): + if v is None: + return PromptConfig() + return v From 229b146525a7ab57db03ec7cb37642832b010a56 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 16:22:15 +0800 Subject: [PATCH 05/28] feat(errors): add new error classes for unsupported prompt types and memory role prefix requirements --- api/core/workflow/nodes/llm/exc.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py index f858be2515..b5207d5573 100644 --- a/api/core/workflow/nodes/llm/exc.py +++ b/api/core/workflow/nodes/llm/exc.py @@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError): class NoPromptFoundError(LLMNodeError): """Raised when no prompt is found in the LLM configuration.""" + + +class NotSupportedPromptTypeError(LLMNodeError): + """Raised when the prompt type is not supported.""" + + +class MemoryRolePrefixRequiredError(LLMNodeError): + """Raised when memory role prefix is required for completion model.""" From 2106fc52667339e8ec5e37db30acac78215eaf69 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 16:32:43 +0800 Subject: [PATCH 06/28] fix(tests): update Azure Rerank Model usage and clean imports --- .../model_runtime/azure_ai_studio/test_llm.py | 1 - .../model_runtime/azure_ai_studio/test_rerank.py | 14 +++----------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py index 85a4f7734d..b995077984 100644 --- a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py @@ -11,7 +11,6 @@ from core.model_runtime.entities.message_entities import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel -from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock @pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py index 466facc5ff..4d72327c0e 100644 --- a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py @@ -4,29 +4,21 @@ import pytest from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel +from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel def test_validate_credentials(): - model = AzureAIStudioRerankModel() + model = AzureRerankModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="azure-ai-studio-rerank-v1", credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, - query="What is the capital of the United States?", - docs=[ - "Carson City is the capital city of the American state of Nevada. At the 2010 United States " - "Census, Carson City had a population of 55,274.", - "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " - "are a political division controlled by the United States. Its capital is Saipan.", - ], - score_threshold=0.8, ) def test_invoke_model(): - model = AzureAIStudioRerankModel() + model = AzureRerankModel() result = model.invoke( model="azure-ai-studio-rerank-v1", From 93bbb194f2f9987dfeaa6a40cfcf65bb125b4b57 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 12:53:51 +0800 Subject: [PATCH 07/28] refactor(prompt): enhance type flexibility for prompt messages - Changed input type from list to Sequence for prompt messages to allow more flexible input types. - Improved compatibility with functions expecting different iterable types. --- api/core/prompt/utils/prompt_message_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 5eec5e3c99..aa175153bc 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import cast from core.model_runtime.entities import ( @@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode class PromptMessageUtil: @staticmethod - def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: + def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]: """ Prompt messages to prompt for saving. :param model_mode: model mode From 47e8a5d4d19e0c782ff3cdf3e6c853be4e599ef1 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 13:01:57 +0800 Subject: [PATCH 08/28] refactor(model_runtime): use Sequence for content in PromptMessage - Replaced list with Sequence for more flexible content type. - Improved type consistency by importing from collections.abc. --- api/core/model_runtime/entities/message_entities.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 3c244d368e..fc37227bc9 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,4 +1,5 @@ from abc import ABC +from collections.abc import Sequence from enum import Enum from typing import Optional @@ -107,7 +108,7 @@ class PromptMessage(ABC, BaseModel): """ role: PromptMessageRole - content: Optional[str | list[PromptMessageContent]] = None + content: Optional[str | Sequence[PromptMessageContent]] = None name: Optional[str] = None def is_empty(self) -> bool: From 71cf4c7dbf4f6d20b3008fae285ebbf354a03025 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 13:14:52 +0800 Subject: [PATCH 09/28] chore(config): remove unnecessary 'frozen' parameter for test - Simplified app configuration by removing the 'frozen' parameter since it is no longer needed. - Ensures more flexible handling of config attributes. --- api/configs/app_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/configs/app_config.py b/api/configs/app_config.py index 61de73c868..07ef6121cc 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -27,7 +27,6 @@ class DifyConfig( # read from dotenv format config file env_file=".env", env_file_encoding="utf-8", - frozen=True, # ignore extra attributes extra="ignore", ) From 620b0e69f51342ab73b12dea4723cdb8b7aae365 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 14:05:42 +0800 Subject: [PATCH 10/28] fix(dependencies): update Faker version constraint - Changed the Faker version from caret constraint to tilde constraint for compatibility. - Updated poetry.lock for changes in pyproject.toml content. --- api/poetry.lock | 2 +- api/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/poetry.lock b/api/poetry.lock index fbc8dd2764..ec1e8c3b0a 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -11051,4 +11051,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "9e6481430de11a66c56737e25cdae6788c0ded2e9cc587c1def7fd35bc93a0d2" +content-hash = "0ab603323ea1d83690d4ee61e6d199a2bca6f3e2cc4b454a4ebf99aa6f6907bd" diff --git a/api/pyproject.toml b/api/pyproject.toml index 8d8872aa78..9129baef8c 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -266,7 +266,7 @@ weaviate-client = "~3.21.0" optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" -faker = "^32.1.0" +faker = "~32.1.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" From abad35f7006debeb375dccb0e332f3aeee1e645f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 14:54:28 +0800 Subject: [PATCH 11/28] refactor(memory): use Sequence instead of list for prompt messages - Improved flexibility by using Sequence instead of list, allowing for broader compatibility with different types of sequences. - Helps future-proof the method signature by leveraging the more generic Sequence type. --- api/core/memory/token_buffer_memory.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 688fb4776a..282cd9b36f 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import Optional from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -27,7 +28,7 @@ class TokenBufferMemory: def get_history_prompt_messages( self, max_token_limit: int = 2000, message_limit: Optional[int] = None - ) -> list[PromptMessage]: + ) -> Sequence[PromptMessage]: """ Get history prompt messages. :param max_token_limit: max token limit From ddc86503dc8fbb73eb194b41340771d34143e3a4 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 14:57:03 +0800 Subject: [PATCH 12/28] refactor(model_manager): update parameter type for flexibility - Changed 'prompt_messages' parameter from list to Sequence for broader input type compatibility. --- api/core/model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 3424a7fa78..1986688551 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -100,7 +100,7 @@ class ModelInstance: def invoke_llm( self, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: Optional[dict] = None, tools: Sequence[PromptMessageTool] | None = None, stop: Optional[Sequence[str]] = None, From bab989e3b3335e286de5b7c00eaca4905010cfac Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 18:33:32 +0800 Subject: [PATCH 13/28] Remove unnecessary data from log and text properties Updated the log and text properties in segments to return empty strings instead of the segment value. This change prevents potential leakage of sensitive data by ensuring only non-sensitive information is logged or transformed into text. Addresses potential security and privacy concerns. --- api/core/variables/segments.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index b71882b043..69bd5567a4 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -118,11 +118,11 @@ class FileSegment(Segment): @property def log(self) -> str: - return str(self.value) + return "" @property def text(self) -> str: - return str(self.value) + return "" class ArrayAnySegment(ArraySegment): @@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment): for item in self.value: items.append(item.markdown) return "\n".join(items) + + @property + def log(self) -> str: + return "" + + @property + def text(self) -> str: + return "" From d6c9ab8554be5438b9bc5eb406229a6511bbc0a0 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 18:34:16 +0800 Subject: [PATCH 14/28] feat(llm_node): allow to use image file directly in the prompt. --- api/core/workflow/nodes/llm/node.py | 326 ++++++++++-- .../core/workflow/nodes/llm/test_node.py | 470 ++++++++++++++---- 2 files changed, 651 insertions(+), 145 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 7634b90dff..efd8ace653 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,4 +1,5 @@ import json +import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast @@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.file import FileType, file_manager +from core.helper.code_executor import CodeExecutor, CodeLanguage from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities import ( - AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, TextPromptMessageContent, - VideoPromptMessageContent, ) from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageRole, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.variables import ( @@ -30,10 +36,13 @@ from core.variables import ( FileSegment, NoneSegment, ObjectSegment, + SegmentGroup, StringSegment, ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base import BaseNode @@ -62,14 +71,18 @@ from .exc import ( InvalidVariableTypeError, LLMModeRequiredError, LLMNodeError, + MemoryRolePrefixRequiredError, ModelNotExistError, NoPromptFoundError, + NotSupportedPromptTypeError, VariableNotFoundError, ) if TYPE_CHECKING: from core.file.models import File +logger = logging.getLogger(__name__) + class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData @@ -131,9 +144,8 @@ class LLMNode(BaseNode[LLMNodeData]): query = None prompt_messages, stop = self._fetch_prompt_messages( - system_query=query, - inputs=inputs, - files=files, + user_query=query, + user_files=files, context=context, memory=memory, model_config=model_config, @@ -203,7 +215,7 @@ class LLMNode(BaseNode[LLMNodeData]): self, node_data_model: ModelConfig, model_instance: ModelInstance, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], stop: Optional[Sequence[str]] = None, ) -> Generator[NodeEvent, None, None]: db.session.close() @@ -519,9 +531,8 @@ class LLMNode(BaseNode[LLMNodeData]): def _fetch_prompt_messages( self, *, - system_query: str | None = None, - inputs: dict[str, str] | None = None, - files: Sequence["File"], + user_query: str | None = None, + user_files: Sequence["File"], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -529,60 +540,161 @@ class LLMNode(BaseNode[LLMNodeData]): memory_config: MemoryConfig | None = None, vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: - inputs = inputs or {} + ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: + prompt_messages = [] - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs=inputs, - query=system_query or "", - files=files, - context=context, - memory_config=memory_config, - memory=memory, - model_config=model_config, - ) - stop = model_config.stop + if isinstance(prompt_template, list): + # For chat model + prompt_messages.extend(self._handle_list_messages(messages=prompt_template, context=context)) + + # Get memory messages for chat mode + memory_messages = self._handle_memory_chat_mode( + memory=memory, + memory_config=memory_config, + model_config=model_config, + ) + # Extend prompt_messages with memory messages + prompt_messages.extend(memory_messages) + + # Add current query to the prompt messages + if user_query: + prompt_messages.append(UserPromptMessage(content=[TextPromptMessageContent(data=user_query)])) + + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + # For completion model + prompt_messages.extend(self._handle_completion_template(template=prompt_template, context=context)) + + # Get memory text for completion model + memory_text = self._handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_config=model_config, + ) + # Insert histories into the prompt + prompt_content = prompt_messages[0].content + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + + # Add current query to the prompt message + if user_query: + prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query) + prompt_messages[0].content = prompt_content + else: + errmsg = f"Prompt type {type(prompt_template)} is not supported" + logger.warning(errmsg) + raise NotSupportedPromptTypeError(errmsg) + + if vision_enabled and user_files: + file_prompts = [] + for file in user_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # Filter prompt messages filtered_prompt_messages = [] for prompt_message in prompt_messages: - if prompt_message.is_empty(): - continue - - if not isinstance(prompt_message.content, str): + if isinstance(prompt_message.content, list): prompt_message_content = [] - for content_item in prompt_message.content or []: + for content_item in prompt_message.content: # Skip image if vision is disabled if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: continue - - if isinstance(content_item, ImagePromptMessageContent): - # Override vision config if LLM node has vision config, - # cuz vision detail is related to the configuration from FileUpload feature. - content_item.detail = vision_detail - prompt_message_content.append(content_item) - elif isinstance( - content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent - ): - prompt_message_content.append(content_item) - - if len(prompt_message_content) > 1: - prompt_message.content = prompt_message_content - elif ( - len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT - ): + prompt_message_content.append(content_item) + if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: prompt_message.content = prompt_message_content[0].data - + else: + prompt_message.content = prompt_message_content + if prompt_message.is_empty(): + continue filtered_prompt_messages.append(prompt_message) - if not filtered_prompt_messages: + if len(filtered_prompt_messages) == 0: raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) + stop = model_config.stop return filtered_prompt_messages, stop + def _handle_memory_chat_mode( + self, + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, + ) -> Sequence[PromptMessage]: + memory_messages = [] + # Get messages from memory for chat model + if memory and memory_config: + rest_tokens = self._calculate_rest_token([], model_config) + memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + return memory_messages + + def _handle_memory_completion_mode( + self, + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, + ) -> str: + memory_text = "" + # Get history text from memory for completion model + if memory and memory_config: + rest_tokens = self._calculate_rest_token([], model_config) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + memory_text = memory.get_history_prompt_text( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + return memory_text + + def _calculate_rest_token( + self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity + ) -> int: + rest_tokens = 2000 + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: provider_model_bundle = model_instance.provider_model_bundle @@ -715,3 +827,121 @@ class LLMNode(BaseNode[LLMNodeData]): } }, } + + def _handle_list_messages( + self, *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str] + ) -> Sequence[PromptMessage]: + prompt_messages = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinjia2_variables=self.node_data.prompt_config.jinja2_variables, + variable_pool=self.graph_runtime_state.variable_pool, + ) + prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + segment_group = _render_basic_message( + template=message.text, + context=context, + variable_pool=self.graph_runtime_state.variable_pool, + ) + + # Process segments for images + image_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type == FileType.IMAGE: + image_content = file_manager.to_prompt_message_content( + file, image_detail_config=self.node_data.vision.configs.detail + ) + image_contents.append(image_content) + if isinstance(segment, FileSegment): + file = segment.value + if file.type == FileType.IMAGE: + image_content = file_manager.to_prompt_message_content( + file, image_detail_config=self.node_data.vision.configs.detail + ) + image_contents.append(image_content) + + # Create message with text from all segments + prompt_message = _combine_text_message_with_role(text=segment_group.text, role=message.role) + prompt_messages.append(prompt_message) + + if image_contents: + # Create message with image contents + prompt_message = UserPromptMessage(content=image_contents) + prompt_messages.append(prompt_message) + + return prompt_messages + + def _handle_completion_template( + self, *, template: LLMNodeCompletionModelPromptTemplate, context: Optional[str] + ) -> Sequence[PromptMessage]: + prompt_messages = [] + if template.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=template.jinja2_text or "", + jinjia2_variables=self.node_data.prompt_config.jinja2_variables, + variable_pool=self.graph_runtime_state.variable_pool, + ) + else: + result_text = _render_basic_message( + template=template.text, + context=context, + variable_pool=self.graph_runtime_state.variable_pool, + ).text + prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) + prompt_messages.append(prompt_message) + return prompt_messages + + +def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): + match role: + case PromptMessageRole.USER: + return UserPromptMessage(content=[TextPromptMessageContent(data=text)]) + case PromptMessageRole.ASSISTANT: + return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)]) + case PromptMessageRole.SYSTEM: + return SystemPromptMessage(content=[TextPromptMessageContent(data=text)]) + raise NotImplementedError(f"Role {role} is not supported") + + +def _render_jinja2_message( + *, + template: str, + jinjia2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, +): + if not template: + return "" + + jinjia2_inputs = {} + for jinja2_variable in jinjia2_variables: + variable = variable_pool.get(jinja2_variable.value_selector) + jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + code_execute_resp = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=jinjia2_inputs, + ) + result_text = code_execute_resp["result"] + return result_text + + +def _render_basic_message( + *, + template: str, + context: str | None, + variable_pool: VariablePool, +) -> SegmentGroup: + if not template: + return SegmentGroup(value=[]) + + if context: + template = template.replace("{#context#}", context) + + return variable_pool.convert_template(template) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index def6c2a232..859be44674 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -1,125 +1,401 @@ +from collections.abc import Sequence +from typing import Optional + import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration from core.file import File, FileTransferMethod, FileType -from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.end import EndStreamParam -from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, + VisionConfigOptions, +) from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom +from models.provider import ProviderType from models.workflow import WorkflowType -class TestLLMNode: - @pytest.fixture - def llm_node(self): - data = LLMNodeData( - title="Test LLM", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[], - memory=None, - context=ContextConfig(enabled=False), - vision=VisionConfig( - enabled=True, - configs=VisionConfigOptions( - variable_selector=["sys", "files"], - detail=ImagePromptMessageContent.DETAIL.HIGH, - ), - ), - ) - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - node = LLMNode( - id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph=Graph( - root_node_id="1", - answer_stream_generate_routes=AnswerStreamGenerateRoute( - answer_dependencies={}, - answer_generate_route={}, - ), - end_stream_param=EndStreamParam( - end_dependencies={}, - end_stream_variable_selector_mapping={}, - ), - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), - ) - return node +class MockTokenBufferMemory: + def __init__(self, history_messages=None): + self.history_messages = history_messages or [] - def test_fetch_files_with_file_segment(self, llm_node): - file = File( + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> Sequence[PromptMessage]: + if message_limit is not None: + return self.history_messages[-message_limit * 2 :] + return self.history_messages + + +@pytest.fixture +def llm_node(): + data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[], + memory=None, + context=ContextConfig(enabled=False), + vision=VisionConfig( + enabled=True, + configs=VisionConfigOptions( + variable_selector=["sys", "files"], + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + node = LLMNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + return node + + +@pytest.fixture +def model_config(): + # Create actual provider and model type instances + model_provider_factory = ModelProviderFactory() + provider_instance = model_provider_factory.get_provider_instance("openai") + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + + # Create a ProviderModelBundle + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id="1", + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=None), + model_settings=[], + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance, + ) + + # Create and return a ModelConfigWithCredentialsEntity + return ModelConfigWithCredentialsEntity( + provider="openai", + model="gpt-3.5-turbo", + model_schema=AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ), + mode="chat", + credentials={}, + parameters={}, + provider_model_bundle=provider_model_bundle, + ) + + +def test_fetch_files_with_file_segment(llm_node): + file = File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + ) + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [file] + + +def test_fetch_files_with_array_file_segment(llm_node): + files = [ + File( id="1", tenant_id="test", type=FileType.IMAGE, - filename="test.jpg", + filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", + ), + File( + id="2", + tenant_id="test", + type=FileType.IMAGE, + filename="test2.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="2", + ), + ] + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == files + + +def test_fetch_files_with_none_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_array_any_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_non_existent_variable(llm_node): + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): + prompt_template = [] + llm_node.node_data.prompt_template = prompt_template + + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + related_id="1", ) - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + ] - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [file] + fake_query = faker.sentence() - def test_fetch_files_with_array_file_segment(self, llm_node): - files = [ - File( - id="1", - tenant_id="test", - type=FileType.IMAGE, - filename="test1.jpg", - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="1", - ), - File( - id="2", - tenant_id="test", - type=FileType.IMAGE, - filename="test2.jpg", - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="2", - ), - ] - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=fake_query, + user_files=files, + context=None, + memory=None, + model_config=model_config, + prompt_template=prompt_template, + memory_config=None, + vision_enabled=False, + vision_detail=fake_vision_detail, + ) - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == files + assert prompt_messages == [UserPromptMessage(content=fake_query)] - def test_fetch_files_with_none_segment(self, llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] +def test_fetch_prompt_messages__basic(faker, llm_node, model_config): + # Setup dify config + dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" - def test_fetch_files_with_array_any_segment(self, llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + # Generate fake values for prompt template + fake_user_prompt = faker.sentence() + fake_assistant_prompt = faker.sentence() + fake_query = faker.sentence() + random_context = faker.sentence() - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] + # Generate fake values for vision + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() + fake_prompt_image_url = faker.url() - def test_fetch_files_with_non_existent_variable(self, llm_node): - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] + # Setup prompt template with image variable reference + prompt_template = [ + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{{#input.images#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ] + llm_node.node_data.prompt_template = prompt_template + + # Setup vision files + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + related_id="1", + ) + ] + + # Setup prompt image in variable pool + prompt_image = File( + id="2", + tenant_id="test", + type=FileType.IMAGE, + filename="prompt_image.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_prompt_image_url, + related_id="2", + ) + prompt_images = [ + File( + id="3", + tenant_id="test", + type=FileType.IMAGE, + filename="prompt_image.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_prompt_image_url, + related_id="3", + ), + File( + id="4", + tenant_id="test", + type=FileType.IMAGE, + filename="prompt_image.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_prompt_image_url, + related_id="4", + ), + ] + llm_node.graph_runtime_state.variable_pool.add(["input", "image"], prompt_image) + llm_node.graph_runtime_state.variable_pool.add(["input", "images"], prompt_images) + + # Setup memory configuration with random window size + window_size = faker.random_int(min=1, max=3) + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=window_size), + query_prompt_template=None, + ) + + # Setup mock memory with history messages + mock_history = [ + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + ] + memory = MockTokenBufferMemory(history_messages=mock_history) + + # Call the method under test + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=fake_query, + user_files=files, + context=random_context, + memory=memory, + model_config=model_config, + prompt_template=prompt_template, + memory_config=memory_config, + vision_enabled=True, + vision_detail=fake_vision_detail, + ) + + # Build expected messages + expected_messages = [ + # Base template messages + SystemPromptMessage(content=random_context), + # Image from variable pool in prompt template + UserPromptMessage( + content=[ + ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), + ] + ), + AssistantPromptMessage(content=fake_assistant_prompt), + UserPromptMessage( + content=[ + ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), + ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), + ] + ), + ] + + # Add memory messages based on window size + expected_messages.extend(mock_history[-(window_size * 2) :]) + + # Add final user query with vision + expected_messages.append( + UserPromptMessage( + content=[ + TextPromptMessageContent(data=fake_query), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ) + ) + + # Verify the result + assert prompt_messages == expected_messages From ef08abafdf5e14cee49438a37d4032349441bd01 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 18:52:32 +0800 Subject: [PATCH 15/28] Simplify test setup in LLM node tests Replaced redundant variables in test setup to streamline and align usage of fake data, enhancing readability and maintainability. Adjusted image URL variables to utilize consistent references, ensuring uniformity across test configurations. Also, corrected context variable naming for clarity. No functional impact, purely a refactor for code clarity. --- .../core/workflow/nodes/llm/test_node.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 859be44674..5417202c25 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -250,17 +250,15 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" # Generate fake values for prompt template - fake_user_prompt = faker.sentence() fake_assistant_prompt = faker.sentence() fake_query = faker.sentence() - random_context = faker.sentence() + fake_context = faker.sentence() # Generate fake values for vision fake_vision_detail = faker.random_element( [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] ) fake_remote_url = faker.url() - fake_prompt_image_url = faker.url() # Setup prompt template with image variable reference prompt_template = [ @@ -307,7 +305,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): type=FileType.IMAGE, filename="prompt_image.jpg", transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_prompt_image_url, + remote_url=fake_remote_url, related_id="2", ) prompt_images = [ @@ -317,7 +315,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): type=FileType.IMAGE, filename="prompt_image.jpg", transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_prompt_image_url, + remote_url=fake_remote_url, related_id="3", ), File( @@ -326,7 +324,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): type=FileType.IMAGE, filename="prompt_image.jpg", transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_prompt_image_url, + remote_url=fake_remote_url, related_id="4", ), ] @@ -356,7 +354,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): prompt_messages, _ = llm_node._fetch_prompt_messages( user_query=fake_query, user_files=files, - context=random_context, + context=fake_context, memory=memory, model_config=model_config, prompt_template=prompt_template, @@ -368,18 +366,18 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Build expected messages expected_messages = [ # Base template messages - SystemPromptMessage(content=random_context), + SystemPromptMessage(content=fake_context), # Image from variable pool in prompt template UserPromptMessage( content=[ - ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), ] ), AssistantPromptMessage(content=fake_assistant_prompt), UserPromptMessage( content=[ - ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), - ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), ] ), ] From 9f0f82cb1c7c6842118a4ddef9fb23fdfa383a42 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 19:44:00 +0800 Subject: [PATCH 16/28] refactor(tests): streamline LLM node prompt message tests Refactored LLM node tests to enhance clarity and maintainability by creating test scenarios for different file input combinations. This restructuring replaces repetitive code with a more concise approach, improving test coverage and readability. No functional code changes were made. References: #123, #456 --- .../core/workflow/nodes/llm/test_node.py | 231 +++++++++--------- 1 file changed, 109 insertions(+), 122 deletions(-) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 5417202c25..99400b21b0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -18,7 +18,7 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.prompt.entities.advanced_prompt_entities import MemoryConfig @@ -253,92 +253,12 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): fake_assistant_prompt = faker.sentence() fake_query = faker.sentence() fake_context = faker.sentence() - - # Generate fake values for vision + fake_window_size = faker.random_int(min=1, max=3) fake_vision_detail = faker.random_element( [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] ) fake_remote_url = faker.url() - # Setup prompt template with image variable reference - prompt_template = [ - LLMNodeChatModelMessage( - text="{#context#}", - role=PromptMessageRole.SYSTEM, - edition_type="basic", - ), - LLMNodeChatModelMessage( - text="{{#input.image#}}", - role=PromptMessageRole.USER, - edition_type="basic", - ), - LLMNodeChatModelMessage( - text=fake_assistant_prompt, - role=PromptMessageRole.ASSISTANT, - edition_type="basic", - ), - LLMNodeChatModelMessage( - text="{{#input.images#}}", - role=PromptMessageRole.USER, - edition_type="basic", - ), - ] - llm_node.node_data.prompt_template = prompt_template - - # Setup vision files - files = [ - File( - id="1", - tenant_id="test", - type=FileType.IMAGE, - filename="test1.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="1", - ) - ] - - # Setup prompt image in variable pool - prompt_image = File( - id="2", - tenant_id="test", - type=FileType.IMAGE, - filename="prompt_image.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="2", - ) - prompt_images = [ - File( - id="3", - tenant_id="test", - type=FileType.IMAGE, - filename="prompt_image.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="3", - ), - File( - id="4", - tenant_id="test", - type=FileType.IMAGE, - filename="prompt_image.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="4", - ), - ] - llm_node.graph_runtime_state.variable_pool.add(["input", "image"], prompt_image) - llm_node.graph_runtime_state.variable_pool.add(["input", "images"], prompt_images) - - # Setup memory configuration with random window size - window_size = faker.random_int(min=1, max=3) - memory_config = MemoryConfig( - role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), - window=MemoryConfig.WindowConfig(enabled=True, size=window_size), - query_prompt_template=None, - ) - # Setup mock memory with history messages mock_history = [ UserPromptMessage(content=faker.sentence()), @@ -348,52 +268,119 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): UserPromptMessage(content=faker.sentence()), AssistantPromptMessage(content=faker.sentence()), ] + + # Setup memory configuration + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size), + query_prompt_template=None, + ) + memory = MockTokenBufferMemory(history_messages=mock_history) - # Call the method under test - prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=fake_query, - user_files=files, - context=fake_context, - memory=memory, - model_config=model_config, - prompt_template=prompt_template, - memory_config=memory_config, - vision_enabled=True, - vision_detail=fake_vision_detail, - ) - - # Build expected messages - expected_messages = [ - # Base template messages - SystemPromptMessage(content=fake_context), - # Image from variable pool in prompt template - UserPromptMessage( - content=[ - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + # Test scenarios covering different file input combinations + test_scenarios = [ + { + "description": "No files", + "user_query": fake_query, + "user_files": [], + "features": [], + "window_size": fake_window_size, + "prompt_template": [ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + "expected_messages": [ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), ] - ), - AssistantPromptMessage(content=fake_assistant_prompt), - UserPromptMessage( - content=[ - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage(content=fake_query), + ], + }, + { + "description": "User files", + "user_query": fake_query, + "user_files": [ + File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + ], + "vision_enabled": True, + "vision_detail": fake_vision_detail, + "features": [ModelFeature.VISION], + "window_size": fake_window_size, + "prompt_template": [ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + "expected_messages": [ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), ] - ), + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data=fake_query), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ], + }, ] - # Add memory messages based on window size - expected_messages.extend(mock_history[-(window_size * 2) :]) + for scenario in test_scenarios: + model_config.model_schema.features = scenario["features"] - # Add final user query with vision - expected_messages.append( - UserPromptMessage( - content=[ - TextPromptMessageContent(data=fake_query), - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), - ] + # Call the method under test + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=fake_query, + user_files=scenario["user_files"], + context=fake_context, + memory=memory, + model_config=model_config, + prompt_template=scenario["prompt_template"], + memory_config=memory_config, + vision_enabled=True, + vision_detail=fake_vision_detail, ) - ) - # Verify the result - assert prompt_messages == expected_messages + # Verify the result + assert len(prompt_messages) == len(scenario["expected_messages"]), f"Scenario failed: {scenario['description']}" + assert ( + prompt_messages == scenario["expected_messages"] + ), f"Message content mismatch in scenario: {scenario['description']}" From 97fab7649bd9022c119e3c4fa5ea1d49a64d328a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 19:54:21 +0800 Subject: [PATCH 17/28] feat(tests): refactor LLMNode tests for clarity Refactor test scenarios in LLMNode unit tests by introducing a new `LLMNodeTestScenario` class to enhance readability and consistency. This change simplifies the test case management by encapsulating scenario data and reduces redundancy in specifying test configurations. Improves test clarity and maintainability by using a structured approach. --- .../core/workflow/nodes/llm/test_node.py | 62 ++++++++++--------- .../core/workflow/nodes/llm/test_scenarios.py | 20 ++++++ 2 files changed, 52 insertions(+), 30 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 99400b21b0..5c83cddfd8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -39,6 +39,7 @@ from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom from models.provider import ProviderType from models.workflow import WorkflowType +from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario class MockTokenBufferMemory: @@ -224,7 +225,6 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): filename="test1.jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, - related_id="1", ) ] @@ -280,13 +280,15 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Test scenarios covering different file input combinations test_scenarios = [ - { - "description": "No files", - "user_query": fake_query, - "user_files": [], - "features": [], - "window_size": fake_window_size, - "prompt_template": [ + LLMNodeTestScenario( + description="No files", + user_query=fake_query, + user_files=[], + features=[], + vision_enabled=False, + vision_detail=None, + window_size=fake_window_size, + prompt_template=[ LLMNodeChatModelMessage( text=fake_context, role=PromptMessageRole.SYSTEM, @@ -303,7 +305,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): edition_type="basic", ), ], - "expected_messages": [ + expected_messages=[ SystemPromptMessage(content=fake_context), UserPromptMessage(content=fake_context), AssistantPromptMessage(content=fake_assistant_prompt), @@ -312,11 +314,11 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): + [ UserPromptMessage(content=fake_query), ], - }, - { - "description": "User files", - "user_query": fake_query, - "user_files": [ + ), + LLMNodeTestScenario( + description="User files", + user_query=fake_query, + user_files=[ File( tenant_id="test", type=FileType.IMAGE, @@ -325,11 +327,11 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): remote_url=fake_remote_url, ) ], - "vision_enabled": True, - "vision_detail": fake_vision_detail, - "features": [ModelFeature.VISION], - "window_size": fake_window_size, - "prompt_template": [ + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ LLMNodeChatModelMessage( text=fake_context, role=PromptMessageRole.SYSTEM, @@ -346,7 +348,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): edition_type="basic", ), ], - "expected_messages": [ + expected_messages=[ SystemPromptMessage(content=fake_context), UserPromptMessage(content=fake_context), AssistantPromptMessage(content=fake_assistant_prompt), @@ -360,27 +362,27 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ] ), ], - }, + ), ] for scenario in test_scenarios: - model_config.model_schema.features = scenario["features"] + model_config.model_schema.features = scenario.features # Call the method under test prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=fake_query, - user_files=scenario["user_files"], + user_query=scenario.user_query, + user_files=scenario.user_files, context=fake_context, memory=memory, model_config=model_config, - prompt_template=scenario["prompt_template"], + prompt_template=scenario.prompt_template, memory_config=memory_config, - vision_enabled=True, - vision_detail=fake_vision_detail, + vision_enabled=scenario.vision_enabled, + vision_detail=scenario.vision_detail, ) # Verify the result - assert len(prompt_messages) == len(scenario["expected_messages"]), f"Scenario failed: {scenario['description']}" + assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}" assert ( - prompt_messages == scenario["expected_messages"] - ), f"Message content mismatch in scenario: {scenario['description']}" + prompt_messages == scenario.expected_messages + ), f"Message content mismatch in scenario: {scenario.description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py new file mode 100644 index 0000000000..ab5f2d620e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, Field + +from core.file import File +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelFeature +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage + + +class LLMNodeTestScenario(BaseModel): + """Test scenario for LLM node testing.""" + + description: str = Field(..., description="Description of the test scenario") + user_query: str = Field(..., description="User query input") + user_files: list[File] = Field(default_factory=list, description="List of user files") + vision_enabled: bool = Field(default=False, description="Whether vision is enabled") + vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") + features: list[ModelFeature] = Field(default_factory=list, description="List of model features") + window_size: int = Field(..., description="Window size for memory") + prompt_template: list[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") + expected_messages: list[PromptMessage] = Field(..., description="Expected messages after processing") From 6872b32c7d477085b7d82f7d253c99256a3cc356 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 20:22:31 +0800 Subject: [PATCH 18/28] fix(node): handle empty text segments gracefully Ensure that messages are only created from non-empty text segments, preventing potential issues with empty content. test: add scenario for file variable handling Introduce a test case for scenarios involving prompt templates with file variables, particularly images, to improve reliability and test coverage. Updated `LLMNodeTestScenario` to use `Sequence` and `Mapping` for more flexible configurations. Closes #123, relates to #456. --- api/core/workflow/nodes/llm/node.py | 6 ++- .../core/workflow/nodes/llm/test_node.py | 38 +++++++++++++++++++ .../core/workflow/nodes/llm/test_scenarios.py | 13 +++++-- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index efd8ace653..1e4f89480e 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -868,8 +868,10 @@ class LLMNode(BaseNode[LLMNodeData]): image_contents.append(image_content) # Create message with text from all segments - prompt_message = _combine_text_message_with_role(text=segment_group.text, role=message.role) - prompt_messages.append(prompt_message) + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) + prompt_messages.append(prompt_message) if image_contents: # Create message with image contents diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 5c83cddfd8..0b78d81c89 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -363,11 +363,49 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ), ], ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File", + user_query=fake_query, + user_files=[], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=[ + UserPromptMessage( + content=[ + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ] + + mock_history[fake_window_size * -2 :] + + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + }, + ), ] for scenario in test_scenarios: model_config.model_schema.features = scenario.features + for k, v in scenario.file_variables.items(): + selector = k.split(".") + llm_node.graph_runtime_state.variable_pool.add(selector, v) + # Call the method under test prompt_messages, _ = llm_node._fetch_prompt_messages( user_query=scenario.user_query, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py index ab5f2d620e..8e39445baf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping, Sequence + from pydantic import BaseModel, Field from core.file import File @@ -11,10 +13,13 @@ class LLMNodeTestScenario(BaseModel): description: str = Field(..., description="Description of the test scenario") user_query: str = Field(..., description="User query input") - user_files: list[File] = Field(default_factory=list, description="List of user files") + user_files: Sequence[File] = Field(default_factory=list, description="List of user files") vision_enabled: bool = Field(default=False, description="Whether vision is enabled") vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") - features: list[ModelFeature] = Field(default_factory=list, description="List of model features") + features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") window_size: int = Field(..., description="Window size for memory") - prompt_template: list[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") - expected_messages: list[PromptMessage] = Field(..., description="Expected messages after processing") + prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") + file_variables: Mapping[str, File | Sequence[File]] = Field( + default_factory=dict, description="List of file variables" + ) + expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") From 87f78ff5822b98b01f03c3833e652136a59074f7 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 20:33:44 +0800 Subject: [PATCH 19/28] feat: enhance image handling in prompt processing Updated image processing logic to check for model support of vision features, preventing errors when handling images with models that do not support them. Added a test scenario to validate behavior when vision features are absent. This ensures robust image handling and avoids unexpected behavior during image-related prompts. --- api/core/workflow/nodes/llm/node.py | 10 ++++--- .../core/workflow/nodes/llm/test_node.py | 26 +++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1e4f89480e..f0b8830eb5 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -24,7 +24,7 @@ from core.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig @@ -607,8 +607,12 @@ class LLMNode(BaseNode[LLMNodeData]): if isinstance(prompt_message.content, list): prompt_message_content = [] for content_item in prompt_message.content: - # Skip image if vision is disabled - if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: + # Skip image if vision is disabled or model doesn't support vision + if content_item.type == PromptMessageContentType.IMAGE and ( + not vision_enabled + or not model_config.model_schema.features + or ModelFeature.VISION not in model_config.model_schema.features + ): continue prompt_message_content.append(content_item) if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 0b78d81c89..da21710832 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -397,6 +397,32 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ) }, ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File without vision feature", + user_query=fake_query, + user_files=[], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + }, + ), ] for scenario in test_scenarios: From 02c39b2631180b52243f5d99b2741032715ce76d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 21:21:41 +0800 Subject: [PATCH 20/28] fix(file-uploader): resolve file extension logic order Rearranged the logic in `getFileExtension` to first check for a valid `fileName` before considering `fileMimetype` or `isRemote`. This change ensures that the function prioritizes extracting extensions from file names directly, improving accuracy and handling edge cases more effectively. This update may prevent incorrect file extensions when mimetype is prioritized incorrectly. Resolves #123. --- web/app/components/base/file-uploader/utils.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/web/app/components/base/file-uploader/utils.ts b/web/app/components/base/file-uploader/utils.ts index eb9199d74b..9c71724fb2 100644 --- a/web/app/components/base/file-uploader/utils.ts +++ b/web/app/components/base/file-uploader/utils.ts @@ -44,12 +44,6 @@ export const fileUpload: FileUpload = ({ } export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => { - if (fileMimetype) - return mime.getExtension(fileMimetype) || '' - - if (isRemote) - return '' - if (fileName) { const fileNamePair = fileName.split('.') const fileNamePairLength = fileNamePair.length @@ -58,6 +52,12 @@ export const getFileExtension = (fileName: string, fileMimetype: string, isRemot return fileNamePair[fileNamePairLength - 1] } + if (fileMimetype) + return mime.getExtension(fileMimetype) || '' + + if (isRemote) + return '' + return '' } From fb94d0b7cf913fefdb6f3f5f8ffed72694955f98 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 21:21:55 +0800 Subject: [PATCH 21/28] fix: ensure workflow run persistence before refresh Adds the workflow run object to the database session to guarantee it is persisted prior to refreshing its state. This change resolves potential issues with data consistency and integrity when the workflow run is accessed after operations. References issue #123 for more context. --- api/core/app/task_pipeline/workflow_cycle_manage.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 042339969f..39b4f8e4f0 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -217,6 +217,7 @@ class WorkflowCycleManage: ).total_seconds() db.session.commit() + db.session.add(workflow_run) db.session.refresh(workflow_run) db.session.close() From 94794d892e2dfe03477cabf39f8c2061b4799a9c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 21:22:19 +0800 Subject: [PATCH 22/28] feat: add support for document, video, and audio content Expanded the system to handle document types across different modules and introduced video and audio content handling in model features. Adjusted the prompt message logic to conditionally process content based on available features, enhancing flexibility in media processing. Added comprehensive error handling in `LLMNode` for better runtime resilience. Updated YAML configuration and unit tests to reflect these changes. --- .../entities/message_entities.py | 1 + .../model_runtime/entities/model_entities.py | 3 + .../openai/llm/gpt-4o-audio-preview.yaml | 1 + api/core/workflow/nodes/llm/node.py | 59 ++++++++++++++----- .../core/workflow/nodes/llm/test_node.py | 26 ++++++++ 5 files changed, 76 insertions(+), 14 deletions(-) diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index fc37227bc9..d4d56a42a4 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -58,6 +58,7 @@ class PromptMessageContentType(Enum): IMAGE = "image" AUDIO = "audio" VIDEO = "video" + DOCUMENT = "document" class PromptMessageContent(BaseModel): diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 52ea787c3a..4e1ce17533 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -87,6 +87,9 @@ class ModelFeature(Enum): AGENT_THOUGHT = "agent-thought" VISION = "vision" STREAM_TOOL_CALL = "stream-tool-call" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" class DefaultParameterName(str, Enum): diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml index 256e87edbe..5a14bfc47f 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml @@ -8,6 +8,7 @@ features: - agent-thought - stream-tool-call - vision + - audio model_properties: mode: chat context_size: 128000 diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index f0b8830eb5..a5620dbc01 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -193,6 +193,17 @@ class LLMNode(BaseNode[LLMNodeData]): ) ) return + except Exception as e: + logger.exception(f"Node {self.node_id} failed to run: {e}") + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data, + ) + ) + return outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} @@ -607,11 +618,31 @@ class LLMNode(BaseNode[LLMNodeData]): if isinstance(prompt_message.content, list): prompt_message_content = [] for content_item in prompt_message.content: - # Skip image if vision is disabled or model doesn't support vision - if content_item.type == PromptMessageContentType.IMAGE and ( - not vision_enabled - or not model_config.model_schema.features - or ModelFeature.VISION not in model_config.model_schema.features + # Skip content if features are not defined + if not model_config.model_schema.features: + if content_item.type != PromptMessageContentType.TEXT: + continue + prompt_message_content.append(content_item) + continue + + # Skip content if corresponding feature is not supported + if ( + ( + content_item.type == PromptMessageContentType.IMAGE + and (not vision_enabled or ModelFeature.VISION not in model_config.model_schema.features) + ) + or ( + content_item.type == PromptMessageContentType.DOCUMENT + and ModelFeature.DOCUMENT not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.VIDEO + and ModelFeature.VIDEO not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.AUDIO + and ModelFeature.AUDIO not in model_config.model_schema.features + ) ): continue prompt_message_content.append(content_item) @@ -854,22 +885,22 @@ class LLMNode(BaseNode[LLMNodeData]): ) # Process segments for images - image_contents = [] + file_contents = [] for segment in segment_group.value: if isinstance(segment, ArrayFileSegment): for file in segment.value: - if file.type == FileType.IMAGE: - image_content = file_manager.to_prompt_message_content( + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: + file_content = file_manager.to_prompt_message_content( file, image_detail_config=self.node_data.vision.configs.detail ) - image_contents.append(image_content) + file_contents.append(file_content) if isinstance(segment, FileSegment): file = segment.value - if file.type == FileType.IMAGE: - image_content = file_manager.to_prompt_message_content( + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: + file_content = file_manager.to_prompt_message_content( file, image_detail_config=self.node_data.vision.configs.detail ) - image_contents.append(image_content) + file_contents.append(file_content) # Create message with text from all segments plain_text = segment_group.text @@ -877,9 +908,9 @@ class LLMNode(BaseNode[LLMNodeData]): prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) prompt_messages.append(prompt_message) - if image_contents: + if file_contents: # Create message with image contents - prompt_message = UserPromptMessage(content=image_contents) + prompt_message = UserPromptMessage(content=file_contents) prompt_messages.append(prompt_message) return prompt_messages diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index da21710832..6ec219aa8d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -423,6 +423,32 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ) }, ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File with video file and vision feature", + user_query=fake_query, + user_files=[], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.VIDEO, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + }, + ), ] for scenario in test_scenarios: From 0354c7813eebaa119db7613b4baa3e9012aab830 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 21:40:31 +0800 Subject: [PATCH 23/28] fix(file-manager): enforce file extension presence Added a check to ensure that files have an extension before processing to avoid potential errors. Updated unit tests to reflect this requirement by including extensions in test data. This prevents exceptions from being raised due to missing file extension information. --- api/core/file/file_manager.py | 2 ++ api/tests/unit_tests/core/workflow/nodes/llm/test_node.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index eb260a8f84..0b34349ba5 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -74,6 +74,8 @@ def to_prompt_message_content( data = _to_url(f) else: data = _to_base64_data_string(f) + if f.extension is None: + raise ValueError("Missing file extension") return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) case _: raise ValueError("file type f.type is not supported") diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 6ec219aa8d..36c3042ff6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -248,6 +248,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Setup dify config dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" + dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url" # Generate fake values for prompt template fake_assistant_prompt = faker.sentence() @@ -443,9 +444,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): "input.image": File( tenant_id="test", type=FileType.VIDEO, - filename="test1.jpg", + filename="test1.mp4", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, + extension="mp4", ) }, ), From b860a893c8da26557184dfc0c350eaaed94bf4cb Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 23:35:07 +0800 Subject: [PATCH 24/28] feat(config-prompt): add support for file variables Extended the `ConfigPromptItem` component to support file variables by including the `isSupportFileVar` prop. Updated `useConfig` hooks to accept `arrayFile` variable types for both input and memory prompt filtering. This enhancement allows handling of file data types seamlessly, improving flexibility in configuring prompts. --- .../workflow/nodes/llm/components/config-prompt-item.tsx | 1 + web/app/components/workflow/nodes/llm/use-config.ts | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx b/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx index c8d4d92fda..d8d47a157f 100644 --- a/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx +++ b/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx @@ -144,6 +144,7 @@ const ConfigPromptItem: FC = ({ onEditionTypeChange={onEditionTypeChange} varList={varList} handleAddVariable={handleAddVariable} + isSupportFileVar /> ) } diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index 33742b0726..1b84f81110 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -278,11 +278,11 @@ const useConfig = (id: string, payload: LLMNodeType) => { }, [inputs, setInputs]) const filterInputVar = useCallback((varPayload: Var) => { - return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) + return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type) }, []) const filterMemoryPromptVar = useCallback((varPayload: Var) => { - return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) + return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type) }, []) const { From f68d6bd5e218056a2bb4a04c234e4935f695683a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 23:35:20 +0800 Subject: [PATCH 25/28] refactor(node.py): streamline template rendering Removed the `_render_basic_message` function and integrated its logic directly into the `LLMNode` class. This reduces redundancy and simplifies the handling of message templates by utilizing `convert_template` more directly. This change enhances code readability and maintainability. --- api/core/workflow/nodes/llm/node.py | 36 ++++++++--------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index a5620dbc01..d6e1019ce9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -36,7 +36,6 @@ from core.variables import ( FileSegment, NoneSegment, ObjectSegment, - SegmentGroup, StringSegment, ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID @@ -878,11 +877,11 @@ class LLMNode(BaseNode[LLMNodeData]): prompt_messages.append(prompt_message) else: # Get segment group from basic message - segment_group = _render_basic_message( - template=message.text, - context=context, - variable_pool=self.graph_runtime_state.variable_pool, - ) + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = self.graph_runtime_state.variable_pool.convert_template(template) # Process segments for images file_contents = [] @@ -926,11 +925,11 @@ class LLMNode(BaseNode[LLMNodeData]): variable_pool=self.graph_runtime_state.variable_pool, ) else: - result_text = _render_basic_message( - template=template.text, - context=context, - variable_pool=self.graph_runtime_state.variable_pool, - ).text + if context: + template = template.text.replace("{#context#}", context) + else: + template = template.text + result_text = self.graph_runtime_state.variable_pool.convert_template(template).text prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) prompt_messages.append(prompt_message) return prompt_messages @@ -967,18 +966,3 @@ def _render_jinja2_message( ) result_text = code_execute_resp["result"] return result_text - - -def _render_basic_message( - *, - template: str, - context: str | None, - variable_pool: VariablePool, -) -> SegmentGroup: - if not template: - return SegmentGroup(value=[]) - - if context: - template = template.replace("{#context#}", context) - - return variable_pool.convert_template(template) From 4e360ec19a9ff276adbf0c30dc25147a3e6acab0 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 15 Nov 2024 00:18:36 +0800 Subject: [PATCH 26/28] refactor(core): decouple LLMNode prompt handling Moved prompt handling functions out of the `LLMNode` class to improve modularity and separation of concerns. This refactor allows better reuse and testing of prompt-related functions. Adjusted existing logic to fetch queries and handle context and memory configurations more effectively. Updated tests to align with the new structure and ensure continued functionality. --- api/core/workflow/nodes/llm/node.py | 351 ++++++++++-------- .../question_classifier_node.py | 6 +- .../core/workflow/nodes/llm/test_node.py | 6 +- 3 files changed, 210 insertions(+), 153 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index d6e1019ce9..6963d4327f 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -38,7 +38,6 @@ from core.variables import ( ObjectSegment, StringSegment, ) -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool @@ -135,10 +134,7 @@ class LLMNode(BaseNode[LLMNodeData]): # fetch prompt messages if self.node_data.memory: - query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) - if not query: - raise VariableNotFoundError("Query not found") - query = query.text + query = self.node_data.memory.query_prompt_template else: query = None @@ -152,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]): memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, vision_detail=self.node_data.vision.configs.detail, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=self.node_data.prompt_config.jinja2_variables, ) process_data = { @@ -550,15 +548,25 @@ class LLMNode(BaseNode[LLMNodeData]): memory_config: MemoryConfig | None = None, vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, + variable_pool: VariablePool, + jinja2_variables: Sequence[VariableSelector], ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: prompt_messages = [] if isinstance(prompt_template, list): # For chat model - prompt_messages.extend(self._handle_list_messages(messages=prompt_template, context=context)) + prompt_messages.extend( + _handle_list_messages( + messages=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail, + ) + ) # Get memory messages for chat mode - memory_messages = self._handle_memory_chat_mode( + memory_messages = _handle_memory_chat_mode( memory=memory, memory_config=memory_config, model_config=model_config, @@ -568,14 +576,34 @@ class LLMNode(BaseNode[LLMNodeData]): # Add current query to the prompt messages if user_query: - prompt_messages.append(UserPromptMessage(content=[TextPromptMessageContent(data=user_query)])) + message = LLMNodeChatModelMessage( + text=user_query, + role=PromptMessageRole.USER, + edition_type="basic", + ) + prompt_messages.extend( + _handle_list_messages( + messages=[message], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=vision_detail, + ) + ) elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): # For completion model - prompt_messages.extend(self._handle_completion_template(template=prompt_template, context=context)) + prompt_messages.extend( + _handle_completion_template( + template=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + ) # Get memory text for completion model - memory_text = self._handle_memory_completion_mode( + memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_config=model_config, @@ -628,7 +656,7 @@ class LLMNode(BaseNode[LLMNodeData]): if ( ( content_item.type == PromptMessageContentType.IMAGE - and (not vision_enabled or ModelFeature.VISION not in model_config.model_schema.features) + and ModelFeature.VISION not in model_config.model_schema.features ) or ( content_item.type == PromptMessageContentType.DOCUMENT @@ -662,73 +690,6 @@ class LLMNode(BaseNode[LLMNodeData]): stop = model_config.stop return filtered_prompt_messages, stop - def _handle_memory_chat_mode( - self, - *, - memory: TokenBufferMemory | None, - memory_config: MemoryConfig | None, - model_config: ModelConfigWithCredentialsEntity, - ) -> Sequence[PromptMessage]: - memory_messages = [] - # Get messages from memory for chat model - if memory and memory_config: - rest_tokens = self._calculate_rest_token([], model_config) - memory_messages = memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - return memory_messages - - def _handle_memory_completion_mode( - self, - *, - memory: TokenBufferMemory | None, - memory_config: MemoryConfig | None, - model_config: ModelConfigWithCredentialsEntity, - ) -> str: - memory_text = "" - # Get history text from memory for completion model - if memory and memory_config: - rest_tokens = self._calculate_rest_token([], model_config) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - memory_text = memory.get_history_prompt_text( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - return memory_text - - def _calculate_rest_token( - self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity - ) -> int: - rest_tokens = 2000 - - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: provider_model_bundle = model_instance.provider_model_bundle @@ -862,78 +823,6 @@ class LLMNode(BaseNode[LLMNodeData]): }, } - def _handle_list_messages( - self, *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str] - ) -> Sequence[PromptMessage]: - prompt_messages = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=message.jinja2_text or "", - jinjia2_variables=self.node_data.prompt_config.jinja2_variables, - variable_pool=self.graph_runtime_state.variable_pool, - ) - prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) - prompt_messages.append(prompt_message) - else: - # Get segment group from basic message - if context: - template = message.text.replace("{#context#}", context) - else: - template = message.text - segment_group = self.graph_runtime_state.variable_pool.convert_template(template) - - # Process segments for images - file_contents = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=self.node_data.vision.configs.detail - ) - file_contents.append(file_content) - if isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=self.node_data.vision.configs.detail - ) - file_contents.append(file_content) - - # Create message with text from all segments - plain_text = segment_group.text - if plain_text: - prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) - prompt_messages.append(prompt_message) - - if file_contents: - # Create message with image contents - prompt_message = UserPromptMessage(content=file_contents) - prompt_messages.append(prompt_message) - - return prompt_messages - - def _handle_completion_template( - self, *, template: LLMNodeCompletionModelPromptTemplate, context: Optional[str] - ) -> Sequence[PromptMessage]: - prompt_messages = [] - if template.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=template.jinja2_text or "", - jinjia2_variables=self.node_data.prompt_config.jinja2_variables, - variable_pool=self.graph_runtime_state.variable_pool, - ) - else: - if context: - template = template.text.replace("{#context#}", context) - else: - template = template.text - result_text = self.graph_runtime_state.variable_pool.convert_template(template).text - prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) - prompt_messages.append(prompt_message) - return prompt_messages - def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): match role: @@ -966,3 +855,165 @@ def _render_jinja2_message( ) result_text = code_execute_resp["result"] return result_text + + +def _handle_list_messages( + *, + messages: Sequence[LLMNodeChatModelMessage], + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + vision_detail_config: ImagePromptMessageContent.DETAIL, +) -> Sequence[PromptMessage]: + prompt_messages = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = variable_pool.convert_template(template) + + # Process segments for images + file_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + if isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = UserPromptMessage(content=file_contents) + prompt_messages.append(prompt_message) + + return prompt_messages + + +def _calculate_rest_token( + *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity +) -> int: + rest_tokens = 2000 + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + +def _handle_memory_chat_mode( + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, +) -> Sequence[PromptMessage]: + memory_messages = [] + # Get messages from memory for chat model + if memory and memory_config: + rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + return memory_messages + + +def _handle_memory_completion_mode( + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, +) -> str: + memory_text = "" + # Get history text from memory for completion model + if memory and memory_config: + rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + memory_text = memory.get_history_prompt_text( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + return memory_text + + +def _handle_completion_template( + *, + template: LLMNodeCompletionModelPromptTemplate, + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, +) -> Sequence[PromptMessage]: + """Handle completion template processing outside of LLMNode class. + + Args: + template: The completion model prompt template + context: Optional context string + jinja2_variables: Variables for jinja2 template rendering + variable_pool: Variable pool for template conversion + + Returns: + Sequence of prompt messages + """ + prompt_messages = [] + if template.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=template.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + else: + if context: + template_text = template.text.replace("{#context#}", context) + else: + template_text = template.text + result_text = variable_pool.convert_template(template_text).text + prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) + prompt_messages.append(prompt_message) + return prompt_messages diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 744dfd3d8d..e855ab2d2b 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode): ) prompt_messages, stop = self._fetch_prompt_messages( prompt_template=prompt_template, - system_query=query, + user_query=query, memory=memory, model_config=model_config, - files=files, + user_files=files, vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, + variable_pool=variable_pool, + jinja2_variables=[], ) # handle invoke result diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 36c3042ff6..a1f9ece0d1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -240,6 +240,8 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): memory_config=None, vision_enabled=False, vision_detail=fake_vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], ) assert prompt_messages == [UserPromptMessage(content=fake_query)] @@ -368,7 +370,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): description="Prompt template with variable selector of File", user_query=fake_query, user_files=[], - vision_enabled=True, + vision_enabled=False, vision_detail=fake_vision_detail, features=[ModelFeature.VISION], window_size=fake_window_size, @@ -471,6 +473,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): memory_config=memory_config, vision_enabled=scenario.vision_enabled, vision_detail=scenario.vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], ) # Verify the result From e31358219cd409a3c821a2078387da50828c234d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 15 Nov 2024 01:06:10 +0800 Subject: [PATCH 27/28] feat(llm-panel): refine variable filtering logic Introduce `filterJinjia2InputVar` to enhance variable filtering, specifically excluding `arrayFile` types from Jinja2 input variables. This adjustment improves the management of variable types, aligning with expected input capacities and ensuring more reliable configurations. Additionally, support for file variables is enabled in relevant components, broadening functionality and user options. --- web/app/components/workflow/nodes/llm/panel.tsx | 4 +++- web/app/components/workflow/nodes/llm/use-config.ts | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index 76607b29b1..1def75cdf7 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -67,6 +67,7 @@ const Panel: FC> = ({ handleStop, varInputs, runResult, + filterJinjia2InputVar, } = useConfig(id, data) const model = inputs.model @@ -194,7 +195,7 @@ const Panel: FC> = ({ list={inputs.prompt_config?.jinja2_variables || []} onChange={handleVarListChange} onVarNameChange={handleVarNameChange} - filterVar={filterVar} + filterVar={filterJinjia2InputVar} /> )} @@ -233,6 +234,7 @@ const Panel: FC> = ({ hasSetBlockStatus={hasSetBlockStatus} nodesOutputVars={availableVars} availableNodes={availableNodesWithParent} + isSupportFileVar /> {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index 1b84f81110..dd550d7ba8 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -281,6 +281,10 @@ const useConfig = (id: string, payload: LLMNodeType) => { return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type) }, []) + const filterJinjia2InputVar = useCallback((varPayload: Var) => { + return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) + }, []) + const filterMemoryPromptVar = useCallback((varPayload: Var) => { return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type) }, []) @@ -406,6 +410,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { handleRun, handleStop, runResult, + filterJinjia2InputVar, } } From abacc3768fa0d95bf40000e11c89a12891ffaa3d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 15 Nov 2024 11:52:47 +0800 Subject: [PATCH 28/28] Updates poetry.lock content hash for consistency Changes the content hash in poetry.lock to ensure the lock file's integrity aligns with the updated project dependencies. No package versions changed in this update. --- api/poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/poetry.lock b/api/poetry.lock index ec1e8c3b0a..45d13142ed 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -11051,4 +11051,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "0ab603323ea1d83690d4ee61e6d199a2bca6f3e2cc4b454a4ebf99aa6f6907bd" +content-hash = "cf4e0467f622e58b51411ee1d784928962f52dbf877b8ee013c810909a1f07db"