diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 988bb0ce44..4760e8f118 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -85,7 +85,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): tools=tools, stop=stop, stream=stream, user=user, extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key'), ) ) @@ -106,7 +107,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key') ) if 'completion_type' not in credentials: if 'chat' in extra_param.model_ability: @@ -396,7 +398,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): else: extra_args = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key') ) if 'chat' in extra_args.model_ability: @@ -464,6 +467,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): xinference_client = Client( base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_model = xinference_client.get_model(credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 4e7543fd99..d809537479 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -108,7 +108,8 @@ class XinferenceRerankModel(RerankModel): # initialize client client = Client( - base_url=credentials['server_url'] + base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_client = client.get_model(model_uid=credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 9ee3621317..62b77f22e5 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -52,7 +52,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel): # initialize client client = Client( - base_url=credentials['server_url'] + base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_client = client.get_model(model_uid=credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 11f1e29cb3..3a8d704c25 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -110,14 +110,22 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): server_url = credentials['server_url'] model_uid = credentials['model_uid'] - extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) + api_key = credentials.get('api_key') + extra_args = XinferenceHelper.get_xinference_extra_parameter( + server_url=server_url, + model_uid=model_uid, + api_key=api_key, + ) if extra_args.max_tokens: credentials['max_tokens'] = extra_args.max_tokens if server_url.endswith('/'): server_url = server_url[:-1] - client = Client(base_url=server_url) + client = Client( + base_url=server_url, + api_key=api_key, + ) try: handle = client.get_model(model_uid=model_uid) diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index a564a021b1..bfa752df8c 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -81,7 +81,8 @@ class XinferenceText2SpeechModel(TTSModel): extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key'), ) if 'text-to-audio' not in extra_param.model_ability: @@ -203,7 +204,11 @@ class XinferenceText2SpeechModel(TTSModel): credentials['server_url'] = credentials['server_url'][:-1] try: - handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) + api_key = credentials.get('api_key') + auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + handle = RESTfulAudioModelHandle( + credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers + ) model_support_voice = [x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials)] diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 7db483a485..75161ad376 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -35,13 +35,13 @@ cache_lock = Lock() class XinferenceHelper: @staticmethod - def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: XinferenceHelper._clean_cache() with cache_lock: if model_uid not in cache: cache[model_uid] = { 'expires': time() + 300, - 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) + 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key) } return cache[model_uid]['value'] @@ -56,7 +56,7 @@ class XinferenceHelper: pass @staticmethod - def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: """ get xinference model extra parameter like model_format and model_handle_type """ @@ -70,9 +70,10 @@ class XinferenceHelper: session = Session() session.mount('http://', HTTPAdapter(max_retries=3)) session.mount('https://', HTTPAdapter(max_retries=3)) + headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} try: - response = session.get(url, timeout=10) + response = session.get(url, headers=headers, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') if response.status_code != 200: diff --git a/api/poetry.lock b/api/poetry.lock index 0527507bff..9bfeec30d7 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -9584,4 +9584,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "165e4af9cfbce83ee831dd0e82159446ef595d7a7850ee8644c8e2d24dd7040d" +content-hash = "a74c7b6a72145d5074aa84581df6e543ea422810caf0ba1561cd2d35497243ca" diff --git a/api/pyproject.toml b/api/pyproject.toml index f0df3e5a0e..82ccd0b202 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -156,6 +156,7 @@ markdown = "~3.5.1" novita-client = "^0.5.6" numpy = "~1.26.4" openai = "~1.29.0" +openpyxl = "~3.1.5" oss2 = "2.18.5" pandas = { version = "~2.2.2", extras = ["performance", "excel"] } psycopg2-binary = "~2.9.6" @@ -173,7 +174,6 @@ readabilipy = "0.2.0" redis = { version = "~5.0.3", extras = ["hiredis"] } replicate = "~0.22.0" resend = "~0.7.0" -safetensors = "~0.4.3" scikit-learn = "^1.5.1" sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sqlalchemy = "~2.0.29" @@ -187,10 +187,16 @@ werkzeug = "~3.0.1" xinference-client = "0.13.3" yarl = "~1.9.4" zhipuai = "1.0.7" -rank-bm25 = "~0.2.2" -openpyxl = "^3.1.5" +# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. + +############################################################ +# Related transparent dependencies with pinned verion +# required by main implementations +############################################################ +[tool.poetry.group.indriect.dependencies] kaleido = "0.2.1" -elasticsearch = "8.14.0" +rank-bm25 = "~0.2.2" +safetensors = "~0.4.3" ############################################################ # Tool dependencies required by tool implementations @@ -198,6 +204,7 @@ elasticsearch = "8.14.0" [tool.poetry.group.tool.dependencies] arxiv = "2.1.0" +cloudscraper = "1.2.71" matplotlib = "~3.8.2" newspaper3k = "0.2.8" duckduckgo-search = "^6.2.6" @@ -209,26 +216,25 @@ twilio = "~9.0.4" vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } wikipedia = "1.4.0" yfinance = "~0.2.40" -cloudscraper = "1.2.71" ############################################################ # VDB dependencies required by vector store clients ############################################################ [tool.poetry.group.vdb.dependencies] +alibabacloud_gpdb20160503 = "~3.8.0" +alibabacloud_tea_openapi = "~0.3.9" chromadb = "0.5.1" +clickhouse-connect = "~0.7.16" +elasticsearch = "8.14.0" oracledb = "~2.2.1" pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvector = "0.2.5" pymilvus = "~2.4.4" -pymysql = "1.1.1" tcvectordb = "1.3.2" tidb-vector = "0.0.9" qdrant-client = "1.7.3" weaviate-client = "~3.21.0" -alibabacloud_gpdb20160503 = "~3.8.0" -alibabacloud_tea_openapi = "~0.3.9" -clickhouse-connect = "~0.7.16" ############################################################ # Dev dependencies for running tests @@ -252,5 +258,5 @@ pytest-mock = "~3.14.0" optional = true [tool.poetry.group.lint.dependencies] -ruff = "~0.6.1" dotenv-linter = "~0.5.0" +ruff = "~0.6.1"