From 71d299d0d3665af2979db65a0e00df791fcf5bd9 Mon Sep 17 00:00:00 2001 From: YBoy Date: Fri, 3 Apr 2026 04:25:30 +0200 Subject: [PATCH] refactor(api): type hit testing retrieve responses with TypedDict (#34484) --- api/services/hit_testing_service.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 82e0b0f8b1f..7900f6da266 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,7 +1,7 @@ import json import logging import time -from typing import Any +from typing import Any, TypedDict from graphon.model_runtime.entities import LLMMode @@ -18,6 +18,16 @@ from models.enums import CreatorUserRole, DatasetQuerySource logger = logging.getLogger(__name__) + +class QueryDict(TypedDict): + content: str + + +class RetrieveResponseDict(TypedDict): + query: QueryDict + records: list[dict[str, Any]] + + default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, @@ -150,7 +160,7 @@ class HitTestingService: return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod - def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]: + def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict: records = RetrievalService.format_retrieval_documents(documents) return { @@ -161,7 +171,7 @@ class HitTestingService: } @classmethod - def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]: + def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> RetrieveResponseDict: records = [] if dataset.provider == "external": for document in documents: