From cf0f4b97fad5d8a68b187ed46efffdd8cf4f4b71 Mon Sep 17 00:00:00 2001 From: MPolaris <540492239@qq.com> Date: Fri, 5 Aug 2022 08:58:29 +0800 Subject: [PATCH] fix Cast bug(#7) --- layers/common_layers.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index d9059b8..b372d74 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -259,7 +259,7 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args 5: np.int16, 6: np.int32, 7: np.int64, - 9: np.bool, + 9: np.bool_, 10: np.float16, 11: np.double, } @@ -276,16 +276,15 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args } def __call__(self, inputs): - np_cast_op = self.np_cast_map[self.cast_to] if isinstance(inputs, list): for i in range(len(inputs)): if isinstance(inputs[i], np.ndarray) or isinstance(inputs[i], np.generic): - inputs[i] = np_cast_op(inputs[i]) + inputs[i] = self.np_cast_map[self.cast_to](inputs[i]) else: inputs[i] = tf.cast(input[i], dtype=self.tf_cast_map[self.cast_to]) else: if isinstance(inputs, np.ndarray) or isinstance(inputs, np.generic): - inputs = np_cast_op(inputs) + inputs = self.np_cast_map[self.cast_to](inputs) else: inputs = tf.cast(inputs, dtype=self.tf_cast_map[self.cast_to])