diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index a26bcbe7c..5e0da20d0 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -632,12 +632,15 @@ def aten_hardtanh(self: TReal, min_val: float = -1.0, max_val: float = 1.0) -> T return op.Clip(self, min_val, max_val) +@torch_op("aten::hardtanh_backward", trace_only=True) def aten_hardtanh_backward( grad_output: TensorType, self: TensorType, min_val: float, max_val: float ) -> TensorType: """hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor""" - raise NotImplementedError() + max_mask = op.Where(op.Greater(self, max_val), 0.0, 1.0) + min_mask = op.Where(op.Less(self, min_val), 0.0, 1.0) + return op.Mul(op.Mul(grad_output, max_mask), min_mask) def aten_huber_loss(