From aab94a14a385223df700ff41d5dc2b8e1a2c80cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Thu, 20 Jul 2017 16:32:05 +0200 Subject: [PATCH] Add download_if_missing argument to fetch_20newsgroups_vectorized --- sklearn/datasets/twenty_newsgroups.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sklearn/datasets/twenty_newsgroups.py b/sklearn/datasets/twenty_newsgroups.py index fe838b1be5fd0..ec6b698dad645 100644 --- a/sklearn/datasets/twenty_newsgroups.py +++ b/sklearn/datasets/twenty_newsgroups.py @@ -283,7 +283,8 @@ def fetch_20newsgroups(data_home=None, subset='train', categories=None, return data -def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None): +def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None, + download_if_missing=True): """Load the 20 newsgroups dataset and transform it into tf-idf vectors. This is a convenience function; the tf-idf transformation is done using the @@ -313,6 +314,10 @@ def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None): Specify an download and cache folder for the datasets. If None, all scikit-learn data is stored in '~/scikit_learn_data' subfolders. + download_if_missing : optional, True by default + If False, raise an IOError if the data is not locally available + instead of trying to download the data from the source site. + Returns ------- bunch : Bunch object @@ -332,14 +337,16 @@ def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None): categories=None, shuffle=True, random_state=12, - remove=remove) + remove=remove, + download_if_missing=download_if_missing) data_test = fetch_20newsgroups(data_home=data_home, subset='test', categories=None, shuffle=True, random_state=12, - remove=remove) + remove=remove, + download_if_missing=download_if_missing) if os.path.exists(target_file): X_train, X_test = joblib.load(target_file)