21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Index/IR/IndexDialect.h"
23#include "mlir/Dialect/Index/IR/IndexOps.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/DenseMapInfoVariant.h"
28#include "llvm/Support/Debug.h"
34#define GEN_PASS_DEF_ELABORATIONPASS
35#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
44#define DEBUG_TYPE "rtg-elaboration"
56 size_t n = w / 32 + (w % 32 != 0);
58 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
63 const uint32_t diff = b - a + 1;
67 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
71 uint32_t width = digits - llvm::countl_zero(diff) - 1;
72 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
91struct SequenceStorage;
92struct RandomizedSequenceStorage;
93struct InterleavedSequenceStorage;
95struct VirtualRegisterStorage;
96struct UniqueLabelStorage;
98struct MemoryBlockStorage;
104 LabelValue(StringAttr name) : name(name) {}
106 bool operator==(
const LabelValue &other)
const {
return name == other.name; }
113using ElaboratorValue =
114 std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
115 RandomizedSequenceStorage *, InterleavedSequenceStorage *,
116 SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
117 LabelValue, ArrayStorage *, TupleStorage *,
118 MemoryBlockStorage *>;
121llvm::hash_code
hash_value(
const LabelValue &val) {
122 return llvm::hash_value(val.name);
126llvm::hash_code
hash_value(
const ElaboratorValue &val) {
128 [&val](
const auto &alternative) {
131 return llvm::hash_combine(val.index(), alternative);
146 static bool isEqual(
const bool &lhs,
const bool &rhs) {
return lhs == rhs; }
160 static bool isEqual(
const LabelValue &lhs,
const LabelValue &rhs) {
177template <
typename StorageTy>
178struct HashedStorage {
179 HashedStorage(
unsigned hashcode = 0, StorageTy *storage =
nullptr)
180 : hashcode(hashcode), storage(storage) {}
190template <
typename StorageTy>
191struct StorageKeyInfo {
192 static inline HashedStorage<StorageTy> getEmptyKey() {
193 return HashedStorage<StorageTy>(0,
194 DenseMapInfo<StorageTy *>::getEmptyKey());
196 static inline HashedStorage<StorageTy> getTombstoneKey() {
197 return HashedStorage<StorageTy>(
198 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
201 static inline unsigned getHashValue(
const HashedStorage<StorageTy> &key) {
204 static inline unsigned getHashValue(
const StorageTy &key) {
208 static inline bool isEqual(
const HashedStorage<StorageTy> &lhs,
209 const HashedStorage<StorageTy> &rhs) {
210 return lhs.storage == rhs.storage;
212 static inline bool isEqual(
const StorageTy &lhs,
213 const HashedStorage<StorageTy> &rhs) {
214 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
217 return lhs.isEqual(rhs.storage);
223 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
225 type,
llvm::hash_combine_range(set.begin(), set.
end()))),
226 set(std::move(set)), type(type) {}
228 bool isEqual(
const SetStorage *other)
const {
229 return hashcode == other->hashcode && set == other->set &&
234 const unsigned hashcode;
237 const SetVector<ElaboratorValue> set;
246 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
248 type,
llvm::hash_combine_range(bag.begin(), bag.
end()))),
249 bag(std::move(bag)), type(type) {}
251 bool isEqual(
const BagStorage *other)
const {
252 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
257 const unsigned hashcode;
261 const MapVector<ElaboratorValue, uint64_t> bag;
269struct SequenceStorage {
270 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
272 familyName,
llvm::hash_combine_range(args.begin(), args.
end()))),
273 familyName(familyName), args(std::move(args)) {}
275 bool isEqual(
const SequenceStorage *other)
const {
276 return hashcode == other->hashcode && familyName == other->familyName &&
281 const unsigned hashcode;
284 const StringAttr familyName;
287 const SmallVector<ElaboratorValue> args;
291struct RandomizedSequenceStorage {
292 RandomizedSequenceStorage(StringRef name,
293 ContextResourceAttrInterface context,
294 StringAttr test, SequenceStorage *sequence)
296 context(context), test(test), sequence(sequence) {}
298 bool isEqual(
const RandomizedSequenceStorage *other)
const {
299 return hashcode == other->hashcode && name == other->name &&
300 context == other->context && test == other->test &&
301 sequence == other->sequence;
305 const unsigned hashcode;
308 const StringRef name;
311 const ContextResourceAttrInterface context;
314 const StringAttr test;
316 const SequenceStorage *sequence;
320struct InterleavedSequenceStorage {
321 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
323 : sequences(std::move(sequences)), batchSize(batchSize),
325 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
328 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
329 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
331 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
334 bool isEqual(
const InterleavedSequenceStorage *other)
const {
335 return hashcode == other->hashcode && sequences == other->sequences &&
336 batchSize == other->batchSize;
339 const SmallVector<ElaboratorValue> sequences;
341 const uint32_t batchSize;
344 const unsigned hashcode;
348struct VirtualRegisterStorage {
349 VirtualRegisterStorage(ArrayAttr allowedRegs) : allowedRegs(allowedRegs) {}
356 const ArrayAttr allowedRegs;
359struct UniqueLabelStorage {
360 UniqueLabelStorage(StringAttr name) : name(name) {}
366 const StringAttr name;
371 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
373 type,
llvm::hash_combine_range(array.begin(), array.
end()))),
374 type(type), array(array) {}
376 bool isEqual(
const ArrayStorage *other)
const {
377 return hashcode == other->hashcode && type == other->type &&
378 array == other->array;
382 const unsigned hashcode;
389 const SmallVector<ElaboratorValue> array;
394 TupleStorage(SmallVector<ElaboratorValue> &&values)
395 : hashcode(
llvm::hash_combine_range(values.begin(), values.
end())),
396 values(std::move(values)) {}
398 bool isEqual(
const TupleStorage *other)
const {
399 return hashcode == other->hashcode && values == other->values;
403 const unsigned hashcode;
405 const SmallVector<ElaboratorValue> values;
409struct MemoryBlockStorage {
410 MemoryBlockStorage(
const APInt &baseAddress,
const APInt &endAddress)
411 : baseAddress(baseAddress), endAddress(endAddress) {}
416 const APInt baseAddress;
419 const APInt endAddress;
432 template <
typename StorageTy,
typename... Args>
433 StorageTy *internalize(Args &&...args) {
434 StorageTy storage(std::forward<Args>(args)...);
436 auto existing = getInternSet<StorageTy>().insert_as(
437 HashedStorage<StorageTy>(storage.hashcode), storage);
438 StorageTy *&storagePtr = existing.first->storage;
441 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
446 template <
typename StorageTy,
typename... Args>
447 StorageTy *create(Args &&...args) {
448 return new (allocator.Allocate<StorageTy>())
449 StorageTy(std::forward<Args>(args)...);
453 template <
typename StorageTy>
454 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
456 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
457 return internedArrays;
458 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
460 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
462 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
463 return internedSequences;
464 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
465 return internedRandomizedSequences;
466 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
467 return internedInterleavedSequences;
468 else if constexpr (std::is_same_v<StorageTy, TupleStorage>)
469 return internedTuples;
471 static_assert(!
sizeof(StorageTy),
472 "no intern set available for this storage type.");
477 llvm::BumpPtrAllocator allocator;
482 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
484 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
485 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
486 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
488 DenseSet<HashedStorage<RandomizedSequenceStorage>,
489 StorageKeyInfo<RandomizedSequenceStorage>>
490 internedRandomizedSequences;
491 DenseSet<HashedStorage<InterleavedSequenceStorage>,
492 StorageKeyInfo<InterleavedSequenceStorage>>
493 internedInterleavedSequences;
494 DenseSet<HashedStorage<TupleStorage>, StorageKeyInfo<TupleStorage>>
502static llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
503 const ElaboratorValue &value);
505static void print(TypedAttr val, llvm::raw_ostream &os) {
506 os <<
"<attr " << val <<
">";
509static void print(BagStorage *val, llvm::raw_ostream &os) {
511 llvm::interleaveComma(val->bag, os,
512 [&](
const std::pair<ElaboratorValue, uint64_t> &el) {
513 os << el.first <<
" -> " << el.second;
515 os <<
"} at " << val <<
">";
518static void print(
bool val, llvm::raw_ostream &os) {
519 os <<
"<bool " << (val ?
"true" :
"false") <<
">";
522static void print(
size_t val, llvm::raw_ostream &os) {
523 os <<
"<index " << val <<
">";
526static void print(SequenceStorage *val, llvm::raw_ostream &os) {
527 os <<
"<sequence @" << val->familyName.getValue() <<
"(";
528 llvm::interleaveComma(val->args, os,
529 [&](
const ElaboratorValue &val) { os << val; });
530 os <<
") at " << val <<
">";
533static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
534 os <<
"<randomized-sequence @" << val->name <<
" derived from @"
535 << val->sequence->familyName.getValue() <<
" under context "
536 << val->context <<
" in test " << val->test <<
"(";
537 llvm::interleaveComma(val->sequence->args, os,
538 [&](
const ElaboratorValue &val) { os << val; });
539 os <<
") at " << val <<
">";
542static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
543 os <<
"<interleaved-sequence [";
544 llvm::interleaveComma(val->sequences, os,
545 [&](
const ElaboratorValue &val) { os << val; });
546 os <<
"] batch-size " << val->batchSize <<
" at " << val <<
">";
549static void print(ArrayStorage *val, llvm::raw_ostream &os) {
551 llvm::interleaveComma(val->array, os,
552 [&](
const ElaboratorValue &val) { os << val; });
553 os <<
"] at " << val <<
">";
556static void print(SetStorage *val, llvm::raw_ostream &os) {
558 llvm::interleaveComma(val->set, os,
559 [&](
const ElaboratorValue &val) { os << val; });
560 os <<
"} at " << val <<
">";
563static void print(
const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
564 os <<
"<virtual-register " << val <<
" " << val->allowedRegs <<
">";
567static void print(
const UniqueLabelStorage *val, llvm::raw_ostream &os) {
568 os <<
"<unique-label " << val <<
" " << val->name <<
">";
571static void print(
const LabelValue &val, llvm::raw_ostream &os) {
572 os <<
"<label " << val.name <<
">";
575static void print(
const TupleStorage *val, llvm::raw_ostream &os) {
577 llvm::interleaveComma(val->values, os,
578 [&](
const ElaboratorValue &val) { os << val; });
582static void print(
const MemoryBlockStorage *val, llvm::raw_ostream &os) {
583 os <<
"<memory-block {"
584 <<
", address-width=" << val->baseAddress.getBitWidth()
585 <<
", base-address=" << val->baseAddress
586 <<
", end-address=" << val->endAddress <<
"}>";
590 const ElaboratorValue &value) {
591 std::visit([&](
auto val) {
print(val, os); }, value);
607 Materializer(OpBuilder builder) : builder(builder) {}
611 Value materialize(ElaboratorValue val, Location loc,
612 std::queue<RandomizedSequenceStorage *> &elabRequests,
613 function_ref<InFlightDiagnostic()> emitError) {
614 auto iter = materializedValues.find(val);
615 if (iter != materializedValues.end())
618 LLVM_DEBUG(llvm::dbgs() <<
"Materializing " << val <<
"\n\n");
621 [&](
auto val) {
return visit(val, loc, elabRequests, emitError); },
632 materialize(Operation *op, DenseMap<Value, ElaboratorValue> &state,
633 std::queue<RandomizedSequenceStorage *> &elabRequests) {
634 if (op->getNumRegions() > 0)
635 return op->emitOpError(
"ops with nested regions must be elaborated away");
643 for (
auto res : op->getResults())
644 if (!res.use_empty())
645 return op->emitOpError(
646 "ops with results that have uses are not supported");
648 if (op->getParentRegion() == builder.getBlock()->getParent()) {
651 deleteOpsUntil([&](
auto iter) {
return &*iter == op; });
653 if (builder.getInsertionPoint() == builder.getBlock()->end())
654 return op->emitError(
"operation did not occur after the current "
655 "materializer insertion point");
657 LLVM_DEBUG(llvm::dbgs() <<
"Modifying in-place: " << *op <<
"\n\n");
659 LLVM_DEBUG(llvm::dbgs() <<
"Materializing a clone of " << *op <<
"\n\n");
660 op = builder.clone(*op);
661 builder.setInsertionPoint(op);
664 for (
auto &operand : op->getOpOperands()) {
665 auto emitError = [&]() {
666 auto diag = op->emitError();
667 diag.attachNote(op->getLoc())
668 <<
"while materializing value for operand#"
669 << operand.getOperandNumber();
673 Value val = materialize(state.at(operand.get()), op->getLoc(),
674 elabRequests, emitError);
681 builder.setInsertionPointAfter(op);
688 deleteOpsUntil([](
auto iter) {
return false; });
690 for (
auto *op :
llvm::reverse(toDelete))
694 template <
typename OpTy,
typename... Args>
695 OpTy create(Location location, Args &&...args) {
696 return builder.create<OpTy>(location, std::forward<Args>(args)...);
700 void deleteOpsUntil(function_ref<
bool(Block::iterator)> stop) {
701 auto ip = builder.getInsertionPoint();
702 while (ip != builder.getBlock()->end() && !stop(ip)) {
703 LLVM_DEBUG(llvm::dbgs() <<
"Marking to be deleted: " << *ip <<
"\n\n");
704 toDelete.push_back(&*ip);
706 builder.setInsertionPointAfter(&*ip);
707 ip = builder.getInsertionPoint();
711 Value visit(TypedAttr val, Location loc,
712 std::queue<RandomizedSequenceStorage *> &elabRequests,
713 function_ref<InFlightDiagnostic()> emitError) {
716 if (
auto intAttr = dyn_cast<IntegerAttr>(val);
717 intAttr && isa<IndexType>(val.getType())) {
718 Value res = builder.create<index::ConstantOp>(loc, intAttr);
719 materializedValues[val] = res;
726 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
728 emitError() <<
"materializer of dialect '"
729 << val.getDialect().getNamespace()
730 <<
"' unable to materialize value for attribute '" << val
735 Value res = op->getResult(0);
736 materializedValues[val] = res;
740 Value visit(
size_t val, Location loc,
741 std::queue<RandomizedSequenceStorage *> &elabRequests,
742 function_ref<InFlightDiagnostic()> emitError) {
743 Value res = builder.create<index::ConstantOp>(loc, val);
744 materializedValues[val] = res;
748 Value visit(
bool val, Location loc,
749 std::queue<RandomizedSequenceStorage *> &elabRequests,
750 function_ref<InFlightDiagnostic()> emitError) {
751 Value res = builder.create<index::BoolConstantOp>(loc, val);
752 materializedValues[val] = res;
756 Value visit(ArrayStorage *val, Location loc,
757 std::queue<RandomizedSequenceStorage *> &elabRequests,
758 function_ref<InFlightDiagnostic()> emitError) {
759 SmallVector<Value> elements;
760 elements.reserve(val->array.size());
761 for (
auto el : val->array) {
762 auto materialized = materialize(el, loc, elabRequests, emitError);
766 elements.push_back(materialized);
769 Value res = builder.create<ArrayCreateOp>(loc, val->type, elements);
770 materializedValues[val] = res;
774 Value visit(SetStorage *val, Location loc,
775 std::queue<RandomizedSequenceStorage *> &elabRequests,
776 function_ref<InFlightDiagnostic()> emitError) {
777 SmallVector<Value> elements;
778 elements.reserve(val->set.size());
779 for (
auto el : val->set) {
780 auto materialized = materialize(el, loc, elabRequests, emitError);
784 elements.push_back(materialized);
787 auto res = builder.create<SetCreateOp>(loc, val->type, elements);
788 materializedValues[val] = res;
792 Value visit(BagStorage *val, Location loc,
793 std::queue<RandomizedSequenceStorage *> &elabRequests,
794 function_ref<InFlightDiagnostic()> emitError) {
795 SmallVector<Value> values, weights;
796 values.reserve(val->bag.size());
797 weights.reserve(val->bag.size());
798 for (
auto [val, weight] : val->bag) {
799 auto materializedVal = materialize(val, loc, elabRequests, emitError);
800 auto materializedWeight =
801 materialize(weight, loc, elabRequests, emitError);
802 if (!materializedVal || !materializedWeight)
805 values.push_back(materializedVal);
806 weights.push_back(materializedWeight);
809 auto res = builder.create<BagCreateOp>(loc, val->type, values, weights);
810 materializedValues[val] = res;
814 Value visit(SequenceStorage *val, Location loc,
815 std::queue<RandomizedSequenceStorage *> &elabRequests,
816 function_ref<InFlightDiagnostic()> emitError) {
817 emitError() <<
"materializing a non-randomized sequence not supported yet";
821 Value visit(RandomizedSequenceStorage *val, Location loc,
822 std::queue<RandomizedSequenceStorage *> &elabRequests,
823 function_ref<InFlightDiagnostic()> emitError) {
824 elabRequests.push(val);
825 Value
seq = builder.create<GetSequenceOp>(
826 loc, SequenceType::get(builder.getContext(), {}), val->name);
827 Value res = builder.create<RandomizeSequenceOp>(loc,
seq);
828 materializedValues[val] = res;
832 Value visit(InterleavedSequenceStorage *val, Location loc,
833 std::queue<RandomizedSequenceStorage *> &elabRequests,
834 function_ref<InFlightDiagnostic()> emitError) {
835 SmallVector<Value> sequences;
836 for (
auto seqVal : val->sequences)
837 sequences.push_back(materialize(seqVal, loc, elabRequests, emitError));
839 if (sequences.size() == 1)
843 builder.create<InterleaveSequencesOp>(loc, sequences, val->batchSize);
844 materializedValues[val] = res;
848 Value visit(VirtualRegisterStorage *val, Location loc,
849 std::queue<RandomizedSequenceStorage *> &elabRequests,
850 function_ref<InFlightDiagnostic()> emitError) {
851 Value res = builder.create<VirtualRegisterOp>(loc, val->allowedRegs);
852 materializedValues[val] = res;
856 Value visit(UniqueLabelStorage *val, Location loc,
857 std::queue<RandomizedSequenceStorage *> &elabRequests,
858 function_ref<InFlightDiagnostic()> emitError) {
859 Value res = builder.create<LabelUniqueDeclOp>(loc, val->name, ValueRange());
860 materializedValues[val] = res;
864 Value visit(
const LabelValue &val, Location loc,
865 std::queue<RandomizedSequenceStorage *> &elabRequests,
866 function_ref<InFlightDiagnostic()> emitError) {
867 Value res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
868 materializedValues[val] = res;
872 Value visit(TupleStorage *val, Location loc,
873 std::queue<RandomizedSequenceStorage *> &elabRequests,
874 function_ref<InFlightDiagnostic()> emitError) {
875 SmallVector<Value> materialized;
876 materialized.reserve(val->values.size());
877 for (
auto v : val->values)
878 materialized.push_back(materialize(v, loc, elabRequests, emitError));
879 Value res = builder.create<TupleCreateOp>(loc, materialized);
880 materializedValues[val] = res;
890 DenseMap<ElaboratorValue, Value> materializedValues;
896 SmallVector<Operation *> toDelete;
905enum class DeletionKind { Keep, Delete };
908struct ElaboratorSharedState {
909 ElaboratorSharedState(SymbolTable &table,
unsigned seed)
910 : table(table), rng(seed) {}
915 Internalizer internalizer;
919 std::queue<RandomizedSequenceStorage *> worklist;
929 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
935class Elaborator :
public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
938 using RTGBase::visitOp;
940 Elaborator(ElaboratorSharedState &sharedState, TestState &testState,
941 Materializer &materializer,
942 ContextResourceAttrInterface currentContext = {})
943 : sharedState(sharedState), testState(testState),
944 materializer(materializer), currentContext(currentContext) {}
946 template <
typename ValueTy>
947 inline ValueTy
get(Value val)
const {
948 return std::get<ValueTy>(state.at(val));
951 FailureOr<DeletionKind> visitPureOp(Operation *op) {
952 SmallVector<Attribute> operands;
953 for (
auto operand : op->getOperands()) {
954 auto evalValue = state[operand];
955 if (std::holds_alternative<TypedAttr>(evalValue))
956 operands.push_back(std::get<TypedAttr>(evalValue));
961 SmallVector<OpFoldResult> results;
962 if (failed(op->fold(operands, results)))
966 if (results.size() != op->getNumResults())
969 for (
auto [res, val] :
llvm::zip(results, op->getResults())) {
970 auto attr = llvm::dyn_cast_or_null<TypedAttr>(res.dyn_cast<Attribute>());
972 return op->emitError(
973 "only typed attributes supported for constant-like operations");
975 auto intAttr = dyn_cast<IntegerAttr>(attr);
976 if (intAttr && isa<IndexType>(attr.getType()))
977 state[op->getResult(0)] = size_t(intAttr.getInt());
978 else if (intAttr && intAttr.getType().isSignlessInteger(1))
979 state[op->getResult(0)] = bool(intAttr.getInt());
981 state[op->getResult(0)] = attr;
984 return DeletionKind::Delete;
989 return op->emitOpError(
"elaboration not supported");
993 auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
994 if (op->hasTrait<OpTrait::ConstantLike>() || (memOp && memOp.hasNoEffect()))
995 return visitPureOp(op);
1000 if (op->use_empty())
1001 return DeletionKind::Keep;
1006 FailureOr<DeletionKind> visitOp(ConstantOp op) {
return visitPureOp(op); }
1008 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
1009 SmallVector<ElaboratorValue> replacements;
1010 state[op.getResult()] =
1011 sharedState.internalizer.internalize<SequenceStorage>(
1012 op.getSequenceAttr(), std::move(replacements));
1013 return DeletionKind::Delete;
1016 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
1017 auto *
seq = get<SequenceStorage *>(op.getSequence());
1019 SmallVector<ElaboratorValue> replacements(
seq->args);
1020 for (
auto replacement : op.getReplacements())
1021 replacements.push_back(state.at(replacement));
1023 state[op.getResult()] =
1024 sharedState.internalizer.internalize<SequenceStorage>(
1025 seq->familyName, std::move(replacements));
1027 return DeletionKind::Delete;
1030 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
1031 auto *
seq = get<SequenceStorage *>(op.getSequence());
1033 auto name = sharedState.names.newName(
seq->familyName.getValue());
1034 auto *randomizedSeq =
1035 sharedState.internalizer.internalize<RandomizedSequenceStorage>(
1036 name, currentContext, testState.name,
seq);
1037 state[op.getResult()] =
1038 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1040 return DeletionKind::Delete;
1043 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
1044 SmallVector<ElaboratorValue> sequences;
1045 for (
auto seq : op.getSequences())
1046 sequences.push_back(
get<InterleavedSequenceStorage *>(
seq));
1048 state[op.getResult()] =
1049 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1050 std::move(sequences), op.getBatchSize());
1051 return DeletionKind::Delete;
1055 LogicalResult isValidContext(ElaboratorValue value, Operation *op)
const {
1056 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
1057 auto *
seq = std::get<RandomizedSequenceStorage *>(value);
1058 if (
seq->context != currentContext) {
1059 auto err = op->emitError(
"attempting to place sequence ")
1060 <<
seq->name <<
" derived from "
1061 <<
seq->sequence->familyName.getValue() <<
" under context "
1063 <<
", but it was previously randomized for context ";
1065 err <<
seq->context;
1073 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1074 for (
auto val : interVal->sequences)
1075 if (failed(isValidContext(val, op)))
1080 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1081 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1082 if (failed(isValidContext(seqVal, op)))
1085 return DeletionKind::Keep;
1088 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1089 SetVector<ElaboratorValue> set;
1090 for (
auto val : op.getElements())
1091 set.insert(state.at(val));
1093 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1094 std::move(set), op.getSet().getType());
1095 return DeletionKind::Delete;
1098 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1099 auto set = get<SetStorage *>(op.getSet())->set;
1102 return op->emitError(
"cannot select from an empty set");
1106 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1107 std::mt19937 customRng(intAttr.getInt());
1113 state[op.getResult()] = set[selected];
1114 return DeletionKind::Delete;
1117 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1118 auto original = get<SetStorage *>(op.getOriginal())->set;
1119 auto diff = get<SetStorage *>(op.getDiff())->set;
1121 SetVector<ElaboratorValue> result(original);
1122 result.set_subtract(diff);
1124 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1125 std::move(result), op.getResult().getType());
1126 return DeletionKind::Delete;
1129 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1130 SetVector<ElaboratorValue> result;
1131 for (
auto set : op.getSets())
1132 result.set_union(
get<SetStorage *>(set)->set);
1134 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1135 std::move(result), op.getType());
1136 return DeletionKind::Delete;
1139 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1140 auto size = get<SetStorage *>(op.getSet())->set.size();
1141 state[op.getResult()] = size;
1142 return DeletionKind::Delete;
1148 FailureOr<DeletionKind> visitOp(SetCartesianProductOp op) {
1149 SetVector<ElaboratorValue> result;
1150 SmallVector<SmallVector<ElaboratorValue>> tuples;
1151 tuples.push_back({});
1153 for (
auto input : op.getInputs()) {
1154 auto &set = get<SetStorage *>(input)->set;
1156 SetVector<ElaboratorValue>
empty;
1157 state[op.getResult()] =
1158 sharedState.internalizer.internalize<SetStorage>(std::move(
empty),
1160 return DeletionKind::Delete;
1163 for (
unsigned i = 0, e = tuples.size(); i < e; ++i) {
1164 for (
auto setEl : set.getArrayRef().drop_back()) {
1165 tuples.push_back(tuples[i]);
1166 tuples.back().push_back(setEl);
1168 tuples[i].push_back(set.back());
1172 for (
auto &tup : tuples)
1174 sharedState.internalizer.internalize<TupleStorage>(std::move(tup)));
1176 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1177 std::move(result), op.getType());
1178 return DeletionKind::Delete;
1181 FailureOr<DeletionKind> visitOp(SetConvertToBagOp op) {
1182 auto set = get<SetStorage *>(op.getInput())->set;
1183 MapVector<ElaboratorValue, uint64_t> bag;
1184 for (
auto val : set)
1185 bag.insert({val, 1});
1186 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1187 std::move(bag), op.getType());
1188 return DeletionKind::Delete;
1191 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1192 MapVector<ElaboratorValue, uint64_t> bag;
1193 for (
auto [val, multiple] :
1194 llvm::zip(op.getElements(), op.getMultiples())) {
1198 bag[state.at(val)] += get<size_t>(multiple);
1201 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1202 std::move(bag), op.getType());
1203 return DeletionKind::Delete;
1206 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1207 auto bag = get<BagStorage *>(op.getBag())->bag;
1210 return op->emitError(
"cannot select from an empty bag");
1212 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1213 prefixSum.reserve(bag.size());
1214 uint32_t accumulator = 0;
1215 for (
auto [val, weight] : bag) {
1216 accumulator += weight;
1217 prefixSum.push_back({val, accumulator});
1220 auto customRng = sharedState.rng;
1222 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1223 customRng = std::mt19937(intAttr.getInt());
1227 auto *iter = llvm::upper_bound(
1229 [](uint32_t a,
const std::pair<ElaboratorValue, uint32_t> &b) {
1230 return a < b.second;
1233 state[op.getResult()] = iter->first;
1234 return DeletionKind::Delete;
1237 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1238 auto original = get<BagStorage *>(op.getOriginal())->bag;
1239 auto diff = get<BagStorage *>(op.getDiff())->bag;
1241 MapVector<ElaboratorValue, uint64_t> result;
1242 for (
const auto &el : original) {
1243 if (!diff.contains(el.first)) {
1251 auto toDiff = diff.lookup(el.first);
1252 if (el.second <= toDiff)
1255 result.insert({el.first, el.second - toDiff});
1258 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1259 std::move(result), op.getType());
1260 return DeletionKind::Delete;
1263 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1264 MapVector<ElaboratorValue, uint64_t> result;
1265 for (
auto bag : op.getBags()) {
1266 auto val = get<BagStorage *>(bag)->bag;
1267 for (
auto [el, multiple] : val)
1268 result[el] += multiple;
1271 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1272 std::move(result), op.getType());
1273 return DeletionKind::Delete;
1276 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1277 auto size = get<BagStorage *>(op.getBag())->bag.size();
1278 state[op.getResult()] = size;
1279 return DeletionKind::Delete;
1282 FailureOr<DeletionKind> visitOp(BagConvertToSetOp op) {
1283 auto bag = get<BagStorage *>(op.getInput())->bag;
1284 SetVector<ElaboratorValue> set;
1285 for (
auto [k, v] : bag)
1287 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1288 std::move(set), op.getType());
1289 return DeletionKind::Delete;
1292 FailureOr<DeletionKind> visitOp(FixedRegisterOp op) {
1293 return visitPureOp(op);
1296 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1297 state[op.getResult()] =
1298 sharedState.internalizer.create<VirtualRegisterStorage>(
1299 op.getAllowedRegsAttr());
1300 return DeletionKind::Delete;
1303 StringAttr substituteFormatString(StringAttr formatString,
1304 ValueRange substitutes)
const {
1305 if (substitutes.empty() || formatString.empty())
1306 return formatString;
1308 auto original = formatString.getValue().str();
1309 for (
auto [i, subst] :
llvm::enumerate(substitutes)) {
1310 size_t startPos = 0;
1311 std::string from =
"{{" + std::to_string(i) +
"}}";
1312 while ((startPos = original.find(from, startPos)) != std::string::npos) {
1313 auto substString = std::to_string(get<size_t>(subst));
1314 original.replace(startPos, from.length(), substString);
1318 return StringAttr::get(formatString.getContext(), original);
1321 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1322 SmallVector<ElaboratorValue> array;
1323 array.reserve(op.getElements().size());
1324 for (
auto val : op.getElements())
1325 array.emplace_back(state.at(val));
1327 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1328 op.getResult().getType(), std::move(array));
1329 return DeletionKind::Delete;
1332 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1333 auto array = get<ArrayStorage *>(op.getArray())->array;
1334 size_t idx = get<size_t>(op.getIndex());
1336 if (array.size() <= idx)
1337 return op->emitError(
"invalid to access index ")
1338 << idx <<
" of an array with " << array.size() <<
" elements";
1340 state[op.getResult()] = array[idx];
1341 return DeletionKind::Delete;
1344 FailureOr<DeletionKind> visitOp(ArrayInjectOp op) {
1345 auto array = get<ArrayStorage *>(op.getArray())->array;
1346 size_t idx = get<size_t>(op.getIndex());
1348 if (array.size() <= idx)
1349 return op->emitError(
"invalid to access index ")
1350 << idx <<
" of an array with " << array.size() <<
" elements";
1352 array[idx] = state[op.getValue()];
1353 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1354 op.getResult().getType(), std::move(array));
1355 return DeletionKind::Delete;
1358 FailureOr<DeletionKind> visitOp(ArraySizeOp op) {
1359 auto array = get<ArrayStorage *>(op.getArray())->array;
1360 state[op.getResult()] = array.size();
1361 return DeletionKind::Delete;
1364 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
1366 substituteFormatString(op.getFormatStringAttr(), op.getArgs());
1367 state[op.getLabel()] = LabelValue(substituted);
1368 return DeletionKind::Delete;
1371 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1372 state[op.getLabel()] = sharedState.internalizer.create<UniqueLabelStorage>(
1373 substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
1374 return DeletionKind::Delete;
1377 FailureOr<DeletionKind> visitOp(LabelOp op) {
return DeletionKind::Keep; }
1379 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1380 size_t lower = get<size_t>(op.getLowerBound());
1381 size_t upper = get<size_t>(op.getUpperBound()) - 1;
1383 return op->emitError(
"cannot select a number from an empty range");
1386 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1387 std::mt19937 customRng(intAttr.getInt());
1388 state[op.getResult()] =
1391 state[op.getResult()] =
1395 return DeletionKind::Delete;
1398 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1399 size_t input = get<size_t>(op.getInput());
1400 auto width = op.getType().getWidth();
1401 auto emitError = [&]() {
return op->emitError(); };
1402 if (input > APInt::getAllOnes(width).getZExtValue())
1403 return emitError() <<
"cannot represent " << input <<
" with " << width
1406 state[op.getResult()] =
1407 ImmediateAttr::get(op.getContext(), APInt(width, input));
1408 return DeletionKind::Delete;
1411 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1412 ContextResourceAttrInterface from = currentContext,
1413 to = cast<ContextResourceAttrInterface>(
1414 get<TypedAttr>(op.getContext()));
1415 if (!currentContext)
1416 from = DefaultContextAttr::get(op->getContext(), to.getType());
1418 auto emitError = [&]() {
1419 auto diag = op.emitError();
1420 diag.attachNote(op.getLoc())
1421 <<
"while materializing value for context switching for " << op;
1426 Value seqVal = materializer.materialize(
1427 get<SequenceStorage *>(op.getSequence()), op.getLoc(),
1428 sharedState.worklist, emitError);
1430 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1431 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1432 return DeletionKind::Delete;
1438 auto *iter = testState.contextSwitches.find({from, to});
1441 if (iter == testState.contextSwitches.end())
1442 iter = testState.contextSwitches.find(
1443 {from, AnyContextAttr::get(op->getContext(), to.getType())});
1446 if (iter == testState.contextSwitches.end())
1447 iter = testState.contextSwitches.find(
1448 {AnyContextAttr::get(op->getContext(), from.getType()), to});
1451 if (iter == testState.contextSwitches.end())
1452 iter = testState.contextSwitches.find(
1453 {AnyContextAttr::get(op->getContext(), from.getType()),
1454 AnyContextAttr::get(op->getContext(), to.getType())});
1460 if (iter == testState.contextSwitches.end())
1461 return op->emitError(
"no context transition registered to switch from ")
1462 << from <<
" to " << to;
1464 auto familyName = iter->second->familyName;
1465 SmallVector<ElaboratorValue> args{from, to,
1466 get<SequenceStorage *>(op.getSequence())};
1467 auto *
seq = sharedState.internalizer.internalize<SequenceStorage>(
1468 familyName, std::move(args));
1470 sharedState.internalizer.internalize<RandomizedSequenceStorage>(
1471 sharedState.names.newName(familyName.getValue()), to,
1472 testState.name,
seq);
1473 Value seqVal = materializer.materialize(randSeq, op.getLoc(),
1474 sharedState.worklist, emitError);
1475 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1477 return DeletionKind::Delete;
1480 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1481 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1482 get<SequenceStorage *>(op.getSequence());
1483 return DeletionKind::Delete;
1486 FailureOr<DeletionKind> visitOp(TupleCreateOp op) {
1487 SmallVector<ElaboratorValue> values;
1488 values.reserve(op.getElements().size());
1489 for (
auto el : op.getElements())
1490 values.push_back(state[el]);
1492 state[op.getResult()] =
1493 sharedState.internalizer.internalize<TupleStorage>(std::move(values));
1494 return DeletionKind::Delete;
1497 FailureOr<DeletionKind> visitOp(TupleExtractOp op) {
1498 auto *tuple = get<TupleStorage *>(op.getTuple());
1499 state[op.getResult()] = tuple->values[op.getIndex().getZExtValue()];
1500 return DeletionKind::Delete;
1503 FailureOr<DeletionKind> visitOp(CommentOp op) {
return DeletionKind::Keep; }
1505 FailureOr<DeletionKind> visitOp(MemoryBlockDeclareOp op) {
1506 state[op.getResult()] = sharedState.internalizer.create<MemoryBlockStorage>(
1507 op.getEndAddress(), op.getBaseAddress());
1508 return DeletionKind::Delete;
1511 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1512 bool cond = get<bool>(op.getCondition());
1513 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1514 if (toElaborate.empty())
1515 return DeletionKind::Delete;
1521 if (failed(elaborate(toElaborate)))
1525 for (
auto [res, out] :
1526 llvm::zip(op.getResults(),
1527 toElaborate.front().getTerminator()->getOperands()))
1528 state[res] = state.at(out);
1530 return DeletionKind::Delete;
1533 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1534 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1535 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1536 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1537 return op->emitOpError(
"can only elaborate index type iterator");
1539 auto lowerBound = get<size_t>(op.getLowerBound());
1540 auto step = get<size_t>(op.getStep());
1541 auto upperBound = get<size_t>(op.getUpperBound());
1547 state[op.getInductionVar()] = lowerBound;
1548 for (
auto [iterArg, initArg] :
1549 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1550 state[iterArg] = state.at(initArg);
1553 for (
size_t i = lowerBound; i < upperBound; i += step) {
1554 if (failed(elaborate(op.getBodyRegion())))
1559 state[op.getInductionVar()] = i + step;
1560 for (
auto [iterArg, prevIterArg] :
1561 llvm::zip(op.getRegionIterArgs(),
1562 op.getBody()->getTerminator()->getOperands()))
1563 state[iterArg] = state.at(prevIterArg);
1567 for (
auto [res, iterArg] :
1568 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1569 state[res] = state.at(iterArg);
1571 return DeletionKind::Delete;
1574 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1575 return DeletionKind::Delete;
1578 FailureOr<DeletionKind> visitOp(arith::AddIOp op) {
1579 if (!isa<IndexType>(op.getType()))
1580 return op->emitError(
"only index operands supported");
1582 size_t lhs = get<size_t>(op.getLhs());
1583 size_t rhs = get<size_t>(op.getRhs());
1584 state[op.getResult()] = lhs + rhs;
1585 return DeletionKind::Delete;
1588 FailureOr<DeletionKind> visitOp(arith::AndIOp op) {
1589 if (!op.getType().isSignlessInteger(1))
1590 return op->emitError(
"only 'i1' operands supported");
1592 bool lhs = get<bool>(op.getLhs());
1593 bool rhs = get<bool>(op.getRhs());
1594 state[op.getResult()] = lhs && rhs;
1595 return DeletionKind::Delete;
1598 FailureOr<DeletionKind> visitOp(arith::XOrIOp op) {
1599 if (!op.getType().isSignlessInteger(1))
1600 return op->emitError(
"only 'i1' operands supported");
1602 bool lhs = get<bool>(op.getLhs());
1603 bool rhs = get<bool>(op.getRhs());
1604 state[op.getResult()] = lhs != rhs;
1605 return DeletionKind::Delete;
1608 FailureOr<DeletionKind> visitOp(arith::OrIOp op) {
1609 if (!op.getType().isSignlessInteger(1))
1610 return op->emitError(
"only 'i1' operands supported");
1612 bool lhs = get<bool>(op.getLhs());
1613 bool rhs = get<bool>(op.getRhs());
1614 state[op.getResult()] = lhs || rhs;
1615 return DeletionKind::Delete;
1618 FailureOr<DeletionKind> visitOp(arith::SelectOp op) {
1619 bool cond = get<bool>(op.getCondition());
1620 auto trueVal = state[op.getTrueValue()];
1621 auto falseVal = state[op.getFalseValue()];
1622 state[op.getResult()] = cond ? trueVal : falseVal;
1623 return DeletionKind::Delete;
1626 FailureOr<DeletionKind> visitOp(index::AddOp op) {
1627 size_t lhs = get<size_t>(op.getLhs());
1628 size_t rhs = get<size_t>(op.getRhs());
1629 state[op.getResult()] = lhs + rhs;
1630 return DeletionKind::Delete;
1633 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
1634 size_t lhs = get<size_t>(op.getLhs());
1635 size_t rhs = get<size_t>(op.getRhs());
1637 switch (op.getPred()) {
1638 case index::IndexCmpPredicate::EQ:
1639 result = lhs == rhs;
1641 case index::IndexCmpPredicate::NE:
1642 result = lhs != rhs;
1644 case index::IndexCmpPredicate::ULT:
1647 case index::IndexCmpPredicate::ULE:
1648 result = lhs <= rhs;
1650 case index::IndexCmpPredicate::UGT:
1653 case index::IndexCmpPredicate::UGE:
1654 result = lhs >= rhs;
1657 return op->emitOpError(
"elaboration not supported");
1659 state[op.getResult()] = result;
1660 return DeletionKind::Delete;
1663 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
1664 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
1667 arith::AddIOp, arith::XOrIOp, arith::AndIOp, arith::OrIOp,
1670 index::AddOp, index::CmpOp,
1672 scf::IfOp, scf::ForOp, scf::YieldOp>(
1673 [&](
auto op) {
return visitOp(op); })
1674 .Default([&](Operation *op) {
return RTGBase::dispatchOpVisitor(op); });
1678 LogicalResult elaborate(Region ®ion,
1679 ArrayRef<ElaboratorValue> regionArguments = {}) {
1680 if (region.getBlocks().size() > 1)
1681 return region.getParentOp()->emitOpError(
1682 "regions with more than one block are not supported");
1684 for (
auto [arg, elabArg] :
1685 llvm::zip(region.getArguments(), regionArguments))
1686 state[arg] = elabArg;
1688 Block *block = ®ion.front();
1689 for (
auto &op : *block) {
1690 auto result = dispatchOpVisitor(&op);
1694 if (*result == DeletionKind::Keep)
1695 if (failed(materializer.materialize(&op, state, sharedState.worklist)))
1699 llvm::dbgs() <<
"Elaborated " << op <<
" to\n[";
1701 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](
auto res) {
1702 if (state.contains(res))
1703 llvm::dbgs() << state.at(res);
1705 llvm::dbgs() <<
"unknown";
1708 llvm::dbgs() <<
"]\n\n";
1717 ElaboratorSharedState &sharedState;
1720 TestState &testState;
1724 Materializer &materializer;
1727 DenseMap<Value, ElaboratorValue> state;
1730 ContextResourceAttrInterface currentContext;
1739struct ElaborationPass
1740 :
public rtg::impl::ElaborationPassBase<ElaborationPass> {
1743 void runOnOperation()
override;
1744 void cloneTargetsIntoTests(SymbolTable &table);
1745 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
1749void ElaborationPass::runOnOperation() {
1750 auto moduleOp = getOperation();
1751 SymbolTable table(moduleOp);
1753 cloneTargetsIntoTests(table);
1755 if (failed(elaborateModule(moduleOp, table)))
1756 return signalPassFailure();
1759void ElaborationPass::cloneTargetsIntoTests(SymbolTable &table) {
1760 auto moduleOp = getOperation();
1761 for (
auto target :
llvm::make_early_inc_range(moduleOp.getOps<TargetOp>())) {
1762 for (
auto test : moduleOp.getOps<TestOp>()) {
1764 if (test.getTarget().getEntries().empty())
1769 if (target.getTarget() != test.getTarget())
1772 IRRewriter rewriter(test);
1774 auto newTest = cast<TestOp>(test->clone());
1775 newTest.setSymName(test.getSymName().str() +
"_" +
1776 target.getSymName().str());
1777 table.insert(newTest, rewriter.getInsertionPoint());
1781 rewriter.setInsertionPointToStart(newTest.getBody());
1782 for (
auto &op : target.getBody()->without_terminator())
1783 rewriter.clone(op, mapping);
1785 for (
auto [returnVal, result] :
1786 llvm::zip(target.getBody()->getTerminator()->getOperands(),
1787 newTest.getBody()->getArguments()))
1788 result.replaceAllUsesWith(mapping.lookup(returnVal));
1790 newTest.getBody()->eraseArguments(0,
1791 newTest.getBody()->getNumArguments());
1792 newTest.setTarget(DictType::get(&getContext(), {}));
1799 for (
auto test :
llvm::make_early_inc_range(moduleOp.getOps<TestOp>()))
1800 if (!test.getTarget().getEntries().
empty())
1804LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
1805 SymbolTable &table) {
1806 ElaboratorSharedState state(table, seed);
1809 state.names.add(moduleOp);
1813 DenseMap<StringAttr, TestState> testStates;
1814 for (
auto testOp : moduleOp.getOps<TestOp>()) {
1815 LLVM_DEBUG(llvm::dbgs()
1816 <<
"\n=== Elaborating test @" << testOp.getSymName() <<
"\n\n");
1817 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()));
1818 testStates[testOp.getSymNameAttr()].name = testOp.getSymNameAttr();
1819 Elaborator elaborator(state, testStates[testOp.getSymNameAttr()],
1821 if (failed(elaborator.elaborate(testOp.getBodyRegion())))
1824 materializer.finalize();
1829 while (!state.worklist.empty()) {
1830 auto *curr = state.worklist.front();
1831 state.worklist.pop();
1833 if (table.lookup<SequenceOp>(curr->name))
1836 auto familyOp = table.lookup<SequenceOp>(curr->sequence->familyName);
1839 OpBuilder builder(familyOp);
1840 auto seqOp = builder.cloneWithoutRegions(familyOp);
1841 seqOp.getBodyRegion().emplaceBlock();
1842 seqOp.setSymName(curr->name);
1843 seqOp.setSequenceType(
1844 SequenceType::get(builder.getContext(), ArrayRef<Type>{}));
1845 table.insert(seqOp);
1846 assert(seqOp.getSymName() == curr->name &&
"should not have been renamed");
1848 LLVM_DEBUG(llvm::dbgs()
1849 <<
"\n=== Elaborating sequence family @" << familyOp.getSymName()
1850 <<
" into @" << seqOp.getSymName() <<
" under context "
1851 << curr->context <<
"\n\n");
1853 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()));
1854 Elaborator elaborator(state, testStates[curr->test], materializer,
1856 if (failed(elaborator.elaborate(familyOp.getBodyRegion(),
1857 curr->sequence->args)))
1860 materializer.finalize();
assert(baseType &&"element must be base type")
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static void print(TypedAttr val, llvm::raw_ostream &os)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
This helps visit TypeOp nodes.
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
static llvm::hash_code hash_value(const ModulePort &port)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static LabelValue getEmptyKey()
static LabelValue getTombstoneKey()
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)
static unsigned getEmptyKey()