Skip to content

Commit 8060a64

Browse files
ref(client): Improve get_integration typing (getsentry#3550)
Improve `get_integration` typing to make it clear that we return an `Optional[Integration]`. Further, add overloads to specify that when called with some integration type `I` (i.e. `I` is a subclass of `Integration`), then `get_integration` guarantees a return value of `Optional[I]`. These changes should enhance type safety by explicitly guaranteeing the existing behavior of `get_integration`.
1 parent 2a2fab1 commit 8060a64

22 files changed

+132
-80
lines changed

sentry_sdk/client.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Mapping
66
from datetime import datetime, timezone
77
from importlib import import_module
8-
from typing import cast
8+
from typing import cast, overload
99

1010
from sentry_sdk._compat import PY37, check_uwsgi_thread_support
1111
from sentry_sdk.utils import (
@@ -54,6 +54,7 @@
5454
from typing import Sequence
5555
from typing import Type
5656
from typing import Union
57+
from typing import TypeVar
5758

5859
from sentry_sdk._types import Event, Hint, SDKInfo
5960
from sentry_sdk.integrations import Integration
@@ -62,6 +63,7 @@
6263
from sentry_sdk.session import Session
6364
from sentry_sdk.transport import Transport
6465

66+
I = TypeVar("I", bound=Integration) # noqa: E741
6567

6668
_client_init_debug = ContextVar("client_init_debug")
6769

@@ -195,8 +197,20 @@ def capture_session(self, *args, **kwargs):
195197
# type: (*Any, **Any) -> None
196198
return None
197199

198-
def get_integration(self, *args, **kwargs):
199-
# type: (*Any, **Any) -> Any
200+
if TYPE_CHECKING:
201+
202+
@overload
203+
def get_integration(self, name_or_class):
204+
# type: (str) -> Optional[Integration]
205+
...
206+
207+
@overload
208+
def get_integration(self, name_or_class):
209+
# type: (type[I]) -> Optional[I]
210+
...
211+
212+
def get_integration(self, name_or_class):
213+
# type: (Union[str, type[Integration]]) -> Optional[Integration]
200214
return None
201215

202216
def close(self, *args, **kwargs):
@@ -815,10 +829,22 @@ def capture_session(
815829
else:
816830
self.session_flusher.add_session(session)
817831

832+
if TYPE_CHECKING:
833+
834+
@overload
835+
def get_integration(self, name_or_class):
836+
# type: (str) -> Optional[Integration]
837+
...
838+
839+
@overload
840+
def get_integration(self, name_or_class):
841+
# type: (type[I]) -> Optional[I]
842+
...
843+
818844
def get_integration(
819845
self, name_or_class # type: Union[str, Type[Integration]]
820846
):
821-
# type: (...) -> Any
847+
# type: (...) -> Optional[Integration]
822848
"""Returns the integration for this client by name or class.
823849
If the client does not have that integration then `None` is returned.
824850
"""

sentry_sdk/integrations/aiohttp.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22
import weakref
3+
from functools import wraps
34

45
import sentry_sdk
56
from sentry_sdk.api import continue_trace
@@ -156,11 +157,14 @@ async def sentry_app_handle(self, request, *args, **kwargs):
156157

157158
old_urldispatcher_resolve = UrlDispatcher.resolve
158159

160+
@wraps(old_urldispatcher_resolve)
159161
async def sentry_urldispatcher_resolve(self, request):
160162
# type: (UrlDispatcher, Request) -> UrlMappingMatchInfo
161163
rv = await old_urldispatcher_resolve(self, request)
162164

163165
integration = sentry_sdk.get_client().get_integration(AioHttpIntegration)
166+
if integration is None:
167+
return rv
164168

165169
name = None
166170

sentry_sdk/integrations/anthropic.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from sentry_sdk.scope import should_send_default_pii
88
from sentry_sdk.utils import (
99
capture_internal_exceptions,
10-
ensure_integration_enabled,
1110
event_from_exception,
1211
package_version,
1312
)
@@ -78,10 +77,11 @@ def _calculate_token_usage(result, span):
7877
def _wrap_message_create(f):
7978
# type: (Any) -> Any
8079
@wraps(f)
81-
@ensure_integration_enabled(AnthropicIntegration, f)
8280
def _sentry_patched_create(*args, **kwargs):
8381
# type: (*Any, **Any) -> Any
84-
if "messages" not in kwargs:
82+
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
83+
84+
if integration is None or "messages" not in kwargs:
8585
return f(*args, **kwargs)
8686

8787
try:
@@ -106,8 +106,6 @@ def _sentry_patched_create(*args, **kwargs):
106106
span.__exit__(None, None, None)
107107
raise exc from None
108108

109-
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
110-
111109
with capture_internal_exceptions():
112110
span.set_data(SPANDATA.AI_MODEL_ID, model)
113111
span.set_data(SPANDATA.AI_STREAMING, False)

sentry_sdk/integrations/atexit.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import sentry_sdk
66
from sentry_sdk.utils import logger
77
from sentry_sdk.integrations import Integration
8-
from sentry_sdk.utils import ensure_integration_enabled
9-
108
from typing import TYPE_CHECKING
119

1210
if TYPE_CHECKING:
@@ -44,13 +42,16 @@ def __init__(self, callback=None):
4442
def setup_once():
4543
# type: () -> None
4644
@atexit.register
47-
@ensure_integration_enabled(AtexitIntegration)
4845
def _shutdown():
4946
# type: () -> None
50-
logger.debug("atexit: got shutdown signal")
5147
client = sentry_sdk.get_client()
5248
integration = client.get_integration(AtexitIntegration)
5349

50+
if integration is None:
51+
return
52+
53+
logger.debug("atexit: got shutdown signal")
5454
logger.debug("atexit: shutting down client")
5555
sentry_sdk.get_isolation_scope().end_session()
56+
5657
client.close(callback=integration.callback)

sentry_sdk/integrations/aws_lambda.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import json
23
import re
34
import sys
@@ -70,7 +71,7 @@ def sentry_init_error(*args, **kwargs):
7071

7172
def _wrap_handler(handler):
7273
# type: (F) -> F
73-
@ensure_integration_enabled(AwsLambdaIntegration, handler)
74+
@functools.wraps(handler)
7475
def sentry_handler(aws_event, aws_context, *args, **kwargs):
7576
# type: (Any, Any, *Any, **Any) -> Any
7677

@@ -84,6 +85,12 @@ def sentry_handler(aws_event, aws_context, *args, **kwargs):
8485
# will be the same for all events in the list, since they're all hitting
8586
# the lambda in the same request.)
8687

88+
client = sentry_sdk.get_client()
89+
integration = client.get_integration(AwsLambdaIntegration)
90+
91+
if integration is None:
92+
return handler(aws_event, aws_context, *args, **kwargs)
93+
8794
if isinstance(aws_event, list) and len(aws_event) >= 1:
8895
request_data = aws_event[0]
8996
batch_size = len(aws_event)
@@ -97,9 +104,6 @@ def sentry_handler(aws_event, aws_context, *args, **kwargs):
97104
# this is empty
98105
request_data = {}
99106

100-
client = sentry_sdk.get_client()
101-
integration = client.get_integration(AwsLambdaIntegration)
102-
103107
configured_time = aws_context.get_remaining_time_in_millis()
104108

105109
with sentry_sdk.isolation_scope() as scope:

sentry_sdk/integrations/bottle.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import functools
2+
13
import sentry_sdk
24
from sentry_sdk.tracing import SOURCE_FOR_STYLE
35
from sentry_sdk.utils import (
@@ -81,10 +83,12 @@ def sentry_patched_wsgi_app(self, environ, start_response):
8183

8284
old_handle = Bottle._handle
8385

84-
@ensure_integration_enabled(BottleIntegration, old_handle)
86+
@functools.wraps(old_handle)
8587
def _patched_handle(self, environ):
8688
# type: (Bottle, Dict[str, Any]) -> Any
8789
integration = sentry_sdk.get_client().get_integration(BottleIntegration)
90+
if integration is None:
91+
return old_handle(self, environ)
8892

8993
scope = sentry_sdk.get_isolation_scope()
9094
scope._name = "bottle"

sentry_sdk/integrations/celery/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -248,13 +248,15 @@ def __exit__(self, exc_type, exc_value, traceback):
248248
def _wrap_task_run(f):
249249
# type: (F) -> F
250250
@wraps(f)
251-
@ensure_integration_enabled(CeleryIntegration, f)
252251
def apply_async(*args, **kwargs):
253252
# type: (*Any, **Any) -> Any
254253
# Note: kwargs can contain headers=None, so no setdefault!
255254
# Unsure which backend though.
256-
kwarg_headers = kwargs.get("headers") or {}
257255
integration = sentry_sdk.get_client().get_integration(CeleryIntegration)
256+
if integration is None:
257+
return f(*args, **kwargs)
258+
259+
kwarg_headers = kwargs.get("headers") or {}
258260
propagate_traces = kwarg_headers.pop(
259261
"sentry-propagate-traces", integration.propagate_traces
260262
)

sentry_sdk/integrations/cohere.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
import sentry_sdk
1515
from sentry_sdk.scope import should_send_default_pii
1616
from sentry_sdk.integrations import DidNotEnable, Integration
17-
from sentry_sdk.utils import (
18-
capture_internal_exceptions,
19-
event_from_exception,
20-
ensure_integration_enabled,
21-
)
17+
from sentry_sdk.utils import capture_internal_exceptions, event_from_exception
2218

2319
try:
2420
from cohere.client import Client
@@ -134,13 +130,15 @@ def collect_chat_response_fields(span, res, include_pii):
134130
set_data_normalized(span, "ai.warnings", res.meta.warnings)
135131

136132
@wraps(f)
137-
@ensure_integration_enabled(CohereIntegration, f)
138133
def new_chat(*args, **kwargs):
139134
# type: (*Any, **Any) -> Any
140-
if "message" not in kwargs:
141-
return f(*args, **kwargs)
135+
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
142136

143-
if not isinstance(kwargs.get("message"), str):
137+
if (
138+
integration is None
139+
or "message" not in kwargs
140+
or not isinstance(kwargs.get("message"), str)
141+
):
144142
return f(*args, **kwargs)
145143

146144
message = kwargs.get("message")
@@ -158,8 +156,6 @@ def new_chat(*args, **kwargs):
158156
span.__exit__(None, None, None)
159157
raise e from None
160158

161-
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
162-
163159
with capture_internal_exceptions():
164160
if should_send_default_pii() and integration.include_prompts:
165161
set_data_normalized(
@@ -227,15 +223,17 @@ def _wrap_embed(f):
227223
# type: (Callable[..., Any]) -> Callable[..., Any]
228224

229225
@wraps(f)
230-
@ensure_integration_enabled(CohereIntegration, f)
231226
def new_embed(*args, **kwargs):
232227
# type: (*Any, **Any) -> Any
228+
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
229+
if integration is None:
230+
return f(*args, **kwargs)
231+
233232
with sentry_sdk.start_span(
234233
op=consts.OP.COHERE_EMBEDDINGS_CREATE,
235234
name="Cohere Embedding Creation",
236235
origin=CohereIntegration.origin,
237236
) as span:
238-
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
239237
if "texts" in kwargs and (
240238
should_send_default_pii() and integration.include_prompts
241239
):

sentry_sdk/integrations/django/__init__.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,11 @@ def _set_transaction_name_and_source(scope, transaction_style, request):
411411
pass
412412

413413

414-
@ensure_integration_enabled(DjangoIntegration)
415414
def _before_get_response(request):
416415
# type: (WSGIRequest) -> None
417416
integration = sentry_sdk.get_client().get_integration(DjangoIntegration)
417+
if integration is None:
418+
return
418419

419420
_patch_drf()
420421

@@ -440,11 +441,10 @@ def _attempt_resolve_again(request, scope, transaction_style):
440441
_set_transaction_name_and_source(scope, transaction_style, request)
441442

442443

443-
@ensure_integration_enabled(DjangoIntegration)
444444
def _after_get_response(request):
445445
# type: (WSGIRequest) -> None
446446
integration = sentry_sdk.get_client().get_integration(DjangoIntegration)
447-
if integration.transaction_style != "url":
447+
if integration is None or integration.transaction_style != "url":
448448
return
449449

450450
scope = sentry_sdk.get_current_scope()
@@ -510,11 +510,12 @@ def wsgi_request_event_processor(event, hint):
510510
return wsgi_request_event_processor
511511

512512

513-
@ensure_integration_enabled(DjangoIntegration)
514513
def _got_request_exception(request=None, **kwargs):
515514
# type: (WSGIRequest, **Any) -> None
516515
client = sentry_sdk.get_client()
517516
integration = client.get_integration(DjangoIntegration)
517+
if integration is None:
518+
return
518519

519520
if request is not None and integration.transaction_style == "url":
520521
scope = sentry_sdk.get_current_scope()

sentry_sdk/integrations/fastapi.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def _sentry_call(*args, **kwargs):
9999

100100
async def _sentry_app(*args, **kwargs):
101101
# type: (*Any, **Any) -> Any
102-
if sentry_sdk.get_client().get_integration(FastApiIntegration) is None:
102+
integration = sentry_sdk.get_client().get_integration(FastApiIntegration)
103+
if integration is None:
103104
return await old_app(*args, **kwargs)
104105

105-
integration = sentry_sdk.get_client().get_integration(FastApiIntegration)
106106
request = args[0]
107107

108108
_set_transaction_name_and_source(

sentry_sdk/integrations/flask.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,12 @@ def _set_transaction_name_and_source(scope, transaction_style, request):
118118
pass
119119

120120

121-
@ensure_integration_enabled(FlaskIntegration)
122121
def _request_started(app, **kwargs):
123122
# type: (Flask, **Any) -> None
124123
integration = sentry_sdk.get_client().get_integration(FlaskIntegration)
124+
if integration is None:
125+
return
126+
125127
request = flask_request._get_current_object()
126128

127129
# Set the transaction name and source here,

sentry_sdk/integrations/gcp.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import sys
23
from copy import deepcopy
34
from datetime import datetime, timedelta, timezone
@@ -13,7 +14,6 @@
1314
from sentry_sdk.utils import (
1415
AnnotatedValue,
1516
capture_internal_exceptions,
16-
ensure_integration_enabled,
1717
event_from_exception,
1818
logger,
1919
TimeoutThread,
@@ -39,12 +39,14 @@
3939

4040
def _wrap_func(func):
4141
# type: (F) -> F
42-
@ensure_integration_enabled(GcpIntegration, func)
42+
@functools.wraps(func)
4343
def sentry_func(functionhandler, gcp_event, *args, **kwargs):
4444
# type: (Any, Any, *Any, **Any) -> Any
4545
client = sentry_sdk.get_client()
4646

4747
integration = client.get_integration(GcpIntegration)
48+
if integration is None:
49+
return func(functionhandler, gcp_event, *args, **kwargs)
4850

4951
configured_time = environ.get("FUNCTION_TIMEOUT_SEC")
5052
if not configured_time:

0 commit comments

Comments
 (0)