mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
chore: refurbish Python code by applying refurb linter rules (#8296)
This commit is contained in:
parent
c69f5b07ba
commit
40fb4d16ef
|
@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource):
|
|||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = args["desc"] if args["desc"] else ""
|
||||
copy_right = args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
|
||||
desc = args["desc"] or ""
|
||||
copy_right = args["copyright"] or ""
|
||||
privacy_policy = args["privacy_policy"] or ""
|
||||
custom_disclaimer = args["custom_disclaimer"] or ""
|
||||
else:
|
||||
desc = site.description if site.description else args["desc"] if args["desc"] else ""
|
||||
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = (
|
||||
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
)
|
||||
custom_disclaimer = (
|
||||
site.custom_disclaimer
|
||||
if site.custom_disclaimer
|
||||
else args["custom_disclaimer"]
|
||||
if args["custom_disclaimer"]
|
||||
else ""
|
||||
)
|
||||
desc = site.description or args["desc"] or ""
|
||||
copy_right = site.copyright or args["copyright"] or ""
|
||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
|
||||
|
|
|
@ -99,14 +99,10 @@ class ChatMessageTextApi(Resource):
|
|||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
|
||||
|
|
|
@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
|
||||
if not account:
|
||||
# Create account
|
||||
account_name = user_info.name if user_info.name else "Dify"
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
)
|
||||
|
|
|
@ -550,12 +550,7 @@ class DatasetApiBaseUrlApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {
|
||||
"api_base_url": (
|
||||
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
|
||||
)
|
||||
+ "/v1"
|
||||
}
|
||||
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
|
||||
|
||||
|
||||
class DatasetRetrievalSettingApi(Resource):
|
||||
|
|
|
@ -86,14 +86,10 @@ class ChatTextApi(InstalledAppResource):
|
|||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
|
||||
|
|
|
@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
|||
|
||||
return ApiToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
args["provider_name"] if args["provider_name"] else "",
|
||||
args["provider_name"] or "",
|
||||
args["tool_name"],
|
||||
args["credentials"],
|
||||
args["parameters"],
|
||||
|
|
|
@ -84,14 +84,10 @@ class TextApi(Resource):
|
|||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
|
|
|
@ -83,14 +83,10 @@ class TextApi(WebApiResource):
|
|||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
|
|
|
@ -256,7 +256,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
|
|
|
@ -298,7 +298,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
|
|
|
@ -161,7 +161,7 @@ class AppRunner:
|
|||
app_mode=AppMode.value_of(app_record.mode),
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else "",
|
||||
query=query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
|
@ -189,7 +189,7 @@ class AppRunner:
|
|||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query if query else "",
|
||||
query=query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
|
@ -238,7 +238,7 @@ class AppRunner:
|
|||
model=app_generate_entity.model_conf.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage(),
|
||||
usage=usage or LLMUsage.empty_usage(),
|
||||
),
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
|
@ -351,7 +351,7 @@ class AppRunner:
|
|||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
query=query if query else "",
|
||||
query=query or "",
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
)
|
||||
|
|
|
@ -3,6 +3,7 @@ import importlib.util
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -63,8 +64,7 @@ class Extensible:
|
|||
|
||||
builtin_file_path = os.path.join(subdir_path, "__builtin__")
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, encoding="utf-8") as f:
|
||||
position = int(f.read().strip())
|
||||
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
|
||||
position_map[extension_name] = position
|
||||
|
||||
if (extension_name + ".py") not in file_names:
|
||||
|
|
|
@ -39,7 +39,7 @@ class TokenBufferMemory:
|
|||
)
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
message_limit = message_limit if message_limit <= 500 else 500
|
||||
message_limit = min(message_limit, 500)
|
||||
else:
|
||||
message_limit = 500
|
||||
|
||||
|
|
|
@ -449,7 +449,7 @@ if you are not sure about the structure.
|
|||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=prompt_message,
|
||||
usage=usage if usage else LLMUsage.empty_usage(),
|
||||
usage=usage or LLMUsage.empty_usage(),
|
||||
system_fingerprint=system_fingerprint,
|
||||
),
|
||||
credentials=credentials,
|
||||
|
|
|
@ -409,7 +409,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
),
|
||||
)
|
||||
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||
chunk_text = chunk.delta.text if chunk.delta.text else ""
|
||||
chunk_text = chunk.delta.text or ""
|
||||
full_assistant_content += chunk_text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
|
|
|
@ -213,7 +213,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
|||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=prompt_message,
|
||||
usage=usage if usage else LLMUsage.empty_usage(),
|
||||
usage=usage or LLMUsage.empty_usage(),
|
||||
system_fingerprint=system_fingerprint,
|
||||
),
|
||||
credentials=credentials,
|
||||
|
|
|
@ -225,7 +225,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
text = delta.text if delta.text else ""
|
||||
text = delta.text or ""
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||
|
||||
full_text += text
|
||||
|
@ -400,15 +400,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
||||
|
||||
full_assistant_content += delta.delta.content if delta.delta.content else ""
|
||||
full_assistant_content += delta.delta.content or ""
|
||||
|
||||
real_model = chunk.model
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
completion += delta.delta.content if delta.delta.content else ""
|
||||
completion += delta.delta.content or ""
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=real_model,
|
||||
|
|
|
@ -84,7 +84,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||
)
|
||||
for i in range(len(sentences))
|
||||
]
|
||||
for index, future in enumerate(futures):
|
||||
for future in futures:
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
|
||||
else:
|
||||
|
|
|
@ -331,10 +331,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
elif "contentBlockDelta" in chunk:
|
||||
delta = chunk["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
chunk_text = delta["text"] if delta["text"] else ""
|
||||
chunk_text = delta["text"] or ""
|
||||
full_assistant_content += chunk_text
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text if chunk_text else "",
|
||||
content=chunk_text or "",
|
||||
)
|
||||
index = chunk["contentBlockDelta"]["contentBlockIndex"]
|
||||
yield LLMResultChunk(
|
||||
|
@ -751,7 +751,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
elif model_prefix == "cohere":
|
||||
output = response_body.get("generations")[0].get("text")
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, output if output else "")
|
||||
completion_tokens = self.get_num_tokens(model, credentials, output or "")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
|
@ -828,7 +828,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=content_delta if content_delta else "",
|
||||
content=content_delta or "",
|
||||
)
|
||||
index += 1
|
||||
|
||||
|
|
|
@ -302,11 +302,11 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||
if delta.delta.function_call:
|
||||
function_calls = [delta.delta.function_call]
|
||||
|
||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
|
||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or [])
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
|
||||
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
|
||||
)
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
|
|
|
@ -511,7 +511,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||
delta = chunk.choices[0]
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[])
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[])
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
# temp_assistant_prompt_message is used to calculate usage
|
||||
|
@ -578,11 +578,11 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||
if delta.delta.function_call:
|
||||
function_calls = [delta.delta.function_call]
|
||||
|
||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
|
||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or [])
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
|
||||
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
|
||||
)
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
|
|
|
@ -211,7 +211,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
|||
index=0,
|
||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||
usage=usage,
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
finish_reason=message.stop_reason or None,
|
||||
),
|
||||
)
|
||||
elif message.function_call:
|
||||
|
@ -244,7 +244,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
|||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
finish_reason=message.stop_reason or None,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||
inputs = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
for text in texts:
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
|
|
|
@ -508,7 +508,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
text = delta.text if delta.text else ""
|
||||
text = delta.text or ""
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||
|
||||
full_text += text
|
||||
|
@ -760,11 +760,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
final_tool_calls.extend(tool_calls)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
||||
|
||||
full_assistant_content += delta.delta.content if delta.delta.content else ""
|
||||
full_assistant_content += delta.delta.content or ""
|
||||
|
||||
if has_finish_reason:
|
||||
final_chunk = LLMResultChunk(
|
||||
|
|
|
@ -88,7 +88,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
|||
)
|
||||
for i in range(len(sentences))
|
||||
]
|
||||
for index, future in enumerate(futures):
|
||||
for future in futures:
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
|
||||
else:
|
||||
|
|
|
@ -179,9 +179,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||
features = []
|
||||
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
if function_calling_type in ["function_call"]:
|
||||
if function_calling_type == "function_call":
|
||||
features.append(ModelFeature.TOOL_CALL)
|
||||
elif function_calling_type in ["tool_call"]:
|
||||
elif function_calling_type == "tool_call":
|
||||
features.append(ModelFeature.MULTI_TOOL_CALL)
|
||||
|
||||
stream_function_calling = credentials.get("stream_function_calling", "supported")
|
||||
|
|
|
@ -179,7 +179,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
|||
index=0,
|
||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||
usage=usage,
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
finish_reason=message.stop_reason or None,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
@ -189,7 +189,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
|||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
finish_reason=message.stop_reason or None,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ class OpenLLMGenerate:
|
|||
timeout = 120
|
||||
|
||||
data = {
|
||||
"stop": stop if stop else [],
|
||||
"stop": stop or [],
|
||||
"prompt": "\n".join([message.content for message in prompt_messages]),
|
||||
"llm_config": default_llm_config,
|
||||
}
|
||||
|
|
|
@ -214,7 +214,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
|||
|
||||
index += 1
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(content=output if output else "")
|
||||
assistant_prompt_message = AssistantPromptMessage(content=output or "")
|
||||
|
||||
if index < prediction_output_length:
|
||||
yield LLMResultChunk(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
@ -94,7 +95,7 @@ class SageMakerRerankModel(RerankModel):
|
|||
for idx in range(len(scores)):
|
||||
candidate_docs.append({"content": docs[idx], "score": scores[idx]})
|
||||
|
||||
sorted(candidate_docs, key=lambda x: x["score"], reverse=True)
|
||||
sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 3
|
||||
rerank_documents = []
|
||||
|
|
|
@ -260,7 +260,7 @@ class SageMakerText2SpeechModel(TTSModel):
|
|||
for payload in payloads
|
||||
]
|
||||
|
||||
for index, future in enumerate(futures):
|
||||
for future in futures:
|
||||
resp = future.result()
|
||||
audio_bytes = requests.get(resp.get("s3_presign_url")).content
|
||||
for i in range(0, len(audio_bytes), 1024):
|
||||
|
|
|
@ -220,7 +220,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
|
|||
delta = content
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta if delta else "",
|
||||
content=delta or "",
|
||||
)
|
||||
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import operator
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
@ -127,7 +128,7 @@ class FlashRecognizer:
|
|||
return s
|
||||
|
||||
def _build_req_with_signature(self, secret_key, params, header):
|
||||
query = sorted(params.items(), key=lambda d: d[0])
|
||||
query = sorted(params.items(), key=operator.itemgetter(0))
|
||||
signstr = self._format_sign_string(query)
|
||||
signature = self._sign(signstr, secret_key)
|
||||
header["Authorization"] = signature
|
||||
|
|
|
@ -4,6 +4,7 @@ import tempfile
|
|||
import uuid
|
||||
from collections.abc import Generator
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from dashscope import Generation, MultiModalConversation, get_tokenizer
|
||||
|
@ -454,8 +455,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{mime_type.split('/')[1]}")
|
||||
|
||||
with open(file_path, "wb") as image_file:
|
||||
image_file.write(base64.b64decode(encoded_string))
|
||||
Path(file_path).write_bytes(base64.b64decode(encoded_string))
|
||||
|
||||
return f"file://{file_path}"
|
||||
|
||||
|
|
|
@ -368,11 +368,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
|
|||
final_tool_calls.extend(tool_calls)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
||||
|
||||
full_assistant_content += delta.delta.content if delta.delta.content else ""
|
||||
full_assistant_content += delta.delta.content or ""
|
||||
|
||||
if has_finish_reason:
|
||||
final_chunk = LLMResultChunk(
|
||||
|
|
|
@ -231,10 +231,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
|||
),
|
||||
)
|
||||
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||
chunk_text = chunk.delta.text if chunk.delta.text else ""
|
||||
chunk_text = chunk.delta.text or ""
|
||||
full_assistant_content += chunk_text
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text if chunk_text else "",
|
||||
content=chunk_text or "",
|
||||
)
|
||||
index = chunk.index
|
||||
yield LLMResultChunk(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# coding : utf-8
|
||||
import datetime
|
||||
from itertools import starmap
|
||||
|
||||
import pytz
|
||||
|
||||
|
@ -48,7 +49,7 @@ class SignResult:
|
|||
self.authorization = ""
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(["{}:{}".format(*item) for item in self.__dict__.items()])
|
||||
return "\n".join(list(starmap("{}:{}".format, self.__dict__.items())))
|
||||
|
||||
|
||||
class Credentials:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import hashlib
|
||||
import hmac
|
||||
import operator
|
||||
from functools import reduce
|
||||
from urllib.parse import quote
|
||||
|
||||
|
@ -40,4 +41,4 @@ class Util:
|
|||
if len(hv) == 1:
|
||||
hv = "0" + hv
|
||||
lst.append(hv)
|
||||
return reduce(lambda x, y: x + y, lst)
|
||||
return reduce(operator.add, lst)
|
||||
|
|
|
@ -174,9 +174,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(
|
||||
content=message["content"] if message["content"] else "", tool_calls=[]
|
||||
),
|
||||
message=AssistantPromptMessage(content=message["content"] or "", tool_calls=[]),
|
||||
usage=usage,
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
),
|
||||
|
@ -208,7 +206,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=message["content"] if message["content"] else "",
|
||||
content=message["content"] or "",
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
usage=self._calc_response_usage(
|
||||
|
@ -284,7 +282,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content if message.content else "",
|
||||
content=message.content or "",
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
usage=self._calc_response_usage(
|
||||
|
|
|
@ -199,7 +199,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
|||
secret_key=credentials["secret_key"],
|
||||
)
|
||||
|
||||
user = user if user else "ErnieBotDefault"
|
||||
user = user or "ErnieBotDefault"
|
||||
|
||||
# convert prompt messages to baichuan messages
|
||||
messages = [
|
||||
|
@ -289,7 +289,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
|||
index=0,
|
||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||
usage=usage,
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
finish_reason=message.stop_reason or None,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
@ -299,7 +299,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
|||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
finish_reason=message.stop_reason or None,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel):
|
|||
api_key = credentials["api_key"]
|
||||
secret_key = credentials["secret_key"]
|
||||
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
|
||||
user = user if user else "ErnieBotDefault"
|
||||
user = user or "ErnieBotDefault"
|
||||
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
|
|
@ -589,7 +589,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
# convert tool call to assistant message tool call
|
||||
tool_calls = assistant_message.tool_calls
|
||||
assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else [])
|
||||
assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls or [])
|
||||
function_call = assistant_message.function_call
|
||||
if function_call:
|
||||
assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)]
|
||||
|
@ -652,7 +652,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
|
||||
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
|
||||
)
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
|
@ -749,7 +749,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||
delta = chunk.choices[0]
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[])
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[])
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
# temp_assistant_prompt_message is used to calculate usage
|
||||
|
|
|
@ -215,7 +215,7 @@ class XinferenceText2SpeechModel(TTSModel):
|
|||
for i in range(len(sentences))
|
||||
]
|
||||
|
||||
for index, future in enumerate(futures):
|
||||
for future in futures:
|
||||
response = future.result()
|
||||
for i in range(0, len(response), 1024):
|
||||
yield response[i : i + 1024]
|
||||
|
|
|
@ -414,10 +414,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_tool_calls
|
||||
content=delta.delta.content or "", tool_calls=assistant_tool_calls
|
||||
)
|
||||
|
||||
full_assistant_content += delta.delta.content if delta.delta.content else ""
|
||||
full_assistant_content += delta.delta.content or ""
|
||||
|
||||
if delta.finish_reason is not None and chunk.usage is not None:
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
|
|
|
@ -30,6 +30,8 @@ def _merge_map(map1: Mapping, map2: Mapping) -> Mapping:
|
|||
return {key: val for key, val in merged.items() if val is not None}
|
||||
|
||||
|
||||
from itertools import starmap
|
||||
|
||||
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
|
||||
|
||||
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
|
||||
|
@ -159,7 +161,7 @@ class HttpClient:
|
|||
return [(key, str_data)]
|
||||
|
||||
def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
|
||||
items = flatten([self._object_to_formdata(k, v) for k, v in data.items()])
|
||||
items = flatten(list(starmap(self._object_to_formdata, data.items())))
|
||||
|
||||
serialized: dict[str, object] = {}
|
||||
for key, value in items:
|
||||
|
|
|
@ -65,7 +65,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id
|
||||
trace_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
user_id = trace_info.metadata.get("user_id")
|
||||
if trace_info.message_id:
|
||||
trace_id = trace_info.message_id
|
||||
|
@ -84,7 +84,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
workflow_span_data = LangfuseSpan(
|
||||
id=(trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id),
|
||||
id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
input=trace_info.workflow_run_inputs,
|
||||
output=trace_info.workflow_run_outputs,
|
||||
|
@ -93,7 +93,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
||||
status_message=trace_info.error if trace_info.error else "",
|
||||
status_message=trace_info.error or "",
|
||||
)
|
||||
self.add_span(langfuse_span_data=workflow_span_data)
|
||||
else:
|
||||
|
@ -143,7 +143,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at if node_execution.created_at else datetime.now()
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
|
@ -172,10 +172,8 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
end_time=finished_at,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error if trace_info.error else "",
|
||||
parent_observation_id=(
|
||||
trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id
|
||||
),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
|
||||
)
|
||||
else:
|
||||
span_data = LangfuseSpan(
|
||||
|
@ -188,7 +186,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
end_time=finished_at,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error if trace_info.error else "",
|
||||
status_message=trace_info.error or "",
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
@ -212,7 +210,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
output=outputs,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error if trace_info.error else "",
|
||||
status_message=trace_info.error or "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
|
@ -277,7 +275,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
output=message_data.answer,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
||||
status_message=message_data.error if message_data.error else "",
|
||||
status_message=message_data.error or "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
|
@ -319,7 +317,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
||||
status_message=message_data.error if message_data.error else "",
|
||||
status_message=message_data.error or "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
langsmith_run = LangSmithRunModel(
|
||||
file_list=trace_info.file_list,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
|
||||
id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
inputs=trace_info.workflow_run_inputs,
|
||||
run_type=LangSmithRunType.tool,
|
||||
|
@ -94,7 +94,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
},
|
||||
error=trace_info.error,
|
||||
tags=["workflow"],
|
||||
parent_run_id=trace_info.message_id if trace_info.message_id else None,
|
||||
parent_run_id=trace_info.message_id or None,
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
@ -133,7 +133,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at if node_execution.created_at else datetime.now()
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
|
@ -180,9 +180,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
extra={
|
||||
"metadata": metadata,
|
||||
},
|
||||
parent_run_id=trace_info.workflow_app_log_id
|
||||
if trace_info.workflow_app_log_id
|
||||
else trace_info.workflow_run_id,
|
||||
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
|
||||
tags=["node_execution"],
|
||||
)
|
||||
|
||||
|
|
|
@ -354,11 +354,11 @@ class TraceTask:
|
|||
workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {}
|
||||
workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {}
|
||||
workflow_run_version = workflow_run.version
|
||||
error = workflow_run.error if workflow_run.error else ""
|
||||
error = workflow_run.error or ""
|
||||
|
||||
total_tokens = workflow_run.total_tokens
|
||||
|
||||
file_list = workflow_run_inputs.get("sys.file") if workflow_run_inputs.get("sys.file") else []
|
||||
file_list = workflow_run_inputs.get("sys.file") or []
|
||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||
|
||||
# get workflow_app_log_id
|
||||
|
@ -452,7 +452,7 @@ class TraceTask:
|
|||
message_tokens=message_tokens,
|
||||
answer_tokens=message_data.answer_tokens,
|
||||
total_tokens=message_tokens + message_data.answer_tokens,
|
||||
error=message_data.error if message_data.error else "",
|
||||
error=message_data.error or "",
|
||||
inputs=inputs,
|
||||
outputs=message_data.answer,
|
||||
file_list=file_list,
|
||||
|
@ -487,7 +487,7 @@ class TraceTask:
|
|||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
|
||||
moderation_trace_info = ModerationTraceInfo(
|
||||
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
|
||||
message_id=workflow_app_log_id or message_id,
|
||||
inputs=inputs,
|
||||
message_data=message_data.to_dict(),
|
||||
flagged=moderation_result.flagged,
|
||||
|
@ -527,7 +527,7 @@ class TraceTask:
|
|||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
|
||||
suggested_question_trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
|
||||
message_id=workflow_app_log_id or message_id,
|
||||
message_data=message_data.to_dict(),
|
||||
inputs=message_data.message,
|
||||
outputs=message_data.answer,
|
||||
|
@ -569,7 +569,7 @@ class TraceTask:
|
|||
|
||||
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id=message_id,
|
||||
inputs=message_data.query if message_data.query else message_data.inputs,
|
||||
inputs=message_data.query or message_data.inputs,
|
||||
documents=[doc.model_dump() for doc in documents],
|
||||
start_time=timer.get("start"),
|
||||
end_time=timer.get("end"),
|
||||
|
@ -695,8 +695,7 @@ class TraceQueueManager:
|
|||
self.start_timer()
|
||||
|
||||
def add_trace_task(self, trace_task: TraceTask):
|
||||
global trace_manager_timer
|
||||
global trace_manager_queue
|
||||
global trace_manager_timer, trace_manager_queue
|
||||
try:
|
||||
if self.trace_instance:
|
||||
trace_task.app_id = self.app_id
|
||||
|
|
|
@ -112,11 +112,11 @@ class SimplePromptTransform(PromptTransform):
|
|||
for v in prompt_template_config["special_variable_keys"]:
|
||||
# support #context#, #query# and #histories#
|
||||
if v == "#context#":
|
||||
variables["#context#"] = context if context else ""
|
||||
variables["#context#"] = context or ""
|
||||
elif v == "#query#":
|
||||
variables["#query#"] = query if query else ""
|
||||
variables["#query#"] = query or ""
|
||||
elif v == "#histories#":
|
||||
variables["#histories#"] = histories if histories else ""
|
||||
variables["#histories#"] = histories or ""
|
||||
|
||||
prompt_template = prompt_template_config["prompt_template"]
|
||||
prompt = prompt_template.format(variables)
|
||||
|
|
|
@ -34,7 +34,7 @@ class BaseKeyword(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
for text in texts.copy():
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
|
|
|
@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector):
|
|||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
|
@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector):
|
|||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
|
|
|
@ -92,7 +92,7 @@ class ChromaVector(BaseVector):
|
|||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
|
||||
ids: list[str] = results["ids"][0]
|
||||
documents: list[str] = results["documents"][0]
|
||||
|
|
|
@ -86,8 +86,8 @@ class ElasticSearchVector(BaseVector):
|
|||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {},
|
||||
Field.VECTOR.value: embeddings[i] or None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata or {},
|
||||
},
|
||||
)
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
|
@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector):
|
|||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
|
|
@ -141,7 +141,7 @@ class MilvusVector(BaseVector):
|
|||
for result in results[0]:
|
||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
||||
metadata["score"] = result["distance"]
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
|
|
@ -122,7 +122,7 @@ class MyScaleVector(BaseVector):
|
|||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
where_str = (
|
||||
f"WHERE dist < {1 - score_threshold}"
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
|
|
|
@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector):
|
|||
metadata = {}
|
||||
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if hit["_score"] > score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
|
|
@ -200,7 +200,7 @@ class OracleVector(BaseVector):
|
|||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
|
@ -212,7 +212,7 @@ class OracleVector(BaseVector):
|
|||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if len(query) > 0:
|
||||
# Check which language the query is in
|
||||
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
||||
|
|
|
@ -198,7 +198,7 @@ class PGVectoRS(BaseVector):
|
|||
metadata = record.meta
|
||||
score = 1 - dis
|
||||
metadata["score"] = score
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if score > score_threshold:
|
||||
doc = Document(page_content=record.text, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
|
|
@ -144,7 +144,7 @@ class PGVector(BaseVector):
|
|||
(json.dumps(query_vector),),
|
||||
)
|
||||
docs = []
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
|
|
|
@ -339,7 +339,7 @@ class QdrantVector(BaseVector):
|
|||
for result in results:
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
|
|
|
@ -230,7 +230,7 @@ class RelytVector(BaseVector):
|
|||
# Organize results.
|
||||
docs = []
|
||||
for document, score in results:
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if 1 - score > score_threshold:
|
||||
docs.append(document)
|
||||
return docs
|
||||
|
|
|
@ -153,7 +153,7 @@ class TencentVector(BaseVector):
|
|||
limit=kwargs.get("top_k", 4),
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
|
|
|
@ -185,7 +185,7 @@ class TiDBVector(BaseVector):
|
|||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
filter = kwargs.get("filter")
|
||||
distance = 1 - score_threshold
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ class BaseVector(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
for text in texts.copy():
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
|
|
|
@ -153,7 +153,7 @@ class Vector:
|
|||
return CacheEmbedding(embedding_model)
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
for text in texts.copy():
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
|
|
|
@ -205,7 +205,7 @@ class WeaviateVector(BaseVector):
|
|||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
|
|
|
@ -12,7 +12,7 @@ import mimetypes
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Iterable, Mapping
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
@ -56,8 +56,7 @@ class Blob(BaseModel):
|
|||
def as_string(self) -> str:
|
||||
"""Read data as a string."""
|
||||
if self.data is None and self.path:
|
||||
with open(str(self.path), encoding=self.encoding) as f:
|
||||
return f.read()
|
||||
return Path(str(self.path)).read_text(encoding=self.encoding)
|
||||
elif isinstance(self.data, bytes):
|
||||
return self.data.decode(self.encoding)
|
||||
elif isinstance(self.data, str):
|
||||
|
@ -72,8 +71,7 @@ class Blob(BaseModel):
|
|||
elif isinstance(self.data, str):
|
||||
return self.data.encode(self.encoding)
|
||||
elif self.data is None and self.path:
|
||||
with open(str(self.path), "rb") as f:
|
||||
return f.read()
|
||||
return Path(str(self.path)).read_bytes()
|
||||
else:
|
||||
raise ValueError(f"Unable to get bytes for blob {self}")
|
||||
|
||||
|
|
|
@ -68,8 +68,7 @@ class ExtractProcessor:
|
|||
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
|
||||
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(response.content)
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
||||
if return_text:
|
||||
delimiter = "\n"
|
||||
|
@ -111,7 +110,7 @@ class ExtractProcessor:
|
|||
)
|
||||
elif file_extension in [".htm", ".html"]:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in [".docx"]:
|
||||
elif file_extension == ".docx":
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
|
@ -143,7 +142,7 @@ class ExtractProcessor:
|
|||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in [".htm", ".html"]:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in [".docx"]:
|
||||
elif file_extension == ".docx":
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Document loader helpers."""
|
||||
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Optional, cast
|
||||
|
||||
|
||||
|
@ -28,8 +29,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding
|
|||
import chardet
|
||||
|
||||
def read_and_detect(file_path: str) -> list[dict]:
|
||||
with open(file_path, "rb") as f:
|
||||
rawdata = f.read()
|
||||
rawdata = Path(file_path).read_bytes()
|
||||
return cast(list[dict], chardet.detect_all(rawdata))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
|
@ -102,15 +103,13 @@ class MarkdownExtractor(BaseExtractor):
|
|||
"""Parse file into tuples."""
|
||||
content = ""
|
||||
try:
|
||||
with open(filepath, encoding=self._encoding) as f:
|
||||
content = f.read()
|
||||
content = Path(filepath).read_text(encoding=self._encoding)
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(filepath)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(filepath, encoding=encoding.encoding) as f:
|
||||
content = f.read()
|
||||
content = Path(filepath).read_text(encoding=encoding.encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
|
@ -25,15 +26,13 @@ class TextExtractor(BaseExtractor):
|
|||
"""Load from file path."""
|
||||
text = ""
|
||||
try:
|
||||
with open(self._file_path, encoding=self._encoding) as f:
|
||||
text = f.read()
|
||||
text = Path(self._file_path).read_text(encoding=self._encoding)
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, encoding=encoding.encoding) as f:
|
||||
text = f.read()
|
||||
text = Path(self._file_path).read_text(encoding=encoding.encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
|
|
@ -153,7 +153,7 @@ class WordExtractor(BaseExtractor):
|
|||
if col_index >= total_cols:
|
||||
break
|
||||
cell_content = self._parse_cell(cell, image_map).strip()
|
||||
cell_colspan = cell.grid_span if cell.grid_span else 1
|
||||
cell_colspan = cell.grid_span or 1
|
||||
for i in range(cell_colspan):
|
||||
if col_index + i < total_cols:
|
||||
row_cells[col_index + i] = cell_content if i == 0 else ""
|
||||
|
|
|
@ -256,7 +256,7 @@ class DatasetRetrieval:
|
|||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
|
@ -410,7 +410,7 @@ class DatasetRetrieval:
|
|||
return []
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
|
@ -433,9 +433,7 @@ class DatasetRetrieval:
|
|||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode")
|
||||
if retrieval_model.get("reranking_mode")
|
||||
else "reranking_model",
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
|
@ -486,7 +484,7 @@ class DatasetRetrieval:
|
|||
}
|
||||
|
||||
for dataset in available_datasets:
|
||||
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
|
|
|
@ -106,7 +106,7 @@ class ApiToolProviderController(ToolProviderController):
|
|||
"human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
|
||||
"llm": tool_bundle.summary or "",
|
||||
},
|
||||
"parameters": tool_bundle.parameters if tool_bundle.parameters else [],
|
||||
"parameters": tool_bundle.parameters or [],
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import operator
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
@ -71,7 +72,7 @@ class SageMakerReRankTool(BuiltinTool):
|
|||
candidate_docs[idx]["score"] = scores[idx]
|
||||
|
||||
line = 8
|
||||
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x["score"], reverse=True)
|
||||
sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 9
|
||||
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
|
||||
|
|
|
@ -115,7 +115,7 @@ class GetWorksheetFieldsTool(BuiltinTool):
|
|||
fields.append(field)
|
||||
fields_list.append(
|
||||
f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}"
|
||||
f"|{field['options'] if field['options'] else ''}|"
|
||||
f"|{field['options'] or ''}|"
|
||||
)
|
||||
|
||||
fields.append(
|
||||
|
|
|
@ -130,7 +130,7 @@ class GetWorksheetPivotDataTool(BuiltinTool):
|
|||
# ]
|
||||
rows = []
|
||||
for row in data["data"]:
|
||||
row_data = row["rows"] if row["rows"] else {}
|
||||
row_data = row["rows"] or {}
|
||||
row_data.update(row["columns"])
|
||||
row_data.update(row["values"])
|
||||
rows.append(row_data)
|
||||
|
|
|
@ -113,7 +113,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
|
|||
result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"."
|
||||
if result["total"] > 0:
|
||||
result_text += (
|
||||
f" The following are {result['total'] if result['total'] < limit else limit}"
|
||||
f" The following are {min(limit, result['total'])}"
|
||||
f" pieces of data presented in a table format:\n\n{table_header}"
|
||||
)
|
||||
for row in rows:
|
||||
|
|
|
@ -37,7 +37,7 @@ class SearchAPI:
|
|||
return {
|
||||
"engine": "youtube_transcripts",
|
||||
"video_id": video_id,
|
||||
"lang": language if language else "en",
|
||||
"lang": language or "en",
|
||||
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
|
||||
}
|
||||
|
||||
|
|
|
@ -160,7 +160,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
hit_callback.on_query(query, dataset.id)
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
|
@ -183,9 +183,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode")
|
||||
if retrieval_model.get("reranking_mode")
|
||||
else "reranking_model",
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
hit_callback.on_query(query, dataset.id)
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
|
@ -76,9 +76,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode")
|
||||
if retrieval_model.get("reranking_mode")
|
||||
else "reranking_model",
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -8,6 +8,7 @@ import subprocess
|
|||
import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
|
@ -98,7 +99,7 @@ def get_url(url: str, user_agent: str = None) -> str:
|
|||
authors=a["byline"],
|
||||
publish_date=a["date"],
|
||||
top_image="",
|
||||
text=a["plain_text"] if a["plain_text"] else "",
|
||||
text=a["plain_text"] or "",
|
||||
)
|
||||
|
||||
return res
|
||||
|
@ -117,8 +118,7 @@ def extract_using_readabilipy(html):
|
|||
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
||||
|
||||
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
|
||||
with open(article_json_path, encoding="utf-8") as json_file:
|
||||
input_json = json.loads(json_file.read())
|
||||
input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
|
||||
|
||||
# Deleting files after processing
|
||||
os.unlink(article_json_path)
|
||||
|
|
|
@ -21,7 +21,7 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any
|
|||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content if yaml_content else default_value
|
||||
return yaml_content or default_value
|
||||
except Exception as e:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}")
|
||||
except Exception as e:
|
||||
|
|
|
@ -268,7 +268,7 @@ class Graph(BaseModel):
|
|||
f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph."
|
||||
)
|
||||
|
||||
new_route = route[:]
|
||||
new_route = route.copy()
|
||||
new_route.append(graph_edge.target_node_id)
|
||||
cls._check_connected_to_previous_node(
|
||||
route=new_route,
|
||||
|
@ -679,8 +679,7 @@ class Graph(BaseModel):
|
|||
all_routes_node_ids = set()
|
||||
parallel_start_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
all_routes_node_ids.add(node_id)
|
||||
all_routes_node_ids.update(node_ids)
|
||||
|
||||
if branch_node_id in reverse_edge_mapping:
|
||||
for graph_edge in reverse_edge_mapping[branch_node_id]:
|
||||
|
|
|
@ -74,7 +74,7 @@ class CodeNode(BaseNode):
|
|||
:return:
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
if isinstance(value, type(None)):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Output variable `{variable}` must be a string")
|
||||
|
@ -95,7 +95,7 @@ class CodeNode(BaseNode):
|
|||
:return:
|
||||
"""
|
||||
if not isinstance(value, int | float):
|
||||
if isinstance(value, type(None)):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Output variable `{variable}` must be a number")
|
||||
|
@ -182,7 +182,7 @@ class CodeNode(BaseNode):
|
|||
f"Output {prefix}.{output_name} is not a valid array."
|
||||
f" make sure all elements are of the same type."
|
||||
)
|
||||
elif isinstance(output_value, type(None)):
|
||||
elif output_value is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Output {prefix}.{output_name} is not a valid type.")
|
||||
|
@ -284,7 +284,7 @@ class CodeNode(BaseNode):
|
|||
|
||||
for i, value in enumerate(result[output_name]):
|
||||
if not isinstance(value, dict):
|
||||
if isinstance(value, type(None)):
|
||||
if value is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
|
@ -79,7 +79,7 @@ class IfElseNode(BaseNode):
|
|||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_datas,
|
||||
edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default'
|
||||
edge_source_handle=selected_case_id or "false", # Use case ID or 'default'
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
|
|
|
@ -580,7 +580,7 @@ class LLMNode(BaseNode):
|
|||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=node_data.prompt_template,
|
||||
inputs=inputs,
|
||||
query=query if query else "",
|
||||
query=query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
|
|
|
@ -250,7 +250,7 @@ class QuestionClassifierNode(LLMNode):
|
|||
for class_ in classes:
|
||||
category = {"category_id": class_.id, "category_name": class_.name}
|
||||
categories.append(category)
|
||||
instruction = node_data.instruction if node_data.instruction else ""
|
||||
instruction = node_data.instruction or ""
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
if memory:
|
||||
|
|
|
@ -18,8 +18,7 @@ def handle(sender, **kwargs):
|
|||
added_dataset_ids = dataset_ids
|
||||
else:
|
||||
old_dataset_ids = set()
|
||||
for app_dataset_join in app_dataset_joins:
|
||||
old_dataset_ids.add(app_dataset_join.dataset_id)
|
||||
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
|
||||
|
||||
added_dataset_ids = dataset_ids - old_dataset_ids
|
||||
removed_dataset_ids = old_dataset_ids - dataset_ids
|
||||
|
|
|
@ -22,8 +22,7 @@ def handle(sender, **kwargs):
|
|||
added_dataset_ids = dataset_ids
|
||||
else:
|
||||
old_dataset_ids = set()
|
||||
for app_dataset_join in app_dataset_joins:
|
||||
old_dataset_ids.add(app_dataset_join.dataset_id)
|
||||
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
|
||||
|
||||
added_dataset_ids = dataset_ids - old_dataset_ids
|
||||
removed_dataset_ids = old_dataset_ids - dataset_ids
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import shutil
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
@ -26,8 +27,7 @@ class LocalStorage(BaseStorage):
|
|||
folder = os.path.dirname(filename)
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
with open(os.path.join(os.getcwd(), filename), "wb") as f:
|
||||
f.write(data)
|
||||
Path(os.path.join(os.getcwd(), filename)).write_bytes(data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if not self.folder or self.folder.endswith("/"):
|
||||
|
@ -38,9 +38,7 @@ class LocalStorage(BaseStorage):
|
|||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
data = Path(filename).read_bytes()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
|
|
|
@ -144,7 +144,7 @@ class Dataset(db.Model):
|
|||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
return self.retrieval_model if self.retrieval_model else default_retrieval_model
|
||||
return self.retrieval_model or default_retrieval_model
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
|
@ -160,7 +160,7 @@ class Dataset(db.Model):
|
|||
.all()
|
||||
)
|
||||
|
||||
return tags if tags else []
|
||||
return tags or []
|
||||
|
||||
@staticmethod
|
||||
def gen_collection_name_by_id(dataset_id: str) -> str:
|
||||
|
|
|
@ -118,7 +118,7 @@ class App(db.Model):
|
|||
|
||||
@property
|
||||
def api_base_url(self):
|
||||
return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")) + "/v1"
|
||||
return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"
|
||||
|
||||
@property
|
||||
def tenant(self):
|
||||
|
@ -207,7 +207,7 @@ class App(db.Model):
|
|||
.all()
|
||||
)
|
||||
|
||||
return tags if tags else []
|
||||
return tags or []
|
||||
|
||||
|
||||
class AppModelConfig(db.Model):
|
||||
|
@ -908,7 +908,7 @@ class Message(db.Model):
|
|||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"url": url,
|
||||
"belongs_to": message_file.belongs_to if message_file.belongs_to else "user",
|
||||
"belongs_to": message_file.belongs_to or "user",
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -1212,7 +1212,7 @@ class Site(db.Model):
|
|||
|
||||
@property
|
||||
def app_base_url(self):
|
||||
return dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip("/")
|
||||
return dify_config.APP_WEB_URL or request.url_root.rstrip("/")
|
||||
|
||||
|
||||
class ApiToken(db.Model):
|
||||
|
@ -1488,7 +1488,7 @@ class TraceAppConfig(db.Model):
|
|||
|
||||
@property
|
||||
def tracing_config_dict(self):
|
||||
return self.tracing_config if self.tracing_config else {}
|
||||
return self.tracing_config or {}
|
||||
|
||||
@property
|
||||
def tracing_config_str(self):
|
||||
|
|
|
@ -15,6 +15,7 @@ select = [
|
|||
"C4", # flake8-comprehensions
|
||||
"E", # pycodestyle E rules
|
||||
"F", # pyflakes rules
|
||||
"FURB", # refurb rules
|
||||
"I", # isort rules
|
||||
"N", # pep8-naming
|
||||
"RUF019", # unnecessary-key-check
|
||||
|
@ -37,6 +38,8 @@ ignore = [
|
|||
"F405", # undefined-local-with-import-star-usage
|
||||
"F821", # undefined-name
|
||||
"F841", # unused-variable
|
||||
"FURB113", # repeated-append
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"B005", # strip-with-multi-characters
|
||||
|
|
|
@ -544,7 +544,7 @@ class RegisterService:
|
|||
"""Register account"""
|
||||
try:
|
||||
account = AccountService.create_account(
|
||||
email=email, name=name, interface_language=language if language else languages[0], password=password
|
||||
email=email, name=name, interface_language=language or languages[0], password=password
|
||||
)
|
||||
account.status = AccountStatus.ACTIVE.value if not status else status.value
|
||||
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
|
|
@ -81,13 +81,11 @@ class AppDslService:
|
|||
raise ValueError("Missing app in data argument")
|
||||
|
||||
# get app basic info
|
||||
name = args.get("name") if args.get("name") else app_data.get("name")
|
||||
description = args.get("description") if args.get("description") else app_data.get("description", "")
|
||||
icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type")
|
||||
icon = args.get("icon") if args.get("icon") else app_data.get("icon")
|
||||
icon_background = (
|
||||
args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background")
|
||||
)
|
||||
name = args.get("name") or app_data.get("name")
|
||||
description = args.get("description") or app_data.get("description", "")
|
||||
icon_type = args.get("icon_type") or app_data.get("icon_type")
|
||||
icon = args.get("icon") or app_data.get("icon")
|
||||
icon_background = args.get("icon_background") or app_data.get("icon_background")
|
||||
use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False)
|
||||
|
||||
# import dsl and create app
|
||||
|
|
|
@ -155,7 +155,7 @@ class DatasetService:
|
|||
dataset.tenant_id = tenant_id
|
||||
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
|
||||
dataset.embedding_model = embedding_model.model if embedding_model else None
|
||||
dataset.permission = permission if permission else DatasetPermissionEnum.ONLY_ME
|
||||
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
return dataset
|
||||
|
@ -681,11 +681,7 @@ class DocumentService:
|
|||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
dataset.retrieval_model = (
|
||||
document_data.get("retrieval_model")
|
||||
if document_data.get("retrieval_model")
|
||||
else default_retrieval_model
|
||||
)
|
||||
dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model
|
||||
|
||||
documents = []
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||
|
|
|
@ -33,7 +33,7 @@ class HitTestingService:
|
|||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
if not retrieval_model:
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
|
@ -46,9 +46,7 @@ class HitTestingService:
|
|||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode")
|
||||
if retrieval_model.get("reranking_mode")
|
||||
else "reranking_model",
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
import requests
|
||||
|
@ -453,9 +454,8 @@ class ModelProviderService:
|
|||
mimetype = mimetype or "application/octet-stream"
|
||||
|
||||
# read binary from file
|
||||
with open(file_path, "rb") as f:
|
||||
byte_data = f.read()
|
||||
return byte_data, mimetype
|
||||
byte_data = Path(file_path).read_bytes()
|
||||
return byte_data, mimetype
|
||||
|
||||
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
@ -218,10 +219,9 @@ class RecommendedAppService:
|
|||
return cls.builtin_data
|
||||
|
||||
root_path = current_app.root_path
|
||||
with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f:
|
||||
json_data = f.read()
|
||||
data = json.loads(json_data)
|
||||
cls.builtin_data = data
|
||||
cls.builtin_data = json.loads(
|
||||
Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8")
|
||||
)
|
||||
|
||||
return cls.builtin_data
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user