Skip to content

Commit cf2532b

Browse files
authored
[3.12] Tee of tee was not producing n independent iterators (gh-123884) (gh-125153)
1 parent 382ee1c commit cf2532b

File tree

4 files changed

+214
-47
lines changed

4 files changed

+214
-47
lines changed

Doc/library/itertools.rst

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -676,24 +676,37 @@ loops that truncate the stream.
676676
Roughly equivalent to::
677677

678678
def tee(iterable, n=2):
679-
iterator = iter(iterable)
680-
shared_link = [None, None]
681-
return tuple(_tee(iterator, shared_link) for _ in range(n))
682-
683-
def _tee(iterator, link):
684-
try:
685-
while True:
686-
if link[1] is None:
687-
link[0] = next(iterator)
688-
link[1] = [None, None]
689-
value, link = link
690-
yield value
691-
except StopIteration:
692-
return
693-
694-
Once a :func:`tee` has been created, the original *iterable* should not be
695-
used anywhere else; otherwise, the *iterable* could get advanced without
696-
the tee objects being informed.
679+
if n < 0:
680+
raise ValueError
681+
if n == 0:
682+
return ()
683+
iterator = _tee(iterable)
684+
result = [iterator]
685+
for _ in range(n - 1):
686+
result.append(_tee(iterator))
687+
return tuple(result)
688+
689+
class _tee:
690+
691+
def __init__(self, iterable):
692+
it = iter(iterable)
693+
if isinstance(it, _tee):
694+
self.iterator = it.iterator
695+
self.link = it.link
696+
else:
697+
self.iterator = it
698+
self.link = [None, None]
699+
700+
def __iter__(self):
701+
return self
702+
703+
def __next__(self):
704+
link = self.link
705+
if link[1] is None:
706+
link[0] = next(self.iterator)
707+
link[1] = [None, None]
708+
value, self.link = link
709+
return value
697710

698711
``tee`` iterators are not threadsafe. A :exc:`RuntimeError` may be
699712
raised when simultaneously using iterators returned by the same :func:`tee`

Lib/test/test_itertools.py

Lines changed: 169 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,10 +1612,11 @@ def test_tee(self):
16121612
self.assertEqual(len(result), n)
16131613
self.assertEqual([list(x) for x in result], [list('abc')]*n)
16141614

1615-
# tee pass-through to copyable iterator
1615+
# tee objects are independent (see bug gh-123884)
16161616
a, b = tee('abc')
16171617
c, d = tee(a)
1618-
self.assertTrue(a is c)
1618+
e, f = tee(c)
1619+
self.assertTrue(len({a, b, c, d, e, f}) == 6)
16191620

16201621
# test tee_new
16211622
t1, t2 = tee('abc')
@@ -2029,6 +2030,172 @@ def test_islice_recipe(self):
20292030
self.assertEqual(next(c), 3)
20302031

20312032

2033+
def test_tee_recipe(self):
2034+
2035+
# Begin tee() recipe ###########################################
2036+
2037+
def tee(iterable, n=2):
2038+
if n < 0:
2039+
raise ValueError
2040+
if n == 0:
2041+
return ()
2042+
iterator = _tee(iterable)
2043+
result = [iterator]
2044+
for _ in range(n - 1):
2045+
result.append(_tee(iterator))
2046+
return tuple(result)
2047+
2048+
class _tee:
2049+
2050+
def __init__(self, iterable):
2051+
it = iter(iterable)
2052+
if isinstance(it, _tee):
2053+
self.iterator = it.iterator
2054+
self.link = it.link
2055+
else:
2056+
self.iterator = it
2057+
self.link = [None, None]
2058+
2059+
def __iter__(self):
2060+
return self
2061+
2062+
def __next__(self):
2063+
link = self.link
2064+
if link[1] is None:
2065+
link[0] = next(self.iterator)
2066+
link[1] = [None, None]
2067+
value, self.link = link
2068+
return value
2069+
2070+
# End tee() recipe #############################################
2071+
2072+
n = 200
2073+
2074+
a, b = tee([]) # test empty iterator
2075+
self.assertEqual(list(a), [])
2076+
self.assertEqual(list(b), [])
2077+
2078+
a, b = tee(irange(n)) # test 100% interleaved
2079+
self.assertEqual(lzip(a,b), lzip(range(n), range(n)))
2080+
2081+
a, b = tee(irange(n)) # test 0% interleaved
2082+
self.assertEqual(list(a), list(range(n)))
2083+
self.assertEqual(list(b), list(range(n)))
2084+
2085+
a, b = tee(irange(n)) # test dealloc of leading iterator
2086+
for i in range(100):
2087+
self.assertEqual(next(a), i)
2088+
del a
2089+
self.assertEqual(list(b), list(range(n)))
2090+
2091+
a, b = tee(irange(n)) # test dealloc of trailing iterator
2092+
for i in range(100):
2093+
self.assertEqual(next(a), i)
2094+
del b
2095+
self.assertEqual(list(a), list(range(100, n)))
2096+
2097+
for j in range(5): # test randomly interleaved
2098+
order = [0]*n + [1]*n
2099+
random.shuffle(order)
2100+
lists = ([], [])
2101+
its = tee(irange(n))
2102+
for i in order:
2103+
value = next(its[i])
2104+
lists[i].append(value)
2105+
self.assertEqual(lists[0], list(range(n)))
2106+
self.assertEqual(lists[1], list(range(n)))
2107+
2108+
# test argument format checking
2109+
self.assertRaises(TypeError, tee)
2110+
self.assertRaises(TypeError, tee, 3)
2111+
self.assertRaises(TypeError, tee, [1,2], 'x')
2112+
self.assertRaises(TypeError, tee, [1,2], 3, 'x')
2113+
2114+
# tee object should be instantiable
2115+
a, b = tee('abc')
2116+
c = type(a)('def')
2117+
self.assertEqual(list(c), list('def'))
2118+
2119+
# test long-lagged and multi-way split
2120+
a, b, c = tee(range(2000), 3)
2121+
for i in range(100):
2122+
self.assertEqual(next(a), i)
2123+
self.assertEqual(list(b), list(range(2000)))
2124+
self.assertEqual([next(c), next(c)], list(range(2)))
2125+
self.assertEqual(list(a), list(range(100,2000)))
2126+
self.assertEqual(list(c), list(range(2,2000)))
2127+
2128+
# test invalid values of n
2129+
self.assertRaises(TypeError, tee, 'abc', 'invalid')
2130+
self.assertRaises(ValueError, tee, [], -1)
2131+
2132+
for n in range(5):
2133+
result = tee('abc', n)
2134+
self.assertEqual(type(result), tuple)
2135+
self.assertEqual(len(result), n)
2136+
self.assertEqual([list(x) for x in result], [list('abc')]*n)
2137+
2138+
# tee objects are independent (see bug gh-123884)
2139+
a, b = tee('abc')
2140+
c, d = tee(a)
2141+
e, f = tee(c)
2142+
self.assertTrue(len({a, b, c, d, e, f}) == 6)
2143+
2144+
# test tee_new
2145+
t1, t2 = tee('abc')
2146+
tnew = type(t1)
2147+
self.assertRaises(TypeError, tnew)
2148+
self.assertRaises(TypeError, tnew, 10)
2149+
t3 = tnew(t1)
2150+
self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
2151+
2152+
# test that tee objects are weak referencable
2153+
a, b = tee(range(10))
2154+
p = weakref.proxy(a)
2155+
self.assertEqual(getattr(p, '__class__'), type(b))
2156+
del a
2157+
gc.collect() # For PyPy or other GCs.
2158+
self.assertRaises(ReferenceError, getattr, p, '__class__')
2159+
2160+
ans = list('abc')
2161+
long_ans = list(range(10000))
2162+
2163+
# Tests not applicable to the tee() recipe
2164+
if False:
2165+
# check copy
2166+
a, b = tee('abc')
2167+
self.assertEqual(list(copy.copy(a)), ans)
2168+
self.assertEqual(list(copy.copy(b)), ans)
2169+
a, b = tee(list(range(10000)))
2170+
self.assertEqual(list(copy.copy(a)), long_ans)
2171+
self.assertEqual(list(copy.copy(b)), long_ans)
2172+
2173+
# check partially consumed copy
2174+
a, b = tee('abc')
2175+
take(2, a)
2176+
take(1, b)
2177+
self.assertEqual(list(copy.copy(a)), ans[2:])
2178+
self.assertEqual(list(copy.copy(b)), ans[1:])
2179+
self.assertEqual(list(a), ans[2:])
2180+
self.assertEqual(list(b), ans[1:])
2181+
a, b = tee(range(10000))
2182+
take(100, a)
2183+
take(60, b)
2184+
self.assertEqual(list(copy.copy(a)), long_ans[100:])
2185+
self.assertEqual(list(copy.copy(b)), long_ans[60:])
2186+
self.assertEqual(list(a), long_ans[100:])
2187+
self.assertEqual(list(b), long_ans[60:])
2188+
2189+
# Issue 13454: Crash when deleting backward iterator from tee()
2190+
forward, backward = tee(repeat(None, 2000)) # 20000000
2191+
try:
2192+
any(forward) # exhaust the iterator
2193+
del backward
2194+
except:
2195+
del forward, backward
2196+
raise
2197+
2198+
20322199
class TestGC(unittest.TestCase):
20332200

20342201
def makecycle(self, iterator, container):
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Fixed bug in itertools.tee() handling of other tee inputs (a tee in a tee).
2+
The output now has the promised *n* independent new iterators. Formerly,
3+
the first iterator was identical (not independent) to the input iterator.
4+
This would sometimes give surprising results.

Modules/itertoolsmodule.c

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,7 +1137,7 @@ itertools_tee_impl(PyObject *module, PyObject *iterable, Py_ssize_t n)
11371137
/*[clinic end generated code: output=1c64519cd859c2f0 input=c99a1472c425d66d]*/
11381138
{
11391139
Py_ssize_t i;
1140-
PyObject *it, *copyable, *copyfunc, *result;
1140+
PyObject *it, *to, *result;
11411141

11421142
if (n < 0) {
11431143
PyErr_SetString(PyExc_ValueError, "n must be >= 0");
@@ -1154,41 +1154,24 @@ itertools_tee_impl(PyObject *module, PyObject *iterable, Py_ssize_t n)
11541154
return NULL;
11551155
}
11561156

1157-
if (_PyObject_LookupAttr(it, &_Py_ID(__copy__), &copyfunc) < 0) {
1158-
Py_DECREF(it);
1157+
(void)&_Py_ID(__copy__); // Retain a reference to __copy__
1158+
itertools_state *state = get_module_state(module);
1159+
to = tee_fromiterable(state, it);
1160+
Py_DECREF(it);
1161+
if (to == NULL) {
11591162
Py_DECREF(result);
11601163
return NULL;
11611164
}
1162-
if (copyfunc != NULL) {
1163-
copyable = it;
1164-
}
1165-
else {
1166-
itertools_state *state = get_module_state(module);
1167-
copyable = tee_fromiterable(state, it);
1168-
Py_DECREF(it);
1169-
if (copyable == NULL) {
1170-
Py_DECREF(result);
1171-
return NULL;
1172-
}
1173-
copyfunc = PyObject_GetAttr(copyable, &_Py_ID(__copy__));
1174-
if (copyfunc == NULL) {
1175-
Py_DECREF(copyable);
1176-
Py_DECREF(result);
1177-
return NULL;
1178-
}
1179-
}
11801165

1181-
PyTuple_SET_ITEM(result, 0, copyable);
1166+
PyTuple_SET_ITEM(result, 0, to);
11821167
for (i = 1; i < n; i++) {
1183-
copyable = _PyObject_CallNoArgs(copyfunc);
1184-
if (copyable == NULL) {
1185-
Py_DECREF(copyfunc);
1168+
to = tee_copy((teeobject *)to, NULL);
1169+
if (to == NULL) {
11861170
Py_DECREF(result);
11871171
return NULL;
11881172
}
1189-
PyTuple_SET_ITEM(result, i, copyable);
1173+
PyTuple_SET_ITEM(result, i, to);
11901174
}
1191-
Py_DECREF(copyfunc);
11921175
return result;
11931176
}
11941177

0 commit comments

Comments
 (0)