Skip to content

Commit

Permalink
Add Hadoop InputFormat/OutputFormat for TFRecord feature file (tensor…
Browse files Browse the repository at this point in the history
  • Loading branch information
llhe authored and jhseu committed Nov 1, 2016
1 parent 628c8d6 commit 6fd9200
Show file tree
Hide file tree
Showing 12 changed files with 829 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ request.
Kubernetes.
- [marathon](marathon) - Templates for running distributed TensorFlow using
Marathon, deployed on top of Mesos.
- [hadoop](hadoop) - TFRecord file InputFormat/OutputFormat for Hadoop MapReduce
and Spark.

## Distributed TensorFlow

Expand Down
80 changes: 80 additions & 0 deletions hadoop/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Hadoop MapReduce InputFormat/OutputFormat for TFRecord file

This directory contains [Apache Hadoop](http://hadoop.apache.org/) MapReduce
InputFormat/OutputFormat implementation for Tensorflow TFRecord feature files.
This can also be used with [Apache Spark](http://spark.apache.org/).

## Prerequisites

1. [protoc 3.1.0](https://developers.google.com/protocol-buffers/) must be
installed.

2. [Apache Maven](https://maven.apache.org/) must be installed.

3. This is tested with Hadoop 2.6.0, so you'd better have hadoop 2.6.0 or
compitable YARN/Spark cluster.

## Build and install

1. Compile Tensorflow Example protos

```sh
# Suppose $TF_SRC_ROOT is the source code root of tensorflow project
protoc --proto_path=$TF_SRC_ROOT --java_out=src/main/java/ $TF_SRC_ROOT/tensorflow/core/example/{example,feature}.proto
```

2. Compile the code

```sh
mvn clean package
```

3. Optionally install (or deploy) the jars

```sh
mvn install
```

After installed (or deployed), the package can be used with following dependency:

```xml
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-hadoop</artifactId>
<version>1.0-SNAPSHOT</version>
</dependency>
```

## Use with MapReduce
The Hadoop MapReduce example can be found [here](src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java).

## Use with Spark
Spark support reading/writing files with Hadoop InputFormat/OutputFormat, the
following code snippet demostrate the usage.

```scala
import com.google.protobuf.ByteString
import org.apache.hadoop.io.{NullWritable, BytesWritable}
import org.apache.spark.{SparkConf, SparkContext}
import org.tensorflow.example.{BytesList, Int64List, Feature, Features, Example}
import org.tensorflow.hadoop.io.TFRecordFileOutputFormat
val inputPath = "path/to/input.txt"
val outputPath = "path/to/output.tfr"
val sparkConf = new SparkConf().setAppName("TFRecord Demo")
val sc = new SparkContext(sparkConf)
var features = sc.textFile(inputPath).map(line => {
val text = BytesList.newBuilder().addValue(ByteString.copyFrom(line.getBytes)).build()
val features = Features.newBuilder()
.putFeature("text", Feature.newBuilder().setBytesList(text).build())
.build()
val example = Example.newBuilder()
.setFeatures(features)
.build()
(new BytesWritable(example.toByteArray), NullWritable.get())
})
features.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat](outputPath)
```
70 changes: 70 additions & 0 deletions hadoop/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-hadoop</artifactId>
<packaging>jar</packaging>
<version>1.0-SNAPSHOT</version>
<name>tensorflow-hadoop</name>
<url>https://www.tensorflow.org</url>

<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.6</maven.compiler.source>
<maven.compiler.target>1.6</maven.compiler.target>
<hadoop.version>2.6.0</hadoop.version>
<protobuf.version>3.1.0</protobuf.version>
<junit.version>4.11</junit.version>
</properties>

<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>${hadoop.version}</version>
<exclusions>
<exclusion>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-mapreduce-client-core</artifactId>
<version>${hadoop.version}</version>
<exclusions>
<exclusion>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-mapreduce-client-jobclient</artifactId>
<version>${hadoop.version}</version>
<type>test-jar</type>
<optional>true</optional>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow.hadoop.example;

import com.google.protobuf.ByteString;
import org.apache.hadoop.mapreduce.InputFormat;
import org.apache.hadoop.mapreduce.OutputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.tensorflow.example.*;
import org.tensorflow.hadoop.io.TFRecordFileInputFormat;
import org.tensorflow.hadoop.io.TFRecordFileOutputFormat;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

import java.io.IOException;
import java.util.Map;

public class TFRecordFileMRExample {
/**
* Convert from text file to TFRecord file. Each line is converted into two dummy features: the
* content of each line and the starting offset of each line.
*/
static class ToTFRecordMapper extends Mapper<LongWritable, Text, BytesWritable, NullWritable> {
ToTFRecordMapper(){}

@Override protected void map(LongWritable key, Text value,
Context context) throws IOException, InterruptedException {
Int64List int64List = Int64List.newBuilder().addValue(key.get()).build();
Feature offset = Feature.newBuilder().setInt64List(int64List).build();

ByteString byteString = ByteString.copyFrom(value.copyBytes());
BytesList bytesList = BytesList.newBuilder().addValue(byteString).build();
Feature text = Feature.newBuilder().setBytesList(bytesList).build();

Features features = Features.newBuilder()
.putFeature("offset", offset)
.putFeature("text", text)
.build();
Example example = Example.newBuilder().setFeatures(features).build();
context.write(new BytesWritable(example.toByteArray()), NullWritable.get());
}
}

/**
* Convert from previous TFRecord file to text file.
*/
static class FromTFRecordMapper extends Mapper<BytesWritable, NullWritable, NullWritable, Text> {
FromTFRecordMapper(){}

@Override protected void map(BytesWritable key, NullWritable value,
Context context) throws IOException, InterruptedException {
Example example = Example.parseFrom(key.getBytes());
Map<String, Feature> featureMap = example.getFeatures().getFeatureMap();
byte[] text = featureMap.get("text").getBytesList().getValue(0).toByteArray();
context.write(NullWritable.get(), new Text(text));
}
}

public static boolean convert(String jobName,
Class<? extends Mapper> mapperClass,
Class<? extends Writable> outputKeyClass,
Class<? extends Writable> outputValueClass,
Class<? extends InputFormat> inFormatClass,
Class<? extends OutputFormat> outFormatClass,
Path input,
Path output) throws InterruptedException, IOException, ClassNotFoundException {
Configuration conf = new Configuration();
Job job = Job.getInstance(conf, jobName);
job.setJarByClass(mapperClass);
job.setMapperClass(mapperClass);
job.setNumReduceTasks(0);

job.setInputFormatClass(inFormatClass);
job.setOutputFormatClass(outFormatClass);
job.setOutputKeyClass(outputKeyClass);
job.setOutputValueClass(outputValueClass);

final FileSystem fs = FileSystem.get(output.toUri(), conf);
fs.delete(output, true);
FileInputFormat.addInputPath(job, input);
FileOutputFormat.setOutputPath(job, output);

return job.waitForCompletion(true);
}

public static void main(String[] args) throws Exception {
String testRoot = "/tmp/tfrecord-file-test";
if (args.length == 1) {
testRoot = args[0];
} else if (args.length > 1) {
System.out.println("Usage: TFRecordFileMRExample [path]");
}

Path testRootPath = new Path(testRoot);
Path input = new Path(testRootPath, "input.txt");
Path tfrout = new Path(testRootPath, "output.tfr");
Path txtout = new Path(testRootPath, "output.txt");

convert("ToTFR", ToTFRecordMapper.class, BytesWritable.class, NullWritable.class,
TextInputFormat.class, TFRecordFileOutputFormat.class, input, tfrout);
convert("FromTFR", FromTFRecordMapper.class, NullWritable.class, Text.class,
TFRecordFileInputFormat.class, TextOutputFormat.class, tfrout, txtout);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow.hadoop.io;

import org.tensorflow.hadoop.util.TFRecordReader;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;

import java.io.IOException;

public class TFRecordFileInputFormat extends FileInputFormat<BytesWritable, NullWritable> {
@Override public RecordReader<BytesWritable, NullWritable> createRecordReader(
InputSplit inputSplit, final TaskAttemptContext context) throws IOException, InterruptedException {

return new RecordReader<BytesWritable, NullWritable>() {
private FSDataInputStream fsdis;
private TFRecordReader reader;
private long length;
private long begin;
private byte[] current;

@Override public void initialize(InputSplit split, TaskAttemptContext context)
throws IOException, InterruptedException {
Configuration conf = context.getConfiguration();
FileSplit fileSplit = (FileSplit) split;
length = fileSplit.getLength();
begin = fileSplit.getStart();

final Path file = fileSplit.getPath();
FileSystem fs = file.getFileSystem(conf);
fsdis = fs.open(file, TFRecordIOConf.getBufferSize(conf));
reader = new TFRecordReader(fsdis, TFRecordIOConf.getDoCrc32Check(conf));
}

@Override public boolean nextKeyValue() throws IOException, InterruptedException {
current = reader.read();
return current != null;
}

@Override public BytesWritable getCurrentKey() throws IOException, InterruptedException {
return new BytesWritable(current);
}

@Override public NullWritable getCurrentValue() throws IOException, InterruptedException {
return NullWritable.get();
}

@Override public float getProgress() throws IOException, InterruptedException {
return (fsdis.getPos() - begin) / (length + 1e-6f);
}

@Override public void close() throws IOException {
fsdis.close();
}
};
}

@Override
protected boolean isSplitable(JobContext context, Path file) {
return false;
}
}
Loading

0 comments on commit 6fd9200

Please sign in to comment.