diff --git a/README.md b/README.md
index 9bc3afaf..458678fe 100644
--- a/README.md
+++ b/README.md
@@ -17,6 +17,7 @@ request.
Marathon, deployed on top of Mesos.
- [hadoop](hadoop) - TFRecord file InputFormat/OutputFormat for Hadoop MapReduce
and Spark.
+- [spark](spark) - Spark TensorFlow Connector
## Distributed TensorFlow
diff --git a/spark/README.md b/spark/README.md
new file mode 100644
index 00000000..32a8c10e
--- /dev/null
+++ b/spark/README.md
@@ -0,0 +1,85 @@
+# spark-tensorflow-connector
+
+This repo contains a library for loading and storing TensorFlow records with [Apache Spark](http://spark.apache.org/).
+The library implements data import from the standard TensorFlow record format ([TFRecords]
+(https://www.tensorflow.org/how_tos/reading_data/)) into Spark SQL DataFrames, and data export from DataFrames to TensorFlow records.
+
+## What's new
+
+This is the initial release of the `spark-tensorflow-connector` repo.
+
+## Known issues
+
+None.
+
+## Prerequisites
+
+1. [Apache Spark 2.0 (or later)](http://spark.apache.org/)
+
+2. [Apache Maven](https://maven.apache.org/)
+
+## Building the library
+You can build library using both Maven and SBT build tools
+
+#### Maven
+Build the library using Maven(3.3) as shown below
+
+```sh
+mvn clean install
+```
+
+#### SBT
+Build the library using SBT(0.13.13) as show below
+```sh
+sbt clean assembly
+```
+
+## Using Spark Shell
+Run this library in Spark using the `--jars` command line option in `spark-shell` or `spark-submit`. For example:
+
+Maven Jars
+```sh
+$SPARK_HOME/bin/spark-shell --jars target/spark-tensorflow-connector-1.0-SNAPSHOT.jar,target/lib/tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar
+```
+
+SBT Jars
+```sh
+$SPARK_HOME/bin/spark-shell --jars target/scala-2.11/spark-tensorflow-connector-assembly-1.0.0.jar
+```
+
+The following code snippet demonstrates usage.
+
+```scala
+import org.apache.commons.io.FileUtils
+import org.apache.spark.sql.{ DataFrame, Row }
+import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.sql.types._
+
+val path = "test-output.tfr"
+val testRows: Array[Row] = Array(
+new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
+new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
+val schema = StructType(List(StructField("id", IntegerType),
+ StructField("IntegerTypelabel", IntegerType),
+ StructField("LongTypelabel", LongType),
+ StructField("FloatTypelabel", FloatType),
+ StructField("DoubleTypelabel", DoubleType),
+ StructField("vectorlabel", ArrayType(DoubleType, true)),
+ StructField("name", StringType)))
+
+val rdd = spark.sparkContext.parallelize(testRows)
+
+//Save DataFrame as TFRecords
+val df: DataFrame = spark.createDataFrame(rdd, schema)
+df.write.format("tfrecords").save(path)
+
+//Read TFRecords into DataFrame.
+//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
+val importedDf1: DataFrame = spark.read.format("tfrecords").load(path)
+importedDf1.show()
+
+//Read TFRecords into DataFrame using custom schema
+val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
+importedDf2.show()
+
+```
diff --git a/spark/build.sbt b/spark/build.sbt
new file mode 100644
index 00000000..3ec7bbde
--- /dev/null
+++ b/spark/build.sbt
@@ -0,0 +1,93 @@
+name := "spark-tensorflow-connector"
+
+organization := "org.tensorflow"
+
+scalaVersion in Global := "2.11.8"
+
+spName := "tensorflow/spark-tensorflow-connector"
+
+sparkVersion := "2.1.0"
+
+sparkComponents ++= Seq("sql", "mllib")
+
+version := "1.0.0"
+
+def ProjectName(name: String,path:String): Project = Project(name, file(path))
+
+resolvers in Global ++= Seq("https://tap.jfrog.io/tap/public" at "https://tap.jfrog.io/tap/public" ,
+ "https://tap.jfrog.io/tap/public-snapshots" at "https://tap.jfrog.io/tap/public-snapshots" ,
+ "https://repo.maven.apache.org/maven2" at "https://repo.maven.apache.org/maven2" )
+
+val `junit_junit` = "junit" % "junit" % "4.12"
+
+val `org.apache.hadoop_hadoop-yarn-api` = "org.apache.hadoop" % "hadoop-yarn-api" % "2.7.3"
+
+val `org.apache.spark_spark-core_2.11` = "org.apache.spark" % "spark-core_2.11" % "2.1.0"
+
+val `org.apache.spark_spark-sql_2.11` = "org.apache.spark" % "spark-sql_2.11" % "2.1.0"
+
+val `org.apache.spark_spark-mllib_2.11` = "org.apache.spark" % "spark-mllib_2.11" % "2.1.0"
+
+val `org.scalatest_scalatest_2.11` = "org.scalatest" % "scalatest_2.11" % "2.2.6"
+
+val `org.tensorflow_tensorflow-hadoop` = "org.tensorflow" % "tensorflow-hadoop" % "1.0-01232017-SNAPSHOT"
+
+libraryDependencies in Global ++= Seq(`org.tensorflow_tensorflow-hadoop` classifier "shaded-protobuf",
+ `org.scalatest_scalatest_2.11` % "test" ,
+ `org.apache.spark_spark-sql_2.11` % "provided" ,
+ `org.apache.spark_spark-mllib_2.11` % "test" classifier "tests",
+ `org.apache.spark_spark-core_2.11` % "provided" ,
+ `org.apache.hadoop_hadoop-yarn-api` % "provided" ,
+ `junit_junit` % "test" )
+
+assemblyExcludedJars in assembly := {
+ val cp = (fullClasspath in assembly).value
+ cp filterNot {x => List("spark-tensorflow-connector-1.0-SNAPSHOT.jar",
+ "tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar").contains(x.data.getName)}
+}
+
+/********************
+ * Release settings *
+ ********************/
+
+spIgnoreProvided := true
+
+spAppendScalaVersion := true
+
+// If you published your package to Maven Central for this release (must be done prior to spPublish)
+spIncludeMaven := false
+
+publishMavenStyle := true
+
+licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0"))
+
+pomExtra :=
+ https://github.com/tensorflow/ecosystem
+
+ git@github.com:tensorflow/ecosystem.git
+ scm:git:git@github.com:tensorflow/ecosystem.git
+
+
+
+ karthikvadla
+ Karthik Vadla
+ https://github.com/karthikvadla
+
+
+ skavulya
+ Soila Kavulya
+ https://github.com/skavulya
+
+
+ joyeshmishra
+ Joyesh Mishra
+ https://github.com/joyeshmishra
+
+
+
+credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials") // A file containing credentials
+
+// Add assembly jar to Spark package
+test in assembly := {}
+
+spShade := true
diff --git a/spark/pom.xml b/spark/pom.xml
new file mode 100644
index 00000000..82b45929
--- /dev/null
+++ b/spark/pom.xml
@@ -0,0 +1,306 @@
+
+
+ 4.0.0
+
+ org.tensorflow
+ spark-tensorflow-connector
+ jar
+ 1.0-SNAPSHOT
+
+
+
+ central1
+ http://central1.maven.org/maven2
+
+ true
+
+
+ false
+
+
+
+
+ tap
+ https://tap.jfrog.io/tap/public
+
+ false
+
+
+ true
+
+
+
+ tap-snapshots
+ https://tap.jfrog.io/tap/public-snapshots
+
+ true
+
+
+ false
+
+
+
+
+
+
+ compile
+
+ true
+
+ !NEVERSETME
+
+
+
+
+
+
+
+ true
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.1.6
+
+
+ compile
+
+ add-source
+ compile
+
+
+
+ -Xms256m
+ -Xmx512m
+
+
+ -g:vars
+ -deprecation
+ -feature
+ -unchecked
+ -Xfatal-warnings
+ -language:implicitConversions
+ -language:existentials
+
+
+
+
+ test
+
+ add-source
+ testCompile
+
+
+
+
+ incremental
+ true
+ 2.11
+ false
+
+
+
+ org.apache.maven.plugins
+ maven-dependency-plugin
+
+
+ copy-dependencies
+ process-resources
+
+ copy-dependencies
+
+
+ provided
+ true
+ org.apache.spark,junit,org.scalatest
+ ${project.build.directory}/lib
+
+
+
+
+
+
+ org.codehaus.mojo
+ properties-maven-plugin
+ 1.0.0
+
+
+ generate-resources
+
+ write-project-properties
+
+
+ ${project.build.outputDirectory}/maven.properties
+
+
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+
+
+
+
+ test
+
+ true
+
+ !NEVERSETME
+
+
+
+
+
+
+
+ true
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.2.2
+
+
+ compile
+
+
+
+
+ true
+ org.scalatest
+ scalatest-maven-plugin
+ 1.0
+
+ ${project.build.directory}/surefire-reports
+ .
+ WDF TestSuite.txt
+ false
+ FTD
+ -Xmx1024m -XX:PermSize=256m -XX:MaxDirectMemorySize=1000m
+
+
+
+ scalaTest
+ test
+
+ test
+
+
+
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+
+
+
+
+ org.scalatest
+ scalatest_2.11
+ 2.2.6
+ test
+
+
+
+
+
+
+ org.scalatest
+ scalatest_2.11
+ test
+
+
+
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org.apache.maven.plugins
+ maven-dependency-plugin
+
+
+ org.scalatest
+ scalatest-maven-plugin
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 3.0
+
+ 1.8
+ 1.8
+
+
+
+
+
+
+ core/src/main/resources
+
+ reference.conf
+
+
+
+ core/src/test/resources
+
+
+
+
+
+
+ org.tensorflow
+ tensorflow-hadoop
+ 1.0-01232017-SNAPSHOT
+ shaded-protobuf
+
+
+ org.apache.spark
+ spark-core_2.11
+ 2.1.0
+ provided
+
+
+ org.apache.spark
+ spark-sql_2.11
+ 2.1.0
+ provided
+
+
+ org.apache.hadoop
+ hadoop-yarn-api
+ 2.7.3
+ provided
+
+
+
+ org.apache.spark
+ spark-mllib_2.11
+ 2.1.0
+ test-jar
+ test
+
+
+ junit
+ junit
+ 4.12
+ test
+
+
+
+
diff --git a/spark/project/build.properties b/spark/project/build.properties
new file mode 100644
index 00000000..5f32afe7
--- /dev/null
+++ b/spark/project/build.properties
@@ -0,0 +1 @@
+sbt.version=0.13.13
\ No newline at end of file
diff --git a/spark/project/plugins.sbt b/spark/project/plugins.sbt
new file mode 100644
index 00000000..0536c57a
--- /dev/null
+++ b/spark/project/plugins.sbt
@@ -0,0 +1,5 @@
+resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/"
+
+addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3")
+
+addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.5")
diff --git a/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
new file mode 100644
index 00000000..c2869067
--- /dev/null
+++ b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -0,0 +1 @@
+org.tensorflow.spark.datasources.tfrecords.DefaultSource
\ No newline at end of file
diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DataTypesConvertor.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DataTypesConvertor.scala
new file mode 100644
index 00000000..2f76aa97
--- /dev/null
+++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DataTypesConvertor.scala
@@ -0,0 +1,49 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords
+
+/**
+ * DataTypes supported
+ */
+object DataTypesConvertor {
+
+ def toLong(value: Any): Long = {
+ value match {
+ case null => throw new IllegalArgumentException("null cannot be converted to Long")
+ case i: Int => i.toLong
+ case l: Long => l
+ case f: Float => f.toLong
+ case d: Double => d.toLong
+ case bd: BigDecimal => bd.toLong
+ case s: String => s.trim().toLong
+ case _ => throw new RuntimeException(s"${value.getClass.getName} toLong is not implemented")
+ }
+ }
+
+ def toFloat(value: Any): Float = {
+ value match {
+ case null => throw new IllegalArgumentException("null cannot be converted to Float")
+ case i: Int => i.toFloat
+ case l: Long => l.toFloat
+ case f: Float => f
+ case d: Double => d.toFloat
+ case bd: BigDecimal => bd.toFloat
+ case s: String => s.trim().toFloat
+ case _ => throw new RuntimeException(s"${value.getClass.getName} toFloat is not implemented")
+ }
+ }
+}
+
diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala
new file mode 100644
index 00000000..ba9e6628
--- /dev/null
+++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala
@@ -0,0 +1,67 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords
+
+import org.apache.hadoop.io.{BytesWritable, NullWritable}
+import org.apache.spark.sql._
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types.StructType
+import org.tensorflow.hadoop.io.TFRecordFileOutputFormat
+import org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowEncoder
+
+/**
+ * Provides access to TensorFlow record source
+ */
+class DefaultSource extends DataSourceRegister
+ with CreatableRelationProvider
+ with RelationProvider
+ with SchemaRelationProvider{
+
+ /**
+ * Short alias for spark-tensorflow data source.
+ */
+ override def shortName(): String = "tfrecords"
+
+ // Writes DataFrame as TensorFlow Records
+ override def createRelation(
+ sqlContext: SQLContext,
+ mode: SaveMode,
+ parameters: Map[String, String],
+ data: DataFrame): BaseRelation = {
+
+ val path = parameters("path")
+
+ //Export DataFrame as TFRecords
+ val features = data.rdd.map(row => {
+ val example = DefaultTfRecordRowEncoder.encodeTfRecord(row)
+ (new BytesWritable(example.toByteArray), NullWritable.get())
+ })
+ features.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat](path)
+
+ TensorflowRelation(parameters)(sqlContext.sparkSession)
+ }
+
+ override def createRelation(sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: StructType): BaseRelation = {
+ TensorflowRelation(parameters, Some(schema))(sqlContext.sparkSession)
+ }
+
+ // Reads TensorFlow Records into DataFrame
+ override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): TensorflowRelation = {
+ TensorflowRelation(parameters)(sqlContext.sparkSession)
+ }
+}
diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowInferSchema.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowInferSchema.scala
new file mode 100644
index 00000000..5a108add
--- /dev/null
+++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowInferSchema.scala
@@ -0,0 +1,197 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types._
+import org.tensorflow.example.{Example, Feature}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.Map
+import scala.util.control.Exception._
+
+object TensorflowInferSchema {
+
+ /**
+ * Similar to the JSON schema inference.
+ * [[org.apache.spark.sql.execution.datasources.json.InferSchema]]
+ * 1. Infer type of each row
+ * 2. Merge row types to find common type
+ * 3. Replace any null types with string type
+ */
+ def apply(exampleRdd: RDD[Example]): StructType = {
+ val startType: Map[String, DataType] = Map.empty[String, DataType]
+ val rootTypes: Map[String, DataType] = exampleRdd.aggregate(startType)(inferRowType, mergeFieldTypes)
+ val columnsList = rootTypes.map {
+ case (featureName, featureType) =>
+ if (featureType == null) {
+ StructField(featureName, StringType)
+ }
+ else {
+ StructField(featureName, featureType)
+ }
+ }
+ StructType(columnsList.toSeq)
+ }
+
+ private def inferRowType(schemaSoFar: Map[String, DataType], next: Example): Map[String, DataType] = {
+ next.getFeatures.getFeatureMap.asScala.map {
+ case (featureName, feature) => {
+ val currentType = inferField(feature)
+ if (schemaSoFar.contains(featureName)) {
+ val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType)
+ schemaSoFar(featureName) = updatedType.getOrElse(null)
+ }
+ else {
+ schemaSoFar += (featureName -> currentType)
+ }
+ }
+ }
+ schemaSoFar
+ }
+
+ private def mergeFieldTypes(first: Map[String, DataType], second: Map[String, DataType]): Map[String, DataType] = {
+ //Merge two maps and do the comparison.
+ val mutMap = collection.mutable.Map[String, DataType]((first.keySet ++ second.keySet)
+ .map(key => (key, findTightestCommonType(first.getOrElse(key, null), second.getOrElse(key, null)).get))
+ .toSeq: _*)
+ mutMap
+ }
+
+ /**
+ * Infer Feature datatype based on field number
+ */
+ private def inferField(feature: Feature): DataType = {
+ feature.getKindCase.getNumber match {
+ case Feature.BYTES_LIST_FIELD_NUMBER => {
+ StringType
+ }
+ case Feature.INT64_LIST_FIELD_NUMBER => {
+ parseInt64List(feature)
+ }
+ case Feature.FLOAT_LIST_FIELD_NUMBER => {
+ parseFloatList(feature)
+ }
+ case _ => throw new RuntimeException("unsupported type ...")
+ }
+ }
+
+ private def parseInt64List(feature: Feature): DataType = {
+ val int64List = feature.getInt64List.getValueList.asScala.toArray
+ val length = int64List.size
+ if (length == 0) {
+ null
+ }
+ else if (length > 1) {
+ ArrayType(LongType)
+ }
+ else {
+ val fieldValue = int64List(0).toString
+ parseInteger(fieldValue)
+ }
+ }
+
+ private def parseFloatList(feature: Feature): DataType = {
+ val floatList = feature.getFloatList.getValueList.asScala.toArray
+ val length = floatList.size
+ if (length == 0) {
+ null
+ }
+ else if (length > 1) {
+ ArrayType(DoubleType)
+ }
+ else {
+ val fieldValue = floatList(0).toString
+ parseFloat(fieldValue)
+ }
+ }
+
+ private def parseInteger(field: String): DataType = if (allCatch.opt(field.toInt).isDefined) {
+ IntegerType
+ }
+ else {
+ parseLong(field)
+ }
+
+ private def parseLong(field: String): DataType = if (allCatch.opt(field.toLong).isDefined) {
+ LongType
+ }
+ else {
+ throw new RuntimeException("Unable to parse field datatype to int64...")
+ }
+
+ private def parseFloat(field: String): DataType = {
+ if ((allCatch opt field.toFloat).isDefined) {
+ FloatType
+ }
+ else {
+ parseDouble(field)
+ }
+ }
+
+ private def parseDouble(field: String): DataType = if (allCatch.opt(field.toDouble).isDefined) {
+ DoubleType
+ }
+ else {
+ throw new RuntimeException("Unable to parse field datatype to float64...")
+ }
+ /**
+ * Copied from internal Spark api
+ * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
+ */
+ private val numericPrecedence: IndexedSeq[DataType] =
+ IndexedSeq[DataType](IntegerType,
+ LongType,
+ FloatType,
+ DoubleType,
+ StringType)
+
+ private def getNumericPrecedence(dataType: DataType): Int = {
+ dataType match {
+ case x if x.equals(IntegerType) => 0
+ case x if x.equals(LongType) => 1
+ case x if x.equals(FloatType) => 2
+ case x if x.equals(DoubleType) => 3
+ case x if x.equals(ArrayType(LongType)) => 4
+ case x if x.equals(ArrayType(DoubleType)) => 5
+ case x if x.equals(StringType) => 6
+ case _ => throw new RuntimeException("Unable to get the precedence for given datatype...")
+ }
+ }
+
+ /**
+ * Copied from internal Spark api
+ * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
+ */
+ private val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
+ case (t1, t2) if t1 == t2 => Some(t1)
+ case (null, t2) => Some(t2)
+ case (t1, null) => Some(t1)
+ case (t1, t2) if t1.equals(ArrayType(LongType)) && t2.equals(ArrayType(DoubleType)) => Some(ArrayType(DoubleType))
+ case (t1, t2) if t1.equals(ArrayType(DoubleType)) && t2.equals(ArrayType(LongType)) => Some(ArrayType(DoubleType))
+ case (StringType, t2) => Some(StringType)
+ case (t1, StringType) => Some(StringType)
+
+ // Promote numeric types to the highest of the two and all numeric types to unlimited decimal
+ case (t1, t2) =>
+ val t1Precedence = getNumericPrecedence(t1)
+ val t2Precedence = getNumericPrecedence(t2)
+ val newType = if (t1Precedence > t2Precedence) t1 else t2
+ Some(newType)
+ case _ => None
+ }
+}
+
diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala
new file mode 100644
index 00000000..7766aa05
--- /dev/null
+++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala
@@ -0,0 +1,49 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords
+
+import org.apache.hadoop.io.{BytesWritable, NullWritable}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.sources.{BaseRelation, TableScan}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{Row, SQLContext, SparkSession}
+import org.tensorflow.example.Example
+import org.tensorflow.hadoop.io.TFRecordFileInputFormat
+import org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowDecoder
+
+
+case class TensorflowRelation(options: Map[String, String], customSchema: Option[StructType]=None)(@transient val session: SparkSession) extends BaseRelation with TableScan {
+
+ //Import TFRecords as DataFrame happens here
+ lazy val (tf_rdd, tf_schema) = {
+ val rdd = session.sparkContext.newAPIHadoopFile(options("path"), classOf[TFRecordFileInputFormat], classOf[BytesWritable], classOf[NullWritable])
+
+ val exampleRdd = rdd.map {
+ case (bytesWritable, nullWritable) => Example.parseFrom(bytesWritable.getBytes)
+ }
+
+ val finalSchema = customSchema.getOrElse(TensorflowInferSchema(exampleRdd))
+
+ (exampleRdd.map(example => DefaultTfRecordRowDecoder.decodeTfRecord(example, finalSchema)), finalSchema)
+ }
+
+ override def sqlContext: SQLContext = session.sqlContext
+
+ override def schema: StructType = tf_schema
+
+ override def buildScan(): RDD[Row] = tf_rdd
+}
+
diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala
new file mode 100644
index 00000000..ccde03aa
--- /dev/null
+++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala
@@ -0,0 +1,70 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords.serde
+
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.Row
+import org.tensorflow.example._
+import scala.collection.JavaConverters._
+
+trait TfRecordRowDecoder {
+ /**
+ * Decodes each TensorFlow "Example" as DataFrame "Row"
+ *
+ * Maps each feature in Example to element in Row with DataType based on custom schema or
+ * default mapping of Int64List, FloatList, BytesList to column data type
+ *
+ * @param example TensorFlow Example to decode
+ * @param schema Decode Example using specified schema
+ * @return a DataFrame row
+ */
+ def decodeTfRecord(example: Example, schema: StructType): Row
+}
+
+object DefaultTfRecordRowDecoder extends TfRecordRowDecoder {
+
+ /**
+ * Decodes each TensorFlow "Example" as DataFrame "Row"
+ *
+ * Maps each feature in Example to element in Row with DataType based on custom schema
+ *
+ * @param example TensorFlow Example to decode
+ * @param schema Decode Example using specified schema
+ * @return a DataFrame row
+ */
+ def decodeTfRecord(example: Example, schema: StructType): Row = {
+ val row = Array.fill[Any](schema.length)(null)
+ example.getFeatures.getFeatureMap.asScala.foreach {
+ case (featureName, feature) =>
+ val index = schema.fieldIndex(featureName)
+ val colDataType = schema.fields(index).dataType
+ row(index) = colDataType match {
+ case IntegerType => IntFeatureDecoder.decode(feature)
+ case LongType => LongFeatureDecoder.decode(feature)
+ case FloatType => FloatFeatureDecoder.decode(feature)
+ case DoubleType => DoubleFeatureDecoder.decode(feature)
+ case ArrayType(IntegerType, true) => IntListFeatureDecoder.decode(feature)
+ case ArrayType(LongType, _) => LongListFeatureDecoder.decode(feature)
+ case ArrayType(FloatType, _) => FloatListFeatureDecoder.decode(feature)
+ case ArrayType(DoubleType, _) => DoubleListFeatureDecoder.decode(feature)
+ case StringType => StringFeatureDecoder.decode(feature)
+ case _ => throw new RuntimeException(s"Cannot convert feature to unsupported data type ${colDataType}")
+ }
+ }
+ Row.fromSeq(row)
+ }
+}
+
diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala
new file mode 100644
index 00000000..3d9072a5
--- /dev/null
+++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala
@@ -0,0 +1,66 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords.serde
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.tensorflow.example._
+
+trait TfRecordRowEncoder {
+ /**
+ * Encodes each Row as TensorFlow "Example"
+ *
+ * Maps each column in Row to one of Int64List, FloatList, BytesList based on the column data type
+ *
+ * @param row a DataFrame row
+ * @return TensorFlow Example
+ */
+ def encodeTfRecord(row: Row): Example
+}
+
+object DefaultTfRecordRowEncoder extends TfRecordRowEncoder {
+
+ /**
+ * Encodes each Row as TensorFlow "Example"
+ *
+ * Maps each column in Row to one of Int64List, FloatList, BytesList based on the column data type
+ *
+ * @param row a DataFrame row
+ * @return TensorFlow Example
+ */
+ def encodeTfRecord(row: Row): Example = {
+ val features = Features.newBuilder()
+ val example = Example.newBuilder()
+
+ row.schema.zipWithIndex.map {
+ case (structField, index) =>
+ val value = row.get(index)
+ val feature = structField.dataType match {
+ case IntegerType | LongType => Int64ListFeatureEncoder.encode(value)
+ case FloatType | DoubleType => FloatListFeatureEncoder.encode(value)
+ case ArrayType(IntegerType, _) | ArrayType(LongType, _) => Int64ListFeatureEncoder.encode(value)
+ case ArrayType(DoubleType, _) => FloatListFeatureEncoder.encode(value)
+ case _ => BytesListFeatureEncoder.encode(value)
+ }
+ features.putFeature(structField.name, feature)
+ }
+
+ features.build()
+ example.setFeatures(features)
+ example.build()
+ }
+}
+
diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoder.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoder.scala
new file mode 100644
index 00000000..c390ae42
--- /dev/null
+++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoder.scala
@@ -0,0 +1,187 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords.serde
+
+import org.tensorflow.example.Feature
+
+import scala.collection.JavaConverters._
+
+trait FeatureDecoder[T] {
+ /**
+ * Decodes each TensorFlow "Feature" to desired Scala type
+ *
+ * @param feature TensorFlow Feature
+ * @return Decoded feature
+ */
+ def decode(feature: Feature): T
+}
+
+/**
+ * Decode TensorFlow "Feature" to Integer
+ */
+object IntFeatureDecoder extends FeatureDecoder[Int] {
+ override def decode(feature: Feature): Int = {
+ require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List")
+ try {
+ val int64List = feature.getInt64List.getValueList
+ require(int64List.size() == 1, "Length of Int64List must equal 1")
+ int64List.get(0).intValue()
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert feature to Int.", ex)
+ }
+ }
+}
+
+/**
+ * Decode TensorFlow "Feature" to Seq[Int]
+ */
+object IntListFeatureDecoder extends FeatureDecoder[Seq[Int]] {
+ override def decode(feature: Feature): Seq[Int] = {
+ require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List")
+ try {
+ val array = feature.getInt64List.getValueList.asScala.toArray
+ array.map(_.toInt)
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert feature to Seq[Int].", ex)
+ }
+ }
+}
+
+/**
+ * Decode TensorFlow "Feature" to Long
+ */
+object LongFeatureDecoder extends FeatureDecoder[Long] {
+ override def decode(feature: Feature): Long = {
+ require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List")
+ try {
+ val int64List = feature.getInt64List.getValueList
+ require(int64List.size() == 1, "Length of Int64List must equal 1")
+ int64List.get(0).longValue()
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert feature to Long.", ex)
+ }
+ }
+}
+
+/**
+ * Decode TensorFlow "Feature" to Seq[Long]
+ */
+object LongListFeatureDecoder extends FeatureDecoder[Seq[Long]] {
+ override def decode(feature: Feature): Seq[Long] = {
+ require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List")
+ try {
+ val array = feature.getInt64List.getValueList.asScala.toArray
+ array.map(_.toLong)
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert feature to Array[Long].", ex)
+ }
+ }
+}
+
+/**
+ * Decode TensorFlow "Feature" to Float
+ */
+object FloatFeatureDecoder extends FeatureDecoder[Float] {
+ override def decode(feature: Feature): Float = {
+ require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList")
+ try {
+ val floatList = feature.getFloatList.getValueList
+ require(floatList.size() == 1, "Length of FloatList must equal 1")
+ floatList.get(0).floatValue()
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert feature to Float.", ex)
+ }
+ }
+}
+
+/**
+ * Decode TensorFlow "Feature" to Seq[Float]
+ */
+object FloatListFeatureDecoder extends FeatureDecoder[Seq[Float]] {
+ override def decode(feature: Feature): Seq[Float] = {
+ require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList")
+ try {
+ val array = feature.getFloatList.getValueList.asScala.toArray
+ array.map(_.toFloat)
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert feature to Array[Float].", ex)
+ }
+ }
+}
+
+/**
+ * Decode TensorFlow "Feature" to Double
+ */
+object DoubleFeatureDecoder extends FeatureDecoder[Double] {
+ override def decode(feature: Feature): Double = {
+ require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList")
+ try {
+ val floatList = feature.getFloatList.getValueList
+ require(floatList.size() == 1, "Length of FloatList must equal 1")
+ floatList.get(0).doubleValue()
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert feature to Double.", ex)
+ }
+ }
+}
+
+/**
+ * Decode TensorFlow "Feature" to Seq[Double]
+ */
+object DoubleListFeatureDecoder extends FeatureDecoder[Seq[Double]] {
+ override def decode(feature: Feature): Seq[Double] = {
+ require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList")
+ try {
+ val array = feature.getFloatList.getValueList.asScala.toArray
+ array.map(_.toDouble)
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert feature to Array[Double].", ex)
+ }
+ }
+}
+
+/**
+ * Decode TensorFlow "Feature" to String
+ */
+object StringFeatureDecoder extends FeatureDecoder[String] {
+ override def decode(feature: Feature): String = {
+ require(feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList")
+ try {
+ feature.getBytesList.toByteString.toStringUtf8.trim
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert feature to String.", ex)
+ }
+ }
+}
+
diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala
new file mode 100644
index 00000000..646cbadd
--- /dev/null
+++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala
@@ -0,0 +1,118 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords.serde
+
+import org.tensorflow.example.{BytesList, Feature, FloatList, Int64List}
+import org.tensorflow.hadoop.shaded.protobuf.ByteString
+import org.tensorflow.spark.datasources.tfrecords.DataTypesConvertor
+
+trait FeatureEncoder {
+ /**
+ * Encodes input value as TensorFlow "Feature"
+ *
+ * Maps input value to one of Int64List, FloatList, BytesList
+ *
+ * @param value Input value
+ * @return TensorFlow Feature
+ */
+ def encode(value: Any): Feature
+}
+
+/**
+ * Encode input value to Int64List
+ */
+object Int64ListFeatureEncoder extends FeatureEncoder {
+ override def encode(value: Any): Feature = {
+ try {
+ val int64List = value match {
+ case i: Int => Int64List.newBuilder().addValue(i.toLong).build()
+ case l: Long => Int64List.newBuilder().addValue(l).build()
+ case arr: scala.collection.mutable.WrappedArray[_] => toInt64List(arr.toArray[Any])
+ case arr: Array[_] => toInt64List(arr)
+ case seq: Seq[_] => toInt64List(seq.toArray[Any])
+ case _ => throw new RuntimeException(s"Cannot convert object $value to Int64List")
+ }
+ Feature.newBuilder().setInt64List(int64List).build()
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to Int64List feature.", ex)
+ }
+ }
+
+ private def toInt64List[T](arr: Array[T]): Int64List = {
+ val intListBuilder = Int64List.newBuilder()
+ arr.foreach(x => {
+ require(x != null, "Int64List with null values is not supported")
+ val longValue = DataTypesConvertor.toLong(x)
+ intListBuilder.addValue(longValue)
+ })
+ intListBuilder.build()
+ }
+}
+
+/**
+ * Encode input value to FloatList
+ */
+object FloatListFeatureEncoder extends FeatureEncoder {
+ override def encode(value: Any): Feature = {
+ try {
+ val floatList = value match {
+ case i: Int => FloatList.newBuilder().addValue(i.toFloat).build()
+ case l: Long => FloatList.newBuilder().addValue(l.toFloat).build()
+ case f: Float => FloatList.newBuilder().addValue(f).build()
+ case d: Double => FloatList.newBuilder().addValue(d.toFloat).build()
+ case arr: scala.collection.mutable.WrappedArray[_] => toFloatList(arr.toArray[Any])
+ case arr: Array[_] => toFloatList(arr)
+ case seq: Seq[_] => toFloatList(seq.toArray[Any])
+ case _ => throw new RuntimeException(s"Cannot convert object $value to FloatList")
+ }
+ Feature.newBuilder().setFloatList(floatList).build()
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to FloatList feature.", ex)
+ }
+ }
+
+ private def toFloatList[T](arr: Array[T]): FloatList = {
+ val floatListBuilder = FloatList.newBuilder()
+ arr.foreach(x => {
+ require(x != null, "FloatList with null values is not supported")
+ val longValue = DataTypesConvertor.toFloat(x)
+ floatListBuilder.addValue(longValue)
+ })
+ floatListBuilder.build()
+ }
+}
+
+/**
+ * Encode input value to ByteList
+ */
+object BytesListFeatureEncoder extends FeatureEncoder {
+ override def encode(value: Any): Feature = {
+ try {
+ val byteList = BytesList.newBuilder().addValue(ByteString.copyFrom(value.toString.getBytes)).build()
+ Feature.newBuilder().setBytesList(byteList).build()
+ }
+ catch {
+ case ex: Exception =>
+ throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to ByteList feature.", ex)
+ }
+ }
+}
+
+
diff --git a/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/SharedSparkSessionSuite.scala b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/SharedSparkSessionSuite.scala
new file mode 100644
index 00000000..e87c815b
--- /dev/null
+++ b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/SharedSparkSessionSuite.scala
@@ -0,0 +1,45 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords
+
+import java.io.File
+
+import org.apache.commons.io.FileUtils
+import org.apache.spark.SharedSparkSession
+import org.junit.{After, Before}
+import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}
+
+
+trait BaseSuite extends WordSpecLike with Matchers with BeforeAndAfterAll
+
+class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite {
+ val TF_SANDBOX_DIR = "tf-sandbox"
+ val file = new File(TF_SANDBOX_DIR)
+
+ @Before
+ override def beforeAll() = {
+ super.setUp()
+ FileUtils.deleteQuietly(file)
+ file.mkdirs()
+ }
+
+ @After
+ override def afterAll() = {
+ FileUtils.deleteQuietly(file)
+ super.tearDown()
+ }
+}
+
diff --git a/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowSuite.scala b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowSuite.scala
new file mode 100644
index 00000000..0e37968e
--- /dev/null
+++ b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowSuite.scala
@@ -0,0 +1,210 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, Row}
+import org.tensorflow.example._
+import org.tensorflow.hadoop.shaded.protobuf.ByteString
+import org.tensorflow.spark.datasources.tfrecords.serde.{DefaultTfRecordRowDecoder, DefaultTfRecordRowEncoder}
+
+import scala.collection.JavaConverters._
+
+class TensorflowSuite extends SharedSparkSessionSuite {
+
+ "Spark TensorFlow module" should {
+
+ "Test Import/Export" in {
+
+ val path = s"$TF_SANDBOX_DIR/output25.tfr"
+ val testRows: Array[Row] = Array(
+ new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
+ new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
+
+ val schema = StructType(List(
+ StructField("id", IntegerType),
+ StructField("IntegerTypelabel", IntegerType),
+ StructField("LongTypelabel", LongType),
+ StructField("FloatTypelabel", FloatType),
+ StructField("DoubleTypelabel", DoubleType),
+ StructField("vectorlabel", ArrayType(DoubleType, true)),
+ StructField("name", StringType)))
+
+ val rdd = spark.sparkContext.parallelize(testRows)
+
+ val df: DataFrame = spark.createDataFrame(rdd, schema)
+ df.write.format("tfrecords").save(path)
+
+ //If schema is not provided. It will automatically infer schema
+ val importedDf: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
+ val actualDf = importedDf.select("id", "IntegerTypelabel", "LongTypelabel", "FloatTypelabel", "DoubleTypelabel", "vectorlabel", "name").sort("name")
+
+ val expectedRows = df.collect()
+ val actualRows = actualDf.collect()
+
+ expectedRows should equal(actualRows)
+ }
+
+ "Encode given Row as TensorFlow example" in {
+ val schemaStructType = StructType(Array(
+ StructField("IntegerTypelabel", IntegerType),
+ StructField("LongTypelabel", LongType),
+ StructField("FloatTypelabel", FloatType),
+ StructField("DoubleTypelabel", DoubleType),
+ StructField("vectorlabel", ArrayType(DoubleType, true)),
+ StructField("strlabel", StringType)
+ ))
+ val doubleArray = Array(1.1, 111.1, 11111.1)
+ val expectedFloatArray = Array(1.1F, 111.1F, 11111.1F)
+
+ val rowWithSchema = new GenericRowWithSchema(Array[Any](1, 23L, 10.0F, 14.0, doubleArray, "r1"), schemaStructType)
+
+ //Encode Sql Row to TensorFlow example
+ val example = DefaultTfRecordRowEncoder.encodeTfRecord(rowWithSchema)
+ import org.tensorflow.example.Feature
+
+ //Verify each Datatype converted to TensorFlow datatypes
+ val featureMap = example.getFeatures.getFeatureMap.asScala
+ assert(featureMap("IntegerTypelabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER)
+ assert(featureMap("IntegerTypelabel").getInt64List.getValue(0).toInt == 1)
+
+ assert(featureMap("LongTypelabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER)
+ assert(featureMap("LongTypelabel").getInt64List.getValue(0).toInt == 23)
+
+ assert(featureMap("FloatTypelabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
+ assert(featureMap("FloatTypelabel").getFloatList.getValue(0) == 10.0F)
+
+ assert(featureMap("DoubleTypelabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
+ assert(featureMap("DoubleTypelabel").getFloatList.getValue(0) == 14.0F)
+
+ assert(featureMap("vectorlabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
+ assert(featureMap("vectorlabel").getFloatList.getValueList.toArray === expectedFloatArray)
+
+ assert(featureMap("strlabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER)
+ assert(featureMap("strlabel").getBytesList.toByteString.toStringUtf8.trim == "r1")
+
+ }
+
+ "Throw an exception for a vector with null values during Encode" in {
+ intercept[Exception] {
+ val schemaStructType = StructType(Array(
+ StructField("vectorlabel", ArrayType(DoubleType, true))
+ ))
+ val doubleArray = Array(1.1, null, 111.1, null, 11111.1)
+
+ val rowWithSchema = new GenericRowWithSchema(Array[Any](doubleArray), schemaStructType)
+
+ //Throws NullPointerException
+ DefaultTfRecordRowEncoder.encodeTfRecord(rowWithSchema)
+ }
+ }
+
+ "Decode given TensorFlow Example as Row" in {
+
+ //Here Vector with null's are not supported
+ val expectedRow = new GenericRow(Array[Any](1, 23L, 10.0F, 14.0, Seq(1.0, 2.0), "r1"))
+
+ val schema = StructType(List(
+ StructField("IntegerTypelabel", IntegerType),
+ StructField("LongTypelabel", LongType),
+ StructField("FloatTypelabel", FloatType),
+ StructField("DoubleTypelabel", DoubleType),
+ StructField("vectorlabel", ArrayType(DoubleType)),
+ StructField("strlabel", StringType)))
+
+ //Build example
+ val intFeature = Int64List.newBuilder().addValue(1)
+ val longFeature = Int64List.newBuilder().addValue(23L)
+ val floatFeature = FloatList.newBuilder().addValue(10.0F)
+ val doubleFeature = FloatList.newBuilder().addValue(14.0F)
+ val vectorFeature = FloatList.newBuilder().addValue(1F).addValue(2F).build()
+ val strFeature = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build()
+ val features = Features.newBuilder()
+ .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature).build())
+ .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature).build())
+ .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature).build())
+ .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature).build())
+ .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature).build())
+ .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature).build())
+ .build()
+ val example = Example.newBuilder()
+ .setFeatures(features)
+ .build()
+
+ //Decode TensorFlow example to Sql Row
+ val actualRow = DefaultTfRecordRowDecoder.decodeTfRecord(example, schema)
+ actualRow should equal(expectedRow)
+ }
+
+ "Check infer schema" in {
+
+ //Build example1
+ val intFeature1 = Int64List.newBuilder().addValue(1)
+ val longFeature1 = Int64List.newBuilder().addValue(Int.MaxValue + 10L)
+ val floatFeature1 = FloatList.newBuilder().addValue(10.0F)
+ val doubleFeature1 = FloatList.newBuilder().addValue(14.0F)
+ val vectorFeature1 = FloatList.newBuilder().addValue(1F).build()
+ val strFeature1 = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build()
+ val features1 = Features.newBuilder()
+ .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature1).build())
+ .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature1).build())
+ .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature1).build())
+ .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature1).build())
+ .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature1).build())
+ .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature1).build())
+ .build()
+ val example1 = Example.newBuilder()
+ .setFeatures(features1)
+ .build()
+
+ //Build example2
+ val intFeature2 = Int64List.newBuilder().addValue(2)
+ val longFeature2 = Int64List.newBuilder().addValue(24)
+ val floatFeature2 = FloatList.newBuilder().addValue(12.0F)
+ val doubleFeature2 = FloatList.newBuilder().addValue(Float.MaxValue + 15)
+ val vectorFeature2 = FloatList.newBuilder().addValue(2F).addValue(2F).build()
+ val strFeature2 = BytesList.newBuilder().addValue(ByteString.copyFrom("r2".getBytes)).build()
+ val features2 = Features.newBuilder()
+ .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature2).build())
+ .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature2).build())
+ .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature2).build())
+ .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature2).build())
+ .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature2).build())
+ .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature2).build())
+ .build()
+ val example2 = Example.newBuilder()
+ .setFeatures(features2)
+ .build()
+
+ val exampleRDD: RDD[Example] = spark.sparkContext.parallelize(List(example1, example2))
+
+ val actualSchema = TensorflowInferSchema(exampleRDD)
+
+ //Verify each TensorFlow Datatype is inferred as one of our Datatype
+ actualSchema.fields.map { colum =>
+ colum.name match {
+ case "IntegerTypelabel" => colum.dataType.equals(IntegerType)
+ case "LongTypelabel" => colum.dataType.equals(LongType)
+ case "FloatTypelabel" | "DoubleTypelabel" | "vectorlabel" => colum.dataType.equals(FloatType)
+ case "strlabel" => colum.dataType.equals(StringType)
+ }
+ }
+ }
+ }
+}
+
diff --git a/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoderTest.scala b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoderTest.scala
new file mode 100644
index 00000000..954803d1
--- /dev/null
+++ b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoderTest.scala
@@ -0,0 +1,209 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords.serde
+
+import org.scalatest.{Matchers, WordSpec}
+import org.tensorflow.example.{BytesList, Feature, FloatList, Int64List}
+import org.tensorflow.hadoop.shaded.protobuf.ByteString
+
+class FeatureDecoderTest extends WordSpec with Matchers {
+
+ "Int Feature decoder" should {
+
+ "Decode Feature to Int" in {
+ val int64List = Int64List.newBuilder().addValue(4).build()
+ val intFeature = Feature.newBuilder().setInt64List(int64List).build()
+ IntFeatureDecoder.decode(intFeature) should equal(4)
+ }
+
+ "Throw an exception if length of feature array exceeds 1" in {
+ intercept[Exception] {
+ val int64List = Int64List.newBuilder().addValue(4).addValue(7).build()
+ val intFeature = Feature.newBuilder().setInt64List(int64List).build()
+ IntFeatureDecoder.decode(intFeature)
+ }
+ }
+
+ "Throw an exception if feature is not an Int64List" in {
+ intercept[Exception] {
+ val floatList = FloatList.newBuilder().addValue(4).build()
+ val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
+ IntFeatureDecoder.decode(floatFeature)
+ }
+ }
+ }
+
+ "Int List Feature decoder" should {
+
+ "Decode Feature to Int List" in {
+ val int64List = Int64List.newBuilder().addValue(3).addValue(9).build()
+ val intFeature = Feature.newBuilder().setInt64List(int64List).build()
+ IntListFeatureDecoder.decode(intFeature) should equal(Seq(3,9))
+ }
+
+ "Throw an exception if feature is not an Int64List" in {
+ intercept[Exception] {
+ val floatList = FloatList.newBuilder().addValue(4).build()
+ val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
+ IntListFeatureDecoder.decode(floatFeature)
+ }
+ }
+ }
+
+ "Long Feature decoder" should {
+
+ "Decode Feature to Long" in {
+ val int64List = Int64List.newBuilder().addValue(5L).build()
+ val intFeature = Feature.newBuilder().setInt64List(int64List).build()
+ LongFeatureDecoder.decode(intFeature) should equal(5L)
+ }
+
+ "Throw an exception if length of feature array exceeds 1" in {
+ intercept[Exception] {
+ val int64List = Int64List.newBuilder().addValue(4L).addValue(10L).build()
+ val intFeature = Feature.newBuilder().setInt64List(int64List).build()
+ LongFeatureDecoder.decode(intFeature)
+ }
+ }
+
+ "Throw an exception if feature is not an Int64List" in {
+ intercept[Exception] {
+ val floatList = FloatList.newBuilder().addValue(4).build()
+ val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
+ LongFeatureDecoder.decode(floatFeature)
+ }
+ }
+ }
+
+ "Long List Feature decoder" should {
+
+ "Decode Feature to Long List" in {
+ val int64List = Int64List.newBuilder().addValue(3L).addValue(Int.MaxValue+10L).build()
+ val intFeature = Feature.newBuilder().setInt64List(int64List).build()
+ LongListFeatureDecoder.decode(intFeature) should equal(Seq(3L,Int.MaxValue+10L))
+ }
+
+ "Throw an exception if feature is not an Int64List" in {
+ intercept[Exception] {
+ val floatList = FloatList.newBuilder().addValue(4).build()
+ val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
+ LongListFeatureDecoder.decode(floatFeature)
+ }
+ }
+ }
+
+ "Float Feature decoder" should {
+
+ "Decode Feature to Float" in {
+ val floatList = FloatList.newBuilder().addValue(2.5F).build()
+ val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
+ FloatFeatureDecoder.decode(floatFeature) should equal(2.5F)
+ }
+
+ "Throw an exception if length of feature array exceeds 1" in {
+ intercept[Exception] {
+ val floatList = FloatList.newBuilder().addValue(1.5F).addValue(3.33F).build()
+ val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
+ FloatFeatureDecoder.decode(floatFeature)
+ }
+ }
+
+ "Throw an exception if feature is not a FloatList" in {
+ intercept[Exception] {
+ val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
+ val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
+ FloatFeatureDecoder.decode(bytesFeature)
+ }
+ }
+ }
+
+ "Float List Feature decoder" should {
+
+ "Decode Feature to Float List" in {
+ val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.3F).build()
+ val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
+ FloatListFeatureDecoder.decode(floatFeature) should equal(Seq(2.5F, 4.3F))
+ }
+
+ "Throw an exception if feature is not a FloatList" in {
+ intercept[Exception] {
+ val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
+ val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
+ FloatListFeatureDecoder.decode(bytesFeature)
+ }
+ }
+ }
+
+ "Double Feature decoder" should {
+
+ "Decode Feature to Double" in {
+ val floatList = FloatList.newBuilder().addValue(2.5F).build()
+ val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
+ DoubleFeatureDecoder.decode(floatFeature) should equal(2.5d)
+ }
+
+ "Throw an exception if length of feature array exceeds 1" in {
+ intercept[Exception] {
+ val floatList = FloatList.newBuilder().addValue(1.5F).addValue(3.33F).build()
+ val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
+ DoubleFeatureDecoder.decode(floatFeature)
+ }
+ }
+
+ "Throw an exception if feature is not a FloatList" in {
+ intercept[Exception] {
+ val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
+ val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
+ DoubleFeatureDecoder.decode(bytesFeature)
+ }
+ }
+ }
+
+ "Double List Feature decoder" should {
+
+ "Decode Feature to Double List" in {
+ val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.0F).build()
+ val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
+ DoubleListFeatureDecoder.decode(floatFeature) should equal(Seq(2.5d, 4.0d))
+ }
+
+ "Throw an exception if feature is not a DoubleList" in {
+ intercept[Exception] {
+ val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
+ val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
+ FloatListFeatureDecoder.decode(bytesFeature)
+ }
+ }
+ }
+
+ "Bytes List Feature decoder" should {
+
+ "Decode Feature to Bytes List" in {
+ val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
+ val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
+ StringFeatureDecoder.decode(bytesFeature) should equal("str-input")
+ }
+
+ "Throw an exception if feature is not a BytesList" in {
+ intercept[Exception] {
+ val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.0F).build()
+ val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
+ StringFeatureDecoder.decode(floatFeature)
+ }
+ }
+ }
+}
+
diff --git a/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala
new file mode 100644
index 00000000..4c09d568
--- /dev/null
+++ b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala
@@ -0,0 +1,99 @@
+/**
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.spark.datasources.tfrecords.serde
+
+import org.scalatest.{Matchers, WordSpec}
+
+import scala.collection.JavaConverters._
+
+class FeatureEncoderTest extends WordSpec with Matchers {
+
+ "Int64List feature encoder" should {
+ "Encode inputs to Int64List" in {
+ val intFeature = Int64ListFeatureEncoder.encode(5)
+ val longFeature = Int64ListFeatureEncoder.encode(10L)
+ val longListFeature = Int64ListFeatureEncoder.encode(Seq(3L,5L,6L))
+
+ intFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(5L))
+ longFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(10L))
+ longListFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(3L, 5L, 6L))
+ }
+
+ "Throw an exception when inputs contain null" in {
+ intercept[Exception] {
+ Int64ListFeatureEncoder.encode(null)
+ }
+ intercept[Exception] {
+ Int64ListFeatureEncoder.encode(Seq(3,null,6))
+ }
+ }
+
+ "Throw an exception for non-numeric inputs" in {
+ intercept[Exception] {
+ Int64ListFeatureEncoder.encode("bad-input")
+ }
+ }
+ }
+
+ "FloatList feature encoder" should {
+ "Encode inputs to FloatList" in {
+ val intFeature = FloatListFeatureEncoder.encode(5)
+ val longFeature = FloatListFeatureEncoder.encode(10L)
+ val floatFeature = FloatListFeatureEncoder.encode(2.5F)
+ val doubleFeature = FloatListFeatureEncoder.encode(14.6)
+ val floatListFeature = FloatListFeatureEncoder.encode(Seq(1.5F,6.8F,-3.2F))
+
+ intFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(5F))
+ longFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(10F))
+ floatFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(2.5F))
+ doubleFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(14.6F))
+ floatListFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(1.5F,6.8F,-3.2F))
+ }
+
+ "Throw an exception when inputs contain null" in {
+ intercept[Exception] {
+ FloatListFeatureEncoder.encode(null)
+ }
+ intercept[Exception] {
+ FloatListFeatureEncoder.encode(Seq(3,null,6))
+ }
+ }
+
+ "Throw an exception for non-numeric inputs" in {
+ intercept[Exception] {
+ FloatListFeatureEncoder.encode("bad-input")
+ }
+ }
+ }
+
+ "ByteList feature encoder" should {
+ "Encode inputs to ByteList" in {
+ val longFeature = BytesListFeatureEncoder.encode(10L)
+ val longListFeature = BytesListFeatureEncoder.encode(Seq(3L,5L,6L))
+ val strFeature = BytesListFeatureEncoder.encode("str-input")
+
+ longFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("10")
+ longListFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("List(3, 5, 6)")
+ strFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("str-input")
+ }
+
+ "Throw an exception when inputs contain null" in {
+ intercept[Exception] {
+ BytesListFeatureEncoder.encode(null)
+ }
+ }
+ }
+}