mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 02:19:20 +08:00
refactor: select in dataset_service (DocumentService class) (#34528)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1400,8 +1400,8 @@ class DocumentService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_document(dataset_id: str, document_id: str | None = None) -> Document | None:
|
def get_document(dataset_id: str, document_id: str | None = None) -> Document | None:
|
||||||
if document_id:
|
if document_id:
|
||||||
document = (
|
document = db.session.scalar(
|
||||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||||
)
|
)
|
||||||
return document
|
return document
|
||||||
else:
|
else:
|
||||||
@@ -1630,7 +1630,7 @@ class DocumentService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_document_by_id(document_id: str) -> Document | None:
|
def get_document_by_id(document_id: str) -> Document | None:
|
||||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
document = db.session.get(Document, document_id)
|
||||||
|
|
||||||
return document
|
return document
|
||||||
|
|
||||||
@@ -1695,7 +1695,7 @@ class DocumentService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_document_file_detail(file_id: str):
|
def get_document_file_detail(file_id: str):
|
||||||
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
|
file_detail = db.session.get(UploadFile, file_id)
|
||||||
return file_detail
|
return file_detail
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -1769,9 +1769,11 @@ class DocumentService:
|
|||||||
document.name = name
|
document.name = name
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict:
|
if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict:
|
||||||
db.session.query(UploadFile).where(
|
db.session.execute(
|
||||||
UploadFile.id == document.data_source_info_dict["upload_file_id"]
|
update(UploadFile)
|
||||||
).update({UploadFile.name: name})
|
.where(UploadFile.id == document.data_source_info_dict["upload_file_id"])
|
||||||
|
.values(name=name)
|
||||||
|
)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@@ -1858,8 +1860,8 @@ class DocumentService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_documents_position(dataset_id):
|
def get_documents_position(dataset_id):
|
||||||
document = (
|
document = db.session.scalar(
|
||||||
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
select(Document).where(Document.dataset_id == dataset_id).order_by(Document.position.desc()).limit(1)
|
||||||
)
|
)
|
||||||
if document:
|
if document:
|
||||||
return document.position + 1
|
return document.position + 1
|
||||||
@@ -2016,28 +2018,28 @@ class DocumentService:
|
|||||||
if not knowledge_config.data_source.info_list.file_info_list:
|
if not knowledge_config.data_source.info_list.file_info_list:
|
||||||
raise ValueError("File source info is required")
|
raise ValueError("File source info is required")
|
||||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||||
files = (
|
files = list(
|
||||||
db.session.query(UploadFile)
|
db.session.scalars(
|
||||||
.where(
|
select(UploadFile).where(
|
||||||
UploadFile.tenant_id == dataset.tenant_id,
|
UploadFile.tenant_id == dataset.tenant_id,
|
||||||
UploadFile.id.in_(upload_file_list),
|
UploadFile.id.in_(upload_file_list),
|
||||||
)
|
)
|
||||||
.all()
|
).all()
|
||||||
)
|
)
|
||||||
if len(files) != len(set(upload_file_list)):
|
if len(files) != len(set(upload_file_list)):
|
||||||
raise FileNotExistsError("One or more files not found.")
|
raise FileNotExistsError("One or more files not found.")
|
||||||
|
|
||||||
file_names = [file.name for file in files]
|
file_names = [file.name for file in files]
|
||||||
db_documents = (
|
db_documents = list(
|
||||||
db.session.query(Document)
|
db.session.scalars(
|
||||||
.where(
|
select(Document).where(
|
||||||
Document.dataset_id == dataset.id,
|
Document.dataset_id == dataset.id,
|
||||||
Document.tenant_id == current_user.current_tenant_id,
|
Document.tenant_id == current_user.current_tenant_id,
|
||||||
Document.data_source_type == DataSourceType.UPLOAD_FILE,
|
Document.data_source_type == DataSourceType.UPLOAD_FILE,
|
||||||
Document.enabled == True,
|
Document.enabled == True,
|
||||||
Document.name.in_(file_names),
|
Document.name.in_(file_names),
|
||||||
)
|
)
|
||||||
.all()
|
).all()
|
||||||
)
|
)
|
||||||
documents_map = {document.name: document for document in db_documents}
|
documents_map = {document.name: document for document in db_documents}
|
||||||
for file in files:
|
for file in files:
|
||||||
@@ -2083,15 +2085,15 @@ class DocumentService:
|
|||||||
raise ValueError("No notion info list found.")
|
raise ValueError("No notion info list found.")
|
||||||
exist_page_ids = []
|
exist_page_ids = []
|
||||||
exist_document = {}
|
exist_document = {}
|
||||||
documents = (
|
documents = list(
|
||||||
db.session.query(Document)
|
db.session.scalars(
|
||||||
.filter_by(
|
select(Document).where(
|
||||||
dataset_id=dataset.id,
|
Document.dataset_id == dataset.id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
Document.tenant_id == current_user.current_tenant_id,
|
||||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
Document.data_source_type == DataSourceType.NOTION_IMPORT,
|
||||||
enabled=True,
|
Document.enabled == True,
|
||||||
)
|
)
|
||||||
.all()
|
).all()
|
||||||
)
|
)
|
||||||
if documents:
|
if documents:
|
||||||
for document in documents:
|
for document in documents:
|
||||||
@@ -2522,14 +2524,15 @@ class DocumentService:
|
|||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
|
|
||||||
documents_count = (
|
documents_count = (
|
||||||
db.session.query(Document)
|
db.session.scalar(
|
||||||
.where(
|
select(func.count(Document.id)).where(
|
||||||
Document.completed_at.isnot(None),
|
Document.completed_at.isnot(None),
|
||||||
Document.enabled == True,
|
Document.enabled == True,
|
||||||
Document.archived == False,
|
Document.archived == False,
|
||||||
Document.tenant_id == current_user.current_tenant_id,
|
Document.tenant_id == current_user.current_tenant_id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
.count()
|
or 0
|
||||||
)
|
)
|
||||||
return documents_count
|
return documents_count
|
||||||
|
|
||||||
@@ -2579,10 +2582,10 @@ class DocumentService:
|
|||||||
raise ValueError("No file info list found.")
|
raise ValueError("No file info list found.")
|
||||||
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
|
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
|
||||||
for file_id in upload_file_list:
|
for file_id in upload_file_list:
|
||||||
file = (
|
file = db.session.scalar(
|
||||||
db.session.query(UploadFile)
|
select(UploadFile)
|
||||||
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# raise error if file not found
|
# raise error if file not found
|
||||||
@@ -2599,8 +2602,8 @@ class DocumentService:
|
|||||||
notion_info_list = document_data.data_source.info_list.notion_info_list
|
notion_info_list = document_data.data_source.info_list.notion_info_list
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
workspace_id = notion_info.workspace_id
|
workspace_id = notion_info.workspace_id
|
||||||
data_source_binding = (
|
data_source_binding = db.session.scalar(
|
||||||
db.session.query(DataSourceOauthBinding)
|
select(DataSourceOauthBinding)
|
||||||
.where(
|
.where(
|
||||||
sa.and_(
|
sa.and_(
|
||||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
@@ -2609,7 +2612,7 @@ class DocumentService:
|
|||||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
if not data_source_binding:
|
if not data_source_binding:
|
||||||
raise ValueError("Data source binding not found.")
|
raise ValueError("Data source binding not found.")
|
||||||
@@ -2654,8 +2657,10 @@ class DocumentService:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
# update document segment
|
# update document segment
|
||||||
|
|
||||||
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
|
db.session.execute(
|
||||||
{DocumentSegment.status: SegmentStatus.RE_SEGMENT}
|
update(DocumentSegment)
|
||||||
|
.where(DocumentSegment.document_id == document.id)
|
||||||
|
.values(status=SegmentStatus.RE_SEGMENT)
|
||||||
)
|
)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
# trigger async task
|
# trigger async task
|
||||||
|
|||||||
@@ -90,13 +90,13 @@ class TestDocumentServiceQueryAndDownloadHelpers:
|
|||||||
result = DocumentService.get_document("dataset-1", None)
|
result = DocumentService.get_document("dataset-1", None)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
mock_db.session.query.assert_not_called()
|
mock_db.session.scalar.assert_not_called()
|
||||||
|
|
||||||
def test_get_document_queries_by_dataset_and_document_id(self):
|
def test_get_document_queries_by_dataset_and_document_id(self):
|
||||||
document = DatasetServiceUnitDataFactory.create_document_mock()
|
document = DatasetServiceUnitDataFactory.create_document_mock()
|
||||||
|
|
||||||
with patch("services.dataset_service.db") as mock_db:
|
with patch("services.dataset_service.db") as mock_db:
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = document
|
mock_db.session.scalar.return_value = document
|
||||||
|
|
||||||
result = DocumentService.get_document("dataset-1", "doc-1")
|
result = DocumentService.get_document("dataset-1", "doc-1")
|
||||||
|
|
||||||
@@ -435,7 +435,7 @@ class TestDocumentServiceQueryAndDownloadHelpers:
|
|||||||
upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock()
|
upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock()
|
||||||
|
|
||||||
with patch("services.dataset_service.db") as mock_db:
|
with patch("services.dataset_service.db") as mock_db:
|
||||||
mock_db.session.query.return_value.where.return_value.one_or_none.return_value = upload_file
|
mock_db.session.get.return_value = upload_file
|
||||||
|
|
||||||
result = DocumentService.get_document_file_detail(upload_file.id)
|
result = DocumentService.get_document_file_detail(upload_file.id)
|
||||||
|
|
||||||
@@ -570,7 +570,7 @@ class TestDocumentServiceMutations:
|
|||||||
assert document.name == "New Name"
|
assert document.name == "New Name"
|
||||||
assert document.doc_metadata[BuiltInField.document_name] == "New Name"
|
assert document.doc_metadata[BuiltInField.document_name] == "New Name"
|
||||||
mock_db.session.add.assert_called_once_with(document)
|
mock_db.session.add.assert_called_once_with(document)
|
||||||
mock_db.session.query.return_value.where.return_value.update.assert_called_once()
|
mock_db.session.execute.assert_called()
|
||||||
mock_db.session.commit.assert_called_once()
|
mock_db.session.commit.assert_called_once()
|
||||||
|
|
||||||
def test_recover_document_raises_when_document_is_not_paused(self):
|
def test_recover_document_raises_when_document_is_not_paused(self):
|
||||||
@@ -624,9 +624,7 @@ class TestDocumentServiceMutations:
|
|||||||
document = DatasetServiceUnitDataFactory.create_document_mock(position=7)
|
document = DatasetServiceUnitDataFactory.create_document_mock(position=7)
|
||||||
|
|
||||||
with patch("services.dataset_service.db") as mock_db:
|
with patch("services.dataset_service.db") as mock_db:
|
||||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = (
|
mock_db.session.scalar.return_value = document
|
||||||
document
|
|
||||||
)
|
|
||||||
|
|
||||||
result = DocumentService.get_documents_position("dataset-1")
|
result = DocumentService.get_documents_position("dataset-1")
|
||||||
|
|
||||||
@@ -634,7 +632,7 @@ class TestDocumentServiceMutations:
|
|||||||
|
|
||||||
def test_get_documents_position_defaults_to_one_when_dataset_is_empty(self):
|
def test_get_documents_position_defaults_to_one_when_dataset_is_empty(self):
|
||||||
with patch("services.dataset_service.db") as mock_db:
|
with patch("services.dataset_service.db") as mock_db:
|
||||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = None
|
mock_db.session.scalar.return_value = None
|
||||||
|
|
||||||
result = DocumentService.get_documents_position("dataset-1")
|
result = DocumentService.get_documents_position("dataset-1")
|
||||||
|
|
||||||
@@ -869,11 +867,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId:
|
|||||||
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
||||||
patch("services.dataset_service.document_indexing_update_task") as update_task,
|
patch("services.dataset_service.document_indexing_update_task") as update_task,
|
||||||
):
|
):
|
||||||
upload_query = MagicMock()
|
mock_db.session.scalar.return_value = SimpleNamespace(id="file-1", name="upload.txt")
|
||||||
upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt")
|
|
||||||
segment_query = MagicMock()
|
|
||||||
segment_query.filter_by.return_value.update.return_value = 3
|
|
||||||
mock_db.session.query.side_effect = [upload_query, segment_query]
|
|
||||||
|
|
||||||
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
||||||
|
|
||||||
@@ -892,7 +886,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId:
|
|||||||
assert document.created_from == "web"
|
assert document.created_from == "web"
|
||||||
assert document.doc_form == IndexStructureType.QA_INDEX
|
assert document.doc_form == IndexStructureType.QA_INDEX
|
||||||
assert mock_db.session.commit.call_count == 3
|
assert mock_db.session.commit.call_count == 3
|
||||||
segment_query.filter_by.return_value.update.assert_called_once()
|
mock_db.session.execute.assert_called()
|
||||||
update_task.delay.assert_called_once_with(document.dataset_id, document.id)
|
update_task.delay.assert_called_once_with(document.dataset_id, document.id)
|
||||||
|
|
||||||
def test_update_document_with_dataset_id_notion_import_requires_binding(self, account_context):
|
def test_update_document_with_dataset_id_notion_import_requires_binding(self, account_context):
|
||||||
@@ -920,9 +914,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId:
|
|||||||
patch.object(DatasetService, "check_dataset_model_setting"),
|
patch.object(DatasetService, "check_dataset_model_setting"),
|
||||||
patch("services.dataset_service.db") as mock_db,
|
patch("services.dataset_service.db") as mock_db,
|
||||||
):
|
):
|
||||||
binding_query = MagicMock()
|
mock_db.session.scalar.return_value = None
|
||||||
binding_query.where.return_value.first.return_value = None
|
|
||||||
mock_db.session.query.return_value = binding_query
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Data source binding not found"):
|
with pytest.raises(ValueError, match="Data source binding not found"):
|
||||||
DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
||||||
@@ -954,10 +946,6 @@ class TestDocumentServiceUpdateDocumentWithDatasetId:
|
|||||||
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
||||||
patch("services.dataset_service.document_indexing_update_task") as update_task,
|
patch("services.dataset_service.document_indexing_update_task") as update_task,
|
||||||
):
|
):
|
||||||
segment_query = MagicMock()
|
|
||||||
segment_query.filter_by.return_value.update.return_value = 2
|
|
||||||
mock_db.session.query.return_value = segment_query
|
|
||||||
|
|
||||||
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
||||||
|
|
||||||
assert result is document
|
assert result is document
|
||||||
@@ -968,7 +956,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId:
|
|||||||
)
|
)
|
||||||
assert document.name == ""
|
assert document.name == ""
|
||||||
assert document.doc_form == IndexStructureType.PARENT_CHILD_INDEX
|
assert document.doc_form == IndexStructureType.PARENT_CHILD_INDEX
|
||||||
segment_query.filter_by.return_value.update.assert_called_once()
|
mock_db.session.execute.assert_called()
|
||||||
update_task.delay.assert_called_once_with("dataset-1", "doc-1")
|
update_task.delay.assert_called_once_with("dataset-1", "doc-1")
|
||||||
|
|
||||||
|
|
||||||
@@ -1218,11 +1206,10 @@ class TestDocumentServiceSaveDocumentWithDatasetId:
|
|||||||
patch("services.dataset_service.secrets.randbelow", return_value=23),
|
patch("services.dataset_service.secrets.randbelow", return_value=23),
|
||||||
):
|
):
|
||||||
mock_redis.lock.return_value = _make_lock_context()
|
mock_redis.lock.return_value = _make_lock_context()
|
||||||
upload_query = MagicMock()
|
mock_db.session.scalars.return_value.all.side_effect = [
|
||||||
upload_query.where.return_value.all.return_value = [upload_file_a, upload_file_b]
|
[upload_file_a, upload_file_b],
|
||||||
existing_documents_query = MagicMock()
|
[duplicate_document],
|
||||||
existing_documents_query.where.return_value.all.return_value = [duplicate_document]
|
]
|
||||||
mock_db.session.query.side_effect = [upload_query, existing_documents_query]
|
|
||||||
|
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
dataset,
|
dataset,
|
||||||
@@ -1302,9 +1289,7 @@ class TestDocumentServiceSaveDocumentWithDatasetId:
|
|||||||
patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls,
|
patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls,
|
||||||
):
|
):
|
||||||
mock_redis.lock.return_value = _make_lock_context()
|
mock_redis.lock.return_value = _make_lock_context()
|
||||||
notion_documents_query = MagicMock()
|
mock_db.session.scalars.return_value.all.return_value = [existing_keep, existing_remove]
|
||||||
notion_documents_query.filter_by.return_value.all.return_value = [existing_keep, existing_remove]
|
|
||||||
mock_db.session.query.return_value = notion_documents_query
|
|
||||||
|
|
||||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||||
dataset,
|
dataset,
|
||||||
@@ -1474,12 +1459,11 @@ class TestDocumentServiceTenantAndUpdateEdges:
|
|||||||
|
|
||||||
def test_get_tenant_documents_count_returns_query_count(self, account_context):
|
def test_get_tenant_documents_count_returns_query_count(self, account_context):
|
||||||
with patch("services.dataset_service.db") as mock_db:
|
with patch("services.dataset_service.db") as mock_db:
|
||||||
mock_db.session.query.return_value.where.return_value.count.return_value = 12
|
mock_db.session.scalar.return_value = 12
|
||||||
|
|
||||||
result = DocumentService.get_tenant_documents_count()
|
result = DocumentService.get_tenant_documents_count()
|
||||||
|
|
||||||
assert result == 12
|
assert result == 12
|
||||||
mock_db.session.query.return_value.where.return_value.count.assert_called_once()
|
|
||||||
|
|
||||||
def test_update_document_with_dataset_id_uses_automatic_process_rule_payload(self, account_context):
|
def test_update_document_with_dataset_id_uses_automatic_process_rule_payload(self, account_context):
|
||||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||||
@@ -1514,11 +1498,7 @@ class TestDocumentServiceTenantAndUpdateEdges:
|
|||||||
):
|
):
|
||||||
process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES
|
process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES
|
||||||
process_rule_cls.return_value = created_process_rule
|
process_rule_cls.return_value = created_process_rule
|
||||||
upload_query = MagicMock()
|
mock_db.session.scalar.return_value = SimpleNamespace(id="file-1", name="upload.txt")
|
||||||
upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt")
|
|
||||||
segment_query = MagicMock()
|
|
||||||
segment_query.filter_by.return_value.update.return_value = 1
|
|
||||||
mock_db.session.query.side_effect = [upload_query, segment_query]
|
|
||||||
|
|
||||||
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
||||||
|
|
||||||
@@ -1567,7 +1547,7 @@ class TestDocumentServiceTenantAndUpdateEdges:
|
|||||||
patch.object(DatasetService, "check_dataset_model_setting"),
|
patch.object(DatasetService, "check_dataset_model_setting"),
|
||||||
patch("services.dataset_service.db") as mock_db,
|
patch("services.dataset_service.db") as mock_db,
|
||||||
):
|
):
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
mock_db.session.scalar.return_value = None
|
||||||
|
|
||||||
with pytest.raises(FileNotExistsError):
|
with pytest.raises(FileNotExistsError):
|
||||||
DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
||||||
@@ -1618,11 +1598,7 @@ class TestDocumentServiceTenantAndUpdateEdges:
|
|||||||
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
||||||
patch("services.dataset_service.document_indexing_update_task") as update_task,
|
patch("services.dataset_service.document_indexing_update_task") as update_task,
|
||||||
):
|
):
|
||||||
binding_query = MagicMock()
|
mock_db.session.scalar.return_value = SimpleNamespace(id="binding-1")
|
||||||
binding_query.where.return_value.first.return_value = SimpleNamespace(id="binding-1")
|
|
||||||
segment_query = MagicMock()
|
|
||||||
segment_query.filter_by.return_value.update.return_value = 1
|
|
||||||
mock_db.session.query.side_effect = [binding_query, segment_query]
|
|
||||||
|
|
||||||
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
|
||||||
|
|
||||||
@@ -1914,11 +1890,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
|
|||||||
):
|
):
|
||||||
mock_redis.lock.return_value = _make_lock_context()
|
mock_redis.lock.return_value = _make_lock_context()
|
||||||
process_rule_cls.return_value = created_process_rule
|
process_rule_cls.return_value = created_process_rule
|
||||||
upload_query = MagicMock()
|
mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []]
|
||||||
upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")]
|
|
||||||
existing_documents_query = MagicMock()
|
|
||||||
existing_documents_query.where.return_value.all.return_value = []
|
|
||||||
mock_db.session.query.side_effect = [upload_query, existing_documents_query]
|
|
||||||
|
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
|
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
|
||||||
|
|
||||||
@@ -1958,11 +1930,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
|
|||||||
mock_redis.lock.return_value = _make_lock_context()
|
mock_redis.lock.return_value = _make_lock_context()
|
||||||
process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES
|
process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES
|
||||||
process_rule_cls.return_value = created_process_rule
|
process_rule_cls.return_value = created_process_rule
|
||||||
upload_query = MagicMock()
|
mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []]
|
||||||
upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")]
|
|
||||||
existing_documents_query = MagicMock()
|
|
||||||
existing_documents_query.where.return_value.all.return_value = []
|
|
||||||
mock_db.session.query.side_effect = [upload_query, existing_documents_query]
|
|
||||||
|
|
||||||
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
|
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
|
||||||
|
|
||||||
@@ -1996,11 +1964,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
|
|||||||
mock_redis.lock.return_value = _make_lock_context()
|
mock_redis.lock.return_value = _make_lock_context()
|
||||||
process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES
|
process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES
|
||||||
process_rule_cls.return_value = created_process_rule
|
process_rule_cls.return_value = created_process_rule
|
||||||
upload_query = MagicMock()
|
mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []]
|
||||||
upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")]
|
|
||||||
existing_documents_query = MagicMock()
|
|
||||||
existing_documents_query.where.return_value.all.return_value = []
|
|
||||||
mock_db.session.query.side_effect = [upload_query, existing_documents_query]
|
|
||||||
|
|
||||||
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
|
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
|
||||||
|
|
||||||
@@ -2024,9 +1988,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
|
|||||||
patch("services.dataset_service.secrets.randbelow", return_value=23),
|
patch("services.dataset_service.secrets.randbelow", return_value=23),
|
||||||
):
|
):
|
||||||
mock_redis.lock.return_value = _make_lock_context()
|
mock_redis.lock.return_value = _make_lock_context()
|
||||||
upload_query = MagicMock()
|
mock_db.session.scalars.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")]
|
||||||
upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")]
|
|
||||||
mock_db.session.query.return_value = upload_query
|
|
||||||
|
|
||||||
with pytest.raises(FileNotExistsError, match="One or more files not found"):
|
with pytest.raises(FileNotExistsError, match="One or more files not found"):
|
||||||
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
|
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
|
||||||
|
|||||||
Reference in New Issue
Block a user