Skip to content

Commit

Permalink
remove elephas optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Mar 20, 2019
1 parent 76d49e5 commit 6c5a74e
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 362 deletions.
18 changes: 0 additions & 18 deletions elephas/ml/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,6 @@ def get_keras_model_config(self):
return self.getOrDefault(self.keras_model_config)


class HasElephasOptimizerConfig(Params):
"""Parameter mixin for Elephas optimizer config
"""

def __init__(self):
super(HasElephasOptimizerConfig, self).__init__()
self.elephas_optimizer_config = Param(self, "elephas_optimizer_config",
"Serialized Elephas optimizer properties")
self._setDefault(elephas_optimizer_config=None)

def set_elephas_optimizer_config(self, elephas_optimizer_config):
self._paramMap[self.elephas_optimizer_config] = elephas_optimizer_config
return self

def get_elephas_optimizer_config(self):
return self.getOrDefault(self.elephas_optimizer_config)


class HasMode(Params):
"""Parameter mixin for Elephas mode
"""
Expand Down
26 changes: 10 additions & 16 deletions elephas/ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from .utils.rdd_utils import from_vector
from .ml.adapter import df_to_simple_rdd
from .ml.params import *
from .optimizers import get


class ElephasEstimator(Estimator, HasCategoricalLabels, HasValidationSplit, HasKerasModelConfig, HasFeaturesCol,
HasLabelCol, HasMode, HasEpochs, HasBatchSize, HasFrequency, HasVerbosity, HasNumberOfClasses,
HasNumberOfWorkers, HasElephasOptimizerConfig, HasOutputCol, HasLoss,
HasNumberOfWorkers, HasOutputCol, HasLoss,
HasMetrics, HasKerasOptimizerConfig):
"""
SparkML Estimator implementation of an elephas model. This estimator takes all relevant arguments for model
Expand All @@ -38,7 +37,6 @@ def __init__(self, **kwargs):

def get_config(self):
return {'keras_model_config': self.get_keras_model_config(),
'elephas_optimizer_config': self.get_elephas_optimizer_config(),
'mode': self.get_mode(),
'frequency': self.get_frequency(),
'num_workers': self.get_num_workers(),
Expand Down Expand Up @@ -77,27 +75,26 @@ def _fit(self, df):
simple_rdd = df_to_simple_rdd(df, categorical=self.get_categorical_labels(), nb_classes=self.get_nb_classes(),
features_col=self.getFeaturesCol(), label_col=self.getLabelCol())
simple_rdd = simple_rdd.repartition(self.get_num_workers())
elephas_optimizer = None
if self.get_elephas_optimizer_config() is not None:
elephas_optimizer = get({'class_name': self.get_optimizer_config()['class_name'],
'config': self.get_optimizer_config()})

keras_model = model_from_yaml(self.get_keras_model_config())
metrics = self.get_metrics()
loss = self.get_loss()
optimizer = get_optimizer(self.get_optimizer_config())
keras_model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

spark_model = SparkModel(model=keras_model, elephas_optimizer=elephas_optimizer,
mode=self.get_mode(), frequency=self.get_frequency(),
spark_model = SparkModel(model=keras_model,
mode=self.get_mode(),
frequency=self.get_frequency(),
num_workers=self.get_num_workers())
spark_model.fit(simple_rdd, epochs=self.get_epochs(), batch_size=self.get_batch_size(),
verbose=self.get_verbosity(), validation_split=self.get_validation_split())
spark_model.fit(simple_rdd,
epochs=self.get_epochs(),
batch_size=self.get_batch_size(),
verbose=self.get_verbosity(),
validation_split=self.get_validation_split())

model_weights = spark_model.master_network.get_weights()
weights = simple_rdd.ctx.broadcast(model_weights)
return ElephasTransformer(labelCol=self.getLabelCol(),
outputCol='prediction', # TODO: Set default value
outputCol='prediction',
keras_model_config=spark_model.master_network.to_yaml(),
weights=weights)

Expand Down Expand Up @@ -165,9 +162,6 @@ def _transform(self, df):
predictions = predictions.map(lambda x: tuple(str(x)))

results_rdd = rdd.zip(predictions).map(lambda x: x[0] + x[1])
# TODO: Zipping like this is very likely wrong
# results_rdd = rdd.zip(predictions).map(lambda pair: Row(features=to_vector(pair[0].features),
# label=pair[0].label, prediction=float(pair[1])))
results_df = df.sql_ctx.createDataFrame(results_rdd, new_schema)
results_df = results_df.withColumn(
output_col, results_df[output_col].cast(DoubleType()))
Expand Down
297 changes: 0 additions & 297 deletions elephas/optimizers.py

This file was deleted.

Loading

0 comments on commit 6c5a74e

Please sign in to comment.