From 05ac88ba540cc9bcdbc8989df50a9774a38b68cc Mon Sep 17 00:00:00 2001 From: Soila Kavulya Date: Thu, 29 Jun 2017 13:15:43 -0700 Subject: [PATCH] Add support for SequenceExample to Spark TensorFlow connector (#52) * Decode and encode SequenceExample feature lists to Spark SQL rows * Added format to specify recordType while writing dataframes * Add unit tests for SequenceExample and fix sql datatype for BytesList DataType for BytesList changed from StringType to ArrayType(StringType) * Add tests for sequence example * Update tests and README example for importing TensorFlow SequenceExample * Update README for Spark TensorFlow connector * Update recordType in README for Spark TensorFlow connector * Fix handling of integers in SequenceExample export * Add schema inference for SequenceExamples in Spark --- spark/spark-tensorflow-connector/README.md | 86 ++++++- .../tfrecords/DataTypesConvertor.scala | 49 ---- .../datasources/tfrecords/DefaultSource.scala | 17 +- .../tfrecords/TensorFlowInferSchema.scala | 223 ++++++++++++++++++ .../tfrecords/TensorflowInferSchema.scala | 197 ---------------- .../tfrecords/TensorflowRelation.scala | 36 ++- .../serde/DefaultTfRecordRowDecoder.scala | 94 ++++++-- .../serde/DefaultTfRecordRowEncoder.scala | 120 ++++++++-- .../tfrecords/serde/FeatureDecoder.scala | 66 ++++-- .../tfrecords/serde/FeatureEncoder.scala | 95 ++------ .../tfrecords/serde/FeatureListDecoder.scala | 74 ++++++ .../tfrecords/serde/FeatureListEncoder.scala | 73 ++++++ .../tfrecords/InferSchemaSuite.scala | 141 +++++++++++ .../tfrecords/TensorFlowSuite.scala | 84 +++++++ .../tfrecords/TensorflowSuite.scala | 210 ----------------- .../datasources/tfrecords/TestingUtils.scala | 156 ++++++++++++ .../tfrecords/serde/FeatureDecoderTest.scala | 52 +++- .../tfrecords/serde/FeatureEncoderTest.scala | 71 ++---- .../serde/FeatureListDecoderTest.scala | 161 +++++++++++++ .../serde/FeatureListEncoderTest.scala | 69 ++++++ .../serde/TfRecordRowDecoderTest.scala | 127 ++++++++++ .../serde/TfRecordRowEncoderTest.scala | 105 +++++++++ 22 files changed, 1639 insertions(+), 667 deletions(-) delete mode 100644 spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DataTypesConvertor.scala create mode 100644 spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowInferSchema.scala delete mode 100644 spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowInferSchema.scala create mode 100644 spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListDecoder.scala create mode 100644 spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListEncoder.scala create mode 100644 spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/InferSchemaSuite.scala create mode 100644 spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowSuite.scala delete mode 100644 spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowSuite.scala create mode 100644 spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TestingUtils.scala create mode 100644 spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListDecoderTest.scala create mode 100644 spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListEncoderTest.scala create mode 100644 spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/TfRecordRowDecoderTest.scala create mode 100644 spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/TfRecordRowEncoderTest.scala diff --git a/spark/spark-tensorflow-connector/README.md b/spark/spark-tensorflow-connector/README.md index 5239e5aa..050ec75c 100644 --- a/spark/spark-tensorflow-connector/README.md +++ b/spark/spark-tensorflow-connector/README.md @@ -53,7 +53,39 @@ SBT Jars $SPARK_HOME/bin/spark-shell --jars target/scala-2.11/spark-tensorflow-connector-assembly-1.0.0.jar ``` -The following code snippet demonstrates usage. +## Features +This library allows reading TensorFlow records in local or distributed filesystem as [Spark DataFrames](https://spark.apache.org/docs/latest/sql-programming-guide.html). +When reading TensorFlow records into Spark DataFrame, the API accepts several options: +* `load`: input path to TensorFlow records. Similar to Spark can accept standard Hadoop globbing expressions. +* `schema`: schema of TensorFlow records. Optional schema defined using Spark StructType. If not provided, the schema is inferred from TensorFlow records. +* `recordType`: input format of TensorFlow records. By default it is Example. Possible values are: + * `Example`: TensorFlow [Example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records + * `SequenceExample`: TensorFlow [SequenceExample](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records + +When writing Spark DataFrame to TensorFlow records, the API accepts several options: +* `save`: output path to TensorFlow records. Output path to TensorFlow records on local or distributed filesystem. +* `recordType`: output format of TensorFlow records. By default it is Example. Possible values are: + * `Example`: TensorFlow [Example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records + * `SequenceExample`: TensorFlow [SequenceExample](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records + +## Schema inference +This library supports automatic schema inference when reading TensorFlow records into Spark DataFrames. +Schema inference is expensive since it requires an extra pass through the data. + +The schema inference rules are described in the table below: + +| TFRecordType | Feature Type | Inferred Spark Data Type | +| ------------------------ |:--------------|:--------------------------| +| Example, SequenceExample | Int64List | LongType if all lists have length=1, else ArrayType(LongType) | +| Example, SequenceExample | FloatList | FloatType if all lists have length=1, else ArrayType(FloatType) | +| Example, SequenceExample | BytesList | StringType if all lists have length=1, else ArrayType(StringType) | +| SequenceExample | FeatureList of Int64List | ArrayType(ArrayType(LongType)) | +| SequenceExample | FeatureList of FloatList | ArrayType(ArrayType(FloatType)) | +| SequenceExample | FeatureList of BytesList | ArrayType(ArrayType(StringType)) | + +## Usage Examples + +The following code snippet demonstrates usage on test data. ```scala import org.apache.commons.io.FileUtils @@ -61,31 +93,67 @@ import org.apache.spark.sql.{ DataFrame, Row } import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types._ -val path = "test-output.tfr" +val path = "test-output.tfrecord" val testRows: Array[Row] = Array( new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")), new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2"))) val schema = StructType(List(StructField("id", IntegerType), - StructField("IntegerTypelabel", IntegerType), - StructField("LongTypelabel", LongType), - StructField("FloatTypelabel", FloatType), - StructField("DoubleTypelabel", DoubleType), - StructField("vectorlabel", ArrayType(DoubleType, true)), + StructField("IntegerTypeLabel", IntegerType), + StructField("LongTypeLabel", LongType), + StructField("FloatTypeLabel", FloatType), + StructField("DoubleTypeLabel", DoubleType), + StructField("VectorLabel", ArrayType(DoubleType, true)), StructField("name", StringType))) val rdd = spark.sparkContext.parallelize(testRows) //Save DataFrame as TFRecords val df: DataFrame = spark.createDataFrame(rdd, schema) -df.write.format("tfrecords").save(path) +df.write.format("tfrecords").option("recordType", "Example").save(path) //Read TFRecords into DataFrame. //The DataFrame schema is inferred from the TFRecords if no custom schema is provided. -val importedDf1: DataFrame = spark.read.format("tfrecords").load(path) +val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path) importedDf1.show() //Read TFRecords into DataFrame using custom schema val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path) importedDf2.show() +``` + +#### Loading YouTube-8M dataset to Spark +Here's how to import the [YouTube-8M](https://research.google.com/youtube8m/) dataset into a Spark DataFrame. +```sh +curl http://us.data.yt8m.org/1/video_level/train/train-0.tfrecord > /tmp/video_level-train-0.tfrecord +curl http://us.data.yt8m.org/1/frame_level/train/train-0.tfrecord > /tmp/frame_level-train-0.tfrecord ``` + +```scala +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.{ DataFrame, Row } +import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.types._ + +//Import Video-level Example dataset into DataFrame +val videoSchema = StructType(List(StructField("video_id", StringType), + StructField("labels", ArrayType(IntegerType, true)), + StructField("mean_rgb", ArrayType(FloatType, true)), + StructField("mean_audio", ArrayType(FloatType, true)))) +val videoDf: DataFrame = spark.read.format("tfrecords").schema(videoSchema).option("recordType", "Example").load("file:///tmp/video_level-train-0.tfrecord") +videoDf.show() +videoDf.write.format("tfrecords").option("recordType", "Example").save("youtube-8m-video.tfrecord") +val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").schema(videoSchema).load("youtube-8m-video.tfrecords") +importedDf1.show() + +//Import Frame-level SequenceExample dataset into DataFrame +val frameSchema = StructType(List(StructField("video_id", StringType), + StructField("labels", ArrayType(IntegerType, true)), + StructField("rgb", ArrayType(ArrayType(StringType, true),true)), + StructField("audio", ArrayType(ArrayType(StringType, true),true)))) +val frameDf: DataFrame = spark.read.format("tfrecords").schema(frameSchema).option("recordType", "SequenceExample").load("file:///tmp/frame_level-train-0.tfrecord") +frameDf.show() +frameDf.write.format("tfrecords").option("recordType", "SequenceExample").save("youtube-8m-frame.tfrecord") +val importedDf2: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").schema(frameSchema).load("youtube-8m-frame.tfrecords") +importedDf2.show() +``` \ No newline at end of file diff --git a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DataTypesConvertor.scala b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DataTypesConvertor.scala deleted file mode 100644 index 2f76aa97..00000000 --- a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DataTypesConvertor.scala +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - *       http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.spark.datasources.tfrecords - -/** - * DataTypes supported - */ -object DataTypesConvertor { - - def toLong(value: Any): Long = { - value match { - case null => throw new IllegalArgumentException("null cannot be converted to Long") - case i: Int => i.toLong - case l: Long => l - case f: Float => f.toLong - case d: Double => d.toLong - case bd: BigDecimal => bd.toLong - case s: String => s.trim().toLong - case _ => throw new RuntimeException(s"${value.getClass.getName} toLong is not implemented") - } - } - - def toFloat(value: Any): Float = { - value match { - case null => throw new IllegalArgumentException("null cannot be converted to Float") - case i: Int => i.toFloat - case l: Long => l.toFloat - case f: Float => f - case d: Double => d.toFloat - case bd: BigDecimal => bd.toFloat - case s: String => s.trim().toFloat - case _ => throw new RuntimeException(s"${value.getClass.getName} toFloat is not implemented") - } - } -} - diff --git a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala index ba9e6628..d34fef8e 100644 --- a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala +++ b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala @@ -44,23 +44,34 @@ class DefaultSource extends DataSourceRegister val path = parameters("path") + val recordType = parameters.getOrElse("recordType", "Example") + //Export DataFrame as TFRecords val features = data.rdd.map(row => { - val example = DefaultTfRecordRowEncoder.encodeTfRecord(row) - (new BytesWritable(example.toByteArray), NullWritable.get()) + recordType match { + case "Example" => + val example = DefaultTfRecordRowEncoder.encodeExample(row) + (new BytesWritable(example.toByteArray), NullWritable.get()) + case "SequenceExample" => + val sequenceExample = DefaultTfRecordRowEncoder.encodeSequenceExample(row) + (new BytesWritable(sequenceExample.toByteArray), NullWritable.get()) + case _ => + throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be Example or SequenceExample") + } }) features.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat](path) TensorflowRelation(parameters)(sqlContext.sparkSession) } + // Reads TensorFlow Records into DataFrame with Custom Schema override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = { TensorflowRelation(parameters, Some(schema))(sqlContext.sparkSession) } - // Reads TensorFlow Records into DataFrame + // Reads TensorFlow Records into DataFrame with schema inferred override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): TensorflowRelation = { TensorflowRelation(parameters)(sqlContext.sparkSession) } diff --git a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowInferSchema.scala b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowInferSchema.scala new file mode 100644 index 00000000..4007a0ec --- /dev/null +++ b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowInferSchema.scala @@ -0,0 +1,223 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.tensorflow.example.{FeatureList, SequenceExample, Example, Feature} +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.reflect.runtime.universe._ + +object TensorFlowInferSchema { + + /** + * Similar to the JSON schema inference. + * [[org.apache.spark.sql.execution.datasources.json.InferSchema]] + * 1. Infer type of each row + * 2. Merge row types to find common type + * 3. Replace any null types with string type + */ + def apply[T : TypeTag](rdd: RDD[T]): StructType = { + val startType: mutable.Map[String, DataType] = mutable.Map.empty[String, DataType] + + val rootTypes: mutable.Map[String, DataType] = typeOf[T] match { + case t if t =:= typeOf[Example] => { + rdd.asInstanceOf[RDD[Example]].aggregate(startType)(inferExampleRowType, mergeFieldTypes) + } + case t if t =:= typeOf[SequenceExample] => { + rdd.asInstanceOf[RDD[SequenceExample]].aggregate(startType)(inferSequenceExampleRowType, mergeFieldTypes) + } + case _ => throw new IllegalArgumentException(s"Unsupported recordType: recordType can be Example or SequenceExample") + } + + val columnsList = rootTypes.map { + case (featureName, featureType) => + if (featureType == null) { + StructField(featureName, StringType) + } + else { + StructField(featureName, featureType) + } + } + StructType(columnsList.toSeq) + } + + private def inferSequenceExampleRowType(schemaSoFar: mutable.Map[String, DataType], + next: SequenceExample): mutable.Map[String, DataType] = { + val featureMap = next.getContext.getFeatureMap.asScala + val updatedSchema = inferFeatureTypes(schemaSoFar, featureMap) + + val featureListMap = next.getFeatureLists.getFeatureListMap.asScala + inferFeatureListTypes(updatedSchema, featureListMap) + } + + private def inferExampleRowType(schemaSoFar: mutable.Map[String, DataType], + next: Example): mutable.Map[String, DataType] = { + val featureMap = next.getFeatures.getFeatureMap.asScala + inferFeatureTypes(schemaSoFar, featureMap) + } + + private def inferFeatureTypes(schemaSoFar: mutable.Map[String, DataType], + featureMap: mutable.Map[String, Feature]): mutable.Map[String, DataType] = { + featureMap.foreach { + case (featureName, feature) => { + val currentType = inferField(feature) + if (schemaSoFar.contains(featureName)) { + val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType) + schemaSoFar(featureName) = updatedType.orNull + } + else { + schemaSoFar += (featureName -> currentType) + } + } + } + schemaSoFar + } + + def inferFeatureListTypes(schemaSoFar: mutable.Map[String, DataType], + featureListMap: mutable.Map[String, FeatureList]): mutable.Map[String, DataType] = { + featureListMap.foreach { + case (featureName, featureList) => { + val featureType = featureList.getFeatureList.asScala.map(f => inferField(f)) + .reduceLeft((a, b) => findTightestCommonType(a, b).orNull) + val currentType = featureType match { + case ArrayType(_, _) => ArrayType(featureType) + case _ => ArrayType(ArrayType(featureType)) + } + if (schemaSoFar.contains(featureName)) { + val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType) + schemaSoFar(featureName) = updatedType.orNull + } + else { + schemaSoFar += (featureName -> currentType) + } + } + } + schemaSoFar + } + + private def mergeFieldTypes(first: mutable.Map[String, DataType], + second: mutable.Map[String, DataType]): mutable.Map[String, DataType] = { + //Merge two maps and do the comparison. + val mutMap = collection.mutable.Map[String, DataType]((first.keySet ++ second.keySet) + .map(key => (key, findTightestCommonType(first.getOrElse(key, null), second.getOrElse(key, null)).get)) + .toSeq: _*) + mutMap + } + + /** + * Infer Feature datatype based on field number + */ + private def inferField(feature: Feature): DataType = { + feature.getKindCase.getNumber match { + case Feature.BYTES_LIST_FIELD_NUMBER => { + parseBytesList(feature) + } + case Feature.INT64_LIST_FIELD_NUMBER => { + parseInt64List(feature) + } + case Feature.FLOAT_LIST_FIELD_NUMBER => { + parseFloatList(feature) + } + case _ => throw new RuntimeException("unsupported type ...") + } + } + + private def parseBytesList(feature: Feature): DataType = { + val length = feature.getBytesList.getValueCount + + if (length == 0) { + null + } + else if (length > 1) { + ArrayType(StringType) + } + else { + StringType + } + } + + private def parseInt64List(feature: Feature): DataType = { + val int64List = feature.getInt64List.getValueList.asScala.toArray + val length = int64List.length + + if (length == 0) { + null + } + else if (length > 1) { + ArrayType(LongType) + } + else { + LongType + } + } + + private def parseFloatList(feature: Feature): DataType = { + val floatList = feature.getFloatList.getValueList.asScala.toArray + val length = floatList.length + if (length == 0) { + null + } + else if (length > 1) { + ArrayType(FloatType) + } + else { + FloatType + } + } + + /** + * Copied from internal Spark api + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] + */ + private def getNumericPrecedence(dataType: DataType): Int = { + dataType match { + case LongType => 1 + case FloatType => 2 + case StringType => 3 + case ArrayType(LongType, _) => 4 + case ArrayType(FloatType, _) => 5 + case ArrayType(StringType, _) => 6 + case ArrayType(ArrayType(LongType, _), _) => 7 + case ArrayType(ArrayType(FloatType, _), _) => 8 + case ArrayType(ArrayType(StringType, _), _) => 9 + case _ => throw new RuntimeException("Unable to get the precedence for given datatype...") + } + } + + /** + * Copied from internal Spark api + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] + */ + private def findTightestCommonType(tt1: DataType, tt2: DataType) : Option[DataType] = { + val currType = (tt1, tt2) match { + case (t1, t2) if t1 == t2 => Some(t1) + case (null, t2) => Some(t2) + case (t1, null) => Some(t1) + + // Promote types based on numeric precedence + case (t1, t2) => + val t1Precedence = getNumericPrecedence(t1) + val t2Precedence = getNumericPrecedence(t2) + val newType = if (t1Precedence > t2Precedence) t1 else t2 + Some(newType) + case _ => None + } + currType + } +} + diff --git a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowInferSchema.scala b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowInferSchema.scala deleted file mode 100644 index 5a108add..00000000 --- a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowInferSchema.scala +++ /dev/null @@ -1,197 +0,0 @@ -/** - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - *       http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.spark.datasources.tfrecords - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types._ -import org.tensorflow.example.{Example, Feature} - -import scala.collection.JavaConverters._ -import scala.collection.mutable.Map -import scala.util.control.Exception._ - -object TensorflowInferSchema { - - /** - * Similar to the JSON schema inference. - * [[org.apache.spark.sql.execution.datasources.json.InferSchema]] - * 1. Infer type of each row - * 2. Merge row types to find common type - * 3. Replace any null types with string type - */ - def apply(exampleRdd: RDD[Example]): StructType = { - val startType: Map[String, DataType] = Map.empty[String, DataType] - val rootTypes: Map[String, DataType] = exampleRdd.aggregate(startType)(inferRowType, mergeFieldTypes) - val columnsList = rootTypes.map { - case (featureName, featureType) => - if (featureType == null) { - StructField(featureName, StringType) - } - else { - StructField(featureName, featureType) - } - } - StructType(columnsList.toSeq) - } - - private def inferRowType(schemaSoFar: Map[String, DataType], next: Example): Map[String, DataType] = { - next.getFeatures.getFeatureMap.asScala.map { - case (featureName, feature) => { - val currentType = inferField(feature) - if (schemaSoFar.contains(featureName)) { - val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType) - schemaSoFar(featureName) = updatedType.getOrElse(null) - } - else { - schemaSoFar += (featureName -> currentType) - } - } - } - schemaSoFar - } - - private def mergeFieldTypes(first: Map[String, DataType], second: Map[String, DataType]): Map[String, DataType] = { - //Merge two maps and do the comparison. - val mutMap = collection.mutable.Map[String, DataType]((first.keySet ++ second.keySet) - .map(key => (key, findTightestCommonType(first.getOrElse(key, null), second.getOrElse(key, null)).get)) - .toSeq: _*) - mutMap - } - - /** - * Infer Feature datatype based on field number - */ - private def inferField(feature: Feature): DataType = { - feature.getKindCase.getNumber match { - case Feature.BYTES_LIST_FIELD_NUMBER => { - StringType - } - case Feature.INT64_LIST_FIELD_NUMBER => { - parseInt64List(feature) - } - case Feature.FLOAT_LIST_FIELD_NUMBER => { - parseFloatList(feature) - } - case _ => throw new RuntimeException("unsupported type ...") - } - } - - private def parseInt64List(feature: Feature): DataType = { - val int64List = feature.getInt64List.getValueList.asScala.toArray - val length = int64List.size - if (length == 0) { - null - } - else if (length > 1) { - ArrayType(LongType) - } - else { - val fieldValue = int64List(0).toString - parseInteger(fieldValue) - } - } - - private def parseFloatList(feature: Feature): DataType = { - val floatList = feature.getFloatList.getValueList.asScala.toArray - val length = floatList.size - if (length == 0) { - null - } - else if (length > 1) { - ArrayType(DoubleType) - } - else { - val fieldValue = floatList(0).toString - parseFloat(fieldValue) - } - } - - private def parseInteger(field: String): DataType = if (allCatch.opt(field.toInt).isDefined) { - IntegerType - } - else { - parseLong(field) - } - - private def parseLong(field: String): DataType = if (allCatch.opt(field.toLong).isDefined) { - LongType - } - else { - throw new RuntimeException("Unable to parse field datatype to int64...") - } - - private def parseFloat(field: String): DataType = { - if ((allCatch opt field.toFloat).isDefined) { - FloatType - } - else { - parseDouble(field) - } - } - - private def parseDouble(field: String): DataType = if (allCatch.opt(field.toDouble).isDefined) { - DoubleType - } - else { - throw new RuntimeException("Unable to parse field datatype to float64...") - } - /** - * Copied from internal Spark api - * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] - */ - private val numericPrecedence: IndexedSeq[DataType] = - IndexedSeq[DataType](IntegerType, - LongType, - FloatType, - DoubleType, - StringType) - - private def getNumericPrecedence(dataType: DataType): Int = { - dataType match { - case x if x.equals(IntegerType) => 0 - case x if x.equals(LongType) => 1 - case x if x.equals(FloatType) => 2 - case x if x.equals(DoubleType) => 3 - case x if x.equals(ArrayType(LongType)) => 4 - case x if x.equals(ArrayType(DoubleType)) => 5 - case x if x.equals(StringType) => 6 - case _ => throw new RuntimeException("Unable to get the precedence for given datatype...") - } - } - - /** - * Copied from internal Spark api - * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] - */ - private val findTightestCommonType: (DataType, DataType) => Option[DataType] = { - case (t1, t2) if t1 == t2 => Some(t1) - case (null, t2) => Some(t2) - case (t1, null) => Some(t1) - case (t1, t2) if t1.equals(ArrayType(LongType)) && t2.equals(ArrayType(DoubleType)) => Some(ArrayType(DoubleType)) - case (t1, t2) if t1.equals(ArrayType(DoubleType)) && t2.equals(ArrayType(LongType)) => Some(ArrayType(DoubleType)) - case (StringType, t2) => Some(StringType) - case (t1, StringType) => Some(StringType) - - // Promote numeric types to the highest of the two and all numeric types to unlimited decimal - case (t1, t2) => - val t1Precedence = getNumericPrecedence(t1) - val t2Precedence = getNumericPrecedence(t2) - val newType = if (t1Precedence > t2Precedence) t1 else t2 - Some(newType) - case _ => None - } -} - diff --git a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala index 7766aa05..bde354f6 100644 --- a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala +++ b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala @@ -20,30 +20,44 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext, SparkSession} -import org.tensorflow.example.Example +import org.tensorflow.example.{SequenceExample, Example} import org.tensorflow.hadoop.io.TFRecordFileInputFormat import org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowDecoder -case class TensorflowRelation(options: Map[String, String], customSchema: Option[StructType]=None)(@transient val session: SparkSession) extends BaseRelation with TableScan { +case class TensorflowRelation(options: Map[String, String], customSchema: Option[StructType]=None) + (@transient val session: SparkSession) extends BaseRelation with TableScan { //Import TFRecords as DataFrame happens here - lazy val (tf_rdd, tf_schema) = { + lazy val (tfRdd, tfSchema) = { val rdd = session.sparkContext.newAPIHadoopFile(options("path"), classOf[TFRecordFileInputFormat], classOf[BytesWritable], classOf[NullWritable]) - val exampleRdd = rdd.map { - case (bytesWritable, nullWritable) => Example.parseFrom(bytesWritable.getBytes) + val recordType = options.getOrElse("recordType", "Example") + + recordType match { + case "Example" => + val exampleRdd = rdd.map{case (bytesWritable, nullWritable) => + Example.parseFrom(bytesWritable.getBytes) + } + val finalSchema = customSchema.getOrElse(TensorFlowInferSchema(exampleRdd)) + val rowRdd = exampleRdd.map(example => DefaultTfRecordRowDecoder.decodeExample(example, finalSchema)) + (rowRdd, finalSchema) + case "SequenceExample" => + val sequenceExampleRdd = rdd.map{case (bytesWritable, nullWritable) => + SequenceExample.parseFrom(bytesWritable.getBytes) + } + val finalSchema = customSchema.getOrElse(TensorFlowInferSchema(sequenceExampleRdd)) + val rowRdd = sequenceExampleRdd.map(example => DefaultTfRecordRowDecoder.decodeSequenceExample(example, finalSchema)) + (rowRdd, finalSchema) + case _ => + throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be Example or SequenceExample") } - - val finalSchema = customSchema.getOrElse(TensorflowInferSchema(exampleRdd)) - - (exampleRdd.map(example => DefaultTfRecordRowDecoder.decodeTfRecord(example, finalSchema)), finalSchema) } override def sqlContext: SQLContext = session.sqlContext - override def schema: StructType = tf_schema + override def schema: StructType = tfSchema - override def buildScan(): RDD[Row] = tf_rdd + override def buildScan(): RDD[Row] = tfRdd } diff --git a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala index ccde03aa..1ce3fc2d 100644 --- a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala +++ b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala @@ -24,14 +24,24 @@ trait TfRecordRowDecoder { /** * Decodes each TensorFlow "Example" as DataFrame "Row" * - * Maps each feature in Example to element in Row with DataType based on custom schema or - * default mapping of Int64List, FloatList, BytesList to column data type + * Maps each feature in Example to element in Row with DataType based on custom schema * * @param example TensorFlow Example to decode * @param schema Decode Example using specified schema * @return a DataFrame row */ - def decodeTfRecord(example: Example, schema: StructType): Row + def decodeExample(example: Example, schema: StructType): Row + + /** + * Decodes each TensorFlow "SequenceExample" as DataFrame "Row" + * + * Maps each feature in SequenceExample to element in Row with DataType based on custom schema + * + * @param sequenceExample TensorFlow SequenceExample to decode + * @param schema Decode SequenceExample using specified schema + * @return a DataFrame row + */ + def decodeSequenceExample(sequenceExample: SequenceExample, schema: StructType): Row } object DefaultTfRecordRowDecoder extends TfRecordRowDecoder { @@ -45,26 +55,76 @@ object DefaultTfRecordRowDecoder extends TfRecordRowDecoder { * @param schema Decode Example using specified schema * @return a DataFrame row */ - def decodeTfRecord(example: Example, schema: StructType): Row = { + def decodeExample(example: Example, schema: StructType): Row = { val row = Array.fill[Any](schema.length)(null) example.getFeatures.getFeatureMap.asScala.foreach { case (featureName, feature) => val index = schema.fieldIndex(featureName) - val colDataType = schema.fields(index).dataType - row(index) = colDataType match { - case IntegerType => IntFeatureDecoder.decode(feature) - case LongType => LongFeatureDecoder.decode(feature) - case FloatType => FloatFeatureDecoder.decode(feature) - case DoubleType => DoubleFeatureDecoder.decode(feature) - case ArrayType(IntegerType, true) => IntListFeatureDecoder.decode(feature) - case ArrayType(LongType, _) => LongListFeatureDecoder.decode(feature) - case ArrayType(FloatType, _) => FloatListFeatureDecoder.decode(feature) - case ArrayType(DoubleType, _) => DoubleListFeatureDecoder.decode(feature) - case StringType => StringFeatureDecoder.decode(feature) - case _ => throw new RuntimeException(s"Cannot convert feature to unsupported data type ${colDataType}") - } + row(index) = decodeFeature(feature, schema, index) + } + Row.fromSeq(row) + } + + /** + * Decodes each TensorFlow "SequenceExample" as DataFrame "Row" + * + * Maps each feature in SequenceExample to element in Row with DataType based on custom schema + * + * @param sequenceExample TensorFlow SequenceExample to decode + * @param schema Decode Example using specified schema + * @return a DataFrame row + */ + def decodeSequenceExample(sequenceExample: SequenceExample, schema: StructType): Row = { + val row = Array.fill[Any](schema.length)(null) + + //Decode features + sequenceExample.getContext.getFeatureMap.asScala.foreach { + case (featureName, feature) => + val index = schema.fieldIndex(featureName) + row(index) = decodeFeature(feature, schema, index) } + + //Decode feature lists + sequenceExample.getFeatureLists.getFeatureListMap.asScala.foreach { + case (featureName, featureList) => + val index = schema.fieldIndex(featureName) + row(index) = decodeFeatureList(featureList, schema, index) + } + Row.fromSeq(row) } + + // Decode Feature to Scala Type based on field in schema + private def decodeFeature(feature: Feature, schema: StructType, fieldIndex: Int): Any = { + val colDataType = schema.fields(fieldIndex).dataType + + colDataType match { + case IntegerType => IntFeatureDecoder.decode(feature) + case LongType => LongFeatureDecoder.decode(feature) + case FloatType => FloatFeatureDecoder.decode(feature) + case DoubleType => DoubleFeatureDecoder.decode(feature) + case StringType => StringFeatureDecoder.decode(feature) + case ArrayType(IntegerType, _) => IntListFeatureDecoder.decode(feature) + case ArrayType(LongType, _) => LongListFeatureDecoder.decode(feature) + case ArrayType(FloatType, _) => FloatListFeatureDecoder.decode(feature) + case ArrayType(DoubleType, _) => DoubleListFeatureDecoder.decode(feature) + case ArrayType(StringType, _) => StringListFeatureDecoder.decode(feature) + case _ => throw new scala.RuntimeException(s"Cannot convert Feature to unsupported data type ${colDataType}") + } + } + + // Decode FeatureList to Scala Type based on field in schema + private def decodeFeatureList(featureList: FeatureList, schema: StructType, fieldIndex: Int): Any = { + val colDataType = schema.fields(fieldIndex).dataType + + colDataType match { + case ArrayType(ArrayType(IntegerType, _), _) => IntFeatureListDecoder.decode(featureList) + case ArrayType(ArrayType(LongType, _), _) => LongFeatureListDecoder.decode(featureList) + case ArrayType(ArrayType(FloatType, _), _) => FloatFeatureListDecoder.decode(featureList) + case ArrayType(ArrayType(DoubleType, _), _) => DoubleFeatureListDecoder.decode(featureList) + case ArrayType(ArrayType(StringType, _), _) => StringFeatureListDecoder.decode(featureList) + case _ => throw new scala.RuntimeException(s"Cannot convert FeatureList to unsupported data type ${colDataType}") + } + } } diff --git a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala index 3d9072a5..b7421cc0 100644 --- a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala +++ b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala @@ -16,6 +16,7 @@ package org.tensorflow.spark.datasources.tfrecords.serde import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.tensorflow.example._ @@ -28,7 +29,17 @@ trait TfRecordRowEncoder { * @param row a DataFrame row * @return TensorFlow Example */ - def encodeTfRecord(row: Row): Example + def encodeExample(row: Row): Example + + /** + * Encodes each Row as TensorFlow "SequenceExample" + * + * Maps each column in Row to one of Int64List, FloatList, BytesList or FeatureList based on the column data type + * + * @param row a DataFrame row + * @return TensorFlow SequenceExample + */ + def encodeSequenceExample(row: Row): SequenceExample } object DefaultTfRecordRowEncoder extends TfRecordRowEncoder { @@ -41,26 +52,107 @@ object DefaultTfRecordRowEncoder extends TfRecordRowEncoder { * @param row a DataFrame row * @return TensorFlow Example */ - def encodeTfRecord(row: Row): Example = { + def encodeExample(row: Row): Example = { val features = Features.newBuilder() val example = Example.newBuilder() - row.schema.zipWithIndex.map { + row.schema.zipWithIndex.foreach { case (structField, index) => - val value = row.get(index) - val feature = structField.dataType match { - case IntegerType | LongType => Int64ListFeatureEncoder.encode(value) - case FloatType | DoubleType => FloatListFeatureEncoder.encode(value) - case ArrayType(IntegerType, _) | ArrayType(LongType, _) => Int64ListFeatureEncoder.encode(value) - case ArrayType(DoubleType, _) => FloatListFeatureEncoder.encode(value) - case _ => BytesListFeatureEncoder.encode(value) - } + val feature = encodeFeature(row, structField, index) features.putFeature(structField.name, feature) } - features.build() - example.setFeatures(features) + example.setFeatures(features.build()) example.build() } -} + /** + * Encodes each Row as TensorFlow "SequenceExample" + * + * Maps each column in Row to one of Int64List, FloatList, BytesList or FeatureList based on the column data type + * + * @param row a DataFrame row + * @return TensorFlow SequenceExample + */ + def encodeSequenceExample(row: Row): SequenceExample = { + val features = Features.newBuilder() + val featureLists = FeatureLists.newBuilder() + val sequenceExample = SequenceExample.newBuilder() + + row.schema.zipWithIndex.foreach { + case (structField, index) => structField.dataType match { + case ArrayType(ArrayType(_, _), _) | ArrayType(StringType, _) => + val featureList = encodeFeatureList(row, structField, index) + featureLists.putFeatureList(structField.name, featureList) + case _ => + val feature = encodeFeature(row, structField, index) + features.putFeature(structField.name, feature) + } + } + + sequenceExample.setContext(features.build()) + sequenceExample.setFeatureLists(featureLists.build()) + sequenceExample.build() + } + + //Encode field in row to TensorFlow Feature + private def encodeFeature(row: Row, structField: StructField, index: Int): Feature = { + val feature = structField.dataType match { + case IntegerType => Int64ListFeatureEncoder.encode(Seq(row.getInt(index).toLong)) + case LongType => Int64ListFeatureEncoder.encode(Seq(row.getLong(index))) + case FloatType => FloatListFeatureEncoder.encode(Seq(row.getFloat(index))) + case DoubleType => FloatListFeatureEncoder.encode(Seq(row.getDouble(index).toFloat)) + case StringType => BytesListFeatureEncoder.encode(Seq(row.getString(index))) + case ArrayType(IntegerType, _) => + Int64ListFeatureEncoder.encode(ArrayData.toArrayData(row.get(index)).toIntArray().map(_.toLong)) + case ArrayType(LongType, _) => + Int64ListFeatureEncoder.encode(ArrayData.toArrayData(row.get(index)).toLongArray()) + case ArrayType(FloatType, _) => + FloatListFeatureEncoder.encode(ArrayData.toArrayData(row.get(index)).toFloatArray()) + case ArrayType(DoubleType, _) => + FloatListFeatureEncoder.encode(ArrayData.toArrayData(row.get(index)).toDoubleArray().map(_.toFloat)) + case ArrayType(_, _) => + BytesListFeatureEncoder.encode(ArrayData.toArrayData(row.get(index)).toArray[String](StringType)) + case _ => BytesListFeatureEncoder.encode(Seq(row.getString(index))) + } + feature + } + + //Encode field in row to TensorFlow FeatureList + def encodeFeatureList(row: Row, structField: StructField, index: Int): FeatureList = { + val featureList = structField.dataType match { + case ArrayType(ArrayType(IntegerType, _), _) => + val longArrays = ArrayData.toArrayData(row.get(index)).array.map {arr => + ArrayData.toArrayData(arr).toIntArray().map(_.toLong).toSeq + } + Int64FeatureListEncoder.encode(longArrays) + + case ArrayType(ArrayType(LongType, _), _) => + val longArrays = ArrayData.toArrayData(row.get(index)).array.map {arr => + ArrayData.toArrayData(arr).toLongArray().toSeq + } + Int64FeatureListEncoder.encode(longArrays) + + case ArrayType(ArrayType(FloatType, _), _) => + val floatArrays = ArrayData.toArrayData(row.get(index)).array.map {arr => + ArrayData.toArrayData(arr).toFloatArray().toSeq + } + FloatFeatureListEncoder.encode(floatArrays) + + case ArrayType(ArrayType(DoubleType, _), _) => + val floatArrays = ArrayData.toArrayData(row.get(index)).array.map {arr => + ArrayData.toArrayData(arr).toDoubleArray().map(_.toFloat).toSeq + } + FloatFeatureListEncoder.encode(floatArrays) + + case ArrayType(ArrayType(StringType, _), _) => + val arrayData = ArrayData.toArrayData(row.get(index)).array.map {arr => + ArrayData.toArrayData(arr).toArray[String](StringType).toSeq + }.toSeq + BytesFeatureListEncoder.encode(arrayData) + + case _ => throw new RuntimeException(s"Cannot convert row element ${row.get(index)} to FeatureList.") + } + featureList + } +} diff --git a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoder.scala b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoder.scala index c390ae42..766b70c8 100644 --- a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoder.scala +++ b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoder.scala @@ -16,7 +16,6 @@ package org.tensorflow.spark.datasources.tfrecords.serde import org.tensorflow.example.Feature - import scala.collection.JavaConverters._ trait FeatureDecoder[T] { @@ -34,7 +33,7 @@ trait FeatureDecoder[T] { */ object IntFeatureDecoder extends FeatureDecoder[Int] { override def decode(feature: Feature): Int = { - require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") + require(feature != null && feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") try { val int64List = feature.getInt64List.getValueList require(int64List.size() == 1, "Length of Int64List must equal 1") @@ -42,24 +41,24 @@ object IntFeatureDecoder extends FeatureDecoder[Int] { } catch { case ex: Exception => - throw new RuntimeException(s"Cannot convert feature to Int.", ex) + throw new RuntimeException(s"Cannot convert feature to integer.", ex) } } } /** - * Decode TensorFlow "Feature" to Seq[Int] + * Decode TensorFlow "Feature" to Integer array */ object IntListFeatureDecoder extends FeatureDecoder[Seq[Int]] { override def decode(feature: Feature): Seq[Int] = { - require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") + require(feature != null && feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") try { - val array = feature.getInt64List.getValueList.asScala.toArray + val array = feature.getInt64List.getValueList.asScala.toSeq array.map(_.toInt) } catch { case ex: Exception => - throw new RuntimeException(s"Cannot convert feature to Seq[Int].", ex) + throw new RuntimeException(s"Cannot convert feature to Integer array.", ex) } } } @@ -69,7 +68,7 @@ object IntListFeatureDecoder extends FeatureDecoder[Seq[Int]] { */ object LongFeatureDecoder extends FeatureDecoder[Long] { override def decode(feature: Feature): Long = { - require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") + require(feature != null && feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") try { val int64List = feature.getInt64List.getValueList require(int64List.size() == 1, "Length of Int64List must equal 1") @@ -83,18 +82,18 @@ object LongFeatureDecoder extends FeatureDecoder[Long] { } /** - * Decode TensorFlow "Feature" to Seq[Long] + * Decode TensorFlow "Feature" to Long array */ object LongListFeatureDecoder extends FeatureDecoder[Seq[Long]] { override def decode(feature: Feature): Seq[Long] = { - require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") + require(feature != null && feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") try { - val array = feature.getInt64List.getValueList.asScala.toArray + val array = feature.getInt64List.getValueList.asScala.toSeq array.map(_.toLong) } catch { case ex: Exception => - throw new RuntimeException(s"Cannot convert feature to Array[Long].", ex) + throw new RuntimeException(s"Cannot convert feature to Long array.", ex) } } } @@ -104,7 +103,7 @@ object LongListFeatureDecoder extends FeatureDecoder[Seq[Long]] { */ object FloatFeatureDecoder extends FeatureDecoder[Float] { override def decode(feature: Feature): Float = { - require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") + require(feature != null && feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") try { val floatList = feature.getFloatList.getValueList require(floatList.size() == 1, "Length of FloatList must equal 1") @@ -118,18 +117,18 @@ object FloatFeatureDecoder extends FeatureDecoder[Float] { } /** - * Decode TensorFlow "Feature" to Seq[Float] + * Decode TensorFlow "Feature" to Float array */ object FloatListFeatureDecoder extends FeatureDecoder[Seq[Float]] { override def decode(feature: Feature): Seq[Float] = { - require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") + require(feature != null && feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") try { - val array = feature.getFloatList.getValueList.asScala.toArray + val array = feature.getFloatList.getValueList.asScala.toSeq array.map(_.toFloat) } catch { case ex: Exception => - throw new RuntimeException(s"Cannot convert feature to Array[Float].", ex) + throw new RuntimeException(s"Cannot convert feature to Float array.", ex) } } } @@ -139,7 +138,7 @@ object FloatListFeatureDecoder extends FeatureDecoder[Seq[Float]] { */ object DoubleFeatureDecoder extends FeatureDecoder[Double] { override def decode(feature: Feature): Double = { - require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") + require(feature != null && feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") try { val floatList = feature.getFloatList.getValueList require(floatList.size() == 1, "Length of FloatList must equal 1") @@ -153,18 +152,18 @@ object DoubleFeatureDecoder extends FeatureDecoder[Double] { } /** - * Decode TensorFlow "Feature" to Seq[Double] + * Decode TensorFlow "Feature" to Double array */ object DoubleListFeatureDecoder extends FeatureDecoder[Seq[Double]] { override def decode(feature: Feature): Seq[Double] = { - require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") + require(feature != null && feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") try { - val array = feature.getFloatList.getValueList.asScala.toArray + val array = feature.getFloatList.getValueList.asScala.toSeq array.map(_.toDouble) } catch { case ex: Exception => - throw new RuntimeException(s"Cannot convert feature to Array[Double].", ex) + throw new RuntimeException(s"Cannot convert feature to Double array.", ex) } } } @@ -174,9 +173,11 @@ object DoubleListFeatureDecoder extends FeatureDecoder[Seq[Double]] { */ object StringFeatureDecoder extends FeatureDecoder[String] { override def decode(feature: Feature): String = { - require(feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList") + require(feature != null && feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList") try { - feature.getBytesList.toByteString.toStringUtf8.trim + val bytesList = feature.getBytesList.getValueList + require(bytesList.size() == 1, "Length of BytesList must equal 1") + bytesList.get(0).toStringUtf8 } catch { case ex: Exception => @@ -185,3 +186,20 @@ object StringFeatureDecoder extends FeatureDecoder[String] { } } +/** + * Decode TensorFlow "Feature" to String array + */ +object StringListFeatureDecoder extends FeatureDecoder[Seq[String]] { + override def decode(feature: Feature): Seq[String] = { + require(feature != null && feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList") + try { + val array = feature.getBytesList.getValueList.asScala.toSeq + array.map(_.toStringUtf8) + } + catch { + case ex: Exception => + throw new RuntimeException(s"Cannot convert feature to String array.", ex) + } + } +} + diff --git a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala index 646cbadd..8af04d84 100644 --- a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala +++ b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala @@ -15,11 +15,10 @@ */ package org.tensorflow.spark.datasources.tfrecords.serde -import org.tensorflow.example.{BytesList, Feature, FloatList, Int64List} +import org.tensorflow.example._ import org.tensorflow.hadoop.shaded.protobuf.ByteString -import org.tensorflow.spark.datasources.tfrecords.DataTypesConvertor -trait FeatureEncoder { +trait FeatureEncoder[T] { /** * Encodes input value as TensorFlow "Feature" * @@ -28,91 +27,47 @@ trait FeatureEncoder { * @param value Input value * @return TensorFlow Feature */ - def encode(value: Any): Feature + def encode(value: T): Feature } /** * Encode input value to Int64List */ -object Int64ListFeatureEncoder extends FeatureEncoder { - override def encode(value: Any): Feature = { - try { - val int64List = value match { - case i: Int => Int64List.newBuilder().addValue(i.toLong).build() - case l: Long => Int64List.newBuilder().addValue(l).build() - case arr: scala.collection.mutable.WrappedArray[_] => toInt64List(arr.toArray[Any]) - case arr: Array[_] => toInt64List(arr) - case seq: Seq[_] => toInt64List(seq.toArray[Any]) - case _ => throw new RuntimeException(s"Cannot convert object $value to Int64List") - } - Feature.newBuilder().setInt64List(int64List).build() - } - catch { - case ex: Exception => - throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to Int64List feature.", ex) - } - } - - private def toInt64List[T](arr: Array[T]): Int64List = { +object Int64ListFeatureEncoder extends FeatureEncoder[Seq[Long]] { + override def encode(value: Seq[Long]): Feature = { val intListBuilder = Int64List.newBuilder() - arr.foreach(x => { - require(x != null, "Int64List with null values is not supported") - val longValue = DataTypesConvertor.toLong(x) - intListBuilder.addValue(longValue) - }) - intListBuilder.build() + value.foreach {x => + intListBuilder.addValue(x) + } + val int64List = intListBuilder.build() + Feature.newBuilder().setInt64List(int64List).build() } } /** * Encode input value to FloatList */ -object FloatListFeatureEncoder extends FeatureEncoder { - override def encode(value: Any): Feature = { - try { - val floatList = value match { - case i: Int => FloatList.newBuilder().addValue(i.toFloat).build() - case l: Long => FloatList.newBuilder().addValue(l.toFloat).build() - case f: Float => FloatList.newBuilder().addValue(f).build() - case d: Double => FloatList.newBuilder().addValue(d.toFloat).build() - case arr: scala.collection.mutable.WrappedArray[_] => toFloatList(arr.toArray[Any]) - case arr: Array[_] => toFloatList(arr) - case seq: Seq[_] => toFloatList(seq.toArray[Any]) - case _ => throw new RuntimeException(s"Cannot convert object $value to FloatList") - } - Feature.newBuilder().setFloatList(floatList).build() - } - catch { - case ex: Exception => - throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to FloatList feature.", ex) - } - } - - private def toFloatList[T](arr: Array[T]): FloatList = { +object FloatListFeatureEncoder extends FeatureEncoder[Seq[Float]] { + override def encode(value: Seq[Float]): Feature = { val floatListBuilder = FloatList.newBuilder() - arr.foreach(x => { - require(x != null, "FloatList with null values is not supported") - val longValue = DataTypesConvertor.toFloat(x) - floatListBuilder.addValue(longValue) - }) - floatListBuilder.build() + value.foreach {x => + floatListBuilder.addValue(x) + } + val floatList = floatListBuilder.build() + Feature.newBuilder().setFloatList(floatList).build() } } /** * Encode input value to ByteList */ -object BytesListFeatureEncoder extends FeatureEncoder { - override def encode(value: Any): Feature = { - try { - val byteList = BytesList.newBuilder().addValue(ByteString.copyFrom(value.toString.getBytes)).build() - Feature.newBuilder().setBytesList(byteList).build() - } - catch { - case ex: Exception => - throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to ByteList feature.", ex) +object BytesListFeatureEncoder extends FeatureEncoder[Seq[String]] { + override def encode(value: Seq[String]): Feature = { + val bytesListBuilder = BytesList.newBuilder() + value.foreach {x => + bytesListBuilder.addValue(ByteString.copyFrom(x.getBytes)) } + val bytesList = bytesListBuilder.build() + Feature.newBuilder().setBytesList(bytesList).build() } -} - - +} \ No newline at end of file diff --git a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListDecoder.scala b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListDecoder.scala new file mode 100644 index 00000000..faf2f986 --- /dev/null +++ b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListDecoder.scala @@ -0,0 +1,74 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.tensorflow.example.FeatureList +import scala.collection.JavaConverters._ + +trait FeatureListDecoder[T] extends Serializable{ + /** + * Decodes each TensorFlow "FeatureList" to desired Scala type + * + * @param featureList TensorFlow FeatureList + * @return Decoded featureList + */ + def decode(featureList: FeatureList): T +} + +/** + * Decode TensorFlow "FeatureList" to 2-dimensional Integer array + */ +object IntFeatureListDecoder extends FeatureListDecoder[Seq[Seq[Int]]] { + override def decode(featureList: FeatureList): Seq[Seq[Int]] = { + featureList.getFeatureList.asScala.map(x => IntListFeatureDecoder.decode(x)).toSeq + } +} + +/** + * Decode TensorFlow "FeatureList" to 2-dimensional Long array + */ +object LongFeatureListDecoder extends FeatureListDecoder[Seq[Seq[Long]]] { + override def decode(featureList: FeatureList): Seq[Seq[Long]] = { + featureList.getFeatureList.asScala.map(x => LongListFeatureDecoder.decode(x)).toSeq + } +} + +/** + * Decode TensorFlow "FeatureList" to 2-dimensional Float array + */ +object FloatFeatureListDecoder extends FeatureListDecoder[Seq[Seq[Float]]] { + override def decode(featureList: FeatureList): Seq[Seq[Float]] = { + featureList.getFeatureList.asScala.map(x => FloatListFeatureDecoder.decode(x)).toSeq + } +} + +/** + * Decode TensorFlow "FeatureList" to 2-dimensional Double array + */ +object DoubleFeatureListDecoder extends FeatureListDecoder[Seq[Seq[Double]]] { + override def decode(featureList: FeatureList): Seq[Seq[Double]] = { + featureList.getFeatureList.asScala.map(x => DoubleListFeatureDecoder.decode(x)).toSeq + } +} + +/** + * Decode TensorFlow "FeatureList" to 2-dimensional String array + */ +object StringFeatureListDecoder extends FeatureListDecoder[Seq[Seq[String]]] { + override def decode(featureList: FeatureList): Seq[Seq[String]] = { + featureList.getFeatureList.asScala.map(x => StringListFeatureDecoder.decode(x)).toSeq + } +} \ No newline at end of file diff --git a/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListEncoder.scala b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListEncoder.scala new file mode 100644 index 00000000..5efa6028 --- /dev/null +++ b/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListEncoder.scala @@ -0,0 +1,73 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.tensorflow.example.FeatureList + +trait FeatureListEncoder[T] extends Serializable{ + /** + * Encodes input value as TensorFlow "FeatureList" + * + * Maps input value to a feature list of type Int64List, FloatList, or BytesList + * + * @param values Input values + * @return TensorFlow FeatureList + */ + def encode(values: T): FeatureList +} + + +/** + * Encode 2-dimensional Long array to TensorFlow "FeatureList" of type Int64List + */ +object Int64FeatureListEncoder extends FeatureListEncoder[Seq[Seq[Long]]] { + def encode(values: Seq[Seq[Long]]) : FeatureList = { + val builder = FeatureList.newBuilder() + values.foreach { x => + val int64list = Int64ListFeatureEncoder.encode(x) + builder.addFeature(int64list) + } + builder.build() + } +} + +/** + * Encode 2-dimensional Float array to TensorFlow "FeatureList" of type FloatList + */ +object FloatFeatureListEncoder extends FeatureListEncoder[Seq[Seq[Float]]] { + def encode(value: Seq[Seq[Float]]) : FeatureList = { + val builder = FeatureList.newBuilder() + value.foreach { x => + val floatList = FloatListFeatureEncoder.encode(x) + builder.addFeature(floatList) + } + builder.build() + } +} + +/** + * Encode 2-dimensional String array to TensorFlow "FeatureList" of type BytesList + */ +object BytesFeatureListEncoder extends FeatureListEncoder[Seq[Seq[String]]] { + def encode(value: Seq[Seq[String]]) : FeatureList = { + val builder = FeatureList.newBuilder() + value.foreach { x => + val bytesList = BytesListFeatureEncoder.encode(x) + builder.addFeature(bytesList) + } + builder.build() + } +} diff --git a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/InferSchemaSuite.scala b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/InferSchemaSuite.scala new file mode 100644 index 00000000..a334d097 --- /dev/null +++ b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/InferSchemaSuite.scala @@ -0,0 +1,141 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.tensorflow.example._ +import org.tensorflow.hadoop.shaded.protobuf.ByteString + +class InferSchemaSuite extends SharedSparkSessionSuite { + + val longFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(Int.MaxValue + 10L)).build() + val floatFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(10.0F).build()).build() + val strFeature = Feature.newBuilder().setBytesList( + BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes))).build() + + val longList = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(-2L).addValue(20L).build()).build() + val floatList = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(2.5F).addValue(7F).build()).build() + val strList = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)) + .addValue(ByteString.copyFrom("r2".getBytes)).build()).build() + + "InferSchema" should { + + "Infer schema from Example records" in { + //Build example1 + val features1 = Features.newBuilder() + .putFeature("LongFeature", longFeature) + .putFeature("FloatFeature", floatFeature) + .putFeature("StrFeature", strFeature) + .putFeature("LongList", longFeature) + .putFeature("FloatList", floatFeature) + .putFeature("StrList", strFeature) + .putFeature("MixedTypeList", longList) + .build() + val example1 = Example.newBuilder() + .setFeatures(features1) + .build() + + //Example2 contains subset of features in example1 to test behavior with missing features + val features2 = Features.newBuilder() + .putFeature("StrFeature", strFeature) + .putFeature("LongList", longList) + .putFeature("FloatList", floatList) + .putFeature("StrList", strList) + .putFeature("MixedTypeList", floatList) + .build() + val example2 = Example.newBuilder() + .setFeatures(features2) + .build() + + val exampleRdd: RDD[Example] = spark.sparkContext.parallelize(List(example1, example2)) + val inferredSchema = TensorFlowInferSchema(exampleRdd) + + //Verify each TensorFlow Datatype is inferred as one of our Datatype + assert(inferredSchema.fields.length == 7) + val schemaMap = inferredSchema.map(f => (f.name, f.dataType)).toMap + assert(schemaMap("LongFeature") === LongType) + assert(schemaMap("FloatFeature") === FloatType) + assert(schemaMap("StrFeature") === StringType) + assert(schemaMap("LongList") === ArrayType(LongType)) + assert(schemaMap("FloatList") === ArrayType(FloatType)) + assert(schemaMap("StrList") === ArrayType(StringType)) + assert(schemaMap("MixedTypeList") === ArrayType(FloatType)) + } + + "Infer schema from SequenceExample records" in { + + //Build sequence example1 + val features1 = Features.newBuilder() + .putFeature("FloatFeature", floatFeature) + + val longFeatureList1 = FeatureList.newBuilder().addFeature(longFeature).addFeature(longList).build() + val floatFeatureList1 = FeatureList.newBuilder().addFeature(floatFeature).addFeature(floatList).build() + val strFeatureList1 = FeatureList.newBuilder().addFeature(strFeature).build() + val mixedFeatureList1 = FeatureList.newBuilder().addFeature(floatFeature).addFeature(strList).build() + + val featureLists1 = FeatureLists.newBuilder() + .putFeatureList("LongListOfLists", longFeatureList1) + .putFeatureList("FloatListOfLists", floatFeatureList1) + .putFeatureList("StringListOfLists", strFeatureList1) + .putFeatureList("MixedListOfLists", mixedFeatureList1) + .build() + + val seqExample1 = SequenceExample.newBuilder() + .setContext(features1) + .setFeatureLists(featureLists1) + .build() + + //SequenceExample2 contains subset of features in example1 to test behavior with missing features + val longFeatureList2 = FeatureList.newBuilder().addFeature(longList).build() + val floatFeatureList2 = FeatureList.newBuilder().addFeature(floatFeature).build() + val strFeatureList2 = FeatureList.newBuilder().addFeature(strFeature).build() //test both string lists of length=1 + val mixedFeatureList2 = FeatureList.newBuilder().addFeature(longFeature).addFeature(strFeature).build() + + val featureLists2 = FeatureLists.newBuilder() + .putFeatureList("LongListOfLists", longFeatureList2) + .putFeatureList("FloatListOfLists", floatFeatureList2) + .putFeatureList("StringListOfLists", strFeatureList2) + .putFeatureList("MixedListOfLists", mixedFeatureList2) + .build() + + val seqExample2 = SequenceExample.newBuilder() + .setFeatureLists(featureLists2) + .build() + + val seqExampleRdd: RDD[SequenceExample] = spark.sparkContext.parallelize(List(seqExample1, seqExample2)) + val inferredSchema = TensorFlowInferSchema(seqExampleRdd) + + //Verify each TensorFlow Datatype is inferred as one of our Datatype + assert(inferredSchema.fields.length == 5) + val schemaMap = inferredSchema.map(f => (f.name, f.dataType)).toMap + assert(schemaMap("FloatFeature") === FloatType) + assert(schemaMap("LongListOfLists") === ArrayType(ArrayType(LongType))) + assert(schemaMap("FloatListOfLists") === ArrayType(ArrayType(FloatType))) + assert(schemaMap("StringListOfLists") === ArrayType(ArrayType(StringType))) + assert(schemaMap("MixedListOfLists") === ArrayType(ArrayType(StringType))) + } + } + + "Throw an exception for unsupported record types" in { + intercept[Exception] { + val rdd: RDD[Long] = spark.sparkContext.parallelize(List(5L, 6L)) + TensorFlowInferSchema(rdd) + } + + } +} + diff --git a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowSuite.scala b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowSuite.scala new file mode 100644 index 00000000..f8ceab18 --- /dev/null +++ b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowSuite.scala @@ -0,0 +1,84 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords + +import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +class TensorFlowSuite extends SharedSparkSessionSuite { + + "Spark TensorFlow module" should { + + "Test Import/Export of Example records" in { + + val path = s"$TF_SANDBOX_DIR/example.tfrecord" + val testRows: Array[Row] = Array( + new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")), + new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2"))) + + val schema = StructType(List( + StructField("id", IntegerType), + StructField("IntegerLabel", IntegerType), + StructField("LongLabel", LongType), + StructField("FloatLabel", FloatType), + StructField("DoubleLabel", DoubleType), + StructField("DoubleArrayLabel", ArrayType(DoubleType, true)), + StructField("StrLabel", StringType))) + + val rdd = spark.sparkContext.parallelize(testRows) + + val df: DataFrame = spark.createDataFrame(rdd, schema) + df.write.format("tfrecords").option("recordType", "Example").save(path) + + //If schema is not provided. It will automatically infer schema + val importedDf: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").schema(schema).load(path) + val actualDf = importedDf.select("id", "IntegerLabel", "LongLabel", "FloatLabel", "DoubleLabel", "DoubleArrayLabel", "StrLabel").sort("StrLabel") + + val expectedRows = df.collect() + val actualRows = actualDf.collect() + + assert(expectedRows === actualRows) + } + + "Test Import/Export of SequenceExample records" in { + + val path = s"$TF_SANDBOX_DIR/sequenceExample.tfrecord" + val testRows: Array[Row] = Array( + new GenericRow(Array[Any](23L, Seq(Seq(2.0F, 4.5F)), Seq(Seq("r1", "r2")))), + new GenericRow(Array[Any](24L, Seq(Seq(-1.0F, 0F)), Seq(Seq("r3"))))) + + val schema = StructType(List( + StructField("id",LongType), + StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))), + StructField("StrArrayOfArrayLabel", ArrayType(ArrayType(StringType))) + )) + + val rdd = spark.sparkContext.parallelize(testRows) + + val df: DataFrame = spark.createDataFrame(rdd, schema) + df.write.format("tfrecords").option("recordType", "SequenceExample").save(path) + + val importedDf: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").schema(schema).load(path) + val actualDf = importedDf.select("id", "FloatArrayOfArrayLabel", "StrArrayOfArrayLabel").sort("id") + + val expectedRows = df.collect() + val actualRows = actualDf.collect() + + assert(expectedRows === actualRows) + } + } +} diff --git a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowSuite.scala b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowSuite.scala deleted file mode 100644 index 0e37968e..00000000 --- a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowSuite.scala +++ /dev/null @@ -1,210 +0,0 @@ -/** - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - *       http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.spark.datasources.tfrecords - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row} -import org.tensorflow.example._ -import org.tensorflow.hadoop.shaded.protobuf.ByteString -import org.tensorflow.spark.datasources.tfrecords.serde.{DefaultTfRecordRowDecoder, DefaultTfRecordRowEncoder} - -import scala.collection.JavaConverters._ - -class TensorflowSuite extends SharedSparkSessionSuite { - - "Spark TensorFlow module" should { - - "Test Import/Export" in { - - val path = s"$TF_SANDBOX_DIR/output25.tfr" - val testRows: Array[Row] = Array( - new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")), - new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2"))) - - val schema = StructType(List( - StructField("id", IntegerType), - StructField("IntegerTypelabel", IntegerType), - StructField("LongTypelabel", LongType), - StructField("FloatTypelabel", FloatType), - StructField("DoubleTypelabel", DoubleType), - StructField("vectorlabel", ArrayType(DoubleType, true)), - StructField("name", StringType))) - - val rdd = spark.sparkContext.parallelize(testRows) - - val df: DataFrame = spark.createDataFrame(rdd, schema) - df.write.format("tfrecords").save(path) - - //If schema is not provided. It will automatically infer schema - val importedDf: DataFrame = spark.read.format("tfrecords").schema(schema).load(path) - val actualDf = importedDf.select("id", "IntegerTypelabel", "LongTypelabel", "FloatTypelabel", "DoubleTypelabel", "vectorlabel", "name").sort("name") - - val expectedRows = df.collect() - val actualRows = actualDf.collect() - - expectedRows should equal(actualRows) - } - - "Encode given Row as TensorFlow example" in { - val schemaStructType = StructType(Array( - StructField("IntegerTypelabel", IntegerType), - StructField("LongTypelabel", LongType), - StructField("FloatTypelabel", FloatType), - StructField("DoubleTypelabel", DoubleType), - StructField("vectorlabel", ArrayType(DoubleType, true)), - StructField("strlabel", StringType) - )) - val doubleArray = Array(1.1, 111.1, 11111.1) - val expectedFloatArray = Array(1.1F, 111.1F, 11111.1F) - - val rowWithSchema = new GenericRowWithSchema(Array[Any](1, 23L, 10.0F, 14.0, doubleArray, "r1"), schemaStructType) - - //Encode Sql Row to TensorFlow example - val example = DefaultTfRecordRowEncoder.encodeTfRecord(rowWithSchema) - import org.tensorflow.example.Feature - - //Verify each Datatype converted to TensorFlow datatypes - val featureMap = example.getFeatures.getFeatureMap.asScala - assert(featureMap("IntegerTypelabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) - assert(featureMap("IntegerTypelabel").getInt64List.getValue(0).toInt == 1) - - assert(featureMap("LongTypelabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) - assert(featureMap("LongTypelabel").getInt64List.getValue(0).toInt == 23) - - assert(featureMap("FloatTypelabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) - assert(featureMap("FloatTypelabel").getFloatList.getValue(0) == 10.0F) - - assert(featureMap("DoubleTypelabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) - assert(featureMap("DoubleTypelabel").getFloatList.getValue(0) == 14.0F) - - assert(featureMap("vectorlabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) - assert(featureMap("vectorlabel").getFloatList.getValueList.toArray === expectedFloatArray) - - assert(featureMap("strlabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) - assert(featureMap("strlabel").getBytesList.toByteString.toStringUtf8.trim == "r1") - - } - - "Throw an exception for a vector with null values during Encode" in { - intercept[Exception] { - val schemaStructType = StructType(Array( - StructField("vectorlabel", ArrayType(DoubleType, true)) - )) - val doubleArray = Array(1.1, null, 111.1, null, 11111.1) - - val rowWithSchema = new GenericRowWithSchema(Array[Any](doubleArray), schemaStructType) - - //Throws NullPointerException - DefaultTfRecordRowEncoder.encodeTfRecord(rowWithSchema) - } - } - - "Decode given TensorFlow Example as Row" in { - - //Here Vector with null's are not supported - val expectedRow = new GenericRow(Array[Any](1, 23L, 10.0F, 14.0, Seq(1.0, 2.0), "r1")) - - val schema = StructType(List( - StructField("IntegerTypelabel", IntegerType), - StructField("LongTypelabel", LongType), - StructField("FloatTypelabel", FloatType), - StructField("DoubleTypelabel", DoubleType), - StructField("vectorlabel", ArrayType(DoubleType)), - StructField("strlabel", StringType))) - - //Build example - val intFeature = Int64List.newBuilder().addValue(1) - val longFeature = Int64List.newBuilder().addValue(23L) - val floatFeature = FloatList.newBuilder().addValue(10.0F) - val doubleFeature = FloatList.newBuilder().addValue(14.0F) - val vectorFeature = FloatList.newBuilder().addValue(1F).addValue(2F).build() - val strFeature = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build() - val features = Features.newBuilder() - .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature).build()) - .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature).build()) - .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature).build()) - .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature).build()) - .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature).build()) - .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature).build()) - .build() - val example = Example.newBuilder() - .setFeatures(features) - .build() - - //Decode TensorFlow example to Sql Row - val actualRow = DefaultTfRecordRowDecoder.decodeTfRecord(example, schema) - actualRow should equal(expectedRow) - } - - "Check infer schema" in { - - //Build example1 - val intFeature1 = Int64List.newBuilder().addValue(1) - val longFeature1 = Int64List.newBuilder().addValue(Int.MaxValue + 10L) - val floatFeature1 = FloatList.newBuilder().addValue(10.0F) - val doubleFeature1 = FloatList.newBuilder().addValue(14.0F) - val vectorFeature1 = FloatList.newBuilder().addValue(1F).build() - val strFeature1 = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build() - val features1 = Features.newBuilder() - .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature1).build()) - .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature1).build()) - .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature1).build()) - .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature1).build()) - .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature1).build()) - .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature1).build()) - .build() - val example1 = Example.newBuilder() - .setFeatures(features1) - .build() - - //Build example2 - val intFeature2 = Int64List.newBuilder().addValue(2) - val longFeature2 = Int64List.newBuilder().addValue(24) - val floatFeature2 = FloatList.newBuilder().addValue(12.0F) - val doubleFeature2 = FloatList.newBuilder().addValue(Float.MaxValue + 15) - val vectorFeature2 = FloatList.newBuilder().addValue(2F).addValue(2F).build() - val strFeature2 = BytesList.newBuilder().addValue(ByteString.copyFrom("r2".getBytes)).build() - val features2 = Features.newBuilder() - .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature2).build()) - .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature2).build()) - .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature2).build()) - .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature2).build()) - .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature2).build()) - .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature2).build()) - .build() - val example2 = Example.newBuilder() - .setFeatures(features2) - .build() - - val exampleRDD: RDD[Example] = spark.sparkContext.parallelize(List(example1, example2)) - - val actualSchema = TensorflowInferSchema(exampleRDD) - - //Verify each TensorFlow Datatype is inferred as one of our Datatype - actualSchema.fields.map { colum => - colum.name match { - case "IntegerTypelabel" => colum.dataType.equals(IntegerType) - case "LongTypelabel" => colum.dataType.equals(LongType) - case "FloatTypelabel" | "DoubleTypelabel" | "vectorlabel" => colum.dataType.equals(FloatType) - case "strlabel" => colum.dataType.equals(StringType) - } - } - } - } -} - diff --git a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TestingUtils.scala b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TestingUtils.scala new file mode 100644 index 00000000..37520614 --- /dev/null +++ b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TestingUtils.scala @@ -0,0 +1,156 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ +import org.scalatest.Matchers + +object TestingUtils extends Matchers { + + /** + * Implicit class for comparing two double values using absolute tolerance. + */ + implicit class FloatArrayWithAlmostEquals(val left: Seq[Float]) { + + /** + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Float], epsilon : Float = 1E-6F): Boolean = { + if (left.size === right.size) { + (left zip right) forall { case (a, b) => a === (b +- epsilon) } + } + else false + } + } + + /** + * Implicit class for comparing two double values using absolute tolerance. + */ + implicit class DoubleArrayWithAlmostEquals(val left: Seq[Double]) { + + /** + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Double], epsilon : Double = 1E-6): Boolean = { + if (left.size === right.size) { + (left zip right) forall { case (a, b) => a === (b +- epsilon) } + } + else false + } + } + + /** + * Implicit class for comparing two double values using absolute tolerance. + */ + implicit class FloatMatrixWithAlmostEquals(val left: Seq[Seq[Float]]) { + + /** + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Seq[Float]], epsilon : Float = 1E-6F): Boolean = { + if (left.size === right.size) { + (left zip right) forall { case (a, b) => a ~== (b, epsilon) } + } + else false + } + } + + /** + * Implicit class for comparing two double values using absolute tolerance. + */ + implicit class DoubleMatrixWithAlmostEquals(val left: Seq[Seq[Double]]) { + + /** + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Seq[Double]], epsilon : Double = 1E-6): Boolean = { + if (left.size === right.size) { + (left zip right) forall { case (a, b) => a ~== (b, epsilon) } + } + else false + } + } + + /** + * Implicit class for comparing two rows using absolute tolerance. + */ + implicit class RowWithAlmostEquals(val left: Row) { + + /** + * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Row, schema: StructType): Boolean = { + if (schema != null && schema.fields.size == left.size && schema.fields.size == right.size) { + val leftRowWithSchema = new GenericRowWithSchema(left.toSeq.toArray, schema) + val rightRowWithSchema = new GenericRowWithSchema(right.toSeq.toArray, schema) + leftRowWithSchema ~== rightRowWithSchema + } + else false + } + + /** + * When all fields in row are equal or are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Row, epsilon : Float = 1E-6F): Boolean = { + if (left.size === right.size) { + val leftDataTypes = left.schema.fields.map(_.dataType) + val rightDataTypes = right.schema.fields.map(_.dataType) + + (leftDataTypes zip rightDataTypes).zipWithIndex.forall { + case ((FloatType, FloatType), i) => + left.getFloat(i) === (right.getFloat(i) +- epsilon) + + case ((DoubleType, DoubleType), i) => + left.getDouble(i) === (right.getDouble(i) +- epsilon) + + case ((ArrayType(FloatType,_), ArrayType(FloatType,_)), i) => + val leftArray = ArrayData.toArrayData(left.get(i)).toFloatArray().toSeq + val rightArray = ArrayData.toArrayData(right.get(i)).toFloatArray().toSeq + leftArray ~== (rightArray, epsilon) + + case ((ArrayType(DoubleType,_), ArrayType(DoubleType,_)), i) => + val leftArray = ArrayData.toArrayData(left.get(i)).toDoubleArray().toSeq + val rightArray = ArrayData.toArrayData(right.get(i)).toDoubleArray().toSeq + leftArray ~== (rightArray, epsilon) + + case ((ArrayType(ArrayType(FloatType,_),_), ArrayType(ArrayType(FloatType,_),_)), i) => + val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => + ArrayData.toArrayData(arr).toFloatArray().toSeq + } + val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => + ArrayData.toArrayData(arr).toFloatArray().toSeq + } + leftArrays ~== (rightArrays, epsilon) + + case ((ArrayType(ArrayType(DoubleType,_),_), ArrayType(ArrayType(DoubleType,_),_)), i) => + val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => + ArrayData.toArrayData(arr).toDoubleArray().toSeq + } + val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => + ArrayData.toArrayData(arr).toDoubleArray().toSeq + } + leftArrays ~== (rightArrays, epsilon) + + case((a,b), i) => left.get(i) === right.get(i) + } + } + else false + } + } +} \ No newline at end of file diff --git a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoderTest.scala b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoderTest.scala index 954803d1..8c450014 100644 --- a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoderTest.scala +++ b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureDecoderTest.scala @@ -18,15 +18,17 @@ package org.tensorflow.spark.datasources.tfrecords.serde import org.scalatest.{Matchers, WordSpec} import org.tensorflow.example.{BytesList, Feature, FloatList, Int64List} import org.tensorflow.hadoop.shaded.protobuf.ByteString +import org.tensorflow.spark.datasources.tfrecords.TestingUtils._ class FeatureDecoderTest extends WordSpec with Matchers { + val epsilon = 1E-6 "Int Feature decoder" should { "Decode Feature to Int" in { val int64List = Int64List.newBuilder().addValue(4).build() val intFeature = Feature.newBuilder().setInt64List(int64List).build() - IntFeatureDecoder.decode(intFeature) should equal(4) + assert(IntFeatureDecoder.decode(intFeature) === 4) } "Throw an exception if length of feature array exceeds 1" in { @@ -51,7 +53,7 @@ class FeatureDecoderTest extends WordSpec with Matchers { "Decode Feature to Int List" in { val int64List = Int64List.newBuilder().addValue(3).addValue(9).build() val intFeature = Feature.newBuilder().setInt64List(int64List).build() - IntListFeatureDecoder.decode(intFeature) should equal(Seq(3,9)) + assert(IntListFeatureDecoder.decode(intFeature) === Seq(3,9)) } "Throw an exception if feature is not an Int64List" in { @@ -68,7 +70,7 @@ class FeatureDecoderTest extends WordSpec with Matchers { "Decode Feature to Long" in { val int64List = Int64List.newBuilder().addValue(5L).build() val intFeature = Feature.newBuilder().setInt64List(int64List).build() - LongFeatureDecoder.decode(intFeature) should equal(5L) + assert(LongFeatureDecoder.decode(intFeature) === 5L) } "Throw an exception if length of feature array exceeds 1" in { @@ -93,7 +95,7 @@ class FeatureDecoderTest extends WordSpec with Matchers { "Decode Feature to Long List" in { val int64List = Int64List.newBuilder().addValue(3L).addValue(Int.MaxValue+10L).build() val intFeature = Feature.newBuilder().setInt64List(int64List).build() - LongListFeatureDecoder.decode(intFeature) should equal(Seq(3L,Int.MaxValue+10L)) + assert(LongListFeatureDecoder.decode(intFeature) === Seq(3L,Int.MaxValue+10L)) } "Throw an exception if feature is not an Int64List" in { @@ -110,7 +112,7 @@ class FeatureDecoderTest extends WordSpec with Matchers { "Decode Feature to Float" in { val floatList = FloatList.newBuilder().addValue(2.5F).build() val floatFeature = Feature.newBuilder().setFloatList(floatList).build() - FloatFeatureDecoder.decode(floatFeature) should equal(2.5F) + assert(FloatFeatureDecoder.decode(floatFeature) === 2.5F +- epsilon.toFloat) } "Throw an exception if length of feature array exceeds 1" in { @@ -135,7 +137,7 @@ class FeatureDecoderTest extends WordSpec with Matchers { "Decode Feature to Float List" in { val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.3F).build() val floatFeature = Feature.newBuilder().setFloatList(floatList).build() - FloatListFeatureDecoder.decode(floatFeature) should equal(Seq(2.5F, 4.3F)) + assert(FloatListFeatureDecoder.decode(floatFeature) ~== Seq(2.5F, 4.3F)) } "Throw an exception if feature is not a FloatList" in { @@ -152,7 +154,7 @@ class FeatureDecoderTest extends WordSpec with Matchers { "Decode Feature to Double" in { val floatList = FloatList.newBuilder().addValue(2.5F).build() val floatFeature = Feature.newBuilder().setFloatList(floatList).build() - DoubleFeatureDecoder.decode(floatFeature) should equal(2.5d) + assert(DoubleFeatureDecoder.decode(floatFeature) === 2.5d +- epsilon) } "Throw an exception if length of feature array exceeds 1" in { @@ -177,7 +179,7 @@ class FeatureDecoderTest extends WordSpec with Matchers { "Decode Feature to Double List" in { val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.0F).build() val floatFeature = Feature.newBuilder().setFloatList(floatList).build() - DoubleListFeatureDecoder.decode(floatFeature) should equal(Seq(2.5d, 4.0d)) + assert(DoubleListFeatureDecoder.decode(floatFeature) ~== Seq(2.5d, 4.0d)) } "Throw an exception if feature is not a DoubleList" in { @@ -189,12 +191,21 @@ class FeatureDecoderTest extends WordSpec with Matchers { } } - "Bytes List Feature decoder" should { + "String Feature decoder" should { - "Decode Feature to Bytes List" in { + "Decode Feature to String" in { val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() - StringFeatureDecoder.decode(bytesFeature) should equal("str-input") + assert(StringFeatureDecoder.decode(bytesFeature) === "str-input") + } + + "Throw an exception if length of feature array exceeds 1" in { + intercept[Exception] { + val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("alice".getBytes)) + .addValue(ByteString.copyFrom("bob".getBytes)).build() + val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() + StringFeatureDecoder.decode(bytesFeature) + } } "Throw an exception if feature is not a BytesList" in { @@ -205,5 +216,22 @@ class FeatureDecoderTest extends WordSpec with Matchers { } } } -} + "String List Feature decoder" should { + + "Decode Feature to String List" in { + val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("alice".getBytes)) + .addValue(ByteString.copyFrom("bob".getBytes)).build() + val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() + assert(StringListFeatureDecoder.decode(bytesFeature) === Seq("alice", "bob")) + } + + "Throw an exception if feature is not a BytesList" in { + intercept[Exception] { + val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.0F).build() + val floatFeature = Feature.newBuilder().setFloatList(floatList).build() + StringListFeatureDecoder.decode(floatFeature) + } + } + } +} diff --git a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala index 4c09d568..73a549ac 100644 --- a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala +++ b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala @@ -16,84 +16,53 @@ package org.tensorflow.spark.datasources.tfrecords.serde import org.scalatest.{Matchers, WordSpec} - +import org.tensorflow.spark.datasources.tfrecords.TestingUtils._ import scala.collection.JavaConverters._ class FeatureEncoderTest extends WordSpec with Matchers { "Int64List feature encoder" should { "Encode inputs to Int64List" in { - val intFeature = Int64ListFeatureEncoder.encode(5) - val longFeature = Int64ListFeatureEncoder.encode(10L) + val longFeature = Int64ListFeatureEncoder.encode(Seq(10L)) val longListFeature = Int64ListFeatureEncoder.encode(Seq(3L,5L,6L)) - intFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(5L)) - longFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(10L)) - longListFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(3L, 5L, 6L)) - } - - "Throw an exception when inputs contain null" in { - intercept[Exception] { - Int64ListFeatureEncoder.encode(null) - } - intercept[Exception] { - Int64ListFeatureEncoder.encode(Seq(3,null,6)) - } + assert(longFeature.getInt64List.getValueList.asScala.toSeq === Seq(10L)) + assert(longListFeature.getInt64List.getValueList.asScala.toSeq === Seq(3L, 5L, 6L)) } - "Throw an exception for non-numeric inputs" in { - intercept[Exception] { - Int64ListFeatureEncoder.encode("bad-input") - } + "Encode empty list to empty feature" in { + val longListFeature = Int64ListFeatureEncoder.encode(Seq.empty[Long]) + assert(longListFeature.getInt64List.getValueList.size() === 0) } } "FloatList feature encoder" should { "Encode inputs to FloatList" in { - val intFeature = FloatListFeatureEncoder.encode(5) - val longFeature = FloatListFeatureEncoder.encode(10L) - val floatFeature = FloatListFeatureEncoder.encode(2.5F) - val doubleFeature = FloatListFeatureEncoder.encode(14.6) + val floatFeature = FloatListFeatureEncoder.encode(Seq(2.5F)) val floatListFeature = FloatListFeatureEncoder.encode(Seq(1.5F,6.8F,-3.2F)) - intFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(5F)) - longFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(10F)) - floatFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(2.5F)) - doubleFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(14.6F)) - floatListFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(1.5F,6.8F,-3.2F)) - } - - "Throw an exception when inputs contain null" in { - intercept[Exception] { - FloatListFeatureEncoder.encode(null) - } - intercept[Exception] { - FloatListFeatureEncoder.encode(Seq(3,null,6)) - } + assert(floatFeature.getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== Seq(2.5F)) + assert(floatListFeature.getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== Seq(1.5F,6.8F,-3.2F)) } - "Throw an exception for non-numeric inputs" in { - intercept[Exception] { - FloatListFeatureEncoder.encode("bad-input") - } + "Encode empty list to empty feature" in { + val floatListFeature = FloatListFeatureEncoder.encode(Seq.empty[Float]) + assert(floatListFeature.getFloatList.getValueList.size() === 0) } } "ByteList feature encoder" should { "Encode inputs to ByteList" in { - val longFeature = BytesListFeatureEncoder.encode(10L) - val longListFeature = BytesListFeatureEncoder.encode(Seq(3L,5L,6L)) - val strFeature = BytesListFeatureEncoder.encode("str-input") + val strFeature = BytesListFeatureEncoder.encode(Seq("str-input")) + val strListFeature = BytesListFeatureEncoder.encode(Seq("alice", "bob")) - longFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("10") - longListFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("List(3, 5, 6)") - strFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("str-input") + assert(strFeature.getBytesList.getValueList.asScala.map(_.toStringUtf8) === Seq("str-input")) + assert(strListFeature.getBytesList.getValueList.asScala.map(_.toStringUtf8) === Seq("alice", "bob")) } - "Throw an exception when inputs contain null" in { - intercept[Exception] { - BytesListFeatureEncoder.encode(null) - } + "Encode empty list to empty feature" in { + val strListFeature = BytesListFeatureEncoder.encode(Seq.empty[String]) + assert(strListFeature.getBytesList.getValueList.size() === 0) } } } diff --git a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListDecoderTest.scala b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListDecoderTest.scala new file mode 100644 index 00000000..a3168cae --- /dev/null +++ b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListDecoderTest.scala @@ -0,0 +1,161 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.tensorflow.hadoop.shaded.protobuf.ByteString +import org.tensorflow.spark.datasources.tfrecords.TestingUtils._ +import org.scalatest.{Matchers, WordSpec} +import org.tensorflow.example._ + +class FeatureListDecoderTest extends WordSpec with Matchers{ + + "Int FeatureList decoder" should { + + "Decode FeatureList to 2-dimensional integer array" in { + val int64List1 = Int64List.newBuilder().addValue(1).addValue(3).build() + val int64List2 = Int64List.newBuilder().addValue(-2).addValue(5).addValue(10).build() + val feature1 = Feature.newBuilder().setInt64List(int64List1).build() + val feature2 = Feature.newBuilder().setInt64List(int64List2).build() + val featureList = FeatureList.newBuilder().addFeature(feature1).addFeature(feature2).build() + + assert(IntFeatureListDecoder.decode(featureList) === Seq(Seq(1,3), Seq(-2,5,10))) + } + + "Decode empty feature list to empty array" in { + val featureList = FeatureList.newBuilder().build() + assert(IntFeatureListDecoder.decode(featureList).size === 0) + } + + "Throw an exception if FeatureList is not of type Int64List" in { + intercept[Exception] { + val floatList = FloatList.newBuilder().addValue(4).build() + val feature = Feature.newBuilder().setFloatList(floatList).build() + val featureList = FeatureList.newBuilder().addFeature(feature).build() + IntFeatureListDecoder.decode(featureList) + } + } + } + + "Long FeatureList decoder" should { + + "Decode FeatureList to 2-dimensional long array" in { + val int64List1 = Int64List.newBuilder().addValue(1).addValue(Int.MaxValue+10L).build() + val int64List2 = Int64List.newBuilder().addValue(Int.MinValue-20L).build() + val intFeature1 = Feature.newBuilder().setInt64List(int64List1).build() + val intFeature2 = Feature.newBuilder().setInt64List(int64List2).build() + val featureList = FeatureList.newBuilder().addFeature(intFeature1).addFeature(intFeature2).build() + + assert(LongFeatureListDecoder.decode(featureList) === Seq(Seq(1L,Int.MaxValue+10L), Seq(Int.MinValue-20L))) + } + + "Decode empty feature list to empty array" in { + val featureList = FeatureList.newBuilder().build() + assert(LongFeatureListDecoder.decode(featureList).size === 0) + } + + "Throw an exception if FeatureList is not of type Int64List" in { + intercept[Exception] { + val floatList = FloatList.newBuilder().addValue(4).build() + val feature = Feature.newBuilder().setFloatList(floatList).build() + val featureList = FeatureList.newBuilder().addFeature(feature).build() + LongFeatureListDecoder.decode(featureList) + } + } + } + + "Float FeatureList decoder" should { + + "Decode FeatureList to 2-dimensional float array" in { + val floatList1 = FloatList.newBuilder().addValue(1.3F).addValue(3.85F).build() + val floatList2 = FloatList.newBuilder().addValue(-2.0F).build() + val feature1 = Feature.newBuilder().setFloatList(floatList1).build() + val feature2 = Feature.newBuilder().setFloatList(floatList2).build() + val featureList = FeatureList.newBuilder().addFeature(feature1).addFeature(feature2).build() + + assert(FloatFeatureListDecoder.decode(featureList) ~== Seq(Seq(1.3F,3.85F), Seq(-2.0F))) + } + + "Decode empty feature list to empty array" in { + val featureList = FeatureList.newBuilder().build() + assert(FloatFeatureListDecoder.decode(featureList).size === 0) + } + + "Throw an exception if FeatureList is not of type FloatList" in { + intercept[Exception] { + val intList = Int64List.newBuilder().addValue(4).build() + val feature = Feature.newBuilder().setInt64List(intList).build() + val featureList = FeatureList.newBuilder().addFeature(feature).build() + FloatFeatureListDecoder.decode(featureList) + } + } + } + + "Double FeatureList decoder" should { + + "Decode FeatureList to 2-dimensional double array" in { + val floatList1 = FloatList.newBuilder().addValue(4.3F).addValue(13.8F).build() + val floatList2 = FloatList.newBuilder().addValue(-12.0F).build() + val feature1 = Feature.newBuilder().setFloatList(floatList1).build() + val feature2 = Feature.newBuilder().setFloatList(floatList2).build() + val featureList = FeatureList.newBuilder().addFeature(feature1).addFeature(feature2).build() + + assert(DoubleFeatureListDecoder.decode(featureList) ~== Seq(Seq(4.3d,13.8d), Seq(-12.0d))) + } + + "Decode empty feature list to empty array" in { + val featureList = FeatureList.newBuilder().build() + assert(DoubleFeatureListDecoder.decode(featureList).size === 0) + } + + "Throw an exception if FeatureList is not of type FloatList" in { + intercept[Exception] { + val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("charles".getBytes)).build() + val feature = Feature.newBuilder().setBytesList(bytesList).build() + val featureList = FeatureList.newBuilder().addFeature(feature).build() + DoubleFeatureListDecoder.decode(featureList) + } + } + } + + "String FeatureList decoder" should { + + "Decode FeatureList to 2-dimensional string array" in { + val bytesList1 = BytesList.newBuilder().addValue(ByteString.copyFrom("alice".getBytes)) + .addValue(ByteString.copyFrom("bob".getBytes)).build() + val bytesList2 = BytesList.newBuilder().addValue(ByteString.copyFrom("charles".getBytes)).build() + + val feature1 = Feature.newBuilder().setBytesList(bytesList1).build() + val feature2 = Feature.newBuilder().setBytesList(bytesList2).build() + val featureList = FeatureList.newBuilder().addFeature(feature1).addFeature(feature2).build() + + assert(StringFeatureListDecoder.decode(featureList) === Seq(Seq("alice", "bob"), Seq("charles"))) + } + + "Decode empty feature list to empty array" in { + val featureList = FeatureList.newBuilder().build() + assert(StringFeatureListDecoder.decode(featureList).size === 0) + } + + "Throw an exception if FeatureList is not of type BytesList" in { + intercept[Exception] { + val intList = Int64List.newBuilder().addValue(4).build() + val feature = Feature.newBuilder().setInt64List(intList).build() + val featureList = FeatureList.newBuilder().addFeature(feature).build() + StringFeatureListDecoder.decode(featureList) + } + } + } +} diff --git a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListEncoderTest.scala b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListEncoderTest.scala new file mode 100644 index 00000000..1129c39f --- /dev/null +++ b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListEncoderTest.scala @@ -0,0 +1,69 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.scalatest.{Matchers, WordSpec} +import org.tensorflow.spark.datasources.tfrecords.TestingUtils._ + +import scala.collection.JavaConverters._ + +class FeatureListEncoderTest extends WordSpec with Matchers { + + "Int64 feature list encoder" should { + + "Encode inputs to feature list of Int64" in { + val longListOfLists = Seq(Seq(3L,5L,Int.MaxValue+6L), Seq(-1L,-6L)) + val longFeatureList = Int64FeatureListEncoder.encode(longListOfLists) + + longFeatureList.getFeatureList.asScala.map(_.getInt64List.getValueList.asScala.toSeq) should equal (longListOfLists) + } + + "Encode empty array to empty feature list" in { + val longFeatureList = Int64FeatureListEncoder.encode(Seq.empty[Seq[Long]]) + assert(longFeatureList.getFeatureList.size() === 0) + } + } + + "Float feature list encoder" should { + + "Encode inputs to feature list of Float" in { + val floatListOfLists = Seq(Seq(-2.67F, 1.5F, 0F), Seq(-1.4F,-6F)) + val floatFeatureList = FloatFeatureListEncoder.encode(floatListOfLists) + + assert(floatFeatureList.getFeatureList.asScala.map(_.getFloatList.getValueList.asScala.map(_.toFloat).toSeq) ~== floatListOfLists) + } + + "Encode empty array to empty feature list" in { + val floatFeatureList = FloatFeatureListEncoder.encode(Seq.empty[Seq[Float]]) + assert(floatFeatureList.getFeatureList.size() === 0) + } + } + + "String feature list encoder" should { + + "Encode inputs to feature list of string" in { + val stringListOfLists = Seq(Seq("alice", "bob"), Seq("charles")) + val stringFeatureList = BytesFeatureListEncoder.encode(stringListOfLists) + + assert(stringFeatureList.getFeatureList.asScala.map(_.getBytesList.getValueList.asScala.map(_.toStringUtf8).toSeq) === stringListOfLists) + } + + "Encode empty array to empty feature list" in { + val stringFeatureList = BytesFeatureListEncoder.encode(Seq.empty[Seq[String]]) + assert(stringFeatureList.getFeatureList.size() === 0) + } + } +} diff --git a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/TfRecordRowDecoderTest.scala b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/TfRecordRowDecoderTest.scala new file mode 100644 index 00000000..e92b73d6 --- /dev/null +++ b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/TfRecordRowDecoderTest.scala @@ -0,0 +1,127 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.types._ +import org.scalatest.{Matchers, WordSpec} +import org.tensorflow.example._ +import org.tensorflow.hadoop.shaded.protobuf.ByteString +import org.tensorflow.spark.datasources.tfrecords.TestingUtils._ + +class TfRecordRowDecoderTest extends WordSpec with Matchers { + + "TensorFlow row decoder" should { + + "Decode given TensorFlow Example as Row" in { + + val schema = StructType(List( + StructField("IntegerLabel", IntegerType), + StructField("LongLabel", LongType), + StructField("FloatLabel", FloatType), + StructField("DoubleLabel", DoubleType), + StructField("LongArrayLabel", ArrayType(LongType)), + StructField("DoubleArrayLabel", ArrayType(DoubleType)), + StructField("StrLabel", StringType), + StructField("StrArrayLabel", ArrayType(StringType)) + )) + + val expectedRow = new GenericRow( + Array[Any](1, 23L, 10.0F, 14.0, Seq(-2L,7L), Seq(1.0, 2.0), "r1", Seq("r2", "r3")) + ) + + //Build example + val intFeature = Int64List.newBuilder().addValue(1) + val longFeature = Int64List.newBuilder().addValue(23L) + val floatFeature = FloatList.newBuilder().addValue(10.0F) + val doubleFeature = FloatList.newBuilder().addValue(14.0F) + val longArrFeature = Int64List.newBuilder().addValue(-2L).addValue(7L).build() + val doubleArrFeature = FloatList.newBuilder().addValue(1F).addValue(2F).build() + val strFeature = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build() + val strListFeature = BytesList.newBuilder().addValue(ByteString.copyFrom("r2".getBytes)) + .addValue(ByteString.copyFrom("r3".getBytes)).build() + val features = Features.newBuilder() + .putFeature("IntegerLabel", Feature.newBuilder().setInt64List(intFeature).build()) + .putFeature("LongLabel", Feature.newBuilder().setInt64List(longFeature).build()) + .putFeature("FloatLabel", Feature.newBuilder().setFloatList(floatFeature).build()) + .putFeature("DoubleLabel", Feature.newBuilder().setFloatList(doubleFeature).build()) + .putFeature("LongArrayLabel", Feature.newBuilder().setInt64List(longArrFeature).build()) + .putFeature("DoubleArrayLabel", Feature.newBuilder().setFloatList(doubleArrFeature).build()) + .putFeature("StrLabel", Feature.newBuilder().setBytesList(strFeature).build()) + .putFeature("StrArrayLabel", Feature.newBuilder().setBytesList(strListFeature).build()) + .build() + val example = Example.newBuilder() + .setFeatures(features) + .build() + + //Decode TensorFlow example to Sql Row + val actualRow = DefaultTfRecordRowDecoder.decodeExample(example, schema) + assert(actualRow ~== (expectedRow,schema)) + } + + "Decode given TensorFlow SequenceExample as Row" in { + + val schema = StructType(List( + StructField("LongArrayLabel", ArrayType(LongType)), + StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), + StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))), + StructField("StrArrayOfArrayLabel", ArrayType(ArrayType(StringType))) + )) + + val expectedRow = new GenericRow(Array[Any]( + Seq(-2L,7L), Seq(Seq(4L, 10L)), Seq(Seq(2.25F), Seq(-1.9F,3.5F)), Seq(Seq("r1", "r2"), Seq("r3"))) + ) + + //Build sequence example + val longArrFeature = Int64List.newBuilder().addValue(-2L).addValue(7L).build() + + val int64List1 = Int64List.newBuilder().addValue(4L).addValue(10L).build() + val intFeature1 = Feature.newBuilder().setInt64List(int64List1).build() + val int64FeatureList = FeatureList.newBuilder().addFeature(intFeature1).build() + + val floatList1 = FloatList.newBuilder().addValue(2.25F).build() + val floatList2 = FloatList.newBuilder().addValue(-1.9F).addValue(3.5F).build() + val floatFeature1 = Feature.newBuilder().setFloatList(floatList1).build() + val floatFeature2 = Feature.newBuilder().setFloatList(floatList2).build() + val floatFeatureList = FeatureList.newBuilder().addFeature(floatFeature1).addFeature(floatFeature2).build() + + val bytesList1 = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)) + .addValue(ByteString.copyFrom("r2".getBytes)).build() + val bytesList2 = BytesList.newBuilder().addValue(ByteString.copyFrom("r3".getBytes)).build() + val bytesFeature1 = Feature.newBuilder().setBytesList(bytesList1).build() + val bytesFeature2 = Feature.newBuilder().setBytesList(bytesList2).build() + val bytesFeatureList = FeatureList.newBuilder().addFeature(bytesFeature1).addFeature(bytesFeature2).build() + + val features = Features.newBuilder() + .putFeature("LongArrayLabel", Feature.newBuilder().setInt64List(longArrFeature).build()) + + val featureLists = FeatureLists.newBuilder() + .putFeatureList("LongArrayOfArrayLabel", int64FeatureList) + .putFeatureList("FloatArrayOfArrayLabel", floatFeatureList) + .putFeatureList("StrArrayOfArrayLabel", bytesFeatureList) + .build() + + val seqExample = SequenceExample.newBuilder() + .setContext(features) + .setFeatureLists(featureLists) + .build() + + //Decode TensorFlow example to Sql Row + val actualRow = DefaultTfRecordRowDecoder.decodeSequenceExample(seqExample, schema) + assert(actualRow ~== (expectedRow, schema)) + } + } +} diff --git a/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/TfRecordRowEncoderTest.scala b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/TfRecordRowEncoderTest.scala new file mode 100644 index 00000000..f2f6540f --- /dev/null +++ b/spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/TfRecordRowEncoderTest.scala @@ -0,0 +1,105 @@ +/** + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.spark.datasources.tfrecords.serde + +import org.tensorflow.example.Feature +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.types._ +import org.scalatest.{Matchers, WordSpec} +import scala.collection.JavaConverters._ +import org.tensorflow.spark.datasources.tfrecords.TestingUtils._ + +class TfRecordRowEncoderTest extends WordSpec with Matchers { + + "TensorFlow row encoder" should { + + "Encode given Row as TensorFlow Example" in { + val schemaStructType = StructType(Array( + StructField("IntegerLabel", IntegerType), + StructField("LongLabel", LongType), + StructField("FloatLabel", FloatType), + StructField("DoubleLabel", DoubleType), + StructField("DoubleArrayLabel", ArrayType(DoubleType)), + StructField("StrLabel", StringType), + StructField("StrArrayLabel", ArrayType(StringType)) + )) + val doubleArray = Array(1.1, 111.1, 11111.1) + val expectedFloatArray = Array(1.1F, 111.1F, 11111.1F) + + val rowWithSchema = new GenericRowWithSchema(Array[Any](1, 23L, 10.0F, 14.0, doubleArray, "r1", Seq("r2", "r3")), schemaStructType) + + //Encode Sql Row to TensorFlow example + val example = DefaultTfRecordRowEncoder.encodeExample(rowWithSchema) + + //Verify each Datatype converted to TensorFlow datatypes + val featureMap = example.getFeatures.getFeatureMap.asScala + assert(featureMap.size == 7) + + assert(featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) + assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 1) + + assert(featureMap("LongLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) + assert(featureMap("LongLabel").getInt64List.getValue(0).toInt == 23) + + assert(featureMap("FloatLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert(featureMap("FloatLabel").getFloatList.getValue(0) == 10.0F) + + assert(featureMap("DoubleLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert(featureMap("DoubleLabel").getFloatList.getValue(0) == 14.0F) + + assert(featureMap("DoubleArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert(featureMap("DoubleArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== expectedFloatArray) + + assert(featureMap("StrLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) + assert(featureMap("StrLabel").getBytesList.getValue(0).toStringUtf8 == "r1") + + assert(featureMap("StrArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) + assert(featureMap("StrArrayLabel").getBytesList.getValueList.asScala.map(_.toStringUtf8) === Seq("r2", "r3")) + } + + "Encode given Row as TensorFlow SequenceExample" in { + + val schemaStructType = StructType(Array( + StructField("IntegerLabel", IntegerType), + StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), + StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))), + StructField("StringArrayOfArrayLabel", ArrayType(ArrayType(StringType))) + )) + + val longListOfLists = Seq(Seq(3L, 5L), Seq(-8L, 0L)) + val floatListOfLists = Seq(Seq(1.5F, -6.5F), Seq(-8.2F, 0F)) + val stringListOfLists = Seq(Seq("r1"), Seq("r2", "r3"), Seq("r4")) + + val rowWithSchema = new GenericRowWithSchema(Array[Any](10, longListOfLists, floatListOfLists, stringListOfLists), schemaStructType) + + //Encode Sql Row to TensorFlow example + val seqExample = DefaultTfRecordRowEncoder.encodeSequenceExample(rowWithSchema) + + //Verify each Datatype converted to TensorFlow datatypes + val featureMap = seqExample.getContext.getFeatureMap.asScala + val featureListMap = seqExample.getFeatureLists.getFeatureListMap.asScala + + assert(featureMap.size == 1) + assert(featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) + assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 10) + + assert(featureListMap.size == 3) + assert(featureListMap("LongArrayOfArrayLabel").getFeatureList.asScala.map(_.getInt64List.getValueList.asScala.toSeq) === longListOfLists) + assert(featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.map(_.getFloatList.getValueList.asScala.map(_.toFloat).toSeq) ~== floatListOfLists) + assert(featureListMap("StringArrayOfArrayLabel").getFeatureList.asScala.map(_.getBytesList.getValueList.asScala.map(_.toStringUtf8).toSeq) === stringListOfLists) + } + } +}