Skip to content

Commit

Permalink
Add Spark Tensorflow Connector (tensorflow#34)
Browse files Browse the repository at this point in the history
* Add spark-tensorflow-connector to ecosystem

* Update Readme

* Add feedback changes
  • Loading branch information
karthikvadla authored and jhseu committed Mar 7, 2017
1 parent fd225df commit f2a87ca
Show file tree
Hide file tree
Showing 19 changed files with 1,858 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ request.
Marathon, deployed on top of Mesos.
- [hadoop](hadoop) - TFRecord file InputFormat/OutputFormat for Hadoop MapReduce
and Spark.
- [spark](spark) - Spark TensorFlow Connector

## Distributed TensorFlow

Expand Down
85 changes: 85 additions & 0 deletions spark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# spark-tensorflow-connector

This repo contains a library for loading and storing TensorFlow records with [Apache Spark](http://spark.apache.org/).
The library implements data import from the standard TensorFlow record format ([TFRecords]
(https://www.tensorflow.org/how_tos/reading_data/)) into Spark SQL DataFrames, and data export from DataFrames to TensorFlow records.

## What's new

This is the initial release of the `spark-tensorflow-connector` repo.

## Known issues

None.

## Prerequisites

1. [Apache Spark 2.0 (or later)](http://spark.apache.org/)

2. [Apache Maven](https://maven.apache.org/)

## Building the library
You can build library using both Maven and SBT build tools

#### Maven
Build the library using Maven(3.3) as shown below

```sh
mvn clean install
```

#### SBT
Build the library using SBT(0.13.13) as show below
```sh
sbt clean assembly
```

## Using Spark Shell
Run this library in Spark using the `--jars` command line option in `spark-shell` or `spark-submit`. For example:

Maven Jars
```sh
$SPARK_HOME/bin/spark-shell --jars target/spark-tensorflow-connector-1.0-SNAPSHOT.jar,target/lib/tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar
```

SBT Jars
```sh
$SPARK_HOME/bin/spark-shell --jars target/scala-2.11/spark-tensorflow-connector-assembly-1.0.0.jar
```

The following code snippet demonstrates usage.

```scala
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

val path = "test-output.tfr"
val testRows: Array[Row] = Array(
new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
val schema = StructType(List(StructField("id", IntegerType),
StructField("IntegerTypelabel", IntegerType),
StructField("LongTypelabel", LongType),
StructField("FloatTypelabel", FloatType),
StructField("DoubleTypelabel", DoubleType),
StructField("vectorlabel", ArrayType(DoubleType, true)),
StructField("name", StringType)))

val rdd = spark.sparkContext.parallelize(testRows)

//Save DataFrame as TFRecords
val df: DataFrame = spark.createDataFrame(rdd, schema)
df.write.format("tfrecords").save(path)

//Read TFRecords into DataFrame.
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val importedDf1: DataFrame = spark.read.format("tfrecords").load(path)
importedDf1.show()

//Read TFRecords into DataFrame using custom schema
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()

```
93 changes: 93 additions & 0 deletions spark/build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
name := "spark-tensorflow-connector"

organization := "org.tensorflow"

scalaVersion in Global := "2.11.8"

spName := "tensorflow/spark-tensorflow-connector"

sparkVersion := "2.1.0"

sparkComponents ++= Seq("sql", "mllib")

version := "1.0.0"

def ProjectName(name: String,path:String): Project = Project(name, file(path))

resolvers in Global ++= Seq("https://tap.jfrog.io/tap/public" at "https://tap.jfrog.io/tap/public" ,
"https://tap.jfrog.io/tap/public-snapshots" at "https://tap.jfrog.io/tap/public-snapshots" ,
"https://repo.maven.apache.org/maven2" at "https://repo.maven.apache.org/maven2" )

val `junit_junit` = "junit" % "junit" % "4.12"

val `org.apache.hadoop_hadoop-yarn-api` = "org.apache.hadoop" % "hadoop-yarn-api" % "2.7.3"

val `org.apache.spark_spark-core_2.11` = "org.apache.spark" % "spark-core_2.11" % "2.1.0"

val `org.apache.spark_spark-sql_2.11` = "org.apache.spark" % "spark-sql_2.11" % "2.1.0"

val `org.apache.spark_spark-mllib_2.11` = "org.apache.spark" % "spark-mllib_2.11" % "2.1.0"

val `org.scalatest_scalatest_2.11` = "org.scalatest" % "scalatest_2.11" % "2.2.6"

val `org.tensorflow_tensorflow-hadoop` = "org.tensorflow" % "tensorflow-hadoop" % "1.0-01232017-SNAPSHOT"

libraryDependencies in Global ++= Seq(`org.tensorflow_tensorflow-hadoop` classifier "shaded-protobuf",
`org.scalatest_scalatest_2.11` % "test" ,
`org.apache.spark_spark-sql_2.11` % "provided" ,
`org.apache.spark_spark-mllib_2.11` % "test" classifier "tests",
`org.apache.spark_spark-core_2.11` % "provided" ,
`org.apache.hadoop_hadoop-yarn-api` % "provided" ,
`junit_junit` % "test" )

assemblyExcludedJars in assembly := {
val cp = (fullClasspath in assembly).value
cp filterNot {x => List("spark-tensorflow-connector-1.0-SNAPSHOT.jar",
"tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar").contains(x.data.getName)}
}

/********************
* Release settings *
********************/

spIgnoreProvided := true

spAppendScalaVersion := true

// If you published your package to Maven Central for this release (must be done prior to spPublish)
spIncludeMaven := false

publishMavenStyle := true

licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0"))

pomExtra :=
<url>https://github.com/tensorflow/ecosystem</url>
<scm>
<url>git@github.com:tensorflow/ecosystem.git</url>
<connection>scm:git:git@github.com:tensorflow/ecosystem.git</connection>
</scm>
<developers>
<developer>
<id>karthikvadla</id>
<name>Karthik Vadla</name>
<url>https://github.com/karthikvadla</url>
</developer>
<developer>
<id>skavulya</id>
<name>Soila Kavulya</name>
<url>https://github.com/skavulya</url>
</developer>
<developer>
<id>joyeshmishra</id>
<name>Joyesh Mishra</name>
<url>https://github.com/joyeshmishra</url>
</developer>
</developers>

credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials") // A file containing credentials

// Add assembly jar to Spark package
test in assembly := {}

spShade := true
Loading

0 comments on commit f2a87ca

Please sign in to comment.