feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

View File

@@ -0,0 +1,99 @@
"""Wrapper around Replicate embedding models."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
class ReplicateEmbeddings(BaseModel, Embeddings):
"""Wrapper around Replicate embedding models.
To use, you should have the ``replicate`` python package installed.
"""
client: Any #: :meta private:
model: str
"""Model name to use."""
replicate_api_token: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
replicate_api_token = get_from_dict_or_env(
values, "replicate_api_token", "REPLICATE_API_TOKEN"
)
try:
import replicate as replicate_python
values["client"] = replicate_python.Client(api_token=replicate_api_token)
except ImportError:
raise ImportError(
"Could not import replicate python package. "
"Please install it with `pip install replicate`."
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Replicate's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
# get the model and version
model_str, version_str = self.model.split(":")
model = self.client.models.get(model_str)
version = model.versions.get(version_str)
# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)
first_input_name = input_properties[0][0]
embeddings = []
for text in texts:
result = self.client.run(self.model, input={first_input_name: text})
embeddings.append(result[0].get('embedding'))
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Call out to Replicate's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
# get the model and version
model_str, version_str = self.model.split(":")
model = self.client.models.get(model_str)
version = model.versions.get(version_str)
# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)
first_input_name = input_properties[0][0]
result = self.client.run(self.model, input={first_input_name: text})
embedding = result[0].get('embedding')
return list(map(float, embedding))

View File

View File

@@ -0,0 +1,91 @@
from typing import Dict, Any, Optional, List, Tuple, Union
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models import AzureChatOpenAI
from langchain.chat_models.openai import _convert_dict_to_message
from langchain.schema import ChatResult, BaseMessage, ChatGeneration
from pydantic import root_validator
class EnhanceAzureChatOpenAI(AzureChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
**super()._default_params,
"engine": self.deployment_name,
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
function_call: Optional[dict] = None
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
):
if len(stream_resp["choices"]) > 0:
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content") or ""
inner_completion += token
_function_call = stream_resp["choices"][0]["delta"].get("function_call")
if _function_call:
if function_call is None:
function_call = _function_call
else:
function_call["arguments"] += _function_call["arguments"]
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
"function_call": function_call,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)

View File

@@ -0,0 +1,110 @@
from typing import Dict, Any, Mapping, Optional, List, Union, Tuple
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import AzureOpenAI
from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
update_token_usage
from langchain.schema import LLMResult
from pydantic import root_validator
class EnhanceAzureOpenAI(AzureOpenAI):
openai_api_type: str = "azure"
openai_api_version: str = ""
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
values["client"] = openai.Completion
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")
return values
@property
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to OpenAI's endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
Example:
.. code-block:: python
response = openai.generate(["Tell me a joke."])
"""
params = self._invocation_params
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
if self.streaming:
if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
params["stream"] = True
response = _streaming_response_template()
for stream_resp in completion_with_retry(
self, prompt=_prompts, **params
):
if len(stream_resp["choices"]) > 0:
if run_manager:
run_manager.on_llm_new_token(
stream_resp["choices"][0]["text"],
verbose=self.verbose,
logprobs=stream_resp["choices"][0]["logprobs"],
)
_update_response(response, stream_resp)
choices.extend(response["choices"])
else:
response = completion_with_retry(self, prompt=_prompts, **params)
choices.extend(response["choices"])
if not self.streaming:
# Can't update token usage if streaming
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)

View File

@@ -0,0 +1,49 @@
import os
from typing import Dict, Any, Optional, Union, Tuple
from langchain.chat_models import ChatOpenAI
from pydantic import root_validator
class EnhanceChatOpenAI(ChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
**super()._default_params,
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}

View File

@@ -0,0 +1,61 @@
import time
from typing import List, Optional, Any, Mapping, Callable
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration
from core.model_providers.models.entity.message import str_to_prompt_messages
class FakeLLM(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
streaming: bool = False
"""Whether to stream the results or not."""
response: str
num_token_func: Optional[Callable] = None
@property
def _llm_type(self) -> str:
return "fake-chat-model"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
return self.response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"response": self.response}
def get_num_tokens(self, text: str) -> int:
return self.num_token_func(str_to_prompt_messages([text])) if self.num_token_func else 0
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
if self.streaming:
for token in output_str:
if run_manager:
run_manager.on_llm_new_token(token)
time.sleep(0.01)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
llm_output = {"token_usage": {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
}}
return ChatResult(generations=[generation], llm_output=llm_output)

View File

@@ -0,0 +1,50 @@
import os
from typing import Dict, Any, Mapping, Optional, Union, Tuple
from langchain import OpenAI
from pydantic import root_validator
class EnhanceOpenAI(OpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
values["client"] = openai.Completion
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")
return values
@property
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}

View File

@@ -0,0 +1,75 @@
from typing import Dict, Optional, List, Any
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Replicate
from langchain.utils import get_from_dict_or_env
from pydantic import root_validator
class EnhanceReplicate(Replicate):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
replicate_api_token = get_from_dict_or_env(
values, "replicate_api_token", "REPLICATE_API_TOKEN"
)
values["replicate_api_token"] = replicate_api_token
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to replicate endpoint."""
try:
import replicate as replicate_python
except ImportError:
raise ImportError(
"Could not import replicate python package. "
"Please install it with `pip install replicate`."
)
client = replicate_python.Client(api_token=self.replicate_api_token)
# get the model and version
model_str, version_str = self.model.split(":")
model = client.models.get(model_str)
version = model.versions.get(version_str)
# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)
first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input}
prediction = client.predictions.create(
version=version, input={**inputs, **kwargs}
)
current_completion: str = ""
stop_condition_reached = False
for output in prediction.output_iterator():
current_completion += output
# test for stop conditions, if specified
if stop:
for s in stop:
if s in current_completion:
prediction.cancel()
stop_index = current_completion.find(s)
current_completion = current_completion[:stop_index]
stop_condition_reached = True
break
if stop_condition_reached:
break
if self.streaming and run_manager:
run_manager.on_llm_new_token(output)
return current_completion

View File

@@ -0,0 +1,185 @@
import re
import string
import threading
from _decimal import Decimal, ROUND_HALF_UP
from typing import Dict, List, Optional, Any, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.llms.utils import enforce_stop_tokens
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, ChatResult, \
ChatGeneration
from langchain.utils import get_from_dict_or_env
from pydantic import root_validator
from core.third_party.spark.spark_llm import SparkLLMClient
class ChatSpark(BaseChatModel):
r"""Wrapper around Spark's large language model.
To use, you should pass `app_id`, `api_key`, `api_secret`
as a named parameter to the constructor.
Example:
.. code-block:: python
client = SparkLLMClient(
app_id="<app_id>",
api_key="<api_key>",
api_secret="<api_secret>"
)
"""
client: Any = None #: :meta private:
max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = None
"""A non-negative float that tunes the degree of randomness in generation."""
top_k: Optional[int] = None
"""Number of most likely tokens to consider at each step."""
user_id: Optional[str] = None
"""User ID to use for the model."""
streaming: bool = False
"""Whether to stream the results."""
app_id: Optional[str] = None
api_key: Optional[str] = None
api_secret: Optional[str] = None
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["app_id"] = get_from_dict_or_env(
values, "app_id", "SPARK_APP_ID"
)
values["api_key"] = get_from_dict_or_env(
values, "api_key", "SPARK_API_KEY"
)
values["api_secret"] = get_from_dict_or_env(
values, "api_secret", "SPARK_API_SECRET"
)
values["client"] = SparkLLMClient(
app_id=values["app_id"],
api_key=values["api_key"],
api_secret=values["api_secret"],
)
return values
@property
def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling Anthropic API."""
d = {
"max_tokens": self.max_tokens
}
if self.temperature is not None:
d["temperature"] = self.temperature
if self.top_k is not None:
d["top_k"] = self.top_k
return d
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{}, **self._default_params}
@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "API_KEY", "api_secret": "API_SECRET"}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "spark-chat"
@property
def lc_serializable(self) -> bool:
return True
def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> list[dict]:
"""Format a list of messages into a full dict list.
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
list[dict]
"""
messages = messages.copy() # don't mutate the original list
new_messages = []
for message in messages:
if isinstance(message, ChatMessage):
new_messages.append({'role': 'user', 'content': message.content})
elif isinstance(message, HumanMessage) or isinstance(message, SystemMessage):
new_messages.append({'role': 'user', 'content': message.content})
elif isinstance(message, AIMessage):
new_messages.append({'role': 'assistant', 'content': message.content})
else:
raise ValueError(f"Got unknown type {message}")
return new_messages
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
messages = self._convert_messages_to_dicts(messages)
thread = threading.Thread(target=self.client.run, args=(
messages,
self.user_id,
self._default_params,
self.streaming
))
thread.start()
completion = ""
for content in self.client.subscribe():
if isinstance(content, dict):
delta = content['data']
else:
delta = content
completion += delta
if self.streaming and run_manager:
run_manager.on_llm_new_token(
delta,
)
thread.join()
if stop is not None:
completion = enforce_stop_tokens(completion, stop)
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message = AIMessage(content='')
return ChatResult(generations=[ChatGeneration(message=message)])
def get_num_tokens(self, text: str) -> float:
"""Calculate number of tokens."""
total = Decimal(0)
words = re.findall(r'\b\w+\b|[{}]|\s'.format(re.escape(string.punctuation)), text)
for word in words:
if word:
if '\u4e00' <= word <= '\u9fff': # if chinese
total += Decimal('1.5')
else:
total += Decimal('0.8')
return int(total)

View File

@@ -0,0 +1,82 @@
from typing import Dict, Any, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Tongyi
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
from langchain.schema import Generation, LLMResult
class EnhanceTongyi(Tongyi):
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
normal_params = {
"top_p": self.top_p,
"api_key": self.dashscope_api_key
}
return {**normal_params, **self.model_kwargs}
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations = []
params: Dict[str, Any] = {
**{"model": self.model_name},
**self._default_params,
**kwargs,
}
if self.streaming:
if len(prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
params["stream"] = True
text = ''
for stream_resp in stream_generate_with_retry(
self, prompt=prompts[0], **params
):
if not generations:
current_text = stream_resp["output"]["text"]
else:
current_text = stream_resp["output"]["text"][len(text):]
text = stream_resp["output"]["text"]
generations.append(
[
Generation(
text=current_text,
generation_info=dict(
finish_reason=stream_resp["output"]["finish_reason"],
),
)
]
)
if run_manager:
run_manager.on_llm_new_token(
current_text,
verbose=self.verbose,
logprobs=None,
)
else:
for prompt in prompts:
completion = generate_with_retry(
self,
prompt=prompt,
**params,
)
generations.append(
[
Generation(
text=completion["output"]["text"],
generation_info=dict(
finish_reason=completion["output"]["finish_reason"],
),
)
]
)
return LLMResult(generations=generations)

View File

@@ -0,0 +1,233 @@
"""Wrapper around Wenxin APIs."""
from __future__ import annotations
import json
import logging
from typing import (
Any,
Dict,
List,
Optional, Iterator,
)
import requests
from langchain.llms.utils import enforce_stop_tokens
from langchain.schema.output import GenerationChunk
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class _WenxinEndpointClient(BaseModel):
"""An API client that talks to a Wenxin llm endpoint."""
base_url: str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/"
secret_key: str
api_key: str
def get_access_token(self) -> str:
url = f"https://aip.baidubce.com/oauth/2.0/token?client_id={self.api_key}" \
f"&client_secret={self.secret_key}&grant_type=client_credentials"
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
response = requests.post(url, headers=headers)
if not response.ok:
raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
if 'error' in response.json():
raise ValueError(
f"Wenxin API {response.json()['error']}"
f" error: {response.json()['error_description']}"
)
access_token = response.json()['access_token']
# todo add cache
return access_token
def post(self, request: dict) -> Any:
if 'model' not in request:
raise ValueError(f"Wenxin Model name is required")
model_url_map = {
'ernie-bot': 'completions',
'ernie-bot-turbo': 'eb-instant',
'bloomz-7b': 'bloomz_7b1',
}
stream = 'stream' in request and request['stream']
access_token = self.get_access_token()
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
headers = {"Content-Type": "application/json"}
response = requests.post(api_url,
headers=headers,
json=request,
stream=stream)
if not response.ok:
raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
if not stream:
json_response = response.json()
if 'error_code' in json_response:
raise ValueError(
f"Wenxin API {json_response['error_code']}"
f" error: {json_response['error_msg']}"
)
return json_response["result"]
else:
return response
class Wenxin(LLM):
"""Wrapper around Wenxin large language models.
To use, you should have the environment variable
``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
or pass them as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms.wenxin import Wenxin
wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
secret_key="my-group-id")
"""
_client: _WenxinEndpointClient = PrivateAttr()
model: str = "ernie-bot"
"""Model name to use."""
temperature: float = 0.7
"""A non-negative float that tunes the degree of randomness in generation."""
top_p: float = 0.95
"""Total probability mass of tokens to consider at each step."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
streaming: bool = False
"""Whether to stream the response or return it all at once."""
api_key: Optional[str] = None
secret_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["api_key"] = get_from_dict_or_env(
values, "api_key", "WENXIN_API_KEY"
)
values["secret_key"] = get_from_dict_or_env(
values, "secret_key", "WENXIN_SECRET_KEY"
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"stream": self.streaming,
**self.model_kwargs,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "wenxin"
def __init__(self, **data: Any):
super().__init__(**data)
self._client = _WenxinEndpointClient(
api_key=self.api_key,
secret_key=self.secret_key,
)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""Call out to Wenxin's completion endpoint to chat
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = wenxin("Tell me a joke.")
"""
if self.streaming:
completion = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
else:
request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}]
request.update(kwargs)
completion = self._client.post(request)
if stop is not None:
completion = enforce_stop_tokens(completion, stop)
return completion
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
r"""Call wenxin completion_stream and return the resulting generator.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Wenxin.
Example:
.. code-block:: python
prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = wenxin.stream(prompt)
for token in generator:
yield token
"""
request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}]
request.update(kwargs)
for token in self._client.post(request).iter_lines():
if token:
token = token.decode("utf-8")
completion = json.loads(token[5:])
yield GenerationChunk(text=completion['result'])
if run_manager:
run_manager.on_llm_new_token(completion['result'])
if completion['is_end']:
break

View File

150
api/core/third_party/spark/spark_llm.py vendored Normal file
View File

@@ -0,0 +1,150 @@
import base64
import datetime
import hashlib
import hmac
import json
import queue
from typing import Optional
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import websocket
class SparkLLMClient:
def __init__(self, app_id: str, api_key: str, api_secret: str):
self.api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
self.app_id = app_id
self.ws_url = self.create_url(
urlparse(self.api_base).netloc,
urlparse(self.api_base).path,
self.api_base,
api_key,
api_secret
)
self.queue = queue.Queue()
self.blocking_message = ''
def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
# generate timestamp by RFC1123
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + path + " HTTP/1.1"
# encrypt using hmac-sha256
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": host
}
# generate url
url = api_base + '?' + urlencode(v)
return url
def run(self, messages: list, user_id: str,
model_kwargs: Optional[dict] = None, streaming: bool = False):
websocket.enableTrace(False)
ws = websocket.WebSocketApp(
self.ws_url,
on_message=self.on_message,
on_error=self.on_error,
on_close=self.on_close,
on_open=self.on_open
)
ws.messages = messages
ws.user_id = user_id
ws.model_kwargs = model_kwargs
ws.streaming = streaming
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_error(self, ws, error):
self.queue.put({'error': error})
ws.close()
def on_close(self, ws, close_status_code, close_reason):
self.queue.put({'done': True})
def on_open(self, ws):
self.blocking_message = ''
data = json.dumps(self.gen_params(
messages=ws.messages,
user_id=ws.user_id,
model_kwargs=ws.model_kwargs
))
ws.send(data)
def on_message(self, ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
self.queue.put({'error': f"Code: {code}, Error: {data['header']['message']}"})
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
if ws.streaming:
self.queue.put({'data': content})
else:
self.blocking_message += content
if status == 2:
if not ws.streaming:
self.queue.put({'data': self.blocking_message})
ws.close()
def gen_params(self, messages: list, user_id: str,
model_kwargs: Optional[dict] = None) -> dict:
data = {
"header": {
"app_id": self.app_id,
"uid": user_id
},
"parameter": {
"chat": {
"domain": "general"
}
},
"payload": {
"message": {
"text": messages
}
}
}
if model_kwargs:
data['parameter']['chat'].update(model_kwargs)
return data
def subscribe(self):
while True:
content = self.queue.get()
if 'error' in content:
raise SparkError(content['error'])
if 'data' not in content:
break
yield content
class SparkError(Exception):
pass