diff --git a/codecov.yml b/codecov.yml index f76a2bc1..1a517b3a 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,3 +1,4 @@ comment: false ignore: - - "src/array_api_extra/_typing" + - "src/array_api_extra/_lib/_compat" + - "src/array_api_extra/_lib/_typing" diff --git a/docs/api-reference.md b/docs/api-reference.md index 1307b6fc..ffe68f24 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -11,5 +11,6 @@ create_diagonal expand_dims kron + setdiff1d sinc ``` diff --git a/pixi.lock b/pixi.lock index c4169a62..eaf64305 100644 --- a/pixi.lock +++ b/pixi.lock @@ -46,6 +46,7 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2 - pypi: . @@ -81,6 +82,7 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/xz-5.2.6-h57fd34a_0.tar.bz2 - pypi: . @@ -116,6 +118,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda @@ -163,12 +166,13 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-cov-6.0.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_100_cp313.conda + - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_101_cp313.conda - conda: https://prefix.dev/conda-forge/linux-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/linux-64/readline-8.2-h8228510_1.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2 - pypi: . @@ -200,12 +204,13 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-cov-6.0.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-h75c3a9f_100_cp313.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-hbbac1ca_101_cp313.conda - conda: https://prefix.dev/conda-forge/osx-arm64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/osx-arm64/readline-8.2-h92ec313_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/xz-5.2.6-h57fd34a_0.tar.bz2 - pypi: . @@ -237,12 +242,13 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/pthreads-win32-2.9.1-h2466b09_4.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-cov-6.0.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_100_cp313.conda + - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_101_cp313.conda - conda: https://prefix.dev/conda-forge/win-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/win-64/tbb-2021.13.0-hc790b64_0.conda - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda @@ -273,10 +279,11 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/ncurses-6.5-he02047a_1.conda - conda: https://prefix.dev/conda-forge/linux-64/openssl-3.4.0-hb9d3cd8_0.conda - - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_100_cp313.conda + - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_101_cp313.conda - conda: https://prefix.dev/conda-forge/linux-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/linux-64/readline-8.2-h8228510_1.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2 - pypi: . @@ -290,10 +297,11 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/libzlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ncurses-6.5-h7bae524_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/openssl-3.4.0-h39f12f2_0.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-h75c3a9f_100_cp313.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-hbbac1ca_101_cp313.conda - conda: https://prefix.dev/conda-forge/osx-arm64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/osx-arm64/readline-8.2-h92ec313_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/xz-5.2.6-h57fd34a_0.tar.bz2 - pypi: . @@ -306,9 +314,10 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda - conda: https://prefix.dev/conda-forge/win-64/openssl-3.4.0-h2466b09_0.conda - - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_100_cp313.conda + - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_101_cp313.conda - conda: https://prefix.dev/conda-forge/win-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda @@ -357,7 +366,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/idna-3.10-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/noarch/ipython-8.29.0-pyh707e725_0.conda + - conda: https://prefix.dev/conda-forge/noarch/ipython-8.30.0-pyh707e725_0.conda - conda: https://prefix.dev/conda-forge/noarch/isort-5.13.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/jedi-0.19.2-pyhff2d567_0.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.4-pyhd8ed1ab_0.conda @@ -403,7 +412,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-cov-6.0.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_100_cp313.conda + - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_101_cp313.conda - conda: https://prefix.dev/conda-forge/linux-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/noarch/pytz-2024.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/pyyaml-6.0.2-py313h536fd9c_1.conda @@ -480,7 +489,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/idna-3.10-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/noarch/ipython-8.29.0-pyh707e725_0.conda + - conda: https://prefix.dev/conda-forge/noarch/ipython-8.30.0-pyh707e725_0.conda - conda: https://prefix.dev/conda-forge/noarch/isort-5.13.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/jedi-0.19.2-pyhff2d567_0.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.4-pyhd8ed1ab_0.conda @@ -521,7 +530,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-cov-6.0.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-h75c3a9f_100_cp313.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-hbbac1ca_101_cp313.conda - conda: https://prefix.dev/conda-forge/osx-arm64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/noarch/pytz-2024.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/pyyaml-6.0.2-py313h20a7fcf_1.conda @@ -599,7 +608,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/intel-openmp-2024.2.1-h57928b3_1083.conda - - conda: https://prefix.dev/conda-forge/noarch/ipython-8.29.0-pyh7428d3b_0.conda + - conda: https://prefix.dev/conda-forge/noarch/ipython-8.30.0-pyh7428d3b_0.conda - conda: https://prefix.dev/conda-forge/noarch/isort-5.13.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/jedi-0.19.2-pyhff2d567_0.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.4-pyhd8ed1ab_0.conda @@ -637,7 +646,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pysocks-1.7.1-pyh0701188_6.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-cov-6.0.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_100_cp313.conda + - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_101_cp313.conda - conda: https://prefix.dev/conda-forge/win-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/noarch/pytz-2024.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/pyyaml-6.0.2-py313ha7868ed_1.conda @@ -732,7 +741,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pygments-2.18.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2 - - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_100_cp313.conda + - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_101_cp313.conda - conda: https://prefix.dev/conda-forge/linux-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/noarch/pytz-2024.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/pyyaml-6.0.2-py313h536fd9c_1.conda @@ -750,6 +759,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.2.3-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2 @@ -795,7 +805,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pygments-2.18.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2 - - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-h75c3a9f_100_cp313.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-hbbac1ca_101_cp313.conda - conda: https://prefix.dev/conda-forge/osx-arm64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/noarch/pytz-2024.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/pyyaml-6.0.2-py313h20a7fcf_1.conda @@ -813,6 +823,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.2.3-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/xz-5.2.6-h57fd34a_0.tar.bz2 @@ -856,7 +867,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pygments-2.18.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pysocks-1.7.1-pyh0701188_6.tar.bz2 - - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_100_cp313.conda + - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_101_cp313.conda - conda: https://prefix.dev/conda-forge/win-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/noarch/pytz-2024.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/pyyaml-6.0.2-py313ha7868ed_1.conda @@ -873,6 +884,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.2.3-pyhd8ed1ab_0.conda @@ -946,7 +958,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pylint-3.3.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_100_cp313.conda + - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_101_cp313.conda - conda: https://prefix.dev/conda-forge/linux-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/linux-64/pyyaml-6.0.2-py313h536fd9c_1.conda - conda: https://prefix.dev/conda-forge/linux-64/readline-8.2-h8228510_1.conda @@ -1007,7 +1019,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pylint-3.3.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-h75c3a9f_100_cp313.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-hbbac1ca_101_cp313.conda - conda: https://prefix.dev/conda-forge/osx-arm64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/osx-arm64/pyyaml-6.0.2-py313h20a7fcf_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/readline-8.2-h92ec313_1.conda @@ -1068,7 +1080,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pycparser-2.22-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pylint-3.3.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_100_cp313.conda + - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_101_cp313.conda - conda: https://prefix.dev/conda-forge/win-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/win-64/pyyaml-6.0.2-py313ha7868ed_1.conda - conda: https://prefix.dev/conda-forge/noarch/setuptools-75.6.0-pyhff2d567_1.conda @@ -1130,12 +1142,13 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-cov-6.0.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_100_cp313.conda + - conda: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_101_cp313.conda - conda: https://prefix.dev/conda-forge/linux-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/linux-64/readline-8.2-h8228510_1.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2 - pypi: . @@ -1167,12 +1180,13 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-cov-6.0.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-h75c3a9f_100_cp313.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-hbbac1ca_101_cp313.conda - conda: https://prefix.dev/conda-forge/osx-arm64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/osx-arm64/readline-8.2-h92ec313_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/xz-5.2.6-h57fd34a_0.tar.bz2 - pypi: . @@ -1204,12 +1218,13 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/pthreads-win32-2.9.1-h2466b09_4.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.3-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pytest-cov-6.0.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_100_cp313.conda + - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_101_cp313.conda - conda: https://prefix.dev/conda-forge/win-64/python_abi-3.13-5_cp313.conda - conda: https://prefix.dev/conda-forge/win-64/tbb-2021.13.0-hc790b64_0.conda - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda @@ -1274,8 +1289,9 @@ packages: name: array-api-extra version: 0.2.1.dev0 path: . - sha256: 8df2522a223b90e904144fd84d4a1c3119b3e3eaa1a17f12d3fa98070195d147 + sha256: 8e5573eb0fdab83a4df5f31277f541ecbf3b046edb8730d82dc1580593ced2d6 requires_dist: + - typing-extensions - furo>=2023.8.17 ; extra == 'docs' - myst-parser>=0.13 ; extra == 'docs' - sphinx-autodoc-typehints ; extra == 'docs' @@ -1794,6 +1810,7 @@ packages: arch: x86_64 platform: win license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 319434 @@ -1815,6 +1832,7 @@ packages: arch: x86_64 platform: linux license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 294004 @@ -1836,6 +1854,7 @@ packages: arch: arm64 platform: osx license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 293158 @@ -1857,6 +1876,7 @@ packages: arch: x86_64 platform: linux license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 371219 @@ -1878,6 +1898,7 @@ packages: arch: arm64 platform: osx license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 370706 @@ -1900,6 +1921,7 @@ packages: arch: x86_64 platform: win license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 396450 @@ -2182,13 +2204,13 @@ packages: timestamp: 1723739573141 - kind: conda name: ipython - version: 8.29.0 + version: 8.30.0 build: pyh707e725_0 subdir: noarch noarch: python - url: https://prefix.dev/conda-forge/noarch/ipython-8.29.0-pyh707e725_0.conda - sha256: 606723272a208cca1036852e04fbb61741b78451784746e75edd1becb70347d2 - md5: 56db21d7d51410fcfbfeca3d1a6b4269 + url: https://prefix.dev/conda-forge/noarch/ipython-8.30.0-pyh707e725_0.conda + sha256: 65cdc105e5effea2943d3979cc1592590c923a589009b484d07672faaf047af1 + md5: 5d6e5cb3a4b820f61b2073f0ad5431f1 depends: - __unix - decorator @@ -2207,17 +2229,17 @@ packages: license_family: BSD purls: - pkg:pypi/ipython?source=hash-mapping - size: 599356 - timestamp: 1729866495921 + size: 600248 + timestamp: 1732897026255 - kind: conda name: ipython - version: 8.29.0 + version: 8.30.0 build: pyh7428d3b_0 subdir: noarch noarch: python - url: https://prefix.dev/conda-forge/noarch/ipython-8.29.0-pyh7428d3b_0.conda - sha256: 2208dbe96e94ba653c4e0a5f302e36f16df73eec1968cfb85eff2d9775c9ced1 - md5: 9dc505b3569b4c26cffc241c50695f75 + url: https://prefix.dev/conda-forge/noarch/ipython-8.30.0-pyh7428d3b_0.conda + sha256: 94ee8215bd1f614c9c984437b184e8dbe61a4014eb5813c276e3dcb18aaa7f46 + md5: 6cdaebbc9e3feb2811eb9f52ed0b89e1 depends: - __win - colorama @@ -2236,8 +2258,8 @@ packages: license_family: BSD purls: - pkg:pypi/ipython?source=hash-mapping - size: 600237 - timestamp: 1729866942619 + size: 600466 + timestamp: 1732897444811 - kind: conda name: isort version: 5.13.2 @@ -4125,83 +4147,83 @@ packages: - kind: conda name: python version: 3.13.0 - build: h75c3a9f_100_cp313 - build_number: 100 - subdir: osx-arm64 - url: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-h75c3a9f_100_cp313.conda - sha256: be9464399b76ae1fef77853eed70267ef657a98a5f69f7df012b7c6a34792151 - md5: 94ae22ea862d056ad1bc095443d02d73 + build: h9ebbce0_101_cp313 + build_number: 101 + subdir: linux-64 + url: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_101_cp313.conda + sha256: 66a7997b24b2dca636df11402abec7bd2199ddf6971eb47a3ee6b1d27d4faee9 + md5: f4fea9d5bb3f2e61a39950a7ab70ee4e depends: - - __osx >=11.0 + - __glibc >=2.17,<3.0.a0 - bzip2 >=1.0.8,<2.0a0 - - libexpat >=2.6.3,<3.0a0 + - ld_impl_linux-64 >=2.36.1 + - libexpat >=2.6.4,<3.0a0 - libffi >=3.4,<4.0a0 + - libgcc >=13 - libmpdec >=4.0.0,<5.0a0 - - libsqlite >=3.46.1,<4.0a0 + - libsqlite >=3.47.0,<4.0a0 + - libuuid >=2.38.1,<3.0a0 - libzlib >=1.3.1,<2.0a0 - ncurses >=6.5,<7.0a0 - - openssl >=3.3.2,<4.0a0 + - openssl >=3.4.0,<4.0a0 - python_abi 3.13.* *_cp313 - readline >=8.2,<9.0a0 - tk >=8.6.13,<8.7.0a0 - tzdata - xz >=5.2.6,<6.0a0 - arch: arm64 - platform: osx + arch: x86_64 + platform: linux license: Python-2.0 purls: [] - size: 12804842 - timestamp: 1729168680448 + size: 33054218 + timestamp: 1732736838043 - kind: conda name: python version: 3.13.0 - build: h9ebbce0_100_cp313 - build_number: 100 - subdir: linux-64 - url: https://prefix.dev/conda-forge/linux-64/python-3.13.0-h9ebbce0_100_cp313.conda - sha256: 6ab5179679f0909db828d8316f3b8b379014a82404807310fe7df5a6cf303646 - md5: 08e9aef080f33daeb192b2ddc7e4721f + build: hbbac1ca_101_cp313 + build_number: 101 + subdir: osx-arm64 + url: https://prefix.dev/conda-forge/osx-arm64/python-3.13.0-hbbac1ca_101_cp313.conda + sha256: 742544a4cf9a10cf2c16d35d96fb696c27d58b9df0cc29fbef5629283aeca941 + md5: e972e146a1e0cfb1f26da42cb6f6648c depends: - - __glibc >=2.17,<3.0.a0 + - __osx >=11.0 - bzip2 >=1.0.8,<2.0a0 - - ld_impl_linux-64 >=2.36.1 - - libexpat >=2.6.3,<3.0a0 + - libexpat >=2.6.4,<3.0a0 - libffi >=3.4,<4.0a0 - - libgcc >=13 - libmpdec >=4.0.0,<5.0a0 - - libsqlite >=3.46.1,<4.0a0 - - libuuid >=2.38.1,<3.0a0 + - libsqlite >=3.47.0,<4.0a0 - libzlib >=1.3.1,<2.0a0 - ncurses >=6.5,<7.0a0 - - openssl >=3.3.2,<4.0a0 + - openssl >=3.4.0,<4.0a0 - python_abi 3.13.* *_cp313 - readline >=8.2,<9.0a0 - tk >=8.6.13,<8.7.0a0 - tzdata - xz >=5.2.6,<6.0a0 - arch: x86_64 - platform: linux + arch: arm64 + platform: osx license: Python-2.0 purls: [] - size: 33112481 - timestamp: 1728419573472 + size: 12806496 + timestamp: 1732735488999 - kind: conda name: python version: 3.13.0 - build: hf5aa216_100_cp313 - build_number: 100 + build: hf5aa216_101_cp313 + build_number: 101 subdir: win-64 - url: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_100_cp313.conda - sha256: 18f3f0bd514c9101d38d57835b2d027958f3ae4b3b65c22d187a857aa26b3a08 - md5: 3c2f7ad3f598480fe2a09e4e33cb1a2a + url: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_101_cp313.conda + sha256: b8eba57bd86c7890b27e67b477b52b5bd547946c354f29b9dbbc70ad83f2863b + md5: 158d6077a635cf0c0c23bec3955a4833 depends: - bzip2 >=1.0.8,<2.0a0 - - libexpat >=2.6.3,<3.0a0 + - libexpat >=2.6.4,<3.0a0 - libffi >=3.4,<4.0a0 - libmpdec >=4.0.0,<5.0a0 - - libsqlite >=3.46.1,<4.0a0 + - libsqlite >=3.47.0,<4.0a0 - libzlib >=1.3.1,<2.0a0 - - openssl >=3.3.2,<4.0a0 + - openssl >=3.4.0,<4.0a0 - python_abi 3.13.* *_cp313 - tk >=8.6.13,<8.7.0a0 - tzdata @@ -4213,8 +4235,8 @@ packages: platform: win license: Python-2.0 purls: [] - size: 16641177 - timestamp: 1728417810202 + size: 16697406 + timestamp: 1732734725404 - kind: conda name: python_abi version: '3.10' @@ -5108,6 +5130,7 @@ packages: - platformdirs >=3.9.1,<5 - python >=3.9 license: MIT + license_family: MIT purls: - pkg:pypi/virtualenv?source=hash-mapping size: 3350255 diff --git a/pyproject.toml b/pyproject.toml index 5bc15b9e..ec5fcbf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Typing :: Typed", ] dynamic = ["version"] -dependencies = [] +dependencies = ["typing-extensions"] [project.optional-dependencies] tests = [ @@ -64,6 +64,7 @@ platforms = ["linux-64", "osx-arm64", "win-64"] [tool.pixi.dependencies] python = ">=3.10.15,<3.14" +typing_extensions = ">=4.12.2,<4.13" [tool.pixi.pypi-dependencies] array-api-extra = { path = ".", editable = true } diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index b7751594..46b1388f 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,9 +1,10 @@ from __future__ import annotations -from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc +from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc __version__ = "0.2.1.dev0" +# pylint: disable=duplicate-code __all__ = [ "__version__", "atleast_nd", @@ -11,5 +12,6 @@ "create_diagonal", "expand_dims", "kron", + "setdiff1d", "sinc", ] diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index a305bfb5..4062c56f 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -1,12 +1,22 @@ -from __future__ import annotations +from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 import typing import warnings if typing.TYPE_CHECKING: - from ._typing import Array, ModuleType + from ._lib._typing import Array, ModuleType -__all__ = ["atleast_nd", "cov", "create_diagonal", "expand_dims", "kron", "sinc"] +from ._lib import _utils + +__all__ = [ + "atleast_nd", + "cov", + "create_diagonal", + "expand_dims", + "kron", + "setdiff1d", + "sinc", +] def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array: @@ -399,6 +409,53 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array: return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape))) +def setdiff1d( + x1: Array, x2: Array, /, *, assume_unique: bool = False, xp: ModuleType +) -> Array: + """ + Find the set difference of two arrays. + + Return the unique values in `x1` that are not in `x2`. + + Parameters + ---------- + x1 : array + Input array. + x2 : array + Input comparison array. + assume_unique : bool + If ``True``, the input arrays are both assumed to be unique, which + can speed up the calculation. Default is ``False``. + xp : array_namespace + The standard-compatible namespace for `x1` and `x2`. + + Returns + ------- + res : array + 1D array of values in `x1` that are not in `x2`. The result + is sorted when `assume_unique` is ``False``, but otherwise only sorted + if the input is sorted. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + + >>> x1 = xp.asarray([1, 2, 3, 2, 4, 1]) + >>> x2 = xp.asarray([3, 4, 5, 6]) + >>> xpx.setdiff1d(x1, x2, xp=xp) + Array([1, 2], dtype=array_api_strict.int64) + + """ + + if assume_unique: + x1 = xp.reshape(x1, (-1,)) + else: + x1 = xp.unique_values(x1) + x2 = xp.unique_values(x2) + return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] + + def sinc(x: Array, /, *, xp: ModuleType) -> Array: r""" Return the normalized sinc function. diff --git a/src/array_api_extra/_lib/_compat.py b/src/array_api_extra/_lib/_compat.py new file mode 100644 index 00000000..b9577ff3 --- /dev/null +++ b/src/array_api_extra/_lib/_compat.py @@ -0,0 +1,168 @@ +### Helpers borrowed from array-api-compat + +from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 + +import inspect +import sys +import typing + +from typing_extensions import override + +if typing.TYPE_CHECKING: + from ._typing import Array, Device + +__all__ = ["device"] + + +# Placeholder object to represent the dask device +# when the array backend is not the CPU. +# (since it is not easy to tell which device a dask array is on) +class _dask_device: # pylint: disable=invalid-name + @override + def __repr__(self) -> str: + return "DASK_DEVICE" + + +_DASK_DEVICE = _dask_device() + + +# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray +# or cupy.ndarray. They are not included in array objects of this library +# because this library just reuses the respective ndarray classes without +# wrapping or subclassing them. These helper functions can be used instead of +# the wrapper functions for libraries that need to support both NumPy/CuPy and +# other libraries that use devices. +def device(x: Array, /) -> Device: + """ + Hardware device the array data resides on. + + This is equivalent to `x.device` according to the `standard + `__. + This helper is included because some array libraries either do not have + the `device` attribute or include it with an incompatible API. + + Parameters + ---------- + x: array + array instance from an array API compatible library. + + Returns + ------- + out: device + a ``device`` object (see the `Device Support `__ + section of the array API specification). + + Notes + ----- + + For NumPy the device is always `"cpu"`. For Dask, the device is always a + special `DASK_DEVICE` object. + + See Also + -------- + + to_device : Move array data to a different device. + + """ + if _is_numpy_array(x): + return "cpu" + if _is_dask_array(x): + # Peek at the metadata of the jax array to determine type + try: + import numpy as np # pylint: disable=import-outside-toplevel + + if isinstance(x._meta, np.ndarray): # pylint: disable=protected-access + # Must be on CPU since backed by numpy + return "cpu" + except ImportError: + pass + return _DASK_DEVICE + if _is_jax_array(x): + # JAX has .device() as a method, but it is being deprecated so that it + # can become a property, in accordance with the standard. In order for + # this function to not break when JAX makes the flip, we check for + # both here. + if inspect.ismethod(x.device): + return x.device() + return x.device + if _is_pydata_sparse_array(x): + # `sparse` will gain `.device`, so check for this first. + x_device = getattr(x, "device", None) + if x_device is not None: + return x_device + # Everything but DOK has this attr. + try: + inner = x.data + except AttributeError: + return "cpu" + # Return the device of the constituent array + return device(inner) + return x.device + + +def _is_numpy_array(x: Array) -> bool: + """Return True if `x` is a NumPy array.""" + # Avoid importing NumPy if it isn't already + if "numpy" not in sys.modules: + return False + + import numpy as np # pylint: disable=import-outside-toplevel + + # TODO: Should we reject ndarray subclasses? + return isinstance(x, (np.ndarray, np.generic)) and not _is_jax_zero_gradient_array( + x + ) + + +def _is_dask_array(x: Array) -> bool: + """Return True if `x` is a dask.array Array.""" + # Avoid importing dask if it isn't already + if "dask.array" not in sys.modules: + return False + + # pylint: disable=import-error, import-outside-toplevel + import dask.array # type: ignore[import-not-found] # pyright: ignore[reportMissingImports] + + return isinstance(x, dask.array.Array) + + +def _is_jax_zero_gradient_array(x: Array) -> bool: + """Return True if `x` is a zero-gradient array. + + These arrays are a design quirk of Jax that may one day be removed. + See https://github.com/google/jax/issues/20620. + """ + if "numpy" not in sys.modules or "jax" not in sys.modules: + return False + + # pylint: disable=import-error, import-outside-toplevel + import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports] + import numpy as np # pylint: disable=import-outside-toplevel + + return isinstance(x, np.ndarray) and x.dtype == jax.float0 # pyright: ignore[reportUnknownVariableType] + + +def _is_jax_array(x: Array) -> bool: + """Return True if `x` is a JAX array.""" + # Avoid importing jax if it isn't already + if "jax" not in sys.modules: + return False + + # pylint: disable=import-error, import-outside-toplevel + import jax # pyright: ignore[reportMissingImports] + + return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) + + +def _is_pydata_sparse_array(x: Array) -> bool: + """Return True if `x` is an array from the `sparse` package.""" + + # Avoid importing jax if it isn't already + if "sparse" not in sys.modules: + return False + + # pylint: disable=import-error, import-outside-toplevel + import sparse # type: ignore[import-not-found] # pyright: ignore[reportMissingImports] + + # TODO: Account for other backends. + return isinstance(x, sparse.SparseArray) diff --git a/src/array_api_extra/_lib/_typing.py b/src/array_api_extra/_lib/_typing.py new file mode 100644 index 00000000..13079807 --- /dev/null +++ b/src/array_api_extra/_lib/_typing.py @@ -0,0 +1,10 @@ +from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 + +from types import ModuleType +from typing import Any + +# To be changed to a Protocol later (see data-apis/array-api#589) +Array = Any # type: ignore[no-any-explicit] +Device = Any # type: ignore[no-any-explicit] + +__all__ = ["Array", "Device", "ModuleType"] diff --git a/src/array_api_extra/_lib/_utils.py b/src/array_api_extra/_lib/_utils.py new file mode 100644 index 00000000..bf65340e --- /dev/null +++ b/src/array_api_extra/_lib/_utils.py @@ -0,0 +1,65 @@ +from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 + +import typing + +if typing.TYPE_CHECKING: + from ._typing import Array, ModuleType + +from . import _compat + +__all__ = ["in1d"] + + +def in1d( + x1: Array, + x2: Array, + /, + *, + assume_unique: bool = False, + invert: bool = False, + xp: ModuleType, +) -> Array: + """Checks whether each element of an array is also present in a + second array. + + Returns a boolean array the same length as `x1` that is True + where an element of `x1` is in `x2` and False otherwise. + + This function has been adapted using the original implementation + present in numpy: + https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758 + """ + + # This code is run to make the code significantly faster + if x2.shape[0] < 10 * x1.shape[0] ** 0.145: + if invert: + mask = xp.ones(x1.shape[0], dtype=xp.bool, device=x1.device) + for a in x2: + mask &= x1 != a + else: + mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=x1.device) + for a in x2: + mask |= x1 == a + return mask + + rev_idx = xp.empty(0) # placeholder + if not assume_unique: + x1, rev_idx = xp.unique_inverse(x1) + x2 = xp.unique_values(x2) + + ar = xp.concat((x1, x2)) + device_ = _compat.device(ar) + # We need this to be a stable sort. + order = xp.argsort(ar, stable=True) + reverse_order = xp.argsort(order, stable=True) + sar = xp.take(ar, order, axis=0) + if sar.size >= 1: + bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1] + else: + bool_ar = xp.asarray([False]) if invert else xp.asarray([True]) + flag = xp.concat((bool_ar, xp.asarray([invert], device=device_))) + ret = xp.take(flag, reverse_order, axis=0) + + if assume_unique: + return ret[: x1.shape[0]] + return xp.take(ret, rev_idx, axis=0) diff --git a/src/array_api_extra/_typing.py b/src/array_api_extra/_typing.py deleted file mode 100644 index 5584d511..00000000 --- a/src/array_api_extra/_typing.py +++ /dev/null @@ -1,9 +0,0 @@ -from __future__ import annotations - -from types import ModuleType -from typing import Any - -# To be changed to a Protocol later (see data-apis/array-api#589) -Array = Any # type: ignore[no-any-explicit] - -__all__ = ["Array", "ModuleType"] diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 827da9c4..36411958 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 import contextlib import typing @@ -10,10 +10,18 @@ import pytest from numpy.testing import assert_allclose, assert_array_equal, assert_equal -from array_api_extra import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc +from array_api_extra import ( + atleast_nd, + cov, + create_diagonal, + expand_dims, + kron, + setdiff1d, + sinc, +) if typing.TYPE_CHECKING: - from array_api_extra._typing import Array + from array_api_extra._lib._typing import Array class TestAtLeastND: @@ -263,6 +271,34 @@ def test_positive_negative_repeated(self): expand_dims(a, axis=(3, -3), xp=xp) +class TestSetDiff1D: + def test_setdiff1d(self): + x1 = xp.asarray([6, 5, 4, 7, 1, 2, 7, 4]) + x2 = xp.asarray([2, 4, 3, 3, 2, 1, 5]) + + expected = xp.asarray([6, 7]) + actual = setdiff1d(x1, x2, xp=xp) + assert_array_equal(actual, expected) + + x1 = xp.arange(21) + x2 = xp.arange(19) + expected = xp.asarray([19, 20]) + actual = setdiff1d(x1, x2, xp=xp) + assert_array_equal(actual, expected) + + assert_array_equal(setdiff1d(xp.empty(0), xp.empty(0), xp=xp), xp.empty(0)) + x1 = xp.empty(0, dtype=xp.uint32) + x2 = x1 + assert_equal(setdiff1d(x1, x2, xp=xp).dtype, xp.uint32) + + def test_assume_unique(self): + x1 = xp.asarray([3, 2, 1]) + x2 = xp.asarray([7, 5, 2]) + expected = xp.asarray([3, 1]) + actual = setdiff1d(x1, x2, assume_unique=True, xp=xp) + assert_array_equal(actual, expected) + + class TestSinc: def test_simple(self): assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..a34ec56f --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,24 @@ +from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 + +import typing + +# data-apis/array-api-strict#6 +import array_api_strict as xp # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] +import pytest +from numpy.testing import assert_array_equal + +from array_api_extra._lib._utils import in1d + +if typing.TYPE_CHECKING: + from array_api_extra._lib._typing import Array + + +# some test coverage already provided by TestSetDiff1D +class TestIn1D: + # cover both code paths + @pytest.mark.parametrize("x2", [xp.arange(9), xp.arange(15)]) + def test_no_invert_assume_unique(self, x2: Array): + x1 = xp.asarray([3, 8, 20]) + expected = xp.asarray([True, True, False]) + actual = in1d(x1, x2, xp=xp) + assert_array_equal(actual, expected) diff --git a/tests/test_version.py b/tests/test_version.py index 21d43d17..1b20232f 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 import importlib.metadata