Skip to content

Commit 990bc2c

Browse files
tensorflower-gardenermartinwicke
authored andcommitted
Edits to tf.learn quickstart.
Change: 126101067
1 parent 1fce6b7 commit 990bc2c

File tree

1 file changed

+137
-53
lines changed
  • tensorflow/g3doc/tutorials/tflearn

1 file changed

+137
-53
lines changed
Lines changed: 137 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,62 @@
1-
## TF.Learn Quickstart
1+
## tf.contrib.learn Quickstart
22

3-
TensorFlow’s Learn API (TF.Learn) makes it easy to configure, train, and evaluate a
4-
variety of machine learning models. In this quickstart tutorial, you’ll use TF.Learn
5-
to construct a [Deep Neural Network](https://en.wikipedia.org/wiki/Artificial_neural_network)
6-
classifier model and train it on [Fisher’s Iris data set](https://en.wikipedia.org/wiki/Iris_flower_data_set)
7-
to predict flower species based on sepal/petal geometry. You’ll perform the following four steps:
3+
TensorFlow’s high-level machine learning API (tf.contrib.learn) makes it easy
4+
to configure, train, and evaluate a variety of machine learning models. In
5+
this quickstart tutorial, you’ll use tf.contrib.learn to construct a [neural
6+
network](https://en.wikipedia.org/wiki/Artificial_neural_network) classifier
7+
and train it on [Fisher’s Iris data
8+
set](https://en.wikipedia.org/wiki/Iris_flower_data_set) to predict flower
9+
species based on sepal/petal geometry. You’ll perform the following five
10+
steps:
811

912
1. Load CSVs containing Iris training/test data into a TensorFlow `Dataset`
10-
2. Construct a [Deep Neural Network classifier](https://www.tensorflow.org/versions/r0.9/api_docs/python/contrib.learn.html#DNNClassifier)
11-
3. Fit the DNN model using the training data
13+
2. Construct a [neural network classifier](
14+
../../api_docs/python/contrib.learn.html#DNNClassifier)
15+
3. Fit the model using the training data
1216
4. Evaluate the accuracy of the model
17+
5. Classify new samples
1318

1419
## Get Started
15-
Remember to [install TensorFlow on your machine](https://www.tensorflow.org/versions/r0.9/get_started/os_setup.html#download-and-setup)
16-
before getting started with this tutorial. The full code and datasets for this tutorial
17-
can be found [here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/tflearnqs/),
18-
and the following sections walk through them in detail.
20+
21+
Remember to [install TensorFlow on your
22+
machine](../../get_started/os_setup.html#download-and-setup) before getting
23+
started with this tutorial.
24+
25+
Here is the full code for our neural network:
26+
27+
```python
28+
import tensorflow as tf
29+
import numpy as np
30+
31+
# Data sets
32+
IRIS_TRAINING = "iris_training.csv"
33+
IRIS_TEST = "iris_test.csv"
34+
35+
# Load datasets.
36+
training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING, target_dtype=np.int)
37+
test_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TEST, target_dtype=np.int)
38+
39+
x_train, x_test, y_train, y_test = training_set.data, test_set.data, \
40+
training_set.target, test_set.target
41+
42+
# Build 3 layer DNN with 10, 20, 10 units respectively.
43+
classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
44+
45+
# Fit model.
46+
classifier.fit(x=x_train, y=y_train, steps=200)
47+
48+
# Evaluate accuracy.
49+
accuracy_score = classifier.evaluate(x=x_test, y=y_test)["accuracy"]
50+
print('Accuracy: {0:f}'.format(accuracy_score))
51+
52+
# Classify two new flower samples.
53+
new_samples = np.array(
54+
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
55+
y = classifier.predict(new_samples)
56+
print ('Predictions: {}'.format(str(y)))
57+
```
58+
59+
The following sections walk through the code in detail.
1960

2061
## Load the Iris CSV data to TensorFlow
2162

@@ -41,61 +82,72 @@ Sepal Length | Sepal Width | Petal Length | Petal Width | Species
4182
6.2 | 3.4 | 5.4 | 2.3 | 2
4283
5.9 | 3.0 | 5.1 | 1.8 | 2
4384

44-
<!-- TODO: The rest of this section presumes that CSVs will live in same directory as tutorial examples; if not, update links and code -->
45-
For this tutorial, the Iris data has been randomized and split into two separate CSVs:
46-
a training set of 120 samples ([iris_training.csv](https://www.tensorflow.org/code/tensorflow/examples/tutorials/tflearnqs/iris_training.csv))
47-
and a test set of 30 samples ([iris_test.csv](https://www.tensorflow.org/code/tensorflow/examples/tutorials/tflearnqs/iris_test.csv)).
85+
<!-- TODO: The rest of this section presumes that CSVs will live in same
86+
directory as tutorial examples; if not, update links and code --> For this
87+
tutorial, the Iris data has been randomized and split into two separate CSVs:
88+
a training set of 120 samples
89+
([iris_training.csv](http://download.tensorflow.org/data/iris_training.csv)).
90+
and a test set of 30 samples
91+
([iris_test.csv](http://download.tensorflow.org/data/iris_test.csv)).
4892

49-
To get started, first import TensorFlow, TF.Learn, and numpy:
93+
To get started, first import TensorFlow and numpy:
5094

5195
```python
5296
import tensorflow as tf
53-
from tensorflow.contrib import learn
5497
import numpy as np
5598
```
5699

57-
Next, load the training and test sets into `Dataset`s using the [`load_csv()`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/base.py#L36)
58-
method in `learn.datasets.base`. `load_csv()` has two required arguments:
59-
`filename`, which takes the filepath to the CSV file,
60-
and `target_dtype`, which takes the [`numpy` datatype](http://docs.scipy.org/doc/numpy/user/basics.types.html)
61-
of the dataset's target value. Here, the target (the value you're training the model to predict) is
62-
flower species, which is an integer from 0&ndash;2, so the appropriate `numpy` datatype is `np.int`:
100+
Next, load the training and test sets into `Dataset`s using the [`load_csv()`]
101+
(https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/datasets/base.py) method in `learn.datasets.base`. The
102+
`load_csv()` method has two required arguments:
103+
104+
* `filename`, which takes the filepath to the CSV file, and
105+
* `target_dtype`, which takes the [`numpy` datatype](http://docs.scipy.org/doc/numpy/user/basics.types.html) of the dataset's target value.
106+
107+
Here, the target (the value you're training the model to predict) is flower
108+
species, which is an integer from 0&ndash;2, so the appropriate `numpy`
109+
datatype is `np.int`:
63110

64111
```python
65112
# Data sets
66113
IRIS_TRAINING = "iris_training.csv"
67114
IRIS_TEST = "iris_test.csv"
68115

69116
# Load datasets.
70-
training_set = learn.datasets.base.load_csv(filename=IRIS_TRAINING, target_dtype=np.int)
71-
test_set = learn.datasets.base.load_csv(filename=IRIS_TEST, target_dtype=np.int)
117+
training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING, target_dtype=np.int)
118+
test_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TEST, target_dtype=np.int)
72119
```
73120

74-
Next, assign variables to the feature data and target values: `x_train` for training-set feature data,
75-
`x_test` for test-set feature data, `y_train` for training-set target values, and `y_test` for test-set
76-
target values. Datasets in TensorFlow are [named tuples](https://docs.python.org/2/library/collections.html#collections.namedtuple),
77-
and you can access feature data and target values via the `data` and `target` fields, respectively:
121+
Next, assign variables to the feature data and target values: `x_train` for
122+
training-set feature data, `x_test` for test-set feature data, `y_train` for
123+
training-set target values, and `y_test` for test-set target values. `Dataset`s
124+
in tf.contrib.learn are [named tuples](https://docs.python.org/2/library/collections.h
125+
tml#collections.namedtuple), and you can access feature data and target values
126+
via the `data` and `target` fields, respectively:
78127

79128
```python
80129
x_train, x_test, y_train, y_test = training_set.data, test_set.data, \
81130
training_set.target, test_set.target
82131
```
83132

84-
Later on, in "Fit the DNNClassifier to the Iris Training Data," you'll use `x_train` and `y_train` to
85-
train your model, and in "Evaluate Model Accuracy", you'll use `x_test` and `y_test`. But first,
86-
you'll construct your model in the next section.
133+
Later on, in "Fit the DNNClassifier to the Iris Training Data," you'll use
134+
`x_train` and `y_train` to train your model, and in "Evaluate Model
135+
Accuracy", you'll use `x_test` and `y_test`. But first, you'll construct your
136+
model in the next section.
87137

88138
## Construct a Deep Neural Network Classifier
89139

90-
TF.Learn offers a variety of predefined models, called [`Estimator`s](https://www.tensorflow.org/versions/r0.9/api_docs/python/contrib.learn.html#estimators),
91-
which you can use "out of the box" to run training and evaluation operations on your data.
92-
Here, you'll configure a Deep Neural Network Classifier model to fit the iris data. Using TF.Learn,
93-
you can instantiate your [`DNNClassifier`](https://www.tensorflow.org/versions/r0.9/api_docs/python/contrib.learn.html#DNNClassifier)
94-
with just one line of code:
140+
tf.contrib.learn offers a variety of predefined models, called [`Estimator`s
141+
](../../api_docs/python/contrib.learn.html#estimators), which you can use "out
142+
of the box" to run training and evaluation operations on your data. Here,
143+
you'll configure a Deep Neural Network Classifier model to fit the Iris data.
144+
Using tf.contrib.learn, you can instantiate your
145+
[`DNNClassifier`](../../api_docs/python/contrib.learn.html#DNNClassifier) with
146+
just one line of code:
95147

96148
```python
97149
# Build 3 layer DNN with 10, 20, 10 units respectively.
98-
classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
150+
classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
99151
```
100152

101153
The code above creates a `DNNClassifier` model with three [hidden layers](http://stats.stackexchange.com/questions/181/how-to-choose-the-number-of-hidden-layers-and-nodes-in-a-feedforward-neural-netw),
@@ -106,7 +158,7 @@ classes (`n_classes=3`).
106158
## Fit the DNNClassifier to the Iris Training Data
107159

108160
Now that you've configured your DNN `classifier` model, you can fit it to the Iris training data
109-
using the [`fit`](https://www.tensorflow.org/versions/r0.9/api_docs/python/contrib.learn.html#BaseEstimator.fit)
161+
using the [`fit`](../../api_docs/python/contrib.learn.html#BaseEstimator.fit)
110162
method. Pass as arguments your feature data (`x_train`), target values
111163
(`y_train`), and the number of steps to train (here, 200):
112164

@@ -128,17 +180,18 @@ classifier.fit(x=x_train, y=y_train, steps=100)
128180

129181
<!-- TODO: When tutorial exists for monitoring, link to it here -->
130182
However, if you're looking to track the model while it trains, you'll likely
131-
want to instead use a TensorFlow [`monitor`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/monitors.py)
183+
want to instead use a TensorFlow [`monitor`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/monitors.py)
132184
to perform logging operations.
133185

134186
## Evaluate Model Accuracy
135187

136-
You've fit your `DNNClassifier` model on the Iris training data; now, you can check its accuracy on
137-
the Iris test data using the [`evaluate`](https://www.tensorflow.org/versions/r0.9/api_docs/python/contrib.learn.html#BaseEstimator.evaluate)
138-
method. Like `fit`, `evaluate` takes feature data and target values as arguments,
139-
and returns a `dict` with the evaluation results. The following code passes the Iris
140-
test data&mdash;`x_test` and `y_test`&mdash;to `evaluate`, retrieves `accuracy` from the
141-
results, and prints it to output:
188+
You've fit your `DNNClassifier` model on the Iris training data; now, you can
189+
check its accuracy on the Iris test data using the [`evaluate`
190+
](../../api_docs/python/contrib.learn.html#BaseEstimator.evaluate) method.
191+
Like `fit`, `evaluate` takes feature data and target values as
192+
arguments, and returns a `dict` with the evaluation results. The following
193+
code passes the Iris test data&mdash;`x_test` and `y_test`&mdash;to `evaluate`
194+
and prints the `accuracy` from the results:
142195

143196
```python
144197
accuracy_score = classifier.evaluate(x=x_test, y=y_test)["accuracy"]
@@ -153,15 +206,46 @@ Accuracy: 0.933333
153206

154207
Not bad for a relatively small data set!
155208

209+
## Classify New Samples
210+
211+
Use the estimator's `predict()` method to classify new samples. For example,
212+
say you have these two new flower samples:
213+
214+
Sepal Length | Sepal Width | Petal Length | Petal Width
215+
:----------- | :---------- | :----------- | :----------
216+
6.4 | 3.2 | 4.5 | 1.5
217+
5.8 | 3.1 | 5.0 | 1.7
218+
219+
You can predict their species with the following code:
220+
221+
```python
222+
# Classify two new flower samples.
223+
new_samples = np.array(
224+
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
225+
y = classifier.predict(new_samples)
226+
print ('Predictions: {}'.format(str(y)))
227+
```
228+
229+
The `predict()` method returns an array of predictions, one for each sample:
230+
231+
```python
232+
Prediction: [1 2]
233+
```
234+
235+
The model thus predicts that the first sample is *Iris versicolor*, and the
236+
second sample is *Iris virginica*.
237+
156238
## Additional Resources
157239

158-
* For further reference materials on TF.Learn, see the official [API docs](https://www.tensorflow.org/versions/r0.9/api_docs/python/contrib.learn.html).
240+
* For further reference materials on tf.contrib.learn, see the official
241+
[API docs](../../api_docs/python/contrib.learn.md).
159242

160243
<!-- David, will the below be live when this tutorial is released? -->
161-
* To learn more about using TF.Learn to create linear models, see
162-
[Large-scale Linear Models with TensorFlow](https://www.tensorflow.org/versions/r0.9/tutorials/linear/index.html).
244+
* To learn more about using tf.contrib.learn to create linear models, see
245+
[Large-scale Linear Models with TensorFlow](../linear/).
163246

164-
* To experiment with neural network modeling and visualization in the browser, check out [Deep Playground](http://playground.tensorflow.org/).
247+
* To experiment with neural network modeling and visualization in the browser,
248+
check out [Deep Playground](http://playground.tensorflow.org/).
165249

166-
* For more advanced tutorials on neural networks, see [Convolutional Neural Networks](https://www.tensorflow.org/versions/r0.9/tutorials/deep_cnn/index.html)
167-
and [Recurrent Neural Networks](https://www.tensorflow.org/versions/r0.9/tutorials/recurrent/index.html).
250+
* For more advanced tutorials on neural networks, see [Convolutional Neural
251+
Networks](../deep_cnn/) and [Recurrent Neural Networks](../recurrent/).

0 commit comments

Comments
 (0)