mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 06:09:24 +08:00
test: migrate rag pipeline controller tests to testcontainers (#34303)
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
"""Testcontainers integration tests for rag_pipeline controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
@@ -9,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
PipelineTemplateListApi,
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
)
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@@ -18,6 +25,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -38,6 +49,10 @@ class TestPipelineTemplateListApi:
|
||||
|
||||
|
||||
class TestPipelineTemplateDetailApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -99,6 +114,10 @@ class TestPipelineTemplateDetailApi:
|
||||
|
||||
|
||||
class TestCustomizedPipelineTemplateApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_patch_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
@@ -136,35 +155,29 @@ class TestCustomizedPipelineTemplateApi:
|
||||
delete_mock.assert_called_once_with("tpl-1")
|
||||
assert response == 200
|
||||
|
||||
def test_post_success(self, app):
|
||||
def test_post_success(self, app, db_session_with_containers: Session):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
template = MagicMock()
|
||||
template.yaml_content = "yaml-data"
|
||||
tenant_id = str(uuid4())
|
||||
template = PipelineCustomizedTemplate(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Template",
|
||||
description="Test",
|
||||
chunk_structure="hierarchical",
|
||||
icon={"icon": "📘"},
|
||||
position=0,
|
||||
yaml_content="yaml-data",
|
||||
install_count=0,
|
||||
language="en-US",
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(template)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = template
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
with app.test_request_context("/"):
|
||||
response, status = method(api, template.id)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": "yaml-data"}
|
||||
@@ -173,32 +186,16 @@ class TestCustomizedPipelineTemplateApi:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "tpl-1")
|
||||
method(api, str(uuid4()))
|
||||
|
||||
|
||||
class TestPublishCustomizedPipelineTemplateApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
Reference in New Issue
Block a user