21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Index/IR/IndexDialect.h"
23#include "mlir/Dialect/Index/IR/IndexOps.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/DenseMapInfoVariant.h"
28#include "llvm/Support/Debug.h"
34#define GEN_PASS_DEF_ELABORATIONPASS
35#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
44#define DEBUG_TYPE "rtg-elaboration"
56 size_t n = w / 32 + (w % 32 != 0);
58 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
63 const uint32_t diff = b - a + 1;
67 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
71 uint32_t width = digits - llvm::countl_zero(diff) - 1;
72 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
91struct SequenceStorage;
92struct RandomizedSequenceStorage;
93struct InterleavedSequenceStorage;
95struct VirtualRegisterStorage;
96struct UniqueLabelStorage;
99struct MemoryBlockStorage;
100struct ValidationValue;
101struct ValidationMuxedValue;
102struct ImmediateConcatStorage;
103struct ImmediateSliceStorage;
109 LabelValue(StringAttr name) : name(name) {}
111 bool operator==(
const LabelValue &other)
const {
return name == other.name; }
118using ElaboratorValue = std::variant<
119 TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
120 RandomizedSequenceStorage *, InterleavedSequenceStorage *, SetStorage *,
121 VirtualRegisterStorage *, UniqueLabelStorage *, LabelValue, ArrayStorage *,
122 TupleStorage *, MemoryStorage *, MemoryBlockStorage *, ValidationValue *,
123 ValidationMuxedValue *, ImmediateConcatStorage *, ImmediateSliceStorage *>;
126llvm::hash_code
hash_value(
const LabelValue &val) {
127 return llvm::hash_value(val.name);
131llvm::hash_code
hash_value(
const ElaboratorValue &val) {
133 [&val](
const auto &alternative) {
136 return llvm::hash_combine(val.index(), alternative);
151 static bool isEqual(
const bool &lhs,
const bool &rhs) {
return lhs == rhs; }
165 static bool isEqual(
const LabelValue &lhs,
const LabelValue &rhs) {
165 static bool isEqual(
const LabelValue &lhs,
const LabelValue &rhs) {
…}
182template <
typename StorageTy>
183struct HashedStorage {
184 HashedStorage(
unsigned hashcode = 0, StorageTy *storage =
nullptr)
185 : hashcode(hashcode), storage(storage) {}
195template <
typename StorageTy>
196struct StorageKeyInfo {
197 static inline HashedStorage<StorageTy> getEmptyKey() {
198 return HashedStorage<StorageTy>(0,
199 DenseMapInfo<StorageTy *>::getEmptyKey());
201 static inline HashedStorage<StorageTy> getTombstoneKey() {
202 return HashedStorage<StorageTy>(
203 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
206 static inline unsigned getHashValue(
const HashedStorage<StorageTy> &key) {
209 static inline unsigned getHashValue(
const StorageTy &key) {
213 static inline bool isEqual(
const HashedStorage<StorageTy> &lhs,
214 const HashedStorage<StorageTy> &rhs) {
215 return lhs.storage == rhs.storage;
217 static inline bool isEqual(
const StorageTy &lhs,
218 const HashedStorage<StorageTy> &rhs) {
219 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
222 return lhs.isEqual(rhs.storage);
231 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
233 type,
llvm::hash_combine_range(set.begin(), set.
end()))),
234 set(std::move(set)), type(type) {}
236 bool isEqual(
const SetStorage *other)
const {
237 return hashcode == other->hashcode && set == other->set &&
242 const unsigned hashcode;
245 const SetVector<ElaboratorValue> set;
254 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
256 type,
llvm::hash_combine_range(bag.begin(), bag.
end()))),
257 bag(std::move(bag)), type(type) {}
259 bool isEqual(
const BagStorage *other)
const {
260 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
265 const unsigned hashcode;
269 const MapVector<ElaboratorValue, uint64_t> bag;
277struct SequenceStorage {
278 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
280 familyName,
llvm::hash_combine_range(args.begin(), args.
end()))),
281 familyName(familyName), args(std::move(args)) {}
283 bool isEqual(
const SequenceStorage *other)
const {
284 return hashcode == other->hashcode && familyName == other->familyName &&
289 const unsigned hashcode;
292 const StringAttr familyName;
295 const SmallVector<ElaboratorValue> args;
299struct InterleavedSequenceStorage {
300 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
302 : sequences(std::move(sequences)), batchSize(batchSize),
304 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
307 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
308 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
310 llvm::hash_combine_range(sequences.begin(), sequences.
end()),
313 bool isEqual(
const InterleavedSequenceStorage *other)
const {
314 return hashcode == other->hashcode && sequences == other->sequences &&
315 batchSize == other->batchSize;
318 const SmallVector<ElaboratorValue> sequences;
320 const uint32_t batchSize;
323 const unsigned hashcode;
328 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
330 type,
llvm::hash_combine_range(array.begin(), array.
end()))),
331 type(type), array(array) {}
333 bool isEqual(
const ArrayStorage *other)
const {
334 return hashcode == other->hashcode && type == other->type &&
335 array == other->array;
339 const unsigned hashcode;
346 const SmallVector<ElaboratorValue> array;
351 TupleStorage(SmallVector<ElaboratorValue> &&values)
352 : hashcode(
llvm::hash_combine_range(values.begin(), values.
end())),
353 values(std::move(values)) {}
355 bool isEqual(
const TupleStorage *other)
const {
356 return hashcode == other->hashcode && values == other->values;
360 const unsigned hashcode;
362 const SmallVector<ElaboratorValue> values;
367struct ImmediateConcatStorage {
368 ImmediateConcatStorage(SmallVector<ElaboratorValue> &&operands)
369 : hashcode(
llvm::hash_combine_range(operands.begin(), operands.
end())),
370 operands(std::move(operands)) {}
372 bool isEqual(
const ImmediateConcatStorage *other)
const {
373 return hashcode == other->hashcode && operands == other->operands;
376 const unsigned hashcode;
377 const SmallVector<ElaboratorValue> operands;
382struct ImmediateSliceStorage {
383 ImmediateSliceStorage(ElaboratorValue input,
unsigned lowBit, Type type)
385 lowBit(lowBit), type(type) {}
387 bool isEqual(
const ImmediateSliceStorage *other)
const {
388 return hashcode == other->hashcode && input == other->input &&
389 lowBit == other->lowBit && type == other->type;
392 const unsigned hashcode;
393 const ElaboratorValue input;
394 const unsigned lowBit;
406struct IdentityValue {
408 IdentityValue(Type type) : type(type) {}
421 bool alreadyMaterialized =
false;
429struct VirtualRegisterStorage : IdentityValue {
430 VirtualRegisterStorage(ArrayAttr allowedRegs, Type type)
431 : IdentityValue(type), allowedRegs(allowedRegs) {}
438 const ArrayAttr allowedRegs;
441struct UniqueLabelStorage : IdentityValue {
442 UniqueLabelStorage(StringAttr name)
443 : IdentityValue(LabelType::
get(name.getContext())), name(name) {}
449 const StringAttr name;
453struct MemoryBlockStorage : IdentityValue {
454 MemoryBlockStorage(
const APInt &baseAddress,
const APInt &endAddress,
456 : IdentityValue(type), baseAddress(baseAddress), endAddress(endAddress) {}
461 const APInt baseAddress;
464 const APInt endAddress;
468struct MemoryStorage : IdentityValue {
469 MemoryStorage(MemoryBlockStorage *memoryBlock,
size_t size,
size_t alignment)
470 : IdentityValue(MemoryType::
get(memoryBlock->type.getContext(),
472 memoryBlock(memoryBlock), size(size), alignment(alignment) {}
474 MemoryBlockStorage *memoryBlock;
476 const size_t alignment;
480struct RandomizedSequenceStorage : IdentityValue {
481 RandomizedSequenceStorage(ContextResourceAttrInterface context,
482 SequenceStorage *sequence)
484 RandomizedSequenceType::
get(sequence->familyName.getContext())),
485 context(context), sequence(sequence) {}
488 const ContextResourceAttrInterface context;
490 const SequenceStorage *sequence;
494struct ValidationValue : IdentityValue {
495 ValidationValue(Type type,
const ElaboratorValue &ref,
496 const ElaboratorValue &defaultValue, StringAttr
id,
497 SmallVector<ElaboratorValue> &&defaultUsedValues,
498 SmallVector<ElaboratorValue> &&elseValues)
499 : IdentityValue(type), ref(ref), defaultValue(defaultValue), id(id),
500 defaultUsedValues(std::move(defaultUsedValues)),
501 elseValues(std::move(elseValues)) {}
503 const ElaboratorValue ref;
504 const ElaboratorValue defaultValue;
506 const SmallVector<ElaboratorValue> defaultUsedValues;
507 const SmallVector<ElaboratorValue> elseValues;
511struct ValidationMuxedValue : IdentityValue {
512 ValidationMuxedValue(Type type,
const ValidationValue *value,
unsigned idx)
513 : IdentityValue(type), value(value), idx(idx) {}
515 const ValidationValue *value;
529 template <
typename StorageTy,
typename... Args>
530 StorageTy *internalize(Args &&...args) {
531 static_assert(!std::is_base_of_v<IdentityValue, StorageTy> &&
532 "values with identity must not be internalized");
534 StorageTy storage(std::forward<Args>(args)...);
536 auto existing = getInternSet<StorageTy>().insert_as(
537 HashedStorage<StorageTy>(storage.hashcode), storage);
538 StorageTy *&storagePtr = existing.first->storage;
541 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
546 template <
typename StorageTy,
typename... Args>
547 StorageTy *create(Args &&...args) {
548 static_assert(std::is_base_of_v<IdentityValue, StorageTy> &&
549 "values with structural equivalence must be internalized");
551 return new (allocator.Allocate<StorageTy>())
552 StorageTy(std::forward<Args>(args)...);
556 template <
typename StorageTy>
557 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
559 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
560 return internedArrays;
561 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
563 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
565 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
566 return internedSequences;
567 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
568 return internedRandomizedSequences;
569 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
570 return internedInterleavedSequences;
571 else if constexpr (std::is_same_v<StorageTy, TupleStorage>)
572 return internedTuples;
573 else if constexpr (std::is_same_v<StorageTy, ImmediateConcatStorage>)
574 return internedImmediateConcatValues;
575 else if constexpr (std::is_same_v<StorageTy, ImmediateSliceStorage>)
576 return internedImmediateSliceValues;
578 static_assert(!
sizeof(StorageTy),
579 "no intern set available for this storage type.");
584 llvm::BumpPtrAllocator allocator;
589 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
591 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
592 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
593 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
595 DenseSet<HashedStorage<RandomizedSequenceStorage>,
596 StorageKeyInfo<RandomizedSequenceStorage>>
597 internedRandomizedSequences;
598 DenseSet<HashedStorage<InterleavedSequenceStorage>,
599 StorageKeyInfo<InterleavedSequenceStorage>>
600 internedInterleavedSequences;
601 DenseSet<HashedStorage<TupleStorage>, StorageKeyInfo<TupleStorage>>
603 DenseSet<HashedStorage<ImmediateConcatStorage>,
604 StorageKeyInfo<ImmediateConcatStorage>>
605 internedImmediateConcatValues;
606 DenseSet<HashedStorage<ImmediateSliceStorage>,
607 StorageKeyInfo<ImmediateSliceStorage>>
608 internedImmediateSliceValues;
615static llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
616 const ElaboratorValue &value);
618static void print(TypedAttr val, llvm::raw_ostream &os) {
619 os <<
"<attr " << val <<
">";
618static void print(TypedAttr val, llvm::raw_ostream &os) {
…}
622static void print(BagStorage *val, llvm::raw_ostream &os) {
624 llvm::interleaveComma(val->bag, os,
625 [&](
const std::pair<ElaboratorValue, uint64_t> &el) {
626 os << el.first <<
" -> " << el.second;
628 os <<
"} at " << val <<
">";
622static void print(BagStorage *val, llvm::raw_ostream &os) {
…}
631static void print(
bool val, llvm::raw_ostream &os) {
632 os <<
"<bool " << (val ?
"true" :
"false") <<
">";
631static void print(
bool val, llvm::raw_ostream &os) {
…}
635static void print(
size_t val, llvm::raw_ostream &os) {
636 os <<
"<index " << val <<
">";
635static void print(
size_t val, llvm::raw_ostream &os) {
…}
639static void print(SequenceStorage *val, llvm::raw_ostream &os) {
640 os <<
"<sequence @" << val->familyName.getValue() <<
"(";
641 llvm::interleaveComma(val->args, os,
642 [&](
const ElaboratorValue &val) { os << val; });
643 os <<
") at " << val <<
">";
639static void print(SequenceStorage *val, llvm::raw_ostream &os) {
…}
646static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
647 os <<
"<randomized-sequence derived from @"
648 << val->sequence->familyName.getValue() <<
" under context "
649 << val->context <<
"(";
650 llvm::interleaveComma(val->sequence->args, os,
651 [&](
const ElaboratorValue &val) { os << val; });
652 os <<
") at " << val <<
">";
646static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
…}
655static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
656 os <<
"<interleaved-sequence [";
657 llvm::interleaveComma(val->sequences, os,
658 [&](
const ElaboratorValue &val) { os << val; });
659 os <<
"] batch-size " << val->batchSize <<
" at " << val <<
">";
655static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
…}
662static void print(ArrayStorage *val, llvm::raw_ostream &os) {
664 llvm::interleaveComma(val->array, os,
665 [&](
const ElaboratorValue &val) { os << val; });
666 os <<
"] at " << val <<
">";
662static void print(ArrayStorage *val, llvm::raw_ostream &os) {
…}
669static void print(SetStorage *val, llvm::raw_ostream &os) {
671 llvm::interleaveComma(val->set, os,
672 [&](
const ElaboratorValue &val) { os << val; });
673 os <<
"} at " << val <<
">";
669static void print(SetStorage *val, llvm::raw_ostream &os) {
…}
676static void print(
const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
677 os <<
"<virtual-register " << val <<
" " << val->allowedRegs <<
">";
676static void print(
const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
…}
680static void print(
const UniqueLabelStorage *val, llvm::raw_ostream &os) {
681 os <<
"<unique-label " << val <<
" " << val->name <<
">";
680static void print(
const UniqueLabelStorage *val, llvm::raw_ostream &os) {
…}
684static void print(
const LabelValue &val, llvm::raw_ostream &os) {
685 os <<
"<label " << val.name <<
">";
684static void print(
const LabelValue &val, llvm::raw_ostream &os) {
…}
688static void print(
const TupleStorage *val, llvm::raw_ostream &os) {
690 llvm::interleaveComma(val->values, os,
691 [&](
const ElaboratorValue &val) { os << val; });
688static void print(
const TupleStorage *val, llvm::raw_ostream &os) {
…}
695static void print(
const MemoryStorage *val, llvm::raw_ostream &os) {
696 os <<
"<memory {" << ElaboratorValue(val->memoryBlock)
697 <<
", size=" << val->size <<
", alignment=" << val->alignment <<
"}>";
695static void print(
const MemoryStorage *val, llvm::raw_ostream &os) {
…}
700static void print(
const MemoryBlockStorage *val, llvm::raw_ostream &os) {
701 os <<
"<memory-block {"
702 <<
", address-width=" << val->baseAddress.getBitWidth()
703 <<
", base-address=" << val->baseAddress
704 <<
", end-address=" << val->endAddress <<
"}>";
700static void print(
const MemoryBlockStorage *val, llvm::raw_ostream &os) {
…}
707static void print(
const ValidationValue *val, llvm::raw_ostream &os) {
708 os <<
"<validation-value {type=" << val->type <<
", ref=" << val->ref
709 <<
", defaultValue=" << val->defaultValue <<
"}>";
707static void print(
const ValidationValue *val, llvm::raw_ostream &os) {
…}
712static void print(
const ValidationMuxedValue *val, llvm::raw_ostream &os) {
713 os <<
"<validation-muxed-value (" << val->value <<
") at " << val->idx <<
">";
712static void print(
const ValidationMuxedValue *val, llvm::raw_ostream &os) {
…}
716static void print(
const ImmediateConcatStorage *val, llvm::raw_ostream &os) {
717 os <<
"<immediate-concat [";
718 llvm::interleaveComma(val->operands, os,
719 [&](
const ElaboratorValue &val) { os << val; });
716static void print(
const ImmediateConcatStorage *val, llvm::raw_ostream &os) {
…}
723static void print(
const ImmediateSliceStorage *val, llvm::raw_ostream &os) {
724 os <<
"<immediate-slice " << val->input <<
" from " << val->lowBit <<
">";
723static void print(
const ImmediateSliceStorage *val, llvm::raw_ostream &os) {
…}
728 const ElaboratorValue &value) {
729 std::visit([&](
auto val) {
print(val, os); }, value);
744 SharedState(SymbolTable &table,
unsigned seed) : table(table), rng(seed) {}
749 Internalizer internalizer;
759 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
767 Materializer(OpBuilder builder, TestState &testState,
768 SharedState &sharedState,
769 SmallVector<ElaboratorValue> &blockArgs)
770 : builder(builder), testState(testState), sharedState(sharedState),
771 blockArgs(blockArgs) {}
775 Value materialize(ElaboratorValue val, Location loc,
776 function_ref<InFlightDiagnostic()> emitError) {
777 auto iter = materializedValues.find(val);
778 if (iter != materializedValues.end())
781 LLVM_DEBUG(llvm::dbgs() <<
"Materializing " << val);
785 Value res = std::visit(
787 if constexpr (std::is_base_of_v<IdentityValue,
788 std::remove_pointer_t<
789 std::decay_t<
decltype(value)>>>) {
790 if (identityValueRoot.contains(value)) {
793 static_cast<IdentityValue *
>(value)->alreadyMaterialized;
794 assert(!materialized &&
"must not already be materialized");
798 return visit(value, loc, emitError);
801 Value arg = builder.getBlock()->addArgument(value->type, loc);
802 blockArgs.push_back(val);
803 blockArgTypes.push_back(arg.getType());
804 materializedValues[val] = arg;
808 return visit(value, loc, emitError);
812 LLVM_DEBUG(llvm::dbgs() <<
" to\n" << res <<
"\n\n");
823 LogicalResult materialize(Operation *op,
824 DenseMap<Value, ElaboratorValue> &state) {
825 if (op->getNumRegions() > 0)
826 return op->emitOpError(
"ops with nested regions must be elaborated away");
834 for (
auto res : op->getResults())
835 if (!res.use_empty() && !isa<ValidateOp>(op))
836 return op->emitOpError(
837 "ops with results that have uses are not supported");
839 if (op->getParentRegion() == builder.getBlock()->getParent()) {
842 deleteOpsUntil([&](
auto iter) {
return &*iter == op; });
844 if (builder.getInsertionPoint() == builder.getBlock()->end())
845 return op->emitError(
"operation did not occur after the current "
846 "materializer insertion point");
848 LLVM_DEBUG(llvm::dbgs() <<
"Modifying in-place: " << *op <<
"\n\n");
850 LLVM_DEBUG(llvm::dbgs() <<
"Materializing a clone of " << *op <<
"\n\n");
851 op = builder.clone(*op);
852 builder.setInsertionPoint(op);
855 for (
auto &operand : op->getOpOperands()) {
856 auto emitError = [&]() {
857 auto diag = op->emitError();
858 diag.attachNote(op->getLoc())
859 <<
"while materializing value for operand#"
860 << operand.getOperandNumber();
864 auto elabVal = state.at(operand.get());
865 Value val = materialize(elabVal, op->getLoc(), emitError);
869 state[val] = elabVal;
873 builder.setInsertionPointAfter(op);
880 deleteOpsUntil([](
auto iter) {
return false; });
882 for (
auto *op :
llvm::reverse(toDelete))
889 void registerIdentityValue(IdentityValue *val) {
890 identityValueRoot.insert(val);
893 ArrayRef<Type> getBlockArgTypes()
const {
return blockArgTypes; }
895 void map(ElaboratorValue eval, Value val) { materializedValues[eval] = val; }
897 template <
typename OpTy,
typename... Args>
898 OpTy create(Location location, Args &&...args) {
899 return OpTy::create(builder, location, std::forward<Args>(args)...);
903 SequenceOp elaborateSequence(
const RandomizedSequenceStorage *
seq,
904 SmallVector<ElaboratorValue> &elabArgs);
906 void deleteOpsUntil(function_ref<
bool(Block::iterator)> stop) {
907 auto ip = builder.getInsertionPoint();
908 while (ip != builder.getBlock()->end() && !stop(ip)) {
909 LLVM_DEBUG(llvm::dbgs() <<
"Marking to be deleted: " << *ip <<
"\n\n");
910 toDelete.push_back(&*ip);
912 builder.setInsertionPointAfter(&*ip);
913 ip = builder.getInsertionPoint();
917 Value visit(TypedAttr val, Location loc,
918 function_ref<InFlightDiagnostic()> emitError) {
921 if (
auto intAttr = dyn_cast<IntegerAttr>(val);
922 intAttr && isa<IndexType>(val.getType())) {
923 Value res = index::ConstantOp::create(builder, loc, intAttr);
924 materializedValues[val] = res;
931 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
933 emitError() <<
"materializer of dialect '"
934 << val.getDialect().getNamespace()
935 <<
"' unable to materialize value for attribute '" << val
940 Value res = op->getResult(0);
941 materializedValues[val] = res;
945 Value visit(
size_t val, Location loc,
946 function_ref<InFlightDiagnostic()> emitError) {
947 Value res = index::ConstantOp::create(builder, loc, val);
948 materializedValues[val] = res;
952 Value visit(
bool val, Location loc,
953 function_ref<InFlightDiagnostic()> emitError) {
954 Value res = index::BoolConstantOp::create(builder, loc, val);
955 materializedValues[val] = res;
959 Value visit(ArrayStorage *val, Location loc,
960 function_ref<InFlightDiagnostic()> emitError) {
961 SmallVector<Value> elements;
962 elements.reserve(val->array.size());
963 for (
auto el : val->array) {
964 auto materialized = materialize(el, loc, emitError);
968 elements.push_back(materialized);
971 Value res = ArrayCreateOp::create(builder, loc, val->type, elements);
972 materializedValues[val] = res;
976 Value visit(SetStorage *val, Location loc,
977 function_ref<InFlightDiagnostic()> emitError) {
978 SmallVector<Value> elements;
979 elements.reserve(val->set.size());
980 for (
auto el : val->set) {
981 auto materialized = materialize(el, loc, emitError);
985 elements.push_back(materialized);
988 auto res = SetCreateOp::create(builder, loc, val->type, elements);
989 materializedValues[val] = res;
993 Value visit(BagStorage *val, Location loc,
994 function_ref<InFlightDiagnostic()> emitError) {
995 SmallVector<Value> values, weights;
996 values.reserve(val->bag.size());
997 weights.reserve(val->bag.size());
998 for (
auto [val, weight] : val->bag) {
999 auto materializedVal = materialize(val, loc, emitError);
1000 auto materializedWeight = materialize(weight, loc, emitError);
1001 if (!materializedVal || !materializedWeight)
1004 values.push_back(materializedVal);
1005 weights.push_back(materializedWeight);
1008 auto res = BagCreateOp::create(builder, loc, val->type, values, weights);
1009 materializedValues[val] = res;
1013 Value visit(MemoryBlockStorage *val, Location loc,
1014 function_ref<InFlightDiagnostic()> emitError) {
1015 auto intType = builder.getIntegerType(val->baseAddress.getBitWidth());
1016 Value res = MemoryBlockDeclareOp::create(
1017 builder, loc, val->type, IntegerAttr::get(intType, val->baseAddress),
1018 IntegerAttr::get(intType, val->endAddress));
1019 materializedValues[val] = res;
1023 Value visit(MemoryStorage *val, Location loc,
1024 function_ref<InFlightDiagnostic()> emitError) {
1025 auto memBlock = materialize(val->memoryBlock, loc, emitError);
1026 auto memSize = materialize(val->size, loc, emitError);
1027 auto memAlign = materialize(val->alignment, loc, emitError);
1028 if (!(memBlock && memSize && memAlign))
1032 MemoryAllocOp::create(builder, loc, memBlock, memSize, memAlign);
1033 materializedValues[val] = res;
1037 Value visit(SequenceStorage *val, Location loc,
1038 function_ref<InFlightDiagnostic()> emitError) {
1039 emitError() <<
"materializing a non-randomized sequence not supported yet";
1043 Value visit(RandomizedSequenceStorage *val, Location loc,
1044 function_ref<InFlightDiagnostic()> emitError) {
1050 SmallVector<ElaboratorValue> elabArgs;
1051 SequenceOp seqOp = elaborateSequence(val, elabArgs);
1057 SmallVector<Value> args;
1058 SmallVector<Type> argTypes;
1059 for (
auto arg : elabArgs) {
1060 Value materialized = materialize(arg, loc, emitError);
1064 args.push_back(materialized);
1065 argTypes.push_back(materialized.getType());
1068 Value res = GetSequenceOp::create(
1069 builder, loc, SequenceType::get(builder.getContext(), argTypes),
1070 seqOp.getSymName());
1075 res = SubstituteSequenceOp::create(builder, loc, res, args);
1077 res = RandomizeSequenceOp::create(builder, loc, res);
1079 materializedValues[val] = res;
1083 Value visit(InterleavedSequenceStorage *val, Location loc,
1084 function_ref<InFlightDiagnostic()> emitError) {
1085 SmallVector<Value> sequences;
1086 for (
auto seqVal : val->sequences) {
1087 Value materialized = materialize(seqVal, loc, emitError);
1091 sequences.push_back(materialized);
1094 if (sequences.size() == 1)
1095 return sequences[0];
1098 InterleaveSequencesOp::create(builder, loc, sequences, val->batchSize);
1099 materializedValues[val] = res;
1103 Value visit(VirtualRegisterStorage *val, Location loc,
1104 function_ref<InFlightDiagnostic()> emitError) {
1105 Value res = VirtualRegisterOp::create(builder, loc, val->allowedRegs);
1106 materializedValues[val] = res;
1110 Value visit(UniqueLabelStorage *val, Location loc,
1111 function_ref<InFlightDiagnostic()> emitError) {
1113 LabelUniqueDeclOp::create(builder, loc, val->name, ValueRange());
1114 materializedValues[val] = res;
1118 Value visit(
const LabelValue &val, Location loc,
1119 function_ref<InFlightDiagnostic()> emitError) {
1120 Value res = LabelDeclOp::create(builder, loc, val.name, ValueRange());
1121 materializedValues[val] = res;
1125 Value visit(TupleStorage *val, Location loc,
1126 function_ref<InFlightDiagnostic()> emitError) {
1127 SmallVector<Value> materialized;
1128 materialized.reserve(val->values.size());
1129 for (
auto v : val->values)
1130 materialized.push_back(materialize(v, loc, emitError));
1131 Value res = TupleCreateOp::create(builder, loc, materialized);
1132 materializedValues[val] = res;
1136 Value visit(ValidationValue *val, Location loc,
1137 function_ref<InFlightDiagnostic()> emitError) {
1138 SmallVector<Value> usedDefaultValues, elseValues;
1139 for (
auto [dfltVal, elseVal] :
1140 llvm::zip(val->defaultUsedValues, val->elseValues)) {
1141 auto dfltMat = materialize(dfltVal, loc, emitError);
1142 auto elseMat = materialize(elseVal, loc, emitError);
1143 if (!dfltMat || !elseMat)
1146 usedDefaultValues.push_back(dfltMat);
1147 usedDefaultValues.push_back(elseMat);
1150 auto validateOp = ValidateOp::create(
1151 builder, loc, val->type, materialize(val->ref, loc, emitError),
1152 materialize(val->defaultValue, loc, emitError), val->id,
1153 usedDefaultValues, elseValues);
1154 materializedValues[val] = validateOp.getValue();
1155 return validateOp.getValue();
1158 Value visit(ValidationMuxedValue *val, Location loc,
1159 function_ref<InFlightDiagnostic()> emitError) {
1160 Value validateValue =
1161 materialize(
const_cast<ValidationValue *
>(val->value), loc, emitError);
1164 auto validateOp = validateValue.getDefiningOp<ValidateOp>();
1166 auto *defOp = validateValue.getDefiningOp();
1168 <<
"expected validate op for validation muxed value, but found "
1169 << (defOp ? defOp->getName().getStringRef() :
"block argument");
1172 materializedValues[val] = validateOp.getValues()[val->idx];
1173 return validateOp.getValues()[val->idx];
1176 Value visit(ImmediateConcatStorage *val, Location loc,
1177 function_ref<InFlightDiagnostic()> emitError) {
1178 SmallVector<Value> operands;
1179 for (
auto operand : val->operands) {
1180 auto materialized = materialize(operand, loc, emitError);
1184 operands.push_back(materialized);
1187 Value res = ConcatImmediateOp::create(builder, loc, operands);
1188 materializedValues[val] = res;
1192 Value visit(ImmediateSliceStorage *val, Location loc,
1193 function_ref<InFlightDiagnostic()> emitError) {
1194 Value input = materialize(val->input, loc, emitError);
1199 SliceImmediateOp::create(builder, loc, val->type, input, val->lowBit);
1200 materializedValues[val] = res;
1210 DenseMap<ElaboratorValue, Value> materializedValues;
1216 SmallVector<Operation *> toDelete;
1218 TestState &testState;
1219 SharedState &sharedState;
1224 SmallVector<ElaboratorValue> &blockArgs;
1225 SmallVector<Type> blockArgTypes;
1230 DenseSet<IdentityValue *> identityValueRoot;
1239enum class DeletionKind { Keep, Delete };
1242class Elaborator :
public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
1245 using RTGBase::visitOp;
1247 Elaborator(SharedState &sharedState, TestState &testState,
1248 Materializer &materializer,
1249 ContextResourceAttrInterface currentContext = {})
1250 : sharedState(sharedState), testState(testState),
1251 materializer(materializer), currentContext(currentContext) {}
1253 template <
typename ValueTy>
1254 inline ValueTy
get(Value val)
const {
1255 return std::get<ValueTy>(state.at(val));
1258 FailureOr<DeletionKind> visitPureOp(Operation *op) {
1259 SmallVector<Attribute> operands;
1260 for (
auto operand : op->getOperands()) {
1261 auto evalValue = state[operand];
1262 if (std::holds_alternative<TypedAttr>(evalValue))
1263 operands.push_back(std::get<TypedAttr>(evalValue));
1268 SmallVector<OpFoldResult> results;
1269 if (failed(op->fold(operands, results)))
1273 if (results.size() != op->getNumResults())
1276 for (
auto [res, val] :
llvm::zip(results, op->getResults())) {
1277 auto attr = llvm::dyn_cast_or_null<TypedAttr>(res.dyn_cast<Attribute>());
1279 return op->emitError(
1280 "only typed attributes supported for constant-like operations");
1282 auto intAttr = dyn_cast<IntegerAttr>(attr);
1283 if (intAttr && isa<IndexType>(attr.getType()))
1284 state[op->getResult(0)] = size_t(intAttr.getInt());
1285 else if (intAttr && intAttr.getType().isSignlessInteger(1))
1286 state[op->getResult(0)] = bool(intAttr.getInt());
1288 state[op->getResult(0)] = attr;
1291 return DeletionKind::Delete;
1296 return op->emitOpError(
"elaboration not supported");
1300 auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
1301 if (op->hasTrait<OpTrait::ConstantLike>() || (memOp && memOp.hasNoEffect()))
1302 return visitPureOp(op);
1307 if (op->use_empty())
1308 return DeletionKind::Keep;
1313 FailureOr<DeletionKind> visitOp(ConstantOp op) {
return visitPureOp(op); }
1315 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
1316 SmallVector<ElaboratorValue> replacements;
1317 state[op.getResult()] =
1318 sharedState.internalizer.internalize<SequenceStorage>(
1319 op.getSequenceAttr(), std::move(replacements));
1320 return DeletionKind::Delete;
1323 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
1324 auto *
seq = get<SequenceStorage *>(op.getSequence());
1326 SmallVector<ElaboratorValue> replacements(
seq->args);
1327 for (
auto replacement : op.getReplacements())
1328 replacements.push_back(state.at(replacement));
1330 state[op.getResult()] =
1331 sharedState.internalizer.internalize<SequenceStorage>(
1332 seq->familyName, std::move(replacements));
1334 return DeletionKind::Delete;
1337 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
1338 auto *
seq = get<SequenceStorage *>(op.getSequence());
1339 auto *randomizedSeq =
1340 sharedState.internalizer.create<RandomizedSequenceStorage>(
1341 currentContext,
seq);
1342 materializer.registerIdentityValue(randomizedSeq);
1343 state[op.getResult()] =
1344 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1346 return DeletionKind::Delete;
1349 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
1350 SmallVector<ElaboratorValue> sequences;
1351 for (
auto seq : op.getSequences())
1352 sequences.push_back(
get<InterleavedSequenceStorage *>(
seq));
1354 state[op.getResult()] =
1355 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1356 std::move(sequences), op.getBatchSize());
1357 return DeletionKind::Delete;
1361 LogicalResult isValidContext(ElaboratorValue value, Operation *op)
const {
1362 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
1363 auto *
seq = std::get<RandomizedSequenceStorage *>(value);
1364 if (
seq->context != currentContext) {
1365 auto err = op->emitError(
"attempting to place sequence derived from ")
1366 <<
seq->sequence->familyName.getValue() <<
" under context "
1368 <<
", but it was previously randomized for context ";
1370 err <<
seq->context;
1378 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1379 for (
auto val : interVal->sequences)
1380 if (failed(isValidContext(val, op)))
1385 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1386 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1387 if (failed(isValidContext(seqVal, op)))
1390 return DeletionKind::Keep;
1393 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1394 SetVector<ElaboratorValue> set;
1395 for (
auto val : op.getElements())
1396 set.insert(state.at(val));
1398 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1399 std::move(set), op.getSet().getType());
1400 return DeletionKind::Delete;
1403 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1404 auto set = get<SetStorage *>(op.getSet())->set;
1407 return op->emitError(
"cannot select from an empty set");
1411 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1412 std::mt19937 customRng(intAttr.getInt());
1418 state[op.getResult()] = set[selected];
1419 return DeletionKind::Delete;
1422 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1423 auto original = get<SetStorage *>(op.getOriginal())->set;
1424 auto diff = get<SetStorage *>(op.getDiff())->set;
1426 SetVector<ElaboratorValue> result(original);
1427 result.set_subtract(diff);
1429 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1430 std::move(result), op.getResult().getType());
1431 return DeletionKind::Delete;
1434 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1435 SetVector<ElaboratorValue> result;
1436 for (
auto set : op.getSets())
1437 result.set_union(
get<SetStorage *>(set)->set);
1439 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1440 std::move(result), op.getType());
1441 return DeletionKind::Delete;
1444 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1445 auto size = get<SetStorage *>(op.getSet())->set.size();
1446 state[op.getResult()] = size;
1447 return DeletionKind::Delete;
1453 FailureOr<DeletionKind> visitOp(SetCartesianProductOp op) {
1454 SetVector<ElaboratorValue> result;
1455 SmallVector<SmallVector<ElaboratorValue>> tuples;
1456 tuples.push_back({});
1458 for (
auto input : op.getInputs()) {
1459 auto &set = get<SetStorage *>(input)->set;
1461 SetVector<ElaboratorValue>
empty;
1462 state[op.getResult()] =
1463 sharedState.internalizer.internalize<SetStorage>(std::move(
empty),
1465 return DeletionKind::Delete;
1468 for (
unsigned i = 0, e = tuples.size(); i < e; ++i) {
1469 for (
auto setEl : set.getArrayRef().drop_back()) {
1470 tuples.push_back(tuples[i]);
1471 tuples.back().push_back(setEl);
1473 tuples[i].push_back(set.back());
1477 for (
auto &tup : tuples)
1479 sharedState.internalizer.internalize<TupleStorage>(std::move(tup)));
1481 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1482 std::move(result), op.getType());
1483 return DeletionKind::Delete;
1486 FailureOr<DeletionKind> visitOp(SetConvertToBagOp op) {
1487 auto set = get<SetStorage *>(op.getInput())->set;
1488 MapVector<ElaboratorValue, uint64_t> bag;
1489 for (
auto val : set)
1490 bag.insert({val, 1});
1491 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1492 std::move(bag), op.getType());
1493 return DeletionKind::Delete;
1496 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1497 MapVector<ElaboratorValue, uint64_t> bag;
1498 for (
auto [val, multiple] :
1499 llvm::zip(op.getElements(), op.getMultiples())) {
1503 bag[state.at(val)] += get<size_t>(multiple);
1506 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1507 std::move(bag), op.getType());
1508 return DeletionKind::Delete;
1511 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1512 auto bag = get<BagStorage *>(op.getBag())->bag;
1515 return op->emitError(
"cannot select from an empty bag");
1517 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1518 prefixSum.reserve(bag.size());
1519 uint32_t accumulator = 0;
1520 for (
auto [val, weight] : bag) {
1521 accumulator += weight;
1522 prefixSum.push_back({val, accumulator});
1525 auto customRng = sharedState.rng;
1527 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1528 customRng = std::mt19937(intAttr.getInt());
1532 auto *iter = llvm::upper_bound(
1534 [](uint32_t a,
const std::pair<ElaboratorValue, uint32_t> &b) {
1535 return a < b.second;
1538 state[op.getResult()] = iter->first;
1539 return DeletionKind::Delete;
1542 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1543 auto original = get<BagStorage *>(op.getOriginal())->bag;
1544 auto diff = get<BagStorage *>(op.getDiff())->bag;
1546 MapVector<ElaboratorValue, uint64_t> result;
1547 for (
const auto &el : original) {
1548 if (!diff.contains(el.first)) {
1556 auto toDiff = diff.lookup(el.first);
1557 if (el.second <= toDiff)
1560 result.insert({el.first, el.second - toDiff});
1563 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1564 std::move(result), op.getType());
1565 return DeletionKind::Delete;
1568 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1569 MapVector<ElaboratorValue, uint64_t> result;
1570 for (
auto bag : op.getBags()) {
1571 auto val = get<BagStorage *>(bag)->bag;
1572 for (
auto [el, multiple] : val)
1573 result[el] += multiple;
1576 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1577 std::move(result), op.getType());
1578 return DeletionKind::Delete;
1581 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1582 auto size = get<BagStorage *>(op.getBag())->bag.size();
1583 state[op.getResult()] = size;
1584 return DeletionKind::Delete;
1587 FailureOr<DeletionKind> visitOp(BagConvertToSetOp op) {
1588 auto bag = get<BagStorage *>(op.getInput())->bag;
1589 SetVector<ElaboratorValue> set;
1590 for (
auto [k, v] : bag)
1592 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1593 std::move(set), op.getType());
1594 return DeletionKind::Delete;
1597 FailureOr<DeletionKind> visitOp(FixedRegisterOp op) {
1598 return visitPureOp(op);
1601 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1602 auto *val = sharedState.internalizer.create<VirtualRegisterStorage>(
1603 op.getAllowedRegsAttr(), op.getType());
1604 state[op.getResult()] = val;
1605 materializer.registerIdentityValue(val);
1606 return DeletionKind::Delete;
1609 StringAttr substituteFormatString(StringAttr formatString,
1610 ValueRange substitutes)
const {
1611 if (substitutes.empty() || formatString.empty())
1612 return formatString;
1614 auto original = formatString.getValue().str();
1615 for (
auto [i, subst] :
llvm::enumerate(substitutes)) {
1616 size_t startPos = 0;
1617 std::string from =
"{{" + std::to_string(i) +
"}}";
1618 while ((startPos = original.find(from, startPos)) != std::string::npos) {
1619 auto substString = std::to_string(get<size_t>(subst));
1620 original.replace(startPos, from.length(), substString);
1624 return StringAttr::get(formatString.getContext(), original);
1627 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1628 SmallVector<ElaboratorValue> array;
1629 array.reserve(op.getElements().size());
1630 for (
auto val : op.getElements())
1631 array.emplace_back(state.at(val));
1633 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1634 op.getResult().getType(), std::move(array));
1635 return DeletionKind::Delete;
1638 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1639 auto array = get<ArrayStorage *>(op.getArray())->array;
1640 size_t idx = get<size_t>(op.getIndex());
1642 if (array.size() <= idx)
1643 return op->emitError(
"invalid to access index ")
1644 << idx <<
" of an array with " << array.size() <<
" elements";
1646 state[op.getResult()] = array[idx];
1647 return DeletionKind::Delete;
1650 FailureOr<DeletionKind> visitOp(ArrayInjectOp op) {
1651 auto array = get<ArrayStorage *>(op.getArray())->array;
1652 size_t idx = get<size_t>(op.getIndex());
1654 if (array.size() <= idx)
1655 return op->emitError(
"invalid to access index ")
1656 << idx <<
" of an array with " << array.size() <<
" elements";
1658 array[idx] = state[op.getValue()];
1659 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1660 op.getResult().getType(), std::move(array));
1661 return DeletionKind::Delete;
1664 FailureOr<DeletionKind> visitOp(ArraySizeOp op) {
1665 auto array = get<ArrayStorage *>(op.getArray())->array;
1666 state[op.getResult()] = array.size();
1667 return DeletionKind::Delete;
1670 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
1672 substituteFormatString(op.getFormatStringAttr(), op.getArgs());
1673 state[op.getLabel()] = LabelValue(substituted);
1674 return DeletionKind::Delete;
1677 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1678 auto *val = sharedState.internalizer.create<UniqueLabelStorage>(
1679 substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
1680 state[op.getLabel()] = val;
1681 materializer.registerIdentityValue(val);
1682 return DeletionKind::Delete;
1685 FailureOr<DeletionKind> visitOp(LabelOp op) {
return DeletionKind::Keep; }
1687 FailureOr<DeletionKind> visitOp(TestSuccessOp op) {
1688 return DeletionKind::Keep;
1691 FailureOr<DeletionKind> visitOp(TestFailureOp op) {
1692 return DeletionKind::Keep;
1695 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1696 size_t lower = get<size_t>(op.getLowerBound());
1697 size_t upper = get<size_t>(op.getUpperBound());
1699 return op->emitError(
"cannot select a number from an empty range");
1702 op->getAttrOfType<IntegerAttr>(
"rtg.elaboration_custom_seed")) {
1703 std::mt19937 customRng(intAttr.getInt());
1704 state[op.getResult()] =
1707 state[op.getResult()] =
1711 return DeletionKind::Delete;
1714 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1715 size_t input = get<size_t>(op.getInput());
1716 auto width = op.getType().getWidth();
1717 auto emitError = [&]() {
return op->emitError(); };
1718 if (input > APInt::getAllOnes(width).getZExtValue())
1719 return emitError() <<
"cannot represent " << input <<
" with " << width
1722 state[op.getResult()] =
1723 ImmediateAttr::get(op.getContext(), APInt(width, input));
1724 return DeletionKind::Delete;
1727 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1728 ContextResourceAttrInterface from = currentContext,
1729 to = cast<ContextResourceAttrInterface>(
1730 get<TypedAttr>(op.getContext()));
1731 if (!currentContext)
1732 from = DefaultContextAttr::get(op->getContext(), to.getType());
1734 auto emitError = [&]() {
1735 auto diag = op.emitError();
1736 diag.attachNote(op.getLoc())
1737 <<
"while materializing value for context switching for " << op;
1742 Value seqVal = materializer.materialize(
1743 get<SequenceStorage *>(op.getSequence()), op.getLoc(), emitError);
1748 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1749 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1750 return DeletionKind::Delete;
1756 auto *iter = testState.contextSwitches.find({from, to});
1759 if (iter == testState.contextSwitches.end())
1760 iter = testState.contextSwitches.find(
1761 {from, AnyContextAttr::get(op->getContext(), to.getType())});
1764 if (iter == testState.contextSwitches.end())
1765 iter = testState.contextSwitches.find(
1766 {AnyContextAttr::get(op->getContext(), from.getType()), to});
1769 if (iter == testState.contextSwitches.end())
1770 iter = testState.contextSwitches.find(
1771 {AnyContextAttr::get(op->getContext(), from.getType()),
1772 AnyContextAttr::get(op->getContext(), to.getType())});
1778 if (iter == testState.contextSwitches.end())
1779 return op->emitError(
"no context transition registered to switch from ")
1780 << from <<
" to " << to;
1782 auto familyName = iter->second->familyName;
1783 SmallVector<ElaboratorValue> args{from, to,
1784 get<SequenceStorage *>(op.getSequence())};
1785 auto *
seq = sharedState.internalizer.internalize<SequenceStorage>(
1786 familyName, std::move(args));
1788 sharedState.internalizer.create<RandomizedSequenceStorage>(to,
seq);
1789 materializer.registerIdentityValue(randSeq);
1790 Value seqVal = materializer.materialize(randSeq, op.getLoc(), emitError);
1794 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1795 return DeletionKind::Delete;
1798 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1799 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1800 get<SequenceStorage *>(op.getSequence());
1801 return DeletionKind::Delete;
1804 FailureOr<DeletionKind> visitOp(MemoryBlockDeclareOp op) {
1805 auto *val = sharedState.internalizer.create<MemoryBlockStorage>(
1806 op.getBaseAddress(), op.getEndAddress(), op.getType());
1807 state[op.getResult()] = val;
1808 materializer.registerIdentityValue(val);
1809 return DeletionKind::Delete;
1812 FailureOr<DeletionKind> visitOp(MemoryAllocOp op) {
1813 size_t size = get<size_t>(op.getSize());
1814 size_t alignment = get<size_t>(op.getAlignment());
1815 auto *memBlock = get<MemoryBlockStorage *>(op.getMemoryBlock());
1816 auto *val = sharedState.internalizer.create<MemoryStorage>(memBlock, size,
1818 state[op.getResult()] = val;
1819 materializer.registerIdentityValue(val);
1820 return DeletionKind::Delete;
1823 FailureOr<DeletionKind> visitOp(MemorySizeOp op) {
1824 auto *memory = get<MemoryStorage *>(op.getMemory());
1825 state[op.getResult()] = memory->size;
1826 return DeletionKind::Delete;
1829 FailureOr<DeletionKind> visitOp(TupleCreateOp op) {
1830 SmallVector<ElaboratorValue> values;
1831 values.reserve(op.getElements().size());
1832 for (
auto el : op.getElements())
1833 values.push_back(state[el]);
1835 state[op.getResult()] =
1836 sharedState.internalizer.internalize<TupleStorage>(std::move(values));
1837 return DeletionKind::Delete;
1840 FailureOr<DeletionKind> visitOp(TupleExtractOp op) {
1841 auto *tuple = get<TupleStorage *>(op.getTuple());
1842 state[op.getResult()] = tuple->values[op.getIndex().getZExtValue()];
1843 return DeletionKind::Delete;
1846 FailureOr<DeletionKind> visitOp(CommentOp op) {
return DeletionKind::Keep; }
1848 FailureOr<DeletionKind> visitOp(rtg::YieldOp op) {
1849 return DeletionKind::Keep;
1852 FailureOr<DeletionKind> visitOp(ValidateOp op) {
1853 SmallVector<ElaboratorValue> defaultUsedValues, elseValues;
1854 for (
auto v : op.getDefaultUsedValues())
1855 defaultUsedValues.push_back(state.at(v));
1857 for (
auto v : op.getElseValues())
1858 elseValues.push_back(state.at(v));
1860 auto *validationVal = sharedState.internalizer.create<ValidationValue>(
1861 op.getValue().getType(), state[op.getRef()],
1862 state[op.getDefaultValue()], op.getIdAttr(),
1863 std::move(defaultUsedValues), std::move(elseValues));
1864 state[op.getValue()] = validationVal;
1865 materializer.registerIdentityValue(validationVal);
1866 materializer.map(validationVal, op.getValue());
1868 for (
auto [i, val] :
llvm::enumerate(op.getValues())) {
1869 auto *muxVal = sharedState.internalizer.create<ValidationMuxedValue>(
1870 val.getType(), validationVal, i);
1871 state[val] = muxVal;
1872 materializer.registerIdentityValue(muxVal);
1873 materializer.map(muxVal, val);
1876 return DeletionKind::Keep;
1879 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1880 bool cond = get<bool>(op.getCondition());
1881 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1882 if (toElaborate.empty())
1883 return DeletionKind::Delete;
1889 SmallVector<ElaboratorValue> yieldedVals;
1890 if (failed(elaborate(toElaborate, {}, yieldedVals)))
1894 for (
auto [res, out] :
llvm::zip(op.getResults(), yieldedVals))
1897 return DeletionKind::Delete;
1900 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1901 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1902 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1903 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1904 return op->emitOpError(
"can only elaborate index type iterator");
1906 auto lowerBound = get<size_t>(op.getLowerBound());
1907 auto step = get<size_t>(op.getStep());
1908 auto upperBound = get<size_t>(op.getUpperBound());
1914 state[op.getInductionVar()] = lowerBound;
1915 for (
auto [iterArg, initArg] :
1916 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1917 state[iterArg] = state.at(initArg);
1920 SmallVector<ElaboratorValue> yieldedVals;
1921 for (
size_t i = lowerBound; i < upperBound; i += step) {
1922 yieldedVals.clear();
1923 if (failed(elaborate(op.getBodyRegion(), {}, yieldedVals)))
1928 state[op.getInductionVar()] = i + step;
1929 for (
auto [iterArg, prevIterArg] :
1930 llvm::zip(op.getRegionIterArgs(), yieldedVals))
1931 state[iterArg] = prevIterArg;
1935 for (
auto [res, iterArg] :
1936 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1937 state[res] = state.at(iterArg);
1939 return DeletionKind::Delete;
1942 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1943 return DeletionKind::Delete;
1946 FailureOr<DeletionKind> visitOp(arith::AddIOp op) {
1947 if (!isa<IndexType>(op.getType()))
1948 return op->emitError(
"only index operands supported");
1950 size_t lhs = get<size_t>(op.getLhs());
1951 size_t rhs = get<size_t>(op.getRhs());
1952 state[op.getResult()] = lhs + rhs;
1953 return DeletionKind::Delete;
1956 FailureOr<DeletionKind> visitOp(arith::AndIOp op) {
1957 if (!op.getType().isSignlessInteger(1))
1958 return op->emitError(
"only 'i1' operands supported");
1960 bool lhs = get<bool>(op.getLhs());
1961 bool rhs = get<bool>(op.getRhs());
1962 state[op.getResult()] = lhs && rhs;
1963 return DeletionKind::Delete;
1966 FailureOr<DeletionKind> visitOp(arith::XOrIOp op) {
1967 if (!op.getType().isSignlessInteger(1))
1968 return op->emitError(
"only 'i1' operands supported");
1970 bool lhs = get<bool>(op.getLhs());
1971 bool rhs = get<bool>(op.getRhs());
1972 state[op.getResult()] = lhs != rhs;
1973 return DeletionKind::Delete;
1976 FailureOr<DeletionKind> visitOp(arith::OrIOp op) {
1977 if (!op.getType().isSignlessInteger(1))
1978 return op->emitError(
"only 'i1' operands supported");
1980 bool lhs = get<bool>(op.getLhs());
1981 bool rhs = get<bool>(op.getRhs());
1982 state[op.getResult()] = lhs || rhs;
1983 return DeletionKind::Delete;
1986 FailureOr<DeletionKind> visitOp(arith::SelectOp op) {
1987 bool cond = get<bool>(op.getCondition());
1988 auto trueVal = state[op.getTrueValue()];
1989 auto falseVal = state[op.getFalseValue()];
1990 state[op.getResult()] = cond ? trueVal : falseVal;
1991 return DeletionKind::Delete;
1994 FailureOr<DeletionKind> visitOp(index::AddOp op) {
1995 size_t lhs = get<size_t>(op.getLhs());
1996 size_t rhs = get<size_t>(op.getRhs());
1997 state[op.getResult()] = lhs + rhs;
1998 return DeletionKind::Delete;
2001 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
2002 size_t lhs = get<size_t>(op.getLhs());
2003 size_t rhs = get<size_t>(op.getRhs());
2005 switch (op.getPred()) {
2006 case index::IndexCmpPredicate::EQ:
2007 result = lhs == rhs;
2009 case index::IndexCmpPredicate::NE:
2010 result = lhs != rhs;
2012 case index::IndexCmpPredicate::ULT:
2015 case index::IndexCmpPredicate::ULE:
2016 result = lhs <= rhs;
2018 case index::IndexCmpPredicate::UGT:
2021 case index::IndexCmpPredicate::UGE:
2022 result = lhs >= rhs;
2025 return op->emitOpError(
"elaboration not supported");
2027 state[op.getResult()] = result;
2028 return DeletionKind::Delete;
2031 FailureOr<DeletionKind> visitOp(ConcatImmediateOp op) {
2032 bool anyValidationValues =
2033 llvm::any_of(op.getOperands(), [&](
auto operand) {
2034 return std::holds_alternative<ValidationValue *>(state[operand]);
2039 if (anyValidationValues) {
2040 SmallVector<ElaboratorValue> operands;
2041 for (
auto operand : op.getOperands())
2042 operands.push_back(state[operand]);
2043 state[op.getResult()] =
2044 sharedState.internalizer.internalize<ImmediateConcatStorage>(
2045 std::move(operands));
2046 return DeletionKind::Delete;
2049 auto result = APInt::getZeroWidth();
2050 for (
auto operand : op.getOperands())
2052 cast<ImmediateAttr>(
get<TypedAttr>(operand)).getValue());
2054 state[op.getResult()] = ImmediateAttr::get(op.getContext(), result);
2055 return DeletionKind::Delete;
2058 FailureOr<DeletionKind> visitOp(SliceImmediateOp op) {
2061 if (std::holds_alternative<ValidationValue *>(state[op.getInput()])) {
2062 state[op.getResult()] =
2063 sharedState.internalizer.internalize<ImmediateSliceStorage>(
2064 state[op.getInput()], op.getLowBit(), op.getResult().getType());
2065 return DeletionKind::Delete;
2069 cast<ImmediateAttr>(get<TypedAttr>(op.getInput())).getValue();
2070 auto sliced = inputValue.extractBits(op.getResult().getType().getWidth(),
2072 state[op.getResult()] = ImmediateAttr::get(op.getContext(), sliced);
2073 return DeletionKind::Delete;
2076 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
2077 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
2080 arith::AddIOp, arith::XOrIOp, arith::AndIOp, arith::OrIOp,
2083 index::AddOp, index::CmpOp,
2085 scf::IfOp, scf::ForOp, scf::YieldOp>(
2086 [&](
auto op) {
return visitOp(op); })
2087 .Default([&](Operation *op) {
return RTGBase::dispatchOpVisitor(op); });
2091 LogicalResult elaborate(Region ®ion,
2092 ArrayRef<ElaboratorValue> regionArguments,
2093 SmallVector<ElaboratorValue> &terminatorOperands) {
2094 if (region.getBlocks().size() > 1)
2095 return region.getParentOp()->emitOpError(
2096 "regions with more than one block are not supported");
2098 for (
auto [arg, elabArg] :
2099 llvm::zip(region.getArguments(), regionArguments))
2100 state[arg] = elabArg;
2102 Block *block = ®ion.front();
2103 for (
auto &op : *block) {
2104 auto result = dispatchOpVisitor(&op);
2108 if (*result == DeletionKind::Keep)
2109 if (failed(materializer.materialize(&op, state)))
2113 llvm::dbgs() <<
"Elaborated " << op <<
" to\n[";
2115 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](
auto res) {
2116 if (state.contains(res))
2117 llvm::dbgs() << state.at(res);
2119 llvm::dbgs() <<
"unknown";
2122 llvm::dbgs() <<
"]\n\n";
2126 if (region.front().mightHaveTerminator())
2127 for (
auto val : region.front().getTerminator()->getOperands())
2128 terminatorOperands.push_back(state.at(val));
2135 SharedState &sharedState;
2138 TestState &testState;
2142 Materializer &materializer;
2145 DenseMap<Value, ElaboratorValue> state;
2148 ContextResourceAttrInterface currentContext;
2153Materializer::elaborateSequence(
const RandomizedSequenceStorage *
seq,
2154 SmallVector<ElaboratorValue> &elabArgs) {
2156 sharedState.table.lookup<SequenceOp>(
seq->sequence->familyName);
2159 OpBuilder builder(familyOp);
2160 auto seqOp = builder.cloneWithoutRegions(familyOp);
2161 auto name = sharedState.names.newName(
seq->sequence->familyName.getValue());
2162 seqOp.setSymName(name);
2163 seqOp.getBodyRegion().emplaceBlock();
2164 sharedState.table.insert(seqOp);
2165 assert(seqOp.getSymName() == name &&
"should not have been renamed");
2167 LLVM_DEBUG(llvm::dbgs() <<
"\n=== Elaborating sequence family @"
2168 << familyOp.getSymName() <<
" into @"
2169 << seqOp.getSymName() <<
" under context "
2170 <<
seq->context <<
"\n\n");
2172 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()), testState,
2173 sharedState, elabArgs);
2174 Elaborator elaborator(sharedState, testState, materializer,
seq->context);
2175 SmallVector<ElaboratorValue> yieldedVals;
2176 if (failed(elaborator.elaborate(familyOp.getBodyRegion(),
seq->sequence->args,
2180 seqOp.setSequenceType(
2181 SequenceType::get(builder.getContext(), materializer.getBlockArgTypes()));
2182 materializer.finalize();
2192struct ElaborationPass
2193 :
public rtg::impl::ElaborationPassBase<ElaborationPass> {
2196 void runOnOperation()
override;
2197 void matchTestsAgainstTargets(SymbolTable &table);
2198 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
2202void ElaborationPass::runOnOperation() {
2203 auto moduleOp = getOperation();
2204 SymbolTable table(moduleOp);
2206 matchTestsAgainstTargets(table);
2208 if (failed(elaborateModule(moduleOp, table)))
2209 return signalPassFailure();
2212void ElaborationPass::matchTestsAgainstTargets(SymbolTable &table) {
2213 auto moduleOp = getOperation();
2215 for (
auto test :
llvm::make_early_inc_range(moduleOp.getOps<TestOp>())) {
2216 if (test.getTargetAttr())
2219 bool matched =
false;
2221 for (
auto target : moduleOp.getOps<TargetOp>()) {
2225 bool isSubtype =
true;
2226 auto testEntries = test.getTargetType().getEntries();
2227 auto targetEntries = target.getTarget().getEntries();
2231 size_t targetIdx = 0;
2232 for (
auto testEntry : testEntries) {
2234 while (targetIdx < targetEntries.size() &&
2235 targetEntries[targetIdx].name.getValue() <
2236 testEntry.name.getValue())
2240 if (targetIdx >= targetEntries.size() ||
2241 targetEntries[targetIdx].name != testEntry.name ||
2242 targetEntries[targetIdx].type != testEntry.type) {
2251 IRRewriter rewriter(test);
2253 auto newTest = cast<TestOp>(test->clone());
2254 newTest.setSymName(test.getSymName().str() +
"_" +
2255 target.getSymName().str());
2259 newTest.setTargetAttr(target.getSymNameAttr());
2261 table.insert(newTest, rewriter.getInsertionPoint());
2265 if (matched || deleteUnmatchedTests)
2271 return isa<MemoryBlockType, ContextResourceTypeInterface>(type);
2274LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
2275 SymbolTable &table) {
2276 SharedState state(table, seed);
2279 state.names.add(moduleOp);
2281 struct TargetElabResult {
2282 DictType targetType;
2283 SmallVector<ElaboratorValue> yields;
2284 TestState testState;
2288 DenseMap<StringAttr, TargetElabResult> targetMap;
2289 for (
auto targetOp : moduleOp.getOps<TargetOp>()) {
2290 LLVM_DEBUG(llvm::dbgs() <<
"=== Elaborating target @"
2291 << targetOp.getSymName() <<
"\n\n");
2293 auto &result = targetMap[targetOp.getSymNameAttr()];
2294 result.targetType = targetOp.getTarget();
2296 SmallVector<ElaboratorValue> blockArgs;
2297 Materializer targetMaterializer(OpBuilder::atBlockBegin(targetOp.getBody()),
2298 result.testState, state, blockArgs);
2299 Elaborator targetElaborator(state, result.testState, targetMaterializer);
2302 if (failed(targetElaborator.elaborate(targetOp.getBodyRegion(), {},
2309 for (
auto testOp : moduleOp.getOps<TestOp>()) {
2313 if (!testOp.getTargetAttr())
2316 LLVM_DEBUG(llvm::dbgs()
2317 <<
"\n=== Elaborating test @" << testOp.getTemplateName()
2318 <<
" for target @" << *testOp.getTarget() <<
"\n\n");
2321 auto targetResult = targetMap[testOp.getTargetAttr()];
2322 TestState testState = targetResult.testState;
2323 testState.name = testOp.getSymNameAttr();
2325 SmallVector<ElaboratorValue> filteredYields;
2327 for (
auto [entry, yield] :
2328 llvm::zip(targetResult.targetType.getEntries(), targetResult.yields)) {
2329 if (i >= testOp.getTargetType().getEntries().size())
2332 if (entry.name == testOp.getTargetType().getEntries()[i].name) {
2333 filteredYields.push_back(yield);
2340 SmallVector<ElaboratorValue> blockArgs;
2341 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()),
2342 testState, state, blockArgs);
2344 for (
auto [arg, val] :
2345 llvm::zip(testOp.getBody()->getArguments(), filteredYields))
2347 materializer.map(val, arg);
2349 Elaborator elaborator(state, testState, materializer);
2350 SmallVector<ElaboratorValue> ignore;
2351 if (failed(elaborator.elaborate(testOp.getBodyRegion(), filteredYields,
2355 materializer.finalize();
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static bool onlyLegalToMaterializeInTarget(Type type)
static void print(TypedAttr val, llvm::raw_ostream &os)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
This helps visit TypeOp nodes.
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
static llvm::hash_code hash_value(const ModulePort &port)
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static LabelValue getEmptyKey()
static LabelValue getTombstoneKey()
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)
static unsigned getEmptyKey()