-
Notifications
You must be signed in to change notification settings - Fork 53
feat: add bincount
to the specification
#960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Why not allow Jax to implement this function by adding an optional |
I don't think this is correct in the case of JAX. I confirmed that this errors in v0.5.x and v0.6.x: import jax.numpy as jnp
jnp.bincount(jnp.arange(9).reshape(3, 3)) ---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-1-1852179904>](https://localhost:8080/#) in <cell line: 0>()
2 import jax.numpy as jnp
3 print(jax.__version__)
----> 4 jnp.bincount(jnp.arange(9).reshape(3, 3))
[/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in bincount(x, weights, minlength, length)
2994 raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}")
2995 if np.ndim(x) != 1:
-> 2996 raise ValueError("only 1-dimensional input supported.")
2997 minlength = core.concrete_or_error(operator.index, minlength,
2998 "The error occurred because of argument 'minlength' of jnp.bincount.")
ValueError: only 1-dimensional input supported. |
It maybe nice to treat additional dimensions as broadcasted dimensions like e.g., |
Thanks @kgryte for the detailed proposal. This function is heavily used and present everywhere, so it makes sense to add from that perspective. The main question I have at the moment is whether there is a good alternative for The function itself is pretty specific; for it to work you have to shift the values to a non-negative range just above zero. I think that that's usually not done; it's more common to use something like
No. This seems super niche, and it's not supported by NumPy - so no reason to even consider this I'd think.
I'd vote for keyword-only, since it's a very descriptive name and there's no real reason to use positional-only as far as I can tell. |
@jakevdp Would be good to update the docstring then for |
@rgommers I am fine making the change to kwarg-only for guaranteed portability. sklearn includes both positional and kwarg usage, with the latter being more predominant. Similarly, from a search on sourcegraph, kwarg usage is more common, although positional usage of |
Thanks for pointing that out – updated in jax-ml/jax#29441. |
Slight preference for weights as keyword only. It is easy enough to update in scikit-learn. As a user it is annoying/tedious if different libraries require different treatment. The whole point of array API is to have something uniform instead of maintaining a big bunch of |
My takeaway from the discussion in the community meeting was similar to the question @betatim asked above. Can actually be split into two:
|
This PR
bincount
to the specification for counting the number of occurrences of each element in an input integer array.minlength
keyword argument.weights
to be both a positional and keyword argument.weights
to have any numeric data type, including complex.weights
is not provided, the output data type must be an integer data type. The data type rules follow other statistical functions (e.g.,sum
), where a minimum precision is required.weights
is provided, the output data type must have the same data type asweights
.x
should (not must) be a one-dimensional array. TensorFlow allowsx
to be multi-dimensional, and this PR chooses to provide wiggle room to allow supporting multi-dimensional arrays. One reason this isn't commonly supported in other libraries is thatbincount
has a data-dependent output shape; however, TF supports kwargs which allow specifying a static output shape, thus allowingbincount
to generalize to multiple dimensions.bincount
has a data-dependent output shape and thus certain libraries are allowed to omit this function if too difficult to implement. This follows similar practice for other APIs having data-dependent output shapes (e.g.,unique*
).x
contains negative values, behavior is unspecified and thus implementation-defined. NumPy raises an exception, while JAX clips.weights
have the same shape asx
; although, IMO, this restriction is not necessary and could be relaxed to broadcast-compatibility.(N,)
, whereN = max(xp.max(x)+1, minlength)
. The use of should is intentional, in order to allow libraries such as JAX and TF to support other keyword arguments which may constrain the output shape.minlength
must be0
. According to docs, both CuPy and TF use a default ofNone
.Questions
weights
to be complex? This does not appear to be supported in NumPy (ref: bincount does not accept complex valued weights numpy/numpy#23313 and bincount fails for complex weights numpy/numpy#16903); however, there isn't a technical reason why weights cannot be complex, as summation is well-defined for complex numbers.sum
and other statistical functions, should we support an outputdtype
kwarg (e.g., in order to support overriding the default integer output type behavior)?weights
being both positional and a keyword argument?