CIRCT 22.0.0git
Loading...
Searching...
No Matches
ElaborationPass.cpp
Go to the documentation of this file.
1//===- ElaborationPass.cpp - RTG ElaborationPass implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass elaborates the random parts of the RTG dialect.
10// It performs randomization top-down, i.e., random constructs in a sequence
11// that is invoked multiple times can yield different randomization results
12// for each invokation.
13//
14//===----------------------------------------------------------------------===//
15
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Index/IR/IndexDialect.h"
23#include "mlir/Dialect/Index/IR/IndexOps.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/DenseMapInfoVariant.h"
28#include "llvm/Support/Debug.h"
29#include <queue>
30#include <random>
31
32namespace circt {
33namespace rtg {
34#define GEN_PASS_DEF_ELABORATIONPASS
35#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
36} // namespace rtg
37} // namespace circt
38
39using namespace mlir;
40using namespace circt;
41using namespace circt::rtg;
42using llvm::MapVector;
43
44#define DEBUG_TYPE "rtg-elaboration"
45
46//===----------------------------------------------------------------------===//
47// Uniform Distribution Helper
48//
49// Simplified version of
50// https://github.com/llvm/llvm-project/blob/main/libcxx/include/__random/uniform_int_distribution.h
51// We use our custom version here to get the same results when compiled with
52// different compiler versions and standard libraries.
53//===----------------------------------------------------------------------===//
54
55static uint32_t computeMask(size_t w) {
56 size_t n = w / 32 + (w % 32 != 0);
57 size_t w0 = w / n;
58 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
59}
60
61/// Get a number uniformly at random in the in specified range.
62static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) {
63 const uint32_t diff = b - a + 1;
64 if (diff == 1)
65 return a;
66
67 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
68 if (diff == 0)
69 return rng();
70
71 uint32_t width = digits - llvm::countl_zero(diff) - 1;
72 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
73 ++width;
74
75 uint32_t mask = computeMask(diff);
76 uint32_t u;
77 do {
78 u = rng() & mask;
79 } while (u >= diff);
80
81 return u + a;
82}
83
84//===----------------------------------------------------------------------===//
85// Elaborator Value
86//===----------------------------------------------------------------------===//
87
88namespace {
89struct ArrayStorage;
90struct BagStorage;
91struct SequenceStorage;
92struct RandomizedSequenceStorage;
93struct InterleavedSequenceStorage;
94struct SetStorage;
95struct VirtualRegisterStorage;
96struct UniqueLabelStorage;
97struct TupleStorage;
98struct MemoryStorage;
99struct MemoryBlockStorage;
100struct ValidationValue;
101struct ValidationMuxedValue;
102struct ImmediateConcatStorage;
103struct ImmediateSliceStorage;
104
105/// Simple wrapper around a 'StringAttr' such that we know to materialize it as
106/// a label declaration instead of calling the builtin dialect constant
107/// materializer.
108struct LabelValue {
109 LabelValue(StringAttr name) : name(name) {}
110
111 bool operator==(const LabelValue &other) const { return name == other.name; }
112
113 /// The label name.
114 StringAttr name;
115};
116
117/// The abstract base class for elaborated values.
118using ElaboratorValue = std::variant<
119 TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
120 RandomizedSequenceStorage *, InterleavedSequenceStorage *, SetStorage *,
121 VirtualRegisterStorage *, UniqueLabelStorage *, LabelValue, ArrayStorage *,
122 TupleStorage *, MemoryStorage *, MemoryBlockStorage *, ValidationValue *,
123 ValidationMuxedValue *, ImmediateConcatStorage *, ImmediateSliceStorage *>;
124
125// NOLINTNEXTLINE(readability-identifier-naming)
126llvm::hash_code hash_value(const LabelValue &val) {
127 return llvm::hash_value(val.name);
128}
129
130// NOLINTNEXTLINE(readability-identifier-naming)
131llvm::hash_code hash_value(const ElaboratorValue &val) {
132 return std::visit(
133 [&val](const auto &alternative) {
134 // Include index in hash to make sure same value as different
135 // alternatives don't collide.
136 return llvm::hash_combine(val.index(), alternative);
137 },
138 val);
139}
140
141} // namespace
142
143namespace llvm {
144
145template <>
146struct DenseMapInfo<bool> {
147 static inline unsigned getEmptyKey() { return false; }
148 static inline unsigned getTombstoneKey() { return true; }
149 static unsigned getHashValue(const bool &val) { return val * 37U; }
150
151 static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; }
152};
153template <>
154struct DenseMapInfo<LabelValue> {
155 static inline LabelValue getEmptyKey() {
157 }
158 static inline LabelValue getTombstoneKey() {
160 }
161 static unsigned getHashValue(const LabelValue &val) {
162 return hash_value(val);
163 }
164
165 static bool isEqual(const LabelValue &lhs, const LabelValue &rhs) {
166 return lhs == rhs;
167 }
168};
169
170} // namespace llvm
171
172//===----------------------------------------------------------------------===//
173// Elaborator Value Storages and Internalization
174//===----------------------------------------------------------------------===//
175
176namespace {
177
178/// Lightweight object to be used as the key for internalization sets. It caches
179/// the hashcode of the internalized object and a pointer to it. This allows a
180/// delayed allocation and construction of the actual object and thus only has
181/// to happen if the object is not already in the set.
182template <typename StorageTy>
183struct HashedStorage {
184 HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr)
185 : hashcode(hashcode), storage(storage) {}
186
187 unsigned hashcode;
188 StorageTy *storage;
189};
190
191/// A DenseMapInfo implementation to support 'insert_as' for the internalization
192/// sets. When comparing two 'HashedStorage's we can just compare the already
193/// internalized storage pointers, otherwise we have to call the costly
194/// 'isEqual' method.
195template <typename StorageTy>
196struct StorageKeyInfo {
197 static inline HashedStorage<StorageTy> getEmptyKey() {
198 return HashedStorage<StorageTy>(0,
199 DenseMapInfo<StorageTy *>::getEmptyKey());
200 }
201 static inline HashedStorage<StorageTy> getTombstoneKey() {
202 return HashedStorage<StorageTy>(
203 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
204 }
205
206 static inline unsigned getHashValue(const HashedStorage<StorageTy> &key) {
207 return key.hashcode;
208 }
209 static inline unsigned getHashValue(const StorageTy &key) {
210 return key.hashcode;
211 }
212
213 static inline bool isEqual(const HashedStorage<StorageTy> &lhs,
214 const HashedStorage<StorageTy> &rhs) {
215 return lhs.storage == rhs.storage;
216 }
217 static inline bool isEqual(const StorageTy &lhs,
218 const HashedStorage<StorageTy> &rhs) {
219 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
220 return false;
221
222 return lhs.isEqual(rhs.storage);
223 }
224};
225
226// Values with structural equivalence intended to be internalized.
227//===----------------------------------------------------------------------===//
228
229/// Storage object for an '!rtg.set<T>'.
230struct SetStorage {
231 static unsigned computeHash(const SetVector<ElaboratorValue> &set,
232 Type type) {
233 llvm::hash_code setHash = 0;
234 for (auto el : set) {
235 // Just XOR all hashes because it's a commutative operation and
236 // `llvm::hash_combine_range` is not commutative.
237 // We don't want the order in which elements were added to influence the
238 // hash and thus the equivalence of sets.
239 setHash = setHash ^ llvm::hash_combine(el);
240 }
241 return llvm::hash_combine(type, setHash);
242 }
243
244 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
245 : hashcode(computeHash(set, type)), set(std::move(set)), type(type) {}
246
247 bool isEqual(const SetStorage *other) const {
248 // Note: we are not using the `==` operator of `SetVector` because it
249 // takes the order in which elements were added into account (since it's a
250 // vector after all). We just use it as a convenient way to keep track of a
251 // deterministic order for re-materialization.
252 bool allContained = true;
253 for (auto el : set)
254 allContained &= other->set.contains(el);
255
256 return hashcode == other->hashcode && set.size() == other->set.size() &&
257 allContained && type == other->type;
258 }
259
260 // The cached hashcode to avoid repeated computations.
261 const unsigned hashcode;
262
263 // Stores the elaborated values contained in the set.
264 const SetVector<ElaboratorValue> set;
265
266 // Store the set type such that we can materialize this evaluated value
267 // also in the case where the set is empty.
268 const Type type;
269};
270
271/// Storage object for an '!rtg.bag<T>'.
272struct BagStorage {
273 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
274 : hashcode(llvm::hash_combine(
275 type, llvm::hash_combine_range(bag.begin(), bag.end()))),
276 bag(std::move(bag)), type(type) {}
277
278 bool isEqual(const BagStorage *other) const {
279 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
280 type == other->type;
281 }
282
283 // The cached hashcode to avoid repeated computations.
284 const unsigned hashcode;
285
286 // Stores the elaborated values contained in the bag with their number of
287 // occurences.
288 const MapVector<ElaboratorValue, uint64_t> bag;
289
290 // Store the bag type such that we can materialize this evaluated value
291 // also in the case where the bag is empty.
292 const Type type;
293};
294
295/// Storage object for an '!rtg.sequence'.
296struct SequenceStorage {
297 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
298 : hashcode(llvm::hash_combine(
299 familyName, llvm::hash_combine_range(args.begin(), args.end()))),
300 familyName(familyName), args(std::move(args)) {}
301
302 bool isEqual(const SequenceStorage *other) const {
303 return hashcode == other->hashcode && familyName == other->familyName &&
304 args == other->args;
305 }
306
307 // The cached hashcode to avoid repeated computations.
308 const unsigned hashcode;
309
310 // The name of the sequence family this sequence is derived from.
311 const StringAttr familyName;
312
313 // The elaborator values used during substitution of the sequence family.
314 const SmallVector<ElaboratorValue> args;
315};
316
317/// Storage object for interleaved '!rtg.randomized_sequence'es.
318struct InterleavedSequenceStorage {
319 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
320 uint32_t batchSize)
321 : sequences(std::move(sequences)), batchSize(batchSize),
322 hashcode(llvm::hash_combine(
323 llvm::hash_combine_range(sequences.begin(), sequences.end()),
324 batchSize)) {}
325
326 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
327 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
328 hashcode(llvm::hash_combine(
329 llvm::hash_combine_range(sequences.begin(), sequences.end()),
330 batchSize)) {}
331
332 bool isEqual(const InterleavedSequenceStorage *other) const {
333 return hashcode == other->hashcode && sequences == other->sequences &&
334 batchSize == other->batchSize;
335 }
336
337 const SmallVector<ElaboratorValue> sequences;
338
339 const uint32_t batchSize;
340
341 // The cached hashcode to avoid repeated computations.
342 const unsigned hashcode;
343};
344
345/// Storage object for '!rtg.array`-typed values.
346struct ArrayStorage {
347 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
348 : hashcode(llvm::hash_combine(
349 type, llvm::hash_combine_range(array.begin(), array.end()))),
350 type(type), array(array) {}
351
352 bool isEqual(const ArrayStorage *other) const {
353 return hashcode == other->hashcode && type == other->type &&
354 array == other->array;
355 }
356
357 // The cached hashcode to avoid repeated computations.
358 const unsigned hashcode;
359
360 /// The type of the array. This is necessary because an array of size 0
361 /// cannot be reconstructed without knowing the original element type.
362 const Type type;
363
364 /// The label name. For unique labels, this is just the prefix.
365 const SmallVector<ElaboratorValue> array;
366};
367
368/// Storage object for 'tuple`-typed values.
369struct TupleStorage {
370 TupleStorage(SmallVector<ElaboratorValue> &&values)
371 : hashcode(llvm::hash_combine_range(values.begin(), values.end())),
372 values(std::move(values)) {}
373
374 bool isEqual(const TupleStorage *other) const {
375 return hashcode == other->hashcode && values == other->values;
376 }
377
378 // The cached hashcode to avoid repeated computations.
379 const unsigned hashcode;
380
381 const SmallVector<ElaboratorValue> values;
382};
383
384/// Storage object for immediate concatenation operations.
385/// FIXME: opaque values should be handled generically
386struct ImmediateConcatStorage {
387 ImmediateConcatStorage(SmallVector<ElaboratorValue> &&operands)
388 : hashcode(llvm::hash_combine_range(operands.begin(), operands.end())),
389 operands(std::move(operands)) {}
390
391 bool isEqual(const ImmediateConcatStorage *other) const {
392 return hashcode == other->hashcode && operands == other->operands;
393 }
394
395 const unsigned hashcode;
396 const SmallVector<ElaboratorValue> operands;
397};
398
399/// Storage object for immediate slice operations.
400/// FIXME: opaque values should be handled generically
401struct ImmediateSliceStorage {
402 ImmediateSliceStorage(ElaboratorValue input, unsigned lowBit, Type type)
403 : hashcode(llvm::hash_combine(input, lowBit, type)), input(input),
404 lowBit(lowBit), type(type) {}
405
406 bool isEqual(const ImmediateSliceStorage *other) const {
407 return hashcode == other->hashcode && input == other->input &&
408 lowBit == other->lowBit && type == other->type;
409 }
410
411 const unsigned hashcode;
412 const ElaboratorValue input;
413 const unsigned lowBit;
414 const Type type;
415};
416
417// Values with identity not intended to be internalized.
418//===----------------------------------------------------------------------===//
419
420/// Base class for storages that represent values with identity, i.e., two
421/// values are not considered equivalent if they are structurally the same, but
422/// each definition of such a value is unique. E.g., unique labels or virtual
423/// registers. These cannot be materialized anew in each nested sequence, but
424/// must be passed as arguments.
425struct IdentityValue {
426
427 IdentityValue(Type type) : type(type) {}
428
429#ifndef NDEBUG
430
431 /// In debug mode, track whether this value was already materialized to
432 /// assert if it's illegally materialized multiple times.
433 ///
434 /// Instead of deleting operations defining these values and materializing
435 /// them again, we could retain the operations. However, we still need
436 /// specific storages to represent these values in some cases, e.g., to get
437 /// the size of a memory allocation. Also, elaboration of nested control-flow
438 /// regions (e.g. `scf.for`) relies on materialization of such values lazily
439 /// instead of cloning the operations eagerly.
440 bool alreadyMaterialized = false;
441
442#endif
443
444 const Type type;
445};
446
447/// Represents a unique virtual register.
448struct VirtualRegisterStorage : IdentityValue {
449 VirtualRegisterStorage(VirtualRegisterConfigAttr allowedRegs, Type type)
450 : IdentityValue(type), allowedRegs(allowedRegs) {}
451
452 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
453 // VirtualRegisters are never internalized.
454
455 // The list of fixed registers allowed to be selected for this virtual
456 // register.
457 const VirtualRegisterConfigAttr allowedRegs;
458};
459
460struct UniqueLabelStorage : IdentityValue {
461 UniqueLabelStorage(StringAttr name)
462 : IdentityValue(LabelType::get(name.getContext())), name(name) {}
463
464 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
465 // VirtualRegisters are never internalized.
466
467 /// The label name. For unique labels, this is just the prefix.
468 const StringAttr name;
469};
470
471/// Storage object for '!rtg.isa.memoryblock`-typed values.
472struct MemoryBlockStorage : IdentityValue {
473 MemoryBlockStorage(const APInt &baseAddress, const APInt &endAddress,
474 Type type)
475 : IdentityValue(type), baseAddress(baseAddress), endAddress(endAddress) {}
476
477 // The base address of the memory. The width of the APInt also represents the
478 // address width of the memory. This is an APInt to support memories of
479 // >64-bit machines.
480 const APInt baseAddress;
481
482 // The last address of the memory.
483 const APInt endAddress;
484};
485
486/// Storage object for '!rtg.isa.memory`-typed values.
487struct MemoryStorage : IdentityValue {
488 MemoryStorage(MemoryBlockStorage *memoryBlock, size_t size, size_t alignment)
489 : IdentityValue(MemoryType::get(memoryBlock->type.getContext(),
490 memoryBlock->baseAddress.getBitWidth())),
491 memoryBlock(memoryBlock), size(size), alignment(alignment) {}
492
493 MemoryBlockStorage *memoryBlock;
494 const size_t size;
495 const size_t alignment;
496};
497
498/// Storage object for an '!rtg.randomized_sequence'.
499struct RandomizedSequenceStorage : IdentityValue {
500 RandomizedSequenceStorage(ContextResourceAttrInterface context,
501 SequenceStorage *sequence)
502 : IdentityValue(
503 RandomizedSequenceType::get(sequence->familyName.getContext())),
504 context(context), sequence(sequence) {}
505
506 // The context under which this sequence is placed.
507 const ContextResourceAttrInterface context;
508
509 const SequenceStorage *sequence;
510};
511
512/// Storage object for an '!rtg.validate' result.
513struct ValidationValue : IdentityValue {
514 ValidationValue(Type type, const ElaboratorValue &ref,
515 const ElaboratorValue &defaultValue, StringAttr id,
516 SmallVector<ElaboratorValue> &&defaultUsedValues,
517 SmallVector<ElaboratorValue> &&elseValues)
518 : IdentityValue(type), ref(ref), defaultValue(defaultValue), id(id),
519 defaultUsedValues(std::move(defaultUsedValues)),
520 elseValues(std::move(elseValues)) {}
521
522 const ElaboratorValue ref;
523 const ElaboratorValue defaultValue;
524 const StringAttr id;
525 const SmallVector<ElaboratorValue> defaultUsedValues;
526 const SmallVector<ElaboratorValue> elseValues;
527};
528
529/// Storage object for the 'values' results of 'rtg.validate'.
530struct ValidationMuxedValue : IdentityValue {
531 ValidationMuxedValue(Type type, const ValidationValue *value, unsigned idx)
532 : IdentityValue(type), value(value), idx(idx) {}
533
534 const ValidationValue *value;
535 const unsigned idx;
536};
537
538/// An 'Internalizer' object internalizes storages and takes ownership of them.
539/// When the initializer object is destroyed, all owned storages are also
540/// deallocated and thus must not be accessed anymore.
541class Internalizer {
542public:
543 /// Internalize a storage of type `StorageTy` constructed with arguments
544 /// `args`. The pointers returned by this method can be used to compare
545 /// objects when, e.g., computing set differences, uniquing the elements in a
546 /// set, etc. Otherwise, we'd need to do a deep value comparison in those
547 /// situations.
548 template <typename StorageTy, typename... Args>
549 StorageTy *internalize(Args &&...args) {
550 static_assert(!std::is_base_of_v<IdentityValue, StorageTy> &&
551 "values with identity must not be internalized");
552
553 StorageTy storage(std::forward<Args>(args)...);
554
555 auto existing = getInternSet<StorageTy>().insert_as(
556 HashedStorage<StorageTy>(storage.hashcode), storage);
557 StorageTy *&storagePtr = existing.first->storage;
558 if (existing.second)
559 storagePtr =
560 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
561
562 return storagePtr;
563 }
564
565 template <typename StorageTy, typename... Args>
566 StorageTy *create(Args &&...args) {
567 static_assert(std::is_base_of_v<IdentityValue, StorageTy> &&
568 "values with structural equivalence must be internalized");
569
570 return new (allocator.Allocate<StorageTy>())
571 StorageTy(std::forward<Args>(args)...);
572 }
573
574private:
575 template <typename StorageTy>
576 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
577 getInternSet() {
578 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
579 return internedArrays;
580 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
581 return internedSets;
582 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
583 return internedBags;
584 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
585 return internedSequences;
586 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
587 return internedRandomizedSequences;
588 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
589 return internedInterleavedSequences;
590 else if constexpr (std::is_same_v<StorageTy, TupleStorage>)
591 return internedTuples;
592 else if constexpr (std::is_same_v<StorageTy, ImmediateConcatStorage>)
593 return internedImmediateConcatValues;
594 else if constexpr (std::is_same_v<StorageTy, ImmediateSliceStorage>)
595 return internedImmediateSliceValues;
596 else
597 static_assert(!sizeof(StorageTy),
598 "no intern set available for this storage type.");
599 }
600
601 // This allocator allocates on the heap. It automatically deallocates all
602 // objects it allocated once the allocator itself is destroyed.
603 llvm::BumpPtrAllocator allocator;
604
605 // The sets holding the internalized objects. We use one set per storage type
606 // such that we can have a simpler equality checking function (no need to
607 // compare some sort of TypeIDs).
608 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
609 internedArrays;
610 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
611 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
612 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
613 internedSequences;
614 DenseSet<HashedStorage<RandomizedSequenceStorage>,
615 StorageKeyInfo<RandomizedSequenceStorage>>
616 internedRandomizedSequences;
617 DenseSet<HashedStorage<InterleavedSequenceStorage>,
618 StorageKeyInfo<InterleavedSequenceStorage>>
619 internedInterleavedSequences;
620 DenseSet<HashedStorage<TupleStorage>, StorageKeyInfo<TupleStorage>>
621 internedTuples;
622 DenseSet<HashedStorage<ImmediateConcatStorage>,
623 StorageKeyInfo<ImmediateConcatStorage>>
624 internedImmediateConcatValues;
625 DenseSet<HashedStorage<ImmediateSliceStorage>,
626 StorageKeyInfo<ImmediateSliceStorage>>
627 internedImmediateSliceValues;
628};
629
630} // namespace
631
632#ifndef NDEBUG
633
634static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
635 const ElaboratorValue &value);
636
637static void print(TypedAttr val, llvm::raw_ostream &os) {
638 os << "<attr " << val << ">";
639}
640
641static void print(BagStorage *val, llvm::raw_ostream &os) {
642 os << "<bag {";
643 llvm::interleaveComma(val->bag, os,
644 [&](const std::pair<ElaboratorValue, uint64_t> &el) {
645 os << el.first << " -> " << el.second;
646 });
647 os << "} at " << val << ">";
648}
649
650static void print(bool val, llvm::raw_ostream &os) {
651 os << "<bool " << (val ? "true" : "false") << ">";
652}
653
654static void print(size_t val, llvm::raw_ostream &os) {
655 os << "<index " << val << ">";
656}
657
658static void print(SequenceStorage *val, llvm::raw_ostream &os) {
659 os << "<sequence @" << val->familyName.getValue() << "(";
660 llvm::interleaveComma(val->args, os,
661 [&](const ElaboratorValue &val) { os << val; });
662 os << ") at " << val << ">";
663}
664
665static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
666 os << "<randomized-sequence derived from @"
667 << val->sequence->familyName.getValue() << " under context "
668 << val->context << "(";
669 llvm::interleaveComma(val->sequence->args, os,
670 [&](const ElaboratorValue &val) { os << val; });
671 os << ") at " << val << ">";
672}
673
674static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
675 os << "<interleaved-sequence [";
676 llvm::interleaveComma(val->sequences, os,
677 [&](const ElaboratorValue &val) { os << val; });
678 os << "] batch-size " << val->batchSize << " at " << val << ">";
679}
680
681static void print(ArrayStorage *val, llvm::raw_ostream &os) {
682 os << "<array [";
683 llvm::interleaveComma(val->array, os,
684 [&](const ElaboratorValue &val) { os << val; });
685 os << "] at " << val << ">";
686}
687
688static void print(SetStorage *val, llvm::raw_ostream &os) {
689 os << "<set {";
690 llvm::interleaveComma(val->set, os,
691 [&](const ElaboratorValue &val) { os << val; });
692 os << "} at " << val << ">";
693}
694
695static void print(const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
696 os << "<virtual-register " << val << " " << val->allowedRegs << ">";
697}
698
699static void print(const UniqueLabelStorage *val, llvm::raw_ostream &os) {
700 os << "<unique-label " << val << " " << val->name << ">";
701}
702
703static void print(const LabelValue &val, llvm::raw_ostream &os) {
704 os << "<label " << val.name << ">";
705}
706
707static void print(const TupleStorage *val, llvm::raw_ostream &os) {
708 os << "<tuple (";
709 llvm::interleaveComma(val->values, os,
710 [&](const ElaboratorValue &val) { os << val; });
711 os << ")>";
712}
713
714static void print(const MemoryStorage *val, llvm::raw_ostream &os) {
715 os << "<memory {" << ElaboratorValue(val->memoryBlock)
716 << ", size=" << val->size << ", alignment=" << val->alignment << "}>";
717}
718
719static void print(const MemoryBlockStorage *val, llvm::raw_ostream &os) {
720 os << "<memory-block {"
721 << ", address-width=" << val->baseAddress.getBitWidth()
722 << ", base-address=" << val->baseAddress
723 << ", end-address=" << val->endAddress << "}>";
724}
725
726static void print(const ValidationValue *val, llvm::raw_ostream &os) {
727 os << "<validation-value {type=" << val->type << ", ref=" << val->ref
728 << ", defaultValue=" << val->defaultValue << "}>";
729}
730
731static void print(const ValidationMuxedValue *val, llvm::raw_ostream &os) {
732 os << "<validation-muxed-value (" << val->value << ") at " << val->idx << ">";
733}
734
735static void print(const ImmediateConcatStorage *val, llvm::raw_ostream &os) {
736 os << "<immediate-concat [";
737 llvm::interleaveComma(val->operands, os,
738 [&](const ElaboratorValue &val) { os << val; });
739 os << "]>";
740}
741
742static void print(const ImmediateSliceStorage *val, llvm::raw_ostream &os) {
743 os << "<immediate-slice " << val->input << " from " << val->lowBit << ">";
744}
745
746static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
747 const ElaboratorValue &value) {
748 std::visit([&](auto val) { print(val, os); }, value);
749
750 return os;
751}
752
753#endif
754
755//===----------------------------------------------------------------------===//
756// Elaborator Value Materialization
757//===----------------------------------------------------------------------===//
758
759namespace {
760
761/// State that should be shared by all elaborator and materializer instances.
762struct SharedState {
763 SharedState(SymbolTable &table, unsigned seed) : table(table), rng(seed) {}
764
765 SymbolTable &table;
766 std::mt19937 rng;
767 Namespace names;
768 Internalizer internalizer;
769};
770
771/// A collection of state per RTG test.
772struct TestState {
773 /// The name of the test.
774 StringAttr name;
775
776 /// The context switches registered for this test.
777 MapVector<
778 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
779 SequenceStorage *>
780 contextSwitches;
781};
782
783/// Construct an SSA value from a given elaborated value.
784class Materializer {
785public:
786 Materializer(OpBuilder builder, TestState &testState,
787 SharedState &sharedState,
788 SmallVector<ElaboratorValue> &blockArgs)
789 : builder(builder), testState(testState), sharedState(sharedState),
790 blockArgs(blockArgs) {}
791
792 /// Materialize IR representing the provided `ElaboratorValue` and return the
793 /// `Value` or a null value on failure.
794 Value materialize(ElaboratorValue val, Location loc,
795 function_ref<InFlightDiagnostic()> emitError) {
796 auto iter = materializedValues.find(val);
797 if (iter != materializedValues.end())
798 return iter->second;
799
800 LLVM_DEBUG(llvm::dbgs() << "Materializing " << val);
801
802 // In debug mode, track whether values with identity were already
803 // materialized before and assert in such a situation.
804 Value res = std::visit(
805 [&](auto value) {
806 if constexpr (std::is_base_of_v<IdentityValue,
807 std::remove_pointer_t<
808 std::decay_t<decltype(value)>>>) {
809 if (identityValueRoot.contains(value)) {
810#ifndef NDEBUG
811 bool &materialized =
812 static_cast<IdentityValue *>(value)->alreadyMaterialized;
813 assert(!materialized && "must not already be materialized");
814 materialized = true;
815#endif
816
817 return visit(value, loc, emitError);
818 }
819
820 Value arg = builder.getBlock()->addArgument(value->type, loc);
821 blockArgs.push_back(val);
822 blockArgTypes.push_back(arg.getType());
823 materializedValues[val] = arg;
824 return arg;
825 }
826
827 return visit(value, loc, emitError);
828 },
829 val);
830
831 LLVM_DEBUG(llvm::dbgs() << " to\n" << res << "\n\n");
832
833 return res;
834 }
835
836 /// If `op` is not in the same region as the materializer insertion point, a
837 /// clone is created at the materializer's insertion point by also
838 /// materializing the `ElaboratorValue`s for each operand just before it.
839 /// Otherwise, all operations after the materializer's insertion point are
840 /// deleted until `op` is reached. An error is returned if the operation is
841 /// before the insertion point.
842 LogicalResult materialize(Operation *op,
843 DenseMap<Value, ElaboratorValue> &state) {
844 if (op->getNumRegions() > 0)
845 return op->emitOpError("ops with nested regions must be elaborated away");
846
847 // We don't support opaque values. If there is an SSA value that has a
848 // use-site it needs an equivalent ElaborationValue representation.
849 // NOTE: We could support cases where there is initially a use-site but that
850 // op is guaranteed to be deleted during elaboration. Or the use-sites are
851 // replaced with freshly materialized values from the ElaborationValue. But
852 // then, why can't we delete the value defining op?
853 for (auto res : op->getResults())
854 if (!res.use_empty() && !isa<ValidateOp>(op))
855 return op->emitOpError(
856 "ops with results that have uses are not supported");
857
858 if (op->getParentRegion() == builder.getBlock()->getParent()) {
859 // We are doing in-place materialization, so mark all ops deleted until we
860 // reach the one to be materialized and modify it in-place.
861 deleteOpsUntil([&](auto iter) { return &*iter == op; });
862
863 if (builder.getInsertionPoint() == builder.getBlock()->end())
864 return op->emitError("operation did not occur after the current "
865 "materializer insertion point");
866
867 LLVM_DEBUG(llvm::dbgs() << "Modifying in-place: " << *op << "\n\n");
868 } else {
869 LLVM_DEBUG(llvm::dbgs() << "Materializing a clone of " << *op << "\n\n");
870 op = builder.clone(*op);
871 builder.setInsertionPoint(op);
872 }
873
874 for (auto &operand : op->getOpOperands()) {
875 auto emitError = [&]() {
876 auto diag = op->emitError();
877 diag.attachNote(op->getLoc())
878 << "while materializing value for operand#"
879 << operand.getOperandNumber();
880 return diag;
881 };
882
883 auto elabVal = state.at(operand.get());
884 Value val = materialize(elabVal, op->getLoc(), emitError);
885 if (!val)
886 return failure();
887
888 state[val] = elabVal;
889 operand.set(val);
890 }
891
892 builder.setInsertionPointAfter(op);
893 return success();
894 }
895
896 /// Should be called once the `Region` is successfully materialized. No calls
897 /// to `materialize` should happen after this anymore.
898 void finalize() {
899 deleteOpsUntil([](auto iter) { return false; });
900
901 for (auto *op : llvm::reverse(toDelete))
902 op->erase();
903 }
904
905 /// Tell this materializer that it is responsible for materializing the given
906 /// identity value at the earliest position it is needed, and should't
907 /// request the value via block argument.
908 void registerIdentityValue(IdentityValue *val) {
909 identityValueRoot.insert(val);
910 }
911
912 ArrayRef<Type> getBlockArgTypes() const { return blockArgTypes; }
913
914 void map(ElaboratorValue eval, Value val) { materializedValues[eval] = val; }
915
916 template <typename OpTy, typename... Args>
917 OpTy create(Location location, Args &&...args) {
918 return OpTy::create(builder, location, std::forward<Args>(args)...);
919 }
920
921private:
922 SequenceOp elaborateSequence(const RandomizedSequenceStorage *seq,
923 SmallVector<ElaboratorValue> &elabArgs);
924
925 void deleteOpsUntil(function_ref<bool(Block::iterator)> stop) {
926 auto ip = builder.getInsertionPoint();
927 while (ip != builder.getBlock()->end() && !stop(ip)) {
928 LLVM_DEBUG(llvm::dbgs() << "Marking to be deleted: " << *ip << "\n\n");
929 toDelete.push_back(&*ip);
930
931 builder.setInsertionPointAfter(&*ip);
932 ip = builder.getInsertionPoint();
933 }
934 }
935
936 Value visit(TypedAttr val, Location loc,
937 function_ref<InFlightDiagnostic()> emitError) {
938 // For index attributes (and arithmetic operations on them) we use the
939 // index dialect.
940 if (auto intAttr = dyn_cast<IntegerAttr>(val);
941 intAttr && isa<IndexType>(val.getType())) {
942 Value res = index::ConstantOp::create(builder, loc, intAttr);
943 materializedValues[val] = res;
944 return res;
945 }
946
947 // For any other attribute, we just call the materializer of the dialect
948 // defining that attribute.
949 auto *op =
950 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
951 if (!op) {
952 emitError() << "materializer of dialect '"
953 << val.getDialect().getNamespace()
954 << "' unable to materialize value for attribute '" << val
955 << "'";
956 return Value();
957 }
958
959 Value res = op->getResult(0);
960 materializedValues[val] = res;
961 return res;
962 }
963
964 Value visit(size_t val, Location loc,
965 function_ref<InFlightDiagnostic()> emitError) {
966 Value res = index::ConstantOp::create(builder, loc, val);
967 materializedValues[val] = res;
968 return res;
969 }
970
971 Value visit(bool val, Location loc,
972 function_ref<InFlightDiagnostic()> emitError) {
973 Value res = index::BoolConstantOp::create(builder, loc, val);
974 materializedValues[val] = res;
975 return res;
976 }
977
978 Value visit(ArrayStorage *val, Location loc,
979 function_ref<InFlightDiagnostic()> emitError) {
980 SmallVector<Value> elements;
981 elements.reserve(val->array.size());
982 for (auto el : val->array) {
983 auto materialized = materialize(el, loc, emitError);
984 if (!materialized)
985 return Value();
986
987 elements.push_back(materialized);
988 }
989
990 Value res = ArrayCreateOp::create(builder, loc, val->type, elements);
991 materializedValues[val] = res;
992 return res;
993 }
994
995 Value visit(SetStorage *val, Location loc,
996 function_ref<InFlightDiagnostic()> emitError) {
997 SmallVector<Value> elements;
998 elements.reserve(val->set.size());
999 for (auto el : val->set) {
1000 auto materialized = materialize(el, loc, emitError);
1001 if (!materialized)
1002 return Value();
1003
1004 elements.push_back(materialized);
1005 }
1006
1007 auto res = SetCreateOp::create(builder, loc, val->type, elements);
1008 materializedValues[val] = res;
1009 return res;
1010 }
1011
1012 Value visit(BagStorage *val, Location loc,
1013 function_ref<InFlightDiagnostic()> emitError) {
1014 SmallVector<Value> values, weights;
1015 values.reserve(val->bag.size());
1016 weights.reserve(val->bag.size());
1017 for (auto [val, weight] : val->bag) {
1018 auto materializedVal = materialize(val, loc, emitError);
1019 auto materializedWeight = materialize(weight, loc, emitError);
1020 if (!materializedVal || !materializedWeight)
1021 return Value();
1022
1023 values.push_back(materializedVal);
1024 weights.push_back(materializedWeight);
1025 }
1026
1027 auto res = BagCreateOp::create(builder, loc, val->type, values, weights);
1028 materializedValues[val] = res;
1029 return res;
1030 }
1031
1032 Value visit(MemoryBlockStorage *val, Location loc,
1033 function_ref<InFlightDiagnostic()> emitError) {
1034 auto intType = builder.getIntegerType(val->baseAddress.getBitWidth());
1035 Value res = MemoryBlockDeclareOp::create(
1036 builder, loc, val->type, IntegerAttr::get(intType, val->baseAddress),
1037 IntegerAttr::get(intType, val->endAddress));
1038 materializedValues[val] = res;
1039 return res;
1040 }
1041
1042 Value visit(MemoryStorage *val, Location loc,
1043 function_ref<InFlightDiagnostic()> emitError) {
1044 auto memBlock = materialize(val->memoryBlock, loc, emitError);
1045 auto memSize = materialize(val->size, loc, emitError);
1046 auto memAlign = materialize(val->alignment, loc, emitError);
1047 if (!(memBlock && memSize && memAlign))
1048 return {};
1049
1050 Value res =
1051 MemoryAllocOp::create(builder, loc, memBlock, memSize, memAlign);
1052 materializedValues[val] = res;
1053 return res;
1054 }
1055
1056 Value visit(SequenceStorage *val, Location loc,
1057 function_ref<InFlightDiagnostic()> emitError) {
1058 emitError() << "materializing a non-randomized sequence not supported yet";
1059 return Value();
1060 }
1061
1062 Value visit(RandomizedSequenceStorage *val, Location loc,
1063 function_ref<InFlightDiagnostic()> emitError) {
1064 // To know which values we have to pass by argument (and not just pass all
1065 // that migth be used eagerly), we have to elaborate the sequence family if
1066 // not already done so.
1067 // We need to get back the sequence to reference, and the list of elaborated
1068 // values to pass as arguments.
1069 SmallVector<ElaboratorValue> elabArgs;
1070 SequenceOp seqOp = elaborateSequence(val, elabArgs);
1071 if (!seqOp)
1072 return {};
1073
1074 // Materialize all the values we need to pass as arguments and collect their
1075 // types.
1076 SmallVector<Value> args;
1077 SmallVector<Type> argTypes;
1078 for (auto arg : elabArgs) {
1079 Value materialized = materialize(arg, loc, emitError);
1080 if (!materialized)
1081 return {};
1082
1083 args.push_back(materialized);
1084 argTypes.push_back(materialized.getType());
1085 }
1086
1087 Value res = GetSequenceOp::create(
1088 builder, loc, SequenceType::get(builder.getContext(), argTypes),
1089 seqOp.getSymName());
1090
1091 // Only materialize a substitute_sequence op when we have arguments to
1092 // substitute since this op does not support 0 arguments.
1093 if (!args.empty())
1094 res = SubstituteSequenceOp::create(builder, loc, res, args);
1095
1096 res = RandomizeSequenceOp::create(builder, loc, res);
1097
1098 materializedValues[val] = res;
1099 return res;
1100 }
1101
1102 Value visit(InterleavedSequenceStorage *val, Location loc,
1103 function_ref<InFlightDiagnostic()> emitError) {
1104 SmallVector<Value> sequences;
1105 for (auto seqVal : val->sequences) {
1106 Value materialized = materialize(seqVal, loc, emitError);
1107 if (!materialized)
1108 return {};
1109
1110 sequences.push_back(materialized);
1111 }
1112
1113 if (sequences.size() == 1)
1114 return sequences[0];
1115
1116 Value res =
1117 InterleaveSequencesOp::create(builder, loc, sequences, val->batchSize);
1118 materializedValues[val] = res;
1119 return res;
1120 }
1121
1122 Value visit(VirtualRegisterStorage *val, Location loc,
1123 function_ref<InFlightDiagnostic()> emitError) {
1124 Value res = VirtualRegisterOp::create(builder, loc, val->allowedRegs);
1125 materializedValues[val] = res;
1126 return res;
1127 }
1128
1129 Value visit(UniqueLabelStorage *val, Location loc,
1130 function_ref<InFlightDiagnostic()> emitError) {
1131 Value res =
1132 LabelUniqueDeclOp::create(builder, loc, val->name, ValueRange());
1133 materializedValues[val] = res;
1134 return res;
1135 }
1136
1137 Value visit(const LabelValue &val, Location loc,
1138 function_ref<InFlightDiagnostic()> emitError) {
1139 Value res = LabelDeclOp::create(builder, loc, val.name, ValueRange());
1140 materializedValues[val] = res;
1141 return res;
1142 }
1143
1144 Value visit(TupleStorage *val, Location loc,
1145 function_ref<InFlightDiagnostic()> emitError) {
1146 SmallVector<Value> materialized;
1147 materialized.reserve(val->values.size());
1148 for (auto v : val->values)
1149 materialized.push_back(materialize(v, loc, emitError));
1150 Value res = TupleCreateOp::create(builder, loc, materialized);
1151 materializedValues[val] = res;
1152 return res;
1153 }
1154
1155 Value visit(ValidationValue *val, Location loc,
1156 function_ref<InFlightDiagnostic()> emitError) {
1157 SmallVector<Value> usedDefaultValues, elseValues;
1158 for (auto [dfltVal, elseVal] :
1159 llvm::zip(val->defaultUsedValues, val->elseValues)) {
1160 auto dfltMat = materialize(dfltVal, loc, emitError);
1161 auto elseMat = materialize(elseVal, loc, emitError);
1162 if (!dfltMat || !elseMat)
1163 return {};
1164
1165 usedDefaultValues.push_back(dfltMat);
1166 usedDefaultValues.push_back(elseMat);
1167 }
1168
1169 auto validateOp = ValidateOp::create(
1170 builder, loc, val->type, materialize(val->ref, loc, emitError),
1171 materialize(val->defaultValue, loc, emitError), val->id,
1172 usedDefaultValues, elseValues);
1173 materializedValues[val] = validateOp.getValue();
1174 return validateOp.getValue();
1175 }
1176
1177 Value visit(ValidationMuxedValue *val, Location loc,
1178 function_ref<InFlightDiagnostic()> emitError) {
1179 Value validateValue =
1180 materialize(const_cast<ValidationValue *>(val->value), loc, emitError);
1181 if (!validateValue)
1182 return {};
1183 auto validateOp = validateValue.getDefiningOp<ValidateOp>();
1184 if (!validateOp) {
1185 auto *defOp = validateValue.getDefiningOp();
1186 emitError()
1187 << "expected validate op for validation muxed value, but found "
1188 << (defOp ? defOp->getName().getStringRef() : "block argument");
1189 return {};
1190 }
1191 materializedValues[val] = validateOp.getValues()[val->idx];
1192 return validateOp.getValues()[val->idx];
1193 }
1194
1195 Value visit(ImmediateConcatStorage *val, Location loc,
1196 function_ref<InFlightDiagnostic()> emitError) {
1197 SmallVector<Value> operands;
1198 for (auto operand : val->operands) {
1199 auto materialized = materialize(operand, loc, emitError);
1200 if (!materialized)
1201 return {};
1202
1203 operands.push_back(materialized);
1204 }
1205
1206 Value res = ConcatImmediateOp::create(builder, loc, operands);
1207 materializedValues[val] = res;
1208 return res;
1209 }
1210
1211 Value visit(ImmediateSliceStorage *val, Location loc,
1212 function_ref<InFlightDiagnostic()> emitError) {
1213 Value input = materialize(val->input, loc, emitError);
1214 if (!input)
1215 return {};
1216
1217 Value res =
1218 SliceImmediateOp::create(builder, loc, val->type, input, val->lowBit);
1219 materializedValues[val] = res;
1220 return res;
1221 }
1222
1223private:
1224 /// Cache values we have already materialized to reuse them later. We start
1225 /// with an insertion point at the start of the block and cache the (updated)
1226 /// insertion point such that future materializations can also reuse previous
1227 /// materializations without running into dominance issues (or requiring
1228 /// additional checks to avoid them).
1229 DenseMap<ElaboratorValue, Value> materializedValues;
1230
1231 /// Cache the builder to continue insertions at their current insertion point
1232 /// for the reason stated above.
1233 OpBuilder builder;
1234
1235 SmallVector<Operation *> toDelete;
1236
1237 TestState &testState;
1238 SharedState &sharedState;
1239
1240 /// Keep track of the block arguments we had to add to this materializer's
1241 /// block for identity values and also remember which elaborator values are
1242 /// expected to be passed as arguments from outside.
1243 SmallVector<ElaboratorValue> &blockArgs;
1244 SmallVector<Type> blockArgTypes;
1245
1246 /// Identity values in this set are materialized by this materializer,
1247 /// otherwise they are added as block arguments and the block that wants to
1248 /// embed this sequence is expected to provide a value for it.
1249 DenseSet<IdentityValue *> identityValueRoot;
1250};
1251
1252//===----------------------------------------------------------------------===//
1253// Elaboration Visitor
1254//===----------------------------------------------------------------------===//
1255
1256/// Used to signal to the elaboration driver whether the operation should be
1257/// removed.
1258enum class DeletionKind { Keep, Delete };
1259
1260/// Interprets the IR to perform and lower the represented randomizations.
1261class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
1262public:
1264 using RTGBase::visitOp;
1265
1266 Elaborator(SharedState &sharedState, TestState &testState,
1267 Materializer &materializer,
1268 ContextResourceAttrInterface currentContext = {})
1269 : sharedState(sharedState), testState(testState),
1270 materializer(materializer), currentContext(currentContext) {}
1271
1272 template <typename ValueTy>
1273 inline ValueTy get(Value val) const {
1274 return std::get<ValueTy>(state.at(val));
1275 }
1276
1277 FailureOr<DeletionKind> visitPureOp(Operation *op) {
1278 SmallVector<Attribute> operands;
1279 for (auto operand : op->getOperands()) {
1280 auto evalValue = state[operand];
1281 if (std::holds_alternative<TypedAttr>(evalValue))
1282 operands.push_back(std::get<TypedAttr>(evalValue));
1283 else
1284 return visitUnhandledOp(op);
1285 }
1286
1287 SmallVector<OpFoldResult> results;
1288 if (failed(op->fold(operands, results)))
1289 return visitUnhandledOp(op);
1290
1291 // We don't support in-place folders.
1292 if (results.size() != op->getNumResults())
1293 return visitUnhandledOp(op);
1294
1295 for (auto [res, val] : llvm::zip(results, op->getResults())) {
1296 auto attr = llvm::dyn_cast_or_null<TypedAttr>(res.dyn_cast<Attribute>());
1297 if (!attr)
1298 return op->emitError(
1299 "only typed attributes supported for constant-like operations");
1300
1301 auto intAttr = dyn_cast<IntegerAttr>(attr);
1302 if (intAttr && isa<IndexType>(attr.getType()))
1303 state[op->getResult(0)] = size_t(intAttr.getInt());
1304 else if (intAttr && intAttr.getType().isSignlessInteger(1))
1305 state[op->getResult(0)] = bool(intAttr.getInt());
1306 else
1307 state[op->getResult(0)] = attr;
1308 }
1309
1310 return DeletionKind::Delete;
1311 }
1312
1313 /// Print a nice error message for operations we don't support yet.
1314 FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
1315 return op->emitOpError("elaboration not supported");
1316 }
1317
1318 FailureOr<DeletionKind> visitExternalOp(Operation *op) {
1319 auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
1320 if (op->hasTrait<OpTrait::ConstantLike>() || (memOp && memOp.hasNoEffect()))
1321 return visitPureOp(op);
1322
1323 // TODO: we only have this to be able to write tests for this pass without
1324 // having to add support for more operations for now, so it should be
1325 // removed once it is not necessary anymore for writing tests
1326 if (op->use_empty())
1327 return DeletionKind::Keep;
1328
1329 return visitUnhandledOp(op);
1330 }
1331
1332 FailureOr<DeletionKind> visitOp(ConstantOp op) { return visitPureOp(op); }
1333
1334 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
1335 SmallVector<ElaboratorValue> replacements;
1336 state[op.getResult()] =
1337 sharedState.internalizer.internalize<SequenceStorage>(
1338 op.getSequenceAttr().getAttr(), std::move(replacements));
1339 return DeletionKind::Delete;
1340 }
1341
1342 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
1343 auto *seq = get<SequenceStorage *>(op.getSequence());
1344
1345 SmallVector<ElaboratorValue> replacements(seq->args);
1346 for (auto replacement : op.getReplacements())
1347 replacements.push_back(state.at(replacement));
1348
1349 state[op.getResult()] =
1350 sharedState.internalizer.internalize<SequenceStorage>(
1351 seq->familyName, std::move(replacements));
1352
1353 return DeletionKind::Delete;
1354 }
1355
1356 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
1357 auto *seq = get<SequenceStorage *>(op.getSequence());
1358 auto *randomizedSeq =
1359 sharedState.internalizer.create<RandomizedSequenceStorage>(
1360 currentContext, seq);
1361 materializer.registerIdentityValue(randomizedSeq);
1362 state[op.getResult()] =
1363 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1364 randomizedSeq);
1365 return DeletionKind::Delete;
1366 }
1367
1368 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
1369 SmallVector<ElaboratorValue> sequences;
1370 for (auto seq : op.getSequences())
1371 sequences.push_back(get<InterleavedSequenceStorage *>(seq));
1372
1373 state[op.getResult()] =
1374 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1375 std::move(sequences), op.getBatchSize());
1376 return DeletionKind::Delete;
1377 }
1378
1379 // NOLINTNEXTLINE(misc-no-recursion)
1380 LogicalResult isValidContext(ElaboratorValue value, Operation *op) const {
1381 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
1382 auto *seq = std::get<RandomizedSequenceStorage *>(value);
1383 if (seq->context != currentContext) {
1384 auto err = op->emitError("attempting to place sequence derived from ")
1385 << seq->sequence->familyName.getValue() << " under context "
1386 << currentContext
1387 << ", but it was previously randomized for context ";
1388 if (seq->context)
1389 err << seq->context;
1390 else
1391 err << "'default'";
1392 return err;
1393 }
1394 return success();
1395 }
1396
1397 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1398 for (auto val : interVal->sequences)
1399 if (failed(isValidContext(val, op)))
1400 return failure();
1401 return success();
1402 }
1403
1404 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1405 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1406 if (failed(isValidContext(seqVal, op)))
1407 return failure();
1408
1409 return DeletionKind::Keep;
1410 }
1411
1412 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1413 SetVector<ElaboratorValue> set;
1414 for (auto val : op.getElements())
1415 set.insert(state.at(val));
1416
1417 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1418 std::move(set), op.getSet().getType());
1419 return DeletionKind::Delete;
1420 }
1421
1422 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1423 auto set = get<SetStorage *>(op.getSet())->set;
1424
1425 if (set.empty())
1426 return op->emitError("cannot select from an empty set");
1427
1428 size_t selected;
1429 if (auto intAttr =
1430 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1431 std::mt19937 customRng(intAttr.getInt());
1432 selected = getUniformlyInRange(customRng, 0, set.size() - 1);
1433 } else {
1434 selected = getUniformlyInRange(sharedState.rng, 0, set.size() - 1);
1435 }
1436
1437 state[op.getResult()] = set[selected];
1438 return DeletionKind::Delete;
1439 }
1440
1441 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1442 auto original = get<SetStorage *>(op.getOriginal())->set;
1443 auto diff = get<SetStorage *>(op.getDiff())->set;
1444
1445 SetVector<ElaboratorValue> result(original);
1446 result.set_subtract(diff);
1447
1448 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1449 std::move(result), op.getResult().getType());
1450 return DeletionKind::Delete;
1451 }
1452
1453 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1454 SetVector<ElaboratorValue> result;
1455 for (auto set : op.getSets())
1456 result.set_union(get<SetStorage *>(set)->set);
1457
1458 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1459 std::move(result), op.getType());
1460 return DeletionKind::Delete;
1461 }
1462
1463 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1464 auto size = get<SetStorage *>(op.getSet())->set.size();
1465 state[op.getResult()] = size;
1466 return DeletionKind::Delete;
1467 }
1468
1469 // {a0,a1} x {b0,b1} x {c0,c1} -> {(a0), (a1)} -> {(a0,b0), (a0,b1), (a1,b0),
1470 // (a1,b1)} -> {(a0,b0,c0), (a0,b0,c1), (a0,b1,c0), (a0,b1,c1), (a1,b0,c0),
1471 // (a1,b0,c1), (a1,b1,c0), (a1,b1,c1)}
1472 FailureOr<DeletionKind> visitOp(SetCartesianProductOp op) {
1473 SetVector<ElaboratorValue> result;
1474 SmallVector<SmallVector<ElaboratorValue>> tuples;
1475 tuples.push_back({});
1476
1477 for (auto input : op.getInputs()) {
1478 auto &set = get<SetStorage *>(input)->set;
1479 if (set.empty()) {
1480 SetVector<ElaboratorValue> empty;
1481 state[op.getResult()] =
1482 sharedState.internalizer.internalize<SetStorage>(std::move(empty),
1483 op.getType());
1484 return DeletionKind::Delete;
1485 }
1486
1487 for (unsigned i = 0, e = tuples.size(); i < e; ++i) {
1488 for (auto setEl : set.getArrayRef().drop_back()) {
1489 tuples.push_back(tuples[i]);
1490 tuples.back().push_back(setEl);
1491 }
1492 tuples[i].push_back(set.back());
1493 }
1494 }
1495
1496 for (auto &tup : tuples)
1497 result.insert(
1498 sharedState.internalizer.internalize<TupleStorage>(std::move(tup)));
1499
1500 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1501 std::move(result), op.getType());
1502 return DeletionKind::Delete;
1503 }
1504
1505 FailureOr<DeletionKind> visitOp(SetConvertToBagOp op) {
1506 auto set = get<SetStorage *>(op.getInput())->set;
1507 MapVector<ElaboratorValue, uint64_t> bag;
1508 for (auto val : set)
1509 bag.insert({val, 1});
1510 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1511 std::move(bag), op.getType());
1512 return DeletionKind::Delete;
1513 }
1514
1515 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1516 MapVector<ElaboratorValue, uint64_t> bag;
1517 for (auto [val, multiple] :
1518 llvm::zip(op.getElements(), op.getMultiples())) {
1519 // If the multiple is not stored as an AttributeValue, the elaboration
1520 // must have already failed earlier (since we don't have
1521 // unevaluated/opaque values).
1522 bag[state.at(val)] += get<size_t>(multiple);
1523 }
1524
1525 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1526 std::move(bag), op.getType());
1527 return DeletionKind::Delete;
1528 }
1529
1530 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1531 auto bag = get<BagStorage *>(op.getBag())->bag;
1532
1533 if (bag.empty())
1534 return op->emitError("cannot select from an empty bag");
1535
1536 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1537 prefixSum.reserve(bag.size());
1538 uint32_t accumulator = 0;
1539 for (auto [val, weight] : bag) {
1540 accumulator += weight;
1541 prefixSum.push_back({val, accumulator});
1542 }
1543
1544 auto customRng = sharedState.rng;
1545 if (auto intAttr =
1546 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1547 customRng = std::mt19937(intAttr.getInt());
1548 }
1549
1550 auto idx = getUniformlyInRange(customRng, 0, accumulator - 1);
1551 auto *iter = llvm::upper_bound(
1552 prefixSum, idx,
1553 [](uint32_t a, const std::pair<ElaboratorValue, uint32_t> &b) {
1554 return a < b.second;
1555 });
1556
1557 state[op.getResult()] = iter->first;
1558 return DeletionKind::Delete;
1559 }
1560
1561 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1562 auto original = get<BagStorage *>(op.getOriginal())->bag;
1563 auto diff = get<BagStorage *>(op.getDiff())->bag;
1564
1565 MapVector<ElaboratorValue, uint64_t> result;
1566 for (const auto &el : original) {
1567 if (!diff.contains(el.first)) {
1568 result.insert(el);
1569 continue;
1570 }
1571
1572 if (op.getInf())
1573 continue;
1574
1575 auto toDiff = diff.lookup(el.first);
1576 if (el.second <= toDiff)
1577 continue;
1578
1579 result.insert({el.first, el.second - toDiff});
1580 }
1581
1582 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1583 std::move(result), op.getType());
1584 return DeletionKind::Delete;
1585 }
1586
1587 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1588 MapVector<ElaboratorValue, uint64_t> result;
1589 for (auto bag : op.getBags()) {
1590 auto val = get<BagStorage *>(bag)->bag;
1591 for (auto [el, multiple] : val)
1592 result[el] += multiple;
1593 }
1594
1595 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1596 std::move(result), op.getType());
1597 return DeletionKind::Delete;
1598 }
1599
1600 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1601 auto size = get<BagStorage *>(op.getBag())->bag.size();
1602 state[op.getResult()] = size;
1603 return DeletionKind::Delete;
1604 }
1605
1606 FailureOr<DeletionKind> visitOp(BagConvertToSetOp op) {
1607 auto bag = get<BagStorage *>(op.getInput())->bag;
1608 SetVector<ElaboratorValue> set;
1609 for (auto [k, v] : bag)
1610 set.insert(k);
1611 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1612 std::move(set), op.getType());
1613 return DeletionKind::Delete;
1614 }
1615
1616 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1617 auto *val = sharedState.internalizer.create<VirtualRegisterStorage>(
1618 op.getAllowedRegsAttr(), op.getType());
1619 state[op.getResult()] = val;
1620 materializer.registerIdentityValue(val);
1621 return DeletionKind::Delete;
1622 }
1623
1624 StringAttr substituteFormatString(StringAttr formatString,
1625 ValueRange substitutes) const {
1626 if (substitutes.empty() || formatString.empty())
1627 return formatString;
1628
1629 auto original = formatString.getValue().str();
1630 for (auto [i, subst] : llvm::enumerate(substitutes)) {
1631 size_t startPos = 0;
1632 std::string from = "{{" + std::to_string(i) + "}}";
1633 while ((startPos = original.find(from, startPos)) != std::string::npos) {
1634 auto substString = std::to_string(get<size_t>(subst));
1635 original.replace(startPos, from.length(), substString);
1636 }
1637 }
1638
1639 return StringAttr::get(formatString.getContext(), original);
1640 }
1641
1642 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1643 SmallVector<ElaboratorValue> array;
1644 array.reserve(op.getElements().size());
1645 for (auto val : op.getElements())
1646 array.emplace_back(state.at(val));
1647
1648 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1649 op.getResult().getType(), std::move(array));
1650 return DeletionKind::Delete;
1651 }
1652
1653 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1654 auto array = get<ArrayStorage *>(op.getArray())->array;
1655 size_t idx = get<size_t>(op.getIndex());
1656
1657 if (array.size() <= idx)
1658 return op->emitError("invalid to access index ")
1659 << idx << " of an array with " << array.size() << " elements";
1660
1661 state[op.getResult()] = array[idx];
1662 return DeletionKind::Delete;
1663 }
1664
1665 FailureOr<DeletionKind> visitOp(ArrayInjectOp op) {
1666 auto array = get<ArrayStorage *>(op.getArray())->array;
1667 size_t idx = get<size_t>(op.getIndex());
1668
1669 if (array.size() <= idx)
1670 return op->emitError("invalid to access index ")
1671 << idx << " of an array with " << array.size() << " elements";
1672
1673 array[idx] = state[op.getValue()];
1674 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1675 op.getResult().getType(), std::move(array));
1676 return DeletionKind::Delete;
1677 }
1678
1679 FailureOr<DeletionKind> visitOp(ArraySizeOp op) {
1680 auto array = get<ArrayStorage *>(op.getArray())->array;
1681 state[op.getResult()] = array.size();
1682 return DeletionKind::Delete;
1683 }
1684
1685 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
1686 auto substituted =
1687 substituteFormatString(op.getFormatStringAttr(), op.getArgs());
1688 state[op.getLabel()] = LabelValue(substituted);
1689 return DeletionKind::Delete;
1690 }
1691
1692 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1693 auto *val = sharedState.internalizer.create<UniqueLabelStorage>(
1694 substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
1695 state[op.getLabel()] = val;
1696 materializer.registerIdentityValue(val);
1697 return DeletionKind::Delete;
1698 }
1699
1700 FailureOr<DeletionKind> visitOp(LabelOp op) { return DeletionKind::Keep; }
1701
1702 FailureOr<DeletionKind> visitOp(TestSuccessOp op) {
1703 return DeletionKind::Keep;
1704 }
1705
1706 FailureOr<DeletionKind> visitOp(TestFailureOp op) {
1707 return DeletionKind::Keep;
1708 }
1709
1710 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1711 size_t lower = get<size_t>(op.getLowerBound());
1712 size_t upper = get<size_t>(op.getUpperBound());
1713 if (lower > upper)
1714 return op->emitError("cannot select a number from an empty range");
1715
1716 if (auto intAttr =
1717 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1718 std::mt19937 customRng(intAttr.getInt());
1719 state[op.getResult()] =
1720 size_t(getUniformlyInRange(customRng, lower, upper));
1721 } else {
1722 state[op.getResult()] =
1723 size_t(getUniformlyInRange(sharedState.rng, lower, upper));
1724 }
1725
1726 return DeletionKind::Delete;
1727 }
1728
1729 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1730 size_t input = get<size_t>(op.getInput());
1731 auto width = op.getType().getWidth();
1732 auto emitError = [&]() { return op->emitError(); };
1733 if (input > APInt::getAllOnes(width).getZExtValue())
1734 return emitError() << "cannot represent " << input << " with " << width
1735 << " bits";
1736
1737 state[op.getResult()] =
1738 ImmediateAttr::get(op.getContext(), APInt(width, input));
1739 return DeletionKind::Delete;
1740 }
1741
1742 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1743 ContextResourceAttrInterface from = currentContext,
1744 to = cast<ContextResourceAttrInterface>(
1745 get<TypedAttr>(op.getContext()));
1746 if (!currentContext)
1747 from = DefaultContextAttr::get(op->getContext(), to.getType());
1748
1749 auto emitError = [&]() {
1750 auto diag = op.emitError();
1751 diag.attachNote(op.getLoc())
1752 << "while materializing value for context switching for " << op;
1753 return diag;
1754 };
1755
1756 if (from == to) {
1757 Value seqVal = materializer.materialize(
1758 get<SequenceStorage *>(op.getSequence()), op.getLoc(), emitError);
1759 if (!seqVal)
1760 return failure();
1761
1762 Value randSeqVal =
1763 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1764 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1765 return DeletionKind::Delete;
1766 }
1767
1768 // Switch to the desired context.
1769 // First, check if a context switch is registered that has the concrete
1770 // context as source and target.
1771 auto *iter = testState.contextSwitches.find({from, to});
1772
1773 // Try with 'any' context as target and the concrete context as source.
1774 if (iter == testState.contextSwitches.end())
1775 iter = testState.contextSwitches.find(
1776 {from, AnyContextAttr::get(op->getContext(), to.getType())});
1777
1778 // Try with 'any' context as source and the concrete context as target.
1779 if (iter == testState.contextSwitches.end())
1780 iter = testState.contextSwitches.find(
1781 {AnyContextAttr::get(op->getContext(), from.getType()), to});
1782
1783 // Try with 'any' context for both the source and the target.
1784 if (iter == testState.contextSwitches.end())
1785 iter = testState.contextSwitches.find(
1786 {AnyContextAttr::get(op->getContext(), from.getType()),
1787 AnyContextAttr::get(op->getContext(), to.getType())});
1788
1789 // Otherwise, fail with an error because we couldn't find a user
1790 // specification on how to switch between the requested contexts.
1791 // NOTE: we could think about supporting context switching via intermediate
1792 // context, i.e., treat it as a transitive relation.
1793 if (iter == testState.contextSwitches.end())
1794 return op->emitError("no context transition registered to switch from ")
1795 << from << " to " << to;
1796
1797 auto familyName = iter->second->familyName;
1798 SmallVector<ElaboratorValue> args{from, to,
1799 get<SequenceStorage *>(op.getSequence())};
1800 auto *seq = sharedState.internalizer.internalize<SequenceStorage>(
1801 familyName, std::move(args));
1802 auto *randSeq =
1803 sharedState.internalizer.create<RandomizedSequenceStorage>(to, seq);
1804 materializer.registerIdentityValue(randSeq);
1805 Value seqVal = materializer.materialize(randSeq, op.getLoc(), emitError);
1806 if (!seqVal)
1807 return failure();
1808
1809 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1810 return DeletionKind::Delete;
1811 }
1812
1813 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1814 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1815 get<SequenceStorage *>(op.getSequence());
1816 return DeletionKind::Delete;
1817 }
1818
1819 FailureOr<DeletionKind> visitOp(MemoryBlockDeclareOp op) {
1820 auto *val = sharedState.internalizer.create<MemoryBlockStorage>(
1821 op.getBaseAddress(), op.getEndAddress(), op.getType());
1822 state[op.getResult()] = val;
1823 materializer.registerIdentityValue(val);
1824 return DeletionKind::Delete;
1825 }
1826
1827 FailureOr<DeletionKind> visitOp(MemoryAllocOp op) {
1828 size_t size = get<size_t>(op.getSize());
1829 size_t alignment = get<size_t>(op.getAlignment());
1830 auto *memBlock = get<MemoryBlockStorage *>(op.getMemoryBlock());
1831 auto *val = sharedState.internalizer.create<MemoryStorage>(memBlock, size,
1832 alignment);
1833 state[op.getResult()] = val;
1834 materializer.registerIdentityValue(val);
1835 return DeletionKind::Delete;
1836 }
1837
1838 FailureOr<DeletionKind> visitOp(MemorySizeOp op) {
1839 auto *memory = get<MemoryStorage *>(op.getMemory());
1840 state[op.getResult()] = memory->size;
1841 return DeletionKind::Delete;
1842 }
1843
1844 FailureOr<DeletionKind> visitOp(TupleCreateOp op) {
1845 SmallVector<ElaboratorValue> values;
1846 values.reserve(op.getElements().size());
1847 for (auto el : op.getElements())
1848 values.push_back(state[el]);
1849
1850 state[op.getResult()] =
1851 sharedState.internalizer.internalize<TupleStorage>(std::move(values));
1852 return DeletionKind::Delete;
1853 }
1854
1855 FailureOr<DeletionKind> visitOp(TupleExtractOp op) {
1856 auto *tuple = get<TupleStorage *>(op.getTuple());
1857 state[op.getResult()] = tuple->values[op.getIndex().getZExtValue()];
1858 return DeletionKind::Delete;
1859 }
1860
1861 FailureOr<DeletionKind> visitOp(CommentOp op) { return DeletionKind::Keep; }
1862
1863 FailureOr<DeletionKind> visitOp(rtg::YieldOp op) {
1864 return DeletionKind::Keep;
1865 }
1866
1867 FailureOr<DeletionKind> visitOp(ValidateOp op) {
1868 SmallVector<ElaboratorValue> defaultUsedValues, elseValues;
1869 for (auto v : op.getDefaultUsedValues())
1870 defaultUsedValues.push_back(state.at(v));
1871
1872 for (auto v : op.getElseValues())
1873 elseValues.push_back(state.at(v));
1874
1875 auto *validationVal = sharedState.internalizer.create<ValidationValue>(
1876 op.getValue().getType(), state[op.getRef()],
1877 state[op.getDefaultValue()], op.getIdAttr(),
1878 std::move(defaultUsedValues), std::move(elseValues));
1879 state[op.getValue()] = validationVal;
1880 materializer.registerIdentityValue(validationVal);
1881 materializer.map(validationVal, op.getValue());
1882
1883 for (auto [i, val] : llvm::enumerate(op.getValues())) {
1884 auto *muxVal = sharedState.internalizer.create<ValidationMuxedValue>(
1885 val.getType(), validationVal, i);
1886 state[val] = muxVal;
1887 materializer.registerIdentityValue(muxVal);
1888 materializer.map(muxVal, val);
1889 }
1890
1891 return DeletionKind::Keep;
1892 }
1893
1894 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1895 bool cond = get<bool>(op.getCondition());
1896 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1897 if (toElaborate.empty())
1898 return DeletionKind::Delete;
1899
1900 // Just reuse this elaborator for the nested region because we need access
1901 // to the elaborated values outside the nested region (since it is not
1902 // isolated from above) and we want to materialize the region inline, thus
1903 // don't need a new materializer instance.
1904 SmallVector<ElaboratorValue> yieldedVals;
1905 if (failed(elaborate(toElaborate, {}, yieldedVals)))
1906 return failure();
1907
1908 // Map the results of the 'scf.if' to the yielded values.
1909 for (auto [res, out] : llvm::zip(op.getResults(), yieldedVals))
1910 state[res] = out;
1911
1912 return DeletionKind::Delete;
1913 }
1914
1915 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1916 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1917 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1918 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1919 return op->emitOpError("can only elaborate index type iterator");
1920
1921 auto lowerBound = get<size_t>(op.getLowerBound());
1922 auto step = get<size_t>(op.getStep());
1923 auto upperBound = get<size_t>(op.getUpperBound());
1924
1925 // Prepare for first iteration by assigning the nested regions block
1926 // arguments. We can just reuse this elaborator because we need access to
1927 // values elaborated in the parent region anyway and materialize everything
1928 // inline (i.e., don't need a new materializer).
1929 state[op.getInductionVar()] = lowerBound;
1930 for (auto [iterArg, initArg] :
1931 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1932 state[iterArg] = state.at(initArg);
1933
1934 // This loop performs the actual 'scf.for' loop iterations.
1935 SmallVector<ElaboratorValue> yieldedVals;
1936 for (size_t i = lowerBound; i < upperBound; i += step) {
1937 yieldedVals.clear();
1938 if (failed(elaborate(op.getBodyRegion(), {}, yieldedVals)))
1939 return failure();
1940
1941 // Prepare for the next iteration by updating the mapping of the nested
1942 // regions block arguments
1943 state[op.getInductionVar()] = i + step;
1944 for (auto [iterArg, prevIterArg] :
1945 llvm::zip(op.getRegionIterArgs(), yieldedVals))
1946 state[iterArg] = prevIterArg;
1947 }
1948
1949 // Transfer the previously yielded values to the for loop result values.
1950 for (auto [res, iterArg] :
1951 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1952 state[res] = state.at(iterArg);
1953
1954 return DeletionKind::Delete;
1955 }
1956
1957 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1958 return DeletionKind::Delete;
1959 }
1960
1961 FailureOr<DeletionKind> visitOp(arith::AddIOp op) {
1962 if (!isa<IndexType>(op.getType()))
1963 return op->emitError("only index operands supported");
1964
1965 size_t lhs = get<size_t>(op.getLhs());
1966 size_t rhs = get<size_t>(op.getRhs());
1967 state[op.getResult()] = lhs + rhs;
1968 return DeletionKind::Delete;
1969 }
1970
1971 FailureOr<DeletionKind> visitOp(arith::AndIOp op) {
1972 if (!op.getType().isSignlessInteger(1))
1973 return op->emitError("only 'i1' operands supported");
1974
1975 bool lhs = get<bool>(op.getLhs());
1976 bool rhs = get<bool>(op.getRhs());
1977 state[op.getResult()] = lhs && rhs;
1978 return DeletionKind::Delete;
1979 }
1980
1981 FailureOr<DeletionKind> visitOp(arith::XOrIOp op) {
1982 if (!op.getType().isSignlessInteger(1))
1983 return op->emitError("only 'i1' operands supported");
1984
1985 bool lhs = get<bool>(op.getLhs());
1986 bool rhs = get<bool>(op.getRhs());
1987 state[op.getResult()] = lhs != rhs;
1988 return DeletionKind::Delete;
1989 }
1990
1991 FailureOr<DeletionKind> visitOp(arith::OrIOp op) {
1992 if (!op.getType().isSignlessInteger(1))
1993 return op->emitError("only 'i1' operands supported");
1994
1995 bool lhs = get<bool>(op.getLhs());
1996 bool rhs = get<bool>(op.getRhs());
1997 state[op.getResult()] = lhs || rhs;
1998 return DeletionKind::Delete;
1999 }
2000
2001 FailureOr<DeletionKind> visitOp(arith::SelectOp op) {
2002 bool cond = get<bool>(op.getCondition());
2003 auto trueVal = state[op.getTrueValue()];
2004 auto falseVal = state[op.getFalseValue()];
2005 state[op.getResult()] = cond ? trueVal : falseVal;
2006 return DeletionKind::Delete;
2007 }
2008
2009 FailureOr<DeletionKind> visitOp(index::AddOp op) {
2010 size_t lhs = get<size_t>(op.getLhs());
2011 size_t rhs = get<size_t>(op.getRhs());
2012 state[op.getResult()] = lhs + rhs;
2013 return DeletionKind::Delete;
2014 }
2015
2016 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
2017 size_t lhs = get<size_t>(op.getLhs());
2018 size_t rhs = get<size_t>(op.getRhs());
2019 bool result;
2020 switch (op.getPred()) {
2021 case index::IndexCmpPredicate::EQ:
2022 result = lhs == rhs;
2023 break;
2024 case index::IndexCmpPredicate::NE:
2025 result = lhs != rhs;
2026 break;
2027 case index::IndexCmpPredicate::ULT:
2028 result = lhs < rhs;
2029 break;
2030 case index::IndexCmpPredicate::ULE:
2031 result = lhs <= rhs;
2032 break;
2033 case index::IndexCmpPredicate::UGT:
2034 result = lhs > rhs;
2035 break;
2036 case index::IndexCmpPredicate::UGE:
2037 result = lhs >= rhs;
2038 break;
2039 default:
2040 return op->emitOpError("elaboration not supported");
2041 }
2042 state[op.getResult()] = result;
2043 return DeletionKind::Delete;
2044 }
2045
2046 FailureOr<DeletionKind> visitOp(ConcatImmediateOp op) {
2047 bool anyValidationValues =
2048 llvm::any_of(op.getOperands(), [&](auto operand) {
2049 return std::holds_alternative<ValidationValue *>(state[operand]);
2050 });
2051
2052 // FIXME: this is a hack and this case should be handled generically for all
2053 // operations
2054 if (anyValidationValues) {
2055 SmallVector<ElaboratorValue> operands;
2056 for (auto operand : op.getOperands())
2057 operands.push_back(state[operand]);
2058 state[op.getResult()] =
2059 sharedState.internalizer.internalize<ImmediateConcatStorage>(
2060 std::move(operands));
2061 return DeletionKind::Delete;
2062 }
2063
2064 auto result = APInt::getZeroWidth();
2065 for (auto operand : op.getOperands())
2066 result = result.concat(
2067 cast<ImmediateAttr>(get<TypedAttr>(operand)).getValue());
2068
2069 state[op.getResult()] = ImmediateAttr::get(op.getContext(), result);
2070 return DeletionKind::Delete;
2071 }
2072
2073 FailureOr<DeletionKind> visitOp(SliceImmediateOp op) {
2074 // FIXME: this is a hack and this case should be handled generically for all
2075 // operations
2076 if (std::holds_alternative<ValidationValue *>(state[op.getInput()])) {
2077 state[op.getResult()] =
2078 sharedState.internalizer.internalize<ImmediateSliceStorage>(
2079 state[op.getInput()], op.getLowBit(), op.getResult().getType());
2080 return DeletionKind::Delete;
2081 }
2082
2083 auto inputValue =
2084 cast<ImmediateAttr>(get<TypedAttr>(op.getInput())).getValue();
2085 auto sliced = inputValue.extractBits(op.getResult().getType().getWidth(),
2086 op.getLowBit());
2087 state[op.getResult()] = ImmediateAttr::get(op.getContext(), sliced);
2088 return DeletionKind::Delete;
2089 }
2090
2091 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
2092 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
2093 .Case<
2094 // Arith ops
2095 arith::AddIOp, arith::XOrIOp, arith::AndIOp, arith::OrIOp,
2096 arith::SelectOp,
2097 // Index ops
2098 index::AddOp, index::CmpOp,
2099 // SCF ops
2100 scf::IfOp, scf::ForOp, scf::YieldOp>(
2101 [&](auto op) { return visitOp(op); })
2102 .Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); });
2103 }
2104
2105 // NOLINTNEXTLINE(misc-no-recursion)
2106 LogicalResult elaborate(Region &region,
2107 ArrayRef<ElaboratorValue> regionArguments,
2108 SmallVector<ElaboratorValue> &terminatorOperands) {
2109 if (region.getBlocks().size() > 1)
2110 return region.getParentOp()->emitOpError(
2111 "regions with more than one block are not supported");
2112
2113 for (auto [arg, elabArg] :
2114 llvm::zip(region.getArguments(), regionArguments))
2115 state[arg] = elabArg;
2116
2117 Block *block = &region.front();
2118 for (auto &op : *block) {
2119 auto result = dispatchOpVisitor(&op);
2120 if (failed(result))
2121 return failure();
2122
2123 if (*result == DeletionKind::Keep)
2124 if (failed(materializer.materialize(&op, state)))
2125 return failure();
2126
2127 LLVM_DEBUG({
2128 llvm::dbgs() << "Elaborated " << op << " to\n[";
2129
2130 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) {
2131 if (state.contains(res))
2132 llvm::dbgs() << state.at(res);
2133 else
2134 llvm::dbgs() << "unknown";
2135 });
2136
2137 llvm::dbgs() << "]\n\n";
2138 });
2139 }
2140
2141 if (region.front().mightHaveTerminator())
2142 for (auto val : region.front().getTerminator()->getOperands())
2143 terminatorOperands.push_back(state.at(val));
2144
2145 return success();
2146 }
2147
2148private:
2149 // State to be shared between all elaborator instances.
2150 SharedState &sharedState;
2151
2152 // State to a specific RTG test and the sequences placed within it.
2153 TestState &testState;
2154
2155 // Allows us to materialize ElaboratorValues to the IR operations necessary to
2156 // obtain an SSA value representing that elaborated value.
2157 Materializer &materializer;
2158
2159 // A map from SSA values to a pointer of an interned elaborator value.
2160 DenseMap<Value, ElaboratorValue> state;
2161
2162 // The current context we are elaborating under.
2163 ContextResourceAttrInterface currentContext;
2164};
2165} // namespace
2166
2167SequenceOp
2168Materializer::elaborateSequence(const RandomizedSequenceStorage *seq,
2169 SmallVector<ElaboratorValue> &elabArgs) {
2170 auto familyOp =
2171 sharedState.table.lookup<SequenceOp>(seq->sequence->familyName);
2172 // TODO: don't clone if this is the only remaining reference to this
2173 // sequence
2174 OpBuilder builder(familyOp);
2175 auto seqOp = builder.cloneWithoutRegions(familyOp);
2176 auto name = sharedState.names.newName(seq->sequence->familyName.getValue());
2177 seqOp.setSymName(name);
2178 seqOp.getBodyRegion().emplaceBlock();
2179 sharedState.table.insert(seqOp);
2180 assert(seqOp.getSymName() == name && "should not have been renamed");
2181
2182 LLVM_DEBUG(llvm::dbgs() << "\n=== Elaborating sequence family @"
2183 << familyOp.getSymName() << " into @"
2184 << seqOp.getSymName() << " under context "
2185 << seq->context << "\n\n");
2186
2187 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()), testState,
2188 sharedState, elabArgs);
2189 Elaborator elaborator(sharedState, testState, materializer, seq->context);
2190 SmallVector<ElaboratorValue> yieldedVals;
2191 if (failed(elaborator.elaborate(familyOp.getBodyRegion(), seq->sequence->args,
2192 yieldedVals)))
2193 return {};
2194
2195 seqOp.setSequenceType(
2196 SequenceType::get(builder.getContext(), materializer.getBlockArgTypes()));
2197 materializer.finalize();
2198
2199 return seqOp;
2200}
2201
2202//===----------------------------------------------------------------------===//
2203// Elaborator Pass
2204//===----------------------------------------------------------------------===//
2205
2206namespace {
2207struct ElaborationPass
2208 : public rtg::impl::ElaborationPassBase<ElaborationPass> {
2209 using Base::Base;
2210
2211 void runOnOperation() override;
2212 void matchTestsAgainstTargets(SymbolTable &table);
2213 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
2214};
2215} // namespace
2216
2217void ElaborationPass::runOnOperation() {
2218 auto moduleOp = getOperation();
2219 SymbolTable table(moduleOp);
2220
2221 matchTestsAgainstTargets(table);
2222
2223 if (failed(elaborateModule(moduleOp, table)))
2224 return signalPassFailure();
2225}
2226
2227void ElaborationPass::matchTestsAgainstTargets(SymbolTable &table) {
2228 auto moduleOp = getOperation();
2229
2230 for (auto test : llvm::make_early_inc_range(moduleOp.getOps<TestOp>())) {
2231 if (test.getTargetAttr())
2232 continue;
2233
2234 bool matched = false;
2235
2236 for (auto target : moduleOp.getOps<TargetOp>()) {
2237 // Check if the target type is a subtype of the test's target type
2238 // This means that for each entry in the test's target type, there must be
2239 // a corresponding entry with the same name and type in the target's type
2240 bool isSubtype = true;
2241 auto testEntries = test.getTargetType().getEntries();
2242 auto targetEntries = target.getTarget().getEntries();
2243
2244 // Check if target is a subtype of test requirements
2245 // Since entries are sorted by name, we can do this in a single pass
2246 size_t targetIdx = 0;
2247 for (auto testEntry : testEntries) {
2248 // Find the matching entry in target entries.
2249 while (targetIdx < targetEntries.size() &&
2250 targetEntries[targetIdx].name.getValue() <
2251 testEntry.name.getValue())
2252 targetIdx++;
2253
2254 // Check if we found a matching entry with the same name and type
2255 if (targetIdx >= targetEntries.size() ||
2256 targetEntries[targetIdx].name != testEntry.name ||
2257 targetEntries[targetIdx].type != testEntry.type) {
2258 isSubtype = false;
2259 break;
2260 }
2261 }
2262
2263 if (!isSubtype)
2264 continue;
2265
2266 IRRewriter rewriter(test);
2267 // Create a new test for the matched target
2268 auto newTest = cast<TestOp>(test->clone());
2269 newTest.setSymName(test.getSymName().str() + "_" +
2270 target.getSymName().str());
2271
2272 // Set the target symbol specifying that this test is only suitable for
2273 // that target.
2274 newTest.setTargetAttr(target.getSymNameAttr());
2275
2276 table.insert(newTest, rewriter.getInsertionPoint());
2277 matched = true;
2278 }
2279
2280 if (matched || deleteUnmatchedTests)
2281 test->erase();
2282 }
2283}
2284
2285static bool onlyLegalToMaterializeInTarget(Type type) {
2286 return isa<MemoryBlockType, ContextResourceTypeInterface>(type);
2287}
2288
2289LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
2290 SymbolTable &table) {
2291 SharedState state(table, seed);
2292
2293 // Update the name cache
2294 state.names.add(moduleOp);
2295
2296 struct TargetElabResult {
2297 DictType targetType;
2298 SmallVector<ElaboratorValue> yields;
2299 TestState testState;
2300 };
2301
2302 // Map to store elaborated targets
2303 DenseMap<StringAttr, TargetElabResult> targetMap;
2304 for (auto targetOp : moduleOp.getOps<TargetOp>()) {
2305 LLVM_DEBUG(llvm::dbgs() << "=== Elaborating target @"
2306 << targetOp.getSymName() << "\n\n");
2307
2308 auto &result = targetMap[targetOp.getSymNameAttr()];
2309 result.targetType = targetOp.getTarget();
2310
2311 SmallVector<ElaboratorValue> blockArgs;
2312 Materializer targetMaterializer(OpBuilder::atBlockBegin(targetOp.getBody()),
2313 result.testState, state, blockArgs);
2314 Elaborator targetElaborator(state, result.testState, targetMaterializer);
2315
2316 // Elaborate the target
2317 if (failed(targetElaborator.elaborate(targetOp.getBodyRegion(), {},
2318 result.yields)))
2319 return failure();
2320 }
2321
2322 // Initialize the worklist with the test ops since they cannot be placed by
2323 // other ops.
2324 for (auto testOp : moduleOp.getOps<TestOp>()) {
2325 // Skip tests without a target attribute - these couldn't be matched
2326 // against any target but can be useful to keep around for reporting
2327 // purposes.
2328 if (!testOp.getTargetAttr())
2329 continue;
2330
2331 LLVM_DEBUG(llvm::dbgs()
2332 << "\n=== Elaborating test @" << testOp.getTemplateName()
2333 << " for target @" << *testOp.getTarget() << "\n\n");
2334
2335 // Get the target for this test
2336 auto targetResult = targetMap[testOp.getTargetAttr()];
2337 TestState testState = targetResult.testState;
2338 testState.name = testOp.getSymNameAttr();
2339
2340 SmallVector<ElaboratorValue> filteredYields;
2341 unsigned i = 0;
2342 for (auto [entry, yield] :
2343 llvm::zip(targetResult.targetType.getEntries(), targetResult.yields)) {
2344 if (i >= testOp.getTargetType().getEntries().size())
2345 break;
2346
2347 if (entry.name == testOp.getTargetType().getEntries()[i].name) {
2348 filteredYields.push_back(yield);
2349 ++i;
2350 }
2351 }
2352
2353 // Now elaborate the test with the same state, passing the target yield
2354 // values as arguments
2355 SmallVector<ElaboratorValue> blockArgs;
2356 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()),
2357 testState, state, blockArgs);
2358
2359 for (auto [arg, val] :
2360 llvm::zip(testOp.getBody()->getArguments(), filteredYields))
2361 if (onlyLegalToMaterializeInTarget(arg.getType()))
2362 materializer.map(val, arg);
2363
2364 Elaborator elaborator(state, testState, materializer);
2365 SmallVector<ElaboratorValue> ignore;
2366 if (failed(elaborator.elaborate(testOp.getBodyRegion(), filteredYields,
2367 ignore)))
2368 return failure();
2369
2370 materializer.finalize();
2371 }
2372
2373 return success();
2374}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static bool onlyLegalToMaterializeInTarget(Type type)
static void print(TypedAttr val, llvm::raw_ostream &os)
static StringAttr substituteFormatString(StringAttr formatString, ArrayRef< Attribute > substitutes)
Definition RTGOps.cpp:895
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
This helps visit TypeOp nodes.
Definition RTGVisitors.h:29
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:90
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
static llvm::hash_code hash_value(const ModulePort &port)
Definition HWTypes.h:38
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition Utils.h:32
llvm::hash_code hash_value(const DenseSet< T > &set)
Definition rtg.py:1
Definition seq.py:1
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)