diff --git a/elephas/ml_model.py b/elephas/ml_model.py index 5e4f951..90b5438 100644 --- a/elephas/ml_model.py +++ b/elephas/ml_model.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, print_function import numpy as np +import copy import h5py import json @@ -150,7 +151,7 @@ def _transform(self, df): """ output_col = self.getOutputCol() label_col = self.getLabelCol() - new_schema = df.schema + new_schema = copy.deepcopy(df.schema) new_schema.add(StructField(output_col, StringType(), True)) rdd = df.rdd.coalesce(1) @@ -176,4 +177,4 @@ def load_ml_transformer(file_name): f = h5py.File(file_name, mode='r') elephas_conf = json.loads(f.attrs.get('distributed_config')) config = elephas_conf.get('config') - return ElephasTransformer(**config) \ No newline at end of file + return ElephasTransformer(**config)