mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 05:09:19 +08:00
refactor: select in account_service (TenantService class) (#34499)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
@@ -8,7 +8,7 @@ from hashlib import sha256
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@@ -1069,11 +1069,11 @@ class TenantService:
|
||||
@staticmethod
|
||||
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
|
||||
"""Check if user have a workspace or not"""
|
||||
available_ta = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter_by(account_id=account.id)
|
||||
available_ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if available_ta:
|
||||
@@ -1104,7 +1104,11 @@ class TenantService:
|
||||
logger.error("Tenant %s has already an owner.", tenant.id)
|
||||
raise Exception("Tenant already has an owner.")
|
||||
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
if ta:
|
||||
ta.role = TenantAccountRole(role)
|
||||
else:
|
||||
@@ -1119,11 +1123,12 @@ class TenantService:
|
||||
@staticmethod
|
||||
def get_join_tenants(account: Account) -> list[Tenant]:
|
||||
"""Get account join tenants"""
|
||||
return (
|
||||
db.session.query(Tenant)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
|
||||
.all()
|
||||
return list(
|
||||
db.session.scalars(
|
||||
select(Tenant)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
|
||||
).all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1133,7 +1138,11 @@ class TenantService:
|
||||
if not tenant:
|
||||
raise TenantNotFoundError("Tenant not found.")
|
||||
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
if ta:
|
||||
tenant.role = ta.role
|
||||
else:
|
||||
@@ -1148,23 +1157,25 @@ class TenantService:
|
||||
if tenant_id is None:
|
||||
raise ValueError("Tenant ID must be provided.")
|
||||
|
||||
tenant_account_join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
tenant_account_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(
|
||||
TenantAccountJoin.account_id == account.id,
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not tenant_account_join:
|
||||
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
||||
else:
|
||||
db.session.query(TenantAccountJoin).where(
|
||||
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
|
||||
).update({"current": False})
|
||||
db.session.execute(
|
||||
update(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id)
|
||||
.values(current=False)
|
||||
)
|
||||
tenant_account_join.current = True
|
||||
# Set the current tenant for the account
|
||||
account.set_tenant_id(tenant_account_join.tenant_id)
|
||||
@@ -1173,8 +1184,8 @@ class TenantService:
|
||||
@staticmethod
|
||||
def get_tenant_members(tenant: Tenant) -> list[Account]:
|
||||
"""Get tenant members"""
|
||||
query = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
stmt = (
|
||||
select(Account, TenantAccountJoin.role)
|
||||
.select_from(Account)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id)
|
||||
@@ -1183,7 +1194,7 @@ class TenantService:
|
||||
# Initialize an empty list to store the updated accounts
|
||||
updated_accounts = []
|
||||
|
||||
for account, role in query:
|
||||
for account, role in db.session.execute(stmt):
|
||||
account.role = role
|
||||
updated_accounts.append(account)
|
||||
|
||||
@@ -1192,8 +1203,8 @@ class TenantService:
|
||||
@staticmethod
|
||||
def get_dataset_operator_members(tenant: Tenant) -> list[Account]:
|
||||
"""Get dataset admin members"""
|
||||
query = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
stmt = (
|
||||
select(Account, TenantAccountJoin.role)
|
||||
.select_from(Account)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id)
|
||||
@@ -1203,7 +1214,7 @@ class TenantService:
|
||||
# Initialize an empty list to store the updated accounts
|
||||
updated_accounts = []
|
||||
|
||||
for account, role in query:
|
||||
for account, role in db.session.execute(stmt):
|
||||
account.role = role
|
||||
updated_accounts.append(account)
|
||||
|
||||
@@ -1216,26 +1227,31 @@ class TenantService:
|
||||
raise ValueError("all roles must be TenantAccountRole")
|
||||
|
||||
return (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]))
|
||||
.first()
|
||||
db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
TenantAccountJoin.role.in_([role.value for role in roles]),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
|
||||
"""Get the role of the current account for a given tenant"""
|
||||
join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return TenantAccountRole(join.role) if join else None
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_count() -> int:
|
||||
"""Get tenant count"""
|
||||
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
|
||||
return cast(int, db.session.scalar(select(func.count(Tenant.id))))
|
||||
|
||||
@staticmethod
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
|
||||
@@ -1252,7 +1268,11 @@ class TenantService:
|
||||
if operator.id == member.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first()
|
||||
ta_operator = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not ta_operator or ta_operator.role not in perms[action]:
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
@@ -1270,7 +1290,11 @@ class TenantService:
|
||||
|
||||
TenantService.check_member_permission(tenant, operator, account, "remove")
|
||||
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
if not ta:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
@@ -1285,7 +1309,12 @@ class TenantService:
|
||||
should_delete_account = False
|
||||
if account.status == AccountStatus.PENDING:
|
||||
# autoflush flushes ta deletion before this query, so 0 means no remaining joins
|
||||
remaining_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).count()
|
||||
remaining_joins = (
|
||||
db.session.scalar(
|
||||
select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if remaining_joins == 0:
|
||||
db.session.delete(account)
|
||||
should_delete_account = True
|
||||
@@ -1320,8 +1349,10 @@ class TenantService:
|
||||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, "update")
|
||||
|
||||
target_member_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first()
|
||||
target_member_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not target_member_join:
|
||||
@@ -1332,8 +1363,10 @@ class TenantService:
|
||||
|
||||
if new_role == "owner":
|
||||
# Find the current owner and change their role to 'admin'
|
||||
current_owner_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
current_owner_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
|
||||
.limit(1)
|
||||
)
|
||||
if current_owner_join:
|
||||
current_owner_join.role = TenantAccountRole.ADMIN
|
||||
|
||||
@@ -556,12 +556,8 @@ class TestTenantService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
|
||||
# Setup smart database query mock - no existing tenant joins
|
||||
query_results = {
|
||||
("TenantAccountJoin", "account_id", "user-123"): None,
|
||||
("TenantAccountJoin", "tenant_id", "tenant-456"): None, # For has_roles check
|
||||
}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
# Mock scalar to return None (no existing tenant joins)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
# Setup external service mocks
|
||||
mock_external_service_dependencies[
|
||||
@@ -650,9 +646,8 @@ class TestTenantService:
|
||||
mock_tenant.id = "tenant-456"
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
|
||||
# Setup smart database query mock - no existing member
|
||||
query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): None}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
# Mock scalar to return None (no existing member)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
# Mock database operations
|
||||
mock_db_dependencies["db"].session.add = MagicMock()
|
||||
@@ -693,16 +688,8 @@ class TestTenantService:
|
||||
tenant_id="tenant-456", account_id="operator-123", role="owner"
|
||||
)
|
||||
|
||||
query_mock_permission = MagicMock()
|
||||
query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join
|
||||
|
||||
query_mock_ta = MagicMock()
|
||||
query_mock_ta.filter_by.return_value.first.return_value = mock_ta
|
||||
|
||||
query_mock_count = MagicMock()
|
||||
query_mock_count.filter_by.return_value.count.return_value = 0
|
||||
|
||||
mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta, query_mock_count]
|
||||
# scalar calls: permission check, ta lookup, remaining count
|
||||
mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta, 0]
|
||||
|
||||
with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync:
|
||||
mock_sync.return_value = True
|
||||
@@ -741,17 +728,8 @@ class TestTenantService:
|
||||
tenant_id="tenant-456", account_id="operator-123", role="owner"
|
||||
)
|
||||
|
||||
query_mock_permission = MagicMock()
|
||||
query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join
|
||||
|
||||
query_mock_ta = MagicMock()
|
||||
query_mock_ta.filter_by.return_value.first.return_value = mock_ta
|
||||
|
||||
# Remaining join count = 1 (still in another workspace)
|
||||
query_mock_count = MagicMock()
|
||||
query_mock_count.filter_by.return_value.count.return_value = 1
|
||||
|
||||
mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta, query_mock_count]
|
||||
# scalar calls: permission check, ta lookup, remaining count = 1
|
||||
mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta, 1]
|
||||
|
||||
with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync:
|
||||
mock_sync.return_value = True
|
||||
@@ -781,13 +759,8 @@ class TestTenantService:
|
||||
tenant_id="tenant-456", account_id="operator-123", role="owner"
|
||||
)
|
||||
|
||||
query_mock_permission = MagicMock()
|
||||
query_mock_permission.filter_by.return_value.first.return_value = mock_operator_join
|
||||
|
||||
query_mock_ta = MagicMock()
|
||||
query_mock_ta.filter_by.return_value.first.return_value = mock_ta
|
||||
|
||||
mock_db.session.query.side_effect = [query_mock_permission, query_mock_ta]
|
||||
# scalar calls: permission check, ta lookup (no count needed for active member)
|
||||
mock_db.session.scalar.side_effect = [mock_operator_join, mock_ta]
|
||||
|
||||
with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync:
|
||||
mock_sync.return_value = True
|
||||
@@ -810,13 +783,8 @@ class TestTenantService:
|
||||
|
||||
# Mock the complex query in switch_tenant method
|
||||
with patch("services.account_service.db") as mock_db:
|
||||
# Mock the join query that returns the tenant_account_join
|
||||
mock_query = MagicMock()
|
||||
mock_where = MagicMock()
|
||||
mock_where.first.return_value = mock_tenant_join
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_db.session.query.return_value = mock_query
|
||||
# Mock scalar for the join query
|
||||
mock_db.session.scalar.return_value = mock_tenant_join
|
||||
|
||||
# Execute test
|
||||
TenantService.switch_tenant(mock_account, "tenant-456")
|
||||
@@ -851,20 +819,8 @@ class TestTenantService:
|
||||
|
||||
# Mock the database queries in update_member_role method
|
||||
with patch("services.account_service.db") as mock_db:
|
||||
# Mock the first query for operator permission check
|
||||
mock_query1 = MagicMock()
|
||||
mock_filter1 = MagicMock()
|
||||
mock_filter1.first.return_value = mock_operator_join
|
||||
mock_query1.filter_by.return_value = mock_filter1
|
||||
|
||||
# Mock the second query for target member
|
||||
mock_query2 = MagicMock()
|
||||
mock_filter2 = MagicMock()
|
||||
mock_filter2.first.return_value = mock_target_join
|
||||
mock_query2.filter_by.return_value = mock_filter2
|
||||
|
||||
# Make the query method return different mocks for different calls
|
||||
mock_db.session.query.side_effect = [mock_query1, mock_query2]
|
||||
# scalar calls: permission check, target member lookup
|
||||
mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join]
|
||||
|
||||
# Execute test
|
||||
TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator)
|
||||
@@ -886,9 +842,7 @@ class TestTenantService:
|
||||
tenant_id="tenant-456", account_id="operator-123", role="owner"
|
||||
)
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): mock_operator_join}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_operator_join
|
||||
|
||||
# Execute test - should not raise exception
|
||||
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add")
|
||||
|
||||
Reference in New Issue
Block a user