Skip to content

Commit aa80df5

Browse files
tensorflower-gardenerSanders Kleinfeld
authored andcommitted
Code and data for tf.contrib.learn monitors tutorial.
Change: 128588838
1 parent 3cb3995 commit aa80df5

File tree

3 files changed

+222
-0
lines changed

3 files changed

+222
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Model training for Iris data set using Validation Monitor."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
import tensorflow as tf
23+
24+
tf.logging.set_verbosity(tf.logging.INFO)
25+
26+
# Data sets
27+
IRIS_TRAINING = "iris_training.csv"
28+
IRIS_TEST = "iris_test.csv"
29+
30+
# Load datasets.
31+
training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING,
32+
target_dtype=np.int)
33+
test_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TEST,
34+
target_dtype=np.int)
35+
36+
validation_metrics = {"accuracy": tf.contrib.metrics.streaming_accuracy,
37+
"precision": tf.contrib.metrics.streaming_precision,
38+
"recall": tf.contrib.metrics.streaming_recall}
39+
validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
40+
test_set.data,
41+
test_set.target,
42+
every_n_steps=50,
43+
metrics=validation_metrics,
44+
early_stopping_metric="loss",
45+
early_stopping_metric_minimize=True,
46+
early_stopping_rounds=200)
47+
48+
# Build 3 layer DNN with 10, 20, 10 units respectively.
49+
classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10],
50+
n_classes=3,
51+
model_dir="/tmp/iris_model",
52+
config=tf.contrib.learn.RunConfig(
53+
save_checkpoints_secs=1))
54+
55+
# Fit model.
56+
classifier.fit(x=training_set.data,
57+
y=training_set.target,
58+
steps=2000,
59+
monitors=[validation_monitor])
60+
61+
# Evaluate accuracy.
62+
accuracy_score = classifier.evaluate(x=test_set.data,
63+
y=test_set.target)["accuracy"]
64+
print("Accuracy: {0:f}".format(accuracy_score))
65+
66+
# Classify two new flower samples.
67+
new_samples = np.array(
68+
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
69+
y = classifier.predict(new_samples)
70+
print("Predictions: {}".format(str(y)))
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
30,4,setosa,versicolor,virginica
2+
5.9,3.0,4.2,1.5,1
3+
6.9,3.1,5.4,2.1,2
4+
5.1,3.3,1.7,0.5,0
5+
6.0,3.4,4.5,1.6,1
6+
5.5,2.5,4.0,1.3,1
7+
6.2,2.9,4.3,1.3,1
8+
5.5,4.2,1.4,0.2,0
9+
6.3,2.8,5.1,1.5,2
10+
5.6,3.0,4.1,1.3,1
11+
6.7,2.5,5.8,1.8,2
12+
7.1,3.0,5.9,2.1,2
13+
4.3,3.0,1.1,0.1,0
14+
5.6,2.8,4.9,2.0,2
15+
5.5,2.3,4.0,1.3,1
16+
6.0,2.2,4.0,1.0,1
17+
5.1,3.5,1.4,0.2,0
18+
5.7,2.6,3.5,1.0,1
19+
4.8,3.4,1.9,0.2,0
20+
5.1,3.4,1.5,0.2,0
21+
5.7,2.5,5.0,2.0,2
22+
5.4,3.4,1.7,0.2,0
23+
5.6,3.0,4.5,1.5,1
24+
6.3,2.9,5.6,1.8,2
25+
6.3,2.5,4.9,1.5,1
26+
5.8,2.7,3.9,1.2,1
27+
6.1,3.0,4.6,1.4,1
28+
5.2,4.1,1.5,0.1,0
29+
6.7,3.1,4.7,1.5,1
30+
6.7,3.3,5.7,2.5,2
31+
6.4,2.9,4.3,1.3,1
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
120,4,setosa,versicolor,virginica
2+
6.4,2.8,5.6,2.2,2
3+
5.0,2.3,3.3,1.0,1
4+
4.9,2.5,4.5,1.7,2
5+
4.9,3.1,1.5,0.1,0
6+
5.7,3.8,1.7,0.3,0
7+
4.4,3.2,1.3,0.2,0
8+
5.4,3.4,1.5,0.4,0
9+
6.9,3.1,5.1,2.3,2
10+
6.7,3.1,4.4,1.4,1
11+
5.1,3.7,1.5,0.4,0
12+
5.2,2.7,3.9,1.4,1
13+
6.9,3.1,4.9,1.5,1
14+
5.8,4.0,1.2,0.2,0
15+
5.4,3.9,1.7,0.4,0
16+
7.7,3.8,6.7,2.2,2
17+
6.3,3.3,4.7,1.6,1
18+
6.8,3.2,5.9,2.3,2
19+
7.6,3.0,6.6,2.1,2
20+
6.4,3.2,5.3,2.3,2
21+
5.7,4.4,1.5,0.4,0
22+
6.7,3.3,5.7,2.1,2
23+
6.4,2.8,5.6,2.1,2
24+
5.4,3.9,1.3,0.4,0
25+
6.1,2.6,5.6,1.4,2
26+
7.2,3.0,5.8,1.6,2
27+
5.2,3.5,1.5,0.2,0
28+
5.8,2.6,4.0,1.2,1
29+
5.9,3.0,5.1,1.8,2
30+
5.4,3.0,4.5,1.5,1
31+
6.7,3.0,5.0,1.7,1
32+
6.3,2.3,4.4,1.3,1
33+
5.1,2.5,3.0,1.1,1
34+
6.4,3.2,4.5,1.5,1
35+
6.8,3.0,5.5,2.1,2
36+
6.2,2.8,4.8,1.8,2
37+
6.9,3.2,5.7,2.3,2
38+
6.5,3.2,5.1,2.0,2
39+
5.8,2.8,5.1,2.4,2
40+
5.1,3.8,1.5,0.3,0
41+
4.8,3.0,1.4,0.3,0
42+
7.9,3.8,6.4,2.0,2
43+
5.8,2.7,5.1,1.9,2
44+
6.7,3.0,5.2,2.3,2
45+
5.1,3.8,1.9,0.4,0
46+
4.7,3.2,1.6,0.2,0
47+
6.0,2.2,5.0,1.5,2
48+
4.8,3.4,1.6,0.2,0
49+
7.7,2.6,6.9,2.3,2
50+
4.6,3.6,1.0,0.2,0
51+
7.2,3.2,6.0,1.8,2
52+
5.0,3.3,1.4,0.2,0
53+
6.6,3.0,4.4,1.4,1
54+
6.1,2.8,4.0,1.3,1
55+
5.0,3.2,1.2,0.2,0
56+
7.0,3.2,4.7,1.4,1
57+
6.0,3.0,4.8,1.8,2
58+
7.4,2.8,6.1,1.9,2
59+
5.8,2.7,5.1,1.9,2
60+
6.2,3.4,5.4,2.3,2
61+
5.0,2.0,3.5,1.0,1
62+
5.6,2.5,3.9,1.1,1
63+
6.7,3.1,5.6,2.4,2
64+
6.3,2.5,5.0,1.9,2
65+
6.4,3.1,5.5,1.8,2
66+
6.2,2.2,4.5,1.5,1
67+
7.3,2.9,6.3,1.8,2
68+
4.4,3.0,1.3,0.2,0
69+
7.2,3.6,6.1,2.5,2
70+
6.5,3.0,5.5,1.8,2
71+
5.0,3.4,1.5,0.2,0
72+
4.7,3.2,1.3,0.2,0
73+
6.6,2.9,4.6,1.3,1
74+
5.5,3.5,1.3,0.2,0
75+
7.7,3.0,6.1,2.3,2
76+
6.1,3.0,4.9,1.8,2
77+
4.9,3.1,1.5,0.1,0
78+
5.5,2.4,3.8,1.1,1
79+
5.7,2.9,4.2,1.3,1
80+
6.0,2.9,4.5,1.5,1
81+
6.4,2.7,5.3,1.9,2
82+
5.4,3.7,1.5,0.2,0
83+
6.1,2.9,4.7,1.4,1
84+
6.5,2.8,4.6,1.5,1
85+
5.6,2.7,4.2,1.3,1
86+
6.3,3.4,5.6,2.4,2
87+
4.9,3.1,1.5,0.1,0
88+
6.8,2.8,4.8,1.4,1
89+
5.7,2.8,4.5,1.3,1
90+
6.0,2.7,5.1,1.6,1
91+
5.0,3.5,1.3,0.3,0
92+
6.5,3.0,5.2,2.0,2
93+
6.1,2.8,4.7,1.2,1
94+
5.1,3.5,1.4,0.3,0
95+
4.6,3.1,1.5,0.2,0
96+
6.5,3.0,5.8,2.2,2
97+
4.6,3.4,1.4,0.3,0
98+
4.6,3.2,1.4,0.2,0
99+
7.7,2.8,6.7,2.0,2
100+
5.9,3.2,4.8,1.8,1
101+
5.1,3.8,1.6,0.2,0
102+
4.9,3.0,1.4,0.2,0
103+
4.9,2.4,3.3,1.0,1
104+
4.5,2.3,1.3,0.3,0
105+
5.8,2.7,4.1,1.0,1
106+
5.0,3.4,1.6,0.4,0
107+
5.2,3.4,1.4,0.2,0
108+
5.3,3.7,1.5,0.2,0
109+
5.0,3.6,1.4,0.2,0
110+
5.6,2.9,3.6,1.3,1
111+
4.8,3.1,1.6,0.2,0
112+
6.3,2.7,4.9,1.8,2
113+
5.7,2.8,4.1,1.3,1
114+
5.0,3.0,1.6,0.2,0
115+
6.3,3.3,6.0,2.5,2
116+
5.0,3.5,1.6,0.6,0
117+
5.5,2.6,4.4,1.2,1
118+
5.7,3.0,4.2,1.2,1
119+
4.4,2.9,1.4,0.2,0
120+
4.8,3.0,1.4,0.1,0
121+
5.5,2.4,3.7,1.0,1

0 commit comments

Comments
 (0)