@@ -753,82 +753,93 @@ def test_fused_types_make_dataset():
753
753
assert_array_equal (yi_64 , yicsr_64 )
754
754
755
755
756
- @pytest .mark .parametrize (
757
- "sparseX" ,
758
- [
759
- False ,
760
- pytest .param (
761
- True ,
762
- marks = pytest .mark .xfail (
763
- reason = "Compare issue #15438: sparse with "
764
- "sample weight gives wrong results."
765
- ),
766
- ),
767
- ],
768
- )
769
- @pytest .mark .parametrize ("fit_intercept" , [False , True ])
756
+ # FIXME: 'normalize' to be removed in 1.2
757
+ @pytest .mark .filterwarnings ("ignore:'normalize' was deprecated" )
770
758
@pytest .mark .parametrize ("normalize" , [False , True ])
771
- def test_linear_regression_sample_weight_consistentcy (
772
- sparseX , fit_intercept , normalize
759
+ @pytest .mark .parametrize ("sparseX" , [False , True ])
760
+ @pytest .mark .parametrize ("fit_intercept" , [False , True ])
761
+ @pytest .mark .parametrize ("data" , ["tall" , "wide" ])
762
+ def test_linear_regression_sample_weight_consistency (
763
+ normalize , sparseX , fit_intercept , data
773
764
):
774
- """Test that the impact of sample_weight is consistent."""
775
- rng = np .random .RandomState (0 )
776
- n_samples , n_features = 10 , 5
765
+ """Test that the impact of sample_weight is consistent.
777
766
778
- X = rng .rand (n_samples , n_features )
779
- y = rng .rand (n_samples )
767
+ Note that this test is stricter than the common test
768
+ check_sample_weights_invariance alone.
769
+ """
770
+ tol = 1e-4
771
+ if fit_intercept :
772
+ tol = 1e-1
773
+ n_samples = 100
774
+ if data == "tall" :
775
+ n_features = n_samples // 2
776
+ else :
777
+ n_features = n_samples * 2
778
+ rng = np .random .RandomState (42 )
779
+ X , y = make_regression (
780
+ n_samples = n_samples ,
781
+ n_features = n_features ,
782
+ effective_rank = None ,
783
+ n_informative = 50 ,
784
+ bias = int (fit_intercept ),
785
+ random_state = rng ,
786
+ )
780
787
if sparseX :
781
788
X = sparse .csr_matrix (X )
782
789
params = dict (fit_intercept = fit_intercept )
783
790
784
- reg = LinearRegression (** params ).fit (X , y )
791
+ # 1) sample_weight=np.ones(..) should be equivalent to sample_weight=None
792
+ # same check as check_sample_weights_invariance(name, reg, kind="ones"), but we also
793
+ # test with sparse input.
794
+ reg = LinearRegression (** params ).fit (X , y , sample_weight = None )
785
795
coef = reg .coef_ .copy ()
786
796
if fit_intercept :
787
797
intercept = reg .intercept_
788
-
789
- # sample_weight=np.ones(..) should be equivalent to sample_weight=None
790
- # same check is done as check_sample_weights_invariance in
791
- # sklearn.utils.estimator_checks.py, at least for dense input
792
798
sample_weight = np .ones_like (y )
793
799
reg .fit (X , y , sample_weight = sample_weight )
794
- assert_allclose (reg .coef_ , coef , rtol = 1e-6 )
800
+ assert_allclose (reg .coef_ , coef , rtol = tol )
795
801
if fit_intercept :
796
802
assert_allclose (reg .intercept_ , intercept )
797
803
798
- # scaling of sample_weight should have no effect
799
- # Note: For models with penalty, scaling the penalty term might work.
800
- sample_weight = np .pi * np .ones_like (y )
804
+ # 2) setting elements of sample_weight to 0 is equivalent to removing these samples
805
+ # same check as check_sample_weights_invariance(name, reg, kind="zeros"), but we
806
+ # also test with sparse input
807
+ sample_weight = rng .uniform (low = 0.01 , high = 2 , size = X .shape [0 ])
808
+ sample_weight [- 5 :] = 0
809
+ y [- 5 :] *= 1000 # to make excluding those samples important
801
810
reg .fit (X , y , sample_weight = sample_weight )
802
- assert_allclose ( reg . coef_ , coef , rtol = 1e-6 )
811
+ coef = reg . coef_ . copy ( )
803
812
if fit_intercept :
804
- assert_allclose (reg .intercept_ , intercept )
805
-
806
- # setting one element of sample_weight to 0 is equivalent to removing
807
- # the corresponding sample, see PR #15015
808
- sample_weight = np .ones_like (y )
809
- sample_weight [- 1 ] = 0
810
- reg .fit (X , y , sample_weight = sample_weight )
811
- coef1 = reg .coef_ .copy ()
813
+ intercept = reg .intercept_
814
+ reg .fit (X [:- 5 , :], y [:- 5 ], sample_weight = sample_weight [:- 5 ])
815
+ assert_allclose (reg .coef_ , coef , rtol = tol )
812
816
if fit_intercept :
813
- intercept1 = reg .intercept_
814
- reg .fit (X [:- 1 ], y [:- 1 ])
815
- assert_allclose (reg .coef_ , coef1 , rtol = 1e-6 )
817
+ assert_allclose (reg .intercept_ , intercept , atol = 1e-14 , rtol = 1e-6 )
818
+
819
+ # 3) scaling of sample_weight should have no effect
820
+ # Note: For models with penalty, scaling the penalty term might work.
821
+ reg2 = LinearRegression (** params )
822
+ reg2 .fit (X , y , sample_weight = np .pi * sample_weight )
823
+ assert_allclose (reg2 .coef_ , coef , rtol = tol )
816
824
if fit_intercept :
817
- assert_allclose (reg .intercept_ , intercept1 )
825
+ assert_allclose (reg2 .intercept_ , intercept , atol = 1e-14 , rtol = 1e-5 )
818
826
819
- # check that multiplying sample_weight by 2 is equivalent
827
+ # 4) check that multiplying sample_weight by 2 is equivalent
820
828
# to repeating correspoding samples twice
821
829
if sparseX :
822
830
X = X .toarray ()
823
831
X2 = np .concatenate ([X , X [: n_samples // 2 ]], axis = 0 )
824
832
y2 = np .concatenate ([y , y [: n_samples // 2 ]])
825
- sample_weight_1 = np .ones_like (y )
826
- sample_weight_1 [: n_samples // 2 ] = 2
833
+ sample_weight_1 = sample_weight .copy ()
834
+ sample_weight_1 [: n_samples // 2 ] *= 2
835
+ sample_weight_2 = np .concatenate (
836
+ [sample_weight , sample_weight [: n_samples // 2 ]], axis = 0
837
+ )
827
838
if sparseX :
828
839
X = sparse .csr_matrix (X )
829
840
X2 = sparse .csr_matrix (X2 )
830
841
reg1 = LinearRegression (** params ).fit (X , y , sample_weight = sample_weight_1 )
831
- reg2 = LinearRegression (** params ).fit (X2 , y2 , sample_weight = None )
832
- assert_allclose (reg1 .coef_ , reg2 .coef_ )
842
+ reg2 = LinearRegression (** params ).fit (X2 , y2 , sample_weight = sample_weight_2 )
843
+ assert_allclose (reg1 .coef_ , reg2 .coef_ , rtol = tol )
833
844
if fit_intercept :
834
- assert_allclose (reg1 .intercept_ , reg2 .intercept_ )
845
+ assert_allclose (reg1 .intercept_ , reg2 .intercept_ , atol = 1e-14 , rtol = 1e-6 )
0 commit comments