diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py similarity index 66% rename from api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py rename to api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py index a72ad45110c..cb678928782 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -1,5 +1,11 @@ +"""Testcontainers integration tests for rag_pipeline_import controller endpoints.""" + +from __future__ import annotations + from unittest.mock import MagicMock, patch +import pytest + from controllers.console import console_ns from controllers.console.datasets.rag_pipeline.rag_pipeline_import import ( RagPipelineExportApi, @@ -18,6 +24,10 @@ def unwrap(func): class TestRagPipelineImportApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def _payload(self, mode="create"): return { "mode": mode, @@ -30,7 +40,6 @@ class TestRagPipelineImportApi: method = unwrap(api.post) payload = self._payload() - user = MagicMock() result = MagicMock() result.status = "completed" @@ -39,13 +48,6 @@ class TestRagPipelineImportApi: service = MagicMock() service.import_rag_pipeline.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), @@ -53,14 +55,6 @@ class TestRagPipelineImportApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", return_value=(user, "tenant"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -76,7 +70,6 @@ class TestRagPipelineImportApi: method = unwrap(api.post) payload = self._payload() - user = MagicMock() result = MagicMock() result.status = ImportStatus.FAILED @@ -85,13 +78,6 @@ class TestRagPipelineImportApi: service = MagicMock() service.import_rag_pipeline.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), @@ -99,14 +85,6 @@ class TestRagPipelineImportApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", return_value=(user, "tenant"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -122,7 +100,6 @@ class TestRagPipelineImportApi: method = unwrap(api.post) payload = self._payload() - user = MagicMock() result = MagicMock() result.status = ImportStatus.PENDING @@ -131,13 +108,6 @@ class TestRagPipelineImportApi: service = MagicMock() service.import_rag_pipeline.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), @@ -145,14 +115,6 @@ class TestRagPipelineImportApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", return_value=(user, "tenant"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -165,6 +127,10 @@ class TestRagPipelineImportApi: class TestRagPipelineImportConfirmApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_confirm_success(self, app): api = RagPipelineImportConfirmApi() method = unwrap(api.post) @@ -177,27 +143,12 @@ class TestRagPipelineImportConfirmApi: service = MagicMock() service.confirm_import.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/"), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", return_value=(user, "tenant"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -220,27 +171,12 @@ class TestRagPipelineImportConfirmApi: service = MagicMock() service.confirm_import.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/"), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", return_value=(user, "tenant"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -253,6 +189,10 @@ class TestRagPipelineImportConfirmApi: class TestRagPipelineImportCheckDependenciesApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_success(self, app): api = RagPipelineImportCheckDependenciesApi() method = unwrap(api.get) @@ -264,23 +204,8 @@ class TestRagPipelineImportCheckDependenciesApi: service = MagicMock() service.check_dependencies.return_value = result - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, @@ -293,6 +218,10 @@ class TestRagPipelineImportCheckDependenciesApi: class TestRagPipelineExportApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_with_include_secret(self, app): api = RagPipelineExportApi() method = unwrap(api.get) @@ -301,23 +230,8 @@ class TestRagPipelineExportApi: service = MagicMock() service.export_rag_pipeline_dsl.return_value = {"yaml": "data"} - fake_db = MagicMock() - fake_db.engine = MagicMock() - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = MagicMock() - session_ctx.__exit__.return_value = None - with ( app.test_request_context("/?include_secret=true"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service,