diff --git a/layers/common_layers.py b/layers/common_layers.py index d9059b8..b372d74 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -259,7 +259,7 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args 5: np.int16, 6: np.int32, 7: np.int64, - 9: np.bool, + 9: np.bool_, 10: np.float16, 11: np.double, } @@ -276,16 +276,15 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args } def __call__(self, inputs): - np_cast_op = self.np_cast_map[self.cast_to] if isinstance(inputs, list): for i in range(len(inputs)): if isinstance(inputs[i], np.ndarray) or isinstance(inputs[i], np.generic): - inputs[i] = np_cast_op(inputs[i]) + inputs[i] = self.np_cast_map[self.cast_to](inputs[i]) else: inputs[i] = tf.cast(input[i], dtype=self.tf_cast_map[self.cast_to]) else: if isinstance(inputs, np.ndarray) or isinstance(inputs, np.generic): - inputs = np_cast_op(inputs) + inputs = self.np_cast_map[self.cast_to](inputs) else: inputs = tf.cast(inputs, dtype=self.tf_cast_map[self.cast_to])