mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 08:52:43 +08:00
feat: server multi models support (#799)
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from typing import Any, List, Dict, Union
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
|
||||
from langchain.schema import get_buffer_string, BaseMessage
|
||||
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message
|
||||
|
||||
@@ -13,7 +13,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
conversation: Conversation
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "Assistant"
|
||||
llm: BaseLanguageModel
|
||||
model_instance: BaseLLM
|
||||
memory_key: str = "chat_history"
|
||||
max_token_limit: int = 2000
|
||||
message_limit: int = 10
|
||||
@@ -29,23 +29,23 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
|
||||
messages = list(reversed(messages))
|
||||
|
||||
chat_messages: List[BaseMessage] = []
|
||||
chat_messages: List[PromptMessage] = []
|
||||
for message in messages:
|
||||
chat_messages.append(HumanMessage(content=message.query))
|
||||
chat_messages.append(AIMessage(content=message.answer))
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
|
||||
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
|
||||
|
||||
if not chat_messages:
|
||||
return chat_messages
|
||||
return []
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit and chat_messages:
|
||||
pruned_memory.append(chat_messages.pop(0))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
|
||||
|
||||
return chat_messages
|
||||
return to_lc_messages(chat_messages)
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
|
||||
Reference in New Issue
Block a user