cifar10_eval.py
1 |
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
---|---|
2 |
#
|
3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
# you may not use this file except in compliance with the License.
|
5 |
# You may obtain a copy of the License at
|
6 |
#
|
7 |
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
#
|
9 |
# Unless required by applicable law or agreed to in writing, software
|
10 |
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
# ==============================================================================
|
15 |
|
16 |
"""Evaluation for CIFAR-10.
|
17 |
|
18 |
Accuracy:
|
19 |
cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs
|
20 |
of data) as judged by cifar10_eval.py.
|
21 |
|
22 |
Speed:
|
23 |
On a single Tesla K40, cifar10_train.py processes a single batch of 128 images
|
24 |
in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86%
|
25 |
accuracy after 100K steps in 8 hours of training time.
|
26 |
|
27 |
Usage:
|
28 |
Please see the tutorial and website for how to download the CIFAR-10
|
29 |
data set, compile the program and train the model.
|
30 |
|
31 |
http://tensorflow.org/tutorials/deep_cnn/
|
32 |
"""
|
33 |
from __future__ import absolute_import |
34 |
from __future__ import division |
35 |
from __future__ import print_function |
36 |
|
37 |
from datetime import datetime |
38 |
import math |
39 |
import time |
40 |
import os |
41 |
import numpy as np |
42 |
import tensorflow as tf |
43 |
|
44 |
import cifar10 |
45 |
|
46 |
FLAGS = tf.app.flags.FLAGS |
47 |
|
48 |
username = str(os.environ['USER']) |
49 |
|
50 |
tf.app.flags.DEFINE_string('eval_dir', '/tmp/'+username+'/cifar10_eval', |
51 |
"""Directory where to write event logs.""")
|
52 |
tf.app.flags.DEFINE_string('eval_data', 'test', |
53 |
"""Either 'test' or 'train_eval'.""")
|
54 |
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/'+username+'/cifar10_train', |
55 |
"""Directory where to read model checkpoints.""")
|
56 |
tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5, |
57 |
"""How often to run the eval.""")
|
58 |
tf.app.flags.DEFINE_integer('num_examples', 10000, |
59 |
"""Number of examples to run.""")
|
60 |
tf.app.flags.DEFINE_boolean('run_once', True, |
61 |
"""Whether to run eval only once.""")
|
62 |
|
63 |
|
64 |
def eval_once(saver, summary_writer, top_k_op, summary_op): |
65 |
"""Run Eval once.
|
66 |
|
67 |
Args:
|
68 |
saver: Saver.
|
69 |
summary_writer: Summary writer.
|
70 |
top_k_op: Top K op.
|
71 |
summary_op: Summary op.
|
72 |
"""
|
73 |
with tf.Session() as sess: |
74 |
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) |
75 |
if ckpt and ckpt.model_checkpoint_path: |
76 |
# Restores from checkpoint
|
77 |
saver.restore(sess, ckpt.model_checkpoint_path) |
78 |
# Assuming model_checkpoint_path looks something like:
|
79 |
# /my-favorite-path/cifar10_train/model.ckpt-0,
|
80 |
# extract global_step from it.
|
81 |
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] |
82 |
else:
|
83 |
print('No checkpoint file found')
|
84 |
return
|
85 |
|
86 |
# Start the queue runners.
|
87 |
coord = tf.train.Coordinator() |
88 |
try:
|
89 |
threads = [] |
90 |
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): |
91 |
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
|
92 |
start=True))
|
93 |
|
94 |
num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
|
95 |
true_count = 0 # Counts the number of correct predictions. |
96 |
total_sample_count = num_iter * FLAGS.batch_size |
97 |
step = 0
|
98 |
while step < num_iter and not coord.should_stop(): |
99 |
predictions = sess.run([top_k_op]) |
100 |
true_count += np.sum(predictions) |
101 |
step += 1
|
102 |
|
103 |
# Compute precision @ 1.
|
104 |
precision = true_count / total_sample_count |
105 |
#with open('out.txt', 'a') as f:
|
106 |
# print('%s: precision @ 1 = %.3f' % (datetime.now(), precision), file =f)
|
107 |
print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
|
108 |
summary = tf.Summary() |
109 |
summary.ParseFromString(sess.run(summary_op)) |
110 |
summary.value.add(tag='Precision @ 1', simple_value=precision)
|
111 |
summary_writer.add_summary(summary, global_step) |
112 |
except Exception as e: # pylint: disable=broad-except |
113 |
coord.request_stop(e) |
114 |
|
115 |
coord.request_stop() |
116 |
coord.join(threads, stop_grace_period_secs=10)
|
117 |
|
118 |
|
119 |
def evaluate(): |
120 |
"""Eval CIFAR-10 for a number of steps."""
|
121 |
with tf.Graph().as_default() as g: |
122 |
# Get images and labels for CIFAR-10.
|
123 |
eval_data = FLAGS.eval_data == 'test'
|
124 |
images, labels = cifar10.inputs(eval_data=eval_data) |
125 |
|
126 |
# Build a Graph that computes the logits predictions from the
|
127 |
# inference model.
|
128 |
logits = cifar10.inference(images) |
129 |
|
130 |
# Calculate predictions.
|
131 |
top_k_op = tf.nn.in_top_k(logits, labels, 1)
|
132 |
|
133 |
# Restore the moving average version of the learned variables for eval.
|
134 |
variable_averages = tf.train.ExponentialMovingAverage( |
135 |
cifar10.MOVING_AVERAGE_DECAY) |
136 |
variables_to_restore = variable_averages.variables_to_restore() |
137 |
saver = tf.train.Saver(variables_to_restore) |
138 |
|
139 |
# Build the summary operation based on the TF collection of Summaries.
|
140 |
summary_op = tf.summary.merge_all() |
141 |
|
142 |
summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g) |
143 |
|
144 |
while True: |
145 |
eval_once(saver, summary_writer, top_k_op, summary_op) |
146 |
if FLAGS.run_once:
|
147 |
break
|
148 |
time.sleep(FLAGS.eval_interval_secs) |
149 |
|
150 |
|
151 |
def main(argv=None): # pylint: disable=unused-argument |
152 |
cifar10.maybe_download_and_extract() |
153 |
if tf.gfile.Exists(FLAGS.eval_dir):
|
154 |
tf.gfile.DeleteRecursively(FLAGS.eval_dir) |
155 |
tf.gfile.MakeDirs(FLAGS.eval_dir) |
156 |
evaluate() |
157 |
#f.close()
|
158 |
|
159 |
if __name__ == '__main__': |
160 |
tf.app.run() |