Description
Describe the issue:
When all the values in alpha
are less than 0.1, and alpha
ends in two or more zeros, the components of the variates returned by dirichlet(alpha)
corresponding to those final zeros will be nan.
For example,
In [18]: rng.dirichlet([0.09, 0.085, 0, 0, 0])
Out[18]: array([0.91378938, 0.08621062, nan, nan, nan])
When all the values in alpha
are less than 0.1, dirichlet
uses the algorithm that is based on the beta distribution. The problem occurs because dirichlet
ends up calling the C function random_beta
with both parameters a
and b
being 0, which results in random_beta
returning nan
. Currently, the public API for beta
requires both a
and b
to be positive; this is checked before the beta
method calls the C function random_beta
. The dirichlet
code calls random_beta
directly, so that validation is bypassed.
It looks like random_beta
handles one parameter being 0 in a manner consistent with the reasoning that allows dirichlet
to have zeros in alpha
. That's why nan
s are produced only when there are two or more zeros at the end of alpha
, because that is the only case where dirichlet
will call random_beta
with both parameters being 0.
It shouldn't be too difficult to fix to dirichlet
to handle the case where all the values in alpha
are less than 0.1, and two or more values at the end of alpha
are 0.
But there is a remaining question that was not brought up in #22547 or #23440: how should an input that is all zeros be handled? Some options (ordered by my preference):
- Raise a ValueError (i.e. disallow alpha being all zeros).
- Return a vector of zeros.
- Return a random unit vector (i.e. a vector with
len(alpha) - 1
zeros and single 1 at a random position in the vector). (On second thought, there probably isn't any reasonable justification for this.)
Runtime information:
In [4]: import sys, numpy; print(numpy.__version__); print(sys.version)
1.25.0rc1+530.g2e668061db
3.11.4 (main, Jul 3 2023, 14:49:40) [GCC 11.3.0]
In [5]: print(numpy.show_runtime())
[{'numpy_version': '1.25.0rc1+530.g2e668061db',
'python': '3.11.4 (main, Jul 3 2023, 14:49:40) [GCC 11.3.0]',
'uname': uname_result(system='Linux', node='pop-os', release='6.2.6-76060206-generic', version='#202303130630~1689015125~22.04~ab2190e SMP PREEMPT_DYNAMIC Mon J', machine='x86_64')},
{'simd_extensions': {'baseline': [], 'found': [], 'not_found': []}},
{'architecture': 'Zen',
'filepath': '/usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so',
'internal_api': 'openblas',
'num_threads': 24,
'prefix': 'libopenblas',
'threading_layer': 'pthreads',
'user_api': 'blas',
'version': '0.3.20'}]
None
Context for the issue:
No response