8
8
import time
9
9
import sys
10
10
11
- class TestThreadTaskNode ( InputIteratorThreadTask ):
11
+ class _TestTaskBase ( object ):
12
12
def __init__ (self , * args , ** kwargs ):
13
- super (TestThreadTaskNode , self ).__init__ (* args , ** kwargs )
13
+ super (_TestTaskBase , self ).__init__ (* args , ** kwargs )
14
14
self .should_fail = False
15
15
self .lock = threading .Lock () # yes, can't safely do x = x + 1 :)
16
16
self .plock = threading .Lock ()
17
17
self .item_count = 0
18
18
self .process_count = 0
19
- self ._scheduled_items = 0
20
19
21
20
def do_fun (self , item ):
22
21
self .lock .acquire ()
@@ -32,44 +31,118 @@ def process(self, count=1):
32
31
self .plock .acquire ()
33
32
self .process_count += 1
34
33
self .plock .release ()
35
- super (TestThreadTaskNode , self ).process (count )
34
+ super (_TestTaskBase , self ).process (count )
36
35
37
36
def _assert (self , pc , fc , check_scheduled = False ):
38
37
"""Assert for num process counts (pc) and num function counts (fc)
39
38
:return: self"""
40
- # TODO: fixme
41
- return self
42
- self .plock .acquire ()
43
- if self .process_count != pc :
44
- print self .process_count , pc
45
- assert self .process_count == pc
46
- self .plock .release ()
47
39
self .lock .acquire ()
48
40
if self .item_count != fc :
49
41
print self .item_count , fc
50
42
assert self .item_count == fc
51
43
self .lock .release ()
52
44
53
- # if we read all, we can't really use scheduled items
54
- if check_scheduled :
55
- assert self ._scheduled_items == 0
56
- assert not self .error ()
57
45
return self
46
+
47
+ class TestThreadTaskNode (_TestTaskBase , InputIteratorThreadTask ):
48
+ pass
58
49
59
50
60
51
class TestThreadFailureNode (TestThreadTaskNode ):
61
52
"""Fails after X items"""
53
+ def __init__ (self , * args , ** kwargs ):
54
+ self .fail_after = kwargs .pop ('fail_after' )
55
+ super (TestThreadFailureNode , self ).__init__ (* args , ** kwargs )
62
56
57
+ def do_fun (self , item ):
58
+ item = TestThreadTaskNode .do_fun (self , item )
59
+ if self .item_count > self .fail_after :
60
+ raise AssertionError ("Simulated failure after processing %i items" % self .fail_after )
61
+ return item
62
+
63
+
64
+ class TestThreadInputChannelTaskNode (_TestTaskBase , InputChannelTask ):
65
+ """Apply a transformation on items read from an input channel"""
66
+
67
+ def do_fun (self , item ):
68
+ """return tuple(i, i*2)"""
69
+ item = super (TestThreadInputChannelTaskNode , self ).do_fun (item )
70
+ if isinstance (item , tuple ):
71
+ i = item [0 ]
72
+ return item + (i * self .id , )
73
+ else :
74
+ return (item , item * self .id )
75
+ # END handle tuple
76
+
77
+
78
+ class TestThreadInputChannelVerifyTaskNode (_TestTaskBase , InputChannelTask ):
79
+ """An input channel task, which verifies the result of its input channels,
80
+ should be last in the chain.
81
+ Id must be int"""
82
+
83
+ def do_fun (self , item ):
84
+ """return tuple(i, i*2)"""
85
+ item = super (TestThreadInputChannelTaskNode , self ).do_fun (item )
86
+
87
+ # make sure the computation order matches
88
+ assert isinstance (item , tuple )
89
+
90
+ base = item [0 ]
91
+ for num in item [1 :]:
92
+ assert num == base * 2
93
+ base = num
94
+ # END verify order
95
+
96
+ return item
97
+
98
+
63
99
64
100
class TestThreadPool (TestBase ):
65
101
66
102
max_threads = cpu_count ()
67
103
68
- def _add_triple_task (self , p ):
69
- """Add a triplet of feeder, transformer and finalizer to the pool, like
70
- t1 -> t2 -> t3, return all 3 return channels in order"""
71
- # t1 = TestThreadTaskNode(make_task(), 'iterator', None)
72
- # TODO:
104
+ def _add_task_chain (self , p , ni , count = 1 ):
105
+ """Create a task chain of feeder, count transformers and order verifcator
106
+ to the pool p, like t1 -> t2 -> t3
107
+ :return: tuple(list(task1, taskN, ...), list(rc1, rcN, ...))"""
108
+ nt = p .num_tasks ()
109
+
110
+ feeder = self ._make_iterator_task (ni )
111
+ frc = p .add_task (feeder )
112
+
113
+ assert p .num_tasks () == nt + 1
114
+
115
+ rcs = [frc ]
116
+ tasks = [feeder ]
117
+
118
+ inrc = frc
119
+ for tc in xrange (count ):
120
+ t = TestThreadInputChannelTaskNode (inrc , tc , None )
121
+ t .fun = t .do_fun
122
+ inrc = p .add_task (t )
123
+
124
+ tasks .append (t )
125
+ rcs .append (inrc )
126
+ assert p .num_tasks () == nt + 2 + tc
127
+ # END create count transformers
128
+
129
+ verifier = TestThreadInputChannelVerifyTaskNode (inrc , 'verifier' , None )
130
+ verifier .fun = verifier .do_fun
131
+ vrc = p .add_task (verifier )
132
+
133
+ assert p .num_tasks () == nt + tc + 3
134
+
135
+ tasks .append (verifier )
136
+ rcs .append (vrc )
137
+ return tasks , rcs
138
+
139
+ def _make_iterator_task (self , ni , taskcls = TestThreadTaskNode , ** kwargs ):
140
+ """:return: task which yields ni items
141
+ :param taskcls: the actual iterator type to use
142
+ :param **kwargs: additional kwargs to be passed to the task"""
143
+ t = taskcls (iter (range (ni )), 'iterator' , None , ** kwargs )
144
+ t .fun = t .do_fun
145
+ return t
73
146
74
147
def _assert_single_task (self , p , async = False ):
75
148
"""Performs testing in a synchronized environment"""
@@ -82,11 +155,7 @@ def _assert_single_task(self, p, async=False):
82
155
assert ni % 2 == 0 , "ni needs to be dividable by 2"
83
156
assert ni % 4 == 0 , "ni needs to be dividable by 4"
84
157
85
- def make_task ():
86
- t = TestThreadTaskNode (iter (range (ni )), 'iterator' , None )
87
- t .fun = t .do_fun
88
- return t
89
- # END utility
158
+ make_task = lambda * args , ** kwargs : self ._make_iterator_task (ni , * args , ** kwargs )
90
159
91
160
task = make_task ()
92
161
@@ -252,15 +321,44 @@ def make_task():
252
321
253
322
# test failure after ni / 2 items
254
323
# This makes sure it correctly closes the channel on failure to prevent blocking
324
+ nri = ni / 2
325
+ task = make_task (TestThreadFailureNode , fail_after = ni / 2 )
326
+ rc = p .add_task (task )
327
+ assert len (rc .read ()) == nri
328
+ assert task .is_done ()
329
+ assert isinstance (task .error (), AssertionError )
255
330
256
331
257
332
258
- def _assert_async_dependent_tasks (self , p ):
333
+ def _assert_async_dependent_tasks (self , pool ):
259
334
# includes failure in center task, 'recursive' orphan cleanup
260
335
# This will also verify that the channel-close mechanism works
261
336
# t1 -> t2 -> t3
262
337
# t1 -> x -> t3
263
- pass
338
+ null_tasks = pool .num_tasks ()
339
+ ni = 100
340
+ count = 1
341
+ make_task = lambda * args , ** kwargs : self ._add_task_chain (pool , ni , count , * args , ** kwargs )
342
+
343
+ ts , rcs = make_task ()
344
+ assert len (ts ) == count + 2
345
+ assert len (rcs ) == count + 2
346
+ assert pool .num_tasks () == null_tasks + len (ts )
347
+ print pool ._tasks .nodes
348
+
349
+
350
+ # in the end, we expect all tasks to be gone, automatically
351
+
352
+
353
+
354
+ # order of deletion matters - just keep the end, then delete
355
+ final_rc = rcs [- 1 ]
356
+ del (ts )
357
+ del (rcs )
358
+ del (final_rc )
359
+ assert pool .num_tasks () == null_tasks
360
+
361
+
264
362
265
363
@terminate_threads
266
364
def test_base (self ):
@@ -301,8 +399,8 @@ def test_base(self):
301
399
assert p .num_tasks () == 0
302
400
303
401
304
- # DEPENDENT TASKS SERIAL
305
- ########################
402
+ # DEPENDENT TASKS SYNC MODE
403
+ ###########################
306
404
self ._assert_async_dependent_tasks (p )
307
405
308
406
@@ -311,12 +409,11 @@ def test_base(self):
311
409
# step one gear up - just one thread for now.
312
410
p .set_size (1 )
313
411
assert p .size () == 1
314
- print len (threading .enumerate ()), num_threads
315
412
assert len (threading .enumerate ()) == num_threads + 1
316
413
# deleting the pool stops its threads - just to be sure ;)
317
414
# Its not synchronized, hence we wait a moment
318
415
del (p )
319
- time .sleep (0.25 )
416
+ time .sleep (0.05 )
320
417
assert len (threading .enumerate ()) == num_threads
321
418
322
419
p = ThreadPool (1 )
0 commit comments