-
-
Notifications
You must be signed in to change notification settings - Fork 26.1k
Closed
Labels
Description
Describe the bug
When using multinomial logistic regression with warm starts from a previous iteration, the final coefficients in the model are correct, but the intercepts somehow get filled with incorrect numbers somewhere.
As a result, predictions from a warm-started model differ from those of a cold-start model that has more iterations on the same data.
The issue appears to have been introduced recently as it works fine with version 1.5, but not with 1.6 or 1.7.
Steps/Code to Reproduce
import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
X, y = load_iris(return_X_y=True)
model1 = LogisticRegression(
solver="newton-cholesky",
max_iter=2
).fit(X, y)
model2 = LogisticRegression(
solver="newton-cholesky",
max_iter=1,
warm_start=True
).fit(X, y).fit(X, y)
np.testing.assert_almost_equal(
model1.coef_,
model2.coef_
)
np.testing.assert_almost_equal(
model1.predict_proba(X[:5]),
model2.predict_proba(X[:5])
)
Expected Results
Intercepts should be the same, up to shifting by a constant if needed.
Actual Results
Intercepts are different, as are predicted probabilities
Versions
System:
python: 3.12.6 | packaged by conda-forge | (main, Sep 22 2024, 14:16:49) [GCC 13.3.0]
executable: /home/david/miniforge3/bin/python
machine: Linux-6.12.33+deb12-amd64-x86_64-with-glibc2.36
Python dependencies:
sklearn: 1.7.1
pip: 24.2
setuptools: 74.1.2
numpy: 2.0.1
scipy: 1.14.1
Cython: 3.1.0
pandas: 2.2.3
matplotlib: 3.9.2
joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 20
prefix: libscipy_openblas
filepath: /home/david/.local/lib/python3.12/site-packages/numpy.libs/libscipy_openblas64_-99b71e71.so
version: 0.3.27
threading_layer: pthreads
architecture: Haswell
user_api: blas
internal_api: mkl
num_threads: 14
prefix: libmkl_rt
filepath: /home/david/miniforge3/lib/libmkl_rt.so.2
version: 2023.2-Product
threading_layer: gnu
user_api: openmp
internal_api: openmp
num_threads: 20
prefix: libgomp
filepath: /home/david/miniforge3/lib/libgomp.so.1.0.0
version: None