@@ -283,7 +283,8 @@ def fetch_20newsgroups(data_home=None, subset='train', categories=None,
283
283
return data
284
284
285
285
286
- def fetch_20newsgroups_vectorized (subset = "train" , remove = (), data_home = None ):
286
+ def fetch_20newsgroups_vectorized (subset = "train" , remove = (), data_home = None ,
287
+ download_if_missing = True ):
287
288
"""Load the 20 newsgroups dataset and transform it into tf-idf vectors.
288
289
289
290
This is a convenience function; the tf-idf transformation is done using the
@@ -313,6 +314,10 @@ def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None):
313
314
Specify an download and cache folder for the datasets. If None,
314
315
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
315
316
317
+ download_if_missing : optional, True by default
318
+ If False, raise an IOError if the data is not locally available
319
+ instead of trying to download the data from the source site.
320
+
316
321
Returns
317
322
-------
318
323
bunch : Bunch object
@@ -332,14 +337,16 @@ def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None):
332
337
categories = None ,
333
338
shuffle = True ,
334
339
random_state = 12 ,
335
- remove = remove )
340
+ remove = remove ,
341
+ download_if_missing = download_if_missing )
336
342
337
343
data_test = fetch_20newsgroups (data_home = data_home ,
338
344
subset = 'test' ,
339
345
categories = None ,
340
346
shuffle = True ,
341
347
random_state = 12 ,
342
- remove = remove )
348
+ remove = remove ,
349
+ download_if_missing = download_if_missing )
343
350
344
351
if os .path .exists (target_file ):
345
352
X_train , X_test = joblib .load (target_file )
0 commit comments