21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Index/IR/IndexDialect.h"
23#include "mlir/Dialect/Index/IR/IndexOps.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/DenseMapInfoVariant.h"
28#include "llvm/Support/Debug.h"
34#define GEN_PASS_DEF_ELABORATIONPASS
35#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
44#define DEBUG_TYPE "rtg-elaboration"
56 size_t n = w / 32 + (w % 32 != 0);
58 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
63 const uint32_t diff = b - a + 1;
67 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
71 uint32_t width = digits - llvm::countl_zero(diff) - 1;
72 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
91struct SequenceStorage;
92struct RandomizedSequenceStorage;
93struct InterleavedSequenceStorage;
95struct VirtualRegisterStorage;
96struct UniqueLabelStorage;
99struct MemoryBlockStorage;
100struct ValidationValue;
101struct ValidationMuxedValue;
102struct ImmediateConcatStorage;
103struct ImmediateSliceStorage;
109 LabelValue(StringAttr name) : name(name) {}
111 bool operator==(
const LabelValue &other)
const {
return name == other.name; }
118using ElaboratorValue = std::variant<
119 TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
120 RandomizedSequenceStorage *, InterleavedSequenceStorage *, SetStorage *,
121 VirtualRegisterStorage *, UniqueLabelStorage *, LabelValue, ArrayStorage *,
122 TupleStorage *, MemoryStorage *, MemoryBlockStorage *, ValidationValue *,
123 ValidationMuxedValue *, ImmediateConcatStorage *, ImmediateSliceStorage *>;
126llvm::hash_code
hash_value(
const LabelValue &val) {
131llvm::hash_code
hash_value(
const ElaboratorValue &val) {
133 [&val](
const auto &alternative) {
136 return llvm::hash_combine(val.index(), alternative);
151 static bool isEqual(
const bool &lhs,
const bool &rhs) {
return lhs == rhs; }
165 static bool isEqual(
const LabelValue &lhs,
const LabelValue &rhs) {
182template <
typename StorageTy>
183struct HashedStorage {
184 HashedStorage(
unsigned hashcode = 0, StorageTy *storage =
nullptr)
185 : hashcode(hashcode), storage(storage) {}
195template <
typename StorageTy>
196struct StorageKeyInfo {
197 static inline HashedStorage<StorageTy> getEmptyKey() {
198 return HashedStorage<StorageTy>(0,
199 DenseMapInfo<StorageTy *>::getEmptyKey());
201 static inline HashedStorage<StorageTy> getTombstoneKey() {
202 return HashedStorage<StorageTy>(
203 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
206 static inline unsigned getHashValue(
const HashedStorage<StorageTy> &key) {
209 static inline unsigned getHashValue(
const StorageTy &key) {
213 static inline bool isEqual(
const HashedStorage<StorageTy> &lhs,
214 const HashedStorage<StorageTy> &rhs) {
215 return lhs.storage == rhs.storage;
217 static inline bool isEqual(
const StorageTy &lhs,
218 const HashedStorage<StorageTy> &rhs) {
219 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
222 return lhs.isEqual(rhs.storage);
231 static unsigned computeHash(
const SetVector<ElaboratorValue> &set,
233 llvm::hash_code setHash = 0;
234 for (
auto el : set) {
239 setHash = setHash ^ llvm::hash_combine(el);
241 return llvm::hash_combine(type, setHash);
244 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
245 : hashcode(computeHash(set, type)), set(std::move(set)), type(type) {}
247 bool isEqual(
const SetStorage *other)
const {
252 bool allContained =
true;
254 allContained &= other->set.contains(el);
256 return hashcode == other->hashcode && set.size() == other->set.size() &&
257 allContained && type == other->type;
261 const unsigned hashcode;
264 const SetVector<ElaboratorValue> set;
273 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
275 type,
llvm::hash_combine_range(bag.begin(), bag.
end()))),
276 bag(std::move(bag)), type(type) {}
278 bool isEqual(
const BagStorage *other)
const {
279 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
284 const unsigned hashcode;
288 const MapVector<ElaboratorValue, uint64_t> bag;
296struct SequenceStorage {
297 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
299 familyName,
llvm::hash_combine_range(args.begin(), args.
end()))),
300 familyName(familyName), args(std::move(args)) {}
302 bool isEqual(
const SequenceStorage *other)
const {
303 return hashcode == other->hashcode && familyName == other->familyName &&
308 const unsigned hashcode;
311 const StringAttr familyName;
314 const SmallVector<ElaboratorValue> args;
318struct InterleavedSequenceStorage {
319 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
321 : sequences(std::move(sequences)), batchSize(batchSize),
323 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
326 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
327 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
329 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
332 bool isEqual(
const InterleavedSequenceStorage *other)
const {
333 return hashcode == other->hashcode && sequences == other->sequences &&
334 batchSize == other->batchSize;
337 const SmallVector<ElaboratorValue> sequences;
339 const uint32_t batchSize;
342 const unsigned hashcode;
347 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
349 type,
llvm::hash_combine_range(array.begin(), array.
end()))),
350 type(type), array(array) {}
352 bool isEqual(
const ArrayStorage *other)
const {
353 return hashcode == other->hashcode && type == other->type &&
354 array == other->array;
358 const unsigned hashcode;
365 const SmallVector<ElaboratorValue> array;
370 TupleStorage(SmallVector<ElaboratorValue> &&values)
371 : hashcode(
llvm::hash_combine_range(values.begin(), values.
end())),
372 values(std::move(values)) {}
374 bool isEqual(
const TupleStorage *other)
const {
375 return hashcode == other->hashcode && values == other->values;
379 const unsigned hashcode;
381 const SmallVector<ElaboratorValue> values;
386struct ImmediateConcatStorage {
387 ImmediateConcatStorage(SmallVector<ElaboratorValue> &&operands)
388 : hashcode(
llvm::hash_combine_range(operands.begin(), operands.
end())),
389 operands(std::move(operands)) {}
391 bool isEqual(
const ImmediateConcatStorage *other)
const {
392 return hashcode == other->hashcode && operands == other->operands;
395 const unsigned hashcode;
396 const SmallVector<ElaboratorValue> operands;
401struct ImmediateSliceStorage {
402 ImmediateSliceStorage(ElaboratorValue input,
unsigned lowBit, Type type)
404 lowBit(lowBit), type(type) {}
406 bool isEqual(
const ImmediateSliceStorage *other)
const {
407 return hashcode == other->hashcode && input == other->input &&
408 lowBit == other->lowBit && type == other->type;
411 const unsigned hashcode;
412 const ElaboratorValue input;
413 const unsigned lowBit;
425struct IdentityValue {
427 IdentityValue(Type type) : type(type) {}
440 bool alreadyMaterialized =
false;
448struct VirtualRegisterStorage : IdentityValue {
449 VirtualRegisterStorage(VirtualRegisterConfigAttr allowedRegs, Type type)
450 : IdentityValue(type), allowedRegs(allowedRegs) {}
457 const VirtualRegisterConfigAttr allowedRegs;
460struct UniqueLabelStorage : IdentityValue {
461 UniqueLabelStorage(StringAttr name)
462 : IdentityValue(LabelType::
get(name.getContext())), name(name) {}
468 const StringAttr name;
472struct MemoryBlockStorage : IdentityValue {
473 MemoryBlockStorage(
const APInt &baseAddress,
const APInt &endAddress,
475 : IdentityValue(type), baseAddress(baseAddress), endAddress(endAddress) {}
480 const APInt baseAddress;
483 const APInt endAddress;
487struct MemoryStorage : IdentityValue {
488 MemoryStorage(MemoryBlockStorage *memoryBlock,
size_t size,
size_t alignment)
489 : IdentityValue(MemoryType::
get(memoryBlock->type.getContext(),
491 memoryBlock(memoryBlock), size(size), alignment(alignment) {}
493 MemoryBlockStorage *memoryBlock;
495 const size_t alignment;
499struct RandomizedSequenceStorage : IdentityValue {
500 RandomizedSequenceStorage(ContextResourceAttrInterface context,
501 SequenceStorage *sequence)
503 RandomizedSequenceType::
get(sequence->familyName.getContext())),
504 context(context), sequence(sequence) {}
507 const ContextResourceAttrInterface context;
509 const SequenceStorage *sequence;
513struct ValidationValue : IdentityValue {
514 ValidationValue(Type type,
const ElaboratorValue &ref,
515 const ElaboratorValue &defaultValue, StringAttr
id,
516 SmallVector<ElaboratorValue> &&defaultUsedValues,
517 SmallVector<ElaboratorValue> &&elseValues)
518 : IdentityValue(type), ref(ref), defaultValue(defaultValue), id(id),
519 defaultUsedValues(std::move(defaultUsedValues)),
520 elseValues(std::move(elseValues)) {}
522 const ElaboratorValue ref;
523 const ElaboratorValue defaultValue;
525 const SmallVector<ElaboratorValue> defaultUsedValues;
526 const SmallVector<ElaboratorValue> elseValues;
530struct ValidationMuxedValue : IdentityValue {
531 ValidationMuxedValue(Type type,
const ValidationValue *value,
unsigned idx)
532 : IdentityValue(type), value(value), idx(idx) {}
534 const ValidationValue *value;
548 template <
typename StorageTy,
typename... Args>
549 StorageTy *internalize(Args &&...args) {
550 static_assert(!std::is_base_of_v<IdentityValue, StorageTy> &&
551 "values with identity must not be internalized");
553 StorageTy storage(std::forward<Args>(args)...);
555 auto existing = getInternSet<StorageTy>().insert_as(
556 HashedStorage<StorageTy>(storage.hashcode), storage);
557 StorageTy *&storagePtr = existing.first->storage;
560 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
565 template <
typename StorageTy,
typename... Args>
566 StorageTy *create(Args &&...args) {
567 static_assert(std::is_base_of_v<IdentityValue, StorageTy> &&
568 "values with structural equivalence must be internalized");
570 return new (allocator.Allocate<StorageTy>())
571 StorageTy(std::forward<Args>(args)...);
575 template <
typename StorageTy>
576 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
578 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
579 return internedArrays;
580 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
582 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
584 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
585 return internedSequences;
586 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
587 return internedRandomizedSequences;
588 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
589 return internedInterleavedSequences;
590 else if constexpr (std::is_same_v<StorageTy, TupleStorage>)
591 return internedTuples;
592 else if constexpr (std::is_same_v<StorageTy, ImmediateConcatStorage>)
593 return internedImmediateConcatValues;
594 else if constexpr (std::is_same_v<StorageTy, ImmediateSliceStorage>)
595 return internedImmediateSliceValues;
597 static_assert(!
sizeof(StorageTy),
598 "no intern set available for this storage type.");
603 llvm::BumpPtrAllocator allocator;
608 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
610 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
611 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
612 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
614 DenseSet<HashedStorage<RandomizedSequenceStorage>,
615 StorageKeyInfo<RandomizedSequenceStorage>>
616 internedRandomizedSequences;
617 DenseSet<HashedStorage<InterleavedSequenceStorage>,
618 StorageKeyInfo<InterleavedSequenceStorage>>
619 internedInterleavedSequences;
620 DenseSet<HashedStorage<TupleStorage>, StorageKeyInfo<TupleStorage>>
622 DenseSet<HashedStorage<ImmediateConcatStorage>,
623 StorageKeyInfo<ImmediateConcatStorage>>
624 internedImmediateConcatValues;
625 DenseSet<HashedStorage<ImmediateSliceStorage>,
626 StorageKeyInfo<ImmediateSliceStorage>>
627 internedImmediateSliceValues;
634static llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
635 const ElaboratorValue &value);
637static void print(TypedAttr val, llvm::raw_ostream &os) {
638 os <<
"<attr " << val <<
">";
641static void print(BagStorage *val, llvm::raw_ostream &os) {
643 llvm::interleaveComma(val->bag, os,
644 [&](
const std::pair<ElaboratorValue, uint64_t> &el) {
645 os << el.first <<
" -> " << el.second;
647 os <<
"} at " << val <<
">";
650static void print(
bool val, llvm::raw_ostream &os) {
651 os <<
"<bool " << (val ?
"true" :
"false") <<
">";
654static void print(
size_t val, llvm::raw_ostream &os) {
655 os <<
"<index " << val <<
">";
658static void print(SequenceStorage *val, llvm::raw_ostream &os) {
659 os <<
"<sequence @" << val->familyName.getValue() <<
"(";
660 llvm::interleaveComma(val->args, os,
661 [&](
const ElaboratorValue &val) { os << val; });
662 os <<
") at " << val <<
">";
665static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
666 os <<
"<randomized-sequence derived from @"
667 << val->sequence->familyName.getValue() <<
" under context "
668 << val->context <<
"(";
669 llvm::interleaveComma(val->sequence->args, os,
670 [&](
const ElaboratorValue &val) { os << val; });
671 os <<
") at " << val <<
">";
674static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
675 os <<
"<interleaved-sequence [";
676 llvm::interleaveComma(val->sequences, os,
677 [&](
const ElaboratorValue &val) { os << val; });
678 os <<
"] batch-size " << val->batchSize <<
" at " << val <<
">";
681static void print(ArrayStorage *val, llvm::raw_ostream &os) {
683 llvm::interleaveComma(val->array, os,
684 [&](
const ElaboratorValue &val) { os << val; });
685 os <<
"] at " << val <<
">";
688static void print(SetStorage *val, llvm::raw_ostream &os) {
690 llvm::interleaveComma(val->set, os,
691 [&](
const ElaboratorValue &val) { os << val; });
692 os <<
"} at " << val <<
">";
695static void print(
const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
696 os <<
"<virtual-register " << val <<
" " << val->allowedRegs <<
">";
699static void print(
const UniqueLabelStorage *val, llvm::raw_ostream &os) {
700 os <<
"<unique-label " << val <<
" " << val->name <<
">";
703static void print(
const LabelValue &val, llvm::raw_ostream &os) {
704 os <<
"<label " << val.name <<
">";
707static void print(
const TupleStorage *val, llvm::raw_ostream &os) {
709 llvm::interleaveComma(val->values, os,
710 [&](
const ElaboratorValue &val) { os << val; });
714static void print(
const MemoryStorage *val, llvm::raw_ostream &os) {
715 os <<
"<memory {" << ElaboratorValue(val->memoryBlock)
716 <<
", size=" << val->size <<
", alignment=" << val->alignment <<
"}>";
719static void print(
const MemoryBlockStorage *val, llvm::raw_ostream &os) {
720 os <<
"<memory-block {"
721 <<
", address-width=" << val->baseAddress.getBitWidth()
722 <<
", base-address=" << val->baseAddress
723 <<
", end-address=" << val->endAddress <<
"}>";
726static void print(
const ValidationValue *val, llvm::raw_ostream &os) {
727 os <<
"<validation-value {type=" << val->type <<
", ref=" << val->ref
728 <<
", defaultValue=" << val->defaultValue <<
"}>";
731static void print(
const ValidationMuxedValue *val, llvm::raw_ostream &os) {
732 os <<
"<validation-muxed-value (" << val->value <<
") at " << val->idx <<
">";
735static void print(
const ImmediateConcatStorage *val, llvm::raw_ostream &os) {
736 os <<
"<immediate-concat [";
737 llvm::interleaveComma(val->operands, os,
738 [&](
const ElaboratorValue &val) { os << val; });
742static void print(
const ImmediateSliceStorage *val, llvm::raw_ostream &os) {
743 os <<
"<immediate-slice " << val->input <<
" from " << val->lowBit <<
">";
747 const ElaboratorValue &value) {
748 std::visit([&](
auto val) {
print(val, os); }, value);
763 SharedState(SymbolTable &table,
unsigned seed) : table(table), rng(seed) {}
768 Internalizer internalizer;
778 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
786 Materializer(OpBuilder builder, TestState &testState,
787 SharedState &sharedState,
788 SmallVector<ElaboratorValue> &blockArgs)
789 : builder(builder), testState(testState), sharedState(sharedState),
790 blockArgs(blockArgs) {}
794 Value materialize(ElaboratorValue val, Location loc,
795 function_ref<InFlightDiagnostic()> emitError) {
796 auto iter = materializedValues.find(val);
797 if (iter != materializedValues.end())
800 LLVM_DEBUG(llvm::dbgs() <<
"Materializing " << val);
804 Value res = std::visit(
806 if constexpr (std::is_base_of_v<IdentityValue,
807 std::remove_pointer_t<
808 std::decay_t<
decltype(value)>>>) {
809 if (identityValueRoot.contains(value)) {
812 static_cast<IdentityValue *
>(value)->alreadyMaterialized;
813 assert(!materialized &&
"must not already be materialized");
817 return visit(value, loc, emitError);
820 Value arg = builder.getBlock()->addArgument(value->type, loc);
821 blockArgs.push_back(val);
822 blockArgTypes.push_back(arg.getType());
823 materializedValues[val] = arg;
827 return visit(value, loc, emitError);
831 LLVM_DEBUG(llvm::dbgs() <<
" to\n" << res <<
"\n\n");
842 LogicalResult materialize(Operation *op,
843 DenseMap<Value, ElaboratorValue> &state) {
844 if (op->getNumRegions() > 0)
845 return op->emitOpError(
"ops with nested regions must be elaborated away");
853 for (
auto res : op->getResults())
854 if (!res.use_empty() && !isa<ValidateOp>(op))
855 return op->emitOpError(
856 "ops with results that have uses are not supported");
858 if (op->getParentRegion() == builder.getBlock()->getParent()) {
861 deleteOpsUntil([&](
auto iter) {
return &*iter == op; });
863 if (builder.getInsertionPoint() == builder.getBlock()->end())
864 return op->emitError(
"operation did not occur after the current "
865 "materializer insertion point");
867 LLVM_DEBUG(llvm::dbgs() <<
"Modifying in-place: " << *op <<
"\n\n");
869 LLVM_DEBUG(llvm::dbgs() <<
"Materializing a clone of " << *op <<
"\n\n");
870 op = builder.clone(*op);
871 builder.setInsertionPoint(op);
874 for (
auto &operand : op->getOpOperands()) {
875 auto emitError = [&]() {
876 auto diag = op->emitError();
877 diag.attachNote(op->getLoc())
878 <<
"while materializing value for operand#"
879 << operand.getOperandNumber();
883 auto elabVal = state.at(operand.get());
884 Value val = materialize(elabVal, op->getLoc(), emitError);
888 state[val] = elabVal;
892 builder.setInsertionPointAfter(op);
899 deleteOpsUntil([](
auto iter) {
return false; });
901 for (
auto *op :
llvm::reverse(toDelete))
908 void registerIdentityValue(IdentityValue *val) {
909 identityValueRoot.insert(val);
912 ArrayRef<Type> getBlockArgTypes()
const {
return blockArgTypes; }
914 void map(ElaboratorValue eval, Value val) { materializedValues[eval] = val; }
916 template <
typename OpTy,
typename... Args>
917 OpTy create(Location location, Args &&...args) {
918 return OpTy::create(builder, location, std::forward<Args>(args)...);
922 SequenceOp elaborateSequence(
const RandomizedSequenceStorage *
seq,
923 SmallVector<ElaboratorValue> &elabArgs);
925 void deleteOpsUntil(function_ref<
bool(Block::iterator)> stop) {
926 auto ip = builder.getInsertionPoint();
927 while (ip != builder.getBlock()->end() && !stop(ip)) {
928 LLVM_DEBUG(llvm::dbgs() <<
"Marking to be deleted: " << *ip <<
"\n\n");
929 toDelete.push_back(&*ip);
931 builder.setInsertionPointAfter(&*ip);
932 ip = builder.getInsertionPoint();
936 Value visit(TypedAttr val, Location loc,
937 function_ref<InFlightDiagnostic()> emitError) {
940 if (
auto intAttr = dyn_cast<IntegerAttr>(val);
941 intAttr && isa<IndexType>(val.getType())) {
942 Value res = index::ConstantOp::create(builder, loc, intAttr);
943 materializedValues[val] = res;
950 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
952 emitError() <<
"materializer of dialect '"
953 << val.getDialect().getNamespace()
954 <<
"' unable to materialize value for attribute '" << val
959 Value res = op->getResult(0);
960 materializedValues[val] = res;
964 Value visit(
size_t val, Location loc,
965 function_ref<InFlightDiagnostic()> emitError) {
966 Value res = index::ConstantOp::create(builder, loc, val);
967 materializedValues[val] = res;
971 Value visit(
bool val, Location loc,
972 function_ref<InFlightDiagnostic()> emitError) {
973 Value res = index::BoolConstantOp::create(builder, loc, val);
974 materializedValues[val] = res;
978 Value visit(ArrayStorage *val, Location loc,
979 function_ref<InFlightDiagnostic()> emitError) {
980 SmallVector<Value> elements;
981 elements.reserve(val->array.size());
982 for (
auto el : val->array) {
983 auto materialized = materialize(el, loc, emitError);
987 elements.push_back(materialized);
990 Value res = ArrayCreateOp::create(builder, loc, val->type, elements);
991 materializedValues[val] = res;
995 Value visit(SetStorage *val, Location loc,
996 function_ref<InFlightDiagnostic()> emitError) {
997 SmallVector<Value> elements;
998 elements.reserve(val->set.size());
999 for (
auto el : val->set) {
1000 auto materialized = materialize(el, loc, emitError);
1004 elements.push_back(materialized);
1007 auto res = SetCreateOp::create(builder, loc, val->type, elements);
1008 materializedValues[val] = res;
1012 Value visit(BagStorage *val, Location loc,
1013 function_ref<InFlightDiagnostic()> emitError) {
1014 SmallVector<Value> values, weights;
1015 values.reserve(val->bag.size());
1016 weights.reserve(val->bag.size());
1017 for (
auto [val, weight] : val->bag) {
1018 auto materializedVal = materialize(val, loc, emitError);
1019 auto materializedWeight = materialize(weight, loc, emitError);
1020 if (!materializedVal || !materializedWeight)
1023 values.push_back(materializedVal);
1024 weights.push_back(materializedWeight);
1027 auto res = BagCreateOp::create(builder, loc, val->type, values, weights);
1028 materializedValues[val] = res;
1032 Value visit(MemoryBlockStorage *val, Location loc,
1033 function_ref<InFlightDiagnostic()> emitError) {
1034 auto intType = builder.getIntegerType(val->baseAddress.getBitWidth());
1035 Value res = MemoryBlockDeclareOp::create(
1036 builder, loc, val->type, IntegerAttr::get(intType, val->baseAddress),
1037 IntegerAttr::get(intType, val->endAddress));
1038 materializedValues[val] = res;
1042 Value visit(MemoryStorage *val, Location loc,
1043 function_ref<InFlightDiagnostic()> emitError) {
1044 auto memBlock = materialize(val->memoryBlock, loc, emitError);
1045 auto memSize = materialize(val->size, loc, emitError);
1046 auto memAlign = materialize(val->alignment, loc, emitError);
1047 if (!(memBlock && memSize && memAlign))
1051 MemoryAllocOp::create(builder, loc, memBlock, memSize, memAlign);
1052 materializedValues[val] = res;
1056 Value visit(SequenceStorage *val, Location loc,
1057 function_ref<InFlightDiagnostic()> emitError) {
1058 emitError() <<
"materializing a non-randomized sequence not supported yet";
1062 Value visit(RandomizedSequenceStorage *val, Location loc,
1063 function_ref<InFlightDiagnostic()> emitError) {
1069 SmallVector<ElaboratorValue> elabArgs;
1070 SequenceOp seqOp = elaborateSequence(val, elabArgs);
1076 SmallVector<Value> args;
1077 SmallVector<Type> argTypes;
1078 for (
auto arg : elabArgs) {
1079 Value materialized = materialize(arg, loc, emitError);
1083 args.push_back(materialized);
1084 argTypes.push_back(materialized.getType());
1087 Value res = GetSequenceOp::create(
1088 builder, loc, SequenceType::get(builder.getContext(), argTypes),
1089 seqOp.getSymName());
1094 res = SubstituteSequenceOp::create(builder, loc, res, args);
1096 res = RandomizeSequenceOp::create(builder, loc, res);
1098 materializedValues[val] = res;
1102 Value visit(InterleavedSequenceStorage *val, Location loc,
1103 function_ref<InFlightDiagnostic()> emitError) {
1104 SmallVector<Value> sequences;
1105 for (
auto seqVal : val->sequences) {
1106 Value materialized = materialize(seqVal, loc, emitError);
1110 sequences.push_back(materialized);
1113 if (sequences.size() == 1)
1114 return sequences[0];
1117 InterleaveSequencesOp::create(builder, loc, sequences, val->batchSize);
1118 materializedValues[val] = res;
1122 Value visit(VirtualRegisterStorage *val, Location loc,
1123 function_ref<InFlightDiagnostic()> emitError) {
1124 Value res = VirtualRegisterOp::create(builder, loc, val->allowedRegs);
1125 materializedValues[val] = res;
1129 Value visit(UniqueLabelStorage *val, Location loc,
1130 function_ref<InFlightDiagnostic()> emitError) {
1132 LabelUniqueDeclOp::create(builder, loc, val->name, ValueRange());
1133 materializedValues[val] = res;
1137 Value visit(
const LabelValue &val, Location loc,
1138 function_ref<InFlightDiagnostic()> emitError) {
1139 Value res = LabelDeclOp::create(builder, loc, val.name, ValueRange());
1140 materializedValues[val] = res;
1144 Value visit(TupleStorage *val, Location loc,
1145 function_ref<InFlightDiagnostic()> emitError) {
1146 SmallVector<Value> materialized;
1147 materialized.reserve(val->values.size());
1148 for (
auto v : val->values)
1149 materialized.push_back(materialize(v, loc, emitError));
1150 Value res = TupleCreateOp::create(builder, loc, materialized);
1151 materializedValues[val] = res;
1155 Value visit(ValidationValue *val, Location loc,
1156 function_ref<InFlightDiagnostic()> emitError) {
1157 SmallVector<Value> usedDefaultValues, elseValues;
1158 for (
auto [dfltVal, elseVal] :
1159 llvm::zip(val->defaultUsedValues, val->elseValues)) {
1160 auto dfltMat = materialize(dfltVal, loc, emitError);
1161 auto elseMat = materialize(elseVal, loc, emitError);
1162 if (!dfltMat || !elseMat)
1165 usedDefaultValues.push_back(dfltMat);
1166 usedDefaultValues.push_back(elseMat);
1169 auto validateOp = ValidateOp::create(
1170 builder, loc, val->type, materialize(val->ref, loc, emitError),
1171 materialize(val->defaultValue, loc, emitError), val->id,
1172 usedDefaultValues, elseValues);
1173 materializedValues[val] = validateOp.getValue();
1174 return validateOp.getValue();
1177 Value visit(ValidationMuxedValue *val, Location loc,
1178 function_ref<InFlightDiagnostic()> emitError) {
1179 Value validateValue =
1180 materialize(
const_cast<ValidationValue *
>(val->value), loc, emitError);
1183 auto validateOp = validateValue.getDefiningOp<ValidateOp>();
1185 auto *defOp = validateValue.getDefiningOp();
1187 <<
"expected validate op for validation muxed value, but found "
1188 << (defOp ? defOp->getName().getStringRef() :
"block argument");
1191 materializedValues[val] = validateOp.getValues()[val->idx];
1192 return validateOp.getValues()[val->idx];
1195 Value visit(ImmediateConcatStorage *val, Location loc,
1196 function_ref<InFlightDiagnostic()> emitError) {
1197 SmallVector<Value> operands;
1198 for (
auto operand : val->operands) {
1199 auto materialized = materialize(operand, loc, emitError);
1203 operands.push_back(materialized);
1206 Value res = ConcatImmediateOp::create(builder, loc, operands);
1207 materializedValues[val] = res;
1211 Value visit(ImmediateSliceStorage *val, Location loc,
1212 function_ref<InFlightDiagnostic()> emitError) {
1213 Value input = materialize(val->input, loc, emitError);
1218 SliceImmediateOp::create(builder, loc, val->type, input, val->lowBit);
1219 materializedValues[val] = res;
1229 DenseMap<ElaboratorValue, Value> materializedValues;
1235 SmallVector<Operation *> toDelete;
1237 TestState &testState;
1238 SharedState &sharedState;
1243 SmallVector<ElaboratorValue> &blockArgs;
1244 SmallVector<Type> blockArgTypes;
1249 DenseSet<IdentityValue *> identityValueRoot;
1258enum class DeletionKind { Keep, Delete };
1261class Elaborator :
public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
1264 using RTGBase::visitOp;
1266 Elaborator(SharedState &sharedState, TestState &testState,
1267 Materializer &materializer,
1268 ContextResourceAttrInterface currentContext = {})
1269 : sharedState(sharedState), testState(testState),
1270 materializer(materializer), currentContext(currentContext) {}
1272 template <
typename ValueTy>
1273 inline ValueTy
get(Value val)
const {
1274 return std::get<ValueTy>(state.at(val));
1277 FailureOr<DeletionKind> visitPureOp(Operation *op) {
1278 SmallVector<Attribute> operands;
1279 for (
auto operand : op->getOperands()) {
1280 auto evalValue = state[operand];
1281 if (std::holds_alternative<TypedAttr>(evalValue))
1282 operands.push_back(std::get<TypedAttr>(evalValue));
1287 SmallVector<OpFoldResult> results;
1288 if (failed(op->fold(operands, results)))
1292 if (results.size() != op->getNumResults())
1295 for (
auto [res, val] :
llvm::zip(results, op->getResults())) {
1296 auto attr = llvm::dyn_cast_or_null<TypedAttr>(res.dyn_cast<Attribute>());
1298 return op->emitError(
1299 "only typed attributes supported for constant-like operations");
1301 auto intAttr = dyn_cast<IntegerAttr>(attr);
1302 if (intAttr && isa<IndexType>(attr.getType()))
1303 state[op->getResult(0)] = size_t(intAttr.getInt());
1304 else if (intAttr && intAttr.getType().isSignlessInteger(1))
1305 state[op->getResult(0)] = bool(intAttr.getInt());
1307 state[op->getResult(0)] = attr;
1310 return DeletionKind::Delete;
1315 return op->emitOpError(
"elaboration not supported");
1319 auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
1320 if (op->hasTrait<OpTrait::ConstantLike>() || (memOp && memOp.hasNoEffect()))
1321 return visitPureOp(op);
1326 if (op->use_empty())
1327 return DeletionKind::Keep;
1332 FailureOr<DeletionKind> visitOp(ConstantOp op) {
return visitPureOp(op); }
1334 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
1335 SmallVector<ElaboratorValue> replacements;
1336 state[op.getResult()] =
1337 sharedState.internalizer.internalize<SequenceStorage>(
1338 op.getSequenceAttr().getAttr(), std::move(replacements));
1339 return DeletionKind::Delete;
1342 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
1343 auto *
seq = get<SequenceStorage *>(op.getSequence());
1345 SmallVector<ElaboratorValue> replacements(
seq->args);
1346 for (
auto replacement : op.getReplacements())
1347 replacements.push_back(state.at(replacement));
1349 state[op.getResult()] =
1350 sharedState.internalizer.internalize<SequenceStorage>(
1351 seq->familyName, std::move(replacements));
1353 return DeletionKind::Delete;
1356 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
1357 auto *
seq = get<SequenceStorage *>(op.getSequence());
1358 auto *randomizedSeq =
1359 sharedState.internalizer.create<RandomizedSequenceStorage>(
1360 currentContext,
seq);
1361 materializer.registerIdentityValue(randomizedSeq);
1362 state[op.getResult()] =
1363 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1365 return DeletionKind::Delete;
1368 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
1369 SmallVector<ElaboratorValue> sequences;
1370 for (
auto seq : op.getSequences())
1371 sequences.push_back(
get<InterleavedSequenceStorage *>(
seq));
1373 state[op.getResult()] =
1374 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1375 std::move(sequences), op.getBatchSize());
1376 return DeletionKind::Delete;
1380 LogicalResult isValidContext(ElaboratorValue value, Operation *op)
const {
1381 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
1382 auto *
seq = std::get<RandomizedSequenceStorage *>(value);
1383 if (
seq->context != currentContext) {
1384 auto err = op->emitError(
"attempting to place sequence derived from ")
1385 <<
seq->sequence->familyName.getValue() <<
" under context "
1387 <<
", but it was previously randomized for context ";
1389 err <<
seq->context;
1397 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1398 for (
auto val : interVal->sequences)
1399 if (failed(isValidContext(val, op)))
1404 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1405 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1406 if (failed(isValidContext(seqVal, op)))
1409 return DeletionKind::Keep;
1412 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1413 SetVector<ElaboratorValue> set;
1414 for (
auto val : op.getElements())
1415 set.insert(state.at(val));
1417 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1418 std::move(set), op.getSet().getType());
1419 return DeletionKind::Delete;
1422 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1423 auto set = get<SetStorage *>(op.getSet())->set;
1426 return op->emitError(
"cannot select from an empty set");
1430 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1431 std::mt19937 customRng(intAttr.getInt());
1437 state[op.getResult()] = set[selected];
1438 return DeletionKind::Delete;
1441 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1442 auto original = get<SetStorage *>(op.getOriginal())->set;
1443 auto diff = get<SetStorage *>(op.getDiff())->set;
1445 SetVector<ElaboratorValue> result(original);
1446 result.set_subtract(diff);
1448 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1449 std::move(result), op.getResult().getType());
1450 return DeletionKind::Delete;
1453 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1454 SetVector<ElaboratorValue> result;
1455 for (
auto set : op.getSets())
1456 result.set_union(
get<SetStorage *>(set)->set);
1458 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1459 std::move(result), op.getType());
1460 return DeletionKind::Delete;
1463 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1464 auto size = get<SetStorage *>(op.getSet())->set.size();
1465 state[op.getResult()] = size;
1466 return DeletionKind::Delete;
1472 FailureOr<DeletionKind> visitOp(SetCartesianProductOp op) {
1473 SetVector<ElaboratorValue> result;
1474 SmallVector<SmallVector<ElaboratorValue>> tuples;
1475 tuples.push_back({});
1477 for (
auto input : op.getInputs()) {
1478 auto &set = get<SetStorage *>(input)->set;
1480 SetVector<ElaboratorValue>
empty;
1481 state[op.getResult()] =
1482 sharedState.internalizer.internalize<SetStorage>(std::move(
empty),
1484 return DeletionKind::Delete;
1487 for (
unsigned i = 0, e = tuples.size(); i < e; ++i) {
1488 for (
auto setEl : set.getArrayRef().drop_back()) {
1489 tuples.push_back(tuples[i]);
1490 tuples.back().push_back(setEl);
1492 tuples[i].push_back(set.back());
1496 for (
auto &tup : tuples)
1498 sharedState.internalizer.internalize<TupleStorage>(std::move(tup)));
1500 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1501 std::move(result), op.getType());
1502 return DeletionKind::Delete;
1505 FailureOr<DeletionKind> visitOp(SetConvertToBagOp op) {
1506 auto set = get<SetStorage *>(op.getInput())->set;
1507 MapVector<ElaboratorValue, uint64_t> bag;
1508 for (
auto val : set)
1509 bag.insert({val, 1});
1510 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1511 std::move(bag), op.getType());
1512 return DeletionKind::Delete;
1515 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1516 MapVector<ElaboratorValue, uint64_t> bag;
1517 for (
auto [val, multiple] :
1518 llvm::zip(op.getElements(), op.getMultiples())) {
1522 bag[state.at(val)] += get<size_t>(multiple);
1525 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1526 std::move(bag), op.getType());
1527 return DeletionKind::Delete;
1530 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1531 auto bag = get<BagStorage *>(op.getBag())->bag;
1534 return op->emitError(
"cannot select from an empty bag");
1536 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1537 prefixSum.reserve(bag.size());
1538 uint32_t accumulator = 0;
1539 for (
auto [val, weight] : bag) {
1540 accumulator += weight;
1541 prefixSum.push_back({val, accumulator});
1544 auto customRng = sharedState.rng;
1546 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1547 customRng = std::mt19937(intAttr.getInt());
1551 auto *iter = llvm::upper_bound(
1553 [](uint32_t a,
const std::pair<ElaboratorValue, uint32_t> &b) {
1554 return a < b.second;
1557 state[op.getResult()] = iter->first;
1558 return DeletionKind::Delete;
1561 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1562 auto original = get<BagStorage *>(op.getOriginal())->bag;
1563 auto diff = get<BagStorage *>(op.getDiff())->bag;
1565 MapVector<ElaboratorValue, uint64_t> result;
1566 for (
const auto &el : original) {
1567 if (!diff.contains(el.first)) {
1575 auto toDiff = diff.lookup(el.first);
1576 if (el.second <= toDiff)
1579 result.insert({el.first, el.second - toDiff});
1582 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1583 std::move(result), op.getType());
1584 return DeletionKind::Delete;
1587 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1588 MapVector<ElaboratorValue, uint64_t> result;
1589 for (
auto bag : op.getBags()) {
1590 auto val = get<BagStorage *>(bag)->bag;
1591 for (
auto [el, multiple] : val)
1592 result[el] += multiple;
1595 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1596 std::move(result), op.getType());
1597 return DeletionKind::Delete;
1600 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1601 auto size = get<BagStorage *>(op.getBag())->bag.size();
1602 state[op.getResult()] = size;
1603 return DeletionKind::Delete;
1606 FailureOr<DeletionKind> visitOp(BagConvertToSetOp op) {
1607 auto bag = get<BagStorage *>(op.getInput())->bag;
1608 SetVector<ElaboratorValue> set;
1609 for (
auto [k, v] : bag)
1611 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1612 std::move(set), op.getType());
1613 return DeletionKind::Delete;
1616 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1617 auto *val = sharedState.internalizer.create<VirtualRegisterStorage>(
1618 op.getAllowedRegsAttr(), op.getType());
1619 state[op.getResult()] = val;
1620 materializer.registerIdentityValue(val);
1621 return DeletionKind::Delete;
1625 ValueRange substitutes)
const {
1626 if (substitutes.empty() || formatString.empty())
1627 return formatString;
1629 auto original = formatString.getValue().str();
1630 for (
auto [i, subst] :
llvm::enumerate(substitutes)) {
1631 size_t startPos = 0;
1632 std::string from =
"{{" + std::to_string(i) +
"}}";
1633 while ((startPos = original.find(from, startPos)) != std::string::npos) {
1634 auto substString = std::to_string(get<size_t>(subst));
1635 original.replace(startPos, from.length(), substString);
1639 return StringAttr::get(formatString.getContext(), original);
1642 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1643 SmallVector<ElaboratorValue> array;
1644 array.reserve(op.getElements().size());
1645 for (
auto val : op.getElements())
1646 array.emplace_back(state.at(val));
1648 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1649 op.getResult().getType(), std::move(array));
1650 return DeletionKind::Delete;
1653 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1654 auto array = get<ArrayStorage *>(op.getArray())->array;
1655 size_t idx = get<size_t>(op.getIndex());
1657 if (array.size() <= idx)
1658 return op->emitError(
"invalid to access index ")
1659 << idx <<
" of an array with " << array.size() <<
" elements";
1661 state[op.getResult()] = array[idx];
1662 return DeletionKind::Delete;
1665 FailureOr<DeletionKind> visitOp(ArrayInjectOp op) {
1666 auto array = get<ArrayStorage *>(op.getArray())->array;
1667 size_t idx = get<size_t>(op.getIndex());
1669 if (array.size() <= idx)
1670 return op->emitError(
"invalid to access index ")
1671 << idx <<
" of an array with " << array.size() <<
" elements";
1673 array[idx] = state[op.getValue()];
1674 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1675 op.getResult().getType(), std::move(array));
1676 return DeletionKind::Delete;
1679 FailureOr<DeletionKind> visitOp(ArraySizeOp op) {
1680 auto array = get<ArrayStorage *>(op.getArray())->array;
1681 state[op.getResult()] = array.size();
1682 return DeletionKind::Delete;
1685 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
1688 state[op.getLabel()] = LabelValue(substituted);
1689 return DeletionKind::Delete;
1692 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1693 auto *val = sharedState.internalizer.create<UniqueLabelStorage>(
1695 state[op.getLabel()] = val;
1696 materializer.registerIdentityValue(val);
1697 return DeletionKind::Delete;
1700 FailureOr<DeletionKind> visitOp(LabelOp op) {
return DeletionKind::Keep; }
1702 FailureOr<DeletionKind> visitOp(TestSuccessOp op) {
1703 return DeletionKind::Keep;
1706 FailureOr<DeletionKind> visitOp(TestFailureOp op) {
1707 return DeletionKind::Keep;
1710 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1711 size_t lower = get<size_t>(op.getLowerBound());
1712 size_t upper = get<size_t>(op.getUpperBound());
1714 return op->emitError(
"cannot select a number from an empty range");
1717 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1718 std::mt19937 customRng(intAttr.getInt());
1719 state[op.getResult()] =
1722 state[op.getResult()] =
1726 return DeletionKind::Delete;
1729 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1730 size_t input = get<size_t>(op.getInput());
1731 auto width = op.getType().getWidth();
1732 auto emitError = [&]() {
return op->emitError(); };
1733 if (input > APInt::getAllOnes(width).getZExtValue())
1734 return emitError() <<
"cannot represent " << input <<
" with " << width
1737 state[op.getResult()] =
1738 ImmediateAttr::get(op.getContext(), APInt(width, input));
1739 return DeletionKind::Delete;
1742 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1743 ContextResourceAttrInterface from = currentContext,
1744 to = cast<ContextResourceAttrInterface>(
1745 get<TypedAttr>(op.getContext()));
1746 if (!currentContext)
1747 from = DefaultContextAttr::get(op->getContext(), to.getType());
1749 auto emitError = [&]() {
1750 auto diag = op.emitError();
1751 diag.attachNote(op.getLoc())
1752 <<
"while materializing value for context switching for " << op;
1757 Value seqVal = materializer.materialize(
1758 get<SequenceStorage *>(op.getSequence()), op.getLoc(), emitError);
1763 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1764 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1765 return DeletionKind::Delete;
1771 auto *iter = testState.contextSwitches.find({from, to});
1774 if (iter == testState.contextSwitches.end())
1775 iter = testState.contextSwitches.find(
1776 {from, AnyContextAttr::get(op->getContext(), to.getType())});
1779 if (iter == testState.contextSwitches.end())
1780 iter = testState.contextSwitches.find(
1781 {AnyContextAttr::get(op->getContext(), from.getType()), to});
1784 if (iter == testState.contextSwitches.end())
1785 iter = testState.contextSwitches.find(
1786 {AnyContextAttr::get(op->getContext(), from.getType()),
1787 AnyContextAttr::get(op->getContext(), to.getType())});
1793 if (iter == testState.contextSwitches.end())
1794 return op->emitError(
"no context transition registered to switch from ")
1795 << from <<
" to " << to;
1797 auto familyName = iter->second->familyName;
1798 SmallVector<ElaboratorValue> args{from, to,
1799 get<SequenceStorage *>(op.getSequence())};
1800 auto *
seq = sharedState.internalizer.internalize<SequenceStorage>(
1801 familyName, std::move(args));
1803 sharedState.internalizer.create<RandomizedSequenceStorage>(to,
seq);
1804 materializer.registerIdentityValue(randSeq);
1805 Value seqVal = materializer.materialize(randSeq, op.getLoc(), emitError);
1809 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1810 return DeletionKind::Delete;
1813 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1814 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1815 get<SequenceStorage *>(op.getSequence());
1816 return DeletionKind::Delete;
1819 FailureOr<DeletionKind> visitOp(MemoryBlockDeclareOp op) {
1820 auto *val = sharedState.internalizer.create<MemoryBlockStorage>(
1821 op.getBaseAddress(), op.getEndAddress(), op.getType());
1822 state[op.getResult()] = val;
1823 materializer.registerIdentityValue(val);
1824 return DeletionKind::Delete;
1827 FailureOr<DeletionKind> visitOp(MemoryAllocOp op) {
1828 size_t size = get<size_t>(op.getSize());
1829 size_t alignment = get<size_t>(op.getAlignment());
1830 auto *memBlock = get<MemoryBlockStorage *>(op.getMemoryBlock());
1831 auto *val = sharedState.internalizer.create<MemoryStorage>(memBlock, size,
1833 state[op.getResult()] = val;
1834 materializer.registerIdentityValue(val);
1835 return DeletionKind::Delete;
1838 FailureOr<DeletionKind> visitOp(MemorySizeOp op) {
1839 auto *memory = get<MemoryStorage *>(op.getMemory());
1840 state[op.getResult()] = memory->size;
1841 return DeletionKind::Delete;
1844 FailureOr<DeletionKind> visitOp(TupleCreateOp op) {
1845 SmallVector<ElaboratorValue> values;
1846 values.reserve(op.getElements().size());
1847 for (
auto el : op.getElements())
1848 values.push_back(state[el]);
1850 state[op.getResult()] =
1851 sharedState.internalizer.internalize<TupleStorage>(std::move(values));
1852 return DeletionKind::Delete;
1855 FailureOr<DeletionKind> visitOp(TupleExtractOp op) {
1856 auto *tuple = get<TupleStorage *>(op.getTuple());
1857 state[op.getResult()] = tuple->values[op.getIndex().getZExtValue()];
1858 return DeletionKind::Delete;
1861 FailureOr<DeletionKind> visitOp(CommentOp op) {
return DeletionKind::Keep; }
1863 FailureOr<DeletionKind> visitOp(rtg::YieldOp op) {
1864 return DeletionKind::Keep;
1867 FailureOr<DeletionKind> visitOp(ValidateOp op) {
1868 SmallVector<ElaboratorValue> defaultUsedValues, elseValues;
1869 for (
auto v : op.getDefaultUsedValues())
1870 defaultUsedValues.push_back(state.at(v));
1872 for (
auto v : op.getElseValues())
1873 elseValues.push_back(state.at(v));
1875 auto *validationVal = sharedState.internalizer.create<ValidationValue>(
1876 op.getValue().getType(), state[op.getRef()],
1877 state[op.getDefaultValue()], op.getIdAttr(),
1878 std::move(defaultUsedValues), std::move(elseValues));
1879 state[op.getValue()] = validationVal;
1880 materializer.registerIdentityValue(validationVal);
1881 materializer.map(validationVal, op.getValue());
1883 for (
auto [i, val] :
llvm::enumerate(op.getValues())) {
1884 auto *muxVal = sharedState.internalizer.create<ValidationMuxedValue>(
1885 val.getType(), validationVal, i);
1886 state[val] = muxVal;
1887 materializer.registerIdentityValue(muxVal);
1888 materializer.map(muxVal, val);
1891 return DeletionKind::Keep;
1894 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1895 bool cond = get<bool>(op.getCondition());
1896 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1897 if (toElaborate.empty())
1898 return DeletionKind::Delete;
1904 SmallVector<ElaboratorValue> yieldedVals;
1905 if (failed(elaborate(toElaborate, {}, yieldedVals)))
1909 for (
auto [res, out] :
llvm::zip(op.getResults(), yieldedVals))
1912 return DeletionKind::Delete;
1915 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1916 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1917 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1918 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1919 return op->emitOpError(
"can only elaborate index type iterator");
1921 auto lowerBound = get<size_t>(op.getLowerBound());
1922 auto step = get<size_t>(op.getStep());
1923 auto upperBound = get<size_t>(op.getUpperBound());
1929 state[op.getInductionVar()] = lowerBound;
1930 for (
auto [iterArg, initArg] :
1931 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1932 state[iterArg] = state.at(initArg);
1935 SmallVector<ElaboratorValue> yieldedVals;
1936 for (
size_t i = lowerBound; i < upperBound; i += step) {
1937 yieldedVals.clear();
1938 if (failed(elaborate(op.getBodyRegion(), {}, yieldedVals)))
1943 state[op.getInductionVar()] = i + step;
1944 for (
auto [iterArg, prevIterArg] :
1945 llvm::zip(op.getRegionIterArgs(), yieldedVals))
1946 state[iterArg] = prevIterArg;
1950 for (
auto [res, iterArg] :
1951 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1952 state[res] = state.at(iterArg);
1954 return DeletionKind::Delete;
1957 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1958 return DeletionKind::Delete;
1961 FailureOr<DeletionKind> visitOp(arith::AddIOp op) {
1962 if (!isa<IndexType>(op.getType()))
1963 return op->emitError(
"only index operands supported");
1965 size_t lhs = get<size_t>(op.getLhs());
1966 size_t rhs = get<size_t>(op.getRhs());
1967 state[op.getResult()] = lhs + rhs;
1968 return DeletionKind::Delete;
1971 FailureOr<DeletionKind> visitOp(arith::AndIOp op) {
1972 if (!op.getType().isSignlessInteger(1))
1973 return op->emitError(
"only 'i1' operands supported");
1975 bool lhs = get<bool>(op.getLhs());
1976 bool rhs = get<bool>(op.getRhs());
1977 state[op.getResult()] = lhs && rhs;
1978 return DeletionKind::Delete;
1981 FailureOr<DeletionKind> visitOp(arith::XOrIOp op) {
1982 if (!op.getType().isSignlessInteger(1))
1983 return op->emitError(
"only 'i1' operands supported");
1985 bool lhs = get<bool>(op.getLhs());
1986 bool rhs = get<bool>(op.getRhs());
1987 state[op.getResult()] = lhs != rhs;
1988 return DeletionKind::Delete;
1991 FailureOr<DeletionKind> visitOp(arith::OrIOp op) {
1992 if (!op.getType().isSignlessInteger(1))
1993 return op->emitError(
"only 'i1' operands supported");
1995 bool lhs = get<bool>(op.getLhs());
1996 bool rhs = get<bool>(op.getRhs());
1997 state[op.getResult()] = lhs || rhs;
1998 return DeletionKind::Delete;
2001 FailureOr<DeletionKind> visitOp(arith::SelectOp op) {
2002 bool cond = get<bool>(op.getCondition());
2003 auto trueVal = state[op.getTrueValue()];
2004 auto falseVal = state[op.getFalseValue()];
2005 state[op.getResult()] = cond ? trueVal : falseVal;
2006 return DeletionKind::Delete;
2009 FailureOr<DeletionKind> visitOp(index::AddOp op) {
2010 size_t lhs = get<size_t>(op.getLhs());
2011 size_t rhs = get<size_t>(op.getRhs());
2012 state[op.getResult()] = lhs + rhs;
2013 return DeletionKind::Delete;
2016 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
2017 size_t lhs = get<size_t>(op.getLhs());
2018 size_t rhs = get<size_t>(op.getRhs());
2020 switch (op.getPred()) {
2021 case index::IndexCmpPredicate::EQ:
2022 result = lhs == rhs;
2024 case index::IndexCmpPredicate::NE:
2025 result = lhs != rhs;
2027 case index::IndexCmpPredicate::ULT:
2030 case index::IndexCmpPredicate::ULE:
2031 result = lhs <= rhs;
2033 case index::IndexCmpPredicate::UGT:
2036 case index::IndexCmpPredicate::UGE:
2037 result = lhs >= rhs;
2040 return op->emitOpError(
"elaboration not supported");
2042 state[op.getResult()] = result;
2043 return DeletionKind::Delete;
2046 FailureOr<DeletionKind> visitOp(ConcatImmediateOp op) {
2047 bool anyValidationValues =
2048 llvm::any_of(op.getOperands(), [&](
auto operand) {
2049 return std::holds_alternative<ValidationValue *>(state[operand]);
2054 if (anyValidationValues) {
2055 SmallVector<ElaboratorValue> operands;
2056 for (
auto operand : op.getOperands())
2057 operands.push_back(state[operand]);
2058 state[op.getResult()] =
2059 sharedState.internalizer.internalize<ImmediateConcatStorage>(
2060 std::move(operands));
2061 return DeletionKind::Delete;
2064 auto result = APInt::getZeroWidth();
2065 for (
auto operand : op.getOperands())
2067 cast<ImmediateAttr>(
get<TypedAttr>(operand)).getValue());
2069 state[op.getResult()] = ImmediateAttr::get(op.getContext(), result);
2070 return DeletionKind::Delete;
2073 FailureOr<DeletionKind> visitOp(SliceImmediateOp op) {
2076 if (std::holds_alternative<ValidationValue *>(state[op.getInput()])) {
2077 state[op.getResult()] =
2078 sharedState.internalizer.internalize<ImmediateSliceStorage>(
2079 state[op.getInput()], op.getLowBit(), op.getResult().getType());
2080 return DeletionKind::Delete;
2084 cast<ImmediateAttr>(get<TypedAttr>(op.getInput())).getValue();
2085 auto sliced = inputValue.extractBits(op.getResult().getType().getWidth(),
2087 state[op.getResult()] = ImmediateAttr::get(op.getContext(), sliced);
2088 return DeletionKind::Delete;
2091 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
2092 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
2095 arith::AddIOp, arith::XOrIOp, arith::AndIOp, arith::OrIOp,
2098 index::AddOp, index::CmpOp,
2100 scf::IfOp, scf::ForOp, scf::YieldOp>(
2101 [&](
auto op) {
return visitOp(op); })
2102 .Default([&](Operation *op) {
return RTGBase::dispatchOpVisitor(op); });
2106 LogicalResult elaborate(Region ®ion,
2107 ArrayRef<ElaboratorValue> regionArguments,
2108 SmallVector<ElaboratorValue> &terminatorOperands) {
2109 if (region.getBlocks().size() > 1)
2110 return region.getParentOp()->emitOpError(
2111 "regions with more than one block are not supported");
2113 for (
auto [arg, elabArg] :
2114 llvm::zip(region.getArguments(), regionArguments))
2115 state[arg] = elabArg;
2117 Block *block = ®ion.front();
2118 for (
auto &op : *block) {
2119 auto result = dispatchOpVisitor(&op);
2123 if (*result == DeletionKind::Keep)
2124 if (failed(materializer.materialize(&op, state)))
2128 llvm::dbgs() <<
"Elaborated " << op <<
" to\n[";
2130 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](
auto res) {
2131 if (state.contains(res))
2132 llvm::dbgs() << state.at(res);
2134 llvm::dbgs() <<
"unknown";
2137 llvm::dbgs() <<
"]\n\n";
2141 if (region.front().mightHaveTerminator())
2142 for (
auto val : region.front().getTerminator()->getOperands())
2143 terminatorOperands.push_back(state.at(val));
2150 SharedState &sharedState;
2153 TestState &testState;
2157 Materializer &materializer;
2160 DenseMap<Value, ElaboratorValue> state;
2163 ContextResourceAttrInterface currentContext;
2168Materializer::elaborateSequence(
const RandomizedSequenceStorage *
seq,
2169 SmallVector<ElaboratorValue> &elabArgs) {
2171 sharedState.table.lookup<SequenceOp>(
seq->sequence->familyName);
2174 OpBuilder builder(familyOp);
2175 auto seqOp = builder.cloneWithoutRegions(familyOp);
2176 auto name = sharedState.names.newName(
seq->sequence->familyName.getValue());
2177 seqOp.setSymName(name);
2178 seqOp.getBodyRegion().emplaceBlock();
2179 sharedState.table.insert(seqOp);
2180 assert(seqOp.getSymName() == name &&
"should not have been renamed");
2182 LLVM_DEBUG(llvm::dbgs() <<
"\n=== Elaborating sequence family @"
2183 << familyOp.getSymName() <<
" into @"
2184 << seqOp.getSymName() <<
" under context "
2185 <<
seq->context <<
"\n\n");
2187 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()), testState,
2188 sharedState, elabArgs);
2189 Elaborator elaborator(sharedState, testState, materializer,
seq->context);
2190 SmallVector<ElaboratorValue> yieldedVals;
2191 if (failed(elaborator.elaborate(familyOp.getBodyRegion(),
seq->sequence->args,
2195 seqOp.setSequenceType(
2196 SequenceType::get(builder.getContext(), materializer.getBlockArgTypes()));
2197 materializer.finalize();
2207struct ElaborationPass
2208 :
public rtg::impl::ElaborationPassBase<ElaborationPass> {
2211 void runOnOperation()
override;
2212 void matchTestsAgainstTargets(SymbolTable &table);
2213 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
2217void ElaborationPass::runOnOperation() {
2218 auto moduleOp = getOperation();
2219 SymbolTable table(moduleOp);
2221 matchTestsAgainstTargets(table);
2223 if (failed(elaborateModule(moduleOp, table)))
2224 return signalPassFailure();
2227void ElaborationPass::matchTestsAgainstTargets(SymbolTable &table) {
2228 auto moduleOp = getOperation();
2230 for (
auto test :
llvm::make_early_inc_range(moduleOp.getOps<TestOp>())) {
2231 if (test.getTargetAttr())
2234 bool matched =
false;
2236 for (
auto target : moduleOp.getOps<TargetOp>()) {
2240 bool isSubtype =
true;
2241 auto testEntries = test.getTargetType().getEntries();
2242 auto targetEntries = target.getTarget().getEntries();
2246 size_t targetIdx = 0;
2247 for (
auto testEntry : testEntries) {
2249 while (targetIdx < targetEntries.size() &&
2250 targetEntries[targetIdx].name.getValue() <
2251 testEntry.name.getValue())
2255 if (targetIdx >= targetEntries.size() ||
2256 targetEntries[targetIdx].name != testEntry.name ||
2257 targetEntries[targetIdx].type != testEntry.type) {
2266 IRRewriter rewriter(test);
2268 auto newTest = cast<TestOp>(test->clone());
2269 newTest.setSymName(test.getSymName().str() +
"_" +
2270 target.getSymName().str());
2274 newTest.setTargetAttr(target.getSymNameAttr());
2276 table.insert(newTest, rewriter.getInsertionPoint());
2280 if (matched || deleteUnmatchedTests)
2286 return isa<MemoryBlockType, ContextResourceTypeInterface>(type);
2289LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
2290 SymbolTable &table) {
2291 SharedState state(table, seed);
2294 state.names.add(moduleOp);
2296 struct TargetElabResult {
2297 DictType targetType;
2298 SmallVector<ElaboratorValue> yields;
2299 TestState testState;
2303 DenseMap<StringAttr, TargetElabResult> targetMap;
2304 for (
auto targetOp : moduleOp.getOps<TargetOp>()) {
2305 LLVM_DEBUG(llvm::dbgs() <<
"=== Elaborating target @"
2306 << targetOp.getSymName() <<
"\n\n");
2308 auto &result = targetMap[targetOp.getSymNameAttr()];
2309 result.targetType = targetOp.getTarget();
2311 SmallVector<ElaboratorValue> blockArgs;
2312 Materializer targetMaterializer(OpBuilder::atBlockBegin(targetOp.getBody()),
2313 result.testState, state, blockArgs);
2314 Elaborator targetElaborator(state, result.testState, targetMaterializer);
2317 if (failed(targetElaborator.elaborate(targetOp.getBodyRegion(), {},
2324 for (
auto testOp : moduleOp.getOps<TestOp>()) {
2328 if (!testOp.getTargetAttr())
2331 LLVM_DEBUG(llvm::dbgs()
2332 <<
"\n=== Elaborating test @" << testOp.getTemplateName()
2333 <<
" for target @" << *testOp.getTarget() <<
"\n\n");
2336 auto targetResult = targetMap[testOp.getTargetAttr()];
2337 TestState testState = targetResult.testState;
2338 testState.name = testOp.getSymNameAttr();
2340 SmallVector<ElaboratorValue> filteredYields;
2342 for (
auto [entry, yield] :
2343 llvm::zip(targetResult.targetType.getEntries(), targetResult.yields)) {
2344 if (i >= testOp.getTargetType().getEntries().size())
2347 if (entry.name == testOp.getTargetType().getEntries()[i].name) {
2348 filteredYields.push_back(yield);
2355 SmallVector<ElaboratorValue> blockArgs;
2356 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()),
2357 testState, state, blockArgs);
2359 for (
auto [arg, val] :
2360 llvm::zip(testOp.getBody()->getArguments(), filteredYields))
2362 materializer.map(val, arg);
2364 Elaborator elaborator(state, testState, materializer);
2365 SmallVector<ElaboratorValue> ignore;
2366 if (failed(elaborator.elaborate(testOp.getBodyRegion(), filteredYields,
2370 materializer.finalize();
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static bool onlyLegalToMaterializeInTarget(Type type)
static void print(TypedAttr val, llvm::raw_ostream &os)
static StringAttr substituteFormatString(StringAttr formatString, ArrayRef< Attribute > substitutes)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
This helps visit TypeOp nodes.
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
static llvm::hash_code hash_value(const ModulePort &port)
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
llvm::hash_code hash_value(const DenseSet< T > &set)
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static LabelValue getEmptyKey()
static LabelValue getTombstoneKey()
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)
static unsigned getEmptyKey()