Skip to content

[MRG] Add quantile regression #9978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 147 commits into from
May 25, 2021
Merged

Conversation

avidale
Copy link
Contributor

@avidale avidale commented Oct 22, 2017

This PR fixes issue #3148

This new feature implements quantile regression - an algorithm that directly minimizes mean absolute error of a linear regression model.

The work is still in progress, but I do want to receive some feedback.

@jnothman
Copy link
Member

Thanks. You'll need a lot of patience for core devs to review this in detail, I suspect. We have a number of pressing API issues as well as highly requested features that have been in the reviewing queue for some time. But I hope we'll get here soon enough!

@ashimb9
Copy link
Contributor

ashimb9 commented Nov 28, 2017

@avidale Thanks a lot for your contribution! @jnothman has already outlined the broader picture but FWIW a few pointers from a wanderer for when the core devs become available. First, I would suggest that you look into resolving the CI test failures. After that, you might want to consider adding to the PR to a point where you feel comfortable changing the status of this PR to [MRG] from [WIP]. (Of course, this is of no use if you actually need some comments/ideas before you can start working any further). In my limited experience, [WIP]s are usually not prioritized for review (but don't quote me on this ;)) so you might want to consider the change. Finally when you get to that point, you might want to tag some of the core devs that participated in the original discussion since some of them might have missed the initial post.

@JasonSanchez
Copy link

This is a great add to scikit-learn. Would personally really like to see it merged.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • You have test failure to deal with.
  • Some parameters have not been tested, such as l1_ratio.
  • Please add the class to doc/modules/classes.rst
  • We no longer use assert_true, assert_greater. Just use bare assert.

@avidale
Copy link
Contributor Author

avidale commented Mar 18, 2018

Under Travis CI there is another failure: the 'nit' parameter not found in the result of a scipy.optimize call. It runs with scipy 0.13.3, which might be too old. What would you recommend to do?

  • make a workaround for old versions of scipy?
  • remove the 'nit' functionality at all?
  • mark the class as working only with new scipy and change my test to respect this restriction?

Current solution: just changed the solver from BFGS to L-BFGS-B. The latter has supported nit since scipy 0.12.

@ghost
Copy link

ghost commented May 23, 2021

I have had a go with a real training set as follows:

X = pandas dataframe ([61296 rows x 2846 columns])
y = numpy array (array of float 64) - Size is (61296,)

(There are 2846 variables in the model and I have 61296 measurements of the 2846 variables.)

I run into the following problem in _quantile.py
"MemoryError: Unable to allocate 28.0 GiB for an array with shape (61296, 61296) and data type float64"

This occurs when np.eye(n_mask) tries to create a 61296 x 61296 identity matrix.

Is a more memory efficient implementation possible?


File "", line 1, in
z=est.fit(X, y)

File "\sklearn\linear_model_quantile.py", line 217, in fit
-np.eye(n_mask),

File "\site-packages\numpy\lib\twodim_base.py", line 209, in eye
m = zeros((N, M), dtype=dtype, order=order)

MemoryError: Unable to allocate 28.0 GiB for an array with shape (61296, 61296) and data type float64

@agramfort
Copy link
Member

@RPyElec do you know a solver that would handle problems of this size? if so what optimization method is used?

@glemaitre
Copy link
Member

I did only list free references, otherwise the book by Koenker (2005) would be THE reference

The book reference could be nice I think.

@ghost
Copy link

ghost commented May 23, 2021

I'm not an expert on solvers I'm afraid. I have been running the quantile regression problem above with the implementation here (https://pypi.org/project/asgl/). I have an academic license for MOSEK/Gurobi as the LP solver but have also had a go with the free solvers that it comes with - which also work).

(Scikit-learn is better maintained so is preferable).

@glemaitre
Copy link
Member

do you know a solver that would handle problems of this size

Is there some incremental/online solver?

@agramfort
Copy link
Member

agramfort commented May 23, 2021 via email

@lorentzenchr
Copy link
Member

@RPyElec Cool that you already gave this PR a try and feedback. Once this PR is merged, there might be room for improvements. In particular, if your problem/feature matrix X is sparse, one could use the linprog methods that support sparse input.

If you urgently need a solution right now, you could try the R package https://cran.r-project.org/package=quantreg. Though I don't know if it will work on your problem.

@ghost
Copy link

ghost commented May 24, 2021

@lorentzenchr - understood! Looking forward to seeing this get pushed too. I have a more powerful machine I can run the current implementation on (and ASGL too) - so no worries there. I was just flagging the large identity matrices being created as a potential issue.

@agramfort - I'm trying to put something together that I can use as an example and will get back to you.

@ghost
Copy link

ghost commented May 24, 2021

You manage to make it work with asgl? It seems to use cvxopt/cvxpy which makes me skeptical that it scales? Can you share more details?

@agramfort

QR examples.txt
I have attached a simplified QR script which runs scikit-learn's QR and asgl with MOSEK (licence needed) and SCS (free) as LP solvers. It is slow so not sure how well it scales (though it does run with 16 GB RAM).

If you run with MOSEK, then in asgl.py, you will need to add the MOSEK section to _cvxpy_solver_options:

def _cvxpy_solver_options(self, solver):

    if solver == 'ECOS':
        solver_dict = dict(solver=solver,
                           max_iters=self.max_iters)

    elif solver == 'OSQP':
        solver_dict = dict(solver=solver,
                           max_iter=self.max_iters)

    elif solver == "MOSEK":
        import mosek
        solver_dict = dict(solver=solver,
                           warm_start=True,
                           #max_iters=self.max_iters,
                           mosek_params={mosek.iparam.intpnt_solve_form:
                                         mosek.solveform.dual,
                         #                mosek.iparam.num_threads: 1
                                         })

    else:
        solver_dict = dict(solver=solver)
    return solver_dict

(It might be necessary to uncomment the num_threads line depending on your setup).

I would also recommend adding a print statement in def lasso at the three locations below (again in asgl.py):

            if self.solver == 'default':
                print("Using %s" % self.solver)                                     # PRINT 1
                problem.solve(warm_start=True) 
            else:
                print("Using %s" % self.solver)                                    # PRINT 2
                solver_dict = self._cvxpy_solver_options(solver=self.solver)
                problem.solve(**solver_dict)
        except (ValueError, cvxpy.error.SolverError):
            logging.warning(
                'Default solver failed. Using alternative options. Check solver and solver_stats for more '
                'details')
            solver = ['ECOS', 'OSQP', 'SCS']
            for elt in solver:
                print("Using %s" % elt)                                    # PRINT 3

@avidale
Copy link
Contributor Author

avidale commented May 24, 2021

@RPyElec @glemaitre I think that for quantile regression it is very easy to implement a naive gradient descent based solver that is memory-efficient and is relatively fast on large datasets.

The quantile loss is (resid > 0) * resid * q - (resid < 0) * resid * (1 - q) where resid = y - X @ coef, so its antigradient w.r.t. the coefficients is ((resid > 0) * q - (resid < 0) * (1 - q)) @ X, and we can just add this value to the coefficients until the loss converges. The tricky parts is choosing the right learning rate, but there are a couple of heuristics that I hope will work well on most datasets.

Here is a Colab notebook with my proof-of-concept implementation that trains on a 60,000 x 3,000 dataset in 20 seconds.

If you think that this direction is correct, maybe we include a solver like this into a future version of QuantileRegressor?

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented May 24, 2021 via email

@agramfort
Copy link
Member

@RPyElec it would make our life easier if you share code snippets eg in https://gist.github.com/ so we have very control if needed and if you could share a branch from your fork of asgl project so we don't have to apply the patch manually.

@avidale as @GaelVaroquaux said above it's a non-smooth problem. What you do is a "sub-gradient" descent will can be quite slow (known theoretical rates). You could add your solver in https://github.com/benchopt/benchmark_quantile_regression to actually compare the solvers objectively. WDYT?

@glemaitre
Copy link
Member

You could add your solver in https://github.com/benchopt/benchmark_quantile_regression to actually compare the solvers objectively. WDYT?

This seems a good idea.

Now, my question is: do we want to merge the current version with solvers that work reasonably well for in-memory problems, or do we want as well benchmark solver optimum for online learning on the offline problems? In short, do we merge now and improve the estimator or by doing so, do are putting ourselve into trouble?

@lorentzenchr
Copy link
Member

+1 for merging now (or else I go crazy - to share my feelings, too:smirk:). linprog is a solid approach. Let's investigate room for improvement (which I'm very interested in, don't get me wrong) after having a solid solution in place => merge.

@glemaitre glemaitre merged commit c1cc67d into scikit-learn:main May 25, 2021
@glemaitre
Copy link
Member

OK so LGTM. I will open a subsequent issue to address the point raised in the discussion.

@agramfort
Copy link
Member

🍻 🎉

@glemaitre
Copy link
Member

Thanks to all contributors

@atrettin
Copy link

Hi! The QuantileRegressor is already very neat, but it can be made even nicer by applying the kernel trick to support more functional forms (in particular the 'rbf' kernel). So, I did that in #23153 ! I thought I'd mention it here since people who are interested in quantile regression might see it. I'd appreciate feedback!

@glemaitre
Copy link
Member

Open an issue when you make a feature request. A closed or merged issue/pr will not raise any attention.

On the topic, it might be better to do a pipeline with a Nystroem transformer followed by a Quantile Regressor. It will scale better.

@atrettin
Copy link

Open an issue when you make a feature request. A closed or merged issue/pr will not raise any attention.

On the topic, it might be better to do a pipeline with a Nystroem transformer followed by a Quantile Regressor. It will scale better.

Well, it already got more than zero engagement 😅 ! I was not aware of this, so should I make an issue for the PR as well? Also, thanks for the tip with the transformer, I'll check it out! On the other hand, maybe L2 regularization would be more desirable? The QuantileRegressor only does L1, the support vector regression does L2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Moderate Anything that requires some knowledge of conventions and best practices module:linear_model New Feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.