mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 02:19:20 +08:00
refactor: select in account_service (RegisterService class) (#34500)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from hashlib import sha256
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@@ -1392,10 +1392,10 @@ class RegisterService:
|
||||
db.session.add(dify_setup)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DifySetup).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.execute(delete(DifySetup))
|
||||
db.session.execute(delete(TenantAccountJoin))
|
||||
db.session.execute(delete(Account))
|
||||
db.session.execute(delete(Tenant))
|
||||
db.session.commit()
|
||||
|
||||
logger.exception("Setup account failed, email: %s, name: %s", email, name)
|
||||
@@ -1496,7 +1496,11 @@ class RegisterService:
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add")
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not ta:
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
@@ -1553,21 +1557,18 @@ class RegisterService:
|
||||
if not invitation_data:
|
||||
return None
|
||||
|
||||
tenant = (
|
||||
db.session.query(Tenant)
|
||||
.where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
|
||||
.first()
|
||||
tenant = db.session.scalar(
|
||||
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
tenant_account = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
tenant_account = db.session.execute(
|
||||
select(Account, TenantAccountJoin.role)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
|
||||
.first()
|
||||
)
|
||||
).first()
|
||||
|
||||
if not tenant_account:
|
||||
return None
|
||||
|
||||
@@ -1034,7 +1034,7 @@ class TestRegisterService:
|
||||
)
|
||||
|
||||
# Verify rollback operations were called
|
||||
mock_db_dependencies["db"].session.query.assert_called()
|
||||
mock_db_dependencies["db"].session.execute.assert_called()
|
||||
|
||||
# ==================== Registration Tests ====================
|
||||
|
||||
@@ -1599,10 +1599,8 @@ class TestRegisterService:
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
|
||||
# Mock the db.session.query for TenantAccountJoin
|
||||
mock_db_query = MagicMock()
|
||||
mock_db_query.filter_by.return_value.first.return_value = None # No existing member
|
||||
mock_db_dependencies["db"].session.query.return_value = mock_db_query
|
||||
# Mock scalar for TenantAccountJoin lookup - no existing member
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
# Mock TenantService methods
|
||||
with (
|
||||
@@ -1777,14 +1775,9 @@ class TestRegisterService:
|
||||
}
|
||||
mock_get_invitation_by_token.return_value = invitation_data
|
||||
|
||||
# Mock database queries - complex query mocking
|
||||
mock_query1 = MagicMock()
|
||||
mock_query1.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
# Mock scalar for tenant lookup, execute for account+role lookup
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
|
||||
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
@@ -1816,10 +1809,8 @@ class TestRegisterService:
|
||||
}
|
||||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||
|
||||
# Mock database queries - no tenant found
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = None
|
||||
mock_db_dependencies["db"].session.query.return_value = mock_query
|
||||
# Mock scalar for tenant lookup - not found
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
@@ -1842,14 +1833,9 @@ class TestRegisterService:
|
||||
}
|
||||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||
|
||||
# Mock database queries
|
||||
mock_query1 = MagicMock()
|
||||
mock_query1.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = None # No account found
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
# Mock scalar for tenant, execute for account+role
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
|
||||
mock_db_dependencies["db"].session.execute.return_value.first.return_value = None # No account found
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
@@ -1875,14 +1861,9 @@ class TestRegisterService:
|
||||
}
|
||||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||
|
||||
# Mock database queries
|
||||
mock_query1 = MagicMock()
|
||||
mock_query1.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
# Mock scalar for tenant, execute for account+role
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
|
||||
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
|
||||
Reference in New Issue
Block a user