21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Index/IR/IndexDialect.h"
23#include "mlir/Dialect/Index/IR/IndexOps.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/DenseMapInfoVariant.h"
28#include "llvm/Support/Debug.h"
34#define GEN_PASS_DEF_ELABORATIONPASS
35#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
44#define DEBUG_TYPE "rtg-elaboration"
56 size_t n = w / 32 + (w % 32 != 0);
58 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
63 const uint32_t diff = b - a + 1;
67 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
71 uint32_t width = digits - llvm::countl_zero(diff) - 1;
72 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
91struct SequenceStorage;
92struct RandomizedSequenceStorage;
93struct InterleavedSequenceStorage;
95struct VirtualRegisterStorage;
96struct UniqueLabelStorage;
99struct MemoryBlockStorage;
100struct ValidationValue;
106 LabelValue(StringAttr name) : name(name) {}
108 bool operator==(
const LabelValue &other)
const {
return name == other.name; }
115using ElaboratorValue =
116 std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
117 RandomizedSequenceStorage *, InterleavedSequenceStorage *,
118 SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
119 LabelValue, ArrayStorage *, TupleStorage *, MemoryStorage *,
120 MemoryBlockStorage *, ValidationValue *>;
123llvm::hash_code
hash_value(
const LabelValue &val) {
124 return llvm::hash_value(val.name);
128llvm::hash_code
hash_value(
const ElaboratorValue &val) {
130 [&val](
const auto &alternative) {
133 return llvm::hash_combine(val.index(), alternative);
148 static bool isEqual(
const bool &lhs,
const bool &rhs) {
return lhs == rhs; }
162 static bool isEqual(
const LabelValue &lhs,
const LabelValue &rhs) {
162 static bool isEqual(
const LabelValue &lhs,
const LabelValue &rhs) {
…}
179template <
typename StorageTy>
180struct HashedStorage {
181 HashedStorage(
unsigned hashcode = 0, StorageTy *storage =
nullptr)
182 : hashcode(hashcode), storage(storage) {}
192template <
typename StorageTy>
193struct StorageKeyInfo {
194 static inline HashedStorage<StorageTy> getEmptyKey() {
195 return HashedStorage<StorageTy>(0,
196 DenseMapInfo<StorageTy *>::getEmptyKey());
198 static inline HashedStorage<StorageTy> getTombstoneKey() {
199 return HashedStorage<StorageTy>(
200 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
203 static inline unsigned getHashValue(
const HashedStorage<StorageTy> &key) {
206 static inline unsigned getHashValue(
const StorageTy &key) {
210 static inline bool isEqual(
const HashedStorage<StorageTy> &lhs,
211 const HashedStorage<StorageTy> &rhs) {
212 return lhs.storage == rhs.storage;
214 static inline bool isEqual(
const StorageTy &lhs,
215 const HashedStorage<StorageTy> &rhs) {
216 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
219 return lhs.isEqual(rhs.storage);
228 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
230 type,
llvm::hash_combine_range(set.begin(), set.
end()))),
231 set(std::move(set)), type(type) {}
233 bool isEqual(
const SetStorage *other)
const {
234 return hashcode == other->hashcode && set == other->set &&
239 const unsigned hashcode;
242 const SetVector<ElaboratorValue> set;
251 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
253 type,
llvm::hash_combine_range(bag.begin(), bag.
end()))),
254 bag(std::move(bag)), type(type) {}
256 bool isEqual(
const BagStorage *other)
const {
257 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
262 const unsigned hashcode;
266 const MapVector<ElaboratorValue, uint64_t> bag;
274struct SequenceStorage {
275 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
277 familyName,
llvm::hash_combine_range(args.begin(), args.
end()))),
278 familyName(familyName), args(std::move(args)) {}
280 bool isEqual(
const SequenceStorage *other)
const {
281 return hashcode == other->hashcode && familyName == other->familyName &&
286 const unsigned hashcode;
289 const StringAttr familyName;
292 const SmallVector<ElaboratorValue> args;
296struct InterleavedSequenceStorage {
297 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
299 : sequences(std::move(sequences)), batchSize(batchSize),
301 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
304 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
305 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
307 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
310 bool isEqual(
const InterleavedSequenceStorage *other)
const {
311 return hashcode == other->hashcode && sequences == other->sequences &&
312 batchSize == other->batchSize;
315 const SmallVector<ElaboratorValue> sequences;
317 const uint32_t batchSize;
320 const unsigned hashcode;
325 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
327 type,
llvm::hash_combine_range(array.begin(), array.
end()))),
328 type(type), array(array) {}
330 bool isEqual(
const ArrayStorage *other)
const {
331 return hashcode == other->hashcode && type == other->type &&
332 array == other->array;
336 const unsigned hashcode;
343 const SmallVector<ElaboratorValue> array;
348 TupleStorage(SmallVector<ElaboratorValue> &&values)
349 : hashcode(
llvm::hash_combine_range(values.begin(), values.
end())),
350 values(std::move(values)) {}
352 bool isEqual(
const TupleStorage *other)
const {
353 return hashcode == other->hashcode && values == other->values;
357 const unsigned hashcode;
359 const SmallVector<ElaboratorValue> values;
370struct IdentityValue {
372 IdentityValue(Type type) : type(type) {}
385 bool alreadyMaterialized =
false;
393struct VirtualRegisterStorage : IdentityValue {
394 VirtualRegisterStorage(ArrayAttr allowedRegs, Type type)
395 : IdentityValue(type), allowedRegs(allowedRegs) {}
402 const ArrayAttr allowedRegs;
405struct UniqueLabelStorage : IdentityValue {
406 UniqueLabelStorage(StringAttr name)
407 : IdentityValue(LabelType::
get(name.getContext())), name(name) {}
413 const StringAttr name;
417struct MemoryBlockStorage : IdentityValue {
418 MemoryBlockStorage(
const APInt &baseAddress,
const APInt &endAddress,
420 : IdentityValue(type), baseAddress(baseAddress), endAddress(endAddress) {}
425 const APInt baseAddress;
428 const APInt endAddress;
432struct MemoryStorage : IdentityValue {
433 MemoryStorage(MemoryBlockStorage *memoryBlock,
size_t size,
size_t alignment)
434 : IdentityValue(MemoryType::
get(memoryBlock->type.getContext(),
436 memoryBlock(memoryBlock), size(size), alignment(alignment) {}
438 MemoryBlockStorage *memoryBlock;
440 const size_t alignment;
444struct RandomizedSequenceStorage : IdentityValue {
445 RandomizedSequenceStorage(ContextResourceAttrInterface context,
446 SequenceStorage *sequence)
448 RandomizedSequenceType::
get(sequence->familyName.getContext())),
449 context(context), sequence(sequence) {}
452 const ContextResourceAttrInterface context;
454 const SequenceStorage *sequence;
458struct ValidationValue : IdentityValue {
459 ValidationValue(Type type,
const ElaboratorValue &ref,
460 const ElaboratorValue &defaultValue, StringAttr
id)
461 : IdentityValue(type), ref(ref), defaultValue(defaultValue), id(id) {}
463 const ElaboratorValue ref;
464 const ElaboratorValue defaultValue;
478 template <
typename StorageTy,
typename... Args>
479 StorageTy *internalize(Args &&...args) {
480 static_assert(!std::is_base_of_v<IdentityValue, StorageTy> &&
481 "values with identity must not be internalized");
483 StorageTy storage(std::forward<Args>(args)...);
485 auto existing = getInternSet<StorageTy>().insert_as(
486 HashedStorage<StorageTy>(storage.hashcode), storage);
487 StorageTy *&storagePtr = existing.first->storage;
490 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
495 template <
typename StorageTy,
typename... Args>
496 StorageTy *create(Args &&...args) {
497 static_assert(std::is_base_of_v<IdentityValue, StorageTy> &&
498 "values with structural equivalence must be internalized");
500 return new (allocator.Allocate<StorageTy>())
501 StorageTy(std::forward<Args>(args)...);
505 template <
typename StorageTy>
506 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
508 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
509 return internedArrays;
510 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
512 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
514 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
515 return internedSequences;
516 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
517 return internedRandomizedSequences;
518 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
519 return internedInterleavedSequences;
520 else if constexpr (std::is_same_v<StorageTy, TupleStorage>)
521 return internedTuples;
523 static_assert(!
sizeof(StorageTy),
524 "no intern set available for this storage type.");
529 llvm::BumpPtrAllocator allocator;
534 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
536 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
537 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
538 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
540 DenseSet<HashedStorage<RandomizedSequenceStorage>,
541 StorageKeyInfo<RandomizedSequenceStorage>>
542 internedRandomizedSequences;
543 DenseSet<HashedStorage<InterleavedSequenceStorage>,
544 StorageKeyInfo<InterleavedSequenceStorage>>
545 internedInterleavedSequences;
546 DenseSet<HashedStorage<TupleStorage>, StorageKeyInfo<TupleStorage>>
554static llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
555 const ElaboratorValue &value);
557static void print(TypedAttr val, llvm::raw_ostream &os) {
558 os <<
"<attr " << val <<
">";
557static void print(TypedAttr val, llvm::raw_ostream &os) {
…}
561static void print(BagStorage *val, llvm::raw_ostream &os) {
563 llvm::interleaveComma(val->bag, os,
564 [&](
const std::pair<ElaboratorValue, uint64_t> &el) {
565 os << el.first <<
" -> " << el.second;
567 os <<
"} at " << val <<
">";
561static void print(BagStorage *val, llvm::raw_ostream &os) {
…}
570static void print(
bool val, llvm::raw_ostream &os) {
571 os <<
"<bool " << (val ?
"true" :
"false") <<
">";
570static void print(
bool val, llvm::raw_ostream &os) {
…}
574static void print(
size_t val, llvm::raw_ostream &os) {
575 os <<
"<index " << val <<
">";
574static void print(
size_t val, llvm::raw_ostream &os) {
…}
578static void print(SequenceStorage *val, llvm::raw_ostream &os) {
579 os <<
"<sequence @" << val->familyName.getValue() <<
"(";
580 llvm::interleaveComma(val->args, os,
581 [&](
const ElaboratorValue &val) { os << val; });
582 os <<
") at " << val <<
">";
578static void print(SequenceStorage *val, llvm::raw_ostream &os) {
…}
585static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
586 os <<
"<randomized-sequence derived from @"
587 << val->sequence->familyName.getValue() <<
" under context "
588 << val->context <<
"(";
589 llvm::interleaveComma(val->sequence->args, os,
590 [&](
const ElaboratorValue &val) { os << val; });
591 os <<
") at " << val <<
">";
585static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
…}
594static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
595 os <<
"<interleaved-sequence [";
596 llvm::interleaveComma(val->sequences, os,
597 [&](
const ElaboratorValue &val) { os << val; });
598 os <<
"] batch-size " << val->batchSize <<
" at " << val <<
">";
594static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
…}
601static void print(ArrayStorage *val, llvm::raw_ostream &os) {
603 llvm::interleaveComma(val->array, os,
604 [&](
const ElaboratorValue &val) { os << val; });
605 os <<
"] at " << val <<
">";
601static void print(ArrayStorage *val, llvm::raw_ostream &os) {
…}
608static void print(SetStorage *val, llvm::raw_ostream &os) {
610 llvm::interleaveComma(val->set, os,
611 [&](
const ElaboratorValue &val) { os << val; });
612 os <<
"} at " << val <<
">";
608static void print(SetStorage *val, llvm::raw_ostream &os) {
…}
615static void print(
const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
616 os <<
"<virtual-register " << val <<
" " << val->allowedRegs <<
">";
615static void print(
const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
…}
619static void print(
const UniqueLabelStorage *val, llvm::raw_ostream &os) {
620 os <<
"<unique-label " << val <<
" " << val->name <<
">";
619static void print(
const UniqueLabelStorage *val, llvm::raw_ostream &os) {
…}
623static void print(
const LabelValue &val, llvm::raw_ostream &os) {
624 os <<
"<label " << val.name <<
">";
623static void print(
const LabelValue &val, llvm::raw_ostream &os) {
…}
627static void print(
const TupleStorage *val, llvm::raw_ostream &os) {
629 llvm::interleaveComma(val->values, os,
630 [&](
const ElaboratorValue &val) { os << val; });
627static void print(
const TupleStorage *val, llvm::raw_ostream &os) {
…}
634static void print(
const MemoryStorage *val, llvm::raw_ostream &os) {
635 os <<
"<memory {" << ElaboratorValue(val->memoryBlock)
636 <<
", size=" << val->size <<
", alignment=" << val->alignment <<
"}>";
634static void print(
const MemoryStorage *val, llvm::raw_ostream &os) {
…}
639static void print(
const MemoryBlockStorage *val, llvm::raw_ostream &os) {
640 os <<
"<memory-block {"
641 <<
", address-width=" << val->baseAddress.getBitWidth()
642 <<
", base-address=" << val->baseAddress
643 <<
", end-address=" << val->endAddress <<
"}>";
639static void print(
const MemoryBlockStorage *val, llvm::raw_ostream &os) {
…}
646static void print(
const ValidationValue *val, llvm::raw_ostream &os) {
647 os <<
"<validation-value {type=" << val->type <<
", ref=" << val->ref
648 <<
", defaultValue=" << val->defaultValue <<
"}>";
646static void print(
const ValidationValue *val, llvm::raw_ostream &os) {
…}
652 const ElaboratorValue &value) {
653 std::visit([&](
auto val) {
print(val, os); }, value);
668 SharedState(SymbolTable &table,
unsigned seed) : table(table), rng(seed) {}
673 Internalizer internalizer;
683 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
691 Materializer(OpBuilder builder, TestState &testState,
692 SharedState &sharedState,
693 SmallVector<ElaboratorValue> &blockArgs)
694 : builder(builder), testState(testState), sharedState(sharedState),
695 blockArgs(blockArgs) {}
699 Value materialize(ElaboratorValue val, Location loc,
700 function_ref<InFlightDiagnostic()> emitError) {
701 auto iter = materializedValues.find(val);
702 if (iter != materializedValues.end())
705 LLVM_DEBUG(llvm::dbgs() <<
"Materializing " << val);
709 Value res = std::visit(
711 if constexpr (std::is_base_of_v<IdentityValue,
712 std::remove_pointer_t<
713 std::decay_t<
decltype(value)>>>) {
714 if (identityValueRoot.contains(value)) {
717 static_cast<IdentityValue *
>(value)->alreadyMaterialized;
718 assert(!materialized &&
"must not already be materialized");
722 return visit(value, loc, emitError);
725 Value arg = builder.getBlock()->addArgument(value->type, loc);
726 blockArgs.push_back(val);
727 blockArgTypes.push_back(arg.getType());
728 materializedValues[val] = arg;
732 return visit(value, loc, emitError);
736 LLVM_DEBUG(llvm::dbgs() <<
" to\n" << res <<
"\n\n");
747 LogicalResult materialize(Operation *op,
748 DenseMap<Value, ElaboratorValue> &state) {
749 if (op->getNumRegions() > 0)
750 return op->emitOpError(
"ops with nested regions must be elaborated away");
758 for (
auto res : op->getResults())
759 if (!res.use_empty() && !isa<ValidateOp>(op))
760 return op->emitOpError(
761 "ops with results that have uses are not supported");
763 if (op->getParentRegion() == builder.getBlock()->getParent()) {
766 deleteOpsUntil([&](
auto iter) {
return &*iter == op; });
768 if (builder.getInsertionPoint() == builder.getBlock()->end())
769 return op->emitError(
"operation did not occur after the current "
770 "materializer insertion point");
772 LLVM_DEBUG(llvm::dbgs() <<
"Modifying in-place: " << *op <<
"\n\n");
774 LLVM_DEBUG(llvm::dbgs() <<
"Materializing a clone of " << *op <<
"\n\n");
775 op = builder.clone(*op);
776 builder.setInsertionPoint(op);
779 for (
auto &operand : op->getOpOperands()) {
780 auto emitError = [&]() {
781 auto diag = op->emitError();
782 diag.attachNote(op->getLoc())
783 <<
"while materializing value for operand#"
784 << operand.getOperandNumber();
788 auto elabVal = state.at(operand.get());
789 Value val = materialize(elabVal, op->getLoc(), emitError);
793 state[val] = elabVal;
797 builder.setInsertionPointAfter(op);
804 deleteOpsUntil([](
auto iter) {
return false; });
806 for (
auto *op :
llvm::reverse(toDelete))
813 void registerIdentityValue(IdentityValue *val) {
814 identityValueRoot.insert(val);
817 ArrayRef<Type> getBlockArgTypes()
const {
return blockArgTypes; }
819 void map(ElaboratorValue eval, Value val) { materializedValues[eval] = val; }
821 template <
typename OpTy,
typename... Args>
822 OpTy create(Location location, Args &&...args) {
823 return builder.create<OpTy>(location, std::forward<Args>(args)...);
827 SequenceOp elaborateSequence(
const RandomizedSequenceStorage *
seq,
828 SmallVector<ElaboratorValue> &elabArgs);
830 void deleteOpsUntil(function_ref<
bool(Block::iterator)> stop) {
831 auto ip = builder.getInsertionPoint();
832 while (ip != builder.getBlock()->end() && !stop(ip)) {
833 LLVM_DEBUG(llvm::dbgs() <<
"Marking to be deleted: " << *ip <<
"\n\n");
834 toDelete.push_back(&*ip);
836 builder.setInsertionPointAfter(&*ip);
837 ip = builder.getInsertionPoint();
841 Value visit(TypedAttr val, Location loc,
842 function_ref<InFlightDiagnostic()> emitError) {
845 if (
auto intAttr = dyn_cast<IntegerAttr>(val);
846 intAttr && isa<IndexType>(val.getType())) {
847 Value res = builder.create<index::ConstantOp>(loc, intAttr);
848 materializedValues[val] = res;
855 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
857 emitError() <<
"materializer of dialect '"
858 << val.getDialect().getNamespace()
859 <<
"' unable to materialize value for attribute '" << val
864 Value res = op->getResult(0);
865 materializedValues[val] = res;
869 Value visit(
size_t val, Location loc,
870 function_ref<InFlightDiagnostic()> emitError) {
871 Value res = builder.create<index::ConstantOp>(loc, val);
872 materializedValues[val] = res;
876 Value visit(
bool val, Location loc,
877 function_ref<InFlightDiagnostic()> emitError) {
878 Value res = builder.create<index::BoolConstantOp>(loc, val);
879 materializedValues[val] = res;
883 Value visit(ArrayStorage *val, Location loc,
884 function_ref<InFlightDiagnostic()> emitError) {
885 SmallVector<Value> elements;
886 elements.reserve(val->array.size());
887 for (
auto el : val->array) {
888 auto materialized = materialize(el, loc, emitError);
892 elements.push_back(materialized);
895 Value res = builder.create<ArrayCreateOp>(loc, val->type, elements);
896 materializedValues[val] = res;
900 Value visit(SetStorage *val, Location loc,
901 function_ref<InFlightDiagnostic()> emitError) {
902 SmallVector<Value> elements;
903 elements.reserve(val->set.size());
904 for (
auto el : val->set) {
905 auto materialized = materialize(el, loc, emitError);
909 elements.push_back(materialized);
912 auto res = builder.create<SetCreateOp>(loc, val->type, elements);
913 materializedValues[val] = res;
917 Value visit(BagStorage *val, Location loc,
918 function_ref<InFlightDiagnostic()> emitError) {
919 SmallVector<Value> values, weights;
920 values.reserve(val->bag.size());
921 weights.reserve(val->bag.size());
922 for (
auto [val, weight] : val->bag) {
923 auto materializedVal = materialize(val, loc, emitError);
924 auto materializedWeight = materialize(weight, loc, emitError);
925 if (!materializedVal || !materializedWeight)
928 values.push_back(materializedVal);
929 weights.push_back(materializedWeight);
932 auto res = builder.create<BagCreateOp>(loc, val->type, values, weights);
933 materializedValues[val] = res;
937 Value visit(MemoryBlockStorage *val, Location loc,
938 function_ref<InFlightDiagnostic()> emitError) {
939 auto intType = builder.getIntegerType(val->baseAddress.getBitWidth());
940 Value res = builder.create<MemoryBlockDeclareOp>(
941 loc, val->type, IntegerAttr::get(intType, val->baseAddress),
942 IntegerAttr::get(intType, val->endAddress));
943 materializedValues[val] = res;
947 Value visit(MemoryStorage *val, Location loc,
948 function_ref<InFlightDiagnostic()> emitError) {
949 auto memBlock = materialize(val->memoryBlock, loc, emitError);
950 auto memSize = materialize(val->size, loc, emitError);
951 auto memAlign = materialize(val->alignment, loc, emitError);
952 if (!(memBlock && memSize && memAlign))
955 Value res = builder.create<MemoryAllocOp>(loc, memBlock, memSize, memAlign);
956 materializedValues[val] = res;
960 Value visit(SequenceStorage *val, Location loc,
961 function_ref<InFlightDiagnostic()> emitError) {
962 emitError() <<
"materializing a non-randomized sequence not supported yet";
966 Value visit(RandomizedSequenceStorage *val, Location loc,
967 function_ref<InFlightDiagnostic()> emitError) {
973 SmallVector<ElaboratorValue> elabArgs;
974 SequenceOp seqOp = elaborateSequence(val, elabArgs);
980 SmallVector<Value> args;
981 SmallVector<Type> argTypes;
982 for (
auto arg : elabArgs) {
983 Value materialized = materialize(arg, loc, emitError);
987 args.push_back(materialized);
988 argTypes.push_back(materialized.getType());
991 Value res = builder.create<GetSequenceOp>(
992 loc, SequenceType::get(builder.getContext(), argTypes),
998 res = builder.create<SubstituteSequenceOp>(loc, res, args);
1000 res = builder.create<RandomizeSequenceOp>(loc, res);
1002 materializedValues[val] = res;
1006 Value visit(InterleavedSequenceStorage *val, Location loc,
1007 function_ref<InFlightDiagnostic()> emitError) {
1008 SmallVector<Value> sequences;
1009 for (
auto seqVal : val->sequences) {
1010 Value materialized = materialize(seqVal, loc, emitError);
1014 sequences.push_back(materialized);
1017 if (sequences.size() == 1)
1018 return sequences[0];
1021 builder.create<InterleaveSequencesOp>(loc, sequences, val->batchSize);
1022 materializedValues[val] = res;
1026 Value visit(VirtualRegisterStorage *val, Location loc,
1027 function_ref<InFlightDiagnostic()> emitError) {
1028 Value res = builder.create<VirtualRegisterOp>(loc, val->allowedRegs);
1029 materializedValues[val] = res;
1033 Value visit(UniqueLabelStorage *val, Location loc,
1034 function_ref<InFlightDiagnostic()> emitError) {
1035 Value res = builder.create<LabelUniqueDeclOp>(loc, val->name, ValueRange());
1036 materializedValues[val] = res;
1040 Value visit(
const LabelValue &val, Location loc,
1041 function_ref<InFlightDiagnostic()> emitError) {
1042 Value res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
1043 materializedValues[val] = res;
1047 Value visit(TupleStorage *val, Location loc,
1048 function_ref<InFlightDiagnostic()> emitError) {
1049 SmallVector<Value> materialized;
1050 materialized.reserve(val->values.size());
1051 for (
auto v : val->values)
1052 materialized.push_back(materialize(v, loc, emitError));
1053 Value res = builder.create<TupleCreateOp>(loc, materialized);
1054 materializedValues[val] = res;
1058 Value visit(ValidationValue *val, Location loc,
1059 function_ref<InFlightDiagnostic()> emitError) {
1060 Value res = builder.create<ValidateOp>(
1061 loc, val->type, materialize(val->ref, loc, emitError),
1062 materialize(val->defaultValue, loc, emitError), val->id);
1063 materializedValues[val] = res;
1073 DenseMap<ElaboratorValue, Value> materializedValues;
1079 SmallVector<Operation *> toDelete;
1081 TestState &testState;
1082 SharedState &sharedState;
1087 SmallVector<ElaboratorValue> &blockArgs;
1088 SmallVector<Type> blockArgTypes;
1093 DenseSet<IdentityValue *> identityValueRoot;
1102enum class DeletionKind { Keep, Delete };
1105class Elaborator :
public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
1108 using RTGBase::visitOp;
1110 Elaborator(SharedState &sharedState, TestState &testState,
1111 Materializer &materializer,
1112 ContextResourceAttrInterface currentContext = {})
1113 : sharedState(sharedState), testState(testState),
1114 materializer(materializer), currentContext(currentContext) {}
1116 template <
typename ValueTy>
1117 inline ValueTy
get(Value val)
const {
1118 return std::get<ValueTy>(state.at(val));
1121 FailureOr<DeletionKind> visitPureOp(Operation *op) {
1122 SmallVector<Attribute> operands;
1123 for (
auto operand : op->getOperands()) {
1124 auto evalValue = state[operand];
1125 if (std::holds_alternative<TypedAttr>(evalValue))
1126 operands.push_back(std::get<TypedAttr>(evalValue));
1131 SmallVector<OpFoldResult> results;
1132 if (failed(op->fold(operands, results)))
1136 if (results.size() != op->getNumResults())
1139 for (
auto [res, val] :
llvm::zip(results, op->getResults())) {
1140 auto attr = llvm::dyn_cast_or_null<TypedAttr>(res.dyn_cast<Attribute>());
1142 return op->emitError(
1143 "only typed attributes supported for constant-like operations");
1145 auto intAttr = dyn_cast<IntegerAttr>(attr);
1146 if (intAttr && isa<IndexType>(attr.getType()))
1147 state[op->getResult(0)] = size_t(intAttr.getInt());
1148 else if (intAttr && intAttr.getType().isSignlessInteger(1))
1149 state[op->getResult(0)] = bool(intAttr.getInt());
1151 state[op->getResult(0)] = attr;
1154 return DeletionKind::Delete;
1159 return op->emitOpError(
"elaboration not supported");
1163 auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
1164 if (op->hasTrait<OpTrait::ConstantLike>() || (memOp && memOp.hasNoEffect()))
1165 return visitPureOp(op);
1170 if (op->use_empty())
1171 return DeletionKind::Keep;
1176 FailureOr<DeletionKind> visitOp(ConstantOp op) {
return visitPureOp(op); }
1178 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
1179 SmallVector<ElaboratorValue> replacements;
1180 state[op.getResult()] =
1181 sharedState.internalizer.internalize<SequenceStorage>(
1182 op.getSequenceAttr(), std::move(replacements));
1183 return DeletionKind::Delete;
1186 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
1187 auto *
seq = get<SequenceStorage *>(op.getSequence());
1189 SmallVector<ElaboratorValue> replacements(
seq->args);
1190 for (
auto replacement : op.getReplacements())
1191 replacements.push_back(state.at(replacement));
1193 state[op.getResult()] =
1194 sharedState.internalizer.internalize<SequenceStorage>(
1195 seq->familyName, std::move(replacements));
1197 return DeletionKind::Delete;
1200 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
1201 auto *
seq = get<SequenceStorage *>(op.getSequence());
1202 auto *randomizedSeq =
1203 sharedState.internalizer.create<RandomizedSequenceStorage>(
1204 currentContext,
seq);
1205 materializer.registerIdentityValue(randomizedSeq);
1206 state[op.getResult()] =
1207 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1209 return DeletionKind::Delete;
1212 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
1213 SmallVector<ElaboratorValue> sequences;
1214 for (
auto seq : op.getSequences())
1215 sequences.push_back(
get<InterleavedSequenceStorage *>(
seq));
1217 state[op.getResult()] =
1218 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1219 std::move(sequences), op.getBatchSize());
1220 return DeletionKind::Delete;
1224 LogicalResult isValidContext(ElaboratorValue value, Operation *op)
const {
1225 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
1226 auto *
seq = std::get<RandomizedSequenceStorage *>(value);
1227 if (
seq->context != currentContext) {
1228 auto err = op->emitError(
"attempting to place sequence derived from ")
1229 <<
seq->sequence->familyName.getValue() <<
" under context "
1231 <<
", but it was previously randomized for context ";
1233 err <<
seq->context;
1241 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1242 for (
auto val : interVal->sequences)
1243 if (failed(isValidContext(val, op)))
1248 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1249 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1250 if (failed(isValidContext(seqVal, op)))
1253 return DeletionKind::Keep;
1256 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1257 SetVector<ElaboratorValue> set;
1258 for (
auto val : op.getElements())
1259 set.insert(state.at(val));
1261 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1262 std::move(set), op.getSet().getType());
1263 return DeletionKind::Delete;
1266 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1267 auto set = get<SetStorage *>(op.getSet())->set;
1270 return op->emitError(
"cannot select from an empty set");
1274 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1275 std::mt19937 customRng(intAttr.getInt());
1281 state[op.getResult()] = set[selected];
1282 return DeletionKind::Delete;
1285 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1286 auto original = get<SetStorage *>(op.getOriginal())->set;
1287 auto diff = get<SetStorage *>(op.getDiff())->set;
1289 SetVector<ElaboratorValue> result(original);
1290 result.set_subtract(diff);
1292 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1293 std::move(result), op.getResult().getType());
1294 return DeletionKind::Delete;
1297 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1298 SetVector<ElaboratorValue> result;
1299 for (
auto set : op.getSets())
1300 result.set_union(
get<SetStorage *>(set)->set);
1302 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1303 std::move(result), op.getType());
1304 return DeletionKind::Delete;
1307 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1308 auto size = get<SetStorage *>(op.getSet())->set.size();
1309 state[op.getResult()] = size;
1310 return DeletionKind::Delete;
1316 FailureOr<DeletionKind> visitOp(SetCartesianProductOp op) {
1317 SetVector<ElaboratorValue> result;
1318 SmallVector<SmallVector<ElaboratorValue>> tuples;
1319 tuples.push_back({});
1321 for (
auto input : op.getInputs()) {
1322 auto &set = get<SetStorage *>(input)->set;
1324 SetVector<ElaboratorValue>
empty;
1325 state[op.getResult()] =
1326 sharedState.internalizer.internalize<SetStorage>(std::move(
empty),
1328 return DeletionKind::Delete;
1331 for (
unsigned i = 0, e = tuples.size(); i < e; ++i) {
1332 for (
auto setEl : set.getArrayRef().drop_back()) {
1333 tuples.push_back(tuples[i]);
1334 tuples.back().push_back(setEl);
1336 tuples[i].push_back(set.back());
1340 for (
auto &tup : tuples)
1342 sharedState.internalizer.internalize<TupleStorage>(std::move(tup)));
1344 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1345 std::move(result), op.getType());
1346 return DeletionKind::Delete;
1349 FailureOr<DeletionKind> visitOp(SetConvertToBagOp op) {
1350 auto set = get<SetStorage *>(op.getInput())->set;
1351 MapVector<ElaboratorValue, uint64_t> bag;
1352 for (
auto val : set)
1353 bag.insert({val, 1});
1354 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1355 std::move(bag), op.getType());
1356 return DeletionKind::Delete;
1359 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1360 MapVector<ElaboratorValue, uint64_t> bag;
1361 for (
auto [val, multiple] :
1362 llvm::zip(op.getElements(), op.getMultiples())) {
1366 bag[state.at(val)] += get<size_t>(multiple);
1369 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1370 std::move(bag), op.getType());
1371 return DeletionKind::Delete;
1374 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1375 auto bag = get<BagStorage *>(op.getBag())->bag;
1378 return op->emitError(
"cannot select from an empty bag");
1380 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1381 prefixSum.reserve(bag.size());
1382 uint32_t accumulator = 0;
1383 for (
auto [val, weight] : bag) {
1384 accumulator += weight;
1385 prefixSum.push_back({val, accumulator});
1388 auto customRng = sharedState.rng;
1390 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1391 customRng = std::mt19937(intAttr.getInt());
1395 auto *iter = llvm::upper_bound(
1397 [](uint32_t a,
const std::pair<ElaboratorValue, uint32_t> &b) {
1398 return a < b.second;
1401 state[op.getResult()] = iter->first;
1402 return DeletionKind::Delete;
1405 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1406 auto original = get<BagStorage *>(op.getOriginal())->bag;
1407 auto diff = get<BagStorage *>(op.getDiff())->bag;
1409 MapVector<ElaboratorValue, uint64_t> result;
1410 for (
const auto &el : original) {
1411 if (!diff.contains(el.first)) {
1419 auto toDiff = diff.lookup(el.first);
1420 if (el.second <= toDiff)
1423 result.insert({el.first, el.second - toDiff});
1426 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1427 std::move(result), op.getType());
1428 return DeletionKind::Delete;
1431 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1432 MapVector<ElaboratorValue, uint64_t> result;
1433 for (
auto bag : op.getBags()) {
1434 auto val = get<BagStorage *>(bag)->bag;
1435 for (
auto [el, multiple] : val)
1436 result[el] += multiple;
1439 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1440 std::move(result), op.getType());
1441 return DeletionKind::Delete;
1444 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1445 auto size = get<BagStorage *>(op.getBag())->bag.size();
1446 state[op.getResult()] = size;
1447 return DeletionKind::Delete;
1450 FailureOr<DeletionKind> visitOp(BagConvertToSetOp op) {
1451 auto bag = get<BagStorage *>(op.getInput())->bag;
1452 SetVector<ElaboratorValue> set;
1453 for (
auto [k, v] : bag)
1455 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1456 std::move(set), op.getType());
1457 return DeletionKind::Delete;
1460 FailureOr<DeletionKind> visitOp(FixedRegisterOp op) {
1461 return visitPureOp(op);
1464 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1465 auto *val = sharedState.internalizer.create<VirtualRegisterStorage>(
1466 op.getAllowedRegsAttr(), op.getType());
1467 state[op.getResult()] = val;
1468 materializer.registerIdentityValue(val);
1469 return DeletionKind::Delete;
1472 StringAttr substituteFormatString(StringAttr formatString,
1473 ValueRange substitutes)
const {
1474 if (substitutes.empty() || formatString.empty())
1475 return formatString;
1477 auto original = formatString.getValue().str();
1478 for (
auto [i, subst] :
llvm::enumerate(substitutes)) {
1479 size_t startPos = 0;
1480 std::string from =
"{{" + std::to_string(i) +
"}}";
1481 while ((startPos = original.find(from, startPos)) != std::string::npos) {
1482 auto substString = std::to_string(get<size_t>(subst));
1483 original.replace(startPos, from.length(), substString);
1487 return StringAttr::get(formatString.getContext(), original);
1490 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1491 SmallVector<ElaboratorValue> array;
1492 array.reserve(op.getElements().size());
1493 for (
auto val : op.getElements())
1494 array.emplace_back(state.at(val));
1496 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1497 op.getResult().getType(), std::move(array));
1498 return DeletionKind::Delete;
1501 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1502 auto array = get<ArrayStorage *>(op.getArray())->array;
1503 size_t idx = get<size_t>(op.getIndex());
1505 if (array.size() <= idx)
1506 return op->emitError(
"invalid to access index ")
1507 << idx <<
" of an array with " << array.size() <<
" elements";
1509 state[op.getResult()] = array[idx];
1510 return DeletionKind::Delete;
1513 FailureOr<DeletionKind> visitOp(ArrayInjectOp op) {
1514 auto array = get<ArrayStorage *>(op.getArray())->array;
1515 size_t idx = get<size_t>(op.getIndex());
1517 if (array.size() <= idx)
1518 return op->emitError(
"invalid to access index ")
1519 << idx <<
" of an array with " << array.size() <<
" elements";
1521 array[idx] = state[op.getValue()];
1522 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1523 op.getResult().getType(), std::move(array));
1524 return DeletionKind::Delete;
1527 FailureOr<DeletionKind> visitOp(ArraySizeOp op) {
1528 auto array = get<ArrayStorage *>(op.getArray())->array;
1529 state[op.getResult()] = array.size();
1530 return DeletionKind::Delete;
1533 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
1535 substituteFormatString(op.getFormatStringAttr(), op.getArgs());
1536 state[op.getLabel()] = LabelValue(substituted);
1537 return DeletionKind::Delete;
1540 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1541 auto *val = sharedState.internalizer.create<UniqueLabelStorage>(
1542 substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
1543 state[op.getLabel()] = val;
1544 materializer.registerIdentityValue(val);
1545 return DeletionKind::Delete;
1548 FailureOr<DeletionKind> visitOp(LabelOp op) {
return DeletionKind::Keep; }
1550 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1551 size_t lower = get<size_t>(op.getLowerBound());
1552 size_t upper = get<size_t>(op.getUpperBound()) - 1;
1554 return op->emitError(
"cannot select a number from an empty range");
1557 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1558 std::mt19937 customRng(intAttr.getInt());
1559 state[op.getResult()] =
1562 state[op.getResult()] =
1566 return DeletionKind::Delete;
1569 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1570 size_t input = get<size_t>(op.getInput());
1571 auto width = op.getType().getWidth();
1572 auto emitError = [&]() {
return op->emitError(); };
1573 if (input > APInt::getAllOnes(width).getZExtValue())
1574 return emitError() <<
"cannot represent " << input <<
" with " << width
1577 state[op.getResult()] =
1578 ImmediateAttr::get(op.getContext(), APInt(width, input));
1579 return DeletionKind::Delete;
1582 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1583 ContextResourceAttrInterface from = currentContext,
1584 to = cast<ContextResourceAttrInterface>(
1585 get<TypedAttr>(op.getContext()));
1586 if (!currentContext)
1587 from = DefaultContextAttr::get(op->getContext(), to.getType());
1589 auto emitError = [&]() {
1590 auto diag = op.emitError();
1591 diag.attachNote(op.getLoc())
1592 <<
"while materializing value for context switching for " << op;
1597 Value seqVal = materializer.materialize(
1598 get<SequenceStorage *>(op.getSequence()), op.getLoc(), emitError);
1603 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1604 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1605 return DeletionKind::Delete;
1611 auto *iter = testState.contextSwitches.find({from, to});
1614 if (iter == testState.contextSwitches.end())
1615 iter = testState.contextSwitches.find(
1616 {from, AnyContextAttr::get(op->getContext(), to.getType())});
1619 if (iter == testState.contextSwitches.end())
1620 iter = testState.contextSwitches.find(
1621 {AnyContextAttr::get(op->getContext(), from.getType()), to});
1624 if (iter == testState.contextSwitches.end())
1625 iter = testState.contextSwitches.find(
1626 {AnyContextAttr::get(op->getContext(), from.getType()),
1627 AnyContextAttr::get(op->getContext(), to.getType())});
1633 if (iter == testState.contextSwitches.end())
1634 return op->emitError(
"no context transition registered to switch from ")
1635 << from <<
" to " << to;
1637 auto familyName = iter->second->familyName;
1638 SmallVector<ElaboratorValue> args{from, to,
1639 get<SequenceStorage *>(op.getSequence())};
1640 auto *
seq = sharedState.internalizer.internalize<SequenceStorage>(
1641 familyName, std::move(args));
1643 sharedState.internalizer.create<RandomizedSequenceStorage>(to,
seq);
1644 materializer.registerIdentityValue(randSeq);
1645 Value seqVal = materializer.materialize(randSeq, op.getLoc(), emitError);
1649 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1650 return DeletionKind::Delete;
1653 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1654 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1655 get<SequenceStorage *>(op.getSequence());
1656 return DeletionKind::Delete;
1659 FailureOr<DeletionKind> visitOp(MemoryBlockDeclareOp op) {
1660 auto *val = sharedState.internalizer.create<MemoryBlockStorage>(
1661 op.getBaseAddress(), op.getEndAddress(), op.getType());
1662 state[op.getResult()] = val;
1663 materializer.registerIdentityValue(val);
1664 return DeletionKind::Delete;
1667 FailureOr<DeletionKind> visitOp(MemoryAllocOp op) {
1668 size_t size = get<size_t>(op.getSize());
1669 size_t alignment = get<size_t>(op.getAlignment());
1670 auto *memBlock = get<MemoryBlockStorage *>(op.getMemoryBlock());
1671 auto *val = sharedState.internalizer.create<MemoryStorage>(memBlock, size,
1673 state[op.getResult()] = val;
1674 materializer.registerIdentityValue(val);
1675 return DeletionKind::Delete;
1678 FailureOr<DeletionKind> visitOp(MemorySizeOp op) {
1679 auto *memory = get<MemoryStorage *>(op.getMemory());
1680 state[op.getResult()] = memory->size;
1681 return DeletionKind::Delete;
1684 FailureOr<DeletionKind> visitOp(TupleCreateOp op) {
1685 SmallVector<ElaboratorValue> values;
1686 values.reserve(op.getElements().size());
1687 for (
auto el : op.getElements())
1688 values.push_back(state[el]);
1690 state[op.getResult()] =
1691 sharedState.internalizer.internalize<TupleStorage>(std::move(values));
1692 return DeletionKind::Delete;
1695 FailureOr<DeletionKind> visitOp(TupleExtractOp op) {
1696 auto *tuple = get<TupleStorage *>(op.getTuple());
1697 state[op.getResult()] = tuple->values[op.getIndex().getZExtValue()];
1698 return DeletionKind::Delete;
1701 FailureOr<DeletionKind> visitOp(CommentOp op) {
return DeletionKind::Keep; }
1703 FailureOr<DeletionKind> visitOp(rtg::YieldOp op) {
1704 return DeletionKind::Keep;
1707 FailureOr<DeletionKind> visitOp(ValidateOp op) {
1708 auto *validationVal = sharedState.internalizer.create<ValidationValue>(
1709 op.getType(), state[op.getRef()], state[op.getDefaultValue()],
1711 state[op.getValue()] = validationVal;
1712 materializer.registerIdentityValue(validationVal);
1713 materializer.map(validationVal, op.getValue());
1714 return DeletionKind::Keep;
1717 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1718 bool cond = get<bool>(op.getCondition());
1719 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1720 if (toElaborate.empty())
1721 return DeletionKind::Delete;
1727 SmallVector<ElaboratorValue> yieldedVals;
1728 if (failed(elaborate(toElaborate, {}, yieldedVals)))
1732 for (
auto [res, out] :
llvm::zip(op.getResults(), yieldedVals))
1735 return DeletionKind::Delete;
1738 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1739 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1740 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1741 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1742 return op->emitOpError(
"can only elaborate index type iterator");
1744 auto lowerBound = get<size_t>(op.getLowerBound());
1745 auto step = get<size_t>(op.getStep());
1746 auto upperBound = get<size_t>(op.getUpperBound());
1752 state[op.getInductionVar()] = lowerBound;
1753 for (
auto [iterArg, initArg] :
1754 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1755 state[iterArg] = state.at(initArg);
1758 SmallVector<ElaboratorValue> yieldedVals;
1759 for (
size_t i = lowerBound; i < upperBound; i += step) {
1760 yieldedVals.clear();
1761 if (failed(elaborate(op.getBodyRegion(), {}, yieldedVals)))
1766 state[op.getInductionVar()] = i + step;
1767 for (
auto [iterArg, prevIterArg] :
1768 llvm::zip(op.getRegionIterArgs(), yieldedVals))
1769 state[iterArg] = prevIterArg;
1773 for (
auto [res, iterArg] :
1774 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1775 state[res] = state.at(iterArg);
1777 return DeletionKind::Delete;
1780 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1781 return DeletionKind::Delete;
1784 FailureOr<DeletionKind> visitOp(arith::AddIOp op) {
1785 if (!isa<IndexType>(op.getType()))
1786 return op->emitError(
"only index operands supported");
1788 size_t lhs = get<size_t>(op.getLhs());
1789 size_t rhs = get<size_t>(op.getRhs());
1790 state[op.getResult()] = lhs + rhs;
1791 return DeletionKind::Delete;
1794 FailureOr<DeletionKind> visitOp(arith::AndIOp op) {
1795 if (!op.getType().isSignlessInteger(1))
1796 return op->emitError(
"only 'i1' operands supported");
1798 bool lhs = get<bool>(op.getLhs());
1799 bool rhs = get<bool>(op.getRhs());
1800 state[op.getResult()] = lhs && rhs;
1801 return DeletionKind::Delete;
1804 FailureOr<DeletionKind> visitOp(arith::XOrIOp op) {
1805 if (!op.getType().isSignlessInteger(1))
1806 return op->emitError(
"only 'i1' operands supported");
1808 bool lhs = get<bool>(op.getLhs());
1809 bool rhs = get<bool>(op.getRhs());
1810 state[op.getResult()] = lhs != rhs;
1811 return DeletionKind::Delete;
1814 FailureOr<DeletionKind> visitOp(arith::OrIOp op) {
1815 if (!op.getType().isSignlessInteger(1))
1816 return op->emitError(
"only 'i1' operands supported");
1818 bool lhs = get<bool>(op.getLhs());
1819 bool rhs = get<bool>(op.getRhs());
1820 state[op.getResult()] = lhs || rhs;
1821 return DeletionKind::Delete;
1824 FailureOr<DeletionKind> visitOp(arith::SelectOp op) {
1825 bool cond = get<bool>(op.getCondition());
1826 auto trueVal = state[op.getTrueValue()];
1827 auto falseVal = state[op.getFalseValue()];
1828 state[op.getResult()] = cond ? trueVal : falseVal;
1829 return DeletionKind::Delete;
1832 FailureOr<DeletionKind> visitOp(index::AddOp op) {
1833 size_t lhs = get<size_t>(op.getLhs());
1834 size_t rhs = get<size_t>(op.getRhs());
1835 state[op.getResult()] = lhs + rhs;
1836 return DeletionKind::Delete;
1839 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
1840 size_t lhs = get<size_t>(op.getLhs());
1841 size_t rhs = get<size_t>(op.getRhs());
1843 switch (op.getPred()) {
1844 case index::IndexCmpPredicate::EQ:
1845 result = lhs == rhs;
1847 case index::IndexCmpPredicate::NE:
1848 result = lhs != rhs;
1850 case index::IndexCmpPredicate::ULT:
1853 case index::IndexCmpPredicate::ULE:
1854 result = lhs <= rhs;
1856 case index::IndexCmpPredicate::UGT:
1859 case index::IndexCmpPredicate::UGE:
1860 result = lhs >= rhs;
1863 return op->emitOpError(
"elaboration not supported");
1865 state[op.getResult()] = result;
1866 return DeletionKind::Delete;
1869 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
1870 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
1873 arith::AddIOp, arith::XOrIOp, arith::AndIOp, arith::OrIOp,
1876 index::AddOp, index::CmpOp,
1878 scf::IfOp, scf::ForOp, scf::YieldOp>(
1879 [&](
auto op) {
return visitOp(op); })
1880 .Default([&](Operation *op) {
return RTGBase::dispatchOpVisitor(op); });
1884 LogicalResult elaborate(Region ®ion,
1885 ArrayRef<ElaboratorValue> regionArguments,
1886 SmallVector<ElaboratorValue> &terminatorOperands) {
1887 if (region.getBlocks().size() > 1)
1888 return region.getParentOp()->emitOpError(
1889 "regions with more than one block are not supported");
1891 for (
auto [arg, elabArg] :
1892 llvm::zip(region.getArguments(), regionArguments))
1893 state[arg] = elabArg;
1895 Block *block = ®ion.front();
1896 for (
auto &op : *block) {
1897 auto result = dispatchOpVisitor(&op);
1901 if (*result == DeletionKind::Keep)
1902 if (failed(materializer.materialize(&op, state)))
1906 llvm::dbgs() <<
"Elaborated " << op <<
" to\n[";
1908 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](
auto res) {
1909 if (state.contains(res))
1910 llvm::dbgs() << state.at(res);
1912 llvm::dbgs() <<
"unknown";
1915 llvm::dbgs() <<
"]\n\n";
1919 if (region.front().mightHaveTerminator())
1920 for (
auto val : region.front().getTerminator()->getOperands())
1921 terminatorOperands.push_back(state.at(val));
1928 SharedState &sharedState;
1931 TestState &testState;
1935 Materializer &materializer;
1938 DenseMap<Value, ElaboratorValue> state;
1941 ContextResourceAttrInterface currentContext;
1946Materializer::elaborateSequence(
const RandomizedSequenceStorage *
seq,
1947 SmallVector<ElaboratorValue> &elabArgs) {
1949 sharedState.table.lookup<SequenceOp>(
seq->sequence->familyName);
1952 OpBuilder builder(familyOp);
1953 auto seqOp = builder.cloneWithoutRegions(familyOp);
1954 auto name = sharedState.names.newName(
seq->sequence->familyName.getValue());
1955 seqOp.setSymName(name);
1956 seqOp.getBodyRegion().emplaceBlock();
1957 sharedState.table.insert(seqOp);
1958 assert(seqOp.getSymName() == name &&
"should not have been renamed");
1960 LLVM_DEBUG(llvm::dbgs() <<
"\n=== Elaborating sequence family @"
1961 << familyOp.getSymName() <<
" into @"
1962 << seqOp.getSymName() <<
" under context "
1963 <<
seq->context <<
"\n\n");
1965 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()), testState,
1966 sharedState, elabArgs);
1967 Elaborator elaborator(sharedState, testState, materializer,
seq->context);
1968 SmallVector<ElaboratorValue> yieldedVals;
1969 if (failed(elaborator.elaborate(familyOp.getBodyRegion(),
seq->sequence->args,
1973 seqOp.setSequenceType(
1974 SequenceType::get(builder.getContext(), materializer.getBlockArgTypes()));
1975 materializer.finalize();
1985struct ElaborationPass
1986 :
public rtg::impl::ElaborationPassBase<ElaborationPass> {
1989 void runOnOperation()
override;
1990 void matchTestsAgainstTargets(SymbolTable &table);
1991 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
1995void ElaborationPass::runOnOperation() {
1996 auto moduleOp = getOperation();
1997 SymbolTable table(moduleOp);
1999 matchTestsAgainstTargets(table);
2001 if (failed(elaborateModule(moduleOp, table)))
2002 return signalPassFailure();
2005void ElaborationPass::matchTestsAgainstTargets(SymbolTable &table) {
2006 auto moduleOp = getOperation();
2008 for (
auto test :
llvm::make_early_inc_range(moduleOp.getOps<TestOp>())) {
2009 if (test.getTargetAttr())
2012 bool matched =
false;
2014 for (
auto target : moduleOp.getOps<TargetOp>()) {
2018 bool isSubtype =
true;
2019 auto testEntries = test.getTargetType().getEntries();
2020 auto targetEntries = target.getTarget().getEntries();
2024 size_t targetIdx = 0;
2025 for (
auto testEntry : testEntries) {
2027 while (targetIdx < targetEntries.size() &&
2028 targetEntries[targetIdx].name.getValue() <
2029 testEntry.name.getValue())
2033 if (targetIdx >= targetEntries.size() ||
2034 targetEntries[targetIdx].name != testEntry.name ||
2035 targetEntries[targetIdx].type != testEntry.type) {
2044 IRRewriter rewriter(test);
2046 auto newTest = cast<TestOp>(test->clone());
2047 newTest.setSymName(test.getSymName().str() +
"_" +
2048 target.getSymName().str());
2052 newTest.setTargetAttr(target.getSymNameAttr());
2054 table.insert(newTest, rewriter.getInsertionPoint());
2058 if (matched || deleteUnmatchedTests)
2064 return isa<MemoryBlockType, ContextResourceTypeInterface>(type);
2067LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
2068 SymbolTable &table) {
2069 SharedState state(table, seed);
2072 state.names.add(moduleOp);
2074 struct TargetElabResult {
2075 DictType targetType;
2076 SmallVector<ElaboratorValue> yields;
2077 TestState testState;
2081 DenseMap<StringAttr, TargetElabResult> targetMap;
2082 for (
auto targetOp : moduleOp.getOps<TargetOp>()) {
2083 LLVM_DEBUG(llvm::dbgs() <<
"=== Elaborating target @"
2084 << targetOp.getSymName() <<
"\n\n");
2086 auto &result = targetMap[targetOp.getSymNameAttr()];
2087 result.targetType = targetOp.getTarget();
2089 SmallVector<ElaboratorValue> blockArgs;
2090 Materializer targetMaterializer(OpBuilder::atBlockBegin(targetOp.getBody()),
2091 result.testState, state, blockArgs);
2092 Elaborator targetElaborator(state, result.testState, targetMaterializer);
2095 if (failed(targetElaborator.elaborate(targetOp.getBodyRegion(), {},
2102 for (
auto testOp : moduleOp.getOps<TestOp>()) {
2106 if (!testOp.getTargetAttr())
2109 LLVM_DEBUG(llvm::dbgs()
2110 <<
"\n=== Elaborating test @" << testOp.getTemplateName()
2111 <<
" for target @" << *testOp.getTarget() <<
"\n\n");
2114 auto targetResult = targetMap[testOp.getTargetAttr()];
2115 TestState testState = targetResult.testState;
2116 testState.name = testOp.getSymNameAttr();
2118 SmallVector<ElaboratorValue> filteredYields;
2120 for (
auto [entry, yield] :
2121 llvm::zip(targetResult.targetType.getEntries(), targetResult.yields)) {
2122 if (i >= testOp.getTargetType().getEntries().size())
2125 if (entry.name == testOp.getTargetType().getEntries()[i].name) {
2126 filteredYields.push_back(yield);
2133 SmallVector<ElaboratorValue> blockArgs;
2134 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()),
2135 testState, state, blockArgs);
2137 for (
auto [arg, val] :
2138 llvm::zip(testOp.getBody()->getArguments(), filteredYields))
2140 materializer.map(val, arg);
2142 Elaborator elaborator(state, testState, materializer);
2143 SmallVector<ElaboratorValue> ignore;
2144 if (failed(elaborator.elaborate(testOp.getBodyRegion(), filteredYields,
2148 materializer.finalize();
assert(baseType &&"element must be base type")
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static bool onlyLegalToMaterializeInTarget(Type type)
static void print(TypedAttr val, llvm::raw_ostream &os)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
This helps visit TypeOp nodes.
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
static llvm::hash_code hash_value(const ModulePort &port)
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static LabelValue getEmptyKey()
static LabelValue getTombstoneKey()
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)
static unsigned getEmptyKey()