|
5 | 5 | bazel run -c opt \
|
6 | 6 | <...>/tensorflow/g3doc/how_tos/reading_data:fully_connected_preloaded
|
7 | 7 | """
|
| 8 | +from __future__ import print_function |
8 | 9 | import os.path
|
9 | 10 | import time
|
10 | 11 |
|
|
31 | 32 |
|
32 | 33 |
|
33 | 34 | def run_training():
|
34 |
| - """Train MNIST for a number of epochs.""" |
35 |
| - # Get the sets of images and labels for training, validation, and |
36 |
| - # test on MNIST. |
37 |
| - data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data) |
38 |
| - |
39 |
| - # Tell TensorFlow that the model will be built into the default Graph. |
40 |
| - with tf.Graph().as_default(): |
41 |
| - with tf.name_scope('input'): |
42 |
| - # Input data |
43 |
| - input_images = tf.constant(data_sets.train.images) |
44 |
| - input_labels = tf.constant(data_sets.train.labels) |
45 |
| - |
46 |
| - image, label = tf.train.slice_input_producer( |
47 |
| - [input_images, input_labels], num_epochs=FLAGS.num_epochs) |
48 |
| - label = tf.cast(label, tf.int32) |
49 |
| - images, labels = tf.train.batch( |
50 |
| - [image, label], batch_size=FLAGS.batch_size) |
51 |
| - |
52 |
| - # Build a Graph that computes predictions from the inference model. |
53 |
| - logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2) |
54 |
| - |
55 |
| - # Add to the Graph the Ops for loss calculation. |
56 |
| - loss = mnist.loss(logits, labels) |
57 |
| - |
58 |
| - # Add to the Graph the Ops that calculate and apply gradients. |
59 |
| - train_op = mnist.training(loss, FLAGS.learning_rate) |
60 |
| - |
61 |
| - # Add the Op to compare the logits to the labels during evaluation. |
62 |
| - eval_correct = mnist.evaluation(logits, labels) |
63 |
| - |
64 |
| - # Build the summary operation based on the TF collection of Summaries. |
65 |
| - summary_op = tf.merge_all_summaries() |
66 |
| - |
67 |
| - # Create a saver for writing training checkpoints. |
68 |
| - saver = tf.train.Saver() |
69 |
| - |
70 |
| - # Create the op for initializing variables. |
71 |
| - init_op = tf.initialize_all_variables() |
72 |
| - |
73 |
| - # Create a session for running Ops on the Graph. |
74 |
| - sess = tf.Session() |
75 |
| - |
76 |
| - # Run the Op to initialize the variables. |
77 |
| - sess.run(init_op) |
78 |
| - |
79 |
| - # Instantiate a SummaryWriter to output summaries and the Graph. |
80 |
| - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, |
81 |
| - graph_def=sess.graph_def) |
82 |
| - |
83 |
| - # Start input enqueue threads. |
84 |
| - coord = tf.train.Coordinator() |
85 |
| - threads = tf.train.start_queue_runners(sess=sess, coord=coord) |
86 |
| - |
87 |
| - # And then after everything is built, start the training loop. |
88 |
| - try: |
89 |
| - step = 0 |
90 |
| - while not coord.should_stop(): |
91 |
| - start_time = time.time() |
92 |
| - |
93 |
| - # Run one step of the model. |
94 |
| - _, loss_value = sess.run([train_op, loss]) |
95 |
| - |
96 |
| - duration = time.time() - start_time |
97 |
| - |
98 |
| - # Write the summaries and print an overview fairly often. |
99 |
| - if step % 100 == 0: |
100 |
| - # Print status to stdout. |
101 |
| - print 'Step %d: loss = %.2f (%.3f sec)' % (step, |
102 |
| - loss_value, |
103 |
| - duration) |
104 |
| - # Update the events file. |
105 |
| - summary_str = sess.run(summary_op) |
106 |
| - summary_writer.add_summary(summary_str, step) |
107 |
| - step += 1 |
108 |
| - |
109 |
| - # Save a checkpoint periodically. |
110 |
| - if (step + 1) % 1000 == 0: |
111 |
| - print 'Saving' |
112 |
| - saver.save(sess, FLAGS.train_dir, global_step=step) |
113 |
| - |
114 |
| - step += 1 |
115 |
| - except tf.errors.OutOfRangeError: |
116 |
| - print 'Saving' |
117 |
| - saver.save(sess, FLAGS.train_dir, global_step=step) |
118 |
| - print 'Done training for %d epochs, %d steps.' % ( |
119 |
| - FLAGS.num_epochs, step) |
120 |
| - finally: |
121 |
| - # When done, ask the threads to stop. |
122 |
| - coord.request_stop() |
123 |
| - |
124 |
| - # Wait for threads to finish. |
125 |
| - coord.join(threads) |
126 |
| - sess.close() |
| 35 | + """Train MNIST for a number of epochs.""" |
| 36 | + # Get the sets of images and labels for training, validation, and |
| 37 | + # test on MNIST. |
| 38 | + data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data) |
| 39 | + |
| 40 | + # Tell TensorFlow that the model will be built into the default Graph. |
| 41 | + with tf.Graph().as_default(): |
| 42 | + with tf.name_scope('input'): |
| 43 | + # Input data |
| 44 | + input_images = tf.constant(data_sets.train.images) |
| 45 | + input_labels = tf.constant(data_sets.train.labels) |
| 46 | + |
| 47 | + image, label = tf.train.slice_input_producer( |
| 48 | + [input_images, input_labels], num_epochs=FLAGS.num_epochs) |
| 49 | + label = tf.cast(label, tf.int32) |
| 50 | + images, labels = tf.train.batch( |
| 51 | + [image, label], batch_size=FLAGS.batch_size) |
| 52 | + |
| 53 | + # Build a Graph that computes predictions from the inference model. |
| 54 | + logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2) |
| 55 | + |
| 56 | + # Add to the Graph the Ops for loss calculation. |
| 57 | + loss = mnist.loss(logits, labels) |
| 58 | + |
| 59 | + # Add to the Graph the Ops that calculate and apply gradients. |
| 60 | + train_op = mnist.training(loss, FLAGS.learning_rate) |
| 61 | + |
| 62 | + # Add the Op to compare the logits to the labels during evaluation. |
| 63 | + eval_correct = mnist.evaluation(logits, labels) |
| 64 | + |
| 65 | + # Build the summary operation based on the TF collection of Summaries. |
| 66 | + summary_op = tf.merge_all_summaries() |
| 67 | + |
| 68 | + # Create a saver for writing training checkpoints. |
| 69 | + saver = tf.train.Saver() |
| 70 | + |
| 71 | + # Create the op for initializing variables. |
| 72 | + init_op = tf.initialize_all_variables() |
| 73 | + |
| 74 | + # Create a session for running Ops on the Graph. |
| 75 | + sess = tf.Session() |
| 76 | + |
| 77 | + # Run the Op to initialize the variables. |
| 78 | + sess.run(init_op) |
| 79 | + |
| 80 | + # Instantiate a SummaryWriter to output summaries and the Graph. |
| 81 | + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, |
| 82 | + graph_def=sess.graph_def) |
| 83 | + |
| 84 | + # Start input enqueue threads. |
| 85 | + coord = tf.train.Coordinator() |
| 86 | + threads = tf.train.start_queue_runners(sess=sess, coord=coord) |
| 87 | + |
| 88 | + # And then after everything is built, start the training loop. |
| 89 | + try: |
| 90 | + step = 0 |
| 91 | + while not coord.should_stop(): |
| 92 | + start_time = time.time() |
| 93 | + |
| 94 | + # Run one step of the model. |
| 95 | + _, loss_value = sess.run([train_op, loss]) |
| 96 | + |
| 97 | + duration = time.time() - start_time |
| 98 | + |
| 99 | + # Write the summaries and print an overview fairly often. |
| 100 | + if step % 100 == 0: |
| 101 | + # Print status to stdout. |
| 102 | + print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, |
| 103 | + duration)) |
| 104 | + # Update the events file. |
| 105 | + summary_str = sess.run(summary_op) |
| 106 | + summary_writer.add_summary(summary_str, step) |
| 107 | + step += 1 |
| 108 | + |
| 109 | + # Save a checkpoint periodically. |
| 110 | + if (step + 1) % 1000 == 0: |
| 111 | + print('Saving') |
| 112 | + saver.save(sess, FLAGS.train_dir, global_step=step) |
| 113 | + |
| 114 | + step += 1 |
| 115 | + except tf.errors.OutOfRangeError: |
| 116 | + print('Saving') |
| 117 | + saver.save(sess, FLAGS.train_dir, global_step=step) |
| 118 | + print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) |
| 119 | + finally: |
| 120 | + # When done, ask the threads to stop. |
| 121 | + coord.request_stop() |
| 122 | + |
| 123 | + # Wait for threads to finish. |
| 124 | + coord.join(threads) |
| 125 | + sess.close() |
127 | 126 |
|
128 | 127 |
|
129 | 128 | def main(_):
|
130 |
| - run_training() |
| 129 | + run_training() |
131 | 130 |
|
132 | 131 |
|
133 | 132 | if __name__ == '__main__':
|
134 |
| - tf.app.run() |
| 133 | + tf.app.run() |
0 commit comments