From b70d9bfbd2b7301082bb40a3f3059373d992d55b Mon Sep 17 00:00:00 2001 From: KB Sriram Date: Fri, 23 Feb 2024 09:39:27 -0800 Subject: [PATCH 1/3] Type annotations for `adafruit_itertools` This adds another set of type annotations for https://github.com/adafruit/Adafruit_CircuitPython_IterTools/issues/12 - Mypy passes the adafruit_itertools module without errors. - Added tests for functionality and to exercise type annotations, and mypy also passes these tests without errors. --- adafruit_itertools/__init__.py | 47 +++-- tests/README.rst | 15 +- tests/test_itertools.py | 375 ++++++++++++++++++++++++++++++++- 3 files changed, 420 insertions(+), 17 deletions(-) diff --git a/adafruit_itertools/__init__.py b/adafruit_itertools/__init__.py index 6617ba4..857a916 100644 --- a/adafruit_itertools/__init__.py +++ b/adafruit_itertools/__init__.py @@ -53,7 +53,10 @@ pass -def accumulate(iterable, func=lambda x, y: x + y): +def accumulate( + iterable: Iterable[_T], + func: Callable[[_T, _T], _T] = lambda x, y: x + y, # type: ignore[operator] +) -> Iterator[_T]: """Make an iterator that returns accumulated sums, or accumulated results of other binary functions (specified via the optional func argument). If func is supplied, it should be a function of two @@ -200,7 +203,7 @@ def count(start: _N = 0, step: _N = 1) -> Iterator[_N]: start += step -def cycle(p): +def cycle(p: Iterable[_T]) -> Iterator[_T]: """Make an iterator returning elements from the iterable and saving a copy of each. When the iterable is exhausted, return elements from the saved copy. Repeats indefinitely. @@ -209,7 +212,7 @@ def cycle(p): """ try: - len(p) + len(p) # type: ignore[arg-type] except TypeError: # len() is not defined for this type. Assume it is # a finite iterable so we must cache the elements. @@ -242,7 +245,9 @@ def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T]) -> Iterator[_T] yield x -def filterfalse(predicate: _Predicate[_T], iterable: Iterable[_T]) -> Iterator[_T]: +def filterfalse( + predicate: Optional[_Predicate[_T]], iterable: Iterable[_T] +) -> Iterator[_T]: """Make an iterator that filters elements from iterable returning only those for which the predicate is False. If predicate is None, return the items that are false. @@ -288,15 +293,21 @@ class groupby: # [k for k, g in groupby('AAAABBBCCDAABBB')] --> A B C D A B # [list(g) for k, g in groupby('AAAABBBCCD')] --> AAAA BBB CC D - def __init__(self, iterable, key=None): + def __init__( + self, + iterable: Iterable[_T], + key: Optional[Callable[[_T], Any]] = None, + ): self.keyfunc = key if key is not None else lambda x: x self.it = iter(iterable) - self.tgtkey = self.currkey = self.currvalue = object() + # Sentinel values, not actually returned during iteration. + self.currvalue: _T = object() # type: ignore[assignment] + self.tgtkey = self.currkey = self.currvalue - def __iter__(self): + def __iter__(self) -> Iterator[Tuple[Any, Iterator[_T]]]: return self - def __next__(self): + def __next__(self) -> Tuple[Any, Iterator[_T]]: self.id = object() while self.currkey == self.tgtkey: self.currvalue = next(self.it) # Exit on StopIteration @@ -304,7 +315,7 @@ def __next__(self): self.tgtkey = self.currkey return (self.currkey, self._grouper(self.tgtkey, self.id)) - def _grouper(self, tgtkey, id): + def _grouper(self, tgtkey: Any, id: object) -> Iterator[_T]: while self.id is id and self.currkey == tgtkey: yield self.currvalue try: @@ -314,7 +325,12 @@ def _grouper(self, tgtkey, id): self.currkey = self.keyfunc(self.currvalue) -def islice(p, start, stop=(), step=1): +def islice( + p: Iterable[_T], + start: int, + stop: Optional[int] = (), # type: ignore[assignment] + step: int = 1, +) -> Iterator[_T]: """Make an iterator that returns selected elements from the iterable. If start is non-zero and stop is unspecified, then the value for start is used as end, and start is taken to be 0. Thus the @@ -420,7 +436,8 @@ def permutations( return -def product(*args: Iterable[_T], r: int = 1) -> Iterator[Tuple[_T, ...]]: +# def product(*args: Iterable[_T], r: int = 1) -> Iterator[Tuple[_T, ...]]: +def product(*args: Iterable[Any], r: int = 1) -> Iterator[Tuple[Any, ...]]: """Cartesian product of input iterables. Roughly equivalent to nested for-loops in a generator expression. For @@ -444,7 +461,7 @@ def product(*args: Iterable[_T], r: int = 1) -> Iterator[Tuple[_T, ...]]: # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111 pools = [tuple(pool) for pool in args] * r - result: List[List[_T]] = [[]] + result: List[List[Any]] = [[]] for pool in pools: result = [x + [y] for x in result for y in pool] for prod in result: @@ -513,8 +530,8 @@ def tee(iterable: Iterable[_T], n: int = 2) -> Sequence[Iterator[_T]]: def zip_longest( - *args: Iterable[_T], fillvalue: _OptionalFill = None -) -> Iterator[Tuple[Union[_T, _OptionalFill], ...]]: + *args: Iterable[Any], fillvalue: _OptionalFill = None +) -> Iterator[Tuple[Any, ...]]: """Make an iterator that aggregates elements from each of the iterables. If the iterables are of uneven length, missing values are filled-in with fillvalue. Iteration continues until the longest @@ -524,7 +541,7 @@ def zip_longest( :param fillvalue: value to fill in those missing from shorter iterables """ # zip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- D- - iterators: List[Iterator[Union[_T, _OptionalFill]]] = [iter(it) for it in args] + iterators: List[Iterator[Any]] = [iter(it) for it in args] num_active = len(iterators) if not num_active: return diff --git a/tests/README.rst b/tests/README.rst index ab25406..c91f434 100644 --- a/tests/README.rst +++ b/tests/README.rst @@ -8,7 +8,8 @@ Itertools Tests These tests run under CPython, and are intended to verify that the Adafruit library functions return the same outputs compared to ones in -the standard `itertools` module. +the standard `itertools` module, and also to exercise some type +annotations. These tests run automatically from the standard `circuitpython github workflow `_. To run them manually, first install these packages @@ -21,4 +22,16 @@ the following command:: $ python -m pytest +Type annotation tests don't run automatically at this point. But to +verify type-related issues manually, first install these packages if +necessary:: + + $ pip3 install mypy + +Then ensure you're in the *root* directory of the repository and run +the following command:: + + $ mypy --warn-unused-ignores --disallow-untyped-defs tests + + .. _wf: https://github.com/adafruit/workflows-circuitpython-libs/blob/6e1562eaabced4db1bd91173b698b1cc1dfd35ab/build/action.yml#L78-L84 diff --git a/tests/test_itertools.py b/tests/test_itertools.py index 9a52b5b..f2b13a0 100644 --- a/tests/test_itertools.py +++ b/tests/test_itertools.py @@ -1,14 +1,256 @@ # SPDX-FileCopyrightText: KB Sriram # SPDX-License-Identifier: MIT -from typing import Iterator, Optional, Sequence, TypeVar +from typing import Any, Callable, Iterator, Optional, Sequence, Tuple, TypeVar, Union import itertools as it import pytest import adafruit_itertools as ait +_K = TypeVar("_K") _T = TypeVar("_T") +def _take(n: int, iterator: Iterator[_T]) -> Sequence[_T]: + """Extract the first n elements from a long/infinite iterator.""" + return [v for _, v in zip(range(n), iterator)] + + +@pytest.mark.parametrize( + "seq, func", + [ + ([1, 2, 3, 4], lambda a, x: a - x), + ([], lambda a, _: a), + (["abc", "def"], lambda a, x: a + x), + ("abc", lambda a, x: a + x), + ], +) +def test_accumulate_with(seq: Sequence[_T], func: Callable[[_T, _T], _T]) -> None: + x: Sequence[_T] = list(it.accumulate(seq, func)) + y: Sequence[_T] = list(ait.accumulate(seq, func)) + assert x == y + + +def test_accumulate_types() -> None: + x_int: Iterator[int] = ait.accumulate([1, 2, 3]) + assert list(x_int) == list(it.accumulate([1, 2, 3])) + + x_bad_type: Iterator[str] = ait.accumulate([1, 2, 3]) # type: ignore[list-item] + assert list(x_bad_type) == list(it.accumulate([1, 2, 3])) + + x_str_f: Iterator[str] = ait.accumulate("abc", lambda a, x: a + x) + assert list(x_str_f) == list(it.accumulate("abc", lambda a, x: a + x)) + + x_bad_arg_f: Iterator[int] = ait.accumulate( + [1, 2], lambda a, x: a + ord(x) # type: ignore[arg-type] + ) + with pytest.raises(TypeError): + list(x_bad_arg_f) + + # Note: technically, this works and produces [1, "12"]. But the annotated types + # are declared to be more strict, and reject accumulator functions that produce + # mixed types in the result. + inp = [1, 2] + + def _stringify(acc: Union[int, str], item: int) -> str: + return str(acc) + str(item) + + x_mixed_f: Iterator[Union[int, str]] = ait.accumulate(inp, _stringify) # type: ignore[arg-type] + assert [1, "12"] == list(x_mixed_f) + + +@pytest.mark.parametrize( + "arglist, partial", + [ + ([[1, 2], [3, 4]], 1), + ([[3]], 1), + ([[]], 0), + ([[]], 1), + ([[], [None]], 1), + ([[1, "a"], ["b", 2]], 1), + ([[1, 2, 3], [4, 5, 6]], 4), + ], +) +def test_chain_basic(arglist: Sequence[Sequence[_T]], partial: int) -> None: + x: Sequence[_T] = list(ait.chain(*arglist)) + y: Sequence[_T] = list(it.chain(*arglist)) + assert x == y + xit: Iterator[_T] = ait.chain(*arglist) + yit: Iterator[_T] = it.chain(*arglist) + assert _take(partial, xit) == _take(partial, yit) + + +@pytest.mark.parametrize( + "arglist, partial", + [ + ([[1, 2], [3, 4]], 1), + ([[3]], 1), + ([[]], 0), + ([[]], 1), + ([[], [None]], 1), + ([[1, "a"], ["b", 2]], 1), + ([[1, 2, 3], [4, 5, 6]], 4), + ], +) +def test_chain_from_iterable(arglist: Sequence[Sequence[_T]], partial: int) -> None: + x: Sequence[_T] = list(ait.chain_from_iterable(arglist)) + y: Sequence[_T] = list(it.chain.from_iterable(arglist)) + assert x == y + xit: Iterator[_T] = ait.chain_from_iterable(arglist) + yit: Iterator[_T] = it.chain.from_iterable(arglist) + assert _take(partial, xit) == _take(partial, yit) + + +@pytest.mark.parametrize( + "seq, n", + [ + ([1, 2, 3, 4], 2), + ([1, 2, 3, 4], 3), + ([1, 2, 3], 32), + ([1, 2, 3], 0), + ([], 0), + ([], 1), + ], +) +def test_combinations(seq: Sequence[_T], n: int) -> None: + x: Sequence[Tuple[_T, ...]] = list(ait.combinations(seq, n)) + y: Sequence[Tuple[_T, ...]] = list(it.combinations(seq, n)) + assert x == y + + +@pytest.mark.parametrize( + "seq, n", + [ + ([1, 2, 3, 4], 2), + ([1, 2, 3, 4], 3), + ([1, 2, 3], 32), + ([1, 2, 3], 0), + ([], 0), + ([], 1), + ], +) +def test_combo_with_replacement(seq: Sequence[_T], n: int) -> None: + x: Sequence[Tuple[_T, ...]] = list(ait.combinations_with_replacement(seq, n)) + y: Sequence[Tuple[_T, ...]] = list(it.combinations_with_replacement(seq, n)) + assert x == y + + +@pytest.mark.parametrize( + "data, selectors", + [ + ([1, 2, 3, 4, 5], [True, False, True, False, True]), + ([1, 2, 3, 4, 5], [True, "", True, True, ""]), + ([1, 2, 3, 4, 5], [0, 0, None, 0, 0]), + ([1, 2, 3, 4, 5], [1, 1, 1, True, 1]), + ([1, 2, 3, 4, 5], [1, 0, 1]), + ([1, 2, 3, 4, 5], []), + ([1, 2, 3], [1, 1, 0, 0, 0, 0, 0, 0]), + ([], [1, 2, 3]), + ([], []), + ], +) +def test_compress(data: Sequence[int], selectors: Sequence[Any]) -> None: + x: Sequence[int] = list(ait.compress(data, selectors)) + y: Sequence[int] = list(it.compress(data, selectors)) + assert x == y + + +def test_count() -> None: + assert _take(5, it.count()) == _take(5, ait.count()) + for start in range(-10, 10): + assert _take(5, it.count(start)) == _take(5, ait.count(start)) + + for step in range(-10, 10): + assert _take(5, it.count(step=step)) == _take(5, ait.count(step=step)) + + for start in range(-5, 5): + for step in range(-5, 5): + assert _take(10, it.count(start, step)) == _take(10, ait.count(start, step)) + + +@pytest.mark.parametrize( + "seq", + [ + ([]), + ([None]), + ([1, 2]), + ], +) +def test_cycle(seq: Sequence[_T]) -> None: + x: Iterator[_T] = ait.cycle(seq) + y: Iterator[_T] = it.cycle(seq) + assert _take(10, x) == _take(10, y) + + +@pytest.mark.parametrize( + "predicate, seq", + [ + (ord, ""), + (lambda x: x == 42, [1, 2]), + (lambda x: x == 42, [1, 42]), + ], +) +def test_dropwhile(predicate: Callable[[_T], object], seq: Sequence[_T]) -> None: + x: Iterator[_T] = ait.dropwhile(predicate, seq) + y: Iterator[_T] = it.dropwhile(predicate, seq) + assert list(x) == list(y) + bad_type: Iterator[int] = ait.dropwhile(ord, [1, 2]) # type: ignore[arg-type] + with pytest.raises(TypeError): + list(bad_type) + + +@pytest.mark.parametrize( + "predicate, seq", + [ + (None, []), + (None, [1, 0, 2]), + (lambda x: x % 2, range(10)), + ], +) +def test_filterfalse( + predicate: Optional[Callable[[_T], object]], seq: Sequence[_T] +) -> None: + x: Iterator[_T] = ait.filterfalse(predicate, seq) + y: Iterator[_T] = it.filterfalse(predicate, seq) + assert list(x) == list(y) + bad_type: Iterator[str] = ait.filterfalse(ord, [1, 2]) # type: ignore[list-item] + with pytest.raises(TypeError): + list(bad_type) + + +@pytest.mark.parametrize( + "data, key", + [ + ("abcd", ord), + ("", ord), + ("aabbcbbbaaa", ord), + ([(0, 1), (0, 2), (0, 3), (1, 4), (0, 5), (0, 6)], lambda x: x[0]), + ([(0, 1), (0, 2), (0, 3), (1, 4), (0, 5), (0, 6)], max), + ], +) +def test_groupby(data: Sequence[_T], key: Callable[[_T], _K]) -> None: + def _listify( + iterable: Iterator[Tuple[_K, Iterator[_T]]] + ) -> Sequence[Tuple[_K, Sequence[_T]]]: + return [(k, list(group)) for k, group in iterable] + + it_l = _listify(it.groupby(data, key)) + ait_l = _listify(ait.groupby(data, key)) + assert it_l == ait_l + + +def test_groupby_types() -> None: + assert list(ait.groupby([])) == list(it.groupby([])) + assert list(ait.groupby([], key=id)) == list(it.groupby([], key=id)) + assert list(ait.groupby("", ord)) == list(it.groupby("", ord)) + + with pytest.raises(TypeError): + list(ait.groupby("abc", [])) # type: ignore[arg-type] + with pytest.raises(TypeError): + list(ait.groupby("abc", chr)) # type: ignore[arg-type] + with pytest.raises(TypeError): + ait.groupby(None) # type: ignore[arg-type] + + @pytest.mark.parametrize( "seq, start", [ @@ -73,3 +315,134 @@ def test_islice_error() -> None: list(ait.islice("abc", 0, -1)) with pytest.raises(ValueError): list(ait.islice("abc", 0, 0, 0)) + + +@pytest.mark.parametrize( + "seq", + [ + "", + "A", + "ABCDEFGH", + ], +) +def test_permutations(seq: Sequence[_T]) -> None: + x: Iterator[Tuple[_T, ...]] = ait.permutations(seq) + y: Iterator[Tuple[_T, ...]] = it.permutations(seq) + assert list(x) == list(y) + + for r in range(3): + x = ait.permutations(seq, r) + y = it.permutations(seq, r) + assert list(x) == list(y) + + +@pytest.mark.parametrize( + "seq", + [ + "", + "A", + "ABCDEFGH", + [1, 2, "3", None, 4], + ], +) +def test_product_one(seq: Sequence[object]) -> None: + x: Iterator[Tuple[object, ...]] = ait.product(seq) + y: Iterator[Tuple[object, ...]] = it.product(seq) + assert list(x) == list(y) + + for r in range(3): + x = ait.product(seq, r=r) + y = it.product(seq, repeat=r) + assert list(x) == list(y) + + +@pytest.mark.parametrize( + "seq1, seq2", + [ + ("", []), + ("", [1, 2]), + ("AB", []), + ("ABCDEFGH", [1, 2, 3]), + ], +) +def test_product_two(seq1: Sequence[str], seq2: Sequence[int]) -> None: + x: Iterator[Tuple[str, int]] = ait.product(seq1, seq2) + y: Iterator[Tuple[str, int]] = it.product(seq1, seq2) + assert list(x) == list(y) + + for r in range(3): + x_repeat: Iterator[Tuple[object, ...]] = ait.product(seq1, seq2, r=r) + y_repeat: Iterator[Tuple[object, ...]] = it.product(seq1, seq2, repeat=r) + assert list(x_repeat) == list(y_repeat) + + +@pytest.mark.parametrize( + "element", + ["", None, 5, "abc"], +) +def test_repeat(element: _T) -> None: + x: Iterator[_T] = ait.repeat(element) + y: Iterator[_T] = it.repeat(element) + assert _take(5, x) == _take(5, y) + + for count in range(10): + x = ait.repeat(element, count) + y = it.repeat(element, count) + assert _take(5, x) == _take(5, y) + + +@pytest.mark.parametrize( + "func, seq", + [ + (pow, [(2, 3), (3, 2), (10, 2)]), + (lambda x, y: x + y, [("a", "b"), ("c", "d")]), + ], +) +def test_starmap(func: Callable[[_T, _T], _T], seq: Sequence[Sequence[_T]]) -> None: + x: Iterator[_T] = ait.starmap(func, seq) + y: Iterator[_T] = it.starmap(func, seq) + assert list(x) == list(y) + + +@pytest.mark.parametrize( + "func, seq", + [ + (lambda x: x, []), + (lambda x: x == 3, [1, 2, 3, 2, 3]), + (lambda x: x == 3, [1, 2]), + ], +) +def test_takewhile(func: Callable[[_T], bool], seq: Sequence[_T]) -> None: + x: Iterator[_T] = ait.takewhile(func, seq) + y: Iterator[_T] = it.takewhile(func, seq) + assert list(x) == list(y) + + +@pytest.mark.parametrize( + "seq", + ["", "abc"], +) +def test_tee(seq: Sequence[_T]) -> None: + x: Sequence[Iterator[_T]] = ait.tee(seq) + y: Sequence[Iterator[_T]] = it.tee(seq) + assert [list(v) for v in x] == [list(v) for v in y] + + for n in range(3): + x = ait.tee(seq, n) + y = it.tee(seq, n) + assert [list(v) for v in x] == [list(v) for v in y] + + +@pytest.mark.parametrize( + "seq1, seq2", + [ + ("", []), + ("", [1, 2]), + ("abc", []), + ("abc", [1, 2]), + ], +) +def test_zip_longest(seq1: Sequence[str], seq2: Sequence[int]) -> None: + x: Iterator[Tuple[str, int]] = ait.zip_longest(seq1, seq2) + y: Iterator[Tuple[str, int]] = it.zip_longest(seq1, seq2) + assert list(x) == list(y) From 088e246c0bee4e81a24730be3168c2bc19c46853 Mon Sep 17 00:00:00 2001 From: KB Sriram Date: Sun, 28 Apr 2024 22:39:29 -0700 Subject: [PATCH 2/3] Update type annotations for itertools extras. - Fixed a small bug with all_equal() on empty inputs. - Fixed a documentation bug with pairwise(). - Added tests for all methods. - Verified mypy --warn-unused-ignores --disallow-untyped-defs runs successfully on tests. Fixes https://github.com/adafruit/Adafruit_CircuitPython_IterTools/issues/12 --- .../adafruit_itertools_extras.py | 93 ++++-- tests/test_itertools_extras.py | 285 ++++++++++++++++++ 2 files changed, 353 insertions(+), 25 deletions(-) create mode 100644 tests/test_itertools_extras.py diff --git a/adafruit_itertools/adafruit_itertools_extras.py b/adafruit_itertools/adafruit_itertools_extras.py index 8a41038..a435bbd 100644 --- a/adafruit_itertools/adafruit_itertools_extras.py +++ b/adafruit_itertools/adafruit_itertools_extras.py @@ -41,26 +41,54 @@ import adafruit_itertools as it +try: + from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + ) + from typing_extensions import TypeAlias + + _T = TypeVar("_T") + _N: TypeAlias = Union[int, float, complex] + _Predicate: TypeAlias = Callable[[_T], bool] +except ImportError: + pass + + __version__ = "0.0.0+auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Itertools.git" -def all_equal(iterable): +def all_equal(iterable: Iterable[Any]) -> bool: """Returns True if all the elements are equal to each other. :param iterable: source of values """ g = it.groupby(iterable) - next(g) # should succeed, value isn't relevant try: - next(g) # should fail: only 1 group + next(g) # value isn't relevant + except StopIteration: + # Empty iterable, return True to match cpython behavior. + return True + try: + next(g) + # more than one group, so we have different elements. return False except StopIteration: + # Only one group - all elements must be equal. return True -def dotproduct(vec1, vec2): +def dotproduct(vec1: Iterable[_N], vec2: Iterable[_N]) -> _N: """Compute the dot product of two vectors. :param vec1: the first vector @@ -71,7 +99,11 @@ def dotproduct(vec1, vec2): return sum(map(lambda x, y: x * y, vec1, vec2)) -def first_true(iterable, default=False, pred=None): +def first_true( + iterable: Iterable[_T], + default: Union[bool, _T] = False, + pred: Optional[_Predicate[_T]] = None, +) -> Union[bool, _T]: """Returns the first true value in the iterable. If no true value is found, returns *default* @@ -94,7 +126,7 @@ def first_true(iterable, default=False, pred=None): return default -def flatten(iterable_of_iterables): +def flatten(iterable_of_iterables: Iterable[Iterable[_T]]) -> Iterator[_T]: """Flatten one level of nesting. :param iterable_of_iterables: a sequence of iterables to flatten @@ -104,7 +136,9 @@ def flatten(iterable_of_iterables): return it.chain_from_iterable(iterable_of_iterables) -def grouper(iterable, n, fillvalue=None): +def grouper( + iterable: Iterable[_T], n: int, fillvalue: Optional[_T] = None +) -> Iterator[Tuple[_T, ...]]: """Collect data into fixed-length chunks or blocks. :param iterable: source of values @@ -118,7 +152,7 @@ def grouper(iterable, n, fillvalue=None): return it.zip_longest(*args, fillvalue=fillvalue) -def iter_except(func, exception): +def iter_except(func: Callable[[], _T], exception: Type[BaseException]) -> Iterator[_T]: """Call a function repeatedly, yielding the results, until exception is raised. Converts a call-until-exception interface to an iterator interface. @@ -143,7 +177,7 @@ def iter_except(func, exception): pass -def ncycles(iterable, n): +def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: """Returns the sequence elements a number of times. :param iterable: the source of values @@ -153,7 +187,7 @@ def ncycles(iterable, n): return it.chain_from_iterable(it.repeat(tuple(iterable), n)) -def nth(iterable, n, default=None): +def nth(iterable: Iterable[_T], n: int, default: Optional[_T] = None) -> Optional[_T]: """Returns the nth item or a default value. :param iterable: the source of values @@ -166,7 +200,7 @@ def nth(iterable, n, default=None): return default -def padnone(iterable): +def padnone(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: """Returns the sequence elements and then returns None indefinitely. Useful for emulating the behavior of the built-in map() function. @@ -177,13 +211,17 @@ def padnone(iterable): return it.chain(iterable, it.repeat(None)) -def pairwise(iterable): - """Pair up valuesin the iterable. +def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: + """Return successive overlapping pairs from the iterable. + + The number of tuples from the output will be one fewer than the + number of values in the input. It will be empty if the input has + fewer than two values. :param iterable: source of values """ - # pairwise(range(11)) -> (1, 2), (3, 4), (5, 6), (7, 8), (9, 10) + # pairwise(range(5)) -> (0, 1), (1, 2), (2, 3), (3, 4) a, b = it.tee(iterable) try: next(b) @@ -192,7 +230,9 @@ def pairwise(iterable): return zip(a, b) -def partition(pred, iterable): +def partition( + pred: _Predicate[_T], iterable: Iterable[_T] +) -> Tuple[Iterator[_T], Iterator[_T]]: """Use a predicate to partition entries into false entries and true entries. :param pred: the predicate that divides the values @@ -204,7 +244,7 @@ def partition(pred, iterable): return it.filterfalse(pred, t1), filter(pred, t2) -def prepend(value, iterator): +def prepend(value: _T, iterator: Iterable[_T]) -> Iterator[_T]: """Prepend a single value in front of an iterator :param value: the value to prepend @@ -215,7 +255,7 @@ def prepend(value, iterator): return it.chain([value], iterator) -def quantify(iterable, pred=bool): +def quantify(iterable: Iterable[_T], pred: _Predicate[_T] = bool) -> int: """Count how many times the predicate is true. :param iterable: source of values @@ -227,7 +267,9 @@ def quantify(iterable, pred=bool): return sum(map(pred, iterable)) -def repeatfunc(func, times=None, *args): +def repeatfunc( + func: Callable[..., _T], times: Optional[int] = None, *args: Any +) -> Iterator[_T]: """Repeat calls to func with specified arguments. Example: repeatfunc(random.random) @@ -242,7 +284,7 @@ def repeatfunc(func, times=None, *args): return it.starmap(func, it.repeat(args, times)) -def roundrobin(*iterables): +def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: """Return an iterable created by repeatedly picking value from each argument in order. @@ -263,18 +305,19 @@ def roundrobin(*iterables): nexts = it.cycle(it.islice(nexts, num_active)) -def tabulate(function, start=0): - """Apply a function to a sequence of consecutive integers. +def tabulate(function: Callable[[int], int], start: int = 0) -> Iterator[int]: + """Apply a function to a sequence of consecutive numbers. - :param function: the function of one integer argument + :param function: the function of one numeric argument. :param start: optional value to start at (default is 0) """ # take(5, tabulate(lambda x: x * x))) -> 0 1 4 9 16 - return map(function, it.count(start)) + counter: Iterator[int] = it.count(start) # type: ignore[assignment] + return map(function, counter) -def tail(n, iterable): +def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: """Return an iterator over the last n items :param n: how many values to return @@ -294,7 +337,7 @@ def tail(n, iterable): return iter(buf) -def take(n, iterable): +def take(n: int, iterable: Iterable[_T]) -> List[_T]: """Return first n items of the iterable as a list :param n: how many values to take diff --git a/tests/test_itertools_extras.py b/tests/test_itertools_extras.py new file mode 100644 index 0000000..fad8786 --- /dev/null +++ b/tests/test_itertools_extras.py @@ -0,0 +1,285 @@ +# SPDX-FileCopyrightText: KB Sriram +# SPDX-License-Identifier: MIT + +from typing import ( + Callable, + Iterator, + Optional, + Sequence, + TypeVar, +) +from typing_extensions import TypeAlias + +import more_itertools as itextras +import pytest +from adafruit_itertools import adafruit_itertools_extras as aextras + +_K = TypeVar("_K") +_T = TypeVar("_T") +_S = TypeVar("_S") +_Predicate: TypeAlias = Callable[[_T], bool] + + +def _take(n: int, iterator: Iterator[_T]) -> Sequence[_T]: + """Extract the first n elements from a long/infinite iterator.""" + return [v for _, v in zip(range(n), iterator)] + + +@pytest.mark.parametrize( + "data", + [ + "aaaa", + "abcd", + "a", + "", + (1, 2), + (3, 3), + ("", False), + (42, True), + ], +) +def test_all_equal(data: Sequence[_T]) -> None: + assert itextras.all_equal(data) == aextras.all_equal(data) + + +@pytest.mark.parametrize( + ("vec1", "vec2"), + [ + ([1, 2], [3, 4]), + ([], []), + ([1], [2, 3]), + ([4, 5], [6]), + ], +) +def test_dotproduct(vec1: Sequence[int], vec2: Sequence[int]) -> None: + assert itextras.dotproduct(vec1, vec2) == aextras.dotproduct(vec1, vec2) + + +@pytest.mark.parametrize( + ("seq", "dflt", "pred"), + [ + ([0, 2], 0, None), + ([], 10, None), + ([False], True, None), + ([1, 2], -1, lambda _: False), + ([0, 1], -1, lambda _: True), + ([], -1, lambda _: True), + ], +) +def test_first_true( + seq: Sequence[_T], dflt: _T, pred: Optional[_Predicate[_T]] +) -> None: + assert itextras.first_true(seq, dflt, pred) == aextras.first_true(seq, dflt, pred) + + +@pytest.mark.parametrize( + ("seq1", "seq2"), + [ + ("abc", "def"), + ("", "def"), + ("abc", ""), + ("", ""), + ], +) +def test_flatten(seq1: str, seq2: str) -> None: + assert list(itextras.flatten(seq1 + seq2)) == list(aextras.flatten(seq1 + seq2)) + for repeat in range(3): + assert list(itextras.flatten([seq1] * repeat)) == list( + aextras.flatten([seq1] * repeat) + ) + assert list(itextras.flatten([seq2] * repeat)) == list( + aextras.flatten([seq2] * repeat) + ) + + +@pytest.mark.parametrize( + ("seq", "count", "fill"), + [ + ("abc", 3, None), + ("abcd", 3, None), + ("abc", 3, "x"), + ("abcd", 3, "x"), + ("abc", 0, None), + ("", 3, "xy"), + ], +) +def test_grouper(seq: Sequence[str], count: int, fill: Optional[str]) -> None: + assert list(itextras.grouper(seq, count, fillvalue=fill)) == list( + aextras.grouper(seq, count, fillvalue=fill) + ) + + +@pytest.mark.parametrize( + ("data"), + [ + (1, 2, 3), + (), + ], +) +def test_iter_except(data: Sequence[int]) -> None: + assert list(itextras.iter_except(list(data).pop, IndexError)) == list( + aextras.iter_except(list(data).pop, IndexError) + ) + + +@pytest.mark.parametrize( + ("seq", "count"), + [ + ("abc", 4), + ("abc", 0), + ("", 4), + ], +) +def test_ncycles(seq: str, count: int) -> None: + assert list(itextras.ncycles(seq, count)) == list(aextras.ncycles(seq, count)) + + +@pytest.mark.parametrize( + ("seq", "n", "dflt"), + [ + ("abc", 1, None), + ("abc", 10, None), + ("abc", 10, "x"), + ("", 0, None), + ], +) +def test_nth(seq: str, n: int, dflt: Optional[str]) -> None: + assert itextras.nth(seq, n, dflt) == aextras.nth(seq, n, dflt) + + +@pytest.mark.parametrize( + ("seq"), + [ + "abc", + "", + ], +) +def test_padnone(seq: str) -> None: + assert _take(10, itextras.padnone(seq)) == _take(10, aextras.padnone(seq)) + + +@pytest.mark.parametrize( + ("seq"), + [ + (), + (1,), + (1, 2), + (1, 2, 3), + (1, 2, 3, 4), + ], +) +def test_pairwise(seq: Sequence[int]) -> None: + assert list(itextras.pairwise(seq)) == list(aextras.pairwise(seq)) + + +@pytest.mark.parametrize( + ("pred", "seq"), + [ + (lambda x: x % 2, (0, 1, 2, 3)), + (lambda x: x % 2, (0, 2)), + (lambda x: x % 2, ()), + ], +) +def test_partition(pred: _Predicate[int], seq: Sequence[int]) -> None: + # assert list(itextras.partition(pred, seq)) == list(aextras.partition(pred, seq)) + true1, false1 = itextras.partition(pred, seq) + true2, false2 = aextras.partition(pred, seq) + assert list(true1) == list(true2) + assert list(false1) == list(false2) + + +@pytest.mark.parametrize( + ("value", "seq"), + [ + (1, (2, 3)), + (1, ()), + ], +) +def test_prepend(value: int, seq: Sequence[int]) -> None: + assert list(itextras.prepend(value, seq)) == list(aextras.prepend(value, seq)) + + +@pytest.mark.parametrize( + ("seq", "pred"), + [ + ((0, 1), lambda x: x % 2 == 0), + ((1, 1), lambda x: x % 2 == 0), + ((), lambda x: x % 2 == 0), + ], +) +def test_quantify(seq: Sequence[int], pred: _Predicate[int]) -> None: + assert itextras.quantify(seq) == aextras.quantify(seq) + assert itextras.quantify(seq, pred) == aextras.quantify(seq, pred) + + +@pytest.mark.parametrize( + ("func", "times", "args"), + [ + (lambda: 1, 5, []), + (lambda: 1, 0, []), + (lambda x: x + 1, 10, [3]), + (lambda x, y: x + y, 10, [3, 4]), + ], +) +def test_repeatfunc(func: Callable, times: int, args: Sequence[int]) -> None: + assert _take(5, itextras.repeatfunc(func, None, *args)) == _take( + 5, aextras.repeatfunc(func, None, *args) + ) + assert list(itextras.repeatfunc(func, times, *args)) == list( + aextras.repeatfunc(func, times, *args) + ) + + +@pytest.mark.parametrize( + ("seq1", "seq2"), + [ + ("abc", "def"), + ("a", "bc"), + ("ab", "c"), + ("", "abc"), + ("", ""), + ], +) +def test_roundrobin(seq1: str, seq2: str) -> None: + assert list(itextras.roundrobin(seq1)) == list(aextras.roundrobin(seq1)) + assert list(itextras.roundrobin(seq1, seq2)) == list(aextras.roundrobin(seq1, seq2)) + + +@pytest.mark.parametrize( + ("func", "start"), + [ + (lambda x: 2 * x, 17), + (lambda x: -x, -3), + ], +) +def test_tabulate(func: Callable[[int], int], start: int) -> None: + assert _take(5, itextras.tabulate(func)) == _take(5, aextras.tabulate(func)) + assert _take(5, itextras.tabulate(func, start)) == _take( + 5, aextras.tabulate(func, start) + ) + + +@pytest.mark.parametrize( + ("n", "seq"), + [ + (3, "abcdefg"), + (0, "abcdefg"), + (10, "abcdefg"), + (5, ""), + ], +) +def test_tail(n: int, seq: str) -> None: + assert list(itextras.tail(n, seq)) == list(aextras.tail(n, seq)) + + +@pytest.mark.parametrize( + ("n", "seq"), + [ + (3, "abcdefg"), + (0, "abcdefg"), + (10, "abcdefg"), + (5, ""), + ], +) +def test_take(n: int, seq: str) -> None: + assert list(itextras.take(n, seq)) == list(aextras.take(n, seq)) From 3bd2dd91f89bcf9d35d43a820409221068cafa3e Mon Sep 17 00:00:00 2001 From: KB Sriram Date: Sun, 28 Apr 2024 23:56:38 -0700 Subject: [PATCH 3/3] Include more-itertools in optional_requirements This is only used in tests, for comparison. --- optional_requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/optional_requirements.txt b/optional_requirements.txt index d4e27c4..1856c06 100644 --- a/optional_requirements.txt +++ b/optional_requirements.txt @@ -1,3 +1,6 @@ # SPDX-FileCopyrightText: 2022 Alec Delaney, for Adafruit Industries # # SPDX-License-Identifier: Unlicense + +# For comparison when running tests +more-itertools