mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 06:06:13 +08:00
refactor: select in service API dataset document and segment controllers (#34101)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from uuid import UUID
|
||||
from flask import request, send_file
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy import desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@@ -155,7 +155,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
@@ -238,7 +240,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text."""
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
|
||||
)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
@@ -315,7 +319,9 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
@@ -425,7 +431,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
@@ -515,7 +523,9 @@ class DocumentListApi(DatasetApiResource):
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
query_params = DocumentListQuery.model_validate(request.args.to_dict())
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@@ -609,7 +619,9 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
batch = str(batch)
|
||||
tenant_id = str(tenant_id)
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# get documents
|
||||
@@ -619,20 +631,23 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
@@ -822,7 +837,9 @@ class DocumentApi(DatasetApiResource):
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# get dataset info
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
from flask import request
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@@ -92,7 +93,9 @@ class SegmentApi(DatasetApiResource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
@@ -150,7 +153,9 @@ class SegmentApi(DatasetApiResource):
|
||||
# check dataset
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
@@ -220,7 +225,9 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
@@ -254,7 +261,9 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
@@ -301,7 +310,9 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
@@ -344,7 +355,9 @@ class ChildChunkApi(DatasetApiResource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@@ -402,7 +415,9 @@ class ChildChunkApi(DatasetApiResource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get child chunks."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@@ -468,7 +483,9 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Delete child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@@ -527,7 +544,9 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Update child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
|
||||
@@ -788,7 +788,7 @@ class TestSegmentApiGet:
|
||||
"""Test successful segment list retrieval."""
|
||||
# Arrange
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
|
||||
mock_marshal.return_value = [{"id": mock_segment.id}]
|
||||
@@ -813,7 +813,7 @@ class TestSegmentApiGet:
|
||||
"""Test 404 when dataset not found."""
|
||||
# Arrange
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
@@ -833,7 +833,7 @@ class TestSegmentApiGet:
|
||||
"""Test 404 when document not found."""
|
||||
# Arrange
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
@@ -899,7 +899,7 @@ class TestSegmentApiPost:
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc.indexing_status = "completed"
|
||||
@@ -950,7 +950,7 @@ class TestSegmentApiPost:
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc.indexing_status = "completed"
|
||||
@@ -992,7 +992,7 @@ class TestSegmentApiPost:
|
||||
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc.indexing_status = "indexing" # Not completed
|
||||
@@ -1043,7 +1043,7 @@ class TestDatasetSegmentApiDelete:
|
||||
"""Test successful segment deletion."""
|
||||
# Arrange
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
|
||||
mock_doc = Mock()
|
||||
@@ -1087,7 +1087,7 @@ class TestDatasetSegmentApiDelete:
|
||||
"""Test 404 when segment not found."""
|
||||
# Arrange
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc.indexing_status = "completed"
|
||||
@@ -1129,7 +1129,7 @@ class TestDatasetSegmentApiDelete:
|
||||
"""Test 404 when dataset not found for delete."""
|
||||
# Arrange
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
@@ -1163,7 +1163,7 @@ class TestDatasetSegmentApiDelete:
|
||||
"""Test 404 when document not found for delete."""
|
||||
# Arrange
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc_svc.get_document.return_value = None
|
||||
|
||||
@@ -1233,7 +1233,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
@@ -1280,7 +1280,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
"""Test 404 when dataset not found for update."""
|
||||
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id",
|
||||
@@ -1321,7 +1321,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = None
|
||||
@@ -1370,7 +1370,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
):
|
||||
"""Test successful single segment retrieval."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
@@ -1405,7 +1405,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
):
|
||||
"""Test 404 when dataset not found."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id",
|
||||
@@ -1436,7 +1436,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
):
|
||||
"""Test 404 when document not found."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc_svc.get_document.return_value = None
|
||||
|
||||
@@ -1471,7 +1471,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
):
|
||||
"""Test 404 when segment not found."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = None
|
||||
@@ -1515,7 +1515,7 @@ class TestChildChunkApiGet:
|
||||
):
|
||||
"""Test successful child chunk list retrieval."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = Mock()
|
||||
|
||||
@@ -1554,7 +1554,7 @@ class TestChildChunkApiGet:
|
||||
):
|
||||
"""Test 404 when dataset not found."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks",
|
||||
@@ -1583,7 +1583,7 @@ class TestChildChunkApiGet:
|
||||
):
|
||||
"""Test 404 when document not found."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
@@ -1615,7 +1615,7 @@ class TestChildChunkApiGet:
|
||||
):
|
||||
"""Test 404 when segment not found."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = None
|
||||
|
||||
@@ -1676,7 +1676,7 @@ class TestChildChunkApiPost:
|
||||
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = Mock()
|
||||
mock_child = Mock()
|
||||
@@ -1717,7 +1717,7 @@ class TestChildChunkApiPost:
|
||||
"""Test 404 when dataset not found."""
|
||||
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks",
|
||||
@@ -1755,7 +1755,7 @@ class TestChildChunkApiPost:
|
||||
"""Test 404 when segment not found."""
|
||||
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = None
|
||||
|
||||
@@ -1808,7 +1808,7 @@ class TestDatasetChildChunkApiDelete:
|
||||
):
|
||||
"""Test successful child chunk deletion."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
@@ -1858,7 +1858,7 @@ class TestDatasetChildChunkApiDelete:
|
||||
):
|
||||
"""Test 404 when child chunk not found."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
|
||||
segment_id = str(uuid.uuid4())
|
||||
@@ -1899,7 +1899,7 @@ class TestDatasetChildChunkApiDelete:
|
||||
):
|
||||
"""Test 404 when segment does not belong to the document."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
|
||||
segment_id = str(uuid.uuid4())
|
||||
@@ -1939,7 +1939,7 @@ class TestDatasetChildChunkApiDelete:
|
||||
):
|
||||
"""Test 404 when child chunk does not belong to the segment."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
|
||||
segment_id = str(uuid.uuid4())
|
||||
|
||||
@@ -717,7 +717,7 @@ class TestDocumentApiDelete:
|
||||
dataset_id = str(uuid.uuid4())
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc_svc.get_document.return_value = mock_document
|
||||
mock_doc_svc.check_archived.return_value = False
|
||||
@@ -746,7 +746,7 @@ class TestDocumentApiDelete:
|
||||
document_id = str(uuid.uuid4())
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc_svc.get_document.return_value = None
|
||||
|
||||
@@ -767,7 +767,7 @@ class TestDocumentApiDelete:
|
||||
dataset_id = str(uuid.uuid4())
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc_svc.get_document.return_value = mock_document
|
||||
mock_doc_svc.check_archived.return_value = True
|
||||
@@ -788,7 +788,7 @@ class TestDocumentApiDelete:
|
||||
# Arrange
|
||||
dataset_id = str(uuid.uuid4())
|
||||
document_id = str(uuid.uuid4())
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
@@ -809,7 +809,7 @@ class TestDocumentListApi:
|
||||
def test_list_documents_success(self, mock_db, mock_doc_svc, mock_marshal, app, mock_tenant, mock_dataset):
|
||||
"""Test successful document list retrieval."""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_pagination = Mock()
|
||||
mock_pagination.items = [Mock(), Mock()]
|
||||
@@ -838,7 +838,7 @@ class TestDocumentListApi:
|
||||
def test_list_documents_dataset_not_found(self, mock_db, app, mock_tenant, mock_dataset):
|
||||
"""Test 404 when dataset not found."""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
@@ -860,8 +860,6 @@ class TestDocumentIndexingStatusApi:
|
||||
"""Test successful indexing status retrieval."""
|
||||
# Arrange
|
||||
batch_id = "batch_123"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc.id = str(uuid.uuid4())
|
||||
mock_doc.is_paused = False
|
||||
@@ -877,8 +875,8 @@ class TestDocumentIndexingStatusApi:
|
||||
|
||||
mock_doc_svc.get_batch_documents.return_value = [mock_doc]
|
||||
|
||||
# Mock segment count queries
|
||||
mock_db.session.query.return_value.where.return_value.where.return_value.count.return_value = 5
|
||||
# scalar() called 3 times: dataset lookup, completed_segments count, total_segments count
|
||||
mock_db.session.scalar.side_effect = [mock_dataset, 5, 5]
|
||||
mock_marshal.return_value = {"id": mock_doc.id, "indexing_status": "completed"}
|
||||
|
||||
# Act
|
||||
@@ -898,7 +896,7 @@ class TestDocumentIndexingStatusApi:
|
||||
"""Test 404 when dataset not found."""
|
||||
# Arrange
|
||||
batch_id = "batch_123"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
@@ -915,7 +913,7 @@ class TestDocumentIndexingStatusApi:
|
||||
"""Test 404 when no documents found for batch."""
|
||||
# Arrange
|
||||
batch_id = "batch_empty"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_batch_documents.return_value = []
|
||||
|
||||
# Act & Assert
|
||||
@@ -986,7 +984,7 @@ class TestDocumentAddByTextApi:
|
||||
# Arrange — neutralise billing decorators
|
||||
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_current_user.id = str(uuid.uuid4())
|
||||
|
||||
@@ -1035,7 +1033,7 @@ class TestDocumentAddByTextApi:
|
||||
# Arrange — neutralise billing decorators
|
||||
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
@@ -1064,7 +1062,7 @@ class TestDocumentAddByTextApi:
|
||||
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
|
||||
mock_dataset.indexing_technique = None
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
@@ -1150,7 +1148,7 @@ class TestDocumentUpdateByTextApiPost:
|
||||
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_dataset.latest_process_rule = Mock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_current_user.id = "user-1"
|
||||
mock_upload = Mock()
|
||||
@@ -1193,7 +1191,7 @@ class TestDocumentUpdateByTextApiPost:
|
||||
):
|
||||
"""Test ValueError when dataset not found."""
|
||||
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
with app.test_request_context(
|
||||
@@ -1232,7 +1230,7 @@ class TestDocumentAddByFileApiPost:
|
||||
):
|
||||
"""Test ValueError when dataset not found."""
|
||||
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
@@ -1263,7 +1261,7 @@ class TestDocumentAddByFileApiPost:
|
||||
"""Test ValueError when dataset is external."""
|
||||
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_dataset.provider = "external"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
@@ -1298,7 +1296,7 @@ class TestDocumentAddByFileApiPost:
|
||||
mock_dataset.provider = "vendor"
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_dataset.chunk_structure = None
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/document/create_by_file",
|
||||
@@ -1328,7 +1326,7 @@ class TestDocumentAddByFileApiPost:
|
||||
mock_dataset.provider = "vendor"
|
||||
mock_dataset.indexing_technique = None
|
||||
mock_dataset.chunk_structure = None
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
@@ -1366,7 +1364,7 @@ class TestDocumentUpdateByFileApiPost:
|
||||
):
|
||||
"""Test ValueError when dataset not found."""
|
||||
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
@@ -1402,7 +1400,7 @@ class TestDocumentUpdateByFileApiPost:
|
||||
"""Test ValueError when dataset is external."""
|
||||
_setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_dataset.provider = "external"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
@@ -1450,7 +1448,7 @@ class TestDocumentUpdateByFileApiPost:
|
||||
mock_dataset.chunk_structure = None
|
||||
mock_dataset.latest_process_rule = Mock()
|
||||
mock_dataset.created_by_account = Mock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_current_user.id = "user-1"
|
||||
mock_upload = Mock()
|
||||
|
||||
@@ -10,13 +10,13 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
|
||||
Reference in New Issue
Block a user