From f2a87ca9d0be7e7f826ac96828200c3a4fa616ce Mon Sep 17 00:00:00 2001 From: Karthik Vadla Date: Tue, 7 Mar 2017 12:59:07 -0800 Subject: [PATCH] Add Spark Tensorflow Connector (#34) * Add spark-tensorflow-connector to ecosystem * Update Readme * Add feedback changes --- README.md | 1 + spark/README.md | 85 +++++ spark/build.sbt | 93 ++++++ spark/pom.xml | 306 ++++++++++++++++++ spark/project/build.properties | 1 + spark/project/plugins.sbt | 5 + ...pache.spark.sql.sources.DataSourceRegister | 1 + .../tfrecords/DataTypesConvertor.scala | 49 +++ .../datasources/tfrecords/DefaultSource.scala | 67 ++++ .../tfrecords/TensorflowInferSchema.scala | 197 +++++++++++ .../tfrecords/TensorflowRelation.scala | 49 +++ .../serde/DefaultTfRecordRowDecoder.scala | 70 ++++ .../serde/DefaultTfRecordRowEncoder.scala | 66 ++++ .../tfrecords/serde/FeatureDecoder.scala | 187 +++++++++++ .../tfrecords/serde/FeatureEncoder.scala | 118 +++++++ .../tfrecords/SharedSparkSessionSuite.scala | 45 +++ .../tfrecords/TensorflowSuite.scala | 210 ++++++++++++ .../tfrecords/serde/FeatureDecoderTest.scala | 209 ++++++++++++ .../tfrecords/serde/FeatureEncoderTest.scala | 99 ++++++ 19 files changed, 1858 insertions(+) create mode 100644 spark/README.md create mode 100644 spark/build.sbt create mode 100644 spark/pom.xml create mode 100644 spark/project/build.properties create mode 100644 spark/project/plugins.sbt create mode 100644 spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DataTypesConvertor.scala create mode 100644 spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala create mode 100644 spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowInferSchema.scala create mode 100644 spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala create mode 100644 spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala create mode 100644 spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala create mode 100644 spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoder.scala create mode 100644 spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala create mode 100644 spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/SharedSparkSessionSuite.scala create mode 100644 spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowSuite.scala create mode 100644 spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoderTest.scala create mode 100644 spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala diff --git a/README.md b/README.md index 9bc3afaf..458678fe 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ request. Marathon, deployed on top of Mesos. - [hadoop](hadoop) - TFRecord file InputFormat/OutputFormat for Hadoop MapReduce and Spark. +- [spark](spark) - Spark TensorFlow Connector ## Distributed TensorFlow diff --git a/spark/README.md b/spark/README.md new file mode 100644 index 00000000..32a8c10e --- /dev/null +++ b/spark/README.md @@ -0,0 +1,85 @@ +# spark-tensorflow-connector + +This repo contains a library for loading and storing TensorFlow records with [Apache Spark](http://spark.apache.org/). +The library implements data import from the standard TensorFlow record format ([TFRecords] +(https://www.tensorflow.org/how_tos/reading_data/)) into Spark SQL DataFrames, and data export from DataFrames to TensorFlow records. + +## What's new + +This is the initial release of the `spark-tensorflow-connector` repo. + +## Known issues + +None. + +## Prerequisites + +1. [Apache Spark 2.0 (or later)](http://spark.apache.org/) + +2. [Apache Maven](https://maven.apache.org/) + +## Building the library +You can build library using both Maven and SBT build tools + +#### Maven +Build the library using Maven(3.3) as shown below + +```sh +mvn clean install +``` + +#### SBT +Build the library using SBT(0.13.13) as show below +```sh +sbt clean assembly +``` + +## Using Spark Shell +Run this library in Spark using the `--jars` command line option in `spark-shell` or `spark-submit`. For example: + +Maven Jars +```sh +$SPARK_HOME/bin/spark-shell --jars target/spark-tensorflow-connector-1.0-SNAPSHOT.jar,target/lib/tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar +``` + +SBT Jars +```sh +$SPARK_HOME/bin/spark-shell --jars target/scala-2.11/spark-tensorflow-connector-assembly-1.0.0.jar +``` + +The following code snippet demonstrates usage. + +```scala +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.{ DataFrame, Row } +import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.types._ + +val path = "test-output.tfr" +val testRows: Array[Row] = Array( +new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")), +new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2"))) +val schema = StructType(List(StructField("id", IntegerType), + StructField("IntegerTypelabel", IntegerType), + StructField("LongTypelabel", LongType), + StructField("FloatTypelabel", FloatType), + StructField("DoubleTypelabel", DoubleType), + StructField("vectorlabel", ArrayType(DoubleType, true)), + StructField("name", StringType))) + +val rdd = spark.sparkContext.parallelize(testRows) + +//Save DataFrame as TFRecords +val df: DataFrame = spark.createDataFrame(rdd, schema) +df.write.format("tfrecords").save(path) + +//Read TFRecords into DataFrame. +//The DataFrame schema is inferred from the TFRecords if no custom schema is provided. +val importedDf1: DataFrame = spark.read.format("tfrecords").load(path) +importedDf1.show() + +//Read TFRecords into DataFrame using custom schema +val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path) +importedDf2.show() + +``` diff --git a/spark/build.sbt b/spark/build.sbt new file mode 100644 index 00000000..3ec7bbde --- /dev/null +++ b/spark/build.sbt @@ -0,0 +1,93 @@ +name := "spark-tensorflow-connector" + +organization := "org.tensorflow" + +scalaVersion in Global := "2.11.8" + +spName := "tensorflow/spark-tensorflow-connector" + +sparkVersion := "2.1.0" + +sparkComponents ++= Seq("sql", "mllib") + +version := "1.0.0" + +def ProjectName(name: String,path:String): Project = Project(name, file(path)) + +resolvers in Global ++= Seq("https://tap.jfrog.io/tap/public" at "https://tap.jfrog.io/tap/public" , + "https://tap.jfrog.io/tap/public-snapshots" at "https://tap.jfrog.io/tap/public-snapshots" , + "https://repo.maven.apache.org/maven2" at "https://repo.maven.apache.org/maven2" ) + +val `junit_junit` = "junit" % "junit" % "4.12" + +val `org.apache.hadoop_hadoop-yarn-api` = "org.apache.hadoop" % "hadoop-yarn-api" % "2.7.3" + +val `org.apache.spark_spark-core_2.11` = "org.apache.spark" % "spark-core_2.11" % "2.1.0" + +val `org.apache.spark_spark-sql_2.11` = "org.apache.spark" % "spark-sql_2.11" % "2.1.0" + +val `org.apache.spark_spark-mllib_2.11` = "org.apache.spark" % "spark-mllib_2.11" % "2.1.0" + +val `org.scalatest_scalatest_2.11` = "org.scalatest" % "scalatest_2.11" % "2.2.6" + +val `org.tensorflow_tensorflow-hadoop` = "org.tensorflow" % "tensorflow-hadoop" % "1.0-01232017-SNAPSHOT" + +libraryDependencies in Global ++= Seq(`org.tensorflow_tensorflow-hadoop` classifier "shaded-protobuf", + `org.scalatest_scalatest_2.11` % "test" , + `org.apache.spark_spark-sql_2.11` % "provided" , + `org.apache.spark_spark-mllib_2.11` % "test" classifier "tests", + `org.apache.spark_spark-core_2.11` % "provided" , + `org.apache.hadoop_hadoop-yarn-api` % "provided" , + `junit_junit` % "test" ) + +assemblyExcludedJars in assembly := { + val cp = (fullClasspath in assembly).value + cp filterNot {x => List("spark-tensorflow-connector-1.0-SNAPSHOT.jar", + "tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar").contains(x.data.getName)} +} + +/******************** + * Release settings * + ********************/ + +spIgnoreProvided := true + +spAppendScalaVersion := true + +// If you published your package to Maven Central for this release (must be done prior to spPublish) +spIncludeMaven := false + +publishMavenStyle := true + +licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0")) + +pomExtra := + https://github.com/tensorflow/ecosystem + + git@github.com:tensorflow/ecosystem.git + scm:git:git@github.com:tensorflow/ecosystem.git + + + + karthikvadla + Karthik Vadla + https://github.com/karthikvadla + + + skavulya + Soila Kavulya + https://github.com/skavulya + + + joyeshmishra + Joyesh Mishra + https://github.com/joyeshmishra + + + +credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials") // A file containing credentials + +// Add assembly jar to Spark package +test in assembly := {} + +spShade := true diff --git a/spark/pom.xml b/spark/pom.xml new file mode 100644 index 00000000..82b45929 --- /dev/null +++ b/spark/pom.xml @@ -0,0 +1,306 @@ + + + 4.0.0 + + org.tensorflow + spark-tensorflow-connector + jar + 1.0-SNAPSHOT + + + + central1 + http://central1.maven.org/maven2 + + true + + + false + + + + + tap + https://tap.jfrog.io/tap/public + + false + + + true + + + + tap-snapshots + https://tap.jfrog.io/tap/public-snapshots + + true + + + false + + + + + + + compile + + true + + !NEVERSETME + + + + + + + + true + net.alchim31.maven + scala-maven-plugin + 3.1.6 + + + compile + + add-source + compile + + + + -Xms256m + -Xmx512m + + + -g:vars + -deprecation + -feature + -unchecked + -Xfatal-warnings + -language:implicitConversions + -language:existentials + + + + + test + + add-source + testCompile + + + + + incremental + true + 2.11 + false + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-dependencies + process-resources + + copy-dependencies + + + provided + true + org.apache.spark,junit,org.scalatest + ${project.build.directory}/lib + + + + + + + org.codehaus.mojo + properties-maven-plugin + 1.0.0 + + + generate-resources + + write-project-properties + + + ${project.build.outputDirectory}/maven.properties + + + + + + + + + net.alchim31.maven + scala-maven-plugin + + + + + + + test + + true + + !NEVERSETME + + + + + + + + true + net.alchim31.maven + scala-maven-plugin + 3.2.2 + + + compile + + + + + true + org.scalatest + scalatest-maven-plugin + 1.0 + + ${project.build.directory}/surefire-reports + . + WDF TestSuite.txt + false + FTD + -Xmx1024m -XX:PermSize=256m -XX:MaxDirectMemorySize=1000m + + + + scalaTest + test + + test + + + + + + + + + + net.alchim31.maven + scala-maven-plugin + + + + + + + org.scalatest + scalatest_2.11 + 2.2.6 + test + + + + + + + org.scalatest + scalatest_2.11 + test + + + + + + + + + + net.alchim31.maven + scala-maven-plugin + + + org.apache.maven.plugins + maven-dependency-plugin + + + org.scalatest + scalatest-maven-plugin + + + org.apache.maven.plugins + maven-compiler-plugin + 3.0 + + 1.8 + 1.8 + + + + + + + core/src/main/resources + + reference.conf + + + + core/src/test/resources + + + + + + + org.tensorflow + tensorflow-hadoop + 1.0-01232017-SNAPSHOT + shaded-protobuf + + + org.apache.spark + spark-core_2.11 + 2.1.0 + provided + + + org.apache.spark + spark-sql_2.11 + 2.1.0 + provided + + + org.apache.hadoop + hadoop-yarn-api + 2.7.3 + provided + + + + org.apache.spark + spark-mllib_2.11 + 2.1.0 + test-jar + test + + + junit + junit + 4.12 + test + + + + diff --git a/spark/project/build.properties b/spark/project/build.properties new file mode 100644 index 00000000..5f32afe7 --- /dev/null +++ b/spark/project/build.properties @@ -0,0 +1 @@ +sbt.version=0.13.13 \ No newline at end of file diff --git a/spark/project/plugins.sbt b/spark/project/plugins.sbt new file mode 100644 index 00000000..0536c57a --- /dev/null +++ b/spark/project/plugins.sbt @@ -0,0 +1,5 @@ +resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/" + +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3") + +addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.5") diff --git a/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 00000000..c2869067 --- /dev/null +++ b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.tensorflow.spark.datasources.tfrecords.DefaultSource \ No newline at end of file diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DataTypesConvertor.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DataTypesConvertor.scala new file mode 100644 index 00000000..2f76aa97 --- /dev/null +++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DataTypesConvertor.scala @@ -0,0 +1,49 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords + +/** + * DataTypes supported + */ +object DataTypesConvertor { + + def toLong(value: Any): Long = { + value match { + case null => throw new IllegalArgumentException("null cannot be converted to Long") + case i: Int => i.toLong + case l: Long => l + case f: Float => f.toLong + case d: Double => d.toLong + case bd: BigDecimal => bd.toLong + case s: String => s.trim().toLong + case _ => throw new RuntimeException(s"${value.getClass.getName} toLong is not implemented") + } + } + + def toFloat(value: Any): Float = { + value match { + case null => throw new IllegalArgumentException("null cannot be converted to Float") + case i: Int => i.toFloat + case l: Long => l.toFloat + case f: Float => f + case d: Double => d.toFloat + case bd: BigDecimal => bd.toFloat + case s: String => s.trim().toFloat + case _ => throw new RuntimeException(s"${value.getClass.getName} toFloat is not implemented") + } + } +} + diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala new file mode 100644 index 00000000..ba9e6628 --- /dev/null +++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala @@ -0,0 +1,67 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords + +import org.apache.hadoop.io.{BytesWritable, NullWritable} +import org.apache.spark.sql._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StructType +import org.tensorflow.hadoop.io.TFRecordFileOutputFormat +import org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowEncoder + +/** + * Provides access to TensorFlow record source + */ +class DefaultSource extends DataSourceRegister + with CreatableRelationProvider + with RelationProvider + with SchemaRelationProvider{ + + /** + * Short alias for spark-tensorflow data source. + */ + override def shortName(): String = "tfrecords" + + // Writes DataFrame as TensorFlow Records + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + + val path = parameters("path") + + //Export DataFrame as TFRecords + val features = data.rdd.map(row => { + val example = DefaultTfRecordRowEncoder.encodeTfRecord(row) + (new BytesWritable(example.toByteArray), NullWritable.get()) + }) + features.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat](path) + + TensorflowRelation(parameters)(sqlContext.sparkSession) + } + + override def createRelation(sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + TensorflowRelation(parameters, Some(schema))(sqlContext.sparkSession) + } + + // Reads TensorFlow Records into DataFrame + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): TensorflowRelation = { + TensorflowRelation(parameters)(sqlContext.sparkSession) + } +} diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowInferSchema.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowInferSchema.scala new file mode 100644 index 00000000..5a108add --- /dev/null +++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowInferSchema.scala @@ -0,0 +1,197 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.tensorflow.example.{Example, Feature} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.Map +import scala.util.control.Exception._ + +object TensorflowInferSchema { + + /** + * Similar to the JSON schema inference. + * [[org.apache.spark.sql.execution.datasources.json.InferSchema]] + * 1. Infer type of each row + * 2. Merge row types to find common type + * 3. Replace any null types with string type + */ + def apply(exampleRdd: RDD[Example]): StructType = { + val startType: Map[String, DataType] = Map.empty[String, DataType] + val rootTypes: Map[String, DataType] = exampleRdd.aggregate(startType)(inferRowType, mergeFieldTypes) + val columnsList = rootTypes.map { + case (featureName, featureType) => + if (featureType == null) { + StructField(featureName, StringType) + } + else { + StructField(featureName, featureType) + } + } + StructType(columnsList.toSeq) + } + + private def inferRowType(schemaSoFar: Map[String, DataType], next: Example): Map[String, DataType] = { + next.getFeatures.getFeatureMap.asScala.map { + case (featureName, feature) => { + val currentType = inferField(feature) + if (schemaSoFar.contains(featureName)) { + val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType) + schemaSoFar(featureName) = updatedType.getOrElse(null) + } + else { + schemaSoFar += (featureName -> currentType) + } + } + } + schemaSoFar + } + + private def mergeFieldTypes(first: Map[String, DataType], second: Map[String, DataType]): Map[String, DataType] = { + //Merge two maps and do the comparison. + val mutMap = collection.mutable.Map[String, DataType]((first.keySet ++ second.keySet) + .map(key => (key, findTightestCommonType(first.getOrElse(key, null), second.getOrElse(key, null)).get)) + .toSeq: _*) + mutMap + } + + /** + * Infer Feature datatype based on field number + */ + private def inferField(feature: Feature): DataType = { + feature.getKindCase.getNumber match { + case Feature.BYTES_LIST_FIELD_NUMBER => { + StringType + } + case Feature.INT64_LIST_FIELD_NUMBER => { + parseInt64List(feature) + } + case Feature.FLOAT_LIST_FIELD_NUMBER => { + parseFloatList(feature) + } + case _ => throw new RuntimeException("unsupported type ...") + } + } + + private def parseInt64List(feature: Feature): DataType = { + val int64List = feature.getInt64List.getValueList.asScala.toArray + val length = int64List.size + if (length == 0) { + null + } + else if (length > 1) { + ArrayType(LongType) + } + else { + val fieldValue = int64List(0).toString + parseInteger(fieldValue) + } + } + + private def parseFloatList(feature: Feature): DataType = { + val floatList = feature.getFloatList.getValueList.asScala.toArray + val length = floatList.size + if (length == 0) { + null + } + else if (length > 1) { + ArrayType(DoubleType) + } + else { + val fieldValue = floatList(0).toString + parseFloat(fieldValue) + } + } + + private def parseInteger(field: String): DataType = if (allCatch.opt(field.toInt).isDefined) { + IntegerType + } + else { + parseLong(field) + } + + private def parseLong(field: String): DataType = if (allCatch.opt(field.toLong).isDefined) { + LongType + } + else { + throw new RuntimeException("Unable to parse field datatype to int64...") + } + + private def parseFloat(field: String): DataType = { + if ((allCatch opt field.toFloat).isDefined) { + FloatType + } + else { + parseDouble(field) + } + } + + private def parseDouble(field: String): DataType = if (allCatch.opt(field.toDouble).isDefined) { + DoubleType + } + else { + throw new RuntimeException("Unable to parse field datatype to float64...") + } + /** + * Copied from internal Spark api + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] + */ + private val numericPrecedence: IndexedSeq[DataType] = + IndexedSeq[DataType](IntegerType, + LongType, + FloatType, + DoubleType, + StringType) + + private def getNumericPrecedence(dataType: DataType): Int = { + dataType match { + case x if x.equals(IntegerType) => 0 + case x if x.equals(LongType) => 1 + case x if x.equals(FloatType) => 2 + case x if x.equals(DoubleType) => 3 + case x if x.equals(ArrayType(LongType)) => 4 + case x if x.equals(ArrayType(DoubleType)) => 5 + case x if x.equals(StringType) => 6 + case _ => throw new RuntimeException("Unable to get the precedence for given datatype...") + } + } + + /** + * Copied from internal Spark api + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] + */ + private val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + case (t1, t2) if t1 == t2 => Some(t1) + case (null, t2) => Some(t2) + case (t1, null) => Some(t1) + case (t1, t2) if t1.equals(ArrayType(LongType)) && t2.equals(ArrayType(DoubleType)) => Some(ArrayType(DoubleType)) + case (t1, t2) if t1.equals(ArrayType(DoubleType)) && t2.equals(ArrayType(LongType)) => Some(ArrayType(DoubleType)) + case (StringType, t2) => Some(StringType) + case (t1, StringType) => Some(StringType) + + // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + case (t1, t2) => + val t1Precedence = getNumericPrecedence(t1) + val t2Precedence = getNumericPrecedence(t2) + val newType = if (t1Precedence > t2Precedence) t1 else t2 + Some(newType) + case _ => None + } +} + diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala new file mode 100644 index 00000000..7766aa05 --- /dev/null +++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala @@ -0,0 +1,49 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords + +import org.apache.hadoop.io.{BytesWritable, NullWritable} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.sources.{BaseRelation, TableScan} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SQLContext, SparkSession} +import org.tensorflow.example.Example +import org.tensorflow.hadoop.io.TFRecordFileInputFormat +import org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowDecoder + + +case class TensorflowRelation(options: Map[String, String], customSchema: Option[StructType]=None)(@transient val session: SparkSession) extends BaseRelation with TableScan { + + //Import TFRecords as DataFrame happens here + lazy val (tf_rdd, tf_schema) = { + val rdd = session.sparkContext.newAPIHadoopFile(options("path"), classOf[TFRecordFileInputFormat], classOf[BytesWritable], classOf[NullWritable]) + + val exampleRdd = rdd.map { + case (bytesWritable, nullWritable) => Example.parseFrom(bytesWritable.getBytes) + } + + val finalSchema = customSchema.getOrElse(TensorflowInferSchema(exampleRdd)) + + (exampleRdd.map(example => DefaultTfRecordRowDecoder.decodeTfRecord(example, finalSchema)), finalSchema) + } + + override def sqlContext: SQLContext = session.sqlContext + + override def schema: StructType = tf_schema + + override def buildScan(): RDD[Row] = tf_rdd +} + diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala new file mode 100644 index 00000000..ccde03aa --- /dev/null +++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala @@ -0,0 +1,70 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row +import org.tensorflow.example._ +import scala.collection.JavaConverters._ + +trait TfRecordRowDecoder { + /** + * Decodes each TensorFlow "Example" as DataFrame "Row" + * + * Maps each feature in Example to element in Row with DataType based on custom schema or + * default mapping of Int64List, FloatList, BytesList to column data type + * + * @param example TensorFlow Example to decode + * @param schema Decode Example using specified schema + * @return a DataFrame row + */ + def decodeTfRecord(example: Example, schema: StructType): Row +} + +object DefaultTfRecordRowDecoder extends TfRecordRowDecoder { + + /** + * Decodes each TensorFlow "Example" as DataFrame "Row" + * + * Maps each feature in Example to element in Row with DataType based on custom schema + * + * @param example TensorFlow Example to decode + * @param schema Decode Example using specified schema + * @return a DataFrame row + */ + def decodeTfRecord(example: Example, schema: StructType): Row = { + val row = Array.fill[Any](schema.length)(null) + example.getFeatures.getFeatureMap.asScala.foreach { + case (featureName, feature) => + val index = schema.fieldIndex(featureName) + val colDataType = schema.fields(index).dataType + row(index) = colDataType match { + case IntegerType => IntFeatureDecoder.decode(feature) + case LongType => LongFeatureDecoder.decode(feature) + case FloatType => FloatFeatureDecoder.decode(feature) + case DoubleType => DoubleFeatureDecoder.decode(feature) + case ArrayType(IntegerType, true) => IntListFeatureDecoder.decode(feature) + case ArrayType(LongType, _) => LongListFeatureDecoder.decode(feature) + case ArrayType(FloatType, _) => FloatListFeatureDecoder.decode(feature) + case ArrayType(DoubleType, _) => DoubleListFeatureDecoder.decode(feature) + case StringType => StringFeatureDecoder.decode(feature) + case _ => throw new RuntimeException(s"Cannot convert feature to unsupported data type ${colDataType}") + } + } + Row.fromSeq(row) + } +} + diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala new file mode 100644 index 00000000..3d9072a5 --- /dev/null +++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala @@ -0,0 +1,66 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.tensorflow.example._ + +trait TfRecordRowEncoder { + /** + * Encodes each Row as TensorFlow "Example" + * + * Maps each column in Row to one of Int64List, FloatList, BytesList based on the column data type + * + * @param row a DataFrame row + * @return TensorFlow Example + */ + def encodeTfRecord(row: Row): Example +} + +object DefaultTfRecordRowEncoder extends TfRecordRowEncoder { + + /** + * Encodes each Row as TensorFlow "Example" + * + * Maps each column in Row to one of Int64List, FloatList, BytesList based on the column data type + * + * @param row a DataFrame row + * @return TensorFlow Example + */ + def encodeTfRecord(row: Row): Example = { + val features = Features.newBuilder() + val example = Example.newBuilder() + + row.schema.zipWithIndex.map { + case (structField, index) => + val value = row.get(index) + val feature = structField.dataType match { + case IntegerType | LongType => Int64ListFeatureEncoder.encode(value) + case FloatType | DoubleType => FloatListFeatureEncoder.encode(value) + case ArrayType(IntegerType, _) | ArrayType(LongType, _) => Int64ListFeatureEncoder.encode(value) + case ArrayType(DoubleType, _) => FloatListFeatureEncoder.encode(value) + case _ => BytesListFeatureEncoder.encode(value) + } + features.putFeature(structField.name, feature) + } + + features.build() + example.setFeatures(features) + example.build() + } +} + diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoder.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoder.scala new file mode 100644 index 00000000..c390ae42 --- /dev/null +++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoder.scala @@ -0,0 +1,187 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.tensorflow.example.Feature + +import scala.collection.JavaConverters._ + +trait FeatureDecoder[T] { + /** + * Decodes each TensorFlow "Feature" to desired Scala type + * + * @param feature TensorFlow Feature + * @return Decoded feature + */ + def decode(feature: Feature): T +} + +/** + * Decode TensorFlow "Feature" to Integer + */ +object IntFeatureDecoder extends FeatureDecoder[Int] { + override def decode(feature: Feature): Int = { + require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") + try { + val int64List = feature.getInt64List.getValueList + require(int64List.size() == 1, "Length of Int64List must equal 1") + int64List.get(0).intValue() + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert feature to Int.", ex) + } + } +} + +/** + * Decode TensorFlow "Feature" to Seq[Int] + */ +object IntListFeatureDecoder extends FeatureDecoder[Seq[Int]] { + override def decode(feature: Feature): Seq[Int] = { + require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") + try { + val array = feature.getInt64List.getValueList.asScala.toArray + array.map(_.toInt) + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert feature to Seq[Int].", ex) + } + } +} + +/** + * Decode TensorFlow "Feature" to Long + */ +object LongFeatureDecoder extends FeatureDecoder[Long] { + override def decode(feature: Feature): Long = { + require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") + try { + val int64List = feature.getInt64List.getValueList + require(int64List.size() == 1, "Length of Int64List must equal 1") + int64List.get(0).longValue() + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert feature to Long.", ex) + } + } +} + +/** + * Decode TensorFlow "Feature" to Seq[Long] + */ +object LongListFeatureDecoder extends FeatureDecoder[Seq[Long]] { + override def decode(feature: Feature): Seq[Long] = { + require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") + try { + val array = feature.getInt64List.getValueList.asScala.toArray + array.map(_.toLong) + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert feature to Array[Long].", ex) + } + } +} + +/** + * Decode TensorFlow "Feature" to Float + */ +object FloatFeatureDecoder extends FeatureDecoder[Float] { + override def decode(feature: Feature): Float = { + require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") + try { + val floatList = feature.getFloatList.getValueList + require(floatList.size() == 1, "Length of FloatList must equal 1") + floatList.get(0).floatValue() + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert feature to Float.", ex) + } + } +} + +/** + * Decode TensorFlow "Feature" to Seq[Float] + */ +object FloatListFeatureDecoder extends FeatureDecoder[Seq[Float]] { + override def decode(feature: Feature): Seq[Float] = { + require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") + try { + val array = feature.getFloatList.getValueList.asScala.toArray + array.map(_.toFloat) + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert feature to Array[Float].", ex) + } + } +} + +/** + * Decode TensorFlow "Feature" to Double + */ +object DoubleFeatureDecoder extends FeatureDecoder[Double] { + override def decode(feature: Feature): Double = { + require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") + try { + val floatList = feature.getFloatList.getValueList + require(floatList.size() == 1, "Length of FloatList must equal 1") + floatList.get(0).doubleValue() + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert feature to Double.", ex) + } + } +} + +/** + * Decode TensorFlow "Feature" to Seq[Double] + */ +object DoubleListFeatureDecoder extends FeatureDecoder[Seq[Double]] { + override def decode(feature: Feature): Seq[Double] = { + require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") + try { + val array = feature.getFloatList.getValueList.asScala.toArray + array.map(_.toDouble) + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert feature to Array[Double].", ex) + } + } +} + +/** + * Decode TensorFlow "Feature" to String + */ +object StringFeatureDecoder extends FeatureDecoder[String] { + override def decode(feature: Feature): String = { + require(feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList") + try { + feature.getBytesList.toByteString.toStringUtf8.trim + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert feature to String.", ex) + } + } +} + diff --git a/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala new file mode 100644 index 00000000..646cbadd --- /dev/null +++ b/spark/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala @@ -0,0 +1,118 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.tensorflow.example.{BytesList, Feature, FloatList, Int64List} +import org.tensorflow.hadoop.shaded.protobuf.ByteString +import org.tensorflow.spark.datasources.tfrecords.DataTypesConvertor + +trait FeatureEncoder { + /** + * Encodes input value as TensorFlow "Feature" + * + * Maps input value to one of Int64List, FloatList, BytesList + * + * @param value Input value + * @return TensorFlow Feature + */ + def encode(value: Any): Feature +} + +/** + * Encode input value to Int64List + */ +object Int64ListFeatureEncoder extends FeatureEncoder { + override def encode(value: Any): Feature = { + try { + val int64List = value match { + case i: Int => Int64List.newBuilder().addValue(i.toLong).build() + case l: Long => Int64List.newBuilder().addValue(l).build() + case arr: scala.collection.mutable.WrappedArray[_] => toInt64List(arr.toArray[Any]) + case arr: Array[_] => toInt64List(arr) + case seq: Seq[_] => toInt64List(seq.toArray[Any]) + case _ => throw new RuntimeException(s"Cannot convert object $value to Int64List") + } + Feature.newBuilder().setInt64List(int64List).build() + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to Int64List feature.", ex) + } + } + + private def toInt64List[T](arr: Array[T]): Int64List = { + val intListBuilder = Int64List.newBuilder() + arr.foreach(x => { + require(x != null, "Int64List with null values is not supported") + val longValue = DataTypesConvertor.toLong(x) + intListBuilder.addValue(longValue) + }) + intListBuilder.build() + } +} + +/** + * Encode input value to FloatList + */ +object FloatListFeatureEncoder extends FeatureEncoder { + override def encode(value: Any): Feature = { + try { + val floatList = value match { + case i: Int => FloatList.newBuilder().addValue(i.toFloat).build() + case l: Long => FloatList.newBuilder().addValue(l.toFloat).build() + case f: Float => FloatList.newBuilder().addValue(f).build() + case d: Double => FloatList.newBuilder().addValue(d.toFloat).build() + case arr: scala.collection.mutable.WrappedArray[_] => toFloatList(arr.toArray[Any]) + case arr: Array[_] => toFloatList(arr) + case seq: Seq[_] => toFloatList(seq.toArray[Any]) + case _ => throw new RuntimeException(s"Cannot convert object $value to FloatList") + } + Feature.newBuilder().setFloatList(floatList).build() + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to FloatList feature.", ex) + } + } + + private def toFloatList[T](arr: Array[T]): FloatList = { + val floatListBuilder = FloatList.newBuilder() + arr.foreach(x => { + require(x != null, "FloatList with null values is not supported") + val longValue = DataTypesConvertor.toFloat(x) + floatListBuilder.addValue(longValue) + }) + floatListBuilder.build() + } +} + +/** + * Encode input value to ByteList + */ +object BytesListFeatureEncoder extends FeatureEncoder { + override def encode(value: Any): Feature = { + try { + val byteList = BytesList.newBuilder().addValue(ByteString.copyFrom(value.toString.getBytes)).build() + Feature.newBuilder().setBytesList(byteList).build() + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to ByteList feature.", ex) + } + } +} + + diff --git a/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/SharedSparkSessionSuite.scala b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/SharedSparkSessionSuite.scala new file mode 100644 index 00000000..e87c815b --- /dev/null +++ b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/SharedSparkSessionSuite.scala @@ -0,0 +1,45 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords + +import java.io.File + +import org.apache.commons.io.FileUtils +import org.apache.spark.SharedSparkSession +import org.junit.{After, Before} +import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike} + + +trait BaseSuite extends WordSpecLike with Matchers with BeforeAndAfterAll + +class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite { + val TF_SANDBOX_DIR = "tf-sandbox" + val file = new File(TF_SANDBOX_DIR) + + @Before + override def beforeAll() = { + super.setUp() + FileUtils.deleteQuietly(file) + file.mkdirs() + } + + @After + override def afterAll() = { + FileUtils.deleteQuietly(file) + super.tearDown() + } +} + diff --git a/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowSuite.scala b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowSuite.scala new file mode 100644 index 00000000..0e37968e --- /dev/null +++ b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowSuite.scala @@ -0,0 +1,210 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} +import org.tensorflow.example._ +import org.tensorflow.hadoop.shaded.protobuf.ByteString +import org.tensorflow.spark.datasources.tfrecords.serde.{DefaultTfRecordRowDecoder, DefaultTfRecordRowEncoder} + +import scala.collection.JavaConverters._ + +class TensorflowSuite extends SharedSparkSessionSuite { + + "Spark TensorFlow module" should { + + "Test Import/Export" in { + + val path = s"$TF_SANDBOX_DIR/output25.tfr" + val testRows: Array[Row] = Array( + new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")), + new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2"))) + + val schema = StructType(List( + StructField("id", IntegerType), + StructField("IntegerTypelabel", IntegerType), + StructField("LongTypelabel", LongType), + StructField("FloatTypelabel", FloatType), + StructField("DoubleTypelabel", DoubleType), + StructField("vectorlabel", ArrayType(DoubleType, true)), + StructField("name", StringType))) + + val rdd = spark.sparkContext.parallelize(testRows) + + val df: DataFrame = spark.createDataFrame(rdd, schema) + df.write.format("tfrecords").save(path) + + //If schema is not provided. It will automatically infer schema + val importedDf: DataFrame = spark.read.format("tfrecords").schema(schema).load(path) + val actualDf = importedDf.select("id", "IntegerTypelabel", "LongTypelabel", "FloatTypelabel", "DoubleTypelabel", "vectorlabel", "name").sort("name") + + val expectedRows = df.collect() + val actualRows = actualDf.collect() + + expectedRows should equal(actualRows) + } + + "Encode given Row as TensorFlow example" in { + val schemaStructType = StructType(Array( + StructField("IntegerTypelabel", IntegerType), + StructField("LongTypelabel", LongType), + StructField("FloatTypelabel", FloatType), + StructField("DoubleTypelabel", DoubleType), + StructField("vectorlabel", ArrayType(DoubleType, true)), + StructField("strlabel", StringType) + )) + val doubleArray = Array(1.1, 111.1, 11111.1) + val expectedFloatArray = Array(1.1F, 111.1F, 11111.1F) + + val rowWithSchema = new GenericRowWithSchema(Array[Any](1, 23L, 10.0F, 14.0, doubleArray, "r1"), schemaStructType) + + //Encode Sql Row to TensorFlow example + val example = DefaultTfRecordRowEncoder.encodeTfRecord(rowWithSchema) + import org.tensorflow.example.Feature + + //Verify each Datatype converted to TensorFlow datatypes + val featureMap = example.getFeatures.getFeatureMap.asScala + assert(featureMap("IntegerTypelabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) + assert(featureMap("IntegerTypelabel").getInt64List.getValue(0).toInt == 1) + + assert(featureMap("LongTypelabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) + assert(featureMap("LongTypelabel").getInt64List.getValue(0).toInt == 23) + + assert(featureMap("FloatTypelabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert(featureMap("FloatTypelabel").getFloatList.getValue(0) == 10.0F) + + assert(featureMap("DoubleTypelabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert(featureMap("DoubleTypelabel").getFloatList.getValue(0) == 14.0F) + + assert(featureMap("vectorlabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert(featureMap("vectorlabel").getFloatList.getValueList.toArray === expectedFloatArray) + + assert(featureMap("strlabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) + assert(featureMap("strlabel").getBytesList.toByteString.toStringUtf8.trim == "r1") + + } + + "Throw an exception for a vector with null values during Encode" in { + intercept[Exception] { + val schemaStructType = StructType(Array( + StructField("vectorlabel", ArrayType(DoubleType, true)) + )) + val doubleArray = Array(1.1, null, 111.1, null, 11111.1) + + val rowWithSchema = new GenericRowWithSchema(Array[Any](doubleArray), schemaStructType) + + //Throws NullPointerException + DefaultTfRecordRowEncoder.encodeTfRecord(rowWithSchema) + } + } + + "Decode given TensorFlow Example as Row" in { + + //Here Vector with null's are not supported + val expectedRow = new GenericRow(Array[Any](1, 23L, 10.0F, 14.0, Seq(1.0, 2.0), "r1")) + + val schema = StructType(List( + StructField("IntegerTypelabel", IntegerType), + StructField("LongTypelabel", LongType), + StructField("FloatTypelabel", FloatType), + StructField("DoubleTypelabel", DoubleType), + StructField("vectorlabel", ArrayType(DoubleType)), + StructField("strlabel", StringType))) + + //Build example + val intFeature = Int64List.newBuilder().addValue(1) + val longFeature = Int64List.newBuilder().addValue(23L) + val floatFeature = FloatList.newBuilder().addValue(10.0F) + val doubleFeature = FloatList.newBuilder().addValue(14.0F) + val vectorFeature = FloatList.newBuilder().addValue(1F).addValue(2F).build() + val strFeature = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build() + val features = Features.newBuilder() + .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature).build()) + .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature).build()) + .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature).build()) + .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature).build()) + .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature).build()) + .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature).build()) + .build() + val example = Example.newBuilder() + .setFeatures(features) + .build() + + //Decode TensorFlow example to Sql Row + val actualRow = DefaultTfRecordRowDecoder.decodeTfRecord(example, schema) + actualRow should equal(expectedRow) + } + + "Check infer schema" in { + + //Build example1 + val intFeature1 = Int64List.newBuilder().addValue(1) + val longFeature1 = Int64List.newBuilder().addValue(Int.MaxValue + 10L) + val floatFeature1 = FloatList.newBuilder().addValue(10.0F) + val doubleFeature1 = FloatList.newBuilder().addValue(14.0F) + val vectorFeature1 = FloatList.newBuilder().addValue(1F).build() + val strFeature1 = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build() + val features1 = Features.newBuilder() + .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature1).build()) + .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature1).build()) + .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature1).build()) + .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature1).build()) + .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature1).build()) + .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature1).build()) + .build() + val example1 = Example.newBuilder() + .setFeatures(features1) + .build() + + //Build example2 + val intFeature2 = Int64List.newBuilder().addValue(2) + val longFeature2 = Int64List.newBuilder().addValue(24) + val floatFeature2 = FloatList.newBuilder().addValue(12.0F) + val doubleFeature2 = FloatList.newBuilder().addValue(Float.MaxValue + 15) + val vectorFeature2 = FloatList.newBuilder().addValue(2F).addValue(2F).build() + val strFeature2 = BytesList.newBuilder().addValue(ByteString.copyFrom("r2".getBytes)).build() + val features2 = Features.newBuilder() + .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature2).build()) + .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature2).build()) + .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature2).build()) + .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature2).build()) + .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature2).build()) + .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature2).build()) + .build() + val example2 = Example.newBuilder() + .setFeatures(features2) + .build() + + val exampleRDD: RDD[Example] = spark.sparkContext.parallelize(List(example1, example2)) + + val actualSchema = TensorflowInferSchema(exampleRDD) + + //Verify each TensorFlow Datatype is inferred as one of our Datatype + actualSchema.fields.map { colum => + colum.name match { + case "IntegerTypelabel" => colum.dataType.equals(IntegerType) + case "LongTypelabel" => colum.dataType.equals(LongType) + case "FloatTypelabel" | "DoubleTypelabel" | "vectorlabel" => colum.dataType.equals(FloatType) + case "strlabel" => colum.dataType.equals(StringType) + } + } + } + } +} + diff --git a/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoderTest.scala b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoderTest.scala new file mode 100644 index 00000000..954803d1 --- /dev/null +++ b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoderTest.scala @@ -0,0 +1,209 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.scalatest.{Matchers, WordSpec} +import org.tensorflow.example.{BytesList, Feature, FloatList, Int64List} +import org.tensorflow.hadoop.shaded.protobuf.ByteString + +class FeatureDecoderTest extends WordSpec with Matchers { + + "Int Feature decoder" should { + + "Decode Feature to Int" in { + val int64List = Int64List.newBuilder().addValue(4).build() + val intFeature = Feature.newBuilder().setInt64List(int64List).build() + IntFeatureDecoder.decode(intFeature) should equal(4) + } + + "Throw an exception if length of feature array exceeds 1" in { + intercept[Exception] { + val int64List = Int64List.newBuilder().addValue(4).addValue(7).build() + val intFeature = Feature.newBuilder().setInt64List(int64List).build() + IntFeatureDecoder.decode(intFeature) + } + } + + "Throw an exception if feature is not an Int64List" in { + intercept[Exception] { + val floatList = FloatList.newBuilder().addValue(4).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + IntFeatureDecoder.decode(floatFeature) + } + } + } + + "Int List Feature decoder" should { + + "Decode Feature to Int List" in { + val int64List = Int64List.newBuilder().addValue(3).addValue(9).build() + val intFeature = Feature.newBuilder().setInt64List(int64List).build() + IntListFeatureDecoder.decode(intFeature) should equal(Seq(3,9)) + } + + "Throw an exception if feature is not an Int64List" in { + intercept[Exception] { + val floatList = FloatList.newBuilder().addValue(4).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + IntListFeatureDecoder.decode(floatFeature) + } + } + } + + "Long Feature decoder" should { + + "Decode Feature to Long" in { + val int64List = Int64List.newBuilder().addValue(5L).build() + val intFeature = Feature.newBuilder().setInt64List(int64List).build() + LongFeatureDecoder.decode(intFeature) should equal(5L) + } + + "Throw an exception if length of feature array exceeds 1" in { + intercept[Exception] { + val int64List = Int64List.newBuilder().addValue(4L).addValue(10L).build() + val intFeature = Feature.newBuilder().setInt64List(int64List).build() + LongFeatureDecoder.decode(intFeature) + } + } + + "Throw an exception if feature is not an Int64List" in { + intercept[Exception] { + val floatList = FloatList.newBuilder().addValue(4).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + LongFeatureDecoder.decode(floatFeature) + } + } + } + + "Long List Feature decoder" should { + + "Decode Feature to Long List" in { + val int64List = Int64List.newBuilder().addValue(3L).addValue(Int.MaxValue+10L).build() + val intFeature = Feature.newBuilder().setInt64List(int64List).build() + LongListFeatureDecoder.decode(intFeature) should equal(Seq(3L,Int.MaxValue+10L)) + } + + "Throw an exception if feature is not an Int64List" in { + intercept[Exception] { + val floatList = FloatList.newBuilder().addValue(4).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + LongListFeatureDecoder.decode(floatFeature) + } + } + } + + "Float Feature decoder" should { + + "Decode Feature to Float" in { + val floatList = FloatList.newBuilder().addValue(2.5F).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + FloatFeatureDecoder.decode(floatFeature) should equal(2.5F) + } + + "Throw an exception if length of feature array exceeds 1" in { + intercept[Exception] { + val floatList = FloatList.newBuilder().addValue(1.5F).addValue(3.33F).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + FloatFeatureDecoder.decode(floatFeature) + } + } + + "Throw an exception if feature is not a FloatList" in { + intercept[Exception] { + val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() + val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() + FloatFeatureDecoder.decode(bytesFeature) + } + } + } + + "Float List Feature decoder" should { + + "Decode Feature to Float List" in { + val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.3F).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + FloatListFeatureDecoder.decode(floatFeature) should equal(Seq(2.5F, 4.3F)) + } + + "Throw an exception if feature is not a FloatList" in { + intercept[Exception] { + val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() + val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() + FloatListFeatureDecoder.decode(bytesFeature) + } + } + } + + "Double Feature decoder" should { + + "Decode Feature to Double" in { + val floatList = FloatList.newBuilder().addValue(2.5F).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + DoubleFeatureDecoder.decode(floatFeature) should equal(2.5d) + } + + "Throw an exception if length of feature array exceeds 1" in { + intercept[Exception] { + val floatList = FloatList.newBuilder().addValue(1.5F).addValue(3.33F).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + DoubleFeatureDecoder.decode(floatFeature) + } + } + + "Throw an exception if feature is not a FloatList" in { + intercept[Exception] { + val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() + val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() + DoubleFeatureDecoder.decode(bytesFeature) + } + } + } + + "Double List Feature decoder" should { + + "Decode Feature to Double List" in { + val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.0F).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + DoubleListFeatureDecoder.decode(floatFeature) should equal(Seq(2.5d, 4.0d)) + } + + "Throw an exception if feature is not a DoubleList" in { + intercept[Exception] { + val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() + val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() + FloatListFeatureDecoder.decode(bytesFeature) + } + } + } + + "Bytes List Feature decoder" should { + + "Decode Feature to Bytes List" in { + val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() + val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() + StringFeatureDecoder.decode(bytesFeature) should equal("str-input") + } + + "Throw an exception if feature is not a BytesList" in { + intercept[Exception] { + val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.0F).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + StringFeatureDecoder.decode(floatFeature) + } + } + } +} + diff --git a/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala new file mode 100644 index 00000000..4c09d568 --- /dev/null +++ b/spark/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala @@ -0,0 +1,99 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.scalatest.{Matchers, WordSpec} + +import scala.collection.JavaConverters._ + +class FeatureEncoderTest extends WordSpec with Matchers { + + "Int64List feature encoder" should { + "Encode inputs to Int64List" in { + val intFeature = Int64ListFeatureEncoder.encode(5) + val longFeature = Int64ListFeatureEncoder.encode(10L) + val longListFeature = Int64ListFeatureEncoder.encode(Seq(3L,5L,6L)) + + intFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(5L)) + longFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(10L)) + longListFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(3L, 5L, 6L)) + } + + "Throw an exception when inputs contain null" in { + intercept[Exception] { + Int64ListFeatureEncoder.encode(null) + } + intercept[Exception] { + Int64ListFeatureEncoder.encode(Seq(3,null,6)) + } + } + + "Throw an exception for non-numeric inputs" in { + intercept[Exception] { + Int64ListFeatureEncoder.encode("bad-input") + } + } + } + + "FloatList feature encoder" should { + "Encode inputs to FloatList" in { + val intFeature = FloatListFeatureEncoder.encode(5) + val longFeature = FloatListFeatureEncoder.encode(10L) + val floatFeature = FloatListFeatureEncoder.encode(2.5F) + val doubleFeature = FloatListFeatureEncoder.encode(14.6) + val floatListFeature = FloatListFeatureEncoder.encode(Seq(1.5F,6.8F,-3.2F)) + + intFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(5F)) + longFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(10F)) + floatFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(2.5F)) + doubleFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(14.6F)) + floatListFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(1.5F,6.8F,-3.2F)) + } + + "Throw an exception when inputs contain null" in { + intercept[Exception] { + FloatListFeatureEncoder.encode(null) + } + intercept[Exception] { + FloatListFeatureEncoder.encode(Seq(3,null,6)) + } + } + + "Throw an exception for non-numeric inputs" in { + intercept[Exception] { + FloatListFeatureEncoder.encode("bad-input") + } + } + } + + "ByteList feature encoder" should { + "Encode inputs to ByteList" in { + val longFeature = BytesListFeatureEncoder.encode(10L) + val longListFeature = BytesListFeatureEncoder.encode(Seq(3L,5L,6L)) + val strFeature = BytesListFeatureEncoder.encode("str-input") + + longFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("10") + longListFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("List(3, 5, 6)") + strFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("str-input") + } + + "Throw an exception when inputs contain null" in { + intercept[Exception] { + BytesListFeatureEncoder.encode(null) + } + } + } +}