Skip to content

feat: add bincount to the specification #960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kgryte
Copy link
Contributor

@kgryte kgryte commented Jun 12, 2025

This PR

  • resolves RFC: add bincount #812 by adding bincount to the specification for counting the number of occurrences of each element in an input integer array.
  • based on comparison data, only supports a minlength keyword argument.
  • allows weights to be both a positional and keyword argument.
  • allows weights to have any numeric data type, including complex.
  • specifies that, when weights is not provided, the output data type must be an integer data type. The data type rules follow other statistical functions (e.g., sum), where a minimum precision is required.
  • specifies that, when weights is provided, the output data type must have the same data type as weights.
  • specifies that an input array x should (not must) be a one-dimensional array. TensorFlow allows x to be multi-dimensional, and this PR chooses to provide wiggle room to allow supporting multi-dimensional arrays. One reason this isn't commonly supported in other libraries is that bincount has a data-dependent output shape; however, TF supports kwargs which allow specifying a static output shape, thus allowing bincount to generalize to multiple dimensions.
  • includes an admonition that bincount has a data-dependent output shape and thus certain libraries are allowed to omit this function if too difficult to implement. This follows similar practice for other APIs having data-dependent output shapes (e.g., unique*).
  • specifies that, if x contains negative values, behavior is unspecified and thus implementation-defined. NumPy raises an exception, while JAX clips.
  • specifies that weights have the same shape as x; although, IMO, this restriction is not necessary and could be relaxed to broadcast-compatibility.
  • specifies that the returned array should (not must) have shape (N,), where N = max(xp.max(x)+1, minlength). The use of should is intentional, in order to allow libraries such as JAX and TF to support other keyword arguments which may constrain the output shape.
  • specifies that the default value of minlength must be 0. According to docs, both CuPy and TF use a default of None.

Questions

@kgryte kgryte added this to the v2025 milestone Jun 12, 2025
@kgryte kgryte added the API extension Adds new functions or objects to the API. label Jun 12, 2025
@kgryte kgryte mentioned this pull request Jun 12, 2025
@NeilGirdhar
Copy link

NeilGirdhar commented Jun 12, 2025

The shape of the output array for this function depends on the data values in x; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) can find this function difficult to implement without knowing the values in x. Accordingly, such libraries may choose to omit this function. See :ref:data-dependent-output-shapes section for more details.

Why not allow Jax to implement this function by adding an optional length argument and making it mandatory for this function to be in the Array API when length is provided? Most algorithms are amenable to that. Otherwise, you'd have to write Jax versions of the same algorithm (yuck).

@jakevdp
Copy link

jakevdp commented Jun 12, 2025

  • Both JAX and TensorFlow allow x to be multi-dimensional

I don't think this is correct in the case of JAX. I confirmed that this errors in v0.5.x and v0.6.x:

import jax.numpy as jnp
jnp.bincount(jnp.arange(9).reshape(3, 3))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-1-1852179904>](https://localhost:8080/#) in <cell line: 0>()
      2 import jax.numpy as jnp
      3 print(jax.__version__)
----> 4 jnp.bincount(jnp.arange(9).reshape(3, 3))

[/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in bincount(x, weights, minlength, length)
   2994     raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}")
   2995   if np.ndim(x) != 1:
-> 2996     raise ValueError("only 1-dimensional input supported.")
   2997   minlength = core.concrete_or_error(operator.index, minlength,
   2998       "The error occurred because of argument 'minlength' of jnp.bincount.")

ValueError: only 1-dimensional input supported.

@NeilGirdhar
Copy link

NeilGirdhar commented Jun 12, 2025

specifies that an input array x should (not must) be a one-dimensional array. Both JAX and TensorFlow allow x to be multi-dimensional, and this PR chooses to provide wiggle room to allow supporting multi-dimensional arrays. One reason this isn't commonly supported in other libraries is that bincount has a data-dependent output shape; however, both JAX and TF support kwargs which allow specifying a static output shape, thus allowing bincount to generalize to multiple dimensions.

It maybe nice to treat additional dimensions as broadcasted dimensions like e.g., matrix_transpose. That is, suppose x has shape (*xs, xn) and you want to return length bins, you could return an array having shape (*xs, length)? This is just the broadcasted generalization of the 1-dimensional case.

@rgommers
Copy link
Member

Thanks @kgryte for the detailed proposal. This function is heavily used and present everywhere, so it makes sense to add from that perspective. The main question I have at the moment is whether there is a good alternative for bincount that isn't suffering from the value-dependent issue.

The function itself is pretty specific; for it to work you have to shift the values to a non-negative range just above zero. I think that that's usually not done; it's more common to use something like histogram or scipy.stats.binned_statistic in those cases. bincount is usually used for distributions of integers that are already in the (0, N) range with N not very large (otherwise output size explodes).

Are we okay with allowing weights to be complex?

No. This seems super niche, and it's not supported by NumPy - so no reason to even consider this I'd think.

Are we okay with weights being both positional and a keyword argument?

I'd vote for keyword-only, since it's a very descriptive name and there's no real reason to use positional-only as far as I can tell.

@kgryte
Copy link
Contributor Author

kgryte commented Jun 12, 2025

JAX supporting 1-dimensional arrays.

@jakevdp Would be good to update the docstring then for bincount, as currently it suggests that N-dimensional support is present. It is also not clear why JAX's docs state that the array must consists of positive integers, rather than nonnegative integers.

image

@kgryte
Copy link
Contributor Author

kgryte commented Jun 12, 2025

I'd vote for keyword-only, since it's a very descriptive name and there's no real reason to use positional-only as far as I can tell.

@rgommers I am fine making the change to kwarg-only for guaranteed portability. sklearn includes both positional and kwarg usage, with the latter being more predominant. Similarly, from a search on sourcegraph, kwarg usage is more common, although positional usage of np.bincount is not uncommon.

@jakevdp
Copy link

jakevdp commented Jun 12, 2025

@jakevdp Would be good to update the docstring then for bincount

Thanks for pointing that out – updated in jax-ml/jax#29441.

@betatim
Copy link
Member

betatim commented Jun 13, 2025

Slight preference for weights as keyword only. It is easy enough to update in scikit-learn.

As a user it is annoying/tedious if different libraries require different treatment. The whole point of array API is to have something uniform instead of maintaining a big bunch of if statements. From that point of view it would be nice to have something that works for jax as well. This would mean making length a argument mentioned in the standard. Is it possible to make a generic recommendation for what to pass as value. At least my first reaction to "you have to provide length was "how would I know what it should be, can't you work it out for me far better than I can?" But maybe max(a) and len(a) cover the vast majority of cases for naive users/get people started and then they can ponder if there is a better value? Because if it is that easy and it would remove the need to special case libraries like jax ... that might be a tradeoff worth making? Or am I missing something?

@rgommers
Copy link
Member

My takeaway from the discussion in the community meeting was similar to the question @betatim asked above. Can actually be split into two:

  • Should we add a length keyword? Seems to be the most common use case, and seems to make more sense than minlength (the way minlength is often used is "given me length N, rather than >=N).
  • Can we leave out minlength? I.e., is it actually useful separately from length?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RFC: add bincount
5 participants