From 6c5a74e86838cfcb8b7ff9289341dc2bf383493d Mon Sep 17 00:00:00 2001 From: Max Pumperla Date: Wed, 20 Mar 2019 15:57:20 +0100 Subject: [PATCH] remove elephas optimizer --- elephas/ml/params.py | 18 --- elephas/ml_model.py | 26 ++-- elephas/optimizers.py | 297 ------------------------------------ elephas/parameter/server.py | 24 ++- elephas/spark_model.py | 24 +-- elephas/worker.py | 2 +- 6 files changed, 29 insertions(+), 362 deletions(-) delete mode 100644 elephas/optimizers.py diff --git a/elephas/ml/params.py b/elephas/ml/params.py index 0cf62ce..de8a043 100644 --- a/elephas/ml/params.py +++ b/elephas/ml/params.py @@ -21,24 +21,6 @@ def get_keras_model_config(self): return self.getOrDefault(self.keras_model_config) -class HasElephasOptimizerConfig(Params): - """Parameter mixin for Elephas optimizer config - """ - - def __init__(self): - super(HasElephasOptimizerConfig, self).__init__() - self.elephas_optimizer_config = Param(self, "elephas_optimizer_config", - "Serialized Elephas optimizer properties") - self._setDefault(elephas_optimizer_config=None) - - def set_elephas_optimizer_config(self, elephas_optimizer_config): - self._paramMap[self.elephas_optimizer_config] = elephas_optimizer_config - return self - - def get_elephas_optimizer_config(self): - return self.getOrDefault(self.elephas_optimizer_config) - - class HasMode(Params): """Parameter mixin for Elephas mode """ diff --git a/elephas/ml_model.py b/elephas/ml_model.py index 5f91c0d..b79e15c 100644 --- a/elephas/ml_model.py +++ b/elephas/ml_model.py @@ -18,12 +18,11 @@ from .utils.rdd_utils import from_vector from .ml.adapter import df_to_simple_rdd from .ml.params import * -from .optimizers import get class ElephasEstimator(Estimator, HasCategoricalLabels, HasValidationSplit, HasKerasModelConfig, HasFeaturesCol, HasLabelCol, HasMode, HasEpochs, HasBatchSize, HasFrequency, HasVerbosity, HasNumberOfClasses, - HasNumberOfWorkers, HasElephasOptimizerConfig, HasOutputCol, HasLoss, + HasNumberOfWorkers, HasOutputCol, HasLoss, HasMetrics, HasKerasOptimizerConfig): """ SparkML Estimator implementation of an elephas model. This estimator takes all relevant arguments for model @@ -38,7 +37,6 @@ def __init__(self, **kwargs): def get_config(self): return {'keras_model_config': self.get_keras_model_config(), - 'elephas_optimizer_config': self.get_elephas_optimizer_config(), 'mode': self.get_mode(), 'frequency': self.get_frequency(), 'num_workers': self.get_num_workers(), @@ -77,27 +75,26 @@ def _fit(self, df): simple_rdd = df_to_simple_rdd(df, categorical=self.get_categorical_labels(), nb_classes=self.get_nb_classes(), features_col=self.getFeaturesCol(), label_col=self.getLabelCol()) simple_rdd = simple_rdd.repartition(self.get_num_workers()) - elephas_optimizer = None - if self.get_elephas_optimizer_config() is not None: - elephas_optimizer = get({'class_name': self.get_optimizer_config()['class_name'], - 'config': self.get_optimizer_config()}) - keras_model = model_from_yaml(self.get_keras_model_config()) metrics = self.get_metrics() loss = self.get_loss() optimizer = get_optimizer(self.get_optimizer_config()) keras_model.compile(loss=loss, optimizer=optimizer, metrics=metrics) - spark_model = SparkModel(model=keras_model, elephas_optimizer=elephas_optimizer, - mode=self.get_mode(), frequency=self.get_frequency(), + spark_model = SparkModel(model=keras_model, + mode=self.get_mode(), + frequency=self.get_frequency(), num_workers=self.get_num_workers()) - spark_model.fit(simple_rdd, epochs=self.get_epochs(), batch_size=self.get_batch_size(), - verbose=self.get_verbosity(), validation_split=self.get_validation_split()) + spark_model.fit(simple_rdd, + epochs=self.get_epochs(), + batch_size=self.get_batch_size(), + verbose=self.get_verbosity(), + validation_split=self.get_validation_split()) model_weights = spark_model.master_network.get_weights() weights = simple_rdd.ctx.broadcast(model_weights) return ElephasTransformer(labelCol=self.getLabelCol(), - outputCol='prediction', # TODO: Set default value + outputCol='prediction', keras_model_config=spark_model.master_network.to_yaml(), weights=weights) @@ -165,9 +162,6 @@ def _transform(self, df): predictions = predictions.map(lambda x: tuple(str(x))) results_rdd = rdd.zip(predictions).map(lambda x: x[0] + x[1]) - # TODO: Zipping like this is very likely wrong - # results_rdd = rdd.zip(predictions).map(lambda pair: Row(features=to_vector(pair[0].features), - # label=pair[0].label, prediction=float(pair[1]))) results_df = df.sql_ctx.createDataFrame(results_rdd, new_schema) results_df = results_df.withColumn( output_col, results_df[output_col].cast(DoubleType())) diff --git a/elephas/optimizers.py b/elephas/optimizers.py deleted file mode 100644 index a938d4b..0000000 --- a/elephas/optimizers.py +++ /dev/null @@ -1,297 +0,0 @@ -""" -This is essentially a copy of keras' optimizers.py. -We have to modify the base class 'Optimizer' here, -as the gradients will be provided by the Spark workers, -not by one of the backends (Theano or Tensorflow). -""" -from __future__ import absolute_import -from keras import backend as K -from keras.optimizers import TFOptimizer -from keras.utils import deserialize_keras_object, serialize_keras_object -import numpy as np -import six -import tensorflow as tf -from six.moves import zip - - -def clip_norm(g, c, n): - """Clip gradients - """ - if c > 0: - g = K.switch(K.ge(n, c), g * c / n, g) - return g - - -def kl_divergence(p, p_hat): - """Kullbach-Leibler divergence """ - return p_hat - p + p * K.log(p / p_hat) - - -class Optimizer(object): - """Optimizer for elephas models, adapted from - respective Keras module. - """ - - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - self.updates = [] - - def get_state(self): - """ Get latest status of optimizer updates """ - return [u[0].get_value() for u in self.updates] - - def set_state(self, value_list): - """ Set current status of optimizer """ - assert len(self.updates) == len(value_list) - for u, v in zip(self.updates, value_list): - u[0].set_value(v) - - def get_updates(self, params, constraints, grads): - """ Compute updates from gradients and constraints """ - raise NotImplementedError - - def get_gradients(self, grads, params): - - if hasattr(self, 'clipnorm') and self.clipnorm > 0: - norm = K.sqrt(sum([K.sum(g ** 2) for g in grads])) - grads = [clip_norm(g, self.clipnorm, norm) for g in grads] - - if hasattr(self, 'clipvalue') and self.clipvalue > 0: - grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads] - - return K.shared(grads) - - def get_config(self): - """ Get configuration dictionary """ - return {"class_name": self.__class__.__name__} - - -class SGD(Optimizer): - """SGD, optionally with nesterov momentum """ - - def __init__(self, lr=0.01, momentum=0., decay=0., - nesterov=False, *args, **kwargs): - super(SGD, self).__init__(**kwargs) - self.__dict__.update(locals()) - self.iterations = 0 - self.lr = lr - self.momentum = momentum - self.decay = decay - - def get_updates(self, params, constraints, grads): - lr = self.lr * (1.0 / (1.0 + self.decay * self.iterations)) - self.updates = [(self.iterations, self.iterations + 1.)] - new_weights = [] - - for p, g, c in zip(params, grads, constraints): - m = np.zeros_like(p) # momentum - v = self.momentum * m - lr * g # velocity - if self.nesterov: - new_p = p + self.momentum * v - lr * g - else: - new_p = p + v - new_weights.append(c(new_p)) - - return new_weights - - def get_config(self): - return {"class_name": self.__class__.__name__, - "lr": float(self.lr), - "momentum": float(self.momentum), - "decay": float(self.decay), - "nesterov": self.nesterov} - - -class RMSprop(Optimizer): - """Reference: www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf - """ - - def __init__(self, lr=0.001, rho=0.9, epsilon=1e-6, *args, **kwargs): - super(RMSprop, self).__init__(**kwargs) - self.__dict__.update(locals()) - self.lr = lr - self.rho = rho - - def get_updates(self, params, constraints, grads): - accumulators = [np.zeros_like(p) for p in params] - new_weights = [] - - for p, g, a, c in zip(params, grads, accumulators, constraints): - new_a = self.rho * a + (1 - self.rho) * g ** 2 - self.updates.append((a, new_a)) - - new_p = p - self.lr * g / np.sqrt(new_a + self.epsilon) - new_weights.append(c(new_p)) - - return new_weights - - def get_config(self): - return {"class_name": self.__class__.__name__, - "lr": float(self.lr), - "rho": float(self.rho), - "epsilon": self.epsilon} - - -class Adagrad(Optimizer): - """Reference: http://www.magicbroom.info/Papers/DuchiHaSi10.pdf - """ - - def __init__(self, lr=0.01, epsilon=1e-6, *args, **kwargs): - super(Adagrad, self).__init__(**kwargs) - self.__dict__.update(locals()) - self.lr = lr - - def get_updates(self, params, constraints, grads): - accumulators = [np.zeros_like(p) for p in params] - new_weights = [] - for p, g, a, c in zip(params, grads, accumulators, constraints): - new_a = a + g ** 2 - new_p = p - self.lr * g / np.sqrt(new_a + self.epsilon) - new_weights.append(new_p) - - return new_weights - - def get_config(self): - return {"class_name": self.__class__.__name__, - "lr": float(self.lr), - "epsilon": self.epsilon} - - -class Adadelta(Optimizer): - """Reference: http://arxiv.org/abs/1212.5701 - """ - - def __init__(self, lr=1.0, rho=0.95, epsilon=1e-6, *args, **kwargs): - super(Adadelta, self).__init__(**kwargs) - self.__dict__.update(locals()) - self.lr = lr - - def get_updates(self, params, constraints, grads): - accumulators = [np.zeros_like(p) for p in params] - delta_accumulators = [np.zeros_like(p) for p in params] - new_weights = [] - - for p, g, a, d_a, c in zip(params, grads, accumulators, - delta_accumulators, constraints): - new_a = self.rho * a + (1 - self.rho) * g ** 2 - self.updates.append((a, new_a)) - # use the new accumulator and the *old* delta_accumulator - div = np.sqrt(new_a + self.epsilon) - update = g * np.sqrt(d_a + self.epsilon) / div - new_p = p - self.lr * update - self.updates.append((p, c(new_p))) # apply constraints - - new_weights.append(new_p) - return new_weights - - def get_config(self): - return {"class_name": self.__class__.__name__, - "lr": float(self.lr), - "rho": self.rho, - "epsilon": self.epsilon} - - -class Adam(Optimizer): - """Reference: http://arxiv.org/abs/1412.6980v8 - Default parameters follow those provided in the original paper. - """ - - def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, - epsilon=1e-8, *args, **kwargs): - super(Adam, self).__init__(**kwargs) - self.__dict__.update(locals()) - self.iterations = 0 - self.lr = lr - - def get_updates(self, params, constraints, grads): - new_weights = [] - - t = self.iterations + 1 - lr_t = self.lr * np.sqrt(1-self.beta_2**t)/(1-self.beta_1**t) - - for p, g, c in zip(params, grads, constraints): - m = np.zeros_like(p) # zero init of moment - v = np.zeros_like(p) # zero init of velocity - - m_t = (self.beta_1 * m) + (1 - self.beta_1) * g - v_t = (self.beta_2 * v) + (1 - self.beta_2) * (g**2) - p_t = p - lr_t * m_t / (np.sqrt(v_t) + self.epsilon) - new_weights.append(c(p_t)) - - return new_weights - - def get_config(self): - return {"class_name": self.__class__.__name__, - "lr": float(self.lr), - "beta_1": self.beta_1, - "beta_2": self.beta_2, - "epsilon": self.epsilon} - - -# aliases -sgd = SGD -rmsprop = RMSprop -adagrad = Adagrad -adadelta = Adadelta -adam = Adam - - -def serialize(optimizer): - return serialize_keras_object(optimizer) - - -def deserialize(config, custom_objects=None): - """Inverse of the `serialize` function. - # Arguments - config: Optimizer configuration dictionary. - custom_objects: Optional dictionary mapping - names (strings) to custom objects - (classes and functions) - to be considered during deserialization. - # Returns - A Keras Optimizer instance. - """ - all_classes = { - 'sgd': SGD, - 'rmsprop': RMSprop, - 'adagrad': Adagrad, - 'adadelta': Adadelta, - 'adam': Adam - } - # Make deserialization case-insensitive for built-in optimizers. - if config['class_name'].lower() in all_classes: - config['class_name'] = config['class_name'].lower() - return deserialize_keras_object(config, - module_objects=all_classes, - custom_objects=custom_objects, - printable_module_name='optimizer') - - -def get(identifier): - """Retrieves a Keras Optimizer instance. - # Arguments - identifier: Optimizer identifier, one of - - String: name of an optimizer - - Dictionary: configuration dictionary. - - Keras Optimizer instance (it will be returned unchanged). - - TensorFlow Optimizer instance - (it will be wrapped as a Keras Optimizer). - # Returns - A Keras Optimizer instance. - # Raises - ValueError: If `identifier` cannot be interpreted. - """ - if K.backend() == 'tensorflow': - # Wrap TF optimizer instances - if isinstance(identifier, tf.train.Optimizer): - return TFOptimizer(identifier) - if isinstance(identifier, dict): - return deserialize(identifier) - elif isinstance(identifier, six.string_types): - config = {'class_name': str(identifier), 'config': {}} - return deserialize(config) - if isinstance(identifier, Optimizer): - return identifier - else: - raise ValueError('Could not interpret optimizer identifier:', - identifier) diff --git a/elephas/parameter/server.py b/elephas/parameter/server.py index cb8e924..218eb12 100644 --- a/elephas/parameter/server.py +++ b/elephas/parameter/server.py @@ -5,11 +5,12 @@ from flask import Flask, request from multiprocessing import Process -from ..utils.sockets import determine_master -from ..utils.sockets import receive, send -from ..utils.serialization import dict_to_model -from ..utils.rwlock import RWLock as Lock -from ..utils.notebook_utils import is_running_in_notebook +from elephas.utils.sockets import determine_master +from elephas.utils.sockets import receive, send +from elephas.utils.serialization import dict_to_model +from elephas.utils.rwlock import RWLock as Lock +from elephas.utils.notebook_utils import is_running_in_notebook +from elephas.utils import subtract_params class BaseParameterServer(object): @@ -44,15 +45,14 @@ class HttpServer(BaseParameterServer): POST updates. """ - def __init__(self, model, optimizer, mode, port=4000, debug=True, + def __init__(self, model, mode, port=4000, debug=True, threaded=True, use_reloader=True): - """Initializes and HTTP server from a serialized Keras model, elephas optimizer, + """Initializes and HTTP server from a serialized Keras model a parallelisation mode and a port to run the Flask application on. In hogwild mode no read- or write-locks will be acquired, in asynchronous mode this is the case. :param model: Serialized Keras model - :param optimizer: Elephas optimizer :param mode: parallelization mode, either `asynchronous` or `hogwild` :param port: int, port to run the application on :param debug: boolean, Flask debug mode @@ -63,7 +63,6 @@ def __init__(self, model, optimizer, mode, port=4000, debug=True, self.master_network = dict_to_model(model) self.mode = mode self.master_url = None - self.optimizer = optimizer self.port = port @@ -125,10 +124,9 @@ def handle_update_parameters(): if not self.master_network.built: self.master_network.build() - def base_constraint(a): return a - constraints = [base_constraint for _ in self.weights] - self.weights = self.optimizer.get_updates( - self.weights, constraints, delta) + # Just apply the gradient + self.weights = subtract_params(self.weights, delta) + if self.mode == 'asynchronous': self.lock.release() return 'Update done' diff --git a/elephas/spark_model.py b/elephas/spark_model.py index f190411..d77eef2 100644 --- a/elephas/spark_model.py +++ b/elephas/spark_model.py @@ -7,10 +7,10 @@ from keras.optimizers import serialize as serialize_optimizer from keras.models import load_model +from .utils import subtract_params from .utils import lp_to_simple_rdd from .utils import model_to_dict from .mllib import to_matrix, from_matrix, to_vector, from_vector -from .optimizers import SGD from .worker import AsynchronousSparkWorker, SparkWorker from .parameter import HttpServer, SocketServer from .parameter import HttpClient, SocketClient @@ -19,7 +19,7 @@ class SparkModel(object): def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_server_mode='http', num_workers=None, - elephas_optimizer=None, custom_objects=None, batch_size=32, port=4000, *args, **kwargs): + custom_objects=None, batch_size=32, port=4000, *args, **kwargs): """SparkModel Base class for distributed training on RDDs. Spark model takes a Keras @@ -31,7 +31,6 @@ def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_ser :param frequency: String, either `epoch` or `batch` :param parameter_server_mode: String, either `http` or `socket` :param num_workers: int, number of workers used for training (defaults to None) - :param elephas_optimizer: Elephas optimizer :param custom_objects: Keras custom objects :param batch_size: batch size used for training and inference :param port: port used in case of 'http' parameter server mode @@ -49,10 +48,6 @@ def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_ser custom_objects = {} if metrics is None: metrics = ["accuracy"] - if elephas_optimizer is None: - self.optimizer = SGD() - else: - self.optimizer = elephas_optimizer self.mode = mode self.frequency = frequency self.num_workers = num_workers @@ -71,7 +66,7 @@ def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_ser if self.mode is not 'synchronous': if self.parameter_server_mode == 'http': self.parameter_server = HttpServer( - self.serialized_model, self.optimizer, self.mode, self.port) + self.serialized_model, self.mode, self.port) self.client = HttpClient(self.port) elif self.parameter_server_mode == 'socket': self.parameter_server = SocketServer(self.serialized_model) @@ -90,7 +85,6 @@ def get_train_config(epochs, batch_size, verbose, validation_split): def get_config(self): base_config = { 'parameter_server_mode': self.parameter_server_mode, - 'elephas_optimizer': self.optimizer.get_config(), 'mode': self.mode, 'frequency': self.frequency, 'num_workers': self.num_workers, @@ -187,13 +181,10 @@ def _fit(self, rdd, epochs, batch_size, verbose, validation_split): elif self.mode == 'synchronous': worker = SparkWorker(yaml, parameters, train_config, optimizer, loss, metrics, custom) - deltas = rdd.mapPartitions(worker.train).collect() + gradients = rdd.mapPartitions(worker.train).collect() new_parameters = self._master_network.get_weights() - for delta, weight in deltas: - def base_constraint(a): return a - constraints = [base_constraint for _ in weight] - new_parameters = self.optimizer.get_updates( - weight, constraints, delta) + for grad in gradients: # simply accumulate gradients one by one + new_parameters = subtract_params(new_parameters, grad) else: raise ValueError("Unsupported mode {}".format(self.mode)) self._master_network.set_weights(new_parameters) @@ -227,14 +218,13 @@ def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_serv :param frequency: String, either `epoch` or `batch` :param parameter_server_mode: String, either `http` or `socket` :param num_workers: int, number of workers used for training (defaults to None) - :param elephas_optimizer: Elephas optimizer :param custom_objects: Keras custom objects :param batch_size: batch size used for training and inference :param port: port used in case of 'http' parameter server mode """ SparkModel.__init__(self, model=model, mode=mode, frequency=frequency, parameter_server_mode=parameter_server_mode, num_workers=num_workers, - elephas_optimizer=elephas_optimizer, custom_objects=custom_objects, + custom_objects=custom_objects, batch_size=batch_size, port=port, *args, **kwargs) def fit(self, labeled_points, epochs=10, batch_size=32, verbose=0, validation_split=0.1, diff --git a/elephas/worker.py b/elephas/worker.py index 236352f..d1be6ef 100644 --- a/elephas/worker.py +++ b/elephas/worker.py @@ -46,7 +46,7 @@ def train(self, data_iterator): weights_after_training = self.model.get_weights() deltas = subtract_params( weights_before_training, weights_after_training) - yield deltas, weights_after_training + yield deltas class AsynchronousSparkWorker(object):