diff --git a/README.md b/README.md index 3ec52dbd..9bc3afaf 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,8 @@ request. Kubernetes. - [marathon](marathon) - Templates for running distributed TensorFlow using Marathon, deployed on top of Mesos. +- [hadoop](hadoop) - TFRecord file InputFormat/OutputFormat for Hadoop MapReduce + and Spark. ## Distributed TensorFlow diff --git a/hadoop/README.md b/hadoop/README.md new file mode 100644 index 00000000..2a3a2585 --- /dev/null +++ b/hadoop/README.md @@ -0,0 +1,80 @@ +# Hadoop MapReduce InputFormat/OutputFormat for TFRecord file + +This directory contains [Apache Hadoop](http://hadoop.apache.org/) MapReduce +InputFormat/OutputFormat implementation for Tensorflow TFRecord feature files. +This can also be used with [Apache Spark](http://spark.apache.org/). + +## Prerequisites + +1. [protoc 3.1.0](https://developers.google.com/protocol-buffers/) must be +installed. + +2. [Apache Maven](https://maven.apache.org/) must be installed. + +3. This is tested with Hadoop 2.6.0, so you'd better have hadoop 2.6.0 or +compitable YARN/Spark cluster. + +## Build and install + +1. Compile Tensorflow Example protos + + ```sh + # Suppose $TF_SRC_ROOT is the source code root of tensorflow project + protoc --proto_path=$TF_SRC_ROOT --java_out=src/main/java/ $TF_SRC_ROOT/tensorflow/core/example/{example,feature}.proto + ``` + +2. Compile the code + + ```sh + mvn clean package + ``` + +3. Optionally install (or deploy) the jars + + ```sh + mvn install + ``` + + After installed (or deployed), the package can be used with following dependency: + + ```xml + + org.tensorflow + tensorflow-hadoop + 1.0-SNAPSHOT + + ``` + +## Use with MapReduce +The Hadoop MapReduce example can be found [here](src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java). + +## Use with Spark +Spark support reading/writing files with Hadoop InputFormat/OutputFormat, the +following code snippet demostrate the usage. + +```scala +import com.google.protobuf.ByteString +import org.apache.hadoop.io.{NullWritable, BytesWritable} +import org.apache.spark.{SparkConf, SparkContext} +import org.tensorflow.example.{BytesList, Int64List, Feature, Features, Example} +import org.tensorflow.hadoop.io.TFRecordFileOutputFormat + +val inputPath = "path/to/input.txt" +val outputPath = "path/to/output.tfr" + +val sparkConf = new SparkConf().setAppName("TFRecord Demo") +val sc = new SparkContext(sparkConf) + +var features = sc.textFile(inputPath).map(line => { + val text = BytesList.newBuilder().addValue(ByteString.copyFrom(line.getBytes)).build() + val features = Features.newBuilder() + .putFeature("text", Feature.newBuilder().setBytesList(text).build()) + .build() + val example = Example.newBuilder() + .setFeatures(features) + .build() + (new BytesWritable(example.toByteArray), NullWritable.get()) +}) + +features.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat](outputPath) +``` diff --git a/hadoop/pom.xml b/hadoop/pom.xml new file mode 100644 index 00000000..cde2ca9c --- /dev/null +++ b/hadoop/pom.xml @@ -0,0 +1,70 @@ + + 4.0.0 + org.tensorflow + tensorflow-hadoop + jar + 1.0-SNAPSHOT + tensorflow-hadoop + https://www.tensorflow.org + + + UTF-8 + 1.6 + 1.6 + 2.6.0 + 3.1.0 + 4.11 + + + + + org.apache.hadoop + hadoop-common + ${hadoop.version} + + + com.google.protobuf + protobuf-java + + + + + org.apache.hadoop + hadoop-mapreduce-client-core + ${hadoop.version} + + + com.google.protobuf + protobuf-java + + + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + junit + junit + ${junit.version} + test + + + org.apache.hadoop + hadoop-mapreduce-client-jobclient + ${hadoop.version} + test-jar + true + test + + + com.google.protobuf + protobuf-java + + + + + diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java b/hadoop/src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java new file mode 100644 index 00000000..c1b1779a --- /dev/null +++ b/hadoop/src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java @@ -0,0 +1,124 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.hadoop.example; + +import com.google.protobuf.ByteString; +import org.apache.hadoop.mapreduce.InputFormat; +import org.apache.hadoop.mapreduce.OutputFormat; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; +import org.tensorflow.example.*; +import org.tensorflow.hadoop.io.TFRecordFileInputFormat; +import org.tensorflow.hadoop.io.TFRecordFileOutputFormat; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.*; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; + +import java.io.IOException; +import java.util.Map; + +public class TFRecordFileMRExample { + /** + * Convert from text file to TFRecord file. Each line is converted into two dummy features: the + * content of each line and the starting offset of each line. + */ + static class ToTFRecordMapper extends Mapper { + ToTFRecordMapper(){} + + @Override protected void map(LongWritable key, Text value, + Context context) throws IOException, InterruptedException { + Int64List int64List = Int64List.newBuilder().addValue(key.get()).build(); + Feature offset = Feature.newBuilder().setInt64List(int64List).build(); + + ByteString byteString = ByteString.copyFrom(value.copyBytes()); + BytesList bytesList = BytesList.newBuilder().addValue(byteString).build(); + Feature text = Feature.newBuilder().setBytesList(bytesList).build(); + + Features features = Features.newBuilder() + .putFeature("offset", offset) + .putFeature("text", text) + .build(); + Example example = Example.newBuilder().setFeatures(features).build(); + context.write(new BytesWritable(example.toByteArray()), NullWritable.get()); + } + } + + /** + * Convert from previous TFRecord file to text file. + */ + static class FromTFRecordMapper extends Mapper { + FromTFRecordMapper(){} + + @Override protected void map(BytesWritable key, NullWritable value, + Context context) throws IOException, InterruptedException { + Example example = Example.parseFrom(key.getBytes()); + Map featureMap = example.getFeatures().getFeatureMap(); + byte[] text = featureMap.get("text").getBytesList().getValue(0).toByteArray(); + context.write(NullWritable.get(), new Text(text)); + } + } + + public static boolean convert(String jobName, + Class mapperClass, + Class outputKeyClass, + Class outputValueClass, + Class inFormatClass, + Class outFormatClass, + Path input, + Path output) throws InterruptedException, IOException, ClassNotFoundException { + Configuration conf = new Configuration(); + Job job = Job.getInstance(conf, jobName); + job.setJarByClass(mapperClass); + job.setMapperClass(mapperClass); + job.setNumReduceTasks(0); + + job.setInputFormatClass(inFormatClass); + job.setOutputFormatClass(outFormatClass); + job.setOutputKeyClass(outputKeyClass); + job.setOutputValueClass(outputValueClass); + + final FileSystem fs = FileSystem.get(output.toUri(), conf); + fs.delete(output, true); + FileInputFormat.addInputPath(job, input); + FileOutputFormat.setOutputPath(job, output); + + return job.waitForCompletion(true); + } + + public static void main(String[] args) throws Exception { + String testRoot = "/tmp/tfrecord-file-test"; + if (args.length == 1) { + testRoot = args[0]; + } else if (args.length > 1) { + System.out.println("Usage: TFRecordFileMRExample [path]"); + } + + Path testRootPath = new Path(testRoot); + Path input = new Path(testRootPath, "input.txt"); + Path tfrout = new Path(testRootPath, "output.tfr"); + Path txtout = new Path(testRootPath, "output.txt"); + + convert("ToTFR", ToTFRecordMapper.class, BytesWritable.class, NullWritable.class, + TextInputFormat.class, TFRecordFileOutputFormat.class, input, tfrout); + convert("FromTFR", FromTFRecordMapper.class, NullWritable.class, Text.class, + TFRecordFileInputFormat.class, TextOutputFormat.class, tfrout, txtout); + } +} diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileInputFormat.java b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileInputFormat.java new file mode 100644 index 00000000..3cf0d9ef --- /dev/null +++ b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileInputFormat.java @@ -0,0 +1,85 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.hadoop.io; + +import org.tensorflow.hadoop.util.TFRecordReader; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; + +import java.io.IOException; + +public class TFRecordFileInputFormat extends FileInputFormat { + @Override public RecordReader createRecordReader( + InputSplit inputSplit, final TaskAttemptContext context) throws IOException, InterruptedException { + + return new RecordReader() { + private FSDataInputStream fsdis; + private TFRecordReader reader; + private long length; + private long begin; + private byte[] current; + + @Override public void initialize(InputSplit split, TaskAttemptContext context) + throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + FileSplit fileSplit = (FileSplit) split; + length = fileSplit.getLength(); + begin = fileSplit.getStart(); + + final Path file = fileSplit.getPath(); + FileSystem fs = file.getFileSystem(conf); + fsdis = fs.open(file, TFRecordIOConf.getBufferSize(conf)); + reader = new TFRecordReader(fsdis, TFRecordIOConf.getDoCrc32Check(conf)); + } + + @Override public boolean nextKeyValue() throws IOException, InterruptedException { + current = reader.read(); + return current != null; + } + + @Override public BytesWritable getCurrentKey() throws IOException, InterruptedException { + return new BytesWritable(current); + } + + @Override public NullWritable getCurrentValue() throws IOException, InterruptedException { + return NullWritable.get(); + } + + @Override public float getProgress() throws IOException, InterruptedException { + return (fsdis.getPos() - begin) / (length + 1e-6f); + } + + @Override public void close() throws IOException { + fsdis.close(); + } + }; + } + + @Override + protected boolean isSplitable(JobContext context, Path file) { + return false; + } +} diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormat.java b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormat.java new file mode 100644 index 00000000..5cebe1b8 --- /dev/null +++ b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormat.java @@ -0,0 +1,53 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.hadoop.io; + +import org.tensorflow.hadoop.util.TFRecordWriter; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapreduce.RecordWriter; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; + +import java.io.IOException; + +public class TFRecordFileOutputFormat extends FileOutputFormat { + @Override public RecordWriter getRecordWriter( + TaskAttemptContext context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + Path file = getDefaultWorkFile(context, ""); + FileSystem fs = file.getFileSystem(conf); + + int bufferSize = TFRecordIOConf.getBufferSize(conf); + final FSDataOutputStream fsdos = fs.create(file, true, bufferSize); + final TFRecordWriter writer = new TFRecordWriter(fsdos); + return new RecordWriter() { + @Override public void write(BytesWritable key, NullWritable value) + throws IOException, InterruptedException { + writer.write(key.getBytes()); + } + + @Override public void close(TaskAttemptContext context) + throws IOException, InterruptedException { + fsdos.close(); + } + }; + } +} diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordIOConf.java b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordIOConf.java new file mode 100644 index 00000000..dba79368 --- /dev/null +++ b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordIOConf.java @@ -0,0 +1,28 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.hadoop.io; + +import org.apache.hadoop.conf.Configuration; + +public class TFRecordIOConf { + static int getBufferSize(Configuration conf) { + return conf.getInt("io.file.buffer.size", 4096); + } + + static boolean getDoCrc32Check(Configuration conf) { + return conf.getBoolean("tensorflow.read.crc32check", true); + } +} diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/util/Crc32C.java b/hadoop/src/main/java/org/tensorflow/hadoop/util/Crc32C.java new file mode 100644 index 00000000..46c5078d --- /dev/null +++ b/hadoop/src/main/java/org/tensorflow/hadoop/util/Crc32C.java @@ -0,0 +1,83 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.hadoop.util; + +import org.apache.hadoop.util.PureJavaCrc32C; + +import java.util.zip.Checksum; + +public class Crc32C implements Checksum { + private static final int MASK_DELTA = 0xa282ead8; + private PureJavaCrc32C crc32C; + + public static int maskedCrc32c(byte[] data) { + return maskedCrc32c(data, 0, data.length); + } + + public static int maskedCrc32c(byte[] data, int offset, int length) { + Crc32C crc32c = new Crc32C(); + crc32c.update(data, offset, length); + return crc32c.getMaskedValue(); + } + + /** + * Return a masked representation of crc. + *

+ * Motivation: it is problematic to compute the CRC of a string that + * contains embedded CRCs. Therefore we recommend that CRCs stored + * somewhere (e.g., in files) should be masked before being stored. + */ + public static int mask(int crc) { + // Rotate right by 15 bits and add a constant. + return ((crc >>> 15) | (crc << 17)) + MASK_DELTA; + } + + /** + * Return the crc whose masked representation is masked_crc. + */ + public static int unmask(int maskedCrc) { + int rot = maskedCrc - MASK_DELTA; + return ((rot >>> 17) | (rot << 15)); + } + + public Crc32C() { + crc32C = new PureJavaCrc32C(); + } + + public int getMaskedValue() { + return mask(getIntValue()); + } + + public int getIntValue() { + return (int) getValue(); + } + + @Override public void update(int b) { + crc32C.update(b); + } + + @Override public void update(byte[] b, int off, int len) { + crc32C.update(b, off, len); + } + + @Override public long getValue() { + return crc32C.getValue(); + } + + @Override public void reset() { + crc32C.reset(); + } +} diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordReader.java b/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordReader.java new file mode 100644 index 00000000..4526edb8 --- /dev/null +++ b/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordReader.java @@ -0,0 +1,97 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.hadoop.util; + +import java.io.DataInput; +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class TFRecordReader { + private final DataInput input; + private final boolean crcCheck; + + public TFRecordReader(DataInput input, boolean crcCheck) { + this.input = input; + this.crcCheck = crcCheck; + } + + public byte[] read() throws IOException { + /** + * TFRecord format: + * uint64 length + * uint32 masked_crc32_of_length + * byte data[length] + * uint32 masked_crc32_of_data + */ + byte[] lenBytes = new byte[8]; + try { + // Only catch EOF here, other case means corrupted file + input.readFully(lenBytes); + } catch (EOFException eof) { + return null; // return null means EOF + } + Long len = fromInt64LE(lenBytes); + + // Verify length crc32 + if (!crcCheck) { + input.skipBytes(4); + } else { + byte[] lenCrc32Bytes = new byte[4]; + input.readFully(lenCrc32Bytes); + int lenCrc32 = fromInt32LE(lenCrc32Bytes); + if (lenCrc32 != Crc32C.maskedCrc32c(lenBytes)) { + throw new IOException("Length header crc32 checking failed: " + lenCrc32 + " != " + + Crc32C.maskedCrc32c(lenBytes) + ", length = " + len); + } + } + + if (len > Integer.MAX_VALUE) { + throw new IOException("Record size exceeds max value of int32: " + len); + } + byte[] data = new byte[len.intValue()]; + input.readFully(data); + + // Verify data crc32 + if (!crcCheck) { + input.skipBytes(4); + } else { + byte[] dataCrc32Bytes = new byte[4]; + input.readFully(dataCrc32Bytes); + int dataCrc32 = fromInt32LE(dataCrc32Bytes); + if (dataCrc32 != Crc32C.maskedCrc32c(data)) { + throw new IOException("Data crc32 checking failed: " + dataCrc32 + " != " + + Crc32C.maskedCrc32c(data)); + } + } + return data; + } + + private long fromInt64LE(byte[] data) { + assert data.length == 8; + ByteBuffer bb = ByteBuffer.wrap(data); + bb.order(ByteOrder.LITTLE_ENDIAN); + return bb.getLong(); + } + + private int fromInt32LE(byte[] data) { + assert data.length == 4; + ByteBuffer bb = ByteBuffer.wrap(data); + bb.order(ByteOrder.LITTLE_ENDIAN); + return bb.getInt(); + } +} diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordWriter.java b/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordWriter.java new file mode 100644 index 00000000..e2db66f2 --- /dev/null +++ b/hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordWriter.java @@ -0,0 +1,59 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.hadoop.util; + +import java.io.*; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class TFRecordWriter { + private final DataOutput output; + + public TFRecordWriter(DataOutput output) { + this.output = output; + } + + public void write(byte[] record) throws IOException { + /** + * TFRecord format: + * uint64 length + * uint32 masked_crc32_of_length + * byte data[length] + * uint32 masked_crc32_of_data + */ + byte[] len = toInt64LE(record.length); + output.write(len); + output.write(toInt32LE(Crc32C.maskedCrc32c(len))); + output.write(record); + output.write(toInt32LE(Crc32C.maskedCrc32c(record))); + } + + private byte[] toInt64LE(long data) { + byte[] buff = new byte[8]; + ByteBuffer bb = ByteBuffer.wrap(buff); + bb.order(ByteOrder.LITTLE_ENDIAN); + bb.putLong(data); + return buff; + } + + private byte[] toInt32LE(int data) { + byte[] buff = new byte[4]; + ByteBuffer bb = ByteBuffer.wrap(buff); + bb.order(ByteOrder.LITTLE_ENDIAN); + bb.putInt(data); + return buff; + } +} diff --git a/hadoop/src/test/java/org/tensorflow/hadoop/io/TFRecordFileTest.java b/hadoop/src/test/java/org/tensorflow/hadoop/io/TFRecordFileTest.java new file mode 100644 index 00000000..38fd7f04 --- /dev/null +++ b/hadoop/src/test/java/org/tensorflow/hadoop/io/TFRecordFileTest.java @@ -0,0 +1,105 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.hadoop.io; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapreduce.*; +import org.apache.hadoop.mapreduce.task.MapContextImpl; +import org.junit.Test; +import org.tensorflow.example.Example; +import org.tensorflow.example.Feature; +import org.tensorflow.example.Features; +import org.tensorflow.example.Int64List; + +import java.util.Map; +import java.util.Random; +import java.util.TreeMap; + +import static org.junit.Assert.assertEquals; + +public class TFRecordFileTest { + private static final int RECORDS = 10000; + + @Test + public void testInputOutputFormat() throws Exception { + Configuration conf = new Configuration(); + Job job = Job.getInstance(conf); + + Path outdir = new Path(System.getProperty("test.build.data", "/tmp"), "tfr-test"); + + TFRecordFileOutputFormat.setOutputPath(job, outdir); + + TaskAttemptContext context = + MapReduceTestUtil.createDummyMapTaskAttemptContext(job.getConfiguration()); + OutputFormat outputFormat = + new TFRecordFileOutputFormat(); + OutputCommitter committer = outputFormat.getOutputCommitter(context); + committer.setupJob(job); + RecordWriter writer = outputFormat. + getRecordWriter(context); + + // Write Example with random numbers + Random rand = new Random(); + Map records = new TreeMap(); + try { + for (int i = 0; i < RECORDS; ++i) { + long randValue = rand.nextLong(); + records.put((long) i, randValue); + Int64List data = Int64List.newBuilder().addValue(i).addValue(randValue).build(); + Feature feature = Feature.newBuilder().setInt64List(data).build(); + Features features = Features.newBuilder().putFeature("data", feature).build(); + Example example = Example.newBuilder().setFeatures(features).build(); + BytesWritable key = new BytesWritable(example.toByteArray()); + writer.write(key, NullWritable.get()); + } + } finally { + writer.close(context); + } + committer.commitTask(context); + committer.commitJob(job); + + // Read and compare + TFRecordFileInputFormat.setInputPaths(job, outdir); + InputFormat inputFormat = new TFRecordFileInputFormat(); + for (InputSplit split : inputFormat.getSplits(job)) { + RecordReader reader = + inputFormat.createRecordReader(split, context); + MapContext mcontext = + new MapContextImpl + (job.getConfiguration(), context.getTaskAttemptID(), reader, null, null, + MapReduceTestUtil.createDummyReporter(), + split); + reader.initialize(split, mcontext); + try { + while (reader.nextKeyValue()) { + BytesWritable bytes = reader.getCurrentKey(); + Example example = Example.parseFrom(bytes.getBytes()); + Int64List data = example.getFeatures().getFeatureMap().get("data").getInt64List(); + Long key = data.getValue(0); + Long value = data.getValue(1); + assertEquals(records.get(key), value); + records.remove(key); + } + } finally { + reader.close(); + } + } + assertEquals(0, records.size()); + } +} diff --git a/hadoop/src/test/java/org/tensorflow/hadoop/util/TFRecordTest.java b/hadoop/src/test/java/org/tensorflow/hadoop/util/TFRecordTest.java new file mode 100644 index 00000000..5973aac4 --- /dev/null +++ b/hadoop/src/test/java/org/tensorflow/hadoop/util/TFRecordTest.java @@ -0,0 +1,43 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.hadoop.util; + +import java.io.*; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class TFRecordTest { + @Test + public void testTFRecord() throws IOException { + int count = 1000; + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + TFRecordWriter writer = new TFRecordWriter(new DataOutputStream(baos)); + for (int i = 0; i < count; ++i) { + writer.write((Integer.toString(i)).getBytes()); + } + baos.close(); + + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + TFRecordReader reader = new TFRecordReader(new DataInputStream(bais), true); + for (int i = 0; i < count; ++i) { + assertEquals(Integer.toString(i), new String(reader.read())); + } + assertNull(reader.read()); // EOF + } +}