12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import concurrent .futures
16
+ import random
15
17
import unittest
16
18
from unittest import mock
17
19
@@ -222,6 +224,15 @@ def test_ungrouped_batcher_process_not_stateful(self):
222
224
223
225
224
226
class TestCounterAggregator (unittest .TestCase ):
227
+ @staticmethod
228
+ def call_update (counter ):
229
+ update_total = 0
230
+ for _ in range (0 , 100000 ):
231
+ val = random .getrandbits (32 )
232
+ counter .update (val )
233
+ update_total += val
234
+ return update_total
235
+
225
236
def test_update (self ):
226
237
counter = CounterAggregator ()
227
238
counter .update (1.0 )
@@ -243,13 +254,58 @@ def test_merge(self):
243
254
counter .merge (counter2 )
244
255
self .assertEqual (counter .checkpoint , 4.0 )
245
256
257
+ def test_concurrent_update (self ):
258
+ counter = CounterAggregator ()
259
+
260
+ with concurrent .futures .ThreadPoolExecutor (max_workers = 2 ) as executor :
261
+ fut1 = executor .submit (self .call_update , counter )
262
+ fut2 = executor .submit (self .call_update , counter )
263
+
264
+ updapte_total = fut1 .result () + fut2 .result ()
265
+
266
+ counter .take_checkpoint ()
267
+ self .assertEqual (updapte_total , counter .checkpoint )
268
+
269
+ def test_concurrent_update_and_checkpoint (self ):
270
+ counter = CounterAggregator ()
271
+ checkpoint_total = 0
272
+
273
+ with concurrent .futures .ThreadPoolExecutor (max_workers = 1 ) as executor :
274
+ fut = executor .submit (self .call_update , counter )
275
+
276
+ while not fut .done ():
277
+ counter .take_checkpoint ()
278
+ checkpoint_total += counter .checkpoint
279
+
280
+ counter .take_checkpoint ()
281
+ checkpoint_total += counter .checkpoint
282
+
283
+ self .assertEqual (fut .result (), checkpoint_total )
284
+
246
285
247
286
class TestMinMaxSumCountAggregator (unittest .TestCase ):
287
+ @staticmethod
288
+ def call_update (mmsc ):
289
+ min_ = float ("inf" )
290
+ max_ = float ("-inf" )
291
+ sum_ = 0
292
+ count_ = 0
293
+ for _ in range (0 , 100000 ):
294
+ val = random .getrandbits (32 )
295
+ mmsc .update (val )
296
+ if val < min_ :
297
+ min_ = val
298
+ if val > max_ :
299
+ max_ = val
300
+ sum_ += val
301
+ count_ += 1
302
+ return MinMaxSumCountAggregator ._TYPE (min_ , max_ , sum_ , count_ )
303
+
248
304
def test_update (self ):
249
305
mmsc = MinMaxSumCountAggregator ()
250
306
# test current values without any update
251
307
self .assertEqual (
252
- mmsc .current , ( None , None , None , 0 ) ,
308
+ mmsc .current , MinMaxSumCountAggregator . _EMPTY ,
253
309
)
254
310
255
311
# call update with some values
@@ -267,7 +323,7 @@ def test_checkpoint(self):
267
323
# take checkpoint wihtout any update
268
324
mmsc .take_checkpoint ()
269
325
self .assertEqual (
270
- mmsc .checkpoint , ( None , None , None , 0 ) ,
326
+ mmsc .checkpoint , MinMaxSumCountAggregator . _EMPTY ,
271
327
)
272
328
273
329
# call update with some values
@@ -282,7 +338,7 @@ def test_checkpoint(self):
282
338
)
283
339
284
340
self .assertEqual (
285
- mmsc .current , ( None , None , None , 0 ) ,
341
+ mmsc .current , MinMaxSumCountAggregator . _EMPTY ,
286
342
)
287
343
288
344
def test_merge (self ):
@@ -299,14 +355,34 @@ def test_merge(self):
299
355
300
356
self .assertEqual (
301
357
mmsc1 .checkpoint ,
302
- (
303
- min (checkpoint1 .min , checkpoint2 .min ),
304
- max (checkpoint1 .max , checkpoint2 .max ),
305
- checkpoint1 .sum + checkpoint2 .sum ,
306
- checkpoint1 .count + checkpoint2 .count ,
358
+ MinMaxSumCountAggregator ._merge_checkpoint (
359
+ checkpoint1 , checkpoint2
307
360
),
308
361
)
309
362
363
+ def test_merge_checkpoint (self ):
364
+ func = MinMaxSumCountAggregator ._merge_checkpoint
365
+ _type = MinMaxSumCountAggregator ._TYPE
366
+ empty = MinMaxSumCountAggregator ._EMPTY
367
+
368
+ ret = func (empty , empty )
369
+ self .assertEqual (ret , empty )
370
+
371
+ ret = func (empty , _type (0 , 0 , 0 , 0 ))
372
+ self .assertEqual (ret , _type (0 , 0 , 0 , 0 ))
373
+
374
+ ret = func (_type (0 , 0 , 0 , 0 ), empty )
375
+ self .assertEqual (ret , _type (0 , 0 , 0 , 0 ))
376
+
377
+ ret = func (_type (0 , 0 , 0 , 0 ), _type (0 , 0 , 0 , 0 ))
378
+ self .assertEqual (ret , _type (0 , 0 , 0 , 0 ))
379
+
380
+ ret = func (_type (44 , 23 , 55 , 86 ), empty )
381
+ self .assertEqual (ret , _type (44 , 23 , 55 , 86 ))
382
+
383
+ ret = func (_type (3 , 150 , 101 , 3 ), _type (1 , 33 , 44 , 2 ))
384
+ self .assertEqual (ret , _type (1 , 150 , 101 + 44 , 2 + 3 ))
385
+
310
386
def test_merge_with_empty (self ):
311
387
mmsc1 = MinMaxSumCountAggregator ()
312
388
mmsc2 = MinMaxSumCountAggregator ()
@@ -318,6 +394,42 @@ def test_merge_with_empty(self):
318
394
319
395
self .assertEqual (mmsc1 .checkpoint , checkpoint1 )
320
396
397
+ def test_concurrent_update (self ):
398
+ mmsc = MinMaxSumCountAggregator ()
399
+ with concurrent .futures .ThreadPoolExecutor (max_workers = 2 ) as ex :
400
+ fut1 = ex .submit (self .call_update , mmsc )
401
+ fut2 = ex .submit (self .call_update , mmsc )
402
+
403
+ ret1 = fut1 .result ()
404
+ ret2 = fut2 .result ()
405
+
406
+ update_total = MinMaxSumCountAggregator ._merge_checkpoint (
407
+ ret1 , ret2
408
+ )
409
+ mmsc .take_checkpoint ()
410
+
411
+ self .assertEqual (update_total , mmsc .checkpoint )
412
+
413
+ def test_concurrent_update_and_checkpoint (self ):
414
+ mmsc = MinMaxSumCountAggregator ()
415
+ checkpoint_total = MinMaxSumCountAggregator ._TYPE (2 ** 32 , 0 , 0 , 0 )
416
+
417
+ with concurrent .futures .ThreadPoolExecutor (max_workers = 1 ) as ex :
418
+ fut = ex .submit (self .call_update , mmsc )
419
+
420
+ while not fut .done ():
421
+ mmsc .take_checkpoint ()
422
+ checkpoint_total = MinMaxSumCountAggregator ._merge_checkpoint (
423
+ checkpoint_total , mmsc .checkpoint
424
+ )
425
+
426
+ mmsc .take_checkpoint ()
427
+ checkpoint_total = MinMaxSumCountAggregator ._merge_checkpoint (
428
+ checkpoint_total , mmsc .checkpoint
429
+ )
430
+
431
+ self .assertEqual (checkpoint_total , fut .result ())
432
+
321
433
322
434
class TestController (unittest .TestCase ):
323
435
def test_push_controller (self ):
0 commit comments