chore(api/services): apply ruff reformatting (#7599)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Bowen Liang 2024-08-26 13:43:57 +08:00 committed by GitHub
parent 979422cdc6
commit 17fd773a30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 2630 additions and 2655 deletions

View File

@ -74,7 +74,6 @@ exclude = [
"controllers/**/*.py",
"models/**/*.py",
"migrations/**/*",
"services/**/*.py",
]
[tool.pytest_env]

View File

@ -1,3 +1,3 @@
from . import errors
__all__ = ['errors']
__all__ = ["errors"]

View File

@ -39,12 +39,7 @@ from tasks.mail_reset_password_task import send_reset_password_mail_task
class AccountService:
reset_password_rate_limiter = RateLimiter(
prefix="reset_password_rate_limit",
max_attempts=5,
time_window=60 * 60
)
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60)
@staticmethod
def load_user(user_id: str) -> None | Account:
@ -55,12 +50,15 @@ class AccountService:
if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
raise Unauthorized("Account is banned or closed.")
current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(
account_id=account.id, current=True
).first()
if current_tenant:
account.current_tenant_id = current_tenant.tenant_id
else:
available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \
.order_by(TenantAccountJoin.id.asc()).first()
available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
)
if not available_ta:
return None
@ -74,14 +72,13 @@ class AccountService:
return account
@staticmethod
def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)):
payload = {
"user_id": account.id,
"exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
"iss": dify_config.EDITION,
"sub": 'Console API Passport',
"sub": "Console API Passport",
}
token = PassportService().issue(payload)
@ -93,10 +90,10 @@ class AccountService:
account = Account.query.filter_by(email=email).first()
if not account:
raise AccountLoginError('Invalid email or password.')
raise AccountLoginError("Invalid email or password.")
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise AccountLoginError('Account is banned or closed.')
raise AccountLoginError("Account is banned or closed.")
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
@ -104,7 +101,7 @@ class AccountService:
db.session.commit()
if account.password is None or not compare_password(password, account.password, account.password_salt):
raise AccountLoginError('Invalid email or password.')
raise AccountLoginError("Invalid email or password.")
return account
@staticmethod
@ -129,11 +126,9 @@ class AccountService:
return account
@staticmethod
def create_account(email: str,
name: str,
interface_language: str,
password: Optional[str] = None,
interface_theme: str = 'light') -> Account:
def create_account(
email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light"
) -> Account:
"""create account"""
account = Account()
account.email = email
@ -155,7 +150,7 @@ class AccountService:
account.interface_theme = interface_theme
# Set timezone based on language
account.timezone = language_timezone_mapping.get(interface_language, 'UTC')
account.timezone = language_timezone_mapping.get(interface_language, "UTC")
db.session.add(account)
db.session.commit()
@ -166,8 +161,9 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id,
provider=provider).first()
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
account_id=account.id, provider=provider
).first()
if account_integrate:
# If it exists, update the record
@ -176,15 +172,16 @@ class AccountService:
account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
else:
# If it does not exist, create a new record
account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id,
encrypted_token="")
account_integrate = AccountIntegrate(
account_id=account.id, provider=provider, open_id=open_id, encrypted_token=""
)
db.session.add(account_integrate)
db.session.commit()
logging.info(f'Account {account.id} linked {provider} account {open_id}.')
logging.info(f"Account {account.id} linked {provider} account {open_id}.")
except Exception as e:
logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}')
raise LinkAccountIntegrateError('Failed to link account.') from e
logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}")
raise LinkAccountIntegrateError("Failed to link account.") from e
@staticmethod
def close_account(account: Account) -> None:
@ -218,7 +215,7 @@ class AccountService:
AccountService.update_last_login(account, ip_address=ip_address)
exp = timedelta(days=30)
token = AccountService.get_account_jwt_token(account, exp=exp)
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds()))
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds()))
return token
@staticmethod
@ -236,22 +233,18 @@ class AccountService:
if cls.reset_password_rate_limiter.is_rate_limited(account.email):
raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.")
token = TokenManager.generate_token(account, 'reset_password')
send_reset_password_mail_task.delay(
language=account.interface_language,
to=account.email,
token=token
)
token = TokenManager.generate_token(account, "reset_password")
send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token)
cls.reset_password_rate_limiter.increment_rate_limit(account.email)
return token
@classmethod
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, 'reset_password')
TokenManager.revoke_token(token, "reset_password")
@classmethod
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, 'reset_password')
return TokenManager.get_token_data(token, "reset_password")
def _get_login_cache_key(*, account_id: str, token: str):
@ -259,7 +252,6 @@ def _get_login_cache_key(*, account_id: str, token: str):
class TenantService:
@staticmethod
def create_tenant(name: str) -> Tenant:
"""Create tenant"""
@ -275,31 +267,28 @@ class TenantService:
@staticmethod
def create_owner_tenant_if_not_exist(account: Account):
"""Create owner tenant if not exist"""
available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \
.order_by(TenantAccountJoin.id.asc()).first()
available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
)
if available_ta:
return
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role='owner')
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
db.session.commit()
tenant_was_created.send(tenant)
@staticmethod
def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin:
def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
"""Create tenant member"""
if role == TenantAccountJoinRole.OWNER.value:
if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]):
logging.error(f'Tenant {tenant.id} has already an owner.')
raise Exception('Tenant already has an owner.')
logging.error(f"Tenant {tenant.id} has already an owner.")
raise Exception("Tenant already has an owner.")
ta = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role
)
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
db.session.add(ta)
db.session.commit()
return ta
@ -307,9 +296,12 @@ class TenantService:
@staticmethod
def get_join_tenants(account: Account) -> list[Tenant]:
"""Get account join tenants"""
return db.session.query(Tenant).join(
TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id
).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all()
return (
db.session.query(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
.all()
)
@staticmethod
def get_current_tenant_by_account(account: Account):
@ -333,16 +325,23 @@ class TenantService:
if tenant_id is None:
raise ValueError("Tenant ID must be provided.")
tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == tenant_id,
Tenant.status == TenantStatus.NORMAL,
).first()
tenant_account_join = (
db.session.query(TenantAccountJoin)
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
.filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == tenant_id,
Tenant.status == TenantStatus.NORMAL,
)
.first()
)
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False})
TenantAccountJoin.query.filter(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
).update({"current": False})
tenant_account_join.current = True
# Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id
@ -354,9 +353,7 @@ class TenantService:
query = (
db.session.query(Account, TenantAccountJoin.role)
.select_from(Account)
.join(
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.filter(TenantAccountJoin.tenant_id == tenant.id)
)
@ -375,11 +372,9 @@ class TenantService:
query = (
db.session.query(Account, TenantAccountJoin.role)
.select_from(Account)
.join(
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.filter(TenantAccountJoin.tenant_id == tenant.id)
.filter(TenantAccountJoin.role == 'dataset_operator')
.filter(TenantAccountJoin.role == "dataset_operator")
)
# Initialize an empty list to store the updated accounts
@ -395,20 +390,25 @@ class TenantService:
def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool:
"""Check if user has any of the given roles for a tenant"""
if not all(isinstance(role, TenantAccountJoinRole) for role in roles):
raise ValueError('all roles must be TenantAccountJoinRole')
raise ValueError("all roles must be TenantAccountJoinRole")
return db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.role.in_([role.value for role in roles])
).first() is not None
return (
db.session.query(TenantAccountJoin)
.filter(
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])
)
.first()
is not None
)
@staticmethod
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]:
"""Get the role of the current account for a given tenant"""
join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.account_id == account.id
).first()
join = (
db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.first()
)
return join.role if join else None
@staticmethod
@ -420,29 +420,26 @@ class TenantService:
def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None:
"""Check member permission"""
perms = {
'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
'remove': [TenantAccountRole.OWNER],
'update': [TenantAccountRole.OWNER]
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
"remove": [TenantAccountRole.OWNER],
"update": [TenantAccountRole.OWNER],
}
if action not in ['add', 'remove', 'update']:
if action not in ["add", "remove", "update"]:
raise InvalidActionError("Invalid action.")
if member:
if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.")
ta_operator = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id,
account_id=operator.id
).first()
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f'No permission to {action} member.')
raise NoPermissionError(f"No permission to {action} member.")
@staticmethod
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
"""Remove member from tenant"""
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'):
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"):
raise CannotOperateSelfError("Cannot operate self.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
@ -455,23 +452,17 @@ class TenantService:
@staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, 'update')
TenantService.check_member_permission(tenant, operator, member, "update")
target_member_join = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id,
account_id=member.id
).first()
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
if target_member_join.role == new_role:
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
if new_role == 'owner':
if new_role == "owner":
# Find the current owner and change their role to 'admin'
current_owner_join = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id,
role='owner'
).first()
current_owner_join.role = 'admin'
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
current_owner_join.role = "admin"
# Update the role of the target member
target_member_join.role = new_role
@ -480,8 +471,8 @@ class TenantService:
@staticmethod
def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
"""Dissolve tenant"""
if not TenantService.check_member_permission(tenant, operator, operator, 'remove'):
raise NoPermissionError('No permission to dissolve tenant.')
if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
raise NoPermissionError("No permission to dissolve tenant.")
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
db.session.delete(tenant)
db.session.commit()
@ -494,10 +485,9 @@ class TenantService:
class RegisterService:
@classmethod
def _get_invitation_token_key(cls, token: str) -> str:
return f'member_invite:token:{token}'
return f"member_invite:token:{token}"
@classmethod
def setup(cls, email: str, name: str, password: str, ip_address: str) -> None:
@ -523,9 +513,7 @@ class RegisterService:
TenantService.create_owner_tenant_if_not_exist(account)
dify_setup = DifySetup(
version=dify_config.CURRENT_VERSION
)
dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
db.session.add(dify_setup)
db.session.commit()
except Exception as e:
@ -535,34 +523,35 @@ class RegisterService:
db.session.query(Tenant).delete()
db.session.commit()
logging.exception(f'Setup failed: {e}')
raise ValueError(f'Setup failed: {e}')
logging.exception(f"Setup failed: {e}")
raise ValueError(f"Setup failed: {e}")
@classmethod
def register(cls, email, name,
password: Optional[str] = None,
open_id: Optional[str] = None,
provider: Optional[str] = None,
language: Optional[str] = None,
status: Optional[AccountStatus] = None) -> Account:
def register(
cls,
email,
name,
password: Optional[str] = None,
open_id: Optional[str] = None,
provider: Optional[str] = None,
language: Optional[str] = None,
status: Optional[AccountStatus] = None,
) -> Account:
db.session.begin_nested()
"""Register account"""
try:
account = AccountService.create_account(
email=email,
name=name,
interface_language=language if language else languages[0],
password=password
email=email, name=name, interface_language=language if language else languages[0], password=password
)
account.status = AccountStatus.ACTIVE.value if not status else status.value
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
if open_id is not None or provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
if dify_config.EDITION != 'SELF_HOSTED':
if dify_config.EDITION != "SELF_HOSTED":
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role='owner')
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
@ -570,30 +559,29 @@ class RegisterService:
db.session.commit()
except Exception as e:
db.session.rollback()
logging.error(f'Register failed: {e}')
raise AccountRegisterError(f'Registration failed: {e}') from e
logging.error(f"Register failed: {e}")
raise AccountRegisterError(f"Registration failed: {e}") from e
return account
@classmethod
def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str:
def invite_new_member(
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None
) -> str:
"""Invite new member"""
account = Account.query.filter_by(email=email).first()
if not account:
TenantService.check_member_permission(tenant, inviter, None, 'add')
name = email.split('@')[0]
TenantService.check_member_permission(tenant, inviter, None, "add")
name = email.split("@")[0]
account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING)
# Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role)
TenantService.switch_tenant(account, tenant.id)
else:
TenantService.check_member_permission(tenant, inviter, account, 'add')
ta = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id,
account_id=account.id
).first()
TenantService.check_member_permission(tenant, inviter, account, "add")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
if not ta:
TenantService.create_tenant_member(tenant, account, role)
@ -609,7 +597,7 @@ class RegisterService:
language=account.interface_language,
to=email,
token=token,
inviter_name=inviter.name if inviter else 'Dify',
inviter_name=inviter.name if inviter else "Dify",
workspace_name=tenant.name,
)
@ -619,23 +607,19 @@ class RegisterService:
def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
token = str(uuid.uuid4())
invitation_data = {
'account_id': account.id,
'email': account.email,
'workspace_id': tenant.id,
"account_id": account.id,
"email": account.email,
"workspace_id": tenant.id,
}
expiryHours = dify_config.INVITE_EXPIRY_HOURS
redis_client.setex(
cls._get_invitation_token_key(token),
expiryHours * 60 * 60,
json.dumps(invitation_data)
)
redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data))
return token
@classmethod
def revoke_token(cls, workspace_id: str, email: str, token: str):
if workspace_id and email:
email_hash = sha256(email.encode()).hexdigest()
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token)
redis_client.delete(cache_key)
else:
redis_client.delete(cls._get_invitation_token_key(token))
@ -646,17 +630,21 @@ class RegisterService:
if not invitation_data:
return None
tenant = db.session.query(Tenant).filter(
Tenant.id == invitation_data['workspace_id'],
Tenant.status == 'normal'
).first()
tenant = (
db.session.query(Tenant)
.filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
.first()
)
if not tenant:
return None
tenant_account = db.session.query(Account, TenantAccountJoin.role).join(
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first()
tenant_account = (
db.session.query(Account, TenantAccountJoin.role)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
.first()
)
if not tenant_account:
return None
@ -665,29 +653,29 @@ class RegisterService:
if not account:
return None
if invitation_data['account_id'] != str(account.id):
if invitation_data["account_id"] != str(account.id):
return None
return {
'account': account,
'data': invitation_data,
'tenant': tenant,
"account": account,
"data": invitation_data,
"tenant": tenant,
}
@classmethod
def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]:
if workspace_id is not None and email is not None:
email_hash = sha256(email.encode()).hexdigest()
cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}'
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
account_id = redis_client.get(cache_key)
if not account_id:
return None
return {
'account_id': account_id.decode('utf-8'),
'email': email,
'workspace_id': workspace_id,
"account_id": account_id.decode("utf-8"),
"email": email,
"workspace_id": workspace_id,
}
else:
data = redis_client.get(cls._get_invitation_token_key(token))

View File

@ -1,4 +1,3 @@
import copy
from core.prompt.prompt_templates.advanced_prompt_templates import (
@ -17,59 +16,78 @@ from models.model import AppMode
class AdvancedPromptTemplateService:
@classmethod
def get_prompt(cls, args: dict) -> dict:
app_mode = args['app_mode']
model_mode = args['model_mode']
model_name = args['model_name']
has_context = args['has_context']
app_mode = args["app_mode"]
model_mode = args["model_mode"]
model_name = args["model_name"]
has_context = args["has_context"]
if 'baichuan' in model_name.lower():
if "baichuan" in model_name.lower():
return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
else:
return cls.get_common_prompt(app_mode, model_mode, has_context)
@classmethod
def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == AppMode.CHAT.value:
if model_mode == "completion":
return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
return cls.get_completion_prompt(
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == AppMode.COMPLETION.value:
if model_mode == "completion":
return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
return cls.get_completion_prompt(
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
return cls.get_chat_prompt(
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
)
@classmethod
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
if has_context == 'true':
prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
if has_context == "true":
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
)
return prompt_template
@classmethod
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
if has_context == 'true':
prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
if has_context == "true":
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
)
return prompt_template
@classmethod
def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == AppMode.CHAT.value:
if model_mode == "completion":
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif app_mode == AppMode.COMPLETION.value:
if model_mode == "completion":
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
has_context,
baichuan_context_prompt,
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)

View File

@ -10,59 +10,65 @@ from models.model import App, Conversation, EndUser, Message, MessageAgentThough
class AgentService:
@classmethod
def get_agent_logs(cls, app_model: App,
conversation_id: str,
message_id: str) -> dict:
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
"""
Service to get agent logs
"""
conversation: Conversation = db.session.query(Conversation).filter(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
).first()
conversation: Conversation = (
db.session.query(Conversation)
.filter(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
)
.first()
)
if not conversation:
raise ValueError(f"Conversation not found: {conversation_id}")
message: Message = db.session.query(Message).filter(
Message.id == message_id,
Message.conversation_id == conversation_id,
).first()
message: Message = (
db.session.query(Message)
.filter(
Message.id == message_id,
Message.conversation_id == conversation_id,
)
.first()
)
if not message:
raise ValueError(f"Message not found: {message_id}")
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if conversation.from_end_user_id:
# only select name field
executor = db.session.query(EndUser, EndUser.name).filter(
EndUser.id == conversation.from_end_user_id
).first()
executor = (
db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first()
)
else:
executor = db.session.query(Account, Account.name).filter(
Account.id == conversation.from_account_id
).first()
executor = (
db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first()
)
if executor:
executor = executor.name
else:
executor = 'Unknown'
executor = "Unknown"
timezone = pytz.timezone(current_user.timezone)
result = {
'meta': {
'status': 'success',
'executor': executor,
'start_time': message.created_at.astimezone(timezone).isoformat(),
'elapsed_time': message.provider_response_latency,
'total_tokens': message.answer_tokens + message.message_tokens,
'agent_mode': app_model.app_model_config.agent_mode_dict.get('strategy', 'react'),
'iterations': len(agent_thoughts),
"meta": {
"status": "success",
"executor": executor,
"start_time": message.created_at.astimezone(timezone).isoformat(),
"elapsed_time": message.provider_response_latency,
"total_tokens": message.answer_tokens + message.message_tokens,
"agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"),
"iterations": len(agent_thoughts),
},
'iterations': [],
'files': message.files,
"iterations": [],
"files": message.files,
}
agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict())
@ -86,12 +92,12 @@ class AgentService:
tool_input = tool_inputs.get(tool_name, {})
tool_output = tool_outputs.get(tool_name, {})
tool_meta_data = tool_meta.get(tool_name, {})
tool_config = tool_meta_data.get('tool_config', {})
if tool_config.get('tool_provider_type', '') != 'dataset-retrieval':
tool_config = tool_meta_data.get("tool_config", {})
if tool_config.get("tool_provider_type", "") != "dataset-retrieval":
tool_icon = ToolManager.get_tool_icon(
tenant_id=app_model.tenant_id,
provider_type=tool_config.get('tool_provider_type', ''),
provider_id=tool_config.get('tool_provider', ''),
provider_type=tool_config.get("tool_provider_type", ""),
provider_id=tool_config.get("tool_provider", ""),
)
if not tool_icon:
tool_entity = find_agent_tool(tool_name)
@ -102,30 +108,34 @@ class AgentService:
provider_id=tool_entity.provider_id,
)
else:
tool_icon = ''
tool_icon = ""
tool_calls.append({
'status': 'success' if not tool_meta_data.get('error') else 'error',
'error': tool_meta_data.get('error'),
'time_cost': tool_meta_data.get('time_cost', 0),
'tool_name': tool_name,
'tool_label': tool_label,
'tool_input': tool_input,
'tool_output': tool_output,
'tool_parameters': tool_meta_data.get('tool_parameters', {}),
'tool_icon': tool_icon,
})
tool_calls.append(
{
"status": "success" if not tool_meta_data.get("error") else "error",
"error": tool_meta_data.get("error"),
"time_cost": tool_meta_data.get("time_cost", 0),
"tool_name": tool_name,
"tool_label": tool_label,
"tool_input": tool_input,
"tool_output": tool_output,
"tool_parameters": tool_meta_data.get("tool_parameters", {}),
"tool_icon": tool_icon,
}
)
result['iterations'].append({
'tokens': agent_thought.tokens,
'tool_calls': tool_calls,
'tool_raw': {
'inputs': agent_thought.tool_input,
'outputs': agent_thought.observation,
},
'thought': agent_thought.thought,
'created_at': agent_thought.created_at.isoformat(),
'files': agent_thought.files,
})
result["iterations"].append(
{
"tokens": agent_thought.tokens,
"tool_calls": tool_calls,
"tool_raw": {
"inputs": agent_thought.tool_input,
"outputs": agent_thought.observation,
},
"thought": agent_thought.thought,
"created_at": agent_thought.created_at.isoformat(),
"files": agent_thought.files,
}
)
return result
return result

View File

@ -23,21 +23,18 @@ class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
if args.get('message_id'):
message_id = str(args['message_id'])
if args.get("message_id"):
message_id = str(args["message_id"])
# get message info
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app.id
).first()
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first()
if not message:
raise NotFound("Message Not Exists.")
@ -45,159 +42,166 @@ class AppAnnotationService:
annotation = message.annotation
# save the message annotation
if annotation:
annotation.content = args['answer']
annotation.question = args['question']
annotation.content = args["answer"]
annotation.question = args["question"]
else:
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=message.conversation_id,
message_id=message.id,
content=args['answer'],
question=args['question'],
account_id=current_user.id
content=args["answer"],
question=args["question"],
account_id=current_user.id,
)
else:
annotation = MessageAnnotation(
app_id=app.id,
content=args['answer'],
question=args['question'],
account_id=current_user.id
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting:
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
app_id, annotation_setting.collection_binding_id)
add_annotation_to_index_task.delay(
annotation.id,
args["question"],
current_user.current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
)
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id))
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
return {
'job_id': cache_result,
'job_status': 'processing'
}
return {"job_id": cache_result, "job_status": "processing"}
# async job
job_id = str(uuid.uuid4())
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id))
# send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, 'waiting')
enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id,
args['score_threshold'],
args['embedding_provider_name'], args['embedding_model_name'])
return {
'job_id': job_id,
'job_status': 'waiting'
}
redis_client.setnx(enable_app_annotation_job_key, "waiting")
enable_annotation_reply_task.delay(
str(job_id),
app_id,
current_user.id,
current_user.current_tenant_id,
args["score_threshold"],
args["embedding_provider_name"],
args["embedding_model_name"],
)
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def disable_app_annotation(cls, app_id: str) -> dict:
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id))
cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None:
return {
'job_id': cache_result,
'job_status': 'processing'
}
return {"job_id": cache_result, "job_status": "processing"}
# async job
job_id = str(uuid.uuid4())
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id))
# send batch add segments task
redis_client.setnx(disable_app_annotation_job_key, 'waiting')
redis_client.setnx(disable_app_annotation_job_key, "waiting")
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
return {
'job_id': job_id,
'job_status': 'waiting'
}
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
if keyword:
annotations = (db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.filter(
or_(
MessageAnnotation.question.ilike('%{}%'.format(keyword)),
MessageAnnotation.content.ilike('%{}%'.format(keyword))
annotations = (
db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.filter(
or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)),
MessageAnnotation.content.ilike("%{}%".format(keyword)),
)
)
.order_by(MessageAnnotation.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
.order_by(MessageAnnotation.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
else:
annotations = (db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
annotations = (
db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
return annotations.items, annotations.total
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotations = (db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc()).all())
annotations = (
db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc())
.all()
)
return annotations
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = MessageAnnotation(
app_id=app.id,
content=args['answer'],
question=args['question'],
account_id=current_user.id
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting:
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
app_id, annotation_setting.collection_binding_id)
add_annotation_to_index_task.delay(
annotation.id,
args["question"],
current_user.current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
)
return annotation
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
@ -207,30 +211,34 @@ class AppAnnotationService:
if not annotation:
raise NotFound("Annotation not found")
annotation.content = args['answer']
annotation.question = args['question']
annotation.content = args["answer"]
annotation.question = args["question"]
db.session.commit()
# if annotation reply is enabled , add annotation to index
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
update_annotation_to_index_task.delay(annotation.id, annotation.question,
current_user.current_tenant_id,
app_id, app_annotation_setting.collection_binding_id)
update_annotation_to_index_task.delay(
annotation.id,
annotation.question,
current_user.current_tenant_id,
app_id,
app_annotation_setting.collection_binding_id,
)
return annotation
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
@ -242,33 +250,34 @@ class AppAnnotationService:
db.session.delete(annotation)
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)
db.session.commit()
# if annotation reply is enabled , delete annotation index
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
delete_annotation_index_task.delay(annotation.id, app_id,
current_user.current_tenant_id,
app_annotation_setting.collection_binding_id)
delete_annotation_index_task.delay(
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
)
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
@ -278,10 +287,7 @@ class AppAnnotationService:
df = pd.read_csv(file)
result = []
for index, row in df.iterrows():
content = {
'question': row[0],
'answer': row[1]
}
content = {"question": row[0], "answer": row[1]}
result.append(content)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
@ -293,28 +299,24 @@ class AppAnnotationService:
raise ValueError("The number of annotations exceeds the limit of your subscription.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
# send batch add segments task
redis_client.setnx(indexing_cache_key, 'waiting')
batch_import_annotations_task.delay(str(job_id), result, app_id,
current_user.current_tenant_id, current_user.id)
redis_client.setnx(indexing_cache_key, "waiting")
batch_import_annotations_task.delay(
str(job_id), result, app_id, current_user.current_tenant_id, current_user.id
)
except Exception as e:
return {
'error_msg': str(e)
}
return {
'job_id': job_id,
'job_status': 'waiting'
}
return {"error_msg": str(e)}
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
@ -324,12 +326,15 @@ class AppAnnotationService:
if not annotation:
raise NotFound("Annotation not found")
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
.order_by(AppAnnotationHitHistory.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
.filter(
AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
.order_by(AppAnnotationHitHistory.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
return annotation_hit_histories.items, annotation_hit_histories.total
@classmethod
@ -341,15 +346,21 @@ class AppAnnotationService:
return annotation
@classmethod
def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str,
annotation_content: str, query: str, user_id: str,
message_id: str, from_source: str, score: float):
def add_annotation_history(
cls,
annotation_id: str,
app_id: str,
annotation_question: str,
annotation_content: str,
query: str,
user_id: str,
message_id: str,
from_source: str,
score: float,
):
# add hit count to annotation
db.session.query(MessageAnnotation).filter(
MessageAnnotation.id == annotation_id
).update(
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1},
synchronize_session=False
db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update(
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False
)
annotation_hit_history = AppAnnotationHitHistory(
@ -361,7 +372,7 @@ class AppAnnotationService:
score=score,
message_id=message_id,
annotation_question=annotation_question,
annotation_content=annotation_content
annotation_content=annotation_content,
)
db.session.add(annotation_hit_history)
db.session.commit()
@ -369,17 +380,18 @@ class AppAnnotationService:
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
@ -388,32 +400,34 @@ class AppAnnotationService:
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name
}
"embedding_model_name": collection_binding_detail.model_name,
},
}
return {
"enabled": False
}
return {"enabled": False}
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id,
AppAnnotationSetting.id == annotation_setting_id,
).first()
annotation_setting = (
db.session.query(AppAnnotationSetting)
.filter(
AppAnnotationSetting.app_id == app_id,
AppAnnotationSetting.id == annotation_setting_id,
)
.first()
)
if not annotation_setting:
raise NotFound("App annotation not found")
annotation_setting.score_threshold = args['score_threshold']
annotation_setting.score_threshold = args["score_threshold"]
annotation_setting.updated_user_id = current_user.id
annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(annotation_setting)
@ -427,6 +441,6 @@ class AppAnnotationService:
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name
}
"embedding_model_name": collection_binding_detail.model_name,
},
}

View File

@ -5,13 +5,14 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
class APIBasedExtensionService:
@staticmethod
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
extension_list = db.session.query(APIBasedExtension) \
.filter_by(tenant_id=tenant_id) \
.order_by(APIBasedExtension.created_at.desc()) \
.all()
extension_list = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=tenant_id)
.order_by(APIBasedExtension.created_at.desc())
.all()
)
for extension in extension_list:
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
@ -35,10 +36,12 @@ class APIBasedExtensionService:
@staticmethod
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
extension = db.session.query(APIBasedExtension) \
.filter_by(tenant_id=tenant_id) \
.filter_by(id=api_based_extension_id) \
extension = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=tenant_id)
.filter_by(id=api_based_extension_id)
.first()
)
if not extension:
raise ValueError("API based extension is not found")
@ -55,20 +58,24 @@ class APIBasedExtensionService:
if not extension_data.id:
# case one: check new data, name must be unique
is_name_existed = db.session.query(APIBasedExtension) \
.filter_by(tenant_id=extension_data.tenant_id) \
.filter_by(name=extension_data.name) \
is_name_existed = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=extension_data.tenant_id)
.filter_by(name=extension_data.name)
.first()
)
if is_name_existed:
raise ValueError("name must be unique, it is already existed")
else:
# case two: check existing data, name must be unique
is_name_existed = db.session.query(APIBasedExtension) \
.filter_by(tenant_id=extension_data.tenant_id) \
.filter_by(name=extension_data.name) \
.filter(APIBasedExtension.id != extension_data.id) \
is_name_existed = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=extension_data.tenant_id)
.filter_by(name=extension_data.name)
.filter(APIBasedExtension.id != extension_data.id)
.first()
)
if is_name_existed:
raise ValueError("name must be unique, it is already existed")
@ -92,7 +99,7 @@ class APIBasedExtensionService:
try:
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
resp = client.request(point=APIBasedExtensionPoint.PING, params={})
if resp.get('result') != 'pong':
if resp.get("result") != "pong":
raise ValueError(resp)
except Exception as e:
raise ValueError("connection error: {}".format(e))

View File

@ -75,43 +75,44 @@ class AppDslService:
# check or repair dsl version
import_data = cls._check_or_fix_dsl(import_data)
app_data = import_data.get('app')
app_data = import_data.get("app")
if not app_data:
raise ValueError("Missing app in data argument")
# get app basic info
name = args.get("name") if args.get("name") else app_data.get('name')
description = args.get("description") if args.get("description") else app_data.get('description', '')
icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get('icon_type')
icon = args.get("icon") if args.get("icon") else app_data.get('icon')
icon_background = args.get("icon_background") if args.get("icon_background") \
else app_data.get('icon_background')
name = args.get("name") if args.get("name") else app_data.get("name")
description = args.get("description") if args.get("description") else app_data.get("description", "")
icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type")
icon = args.get("icon") if args.get("icon") else app_data.get("icon")
icon_background = (
args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background")
)
# import dsl and create app
app_mode = AppMode.value_of(app_data.get('mode'))
app_mode = AppMode.value_of(app_data.get("mode"))
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
app = cls._import_and_create_new_workflow_based_app(
tenant_id=tenant_id,
app_mode=app_mode,
workflow_data=import_data.get('workflow'),
workflow_data=import_data.get("workflow"),
account=account,
name=name,
description=description,
icon_type=icon_type,
icon=icon,
icon_background=icon_background
icon_background=icon_background,
)
elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]:
app = cls._import_and_create_new_model_config_based_app(
tenant_id=tenant_id,
app_mode=app_mode,
model_config_data=import_data.get('model_config'),
model_config_data=import_data.get("model_config"),
account=account,
name=name,
description=description,
icon_type=icon_type,
icon=icon,
icon_background=icon_background
icon_background=icon_background,
)
else:
raise ValueError("Invalid app mode")
@ -134,27 +135,26 @@ class AppDslService:
# check or repair dsl version
import_data = cls._check_or_fix_dsl(import_data)
app_data = import_data.get('app')
app_data = import_data.get("app")
if not app_data:
raise ValueError("Missing app in data argument")
# import dsl and overwrite app
app_mode = AppMode.value_of(app_data.get('mode'))
app_mode = AppMode.value_of(app_data.get("mode"))
if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
raise ValueError("Only support import workflow in advanced-chat or workflow app.")
if app_data.get('mode') != app_model.mode:
raise ValueError(
f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}")
if app_data.get("mode") != app_model.mode:
raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}")
return cls._import_and_overwrite_workflow_based_app(
app_model=app_model,
workflow_data=import_data.get('workflow'),
workflow_data=import_data.get("workflow"),
account=account,
)
@classmethod
def export_dsl(cls, app_model: App, include_secret:bool = False) -> str:
def export_dsl(cls, app_model: App, include_secret: bool = False) -> str:
"""
Export app
:param app_model: App instance
@ -168,14 +168,16 @@ class AppDslService:
"app": {
"name": app_model.name,
"mode": app_model.mode,
"icon": '🤖' if app_model.icon_type == 'image' else app_model.icon,
"icon_background": '#FFEAD5' if app_model.icon_type == 'image' else app_model.icon_background,
"description": app_model.description
}
"icon": "🤖" if app_model.icon_type == "image" else app_model.icon,
"icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background,
"description": app_model.description,
},
}
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret)
cls._append_workflow_export_data(
export_data=export_data, app_model=app_model, include_secret=include_secret
)
else:
cls._append_model_config_export_data(export_data, app_model)
@ -188,31 +190,35 @@ class AppDslService:
:param import_data: import data
"""
if not import_data.get('version'):
import_data['version'] = "0.1.0"
if not import_data.get("version"):
import_data["version"] = "0.1.0"
if not import_data.get('kind') or import_data.get('kind') != "app":
import_data['kind'] = "app"
if not import_data.get("kind") or import_data.get("kind") != "app":
import_data["kind"] = "app"
if import_data.get('version') != current_dsl_version:
if import_data.get("version") != current_dsl_version:
# Currently only one DSL version, so no difference checks or compatibility fixes will be performed.
logger.warning(f"DSL version {import_data.get('version')} is not compatible "
f"with current version {current_dsl_version}, related to "
f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}.")
logger.warning(
f"DSL version {import_data.get('version')} is not compatible "
f"with current version {current_dsl_version}, related to "
f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}."
)
return import_data
@classmethod
def _import_and_create_new_workflow_based_app(cls,
tenant_id: str,
app_mode: AppMode,
workflow_data: dict,
account: Account,
name: str,
description: str,
icon_type: str,
icon: str,
icon_background: str) -> App:
def _import_and_create_new_workflow_based_app(
cls,
tenant_id: str,
app_mode: AppMode,
workflow_data: dict,
account: Account,
name: str,
description: str,
icon_type: str,
icon: str,
icon_background: str,
) -> App:
"""
Import app dsl and create new workflow based app
@ -227,8 +233,7 @@ class AppDslService:
:param icon_background: app icon background
"""
if not workflow_data:
raise ValueError("Missing workflow in data argument "
"when app mode is advanced-chat or workflow")
raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow")
app = cls._create_app(
tenant_id=tenant_id,
@ -238,37 +243,32 @@ class AppDslService:
description=description,
icon_type=icon_type,
icon=icon,
icon_background=icon_background
icon_background=icon_background,
)
# init draft workflow
environment_variables_list = workflow_data.get('environment_variables') or []
environment_variables_list = workflow_data.get("environment_variables") or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
conversation_variables_list = workflow_data.get('conversation_variables') or []
conversation_variables_list = workflow_data.get("conversation_variables") or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
workflow_service = WorkflowService()
draft_workflow = workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow_data.get('graph', {}),
features=workflow_data.get('../core/app/features', {}),
graph=workflow_data.get("graph", {}),
features=workflow_data.get("../core/app/features", {}),
unique_hash=None,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
workflow_service.publish_workflow(
app_model=app,
account=account,
draft_workflow=draft_workflow
)
workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow)
return app
@classmethod
def _import_and_overwrite_workflow_based_app(cls,
app_model: App,
workflow_data: dict,
account: Account) -> Workflow:
def _import_and_overwrite_workflow_based_app(
cls, app_model: App, workflow_data: dict, account: Account
) -> Workflow:
"""
Import app dsl and overwrite workflow based app
@ -277,8 +277,7 @@ class AppDslService:
:param account: Account instance
"""
if not workflow_data:
raise ValueError("Missing workflow in data argument "
"when app mode is advanced-chat or workflow")
raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow")
# fetch draft workflow by app_model
workflow_service = WorkflowService()
@ -289,14 +288,14 @@ class AppDslService:
unique_hash = None
# sync draft workflow
environment_variables_list = workflow_data.get('environment_variables') or []
environment_variables_list = workflow_data.get("environment_variables") or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
conversation_variables_list = workflow_data.get('conversation_variables') or []
conversation_variables_list = workflow_data.get("conversation_variables") or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
draft_workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=workflow_data.get('graph', {}),
features=workflow_data.get('features', {}),
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
@ -306,16 +305,18 @@ class AppDslService:
return draft_workflow
@classmethod
def _import_and_create_new_model_config_based_app(cls,
tenant_id: str,
app_mode: AppMode,
model_config_data: dict,
account: Account,
name: str,
description: str,
icon_type: str,
icon: str,
icon_background: str) -> App:
def _import_and_create_new_model_config_based_app(
cls,
tenant_id: str,
app_mode: AppMode,
model_config_data: dict,
account: Account,
name: str,
description: str,
icon_type: str,
icon: str,
icon_background: str,
) -> App:
"""
Import app dsl and create new model config based app
@ -329,8 +330,7 @@ class AppDslService:
:param icon_background: app icon background
"""
if not model_config_data:
raise ValueError("Missing model_config in data argument "
"when app mode is chat, agent-chat or completion")
raise ValueError("Missing model_config in data argument " "when app mode is chat, agent-chat or completion")
app = cls._create_app(
tenant_id=tenant_id,
@ -340,7 +340,7 @@ class AppDslService:
description=description,
icon_type=icon_type,
icon=icon,
icon_background=icon_background
icon_background=icon_background,
)
app_model_config = AppModelConfig()
@ -352,23 +352,22 @@ class AppDslService:
app.app_model_config_id = app_model_config.id
app_model_config_was_updated.send(
app,
app_model_config=app_model_config
)
app_model_config_was_updated.send(app, app_model_config=app_model_config)
return app
@classmethod
def _create_app(cls,
tenant_id: str,
app_mode: AppMode,
account: Account,
name: str,
description: str,
icon_type: str,
icon: str,
icon_background: str) -> App:
def _create_app(
cls,
tenant_id: str,
app_mode: AppMode,
account: Account,
name: str,
description: str,
icon_type: str,
icon: str,
icon_background: str,
) -> App:
"""
Create new app
@ -390,7 +389,7 @@ class AppDslService:
icon=icon,
icon_background=icon_background,
enable_site=True,
enable_api=True
enable_api=True,
)
db.session.add(app)
@ -412,7 +411,7 @@ class AppDslService:
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
export_data['workflow'] = workflow.to_dict(include_secret=include_secret)
export_data["workflow"] = workflow.to_dict(include_secret=include_secret)
@classmethod
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
@ -425,4 +424,4 @@ class AppDslService:
if not app_model_config:
raise ValueError("Missing app configuration, please check.")
export_data['model_config'] = app_model_config.to_dict()
export_data["model_config"] = app_model_config.to_dict()

View File

@ -14,14 +14,15 @@ from services.workflow_service import WorkflowService
class AppGenerateService:
@classmethod
def generate(cls, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
streaming: bool = True,
):
def generate(
cls,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
streaming: bool = True,
):
"""
App Content Generate
:param app_model: app model
@ -37,51 +38,54 @@ class AppGenerateService:
try:
request_id = rate_limit.enter(request_id)
if app_model.mode == AppMode.COMPLETION.value:
return rate_limit.generate(CompletionAppGenerator().generate(
app_model=app_model,
user=user,
args=args,
invoke_from=invoke_from,
stream=streaming
), request_id)
return rate_limit.generate(
CompletionAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
),
request_id,
)
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
return rate_limit.generate(AgentChatAppGenerator().generate(
app_model=app_model,
user=user,
args=args,
invoke_from=invoke_from,
stream=streaming
), request_id)
return rate_limit.generate(
AgentChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
),
request_id,
)
elif app_model.mode == AppMode.CHAT.value:
return rate_limit.generate(ChatAppGenerator().generate(
app_model=app_model,
user=user,
args=args,
invoke_from=invoke_from,
stream=streaming
), request_id)
return rate_limit.generate(
ChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
),
request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, invoke_from)
return rate_limit.generate(AdvancedChatAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
stream=streaming
), request_id)
return rate_limit.generate(
AdvancedChatAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
stream=streaming,
),
request_id,
)
elif app_model.mode == AppMode.WORKFLOW.value:
workflow = cls._get_workflow(app_model, invoke_from)
return rate_limit.generate(WorkflowAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
stream=streaming
), request_id)
return rate_limit.generate(
WorkflowAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
stream=streaming,
),
request_id,
)
else:
raise ValueError(f'Invalid app mode {app_model.mode}')
raise ValueError(f"Invalid app mode {app_model.mode}")
finally:
if not streaming:
rate_limit.exit(request_id)
@ -94,38 +98,31 @@ class AppGenerateService:
return max_active_requests
@classmethod
def generate_single_iteration(cls, app_model: App,
user: Union[Account, EndUser],
node_id: str,
args: Any,
streaming: bool = True):
def generate_single_iteration(
cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True
):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
stream=streaming
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming
)
elif app_model.mode == AppMode.WORKFLOW.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return WorkflowAppGenerator().single_iteration_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
stream=streaming
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming
)
else:
raise ValueError(f'Invalid app mode {app_model.mode}')
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
-> Union[dict, Generator]:
def generate_more_like_this(
cls,
app_model: App,
user: Union[Account, EndUser],
message_id: str,
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[dict, Generator]:
"""
Generate more like this
:param app_model: app model
@ -136,11 +133,7 @@ class AppGenerateService:
:return:
"""
return CompletionAppGenerator().generate_more_like_this(
app_model=app_model,
message_id=message_id,
user=user,
invoke_from=invoke_from,
stream=streaming
app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming
)
@classmethod
@ -157,12 +150,12 @@ class AppGenerateService:
workflow = workflow_service.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError('Workflow not initialized')
raise ValueError("Workflow not initialized")
else:
# fetch published workflow by app_model
workflow = workflow_service.get_published_workflow(app_model=app_model)
if not workflow:
raise ValueError('Workflow not published')
raise ValueError("Workflow not published")
return workflow

View File

@ -5,7 +5,6 @@ from models.model import AppMode
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
if app_mode == AppMode.CHAT:

View File

@ -33,27 +33,22 @@ class AppService:
:param args: request args
:return:
"""
filters = [
App.tenant_id == tenant_id,
App.is_universal == False
]
filters = [App.tenant_id == tenant_id, App.is_universal == False]
if args['mode'] == 'workflow':
if args["mode"] == "workflow":
filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value]))
elif args['mode'] == 'chat':
elif args["mode"] == "chat":
filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value]))
elif args['mode'] == 'agent-chat':
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT.value)
elif args['mode'] == 'channel':
elif args["mode"] == "channel":
filters.append(App.mode == AppMode.CHANNEL.value)
if args.get('name'):
name = args['name'][:30]
filters.append(App.name.ilike(f'%{name}%'))
if args.get('tag_ids'):
target_ids = TagService.get_target_ids_by_tag_ids('app',
tenant_id,
args['tag_ids'])
if args.get("name"):
name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%"))
if args.get("tag_ids"):
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
if target_ids:
filters.append(App.id.in_(target_ids))
else:
@ -61,9 +56,9 @@ class AppService:
app_models = db.paginate(
db.select(App).where(*filters).order_by(App.created_at.desc()),
page=args['page'],
per_page=args['limit'],
error_out=False
page=args["page"],
per_page=args["limit"],
error_out=False,
)
return app_models
@ -75,21 +70,20 @@ class AppService:
:param args: request args
:param account: Account instance
"""
app_mode = AppMode.value_of(args['mode'])
app_mode = AppMode.value_of(args["mode"])
app_template = default_app_templates[app_mode]
# get model config
default_model_config = app_template.get('model_config')
default_model_config = app_template.get("model_config")
default_model_config = default_model_config.copy() if default_model_config else None
if default_model_config and 'model' in default_model_config:
if default_model_config and "model" in default_model_config:
# get model provider
model_manager = ModelManager()
# get default model instance
try:
model_instance = model_manager.get_default_model_instance(
tenant_id=account.current_tenant_id,
model_type=ModelType.LLM
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
)
except (ProviderTokenNotInitError, LLMBadRequestError):
model_instance = None
@ -98,39 +92,41 @@ class AppService:
model_instance = None
if model_instance:
if model_instance.model == default_model_config['model']['name'] and model_instance.provider == default_model_config['model']['provider']:
default_model_dict = default_model_config['model']
if (
model_instance.model == default_model_config["model"]["name"]
and model_instance.provider == default_model_config["model"]["provider"]
):
default_model_dict = default_model_config["model"]
else:
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
default_model_dict = {
'provider': model_instance.provider,
'name': model_instance.model,
'mode': model_schema.model_properties.get(ModelPropertyKey.MODE),
'completion_params': {}
"provider": model_instance.provider,
"name": model_instance.model,
"mode": model_schema.model_properties.get(ModelPropertyKey.MODE),
"completion_params": {},
}
else:
provider, model = model_manager.get_default_provider_model_name(
tenant_id=account.current_tenant_id,
model_type=ModelType.LLM
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
)
default_model_config['model']['provider'] = provider
default_model_config['model']['name'] = model
default_model_dict = default_model_config['model']
default_model_config["model"]["provider"] = provider
default_model_config["model"]["name"] = model
default_model_dict = default_model_config["model"]
default_model_config['model'] = json.dumps(default_model_dict)
default_model_config["model"] = json.dumps(default_model_dict)
app = App(**app_template['app'])
app.name = args['name']
app.description = args.get('description', '')
app.mode = args['mode']
app.icon_type = args.get('icon_type', 'emoji')
app.icon = args['icon']
app.icon_background = args['icon_background']
app = App(**app_template["app"])
app.name = args["name"]
app.description = args.get("description", "")
app.mode = args["mode"]
app.icon_type = args.get("icon_type", "emoji")
app.icon = args["icon"]
app.icon_background = args["icon_background"]
app.tenant_id = tenant_id
app.api_rph = args.get('api_rph', 0)
app.api_rpm = args.get('api_rpm', 0)
app.api_rph = args.get("api_rph", 0)
app.api_rpm = args.get("api_rpm", 0)
db.session.add(app)
db.session.flush()
@ -158,7 +154,7 @@ class AppService:
model_config: AppModelConfig = app.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
for tool in agent_mode.get("tools") or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
@ -174,7 +170,7 @@ class AppService:
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app.id}'
identity_id=f"AGENT.{app.id}",
)
# get decrypted parameters
@ -185,7 +181,7 @@ class AppService:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
tool["tool_parameters"] = masked_parameter
except Exception as e:
pass
@ -215,12 +211,12 @@ class AppService:
:param args: request args
:return: App instance
"""
app.name = args.get('name')
app.description = args.get('description', '')
app.max_active_requests = args.get('max_active_requests')
app.icon_type = args.get('icon_type', 'emoji')
app.icon = args.get('icon')
app.icon_background = args.get('icon_background')
app.name = args.get("name")
app.description = args.get("description", "")
app.max_active_requests = args.get("max_active_requests")
app.icon_type = args.get("icon_type", "emoji")
app.icon = args.get("icon")
app.icon_background = args.get("icon_background")
app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
@ -298,10 +294,7 @@ class AppService:
db.session.commit()
# Trigger asynchronous deletion of app and related data
remove_app_and_related_data_task.delay(
tenant_id=app.tenant_id,
app_id=app.id
)
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
def get_app_meta(self, app_model: App) -> dict:
"""
@ -311,9 +304,7 @@ class AppService:
"""
app_mode = AppMode.value_of(app_model.mode)
meta = {
'tool_icons': {}
}
meta = {"tool_icons": {}}
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
workflow = app_model.workflow
@ -321,17 +312,19 @@ class AppService:
return meta
graph = workflow.graph_dict
nodes = graph.get('nodes', [])
nodes = graph.get("nodes", [])
tools = []
for node in nodes:
if node.get('data', {}).get('type') == 'tool':
node_data = node.get('data', {})
tools.append({
'provider_type': node_data.get('provider_type'),
'provider_id': node_data.get('provider_id'),
'tool_name': node_data.get('tool_name'),
'tool_parameters': {}
})
if node.get("data", {}).get("type") == "tool":
node_data = node.get("data", {})
tools.append(
{
"provider_type": node_data.get("provider_type"),
"provider_id": node_data.get("provider_id"),
"tool_name": node_data.get("tool_name"),
"tool_parameters": {},
}
)
else:
app_model_config: AppModelConfig = app_model.app_model_config
@ -341,30 +334,26 @@ class AppService:
agent_config = app_model_config.agent_mode_dict or {}
# get all tools
tools = agent_config.get('tools', [])
tools = agent_config.get("tools", [])
url_prefix = (dify_config.CONSOLE_API_URL
+ "/console/api/workspaces/current/tool-provider/builtin/")
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
for tool in tools:
keys = list(tool.keys())
if len(keys) >= 4:
# current tool standard
provider_type = tool.get('provider_type')
provider_id = tool.get('provider_id')
tool_name = tool.get('tool_name')
if provider_type == 'builtin':
meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon'
elif provider_type == 'api':
provider_type = tool.get("provider_type")
provider_id = tool.get("provider_id")
tool_name = tool.get("tool_name")
if provider_type == "builtin":
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":
try:
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.id == provider_id
).first()
meta['tool_icons'][tool_name] = json.loads(provider.icon)
provider: ApiToolProvider = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first()
)
meta["tool_icons"][tool_name] = json.loads(provider.icon)
except:
meta['tool_icons'][tool_name] = {
"background": "#252525",
"content": "\ud83d\ude01"
}
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
return meta

View File

@ -17,7 +17,7 @@ from services.errors.audio import (
FILE_SIZE = 30
FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr']
ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"]
logger = logging.getLogger(__name__)
@ -31,19 +31,19 @@ class AudioService:
raise ValueError("Speech to text is not enabled")
features_dict = workflow.features_dict
if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'):
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
raise ValueError("Speech to text is not enabled")
else:
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
if not app_model_config.speech_to_text_dict["enabled"]:
raise ValueError("Speech to text is not enabled")
if file is None:
raise NoAudioUploadedServiceError()
extension = file.mimetype
if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]:
if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]:
raise UnsupportedAudioTypeServiceError()
file_content = file.read()
@ -55,20 +55,25 @@ class AudioService:
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id,
model_type=ModelType.SPEECH2TEXT
tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT
)
if model_instance is None:
raise ProviderNotSupportSpeechToTextServiceError()
buffer = io.BytesIO(file_content)
buffer.name = 'temp.mp3'
buffer.name = "temp.mp3"
return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)}
@classmethod
def transcript_tts(cls, app_model: App, text: Optional[str] = None,
voice: Optional[str] = None, end_user: Optional[str] = None, message_id: Optional[str] = None):
def transcript_tts(
cls,
app_model: App,
text: Optional[str] = None,
voice: Optional[str] = None,
end_user: Optional[str] = None,
message_id: Optional[str] = None,
):
from collections.abc import Generator
from flask import Response, stream_with_context
@ -84,65 +89,56 @@ class AudioService:
raise ValueError("TTS is not enabled")
features_dict = workflow.features_dict
if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'):
if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"):
raise ValueError("TTS is not enabled")
voice = features_dict['text_to_speech'].get('voice') if voice is None else voice
voice = features_dict["text_to_speech"].get("voice") if voice is None else voice
else:
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
if not text_to_speech_dict.get('enabled'):
if not text_to_speech_dict.get("enabled"):
raise ValueError("TTS is not enabled")
voice = text_to_speech_dict.get('voice') if voice is None else voice
voice = text_to_speech_dict.get("voice") if voice is None else voice
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id,
model_type=ModelType.TTS
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
)
try:
if not voice:
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get('value')
voice = voices[0].get("value")
else:
raise ValueError("Sorry, no voice available.")
return model_instance.invoke_tts(
content_text=text_content.strip(),
user=end_user,
tenant_id=app_model.tenant_id,
voice=voice
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
)
except Exception as e:
raise e
if message_id:
message = db.session.query(Message).filter(
Message.id == message_id
).first()
if message.answer == '' and message.status == 'normal':
message = db.session.query(Message).filter(Message.id == message_id).first()
if message.answer == "" and message.status == "normal":
return None
else:
response = invoke_tts(message.answer, app_model=app_model, voice=voice)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type='audio/mpeg')
return Response(stream_with_context(response), content_type="audio/mpeg")
return response
else:
response = invoke_tts(text, app_model, voice)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type='audio/mpeg')
return Response(stream_with_context(response), content_type="audio/mpeg")
return response
@classmethod
def transcript_tts_voices(cls, tenant_id: str, language: str):
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TTS
)
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS)
if model_instance is None:
raise ProviderNotSupportTextToSpeechServiceError()

View File

@ -1,14 +1,12 @@
from services.auth.firecrawl import FirecrawlAuth
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
if provider == 'firecrawl':
if provider == "firecrawl":
self.auth = FirecrawlAuth(credentials)
else:
raise ValueError('Invalid provider')
raise ValueError("Invalid provider")
def validate_credentials(self):
return self.auth.validate_credentials()

View File

@ -7,39 +7,43 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str) -> list:
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.disabled.is_(False)
).all()
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all()
)
return data_source_api_key_bindings
@staticmethod
def create_provider_auth(tenant_id: str, args: dict):
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
if auth_result:
# Encrypt the api key
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
args['credentials']['config']['api_key'] = api_key
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
args["credentials"]["config"]["api_key"] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
data_source_api_key_binding.tenant_id = tenant_id
data_source_api_key_binding.category = args['category']
data_source_api_key_binding.provider = args['provider']
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
data_source_api_key_binding.category = args["category"]
data_source_api_key_binding.provider = args["provider"]
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
DataSourceApiKeyAuthBinding.disabled.is_(False)
).first()
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
DataSourceApiKeyAuthBinding.disabled.is_(False),
)
.first()
)
if not data_source_api_key_bindings:
return None
credentials = json.loads(data_source_api_key_bindings.credentials)
@ -47,24 +51,24 @@ class ApiKeyAuthService:
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.id == binding_id
).first()
data_source_api_key_binding = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
.first()
)
if data_source_api_key_binding:
db.session.delete(data_source_api_key_binding)
db.session.commit()
@classmethod
def validate_api_key_auth_args(cls, args):
if 'category' not in args or not args['category']:
raise ValueError('category is required')
if 'provider' not in args or not args['provider']:
raise ValueError('provider is required')
if 'credentials' not in args or not args['credentials']:
raise ValueError('credentials is required')
if not isinstance(args['credentials'], dict):
raise ValueError('credentials must be a dictionary')
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
raise ValueError('auth_type is required')
if "category" not in args or not args["category"]:
raise ValueError("category is required")
if "provider" not in args or not args["provider"]:
raise ValueError("provider is required")
if "credentials" not in args or not args["credentials"]:
raise ValueError("credentials is required")
if not isinstance(args["credentials"], dict):
raise ValueError("credentials must be a dictionary")
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
raise ValueError("auth_type is required")

View File

@ -8,49 +8,40 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get('auth_type')
if auth_type != 'bearer':
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
self.api_key = credentials.get('config').get('api_key', None)
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
self.api_key = credentials.get("config").get("api_key", None)
self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev")
if not self.api_key:
raise ValueError('No API key provided')
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
options = {
'url': 'https://example.com',
'crawlerOptions': {
'excludes': [],
'includes': [],
'limit': 1
},
'pageOptions': {
'onlyMainContent': True
}
"url": "https://example.com",
"crawlerOptions": {"excludes": [], "includes": [], "limit": 1},
"pageOptions": {"onlyMainContent": True},
}
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
response = self._post_request(f"{self.base_url}/v0/crawl", options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in [402, 409, 500]:
error_message = response.json().get('error', 'Unknown error occurred')
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

View File

@ -7,58 +7,40 @@ from models.account import TenantAccountJoin, TenantAccountRole
class BillingService:
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
@classmethod
def get_info(cls, tenant_id: str):
params = {'tenant_id': tenant_id}
params = {"tenant_id": tenant_id}
billing_info = cls._send_request('GET', '/subscription/info', params=params)
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info
@classmethod
def get_subscription(cls, plan: str,
interval: str,
prefilled_email: str = '',
tenant_id: str = ''):
params = {
'plan': plan,
'interval': interval,
'prefilled_email': prefilled_email,
'tenant_id': tenant_id
}
return cls._send_request('GET', '/subscription/payment-link', params=params)
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
return cls._send_request("GET", "/subscription/payment-link", params=params)
@classmethod
def get_model_provider_payment_link(cls,
provider_name: str,
tenant_id: str,
account_id: str,
prefilled_email: str):
def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str):
params = {
'provider_name': provider_name,
'tenant_id': tenant_id,
'account_id': account_id,
'prefilled_email': prefilled_email
"provider_name": provider_name,
"tenant_id": tenant_id,
"account_id": account_id,
"prefilled_email": prefilled_email,
}
return cls._send_request('GET', '/model-provider/payment-link', params=params)
return cls._send_request("GET", "/model-provider/payment-link", params=params)
@classmethod
def get_invoices(cls, prefilled_email: str = '', tenant_id: str = ''):
params = {
'prefilled_email': prefilled_email,
'tenant_id': tenant_id
}
return cls._send_request('GET', '/invoices', params=params)
def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""):
params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id}
return cls._send_request("GET", "/invoices", params=params)
@classmethod
def _send_request(cls, method, endpoint, json=None, params=None):
headers = {
"Content-Type": "application/json",
"Billing-Api-Secret-Key": cls.secret_key
}
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers)
@ -69,10 +51,11 @@ class BillingService:
def is_tenant_owner_or_admin(current_user):
tenant_id = current_user.current_tenant_id
join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.tenant_id == tenant_id,
TenantAccountJoin.account_id == current_user.id
).first()
join = (
db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()
)
if not TenantAccountRole.is_privileged_role(join.role):
raise ValueError('Only team owner or team admin can perform this action')
raise ValueError("Only team owner or team admin can perform this action")

View File

@ -2,12 +2,15 @@ from extensions.ext_code_based_extension import code_based_extension
class CodeBasedExtensionService:
@staticmethod
def get_code_based_extension(module: str) -> list[dict]:
module_extensions = code_based_extension.module_extensions(module)
return [{
'name': module_extension.name,
'label': module_extension.label,
'form_schema': module_extension.form_schema
} for module_extension in module_extensions if not module_extension.builtin]
return [
{
"name": module_extension.name,
"label": module_extension.label,
"form_schema": module_extension.form_schema,
}
for module_extension in module_extensions
if not module_extension.builtin
]

View File

@ -15,22 +15,27 @@ from services.errors.message import MessageNotExistsError
class ConversationService:
@classmethod
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
last_id: Optional[str], limit: int,
invoke_from: InvokeFrom,
include_ids: Optional[list] = None,
exclude_ids: Optional[list] = None,
sort_by: str = '-updated_at') -> InfiniteScrollPagination:
def pagination_by_last_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
limit: int,
invoke_from: InvokeFrom,
include_ids: Optional[list] = None,
exclude_ids: Optional[list] = None,
sort_by: str = "-updated_at",
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = db.session.query(Conversation).filter(
Conversation.is_deleted == False,
Conversation.app_id == app_model.id,
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value)
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
)
if include_ids is not None:
@ -58,28 +63,26 @@ class ConversationService:
has_more = False
if len(conversations) == limit:
current_page_last_conversation = conversations[-1]
rest_filter_condition = cls._build_filter_condition(sort_field, sort_direction,
current_page_last_conversation, is_next_page=True)
rest_filter_condition = cls._build_filter_condition(
sort_field, sort_direction, current_page_last_conversation, is_next_page=True
)
rest_count = base_query.filter(rest_filter_condition).count()
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(
data=conversations,
limit=limit,
has_more=has_more
)
return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more)
@classmethod
def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]:
if sort_by.startswith('-'):
if sort_by.startswith("-"):
return sort_by[1:], desc
return sort_by, asc
@classmethod
def _build_filter_condition(cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation,
is_next_page: bool = False):
def _build_filter_condition(
cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False
):
field_value = getattr(reference_conversation, sort_field)
if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
return getattr(Conversation, sort_field) < field_value
@ -87,8 +90,14 @@ class ConversationService:
return getattr(Conversation, sort_field) > field_value
@classmethod
def rename(cls, app_model: App, conversation_id: str,
user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool):
def rename(
cls,
app_model: App,
conversation_id: str,
user: Optional[Union[Account, EndUser]],
name: str,
auto_generate: bool,
):
conversation = cls.get_conversation(app_model, conversation_id, user)
if auto_generate:
@ -103,11 +112,12 @@ class ConversationService:
@classmethod
def auto_generate_name(cls, app_model: App, conversation: Conversation):
# get conversation first message
message = db.session.query(Message) \
.filter(
Message.app_id == app_model.id,
Message.conversation_id == conversation.id
).order_by(Message.created_at.asc()).first()
message = (
db.session.query(Message)
.filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
.order_by(Message.created_at.asc())
.first()
)
if not message:
raise MessageNotExistsError()
@ -127,15 +137,18 @@ class ConversationService:
@classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
conversation = db.session.query(Conversation) \
conversation = (
db.session.query(Conversation)
.filter(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
Conversation.is_deleted == False
).first()
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
Conversation.is_deleted == False,
)
.first()
)
if not conversation:
raise ConversationNotExistsError()

File diff suppressed because it is too large Load Diff

View File

@ -4,15 +4,12 @@ import requests
class EnterpriseRequest:
base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL')
secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY')
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
@classmethod
def send_request(cls, method, endpoint, json=None, params=None):
headers = {
"Content-Type": "application/json",
"Enterprise-Api-Secret-Key": cls.secret_key
}
headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers)

View File

@ -2,11 +2,10 @@ from services.enterprise.base import EnterpriseRequest
class EnterpriseService:
@classmethod
def get_info(cls):
return EnterpriseRequest.send_request('GET', '/info')
return EnterpriseRequest.send_request("GET", "/info")
@classmethod
def get_app_web_sso_enabled(cls, app_code):
return EnterpriseRequest.send_request('GET', f'/app-sso-setting?appCode={app_code}')
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")

View File

@ -22,14 +22,16 @@ class CustomConfigurationStatus(Enum):
"""
Enum class for custom configuration status.
"""
ACTIVE = 'active'
NO_CONFIGURE = 'no-configure'
ACTIVE = "active"
NO_CONFIGURE = "no-configure"
class CustomConfigurationResponse(BaseModel):
"""
Model class for provider custom configuration response.
"""
status: CustomConfigurationStatus
@ -37,6 +39,7 @@ class SystemConfigurationResponse(BaseModel):
"""
Model class for provider system configuration response.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
@ -46,6 +49,7 @@ class ProviderResponse(BaseModel):
"""
Model class for provider response.
"""
provider: str
label: I18nObject
description: Optional[I18nObject] = None
@ -67,18 +71,15 @@ class ProviderResponse(BaseModel):
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = (dify_config.CONSOLE_API_URL
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US",
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US",
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
@ -86,6 +87,7 @@ class ProviderWithModelsResponse(BaseModel):
"""
Model class for provider with models response.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
@ -96,18 +98,15 @@ class ProviderWithModelsResponse(BaseModel):
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = (dify_config.CONSOLE_API_URL
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US",
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US",
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
@ -119,18 +118,15 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = (dify_config.CONSOLE_API_URL
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US",
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US",
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
@ -138,6 +134,7 @@ class DefaultModelResponse(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: SimpleProviderEntityResponse
@ -150,6 +147,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
"""
Model with provider entity.
"""
provider: SimpleProviderEntityResponse
def __init__(self, model: ModelWithProviderEntity) -> None:

View File

@ -55,4 +55,3 @@ class RoleAlreadyAssignedError(BaseServiceError):
class RateLimitExceededError(BaseServiceError):
pass

View File

@ -1,3 +1,3 @@
class BaseServiceError(Exception):
def __init__(self, description: str = None):
self.description = description
self.description = description

View File

@ -6,8 +6,8 @@ from services.enterprise.enterprise_service import EnterpriseService
class SubscriptionModel(BaseModel):
plan: str = 'sandbox'
interval: str = ''
plan: str = "sandbox"
interval: str = ""
class BillingModel(BaseModel):
@ -27,7 +27,7 @@ class FeatureModel(BaseModel):
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
docs_processing: str = 'standard'
docs_processing: str = "standard"
can_replace_logo: bool = False
model_load_balancing_enabled: bool = False
dataset_operator_enabled: bool = False
@ -38,13 +38,13 @@ class FeatureModel(BaseModel):
class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ''
sso_enforced_for_signin_protocol: str = ""
sso_enforced_for_web: bool = False
sso_enforced_for_web_protocol: str = ''
sso_enforced_for_web_protocol: str = ""
enable_web_sso_switch_component: bool = False
class FeatureService:
class FeatureService:
@classmethod
def get_features(cls, tenant_id: str) -> FeatureModel:
features = FeatureModel()
@ -76,44 +76,44 @@ class FeatureService:
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features.billing.enabled = billing_info['enabled']
features.billing.subscription.plan = billing_info['subscription']['plan']
features.billing.subscription.interval = billing_info['subscription']['interval']
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]
features.billing.subscription.interval = billing_info["subscription"]["interval"]
if 'members' in billing_info:
features.members.size = billing_info['members']['size']
features.members.limit = billing_info['members']['limit']
if "members" in billing_info:
features.members.size = billing_info["members"]["size"]
features.members.limit = billing_info["members"]["limit"]
if 'apps' in billing_info:
features.apps.size = billing_info['apps']['size']
features.apps.limit = billing_info['apps']['limit']
if "apps" in billing_info:
features.apps.size = billing_info["apps"]["size"]
features.apps.limit = billing_info["apps"]["limit"]
if 'vector_space' in billing_info:
features.vector_space.size = billing_info['vector_space']['size']
features.vector_space.limit = billing_info['vector_space']['limit']
if "vector_space" in billing_info:
features.vector_space.size = billing_info["vector_space"]["size"]
features.vector_space.limit = billing_info["vector_space"]["limit"]
if 'documents_upload_quota' in billing_info:
features.documents_upload_quota.size = billing_info['documents_upload_quota']['size']
features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit']
if "documents_upload_quota" in billing_info:
features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"]
features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"]
if 'annotation_quota_limit' in billing_info:
features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size']
features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit']
if "annotation_quota_limit" in billing_info:
features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"]
features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"]
if 'docs_processing' in billing_info:
features.docs_processing = billing_info['docs_processing']
if "docs_processing" in billing_info:
features.docs_processing = billing_info["docs_processing"]
if 'can_replace_logo' in billing_info:
features.can_replace_logo = billing_info['can_replace_logo']
if "can_replace_logo" in billing_info:
features.can_replace_logo = billing_info["can_replace_logo"]
if 'model_load_balancing_enabled' in billing_info:
features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled']
if "model_load_balancing_enabled" in billing_info:
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
@classmethod
def _fulfill_params_from_enterprise(cls, features):
enterprise_info = EnterpriseService.get_info()
features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']
features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web']
features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol']
features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]

View File

@ -17,27 +17,45 @@ from models.account import Account
from models.model import EndUser, UploadFile
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', 'docx', 'csv']
UNSTRUCTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls',
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml', 'epub']
ALLOWED_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
UNSTRUCTURED_ALLOWED_EXTENSIONS = [
"txt",
"markdown",
"md",
"pdf",
"html",
"htm",
"xlsx",
"xls",
"docx",
"csv",
"eml",
"msg",
"pptx",
"ppt",
"xml",
"epub",
]
PREVIEW_WORDS_LIMIT = 3000
class FileService:
@staticmethod
def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
filename = file.filename
extension = file.filename.split('.')[-1]
extension = file.filename.split(".")[-1]
if len(filename) > 200:
filename = filename.split('.')[0][:200] + '.' + extension
filename = filename.split(".")[0][:200] + "." + extension
etl_type = dify_config.ETL_TYPE
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \
allowed_extensions = (
UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
if etl_type == "Unstructured"
else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
)
if extension.lower() not in allowed_extensions:
raise UnsupportedFileTypeError()
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
@ -55,7 +73,7 @@ class FileService:
file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > file_size_limit:
message = f'File size exceeded. {file_size} > {file_size_limit}'
message = f"File size exceeded. {file_size} > {file_size_limit}"
raise FileTooLargeError(message)
# user uuid as file name
@ -67,7 +85,7 @@ class FileService:
# end_user
current_tenant_id = user.tenant_id
file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension
file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension
# save file to storage
storage.save(file_key, file_content)
@ -81,11 +99,11 @@ class FileService:
size=file_size,
extension=extension,
mime_type=file.mimetype,
created_by_role=('account' if isinstance(user, Account) else 'end_user'),
created_by_role=("account" if isinstance(user, Account) else "end_user"),
created_by=user.id,
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
used=False,
hash=hashlib.sha3_256(file_content).hexdigest()
hash=hashlib.sha3_256(file_content).hexdigest(),
)
db.session.add(upload_file)
@ -99,10 +117,10 @@ class FileService:
text_name = text_name[:200]
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt'
file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
# save file to storage
storage.save(file_key, text.encode('utf-8'))
storage.save(file_key, text.encode("utf-8"))
# save file to db
upload_file = UploadFile(
@ -111,13 +129,13 @@ class FileService:
key=file_key,
name=text_name,
size=len(text),
extension='txt',
mime_type='text/plain',
extension="txt",
mime_type="text/plain",
created_by=current_user.id,
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
used=True,
used_by=current_user.id,
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
)
db.session.add(upload_file)
@ -127,9 +145,7 @@ class FileService:
@staticmethod
def get_file_preview(file_id: str) -> str:
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found")
@ -137,12 +153,12 @@ class FileService:
# extract text from file
extension = upload_file.extension
etl_type = dify_config.ETL_TYPE
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
if extension.lower() not in allowed_extensions:
raise UnsupportedFileTypeError()
text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
text = text[0:PREVIEW_WORDS_LIMIT] if text else ""
return text
@ -152,9 +168,7 @@ class FileService:
if not result:
raise NotFound("File not found or signature is invalid")
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -170,9 +184,7 @@ class FileService:
@staticmethod
def get_public_image_preview(file_id: str) -> tuple[Generator, str]:
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")

View File

@ -9,14 +9,11 @@ from models.account import Account
from models.dataset import Dataset, DatasetQuery, DocumentSegment
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
@ -27,9 +24,9 @@ class HitTestingService:
return {
"query": {
"content": query,
"tsne_position": {'x': 0, 'y': 0},
"tsne_position": {"x": 0, "y": 0},
},
"records": []
"records": [],
}
start = time.perf_counter()
@ -38,28 +35,28 @@ class HitTestingService:
if not retrieval_model:
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
dataset_id=dataset.id,
query=cls.escape_query_for_search(query),
top_k=retrieval_model.get('top_k', 2),
score_threshold=retrieval_model.get('score_threshold', .0)
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model.get('reranking_model', None)
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model',
weights=retrieval_model.get('weights', None),
)
all_documents = RetrievalService.retrieve(
retrival_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=cls.escape_query_for_search(query),
top_k=retrieval_model.get("top_k", 2),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else None,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode")
if retrieval_model.get("reranking_mode")
else "reranking_model",
weights=retrieval_model.get("weights", None),
)
end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=query,
source='hit_testing',
created_by_role='account',
created_by=account.id
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
)
db.session.add(dataset_query)
@ -72,14 +69,18 @@ class HitTestingService:
i = 0
records = []
for document in documents:
index_node_id = document.metadata['doc_id']
index_node_id = document.metadata["doc_id"]
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.enabled == True,
DocumentSegment.status == 'completed',
DocumentSegment.index_node_id == index_node_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
i += 1
@ -87,7 +88,7 @@ class HitTestingService:
record = {
"segment": segment,
"score": document.metadata.get('score', None),
"score": document.metadata.get("score", None),
}
records.append(record)
@ -98,15 +99,15 @@ class HitTestingService:
"query": {
"content": query,
},
"records": records
"records": records,
}
@classmethod
def hit_testing_args_check(cls, args):
query = args['query']
query = args["query"]
if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters')
raise ValueError("Query is required and cannot exceed 250 characters")
@staticmethod
def escape_query_for_search(query: str) -> str:

View File

@ -27,8 +27,14 @@ from services.workflow_service import WorkflowService
class MessageService:
@classmethod
def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination:
def pagination_by_first_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
conversation_id: str,
first_id: Optional[str],
limit: int,
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
@ -36,52 +42,69 @@ class MessageService:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
conversation = ConversationService.get_conversation(
app_model=app_model,
user=user,
conversation_id=conversation_id
app_model=app_model, user=user, conversation_id=conversation_id
)
if first_id:
first_message = db.session.query(Message) \
.filter(Message.conversation_id == conversation.id, Message.id == first_id).first()
first_message = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id, Message.id == first_id)
.first()
)
if not first_message:
raise FirstMessageNotExistsError()
history_messages = db.session.query(Message).filter(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id
) \
.order_by(Message.created_at.desc()).limit(limit).all()
history_messages = (
db.session.query(Message)
.filter(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
.limit(limit)
.all()
)
else:
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
.order_by(Message.created_at.desc()).limit(limit).all()
history_messages = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(limit)
.all()
)
has_more = False
if len(history_messages) == limit:
current_page_first_message = history_messages[-1]
rest_count = db.session.query(Message).filter(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id
).count()
rest_count = (
db.session.query(Message)
.filter(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
.count()
)
if rest_count > 0:
has_more = True
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(
data=history_messages,
limit=limit,
has_more=has_more
)
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
@classmethod
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
last_id: Optional[str], limit: int, conversation_id: Optional[str] = None,
include_ids: Optional[list] = None) -> InfiniteScrollPagination:
def pagination_by_last_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
limit: int,
conversation_id: Optional[str] = None,
include_ids: Optional[list] = None,
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
@ -89,9 +112,7 @@ class MessageService:
if conversation_id is not None:
conversation = ConversationService.get_conversation(
app_model=app_model,
user=user,
conversation_id=conversation_id
app_model=app_model, user=user, conversation_id=conversation_id
)
base_query = base_query.filter(Message.conversation_id == conversation.id)
@ -105,10 +126,12 @@ class MessageService:
if not last_message:
raise LastMessageNotExistsError()
history_messages = base_query.filter(
Message.created_at < last_message.created_at,
Message.id != last_message.id
).order_by(Message.created_at.desc()).limit(limit).all()
history_messages = (
base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id)
.order_by(Message.created_at.desc())
.limit(limit)
.all()
)
else:
history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all()
@ -116,30 +139,22 @@ class MessageService:
if len(history_messages) == limit:
current_page_first_message = history_messages[-1]
rest_count = base_query.filter(
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id
Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id
).count()
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(
data=history_messages,
limit=limit,
has_more=has_more
)
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
@classmethod
def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]],
rating: Optional[str]) -> MessageFeedback:
def create_feedback(
cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str]
) -> MessageFeedback:
if not user:
raise ValueError('user cannot be None')
raise ValueError("user cannot be None")
message = cls.get_message(
app_model=app_model,
user=user,
message_id=message_id
)
message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback
@ -148,14 +163,14 @@ class MessageService:
elif rating and feedback:
feedback.rating = rating
elif not rating and not feedback:
raise ValueError('rating cannot be None when feedback not exists')
raise ValueError("rating cannot be None when feedback not exists")
else:
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating,
from_source=('user' if isinstance(user, EndUser) else 'admin'),
from_source=("user" if isinstance(user, EndUser) else "admin"),
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
from_account_id=(user.id if isinstance(user, Account) else None),
)
@ -167,13 +182,17 @@ class MessageService:
@classmethod
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
).first()
message = (
db.session.query(Message)
.filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
)
if not message:
raise MessageNotExistsError()
@ -181,27 +200,22 @@ class MessageService:
return message
@classmethod
def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]],
message_id: str, invoke_from: InvokeFrom) -> list[Message]:
def get_suggested_questions_after_answer(
cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom
) -> list[Message]:
if not user:
raise ValueError('user cannot be None')
raise ValueError("user cannot be None")
message = cls.get_message(
app_model=app_model,
user=user,
message_id=message_id
)
message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
conversation = ConversationService.get_conversation(
app_model=app_model,
conversation_id=message.conversation_id,
user=user
app_model=app_model, conversation_id=message.conversation_id, user=user
)
if not conversation:
raise ConversationNotExistsError()
if conversation.status != 'normal':
if conversation.status != "normal":
raise ConversationCompletedError()
model_manager = ModelManager()
@ -216,24 +230,23 @@ class MessageService:
if workflow is None:
return []
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
if not app_config.additional_features.suggested_questions_after_answer:
raise SuggestedQuestionsAfterAnswerDisabledError()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id,
model_type=ModelType.LLM
tenant_id=app_model.tenant_id, model_type=ModelType.LLM
)
else:
if not conversation.override_model_configs:
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
app_model_config = (
db.session.query(AppModelConfig)
.filter(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
.first()
)
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig(
@ -249,16 +262,13 @@ class MessageService:
model_instance = model_manager.get_model_instance(
tenant_id=app_model.tenant_id,
provider=app_model_config.model_dict['provider'],
provider=app_model_config.model_dict["provider"],
model_type=ModelType.LLM,
model=app_model_config.model_dict['name']
model=app_model_config.model_dict["name"],
)
# get memory of conversation (read-only)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
histories = memory.get_history_prompt_text(
max_token_limit=3000,
@ -267,18 +277,14 @@ class MessageService:
with measure_time() as timer:
questions = LLMGenerator.generate_suggested_questions_after_answer(
tenant_id=app_model.tenant_id,
histories=histories
tenant_id=app_model.tenant_id, histories=histories
)
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_model.id)
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.SUGGESTED_QUESTION_TRACE,
message_id=message_id,
suggested_question=questions,
timer=timer
TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer
)
)

View File

@ -23,7 +23,6 @@ logger = logging.getLogger(__name__)
class ModelLoadBalancingService:
def __init__(self) -> None:
self.provider_manager = ProviderManager()
@ -46,10 +45,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model load balancing
provider_configuration.enable_model_load_balancing(
model=model,
model_type=ModelType.value_of(model_type)
)
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
"""
@ -70,13 +66,11 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# disable model load balancing
provider_configuration.disable_model_load_balancing(
model=model,
model_type=ModelType.value_of(model_type)
)
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \
-> tuple[bool, list[dict]]:
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str
) -> tuple[bool, list[dict]]:
"""
Get load balancing configurations.
:param tenant_id: workspace id
@ -107,20 +101,24 @@ class ModelLoadBalancingService:
is_load_balancing_enabled = True
# Get load balancing configurations
load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model
).order_by(LoadBalancingModelConfig.created_at).all()
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.order_by(LoadBalancingModelConfig.created_at)
.all()
)
if provider_configuration.custom_configuration.provider:
# check if the inherit configuration exists,
# inherit is represented for the provider or model custom credentials
inherit_config_exists = False
for load_balancing_config in load_balancing_configs:
if load_balancing_config.name == '__inherit__':
if load_balancing_config.name == "__inherit__":
inherit_config_exists = True
break
@ -133,7 +131,7 @@ class ModelLoadBalancingService:
else:
# move the inherit configuration to the first
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
if load_balancing_config.name == '__inherit__':
if load_balancing_config.name == "__inherit__":
inherit_config = load_balancing_configs.pop(i)
load_balancing_configs.insert(0, inherit_config)
@ -151,7 +149,7 @@ class ModelLoadBalancingService:
provider=provider,
model=model,
model_type=model_type,
config_id=load_balancing_config.id
config_id=load_balancing_config.id,
)
try:
@ -172,32 +170,32 @@ class ModelLoadBalancingService:
if variable in credentials:
try:
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa
)
except ValueError:
pass
# Obfuscate credentials
credentials = provider_configuration.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=credential_schemas.credential_form_schemas
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
)
datas.append({
'id': load_balancing_config.id,
'name': load_balancing_config.name,
'credentials': credentials,
'enabled': load_balancing_config.enabled,
'in_cooldown': in_cooldown,
'ttl': ttl
})
datas.append(
{
"id": load_balancing_config.id,
"name": load_balancing_config.name,
"credentials": credentials,
"enabled": load_balancing_config.enabled,
"in_cooldown": in_cooldown,
"ttl": ttl,
}
)
return is_load_balancing_enabled, datas
def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \
-> Optional[dict]:
def get_load_balancing_config(
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
) -> Optional[dict]:
"""
Get load balancing configuration.
:param tenant_id: workspace id
@ -219,14 +217,17 @@ class ModelLoadBalancingService:
model_type = ModelType.value_of(model_type)
# Get load balancing configurations
load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id
).first()
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
)
if not load_balancing_model_config:
return None
@ -244,19 +245,19 @@ class ModelLoadBalancingService:
# Obfuscate credentials
credentials = provider_configuration.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=credential_schemas.credential_form_schemas
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
)
return {
'id': load_balancing_model_config.id,
'name': load_balancing_model_config.name,
'credentials': credentials,
'enabled': load_balancing_model_config.enabled
"id": load_balancing_model_config.id,
"name": load_balancing_model_config.name,
"credentials": credentials,
"enabled": load_balancing_model_config.enabled,
}
def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \
-> LoadBalancingModelConfig:
def _init_inherit_config(
self, tenant_id: str, provider: str, model: str, model_type: ModelType
) -> LoadBalancingModelConfig:
"""
Initialize the inherit configuration.
:param tenant_id: workspace id
@ -271,18 +272,16 @@ class ModelLoadBalancingService:
provider_name=provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
name='__inherit__'
name="__inherit__",
)
db.session.add(inherit_config)
db.session.commit()
return inherit_config
def update_load_balancing_configs(self, tenant_id: str,
provider: str,
model: str,
model_type: str,
configs: list[dict]) -> None:
def update_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict]
) -> None:
"""
Update load balancing configurations.
:param tenant_id: workspace id
@ -304,15 +303,18 @@ class ModelLoadBalancingService:
model_type = ModelType.value_of(model_type)
if not isinstance(configs, list):
raise ValueError('Invalid load balancing configs')
raise ValueError("Invalid load balancing configs")
current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model
).all()
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.all()
)
# id as key, config as value
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
@ -320,25 +322,25 @@ class ModelLoadBalancingService:
for config in configs:
if not isinstance(config, dict):
raise ValueError('Invalid load balancing config')
raise ValueError("Invalid load balancing config")
config_id = config.get('id')
name = config.get('name')
credentials = config.get('credentials')
enabled = config.get('enabled')
config_id = config.get("id")
name = config.get("name")
credentials = config.get("credentials")
enabled = config.get("enabled")
if not name:
raise ValueError('Invalid load balancing config name')
raise ValueError("Invalid load balancing config name")
if enabled is None:
raise ValueError('Invalid load balancing config enabled')
raise ValueError("Invalid load balancing config enabled")
# is config exists
if config_id:
config_id = str(config_id)
if config_id not in current_load_balancing_configs_dict:
raise ValueError('Invalid load balancing config id: {}'.format(config_id))
raise ValueError("Invalid load balancing config id: {}".format(config_id))
updated_config_ids.add(config_id)
@ -347,11 +349,11 @@ class ModelLoadBalancingService:
# check duplicate name
for current_load_balancing_config in current_load_balancing_configs:
if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
raise ValueError('Load balancing config name {} already exists'.format(name))
raise ValueError("Load balancing config name {} already exists".format(name))
if credentials:
if not isinstance(credentials, dict):
raise ValueError('Invalid load balancing config credentials')
raise ValueError("Invalid load balancing config credentials")
# validate custom provider config
credentials = self._custom_credentials_validate(
@ -361,7 +363,7 @@ class ModelLoadBalancingService:
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_config,
validate=False
validate=False,
)
# update load balancing config
@ -375,19 +377,19 @@ class ModelLoadBalancingService:
self._clear_credentials_cache(tenant_id, config_id)
else:
# create load balancing config
if name == '__inherit__':
raise ValueError('Invalid load balancing config name')
if name == "__inherit__":
raise ValueError("Invalid load balancing config name")
# check duplicate name
for current_load_balancing_config in current_load_balancing_configs:
if current_load_balancing_config.name == name:
raise ValueError('Load balancing config name {} already exists'.format(name))
raise ValueError("Load balancing config name {} already exists".format(name))
if not credentials:
raise ValueError('Invalid load balancing config credentials')
raise ValueError("Invalid load balancing config credentials")
if not isinstance(credentials, dict):
raise ValueError('Invalid load balancing config credentials')
raise ValueError("Invalid load balancing config credentials")
# validate custom provider config
credentials = self._custom_credentials_validate(
@ -396,7 +398,7 @@ class ModelLoadBalancingService:
model_type=model_type,
model=model,
credentials=credentials,
validate=False
validate=False,
)
# create load balancing config
@ -406,7 +408,7 @@ class ModelLoadBalancingService:
model_type=model_type.to_origin_model_type(),
model_name=model,
name=name,
encrypted_config=json.dumps(credentials)
encrypted_config=json.dumps(credentials),
)
db.session.add(load_balancing_model_config)
@ -420,12 +422,15 @@ class ModelLoadBalancingService:
self._clear_credentials_cache(tenant_id, config_id)
def validate_load_balancing_credentials(self, tenant_id: str,
provider: str,
model: str,
model_type: str,
credentials: dict,
config_id: Optional[str] = None) -> None:
def validate_load_balancing_credentials(
self,
tenant_id: str,
provider: str,
model: str,
model_type: str,
credentials: dict,
config_id: Optional[str] = None,
) -> None:
"""
Validate load balancing credentials.
:param tenant_id: workspace id
@ -450,14 +455,17 @@ class ModelLoadBalancingService:
load_balancing_model_config = None
if config_id:
# Get load balancing config
load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id
).first()
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
)
if not load_balancing_model_config:
raise ValueError(f"Load balancing config {config_id} does not exist.")
@ -469,16 +477,19 @@ class ModelLoadBalancingService:
model_type=model_type,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_model_config
load_balancing_model_config=load_balancing_model_config,
)
def _custom_credentials_validate(self, tenant_id: str,
provider_configuration: ProviderConfiguration,
model_type: ModelType,
model: str,
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
validate: bool = True) -> dict:
def _custom_credentials_validate(
self,
tenant_id: str,
provider_configuration: ProviderConfiguration,
model_type: ModelType,
model: str,
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
validate: bool = True,
) -> dict:
"""
Validate custom credentials.
:param tenant_id: workspace id
@ -521,12 +532,11 @@ class ModelLoadBalancingService:
provider=provider_configuration.provider.provider,
model_type=model_type,
model=model,
credentials=credentials
credentials=credentials,
)
else:
credentials = model_provider_factory.provider_credentials_validate(
provider=provider_configuration.provider.provider,
credentials=credentials
provider=provider_configuration.provider.provider, credentials=credentials
)
for key, value in credentials.items():
@ -535,8 +545,9 @@ class ModelLoadBalancingService:
return credentials
def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \
-> ModelCredentialSchema | ProviderCredentialSchema:
def _get_credential_schema(
self, provider_configuration: ProviderConfiguration
) -> ModelCredentialSchema | ProviderCredentialSchema:
"""
Get form schemas.
:param provider_configuration: provider configuration
@ -558,9 +569,7 @@ class ModelLoadBalancingService:
:return:
"""
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=config_id,
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
)
provider_model_credentials_cache.delete()

View File

@ -73,8 +73,8 @@ class ModelProviderService:
system_configuration=SystemConfigurationResponse(
enabled=provider_configuration.system_configuration.enabled,
current_quota_type=provider_configuration.system_configuration.current_quota_type,
quota_configurations=provider_configuration.system_configuration.quota_configurations
)
quota_configurations=provider_configuration.system_configuration.quota_configurations,
),
)
provider_responses.append(provider_response)
@ -95,9 +95,9 @@ class ModelProviderService:
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider available models
return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(
provider=provider
)]
return [
ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
]
def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
"""
@ -195,13 +195,12 @@ class ModelProviderService:
# Get model custom credentials from ProviderModel if exists
return provider_configuration.get_custom_model_credentials(
model_type=ModelType.value_of(model_type),
model=model,
obfuscated=True
model_type=ModelType.value_of(model_type), model=model, obfuscated=True
)
def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str,
credentials: dict) -> None:
def model_credentials_validate(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
) -> None:
"""
validate model credentials.
@ -222,13 +221,12 @@ class ModelProviderService:
# Validate model credentials
provider_configuration.custom_model_credentials_validate(
model_type=ModelType.value_of(model_type),
model=model,
credentials=credentials
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
)
def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str,
credentials: dict) -> None:
def save_model_credentials(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
) -> None:
"""
save model credentials.
@ -249,9 +247,7 @@ class ModelProviderService:
# Add or update custom model credentials
provider_configuration.add_or_update_custom_model_credentials(
model_type=ModelType.value_of(model_type),
model=model,
credentials=credentials
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
)
def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
@ -273,10 +269,7 @@ class ModelProviderService:
raise ValueError(f"Provider {provider} does not exist.")
# Remove custom model credentials
provider_configuration.delete_custom_model_credentials(
model_type=ModelType.value_of(model_type),
model=model
)
provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model)
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
"""
@ -290,9 +283,7 @@ class ModelProviderService:
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider available models
models = provider_configurations.get_models(
model_type=ModelType.value_of(model_type)
)
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
# Group models by provider
provider_models = {}
@ -323,16 +314,19 @@ class ModelProviderService:
icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE,
models=[ProviderModelWithStatusEntity(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
status=model.status,
load_balancing_enabled=model.load_balancing_enabled
) for model in models]
models=[
ProviderModelWithStatusEntity(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
status=model.status,
load_balancing_enabled=model.load_balancing_enabled,
)
for model in models
],
)
)
@ -361,19 +355,13 @@ class ModelProviderService:
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# fetch credentials
credentials = provider_configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model
)
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
if not credentials:
return []
# Call get_parameter_rules method of model instance to get model parameter rules
return model_type_instance.get_parameter_rules(
model=model,
credentials=credentials
)
return model_type_instance.get_parameter_rules(model=model, credentials=credentials)
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
"""
@ -384,22 +372,23 @@ class ModelProviderService:
:return:
"""
model_type_enum = ModelType.value_of(model_type)
result = self.provider_manager.get_default_model(
tenant_id=tenant_id,
model_type=model_type_enum
)
result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
try:
return DefaultModelResponse(
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types
return (
DefaultModelResponse(
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types,
),
)
) if result else None
if result
else None
)
except Exception as e:
logger.info(f"get_default_model_of_model_type error: {e}")
return None
@ -416,13 +405,12 @@ class ModelProviderService:
"""
model_type_enum = ModelType.value_of(model_type)
self.provider_manager.update_default_model_record(
tenant_id=tenant_id,
model_type=model_type_enum,
provider=provider,
model=model
tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model
)
def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[Optional[bytes], Optional[str]]:
def get_model_provider_icon(
self, provider: str, icon_type: str, lang: str
) -> tuple[Optional[bytes], Optional[str]]:
"""
get model provider icon.
@ -434,11 +422,11 @@ class ModelProviderService:
provider_instance = model_provider_factory.get_provider_instance(provider)
provider_schema = provider_instance.get_provider_schema()
if icon_type.lower() == 'icon_small':
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
if lang.lower() == 'zh_hans':
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_small.zh_Hans
else:
file_name = provider_schema.icon_small.en_US
@ -446,13 +434,15 @@ class ModelProviderService:
if not provider_schema.icon_large:
raise ValueError(f"Provider {provider} does not have large icon.")
if lang.lower() == 'zh_hans':
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_large.zh_Hans
else:
file_name = provider_schema.icon_large.en_US
root_path = current_app.root_path
provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/')))
provider_instance_path = os.path.dirname(
os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/"))
)
file_path = os.path.join(provider_instance_path, "_assets")
file_path = os.path.join(file_path, file_name)
@ -460,10 +450,10 @@ class ModelProviderService:
return None, None
mimetype, _ = mimetypes.guess_type(file_path)
mimetype = mimetype or 'application/octet-stream'
mimetype = mimetype or "application/octet-stream"
# read binary from file
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
byte_data = f.read()
return byte_data, mimetype
@ -509,10 +499,7 @@ class ModelProviderService:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model
provider_configuration.enable_model(
model=model,
model_type=ModelType.value_of(model_type)
)
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
"""
@ -533,78 +520,49 @@ class ModelProviderService:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model
provider_configuration.disable_model(
model=model,
model_type=ModelType.value_of(model_type)
)
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))
def free_quota_submit(self, tenant_id: str, provider: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_url = api_base_url + '/api/v1/providers/apply'
api_url = api_base_url + "/api/v1/providers/apply"
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {api_key}"
}
response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider})
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider})
if not response.ok:
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
raise ValueError(f"Error: {response.status_code} ")
if response.json()["code"] != 'success':
raise ValueError(
f"error: {response.json()['message']}"
)
if response.json()["code"] != "success":
raise ValueError(f"error: {response.json()['message']}")
rst = response.json()
if rst['type'] == 'redirect':
return {
'type': rst['type'],
'redirect_url': rst['redirect_url']
}
if rst["type"] == "redirect":
return {"type": rst["type"], "redirect_url": rst["redirect_url"]}
else:
return {
'type': rst['type'],
'result': 'success'
}
return {"type": rst["type"], "result": "success"}
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_url = api_base_url + '/api/v1/providers/qualification-verify'
api_url = api_base_url + "/api/v1/providers/qualification-verify"
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {api_key}"
}
json_data = {'workspace_id': tenant_id, 'provider_name': provider}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
json_data = {"workspace_id": tenant_id, "provider_name": provider}
if token:
json_data['token'] = token
response = requests.post(api_url, headers=headers,
json=json_data)
json_data["token"] = token
response = requests.post(api_url, headers=headers, json=json_data)
if not response.ok:
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
raise ValueError(f"Error: {response.status_code} ")
rst = response.json()
if rst["code"] != 'success':
raise ValueError(
f"error: {rst['message']}"
)
if rst["code"] != "success":
raise ValueError(f"error: {rst['message']}")
data = rst['data']
if data['qualified'] is True:
return {
'result': 'success',
'provider_name': provider,
'flag': True
}
data = rst["data"]
if data["qualified"] is True:
return {"result": "success", "provider_name": provider, "flag": True}
else:
return {
'result': 'success',
'provider_name': provider,
'flag': False,
'reason': data['reason']
}
return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]}

View File

@ -4,17 +4,18 @@ from models.model import App, AppModelConfig
class ModerationService:
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
app_model_config: AppModelConfig = None
app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
)
if not app_model_config:
raise ValueError("app model config not found")
name = app_model_config.sensitive_word_avoidance_dict['type']
config = app_model_config.sensitive_word_avoidance_dict['config']
name = app_model_config.sensitive_word_avoidance_dict["type"]
config = app_model_config.sensitive_word_avoidance_dict["config"]
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
return moderation.moderation_for_outputs(text)

View File

@ -4,15 +4,12 @@ import requests
class OperationService:
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
@classmethod
def _send_request(cls, method, endpoint, json=None, params=None):
headers = {
"Content-Type": "application/json",
"Billing-Api-Secret-Key": cls.secret_key
}
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers)
@ -22,11 +19,11 @@ class OperationService:
@classmethod
def record_utm(cls, tenant_id: str, utm_info: dict):
params = {
'tenant_id': tenant_id,
'utm_source': utm_info.get('utm_source', ''),
'utm_medium': utm_info.get('utm_medium', ''),
'utm_campaign': utm_info.get('utm_campaign', ''),
'utm_content': utm_info.get('utm_content', ''),
'utm_term': utm_info.get('utm_term', '')
"tenant_id": tenant_id,
"utm_source": utm_info.get("utm_source", ""),
"utm_medium": utm_info.get("utm_medium", ""),
"utm_campaign": utm_info.get("utm_campaign", ""),
"utm_content": utm_info.get("utm_content", ""),
"utm_term": utm_info.get("utm_term", ""),
}
return cls._send_request('POST', '/tenant_utms', params=params)
return cls._send_request("POST", "/tenant_utms", params=params)

View File

@ -12,19 +12,25 @@ class OpsService:
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
).first()
trace_config_data: TraceAppConfig = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not trace_config_data:
return None
# decrypt_token and obfuscated_token
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config)
if tracing_provider == 'langfuse' and ('project_key' not in decrypt_tracing_config or not decrypt_tracing_config.get('project_key')):
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config
)
if tracing_provider == "langfuse" and (
"project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key")
):
project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider)
decrypt_tracing_config['project_key'] = project_key
decrypt_tracing_config["project_key"] = project_key
decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config)
@ -44,8 +50,10 @@ class OpsService:
if tracing_provider not in provider_config_map.keys() and tracing_provider:
return {"error": f"Invalid tracing provider: {tracing_provider}"}
config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['other_keys']
config_class, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["other_keys"],
)
default_config_instance = config_class(**tracing_config)
for key in other_keys:
if key in tracing_config and tracing_config[key] == "":
@ -59,9 +67,11 @@ class OpsService:
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
# check if trace config already exists
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
).first()
trace_config_data: TraceAppConfig = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if trace_config_data:
return None
@ -69,8 +79,8 @@ class OpsService:
# get tenant id
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
if tracing_provider == 'langfuse':
tracing_config['project_key'] = project_key
if tracing_provider == "langfuse":
tracing_config["project_key"] = project_key
trace_config_data = TraceAppConfig(
app_id=app_id,
tracing_provider=tracing_provider,
@ -94,9 +104,11 @@ class OpsService:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
# check if trace config already exists
current_trace_config = db.session.query(TraceAppConfig).filter(
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
).first()
current_trace_config = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not current_trace_config:
return None
@ -126,9 +138,11 @@ class OpsService:
:param tracing_provider: tracing provider
:return:
"""
trace_config = db.session.query(TraceAppConfig).filter(
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
).first()
trace_config = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not trace_config:
return None

View File

@ -16,7 +16,6 @@ logger = logging.getLogger(__name__)
class RecommendedAppService:
builtin_data: Optional[dict] = None
@classmethod
@ -27,21 +26,21 @@ class RecommendedAppService:
:return:
"""
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
if mode == 'remote':
if mode == "remote":
try:
result = cls._fetch_recommended_apps_from_dify_official(language)
except Exception as e:
logger.warning(f'fetch recommended apps from dify official failed: {e}, switch to built-in.')
logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.")
result = cls._fetch_recommended_apps_from_builtin(language)
elif mode == 'db':
elif mode == "db":
result = cls._fetch_recommended_apps_from_db(language)
elif mode == 'builtin':
elif mode == "builtin":
result = cls._fetch_recommended_apps_from_builtin(language)
else:
raise ValueError(f'invalid fetch recommended apps mode: {mode}')
raise ValueError(f"invalid fetch recommended apps mode: {mode}")
if not result.get('recommended_apps') and language != 'en-US':
result = cls._fetch_recommended_apps_from_builtin('en-US')
if not result.get("recommended_apps") and language != "en-US":
result = cls._fetch_recommended_apps_from_builtin("en-US")
return result
@ -52,16 +51,18 @@ class RecommendedAppService:
:param language: language
:return:
"""
recommended_apps = db.session.query(RecommendedApp).filter(
RecommendedApp.is_listed == True,
RecommendedApp.language == language
).all()
recommended_apps = (
db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.all()
)
if len(recommended_apps) == 0:
recommended_apps = db.session.query(RecommendedApp).filter(
RecommendedApp.is_listed == True,
RecommendedApp.language == languages[0]
).all()
recommended_apps = (
db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.all()
)
categories = set()
recommended_apps_result = []
@ -75,28 +76,28 @@ class RecommendedAppService:
continue
recommended_app_result = {
'id': recommended_app.id,
'app': {
'id': app.id,
'name': app.name,
'mode': app.mode,
'icon': app.icon,
'icon_background': app.icon_background
"id": recommended_app.id,
"app": {
"id": app.id,
"name": app.name,
"mode": app.mode,
"icon": app.icon,
"icon_background": app.icon_background,
},
'app_id': recommended_app.app_id,
'description': site.description,
'copyright': site.copyright,
'privacy_policy': site.privacy_policy,
'custom_disclaimer': site.custom_disclaimer,
'category': recommended_app.category,
'position': recommended_app.position,
'is_listed': recommended_app.is_listed
"app_id": recommended_app.app_id,
"description": site.description,
"copyright": site.copyright,
"privacy_policy": site.privacy_policy,
"custom_disclaimer": site.custom_disclaimer,
"category": recommended_app.category,
"position": recommended_app.position,
"is_listed": recommended_app.is_listed,
}
recommended_apps_result.append(recommended_app_result)
categories.add(recommended_app.category) # add category to categories
return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)}
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
@classmethod
def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
@ -106,16 +107,16 @@ class RecommendedAppService:
:return:
"""
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
url = f'{domain}/apps?language={language}'
url = f"{domain}/apps?language={language}"
response = requests.get(url, timeout=(3, 10))
if response.status_code != 200:
raise ValueError(f'fetch recommended apps failed, status code: {response.status_code}')
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
result = response.json()
if "categories" in result:
result["categories"] = sorted(result["categories"])
return result
@classmethod
@ -126,7 +127,7 @@ class RecommendedAppService:
:return:
"""
builtin_data = cls._get_builtin_data()
return builtin_data.get('recommended_apps', {}).get(language)
return builtin_data.get("recommended_apps", {}).get(language)
@classmethod
def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
@ -136,18 +137,18 @@ class RecommendedAppService:
:return:
"""
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
if mode == 'remote':
if mode == "remote":
try:
result = cls._fetch_recommended_app_detail_from_dify_official(app_id)
except Exception as e:
logger.warning(f'fetch recommended app detail from dify official failed: {e}, switch to built-in.')
logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.")
result = cls._fetch_recommended_app_detail_from_builtin(app_id)
elif mode == 'db':
elif mode == "db":
result = cls._fetch_recommended_app_detail_from_db(app_id)
elif mode == 'builtin':
elif mode == "builtin":
result = cls._fetch_recommended_app_detail_from_builtin(app_id)
else:
raise ValueError(f'invalid fetch recommended app detail mode: {mode}')
raise ValueError(f"invalid fetch recommended app detail mode: {mode}")
return result
@ -159,7 +160,7 @@ class RecommendedAppService:
:return:
"""
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
url = f'{domain}/apps/{app_id}'
url = f"{domain}/apps/{app_id}"
response = requests.get(url, timeout=(3, 10))
if response.status_code != 200:
return None
@ -174,10 +175,11 @@ class RecommendedAppService:
:return:
"""
# is in public recommended list
recommended_app = db.session.query(RecommendedApp).filter(
RecommendedApp.is_listed == True,
RecommendedApp.app_id == app_id
).first()
recommended_app = (
db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
.first()
)
if not recommended_app:
return None
@ -188,12 +190,12 @@ class RecommendedAppService:
return None
return {
'id': app_model.id,
'name': app_model.name,
'icon': app_model.icon,
'icon_background': app_model.icon_background,
'mode': app_model.mode,
'export_data': AppDslService.export_dsl(app_model=app_model)
"id": app_model.id,
"name": app_model.name,
"icon": app_model.icon,
"icon_background": app_model.icon_background,
"mode": app_model.mode,
"export_data": AppDslService.export_dsl(app_model=app_model),
}
@classmethod
@ -204,7 +206,7 @@ class RecommendedAppService:
:return:
"""
builtin_data = cls._get_builtin_data()
return builtin_data.get('app_details', {}).get(app_id)
return builtin_data.get("app_details", {}).get(app_id)
@classmethod
def _get_builtin_data(cls) -> dict:
@ -216,7 +218,7 @@ class RecommendedAppService:
return cls.builtin_data
root_path = current_app.root_path
with open(path.join(root_path, 'constants', 'recommended_apps.json'), encoding='utf-8') as f:
with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f:
json_data = f.read()
data = json.loads(json_data)
cls.builtin_data = data
@ -229,27 +231,24 @@ class RecommendedAppService:
Fetch all recommended apps and export datas
:return:
"""
templates = {
"recommended_apps": {},
"app_details": {}
}
templates = {"recommended_apps": {}, "app_details": {}}
for language in languages:
try:
result = cls._fetch_recommended_apps_from_dify_official(language)
except Exception as e:
logger.warning(f'fetch recommended apps from dify official failed: {e}, skip.')
logger.warning(f"fetch recommended apps from dify official failed: {e}, skip.")
continue
templates['recommended_apps'][language] = result
templates["recommended_apps"][language] = result
for recommended_app in result.get('recommended_apps'):
app_id = recommended_app.get('app_id')
for recommended_app in result.get("recommended_apps"):
app_id = recommended_app.get("app_id")
# get app detail
app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id)
if not app_detail:
continue
templates['app_details'][app_id] = app_detail
templates["app_details"][app_id] = app_detail
return templates

View File

@ -10,46 +10,48 @@ from services.message_service import MessageService
class SavedMessageService:
@classmethod
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
last_id: Optional[str], limit: int) -> InfiniteScrollPagination:
saved_messages = db.session.query(SavedMessage).filter(
SavedMessage.app_id == app_model.id,
SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
SavedMessage.created_by == user.id
).order_by(SavedMessage.created_at.desc()).all()
def pagination_by_last_id(
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
) -> InfiniteScrollPagination:
saved_messages = (
db.session.query(SavedMessage)
.filter(
SavedMessage.app_id == app_model.id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id,
)
.order_by(SavedMessage.created_at.desc())
.all()
)
message_ids = [sm.message_id for sm in saved_messages]
return MessageService.pagination_by_last_id(
app_model=app_model,
user=user,
last_id=last_id,
limit=limit,
include_ids=message_ids
app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids
)
@classmethod
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
saved_message = db.session.query(SavedMessage).filter(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
SavedMessage.created_by == user.id
).first()
saved_message = (
db.session.query(SavedMessage)
.filter(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id,
)
.first()
)
if saved_message:
return
message = MessageService.get_message(
app_model=app_model,
user=user,
message_id=message_id
)
message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id)
saved_message = SavedMessage(
app_id=app_model.id,
message_id=message.id,
created_by_role='account' if isinstance(user, Account) else 'end_user',
created_by=user.id
created_by_role="account" if isinstance(user, Account) else "end_user",
created_by=user.id,
)
db.session.add(saved_message)
@ -57,12 +59,16 @@ class SavedMessageService:
@classmethod
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
saved_message = db.session.query(SavedMessage).filter(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
SavedMessage.created_by == user.id
).first()
saved_message = (
db.session.query(SavedMessage)
.filter(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id,
)
.first()
)
if not saved_message:
return

View File

@ -12,38 +12,32 @@ from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
query = db.session.query(
Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count')
).outerjoin(
TagBinding, Tag.id == TagBinding.tag_id
).filter(
Tag.type == tag_type,
Tag.tenant_id == current_tenant_id
query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
.filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%')))
query = query.group_by(
Tag.id
)
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id)
results = query.order_by(Tag.created_at.desc()).all()
return results
@staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
tags = db.session.query(Tag).filter(
Tag.id.in_(tag_ids),
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type
).all()
tags = (
db.session.query(Tag)
.filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
if not tags:
return []
tag_ids = [tag.id for tag in tags]
tag_bindings = db.session.query(
TagBinding.target_id
).filter(
TagBinding.tag_id.in_(tag_ids),
TagBinding.tenant_id == current_tenant_id
).all()
tag_bindings = (
db.session.query(TagBinding.target_id)
.filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.all()
)
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
@ -51,27 +45,28 @@ class TagService:
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
tags = db.session.query(Tag).join(
TagBinding,
Tag.id == TagBinding.tag_id
).filter(
TagBinding.target_id == target_id,
TagBinding.tenant_id == current_tenant_id,
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type
).all()
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.filter(
TagBinding.target_id == target_id,
TagBinding.tenant_id == current_tenant_id,
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type,
)
.all()
)
return tags if tags else []
@staticmethod
def save_tags(args: dict) -> Tag:
tag = Tag(
id=str(uuid.uuid4()),
name=args['name'],
type=args['type'],
name=args["name"],
type=args["type"],
created_by=current_user.id,
tenant_id=current_user.current_tenant_id
tenant_id=current_user.current_tenant_id,
)
db.session.add(tag)
db.session.commit()
@ -82,7 +77,7 @@ class TagService:
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
tag.name = args['name']
tag.name = args["name"]
db.session.commit()
return tag
@ -107,20 +102,21 @@ class TagService:
@staticmethod
def save_tag_binding(args):
# check if target exists
TagService.check_target_exists(args['type'], args['target_id'])
TagService.check_target_exists(args["type"], args["target_id"])
# save tag binding
for tag_id in args['tag_ids']:
tag_binding = db.session.query(TagBinding).filter(
TagBinding.tag_id == tag_id,
TagBinding.target_id == args['target_id']
).first()
for tag_id in args["tag_ids"]:
tag_binding = (
db.session.query(TagBinding)
.filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
.first()
)
if tag_binding:
continue
new_tag_binding = TagBinding(
tag_id=tag_id,
target_id=args['target_id'],
target_id=args["target_id"],
tenant_id=current_user.current_tenant_id,
created_by=current_user.id
created_by=current_user.id,
)
db.session.add(new_tag_binding)
db.session.commit()
@ -128,34 +124,34 @@ class TagService:
@staticmethod
def delete_tag_binding(args):
# check if target exists
TagService.check_target_exists(args['type'], args['target_id'])
TagService.check_target_exists(args["type"], args["target_id"])
# delete tag binding
tag_bindings = db.session.query(TagBinding).filter(
TagBinding.target_id == args['target_id'],
TagBinding.tag_id == (args['tag_id'])
).first()
tag_bindings = (
db.session.query(TagBinding)
.filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
.first()
)
if tag_bindings:
db.session.delete(tag_bindings)
db.session.commit()
@staticmethod
def check_target_exists(type: str, target_id: str):
if type == 'knowledge':
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == current_user.current_tenant_id,
Dataset.id == target_id
).first()
if type == "knowledge":
dataset = (
db.session.query(Dataset)
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
.first()
)
if not dataset:
raise NotFound("Dataset not found")
elif type == 'app':
app = db.session.query(App).filter(
App.tenant_id == current_user.current_tenant_id,
App.id == target_id
).first()
elif type == "app":
app = (
db.session.query(App)
.filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
.first()
)
if not app:
raise NotFound("App not found")
else:
raise NotFound("Invalid binding type")

View File

@ -29,111 +29,107 @@ class ApiToolManageService:
@staticmethod
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
"""
parse api schema to tool bundle
parse api schema to tool bundle
"""
try:
warnings = {}
try:
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
except Exception as e:
raise ValueError(f'invalid schema: {str(e)}')
raise ValueError(f"invalid schema: {str(e)}")
credentials_schema = [
ToolProviderCredentials(
name='auth_type',
name="auth_type",
type=ToolProviderCredentials.CredentialsType.SELECT,
required=True,
default='none',
default="none",
options=[
ToolCredentialsOption(value='none', label=I18nObject(
en_US='None',
zh_Hans=''
)),
ToolCredentialsOption(value='api_key', label=I18nObject(
en_US='Api Key',
zh_Hans='Api Key'
)),
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="")),
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
],
placeholder=I18nObject(
en_US='Select auth type',
zh_Hans='选择认证方式'
)
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
),
ToolProviderCredentials(
name='api_key_header',
name="api_key_header",
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
required=False,
placeholder=I18nObject(
en_US='Enter api key header',
zh_Hans='输入 api key headerX-API-KEY'
),
default='api_key',
help=I18nObject(
en_US='HTTP header name for api key',
zh_Hans='HTTP 头部字段名,用于传递 api key'
)
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key headerX-API-KEY"),
default="api_key",
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
),
ToolProviderCredentials(
name='api_key_value',
name="api_key_value",
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
required=False,
placeholder=I18nObject(
en_US='Enter api key',
zh_Hans='输入 api key'
),
default=''
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
default="",
),
]
return jsonable_encoder({
'schema_type': schema_type,
'parameters_schema': tool_bundles,
'credentials_schema': credentials_schema,
'warning': warnings
})
return jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
)
except Exception as e:
raise ValueError(f'invalid schema: {str(e)}')
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
"""
convert schema to tool bundles
convert schema to tool bundles
:return: the list of tool bundles, description
:return: the list of tool bundles, description
"""
try:
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
return tool_bundles
except Exception as e:
raise ValueError(f'invalid schema: {str(e)}')
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def create_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
user_id: str,
tenant_id: str,
provider_name: str,
icon: dict,
credentials: dict,
schema_type: str,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
):
"""
create api tool provider
create api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f'invalid schema type {schema}')
raise ValueError(f"invalid schema type {schema}")
# check if the provider exists
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
).first()
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is not None:
raise ValueError(f'provider {provider_name} already exists')
raise ValueError(f"provider {provider_name} already exists")
# parse openapi to tool bundle
extra_info = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
if len(tool_bundles) > 100:
raise ValueError('the number of apis should be less than 100')
raise ValueError("the number of apis should be less than 100")
# create db provider
db_provider = ApiToolProvider(
@ -142,19 +138,19 @@ class ApiToolManageService:
name=provider_name,
icon=json.dumps(icon),
schema=schema,
description=extra_info.get('description', ''),
description=extra_info.get("description", ""),
schema_type_str=schema_type,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str={},
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer
custom_disclaimer=custom_disclaimer,
)
if 'auth_type' not in credentials:
raise ValueError('auth_type is required')
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
@ -172,14 +168,12 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return { 'result': 'success' }
return {"result": "success"}
@staticmethod
def get_api_tool_provider_remote_schema(
user_id: str, tenant_id: str, url: str
):
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
"""
get api tool provider remote schema
get api tool provider remote schema
"""
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
@ -189,84 +183,98 @@ class ApiToolManageService:
try:
response = get(url, headers=headers, timeout=10)
if response.status_code != 200:
raise ValueError(f'Got status code {response.status_code}')
raise ValueError(f"Got status code {response.status_code}")
schema = response.text
# try to parse schema, avoid SSRF attack
ApiToolManageService.parser_api_schema(schema)
except Exception as e:
logger.error(f"parse api schema error: {str(e)}")
raise ValueError('invalid schema, please check the url you provided')
return {
'schema': schema
}
raise ValueError("invalid schema, please check the url you provided")
return {"schema": schema}
@staticmethod
def list_api_tool_provider_tools(
user_id: str, tenant_id: str, provider: str
) -> list[UserTool]:
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
"""
list api tool provider tools
list api tool provider tools
"""
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
).first()
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
)
.first()
)
if provider is None:
raise ValueError(f'you have not added provider {provider}')
raise ValueError(f"you have not added provider {provider}")
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(controller)
return [
ToolTransformService.tool_to_user_tool(
tool_bundle,
labels=labels,
) for tool_bundle in provider.tools
)
for tool_bundle in provider.tools
]
@staticmethod
def update_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
user_id: str,
tenant_id: str,
provider_name: str,
original_provider: str,
icon: dict,
credentials: dict,
schema_type: str,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
):
"""
update api tool provider
update api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f'invalid schema type {schema}')
raise ValueError(f"invalid schema type {schema}")
# check if the provider exists
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider,
).first()
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider,
)
.first()
)
if provider is None:
raise ValueError(f'api provider {provider_name} does not exists')
raise ValueError(f"api provider {provider_name} does not exists")
# parse openapi to tool bundle
extra_info = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
# update db provider
provider.name = provider_name
provider.icon = json.dumps(icon)
provider.schema = schema
provider.description = extra_info.get('description', '')
provider.description = extra_info.get("description", "")
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
provider.custom_disclaimer = custom_disclaimer
if 'auth_type' not in credentials:
raise ValueError('auth_type is required')
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
@ -295,84 +303,91 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return { 'result': 'success' }
return {"result": "success"}
@staticmethod
def delete_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str
):
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
"""
delete tool provider
delete tool provider
"""
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
).first()
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is None:
raise ValueError(f'you have not added provider {provider_name}')
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider)
db.session.commit()
return { 'result': 'success' }
return {"result": "success"}
@staticmethod
def get_api_tool_provider(
user_id: str, tenant_id: str, provider: str
):
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
"""
get api tool provider
get api tool provider
"""
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
@staticmethod
def test_api_tool_preview(
tenant_id: str,
tenant_id: str,
provider_name: str,
tool_name: str,
credentials: dict,
parameters: dict,
schema_type: str,
schema: str
tool_name: str,
credentials: dict,
parameters: dict,
schema_type: str,
schema: str,
):
"""
test api tool before adding api tool provider
test api tool before adding api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f'invalid schema type {schema_type}')
raise ValueError(f"invalid schema type {schema_type}")
try:
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
except Exception as e:
raise ValueError('invalid schema')
raise ValueError("invalid schema")
# get tool bundle
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
if tool_bundle is None:
raise ValueError(f'invalid tool name {tool_name}')
db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
).first()
raise ValueError(f"invalid tool name {tool_name}")
db_provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if not db_provider:
# create a fake db provider
db_provider = ApiToolProvider(
tenant_id='', user_id='', name='', icon='',
tenant_id="",
user_id="",
name="",
icon="",
schema=schema,
description='',
description="",
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str=json.dumps(credentials),
)
if 'auth_type' not in credentials:
raise ValueError('auth_type is required')
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
@ -381,10 +396,7 @@ class ApiToolManageService:
# decrypt credentials
if db_provider.id:
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
provider_controller=provider_controller
)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
# check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
@ -396,27 +408,27 @@ class ApiToolManageService:
provider_controller.validate_credentials_format(credentials)
# get tool
tool = provider_controller.get_tool(tool_name)
tool = tool.fork_tool_runtime(runtime={
'credentials': credentials,
'tenant_id': tenant_id,
})
tool = tool.fork_tool_runtime(
runtime={
"credentials": credentials,
"tenant_id": tenant_id,
}
)
result = tool.validate_credentials(credentials, parameters)
except Exception as e:
return { 'error': str(e) }
return { 'result': result or 'empty response' }
return {"error": str(e)}
return {"result": result or "empty response"}
@staticmethod
def list_api_tools(
user_id: str, tenant_id: str
) -> list[UserToolProvider]:
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
"""
list api tools
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id
).all() or []
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
)
result: list[UserToolProvider] = []
@ -425,26 +437,21 @@ class ApiToolManageService:
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(provider_controller)
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller,
db_provider=provider,
decrypt_credentials=True
provider_controller, db_provider=provider, decrypt_credentials=True
)
user_provider.labels = labels
# add icon
ToolTransformService.repack_provider(user_provider)
tools = provider_controller.get_tools(
user_id=user_id, tenant_id=tenant_id
)
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
for tool in tools:
user_provider.tools.append(ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id,
tool=tool,
credentials=user_provider.original_credentials,
labels=labels
))
user_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
)
)
result.append(user_provider)

View File

@ -20,21 +20,25 @@ logger = logging.getLogger(__name__)
class BuiltinToolManageService:
@staticmethod
def list_builtin_tool_provider_tools(
user_id: str, tenant_id: str, provider: str
) -> list[UserTool]:
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
"""
list builtin tool provider tools
list builtin tool provider tools
"""
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
tools = provider_controller.get_tools()
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_provider_configurations = ToolConfigurationManager(
tenant_id=tenant_id, provider_controller=provider_controller
)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
builtin_provider: BuiltinToolProvider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.first()
)
credentials = {}
if builtin_provider is not None:
@ -44,47 +48,47 @@ class BuiltinToolManageService:
result = []
for tool in tools:
result.append(ToolTransformService.tool_to_user_tool(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller)
))
result.append(
ToolTransformService.tool_to_user_tool(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
return result
@staticmethod
def list_builtin_provider_credentials_schema(
provider_name
):
def list_builtin_provider_credentials_schema(provider_name):
"""
list builtin provider credentials schema
list builtin provider credentials schema
:return: the list of tool providers
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name)
return jsonable_encoder([
v for _, v in (provider.credentials_schema or {}).items()
])
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
@staticmethod
def update_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name: str, credentials: dict
):
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
"""
update builtin tool provider
update builtin tool provider
"""
# get if the provider exists
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
provider: BuiltinToolProvider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
.first()
)
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
raise ValueError(f'provider {provider_name} does not need credentials')
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
# get original credentials if exists
if provider is not None:
@ -121,19 +125,21 @@ class BuiltinToolManageService:
# delete cache
tool_configuration.delete_tool_credentials_cache()
return {'result': 'success'}
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_credentials(
user_id: str, tenant_id: str, provider: str
):
def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
"""
get builtin tool provider credentials
get builtin tool provider credentials
"""
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
provider: BuiltinToolProvider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.first()
)
if provider is None:
return {}
@ -145,19 +151,21 @@ class BuiltinToolManageService:
return credentials
@staticmethod
def delete_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name: str
):
def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
"""
delete tool provider
delete tool provider
"""
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
provider: BuiltinToolProvider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
.first()
)
if provider is None:
raise ValueError(f'you have not added provider {provider_name}')
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider)
db.session.commit()
@ -167,38 +175,36 @@ class BuiltinToolManageService:
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache()
return {'result': 'success'}
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_icon(
provider: str
):
def get_builtin_tool_provider_icon(provider: str):
"""
get tool provider icon and it's mimetype
get tool provider icon and it's mimetype
"""
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
with open(icon_path, 'rb') as f:
with open(icon_path, "rb") as f:
icon_bytes = f.read()
return icon_bytes, mime_type
@staticmethod
def list_builtin_tools(
user_id: str, tenant_id: str
) -> list[UserToolProvider]:
def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
"""
list builtin tools
list builtin tools
"""
# get all builtin providers
provider_controllers = ToolManager.list_builtin_providers()
# get all user added providers
db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id
).all() or []
db_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
# find provider
find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
find_provider = lambda provider: next(
filter(lambda db_provider: db_provider.provider == provider, db_providers), None
)
result: list[UserToolProvider] = []
@ -209,7 +215,7 @@ class BuiltinToolManageService:
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider_controller,
name_func=lambda x: x.identity.name
name_func=lambda x: x.identity.name,
):
continue
@ -217,7 +223,7 @@ class BuiltinToolManageService:
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.identity.name),
decrypt_credentials=True
decrypt_credentials=True,
)
# add icon
@ -225,12 +231,14 @@ class BuiltinToolManageService:
tools = provider_controller.get_tools()
for tool in tools:
user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller)
))
user_builtin_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
result.append(user_builtin_provider)
except Exception as e:

View File

@ -5,4 +5,4 @@ from core.tools.entities.values import default_tool_labels
class ToolLabelsService:
@classmethod
def list_tool_labels(cls) -> list[ToolLabel]:
return default_tool_labels
return default_tool_labels

View File

@ -11,13 +11,11 @@ class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
"""
list tool providers
list tool providers
:return: the list of tool providers
:return: the list of tool providers
"""
providers = ToolManager.user_list_providers(
user_id, tenant_id, typ
)
providers = ToolManager.user_list_providers(user_id, tenant_id, typ)
# add icon
for provider in providers:
@ -26,4 +24,3 @@ class ToolCommonService:
result = [provider.to_dict() for provider in providers]
return result

View File

@ -22,46 +22,39 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
logger = logging.getLogger(__name__)
class ToolTransformService:
@staticmethod
def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
"""
get tool provider icon url
get tool provider icon url
"""
url_prefix = (dify_config.CONSOLE_API_URL
+ "/console/api/workspaces/current/tool-provider/")
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/"
if provider_type == ToolProviderType.BUILT_IN.value:
return url_prefix + 'builtin/' + provider_name + '/icon'
return url_prefix + "builtin/" + provider_name + "/icon"
elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]:
try:
return json.loads(icon)
except:
return {
"background": "#252525",
"content": "\ud83d\ude01"
}
return ''
return {"background": "#252525", "content": "\ud83d\ude01"}
return ""
@staticmethod
def repack_provider(provider: Union[dict, UserToolProvider]):
"""
repack provider
repack provider
:param provider: the provider dict
:param provider: the provider dict
"""
if isinstance(provider, dict) and 'icon' in provider:
provider['icon'] = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider['type'],
provider_name=provider['name'],
icon=provider['icon']
if isinstance(provider, dict) and "icon" in provider:
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
)
elif isinstance(provider, UserToolProvider):
provider.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value,
provider_name=provider.name,
icon=provider.icon
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
)
@staticmethod
@ -92,14 +85,13 @@ class ToolTransformService:
masked_credentials={},
is_team_authorization=False,
tools=[],
labels=provider_controller.tool_labels
labels=provider_controller.tool_labels,
)
# get credentials schema
schema = provider_controller.get_credentials_schema()
for name, value in schema.items():
result.masked_credentials[name] = \
ToolProviderCredentials.CredentialsType.default(value.type)
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
# check if the provider need credentials
if not provider_controller.need_credentials:
@ -113,8 +105,7 @@ class ToolTransformService:
# init tool configuration
tool_configuration = ToolConfigurationManager(
tenant_id=db_provider.tenant_id,
provider_controller=provider_controller
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
@ -124,7 +115,7 @@ class ToolTransformService:
result.original_credentials = decrypted_credentials
return result
@staticmethod
def api_provider_to_controller(
db_provider: ApiToolProvider,
@ -135,25 +126,23 @@ class ToolTransformService:
# package tool provider controller
controller = ApiToolProviderController.from_db(
db_provider=db_provider,
auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else
ApiProviderAuthType.NONE
auth_type=ApiProviderAuthType.API_KEY
if db_provider.credentials["auth_type"] == "api_key"
else ApiProviderAuthType.NONE,
)
return controller
@staticmethod
def workflow_provider_to_controller(
db_provider: WorkflowToolProvider
) -> WorkflowToolProviderController:
def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
"""
convert provider controller to provider
"""
return WorkflowToolProviderController.from_db(db_provider)
@staticmethod
def workflow_provider_to_user_provider(
provider_controller: WorkflowToolProviderController,
labels: list[str] = None
provider_controller: WorkflowToolProviderController, labels: list[str] = None
):
"""
convert provider controller to user provider
@ -175,7 +164,7 @@ class ToolTransformService:
masked_credentials={},
is_team_authorization=True,
tools=[],
labels=labels or []
labels=labels or [],
)
@staticmethod
@ -183,16 +172,16 @@ class ToolTransformService:
provider_controller: ApiToolProviderController,
db_provider: ApiToolProvider,
decrypt_credentials: bool = True,
labels: list[str] = None
labels: list[str] = None,
) -> UserToolProvider:
"""
convert provider controller to user provider
"""
username = 'Anonymous'
username = "Anonymous"
try:
username = db_provider.user.name
except Exception as e:
logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}')
logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
# add provider into providers
credentials = db_provider.credentials
result = UserToolProvider(
@ -212,14 +201,13 @@ class ToolTransformService:
masked_credentials={},
is_team_authorization=True,
tools=[],
labels=labels or []
labels=labels or [],
)
if decrypt_credentials:
# init tool configuration
tool_configuration = ToolConfigurationManager(
tenant_id=db_provider.tenant_id,
provider_controller=provider_controller
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
)
# decrypt the credentials and mask the credentials
@ -229,23 +217,25 @@ class ToolTransformService:
result.masked_credentials = masked_credentials
return result
@staticmethod
def tool_to_user_tool(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
credentials: dict = None,
tool: Union[ApiToolBundle, WorkflowTool, Tool],
credentials: dict = None,
tenant_id: str = None,
labels: list[str] = None
labels: list[str] = None,
) -> UserTool:
"""
convert tool to user tool
"""
if isinstance(tool, Tool):
# fork tool runtime
tool = tool.fork_tool_runtime(runtime={
'credentials': credentials,
'tenant_id': tenant_id,
})
tool = tool.fork_tool_runtime(
runtime={
"credentials": credentials,
"tenant_id": tenant_id,
}
)
# get tool parameters
parameters = tool.parameters or []
@ -270,20 +260,14 @@ class ToolTransformService:
label=tool.identity.label,
description=tool.description.human,
parameters=current_parameters,
labels=labels
labels=labels,
)
if isinstance(tool, ApiToolBundle):
return UserTool(
author=tool.author,
name=tool.operation_id,
label=I18nObject(
en_US=tool.operation_id,
zh_Hans=tool.operation_id
),
description=I18nObject(
en_US=tool.summary or '',
zh_Hans=tool.summary or ''
),
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
parameters=tool.parameters,
labels=labels
)
labels=labels,
)

View File

@ -19,10 +19,21 @@ class WorkflowToolManageService:
"""
Service class for managing workflow tools.
"""
@classmethod
def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str,
label: str, icon: dict, description: str,
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
def create_workflow_tool(
cls,
user_id: str,
tenant_id: str,
workflow_app_id: str,
name: str,
label: str,
icon: dict,
description: str,
parameters: list[dict],
privacy_policy: str = "",
labels: list[str] = None,
) -> dict:
"""
Create a workflow tool.
:param user_id: the user id
@ -38,27 +49,28 @@ class WorkflowToolManageService:
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id)
).first()
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
.filter(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
)
.first()
)
if existing_workflow_tool_provider is not None:
raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists')
app: App = db.session.query(App).filter(
App.id == workflow_app_id,
App.tenant_id == tenant_id
).first()
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
if app is None:
raise ValueError(f'App {workflow_app_id} not found')
raise ValueError(f"App {workflow_app_id} not found")
workflow: Workflow = app.workflow
if workflow is None:
raise ValueError(f'Workflow not found for app {workflow_app_id}')
raise ValueError(f"Workflow not found for app {workflow_app_id}")
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
@ -76,19 +88,26 @@ class WorkflowToolManageService:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
db.session.add(workflow_tool_provider)
db.session.commit()
return {
'result': 'success'
}
return {"result": "success"}
@classmethod
def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str,
name: str, label: str, icon: dict, description: str,
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
def update_workflow_tool(
cls,
user_id: str,
tenant_id: str,
workflow_tool_id: str,
name: str,
label: str,
icon: dict,
description: str,
parameters: list[dict],
privacy_policy: str = "",
labels: list[str] = None,
) -> dict:
"""
Update a workflow tool.
:param user_id: the user id
@ -106,35 +125,39 @@ class WorkflowToolManageService:
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id
).first()
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
.filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.first()
)
if existing_workflow_tool_provider is not None:
raise ValueError(f'Tool with name {name} already exists')
workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
).first()
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
if workflow_tool_provider is None:
raise ValueError(f'Tool {workflow_tool_id} not found')
app: App = db.session.query(App).filter(
App.id == workflow_tool_provider.app_id,
App.tenant_id == tenant_id
).first()
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App = (
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
)
if app is None:
raise ValueError(f'App {workflow_tool_provider.app_id} not found')
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
workflow: Workflow = app.workflow
if workflow is None:
raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}')
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
workflow_tool_provider.name = name
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
@ -154,13 +177,10 @@ class WorkflowToolManageService:
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
labels
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
return {
'result': 'success'
}
return {"result": "success"}
@classmethod
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
@ -170,9 +190,7 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id
).all()
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
tools = []
for provider in db_tools:
@ -188,14 +206,12 @@ class WorkflowToolManageService:
for tool in tools:
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=tool,
labels=labels.get(tool.provider_id, [])
provider_controller=tool, labels=labels.get(tool.provider_id, [])
)
ToolTransformService.repack_provider(user_tool_provider)
user_tool_provider.tools = [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=labels.get(tool.provider_id, [])
tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
)
]
result.append(user_tool_provider)
@ -211,15 +227,12 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
"""
db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
).delete()
db.session.commit()
return {
'result': 'success'
}
return {"result": "success"}
@classmethod
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
@ -230,40 +243,37 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
).first()
db_tool: WorkflowToolProvider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
if db_tool is None:
raise ValueError(f'Tool {workflow_tool_id} not found')
workflow_app: App = db.session.query(App).filter(
App.id == db_tool.app_id,
App.tenant_id == tenant_id
).first()
raise ValueError(f"Tool {workflow_tool_id} not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
if workflow_app is None:
raise ValueError(f'App {db_tool.app_id} not found')
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return {
'name': db_tool.name,
'label': db_tool.label,
'workflow_tool_id': db_tool.id,
'workflow_app_id': db_tool.app_id,
'icon': json.loads(db_tool.icon),
'description': db_tool.description,
'parameters': jsonable_encoder(db_tool.parameter_configurations),
'tool': ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool)
"name": db_tool.name,
"label": db_tool.label,
"workflow_tool_id": db_tool.id,
"workflow_app_id": db_tool.app_id,
"icon": json.loads(db_tool.icon),
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
),
'synced': workflow_app.workflow.version == db_tool.version,
'privacy_policy': db_tool.privacy_policy,
"synced": workflow_app.workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy,
}
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
"""
@ -273,40 +283,37 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == workflow_app_id
).first()
db_tool: WorkflowToolProvider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
)
if db_tool is None:
raise ValueError(f'Tool {workflow_app_id} not found')
workflow_app: App = db.session.query(App).filter(
App.id == db_tool.app_id,
App.tenant_id == tenant_id
).first()
raise ValueError(f"Tool {workflow_app_id} not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
if workflow_app is None:
raise ValueError(f'App {db_tool.app_id} not found')
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return {
'name': db_tool.name,
'label': db_tool.label,
'workflow_tool_id': db_tool.id,
'workflow_app_id': db_tool.app_id,
'icon': json.loads(db_tool.icon),
'description': db_tool.description,
'parameters': jsonable_encoder(db_tool.parameter_configurations),
'tool': ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool)
"name": db_tool.name,
"label": db_tool.label,
"workflow_tool_id": db_tool.id,
"workflow_app_id": db_tool.app_id,
"icon": json.loads(db_tool.icon),
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
),
'synced': workflow_app.workflow.version == db_tool.version,
'privacy_policy': db_tool.privacy_policy
"synced": workflow_app.workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy,
}
@classmethod
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
"""
@ -316,19 +323,19 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the list of tools
"""
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
).first()
db_tool: WorkflowToolProvider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
if db_tool is None:
raise ValueError(f'Tool {workflow_tool_id} not found')
raise ValueError(f"Tool {workflow_tool_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool)
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
)
]
]

View File

@ -7,10 +7,10 @@ from models.dataset import Dataset, DocumentSegment
class VectorService:
@classmethod
def create_segments_vector(cls, keywords_list: Optional[list[list[str]]],
segments: list[DocumentSegment], dataset: Dataset):
def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset
):
documents = []
for segment in segments:
document = Document(
@ -20,14 +20,12 @@ class VectorService:
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
},
)
documents.append(document)
if dataset.indexing_technique == 'high_quality':
if dataset.indexing_technique == "high_quality":
# save vector index
vector = Vector(
dataset=dataset
)
vector = Vector(dataset=dataset)
vector.add_texts(documents, duplicate_check=True)
# save keyword index
@ -50,13 +48,11 @@ class VectorService:
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
},
)
if dataset.indexing_technique == 'high_quality':
if dataset.indexing_technique == "high_quality":
# update vector index
vector = Vector(
dataset=dataset
)
vector = Vector(dataset=dataset)
vector.delete_by_ids([segment.index_node_id])
vector.add_texts([document], duplicate_check=True)

View File

@ -11,18 +11,29 @@ from services.conversation_service import ConversationService
class WebConversationService:
@classmethod
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
last_id: Optional[str], limit: int, invoke_from: InvokeFrom,
pinned: Optional[bool] = None,
sort_by='-updated_at') -> InfiniteScrollPagination:
def pagination_by_last_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
limit: int,
invoke_from: InvokeFrom,
pinned: Optional[bool] = None,
sort_by="-updated_at",
) -> InfiniteScrollPagination:
include_ids = None
exclude_ids = None
if pinned is not None:
pinned_conversations = db.session.query(PinnedConversation).filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
PinnedConversation.created_by == user.id
).order_by(PinnedConversation.created_at.desc()).all()
pinned_conversations = (
db.session.query(PinnedConversation)
.filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id,
)
.order_by(PinnedConversation.created_at.desc())
.all()
)
pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations]
if pinned:
include_ids = pinned_conversation_ids
@ -37,32 +48,34 @@ class WebConversationService:
invoke_from=invoke_from,
include_ids=include_ids,
exclude_ids=exclude_ids,
sort_by=sort_by
sort_by=sort_by,
)
@classmethod
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
pinned_conversation = db.session.query(PinnedConversation).filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
PinnedConversation.created_by == user.id
).first()
pinned_conversation = (
db.session.query(PinnedConversation)
.filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id,
)
.first()
)
if pinned_conversation:
return
conversation = ConversationService.get_conversation(
app_model=app_model,
conversation_id=conversation_id,
user=user
app_model=app_model, conversation_id=conversation_id, user=user
)
pinned_conversation = PinnedConversation(
app_id=app_model.id,
conversation_id=conversation.id,
created_by_role='account' if isinstance(user, Account) else 'end_user',
created_by=user.id
created_by_role="account" if isinstance(user, Account) else "end_user",
created_by=user.id,
)
db.session.add(pinned_conversation)
@ -70,12 +83,16 @@ class WebConversationService:
@classmethod
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
pinned_conversation = db.session.query(PinnedConversation).filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
PinnedConversation.created_by == user.id
).first()
pinned_conversation = (
db.session.query(PinnedConversation)
.filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id,
)
.first()
)
if not pinned_conversation:
return

View File

@ -11,161 +11,126 @@ from services.auth.api_key_auth_service import ApiKeyAuthService
class WebsiteService:
@classmethod
def document_create_args_validate(cls, args: dict):
if 'url' not in args or not args['url']:
raise ValueError('url is required')
if 'options' not in args or not args['options']:
raise ValueError('options is required')
if 'limit' not in args['options'] or not args['options']['limit']:
raise ValueError('limit is required')
if "url" not in args or not args["url"]:
raise ValueError("url is required")
if "options" not in args or not args["options"]:
raise ValueError("options is required")
if "limit" not in args["options"] or not args["options"]["limit"]:
raise ValueError("limit is required")
@classmethod
def crawl_url(cls, args: dict) -> dict:
provider = args.get('provider')
url = args.get('url')
options = args.get('options')
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
'website',
provider)
if provider == 'firecrawl':
provider = args.get("provider")
url = args.get("url")
options = args.get("options")
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id,
token=credentials.get('config').get('api_key')
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
crawl_sub_pages = options.get('crawl_sub_pages', False)
only_main_content = options.get('only_main_content', False)
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
crawl_sub_pages = options.get("crawl_sub_pages", False)
only_main_content = options.get("only_main_content", False)
if not crawl_sub_pages:
params = {
'crawlerOptions': {
"crawlerOptions": {
"includes": [],
"excludes": [],
"generateImgAltText": True,
"limit": 1,
'returnOnlyUrls': False,
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
"returnOnlyUrls": False,
"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False},
}
}
else:
includes = options.get('includes').split(',') if options.get('includes') else []
excludes = options.get('excludes').split(',') if options.get('excludes') else []
includes = options.get("includes").split(",") if options.get("includes") else []
excludes = options.get("excludes").split(",") if options.get("excludes") else []
params = {
'crawlerOptions': {
"crawlerOptions": {
"includes": includes if includes else [],
"excludes": excludes if excludes else [],
"generateImgAltText": True,
"limit": options.get('limit', 1),
'returnOnlyUrls': False,
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
"limit": options.get("limit", 1),
"returnOnlyUrls": False,
"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False},
}
}
if options.get('max_depth'):
params['crawlerOptions']['maxDepth'] = options.get('max_depth')
if options.get("max_depth"):
params["crawlerOptions"]["maxDepth"] = options.get("max_depth")
job_id = firecrawl_app.crawl_url(url, params)
website_crawl_time_cache_key = f'website_crawl_{job_id}'
website_crawl_time_cache_key = f"website_crawl_{job_id}"
time = str(datetime.datetime.now().timestamp())
redis_client.setex(website_crawl_time_cache_key, 3600, time)
return {
'status': 'active',
'job_id': job_id
}
return {"status": "active", "job_id": job_id}
else:
raise ValueError('Invalid provider')
raise ValueError("Invalid provider")
@classmethod
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
'website',
provider)
if provider == 'firecrawl':
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id,
token=credentials.get('config').get('api_key')
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
result = firecrawl_app.check_crawl_status(job_id)
crawl_status_data = {
'status': result.get('status', 'active'),
'job_id': job_id,
'total': result.get('total', 0),
'current': result.get('current', 0),
'data': result.get('data', [])
"status": result.get("status", "active"),
"job_id": job_id,
"total": result.get("total", 0),
"current": result.get("current", 0),
"data": result.get("data", []),
}
if crawl_status_data['status'] == 'completed':
website_crawl_time_cache_key = f'website_crawl_{job_id}'
if crawl_status_data["status"] == "completed":
website_crawl_time_cache_key = f"website_crawl_{job_id}"
start_time = redis_client.get(website_crawl_time_cache_key)
if start_time:
end_time = datetime.datetime.now().timestamp()
time_consuming = abs(end_time - float(start_time))
crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
redis_client.delete(website_crawl_time_cache_key)
else:
raise ValueError('Invalid provider')
raise ValueError("Invalid provider")
return crawl_status_data
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
'website',
provider)
if provider == 'firecrawl':
file_key = 'website_files/' + job_id + '.txt'
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
if provider == "firecrawl":
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
data = storage.load_once(file_key)
if data:
data = json.loads(data.decode('utf-8'))
data = json.loads(data.decode("utf-8"))
else:
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
result = firecrawl_app.check_crawl_status(job_id)
if result.get('status') != 'completed':
raise ValueError('Crawl job is not completed')
data = result.get('data')
if result.get("status") != "completed":
raise ValueError("Crawl job is not completed")
data = result.get("data")
if data:
for item in data:
if item.get('source_url') == url:
if item.get("source_url") == url:
return item
return None
else:
raise ValueError('Invalid provider')
raise ValueError("Invalid provider")
@classmethod
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
'website',
provider)
if provider == 'firecrawl':
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
params = {
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
}
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
params = {"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}}
result = firecrawl_app.scrape_url(url, params)
return result
else:
raise ValueError('Invalid provider')
raise ValueError("Invalid provider")

View File

@ -10,7 +10,6 @@ from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus
class WorkflowAppService:
def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination:
"""
Get paginate workflow app logs
@ -18,20 +17,14 @@ class WorkflowAppService:
:param args: request args
:return:
"""
query = (
db.select(WorkflowAppLog)
.where(
WorkflowAppLog.tenant_id == app_model.tenant_id,
WorkflowAppLog.app_id == app_model.id
)
query = db.select(WorkflowAppLog).where(
WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id
)
status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None
keyword = args['keyword']
status = WorkflowRunStatus.value_of(args.get("status")) if args.get("status") else None
keyword = args["keyword"]
if keyword or status:
query = query.join(
WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id
)
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
if keyword:
keyword_like_val = f"%{args['keyword'][:30]}%"
@ -39,7 +32,7 @@ class WorkflowAppService:
WorkflowRun.inputs.ilike(keyword_like_val),
WorkflowRun.outputs.ilike(keyword_like_val),
# filter keyword by end user session id if created by end user role
and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_like_val))
and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
]
# filter keyword by workflow run id
@ -49,23 +42,16 @@ class WorkflowAppService:
query = query.outerjoin(
EndUser,
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value)
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value),
).filter(or_(*keyword_conditions))
if status:
# join with workflow_run and filter by status
query = query.filter(
WorkflowRun.status == status.value
)
query = query.filter(WorkflowRun.status == status.value)
query = query.order_by(WorkflowAppLog.created_at.desc())
pagination = db.paginate(
query,
page=args['page'],
per_page=args['limit'],
error_out=False
)
pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
return pagination

View File

@ -18,6 +18,7 @@ class WorkflowRunService:
:param app_model: app model
:param args: request args
"""
class WorkflowWithMessage:
message_id: str
conversation_id: str
@ -33,9 +34,7 @@ class WorkflowRunService:
with_message_workflow_runs = []
for workflow_run in pagination.data:
message = workflow_run.message
with_message_workflow_run = WorkflowWithMessage(
workflow_run=workflow_run
)
with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run)
if message:
with_message_workflow_run.message_id = message.id
with_message_workflow_run.conversation_id = message.conversation_id
@ -53,26 +52,30 @@ class WorkflowRunService:
:param app_model: app model
:param args: request args
"""
limit = int(args.get('limit', 20))
limit = int(args.get("limit", 20))
base_query = db.session.query(WorkflowRun).filter(
WorkflowRun.tenant_id == app_model.tenant_id,
WorkflowRun.app_id == app_model.id,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
)
if args.get('last_id'):
if args.get("last_id"):
last_workflow_run = base_query.filter(
WorkflowRun.id == args.get('last_id'),
WorkflowRun.id == args.get("last_id"),
).first()
if not last_workflow_run:
raise ValueError('Last workflow run not exists')
raise ValueError("Last workflow run not exists")
workflow_runs = base_query.filter(
WorkflowRun.created_at < last_workflow_run.created_at,
WorkflowRun.id != last_workflow_run.id
).order_by(WorkflowRun.created_at.desc()).limit(limit).all()
workflow_runs = (
base_query.filter(
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
)
.order_by(WorkflowRun.created_at.desc())
.limit(limit)
.all()
)
else:
workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
@ -81,17 +84,13 @@ class WorkflowRunService:
current_page_first_workflow_run = workflow_runs[-1]
rest_count = base_query.filter(
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
WorkflowRun.id != current_page_first_workflow_run.id
WorkflowRun.id != current_page_first_workflow_run.id,
).count()
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(
data=workflow_runs,
limit=limit,
has_more=has_more
)
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun:
"""
@ -100,11 +99,15 @@ class WorkflowRunService:
:param app_model: app model
:param run_id: workflow run id
"""
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.tenant_id == app_model.tenant_id,
WorkflowRun.app_id == app_model.id,
WorkflowRun.id == run_id,
).first()
workflow_run = (
db.session.query(WorkflowRun)
.filter(
WorkflowRun.tenant_id == app_model.tenant_id,
WorkflowRun.app_id == app_model.id,
WorkflowRun.id == run_id,
)
.first()
)
return workflow_run
@ -117,12 +120,17 @@ class WorkflowRunService:
if not workflow_run:
return []
node_executions = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
WorkflowNodeExecution.app_id == app_model.id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == run_id,
).order_by(WorkflowNodeExecution.index.desc()).all()
node_executions = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
WorkflowNodeExecution.app_id == app_model.id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == run_id,
)
.order_by(WorkflowNodeExecution.index.desc())
.all()
)
return node_executions

View File

@ -37,11 +37,13 @@ class WorkflowService:
Get draft workflow
"""
# fetch draft workflow by app_model
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == 'draft'
).first()
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft"
)
.first()
)
# return draft workflow
return workflow
@ -55,11 +57,15 @@ class WorkflowService:
return None
# fetch published workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id
).first()
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id,
)
.first()
)
return workflow
@ -85,10 +91,7 @@ class WorkflowService:
raise WorkflowHashNotEqualError()
# validate features structure
self.validate_features_structure(
app_model=app_model,
features=features
)
self.validate_features_structure(app_model=app_model, features=features)
# create draft workflow if not found
if not workflow:
@ -96,7 +99,7 @@ class WorkflowService:
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type=WorkflowType.from_app_mode(app_model.mode).value,
version='draft',
version="draft",
graph=json.dumps(graph),
features=json.dumps(features),
created_by=account.id,
@ -122,9 +125,7 @@ class WorkflowService:
# return draft workflow
return workflow
def publish_workflow(self, app_model: App,
account: Account,
draft_workflow: Optional[Workflow] = None) -> Workflow:
def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow:
"""
Publish workflow from draft
@ -137,7 +138,7 @@ class WorkflowService:
draft_workflow = self.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError('No valid workflow found.')
raise ValueError("No valid workflow found.")
# create new workflow
workflow = Workflow(
@ -187,17 +188,16 @@ class WorkflowService:
workflow_engine_manager = WorkflowEngineManager()
return workflow_engine_manager.get_default_config(node_type, filters)
def run_draft_workflow_node(self, app_model: App,
node_id: str,
user_inputs: dict,
account: Account) -> WorkflowNodeExecution:
def run_draft_workflow_node(
self, app_model: App, node_id: str, user_inputs: dict, account: Account
) -> WorkflowNodeExecution:
"""
Run draft workflow node
"""
# fetch draft workflow by app_model
draft_workflow = self.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError('Workflow not initialized')
raise ValueError("Workflow not initialized")
# run draft workflow node
workflow_engine_manager = WorkflowEngineManager()
@ -226,7 +226,7 @@ class WorkflowService:
created_by_role=CreatedByRole.ACCOUNT.value,
created_by=account.id,
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
)
db.session.add(workflow_node_execution)
db.session.commit()
@ -247,14 +247,15 @@ class WorkflowService:
inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None,
process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None,
outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None,
execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata))
if node_run_result.metadata else None),
execution_metadata=(
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
),
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
elapsed_time=time.perf_counter() - start_at,
created_by_role=CreatedByRole.ACCOUNT.value,
created_by=account.id,
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
)
else:
# create workflow node execution
@ -273,7 +274,7 @@ class WorkflowService:
created_by_role=CreatedByRole.ACCOUNT.value,
created_by=account.id,
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
)
db.session.add(workflow_node_execution)
@ -295,16 +296,16 @@ class WorkflowService:
workflow_converter = WorkflowConverter()
if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]:
raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.')
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
# convert to workflow
new_app = workflow_converter.convert_to_workflow(
app_model=app_model,
account=account,
name=args.get('name'),
icon_type=args.get('icon_type'),
icon=args.get('icon'),
icon_background=args.get('icon_background'),
name=args.get("name"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
)
return new_app
@ -312,15 +313,11 @@ class WorkflowService:
def validate_features_structure(self, app_model: App, features: dict) -> dict:
if app_model.mode == AppMode.ADVANCED_CHAT.value:
return AdvancedChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=features,
only_structure_validate=True
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
elif app_model.mode == AppMode.WORKFLOW.value:
return WorkflowAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=features,
only_structure_validate=True
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
else:
raise ValueError(f"Invalid app mode: {app_model.mode}")

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from configs import dify_config
@ -14,34 +13,40 @@ class WorkspaceService:
if not tenant:
return None
tenant_info = {
'id': tenant.id,
'name': tenant.name,
'plan': tenant.plan,
'status': tenant.status,
'created_at': tenant.created_at,
'in_trail': True,
'trial_end_reason': None,
'role': 'normal',
"id": tenant.id,
"name": tenant.name,
"plan": tenant.plan,
"status": tenant.status,
"created_at": tenant.created_at,
"in_trail": True,
"trial_end_reason": None,
"role": "normal",
}
# Get role of user
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.account_id == current_user.id
).first()
tenant_info['role'] = tenant_account_join.role
tenant_account_join = (
db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
.first()
)
tenant_info["role"] = tenant_account_join.role
can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo
can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
if can_replace_logo and TenantService.has_roles(tenant,
[TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
if can_replace_logo and TenantService.has_roles(
tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]
):
base_url = dify_config.FILES_URL
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
replace_webapp_logo = (
f"{base_url}/files/workspaces/{tenant.id}/webapp-logo"
if tenant.custom_config_dict.get("replace_webapp_logo")
else None
)
remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False)
tenant_info['custom_config'] = {
'remove_webapp_brand': remove_webapp_brand,
'replace_webapp_logo': replace_webapp_logo,
tenant_info["custom_config"] = {
"remove_webapp_brand": remove_webapp_brand,
"replace_webapp_logo": replace_webapp_logo,
}
return tenant_info