mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 16:39:26 +08:00
refactor: init_validate.py to v3 (#31457)
This commit is contained in:
@@ -1,87 +1,74 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from flask import session
|
from flask import session
|
||||||
from flask_restx import Resource, fields
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from controllers.fastopenapi import console_router
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
|
|
||||||
from . import console_ns
|
|
||||||
from .error import AlreadySetupError, InitValidateFailedError
|
from .error import AlreadySetupError, InitValidateFailedError
|
||||||
from .wraps import only_edition_self_hosted
|
from .wraps import only_edition_self_hosted
|
||||||
|
|
||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
|
||||||
|
|
||||||
|
|
||||||
class InitValidatePayload(BaseModel):
|
class InitValidatePayload(BaseModel):
|
||||||
password: str = Field(..., max_length=30)
|
password: str = Field(..., max_length=30, description="Initialization password")
|
||||||
|
|
||||||
|
|
||||||
console_ns.schema_model(
|
class InitStatusResponse(BaseModel):
|
||||||
InitValidatePayload.__name__,
|
status: Literal["finished", "not_started"] = Field(..., description="Initialization status")
|
||||||
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
|
||||||
|
|
||||||
|
class InitValidateResponse(BaseModel):
|
||||||
|
result: str = Field(description="Operation result", examples=["success"])
|
||||||
|
|
||||||
|
|
||||||
|
@console_router.get(
|
||||||
|
"/init",
|
||||||
|
response_model=InitStatusResponse,
|
||||||
|
tags=["console"],
|
||||||
)
|
)
|
||||||
|
def get_init_status() -> InitStatusResponse:
|
||||||
|
"""Get initialization validation status."""
|
||||||
|
init_status = get_init_validate_status()
|
||||||
|
if init_status:
|
||||||
|
return InitStatusResponse(status="finished")
|
||||||
|
return InitStatusResponse(status="not_started")
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/init")
|
@console_router.post(
|
||||||
class InitValidateAPI(Resource):
|
"/init",
|
||||||
@console_ns.doc("get_init_status")
|
response_model=InitValidateResponse,
|
||||||
@console_ns.doc(description="Get initialization validation status")
|
tags=["console"],
|
||||||
@console_ns.response(
|
status_code=201,
|
||||||
200,
|
)
|
||||||
"Success",
|
@only_edition_self_hosted
|
||||||
model=console_ns.model(
|
def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
|
||||||
"InitStatusResponse",
|
"""Validate initialization password."""
|
||||||
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
|
tenant_count = TenantService.get_tenant_count()
|
||||||
),
|
if tenant_count > 0:
|
||||||
)
|
raise AlreadySetupError()
|
||||||
def get(self):
|
|
||||||
"""Get initialization validation status"""
|
|
||||||
init_status = get_init_validate_status()
|
|
||||||
if init_status:
|
|
||||||
return {"status": "finished"}
|
|
||||||
return {"status": "not_started"}
|
|
||||||
|
|
||||||
@console_ns.doc("validate_init_password")
|
if payload.password != os.environ.get("INIT_PASSWORD"):
|
||||||
@console_ns.doc(description="Validate initialization password for self-hosted edition")
|
session["is_init_validated"] = False
|
||||||
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
|
raise InitValidateFailedError()
|
||||||
@console_ns.response(
|
|
||||||
201,
|
|
||||||
"Success",
|
|
||||||
model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
|
|
||||||
)
|
|
||||||
@console_ns.response(400, "Already setup or validation failed")
|
|
||||||
@only_edition_self_hosted
|
|
||||||
def post(self):
|
|
||||||
"""Validate initialization password"""
|
|
||||||
# is tenant created
|
|
||||||
tenant_count = TenantService.get_tenant_count()
|
|
||||||
if tenant_count > 0:
|
|
||||||
raise AlreadySetupError()
|
|
||||||
|
|
||||||
payload = InitValidatePayload.model_validate(console_ns.payload)
|
session["is_init_validated"] = True
|
||||||
input_password = payload.password
|
return InitValidateResponse(result="success")
|
||||||
|
|
||||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
|
||||||
session["is_init_validated"] = False
|
|
||||||
raise InitValidateFailedError()
|
|
||||||
|
|
||||||
session["is_init_validated"] = True
|
|
||||||
return {"result": "success"}, 201
|
|
||||||
|
|
||||||
|
|
||||||
def get_init_validate_status():
|
def get_init_validate_status() -> bool:
|
||||||
if dify_config.EDITION == "SELF_HOSTED":
|
if dify_config.EDITION == "SELF_HOSTED":
|
||||||
if os.environ.get("INIT_PASSWORD"):
|
if os.environ.get("INIT_PASSWORD"):
|
||||||
if session.get("is_init_validated"):
|
if session.get("is_init_validated"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
with Session(db.engine) as db_session:
|
with Session(db.engine) as db_session:
|
||||||
return db_session.execute(select(DifySetup)).scalar_one_or_none()
|
return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -27,9 +27,11 @@ def init_app(app: DifyApp) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Ensure route decorators are evaluated.
|
# Ensure route decorators are evaluated.
|
||||||
|
import controllers.console.init_validate as init_validate_module
|
||||||
import controllers.console.ping as ping_module
|
import controllers.console.ping as ping_module
|
||||||
from controllers.console import remote_files, setup
|
from controllers.console import remote_files, setup
|
||||||
|
|
||||||
|
_ = init_validate_module
|
||||||
_ = ping_module
|
_ = ping_module
|
||||||
_ = remote_files
|
_ = remote_files
|
||||||
_ = setup
|
_ = setup
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
import builtins
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
from flask.views import MethodView
|
||||||
|
|
||||||
|
from extensions import ext_fastopenapi
|
||||||
|
|
||||||
|
if not hasattr(builtins, "MethodView"):
|
||||||
|
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app() -> Flask:
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config["TESTING"] = True
|
||||||
|
app.secret_key = "test-secret-key"
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
ext_fastopenapi.init_app(app)
|
||||||
|
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||||
|
|
||||||
|
with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"):
|
||||||
|
client = app.test_client()
|
||||||
|
response = client.get("/console/api/init")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.get_json() == {"status": "finished"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
ext_fastopenapi.init_app(app)
|
||||||
|
monkeypatch.setenv("INIT_PASSWORD", "test-init-password")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"),
|
||||||
|
patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0),
|
||||||
|
):
|
||||||
|
client = app.test_client()
|
||||||
|
response = client.post("/console/api/init", json={"password": "test-init-password"})
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
assert response.get_json() == {"result": "success"}
|
||||||
Reference in New Issue
Block a user