Skip to content

Commit

Permalink
[RTG][Elaboration] Add support for 'index.add' and 'index.cmp'
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Jan 14, 2025
1 parent bbef1fa commit 83c12d2
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 116 deletions.
2 changes: 1 addition & 1 deletion include/circt/Dialect/RTG/Transforms/RTGPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def ElaborationPass : Pass<"rtg-elaborate", "mlir::ModuleOp"> {
"The seed for any RNG constructs used in the pass.">,
];

let dependentDialects = ["mlir::arith::ArithDialect"];
let dependentDialects = ["mlir::index::IndexDialect"];
}

#endif // CIRCT_DIALECT_RTG_TRANSFORMS_RTGPASSES_TD
2 changes: 1 addition & 1 deletion lib/Dialect/RTG/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ add_circt_dialect_library(CIRCTRTGTransforms

LINK_LIBS PRIVATE
CIRCTRTGDialect
MLIRArithDialect
MLIRIndexDialect
MLIRIR
MLIRPass
)
Expand Down
161 changes: 145 additions & 16 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#include "circt/Dialect/RTG/IR/RTGVisitors.h"
#include "circt/Dialect/RTG/Transforms/RTGPasses.h"
#include "circt/Support/Namespace.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -85,7 +86,7 @@ namespace {
/// The abstract base class for elaborated values.
struct ElaboratorValue {
public:
enum class ValueKind { Attribute, Set, Bag, Sequence };
enum class ValueKind { Attribute, Set, Bag, Sequence, Index, Bool };

ElaboratorValue(ValueKind kind) : kind(kind) {}
virtual ~ElaboratorValue() {}
Expand Down Expand Up @@ -144,6 +145,74 @@ class AttributeValue : public ElaboratorValue {
const TypedAttr attr;
};

/// Holds an evaluated value of a `IndexType`'d value.
class IndexValue : public ElaboratorValue {
public:
IndexValue(size_t index) : ElaboratorValue(ValueKind::Index), index(index) {}

// Implement LLVMs RTTI
static bool classof(const ElaboratorValue *val) {
return val->getKind() == ValueKind::Index;
}

llvm::hash_code getHashValue() const override {
return llvm::hash_value(index);
}

bool isEqual(const ElaboratorValue &other) const override {
auto *indexValue = dyn_cast<IndexValue>(&other);
if (!indexValue)
return false;

return index == indexValue->index;
}

#ifndef NDEBUG
void print(llvm::raw_ostream &os) const override {
os << "<index " << index << " at " << this << ">";
}
#endif

size_t getIndex() const { return index; }

private:
const size_t index;
};

/// Holds an evaluated value of an `i1` type'd value.
class BoolValue : public ElaboratorValue {
public:
BoolValue(bool value) : ElaboratorValue(ValueKind::Bool), value(value) {}

// Implement LLVMs RTTI
static bool classof(const ElaboratorValue *val) {
return val->getKind() == ValueKind::Bool;
}

llvm::hash_code getHashValue() const override {
return llvm::hash_value(value);
}

bool isEqual(const ElaboratorValue &other) const override {
auto *val = dyn_cast<BoolValue>(&other);
if (!val)
return false;

return value == val->value;
}

#ifndef NDEBUG
void print(llvm::raw_ostream &os) const override {
os << "<bool " << (value ? "true" : "false") << " at " << this << ">";
}
#endif

bool getBool() const { return value; }

private:
const bool value;
};

/// Holds an evaluated value of a `SetType`'d value.
class SetValue : public ElaboratorValue {
public:
Expand Down Expand Up @@ -366,7 +435,8 @@ class Materializer {

OpBuilder builder(block, insertionPoint);
return TypeSwitch<ElaboratorValue *, Value>(val)
.Case<AttributeValue, SetValue, BagValue, SequenceValue>([&](auto val) {
.Case<AttributeValue, IndexValue, BoolValue, SetValue, BagValue,
SequenceValue>([&](auto val) {
return visit(val, builder, loc, elabRequests, emitError);
})
.Default([](auto val) {
Expand All @@ -389,13 +459,11 @@ class Materializer {
function_ref<InFlightDiagnostic()> emitError) {
auto attr = val->getAttr();

// For integer attributes (and arithmetic operations on them) we use the
// arith dialect.
if (isa<IntegerAttr>(attr)) {
Value res = builder.getContext()
->getLoadedDialect<arith::ArithDialect>()
->materializeConstant(builder, attr, attr.getType(), loc)
->getResult(0);
// For index attributes (and arithmetic operations on them) we use the
// index dialect.
if (auto intAttr = dyn_cast<IntegerAttr>(attr);
intAttr && isa<IndexType>(attr.getType())) {
Value res = builder.create<index::ConstantOp>(loc, intAttr);
materializedValues[val] = res;
return res;
}
Expand All @@ -417,6 +485,22 @@ class Materializer {
return res;
}

Value visit(IndexValue *val, OpBuilder &builder, Location loc,
std::queue<SequenceValue *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
Value res = builder.create<index::ConstantOp>(loc, val->getIndex());
materializedValues[val] = res;
return res;
}

Value visit(BoolValue *val, OpBuilder &builder, Location loc,
std::queue<SequenceValue *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
Value res = builder.create<index::BoolConstantOp>(loc, val->getBool());
materializedValues[val] = res;
return res;
}

Value visit(SetValue *val, OpBuilder &builder, Location loc,
std::queue<SequenceValue *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
Expand Down Expand Up @@ -451,7 +535,7 @@ class Materializer {
if (iter != integerValues.end()) {
materializedWeight = iter->second;
} else {
materializedWeight = builder.create<arith::ConstantOp>(
materializedWeight = builder.create<index::ConstantOp>(
loc, builder.getIndexAttr(weight));
integerValues[weight] = materializedWeight;
}
Expand Down Expand Up @@ -606,9 +690,8 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
// If the multiple is not stored as an AttributeValue, the elaboration
// must have already failed earlier (since we don't have
// unevaluated/opaque values).
auto *interpMultiple = cast<AttributeValue>(state.at(multiple));
uint64_t m = cast<IntegerAttr>(interpMultiple->getAttr()).getInt();
bag[interpValue] += m;
auto *interpMultiple = cast<IndexValue>(state.at(multiple));
bag[interpValue] += interpMultiple->getIndex();
}

internalizeResult<BagValue>(op.getBag(), std::move(bag), op.getType());
Expand Down Expand Up @@ -688,6 +771,43 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(index::AddOp op) {
size_t lhs = cast<IndexValue>(state.at(op.getLhs()))->getIndex();
size_t rhs = cast<IndexValue>(state.at(op.getRhs()))->getIndex();
internalizeResult<IndexValue>(op.getResult(), lhs + rhs);
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(index::CmpOp op) {
size_t lhs = cast<IndexValue>(state.at(op.getLhs()))->getIndex();
size_t rhs = cast<IndexValue>(state.at(op.getRhs()))->getIndex();
bool result;
switch (op.getPred()) {
case index::IndexCmpPredicate::EQ:
result = lhs == rhs;
break;
case index::IndexCmpPredicate::NE:
result = lhs != rhs;
break;
case index::IndexCmpPredicate::ULT:
result = lhs < rhs;
break;
case index::IndexCmpPredicate::ULE:
result = lhs <= rhs;
break;
case index::IndexCmpPredicate::UGT:
result = lhs > rhs;
break;
case index::IndexCmpPredicate::UGE:
result = lhs >= rhs;
break;
default:
return op->emitOpError("elaboration not supported");
}
internalizeResult<BoolValue>(op.getResult(), result);
return DeletionKind::Delete;
}

FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
if (op->hasTrait<OpTrait::ConstantLike>()) {
SmallVector<OpFoldResult, 1> result;
Expand All @@ -700,11 +820,20 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
return op->emitError(
"only typed attributes supported for constant-like operations");

internalizeResult<AttributeValue>(op->getResult(0), attr);
auto intAttr = dyn_cast<IntegerAttr>(attr);
if (intAttr && isa<IndexType>(attr.getType()))
internalizeResult<IndexValue>(op->getResult(0), intAttr.getInt());
else if (intAttr && intAttr.getType().isSignlessInteger(1))
internalizeResult<BoolValue>(op->getResult(0), intAttr.getInt());
else
internalizeResult<AttributeValue>(op->getResult(0), attr);

return DeletionKind::Delete;
}

return RTGBase::dispatchOpVisitor(op);
return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
.Case<index::AddOp, index::CmpOp>([&](auto op) { return visitOp(op); })
.Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); });
}

LogicalResult elaborate(SequenceOp family, SequenceOp dest,
Expand Down
Loading

0 comments on commit 83c12d2

Please sign in to comment.