dify/api/core/moderation/output_moderation.py
2024-09-10 17:00:20 +08:00

132 lines
4.3 KiB
Python

import logging
import threading
import time
from typing import Any, Optional
from flask import Flask, current_app
from pydantic import BaseModel, ConfigDict
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import QueueMessageReplaceEvent
from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.factory import ModerationFactory
logger = logging.getLogger(__name__)
class ModerationRule(BaseModel):
type: str
config: dict[str, Any]
class OutputModeration(BaseModel):
tenant_id: str
app_id: str
rule: ModerationRule
queue_manager: AppQueueManager
thread: Optional[threading.Thread] = None
thread_running: bool = True
buffer: str = ""
is_final_chunk: bool = False
final_output: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def should_direct_output(self) -> bool:
return self.final_output is not None
def get_final_output(self) -> str:
return self.final_output or ""
def append_new_token(self, token: str) -> None:
self.buffer += token
if not self.thread:
self.thread = self.start_thread()
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion)
if not result or not result.flagged:
return completion
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
else:
final_output = result.text
if public_event:
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
return final_output
def start_thread(self) -> threading.Thread:
buffer_size = dify_config.MODERATION_BUFFER_SIZE
thread = threading.Thread(
target=self.worker,
kwargs={
"flask_app": current_app._get_current_object(),
"buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE,
},
)
thread.start()
return thread
def stop_thread(self):
if self.thread and self.thread.is_alive():
self.thread_running = False
def worker(self, flask_app: Flask, buffer_size: int):
with flask_app.app_context():
current_length = 0
while self.thread_running:
moderation_buffer = self.buffer
buffer_length = len(moderation_buffer)
if not self.is_final_chunk:
chunk_length = buffer_length - current_length
if 0 <= chunk_length < buffer_size:
time.sleep(1)
continue
current_length = buffer_length
result = self.moderation(
tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer
)
if not result or not result.flagged:
continue
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
self.final_output = final_output
else:
final_output = result.text + self.buffer[len(moderation_buffer) :]
# trigger replace event
if self.thread_running:
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
if result.action == ModerationAction.DIRECT_OUTPUT:
break
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
try:
moderation_factory = ModerationFactory(
name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config
)
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
logger.error("Moderation Output error: %s", e)
return None