Skip to content

Commit

Permalink
fix Cast bug(MPolaris#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
MPolaris committed Aug 5, 2022
1 parent 46a0aaf commit cf0f4b9
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args
5: np.int16,
6: np.int32,
7: np.int64,
9: np.bool,
9: np.bool_,
10: np.float16,
11: np.double,
}
Expand All @@ -276,16 +276,15 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args
}

def __call__(self, inputs):
np_cast_op = self.np_cast_map[self.cast_to]
if isinstance(inputs, list):
for i in range(len(inputs)):
if isinstance(inputs[i], np.ndarray) or isinstance(inputs[i], np.generic):
inputs[i] = np_cast_op(inputs[i])
inputs[i] = self.np_cast_map[self.cast_to](inputs[i])
else:
inputs[i] = tf.cast(input[i], dtype=self.tf_cast_map[self.cast_to])
else:
if isinstance(inputs, np.ndarray) or isinstance(inputs, np.generic):
inputs = np_cast_op(inputs)
inputs = self.np_cast_map[self.cast_to](inputs)
else:
inputs = tf.cast(inputs, dtype=self.tf_cast_map[self.cast_to])

Expand Down

0 comments on commit cf0f4b9

Please sign in to comment.