Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ElaborationPass.cpp
Go to the documentation of this file.
1//===- ElaborationPass.cpp - RTG ElaborationPass implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass elaborates the random parts of the RTG dialect.
10// It performs randomization top-down, i.e., random constructs in a sequence
11// that is invoked multiple times can yield different randomization results
12// for each invokation.
13//
14//===----------------------------------------------------------------------===//
15
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Index/IR/IndexDialect.h"
23#include "mlir/Dialect/Index/IR/IndexOps.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/DenseMapInfoVariant.h"
28#include "llvm/Support/Debug.h"
29#include <queue>
30#include <random>
31
32namespace circt {
33namespace rtg {
34#define GEN_PASS_DEF_ELABORATIONPASS
35#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
36} // namespace rtg
37} // namespace circt
38
39using namespace mlir;
40using namespace circt;
41using namespace circt::rtg;
42using llvm::MapVector;
43
44#define DEBUG_TYPE "rtg-elaboration"
45
46//===----------------------------------------------------------------------===//
47// Uniform Distribution Helper
48//
49// Simplified version of
50// https://github.com/llvm/llvm-project/blob/main/libcxx/include/__random/uniform_int_distribution.h
51// We use our custom version here to get the same results when compiled with
52// different compiler versions and standard libraries.
53//===----------------------------------------------------------------------===//
54
55static uint32_t computeMask(size_t w) {
56 size_t n = w / 32 + (w % 32 != 0);
57 size_t w0 = w / n;
58 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
59}
60
61/// Get a number uniformly at random in the in specified range.
62static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) {
63 const uint32_t diff = b - a + 1;
64 if (diff == 1)
65 return a;
66
67 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
68 if (diff == 0)
69 return rng();
70
71 uint32_t width = digits - llvm::countl_zero(diff) - 1;
72 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
73 ++width;
74
75 uint32_t mask = computeMask(diff);
76 uint32_t u;
77 do {
78 u = rng() & mask;
79 } while (u >= diff);
80
81 return u + a;
82}
83
84//===----------------------------------------------------------------------===//
85// Elaborator Value
86//===----------------------------------------------------------------------===//
87
88namespace {
89struct ArrayStorage;
90struct BagStorage;
91struct SequenceStorage;
92struct RandomizedSequenceStorage;
93struct InterleavedSequenceStorage;
94struct SetStorage;
95struct VirtualRegisterStorage;
96struct UniqueLabelStorage;
97struct TupleStorage;
98struct MemoryStorage;
99struct MemoryBlockStorage;
100struct ValidationValue;
101struct ValidationMuxedValue;
102struct ImmediateConcatStorage;
103struct ImmediateSliceStorage;
104
105/// Simple wrapper around a 'StringAttr' such that we know to materialize it as
106/// a label declaration instead of calling the builtin dialect constant
107/// materializer.
108struct LabelValue {
109 LabelValue(StringAttr name) : name(name) {}
110
111 bool operator==(const LabelValue &other) const { return name == other.name; }
112
113 /// The label name.
114 StringAttr name;
115};
116
117/// The abstract base class for elaborated values.
118using ElaboratorValue = std::variant<
119 TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
120 RandomizedSequenceStorage *, InterleavedSequenceStorage *, SetStorage *,
121 VirtualRegisterStorage *, UniqueLabelStorage *, LabelValue, ArrayStorage *,
122 TupleStorage *, MemoryStorage *, MemoryBlockStorage *, ValidationValue *,
123 ValidationMuxedValue *, ImmediateConcatStorage *, ImmediateSliceStorage *>;
124
125// NOLINTNEXTLINE(readability-identifier-naming)
126llvm::hash_code hash_value(const LabelValue &val) {
127 return llvm::hash_value(val.name);
128}
129
130// NOLINTNEXTLINE(readability-identifier-naming)
131llvm::hash_code hash_value(const ElaboratorValue &val) {
132 return std::visit(
133 [&val](const auto &alternative) {
134 // Include index in hash to make sure same value as different
135 // alternatives don't collide.
136 return llvm::hash_combine(val.index(), alternative);
137 },
138 val);
139}
140
141} // namespace
142
143namespace llvm {
144
145template <>
146struct DenseMapInfo<bool> {
147 static inline unsigned getEmptyKey() { return false; }
148 static inline unsigned getTombstoneKey() { return true; }
149 static unsigned getHashValue(const bool &val) { return val * 37U; }
150
151 static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; }
152};
153template <>
154struct DenseMapInfo<LabelValue> {
155 static inline LabelValue getEmptyKey() {
157 }
158 static inline LabelValue getTombstoneKey() {
160 }
161 static unsigned getHashValue(const LabelValue &val) {
162 return hash_value(val);
163 }
164
165 static bool isEqual(const LabelValue &lhs, const LabelValue &rhs) {
166 return lhs == rhs;
167 }
168};
169
170} // namespace llvm
171
172//===----------------------------------------------------------------------===//
173// Elaborator Value Storages and Internalization
174//===----------------------------------------------------------------------===//
175
176namespace {
177
178/// Lightweight object to be used as the key for internalization sets. It caches
179/// the hashcode of the internalized object and a pointer to it. This allows a
180/// delayed allocation and construction of the actual object and thus only has
181/// to happen if the object is not already in the set.
182template <typename StorageTy>
183struct HashedStorage {
184 HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr)
185 : hashcode(hashcode), storage(storage) {}
186
187 unsigned hashcode;
188 StorageTy *storage;
189};
190
191/// A DenseMapInfo implementation to support 'insert_as' for the internalization
192/// sets. When comparing two 'HashedStorage's we can just compare the already
193/// internalized storage pointers, otherwise we have to call the costly
194/// 'isEqual' method.
195template <typename StorageTy>
196struct StorageKeyInfo {
197 static inline HashedStorage<StorageTy> getEmptyKey() {
198 return HashedStorage<StorageTy>(0,
199 DenseMapInfo<StorageTy *>::getEmptyKey());
200 }
201 static inline HashedStorage<StorageTy> getTombstoneKey() {
202 return HashedStorage<StorageTy>(
203 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
204 }
205
206 static inline unsigned getHashValue(const HashedStorage<StorageTy> &key) {
207 return key.hashcode;
208 }
209 static inline unsigned getHashValue(const StorageTy &key) {
210 return key.hashcode;
211 }
212
213 static inline bool isEqual(const HashedStorage<StorageTy> &lhs,
214 const HashedStorage<StorageTy> &rhs) {
215 return lhs.storage == rhs.storage;
216 }
217 static inline bool isEqual(const StorageTy &lhs,
218 const HashedStorage<StorageTy> &rhs) {
219 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
220 return false;
221
222 return lhs.isEqual(rhs.storage);
223 }
224};
225
226// Values with structural equivalence intended to be internalized.
227//===----------------------------------------------------------------------===//
228
229/// Storage object for an '!rtg.set<T>'.
230struct SetStorage {
231 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
232 : hashcode(llvm::hash_combine(
233 type, llvm::hash_combine_range(set.begin(), set.end()))),
234 set(std::move(set)), type(type) {}
235
236 bool isEqual(const SetStorage *other) const {
237 return hashcode == other->hashcode && set == other->set &&
238 type == other->type;
239 }
240
241 // The cached hashcode to avoid repeated computations.
242 const unsigned hashcode;
243
244 // Stores the elaborated values contained in the set.
245 const SetVector<ElaboratorValue> set;
246
247 // Store the set type such that we can materialize this evaluated value
248 // also in the case where the set is empty.
249 const Type type;
250};
251
252/// Storage object for an '!rtg.bag<T>'.
253struct BagStorage {
254 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
255 : hashcode(llvm::hash_combine(
256 type, llvm::hash_combine_range(bag.begin(), bag.end()))),
257 bag(std::move(bag)), type(type) {}
258
259 bool isEqual(const BagStorage *other) const {
260 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
261 type == other->type;
262 }
263
264 // The cached hashcode to avoid repeated computations.
265 const unsigned hashcode;
266
267 // Stores the elaborated values contained in the bag with their number of
268 // occurences.
269 const MapVector<ElaboratorValue, uint64_t> bag;
270
271 // Store the bag type such that we can materialize this evaluated value
272 // also in the case where the bag is empty.
273 const Type type;
274};
275
276/// Storage object for an '!rtg.sequence'.
277struct SequenceStorage {
278 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
279 : hashcode(llvm::hash_combine(
280 familyName, llvm::hash_combine_range(args.begin(), args.end()))),
281 familyName(familyName), args(std::move(args)) {}
282
283 bool isEqual(const SequenceStorage *other) const {
284 return hashcode == other->hashcode && familyName == other->familyName &&
285 args == other->args;
286 }
287
288 // The cached hashcode to avoid repeated computations.
289 const unsigned hashcode;
290
291 // The name of the sequence family this sequence is derived from.
292 const StringAttr familyName;
293
294 // The elaborator values used during substitution of the sequence family.
295 const SmallVector<ElaboratorValue> args;
296};
297
298/// Storage object for interleaved '!rtg.randomized_sequence'es.
299struct InterleavedSequenceStorage {
300 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
301 uint32_t batchSize)
302 : sequences(std::move(sequences)), batchSize(batchSize),
303 hashcode(llvm::hash_combine(
304 llvm::hash_combine_range(sequences.begin(), sequences.end()),
305 batchSize)) {}
306
307 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
308 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
309 hashcode(llvm::hash_combine(
310 llvm::hash_combine_range(sequences.begin(), sequences.end()),
311 batchSize)) {}
312
313 bool isEqual(const InterleavedSequenceStorage *other) const {
314 return hashcode == other->hashcode && sequences == other->sequences &&
315 batchSize == other->batchSize;
316 }
317
318 const SmallVector<ElaboratorValue> sequences;
319
320 const uint32_t batchSize;
321
322 // The cached hashcode to avoid repeated computations.
323 const unsigned hashcode;
324};
325
326/// Storage object for '!rtg.array`-typed values.
327struct ArrayStorage {
328 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
329 : hashcode(llvm::hash_combine(
330 type, llvm::hash_combine_range(array.begin(), array.end()))),
331 type(type), array(array) {}
332
333 bool isEqual(const ArrayStorage *other) const {
334 return hashcode == other->hashcode && type == other->type &&
335 array == other->array;
336 }
337
338 // The cached hashcode to avoid repeated computations.
339 const unsigned hashcode;
340
341 /// The type of the array. This is necessary because an array of size 0
342 /// cannot be reconstructed without knowing the original element type.
343 const Type type;
344
345 /// The label name. For unique labels, this is just the prefix.
346 const SmallVector<ElaboratorValue> array;
347};
348
349/// Storage object for 'tuple`-typed values.
350struct TupleStorage {
351 TupleStorage(SmallVector<ElaboratorValue> &&values)
352 : hashcode(llvm::hash_combine_range(values.begin(), values.end())),
353 values(std::move(values)) {}
354
355 bool isEqual(const TupleStorage *other) const {
356 return hashcode == other->hashcode && values == other->values;
357 }
358
359 // The cached hashcode to avoid repeated computations.
360 const unsigned hashcode;
361
362 const SmallVector<ElaboratorValue> values;
363};
364
365/// Storage object for immediate concatenation operations.
366/// FIXME: opaque values should be handled generically
367struct ImmediateConcatStorage {
368 ImmediateConcatStorage(SmallVector<ElaboratorValue> &&operands)
369 : hashcode(llvm::hash_combine_range(operands.begin(), operands.end())),
370 operands(std::move(operands)) {}
371
372 bool isEqual(const ImmediateConcatStorage *other) const {
373 return hashcode == other->hashcode && operands == other->operands;
374 }
375
376 const unsigned hashcode;
377 const SmallVector<ElaboratorValue> operands;
378};
379
380/// Storage object for immediate slice operations.
381/// FIXME: opaque values should be handled generically
382struct ImmediateSliceStorage {
383 ImmediateSliceStorage(ElaboratorValue input, unsigned lowBit, Type type)
384 : hashcode(llvm::hash_combine(input, lowBit, type)), input(input),
385 lowBit(lowBit), type(type) {}
386
387 bool isEqual(const ImmediateSliceStorage *other) const {
388 return hashcode == other->hashcode && input == other->input &&
389 lowBit == other->lowBit && type == other->type;
390 }
391
392 const unsigned hashcode;
393 const ElaboratorValue input;
394 const unsigned lowBit;
395 const Type type;
396};
397
398// Values with identity not intended to be internalized.
399//===----------------------------------------------------------------------===//
400
401/// Base class for storages that represent values with identity, i.e., two
402/// values are not considered equivalent if they are structurally the same, but
403/// each definition of such a value is unique. E.g., unique labels or virtual
404/// registers. These cannot be materialized anew in each nested sequence, but
405/// must be passed as arguments.
406struct IdentityValue {
407
408 IdentityValue(Type type) : type(type) {}
409
410#ifndef NDEBUG
411
412 /// In debug mode, track whether this value was already materialized to
413 /// assert if it's illegally materialized multiple times.
414 ///
415 /// Instead of deleting operations defining these values and materializing
416 /// them again, we could retain the operations. However, we still need
417 /// specific storages to represent these values in some cases, e.g., to get
418 /// the size of a memory allocation. Also, elaboration of nested control-flow
419 /// regions (e.g. `scf.for`) relies on materialization of such values lazily
420 /// instead of cloning the operations eagerly.
421 bool alreadyMaterialized = false;
422
423#endif
424
425 const Type type;
426};
427
428/// Represents a unique virtual register.
429struct VirtualRegisterStorage : IdentityValue {
430 VirtualRegisterStorage(ArrayAttr allowedRegs, Type type)
431 : IdentityValue(type), allowedRegs(allowedRegs) {}
432
433 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
434 // VirtualRegisters are never internalized.
435
436 // The list of fixed registers allowed to be selected for this virtual
437 // register.
438 const ArrayAttr allowedRegs;
439};
440
441struct UniqueLabelStorage : IdentityValue {
442 UniqueLabelStorage(StringAttr name)
443 : IdentityValue(LabelType::get(name.getContext())), name(name) {}
444
445 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
446 // VirtualRegisters are never internalized.
447
448 /// The label name. For unique labels, this is just the prefix.
449 const StringAttr name;
450};
451
452/// Storage object for '!rtg.isa.memoryblock`-typed values.
453struct MemoryBlockStorage : IdentityValue {
454 MemoryBlockStorage(const APInt &baseAddress, const APInt &endAddress,
455 Type type)
456 : IdentityValue(type), baseAddress(baseAddress), endAddress(endAddress) {}
457
458 // The base address of the memory. The width of the APInt also represents the
459 // address width of the memory. This is an APInt to support memories of
460 // >64-bit machines.
461 const APInt baseAddress;
462
463 // The last address of the memory.
464 const APInt endAddress;
465};
466
467/// Storage object for '!rtg.isa.memory`-typed values.
468struct MemoryStorage : IdentityValue {
469 MemoryStorage(MemoryBlockStorage *memoryBlock, size_t size, size_t alignment)
470 : IdentityValue(MemoryType::get(memoryBlock->type.getContext(),
471 memoryBlock->baseAddress.getBitWidth())),
472 memoryBlock(memoryBlock), size(size), alignment(alignment) {}
473
474 MemoryBlockStorage *memoryBlock;
475 const size_t size;
476 const size_t alignment;
477};
478
479/// Storage object for an '!rtg.randomized_sequence'.
480struct RandomizedSequenceStorage : IdentityValue {
481 RandomizedSequenceStorage(ContextResourceAttrInterface context,
482 SequenceStorage *sequence)
483 : IdentityValue(
484 RandomizedSequenceType::get(sequence->familyName.getContext())),
485 context(context), sequence(sequence) {}
486
487 // The context under which this sequence is placed.
488 const ContextResourceAttrInterface context;
489
490 const SequenceStorage *sequence;
491};
492
493/// Storage object for an '!rtg.validate' result.
494struct ValidationValue : IdentityValue {
495 ValidationValue(Type type, const ElaboratorValue &ref,
496 const ElaboratorValue &defaultValue, StringAttr id,
497 SmallVector<ElaboratorValue> &&defaultUsedValues,
498 SmallVector<ElaboratorValue> &&elseValues)
499 : IdentityValue(type), ref(ref), defaultValue(defaultValue), id(id),
500 defaultUsedValues(std::move(defaultUsedValues)),
501 elseValues(std::move(elseValues)) {}
502
503 const ElaboratorValue ref;
504 const ElaboratorValue defaultValue;
505 const StringAttr id;
506 const SmallVector<ElaboratorValue> defaultUsedValues;
507 const SmallVector<ElaboratorValue> elseValues;
508};
509
510/// Storage object for the 'values' results of 'rtg.validate'.
511struct ValidationMuxedValue : IdentityValue {
512 ValidationMuxedValue(Type type, const ValidationValue *value, unsigned idx)
513 : IdentityValue(type), value(value), idx(idx) {}
514
515 const ValidationValue *value;
516 const unsigned idx;
517};
518
519/// An 'Internalizer' object internalizes storages and takes ownership of them.
520/// When the initializer object is destroyed, all owned storages are also
521/// deallocated and thus must not be accessed anymore.
522class Internalizer {
523public:
524 /// Internalize a storage of type `StorageTy` constructed with arguments
525 /// `args`. The pointers returned by this method can be used to compare
526 /// objects when, e.g., computing set differences, uniquing the elements in a
527 /// set, etc. Otherwise, we'd need to do a deep value comparison in those
528 /// situations.
529 template <typename StorageTy, typename... Args>
530 StorageTy *internalize(Args &&...args) {
531 static_assert(!std::is_base_of_v<IdentityValue, StorageTy> &&
532 "values with identity must not be internalized");
533
534 StorageTy storage(std::forward<Args>(args)...);
535
536 auto existing = getInternSet<StorageTy>().insert_as(
537 HashedStorage<StorageTy>(storage.hashcode), storage);
538 StorageTy *&storagePtr = existing.first->storage;
539 if (existing.second)
540 storagePtr =
541 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
542
543 return storagePtr;
544 }
545
546 template <typename StorageTy, typename... Args>
547 StorageTy *create(Args &&...args) {
548 static_assert(std::is_base_of_v<IdentityValue, StorageTy> &&
549 "values with structural equivalence must be internalized");
550
551 return new (allocator.Allocate<StorageTy>())
552 StorageTy(std::forward<Args>(args)...);
553 }
554
555private:
556 template <typename StorageTy>
557 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
558 getInternSet() {
559 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
560 return internedArrays;
561 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
562 return internedSets;
563 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
564 return internedBags;
565 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
566 return internedSequences;
567 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
568 return internedRandomizedSequences;
569 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
570 return internedInterleavedSequences;
571 else if constexpr (std::is_same_v<StorageTy, TupleStorage>)
572 return internedTuples;
573 else if constexpr (std::is_same_v<StorageTy, ImmediateConcatStorage>)
574 return internedImmediateConcatValues;
575 else if constexpr (std::is_same_v<StorageTy, ImmediateSliceStorage>)
576 return internedImmediateSliceValues;
577 else
578 static_assert(!sizeof(StorageTy),
579 "no intern set available for this storage type.");
580 }
581
582 // This allocator allocates on the heap. It automatically deallocates all
583 // objects it allocated once the allocator itself is destroyed.
584 llvm::BumpPtrAllocator allocator;
585
586 // The sets holding the internalized objects. We use one set per storage type
587 // such that we can have a simpler equality checking function (no need to
588 // compare some sort of TypeIDs).
589 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
590 internedArrays;
591 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
592 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
593 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
594 internedSequences;
595 DenseSet<HashedStorage<RandomizedSequenceStorage>,
596 StorageKeyInfo<RandomizedSequenceStorage>>
597 internedRandomizedSequences;
598 DenseSet<HashedStorage<InterleavedSequenceStorage>,
599 StorageKeyInfo<InterleavedSequenceStorage>>
600 internedInterleavedSequences;
601 DenseSet<HashedStorage<TupleStorage>, StorageKeyInfo<TupleStorage>>
602 internedTuples;
603 DenseSet<HashedStorage<ImmediateConcatStorage>,
604 StorageKeyInfo<ImmediateConcatStorage>>
605 internedImmediateConcatValues;
606 DenseSet<HashedStorage<ImmediateSliceStorage>,
607 StorageKeyInfo<ImmediateSliceStorage>>
608 internedImmediateSliceValues;
609};
610
611} // namespace
612
613#ifndef NDEBUG
614
615static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
616 const ElaboratorValue &value);
617
618static void print(TypedAttr val, llvm::raw_ostream &os) {
619 os << "<attr " << val << ">";
620}
621
622static void print(BagStorage *val, llvm::raw_ostream &os) {
623 os << "<bag {";
624 llvm::interleaveComma(val->bag, os,
625 [&](const std::pair<ElaboratorValue, uint64_t> &el) {
626 os << el.first << " -> " << el.second;
627 });
628 os << "} at " << val << ">";
629}
630
631static void print(bool val, llvm::raw_ostream &os) {
632 os << "<bool " << (val ? "true" : "false") << ">";
633}
634
635static void print(size_t val, llvm::raw_ostream &os) {
636 os << "<index " << val << ">";
637}
638
639static void print(SequenceStorage *val, llvm::raw_ostream &os) {
640 os << "<sequence @" << val->familyName.getValue() << "(";
641 llvm::interleaveComma(val->args, os,
642 [&](const ElaboratorValue &val) { os << val; });
643 os << ") at " << val << ">";
644}
645
646static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
647 os << "<randomized-sequence derived from @"
648 << val->sequence->familyName.getValue() << " under context "
649 << val->context << "(";
650 llvm::interleaveComma(val->sequence->args, os,
651 [&](const ElaboratorValue &val) { os << val; });
652 os << ") at " << val << ">";
653}
654
655static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
656 os << "<interleaved-sequence [";
657 llvm::interleaveComma(val->sequences, os,
658 [&](const ElaboratorValue &val) { os << val; });
659 os << "] batch-size " << val->batchSize << " at " << val << ">";
660}
661
662static void print(ArrayStorage *val, llvm::raw_ostream &os) {
663 os << "<array [";
664 llvm::interleaveComma(val->array, os,
665 [&](const ElaboratorValue &val) { os << val; });
666 os << "] at " << val << ">";
667}
668
669static void print(SetStorage *val, llvm::raw_ostream &os) {
670 os << "<set {";
671 llvm::interleaveComma(val->set, os,
672 [&](const ElaboratorValue &val) { os << val; });
673 os << "} at " << val << ">";
674}
675
676static void print(const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
677 os << "<virtual-register " << val << " " << val->allowedRegs << ">";
678}
679
680static void print(const UniqueLabelStorage *val, llvm::raw_ostream &os) {
681 os << "<unique-label " << val << " " << val->name << ">";
682}
683
684static void print(const LabelValue &val, llvm::raw_ostream &os) {
685 os << "<label " << val.name << ">";
686}
687
688static void print(const TupleStorage *val, llvm::raw_ostream &os) {
689 os << "<tuple (";
690 llvm::interleaveComma(val->values, os,
691 [&](const ElaboratorValue &val) { os << val; });
692 os << ")>";
693}
694
695static void print(const MemoryStorage *val, llvm::raw_ostream &os) {
696 os << "<memory {" << ElaboratorValue(val->memoryBlock)
697 << ", size=" << val->size << ", alignment=" << val->alignment << "}>";
698}
699
700static void print(const MemoryBlockStorage *val, llvm::raw_ostream &os) {
701 os << "<memory-block {"
702 << ", address-width=" << val->baseAddress.getBitWidth()
703 << ", base-address=" << val->baseAddress
704 << ", end-address=" << val->endAddress << "}>";
705}
706
707static void print(const ValidationValue *val, llvm::raw_ostream &os) {
708 os << "<validation-value {type=" << val->type << ", ref=" << val->ref
709 << ", defaultValue=" << val->defaultValue << "}>";
710}
711
712static void print(const ValidationMuxedValue *val, llvm::raw_ostream &os) {
713 os << "<validation-muxed-value (" << val->value << ") at " << val->idx << ">";
714}
715
716static void print(const ImmediateConcatStorage *val, llvm::raw_ostream &os) {
717 os << "<immediate-concat [";
718 llvm::interleaveComma(val->operands, os,
719 [&](const ElaboratorValue &val) { os << val; });
720 os << "]>";
721}
722
723static void print(const ImmediateSliceStorage *val, llvm::raw_ostream &os) {
724 os << "<immediate-slice " << val->input << " from " << val->lowBit << ">";
725}
726
727static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
728 const ElaboratorValue &value) {
729 std::visit([&](auto val) { print(val, os); }, value);
730
731 return os;
732}
733
734#endif
735
736//===----------------------------------------------------------------------===//
737// Elaborator Value Materialization
738//===----------------------------------------------------------------------===//
739
740namespace {
741
742/// State that should be shared by all elaborator and materializer instances.
743struct SharedState {
744 SharedState(SymbolTable &table, unsigned seed) : table(table), rng(seed) {}
745
746 SymbolTable &table;
747 std::mt19937 rng;
748 Namespace names;
749 Internalizer internalizer;
750};
751
752/// A collection of state per RTG test.
753struct TestState {
754 /// The name of the test.
755 StringAttr name;
756
757 /// The context switches registered for this test.
758 MapVector<
759 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
760 SequenceStorage *>
761 contextSwitches;
762};
763
764/// Construct an SSA value from a given elaborated value.
765class Materializer {
766public:
767 Materializer(OpBuilder builder, TestState &testState,
768 SharedState &sharedState,
769 SmallVector<ElaboratorValue> &blockArgs)
770 : builder(builder), testState(testState), sharedState(sharedState),
771 blockArgs(blockArgs) {}
772
773 /// Materialize IR representing the provided `ElaboratorValue` and return the
774 /// `Value` or a null value on failure.
775 Value materialize(ElaboratorValue val, Location loc,
776 function_ref<InFlightDiagnostic()> emitError) {
777 auto iter = materializedValues.find(val);
778 if (iter != materializedValues.end())
779 return iter->second;
780
781 LLVM_DEBUG(llvm::dbgs() << "Materializing " << val);
782
783 // In debug mode, track whether values with identity were already
784 // materialized before and assert in such a situation.
785 Value res = std::visit(
786 [&](auto value) {
787 if constexpr (std::is_base_of_v<IdentityValue,
788 std::remove_pointer_t<
789 std::decay_t<decltype(value)>>>) {
790 if (identityValueRoot.contains(value)) {
791#ifndef NDEBUG
792 bool &materialized =
793 static_cast<IdentityValue *>(value)->alreadyMaterialized;
794 assert(!materialized && "must not already be materialized");
795 materialized = true;
796#endif
797
798 return visit(value, loc, emitError);
799 }
800
801 Value arg = builder.getBlock()->addArgument(value->type, loc);
802 blockArgs.push_back(val);
803 blockArgTypes.push_back(arg.getType());
804 materializedValues[val] = arg;
805 return arg;
806 }
807
808 return visit(value, loc, emitError);
809 },
810 val);
811
812 LLVM_DEBUG(llvm::dbgs() << " to\n" << res << "\n\n");
813
814 return res;
815 }
816
817 /// If `op` is not in the same region as the materializer insertion point, a
818 /// clone is created at the materializer's insertion point by also
819 /// materializing the `ElaboratorValue`s for each operand just before it.
820 /// Otherwise, all operations after the materializer's insertion point are
821 /// deleted until `op` is reached. An error is returned if the operation is
822 /// before the insertion point.
823 LogicalResult materialize(Operation *op,
824 DenseMap<Value, ElaboratorValue> &state) {
825 if (op->getNumRegions() > 0)
826 return op->emitOpError("ops with nested regions must be elaborated away");
827
828 // We don't support opaque values. If there is an SSA value that has a
829 // use-site it needs an equivalent ElaborationValue representation.
830 // NOTE: We could support cases where there is initially a use-site but that
831 // op is guaranteed to be deleted during elaboration. Or the use-sites are
832 // replaced with freshly materialized values from the ElaborationValue. But
833 // then, why can't we delete the value defining op?
834 for (auto res : op->getResults())
835 if (!res.use_empty() && !isa<ValidateOp>(op))
836 return op->emitOpError(
837 "ops with results that have uses are not supported");
838
839 if (op->getParentRegion() == builder.getBlock()->getParent()) {
840 // We are doing in-place materialization, so mark all ops deleted until we
841 // reach the one to be materialized and modify it in-place.
842 deleteOpsUntil([&](auto iter) { return &*iter == op; });
843
844 if (builder.getInsertionPoint() == builder.getBlock()->end())
845 return op->emitError("operation did not occur after the current "
846 "materializer insertion point");
847
848 LLVM_DEBUG(llvm::dbgs() << "Modifying in-place: " << *op << "\n\n");
849 } else {
850 LLVM_DEBUG(llvm::dbgs() << "Materializing a clone of " << *op << "\n\n");
851 op = builder.clone(*op);
852 builder.setInsertionPoint(op);
853 }
854
855 for (auto &operand : op->getOpOperands()) {
856 auto emitError = [&]() {
857 auto diag = op->emitError();
858 diag.attachNote(op->getLoc())
859 << "while materializing value for operand#"
860 << operand.getOperandNumber();
861 return diag;
862 };
863
864 auto elabVal = state.at(operand.get());
865 Value val = materialize(elabVal, op->getLoc(), emitError);
866 if (!val)
867 return failure();
868
869 state[val] = elabVal;
870 operand.set(val);
871 }
872
873 builder.setInsertionPointAfter(op);
874 return success();
875 }
876
877 /// Should be called once the `Region` is successfully materialized. No calls
878 /// to `materialize` should happen after this anymore.
879 void finalize() {
880 deleteOpsUntil([](auto iter) { return false; });
881
882 for (auto *op : llvm::reverse(toDelete))
883 op->erase();
884 }
885
886 /// Tell this materializer that it is responsible for materializing the given
887 /// identity value at the earliest position it is needed, and should't
888 /// request the value via block argument.
889 void registerIdentityValue(IdentityValue *val) {
890 identityValueRoot.insert(val);
891 }
892
893 ArrayRef<Type> getBlockArgTypes() const { return blockArgTypes; }
894
895 void map(ElaboratorValue eval, Value val) { materializedValues[eval] = val; }
896
897 template <typename OpTy, typename... Args>
898 OpTy create(Location location, Args &&...args) {
899 return OpTy::create(builder, location, std::forward<Args>(args)...);
900 }
901
902private:
903 SequenceOp elaborateSequence(const RandomizedSequenceStorage *seq,
904 SmallVector<ElaboratorValue> &elabArgs);
905
906 void deleteOpsUntil(function_ref<bool(Block::iterator)> stop) {
907 auto ip = builder.getInsertionPoint();
908 while (ip != builder.getBlock()->end() && !stop(ip)) {
909 LLVM_DEBUG(llvm::dbgs() << "Marking to be deleted: " << *ip << "\n\n");
910 toDelete.push_back(&*ip);
911
912 builder.setInsertionPointAfter(&*ip);
913 ip = builder.getInsertionPoint();
914 }
915 }
916
917 Value visit(TypedAttr val, Location loc,
918 function_ref<InFlightDiagnostic()> emitError) {
919 // For index attributes (and arithmetic operations on them) we use the
920 // index dialect.
921 if (auto intAttr = dyn_cast<IntegerAttr>(val);
922 intAttr && isa<IndexType>(val.getType())) {
923 Value res = index::ConstantOp::create(builder, loc, intAttr);
924 materializedValues[val] = res;
925 return res;
926 }
927
928 // For any other attribute, we just call the materializer of the dialect
929 // defining that attribute.
930 auto *op =
931 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
932 if (!op) {
933 emitError() << "materializer of dialect '"
934 << val.getDialect().getNamespace()
935 << "' unable to materialize value for attribute '" << val
936 << "'";
937 return Value();
938 }
939
940 Value res = op->getResult(0);
941 materializedValues[val] = res;
942 return res;
943 }
944
945 Value visit(size_t val, Location loc,
946 function_ref<InFlightDiagnostic()> emitError) {
947 Value res = index::ConstantOp::create(builder, loc, val);
948 materializedValues[val] = res;
949 return res;
950 }
951
952 Value visit(bool val, Location loc,
953 function_ref<InFlightDiagnostic()> emitError) {
954 Value res = index::BoolConstantOp::create(builder, loc, val);
955 materializedValues[val] = res;
956 return res;
957 }
958
959 Value visit(ArrayStorage *val, Location loc,
960 function_ref<InFlightDiagnostic()> emitError) {
961 SmallVector<Value> elements;
962 elements.reserve(val->array.size());
963 for (auto el : val->array) {
964 auto materialized = materialize(el, loc, emitError);
965 if (!materialized)
966 return Value();
967
968 elements.push_back(materialized);
969 }
970
971 Value res = ArrayCreateOp::create(builder, loc, val->type, elements);
972 materializedValues[val] = res;
973 return res;
974 }
975
976 Value visit(SetStorage *val, Location loc,
977 function_ref<InFlightDiagnostic()> emitError) {
978 SmallVector<Value> elements;
979 elements.reserve(val->set.size());
980 for (auto el : val->set) {
981 auto materialized = materialize(el, loc, emitError);
982 if (!materialized)
983 return Value();
984
985 elements.push_back(materialized);
986 }
987
988 auto res = SetCreateOp::create(builder, loc, val->type, elements);
989 materializedValues[val] = res;
990 return res;
991 }
992
993 Value visit(BagStorage *val, Location loc,
994 function_ref<InFlightDiagnostic()> emitError) {
995 SmallVector<Value> values, weights;
996 values.reserve(val->bag.size());
997 weights.reserve(val->bag.size());
998 for (auto [val, weight] : val->bag) {
999 auto materializedVal = materialize(val, loc, emitError);
1000 auto materializedWeight = materialize(weight, loc, emitError);
1001 if (!materializedVal || !materializedWeight)
1002 return Value();
1003
1004 values.push_back(materializedVal);
1005 weights.push_back(materializedWeight);
1006 }
1007
1008 auto res = BagCreateOp::create(builder, loc, val->type, values, weights);
1009 materializedValues[val] = res;
1010 return res;
1011 }
1012
1013 Value visit(MemoryBlockStorage *val, Location loc,
1014 function_ref<InFlightDiagnostic()> emitError) {
1015 auto intType = builder.getIntegerType(val->baseAddress.getBitWidth());
1016 Value res = MemoryBlockDeclareOp::create(
1017 builder, loc, val->type, IntegerAttr::get(intType, val->baseAddress),
1018 IntegerAttr::get(intType, val->endAddress));
1019 materializedValues[val] = res;
1020 return res;
1021 }
1022
1023 Value visit(MemoryStorage *val, Location loc,
1024 function_ref<InFlightDiagnostic()> emitError) {
1025 auto memBlock = materialize(val->memoryBlock, loc, emitError);
1026 auto memSize = materialize(val->size, loc, emitError);
1027 auto memAlign = materialize(val->alignment, loc, emitError);
1028 if (!(memBlock && memSize && memAlign))
1029 return {};
1030
1031 Value res =
1032 MemoryAllocOp::create(builder, loc, memBlock, memSize, memAlign);
1033 materializedValues[val] = res;
1034 return res;
1035 }
1036
1037 Value visit(SequenceStorage *val, Location loc,
1038 function_ref<InFlightDiagnostic()> emitError) {
1039 emitError() << "materializing a non-randomized sequence not supported yet";
1040 return Value();
1041 }
1042
1043 Value visit(RandomizedSequenceStorage *val, Location loc,
1044 function_ref<InFlightDiagnostic()> emitError) {
1045 // To know which values we have to pass by argument (and not just pass all
1046 // that migth be used eagerly), we have to elaborate the sequence family if
1047 // not already done so.
1048 // We need to get back the sequence to reference, and the list of elaborated
1049 // values to pass as arguments.
1050 SmallVector<ElaboratorValue> elabArgs;
1051 SequenceOp seqOp = elaborateSequence(val, elabArgs);
1052 if (!seqOp)
1053 return {};
1054
1055 // Materialize all the values we need to pass as arguments and collect their
1056 // types.
1057 SmallVector<Value> args;
1058 SmallVector<Type> argTypes;
1059 for (auto arg : elabArgs) {
1060 Value materialized = materialize(arg, loc, emitError);
1061 if (!materialized)
1062 return {};
1063
1064 args.push_back(materialized);
1065 argTypes.push_back(materialized.getType());
1066 }
1067
1068 Value res = GetSequenceOp::create(
1069 builder, loc, SequenceType::get(builder.getContext(), argTypes),
1070 seqOp.getSymName());
1071
1072 // Only materialize a substitute_sequence op when we have arguments to
1073 // substitute since this op does not support 0 arguments.
1074 if (!args.empty())
1075 res = SubstituteSequenceOp::create(builder, loc, res, args);
1076
1077 res = RandomizeSequenceOp::create(builder, loc, res);
1078
1079 materializedValues[val] = res;
1080 return res;
1081 }
1082
1083 Value visit(InterleavedSequenceStorage *val, Location loc,
1084 function_ref<InFlightDiagnostic()> emitError) {
1085 SmallVector<Value> sequences;
1086 for (auto seqVal : val->sequences) {
1087 Value materialized = materialize(seqVal, loc, emitError);
1088 if (!materialized)
1089 return {};
1090
1091 sequences.push_back(materialized);
1092 }
1093
1094 if (sequences.size() == 1)
1095 return sequences[0];
1096
1097 Value res =
1098 InterleaveSequencesOp::create(builder, loc, sequences, val->batchSize);
1099 materializedValues[val] = res;
1100 return res;
1101 }
1102
1103 Value visit(VirtualRegisterStorage *val, Location loc,
1104 function_ref<InFlightDiagnostic()> emitError) {
1105 Value res = VirtualRegisterOp::create(builder, loc, val->allowedRegs);
1106 materializedValues[val] = res;
1107 return res;
1108 }
1109
1110 Value visit(UniqueLabelStorage *val, Location loc,
1111 function_ref<InFlightDiagnostic()> emitError) {
1112 Value res =
1113 LabelUniqueDeclOp::create(builder, loc, val->name, ValueRange());
1114 materializedValues[val] = res;
1115 return res;
1116 }
1117
1118 Value visit(const LabelValue &val, Location loc,
1119 function_ref<InFlightDiagnostic()> emitError) {
1120 Value res = LabelDeclOp::create(builder, loc, val.name, ValueRange());
1121 materializedValues[val] = res;
1122 return res;
1123 }
1124
1125 Value visit(TupleStorage *val, Location loc,
1126 function_ref<InFlightDiagnostic()> emitError) {
1127 SmallVector<Value> materialized;
1128 materialized.reserve(val->values.size());
1129 for (auto v : val->values)
1130 materialized.push_back(materialize(v, loc, emitError));
1131 Value res = TupleCreateOp::create(builder, loc, materialized);
1132 materializedValues[val] = res;
1133 return res;
1134 }
1135
1136 Value visit(ValidationValue *val, Location loc,
1137 function_ref<InFlightDiagnostic()> emitError) {
1138 SmallVector<Value> usedDefaultValues, elseValues;
1139 for (auto [dfltVal, elseVal] :
1140 llvm::zip(val->defaultUsedValues, val->elseValues)) {
1141 auto dfltMat = materialize(dfltVal, loc, emitError);
1142 auto elseMat = materialize(elseVal, loc, emitError);
1143 if (!dfltMat || !elseMat)
1144 return {};
1145
1146 usedDefaultValues.push_back(dfltMat);
1147 usedDefaultValues.push_back(elseMat);
1148 }
1149
1150 auto validateOp = ValidateOp::create(
1151 builder, loc, val->type, materialize(val->ref, loc, emitError),
1152 materialize(val->defaultValue, loc, emitError), val->id,
1153 usedDefaultValues, elseValues);
1154 materializedValues[val] = validateOp.getValue();
1155 return validateOp.getValue();
1156 }
1157
1158 Value visit(ValidationMuxedValue *val, Location loc,
1159 function_ref<InFlightDiagnostic()> emitError) {
1160 Value validateValue =
1161 materialize(const_cast<ValidationValue *>(val->value), loc, emitError);
1162 if (!validateValue)
1163 return {};
1164 auto validateOp = validateValue.getDefiningOp<ValidateOp>();
1165 if (!validateOp) {
1166 auto *defOp = validateValue.getDefiningOp();
1167 emitError()
1168 << "expected validate op for validation muxed value, but found "
1169 << (defOp ? defOp->getName().getStringRef() : "block argument");
1170 return {};
1171 }
1172 materializedValues[val] = validateOp.getValues()[val->idx];
1173 return validateOp.getValues()[val->idx];
1174 }
1175
1176 Value visit(ImmediateConcatStorage *val, Location loc,
1177 function_ref<InFlightDiagnostic()> emitError) {
1178 SmallVector<Value> operands;
1179 for (auto operand : val->operands) {
1180 auto materialized = materialize(operand, loc, emitError);
1181 if (!materialized)
1182 return {};
1183
1184 operands.push_back(materialized);
1185 }
1186
1187 Value res = ConcatImmediateOp::create(builder, loc, operands);
1188 materializedValues[val] = res;
1189 return res;
1190 }
1191
1192 Value visit(ImmediateSliceStorage *val, Location loc,
1193 function_ref<InFlightDiagnostic()> emitError) {
1194 Value input = materialize(val->input, loc, emitError);
1195 if (!input)
1196 return {};
1197
1198 Value res =
1199 SliceImmediateOp::create(builder, loc, val->type, input, val->lowBit);
1200 materializedValues[val] = res;
1201 return res;
1202 }
1203
1204private:
1205 /// Cache values we have already materialized to reuse them later. We start
1206 /// with an insertion point at the start of the block and cache the (updated)
1207 /// insertion point such that future materializations can also reuse previous
1208 /// materializations without running into dominance issues (or requiring
1209 /// additional checks to avoid them).
1210 DenseMap<ElaboratorValue, Value> materializedValues;
1211
1212 /// Cache the builder to continue insertions at their current insertion point
1213 /// for the reason stated above.
1214 OpBuilder builder;
1215
1216 SmallVector<Operation *> toDelete;
1217
1218 TestState &testState;
1219 SharedState &sharedState;
1220
1221 /// Keep track of the block arguments we had to add to this materializer's
1222 /// block for identity values and also remember which elaborator values are
1223 /// expected to be passed as arguments from outside.
1224 SmallVector<ElaboratorValue> &blockArgs;
1225 SmallVector<Type> blockArgTypes;
1226
1227 /// Identity values in this set are materialized by this materializer,
1228 /// otherwise they are added as block arguments and the block that wants to
1229 /// embed this sequence is expected to provide a value for it.
1230 DenseSet<IdentityValue *> identityValueRoot;
1231};
1232
1233//===----------------------------------------------------------------------===//
1234// Elaboration Visitor
1235//===----------------------------------------------------------------------===//
1236
1237/// Used to signal to the elaboration driver whether the operation should be
1238/// removed.
1239enum class DeletionKind { Keep, Delete };
1240
1241/// Interprets the IR to perform and lower the represented randomizations.
1242class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
1243public:
1245 using RTGBase::visitOp;
1246
1247 Elaborator(SharedState &sharedState, TestState &testState,
1248 Materializer &materializer,
1249 ContextResourceAttrInterface currentContext = {})
1250 : sharedState(sharedState), testState(testState),
1251 materializer(materializer), currentContext(currentContext) {}
1252
1253 template <typename ValueTy>
1254 inline ValueTy get(Value val) const {
1255 return std::get<ValueTy>(state.at(val));
1256 }
1257
1258 FailureOr<DeletionKind> visitPureOp(Operation *op) {
1259 SmallVector<Attribute> operands;
1260 for (auto operand : op->getOperands()) {
1261 auto evalValue = state[operand];
1262 if (std::holds_alternative<TypedAttr>(evalValue))
1263 operands.push_back(std::get<TypedAttr>(evalValue));
1264 else
1265 return visitUnhandledOp(op);
1266 }
1267
1268 SmallVector<OpFoldResult> results;
1269 if (failed(op->fold(operands, results)))
1270 return visitUnhandledOp(op);
1271
1272 // We don't support in-place folders.
1273 if (results.size() != op->getNumResults())
1274 return visitUnhandledOp(op);
1275
1276 for (auto [res, val] : llvm::zip(results, op->getResults())) {
1277 auto attr = llvm::dyn_cast_or_null<TypedAttr>(res.dyn_cast<Attribute>());
1278 if (!attr)
1279 return op->emitError(
1280 "only typed attributes supported for constant-like operations");
1281
1282 auto intAttr = dyn_cast<IntegerAttr>(attr);
1283 if (intAttr && isa<IndexType>(attr.getType()))
1284 state[op->getResult(0)] = size_t(intAttr.getInt());
1285 else if (intAttr && intAttr.getType().isSignlessInteger(1))
1286 state[op->getResult(0)] = bool(intAttr.getInt());
1287 else
1288 state[op->getResult(0)] = attr;
1289 }
1290
1291 return DeletionKind::Delete;
1292 }
1293
1294 /// Print a nice error message for operations we don't support yet.
1295 FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
1296 return op->emitOpError("elaboration not supported");
1297 }
1298
1299 FailureOr<DeletionKind> visitExternalOp(Operation *op) {
1300 auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
1301 if (op->hasTrait<OpTrait::ConstantLike>() || (memOp && memOp.hasNoEffect()))
1302 return visitPureOp(op);
1303
1304 // TODO: we only have this to be able to write tests for this pass without
1305 // having to add support for more operations for now, so it should be
1306 // removed once it is not necessary anymore for writing tests
1307 if (op->use_empty())
1308 return DeletionKind::Keep;
1309
1310 return visitUnhandledOp(op);
1311 }
1312
1313 FailureOr<DeletionKind> visitOp(ConstantOp op) { return visitPureOp(op); }
1314
1315 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
1316 SmallVector<ElaboratorValue> replacements;
1317 state[op.getResult()] =
1318 sharedState.internalizer.internalize<SequenceStorage>(
1319 op.getSequenceAttr(), std::move(replacements));
1320 return DeletionKind::Delete;
1321 }
1322
1323 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
1324 auto *seq = get<SequenceStorage *>(op.getSequence());
1325
1326 SmallVector<ElaboratorValue> replacements(seq->args);
1327 for (auto replacement : op.getReplacements())
1328 replacements.push_back(state.at(replacement));
1329
1330 state[op.getResult()] =
1331 sharedState.internalizer.internalize<SequenceStorage>(
1332 seq->familyName, std::move(replacements));
1333
1334 return DeletionKind::Delete;
1335 }
1336
1337 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
1338 auto *seq = get<SequenceStorage *>(op.getSequence());
1339 auto *randomizedSeq =
1340 sharedState.internalizer.create<RandomizedSequenceStorage>(
1341 currentContext, seq);
1342 materializer.registerIdentityValue(randomizedSeq);
1343 state[op.getResult()] =
1344 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1345 randomizedSeq);
1346 return DeletionKind::Delete;
1347 }
1348
1349 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
1350 SmallVector<ElaboratorValue> sequences;
1351 for (auto seq : op.getSequences())
1352 sequences.push_back(get<InterleavedSequenceStorage *>(seq));
1353
1354 state[op.getResult()] =
1355 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1356 std::move(sequences), op.getBatchSize());
1357 return DeletionKind::Delete;
1358 }
1359
1360 // NOLINTNEXTLINE(misc-no-recursion)
1361 LogicalResult isValidContext(ElaboratorValue value, Operation *op) const {
1362 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
1363 auto *seq = std::get<RandomizedSequenceStorage *>(value);
1364 if (seq->context != currentContext) {
1365 auto err = op->emitError("attempting to place sequence derived from ")
1366 << seq->sequence->familyName.getValue() << " under context "
1367 << currentContext
1368 << ", but it was previously randomized for context ";
1369 if (seq->context)
1370 err << seq->context;
1371 else
1372 err << "'default'";
1373 return err;
1374 }
1375 return success();
1376 }
1377
1378 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1379 for (auto val : interVal->sequences)
1380 if (failed(isValidContext(val, op)))
1381 return failure();
1382 return success();
1383 }
1384
1385 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1386 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1387 if (failed(isValidContext(seqVal, op)))
1388 return failure();
1389
1390 return DeletionKind::Keep;
1391 }
1392
1393 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1394 SetVector<ElaboratorValue> set;
1395 for (auto val : op.getElements())
1396 set.insert(state.at(val));
1397
1398 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1399 std::move(set), op.getSet().getType());
1400 return DeletionKind::Delete;
1401 }
1402
1403 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1404 auto set = get<SetStorage *>(op.getSet())->set;
1405
1406 if (set.empty())
1407 return op->emitError("cannot select from an empty set");
1408
1409 size_t selected;
1410 if (auto intAttr =
1411 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1412 std::mt19937 customRng(intAttr.getInt());
1413 selected = getUniformlyInRange(customRng, 0, set.size() - 1);
1414 } else {
1415 selected = getUniformlyInRange(sharedState.rng, 0, set.size() - 1);
1416 }
1417
1418 state[op.getResult()] = set[selected];
1419 return DeletionKind::Delete;
1420 }
1421
1422 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1423 auto original = get<SetStorage *>(op.getOriginal())->set;
1424 auto diff = get<SetStorage *>(op.getDiff())->set;
1425
1426 SetVector<ElaboratorValue> result(original);
1427 result.set_subtract(diff);
1428
1429 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1430 std::move(result), op.getResult().getType());
1431 return DeletionKind::Delete;
1432 }
1433
1434 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1435 SetVector<ElaboratorValue> result;
1436 for (auto set : op.getSets())
1437 result.set_union(get<SetStorage *>(set)->set);
1438
1439 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1440 std::move(result), op.getType());
1441 return DeletionKind::Delete;
1442 }
1443
1444 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1445 auto size = get<SetStorage *>(op.getSet())->set.size();
1446 state[op.getResult()] = size;
1447 return DeletionKind::Delete;
1448 }
1449
1450 // {a0,a1} x {b0,b1} x {c0,c1} -> {(a0), (a1)} -> {(a0,b0), (a0,b1), (a1,b0),
1451 // (a1,b1)} -> {(a0,b0,c0), (a0,b0,c1), (a0,b1,c0), (a0,b1,c1), (a1,b0,c0),
1452 // (a1,b0,c1), (a1,b1,c0), (a1,b1,c1)}
1453 FailureOr<DeletionKind> visitOp(SetCartesianProductOp op) {
1454 SetVector<ElaboratorValue> result;
1455 SmallVector<SmallVector<ElaboratorValue>> tuples;
1456 tuples.push_back({});
1457
1458 for (auto input : op.getInputs()) {
1459 auto &set = get<SetStorage *>(input)->set;
1460 if (set.empty()) {
1461 SetVector<ElaboratorValue> empty;
1462 state[op.getResult()] =
1463 sharedState.internalizer.internalize<SetStorage>(std::move(empty),
1464 op.getType());
1465 return DeletionKind::Delete;
1466 }
1467
1468 for (unsigned i = 0, e = tuples.size(); i < e; ++i) {
1469 for (auto setEl : set.getArrayRef().drop_back()) {
1470 tuples.push_back(tuples[i]);
1471 tuples.back().push_back(setEl);
1472 }
1473 tuples[i].push_back(set.back());
1474 }
1475 }
1476
1477 for (auto &tup : tuples)
1478 result.insert(
1479 sharedState.internalizer.internalize<TupleStorage>(std::move(tup)));
1480
1481 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1482 std::move(result), op.getType());
1483 return DeletionKind::Delete;
1484 }
1485
1486 FailureOr<DeletionKind> visitOp(SetConvertToBagOp op) {
1487 auto set = get<SetStorage *>(op.getInput())->set;
1488 MapVector<ElaboratorValue, uint64_t> bag;
1489 for (auto val : set)
1490 bag.insert({val, 1});
1491 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1492 std::move(bag), op.getType());
1493 return DeletionKind::Delete;
1494 }
1495
1496 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1497 MapVector<ElaboratorValue, uint64_t> bag;
1498 for (auto [val, multiple] :
1499 llvm::zip(op.getElements(), op.getMultiples())) {
1500 // If the multiple is not stored as an AttributeValue, the elaboration
1501 // must have already failed earlier (since we don't have
1502 // unevaluated/opaque values).
1503 bag[state.at(val)] += get<size_t>(multiple);
1504 }
1505
1506 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1507 std::move(bag), op.getType());
1508 return DeletionKind::Delete;
1509 }
1510
1511 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1512 auto bag = get<BagStorage *>(op.getBag())->bag;
1513
1514 if (bag.empty())
1515 return op->emitError("cannot select from an empty bag");
1516
1517 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1518 prefixSum.reserve(bag.size());
1519 uint32_t accumulator = 0;
1520 for (auto [val, weight] : bag) {
1521 accumulator += weight;
1522 prefixSum.push_back({val, accumulator});
1523 }
1524
1525 auto customRng = sharedState.rng;
1526 if (auto intAttr =
1527 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1528 customRng = std::mt19937(intAttr.getInt());
1529 }
1530
1531 auto idx = getUniformlyInRange(customRng, 0, accumulator - 1);
1532 auto *iter = llvm::upper_bound(
1533 prefixSum, idx,
1534 [](uint32_t a, const std::pair<ElaboratorValue, uint32_t> &b) {
1535 return a < b.second;
1536 });
1537
1538 state[op.getResult()] = iter->first;
1539 return DeletionKind::Delete;
1540 }
1541
1542 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1543 auto original = get<BagStorage *>(op.getOriginal())->bag;
1544 auto diff = get<BagStorage *>(op.getDiff())->bag;
1545
1546 MapVector<ElaboratorValue, uint64_t> result;
1547 for (const auto &el : original) {
1548 if (!diff.contains(el.first)) {
1549 result.insert(el);
1550 continue;
1551 }
1552
1553 if (op.getInf())
1554 continue;
1555
1556 auto toDiff = diff.lookup(el.first);
1557 if (el.second <= toDiff)
1558 continue;
1559
1560 result.insert({el.first, el.second - toDiff});
1561 }
1562
1563 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1564 std::move(result), op.getType());
1565 return DeletionKind::Delete;
1566 }
1567
1568 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1569 MapVector<ElaboratorValue, uint64_t> result;
1570 for (auto bag : op.getBags()) {
1571 auto val = get<BagStorage *>(bag)->bag;
1572 for (auto [el, multiple] : val)
1573 result[el] += multiple;
1574 }
1575
1576 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1577 std::move(result), op.getType());
1578 return DeletionKind::Delete;
1579 }
1580
1581 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1582 auto size = get<BagStorage *>(op.getBag())->bag.size();
1583 state[op.getResult()] = size;
1584 return DeletionKind::Delete;
1585 }
1586
1587 FailureOr<DeletionKind> visitOp(BagConvertToSetOp op) {
1588 auto bag = get<BagStorage *>(op.getInput())->bag;
1589 SetVector<ElaboratorValue> set;
1590 for (auto [k, v] : bag)
1591 set.insert(k);
1592 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1593 std::move(set), op.getType());
1594 return DeletionKind::Delete;
1595 }
1596
1597 FailureOr<DeletionKind> visitOp(FixedRegisterOp op) {
1598 return visitPureOp(op);
1599 }
1600
1601 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1602 auto *val = sharedState.internalizer.create<VirtualRegisterStorage>(
1603 op.getAllowedRegsAttr(), op.getType());
1604 state[op.getResult()] = val;
1605 materializer.registerIdentityValue(val);
1606 return DeletionKind::Delete;
1607 }
1608
1609 StringAttr substituteFormatString(StringAttr formatString,
1610 ValueRange substitutes) const {
1611 if (substitutes.empty() || formatString.empty())
1612 return formatString;
1613
1614 auto original = formatString.getValue().str();
1615 for (auto [i, subst] : llvm::enumerate(substitutes)) {
1616 size_t startPos = 0;
1617 std::string from = "{{" + std::to_string(i) + "}}";
1618 while ((startPos = original.find(from, startPos)) != std::string::npos) {
1619 auto substString = std::to_string(get<size_t>(subst));
1620 original.replace(startPos, from.length(), substString);
1621 }
1622 }
1623
1624 return StringAttr::get(formatString.getContext(), original);
1625 }
1626
1627 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1628 SmallVector<ElaboratorValue> array;
1629 array.reserve(op.getElements().size());
1630 for (auto val : op.getElements())
1631 array.emplace_back(state.at(val));
1632
1633 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1634 op.getResult().getType(), std::move(array));
1635 return DeletionKind::Delete;
1636 }
1637
1638 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1639 auto array = get<ArrayStorage *>(op.getArray())->array;
1640 size_t idx = get<size_t>(op.getIndex());
1641
1642 if (array.size() <= idx)
1643 return op->emitError("invalid to access index ")
1644 << idx << " of an array with " << array.size() << " elements";
1645
1646 state[op.getResult()] = array[idx];
1647 return DeletionKind::Delete;
1648 }
1649
1650 FailureOr<DeletionKind> visitOp(ArrayInjectOp op) {
1651 auto array = get<ArrayStorage *>(op.getArray())->array;
1652 size_t idx = get<size_t>(op.getIndex());
1653
1654 if (array.size() <= idx)
1655 return op->emitError("invalid to access index ")
1656 << idx << " of an array with " << array.size() << " elements";
1657
1658 array[idx] = state[op.getValue()];
1659 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1660 op.getResult().getType(), std::move(array));
1661 return DeletionKind::Delete;
1662 }
1663
1664 FailureOr<DeletionKind> visitOp(ArraySizeOp op) {
1665 auto array = get<ArrayStorage *>(op.getArray())->array;
1666 state[op.getResult()] = array.size();
1667 return DeletionKind::Delete;
1668 }
1669
1670 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
1671 auto substituted =
1672 substituteFormatString(op.getFormatStringAttr(), op.getArgs());
1673 state[op.getLabel()] = LabelValue(substituted);
1674 return DeletionKind::Delete;
1675 }
1676
1677 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1678 auto *val = sharedState.internalizer.create<UniqueLabelStorage>(
1679 substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
1680 state[op.getLabel()] = val;
1681 materializer.registerIdentityValue(val);
1682 return DeletionKind::Delete;
1683 }
1684
1685 FailureOr<DeletionKind> visitOp(LabelOp op) { return DeletionKind::Keep; }
1686
1687 FailureOr<DeletionKind> visitOp(TestSuccessOp op) {
1688 return DeletionKind::Keep;
1689 }
1690
1691 FailureOr<DeletionKind> visitOp(TestFailureOp op) {
1692 return DeletionKind::Keep;
1693 }
1694
1695 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1696 size_t lower = get<size_t>(op.getLowerBound());
1697 size_t upper = get<size_t>(op.getUpperBound());
1698 if (lower > upper)
1699 return op->emitError("cannot select a number from an empty range");
1700
1701 if (auto intAttr =
1702 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1703 std::mt19937 customRng(intAttr.getInt());
1704 state[op.getResult()] =
1705 size_t(getUniformlyInRange(customRng, lower, upper));
1706 } else {
1707 state[op.getResult()] =
1708 size_t(getUniformlyInRange(sharedState.rng, lower, upper));
1709 }
1710
1711 return DeletionKind::Delete;
1712 }
1713
1714 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1715 size_t input = get<size_t>(op.getInput());
1716 auto width = op.getType().getWidth();
1717 auto emitError = [&]() { return op->emitError(); };
1718 if (input > APInt::getAllOnes(width).getZExtValue())
1719 return emitError() << "cannot represent " << input << " with " << width
1720 << " bits";
1721
1722 state[op.getResult()] =
1723 ImmediateAttr::get(op.getContext(), APInt(width, input));
1724 return DeletionKind::Delete;
1725 }
1726
1727 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1728 ContextResourceAttrInterface from = currentContext,
1729 to = cast<ContextResourceAttrInterface>(
1730 get<TypedAttr>(op.getContext()));
1731 if (!currentContext)
1732 from = DefaultContextAttr::get(op->getContext(), to.getType());
1733
1734 auto emitError = [&]() {
1735 auto diag = op.emitError();
1736 diag.attachNote(op.getLoc())
1737 << "while materializing value for context switching for " << op;
1738 return diag;
1739 };
1740
1741 if (from == to) {
1742 Value seqVal = materializer.materialize(
1743 get<SequenceStorage *>(op.getSequence()), op.getLoc(), emitError);
1744 if (!seqVal)
1745 return failure();
1746
1747 Value randSeqVal =
1748 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1749 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1750 return DeletionKind::Delete;
1751 }
1752
1753 // Switch to the desired context.
1754 // First, check if a context switch is registered that has the concrete
1755 // context as source and target.
1756 auto *iter = testState.contextSwitches.find({from, to});
1757
1758 // Try with 'any' context as target and the concrete context as source.
1759 if (iter == testState.contextSwitches.end())
1760 iter = testState.contextSwitches.find(
1761 {from, AnyContextAttr::get(op->getContext(), to.getType())});
1762
1763 // Try with 'any' context as source and the concrete context as target.
1764 if (iter == testState.contextSwitches.end())
1765 iter = testState.contextSwitches.find(
1766 {AnyContextAttr::get(op->getContext(), from.getType()), to});
1767
1768 // Try with 'any' context for both the source and the target.
1769 if (iter == testState.contextSwitches.end())
1770 iter = testState.contextSwitches.find(
1771 {AnyContextAttr::get(op->getContext(), from.getType()),
1772 AnyContextAttr::get(op->getContext(), to.getType())});
1773
1774 // Otherwise, fail with an error because we couldn't find a user
1775 // specification on how to switch between the requested contexts.
1776 // NOTE: we could think about supporting context switching via intermediate
1777 // context, i.e., treat it as a transitive relation.
1778 if (iter == testState.contextSwitches.end())
1779 return op->emitError("no context transition registered to switch from ")
1780 << from << " to " << to;
1781
1782 auto familyName = iter->second->familyName;
1783 SmallVector<ElaboratorValue> args{from, to,
1784 get<SequenceStorage *>(op.getSequence())};
1785 auto *seq = sharedState.internalizer.internalize<SequenceStorage>(
1786 familyName, std::move(args));
1787 auto *randSeq =
1788 sharedState.internalizer.create<RandomizedSequenceStorage>(to, seq);
1789 materializer.registerIdentityValue(randSeq);
1790 Value seqVal = materializer.materialize(randSeq, op.getLoc(), emitError);
1791 if (!seqVal)
1792 return failure();
1793
1794 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1795 return DeletionKind::Delete;
1796 }
1797
1798 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1799 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1800 get<SequenceStorage *>(op.getSequence());
1801 return DeletionKind::Delete;
1802 }
1803
1804 FailureOr<DeletionKind> visitOp(MemoryBlockDeclareOp op) {
1805 auto *val = sharedState.internalizer.create<MemoryBlockStorage>(
1806 op.getBaseAddress(), op.getEndAddress(), op.getType());
1807 state[op.getResult()] = val;
1808 materializer.registerIdentityValue(val);
1809 return DeletionKind::Delete;
1810 }
1811
1812 FailureOr<DeletionKind> visitOp(MemoryAllocOp op) {
1813 size_t size = get<size_t>(op.getSize());
1814 size_t alignment = get<size_t>(op.getAlignment());
1815 auto *memBlock = get<MemoryBlockStorage *>(op.getMemoryBlock());
1816 auto *val = sharedState.internalizer.create<MemoryStorage>(memBlock, size,
1817 alignment);
1818 state[op.getResult()] = val;
1819 materializer.registerIdentityValue(val);
1820 return DeletionKind::Delete;
1821 }
1822
1823 FailureOr<DeletionKind> visitOp(MemorySizeOp op) {
1824 auto *memory = get<MemoryStorage *>(op.getMemory());
1825 state[op.getResult()] = memory->size;
1826 return DeletionKind::Delete;
1827 }
1828
1829 FailureOr<DeletionKind> visitOp(TupleCreateOp op) {
1830 SmallVector<ElaboratorValue> values;
1831 values.reserve(op.getElements().size());
1832 for (auto el : op.getElements())
1833 values.push_back(state[el]);
1834
1835 state[op.getResult()] =
1836 sharedState.internalizer.internalize<TupleStorage>(std::move(values));
1837 return DeletionKind::Delete;
1838 }
1839
1840 FailureOr<DeletionKind> visitOp(TupleExtractOp op) {
1841 auto *tuple = get<TupleStorage *>(op.getTuple());
1842 state[op.getResult()] = tuple->values[op.getIndex().getZExtValue()];
1843 return DeletionKind::Delete;
1844 }
1845
1846 FailureOr<DeletionKind> visitOp(CommentOp op) { return DeletionKind::Keep; }
1847
1848 FailureOr<DeletionKind> visitOp(rtg::YieldOp op) {
1849 return DeletionKind::Keep;
1850 }
1851
1852 FailureOr<DeletionKind> visitOp(ValidateOp op) {
1853 SmallVector<ElaboratorValue> defaultUsedValues, elseValues;
1854 for (auto v : op.getDefaultUsedValues())
1855 defaultUsedValues.push_back(state.at(v));
1856
1857 for (auto v : op.getElseValues())
1858 elseValues.push_back(state.at(v));
1859
1860 auto *validationVal = sharedState.internalizer.create<ValidationValue>(
1861 op.getValue().getType(), state[op.getRef()],
1862 state[op.getDefaultValue()], op.getIdAttr(),
1863 std::move(defaultUsedValues), std::move(elseValues));
1864 state[op.getValue()] = validationVal;
1865 materializer.registerIdentityValue(validationVal);
1866 materializer.map(validationVal, op.getValue());
1867
1868 for (auto [i, val] : llvm::enumerate(op.getValues())) {
1869 auto *muxVal = sharedState.internalizer.create<ValidationMuxedValue>(
1870 val.getType(), validationVal, i);
1871 state[val] = muxVal;
1872 materializer.registerIdentityValue(muxVal);
1873 materializer.map(muxVal, val);
1874 }
1875
1876 return DeletionKind::Keep;
1877 }
1878
1879 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1880 bool cond = get<bool>(op.getCondition());
1881 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1882 if (toElaborate.empty())
1883 return DeletionKind::Delete;
1884
1885 // Just reuse this elaborator for the nested region because we need access
1886 // to the elaborated values outside the nested region (since it is not
1887 // isolated from above) and we want to materialize the region inline, thus
1888 // don't need a new materializer instance.
1889 SmallVector<ElaboratorValue> yieldedVals;
1890 if (failed(elaborate(toElaborate, {}, yieldedVals)))
1891 return failure();
1892
1893 // Map the results of the 'scf.if' to the yielded values.
1894 for (auto [res, out] : llvm::zip(op.getResults(), yieldedVals))
1895 state[res] = out;
1896
1897 return DeletionKind::Delete;
1898 }
1899
1900 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1901 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1902 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1903 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1904 return op->emitOpError("can only elaborate index type iterator");
1905
1906 auto lowerBound = get<size_t>(op.getLowerBound());
1907 auto step = get<size_t>(op.getStep());
1908 auto upperBound = get<size_t>(op.getUpperBound());
1909
1910 // Prepare for first iteration by assigning the nested regions block
1911 // arguments. We can just reuse this elaborator because we need access to
1912 // values elaborated in the parent region anyway and materialize everything
1913 // inline (i.e., don't need a new materializer).
1914 state[op.getInductionVar()] = lowerBound;
1915 for (auto [iterArg, initArg] :
1916 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1917 state[iterArg] = state.at(initArg);
1918
1919 // This loop performs the actual 'scf.for' loop iterations.
1920 SmallVector<ElaboratorValue> yieldedVals;
1921 for (size_t i = lowerBound; i < upperBound; i += step) {
1922 yieldedVals.clear();
1923 if (failed(elaborate(op.getBodyRegion(), {}, yieldedVals)))
1924 return failure();
1925
1926 // Prepare for the next iteration by updating the mapping of the nested
1927 // regions block arguments
1928 state[op.getInductionVar()] = i + step;
1929 for (auto [iterArg, prevIterArg] :
1930 llvm::zip(op.getRegionIterArgs(), yieldedVals))
1931 state[iterArg] = prevIterArg;
1932 }
1933
1934 // Transfer the previously yielded values to the for loop result values.
1935 for (auto [res, iterArg] :
1936 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1937 state[res] = state.at(iterArg);
1938
1939 return DeletionKind::Delete;
1940 }
1941
1942 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1943 return DeletionKind::Delete;
1944 }
1945
1946 FailureOr<DeletionKind> visitOp(arith::AddIOp op) {
1947 if (!isa<IndexType>(op.getType()))
1948 return op->emitError("only index operands supported");
1949
1950 size_t lhs = get<size_t>(op.getLhs());
1951 size_t rhs = get<size_t>(op.getRhs());
1952 state[op.getResult()] = lhs + rhs;
1953 return DeletionKind::Delete;
1954 }
1955
1956 FailureOr<DeletionKind> visitOp(arith::AndIOp op) {
1957 if (!op.getType().isSignlessInteger(1))
1958 return op->emitError("only 'i1' operands supported");
1959
1960 bool lhs = get<bool>(op.getLhs());
1961 bool rhs = get<bool>(op.getRhs());
1962 state[op.getResult()] = lhs && rhs;
1963 return DeletionKind::Delete;
1964 }
1965
1966 FailureOr<DeletionKind> visitOp(arith::XOrIOp op) {
1967 if (!op.getType().isSignlessInteger(1))
1968 return op->emitError("only 'i1' operands supported");
1969
1970 bool lhs = get<bool>(op.getLhs());
1971 bool rhs = get<bool>(op.getRhs());
1972 state[op.getResult()] = lhs != rhs;
1973 return DeletionKind::Delete;
1974 }
1975
1976 FailureOr<DeletionKind> visitOp(arith::OrIOp op) {
1977 if (!op.getType().isSignlessInteger(1))
1978 return op->emitError("only 'i1' operands supported");
1979
1980 bool lhs = get<bool>(op.getLhs());
1981 bool rhs = get<bool>(op.getRhs());
1982 state[op.getResult()] = lhs || rhs;
1983 return DeletionKind::Delete;
1984 }
1985
1986 FailureOr<DeletionKind> visitOp(arith::SelectOp op) {
1987 bool cond = get<bool>(op.getCondition());
1988 auto trueVal = state[op.getTrueValue()];
1989 auto falseVal = state[op.getFalseValue()];
1990 state[op.getResult()] = cond ? trueVal : falseVal;
1991 return DeletionKind::Delete;
1992 }
1993
1994 FailureOr<DeletionKind> visitOp(index::AddOp op) {
1995 size_t lhs = get<size_t>(op.getLhs());
1996 size_t rhs = get<size_t>(op.getRhs());
1997 state[op.getResult()] = lhs + rhs;
1998 return DeletionKind::Delete;
1999 }
2000
2001 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
2002 size_t lhs = get<size_t>(op.getLhs());
2003 size_t rhs = get<size_t>(op.getRhs());
2004 bool result;
2005 switch (op.getPred()) {
2006 case index::IndexCmpPredicate::EQ:
2007 result = lhs == rhs;
2008 break;
2009 case index::IndexCmpPredicate::NE:
2010 result = lhs != rhs;
2011 break;
2012 case index::IndexCmpPredicate::ULT:
2013 result = lhs < rhs;
2014 break;
2015 case index::IndexCmpPredicate::ULE:
2016 result = lhs <= rhs;
2017 break;
2018 case index::IndexCmpPredicate::UGT:
2019 result = lhs > rhs;
2020 break;
2021 case index::IndexCmpPredicate::UGE:
2022 result = lhs >= rhs;
2023 break;
2024 default:
2025 return op->emitOpError("elaboration not supported");
2026 }
2027 state[op.getResult()] = result;
2028 return DeletionKind::Delete;
2029 }
2030
2031 FailureOr<DeletionKind> visitOp(ConcatImmediateOp op) {
2032 bool anyValidationValues =
2033 llvm::any_of(op.getOperands(), [&](auto operand) {
2034 return std::holds_alternative<ValidationValue *>(state[operand]);
2035 });
2036
2037 // FIXME: this is a hack and this case should be handled generically for all
2038 // operations
2039 if (anyValidationValues) {
2040 SmallVector<ElaboratorValue> operands;
2041 for (auto operand : op.getOperands())
2042 operands.push_back(state[operand]);
2043 state[op.getResult()] =
2044 sharedState.internalizer.internalize<ImmediateConcatStorage>(
2045 std::move(operands));
2046 return DeletionKind::Delete;
2047 }
2048
2049 auto result = APInt::getZeroWidth();
2050 for (auto operand : op.getOperands())
2051 result = result.concat(
2052 cast<ImmediateAttr>(get<TypedAttr>(operand)).getValue());
2053
2054 state[op.getResult()] = ImmediateAttr::get(op.getContext(), result);
2055 return DeletionKind::Delete;
2056 }
2057
2058 FailureOr<DeletionKind> visitOp(SliceImmediateOp op) {
2059 // FIXME: this is a hack and this case should be handled generically for all
2060 // operations
2061 if (std::holds_alternative<ValidationValue *>(state[op.getInput()])) {
2062 state[op.getResult()] =
2063 sharedState.internalizer.internalize<ImmediateSliceStorage>(
2064 state[op.getInput()], op.getLowBit(), op.getResult().getType());
2065 return DeletionKind::Delete;
2066 }
2067
2068 auto inputValue =
2069 cast<ImmediateAttr>(get<TypedAttr>(op.getInput())).getValue();
2070 auto sliced = inputValue.extractBits(op.getResult().getType().getWidth(),
2071 op.getLowBit());
2072 state[op.getResult()] = ImmediateAttr::get(op.getContext(), sliced);
2073 return DeletionKind::Delete;
2074 }
2075
2076 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
2077 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
2078 .Case<
2079 // Arith ops
2080 arith::AddIOp, arith::XOrIOp, arith::AndIOp, arith::OrIOp,
2081 arith::SelectOp,
2082 // Index ops
2083 index::AddOp, index::CmpOp,
2084 // SCF ops
2085 scf::IfOp, scf::ForOp, scf::YieldOp>(
2086 [&](auto op) { return visitOp(op); })
2087 .Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); });
2088 }
2089
2090 // NOLINTNEXTLINE(misc-no-recursion)
2091 LogicalResult elaborate(Region &region,
2092 ArrayRef<ElaboratorValue> regionArguments,
2093 SmallVector<ElaboratorValue> &terminatorOperands) {
2094 if (region.getBlocks().size() > 1)
2095 return region.getParentOp()->emitOpError(
2096 "regions with more than one block are not supported");
2097
2098 for (auto [arg, elabArg] :
2099 llvm::zip(region.getArguments(), regionArguments))
2100 state[arg] = elabArg;
2101
2102 Block *block = &region.front();
2103 for (auto &op : *block) {
2104 auto result = dispatchOpVisitor(&op);
2105 if (failed(result))
2106 return failure();
2107
2108 if (*result == DeletionKind::Keep)
2109 if (failed(materializer.materialize(&op, state)))
2110 return failure();
2111
2112 LLVM_DEBUG({
2113 llvm::dbgs() << "Elaborated " << op << " to\n[";
2114
2115 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) {
2116 if (state.contains(res))
2117 llvm::dbgs() << state.at(res);
2118 else
2119 llvm::dbgs() << "unknown";
2120 });
2121
2122 llvm::dbgs() << "]\n\n";
2123 });
2124 }
2125
2126 if (region.front().mightHaveTerminator())
2127 for (auto val : region.front().getTerminator()->getOperands())
2128 terminatorOperands.push_back(state.at(val));
2129
2130 return success();
2131 }
2132
2133private:
2134 // State to be shared between all elaborator instances.
2135 SharedState &sharedState;
2136
2137 // State to a specific RTG test and the sequences placed within it.
2138 TestState &testState;
2139
2140 // Allows us to materialize ElaboratorValues to the IR operations necessary to
2141 // obtain an SSA value representing that elaborated value.
2142 Materializer &materializer;
2143
2144 // A map from SSA values to a pointer of an interned elaborator value.
2145 DenseMap<Value, ElaboratorValue> state;
2146
2147 // The current context we are elaborating under.
2148 ContextResourceAttrInterface currentContext;
2149};
2150} // namespace
2151
2152SequenceOp
2153Materializer::elaborateSequence(const RandomizedSequenceStorage *seq,
2154 SmallVector<ElaboratorValue> &elabArgs) {
2155 auto familyOp =
2156 sharedState.table.lookup<SequenceOp>(seq->sequence->familyName);
2157 // TODO: don't clone if this is the only remaining reference to this
2158 // sequence
2159 OpBuilder builder(familyOp);
2160 auto seqOp = builder.cloneWithoutRegions(familyOp);
2161 auto name = sharedState.names.newName(seq->sequence->familyName.getValue());
2162 seqOp.setSymName(name);
2163 seqOp.getBodyRegion().emplaceBlock();
2164 sharedState.table.insert(seqOp);
2165 assert(seqOp.getSymName() == name && "should not have been renamed");
2166
2167 LLVM_DEBUG(llvm::dbgs() << "\n=== Elaborating sequence family @"
2168 << familyOp.getSymName() << " into @"
2169 << seqOp.getSymName() << " under context "
2170 << seq->context << "\n\n");
2171
2172 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()), testState,
2173 sharedState, elabArgs);
2174 Elaborator elaborator(sharedState, testState, materializer, seq->context);
2175 SmallVector<ElaboratorValue> yieldedVals;
2176 if (failed(elaborator.elaborate(familyOp.getBodyRegion(), seq->sequence->args,
2177 yieldedVals)))
2178 return {};
2179
2180 seqOp.setSequenceType(
2181 SequenceType::get(builder.getContext(), materializer.getBlockArgTypes()));
2182 materializer.finalize();
2183
2184 return seqOp;
2185}
2186
2187//===----------------------------------------------------------------------===//
2188// Elaborator Pass
2189//===----------------------------------------------------------------------===//
2190
2191namespace {
2192struct ElaborationPass
2193 : public rtg::impl::ElaborationPassBase<ElaborationPass> {
2194 using Base::Base;
2195
2196 void runOnOperation() override;
2197 void matchTestsAgainstTargets(SymbolTable &table);
2198 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
2199};
2200} // namespace
2201
2202void ElaborationPass::runOnOperation() {
2203 auto moduleOp = getOperation();
2204 SymbolTable table(moduleOp);
2205
2206 matchTestsAgainstTargets(table);
2207
2208 if (failed(elaborateModule(moduleOp, table)))
2209 return signalPassFailure();
2210}
2211
2212void ElaborationPass::matchTestsAgainstTargets(SymbolTable &table) {
2213 auto moduleOp = getOperation();
2214
2215 for (auto test : llvm::make_early_inc_range(moduleOp.getOps<TestOp>())) {
2216 if (test.getTargetAttr())
2217 continue;
2218
2219 bool matched = false;
2220
2221 for (auto target : moduleOp.getOps<TargetOp>()) {
2222 // Check if the target type is a subtype of the test's target type
2223 // This means that for each entry in the test's target type, there must be
2224 // a corresponding entry with the same name and type in the target's type
2225 bool isSubtype = true;
2226 auto testEntries = test.getTargetType().getEntries();
2227 auto targetEntries = target.getTarget().getEntries();
2228
2229 // Check if target is a subtype of test requirements
2230 // Since entries are sorted by name, we can do this in a single pass
2231 size_t targetIdx = 0;
2232 for (auto testEntry : testEntries) {
2233 // Find the matching entry in target entries.
2234 while (targetIdx < targetEntries.size() &&
2235 targetEntries[targetIdx].name.getValue() <
2236 testEntry.name.getValue())
2237 targetIdx++;
2238
2239 // Check if we found a matching entry with the same name and type
2240 if (targetIdx >= targetEntries.size() ||
2241 targetEntries[targetIdx].name != testEntry.name ||
2242 targetEntries[targetIdx].type != testEntry.type) {
2243 isSubtype = false;
2244 break;
2245 }
2246 }
2247
2248 if (!isSubtype)
2249 continue;
2250
2251 IRRewriter rewriter(test);
2252 // Create a new test for the matched target
2253 auto newTest = cast<TestOp>(test->clone());
2254 newTest.setSymName(test.getSymName().str() + "_" +
2255 target.getSymName().str());
2256
2257 // Set the target symbol specifying that this test is only suitable for
2258 // that target.
2259 newTest.setTargetAttr(target.getSymNameAttr());
2260
2261 table.insert(newTest, rewriter.getInsertionPoint());
2262 matched = true;
2263 }
2264
2265 if (matched || deleteUnmatchedTests)
2266 test->erase();
2267 }
2268}
2269
2270static bool onlyLegalToMaterializeInTarget(Type type) {
2271 return isa<MemoryBlockType, ContextResourceTypeInterface>(type);
2272}
2273
2274LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
2275 SymbolTable &table) {
2276 SharedState state(table, seed);
2277
2278 // Update the name cache
2279 state.names.add(moduleOp);
2280
2281 struct TargetElabResult {
2282 DictType targetType;
2283 SmallVector<ElaboratorValue> yields;
2284 TestState testState;
2285 };
2286
2287 // Map to store elaborated targets
2288 DenseMap<StringAttr, TargetElabResult> targetMap;
2289 for (auto targetOp : moduleOp.getOps<TargetOp>()) {
2290 LLVM_DEBUG(llvm::dbgs() << "=== Elaborating target @"
2291 << targetOp.getSymName() << "\n\n");
2292
2293 auto &result = targetMap[targetOp.getSymNameAttr()];
2294 result.targetType = targetOp.getTarget();
2295
2296 SmallVector<ElaboratorValue> blockArgs;
2297 Materializer targetMaterializer(OpBuilder::atBlockBegin(targetOp.getBody()),
2298 result.testState, state, blockArgs);
2299 Elaborator targetElaborator(state, result.testState, targetMaterializer);
2300
2301 // Elaborate the target
2302 if (failed(targetElaborator.elaborate(targetOp.getBodyRegion(), {},
2303 result.yields)))
2304 return failure();
2305 }
2306
2307 // Initialize the worklist with the test ops since they cannot be placed by
2308 // other ops.
2309 for (auto testOp : moduleOp.getOps<TestOp>()) {
2310 // Skip tests without a target attribute - these couldn't be matched
2311 // against any target but can be useful to keep around for reporting
2312 // purposes.
2313 if (!testOp.getTargetAttr())
2314 continue;
2315
2316 LLVM_DEBUG(llvm::dbgs()
2317 << "\n=== Elaborating test @" << testOp.getTemplateName()
2318 << " for target @" << *testOp.getTarget() << "\n\n");
2319
2320 // Get the target for this test
2321 auto targetResult = targetMap[testOp.getTargetAttr()];
2322 TestState testState = targetResult.testState;
2323 testState.name = testOp.getSymNameAttr();
2324
2325 SmallVector<ElaboratorValue> filteredYields;
2326 unsigned i = 0;
2327 for (auto [entry, yield] :
2328 llvm::zip(targetResult.targetType.getEntries(), targetResult.yields)) {
2329 if (i >= testOp.getTargetType().getEntries().size())
2330 break;
2331
2332 if (entry.name == testOp.getTargetType().getEntries()[i].name) {
2333 filteredYields.push_back(yield);
2334 ++i;
2335 }
2336 }
2337
2338 // Now elaborate the test with the same state, passing the target yield
2339 // values as arguments
2340 SmallVector<ElaboratorValue> blockArgs;
2341 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()),
2342 testState, state, blockArgs);
2343
2344 for (auto [arg, val] :
2345 llvm::zip(testOp.getBody()->getArguments(), filteredYields))
2346 if (onlyLegalToMaterializeInTarget(arg.getType()))
2347 materializer.map(val, arg);
2348
2349 Elaborator elaborator(state, testState, materializer);
2350 SmallVector<ElaboratorValue> ignore;
2351 if (failed(elaborator.elaborate(testOp.getBodyRegion(), filteredYields,
2352 ignore)))
2353 return failure();
2354
2355 materializer.finalize();
2356 }
2357
2358 return success();
2359}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static bool onlyLegalToMaterializeInTarget(Type type)
static void print(TypedAttr val, llvm::raw_ostream &os)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
This helps visit TypeOp nodes.
Definition RTGVisitors.h:29
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:90
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
static llvm::hash_code hash_value(const ModulePort &port)
Definition HWTypes.h:38
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition Utils.h:32
Definition rtg.py:1
Definition seq.py:1
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)