mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 16:39:26 +08:00
fix(api): excessive high CPU usage caused by RedisClientWrapper (#32212)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -34,7 +34,7 @@ def stream_topic_events(
|
|||||||
on_subscribe()
|
on_subscribe()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
msg = sub.receive(timeout=0.1)
|
msg = sub.receive(timeout=1)
|
||||||
except SubscriptionClosedError:
|
except SubscriptionClosedError:
|
||||||
return
|
return
|
||||||
if msg is None:
|
if msg is None:
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class RedisClientWrapper:
|
|||||||
|
|
||||||
|
|
||||||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||||
pubsub_redis_client: RedisClientWrapper = RedisClientWrapper()
|
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
|
||||||
|
|
||||||
|
|
||||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||||
@@ -232,7 +232,7 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
|||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]:
|
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
|
||||||
if use_clusters:
|
if use_clusters:
|
||||||
return RedisCluster.from_url(pubsub_url)
|
return RedisCluster.from_url(pubsub_url)
|
||||||
return redis.Redis.from_url(pubsub_url)
|
return redis.Redis.from_url(pubsub_url)
|
||||||
@@ -256,23 +256,19 @@ def init_app(app: DifyApp):
|
|||||||
redis_client.initialize(client)
|
redis_client.initialize(client)
|
||||||
app.extensions["redis"] = redis_client
|
app.extensions["redis"] = redis_client
|
||||||
|
|
||||||
pubsub_client = client
|
global _pubsub_redis_client
|
||||||
|
_pubsub_redis_client = client
|
||||||
if dify_config.normalized_pubsub_redis_url:
|
if dify_config.normalized_pubsub_redis_url:
|
||||||
pubsub_client = _create_pubsub_client(
|
_pubsub_redis_client = _create_pubsub_client(
|
||||||
dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS
|
dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS
|
||||||
)
|
)
|
||||||
pubsub_redis_client.initialize(pubsub_client)
|
|
||||||
|
|
||||||
|
|
||||||
def get_pubsub_redis_client() -> RedisClientWrapper:
|
|
||||||
return pubsub_redis_client
|
|
||||||
|
|
||||||
|
|
||||||
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||||
redis_conn = get_pubsub_redis_client()
|
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
|
||||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||||
return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
|
||||||
return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
return RedisBroadcastChannel(_pubsub_redis_client)
|
||||||
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class RedisSubscriptionBase(Subscription):
|
|||||||
"""Iterator for consuming messages from the subscription."""
|
"""Iterator for consuming messages from the subscription."""
|
||||||
while not self._closed.is_set():
|
while not self._closed.is_set():
|
||||||
try:
|
try:
|
||||||
item = self._queue.get(timeout=0.1)
|
item = self._queue.get(timeout=1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||||
from redis import Redis
|
from redis import Redis, RedisCluster
|
||||||
|
|
||||||
from ._subscription import RedisSubscriptionBase
|
from ._subscription import RedisSubscriptionBase
|
||||||
|
|
||||||
@@ -18,7 +18,7 @@ class BroadcastChannel:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis_client: Redis,
|
redis_client: Redis | RedisCluster,
|
||||||
):
|
):
|
||||||
self._client = redis_client
|
self._client = redis_client
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ class BroadcastChannel:
|
|||||||
|
|
||||||
|
|
||||||
class Topic:
|
class Topic:
|
||||||
def __init__(self, redis_client: Redis, topic: str):
|
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||||
self._client = redis_client
|
self._client = redis_client
|
||||||
self._topic = topic
|
self._topic = topic
|
||||||
|
|
||||||
|
|||||||
@@ -70,8 +70,9 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
|||||||
# Since we have already filtered at the caller's site, we can safely set
|
# Since we have already filtered at the caller's site, we can safely set
|
||||||
# `ignore_subscribe_messages=False`.
|
# `ignore_subscribe_messages=False`.
|
||||||
if isinstance(self._client, RedisCluster):
|
if isinstance(self._client, RedisCluster):
|
||||||
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message`
|
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` without
|
||||||
# would use busy-looping to wait for incoming message, consuming excessive CPU quota.
|
# specifying the `target_node` argument would use busy-looping to wait
|
||||||
|
# for incoming message, consuming excessive CPU quota.
|
||||||
#
|
#
|
||||||
# Here we specify the `target_node` to mitigate this problem.
|
# Here we specify the `target_node` to mitigate this problem.
|
||||||
node = self._client.get_node_from_key(self._topic)
|
node = self._client.get_node_from_key(self._topic)
|
||||||
@@ -80,8 +81,10 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
|||||||
timeout=1,
|
timeout=1,
|
||||||
target_node=node,
|
target_node=node,
|
||||||
)
|
)
|
||||||
else:
|
elif isinstance(self._client, Redis):
|
||||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
|
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
|
||||||
|
else:
|
||||||
|
raise AssertionError("client should be either Redis or RedisCluster.")
|
||||||
|
|
||||||
def _get_message_type(self) -> str:
|
def _get_message_type(self) -> str:
|
||||||
return "smessage"
|
return "smessage"
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from libs.exception import BaseHTTPException
|
|||||||
from models.human_input import RecipientType
|
from models.human_input import RecipientType
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
from repositories.factory import DifyAPIRepositoryFactory
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE, resume_app_execution
|
from tasks.app_generate.workflow_execute_task import resume_app_execution
|
||||||
|
|
||||||
|
|
||||||
class Form:
|
class Form:
|
||||||
@@ -230,7 +230,6 @@ class HumanInputService:
|
|||||||
try:
|
try:
|
||||||
resume_app_execution.apply_async(
|
resume_app_execution.apply_async(
|
||||||
kwargs={"payload": payload},
|
kwargs={"payload": payload},
|
||||||
queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE,
|
|
||||||
)
|
)
|
||||||
except Exception: # pragma: no cover
|
except Exception: # pragma: no cover
|
||||||
logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id)
|
logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id)
|
||||||
|
|||||||
@@ -129,15 +129,15 @@ def build_workflow_event_stream(
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event = buffer_state.queue.get(timeout=0.1)
|
event = buffer_state.queue.get(timeout=1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
if current_time - last_msg_time > idle_timeout:
|
if current_time - last_msg_time > idle_timeout:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"No workflow events received for %s seconds, keeping stream open",
|
"Idle timeout of %s seconds reached, closing workflow event stream.",
|
||||||
idle_timeout,
|
idle_timeout,
|
||||||
)
|
)
|
||||||
last_msg_time = current_time
|
return
|
||||||
if current_time - last_ping_time >= ping_interval:
|
if current_time - last_ping_time >= ping_interval:
|
||||||
yield StreamEvent.PING.value
|
yield StreamEvent.PING.value
|
||||||
last_ping_time = current_time
|
last_ping_time = current_time
|
||||||
@@ -405,7 +405,7 @@ def _start_buffering(subscription) -> BufferState:
|
|||||||
dropped_count = 0
|
dropped_count = 0
|
||||||
try:
|
try:
|
||||||
while not buffer_state.stop_event.is_set():
|
while not buffer_state.stop_event.is_set():
|
||||||
msg = subscription.receive(timeout=0.1)
|
msg = subscription.receive(timeout=1)
|
||||||
if msg is None:
|
if msg is None:
|
||||||
continue
|
continue
|
||||||
event = _parse_event_message(msg)
|
event = _parse_event_message(msg)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def _patch_redis_clients_on_loaded_modules():
|
|||||||
continue
|
continue
|
||||||
if hasattr(module, "redis_client"):
|
if hasattr(module, "redis_client"):
|
||||||
module.redis_client = redis_mock
|
module.redis_client = redis_mock
|
||||||
if hasattr(module, "pubsub_redis_client"):
|
if hasattr(module, "_pubsub_redis_client"):
|
||||||
module.pubsub_redis_client = redis_mock
|
module.pubsub_redis_client = redis_mock
|
||||||
|
|
||||||
|
|
||||||
@@ -72,7 +72,7 @@ def _patch_redis_clients():
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(ext_redis, "redis_client", redis_mock),
|
patch.object(ext_redis, "redis_client", redis_mock),
|
||||||
patch.object(ext_redis, "pubsub_redis_client", redis_mock),
|
patch.object(ext_redis, "_pubsub_redis_client", redis_mock),
|
||||||
):
|
):
|
||||||
_patch_redis_clients_on_loaded_modules()
|
_patch_redis_clients_on_loaded_modules()
|
||||||
yield
|
yield
|
||||||
|
|||||||
@@ -198,6 +198,15 @@ class SubscriptionTestCase:
|
|||||||
description: str = ""
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class FakeRedisClient:
|
||||||
|
"""Minimal fake Redis client for unit tests."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.publish = MagicMock()
|
||||||
|
self.spublish = MagicMock()
|
||||||
|
self.pubsub = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
|
||||||
class TestRedisSubscription:
|
class TestRedisSubscription:
|
||||||
"""Test cases for the _RedisSubscription class."""
|
"""Test cases for the _RedisSubscription class."""
|
||||||
|
|
||||||
@@ -619,10 +628,13 @@ class TestRedisSubscription:
|
|||||||
class TestRedisShardedSubscription:
|
class TestRedisShardedSubscription:
|
||||||
"""Test cases for the _RedisShardedSubscription class."""
|
"""Test cases for the _RedisShardedSubscription class."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_sharded_redis_type(self, monkeypatch):
|
||||||
|
monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.Redis", FakeRedisClient)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_redis_client(self) -> MagicMock:
|
def mock_redis_client(self) -> FakeRedisClient:
|
||||||
client = MagicMock()
|
return FakeRedisClient()
|
||||||
return client
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_pubsub(self) -> MagicMock:
|
def mock_pubsub(self) -> MagicMock:
|
||||||
@@ -636,7 +648,7 @@ class TestRedisShardedSubscription:
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sharded_subscription(
|
def sharded_subscription(
|
||||||
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient
|
||||||
) -> Generator[_RedisShardedSubscription, None, None]:
|
) -> Generator[_RedisShardedSubscription, None, None]:
|
||||||
"""Create a _RedisShardedSubscription instance for testing."""
|
"""Create a _RedisShardedSubscription instance for testing."""
|
||||||
subscription = _RedisShardedSubscription(
|
subscription = _RedisShardedSubscription(
|
||||||
@@ -657,7 +669,7 @@ class TestRedisShardedSubscription:
|
|||||||
|
|
||||||
# ==================== Lifecycle Tests ====================
|
# ==================== Lifecycle Tests ====================
|
||||||
|
|
||||||
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient):
|
||||||
"""Test that sharded subscription is properly initialized."""
|
"""Test that sharded subscription is properly initialized."""
|
||||||
subscription = _RedisShardedSubscription(
|
subscription = _RedisShardedSubscription(
|
||||||
client=mock_redis_client,
|
client=mock_redis_client,
|
||||||
@@ -970,7 +982,7 @@ class TestRedisShardedSubscription:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sharded_subscription_scenarios(
|
def test_sharded_subscription_scenarios(
|
||||||
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient
|
||||||
):
|
):
|
||||||
"""Test various sharded subscription scenarios using table-driven approach."""
|
"""Test various sharded subscription scenarios using table-driven approach."""
|
||||||
subscription = _RedisShardedSubscription(
|
subscription = _RedisShardedSubscription(
|
||||||
@@ -1058,7 +1070,7 @@ class TestRedisShardedSubscription:
|
|||||||
# Close should still work
|
# Close should still work
|
||||||
sharded_subscription.close() # Should not raise
|
sharded_subscription.close() # Should not raise
|
||||||
|
|
||||||
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient):
|
||||||
"""Test various sharded channel name formats."""
|
"""Test various sharded channel name formats."""
|
||||||
channel_names = [
|
channel_names = [
|
||||||
"simple",
|
"simple",
|
||||||
@@ -1120,10 +1132,13 @@ class TestRedisSubscriptionCommon:
|
|||||||
"""Parameterized fixture providing subscription type and class."""
|
"""Parameterized fixture providing subscription type and class."""
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_sharded_redis_type(self, monkeypatch):
|
||||||
|
monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.Redis", FakeRedisClient)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_redis_client(self) -> MagicMock:
|
def mock_redis_client(self) -> FakeRedisClient:
|
||||||
client = MagicMock()
|
return FakeRedisClient()
|
||||||
return client
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_pubsub(self) -> MagicMock:
|
def mock_pubsub(self) -> MagicMock:
|
||||||
@@ -1140,7 +1155,7 @@ class TestRedisSubscriptionCommon:
|
|||||||
return pubsub
|
return pubsub
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: FakeRedisClient):
|
||||||
"""Create a subscription instance based on parameterized type."""
|
"""Create a subscription instance based on parameterized type."""
|
||||||
subscription_type, subscription_class = subscription_params
|
subscription_type, subscription_class = subscription_params
|
||||||
topic_name = f"test-{subscription_type}-topic"
|
topic_name = f"test-{subscription_type}-topic"
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from core.workflow.nodes.human_input.entities import (
|
|||||||
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus
|
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus
|
||||||
from models.human_input import RecipientType
|
from models.human_input import RecipientType
|
||||||
from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError
|
from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError
|
||||||
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -88,7 +87,6 @@ def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factor
|
|||||||
|
|
||||||
resume_task.apply_async.assert_called_once()
|
resume_task.apply_async.assert_called_once()
|
||||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||||
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
|
||||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||||
|
|
||||||
|
|
||||||
@@ -130,7 +128,6 @@ def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_f
|
|||||||
|
|
||||||
resume_task.apply_async.assert_called_once()
|
resume_task.apply_async.assert_called_once()
|
||||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||||
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
|
||||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user