diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 2d0ab3fcd73..706ae248f0a 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -17,6 +17,7 @@ from pydantic import BaseModel from yarl import URL from configs import dify_config +from core.helper.http_client_pooling import get_pooled_http_client from core.plugin.endpoint.exc import EndpointSetupFailedError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError from core.plugin.impl.exc import ( @@ -54,6 +55,11 @@ T = TypeVar("T", bound=(BaseModel | dict[str, Any] | list[Any] | bool | str)) logger = logging.getLogger(__name__) +_httpx_client: httpx.Client = get_pooled_http_client( + "plugin_daemon", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), trust_env=False), +) + class BasePluginClient: def _request( @@ -84,7 +90,7 @@ class BasePluginClient: request_kwargs["content"] = prepared_data try: - response = httpx.request(**request_kwargs) + response = _httpx_client.request(**request_kwargs) except httpx.RequestError: logger.exception("Request to Plugin Daemon Service failed") raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") @@ -171,7 +177,7 @@ class BasePluginClient: stream_kwargs["content"] = prepared_data try: - with httpx.stream(**stream_kwargs) as response: + with _httpx_client.stream(**stream_kwargs) as response: for raw_line in response.iter_lines(): if not raw_line: continue diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 06b17b9e62c..37114be6e72 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -6,11 +6,18 @@ import httpx from httpx import DigestAuth from configs import dify_config +from core.helper.http_client_pooling import get_pooled_http_client from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding from models.enums import TidbAuthBindingStatus +# Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn +_tidb_http_client: httpx.Client = get_pooled_http_client( + "tidb:cloud", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + class TidbService: @staticmethod @@ -50,7 +57,9 @@ class TidbService: "rootPassword": password, } - response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key)) + response = _tidb_http_client.post( + f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key) + ) if response.status_code == 200: response_data = response.json() @@ -84,7 +93,9 @@ class TidbService: :return: The response from the API. """ - response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) + response = _tidb_http_client.delete( + f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key) + ) if response.status_code == 200: return response.json() @@ -103,7 +114,7 @@ class TidbService: :return: The response from the API. """ - response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) + response = _tidb_http_client.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) if response.status_code == 200: return response.json() @@ -128,7 +139,7 @@ class TidbService: body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} - response = httpx.patch( + response = _tidb_http_client.patch( f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", json=body, auth=DigestAuth(public_key, private_key), @@ -162,7 +173,9 @@ class TidbService: tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} cluster_ids = [item.cluster_id for item in tidb_serverless_list] params = {"clusterIds": cluster_ids, "view": "BASIC"} - response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key)) + response = _tidb_http_client.get( + f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key) + ) if response.status_code == 200: response_data = response.json() @@ -223,7 +236,7 @@ class TidbService: clusters.append(cluster_data) request_body = {"requests": clusters} - response = httpx.post( + response = _tidb_http_client.post( f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key) ) diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 76e741301cd..a2f11140333 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -7,6 +7,8 @@ from typing import NotRequired import httpx from pydantic import TypeAdapter, ValidationError +from core.helper.http_client_pooling import get_pooled_http_client + if sys.version_info >= (3, 12): from typing import TypedDict else: @@ -20,6 +22,12 @@ JsonObjectList = list[JsonObject] JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject) JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList) +# Reuse a pooled httpx.Client for OAuth flows (public endpoints, no SSRF proxy). +_http_client: httpx.Client = get_pooled_http_client( + "oauth:default", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + class AccessTokenResponse(TypedDict, total=False): access_token: str @@ -115,7 +123,7 @@ class GitHubOAuth(OAuth): "redirect_uri": self.redirect_uri, } headers = {"Accept": "application/json"} - response = httpx.post(self._TOKEN_URL, data=data, headers=headers) + response = _http_client.post(self._TOKEN_URL, data=data, headers=headers) response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response)) access_token = response_json.get("access_token") @@ -127,7 +135,7 @@ class GitHubOAuth(OAuth): def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"token {token}"} - response = httpx.get(self._USER_INFO_URL, headers=headers) + response = _http_client.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response)) @@ -147,7 +155,7 @@ class GitHubOAuth(OAuth): Returns an empty string when no usable email is found. """ try: - email_response = httpx.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers) + email_response = _http_client.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers) email_response.raise_for_status() email_records = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) except (httpx.HTTPStatusError, ValidationError): @@ -204,7 +212,7 @@ class GoogleOAuth(OAuth): "redirect_uri": self.redirect_uri, } headers = {"Accept": "application/json"} - response = httpx.post(self._TOKEN_URL, data=data, headers=headers) + response = _http_client.post(self._TOKEN_URL, data=data, headers=headers) response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response)) access_token = response_json.get("access_token") @@ -216,7 +224,7 @@ class GoogleOAuth(OAuth): def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"Bearer {token}"} - response = httpx.get(self._USER_INFO_URL, headers=headers) + response = _http_client.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() return _json_object(response) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index d5dc35ac977..190558e1f39 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -7,6 +7,7 @@ from flask_login import current_user from pydantic import TypeAdapter from sqlalchemy import select +from core.helper.http_client_pooling import get_pooled_http_client from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -38,6 +39,13 @@ NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo) NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary) +# Reuse a small pooled client for OAuth data source flows. +_http_client: httpx.Client = get_pooled_http_client( + "oauth:notion", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + + class OAuthDataSource: client_id: str client_secret: str @@ -75,7 +83,7 @@ class NotionOAuth(OAuthDataSource): data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} headers = {"Accept": "application/json"} auth = (self.client_id, self.client_secret) - response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) + response = _http_client.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) response_json = response.json() access_token = response_json.get("access_token") @@ -268,7 +276,7 @@ class NotionOAuth(OAuthDataSource): "Notion-Version": "2022-06-28", } - response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response = _http_client.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() results.extend(response_json.get("results", [])) @@ -283,7 +291,7 @@ class NotionOAuth(OAuthDataSource): "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", } - response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) + response = _http_client.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response_json = response.json() if response.status_code != 200: message = response_json.get("message", "unknown error") @@ -299,7 +307,7 @@ class NotionOAuth(OAuthDataSource): "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", } - response = httpx.get(url=self._NOTION_BOT_USER, headers=headers) + response = _http_client.get(url=self._NOTION_BOT_USER, headers=headers) response_json = response.json() if "object" in response_json and response_json["object"] == "user": user_type = response_json["type"] @@ -323,7 +331,7 @@ class NotionOAuth(OAuthDataSource): "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", } - response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response = _http_client.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() results.extend(response_json.get("results", [])) diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index e5e2319ce13..e63c9a3a4db 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -2,8 +2,14 @@ import json import httpx +from core.helper.http_client_pooling import get_pooled_http_client from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials +_http_client: httpx.Client = get_pooled_http_client( + "auth:jina_standalone", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + class JinaAuth(ApiKeyAuthBase): def __init__(self, credentials: AuthCredentials): @@ -31,7 +37,7 @@ class JinaAuth(ApiKeyAuthBase): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers): - return httpx.post(url, headers=headers, json=data) + return _http_client.post(url, headers=headers, json=data) def _handle_error(self, response): if response.status_code in {402, 409, 500}: diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index e5e2319ce13..8ea0b6cd69c 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -2,8 +2,14 @@ import json import httpx +from core.helper.http_client_pooling import get_pooled_http_client from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials +_http_client: httpx.Client = get_pooled_http_client( + "auth:jina", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + class JinaAuth(ApiKeyAuthBase): def __init__(self, credentials: AuthCredentials): @@ -31,7 +37,7 @@ class JinaAuth(ApiKeyAuthBase): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers): - return httpx.post(url, headers=headers, json=data) + return _http_client.post(url, headers=headers, json=data) def _handle_error(self, response): if response.status_code in {402, 409, 500}: diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 70d4ce1ee6b..54c595e0cbd 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -10,6 +10,7 @@ from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fix from typing_extensions import TypedDict from werkzeug.exceptions import InternalServerError +from core.helper.http_client_pooling import get_pooled_http_client from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -18,6 +19,11 @@ from models import Account, TenantAccountJoin, TenantAccountRole logger = logging.getLogger(__name__) +_http_client: httpx.Client = get_pooled_http_client( + "billing:default", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + class SubscriptionPlan(TypedDict): """Tenant subscriptionplan information.""" @@ -131,7 +137,7 @@ class BillingService: headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = httpx.request(method, url, json=json, params=params, headers=headers, follow_redirects=True) + response = _http_client.request(method, url, json=json, params=params, headers=headers, follow_redirects=True) if method == "GET" and response.status_code != httpx.codes.OK: raise ValueError("Unable to retrieve billing information. Please try again later or contact support.") if method == "PUT": diff --git a/api/services/website_service.py b/api/services/website_service.py index b2917ba1529..6a521a9cc0c 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -9,12 +9,23 @@ import httpx from flask_login import current_user from core.helper import encrypter +from core.helper.http_client_pooling import get_pooled_http_client from core.rag.extractor.firecrawl.firecrawl_app import CrawlStatusResponse, FirecrawlApp, FirecrawlDocumentData from core.rag.extractor.watercrawl.provider import WaterCrawlProvider from extensions.ext_redis import redis_client from extensions.ext_storage import storage from services.datasource_provider_service import DatasourceProviderService +# Reuse pooled HTTP clients to avoid creating new connections per request and ease testing. +_jina_http_client: httpx.Client = get_pooled_http_client( + "website:jinareader", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) +_adaptive_http_client: httpx.Client = get_pooled_http_client( + "website:adaptivecrawl", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + @dataclass class CrawlOptions: @@ -225,7 +236,7 @@ class WebsiteService: @classmethod def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]: if not request.options.crawl_sub_pages: - response = httpx.get( + response = _jina_http_client.get( f"https://r.jina.ai/{request.url}", headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) @@ -233,7 +244,7 @@ class WebsiteService: raise ValueError("Failed to crawl:") return {"status": "active", "data": response.json().get("data")} else: - response = httpx.post( + response = _adaptive_http_client.post( "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", json={ "url": request.url, @@ -296,7 +307,7 @@ class WebsiteService: @classmethod def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: - response = httpx.post( + response = _adaptive_http_client.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id}, @@ -312,7 +323,7 @@ class WebsiteService: } if crawl_status_data["status"] == "completed": - response = httpx.post( + response = _adaptive_http_client.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, @@ -374,7 +385,7 @@ class WebsiteService: @classmethod def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None: if not job_id: - response = httpx.get( + response = _jina_http_client.get( f"https://r.jina.ai/{url}", headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) @@ -383,7 +394,7 @@ class WebsiteService: return dict(response.json().get("data", {})) else: # Get crawl status first - status_response = httpx.post( + status_response = _adaptive_http_client.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id}, @@ -393,7 +404,7 @@ class WebsiteService: raise ValueError("Crawl job is not completed") # Get processed data - data_response = httpx.post( + data_response = _adaptive_http_client.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py index dc4c0fda1d4..f48c6da690f 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -79,7 +79,7 @@ class TestAuthIntegration: @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") @patch("services.auth.firecrawl.firecrawl.httpx.post") - @patch("services.auth.jina.jina.httpx.post") + @patch("services.auth.jina.jina._http_client.post") def test_multi_tenant_isolation( self, mock_jina_http, diff --git a/api/tests/unit_tests/core/datasource/test_website_crawl.py b/api/tests/unit_tests/core/datasource/test_website_crawl.py index 1d79db2640a..53000881ddc 100644 --- a/api/tests/unit_tests/core/datasource/test_website_crawl.py +++ b/api/tests/unit_tests/core/datasource/test_website_crawl.py @@ -560,7 +560,10 @@ class TestWebsiteService: mock_response = Mock() mock_response.json.return_value = {"code": 200, "data": {"taskId": "task-789"}} - mock_httpx_post = mocker.patch("services.website_service.httpx.post", return_value=mock_response) + mock_httpx_post = mocker.patch( + "services.website_service._adaptive_http_client.post", + return_value=mock_response, + ) from services.website_service import WebsiteCrawlApiRequest @@ -1340,7 +1343,7 @@ class TestProviderSpecificFeatures: "url": "https://example.com/page", }, } - mocker.patch("services.website_service.httpx.get", return_value=mock_response) + mocker.patch("services.website_service._jina_http_client.get", return_value=mock_response) from services.website_service import WebsiteCrawlApiRequest diff --git a/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py index c216906d68e..23894bd417b 100644 --- a/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py +++ b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py @@ -57,7 +57,7 @@ class TestBasePluginClientImpl: def test_stream_request_handles_data_lines_and_dict_payload(self, mocker): client = BasePluginClient() stream_mock = mocker.patch( - "core.plugin.impl.base.httpx.stream", + "httpx.Client.stream", return_value=_StreamContext([b"", b"data: hello", "world"]), ) diff --git a/api/tests/unit_tests/core/plugin/test_endpoint_client.py b/api/tests/unit_tests/core/plugin/test_endpoint_client.py index 48e30e9c2fc..ff9deb918af 100644 --- a/api/tests/unit_tests/core/plugin/test_endpoint_client.py +++ b/api/tests/unit_tests/core/plugin/test_endpoint_client.py @@ -10,12 +10,23 @@ Tests follow the Arrange-Act-Assert pattern for clarity. from unittest.mock import MagicMock, patch +import httpx import pytest from core.plugin.impl.endpoint import PluginEndpointClient from core.plugin.impl.exc import PluginDaemonInternalServerError +@pytest.fixture(autouse=True) +def _patch_shared_httpx_client(): + """Patch module-level client methods to delegate to module httpx.request/stream.""" + with ( + patch("core.plugin.impl.base._httpx_client.request", side_effect=lambda **kw: httpx.request(**kw)), + patch("core.plugin.impl.base._httpx_client.stream", side_effect=lambda **kw: httpx.stream(**kw)), + ): + yield + + class TestPluginEndpointClientDelete: """Unit tests for PluginEndpointClient delete_endpoint operation. diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index 3063ca01970..a3b1e5f6b0e 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -47,6 +47,20 @@ from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager +@pytest.fixture(autouse=True) +def _patch_shared_httpx_client(): + """Make BasePluginClient's module-level httpx client delegate to patched httpx.request/stream. + + After refactor, code uses core.plugin.impl.base._httpx_client directly. + Patch its request/stream to route through module-level httpx so existing mocks still apply. + """ + with ( + patch("core.plugin.impl.base._httpx_client.request", side_effect=lambda **kw: httpx.request(**kw)), + patch("core.plugin.impl.base._httpx_client.stream", side_effect=lambda **kw: httpx.stream(**kw)), + ): + yield + + class TestPluginRuntimeExecution: """Unit tests for plugin execution functionality. diff --git a/api/tests/unit_tests/libs/test_oauth_clients.py b/api/tests/unit_tests/libs/test_oauth_clients.py index ab468c86874..830284e6979 100644 --- a/api/tests/unit_tests/libs/test_oauth_clients.py +++ b/api/tests/unit_tests/libs/test_oauth_clients.py @@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest): ({}, None, True), ], ) - @patch("httpx.post", autospec=True) + @patch("libs.oauth._http_client.post", autospec=True) def test_should_retrieve_access_token( self, mock_post, oauth, mock_response, response_data, expected_token, should_raise ): @@ -109,7 +109,7 @@ class TestGitHubOAuth(BaseOAuthTest): ), ], ) - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email): user_response = MagicMock() user_response.json.return_value = user_data @@ -127,7 +127,7 @@ class TestGitHubOAuth(BaseOAuthTest): # The profile email is absent/null, so /user/emails should be called assert mock_get.call_count == 2 - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_skip_email_endpoint_when_profile_email_present(self, mock_get, oauth): """When the /user profile already contains an email, do not call /user/emails.""" user_response = MagicMock() @@ -162,7 +162,7 @@ class TestGitHubOAuth(BaseOAuthTest): ), ], ) - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_use_noreply_email_when_no_usable_email(self, mock_get, oauth, user_data, email_data): user_response = MagicMock() user_response.json.return_value = user_data @@ -177,7 +177,7 @@ class TestGitHubOAuth(BaseOAuthTest): assert user_info.id == str(user_data["id"]) assert user_info.email == "12345@users.noreply.github.com" - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_use_noreply_email_when_email_endpoint_fails(self, mock_get, oauth): user_response = MagicMock() user_response.json.return_value = {"id": 12345, "login": "testuser", "name": "Test User"} @@ -194,7 +194,7 @@ class TestGitHubOAuth(BaseOAuthTest): assert user_info.id == "12345" assert user_info.email == "12345@users.noreply.github.com" - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_handle_network_errors(self, mock_get, oauth): mock_get.side_effect = httpx.RequestError("Network error") @@ -240,7 +240,7 @@ class TestGoogleOAuth(BaseOAuthTest): ({}, None, True), ], ) - @patch("httpx.post", autospec=True) + @patch("libs.oauth._http_client.post", autospec=True) def test_should_retrieve_access_token( self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise ): @@ -274,7 +274,7 @@ class TestGoogleOAuth(BaseOAuthTest): ({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string ], ) - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name): mock_response.json.return_value = user_data mock_get.return_value = mock_response @@ -295,7 +295,7 @@ class TestGoogleOAuth(BaseOAuthTest): httpx.TimeoutException, ], ) - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_handle_http_errors(self, mock_get, oauth, exception_type): mock_response = MagicMock() mock_response.raise_for_status.side_effect = exception_type("Error") diff --git a/api/tests/unit_tests/services/auth/test_jina_auth.py b/api/tests/unit_tests/services/auth/test_jina_auth.py index 67f252390d5..2c34d46f1e4 100644 --- a/api/tests/unit_tests/services/auth/test_jina_auth.py +++ b/api/tests/unit_tests/services/auth/test_jina_auth.py @@ -35,7 +35,7 @@ class TestJinaAuth: JinaAuth(credentials) assert str(exc_info.value) == "No API key provided" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_validate_valid_credentials_successfully(self, mock_post): """Test successful credential validation""" mock_response = MagicMock() @@ -53,7 +53,7 @@ class TestJinaAuth: json={"url": "https://example.com"}, ) - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_http_402_error(self, mock_post): """Test handling of 402 Payment Required error""" mock_response = MagicMock() @@ -68,7 +68,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_http_409_error(self, mock_post): """Test handling of 409 Conflict error""" mock_response = MagicMock() @@ -83,7 +83,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_http_500_error(self, mock_post): """Test handling of 500 Internal Server Error""" mock_response = MagicMock() @@ -98,7 +98,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_unexpected_error_with_text_response(self, mock_post): """Test handling of unexpected errors with text response""" mock_response = MagicMock() @@ -114,7 +114,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_unexpected_error_without_text(self, mock_post): """Test handling of unexpected errors without text response""" mock_response = MagicMock() @@ -130,7 +130,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_network_errors(self, mock_post): """Test handling of network connection errors""" mock_post.side_effect = httpx.ConnectError("Network error") diff --git a/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py index c2fcd71875c..4b5a97bf3fd 100644 --- a/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py +++ b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py @@ -60,7 +60,7 @@ def test_prepare_headers_includes_bearer_api_key(jina_module: ModuleType) -> Non def test_post_request_calls_httpx(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None: auth = jina_module.JinaAuth(_credentials(api_key="k")) post_mock = MagicMock(name="httpx.post") - monkeypatch.setattr(jina_module.httpx, "post", post_mock) + monkeypatch.setattr(jina_module._http_client, "post", post_mock) auth._post_request("https://r.jina.ai", {"url": "https://example.com"}, {"h": "v"}) post_mock.assert_called_once_with("https://r.jina.ai", headers={"h": "v"}, json={"url": "https://example.com"}) @@ -72,7 +72,7 @@ def test_validate_credentials_success(jina_module: ModuleType, monkeypatch: pyte response = MagicMock() response.status_code = 200 post_mock = MagicMock(return_value=response) - monkeypatch.setattr(jina_module.httpx, "post", post_mock) + monkeypatch.setattr(jina_module._http_client, "post", post_mock) assert auth.validate_credentials() is True post_mock.assert_called_once_with( @@ -90,7 +90,7 @@ def test_validate_credentials_non_200_raises_via_handle_error( response = MagicMock() response.status_code = 402 response.json.return_value = {"error": "Payment required"} - monkeypatch.setattr(jina_module.httpx, "post", MagicMock(return_value=response)) + monkeypatch.setattr(jina_module._http_client, "post", MagicMock(return_value=response)) with pytest.raises(Exception, match="Status code: 402.*Payment required"): auth.validate_credentials() @@ -151,7 +151,7 @@ def test_validate_credentials_propagates_network_errors( jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch ) -> None: auth = jina_module.JinaAuth(_credentials(api_key="k")) - monkeypatch.setattr(jina_module.httpx, "post", MagicMock(side_effect=httpx.ConnectError("boom"))) + monkeypatch.setattr(jina_module._http_client, "post", MagicMock(side_effect=httpx.ConnectError("boom"))) with pytest.raises(httpx.ConnectError, match="boom"): auth.validate_credentials() diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 316381f0ca1..b3d2e608025 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -38,7 +38,7 @@ class TestBillingServiceSendRequest: @pytest.fixture def mock_httpx_request(self): """Mock httpx.request for testing.""" - with patch("services.billing_service.httpx.request") as mock_request: + with patch("services.billing_service._http_client.request") as mock_request: yield mock_request @pytest.fixture diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index 3df7d500cf2..da414816ff1 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock, patch +import httpx import pytest from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session @@ -71,6 +72,8 @@ class TestDatasourceProviderService: @pytest.fixture(autouse=True) def patch_externals(self): with ( + patch("core.plugin.impl.base._httpx_client.request", side_effect=lambda **kw: httpx.request(**kw)), + patch("core.plugin.impl.base._httpx_client.stream", side_effect=lambda **kw: httpx.stream(**kw)), patch("httpx.request") as mock_httpx, patch("services.datasource_provider_service.dify_config") as mock_cfg, patch("services.datasource_provider_service.encrypter") as mock_enc, diff --git a/api/tests/unit_tests/services/test_website_service.py b/api/tests/unit_tests/services/test_website_service.py index e973da7d564..b0ddc7388a1 100644 --- a/api/tests/unit_tests/services/test_website_service.py +++ b/api/tests/unit_tests/services/test_website_service.py @@ -343,7 +343,7 @@ def test_crawl_with_watercrawl_passes_options_dict(monkeypatch: pytest.MonkeyPat def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPatch) -> None: get_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"title": "t"}})) - monkeypatch.setattr(website_service_module.httpx, "get", get_mock) + monkeypatch.setattr(website_service_module._jina_http_client, "get", get_mock) req = WebsiteCrawlApiRequest( provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} @@ -356,7 +356,11 @@ def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPat def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + monkeypatch.setattr( + website_service_module._jina_http_client, + "get", + MagicMock(return_value=_DummyHttpxResponse({"code": 500})), + ) req = WebsiteCrawlApiRequest( provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} ).to_crawl_request() @@ -368,7 +372,7 @@ def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPat def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatch) -> None: post_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"taskId": "t1"}})) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) req = WebsiteCrawlApiRequest( provider="jinareader", @@ -384,7 +388,7 @@ def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatc def test_crawl_with_jinareader_multi_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( - website_service_module.httpx, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400})) + website_service_module._adaptive_http_client, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400})) ) req = WebsiteCrawlApiRequest( provider="jinareader", @@ -482,7 +486,7 @@ def test_get_jinareader_status_active(monkeypatch: pytest.MonkeyPatch) -> None: } ) ) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) result = WebsiteService._get_jinareader_status("job-1", "k") assert result["status"] == "active" @@ -518,7 +522,7 @@ def test_get_jinareader_status_completed_formats_processed_items(monkeypatch: py } } post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) result = WebsiteService._get_jinareader_status("job-1", "k") assert result["status"] == "completed" @@ -619,7 +623,7 @@ def test_get_watercrawl_url_data_delegates(monkeypatch: pytest.MonkeyPatch) -> N def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( - website_service_module.httpx, + website_service_module._jina_http_client, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"url": "u"}})), ) @@ -627,7 +631,11 @@ def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.Monk def test_get_jinareader_url_data_without_job_id_failure(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + monkeypatch.setattr( + website_service_module._jina_http_client, + "get", + MagicMock(return_value=_DummyHttpxResponse({"code": 500})), + ) with pytest.raises(ValueError, match="Failed to crawl$"): WebsiteService._get_jinareader_url_data("", "u", "k") @@ -637,7 +645,7 @@ def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(mon processed_payload = {"data": {"processed": {"u1": {"data": {"url": "u", "title": "t"}}}}} post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") == {"url": "u", "title": "t"} assert post_mock.call_count == 2 @@ -645,7 +653,7 @@ def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(mon def test_get_jinareader_url_data_with_job_id_not_completed_raises(monkeypatch: pytest.MonkeyPatch) -> None: post_mock = MagicMock(return_value=_DummyHttpxResponse({"data": {"status": "active"}})) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) with pytest.raises(ValueError, match=r"Crawl job is no\s*t completed"): WebsiteService._get_jinareader_url_data("job-1", "u", "k") @@ -658,7 +666,7 @@ def test_get_jinareader_url_data_with_job_id_completed_but_not_found_returns_non processed_payload = {"data": {"processed": {"u1": {"data": {"url": "other"}}}}} post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") is None