Skip to content

Commit

Permalink
fix(ONNX): avoids resizing unsupported dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
bjacobgordon committed Jan 28, 2025
1 parent 82626b5 commit bc9d9b8
Showing 1 changed file with 81 additions and 3 deletions.
84 changes: 81 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ int64_t const /**/ channelDim = 1;
int64_t const /* */ heightDim = 2;
// int64_t const /* */ widthDim = 3;
// int64_t const /* */ depthDim = 4;

SmallVector<int64_t> nonResizableDims{
batchDim,
channelDim,
};
} // namespace TorchImageTensor

void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Expand Down Expand Up @@ -2728,6 +2733,33 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"round_prefer_floor") ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
return failure();

Value inputTensor = operands[0];
auto InputTensor = cast<Torch::BaseTensorType>(inputTensor.getType());
auto sizesOfInputTensor = InputTensor.getSizes();
auto sizesOfOutputTensor = OutputTensor.getSizes();

auto unknownSize = Torch::kUnknownSize;

// Compile-time check for dimensions of static size
for (auto &eachDim : TorchImageTensor::nonResizableDims) {
auto eachSizeOfInputTensor = sizesOfInputTensor[eachDim];
auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDim];

if (eachSizeOfInputTensor == unknownSize ||
eachSizeOfOutputTensor == unknownSize)
continue;
if (eachSizeOfInputTensor == eachSizeOfOutputTensor)
continue;

auto resizingIntentErrorMessage =
"unsupported: non-trivial intent to resize dimension: " +
std::to_string(eachDim);

return rewriter.notifyMatchFailure(binder.op,
resizingIntentErrorMessage);
};

if (antialias != 0) {
return rewriter.notifyMatchFailure(
binder.op,
Expand Down Expand Up @@ -2775,9 +2807,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
}

Value inputTensor = operands[0];
auto InputTensor = cast<Torch::BaseTensorType>(inputTensor.getType());
auto sizesOfInputTensor = InputTensor.getSizes();
auto rankOfInputTensor = sizesOfInputTensor.size();

// supported modes:
Expand Down Expand Up @@ -2819,16 +2848,65 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

auto foremostSupportedDim = TorchImageTensor::heightDim;

Type Bool = rewriter.getType<Torch::BoolType>();

Value noneVal = rewriter.create<Torch::ConstantNoneOp>(loc);
Value supportedScaleFactors = noneVal;
Value supportedSizes = noneVal;

if (operands.size() < 4) {
Value proposedScaleFactors = operands[2];

Value scaleIdentity = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(1.0));

// run-time scale factor check for dynamic sizes
for (auto &eachDim : TorchImageTensor::nonResizableDims) {
Value eachProposedScaleFactor = createScalarAs<Torch::FloatType>(
loc, eachDim, proposedScaleFactors, rewriter);

Value eachScaleFactorIsIdentity =
rewriter.create<Torch::AtenEqFloatOp>(
loc, Bool, eachProposedScaleFactor, scaleIdentity);

auto errorMessageForEachDim =
"Unsupported: non-trivial scale factor for dimension " +
std::to_string(eachDim);

rewriter.create<Torch::RuntimeAssertOp>(
loc, eachScaleFactorIsIdentity,
rewriter.getStringAttr(errorMessageForEachDim));
};

supportedScaleFactors = createScalarSublist<Torch::FloatType>(
loc, proposedScaleFactors, foremostSupportedDim, rewriter);
} else {
Value proposedSizes = operands[3];

// run-time target size check for dynamic sizes
for (auto &eachDimAsInt : TorchImageTensor::nonResizableDims) {
Value eachDimAsValue =
rewriter.create<Torch::ConstantIntOp>(loc, eachDimAsInt);

Value eachSizeOfInputTensor = rewriter.create<Torch::AtenSizeIntOp>(
loc, inputTensor, eachDimAsValue);

Value eachProposedSize = createScalarAs<Torch::IntType>(
loc, eachDimAsInt, proposedSizes, rewriter);

Value eachProposedSizeIsTrivial =
rewriter.create<Torch::AtenEqIntOp>(loc, Bool, eachProposedSize,
eachSizeOfInputTensor);

auto errorMessageForEachDim =
"Unsupported: non-trivial resizing of dimension " +
std::to_string(eachDimAsInt);

rewriter.create<Torch::RuntimeAssertOp>(
loc, eachProposedSizeIsTrivial,
rewriter.getStringAttr(errorMessageForEachDim));
};

supportedSizes = createScalarSublist<Torch::IntType>(
loc, proposedSizes, foremostSupportedDim, rewriter);
}
Expand Down

0 comments on commit bc9d9b8

Please sign in to comment.