Skip to content

Commit 2305fea

Browse files
committed
fix: Fix celery integration with custom tag class
1 parent f0bbd04 commit 2305fea

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

sentry_sdk/integrations/celery.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def sentry_build_tracer(name, task, *args, **kwargs):
6161
# short-circuits to task.run if it thinks it's safe.
6262
task.__call__ = _wrap_task_call(task, task.__call__)
6363
task.run = _wrap_task_call(task, task.run)
64+
task.apply_async = _wrap_apply_async(task, task.apply_async)
6465

6566
# `build_tracer` is apparently called for every task
6667
# invocation. Can't wrap every celery task for every invocation
@@ -71,10 +72,6 @@ def sentry_build_tracer(name, task, *args, **kwargs):
7172

7273
trace.build_tracer = sentry_build_tracer
7374

74-
from celery.app.task import Task # type: ignore
75-
76-
Task.apply_async = _wrap_apply_async(Task.apply_async)
77-
7875
_patch_worker_exit()
7976

8077
# This logger logs every status of every task that ran on the worker.
@@ -88,15 +85,15 @@ def sentry_build_tracer(name, task, *args, **kwargs):
8885
ignore_logger("celery.redirected")
8986

9087

91-
def _wrap_apply_async(f):
92-
# type: (F) -> F
88+
def _wrap_apply_async(task, f):
89+
# type: (Any, F) -> F
9390
@wraps(f)
9491
def apply_async(*args, **kwargs):
9592
# type: (*Any, **Any) -> Any
9693
hub = Hub.current
9794
integration = hub.get_integration(CeleryIntegration)
9895
if integration is not None and integration.propagate_traces:
99-
with hub.start_span(op="celery.submit", description=args[0].name):
96+
with hub.start_span(op="celery.submit", description=task.name):
10097
with capture_internal_exceptions():
10198
headers = dict(hub.iter_trace_propagation_headers())
10299

tests/integrations/celery/test_celery.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,16 @@ def celery_invocation(request):
9999
"""
100100
return request.param
101101

102+
@pytest.mark.parametrize("custom_celery_task_cls", (True, False))
103+
def test_simple(capture_events, celery, celery_invocation, custom_celery_task_cls):
104+
105+
if custom_celery_task_cls:
106+
class CustomTask(celery.Task):
107+
def __call__(self, *args, **kwargs):
108+
return self.run(*args, **kwargs)
109+
110+
celery.Task = CustomTask
102111

103-
def test_simple(capture_events, celery, celery_invocation):
104112
events = capture_events()
105113

106114
@celery.task(name="dummy_task")

0 commit comments

Comments
 (0)