Skip to content

Commit

Permalink
Add support for SequenceExample to Spark TensorFlow connector (tensor…
Browse files Browse the repository at this point in the history
…flow#52)

* Decode and encode SequenceExample feature lists to Spark SQL rows

* Added format to specify recordType while writing dataframes

* Add unit tests for SequenceExample and fix sql datatype for BytesList

DataType for BytesList changed from StringType to ArrayType(StringType)

* Add tests for sequence example

* Update tests and README example for importing TensorFlow SequenceExample

* Update README for Spark TensorFlow connector

* Update recordType in README for Spark TensorFlow connector

* Fix handling of integers in SequenceExample export

* Add schema inference for SequenceExamples in Spark
  • Loading branch information
skavulya authored and jhseu committed Jun 29, 2017
1 parent c5cf920 commit 05ac88b
Show file tree
Hide file tree
Showing 22 changed files with 1,639 additions and 667 deletions.
86 changes: 77 additions & 9 deletions spark/spark-tensorflow-connector/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,39 +53,107 @@ SBT Jars
$SPARK_HOME/bin/spark-shell --jars target/scala-2.11/spark-tensorflow-connector-assembly-1.0.0.jar
```

The following code snippet demonstrates usage.
## Features
This library allows reading TensorFlow records in local or distributed filesystem as [Spark DataFrames](https://spark.apache.org/docs/latest/sql-programming-guide.html).
When reading TensorFlow records into Spark DataFrame, the API accepts several options:
* `load`: input path to TensorFlow records. Similar to Spark can accept standard Hadoop globbing expressions.
* `schema`: schema of TensorFlow records. Optional schema defined using Spark StructType. If not provided, the schema is inferred from TensorFlow records.
* `recordType`: input format of TensorFlow records. By default it is Example. Possible values are:
* `Example`: TensorFlow [Example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records
* `SequenceExample`: TensorFlow [SequenceExample](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records

When writing Spark DataFrame to TensorFlow records, the API accepts several options:
* `save`: output path to TensorFlow records. Output path to TensorFlow records on local or distributed filesystem.
* `recordType`: output format of TensorFlow records. By default it is Example. Possible values are:
* `Example`: TensorFlow [Example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records
* `SequenceExample`: TensorFlow [SequenceExample](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records

## Schema inference
This library supports automatic schema inference when reading TensorFlow records into Spark DataFrames.
Schema inference is expensive since it requires an extra pass through the data.

The schema inference rules are described in the table below:

| TFRecordType | Feature Type | Inferred Spark Data Type |
| ------------------------ |:--------------|:--------------------------|
| Example, SequenceExample | Int64List | LongType if all lists have length=1, else ArrayType(LongType) |
| Example, SequenceExample | FloatList | FloatType if all lists have length=1, else ArrayType(FloatType) |
| Example, SequenceExample | BytesList | StringType if all lists have length=1, else ArrayType(StringType) |
| SequenceExample | FeatureList of Int64List | ArrayType(ArrayType(LongType)) |
| SequenceExample | FeatureList of FloatList | ArrayType(ArrayType(FloatType)) |
| SequenceExample | FeatureList of BytesList | ArrayType(ArrayType(StringType)) |

## Usage Examples

The following code snippet demonstrates usage on test data.

```scala
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

val path = "test-output.tfr"
val path = "test-output.tfrecord"
val testRows: Array[Row] = Array(
new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
val schema = StructType(List(StructField("id", IntegerType),
StructField("IntegerTypelabel", IntegerType),
StructField("LongTypelabel", LongType),
StructField("FloatTypelabel", FloatType),
StructField("DoubleTypelabel", DoubleType),
StructField("vectorlabel", ArrayType(DoubleType, true)),
StructField("IntegerTypeLabel", IntegerType),
StructField("LongTypeLabel", LongType),
StructField("FloatTypeLabel", FloatType),
StructField("DoubleTypeLabel", DoubleType),
StructField("VectorLabel", ArrayType(DoubleType, true)),
StructField("name", StringType)))

val rdd = spark.sparkContext.parallelize(testRows)

//Save DataFrame as TFRecords
val df: DataFrame = spark.createDataFrame(rdd, schema)
df.write.format("tfrecords").save(path)
df.write.format("tfrecords").option("recordType", "Example").save(path)

//Read TFRecords into DataFrame.
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val importedDf1: DataFrame = spark.read.format("tfrecords").load(path)
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path)
importedDf1.show()

//Read TFRecords into DataFrame using custom schema
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()
```

#### Loading YouTube-8M dataset to Spark
Here's how to import the [YouTube-8M](https://research.google.com/youtube8m/) dataset into a Spark DataFrame.

```sh
curl http://us.data.yt8m.org/1/video_level/train/train-0.tfrecord > /tmp/video_level-train-0.tfrecord
curl http://us.data.yt8m.org/1/frame_level/train/train-0.tfrecord > /tmp/frame_level-train-0.tfrecord
```

```scala
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

//Import Video-level Example dataset into DataFrame
val videoSchema = StructType(List(StructField("video_id", StringType),
StructField("labels", ArrayType(IntegerType, true)),
StructField("mean_rgb", ArrayType(FloatType, true)),
StructField("mean_audio", ArrayType(FloatType, true))))
val videoDf: DataFrame = spark.read.format("tfrecords").schema(videoSchema).option("recordType", "Example").load("file:///tmp/video_level-train-0.tfrecord")
videoDf.show()
videoDf.write.format("tfrecords").option("recordType", "Example").save("youtube-8m-video.tfrecord")
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").schema(videoSchema).load("youtube-8m-video.tfrecords")
importedDf1.show()

//Import Frame-level SequenceExample dataset into DataFrame
val frameSchema = StructType(List(StructField("video_id", StringType),
StructField("labels", ArrayType(IntegerType, true)),
StructField("rgb", ArrayType(ArrayType(StringType, true),true)),
StructField("audio", ArrayType(ArrayType(StringType, true),true))))
val frameDf: DataFrame = spark.read.format("tfrecords").schema(frameSchema).option("recordType", "SequenceExample").load("file:///tmp/frame_level-train-0.tfrecord")
frameDf.show()
frameDf.write.format("tfrecords").option("recordType", "SequenceExample").save("youtube-8m-frame.tfrecord")
val importedDf2: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").schema(frameSchema).load("youtube-8m-frame.tfrecords")
importedDf2.show()
```

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,34 @@ class DefaultSource extends DataSourceRegister

val path = parameters("path")

val recordType = parameters.getOrElse("recordType", "Example")

//Export DataFrame as TFRecords
val features = data.rdd.map(row => {
val example = DefaultTfRecordRowEncoder.encodeTfRecord(row)
(new BytesWritable(example.toByteArray), NullWritable.get())
recordType match {
case "Example" =>
val example = DefaultTfRecordRowEncoder.encodeExample(row)
(new BytesWritable(example.toByteArray), NullWritable.get())
case "SequenceExample" =>
val sequenceExample = DefaultTfRecordRowEncoder.encodeSequenceExample(row)
(new BytesWritable(sequenceExample.toByteArray), NullWritable.get())
case _ =>
throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be Example or SequenceExample")
}
})
features.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat](path)

TensorflowRelation(parameters)(sqlContext.sparkSession)
}

// Reads TensorFlow Records into DataFrame with Custom Schema
override def createRelation(sqlContext: SQLContext,
parameters: Map[String, String],
schema: StructType): BaseRelation = {
TensorflowRelation(parameters, Some(schema))(sqlContext.sparkSession)
}

// Reads TensorFlow Records into DataFrame
// Reads TensorFlow Records into DataFrame with schema inferred
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): TensorflowRelation = {
TensorflowRelation(parameters)(sqlContext.sparkSession)
}
Expand Down
Loading

0 comments on commit 05ac88b

Please sign in to comment.