Skip to content

Commit 3b15ca7

Browse files
yk5Taylor Robie
authored and
Taylor Robie
committed
Adds boosted_trees to the official models (tensorflow#4074)
* Add boosted_trees to the official models * Comments addressed from review, and a test added; using absl.flags instead of argparser. * Used help_wrap. Also added instructions for inference.
1 parent 499071e commit 3b15ca7

File tree

8 files changed

+568
-1
lines changed

8 files changed

+568
-1
lines changed

official/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ If you are on a version of TensorFlow earlier than 1.4, please [update your inst
1212

1313
Below is a list of the models available.
1414

15+
[boosted_trees](boosted_trees): A Gradient Boosted Trees model to classify higgs boson process from HIGGS Data Set.
16+
1517
[mnist](mnist): A basic model to classify digits from the MNIST dataset.
1618

1719
[resnet](resnet): A deep residual network that can be used to classify both CIFAR-10 and ImageNet's dataset of 1000 classes.

official/boosted_trees/README.md

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Classifying Higgs boson processes in the HIGGS Data Set
2+
## Overview
3+
The [HIGGS Data Set](https://archive.ics.uci.edu/ml/datasets/HIGGS) contains 11 million samples with 28 features, and is for the classification problem to distinguish between a signal process which produces Higgs bosons and a background process which does not.
4+
5+
We use Gradient Boosted Trees algorithm to distinguish the two classes.
6+
7+
---
8+
9+
The code sample uses the high level `tf.estimator.Estimator` and `tf.data.Dataset`. These APIs are great for fast iteration and quickly adapting models to your own datasets without major code overhauls. It allows you to move from single-worker training to distributed training, and makes it easy to export model binaries for prediction. Here, for further simplicity and faster execution, we use a utility function `tf.contrib.estimator.boosted_trees_classifier_train_in_memory`. This utility function is especially effective when the input is provided as in-memory data sets like numpy arrays.
10+
11+
An input function for the `Estimator` typically uses `tf.data.Dataset` API, which can handle various data control like streaming, batching, transform and shuffling. However `boosted_trees_classifier_train_in_memory()` utility function requires that the entire data is provided as a single batch (i.e. without using `batch()` API). Thus in this practice, simply `Dataset.from_tensors()` is used to convert numpy arrays into structured tensors, and `Dataset.zip()` is used to put features and label together.
12+
For further references of `Dataset`, [Read more here](https://www.tensorflow.org/programmers_guide/datasets).
13+
14+
## Running the code
15+
First make sure you've [added the models folder to your Python path](/official/#running-the-models); otherwise you may encounter an error like `ImportError: No module named official.boosted_trees`.
16+
17+
### Setup
18+
The [HIGGS Data Set](https://archive.ics.uci.edu/ml/datasets/HIGGS) that this sample uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). We have provided a script that downloads and cleans the necessary files.
19+
20+
```
21+
python data_download.py
22+
```
23+
24+
This will download a file and store the processed file under the directory designated by `--data_dir` (defaults to `/tmp/higgs_data/`). To change the target directory, set the `--data_dir` flag. The directory could be network storages that Tensorflow supports (like Google Cloud Storage, `gs://<bucket>/<path>/`).
25+
The file downloaded to the local temporary folder is about 2.8 GB, and the processed file is about 0.8 GB, so there should be enough storage to handle them.
26+
27+
28+
### Training
29+
30+
This example uses about 3 GB of RAM during training.
31+
You can run the code locally as follows:
32+
33+
```
34+
python train_higgs.py
35+
```
36+
37+
The model is by default saved to `/tmp/higgs_model`, which can be changed using the `--model_dir` flag.
38+
Note that the model_dir is cleaned up before every time training starts.
39+
40+
Model parameters can be adjusted by flags, like `--n_trees`, `--max_depth`, `--learning_rate` and so on. Check out the code for details.
41+
42+
The final accuacy will be around 74% and loss will be around 0.516 over the eval set, when trained with the default parameters.
43+
44+
By default, the first 1 million examples among 11 millions are used for training, and the last 1 million examples are used for evaluation.
45+
The training/evaluation data can be selected as index ranges by flags `--train_start`, `--train_count`, `--eval_start`, `--eval_count`, etc.
46+
47+
### TensorBoard
48+
49+
Run TensorBoard to inspect the details about the graph and training progression.
50+
51+
```
52+
tensorboard --logdir=/tmp/higgs_model # set logdir as --model_dir set during training.
53+
```
54+
55+
## Inference with SavedModel
56+
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model) format by using the argument `--export_dir`:
57+
58+
```
59+
python train_higgs.py --export_dir /tmp/higgs_boosted_trees_saved_model
60+
```
61+
62+
After the model finishes training, use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
63+
64+
Try the following commands to inspect the SavedModel:
65+
66+
**Replace `${TIMESTAMP}` with the folder produced (e.g. 1524249124)**
67+
```
68+
# List possible tag_sets. Only one metagraph is saved, so there will be one option.
69+
saved_model_cli show --dir /tmp/higgs_boosted_trees_saved_model/${TIMESTAMP}/
70+
71+
# Show SignatureDefs for tag_set=serve. SignatureDefs define the outputs to show.
72+
saved_model_cli show --dir /tmp/higgs_boosted_trees_saved_model/${TIMESTAMP}/ \
73+
--tag_set serve --all
74+
```
75+
76+
### Inference
77+
Let's use the model to predict the income group of two examples:
78+
79+
```
80+
saved_model_cli run --dir /tmp/boosted_trees_higgs_saved_model/${TIMESTAMP}/ \
81+
--tag_set serve --signature_def="predict" \
82+
--input_examples='examples=[{"feature_01":[0.8692932],"feature_02":[-0.6350818],"feature_03":[0.2256903],"feature_04":[0.3274701],"feature_05":[-0.6899932],"feature_06":[0.7542022],"feature_07":[-0.2485731],"feature_08":[-1.0920639],"feature_09":[0.0],"feature_10":[1.3749921],"feature_11":[-0.6536742],"feature_12":[0.9303491],"feature_13":[1.1074361],"feature_14":[1.1389043],"feature_15":[-1.5781983],"feature_16":[-1.0469854],"feature_17":[0.0],"feature_18":[0.6579295],"feature_19":[-0.0104546],"feature_20":[-0.0457672],"feature_21":[3.1019614],"feature_22":[1.3537600],"feature_23":[0.9795631],"feature_24":[0.9780762],"feature_25":[0.9200048],"feature_26":[0.7216575],"feature_27":[0.9887509],"feature_28":[0.8766783]}, {"feature_01":[1.5958393],"feature_02":[-0.6078107],"feature_03":[0.0070749],"feature_04":[1.8184496],"feature_05":[-0.1119060],"feature_06":[0.8475499],"feature_07":[-0.5664370],"feature_08":[1.5812393],"feature_09":[2.1730762],"feature_10":[0.7554210],"feature_11":[0.6431096],"feature_12":[1.4263668],"feature_13":[0.0],"feature_14":[0.9216608],"feature_15":[-1.1904324],"feature_16":[-1.6155890],"feature_17":[0.0],"feature_18":[0.6511141],"feature_19":[-0.6542270],"feature_20":[-1.2743449],"feature_21":[3.1019614],"feature_22":[0.8237606],"feature_23":[0.9381914],"feature_24":[0.9717582],"feature_25":[0.7891763],"feature_26":[0.4305533],"feature_27":[0.9613569],"feature_28":[0.9578179]}]'
83+
```
84+
85+
This will print out the predicted classes and class probabilities.
86+
87+
## Additional Links
88+
89+
If you are interested in distributed training, take a look at [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).
90+
91+
You can also [train models on Cloud ML Engine](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction), which provides [hyperparameter tuning](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#hyperparameter_tuning) to maximize your model's results and enables [deploying your model for prediction](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#deploy_a_model_to_support_prediction).

official/boosted_trees/__init__.py

Whitespace-only changes.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Downloads the UCI HIGGS Dataset and prepares train data.
2+
3+
The details on the dataset are in https://archive.ics.uci.edu/ml/datasets/HIGGS
4+
5+
It takes a while as it needs to download 2.8 GB over the network, process, then
6+
store it into the specified location as a compressed numpy file.
7+
8+
Usage:
9+
$ python data_download.py --data_dir=/tmp/higgs_data
10+
"""
11+
from __future__ import absolute_import
12+
from __future__ import division
13+
from __future__ import print_function
14+
15+
import argparse
16+
import os
17+
import sys
18+
import tempfile
19+
20+
import numpy as np
21+
import pandas as pd
22+
from six.moves import urllib
23+
import tensorflow as tf
24+
25+
URL_ROOT = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00280'
26+
INPUT_FILE = 'HIGGS.csv.gz'
27+
NPZ_FILE = 'HIGGS.csv.gz.npz' # numpy compressed file to contain 'data' array.
28+
29+
30+
def parse_args():
31+
"""Parses arguments and returns a tuple (known_args, unparsed_args)."""
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument(
34+
'--data_dir', type=str, default='/tmp/higgs_data',
35+
help='Directory to download higgs dataset and store training/eval data.')
36+
return parser.parse_known_args()
37+
38+
39+
def _download_higgs_data_and_save_npz(data_dir):
40+
"""Download higgs data and store as a numpy compressed file."""
41+
input_url = os.path.join(URL_ROOT, INPUT_FILE)
42+
np_filename = os.path.join(data_dir, NPZ_FILE)
43+
if tf.gfile.Exists(np_filename):
44+
raise ValueError('data_dir already has the processed data file: {}'.format(
45+
np_filename))
46+
if not tf.gfile.Exists(data_dir):
47+
tf.gfile.MkDir(data_dir)
48+
# 2.8 GB to download.
49+
try:
50+
print('Data downloading..')
51+
temp_filename, _ = urllib.request.urlretrieve(input_url)
52+
53+
# Reading and parsing 11 million csv lines takes 2~3 minutes.
54+
print('Data processing.. taking multiple minutes..')
55+
data = pd.read_csv(
56+
temp_filename,
57+
dtype=np.float32,
58+
names=['c%02d' % i for i in range(29)] # label + 28 features.
59+
).as_matrix()
60+
finally:
61+
os.remove(temp_filename)
62+
63+
# Writing to temporary location then copy to the data_dir (0.8 GB).
64+
f = tempfile.NamedTemporaryFile()
65+
np.savez_compressed(f, data=data)
66+
tf.gfile.Copy(f.name, np_filename)
67+
print('Data saved to: {}'.format(np_filename))
68+
69+
70+
def main(unused_argv):
71+
if not tf.gfile.Exists(FLAGS.data_dir):
72+
tf.gfile.MkDir(FLAGS.data_dir)
73+
_download_higgs_data_and_save_npz(FLAGS.data_dir)
74+
75+
76+
if __name__ == '__main__':
77+
FLAGS, unparsed = parse_args()
78+
tf.app.run(argv=[sys.argv[0]] + unparsed)

0 commit comments

Comments
 (0)