Skip to content

Commit 4bf4859

Browse files
authored
fix: Second attempt at fixing trace propagation in Celery 4.2+ (getsentry#831)
Follow-up to getsentry#824 getsentry#825
1 parent b0f2f41 commit 4bf4859

File tree

3 files changed

+26
-16
lines changed

3 files changed

+26
-16
lines changed

sentry_sdk/integrations/celery.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def sentry_build_tracer(name, task, *args, **kwargs):
6161
# short-circuits to task.run if it thinks it's safe.
6262
task.__call__ = _wrap_task_call(task, task.__call__)
6363
task.run = _wrap_task_call(task, task.run)
64-
task.apply_async = _wrap_apply_async(task, task.apply_async)
6564

6665
# `build_tracer` is apparently called for every task
6766
# invocation. Can't wrap every celery task for every invocation
@@ -72,6 +71,10 @@ def sentry_build_tracer(name, task, *args, **kwargs):
7271

7372
trace.build_tracer = sentry_build_tracer
7473

74+
from celery.app.task import Task # type: ignore
75+
76+
Task.apply_async = _wrap_apply_async(Task.apply_async)
77+
7578
_patch_worker_exit()
7679

7780
# This logger logs every status of every task that ran on the worker.
@@ -85,30 +88,31 @@ def sentry_build_tracer(name, task, *args, **kwargs):
8588
ignore_logger("celery.redirected")
8689

8790

88-
def _wrap_apply_async(task, f):
89-
# type: (Any, F) -> F
91+
def _wrap_apply_async(f):
92+
# type: (F) -> F
9093
@wraps(f)
9194
def apply_async(*args, **kwargs):
9295
# type: (*Any, **Any) -> Any
9396
hub = Hub.current
9497
integration = hub.get_integration(CeleryIntegration)
9598
if integration is not None and integration.propagate_traces:
96-
with hub.start_span(op="celery.submit", description=task.name):
99+
with hub.start_span(op="celery.submit", description=args[0].name):
97100
with capture_internal_exceptions():
98101
headers = dict(hub.iter_trace_propagation_headers())
102+
99103
if headers:
100-
kwarg_headers = kwargs.setdefault("headers", {})
104+
# Note: kwargs can contain headers=None, so no setdefault!
105+
# Unsure which backend though.
106+
kwarg_headers = kwargs.get("headers") or {}
101107
kwarg_headers.update(headers)
102108

103109
# https://github.com/celery/celery/issues/4875
104110
#
105111
# Need to setdefault the inner headers too since other
106112
# tracing tools (dd-trace-py) also employ this exact
107113
# workaround and we don't want to break them.
108-
#
109-
# This is not reproducible outside of AMQP, therefore no
110-
# tests!
111114
kwarg_headers.setdefault("headers", {}).update(headers)
115+
kwargs["headers"] = kwarg_headers
112116

113117
return f(*args, **kwargs)
114118
else:

tests/conftest.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,7 @@ def append_envelope(envelope):
235235
@pytest.fixture
236236
def capture_events_forksafe(monkeypatch, capture_events, request):
237237
def inner():
238-
in_process_events = capture_events()
239-
240-
@request.addfinalizer
241-
def _():
242-
assert not in_process_events
238+
capture_events()
243239

244240
events_r, events_w = os.pipe()
245241
events_r = os.fdopen(events_r, "rb", 0)

tests/integrations/celery/test_celery.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def inner(propagate_traces=True, backend="always_eager", **kwargs):
4242

4343
# this backend requires capture_events_forksafe
4444
celery.conf.worker_max_tasks_per_child = 1
45+
celery.conf.worker_concurrency = 1
4546
celery.conf.broker_url = "redis://127.0.0.1:6379"
4647
celery.conf.result_backend = "redis://127.0.0.1:6379"
4748
celery.conf.task_always_eager = False
@@ -297,7 +298,7 @@ def dummy_task(self):
297298

298299

299300
@pytest.mark.forked
300-
def test_redis_backend(init_celery, capture_events_forksafe, tmpdir):
301+
def test_redis_backend_trace_propagation(init_celery, capture_events_forksafe, tmpdir):
301302
celery = init_celery(traces_sample_rate=1.0, backend="redis", debug=True)
302303

303304
events = capture_events_forksafe()
@@ -309,8 +310,9 @@ def dummy_task(self):
309310
runs.append(1)
310311
1 / 0
311312

312-
# Curious: Cannot use delay() here or py2.7-celery-4.2 crashes
313-
res = dummy_task.apply_async()
313+
with start_transaction(name="submit_celery"):
314+
# Curious: Cannot use delay() here or py2.7-celery-4.2 crashes
315+
res = dummy_task.apply_async()
314316

315317
with pytest.raises(Exception):
316318
# Celery 4.1 raises a gibberish exception
@@ -319,6 +321,13 @@ def dummy_task(self):
319321
# if this is nonempty, the worker never really forked
320322
assert not runs
321323

324+
submit_transaction = events.read_event()
325+
assert submit_transaction["type"] == "transaction"
326+
assert submit_transaction["transaction"] == "submit_celery"
327+
(span,) = submit_transaction["spans"]
328+
assert span["op"] == "celery.submit"
329+
assert span["description"] == "dummy_task"
330+
322331
event = events.read_event()
323332
(exception,) = event["exception"]["values"]
324333
assert exception["type"] == "ZeroDivisionError"
@@ -327,6 +336,7 @@ def dummy_task(self):
327336
assert (
328337
transaction["contexts"]["trace"]["trace_id"]
329338
== event["contexts"]["trace"]["trace_id"]
339+
== submit_transaction["contexts"]["trace"]["trace_id"]
330340
)
331341

332342
events.read_flush()

0 commit comments

Comments
 (0)