Skip to content

ENH: Preserve float32/64 for SGD #13346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ Support for Python 3.4 and below has been officially dropped.
:mod:`sklearn.linear_model`
...........................

- |Enhancement| :mod:`linear_model.stochastic_gradient` now preserves
``float32`` and ``float64`` dtypes. :issues:`11000` by
:user:`Joan Massich <massich>`

- |Feature| :class:`linear_model.LogisticRegression` and
:class:`linear_model.LogisticRegressionCV` now support Elastic-Net penalty,
with the 'saga' solver. :issue:`11646` by :user:`Nicolas Hug <NicolasHug>`.
Expand Down
2 changes: 2 additions & 0 deletions sklearn/linear_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def make_dataset(X, y, sample_weight, random_state=None):
CSRData = CSRDataset64
ArrayData = ArrayDataset64

sample_weight = sample_weight.astype(X.dtype, copy=False) # XXX: I don't think this should be here

if sp.issparse(X):
dataset = CSRData(X.data, X.indptr, X.indices, y, sample_weight,
seed=seed)
Expand Down
46 changes: 24 additions & 22 deletions sklearn/linear_model/sgd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ import sys
from time import time

cimport cython
from cython cimport floating
from libc.math cimport exp, log, sqrt, pow, fabs
cimport numpy as np
from numpy.math cimport INFINITY
cdef extern from "sgd_fast_helpers.h":
bint skl_isfinite(double) nogil

from sklearn.utils.weight_vector cimport WeightVector
from sklearn.utils.seq_dataset cimport SequentialDataset64 as SequentialDataset
# from sklearn.utils.seq_dataset cimport SequentialDataset64 as SequentialDataset
from sklearn.utils.seq_dataset cimport SequentialDataset32 as SequentialDataset

np.import_array()

Expand Down Expand Up @@ -333,7 +335,7 @@ cdef class SquaredEpsilonInsensitive(Regression):
return SquaredEpsilonInsensitive, (self.epsilon,)


def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
def plain_sgd(floating[::1] weights,
double intercept,
LossFunction loss,
int penalty_type,
Expand Down Expand Up @@ -450,9 +452,9 @@ def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
return standard_weights, standard_intercept, n_iter_


def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
def average_sgd(floating[::1] weights,
double intercept,
np.ndarray[double, ndim=1, mode='c'] average_weights,
floating[::1] average_weights,
double average_intercept,
LossFunction loss,
int penalty_type,
Expand Down Expand Up @@ -579,9 +581,9 @@ def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
average)


def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
def _plain_sgd(floating[::1] weights,
double intercept,
np.ndarray[double, ndim=1, mode='c'] average_weights,
floating[::1] average_weights,
double average_intercept,
LossFunction loss,
int penalty_type,
Expand All @@ -605,8 +607,8 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
cdef Py_ssize_t n_features = weights.shape[0]

cdef WeightVector w = WeightVector(weights, average_weights)
cdef double* w_ptr = &weights[0]
cdef double *x_data_ptr = NULL
cdef floating* w_ptr = &weights[0]
cdef floating *x_data_ptr = NULL
cdef int *x_ind_ptr = NULL
cdef double* ps_ptr = NULL

Expand All @@ -621,8 +623,8 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
cdef double score = 0.0
cdef double best_loss = INFINITY
cdef double best_score = -INFINITY
cdef double y = 0.0
cdef double sample_weight
cdef floating y = 0.0
cdef floating sample_weight
cdef double class_weight = 1.0
cdef unsigned int count = 0
cdef unsigned int epoch = 0
Expand All @@ -638,11 +640,10 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
cdef unsigned char [:] validation_mask_view = validation_mask

# q vector is only used for L1 regularization
cdef np.ndarray[double, ndim = 1, mode = "c"] q = None
cdef double * q_data_ptr = NULL
if penalty_type == L1 or penalty_type == ELASTICNET:
q = np.zeros((n_features,), dtype=np.float64, order="c")
q_data_ptr = <double * > q.data
cdef floating[::1] q = np.zeros((n_features,), dtype=weights.dtype,
order="c")
cdef floating * q_data_ptr = &q[0]

cdef double u = 0.0

if penalty_type == L2:
Expand Down Expand Up @@ -757,7 +758,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,

# floating-point under-/overflow check.
if (not skl_isfinite(intercept)
or any_nonfinite(<double *>weights.data, n_features)):
or any_nonfinite(&weights[0], n_features)):
infinity = True
break

Expand Down Expand Up @@ -802,7 +803,8 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights,
return weights, intercept, average_weights, average_intercept, epoch + 1


cdef bint any_nonfinite(double *w, int n) nogil:
cdef bint any_nonfinite(floating *w, int n) nogil:
cdef int i
for i in range(n):
if not skl_isfinite(w[i]):
return True
Expand All @@ -819,18 +821,18 @@ cdef double sqnorm(double * x_data_ptr, int * x_ind_ptr, int xnnz) nogil:
return x_norm


cdef void l1penalty(WeightVector w, double * q_data_ptr,
int *x_ind_ptr, int xnnz, double u) nogil:
cdef void l1penalty(WeightVector w, floating * q_data_ptr,
int *x_ind_ptr, int xnnz, floating u) nogil:
"""Apply the L1 penalty to each updated feature

This implements the truncated gradient approach by
[Tsuruoka, Y., Tsujii, J., and Ananiadou, S., 2009].
"""
cdef double z = 0.0
cdef floating z = 0.0
cdef int j = 0
cdef int idx = 0
cdef double wscale = w.wscale
cdef double *w_data_ptr = w.w_data_ptr
cdef floating wscale = w.wscale
cdef floating *w_data_ptr = w.w_data_ptr
for j in range(xnnz):
idx = x_ind_ptr[j]
z = w_data_ptr[idx]
Expand Down
46 changes: 27 additions & 19 deletions sklearn/linear_model/stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _validate_sample_weight(self, sample_weight, n_samples):
return sample_weight

def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None,
intercept_init=None):
intercept_init=None, dtype=np.float64):
"""Allocate mem for parameters; initialize if provided."""
if n_classes > 2:
# allocate coef_ for multi-class
Expand All @@ -235,7 +235,7 @@ def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None,
self.coef_ = coef_init
else:
self.coef_ = np.zeros((n_classes, n_features),
dtype=np.float64, order="C")
dtype=dtype, order="C")

# allocate intercept_ for multi-class
if intercept_init is not None:
Expand All @@ -245,12 +245,12 @@ def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None,
"does not match dataset.")
self.intercept_ = intercept_init
else:
self.intercept_ = np.zeros(n_classes, dtype=np.float64,
self.intercept_ = np.zeros(n_classes, dtype=dtype,
order="C")
else:
# allocate coef_ for binary problem
if coef_init is not None:
coef_init = np.asarray(coef_init, dtype=np.float64,
coef_init = np.asarray(coef_init, dtype=dtype,
order="C")
coef_init = coef_init.ravel()
if coef_init.shape != (n_features,):
Expand All @@ -259,28 +259,28 @@ def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None,
self.coef_ = coef_init
else:
self.coef_ = np.zeros(n_features,
dtype=np.float64,
dtype=dtype,
order="C")

# allocate intercept_ for binary problem
if intercept_init is not None:
intercept_init = np.asarray(intercept_init, dtype=np.float64)
intercept_init = np.asarray(intercept_init, dtype=dtype)
if intercept_init.shape != (1,) and intercept_init.shape != ():
raise ValueError("Provided intercept_init "
"does not match dataset.")
self.intercept_ = intercept_init.reshape(1,)
else:
self.intercept_ = np.zeros(1, dtype=np.float64, order="C")
self.intercept_ = np.zeros(1, dtype=dtype, order="C")

# initialize average parameters
if self.average > 0:
self.standard_coef_ = self.coef_
self.standard_intercept_ = self.intercept_
self.average_coef_ = np.zeros(self.coef_.shape,
dtype=np.float64,
dtype=dtype,
order="C")
self.average_intercept_ = np.zeros(self.standard_intercept_.shape,
dtype=np.float64,
dtype=dtype,
order="C")

def _make_validation_split(self, y):
Expand Down Expand Up @@ -331,12 +331,12 @@ def _make_validation_score_cb(self, validation_mask, X, y, sample_weight,
sample_weight[validation_mask], classes=classes)


def _prepare_fit_binary(est, y, i):
def _prepare_fit_binary(est, y, i, dtype=np.float64):
"""Initialization for fit_binary.

Returns y, coef, intercept, average_coef, average_intercept.
"""
y_i = np.ones(y.shape, dtype=np.float64, order="C")
y_i = np.ones(y.shape, dtype=dtype, order="C")
y_i[y != est.classes_[i]] = -1.0
average_intercept = 0
average_coef = None
Expand Down Expand Up @@ -412,7 +412,7 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
# if average is not true, average_coef, and average_intercept will be
# unused
y_i, coef, intercept, average_coef, average_intercept = \
_prepare_fit_binary(est, y, i)
_prepare_fit_binary(est, y, i, dtype=X.dtype)
assert y_i.shape[0] == y.shape[0] == sample_weight.shape[0]
dataset, intercept_decay = make_dataset(X, y_i, sample_weight)

Expand Down Expand Up @@ -515,7 +515,10 @@ def _partial_fit(self, X, y, alpha, C,
loss, learning_rate, max_iter,
classes, sample_weight,
coef_init, intercept_init):
X, y = check_X_y(X, y, 'csr', dtype=np.float64, order="C",
X, y = check_X_y(X, y,
accept_sparse='csr',
dtype=[np.float64, np.float32],
order="C",
accept_large_sparse=False)

n_samples, n_features = X.shape
Expand All @@ -531,7 +534,8 @@ def _partial_fit(self, X, y, alpha, C,

if getattr(self, "coef_", None) is None or coef_init is not None:
self._allocate_parameter_mem(n_classes, n_features,
coef_init, intercept_init)
coef_init, intercept_init,
dtype=X.dtype)
elif n_features != self.coef_.shape[-1]:
raise ValueError("Number of features %d does not match previous "
"data %d." % (n_features, self.coef_.shape[-1]))
Expand Down Expand Up @@ -564,7 +568,10 @@ def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
if hasattr(self, "classes_"):
self.classes_ = None

X, y = check_X_y(X, y, 'csr', dtype=np.float64, order="C",
X, y = check_X_y(X, y,
accept_sparse='csr',
dtype=[np.float64, np.float32],
order="C",
accept_large_sparse=False)

# labels can be encoded as float, int, or string literals
Expand Down Expand Up @@ -1136,9 +1143,10 @@ def __init__(self, loss="squared_loss", penalty="l2", alpha=0.0001,

def _partial_fit(self, X, y, alpha, C, loss, learning_rate,
max_iter, sample_weight, coef_init, intercept_init):
X, y = check_X_y(X, y, "csr", copy=False, order='C', dtype=np.float64,
X, y = check_X_y(X, y, "csr", copy=False, order='C',
dtype=[np.float64, np.float32],
accept_large_sparse=False)
y = y.astype(np.float64, copy=False)
y = y.astype(X.dtype, copy=False) # XXX: isn't this done in check_X_y already

n_samples, n_features = X.shape

Expand All @@ -1153,9 +1161,9 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate,
"data %d." % (n_features, self.coef_.shape[-1]))
if self.average > 0 and getattr(self, "average_coef_", None) is None:
self.average_coef_ = np.zeros(n_features,
dtype=np.float64,
dtype=X.dtype,
order="C")
self.average_intercept_ = np.zeros(1, dtype=np.float64, order="C")
self.average_intercept_ = np.zeros(1, dtype=X.dtype, order="C")

self._fit_regressor(X, y, alpha, C, loss, learning_rate,
sample_weight, max_iter)
Expand Down
21 changes: 21 additions & 0 deletions sklearn/linear_model/tests/test_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_allclose
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_less
Expand Down Expand Up @@ -1673,3 +1674,23 @@ def test_SGDClassifier_fit_for_all_backends(backend):
with parallel_backend(backend=backend):
clf_parallel.fit(X, y)
assert_array_almost_equal(clf_sequential.coef_, clf_parallel.coef_)


@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier,
SGDRegressor, SparseSGDRegressor])
@pytest.mark.parametrize(
'loss', ['hinge', 'squared_hinge', 'log', 'modified_huber'])
def test_dtype_sgd_match_and_stability(klass, loss):
# rtol = 1e-2 if os.name == 'nt' and _IS_32BIT else 1e-5
rtol = 1e-5
clf_dict = dict()
for current_dtype in (np.float32, np.float64):
clf_dict[current_dtype] = (klass(alpha=0.01)
.fit(X=X.astype(current_dtype, copy=False),
y=np.array(Y, dtype=current_dtype)))

assert clf_dict[np.float32].coef_.dtype == np.float32
assert clf_dict[np.float64].coef_.dtype == np.float64
assert_allclose(clf_dict[np.float32].coef_,
clf_dict[np.float64].coef_,
rtol=rtol)