Skip to content

Commit

Permalink
refactor(ONNX): renames getValueList helper to createScalarSublist
Browse files Browse the repository at this point in the history
- Before:
  - "get": implies retrieval of some private property
  - "Value": restatement of the return type `Value`
  - "List": assumed result of casting the returned instance
- After:
  - "create": contextualizes the need to pass in `rewriter`
  - "Scalar": contextualizes the opaque return type
  - "Sublist": the relationship between the first parameter and the returned result
  • Loading branch information
bjacobgordon committed Feb 6, 2025
1 parent f6a9459 commit db12b36
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Value extractTorchScalar(
selectionFromGiven1DTensor);
}

Value getValueList(
Value createScalarSublist(
/* at */ Location givenLoc,
/* movingForwardsThrough */ Value given1DTensor,
/* startingAt */ int64_t givenIndex,
Expand Down Expand Up @@ -2828,14 +2828,16 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(binder.getLoc(), scaleOperand,
assumedForemostSpatialDim, rewriter);
scalesValueList =
createScalarSublist(binder.getLoc(), scaleOperand,
assumedForemostSpatialDim, rewriter);
sizesValueList = noneVal;
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(binder.getLoc(), sizeOperand,
assumedForemostSpatialDim, rewriter);
sizesValueList =
createScalarSublist(binder.getLoc(), sizeOperand,
assumedForemostSpatialDim, rewriter);
}
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
isa<Torch::NoneType>(sizesValueList.getType())) {
Expand Down Expand Up @@ -3359,7 +3361,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "supports upto 3d upsampling only");

int64_t assumedForemostSpatialDim = 2;
Value scalesValueList = getValueList(
Value scalesValueList = createScalarSublist(
binder.getLoc(), scales, assumedForemostSpatialDim, rewriter);
if (mode == "linear") {
if (resultRank == 4)
Expand Down

0 comments on commit db12b36

Please sign in to comment.