mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 19:32:16 +08:00
refactor: enhance clean messages task (#29638)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
216
api/services/retention/conversation/messages_clean_policy.py
Normal file
216
api/services/retention/conversation/messages_clean_policy.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import datetime
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleMessage:
|
||||
id: str
|
||||
app_id: str
|
||||
created_at: datetime.datetime
|
||||
|
||||
|
||||
class MessagesCleanPolicy(ABC):
|
||||
"""
|
||||
Abstract base class for message cleanup policies.
|
||||
|
||||
A policy determines which messages from a batch should be deleted.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def filter_message_ids(
|
||||
self,
|
||||
messages: Sequence[SimpleMessage],
|
||||
app_to_tenant: dict[str, str],
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Filter messages and return IDs of messages that should be deleted.
|
||||
|
||||
Args:
|
||||
messages: Batch of messages to evaluate
|
||||
app_to_tenant: Mapping from app_id to tenant_id
|
||||
|
||||
Returns:
|
||||
List of message IDs that should be deleted
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BillingDisabledPolicy(MessagesCleanPolicy):
|
||||
"""
|
||||
Policy for community or enterpriseedition (billing disabled).
|
||||
|
||||
No special filter logic, just return all message ids.
|
||||
"""
|
||||
|
||||
def filter_message_ids(
|
||||
self,
|
||||
messages: Sequence[SimpleMessage],
|
||||
app_to_tenant: dict[str, str],
|
||||
) -> Sequence[str]:
|
||||
return [msg.id for msg in messages]
|
||||
|
||||
|
||||
class BillingSandboxPolicy(MessagesCleanPolicy):
|
||||
"""
|
||||
Policy for sandbox plan tenants in cloud edition (billing enabled).
|
||||
|
||||
Filters messages based on sandbox plan expiration rules:
|
||||
- Skip tenants in the whitelist
|
||||
- Only delete messages from sandbox plan tenants
|
||||
- Respect grace period after subscription expiration
|
||||
- Safe default: if tenant mapping or plan is missing, do NOT delete
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]],
|
||||
graceful_period_days: int = 21,
|
||||
tenant_whitelist: Sequence[str] | None = None,
|
||||
current_timestamp: int | None = None,
|
||||
) -> None:
|
||||
self._graceful_period_days = graceful_period_days
|
||||
self._tenant_whitelist: Sequence[str] = tenant_whitelist or []
|
||||
self._plan_provider = plan_provider
|
||||
self._current_timestamp = current_timestamp
|
||||
|
||||
def filter_message_ids(
|
||||
self,
|
||||
messages: Sequence[SimpleMessage],
|
||||
app_to_tenant: dict[str, str],
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Filter messages based on sandbox plan expiration rules.
|
||||
|
||||
Args:
|
||||
messages: Batch of messages to evaluate
|
||||
app_to_tenant: Mapping from app_id to tenant_id
|
||||
|
||||
Returns:
|
||||
List of message IDs that should be deleted
|
||||
"""
|
||||
if not messages or not app_to_tenant:
|
||||
return []
|
||||
|
||||
# Get unique tenant_ids and fetch subscription plans
|
||||
tenant_ids = list(set(app_to_tenant.values()))
|
||||
tenant_plans = self._plan_provider(tenant_ids)
|
||||
|
||||
if not tenant_plans:
|
||||
return []
|
||||
|
||||
# Apply sandbox deletion rules
|
||||
return self._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
)
|
||||
|
||||
def _filter_expired_sandbox_messages(
|
||||
self,
|
||||
messages: Sequence[SimpleMessage],
|
||||
app_to_tenant: dict[str, str],
|
||||
tenant_plans: dict[str, SubscriptionPlan],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Filter messages that should be deleted based on sandbox plan expiration.
|
||||
|
||||
A message should be deleted if:
|
||||
1. It belongs to a sandbox tenant AND
|
||||
2. Either:
|
||||
a) The tenant has no previous subscription (expiration_date == -1), OR
|
||||
b) The subscription expired more than graceful_period_days ago
|
||||
|
||||
Args:
|
||||
messages: List of message objects with id and app_id attributes
|
||||
app_to_tenant: Mapping from app_id to tenant_id
|
||||
tenant_plans: Mapping from tenant_id to subscription plan info
|
||||
|
||||
Returns:
|
||||
List of message IDs that should be deleted
|
||||
"""
|
||||
current_timestamp = self._current_timestamp
|
||||
if current_timestamp is None:
|
||||
current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
|
||||
|
||||
sandbox_message_ids: list[str] = []
|
||||
graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60
|
||||
|
||||
for msg in messages:
|
||||
# Get tenant_id for this message's app
|
||||
tenant_id = app_to_tenant.get(msg.app_id)
|
||||
if not tenant_id:
|
||||
continue
|
||||
|
||||
# Skip tenant messages in whitelist
|
||||
if tenant_id in self._tenant_whitelist:
|
||||
continue
|
||||
|
||||
# Get subscription plan for this tenant
|
||||
tenant_plan = tenant_plans.get(tenant_id)
|
||||
if not tenant_plan:
|
||||
continue
|
||||
|
||||
plan = str(tenant_plan["plan"])
|
||||
expiration_date = int(tenant_plan["expiration_date"])
|
||||
|
||||
# Only process sandbox plans
|
||||
if plan != CloudPlan.SANDBOX:
|
||||
continue
|
||||
|
||||
# Case 1: No previous subscription (-1 means never had a paid subscription)
|
||||
if expiration_date == -1:
|
||||
sandbox_message_ids.append(msg.id)
|
||||
continue
|
||||
|
||||
# Case 2: Subscription expired beyond grace period
|
||||
if current_timestamp - expiration_date > graceful_period_seconds:
|
||||
sandbox_message_ids.append(msg.id)
|
||||
|
||||
return sandbox_message_ids
|
||||
|
||||
|
||||
def create_message_clean_policy(
|
||||
graceful_period_days: int = 21,
|
||||
current_timestamp: int | None = None,
|
||||
) -> MessagesCleanPolicy:
|
||||
"""
|
||||
Factory function to create the appropriate message clean policy.
|
||||
|
||||
Determines which policy to use based on BILLING_ENABLED configuration:
|
||||
- If BILLING_ENABLED is True: returns BillingSandboxPolicy
|
||||
- If BILLING_ENABLED is False: returns BillingDisabledPolicy
|
||||
|
||||
Args:
|
||||
graceful_period_days: Grace period in days after subscription expiration (default: 21)
|
||||
current_timestamp: Current Unix timestamp for testing (default: None, uses current time)
|
||||
"""
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy")
|
||||
return BillingDisabledPolicy()
|
||||
|
||||
# Billing enabled - fetch whitelist from BillingService
|
||||
tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist()
|
||||
plan_provider = BillingService.get_plan_bulk_with_cache
|
||||
|
||||
logger.info(
|
||||
"create_message_clean_policy: billing enabled, using BillingSandboxPolicy "
|
||||
"(graceful_period_days=%s, whitelist=%s)",
|
||||
graceful_period_days,
|
||||
tenant_whitelist,
|
||||
)
|
||||
|
||||
return BillingSandboxPolicy(
|
||||
plan_provider=plan_provider,
|
||||
graceful_period_days=graceful_period_days,
|
||||
tenant_whitelist=tenant_whitelist,
|
||||
current_timestamp=current_timestamp,
|
||||
)
|
||||
334
api/services/retention/conversation/messages_clean_service.py
Normal file
334
api/services/retention/conversation/messages_clean_service.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import datetime
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
DatasetRetrieverResource,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from services.retention.conversation.messages_clean_policy import (
|
||||
MessagesCleanPolicy,
|
||||
SimpleMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessagesCleanService:
|
||||
"""
|
||||
Service for cleaning expired messages based on retention policies.
|
||||
|
||||
Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted.
|
||||
If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: MessagesCleanPolicy,
|
||||
end_before: datetime.datetime,
|
||||
start_from: datetime.datetime | None = None,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the service with cleanup parameters.
|
||||
|
||||
Args:
|
||||
policy: The policy that determines which messages to delete
|
||||
end_before: End time (exclusive) of the range
|
||||
start_from: Optional start time (inclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
"""
|
||||
self._policy = policy
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._batch_size = batch_size
|
||||
self._dry_run = dry_run
|
||||
|
||||
@classmethod
|
||||
def from_time_range(
|
||||
cls,
|
||||
policy: MessagesCleanPolicy,
|
||||
start_from: datetime.datetime,
|
||||
end_before: datetime.datetime,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
) -> "MessagesCleanService":
|
||||
"""
|
||||
Create a service instance for cleaning messages within a specific time range.
|
||||
|
||||
Time range is [start_from, end_before).
|
||||
|
||||
Args:
|
||||
policy: The policy that determines which messages to delete
|
||||
start_from: Start time (inclusive) of the range
|
||||
end_before: End time (exclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
|
||||
Returns:
|
||||
MessagesCleanService instance
|
||||
|
||||
Raises:
|
||||
ValueError: If start_from >= end_before or invalid parameters
|
||||
"""
|
||||
if start_from >= end_before:
|
||||
raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
|
||||
|
||||
if batch_size <= 0:
|
||||
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
|
||||
|
||||
logger.info(
|
||||
"clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s",
|
||||
start_from,
|
||||
end_before,
|
||||
batch_size,
|
||||
policy.__class__.__name__,
|
||||
)
|
||||
|
||||
return cls(
|
||||
policy=policy,
|
||||
end_before=end_before,
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_days(
|
||||
cls,
|
||||
policy: MessagesCleanPolicy,
|
||||
days: int = 30,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
) -> "MessagesCleanService":
|
||||
"""
|
||||
Create a service instance for cleaning messages older than specified days.
|
||||
|
||||
Args:
|
||||
policy: The policy that determines which messages to delete
|
||||
days: Number of days to look back from now
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
|
||||
Returns:
|
||||
MessagesCleanService instance
|
||||
|
||||
Raises:
|
||||
ValueError: If invalid parameters
|
||||
"""
|
||||
if days < 0:
|
||||
raise ValueError(f"days ({days}) must be greater than or equal to 0")
|
||||
|
||||
if batch_size <= 0:
|
||||
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
|
||||
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
|
||||
|
||||
logger.info(
|
||||
"clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
|
||||
days,
|
||||
end_before,
|
||||
batch_size,
|
||||
policy.__class__.__name__,
|
||||
)
|
||||
|
||||
return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
|
||||
|
||||
def run(self) -> dict[str, int]:
|
||||
"""
|
||||
Execute the message cleanup operation.
|
||||
|
||||
Returns:
|
||||
Dict with statistics: batches, filtered_messages, total_deleted
|
||||
"""
|
||||
return self._clean_messages_by_time_range()
|
||||
|
||||
def _clean_messages_by_time_range(self) -> dict[str, int]:
|
||||
"""
|
||||
Clean messages within a time range using cursor-based pagination.
|
||||
|
||||
Time range is [start_from, end_before)
|
||||
|
||||
Steps:
|
||||
1. Iterate messages using cursor pagination (by created_at, id)
|
||||
2. Query app_id -> tenant_id mapping
|
||||
3. Delegate to policy to determine which messages to delete
|
||||
4. Batch delete messages and their relations
|
||||
|
||||
Returns:
|
||||
Dict with statistics: batches, filtered_messages, total_deleted
|
||||
"""
|
||||
stats = {
|
||||
"batches": 0,
|
||||
"total_messages": 0,
|
||||
"filtered_messages": 0,
|
||||
"total_deleted": 0,
|
||||
}
|
||||
|
||||
# Cursor-based pagination using (created_at, id) to avoid infinite loops
|
||||
# and ensure proper ordering with time-based filtering
|
||||
_cursor: tuple[datetime.datetime, str] | None = None
|
||||
|
||||
logger.info(
|
||||
"clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
|
||||
self._dry_run,
|
||||
self._start_from,
|
||||
self._end_before,
|
||||
)
|
||||
|
||||
while True:
|
||||
stats["batches"] += 1
|
||||
|
||||
# Step 1: Fetch a batch of messages using cursor
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
msg_stmt = (
|
||||
select(Message.id, Message.app_id, Message.created_at)
|
||||
.where(Message.created_at < self._end_before)
|
||||
.order_by(Message.created_at, Message.id)
|
||||
.limit(self._batch_size)
|
||||
)
|
||||
|
||||
if self._start_from:
|
||||
msg_stmt = msg_stmt.where(Message.created_at >= self._start_from)
|
||||
|
||||
# Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
|
||||
# This translates to:
|
||||
# created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
|
||||
if _cursor:
|
||||
# Continuing from previous batch
|
||||
msg_stmt = msg_stmt.where(
|
||||
(Message.created_at > _cursor[0])
|
||||
| ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
|
||||
)
|
||||
|
||||
raw_messages = list(session.execute(msg_stmt).all())
|
||||
messages = [
|
||||
SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
|
||||
for msg_id, app_id, msg_created_at in raw_messages
|
||||
]
|
||||
|
||||
# Track total messages fetched across all batches
|
||||
stats["total_messages"] += len(messages)
|
||||
|
||||
if not messages:
|
||||
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
|
||||
break
|
||||
|
||||
# Update cursor to the last message's (created_at, id)
|
||||
_cursor = (messages[-1].created_at, messages[-1].id)
|
||||
|
||||
# Step 2: Extract app_ids and query tenant_ids
|
||||
app_ids = list({msg.app_id for msg in messages})
|
||||
|
||||
if not app_ids:
|
||||
logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
|
||||
continue
|
||||
|
||||
app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
|
||||
apps = list(session.execute(app_stmt).all())
|
||||
|
||||
if not apps:
|
||||
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
|
||||
continue
|
||||
|
||||
# Build app_id -> tenant_id mapping
|
||||
app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
|
||||
|
||||
# Step 3: Delegate to policy to determine which messages to delete
|
||||
message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant)
|
||||
|
||||
if not message_ids_to_delete:
|
||||
logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
|
||||
continue
|
||||
|
||||
stats["filtered_messages"] += len(message_ids_to_delete)
|
||||
|
||||
# Step 4: Batch delete messages and their relations
|
||||
if not self._dry_run:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Delete related records first
|
||||
self._batch_delete_message_relations(session, message_ids_to_delete)
|
||||
|
||||
# Delete messages
|
||||
delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete))
|
||||
delete_result = cast(CursorResult, session.execute(delete_stmt))
|
||||
messages_deleted = delete_result.rowcount
|
||||
session.commit()
|
||||
|
||||
stats["total_deleted"] += messages_deleted
|
||||
|
||||
logger.info(
|
||||
"clean_messages (batch %s): processed %s messages, deleted %s messages",
|
||||
stats["batches"],
|
||||
len(messages),
|
||||
messages_deleted,
|
||||
)
|
||||
else:
|
||||
# Log random sample of message IDs that would be deleted (up to 10)
|
||||
sample_size = min(10, len(message_ids_to_delete))
|
||||
sampled_ids = random.sample(list(message_ids_to_delete), sample_size)
|
||||
|
||||
logger.info(
|
||||
"clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:",
|
||||
stats["batches"],
|
||||
len(message_ids_to_delete),
|
||||
sample_size,
|
||||
)
|
||||
for msg_id in sampled_ids:
|
||||
logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
|
||||
|
||||
logger.info(
|
||||
"clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
|
||||
stats["batches"],
|
||||
stats["total_messages"],
|
||||
stats["filtered_messages"],
|
||||
stats["total_deleted"],
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None:
|
||||
"""
|
||||
Batch delete all related records for given message IDs.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
message_ids: List of message IDs to delete relations for
|
||||
"""
|
||||
if not message_ids:
|
||||
return
|
||||
|
||||
# Delete all related records in batch
|
||||
session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))
|
||||
Reference in New Issue
Block a user