Skip to content

Commit

Permalink
Standardize formatting in Ops.cpp and Interpreter.cpp (#1039)
Browse files Browse the repository at this point in the history
Removes braces around functions which only have one line for
conciseness.
  • Loading branch information
ghpvnist authored Feb 2, 2023
1 parent 00bffab commit aad57d1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 42 deletions.
17 changes: 7 additions & 10 deletions stablehlo/reference/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,24 @@ namespace stablehlo {

llvm::Expected<SmallVector<Tensor>> eval(func::FuncOp func,
ArrayRef<Tensor> args) {
if (func->getNumRegions() != 1) {
if (func->getNumRegions() != 1)
return invalidArgument("Expected one region in func %s",
func.getName().str().c_str());
}
if (!func.getBody().hasOneBlock()) {

if (!func.getBody().hasOneBlock())
return invalidArgument("Expected one block in func %s",
func.getName().str().c_str());
}

Block &block = func.front();
if (block.getNumArguments() != args.size()) {
if (block.getNumArguments() != args.size())
return invalidArgument(
"Expected same amount of func arguments in %s "
"and runtime arguments (%d)",
func.getName().str().c_str(), args.size());
}

llvm::DenseMap<Value, Tensor> stackFrame;
for (auto [ssaArg, runtimeArg] : llvm::zip(block.getArguments(), args)) {
for (auto [ssaArg, runtimeArg] : llvm::zip(block.getArguments(), args))
stackFrame[ssaArg] = runtimeArg;
}

for (Operation &op : block) {
auto fetchOperand = [&](Value value) -> Tensor {
Expand Down Expand Up @@ -130,9 +128,8 @@ llvm::Expected<SmallVector<Tensor>> eval(func::FuncOp func,
populateResults({runtimeResult});
} else if (auto returnOp = dyn_cast<func::ReturnOp>(op)) {
SmallVector<Tensor> runtimeOperands;
for (Value ssaOperand : returnOp.getOperands()) {
for (Value ssaOperand : returnOp.getOperands())
runtimeOperands.push_back(fetchOperand(ssaOperand));
}
return runtimeOperands;
} else if (auto sineOp = dyn_cast<SineOp>(op)) {
Tensor runtimeOperand = fetchOperand(sineOp.getOperand());
Expand Down
48 changes: 16 additions & 32 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,22 @@ SmallVector<int64_t> permute(ArrayRef<int64_t> array, ArrayRef<int64_t> perm) {

Tensor evalAddOp(const Tensor &lhs, const Tensor &rhs, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it) {
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, lhs.get(*it) + rhs.get(*it));
}
return result;
}

Tensor evalAndOp(const Tensor &lhs, const Tensor &rhs, Type resultType) {
Tensor result(resultType);
for (auto it = lhs.index_begin(); it != lhs.index_end(); ++it) {
for (auto it = lhs.index_begin(); it != lhs.index_end(); ++it)
result.set(*it, lhs.get(*it) & rhs.get(*it));
}
return result;
}

Tensor evalCeilOp(const Tensor &operand, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it) {
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, ceil(operand.get(*it)));
}
return result;
}

Expand All @@ -70,17 +67,15 @@ Tensor evalConstantOp(ElementsAttr value) {

Tensor evalCosineOp(const Tensor &operand, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it) {
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, cosine(operand.get(*it)));
}
return result;
}

Tensor evalFloorOp(const Tensor &operand, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it) {
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, floor(operand.get(*it)));
}
return result;
}

Expand Down Expand Up @@ -122,82 +117,72 @@ Tensor evalIotaOp(int64_t iotaDimension, Type resultType) {

Tensor evalMaxOp(const Tensor &lhs, const Tensor &rhs, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it) {
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, max(lhs.get(*it), rhs.get(*it)));
}
return result;
}

Tensor evalMinOp(const Tensor &lhs, const Tensor &rhs, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it) {
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, min(lhs.get(*it), rhs.get(*it)));
}
return result;
}

Tensor evalMultiplyOp(const Tensor &lhs, const Tensor &rhs, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it) {
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, lhs.get(*it) * rhs.get(*it));
}
return result;
}

Tensor evalNegOp(const Tensor &operand, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it) {
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, -operand.get(*it));
}
return result;
}

Tensor evalNotOp(const Tensor &operand, Type resultType) {
Tensor result(resultType);
for (auto it = operand.index_begin(); it != operand.index_end(); ++it) {
for (auto it = operand.index_begin(); it != operand.index_end(); ++it)
result.set(*it, ~operand.get(*it));
}
return result;
}

Tensor evalOrOp(const Tensor &lhs, const Tensor &rhs, Type resultType) {
Tensor result(resultType);
for (auto it = lhs.index_begin(); it != lhs.index_end(); ++it) {
for (auto it = lhs.index_begin(); it != lhs.index_end(); ++it)
result.set(*it, lhs.get(*it) | rhs.get(*it));
}
return result;
}

Tensor evalReshapeOp(const Tensor &operand, Type resultType) {
Tensor result(resultType);
for (auto resultIt = result.index_begin(), operandIt = operand.index_begin();
resultIt != result.index_end(); ++resultIt, ++operandIt) {
resultIt != result.index_end(); ++resultIt, ++operandIt)
result.set(*resultIt, operand.get(*operandIt));
}
return result;
}

Tensor evalSineOp(const Tensor &operand, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it) {
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, sine(operand.get(*it)));
}
return result;
}

Tensor evalSubtractOp(const Tensor &lhs, const Tensor &rhs, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it) {
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, lhs.get(*it) - rhs.get(*it));
}
return result;
}

Tensor evalTanhOp(const Tensor &operand, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it) {
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, tanh(operand.get(*it)));
}
return result;
}

Expand All @@ -214,9 +199,8 @@ Tensor evalTransposeOp(const Tensor &operand, ArrayRef<int64_t> permutation,

Tensor evalXorOp(const Tensor &lhs, const Tensor &rhs, Type resultType) {
Tensor result(resultType);
for (auto it = lhs.index_begin(); it != lhs.index_end(); ++it) {
for (auto it = lhs.index_begin(); it != lhs.index_end(); ++it)
result.set(*it, lhs.get(*it) ^ rhs.get(*it));
}
return result;
}

Expand Down

0 comments on commit aad57d1

Please sign in to comment.