mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 16:39:26 +08:00
Fix typing errors in dataset API (#26424)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
from typing import Literal
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import marshal, reqparse
|
from flask_restx import marshal, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
import services.dataset_service
|
import services
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||||
from controllers.service_api.wraps import (
|
from controllers.service_api.wraps import (
|
||||||
@@ -254,19 +254,21 @@ class DatasetListApi(DatasetApiResource):
|
|||||||
"""Resource for creating datasets."""
|
"""Resource for creating datasets."""
|
||||||
args = dataset_create_parser.parse_args()
|
args = dataset_create_parser.parse_args()
|
||||||
|
|
||||||
if args.get("embedding_model_provider"):
|
embedding_model_provider = args.get("embedding_model_provider")
|
||||||
DatasetService.check_embedding_model_setting(
|
embedding_model = args.get("embedding_model")
|
||||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
if embedding_model_provider and embedding_model:
|
||||||
)
|
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||||
|
|
||||||
|
retrieval_model = args.get("retrieval_model")
|
||||||
if (
|
if (
|
||||||
args.get("retrieval_model")
|
retrieval_model
|
||||||
and args.get("retrieval_model").get("reranking_model")
|
and retrieval_model.get("reranking_model")
|
||||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||||
):
|
):
|
||||||
DatasetService.check_reranking_model_setting(
|
DatasetService.check_reranking_model_setting(
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -317,7 +319,7 @@ class DatasetApi(DatasetApiResource):
|
|||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
data = marshal(dataset, dataset_detail_fields)
|
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_manager = ProviderManager()
|
provider_manager = ProviderManager()
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
@@ -331,8 +333,8 @@ class DatasetApi(DatasetApiResource):
|
|||||||
for embedding_model in embedding_models:
|
for embedding_model in embedding_models:
|
||||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||||
|
|
||||||
if data["indexing_technique"] == "high_quality":
|
if data.get("indexing_technique") == "high_quality":
|
||||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
|
||||||
if item_model in model_names:
|
if item_model in model_names:
|
||||||
data["embedding_available"] = True
|
data["embedding_available"] = True
|
||||||
else:
|
else:
|
||||||
@@ -341,7 +343,9 @@ class DatasetApi(DatasetApiResource):
|
|||||||
data["embedding_available"] = True
|
data["embedding_available"] = True
|
||||||
|
|
||||||
# force update search method to keyword_search if indexing_technique is economic
|
# force update search method to keyword_search if indexing_technique is economic
|
||||||
data["retrieval_model_dict"]["search_method"] = "keyword_search"
|
retrieval_model_dict = data.get("retrieval_model_dict")
|
||||||
|
if retrieval_model_dict:
|
||||||
|
retrieval_model_dict["search_method"] = "keyword_search"
|
||||||
|
|
||||||
if data.get("permission") == "partial_members":
|
if data.get("permission") == "partial_members":
|
||||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||||
@@ -372,19 +376,24 @@ class DatasetApi(DatasetApiResource):
|
|||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
|
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
|
embedding_model_provider = data.get("embedding_model_provider")
|
||||||
DatasetService.check_embedding_model_setting(
|
embedding_model = data.get("embedding_model")
|
||||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
|
||||||
)
|
if embedding_model_provider and embedding_model:
|
||||||
|
DatasetService.check_embedding_model_setting(
|
||||||
|
dataset.tenant_id, embedding_model_provider, embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
|
retrieval_model = data.get("retrieval_model")
|
||||||
if (
|
if (
|
||||||
data.get("retrieval_model")
|
retrieval_model
|
||||||
and data.get("retrieval_model").get("reranking_model")
|
and retrieval_model.get("reranking_model")
|
||||||
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||||
):
|
):
|
||||||
DatasetService.check_reranking_model_setting(
|
DatasetService.check_reranking_model_setting(
|
||||||
dataset.tenant_id,
|
dataset.tenant_id,
|
||||||
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||||
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
@@ -397,7 +406,7 @@ class DatasetApi(DatasetApiResource):
|
|||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
result_data = marshal(dataset, dataset_detail_fields)
|
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
@@ -591,9 +600,10 @@ class DatasetTagsApi(DatasetApiResource):
|
|||||||
|
|
||||||
args = tag_update_parser.parse_args()
|
args = tag_update_parser.parse_args()
|
||||||
args["type"] = "knowledge"
|
args["type"] = "knowledge"
|
||||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
tag_id = args["tag_id"]
|
||||||
|
tag = TagService.update_tags(args, tag_id)
|
||||||
|
|
||||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||||
|
|
||||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||||
|
|
||||||
@@ -616,7 +626,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||||||
if not current_user.has_edit_permission:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
args = tag_delete_parser.parse_args()
|
args = tag_delete_parser.parse_args()
|
||||||
TagService.delete_tag(args.get("tag_id"))
|
TagService.delete_tag(args["tag_id"])
|
||||||
|
|
||||||
return 204
|
return 204
|
||||||
|
|
||||||
|
|||||||
@@ -108,19 +108,21 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||||||
if text is None or name is None:
|
if text is None or name is None:
|
||||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||||
|
|
||||||
if args.get("embedding_model_provider"):
|
embedding_model_provider = args.get("embedding_model_provider")
|
||||||
DatasetService.check_embedding_model_setting(
|
embedding_model = args.get("embedding_model")
|
||||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
if embedding_model_provider and embedding_model:
|
||||||
)
|
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||||
|
|
||||||
|
retrieval_model = args.get("retrieval_model")
|
||||||
if (
|
if (
|
||||||
args.get("retrieval_model")
|
retrieval_model
|
||||||
and args.get("retrieval_model").get("reranking_model")
|
and retrieval_model.get("reranking_model")
|
||||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||||
):
|
):
|
||||||
DatasetService.check_reranking_model_setting(
|
DatasetService.check_reranking_model_setting(
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not current_user:
|
if not current_user:
|
||||||
@@ -187,15 +189,16 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset does not exist.")
|
raise ValueError("Dataset does not exist.")
|
||||||
|
|
||||||
|
retrieval_model = args.get("retrieval_model")
|
||||||
if (
|
if (
|
||||||
args.get("retrieval_model")
|
retrieval_model
|
||||||
and args.get("retrieval_model").get("reranking_model")
|
and retrieval_model.get("reranking_model")
|
||||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||||
):
|
):
|
||||||
DatasetService.check_reranking_model_setting(
|
DatasetService.check_reranking_model_setting(
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# indexing_technique is already set in dataset since this is an update
|
# indexing_technique is already set in dataset since this is an update
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
|||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
|
||||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
|
||||||
return marshal(metadata, dataset_metadata_fields), 200
|
return marshal(metadata, dataset_metadata_fields), 200
|
||||||
|
|
||||||
@service_api_ns.doc("delete_dataset_metadata")
|
@service_api_ns.doc("delete_dataset_metadata")
|
||||||
|
|||||||
@@ -8,7 +8,6 @@
|
|||||||
"extensions",
|
"extensions",
|
||||||
"libs",
|
"libs",
|
||||||
"controllers/console/datasets",
|
"controllers/console/datasets",
|
||||||
"controllers/service_api/dataset",
|
|
||||||
"core/ops",
|
"core/ops",
|
||||||
"core/tools",
|
"core/tools",
|
||||||
"core/model_runtime",
|
"core/model_runtime",
|
||||||
|
|||||||
Reference in New Issue
Block a user