diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index bc088e9dc9..a155de09b3 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -1,7 +1,7 @@ import time from typing import cast, Optional, List, Tuple, Generator, Union -from core.application_queue_manager import ApplicationQueueManager +from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory @@ -183,7 +183,7 @@ class AppRunner: index=index, message=AssistantPromptMessage(content=token) ) - )) + ), PublishFrom.APPLICATION_MANAGER) index += 1 time.sleep(0.01) @@ -193,7 +193,8 @@ class AppRunner: prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage if usage else LLMUsage.empty_usage() - ) + ), + pub_from=PublishFrom.APPLICATION_MANAGER ) def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], @@ -226,7 +227,8 @@ class AppRunner: :return: """ queue_manager.publish_message_end( - llm_result=invoke_result + llm_result=invoke_result, + pub_from=PublishFrom.APPLICATION_MANAGER ) def _handle_invoke_result_stream(self, invoke_result: Generator, @@ -242,7 +244,7 @@ class AppRunner: text = '' usage = None for result in invoke_result: - queue_manager.publish_chunk_message(result) + queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER) text += result.delta.message.content @@ -263,5 +265,6 @@ class AppRunner: ) queue_manager.publish_message_end( - llm_result=llm_result + llm_result=llm_result, + pub_from=PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index db235c565f..a77baaa495 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \ AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity -from core.application_queue_manager import ApplicationQueueManager +from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.features.annotation_reply import AnnotationReplyFeature from core.features.dataset_retrieval import DatasetRetrievalFeature from core.features.external_data_fetch import ExternalDataFetchFeature @@ -121,7 +121,8 @@ class BasicApplicationRunner(AppRunner): if annotation_reply: queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id + message_annotation_id=annotation_reply.id, + pub_from=PublishFrom.APPLICATION_MANAGER ) self.direct_output( queue_manager=queue_manager, diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 0281259453..bea045f160 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule from core.entities.application_entities import ApplicationGenerateEntity -from core.application_queue_manager import ApplicationQueueManager +from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \ QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \ AnnotationReplyEvent @@ -312,8 +312,11 @@ class GenerateTaskPipeline: index=0, message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) ) - )) - self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION)) + ), PublishFrom.TASK_PIPELINE) + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), + PublishFrom.TASK_PIPELINE + ) continue else: self._output_moderation_handler.append_new_token(delta_text) diff --git a/api/core/app_runner/moderation_handler.py b/api/core/app_runner/moderation_handler.py index c4f2403e7f..2917da6f29 100644 --- a/api/core/app_runner/moderation_handler.py +++ b/api/core/app_runner/moderation_handler.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Dict from flask import current_app, Flask from pydantic import BaseModel +from core.application_queue_manager import PublishFrom from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.factory import ModerationFactory @@ -66,7 +67,7 @@ class OutputModerationHandler(BaseModel): final_output = result.text if public_event: - self.on_message_replace_func(final_output) + self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE) return final_output diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 82bf4ec2ef..88500f3a47 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_template import PromptTemplateParser from core.provider_manager import ProviderManager -from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException +from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom from extensions.ext_database import db from models.account import Account from models.model import EndUser, Conversation, Message, MessageFile, App @@ -169,15 +169,18 @@ class ApplicationManager: except ConversationTaskStoppedException: pass except InvokeAuthorizationError: - queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided')) + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) except ValidationError as e: logger.exception("Validation Error when generating") - queue_manager.publish_error(e) + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - queue_manager.publish_error(e) + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: logger.exception("Unknown Error when generating") - queue_manager.publish_error(e) + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: db.session.remove() diff --git a/api/core/application_queue_manager.py b/api/core/application_queue_manager.py index e32d631295..678d08d772 100644 --- a/api/core/application_queue_manager.py +++ b/api/core/application_queue_manager.py @@ -1,5 +1,6 @@ import queue import time +from enum import Enum from typing import Generator, Any from sqlalchemy.orm import DeclarativeMeta @@ -13,6 +14,11 @@ from extensions.ext_redis import redis_client from models.model import MessageAgentThought +class PublishFrom(Enum): + APPLICATION_MANAGER = 1 + TASK_PIPELINE = 2 + + class ApplicationQueueManager: def __init__(self, task_id: str, user_id: str, @@ -61,11 +67,14 @@ class ApplicationQueueManager: if elapsed_time >= listen_timeout or self._is_stopped(): # publish two messages to make sure the client can receive the stop signal # and stop listening after the stop signal processed - self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)) + self.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), + PublishFrom.TASK_PIPELINE + ) self.stop_listen() if elapsed_time // 10 > last_ping_time: - self.publish(QueuePingEvent()) + self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) last_ping_time = elapsed_time // 10 def stop_listen(self) -> None: @@ -75,76 +84,83 @@ class ApplicationQueueManager: """ self._q.put(None) - def publish_chunk_message(self, chunk: LLMResultChunk) -> None: + def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: """ Publish chunk message to channel :param chunk: chunk + :param pub_from: publish from :return: """ self.publish(QueueMessageEvent( chunk=chunk - )) + ), pub_from) - def publish_message_replace(self, text: str) -> None: + def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None: """ Publish message replace :param text: text + :param pub_from: publish from :return: """ self.publish(QueueMessageReplaceEvent( text=text - )) + ), pub_from) - def publish_retriever_resources(self, retriever_resources: list[dict]) -> None: + def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None: """ Publish retriever resources :return: """ - self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources)) + self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from) - def publish_annotation_reply(self, message_annotation_id: str) -> None: + def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None: """ Publish annotation reply :param message_annotation_id: message annotation id + :param pub_from: publish from :return: """ - self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id)) + self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from) - def publish_message_end(self, llm_result: LLMResult) -> None: + def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None: """ Publish message end :param llm_result: llm result + :param pub_from: publish from :return: """ - self.publish(QueueMessageEndEvent(llm_result=llm_result)) + self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from) self.stop_listen() - def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None: + def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None: """ Publish agent thought :param message_agent_thought: message agent thought + :param pub_from: publish from :return: """ self.publish(QueueAgentThoughtEvent( agent_thought_id=message_agent_thought.id - )) + ), pub_from) - def publish_error(self, e) -> None: + def publish_error(self, e, pub_from: PublishFrom) -> None: """ Publish error :param e: error + :param pub_from: publish from :return: """ self.publish(QueueErrorEvent( error=e - )) + ), pub_from) self.stop_listen() - def publish(self, event: AppQueueEvent) -> None: + def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ Publish event to queue :param event: + :param pub_from: :return: """ self._check_for_sqlalchemy_models(event.dict()) @@ -162,6 +178,9 @@ class ApplicationQueueManager: if isinstance(event, QueueStopEvent): self.stop_listen() + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise ConversationTaskStoppedException() + @classmethod def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: """ @@ -187,7 +206,6 @@ class ApplicationQueueManager: stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id) result = redis_client.get(stopped_cache_key) if result is not None: - redis_client.delete(stopped_cache_key) return True return False diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index ec2964f2af..1c9d3d7139 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -8,7 +8,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage -from core.application_queue_manager import ApplicationQueueManager +from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.entity.agent_loop import AgentLoop from core.entities.application_entities import ModelConfigEntity from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult @@ -232,7 +232,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): db.session.add(message_agent_thought) db.session.commit() - self.queue_manager.publish_agent_thought(message_agent_thought) + self.queue_manager.publish_agent_thought(message_agent_thought, PublishFrom.APPLICATION_MANAGER) return message_agent_thought diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index bfd305e2f7..a7dbbac393 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -2,7 +2,7 @@ from typing import List, Union from langchain.schema import Document -from core.application_queue_manager import ApplicationQueueManager +from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import InvokeFrom from extensions.ext_database import db from models.dataset import DocumentSegment, DatasetQuery @@ -80,4 +80,4 @@ class DatasetIndexToolCallbackHandler: db.session.add(dataset_retriever_resource) db.session.commit() - self._queue_manager.publish_retriever_resources(resource) + self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER)