diff --git a/README.md b/README.md
index 3ec52dbd..9bc3afaf 100644
--- a/README.md
+++ b/README.md
@@ -15,6 +15,8 @@ request.
Kubernetes.
- [marathon](marathon) - Templates for running distributed TensorFlow using
Marathon, deployed on top of Mesos.
+- [hadoop](hadoop) - TFRecord file InputFormat/OutputFormat for Hadoop MapReduce
+ and Spark.
## Distributed TensorFlow
diff --git a/hadoop/README.md b/hadoop/README.md
new file mode 100644
index 00000000..2a3a2585
--- /dev/null
+++ b/hadoop/README.md
@@ -0,0 +1,80 @@
+# Hadoop MapReduce InputFormat/OutputFormat for TFRecord file
+
+This directory contains [Apache Hadoop](http://hadoop.apache.org/) MapReduce
+InputFormat/OutputFormat implementation for Tensorflow TFRecord feature files.
+This can also be used with [Apache Spark](http://spark.apache.org/).
+
+## Prerequisites
+
+1. [protoc 3.1.0](https://developers.google.com/protocol-buffers/) must be
+installed.
+
+2. [Apache Maven](https://maven.apache.org/) must be installed.
+
+3. This is tested with Hadoop 2.6.0, so you'd better have hadoop 2.6.0 or
+compitable YARN/Spark cluster.
+
+## Build and install
+
+1. Compile Tensorflow Example protos
+
+ ```sh
+ # Suppose $TF_SRC_ROOT is the source code root of tensorflow project
+ protoc --proto_path=$TF_SRC_ROOT --java_out=src/main/java/ $TF_SRC_ROOT/tensorflow/core/example/{example,feature}.proto
+ ```
+
+2. Compile the code
+
+ ```sh
+ mvn clean package
+ ```
+
+3. Optionally install (or deploy) the jars
+
+ ```sh
+ mvn install
+ ```
+
+ After installed (or deployed), the package can be used with following dependency:
+
+ ```xml
+
+ org.tensorflow
+ tensorflow-hadoop
+ 1.0-SNAPSHOT
+
+ ```
+
+## Use with MapReduce
+The Hadoop MapReduce example can be found [here](src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java).
+
+## Use with Spark
+Spark support reading/writing files with Hadoop InputFormat/OutputFormat, the
+following code snippet demostrate the usage.
+
+```scala
+import com.google.protobuf.ByteString
+import org.apache.hadoop.io.{NullWritable, BytesWritable}
+import org.apache.spark.{SparkConf, SparkContext}
+import org.tensorflow.example.{BytesList, Int64List, Feature, Features, Example}
+import org.tensorflow.hadoop.io.TFRecordFileOutputFormat
+
+val inputPath = "path/to/input.txt"
+val outputPath = "path/to/output.tfr"
+
+val sparkConf = new SparkConf().setAppName("TFRecord Demo")
+val sc = new SparkContext(sparkConf)
+
+var features = sc.textFile(inputPath).map(line => {
+ val text = BytesList.newBuilder().addValue(ByteString.copyFrom(line.getBytes)).build()
+ val features = Features.newBuilder()
+ .putFeature("text", Feature.newBuilder().setBytesList(text).build())
+ .build()
+ val example = Example.newBuilder()
+ .setFeatures(features)
+ .build()
+ (new BytesWritable(example.toByteArray), NullWritable.get())
+})
+
+features.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat](outputPath)
+```
diff --git a/hadoop/pom.xml b/hadoop/pom.xml
new file mode 100644
index 00000000..cde2ca9c
--- /dev/null
+++ b/hadoop/pom.xml
@@ -0,0 +1,70 @@
+
+ 4.0.0
+ org.tensorflow
+ tensorflow-hadoop
+ jar
+ 1.0-SNAPSHOT
+ tensorflow-hadoop
+ https://www.tensorflow.org
+
+
+ UTF-8
+ 1.6
+ 1.6
+ 2.6.0
+ 3.1.0
+ 4.11
+
+
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop.version}
+
+
+ com.google.protobuf
+ protobuf-java
+
+
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-core
+ ${hadoop.version}
+
+
+ com.google.protobuf
+ protobuf-java
+
+
+
+
+ com.google.protobuf
+ protobuf-java
+ ${protobuf.version}
+
+
+ junit
+ junit
+ ${junit.version}
+ test
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-jobclient
+ ${hadoop.version}
+ test-jar
+ true
+ test
+
+
+ com.google.protobuf
+ protobuf-java
+
+
+
+
+
diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java b/hadoop/src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java
new file mode 100644
index 00000000..c1b1779a
--- /dev/null
+++ b/hadoop/src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java
@@ -0,0 +1,124 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.hadoop.example;
+
+import com.google.protobuf.ByteString;
+import org.apache.hadoop.mapreduce.InputFormat;
+import org.apache.hadoop.mapreduce.OutputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.tensorflow.example.*;
+import org.tensorflow.hadoop.io.TFRecordFileInputFormat;
+import org.tensorflow.hadoop.io.TFRecordFileOutputFormat;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.*;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+
+import java.io.IOException;
+import java.util.Map;
+
+public class TFRecordFileMRExample {
+ /**
+ * Convert from text file to TFRecord file. Each line is converted into two dummy features: the
+ * content of each line and the starting offset of each line.
+ */
+ static class ToTFRecordMapper extends Mapper {
+ ToTFRecordMapper(){}
+
+ @Override protected void map(LongWritable key, Text value,
+ Context context) throws IOException, InterruptedException {
+ Int64List int64List = Int64List.newBuilder().addValue(key.get()).build();
+ Feature offset = Feature.newBuilder().setInt64List(int64List).build();
+
+ ByteString byteString = ByteString.copyFrom(value.copyBytes());
+ BytesList bytesList = BytesList.newBuilder().addValue(byteString).build();
+ Feature text = Feature.newBuilder().setBytesList(bytesList).build();
+
+ Features features = Features.newBuilder()
+ .putFeature("offset", offset)
+ .putFeature("text", text)
+ .build();
+ Example example = Example.newBuilder().setFeatures(features).build();
+ context.write(new BytesWritable(example.toByteArray()), NullWritable.get());
+ }
+ }
+
+ /**
+ * Convert from previous TFRecord file to text file.
+ */
+ static class FromTFRecordMapper extends Mapper {
+ FromTFRecordMapper(){}
+
+ @Override protected void map(BytesWritable key, NullWritable value,
+ Context context) throws IOException, InterruptedException {
+ Example example = Example.parseFrom(key.getBytes());
+ Map featureMap = example.getFeatures().getFeatureMap();
+ byte[] text = featureMap.get("text").getBytesList().getValue(0).toByteArray();
+ context.write(NullWritable.get(), new Text(text));
+ }
+ }
+
+ public static boolean convert(String jobName,
+ Class extends Mapper> mapperClass,
+ Class extends Writable> outputKeyClass,
+ Class extends Writable> outputValueClass,
+ Class extends InputFormat> inFormatClass,
+ Class extends OutputFormat> outFormatClass,
+ Path input,
+ Path output) throws InterruptedException, IOException, ClassNotFoundException {
+ Configuration conf = new Configuration();
+ Job job = Job.getInstance(conf, jobName);
+ job.setJarByClass(mapperClass);
+ job.setMapperClass(mapperClass);
+ job.setNumReduceTasks(0);
+
+ job.setInputFormatClass(inFormatClass);
+ job.setOutputFormatClass(outFormatClass);
+ job.setOutputKeyClass(outputKeyClass);
+ job.setOutputValueClass(outputValueClass);
+
+ final FileSystem fs = FileSystem.get(output.toUri(), conf);
+ fs.delete(output, true);
+ FileInputFormat.addInputPath(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+
+ return job.waitForCompletion(true);
+ }
+
+ public static void main(String[] args) throws Exception {
+ String testRoot = "/tmp/tfrecord-file-test";
+ if (args.length == 1) {
+ testRoot = args[0];
+ } else if (args.length > 1) {
+ System.out.println("Usage: TFRecordFileMRExample [path]");
+ }
+
+ Path testRootPath = new Path(testRoot);
+ Path input = new Path(testRootPath, "input.txt");
+ Path tfrout = new Path(testRootPath, "output.tfr");
+ Path txtout = new Path(testRootPath, "output.txt");
+
+ convert("ToTFR", ToTFRecordMapper.class, BytesWritable.class, NullWritable.class,
+ TextInputFormat.class, TFRecordFileOutputFormat.class, input, tfrout);
+ convert("FromTFR", FromTFRecordMapper.class, NullWritable.class, Text.class,
+ TFRecordFileInputFormat.class, TextOutputFormat.class, tfrout, txtout);
+ }
+}
diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileInputFormat.java b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileInputFormat.java
new file mode 100644
index 00000000..3cf0d9ef
--- /dev/null
+++ b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileInputFormat.java
@@ -0,0 +1,85 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.hadoop.io;
+
+import org.tensorflow.hadoop.util.TFRecordReader;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.RecordReader;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.FileSplit;
+
+import java.io.IOException;
+
+public class TFRecordFileInputFormat extends FileInputFormat {
+ @Override public RecordReader createRecordReader(
+ InputSplit inputSplit, final TaskAttemptContext context) throws IOException, InterruptedException {
+
+ return new RecordReader() {
+ private FSDataInputStream fsdis;
+ private TFRecordReader reader;
+ private long length;
+ private long begin;
+ private byte[] current;
+
+ @Override public void initialize(InputSplit split, TaskAttemptContext context)
+ throws IOException, InterruptedException {
+ Configuration conf = context.getConfiguration();
+ FileSplit fileSplit = (FileSplit) split;
+ length = fileSplit.getLength();
+ begin = fileSplit.getStart();
+
+ final Path file = fileSplit.getPath();
+ FileSystem fs = file.getFileSystem(conf);
+ fsdis = fs.open(file, TFRecordIOConf.getBufferSize(conf));
+ reader = new TFRecordReader(fsdis, TFRecordIOConf.getDoCrc32Check(conf));
+ }
+
+ @Override public boolean nextKeyValue() throws IOException, InterruptedException {
+ current = reader.read();
+ return current != null;
+ }
+
+ @Override public BytesWritable getCurrentKey() throws IOException, InterruptedException {
+ return new BytesWritable(current);
+ }
+
+ @Override public NullWritable getCurrentValue() throws IOException, InterruptedException {
+ return NullWritable.get();
+ }
+
+ @Override public float getProgress() throws IOException, InterruptedException {
+ return (fsdis.getPos() - begin) / (length + 1e-6f);
+ }
+
+ @Override public void close() throws IOException {
+ fsdis.close();
+ }
+ };
+ }
+
+ @Override
+ protected boolean isSplitable(JobContext context, Path file) {
+ return false;
+ }
+}
diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormat.java b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormat.java
new file mode 100644
index 00000000..5cebe1b8
--- /dev/null
+++ b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormat.java
@@ -0,0 +1,53 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.hadoop.io;
+
+import org.tensorflow.hadoop.util.TFRecordWriter;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.mapreduce.RecordWriter;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+
+import java.io.IOException;
+
+public class TFRecordFileOutputFormat extends FileOutputFormat {
+ @Override public RecordWriter getRecordWriter(
+ TaskAttemptContext context) throws IOException, InterruptedException {
+ Configuration conf = context.getConfiguration();
+ Path file = getDefaultWorkFile(context, "");
+ FileSystem fs = file.getFileSystem(conf);
+
+ int bufferSize = TFRecordIOConf.getBufferSize(conf);
+ final FSDataOutputStream fsdos = fs.create(file, true, bufferSize);
+ final TFRecordWriter writer = new TFRecordWriter(fsdos);
+ return new RecordWriter() {
+ @Override public void write(BytesWritable key, NullWritable value)
+ throws IOException, InterruptedException {
+ writer.write(key.getBytes());
+ }
+
+ @Override public void close(TaskAttemptContext context)
+ throws IOException, InterruptedException {
+ fsdos.close();
+ }
+ };
+ }
+}
diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordIOConf.java b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordIOConf.java
new file mode 100644
index 00000000..dba79368
--- /dev/null
+++ b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordIOConf.java
@@ -0,0 +1,28 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.hadoop.io;
+
+import org.apache.hadoop.conf.Configuration;
+
+public class TFRecordIOConf {
+ static int getBufferSize(Configuration conf) {
+ return conf.getInt("io.file.buffer.size", 4096);
+ }
+
+ static boolean getDoCrc32Check(Configuration conf) {
+ return conf.getBoolean("tensorflow.read.crc32check", true);
+ }
+}
diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/util/Crc32C.java b/hadoop/src/main/java/org/tensorflow/hadoop/util/Crc32C.java
new file mode 100644
index 00000000..46c5078d
--- /dev/null
+++ b/hadoop/src/main/java/org/tensorflow/hadoop/util/Crc32C.java
@@ -0,0 +1,83 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.hadoop.util;
+
+import org.apache.hadoop.util.PureJavaCrc32C;
+
+import java.util.zip.Checksum;
+
+public class Crc32C implements Checksum {
+ private static final int MASK_DELTA = 0xa282ead8;
+ private PureJavaCrc32C crc32C;
+
+ public static int maskedCrc32c(byte[] data) {
+ return maskedCrc32c(data, 0, data.length);
+ }
+
+ public static int maskedCrc32c(byte[] data, int offset, int length) {
+ Crc32C crc32c = new Crc32C();
+ crc32c.update(data, offset, length);
+ return crc32c.getMaskedValue();
+ }
+
+ /**
+ * Return a masked representation of crc.
+ *
+ * Motivation: it is problematic to compute the CRC of a string that
+ * contains embedded CRCs. Therefore we recommend that CRCs stored
+ * somewhere (e.g., in files) should be masked before being stored.
+ */
+ public static int mask(int crc) {
+ // Rotate right by 15 bits and add a constant.
+ return ((crc >>> 15) | (crc << 17)) + MASK_DELTA;
+ }
+
+ /**
+ * Return the crc whose masked representation is masked_crc.
+ */
+ public static int unmask(int maskedCrc) {
+ int rot = maskedCrc - MASK_DELTA;
+ return ((rot >>> 17) | (rot << 15));
+ }
+
+ public Crc32C() {
+ crc32C = new PureJavaCrc32C();
+ }
+
+ public int getMaskedValue() {
+ return mask(getIntValue());
+ }
+
+ public int getIntValue() {
+ return (int) getValue();
+ }
+
+ @Override public void update(int b) {
+ crc32C.update(b);
+ }
+
+ @Override public void update(byte[] b, int off, int len) {
+ crc32C.update(b, off, len);
+ }
+
+ @Override public long getValue() {
+ return crc32C.getValue();
+ }
+
+ @Override public void reset() {
+ crc32C.reset();
+ }
+}
diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordReader.java b/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordReader.java
new file mode 100644
index 00000000..4526edb8
--- /dev/null
+++ b/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordReader.java
@@ -0,0 +1,97 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.hadoop.util;
+
+import java.io.DataInput;
+import java.io.EOFException;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+public class TFRecordReader {
+ private final DataInput input;
+ private final boolean crcCheck;
+
+ public TFRecordReader(DataInput input, boolean crcCheck) {
+ this.input = input;
+ this.crcCheck = crcCheck;
+ }
+
+ public byte[] read() throws IOException {
+ /**
+ * TFRecord format:
+ * uint64 length
+ * uint32 masked_crc32_of_length
+ * byte data[length]
+ * uint32 masked_crc32_of_data
+ */
+ byte[] lenBytes = new byte[8];
+ try {
+ // Only catch EOF here, other case means corrupted file
+ input.readFully(lenBytes);
+ } catch (EOFException eof) {
+ return null; // return null means EOF
+ }
+ Long len = fromInt64LE(lenBytes);
+
+ // Verify length crc32
+ if (!crcCheck) {
+ input.skipBytes(4);
+ } else {
+ byte[] lenCrc32Bytes = new byte[4];
+ input.readFully(lenCrc32Bytes);
+ int lenCrc32 = fromInt32LE(lenCrc32Bytes);
+ if (lenCrc32 != Crc32C.maskedCrc32c(lenBytes)) {
+ throw new IOException("Length header crc32 checking failed: " + lenCrc32 + " != " +
+ Crc32C.maskedCrc32c(lenBytes) + ", length = " + len);
+ }
+ }
+
+ if (len > Integer.MAX_VALUE) {
+ throw new IOException("Record size exceeds max value of int32: " + len);
+ }
+ byte[] data = new byte[len.intValue()];
+ input.readFully(data);
+
+ // Verify data crc32
+ if (!crcCheck) {
+ input.skipBytes(4);
+ } else {
+ byte[] dataCrc32Bytes = new byte[4];
+ input.readFully(dataCrc32Bytes);
+ int dataCrc32 = fromInt32LE(dataCrc32Bytes);
+ if (dataCrc32 != Crc32C.maskedCrc32c(data)) {
+ throw new IOException("Data crc32 checking failed: " + dataCrc32 + " != " +
+ Crc32C.maskedCrc32c(data));
+ }
+ }
+ return data;
+ }
+
+ private long fromInt64LE(byte[] data) {
+ assert data.length == 8;
+ ByteBuffer bb = ByteBuffer.wrap(data);
+ bb.order(ByteOrder.LITTLE_ENDIAN);
+ return bb.getLong();
+ }
+
+ private int fromInt32LE(byte[] data) {
+ assert data.length == 4;
+ ByteBuffer bb = ByteBuffer.wrap(data);
+ bb.order(ByteOrder.LITTLE_ENDIAN);
+ return bb.getInt();
+ }
+}
diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordWriter.java b/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordWriter.java
new file mode 100644
index 00000000..e2db66f2
--- /dev/null
+++ b/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordWriter.java
@@ -0,0 +1,59 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.hadoop.util;
+
+import java.io.*;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+public class TFRecordWriter {
+ private final DataOutput output;
+
+ public TFRecordWriter(DataOutput output) {
+ this.output = output;
+ }
+
+ public void write(byte[] record) throws IOException {
+ /**
+ * TFRecord format:
+ * uint64 length
+ * uint32 masked_crc32_of_length
+ * byte data[length]
+ * uint32 masked_crc32_of_data
+ */
+ byte[] len = toInt64LE(record.length);
+ output.write(len);
+ output.write(toInt32LE(Crc32C.maskedCrc32c(len)));
+ output.write(record);
+ output.write(toInt32LE(Crc32C.maskedCrc32c(record)));
+ }
+
+ private byte[] toInt64LE(long data) {
+ byte[] buff = new byte[8];
+ ByteBuffer bb = ByteBuffer.wrap(buff);
+ bb.order(ByteOrder.LITTLE_ENDIAN);
+ bb.putLong(data);
+ return buff;
+ }
+
+ private byte[] toInt32LE(int data) {
+ byte[] buff = new byte[4];
+ ByteBuffer bb = ByteBuffer.wrap(buff);
+ bb.order(ByteOrder.LITTLE_ENDIAN);
+ bb.putInt(data);
+ return buff;
+ }
+}
diff --git a/hadoop/src/test/java/org/tensorflow/hadoop/io/TFRecordFileTest.java b/hadoop/src/test/java/org/tensorflow/hadoop/io/TFRecordFileTest.java
new file mode 100644
index 00000000..38fd7f04
--- /dev/null
+++ b/hadoop/src/test/java/org/tensorflow/hadoop/io/TFRecordFileTest.java
@@ -0,0 +1,105 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.hadoop.io;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.mapreduce.*;
+import org.apache.hadoop.mapreduce.task.MapContextImpl;
+import org.junit.Test;
+import org.tensorflow.example.Example;
+import org.tensorflow.example.Feature;
+import org.tensorflow.example.Features;
+import org.tensorflow.example.Int64List;
+
+import java.util.Map;
+import java.util.Random;
+import java.util.TreeMap;
+
+import static org.junit.Assert.assertEquals;
+
+public class TFRecordFileTest {
+ private static final int RECORDS = 10000;
+
+ @Test
+ public void testInputOutputFormat() throws Exception {
+ Configuration conf = new Configuration();
+ Job job = Job.getInstance(conf);
+
+ Path outdir = new Path(System.getProperty("test.build.data", "/tmp"), "tfr-test");
+
+ TFRecordFileOutputFormat.setOutputPath(job, outdir);
+
+ TaskAttemptContext context =
+ MapReduceTestUtil.createDummyMapTaskAttemptContext(job.getConfiguration());
+ OutputFormat outputFormat =
+ new TFRecordFileOutputFormat();
+ OutputCommitter committer = outputFormat.getOutputCommitter(context);
+ committer.setupJob(job);
+ RecordWriter writer = outputFormat.
+ getRecordWriter(context);
+
+ // Write Example with random numbers
+ Random rand = new Random();
+ Map records = new TreeMap();
+ try {
+ for (int i = 0; i < RECORDS; ++i) {
+ long randValue = rand.nextLong();
+ records.put((long) i, randValue);
+ Int64List data = Int64List.newBuilder().addValue(i).addValue(randValue).build();
+ Feature feature = Feature.newBuilder().setInt64List(data).build();
+ Features features = Features.newBuilder().putFeature("data", feature).build();
+ Example example = Example.newBuilder().setFeatures(features).build();
+ BytesWritable key = new BytesWritable(example.toByteArray());
+ writer.write(key, NullWritable.get());
+ }
+ } finally {
+ writer.close(context);
+ }
+ committer.commitTask(context);
+ committer.commitJob(job);
+
+ // Read and compare
+ TFRecordFileInputFormat.setInputPaths(job, outdir);
+ InputFormat inputFormat = new TFRecordFileInputFormat();
+ for (InputSplit split : inputFormat.getSplits(job)) {
+ RecordReader reader =
+ inputFormat.createRecordReader(split, context);
+ MapContext mcontext =
+ new MapContextImpl
+ (job.getConfiguration(), context.getTaskAttemptID(), reader, null, null,
+ MapReduceTestUtil.createDummyReporter(),
+ split);
+ reader.initialize(split, mcontext);
+ try {
+ while (reader.nextKeyValue()) {
+ BytesWritable bytes = reader.getCurrentKey();
+ Example example = Example.parseFrom(bytes.getBytes());
+ Int64List data = example.getFeatures().getFeatureMap().get("data").getInt64List();
+ Long key = data.getValue(0);
+ Long value = data.getValue(1);
+ assertEquals(records.get(key), value);
+ records.remove(key);
+ }
+ } finally {
+ reader.close();
+ }
+ }
+ assertEquals(0, records.size());
+ }
+}
diff --git a/hadoop/src/test/java/org/tensorflow/hadoop/util/TFRecordTest.java b/hadoop/src/test/java/org/tensorflow/hadoop/util/TFRecordTest.java
new file mode 100644
index 00000000..5973aac4
--- /dev/null
+++ b/hadoop/src/test/java/org/tensorflow/hadoop/util/TFRecordTest.java
@@ -0,0 +1,43 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.hadoop.util;
+
+import java.io.*;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+public class TFRecordTest {
+ @Test
+ public void testTFRecord() throws IOException {
+ int count = 1000;
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ TFRecordWriter writer = new TFRecordWriter(new DataOutputStream(baos));
+ for (int i = 0; i < count; ++i) {
+ writer.write((Integer.toString(i)).getBytes());
+ }
+ baos.close();
+
+ ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
+ TFRecordReader reader = new TFRecordReader(new DataInputStream(bais), true);
+ for (int i = 0; i < count; ++i) {
+ assertEquals(Integer.toString(i), new String(reader.read()));
+ }
+ assertNull(reader.read()); // EOF
+ }
+}