diff --git a/api/extensions/otel/instrumentation.py b/api/extensions/otel/instrumentation.py index b73ba8df8cd..0a70f6ebe97 100644 --- a/api/extensions/otel/instrumentation.py +++ b/api/extensions/otel/instrumentation.py @@ -1,5 +1,7 @@ import contextlib import logging +from collections.abc import Callable +from typing import Protocol, cast import flask from opentelemetry.instrumentation.celery import CeleryInstrumentor @@ -21,6 +23,38 @@ from extensions.otel.runtime import is_celery_worker logger = logging.getLogger(__name__) +class SupportsInstrument(Protocol): + def instrument(self, **kwargs: object) -> None: ... + + +class SupportsFlaskInstrumentor(Protocol): + def instrument_app( + self, app: DifyApp, response_hook: Callable[[Span, str, list], None] | None = None, **kwargs: object + ) -> None: ... + + +# Some OpenTelemetry instrumentor constructors are typed loosely enough that +# pyrefly infers `NoneType`. Narrow the instances to just the methods we use +# while leaving runtime behavior unchanged. +def _new_celery_instrumentor() -> SupportsInstrument: + return cast( + SupportsInstrument, + CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()), + ) + + +def _new_httpx_instrumentor() -> SupportsInstrument: + return cast(SupportsInstrument, HTTPXClientInstrumentor()) + + +def _new_redis_instrumentor() -> SupportsInstrument: + return cast(SupportsInstrument, RedisInstrumentor()) + + +def _new_sqlalchemy_instrumentor() -> SupportsInstrument: + return cast(SupportsInstrument, SQLAlchemyInstrumentor()) + + class ExceptionLoggingHandler(logging.Handler): """ Handler that records exceptions to the current OpenTelemetry span. @@ -97,7 +131,7 @@ def init_flask_instrumentor(app: DifyApp) -> None: from opentelemetry.instrumentation.flask import FlaskInstrumentor - instrumentor = FlaskInstrumentor() + instrumentor = cast(SupportsFlaskInstrumentor, FlaskInstrumentor()) if dify_config.DEBUG: logger.info("Initializing Flask instrumentor") instrumentor.instrument_app(app, response_hook=response_hook) @@ -106,21 +140,21 @@ def init_flask_instrumentor(app: DifyApp) -> None: def init_sqlalchemy_instrumentor(app: DifyApp) -> None: with app.app_context(): engines = list(app.extensions["sqlalchemy"].engines.values()) - SQLAlchemyInstrumentor().instrument(enable_commenter=True, engines=engines) + _new_sqlalchemy_instrumentor().instrument(enable_commenter=True, engines=engines) def init_redis_instrumentor() -> None: - RedisInstrumentor().instrument() + _new_redis_instrumentor().instrument() def init_httpx_instrumentor() -> None: - HTTPXClientInstrumentor().instrument() + _new_httpx_instrumentor().instrument() def init_instruments(app: DifyApp) -> None: if not is_celery_worker(): init_flask_instrumentor(app) - CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument() + _new_celery_instrumentor().instrument() instrument_exception_logging() init_sqlalchemy_instrumentor(app)