Skip to content

Commit 257a8a9

Browse files
committed
test: prepared task dependency test, which already helped to find bug in the reference counting mechanism, causing references to the pool to be kepts via cycles
1 parent 365fb14 commit 257a8a9

File tree

3 files changed

+172
-49
lines changed

3 files changed

+172
-49
lines changed

lib/git/async/pool.py

+39-16
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626

2727
import sys
28+
import weakref
2829
from time import sleep
2930

3031

@@ -33,25 +34,37 @@ class RPoolChannel(CallbackRChannel):
3334
before and after an item is to be read.
3435
3536
It acts like a handle to the underlying task in the pool."""
36-
__slots__ = ('_task', '_pool')
37+
__slots__ = ('_task_ref', '_pool_ref')
3738

3839
def __init__(self, wchannel, task, pool):
3940
CallbackRChannel.__init__(self, wchannel)
40-
self._task = task
41-
self._pool = pool
41+
self._task_ref = weakref.ref(task)
42+
self._pool_ref = weakref.ref(pool)
4243

4344
def __del__(self):
4445
"""Assures that our task will be deleted if we were the last reader"""
45-
del(self._wc) # decrement ref-count early
46-
# now, if this is the last reader to the wc we just handled, there
46+
task = self._task_ref()
47+
if task is None:
48+
return
49+
50+
pool = self._pool_ref()
51+
if pool is None:
52+
return
53+
54+
# if this is the last reader to the wc we just handled, there
4755
# is no way anyone will ever read from the task again. If so,
4856
# delete the task in question, it will take care of itself and orphans
4957
# it might leave
5058
# 1 is ourselves, + 1 for the call + 1, and 3 magical ones which
5159
# I can't explain, but appears to be normal in the destructor
5260
# On the caller side, getrefcount returns 2, as expected
61+
# When just calling remove_task,
62+
# it has no way of knowing that the write channel is about to diminsh.
63+
# which is why we pass the info as a private kwarg - not nice, but
64+
# okay for now
65+
# TODO: Fix this - private/public method
5366
if sys.getrefcount(self) < 6:
54-
self._pool.remove_task(self._task)
67+
pool.remove_task(task, _from_destructor_=True)
5568
# END handle refcount based removal of task
5669

5770
def read(self, count=0, block=True, timeout=None):
@@ -72,11 +85,16 @@ def read(self, count=0, block=True, timeout=None):
7285

7386
# if the user tries to use us to read from a done task, we will never
7487
# compute as all produced items are already in the channel
75-
skip_compute = self._task.is_done() or self._task.error()
88+
task = self._task_ref()
89+
if task is None:
90+
return list()
91+
# END abort if task was deleted
92+
93+
skip_compute = task.is_done() or task.error()
7694

7795
########## prepare ##############################
7896
if not skip_compute:
79-
self._pool._prepare_channel_read(self._task, count)
97+
self._pool_ref()._prepare_channel_read(task, count)
8098
# END prepare pool scheduling
8199

82100

@@ -261,11 +279,16 @@ def _prepare_channel_read(self, task, count):
261279
# END for each task to process
262280

263281

264-
def _remove_task_if_orphaned(self, task):
282+
def _remove_task_if_orphaned(self, task, from_destructor):
265283
"""Check the task, and delete it if it is orphaned"""
266284
# 1 as its stored on the task, 1 for the getrefcount call
267-
if sys.getrefcount(task._out_wc) < 3:
268-
self.remove_task(task)
285+
# If we are getting here from the destructor of an RPool channel,
286+
# its totally valid to virtually decrement the refcount by 1 as
287+
# we can expect it to drop once the destructor completes, which is when
288+
# we finish all recursive calls
289+
max_ref_count = 3 + from_destructor
290+
if sys.getrefcount(task.wchannel()) < max_ref_count:
291+
self.remove_task(task, from_destructor)
269292
#} END internal
270293

271294
#{ Interface
@@ -335,7 +358,7 @@ def num_tasks(self):
335358
finally:
336359
self._taskgraph_lock.release()
337360

338-
def remove_task(self, task):
361+
def remove_task(self, task, _from_destructor_=False):
339362
"""Delete the task
340363
Additionally we will remove orphaned tasks, which can be identified if their
341364
output channel is only held by themselves, so no one will ever consume
@@ -370,7 +393,7 @@ def remove_task(self, task):
370393
# END locked deletion
371394

372395
for t in in_tasks:
373-
self._remove_task_if_orphaned(t)
396+
self._remove_task_if_orphaned(t, _from_destructor_)
374397
# END handle orphans recursively
375398

376399
return self
@@ -409,11 +432,11 @@ def add_task(self, task):
409432

410433
# If the input channel is one of our read channels, we add the relation
411434
if isinstance(task, InputChannelTask):
412-
ic = task.in_rc
413-
if isinstance(ic, RPoolChannel) and ic._pool is self:
435+
ic = task.rchannel()
436+
if isinstance(ic, RPoolChannel) and ic._pool_ref() is self:
414437
self._taskgraph_lock.acquire()
415438
try:
416-
self._tasks.add_edge(ic._task, task)
439+
self._tasks.add_edge(ic._task_ref(), task)
417440
finally:
418441
self._taskgraph_lock.release()
419442
# END handle edge-adding

lib/git/async/task.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -208,5 +208,8 @@ def __init__(self, in_rc, *args, **kwargs):
208208
OutputChannelTask.__init__(self, *args, **kwargs)
209209
self._read = in_rc.read
210210

211-
#{ Configuration
212-
211+
def rchannel(self):
212+
""":return: input channel from which we read"""
213+
# the instance is bound in its instance method - lets use this to keep
214+
# the refcount at one ( per consumer )
215+
return self._read.im_self

test/git/async/test_pool.py

+128-31
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@
88
import time
99
import sys
1010

11-
class TestThreadTaskNode(InputIteratorThreadTask):
11+
class _TestTaskBase(object):
1212
def __init__(self, *args, **kwargs):
13-
super(TestThreadTaskNode, self).__init__(*args, **kwargs)
13+
super(_TestTaskBase, self).__init__(*args, **kwargs)
1414
self.should_fail = False
1515
self.lock = threading.Lock() # yes, can't safely do x = x + 1 :)
1616
self.plock = threading.Lock()
1717
self.item_count = 0
1818
self.process_count = 0
19-
self._scheduled_items = 0
2019

2120
def do_fun(self, item):
2221
self.lock.acquire()
@@ -32,44 +31,118 @@ def process(self, count=1):
3231
self.plock.acquire()
3332
self.process_count += 1
3433
self.plock.release()
35-
super(TestThreadTaskNode, self).process(count)
34+
super(_TestTaskBase, self).process(count)
3635

3736
def _assert(self, pc, fc, check_scheduled=False):
3837
"""Assert for num process counts (pc) and num function counts (fc)
3938
:return: self"""
40-
# TODO: fixme
41-
return self
42-
self.plock.acquire()
43-
if self.process_count != pc:
44-
print self.process_count, pc
45-
assert self.process_count == pc
46-
self.plock.release()
4739
self.lock.acquire()
4840
if self.item_count != fc:
4941
print self.item_count, fc
5042
assert self.item_count == fc
5143
self.lock.release()
5244

53-
# if we read all, we can't really use scheduled items
54-
if check_scheduled:
55-
assert self._scheduled_items == 0
56-
assert not self.error()
5745
return self
46+
47+
class TestThreadTaskNode(_TestTaskBase, InputIteratorThreadTask):
48+
pass
5849

5950

6051
class TestThreadFailureNode(TestThreadTaskNode):
6152
"""Fails after X items"""
53+
def __init__(self, *args, **kwargs):
54+
self.fail_after = kwargs.pop('fail_after')
55+
super(TestThreadFailureNode, self).__init__(*args, **kwargs)
6256

57+
def do_fun(self, item):
58+
item = TestThreadTaskNode.do_fun(self, item)
59+
if self.item_count > self.fail_after:
60+
raise AssertionError("Simulated failure after processing %i items" % self.fail_after)
61+
return item
62+
63+
64+
class TestThreadInputChannelTaskNode(_TestTaskBase, InputChannelTask):
65+
"""Apply a transformation on items read from an input channel"""
66+
67+
def do_fun(self, item):
68+
"""return tuple(i, i*2)"""
69+
item = super(TestThreadInputChannelTaskNode, self).do_fun(item)
70+
if isinstance(item, tuple):
71+
i = item[0]
72+
return item + (i * self.id, )
73+
else:
74+
return (item, item * self.id)
75+
# END handle tuple
76+
77+
78+
class TestThreadInputChannelVerifyTaskNode(_TestTaskBase, InputChannelTask):
79+
"""An input channel task, which verifies the result of its input channels,
80+
should be last in the chain.
81+
Id must be int"""
82+
83+
def do_fun(self, item):
84+
"""return tuple(i, i*2)"""
85+
item = super(TestThreadInputChannelTaskNode, self).do_fun(item)
86+
87+
# make sure the computation order matches
88+
assert isinstance(item, tuple)
89+
90+
base = item[0]
91+
for num in item[1:]:
92+
assert num == base * 2
93+
base = num
94+
# END verify order
95+
96+
return item
97+
98+
6399

64100
class TestThreadPool(TestBase):
65101

66102
max_threads = cpu_count()
67103

68-
def _add_triple_task(self, p):
69-
"""Add a triplet of feeder, transformer and finalizer to the pool, like
70-
t1 -> t2 -> t3, return all 3 return channels in order"""
71-
# t1 = TestThreadTaskNode(make_task(), 'iterator', None)
72-
# TODO:
104+
def _add_task_chain(self, p, ni, count=1):
105+
"""Create a task chain of feeder, count transformers and order verifcator
106+
to the pool p, like t1 -> t2 -> t3
107+
:return: tuple(list(task1, taskN, ...), list(rc1, rcN, ...))"""
108+
nt = p.num_tasks()
109+
110+
feeder = self._make_iterator_task(ni)
111+
frc = p.add_task(feeder)
112+
113+
assert p.num_tasks() == nt + 1
114+
115+
rcs = [frc]
116+
tasks = [feeder]
117+
118+
inrc = frc
119+
for tc in xrange(count):
120+
t = TestThreadInputChannelTaskNode(inrc, tc, None)
121+
t.fun = t.do_fun
122+
inrc = p.add_task(t)
123+
124+
tasks.append(t)
125+
rcs.append(inrc)
126+
assert p.num_tasks() == nt + 2 + tc
127+
# END create count transformers
128+
129+
verifier = TestThreadInputChannelVerifyTaskNode(inrc, 'verifier', None)
130+
verifier.fun = verifier.do_fun
131+
vrc = p.add_task(verifier)
132+
133+
assert p.num_tasks() == nt + tc + 3
134+
135+
tasks.append(verifier)
136+
rcs.append(vrc)
137+
return tasks, rcs
138+
139+
def _make_iterator_task(self, ni, taskcls=TestThreadTaskNode, **kwargs):
140+
""":return: task which yields ni items
141+
:param taskcls: the actual iterator type to use
142+
:param **kwargs: additional kwargs to be passed to the task"""
143+
t = taskcls(iter(range(ni)), 'iterator', None, **kwargs)
144+
t.fun = t.do_fun
145+
return t
73146

74147
def _assert_single_task(self, p, async=False):
75148
"""Performs testing in a synchronized environment"""
@@ -82,11 +155,7 @@ def _assert_single_task(self, p, async=False):
82155
assert ni % 2 == 0, "ni needs to be dividable by 2"
83156
assert ni % 4 == 0, "ni needs to be dividable by 4"
84157

85-
def make_task():
86-
t = TestThreadTaskNode(iter(range(ni)), 'iterator', None)
87-
t.fun = t.do_fun
88-
return t
89-
# END utility
158+
make_task = lambda *args, **kwargs: self._make_iterator_task(ni, *args, **kwargs)
90159

91160
task = make_task()
92161

@@ -252,15 +321,44 @@ def make_task():
252321

253322
# test failure after ni / 2 items
254323
# This makes sure it correctly closes the channel on failure to prevent blocking
324+
nri = ni/2
325+
task = make_task(TestThreadFailureNode, fail_after=ni/2)
326+
rc = p.add_task(task)
327+
assert len(rc.read()) == nri
328+
assert task.is_done()
329+
assert isinstance(task.error(), AssertionError)
255330

256331

257332

258-
def _assert_async_dependent_tasks(self, p):
333+
def _assert_async_dependent_tasks(self, pool):
259334
# includes failure in center task, 'recursive' orphan cleanup
260335
# This will also verify that the channel-close mechanism works
261336
# t1 -> t2 -> t3
262337
# t1 -> x -> t3
263-
pass
338+
null_tasks = pool.num_tasks()
339+
ni = 100
340+
count = 1
341+
make_task = lambda *args, **kwargs: self._add_task_chain(pool, ni, count, *args, **kwargs)
342+
343+
ts, rcs = make_task()
344+
assert len(ts) == count + 2
345+
assert len(rcs) == count + 2
346+
assert pool.num_tasks() == null_tasks + len(ts)
347+
print pool._tasks.nodes
348+
349+
350+
# in the end, we expect all tasks to be gone, automatically
351+
352+
353+
354+
# order of deletion matters - just keep the end, then delete
355+
final_rc = rcs[-1]
356+
del(ts)
357+
del(rcs)
358+
del(final_rc)
359+
assert pool.num_tasks() == null_tasks
360+
361+
264362

265363
@terminate_threads
266364
def test_base(self):
@@ -301,8 +399,8 @@ def test_base(self):
301399
assert p.num_tasks() == 0
302400

303401

304-
# DEPENDENT TASKS SERIAL
305-
########################
402+
# DEPENDENT TASKS SYNC MODE
403+
###########################
306404
self._assert_async_dependent_tasks(p)
307405

308406

@@ -311,12 +409,11 @@ def test_base(self):
311409
# step one gear up - just one thread for now.
312410
p.set_size(1)
313411
assert p.size() == 1
314-
print len(threading.enumerate()), num_threads
315412
assert len(threading.enumerate()) == num_threads + 1
316413
# deleting the pool stops its threads - just to be sure ;)
317414
# Its not synchronized, hence we wait a moment
318415
del(p)
319-
time.sleep(0.25)
416+
time.sleep(0.05)
320417
assert len(threading.enumerate()) == num_threads
321418

322419
p = ThreadPool(1)

0 commit comments

Comments
 (0)