Merge branch 'main' into feat/marketplace-template

This commit is contained in:
Stephen Zhou
2026-04-01 11:20:07 +08:00
committed by GitHub
11 changed files with 208 additions and 164 deletions

View File

@@ -302,7 +302,7 @@ class PipelineGenerator(BaseAppGenerator):
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
workflow = db.session.get(Workflow, workflow_id)
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(

View File

@@ -9,6 +9,7 @@ from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.variable_loader import VariableLoader
from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
@@ -84,13 +85,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
user_id = None
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
end_user = db.session.get(EndUser, self.application_generate_entity.user_id)
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
pipeline = db.session.get(Pipeline, app_config.app_id)
if not pipeline:
raise ValueError("Pipeline not found")
@@ -213,10 +214,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
.first()
.limit(1)
)
# return workflow
@@ -297,10 +298,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
"""
if isinstance(event, GraphRunFailedEvent):
if document_id and dataset_id:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
document = db.session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if document:
document.indexing_status = "error"

View File

@@ -153,7 +153,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first()
upload_file: UploadFile | None = db.session.get(UploadFile, id)
if not upload_file:
return None
@@ -171,7 +171,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first()
message_file: MessageFile | None = db.session.get(MessageFile, id)
# Check if message_file is not None
if message_file is not None:
@@ -185,7 +185,7 @@ class DatasourceFileManager:
else:
tool_file_id = None
tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
tool_file: ToolFile | None = db.session.get(ToolFile, tool_file_id)
if not tool_file:
return None
@@ -203,7 +203,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
upload_file: UploadFile | None = db.session.get(UploadFile, upload_file_id)
if not upload_file:
return None, None

View File

@@ -10,7 +10,7 @@ from typing import Any
from flask import Flask, current_app
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import select
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
@@ -78,7 +78,7 @@ class IndexingRunner:
continue
# get dataset
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
dataset = db.session.get(Dataset, requeried_document.dataset_id)
if not dataset:
raise ValueError("no dataset found")
@@ -95,7 +95,7 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
current_user = db.session.get(Account, requeried_document.created_by)
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
@@ -137,23 +137,24 @@ class IndexingRunner:
return
# get dataset
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
dataset = db.session.get(Dataset, requeried_document.dataset_id)
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
document_segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == requeried_document.id,
)
).all()
for document_segment in document_segments:
db.session.delete(document_segment)
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.execute(delete(ChildChunk).where(ChildChunk.segment_id == document_segment.id))
db.session.commit()
# get the process rule
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
@@ -167,7 +168,7 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
current_user = db.session.get(Account, requeried_document.created_by)
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
@@ -207,17 +208,18 @@ class IndexingRunner:
return
# get dataset
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
dataset = db.session.get(Dataset, requeried_document.dataset_id)
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
document_segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == requeried_document.id,
)
).all()
documents = []
if document_segments:
@@ -289,7 +291,7 @@ class IndexingRunner:
embedding_model_instance = None
if dataset_id:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
dataset = db.session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not found.")
if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}:
@@ -652,24 +654,26 @@ class IndexingRunner:
@staticmethod
def _process_keyword_index(flask_app, dataset_id, document_id, documents):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
dataset = db.session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("no dataset found")
keyword = Keyword(dataset)
keyword.create(documents)
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
document_ids = [document.metadata["doc_id"] for document in documents]
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
db.session.execute(
update(DocumentSegment)
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
)
.values(
status=SegmentStatus.COMPLETED,
enabled=True,
completed_at=naive_utc_now(),
)
)
db.session.commit()
@@ -703,17 +707,19 @@ class IndexingRunner:
)
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
db.session.execute(
update(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
)
.values(
status=SegmentStatus.COMPLETED,
enabled=True,
completed_at=naive_utc_now(),
)
)
db.session.commit()
@@ -734,10 +740,17 @@ class IndexingRunner:
"""
Update the document indexing status.
"""
count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count()
count = (
db.session.scalar(
select(func.count())
.select_from(DatasetDocument)
.where(DatasetDocument.id == document_id, DatasetDocument.is_paused == True)
)
or 0
)
if count > 0:
raise DocumentIsPausedError()
document = db.session.query(DatasetDocument).filter_by(id=document_id).first()
document = db.session.get(DatasetDocument, document_id)
if not document:
raise DocumentIsDeletedPausedError()
@@ -745,7 +758,7 @@ class IndexingRunner:
if extra_update_params:
update_params.update(extra_update_params)
db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore
db.session.execute(update(DatasetDocument).where(DatasetDocument.id == document_id).values(update_params)) # type: ignore
db.session.commit()
@staticmethod
@@ -753,7 +766,9 @@ class IndexingRunner:
"""
Update the document segment by document id.
"""
db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
db.session.execute(
update(DocumentSegment).where(DocumentSegment.document_id == dataset_document_id).values(update_params)
)
db.session.commit()
def _transform(

View File

@@ -345,7 +345,7 @@ def test_generate_raises_when_workflow_not_found(generator, mocker):
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
session.get.return_value = None
mocker.patch.object(module.db, "session", session)
with pytest.raises(ValueError):

View File

@@ -80,9 +80,7 @@ def test_get_workflow_returns_workflow(mocker, runner):
pipeline = MagicMock(tenant_id="tenant", id="pipe")
workflow = MagicMock(id="wf")
query = MagicMock()
query.where.return_value.first.return_value = workflow
mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query)))
mocker.patch.object(module.db, "session", MagicMock(scalar=MagicMock(return_value=workflow)))
result = runner.get_workflow(pipeline=pipeline, workflow_id="wf")
@@ -115,11 +113,8 @@ def test_init_rag_pipeline_graph_not_found(mocker, runner):
def test_update_document_status_on_failure(mocker, runner):
document = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = document
session = MagicMock()
session.query.return_value = query
session.scalar.return_value = document
mocker.patch.object(module.db, "session", session)
event = GraphRunFailedEvent(error="boom")
@@ -189,14 +184,10 @@ def test_run_single_iteration_path(mocker):
app_generate_entity.single_iteration_run = MagicMock()
pipeline = MagicMock(id="pipe")
query_pipeline = MagicMock()
query_pipeline.where.return_value.first.return_value = pipeline
query_end_user = MagicMock()
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
end_user = MagicMock(session_id="sess")
session = MagicMock()
session.query.side_effect = [query_end_user, query_pipeline]
session.get.side_effect = [end_user, pipeline]
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(
@@ -241,14 +232,10 @@ def test_run_normal_path_builds_graph(mocker):
app_generate_entity = _build_app_generate_entity()
pipeline = MagicMock(id="pipe")
query_pipeline = MagicMock()
query_pipeline.where.return_value.first.return_value = pipeline
query_end_user = MagicMock()
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
end_user = MagicMock(session_id="sess")
session = MagicMock()
session.query.side_effect = [query_end_user, query_pipeline]
session.get.side_effect = [end_user, pipeline]
mocker.patch.object(module.db, "session", session)
workflow = MagicMock(

View File

@@ -287,9 +287,7 @@ class TestDatasourceFileManager:
mock_upload_file.key = "some_key"
mock_upload_file.mime_type = "image/png"
mock_query = mock_db.session.query.return_value
mock_where = mock_query.where.return_value
mock_where.first.return_value = mock_upload_file
mock_db.session.get.return_value = mock_upload_file
mock_storage.load_once.return_value = b"file content"
@@ -300,7 +298,7 @@ class TestDatasourceFileManager:
assert result == (b"file content", "image/png")
# Case: Not found
mock_where.first.return_value = None
mock_db.session.get.return_value = None
assert DatasourceFileManager.get_file_binary("unknown") is None
@patch("core.datasource.datasource_file_manager.db")
@@ -314,16 +312,14 @@ class TestDatasourceFileManager:
mock_tool_file.file_key = "tool_key"
mock_tool_file.mimetype = "image/png"
# Mock query sequence
def mock_query(model):
m = MagicMock()
def mock_get(model, id):
if model == MessageFile:
m.where.return_value.first.return_value = mock_message_file
return mock_message_file
elif model == ToolFile:
m.where.return_value.first.return_value = mock_tool_file
return m
return mock_tool_file
return None
mock_db.session.query.side_effect = mock_query
mock_db.session.get.side_effect = mock_get
mock_storage.load_once.return_value = b"tool content"
# Execute
@@ -344,15 +340,12 @@ class TestDatasourceFileManager:
mock_tool_file.file_key = "tk"
mock_tool_file.mimetype = "image/png"
def mock_query(model):
m = MagicMock()
def mock_get(model, id):
if model == MessageFile:
m.where.return_value.first.return_value = mock_message_file
else:
m.where.return_value.first.return_value = mock_tool_file
return m
return mock_message_file
return mock_tool_file
mock_db.session.query.side_effect = mock_query
mock_db.session.get.side_effect = mock_get
mock_storage.load_once.return_value = b"bits"
result = DatasourceFileManager.get_file_binary_by_message_file_id("m")
@@ -361,27 +354,20 @@ class TestDatasourceFileManager:
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
def test_get_file_binary_by_message_file_id_failures(self, mock_storage, mock_db):
# Setup common mock
mock_query_obj = MagicMock()
mock_db.session.query.return_value = mock_query_obj
mock_query_obj.where.return_value.first.return_value = None
# Case 1: Message file not found
mock_db.session.get.return_value = None
assert DatasourceFileManager.get_file_binary_by_message_file_id("none") is None
# Case 2: Message file found but tool file not found
mock_message_file = MagicMock(spec=MessageFile)
mock_message_file.url = None
def mock_query_v2(model):
m = MagicMock()
def mock_get_v2(model, id):
if model == MessageFile:
m.where.return_value.first.return_value = mock_message_file
else:
m.where.return_value.first.return_value = None
return m
return mock_message_file
return None
mock_db.session.query.side_effect = mock_query_v2
mock_db.session.get.side_effect = mock_get_v2
assert DatasourceFileManager.get_file_binary_by_message_file_id("msg_id") is None
@patch("core.datasource.datasource_file_manager.db")
@@ -392,7 +378,7 @@ class TestDatasourceFileManager:
mock_upload_file.key = "upload_key"
mock_upload_file.mime_type = "text/plain"
mock_db.session.query.return_value.where.return_value.first.return_value = mock_upload_file
mock_db.session.get.return_value = mock_upload_file
mock_storage.load_stream.return_value = iter([b"chunk1", b"chunk2"])
@@ -404,7 +390,7 @@ class TestDatasourceFileManager:
assert list(stream) == [b"chunk1", b"chunk2"]
# Case: Not found
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_db.session.get.return_value = None
stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("none")
assert stream is None
assert mimetype is None

View File

@@ -795,33 +795,21 @@ class TestIndexingRunnerRun:
doc = sample_dataset_documents[0]
# Mock database queries
mock_dependencies["db"].session.get.return_value = doc
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = doc.dataset_id
mock_dataset.tenant_id = doc.tenant_id
mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
mock_current_user = MagicMock()
mock_current_user.set_tenant_id = MagicMock()
get_dispatch = {"Document": doc, "Dataset": mock_dataset, "Account": mock_current_user}
mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__)
mock_process_rule = Mock(spec=DatasetProcessRule)
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
mock_dependencies["db"].session.scalar.return_value = mock_process_rule
# Mock current_user (Account) for _transform
mock_current_user = MagicMock()
mock_current_user.set_tenant_id = MagicMock()
# Setup db.session.query to return different results based on the model
def mock_query_side_effect(model):
mock_query_result = MagicMock()
if model.__name__ == "Dataset":
mock_query_result.filter_by.return_value.first.return_value = mock_dataset
elif model.__name__ == "Account":
mock_query_result.filter_by.return_value.first.return_value = mock_current_user
return mock_query_result
mock_dependencies["db"].session.query.side_effect = mock_query_side_effect
# Mock processor
mock_processor = MagicMock()
mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor
@@ -891,10 +879,11 @@ class TestIndexingRunnerRun:
doc = sample_dataset_documents[0]
# Mock database
mock_dependencies["db"].session.get.return_value = doc
mock_dataset = Mock(spec=Dataset)
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
mock_dataset.tenant_id = doc.tenant_id
get_dispatch = {"Document": doc, "Dataset": mock_dataset}
mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__)
mock_process_rule = Mock(spec=DatasetProcessRule)
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
@@ -917,11 +906,12 @@ class TestIndexingRunnerRun:
runner = IndexingRunner()
doc = sample_dataset_documents[0]
# Mock database to raise ObjectDeletedError
mock_dependencies["db"].session.get.return_value = doc
# Mock database
mock_dataset = Mock(spec=Dataset)
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
mock_dataset.tenant_id = doc.tenant_id
get_dispatch = {"Document": doc, "Dataset": mock_dataset}
mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__)
mock_process_rule = Mock(spec=DatasetProcessRule)
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
@@ -945,17 +935,21 @@ class TestIndexingRunnerRun:
docs = sample_dataset_documents
# Mock database
def get_side_effect(model_class, doc_id):
for doc in docs:
if doc.id == doc_id:
return doc
return None
mock_dependencies["db"].session.get.side_effect = get_side_effect
mock_dataset = Mock(spec=Dataset)
mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset
mock_current_user = MagicMock()
mock_current_user.set_tenant_id = MagicMock()
doc_map = {doc.id: doc for doc in docs}
model_dispatch = {"Dataset": mock_dataset, "Account": mock_current_user}
def get_side_effect(model_class, id):
name = model_class.__name__
if name == "Document":
return doc_map.get(id)
return model_dispatch.get(name)
mock_dependencies["db"].session.get.side_effect = get_side_effect
mock_process_rule = Mock(spec=DatasetProcessRule)
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
@@ -1035,9 +1029,8 @@ class TestIndexingRunnerRetryLogic:
mock_document = Mock(spec=DatasetDocument)
mock_document.id = document_id
mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 0
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_dependencies["db"].session.query.return_value.filter_by.return_value.update.return_value = None
mock_dependencies["db"].session.scalar.return_value = 0
mock_dependencies["db"].session.get.return_value = mock_document
# Act
IndexingRunner._update_document_index_status(
@@ -1053,7 +1046,7 @@ class TestIndexingRunnerRetryLogic:
"""Test document status update when document is paused."""
# Arrange
document_id = str(uuid.uuid4())
mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 1
mock_dependencies["db"].session.scalar.return_value = 1
# Act & Assert
with pytest.raises(DocumentIsPausedError):
@@ -1063,8 +1056,8 @@ class TestIndexingRunnerRetryLogic:
"""Test document status update when document is deleted."""
# Arrange
document_id = str(uuid.uuid4())
mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 0
mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = None
mock_dependencies["db"].session.scalar.return_value = 0
mock_dependencies["db"].session.get.return_value = None
# Act & Assert
with pytest.raises(DocumentIsDeletedPausedError):

View File

@@ -91,6 +91,15 @@ const createPayload = (overrides: Partial<VariableAssignerNodeType> = {}): Varia
...overrides,
})
const createPayloadWithoutAdvancedSettings = (): VariableAssignerNodeType => {
const payload = createPayload() as Omit<VariableAssignerNodeType, 'advanced_settings'> & {
advanced_settings?: VariableAssignerNodeType['advanced_settings']
}
delete payload.advanced_settings
return payload as VariableAssignerNodeType
}
describe('useConfig', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -252,4 +261,25 @@ describe('useConfig', () => {
advanced_settings: expect.objectContaining({ group_enabled: false }),
}))
})
it('should not throw when enabling groups with missing advanced settings', () => {
const { result } = renderHook(() => useConfig('assigner-node', createPayloadWithoutAdvancedSettings()))
expect(() => {
result.current.handleGroupEnabledChange(true)
}).not.toThrow()
expect(mockHandleOutVarRenameChange).toHaveBeenCalledWith(
'assigner-node',
['assigner-node', 'output'],
['assigner-node', 'Group1', 'output'],
)
expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({
advanced_settings: expect.objectContaining({
group_enabled: true,
groups: [expect.objectContaining({ group_name: 'Group1' })],
}),
}))
expect(mockDeleteNodeInspectorVars).toHaveBeenCalledWith('assigner-node')
})
})

View File

@@ -26,7 +26,13 @@ export const updateNestedVarGroupItem = (
groupId: string,
payload: VarGroupItem,
) => produce(inputs, (draft) => {
if (!draft.advanced_settings)
return
const index = draft.advanced_settings.groups.findIndex(item => item.groupId === groupId)
if (index < 0)
return
draft.advanced_settings.groups[index] = {
...draft.advanced_settings.groups[index],
...payload,
@@ -37,6 +43,11 @@ export const removeGroupByIndex = (
inputs: VariableAssignerNodeType,
index: number,
) => produce(inputs, (draft) => {
if (!draft.advanced_settings)
return
if (index < 0 || index >= draft.advanced_settings.groups.length)
return
draft.advanced_settings.groups.splice(index, 1)
})
@@ -70,7 +81,8 @@ export const toggleGroupEnabled = ({
export const addGroup = (inputs: VariableAssignerNodeType) => {
let maxInGroupName = 1
inputs.advanced_settings.groups.forEach((item) => {
const groups = inputs.advanced_settings?.groups ?? []
groups.forEach((item) => {
const match = /(\d+)$/.exec(item.group_name)
if (match) {
const num = Number.parseInt(match[1], 10)
@@ -80,6 +92,9 @@ export const addGroup = (inputs: VariableAssignerNodeType) => {
})
return produce(inputs, (draft) => {
if (!draft.advanced_settings)
draft.advanced_settings = { group_enabled: false, groups: [] }
draft.advanced_settings.groups.push({
output_type: VarType.any,
variables: [],
@@ -94,6 +109,12 @@ export const renameGroup = (
groupId: string,
name: string,
) => produce(inputs, (draft) => {
if (!draft.advanced_settings)
return
const index = draft.advanced_settings.groups.findIndex(item => item.groupId === groupId)
if (index < 0)
return
draft.advanced_settings.groups[index].group_name = name
})

View File

@@ -54,10 +54,15 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
const [removedGroupIndex, setRemovedGroupIndex] = useState<number>(-1)
const handleGroupRemoved = useCallback((groupId: string) => {
return () => {
const index = inputs.advanced_settings.groups.findIndex(item => item.groupId === groupId)
if (isVarUsedInNodes([id, inputs.advanced_settings.groups[index].group_name, 'output'])) {
const groups = inputs.advanced_settings?.groups ?? []
const index = groups.findIndex(item => item.groupId === groupId)
if (index < 0)
return
const groupName = groups[index].group_name
if (isVarUsedInNodes([id, groupName, 'output'])) {
showRemoveVarConfirm()
setRemovedVars([[id, inputs.advanced_settings.groups[index].group_name, 'output']])
setRemovedVars([[id, groupName, 'output']])
setRemoveType('group')
setRemovedGroupIndex(index)
return
@@ -67,13 +72,15 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
}, [id, inputs, isVarUsedInNodes, setInputs, showRemoveVarConfirm])
const handleGroupEnabledChange = useCallback((enabled: boolean) => {
if (enabled && inputs.advanced_settings.groups.length === 0) {
const groups = inputs.advanced_settings?.groups ?? []
if (enabled && groups.length === 0) {
handleOutVarRenameChange(id, [id, 'output'], [id, 'Group1', 'output'])
}
if (!enabled && inputs.advanced_settings.groups.length > 0) {
if (inputs.advanced_settings.groups.length > 1) {
const useVars = inputs.advanced_settings.groups.filter((item, index) => index > 0 && isVarUsedInNodes([id, item.group_name, 'output']))
if (!enabled && groups.length > 0) {
if (groups.length > 1) {
const useVars = groups.filter((item, index) => index > 0 && isVarUsedInNodes([id, item.group_name, 'output']))
if (useVars.length > 0) {
showRemoveVarConfirm()
setRemovedVars(useVars.map(item => [id, item.group_name, 'output']))
@@ -82,7 +89,7 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
}
}
handleOutVarRenameChange(id, [id, inputs.advanced_settings.groups[0].group_name, 'output'], [id, 'output'])
handleOutVarRenameChange(id, [id, groups[0].group_name, 'output'], [id, 'output'])
}
setInputs(toggleGroupEnabled({ inputs, enabled }))
@@ -110,11 +117,16 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
const handleVarGroupNameChange = useCallback((groupId: string) => {
return (name: string) => {
const index = inputs.advanced_settings.groups.findIndex(item => item.groupId === groupId)
handleOutVarRenameChange(id, [id, inputs.advanced_settings.groups[index].group_name, 'output'], [id, name, 'output'])
const groups = inputs.advanced_settings?.groups ?? []
const index = groups.findIndex(item => item.groupId === groupId)
if (index < 0)
return
const oldName = groups[index].group_name
handleOutVarRenameChange(id, [id, oldName, 'output'], [id, name, 'output'])
setInputs(renameGroup(inputs, groupId, name))
if (!(id in oldNameRef.current))
oldNameRef.current[id] = inputs.advanced_settings.groups[index].group_name
oldNameRef.current[id] = oldName
renameInspectNameWithDebounce(id, name)
}
}, [handleOutVarRenameChange, id, inputs, renameInspectNameWithDebounce, setInputs])
@@ -125,7 +137,8 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
})
hideRemoveVarConfirm()
if (removeType === 'group') {
setInputs(removeGroupByIndex(inputs, removedGroupIndex))
if (removedGroupIndex >= 0)
setInputs(removeGroupByIndex(inputs, removedGroupIndex))
}
else {
// removeType === 'enableChanged' to enabled