diff --git a/layers/common_layers.py b/layers/common_layers.py index 6db4ad7..d9059b8 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -285,8 +285,8 @@ def __call__(self, inputs): inputs[i] = tf.cast(input[i], dtype=self.tf_cast_map[self.cast_to]) else: if isinstance(inputs, np.ndarray) or isinstance(inputs, np.generic): - inputs[i] = np_cast_op(inputs) + inputs = np_cast_op(inputs) else: - inputs[i] = tf.cast(inputs, dtype=self.tf_cast_map[self.cast_to]) + inputs = tf.cast(inputs, dtype=self.tf_cast_map[self.cast_to]) return inputs \ No newline at end of file