diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py similarity index 83% rename from api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py rename to api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py index fd38fcbb5e1..64e3de2ca39 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py @@ -1,3 +1,7 @@ +"""Testcontainers integration tests for rag_pipeline_datasets controller endpoints.""" + +from __future__ import annotations + from unittest.mock import MagicMock, patch import pytest @@ -19,6 +23,10 @@ def unwrap(func): class TestCreateRagPipelineDatasetApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def _valid_payload(self): return {"yaml_content": "name: test"} @@ -33,13 +41,6 @@ class TestCreateRagPipelineDatasetApi: mock_service = MagicMock() mock_service.create_rag_pipeline_dataset.return_value = import_info - mock_session_ctx = MagicMock() - mock_session_ctx.__enter__.return_value = MagicMock() - mock_session_ctx.__exit__.return_value = None - - fake_db = MagicMock() - fake_db.engine = MagicMock() - with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), @@ -47,14 +48,6 @@ class TestCreateRagPipelineDatasetApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", return_value=(user, "tenant-1"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session", - return_value=mock_session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", return_value=mock_service, @@ -93,13 +86,6 @@ class TestCreateRagPipelineDatasetApi: mock_service = MagicMock() mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError() - mock_session_ctx = MagicMock() - mock_session_ctx.__enter__.return_value = MagicMock() - mock_session_ctx.__exit__.return_value = None - - fake_db = MagicMock() - fake_db.engine = MagicMock() - with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), @@ -107,14 +93,6 @@ class TestCreateRagPipelineDatasetApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", return_value=(user, "tenant-1"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session", - return_value=mock_session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", return_value=mock_service, @@ -143,6 +121,10 @@ class TestCreateRagPipelineDatasetApi: class TestCreateEmptyRagPipelineDatasetApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_post_success(self, app): api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post)