@@ -23,12 +23,12 @@ def inner(signal, f):
23
23
24
24
@pytest .fixture
25
25
def init_celery (sentry_init , request ):
26
- def inner (propagate_traces = True , backend = "always_eager" , ** kwargs ):
26
+ def inner (propagate_traces = True , backend = "always_eager" , celery_kwargs = None , ** kwargs ):
27
27
sentry_init (
28
28
integrations = [CeleryIntegration (propagate_traces = propagate_traces )],
29
29
** kwargs
30
30
)
31
- celery = Celery (__name__ )
31
+ celery = Celery (__name__ , ** ( celery_kwargs or {}) )
32
32
33
33
if backend == "always_eager" :
34
34
if VERSION < (4 ,):
@@ -100,20 +100,7 @@ def celery_invocation(request):
100
100
return request .param
101
101
102
102
103
- @pytest .mark .parametrize ("custom_celery_task_cls" , (True , False ))
104
- def test_simple (capture_events , celery , celery_invocation , custom_celery_task_cls ):
105
-
106
- if custom_celery_task_cls :
107
-
108
- custom_calls = []
109
-
110
- class CustomTask (celery .Task ):
111
- def __call__ (self , * args , ** kwargs ):
112
- custom_calls .append (1 )
113
- return self .run (* args , ** kwargs )
114
-
115
- celery .Task = CustomTask
116
-
103
+ def test_simple (capture_events , celery , celery_invocation ):
117
104
events = capture_events ()
118
105
119
106
@celery .task (name = "dummy_task" )
@@ -125,8 +112,6 @@ def dummy_task(x, y):
125
112
celery_invocation (dummy_task , 1 , 2 )
126
113
_ , expected_context = celery_invocation (dummy_task , 1 , 0 )
127
114
128
- assert not custom_celery_task_cls or custom_calls
129
-
130
115
(event ,) = events
131
116
132
117
assert event ["contexts" ]["trace" ]["trace_id" ] == transaction .trace_id
@@ -322,9 +307,22 @@ def dummy_task(self):
322
307
323
308
324
309
@pytest .mark .forked
325
- def test_redis_backend_trace_propagation (init_celery , capture_events_forksafe , tmpdir ):
310
+ @pytest .mark .parametrize ("custom_celery_task_cls" , (True , False ))
311
+ def test_redis_backend_trace_propagation (init_celery , capture_events_forksafe , tmpdir , custom_celery_task_cls ):
326
312
celery = init_celery (traces_sample_rate = 1.0 , backend = "redis" , debug = True )
327
313
314
+ custom_calls = []
315
+
316
+ if custom_celery_task_cls :
317
+ import celery as celery_mod
318
+
319
+ class CustomTask (celery_mod .Task ):
320
+ def __call__ (self , * args , ** kwargs ):
321
+ custom_calls .append (1 )
322
+ return self .run (* args , ** kwargs )
323
+
324
+ celery .Task = CustomTask
325
+
328
326
events = capture_events_forksafe ()
329
327
330
328
runs = []
@@ -368,6 +366,8 @@ def dummy_task(self):
368
366
# if this is nonempty, the worker never really forked
369
367
assert not runs
370
368
369
+ assert not custom_celery_task_cls or custom_calls
370
+
371
371
372
372
@pytest .mark .forked
373
373
@pytest .mark .parametrize ("newrelic_order" , ["sentry_first" , "sentry_last" ])
0 commit comments