20#include "mlir/Dialect/Index/IR/IndexDialect.h"
21#include "mlir/Dialect/Index/IR/IndexOps.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/PatternMatch.h"
25#include "llvm/ADT/DenseMapInfoVariant.h"
26#include "llvm/Support/Debug.h"
32#define GEN_PASS_DEF_ELABORATIONPASS
33#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
42#define DEBUG_TYPE "rtg-elaboration"
54 size_t n = w / 32 + (w % 32 != 0);
56 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
61 const uint32_t diff = b - a + 1;
65 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
69 uint32_t width = digits - llvm::countl_zero(diff) - 1;
70 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
88struct SequenceStorage;
92struct VirtualRegister {
93 VirtualRegister(uint64_t
id, ArrayAttr allowedRegs)
94 : id(id), allowedRegs(allowedRegs) {}
96 bool operator==(
const VirtualRegister &other)
const {
99 allowedRegs == other.allowedRegs &&
100 "instances with the same ID must have the same allowed registers");
101 return id == other.id;
109 ArrayAttr allowedRegs;
113 LabelValue(StringAttr name, uint64_t
id = 0) : name(name), id(id) {}
115 bool operator==(
const LabelValue &other)
const {
116 return name == other.name &&
id == other.id;
127using ElaboratorValue =
128 std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
129 SetStorage *, VirtualRegister, LabelValue>;
132llvm::hash_code
hash_value(
const VirtualRegister &val) {
133 return llvm::hash_value(val.id);
137llvm::hash_code
hash_value(
const LabelValue &val) {
138 return llvm::hash_combine(val.id, val.name);
142llvm::hash_code
hash_value(
const ElaboratorValue &val) {
144 [&val](
const auto &alternative) {
147 return llvm::hash_combine(val.index(), alternative);
162 static bool isEqual(
const bool &lhs,
const bool &rhs) {
return lhs == rhs; }
168 return VirtualRegister(0, ArrayAttr());
171 return VirtualRegister(~0, ArrayAttr());
174 return llvm::hash_combine(val.id, val.allowedRegs);
177 static bool isEqual(
const VirtualRegister &lhs,
const VirtualRegister &rhs) {
184 static inline LabelValue
getEmptyKey() {
return LabelValue(StringAttr(), 0); }
186 return LabelValue(StringAttr(), ~0);
189 return llvm::hash_combine(val.name, val.id);
192 static bool isEqual(
const LabelValue &lhs,
const LabelValue &rhs) {
209template <
typename StorageTy>
210struct HashedStorage {
211 HashedStorage(
unsigned hashcode = 0, StorageTy *storage =
nullptr)
212 : hashcode(hashcode), storage(storage) {}
222template <
typename StorageTy>
223struct StorageKeyInfo {
224 static inline HashedStorage<StorageTy> getEmptyKey() {
225 return HashedStorage<StorageTy>(0,
226 DenseMapInfo<StorageTy *>::getEmptyKey());
228 static inline HashedStorage<StorageTy> getTombstoneKey() {
229 return HashedStorage<StorageTy>(
230 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
233 static inline unsigned getHashValue(
const HashedStorage<StorageTy> &key) {
236 static inline unsigned getHashValue(
const StorageTy &key) {
240 static inline bool isEqual(
const HashedStorage<StorageTy> &lhs,
241 const HashedStorage<StorageTy> &rhs) {
242 return lhs.storage == rhs.storage;
244 static inline bool isEqual(
const StorageTy &lhs,
245 const HashedStorage<StorageTy> &rhs) {
246 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
249 return lhs.isEqual(rhs.storage);
255 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
257 type,
llvm::hash_combine_range(set.begin(), set.
end()))),
258 set(std::move(set)), type(type) {}
260 bool isEqual(
const SetStorage *other)
const {
261 return hashcode == other->hashcode && set == other->set &&
266 const unsigned hashcode;
269 const SetVector<ElaboratorValue> set;
278 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
280 type,
llvm::hash_combine_range(bag.begin(), bag.
end()))),
281 bag(std::move(bag)), type(type) {}
283 bool isEqual(
const BagStorage *other)
const {
284 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
289 const unsigned hashcode;
293 const MapVector<ElaboratorValue, uint64_t> bag;
301struct SequenceStorage {
302 SequenceStorage(StringRef name, StringAttr familyName,
303 SmallVector<ElaboratorValue> &&args)
306 llvm::hash_combine_range(args.begin(), args.
end()))),
307 name(name), familyName(familyName), args(std::move(args)) {}
309 bool isEqual(
const SequenceStorage *other)
const {
310 return hashcode == other->hashcode && name == other->name &&
311 familyName == other->familyName && args == other->args;
315 const unsigned hashcode;
318 const StringRef name;
321 const StringAttr familyName;
324 const SmallVector<ElaboratorValue> args;
337 template <
typename StorageTy,
typename... Args>
338 StorageTy *internalize(Args &&...args) {
339 StorageTy storage(std::forward<Args>(args)...);
341 auto existing = getInternSet<StorageTy>().insert_as(
342 HashedStorage<StorageTy>(storage.hashcode), storage);
343 StorageTy *&storagePtr = existing.first->storage;
346 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
352 template <
typename StorageTy>
353 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
355 if constexpr (std::is_same_v<StorageTy, SetStorage>)
357 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
359 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
360 return internedSequences;
362 static_assert(!
sizeof(StorageTy),
363 "no intern set available for this storage type.");
368 llvm::BumpPtrAllocator allocator;
373 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
374 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
375 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
383static llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
384 const ElaboratorValue &value);
386static void print(TypedAttr val, llvm::raw_ostream &os) {
387 os <<
"<attr " << val <<
">";
390static void print(BagStorage *val, llvm::raw_ostream &os) {
392 llvm::interleaveComma(val->bag, os,
393 [&](
const std::pair<ElaboratorValue, uint64_t> &el) {
394 os << el.first <<
" -> " << el.second;
396 os <<
"} at " << val <<
">";
399static void print(
bool val, llvm::raw_ostream &os) {
400 os <<
"<bool " << (val ?
"true" :
"false") <<
">";
403static void print(
size_t val, llvm::raw_ostream &os) {
404 os <<
"<index " << val <<
">";
407static void print(SequenceStorage *val, llvm::raw_ostream &os) {
408 os <<
"<sequence @" << val->name <<
" derived from @"
409 << val->familyName.getValue() <<
"(";
410 llvm::interleaveComma(val->args, os,
411 [&](
const ElaboratorValue &val) { os << val; });
412 os <<
") at " << val <<
">";
415static void print(SetStorage *val, llvm::raw_ostream &os) {
417 llvm::interleaveComma(val->set, os,
418 [&](
const ElaboratorValue &val) { os << val; });
419 os <<
"} at " << val <<
">";
422static void print(
const VirtualRegister &val, llvm::raw_ostream &os) {
423 os <<
"<virtual-register " << val.id <<
" " << val.allowedRegs <<
">";
426static void print(
const LabelValue &val, llvm::raw_ostream &os) {
427 os <<
"<label " << val.id <<
" " << val.name <<
">";
431 const ElaboratorValue &value) {
432 std::visit([&](
auto val) {
print(val, os); }, value);
448 Materializer(OpBuilder builder) : builder(builder) {}
452 Value materialize(ElaboratorValue val, Location loc,
453 std::queue<SequenceStorage *> &elabRequests,
454 function_ref<InFlightDiagnostic()> emitError) {
455 auto iter = materializedValues.find(val);
456 if (iter != materializedValues.end())
459 LLVM_DEBUG(llvm::dbgs() <<
"Materializing " << val <<
"\n\n");
462 [&](
auto val) {
return visit(val, loc, elabRequests, emitError); },
472 LogicalResult materialize(Operation *op,
473 DenseMap<Value, ElaboratorValue> &state,
474 std::queue<SequenceStorage *> &elabRequests) {
475 if (op->getNumRegions() > 0)
476 return op->emitOpError(
"ops with nested regions must be elaborated away");
484 for (
auto res : op->getResults())
485 if (!res.use_empty())
486 return op->emitOpError(
487 "ops with results that have uses are not supported");
489 if (op->getParentRegion() == builder.getBlock()->getParent()) {
492 deleteOpsUntil([&](
auto iter) {
return &*iter == op; });
494 if (builder.getInsertionPoint() == builder.getBlock()->end())
495 return op->emitError(
"operation did not occur after the current "
496 "materializer insertion point");
498 LLVM_DEBUG(llvm::dbgs() <<
"Modifying in-place: " << *op <<
"\n\n");
500 LLVM_DEBUG(llvm::dbgs() <<
"Materializing a clone of " << *op <<
"\n\n");
501 op = builder.clone(*op);
502 builder.setInsertionPoint(op);
505 for (
auto &operand : op->getOpOperands()) {
506 auto emitError = [&]() {
507 auto diag = op->emitError();
508 diag.attachNote(op->getLoc())
509 <<
"while materializing value for operand#"
510 << operand.getOperandNumber();
514 Value val = materialize(state.at(operand.get()), op->getLoc(),
515 elabRequests, emitError);
522 builder.setInsertionPointAfter(op);
529 deleteOpsUntil([](
auto iter) {
return false; });
531 for (
auto *op :
llvm::reverse(toDelete))
536 void deleteOpsUntil(function_ref<
bool(Block::iterator)> stop) {
537 auto ip = builder.getInsertionPoint();
538 while (ip != builder.getBlock()->end() && !stop(ip)) {
539 LLVM_DEBUG(llvm::dbgs() <<
"Marking to be deleted: " << *ip <<
"\n\n");
540 toDelete.push_back(&*ip);
542 builder.setInsertionPointAfter(&*ip);
543 ip = builder.getInsertionPoint();
547 Value visit(TypedAttr val, Location loc,
548 std::queue<SequenceStorage *> &elabRequests,
549 function_ref<InFlightDiagnostic()> emitError) {
552 if (
auto intAttr = dyn_cast<IntegerAttr>(val);
553 intAttr && isa<IndexType>(val.getType())) {
554 Value res = builder.create<index::ConstantOp>(loc, intAttr);
555 materializedValues[val] = res;
562 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
564 emitError() <<
"materializer of dialect '"
565 << val.getDialect().getNamespace()
566 <<
"' unable to materialize value for attribute '" << val
571 Value res = op->getResult(0);
572 materializedValues[val] = res;
576 Value visit(
size_t val, Location loc,
577 std::queue<SequenceStorage *> &elabRequests,
578 function_ref<InFlightDiagnostic()> emitError) {
579 Value res = builder.create<index::ConstantOp>(loc, val);
580 materializedValues[val] = res;
584 Value visit(
bool val, Location loc,
585 std::queue<SequenceStorage *> &elabRequests,
586 function_ref<InFlightDiagnostic()> emitError) {
587 Value res = builder.create<index::BoolConstantOp>(loc, val);
588 materializedValues[val] = res;
592 Value visit(SetStorage *val, Location loc,
593 std::queue<SequenceStorage *> &elabRequests,
594 function_ref<InFlightDiagnostic()> emitError) {
595 SmallVector<Value> elements;
596 elements.reserve(val->set.size());
597 for (
auto el : val->set) {
598 auto materialized = materialize(el, loc, elabRequests, emitError);
602 elements.push_back(materialized);
605 auto res = builder.create<SetCreateOp>(loc, val->type, elements);
606 materializedValues[val] = res;
610 Value visit(BagStorage *val, Location loc,
611 std::queue<SequenceStorage *> &elabRequests,
612 function_ref<InFlightDiagnostic()> emitError) {
613 SmallVector<Value> values, weights;
614 values.reserve(val->bag.size());
615 weights.reserve(val->bag.size());
616 for (
auto [val, weight] : val->bag) {
617 auto materializedVal = materialize(val, loc, elabRequests, emitError);
618 auto materializedWeight =
619 materialize(weight, loc, elabRequests, emitError);
620 if (!materializedVal || !materializedWeight)
623 values.push_back(materializedVal);
624 weights.push_back(materializedWeight);
627 auto res = builder.create<BagCreateOp>(loc, val->type, values, weights);
628 materializedValues[val] = res;
632 Value visit(SequenceStorage *val, Location loc,
633 std::queue<SequenceStorage *> &elabRequests,
634 function_ref<InFlightDiagnostic()> emitError) {
635 elabRequests.push(val);
636 return builder.create<SequenceClosureOp>(loc, val->name, ValueRange());
639 Value visit(
const VirtualRegister &val, Location loc,
640 std::queue<SequenceStorage *> &elabRequests,
641 function_ref<InFlightDiagnostic()> emitError) {
642 auto res = builder.create<VirtualRegisterOp>(loc, val.allowedRegs);
643 materializedValues[val] = res;
647 Value visit(
const LabelValue &val, Location loc,
648 std::queue<SequenceStorage *> &elabRequests,
649 function_ref<InFlightDiagnostic()> emitError) {
651 auto res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
652 materializedValues[val] = res;
656 auto res = builder.create<LabelUniqueDeclOp>(loc, val.name, ValueRange());
657 materializedValues[val] = res;
667 DenseMap<ElaboratorValue, Value> materializedValues;
673 SmallVector<Operation *> toDelete;
682enum class DeletionKind { Keep, Delete };
685struct ElaboratorSharedState {
686 ElaboratorSharedState(SymbolTable &table,
unsigned seed)
687 : table(table), rng(seed) {}
693 Internalizer internalizer;
697 std::queue<SequenceStorage *> worklist;
699 uint64_t virtualRegisterID = 0;
700 uint64_t uniqueLabelID = 1;
704class Elaborator :
public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
707 using RTGBase::visitOp;
709 Elaborator(ElaboratorSharedState &sharedState, Materializer &materializer)
710 : sharedState(sharedState), materializer(materializer) {}
712 template <
typename ValueTy>
713 inline ValueTy
get(Value val)
const {
714 return std::get<ValueTy>(state.at(val));
717 FailureOr<DeletionKind> visitConstantLike(Operation *op) {
718 assert(op->hasTrait<OpTrait::ConstantLike>() &&
719 "op is expected to be constant-like");
721 SmallVector<OpFoldResult, 1> result;
722 auto foldResult = op->fold(result);
724 assert(succeeded(foldResult) &&
725 "constant folder of a constant-like must always succeed");
726 auto attr = dyn_cast<TypedAttr>(result[0].dyn_cast<Attribute>());
728 return op->emitError(
729 "only typed attributes supported for constant-like operations");
731 auto intAttr = dyn_cast<IntegerAttr>(attr);
732 if (intAttr && isa<IndexType>(attr.getType()))
733 state[op->getResult(0)] = size_t(intAttr.getInt());
734 else if (intAttr && intAttr.getType().isSignlessInteger(1))
735 state[op->getResult(0)] = bool(intAttr.getInt());
737 state[op->getResult(0)] = attr;
739 return DeletionKind::Delete;
744 return op->emitOpError(
"elaboration not supported");
748 if (op->hasTrait<OpTrait::ConstantLike>())
749 return visitConstantLike(op);
755 return DeletionKind::Keep;
760 FailureOr<DeletionKind> visitOp(SequenceClosureOp op) {
761 SmallVector<ElaboratorValue> args;
762 for (
auto arg : op.getArgs())
763 args.push_back(state.at(arg));
765 auto familyName = op.getSequenceAttr();
766 auto name = sharedState.names.newName(familyName.getValue());
767 state[op.getResult()] =
768 sharedState.internalizer.internalize<SequenceStorage>(name, familyName,
770 return DeletionKind::Delete;
773 FailureOr<DeletionKind> visitOp(InvokeSequenceOp op) {
774 return DeletionKind::Keep;
777 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
778 SetVector<ElaboratorValue> set;
779 for (
auto val : op.getElements())
780 set.insert(state.at(val));
782 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
783 std::move(set), op.getSet().getType());
784 return DeletionKind::Delete;
787 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
788 auto set = get<SetStorage *>(op.getSet())->set;
792 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
793 std::mt19937 customRng(intAttr.getInt());
799 state[op.getResult()] = set[selected];
800 return DeletionKind::Delete;
803 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
804 auto original = get<SetStorage *>(op.getOriginal())->set;
805 auto diff = get<SetStorage *>(op.getDiff())->set;
807 SetVector<ElaboratorValue> result(original);
808 result.set_subtract(diff);
810 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
811 std::move(result), op.getResult().getType());
812 return DeletionKind::Delete;
815 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
816 SetVector<ElaboratorValue> result;
817 for (
auto set : op.getSets())
818 result.set_union(
get<SetStorage *>(set)->set);
820 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
821 std::move(result), op.getType());
822 return DeletionKind::Delete;
825 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
826 auto size = get<SetStorage *>(op.getSet())->set.size();
827 state[op.getResult()] = size;
828 return DeletionKind::Delete;
831 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
832 MapVector<ElaboratorValue, uint64_t> bag;
833 for (
auto [val, multiple] :
834 llvm::zip(op.getElements(), op.getMultiples())) {
838 bag[state.at(val)] += get<size_t>(multiple);
841 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
842 std::move(bag), op.getType());
843 return DeletionKind::Delete;
846 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
847 auto bag = get<BagStorage *>(op.getBag())->bag;
849 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
850 prefixSum.reserve(bag.size());
851 uint32_t accumulator = 0;
852 for (
auto [val, weight] : bag) {
853 accumulator += weight;
854 prefixSum.push_back({val, accumulator});
857 auto customRng = sharedState.rng;
859 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
860 customRng = std::mt19937(intAttr.getInt());
864 auto *iter = llvm::upper_bound(
866 [](uint32_t a,
const std::pair<ElaboratorValue, uint32_t> &b) {
870 state[op.getResult()] = iter->first;
871 return DeletionKind::Delete;
874 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
875 auto original = get<BagStorage *>(op.getOriginal())->bag;
876 auto diff = get<BagStorage *>(op.getDiff())->bag;
878 MapVector<ElaboratorValue, uint64_t> result;
879 for (
const auto &el : original) {
880 if (!diff.contains(el.first)) {
888 auto toDiff = diff.lookup(el.first);
889 if (el.second <= toDiff)
892 result.insert({el.first, el.second - toDiff});
895 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
896 std::move(result), op.getType());
897 return DeletionKind::Delete;
900 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
901 MapVector<ElaboratorValue, uint64_t> result;
902 for (
auto bag : op.getBags()) {
903 auto val = get<BagStorage *>(bag)->bag;
904 for (
auto [el, multiple] : val)
905 result[el] += multiple;
908 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
909 std::move(result), op.getType());
910 return DeletionKind::Delete;
913 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
914 auto size = get<BagStorage *>(op.getBag())->bag.size();
915 state[op.getResult()] = size;
916 return DeletionKind::Delete;
919 FailureOr<DeletionKind> visitOp(FixedRegisterOp op) {
920 return visitConstantLike(op);
923 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
924 state[op.getResult()] = VirtualRegister(sharedState.virtualRegisterID++,
925 op.getAllowedRegsAttr());
926 return DeletionKind::Delete;
929 StringAttr substituteFormatString(StringAttr formatString,
930 ValueRange substitutes)
const {
931 if (substitutes.empty() || formatString.empty())
934 auto original = formatString.getValue().str();
935 for (
auto [i, subst] :
llvm::enumerate(substitutes)) {
937 std::string from =
"{{" + std::to_string(i) +
"}}";
938 while ((startPos = original.find(from, startPos)) != std::string::npos) {
939 auto substString = std::to_string(get<size_t>(subst));
940 original.replace(startPos, from.length(), substString);
944 return StringAttr::get(formatString.getContext(), original);
947 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
949 substituteFormatString(op.getFormatStringAttr(), op.getArgs());
950 sharedState.labelNames.add(substituted.getValue());
951 state[op.getLabel()] = LabelValue(substituted);
952 return DeletionKind::Delete;
955 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
956 state[op.getLabel()] = LabelValue(
957 substituteFormatString(op.getFormatStringAttr(), op.getArgs()),
958 sharedState.uniqueLabelID++);
959 return DeletionKind::Delete;
962 FailureOr<DeletionKind> visitOp(LabelOp op) {
return DeletionKind::Keep; }
964 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
965 bool cond = get<bool>(op.getCondition());
966 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
967 if (toElaborate.empty())
968 return DeletionKind::Delete;
974 if (failed(elaborate(toElaborate)))
978 for (
auto [res, out] :
979 llvm::zip(op.getResults(),
980 toElaborate.front().getTerminator()->getOperands()))
981 state[res] = state.at(out);
983 return DeletionKind::Delete;
986 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
987 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
988 std::holds_alternative<size_t>(state.at(op.getStep())) &&
989 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
990 return op->emitOpError(
"can only elaborate index type iterator");
992 auto lowerBound = get<size_t>(op.getLowerBound());
993 auto step = get<size_t>(op.getStep());
994 auto upperBound = get<size_t>(op.getUpperBound());
1000 state[op.getInductionVar()] = lowerBound;
1001 for (
auto [iterArg, initArg] :
1002 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1003 state[iterArg] = state.at(initArg);
1006 for (
size_t i = lowerBound; i < upperBound; i += step) {
1007 if (failed(elaborate(op.getBodyRegion())))
1012 state[op.getInductionVar()] = i + step;
1013 for (
auto [iterArg, prevIterArg] :
1014 llvm::zip(op.getRegionIterArgs(),
1015 op.getBody()->getTerminator()->getOperands()))
1016 state[iterArg] = state.at(prevIterArg);
1020 for (
auto [res, iterArg] :
1021 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1022 state[res] = state.at(iterArg);
1024 return DeletionKind::Delete;
1027 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1028 return DeletionKind::Delete;
1031 FailureOr<DeletionKind> visitOp(index::AddOp op) {
1032 size_t lhs = get<size_t>(op.getLhs());
1033 size_t rhs = get<size_t>(op.getRhs());
1034 state[op.getResult()] = lhs + rhs;
1035 return DeletionKind::Delete;
1038 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
1039 size_t lhs = get<size_t>(op.getLhs());
1040 size_t rhs = get<size_t>(op.getRhs());
1042 switch (op.getPred()) {
1043 case index::IndexCmpPredicate::EQ:
1044 result = lhs == rhs;
1046 case index::IndexCmpPredicate::NE:
1047 result = lhs != rhs;
1049 case index::IndexCmpPredicate::ULT:
1052 case index::IndexCmpPredicate::ULE:
1053 result = lhs <= rhs;
1055 case index::IndexCmpPredicate::UGT:
1058 case index::IndexCmpPredicate::UGE:
1059 result = lhs >= rhs;
1062 return op->emitOpError(
"elaboration not supported");
1064 state[op.getResult()] = result;
1065 return DeletionKind::Delete;
1069 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
1072 index::AddOp, index::CmpOp,
1074 scf::IfOp, scf::ForOp, scf::YieldOp>(
1075 [&](
auto op) {
return visitOp(op); })
1076 .Default([&](Operation *op) {
return RTGBase::dispatchOpVisitor(op); });
1080 LogicalResult elaborate(Region ®ion,
1081 ArrayRef<ElaboratorValue> regionArguments = {}) {
1082 if (region.getBlocks().size() > 1)
1083 return region.getParentOp()->emitOpError(
1084 "regions with more than one block are not supported");
1086 for (
auto [arg, elabArg] :
1087 llvm::zip(region.getArguments(), regionArguments))
1088 state[arg] = elabArg;
1090 Block *block = ®ion.front();
1091 for (
auto &op : *block) {
1096 if (*result == DeletionKind::Keep)
1097 if (failed(materializer.materialize(&op, state, sharedState.worklist)))
1101 llvm::dbgs() <<
"Elaborated " << op <<
" to\n[";
1103 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](
auto res) {
1104 if (state.contains(res))
1105 llvm::dbgs() << state.at(res);
1107 llvm::dbgs() <<
"unknown";
1110 llvm::dbgs() <<
"]\n\n";
1119 ElaboratorSharedState &sharedState;
1123 Materializer &materializer;
1126 DenseMap<Value, ElaboratorValue> state;
1135struct ElaborationPass
1136 :
public rtg::impl::ElaborationPassBase<ElaborationPass> {
1139 void runOnOperation()
override;
1140 void cloneTargetsIntoTests(SymbolTable &table);
1141 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
1142 LogicalResult inlineSequences(TestOp testOp, SymbolTable &table);
1146void ElaborationPass::runOnOperation() {
1147 auto moduleOp = getOperation();
1148 SymbolTable table(moduleOp);
1150 cloneTargetsIntoTests(table);
1152 if (failed(elaborateModule(moduleOp, table)))
1153 return signalPassFailure();
1156void ElaborationPass::cloneTargetsIntoTests(SymbolTable &table) {
1157 auto moduleOp = getOperation();
1158 for (
auto target :
llvm::make_early_inc_range(moduleOp.getOps<TargetOp>())) {
1159 for (
auto test : moduleOp.getOps<TestOp>()) {
1161 if (test.getTarget().getEntries().empty())
1166 if (target.getTarget() != test.getTarget())
1169 IRRewriter rewriter(test);
1171 auto newTest = cast<TestOp>(test->clone());
1172 newTest.setSymName(test.getSymName().str() +
"_" +
1173 target.getSymName().str());
1174 table.insert(newTest, rewriter.getInsertionPoint());
1178 rewriter.setInsertionPointToStart(newTest.getBody());
1179 for (
auto &op : target.getBody()->without_terminator())
1180 rewriter.clone(op, mapping);
1182 for (
auto [returnVal, result] :
1183 llvm::zip(target.getBody()->getTerminator()->getOperands(),
1184 newTest.getBody()->getArguments()))
1185 result.replaceAllUsesWith(mapping.lookup(returnVal));
1187 newTest.getBody()->eraseArguments(0,
1188 newTest.getBody()->getNumArguments());
1189 newTest.setTarget(DictType::get(&getContext(), {}));
1196 for (
auto test :
llvm::make_early_inc_range(moduleOp.getOps<TestOp>()))
1197 if (!test.getTarget().getEntries().
empty())
1201LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
1202 SymbolTable &table) {
1203 ElaboratorSharedState state(table, seed);
1206 state.names.add(moduleOp);
1210 for (
auto testOp : moduleOp.getOps<TestOp>()) {
1211 LLVM_DEBUG(llvm::dbgs()
1212 <<
"\n=== Elaborating test @" << testOp.getSymName() <<
"\n\n");
1213 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()));
1214 Elaborator elaborator(state, materializer);
1215 if (failed(elaborator.elaborate(testOp.getBodyRegion())))
1218 materializer.finalize();
1223 while (!state.worklist.empty()) {
1224 auto *curr = state.worklist.front();
1225 state.worklist.pop();
1227 if (table.lookup<SequenceOp>(curr->name))
1230 auto familyOp = table.lookup<SequenceOp>(curr->familyName);
1233 OpBuilder builder(familyOp);
1234 auto seqOp = builder.cloneWithoutRegions(familyOp);
1235 seqOp.getBodyRegion().emplaceBlock();
1236 seqOp.setSymName(curr->name);
1237 table.insert(seqOp);
1238 assert(seqOp.getSymName() == curr->name &&
"should not have been renamed");
1240 LLVM_DEBUG(llvm::dbgs()
1241 <<
"\n=== Elaborating sequence family @" << familyOp.getSymName()
1242 <<
" into @" << seqOp.getSymName() <<
"\n\n");
1244 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()));
1245 Elaborator elaborator(state, materializer);
1246 if (failed(elaborator.elaborate(familyOp.getBodyRegion(), curr->args)))
1249 materializer.finalize();
1252 for (
auto testOp : moduleOp.getOps<TestOp>()) {
1254 if (failed(inlineSequences(testOp, table)))
1260 llvm::make_early_inc_range(testOp.getOps<LabelUniqueDeclOp>())) {
1261 IRRewriter rewriter(labelOp);
1262 auto newName = state.labelNames.newName(labelOp.getFormatString());
1263 rewriter.replaceOpWithNewOp<LabelDeclOp>(labelOp, newName, ValueRange());
1269 for (
auto seqOp :
llvm::make_early_inc_range(moduleOp.getOps<SequenceOp>()))
1275LogicalResult ElaborationPass::inlineSequences(TestOp testOp,
1276 SymbolTable &table) {
1277 OpBuilder builder(testOp);
1278 for (
auto iter = testOp.getBody()->begin();
1279 iter != testOp.getBody()->end();) {
1280 auto invokeOp = dyn_cast<InvokeSequenceOp>(&*iter);
1287 invokeOp.getSequence().getDefiningOp<SequenceClosureOp>();
1289 return invokeOp->emitError(
1290 "sequence operand not directly defined by sequence_closure op");
1292 auto seqOp = table.lookup<SequenceOp>(seqClosureOp.getSequenceAttr());
1294 builder.setInsertionPointAfter(invokeOp);
1296 for (
auto &op : *seqOp.getBody())
1297 builder.clone(op, mapping);
1301 if (seqClosureOp->use_empty())
1302 seqClosureOp->erase();
assert(baseType &&"element must be base type")
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static void print(TypedAttr val, llvm::raw_ostream &os)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
This helps visit TypeOp nodes.
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
ResultType dispatchOpVisitor(Operation *op, ExtraArgs... args)
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
static llvm::hash_code hash_value(const ModulePort &port)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static LabelValue getEmptyKey()
static LabelValue getTombstoneKey()
static VirtualRegister getEmptyKey()
static bool isEqual(const VirtualRegister &lhs, const VirtualRegister &rhs)
static unsigned getHashValue(const VirtualRegister &val)
static VirtualRegister getTombstoneKey()
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)
static unsigned getEmptyKey()