From 25e612bc38503b471621afb298c275b18df4265f Mon Sep 17 00:00:00 2001 From: naman1608 <1608naman@gmail.com> Date: Mon, 28 Aug 2023 03:14:47 +0530 Subject: [PATCH] validation params for weighted mode --- sklearn/tests/test_public_functions.py | 1 + sklearn/utils/extmath.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index bd20de37d405e..26dc041f41429 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -321,6 +321,7 @@ def _check_function_param_validation( "sklearn.utils.resample", "sklearn.utils.safe_mask", "sklearn.utils.extmath.randomized_svd", + "sklearn.utils.extmath.weighted_mode", "sklearn.utils.class_weight.compute_class_weight", "sklearn.utils.class_weight.compute_sample_weight", "sklearn.utils.graph.single_source_shortest_path_length", diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index c2aa9d07e6635..493f8cf78b0b8 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -693,6 +693,14 @@ def _randomized_eigsh( return eigvals, eigvecs +@validate_params( + { + "a": ["array-like"], + "w": ["array-like"], + "axis": [Interval(Integral, 0, None, closed="left"), None], + }, + prefer_skip_nested_validation=True, +) def weighted_mode(a, w, *, axis=0): """Return an array of the weighted modal (most common) value in the passed array.