mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
Feat/fix ops trace (#5672)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
parent
f0ea540b34
commit
e8b8f6c6dd
|
@ -3,7 +3,7 @@
|
|||
cd web && npm install
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
|
||||
|
||||
|
|
14
.vscode/launch.json
vendored
14
.vscode/launch.json
vendored
|
@ -37,7 +37,19 @@
|
|||
"FLASK_DEBUG": "1",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
|
||||
"args": [
|
||||
"-A",
|
||||
"app.celery",
|
||||
"worker",
|
||||
"-P",
|
||||
"gevent",
|
||||
"-c",
|
||||
"1",
|
||||
"--loglevel",
|
||||
"info",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace"
|
||||
]
|
||||
},
|
||||
]
|
||||
}
|
|
@ -66,7 +66,7 @@
|
|||
10. If you need to debug local async processing, please start the worker service.
|
||||
|
||||
```bash
|
||||
poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail
|
||||
poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace
|
||||
```
|
||||
|
||||
The started celery app handles the async tasks, e.g. dataset importing and documents indexing.
|
||||
|
|
|
@ -26,7 +26,6 @@ from werkzeug.exceptions import Unauthorized
|
|||
from commands import register_commands
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
from extensions import (
|
||||
ext_celery,
|
||||
ext_code_based_extension,
|
||||
|
@ -43,7 +42,6 @@ from extensions import (
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
from models import account, dataset, model, source, task, tool, tools, web
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
|
|
@ -57,7 +57,7 @@ class InputModeration:
|
|||
timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if not moderation_result.flagged:
|
||||
return False, inputs, query
|
||||
|
||||
|
|
|
@ -94,5 +94,15 @@ class ToolTraceInfo(BaseTraceInfo):
|
|||
|
||||
|
||||
class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
conversation_id: str
|
||||
conversation_id: Optional[str] = None
|
||||
tenant_id: str
|
||||
|
||||
trace_info_info_map = {
|
||||
'WorkflowTraceInfo': WorkflowTraceInfo,
|
||||
'MessageTraceInfo': MessageTraceInfo,
|
||||
'ModerationTraceInfo': ModerationTraceInfo,
|
||||
'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo,
|
||||
'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo,
|
||||
'ToolTraceInfo': ToolTraceInfo,
|
||||
'GenerateNameTraceInfo': GenerateNameTraceInfo,
|
||||
}
|
|
@ -147,6 +147,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
# add span
|
||||
if trace_info.message_id:
|
||||
span_data = LangfuseSpan(
|
||||
id=node_execution_id,
|
||||
name=f"{node_name}_{node_execution_id}",
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
|
@ -160,6 +161,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
)
|
||||
else:
|
||||
span_data = LangfuseSpan(
|
||||
id=node_execution_id,
|
||||
name=f"{node_name}_{node_execution_id}",
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
|
@ -173,6 +175,30 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
# add generation
|
||||
generation_usage = GenerationUsage(
|
||||
totalTokens=total_token,
|
||||
)
|
||||
|
||||
node_generation_data = LangfuseGeneration(
|
||||
name=f"generation_{node_execution_id}",
|
||||
trace_id=trace_id,
|
||||
parent_observation_id=node_execution_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
metadata=metadata,
|
||||
level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
|
||||
status_message=trace_info.error if trace_info.error else "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
self.add_generation(langfuse_generation_data=node_generation_data)
|
||||
|
||||
def message_trace(
|
||||
self, trace_info: MessageTraceInfo, **kwargs
|
||||
):
|
||||
|
@ -186,7 +212,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser = db.session.query(EndUser).filter(
|
||||
EndUser.id == message_data.from_end_user_id
|
||||
).first().session_id
|
||||
).first()
|
||||
user_id = end_user_data.session_id
|
||||
|
||||
trace_data = LangfuseTrace(
|
||||
|
@ -220,6 +246,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
output=trace_info.answer_tokens,
|
||||
total=trace_info.total_tokens,
|
||||
unit=UnitEnum.TOKENS,
|
||||
totalCost=message_data.total_price,
|
||||
)
|
||||
|
||||
langfuse_generation_data = LangfuseGeneration(
|
||||
|
@ -303,7 +330,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
||||
level=LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR,
|
||||
status_message=trace_info.error,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import current_app
|
||||
|
||||
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import (
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
|
@ -31,6 +32,7 @@ from core.ops.utils import get_message_data
|
|||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
provider_config_map = {
|
||||
TracingProviderEnum.LANGFUSE.value: {
|
||||
|
@ -105,7 +107,7 @@ class OpsTraceManager:
|
|||
return config_class(**new_config).model_dump()
|
||||
|
||||
@classmethod
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config:dict):
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
|
||||
"""
|
||||
Decrypt tracing config
|
||||
:param tracing_provider: tracing provider
|
||||
|
@ -295,11 +297,9 @@ class TraceTask:
|
|||
self.kwargs = kwargs
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def execute(self, trace_instance: BaseTraceInstance):
|
||||
def execute(self):
|
||||
method_name, trace_info = self.preprocess()
|
||||
if trace_instance:
|
||||
method = trace_instance.trace
|
||||
method(trace_info)
|
||||
return trace_info
|
||||
|
||||
def preprocess(self):
|
||||
if self.trace_type == TraceTaskName.CONVERSATION_TRACE:
|
||||
|
@ -372,7 +372,7 @@ class TraceTask:
|
|||
}
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
workflow_data=workflow_run,
|
||||
workflow_data=workflow_run.to_dict(),
|
||||
conversation_id=conversation_id,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=tenant_id,
|
||||
|
@ -427,7 +427,8 @@ class TraceTask:
|
|||
message_tokens = message_data.message_tokens
|
||||
|
||||
message_trace_info = MessageTraceInfo(
|
||||
message_data=message_data,
|
||||
message_id=message_id,
|
||||
message_data=message_data.to_dict(),
|
||||
conversation_model=conversation_mode,
|
||||
message_tokens=message_tokens,
|
||||
answer_tokens=message_data.answer_tokens,
|
||||
|
@ -469,7 +470,7 @@ class TraceTask:
|
|||
moderation_trace_info = ModerationTraceInfo(
|
||||
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
|
||||
inputs=inputs,
|
||||
message_data=message_data,
|
||||
message_data=message_data.to_dict(),
|
||||
flagged=moderation_result.flagged,
|
||||
action=moderation_result.action,
|
||||
preset_response=moderation_result.preset_response,
|
||||
|
@ -508,7 +509,7 @@ class TraceTask:
|
|||
|
||||
suggested_question_trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
|
||||
message_data=message_data,
|
||||
message_data=message_data.to_dict(),
|
||||
inputs=message_data.message,
|
||||
outputs=message_data.answer,
|
||||
start_time=timer.get("start"),
|
||||
|
@ -550,11 +551,11 @@ class TraceTask:
|
|||
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id=message_id,
|
||||
inputs=message_data.query if message_data.query else message_data.inputs,
|
||||
documents=documents,
|
||||
documents=[doc.model_dump() for doc in documents],
|
||||
start_time=timer.get("start"),
|
||||
end_time=timer.get("end"),
|
||||
metadata=metadata,
|
||||
message_data=message_data,
|
||||
message_data=message_data.to_dict(),
|
||||
)
|
||||
|
||||
return dataset_retrieval_trace_info
|
||||
|
@ -613,7 +614,7 @@ class TraceTask:
|
|||
|
||||
tool_trace_info = ToolTraceInfo(
|
||||
message_id=message_id,
|
||||
message_data=message_data,
|
||||
message_data=message_data.to_dict(),
|
||||
tool_name=tool_name,
|
||||
start_time=timer.get("start") if timer else created_time,
|
||||
end_time=timer.get("end") if timer else end_time,
|
||||
|
@ -657,31 +658,71 @@ class TraceTask:
|
|||
return generate_name_trace_info
|
||||
|
||||
|
||||
trace_manager_timer = None
|
||||
trace_manager_queue = queue.Queue()
|
||||
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 1))
|
||||
trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
|
||||
|
||||
|
||||
class TraceQueueManager:
|
||||
def __init__(self, app_id=None, conversation_id=None, message_id=None):
|
||||
tracing_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
|
||||
self.queue = queue.Queue()
|
||||
self.is_running = True
|
||||
self.thread = threading.Thread(
|
||||
target=self.process_queue, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'trace_instance': tracing_instance
|
||||
}
|
||||
)
|
||||
self.thread.start()
|
||||
global trace_manager_timer
|
||||
|
||||
def stop(self):
|
||||
self.is_running = False
|
||||
|
||||
def process_queue(self, flask_app: Flask, trace_instance: BaseTraceInstance):
|
||||
with flask_app.app_context():
|
||||
while self.is_running:
|
||||
try:
|
||||
task = self.queue.get(timeout=60)
|
||||
task.execute(trace_instance)
|
||||
self.queue.task_done()
|
||||
except queue.Empty:
|
||||
self.stop()
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.message_id = message_id
|
||||
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
|
||||
self.flask_app = current_app._get_current_object()
|
||||
if trace_manager_timer is None:
|
||||
self.start_timer()
|
||||
|
||||
def add_trace_task(self, trace_task: TraceTask):
|
||||
self.queue.put(trace_task)
|
||||
global trace_manager_timer
|
||||
global trace_manager_queue
|
||||
try:
|
||||
if self.trace_instance:
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception as e:
|
||||
logging.debug(f"Error adding trace task: {e}")
|
||||
finally:
|
||||
self.start_timer()
|
||||
|
||||
def collect_tasks(self):
|
||||
global trace_manager_queue
|
||||
tasks = []
|
||||
while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
|
||||
task = trace_manager_queue.get_nowait()
|
||||
tasks.append(task)
|
||||
trace_manager_queue.task_done()
|
||||
return tasks
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
tasks = self.collect_tasks()
|
||||
if tasks:
|
||||
self.send_to_celery(tasks)
|
||||
except Exception as e:
|
||||
logging.debug(f"Error processing trace tasks: {e}")
|
||||
|
||||
def start_timer(self):
|
||||
global trace_manager_timer
|
||||
if trace_manager_timer is None or not trace_manager_timer.is_alive():
|
||||
trace_manager_timer = threading.Timer(
|
||||
trace_manager_interval, self.run
|
||||
)
|
||||
trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
|
||||
trace_manager_timer.daemon = False
|
||||
trace_manager_timer.start()
|
||||
|
||||
def send_to_celery(self, tasks: list[TraceTask]):
|
||||
with self.flask_app.app_context():
|
||||
for task in tasks:
|
||||
trace_info = task.execute()
|
||||
task_data = {
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"trace_info_type": type(trace_info).__name__,
|
||||
"trace_info": trace_info.model_dump() if trace_info else {},
|
||||
}
|
||||
process_trace_tasks.delay(task_data)
|
||||
|
|
|
@ -12,7 +12,7 @@ from core.model_manager import ModelInstance, ModelManager
|
|||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.ops_trace_manager import TraceTask, TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
||||
from core.ops.utils import measure_time
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
|
@ -357,7 +357,7 @@ class DatasetRetrieval:
|
|||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
|
|
|
@ -94,7 +94,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
|
||||
and node_data.reasoning_mode == 'function_call':
|
||||
and node_data.reasoning_mode == 'function_call':
|
||||
# use function call
|
||||
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
|
||||
node_data, query, variable_pool, model_config, memory
|
||||
|
|
|
@ -9,7 +9,7 @@ fi
|
|||
|
||||
if [[ "${MODE}" == "worker" ]]; then
|
||||
celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} -c ${CELERY_WORKER_AMOUNT:-1} --loglevel INFO \
|
||||
-Q ${CELERY_QUEUES:-dataset,generation,mail}
|
||||
-Q ${CELERY_QUEUES:-dataset,generation,mail,ops_trace}
|
||||
elif [[ "${MODE}" == "beat" ]]; then
|
||||
celery -A app.celery beat --loglevel INFO
|
||||
else
|
||||
|
|
|
@ -31,17 +31,11 @@ def upgrade():
|
|||
with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
|
||||
batch_op.create_index('tracing_app_config_app_id_idx', ['app_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('trace_config', sa.Text(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('trace_config')
|
||||
|
||||
# ### commands auto generated by Alembic - please adjust! ##
|
||||
with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
|
||||
batch_op.drop_index('tracing_app_config_app_id_idx')
|
||||
|
||||
|
|
|
@ -35,18 +35,11 @@ def upgrade():
|
|||
|
||||
with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
|
||||
batch_op.drop_index('tracing_app_config_app_id_idx')
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('trace_config')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('trace_config', sa.TEXT(), autoincrement=False, nullable=True))
|
||||
|
||||
op.create_table('tracing_app_configs',
|
||||
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
|
||||
sa.Column('app_id', sa.UUID(), autoincrement=False, nullable=False),
|
||||
|
|
|
@ -352,6 +352,101 @@ class Document(db.Model):
|
|||
return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \
|
||||
.filter(DocumentSegment.document_id == self.id).scalar()
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'dataset_id': self.dataset_id,
|
||||
'position': self.position,
|
||||
'data_source_type': self.data_source_type,
|
||||
'data_source_info': self.data_source_info,
|
||||
'dataset_process_rule_id': self.dataset_process_rule_id,
|
||||
'batch': self.batch,
|
||||
'name': self.name,
|
||||
'created_from': self.created_from,
|
||||
'created_by': self.created_by,
|
||||
'created_api_request_id': self.created_api_request_id,
|
||||
'created_at': self.created_at,
|
||||
'processing_started_at': self.processing_started_at,
|
||||
'file_id': self.file_id,
|
||||
'word_count': self.word_count,
|
||||
'parsing_completed_at': self.parsing_completed_at,
|
||||
'cleaning_completed_at': self.cleaning_completed_at,
|
||||
'splitting_completed_at': self.splitting_completed_at,
|
||||
'tokens': self.tokens,
|
||||
'indexing_latency': self.indexing_latency,
|
||||
'completed_at': self.completed_at,
|
||||
'is_paused': self.is_paused,
|
||||
'paused_by': self.paused_by,
|
||||
'paused_at': self.paused_at,
|
||||
'error': self.error,
|
||||
'stopped_at': self.stopped_at,
|
||||
'indexing_status': self.indexing_status,
|
||||
'enabled': self.enabled,
|
||||
'disabled_at': self.disabled_at,
|
||||
'disabled_by': self.disabled_by,
|
||||
'archived': self.archived,
|
||||
'archived_reason': self.archived_reason,
|
||||
'archived_by': self.archived_by,
|
||||
'archived_at': self.archived_at,
|
||||
'updated_at': self.updated_at,
|
||||
'doc_type': self.doc_type,
|
||||
'doc_metadata': self.doc_metadata,
|
||||
'doc_form': self.doc_form,
|
||||
'doc_language': self.doc_language,
|
||||
'display_status': self.display_status,
|
||||
'data_source_info_dict': self.data_source_info_dict,
|
||||
'average_segment_length': self.average_segment_length,
|
||||
'dataset_process_rule': self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
|
||||
'dataset': self.dataset.to_dict() if self.dataset else None,
|
||||
'segment_count': self.segment_count,
|
||||
'hit_count': self.hit_count
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(
|
||||
id=data.get('id'),
|
||||
tenant_id=data.get('tenant_id'),
|
||||
dataset_id=data.get('dataset_id'),
|
||||
position=data.get('position'),
|
||||
data_source_type=data.get('data_source_type'),
|
||||
data_source_info=data.get('data_source_info'),
|
||||
dataset_process_rule_id=data.get('dataset_process_rule_id'),
|
||||
batch=data.get('batch'),
|
||||
name=data.get('name'),
|
||||
created_from=data.get('created_from'),
|
||||
created_by=data.get('created_by'),
|
||||
created_api_request_id=data.get('created_api_request_id'),
|
||||
created_at=data.get('created_at'),
|
||||
processing_started_at=data.get('processing_started_at'),
|
||||
file_id=data.get('file_id'),
|
||||
word_count=data.get('word_count'),
|
||||
parsing_completed_at=data.get('parsing_completed_at'),
|
||||
cleaning_completed_at=data.get('cleaning_completed_at'),
|
||||
splitting_completed_at=data.get('splitting_completed_at'),
|
||||
tokens=data.get('tokens'),
|
||||
indexing_latency=data.get('indexing_latency'),
|
||||
completed_at=data.get('completed_at'),
|
||||
is_paused=data.get('is_paused'),
|
||||
paused_by=data.get('paused_by'),
|
||||
paused_at=data.get('paused_at'),
|
||||
error=data.get('error'),
|
||||
stopped_at=data.get('stopped_at'),
|
||||
indexing_status=data.get('indexing_status'),
|
||||
enabled=data.get('enabled'),
|
||||
disabled_at=data.get('disabled_at'),
|
||||
disabled_by=data.get('disabled_by'),
|
||||
archived=data.get('archived'),
|
||||
archived_reason=data.get('archived_reason'),
|
||||
archived_by=data.get('archived_by'),
|
||||
archived_at=data.get('archived_at'),
|
||||
updated_at=data.get('updated_at'),
|
||||
doc_type=data.get('doc_type'),
|
||||
doc_metadata=data.get('doc_metadata'),
|
||||
doc_form=data.get('doc_form'),
|
||||
doc_language=data.get('doc_language')
|
||||
)
|
||||
|
||||
class DocumentSegment(db.Model):
|
||||
__tablename__ = 'document_segments'
|
||||
|
|
|
@ -838,6 +838,49 @@ class Message(db.Model):
|
|||
|
||||
return None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'id': self.id,
|
||||
'app_id': self.app_id,
|
||||
'conversation_id': self.conversation_id,
|
||||
'inputs': self.inputs,
|
||||
'query': self.query,
|
||||
'message': self.message,
|
||||
'answer': self.answer,
|
||||
'status': self.status,
|
||||
'error': self.error,
|
||||
'message_metadata': self.message_metadata_dict,
|
||||
'from_source': self.from_source,
|
||||
'from_end_user_id': self.from_end_user_id,
|
||||
'from_account_id': self.from_account_id,
|
||||
'created_at': self.created_at.isoformat(),
|
||||
'updated_at': self.updated_at.isoformat(),
|
||||
'agent_based': self.agent_based,
|
||||
'workflow_run_id': self.workflow_run_id
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(
|
||||
id=data['id'],
|
||||
app_id=data['app_id'],
|
||||
conversation_id=data['conversation_id'],
|
||||
inputs=data['inputs'],
|
||||
query=data['query'],
|
||||
message=data['message'],
|
||||
answer=data['answer'],
|
||||
status=data['status'],
|
||||
error=data['error'],
|
||||
message_metadata=json.dumps(data['message_metadata']),
|
||||
from_source=data['from_source'],
|
||||
from_end_user_id=data['from_end_user_id'],
|
||||
from_account_id=data['from_account_id'],
|
||||
created_at=data['created_at'],
|
||||
updated_at=data['updated_at'],
|
||||
agent_based=data['agent_based'],
|
||||
workflow_run_id=data['workflow_run_id']
|
||||
)
|
||||
|
||||
|
||||
class MessageFeedback(db.Model):
|
||||
__tablename__ = 'message_feedbacks'
|
||||
|
|
|
@ -324,6 +324,55 @@ class WorkflowRun(db.Model):
|
|||
def workflow(self):
|
||||
return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'app_id': self.app_id,
|
||||
'sequence_number': self.sequence_number,
|
||||
'workflow_id': self.workflow_id,
|
||||
'type': self.type,
|
||||
'triggered_from': self.triggered_from,
|
||||
'version': self.version,
|
||||
'graph': self.graph_dict,
|
||||
'inputs': self.inputs_dict,
|
||||
'status': self.status,
|
||||
'outputs': self.outputs_dict,
|
||||
'error': self.error,
|
||||
'elapsed_time': self.elapsed_time,
|
||||
'total_tokens': self.total_tokens,
|
||||
'total_steps': self.total_steps,
|
||||
'created_by_role': self.created_by_role,
|
||||
'created_by': self.created_by,
|
||||
'created_at': self.created_at,
|
||||
'finished_at': self.finished_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> 'WorkflowRun':
|
||||
return cls(
|
||||
id=data.get('id'),
|
||||
tenant_id=data.get('tenant_id'),
|
||||
app_id=data.get('app_id'),
|
||||
sequence_number=data.get('sequence_number'),
|
||||
workflow_id=data.get('workflow_id'),
|
||||
type=data.get('type'),
|
||||
triggered_from=data.get('triggered_from'),
|
||||
version=data.get('version'),
|
||||
graph=json.dumps(data.get('graph')),
|
||||
inputs=json.dumps(data.get('inputs')),
|
||||
status=data.get('status'),
|
||||
outputs=json.dumps(data.get('outputs')),
|
||||
error=data.get('error'),
|
||||
elapsed_time=data.get('elapsed_time'),
|
||||
total_tokens=data.get('total_tokens'),
|
||||
total_steps=data.get('total_steps'),
|
||||
created_by_role=data.get('created_by_role'),
|
||||
created_by=data.get('created_by'),
|
||||
created_at=data.get('created_at'),
|
||||
finished_at=data.get('finished_at'),
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionTriggeredFrom(Enum):
|
||||
"""
|
||||
|
|
46
api/tasks/ops_trace_task.py
Normal file
46
api/tasks/ops_trace_task.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from celery import shared_task
|
||||
from flask import current_app
|
||||
|
||||
from core.ops.entities.trace_entity import trace_info_info_map
|
||||
from core.rag.models.document import Document
|
||||
from models.model import Message
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
|
||||
@shared_task(queue='ops_trace')
|
||||
def process_trace_tasks(tasks_data):
|
||||
"""
|
||||
Async process trace tasks
|
||||
:param tasks_data: List of dictionaries containing task data
|
||||
|
||||
Usage: process_trace_tasks.delay(tasks_data)
|
||||
"""
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
|
||||
trace_info = tasks_data.get('trace_info')
|
||||
app_id = tasks_data.get('app_id')
|
||||
conversation_id = tasks_data.get('conversation_id')
|
||||
message_id = tasks_data.get('message_id')
|
||||
trace_info_type = tasks_data.get('trace_info_type')
|
||||
trace_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
|
||||
|
||||
if trace_info.get('message_data'):
|
||||
trace_info['message_data'] = Message.from_dict(data=trace_info['message_data'])
|
||||
if trace_info.get('workflow_data'):
|
||||
trace_info['workflow_data'] = WorkflowRun.from_dict(data=trace_info['workflow_data'])
|
||||
if trace_info.get('documents'):
|
||||
trace_info['documents'] = [Document(**doc) for doc in trace_info['documents']]
|
||||
|
||||
try:
|
||||
if trace_instance:
|
||||
with current_app.app_context():
|
||||
trace_type = trace_info_info_map.get(trace_info_type)
|
||||
if trace_type:
|
||||
trace_info = trace_type(**trace_info)
|
||||
trace_instance.trace(trace_info)
|
||||
end_at = time.perf_counter()
|
||||
except Exception:
|
||||
logging.exception("Processing trace tasks failed")
|
Loading…
Reference in New Issue
Block a user