Skip to content

Commit 5d040fc

Browse files
lesteveagramfort
authored andcommitted
Add download_if_missing argument to fetch_20newsgroups_vectorized (#9425)
1 parent 93d0766 commit 5d040fc

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

sklearn/datasets/twenty_newsgroups.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ def fetch_20newsgroups(data_home=None, subset='train', categories=None,
283283
return data
284284

285285

286-
def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None):
286+
def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None,
287+
download_if_missing=True):
287288
"""Load the 20 newsgroups dataset and transform it into tf-idf vectors.
288289
289290
This is a convenience function; the tf-idf transformation is done using the
@@ -313,6 +314,10 @@ def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None):
313314
Specify an download and cache folder for the datasets. If None,
314315
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
315316
317+
download_if_missing : optional, True by default
318+
If False, raise an IOError if the data is not locally available
319+
instead of trying to download the data from the source site.
320+
316321
Returns
317322
-------
318323
bunch : Bunch object
@@ -332,14 +337,16 @@ def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None):
332337
categories=None,
333338
shuffle=True,
334339
random_state=12,
335-
remove=remove)
340+
remove=remove,
341+
download_if_missing=download_if_missing)
336342

337343
data_test = fetch_20newsgroups(data_home=data_home,
338344
subset='test',
339345
categories=None,
340346
shuffle=True,
341347
random_state=12,
342-
remove=remove)
348+
remove=remove,
349+
download_if_missing=download_if_missing)
343350

344351
if os.path.exists(target_file):
345352
X_train, X_test = joblib.load(target_file)

0 commit comments

Comments
 (0)