forked from tensorflow/ecosystem
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Hadoop InputFormat/OutputFormat for TFRecord feature file (tensor…
- Loading branch information
Showing
12 changed files
with
829 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Hadoop MapReduce InputFormat/OutputFormat for TFRecord file | ||
|
||
This directory contains [Apache Hadoop](http://hadoop.apache.org/) MapReduce | ||
InputFormat/OutputFormat implementation for Tensorflow TFRecord feature files. | ||
This can also be used with [Apache Spark](http://spark.apache.org/). | ||
|
||
## Prerequisites | ||
|
||
1. [protoc 3.1.0](https://developers.google.com/protocol-buffers/) must be | ||
installed. | ||
|
||
2. [Apache Maven](https://maven.apache.org/) must be installed. | ||
|
||
3. This is tested with Hadoop 2.6.0, so you'd better have hadoop 2.6.0 or | ||
compitable YARN/Spark cluster. | ||
|
||
## Build and install | ||
|
||
1. Compile Tensorflow Example protos | ||
|
||
```sh | ||
# Suppose $TF_SRC_ROOT is the source code root of tensorflow project | ||
protoc --proto_path=$TF_SRC_ROOT --java_out=src/main/java/ $TF_SRC_ROOT/tensorflow/core/example/{example,feature}.proto | ||
``` | ||
|
||
2. Compile the code | ||
|
||
```sh | ||
mvn clean package | ||
``` | ||
|
||
3. Optionally install (or deploy) the jars | ||
|
||
```sh | ||
mvn install | ||
``` | ||
|
||
After installed (or deployed), the package can be used with following dependency: | ||
|
||
```xml | ||
<dependency> | ||
<groupId>org.tensorflow</groupId> | ||
<artifactId>tensorflow-hadoop</artifactId> | ||
<version>1.0-SNAPSHOT</version> | ||
</dependency> | ||
``` | ||
|
||
## Use with MapReduce | ||
The Hadoop MapReduce example can be found [here](src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java). | ||
|
||
## Use with Spark | ||
Spark support reading/writing files with Hadoop InputFormat/OutputFormat, the | ||
following code snippet demostrate the usage. | ||
|
||
```scala | ||
import com.google.protobuf.ByteString | ||
import org.apache.hadoop.io.{NullWritable, BytesWritable} | ||
import org.apache.spark.{SparkConf, SparkContext} | ||
import org.tensorflow.example.{BytesList, Int64List, Feature, Features, Example} | ||
import org.tensorflow.hadoop.io.TFRecordFileOutputFormat | ||
val inputPath = "path/to/input.txt" | ||
val outputPath = "path/to/output.tfr" | ||
val sparkConf = new SparkConf().setAppName("TFRecord Demo") | ||
val sc = new SparkContext(sparkConf) | ||
var features = sc.textFile(inputPath).map(line => { | ||
val text = BytesList.newBuilder().addValue(ByteString.copyFrom(line.getBytes)).build() | ||
val features = Features.newBuilder() | ||
.putFeature("text", Feature.newBuilder().setBytesList(text).build()) | ||
.build() | ||
val example = Example.newBuilder() | ||
.setFeatures(features) | ||
.build() | ||
(new BytesWritable(example.toByteArray), NullWritable.get()) | ||
}) | ||
features.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat](outputPath) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<groupId>org.tensorflow</groupId> | ||
<artifactId>tensorflow-hadoop</artifactId> | ||
<packaging>jar</packaging> | ||
<version>1.0-SNAPSHOT</version> | ||
<name>tensorflow-hadoop</name> | ||
<url>https://www.tensorflow.org</url> | ||
|
||
<properties> | ||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> | ||
<maven.compiler.source>1.6</maven.compiler.source> | ||
<maven.compiler.target>1.6</maven.compiler.target> | ||
<hadoop.version>2.6.0</hadoop.version> | ||
<protobuf.version>3.1.0</protobuf.version> | ||
<junit.version>4.11</junit.version> | ||
</properties> | ||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>org.apache.hadoop</groupId> | ||
<artifactId>hadoop-common</artifactId> | ||
<version>${hadoop.version}</version> | ||
<exclusions> | ||
<exclusion> | ||
<groupId>com.google.protobuf</groupId> | ||
<artifactId>protobuf-java</artifactId> | ||
</exclusion> | ||
</exclusions> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.apache.hadoop</groupId> | ||
<artifactId>hadoop-mapreduce-client-core</artifactId> | ||
<version>${hadoop.version}</version> | ||
<exclusions> | ||
<exclusion> | ||
<groupId>com.google.protobuf</groupId> | ||
<artifactId>protobuf-java</artifactId> | ||
</exclusion> | ||
</exclusions> | ||
</dependency> | ||
<dependency> | ||
<groupId>com.google.protobuf</groupId> | ||
<artifactId>protobuf-java</artifactId> | ||
<version>${protobuf.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>junit</groupId> | ||
<artifactId>junit</artifactId> | ||
<version>${junit.version}</version> | ||
<scope>test</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.apache.hadoop</groupId> | ||
<artifactId>hadoop-mapreduce-client-jobclient</artifactId> | ||
<version>${hadoop.version}</version> | ||
<type>test-jar</type> | ||
<optional>true</optional> | ||
<scope>test</scope> | ||
<exclusions> | ||
<exclusion> | ||
<groupId>com.google.protobuf</groupId> | ||
<artifactId>protobuf-java</artifactId> | ||
</exclusion> | ||
</exclusions> | ||
</dependency> | ||
</dependencies> | ||
</project> |
124 changes: 124 additions & 0 deletions
124
hadoop/src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
package org.tensorflow.hadoop.example; | ||
|
||
import com.google.protobuf.ByteString; | ||
import org.apache.hadoop.mapreduce.InputFormat; | ||
import org.apache.hadoop.mapreduce.OutputFormat; | ||
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; | ||
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; | ||
import org.tensorflow.example.*; | ||
import org.tensorflow.hadoop.io.TFRecordFileInputFormat; | ||
import org.tensorflow.hadoop.io.TFRecordFileOutputFormat; | ||
import org.apache.hadoop.conf.Configuration; | ||
import org.apache.hadoop.fs.FileSystem; | ||
import org.apache.hadoop.fs.Path; | ||
import org.apache.hadoop.io.*; | ||
import org.apache.hadoop.mapreduce.Job; | ||
import org.apache.hadoop.mapreduce.Mapper; | ||
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; | ||
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; | ||
|
||
import java.io.IOException; | ||
import java.util.Map; | ||
|
||
public class TFRecordFileMRExample { | ||
/** | ||
* Convert from text file to TFRecord file. Each line is converted into two dummy features: the | ||
* content of each line and the starting offset of each line. | ||
*/ | ||
static class ToTFRecordMapper extends Mapper<LongWritable, Text, BytesWritable, NullWritable> { | ||
ToTFRecordMapper(){} | ||
|
||
@Override protected void map(LongWritable key, Text value, | ||
Context context) throws IOException, InterruptedException { | ||
Int64List int64List = Int64List.newBuilder().addValue(key.get()).build(); | ||
Feature offset = Feature.newBuilder().setInt64List(int64List).build(); | ||
|
||
ByteString byteString = ByteString.copyFrom(value.copyBytes()); | ||
BytesList bytesList = BytesList.newBuilder().addValue(byteString).build(); | ||
Feature text = Feature.newBuilder().setBytesList(bytesList).build(); | ||
|
||
Features features = Features.newBuilder() | ||
.putFeature("offset", offset) | ||
.putFeature("text", text) | ||
.build(); | ||
Example example = Example.newBuilder().setFeatures(features).build(); | ||
context.write(new BytesWritable(example.toByteArray()), NullWritable.get()); | ||
} | ||
} | ||
|
||
/** | ||
* Convert from previous TFRecord file to text file. | ||
*/ | ||
static class FromTFRecordMapper extends Mapper<BytesWritable, NullWritable, NullWritable, Text> { | ||
FromTFRecordMapper(){} | ||
|
||
@Override protected void map(BytesWritable key, NullWritable value, | ||
Context context) throws IOException, InterruptedException { | ||
Example example = Example.parseFrom(key.getBytes()); | ||
Map<String, Feature> featureMap = example.getFeatures().getFeatureMap(); | ||
byte[] text = featureMap.get("text").getBytesList().getValue(0).toByteArray(); | ||
context.write(NullWritable.get(), new Text(text)); | ||
} | ||
} | ||
|
||
public static boolean convert(String jobName, | ||
Class<? extends Mapper> mapperClass, | ||
Class<? extends Writable> outputKeyClass, | ||
Class<? extends Writable> outputValueClass, | ||
Class<? extends InputFormat> inFormatClass, | ||
Class<? extends OutputFormat> outFormatClass, | ||
Path input, | ||
Path output) throws InterruptedException, IOException, ClassNotFoundException { | ||
Configuration conf = new Configuration(); | ||
Job job = Job.getInstance(conf, jobName); | ||
job.setJarByClass(mapperClass); | ||
job.setMapperClass(mapperClass); | ||
job.setNumReduceTasks(0); | ||
|
||
job.setInputFormatClass(inFormatClass); | ||
job.setOutputFormatClass(outFormatClass); | ||
job.setOutputKeyClass(outputKeyClass); | ||
job.setOutputValueClass(outputValueClass); | ||
|
||
final FileSystem fs = FileSystem.get(output.toUri(), conf); | ||
fs.delete(output, true); | ||
FileInputFormat.addInputPath(job, input); | ||
FileOutputFormat.setOutputPath(job, output); | ||
|
||
return job.waitForCompletion(true); | ||
} | ||
|
||
public static void main(String[] args) throws Exception { | ||
String testRoot = "/tmp/tfrecord-file-test"; | ||
if (args.length == 1) { | ||
testRoot = args[0]; | ||
} else if (args.length > 1) { | ||
System.out.println("Usage: TFRecordFileMRExample [path]"); | ||
} | ||
|
||
Path testRootPath = new Path(testRoot); | ||
Path input = new Path(testRootPath, "input.txt"); | ||
Path tfrout = new Path(testRootPath, "output.tfr"); | ||
Path txtout = new Path(testRootPath, "output.txt"); | ||
|
||
convert("ToTFR", ToTFRecordMapper.class, BytesWritable.class, NullWritable.class, | ||
TextInputFormat.class, TFRecordFileOutputFormat.class, input, tfrout); | ||
convert("FromTFR", FromTFRecordMapper.class, NullWritable.class, Text.class, | ||
TFRecordFileInputFormat.class, TextOutputFormat.class, tfrout, txtout); | ||
} | ||
} |
85 changes: 85 additions & 0 deletions
85
hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileInputFormat.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
package org.tensorflow.hadoop.io; | ||
|
||
import org.tensorflow.hadoop.util.TFRecordReader; | ||
import org.apache.hadoop.conf.Configuration; | ||
import org.apache.hadoop.fs.FSDataInputStream; | ||
import org.apache.hadoop.fs.FileSystem; | ||
import org.apache.hadoop.fs.Path; | ||
import org.apache.hadoop.io.BytesWritable; | ||
import org.apache.hadoop.io.NullWritable; | ||
import org.apache.hadoop.mapreduce.InputSplit; | ||
import org.apache.hadoop.mapreduce.JobContext; | ||
import org.apache.hadoop.mapreduce.RecordReader; | ||
import org.apache.hadoop.mapreduce.TaskAttemptContext; | ||
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; | ||
import org.apache.hadoop.mapreduce.lib.input.FileSplit; | ||
|
||
import java.io.IOException; | ||
|
||
public class TFRecordFileInputFormat extends FileInputFormat<BytesWritable, NullWritable> { | ||
@Override public RecordReader<BytesWritable, NullWritable> createRecordReader( | ||
InputSplit inputSplit, final TaskAttemptContext context) throws IOException, InterruptedException { | ||
|
||
return new RecordReader<BytesWritable, NullWritable>() { | ||
private FSDataInputStream fsdis; | ||
private TFRecordReader reader; | ||
private long length; | ||
private long begin; | ||
private byte[] current; | ||
|
||
@Override public void initialize(InputSplit split, TaskAttemptContext context) | ||
throws IOException, InterruptedException { | ||
Configuration conf = context.getConfiguration(); | ||
FileSplit fileSplit = (FileSplit) split; | ||
length = fileSplit.getLength(); | ||
begin = fileSplit.getStart(); | ||
|
||
final Path file = fileSplit.getPath(); | ||
FileSystem fs = file.getFileSystem(conf); | ||
fsdis = fs.open(file, TFRecordIOConf.getBufferSize(conf)); | ||
reader = new TFRecordReader(fsdis, TFRecordIOConf.getDoCrc32Check(conf)); | ||
} | ||
|
||
@Override public boolean nextKeyValue() throws IOException, InterruptedException { | ||
current = reader.read(); | ||
return current != null; | ||
} | ||
|
||
@Override public BytesWritable getCurrentKey() throws IOException, InterruptedException { | ||
return new BytesWritable(current); | ||
} | ||
|
||
@Override public NullWritable getCurrentValue() throws IOException, InterruptedException { | ||
return NullWritable.get(); | ||
} | ||
|
||
@Override public float getProgress() throws IOException, InterruptedException { | ||
return (fsdis.getPos() - begin) / (length + 1e-6f); | ||
} | ||
|
||
@Override public void close() throws IOException { | ||
fsdis.close(); | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
protected boolean isSplitable(JobContext context, Path file) { | ||
return false; | ||
} | ||
} |
Oops, something went wrong.