Skip to content

[MRG] Add download_if_missing argument to fetch_20newsgroups_vectorized #9425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions sklearn/datasets/twenty_newsgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def fetch_20newsgroups(data_home=None, subset='train', categories=None,
return data


def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None):
def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None,
download_if_missing=True):
"""Load the 20 newsgroups dataset and transform it into tf-idf vectors.

This is a convenience function; the tf-idf transformation is done using the
Expand Down Expand Up @@ -313,6 +314,10 @@ def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None):
Specify an download and cache folder for the datasets. If None,
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.

download_if_missing : optional, True by default
If False, raise an IOError if the data is not locally available
instead of trying to download the data from the source site.

Returns
-------
bunch : Bunch object
Expand All @@ -332,14 +337,16 @@ def fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None):
categories=None,
shuffle=True,
random_state=12,
remove=remove)
remove=remove,
download_if_missing=download_if_missing)

data_test = fetch_20newsgroups(data_home=data_home,
subset='test',
categories=None,
shuffle=True,
random_state=12,
remove=remove)
remove=remove,
download_if_missing=download_if_missing)

if os.path.exists(target_file):
X_train, X_test = joblib.load(target_file)
Expand Down