20#include "mlir/Dialect/Index/IR/IndexDialect.h"
21#include "mlir/Dialect/Index/IR/IndexOps.h"
22#include "mlir/IR/IRMapping.h"
23#include "mlir/IR/PatternMatch.h"
24#include "llvm/Support/Debug.h"
30#define GEN_PASS_DEF_ELABORATIONPASS
31#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
40#define DEBUG_TYPE "rtg-elaboration"
52 size_t n = w / 32 + (w % 32 != 0);
54 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
59 const uint32_t diff = b - a + 1;
63 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
67 uint32_t width = digits - llvm::countl_zero(diff) - 1;
68 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
87struct ElaboratorValue {
89 enum class ValueKind { Attribute, Set, Bag, Sequence, Index, Bool };
91 ElaboratorValue(ValueKind kind) : kind(kind) {}
92 virtual ~ElaboratorValue() {}
94 virtual llvm::hash_code getHashValue()
const = 0;
95 virtual bool isEqual(
const ElaboratorValue &other)
const = 0;
98 virtual void print(llvm::raw_ostream &os)
const = 0;
101 ValueKind getKind()
const {
return kind; }
104 const ValueKind kind;
112class AttributeValue :
public ElaboratorValue {
114 AttributeValue(TypedAttr attr)
115 : ElaboratorValue(ValueKind::Attribute), attr(attr) {
116 assert(attr &&
"null attributes not allowed");
120 static bool classof(
const ElaboratorValue *val) {
121 return val->getKind() == ValueKind::Attribute;
124 llvm::hash_code getHashValue()
const override {
125 return llvm::hash_combine(attr);
128 bool isEqual(
const ElaboratorValue &other)
const override {
129 auto *attrValue = dyn_cast<AttributeValue>(&other);
133 return attr == attrValue->attr;
137 void print(llvm::raw_ostream &os)
const override {
138 os <<
"<attr " << attr <<
" at " <<
this <<
">";
142 TypedAttr getAttr()
const {
return attr; }
145 const TypedAttr attr;
149class IndexValue :
public ElaboratorValue {
151 IndexValue(
size_t index) : ElaboratorValue(ValueKind::Index), index(index) {}
154 static bool classof(
const ElaboratorValue *val) {
155 return val->getKind() == ValueKind::Index;
158 llvm::hash_code getHashValue()
const override {
159 return llvm::hash_value(index);
162 bool isEqual(
const ElaboratorValue &other)
const override {
163 auto *indexValue = dyn_cast<IndexValue>(&other);
167 return index == indexValue->index;
171 void print(llvm::raw_ostream &os)
const override {
172 os <<
"<index " << index <<
" at " <<
this <<
">";
176 size_t getIndex()
const {
return index; }
183class BoolValue :
public ElaboratorValue {
185 BoolValue(
bool value) : ElaboratorValue(ValueKind::Bool), value(value) {}
188 static bool classof(
const ElaboratorValue *val) {
189 return val->getKind() == ValueKind::Bool;
192 llvm::hash_code getHashValue()
const override {
193 return llvm::hash_value(value);
196 bool isEqual(
const ElaboratorValue &other)
const override {
197 auto *val = dyn_cast<BoolValue>(&other);
201 return value == val->value;
205 void print(llvm::raw_ostream &os)
const override {
206 os <<
"<bool " << (value ?
"true" :
"false") <<
" at " <<
this <<
">";
210 bool getBool()
const {
return value; }
217class SetValue :
public ElaboratorValue {
219 SetValue(SetVector<ElaboratorValue *> &&set, Type type)
220 : ElaboratorValue(ValueKind::Set), set(std::move(set)), type(type),
222 llvm::hash_combine_range(set.begin(), set.
end()), type)) {}
225 static bool classof(
const ElaboratorValue *val) {
226 return val->getKind() == ValueKind::Set;
229 llvm::hash_code getHashValue()
const override {
return cachedHash; }
231 bool isEqual(
const ElaboratorValue &other)
const override {
232 auto *otherSet = dyn_cast<SetValue>(&other);
236 if (cachedHash != otherSet->cachedHash)
240 return set == otherSet->set && type == otherSet->type;
244 void print(llvm::raw_ostream &os)
const override {
246 llvm::interleaveComma(set, os, [&](ElaboratorValue *el) { el->print(os); });
247 os <<
"} at " <<
this <<
">";
251 const SetVector<ElaboratorValue *> &getSet()
const {
return set; }
253 Type getType()
const {
return type; }
260 const SetVector<ElaboratorValue *> set;
267 const llvm::hash_code cachedHash;
271class BagValue :
public ElaboratorValue {
273 BagValue(MapVector<ElaboratorValue *, uint64_t> &&bag, Type type)
274 : ElaboratorValue(ValueKind::Bag), bag(std::move(bag)), type(type),
276 llvm::hash_combine_range(bag.begin(), bag.
end()), type)) {}
279 static bool classof(
const ElaboratorValue *val) {
280 return val->getKind() == ValueKind::Bag;
283 llvm::hash_code getHashValue()
const override {
return cachedHash; }
285 bool isEqual(
const ElaboratorValue &other)
const override {
286 auto *otherBag = dyn_cast<BagValue>(&other);
290 if (cachedHash != otherBag->cachedHash)
293 return llvm::equal(bag, otherBag->bag) && type == otherBag->type;
297 void print(llvm::raw_ostream &os)
const override {
299 llvm::interleaveComma(bag, os,
300 [&](std::pair<ElaboratorValue *, uint64_t> el) {
302 os <<
" -> " << el.second;
304 os <<
"} at " <<
this <<
">";
308 const MapVector<ElaboratorValue *, uint64_t> &getBag()
const {
return bag; }
310 Type getType()
const {
return type; }
314 const MapVector<ElaboratorValue *, uint64_t> bag;
321 const llvm::hash_code cachedHash;
325class SequenceValue :
public ElaboratorValue {
327 SequenceValue(StringRef name, StringAttr familyName,
328 SmallVector<ElaboratorValue *> &&args)
329 : ElaboratorValue(ValueKind::Sequence), name(name),
330 familyName(familyName), args(std::move(args)),
332 llvm::hash_combine_range(this->args.begin(), this->args.
end()),
333 name, familyName)) {}
336 static bool classof(
const ElaboratorValue *val) {
337 return val->getKind() == ValueKind::Sequence;
340 llvm::hash_code getHashValue()
const override {
return cachedHash; }
342 bool isEqual(
const ElaboratorValue &other)
const override {
343 auto *otherSeq = dyn_cast<SequenceValue>(&other);
347 if (cachedHash != otherSeq->cachedHash)
350 return name == otherSeq->name && familyName == otherSeq->familyName &&
351 args == otherSeq->args;
355 void print(llvm::raw_ostream &os)
const override {
356 os <<
"<sequence @" << name <<
" derived from @" << familyName.getValue()
358 llvm::interleaveComma(args, os,
359 [&](ElaboratorValue *val) { val->print(os); });
360 os <<
") at " <<
this <<
">";
364 StringRef
getName()
const {
return name; }
365 StringAttr getFamilyName()
const {
return familyName; }
366 ArrayRef<ElaboratorValue *> getArgs()
const {
return args; }
369 const StringRef name;
370 const StringAttr familyName;
371 const SmallVector<ElaboratorValue *> args;
374 const llvm::hash_code cachedHash;
380 const ElaboratorValue &value) {
391static llvm::hash_code
hash_value(
const ElaboratorValue &val) {
392 return val.getHashValue();
396struct InternMapInfo :
public DenseMapInfo<ElaboratorValue *> {
397 static unsigned getHashValue(
const ElaboratorValue *value) {
398 assert(value != getTombstoneKey() && value != getEmptyKey());
402 static bool isEqual(
const ElaboratorValue *lhs,
const ElaboratorValue *rhs) {
406 auto *tk = getTombstoneKey();
407 auto *ek = getEmptyKey();
408 if (lhs == tk || rhs == tk || lhs == ek || rhs == ek)
411 return lhs->isEqual(*rhs);
425 Value materialize(ElaboratorValue *val, Location loc,
426 std::queue<SequenceValue *> &elabRequests,
427 function_ref<InFlightDiagnostic()> emitError) {
428 assert(block &&
"must call reset before calling this function");
430 auto iter = materializedValues.find(val);
431 if (iter != materializedValues.end())
434 LLVM_DEBUG(llvm::dbgs() <<
"Materializing " << *val <<
"\n\n");
436 OpBuilder builder(block, insertionPoint);
437 return TypeSwitch<ElaboratorValue *, Value>(val)
438 .Case<AttributeValue, IndexValue, BoolValue, SetValue, BagValue,
439 SequenceValue>([&](
auto val) {
440 return visit(val, builder, loc, elabRequests, emitError);
442 .Default([](
auto val) {
443 assert(
false &&
"all cases must be covered above");
448 Materializer &reset(Block *block) {
449 materializedValues.clear();
450 integerValues.clear();
452 insertionPoint = block->begin();
457 Value visit(AttributeValue *val, OpBuilder &builder, Location loc,
458 std::queue<SequenceValue *> &elabRequests,
459 function_ref<InFlightDiagnostic()> emitError) {
460 auto attr = val->getAttr();
464 if (
auto intAttr = dyn_cast<IntegerAttr>(attr);
465 intAttr && isa<IndexType>(attr.getType())) {
466 Value res = builder.create<index::ConstantOp>(loc, intAttr);
467 materializedValues[val] = res;
473 auto *op = attr.getDialect().materializeConstant(builder, attr,
474 attr.getType(), loc);
476 emitError() <<
"materializer of dialect '"
477 << attr.getDialect().getNamespace()
478 <<
"' unable to materialize value for attribute '" << attr
483 Value res = op->getResult(0);
484 materializedValues[val] = res;
488 Value visit(IndexValue *val, OpBuilder &builder, Location loc,
489 std::queue<SequenceValue *> &elabRequests,
490 function_ref<InFlightDiagnostic()> emitError) {
491 Value res = builder.create<index::ConstantOp>(loc, val->getIndex());
492 materializedValues[val] = res;
496 Value visit(BoolValue *val, OpBuilder &builder, Location loc,
497 std::queue<SequenceValue *> &elabRequests,
498 function_ref<InFlightDiagnostic()> emitError) {
499 Value res = builder.create<index::BoolConstantOp>(loc, val->getBool());
500 materializedValues[val] = res;
504 Value visit(SetValue *val, OpBuilder &builder, Location loc,
505 std::queue<SequenceValue *> &elabRequests,
506 function_ref<InFlightDiagnostic()> emitError) {
507 SmallVector<Value> elements;
508 elements.reserve(val->getSet().size());
509 for (
auto *el : val->getSet()) {
510 auto materialized = materialize(el, loc, elabRequests, emitError);
514 elements.push_back(materialized);
517 auto res = builder.create<SetCreateOp>(loc, val->getType(), elements);
518 materializedValues[val] = res;
522 Value visit(BagValue *val, OpBuilder &builder, Location loc,
523 std::queue<SequenceValue *> &elabRequests,
524 function_ref<InFlightDiagnostic()> emitError) {
525 SmallVector<Value> values, weights;
526 values.reserve(val->getBag().size());
527 weights.reserve(val->getBag().size());
528 for (
auto [val, weight] : val->getBag()) {
529 auto materializedVal = materialize(val, loc, elabRequests, emitError);
530 if (!materializedVal)
533 auto iter = integerValues.find(weight);
534 Value materializedWeight;
535 if (iter != integerValues.end()) {
536 materializedWeight = iter->second;
538 materializedWeight = builder.create<index::ConstantOp>(
539 loc, builder.getIndexAttr(weight));
540 integerValues[weight] = materializedWeight;
543 values.push_back(materializedVal);
544 weights.push_back(materializedWeight);
548 builder.create<BagCreateOp>(loc, val->getType(), values, weights);
549 materializedValues[val] = res;
553 Value visit(SequenceValue *val, OpBuilder &builder, Location loc,
554 std::queue<SequenceValue *> &elabRequests,
555 function_ref<InFlightDiagnostic()> emitError) {
556 elabRequests.push(val);
557 return builder.create<SequenceClosureOp>(loc, val->getName(), ValueRange());
566 DenseMap<ElaboratorValue *, Value> materializedValues;
567 DenseMap<uint64_t, Value> integerValues;
572 Block::iterator insertionPoint;
577enum class DeletionKind { Keep, Delete };
580class Elaborator :
public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
583 using RTGBase::visitOp;
584 using RTGBase::visitRegisterOp;
586 Elaborator(SymbolTable &table, std::mt19937 &rng) : rng(rng), table(table) {}
590 template <
typename ValueTy,
typename... Args>
591 void internalizeResult(Value val, Args &&...args) {
593 auto ptr = std::make_unique<ValueTy>(std::forward<Args>(args)...);
595 auto [iter, _] = interned.insert({e, std::move(ptr)});
596 state[val] = iter->second.get();
601 return op->emitOpError(
"elaboration not supported");
609 return DeletionKind::Keep;
614 FailureOr<DeletionKind> visitOp(SequenceClosureOp op) {
615 SmallVector<ElaboratorValue *> args;
616 for (
auto arg : op.getArgs())
617 args.push_back(state.at(arg));
619 auto familyName = op.getSequenceAttr();
620 auto name = names.newName(familyName.getValue());
621 internalizeResult<SequenceValue>(op.getResult(), name, familyName,
623 return DeletionKind::Delete;
626 FailureOr<DeletionKind> visitOp(InvokeSequenceOp op) {
627 return DeletionKind::Keep;
630 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
631 SetVector<ElaboratorValue *> set;
632 for (
auto val : op.getElements())
633 set.insert(state.at(val));
635 internalizeResult<SetValue>(op.getSet(), std::move(set),
636 op.getSet().getType());
637 return DeletionKind::Delete;
640 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
641 auto *set = cast<SetValue>(state.at(op.getSet()));
645 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
646 std::mt19937 customRng(intAttr.getInt());
652 state[op.getResult()] = set->getSet()[selected];
653 return DeletionKind::Delete;
656 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
657 auto original = cast<SetValue>(state.at(op.getOriginal()))->getSet();
658 auto diff = cast<SetValue>(state.at(op.getDiff()))->getSet();
660 SetVector<ElaboratorValue *> result(original);
661 result.set_subtract(diff);
663 internalizeResult<SetValue>(op.getResult(), std::move(result),
664 op.getResult().getType());
665 return DeletionKind::Delete;
668 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
669 SetVector<ElaboratorValue *> result;
670 for (
auto set : op.getSets())
671 result.set_union(cast<SetValue>(state.at(set))->getSet());
673 internalizeResult<SetValue>(op.getResult(), std::move(result),
675 return DeletionKind::Delete;
678 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
679 auto size = cast<SetValue>(state.at(op.getSet()))->getSet().size();
680 auto sizeAttr = IntegerAttr::get(IndexType::get(op->getContext()), size);
681 internalizeResult<AttributeValue>(op.getResult(), sizeAttr);
682 return DeletionKind::Delete;
685 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
686 MapVector<ElaboratorValue *, uint64_t> bag;
687 for (
auto [val, multiple] :
688 llvm::zip(op.getElements(), op.getMultiples())) {
689 auto *interpValue = state.at(val);
693 auto *interpMultiple = cast<IndexValue>(state.at(multiple));
694 bag[interpValue] += interpMultiple->getIndex();
697 internalizeResult<BagValue>(op.getBag(), std::move(bag), op.getType());
698 return DeletionKind::Delete;
701 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
702 auto *bag = cast<BagValue>(state.at(op.getBag()));
704 SmallVector<std::pair<ElaboratorValue *, uint32_t>> prefixSum;
705 prefixSum.reserve(bag->getBag().size());
706 uint32_t accumulator = 0;
707 for (
auto [val, weight] : bag->getBag()) {
708 accumulator += weight;
709 prefixSum.push_back({val, accumulator});
712 auto customRng = rng;
714 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
715 customRng = std::mt19937(intAttr.getInt());
719 auto *iter = llvm::upper_bound(
721 [](uint32_t a,
const std::pair<ElaboratorValue *, uint32_t> &b) {
724 state[op.getResult()] = iter->first;
725 return DeletionKind::Delete;
728 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
729 auto *original = cast<BagValue>(state.at(op.getOriginal()));
730 auto *diff = cast<BagValue>(state.at(op.getDiff()));
732 MapVector<ElaboratorValue *, uint64_t> result;
733 for (
const auto &el : original->getBag()) {
734 if (!diff->getBag().contains(el.first)) {
742 auto toDiff = diff->getBag().lookup(el.first);
743 if (el.second <= toDiff)
746 result.insert({el.first, el.second - toDiff});
749 internalizeResult<BagValue>(op.getResult(), std::move(result),
751 return DeletionKind::Delete;
754 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
755 MapVector<ElaboratorValue *, uint64_t> result;
756 for (
auto bag : op.getBags()) {
757 auto *val = cast<BagValue>(state.at(bag));
758 for (
auto [el, multiple] : val->getBag())
759 result[el] += multiple;
762 internalizeResult<BagValue>(op.getResult(), std::move(result),
764 return DeletionKind::Delete;
767 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
768 auto size = cast<BagValue>(state.at(op.getBag()))->getBag().size();
769 auto sizeAttr = IntegerAttr::get(IndexType::get(op->getContext()), size);
770 internalizeResult<AttributeValue>(op.getResult(), sizeAttr);
771 return DeletionKind::Delete;
774 FailureOr<DeletionKind> visitOp(index::AddOp op) {
775 size_t lhs = cast<IndexValue>(state.at(op.getLhs()))->getIndex();
776 size_t rhs = cast<IndexValue>(state.at(op.getRhs()))->getIndex();
777 internalizeResult<IndexValue>(op.getResult(), lhs + rhs);
778 return DeletionKind::Delete;
781 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
782 size_t lhs = cast<IndexValue>(state.at(op.getLhs()))->getIndex();
783 size_t rhs = cast<IndexValue>(state.at(op.getRhs()))->getIndex();
785 switch (op.getPred()) {
786 case index::IndexCmpPredicate::EQ:
789 case index::IndexCmpPredicate::NE:
792 case index::IndexCmpPredicate::ULT:
795 case index::IndexCmpPredicate::ULE:
798 case index::IndexCmpPredicate::UGT:
801 case index::IndexCmpPredicate::UGE:
805 return op->emitOpError(
"elaboration not supported");
807 internalizeResult<BoolValue>(op.getResult(), result);
808 return DeletionKind::Delete;
812 if (op->hasTrait<OpTrait::ConstantLike>()) {
813 SmallVector<OpFoldResult, 1> result;
814 auto foldResult = op->fold(result);
816 assert(succeeded(foldResult) &&
817 "constant folder of a constant-like must always succeed");
818 auto attr = dyn_cast<TypedAttr>(result[0].dyn_cast<Attribute>());
820 return op->emitError(
821 "only typed attributes supported for constant-like operations");
823 auto intAttr = dyn_cast<IntegerAttr>(attr);
824 if (intAttr && isa<IndexType>(attr.getType()))
825 internalizeResult<IndexValue>(op->getResult(0), intAttr.getInt());
826 else if (intAttr && intAttr.getType().isSignlessInteger(1))
827 internalizeResult<BoolValue>(op->getResult(0), intAttr.getInt());
829 internalizeResult<AttributeValue>(op->getResult(0), attr);
831 return DeletionKind::Delete;
834 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
835 .Case<index::AddOp, index::CmpOp>([&](
auto op) {
return visitOp(op); })
836 .Default([&](Operation *op) {
return RTGBase::dispatchOpVisitor(op); });
839 LogicalResult elaborate(SequenceOp family, SequenceOp dest,
840 ArrayRef<ElaboratorValue *> args) {
841 LLVM_DEBUG(llvm::dbgs() <<
"\n=== Elaborating " << family.getOperationName()
842 <<
" @" << family.getSymName() <<
" into @"
843 << dest.getSymName() <<
"\n\n");
850 materializer.reset(dest.getBody());
853 for (
auto [arg, elabArg] :
854 llvm::zip(family.getBody()->getArguments(), args))
855 state[arg] = elabArg;
857 for (
auto &op : *family.getBody()) {
858 if (op.getNumRegions() != 0)
859 return op.emitOpError(
"nested regions not supported");
865 if (*result == DeletionKind::Keep) {
866 for (
auto &operand : op.getOpOperands()) {
867 if (mapping.contains(operand.get()))
870 auto emitError = [&]() {
871 auto diag = op.emitError();
872 diag.attachNote(op.getLoc())
873 <<
"while materializing value for operand#"
874 << operand.getOperandNumber();
877 Value val = materializer.materialize(
878 state.at(operand.get()), op.getLoc(), worklist, emitError);
882 mapping.map(operand.get(), val);
885 OpBuilder builder = OpBuilder::atBlockEnd(dest.getBody());
886 builder.clone(op, mapping);
890 llvm::dbgs() <<
"Elaborating " << op <<
" to\n[";
892 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](
auto res) {
893 if (state.contains(res))
894 llvm::dbgs() << *state.at(res);
896 llvm::dbgs() <<
"unknown";
899 llvm::dbgs() <<
"]\n\n";
906 template <
typename OpTy>
907 LogicalResult elaborateInPlace(OpTy op) {
908 LLVM_DEBUG(llvm::dbgs()
909 <<
"\n=== Elaborating (in place) " << op.getOperationName()
910 <<
" @" << op.getSymName() <<
"\n\n");
917 materializer.reset(op.getBody());
919 SmallVector<Operation *> toDelete;
920 for (
auto &op : *op.getBody()) {
921 if (op.getNumRegions() != 0)
922 return op.emitOpError(
"nested regions not supported");
928 if (*result == DeletionKind::Keep) {
929 for (
auto &operand : op.getOpOperands()) {
930 auto emitError = [&]() {
931 auto diag = op.emitError();
932 diag.attachNote(op.getLoc())
933 <<
"while materializing value for operand#"
934 << operand.getOperandNumber();
937 Value val = materializer.materialize(
938 state.at(operand.get()), op.getLoc(), worklist, emitError);
944 toDelete.push_back(&op);
948 llvm::dbgs() <<
"Elaborating " << op <<
" to\n[";
950 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](
auto res) {
951 if (state.contains(res))
952 llvm::dbgs() << *state.at(res);
954 llvm::dbgs() <<
"unknown";
957 llvm::dbgs() <<
"]\n\n";
961 for (
auto *op :
llvm::reverse(toDelete))
967 LogicalResult inlineSequences(TestOp testOp) {
968 OpBuilder builder(testOp);
969 for (
auto iter = testOp.getBody()->begin();
970 iter != testOp.getBody()->end();) {
971 auto invokeOp = dyn_cast<InvokeSequenceOp>(&*iter);
978 invokeOp.getSequence().getDefiningOp<SequenceClosureOp>();
980 return invokeOp->emitError(
981 "sequence operand not directly defined by sequence_closure op");
983 auto seqOp = table.lookup<SequenceOp>(seqClosureOp.getSequenceAttr());
985 builder.setInsertionPointAfter(invokeOp);
987 for (
auto &op : *seqOp.getBody())
988 builder.clone(op, mapping);
992 if (seqClosureOp->use_empty())
993 seqClosureOp->erase();
999 LogicalResult elaborateModule(ModuleOp moduleOp) {
1002 names.add(moduleOp);
1006 for (
auto testOp : moduleOp.getOps<TestOp>())
1007 if (failed(elaborateInPlace(testOp)))
1012 while (!worklist.empty()) {
1013 auto *curr = worklist.front();
1016 if (table.lookup<SequenceOp>(curr->getName()))
1019 auto familyOp = table.lookup<SequenceOp>(curr->getFamilyName());
1022 OpBuilder builder(familyOp);
1023 auto seqOp = builder.cloneWithoutRegions(familyOp);
1024 seqOp.getBodyRegion().emplaceBlock();
1025 seqOp.setSymName(curr->getName());
1026 table.insert(seqOp);
1027 assert(seqOp.getSymName() == curr->getName() &&
1028 "should not have been renamed");
1030 if (failed(elaborate(familyOp, seqOp, curr->getArgs())))
1035 for (
auto testOp : moduleOp.getOps<TestOp>())
1036 if (failed(inlineSequences(testOp)))
1041 for (
auto seqOp :
llvm::make_early_inc_range(moduleOp.getOps<SequenceOp>()))
1054 std::queue<SequenceValue *> worklist;
1064 DenseMap<ElaboratorValue *, std::unique_ptr<ElaboratorValue>, InternMapInfo>
1068 DenseMap<Value, ElaboratorValue *> state;
1072 Materializer materializer;
1081struct ElaborationPass
1082 :
public rtg::impl::ElaborationPassBase<ElaborationPass> {
1085 void runOnOperation()
override;
1086 void cloneTargetsIntoTests(SymbolTable &table);
1090void ElaborationPass::runOnOperation() {
1091 auto moduleOp = getOperation();
1092 SymbolTable table(moduleOp);
1094 cloneTargetsIntoTests(table);
1096 std::mt19937 rng(seed);
1097 Elaborator elaborator(table, rng);
1098 if (failed(elaborator.elaborateModule(moduleOp)))
1099 return signalPassFailure();
1102void ElaborationPass::cloneTargetsIntoTests(SymbolTable &table) {
1103 auto moduleOp = getOperation();
1104 for (
auto target :
llvm::make_early_inc_range(moduleOp.getOps<TargetOp>())) {
1105 for (
auto test : moduleOp.getOps<TestOp>()) {
1107 if (test.getTarget().getEntries().empty())
1112 if (target.getTarget() != test.getTarget())
1115 IRRewriter rewriter(test);
1117 auto newTest = cast<TestOp>(test->clone());
1118 newTest.setSymName(test.getSymName().str() +
"_" +
1119 target.getSymName().str());
1120 table.insert(newTest, rewriter.getInsertionPoint());
1124 rewriter.setInsertionPointToStart(newTest.getBody());
1125 for (
auto &op : target.getBody()->without_terminator())
1126 rewriter.clone(op, mapping);
1128 for (
auto [returnVal, result] :
1129 llvm::zip(target.getBody()->getTerminator()->getOperands(),
1130 newTest.getBody()->getArguments()))
1131 result.replaceAllUsesWith(mapping.lookup(returnVal));
1133 newTest.getBody()->eraseArguments(0,
1134 newTest.getBody()->getNumArguments());
1135 newTest.setTarget(DictType::get(&getContext(), {}));
1142 for (
auto test :
llvm::make_early_inc_range(moduleOp.getOps<TestOp>()))
1143 if (!test.getTarget().getEntries().
empty())
assert(baseType &&"element must be base type")
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
This helps visit TypeOp nodes.
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
ResultType dispatchOpVisitor(Operation *op, ExtraArgs... args)
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static llvm::hash_code hash_value(const ModulePort &port)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.