1
|
|
2
|
|
3
|
|
4
|
|
5
|
|
6
|
|
7
|
|
8
|
|
9
|
|
10
|
|
11
|
|
12
|
|
13
|
|
14
|
|
15
|
|
16
|
"""Builds the CIFAR-10 network.
|
17
|
|
18
|
Summary of available functions:
|
19
|
|
20
|
# Compute input images and labels for training. If you would like to run
|
21
|
# evaluations, use inputs() instead.
|
22
|
inputs, labels = distorted_inputs()
|
23
|
|
24
|
# Compute inference on the model inputs to make a prediction.
|
25
|
predictions = inference(inputs)
|
26
|
|
27
|
# Compute the total loss of the prediction with respect to the labels.
|
28
|
loss = loss(predictions, labels)
|
29
|
|
30
|
# Create a graph to run one step of training with respect to the loss.
|
31
|
train_op = train(loss, global_step)
|
32
|
"""
|
33
|
|
34
|
from __future__ import absolute_import
|
35
|
from __future__ import division
|
36
|
from __future__ import print_function
|
37
|
|
38
|
import os
|
39
|
import re
|
40
|
import sys
|
41
|
import tarfile
|
42
|
|
43
|
from six.moves import urllib
|
44
|
import tensorflow as tf
|
45
|
|
46
|
import cifar10_input
|
47
|
|
48
|
FLAGS = tf.app.flags.FLAGS
|
49
|
|
50
|
|
51
|
username = str(os.environ['USER'])
|
52
|
tf.app.flags.DEFINE_integer('batch_size', 128,
|
53
|
"""Number of images to process in a batch.""")
|
54
|
tf.app.flags.DEFINE_string('data_dir', '/tmp/'+username+'/cifar10_data',
|
55
|
"""Path to the CIFAR-10 data directory.""")
|
56
|
tf.app.flags.DEFINE_boolean('use_fp16', False,
|
57
|
"""Train the model using fp16.""")
|
58
|
|
59
|
|
60
|
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
|
61
|
NUM_CLASSES = cifar10_input.NUM_CLASSES
|
62
|
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
|
63
|
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
|
64
|
|
65
|
|
66
|
|
67
|
MOVING_AVERAGE_DECAY = 0.9999
|
68
|
NUM_EPOCHS_PER_DECAY = 350.0
|
69
|
LEARNING_RATE_DECAY_FACTOR = 0.1
|
70
|
INITIAL_LEARNING_RATE = 0.1
|
71
|
|
72
|
|
73
|
|
74
|
|
75
|
TOWER_NAME = 'tower'
|
76
|
|
77
|
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
|
78
|
|
79
|
|
80
|
def _activation_summary(x):
|
81
|
"""Helper to create summaries for activations.
|
82
|
|
83
|
Creates a summary that provides a histogram of activations.
|
84
|
Creates a summary that measures the sparsity of activations.
|
85
|
|
86
|
Args:
|
87
|
x: Tensor
|
88
|
Returns:
|
89
|
nothing
|
90
|
"""
|
91
|
|
92
|
|
93
|
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
|
94
|
tf.summary.histogram(tensor_name + '/activations', x)
|
95
|
tf.summary.scalar(tensor_name + '/sparsity',
|
96
|
tf.nn.zero_fraction(x))
|
97
|
|
98
|
|
99
|
def _variable_on_cpu(name, shape, initializer):
|
100
|
"""Helper to create a Variable stored on CPU memory.
|
101
|
|
102
|
Args:
|
103
|
name: name of the variable
|
104
|
shape: list of ints
|
105
|
initializer: initializer for Variable
|
106
|
|
107
|
Returns:
|
108
|
Variable Tensor
|
109
|
"""
|
110
|
with tf.device('/cpu:0'):
|
111
|
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
|
112
|
var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
|
113
|
return var
|
114
|
|
115
|
|
116
|
def _variable_with_weight_decay(name, shape, stddev, wd):
|
117
|
"""Helper to create an initialized Variable with weight decay.
|
118
|
|
119
|
Note that the Variable is initialized with a truncated normal distribution.
|
120
|
A weight decay is added only if one is specified.
|
121
|
|
122
|
Args:
|
123
|
name: name of the variable
|
124
|
shape: list of ints
|
125
|
stddev: standard deviation of a truncated Gaussian
|
126
|
wd: add L2Loss weight decay multiplied by this float. If None, weight
|
127
|
decay is not added for this Variable.
|
128
|
|
129
|
Returns:
|
130
|
Variable Tensor
|
131
|
"""
|
132
|
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
|
133
|
var = _variable_on_cpu(
|
134
|
name,
|
135
|
shape,
|
136
|
tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
|
137
|
if wd is not None:
|
138
|
weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
|
139
|
tf.add_to_collection('losses', weight_decay)
|
140
|
return var
|
141
|
|
142
|
|
143
|
def distorted_inputs():
|
144
|
"""Construct distorted input for CIFAR training using the Reader ops.
|
145
|
|
146
|
Returns:
|
147
|
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
148
|
labels: Labels. 1D tensor of [batch_size] size.
|
149
|
|
150
|
Raises:
|
151
|
ValueError: If no data_dir
|
152
|
"""
|
153
|
if not FLAGS.data_dir:
|
154
|
raise ValueError('Please supply a data_dir')
|
155
|
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
|
156
|
images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
|
157
|
batch_size=FLAGS.batch_size)
|
158
|
if FLAGS.use_fp16:
|
159
|
images = tf.cast(images, tf.float16)
|
160
|
labels = tf.cast(labels, tf.float16)
|
161
|
return images, labels
|
162
|
|
163
|
|
164
|
def inputs(eval_data):
|
165
|
"""Construct input for CIFAR evaluation using the Reader ops.
|
166
|
|
167
|
Args:
|
168
|
eval_data: bool, indicating if one should use the train or eval data set.
|
169
|
|
170
|
Returns:
|
171
|
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
172
|
labels: Labels. 1D tensor of [batch_size] size.
|
173
|
|
174
|
Raises:
|
175
|
ValueError: If no data_dir
|
176
|
"""
|
177
|
if not FLAGS.data_dir:
|
178
|
raise ValueError('Please supply a data_dir')
|
179
|
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
|
180
|
images, labels = cifar10_input.inputs(eval_data=eval_data,
|
181
|
data_dir=data_dir,
|
182
|
batch_size=FLAGS.batch_size)
|
183
|
if FLAGS.use_fp16:
|
184
|
images = tf.cast(images, tf.float16)
|
185
|
labels = tf.cast(labels, tf.float16)
|
186
|
return images, labels
|
187
|
|
188
|
|
189
|
def inference(images):
|
190
|
"""Build the CIFAR-10 model.
|
191
|
|
192
|
Args:
|
193
|
images: Images returned from distorted_inputs() or inputs().
|
194
|
|
195
|
Returns:
|
196
|
Logits.
|
197
|
"""
|
198
|
|
199
|
|
200
|
|
201
|
|
202
|
|
203
|
|
204
|
with tf.variable_scope('conv1') as scope:
|
205
|
kernel = _variable_with_weight_decay('weights',
|
206
|
shape=[5, 5, 3, 64],
|
207
|
stddev=5e-2,
|
208
|
wd=0.0)
|
209
|
conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
|
210
|
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
|
211
|
pre_activation = tf.nn.bias_add(conv, biases)
|
212
|
conv1 = tf.nn.relu(pre_activation, name=scope.name)
|
213
|
_activation_summary(conv1)
|
214
|
|
215
|
|
216
|
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
|
217
|
padding='SAME', name='pool1')
|
218
|
|
219
|
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
|
220
|
name='norm1')
|
221
|
|
222
|
|
223
|
with tf.variable_scope('conv2') as scope:
|
224
|
kernel = _variable_with_weight_decay('weights',
|
225
|
shape=[5, 5, 64, 64],
|
226
|
stddev=5e-2,
|
227
|
wd=0.0)
|
228
|
conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
|
229
|
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
|
230
|
pre_activation = tf.nn.bias_add(conv, biases)
|
231
|
conv2 = tf.nn.relu(pre_activation, name=scope.name)
|
232
|
_activation_summary(conv2)
|
233
|
|
234
|
|
235
|
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
|
236
|
name='norm2')
|
237
|
|
238
|
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
|
239
|
strides=[1, 2, 2, 1], padding='SAME', name='pool2')
|
240
|
|
241
|
|
242
|
with tf.variable_scope('local3') as scope:
|
243
|
|
244
|
reshape = tf.reshape(pool2, [FLAGS.batch_size, -1])
|
245
|
dim = reshape.get_shape()[1].value
|
246
|
weights = _variable_with_weight_decay('weights', shape=[dim, 384],
|
247
|
stddev=0.04, wd=0.004)
|
248
|
biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
|
249
|
local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
|
250
|
_activation_summary(local3)
|
251
|
|
252
|
|
253
|
with tf.variable_scope('local4') as scope:
|
254
|
weights = _variable_with_weight_decay('weights', shape=[384, 192],
|
255
|
stddev=0.04, wd=0.004)
|
256
|
biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
|
257
|
local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name)
|
258
|
_activation_summary(local4)
|
259
|
|
260
|
|
261
|
|
262
|
|
263
|
|
264
|
with tf.variable_scope('softmax_linear') as scope:
|
265
|
weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],
|
266
|
stddev=1/192.0, wd=0.0)
|
267
|
biases = _variable_on_cpu('biases', [NUM_CLASSES],
|
268
|
tf.constant_initializer(0.0))
|
269
|
softmax_linear = tf.add(tf.matmul(local4, weights), biases, name=scope.name)
|
270
|
_activation_summary(softmax_linear)
|
271
|
|
272
|
return softmax_linear
|
273
|
|
274
|
|
275
|
def loss(logits, labels):
|
276
|
"""Add L2Loss to all the trainable variables.
|
277
|
|
278
|
Add summary for "Loss" and "Loss/avg".
|
279
|
Args:
|
280
|
logits: Logits from inference().
|
281
|
labels: Labels from distorted_inputs or inputs(). 1-D tensor
|
282
|
of shape [batch_size]
|
283
|
|
284
|
Returns:
|
285
|
Loss tensor of type float.
|
286
|
"""
|
287
|
|
288
|
labels = tf.cast(labels, tf.int64)
|
289
|
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
290
|
labels=labels, logits=logits, name='cross_entropy_per_example')
|
291
|
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
|
292
|
tf.add_to_collection('losses', cross_entropy_mean)
|
293
|
|
294
|
|
295
|
|
296
|
return tf.add_n(tf.get_collection('losses'), name='total_loss')
|
297
|
|
298
|
|
299
|
def _add_loss_summaries(total_loss):
|
300
|
"""Add summaries for losses in CIFAR-10 model.
|
301
|
|
302
|
Generates moving average for all losses and associated summaries for
|
303
|
visualizing the performance of the network.
|
304
|
|
305
|
Args:
|
306
|
total_loss: Total loss from loss().
|
307
|
Returns:
|
308
|
loss_averages_op: op for generating moving averages of losses.
|
309
|
"""
|
310
|
|
311
|
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
|
312
|
losses = tf.get_collection('losses')
|
313
|
loss_averages_op = loss_averages.apply(losses + [total_loss])
|
314
|
|
315
|
|
316
|
|
317
|
for l in losses + [total_loss]:
|
318
|
|
319
|
|
320
|
tf.summary.scalar(l.op.name + ' (raw)', l)
|
321
|
tf.summary.scalar(l.op.name, loss_averages.average(l))
|
322
|
|
323
|
return loss_averages_op
|
324
|
|
325
|
|
326
|
def train(total_loss, global_step):
|
327
|
"""Train CIFAR-10 model.
|
328
|
|
329
|
Create an optimizer and apply to all trainable variables. Add moving
|
330
|
average for all trainable variables.
|
331
|
|
332
|
Args:
|
333
|
total_loss: Total loss from loss().
|
334
|
global_step: Integer Variable counting the number of training steps
|
335
|
processed.
|
336
|
Returns:
|
337
|
train_op: op for training.
|
338
|
"""
|
339
|
|
340
|
num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
|
341
|
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
|
342
|
|
343
|
|
344
|
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
|
345
|
global_step,
|
346
|
decay_steps,
|
347
|
LEARNING_RATE_DECAY_FACTOR,
|
348
|
staircase=True)
|
349
|
tf.summary.scalar('learning_rate', lr)
|
350
|
|
351
|
|
352
|
loss_averages_op = _add_loss_summaries(total_loss)
|
353
|
|
354
|
|
355
|
with tf.control_dependencies([loss_averages_op]):
|
356
|
opt = tf.train.GradientDescentOptimizer(lr)
|
357
|
grads = opt.compute_gradients(total_loss)
|
358
|
|
359
|
|
360
|
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
|
361
|
|
362
|
|
363
|
for var in tf.trainable_variables():
|
364
|
tf.summary.histogram(var.op.name, var)
|
365
|
|
366
|
|
367
|
for grad, var in grads:
|
368
|
if grad is not None:
|
369
|
tf.summary.histogram(var.op.name + '/gradients', grad)
|
370
|
|
371
|
|
372
|
variable_averages = tf.train.ExponentialMovingAverage(
|
373
|
MOVING_AVERAGE_DECAY, global_step)
|
374
|
variables_averages_op = variable_averages.apply(tf.trainable_variables())
|
375
|
|
376
|
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
|
377
|
train_op = tf.no_op(name='train')
|
378
|
|
379
|
return train_op
|
380
|
|
381
|
|
382
|
def maybe_download_and_extract():
|
383
|
"""Download and extract the tarball from Alex's website."""
|
384
|
dest_directory = FLAGS.data_dir
|
385
|
if not os.path.exists(dest_directory):
|
386
|
os.makedirs(dest_directory)
|
387
|
filename = DATA_URL.split('/')[-1]
|
388
|
filepath = os.path.join(dest_directory, filename)
|
389
|
if not os.path.exists(filepath):
|
390
|
print('Download strarted')
|
391
|
def _progress(count, block_size, total_size):
|
392
|
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
|
393
|
float(count * block_size) / float(total_size) * 100.0))
|
394
|
sys.stdout.flush()
|
395
|
|
396
|
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath)
|
397
|
statinfo = os.stat(filepath)
|
398
|
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
|
399
|
extracted_dir_path = os.path.join(dest_directory, 'cifar-10-batches-bin')
|
400
|
if not os.path.exists(extracted_dir_path):
|
401
|
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
|