Skip to content

Commit 2578b37

Browse files
Make Counter and MinMaxSumCount aggregators thread safe (open-telemetry#439)
1 parent 5b2e693 commit 2578b37

File tree

2 files changed

+161
-44
lines changed

2 files changed

+161
-44
lines changed

opentelemetry-sdk/src/opentelemetry/sdk/metrics/export/aggregate.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import abc
16+
import threading
1617
from collections import namedtuple
1718

1819

@@ -47,62 +48,66 @@ def __init__(self):
4748
super().__init__()
4849
self.current = 0
4950
self.checkpoint = 0
51+
self._lock = threading.Lock()
5052

5153
def update(self, value):
52-
self.current += value
54+
with self._lock:
55+
self.current += value
5356

5457
def take_checkpoint(self):
55-
self.checkpoint = self.current
56-
self.current = 0
58+
with self._lock:
59+
self.checkpoint = self.current
60+
self.current = 0
5761

5862
def merge(self, other):
59-
self.checkpoint += other.checkpoint
63+
with self._lock:
64+
self.checkpoint += other.checkpoint
6065

6166

6267
class MinMaxSumCountAggregator(Aggregator):
6368
"""Agregator for Measure metrics that keeps min, max, sum and count."""
6469

6570
_TYPE = namedtuple("minmaxsumcount", "min max sum count")
71+
_EMPTY = _TYPE(None, None, None, 0)
6672

6773
@classmethod
68-
def _min(cls, val1, val2):
69-
if val1 is None and val2 is None:
70-
return None
71-
return min(val1 or val2, val2 or val1)
72-
73-
@classmethod
74-
def _max(cls, val1, val2):
75-
if val1 is None and val2 is None:
76-
return None
77-
return max(val1 or val2, val2 or val1)
78-
79-
@classmethod
80-
def _sum(cls, val1, val2):
81-
if val1 is None and val2 is None:
82-
return None
83-
return (val1 or 0) + (val2 or 0)
74+
def _merge_checkpoint(cls, val1, val2):
75+
if val1 is cls._EMPTY:
76+
return val2
77+
if val2 is cls._EMPTY:
78+
return val1
79+
return cls._TYPE(
80+
min(val1.min, val2.min),
81+
max(val1.max, val2.max),
82+
val1.sum + val2.sum,
83+
val1.count + val2.count,
84+
)
8485

8586
def __init__(self):
8687
super().__init__()
87-
self.current = self._TYPE(None, None, None, 0)
88-
self.checkpoint = self._TYPE(None, None, None, 0)
88+
self.current = self._EMPTY
89+
self.checkpoint = self._EMPTY
90+
self._lock = threading.Lock()
8991

9092
def update(self, value):
91-
self.current = self._TYPE(
92-
self._min(self.current.min, value),
93-
self._max(self.current.max, value),
94-
self._sum(self.current.sum, value),
95-
self.current.count + 1,
96-
)
93+
with self._lock:
94+
if self.current is self._EMPTY:
95+
self.current = self._TYPE(value, value, value, 1)
96+
else:
97+
self.current = self._TYPE(
98+
min(self.current.min, value),
99+
max(self.current.max, value),
100+
self.current.sum + value,
101+
self.current.count + 1,
102+
)
97103

98104
def take_checkpoint(self):
99-
self.checkpoint = self.current
100-
self.current = self._TYPE(None, None, None, 0)
105+
with self._lock:
106+
self.checkpoint = self.current
107+
self.current = self._EMPTY
101108

102109
def merge(self, other):
103-
self.checkpoint = self._TYPE(
104-
self._min(self.checkpoint.min, other.checkpoint.min),
105-
self._max(self.checkpoint.max, other.checkpoint.max),
106-
self._sum(self.checkpoint.sum, other.checkpoint.sum),
107-
self.checkpoint.count + other.checkpoint.count,
108-
)
110+
with self._lock:
111+
self.checkpoint = self._merge_checkpoint(
112+
self.checkpoint, other.checkpoint
113+
)

opentelemetry-sdk/tests/metrics/export/test_export.py

Lines changed: 120 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import concurrent.futures
16+
import random
1517
import unittest
1618
from unittest import mock
1719

@@ -222,6 +224,15 @@ def test_ungrouped_batcher_process_not_stateful(self):
222224

223225

224226
class TestCounterAggregator(unittest.TestCase):
227+
@staticmethod
228+
def call_update(counter):
229+
update_total = 0
230+
for _ in range(0, 100000):
231+
val = random.getrandbits(32)
232+
counter.update(val)
233+
update_total += val
234+
return update_total
235+
225236
def test_update(self):
226237
counter = CounterAggregator()
227238
counter.update(1.0)
@@ -243,13 +254,58 @@ def test_merge(self):
243254
counter.merge(counter2)
244255
self.assertEqual(counter.checkpoint, 4.0)
245256

257+
def test_concurrent_update(self):
258+
counter = CounterAggregator()
259+
260+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
261+
fut1 = executor.submit(self.call_update, counter)
262+
fut2 = executor.submit(self.call_update, counter)
263+
264+
updapte_total = fut1.result() + fut2.result()
265+
266+
counter.take_checkpoint()
267+
self.assertEqual(updapte_total, counter.checkpoint)
268+
269+
def test_concurrent_update_and_checkpoint(self):
270+
counter = CounterAggregator()
271+
checkpoint_total = 0
272+
273+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
274+
fut = executor.submit(self.call_update, counter)
275+
276+
while not fut.done():
277+
counter.take_checkpoint()
278+
checkpoint_total += counter.checkpoint
279+
280+
counter.take_checkpoint()
281+
checkpoint_total += counter.checkpoint
282+
283+
self.assertEqual(fut.result(), checkpoint_total)
284+
246285

247286
class TestMinMaxSumCountAggregator(unittest.TestCase):
287+
@staticmethod
288+
def call_update(mmsc):
289+
min_ = float("inf")
290+
max_ = float("-inf")
291+
sum_ = 0
292+
count_ = 0
293+
for _ in range(0, 100000):
294+
val = random.getrandbits(32)
295+
mmsc.update(val)
296+
if val < min_:
297+
min_ = val
298+
if val > max_:
299+
max_ = val
300+
sum_ += val
301+
count_ += 1
302+
return MinMaxSumCountAggregator._TYPE(min_, max_, sum_, count_)
303+
248304
def test_update(self):
249305
mmsc = MinMaxSumCountAggregator()
250306
# test current values without any update
251307
self.assertEqual(
252-
mmsc.current, (None, None, None, 0),
308+
mmsc.current, MinMaxSumCountAggregator._EMPTY,
253309
)
254310

255311
# call update with some values
@@ -267,7 +323,7 @@ def test_checkpoint(self):
267323
# take checkpoint wihtout any update
268324
mmsc.take_checkpoint()
269325
self.assertEqual(
270-
mmsc.checkpoint, (None, None, None, 0),
326+
mmsc.checkpoint, MinMaxSumCountAggregator._EMPTY,
271327
)
272328

273329
# call update with some values
@@ -282,7 +338,7 @@ def test_checkpoint(self):
282338
)
283339

284340
self.assertEqual(
285-
mmsc.current, (None, None, None, 0),
341+
mmsc.current, MinMaxSumCountAggregator._EMPTY,
286342
)
287343

288344
def test_merge(self):
@@ -299,14 +355,34 @@ def test_merge(self):
299355

300356
self.assertEqual(
301357
mmsc1.checkpoint,
302-
(
303-
min(checkpoint1.min, checkpoint2.min),
304-
max(checkpoint1.max, checkpoint2.max),
305-
checkpoint1.sum + checkpoint2.sum,
306-
checkpoint1.count + checkpoint2.count,
358+
MinMaxSumCountAggregator._merge_checkpoint(
359+
checkpoint1, checkpoint2
307360
),
308361
)
309362

363+
def test_merge_checkpoint(self):
364+
func = MinMaxSumCountAggregator._merge_checkpoint
365+
_type = MinMaxSumCountAggregator._TYPE
366+
empty = MinMaxSumCountAggregator._EMPTY
367+
368+
ret = func(empty, empty)
369+
self.assertEqual(ret, empty)
370+
371+
ret = func(empty, _type(0, 0, 0, 0))
372+
self.assertEqual(ret, _type(0, 0, 0, 0))
373+
374+
ret = func(_type(0, 0, 0, 0), empty)
375+
self.assertEqual(ret, _type(0, 0, 0, 0))
376+
377+
ret = func(_type(0, 0, 0, 0), _type(0, 0, 0, 0))
378+
self.assertEqual(ret, _type(0, 0, 0, 0))
379+
380+
ret = func(_type(44, 23, 55, 86), empty)
381+
self.assertEqual(ret, _type(44, 23, 55, 86))
382+
383+
ret = func(_type(3, 150, 101, 3), _type(1, 33, 44, 2))
384+
self.assertEqual(ret, _type(1, 150, 101 + 44, 2 + 3))
385+
310386
def test_merge_with_empty(self):
311387
mmsc1 = MinMaxSumCountAggregator()
312388
mmsc2 = MinMaxSumCountAggregator()
@@ -318,6 +394,42 @@ def test_merge_with_empty(self):
318394

319395
self.assertEqual(mmsc1.checkpoint, checkpoint1)
320396

397+
def test_concurrent_update(self):
398+
mmsc = MinMaxSumCountAggregator()
399+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as ex:
400+
fut1 = ex.submit(self.call_update, mmsc)
401+
fut2 = ex.submit(self.call_update, mmsc)
402+
403+
ret1 = fut1.result()
404+
ret2 = fut2.result()
405+
406+
update_total = MinMaxSumCountAggregator._merge_checkpoint(
407+
ret1, ret2
408+
)
409+
mmsc.take_checkpoint()
410+
411+
self.assertEqual(update_total, mmsc.checkpoint)
412+
413+
def test_concurrent_update_and_checkpoint(self):
414+
mmsc = MinMaxSumCountAggregator()
415+
checkpoint_total = MinMaxSumCountAggregator._TYPE(2 ** 32, 0, 0, 0)
416+
417+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex:
418+
fut = ex.submit(self.call_update, mmsc)
419+
420+
while not fut.done():
421+
mmsc.take_checkpoint()
422+
checkpoint_total = MinMaxSumCountAggregator._merge_checkpoint(
423+
checkpoint_total, mmsc.checkpoint
424+
)
425+
426+
mmsc.take_checkpoint()
427+
checkpoint_total = MinMaxSumCountAggregator._merge_checkpoint(
428+
checkpoint_total, mmsc.checkpoint
429+
)
430+
431+
self.assertEqual(checkpoint_total, fut.result())
432+
321433

322434
class TestController(unittest.TestCase):
323435
def test_push_controller(self):

0 commit comments

Comments
 (0)