Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ElaborationPass.cpp
Go to the documentation of this file.
1//===- ElaborationPass.cpp - RTG ElaborationPass implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass elaborates the random parts of the RTG dialect.
10// It performs randomization top-down, i.e., random constructs in a sequence
11// that is invoked multiple times can yield different randomization results
12// for each invokation.
13//
14//===----------------------------------------------------------------------===//
15
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Index/IR/IndexDialect.h"
23#include "mlir/Dialect/Index/IR/IndexOps.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/DenseMapInfoVariant.h"
28#include "llvm/Support/Debug.h"
29#include <queue>
30#include <random>
31
32namespace circt {
33namespace rtg {
34#define GEN_PASS_DEF_ELABORATIONPASS
35#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
36} // namespace rtg
37} // namespace circt
38
39using namespace mlir;
40using namespace circt;
41using namespace circt::rtg;
42using llvm::MapVector;
43
44#define DEBUG_TYPE "rtg-elaboration"
45
46//===----------------------------------------------------------------------===//
47// Uniform Distribution Helper
48//
49// Simplified version of
50// https://github.com/llvm/llvm-project/blob/main/libcxx/include/__random/uniform_int_distribution.h
51// We use our custom version here to get the same results when compiled with
52// different compiler versions and standard libraries.
53//===----------------------------------------------------------------------===//
54
55static uint32_t computeMask(size_t w) {
56 size_t n = w / 32 + (w % 32 != 0);
57 size_t w0 = w / n;
58 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
59}
60
61/// Get a number uniformly at random in the in specified range.
62static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) {
63 const uint32_t diff = b - a + 1;
64 if (diff == 1)
65 return a;
66
67 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
68 if (diff == 0)
69 return rng();
70
71 uint32_t width = digits - llvm::countl_zero(diff) - 1;
72 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
73 ++width;
74
75 uint32_t mask = computeMask(diff);
76 uint32_t u;
77 do {
78 u = rng() & mask;
79 } while (u >= diff);
80
81 return u + a;
82}
83
84//===----------------------------------------------------------------------===//
85// Elaborator Value
86//===----------------------------------------------------------------------===//
87
88namespace {
89struct ArrayStorage;
90struct BagStorage;
91struct SequenceStorage;
92struct RandomizedSequenceStorage;
93struct InterleavedSequenceStorage;
94struct SetStorage;
95struct VirtualRegisterStorage;
96struct UniqueLabelStorage;
97struct TupleStorage;
98struct MemoryStorage;
99struct MemoryBlockStorage;
100struct ValidationValue;
101
102/// Simple wrapper around a 'StringAttr' such that we know to materialize it as
103/// a label declaration instead of calling the builtin dialect constant
104/// materializer.
105struct LabelValue {
106 LabelValue(StringAttr name) : name(name) {}
107
108 bool operator==(const LabelValue &other) const { return name == other.name; }
109
110 /// The label name.
111 StringAttr name;
112};
113
114/// The abstract base class for elaborated values.
115using ElaboratorValue =
116 std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
117 RandomizedSequenceStorage *, InterleavedSequenceStorage *,
118 SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
119 LabelValue, ArrayStorage *, TupleStorage *, MemoryStorage *,
120 MemoryBlockStorage *, ValidationValue *>;
121
122// NOLINTNEXTLINE(readability-identifier-naming)
123llvm::hash_code hash_value(const LabelValue &val) {
124 return llvm::hash_value(val.name);
125}
126
127// NOLINTNEXTLINE(readability-identifier-naming)
128llvm::hash_code hash_value(const ElaboratorValue &val) {
129 return std::visit(
130 [&val](const auto &alternative) {
131 // Include index in hash to make sure same value as different
132 // alternatives don't collide.
133 return llvm::hash_combine(val.index(), alternative);
134 },
135 val);
136}
137
138} // namespace
139
140namespace llvm {
141
142template <>
143struct DenseMapInfo<bool> {
144 static inline unsigned getEmptyKey() { return false; }
145 static inline unsigned getTombstoneKey() { return true; }
146 static unsigned getHashValue(const bool &val) { return val * 37U; }
147
148 static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; }
149};
150template <>
151struct DenseMapInfo<LabelValue> {
152 static inline LabelValue getEmptyKey() {
154 }
155 static inline LabelValue getTombstoneKey() {
157 }
158 static unsigned getHashValue(const LabelValue &val) {
159 return hash_value(val);
160 }
161
162 static bool isEqual(const LabelValue &lhs, const LabelValue &rhs) {
163 return lhs == rhs;
164 }
165};
166
167} // namespace llvm
168
169//===----------------------------------------------------------------------===//
170// Elaborator Value Storages and Internalization
171//===----------------------------------------------------------------------===//
172
173namespace {
174
175/// Lightweight object to be used as the key for internalization sets. It caches
176/// the hashcode of the internalized object and a pointer to it. This allows a
177/// delayed allocation and construction of the actual object and thus only has
178/// to happen if the object is not already in the set.
179template <typename StorageTy>
180struct HashedStorage {
181 HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr)
182 : hashcode(hashcode), storage(storage) {}
183
184 unsigned hashcode;
185 StorageTy *storage;
186};
187
188/// A DenseMapInfo implementation to support 'insert_as' for the internalization
189/// sets. When comparing two 'HashedStorage's we can just compare the already
190/// internalized storage pointers, otherwise we have to call the costly
191/// 'isEqual' method.
192template <typename StorageTy>
193struct StorageKeyInfo {
194 static inline HashedStorage<StorageTy> getEmptyKey() {
195 return HashedStorage<StorageTy>(0,
196 DenseMapInfo<StorageTy *>::getEmptyKey());
197 }
198 static inline HashedStorage<StorageTy> getTombstoneKey() {
199 return HashedStorage<StorageTy>(
200 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
201 }
202
203 static inline unsigned getHashValue(const HashedStorage<StorageTy> &key) {
204 return key.hashcode;
205 }
206 static inline unsigned getHashValue(const StorageTy &key) {
207 return key.hashcode;
208 }
209
210 static inline bool isEqual(const HashedStorage<StorageTy> &lhs,
211 const HashedStorage<StorageTy> &rhs) {
212 return lhs.storage == rhs.storage;
213 }
214 static inline bool isEqual(const StorageTy &lhs,
215 const HashedStorage<StorageTy> &rhs) {
216 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
217 return false;
218
219 return lhs.isEqual(rhs.storage);
220 }
221};
222
223// Values with structural equivalence intended to be internalized.
224//===----------------------------------------------------------------------===//
225
226/// Storage object for an '!rtg.set<T>'.
227struct SetStorage {
228 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
229 : hashcode(llvm::hash_combine(
230 type, llvm::hash_combine_range(set.begin(), set.end()))),
231 set(std::move(set)), type(type) {}
232
233 bool isEqual(const SetStorage *other) const {
234 return hashcode == other->hashcode && set == other->set &&
235 type == other->type;
236 }
237
238 // The cached hashcode to avoid repeated computations.
239 const unsigned hashcode;
240
241 // Stores the elaborated values contained in the set.
242 const SetVector<ElaboratorValue> set;
243
244 // Store the set type such that we can materialize this evaluated value
245 // also in the case where the set is empty.
246 const Type type;
247};
248
249/// Storage object for an '!rtg.bag<T>'.
250struct BagStorage {
251 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
252 : hashcode(llvm::hash_combine(
253 type, llvm::hash_combine_range(bag.begin(), bag.end()))),
254 bag(std::move(bag)), type(type) {}
255
256 bool isEqual(const BagStorage *other) const {
257 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
258 type == other->type;
259 }
260
261 // The cached hashcode to avoid repeated computations.
262 const unsigned hashcode;
263
264 // Stores the elaborated values contained in the bag with their number of
265 // occurences.
266 const MapVector<ElaboratorValue, uint64_t> bag;
267
268 // Store the bag type such that we can materialize this evaluated value
269 // also in the case where the bag is empty.
270 const Type type;
271};
272
273/// Storage object for an '!rtg.sequence'.
274struct SequenceStorage {
275 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
276 : hashcode(llvm::hash_combine(
277 familyName, llvm::hash_combine_range(args.begin(), args.end()))),
278 familyName(familyName), args(std::move(args)) {}
279
280 bool isEqual(const SequenceStorage *other) const {
281 return hashcode == other->hashcode && familyName == other->familyName &&
282 args == other->args;
283 }
284
285 // The cached hashcode to avoid repeated computations.
286 const unsigned hashcode;
287
288 // The name of the sequence family this sequence is derived from.
289 const StringAttr familyName;
290
291 // The elaborator values used during substitution of the sequence family.
292 const SmallVector<ElaboratorValue> args;
293};
294
295/// Storage object for interleaved '!rtg.randomized_sequence'es.
296struct InterleavedSequenceStorage {
297 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
298 uint32_t batchSize)
299 : sequences(std::move(sequences)), batchSize(batchSize),
300 hashcode(llvm::hash_combine(
301 llvm::hash_combine_range(sequences.begin(), sequences.end()),
302 batchSize)) {}
303
304 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
305 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
306 hashcode(llvm::hash_combine(
307 llvm::hash_combine_range(sequences.begin(), sequences.end()),
308 batchSize)) {}
309
310 bool isEqual(const InterleavedSequenceStorage *other) const {
311 return hashcode == other->hashcode && sequences == other->sequences &&
312 batchSize == other->batchSize;
313 }
314
315 const SmallVector<ElaboratorValue> sequences;
316
317 const uint32_t batchSize;
318
319 // The cached hashcode to avoid repeated computations.
320 const unsigned hashcode;
321};
322
323/// Storage object for '!rtg.array`-typed values.
324struct ArrayStorage {
325 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
326 : hashcode(llvm::hash_combine(
327 type, llvm::hash_combine_range(array.begin(), array.end()))),
328 type(type), array(array) {}
329
330 bool isEqual(const ArrayStorage *other) const {
331 return hashcode == other->hashcode && type == other->type &&
332 array == other->array;
333 }
334
335 // The cached hashcode to avoid repeated computations.
336 const unsigned hashcode;
337
338 /// The type of the array. This is necessary because an array of size 0
339 /// cannot be reconstructed without knowing the original element type.
340 const Type type;
341
342 /// The label name. For unique labels, this is just the prefix.
343 const SmallVector<ElaboratorValue> array;
344};
345
346/// Storage object for 'tuple`-typed values.
347struct TupleStorage {
348 TupleStorage(SmallVector<ElaboratorValue> &&values)
349 : hashcode(llvm::hash_combine_range(values.begin(), values.end())),
350 values(std::move(values)) {}
351
352 bool isEqual(const TupleStorage *other) const {
353 return hashcode == other->hashcode && values == other->values;
354 }
355
356 // The cached hashcode to avoid repeated computations.
357 const unsigned hashcode;
358
359 const SmallVector<ElaboratorValue> values;
360};
361
362// Values with identity not intended to be internalized.
363//===----------------------------------------------------------------------===//
364
365/// Base class for storages that represent values with identity, i.e., two
366/// values are not considered equivalent if they are structurally the same, but
367/// each definition of such a value is unique. E.g., unique labels or virtual
368/// registers. These cannot be materialized anew in each nested sequence, but
369/// must be passed as arguments.
370struct IdentityValue {
371
372 IdentityValue(Type type) : type(type) {}
373
374#ifndef NDEBUG
375
376 /// In debug mode, track whether this value was already materialized to
377 /// assert if it's illegally materialized multiple times.
378 ///
379 /// Instead of deleting operations defining these values and materializing
380 /// them again, we could retain the operations. However, we still need
381 /// specific storages to represent these values in some cases, e.g., to get
382 /// the size of a memory allocation. Also, elaboration of nested control-flow
383 /// regions (e.g. `scf.for`) relies on materialization of such values lazily
384 /// instead of cloning the operations eagerly.
385 bool alreadyMaterialized = false;
386
387#endif
388
389 const Type type;
390};
391
392/// Represents a unique virtual register.
393struct VirtualRegisterStorage : IdentityValue {
394 VirtualRegisterStorage(ArrayAttr allowedRegs, Type type)
395 : IdentityValue(type), allowedRegs(allowedRegs) {}
396
397 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
398 // VirtualRegisters are never internalized.
399
400 // The list of fixed registers allowed to be selected for this virtual
401 // register.
402 const ArrayAttr allowedRegs;
403};
404
405struct UniqueLabelStorage : IdentityValue {
406 UniqueLabelStorage(StringAttr name)
407 : IdentityValue(LabelType::get(name.getContext())), name(name) {}
408
409 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
410 // VirtualRegisters are never internalized.
411
412 /// The label name. For unique labels, this is just the prefix.
413 const StringAttr name;
414};
415
416/// Storage object for '!rtg.isa.memoryblock`-typed values.
417struct MemoryBlockStorage : IdentityValue {
418 MemoryBlockStorage(const APInt &baseAddress, const APInt &endAddress,
419 Type type)
420 : IdentityValue(type), baseAddress(baseAddress), endAddress(endAddress) {}
421
422 // The base address of the memory. The width of the APInt also represents the
423 // address width of the memory. This is an APInt to support memories of
424 // >64-bit machines.
425 const APInt baseAddress;
426
427 // The last address of the memory.
428 const APInt endAddress;
429};
430
431/// Storage object for '!rtg.isa.memory`-typed values.
432struct MemoryStorage : IdentityValue {
433 MemoryStorage(MemoryBlockStorage *memoryBlock, size_t size, size_t alignment)
434 : IdentityValue(MemoryType::get(memoryBlock->type.getContext(),
435 memoryBlock->baseAddress.getBitWidth())),
436 memoryBlock(memoryBlock), size(size), alignment(alignment) {}
437
438 MemoryBlockStorage *memoryBlock;
439 const size_t size;
440 const size_t alignment;
441};
442
443/// Storage object for an '!rtg.randomized_sequence'.
444struct RandomizedSequenceStorage : IdentityValue {
445 RandomizedSequenceStorage(ContextResourceAttrInterface context,
446 SequenceStorage *sequence)
447 : IdentityValue(
448 RandomizedSequenceType::get(sequence->familyName.getContext())),
449 context(context), sequence(sequence) {}
450
451 // The context under which this sequence is placed.
452 const ContextResourceAttrInterface context;
453
454 const SequenceStorage *sequence;
455};
456
457/// Storage object for an '!rtg.validate' result.
458struct ValidationValue : IdentityValue {
459 ValidationValue(Type type, const ElaboratorValue &ref,
460 const ElaboratorValue &defaultValue, StringAttr id)
461 : IdentityValue(type), ref(ref), defaultValue(defaultValue), id(id) {}
462
463 const ElaboratorValue ref;
464 const ElaboratorValue defaultValue;
465 const StringAttr id;
466};
467
468/// An 'Internalizer' object internalizes storages and takes ownership of them.
469/// When the initializer object is destroyed, all owned storages are also
470/// deallocated and thus must not be accessed anymore.
471class Internalizer {
472public:
473 /// Internalize a storage of type `StorageTy` constructed with arguments
474 /// `args`. The pointers returned by this method can be used to compare
475 /// objects when, e.g., computing set differences, uniquing the elements in a
476 /// set, etc. Otherwise, we'd need to do a deep value comparison in those
477 /// situations.
478 template <typename StorageTy, typename... Args>
479 StorageTy *internalize(Args &&...args) {
480 static_assert(!std::is_base_of_v<IdentityValue, StorageTy> &&
481 "values with identity must not be internalized");
482
483 StorageTy storage(std::forward<Args>(args)...);
484
485 auto existing = getInternSet<StorageTy>().insert_as(
486 HashedStorage<StorageTy>(storage.hashcode), storage);
487 StorageTy *&storagePtr = existing.first->storage;
488 if (existing.second)
489 storagePtr =
490 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
491
492 return storagePtr;
493 }
494
495 template <typename StorageTy, typename... Args>
496 StorageTy *create(Args &&...args) {
497 static_assert(std::is_base_of_v<IdentityValue, StorageTy> &&
498 "values with structural equivalence must be internalized");
499
500 return new (allocator.Allocate<StorageTy>())
501 StorageTy(std::forward<Args>(args)...);
502 }
503
504private:
505 template <typename StorageTy>
506 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
507 getInternSet() {
508 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
509 return internedArrays;
510 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
511 return internedSets;
512 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
513 return internedBags;
514 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
515 return internedSequences;
516 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
517 return internedRandomizedSequences;
518 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
519 return internedInterleavedSequences;
520 else if constexpr (std::is_same_v<StorageTy, TupleStorage>)
521 return internedTuples;
522 else
523 static_assert(!sizeof(StorageTy),
524 "no intern set available for this storage type.");
525 }
526
527 // This allocator allocates on the heap. It automatically deallocates all
528 // objects it allocated once the allocator itself is destroyed.
529 llvm::BumpPtrAllocator allocator;
530
531 // The sets holding the internalized objects. We use one set per storage type
532 // such that we can have a simpler equality checking function (no need to
533 // compare some sort of TypeIDs).
534 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
535 internedArrays;
536 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
537 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
538 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
539 internedSequences;
540 DenseSet<HashedStorage<RandomizedSequenceStorage>,
541 StorageKeyInfo<RandomizedSequenceStorage>>
542 internedRandomizedSequences;
543 DenseSet<HashedStorage<InterleavedSequenceStorage>,
544 StorageKeyInfo<InterleavedSequenceStorage>>
545 internedInterleavedSequences;
546 DenseSet<HashedStorage<TupleStorage>, StorageKeyInfo<TupleStorage>>
547 internedTuples;
548};
549
550} // namespace
551
552#ifndef NDEBUG
553
554static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
555 const ElaboratorValue &value);
556
557static void print(TypedAttr val, llvm::raw_ostream &os) {
558 os << "<attr " << val << ">";
559}
560
561static void print(BagStorage *val, llvm::raw_ostream &os) {
562 os << "<bag {";
563 llvm::interleaveComma(val->bag, os,
564 [&](const std::pair<ElaboratorValue, uint64_t> &el) {
565 os << el.first << " -> " << el.second;
566 });
567 os << "} at " << val << ">";
568}
569
570static void print(bool val, llvm::raw_ostream &os) {
571 os << "<bool " << (val ? "true" : "false") << ">";
572}
573
574static void print(size_t val, llvm::raw_ostream &os) {
575 os << "<index " << val << ">";
576}
577
578static void print(SequenceStorage *val, llvm::raw_ostream &os) {
579 os << "<sequence @" << val->familyName.getValue() << "(";
580 llvm::interleaveComma(val->args, os,
581 [&](const ElaboratorValue &val) { os << val; });
582 os << ") at " << val << ">";
583}
584
585static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
586 os << "<randomized-sequence derived from @"
587 << val->sequence->familyName.getValue() << " under context "
588 << val->context << "(";
589 llvm::interleaveComma(val->sequence->args, os,
590 [&](const ElaboratorValue &val) { os << val; });
591 os << ") at " << val << ">";
592}
593
594static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
595 os << "<interleaved-sequence [";
596 llvm::interleaveComma(val->sequences, os,
597 [&](const ElaboratorValue &val) { os << val; });
598 os << "] batch-size " << val->batchSize << " at " << val << ">";
599}
600
601static void print(ArrayStorage *val, llvm::raw_ostream &os) {
602 os << "<array [";
603 llvm::interleaveComma(val->array, os,
604 [&](const ElaboratorValue &val) { os << val; });
605 os << "] at " << val << ">";
606}
607
608static void print(SetStorage *val, llvm::raw_ostream &os) {
609 os << "<set {";
610 llvm::interleaveComma(val->set, os,
611 [&](const ElaboratorValue &val) { os << val; });
612 os << "} at " << val << ">";
613}
614
615static void print(const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
616 os << "<virtual-register " << val << " " << val->allowedRegs << ">";
617}
618
619static void print(const UniqueLabelStorage *val, llvm::raw_ostream &os) {
620 os << "<unique-label " << val << " " << val->name << ">";
621}
622
623static void print(const LabelValue &val, llvm::raw_ostream &os) {
624 os << "<label " << val.name << ">";
625}
626
627static void print(const TupleStorage *val, llvm::raw_ostream &os) {
628 os << "<tuple (";
629 llvm::interleaveComma(val->values, os,
630 [&](const ElaboratorValue &val) { os << val; });
631 os << ")>";
632}
633
634static void print(const MemoryStorage *val, llvm::raw_ostream &os) {
635 os << "<memory {" << ElaboratorValue(val->memoryBlock)
636 << ", size=" << val->size << ", alignment=" << val->alignment << "}>";
637}
638
639static void print(const MemoryBlockStorage *val, llvm::raw_ostream &os) {
640 os << "<memory-block {"
641 << ", address-width=" << val->baseAddress.getBitWidth()
642 << ", base-address=" << val->baseAddress
643 << ", end-address=" << val->endAddress << "}>";
644}
645
646static void print(const ValidationValue *val, llvm::raw_ostream &os) {
647 os << "<validation-value {type=" << val->type << ", ref=" << val->ref
648 << ", defaultValue=" << val->defaultValue << "}>";
649}
650
651static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
652 const ElaboratorValue &value) {
653 std::visit([&](auto val) { print(val, os); }, value);
654
655 return os;
656}
657
658#endif
659
660//===----------------------------------------------------------------------===//
661// Elaborator Value Materialization
662//===----------------------------------------------------------------------===//
663
664namespace {
665
666/// State that should be shared by all elaborator and materializer instances.
667struct SharedState {
668 SharedState(SymbolTable &table, unsigned seed) : table(table), rng(seed) {}
669
670 SymbolTable &table;
671 std::mt19937 rng;
672 Namespace names;
673 Internalizer internalizer;
674};
675
676/// A collection of state per RTG test.
677struct TestState {
678 /// The name of the test.
679 StringAttr name;
680
681 /// The context switches registered for this test.
682 MapVector<
683 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
684 SequenceStorage *>
685 contextSwitches;
686};
687
688/// Construct an SSA value from a given elaborated value.
689class Materializer {
690public:
691 Materializer(OpBuilder builder, TestState &testState,
692 SharedState &sharedState,
693 SmallVector<ElaboratorValue> &blockArgs)
694 : builder(builder), testState(testState), sharedState(sharedState),
695 blockArgs(blockArgs) {}
696
697 /// Materialize IR representing the provided `ElaboratorValue` and return the
698 /// `Value` or a null value on failure.
699 Value materialize(ElaboratorValue val, Location loc,
700 function_ref<InFlightDiagnostic()> emitError) {
701 auto iter = materializedValues.find(val);
702 if (iter != materializedValues.end())
703 return iter->second;
704
705 LLVM_DEBUG(llvm::dbgs() << "Materializing " << val);
706
707 // In debug mode, track whether values with identity were already
708 // materialized before and assert in such a situation.
709 Value res = std::visit(
710 [&](auto value) {
711 if constexpr (std::is_base_of_v<IdentityValue,
712 std::remove_pointer_t<
713 std::decay_t<decltype(value)>>>) {
714 if (identityValueRoot.contains(value)) {
715#ifndef NDEBUG
716 bool &materialized =
717 static_cast<IdentityValue *>(value)->alreadyMaterialized;
718 assert(!materialized && "must not already be materialized");
719 materialized = true;
720#endif
721
722 return visit(value, loc, emitError);
723 }
724
725 Value arg = builder.getBlock()->addArgument(value->type, loc);
726 blockArgs.push_back(val);
727 blockArgTypes.push_back(arg.getType());
728 materializedValues[val] = arg;
729 return arg;
730 }
731
732 return visit(value, loc, emitError);
733 },
734 val);
735
736 LLVM_DEBUG(llvm::dbgs() << " to\n" << res << "\n\n");
737
738 return res;
739 }
740
741 /// If `op` is not in the same region as the materializer insertion point, a
742 /// clone is created at the materializer's insertion point by also
743 /// materializing the `ElaboratorValue`s for each operand just before it.
744 /// Otherwise, all operations after the materializer's insertion point are
745 /// deleted until `op` is reached. An error is returned if the operation is
746 /// before the insertion point.
747 LogicalResult materialize(Operation *op,
748 DenseMap<Value, ElaboratorValue> &state) {
749 if (op->getNumRegions() > 0)
750 return op->emitOpError("ops with nested regions must be elaborated away");
751
752 // We don't support opaque values. If there is an SSA value that has a
753 // use-site it needs an equivalent ElaborationValue representation.
754 // NOTE: We could support cases where there is initially a use-site but that
755 // op is guaranteed to be deleted during elaboration. Or the use-sites are
756 // replaced with freshly materialized values from the ElaborationValue. But
757 // then, why can't we delete the value defining op?
758 for (auto res : op->getResults())
759 if (!res.use_empty() && !isa<ValidateOp>(op))
760 return op->emitOpError(
761 "ops with results that have uses are not supported");
762
763 if (op->getParentRegion() == builder.getBlock()->getParent()) {
764 // We are doing in-place materialization, so mark all ops deleted until we
765 // reach the one to be materialized and modify it in-place.
766 deleteOpsUntil([&](auto iter) { return &*iter == op; });
767
768 if (builder.getInsertionPoint() == builder.getBlock()->end())
769 return op->emitError("operation did not occur after the current "
770 "materializer insertion point");
771
772 LLVM_DEBUG(llvm::dbgs() << "Modifying in-place: " << *op << "\n\n");
773 } else {
774 LLVM_DEBUG(llvm::dbgs() << "Materializing a clone of " << *op << "\n\n");
775 op = builder.clone(*op);
776 builder.setInsertionPoint(op);
777 }
778
779 for (auto &operand : op->getOpOperands()) {
780 auto emitError = [&]() {
781 auto diag = op->emitError();
782 diag.attachNote(op->getLoc())
783 << "while materializing value for operand#"
784 << operand.getOperandNumber();
785 return diag;
786 };
787
788 auto elabVal = state.at(operand.get());
789 Value val = materialize(elabVal, op->getLoc(), emitError);
790 if (!val)
791 return failure();
792
793 state[val] = elabVal;
794 operand.set(val);
795 }
796
797 builder.setInsertionPointAfter(op);
798 return success();
799 }
800
801 /// Should be called once the `Region` is successfully materialized. No calls
802 /// to `materialize` should happen after this anymore.
803 void finalize() {
804 deleteOpsUntil([](auto iter) { return false; });
805
806 for (auto *op : llvm::reverse(toDelete))
807 op->erase();
808 }
809
810 /// Tell this materializer that it is responsible for materializing the given
811 /// identity value at the earliest position it is needed, and should't
812 /// request the value via block argument.
813 void registerIdentityValue(IdentityValue *val) {
814 identityValueRoot.insert(val);
815 }
816
817 ArrayRef<Type> getBlockArgTypes() const { return blockArgTypes; }
818
819 void map(ElaboratorValue eval, Value val) { materializedValues[eval] = val; }
820
821 template <typename OpTy, typename... Args>
822 OpTy create(Location location, Args &&...args) {
823 return builder.create<OpTy>(location, std::forward<Args>(args)...);
824 }
825
826private:
827 SequenceOp elaborateSequence(const RandomizedSequenceStorage *seq,
828 SmallVector<ElaboratorValue> &elabArgs);
829
830 void deleteOpsUntil(function_ref<bool(Block::iterator)> stop) {
831 auto ip = builder.getInsertionPoint();
832 while (ip != builder.getBlock()->end() && !stop(ip)) {
833 LLVM_DEBUG(llvm::dbgs() << "Marking to be deleted: " << *ip << "\n\n");
834 toDelete.push_back(&*ip);
835
836 builder.setInsertionPointAfter(&*ip);
837 ip = builder.getInsertionPoint();
838 }
839 }
840
841 Value visit(TypedAttr val, Location loc,
842 function_ref<InFlightDiagnostic()> emitError) {
843 // For index attributes (and arithmetic operations on them) we use the
844 // index dialect.
845 if (auto intAttr = dyn_cast<IntegerAttr>(val);
846 intAttr && isa<IndexType>(val.getType())) {
847 Value res = builder.create<index::ConstantOp>(loc, intAttr);
848 materializedValues[val] = res;
849 return res;
850 }
851
852 // For any other attribute, we just call the materializer of the dialect
853 // defining that attribute.
854 auto *op =
855 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
856 if (!op) {
857 emitError() << "materializer of dialect '"
858 << val.getDialect().getNamespace()
859 << "' unable to materialize value for attribute '" << val
860 << "'";
861 return Value();
862 }
863
864 Value res = op->getResult(0);
865 materializedValues[val] = res;
866 return res;
867 }
868
869 Value visit(size_t val, Location loc,
870 function_ref<InFlightDiagnostic()> emitError) {
871 Value res = builder.create<index::ConstantOp>(loc, val);
872 materializedValues[val] = res;
873 return res;
874 }
875
876 Value visit(bool val, Location loc,
877 function_ref<InFlightDiagnostic()> emitError) {
878 Value res = builder.create<index::BoolConstantOp>(loc, val);
879 materializedValues[val] = res;
880 return res;
881 }
882
883 Value visit(ArrayStorage *val, Location loc,
884 function_ref<InFlightDiagnostic()> emitError) {
885 SmallVector<Value> elements;
886 elements.reserve(val->array.size());
887 for (auto el : val->array) {
888 auto materialized = materialize(el, loc, emitError);
889 if (!materialized)
890 return Value();
891
892 elements.push_back(materialized);
893 }
894
895 Value res = builder.create<ArrayCreateOp>(loc, val->type, elements);
896 materializedValues[val] = res;
897 return res;
898 }
899
900 Value visit(SetStorage *val, Location loc,
901 function_ref<InFlightDiagnostic()> emitError) {
902 SmallVector<Value> elements;
903 elements.reserve(val->set.size());
904 for (auto el : val->set) {
905 auto materialized = materialize(el, loc, emitError);
906 if (!materialized)
907 return Value();
908
909 elements.push_back(materialized);
910 }
911
912 auto res = builder.create<SetCreateOp>(loc, val->type, elements);
913 materializedValues[val] = res;
914 return res;
915 }
916
917 Value visit(BagStorage *val, Location loc,
918 function_ref<InFlightDiagnostic()> emitError) {
919 SmallVector<Value> values, weights;
920 values.reserve(val->bag.size());
921 weights.reserve(val->bag.size());
922 for (auto [val, weight] : val->bag) {
923 auto materializedVal = materialize(val, loc, emitError);
924 auto materializedWeight = materialize(weight, loc, emitError);
925 if (!materializedVal || !materializedWeight)
926 return Value();
927
928 values.push_back(materializedVal);
929 weights.push_back(materializedWeight);
930 }
931
932 auto res = builder.create<BagCreateOp>(loc, val->type, values, weights);
933 materializedValues[val] = res;
934 return res;
935 }
936
937 Value visit(MemoryBlockStorage *val, Location loc,
938 function_ref<InFlightDiagnostic()> emitError) {
939 auto intType = builder.getIntegerType(val->baseAddress.getBitWidth());
940 Value res = builder.create<MemoryBlockDeclareOp>(
941 loc, val->type, IntegerAttr::get(intType, val->baseAddress),
942 IntegerAttr::get(intType, val->endAddress));
943 materializedValues[val] = res;
944 return res;
945 }
946
947 Value visit(MemoryStorage *val, Location loc,
948 function_ref<InFlightDiagnostic()> emitError) {
949 auto memBlock = materialize(val->memoryBlock, loc, emitError);
950 auto memSize = materialize(val->size, loc, emitError);
951 auto memAlign = materialize(val->alignment, loc, emitError);
952 if (!(memBlock && memSize && memAlign))
953 return {};
954
955 Value res = builder.create<MemoryAllocOp>(loc, memBlock, memSize, memAlign);
956 materializedValues[val] = res;
957 return res;
958 }
959
960 Value visit(SequenceStorage *val, Location loc,
961 function_ref<InFlightDiagnostic()> emitError) {
962 emitError() << "materializing a non-randomized sequence not supported yet";
963 return Value();
964 }
965
966 Value visit(RandomizedSequenceStorage *val, Location loc,
967 function_ref<InFlightDiagnostic()> emitError) {
968 // To know which values we have to pass by argument (and not just pass all
969 // that migth be used eagerly), we have to elaborate the sequence family if
970 // not already done so.
971 // We need to get back the sequence to reference, and the list of elaborated
972 // values to pass as arguments.
973 SmallVector<ElaboratorValue> elabArgs;
974 SequenceOp seqOp = elaborateSequence(val, elabArgs);
975 if (!seqOp)
976 return {};
977
978 // Materialize all the values we need to pass as arguments and collect their
979 // types.
980 SmallVector<Value> args;
981 SmallVector<Type> argTypes;
982 for (auto arg : elabArgs) {
983 Value materialized = materialize(arg, loc, emitError);
984 if (!materialized)
985 return {};
986
987 args.push_back(materialized);
988 argTypes.push_back(materialized.getType());
989 }
990
991 Value res = builder.create<GetSequenceOp>(
992 loc, SequenceType::get(builder.getContext(), argTypes),
993 seqOp.getSymName());
994
995 // Only materialize a substitute_sequence op when we have arguments to
996 // substitute since this op does not support 0 arguments.
997 if (!args.empty())
998 res = builder.create<SubstituteSequenceOp>(loc, res, args);
999
1000 res = builder.create<RandomizeSequenceOp>(loc, res);
1001
1002 materializedValues[val] = res;
1003 return res;
1004 }
1005
1006 Value visit(InterleavedSequenceStorage *val, Location loc,
1007 function_ref<InFlightDiagnostic()> emitError) {
1008 SmallVector<Value> sequences;
1009 for (auto seqVal : val->sequences) {
1010 Value materialized = materialize(seqVal, loc, emitError);
1011 if (!materialized)
1012 return {};
1013
1014 sequences.push_back(materialized);
1015 }
1016
1017 if (sequences.size() == 1)
1018 return sequences[0];
1019
1020 Value res =
1021 builder.create<InterleaveSequencesOp>(loc, sequences, val->batchSize);
1022 materializedValues[val] = res;
1023 return res;
1024 }
1025
1026 Value visit(VirtualRegisterStorage *val, Location loc,
1027 function_ref<InFlightDiagnostic()> emitError) {
1028 Value res = builder.create<VirtualRegisterOp>(loc, val->allowedRegs);
1029 materializedValues[val] = res;
1030 return res;
1031 }
1032
1033 Value visit(UniqueLabelStorage *val, Location loc,
1034 function_ref<InFlightDiagnostic()> emitError) {
1035 Value res = builder.create<LabelUniqueDeclOp>(loc, val->name, ValueRange());
1036 materializedValues[val] = res;
1037 return res;
1038 }
1039
1040 Value visit(const LabelValue &val, Location loc,
1041 function_ref<InFlightDiagnostic()> emitError) {
1042 Value res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
1043 materializedValues[val] = res;
1044 return res;
1045 }
1046
1047 Value visit(TupleStorage *val, Location loc,
1048 function_ref<InFlightDiagnostic()> emitError) {
1049 SmallVector<Value> materialized;
1050 materialized.reserve(val->values.size());
1051 for (auto v : val->values)
1052 materialized.push_back(materialize(v, loc, emitError));
1053 Value res = builder.create<TupleCreateOp>(loc, materialized);
1054 materializedValues[val] = res;
1055 return res;
1056 }
1057
1058 Value visit(ValidationValue *val, Location loc,
1059 function_ref<InFlightDiagnostic()> emitError) {
1060 Value res = builder.create<ValidateOp>(
1061 loc, val->type, materialize(val->ref, loc, emitError),
1062 materialize(val->defaultValue, loc, emitError), val->id);
1063 materializedValues[val] = res;
1064 return res;
1065 }
1066
1067private:
1068 /// Cache values we have already materialized to reuse them later. We start
1069 /// with an insertion point at the start of the block and cache the (updated)
1070 /// insertion point such that future materializations can also reuse previous
1071 /// materializations without running into dominance issues (or requiring
1072 /// additional checks to avoid them).
1073 DenseMap<ElaboratorValue, Value> materializedValues;
1074
1075 /// Cache the builder to continue insertions at their current insertion point
1076 /// for the reason stated above.
1077 OpBuilder builder;
1078
1079 SmallVector<Operation *> toDelete;
1080
1081 TestState &testState;
1082 SharedState &sharedState;
1083
1084 /// Keep track of the block arguments we had to add to this materializer's
1085 /// block for identity values and also remember which elaborator values are
1086 /// expected to be passed as arguments from outside.
1087 SmallVector<ElaboratorValue> &blockArgs;
1088 SmallVector<Type> blockArgTypes;
1089
1090 /// Identity values in this set are materialized by this materializer,
1091 /// otherwise they are added as block arguments and the block that wants to
1092 /// embed this sequence is expected to provide a value for it.
1093 DenseSet<IdentityValue *> identityValueRoot;
1094};
1095
1096//===----------------------------------------------------------------------===//
1097// Elaboration Visitor
1098//===----------------------------------------------------------------------===//
1099
1100/// Used to signal to the elaboration driver whether the operation should be
1101/// removed.
1102enum class DeletionKind { Keep, Delete };
1103
1104/// Interprets the IR to perform and lower the represented randomizations.
1105class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
1106public:
1108 using RTGBase::visitOp;
1109
1110 Elaborator(SharedState &sharedState, TestState &testState,
1111 Materializer &materializer,
1112 ContextResourceAttrInterface currentContext = {})
1113 : sharedState(sharedState), testState(testState),
1114 materializer(materializer), currentContext(currentContext) {}
1115
1116 template <typename ValueTy>
1117 inline ValueTy get(Value val) const {
1118 return std::get<ValueTy>(state.at(val));
1119 }
1120
1121 FailureOr<DeletionKind> visitPureOp(Operation *op) {
1122 SmallVector<Attribute> operands;
1123 for (auto operand : op->getOperands()) {
1124 auto evalValue = state[operand];
1125 if (std::holds_alternative<TypedAttr>(evalValue))
1126 operands.push_back(std::get<TypedAttr>(evalValue));
1127 else
1128 return visitUnhandledOp(op);
1129 }
1130
1131 SmallVector<OpFoldResult> results;
1132 if (failed(op->fold(operands, results)))
1133 return visitUnhandledOp(op);
1134
1135 // We don't support in-place folders.
1136 if (results.size() != op->getNumResults())
1137 return visitUnhandledOp(op);
1138
1139 for (auto [res, val] : llvm::zip(results, op->getResults())) {
1140 auto attr = llvm::dyn_cast_or_null<TypedAttr>(res.dyn_cast<Attribute>());
1141 if (!attr)
1142 return op->emitError(
1143 "only typed attributes supported for constant-like operations");
1144
1145 auto intAttr = dyn_cast<IntegerAttr>(attr);
1146 if (intAttr && isa<IndexType>(attr.getType()))
1147 state[op->getResult(0)] = size_t(intAttr.getInt());
1148 else if (intAttr && intAttr.getType().isSignlessInteger(1))
1149 state[op->getResult(0)] = bool(intAttr.getInt());
1150 else
1151 state[op->getResult(0)] = attr;
1152 }
1153
1154 return DeletionKind::Delete;
1155 }
1156
1157 /// Print a nice error message for operations we don't support yet.
1158 FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
1159 return op->emitOpError("elaboration not supported");
1160 }
1161
1162 FailureOr<DeletionKind> visitExternalOp(Operation *op) {
1163 auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
1164 if (op->hasTrait<OpTrait::ConstantLike>() || (memOp && memOp.hasNoEffect()))
1165 return visitPureOp(op);
1166
1167 // TODO: we only have this to be able to write tests for this pass without
1168 // having to add support for more operations for now, so it should be
1169 // removed once it is not necessary anymore for writing tests
1170 if (op->use_empty())
1171 return DeletionKind::Keep;
1172
1173 return visitUnhandledOp(op);
1174 }
1175
1176 FailureOr<DeletionKind> visitOp(ConstantOp op) { return visitPureOp(op); }
1177
1178 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
1179 SmallVector<ElaboratorValue> replacements;
1180 state[op.getResult()] =
1181 sharedState.internalizer.internalize<SequenceStorage>(
1182 op.getSequenceAttr(), std::move(replacements));
1183 return DeletionKind::Delete;
1184 }
1185
1186 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
1187 auto *seq = get<SequenceStorage *>(op.getSequence());
1188
1189 SmallVector<ElaboratorValue> replacements(seq->args);
1190 for (auto replacement : op.getReplacements())
1191 replacements.push_back(state.at(replacement));
1192
1193 state[op.getResult()] =
1194 sharedState.internalizer.internalize<SequenceStorage>(
1195 seq->familyName, std::move(replacements));
1196
1197 return DeletionKind::Delete;
1198 }
1199
1200 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
1201 auto *seq = get<SequenceStorage *>(op.getSequence());
1202 auto *randomizedSeq =
1203 sharedState.internalizer.create<RandomizedSequenceStorage>(
1204 currentContext, seq);
1205 materializer.registerIdentityValue(randomizedSeq);
1206 state[op.getResult()] =
1207 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1208 randomizedSeq);
1209 return DeletionKind::Delete;
1210 }
1211
1212 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
1213 SmallVector<ElaboratorValue> sequences;
1214 for (auto seq : op.getSequences())
1215 sequences.push_back(get<InterleavedSequenceStorage *>(seq));
1216
1217 state[op.getResult()] =
1218 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1219 std::move(sequences), op.getBatchSize());
1220 return DeletionKind::Delete;
1221 }
1222
1223 // NOLINTNEXTLINE(misc-no-recursion)
1224 LogicalResult isValidContext(ElaboratorValue value, Operation *op) const {
1225 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
1226 auto *seq = std::get<RandomizedSequenceStorage *>(value);
1227 if (seq->context != currentContext) {
1228 auto err = op->emitError("attempting to place sequence derived from ")
1229 << seq->sequence->familyName.getValue() << " under context "
1230 << currentContext
1231 << ", but it was previously randomized for context ";
1232 if (seq->context)
1233 err << seq->context;
1234 else
1235 err << "'default'";
1236 return err;
1237 }
1238 return success();
1239 }
1240
1241 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1242 for (auto val : interVal->sequences)
1243 if (failed(isValidContext(val, op)))
1244 return failure();
1245 return success();
1246 }
1247
1248 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1249 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1250 if (failed(isValidContext(seqVal, op)))
1251 return failure();
1252
1253 return DeletionKind::Keep;
1254 }
1255
1256 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1257 SetVector<ElaboratorValue> set;
1258 for (auto val : op.getElements())
1259 set.insert(state.at(val));
1260
1261 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1262 std::move(set), op.getSet().getType());
1263 return DeletionKind::Delete;
1264 }
1265
1266 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1267 auto set = get<SetStorage *>(op.getSet())->set;
1268
1269 if (set.empty())
1270 return op->emitError("cannot select from an empty set");
1271
1272 size_t selected;
1273 if (auto intAttr =
1274 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1275 std::mt19937 customRng(intAttr.getInt());
1276 selected = getUniformlyInRange(customRng, 0, set.size() - 1);
1277 } else {
1278 selected = getUniformlyInRange(sharedState.rng, 0, set.size() - 1);
1279 }
1280
1281 state[op.getResult()] = set[selected];
1282 return DeletionKind::Delete;
1283 }
1284
1285 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1286 auto original = get<SetStorage *>(op.getOriginal())->set;
1287 auto diff = get<SetStorage *>(op.getDiff())->set;
1288
1289 SetVector<ElaboratorValue> result(original);
1290 result.set_subtract(diff);
1291
1292 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1293 std::move(result), op.getResult().getType());
1294 return DeletionKind::Delete;
1295 }
1296
1297 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1298 SetVector<ElaboratorValue> result;
1299 for (auto set : op.getSets())
1300 result.set_union(get<SetStorage *>(set)->set);
1301
1302 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1303 std::move(result), op.getType());
1304 return DeletionKind::Delete;
1305 }
1306
1307 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1308 auto size = get<SetStorage *>(op.getSet())->set.size();
1309 state[op.getResult()] = size;
1310 return DeletionKind::Delete;
1311 }
1312
1313 // {a0,a1} x {b0,b1} x {c0,c1} -> {(a0), (a1)} -> {(a0,b0), (a0,b1), (a1,b0),
1314 // (a1,b1)} -> {(a0,b0,c0), (a0,b0,c1), (a0,b1,c0), (a0,b1,c1), (a1,b0,c0),
1315 // (a1,b0,c1), (a1,b1,c0), (a1,b1,c1)}
1316 FailureOr<DeletionKind> visitOp(SetCartesianProductOp op) {
1317 SetVector<ElaboratorValue> result;
1318 SmallVector<SmallVector<ElaboratorValue>> tuples;
1319 tuples.push_back({});
1320
1321 for (auto input : op.getInputs()) {
1322 auto &set = get<SetStorage *>(input)->set;
1323 if (set.empty()) {
1324 SetVector<ElaboratorValue> empty;
1325 state[op.getResult()] =
1326 sharedState.internalizer.internalize<SetStorage>(std::move(empty),
1327 op.getType());
1328 return DeletionKind::Delete;
1329 }
1330
1331 for (unsigned i = 0, e = tuples.size(); i < e; ++i) {
1332 for (auto setEl : set.getArrayRef().drop_back()) {
1333 tuples.push_back(tuples[i]);
1334 tuples.back().push_back(setEl);
1335 }
1336 tuples[i].push_back(set.back());
1337 }
1338 }
1339
1340 for (auto &tup : tuples)
1341 result.insert(
1342 sharedState.internalizer.internalize<TupleStorage>(std::move(tup)));
1343
1344 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1345 std::move(result), op.getType());
1346 return DeletionKind::Delete;
1347 }
1348
1349 FailureOr<DeletionKind> visitOp(SetConvertToBagOp op) {
1350 auto set = get<SetStorage *>(op.getInput())->set;
1351 MapVector<ElaboratorValue, uint64_t> bag;
1352 for (auto val : set)
1353 bag.insert({val, 1});
1354 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1355 std::move(bag), op.getType());
1356 return DeletionKind::Delete;
1357 }
1358
1359 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1360 MapVector<ElaboratorValue, uint64_t> bag;
1361 for (auto [val, multiple] :
1362 llvm::zip(op.getElements(), op.getMultiples())) {
1363 // If the multiple is not stored as an AttributeValue, the elaboration
1364 // must have already failed earlier (since we don't have
1365 // unevaluated/opaque values).
1366 bag[state.at(val)] += get<size_t>(multiple);
1367 }
1368
1369 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1370 std::move(bag), op.getType());
1371 return DeletionKind::Delete;
1372 }
1373
1374 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1375 auto bag = get<BagStorage *>(op.getBag())->bag;
1376
1377 if (bag.empty())
1378 return op->emitError("cannot select from an empty bag");
1379
1380 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1381 prefixSum.reserve(bag.size());
1382 uint32_t accumulator = 0;
1383 for (auto [val, weight] : bag) {
1384 accumulator += weight;
1385 prefixSum.push_back({val, accumulator});
1386 }
1387
1388 auto customRng = sharedState.rng;
1389 if (auto intAttr =
1390 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1391 customRng = std::mt19937(intAttr.getInt());
1392 }
1393
1394 auto idx = getUniformlyInRange(customRng, 0, accumulator - 1);
1395 auto *iter = llvm::upper_bound(
1396 prefixSum, idx,
1397 [](uint32_t a, const std::pair<ElaboratorValue, uint32_t> &b) {
1398 return a < b.second;
1399 });
1400
1401 state[op.getResult()] = iter->first;
1402 return DeletionKind::Delete;
1403 }
1404
1405 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1406 auto original = get<BagStorage *>(op.getOriginal())->bag;
1407 auto diff = get<BagStorage *>(op.getDiff())->bag;
1408
1409 MapVector<ElaboratorValue, uint64_t> result;
1410 for (const auto &el : original) {
1411 if (!diff.contains(el.first)) {
1412 result.insert(el);
1413 continue;
1414 }
1415
1416 if (op.getInf())
1417 continue;
1418
1419 auto toDiff = diff.lookup(el.first);
1420 if (el.second <= toDiff)
1421 continue;
1422
1423 result.insert({el.first, el.second - toDiff});
1424 }
1425
1426 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1427 std::move(result), op.getType());
1428 return DeletionKind::Delete;
1429 }
1430
1431 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1432 MapVector<ElaboratorValue, uint64_t> result;
1433 for (auto bag : op.getBags()) {
1434 auto val = get<BagStorage *>(bag)->bag;
1435 for (auto [el, multiple] : val)
1436 result[el] += multiple;
1437 }
1438
1439 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1440 std::move(result), op.getType());
1441 return DeletionKind::Delete;
1442 }
1443
1444 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1445 auto size = get<BagStorage *>(op.getBag())->bag.size();
1446 state[op.getResult()] = size;
1447 return DeletionKind::Delete;
1448 }
1449
1450 FailureOr<DeletionKind> visitOp(BagConvertToSetOp op) {
1451 auto bag = get<BagStorage *>(op.getInput())->bag;
1452 SetVector<ElaboratorValue> set;
1453 for (auto [k, v] : bag)
1454 set.insert(k);
1455 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1456 std::move(set), op.getType());
1457 return DeletionKind::Delete;
1458 }
1459
1460 FailureOr<DeletionKind> visitOp(FixedRegisterOp op) {
1461 return visitPureOp(op);
1462 }
1463
1464 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1465 auto *val = sharedState.internalizer.create<VirtualRegisterStorage>(
1466 op.getAllowedRegsAttr(), op.getType());
1467 state[op.getResult()] = val;
1468 materializer.registerIdentityValue(val);
1469 return DeletionKind::Delete;
1470 }
1471
1472 StringAttr substituteFormatString(StringAttr formatString,
1473 ValueRange substitutes) const {
1474 if (substitutes.empty() || formatString.empty())
1475 return formatString;
1476
1477 auto original = formatString.getValue().str();
1478 for (auto [i, subst] : llvm::enumerate(substitutes)) {
1479 size_t startPos = 0;
1480 std::string from = "{{" + std::to_string(i) + "}}";
1481 while ((startPos = original.find(from, startPos)) != std::string::npos) {
1482 auto substString = std::to_string(get<size_t>(subst));
1483 original.replace(startPos, from.length(), substString);
1484 }
1485 }
1486
1487 return StringAttr::get(formatString.getContext(), original);
1488 }
1489
1490 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1491 SmallVector<ElaboratorValue> array;
1492 array.reserve(op.getElements().size());
1493 for (auto val : op.getElements())
1494 array.emplace_back(state.at(val));
1495
1496 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1497 op.getResult().getType(), std::move(array));
1498 return DeletionKind::Delete;
1499 }
1500
1501 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1502 auto array = get<ArrayStorage *>(op.getArray())->array;
1503 size_t idx = get<size_t>(op.getIndex());
1504
1505 if (array.size() <= idx)
1506 return op->emitError("invalid to access index ")
1507 << idx << " of an array with " << array.size() << " elements";
1508
1509 state[op.getResult()] = array[idx];
1510 return DeletionKind::Delete;
1511 }
1512
1513 FailureOr<DeletionKind> visitOp(ArrayInjectOp op) {
1514 auto array = get<ArrayStorage *>(op.getArray())->array;
1515 size_t idx = get<size_t>(op.getIndex());
1516
1517 if (array.size() <= idx)
1518 return op->emitError("invalid to access index ")
1519 << idx << " of an array with " << array.size() << " elements";
1520
1521 array[idx] = state[op.getValue()];
1522 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1523 op.getResult().getType(), std::move(array));
1524 return DeletionKind::Delete;
1525 }
1526
1527 FailureOr<DeletionKind> visitOp(ArraySizeOp op) {
1528 auto array = get<ArrayStorage *>(op.getArray())->array;
1529 state[op.getResult()] = array.size();
1530 return DeletionKind::Delete;
1531 }
1532
1533 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
1534 auto substituted =
1535 substituteFormatString(op.getFormatStringAttr(), op.getArgs());
1536 state[op.getLabel()] = LabelValue(substituted);
1537 return DeletionKind::Delete;
1538 }
1539
1540 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1541 auto *val = sharedState.internalizer.create<UniqueLabelStorage>(
1542 substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
1543 state[op.getLabel()] = val;
1544 materializer.registerIdentityValue(val);
1545 return DeletionKind::Delete;
1546 }
1547
1548 FailureOr<DeletionKind> visitOp(LabelOp op) { return DeletionKind::Keep; }
1549
1550 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1551 size_t lower = get<size_t>(op.getLowerBound());
1552 size_t upper = get<size_t>(op.getUpperBound()) - 1;
1553 if (lower > upper)
1554 return op->emitError("cannot select a number from an empty range");
1555
1556 if (auto intAttr =
1557 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1558 std::mt19937 customRng(intAttr.getInt());
1559 state[op.getResult()] =
1560 size_t(getUniformlyInRange(customRng, lower, upper));
1561 } else {
1562 state[op.getResult()] =
1563 size_t(getUniformlyInRange(sharedState.rng, lower, upper));
1564 }
1565
1566 return DeletionKind::Delete;
1567 }
1568
1569 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1570 size_t input = get<size_t>(op.getInput());
1571 auto width = op.getType().getWidth();
1572 auto emitError = [&]() { return op->emitError(); };
1573 if (input > APInt::getAllOnes(width).getZExtValue())
1574 return emitError() << "cannot represent " << input << " with " << width
1575 << " bits";
1576
1577 state[op.getResult()] =
1578 ImmediateAttr::get(op.getContext(), APInt(width, input));
1579 return DeletionKind::Delete;
1580 }
1581
1582 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1583 ContextResourceAttrInterface from = currentContext,
1584 to = cast<ContextResourceAttrInterface>(
1585 get<TypedAttr>(op.getContext()));
1586 if (!currentContext)
1587 from = DefaultContextAttr::get(op->getContext(), to.getType());
1588
1589 auto emitError = [&]() {
1590 auto diag = op.emitError();
1591 diag.attachNote(op.getLoc())
1592 << "while materializing value for context switching for " << op;
1593 return diag;
1594 };
1595
1596 if (from == to) {
1597 Value seqVal = materializer.materialize(
1598 get<SequenceStorage *>(op.getSequence()), op.getLoc(), emitError);
1599 if (!seqVal)
1600 return failure();
1601
1602 Value randSeqVal =
1603 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1604 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1605 return DeletionKind::Delete;
1606 }
1607
1608 // Switch to the desired context.
1609 // First, check if a context switch is registered that has the concrete
1610 // context as source and target.
1611 auto *iter = testState.contextSwitches.find({from, to});
1612
1613 // Try with 'any' context as target and the concrete context as source.
1614 if (iter == testState.contextSwitches.end())
1615 iter = testState.contextSwitches.find(
1616 {from, AnyContextAttr::get(op->getContext(), to.getType())});
1617
1618 // Try with 'any' context as source and the concrete context as target.
1619 if (iter == testState.contextSwitches.end())
1620 iter = testState.contextSwitches.find(
1621 {AnyContextAttr::get(op->getContext(), from.getType()), to});
1622
1623 // Try with 'any' context for both the source and the target.
1624 if (iter == testState.contextSwitches.end())
1625 iter = testState.contextSwitches.find(
1626 {AnyContextAttr::get(op->getContext(), from.getType()),
1627 AnyContextAttr::get(op->getContext(), to.getType())});
1628
1629 // Otherwise, fail with an error because we couldn't find a user
1630 // specification on how to switch between the requested contexts.
1631 // NOTE: we could think about supporting context switching via intermediate
1632 // context, i.e., treat it as a transitive relation.
1633 if (iter == testState.contextSwitches.end())
1634 return op->emitError("no context transition registered to switch from ")
1635 << from << " to " << to;
1636
1637 auto familyName = iter->second->familyName;
1638 SmallVector<ElaboratorValue> args{from, to,
1639 get<SequenceStorage *>(op.getSequence())};
1640 auto *seq = sharedState.internalizer.internalize<SequenceStorage>(
1641 familyName, std::move(args));
1642 auto *randSeq =
1643 sharedState.internalizer.create<RandomizedSequenceStorage>(to, seq);
1644 materializer.registerIdentityValue(randSeq);
1645 Value seqVal = materializer.materialize(randSeq, op.getLoc(), emitError);
1646 if (!seqVal)
1647 return failure();
1648
1649 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1650 return DeletionKind::Delete;
1651 }
1652
1653 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1654 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1655 get<SequenceStorage *>(op.getSequence());
1656 return DeletionKind::Delete;
1657 }
1658
1659 FailureOr<DeletionKind> visitOp(MemoryBlockDeclareOp op) {
1660 auto *val = sharedState.internalizer.create<MemoryBlockStorage>(
1661 op.getBaseAddress(), op.getEndAddress(), op.getType());
1662 state[op.getResult()] = val;
1663 materializer.registerIdentityValue(val);
1664 return DeletionKind::Delete;
1665 }
1666
1667 FailureOr<DeletionKind> visitOp(MemoryAllocOp op) {
1668 size_t size = get<size_t>(op.getSize());
1669 size_t alignment = get<size_t>(op.getAlignment());
1670 auto *memBlock = get<MemoryBlockStorage *>(op.getMemoryBlock());
1671 auto *val = sharedState.internalizer.create<MemoryStorage>(memBlock, size,
1672 alignment);
1673 state[op.getResult()] = val;
1674 materializer.registerIdentityValue(val);
1675 return DeletionKind::Delete;
1676 }
1677
1678 FailureOr<DeletionKind> visitOp(MemorySizeOp op) {
1679 auto *memory = get<MemoryStorage *>(op.getMemory());
1680 state[op.getResult()] = memory->size;
1681 return DeletionKind::Delete;
1682 }
1683
1684 FailureOr<DeletionKind> visitOp(TupleCreateOp op) {
1685 SmallVector<ElaboratorValue> values;
1686 values.reserve(op.getElements().size());
1687 for (auto el : op.getElements())
1688 values.push_back(state[el]);
1689
1690 state[op.getResult()] =
1691 sharedState.internalizer.internalize<TupleStorage>(std::move(values));
1692 return DeletionKind::Delete;
1693 }
1694
1695 FailureOr<DeletionKind> visitOp(TupleExtractOp op) {
1696 auto *tuple = get<TupleStorage *>(op.getTuple());
1697 state[op.getResult()] = tuple->values[op.getIndex().getZExtValue()];
1698 return DeletionKind::Delete;
1699 }
1700
1701 FailureOr<DeletionKind> visitOp(CommentOp op) { return DeletionKind::Keep; }
1702
1703 FailureOr<DeletionKind> visitOp(rtg::YieldOp op) {
1704 return DeletionKind::Keep;
1705 }
1706
1707 FailureOr<DeletionKind> visitOp(ValidateOp op) {
1708 auto *validationVal = sharedState.internalizer.create<ValidationValue>(
1709 op.getType(), state[op.getRef()], state[op.getDefaultValue()],
1710 op.getIdAttr());
1711 state[op.getValue()] = validationVal;
1712 materializer.registerIdentityValue(validationVal);
1713 materializer.map(validationVal, op.getValue());
1714 return DeletionKind::Keep;
1715 }
1716
1717 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1718 bool cond = get<bool>(op.getCondition());
1719 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1720 if (toElaborate.empty())
1721 return DeletionKind::Delete;
1722
1723 // Just reuse this elaborator for the nested region because we need access
1724 // to the elaborated values outside the nested region (since it is not
1725 // isolated from above) and we want to materialize the region inline, thus
1726 // don't need a new materializer instance.
1727 SmallVector<ElaboratorValue> yieldedVals;
1728 if (failed(elaborate(toElaborate, {}, yieldedVals)))
1729 return failure();
1730
1731 // Map the results of the 'scf.if' to the yielded values.
1732 for (auto [res, out] : llvm::zip(op.getResults(), yieldedVals))
1733 state[res] = out;
1734
1735 return DeletionKind::Delete;
1736 }
1737
1738 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1739 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1740 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1741 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1742 return op->emitOpError("can only elaborate index type iterator");
1743
1744 auto lowerBound = get<size_t>(op.getLowerBound());
1745 auto step = get<size_t>(op.getStep());
1746 auto upperBound = get<size_t>(op.getUpperBound());
1747
1748 // Prepare for first iteration by assigning the nested regions block
1749 // arguments. We can just reuse this elaborator because we need access to
1750 // values elaborated in the parent region anyway and materialize everything
1751 // inline (i.e., don't need a new materializer).
1752 state[op.getInductionVar()] = lowerBound;
1753 for (auto [iterArg, initArg] :
1754 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1755 state[iterArg] = state.at(initArg);
1756
1757 // This loop performs the actual 'scf.for' loop iterations.
1758 SmallVector<ElaboratorValue> yieldedVals;
1759 for (size_t i = lowerBound; i < upperBound; i += step) {
1760 yieldedVals.clear();
1761 if (failed(elaborate(op.getBodyRegion(), {}, yieldedVals)))
1762 return failure();
1763
1764 // Prepare for the next iteration by updating the mapping of the nested
1765 // regions block arguments
1766 state[op.getInductionVar()] = i + step;
1767 for (auto [iterArg, prevIterArg] :
1768 llvm::zip(op.getRegionIterArgs(), yieldedVals))
1769 state[iterArg] = prevIterArg;
1770 }
1771
1772 // Transfer the previously yielded values to the for loop result values.
1773 for (auto [res, iterArg] :
1774 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1775 state[res] = state.at(iterArg);
1776
1777 return DeletionKind::Delete;
1778 }
1779
1780 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1781 return DeletionKind::Delete;
1782 }
1783
1784 FailureOr<DeletionKind> visitOp(arith::AddIOp op) {
1785 if (!isa<IndexType>(op.getType()))
1786 return op->emitError("only index operands supported");
1787
1788 size_t lhs = get<size_t>(op.getLhs());
1789 size_t rhs = get<size_t>(op.getRhs());
1790 state[op.getResult()] = lhs + rhs;
1791 return DeletionKind::Delete;
1792 }
1793
1794 FailureOr<DeletionKind> visitOp(arith::AndIOp op) {
1795 if (!op.getType().isSignlessInteger(1))
1796 return op->emitError("only 'i1' operands supported");
1797
1798 bool lhs = get<bool>(op.getLhs());
1799 bool rhs = get<bool>(op.getRhs());
1800 state[op.getResult()] = lhs && rhs;
1801 return DeletionKind::Delete;
1802 }
1803
1804 FailureOr<DeletionKind> visitOp(arith::XOrIOp op) {
1805 if (!op.getType().isSignlessInteger(1))
1806 return op->emitError("only 'i1' operands supported");
1807
1808 bool lhs = get<bool>(op.getLhs());
1809 bool rhs = get<bool>(op.getRhs());
1810 state[op.getResult()] = lhs != rhs;
1811 return DeletionKind::Delete;
1812 }
1813
1814 FailureOr<DeletionKind> visitOp(arith::OrIOp op) {
1815 if (!op.getType().isSignlessInteger(1))
1816 return op->emitError("only 'i1' operands supported");
1817
1818 bool lhs = get<bool>(op.getLhs());
1819 bool rhs = get<bool>(op.getRhs());
1820 state[op.getResult()] = lhs || rhs;
1821 return DeletionKind::Delete;
1822 }
1823
1824 FailureOr<DeletionKind> visitOp(arith::SelectOp op) {
1825 bool cond = get<bool>(op.getCondition());
1826 auto trueVal = state[op.getTrueValue()];
1827 auto falseVal = state[op.getFalseValue()];
1828 state[op.getResult()] = cond ? trueVal : falseVal;
1829 return DeletionKind::Delete;
1830 }
1831
1832 FailureOr<DeletionKind> visitOp(index::AddOp op) {
1833 size_t lhs = get<size_t>(op.getLhs());
1834 size_t rhs = get<size_t>(op.getRhs());
1835 state[op.getResult()] = lhs + rhs;
1836 return DeletionKind::Delete;
1837 }
1838
1839 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
1840 size_t lhs = get<size_t>(op.getLhs());
1841 size_t rhs = get<size_t>(op.getRhs());
1842 bool result;
1843 switch (op.getPred()) {
1844 case index::IndexCmpPredicate::EQ:
1845 result = lhs == rhs;
1846 break;
1847 case index::IndexCmpPredicate::NE:
1848 result = lhs != rhs;
1849 break;
1850 case index::IndexCmpPredicate::ULT:
1851 result = lhs < rhs;
1852 break;
1853 case index::IndexCmpPredicate::ULE:
1854 result = lhs <= rhs;
1855 break;
1856 case index::IndexCmpPredicate::UGT:
1857 result = lhs > rhs;
1858 break;
1859 case index::IndexCmpPredicate::UGE:
1860 result = lhs >= rhs;
1861 break;
1862 default:
1863 return op->emitOpError("elaboration not supported");
1864 }
1865 state[op.getResult()] = result;
1866 return DeletionKind::Delete;
1867 }
1868
1869 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
1870 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
1871 .Case<
1872 // Arith ops
1873 arith::AddIOp, arith::XOrIOp, arith::AndIOp, arith::OrIOp,
1874 arith::SelectOp,
1875 // Index ops
1876 index::AddOp, index::CmpOp,
1877 // SCF ops
1878 scf::IfOp, scf::ForOp, scf::YieldOp>(
1879 [&](auto op) { return visitOp(op); })
1880 .Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); });
1881 }
1882
1883 // NOLINTNEXTLINE(misc-no-recursion)
1884 LogicalResult elaborate(Region &region,
1885 ArrayRef<ElaboratorValue> regionArguments,
1886 SmallVector<ElaboratorValue> &terminatorOperands) {
1887 if (region.getBlocks().size() > 1)
1888 return region.getParentOp()->emitOpError(
1889 "regions with more than one block are not supported");
1890
1891 for (auto [arg, elabArg] :
1892 llvm::zip(region.getArguments(), regionArguments))
1893 state[arg] = elabArg;
1894
1895 Block *block = &region.front();
1896 for (auto &op : *block) {
1897 auto result = dispatchOpVisitor(&op);
1898 if (failed(result))
1899 return failure();
1900
1901 if (*result == DeletionKind::Keep)
1902 if (failed(materializer.materialize(&op, state)))
1903 return failure();
1904
1905 LLVM_DEBUG({
1906 llvm::dbgs() << "Elaborated " << op << " to\n[";
1907
1908 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) {
1909 if (state.contains(res))
1910 llvm::dbgs() << state.at(res);
1911 else
1912 llvm::dbgs() << "unknown";
1913 });
1914
1915 llvm::dbgs() << "]\n\n";
1916 });
1917 }
1918
1919 if (region.front().mightHaveTerminator())
1920 for (auto val : region.front().getTerminator()->getOperands())
1921 terminatorOperands.push_back(state.at(val));
1922
1923 return success();
1924 }
1925
1926private:
1927 // State to be shared between all elaborator instances.
1928 SharedState &sharedState;
1929
1930 // State to a specific RTG test and the sequences placed within it.
1931 TestState &testState;
1932
1933 // Allows us to materialize ElaboratorValues to the IR operations necessary to
1934 // obtain an SSA value representing that elaborated value.
1935 Materializer &materializer;
1936
1937 // A map from SSA values to a pointer of an interned elaborator value.
1938 DenseMap<Value, ElaboratorValue> state;
1939
1940 // The current context we are elaborating under.
1941 ContextResourceAttrInterface currentContext;
1942};
1943} // namespace
1944
1945SequenceOp
1946Materializer::elaborateSequence(const RandomizedSequenceStorage *seq,
1947 SmallVector<ElaboratorValue> &elabArgs) {
1948 auto familyOp =
1949 sharedState.table.lookup<SequenceOp>(seq->sequence->familyName);
1950 // TODO: don't clone if this is the only remaining reference to this
1951 // sequence
1952 OpBuilder builder(familyOp);
1953 auto seqOp = builder.cloneWithoutRegions(familyOp);
1954 auto name = sharedState.names.newName(seq->sequence->familyName.getValue());
1955 seqOp.setSymName(name);
1956 seqOp.getBodyRegion().emplaceBlock();
1957 sharedState.table.insert(seqOp);
1958 assert(seqOp.getSymName() == name && "should not have been renamed");
1959
1960 LLVM_DEBUG(llvm::dbgs() << "\n=== Elaborating sequence family @"
1961 << familyOp.getSymName() << " into @"
1962 << seqOp.getSymName() << " under context "
1963 << seq->context << "\n\n");
1964
1965 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()), testState,
1966 sharedState, elabArgs);
1967 Elaborator elaborator(sharedState, testState, materializer, seq->context);
1968 SmallVector<ElaboratorValue> yieldedVals;
1969 if (failed(elaborator.elaborate(familyOp.getBodyRegion(), seq->sequence->args,
1970 yieldedVals)))
1971 return {};
1972
1973 seqOp.setSequenceType(
1974 SequenceType::get(builder.getContext(), materializer.getBlockArgTypes()));
1975 materializer.finalize();
1976
1977 return seqOp;
1978}
1979
1980//===----------------------------------------------------------------------===//
1981// Elaborator Pass
1982//===----------------------------------------------------------------------===//
1983
1984namespace {
1985struct ElaborationPass
1986 : public rtg::impl::ElaborationPassBase<ElaborationPass> {
1987 using Base::Base;
1988
1989 void runOnOperation() override;
1990 void matchTestsAgainstTargets(SymbolTable &table);
1991 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
1992};
1993} // namespace
1994
1995void ElaborationPass::runOnOperation() {
1996 auto moduleOp = getOperation();
1997 SymbolTable table(moduleOp);
1998
1999 matchTestsAgainstTargets(table);
2000
2001 if (failed(elaborateModule(moduleOp, table)))
2002 return signalPassFailure();
2003}
2004
2005void ElaborationPass::matchTestsAgainstTargets(SymbolTable &table) {
2006 auto moduleOp = getOperation();
2007
2008 for (auto test : llvm::make_early_inc_range(moduleOp.getOps<TestOp>())) {
2009 if (test.getTargetAttr())
2010 continue;
2011
2012 bool matched = false;
2013
2014 for (auto target : moduleOp.getOps<TargetOp>()) {
2015 // Check if the target type is a subtype of the test's target type
2016 // This means that for each entry in the test's target type, there must be
2017 // a corresponding entry with the same name and type in the target's type
2018 bool isSubtype = true;
2019 auto testEntries = test.getTargetType().getEntries();
2020 auto targetEntries = target.getTarget().getEntries();
2021
2022 // Check if target is a subtype of test requirements
2023 // Since entries are sorted by name, we can do this in a single pass
2024 size_t targetIdx = 0;
2025 for (auto testEntry : testEntries) {
2026 // Find the matching entry in target entries.
2027 while (targetIdx < targetEntries.size() &&
2028 targetEntries[targetIdx].name.getValue() <
2029 testEntry.name.getValue())
2030 targetIdx++;
2031
2032 // Check if we found a matching entry with the same name and type
2033 if (targetIdx >= targetEntries.size() ||
2034 targetEntries[targetIdx].name != testEntry.name ||
2035 targetEntries[targetIdx].type != testEntry.type) {
2036 isSubtype = false;
2037 break;
2038 }
2039 }
2040
2041 if (!isSubtype)
2042 continue;
2043
2044 IRRewriter rewriter(test);
2045 // Create a new test for the matched target
2046 auto newTest = cast<TestOp>(test->clone());
2047 newTest.setSymName(test.getSymName().str() + "_" +
2048 target.getSymName().str());
2049
2050 // Set the target symbol specifying that this test is only suitable for
2051 // that target.
2052 newTest.setTargetAttr(target.getSymNameAttr());
2053
2054 table.insert(newTest, rewriter.getInsertionPoint());
2055 matched = true;
2056 }
2057
2058 if (matched || deleteUnmatchedTests)
2059 test->erase();
2060 }
2061}
2062
2063static bool onlyLegalToMaterializeInTarget(Type type) {
2064 return isa<MemoryBlockType, ContextResourceTypeInterface>(type);
2065}
2066
2067LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
2068 SymbolTable &table) {
2069 SharedState state(table, seed);
2070
2071 // Update the name cache
2072 state.names.add(moduleOp);
2073
2074 struct TargetElabResult {
2075 DictType targetType;
2076 SmallVector<ElaboratorValue> yields;
2077 TestState testState;
2078 };
2079
2080 // Map to store elaborated targets
2081 DenseMap<StringAttr, TargetElabResult> targetMap;
2082 for (auto targetOp : moduleOp.getOps<TargetOp>()) {
2083 LLVM_DEBUG(llvm::dbgs() << "=== Elaborating target @"
2084 << targetOp.getSymName() << "\n\n");
2085
2086 auto &result = targetMap[targetOp.getSymNameAttr()];
2087 result.targetType = targetOp.getTarget();
2088
2089 SmallVector<ElaboratorValue> blockArgs;
2090 Materializer targetMaterializer(OpBuilder::atBlockBegin(targetOp.getBody()),
2091 result.testState, state, blockArgs);
2092 Elaborator targetElaborator(state, result.testState, targetMaterializer);
2093
2094 // Elaborate the target
2095 if (failed(targetElaborator.elaborate(targetOp.getBodyRegion(), {},
2096 result.yields)))
2097 return failure();
2098 }
2099
2100 // Initialize the worklist with the test ops since they cannot be placed by
2101 // other ops.
2102 for (auto testOp : moduleOp.getOps<TestOp>()) {
2103 // Skip tests without a target attribute - these couldn't be matched
2104 // against any target but can be useful to keep around for reporting
2105 // purposes.
2106 if (!testOp.getTargetAttr())
2107 continue;
2108
2109 LLVM_DEBUG(llvm::dbgs()
2110 << "\n=== Elaborating test @" << testOp.getTemplateName()
2111 << " for target @" << *testOp.getTarget() << "\n\n");
2112
2113 // Get the target for this test
2114 auto targetResult = targetMap[testOp.getTargetAttr()];
2115 TestState testState = targetResult.testState;
2116 testState.name = testOp.getSymNameAttr();
2117
2118 SmallVector<ElaboratorValue> filteredYields;
2119 unsigned i = 0;
2120 for (auto [entry, yield] :
2121 llvm::zip(targetResult.targetType.getEntries(), targetResult.yields)) {
2122 if (i >= testOp.getTargetType().getEntries().size())
2123 break;
2124
2125 if (entry.name == testOp.getTargetType().getEntries()[i].name) {
2126 filteredYields.push_back(yield);
2127 ++i;
2128 }
2129 }
2130
2131 // Now elaborate the test with the same state, passing the target yield
2132 // values as arguments
2133 SmallVector<ElaboratorValue> blockArgs;
2134 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()),
2135 testState, state, blockArgs);
2136
2137 for (auto [arg, val] :
2138 llvm::zip(testOp.getBody()->getArguments(), filteredYields))
2139 if (onlyLegalToMaterializeInTarget(arg.getType()))
2140 materializer.map(val, arg);
2141
2142 Elaborator elaborator(state, testState, materializer);
2143 SmallVector<ElaboratorValue> ignore;
2144 if (failed(elaborator.elaborate(testOp.getBodyRegion(), filteredYields,
2145 ignore)))
2146 return failure();
2147
2148 materializer.finalize();
2149 }
2150
2151 return success();
2152}
assert(baseType &&"element must be base type")
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static bool onlyLegalToMaterializeInTarget(Type type)
static void print(TypedAttr val, llvm::raw_ostream &os)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
This helps visit TypeOp nodes.
Definition RTGVisitors.h:29
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:90
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
static llvm::hash_code hash_value(const ModulePort &port)
Definition HWTypes.h:38
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition Utils.h:32
Definition rtg.py:1
Definition seq.py:1
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)