21#include "mlir/Dialect/Index/IR/IndexDialect.h"
22#include "mlir/Dialect/Index/IR/IndexOps.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/IR/IRMapping.h"
25#include "mlir/IR/PatternMatch.h"
26#include "llvm/ADT/DenseMapInfoVariant.h"
27#include "llvm/Support/Debug.h"
33#define GEN_PASS_DEF_ELABORATIONPASS
34#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
43#define DEBUG_TYPE "rtg-elaboration"
55 size_t n = w / 32 + (w % 32 != 0);
57 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
62 const uint32_t diff = b - a + 1;
66 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
70 uint32_t width = digits - llvm::countl_zero(diff) - 1;
71 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
90struct SequenceStorage;
91struct RandomizedSequenceStorage;
92struct InterleavedSequenceStorage;
94struct VirtualRegisterStorage;
95struct UniqueLabelStorage;
101 LabelValue(StringAttr name) : name(name) {}
103 bool operator==(
const LabelValue &other)
const {
return name == other.name; }
110using ElaboratorValue =
111 std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
112 RandomizedSequenceStorage *, InterleavedSequenceStorage *,
113 SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
114 LabelValue, ArrayStorage *>;
117llvm::hash_code
hash_value(
const LabelValue &val) {
118 return llvm::hash_value(val.name);
122llvm::hash_code
hash_value(
const ElaboratorValue &val) {
124 [&val](
const auto &alternative) {
127 return llvm::hash_combine(val.index(), alternative);
142 static bool isEqual(
const bool &lhs,
const bool &rhs) {
return lhs == rhs; }
156 static bool isEqual(
const LabelValue &lhs,
const LabelValue &rhs) {
156 static bool isEqual(
const LabelValue &lhs,
const LabelValue &rhs) {
…}
173template <
typename StorageTy>
174struct HashedStorage {
175 HashedStorage(
unsigned hashcode = 0, StorageTy *storage =
nullptr)
176 : hashcode(hashcode), storage(storage) {}
186template <
typename StorageTy>
187struct StorageKeyInfo {
188 static inline HashedStorage<StorageTy> getEmptyKey() {
189 return HashedStorage<StorageTy>(0,
190 DenseMapInfo<StorageTy *>::getEmptyKey());
192 static inline HashedStorage<StorageTy> getTombstoneKey() {
193 return HashedStorage<StorageTy>(
194 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
197 static inline unsigned getHashValue(
const HashedStorage<StorageTy> &key) {
200 static inline unsigned getHashValue(
const StorageTy &key) {
204 static inline bool isEqual(
const HashedStorage<StorageTy> &lhs,
205 const HashedStorage<StorageTy> &rhs) {
206 return lhs.storage == rhs.storage;
208 static inline bool isEqual(
const StorageTy &lhs,
209 const HashedStorage<StorageTy> &rhs) {
210 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
213 return lhs.isEqual(rhs.storage);
219 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
221 type,
llvm::hash_combine_range(set.begin(), set.
end()))),
222 set(std::move(set)), type(type) {}
224 bool isEqual(
const SetStorage *other)
const {
225 return hashcode == other->hashcode && set == other->set &&
230 const unsigned hashcode;
233 const SetVector<ElaboratorValue> set;
242 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
244 type,
llvm::hash_combine_range(bag.begin(), bag.
end()))),
245 bag(std::move(bag)), type(type) {}
247 bool isEqual(
const BagStorage *other)
const {
248 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
253 const unsigned hashcode;
257 const MapVector<ElaboratorValue, uint64_t> bag;
265struct SequenceStorage {
266 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
268 familyName,
llvm::hash_combine_range(args.begin(), args.
end()))),
269 familyName(familyName), args(std::move(args)) {}
271 bool isEqual(
const SequenceStorage *other)
const {
272 return hashcode == other->hashcode && familyName == other->familyName &&
277 const unsigned hashcode;
280 const StringAttr familyName;
283 const SmallVector<ElaboratorValue> args;
287struct RandomizedSequenceStorage {
288 RandomizedSequenceStorage(StringRef name,
289 ContextResourceAttrInterface context,
290 StringAttr test, SequenceStorage *sequence)
292 context(context), test(test), sequence(sequence) {}
294 bool isEqual(
const RandomizedSequenceStorage *other)
const {
295 return hashcode == other->hashcode && name == other->name &&
296 context == other->context && test == other->test &&
297 sequence == other->sequence;
301 const unsigned hashcode;
304 const StringRef name;
307 const ContextResourceAttrInterface context;
310 const StringAttr test;
312 const SequenceStorage *sequence;
316struct InterleavedSequenceStorage {
317 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
319 : sequences(std::move(sequences)), batchSize(batchSize),
321 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
324 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
325 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
327 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
330 bool isEqual(
const InterleavedSequenceStorage *other)
const {
331 return hashcode == other->hashcode && sequences == other->sequences &&
332 batchSize == other->batchSize;
335 const SmallVector<ElaboratorValue> sequences;
337 const uint32_t batchSize;
340 const unsigned hashcode;
344struct VirtualRegisterStorage {
345 VirtualRegisterStorage(ArrayAttr allowedRegs) : allowedRegs(allowedRegs) {}
352 const ArrayAttr allowedRegs;
355struct UniqueLabelStorage {
356 UniqueLabelStorage(StringAttr name) : name(name) {}
362 const StringAttr name;
367 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
369 type,
llvm::hash_combine_range(array.begin(), array.
end()))),
370 type(type), array(array) {}
372 bool isEqual(
const ArrayStorage *other)
const {
373 return hashcode == other->hashcode && type == other->type &&
374 array == other->array;
378 const unsigned hashcode;
385 const SmallVector<ElaboratorValue> array;
398 template <
typename StorageTy,
typename... Args>
399 StorageTy *internalize(Args &&...args) {
400 StorageTy storage(std::forward<Args>(args)...);
402 auto existing = getInternSet<StorageTy>().insert_as(
403 HashedStorage<StorageTy>(storage.hashcode), storage);
404 StorageTy *&storagePtr = existing.first->storage;
407 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
412 template <
typename StorageTy,
typename... Args>
413 StorageTy *create(Args &&...args) {
414 return new (allocator.Allocate<StorageTy>())
415 StorageTy(std::forward<Args>(args)...);
419 template <
typename StorageTy>
420 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
422 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
423 return internedArrays;
424 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
426 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
428 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
429 return internedSequences;
430 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
431 return internedRandomizedSequences;
432 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
433 return internedInterleavedSequences;
435 static_assert(!
sizeof(StorageTy),
436 "no intern set available for this storage type.");
441 llvm::BumpPtrAllocator allocator;
446 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
448 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
449 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
450 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
452 DenseSet<HashedStorage<RandomizedSequenceStorage>,
453 StorageKeyInfo<RandomizedSequenceStorage>>
454 internedRandomizedSequences;
455 DenseSet<HashedStorage<InterleavedSequenceStorage>,
456 StorageKeyInfo<InterleavedSequenceStorage>>
457 internedInterleavedSequences;
464static llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
465 const ElaboratorValue &value);
467static void print(TypedAttr val, llvm::raw_ostream &os) {
468 os <<
"<attr " << val <<
">";
467static void print(TypedAttr val, llvm::raw_ostream &os) {
…}
471static void print(BagStorage *val, llvm::raw_ostream &os) {
473 llvm::interleaveComma(val->bag, os,
474 [&](
const std::pair<ElaboratorValue, uint64_t> &el) {
475 os << el.first <<
" -> " << el.second;
477 os <<
"} at " << val <<
">";
471static void print(BagStorage *val, llvm::raw_ostream &os) {
…}
480static void print(
bool val, llvm::raw_ostream &os) {
481 os <<
"<bool " << (val ?
"true" :
"false") <<
">";
480static void print(
bool val, llvm::raw_ostream &os) {
…}
484static void print(
size_t val, llvm::raw_ostream &os) {
485 os <<
"<index " << val <<
">";
484static void print(
size_t val, llvm::raw_ostream &os) {
…}
488static void print(SequenceStorage *val, llvm::raw_ostream &os) {
489 os <<
"<sequence @" << val->familyName.getValue() <<
"(";
490 llvm::interleaveComma(val->args, os,
491 [&](
const ElaboratorValue &val) { os << val; });
492 os <<
") at " << val <<
">";
488static void print(SequenceStorage *val, llvm::raw_ostream &os) {
…}
495static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
496 os <<
"<randomized-sequence @" << val->name <<
" derived from @"
497 << val->sequence->familyName.getValue() <<
" under context "
498 << val->context <<
" in test " << val->test <<
"(";
499 llvm::interleaveComma(val->sequence->args, os,
500 [&](
const ElaboratorValue &val) { os << val; });
501 os <<
") at " << val <<
">";
495static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
…}
504static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
505 os <<
"<interleaved-sequence [";
506 llvm::interleaveComma(val->sequences, os,
507 [&](
const ElaboratorValue &val) { os << val; });
508 os <<
"] batch-size " << val->batchSize <<
" at " << val <<
">";
504static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
…}
511static void print(ArrayStorage *val, llvm::raw_ostream &os) {
513 llvm::interleaveComma(val->array, os,
514 [&](
const ElaboratorValue &val) { os << val; });
515 os <<
"] at " << val <<
">";
511static void print(ArrayStorage *val, llvm::raw_ostream &os) {
…}
518static void print(SetStorage *val, llvm::raw_ostream &os) {
520 llvm::interleaveComma(val->set, os,
521 [&](
const ElaboratorValue &val) { os << val; });
522 os <<
"} at " << val <<
">";
518static void print(SetStorage *val, llvm::raw_ostream &os) {
…}
525static void print(
const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
526 os <<
"<virtual-register " << val <<
" " << val->allowedRegs <<
">";
525static void print(
const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
…}
529static void print(
const UniqueLabelStorage *val, llvm::raw_ostream &os) {
530 os <<
"<unique-label " << val <<
" " << val->name <<
">";
529static void print(
const UniqueLabelStorage *val, llvm::raw_ostream &os) {
…}
533static void print(
const LabelValue &val, llvm::raw_ostream &os) {
534 os <<
"<label " << val.name <<
">";
533static void print(
const LabelValue &val, llvm::raw_ostream &os) {
…}
538 const ElaboratorValue &value) {
539 std::visit([&](
auto val) {
print(val, os); }, value);
555 Materializer(OpBuilder builder) : builder(builder) {}
559 Value materialize(ElaboratorValue val, Location loc,
560 std::queue<RandomizedSequenceStorage *> &elabRequests,
561 function_ref<InFlightDiagnostic()> emitError) {
562 auto iter = materializedValues.find(val);
563 if (iter != materializedValues.end())
566 LLVM_DEBUG(llvm::dbgs() <<
"Materializing " << val <<
"\n\n");
569 [&](
auto val) {
return visit(val, loc, elabRequests, emitError); },
580 materialize(Operation *op, DenseMap<Value, ElaboratorValue> &state,
581 std::queue<RandomizedSequenceStorage *> &elabRequests) {
582 if (op->getNumRegions() > 0)
583 return op->emitOpError(
"ops with nested regions must be elaborated away");
591 for (
auto res : op->getResults())
592 if (!res.use_empty())
593 return op->emitOpError(
594 "ops with results that have uses are not supported");
596 if (op->getParentRegion() == builder.getBlock()->getParent()) {
599 deleteOpsUntil([&](
auto iter) {
return &*iter == op; });
601 if (builder.getInsertionPoint() == builder.getBlock()->end())
602 return op->emitError(
"operation did not occur after the current "
603 "materializer insertion point");
605 LLVM_DEBUG(llvm::dbgs() <<
"Modifying in-place: " << *op <<
"\n\n");
607 LLVM_DEBUG(llvm::dbgs() <<
"Materializing a clone of " << *op <<
"\n\n");
608 op = builder.clone(*op);
609 builder.setInsertionPoint(op);
612 for (
auto &operand : op->getOpOperands()) {
613 auto emitError = [&]() {
614 auto diag = op->emitError();
615 diag.attachNote(op->getLoc())
616 <<
"while materializing value for operand#"
617 << operand.getOperandNumber();
621 Value val = materialize(state.at(operand.get()), op->getLoc(),
622 elabRequests, emitError);
629 builder.setInsertionPointAfter(op);
636 deleteOpsUntil([](
auto iter) {
return false; });
638 for (
auto *op :
llvm::reverse(toDelete))
642 template <
typename OpTy,
typename... Args>
643 OpTy create(Location location, Args &&...args) {
644 return builder.create<OpTy>(location, std::forward<Args>(args)...);
648 void deleteOpsUntil(function_ref<
bool(Block::iterator)> stop) {
649 auto ip = builder.getInsertionPoint();
650 while (ip != builder.getBlock()->end() && !stop(ip)) {
651 LLVM_DEBUG(llvm::dbgs() <<
"Marking to be deleted: " << *ip <<
"\n\n");
652 toDelete.push_back(&*ip);
654 builder.setInsertionPointAfter(&*ip);
655 ip = builder.getInsertionPoint();
659 Value visit(TypedAttr val, Location loc,
660 std::queue<RandomizedSequenceStorage *> &elabRequests,
661 function_ref<InFlightDiagnostic()> emitError) {
664 if (
auto intAttr = dyn_cast<IntegerAttr>(val);
665 intAttr && isa<IndexType>(val.getType())) {
666 Value res = builder.create<index::ConstantOp>(loc, intAttr);
667 materializedValues[val] = res;
674 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
676 emitError() <<
"materializer of dialect '"
677 << val.getDialect().getNamespace()
678 <<
"' unable to materialize value for attribute '" << val
683 Value res = op->getResult(0);
684 materializedValues[val] = res;
688 Value visit(
size_t val, Location loc,
689 std::queue<RandomizedSequenceStorage *> &elabRequests,
690 function_ref<InFlightDiagnostic()> emitError) {
691 Value res = builder.create<index::ConstantOp>(loc, val);
692 materializedValues[val] = res;
696 Value visit(
bool val, Location loc,
697 std::queue<RandomizedSequenceStorage *> &elabRequests,
698 function_ref<InFlightDiagnostic()> emitError) {
699 Value res = builder.create<index::BoolConstantOp>(loc, val);
700 materializedValues[val] = res;
704 Value visit(ArrayStorage *val, Location loc,
705 std::queue<RandomizedSequenceStorage *> &elabRequests,
706 function_ref<InFlightDiagnostic()> emitError) {
707 SmallVector<Value> elements;
708 elements.reserve(val->array.size());
709 for (
auto el : val->array) {
710 auto materialized = materialize(el, loc, elabRequests, emitError);
714 elements.push_back(materialized);
717 auto res = builder.create<ArrayCreateOp>(loc, val->type, elements);
718 materializedValues[val] = res;
722 Value visit(SetStorage *val, Location loc,
723 std::queue<RandomizedSequenceStorage *> &elabRequests,
724 function_ref<InFlightDiagnostic()> emitError) {
725 SmallVector<Value> elements;
726 elements.reserve(val->set.size());
727 for (
auto el : val->set) {
728 auto materialized = materialize(el, loc, elabRequests, emitError);
732 elements.push_back(materialized);
735 auto res = builder.create<SetCreateOp>(loc, val->type, elements);
736 materializedValues[val] = res;
740 Value visit(BagStorage *val, Location loc,
741 std::queue<RandomizedSequenceStorage *> &elabRequests,
742 function_ref<InFlightDiagnostic()> emitError) {
743 SmallVector<Value> values, weights;
744 values.reserve(val->bag.size());
745 weights.reserve(val->bag.size());
746 for (
auto [val, weight] : val->bag) {
747 auto materializedVal = materialize(val, loc, elabRequests, emitError);
748 auto materializedWeight =
749 materialize(weight, loc, elabRequests, emitError);
750 if (!materializedVal || !materializedWeight)
753 values.push_back(materializedVal);
754 weights.push_back(materializedWeight);
757 auto res = builder.create<BagCreateOp>(loc, val->type, values, weights);
758 materializedValues[val] = res;
762 Value visit(SequenceStorage *val, Location loc,
763 std::queue<RandomizedSequenceStorage *> &elabRequests,
764 function_ref<InFlightDiagnostic()> emitError) {
765 emitError() <<
"materializing a non-randomized sequence not supported yet";
769 Value visit(RandomizedSequenceStorage *val, Location loc,
770 std::queue<RandomizedSequenceStorage *> &elabRequests,
771 function_ref<InFlightDiagnostic()> emitError) {
772 elabRequests.push(val);
773 Value
seq = builder.create<GetSequenceOp>(
774 loc, SequenceType::get(builder.getContext(), {}), val->name);
775 Value res = builder.create<RandomizeSequenceOp>(loc,
seq);
776 materializedValues[val] = res;
780 Value visit(InterleavedSequenceStorage *val, Location loc,
781 std::queue<RandomizedSequenceStorage *> &elabRequests,
782 function_ref<InFlightDiagnostic()> emitError) {
783 SmallVector<Value> sequences;
784 for (
auto seqVal : val->sequences)
785 sequences.push_back(materialize(seqVal, loc, elabRequests, emitError));
787 if (sequences.size() == 1)
791 builder.create<InterleaveSequencesOp>(loc, sequences, val->batchSize);
792 materializedValues[val] = res;
796 Value visit(VirtualRegisterStorage *val, Location loc,
797 std::queue<RandomizedSequenceStorage *> &elabRequests,
798 function_ref<InFlightDiagnostic()> emitError) {
799 Value res = builder.create<VirtualRegisterOp>(loc, val->allowedRegs);
800 materializedValues[val] = res;
804 Value visit(UniqueLabelStorage *val, Location loc,
805 std::queue<RandomizedSequenceStorage *> &elabRequests,
806 function_ref<InFlightDiagnostic()> emitError) {
807 Value res = builder.create<LabelUniqueDeclOp>(loc, val->name, ValueRange());
808 materializedValues[val] = res;
812 Value visit(
const LabelValue &val, Location loc,
813 std::queue<RandomizedSequenceStorage *> &elabRequests,
814 function_ref<InFlightDiagnostic()> emitError) {
815 Value res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
816 materializedValues[val] = res;
826 DenseMap<ElaboratorValue, Value> materializedValues;
832 SmallVector<Operation *> toDelete;
841enum class DeletionKind { Keep, Delete };
844struct ElaboratorSharedState {
845 ElaboratorSharedState(SymbolTable &table,
unsigned seed)
846 : table(table), rng(seed) {}
851 Internalizer internalizer;
855 std::queue<RandomizedSequenceStorage *> worklist;
865 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
871class Elaborator :
public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
874 using RTGBase::visitOp;
876 Elaborator(ElaboratorSharedState &sharedState, TestState &testState,
877 Materializer &materializer,
878 ContextResourceAttrInterface currentContext = {})
879 : sharedState(sharedState), testState(testState),
880 materializer(materializer), currentContext(currentContext) {}
882 template <
typename ValueTy>
883 inline ValueTy
get(Value val)
const {
884 return std::get<ValueTy>(state.at(val));
887 FailureOr<DeletionKind> visitConstantLike(Operation *op) {
888 assert(op->hasTrait<OpTrait::ConstantLike>() &&
889 "op is expected to be constant-like");
891 SmallVector<OpFoldResult, 1> result;
892 auto foldResult = op->fold(result);
894 assert(succeeded(foldResult) &&
895 "constant folder of a constant-like must always succeed");
896 auto attr = dyn_cast<TypedAttr>(result[0].dyn_cast<Attribute>());
898 return op->emitError(
899 "only typed attributes supported for constant-like operations");
901 auto intAttr = dyn_cast<IntegerAttr>(attr);
902 if (intAttr && isa<IndexType>(attr.getType()))
903 state[op->getResult(0)] = size_t(intAttr.getInt());
904 else if (intAttr && intAttr.getType().isSignlessInteger(1))
905 state[op->getResult(0)] = bool(intAttr.getInt());
907 state[op->getResult(0)] = attr;
909 return DeletionKind::Delete;
914 return op->emitOpError(
"elaboration not supported");
918 if (op->hasTrait<OpTrait::ConstantLike>())
919 return visitConstantLike(op);
925 return DeletionKind::Keep;
930 FailureOr<DeletionKind> visitOp(ConstantOp op) {
931 return visitConstantLike(op);
934 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
935 SmallVector<ElaboratorValue> replacements;
936 state[op.getResult()] =
937 sharedState.internalizer.internalize<SequenceStorage>(
938 op.getSequenceAttr(), std::move(replacements));
939 return DeletionKind::Delete;
942 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
943 auto *
seq = get<SequenceStorage *>(op.getSequence());
945 SmallVector<ElaboratorValue> replacements(
seq->args);
946 for (
auto replacement : op.getReplacements())
947 replacements.push_back(state.at(replacement));
949 state[op.getResult()] =
950 sharedState.internalizer.internalize<SequenceStorage>(
951 seq->familyName, std::move(replacements));
953 return DeletionKind::Delete;
956 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
957 auto *
seq = get<SequenceStorage *>(op.getSequence());
959 auto name = sharedState.names.newName(
seq->familyName.getValue());
960 auto *randomizedSeq =
961 sharedState.internalizer.internalize<RandomizedSequenceStorage>(
962 name, currentContext, testState.name,
seq);
963 state[op.getResult()] =
964 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
966 return DeletionKind::Delete;
969 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
970 SmallVector<ElaboratorValue> sequences;
971 for (
auto seq : op.getSequences())
972 sequences.push_back(
get<InterleavedSequenceStorage *>(
seq));
974 state[op.getResult()] =
975 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
976 std::move(sequences), op.getBatchSize());
977 return DeletionKind::Delete;
981 LogicalResult isValidContext(ElaboratorValue value, Operation *op)
const {
982 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
983 auto *
seq = std::get<RandomizedSequenceStorage *>(value);
984 if (
seq->context != currentContext) {
985 auto err = op->emitError(
"attempting to place sequence ")
986 <<
seq->name <<
" derived from "
987 <<
seq->sequence->familyName.getValue() <<
" under context "
989 <<
", but it was previously randomized for context ";
999 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1000 for (
auto val : interVal->sequences)
1001 if (failed(isValidContext(val, op)))
1006 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1007 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1008 if (failed(isValidContext(seqVal, op)))
1011 return DeletionKind::Keep;
1014 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1015 SetVector<ElaboratorValue> set;
1016 for (
auto val : op.getElements())
1017 set.insert(state.at(val));
1019 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1020 std::move(set), op.getSet().getType());
1021 return DeletionKind::Delete;
1024 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1025 auto set = get<SetStorage *>(op.getSet())->set;
1028 return op->emitError(
"cannot select from an empty set");
1032 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1033 std::mt19937 customRng(intAttr.getInt());
1039 state[op.getResult()] = set[selected];
1040 return DeletionKind::Delete;
1043 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1044 auto original = get<SetStorage *>(op.getOriginal())->set;
1045 auto diff = get<SetStorage *>(op.getDiff())->set;
1047 SetVector<ElaboratorValue> result(original);
1048 result.set_subtract(diff);
1050 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1051 std::move(result), op.getResult().getType());
1052 return DeletionKind::Delete;
1055 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1056 SetVector<ElaboratorValue> result;
1057 for (
auto set : op.getSets())
1058 result.set_union(
get<SetStorage *>(set)->set);
1060 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1061 std::move(result), op.getType());
1062 return DeletionKind::Delete;
1065 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1066 auto size = get<SetStorage *>(op.getSet())->set.size();
1067 state[op.getResult()] = size;
1068 return DeletionKind::Delete;
1071 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1072 MapVector<ElaboratorValue, uint64_t> bag;
1073 for (
auto [val, multiple] :
1074 llvm::zip(op.getElements(), op.getMultiples())) {
1078 bag[state.at(val)] += get<size_t>(multiple);
1081 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1082 std::move(bag), op.getType());
1083 return DeletionKind::Delete;
1086 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1087 auto bag = get<BagStorage *>(op.getBag())->bag;
1090 return op->emitError(
"cannot select from an empty bag");
1092 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1093 prefixSum.reserve(bag.size());
1094 uint32_t accumulator = 0;
1095 for (
auto [val, weight] : bag) {
1096 accumulator += weight;
1097 prefixSum.push_back({val, accumulator});
1100 auto customRng = sharedState.rng;
1102 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1103 customRng = std::mt19937(intAttr.getInt());
1107 auto *iter = llvm::upper_bound(
1109 [](uint32_t a,
const std::pair<ElaboratorValue, uint32_t> &b) {
1110 return a < b.second;
1113 state[op.getResult()] = iter->first;
1114 return DeletionKind::Delete;
1117 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1118 auto original = get<BagStorage *>(op.getOriginal())->bag;
1119 auto diff = get<BagStorage *>(op.getDiff())->bag;
1121 MapVector<ElaboratorValue, uint64_t> result;
1122 for (
const auto &el : original) {
1123 if (!diff.contains(el.first)) {
1131 auto toDiff = diff.lookup(el.first);
1132 if (el.second <= toDiff)
1135 result.insert({el.first, el.second - toDiff});
1138 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1139 std::move(result), op.getType());
1140 return DeletionKind::Delete;
1143 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1144 MapVector<ElaboratorValue, uint64_t> result;
1145 for (
auto bag : op.getBags()) {
1146 auto val = get<BagStorage *>(bag)->bag;
1147 for (
auto [el, multiple] : val)
1148 result[el] += multiple;
1151 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1152 std::move(result), op.getType());
1153 return DeletionKind::Delete;
1156 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1157 auto size = get<BagStorage *>(op.getBag())->bag.size();
1158 state[op.getResult()] = size;
1159 return DeletionKind::Delete;
1162 FailureOr<DeletionKind> visitOp(FixedRegisterOp op) {
1163 return visitConstantLike(op);
1166 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1167 state[op.getResult()] =
1168 sharedState.internalizer.create<VirtualRegisterStorage>(
1169 op.getAllowedRegsAttr());
1170 return DeletionKind::Delete;
1173 StringAttr substituteFormatString(StringAttr formatString,
1174 ValueRange substitutes)
const {
1175 if (substitutes.empty() || formatString.empty())
1176 return formatString;
1178 auto original = formatString.getValue().str();
1179 for (
auto [i, subst] :
llvm::enumerate(substitutes)) {
1180 size_t startPos = 0;
1181 std::string from =
"{{" + std::to_string(i) +
"}}";
1182 while ((startPos = original.find(from, startPos)) != std::string::npos) {
1183 auto substString = std::to_string(get<size_t>(subst));
1184 original.replace(startPos, from.length(), substString);
1188 return StringAttr::get(formatString.getContext(), original);
1191 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1192 SmallVector<ElaboratorValue> array;
1193 array.reserve(op.getElements().size());
1194 for (
auto val : op.getElements())
1195 array.emplace_back(state.at(val));
1197 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1198 op.getResult().getType(), std::move(array));
1199 return DeletionKind::Delete;
1202 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1203 auto array = get<ArrayStorage *>(op.getArray())->array;
1204 size_t idx = get<size_t>(op.getIndex());
1206 if (array.size() <= idx)
1207 return op->emitError(
"invalid to access index ")
1208 << idx <<
" of an array with " << array.size() <<
" elements";
1210 state[op.getResult()] = array[idx];
1211 return DeletionKind::Delete;
1214 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
1216 substituteFormatString(op.getFormatStringAttr(), op.getArgs());
1217 state[op.getLabel()] = LabelValue(substituted);
1218 return DeletionKind::Delete;
1221 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1222 state[op.getLabel()] = sharedState.internalizer.create<UniqueLabelStorage>(
1223 substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
1224 return DeletionKind::Delete;
1227 FailureOr<DeletionKind> visitOp(LabelOp op) {
return DeletionKind::Keep; }
1229 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1230 size_t lower = get<size_t>(op.getLowerBound());
1231 size_t upper = get<size_t>(op.getUpperBound()) - 1;
1233 return op->emitError(
"cannot select a number from an empty range");
1236 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1237 std::mt19937 customRng(intAttr.getInt());
1238 state[op.getResult()] =
1241 state[op.getResult()] =
1245 return DeletionKind::Delete;
1248 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1249 size_t input = get<size_t>(op.getInput());
1250 auto width = op.getType().getWidth();
1251 auto emitError = [&]() {
return op->emitError(); };
1252 if (input > APInt::getAllOnes(width).getZExtValue())
1253 return emitError() <<
"cannot represent " << input <<
" with " << width
1256 state[op.getResult()] =
1257 ImmediateAttr::get(op.getContext(), APInt(width, input));
1258 return DeletionKind::Delete;
1261 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1262 ContextResourceAttrInterface from = currentContext,
1263 to = cast<ContextResourceAttrInterface>(
1264 get<TypedAttr>(op.getContext()));
1265 if (!currentContext)
1266 from = DefaultContextAttr::get(op->getContext(), to.getType());
1268 auto emitError = [&]() {
1269 auto diag = op.emitError();
1270 diag.attachNote(op.getLoc())
1271 <<
"while materializing value for context switching for " << op;
1276 Value seqVal = materializer.materialize(
1277 get<SequenceStorage *>(op.getSequence()), op.getLoc(),
1278 sharedState.worklist, emitError);
1280 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1281 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1282 return DeletionKind::Delete;
1286 auto *iter = testState.contextSwitches.find({from, to});
1289 if (iter == testState.contextSwitches.end())
1290 return op->emitError(
"no context transition registered to switch from ")
1291 << from <<
" to " << to;
1293 auto familyName = iter->second->familyName;
1294 SmallVector<ElaboratorValue> args{from, to,
1295 get<SequenceStorage *>(op.getSequence())};
1296 auto *
seq = sharedState.internalizer.internalize<SequenceStorage>(
1297 familyName, std::move(args));
1299 sharedState.internalizer.internalize<RandomizedSequenceStorage>(
1300 sharedState.names.newName(familyName.getValue()), to,
1301 testState.name,
seq);
1302 Value seqVal = materializer.materialize(randSeq, op.getLoc(),
1303 sharedState.worklist, emitError);
1304 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1306 return DeletionKind::Delete;
1309 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1310 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1311 get<SequenceStorage *>(op.getSequence());
1312 return DeletionKind::Delete;
1315 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1316 bool cond = get<bool>(op.getCondition());
1317 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1318 if (toElaborate.empty())
1319 return DeletionKind::Delete;
1325 if (failed(elaborate(toElaborate)))
1329 for (
auto [res, out] :
1330 llvm::zip(op.getResults(),
1331 toElaborate.front().getTerminator()->getOperands()))
1332 state[res] = state.at(out);
1334 return DeletionKind::Delete;
1337 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1338 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1339 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1340 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1341 return op->emitOpError(
"can only elaborate index type iterator");
1343 auto lowerBound = get<size_t>(op.getLowerBound());
1344 auto step = get<size_t>(op.getStep());
1345 auto upperBound = get<size_t>(op.getUpperBound());
1351 state[op.getInductionVar()] = lowerBound;
1352 for (
auto [iterArg, initArg] :
1353 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1354 state[iterArg] = state.at(initArg);
1357 for (
size_t i = lowerBound; i < upperBound; i += step) {
1358 if (failed(elaborate(op.getBodyRegion())))
1363 state[op.getInductionVar()] = i + step;
1364 for (
auto [iterArg, prevIterArg] :
1365 llvm::zip(op.getRegionIterArgs(),
1366 op.getBody()->getTerminator()->getOperands()))
1367 state[iterArg] = state.at(prevIterArg);
1371 for (
auto [res, iterArg] :
1372 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1373 state[res] = state.at(iterArg);
1375 return DeletionKind::Delete;
1378 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1379 return DeletionKind::Delete;
1382 FailureOr<DeletionKind> visitOp(index::AddOp op) {
1383 size_t lhs = get<size_t>(op.getLhs());
1384 size_t rhs = get<size_t>(op.getRhs());
1385 state[op.getResult()] = lhs + rhs;
1386 return DeletionKind::Delete;
1389 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
1390 size_t lhs = get<size_t>(op.getLhs());
1391 size_t rhs = get<size_t>(op.getRhs());
1393 switch (op.getPred()) {
1394 case index::IndexCmpPredicate::EQ:
1395 result = lhs == rhs;
1397 case index::IndexCmpPredicate::NE:
1398 result = lhs != rhs;
1400 case index::IndexCmpPredicate::ULT:
1403 case index::IndexCmpPredicate::ULE:
1404 result = lhs <= rhs;
1406 case index::IndexCmpPredicate::UGT:
1409 case index::IndexCmpPredicate::UGE:
1410 result = lhs >= rhs;
1413 return op->emitOpError(
"elaboration not supported");
1415 state[op.getResult()] = result;
1416 return DeletionKind::Delete;
1420 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
1423 index::AddOp, index::CmpOp,
1425 scf::IfOp, scf::ForOp, scf::YieldOp>(
1426 [&](
auto op) {
return visitOp(op); })
1427 .Default([&](Operation *op) {
return RTGBase::dispatchOpVisitor(op); });
1431 LogicalResult elaborate(Region ®ion,
1432 ArrayRef<ElaboratorValue> regionArguments = {}) {
1433 if (region.getBlocks().size() > 1)
1434 return region.getParentOp()->emitOpError(
1435 "regions with more than one block are not supported");
1437 for (
auto [arg, elabArg] :
1438 llvm::zip(region.getArguments(), regionArguments))
1439 state[arg] = elabArg;
1441 Block *block = ®ion.front();
1442 for (
auto &op : *block) {
1447 if (*result == DeletionKind::Keep)
1448 if (failed(materializer.materialize(&op, state, sharedState.worklist)))
1452 llvm::dbgs() <<
"Elaborated " << op <<
" to\n[";
1454 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](
auto res) {
1455 if (state.contains(res))
1456 llvm::dbgs() << state.at(res);
1458 llvm::dbgs() <<
"unknown";
1461 llvm::dbgs() <<
"]\n\n";
1470 ElaboratorSharedState &sharedState;
1473 TestState &testState;
1477 Materializer &materializer;
1480 DenseMap<Value, ElaboratorValue> state;
1483 ContextResourceAttrInterface currentContext;
1492struct ElaborationPass
1493 :
public rtg::impl::ElaborationPassBase<ElaborationPass> {
1496 void runOnOperation()
override;
1497 void cloneTargetsIntoTests(SymbolTable &table);
1498 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
1502void ElaborationPass::runOnOperation() {
1503 auto moduleOp = getOperation();
1504 SymbolTable table(moduleOp);
1506 cloneTargetsIntoTests(table);
1508 if (failed(elaborateModule(moduleOp, table)))
1509 return signalPassFailure();
1512void ElaborationPass::cloneTargetsIntoTests(SymbolTable &table) {
1513 auto moduleOp = getOperation();
1514 for (
auto target :
llvm::make_early_inc_range(moduleOp.getOps<TargetOp>())) {
1515 for (
auto test : moduleOp.getOps<TestOp>()) {
1517 if (test.getTarget().getEntries().empty())
1522 if (target.getTarget() != test.getTarget())
1525 IRRewriter rewriter(test);
1527 auto newTest = cast<TestOp>(test->clone());
1528 newTest.setSymName(test.getSymName().str() +
"_" +
1529 target.getSymName().str());
1530 table.insert(newTest, rewriter.getInsertionPoint());
1534 rewriter.setInsertionPointToStart(newTest.getBody());
1535 for (
auto &op : target.getBody()->without_terminator())
1536 rewriter.clone(op, mapping);
1538 for (
auto [returnVal, result] :
1539 llvm::zip(target.getBody()->getTerminator()->getOperands(),
1540 newTest.getBody()->getArguments()))
1541 result.replaceAllUsesWith(mapping.lookup(returnVal));
1543 newTest.getBody()->eraseArguments(0,
1544 newTest.getBody()->getNumArguments());
1545 newTest.setTarget(DictType::get(&getContext(), {}));
1552 for (
auto test :
llvm::make_early_inc_range(moduleOp.getOps<TestOp>()))
1553 if (!test.getTarget().getEntries().
empty())
1557LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
1558 SymbolTable &table) {
1559 ElaboratorSharedState state(table, seed);
1562 state.names.add(moduleOp);
1566 DenseMap<StringAttr, TestState> testStates;
1567 for (
auto testOp : moduleOp.getOps<TestOp>()) {
1568 LLVM_DEBUG(llvm::dbgs()
1569 <<
"\n=== Elaborating test @" << testOp.getSymName() <<
"\n\n");
1570 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()));
1571 testStates[testOp.getSymNameAttr()].name = testOp.getSymNameAttr();
1572 Elaborator elaborator(state, testStates[testOp.getSymNameAttr()],
1574 if (failed(elaborator.elaborate(testOp.getBodyRegion())))
1577 materializer.finalize();
1582 while (!state.worklist.empty()) {
1583 auto *curr = state.worklist.front();
1584 state.worklist.pop();
1586 if (table.lookup<SequenceOp>(curr->name))
1589 auto familyOp = table.lookup<SequenceOp>(curr->sequence->familyName);
1592 OpBuilder builder(familyOp);
1593 auto seqOp = builder.cloneWithoutRegions(familyOp);
1594 seqOp.getBodyRegion().emplaceBlock();
1595 seqOp.setSymName(curr->name);
1596 seqOp.setSequenceType(
1597 SequenceType::get(builder.getContext(), ArrayRef<Type>{}));
1598 table.insert(seqOp);
1599 assert(seqOp.getSymName() == curr->name &&
"should not have been renamed");
1601 LLVM_DEBUG(llvm::dbgs()
1602 <<
"\n=== Elaborating sequence family @" << familyOp.getSymName()
1603 <<
" into @" << seqOp.getSymName() <<
" under context "
1604 << curr->context <<
"\n\n");
1606 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()));
1607 Elaborator elaborator(state, testStates[curr->test], materializer,
1609 if (failed(elaborator.elaborate(familyOp.getBodyRegion(),
1610 curr->sequence->args)))
1613 materializer.finalize();
assert(baseType &&"element must be base type")
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static void print(TypedAttr val, llvm::raw_ostream &os)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
This helps visit TypeOp nodes.
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
ResultType dispatchOpVisitor(Operation *op, ExtraArgs... args)
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
static llvm::hash_code hash_value(const ModulePort &port)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static LabelValue getEmptyKey()
static LabelValue getTombstoneKey()
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)
static unsigned getEmptyKey()