feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@@ -1,13 +1,10 @@
import logging
from langchain import PromptTemplate
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
from langchain.schema import OutputParserException
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_open_ai import StreamableOpenAI
from core.llm.token_calculator import TokenCalculator
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@@ -15,9 +12,6 @@ from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTempla
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
GENERATOR_QA_PROMPT
# gpt-3.5-turbo works not well
generate_base_model = 'text-davinci-003'
class LLMGenerator:
@classmethod
@@ -28,29 +22,35 @@ class LLMGenerator:
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
prompt = prompt.format(query=query)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=50,
timeout=600
model_kwargs=ModelKwargs(
max_tokens=50
)
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
model = 'gpt-3.5-turbo'
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=max_tokens
)
)
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = ''
for message in messages:
@@ -68,25 +68,16 @@ class LLMGenerator:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
context += message_qa_text
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=model,
max_tokens=max_tokens
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
@@ -94,16 +85,13 @@ class LLMGenerator:
prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
@@ -119,23 +107,19 @@ class LLMGenerator:
_input = prompt.format_prompt(histories=histories)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=256
model_kwargs=ModelKwargs(
max_tokens=256,
temperature=0
)
)
if isinstance(llm, BaseChatModel):
query = [HumanMessage(content=_input.to_string())]
else:
query = _input.to_string()
prompts = [PromptMessage(content=_input.to_string())]
try:
output = llm(query)
if isinstance(output, BaseMessage):
output = output.content
questions = output_parser.parse(output)
output = model_instance.run(prompts)
questions = output_parser.parse(output.content)
except Exception:
logging.exception("Error generating suggested questions after answer")
questions = []
@@ -160,21 +144,19 @@ class LLMGenerator:
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name=generate_base_model,
temperature=0,
max_tokens=512
model_kwargs=ModelKwargs(
max_tokens=512,
temperature=0
)
)
if isinstance(llm, BaseChatModel):
query = [HumanMessage(content=_input.to_string())]
else:
query = _input.to_string()
prompts = [PromptMessage(content=_input.to_string())]
try:
output = llm(query)
rule_config = output_parser.parse(output)
output = model_instance.run(prompts)
rule_config = output_parser.parse(output.content)
except OutputParserException:
raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
except Exception:
@@ -188,25 +170,21 @@ class LLMGenerator:
return rule_config
@classmethod
async def generate_qa_document(cls, llm: StreamableOpenAI, query):
def generate_qa_document(cls, tenant_id: str, query):
prompt = GENERATOR_QA_PROMPT
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=2000
)
)
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
prompts = [
PromptMessage(content=prompt, type=MessageType.SYSTEM),
PromptMessage(content=query)
]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip()
@classmethod
def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
prompt = GENERATOR_QA_PROMPT
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
response = model_instance.run(prompts)
answer = response.content
return answer.strip()