Skip to content

More sensitive sample weight estimator check #30143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

ogrisel
Copy link
Member

@ogrisel ogrisel commented Oct 24, 2024

Related to #16298.

Based on @jeremiedbb's experience when conducting extra experiments for #29907, I figured we could improve the sensitivity of our sample weight check by having a few data points with considerable weights (instead of just assigning random integer weights uniformly).

I also changed the test to run predictions on more data points and from another distribution to be sensitive to changes in the extrapolation of the decision boundaries. Previously, we only tested the predictions on the small number of training data points.

As a result, our check can reduce the number of XPASS down to 0 false negatives.

I also fixed some false positive XFAIL cases:

  • by avoiding feeding negative feature values to estimators that do not support it;
  • by decreasing the tolerance for the lsqr and sparse-cg solvers. Maybe this indicates that we should adjust their default value, but this is better be done in a dedicated follow-up PR on better defaults for linear model solvers (already tracked in Better selection of the default solver for Ridge based on the shape of data #14269).

Copy link

github-actions bot commented Oct 24, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: fe9d047. Link to the linter CI: here

"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
Copy link
Member Author

@ogrisel ogrisel Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KBinsDiscretizer has actually no sample_weight related problem when the number of samples is small with the default quantile strategy, which is the case in check_sample_weight_equivalence.

This check would fail with a larger number of samples and resampling kicks in (and being concurrently investigated in #29907). Or we were to enable the kmeans strategy, which would also make the fit process no longer deterministic even without subsampling. But in either cases, we would need a much slower statistical test instead, so let's better just test the deterministic case for now.

@ogrisel
Copy link
Member Author

ogrisel commented Oct 24, 2024

cc @antoinebaker, @snath-xoc and @jeremiedbb.

@ogrisel
Copy link
Member Author

ogrisel commented Oct 24, 2024

I have discovered a whole can of new worms with this updated test: I am working on making sure the model converge before investigating errors. I will update the test accordingly.

@ogrisel
Copy link
Member Author

ogrisel commented Oct 24, 2024

So LogisticRegression with the lbfgs or liblinear solvers also has a sample_weight related discrepancy, even when decreasing the tol to minimal values. However, the magnitude of the discrepancy is not that large (~1e-5 to 1e-7 in relative error, 1e-6 to 1e-8 in absolute error), so maybe it's our assertion that is too strict. Still, it's much larger than our convergence tolerance level (1e-12) or machine precision rounding errors, so this might have discovered a new bug in one of our most used estimator/solver combo.

It also detected a real bug in HuberRegressor (but this one is arguably less popular than LogisticRegression).

@ogrisel
Copy link
Member Author

ogrisel commented Oct 24, 2024

It's now failing with:

FAILED utils/tests/test_estimator_checks.py::test_check_estimator - ValueError: Buffer dtype mismatch, expected 'const float' but got 'double'
_____________________________ test_check_estimator _____________________________
[gw1] darwin -- Python 3.12.7 /usr/local/miniconda/envs/testvenv/bin/python

    def test_check_estimator():
        # tests that the estimator actually fails on "bad" estimators.
        # not a complete test of all checks, which are very extensive.
    
        # check that we have a fit method
        msg = "object has no attribute 'fit'"
        with raises(AttributeError, match=msg):
            check_estimator(BaseEstimator())
    
        # does error on binary_only untagged estimator
        msg = "Only 2 classes are supported"
        with raises(ValueError, match=msg):
            check_estimator(UntaggedBinaryClassifier())
    
        for csr_container in CSR_CONTAINERS:
            # non-regression test for estimators transforming to sparse data
            check_estimator(SparseTransformer(sparse_container=csr_container))
    
        # doesn't error on actual estimator
>       check_estimator(LogisticRegression())

csr_container = <class 'scipy.sparse._csr.csr_array'>
msg        = 'Only 2 classes are supported'

../1/s/sklearn/utils/tests/test_estimator_checks.py:781: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../1/s/sklearn/utils/estimator_checks.py:624: in check_estimator
    check(estimator)
        check      = functools.partial(<function check_classifiers_train at 0x120123e20>, 'LogisticRegression', readonly_memmap=True, X_dtype='float32')
        checks_generator = <function check_estimator.<locals>.checks_generator at 0x129fdfd80>
        estimator  = LogisticRegression(max_iter=1000, solver='newton-cholesky', tol=1e-12)
        generate_only = False
        legacy     = True
        name       = 'LogisticRegression'
../1/s/sklearn/utils/_testing.py:140: in wrapper
    return fn(*args, **kwargs)
        args       = ('LogisticRegression', LogisticRegression(max_iter=1000, solver='newton-cholesky', tol=1e-12))
        fn         = <function check_classifiers_train at 0x120123d80>
        kwargs     = {'X_dtype': 'float32', 'readonly_memmap': True}
        self       = _IgnoreWarnings(record=True)
../1/s/sklearn/utils/estimator_checks.py:2307: in check_classifiers_train
    classifier.fit(X, y)
        X          = memmap([[ 0.04310846,  0.25109133],
        [ 0.30452606,  0.25814292],
        [ 0.73052144,  0.98760027],
        [ ...    [ 0.70305467, -0.79237926],
        [ 1.5169084 , -1.2554711 ],
        [ 1.753074  , -2.1646235 ]], dtype=float32)
        X_b        = memmap([[ 0.04310846,  0.25109133],
        [ 0.30452606,  0.25814292],
        [ 0.73052144,  0.98760027],
        [ ...    [ 0.70305467, -0.79237926],
        [ 1.5169084 , -1.2554711 ],
        [ 1.753074  , -2.1646235 ]], dtype=float32)
        X_dtype    = 'float32'
        X_m        = memmap([[-0.85589486,  0.21554239],
        [-0.7850048 ,  0.72501343],
        [ 0.04310846,  0.25109133],
        [ ...    [-1.5771745 ,  0.0352128 ],
        [ 1.5169084 , -1.2554711 ],
        [ 1.753074  , -2.1646235 ]], dtype=float32)
        classes    = array([0, 1])
        classifier = LogisticRegression(max_iter=1000, random_state=0, solver='newton-cholesky',
                   tol=1e-12)
        classifier_orig = LogisticRegression(max_iter=1000, solver='newton-cholesky', tol=1e-12)
        n_classes  = 2
        n_features = 2
        n_samples  = 200
        name       = 'LogisticRegression'
        problems   = [(memmap([[ 0.04310846,  0.25109133],
        [ 0.30452606,  0.25814292],
        [ 0.73052144,  0.98760027],
        ... 2, 1, 2, 1, 2, 2, 0, 0, 2, 0, 2, 2, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0,
        0, 0, 0, 2, 0, 2, 0, 0, 1, 0, 1, 2, 1, 1]))]
        readonly_memmap = True
        tags       = Tags(target_tags=TargetTags(required=True, one_d_labels=False, two_d_labels=False, positive_only=False, multi_output=F...alse, sparse=False, categorical=False, string=False, dict=False, positive_only=False, allow_nan=False, pairwise=False))
        y          = memmap([0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, ... 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
        1, 1])
        y_b        = memmap([0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, ... 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
        1, 1])
        y_m        = memmap([2, 2, 0, 0, 2, 0, 1, 1, 2, 1, 0, 2, 1, 1, 0, 0, 2, 2, 1, 0, 1, 2,
        1, 2, 1, 2, 2, 0, 2, 0, 0, 1, 0, 0, ...   2, 1, 2, 1, 2, 2, 0, 0, 2, 0, 2, 2, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0,
        0, 0, 0, 2, 0, 2, 0, 0, 1, 0, 1, 2, 1, 1])
../1/s/sklearn/base.py:1244: in wrapper
    return fit_method(estimator, *args, **kwargs)
        args       = (memmap([[ 0.04310846,  0.25109133],
        [ 0.30452606,  0.25814292],
        [ 0.73052144,  0.98760027],
        [...1, 1, 1, 1, 0, 0, 0, 1, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
        1, 1]))
        estimator  = LogisticRegression(max_iter=1000, random_state=0, solver='newton-cholesky',
                   tol=1e-12)
        fit_method = <function LogisticRegression.fit at 0x11d5a68e0>
        global_skip_validation = False
        kwargs     = {}
        partial_fit_and_fitted = False
        prefer_skip_nested_validation = True
../1/s/sklearn/linear_model/_logistic.py:1350: in fit
    fold_coefs_ = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, prefer=prefer)(
        C_         = 1.0
        X          = array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.90...      [ 0.70305467, -0.79237926],
       [ 1.5169084 , -1.2554711 ],
       [ 1.753074  , -2.1646235 ]], dtype=float32)
        _dtype     = [<class 'numpy.float64'>, <class 'numpy.float32'>]
        classes_   = array([1])
        max_squared_sum = None
        multi_class = 'ovr'
        n_classes  = 1
        n_threads  = 1
        path_func  = <function _logistic_regression_path at 0x129f5fec0>
        penalty    = 'l2'
        prefer     = 'processes'
        sample_weight = None
        self       = LogisticRegression(max_iter=1000, random_state=0, solver='newton-cholesky',
                   tol=1e-12)
        solver     = 'newton-cholesky'
        warm_start_coef = [None]
        y          = array([0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0,
       0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1,...1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
       1, 1])
../1/s/sklearn/utils/parallel.py:77: in __call__
    return super().__call__(iterable_with_config)
        __class__  = <class 'sklearn.utils.parallel.Parallel'>
        config     = {'array_api_dispatch': False, 'assume_finite': False, 'display': 'diagram', 'enable_cython_pairwise_dist': True, ...}
        iterable   = <generator object LogisticRegression.fit.<locals>.<genexpr> at 0x12980fdf0>
        iterable_with_config = <generator object Parallel.__call__.<locals>.<genexpr> at 0x128b0a4d0>
        self       = Parallel(n_jobs=1)
/usr/local/miniconda/envs/testvenv/lib/python3.12/site-packages/joblib/parallel.py:1918: in __call__
    return output if self.return_generator else list(output)
        iterable   = <generator object Parallel.__call__.<locals>.<genexpr> at 0x128b0a4d0>
        n_jobs     = 1
        output     = <generator object Parallel._get_sequential_output at 0x129776e60>
        self       = Parallel(n_jobs=1)
/usr/local/miniconda/envs/testvenv/lib/python3.12/site-packages/joblib/parallel.py:1847: in _get_sequential_output
    res = func(*args, **kwargs)
        args       = (array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.9..., 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
       1, 1]))
        batch_size = 1
        func       = <sklearn.utils.parallel._FuncWrapper object at 0x129b89070>
        iterable   = <generator object Parallel.__call__.<locals>.<genexpr> at 0x128b0a4d0>
        kwargs     = {'Cs': [1.0], 'check_input': False, 'class_weight': None, 'coef': None, ...}
        self       = Parallel(n_jobs=1)
../1/s/sklearn/utils/parallel.py:139: in __call__
    return self.function(*args, **kwargs)
        args       = (array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.9..., 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
       1, 1]))
        config     = {'array_api_dispatch': False, 'assume_finite': False, 'display': 'diagram', 'enable_cython_pairwise_dist': True, ...}
        kwargs     = {'Cs': [1.0], 'check_input': False, 'class_weight': None, 'coef': None, ...}
        self       = <sklearn.utils.parallel._FuncWrapper object at 0x129b89070>
../1/s/sklearn/linear_model/_logistic.py:496: in _logistic_regression_path
    w0 = sol.solve(X=X, y=target, sample_weight=sample_weight)
        C          = 1.0
        Cs         = [1.0]
        X          = array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.90...      [ 0.70305467, -0.79237926],
       [ 1.5169084 , -1.2554711 ],
       [ 1.753074  , -2.1646235 ]], dtype=float32)
        check_input = False
        class_weight = None
        classes    = array([0, 1])
        coef       = None
        coefs      = []
        dual       = False
        fit_intercept = True
        i          = 0
        intercept_scaling = 1.0
        l1_ratio   = None
        l2_reg_strength = 0.005
        le         = LabelEncoder()
        loss       = <sklearn.linear_model._linear_loss.LinearModelLoss object at 0x129b88e60>
        mask       = array([False, False, False,  True,  True,  True, False,  True,  True,
       False, False,  True, False,  True,  True,...False,  True,  True, False,
       False, False, False, False, False, False,  True, False,  True,
        True,  True])
        mask_classes = array([0, 1])
        max_iter   = 1000
        max_squared_sum = None
        multi_class = 'ovr'
        n_features = 2
        n_iter     = array([0], dtype=int32)
        n_samples  = 200
        n_threads  = 1
        penalty    = 'l2'
        pos_class  = 1
        random_state = RandomState(MT19937) at 0x129FC7B40
        sample_weight = None
        sol        = <sklearn.linear_model._glm._newton_solver.NewtonCholeskySolver object at 0x129b89a90>
        solver     = 'newton-cholesky'
        sw_sum     = 200
        target     = array([0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., ... 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32)
        tol        = 1e-12
        verbose    = 0
        w0         = array([0., 0., 0.], dtype=float32)
        warm_start_sag = {'coef': array([[0.],
       [0.],
       [0.]], dtype=float32)}
        y          = array([0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0,
       0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1,...1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
       1, 1])
        y_bin      = array([0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., ... 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32)
../1/s/sklearn/linear_model/_glm/_newton_solver.py:429: in solve
    self.fallback_lbfgs_solve(X=X, y=y, sample_weight=sample_weight)
        X          = array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.90...      [ 0.70305467, -0.79237926],
       [ 1.5169084 , -1.2554711 ],
       [ 1.753074  , -2.1646235 ]], dtype=float32)
        sample_weight = None
        self       = <sklearn.linear_model._glm._newton_solver.NewtonCholeskySolver object at 0x129b89a90>
        y          = array([0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., ... 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32)
../1/s/sklearn/linear_model/_glm/_newton_solver.py:181: in fallback_lbfgs_solve
    opt_res = scipy.optimize.minimize(
        X          = array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.90...      [ 0.70305467, -0.79237926],
       [ 1.5169084 , -1.2554711 ],
       [ 1.753074  , -2.1646235 ]], dtype=float32)
        sample_weight = None
        self       = <sklearn.linear_model._glm._newton_solver.NewtonCholeskySolver object at 0x129b89a90>
        y          = array([0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., ... 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32)
/usr/local/miniconda/envs/testvenv/lib/python3.12/site-packages/scipy/optimize/_minimize.py:710: in minimize
    res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
        args       = (array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.9...0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32), None, 0.005, 1)
        bounds     = None
        callback   = None
        constraints = []
        fun        = <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>
        hess       = None
        hessp      = None
        jac        = <bound method MemoizeJac.derivative of <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>>
        meth       = 'l-bfgs-b'
        method     = 'L-BFGS-B'
        options    = {'ftol': 1.4210854715202004e-14, 'gtol': 1e-12, 'iprint': -1, 'maxiter': 993, ...}
        remove_vars = False
        tol        = None
        x0         = array([ 0.7035002 , -3.827957  , -0.69446135], dtype=float32)
/usr/local/miniconda/envs/testvenv/lib/python3.12/site-packages/scipy/optimize/_lbfgsb_py.py:307: in _minimize_lbfgsb
    sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
        args       = (array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.9...0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32), None, 0.005, 1)
        bounds     = [(None, None), (None, None), (None, None)]
        callback   = None
        disp       = None
        eps        = 1e-08
        factr      = 64.0
        finite_diff_rel_step = None
        ftol       = 1.4210854715202004e-14
        fun        = <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>
        gtol       = 1e-12
        iprint     = -1
        jac        = <bound method MemoizeJac.derivative of <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>>
        m          = 10
        maxcor     = 10
        maxfun     = 15000
        maxiter    = 993
        maxls      = 50
        n          = 3
        new_bounds = (array([-inf, -inf, -inf]), array([inf, inf, inf]))
        pgtol      = 1e-12
        unknown_options = {}
        x0         = array([ 0.70350021, -3.82795691, -0.69446135])
/usr/local/miniconda/envs/testvenv/lib/python3.12/site-packages/scipy/optimize/_optimize.py:383: in _prepare_scalar_function
    sf = ScalarFunction(fun, x0, args, grad, hess,
        args       = (array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.9...0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32), None, 0.005, 1)
        bounds     = (array([-inf, -inf, -inf]), array([inf, inf, inf]))
        epsilon    = 1e-08
        finite_diff_rel_step = None
        fun        = <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>
        grad       = <bound method MemoizeJac.derivative of <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>>
        hess       = <function _prepare_scalar_function.<locals>.hess at 0x129fdfe20>
        jac        = <bound method MemoizeJac.derivative of <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>>
        x0         = array([ 0.70350021, -3.82795691, -0.69446135])
/usr/local/miniconda/envs/testvenv/lib/python3.12/site-packages/scipy/optimize/_differentiable_functions.py:158: in __init__
    self._update_fun()
        args       = (array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.9...0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32), None, 0.005, 1)
        epsilon    = 1e-08
        finite_diff_bounds = (array([-inf, -inf, -inf]), array([inf, inf, inf]))
        finite_diff_options = {}
        finite_diff_rel_step = None
        fun        = <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>
        fun_wrapped = <function ScalarFunction.__init__.<locals>.fun_wrapped at 0x129fdf7e0>
        grad       = <bound method MemoizeJac.derivative of <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>>
        hess       = <function _prepare_scalar_function.<locals>.hess at 0x129fdfe20>
        self       = <scipy.optimize._differentiable_functions.ScalarFunction object at 0x12a347b00>
        update_fun = <function ScalarFunction.__init__.<locals>.update_fun at 0x129fdf600>
        x0         = array([ 0.70350021, -3.82795691, -0.69446135])
/usr/local/miniconda/envs/testvenv/lib/python3.12/site-packages/scipy/optimize/_differentiable_functions.py:251: in _update_fun
    self._update_fun_impl()
        self       = <scipy.optimize._differentiable_functions.ScalarFunction object at 0x12a347b00>
/usr/local/miniconda/envs/testvenv/lib/python3.12/site-packages/scipy/optimize/_differentiable_functions.py:155: in update_fun
    self.f = fun_wrapped(self.x)
        fun_wrapped = <function ScalarFunction.__init__.<locals>.fun_wrapped at 0x129fdf7e0>
        self       = <scipy.optimize._differentiable_functions.ScalarFunction object at 0x12a347b00>
/usr/local/miniconda/envs/testvenv/lib/python3.12/site-packages/scipy/optimize/_differentiable_functions.py:137: in fun_wrapped
    fx = fun(np.copy(x), *args)
        args       = (array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.9...0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32), None, 0.005, 1)
        fun        = <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>
        self       = <scipy.optimize._differentiable_functions.ScalarFunction object at 0x12a347b00>
        x          = array([ 0.70350021, -3.82795691, -0.69446135])
/usr/local/miniconda/envs/testvenv/lib/python3.12/site-packages/scipy/optimize/_optimize.py:77: in __call__
    self._compute_if_needed(x, *args)
        args       = (array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.9...0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32), None, 0.005, 1)
        self       = <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>
        x          = array([ 0.70350021, -3.82795691, -0.69446135])
/usr/local/miniconda/envs/testvenv/lib/python3.12/site-packages/scipy/optimize/_optimize.py:71: in _compute_if_needed
    fg = self.fun(x, *args)
        args       = (array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.9...0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32), None, 0.005, 1)
        self       = <scipy.optimize._optimize.MemoizeJac object at 0x129f02f60>
        x          = array([ 0.70350021, -3.82795691, -0.69446135])
../1/s/sklearn/linear_model/_linear_loss.py:316: in loss_gradient
    loss, grad_pointwise = self.base_loss.loss_gradient(
        X          = array([[ 0.04310846,  0.25109133],
       [ 0.30452606,  0.25814292],
       [ 0.73052144,  0.98760027],
       [ 0.90...      [ 0.70305467, -0.79237926],
       [ 1.5169084 , -1.2554711 ],
       [ 1.753074  , -2.1646235 ]], dtype=float32)
        coef       = array([ 0.70350021, -3.82795691, -0.69446135])
        intercept  = -0.6944613456726074
        l2_reg_strength = 0.005
        n_classes  = 2
        n_dof      = 3
        n_features = 2
        n_samples  = 200
        n_threads  = 1
        raw_prediction = array([-1.62530133, -1.46838717, -3.96103063,  4.38693961,  3.4173551 ,
        4.0500805 , -4.29471632,  1.81024278, ...441381, -3.47273599, -5.14990182, -0.04975034,
        6.63675183, -4.44000637,  2.83333143,  5.17857336,  8.82491211])
        sample_weight = None
        self       = <sklearn.linear_model._linear_loss.LinearModelLoss object at 0x129b88e60>
        weights    = array([ 0.70350021, -3.82795691])
        y          = array([0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., ... 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32)
../1/s/sklearn/_loss/loss.py:258: in loss_gradient
    self.closs.loss_gradient(
        gradient_out = array([-0.93083998, -0.77392582, -3.26656928,  5.08140096,  4.11181644,
        4.74454184, -3.60025498,  2.50470412, ...995246, -2.77827464, -4.45544047,  0.644711  ,
        7.33121317, -3.74554502,  3.52779277,  5.8730347 ,  9.51937346])
        loss_out   = array([-1.6253014 , -1.4683872 , -3.9610307 ,  4.3869395 ,  3.4173555 ,
        4.050081  , -4.2947164 ,  1.8102429 , ...-5.149902  , -0.04975027,
        6.6367526 , -4.4400067 ,  2.8333313 ,  5.178573  ,  8.824912  ],
      dtype=float32)
        n_threads  = 1
        raw_prediction = array([-1.62530133, -1.46838717, -3.96103063,  4.38693961,  3.4173551 ,
        4.0500805 , -4.29471632,  1.81024278, ...441381, -3.47273599, -5.14990182, -0.04975034,
        6.63675183, -4.44000637,  2.83333143,  5.17857336,  8.82491211])
        sample_weight = None
        self       = <sklearn._loss.loss.HalfBinomialLoss object at 0x129b8b5c0>
        y_true     = array([0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., ... 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1.], dtype=float32)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

>   ???
E   ValueError: Buffer dtype mismatch, expected 'const float' but got 'double'


sklearn/_loss/_loss.pyx:1779: ValueError
This happens when calling the check_classifiers_train estimator check that should not have been updated by this PR.

However, one can observe that classifier_orig is defined as:

classifier_orig = LogisticRegression(max_iter=1000, solver='newton-cholesky', tol=1e-12)

in that function call. This is very suspicious because it shows that a parametrization that was designed for the check_sample_weight_equivalence check that has leaked into another unrelated check. It looks like we miss a call to clone somewhere. But that does not explain the low level Cython error either when newton-cholesky falls back to lbfgs,

@ogrisel
Copy link
Member Author

ogrisel commented Oct 24, 2024

@adrinjalali I pushed ec424ba to resolve the problem found in check_estimator described above. The estimator variable in the for loop was shadowing the local variable that was captured in the checks_generator definition, making the parameter updates of the per-check dict accumulate over one another...

We could probably extract this fix from this PR, but I am not sure how to write a non-convoluted non-regression test...

@ogrisel
Copy link
Member Author

ogrisel commented Oct 24, 2024

Actually, the refactoring you suggested this morning would also have fixed this bug (by extracting the generator code in its own public function).

@adrinjalali
Copy link
Member

@ogrisel #30149 does the refactoring now.

@ogrisel
Copy link
Member Author

ogrisel commented Oct 29, 2024

I am not 100% sure if the newly discovered LogisticRegression XFAIL cases are caused by a bug in their handling of sample_weight or by a numerical convergence problem on the new test data used in the check. I have the feeling that fitting on data with large weights/repetitions is much more challenging for numerical solvers. Maybe we need to use lower values for the large weights.

Similarly, for HuberRegressor and liblinear based models.

We should probably investigate before considering reviewing and merging this PR.

Comment on lines 1171 to 1177
with warnings.catch_warnings(record=True):
# Ensure we converge, otherwise debugging sample_weight equivalence
# failures can be very misleading.
warnings.simplefilter("error", category=ConvergenceWarning)

estimator_repeated.fit(X_repeated, y=y_repeated, sample_weight=None)
estimator_weighted.fit(X_weigthed, y=y_weighted, sample_weight=sw)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we ensure convergence ? In my understanding, the deterministic fit should yield the same results for the weighted/repeated data at each iteration, and we can check equivalence after a small number of iterations even if the fit is terrible and hasn't converged. This should ensure that the computation before the loop and in the loop are sample weight aware, but arguably this doesn't test that the stopping criteria is sample weight aware.

I think having many iterations increases the risk to accumulate numerical errors and have false positives.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the point is to make sure that:

  • we do not raise uncaught warning when running the common tests: this forces us to specify estimator parameters either with high tol or a large max_iter;
  • to make the tests, more sensitive, we might want to specify low tol values so that the sample weight influence has been full imprinted in the learned decision function. But we want to make sure that max_iter is also high enough to make sure that the choice of tol has the effect we intend.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, for some convex objectives, the deterministic test might be expected to pass with low tol values, but the optimization paths for the weighted vs reweighted problems might not be equal for stochastic gradient (or minibatch) solvers.

This would be the case for LogisticRegression / Ridge and the likes with the "sag" and "saga" solvers. It would also be the case for SGDClassifier/SGDRegressor with convex objective functions (I think all the choices of loss and penalties we implement lead to convex objectives, although convergence via SGD might not be guaranteed for non-smooth penalties).

@ogrisel ogrisel marked this pull request as draft January 16, 2025 14:08
@ogrisel
Copy link
Member Author

ogrisel commented Jan 16, 2025

I put this PR back to draft because significantly more work is needed here.

I need to find a balance between making the weighted problem hard enough to reveal weight handling problems but not too hard; otherwise we break numerical solvers via ill-conditioning of the optimization problem...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

3 participants