diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index a1d1ef67bebd07..06c73fc234903f 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -799,10 +799,50 @@ which incur interpreter overhead. "Returns the sequence elements n times" return chain.from_iterable(repeat(tuple(iterable), n)) + def batched(iterable, n): + "Batch data into tuples of length n. The last batch may be shorter." + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + raise ValueError('n must be at least one') + it = iter(iterable) + while (batch := tuple(islice(it, n))): + yield batch + + def grouper(iterable, n, *, incomplete='fill', fillvalue=None): + "Collect data into non-overlapping fixed-length chunks or blocks" + # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx + # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError + # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF + args = [iter(iterable)] * n + if incomplete == 'fill': + return zip_longest(*args, fillvalue=fillvalue) + if incomplete == 'strict': + return zip(*args, strict=True) + if incomplete == 'ignore': + return zip(*args) + else: + raise ValueError('Expected fill, strict, or ignore') + def sumprod(vec1, vec2): "Compute a sum of products." return sum(starmap(operator.mul, zip(vec1, vec2, strict=True))) + def sum_of_squares(it): + "Add up the squares of the input values." + # sum_of_squares([10, 20, 30]) -> 1400 + return sumprod(*tee(it)) + + def transpose(it): + "Swap the rows and columns of the input." + # transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33) + return zip(*it, strict=True) + + def matmul(m1, m2): + "Multiply two matrices." + # matmul([(7, 5), (3, 5)], [[2, 5], [7, 9]]) --> (49, 80), (41, 60) + n = len(m2[0]) + return batched(starmap(sumprod, product(m1, transpose(m2))), n) + def convolve(signal, kernel): # See: https://betterexplained.com/articles/intuitive-convolution/ # convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur) @@ -886,30 +926,6 @@ which incur interpreter overhead. return starmap(func, repeat(args)) return starmap(func, repeat(args, times)) - def grouper(iterable, n, *, incomplete='fill', fillvalue=None): - "Collect data into non-overlapping fixed-length chunks or blocks" - # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx - # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError - # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF - args = [iter(iterable)] * n - if incomplete == 'fill': - return zip_longest(*args, fillvalue=fillvalue) - if incomplete == 'strict': - return zip(*args, strict=True) - if incomplete == 'ignore': - return zip(*args) - else: - raise ValueError('Expected fill, strict, or ignore') - - def batched(iterable, n): - "Batch data into tuples of length n. The last batch may be shorter." - # batched('ABCDEFG', 3) --> ABC DEF G - if n < 1: - raise ValueError('n must be at least one') - it = iter(iterable) - while (batch := tuple(islice(it, n))): - yield batch - def triplewise(iterable): "Return overlapping triplets from an iterable" # triplewise('ABCDEFG') --> ABC BCD CDE DEF EFG @@ -1184,6 +1200,17 @@ which incur interpreter overhead. >>> sumprod([1,2,3], [4,5,6]) 32 + >>> sum_of_squares([10, 20, 30]) + 1400 + + >>> list(transpose([(1, 2, 3), (11, 22, 33)])) + [(1, 11), (2, 22), (3, 33)] + + >>> list(matmul([(7, 5), (3, 5)], [[2, 5], [7, 9]])) + [(49, 80), (41, 60)] + >>> list(matmul([[2, 5], [7, 9], [3, 4]], [[7, 11, 5, 4, 9], [3, 5, 2, 6, 3]])) + [(29, 47, 20, 38, 33), (76, 122, 53, 82, 90), (33, 53, 23, 36, 39)] + >>> data = [20, 40, 24, 32, 20, 28, 16] >>> list(convolve(data, [0.25, 0.25, 0.25, 0.25])) [5.0, 15.0, 21.0, 29.0, 29.0, 26.0, 24.0, 16.0, 11.0, 4.0]