mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 08:39:20 +08:00
Initial commit
This commit is contained in:
178
api/core/callback_handler/agent_loop_gather_callback_handler.py
Normal file
178
api/core/callback_handler/agent_loop_gather_callback_handler.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.model_name = model_name
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self.current_chain = None
|
||||
|
||||
@property
|
||||
def agent_loops(self) -> List[AgentLoop]:
|
||||
return self._agent_loops
|
||||
|
||||
def clear_agent_loops(self) -> None:
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
# serialized={'name': 'OpenAI'}
|
||||
# prompts=['Answer the following questions...\nThought:']
|
||||
# kwargs={}
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt=prompts[0],
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
# kwargs={}
|
||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||
self._current_loop.status = 'llm_end'
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
self._current_loop.completion = response.generations[0][0].text
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
# kwargs={'color': 'green', 'llm_prefix': 'Thought:', 'observation_prefix': 'Observation: '}
|
||||
# input_str='action-input'
|
||||
# serialized={'description': 'A search engine. Useful for when you need to answer questions about current events. Input should be a search query.', 'name': 'Search'}
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
|
||||
if self._current_loop and self._current_loop.status == 'llm_end':
|
||||
self._current_loop.status = 'agent_action'
|
||||
self._current_loop.thought = thought
|
||||
self._current_loop.tool_name = tool
|
||||
self._current_loop.tool_input = tool_input
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
# kwargs={'name': 'Search'}
|
||||
# llm_prefix='Thought:'
|
||||
# observation_prefix='Observation: '
|
||||
# output='53 years'
|
||||
|
||||
if self._current_loop and self._current_loop.status == 'agent_action' and output and output != 'None':
|
||||
self._current_loop.status = 'tool_end'
|
||||
self._current_loop.tool_output = output
|
||||
self._current_loop.completed = True
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.error(error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
# Final Answer
|
||||
if self._current_loop and (self._current_loop.status == 'llm_end' or self._current_loop.status == 'agent_action'):
|
||||
self._current_loop.status = 'agent_finish'
|
||||
self._current_loop.completed = True
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
elif not self._current_loop and self._agent_loops:
|
||||
self._agent_loops[-1].status = 'agent_finish'
|
||||
117
api/core/callback_handler/dataset_tool_callback_handler.py
Normal file
117
api/core/callback_handler/dataset_tool_callback_handler.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.queries = []
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_name = serialized.get('name')
|
||||
dataset_id = tool_name[len("dataset-"):]
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str))
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# kwargs={'name': 'Search'}
|
||||
# llm_prefix='Thought:'
|
||||
# observation_prefix='Observation: '
|
||||
# output='53 years'
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
pass
|
||||
23
api/core/callback_handler/entity/agent_loop.py
Normal file
23
api/core/callback_handler/entity/agent_loop.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentLoop(BaseModel):
|
||||
position: int = 1
|
||||
|
||||
thought: str = None
|
||||
tool_name: str = None
|
||||
tool_input: str = None
|
||||
tool_output: str = None
|
||||
|
||||
prompt: str = None
|
||||
prompt_tokens: int = None
|
||||
completion: str = None
|
||||
completion_tokens: int = None
|
||||
|
||||
latency: float = None
|
||||
|
||||
status: str = 'llm_started'
|
||||
completed: bool = False
|
||||
|
||||
started_at: float = None
|
||||
completed_at: float = None
|
||||
16
api/core/callback_handler/entity/chain_result.py
Normal file
16
api/core/callback_handler/entity/chain_result.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChainResult(BaseModel):
|
||||
type: str = None
|
||||
prompt: dict = None
|
||||
completion: dict = None
|
||||
|
||||
status: str = 'chain_started'
|
||||
completed: bool = False
|
||||
|
||||
started_at: float = None
|
||||
completed_at: float = None
|
||||
|
||||
agent_result: dict = None
|
||||
"""only when type is 'AgentExecutor'"""
|
||||
6
api/core/callback_handler/entity/dataset_query.py
Normal file
6
api/core/callback_handler/entity/dataset_query.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DatasetQueryObj(BaseModel):
|
||||
dataset_id: str = None
|
||||
query: str = None
|
||||
9
api/core/callback_handler/entity/llm_message.py
Normal file
9
api/core/callback_handler/entity/llm_message.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
prompt: str = ''
|
||||
prompt_tokens: int = 0
|
||||
completion: str = ''
|
||||
completion_tokens: int = 0
|
||||
latency: float = 0.0
|
||||
38
api/core/callback_handler/index_tool_callback_handler.py
Normal file
38
api/core/callback_handler/index_tool_callback_handler.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from llama_index import Response
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
|
||||
class IndexToolCallbackHandler:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._response = None
|
||||
|
||||
@property
|
||||
def response(self) -> Response:
|
||||
return self._response
|
||||
|
||||
def on_tool_end(self, response: Response) -> None:
|
||||
"""Handle tool end."""
|
||||
self._response = response
|
||||
|
||||
|
||||
class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
|
||||
"""Callback handler for dataset tool."""
|
||||
|
||||
def __init__(self, dataset_id: str) -> None:
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
|
||||
def on_tool_end(self, response: Response) -> None:
|
||||
"""Handle tool end."""
|
||||
for node in response.source_nodes:
|
||||
index_node_id = node.node.doc_id
|
||||
|
||||
# add hit count to document segment
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset_id,
|
||||
DocumentSegment.index_node_id == index_node_id
|
||||
).update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
147
api/core/callback_handler/llm_callback_handler.py
Normal file
147
api/core/callback_handler/llm_callback_handler.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
|
||||
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
|
||||
|
||||
class LLMCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
self.llm = llm
|
||||
self.llm_message = LLMMessage()
|
||||
self.start_at = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
if 'Chat' in serialized['name']:
|
||||
real_prompts = []
|
||||
messages = []
|
||||
for prompt in prompts:
|
||||
role, content = prompt.split(': ', maxsplit=1)
|
||||
if role == 'human':
|
||||
role = 'user'
|
||||
message = HumanMessage(content=content)
|
||||
elif role == 'ai':
|
||||
role = 'assistant'
|
||||
message = AIMessage(content=content)
|
||||
else:
|
||||
message = SystemMessage(content=content)
|
||||
|
||||
real_prompt = {
|
||||
"role": role,
|
||||
"text": content
|
||||
}
|
||||
real_prompts.append(real_prompt)
|
||||
messages.append(message)
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
|
||||
else:
|
||||
self.llm_message.prompt = [{
|
||||
"role": 'user',
|
||||
"text": prompts[0]
|
||||
}]
|
||||
|
||||
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
end_at = time.perf_counter()
|
||||
self.llm_message.latency = end_at - self.start_at
|
||||
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
self.conversation_message_task.append_message_text(token)
|
||||
self.llm_message.completion += token
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
if isinstance(error, ConversationTaskStoppedException):
|
||||
if self.conversation_message_task.streaming:
|
||||
end_at = time.perf_counter()
|
||||
self.llm_message.latency = end_at - self.start_at
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
||||
else:
|
||||
logging.error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
137
api/core/callback_handler/main_chain_gather_callback_handler.py
Normal file
137
api/core/callback_handler/main_chain_gather_callback_handler.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.constant import llm_constant
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler(
|
||||
llm_constant.agent_model_name,
|
||||
conversation_message_task
|
||||
)
|
||||
|
||||
def clear_chain_results(self) -> None:
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.agent_loop_gather_callback_handler.current_chain = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return True
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
if not self._current_chain_result:
|
||||
self._current_chain_result = ChainResult(
|
||||
type=serialized['name'],
|
||||
prompt=inputs,
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
||||
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
if self._current_chain_result and self._current_chain_result.status == 'chain_started':
|
||||
self._current_chain_result.status = 'chain_ended'
|
||||
self._current_chain_result.completion = outputs
|
||||
self._current_chain_result.completed = True
|
||||
self._current_chain_result.completed_at = time.perf_counter()
|
||||
|
||||
self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result)
|
||||
|
||||
self.clear_chain_results()
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
self.clear_chain_results()
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.error(error)
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
pass
|
||||
127
api/core/callback_handler/std_out_callback_handler.py
Normal file
127
api/core/callback_handler/std_out_callback_handler.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def __init__(self, color: Optional[str] = None) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.color = color
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
print_text("\n[on_llm_start]\n", color='blue')
|
||||
|
||||
if 'Chat' in serialized['name']:
|
||||
for prompt in prompts:
|
||||
print_text(prompt + "\n", color='blue')
|
||||
else:
|
||||
print_text(prompts[0] + "\n", color='blue')
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_llm_end]\nOutput: " + str(response.generations[0][0].text) + "\nllm_output: " + str(
|
||||
response.llm_output) + "\n", color='blue')
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
class_name = serialized["name"]
|
||||
print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink')
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_chain_error]\nError: " + str(error) + "\n", color='pink')
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] " + str(serialized), color='yellow')
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
|
||||
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
|
||||
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
print_text("\n[on_tool_end]\n", color='yellow')
|
||||
if observation_prefix:
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(output, color='yellow')
|
||||
if llm_prefix:
|
||||
print_text(f"\n{llm_prefix}")
|
||||
print_text("\n")
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='yellow')
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run when agent ends."""
|
||||
print_text("\n[on_text] " + text + "\n", color=color if color else self.color, end=end)
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
|
||||
|
||||
|
||||
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
|
||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
Reference in New Issue
Block a user