diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png
new file mode 100644
index 0000000000..dfe8e78049
Binary files /dev/null and b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png differ
diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg
new file mode 100644
index 0000000000..bb23bffcf1
--- /dev/null
+++ b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg
@@ -0,0 +1,15 @@
+
diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png
new file mode 100644
index 0000000000..b154821db9
Binary files /dev/null and b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png differ
diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg
new file mode 100644
index 0000000000..c5c608cd7c
--- /dev/null
+++ b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg
@@ -0,0 +1,11 @@
+
diff --git a/api/core/model_runtime/model_providers/gpustack/gpustack.py b/api/core/model_runtime/model_providers/gpustack/gpustack.py
new file mode 100644
index 0000000000..321100167e
--- /dev/null
+++ b/api/core/model_runtime/model_providers/gpustack/gpustack.py
@@ -0,0 +1,10 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class GPUStackProvider(ModelProvider):
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ pass
diff --git a/api/core/model_runtime/model_providers/gpustack/gpustack.yaml b/api/core/model_runtime/model_providers/gpustack/gpustack.yaml
new file mode 100644
index 0000000000..ee4a3c159a
--- /dev/null
+++ b/api/core/model_runtime/model_providers/gpustack/gpustack.yaml
@@ -0,0 +1,120 @@
+provider: gpustack
+label:
+ en_US: GPUStack
+icon_small:
+ en_US: icon_s_en.png
+icon_large:
+ en_US: icon_l_en.png
+supported_model_types:
+ - llm
+ - text-embedding
+ - rerank
+configurate_methods:
+ - customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter your model name
+ zh_Hans: 输入模型名称
+ credential_form_schemas:
+ - variable: endpoint_url
+ label:
+ zh_Hans: 服务器地址
+ en_US: Server URL
+ type: text-input
+ required: true
+ placeholder:
+ zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100
+ en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100
+ - variable: api_key
+ label:
+ en_US: API Key
+ type: secret-input
+ required: true
+ placeholder:
+ zh_Hans: 输入您的 API Key
+ en_US: Enter your API Key
+ - variable: mode
+ show_on:
+ - variable: __model_type
+ value: llm
+ label:
+ en_US: Completion mode
+ type: select
+ required: false
+ default: chat
+ placeholder:
+ zh_Hans: 选择补全类型
+ en_US: Select completion type
+ options:
+ - value: completion
+ label:
+ en_US: Completion
+ zh_Hans: 补全
+ - value: chat
+ label:
+ en_US: Chat
+ zh_Hans: 对话
+ - variable: context_size
+ label:
+ zh_Hans: 模型上下文长度
+ en_US: Model context size
+ required: true
+ type: text-input
+ default: "8192"
+ placeholder:
+ zh_Hans: 输入您的模型上下文长度
+ en_US: Enter your Model context size
+ - variable: max_tokens_to_sample
+ label:
+ zh_Hans: 最大 token 上限
+ en_US: Upper bound for max tokens
+ show_on:
+ - variable: __model_type
+ value: llm
+ default: "8192"
+ type: text-input
+ - variable: function_calling_type
+ show_on:
+ - variable: __model_type
+ value: llm
+ label:
+ en_US: Function calling
+ type: select
+ required: false
+ default: no_call
+ options:
+ - value: function_call
+ label:
+ en_US: Function Call
+ zh_Hans: Function Call
+ - value: tool_call
+ label:
+ en_US: Tool Call
+ zh_Hans: Tool Call
+ - value: no_call
+ label:
+ en_US: Not Support
+ zh_Hans: 不支持
+ - variable: vision_support
+ show_on:
+ - variable: __model_type
+ value: llm
+ label:
+ zh_Hans: Vision 支持
+ en_US: Vision Support
+ type: select
+ required: false
+ default: no_support
+ options:
+ - value: support
+ label:
+ en_US: Support
+ zh_Hans: 支持
+ - value: no_support
+ label:
+ en_US: Not Support
+ zh_Hans: 不支持
diff --git a/api/core/model_runtime/model_providers/gpustack/llm/__init__.py b/api/core/model_runtime/model_providers/gpustack/llm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/gpustack/llm/llm.py b/api/core/model_runtime/model_providers/gpustack/llm/llm.py
new file mode 100644
index 0000000000..ce6780b6a7
--- /dev/null
+++ b/api/core/model_runtime/model_providers/gpustack/llm/llm.py
@@ -0,0 +1,45 @@
+from collections.abc import Generator
+
+from yarl import URL
+
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import (
+ PromptMessage,
+ PromptMessageTool,
+)
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
+ OAIAPICompatLargeLanguageModel,
+)
+
+
+class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
+ def _invoke(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: list[PromptMessage],
+ model_parameters: dict,
+ tools: list[PromptMessageTool] | None = None,
+ stop: list[str] | None = None,
+ stream: bool = True,
+ user: str | None = None,
+ ) -> LLMResult | Generator:
+ return super()._invoke(
+ model,
+ credentials,
+ prompt_messages,
+ model_parameters,
+ tools,
+ stop,
+ stream,
+ user,
+ )
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ self._add_custom_parameters(credentials)
+ super().validate_credentials(model, credentials)
+
+ @staticmethod
+ def _add_custom_parameters(credentials: dict) -> None:
+ credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
+ credentials["mode"] = "chat"
diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/__init__.py b/api/core/model_runtime/model_providers/gpustack/rerank/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py
new file mode 100644
index 0000000000..5ea7532564
--- /dev/null
+++ b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py
@@ -0,0 +1,146 @@
+from json import dumps
+from typing import Optional
+
+import httpx
+from requests import post
+from yarl import URL
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import (
+ AIModelEntity,
+ FetchFrom,
+ ModelPropertyKey,
+ ModelType,
+)
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+
+
+class GPUStackRerankModel(RerankModel):
+ """
+ Model class for GPUStack rerank model.
+ """
+
+ def _invoke(
+ self,
+ model: str,
+ credentials: dict,
+ query: str,
+ docs: list[str],
+ score_threshold: Optional[float] = None,
+ top_n: Optional[int] = None,
+ user: Optional[str] = None,
+ ) -> RerankResult:
+ """
+ Invoke rerank model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param query: search query
+ :param docs: docs for reranking
+ :param score_threshold: score threshold
+ :param top_n: top n documents to return
+ :param user: unique user id
+ :return: rerank result
+ """
+ if len(docs) == 0:
+ return RerankResult(model=model, docs=[])
+
+ endpoint_url = credentials["endpoint_url"]
+ headers = {
+ "Authorization": f"Bearer {credentials.get('api_key')}",
+ "Content-Type": "application/json",
+ }
+
+ data = {"model": model, "query": query, "documents": docs, "top_n": top_n}
+
+ try:
+ response = post(
+ str(URL(endpoint_url) / "v1" / "rerank"),
+ headers=headers,
+ data=dumps(data),
+ timeout=10,
+ )
+ response.raise_for_status()
+ results = response.json()
+
+ rerank_documents = []
+ for result in results["results"]:
+ index = result["index"]
+ if "document" in result:
+ text = result["document"]["text"]
+ else:
+ text = docs[index]
+
+ rerank_document = RerankDocument(
+ index=index,
+ text=text,
+ score=result["relevance_score"],
+ )
+
+ if score_threshold is None or result["relevance_score"] >= score_threshold:
+ rerank_documents.append(rerank_document)
+
+ return RerankResult(model=model, docs=rerank_documents)
+ except httpx.HTTPStatusError as e:
+ raise InvokeServerUnavailableError(str(e))
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ self._invoke(
+ model=model,
+ credentials=credentials,
+ query="What is the capital of the United States?",
+ docs=[
+ "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+ "Census, Carson City had a population of 55,274.",
+ "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+ "are a political division controlled by the United States. Its capital is Saipan.",
+ ],
+ score_threshold=0.8,
+ )
+ except Exception as ex:
+ raise CredentialsValidateFailedError(str(ex))
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ """
+ return {
+ InvokeConnectionError: [httpx.ConnectError],
+ InvokeServerUnavailableError: [httpx.RemoteProtocolError],
+ InvokeRateLimitError: [],
+ InvokeAuthorizationError: [httpx.HTTPStatusError],
+ InvokeBadRequestError: [httpx.RequestError],
+ }
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.RERANK,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
+ )
+
+ return entity
diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py
new file mode 100644
index 0000000000..eb324491a2
--- /dev/null
+++ b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py
@@ -0,0 +1,35 @@
+from typing import Optional
+
+from yarl import URL
+
+from core.entities.embedding_type import EmbeddingInputType
+from core.model_runtime.entities.text_embedding_entities import (
+ TextEmbeddingResult,
+)
+from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
+ OAICompatEmbeddingModel,
+)
+
+
+class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
+ """
+ Model class for GPUStack text embedding model.
+ """
+
+ def _invoke(
+ self,
+ model: str,
+ credentials: dict,
+ texts: list[str],
+ user: Optional[str] = None,
+ input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
+ ) -> TextEmbeddingResult:
+ return super()._invoke(model, credentials, texts, user, input_type)
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ self._add_custom_parameters(credentials)
+ super().validate_credentials(model, credentials)
+
+ @staticmethod
+ def _add_custom_parameters(credentials: dict) -> None:
+ credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example
index fa4a2eb36c..baa100531f 100644
--- a/api/tests/integration_tests/.env.example
+++ b/api/tests/integration_tests/.env.example
@@ -96,5 +96,9 @@ VESSL_AI_MODEL_NAME=
VESSL_AI_API_KEY=
VESSL_AI_ENDPOINT_URL=
+# GPUStack Credentials
+GPUSTACK_SERVER_URL=
+GPUSTACK_API_KEY=
+
# Gitee AI Credentials
GITEE_AI_API_KEY=
diff --git a/api/tests/integration_tests/model_runtime/gpustack/__init__.py b/api/tests/integration_tests/model_runtime/gpustack/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py b/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py
new file mode 100644
index 0000000000..f56ad0dadc
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py
@@ -0,0 +1,49 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.gpustack.text_embedding.text_embedding import (
+ GPUStackTextEmbeddingModel,
+)
+
+
+def test_validate_credentials():
+ model = GPUStackTextEmbeddingModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model="bge-m3",
+ credentials={
+ "endpoint_url": "invalid_url",
+ "api_key": "invalid_api_key",
+ },
+ )
+
+ model.validate_credentials(
+ model="bge-m3",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ },
+ )
+
+
+def test_invoke_model():
+ model = GPUStackTextEmbeddingModel()
+
+ result = model.invoke(
+ model="bge-m3",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ "context_size": 8192,
+ },
+ texts=["hello", "world"],
+ user="abc-123",
+ )
+
+ assert isinstance(result, TextEmbeddingResult)
+ assert len(result.embeddings) == 2
+ assert result.usage.total_tokens == 7
diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_llm.py b/api/tests/integration_tests/model_runtime/gpustack/test_llm.py
new file mode 100644
index 0000000000..326b7b16f0
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/gpustack/test_llm.py
@@ -0,0 +1,162 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import (
+ LLMResult,
+ LLMResultChunk,
+ LLMResultChunkDelta,
+)
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ PromptMessageTool,
+ SystemPromptMessage,
+ UserPromptMessage,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.gpustack.llm.llm import GPUStackLanguageModel
+
+
+def test_validate_credentials_for_chat_model():
+ model = GPUStackLanguageModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model="llama-3.2-1b-instruct",
+ credentials={
+ "endpoint_url": "invalid_url",
+ "api_key": "invalid_api_key",
+ "mode": "chat",
+ },
+ )
+
+ model.validate_credentials(
+ model="llama-3.2-1b-instruct",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ "mode": "chat",
+ },
+ )
+
+
+def test_invoke_completion_model():
+ model = GPUStackLanguageModel()
+
+ response = model.invoke(
+ model="llama-3.2-1b-instruct",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ "mode": "completion",
+ },
+ prompt_messages=[UserPromptMessage(content="ping")],
+ model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
+ stop=[],
+ user="abc-123",
+ stream=False,
+ )
+
+ assert isinstance(response, LLMResult)
+ assert len(response.message.content) > 0
+ assert response.usage.total_tokens > 0
+
+
+def test_invoke_chat_model():
+ model = GPUStackLanguageModel()
+
+ response = model.invoke(
+ model="llama-3.2-1b-instruct",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ "mode": "chat",
+ },
+ prompt_messages=[UserPromptMessage(content="ping")],
+ model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
+ stop=[],
+ user="abc-123",
+ stream=False,
+ )
+
+ assert isinstance(response, LLMResult)
+ assert len(response.message.content) > 0
+ assert response.usage.total_tokens > 0
+
+
+def test_invoke_stream_chat_model():
+ model = GPUStackLanguageModel()
+
+ response = model.invoke(
+ model="llama-3.2-1b-instruct",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ "mode": "chat",
+ },
+ prompt_messages=[UserPromptMessage(content="Hello World!")],
+ model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
+ stop=["you"],
+ stream=True,
+ user="abc-123",
+ )
+
+ assert isinstance(response, Generator)
+ for chunk in response:
+ assert isinstance(chunk, LLMResultChunk)
+ assert isinstance(chunk.delta, LLMResultChunkDelta)
+ assert isinstance(chunk.delta.message, AssistantPromptMessage)
+ assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_get_num_tokens():
+ model = GPUStackLanguageModel()
+
+ num_tokens = model.get_num_tokens(
+ model="????",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ "mode": "chat",
+ },
+ prompt_messages=[
+ SystemPromptMessage(
+ content="You are a helpful AI assistant.",
+ ),
+ UserPromptMessage(content="Hello World!"),
+ ],
+ tools=[
+ PromptMessageTool(
+ name="get_current_weather",
+ description="Get the current weather in a given location",
+ parameters={
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state e.g. San Francisco, CA",
+ },
+ "unit": {"type": "string", "enum": ["c", "f"]},
+ },
+ "required": ["location"],
+ },
+ )
+ ],
+ )
+
+ assert isinstance(num_tokens, int)
+ assert num_tokens == 80
+
+ num_tokens = model.get_num_tokens(
+ model="????",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ "mode": "chat",
+ },
+ prompt_messages=[UserPromptMessage(content="Hello World!")],
+ )
+
+ assert isinstance(num_tokens, int)
+ assert num_tokens == 10
diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py b/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py
new file mode 100644
index 0000000000..f5c2d2d21c
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py
@@ -0,0 +1,107 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.gpustack.rerank.rerank import (
+ GPUStackRerankModel,
+)
+
+
+def test_validate_credentials_for_rerank_model():
+ model = GPUStackRerankModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model="bge-reranker-v2-m3",
+ credentials={
+ "endpoint_url": "invalid_url",
+ "api_key": "invalid_api_key",
+ },
+ )
+
+ model.validate_credentials(
+ model="bge-reranker-v2-m3",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ },
+ )
+
+
+def test_invoke_rerank_model():
+ model = GPUStackRerankModel()
+
+ response = model.invoke(
+ model="bge-reranker-v2-m3",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ },
+ query="Organic skincare products for sensitive skin",
+ docs=[
+ "Eco-friendly kitchenware for modern homes",
+ "Biodegradable cleaning supplies for eco-conscious consumers",
+ "Organic cotton baby clothes for sensitive skin",
+ "Natural organic skincare range for sensitive skin",
+ "Tech gadgets for smart homes: 2024 edition",
+ "Sustainable gardening tools and compost solutions",
+ "Sensitive skin-friendly facial cleansers and toners",
+ "Organic food wraps and storage solutions",
+ "Yoga mats made from recycled materials",
+ ],
+ top_n=3,
+ score_threshold=-0.75,
+ user="abc-123",
+ )
+
+ assert isinstance(response, RerankResult)
+ assert len(response.docs) == 3
+
+
+def test__invoke():
+ model = GPUStackRerankModel()
+
+ # Test case 1: Empty docs
+ result = model._invoke(
+ model="bge-reranker-v2-m3",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ },
+ query="Organic skincare products for sensitive skin",
+ docs=[],
+ top_n=3,
+ score_threshold=0.75,
+ user="abc-123",
+ )
+ assert isinstance(result, RerankResult)
+ assert len(result.docs) == 0
+
+ # Test case 2: Expected docs
+ result = model._invoke(
+ model="bge-reranker-v2-m3",
+ credentials={
+ "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+ "api_key": os.environ.get("GPUSTACK_API_KEY"),
+ },
+ query="Organic skincare products for sensitive skin",
+ docs=[
+ "Eco-friendly kitchenware for modern homes",
+ "Biodegradable cleaning supplies for eco-conscious consumers",
+ "Organic cotton baby clothes for sensitive skin",
+ "Natural organic skincare range for sensitive skin",
+ "Tech gadgets for smart homes: 2024 edition",
+ "Sustainable gardening tools and compost solutions",
+ "Sensitive skin-friendly facial cleansers and toners",
+ "Organic food wraps and storage solutions",
+ "Yoga mats made from recycled materials",
+ ],
+ top_n=3,
+ score_threshold=-0.75,
+ user="abc-123",
+ )
+ assert isinstance(result, RerankResult)
+ assert len(result.docs) == 3
+ assert all(isinstance(doc, RerankDocument) for doc in result.docs)