diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index ad212368c8d5..d7ce4ce05f71 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/DenseMapInfoVariant.h" #include "llvm/Support/Debug.h" #include #include @@ -79,343 +80,278 @@ static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) { } //===----------------------------------------------------------------------===// -// Elaborator Values +// Elaborator Value //===----------------------------------------------------------------------===// namespace { +struct BagStorage; +struct SequenceStorage; +struct SetStorage; /// The abstract base class for elaborated values. -struct ElaboratorValue { -public: - enum class ValueKind { Attribute, Set, Bag, Sequence, Index, Bool }; +using ElaboratorValue = std::variant; - ElaboratorValue(ValueKind kind) : kind(kind) {} - virtual ~ElaboratorValue() {} +// NOLINTNEXTLINE(readability-identifier-naming) +llvm::hash_code hash_value(const ElaboratorValue &val) { + return std::visit( + [&val](const auto &alternative) { + // Include index in hash to make sure same value as different + // alternatives don't collide. + return llvm::hash_combine(val.index(), alternative); + }, + val); +} - virtual llvm::hash_code getHashValue() const = 0; - virtual bool isEqual(const ElaboratorValue &other) const = 0; +} // namespace -#ifndef NDEBUG - virtual void print(llvm::raw_ostream &os) const = 0; -#endif +namespace llvm { - ValueKind getKind() const { return kind; } +template <> +struct DenseMapInfo { + static inline unsigned getEmptyKey() { return false; } + static inline unsigned getTombstoneKey() { return true; } + static unsigned getHashValue(const bool &val) { return val * 37U; } -private: - const ValueKind kind; + static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; } }; -/// Holds any typed attribute. Wrapping around an MLIR `Attribute` allows us to -/// use this elaborator value class for any values that have a corresponding -/// MLIR attribute rather than one per kind of attribute. We only support typed -/// attributes because for materialization we need to provide the type to the -/// dialect's materializer. -class AttributeValue : public ElaboratorValue { -public: - AttributeValue(TypedAttr attr) - : ElaboratorValue(ValueKind::Attribute), attr(attr) { - assert(attr && "null attributes not allowed"); - } - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Attribute; - } - - llvm::hash_code getHashValue() const override { - return llvm::hash_combine(attr); - } - - bool isEqual(const ElaboratorValue &other) const override { - auto *attrValue = dyn_cast(&other); - if (!attrValue) - return false; - - return attr == attrValue->attr; - } +} // namespace llvm -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << ""; - } -#endif +//===----------------------------------------------------------------------===// +// Elaborator Value Storages and Internalization +//===----------------------------------------------------------------------===// - TypedAttr getAttr() const { return attr; } +namespace { -private: - const TypedAttr attr; +/// Lightweight object to be used as the key for internalization sets. It caches +/// the hashcode of the internalized object and a pointer to it. This allows a +/// delayed allocation and construction of the actual object and thus only has +/// to happen if the object is not already in the set. +template +struct HashedStorage { + HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr) + : hashcode(hashcode), storage(storage) {} + + unsigned hashcode; + StorageTy *storage; }; -/// Holds an evaluated value of a `IndexType`'d value. -class IndexValue : public ElaboratorValue { -public: - IndexValue(size_t index) : ElaboratorValue(ValueKind::Index), index(index) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Index; - } - - llvm::hash_code getHashValue() const override { - return llvm::hash_value(index); +/// A DenseMapInfo implementation to support 'insert_as' for the internalization +/// sets. When comparing two 'HashedStorage's we can just compare the already +/// internalized storage pointers, otherwise we have to call the costly +/// 'isEqual' method. +template +struct StorageKeyInfo { + static inline HashedStorage getEmptyKey() { + return HashedStorage(0, + DenseMapInfo::getEmptyKey()); } - - bool isEqual(const ElaboratorValue &other) const override { - auto *indexValue = dyn_cast(&other); - if (!indexValue) - return false; - - return index == indexValue->index; + static inline HashedStorage getTombstoneKey() { + return HashedStorage( + 0, DenseMapInfo::getTombstoneKey()); } -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << ""; + static inline unsigned getHashValue(const HashedStorage &key) { + return key.hashcode; } -#endif - - size_t getIndex() const { return index; } - -private: - const size_t index; -}; - -/// Holds an evaluated value of an `i1` type'd value. -class BoolValue : public ElaboratorValue { -public: - BoolValue(bool value) : ElaboratorValue(ValueKind::Bool), value(value) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Bool; + static inline unsigned getHashValue(const StorageTy &key) { + return key.hashcode; } - llvm::hash_code getHashValue() const override { - return llvm::hash_value(value); + static inline bool isEqual(const HashedStorage &lhs, + const HashedStorage &rhs) { + return lhs.storage == rhs.storage; } - - bool isEqual(const ElaboratorValue &other) const override { - auto *val = dyn_cast(&other); - if (!val) + static inline bool isEqual(const StorageTy &lhs, + const HashedStorage &rhs) { + if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey())) return false; - return value == val->value; - } - -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << ""; + return lhs.isEqual(rhs.storage); } -#endif - - bool getBool() const { return value; } - -private: - const bool value; }; -/// Holds an evaluated value of a `SetType`'d value. -class SetValue : public ElaboratorValue { -public: - SetValue(SetVector &&set, Type type) - : ElaboratorValue(ValueKind::Set), set(std::move(set)), type(type), - cachedHash(llvm::hash_combine( - llvm::hash_combine_range(set.begin(), set.end()), type)) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Set; - } - - llvm::hash_code getHashValue() const override { return cachedHash; } - - bool isEqual(const ElaboratorValue &other) const override { - auto *otherSet = dyn_cast(&other); - if (!otherSet) - return false; - - if (cachedHash != otherSet->cachedHash) - return false; - - // Make sure empty sets of different types are not considered equal - return set == otherSet->set && type == otherSet->type; - } +/// Storage object for an '!rtg.set'. +struct SetStorage { + SetStorage(SetVector &&set, Type type) + : hashcode(llvm::hash_combine( + type, llvm::hash_combine_range(set.begin(), set.end()))), + set(std::move(set)), type(type) {} -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << "print(os); }); - os << "} at " << this << ">"; + bool isEqual(const SetStorage *other) const { + return hashcode == other->hashcode && set == other->set && + type == other->type; } -#endif - - const SetVector &getSet() const { return set; } - Type getType() const { return type; } + // The cached hashcode to avoid repeated computations. + const unsigned hashcode; -private: - // We currently use a sorted vector to represent sets. Note that it is sorted - // by the pointer value and thus non-deterministic. - // We probably want to do some profiling in the future to see if a DenseSet or - // other representation is better suited. - const SetVector set; + // Stores the elaborated values contained in the set. + const SetVector set; // Store the set type such that we can materialize this evaluated value // also in the case where the set is empty. const Type type; - - // Compute the hash only once at constructor time. - const llvm::hash_code cachedHash; }; -/// Holds an evaluated value of a `BagType`'d value. -class BagValue : public ElaboratorValue { -public: - BagValue(MapVector &&bag, Type type) - : ElaboratorValue(ValueKind::Bag), bag(std::move(bag)), type(type), - cachedHash(llvm::hash_combine( - llvm::hash_combine_range(bag.begin(), bag.end()), type)) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Bag; - } +/// Storage object for an '!rtg.bag'. +struct BagStorage { + BagStorage(MapVector &&bag, Type type) + : hashcode(llvm::hash_combine( + type, llvm::hash_combine_range(bag.begin(), bag.end()))), + bag(std::move(bag)), type(type) {} - llvm::hash_code getHashValue() const override { return cachedHash; } - - bool isEqual(const ElaboratorValue &other) const override { - auto *otherBag = dyn_cast(&other); - if (!otherBag) - return false; - - if (cachedHash != otherBag->cachedHash) - return false; - - return llvm::equal(bag, otherBag->bag) && type == otherBag->type; - } - -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << " el) { - el.first->print(os); - os << " -> " << el.second; - }); - os << "} at " << this << ">"; + bool isEqual(const BagStorage *other) const { + return hashcode == other->hashcode && llvm::equal(bag, other->bag) && + type == other->type; } -#endif - - const MapVector &getBag() const { return bag; } - Type getType() const { return type; } + // The cached hashcode to avoid repeated computations. + const unsigned hashcode; -private: - // Stores the elaborated values of the bag. - const MapVector bag; + // Stores the elaborated values contained in the bag with their number of + // occurences. + const MapVector bag; - // Store the type of the bag such that we can materialize this evaluated value + // Store the bag type such that we can materialize this evaluated value // also in the case where the bag is empty. const Type type; - - // Compute the hash only once at constructor time. - const llvm::hash_code cachedHash; }; -/// Holds an evaluated value of a `SequenceType`'d value. -class SequenceValue : public ElaboratorValue { -public: - SequenceValue(StringRef name, StringAttr familyName, - SmallVector &&args) - : ElaboratorValue(ValueKind::Sequence), name(name), - familyName(familyName), args(std::move(args)), - cachedHash(llvm::hash_combine( - llvm::hash_combine_range(this->args.begin(), this->args.end()), - name, familyName)) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Sequence; +/// Storage object for an '!rtg.sequence'. +struct SequenceStorage { + SequenceStorage(StringRef name, StringAttr familyName, + SmallVector &&args) + : hashcode(llvm::hash_combine( + name, familyName, + llvm::hash_combine_range(args.begin(), args.end()))), + name(name), familyName(familyName), args(std::move(args)) {} + + bool isEqual(const SequenceStorage *other) const { + return hashcode == other->hashcode && name == other->name && + familyName == other->familyName && args == other->args; } - llvm::hash_code getHashValue() const override { return cachedHash; } + // The cached hashcode to avoid repeated computations. + const unsigned hashcode; - bool isEqual(const ElaboratorValue &other) const override { - auto *otherSeq = dyn_cast(&other); - if (!otherSeq) - return false; + // The name of this fully substituted and elaborated sequence. + const StringRef name; - if (cachedHash != otherSeq->cachedHash) - return false; + // The name of the sequence family this sequence is derived from. + const StringAttr familyName; - return name == otherSeq->name && familyName == otherSeq->familyName && - args == otherSeq->args; - } + // The elaborator values used during substitution of the sequence family. + const SmallVector args; +}; -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << "print(os); }); - os << ") at " << this << ">"; +/// An 'Internalizer' object internalizes storages and takes ownership of them. +/// When the initializer object is destroyed, all owned storages are also +/// deallocated and thus must not be accessed anymore. +class Internalizer { +public: + /// Internalize a storage of type `StorageTy` constructed with arguments + /// `args`. The pointers returned by this method can be used to compare + /// objects when, e.g., computing set differences, uniquing the elements in a + /// set, etc. Otherwise, we'd need to do a deep value comparison in those + /// situations. + template + StorageTy *internalize(Args &&...args) { + StorageTy storage(std::forward(args)...); + + auto existing = getInternSet().insert_as( + HashedStorage(storage.hashcode), storage); + StorageTy *&storagePtr = existing.first->storage; + if (existing.second) + storagePtr = + new (allocator.Allocate()) StorageTy(std::move(storage)); + + return storagePtr; } -#endif - - StringRef getName() const { return name; } - StringAttr getFamilyName() const { return familyName; } - ArrayRef getArgs() const { return args; } private: - const StringRef name; - const StringAttr familyName; - const SmallVector args; + template + DenseSet, StorageKeyInfo> & + getInternSet() { + if constexpr (std::is_same_v) + return internedSets; + else if constexpr (std::is_same_v) + return internedBags; + else if constexpr (std::is_same_v) + return internedSequences; + else + static_assert(!sizeof(StorageTy), + "no intern set available for this storage type."); + } - // Compute the hash only once at constructor time. - const llvm::hash_code cachedHash; + // This allocator allocates on the heap. It automatically deallocates all + // objects it allocated once the allocator itself is destroyed. + llvm::BumpPtrAllocator allocator; + + // The sets holding the internalized objects. We use one set per storage type + // such that we can have a simpler equality checking function (no need to + // compare some sort of TypeIDs). + DenseSet, StorageKeyInfo> internedSets; + DenseSet, StorageKeyInfo> internedBags; + DenseSet, StorageKeyInfo> + internedSequences; }; + } // namespace #ifndef NDEBUG + static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const ElaboratorValue &value) { - value.print(os); - return os; + const ElaboratorValue &value); + +static void print(TypedAttr val, llvm::raw_ostream &os) { + os << ""; } -#endif -//===----------------------------------------------------------------------===// -// Hash Map Helpers -//===----------------------------------------------------------------------===// +static void print(BagStorage *val, llvm::raw_ostream &os) { + os << "bag, os, + [&](const std::pair &el) { + os << el.first << " -> " << el.second; + }); + os << "} at " << val << ">"; +} -// NOLINTNEXTLINE(readability-identifier-naming) -static llvm::hash_code hash_value(const ElaboratorValue &val) { - return val.getHashValue(); +static void print(bool val, llvm::raw_ostream &os) { + os << ""; } -namespace { -struct InternMapInfo : public DenseMapInfo { - static unsigned getHashValue(const ElaboratorValue *value) { - assert(value != getTombstoneKey() && value != getEmptyKey()); - return hash_value(*value); - } +static void print(size_t val, llvm::raw_ostream &os) { + os << ""; +} - static bool isEqual(const ElaboratorValue *lhs, const ElaboratorValue *rhs) { - if (lhs == rhs) - return true; +static void print(SequenceStorage *val, llvm::raw_ostream &os) { + os << "name << " derived from @" + << val->familyName.getValue() << "("; + llvm::interleaveComma(val->args, os, + [&](const ElaboratorValue &val) { os << val; }); + os << ") at " << val << ">"; +} - auto *tk = getTombstoneKey(); - auto *ek = getEmptyKey(); - if (lhs == tk || rhs == tk || lhs == ek || rhs == ek) - return false; +static void print(SetStorage *val, llvm::raw_ostream &os) { + os << "set, os, + [&](const ElaboratorValue &val) { os << val; }); + os << "} at " << val << ">"; +} - return lhs->isEqual(*rhs); - } -}; -} // namespace +static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const ElaboratorValue &value) { + std::visit([&](auto val) { print(val, os); }, value); + + return os; +} + +#endif //===----------------------------------------------------------------------===// -// Main Elaborator Implementation +// Elaborator Value Materialization //===----------------------------------------------------------------------===// namespace { @@ -427,23 +363,18 @@ class Materializer { /// Materialize IR representing the provided `ElaboratorValue` and return the /// `Value` or a null value on failure. - Value materialize(ElaboratorValue *val, Location loc, - std::queue &elabRequests, + Value materialize(ElaboratorValue val, Location loc, + std::queue &elabRequests, function_ref emitError) { auto iter = materializedValues.find(val); if (iter != materializedValues.end()) return iter->second; - LLVM_DEBUG(llvm::dbgs() << "Materializing " << *val << "\n\n"); + LLVM_DEBUG(llvm::dbgs() << "Materializing " << val << "\n\n"); - return TypeSwitch(val) - .Case( - [&](auto val) { return visit(val, loc, elabRequests, emitError); }) - .Default([](auto val) { - assert(false && "all cases must be covered above"); - return Value(); - }); + return std::visit( + [&](auto val) { return visit(val, loc, elabRequests, emitError); }, + val); } /// If `op` is not in the same region as the materializer insertion point, a @@ -453,8 +384,8 @@ class Materializer { /// deleted until `op` is reached. An error is returned if the operation is /// before the insertion point. LogicalResult materialize(Operation *op, - DenseMap &state, - std::queue &elabRequests) { + DenseMap &state, + std::queue &elabRequests) { if (op->getNumRegions() > 0) return op->emitOpError("ops with nested regions must be elaborated away"); @@ -521,15 +452,13 @@ class Materializer { } private: - Value visit(AttributeValue *val, Location loc, - std::queue &elabRequests, + Value visit(TypedAttr val, Location loc, + std::queue &elabRequests, function_ref emitError) { - auto attr = val->getAttr(); - // For index attributes (and arithmetic operations on them) we use the // index dialect. - if (auto intAttr = dyn_cast(attr); - intAttr && isa(attr.getType())) { + if (auto intAttr = dyn_cast(val); + intAttr && isa(val.getType())) { Value res = builder.create(loc, intAttr); materializedValues[val] = res; return res; @@ -537,12 +466,12 @@ class Materializer { // For any other attribute, we just call the materializer of the dialect // defining that attribute. - auto *op = attr.getDialect().materializeConstant(builder, attr, - attr.getType(), loc); + auto *op = + val.getDialect().materializeConstant(builder, val, val.getType(), loc); if (!op) { emitError() << "materializer of dialect '" - << attr.getDialect().getNamespace() - << "' unable to materialize value for attribute '" << attr + << val.getDialect().getNamespace() + << "' unable to materialize value for attribute '" << val << "'"; return Value(); } @@ -552,28 +481,28 @@ class Materializer { return res; } - Value visit(IndexValue *val, Location loc, - std::queue &elabRequests, + Value visit(size_t val, Location loc, + std::queue &elabRequests, function_ref emitError) { - Value res = builder.create(loc, val->getIndex()); + Value res = builder.create(loc, val); materializedValues[val] = res; return res; } - Value visit(BoolValue *val, Location loc, - std::queue &elabRequests, + Value visit(bool val, Location loc, + std::queue &elabRequests, function_ref emitError) { - Value res = builder.create(loc, val->getBool()); + Value res = builder.create(loc, val); materializedValues[val] = res; return res; } - Value visit(SetValue *val, Location loc, - std::queue &elabRequests, + Value visit(SetStorage *val, Location loc, + std::queue &elabRequests, function_ref emitError) { SmallVector elements; - elements.reserve(val->getSet().size()); - for (auto *el : val->getSet()) { + elements.reserve(val->set.size()); + for (auto el : val->set) { auto materialized = materialize(el, loc, elabRequests, emitError); if (!materialized) return Value(); @@ -581,47 +510,38 @@ class Materializer { elements.push_back(materialized); } - auto res = builder.create(loc, val->getType(), elements); + auto res = builder.create(loc, val->type, elements); materializedValues[val] = res; return res; } - Value visit(BagValue *val, Location loc, - std::queue &elabRequests, + Value visit(BagStorage *val, Location loc, + std::queue &elabRequests, function_ref emitError) { SmallVector values, weights; - values.reserve(val->getBag().size()); - weights.reserve(val->getBag().size()); - for (auto [val, weight] : val->getBag()) { + values.reserve(val->bag.size()); + weights.reserve(val->bag.size()); + for (auto [val, weight] : val->bag) { auto materializedVal = materialize(val, loc, elabRequests, emitError); - if (!materializedVal) + auto materializedWeight = + materialize(weight, loc, elabRequests, emitError); + if (!materializedVal || !materializedWeight) return Value(); - auto iter = integerValues.find(weight); - Value materializedWeight; - if (iter != integerValues.end()) { - materializedWeight = iter->second; - } else { - materializedWeight = builder.create( - loc, builder.getIndexAttr(weight)); - integerValues[weight] = materializedWeight; - } - values.push_back(materializedVal); weights.push_back(materializedWeight); } - auto res = - builder.create(loc, val->getType(), values, weights); + auto res = builder.create(loc, val->type, values, weights); materializedValues[val] = res; return res; } - Value visit(SequenceValue *val, Location loc, - std::queue &elabRequests, + Value visit(SequenceStorage *val, Location loc, + std::queue &elabRequests, function_ref emitError) { elabRequests.push(val); - return builder.create(loc, val->getName(), ValueRange()); + return builder.create(loc, val->name, ValueRange()); } private: @@ -630,8 +550,7 @@ class Materializer { /// insertion point such that future materializations can also reuse previous /// materializations without running into dominance issues (or requiring /// additional checks to avoid them). - DenseMap materializedValues; - DenseMap integerValues; + DenseMap materializedValues; /// Cache the builder to continue insertions at their current insertion point /// for the reason stated above. @@ -640,6 +559,10 @@ class Materializer { SmallVector toDelete; }; +//===----------------------------------------------------------------------===// +// Elaboration Visitor +//===----------------------------------------------------------------------===// + /// Used to signal to the elaboration driver whether the operation should be /// removed. enum class DeletionKind { Keep, Delete }; @@ -652,21 +575,11 @@ struct ElaboratorSharedState { SymbolTable &table; std::mt19937 rng; Namespace names; - - // A map used to intern elaborator values. We do this such that we can - // compare pointers when, e.g., computing set differences, uniquing the - // elements in a set, etc. Otherwise, we'd need to do a deep value comparison - // in those situations. - // Use a pointer as the key with custom MapInfo because of object slicing when - // inserting an object of a derived class of ElaboratorValue. - // The custom MapInfo makes sure that we do a value comparison instead of - // comparing the pointers. - DenseMap, InternMapInfo> - interned; + Internalizer internalizer; /// The worklist used to keep track of the test and sequence operations to /// make sure they are processed top-down (BFS traversal). - std::queue worklist; + std::queue worklist; }; /// Interprets the IR to perform and lower the represented randomizations. @@ -678,15 +591,9 @@ class Elaborator : public RTGOpVisitor> { Elaborator(ElaboratorSharedState &sharedState, Materializer &materializer) : sharedState(sharedState), materializer(materializer) {} - /// Helper to perform internalization and keep track of interpreted value for - /// the given SSA value. - template - void internalizeResult(Value val, Args &&...args) { - // TODO: this isn't the most efficient way to internalize - auto ptr = std::make_unique(std::forward(args)...); - auto *e = ptr.get(); - auto [iter, _] = sharedState.interned.insert({e, std::move(ptr)}); - state[val] = iter->second.get(); + template + inline ValueTy get(Value val) { + return std::get(state.at(val)); } /// Print a nice error message for operations we don't support yet. @@ -708,11 +615,11 @@ class Elaborator : public RTGOpVisitor> { auto intAttr = dyn_cast(attr); if (intAttr && isa(attr.getType())) - internalizeResult(op->getResult(0), intAttr.getInt()); + state[op->getResult(0)] = size_t(intAttr.getInt()); else if (intAttr && intAttr.getType().isSignlessInteger(1)) - internalizeResult(op->getResult(0), intAttr.getInt()); + state[op->getResult(0)] = bool(intAttr.getInt()); else - internalizeResult(op->getResult(0), attr); + state[op->getResult(0)] = attr; return DeletionKind::Delete; } @@ -727,14 +634,15 @@ class Elaborator : public RTGOpVisitor> { } FailureOr visitOp(SequenceClosureOp op) { - SmallVector args; + SmallVector args; for (auto arg : op.getArgs()) args.push_back(state.at(arg)); auto familyName = op.getSequenceAttr(); auto name = sharedState.names.newName(familyName.getValue()); - internalizeResult(op.getResult(), name, familyName, - std::move(args)); + state[op.getResult()] = + sharedState.internalizer.internalize(name, familyName, + std::move(args)); return DeletionKind::Delete; } @@ -743,84 +651,81 @@ class Elaborator : public RTGOpVisitor> { } FailureOr visitOp(SetCreateOp op) { - SetVector set; + SetVector set; for (auto val : op.getElements()) set.insert(state.at(val)); - internalizeResult(op.getSet(), std::move(set), - op.getSet().getType()); + state[op.getSet()] = sharedState.internalizer.internalize( + std::move(set), op.getSet().getType()); return DeletionKind::Delete; } FailureOr visitOp(SetSelectRandomOp op) { - auto *set = cast(state.at(op.getSet())); + auto set = get(op.getSet())->set; size_t selected; if (auto intAttr = op->getAttrOfType("rtg.elaboration_custom_seed")) { std::mt19937 customRng(intAttr.getInt()); - selected = getUniformlyInRange(customRng, 0, set->getSet().size() - 1); + selected = getUniformlyInRange(customRng, 0, set.size() - 1); } else { - selected = - getUniformlyInRange(sharedState.rng, 0, set->getSet().size() - 1); + selected = getUniformlyInRange(sharedState.rng, 0, set.size() - 1); } - state[op.getResult()] = set->getSet()[selected]; + state[op.getResult()] = set[selected]; return DeletionKind::Delete; } FailureOr visitOp(SetDifferenceOp op) { - auto original = cast(state.at(op.getOriginal()))->getSet(); - auto diff = cast(state.at(op.getDiff()))->getSet(); + auto original = get(op.getOriginal())->set; + auto diff = get(op.getDiff())->set; - SetVector result(original); + SetVector result(original); result.set_subtract(diff); - internalizeResult(op.getResult(), std::move(result), - op.getResult().getType()); + state[op.getResult()] = sharedState.internalizer.internalize( + std::move(result), op.getResult().getType()); return DeletionKind::Delete; } FailureOr visitOp(SetUnionOp op) { - SetVector result; + SetVector result; for (auto set : op.getSets()) - result.set_union(cast(state.at(set))->getSet()); + result.set_union(get(set)->set); - internalizeResult(op.getResult(), std::move(result), - op.getType()); + state[op.getResult()] = sharedState.internalizer.internalize( + std::move(result), op.getType()); return DeletionKind::Delete; } FailureOr visitOp(SetSizeOp op) { - auto size = cast(state.at(op.getSet()))->getSet().size(); - auto sizeAttr = IntegerAttr::get(IndexType::get(op->getContext()), size); - internalizeResult(op.getResult(), sizeAttr); + auto size = get(op.getSet())->set.size(); + state[op.getResult()] = size; return DeletionKind::Delete; } FailureOr visitOp(BagCreateOp op) { - MapVector bag; + MapVector bag; for (auto [val, multiple] : llvm::zip(op.getElements(), op.getMultiples())) { - auto *interpValue = state.at(val); // If the multiple is not stored as an AttributeValue, the elaboration // must have already failed earlier (since we don't have // unevaluated/opaque values). - auto *interpMultiple = cast(state.at(multiple)); - bag[interpValue] += interpMultiple->getIndex(); + bag[state.at(val)] += get(multiple); } - internalizeResult(op.getBag(), std::move(bag), op.getType()); + state[op.getBag()] = sharedState.internalizer.internalize( + std::move(bag), op.getType()); return DeletionKind::Delete; } FailureOr visitOp(BagSelectRandomOp op) { - auto *bag = cast(state.at(op.getBag())); + auto bag = get(op.getBag())->bag; - SmallVector> prefixSum; - prefixSum.reserve(bag->getBag().size()); + SmallVector> prefixSum; + prefixSum.reserve(bag.size()); uint32_t accumulator = 0; - for (auto [val, weight] : bag->getBag()) { + for (auto [val, weight] : bag) { accumulator += weight; prefixSum.push_back({val, accumulator}); } @@ -834,20 +739,21 @@ class Elaborator : public RTGOpVisitor> { auto idx = getUniformlyInRange(customRng, 0, accumulator - 1); auto *iter = llvm::upper_bound( prefixSum, idx, - [](uint32_t a, const std::pair &b) { + [](uint32_t a, const std::pair &b) { return a < b.second; }); + state[op.getResult()] = iter->first; return DeletionKind::Delete; } FailureOr visitOp(BagDifferenceOp op) { - auto *original = cast(state.at(op.getOriginal())); - auto *diff = cast(state.at(op.getDiff())); + auto original = get(op.getOriginal())->bag; + auto diff = get(op.getDiff())->bag; - MapVector result; - for (const auto &el : original->getBag()) { - if (!diff->getBag().contains(el.first)) { + MapVector result; + for (const auto &el : original) { + if (!diff.contains(el.first)) { result.insert(el); continue; } @@ -855,40 +761,39 @@ class Elaborator : public RTGOpVisitor> { if (op.getInf()) continue; - auto toDiff = diff->getBag().lookup(el.first); + auto toDiff = diff.lookup(el.first); if (el.second <= toDiff) continue; result.insert({el.first, el.second - toDiff}); } - internalizeResult(op.getResult(), std::move(result), - op.getType()); + state[op.getResult()] = sharedState.internalizer.internalize( + std::move(result), op.getType()); return DeletionKind::Delete; } FailureOr visitOp(BagUnionOp op) { - MapVector result; + MapVector result; for (auto bag : op.getBags()) { - auto *val = cast(state.at(bag)); - for (auto [el, multiple] : val->getBag()) + auto val = get(bag)->bag; + for (auto [el, multiple] : val) result[el] += multiple; } - internalizeResult(op.getResult(), std::move(result), - op.getType()); + state[op.getResult()] = sharedState.internalizer.internalize( + std::move(result), op.getType()); return DeletionKind::Delete; } FailureOr visitOp(BagUniqueSizeOp op) { - auto size = cast(state.at(op.getBag()))->getBag().size(); - auto sizeAttr = IntegerAttr::get(IndexType::get(op->getContext()), size); - internalizeResult(op.getResult(), sizeAttr); + auto size = get(op.getBag())->bag.size(); + state[op.getResult()] = size; return DeletionKind::Delete; } FailureOr visitOp(scf::IfOp op) { - bool cond = cast(state.at(op.getCondition()))->getBool(); + bool cond = get(op.getCondition()); auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion(); if (toElaborate.empty()) return DeletionKind::Delete; @@ -910,13 +815,15 @@ class Elaborator : public RTGOpVisitor> { } FailureOr visitOp(scf::ForOp op) { - auto *lowerBound = dyn_cast(state.at(op.getLowerBound())); - auto *step = dyn_cast(state.at(op.getStep())); - auto *upperBound = dyn_cast(state.at(op.getUpperBound())); - - if (!lowerBound || !step || !upperBound) + if (!(std::holds_alternative(state.at(op.getLowerBound())) && + std::holds_alternative(state.at(op.getStep())) && + std::holds_alternative(state.at(op.getUpperBound())))) return op->emitOpError("can only elaborate index type iterator"); + auto lowerBound = get(op.getLowerBound()); + auto step = get(op.getStep()); + auto upperBound = get(op.getUpperBound()); + // Prepare for first iteration by assigning the nested regions block // arguments. We can just reuse this elaborator because we need access to // values elaborated in the parent region anyway and materialize everything @@ -927,14 +834,13 @@ class Elaborator : public RTGOpVisitor> { state[iterArg] = state.at(initArg); // This loop performs the actual 'scf.for' loop iterations. - for (size_t i = lowerBound->getIndex(); i < upperBound->getIndex(); - i += step->getIndex()) { + for (size_t i = lowerBound; i < upperBound; i += step) { if (failed(elaborate(op.getBodyRegion()))) return failure(); // Prepare for the next iteration by updating the mapping of the nested // regions block arguments - internalizeResult(op.getInductionVar(), i + step->getIndex()); + state[op.getInductionVar()] = i + step; for (auto [iterArg, prevIterArg] : llvm::zip(op.getRegionIterArgs(), op.getBody()->getTerminator()->getOperands())) @@ -954,15 +860,15 @@ class Elaborator : public RTGOpVisitor> { } FailureOr visitOp(index::AddOp op) { - size_t lhs = cast(state.at(op.getLhs()))->getIndex(); - size_t rhs = cast(state.at(op.getRhs()))->getIndex(); - internalizeResult(op.getResult(), lhs + rhs); + size_t lhs = get(op.getLhs()); + size_t rhs = get(op.getRhs()); + state[op.getResult()] = lhs + rhs; return DeletionKind::Delete; } FailureOr visitOp(index::CmpOp op) { - size_t lhs = cast(state.at(op.getLhs()))->getIndex(); - size_t rhs = cast(state.at(op.getRhs()))->getIndex(); + size_t lhs = get(op.getLhs()); + size_t rhs = get(op.getRhs()); bool result; switch (op.getPred()) { case index::IndexCmpPredicate::EQ: @@ -986,7 +892,7 @@ class Elaborator : public RTGOpVisitor> { default: return op->emitOpError("elaboration not supported"); } - internalizeResult(op.getResult(), result); + state[op.getResult()] = result; return DeletionKind::Delete; } @@ -1003,7 +909,7 @@ class Elaborator : public RTGOpVisitor> { // NOLINTNEXTLINE(misc-no-recursion) LogicalResult elaborate(Region ®ion, - ArrayRef regionArguments = {}) { + ArrayRef regionArguments = {}) { if (region.getBlocks().size() > 1) return region.getParentOp()->emitOpError( "regions with more than one block are not supported"); @@ -1027,7 +933,7 @@ class Elaborator : public RTGOpVisitor> { llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) { if (state.contains(res)) - llvm::dbgs() << *state.at(res); + llvm::dbgs() << state.at(res); else llvm::dbgs() << "unknown"; }); @@ -1048,7 +954,7 @@ class Elaborator : public RTGOpVisitor> { Materializer &materializer; // A map from SSA values to a pointer of an interned elaborator value. - DenseMap state; + DenseMap state; }; } // namespace @@ -1149,19 +1055,18 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp, auto *curr = state.worklist.front(); state.worklist.pop(); - if (table.lookup(curr->getName())) + if (table.lookup(curr->name)) continue; - auto familyOp = table.lookup(curr->getFamilyName()); + auto familyOp = table.lookup(curr->familyName); // TODO: don't clone if this is the only remaining reference to this // sequence OpBuilder builder(familyOp); auto seqOp = builder.cloneWithoutRegions(familyOp); seqOp.getBodyRegion().emplaceBlock(); - seqOp.setSymName(curr->getName()); + seqOp.setSymName(curr->name); table.insert(seqOp); - assert(seqOp.getSymName() == curr->getName() && - "should not have been renamed"); + assert(seqOp.getSymName() == curr->name && "should not have been renamed"); LLVM_DEBUG(llvm::dbgs() << "\n=== Elaborating sequence family @" << familyOp.getSymName() @@ -1169,7 +1074,7 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp, Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody())); Elaborator elaborator(state, materializer); - if (failed(elaborator.elaborate(familyOp.getBodyRegion(), curr->getArgs()))) + if (failed(elaborator.elaborate(familyOp.getBodyRegion(), curr->args))) return failure(); materializer.finalize();