diff --git a/layers/deformation_layers.py b/layers/deformation_layers.py index 6119856..9c6cdfa 100644 --- a/layers/deformation_layers.py +++ b/layers/deformation_layers.py @@ -144,6 +144,8 @@ def __call__(self, inputs): for i in range(len(self.shape)): if int(self.shape[i]//inputs.shape[i]) > 1: inputs = tf.repeat(inputs, repeats=int(self.shape[i]//inputs.shape[i]), axis=i) + elif self.shape[i] < inputs.shape[i] and self.shape[i] != 1: + inputs = tf.repeat(inputs, repeats=int(self.shape[i]), axis=i) return inputs @OPERATOR.register_operator("Unsqueeze")