mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 06:19:25 +08:00
perf: use global httpx client instead of per request create new one (#34311)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from core.plugin.endpoint.exc import EndpointSetupFailedError
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError
|
||||
from core.plugin.impl.exc import (
|
||||
@@ -54,6 +55,11 @@ T = TypeVar("T", bound=(BaseModel | dict[str, Any] | list[Any] | bool | str))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_httpx_client: httpx.Client = get_pooled_http_client(
|
||||
"plugin_daemon",
|
||||
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), trust_env=False),
|
||||
)
|
||||
|
||||
|
||||
class BasePluginClient:
|
||||
def _request(
|
||||
@@ -84,7 +90,7 @@ class BasePluginClient:
|
||||
request_kwargs["content"] = prepared_data
|
||||
|
||||
try:
|
||||
response = httpx.request(**request_kwargs)
|
||||
response = _httpx_client.request(**request_kwargs)
|
||||
except httpx.RequestError:
|
||||
logger.exception("Request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
@@ -171,7 +177,7 @@ class BasePluginClient:
|
||||
stream_kwargs["content"] = prepared_data
|
||||
|
||||
try:
|
||||
with httpx.stream(**stream_kwargs) as response:
|
||||
with _httpx_client.stream(**stream_kwargs) as response:
|
||||
for raw_line in response.iter_lines():
|
||||
if not raw_line:
|
||||
continue
|
||||
|
||||
@@ -6,11 +6,18 @@ import httpx
|
||||
from httpx import DigestAuth
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
# Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn
|
||||
_tidb_http_client: httpx.Client = get_pooled_http_client(
|
||||
"tidb:cloud",
|
||||
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
|
||||
)
|
||||
|
||||
|
||||
class TidbService:
|
||||
@staticmethod
|
||||
@@ -50,7 +57,9 @@ class TidbService:
|
||||
"rootPassword": password,
|
||||
}
|
||||
|
||||
response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key))
|
||||
response = _tidb_http_client.post(
|
||||
f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
@@ -84,7 +93,9 @@ class TidbService:
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
|
||||
response = _tidb_http_client.delete(
|
||||
f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
@@ -103,7 +114,7 @@ class TidbService:
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
|
||||
response = _tidb_http_client.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
@@ -128,7 +139,7 @@ class TidbService:
|
||||
|
||||
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
|
||||
|
||||
response = httpx.patch(
|
||||
response = _tidb_http_client.patch(
|
||||
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
|
||||
json=body,
|
||||
auth=DigestAuth(public_key, private_key),
|
||||
@@ -162,7 +173,9 @@ class TidbService:
|
||||
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
|
||||
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
|
||||
params = {"clusterIds": cluster_ids, "view": "BASIC"}
|
||||
response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key))
|
||||
response = _tidb_http_client.get(
|
||||
f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
@@ -223,7 +236,7 @@ class TidbService:
|
||||
clusters.append(cluster_data)
|
||||
|
||||
request_body = {"requests": clusters}
|
||||
response = httpx.post(
|
||||
response = _tidb_http_client.post(
|
||||
f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key)
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from typing import NotRequired
|
||||
import httpx
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
@@ -20,6 +22,12 @@ JsonObjectList = list[JsonObject]
|
||||
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||
|
||||
# Reuse a pooled httpx.Client for OAuth flows (public endpoints, no SSRF proxy).
|
||||
_http_client: httpx.Client = get_pooled_http_client(
|
||||
"oauth:default",
|
||||
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
|
||||
)
|
||||
|
||||
|
||||
class AccessTokenResponse(TypedDict, total=False):
|
||||
access_token: str
|
||||
@@ -115,7 +123,7 @@ class GitHubOAuth(OAuth):
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
response = _http_client.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
access_token = response_json.get("access_token")
|
||||
@@ -127,7 +135,7 @@ class GitHubOAuth(OAuth):
|
||||
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response = _http_client.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
|
||||
|
||||
@@ -147,7 +155,7 @@ class GitHubOAuth(OAuth):
|
||||
Returns an empty string when no usable email is found.
|
||||
"""
|
||||
try:
|
||||
email_response = httpx.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers)
|
||||
email_response = _http_client.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers)
|
||||
email_response.raise_for_status()
|
||||
email_records = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
except (httpx.HTTPStatusError, ValidationError):
|
||||
@@ -204,7 +212,7 @@ class GoogleOAuth(OAuth):
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
response = _http_client.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
access_token = response_json.get("access_token")
|
||||
@@ -216,7 +224,7 @@ class GoogleOAuth(OAuth):
|
||||
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response = _http_client.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return _json_object(response)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from flask_login import current_user
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.source import DataSourceOauthBinding
|
||||
@@ -38,6 +39,13 @@ NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo)
|
||||
NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary)
|
||||
|
||||
|
||||
# Reuse a small pooled client for OAuth data source flows.
|
||||
_http_client: httpx.Client = get_pooled_http_client(
|
||||
"oauth:notion",
|
||||
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
|
||||
)
|
||||
|
||||
|
||||
class OAuthDataSource:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
@@ -75,7 +83,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
||||
headers = {"Accept": "application/json"}
|
||||
auth = (self.client_id, self.client_secret)
|
||||
response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
|
||||
response = _http_client.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
@@ -268,7 +276,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
|
||||
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response = _http_client.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
|
||||
results.extend(response_json.get("results", []))
|
||||
@@ -283,7 +291,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
||||
response = _http_client.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
message = response_json.get("message", "unknown error")
|
||||
@@ -299,7 +307,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
|
||||
response = _http_client.get(url=self._NOTION_BOT_USER, headers=headers)
|
||||
response_json = response.json()
|
||||
if "object" in response_json and response_json["object"] == "user":
|
||||
user_type = response_json["type"]
|
||||
@@ -323,7 +331,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response = _http_client.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
|
||||
results.extend(response_json.get("results", []))
|
||||
|
||||
@@ -2,8 +2,14 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
_http_client: httpx.Client = get_pooled_http_client(
|
||||
"auth:jina_standalone",
|
||||
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
|
||||
)
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
@@ -31,7 +37,7 @@ class JinaAuth(ApiKeyAuthBase):
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
return _http_client.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
|
||||
@@ -2,8 +2,14 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
_http_client: httpx.Client = get_pooled_http_client(
|
||||
"auth:jina",
|
||||
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
|
||||
)
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
@@ -31,7 +37,7 @@ class JinaAuth(ApiKeyAuthBase):
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
return _http_client.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
|
||||
@@ -10,6 +10,7 @@ from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fix
|
||||
from typing_extensions import TypedDict
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -18,6 +19,11 @@ from models import Account, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_http_client: httpx.Client = get_pooled_http_client(
|
||||
"billing:default",
|
||||
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
|
||||
)
|
||||
|
||||
|
||||
class SubscriptionPlan(TypedDict):
|
||||
"""Tenant subscriptionplan information."""
|
||||
@@ -131,7 +137,7 @@ class BillingService:
|
||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = httpx.request(method, url, json=json, params=params, headers=headers, follow_redirects=True)
|
||||
response = _http_client.request(method, url, json=json, params=params, headers=headers, follow_redirects=True)
|
||||
if method == "GET" and response.status_code != httpx.codes.OK:
|
||||
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
|
||||
if method == "PUT":
|
||||
|
||||
@@ -9,12 +9,23 @@ import httpx
|
||||
from flask_login import current_user
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import CrawlStatusResponse, FirecrawlApp, FirecrawlDocumentData
|
||||
from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
# Reuse pooled HTTP clients to avoid creating new connections per request and ease testing.
|
||||
_jina_http_client: httpx.Client = get_pooled_http_client(
|
||||
"website:jinareader",
|
||||
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
|
||||
)
|
||||
_adaptive_http_client: httpx.Client = get_pooled_http_client(
|
||||
"website:adaptivecrawl",
|
||||
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrawlOptions:
|
||||
@@ -225,7 +236,7 @@ class WebsiteService:
|
||||
@classmethod
|
||||
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
|
||||
if not request.options.crawl_sub_pages:
|
||||
response = httpx.get(
|
||||
response = _jina_http_client.get(
|
||||
f"https://r.jina.ai/{request.url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
@@ -233,7 +244,7 @@ class WebsiteService:
|
||||
raise ValueError("Failed to crawl:")
|
||||
return {"status": "active", "data": response.json().get("data")}
|
||||
else:
|
||||
response = httpx.post(
|
||||
response = _adaptive_http_client.post(
|
||||
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
|
||||
json={
|
||||
"url": request.url,
|
||||
@@ -296,7 +307,7 @@ class WebsiteService:
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
|
||||
response = httpx.post(
|
||||
response = _adaptive_http_client.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
@@ -312,7 +323,7 @@ class WebsiteService:
|
||||
}
|
||||
|
||||
if crawl_status_data["status"] == "completed":
|
||||
response = httpx.post(
|
||||
response = _adaptive_http_client.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
|
||||
@@ -374,7 +385,7 @@ class WebsiteService:
|
||||
@classmethod
|
||||
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
|
||||
if not job_id:
|
||||
response = httpx.get(
|
||||
response = _jina_http_client.get(
|
||||
f"https://r.jina.ai/{url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
@@ -383,7 +394,7 @@ class WebsiteService:
|
||||
return dict(response.json().get("data", {}))
|
||||
else:
|
||||
# Get crawl status first
|
||||
status_response = httpx.post(
|
||||
status_response = _adaptive_http_client.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
@@ -393,7 +404,7 @@ class WebsiteService:
|
||||
raise ValueError("Crawl job is not completed")
|
||||
|
||||
# Get processed data
|
||||
data_response = httpx.post(
|
||||
data_response = _adaptive_http_client.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
|
||||
|
||||
@@ -79,7 +79,7 @@ class TestAuthIntegration:
|
||||
|
||||
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina._http_client.post")
|
||||
def test_multi_tenant_isolation(
|
||||
self,
|
||||
mock_jina_http,
|
||||
|
||||
@@ -560,7 +560,10 @@ class TestWebsiteService:
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"code": 200, "data": {"taskId": "task-789"}}
|
||||
mock_httpx_post = mocker.patch("services.website_service.httpx.post", return_value=mock_response)
|
||||
mock_httpx_post = mocker.patch(
|
||||
"services.website_service._adaptive_http_client.post",
|
||||
return_value=mock_response,
|
||||
)
|
||||
|
||||
from services.website_service import WebsiteCrawlApiRequest
|
||||
|
||||
@@ -1340,7 +1343,7 @@ class TestProviderSpecificFeatures:
|
||||
"url": "https://example.com/page",
|
||||
},
|
||||
}
|
||||
mocker.patch("services.website_service.httpx.get", return_value=mock_response)
|
||||
mocker.patch("services.website_service._jina_http_client.get", return_value=mock_response)
|
||||
|
||||
from services.website_service import WebsiteCrawlApiRequest
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class TestBasePluginClientImpl:
|
||||
def test_stream_request_handles_data_lines_and_dict_payload(self, mocker):
|
||||
client = BasePluginClient()
|
||||
stream_mock = mocker.patch(
|
||||
"core.plugin.impl.base.httpx.stream",
|
||||
"httpx.Client.stream",
|
||||
return_value=_StreamContext([b"", b"data: hello", "world"]),
|
||||
)
|
||||
|
||||
|
||||
@@ -10,12 +10,23 @@ Tests follow the Arrange-Act-Assert pattern for clarity.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.plugin.impl.endpoint import PluginEndpointClient
|
||||
from core.plugin.impl.exc import PluginDaemonInternalServerError
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_shared_httpx_client():
|
||||
"""Patch module-level client methods to delegate to module httpx.request/stream."""
|
||||
with (
|
||||
patch("core.plugin.impl.base._httpx_client.request", side_effect=lambda **kw: httpx.request(**kw)),
|
||||
patch("core.plugin.impl.base._httpx_client.stream", side_effect=lambda **kw: httpx.stream(**kw)),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestPluginEndpointClientDelete:
|
||||
"""Unit tests for PluginEndpointClient delete_endpoint operation.
|
||||
|
||||
|
||||
@@ -47,6 +47,20 @@ from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_shared_httpx_client():
|
||||
"""Make BasePluginClient's module-level httpx client delegate to patched httpx.request/stream.
|
||||
|
||||
After refactor, code uses core.plugin.impl.base._httpx_client directly.
|
||||
Patch its request/stream to route through module-level httpx so existing mocks still apply.
|
||||
"""
|
||||
with (
|
||||
patch("core.plugin.impl.base._httpx_client.request", side_effect=lambda **kw: httpx.request(**kw)),
|
||||
patch("core.plugin.impl.base._httpx_client.stream", side_effect=lambda **kw: httpx.stream(**kw)),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestPluginRuntimeExecution:
|
||||
"""Unit tests for plugin execution functionality.
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("httpx.post", autospec=True)
|
||||
@patch("libs.oauth._http_client.post", autospec=True)
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
@@ -109,7 +109,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("httpx.get", autospec=True)
|
||||
@patch("libs.oauth._http_client.get", autospec=True)
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = user_data
|
||||
@@ -127,7 +127,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
# The profile email is absent/null, so /user/emails should be called
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
@patch("httpx.get", autospec=True)
|
||||
@patch("libs.oauth._http_client.get", autospec=True)
|
||||
def test_should_skip_email_endpoint_when_profile_email_present(self, mock_get, oauth):
|
||||
"""When the /user profile already contains an email, do not call /user/emails."""
|
||||
user_response = MagicMock()
|
||||
@@ -162,7 +162,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("httpx.get", autospec=True)
|
||||
@patch("libs.oauth._http_client.get", autospec=True)
|
||||
def test_should_use_noreply_email_when_no_usable_email(self, mock_get, oauth, user_data, email_data):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = user_data
|
||||
@@ -177,7 +177,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
assert user_info.id == str(user_data["id"])
|
||||
assert user_info.email == "12345@users.noreply.github.com"
|
||||
|
||||
@patch("httpx.get", autospec=True)
|
||||
@patch("libs.oauth._http_client.get", autospec=True)
|
||||
def test_should_use_noreply_email_when_email_endpoint_fails(self, mock_get, oauth):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = {"id": 12345, "login": "testuser", "name": "Test User"}
|
||||
@@ -194,7 +194,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
assert user_info.id == "12345"
|
||||
assert user_info.email == "12345@users.noreply.github.com"
|
||||
|
||||
@patch("httpx.get", autospec=True)
|
||||
@patch("libs.oauth._http_client.get", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_get, oauth):
|
||||
mock_get.side_effect = httpx.RequestError("Network error")
|
||||
|
||||
@@ -240,7 +240,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("httpx.post", autospec=True)
|
||||
@patch("libs.oauth._http_client.post", autospec=True)
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
@@ -274,7 +274,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
|
||||
],
|
||||
)
|
||||
@patch("httpx.get", autospec=True)
|
||||
@patch("libs.oauth._http_client.get", autospec=True)
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
|
||||
mock_response.json.return_value = user_data
|
||||
mock_get.return_value = mock_response
|
||||
@@ -295,7 +295,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
httpx.TimeoutException,
|
||||
],
|
||||
)
|
||||
@patch("httpx.get", autospec=True)
|
||||
@patch("libs.oauth._http_client.get", autospec=True)
|
||||
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = exception_type("Error")
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestJinaAuth:
|
||||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "No API key provided"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@@ -53,7 +53,7 @@ class TestJinaAuth:
|
||||
json={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_handle_http_402_error(self, mock_post):
|
||||
"""Test handling of 402 Payment Required error"""
|
||||
mock_response = MagicMock()
|
||||
@@ -68,7 +68,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_handle_http_409_error(self, mock_post):
|
||||
"""Test handling of 409 Conflict error"""
|
||||
mock_response = MagicMock()
|
||||
@@ -83,7 +83,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_handle_http_500_error(self, mock_post):
|
||||
"""Test handling of 500 Internal Server Error"""
|
||||
mock_response = MagicMock()
|
||||
@@ -98,7 +98,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
|
||||
"""Test handling of unexpected errors with text response"""
|
||||
mock_response = MagicMock()
|
||||
@@ -114,7 +114,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_handle_unexpected_error_without_text(self, mock_post):
|
||||
"""Test handling of unexpected errors without text response"""
|
||||
mock_response = MagicMock()
|
||||
@@ -130,7 +130,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
@patch("services.auth.jina.jina._http_client.post", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_post):
|
||||
"""Test handling of network connection errors"""
|
||||
mock_post.side_effect = httpx.ConnectError("Network error")
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_prepare_headers_includes_bearer_api_key(jina_module: ModuleType) -> Non
|
||||
def test_post_request_calls_httpx(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
post_mock = MagicMock(name="httpx.post")
|
||||
monkeypatch.setattr(jina_module.httpx, "post", post_mock)
|
||||
monkeypatch.setattr(jina_module._http_client, "post", post_mock)
|
||||
|
||||
auth._post_request("https://r.jina.ai", {"url": "https://example.com"}, {"h": "v"})
|
||||
post_mock.assert_called_once_with("https://r.jina.ai", headers={"h": "v"}, json={"url": "https://example.com"})
|
||||
@@ -72,7 +72,7 @@ def test_validate_credentials_success(jina_module: ModuleType, monkeypatch: pyte
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
post_mock = MagicMock(return_value=response)
|
||||
monkeypatch.setattr(jina_module.httpx, "post", post_mock)
|
||||
monkeypatch.setattr(jina_module._http_client, "post", post_mock)
|
||||
|
||||
assert auth.validate_credentials() is True
|
||||
post_mock.assert_called_once_with(
|
||||
@@ -90,7 +90,7 @@ def test_validate_credentials_non_200_raises_via_handle_error(
|
||||
response = MagicMock()
|
||||
response.status_code = 402
|
||||
response.json.return_value = {"error": "Payment required"}
|
||||
monkeypatch.setattr(jina_module.httpx, "post", MagicMock(return_value=response))
|
||||
monkeypatch.setattr(jina_module._http_client, "post", MagicMock(return_value=response))
|
||||
|
||||
with pytest.raises(Exception, match="Status code: 402.*Payment required"):
|
||||
auth.validate_credentials()
|
||||
@@ -151,7 +151,7 @@ def test_validate_credentials_propagates_network_errors(
|
||||
jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
monkeypatch.setattr(jina_module.httpx, "post", MagicMock(side_effect=httpx.ConnectError("boom")))
|
||||
monkeypatch.setattr(jina_module._http_client, "post", MagicMock(side_effect=httpx.ConnectError("boom")))
|
||||
|
||||
with pytest.raises(httpx.ConnectError, match="boom"):
|
||||
auth.validate_credentials()
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestBillingServiceSendRequest:
|
||||
@pytest.fixture
|
||||
def mock_httpx_request(self):
|
||||
"""Mock httpx.request for testing."""
|
||||
with patch("services.billing_service.httpx.request") as mock_request:
|
||||
with patch("services.billing_service._http_client.request") as mock_request:
|
||||
yield mock_request
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from graphon.model_runtime.entities.provider_entities import FormType
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -71,6 +72,8 @@ class TestDatasourceProviderService:
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_externals(self):
|
||||
with (
|
||||
patch("core.plugin.impl.base._httpx_client.request", side_effect=lambda **kw: httpx.request(**kw)),
|
||||
patch("core.plugin.impl.base._httpx_client.stream", side_effect=lambda **kw: httpx.stream(**kw)),
|
||||
patch("httpx.request") as mock_httpx,
|
||||
patch("services.datasource_provider_service.dify_config") as mock_cfg,
|
||||
patch("services.datasource_provider_service.encrypter") as mock_enc,
|
||||
|
||||
@@ -343,7 +343,7 @@ def test_crawl_with_watercrawl_passes_options_dict(monkeypatch: pytest.MonkeyPat
|
||||
|
||||
def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
get_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"title": "t"}}))
|
||||
monkeypatch.setattr(website_service_module.httpx, "get", get_mock)
|
||||
monkeypatch.setattr(website_service_module._jina_http_client, "get", get_mock)
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False}
|
||||
@@ -356,7 +356,11 @@ def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPat
|
||||
|
||||
|
||||
def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500})))
|
||||
monkeypatch.setattr(
|
||||
website_service_module._jina_http_client,
|
||||
"get",
|
||||
MagicMock(return_value=_DummyHttpxResponse({"code": 500})),
|
||||
)
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False}
|
||||
).to_crawl_request()
|
||||
@@ -368,7 +372,7 @@ def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPat
|
||||
|
||||
def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
post_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"taskId": "t1"}}))
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock)
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader",
|
||||
@@ -384,7 +388,7 @@ def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatc
|
||||
|
||||
def test_crawl_with_jinareader_multi_page_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
website_service_module.httpx, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400}))
|
||||
website_service_module._adaptive_http_client, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400}))
|
||||
)
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader",
|
||||
@@ -482,7 +486,7 @@ def test_get_jinareader_status_active(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
}
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock)
|
||||
|
||||
result = WebsiteService._get_jinareader_status("job-1", "k")
|
||||
assert result["status"] == "active"
|
||||
@@ -518,7 +522,7 @@ def test_get_jinareader_status_completed_formats_processed_items(monkeypatch: py
|
||||
}
|
||||
}
|
||||
post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)])
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock)
|
||||
|
||||
result = WebsiteService._get_jinareader_status("job-1", "k")
|
||||
assert result["status"] == "completed"
|
||||
@@ -619,7 +623,7 @@ def test_get_watercrawl_url_data_delegates(monkeypatch: pytest.MonkeyPatch) -> N
|
||||
|
||||
def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
website_service_module.httpx,
|
||||
website_service_module._jina_http_client,
|
||||
"get",
|
||||
MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"url": "u"}})),
|
||||
)
|
||||
@@ -627,7 +631,11 @@ def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.Monk
|
||||
|
||||
|
||||
def test_get_jinareader_url_data_without_job_id_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500})))
|
||||
monkeypatch.setattr(
|
||||
website_service_module._jina_http_client,
|
||||
"get",
|
||||
MagicMock(return_value=_DummyHttpxResponse({"code": 500})),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Failed to crawl$"):
|
||||
WebsiteService._get_jinareader_url_data("", "u", "k")
|
||||
|
||||
@@ -637,7 +645,7 @@ def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(mon
|
||||
processed_payload = {"data": {"processed": {"u1": {"data": {"url": "u", "title": "t"}}}}}
|
||||
|
||||
post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)])
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock)
|
||||
|
||||
assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") == {"url": "u", "title": "t"}
|
||||
assert post_mock.call_count == 2
|
||||
@@ -645,7 +653,7 @@ def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(mon
|
||||
|
||||
def test_get_jinareader_url_data_with_job_id_not_completed_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
post_mock = MagicMock(return_value=_DummyHttpxResponse({"data": {"status": "active"}}))
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Crawl job is no\s*t completed"):
|
||||
WebsiteService._get_jinareader_url_data("job-1", "u", "k")
|
||||
@@ -658,7 +666,7 @@ def test_get_jinareader_url_data_with_job_id_completed_but_not_found_returns_non
|
||||
processed_payload = {"data": {"processed": {"u1": {"data": {"url": "other"}}}}}
|
||||
|
||||
post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)])
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock)
|
||||
|
||||
assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") is None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user