Skip to content

Commit e9b9f23

Browse files
thomasjpfanogrisel
authored andcommitted
CLN Fixes PendingDeprecationWarning in CountVectorizer (#19299)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent e1781d4 commit e9b9f23

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

sklearn/feature_extraction/tests/test_text.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -773,18 +773,30 @@ def test_vectorizer_inverse_transform(Vectorizer):
773773
vectorizer = Vectorizer()
774774
transformed_data = vectorizer.fit_transform(data)
775775
inversed_data = vectorizer.inverse_transform(transformed_data)
776+
assert isinstance(inversed_data, list)
777+
776778
analyze = vectorizer.build_analyzer()
777779
for doc, inversed_terms in zip(data, inversed_data):
778780
terms = np.sort(np.unique(analyze(doc)))
779781
inversed_terms = np.sort(np.unique(inversed_terms))
780782
assert_array_equal(terms, inversed_terms)
781783

782-
# Test that inverse_transform also works with numpy arrays
783-
transformed_data = transformed_data.toarray()
784-
inversed_data2 = vectorizer.inverse_transform(transformed_data)
784+
assert sparse.issparse(transformed_data)
785+
assert transformed_data.format == "csr"
786+
787+
# Test that inverse_transform also works with numpy arrays and
788+
# scipy
789+
transformed_data2 = transformed_data.toarray()
790+
inversed_data2 = vectorizer.inverse_transform(transformed_data2)
785791
for terms, terms2 in zip(inversed_data, inversed_data2):
786792
assert_array_equal(np.sort(terms), np.sort(terms2))
787793

794+
# Check that inverse_transform also works on non CSR sparse data:
795+
transformed_data3 = transformed_data.tocsc()
796+
inversed_data3 = vectorizer.inverse_transform(transformed_data3)
797+
for terms, terms3 in zip(inversed_data, inversed_data3):
798+
assert_array_equal(np.sort(terms), np.sort(terms3))
799+
788800

789801
def test_count_vectorizer_pipeline_grid_selection():
790802
# raw documents

sklearn/feature_extraction/text.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -1270,22 +1270,20 @@ def inverse_transform(self, X):
12701270
List of arrays of terms.
12711271
"""
12721272
self._check_vocabulary()
1273-
1274-
if sp.issparse(X):
1275-
# We need CSR format for fast row manipulations.
1276-
X = X.tocsr()
1277-
else:
1278-
# We need to convert X to a matrix, so that the indexing
1279-
# returns 2D objects
1280-
X = np.asmatrix(X)
1273+
# We need CSR format for fast row manipulations.
1274+
X = check_array(X, accept_sparse='csr')
12811275
n_samples = X.shape[0]
12821276

12831277
terms = np.array(list(self.vocabulary_.keys()))
12841278
indices = np.array(list(self.vocabulary_.values()))
12851279
inverse_vocabulary = terms[np.argsort(indices)]
12861280

1287-
return [inverse_vocabulary[X[i, :].nonzero()[1]].ravel()
1288-
for i in range(n_samples)]
1281+
if sp.issparse(X):
1282+
return [inverse_vocabulary[X[i, :].nonzero()[1]].ravel()
1283+
for i in range(n_samples)]
1284+
else:
1285+
return [inverse_vocabulary[np.flatnonzero(X[i, :])].ravel()
1286+
for i in range(n_samples)]
12891287

12901288
def get_feature_names(self):
12911289
"""Array mapping from feature integer indices to feature name.

0 commit comments

Comments
 (0)