feat: upgrade langchain (#430)

Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
John Wang
2023-06-25 16:49:14 +08:00
committed by GitHub
parent 1dee5de9b4
commit 3241e4015b
91 changed files with 2703 additions and 3153 deletions

View File

@@ -1,7 +1,5 @@
from typing import Optional
from langchain.callbacks import CallbackManager
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.tool_chain import ToolChain
@@ -14,7 +12,7 @@ class ChainBuilder:
tool=tool,
input_key=kwargs.get('input_key', 'input'),
output_key=kwargs.get('output_key', 'tool_output'),
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
callbacks=[DifyStdOutCallbackHandler()]
)
@classmethod
@@ -27,7 +25,7 @@ class ChainBuilder:
sensitive_words=sensitive_words.split(","),
canned_response=tool_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]),
callbacks=[DifyStdOutCallbackHandler()],
**kwargs
)

View File

@@ -1,15 +1,16 @@
"""Base classes for LLM-powered router chains."""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import root_validator
from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
from langchain.schema import BaseOutputParser, OutputParserException
from libs.json_in_md_parser import parse_and_check_json_markdown
@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
raise ValueError
def _call(
self,
inputs: Dict[str, Any]
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
output = cast(
Dict[str, Any],

View File

@@ -1,11 +1,9 @@
from typing import Optional, List
from typing import Optional, List, cast
from langchain.callbacks import SharedCallbackManager, CallbackManager
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder
@@ -18,6 +16,7 @@ from models.dataset import Dataset
class MainChainBuilder:
@classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
rest_tokens: int,
conversation_message_task: ConversationMessageTask):
first_input_key = "input"
final_output_key = "output"
@@ -30,6 +29,7 @@ class MainChainBuilder:
tool_chains, chains_output_key = cls.get_agent_chains(
tenant_id=tenant_id,
agent_mode=agent_mode,
rest_tokens=rest_tokens,
memory=memory,
conversation_message_task=conversation_message_task
)
@@ -42,9 +42,8 @@ class MainChainBuilder:
return None
for chain in chains:
# do not add handler into singleton callback manager
if not isinstance(chain.callback_manager, SharedCallbackManager):
chain.callback_manager.add_handler(chain_callback_handler)
chain = cast(Chain, chain)
chain.callbacks.append(chain_callback_handler)
# build main chain
overall_chain = SequentialChain(
@@ -57,7 +56,9 @@ class MainChainBuilder:
return overall_chain
@classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask):
# agent mode
chains = []
@@ -93,7 +94,8 @@ class MainChainBuilder:
tenant_id=tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chains.append(multi_dataset_router_chain)

View File

@@ -1,9 +1,9 @@
import math
from typing import Mapping, List, Dict, Any, Optional
from langchain import LLMChain, PromptTemplate, ConversationChain
from langchain.callbacks import CallbackManager
from langchain import PromptTemplate
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.schema import BaseLanguageModel
from pydantic import Extra
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
@@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_tool_builder import DatasetToolBuilder
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
from models.dataset import Dataset
from core.tool.dataset_index_tool import DatasetTool
from models.dataset import Dataset, DatasetProcessRule
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MULTI_PROMPT_ROUTER_TEMPLATE = """
Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \
@@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain):
router_chain: LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
dataset_tools: Mapping[str, DatasetTool]
"""Map of name to candidate chains that inputs can be routed to."""
class Config:
@@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain):
tenant_id: str,
datasets: List[Dataset],
conversation_message_task: ConversationMessageTask,
rest_tokens: int,
**kwargs: Any,
):
"""Convenience constructor for instantiating from destination prompts."""
llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=1024,
callback_manager=llm_callback_manager
callbacks=[DifyStdOutCallbackHandler()]
)
destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description
destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
else ('useful for when you want to answer queries about the ' + d.name))
for d in datasets]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str
)
router_prompt = PromptTemplate(
template=router_template,
input_variables=["input"],
output_parser=RouterOutputParser(),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
dataset_tools = {}
for dataset in datasets:
dataset_tool = DatasetToolBuilder.build_dataset_tool(
# fulfill description when it is empty
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
continue
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
if k == 0:
continue
dataset_tool = DatasetTool(
name=f"dataset-{dataset.id}",
description=description,
k=k,
dataset=dataset,
response_mode='no_synthesizer', # "compact"
callback_handler=DatasetToolCallbackHandler(conversation_message_task)
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
)
if dataset_tool:
dataset_tools[dataset.id] = dataset_tool
dataset_tools[str(dataset.id)] = dataset_tool
return cls(
router_chain=router_chain,
@@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain):
**kwargs,
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
else:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens
def _call(
self,
inputs: Dict[str, Any]
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if len(self.dataset_tools) == 0:
return {"text": ''}

View File

@@ -1,5 +1,6 @@
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
return self.canned_response
return text
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
output = self._check_sensitive_word(text)
return {self.output_key: output}

View File

@@ -1,5 +1,6 @@
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.tools import BaseTool
@@ -30,12 +31,20 @@ class ToolChain(Chain):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
input = inputs[self.input_key]
output = self.tool.run(input, self.verbose)
return {self.output_key: output}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
input = inputs[self.input_key]
output = await self.tool.arun(input, self.verbose)