forked from tensorflow/ecosystem
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Spark Tensorflow Connector (tensorflow#34)
* Add spark-tensorflow-connector to ecosystem * Update Readme * Add feedback changes
- Loading branch information
1 parent
fd225df
commit f2a87ca
Showing
19 changed files
with
1,858 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# spark-tensorflow-connector | ||
|
||
This repo contains a library for loading and storing TensorFlow records with [Apache Spark](http://spark.apache.org/). | ||
The library implements data import from the standard TensorFlow record format ([TFRecords] | ||
(https://www.tensorflow.org/how_tos/reading_data/)) into Spark SQL DataFrames, and data export from DataFrames to TensorFlow records. | ||
|
||
## What's new | ||
|
||
This is the initial release of the `spark-tensorflow-connector` repo. | ||
|
||
## Known issues | ||
|
||
None. | ||
|
||
## Prerequisites | ||
|
||
1. [Apache Spark 2.0 (or later)](http://spark.apache.org/) | ||
|
||
2. [Apache Maven](https://maven.apache.org/) | ||
|
||
## Building the library | ||
You can build library using both Maven and SBT build tools | ||
|
||
#### Maven | ||
Build the library using Maven(3.3) as shown below | ||
|
||
```sh | ||
mvn clean install | ||
``` | ||
|
||
#### SBT | ||
Build the library using SBT(0.13.13) as show below | ||
```sh | ||
sbt clean assembly | ||
``` | ||
|
||
## Using Spark Shell | ||
Run this library in Spark using the `--jars` command line option in `spark-shell` or `spark-submit`. For example: | ||
|
||
Maven Jars | ||
```sh | ||
$SPARK_HOME/bin/spark-shell --jars target/spark-tensorflow-connector-1.0-SNAPSHOT.jar,target/lib/tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar | ||
``` | ||
|
||
SBT Jars | ||
```sh | ||
$SPARK_HOME/bin/spark-shell --jars target/scala-2.11/spark-tensorflow-connector-assembly-1.0.0.jar | ||
``` | ||
|
||
The following code snippet demonstrates usage. | ||
|
||
```scala | ||
import org.apache.commons.io.FileUtils | ||
import org.apache.spark.sql.{ DataFrame, Row } | ||
import org.apache.spark.sql.catalyst.expressions.GenericRow | ||
import org.apache.spark.sql.types._ | ||
|
||
val path = "test-output.tfr" | ||
val testRows: Array[Row] = Array( | ||
new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")), | ||
new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2"))) | ||
val schema = StructType(List(StructField("id", IntegerType), | ||
StructField("IntegerTypelabel", IntegerType), | ||
StructField("LongTypelabel", LongType), | ||
StructField("FloatTypelabel", FloatType), | ||
StructField("DoubleTypelabel", DoubleType), | ||
StructField("vectorlabel", ArrayType(DoubleType, true)), | ||
StructField("name", StringType))) | ||
|
||
val rdd = spark.sparkContext.parallelize(testRows) | ||
|
||
//Save DataFrame as TFRecords | ||
val df: DataFrame = spark.createDataFrame(rdd, schema) | ||
df.write.format("tfrecords").save(path) | ||
|
||
//Read TFRecords into DataFrame. | ||
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided. | ||
val importedDf1: DataFrame = spark.read.format("tfrecords").load(path) | ||
importedDf1.show() | ||
|
||
//Read TFRecords into DataFrame using custom schema | ||
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path) | ||
importedDf2.show() | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
name := "spark-tensorflow-connector" | ||
|
||
organization := "org.tensorflow" | ||
|
||
scalaVersion in Global := "2.11.8" | ||
|
||
spName := "tensorflow/spark-tensorflow-connector" | ||
|
||
sparkVersion := "2.1.0" | ||
|
||
sparkComponents ++= Seq("sql", "mllib") | ||
|
||
version := "1.0.0" | ||
|
||
def ProjectName(name: String,path:String): Project = Project(name, file(path)) | ||
|
||
resolvers in Global ++= Seq("https://tap.jfrog.io/tap/public" at "https://tap.jfrog.io/tap/public" , | ||
"https://tap.jfrog.io/tap/public-snapshots" at "https://tap.jfrog.io/tap/public-snapshots" , | ||
"https://repo.maven.apache.org/maven2" at "https://repo.maven.apache.org/maven2" ) | ||
|
||
val `junit_junit` = "junit" % "junit" % "4.12" | ||
|
||
val `org.apache.hadoop_hadoop-yarn-api` = "org.apache.hadoop" % "hadoop-yarn-api" % "2.7.3" | ||
|
||
val `org.apache.spark_spark-core_2.11` = "org.apache.spark" % "spark-core_2.11" % "2.1.0" | ||
|
||
val `org.apache.spark_spark-sql_2.11` = "org.apache.spark" % "spark-sql_2.11" % "2.1.0" | ||
|
||
val `org.apache.spark_spark-mllib_2.11` = "org.apache.spark" % "spark-mllib_2.11" % "2.1.0" | ||
|
||
val `org.scalatest_scalatest_2.11` = "org.scalatest" % "scalatest_2.11" % "2.2.6" | ||
|
||
val `org.tensorflow_tensorflow-hadoop` = "org.tensorflow" % "tensorflow-hadoop" % "1.0-01232017-SNAPSHOT" | ||
|
||
libraryDependencies in Global ++= Seq(`org.tensorflow_tensorflow-hadoop` classifier "shaded-protobuf", | ||
`org.scalatest_scalatest_2.11` % "test" , | ||
`org.apache.spark_spark-sql_2.11` % "provided" , | ||
`org.apache.spark_spark-mllib_2.11` % "test" classifier "tests", | ||
`org.apache.spark_spark-core_2.11` % "provided" , | ||
`org.apache.hadoop_hadoop-yarn-api` % "provided" , | ||
`junit_junit` % "test" ) | ||
|
||
assemblyExcludedJars in assembly := { | ||
val cp = (fullClasspath in assembly).value | ||
cp filterNot {x => List("spark-tensorflow-connector-1.0-SNAPSHOT.jar", | ||
"tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar").contains(x.data.getName)} | ||
} | ||
|
||
/******************** | ||
* Release settings * | ||
********************/ | ||
|
||
spIgnoreProvided := true | ||
|
||
spAppendScalaVersion := true | ||
|
||
// If you published your package to Maven Central for this release (must be done prior to spPublish) | ||
spIncludeMaven := false | ||
|
||
publishMavenStyle := true | ||
|
||
licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0")) | ||
|
||
pomExtra := | ||
<url>https://github.com/tensorflow/ecosystem</url> | ||
<scm> | ||
<url>git@github.com:tensorflow/ecosystem.git</url> | ||
<connection>scm:git:git@github.com:tensorflow/ecosystem.git</connection> | ||
</scm> | ||
<developers> | ||
<developer> | ||
<id>karthikvadla</id> | ||
<name>Karthik Vadla</name> | ||
<url>https://github.com/karthikvadla</url> | ||
</developer> | ||
<developer> | ||
<id>skavulya</id> | ||
<name>Soila Kavulya</name> | ||
<url>https://github.com/skavulya</url> | ||
</developer> | ||
<developer> | ||
<id>joyeshmishra</id> | ||
<name>Joyesh Mishra</name> | ||
<url>https://github.com/joyeshmishra</url> | ||
</developer> | ||
</developers> | ||
|
||
credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials") // A file containing credentials | ||
|
||
// Add assembly jar to Spark package | ||
test in assembly := {} | ||
|
||
spShade := true |
Oops, something went wrong.