feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@@ -1,36 +0,0 @@
import os
from typing import Optional
import langchain
from flask import Flask
from pydantic import BaseModel
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.prompt.prompt_template import OneLineFormatter
class HostedOpenAICredential(BaseModel):
api_key: str
class HostedAnthropicCredential(BaseModel):
api_key: str
class HostedLLMCredentials(BaseModel):
openai: Optional[HostedOpenAICredential] = None
anthropic: Optional[HostedAnthropicCredential] = None
hosted_llm_credentials = HostedLLMCredentials()
def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
if app.config.get("ANTHROPIC_API_KEY"):
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))

View File

@@ -1,20 +1,17 @@
from typing import cast, List
from typing import List
from langchain import OpenAI
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import BaseMessage
from core.constant import llm_constant
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
class CalcTokenMixin:
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
llm = cast(ChatOpenAI, llm)
return llm.get_num_tokens_from_messages(messages)
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
return model_instance.get_num_tokens(to_prompt_messages(messages))
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
@@ -22,10 +19,9 @@ class CalcTokenMixin:
:param messages:
:return:
"""
llm = cast(ChatOpenAI, llm)
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
completion_max_tokens = llm.max_tokens
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
llm_max_tokens = model_instance.model_rules.max_tokens.max
completion_max_tokens = model_instance.model_kwargs.max_tokens
used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
return rest_tokens

View File

@@ -4,9 +4,11 @@ from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -14,6 +16,12 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""

View File

@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@@ -84,7 +85,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))

View File

@@ -3,20 +3,28 @@ from typing import cast, List
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from pydantic import BaseModel
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
model_instance: BaseLLM
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages

View File

@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFuncti
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@@ -84,7 +85,7 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))

View File

@@ -0,0 +1,162 @@
import re
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
model_instance: BaseLLM
dataset_tools: Sequence[BaseTool]
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.dataset_tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
try:
return self.output_parser.parse(full_output)
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
unique_tool_names = set(tool.name for tool in tools)
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
dataset_tools=tools,
**kwargs,
)

View File

@@ -14,7 +14,7 @@ from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.model_providers.models.llm.base import BaseLLM
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@@ -53,6 +53,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
@@ -89,7 +95,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)

View File

@@ -3,7 +3,6 @@ import logging
from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool
@@ -13,14 +12,17 @@ from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'
@@ -28,10 +30,9 @@ class PlanningStrategy(str, enum.Enum):
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
llm: BaseLanguageModel
model_instance: BaseLLM
tools: list[BaseTool]
summary_llm: BaseLanguageModel
dataset_llm: BaseLanguageModel
summary_model_instance: BaseLLM
memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
@@ -60,36 +61,49 @@ class AgentExecutor:
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
llm=self.configuration.dataset_llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True
)
else:
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")

View File

@@ -10,15 +10,16 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.models.llm.base import BaseLLM
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.model_name = model_name
self.model_instant = model_instant
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
@@ -152,7 +153,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
self._message_agent_thought, self.model_instant, self._current_loop
)
self._agent_loops.append(self._current_loop)
@@ -183,7 +184,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
self._message_agent_thought, self.model_instant, self._current_loop
)
self._agent_loops.append(self._current_loop)

View File

@@ -3,18 +3,20 @@ import time
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
from langchain.schema import LLMResult, BaseMessage
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
from core.model_providers.models.llm.base import BaseLLM
class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, llm: BaseLanguageModel,
def __init__(self, model_instance: BaseLLM,
conversation_message_task: ConversationMessageTask):
self.llm = llm
self.model_instance = model_instance
self.llm_message = LLMMessage()
self.start_at = None
self.conversation_message_task = conversation_message_task
@@ -46,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -58,7 +60,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
"text": prompts[0]
}]
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
@@ -68,7 +70,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message)
@@ -89,7 +91,9 @@ class LLMCallbackHandler(BaseCallbackHandler):
if self.conversation_message_task.streaming:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
else:
logging.error(error)

View File

@@ -5,9 +5,7 @@ from typing import Any, Dict, Union
from langchain.callbacks.base import BaseCallbackHandler
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.entity.chain_result import ChainResult
from core.constant import llm_constant
from core.conversation_message_task import ConversationMessageTask

View File

@@ -2,27 +2,19 @@ import logging
import re
from typing import Optional, List, Union, Tuple
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, HumanMessage
from langchain.schema import BaseMessage
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError
from core.llm.fake import FakeLLM
from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
@@ -51,12 +43,10 @@ class Completion:
inputs = conversation.inputs
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
app_model_config=app_model_config,
query=query,
inputs=inputs
model_config=app_model_config.model_dict,
streaming=streaming
)
conversation_message_task = ConversationMessageTask(
@@ -68,10 +58,17 @@ class Completion:
is_override=is_override,
inputs=inputs,
query=query,
streaming=streaming
streaming=streaming,
model_instance=final_model_instance
)
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
model_instance=final_model_instance,
app_model_config=app_model_config,
query=query,
inputs=inputs
)
# init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser(
@@ -80,6 +77,7 @@ class Completion:
)
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)
@@ -102,15 +100,14 @@ class Completion:
# run the final llm
try:
cls.run_final_llm(
tenant_id=app.tenant_id,
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory,
streaming=streaming
memory=memory
)
except ConversationTaskStoppedException:
return
@@ -121,31 +118,20 @@ class Completion:
return
@classmethod
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
final_llm = FakeLLM(response=agent_execute_result.output,
origin_llm=agent_execute_result.configuration.llm,
streaming=streaming)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
response = final_llm.generate([[HumanMessage(content=query)]])
return response
final_llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict,
streaming=streaming
)
fake_response = agent_execute_result.output
# get llm prompt
prompt, stop_words = cls.get_main_llm_prompt(
prompt_messages, stop_words = cls.get_main_llm_prompt(
mode=mode,
llm=final_llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
@@ -154,25 +140,26 @@ class Completion:
memory=memory
)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=final_llm,
model=app_model_config.model_dict,
prompt=prompt,
mode=mode
model_instance=model_instance,
prompt_messages=prompt_messages,
)
response = final_llm.generate([prompt], stop_words)
response = model_instance.run(
messages=prompt_messages,
stop=stop_words,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
return response
@classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
def get_main_llm_prompt(cls, mode: str, model: dict,
pre_prompt: str, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
Tuple[List[PromptMessage], Optional[List[str]]]:
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
@@ -200,11 +187,7 @@ And answer according to the language of the user's question.
**prompt_inputs
)
if isinstance(llm, BaseChatModel):
# use chat llm as completion model
return [HumanMessage(content=prompt_content)], None
else:
return prompt_content, None
return [PromptMessage(content=prompt_content)], None
else:
messages: List[BaseMessage] = []
@@ -249,12 +232,14 @@ And answer according to the language of the user's question.
inputs=human_inputs
)
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
model_name = model['name']
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = llm_constant.max_context_token_length[model_name] \
- max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
if memory.model_instance.model_rules.max_tokens.max:
curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
@@ -274,17 +259,7 @@ And answer according to the language of the user's question.
for message in messages:
message.content = re.sub(r'<\|.*?\|>', '', message.content)
return messages, ['\nHuman:', '</histories>']
@classmethod
def get_llm_callbacks(cls, llm: BaseLanguageModel,
streaming: bool,
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming:
return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
else:
return [llm_callback_handler, DifyStdOutCallbackHandler()]
return to_prompt_messages(messages), ['\nHuman:', '</histories>']
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
@@ -300,15 +275,15 @@ And answer according to the language of the user's question.
conversation: Conversation,
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
# only for calc token in memory
memory_llm = LLMBuilder.to_llm_from_model(
memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=tenant_id,
model=app_model_config.model_dict
model_config=app_model_config.model_dict
)
# use llm config from conversation
memory = ReadOnlyConversationTokenDBBufferSharedMemory(
conversation=conversation,
llm=memory_llm,
model_instance=memory_model_instance,
max_token_limit=kwargs.get("max_token_limit", 2048),
memory_key=kwargs.get("memory_key", "chat_history"),
return_messages=kwargs.get("return_messages", True),
@@ -320,21 +295,20 @@ And answer according to the language of the user's question.
return memory
@classmethod
def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
query: str, inputs: dict) -> int:
llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict
)
model_limited_tokens = model_instance.model_rules.max_tokens.max
max_tokens = model_instance.get_model_kwargs().max_tokens
model_name = app_model_config.model_dict.get("name")
model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
if model_limited_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
# get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt(
prompt_messages, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
@@ -343,9 +317,7 @@ And answer according to the language of the user's question.
memory=None
)
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
else llm.get_num_tokens_from_messages(prompt)
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
@@ -354,36 +326,40 @@ And answer according to the language of the user's question.
return rest_tokens
@classmethod
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
prompt: Union[str, List[BaseMessage]], mode: str):
def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_name = model.get("name")
model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = model.get("completion_params").get('max_tokens')
model_limited_tokens = model_instance.model_rules.max_tokens.max
max_tokens = model_instance.get_model_kwargs().max_tokens
if mode == 'completion' and isinstance(final_llm, BaseLLM):
prompt_tokens = final_llm.get_num_tokens(prompt)
else:
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
if model_limited_tokens is None:
return
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
final_llm.max_tokens = max_tokens
# update model instance max tokens
model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs)
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
llm = LLMBuilder.to_llm_from_model(
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
model=app_model_config.model_dict,
model_config=app_model_config.model_dict,
streaming=streaming
)
# get llm prompt
original_prompt, _ = cls.get_main_llm_prompt(
old_prompt_messages, _ = cls.get_main_llm_prompt(
mode="completion",
llm=llm,
model=app_model_config.model_dict,
pre_prompt=pre_prompt,
query=message.query,
@@ -395,10 +371,9 @@ And answer according to the language of the user's question.
original_completion = message.answer.strip()
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
prompt_messages = [PromptMessage(content=prompt)]
conversation_message_task = ConversationMessageTask(
task_id=task_id,
@@ -408,16 +383,16 @@ And answer according to the language of the user's question.
inputs=message.inputs,
query=message.query,
is_override=True if message.override_model_configs else False,
streaming=streaming
streaming=streaming,
model_instance=final_model_instance
)
llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=llm,
model=app_model_config.model_dict,
prompt=prompt,
mode='completion'
model_instance=final_model_instance,
prompt_messages=prompt_messages
)
llm.generate([prompt])
final_model_instance.run(
messages=prompt_messages,
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
)

View File

@@ -1,109 +0,0 @@
from _decimal import Decimal
models = {
'claude-instant-1': 'anthropic', # 100,000 tokens
'claude-2': 'anthropic', # 100,000 tokens
'gpt-4': 'openai', # 8,192 tokens
'gpt-4-32k': 'openai', # 32,768 tokens
'gpt-3.5-turbo': 'openai', # 4,096 tokens
'gpt-3.5-turbo-16k': 'openai', # 16384 tokens
'text-davinci-003': 'openai', # 4,097 tokens
'text-davinci-002': 'openai', # 4,097 tokens
'text-curie-001': 'openai', # 2,049 tokens
'text-babbage-001': 'openai', # 2,049 tokens
'text-ada-001': 'openai', # 2,049 tokens
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
'whisper-1': 'openai'
}
max_context_token_length = {
'claude-instant-1': 100000,
'claude-2': 100000,
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
'text-davinci-002': 4097,
'text-curie-001': 2049,
'text-babbage-001': 2049,
'text-ada-001': 2049,
'text-embedding-ada-002': 8191,
}
models_by_mode = {
'chat': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
],
'completion': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
'text-davinci-003', # 4,097 tokens
'text-davinci-002' # 4,097 tokens
'text-curie-001', # 2,049 tokens
'text-babbage-001', # 2,049 tokens
'text-ada-001' # 2,049 tokens
],
'embedding': [
'text-embedding-ada-002' # 8191 tokens, 1536 dimensions
]
}
model_currency = 'USD'
model_prices = {
'claude-instant-1': {
'prompt': Decimal('0.00163'),
'completion': Decimal('0.00551'),
},
'claude-2': {
'prompt': Decimal('0.01102'),
'completion': Decimal('0.03268'),
},
'gpt-4': {
'prompt': Decimal('0.03'),
'completion': Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': Decimal('0.06'),
'completion': Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': Decimal('0.0015'),
'completion': Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': Decimal('0.003'),
'completion': Decimal('0.004')
},
'text-davinci-003': {
'prompt': Decimal('0.02'),
'completion': Decimal('0.02')
},
'text-curie-001': {
'prompt': Decimal('0.002'),
'completion': Decimal('0.002')
},
'text-babbage-001': {
'prompt': Decimal('0.0005'),
'completion': Decimal('0.0005')
},
'text-ada-001': {
'prompt': Decimal('0.0004'),
'completion': Decimal('0.0004')
},
'text-embedding-ada-002': {
'usage': Decimal('0.0001'),
}
}
agent_model_name = 'text-davinci-003'

View File

@@ -6,9 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage
from core.callback_handler.entity.chain_result import ChainResult
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
from core.llm.provider.llm_provider_service import LLMProviderService
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from events.message_event import message_was_created
@@ -16,12 +16,11 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
from models.provider import ProviderType, Provider
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
self.task_id = task_id
@@ -38,9 +37,12 @@ class ConversationMessageTask:
self.conversation = conversation
self.is_new_conversation = False
self.model_instance = model_instance
self.message = None
self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider')
self.model_name = self.model_dict.get('name')
self.mode = app.mode
@@ -56,9 +58,6 @@ class ConversationMessageTask:
)
def init(self):
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
self.model_dict['provider'] = provider_name
override_model_configs = None
if self.is_override:
override_model_configs = {
@@ -89,15 +88,19 @@ class ConversationMessageTask:
if self.app_model_config.pre_prompt:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_provider_name=self.provider_name,
model_name=self.model_name
)
system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
if not self.conversation:
self.is_new_conversation = True
self.conversation = Conversation(
app_id=self.app_model_config.app_id,
app_model_config_id=self.app_model_config.id,
model_provider=self.model_dict.get('provider'),
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=self.mode,
@@ -117,7 +120,7 @@ class ConversationMessageTask:
self.message = Message(
app_id=self.app_model_config.app_id,
model_provider=self.model_dict.get('provider'),
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=self.conversation.id,
@@ -131,7 +134,7 @@ class ConversationMessageTask:
answer_unit_price=0,
provider_response_latency=0,
total_price=0,
currency=llm_constant.model_currency,
currency=self.model_instance.get_currency(),
from_source=('console' if isinstance(self.user, Account) else 'api'),
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
@@ -145,12 +148,10 @@ class ConversationMessageTask:
self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
model_name = self.app_model_config.model_dict.get('name')
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = llm_constant.model_prices[model_name]['prompt']
answer_unit_price = llm_constant.model_prices[model_name]['completion']
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
@@ -163,8 +164,6 @@ class ConversationMessageTask:
self.message.provider_response_latency = llm_message.latency
self.message.total_price = total_price
self.update_provider_quota()
db.session.commit()
message_was_created.send(
@@ -176,20 +175,6 @@ class ConversationMessageTask:
if not by_stopped:
self.end()
def update_provider_quota(self):
llm_provider_service = LLMProviderService(
tenant_id=self.app.tenant_id,
provider_name=self.message.model_provider,
)
provider = llm_provider_service.get_provider_db_record()
if provider and provider.provider_type == ProviderType.SYSTEM.value:
db.session.query(Provider).filter(
Provider.tenant_id == self.app.tenant_id,
Provider.provider_name == provider.provider_name,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + 1})
def init_chain(self, chain_result: ChainResult):
message_chain = MessageChain(
message_id=self.message.id,
@@ -229,10 +214,10 @@ class ConversationMessageTask:
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
@@ -253,7 +238,7 @@ class ConversationMessageTask:
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = llm_constant.model_currency
message_agent_thought.currency = agent_model_instant.get_currency()
db.session.flush()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence
from langchain.schema import Document
from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
@@ -13,12 +13,10 @@ class DatesetDocumentStore:
self,
dataset: Dataset,
user_id: str,
embedding_model_name: str,
document_id: Optional[str] = None,
):
self._dataset = dataset
self._user_id = user_id
self._embedding_model_name = embedding_model_name
self._document_id = document_id
@classmethod
@@ -39,10 +37,6 @@ class DatesetDocumentStore:
def user_id(self) -> Any:
return self._user_id
@property
def embedding_model_name(self) -> Any:
return self._embedding_model_name
@property
def docs(self) -> Dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter(
@@ -74,6 +68,10 @@ class DatesetDocumentStore:
if max_position is None:
max_position = 0
embedding_model = ModelFactory.get_embedding_model(
tenant_id=self._dataset.tenant_id
)
for doc in docs:
if not isinstance(doc, Document):
raise ValueError("doc must be a Document")
@@ -88,7 +86,7 @@ class DatesetDocumentStore:
)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
tokens = embedding_model.get_num_tokens(doc.page_content)
if not segment_document:
max_position += 1

View File

@@ -4,14 +4,14 @@ from typing import List
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from core.model_providers.models.embedding.base import BaseEmbedding
from extensions.ext_database import db
from libs import helper
from models.dataset import Embedding
class CacheEmbedding(Embeddings):
def __init__(self, embeddings: Embeddings):
def __init__(self, embeddings: BaseEmbedding):
self._embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -21,48 +21,54 @@ class CacheEmbedding(Embeddings):
embedding_queue_texts = []
for text in texts:
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
if embedding:
text_embeddings.append(embedding.get_embedding())
else:
embedding_queue_texts.append(text)
embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
if embedding_queue_texts:
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
finally:
i += 1
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
text_embeddings.extend(embedding_results)
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
finally:
i += 1
text_embeddings.extend(embedding_results)
return text_embeddings
@handle_openai_exceptions
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
if embedding:
return embedding.get_embedding()
embedding_results = self._embeddings.embed_query(text)
try:
embedding_results = self._embeddings.client.embed_query(text)
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
try:
embedding = Embedding(hash=hash)
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results)
db.session.add(embedding)
db.session.commit()
@@ -72,3 +78,5 @@ class CacheEmbedding(Embeddings):
logging.exception('Failed to add embedding to db')
return embedding_results

View File

@@ -1,13 +1,10 @@
import logging
from langchain import PromptTemplate
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
from langchain.schema import OutputParserException
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_open_ai import StreamableOpenAI
from core.llm.token_calculator import TokenCalculator
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@@ -15,9 +12,6 @@ from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTempla
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
GENERATOR_QA_PROMPT
# gpt-3.5-turbo works not well
generate_base_model = 'text-davinci-003'
class LLMGenerator:
@classmethod
@@ -28,29 +22,35 @@ class LLMGenerator:
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
prompt = prompt.format(query=query)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=50,
timeout=600
model_kwargs=ModelKwargs(
max_tokens=50
)
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
model = 'gpt-3.5-turbo'
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=max_tokens
)
)
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = ''
for message in messages:
@@ -68,25 +68,16 @@ class LLMGenerator:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
context += message_qa_text
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=model,
max_tokens=max_tokens
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
@@ -94,16 +85,13 @@ class LLMGenerator:
prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
@@ -119,23 +107,19 @@ class LLMGenerator:
_input = prompt.format_prompt(histories=histories)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=256
model_kwargs=ModelKwargs(
max_tokens=256,
temperature=0
)
)
if isinstance(llm, BaseChatModel):
query = [HumanMessage(content=_input.to_string())]
else:
query = _input.to_string()
prompts = [PromptMessage(content=_input.to_string())]
try:
output = llm(query)
if isinstance(output, BaseMessage):
output = output.content
questions = output_parser.parse(output)
output = model_instance.run(prompts)
questions = output_parser.parse(output.content)
except Exception:
logging.exception("Error generating suggested questions after answer")
questions = []
@@ -160,21 +144,19 @@ class LLMGenerator:
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name=generate_base_model,
temperature=0,
max_tokens=512
model_kwargs=ModelKwargs(
max_tokens=512,
temperature=0
)
)
if isinstance(llm, BaseChatModel):
query = [HumanMessage(content=_input.to_string())]
else:
query = _input.to_string()
prompts = [PromptMessage(content=_input.to_string())]
try:
output = llm(query)
rule_config = output_parser.parse(output)
output = model_instance.run(prompts)
rule_config = output_parser.parse(output.content)
except OutputParserException:
raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
except Exception:
@@ -188,25 +170,21 @@ class LLMGenerator:
return rule_config
@classmethod
async def generate_qa_document(cls, llm: StreamableOpenAI, query):
def generate_qa_document(cls, tenant_id: str, query):
prompt = GENERATOR_QA_PROMPT
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=2000
)
)
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
prompts = [
PromptMessage(content=prompt, type=MessageType.SYSTEM),
PromptMessage(content=query)
]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip()
@classmethod
def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
prompt = GENERATOR_QA_PROMPT
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
response = model_instance.run(prompts)
answer = response.content
return answer.strip()

View File

View File

@@ -0,0 +1,20 @@
import base64
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
def obfuscated_token(token: str):
return token[:6] + '*' * (len(token) - 8) + token[-2:]
def encrypt_token(tenant_id: str, token: str):
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(tenant_id: str, token: str):
return rsa.decrypt(base64.b64decode(token), tenant_id)

View File

@@ -1,10 +1,9 @@
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from core.model_providers.model_factory import ModelFactory
from models.dataset import Dataset
@@ -15,16 +14,11 @@ class IndexBuilder:
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
max_retries=1,
**model_credentials
))
embeddings = CacheEmbedding(embedding_model)
return VectorIndex(
dataset=dataset,

View File

@@ -1,4 +1,3 @@
import concurrent
import datetime
import json
import logging
@@ -6,7 +5,6 @@ import re
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, List, cast
from flask_login import current_user
@@ -18,11 +16,10 @@ from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.generator.llm_generator import LLMGenerator
from core.index.index import IndexBuilder
from core.llm.error import ProviderTokenNotInitError
from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_open_ai import StreamableOpenAI
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import MessageType
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@@ -35,9 +32,8 @@ from models.source import DataSourceBinding
class IndexingRunner:
def __init__(self, embedding_model_name: str = "text-embedding-ada-002"):
def __init__(self):
self.storage = storage
self.embedding_model_name = embedding_model_name
def run(self, dataset_documents: List[DatasetDocument]):
"""Run the indexing process."""
@@ -227,11 +223,15 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None) -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
tokens = 0
preview_texts = []
total_segments = 0
@@ -253,44 +253,49 @@ class IndexingRunner:
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
self.filter_string(document.page_content))
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0:
# qa model document
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=current_user.current_tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
}
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"currency": embedding_model.get_currency(),
"preview": preview_texts
}
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# load data from notion
tokens = 0
preview_texts = []
@@ -336,31 +341,31 @@ class IndexingRunner:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
tokens += embedding_model.get_num_tokens(document.page_content)
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0:
# qa model document
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=current_user.current_tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
}
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"currency": embedding_model.get_currency(),
"preview": preview_texts
}
@@ -459,7 +464,6 @@ class IndexingRunner:
doc_store = DatesetDocumentStore(
dataset=dataset,
user_id=dataset_document.created_by,
embedding_model_name=self.embedding_model_name,
document_id=dataset_document.id
)
@@ -513,17 +517,12 @@ class IndexingRunner:
all_documents.extend(split_documents)
# processing qa document
if document_form == 'qa_model':
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i:i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents})
'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
@@ -531,13 +530,13 @@ class IndexingRunner:
return all_qa_documents
return all_documents
def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents):
def format_qa_document(self, tenant_id: str, document_node, all_qa_documents):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
try:
# qa model document
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
@@ -638,6 +637,10 @@ class IndexingRunner:
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
@@ -648,7 +651,7 @@ class IndexingRunner:
chunk_documents = documents[i:i + chunk_size]
tokens += sum(
TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
embedding_model.get_num_tokens(document.page_content)
for document in chunk_documents
)

View File

@@ -1,148 +0,0 @@
from typing import Union, Optional, List
from langchain.callbacks.base import BaseCallbackHandler
from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType, ProviderName
class LLMBuilder:
"""
This class handles the following logic:
1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
OPENAI_API_TYPE=azure
OPENAI_API_VERSION=2022-12-01
OPENAI_API_BASE=https://your-resource-name.openai.azure.com
OPENAI_API_KEY=<your Azure OpenAI API key>
3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
"""
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
llm_cls = None
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableChatOpenAI
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureChatOpenAI
elif provider == ProviderName.ANTHROPIC.value:
llm_cls = StreamableChatAnthropic
elif mode == 'completion':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableOpenAI
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureOpenAI
if not llm_cls:
raise ValueError(f"model name {model_name} is not supported.")
model_kwargs = {
'model_name': model_name,
'temperature': kwargs.get('temperature', 0),
'max_tokens': kwargs.get('max_tokens', 256),
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
'callbacks': kwargs.get('callbacks', None),
'streaming': kwargs.get('streaming', False),
}
model_kwargs.update(model_credentials)
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
return llm_cls(**model_kwargs)
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name")
completion_params = model.get("completion_params", {})
return cls.to_llm(
tenant_id=tenant_id,
model_name=model_name,
temperature=completion_params.get('temperature', 0),
max_tokens=completion_params.get('max_tokens', 256),
top_p=completion_params.get('top_p', 0),
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming,
callbacks=callbacks
)
@classmethod
def get_mode_by_model(cls, model_name: str) -> str:
if not model_name:
raise ValueError(f"empty model name is not supported.")
if model_name in llm_constant.models_by_mode['chat']:
return "chat"
elif model_name in llm_constant.models_by_mode['completion']:
return "completion"
else:
raise ValueError(f"model name {model_name} is not supported.")
@classmethod
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
"""
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Raises an exception if the model_name is not found or if the provider is not found.
"""
if not model_name:
raise Exception('model name not found')
#
# if model_name not in llm_constant.models:
# raise Exception('model {} not found'.format(model_name))
# model_provider = llm_constant.models[model_name]
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
return provider_service.get_credentials(model_name)
@classmethod
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
provider_name = llm_constant.models[model_name]
if provider_name == 'openai':
# get the default provider (openai / azure_openai) for the tenant
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = azure_openai_provider
elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = azure_openai_provider
if not provider:
raise ProviderTokenNotInitError(
f"No valid {provider_name} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
provider_name = provider.provider_name
return provider_name

View File

@@ -1,15 +0,0 @@
import openai
from models.provider import ProviderName
class Moderation:
def __init__(self, provider: str, api_key: str):
self.provider = provider
self.api_key = api_key
if self.provider == ProviderName.OPENAI.value:
self.client = openai.Moderation
def moderate(self, text):
return self.client.create(input=text, api_key=self.api_key)

View File

@@ -1,138 +0,0 @@
import json
import logging
from typing import Optional, Union
import anthropic
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName, ProviderType
class AnthropicProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
return [
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
},
{
'id': 'claude-2',
'name': 'claude-2',
},
]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.get_provider_api_key(model_id=model_id)
def get_provider_name(self):
return ProviderName.ANTHROPIC
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'anthropic_api_key': ''
}
if obfuscated:
if not config.get('anthropic_api_key'):
config = {
'anthropic_api_key': ''
}
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
return config
return config
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
# check OpenAI / Azure OpenAI credential is valid
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider:
provider = openai_provider
elif azure_openai_provider:
provider = azure_openai_provider
if not provider:
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if quota_used >= quota_limit:
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
f"please configure OpenAI or Azure OpenAI provider first.")
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'anthropic_api_key' not in config:
raise ValueError('anthropic_api_key must be provided.')
chat_llm = ChatAnthropic(
model='claude-instant-1',
anthropic_api_key=config['anthropic_api_key'],
max_tokens_to_sample=10,
temperature=0,
default_request_timeout=60
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except anthropic.APIConnectionError as ex:
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
except Exception as ex:
logging.exception('Anthropic config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}

View File

@@ -1,145 +0,0 @@
import json
import logging
from typing import Optional, Union
import openai
import requests
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
return []
def check_embedding_model(self, credentials: Optional[dict] = None):
credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
try:
result = openai.Embedding.create(input=['test'],
engine='text-embedding-ada-002',
timeout=60,
api_key=str(credentials.get('openai_api_key')),
api_base=str(credentials.get('openai_api_base')),
api_type='azure',
api_version=str(credentials.get('openai_api_version')))["data"][0][
"embedding"]
except openai.error.AuthenticationError as e:
raise AzureAuthenticationError(str(e))
except openai.error.APIConnectionError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
except openai.error.InvalidRequestError as e:
if e.http_status == 404:
raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
"deployment name is exists in Azure AI")
else:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
except openai.error.OpenAIError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
if not isinstance(result, list):
raise AzureRequestFailedError('Failed to request Azure OpenAI.')
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Azure OpenAI as a dictionary.
"""
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
config['chunk_size'] = 16
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
def get_provider_name(self):
return ProviderName.AZURE_OPENAI
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
if obfuscated:
if not config.get('openai_api_key'):
config = {
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
return config
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'openai_api_version' not in config:
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
self.check_embedding_model(credentials=config)
except ValidateFailedError as e:
raise e
except AzureAuthenticationError:
raise ValidateFailedError('Validation failed, please check your API Key.')
except AzureRequestFailedError as ex:
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
except Exception as ex:
logging.exception('Azure OpenAI Credentials validation failed')
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': config['openai_api_base'],
'openai_api_key': self.encrypt_token(config['openai_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
return config
class AzureAuthenticationError(Exception):
pass
class AzureRequestFailedError(Exception):
pass

View File

@@ -1,132 +0,0 @@
import base64
from abc import ABC, abstractmethod
from typing import Optional, Union
from core.constant import llm_constant
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.provider import Provider, ProviderType, ProviderName
class BaseProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
"""
provider = self.get_provider(only_custom)
if not provider:
raise ProviderTokenNotInitError(
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if model_id and model_id == 'gpt-4':
raise ModelCurrentlyNotSupportError()
if quota_used >= quota_limit:
raise QuotaExceededError()
return self.get_hosted_credentials()
else:
return self.get_decrypted_token(provider.encrypted_config)
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
@classmethod
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist.
"""
query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id
)
if provider_name:
query = query.filter(Provider.provider_name == provider_name)
if only_custom:
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
providers = query.order_by(Provider.provider_type.asc()).all()
for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
return provider
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
return provider
return None
def get_hosted_credentials(self) -> Union[str | dict]:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = ''
if obfuscated:
return self.obfuscated_token(config)
return config
def obfuscated_token(self, token: str):
return token[:6] + '*' * (len(token) - 8) + token[-2:]
def get_token_type(self):
return str
def get_encrypted_token(self, config: Union[dict | str]):
return self.encrypt_token(config)
def get_decrypted_token(self, token: str):
return self.decrypt_token(token)
def encrypt_token(self, token):
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token):
return rsa.decrypt(base64.b64decode(token), self.tenant_id)
@abstractmethod
def get_provider_name(self):
raise NotImplementedError
@abstractmethod
def get_credentials(self, model_id: Optional[str] = None) -> dict:
raise NotImplementedError
@abstractmethod
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
raise NotImplementedError
@abstractmethod
def config_validate(self, config: str):
raise NotImplementedError

View File

@@ -1,2 +0,0 @@
class ValidateFailedError(Exception):
description = "Provider Validate failed"

View File

@@ -1,22 +0,0 @@
from typing import Optional
from core.llm.provider.base import BaseProvider
from models.provider import ProviderName
class HuggingfaceProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
# todo
return []
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
"""
return {
'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self):
return ProviderName.HUGGINGFACEHUB

View File

@@ -1,53 +0,0 @@
from typing import Optional, Union
from core.llm.provider.anthropic_provider import AnthropicProvider
from core.llm.provider.azure_provider import AzureProvider
from core.llm.provider.base import BaseProvider
from core.llm.provider.huggingface_provider import HuggingfaceProvider
from core.llm.provider.openai_provider import OpenAIProvider
from models.provider import Provider
class LLMProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self.init_provider(tenant_id, provider_name)
def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
if provider_name == 'openai':
return OpenAIProvider(tenant_id)
elif provider_name == 'azure_openai':
return AzureProvider(tenant_id)
elif provider_name == 'anthropic':
return AnthropicProvider(tenant_id)
elif provider_name == 'huggingface':
return HuggingfaceProvider(tenant_id)
else:
raise Exception('provider {} not found'.format(provider_name))
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
return self.provider.get_models(model_id)
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.provider.get_credentials(model_id)
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
def get_provider_db_record(self) -> Optional[Provider]:
return self.provider.get_provider()
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
:param config:
:raises: ValidateFailedError
"""
return self.provider.config_validate(config)
def get_token_type(self):
return self.provider.get_token_type()
def get_encrypted_token(self, config: Union[dict | str]):
return self.provider.get_encrypted_token(config)

View File

@@ -1,55 +0,0 @@
import logging
from typing import Optional, Union
import openai
from openai.error import AuthenticationError, OpenAIError
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.moderation import Moderation
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName
class OpenAIProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
response = openai.Model.list(**credentials)
return [{
'id': model['id'],
'name': model['id'],
} for model in response['data']]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the credentials for the given tenant_id and provider_name.
"""
return {
'openai_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self):
return ProviderName.OPENAI
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
try:
Moderation(self.get_provider_name().value, config).moderate('test')
except (AuthenticationError, OpenAIError) as ex:
raise ValidateFailedError(str(ex))
except Exception as ex:
logging.exception('OpenAI config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return hosted_llm_credentials.openai.api_key

View File

@@ -1,62 +0,0 @@
from typing import List, Optional, Any, Dict
from httpx import Timeout
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
from pydantic import root_validator
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
class StreamableChatAnthropic(ChatAnthropic):
"""
Wrapper around Anthropic's large language model.
"""
default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
@root_validator()
def prepare_params(cls, values: Dict) -> Dict:
values['model_name'] = values.get('model')
values['max_tokens'] = values.get('max_tokens_to_sample')
return values
@handle_anthropic_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
params['model'] = params.get('model_name')
del params['model_name']
params['max_tokens_to_sample'] = params.get('max_tokens')
del params['max_tokens']
del params['frequency_penalty']
del params['presence_penalty']
return params
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text

View File

@@ -1,41 +0,0 @@
import decimal
from typing import Optional
import tiktoken
from core.constant import llm_constant
class TokenCalculator:
@classmethod
def get_num_tokens(cls, model_name: str, text: str):
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(model_name)
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
@classmethod
def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal:
if model_name in llm_constant.models_by_mode['embedding']:
unit_price = llm_constant.model_prices[model_name]['usage']
elif text_type == 'prompt':
unit_price = llm_constant.model_prices[model_name]['prompt']
elif text_type == 'completion':
unit_price = llm_constant.model_prices[model_name]['completion']
else:
raise Exception('Invalid text type')
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
@classmethod
def get_currency(cls, model_name: str):
return llm_constant.model_currency

View File

@@ -1,26 +0,0 @@
import openai
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from models.provider import ProviderName
from core.llm.provider.base import BaseProvider
class Whisper:
def __init__(self, provider: BaseProvider):
self.provider = provider
if self.provider.get_provider_name() == ProviderName.OPENAI:
self.client = openai.Audio
self.credentials = provider.get_credentials()
@handle_openai_exceptions
def transcribe(self, file):
return self.client.transcribe(
model='whisper-1',
file=file,
api_key=self.credentials.get('openai_api_key'),
api_base=self.credentials.get('openai_api_base'),
api_type=self.credentials.get('openai_api_type'),
api_version=self.credentials.get('openai_api_version'),
)

View File

@@ -1,27 +0,0 @@
import logging
from functools import wraps
import anthropic
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_anthropic_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except anthropic.APIConnectionError as e:
logging.exception("Failed to connect to Anthropic API.")
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
except anthropic.RateLimitError:
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
except anthropic.AuthenticationError as e:
raise LLMAuthorizationError(f"Anthropic: {e.message}")
except anthropic.BadRequestError as e:
raise LLMBadRequestError(f"Anthropic: {e.message}")
except anthropic.APIStatusError as e:
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
return wrapper

View File

@@ -1,31 +0,0 @@
import logging
from functools import wraps
import openai
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_openai_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except openai.error.InvalidRequestError as e:
logging.exception("Invalid request to OpenAI API.")
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
except openai.error.OpenAIError as e:
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper

View File

@@ -1,10 +1,10 @@
from typing import Any, List, Dict, Union
from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
from langchain.schema import get_buffer_string, BaseMessage
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
from core.model_providers.models.llm.base import BaseLLM
from extensions.ext_database import db
from models.model import Conversation, Message
@@ -13,7 +13,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation
human_prefix: str = "Human"
ai_prefix: str = "Assistant"
llm: BaseLanguageModel
model_instance: BaseLLM
memory_key: str = "chat_history"
max_token_limit: int = 2000
message_limit: int = 10
@@ -29,23 +29,23 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
messages = list(reversed(messages))
chat_messages: List[BaseMessage] = []
chat_messages: List[PromptMessage] = []
for message in messages:
chat_messages.append(HumanMessage(content=message.query))
chat_messages.append(AIMessage(content=message.answer))
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:
return chat_messages
return []
# prune the chat message if it exceeds the max token limit
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
return chat_messages
return to_lc_messages(chat_messages)
@property
def memory_variables(self) -> List[str]:

View File

@@ -0,0 +1,293 @@
from typing import Optional
from langchain.callbacks.base import Callbacks
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db
from models.provider import TenantDefaultModel
class ModelFactory:
@classmethod
def get_text_generation_model_from_model_config(cls, tenant_id: str,
model_config: dict,
streaming: bool = False,
callbacks: Callbacks = None) -> Optional[BaseLLM]:
provider_name = model_config.get("provider")
model_name = model_config.get("name")
completion_params = model_config.get("completion_params", {})
return cls.get_text_generation_model(
tenant_id=tenant_id,
model_provider_name=provider_name,
model_name=model_name,
model_kwargs=ModelKwargs(
temperature=completion_params.get('temperature', 0),
max_tokens=completion_params.get('max_tokens', 256),
top_p=completion_params.get('top_p', 0),
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1)
),
streaming=streaming,
callbacks=callbacks
)
@classmethod
def get_text_generation_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None,
model_kwargs: Optional[ModelKwargs] = None,
streaming: bool = False,
callbacks: Callbacks = None) -> Optional[BaseLLM]:
"""
get text generation model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:param model_kwargs:
:param streaming:
:param callbacks:
:return:
"""
is_default_model = False
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default System Reasoning Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
is_default_model = True
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init text generation model
model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
try:
model_instance = model_class(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs,
streaming=streaming,
callbacks=callbacks
)
except LLMBadRequestError as e:
if is_default_model:
raise LLMBadRequestError(f"Default model {model_name} is not available. "
f"Please check your model provider credentials.")
else:
raise e
if is_default_model:
model_instance.deduct_quota = False
return model_instance
@classmethod
def get_embedding_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
"""
get embedding model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Embedding Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init embedding model
model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_speech2text_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
"""
get speech to text model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Speech-to-Text Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init speech to text model
model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_moderation_model(cls,
tenant_id: str,
model_provider_name: str,
model_name: str) -> Optional[BaseProviderModel]:
"""
get moderation model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init moderation model
model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
"""
get default model of model type.
:param tenant_id:
:param model_type:
:return:
"""
# get default model
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.value
).first()
if not default_model:
model_provider_rules = ModelProviderFactory.get_provider_rules()
for model_provider_name, model_provider_rule in model_provider_rules.items():
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
continue
model_list = model_provider.get_supported_model_list(model_type)
if model_list:
model_info = model_list[0]
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
provider_name=model_provider_name,
model_name=model_info['id']
)
db.session.add(default_model)
db.session.commit()
break
return default_model
@classmethod
def update_default_model(cls,
tenant_id: str,
model_type: ModelType,
provider_name: str,
model_name: str) -> TenantDefaultModel:
"""
update default model of model type.
:param tenant_id:
:param model_type:
:param provider_name:
:param model_name:
:return:
"""
model_provider_name = ModelProviderFactory.get_provider_names()
if provider_name not in model_provider_name:
raise ValueError(f'Invalid provider name: {provider_name}')
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
model_list = model_provider.get_supported_model_list(model_type)
model_ids = [model['id'] for model in model_list]
if model_name not in model_ids:
raise ValueError(f'Invalid model name: {model_name}')
# get default model
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.value
).first()
if default_model:
# update default model
default_model.provider_name = provider_name
default_model.model_name = model_name
db.session.commit()
else:
# create default model
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
provider_name=provider_name,
model_name=model_name,
)
db.session.add(default_model)
db.session.commit()
return default_model

View File

@@ -0,0 +1,228 @@
from typing import Type
from sqlalchemy.exc import IntegrityError
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.rules import provider_rules
from extensions.ext_database import db
from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
DEFAULT_MODELS = {
ModelType.TEXT_GENERATION.value: {
'provider_name': 'openai',
'model_name': 'gpt-3.5-turbo',
},
ModelType.EMBEDDINGS.value: {
'provider_name': 'openai',
'model_name': 'text-embedding-ada-002',
},
ModelType.SPEECH_TO_TEXT.value: {
'provider_name': 'openai',
'model_name': 'whisper-1',
}
}
class ModelProviderFactory:
@classmethod
def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
if provider_name == 'openai':
from core.model_providers.providers.openai_provider import OpenAIProvider
return OpenAIProvider
elif provider_name == 'anthropic':
from core.model_providers.providers.anthropic_provider import AnthropicProvider
return AnthropicProvider
elif provider_name == 'minimax':
from core.model_providers.providers.minimax_provider import MinimaxProvider
return MinimaxProvider
elif provider_name == 'spark':
from core.model_providers.providers.spark_provider import SparkProvider
return SparkProvider
elif provider_name == 'tongyi':
from core.model_providers.providers.tongyi_provider import TongyiProvider
return TongyiProvider
elif provider_name == 'wenxin':
from core.model_providers.providers.wenxin_provider import WenxinProvider
return WenxinProvider
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider
elif provider_name == 'azure_openai':
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
return AzureOpenAIProvider
elif provider_name == 'replicate':
from core.model_providers.providers.replicate_provider import ReplicateProvider
return ReplicateProvider
elif provider_name == 'huggingface_hub':
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
return HuggingfaceHubProvider
else:
raise NotImplementedError
@classmethod
def get_provider_names(cls):
"""
Returns a list of provider names.
"""
return list(provider_rules.keys())
@classmethod
def get_provider_rules(cls):
"""
Returns a list of provider rules.
:return:
"""
return provider_rules
@classmethod
def get_provider_rule(cls, provider_name: str):
"""
Returns provider rule.
"""
return provider_rules[provider_name]
@classmethod
def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred model provider.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:return:
"""
# get preferred provider
preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
if not preferred_provider or not preferred_provider.is_valid:
return None
# init model provider
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
return model_provider_class(provider=preferred_provider)
@classmethod
def get_preferred_type_by_preferred_model_provider(cls,
tenant_id: str,
model_provider_name: str,
preferred_model_provider: TenantPreferredModelProvider):
"""
get preferred provider type by preferred model provider.
:param model_provider_name:
:param preferred_model_provider:
:return:
"""
if not preferred_model_provider:
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
support_provider_types = model_provider_rules['support_provider_types']
if ProviderType.CUSTOM.value in support_provider_types:
custom_provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.is_valid == True
).first()
if custom_provider:
return ProviderType.CUSTOM.value
model_provider = cls.get_model_provider_class(model_provider_name)
if ProviderType.SYSTEM.value in support_provider_types \
and model_provider.is_provider_type_system_supported():
return ProviderType.SYSTEM.value
elif ProviderType.CUSTOM.value in support_provider_types:
return ProviderType.CUSTOM.value
else:
return preferred_model_provider.preferred_provider_type
@classmethod
def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
# get preferred provider type
preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
# get providers by preferred provider type
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == preferred_provider_type
).all()
no_system_provider = False
if preferred_provider_type == ProviderType.SYSTEM.value:
quota_type_to_provider_dict = {}
for provider in providers:
quota_type_to_provider_dict[provider.quota_type] = provider
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
for quota_type_enum in ProviderQuotaType:
quota_type = quota_type_enum.value
if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
and quota_type in quota_type_to_provider_dict.keys():
provider = quota_type_to_provider_dict[quota_type]
if provider.is_valid and provider.quota_limit > provider.quota_used:
return provider
no_system_provider = True
if no_system_provider:
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
if providers:
return providers[0]
else:
try:
provider = Provider(
tenant_id=tenant_id,
provider_name=model_provider_name,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(provider)
db.session.commit()
except IntegrityError:
db.session.rollback()
provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
return provider
return None
@classmethod
def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider type of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name == model_provider_name
).first()
return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)

View File

@@ -0,0 +1,22 @@
from abc import ABC
from typing import Any
from core.model_providers.providers.base import BaseModelProvider
class BaseProviderModel(ABC):
_client: Any
_model_provider: BaseModelProvider
def __init__(self, model_provider: BaseModelProvider, client: Any):
self._model_provider = model_provider
self._client = client
@property
def client(self):
return self._client
@property
def model_provider(self):
return self._model_provider

View File

@@ -0,0 +1,78 @@
import decimal
import logging
import openai
import tiktoken
from langchain.embeddings import OpenAIEmbeddings
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \
LLMAPIUnavailableError, LLMAPIConnectionError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureOpenAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = OpenAIEmbeddings(
deployment=name,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
chunk_size=16,
max_retries=1,
**self.credentials
)
super().__init__(model_provider, client, name)
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name'))
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to Azure OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("Azure OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@@ -0,0 +1,40 @@
from abc import abstractmethod
from typing import Any
import tiktoken
from langchain.schema.language_model import _get_token_ids_default_method
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseEmbedding(BaseProviderModel):
name: str
type: ModelType = ModelType.EMBEDDINGS
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
return len(_get_token_ids_default_method(text))
def get_token_price(self, tokens: int):
return 0
def get_currency(self):
return 'USD'
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError

View File

@@ -0,0 +1,35 @@
import decimal
import logging
from langchain.embeddings import MiniMaxEmbeddings
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
class MinimaxEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = MiniMaxEmbeddings(
model=name,
**credentials
)
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex

View File

@@ -0,0 +1,72 @@
import decimal
import logging
import openai
import tiktoken
from langchain.embeddings import OpenAIEmbeddings
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
class OpenAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = OpenAIEmbeddings(
max_retries=1,
**credentials
)
super().__init__(model_provider, client, name)
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(self.name)
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@@ -0,0 +1,36 @@
import decimal
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings
from core.model_providers.models.embedding.base import BaseEmbedding
class ReplicateEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = ReplicateEmbeddings(
model=name + ':' + credentials.get('model_version'),
replicate_api_token=credentials.get('replicate_api_token')
)
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}")
else:
return ex

View File

@@ -0,0 +1,53 @@
import enum
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
from pydantic import BaseModel
class LLMRunResult(BaseModel):
content: str
prompt_tokens: int
completion_tokens: int
class MessageType(enum.Enum):
HUMAN = 'human'
ASSISTANT = 'assistant'
SYSTEM = 'system'
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
content: str = ''
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
lc_messages.append(SystemMessage(content=message.content))
return lc_messages
def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
return prompt_messages
def str_to_prompt_messages(texts: list[str]):
prompt_messages = []
for text in texts:
prompt_messages.append(PromptMessage(content=text))
return prompt_messages

View File

@@ -0,0 +1,59 @@
import enum
from typing import Optional, TypeVar, Generic
from langchain.load.serializable import Serializable
from pydantic import BaseModel
class ModelMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'
class ModelType(enum.Enum):
TEXT_GENERATION = 'text-generation'
EMBEDDINGS = 'embeddings'
SPEECH_TO_TEXT = 'speech2text'
IMAGE = 'image'
VIDEO = 'video'
MODERATION = 'moderation'
@staticmethod
def value_of(value):
for member in ModelType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ModelKwargs(BaseModel):
max_tokens: Optional[int]
temperature: Optional[float]
top_p: Optional[float]
presence_penalty: Optional[float]
frequency_penalty: Optional[float]
class KwargRuleType(enum.Enum):
STRING = 'string'
INTEGER = 'integer'
FLOAT = 'float'
T = TypeVar('T')
class KwargRule(Generic[T], BaseModel):
enabled: bool = True
min: Optional[T] = None
max: Optional[T] = None
default: Optional[T] = None
alias: Optional[str] = None
class ModelKwargsRules(BaseModel):
max_tokens: KwargRule = KwargRule[int](enabled=False)
temperature: KwargRule = KwargRule[float](enabled=False)
top_p: KwargRule = KwargRule[float](enabled=False)
presence_penalty: KwargRule = KwargRule[float](enabled=False)
frequency_penalty: KwargRule = KwargRule[float](enabled=False)

View File

@@ -0,0 +1,10 @@
from enum import Enum
class ProviderQuotaUnit(Enum):
TIMES = 'times'
TOKENS = 'tokens'
class ModelFeature(Enum):
AGENT_THOUGHT = 'agent_thought'

View File

@@ -0,0 +1,107 @@
import decimal
import logging
from functools import wraps
from typing import List, Optional, Any
import anthropic
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class AnthropicModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatAnthropic(
model=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
default_request_timeout=60,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'claude-instant-1': {
'prompt': decimal.Decimal('1.63'),
'completion': decimal.Decimal('5.51'),
},
'claude-2': {
'prompt': decimal.Decimal('11.02'),
'completion': decimal.Decimal('32.68'),
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1m * unit_price
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, anthropic.APIConnectionError):
logging.warning("Failed to connect to Anthropic API.")
return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}")
elif isinstance(ex, anthropic.RateLimitError):
return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
elif isinstance(ex, anthropic.AuthenticationError):
return LLMAuthorizationError(f"Anthropic: {ex.message}")
elif isinstance(ex, anthropic.BadRequestError):
return LLMBadRequestError(f"Anthropic: {ex.message}")
elif isinstance(ex, anthropic.APIStatusError):
return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -0,0 +1,177 @@
import decimal
import logging
from functools import wraps
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureOpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
if name == 'text-davinci-003':
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.name == 'text-davinci-003':
client = EnhanceAzureOpenAI(
deployment_name=self.name,
streaming=self.streaming,
request_timeout=60,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
openai_api_key=self.credentials.get('openai_api_key'),
openai_api_base=self.credentials.get('openai_api_base'),
callbacks=self.callbacks,
**provider_model_kwargs
)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
client = EnhanceAzureChatOpenAI(
deployment_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
request_timeout=60,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
openai_api_key=self.credentials.get('openai_api_key'),
openai_api_base=self.credentials.get('openai_api_base'),
callbacks=self.callbacks,
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-35-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-35-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
base_model_name = self.credentials.get("base_model_name")
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[base_model_name]['prompt']
else:
unit_price = model_unit_prices[base_model_name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name == 'text-davinci-003':
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to Azure OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("Azure OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -0,0 +1,269 @@
from abc import abstractmethod
from typing import List, Optional, Any, Union
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.fake import FakeLLM
class BaseLLM(BaseProviderModel):
model_mode: ModelMode = ModelMode.COMPLETION
name: str
model_kwargs: ModelKwargs
credentials: dict
streaming: bool = False
type: ModelType = ModelType.TEXT_GENERATION
deduct_quota: bool = True
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
self.name = name
self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
max_tokens=None,
temperature=None,
top_p=None,
presence_penalty=None,
frequency_penalty=None
)
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
self.streaming = streaming
if streaming:
default_callback = DifyStreamingStdOutCallbackHandler()
else:
default_callback = DifyStdOutCallbackHandler()
if not callbacks:
callbacks = [default_callback]
else:
callbacks.append(default_callback)
self.callbacks = callbacks
client = self._init_client()
super().__init__(model_provider, client)
@abstractmethod
def _init_client(self) -> Any:
raise NotImplementedError
def run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMRunResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
if self.deduct_quota:
self.model_provider.check_quota_over_limit()
if not callbacks:
callbacks = self.callbacks
else:
callbacks.extend(self.callbacks)
if 'fake_response' in kwargs and kwargs['fake_response']:
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
response=kwargs['fake_response'],
num_token_func=self.get_num_tokens,
streaming=self.streaming,
callbacks=callbacks
)
result = fake_llm.generate([prompts])
else:
try:
result = self._run(
messages=messages,
stop=stop,
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
**kwargs
)
except Exception as ex:
raise self.handle_exceptions(ex)
if isinstance(result.generations[0][0], ChatGeneration):
completion_content = result.generations[0][0].message.content
else:
completion_content = result.generations[0][0].text
if self.streaming and not self.support_streaming():
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
response=completion_content,
num_token_func=self.get_num_tokens,
streaming=self.streaming,
callbacks=callbacks
)
fake_llm.generate([prompts])
if result.llm_output and result.llm_output['token_usage']:
prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
completion_tokens = result.llm_output['token_usage']['completion_tokens']
total_tokens = result.llm_output['token_usage']['total_tokens']
else:
prompt_tokens = self.get_num_tokens(messages)
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens
if self.deduct_quota:
self.model_provider.deduct_quota(total_tokens)
return LLMRunResult(
content=completion_content,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
)
@abstractmethod
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_token_price(self, tokens: int, message_type: MessageType):
"""
get token price.
:param tokens:
:param message_type:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_currency(self):
"""
get token currency.
:return:
"""
raise NotImplementedError
def get_model_kwargs(self):
return self.model_kwargs
def set_model_kwargs(self, model_kwargs: ModelKwargs):
self.model_kwargs = model_kwargs
self._set_model_kwargs(model_kwargs)
@abstractmethod
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
"""
Handle llm run exceptions.
:param ex:
:return:
"""
raise NotImplementedError
def add_callbacks(self, callbacks: Callbacks):
"""
Add callbacks to client.
:param callbacks:
:return:
"""
if not self.client.callbacks:
self.client.callbacks = callbacks
else:
self.client.callbacks.extend(callbacks)
@classmethod
def support_streaming(cls):
return False
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if len(messages) == 0:
raise ValueError("prompt must not be empty.")
if not model_mode:
model_mode = self.model_mode
if model_mode == ModelMode.COMPLETION:
return messages[0].content
else:
chat_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
chat_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
chat_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
chat_messages.append(SystemMessage(content=message.content))
return chat_messages
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
"""
convert model kwargs to provider model kwargs.
:param model_rules:
:param model_kwargs:
:return:
"""
model_kwargs_input = {}
for key, value in model_kwargs.dict().items():
rule = getattr(model_rules, key)
if not rule.enabled:
continue
if rule.alias:
key = rule.alias
if rule.default is not None and value is None:
value = rule.default
if rule.min is not None:
value = max(value, rule.min)
if rule.max is not None:
value = min(value, rule.max)
model_kwargs_input[key] = value
return model_kwargs_input

View File

@@ -0,0 +1,70 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import ChatGLM
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class ChatGLMModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatGLM(
callbacks=self.callbacks,
endpoint_url=self.credentials.get('api_base'),
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return False

View File

@@ -0,0 +1,82 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks
from langchain.llms import HuggingFaceEndpoint
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class HuggingfaceHubModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
client = HuggingFaceEndpoint(
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
task='text2text-generation',
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks,
)
else:
client = HuggingFaceHub(
repo_id=self.name,
task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks,
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# not support calc price
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@@ -0,0 +1,70 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import Minimax
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class MinimaxModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return Minimax(
model=self.name,
model_kwargs={
'stream': False
},
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex

View File

@@ -0,0 +1,219 @@
import decimal
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from models.provider import ProviderType, ProviderQuotaType
COMPLETION_MODELS = [
'text-davinci-003', # 4,097 tokens
]
CHAT_MODELS = [
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
]
MODEL_MAX_TOKENS = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
class OpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
if name in COMPLETION_MODELS:
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.name in COMPLETION_MODELS:
client = EnhanceOpenAI(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
**self.credentials,
**provider_model_kwargs
)
else:
# Fine-tuning is currently only available for the following base models:
# davinci, curie, babbage, and ada.
# This means that except for the fixed `completion` model,
# all other fine-tuned models are `completion` models.
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
**self.credentials
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
if self.name == 'gpt-4' \
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name in COMPLETION_MODELS:
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True
# def is_model_valid_or_raise(self):
# """
# check is a valid model.
#
# :return:
# """
# credentials = self._model_provider.get_credentials()
#
# try:
# result = openai.Model.retrieve(
# id=self.name,
# api_key=credentials.get('openai_api_key'),
# request_timeout=60
# )
#
# if 'id' not in result or result['id'] != self.name:
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.")
# except openai.error.OpenAIError as e:
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}")
# except Exception as e:
# logging.exception("OpenAI Model retrieve failed.")
# raise e

View File

@@ -0,0 +1,103 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, get_buffer_string
from replicate.exceptions import ReplicateError, ModelError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.error import LLMBadRequestError
from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class ReplicateModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return EnhanceReplicate(
model=self.name + ':' + self.credentials.get('model_version'),
input=provider_model_kwargs,
streaming=self.streaming,
replicate_api_token=self.credentials.get('replicate_api_token'),
callbacks=self.callbacks,
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
extra_kwargs = {}
if isinstance(prompts, list):
system_messages = [message for message in messages if message.type == 'system']
if system_messages:
system_message = system_messages[0]
extra_kwargs['system_prompt'] = system_message.content
prompts = [message for message in messages if message.type != 'system']
prompts = get_buffer_string(prompts)
# The maximum length the generated tokens can have.
# Corresponds to the length of the input prompt + max_new_tokens.
if 'max_length' in self._client.input:
self._client.input['max_length'] = min(
self._client.input['max_length'] + self.get_num_tokens(messages),
self.model_rules.max_tokens.max
)
return self._client.generate([prompts], stop, callbacks, **extra_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, list):
prompts = get_buffer_string(prompts)
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.input = provider_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -0,0 +1,73 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.spark import ChatSpark
from core.third_party.spark.spark_llm import SparkError
class SparkModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatSpark(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
contents = [message.content for message in messages]
return max(self._client.get_num_tokens("".join(contents)), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, SparkError):
return LLMBadRequestError(f"Spark: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -0,0 +1,77 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from requests import HTTPError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
class TongyiModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
del provider_model_kwargs['max_tokens']
return EnhanceTongyi(
model_name=self.name,
max_retries=1,
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
del provider_model_kwargs['max_tokens']
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ValueError, HTTPError)):
return LLMBadRequestError(f"Tongyi: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -0,0 +1,92 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.wenxin import Wenxin
class WenxinModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return Wenxin(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'ernie-bot': {
'prompt': decimal.Decimal('0.012'),
'completion': decimal.Decimal('0.012'),
},
'ernie-bot-turbo': {
'prompt': decimal.Decimal('0.008'),
'completion': decimal.Decimal('0.008')
},
'bloomz-7b': {
'prompt': decimal.Decimal('0.006'),
'completion': decimal.Decimal('0.006')
}
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@@ -0,0 +1,48 @@
import logging
import openai
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
DEFAULT_AUDIO_MODEL = 'whisper-1'
class OpenAIModeration(BaseProviderModel):
type: ModelType = ModelType.MODERATION
def __init__(self, model_provider: BaseModelProvider, name: str):
super().__init__(model_provider, openai.Moderation)
def run(self, text):
credentials = self.model_provider.get_model_credentials(
model_name=DEFAULT_AUDIO_MODEL,
model_type=self.type
)
try:
return self._client.create(input=text, api_key=credentials['openai_api_key'])
except Exception as ex:
raise self.handle_exceptions(ex)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@@ -0,0 +1,29 @@
from abc import abstractmethod
from typing import Any
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseSpeech2Text(BaseProviderModel):
name: str
type: ModelType = ModelType.SPEECH_TO_TEXT
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def run(self, file):
try:
return self._run(file)
except Exception as ex:
raise self.handle_exceptions(ex)
@abstractmethod
def _run(self, file):
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError

View File

@@ -0,0 +1,47 @@
import logging
import openai
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from core.model_providers.providers.base import BaseModelProvider
class OpenAIWhisper(BaseSpeech2Text):
def __init__(self, model_provider: BaseModelProvider, name: str):
super().__init__(model_provider, openai.Audio, name)
def _run(self, file):
credentials = self.model_provider.get_model_credentials(
model_name=self.name,
model_type=self.type
)
return self._client.transcribe(
model=self.name,
file=file,
api_key=credentials.get('openai_api_key'),
api_base=credentials.get('openai_api_base'),
organization=credentials.get('openai_organization'),
)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@@ -0,0 +1,224 @@
import json
import logging
from json import JSONDecodeError
from typing import Type, Optional
import anthropic
from flask import current_app
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
from models.provider import ProviderType
class AnthropicProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'anthropic'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
},
{
'id': 'claude-2',
'name': 'claude-2',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = AnthropicModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'anthropic_api_key' not in credentials:
raise CredentialsValidateFailedError('Anthropic API Key must be provided.')
try:
credential_kwargs = {
'anthropic_api_key': credentials['anthropic_api_key']
}
if 'anthropic_api_url' in credentials:
credential_kwargs['anthropic_api_url'] = credentials['anthropic_api_url']
chat_llm = ChatAnthropic(
model='claude-instant-1',
max_tokens_to_sample=10,
temperature=0,
default_request_timeout=60,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except anthropic.APIConnectionError as ex:
raise CredentialsValidateFailedError(str(ex))
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Anthropic config validation failed')
raise ex
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['anthropic_api_key'] = encrypter.encrypt_token(tenant_id, credentials['anthropic_api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'anthropic_api_url': None,
'anthropic_api_key': None
}
if credentials['anthropic_api_key']:
credentials['anthropic_api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['anthropic_api_key']
)
if obfuscated:
credentials['anthropic_api_key'] = encrypter.obfuscated_token(credentials['anthropic_api_key'])
if 'anthropic_api_url' not in credentials:
credentials['anthropic_api_url'] = None
return credentials
else:
if hosted_model_providers.anthropic:
return {
'anthropic_api_url': hosted_model_providers.anthropic.api_base,
'anthropic_api_key': hosted_model_providers.anthropic.api_key,
}
else:
return {
'anthropic_api_url': None,
'anthropic_api_key': None
}
@classmethod
def is_provider_type_system_supported(cls) -> bool:
if current_app.config['EDITION'] != 'CLOUD':
return False
if hosted_model_providers.anthropic:
return True
return False
def should_deduct_quota(self):
if hosted_model_providers.anthropic and \
hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > 0:
return True
return False
def get_payment_info(self) -> Optional[dict]:
"""
get product info if it payable.
:return:
"""
if hosted_model_providers.anthropic \
and hosted_model_providers.anthropic.paid_enabled:
return {
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
}
return None
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@@ -0,0 +1,387 @@
import json
import logging
from json import JSONDecodeError
from typing import Type
import openai
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
AZURE_OPENAI_API_VERSION
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
from extensions.ext_database import db
from models.provider import ProviderType, ProviderModel, ProviderQuotaType
BASE_MODELS = [
'gpt-4',
'gpt-4-32k',
'gpt-35-turbo',
'gpt-35-turbo-16k',
'text-davinci-003',
'text-embedding-ada-002',
]
class AzureOpenAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'azure_openai'
def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
# convert old provider config to provider models
self._convert_provider_config_to_model_config()
if self.provider.provider_type == ProviderType.CUSTOM.value:
# get configurable provider models
provider_models = db.session.query(ProviderModel).filter(
ProviderModel.tenant_id == self.provider.tenant_id,
ProviderModel.provider_name == self.provider.provider_name,
ProviderModel.model_type == model_type.value,
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all()
model_list = []
for provider_model in provider_models:
model_dict = {
'id': provider_model.model_name,
'name': provider_model.model_name
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['base_model_name'] in [
'gpt-4',
'gpt-4-32k',
'gpt-35-turbo',
'gpt-35-turbo-16k',
]:
model_dict['features'] = [
ModelFeature.AGENT_THOUGHT.value
]
model_list.append(model_dict)
else:
model_list = self._get_fixed_model_list(model_type)
return model_list
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
models = [
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4',
'name': 'gpt-4',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
}
]
if self.provider.provider_type == ProviderType.SYSTEM.value \
and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
return models
elif model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'text-embedding-ada-002',
'name': 'text-embedding-ada-002'
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = AzureOpenAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = AzureOpenAIEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
base_model_max_tokens = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-35-turbo': 4096,
'gpt-35-turbo-16k': 16384,
'text-davinci-003': 4097,
}
model_credentials = self.get_model_credentials(model_name, model_type)
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get(
model_credentials['base_model_name'],
4097
), default=16),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'openai_api_key' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
if 'openai_api_base' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required')
if credentials['base_model_name'] not in BASE_MODELS:
raise CredentialsValidateFailedError('Base Model Name is invalid')
if model_type == ModelType.TEXT_GENERATION:
try:
client = EnhanceAzureChatOpenAI(
deployment_name=model_name,
temperature=0,
max_tokens=15,
request_timeout=10,
openai_api_type='azure',
openai_api_version='2023-07-01-preview',
openai_api_key=credentials['openai_api_key'],
openai_api_base=credentials['openai_api_base'],
)
client.generate([[HumanMessage(content='hi!')]])
except openai.error.OpenAIError as e:
raise CredentialsValidateFailedError(
f"Azure OpenAI deployment {model_name} not exists, cause: {e.__class__.__name__}:{str(e)}")
except Exception as e:
logging.exception("Azure OpenAI Model retrieve failed.")
raise e
elif model_type == ModelType.EMBEDDINGS:
try:
client = OpenAIEmbeddings(
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
deployment=model_name,
chunk_size=16,
max_retries=1,
openai_api_key=credentials['openai_api_key'],
openai_api_base=credentials['openai_api_base']
)
client.embed_query('hi')
except openai.error.OpenAIError as e:
logging.exception("Azure OpenAI Model check error.")
raise CredentialsValidateFailedError(
f"Azure OpenAI deployment {model_name} not exists, cause: {e.__class__.__name__}:{str(e)}")
except Exception as e:
logging.exception("Azure OpenAI Model retrieve failed.")
raise e
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['openai_api_key'] = encrypter.encrypt_token(tenant_id, credentials['openai_api_key'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type == ProviderType.CUSTOM.value:
# convert old provider config to provider models
self._convert_provider_config_to_model_config()
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'openai_api_base': '',
'openai_api_key': '',
'base_model_name': ''
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['openai_api_key']:
credentials['openai_api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['openai_api_key']
)
if obfuscated:
credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key'])
return credentials
else:
if hosted_model_providers.azure_openai:
return {
'openai_api_base': hosted_model_providers.azure_openai.api_base,
'openai_api_key': hosted_model_providers.azure_openai.api_key,
'base_model_name': model_name
}
else:
return {
'openai_api_base': None,
'openai_api_key': None,
'base_model_name': None
}
@classmethod
def is_provider_type_system_supported(cls) -> bool:
if current_app.config['EDITION'] != 'CLOUD':
return False
if hosted_model_providers.azure_openai:
return True
return False
def should_deduct_quota(self):
if hosted_model_providers.azure_openai \
and hosted_model_providers.azure_openai.quota_limit and hosted_model_providers.azure_openai.quota_limit > 0:
return True
return False
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}
def _convert_provider_config_to_model_config(self):
if self.provider.provider_type == ProviderType.CUSTOM.value \
and self.provider.is_valid \
and self.provider.encrypted_config:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'openai_api_base': '',
'openai_api_key': '',
'base_model_name': ''
}
self._add_provider_model(
model_name='gpt-35-turbo',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='gpt-35-turbo-16k',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='gpt-4',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='text-davinci-003',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='text-embedding-ada-002',
model_type=ModelType.EMBEDDINGS,
provider_credentials=credentials
)
self.provider.encrypted_config = None
db.session.commit()
def _add_provider_model(self, model_name: str, model_type: ModelType, provider_credentials: dict):
credentials = provider_credentials.copy()
credentials['base_model_name'] = model_name
provider_model = ProviderModel(
tenant_id=self.provider.tenant_id,
provider_name=self.provider.provider_name,
model_name=model_name,
model_type=model_type.value,
encrypted_config=json.dumps(credentials),
is_valid=True
)
db.session.add(provider_model)
db.session.commit()

View File

@@ -0,0 +1,283 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Type, Optional
from flask import current_app
from pydantic import BaseModel
from core.model_providers.error import QuotaExceededError, LLMBadRequestError
from extensions.ext_database import db
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from core.model_providers.models.entity.provider import ProviderQuotaUnit
from core.model_providers.rules import provider_rules
from models.provider import Provider, ProviderType, ProviderModel
class BaseModelProvider(BaseModel, ABC):
provider: Provider
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
@abstractmethod
def provider_name(self):
"""
Returns the name of a provider.
"""
raise NotImplementedError
def get_rules(self):
"""
Returns the rules of a provider.
"""
return provider_rules[self.provider_name]
def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
"""
get supported model object list for use.
:param model_type:
:return:
"""
rules = self.get_rules()
if 'custom' not in rules['support_provider_types']:
return self._get_fixed_model_list(model_type)
if 'model_flexibility' not in rules:
return self._get_fixed_model_list(model_type)
if rules['model_flexibility'] == 'fixed':
return self._get_fixed_model_list(model_type)
# get configurable provider models
provider_models = db.session.query(ProviderModel).filter(
ProviderModel.tenant_id == self.provider.tenant_id,
ProviderModel.provider_name == self.provider.provider_name,
ProviderModel.model_type == model_type.value,
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all()
return [{
'id': provider_model.model_name,
'name': provider_model.model_name
} for provider_model in provider_models]
@abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
"""
get supported model object list for use.
:param model_type:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_class(self, model_type: ModelType) -> Type:
"""
get specific model class.
:param model_type:
:return:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
check provider credentials valid.
:param credentials:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
"""
encrypt provider credentials for save.
:param tenant_id:
:param credentials:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param obfuscated:
:return:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
raise NotImplementedError
@classmethod
def is_provider_type_system_supported(cls) -> bool:
return current_app.config['EDITION'] == 'CLOUD'
def check_quota_over_limit(self):
"""
check provider quota over limit.
:return:
"""
if self.provider.provider_type != ProviderType.SYSTEM.value:
return
rules = self.get_rules()
if 'system' not in rules['support_provider_types']:
return
provider = db.session.query(Provider).filter(
db.and_(
Provider.id == self.provider.id,
Provider.is_valid == True,
Provider.quota_limit > Provider.quota_used
)
).first()
if not provider:
raise QuotaExceededError()
def deduct_quota(self, used_tokens: int = 0) -> None:
"""
deduct available quota when provider type is system or paid.
:return:
"""
if self.provider.provider_type != ProviderType.SYSTEM.value:
return
rules = self.get_rules()
if 'system' not in rules['support_provider_types']:
return
if not self.should_deduct_quota():
return
if 'system_config' not in rules:
quota_unit = ProviderQuotaUnit.TIMES.value
elif 'quota_unit' not in rules['system_config']:
quota_unit = ProviderQuotaUnit.TIMES.value
else:
quota_unit = rules['system_config']['quota_unit']
if quota_unit == ProviderQuotaUnit.TOKENS.value:
used_quota = used_tokens
else:
used_quota = 1
db.session.query(Provider).filter(
Provider.tenant_id == self.provider.tenant_id,
Provider.provider_name == self.provider.provider_name,
Provider.provider_type == self.provider.provider_type,
Provider.quota_type == self.provider.quota_type,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + used_quota})
db.session.commit()
def should_deduct_quota(self):
return False
def update_last_used(self) -> None:
"""
update last used time.
:return:
"""
db.session.query(Provider).filter(
Provider.tenant_id == self.provider.tenant_id,
Provider.provider_name == self.provider.provider_name
).update({'last_used': datetime.utcnow()})
db.session.commit()
def get_payment_info(self) -> Optional[dict]:
"""
get product info if it payable.
:return:
"""
return None
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
"""
get provider model.
:param model_name:
:param model_type:
:return:
"""
provider_model = db.session.query(ProviderModel).filter(
ProviderModel.tenant_id == self.provider.tenant_id,
ProviderModel.provider_name == self.provider.provider_name,
ProviderModel.model_name == model_name,
ProviderModel.model_type == model_type.value,
ProviderModel.is_valid == True
).first()
if not provider_model:
raise LLMBadRequestError(f"The model {model_name} does not exist. "
f"Please check the configuration.")
return provider_model
class CredentialsValidateFailedError(Exception):
pass

View File

@@ -0,0 +1,157 @@
import json
from json import JSONDecodeError
from typing import Type
from langchain.llms import ChatGLM
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType
class ChatGLMProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'chatglm'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
},
{
'id': 'chatglm-6b',
'name': 'ChatGLM-6B',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = ChatGLMModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
model_max_tokens = {
'chatglm-6b': 2000,
'chatglm2-6b': 32000,
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_base' not in credentials:
raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
try:
credential_kwargs = {
'endpoint_url': credentials['api_base']
}
llm = ChatGLM(
max_token=10,
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_base'] = encrypter.encrypt_token(tenant_id, credentials['api_base'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_base': None
}
if credentials['api_base']:
credentials['api_base'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_base']
)
if obfuscated:
credentials['api_base'] = encrypter.obfuscated_token(credentials['api_base'])
return credentials
return {}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@@ -0,0 +1,76 @@
import os
from typing import Optional
import langchain
from flask import Flask
from pydantic import BaseModel
class HostedOpenAI(BaseModel):
api_base: str = None
api_organization: str = None
api_key: str
quota_limit: int = 0
"""Quota limit for the openai hosted model. 0 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1
class HostedAzureOpenAI(BaseModel):
api_base: str
api_key: str
quota_limit: int = 0
"""Quota limit for the azure openai hosted model. 0 means unlimited."""
class HostedAnthropic(BaseModel):
api_base: str = None
api_key: str
quota_limit: int = 0
"""Quota limit for the anthropic hosted model. 0 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1
class HostedModelProviders(BaseModel):
openai: Optional[HostedOpenAI] = None
azure_openai: Optional[HostedAzureOpenAI] = None
anthropic: Optional[HostedAnthropic] = None
hosted_model_providers = HostedModelProviders()
def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
if app.config.get("HOSTED_OPENAI_ENABLED"):
hosted_model_providers.openai = HostedOpenAI(
api_base=app.config.get("HOSTED_OPENAI_API_BASE"),
api_organization=app.config.get("HOSTED_OPENAI_API_ORGANIZATION"),
api_key=app.config.get("HOSTED_OPENAI_API_KEY"),
quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"),
paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"),
)
if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"):
hosted_model_providers.azure_openai = HostedAzureOpenAI(
api_base=app.config.get("HOSTED_AZURE_OPENAI_API_BASE"),
api_key=app.config.get("HOSTED_AZURE_OPENAI_API_KEY"),
quota_limit=app.config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT"),
)
if app.config.get("HOSTED_ANTHROPIC_ENABLED"):
hosted_model_providers.anthropic = HostedAnthropic(
api_base=app.config.get("HOSTED_ANTHROPIC_API_BASE"),
api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"),
quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"),
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
)

View File

@@ -0,0 +1,183 @@
import json
from typing import Type
from huggingface_hub import HfApi
from langchain.llms import HuggingFaceEndpoint
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from models.provider import ProviderType
class HuggingfaceHubProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'huggingface_hub'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = HuggingfaceHubModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=1500, default=200),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if model_type != ModelType.TEXT_GENERATION:
raise NotImplementedError
if 'huggingfacehub_api_type' not in credentials \
or credentials['huggingfacehub_api_type'] not in ['hosted_inference_api', 'inference_endpoints']:
raise CredentialsValidateFailedError('Hugging Face Hub API Type invalid, '
'must be hosted_inference_api or inference_endpoints.')
if 'huggingfacehub_api_token' not in credentials:
raise CredentialsValidateFailedError('Hugging Face Hub API Token must be provided.')
hfapi = HfApi(token=credentials['huggingfacehub_api_token'])
try:
hfapi.whoami()
except Exception:
raise CredentialsValidateFailedError("Invalid API Token.")
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
if 'huggingfacehub_endpoint_url' not in credentials:
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
try:
llm = HuggingFaceEndpoint(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task="text2text-generation",
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)
llm("ping")
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
else:
try:
model_info = hfapi.model_info(repo_id=model_name)
if not model_info:
raise ValueError(f'Model {model_name} not found.')
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {VALID_TASKS}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['huggingfacehub_api_token'] = encrypter.encrypt_token(tenant_id, credentials['huggingfacehub_api_token'])
if credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
hfapi = HfApi(token=credentials['huggingfacehub_api_token'])
model_info = hfapi.model_info(repo_id=model_name)
if not model_info:
raise ValueError(f'Model {model_name} not found.')
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
credentials['task_type'] = model_info.pipeline_tag
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'huggingfacehub_api_token': None,
'task_type': None
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['huggingfacehub_api_token']:
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['huggingfacehub_api_token']
)
if obfuscated:
credentials['huggingfacehub_api_token'] = encrypter.obfuscated_token(credentials['huggingfacehub_api_token'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@@ -0,0 +1,179 @@
import json
from json import JSONDecodeError
from typing import Type
from langchain.llms import Minimax
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType, ProviderQuotaType
class MinimaxProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'minimax'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'abab5.5-chat',
'name': 'abab5.5-chat',
},
{
'id': 'abab5-chat',
'name': 'abab5-chat',
}
]
elif model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'embo-01',
'name': 'embo-01',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = MinimaxModel
elif model_type == ModelType.EMBEDDINGS:
model_class = MinimaxEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
model_max_tokens = {
'abab5.5-chat': 16384,
'abab5-chat': 6144,
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.9),
top_p=KwargRule[float](min=0, max=1, default=0.95),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'minimax_group_id' not in credentials:
raise CredentialsValidateFailedError('MiniMax Group ID must be provided.')
if 'minimax_api_key' not in credentials:
raise CredentialsValidateFailedError('MiniMax API Key must be provided.')
try:
credential_kwargs = {
'minimax_group_id': credentials['minimax_group_id'],
'minimax_api_key': credentials['minimax_api_key'],
}
llm = Minimax(
model='abab5.5-chat',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['minimax_api_key'] = encrypter.encrypt_token(tenant_id, credentials['minimax_api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value \
or (self.provider.provider_type == ProviderType.SYSTEM.value
and self.provider.quota_type == ProviderQuotaType.FREE.value):
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'minimax_group_id': None,
'minimax_api_key': None,
}
if credentials['minimax_api_key']:
credentials['minimax_api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['minimax_api_key']
)
if obfuscated:
credentials['minimax_api_key'] = encrypter.obfuscated_token(credentials['minimax_api_key'])
return credentials
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@@ -0,0 +1,289 @@
import json
import logging
from json import JSONDecodeError
from typing import Type, Optional
from flask import current_app
from openai.error import AuthenticationError, OpenAIError
import openai
from core.helper import encrypter
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
from models.provider import ProviderType, ProviderQuotaType
class OpenAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'openai'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
models = [
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4',
'name': 'gpt-4',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
}
]
if self.provider.provider_type == ProviderType.SYSTEM.value \
and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
return models
elif model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'text-embedding-ada-002',
'name': 'text-embedding-ada-002'
}
]
elif model_type == ModelType.SPEECH_TO_TEXT:
return [
{
'id': 'whisper-1',
'name': 'whisper-1'
}
]
elif model_type == ModelType.MODERATION:
return [
{
'id': 'text-moderation-stable',
'name': 'text-moderation-stable'
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = OpenAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = OpenAIEmbedding
elif model_type == ModelType.MODERATION:
model_class = OpenAIModeration
elif model_type == ModelType.SPEECH_TO_TEXT:
model_class = OpenAIWhisper
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
model_max_tokens = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'openai_api_key' not in credentials:
raise CredentialsValidateFailedError('OpenAI API key is required')
try:
credentials_kwargs = {
"api_key": credentials['openai_api_key']
}
if 'openai_api_base' in credentials and credentials['openai_api_base']:
credentials_kwargs['api_base'] = credentials['openai_api_base'] + '/v1'
if 'openai_organization' in credentials:
credentials_kwargs['organization'] = credentials['openai_organization']
openai.ChatCompletion.create(
messages=[{"role": "user", "content": 'ping'}],
model='gpt-3.5-turbo',
timeout=10,
request_timeout=(5, 30),
max_tokens=20,
**credentials_kwargs
)
except (AuthenticationError, OpenAIError) as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('OpenAI config validation failed')
raise ex
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['openai_api_key'] = encrypter.encrypt_token(tenant_id, credentials['openai_api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'openai_api_base': None,
'openai_api_key': self.provider.encrypted_config,
'openai_organization': None
}
if credentials['openai_api_key']:
credentials['openai_api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['openai_api_key']
)
if obfuscated:
credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key'])
if 'openai_api_base' not in credentials or not credentials['openai_api_base']:
credentials['openai_api_base'] = None
else:
credentials['openai_api_base'] = credentials['openai_api_base'] + '/v1'
if 'openai_organization' not in credentials:
credentials['openai_organization'] = None
return credentials
else:
if hosted_model_providers.openai:
return {
'openai_api_base': hosted_model_providers.openai.api_base,
'openai_api_key': hosted_model_providers.openai.api_key,
'openai_organization': hosted_model_providers.openai.api_organization
}
else:
return {
'openai_api_base': None,
'openai_api_key': None,
'openai_organization': None
}
@classmethod
def is_provider_type_system_supported(cls) -> bool:
if current_app.config['EDITION'] != 'CLOUD':
return False
if hosted_model_providers.openai:
return True
return False
def should_deduct_quota(self):
if hosted_model_providers.openai \
and hosted_model_providers.openai.quota_limit and hosted_model_providers.openai.quota_limit > 0:
return True
return False
def get_payment_info(self) -> Optional[dict]:
"""
get payment info if it payable.
:return:
"""
if hosted_model_providers.openai \
and hosted_model_providers.openai.paid_enabled:
return {
'product_id': hosted_model_providers.openai.paid_stripe_price_id,
'increase_quota': hosted_model_providers.openai.paid_increase_quota,
}
return None
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@@ -0,0 +1,184 @@
import json
import logging
from typing import Type
import replicate
from replicate.exceptions import ReplicateError
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding
from models.provider import ProviderType
class ReplicateProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'replicate'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = ReplicateModel
elif model_type == ModelType.EMBEDDINGS:
model_class = ReplicateEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
model_credentials = self.get_model_credentials(model_name, model_type)
model = replicate.Client(api_token=model_credentials.get("replicate_api_token")).models.get(model_name)
try:
version = model.versions.get(model_credentials['model_version'])
except ReplicateError as e:
raise CredentialsValidateFailedError(f"Model {model_name}:{model_credentials['model_version']} not exists, "
f"cause: {e.__class__.__name__}:{str(e)}")
except Exception as e:
logging.exception("Model validate failed.")
raise e
model_kwargs_rules = ModelKwargsRules()
for key, value in version.openapi_schema['components']['schemas']['Input']['properties'].items():
if key not in ['debug', 'prompt'] and value['type'] in ['number', 'integer']:
if key == ['temperature', 'top_p']:
kwarg_rule = KwargRule[float](
type=KwargRuleType.FLOAT.value if value['type'] == 'number' else KwargRuleType.INTEGER.value,
min=float(value.get('minimum')) if value.get('minimum') is not None else None,
max=float(value.get('maximum')) if value.get('maximum') is not None else None,
default=float(value.get('default')) if value.get('default') is not None else None,
)
if key == 'temperature':
model_kwargs_rules.temperature = kwarg_rule
else:
model_kwargs_rules.top_p = kwarg_rule
elif key in ['max_length', 'max_new_tokens']:
model_kwargs_rules.max_tokens = KwargRule[int](
alias=key,
type=KwargRuleType.INTEGER.value,
min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
default=int(value.get('default')) if value.get('default') is not None else 500,
)
return model_kwargs_rules
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'replicate_api_token' not in credentials:
raise CredentialsValidateFailedError('Replicate API Key must be provided.')
if 'model_version' not in credentials:
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
if model_name.count("/") != 1:
raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
'format: {user_name}/{model_name}')
version = credentials['model_version']
try:
model = replicate.Client(api_token=credentials.get("replicate_api_token")).models.get(model_name)
rst = model.versions.get(version)
if model_type == ModelType.EMBEDDINGS \
and 'Embedding' not in rst.openapi_schema['components']['schemas']:
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
elif model_type == ModelType.TEXT_GENERATION \
and ('type' not in rst.openapi_schema['components']['schemas']['Output']['items']
or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
except ReplicateError as e:
raise CredentialsValidateFailedError(
f"Model {model_name}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
except Exception as e:
logging.exception("Replicate config validation failed.")
raise e
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['replicate_api_token'] = encrypter.encrypt_token(tenant_id, credentials['replicate_api_token'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'replicate_api_token': None,
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['replicate_api_token']:
credentials['replicate_api_token'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['replicate_api_token']
)
if obfuscated:
credentials['replicate_api_token'] = encrypter.obfuscated_token(credentials['replicate_api_token'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@@ -0,0 +1,191 @@
import json
import logging
from json import JSONDecodeError
from typing import Type
from flask import current_app
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.spark import ChatSpark
from core.third_party.spark.spark_llm import SparkError
from models.provider import ProviderType, ProviderQuotaType
class SparkProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'spark'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'spark',
'name': '星火认知大模型',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = SparkModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.5),
top_p=KwargRule[float](enabled=False),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4096, default=2048),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'app_id' not in credentials:
raise CredentialsValidateFailedError('Spark app_id must be provided.')
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Spark api_key must be provided.')
if 'api_secret' not in credentials:
raise CredentialsValidateFailedError('Spark api_secret must be provided.')
try:
credential_kwargs = {
'app_id': credentials['app_id'],
'api_key': credentials['api_key'],
'api_secret': credentials['api_secret'],
}
chat_llm = ChatSpark(
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
credentials['api_secret'] = encrypter.encrypt_token(tenant_id, credentials['api_secret'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value \
or (self.provider.provider_type == ProviderType.SYSTEM.value
and self.provider.quota_type == ProviderQuotaType.FREE.value):
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'app_id': None,
'api_key': None,
'api_secret': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
if credentials['api_secret']:
credentials['api_secret'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_secret']
)
if obfuscated:
credentials['api_secret'] = encrypter.obfuscated_token(credentials['api_secret'])
return credentials
else:
return {
'app_id': None,
'api_key': None,
'api_secret': None,
}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@@ -0,0 +1,157 @@
import json
from json import JSONDecodeError
from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
from models.provider import ProviderType
class TongyiProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'tongyi'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'qwen-v1',
'name': 'qwen-v1',
},
{
'id': 'qwen-plus-v1',
'name': 'qwen-plus-v1',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = TongyiModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
model_max_tokens = {
'qwen-v1': 1500,
'qwen-plus-v1': 6500
}
return ModelKwargsRules(
temperature=KwargRule[float](enabled=False),
top_p=KwargRule[float](min=0, max=1, default=0.8),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'dashscope_api_key' not in credentials:
raise CredentialsValidateFailedError('Dashscope API Key must be provided.')
try:
credential_kwargs = {
'dashscope_api_key': credentials['dashscope_api_key']
}
llm = EnhanceTongyi(
model_name='qwen-v1',
max_retries=1,
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['dashscope_api_key'] = encrypter.encrypt_token(tenant_id, credentials['dashscope_api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'dashscope_api_key': None
}
if credentials['dashscope_api_key']:
credentials['dashscope_api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['dashscope_api_key']
)
if obfuscated:
credentials['dashscope_api_key'] = encrypter.obfuscated_token(credentials['dashscope_api_key'])
return credentials
return {}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@@ -0,0 +1,182 @@
import json
from json import JSONDecodeError
from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.wenxin import Wenxin
from models.provider import ProviderType
class WenxinProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'wenxin'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = WenxinModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95),
top_p=KwargRule[float](min=0.01, max=1, default=0.8),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
else:
return ModelKwargsRules(
temperature=KwargRule[float](enabled=False),
top_p=KwargRule[float](enabled=False),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Wenxin api_key must be provided.')
if 'secret_key' not in credentials:
raise CredentialsValidateFailedError('Wenxin secret_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
'secret_key': credentials['secret_key'],
}
llm = Wenxin(
temperature=0.01,
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
'secret_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
if credentials['secret_key']:
credentials['secret_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['secret_key']
)
if obfuscated:
credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key'])
return credentials
else:
return {
'api_key': None,
'secret_key': None,
}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@@ -0,0 +1,47 @@
import json
import os
def init_provider_rules():
# Get the absolute path of the subdirectory
subdirectory_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'rules')
# Path to the providers.json file
providers_json_file_path = os.path.join(subdirectory_path, '_providers.json')
try:
# Open the JSON file and read its content
with open(providers_json_file_path, 'r') as json_file:
data = json.load(json_file)
# Store the content in a dictionary with the key as the file name (without extension)
provider_names = data
except FileNotFoundError:
return "JSON file not found or path error"
except json.JSONDecodeError:
return "JSON file decoding error"
# Dictionary to store the content of all JSON files
json_data = {}
try:
# Loop through all files in the directory
for provider_name in provider_names:
filename = provider_name + '.json'
# Path to each JSON file
json_file_path = os.path.join(subdirectory_path, filename)
# Open each JSON file and read its content
with open(json_file_path, 'r') as json_file:
data = json.load(json_file)
# Store the content in the dictionary with the key as the file name (without extension)
json_data[os.path.splitext(filename)[0]] = data
return json_data
except FileNotFoundError:
return "JSON file not found or path error"
except json.JSONDecodeError:
return "JSON file decoding error"
provider_rules = init_provider_rules()

View File

@@ -0,0 +1,12 @@
[
"openai",
"azure_openai",
"anthropic",
"minimax",
"tongyi",
"spark",
"wenxin",
"chatglm",
"replicate",
"huggingface_hub"
]

View File

@@ -0,0 +1,15 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"trial",
"paid"
],
"quota_unit": "times",
"quota_limit": 1000
},
"model_flexibility": "fixed"
}

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
}

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@@ -0,0 +1,13 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"free"
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed"
}

View File

@@ -0,0 +1,14 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"trial"
],
"quota_unit": "times",
"quota_limit": 200
},
"model_flexibility": "fixed"
}

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@@ -0,0 +1,13 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"free"
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed"
}

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
}

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
}

View File

@@ -3,7 +3,6 @@ from typing import Optional
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatOpenAI
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from pydantic import BaseModel, Field
@@ -15,7 +14,8 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
@@ -32,11 +32,9 @@ class OrchestratorRuleParser:
def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
self.tenant_id = tenant_id
self.app_model_config = app_model_config
self.agent_summary_model_name = "gpt-3.5-turbo-16k"
self.dataset_retrieve_model_name = "gpt-3.5-turbo"
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
-> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict:
return None
@@ -47,43 +45,50 @@ class OrchestratorRuleParser:
chain = None
if agent_mode_config and agent_mode_config.get('enabled'):
tool_configs = agent_mode_config.get('tools', [])
agent_provider_name = model_dict.get('provider', 'openai')
agent_model_name = model_dict.get('name', 'gpt-4')
agent_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_provider_name=agent_provider_name,
model_name=agent_model_name,
model_kwargs=ModelKwargs(
temperature=0.2,
top_p=0.3,
max_tokens=1500
)
)
# add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler(
model_name=agent_model_name,
model_instant=agent_model_instance,
conversation_message_task=conversation_message_task
)
chain_callback.agent_callback = agent_callback
agent_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=agent_model_name,
temperature=0,
max_tokens=1500,
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
)
agent_model_instance.add_callbacks([agent_callback])
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
# only OpenAI chat model (include Azure) support function call, use ReACT instead
if not isinstance(agent_llm, ChatOpenAI) \
and planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
planning_strategy = PlanningStrategy.REACT
if agent_model_instance.model_mode != ModelMode.CHAT \
or agent_model_instance.name not in ['openai', 'azure_openai']:
if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
planning_strategy = PlanningStrategy.REACT
elif planning_strategy == PlanningStrategy.ROUTER:
planning_strategy = PlanningStrategy.REACT_ROUTER
summary_llm = LLMBuilder.to_llm(
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_name=self.agent_summary_model_name,
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
model_kwargs=ModelKwargs(
temperature=0,
max_tokens=500
)
)
tools = self.to_tools(
tool_configs=tool_configs,
conversation_message_task=conversation_message_task,
model_name=self.agent_summary_model_name,
rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
)
@@ -91,20 +96,11 @@ class OrchestratorRuleParser:
if len(tools) == 0:
return None
dataset_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=self.dataset_retrieve_model_name,
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
)
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
llm=agent_llm,
model_instance=agent_model_instance,
tools=tools,
summary_llm=summary_llm,
dataset_llm=dataset_llm,
summary_model_instance=summary_model_instance,
memory=memory,
callbacks=[chain_callback, agent_callback],
max_iterations=10,
@@ -141,13 +137,12 @@ class OrchestratorRuleParser:
return None
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
model_name: str, rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
:param rest_tokens:
:param tool_configs: app agent tool configs
:param model_name:
:param conversation_message_task:
:param callbacks:
:return:
@@ -163,7 +158,7 @@ class OrchestratorRuleParser:
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool(model_name)
tool = self.to_web_reader_tool()
elif tool_type == "google_search":
tool = self.to_google_search_tool()
elif tool_type == "wikipedia":
@@ -205,20 +200,22 @@ class OrchestratorRuleParser:
return tool
def to_web_reader_tool(self, model_name: str) -> Optional[BaseTool]:
def to_web_reader_tool(self) -> Optional[BaseTool]:
"""
A tool for reading web pages
:return:
"""
summary_llm = LLMBuilder.to_llm(
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_name=model_name,
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
model_kwargs=ModelKwargs(
temperature=0,
max_tokens=500
)
)
summary_llm = summary_model_instance.client
tool = WebReaderTool(
llm=summary_llm,
max_chunk_length=4000,
@@ -273,6 +270,10 @@ class OrchestratorRuleParser:
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
if rest_tokens == -1:
return DEFAULT_K
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K

View File

View File

@@ -0,0 +1,99 @@
"""Wrapper around Replicate embedding models."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
class ReplicateEmbeddings(BaseModel, Embeddings):
"""Wrapper around Replicate embedding models.
To use, you should have the ``replicate`` python package installed.
"""
client: Any #: :meta private:
model: str
"""Model name to use."""
replicate_api_token: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
replicate_api_token = get_from_dict_or_env(
values, "replicate_api_token", "REPLICATE_API_TOKEN"
)
try:
import replicate as replicate_python
values["client"] = replicate_python.Client(api_token=replicate_api_token)
except ImportError:
raise ImportError(
"Could not import replicate python package. "
"Please install it with `pip install replicate`."
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Replicate's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
# get the model and version
model_str, version_str = self.model.split(":")
model = self.client.models.get(model_str)
version = model.versions.get(version_str)
# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)
first_input_name = input_properties[0][0]
embeddings = []
for text in texts:
result = self.client.run(self.model, input={first_input_name: text})
embeddings.append(result[0].get('embedding'))
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Call out to Replicate's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
# get the model and version
model_str, version_str = self.model.split(":")
model = self.client.models.get(model_str)
version = model.versions.get(version_str)
# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)
first_input_name = input_properties[0][0]
result = self.client.run(self.model, input={first_input_name: text})
embedding = result[0].get('embedding')
return list(map(float, embedding))

View File

View File

@@ -1,15 +1,13 @@
from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
from langchain.chat_models.openai import _convert_dict_to_message
from langchain.schema import BaseMessage, LLMResult, ChatResult, ChatGeneration
from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any, Tuple, Union
from typing import Dict, Any, Optional, List, Tuple, Union
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models import AzureChatOpenAI
from langchain.chat_models.openai import _convert_dict_to_message
from langchain.schema import ChatResult, BaseMessage, ChatGeneration
from pydantic import root_validator
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureChatOpenAI(AzureChatOpenAI):
class EnhanceAzureChatOpenAI(AzureChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
@@ -52,32 +50,6 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}
@handle_openai_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop, callbacks, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
model_kwargs = {
'top_p': params.get('top_p', 1),
'frequency_penalty': params.get('frequency_penalty', 0),
'presence_penalty': params.get('presence_penalty', 0),
}
del params['top_p']
del params['frequency_penalty']
del params['presence_penalty']
params['model_kwargs'] = model_kwargs
return params
def _generate(
self,
messages: List[BaseMessage],
@@ -116,4 +88,4 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
return self._create_chat_result(response)

View File

@@ -1,16 +1,14 @@
from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
from typing import Dict, Any, Mapping, Optional, List, Union, Tuple
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import AzureOpenAI
from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
update_token_usage
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
from pydantic import root_validator
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureOpenAI(AzureOpenAI):
class EnhanceAzureOpenAI(AzureOpenAI):
openai_api_type: str = "azure"
openai_api_version: str = ""
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
@@ -56,20 +54,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}}
@handle_openai_exceptions
def generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(prompts, stop, callbacks, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
return params
def _generate(
self,
prompts: List[str],

Some files were not shown because too many files have changed in this diff Show More