Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 31 additions & 24 deletions sklearn/linear_model/tests/test_logistic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import numpy as np
import scipy.sparse as sp
from scipy import linalg, optimize, sparse
Expand All @@ -9,7 +10,7 @@
from sklearn.metrics.scorer import get_scorer
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import compute_class_weight
from sklearn.utils import compute_class_weight, _IS_32BIT
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_allclose
from sklearn.utils.testing import assert_array_almost_equal
Expand Down Expand Up @@ -1255,7 +1256,8 @@ def test_saga_vs_liblinear():
assert_array_almost_equal(saga.coef_, liblinear.coef_, 3)


def test_dtype_match():
@pytest.mark.parametrize('multi_class', ['ovr', 'multinomial'])
def test_dtype_match(multi_class):
# Test that np.float32 input data is not cast to np.float64 when possible

X_32 = np.array(X).astype(np.float32)
Expand All @@ -1264,28 +1266,33 @@ def test_dtype_match():
y_64 = np.array(Y1).astype(np.float64)
X_sparse_32 = sp.csr_matrix(X, dtype=np.float32)

for solver in ['newton-cg']:
for multi_class in ['ovr', 'multinomial']:

# Check type consistency
lr_32 = LogisticRegression(solver=solver, multi_class=multi_class,
random_state=42)
lr_32.fit(X_32, y_32)
assert_equal(lr_32.coef_.dtype, X_32.dtype)

# check consistency with sparsity
lr_32_sparse = LogisticRegression(solver=solver,
multi_class=multi_class,
random_state=42)
lr_32_sparse.fit(X_sparse_32, y_32)
assert_equal(lr_32_sparse.coef_.dtype, X_sparse_32.dtype)

# Check accuracy consistency
lr_64 = LogisticRegression(solver=solver, multi_class=multi_class,
random_state=42)
lr_64.fit(X_64, y_64)
assert_equal(lr_64.coef_.dtype, X_64.dtype)
assert_almost_equal(lr_32.coef_, lr_64.coef_.astype(np.float32))
solver = 'newton-cg'

# Check type consistency
lr_32 = LogisticRegression(solver=solver, multi_class=multi_class,
random_state=42)
lr_32.fit(X_32, y_32)
assert_equal(lr_32.coef_.dtype, X_32.dtype)

# check consistency with sparsity
lr_32_sparse = LogisticRegression(solver=solver,
multi_class=multi_class,
random_state=42)
lr_32_sparse.fit(X_sparse_32, y_32)
assert_equal(lr_32_sparse.coef_.dtype, X_sparse_32.dtype)

# Check accuracy consistency
lr_64 = LogisticRegression(solver=solver, multi_class=multi_class,
random_state=42)
lr_64.fit(X_64, y_64)
assert_equal(lr_64.coef_.dtype, X_64.dtype)

rtol = 1e-6
if os.name == 'nt' and _IS_32BIT:
# FIXME
rtol = 1e-2

assert_allclose(lr_32.coef_, lr_64.coef_.astype(np.float32), rtol=rtol)


def test_warm_start_converge_LR():
Expand Down
2 changes: 2 additions & 0 deletions sklearn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import numbers
import platform
import struct

import numpy as np
from scipy.sparse import issparse
Expand Down Expand Up @@ -34,6 +35,7 @@
"register_parallel_backend", "hash", "effective_n_jobs"]

IS_PYPY = platform.python_implementation() == 'PyPy'
_IS_32BIT = 8 * struct.calcsize("P") == 32


class Bunch(dict):
Expand Down
12 changes: 3 additions & 9 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
import traceback
import pickle
from copy import deepcopy
import struct
from functools import partial

import numpy as np
from scipy import sparse
from scipy.stats import rankdata

from sklearn.externals.six.moves import zip
from sklearn.utils import IS_PYPY
from sklearn.utils import IS_PYPY, _IS_32BIT
from sklearn.utils._joblib import hash, Memory
from sklearn.utils.testing import assert_raises, _get_args
from sklearn.utils.testing import assert_raises_regex
Expand Down Expand Up @@ -404,11 +403,6 @@ def __array__(self, dtype=None):
return self.data


def _is_32bit():
"""Detect if process is 32bit Python."""
return struct.calcsize('P') * 8 == 32


def _is_pairwise(estimator):
"""Returns True if estimator has a _pairwise attribute set to True.

Expand Down Expand Up @@ -943,7 +937,7 @@ def check_transformers_unfitted(name, transformer):


def _check_transformer(name, transformer_orig, X, y):
if name in ('CCA', 'LocallyLinearEmbedding', 'KernelPCA') and _is_32bit():
if name in ('CCA', 'LocallyLinearEmbedding', 'KernelPCA') and _IS_32BIT:
# Those transformers yield non-deterministic output when executed on
# a 32bit Python. The same transformers are stable on 64bit Python.
# FIXME: try to isolate a minimalistic reproduction case only depending
Expand Down Expand Up @@ -1021,7 +1015,7 @@ def _check_transformer(name, transformer_orig, X, y):

@ignore_warnings
def check_pipeline_consistency(name, estimator_orig):
if name in ('CCA', 'LocallyLinearEmbedding', 'KernelPCA') and _is_32bit():
if name in ('CCA', 'LocallyLinearEmbedding', 'KernelPCA') and _IS_32BIT:
# Those transformers yield non-deterministic output when executed on
# a 32bit Python. The same transformers are stable on 64bit Python.
# FIXME: try to isolate a minimalistic reproduction case only depending
Expand Down
5 changes: 2 additions & 3 deletions sklearn/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import pkgutil
import warnings
import sys
import struct
import functools

import scipy as sp
Expand Down Expand Up @@ -47,7 +46,7 @@
from sklearn.base import BaseEstimator
from sklearn.externals import joblib
from sklearn.utils.fixes import signature
from sklearn.utils import deprecated, IS_PYPY
from sklearn.utils import deprecated, IS_PYPY, _IS_32BIT


additional_names_in_all = []
Expand Down Expand Up @@ -758,7 +757,7 @@ def run_test(*args, **kwargs):
try:
import pytest

skip_if_32bit = pytest.mark.skipif(8 * struct.calcsize("P") == 32,
skip_if_32bit = pytest.mark.skipif(_IS_32BIT,
reason='skipped on 32bit platforms')
skip_travis = pytest.mark.skipif(os.environ.get('TRAVIS') == 'true',
reason='skip on travis')
Expand Down