|
38 | 38 | _NUM_CLASSES = 10
|
39 | 39 | _NUM_DATA_FILES = 5
|
40 | 40 |
|
| 41 | +# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits. |
41 | 42 | _NUM_IMAGES = {
|
42 | 43 | 'train': 50000,
|
43 | 44 | 'validation': 10000,
|
@@ -193,14 +194,14 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
|
193 | 194 | def cifar10_model_fn(features, labels, mode, params):
|
194 | 195 | """Model function for CIFAR-10."""
|
195 | 196 | features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])
|
196 |
| - |
| 197 | + # Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under. |
197 | 198 | learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
|
198 | 199 | batch_size=params['batch_size'], batch_denom=128,
|
199 |
| - num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200], |
| 200 | + num_images=_NUM_IMAGES['train'], boundary_epochs=[91, 136, 182], |
200 | 201 | decay_rates=[1, 0.1, 0.01, 0.001])
|
201 | 202 |
|
202 |
| - # We use a weight decay of 0.0002, which performs better |
203 |
| - # than the 0.0001 that was originally suggested. |
| 203 | + # Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper |
| 204 | + # and seems more stable in testing. The difference was nominal for ResNet-56. |
204 | 205 | weight_decay = 2e-4
|
205 | 206 |
|
206 | 207 | # Empirical testing showed that including batch_normalization variables
|
@@ -234,8 +235,8 @@ def define_cifar_flags():
|
234 | 235 | flags.adopt_module_key_flags(resnet_run_loop)
|
235 | 236 | flags_core.set_defaults(data_dir='/tmp/cifar10_data',
|
236 | 237 | model_dir='/tmp/cifar10_model',
|
237 |
| - resnet_size='32', |
238 |
| - train_epochs=250, |
| 238 | + resnet_size='56', |
| 239 | + train_epochs=182, |
239 | 240 | epochs_between_evals=10,
|
240 | 241 | batch_size=128)
|
241 | 242 |
|
|
0 commit comments