Skip to content

ENH: add setdiff1d #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
comment: false
ignore:
- "src/array_api_extra/_typing"
- "src/array_api_extra/_lib/_compat"
- "src/array_api_extra/_lib/_typing"
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
create_diagonal
expand_dims
kron
setdiff1d
sinc
```
173 changes: 98 additions & 75 deletions pixi.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = []
dependencies = ["typing-extensions"]

[project.optional-dependencies]
tests = [
Expand Down Expand Up @@ -64,6 +64,7 @@ platforms = ["linux-64", "osx-arm64", "win-64"]

[tool.pixi.dependencies]
python = ">=3.10.15,<3.14"
typing_extensions = ">=4.12.2,<4.13"

[tool.pixi.pypi-dependencies]
array-api-extra = { path = ".", editable = true }
Expand Down
4 changes: 3 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc

__version__ = "0.2.1.dev0"

# pylint: disable=duplicate-code
__all__ = [
"__version__",
"atleast_nd",
"cov",
"create_diagonal",
"expand_dims",
"kron",
"setdiff1d",
"sinc",
]
63 changes: 60 additions & 3 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from __future__ import annotations
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

import typing
import warnings

if typing.TYPE_CHECKING:
from ._typing import Array, ModuleType
from ._lib._typing import Array, ModuleType

__all__ = ["atleast_nd", "cov", "create_diagonal", "expand_dims", "kron", "sinc"]
from ._lib import _utils

__all__ = [
"atleast_nd",
"cov",
"create_diagonal",
"expand_dims",
"kron",
"setdiff1d",
"sinc",
]


def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
Expand Down Expand Up @@ -399,6 +409,53 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))


def setdiff1d(
x1: Array, x2: Array, /, *, assume_unique: bool = False, xp: ModuleType
) -> Array:
"""
Find the set difference of two arrays.

Return the unique values in `x1` that are not in `x2`.

Parameters
----------
x1 : array
Input array.
x2 : array
Input comparison array.
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
xp : array_namespace
The standard-compatible namespace for `x1` and `x2`.

Returns
-------
res : array
1D array of values in `x1` that are not in `x2`. The result
is sorted when `assume_unique` is ``False``, but otherwise only sorted
if the input is sorted.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx

>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
>>> x2 = xp.asarray([3, 4, 5, 6])
>>> xpx.setdiff1d(x1, x2, xp=xp)
Array([1, 2], dtype=array_api_strict.int64)

"""

if assume_unique:
x1 = xp.reshape(x1, (-1,))
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]


def sinc(x: Array, /, *, xp: ModuleType) -> Array:
r"""
Return the normalized sinc function.
Expand Down
168 changes: 168 additions & 0 deletions src/array_api_extra/_lib/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
### Helpers borrowed from array-api-compat

from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

import inspect
import sys
import typing

from typing_extensions import override

if typing.TYPE_CHECKING:
from ._typing import Array, Device

__all__ = ["device"]


# Placeholder object to represent the dask device
# when the array backend is not the CPU.
# (since it is not easy to tell which device a dask array is on)
class _dask_device: # pylint: disable=invalid-name
@override
def __repr__(self) -> str:
return "DASK_DEVICE"


_DASK_DEVICE = _dask_device()


# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
# or cupy.ndarray. They are not included in array objects of this library
# because this library just reuses the respective ndarray classes without
# wrapping or subclassing them. These helper functions can be used instead of
# the wrapper functions for libraries that need to support both NumPy/CuPy and
# other libraries that use devices.
def device(x: Array, /) -> Device:
"""
Hardware device the array data resides on.

This is equivalent to `x.device` according to the `standard
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__.
This helper is included because some array libraries either do not have
the `device` attribute or include it with an incompatible API.

Parameters
----------
x: array
array instance from an array API compatible library.

Returns
-------
out: device
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
section of the array API specification).

Notes
-----

For NumPy the device is always `"cpu"`. For Dask, the device is always a
special `DASK_DEVICE` object.

See Also
--------

to_device : Move array data to a different device.

"""
if _is_numpy_array(x):
return "cpu"
if _is_dask_array(x):
# Peek at the metadata of the jax array to determine type
try:
import numpy as np # pylint: disable=import-outside-toplevel

if isinstance(x._meta, np.ndarray): # pylint: disable=protected-access
# Must be on CPU since backed by numpy
return "cpu"
except ImportError:
pass
return _DASK_DEVICE
if _is_jax_array(x):
# JAX has .device() as a method, but it is being deprecated so that it
# can become a property, in accordance with the standard. In order for
# this function to not break when JAX makes the flip, we check for
# both here.
if inspect.ismethod(x.device):
return x.device()
return x.device
if _is_pydata_sparse_array(x):
# `sparse` will gain `.device`, so check for this first.
x_device = getattr(x, "device", None)
if x_device is not None:
return x_device
# Everything but DOK has this attr.
try:
inner = x.data
except AttributeError:
return "cpu"
# Return the device of the constituent array
return device(inner)
return x.device


def _is_numpy_array(x: Array) -> bool:
"""Return True if `x` is a NumPy array."""
# Avoid importing NumPy if it isn't already
if "numpy" not in sys.modules:
return False

import numpy as np # pylint: disable=import-outside-toplevel

# TODO: Should we reject ndarray subclasses?
return isinstance(x, (np.ndarray, np.generic)) and not _is_jax_zero_gradient_array(
x
)


def _is_dask_array(x: Array) -> bool:
"""Return True if `x` is a dask.array Array."""
# Avoid importing dask if it isn't already
if "dask.array" not in sys.modules:
return False

# pylint: disable=import-error, import-outside-toplevel
import dask.array # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]

return isinstance(x, dask.array.Array)


def _is_jax_zero_gradient_array(x: Array) -> bool:
"""Return True if `x` is a zero-gradient array.

These arrays are a design quirk of Jax that may one day be removed.
See https://github.com/google/jax/issues/20620.
"""
if "numpy" not in sys.modules or "jax" not in sys.modules:
return False

# pylint: disable=import-error, import-outside-toplevel
import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
import numpy as np # pylint: disable=import-outside-toplevel

return isinstance(x, np.ndarray) and x.dtype == jax.float0 # pyright: ignore[reportUnknownVariableType]


def _is_jax_array(x: Array) -> bool:
"""Return True if `x` is a JAX array."""
# Avoid importing jax if it isn't already
if "jax" not in sys.modules:
return False

# pylint: disable=import-error, import-outside-toplevel
import jax # pyright: ignore[reportMissingImports]

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def _is_pydata_sparse_array(x: Array) -> bool:
"""Return True if `x` is an array from the `sparse` package."""

# Avoid importing jax if it isn't already
if "sparse" not in sys.modules:
return False

# pylint: disable=import-error, import-outside-toplevel
import sparse # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]

# TODO: Account for other backends.
return isinstance(x, sparse.SparseArray)
10 changes: 10 additions & 0 deletions src/array_api_extra/_lib/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

from types import ModuleType
from typing import Any

# To be changed to a Protocol later (see data-apis/array-api#589)
Array = Any # type: ignore[no-any-explicit]
Device = Any # type: ignore[no-any-explicit]

__all__ = ["Array", "Device", "ModuleType"]
65 changes: 65 additions & 0 deletions src/array_api_extra/_lib/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

import typing

if typing.TYPE_CHECKING:
from ._typing import Array, ModuleType

from . import _compat

__all__ = ["in1d"]


def in1d(
x1: Array,
x2: Array,
/,
*,
assume_unique: bool = False,
invert: bool = False,
xp: ModuleType,
) -> Array:
"""Checks whether each element of an array is also present in a
second array.

Returns a boolean array the same length as `x1` that is True
where an element of `x1` is in `x2` and False otherwise.

This function has been adapted using the original implementation
present in numpy:
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
"""

# This code is run to make the code significantly faster
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
if invert:
mask = xp.ones(x1.shape[0], dtype=xp.bool, device=x1.device)
for a in x2:
mask &= x1 != a
else:
mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=x1.device)
for a in x2:
mask |= x1 == a
return mask

rev_idx = xp.empty(0) # placeholder
if not assume_unique:
x1, rev_idx = xp.unique_inverse(x1)
x2 = xp.unique_values(x2)

ar = xp.concat((x1, x2))
device_ = _compat.device(ar)
# We need this to be a stable sort.
order = xp.argsort(ar, stable=True)
reverse_order = xp.argsort(order, stable=True)
sar = xp.take(ar, order, axis=0)
if sar.size >= 1:
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
else:
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
ret = xp.take(flag, reverse_order, axis=0)

if assume_unique:
return ret[: x1.shape[0]]
return xp.take(ret, rev_idx, axis=0)
9 changes: 0 additions & 9 deletions src/array_api_extra/_typing.py

This file was deleted.

Loading