Skip to content
14 changes: 8 additions & 6 deletions sklearn/feature_selection/mutual_info_.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..preprocessing import scale
from ..utils import check_random_state
from ..utils.fixes import _astype_copy_false
from ..utils.validation import check_X_y
from ..utils.validation import check_array, check_X_y
from ..utils.multiclass import check_classification_targets


Expand Down Expand Up @@ -247,14 +247,16 @@ def _estimate_mi(X, y, discrete_features='auto', discrete_target=False,
X, y = check_X_y(X, y, accept_sparse='csc', y_numeric=not discrete_target)
n_samples, n_features = X.shape

if discrete_features == 'auto':
discrete_features = issparse(X)

if isinstance(discrete_features, bool):
if isinstance(discrete_features, (str, bool)):
if isinstance(discrete_features, str):
if discrete_features == 'auto':
discrete_features = issparse(X)
else:
raise ValueError("Invalid string value for discrete_features.")
discrete_mask = np.empty(n_features, dtype=bool)
discrete_mask.fill(discrete_features)
else:
discrete_features = np.asarray(discrete_features)
discrete_features = check_array(discrete_features, ensure_2d=False)
if discrete_features.dtype != 'bool':
discrete_mask = np.zeros(n_features, dtype=bool)
discrete_mask[discrete_features] = True
Expand Down
18 changes: 13 additions & 5 deletions sklearn/feature_selection/tests/test_mutual_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,26 @@ def test_mutual_info_options():
X_csr = csr_matrix(X)

for mutual_info in (mutual_info_regression, mutual_info_classif):
assert_raises(ValueError, mutual_info_regression, X_csr, y,
assert_raises(ValueError, mutual_info, X_csr, y,
discrete_features=False)
assert_raises(ValueError, mutual_info, X, y,
discrete_features='manual')
assert_raises(ValueError, mutual_info, X_csr, y,
discrete_features=[True, False, True])
assert_raises(IndexError, mutual_info, X, y,
discrete_features=[True, False, True, False])
assert_raises(IndexError, mutual_info, X, y, discrete_features=[1, 4])

mi_1 = mutual_info(X, y, discrete_features='auto', random_state=0)
mi_2 = mutual_info(X, y, discrete_features=False, random_state=0)

mi_3 = mutual_info(X_csr, y, discrete_features='auto',
random_state=0)
mi_4 = mutual_info(X_csr, y, discrete_features=True,
mi_3 = mutual_info(X_csr, y, discrete_features='auto', random_state=0)
mi_4 = mutual_info(X_csr, y, discrete_features=True, random_state=0)
mi_5 = mutual_info(X, y, discrete_features=[True, False, True],
random_state=0)
mi_6 = mutual_info(X, y, discrete_features=[0, 2], random_state=0)

assert_array_equal(mi_1, mi_2)
assert_array_equal(mi_3, mi_4)
assert_array_equal(mi_5, mi_6)

assert not np.allclose(mi_1, mi_3)