21#include "mlir/Dialect/Index/IR/IndexDialect.h"
22#include "mlir/Dialect/Index/IR/IndexOps.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/IR/IRMapping.h"
25#include "mlir/IR/PatternMatch.h"
26#include "llvm/ADT/DenseMapInfoVariant.h"
27#include "llvm/Support/Debug.h"
33#define GEN_PASS_DEF_ELABORATIONPASS
34#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
43#define DEBUG_TYPE "rtg-elaboration"
55 size_t n = w / 32 + (w % 32 != 0);
57 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
62 const uint32_t diff = b - a + 1;
66 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
70 uint32_t width = digits - llvm::countl_zero(diff) - 1;
71 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
89struct SequenceStorage;
90struct RandomizedSequenceStorage;
91struct InterleavedSequenceStorage;
93struct VirtualRegisterStorage;
94struct UniqueLabelStorage;
100 LabelValue(StringAttr name) : name(name) {}
102 bool operator==(
const LabelValue &other)
const {
return name == other.name; }
109using ElaboratorValue =
110 std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
111 RandomizedSequenceStorage *, InterleavedSequenceStorage *,
112 SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
116llvm::hash_code
hash_value(
const LabelValue &val) {
117 return llvm::hash_value(val.name);
121llvm::hash_code
hash_value(
const ElaboratorValue &val) {
123 [&val](
const auto &alternative) {
126 return llvm::hash_combine(val.index(), alternative);
141 static bool isEqual(
const bool &lhs,
const bool &rhs) {
return lhs == rhs; }
155 static bool isEqual(
const LabelValue &lhs,
const LabelValue &rhs) {
172template <
typename StorageTy>
173struct HashedStorage {
174 HashedStorage(
unsigned hashcode = 0, StorageTy *storage =
nullptr)
175 : hashcode(hashcode), storage(storage) {}
185template <
typename StorageTy>
186struct StorageKeyInfo {
187 static inline HashedStorage<StorageTy> getEmptyKey() {
188 return HashedStorage<StorageTy>(0,
189 DenseMapInfo<StorageTy *>::getEmptyKey());
191 static inline HashedStorage<StorageTy> getTombstoneKey() {
192 return HashedStorage<StorageTy>(
193 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
196 static inline unsigned getHashValue(
const HashedStorage<StorageTy> &key) {
199 static inline unsigned getHashValue(
const StorageTy &key) {
203 static inline bool isEqual(
const HashedStorage<StorageTy> &lhs,
204 const HashedStorage<StorageTy> &rhs) {
205 return lhs.storage == rhs.storage;
207 static inline bool isEqual(
const StorageTy &lhs,
208 const HashedStorage<StorageTy> &rhs) {
209 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
212 return lhs.isEqual(rhs.storage);
218 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
220 type,
llvm::hash_combine_range(set.begin(), set.
end()))),
221 set(std::move(set)), type(type) {}
223 bool isEqual(
const SetStorage *other)
const {
224 return hashcode == other->hashcode && set == other->set &&
229 const unsigned hashcode;
232 const SetVector<ElaboratorValue> set;
241 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
243 type,
llvm::hash_combine_range(bag.begin(), bag.
end()))),
244 bag(std::move(bag)), type(type) {}
246 bool isEqual(
const BagStorage *other)
const {
247 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
252 const unsigned hashcode;
256 const MapVector<ElaboratorValue, uint64_t> bag;
264struct SequenceStorage {
265 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
267 familyName,
llvm::hash_combine_range(args.begin(), args.
end()))),
268 familyName(familyName), args(std::move(args)) {}
270 bool isEqual(
const SequenceStorage *other)
const {
271 return hashcode == other->hashcode && familyName == other->familyName &&
276 const unsigned hashcode;
279 const StringAttr familyName;
282 const SmallVector<ElaboratorValue> args;
286struct RandomizedSequenceStorage {
287 RandomizedSequenceStorage(StringRef name,
288 ContextResourceAttrInterface context,
289 StringAttr test, SequenceStorage *sequence)
291 context(context), test(test), sequence(sequence) {}
293 bool isEqual(
const RandomizedSequenceStorage *other)
const {
294 return hashcode == other->hashcode && name == other->name &&
295 context == other->context && test == other->test &&
296 sequence == other->sequence;
300 const unsigned hashcode;
303 const StringRef name;
306 const ContextResourceAttrInterface context;
309 const StringAttr test;
311 const SequenceStorage *sequence;
315struct InterleavedSequenceStorage {
316 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
318 : sequences(std::move(sequences)), batchSize(batchSize),
320 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
323 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
324 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
326 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
329 bool isEqual(
const InterleavedSequenceStorage *other)
const {
330 return hashcode == other->hashcode && sequences == other->sequences &&
331 batchSize == other->batchSize;
334 const SmallVector<ElaboratorValue> sequences;
336 const uint32_t batchSize;
339 const unsigned hashcode;
343struct VirtualRegisterStorage {
344 VirtualRegisterStorage(ArrayAttr allowedRegs) : allowedRegs(allowedRegs) {}
351 const ArrayAttr allowedRegs;
354struct UniqueLabelStorage {
355 UniqueLabelStorage(StringAttr name) : name(name) {}
361 const StringAttr name;
374 template <
typename StorageTy,
typename... Args>
375 StorageTy *internalize(Args &&...args) {
376 StorageTy storage(std::forward<Args>(args)...);
378 auto existing = getInternSet<StorageTy>().insert_as(
379 HashedStorage<StorageTy>(storage.hashcode), storage);
380 StorageTy *&storagePtr = existing.first->storage;
383 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
388 template <
typename StorageTy,
typename... Args>
389 StorageTy *create(Args &&...args) {
390 return new (allocator.Allocate<StorageTy>())
391 StorageTy(std::forward<Args>(args)...);
395 template <
typename StorageTy>
396 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
398 if constexpr (std::is_same_v<StorageTy, SetStorage>)
400 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
402 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
403 return internedSequences;
404 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
405 return internedRandomizedSequences;
406 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
407 return internedInterleavedSequences;
409 static_assert(!
sizeof(StorageTy),
410 "no intern set available for this storage type.");
415 llvm::BumpPtrAllocator allocator;
420 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
421 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
422 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
424 DenseSet<HashedStorage<RandomizedSequenceStorage>,
425 StorageKeyInfo<RandomizedSequenceStorage>>
426 internedRandomizedSequences;
427 DenseSet<HashedStorage<InterleavedSequenceStorage>,
428 StorageKeyInfo<InterleavedSequenceStorage>>
429 internedInterleavedSequences;
436static llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
437 const ElaboratorValue &value);
439static void print(TypedAttr val, llvm::raw_ostream &os) {
440 os <<
"<attr " << val <<
">";
443static void print(BagStorage *val, llvm::raw_ostream &os) {
445 llvm::interleaveComma(val->bag, os,
446 [&](
const std::pair<ElaboratorValue, uint64_t> &el) {
447 os << el.first <<
" -> " << el.second;
449 os <<
"} at " << val <<
">";
452static void print(
bool val, llvm::raw_ostream &os) {
453 os <<
"<bool " << (val ?
"true" :
"false") <<
">";
456static void print(
size_t val, llvm::raw_ostream &os) {
457 os <<
"<index " << val <<
">";
460static void print(SequenceStorage *val, llvm::raw_ostream &os) {
461 os <<
"<sequence @" << val->familyName.getValue() <<
"(";
462 llvm::interleaveComma(val->args, os,
463 [&](
const ElaboratorValue &val) { os << val; });
464 os <<
") at " << val <<
">";
467static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
468 os <<
"<randomized-sequence @" << val->name <<
" derived from @"
469 << val->sequence->familyName.getValue() <<
" under context "
470 << val->context <<
" in test " << val->test <<
"(";
471 llvm::interleaveComma(val->sequence->args, os,
472 [&](
const ElaboratorValue &val) { os << val; });
473 os <<
") at " << val <<
">";
476static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
477 os <<
"<interleaved-sequence [";
478 llvm::interleaveComma(val->sequences, os,
479 [&](
const ElaboratorValue &val) { os << val; });
480 os <<
"] batch-size " << val->batchSize <<
" at " << val <<
">";
483static void print(SetStorage *val, llvm::raw_ostream &os) {
485 llvm::interleaveComma(val->set, os,
486 [&](
const ElaboratorValue &val) { os << val; });
487 os <<
"} at " << val <<
">";
490static void print(
const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
491 os <<
"<virtual-register " << val <<
" " << val->allowedRegs <<
">";
494static void print(
const UniqueLabelStorage *val, llvm::raw_ostream &os) {
495 os <<
"<unique-label " << val <<
" " << val->name <<
">";
498static void print(
const LabelValue &val, llvm::raw_ostream &os) {
499 os <<
"<label " << val.name <<
">";
503 const ElaboratorValue &value) {
504 std::visit([&](
auto val) {
print(val, os); }, value);
520 Materializer(OpBuilder builder) : builder(builder) {}
524 Value materialize(ElaboratorValue val, Location loc,
525 std::queue<RandomizedSequenceStorage *> &elabRequests,
526 function_ref<InFlightDiagnostic()> emitError) {
527 auto iter = materializedValues.find(val);
528 if (iter != materializedValues.end())
531 LLVM_DEBUG(llvm::dbgs() <<
"Materializing " << val <<
"\n\n");
534 [&](
auto val) {
return visit(val, loc, elabRequests, emitError); },
545 materialize(Operation *op, DenseMap<Value, ElaboratorValue> &state,
546 std::queue<RandomizedSequenceStorage *> &elabRequests) {
547 if (op->getNumRegions() > 0)
548 return op->emitOpError(
"ops with nested regions must be elaborated away");
556 for (
auto res : op->getResults())
557 if (!res.use_empty())
558 return op->emitOpError(
559 "ops with results that have uses are not supported");
561 if (op->getParentRegion() == builder.getBlock()->getParent()) {
564 deleteOpsUntil([&](
auto iter) {
return &*iter == op; });
566 if (builder.getInsertionPoint() == builder.getBlock()->end())
567 return op->emitError(
"operation did not occur after the current "
568 "materializer insertion point");
570 LLVM_DEBUG(llvm::dbgs() <<
"Modifying in-place: " << *op <<
"\n\n");
572 LLVM_DEBUG(llvm::dbgs() <<
"Materializing a clone of " << *op <<
"\n\n");
573 op = builder.clone(*op);
574 builder.setInsertionPoint(op);
577 for (
auto &operand : op->getOpOperands()) {
578 auto emitError = [&]() {
579 auto diag = op->emitError();
580 diag.attachNote(op->getLoc())
581 <<
"while materializing value for operand#"
582 << operand.getOperandNumber();
586 Value val = materialize(state.at(operand.get()), op->getLoc(),
587 elabRequests, emitError);
594 builder.setInsertionPointAfter(op);
601 deleteOpsUntil([](
auto iter) {
return false; });
603 for (
auto *op :
llvm::reverse(toDelete))
607 template <
typename OpTy,
typename... Args>
608 OpTy create(Location location, Args &&...args) {
609 return builder.create<OpTy>(location, std::forward<Args>(args)...);
613 void deleteOpsUntil(function_ref<
bool(Block::iterator)> stop) {
614 auto ip = builder.getInsertionPoint();
615 while (ip != builder.getBlock()->end() && !stop(ip)) {
616 LLVM_DEBUG(llvm::dbgs() <<
"Marking to be deleted: " << *ip <<
"\n\n");
617 toDelete.push_back(&*ip);
619 builder.setInsertionPointAfter(&*ip);
620 ip = builder.getInsertionPoint();
624 Value visit(TypedAttr val, Location loc,
625 std::queue<RandomizedSequenceStorage *> &elabRequests,
626 function_ref<InFlightDiagnostic()> emitError) {
629 if (
auto intAttr = dyn_cast<IntegerAttr>(val);
630 intAttr && isa<IndexType>(val.getType())) {
631 Value res = builder.create<index::ConstantOp>(loc, intAttr);
632 materializedValues[val] = res;
639 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
641 emitError() <<
"materializer of dialect '"
642 << val.getDialect().getNamespace()
643 <<
"' unable to materialize value for attribute '" << val
648 Value res = op->getResult(0);
649 materializedValues[val] = res;
653 Value visit(
size_t val, Location loc,
654 std::queue<RandomizedSequenceStorage *> &elabRequests,
655 function_ref<InFlightDiagnostic()> emitError) {
656 Value res = builder.create<index::ConstantOp>(loc, val);
657 materializedValues[val] = res;
661 Value visit(
bool val, Location loc,
662 std::queue<RandomizedSequenceStorage *> &elabRequests,
663 function_ref<InFlightDiagnostic()> emitError) {
664 Value res = builder.create<index::BoolConstantOp>(loc, val);
665 materializedValues[val] = res;
669 Value visit(SetStorage *val, Location loc,
670 std::queue<RandomizedSequenceStorage *> &elabRequests,
671 function_ref<InFlightDiagnostic()> emitError) {
672 SmallVector<Value> elements;
673 elements.reserve(val->set.size());
674 for (
auto el : val->set) {
675 auto materialized = materialize(el, loc, elabRequests, emitError);
679 elements.push_back(materialized);
682 auto res = builder.create<SetCreateOp>(loc, val->type, elements);
683 materializedValues[val] = res;
687 Value visit(BagStorage *val, Location loc,
688 std::queue<RandomizedSequenceStorage *> &elabRequests,
689 function_ref<InFlightDiagnostic()> emitError) {
690 SmallVector<Value> values, weights;
691 values.reserve(val->bag.size());
692 weights.reserve(val->bag.size());
693 for (
auto [val, weight] : val->bag) {
694 auto materializedVal = materialize(val, loc, elabRequests, emitError);
695 auto materializedWeight =
696 materialize(weight, loc, elabRequests, emitError);
697 if (!materializedVal || !materializedWeight)
700 values.push_back(materializedVal);
701 weights.push_back(materializedWeight);
704 auto res = builder.create<BagCreateOp>(loc, val->type, values, weights);
705 materializedValues[val] = res;
709 Value visit(SequenceStorage *val, Location loc,
710 std::queue<RandomizedSequenceStorage *> &elabRequests,
711 function_ref<InFlightDiagnostic()> emitError) {
712 emitError() <<
"materializing a non-randomized sequence not supported yet";
716 Value visit(RandomizedSequenceStorage *val, Location loc,
717 std::queue<RandomizedSequenceStorage *> &elabRequests,
718 function_ref<InFlightDiagnostic()> emitError) {
719 elabRequests.push(val);
720 Value
seq = builder.create<GetSequenceOp>(
721 loc, SequenceType::get(builder.getContext(), {}), val->name);
722 Value res = builder.create<RandomizeSequenceOp>(loc,
seq);
723 materializedValues[val] = res;
727 Value visit(InterleavedSequenceStorage *val, Location loc,
728 std::queue<RandomizedSequenceStorage *> &elabRequests,
729 function_ref<InFlightDiagnostic()> emitError) {
730 SmallVector<Value> sequences;
731 for (
auto seqVal : val->sequences)
732 sequences.push_back(materialize(seqVal, loc, elabRequests, emitError));
734 if (sequences.size() == 1)
738 builder.create<InterleaveSequencesOp>(loc, sequences, val->batchSize);
739 materializedValues[val] = res;
743 Value visit(VirtualRegisterStorage *val, Location loc,
744 std::queue<RandomizedSequenceStorage *> &elabRequests,
745 function_ref<InFlightDiagnostic()> emitError) {
746 Value res = builder.create<VirtualRegisterOp>(loc, val->allowedRegs);
747 materializedValues[val] = res;
751 Value visit(UniqueLabelStorage *val, Location loc,
752 std::queue<RandomizedSequenceStorage *> &elabRequests,
753 function_ref<InFlightDiagnostic()> emitError) {
754 Value res = builder.create<LabelUniqueDeclOp>(loc, val->name, ValueRange());
755 materializedValues[val] = res;
759 Value visit(
const LabelValue &val, Location loc,
760 std::queue<RandomizedSequenceStorage *> &elabRequests,
761 function_ref<InFlightDiagnostic()> emitError) {
762 Value res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
763 materializedValues[val] = res;
773 DenseMap<ElaboratorValue, Value> materializedValues;
779 SmallVector<Operation *> toDelete;
788enum class DeletionKind { Keep, Delete };
791struct ElaboratorSharedState {
792 ElaboratorSharedState(SymbolTable &table,
unsigned seed)
793 : table(table), rng(seed) {}
798 Internalizer internalizer;
802 std::queue<RandomizedSequenceStorage *> worklist;
812 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
818class Elaborator :
public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
821 using RTGBase::visitOp;
823 Elaborator(ElaboratorSharedState &sharedState, TestState &testState,
824 Materializer &materializer,
825 ContextResourceAttrInterface currentContext = {})
826 : sharedState(sharedState), testState(testState),
827 materializer(materializer), currentContext(currentContext) {}
829 template <
typename ValueTy>
830 inline ValueTy
get(Value val)
const {
831 return std::get<ValueTy>(state.at(val));
834 FailureOr<DeletionKind> visitConstantLike(Operation *op) {
835 assert(op->hasTrait<OpTrait::ConstantLike>() &&
836 "op is expected to be constant-like");
838 SmallVector<OpFoldResult, 1> result;
839 auto foldResult = op->fold(result);
841 assert(succeeded(foldResult) &&
842 "constant folder of a constant-like must always succeed");
843 auto attr = dyn_cast<TypedAttr>(result[0].dyn_cast<Attribute>());
845 return op->emitError(
846 "only typed attributes supported for constant-like operations");
848 auto intAttr = dyn_cast<IntegerAttr>(attr);
849 if (intAttr && isa<IndexType>(attr.getType()))
850 state[op->getResult(0)] = size_t(intAttr.getInt());
851 else if (intAttr && intAttr.getType().isSignlessInteger(1))
852 state[op->getResult(0)] = bool(intAttr.getInt());
854 state[op->getResult(0)] = attr;
856 return DeletionKind::Delete;
861 return op->emitOpError(
"elaboration not supported");
865 if (op->hasTrait<OpTrait::ConstantLike>())
866 return visitConstantLike(op);
872 return DeletionKind::Keep;
877 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
878 SmallVector<ElaboratorValue> replacements;
879 state[op.getResult()] =
880 sharedState.internalizer.internalize<SequenceStorage>(
881 op.getSequenceAttr(), std::move(replacements));
882 return DeletionKind::Delete;
885 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
886 auto *
seq = get<SequenceStorage *>(op.getSequence());
888 SmallVector<ElaboratorValue> replacements(
seq->args);
889 for (
auto replacement : op.getReplacements())
890 replacements.push_back(state.at(replacement));
892 state[op.getResult()] =
893 sharedState.internalizer.internalize<SequenceStorage>(
894 seq->familyName, std::move(replacements));
896 return DeletionKind::Delete;
899 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
900 auto *
seq = get<SequenceStorage *>(op.getSequence());
902 auto name = sharedState.names.newName(
seq->familyName.getValue());
903 auto *randomizedSeq =
904 sharedState.internalizer.internalize<RandomizedSequenceStorage>(
905 name, currentContext, testState.name,
seq);
906 state[op.getResult()] =
907 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
909 return DeletionKind::Delete;
912 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
913 SmallVector<ElaboratorValue> sequences;
914 for (
auto seq : op.getSequences())
915 sequences.push_back(
get<InterleavedSequenceStorage *>(
seq));
917 state[op.getResult()] =
918 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
919 std::move(sequences), op.getBatchSize());
920 return DeletionKind::Delete;
924 LogicalResult isValidContext(ElaboratorValue value, Operation *op)
const {
925 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
926 auto *
seq = std::get<RandomizedSequenceStorage *>(value);
927 if (
seq->context != currentContext) {
928 auto err = op->emitError(
"attempting to place sequence ")
929 <<
seq->name <<
" derived from "
930 <<
seq->sequence->familyName.getValue() <<
" under context "
932 <<
", but it was previously randomized for context ";
942 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
943 for (
auto val : interVal->sequences)
944 if (failed(isValidContext(val, op)))
949 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
950 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
951 if (failed(isValidContext(seqVal, op)))
954 return DeletionKind::Keep;
957 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
958 SetVector<ElaboratorValue> set;
959 for (
auto val : op.getElements())
960 set.insert(state.at(val));
962 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
963 std::move(set), op.getSet().getType());
964 return DeletionKind::Delete;
967 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
968 auto set = get<SetStorage *>(op.getSet())->set;
971 return op->emitError(
"cannot select from an empty set");
975 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
976 std::mt19937 customRng(intAttr.getInt());
982 state[op.getResult()] = set[selected];
983 return DeletionKind::Delete;
986 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
987 auto original = get<SetStorage *>(op.getOriginal())->set;
988 auto diff = get<SetStorage *>(op.getDiff())->set;
990 SetVector<ElaboratorValue> result(original);
991 result.set_subtract(diff);
993 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
994 std::move(result), op.getResult().getType());
995 return DeletionKind::Delete;
998 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
999 SetVector<ElaboratorValue> result;
1000 for (
auto set : op.getSets())
1001 result.set_union(
get<SetStorage *>(set)->set);
1003 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1004 std::move(result), op.getType());
1005 return DeletionKind::Delete;
1008 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1009 auto size = get<SetStorage *>(op.getSet())->set.size();
1010 state[op.getResult()] = size;
1011 return DeletionKind::Delete;
1014 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1015 MapVector<ElaboratorValue, uint64_t> bag;
1016 for (
auto [val, multiple] :
1017 llvm::zip(op.getElements(), op.getMultiples())) {
1021 bag[state.at(val)] += get<size_t>(multiple);
1024 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1025 std::move(bag), op.getType());
1026 return DeletionKind::Delete;
1029 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1030 auto bag = get<BagStorage *>(op.getBag())->bag;
1033 return op->emitError(
"cannot select from an empty bag");
1035 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1036 prefixSum.reserve(bag.size());
1037 uint32_t accumulator = 0;
1038 for (
auto [val, weight] : bag) {
1039 accumulator += weight;
1040 prefixSum.push_back({val, accumulator});
1043 auto customRng = sharedState.rng;
1045 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1046 customRng = std::mt19937(intAttr.getInt());
1050 auto *iter = llvm::upper_bound(
1052 [](uint32_t a,
const std::pair<ElaboratorValue, uint32_t> &b) {
1053 return a < b.second;
1056 state[op.getResult()] = iter->first;
1057 return DeletionKind::Delete;
1060 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1061 auto original = get<BagStorage *>(op.getOriginal())->bag;
1062 auto diff = get<BagStorage *>(op.getDiff())->bag;
1064 MapVector<ElaboratorValue, uint64_t> result;
1065 for (
const auto &el : original) {
1066 if (!diff.contains(el.first)) {
1074 auto toDiff = diff.lookup(el.first);
1075 if (el.second <= toDiff)
1078 result.insert({el.first, el.second - toDiff});
1081 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1082 std::move(result), op.getType());
1083 return DeletionKind::Delete;
1086 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1087 MapVector<ElaboratorValue, uint64_t> result;
1088 for (
auto bag : op.getBags()) {
1089 auto val = get<BagStorage *>(bag)->bag;
1090 for (
auto [el, multiple] : val)
1091 result[el] += multiple;
1094 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1095 std::move(result), op.getType());
1096 return DeletionKind::Delete;
1099 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1100 auto size = get<BagStorage *>(op.getBag())->bag.size();
1101 state[op.getResult()] = size;
1102 return DeletionKind::Delete;
1105 FailureOr<DeletionKind> visitOp(FixedRegisterOp op) {
1106 return visitConstantLike(op);
1109 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1110 state[op.getResult()] =
1111 sharedState.internalizer.create<VirtualRegisterStorage>(
1112 op.getAllowedRegsAttr());
1113 return DeletionKind::Delete;
1116 StringAttr substituteFormatString(StringAttr formatString,
1117 ValueRange substitutes)
const {
1118 if (substitutes.empty() || formatString.empty())
1119 return formatString;
1121 auto original = formatString.getValue().str();
1122 for (
auto [i, subst] :
llvm::enumerate(substitutes)) {
1123 size_t startPos = 0;
1124 std::string from =
"{{" + std::to_string(i) +
"}}";
1125 while ((startPos = original.find(from, startPos)) != std::string::npos) {
1126 auto substString = std::to_string(get<size_t>(subst));
1127 original.replace(startPos, from.length(), substString);
1131 return StringAttr::get(formatString.getContext(), original);
1134 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
1136 substituteFormatString(op.getFormatStringAttr(), op.getArgs());
1137 state[op.getLabel()] = LabelValue(substituted);
1138 return DeletionKind::Delete;
1141 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1142 state[op.getLabel()] = sharedState.internalizer.create<UniqueLabelStorage>(
1143 substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
1144 return DeletionKind::Delete;
1147 FailureOr<DeletionKind> visitOp(LabelOp op) {
return DeletionKind::Keep; }
1149 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1150 size_t lower = get<size_t>(op.getLowerBound());
1151 size_t upper = get<size_t>(op.getUpperBound()) - 1;
1153 return op->emitError(
"cannot select a number from an empty range");
1156 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1157 std::mt19937 customRng(intAttr.getInt());
1158 state[op.getResult()] =
1161 state[op.getResult()] =
1165 return DeletionKind::Delete;
1168 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1169 ContextResourceAttrInterface from = currentContext,
1170 to = cast<ContextResourceAttrInterface>(
1171 get<TypedAttr>(op.getContext()));
1172 if (!currentContext)
1173 from = DefaultContextAttr::get(op->getContext(), to.getType());
1175 auto emitError = [&]() {
1176 auto diag = op.emitError();
1177 diag.attachNote(op.getLoc())
1178 <<
"while materializing value for context switching for " << op;
1183 Value seqVal = materializer.materialize(
1184 get<SequenceStorage *>(op.getSequence()), op.getLoc(),
1185 sharedState.worklist, emitError);
1187 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1188 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1189 return DeletionKind::Delete;
1193 auto *iter = testState.contextSwitches.find({from, to});
1196 if (iter == testState.contextSwitches.end())
1197 return op->emitError(
"no context transition registered to switch from ")
1198 << from <<
" to " << to;
1200 auto familyName = iter->second->familyName;
1201 SmallVector<ElaboratorValue> args{from, to,
1202 get<SequenceStorage *>(op.getSequence())};
1203 auto *
seq = sharedState.internalizer.internalize<SequenceStorage>(
1204 familyName, std::move(args));
1206 sharedState.internalizer.internalize<RandomizedSequenceStorage>(
1207 sharedState.names.newName(familyName.getValue()), to,
1208 testState.name,
seq);
1209 Value seqVal = materializer.materialize(randSeq, op.getLoc(),
1210 sharedState.worklist, emitError);
1211 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1213 return DeletionKind::Delete;
1216 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1217 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1218 get<SequenceStorage *>(op.getSequence());
1219 return DeletionKind::Delete;
1222 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1223 bool cond = get<bool>(op.getCondition());
1224 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1225 if (toElaborate.empty())
1226 return DeletionKind::Delete;
1232 if (failed(elaborate(toElaborate)))
1236 for (
auto [res, out] :
1237 llvm::zip(op.getResults(),
1238 toElaborate.front().getTerminator()->getOperands()))
1239 state[res] = state.at(out);
1241 return DeletionKind::Delete;
1244 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1245 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1246 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1247 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1248 return op->emitOpError(
"can only elaborate index type iterator");
1250 auto lowerBound = get<size_t>(op.getLowerBound());
1251 auto step = get<size_t>(op.getStep());
1252 auto upperBound = get<size_t>(op.getUpperBound());
1258 state[op.getInductionVar()] = lowerBound;
1259 for (
auto [iterArg, initArg] :
1260 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1261 state[iterArg] = state.at(initArg);
1264 for (
size_t i = lowerBound; i < upperBound; i += step) {
1265 if (failed(elaborate(op.getBodyRegion())))
1270 state[op.getInductionVar()] = i + step;
1271 for (
auto [iterArg, prevIterArg] :
1272 llvm::zip(op.getRegionIterArgs(),
1273 op.getBody()->getTerminator()->getOperands()))
1274 state[iterArg] = state.at(prevIterArg);
1278 for (
auto [res, iterArg] :
1279 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1280 state[res] = state.at(iterArg);
1282 return DeletionKind::Delete;
1285 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1286 return DeletionKind::Delete;
1289 FailureOr<DeletionKind> visitOp(index::AddOp op) {
1290 size_t lhs = get<size_t>(op.getLhs());
1291 size_t rhs = get<size_t>(op.getRhs());
1292 state[op.getResult()] = lhs + rhs;
1293 return DeletionKind::Delete;
1296 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
1297 size_t lhs = get<size_t>(op.getLhs());
1298 size_t rhs = get<size_t>(op.getRhs());
1300 switch (op.getPred()) {
1301 case index::IndexCmpPredicate::EQ:
1302 result = lhs == rhs;
1304 case index::IndexCmpPredicate::NE:
1305 result = lhs != rhs;
1307 case index::IndexCmpPredicate::ULT:
1310 case index::IndexCmpPredicate::ULE:
1311 result = lhs <= rhs;
1313 case index::IndexCmpPredicate::UGT:
1316 case index::IndexCmpPredicate::UGE:
1317 result = lhs >= rhs;
1320 return op->emitOpError(
"elaboration not supported");
1322 state[op.getResult()] = result;
1323 return DeletionKind::Delete;
1327 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
1330 index::AddOp, index::CmpOp,
1332 scf::IfOp, scf::ForOp, scf::YieldOp>(
1333 [&](
auto op) {
return visitOp(op); })
1334 .Default([&](Operation *op) {
return RTGBase::dispatchOpVisitor(op); });
1338 LogicalResult elaborate(Region ®ion,
1339 ArrayRef<ElaboratorValue> regionArguments = {}) {
1340 if (region.getBlocks().size() > 1)
1341 return region.getParentOp()->emitOpError(
1342 "regions with more than one block are not supported");
1344 for (
auto [arg, elabArg] :
1345 llvm::zip(region.getArguments(), regionArguments))
1346 state[arg] = elabArg;
1348 Block *block = ®ion.front();
1349 for (
auto &op : *block) {
1354 if (*result == DeletionKind::Keep)
1355 if (failed(materializer.materialize(&op, state, sharedState.worklist)))
1359 llvm::dbgs() <<
"Elaborated " << op <<
" to\n[";
1361 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](
auto res) {
1362 if (state.contains(res))
1363 llvm::dbgs() << state.at(res);
1365 llvm::dbgs() <<
"unknown";
1368 llvm::dbgs() <<
"]\n\n";
1377 ElaboratorSharedState &sharedState;
1380 TestState &testState;
1384 Materializer &materializer;
1387 DenseMap<Value, ElaboratorValue> state;
1390 ContextResourceAttrInterface currentContext;
1399struct ElaborationPass
1400 :
public rtg::impl::ElaborationPassBase<ElaborationPass> {
1403 void runOnOperation()
override;
1404 void cloneTargetsIntoTests(SymbolTable &table);
1405 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
1409void ElaborationPass::runOnOperation() {
1410 auto moduleOp = getOperation();
1411 SymbolTable table(moduleOp);
1413 cloneTargetsIntoTests(table);
1415 if (failed(elaborateModule(moduleOp, table)))
1416 return signalPassFailure();
1419void ElaborationPass::cloneTargetsIntoTests(SymbolTable &table) {
1420 auto moduleOp = getOperation();
1421 for (
auto target :
llvm::make_early_inc_range(moduleOp.getOps<TargetOp>())) {
1422 for (
auto test : moduleOp.getOps<TestOp>()) {
1424 if (test.getTarget().getEntries().empty())
1429 if (target.getTarget() != test.getTarget())
1432 IRRewriter rewriter(test);
1434 auto newTest = cast<TestOp>(test->clone());
1435 newTest.setSymName(test.getSymName().str() +
"_" +
1436 target.getSymName().str());
1437 table.insert(newTest, rewriter.getInsertionPoint());
1441 rewriter.setInsertionPointToStart(newTest.getBody());
1442 for (
auto &op : target.getBody()->without_terminator())
1443 rewriter.clone(op, mapping);
1445 for (
auto [returnVal, result] :
1446 llvm::zip(target.getBody()->getTerminator()->getOperands(),
1447 newTest.getBody()->getArguments()))
1448 result.replaceAllUsesWith(mapping.lookup(returnVal));
1450 newTest.getBody()->eraseArguments(0,
1451 newTest.getBody()->getNumArguments());
1452 newTest.setTarget(DictType::get(&getContext(), {}));
1459 for (
auto test :
llvm::make_early_inc_range(moduleOp.getOps<TestOp>()))
1460 if (!test.getTarget().getEntries().
empty())
1464LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
1465 SymbolTable &table) {
1466 ElaboratorSharedState state(table, seed);
1469 state.names.add(moduleOp);
1473 DenseMap<StringAttr, TestState> testStates;
1474 for (
auto testOp : moduleOp.getOps<TestOp>()) {
1475 LLVM_DEBUG(llvm::dbgs()
1476 <<
"\n=== Elaborating test @" << testOp.getSymName() <<
"\n\n");
1477 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()));
1478 testStates[testOp.getSymNameAttr()].name = testOp.getSymNameAttr();
1479 Elaborator elaborator(state, testStates[testOp.getSymNameAttr()],
1481 if (failed(elaborator.elaborate(testOp.getBodyRegion())))
1484 materializer.finalize();
1489 while (!state.worklist.empty()) {
1490 auto *curr = state.worklist.front();
1491 state.worklist.pop();
1493 if (table.lookup<SequenceOp>(curr->name))
1496 auto familyOp = table.lookup<SequenceOp>(curr->sequence->familyName);
1499 OpBuilder builder(familyOp);
1500 auto seqOp = builder.cloneWithoutRegions(familyOp);
1501 seqOp.getBodyRegion().emplaceBlock();
1502 seqOp.setSymName(curr->name);
1503 seqOp.setSequenceType(
1504 SequenceType::get(builder.getContext(), ArrayRef<Type>{}));
1505 table.insert(seqOp);
1506 assert(seqOp.getSymName() == curr->name &&
"should not have been renamed");
1508 LLVM_DEBUG(llvm::dbgs()
1509 <<
"\n=== Elaborating sequence family @" << familyOp.getSymName()
1510 <<
" into @" << seqOp.getSymName() <<
" under context "
1511 << curr->context <<
"\n\n");
1513 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()));
1514 Elaborator elaborator(state, testStates[curr->test], materializer,
1516 if (failed(elaborator.elaborate(familyOp.getBodyRegion(),
1517 curr->sequence->args)))
1520 materializer.finalize();
assert(baseType &&"element must be base type")
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static void print(TypedAttr val, llvm::raw_ostream &os)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
This helps visit TypeOp nodes.
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
ResultType dispatchOpVisitor(Operation *op, ExtraArgs... args)
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
static llvm::hash_code hash_value(const ModulePort &port)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static LabelValue getEmptyKey()
static LabelValue getTombstoneKey()
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)
static unsigned getEmptyKey()