Skip to content

LogisticRegression convert to float64 (for SAG solver) #13243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Feb 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c4e223f
Remove unused code
Jun 16, 2017
9463e21
Squash all the PR 9040 commits
Jun 8, 2017
c1b795b
FIX tests
NelleV May 29, 2018
8a46660
TST more numerically stable test_sgd.test_tol_parameter
ogrisel May 29, 2018
9168b8f
Added benchmarks to compare SAGA 32b and 64b
NelleV May 29, 2018
8d89662
Fixing gael's comments
NelleV Jun 1, 2018
21c3f98
fix
NelleV Jul 8, 2018
2534661
Merge remote-tracking branch 'origin/master' into NelleV-henley_is_87…
glemaitre Jul 17, 2018
b10e2d7
solve some issues
glemaitre Jul 17, 2018
3af8134
Merge remote-tracking branch 'origin/master' into NelleV-henley_is_87…
glemaitre Jul 17, 2018
30f2080
PEP8
glemaitre Jul 17, 2018
7babe3e
Address lesteve comments
glemaitre Jul 19, 2018
8690bd6
Merge branch 'master' into henley_is_8769_minus_merged
Feb 25, 2019
7864555
fix merging
Feb 25, 2019
b83873c
avoid using assert_equal
Feb 25, 2019
7248300
use all_close
Feb 25, 2019
53fc38e
use explicit ArrayDataset64 and CSRDataset64
Feb 25, 2019
5eb99b8
fix: remove unused import
Feb 25, 2019
191da7a
Use parametrized to cover ArrayDaset-CSRDataset-32-64 matrix
Feb 25, 2019
a6e9309
for consistency use 32 first then 64 + add 64 suffix to variables
Feb 25, 2019
0d06913
it would be cool if this worked !!!
Feb 25, 2019
635d301
more verbose version
Feb 25, 2019
380bc10
revert SGD changes as much as possible.
Feb 26, 2019
273651f
Add solvers back to bench_saga
Feb 26, 2019
edc56c1
make 64 explicit in the naming
Feb 26, 2019
c17000f
remove checking native python type + add comparison between 32 64
Feb 26, 2019
e5a034d
Add whatsnew with everyone with commits
Feb 26, 2019
567e640
simplify a bit the testing
Feb 26, 2019
0574e90
simplify the parametrize
Feb 26, 2019
6bae447
update whatsnew
Feb 26, 2019
959cd8c
fix pep8
Feb 27, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,8 @@ _configtest.o.d

# Used by mypy
.mypy_cache/

# files generated from a template
sklearn/utils/seq_dataset.pyx
sklearn/utils/seq_dataset.pxd
sklearn/linear_model/sag_fast.pyx
151 changes: 106 additions & 45 deletions benchmarks/bench_saga.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Author: Arthur Mensch
"""Author: Arthur Mensch, Nelle Varoquaux

Benchmarks of sklearn SAGA vs lightning SAGA vs Liblinear. Shows the gain
in using multinomial logistic regression in term of learning time.
"""
import json
import time
from os.path import expanduser
import os

from joblib import delayed, Parallel, Memory
import matplotlib.pyplot as plt
Expand All @@ -21,7 +21,7 @@


def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
max_iter=10, skip_slow=False):
max_iter=10, skip_slow=False, dtype=np.float64):
if skip_slow and solver == 'lightning' and penalty == 'l1':
print('skip_slowping l1 logistic regression with solver lightning.')
return
Expand All @@ -37,7 +37,8 @@ def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
multi_class = 'ovr'
else:
multi_class = 'multinomial'

X = X.astype(dtype)
y = y.astype(dtype)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42,
stratify=y)
n_samples = X_train.shape[0]
Expand Down Expand Up @@ -69,11 +70,15 @@ def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
multi_class=multi_class,
C=C,
penalty=penalty,
fit_intercept=False, tol=1e-24,
fit_intercept=False, tol=0,
max_iter=this_max_iter,
random_state=42,
)

# Makes cpu cache even for all fit calls
X_train.max()
t0 = time.clock()

lr.fit(X_train, y_train)
train_time = time.clock() - t0

Expand Down Expand Up @@ -106,9 +111,13 @@ def _predict_proba(lr, X):
return softmax(pred)


def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
def exp(solvers, penalty, single_target,
n_samples=30000, max_iter=20,
dataset='rcv1', n_jobs=1, skip_slow=False):
mem = Memory(cachedir=expanduser('~/cache'), verbose=0)
dtypes_mapping = {
"float64": np.float64,
"float32": np.float32,
}

if dataset == 'rcv1':
rcv1 = fetch_rcv1()
Expand Down Expand Up @@ -151,21 +160,24 @@ def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
X = X[:n_samples]
y = y[:n_samples]

cached_fit = mem.cache(fit_single)
out = Parallel(n_jobs=n_jobs, mmap_mode=None)(
delayed(cached_fit)(solver, X, y,
delayed(fit_single)(solver, X, y,
penalty=penalty, single_target=single_target,
dtype=dtype,
C=1, max_iter=max_iter, skip_slow=skip_slow)
for solver in solvers
for penalty in penalties)
for dtype in dtypes_mapping.values())

res = []
idx = 0
for solver in solvers:
for penalty in penalties:
if not (skip_slow and solver == 'lightning' and penalty == 'l1'):
for dtype_name in dtypes_mapping.keys():
for solver in solvers:
if not (skip_slow and
solver == 'lightning' and
penalty == 'l1'):
lr, times, train_scores, test_scores, accuracies = out[idx]
this_res = dict(solver=solver, penalty=penalty,
dtype=dtype_name,
single_target=single_target,
times=times, train_scores=train_scores,
test_scores=test_scores,
Expand All @@ -177,68 +189,117 @@ def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
json.dump(res, f)


def plot():
def plot(outname=None):
import pandas as pd
with open('bench_saga.json', 'r') as f:
f = json.load(f)
res = pd.DataFrame(f)
res.set_index(['single_target', 'penalty'], inplace=True)
res.set_index(['single_target'], inplace=True)

grouped = res.groupby(level=['single_target', 'penalty'])
grouped = res.groupby(level=['single_target'])

colors = {'saga': 'blue', 'liblinear': 'orange', 'lightning': 'green'}
colors = {'saga': 'C0', 'liblinear': 'C1', 'lightning': 'C2'}
linestyles = {"float32": "--", "float64": "-"}
alpha = {"float64": 0.5, "float32": 1}

for idx, group in grouped:
single_target, penalty = idx
fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(131)

train_scores = group['train_scores'].values
ref = np.min(np.concatenate(train_scores)) * 0.999

for scores, times, solver in zip(group['train_scores'], group['times'],
group['solver']):
scores = scores / ref - 1
ax.plot(times, scores, label=solver, color=colors[solver])
single_target = idx
fig, axes = plt.subplots(figsize=(12, 4), ncols=4)
ax = axes[0]

for scores, times, solver, dtype in zip(group['train_scores'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(times, scores, label="%s - %s" % (solver, dtype),
color=colors[solver],
alpha=alpha[dtype],
marker=".",
linestyle=linestyles[dtype])
ax.axvline(times[-1], color=colors[solver],
alpha=alpha[dtype],
linestyle=linestyles[dtype])
ax.set_xlabel('Time (s)')
ax.set_ylabel('Training objective (relative to min)')
ax.set_yscale('log')

ax = fig.add_subplot(132)
ax = axes[1]

test_scores = group['test_scores'].values
ref = np.min(np.concatenate(test_scores)) * 0.999
for scores, times, solver, dtype in zip(group['test_scores'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(times, scores, label=solver, color=colors[solver],
linestyle=linestyles[dtype],
marker=".",
alpha=alpha[dtype])
ax.axvline(times[-1], color=colors[solver],
alpha=alpha[dtype],
linestyle=linestyles[dtype])

for scores, times, solver in zip(group['test_scores'], group['times'],
group['solver']):
scores = scores / ref - 1
ax.plot(times, scores, label=solver, color=colors[solver])
ax.set_xlabel('Time (s)')
ax.set_ylabel('Test objective (relative to min)')
ax.set_yscale('log')

ax = fig.add_subplot(133)
ax = axes[2]
for accuracy, times, solver, dtype in zip(group['accuracies'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(times, accuracy, label="%s - %s" % (solver, dtype),
alpha=alpha[dtype],
marker=".",
color=colors[solver], linestyle=linestyles[dtype])
ax.axvline(times[-1], color=colors[solver],
alpha=alpha[dtype],
linestyle=linestyles[dtype])

for accuracy, times, solver in zip(group['accuracies'], group['times'],
group['solver']):
ax.plot(times, accuracy, label=solver, color=colors[solver])
ax.set_xlabel('Time (s)')
ax.set_ylabel('Test accuracy')
ax.legend()
name = 'single_target' if single_target else 'multi_target'
name += '_%s' % penalty
plt.suptitle(name)
name += '.png'
if outname is None:
outname = name + '.png'
fig.tight_layout()
fig.subplots_adjust(top=0.9)
plt.savefig(name)
plt.close(fig)

ax = axes[3]
for scores, times, solver, dtype in zip(group['train_scores'],
group['times'],
group['solver'],
group["dtype"]):
ax.plot(np.arange(len(scores)),
scores, label="%s - %s" % (solver, dtype),
marker=".",
alpha=alpha[dtype],
color=colors[solver], linestyle=linestyles[dtype])

ax.set_yscale("log")
ax.set_xlabel('# iterations')
ax.set_ylabel('Objective function')
ax.legend()

plt.savefig(outname)


if __name__ == '__main__':
solvers = ['saga', 'liblinear', 'lightning']
penalties = ['l1', 'l2']
n_samples = [100000, 300000, 500000, 800000, None]
single_target = True
exp(solvers, penalties, single_target, n_samples=None, n_jobs=1,
dataset='20newspaper', max_iter=20)
plot()
for penalty in penalties:
for n_sample in n_samples:
exp(solvers, penalty, single_target,
n_samples=n_sample, n_jobs=1,
dataset='rcv1', max_iter=10)
if n_sample is not None:
outname = "figures/saga_%s_%d.png" % (penalty, n_sample)
else:
outname = "figures/saga_%s_all.png" % (penalty,)
try:
os.makedirs("figures")
except OSError:
pass
plot(outname)
5 changes: 5 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ Support for Python 3.4 and below has been officially dropped.
:mod:`sklearn.linear_model`
...........................

- |Enhancement| :class:`linear_model.make_dataset` now preserves
``float32`` and ``float64`` dtypes. :issues:`8769` and :issues:`11000` by
:user:`Nelle Varoquaux`_, :user:`Arthur Imbert <Henley13>`,
:user:`Guillaume Lemaitre <glemaitre>`, and :user:`Joan Massich <massich>`

- |Feature| :class:`linear_model.LogisticRegression` and
:class:`linear_model.LogisticRegressionCV` now support Elastic-Net penalty,
with the 'saga' solver. :issue:`11646` by :user:`Nicolas Hug <NicolasHug>`.
Expand Down
18 changes: 13 additions & 5 deletions sklearn/linear_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from ..utils.extmath import safe_sparse_dot
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
from ..utils.fixes import sparse_lsqr
from ..utils.seq_dataset import ArrayDataset, CSRDataset
from ..utils.seq_dataset import ArrayDataset32, CSRDataset32
from ..utils.seq_dataset import ArrayDataset64, CSRDataset64
from ..utils.validation import check_is_fitted
from ..exceptions import NotFittedError
from ..preprocessing.data import normalize as f_normalize
Expand Down Expand Up @@ -76,15 +77,22 @@ def make_dataset(X, y, sample_weight, random_state=None):
"""

rng = check_random_state(random_state)
# seed should never be 0 in SequentialDataset
# seed should never be 0 in SequentialDataset64
seed = rng.randint(1, np.iinfo(np.int32).max)

if X.dtype == np.float32:
CSRData = CSRDataset32
ArrayData = ArrayDataset32
else:
CSRData = CSRDataset64
ArrayData = ArrayDataset64

if sp.issparse(X):
dataset = CSRDataset(X.data, X.indptr, X.indices, y, sample_weight,
seed=seed)
dataset = CSRData(X.data, X.indptr, X.indices, y, sample_weight,
seed=seed)
intercept_decay = SPARSE_INTERCEPT_DECAY
else:
dataset = ArrayDataset(X, y, sample_weight, seed=seed)
dataset = ArrayData(X, y, sample_weight, seed=seed)
intercept_decay = 1.0

return dataset, intercept_decay
Expand Down
12 changes: 8 additions & 4 deletions sklearn/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,

elif solver in ['sag', 'saga']:
if multi_class == 'multinomial':
target = target.astype(np.float64)
target = target.astype(X.dtype, copy=False)
loss = 'multinomial'
else:
loss = 'log'
Expand Down Expand Up @@ -1487,6 +1487,10 @@ def fit(self, X, y, sample_weight=None):
Returns
-------
self : object

Notes
-----
The SAGA solver supports both float64 and float32 bit arrays.
"""
solver = _check_solver(self.solver, self.penalty, self.dual)

Expand Down Expand Up @@ -1521,10 +1525,10 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Tolerance for stopping criteria must be "
"positive; got (tol=%r)" % self.tol)

if solver in ['newton-cg']:
_dtype = [np.float64, np.float32]
else:
if solver in ['lbfgs', 'liblinear']:
_dtype = np.float64
else:
_dtype = [np.float64, np.float32]

X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype, order="C",
accept_large_sparse=solver != 'liblinear')
Expand Down
Loading