From 628c8d6272ac2439587b229aa98c8ef154747cfe Mon Sep 17 00:00:00 2001 From: Jonathan Hseu Date: Mon, 31 Oct 2016 16:55:26 -0700 Subject: [PATCH] Convert the MNIST example to use best practices (#17) --- docker/Dockerfile | 6 +- docker/Dockerfile.hdfs | 6 +- docker/README.md | 19 +++- docker/mnist.py | 177 ++++++++++++++++++++++++++++++++ docker/mnist_replica.py | 216 ---------------------------------------- marathon/README.md | 4 +- 6 files changed, 199 insertions(+), 229 deletions(-) create mode 100644 docker/mnist.py delete mode 100644 docker/mnist_replica.py diff --git a/docker/Dockerfile b/docker/Dockerfile index 92b6ae7a..5a69d1ff 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM tensorflow/tensorflow:0.11.0rc1 +FROM tensorflow/tensorflow:nightly -COPY mnist_replica.py / -ENTRYPOINT ["python", "/mnist_replica.py"] +COPY mnist.py / +ENTRYPOINT ["python", "/mnist.py"] diff --git a/docker/Dockerfile.hdfs b/docker/Dockerfile.hdfs index c04660db..80b64c45 100644 --- a/docker/Dockerfile.hdfs +++ b/docker/Dockerfile.hdfs @@ -1,4 +1,4 @@ -FROM tensorflow/tensorflow:0.11.0rc1 +FROM tensorflow/tensorflow:nightly # Install java RUN add-apt-repository -y ppa:openjdk-r/ppa && \ @@ -23,6 +23,6 @@ ENV LD_LIBRARY_PATH $LD_LIBRARY_PATH:$JAVA_HOME/jre/lib/amd64/server ENV CLASSPATH /usr/local/hadoop/etc/hadoop:/usr/local/hadoop/share/hadoop/common/lib/httpcore-4.2.5.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-configuration-1.6.jar:/usr/local/hadoop/share/hadoop/common/lib/jackson-xc-1.9.13.jar:/usr/local/hadoop/share/hadoop/common/lib/gson-2.2.4.jar:/usr/local/hadoop/share/hadoop/common/lib/snappy-java-1.0.4.1.jar:/usr/local/hadoop/share/hadoop/common/lib/jaxb-api-2.2.2.jar:/usr/local/hadoop/share/hadoop/common/lib/paranamer-2.3.jar:/usr/local/hadoop/share/hadoop/common/lib/apacheds-kerberos-codec-2.0.0-M15.jar:/usr/local/hadoop/share/hadoop/common/lib/netty-3.6.2.Final.jar:/usr/local/hadoop/share/hadoop/common/lib/hadoop-annotations-2.7.3.jar:/usr/local/hadoop/share/hadoop/common/lib/api-asn1-api-1.0.0-M20.jar:/usr/local/hadoop/share/hadoop/common/lib/xz-1.0.jar:/usr/local/hadoop/share/hadoop/common/lib/java-xmlbuilder-0.4.jar:/usr/local/hadoop/share/hadoop/common/lib/jetty-util-6.1.26.jar:/usr/local/hadoop/share/hadoop/common/lib/slf4j-api-1.7.10.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-cli-1.2.jar:/usr/local/hadoop/share/hadoop/common/lib/servlet-api-2.5.jar:/usr/local/hadoop/share/hadoop/common/lib/jsp-api-2.1.jar:/usr/local/hadoop/share/hadoop/common/lib/protobuf-java-2.5.0.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-io-2.4.jar:/usr/local/hadoop/share/hadoop/common/lib/curator-recipes-2.7.1.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-compress-1.4.1.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-beanutils-1.7.0.jar:/usr/local/hadoop/share/hadoop/common/lib/mockito-all-1.8.5.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-lang-2.6.jar:/usr/local/hadoop/share/hadoop/common/lib/curator-client-2.7.1.jar:/usr/local/hadoop/share/hadoop/common/lib/jersey-json-1.9.jar:/usr/local/hadoop/share/hadoop/common/lib/jackson-jaxrs-1.9.13.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-httpclient-3.1.jar:/usr/local/hadoop/share/hadoop/common/lib/zookeeper-3.4.6.jar:/usr/local/hadoop/share/hadoop/common/lib/curator-framework-2.7.1.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-net-3.1.jar:/usr/local/hadoop/share/hadoop/common/lib/xmlenc-0.52.jar:/usr/local/hadoop/share/hadoop/common/lib/avro-1.7.4.jar:/usr/local/hadoop/share/hadoop/common/lib/jettison-1.1.jar:/usr/local/hadoop/share/hadoop/common/lib/jackson-mapper-asl-1.9.13.jar:/usr/local/hadoop/share/hadoop/common/lib/api-util-1.0.0-M20.jar:/usr/local/hadoop/share/hadoop/common/lib/activation-1.1.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-codec-1.4.jar:/usr/local/hadoop/share/hadoop/common/lib/stax-api-1.0-2.jar:/usr/local/hadoop/share/hadoop/common/lib/apacheds-i18n-2.0.0-M15.jar:/usr/local/hadoop/share/hadoop/common/lib/jersey-server-1.9.jar:/usr/local/hadoop/share/hadoop/common/lib/jackson-core-asl-1.9.13.jar:/usr/local/hadoop/share/hadoop/common/lib/hadoop-auth-2.7.3.jar:/usr/local/hadoop/share/hadoop/common/lib/jetty-6.1.26.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-beanutils-core-1.8.0.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-collections-3.2.2.jar:/usr/local/hadoop/share/hadoop/common/lib/junit-4.11.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-digester-1.8.jar:/usr/local/hadoop/share/hadoop/common/lib/hamcrest-core-1.3.jar:/usr/local/hadoop/share/hadoop/common/lib/jersey-core-1.9.jar:/usr/local/hadoop/share/hadoop/common/lib/slf4j-log4j12-1.7.10.jar:/usr/local/hadoop/share/hadoop/common/lib/jsch-0.1.42.jar:/usr/local/hadoop/share/hadoop/common/lib/jaxb-impl-2.2.3-1.jar:/usr/local/hadoop/share/hadoop/common/lib/guava-11.0.2.jar:/usr/local/hadoop/share/hadoop/common/lib/httpclient-4.2.5.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-logging-1.1.3.jar:/usr/local/hadoop/share/hadoop/common/lib/htrace-core-3.1.0-incubating.jar:/usr/local/hadoop/share/hadoop/common/lib/asm-3.2.jar:/usr/local/hadoop/share/hadoop/common/lib/jsr305-3.0.0.jar:/usr/local/hadoop/share/hadoop/common/lib/commons-math3-3.1.1.jar:/usr/local/hadoop/share/hadoop/common/lib/jets3t-0.9.0.jar:/usr/local/hadoop/share/hadoop/common/lib/log4j-1.2.17.jar:/usr/local/hadoop/share/hadoop/common/hadoop-common-2.7.3.jar:/usr/local/hadoop/share/hadoop/common/hadoop-common-2.7.3-tests.jar:/usr/local/hadoop/share/hadoop/common/hadoop-nfs-2.7.3.jar:/usr/local/hadoop/share/hadoop/hdfs:/usr/local/hadoop/share/hadoop/hdfs/lib/commons-daemon-1.0.13.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/netty-3.6.2.Final.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/xercesImpl-2.9.1.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/jetty-util-6.1.26.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/commons-cli-1.2.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/servlet-api-2.5.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/netty-all-4.0.23.Final.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/protobuf-java-2.5.0.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/commons-io-2.4.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/leveldbjni-all-1.8.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/commons-lang-2.6.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/xmlenc-0.52.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/jackson-mapper-asl-1.9.13.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/commons-codec-1.4.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/jersey-server-1.9.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/jackson-core-asl-1.9.13.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/jetty-6.1.26.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/jersey-core-1.9.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/guava-11.0.2.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/commons-logging-1.1.3.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/htrace-core-3.1.0-incubating.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/asm-3.2.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/jsr305-3.0.0.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/xml-apis-1.3.04.jar:/usr/local/hadoop/share/hadoop/hdfs/lib/log4j-1.2.17.jar:/usr/local/hadoop/share/hadoop/hdfs/hadoop-hdfs-2.7.3.jar:/usr/local/hadoop/share/hadoop/hdfs/hadoop-hdfs-2.7.3-tests.jar:/usr/local/hadoop/share/hadoop/hdfs/hadoop-hdfs-nfs-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jackson-xc-1.9.13.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jaxb-api-2.2.2.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jersey-client-1.9.jar:/usr/local/hadoop/share/hadoop/yarn/lib/netty-3.6.2.Final.jar:/usr/local/hadoop/share/hadoop/yarn/lib/xz-1.0.jar:/usr/local/hadoop/share/hadoop/yarn/lib/aopalliance-1.0.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jetty-util-6.1.26.jar:/usr/local/hadoop/share/hadoop/yarn/lib/commons-cli-1.2.jar:/usr/local/hadoop/share/hadoop/yarn/lib/servlet-api-2.5.jar:/usr/local/hadoop/share/hadoop/yarn/lib/protobuf-java-2.5.0.jar:/usr/local/hadoop/share/hadoop/yarn/lib/commons-io-2.4.jar:/usr/local/hadoop/share/hadoop/yarn/lib/commons-compress-1.4.1.jar:/usr/local/hadoop/share/hadoop/yarn/lib/javax.inject-1.jar:/usr/local/hadoop/share/hadoop/yarn/lib/leveldbjni-all-1.8.jar:/usr/local/hadoop/share/hadoop/yarn/lib/commons-lang-2.6.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jersey-json-1.9.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jackson-jaxrs-1.9.13.jar:/usr/local/hadoop/share/hadoop/yarn/lib/zookeeper-3.4.6.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jersey-guice-1.9.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jettison-1.1.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jackson-mapper-asl-1.9.13.jar:/usr/local/hadoop/share/hadoop/yarn/lib/zookeeper-3.4.6-tests.jar:/usr/local/hadoop/share/hadoop/yarn/lib/activation-1.1.jar:/usr/local/hadoop/share/hadoop/yarn/lib/commons-codec-1.4.jar:/usr/local/hadoop/share/hadoop/yarn/lib/stax-api-1.0-2.jar:/usr/local/hadoop/share/hadoop/yarn/lib/guice-3.0.jar:/usr/local/hadoop/share/hadoop/yarn/lib/guice-servlet-3.0.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jersey-server-1.9.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jackson-core-asl-1.9.13.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jetty-6.1.26.jar:/usr/local/hadoop/share/hadoop/yarn/lib/commons-collections-3.2.2.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jersey-core-1.9.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jaxb-impl-2.2.3-1.jar:/usr/local/hadoop/share/hadoop/yarn/lib/guava-11.0.2.jar:/usr/local/hadoop/share/hadoop/yarn/lib/commons-logging-1.1.3.jar:/usr/local/hadoop/share/hadoop/yarn/lib/asm-3.2.jar:/usr/local/hadoop/share/hadoop/yarn/lib/jsr305-3.0.0.jar:/usr/local/hadoop/share/hadoop/yarn/lib/log4j-1.2.17.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-server-tests-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-api-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-server-nodemanager-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-server-applicationhistoryservice-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-server-common-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-registry-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-server-sharedcachemanager-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-client-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-applications-unmanaged-am-launcher-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-server-resourcemanager-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-applications-distributedshell-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-common-2.7.3.jar:/usr/local/hadoop/share/hadoop/yarn/hadoop-yarn-server-web-proxy-2.7.3.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/snappy-java-1.0.4.1.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/paranamer-2.3.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/netty-3.6.2.Final.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/hadoop-annotations-2.7.3.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/xz-1.0.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/aopalliance-1.0.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/protobuf-java-2.5.0.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/commons-io-2.4.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/commons-compress-1.4.1.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/javax.inject-1.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/leveldbjni-all-1.8.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/jersey-guice-1.9.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/avro-1.7.4.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/jackson-mapper-asl-1.9.13.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/guice-3.0.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/guice-servlet-3.0.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/jersey-server-1.9.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/jackson-core-asl-1.9.13.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/junit-4.11.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/hamcrest-core-1.3.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/jersey-core-1.9.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/asm-3.2.jar:/usr/local/hadoop/share/hadoop/mapreduce/lib/log4j-1.2.17.jar:/usr/local/hadoop/share/hadoop/mapreduce/hadoop-mapreduce-client-hs-plugins-2.7.3.jar:/usr/local/hadoop/share/hadoop/mapreduce/hadoop-mapreduce-client-hs-2.7.3.jar:/usr/local/hadoop/share/hadoop/mapreduce/hadoop-mapreduce-client-common-2.7.3.jar:/usr/local/hadoop/share/hadoop/mapreduce/hadoop-mapreduce-examples-2.7.3.jar:/usr/local/hadoop/share/hadoop/mapreduce/hadoop-mapreduce-client-app-2.7.3.jar:/usr/local/hadoop/share/hadoop/mapreduce/hadoop-mapreduce-client-jobclient-2.7.3-tests.jar:/usr/local/hadoop/share/hadoop/mapreduce/hadoop-mapreduce-client-core-2.7.3.jar:/usr/local/hadoop/share/hadoop/mapreduce/hadoop-mapreduce-client-jobclient-2.7.3.jar:/usr/local/hadoop/share/hadoop/mapreduce/hadoop-mapreduce-client-shuffle-2.7.3.jar:/contrib/capacity-scheduler/*.jar -COPY mnist_replica.py / +COPY mnist.py / -ENTRYPOINT ["python", "/mnist_replica.py"] +ENTRYPOINT ["python", "/mnist.py"] diff --git a/docker/README.md b/docker/README.md index 01e36b98..6379a19a 100644 --- a/docker/README.md +++ b/docker/README.md @@ -1,4 +1,4 @@ -# TensorFlow Docker images +# TensorFlow Docker Images This directory contains example Dockerfiles to run TensorFlow on cluster managers. @@ -7,10 +7,10 @@ managers. training program on top of the tensorflow/tensorflow Docker image. - [Dockerfile.hdfs](Dockerfile.hdfs) installs Hadoop libraries and sets the appropriate environment variables to enable reading from HDFS. -- [mnist_replica.py](mnist_replica.py) demonstrates the programmatic setup - required for distributed TensorFlow training. +- [mnist.py](mnist.py) demonstrates the programmatic setup for distributed + TensorFlow training. -## Best practices +## Best Practices - Always pin the TensorFlow version with the Docker image tag. This ensures that TensorFlow updates don't adversely impact your training program for future @@ -20,7 +20,7 @@ managers. Docker image if they have them cached. Also, versions ensure that you have a single copy of the code running for each job. -## Building the Docker files +## Building the Docker Files First, pick an image name for the job. When running on a cluster manager, you will want to push your images to a container registry. Note that both the @@ -37,3 +37,12 @@ docker push :v1 If you make any updates to the code, increment the version and rerun the above commands with the new version. + +## Running the mnist Example + +The [mnist.py](mnist.py) example reads the mnist data in the TFRecords format. +You can run the [convert_to_records.py](https://github.com/tensorflow/tensorflow/blob/r0.11/tensorflow/examples/how_tos/reading_data/convert_to_records.py) +program to convert mnist data to TFRecords. + +When running distributed TensorFlow, you should upload the converted data to +a common location on distributed storage, such as GCS or HDFS. diff --git a/docker/mnist.py b/docker/mnist.py new file mode 100644 index 00000000..a02672c5 --- /dev/null +++ b/docker/mnist.py @@ -0,0 +1,177 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import print_function + +import math +import os + +import tensorflow as tf + +from tensorflow.examples.tutorials.mnist import mnist + +flags = tf.app.flags + +# Flags for configuring the task +flags.DEFINE_string("job_name", None, "job name: worker or ps") +flags.DEFINE_integer("task_index", 0, + "Worker task index, should be >= 0. task_index=0 is " + "the chief worker task the performs the variable " + "initialization") +flags.DEFINE_string("ps_hosts", "", + "Comma-separated list of hostname:port pairs") +flags.DEFINE_string("worker_hosts", "", + "Comma-separated list of hostname:port pairs") + +# Training related flags +flags.DEFINE_string("data_dir", None, + "Directory where the mnist data is stored") +flags.DEFINE_string("train_dir", None, + "Directory for storing the checkpoints") +flags.DEFINE_integer("hidden1", 128, + "Number of units in the 1st hidden layer of the NN") +flags.DEFINE_integer("hidden2", 128, + "Number of units in the 2nd hidden layer of the NN") +flags.DEFINE_integer("batch_size", 100, "Training batch size") +flags.DEFINE_float("learning_rate", 0.01, "Learning rate") + +FLAGS = flags.FLAGS +TRAIN_FILE = "train.tfrecords" + + +def read_and_decode(filename_queue): + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) + features = tf.parse_single_example( + serialized_example, + # Defaults are not specified since both keys are required. + features={ + 'image_raw': tf.FixedLenFeature([], tf.string), + 'label': tf.FixedLenFeature([], tf.int64), + }) + + # Convert from a scalar string tensor (whose single string has + # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape + # [mnist.IMAGE_PIXELS]. + image = tf.decode_raw(features['image_raw'], tf.uint8) + image.set_shape([mnist.IMAGE_PIXELS]) + + # OPTIONAL: Could reshape into a 28x28 image and apply distortions + # here. Since we are not applying any distortions in this + # example, and the next step expects the image to be flattened + # into a vector, we don't bother. + + # Convert from [0, 255] -> [-0.5, 0.5] floats. + image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 + + # Convert label from a scalar uint8 tensor to an int32 scalar. + label = tf.cast(features['label'], tf.int32) + + return image, label + + +def inputs(batch_size): + """Reads input data. + + Args: + batch_size: Number of examples per returned batch. + + Returns: + A tuple (images, labels), where: + * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS] + in the range [-0.5, 0.5]. + * labels is an int32 tensor with shape [batch_size] with the true label, + a number in the range [0, mnist.NUM_CLASSES). + """ + filename = os.path.join(FLAGS.data_dir, TRAIN_FILE) + + with tf.name_scope('input'): + filename_queue = tf.train.string_input_producer([filename]) + + # Even when reading in multiple threads, share the filename + # queue. + image, label = read_and_decode(filename_queue) + + # Shuffle the examples and collect them into batch_size batches. + # (Internally uses a RandomShuffleQueue.) + # We run this in two threads to avoid being a bottleneck. + images, sparse_labels = tf.train.shuffle_batch( + [image, label], batch_size=batch_size, num_threads=2, + capacity=1000 + 3 * batch_size, + # Ensures a minimum amount of shuffling of examples. + min_after_dequeue=1000) + + return images, sparse_labels + + +def device_and_target(): + # If FLAGS.job_name is not set, we're running single-machine TensorFlow. + # Don't set a device. + if FLAGS.job_name is None: + print("Running single-machine training") + return (None, "") + + # Otherwise we're running distributed TensorFlow. + print("Running distributed training") + if FLAGS.task_index is None or FLAGS.task_index == "": + raise ValueError("Must specify an explicit `task_index`") + if FLAGS.ps_hosts is None or FLAGS.ps_hosts == "": + raise ValueError("Must specify an explicit `ps_hosts`") + if FLAGS.worker_hosts is None or FLAGS.worker_hosts == "": + raise ValueError("Must specify an explicit `worker_hosts`") + + cluster_spec = tf.train.ClusterSpec({ + "ps": FLAGS.ps_hosts.split(","), + "worker": FLAGS.worker_hosts.split(","), + }) + server = tf.train.Server( + cluster_spec, job_name=FLAGS.job_name, task_index=FLAGS.task_index) + if FLAGS.job_name == "ps": + server.join() + + worker_device = "/job:worker/task:{}".format(FLAGS.task_index) + # The device setter will automatically place Variables ops on separate + # parameter servers (ps). The non-Variable ops will be placed on the workers. + return ( + tf.train.replica_device_setter( + worker_device=worker_device, + cluster=cluster_spec), + server.target, + ) + + +def main(unused_argv): + if FLAGS.data_dir is None or FLAGS.data_dir == "": + raise ValueError("Must specify an explicit `data_dir`") + if FLAGS.train_dir is None or FLAGS.train_dir == "": + raise ValueError("Must specify an explicit `train_dir`") + + device, target = device_and_target() + with tf.device(device): + images, labels = inputs(FLAGS.batch_size) + logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2) + loss = mnist.loss(logits, labels) + train_op = mnist.training(loss, FLAGS.learning_rate) + + with tf.train.MonitoredTrainingSession( + master=target, + is_chief=(FLAGS.task_index == 0), + checkpoint_dir=FLAGS.train_dir) as sess: + while not sess.should_stop(): + sess.run(train_op) + + +if __name__ == "__main__": + tf.app.run() diff --git a/docker/mnist_replica.py b/docker/mnist_replica.py deleted file mode 100644 index 33c8730a..00000000 --- a/docker/mnist_replica.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import math -import sys -import tempfile -import time - -import tensorflow as tf -from tensorflow.examples.tutorials.mnist import input_data - - -flags = tf.app.flags - -# Flags for configuring the task -flags.DEFINE_integer("task_index", None, - "Worker task index, should be >= 0. task_index=0 is " - "the master worker task the performs the variable " - "initialization ") -flags.DEFINE_string("ps_hosts","localhost:2222", - "Comma-separated list of hostname:port pairs") -flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", - "Comma-separated list of hostname:port pairs") -flags.DEFINE_string("job_name", None,"job name: worker or ps") - -# Training related flags -flags.DEFINE_string("data_dir", "/tmp/mnist-data", - "Directory for storing mnist data") -flags.DEFINE_string("train_dir", "/tmp/mnist-logs", - "Directory for storing checkpoint") -flags.DEFINE_integer("replicas_to_aggregate", None, - "Number of replicas to aggregate before parameter update" - "is applied (For sync_replicas mode only; default: " - "num_workers)") -flags.DEFINE_integer("hidden_units", 100, - "Number of units in the hidden layer of the NN") -flags.DEFINE_integer("train_steps", 200, - "Number of (global) training steps to perform") -flags.DEFINE_integer("batch_size", 100, "Training batch size") -flags.DEFINE_float("learning_rate", 0.01, "Learning rate") -flags.DEFINE_boolean("sync_replicas", False, - "Use the sync_replicas (synchronized replicas) mode, " - "wherein the parameter updates from workers are aggregated " - "before applied to avoid stale gradients") - -FLAGS = flags.FLAGS - - -IMAGE_PIXELS = 28 - - -def main(unused_argv): - if FLAGS.job_name is None or FLAGS.job_name == "": - raise ValueError("Must specify an explicit `job_name`") - if FLAGS.task_index is None or FLAGS.task_index =="": - raise ValueError("Must specify an explicit `task_index`") - - # Construct the cluster and start the server - ps_spec = FLAGS.ps_hosts.split(",") - worker_spec = FLAGS.worker_hosts.split(",") - - # Get the number of workers. - num_workers = len(worker_spec) - - cluster_spec = tf.train.ClusterSpec({ - "ps": ps_spec, - "worker": worker_spec}) - - server = tf.train.Server( - cluster_spec, job_name=FLAGS.job_name, task_index=FLAGS.task_index) - if FLAGS.job_name == "ps": - server.join() - - is_chief = (FLAGS.task_index == 0) - - mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) - - task_index = FLAGS.task_index - worker_device = "/job:worker/task:%d/cpu:0" % task_index - - # The device setter will automatically place Variables ops on separate - # parameter servers (ps). The non-Variable ops will be placed on the workers. - with tf.device( - tf.train.replica_device_setter( - worker_device=worker_device, - cluster=cluster_spec)): - global_step = tf.Variable(0, name="global_step", trainable=False) - - # Variables of the hidden layer - hid_w = tf.Variable( - tf.truncated_normal( - [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], - stddev=1.0 / IMAGE_PIXELS), - name="hid_w") - hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b") - - # Variables of the softmax layer - sm_w = tf.Variable( - tf.truncated_normal( - [FLAGS.hidden_units, 10], - stddev=1.0 / math.sqrt(FLAGS.hidden_units)), - name="sm_w") - sm_b = tf.Variable(tf.zeros([10]), name="sm_b") - - # Ops: located on the worker specified with task_index - x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS]) - y_ = tf.placeholder(tf.float32, [None, 10]) - - hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) - hid = tf.nn.relu(hid_lin) - - y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) - cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) - - opt = tf.train.AdamOptimizer(FLAGS.learning_rate) - - if FLAGS.sync_replicas: - if FLAGS.replicas_to_aggregate is None: - replicas_to_aggregate = num_workers - else: - replicas_to_aggregate = FLAGS.replicas_to_aggregate - - opt = tf.train.SyncReplicasOptimizer( - opt, - replicas_to_aggregate=replicas_to_aggregate, - total_num_replicas=num_workers, - replica_id=task_index, - name="mnist_sync_replicas") - - train_step = opt.minimize(cross_entropy, global_step=global_step) - - if FLAGS.sync_replicas and is_chief: - # Initial token and chief queue runners required by the sync_replicas mode - chief_queue_runner = opt.get_chief_queue_runner() - init_tokens_op = opt.get_init_tokens_op() - - init_op = tf.initialize_all_variables() - if FLAGS.train_dir: - train_dir = FLAGS.train_dir - else: - train_dir = tempfile.mkdtemp() - sv = tf.train.Supervisor( - is_chief=is_chief, - logdir=train_dir, - init_op=init_op, - recovery_wait_secs=1, - global_step=global_step) - - sess_config = tf.ConfigProto( - allow_soft_placement=True, - log_device_placement=False, - device_filters=["/job:ps", "/job:worker/task:%d" % task_index]) - - # The chief worker (task_index==0) session will prepare the session, - # while the remaining workers will wait for the preparation to complete. - if is_chief: - print("Worker %d: Initializing session..." % task_index) - else: - print("Worker %d: Waiting for session to be initialized..." % - task_index) - - sess = sv.prepare_or_wait_for_session(server.target, - config=sess_config) - - print("Worker %d: Session initialization complete." % task_index) - - if FLAGS.sync_replicas and is_chief: - # Chief worker will start the chief queue runner and call the init op - sv.start_queue_runners(sess, [chief_queue_runner]) - sess.run(init_tokens_op) - - # Perform training - time_begin = time.time() - print("Training begins @ %f" % time_begin) - - local_step = 0 - while True: - # Training feed - batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size) - train_feed = {x: batch_xs, y_: batch_ys} - - _, step = sess.run([train_step, global_step], feed_dict=train_feed) - local_step += 1 - - now = time.time() - print("%f: Worker %d: training step %d done (global step: %d)" % - (now, task_index, local_step, step)) - - if step >= FLAGS.train_steps: - break - - time_end = time.time() - print("Training ends @ %f" % time_end) - training_time = time_end - time_begin - print("Training elapsed time: %f s" % training_time) - - # Validation feed - val_feed = {x: mnist.validation.images, y_: mnist.validation.labels} - val_xent = sess.run(cross_entropy, feed_dict=val_feed) - print("After %d training step(s), validation cross entropy = %g" % - (FLAGS.train_steps, val_xent)) - -if __name__ == "__main__": - tf.app.run() diff --git a/marathon/README.md b/marathon/README.md index 8d4d4103..da7a5b9d 100644 --- a/marathon/README.md +++ b/marathon/README.md @@ -6,7 +6,7 @@ Before you start, you need to set up a Mesos cluster with Marathon installed and ## Write the Training Program This section covers instructions on how to write your training program and build your docker image. - 1. Write your own training program. This program must accept `worker_hosts`, `ps_hosts`, `job_name`, `task_index` as command line flags which are then parsed to build `ClusterSpec`. After that, the task either joins with the server or starts building graphs. Please refero to the [main page](../README.md) for code snippets and description of between-graph replication. An example can be found in `docker/mnist_replica.py`. + 1. Write your own training program. This program must accept `worker_hosts`, `ps_hosts`, `job_name`, `task_index` as command line flags which are then parsed to build `ClusterSpec`. After that, the task either joins with the server or starts building graphs. Please refero to the [main page](../README.md) for code snippets and description of between-graph replication. An example can be found in `docker/mnist.py`. In the case of large training input is needed by the training program, we recommend copying your data to shared storage first and then point each worker to the data. You may want to add a flag called `data_dir`. Please refer to the [adding flags](#add-commandline-flags) section for adding this flag into the marathon config. @@ -47,7 +47,7 @@ To start the cluster, simply post the Marathon json config file to the Marathon curl -i -H 'Content-Type: application/json' -d @mycluster.json http://marathon.mesos:8080/v2/groups ``` -You may want to make sure your cluster is running the training program correctly. Navigate to the DC/OS web console and look for stdout or stderr of the chief worker. The `mnist_replica.py` example would print losses for each step and final loss when training is done. +You may want to make sure your cluster is running the training program correctly. Navigate to the DC/OS web console and look for stdout or stderr of the chief worker. The `mnist.py` example would print losses for each step and final loss when training is done. ![Screenshot of the chief worker] (../images/chief_worker_stdout.png "Screenshot of the chief worker")