mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 09:19:22 +08:00
feat: server multi models support (#799)
This commit is contained in:
0
api/core/third_party/langchain/embeddings/__init__.py
vendored
Normal file
0
api/core/third_party/langchain/embeddings/__init__.py
vendored
Normal file
99
api/core/third_party/langchain/embeddings/replicate_embedding.py
vendored
Normal file
99
api/core/third_party/langchain/embeddings/replicate_embedding.py
vendored
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Wrapper around Replicate embedding models."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class ReplicateEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around Replicate embedding models.
|
||||
|
||||
To use, you should have the ``replicate`` python package installed.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
replicate_api_token: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
replicate_api_token = get_from_dict_or_env(
|
||||
values, "replicate_api_token", "REPLICATE_API_TOKEN"
|
||||
)
|
||||
try:
|
||||
import replicate as replicate_python
|
||||
|
||||
values["client"] = replicate_python.Client(api_token=replicate_api_token)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import replicate python package. "
|
||||
"Please install it with `pip install replicate`."
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Replicate's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
# get the model and version
|
||||
model_str, version_str = self.model.split(":")
|
||||
model = self.client.models.get(model_str)
|
||||
version = model.versions.get(version_str)
|
||||
|
||||
# sort through the openapi schema to get the name of the first input
|
||||
input_properties = sorted(
|
||||
version.openapi_schema["components"]["schemas"]["Input"][
|
||||
"properties"
|
||||
].items(),
|
||||
key=lambda item: item[1].get("x-order", 0),
|
||||
)
|
||||
first_input_name = input_properties[0][0]
|
||||
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
result = self.client.run(self.model, input={first_input_name: text})
|
||||
embeddings.append(result[0].get('embedding'))
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to Replicate's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
# get the model and version
|
||||
model_str, version_str = self.model.split(":")
|
||||
model = self.client.models.get(model_str)
|
||||
version = model.versions.get(version_str)
|
||||
|
||||
# sort through the openapi schema to get the name of the first input
|
||||
input_properties = sorted(
|
||||
version.openapi_schema["components"]["schemas"]["Input"][
|
||||
"properties"
|
||||
].items(),
|
||||
key=lambda item: item[1].get("x-order", 0),
|
||||
)
|
||||
first_input_name = input_properties[0][0]
|
||||
result = self.client.run(self.model, input={first_input_name: text})
|
||||
embedding = result[0].get('embedding')
|
||||
|
||||
return list(map(float, embedding))
|
||||
0
api/core/third_party/langchain/llms/__init__.py
vendored
Normal file
0
api/core/third_party/langchain/llms/__init__.py
vendored
Normal file
91
api/core/third_party/langchain/llms/azure_chat_open_ai.py
vendored
Normal file
91
api/core/third_party/langchain/llms/azure_chat_open_ai.py
vendored
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Dict, Any, Optional, List, Tuple, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from langchain.chat_models.openai import _convert_dict_to_message
|
||||
from langchain.schema import ChatResult, BaseMessage, ChatGeneration
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
**super()._default_params,
|
||||
"engine": self.deployment_name,
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
params["stream"] = True
|
||||
function_call: Optional[dict] = None
|
||||
for stream_resp in self.completion_with_retry(
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if len(stream_resp["choices"]) > 0:
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content") or ""
|
||||
inner_completion += token
|
||||
_function_call = stream_resp["choices"][0]["delta"].get("function_call")
|
||||
if _function_call:
|
||||
if function_call is None:
|
||||
function_call = _function_call
|
||||
else:
|
||||
function_call["arguments"] += _function_call["arguments"]
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{
|
||||
"content": inner_completion,
|
||||
"role": role,
|
||||
"function_call": function_call,
|
||||
}
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
110
api/core/third_party/langchain/llms/azure_open_ai.py
vendored
Normal file
110
api/core/third_party/langchain/llms/azure_open_ai.py
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import Dict, Any, Mapping, Optional, List, Union, Tuple
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import AzureOpenAI
|
||||
from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
|
||||
update_token_usage
|
||||
from langchain.schema import LLMResult
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class EnhanceAzureOpenAI(AzureOpenAI):
|
||||
openai_api_type: str = "azure"
|
||||
openai_api_version: str = ""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
if values["streaming"] and values["n"] > 1:
|
||||
raise ValueError("Cannot stream results when n > 1.")
|
||||
if values["streaming"] and values["best_of"] > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**super()._invocation_params, **{
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {**super()._identifying_params, **{
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to OpenAI's endpoint with k unique prompts.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The full LLM output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = openai.generate(["Tell me a joke."])
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
# Get the token usage from the response.
|
||||
# Includes prompt, completion, and total tokens used.
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
for _prompts in sub_prompts:
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
params["stream"] = True
|
||||
response = _streaming_response_template()
|
||||
for stream_resp in completion_with_retry(
|
||||
self, prompt=_prompts, **params
|
||||
):
|
||||
if len(stream_resp["choices"]) > 0:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"],
|
||||
verbose=self.verbose,
|
||||
logprobs=stream_resp["choices"][0]["logprobs"],
|
||||
)
|
||||
_update_response(response, stream_resp)
|
||||
choices.extend(response["choices"])
|
||||
else:
|
||||
response = completion_with_retry(self, prompt=_prompts, **params)
|
||||
choices.extend(response["choices"])
|
||||
if not self.streaming:
|
||||
# Can't update token usage if streaming
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
49
api/core/third_party/langchain/llms/chat_open_ai.py
vendored
Normal file
49
api/core/third_party/langchain/llms/chat_open_ai.py
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
|
||||
from typing import Dict, Any, Optional, Union, Tuple
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class EnhanceChatOpenAI(ChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
**super()._default_params,
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
61
api/core/third_party/langchain/llms/fake.py
vendored
Normal file
61
api/core/third_party/langchain/llms/fake.py
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
import time
|
||||
from typing import List, Optional, Any, Mapping, Callable
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration
|
||||
|
||||
from core.model_providers.models.entity.message import str_to_prompt_messages
|
||||
|
||||
|
||||
class FakeLLM(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
response: str
|
||||
num_token_func: Optional[Callable] = None
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||
return self.response
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {"response": self.response}
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return self.num_token_func(str_to_prompt_messages([text])) if self.num_token_func else 0
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if self.streaming:
|
||||
for token in output_str:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
time.sleep(0.01)
|
||||
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
llm_output = {"token_usage": {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0,
|
||||
}}
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||
50
api/core/third_party/langchain/llms/open_ai.py
vendored
Normal file
50
api/core/third_party/langchain/llms/open_ai.py
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
|
||||
from typing import Dict, Any, Mapping, Optional, Union, Tuple
|
||||
from langchain import OpenAI
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class EnhanceOpenAI(OpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
if values["streaming"] and values["n"] > 1:
|
||||
raise ValueError("Cannot stream results when n > 1.")
|
||||
if values["streaming"] and values["best_of"] > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**super()._invocation_params, **{
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {**super()._identifying_params, **{
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
75
api/core/third_party/langchain/llms/replicate_llm.py
vendored
Normal file
75
api/core/third_party/langchain/llms/replicate_llm.py
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Dict, Optional, List, Any
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import Replicate
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class EnhanceReplicate(Replicate):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
replicate_api_token = get_from_dict_or_env(
|
||||
values, "replicate_api_token", "REPLICATE_API_TOKEN"
|
||||
)
|
||||
values["replicate_api_token"] = replicate_api_token
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to replicate endpoint."""
|
||||
try:
|
||||
import replicate as replicate_python
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import replicate python package. "
|
||||
"Please install it with `pip install replicate`."
|
||||
)
|
||||
|
||||
client = replicate_python.Client(api_token=self.replicate_api_token)
|
||||
|
||||
# get the model and version
|
||||
model_str, version_str = self.model.split(":")
|
||||
model = client.models.get(model_str)
|
||||
version = model.versions.get(version_str)
|
||||
|
||||
# sort through the openapi schema to get the name of the first input
|
||||
input_properties = sorted(
|
||||
version.openapi_schema["components"]["schemas"]["Input"][
|
||||
"properties"
|
||||
].items(),
|
||||
key=lambda item: item[1].get("x-order", 0),
|
||||
)
|
||||
first_input_name = input_properties[0][0]
|
||||
inputs = {first_input_name: prompt, **self.input}
|
||||
|
||||
prediction = client.predictions.create(
|
||||
version=version, input={**inputs, **kwargs}
|
||||
)
|
||||
current_completion: str = ""
|
||||
stop_condition_reached = False
|
||||
for output in prediction.output_iterator():
|
||||
current_completion += output
|
||||
|
||||
# test for stop conditions, if specified
|
||||
if stop:
|
||||
for s in stop:
|
||||
if s in current_completion:
|
||||
prediction.cancel()
|
||||
stop_index = current_completion.find(s)
|
||||
current_completion = current_completion[:stop_index]
|
||||
stop_condition_reached = True
|
||||
break
|
||||
|
||||
if stop_condition_reached:
|
||||
break
|
||||
|
||||
if self.streaming and run_manager:
|
||||
run_manager.on_llm_new_token(output)
|
||||
return current_completion
|
||||
185
api/core/third_party/langchain/llms/spark.py
vendored
Normal file
185
api/core/third_party/langchain/llms/spark.py
vendored
Normal file
@@ -0,0 +1,185 @@
|
||||
import re
|
||||
import string
|
||||
import threading
|
||||
from _decimal import Decimal, ROUND_HALF_UP
|
||||
from typing import Dict, List, Optional, Any, Mapping
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, ChatResult, \
|
||||
ChatGeneration
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.third_party.spark.spark_llm import SparkLLMClient
|
||||
|
||||
|
||||
class ChatSpark(BaseChatModel):
|
||||
r"""Wrapper around Spark's large language model.
|
||||
|
||||
To use, you should pass `app_id`, `api_key`, `api_secret`
|
||||
as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
client = SparkLLMClient(
|
||||
app_id="<app_id>",
|
||||
api_key="<api_key>",
|
||||
api_secret="<api_secret>"
|
||||
)
|
||||
"""
|
||||
client: Any = None #: :meta private:
|
||||
|
||||
max_tokens: int = 256
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
top_k: Optional[int] = None
|
||||
"""Number of most likely tokens to consider at each step."""
|
||||
|
||||
user_id: Optional[str] = None
|
||||
"""User ID to use for the model."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results."""
|
||||
|
||||
app_id: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_secret: Optional[str] = None
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["app_id"] = get_from_dict_or_env(
|
||||
values, "app_id", "SPARK_APP_ID"
|
||||
)
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "SPARK_API_KEY"
|
||||
)
|
||||
values["api_secret"] = get_from_dict_or_env(
|
||||
values, "api_secret", "SPARK_API_SECRET"
|
||||
)
|
||||
|
||||
values["client"] = SparkLLMClient(
|
||||
app_id=values["app_id"],
|
||||
api_key=values["api_key"],
|
||||
api_secret=values["api_secret"],
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Mapping[str, Any]:
|
||||
"""Get the default parameters for calling Anthropic API."""
|
||||
d = {
|
||||
"max_tokens": self.max_tokens
|
||||
}
|
||||
if self.temperature is not None:
|
||||
d["temperature"] = self.temperature
|
||||
if self.top_k is not None:
|
||||
d["top_k"] = self.top_k
|
||||
return d
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{}, **self._default_params}
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"api_key": "API_KEY", "api_secret": "API_SECRET"}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "spark-chat"
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> list[dict]:
|
||||
"""Format a list of messages into a full dict list.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||
|
||||
Returns:
|
||||
list[dict]
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, ChatMessage):
|
||||
new_messages.append({'role': 'user', 'content': message.content})
|
||||
elif isinstance(message, HumanMessage) or isinstance(message, SystemMessage):
|
||||
new_messages.append({'role': 'user', 'content': message.content})
|
||||
elif isinstance(message, AIMessage):
|
||||
new_messages.append({'role': 'assistant', 'content': message.content})
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return new_messages
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
messages = self._convert_messages_to_dicts(messages)
|
||||
|
||||
thread = threading.Thread(target=self.client.run, args=(
|
||||
messages,
|
||||
self.user_id,
|
||||
self._default_params,
|
||||
self.streaming
|
||||
))
|
||||
thread.start()
|
||||
|
||||
completion = ""
|
||||
for content in self.client.subscribe():
|
||||
if isinstance(content, dict):
|
||||
delta = content['data']
|
||||
else:
|
||||
delta = content
|
||||
|
||||
completion += delta
|
||||
if self.streaming and run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
delta,
|
||||
)
|
||||
|
||||
thread.join()
|
||||
|
||||
if stop is not None:
|
||||
completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message = AIMessage(content='')
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
def get_num_tokens(self, text: str) -> float:
|
||||
"""Calculate number of tokens."""
|
||||
total = Decimal(0)
|
||||
words = re.findall(r'\b\w+\b|[{}]|\s'.format(re.escape(string.punctuation)), text)
|
||||
for word in words:
|
||||
if word:
|
||||
if '\u4e00' <= word <= '\u9fff': # if chinese
|
||||
total += Decimal('1.5')
|
||||
else:
|
||||
total += Decimal('0.8')
|
||||
return int(total)
|
||||
82
api/core/third_party/langchain/llms/tongyi_llm.py
vendored
Normal file
82
api/core/third_party/langchain/llms/tongyi_llm.py
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import Tongyi
|
||||
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
|
||||
from langchain.schema import Generation, LLMResult
|
||||
|
||||
|
||||
class EnhanceTongyi(Tongyi):
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
normal_params = {
|
||||
"top_p": self.top_p,
|
||||
"api_key": self.dashscope_api_key
|
||||
}
|
||||
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations = []
|
||||
params: Dict[str, Any] = {
|
||||
**{"model": self.model_name},
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if self.streaming:
|
||||
if len(prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
params["stream"] = True
|
||||
text = ''
|
||||
for stream_resp in stream_generate_with_retry(
|
||||
self, prompt=prompts[0], **params
|
||||
):
|
||||
if not generations:
|
||||
current_text = stream_resp["output"]["text"]
|
||||
else:
|
||||
current_text = stream_resp["output"]["text"][len(text):]
|
||||
|
||||
text = stream_resp["output"]["text"]
|
||||
|
||||
generations.append(
|
||||
[
|
||||
Generation(
|
||||
text=current_text,
|
||||
generation_info=dict(
|
||||
finish_reason=stream_resp["output"]["finish_reason"],
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
current_text,
|
||||
verbose=self.verbose,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
for prompt in prompts:
|
||||
completion = generate_with_retry(
|
||||
self,
|
||||
prompt=prompt,
|
||||
**params,
|
||||
)
|
||||
generations.append(
|
||||
[
|
||||
Generation(
|
||||
text=completion["output"]["text"],
|
||||
generation_info=dict(
|
||||
finish_reason=completion["output"]["finish_reason"],
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
233
api/core/third_party/langchain/llms/wenxin.py
vendored
Normal file
233
api/core/third_party/langchain/llms/wenxin.py
vendored
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Wrapper around Wenxin APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional, Iterator,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _WenxinEndpointClient(BaseModel):
|
||||
"""An API client that talks to a Wenxin llm endpoint."""
|
||||
|
||||
base_url: str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/"
|
||||
secret_key: str
|
||||
api_key: str
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
url = f"https://aip.baidubce.com/oauth/2.0/token?client_id={self.api_key}" \
|
||||
f"&client_secret={self.secret_key}&grant_type=client_credentials"
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
|
||||
if 'error' in response.json():
|
||||
raise ValueError(
|
||||
f"Wenxin API {response.json()['error']}"
|
||||
f" error: {response.json()['error_description']}"
|
||||
)
|
||||
|
||||
access_token = response.json()['access_token']
|
||||
|
||||
# todo add cache
|
||||
|
||||
return access_token
|
||||
|
||||
def post(self, request: dict) -> Any:
|
||||
if 'model' not in request:
|
||||
raise ValueError(f"Wenxin Model name is required")
|
||||
|
||||
model_url_map = {
|
||||
'ernie-bot': 'completions',
|
||||
'ernie-bot-turbo': 'eb-instant',
|
||||
'bloomz-7b': 'bloomz_7b1',
|
||||
}
|
||||
|
||||
stream = 'stream' in request and request['stream']
|
||||
|
||||
access_token = self.get_access_token()
|
||||
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(api_url,
|
||||
headers=headers,
|
||||
json=request,
|
||||
stream=stream)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
if not stream:
|
||||
json_response = response.json()
|
||||
if 'error_code' in json_response:
|
||||
raise ValueError(
|
||||
f"Wenxin API {json_response['error_code']}"
|
||||
f" error: {json_response['error_msg']}"
|
||||
)
|
||||
return json_response["result"]
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
class Wenxin(LLM):
|
||||
"""Wrapper around Wenxin large language models.
|
||||
To use, you should have the environment variable
|
||||
``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
|
||||
or pass them as a named parameter to the constructor.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain.llms.wenxin import Wenxin
|
||||
wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
|
||||
secret_key="my-group-id")
|
||||
"""
|
||||
|
||||
_client: _WenxinEndpointClient = PrivateAttr()
|
||||
model: str = "ernie-bot"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
top_p: float = 0.95
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the response or return it all at once."""
|
||||
api_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "WENXIN_API_KEY"
|
||||
)
|
||||
values["secret_key"] = get_from_dict_or_env(
|
||||
values, "secret_key", "WENXIN_SECRET_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"stream": self.streaming,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model": self.model}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "wenxin"
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
self._client = _WenxinEndpointClient(
|
||||
api_key=self.api_key,
|
||||
secret_key=self.secret_key,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
r"""Call out to Wenxin's completion endpoint to chat
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
response = wenxin("Tell me a joke.")
|
||||
"""
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
else:
|
||||
request = self._default_params
|
||||
request["messages"] = [{"role": "user", "content": prompt}]
|
||||
request.update(kwargs)
|
||||
completion = self._client.post(request)
|
||||
|
||||
if stop is not None:
|
||||
completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
return completion
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
r"""Call wenxin completion_stream and return the resulting generator.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
A generator representing the stream of tokens from Wenxin.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
prompt = "Write a poem about a stream."
|
||||
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
||||
generator = wenxin.stream(prompt)
|
||||
for token in generator:
|
||||
yield token
|
||||
"""
|
||||
request = self._default_params
|
||||
request["messages"] = [{"role": "user", "content": prompt}]
|
||||
request.update(kwargs)
|
||||
|
||||
for token in self._client.post(request).iter_lines():
|
||||
if token:
|
||||
token = token.decode("utf-8")
|
||||
completion = json.loads(token[5:])
|
||||
|
||||
yield GenerationChunk(text=completion['result'])
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(completion['result'])
|
||||
|
||||
if completion['is_end']:
|
||||
break
|
||||
0
api/core/third_party/spark/__init__.py
vendored
Normal file
0
api/core/third_party/spark/__init__.py
vendored
Normal file
150
api/core/third_party/spark/spark_llm.py
vendored
Normal file
150
api/core/third_party/spark/spark_llm.py
vendored
Normal file
@@ -0,0 +1,150 @@
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import queue
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
import ssl
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import websocket
|
||||
|
||||
|
||||
class SparkLLMClient:
|
||||
def __init__(self, app_id: str, api_key: str, api_secret: str):
|
||||
|
||||
self.api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||
self.app_id = app_id
|
||||
self.ws_url = self.create_url(
|
||||
urlparse(self.api_base).netloc,
|
||||
urlparse(self.api_base).path,
|
||||
self.api_base,
|
||||
api_key,
|
||||
api_secret
|
||||
)
|
||||
|
||||
self.queue = queue.Queue()
|
||||
self.blocking_message = ''
|
||||
|
||||
def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
|
||||
# generate timestamp by RFC1123
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
signature_origin = "host: " + host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + path + " HTTP/1.1"
|
||||
|
||||
# encrypt using hmac-sha256
|
||||
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": host
|
||||
}
|
||||
# generate url
|
||||
url = api_base + '?' + urlencode(v)
|
||||
return url
|
||||
|
||||
def run(self, messages: list, user_id: str,
|
||||
model_kwargs: Optional[dict] = None, streaming: bool = False):
|
||||
websocket.enableTrace(False)
|
||||
ws = websocket.WebSocketApp(
|
||||
self.ws_url,
|
||||
on_message=self.on_message,
|
||||
on_error=self.on_error,
|
||||
on_close=self.on_close,
|
||||
on_open=self.on_open
|
||||
)
|
||||
ws.messages = messages
|
||||
ws.user_id = user_id
|
||||
ws.model_kwargs = model_kwargs
|
||||
ws.streaming = streaming
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def on_error(self, ws, error):
|
||||
self.queue.put({'error': error})
|
||||
ws.close()
|
||||
|
||||
def on_close(self, ws, close_status_code, close_reason):
|
||||
self.queue.put({'done': True})
|
||||
|
||||
def on_open(self, ws):
|
||||
self.blocking_message = ''
|
||||
data = json.dumps(self.gen_params(
|
||||
messages=ws.messages,
|
||||
user_id=ws.user_id,
|
||||
model_kwargs=ws.model_kwargs
|
||||
))
|
||||
ws.send(data)
|
||||
|
||||
def on_message(self, ws, message):
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
self.queue.put({'error': f"Code: {code}, Error: {data['header']['message']}"})
|
||||
ws.close()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
if ws.streaming:
|
||||
self.queue.put({'data': content})
|
||||
else:
|
||||
self.blocking_message += content
|
||||
|
||||
if status == 2:
|
||||
if not ws.streaming:
|
||||
self.queue.put({'data': self.blocking_message})
|
||||
ws.close()
|
||||
|
||||
def gen_params(self, messages: list, user_id: str,
|
||||
model_kwargs: Optional[dict] = None) -> dict:
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": self.app_id,
|
||||
"uid": user_id
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": "general"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": messages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if model_kwargs:
|
||||
data['parameter']['chat'].update(model_kwargs)
|
||||
|
||||
return data
|
||||
|
||||
def subscribe(self):
|
||||
while True:
|
||||
content = self.queue.get()
|
||||
if 'error' in content:
|
||||
raise SparkError(content['error'])
|
||||
|
||||
if 'data' not in content:
|
||||
break
|
||||
yield content
|
||||
|
||||
|
||||
class SparkError(Exception):
|
||||
pass
|
||||
Reference in New Issue
Block a user