Проект

Общее

Профиль

cifar10_multi_gpu_train.py

Сергей Мальковский, 27.09.2017 15:45

Загрузить (10,7 КБ)

 
1
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
# ==============================================================================
15

    
16
"""A binary to train CIFAR-10 using multiple GPUs with synchronous updates.
17

18
Accuracy:
19
cifar10_multi_gpu_train.py achieves ~86% accuracy after 100K steps (256
20
epochs of data) as judged by cifar10_eval.py.
21

22
Speed: With batch_size 128.
23

24
System        | Step Time (sec/batch)  |     Accuracy
25
--------------------------------------------------------------------
26
1 Tesla K20m  | 0.35-0.60              | ~86% at 60K steps  (5 hours)
27
1 Tesla K40m  | 0.25-0.35              | ~86% at 100K steps (4 hours)
28
2 Tesla K20m  | 0.13-0.20              | ~84% at 30K steps  (2.5 hours)
29
3 Tesla K20m  | 0.13-0.18              | ~84% at 30K steps
30
4 Tesla K20m  | ~0.10                  | ~84% at 30K steps
31

32
Usage:
33
Please see the tutorial and website for how to download the CIFAR-10
34
data set, compile the program and train the model.
35

36
http://tensorflow.org/tutorials/deep_cnn/
37
"""
38
from __future__ import absolute_import
39
from __future__ import division
40
from __future__ import print_function
41

    
42
from datetime import datetime
43
import os.path
44
import os
45
import re
46
import time
47
import sys
48
import numpy as np
49
from six.moves import xrange  # pylint: disable=redefined-builtin
50
import tensorflow as tf
51
import cifar10
52

    
53
FLAGS = tf.app.flags.FLAGS
54

    
55
username = str(os.environ['USER'])
56

    
57
tf.app.flags.DEFINE_string('train_dir', '/tmp/'+username+'/cifar10_train',
58
                           """Directory where to write event logs """
59
                           """and checkpoint.""")
60
tf.app.flags.DEFINE_integer('max_steps', 500,
61
                            """Number of batches to run.""")
62
tf.app.flags.DEFINE_integer('num_gpus', 2,
63
                            """How many GPUs to use.""")
64
tf.app.flags.DEFINE_boolean('log_device_placement', False,
65
                            """Whether to log device placement.""")
66

    
67

    
68

    
69

    
70
def tower_loss(scope, images, labels):
71
  """Calculate the total loss on a single tower running the CIFAR model.
72

73
  Args:
74
    scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0'
75
    images: Images. 4D tensor of shape [batch_size, height, width, 3].
76
    labels: Labels. 1D tensor of shape [batch_size].
77

78
  Returns:
79
     Tensor of shape [] containing the total loss for a batch of data
80
  """
81

    
82
  # Build inference Graph.
83
  logits = cifar10.inference(images)
84

    
85
  # Build the portion of the Graph calculating the losses. Note that we will
86
  # assemble the total_loss using a custom function below.
87
  _ = cifar10.loss(logits, labels)
88

    
89
  # Assemble all of the losses for the current tower only.
90
  losses = tf.get_collection('losses', scope)
91

    
92
  # Calculate the total loss for the current tower.
93
  total_loss = tf.add_n(losses, name='total_loss')
94

    
95
  # Attach a scalar summary to all individual losses and the total loss; do the
96
  # same for the averaged version of the losses.
97
  for l in losses + [total_loss]:
98
    # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
99
    # session. This helps the clarity of presentation on tensorboard.
100
    loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', l.op.name)
101
    tf.summary.scalar(loss_name, l)
102

    
103
  return total_loss
104

    
105

    
106
def average_gradients(tower_grads):
107
  """Calculate the average gradient for each shared variable across all towers.
108

109
  Note that this function provides a synchronization point across all towers.
110

111
  Args:
112
    tower_grads: List of lists of (gradient, variable) tuples. The outer list
113
      is over individual gradients. The inner list is over the gradient
114
      calculation for each tower.
115
  Returns:
116
     List of pairs of (gradient, variable) where the gradient has been averaged
117
     across all towers.
118
  """
119
  average_grads = []
120
  for grad_and_vars in zip(*tower_grads):
121
    # Note that each grad_and_vars looks like the following:
122
    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
123
    grads = []
124
    for g, _ in grad_and_vars:
125
      # Add 0 dimension to the gradients to represent the tower.
126
      expanded_g = tf.expand_dims(g, 0)
127

    
128
      # Append on a 'tower' dimension which we will average over below.
129
      grads.append(expanded_g)
130

    
131
    # Average over the 'tower' dimension.
132
    grad = tf.concat(axis=0, values=grads)
133
    grad = tf.reduce_mean(grad, 0)
134

    
135
    # Keep in mind that the Variables are redundant because they are shared
136
    # across towers. So .. we will just return the first tower's pointer to
137
    # the Variable.
138
    v = grad_and_vars[0][1]
139
    grad_and_var = (grad, v)
140
    average_grads.append(grad_and_var)
141
  return average_grads
142

    
143

    
144
def train():
145
  """Train CIFAR-10 for a number of steps."""
146
  with tf.Graph().as_default(), tf.device('/cpu:0'):
147
    # Create a variable to count the number of train() calls. This equals the
148
    # number of batches processed * FLAGS.num_gpus.
149
    global_step = tf.get_variable(
150
        'global_step', [],
151
        initializer=tf.constant_initializer(0), trainable=False)
152

    
153
    # Calculate the learning rate schedule.
154
    num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
155
                             FLAGS.batch_size)
156
    decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY)
157

    
158
    # Decay the learning rate exponentially based on the number of steps.
159
    lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE,
160
                                    global_step,
161
                                    decay_steps,
162
                                    cifar10.LEARNING_RATE_DECAY_FACTOR,
163
                                    staircase=True)
164

    
165
    # Create an optimizer that performs gradient descent.
166
    opt = tf.train.GradientDescentOptimizer(lr)
167

    
168
    # Get images and labels for CIFAR-10.
169
    images, labels = cifar10.distorted_inputs()
170
    batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue(
171
          [images, labels], capacity=2 * FLAGS.num_gpus)
172
    # Calculate the gradients for each model tower.
173
    tower_grads = []
174
    with tf.variable_scope(tf.get_variable_scope()):
175
      for i in xrange(FLAGS.num_gpus):
176
        with tf.device('/gpu:%d' % i):
177
          with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
178
            # Dequeues one batch for the GPU
179
            image_batch, label_batch = batch_queue.dequeue()
180
            # Calculate the loss for one tower of the CIFAR model. This function
181
            # constructs the entire CIFAR model but shares the variables across
182
            # all towers.
183
            loss = tower_loss(scope, image_batch, label_batch)
184

    
185
            # Reuse variables for the next tower.
186
            tf.get_variable_scope().reuse_variables()
187

    
188
            # Retain the summaries from the final tower.
189
            summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
190

    
191
            # Calculate the gradients for the batch of data on this CIFAR tower.
192
            grads = opt.compute_gradients(loss)
193

    
194
            # Keep track of the gradients across all towers.
195
            tower_grads.append(grads)
196

    
197
    # We must calculate the mean of each gradient. Note that this is the
198
    # synchronization point across all towers.
199
    grads = average_gradients(tower_grads)
200

    
201
    # Add a summary to track the learning rate.
202
    summaries.append(tf.summary.scalar('learning_rate', lr))
203

    
204
    # Add histograms for gradients.
205
    for grad, var in grads:
206
      if grad is not None:
207
        summaries.append(tf.summary.histogram(var.op.name + '/gradients', grad))
208

    
209
    # Apply the gradients to adjust the shared variables.
210
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
211

    
212
    # Add histograms for trainable variables.
213
    for var in tf.trainable_variables():
214
      summaries.append(tf.summary.histogram(var.op.name, var))
215

    
216
    # Track the moving averages of all trainable variables.
217
    variable_averages = tf.train.ExponentialMovingAverage(
218
        cifar10.MOVING_AVERAGE_DECAY, global_step)
219
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
220

    
221
    # Group all updates to into a single train op.
222
    train_op = tf.group(apply_gradient_op, variables_averages_op)
223

    
224
    # Create a saver.
225
    saver = tf.train.Saver(tf.global_variables())
226

    
227
    # Build the summary operation from the last tower summaries.
228
    summary_op = tf.summary.merge(summaries)
229

    
230
    # Build an initialization operation to run below.
231
    init = tf.global_variables_initializer()
232

    
233
    # Start running operations on the Graph. allow_soft_placement must be set to
234
    # True to build towers on GPU, as some of the ops do not have GPU
235
    # implementations. #changed soft placement - was True
236
    sess = tf.Session(config=tf.ConfigProto(
237
        allow_soft_placement=True,
238
        log_device_placement=FLAGS.log_device_placement))
239
    sess.run(init)
240

    
241
    # Start the queue runners.
242
    tf.train.start_queue_runners(sess=sess)
243

    
244
    summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
245

    
246
    for step in xrange(FLAGS.max_steps):
247
      start_time = time.time()
248
      _, loss_value = sess.run([train_op, loss])
249
      duration = time.time() - start_time
250

    
251
      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
252

    
253
      if step % 10 == 0:
254
        num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
255
        examples_per_sec = num_examples_per_step / duration
256
        sec_per_batch = duration / FLAGS.num_gpus
257

    
258
        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
259
                      'sec/batch)')
260
        #with open('out.txt', 'a') as f:
261
        #  print (format_str % (datetime.now(), step, loss_value,
262
        #                     examples_per_sec, sec_per_batch),file=f)
263
        print (format_str % (datetime.now(), step, loss_value,
264
                             examples_per_sec, sec_per_batch))
265

    
266
      if step % 100 == 0:
267
        summary_str = sess.run(summary_op)
268
        summary_writer.add_summary(summary_str, step)
269

    
270
      # Save the model checkpoint periodically.
271
      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
272
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
273
        saver.save(sess, checkpoint_path, global_step=step)
274

    
275

    
276
def main(argv=None):  # pylint: disable=unused-argument
277
  cifar10.maybe_download_and_extract()
278
  if tf.gfile.Exists(FLAGS.train_dir):
279
    tf.gfile.DeleteRecursively(FLAGS.train_dir)
280
  tf.gfile.MakeDirs(FLAGS.train_dir)
281
  train()
282
  #f.close()
283

    
284

    
285

    
286
if __name__ == '__main__':
287
  tf.app.run()