From 981c07a2172d2e057d81fb79cea5d7630b93fefd Mon Sep 17 00:00:00 2001 From: MPolaris <540492239@qq.com> Date: Fri, 23 Sep 2022 09:15:17 +0800 Subject: [PATCH] Expand update --- layers/deformation_layers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/layers/deformation_layers.py b/layers/deformation_layers.py index 6119856..9c6cdfa 100644 --- a/layers/deformation_layers.py +++ b/layers/deformation_layers.py @@ -144,6 +144,8 @@ def __call__(self, inputs): for i in range(len(self.shape)): if int(self.shape[i]//inputs.shape[i]) > 1: inputs = tf.repeat(inputs, repeats=int(self.shape[i]//inputs.shape[i]), axis=i) + elif self.shape[i] < inputs.shape[i] and self.shape[i] != 1: + inputs = tf.repeat(inputs, repeats=int(self.shape[i]), axis=i) return inputs @OPERATOR.register_operator("Unsqueeze")