Skip to content

Commit 89b7e9a

Browse files
committed
Used object() in check_X_y to check for sample_weight
1 parent 777ea36 commit 89b7e9a

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

sklearn/utils/validation.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from ..exceptions import NonBLASDotWarning as _NonBLASDotWarning
2222
from ..exceptions import NotFittedError as _NotFittedError
2323

24+
NOT_SPECIFIED = object()
25+
2426

2527
@deprecated("DataConversionWarning has been moved into the sklearn.exceptions"
2628
" module. It will not be available here from version 0.19")
@@ -429,7 +431,7 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None,
429431
copy=False, force_all_finite=True, ensure_2d=True,
430432
allow_nd=False, multi_output=False, ensure_min_samples=1,
431433
ensure_min_features=1, y_numeric=False,
432-
warn_on_dtype=False, estimator=None, sample_weight=False):
434+
warn_on_dtype=False, estimator=None, sample_weight=NOT_SPECIFIED):
433435
"""Input validation for standard estimators.
434436
435437
Checks X and y for consistent length, enforces X 2d and y 1d.
@@ -503,6 +505,10 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None,
503505
estimator : str or estimator instance (default=None)
504506
If passed, include the name of the estimator in warning messages.
505507
508+
sample_weight : nd-array, list
509+
Ensures sample_weight is 1-d.
510+
511+
506512
Returns
507513
-------
508514
X_converted : object
@@ -525,16 +531,17 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None,
525531

526532
check_consistent_length(X, y)
527533

528-
if sample_weight is None:
534+
if sample_weight is NOT_SPECIFIED:
535+
return X, y
536+
537+
elif sample_weight is None:
529538
return X, y, sample_weight
530539

531-
elif sample_weight is not False:
540+
else:
532541
sample_weight = check_array(sample_weight, ensure_2d=False)
533542
check_consistent_length(y, sample_weight)
534543
return X, y, sample_weight
535544

536-
return X, y
537-
538545

539546
def column_or_1d(y, warn=False):
540547
""" Ravel column or 1d numpy array, else raises an error

0 commit comments

Comments
 (0)