From e85d9a0d721ed3dfa988e275dc5b96ef6d7b62f7 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Fri, 3 Apr 2026 14:01:31 +0200 Subject: [PATCH] refactor: select in dataset_service (DatasetService class) (#34525) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/dataset_service.py | 60 ++++++++++--------- .../services/test_dataset_service_dataset.py | 51 ++++++++-------- 2 files changed, 56 insertions(+), 55 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 53bc51d4574..4e1fe3f6a1a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -14,7 +14,7 @@ from graphon.file import helpers as file_helpers from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from redis.exceptions import LockNotOwnedError -from sqlalchemy import exists, func, select +from sqlalchemy import exists, func, select, update from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound @@ -114,9 +114,11 @@ class DatasetService: if user: # get permitted dataset ids - dataset_permission = ( - db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all() - ) + dataset_permission = db.session.scalars( + select(DatasetPermission).where( + DatasetPermission.account_id == user.id, DatasetPermission.tenant_id == tenant_id + ) + ).all() permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None if user.current_role == TenantAccountRole.DATASET_OPERATOR: @@ -182,13 +184,12 @@ class DatasetService: @staticmethod def get_process_rules(dataset_id): # get the latest process rule - dataset_process_rule = ( - db.session.query(DatasetProcessRule) + dataset_process_rule = db.session.execute( + select(DatasetProcessRule) .where(DatasetProcessRule.dataset_id == dataset_id) .order_by(DatasetProcessRule.created_at.desc()) .limit(1) - .one_or_none() - ) + ).scalar_one_or_none() if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict @@ -225,7 +226,7 @@ class DatasetService: summary_index_setting: dict | None = None, ): # check if dataset name already exists - if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): + if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None if indexing_technique == IndexTechniqueType.HIGH_QUALITY: @@ -300,17 +301,17 @@ class DatasetService: ): if rag_pipeline_dataset_create_entity.name: # check if dataset name already exists - if ( - db.session.query(Dataset) - .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) - .first() + if db.session.scalar( + select(Dataset) + .where(Dataset.name == rag_pipeline_dataset_create_entity.name, Dataset.tenant_id == tenant_id) + .limit(1) ): raise DatasetNameDuplicateError( f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) else: # generate a random name as Untitled 1 2 3 ... - datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all() + datasets = db.session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all() names = [dataset.name for dataset in datasets] rag_pipeline_dataset_create_entity.name = generate_incremental_name( names, @@ -344,7 +345,7 @@ class DatasetService: @staticmethod def get_dataset(dataset_id) -> Dataset | None: - dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first() + dataset: Dataset | None = db.session.get(Dataset, dataset_id) return dataset @staticmethod @@ -466,14 +467,14 @@ class DatasetService: @staticmethod def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str): - dataset = ( - db.session.query(Dataset) + dataset = db.session.scalar( + select(Dataset) .where( Dataset.id != dataset_id, Dataset.name == name, Dataset.tenant_id == tenant_id, ) - .first() + .limit(1) ) return dataset is not None @@ -596,7 +597,7 @@ class DatasetService: filtered_data["icon_info"] = data.get("icon_info") # Update dataset in database - db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) + db.session.execute(update(Dataset).where(Dataset.id == dataset.id).values(**filtered_data)) db.session.commit() # Reload dataset to get updated values @@ -631,7 +632,7 @@ class DatasetService: if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE: return - pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first() + pipeline = db.session.get(Pipeline, dataset.pipeline_id) if not pipeline: return @@ -1138,8 +1139,10 @@ class DatasetService: if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: # For partial team permission, user needs explicit permission or be the creator if dataset.created_by != user.id: - user_permission = ( - db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first() + user_permission = db.session.scalar( + select(DatasetPermission) + .where(DatasetPermission.dataset_id == dataset.id, DatasetPermission.account_id == user.id) + .limit(1) ) if not user_permission: logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id) @@ -1161,7 +1164,9 @@ class DatasetService: elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: if not any( dp.dataset_id == dataset.id - for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all() + for dp in db.session.scalars( + select(DatasetPermission).where(DatasetPermission.account_id == user.id) + ).all() ): raise NoPermissionError("You do not have permission to access this dataset.") @@ -1175,12 +1180,11 @@ class DatasetService: @staticmethod def get_related_apps(dataset_id: str): - return ( - db.session.query(AppDatasetJoin) + return db.session.scalars( + select(AppDatasetJoin) .where(AppDatasetJoin.dataset_id == dataset_id) - .order_by(db.desc(AppDatasetJoin.created_at)) - .all() - ) + .order_by(AppDatasetJoin.created_at.desc()) + ).all() @staticmethod def update_dataset_api_status(dataset_id: str, status: bool): diff --git a/api/tests/unit_tests/services/test_dataset_service_dataset.py b/api/tests/unit_tests/services/test_dataset_service_dataset.py index 92aed7c30a8..849229ff435 100644 --- a/api/tests/unit_tests/services/test_dataset_service_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_dataset.py @@ -62,7 +62,7 @@ class TestDatasetServiceQueries: self, mock_dataset_query_dependencies ): user = DatasetServiceUnitDataFactory.create_user_mock(role=TenantAccountRole.DATASET_OPERATOR) - mock_dataset_query_dependencies["db"].session.query.return_value.filter_by.return_value.all.return_value = [] + mock_dataset_query_dependencies["db"].session.scalars.return_value.all.return_value = [] items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1", user=user) @@ -108,9 +108,7 @@ class TestDatasetServiceQueries: dataset_process_rule.rules_dict = {"delimiter": "\n"} with patch("services.dataset_service.db") as mock_db: - ( - mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value - ) = dataset_process_rule + (mock_db.session.execute.return_value.scalar_one_or_none.return_value) = dataset_process_rule result = DatasetService.get_process_rules("dataset-1") @@ -118,9 +116,7 @@ class TestDatasetServiceQueries: def test_get_process_rules_falls_back_to_default_rules_when_missing(self): with patch("services.dataset_service.db") as mock_db: - ( - mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value - ) = None + (mock_db.session.execute.return_value.scalar_one_or_none.return_value) = None result = DatasetService.get_process_rules("dataset-1") @@ -151,7 +147,7 @@ class TestDatasetServiceQueries: dataset = DatasetServiceUnitDataFactory.create_dataset_mock() with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset + mock_db.session.get.return_value = dataset result = DatasetService.get_dataset(dataset.id) @@ -308,7 +304,7 @@ class TestDatasetServiceCreationAndUpdate: account = SimpleNamespace(id="user-1") with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + mock_db.session.scalar.return_value = object() with pytest.raises(DatasetNameDuplicateError, match="Dataset with name Dataset already exists"): DatasetService.create_empty_dataset("tenant-1", "Dataset", None, "economy", account) @@ -319,6 +315,7 @@ class TestDatasetServiceCreationAndUpdate: with ( patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.select"), patch( "services.dataset_service.Dataset", side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs), @@ -326,7 +323,7 @@ class TestDatasetServiceCreationAndUpdate: patch("services.dataset_service.ModelManager") as model_manager_cls, patch.object(DatasetService, "check_embedding_model_setting") as check_embedding, ): - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model dataset = DatasetService.create_empty_dataset( @@ -355,6 +352,7 @@ class TestDatasetServiceCreationAndUpdate: with ( patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.select"), patch( "services.dataset_service.Dataset", side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs), @@ -368,7 +366,7 @@ class TestDatasetServiceCreationAndUpdate: patch.object(DatasetService, "check_embedding_model_setting") as check_embedding, patch.object(DatasetService, "check_reranking_model_setting") as check_reranking, ): - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model dataset = DatasetService.create_empty_dataset( @@ -412,7 +410,7 @@ class TestDatasetServiceCreationAndUpdate: ) with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + mock_db.session.scalar.return_value = object() with pytest.raises(DatasetNameDuplicateError, match="Existing Dataset already exists"): DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) @@ -435,12 +433,13 @@ class TestDatasetServiceCreationAndUpdate: with ( patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.select"), patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), patch("services.dataset_service.generate_incremental_name", return_value="Untitled 2") as generate_name, patch("services.dataset_service.Pipeline", side_effect=pipeline_factory), patch("services.dataset_service.Dataset", side_effect=dataset_factory), ): - mock_db.session.query.return_value.filter_by.return_value.all.return_value = [ + mock_db.session.scalars.return_value.all.return_value = [ SimpleNamespace(name="Untitled"), SimpleNamespace(name="Untitled 1"), ] @@ -465,7 +464,7 @@ class TestDatasetServiceCreationAndUpdate: patch("services.dataset_service.db") as mock_db, patch("services.dataset_service.current_user", SimpleNamespace(id=None)), ): - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with pytest.raises(ValueError, match="Current user or current user id not found"): DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) @@ -520,7 +519,7 @@ class TestDatasetServiceCreationAndUpdate: def test_has_dataset_same_name_returns_true_when_query_matches(self): with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = object() + mock_db.session.scalar.return_value = object() result = DatasetService._has_dataset_same_name("tenant-1", "dataset-1", "Dataset") @@ -630,7 +629,7 @@ class TestDatasetServiceCreationAndUpdate: result = DatasetService._update_internal_dataset(dataset, update_payload.copy(), user) assert result is dataset - updated_values = mock_db.session.query.return_value.filter_by.return_value.update.call_args.args[0] + updated_values = mock_db.session.execute.call_args.args[0].compile().params assert updated_values["name"] == "Updated Dataset" assert updated_values["description"] is None assert updated_values["retrieval_model"] == {"top_k": 4} @@ -658,13 +657,13 @@ class TestDatasetServiceCreationAndUpdate: with patch("services.dataset_service.db") as mock_db: DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") - mock_db.session.query.assert_not_called() + mock_db.session.get.assert_not_called() def test_update_pipeline_knowledge_base_node_data_returns_when_pipeline_is_missing(self): dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1") with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.get.return_value = None DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") @@ -703,7 +702,7 @@ class TestDatasetServiceCreationAndUpdate: patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service), patch("services.dataset_service.Workflow.new", return_value=new_workflow) as workflow_new, ): - mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline + mock_db.session.get.return_value = pipeline DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") @@ -725,7 +724,7 @@ class TestDatasetServiceCreationAndUpdate: patch("services.dataset_service.db") as mock_db, patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service), ): - mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline + mock_db.session.get.return_value = pipeline with pytest.raises(RuntimeError, match="boom"): DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") @@ -1364,7 +1363,7 @@ class TestDatasetServicePermissionsAndLifecycle: ) with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with pytest.raises(NoPermissionError, match="do not have permission"): DatasetService.check_dataset_permission(dataset, user) @@ -1382,7 +1381,7 @@ class TestDatasetServicePermissionsAndLifecycle: with patch("services.dataset_service.db") as mock_db: DatasetService.check_dataset_permission(dataset, user) - mock_db.session.query.assert_not_called() + mock_db.session.scalar.assert_not_called() def test_check_dataset_permission_allows_partial_team_member_with_binding(self): dataset = DatasetServiceUnitDataFactory.create_dataset_mock( @@ -1395,7 +1394,7 @@ class TestDatasetServicePermissionsAndLifecycle: ) with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + mock_db.session.scalar.return_value = object() DatasetService.check_dataset_permission(dataset, user) @@ -1427,7 +1426,7 @@ class TestDatasetServicePermissionsAndLifecycle: ) with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.filter_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] with pytest.raises(NoPermissionError, match="do not have permission"): DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) @@ -1446,9 +1445,7 @@ class TestDatasetServicePermissionsAndLifecycle: def test_get_related_apps_returns_ordered_query_results(self): with patch("services.dataset_service.db") as mock_db: mock_db.desc.side_effect = lambda column: column - mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [ - "relation-1" - ] + mock_db.session.scalars.return_value.all.return_value = ["relation-1"] result = DatasetService.get_related_apps("dataset-1")