Skip to content

BUG: random: dirichlet(alpha) can return nans in some cases. #24210

Closed
@WarrenWeckesser

Description

@WarrenWeckesser

Describe the issue:

When all the values in alpha are less than 0.1, and alpha ends in two or more zeros, the components of the variates returned by dirichlet(alpha) corresponding to those final zeros will be nan.

For example,

In [18]: rng.dirichlet([0.09, 0.085, 0, 0, 0])
Out[18]: array([0.91378938, 0.08621062,        nan,        nan,        nan])

When all the values in alpha are less than 0.1, dirichlet uses the algorithm that is based on the beta distribution. The problem occurs because dirichlet ends up calling the C function random_beta with both parameters a and b being 0, which results in random_beta returning nan. Currently, the public API for beta requires both a and b to be positive; this is checked before the beta method calls the C function random_beta. The dirichlet code calls random_beta directly, so that validation is bypassed.

It looks like random_beta handles one parameter being 0 in a manner consistent with the reasoning that allows dirichlet to have zeros in alpha. That's why nans are produced only when there are two or more zeros at the end of alpha, because that is the only case where dirichlet will call random_beta with both parameters being 0.

It shouldn't be too difficult to fix to dirichlet to handle the case where all the values in alpha are less than 0.1, and two or more values at the end of alpha are 0.

But there is a remaining question that was not brought up in #22547 or #23440: how should an input that is all zeros be handled? Some options (ordered by my preference):

  • Raise a ValueError (i.e. disallow alpha being all zeros).
  • Return a vector of zeros.
  • Return a random unit vector (i.e. a vector with len(alpha) - 1 zeros and single 1 at a random position in the vector). (On second thought, there probably isn't any reasonable justification for this.)

Runtime information:

In [4]: import sys, numpy; print(numpy.__version__); print(sys.version)
1.25.0rc1+530.g2e668061db
3.11.4 (main, Jul  3 2023, 14:49:40) [GCC 11.3.0]

In [5]: print(numpy.show_runtime())
[{'numpy_version': '1.25.0rc1+530.g2e668061db',
  'python': '3.11.4 (main, Jul  3 2023, 14:49:40) [GCC 11.3.0]',
  'uname': uname_result(system='Linux', node='pop-os', release='6.2.6-76060206-generic', version='#202303130630~1689015125~22.04~ab2190e SMP PREEMPT_DYNAMIC Mon J', machine='x86_64')},
 {'simd_extensions': {'baseline': [], 'found': [], 'not_found': []}},
 {'architecture': 'Zen',
  'filepath': '/usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so',
  'internal_api': 'openblas',
  'num_threads': 24,
  'prefix': 'libopenblas',
  'threading_layer': 'pthreads',
  'user_api': 'blas',
  'version': '0.3.20'}]
None

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions