refactor(api): replace json.loads with Pydantic validation in controllers and infra layers (#34277)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Dream
2026-04-01 01:41:44 -04:00
committed by GitHub
parent 09ee8ea1f5
commit c51cd42cb4
23 changed files with 170 additions and 114 deletions

View File

@@ -9,7 +9,7 @@ from graphon.enums import NodeType
from graphon.file import File
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, ValidationError, field_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
@@ -268,22 +268,18 @@ class DraftWorkflowApi(Resource):
content_type = request.headers.get("Content-Type", "")
payload_data: dict[str, Any] | None = None
if "application/json" in content_type:
payload_data = request.get_json(silent=True)
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
elif "text/plain" in content_type:
try:
payload_data = json.loads(request.data.decode("utf-8"))
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
if not isinstance(payload_data, dict):
args_model = SyncDraftWorkflowPayload.model_validate_json(request.data)
except (ValueError, ValidationError):
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
args = args_model.model_dump()
workflow_service = WorkflowService()

View File

@@ -5,7 +5,7 @@ from typing import Any, Literal, cast
from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
@@ -186,29 +186,14 @@ class DraftRagPipelineApi(Resource):
if "application/json" in content_type:
payload_dict = console_ns.payload or {}
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
elif "text/plain" in content_type:
try:
data = json.loads(request.data.decode("utf-8"))
if "graph" not in data or "features" not in data:
raise ValueError("graph or features not found in data")
if not isinstance(data.get("graph"), dict):
raise ValueError("graph is not a dict")
payload_dict = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
"rag_pipeline_variables": data.get("rag_pipeline_variables"),
}
except json.JSONDecodeError:
payload = DraftWorkflowSyncPayload.model_validate_json(request.data)
except (ValueError, ValidationError):
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
rag_pipeline_service = RagPipelineService()
try:

View File

@@ -38,6 +38,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.utils import JSON_DICT_ADAPTER
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile
@@ -469,7 +470,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
if trace_info.message_data and trace_info.message_data.message_metadata:
metadata_dict = json.loads(trace_info.message_data.message_metadata)
metadata_dict = JSON_DICT_ADAPTER.validate_json(trace_info.message_data.message_metadata)
if model_params := metadata_dict.get("model_parameters"):
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)

View File

@@ -1,4 +1,3 @@
import json
import logging
import os
from datetime import datetime, timedelta
@@ -25,6 +24,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.utils import JSON_DICT_ADAPTER
from extensions.ext_database import db
from models import EndUser
from models.workflow import WorkflowNodeExecutionModel
@@ -153,7 +153,7 @@ class MLflowDataTrace(BaseTraceInstance):
inputs = node.process_data # contains request URL
if not inputs:
inputs = json.loads(node.inputs) if node.inputs else {}
inputs = JSON_DICT_ADAPTER.validate_json(node.inputs) if node.inputs else {}
node_span = start_span_no_context(
name=node.title,
@@ -180,7 +180,7 @@ class MLflowDataTrace(BaseTraceInstance):
# End node span
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
outputs = json.loads(node.outputs) if node.outputs else {}
outputs = JSON_DICT_ADAPTER.validate_json(node.outputs) if node.outputs else {}
if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
outputs = self._parse_knowledge_retrieval_outputs(outputs)
elif node.node_type == BuiltinNodeTypes.LLM:
@@ -216,8 +216,8 @@ class MLflowDataTrace(BaseTraceInstance):
return {}, {}
try:
data = json.loads(node.process_data)
except (json.JSONDecodeError, TypeError):
data = JSON_DICT_ADAPTER.validate_json(node.process_data)
except (ValueError, TypeError):
return {}, {}
inputs = self._parse_prompts(data.get("prompts"))

View File

@@ -11,8 +11,10 @@ from uuid import UUID, uuid4
from cachetools import LRUCache
from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import TypedDict
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
@@ -33,7 +35,7 @@ from core.ops.entities.trace_entity import (
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from core.ops.utils import JSON_DICT_ADAPTER, get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
@@ -50,6 +52,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class _AppTracingConfig(TypedDict, total=False):
enabled: bool
tracing_provider: str | None
_app_tracing_config_adapter: TypeAdapter[_AppTracingConfig] = TypeAdapter(_AppTracingConfig)
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
app_name = ""
@@ -468,7 +478,7 @@ class OpsTraceManager:
if app is None:
return None
app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
app_ops_trace_config = _app_tracing_config_adapter.validate_json(app.tracing) if app.tracing else None
if app_ops_trace_config is None:
return None
if not app_ops_trace_config.get("enabled"):
@@ -560,7 +570,7 @@ class OpsTraceManager:
raise ValueError("App not found")
if not app.tracing:
return {"enabled": False, "tracing_provider": None}
app_trace_config = json.loads(app.tracing)
app_trace_config = _app_tracing_config_adapter.validate_json(app.tracing)
return app_trace_config
@staticmethod
@@ -636,7 +646,6 @@ class TraceTask:
carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading
large JSON blobs unnecessarily.
"""
import json
from models.workflow import WorkflowNodeExecutionModel
@@ -658,7 +667,7 @@ class TraceTask:
if not raw:
continue
try:
outputs = json.loads(raw) if isinstance(raw, str) else raw
outputs = JSON_DICT_ADAPTER.validate_json(raw) if isinstance(raw, str) else raw
except (ValueError, TypeError):
continue
if not isinstance(outputs, dict):
@@ -1420,7 +1429,7 @@ class TraceTask:
return {}
try:
metadata = json.loads(message_data.message_metadata)
metadata = JSON_DICT_ADAPTER.validate_json(message_data.message_metadata)
usage = metadata.get("usage", {})
time_to_first_token = usage.get("time_to_first_token")
time_to_generate = usage.get("time_to_generate")
@@ -1430,7 +1439,7 @@ class TraceTask:
"llm_streaming_time_to_generate": time_to_generate,
"is_streaming_request": time_to_first_token is not None,
}
except (json.JSONDecodeError, AttributeError):
except (ValueError, AttributeError):
return {}

View File

@@ -3,11 +3,14 @@ from datetime import datetime
from typing import Any, Union
from urllib.parse import urlparse
from pydantic import TypeAdapter
from sqlalchemy import select
from models.engine import db
from models.model import Message
JSON_DICT_ADAPTER: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
def filter_none_values(data: dict[str, Any]) -> dict[str, Any]:
new_data = {}

View File

@@ -10,6 +10,7 @@ from mysql.connector import Error as MySQLError
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -178,9 +179,7 @@ class AlibabaCloudMySQLVector(BaseVector):
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
docs = []
for record in cur:
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
metadata = parse_metadata_json(record["meta"])
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs
@@ -263,15 +262,13 @@ class AlibabaCloudMySQLVector(BaseVector):
# similarity = 1 / (1 + distance)
similarity = 1.0 / (1.0 + distance)
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
metadata = parse_metadata_json(record["meta"])
metadata["score"] = similarity
metadata["distance"] = distance
if similarity >= score_threshold:
docs.append(Document(page_content=record["text"], metadata=metadata))
except (ValueError, json.JSONDecodeError) as e:
except (ValueError, TypeError) as e:
logger.warning("Error processing search result: %s", e)
continue
@@ -306,9 +303,7 @@ class AlibabaCloudMySQLVector(BaseVector):
)
docs = []
for record in cur:
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
metadata = parse_metadata_json(record["meta"])
metadata["score"] = float(record["score"])
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs

View File

@@ -8,6 +8,7 @@ _import_err_msg = (
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
@@ -257,7 +258,7 @@ class AnalyticdbVectorOpenAPI:
documents = []
for match in response.body.matches.match:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata = parse_metadata_json(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
@@ -294,7 +295,7 @@ class AnalyticdbVectorOpenAPI:
documents = []
for match in response.body.matches.match:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata = parse_metadata_json(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),

View File

@@ -29,6 +29,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams,
from configs import dify_config
from core.rag.datasource.vdb.field import Field as VDBField
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -173,15 +174,9 @@ class BaiduVector(BaseVector):
score = row.get("score", 0.0)
meta = row_data.get(VDBField.METADATA_KEY, {})
# Handle both JSON string and dict formats for backward compatibility
if isinstance(meta, str):
try:
import json
meta = json.loads(meta)
except (json.JSONDecodeError, TypeError):
meta = {}
elif not isinstance(meta, dict):
try:
meta = parse_metadata_json(meta)
except (ValueError, TypeError):
meta = {}
if score >= score_threshold:

View File

@@ -17,7 +17,7 @@ if TYPE_CHECKING:
from clickzetta.connector.v0.connection import Connection # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.field import Field, parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.embedding.embedding_base import Embeddings
@@ -357,18 +357,19 @@ class ClickzettaVector(BaseVector):
"""
try:
if raw_metadata:
metadata = json.loads(raw_metadata)
# First parse may yield a string (double-encoded JSON) so use json.loads
first_pass = json.loads(raw_metadata)
# Handle double-encoded JSON
if isinstance(metadata, str):
metadata = json.loads(metadata)
# Ensure we have a dict
if not isinstance(metadata, dict):
if isinstance(first_pass, str):
metadata = parse_metadata_json(first_pass)
elif isinstance(first_pass, dict):
metadata = first_pass
else:
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, TypeError):
except (json.JSONDecodeError, ValueError, TypeError):
logger.exception("JSON parsing failed for metadata")
# Fallback: extract document_id with regex
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "")
@@ -930,17 +931,18 @@ class ClickzettaVector(BaseVector):
# Parse metadata from JSON string (may be double-encoded)
try:
if row[2]:
metadata = json.loads(row[2])
# First parse may yield a string (double-encoded JSON)
first_pass = json.loads(row[2])
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
if not isinstance(metadata, dict):
if isinstance(first_pass, str):
metadata = parse_metadata_json(first_pass)
elif isinstance(first_pass, dict):
metadata = first_pass
else:
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, TypeError):
except (json.JSONDecodeError, ValueError, TypeError):
logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex

View File

@@ -1,4 +1,24 @@
from enum import StrEnum, auto
from typing import Any
from pydantic import TypeAdapter
_metadata_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
def parse_metadata_json(raw: Any) -> dict[str, Any]:
"""Parse metadata from a JSON string or pass through an existing dict.
Many VDB drivers return metadata as either a JSON string or an already-
decoded dict depending on the column type and driver version.
"""
if raw is None or raw in ("", b""):
return {}
if isinstance(raw, dict):
return raw
if not isinstance(raw, (str, bytes, bytearray)):
return {}
return _metadata_adapter.validate_json(raw)
class Field(StrEnum):

View File

@@ -9,6 +9,7 @@ from psycopg import sql as psql
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -217,8 +218,7 @@ class HologresVector(BaseVector):
text = row[2]
meta = row[3]
if isinstance(meta, str):
meta = json.loads(meta)
meta = parse_metadata_json(meta)
# Convert distance to similarity score (consistent with pgvector)
score = 1 - distance
@@ -265,8 +265,7 @@ class HologresVector(BaseVector):
meta = row[2]
score = row[-1] # score is the last column from return_score
if isinstance(meta, str):
meta = json.loads(meta)
meta = parse_metadata_json(meta)
meta["score"] = score
docs.append(Document(page_content=text, metadata=meta))

View File

@@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any
from configs import dify_config
from configs.middleware.vdb.iris_config import IrisVectorConfig
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -269,7 +270,7 @@ class IrisVector(BaseVector):
if len(row) >= 4:
text, meta_str, score = row[1], row[2], float(row[3])
if score >= score_threshold:
metadata = json.loads(meta_str) if meta_str else {}
metadata = parse_metadata_json(meta_str)
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
@@ -384,7 +385,7 @@ class IrisVector(BaseVector):
meta_str = row[2]
score_value = row[3]
metadata = json.loads(meta_str) if meta_str else {}
metadata = parse_metadata_json(meta_str)
# Add score to metadata for hybrid search compatibility
score = float(score_value) if score_value is not None else 0.0
metadata["score"] = score

View File

@@ -9,6 +9,7 @@ from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -196,11 +197,7 @@ class MatrixoneVector(BaseVector):
docs = []
for result in results:
metadata = result.metadata
if isinstance(metadata, str):
import json
metadata = json.loads(metadata)
metadata = parse_metadata_json(result.metadata)
score = 1 - result.distance
if score >= score_threshold:
metadata["score"] = score

View File

@@ -10,6 +10,7 @@ from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -366,8 +367,8 @@ class OceanBaseVector(BaseVector):
# Parse metadata JSON
try:
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str
except json.JSONDecodeError:
metadata = parse_metadata_json(metadata_str)
except (ValueError, TypeError):
logger.warning("Invalid JSON metadata: %s", metadata_str)
metadata = {}

View File

@@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator
from tablestore import BatchGetRowRequest, TableInBatchGetRowItem
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.field import Field, parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -73,7 +73,8 @@ class TableStoreVector(BaseVector):
for item in table_result:
if item.is_ok and item.row:
kv = {k: v for k, v, _ in item.row.attribute_columns}
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY])))
metadata = parse_metadata_json(kv[Field.METADATA_KEY])
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=metadata))
return docs
def get_type(self) -> str:
@@ -311,7 +312,7 @@ class TableStoreVector(BaseVector):
metadata_str = ots_column_map.get(Field.METADATA_KEY)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}
metadata = parse_metadata_json(metadata_str)
metadata["score"] = search_hit.score
@@ -371,7 +372,7 @@ class TableStoreVector(BaseVector):
ots_column_map[col[0]] = col[1]
metadata_str = ots_column_map.get(Field.METADATA_KEY)
metadata = json.loads(metadata_str) if metadata_str else {}
metadata = parse_metadata_json(metadata_str)
vector_str = ots_column_map.get(Field.VECTOR)
vector = json.loads(vector_str) if vector_str else None

View File

@@ -11,6 +11,7 @@ from tcvectordb.model import index as vdb_index # type: ignore
from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, WeightedRerank # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -286,13 +287,10 @@ class TencentVector(BaseVector):
return docs
for result in res[0]:
meta = result.get(self.field_metadata)
if isinstance(meta, str):
# Compatible with version 1.1.3 and below.
meta = json.loads(meta)
score = 1 - result.get("score", 0.0)
else:
score = result.get("score", 0.0)
raw_meta = result.get(self.field_metadata)
# Compatible with version 1.1.3 and below: str means old driver.
score = (1 - result.get("score", 0.0)) if isinstance(raw_meta, str) else result.get("score", 0.0)
meta = parse_metadata_json(raw_meta)
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)

View File

@@ -9,7 +9,7 @@ from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.field import Field, parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -228,7 +228,7 @@ class TiDBVector(BaseVector):
)
results = [(row[0], row[1], row[2]) for row in res]
for meta, text, distance in results:
metadata = json.loads(meta)
metadata = parse_metadata_json(meta)
metadata["score"] = 1 - distance
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@@ -15,6 +15,7 @@ from volcengine.viking_db import ( # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field as vdb_Field
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -163,7 +164,7 @@ class VikingDBVector(BaseVector):
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = json.loads(metadata)
metadata = parse_metadata_json(metadata)
if metadata.get(key) == value:
ids.append(result.id)
return ids
@@ -189,9 +190,7 @@ class VikingDBVector(BaseVector):
docs = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = json.loads(metadata)
metadata = parse_metadata_json(result.fields.get(vdb_Field.METADATA_KEY))
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata)

View File

@@ -20,6 +20,7 @@ from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.ops.utils import JSON_DICT_ADAPTER
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository
from extensions.logstore.aliyun_logstore import AliyunLogStore
@@ -48,10 +49,10 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
"""
logger.debug("_dict_to_workflow_node_execution: data keys=%s", list(data.keys())[:5])
# Parse JSON fields
inputs = json.loads(data.get("inputs", "{}"))
process_data = json.loads(data.get("process_data", "{}"))
outputs = json.loads(data.get("outputs", "{}"))
metadata = json.loads(data.get("execution_metadata", "{}"))
inputs = JSON_DICT_ADAPTER.validate_json(data.get("inputs") or "{}")
process_data = JSON_DICT_ADAPTER.validate_json(data.get("process_data") or "{}")
outputs = JSON_DICT_ADAPTER.validate_json(data.get("outputs") or "{}")
metadata = JSON_DICT_ADAPTER.validate_json(data.get("execution_metadata") or "{}")
# Convert metadata to domain enum keys
domain_metadata = {}

View File

@@ -15,8 +15,12 @@ from datetime import datetime
from enum import StrEnum, auto
from typing import Any
from pydantic import TypeAdapter
logger = logging.getLogger(__name__)
_metadata_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
class FileStatus(StrEnum):
"""File status enumeration"""
@@ -455,8 +459,8 @@ class FileLifecycleManager:
try:
if self._storage.exists(self._metadata_file):
metadata_content = self._storage.load_once(self._metadata_file)
result = json.loads(metadata_content.decode("utf-8"))
return dict(result) if result else {}
result = _metadata_adapter.validate_json(metadata_content)
return result or {}
else:
return {}
except Exception as e:

View File

@@ -1,13 +1,16 @@
import base64
import io
import json
from collections.abc import Generator
from typing import Any
from google.cloud import storage as google_cloud_storage # type: ignore
from pydantic import TypeAdapter
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
_service_account_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
class GoogleCloudStorage(BaseStorage):
"""Implementation for Google Cloud storage."""
@@ -21,7 +24,7 @@ class GoogleCloudStorage(BaseStorage):
if service_account_json_str:
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
# convert str to object
service_account_obj = json.loads(service_account_json)
service_account_obj = _service_account_adapter.validate_json(service_account_json)
self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj)
else:
self.client = google_cloud_storage.Client()

View File

@@ -0,0 +1,45 @@
import pytest
from core.rag.datasource.vdb.field import parse_metadata_json
class TestParseMetadataJson:
def test_none_returns_empty_dict(self):
assert parse_metadata_json(None) == {}
def test_empty_string_returns_empty_dict(self):
assert parse_metadata_json("") == {}
def test_valid_json_string(self):
result = parse_metadata_json('{"doc_id": "abc", "score": 0.9}')
assert result == {"doc_id": "abc", "score": 0.9}
def test_dict_passthrough(self):
original = {"doc_id": "abc", "document_id": "123"}
result = parse_metadata_json(original)
assert result == original
def test_empty_json_object(self):
assert parse_metadata_json("{}") == {}
def test_invalid_json_raises_value_error(self):
with pytest.raises(ValueError):
parse_metadata_json("{invalid json")
def test_nested_metadata(self):
result = parse_metadata_json('{"doc_id": "1", "extra": {"nested": true}}')
assert result["extra"]["nested"] is True
def test_non_str_non_dict_returns_empty_dict(self):
assert parse_metadata_json(123) == {}
assert parse_metadata_json([1, 2]) == {}
def test_bytes_input(self):
result = parse_metadata_json(b'{"key": "value"}')
assert result == {"key": "value"}
def test_empty_bytes_returns_empty_dict(self):
assert parse_metadata_json(b"") == {}
def test_empty_bytearray_returns_empty_dict(self):
assert parse_metadata_json(bytearray(b"")) == {}