Skip to content

Commit

Permalink
refactor(ONNX): replaces getValueList helper with `createScalarSubl…
Browse files Browse the repository at this point in the history
…ist` (#3987)

A preliminary refactor to support #3945 
- extracts several new helper functions
- removes cruft
  • Loading branch information
bjacobgordon authored Feb 7, 2025
1 parent 2063ec7 commit a4f5beb
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 51 deletions.
121 changes: 72 additions & 49 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,53 +180,67 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter,
return success();
}

Value getValueList(OpBinder binder, ConversionPatternRewriter &rewriter,
Value operand) {
SmallVector<Value> itemList;
auto sizes = dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
Torch::BaseTensorType operandType =
cast<Torch::BaseTensorType>(operand.getType());

SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = operandType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), operandType.getOptionalDtype());

auto extract = [&rewriter, &binder](Value x, Value v) {
auto xTy = cast<Torch::ValueTensorType>(x.getType());
Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(xTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();

return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy, v);
};

Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

MLIRContext *context = binder.op->getContext();
for (int i = 2; i < sizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value ext = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, operand, zero, selectIndex);
Value item = extract(operand, ext);
itemList.push_back(item);
}
auto xTy = cast<Torch::ValueTensorType>(operand.getType());
Value ValueList;
if (isa<IntegerType>(xTy.getDtype())) {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), Torch::ListType::get(Torch::IntType::get(context)),
itemList);
} else {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), Torch::ListType::get(Torch::FloatType::get(context)),
itemList);
Type getTorchScalarType(
/* forElementIn */ Torch::BaseTensorType givenTensorType,
/* using */ ConversionPatternRewriter &rewriter) {
auto elementTypeForGivenTensor = givenTensorType.getDtype();

if (isa<IntegerType>(elementTypeForGivenTensor))
return rewriter.getType<Torch::IntType>();
if (isa<FloatType>(elementTypeForGivenTensor))
return rewriter.getType<Torch::FloatType>();

assert(false && "dtype for given tensor expected to be either int or float");
}

Value extractTorchScalar(
/* at */ Location givenLoc,
/* from */ int64_t givenIndex,
/* in */ Value given1DTensor,
/* using */ ConversionPatternRewriter &rewriter) {
auto some1DTensorType = cast<Torch::BaseTensorType>(given1DTensor.getType());

Type selectionTypeForSome1DTensor = some1DTensorType.getWithSizesAndDtype(
ArrayRef<int64_t>{1}, some1DTensorType.getOptionalDtype());

Value frontDim = rewriter.create<Torch::ConstantIntOp>(givenLoc, 0);

Value selectionIndex =
rewriter.create<Torch::ConstantIntOp>(givenLoc, givenIndex);

auto someTorchScalarType = getTorchScalarType(some1DTensorType, rewriter);

Value selectionFromGiven1DTensor = rewriter.create<Torch::AtenSelectIntOp>(
givenLoc, selectionTypeForSome1DTensor, given1DTensor, frontDim,
selectionIndex);

return rewriter.create<Torch::AtenItemOp>(givenLoc, someTorchScalarType,
selectionFromGiven1DTensor);
}

Value createScalarSublist(
/* at */ Location givenLoc,
/* movingForwardsThrough */ Value given1DTensor,
/* startingAt */ int64_t givenIndex,
/* using */ ConversionPatternRewriter &rewriter) {
auto some1DTensorType = cast<Torch::BaseTensorType>(given1DTensor.getType());
auto sizesOfSome1DTensor = some1DTensorType.getSizes();
auto lengthOfFullList = sizesOfSome1DTensor[0];

SmallVector<Value> runningScalarSublist;

for (int indexOfEachScalar = givenIndex; indexOfEachScalar < lengthOfFullList;
indexOfEachScalar++) {
Value eachScalar = extractTorchScalar(givenLoc, indexOfEachScalar,
given1DTensor, rewriter);
runningScalarSublist.push_back(eachScalar);
}
return ValueList;

auto someTorchScalarType = runningScalarSublist.front().getType();
Type someTorchScalarListType = Torch::ListType::get(someTorchScalarType);

return rewriter.create<Torch::PrimListConstructOp>(
givenLoc, someTorchScalarListType, runningScalarSublist);
}
} // namespace

Expand Down Expand Up @@ -2809,14 +2823,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
}

int64_t assumedForemostSpatialDim = 2;

if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(binder, rewriter, scaleOperand);
scalesValueList =
createScalarSublist(binder.getLoc(), scaleOperand,
assumedForemostSpatialDim, rewriter);
sizesValueList = noneVal;
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(binder, rewriter, sizeOperand);
sizesValueList =
createScalarSublist(binder.getLoc(), sizeOperand,
assumedForemostSpatialDim, rewriter);
}
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
isa<Torch::NoneType>(sizesValueList.getType())) {
Expand Down Expand Up @@ -3339,7 +3360,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return rewriter.notifyMatchFailure(
binder.op, "supports upto 3d upsampling only");

Value scalesValueList = getValueList(binder, rewriter, scales);
int64_t assumedForemostSpatialDim = 2;
Value scalesValueList = createScalarSublist(
binder.getLoc(), scales, assumedForemostSpatialDim, rewriter);
if (mode == "linear") {
if (resultRank == 4)
mode = "bilinear";
Expand Down
6 changes: 4 additions & 2 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2803,8 +2803,9 @@ func.func @test_upsample_nearest(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !t
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list<float>
// CHECK: %[[MODE:.*]] = torch.constant.str "nearest"
Expand All @@ -2824,8 +2825,9 @@ func.func @test_upsample_bilinear(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list<float>
// CHECK: %[[MODE:.*]] = torch.constant.str "bilinear"
Expand Down

0 comments on commit a4f5beb

Please sign in to comment.