21
21
from ..exceptions import NonBLASDotWarning as _NonBLASDotWarning
22
22
from ..exceptions import NotFittedError as _NotFittedError
23
23
24
+ NOT_SPECIFIED = object ()
25
+
24
26
25
27
@deprecated ("DataConversionWarning has been moved into the sklearn.exceptions"
26
28
" module. It will not be available here from version 0.19" )
@@ -429,7 +431,7 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None,
429
431
copy = False , force_all_finite = True , ensure_2d = True ,
430
432
allow_nd = False , multi_output = False , ensure_min_samples = 1 ,
431
433
ensure_min_features = 1 , y_numeric = False ,
432
- warn_on_dtype = False , estimator = None , sample_weight = False ):
434
+ warn_on_dtype = False , estimator = None , sample_weight = NOT_SPECIFIED ):
433
435
"""Input validation for standard estimators.
434
436
435
437
Checks X and y for consistent length, enforces X 2d and y 1d.
@@ -503,6 +505,10 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None,
503
505
estimator : str or estimator instance (default=None)
504
506
If passed, include the name of the estimator in warning messages.
505
507
508
+ sample_weight : nd-array, list
509
+ Ensures sample_weight is 1-d.
510
+
511
+
506
512
Returns
507
513
-------
508
514
X_converted : object
@@ -525,16 +531,17 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None,
525
531
526
532
check_consistent_length (X , y )
527
533
528
- if sample_weight is None :
534
+ if sample_weight is NOT_SPECIFIED :
535
+ return X , y
536
+
537
+ elif sample_weight is None :
529
538
return X , y , sample_weight
530
539
531
- elif sample_weight is not False :
540
+ else :
532
541
sample_weight = check_array (sample_weight , ensure_2d = False )
533
542
check_consistent_length (y , sample_weight )
534
543
return X , y , sample_weight
535
544
536
- return X , y
537
-
538
545
539
546
def column_or_1d (y , warn = False ):
540
547
""" Ravel column or 1d numpy array, else raises an error
0 commit comments