CIRCT 23.0.0git
Loading...
Searching...
No Matches
ElaborationPass.cpp
Go to the documentation of this file.
1//===- ElaborationPass.cpp - RTG ElaborationPass implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass elaborates the random parts of the RTG dialect.
10// It performs randomization top-down, i.e., random constructs in a sequence
11// that is invoked multiple times can yield different randomization results
12// for each invokation.
13//
14//===----------------------------------------------------------------------===//
15
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Index/IR/IndexDialect.h"
23#include "mlir/Dialect/Index/IR/IndexOps.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/DenseMapInfoVariant.h"
28#include "llvm/Support/Debug.h"
29#include <random>
30
31namespace circt {
32namespace rtg {
33#define GEN_PASS_DEF_ELABORATIONPASS
34#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
35} // namespace rtg
36} // namespace circt
37
38using namespace mlir;
39using namespace circt;
40using namespace circt::rtg;
41using llvm::MapVector;
42
43#define DEBUG_TYPE "rtg-elaboration"
44
45//===----------------------------------------------------------------------===//
46// Uniform Distribution Helper
47//
48// Simplified version of
49// https://github.com/llvm/llvm-project/blob/main/libcxx/include/__random/uniform_int_distribution.h
50// We use our custom version here to get the same results when compiled with
51// different compiler versions and standard libraries.
52//===----------------------------------------------------------------------===//
53
54static uint32_t computeMask(size_t w) {
55 size_t n = w / 32 + (w % 32 != 0);
56 size_t w0 = w / n;
57 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
58}
59
60/// Get a number uniformly at random in the in specified range.
61static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) {
62 const uint32_t diff = b - a + 1;
63 if (diff == 1)
64 return a;
65
66 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
67 if (diff == 0)
68 return rng();
69
70 uint32_t width = digits - llvm::countl_zero(diff) - 1;
71 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
72 ++width;
73
74 uint32_t mask = computeMask(diff);
75 uint32_t u;
76 do {
77 u = rng() & mask;
78 } while (u >= diff);
79
80 return u + a;
81}
82
83//===----------------------------------------------------------------------===//
84// Elaborator Value
85//===----------------------------------------------------------------------===//
86
87namespace {
88struct ArrayStorage;
89struct BagStorage;
90struct SequenceStorage;
91struct RandomizedSequenceStorage;
92struct InterleavedSequenceStorage;
93struct SetStorage;
94struct VirtualRegisterStorage;
95struct UniqueLabelStorage;
96struct TupleStorage;
97struct MemoryStorage;
98struct MemoryBlockStorage;
99struct SymbolicComputationWithIdentityStorage;
100struct SymbolicComputationWithIdentityValue;
101struct SymbolicComputationStorage;
102
103/// The abstract base class for elaborated values.
104using ElaboratorValue =
105 std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
106 RandomizedSequenceStorage *, InterleavedSequenceStorage *,
107 SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
108 ArrayStorage *, TupleStorage *, MemoryStorage *,
109 MemoryBlockStorage *, SymbolicComputationWithIdentityStorage *,
110 SymbolicComputationWithIdentityValue *,
111 SymbolicComputationStorage *>;
112
113// NOLINTNEXTLINE(readability-identifier-naming)
114llvm::hash_code hash_value(const ElaboratorValue &val) {
115 return std::visit(
116 [&val](const auto &alternative) {
117 // Include index in hash to make sure same value as different
118 // alternatives don't collide.
119 return llvm::hash_combine(val.index(), alternative);
120 },
121 val);
122}
123
124} // namespace
125
126namespace llvm {
127
128template <>
129struct DenseMapInfo<bool> {
130 static inline unsigned getEmptyKey() { return false; }
131 static inline unsigned getTombstoneKey() { return true; }
132 static unsigned getHashValue(const bool &val) { return val * 37U; }
133
134 static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; }
135};
136
137} // namespace llvm
138
139//===----------------------------------------------------------------------===//
140// Elaborator Value Storages and Internalization
141//===----------------------------------------------------------------------===//
142
143namespace {
144
145/// Lightweight object to be used as the key for internalization sets. It caches
146/// the hashcode of the internalized object and a pointer to it. This allows a
147/// delayed allocation and construction of the actual object and thus only has
148/// to happen if the object is not already in the set.
149template <typename StorageTy>
150struct HashedStorage {
151 HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr)
152 : hashcode(hashcode), storage(storage) {}
153
154 unsigned hashcode;
155 StorageTy *storage;
156};
157
158/// A DenseMapInfo implementation to support 'insert_as' for the internalization
159/// sets. When comparing two 'HashedStorage's we can just compare the already
160/// internalized storage pointers, otherwise we have to call the costly
161/// 'isEqual' method.
162template <typename StorageTy>
163struct StorageKeyInfo {
164 static inline HashedStorage<StorageTy> getEmptyKey() {
165 return HashedStorage<StorageTy>(0,
166 DenseMapInfo<StorageTy *>::getEmptyKey());
167 }
168 static inline HashedStorage<StorageTy> getTombstoneKey() {
169 return HashedStorage<StorageTy>(
170 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
171 }
172
173 static inline unsigned getHashValue(const HashedStorage<StorageTy> &key) {
174 return key.hashcode;
175 }
176 static inline unsigned getHashValue(const StorageTy &key) {
177 return key.hashcode;
178 }
179
180 static inline bool isEqual(const HashedStorage<StorageTy> &lhs,
181 const HashedStorage<StorageTy> &rhs) {
182 return lhs.storage == rhs.storage;
183 }
184 static inline bool isEqual(const StorageTy &lhs,
185 const HashedStorage<StorageTy> &rhs) {
186 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
187 return false;
188
189 return lhs.isEqual(rhs.storage);
190 }
191};
192
193// Values with structural equivalence intended to be internalized.
194//===----------------------------------------------------------------------===//
195
196/// Base class for all storage objects that can be materialized as attributes.
197/// This provides a cache for the attribute that was materialized from this
198/// storage.
199struct CachableStorage {
200 TypedAttr attrCache;
201};
202
203/// Storage object for an '!rtg.set<T>'.
204struct SetStorage : CachableStorage {
205 static unsigned computeHash(const SetVector<ElaboratorValue> &set,
206 Type type) {
207 llvm::hash_code setHash = 0;
208 for (auto el : set) {
209 // Just XOR all hashes because it's a commutative operation and
210 // `llvm::hash_combine_range` is not commutative.
211 // We don't want the order in which elements were added to influence the
212 // hash and thus the equivalence of sets.
213 setHash = setHash ^ llvm::hash_combine(el);
214 }
215 return llvm::hash_combine(type, setHash);
216 }
217
218 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
219 : hashcode(computeHash(set, type)), set(std::move(set)), type(type) {}
220
221 bool isEqual(const SetStorage *other) const {
222 // Note: we are not using the `==` operator of `SetVector` because it
223 // takes the order in which elements were added into account (since it's a
224 // vector after all). We just use it as a convenient way to keep track of a
225 // deterministic order for re-materialization.
226 bool allContained = true;
227 for (auto el : set)
228 allContained &= other->set.contains(el);
229
230 return hashcode == other->hashcode && set.size() == other->set.size() &&
231 allContained && type == other->type;
232 }
233
234 // The cached hashcode to avoid repeated computations.
235 const unsigned hashcode;
236
237 // Stores the elaborated values contained in the set.
238 const SetVector<ElaboratorValue> set;
239
240 // Store the set type such that we can materialize this evaluated value
241 // also in the case where the set is empty.
242 const Type type;
243};
244
245/// Storage object for an '!rtg.bag<T>'.
246struct BagStorage {
247 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
248 : hashcode(llvm::hash_combine(
249 type, llvm::hash_combine_range(bag.begin(), bag.end()))),
250 bag(std::move(bag)), type(type) {}
251
252 bool isEqual(const BagStorage *other) const {
253 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
254 type == other->type;
255 }
256
257 // The cached hashcode to avoid repeated computations.
258 const unsigned hashcode;
259
260 // Stores the elaborated values contained in the bag with their number of
261 // occurences.
262 const MapVector<ElaboratorValue, uint64_t> bag;
263
264 // Store the bag type such that we can materialize this evaluated value
265 // also in the case where the bag is empty.
266 const Type type;
267};
268
269/// Storage object for an '!rtg.sequence'.
270struct SequenceStorage {
271 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
272 : hashcode(llvm::hash_combine(
273 familyName, llvm::hash_combine_range(args.begin(), args.end()))),
274 familyName(familyName), args(std::move(args)) {}
275
276 bool isEqual(const SequenceStorage *other) const {
277 return hashcode == other->hashcode && familyName == other->familyName &&
278 args == other->args;
279 }
280
281 // The cached hashcode to avoid repeated computations.
282 const unsigned hashcode;
283
284 // The name of the sequence family this sequence is derived from.
285 const StringAttr familyName;
286
287 // The elaborator values used during substitution of the sequence family.
288 const SmallVector<ElaboratorValue> args;
289};
290
291/// Storage object for interleaved '!rtg.randomized_sequence'es.
292struct InterleavedSequenceStorage {
293 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
294 uint32_t batchSize)
295 : sequences(std::move(sequences)), batchSize(batchSize),
296 hashcode(llvm::hash_combine(
297 llvm::hash_combine_range(sequences.begin(), sequences.end()),
298 batchSize)) {}
299
300 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
301 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
302 hashcode(llvm::hash_combine(
303 llvm::hash_combine_range(sequences.begin(), sequences.end()),
304 batchSize)) {}
305
306 bool isEqual(const InterleavedSequenceStorage *other) const {
307 return hashcode == other->hashcode && sequences == other->sequences &&
308 batchSize == other->batchSize;
309 }
310
311 const SmallVector<ElaboratorValue> sequences;
312
313 const uint32_t batchSize;
314
315 // The cached hashcode to avoid repeated computations.
316 const unsigned hashcode;
317};
318
319/// Storage object for '!rtg.array`-typed values.
320struct ArrayStorage {
321 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
322 : hashcode(llvm::hash_combine(
323 type, llvm::hash_combine_range(array.begin(), array.end()))),
324 type(type), array(array) {}
325
326 bool isEqual(const ArrayStorage *other) const {
327 return hashcode == other->hashcode && type == other->type &&
328 array == other->array;
329 }
330
331 // The cached hashcode to avoid repeated computations.
332 const unsigned hashcode;
333
334 /// The type of the array. This is necessary because an array of size 0
335 /// cannot be reconstructed without knowing the original element type.
336 const Type type;
337
338 /// The label name. For unique labels, this is just the prefix.
339 const SmallVector<ElaboratorValue> array;
340};
341
342/// Storage object for 'tuple`-typed values.
343struct TupleStorage : CachableStorage {
344 TupleStorage(SmallVector<ElaboratorValue> &&values)
345 : hashcode(llvm::hash_combine_range(values.begin(), values.end())),
346 values(std::move(values)) {}
347
348 bool isEqual(const TupleStorage *other) const {
349 return hashcode == other->hashcode && values == other->values;
350 }
351
352 // The cached hashcode to avoid repeated computations.
353 const unsigned hashcode;
354
355 const SmallVector<ElaboratorValue> values;
356};
357
358struct SymbolicComputationStorage {
359 SymbolicComputationStorage(const DenseMap<Value, ElaboratorValue> &state,
360 Operation *op)
361 : name(op->getName()), resultTypes(op->getResultTypes()),
362 operands(llvm::map_range(op->getOperands(),
363 [&](Value v) { return state.lookup(v); })),
364 attributes(op->getAttrDictionary()),
365 properties(op->getPropertiesAsAttribute()),
366 hashcode(llvm::hash_combine(name, llvm::hash_combine_range(resultTypes),
367 llvm::hash_combine_range(operands),
368 attributes, op->hashProperties())) {}
369
370 bool isEqual(const SymbolicComputationStorage *other) const {
371 return hashcode == other->hashcode && name == other->name &&
372 resultTypes == other->resultTypes && operands == other->operands &&
373 attributes == other->attributes && properties == other->properties;
374 }
375
376 const OperationName name;
377 const SmallVector<Type> resultTypes;
378 const SmallVector<ElaboratorValue> operands;
379 const DictionaryAttr attributes;
380 const Attribute properties;
381 const unsigned hashcode;
382};
383
384// Values with identity not intended to be internalized.
385//===----------------------------------------------------------------------===//
386
387/// Base class for storages that represent values with identity, i.e., two
388/// values are not considered equivalent if they are structurally the same, but
389/// each definition of such a value is unique. E.g., unique labels or virtual
390/// registers. These cannot be materialized anew in each nested sequence, but
391/// must be passed as arguments.
392struct IdentityValue {
393
394 IdentityValue(Type type, Location loc) : type(type), loc(loc) {}
395
396#ifndef NDEBUG
397
398 /// In debug mode, track whether this value was already materialized to
399 /// assert if it's illegally materialized multiple times.
400 ///
401 /// Instead of deleting operations defining these values and materializing
402 /// them again, we could retain the operations. However, we still need
403 /// specific storages to represent these values in some cases, e.g., to get
404 /// the size of a memory allocation. Also, elaboration of nested control-flow
405 /// regions (e.g. `scf.for`) relies on materialization of such values lazily
406 /// instead of cloning the operations eagerly.
407 bool alreadyMaterialized = false;
408
409#endif
410
411 const Type type;
412 const Location loc;
413};
414
415/// Represents a unique virtual register.
416struct VirtualRegisterStorage : IdentityValue {
417 VirtualRegisterStorage(VirtualRegisterConfigAttr allowedRegs, Type type,
418 Location loc)
419 : IdentityValue(type, loc), allowedRegs(allowedRegs) {}
420
421 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
422 // VirtualRegisters are never internalized.
423
424 // The list of fixed registers allowed to be selected for this virtual
425 // register.
426 const VirtualRegisterConfigAttr allowedRegs;
427};
428
429struct UniqueLabelStorage : IdentityValue {
430 UniqueLabelStorage(const ElaboratorValue &name, Location loc)
431 : IdentityValue(LabelType::get(loc->getContext()), loc), name(name) {}
432
433 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
434 // VirtualRegisters are never internalized.
435
436 /// The label name. For unique labels, this is just the prefix.
437 const ElaboratorValue name;
438};
439
440/// Storage object for '!rtg.isa.memoryblock`-typed values.
441struct MemoryBlockStorage : IdentityValue {
442 MemoryBlockStorage(const APInt &baseAddress, const APInt &endAddress,
443 Type type, Location loc)
444 : IdentityValue(type, loc), baseAddress(baseAddress),
445 endAddress(endAddress) {}
446
447 // The base address of the memory. The width of the APInt also represents the
448 // address width of the memory. This is an APInt to support memories of
449 // >64-bit machines.
450 const APInt baseAddress;
451
452 // The last address of the memory.
453 const APInt endAddress;
454};
455
456/// Storage object for '!rtg.isa.memory`-typed values.
457struct MemoryStorage : IdentityValue {
458 MemoryStorage(MemoryBlockStorage *memoryBlock, size_t size, size_t alignment,
459 Location loc)
460 : IdentityValue(MemoryType::get(memoryBlock->type.getContext(),
461 memoryBlock->baseAddress.getBitWidth()),
462 loc),
463 memoryBlock(memoryBlock), size(size), alignment(alignment) {}
464
465 MemoryBlockStorage *memoryBlock;
466 const size_t size;
467 const size_t alignment;
468};
469
470/// Storage object for an '!rtg.randomized_sequence'.
471struct RandomizedSequenceStorage : IdentityValue {
472 RandomizedSequenceStorage(ContextResourceAttrInterface context,
473 SequenceStorage *sequence, Location loc)
474 : IdentityValue(
475 RandomizedSequenceType::get(sequence->familyName.getContext()),
476 loc),
477 context(context), sequence(sequence) {}
478
479 // The context under which this sequence is placed.
480 const ContextResourceAttrInterface context;
481
482 const SequenceStorage *sequence;
483};
484
485/// Operation must have at least 1 result.
486struct SymbolicComputationWithIdentityStorage : IdentityValue {
487 SymbolicComputationWithIdentityStorage(
488 const DenseMap<Value, ElaboratorValue> &state, Operation *op)
489 : IdentityValue(op->getResult(0).getType(), op->getLoc()),
490 name(op->getName()), resultTypes(op->getResultTypes()),
491 operands(llvm::map_range(op->getOperands(),
492 [&](Value v) { return state.lookup(v); })),
493 attributes(op->getAttrDictionary()),
494 properties(op->getPropertiesAsAttribute()) {}
495
496 const OperationName name;
497 const SmallVector<Type> resultTypes;
498 const SmallVector<ElaboratorValue> operands;
499 const DictionaryAttr attributes;
500 const Attribute properties;
501};
502
503struct SymbolicComputationWithIdentityValue : IdentityValue {
504 SymbolicComputationWithIdentityValue(
505 Type type, const SymbolicComputationWithIdentityStorage *storage,
506 unsigned idx)
507 : IdentityValue(type, storage->loc), storage(storage), idx(idx) {
508 assert(
509 idx != 0 &&
510 "Use SymbolicComputationWithIdentityStorage for result with index 0.");
511 }
512
513 const SymbolicComputationWithIdentityStorage *storage;
514 const unsigned idx;
515};
516
517/// An 'Internalizer' object internalizes storages and takes ownership of them.
518/// When the initializer object is destroyed, all owned storages are also
519/// deallocated and thus must not be accessed anymore.
520class Internalizer {
521public:
522 /// Internalize a storage of type `StorageTy` constructed with arguments
523 /// `args`. The pointers returned by this method can be used to compare
524 /// objects when, e.g., computing set differences, uniquing the elements in a
525 /// set, etc. Otherwise, we'd need to do a deep value comparison in those
526 /// situations.
527 template <typename StorageTy, typename... Args>
528 StorageTy *internalize(Args &&...args) {
529 static_assert(!std::is_base_of_v<IdentityValue, StorageTy> &&
530 "values with identity must not be internalized");
531
532 StorageTy storage(std::forward<Args>(args)...);
533
534 auto existing = getInternSet<StorageTy>().insert_as(
535 HashedStorage<StorageTy>(storage.hashcode), storage);
536 StorageTy *&storagePtr = existing.first->storage;
537 if (existing.second)
538 storagePtr =
539 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
540
541 return storagePtr;
542 }
543
544 template <typename StorageTy, typename... Args>
545 StorageTy *create(Args &&...args) {
546 static_assert(std::is_base_of_v<IdentityValue, StorageTy> &&
547 "values with structural equivalence must be internalized");
548
549 return new (allocator.Allocate<StorageTy>())
550 StorageTy(std::forward<Args>(args)...);
551 }
552
553private:
554 template <typename StorageTy>
555 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
556 getInternSet() {
557 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
558 return internedArrays;
559 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
560 return internedSets;
561 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
562 return internedBags;
563 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
564 return internedSequences;
565 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
566 return internedRandomizedSequences;
567 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
568 return internedInterleavedSequences;
569 else if constexpr (std::is_same_v<StorageTy, TupleStorage>)
570 return internedTuples;
571 else if constexpr (std::is_same_v<StorageTy, SymbolicComputationStorage>)
572 return internedSymbolicComputationWithIdentityValues;
573 else
574 static_assert(!sizeof(StorageTy),
575 "no intern set available for this storage type.");
576 }
577
578 // This allocator allocates on the heap. It automatically deallocates all
579 // objects it allocated once the allocator itself is destroyed.
580 llvm::BumpPtrAllocator allocator;
581
582 // The sets holding the internalized objects. We use one set per storage type
583 // such that we can have a simpler equality checking function (no need to
584 // compare some sort of TypeIDs).
585 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
586 internedArrays;
587 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
588 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
589 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
590 internedSequences;
591 DenseSet<HashedStorage<RandomizedSequenceStorage>,
592 StorageKeyInfo<RandomizedSequenceStorage>>
593 internedRandomizedSequences;
594 DenseSet<HashedStorage<InterleavedSequenceStorage>,
595 StorageKeyInfo<InterleavedSequenceStorage>>
596 internedInterleavedSequences;
597 DenseSet<HashedStorage<TupleStorage>, StorageKeyInfo<TupleStorage>>
598 internedTuples;
599 DenseSet<HashedStorage<SymbolicComputationStorage>,
600 StorageKeyInfo<SymbolicComputationStorage>>
601 internedSymbolicComputationWithIdentityValues;
602};
603
604} // namespace
605
606#ifndef NDEBUG
607
608static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
609 const ElaboratorValue &value);
610
611static void print(TypedAttr val, llvm::raw_ostream &os) {
612 os << "<attr " << val << ">";
613}
614
615static void print(BagStorage *val, llvm::raw_ostream &os) {
616 os << "<bag {";
617 llvm::interleaveComma(val->bag, os,
618 [&](const std::pair<ElaboratorValue, uint64_t> &el) {
619 os << el.first << " -> " << el.second;
620 });
621 os << "} at " << val << ">";
622}
623
624static void print(bool val, llvm::raw_ostream &os) {
625 os << "<bool " << (val ? "true" : "false") << ">";
626}
627
628static void print(size_t val, llvm::raw_ostream &os) {
629 os << "<index " << val << ">";
630}
631
632static void print(SequenceStorage *val, llvm::raw_ostream &os) {
633 os << "<sequence @" << val->familyName.getValue() << "(";
634 llvm::interleaveComma(val->args, os,
635 [&](const ElaboratorValue &val) { os << val; });
636 os << ") at " << val << ">";
637}
638
639static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
640 os << "<randomized-sequence derived from @"
641 << val->sequence->familyName.getValue() << " under context "
642 << val->context << "(";
643 llvm::interleaveComma(val->sequence->args, os,
644 [&](const ElaboratorValue &val) { os << val; });
645 os << ") at " << val << ">";
646}
647
648static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
649 os << "<interleaved-sequence [";
650 llvm::interleaveComma(val->sequences, os,
651 [&](const ElaboratorValue &val) { os << val; });
652 os << "] batch-size " << val->batchSize << " at " << val << ">";
653}
654
655static void print(ArrayStorage *val, llvm::raw_ostream &os) {
656 os << "<array [";
657 llvm::interleaveComma(val->array, os,
658 [&](const ElaboratorValue &val) { os << val; });
659 os << "] at " << val << ">";
660}
661
662static void print(SetStorage *val, llvm::raw_ostream &os) {
663 os << "<set {";
664 llvm::interleaveComma(val->set, os,
665 [&](const ElaboratorValue &val) { os << val; });
666 os << "} at " << val << ">";
667}
668
669static void print(const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
670 os << "<virtual-register " << val << " " << val->allowedRegs << ">";
671}
672
673static void print(const UniqueLabelStorage *val, llvm::raw_ostream &os) {
674 os << "<unique-label " << val << " " << val->name << ">";
675}
676
677static void print(const TupleStorage *val, llvm::raw_ostream &os) {
678 os << "<tuple (";
679 llvm::interleaveComma(val->values, os,
680 [&](const ElaboratorValue &val) { os << val; });
681 os << ")>";
682}
683
684static void print(const MemoryStorage *val, llvm::raw_ostream &os) {
685 os << "<memory {" << ElaboratorValue(val->memoryBlock)
686 << ", size=" << val->size << ", alignment=" << val->alignment << "}>";
687}
688
689static void print(const MemoryBlockStorage *val, llvm::raw_ostream &os) {
690 os << "<memory-block {"
691 << ", address-width=" << val->baseAddress.getBitWidth()
692 << ", base-address=" << val->baseAddress
693 << ", end-address=" << val->endAddress << "}>";
694}
695
696static void print(const SymbolicComputationWithIdentityValue *val,
697 llvm::raw_ostream &os) {
698 os << "<symbolic-computation-with-identity-value (" << val->storage << ") at "
699 << val->idx << ">";
700}
701
702static void print(const SymbolicComputationWithIdentityStorage *val,
703 llvm::raw_ostream &os) {
704 os << "<symbolic-computation-with-identity " << val->name << "(";
705 llvm::interleaveComma(val->operands, os,
706 [&](const ElaboratorValue &val) { os << val; });
707 os << ") -> " << val->resultTypes << " with attributes " << val->attributes
708 << " and properties " << val->properties;
709 os << ">";
710}
711
712static void print(const SymbolicComputationStorage *val,
713 llvm::raw_ostream &os) {
714 os << "<symbolic-computation " << val->name << "(";
715 llvm::interleaveComma(val->operands, os,
716 [&](const ElaboratorValue &val) { os << val; });
717 os << ") -> " << val->resultTypes << " with attributes " << val->attributes
718 << " and properties " << val->properties;
719 os << ">";
720}
721
722static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
723 const ElaboratorValue &value) {
724 std::visit([&](auto val) { print(val, os); }, value);
725
726 return os;
727}
728
729#endif
730
731//===----------------------------------------------------------------------===//
732// Attribute <-> ElaboratorValue Converters
733//===----------------------------------------------------------------------===//
734
735namespace {
736
737/// Convert an attribute to an ElaboratorValue. This handles nested attributes
738/// like SetAttr and TupleAttr recursively.
739class AttributeToElaboratorValueConverter {
740public:
741 AttributeToElaboratorValueConverter(Internalizer &internalizer)
742 : internalizer(internalizer) {}
743
744 /// Convert an attribute to an ElaboratorValue.
745 FailureOr<ElaboratorValue> convert(Attribute attr) {
746 return llvm::TypeSwitch<Attribute, FailureOr<ElaboratorValue>>(attr)
747 .Case<IntegerAttr, SetAttr, TupleAttr>(
748 [&](auto attr) { return convert(attr); })
749 .Case<TypedAttr>([&](auto typedAttr) -> FailureOr<ElaboratorValue> {
750 return ElaboratorValue(typedAttr);
751 })
752 .Default(
753 [&](Attribute) -> FailureOr<ElaboratorValue> { return failure(); });
754 }
755
756private:
757 FailureOr<ElaboratorValue> convert(IntegerAttr attr) {
758 if (attr.getType().isSignlessInteger(1))
759 return ElaboratorValue(bool(attr.getInt()));
760 if (isa<IndexType>(attr.getType()))
761 return ElaboratorValue(size_t(attr.getInt()));
762 return ElaboratorValue(attr);
763 }
764
765 FailureOr<ElaboratorValue> convert(SetAttr setAttr) {
766 SetVector<ElaboratorValue> set;
767 for (auto element : *setAttr.getElements()) {
768 auto converted = convert(element);
769 if (failed(converted))
770 return failure();
771 set.insert(*converted);
772 }
773 auto *storage =
774 internalizer.internalize<SetStorage>(std::move(set), setAttr.getType());
775 // Cache the original attribute for efficient materialization
776 storage->attrCache = setAttr;
777 return ElaboratorValue(storage);
778 }
779
780 FailureOr<ElaboratorValue> convert(TupleAttr tupleAttr) {
781 SmallVector<ElaboratorValue> values;
782 for (auto element : tupleAttr.getElements()) {
783 auto converted = convert(element);
784 if (failed(converted))
785 return failure();
786 values.push_back(*converted);
787 }
788 auto *storage = internalizer.internalize<TupleStorage>(std::move(values));
789 // Cache the original attribute for efficient materialization
790 storage->attrCache = tupleAttr;
791 return ElaboratorValue(storage);
792 }
793
794 Internalizer &internalizer;
795};
796
797/// Convert an ElaboratorValue to an attribute when possible. Not all
798/// ElaboratorValues can be converted to attributes (e.g., values with
799/// identity).
800class ElaboratorValueToAttributeConverter {
801public:
802 ElaboratorValueToAttributeConverter(MLIRContext *context)
803 : context(context) {}
804
805 /// Convert an ElaboratorValue to an attribute. Returns a null attribute if
806 /// the conversion is not possible.
807 TypedAttr convert(const ElaboratorValue &value) {
808 return std::visit(
809 [&](auto val) -> TypedAttr {
810 if constexpr (std::is_base_of_v<CachableStorage,
811 std::remove_pointer_t<
812 std::decay_t<decltype(value)>>>) {
813 if (val->attrCache)
814 return val->attrCache;
815 }
816 return visit(val);
817 },
818 value);
819 }
820
821private:
822 TypedAttr visit(TypedAttr val) { return val; }
823
824 TypedAttr visit(bool val) {
825 return IntegerAttr::get(IntegerType::get(context, 1), val);
826 }
827
828 TypedAttr visit(size_t val) {
829 return IntegerAttr::get(IndexType::get(context), val);
830 }
831
832 TypedAttr visit(SetStorage *val) {
833 DenseSet<TypedAttr> elements;
834 for (auto element : val->set) {
835 auto converted = convert(element);
836 if (!converted)
837 return {};
838 auto typedAttr = dyn_cast<TypedAttr>(converted);
839 if (!typedAttr)
840 return {};
841 elements.insert(typedAttr);
842 }
843 return SetAttr::get(cast<SetType>(val->type), &elements);
844 }
845
846 TypedAttr visit(TupleStorage *val) {
847 SmallVector<TypedAttr> elements;
848 for (auto element : val->values) {
849 auto converted = convert(element);
850 if (!converted)
851 return {};
852 auto typedAttr = dyn_cast<TypedAttr>(converted);
853 if (!typedAttr)
854 return {};
855 elements.push_back(typedAttr);
856 }
857 return TupleAttr::get(context, elements);
858 }
859
860 // Default implementation for storage types that cannot be converted to
861 // attributes. Using a macro to avoid repetition. std::visit does not support
862 // any kind of "default" clauses, unfortunately.
863#define VISIT_UNSUPPORTED(STORAGETYPE) \
864 /* NOLINTNEXTLINE(bugprone-macro-parentheses)*/ \
865 TypedAttr visit(STORAGETYPE *val) { return {}; }
866
867 VISIT_UNSUPPORTED(ArrayStorage)
868 VISIT_UNSUPPORTED(BagStorage)
869 VISIT_UNSUPPORTED(SequenceStorage)
870 VISIT_UNSUPPORTED(RandomizedSequenceStorage)
871 VISIT_UNSUPPORTED(InterleavedSequenceStorage)
872 VISIT_UNSUPPORTED(VirtualRegisterStorage)
873 VISIT_UNSUPPORTED(UniqueLabelStorage)
874 VISIT_UNSUPPORTED(MemoryStorage)
875 VISIT_UNSUPPORTED(MemoryBlockStorage)
876 VISIT_UNSUPPORTED(SymbolicComputationWithIdentityStorage)
877 VISIT_UNSUPPORTED(SymbolicComputationWithIdentityValue)
878 VISIT_UNSUPPORTED(SymbolicComputationStorage)
879
880#undef VISIT_UNSUPPORTED
881
882 MLIRContext *context;
883};
884
885} // namespace
886
887//===----------------------------------------------------------------------===//
888// Elaborator Value Materialization
889//===----------------------------------------------------------------------===//
890
891namespace {
892
893/// State that should be shared by all elaborator and materializer instances.
894struct SharedState {
895 SharedState(MLIRContext *ctxt, SymbolTable &table, unsigned seed)
896 : ctxt(ctxt), table(table), rng(seed) {}
897
898 MLIRContext *ctxt;
899 SymbolTable &table;
900 std::mt19937 rng;
901 Namespace names;
902 Internalizer internalizer;
903};
904
905/// A collection of state per RTG test.
906struct TestState {
907 /// The name of the test.
908 StringAttr name;
909
910 /// The context switches registered for this test.
911 MapVector<
912 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
913 SequenceStorage *>
914 contextSwitches;
915};
916
917/// Construct an SSA value from a given elaborated value.
918class Materializer {
919public:
920 Materializer(OpBuilder builder, TestState &testState,
921 SharedState &sharedState,
922 SmallVector<ElaboratorValue> &blockArgs)
923 : builder(builder), testState(testState), sharedState(sharedState),
924 blockArgs(blockArgs), attrConverter(builder.getContext()) {}
925
926 /// Materialize IR representing the provided `ElaboratorValue` and return the
927 /// `Value` or a null value on failure.
928 Value materialize(ElaboratorValue val, Location loc,
929 function_ref<InFlightDiagnostic()> emitError) {
930 auto iter = materializedValues.find(val);
931 if (iter != materializedValues.end())
932 return iter->second;
933
934 LLVM_DEBUG(llvm::dbgs() << "Materializing " << val);
935
936 if (auto res = tryMaterializeAsConstant(val, loc))
937 return res;
938
939 // In debug mode, track whether values with identity were already
940 // materialized before and assert in such a situation.
941 Value res = std::visit(
942 [&](auto value) {
943 if constexpr (std::is_base_of_v<IdentityValue,
944 std::remove_pointer_t<
945 std::decay_t<decltype(value)>>>) {
946 if (identityValueRoot.contains(value)) {
947#ifndef NDEBUG
948 bool &materialized =
949 static_cast<IdentityValue *>(value)->alreadyMaterialized;
950 assert(!materialized && "must not already be materialized");
951 materialized = true;
952#endif
953
954 return visit(value, loc, emitError);
955 }
956
957 Value arg = builder.getBlock()->addArgument(value->type, loc);
958 blockArgs.push_back(val);
959 blockArgTypes.push_back(arg.getType());
960 materializedValues[val] = arg;
961 return arg;
962 }
963
964 return visit(value, loc, emitError);
965 },
966 val);
967
968 LLVM_DEBUG(llvm::dbgs() << " to\n" << res << "\n\n");
969
970 return res;
971 }
972
973 /// If `op` is not in the same region as the materializer insertion point, a
974 /// clone is created at the materializer's insertion point by also
975 /// materializing the `ElaboratorValue`s for each operand just before it.
976 /// Otherwise, all operations after the materializer's insertion point are
977 /// deleted until `op` is reached. An error is returned if the operation is
978 /// before the insertion point.
979 LogicalResult materialize(Operation *op,
980 DenseMap<Value, ElaboratorValue> &state) {
981 if (op->getNumRegions() > 0)
982 return op->emitOpError("ops with nested regions must be elaborated away");
983
984 // We don't support opaque values. If there is an SSA value that has a
985 // use-site it needs an equivalent ElaborationValue representation.
986 // NOTE: We could support cases where there is initially a use-site but that
987 // op is guaranteed to be deleted during elaboration. Or the use-sites are
988 // replaced with freshly materialized values from the ElaborationValue. But
989 // then, why can't we delete the value defining op?
990 for (auto res : op->getResults())
991 if (!res.use_empty())
992 return op->emitOpError(
993 "ops with results that have uses are not supported");
994
995 if (op->getParentRegion() == builder.getBlock()->getParent()) {
996 // We are doing in-place materialization, so mark all ops deleted until we
997 // reach the one to be materialized and modify it in-place.
998 deleteOpsUntil([&](auto iter) { return &*iter == op; });
999
1000 if (builder.getInsertionPoint() == builder.getBlock()->end())
1001 return op->emitError("operation did not occur after the current "
1002 "materializer insertion point");
1003
1004 LLVM_DEBUG(llvm::dbgs() << "Modifying in-place: " << *op << "\n\n");
1005 } else {
1006 LLVM_DEBUG(llvm::dbgs() << "Materializing a clone of " << *op << "\n\n");
1007 op = builder.clone(*op);
1008 builder.setInsertionPoint(op);
1009 }
1010
1011 for (auto &operand : op->getOpOperands()) {
1012 auto emitError = [&]() {
1013 auto diag = op->emitError();
1014 diag.attachNote(op->getLoc())
1015 << "while materializing value for operand#"
1016 << operand.getOperandNumber();
1017 return diag;
1018 };
1019
1020 auto elabVal = state.at(operand.get());
1021 Value val = materialize(elabVal, op->getLoc(), emitError);
1022 if (!val)
1023 return failure();
1024
1025 state[val] = elabVal;
1026 operand.set(val);
1027 }
1028
1029 builder.setInsertionPointAfter(op);
1030 return success();
1031 }
1032
1033 /// Should be called once the `Region` is successfully materialized. No calls
1034 /// to `materialize` should happen after this anymore.
1035 void finalize() {
1036 deleteOpsUntil([](auto iter) { return false; });
1037
1038 for (auto *op : llvm::reverse(toDelete))
1039 op->erase();
1040 }
1041
1042 /// Tell this materializer that it is responsible for materializing the given
1043 /// identity value at the earliest position it is needed, and should't
1044 /// request the value via block argument.
1045 void registerIdentityValue(IdentityValue *val) {
1046 identityValueRoot.insert(val);
1047 }
1048
1049 ArrayRef<Type> getBlockArgTypes() const { return blockArgTypes; }
1050
1051 void map(ElaboratorValue eval, Value val) { materializedValues[eval] = val; }
1052
1053 template <typename OpTy, typename... Args>
1054 OpTy create(Location location, Args &&...args) {
1055 return OpTy::create(builder, location, std::forward<Args>(args)...);
1056 }
1057
1058private:
1059 Value tryMaterializeAsConstant(ElaboratorValue val, Location loc) {
1060 if (auto attr = attrConverter.convert(val)) {
1061 Value res = ConstantOp::create(builder, loc, attr);
1062 materializedValues[val] = res;
1063 return res;
1064 }
1065
1066 return Value();
1067 }
1068
1069 SequenceOp elaborateSequence(const RandomizedSequenceStorage *seq,
1070 SmallVector<ElaboratorValue> &elabArgs);
1071
1072 void deleteOpsUntil(function_ref<bool(Block::iterator)> stop) {
1073 auto ip = builder.getInsertionPoint();
1074 while (ip != builder.getBlock()->end() && !stop(ip)) {
1075 LLVM_DEBUG(llvm::dbgs() << "Marking to be deleted: " << *ip << "\n\n");
1076 toDelete.push_back(&*ip);
1077
1078 builder.setInsertionPointAfter(&*ip);
1079 ip = builder.getInsertionPoint();
1080 }
1081 }
1082
1083 Value visit(TypedAttr val, Location loc,
1084 function_ref<InFlightDiagnostic()> emitError) {
1085 return {};
1086 }
1087
1088 Value visit(size_t val, Location loc,
1089 function_ref<InFlightDiagnostic()> emitError) {
1090 return {};
1091 }
1092
1093 Value visit(bool val, Location loc,
1094 function_ref<InFlightDiagnostic()> emitError) {
1095 return {};
1096 }
1097
1098 Value visit(ArrayStorage *val, Location loc,
1099 function_ref<InFlightDiagnostic()> emitError) {
1100 SmallVector<Value> elements;
1101 elements.reserve(val->array.size());
1102 for (auto el : val->array) {
1103 auto materialized = materialize(el, loc, emitError);
1104 if (!materialized)
1105 return Value();
1106
1107 elements.push_back(materialized);
1108 }
1109
1110 Value res = ArrayCreateOp::create(builder, loc, val->type, elements);
1111 materializedValues[val] = res;
1112 return res;
1113 }
1114
1115 Value visit(SetStorage *val, Location loc,
1116 function_ref<InFlightDiagnostic()> emitError) {
1117 SmallVector<Value> elements;
1118 elements.reserve(val->set.size());
1119 for (auto el : val->set) {
1120 auto materialized = materialize(el, loc, emitError);
1121 if (!materialized)
1122 return Value();
1123
1124 elements.push_back(materialized);
1125 }
1126
1127 auto res = SetCreateOp::create(builder, loc, val->type, elements);
1128 materializedValues[val] = res;
1129 return res;
1130 }
1131
1132 Value visit(BagStorage *val, Location loc,
1133 function_ref<InFlightDiagnostic()> emitError) {
1134 SmallVector<Value> values, weights;
1135 values.reserve(val->bag.size());
1136 weights.reserve(val->bag.size());
1137 for (auto [val, weight] : val->bag) {
1138 auto materializedVal = materialize(val, loc, emitError);
1139 auto materializedWeight = materialize(weight, loc, emitError);
1140 if (!materializedVal || !materializedWeight)
1141 return Value();
1142
1143 values.push_back(materializedVal);
1144 weights.push_back(materializedWeight);
1145 }
1146
1147 auto res = BagCreateOp::create(builder, loc, val->type, values, weights);
1148 materializedValues[val] = res;
1149 return res;
1150 }
1151
1152 Value visit(MemoryBlockStorage *val, Location loc,
1153 function_ref<InFlightDiagnostic()> emitError) {
1154 auto intType = builder.getIntegerType(val->baseAddress.getBitWidth());
1155 Value res = MemoryBlockDeclareOp::create(
1156 builder, val->loc, val->type,
1157 IntegerAttr::get(intType, val->baseAddress),
1158 IntegerAttr::get(intType, val->endAddress));
1159 materializedValues[val] = res;
1160 return res;
1161 }
1162
1163 Value visit(MemoryStorage *val, Location loc,
1164 function_ref<InFlightDiagnostic()> emitError) {
1165 auto memBlock = materialize(val->memoryBlock, val->loc, emitError);
1166 auto memSize = materialize(val->size, val->loc, emitError);
1167 auto memAlign = materialize(val->alignment, val->loc, emitError);
1168 if (!(memBlock && memSize && memAlign))
1169 return {};
1170
1171 Value res =
1172 MemoryAllocOp::create(builder, val->loc, memBlock, memSize, memAlign);
1173 materializedValues[val] = res;
1174 return res;
1175 }
1176
1177 Value visit(SequenceStorage *val, Location loc,
1178 function_ref<InFlightDiagnostic()> emitError) {
1179 emitError() << "materializing a non-randomized sequence not supported yet";
1180 return Value();
1181 }
1182
1183 Value visit(RandomizedSequenceStorage *val, Location loc,
1184 function_ref<InFlightDiagnostic()> emitError) {
1185 // To know which values we have to pass by argument (and not just pass all
1186 // that migth be used eagerly), we have to elaborate the sequence family if
1187 // not already done so.
1188 // We need to get back the sequence to reference, and the list of elaborated
1189 // values to pass as arguments.
1190 SmallVector<ElaboratorValue> elabArgs;
1191 // NOTE: we wouldn't need to elaborate the sequence if it doesn't contain
1192 // randomness to be elaborated.
1193 SequenceOp seqOp = elaborateSequence(val, elabArgs);
1194 if (!seqOp)
1195 return {};
1196
1197 // Materialize all the values we need to pass as arguments and collect their
1198 // types.
1199 SmallVector<Value> args;
1200 SmallVector<Type> argTypes;
1201 for (auto arg : elabArgs) {
1202 Value materialized = materialize(arg, val->loc, emitError);
1203 if (!materialized)
1204 return {};
1205
1206 args.push_back(materialized);
1207 argTypes.push_back(materialized.getType());
1208 }
1209
1210 Value res = GetSequenceOp::create(
1211 builder, val->loc, SequenceType::get(builder.getContext(), argTypes),
1212 seqOp.getSymName());
1213
1214 // Only materialize a substitute_sequence op when we have arguments to
1215 // substitute since this op does not support 0 arguments.
1216 if (!args.empty())
1217 res = SubstituteSequenceOp::create(builder, val->loc, res, args);
1218
1219 res = RandomizeSequenceOp::create(builder, val->loc, res);
1220
1221 materializedValues[val] = res;
1222 return res;
1223 }
1224
1225 Value visit(InterleavedSequenceStorage *val, Location loc,
1226 function_ref<InFlightDiagnostic()> emitError) {
1227 SmallVector<Value> sequences;
1228 for (auto seqVal : val->sequences) {
1229 Value materialized = materialize(seqVal, loc, emitError);
1230 if (!materialized)
1231 return {};
1232
1233 sequences.push_back(materialized);
1234 }
1235
1236 if (sequences.size() == 1)
1237 return sequences[0];
1238
1239 Value res =
1240 InterleaveSequencesOp::create(builder, loc, sequences, val->batchSize);
1241 materializedValues[val] = res;
1242 return res;
1243 }
1244
1245 Value visit(VirtualRegisterStorage *val, Location loc,
1246 function_ref<InFlightDiagnostic()> emitError) {
1247 Value res = VirtualRegisterOp::create(builder, val->loc, val->allowedRegs);
1248 materializedValues[val] = res;
1249 return res;
1250 }
1251
1252 Value visit(UniqueLabelStorage *val, Location loc,
1253 function_ref<InFlightDiagnostic()> emitError) {
1254 auto materialized = materialize(val->name, val->loc, emitError);
1255 if (!materialized)
1256 return {};
1257 Value res = LabelUniqueDeclOp::create(builder, val->loc, materialized);
1258 materializedValues[val] = res;
1259 return res;
1260 }
1261
1262 Value visit(TupleStorage *val, Location loc,
1263 function_ref<InFlightDiagnostic()> emitError) {
1264 SmallVector<Value> materialized;
1265 materialized.reserve(val->values.size());
1266 for (auto v : val->values)
1267 materialized.push_back(materialize(v, loc, emitError));
1268 Value res = TupleCreateOp::create(builder, loc, materialized);
1269 materializedValues[val] = res;
1270 return res;
1271 }
1272
1273 Value visit(SymbolicComputationWithIdentityValue *val, Location loc,
1274 function_ref<InFlightDiagnostic()> emitError) {
1275 auto *noConstStorage =
1276 const_cast<SymbolicComputationWithIdentityStorage *>(val->storage);
1277 auto res0 = materialize(noConstStorage, loc, emitError);
1278 if (!res0)
1279 return {};
1280
1281 auto *op = res0.getDefiningOp();
1282 auto res = op->getResults()[val->idx];
1283 materializedValues[val] = res;
1284 return res;
1285 }
1286
1287 Value visit(SymbolicComputationWithIdentityStorage *val, Location loc,
1288 function_ref<InFlightDiagnostic()> emitError) {
1289 SmallVector<Value> operands;
1290 for (auto operand : val->operands) {
1291 auto materialized = materialize(operand, val->loc, emitError);
1292 if (!materialized)
1293 return {};
1294
1295 operands.push_back(materialized);
1296 }
1297
1298 OperationState state(val->loc, val->name);
1299 state.addTypes(val->resultTypes);
1300 state.attributes = val->attributes;
1301 state.propertiesAttr = val->properties;
1302 state.addOperands(operands);
1303 auto *op = builder.create(state);
1304
1305 materializedValues[val] = op->getResult(0);
1306 return op->getResult(0);
1307 }
1308
1309 Value visit(SymbolicComputationStorage *val, Location loc,
1310 function_ref<InFlightDiagnostic()> emitError) {
1311 SmallVector<Value> operands;
1312 for (auto operand : val->operands) {
1313 auto materialized = materialize(operand, loc, emitError);
1314 if (!materialized)
1315 return {};
1316
1317 operands.push_back(materialized);
1318 }
1319
1320 OperationState state(loc, val->name);
1321 state.addTypes(val->resultTypes);
1322 state.attributes = val->attributes;
1323 state.propertiesAttr = val->properties;
1324 state.addOperands(operands);
1325 auto *op = builder.create(state);
1326
1327 for (auto res : op->getResults())
1328 materializedValues[val] = res;
1329
1330 return op->getResult(0);
1331 }
1332
1333private:
1334 /// Cache values we have already materialized to reuse them later. We start
1335 /// with an insertion point at the start of the block and cache the (updated)
1336 /// insertion point such that future materializations can also reuse previous
1337 /// materializations without running into dominance issues (or requiring
1338 /// additional checks to avoid them).
1339 DenseMap<ElaboratorValue, Value> materializedValues;
1340
1341 /// Cache the builder to continue insertions at their current insertion point
1342 /// for the reason stated above.
1343 OpBuilder builder;
1344
1345 SmallVector<Operation *> toDelete;
1346
1347 TestState &testState;
1348 SharedState &sharedState;
1349
1350 /// Keep track of the block arguments we had to add to this materializer's
1351 /// block for identity values and also remember which elaborator values are
1352 /// expected to be passed as arguments from outside.
1353 SmallVector<ElaboratorValue> &blockArgs;
1354 SmallVector<Type> blockArgTypes;
1355
1356 /// Identity values in this set are materialized by this materializer,
1357 /// otherwise they are added as block arguments and the block that wants to
1358 /// embed this sequence is expected to provide a value for it.
1359 DenseSet<IdentityValue *> identityValueRoot;
1360
1361 /// Helper to convert ElaboratorValues to attributes.
1362 ElaboratorValueToAttributeConverter attrConverter;
1363};
1364
1365//===----------------------------------------------------------------------===//
1366// Elaboration Visitor
1367//===----------------------------------------------------------------------===//
1368
1369/// Used to signal to the elaboration driver whether the operation should be
1370/// removed.
1371enum class DeletionKind { Keep, Delete };
1372
1373/// Interprets the IR to perform and lower the represented randomizations.
1374class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
1375public:
1377 using RTGBase::visitOp;
1378
1379 Elaborator(SharedState &sharedState, TestState &testState,
1380 Materializer &materializer,
1381 ContextResourceAttrInterface currentContext = {})
1382 : sharedState(sharedState), testState(testState),
1383 materializer(materializer), currentContext(currentContext),
1384 attrConverter(sharedState.internalizer),
1385 elabValConverter(sharedState.ctxt) {}
1386
1387 template <typename ValueTy>
1388 inline ValueTy get(Value val) const {
1389 return std::get<ValueTy>(state.at(val));
1390 }
1391
1392 /// Print a nice error message for operations we don't support yet.
1393 FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
1394 return visitOpGeneric(op);
1395 }
1396
1397 FailureOr<DeletionKind> visitExternalOp(Operation *op) {
1398 return visitOpGeneric(op);
1399 }
1400
1401 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
1402 SmallVector<ElaboratorValue> replacements;
1403 state[op.getResult()] =
1404 sharedState.internalizer.internalize<SequenceStorage>(
1405 op.getSequenceAttr().getAttr(), std::move(replacements));
1406 return DeletionKind::Delete;
1407 }
1408
1409 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
1410 if (isSymbolic(state.at(op.getSequence())))
1411 return visitOpGeneric(op);
1412
1413 auto *seq = get<SequenceStorage *>(op.getSequence());
1414
1415 SmallVector<ElaboratorValue> replacements(seq->args);
1416 for (auto replacement : op.getReplacements())
1417 replacements.push_back(state.at(replacement));
1418
1419 state[op.getResult()] =
1420 sharedState.internalizer.internalize<SequenceStorage>(
1421 seq->familyName, std::move(replacements));
1422
1423 return DeletionKind::Delete;
1424 }
1425
1426 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
1427 auto *seq = get<SequenceStorage *>(op.getSequence());
1428 auto *randomizedSeq =
1429 sharedState.internalizer.create<RandomizedSequenceStorage>(
1430 currentContext, seq, op.getLoc());
1431 materializer.registerIdentityValue(randomizedSeq);
1432 state[op.getResult()] =
1433 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1434 randomizedSeq);
1435 return DeletionKind::Delete;
1436 }
1437
1438 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
1439 SmallVector<ElaboratorValue> sequences;
1440 for (auto seq : op.getSequences())
1441 sequences.push_back(state.at(seq));
1442
1443 state[op.getResult()] =
1444 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1445 std::move(sequences), op.getBatchSize());
1446 return DeletionKind::Delete;
1447 }
1448
1449 // NOLINTNEXTLINE(misc-no-recursion)
1450 LogicalResult isValidContext(ElaboratorValue value, Operation *op) const {
1451 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
1452 auto *seq = std::get<RandomizedSequenceStorage *>(value);
1453 if (seq->context != currentContext) {
1454 auto err = op->emitError("attempting to place sequence derived from ")
1455 << seq->sequence->familyName.getValue() << " under context "
1456 << currentContext
1457 << ", but it was previously randomized for context ";
1458 if (seq->context)
1459 err << seq->context;
1460 else
1461 err << "'default'";
1462 return err;
1463 }
1464 return success();
1465 }
1466
1467 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1468 for (auto val : interVal->sequences)
1469 if (failed(isValidContext(val, op)))
1470 return failure();
1471 return success();
1472 }
1473
1474 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1475 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1476 if (failed(isValidContext(seqVal, op)))
1477 return failure();
1478
1479 return DeletionKind::Keep;
1480 }
1481
1482 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1483 SetVector<ElaboratorValue> set;
1484 for (auto val : op.getElements())
1485 set.insert(state.at(val));
1486
1487 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1488 std::move(set), op.getSet().getType());
1489 return DeletionKind::Delete;
1490 }
1491
1492 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1493 auto set = get<SetStorage *>(op.getSet())->set;
1494
1495 if (set.empty())
1496 return op->emitError("cannot select from an empty set");
1497
1498 size_t selected;
1499 if (auto intAttr =
1500 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1501 std::mt19937 customRng(intAttr.getInt());
1502 selected = getUniformlyInRange(customRng, 0, set.size() - 1);
1503 } else {
1504 selected = getUniformlyInRange(sharedState.rng, 0, set.size() - 1);
1505 }
1506
1507 state[op.getResult()] = set[selected];
1508 return DeletionKind::Delete;
1509 }
1510
1511 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1512 auto original = get<SetStorage *>(op.getOriginal())->set;
1513 auto diff = get<SetStorage *>(op.getDiff())->set;
1514
1515 SetVector<ElaboratorValue> result(original);
1516 result.set_subtract(diff);
1517
1518 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1519 std::move(result), op.getResult().getType());
1520 return DeletionKind::Delete;
1521 }
1522
1523 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1524 SetVector<ElaboratorValue> result;
1525 for (auto set : op.getSets())
1526 result.set_union(get<SetStorage *>(set)->set);
1527
1528 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1529 std::move(result), op.getType());
1530 return DeletionKind::Delete;
1531 }
1532
1533 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1534 auto size = get<SetStorage *>(op.getSet())->set.size();
1535 state[op.getResult()] = size;
1536 return DeletionKind::Delete;
1537 }
1538
1539 // {a0,a1} x {b0,b1} x {c0,c1} -> {(a0), (a1)} -> {(a0,b0), (a0,b1), (a1,b0),
1540 // (a1,b1)} -> {(a0,b0,c0), (a0,b0,c1), (a0,b1,c0), (a0,b1,c1), (a1,b0,c0),
1541 // (a1,b0,c1), (a1,b1,c0), (a1,b1,c1)}
1542 FailureOr<DeletionKind> visitOp(SetCartesianProductOp op) {
1543 SetVector<ElaboratorValue> result;
1544 SmallVector<SmallVector<ElaboratorValue>> tuples;
1545 tuples.push_back({});
1546
1547 for (auto input : op.getInputs()) {
1548 auto &set = get<SetStorage *>(input)->set;
1549 if (set.empty()) {
1550 SetVector<ElaboratorValue> empty;
1551 state[op.getResult()] =
1552 sharedState.internalizer.internalize<SetStorage>(std::move(empty),
1553 op.getType());
1554 return DeletionKind::Delete;
1555 }
1556
1557 for (unsigned i = 0, e = tuples.size(); i < e; ++i) {
1558 for (auto setEl : set.getArrayRef().drop_back()) {
1559 tuples.push_back(tuples[i]);
1560 tuples.back().push_back(setEl);
1561 }
1562 tuples[i].push_back(set.back());
1563 }
1564 }
1565
1566 for (auto &tup : tuples)
1567 result.insert(
1568 sharedState.internalizer.internalize<TupleStorage>(std::move(tup)));
1569
1570 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1571 std::move(result), op.getType());
1572 return DeletionKind::Delete;
1573 }
1574
1575 FailureOr<DeletionKind> visitOp(SetConvertToBagOp op) {
1576 auto set = get<SetStorage *>(op.getInput())->set;
1577 MapVector<ElaboratorValue, uint64_t> bag;
1578 for (auto val : set)
1579 bag.insert({val, 1});
1580 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1581 std::move(bag), op.getType());
1582 return DeletionKind::Delete;
1583 }
1584
1585 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1586 MapVector<ElaboratorValue, uint64_t> bag;
1587 for (auto [val, multiple] :
1588 llvm::zip(op.getElements(), op.getMultiples())) {
1589 // If the multiple is not stored as an AttributeValue, the elaboration
1590 // must have already failed earlier (since we don't have
1591 // unevaluated/opaque values).
1592 bag[state.at(val)] += get<size_t>(multiple);
1593 }
1594
1595 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1596 std::move(bag), op.getType());
1597 return DeletionKind::Delete;
1598 }
1599
1600 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1601 auto bag = get<BagStorage *>(op.getBag())->bag;
1602
1603 if (bag.empty())
1604 return op->emitError("cannot select from an empty bag");
1605
1606 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1607 prefixSum.reserve(bag.size());
1608 uint32_t accumulator = 0;
1609 for (auto [val, weight] : bag) {
1610 accumulator += weight;
1611 prefixSum.push_back({val, accumulator});
1612 }
1613
1614 auto customRng = sharedState.rng;
1615 if (auto intAttr =
1616 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1617 customRng = std::mt19937(intAttr.getInt());
1618 }
1619
1620 auto idx = getUniformlyInRange(customRng, 0, accumulator - 1);
1621 auto *iter = llvm::upper_bound(
1622 prefixSum, idx,
1623 [](uint32_t a, const std::pair<ElaboratorValue, uint32_t> &b) {
1624 return a < b.second;
1625 });
1626
1627 state[op.getResult()] = iter->first;
1628 return DeletionKind::Delete;
1629 }
1630
1631 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1632 auto original = get<BagStorage *>(op.getOriginal())->bag;
1633 auto diff = get<BagStorage *>(op.getDiff())->bag;
1634
1635 MapVector<ElaboratorValue, uint64_t> result;
1636 for (const auto &el : original) {
1637 if (!diff.contains(el.first)) {
1638 result.insert(el);
1639 continue;
1640 }
1641
1642 if (op.getInf())
1643 continue;
1644
1645 auto toDiff = diff.lookup(el.first);
1646 if (el.second <= toDiff)
1647 continue;
1648
1649 result.insert({el.first, el.second - toDiff});
1650 }
1651
1652 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1653 std::move(result), op.getType());
1654 return DeletionKind::Delete;
1655 }
1656
1657 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1658 MapVector<ElaboratorValue, uint64_t> result;
1659 for (auto bag : op.getBags()) {
1660 auto val = get<BagStorage *>(bag)->bag;
1661 for (auto [el, multiple] : val)
1662 result[el] += multiple;
1663 }
1664
1665 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1666 std::move(result), op.getType());
1667 return DeletionKind::Delete;
1668 }
1669
1670 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1671 auto size = get<BagStorage *>(op.getBag())->bag.size();
1672 state[op.getResult()] = size;
1673 return DeletionKind::Delete;
1674 }
1675
1676 FailureOr<DeletionKind> visitOp(BagConvertToSetOp op) {
1677 auto bag = get<BagStorage *>(op.getInput())->bag;
1678 SetVector<ElaboratorValue> set;
1679 for (auto [k, v] : bag)
1680 set.insert(k);
1681 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1682 std::move(set), op.getType());
1683 return DeletionKind::Delete;
1684 }
1685
1686 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1687 auto *val = sharedState.internalizer.create<VirtualRegisterStorage>(
1688 op.getAllowedRegsAttr(), op.getType(), op.getLoc());
1689 state[op.getResult()] = val;
1690 materializer.registerIdentityValue(val);
1691 return DeletionKind::Delete;
1692 }
1693
1694 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1695 SmallVector<ElaboratorValue> array;
1696 array.reserve(op.getElements().size());
1697 for (auto val : op.getElements())
1698 array.emplace_back(state.at(val));
1699
1700 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1701 op.getResult().getType(), std::move(array));
1702 return DeletionKind::Delete;
1703 }
1704
1705 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1706 auto array = get<ArrayStorage *>(op.getArray())->array;
1707 size_t idx = get<size_t>(op.getIndex());
1708
1709 if (array.size() <= idx)
1710 return op->emitError("invalid to access index ")
1711 << idx << " of an array with " << array.size() << " elements";
1712
1713 state[op.getResult()] = array[idx];
1714 return DeletionKind::Delete;
1715 }
1716
1717 FailureOr<DeletionKind> visitOp(ArrayInjectOp op) {
1718 auto arrayOpaque = state.at(op.getArray());
1719 auto idxOpaque = state.at(op.getIndex());
1720 if (isSymbolic(arrayOpaque) || isSymbolic(idxOpaque))
1721 return visitOpGeneric(op);
1722
1723 auto array = std::get<ArrayStorage *>(arrayOpaque)->array;
1724 size_t idx = std::get<size_t>(idxOpaque);
1725
1726 if (array.size() <= idx)
1727 return op->emitError("invalid to access index ")
1728 << idx << " of an array with " << array.size() << " elements";
1729
1730 array[idx] = state.at(op.getValue());
1731 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1732 op.getResult().getType(), std::move(array));
1733 return DeletionKind::Delete;
1734 }
1735
1736 FailureOr<DeletionKind> visitOp(ArrayAppendOp op) {
1737 auto array = std::get<ArrayStorage *>(state.at(op.getArray()))->array;
1738 array.push_back(state.at(op.getElement()));
1739 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1740 op.getResult().getType(), std::move(array));
1741 return DeletionKind::Delete;
1742 }
1743
1744 FailureOr<DeletionKind> visitOp(ArraySizeOp op) {
1745 auto array = get<ArrayStorage *>(op.getArray())->array;
1746 state[op.getResult()] = array.size();
1747 return DeletionKind::Delete;
1748 }
1749
1750 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1751 auto *val = sharedState.internalizer.create<UniqueLabelStorage>(
1752 state.at(op.getNamePrefix()), op.getLoc());
1753 state[op.getLabel()] = val;
1754 materializer.registerIdentityValue(val);
1755 return DeletionKind::Delete;
1756 }
1757
1758 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1759 size_t lower = get<size_t>(op.getLowerBound());
1760 size_t upper = get<size_t>(op.getUpperBound());
1761 if (lower > upper)
1762 return op->emitError("cannot select a number from an empty range");
1763
1764 if (auto intAttr =
1765 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1766 std::mt19937 customRng(intAttr.getInt());
1767 state[op.getResult()] =
1768 size_t(getUniformlyInRange(customRng, lower, upper));
1769 } else {
1770 state[op.getResult()] =
1771 size_t(getUniformlyInRange(sharedState.rng, lower, upper));
1772 }
1773
1774 return DeletionKind::Delete;
1775 }
1776
1777 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1778 size_t input = get<size_t>(op.getInput());
1779 auto width = op.getType().getWidth();
1780 auto emitError = [&]() { return op->emitError(); };
1781 if (input > APInt::getAllOnes(width).getZExtValue())
1782 return emitError() << "cannot represent " << input << " with " << width
1783 << " bits";
1784
1785 state[op.getResult()] =
1786 ImmediateAttr::get(op.getContext(), APInt(width, input));
1787 return DeletionKind::Delete;
1788 }
1789
1790 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1791 ContextResourceAttrInterface from = currentContext,
1792 to = cast<ContextResourceAttrInterface>(
1793 get<TypedAttr>(op.getContext()));
1794 if (!currentContext)
1795 from = DefaultContextAttr::get(op->getContext(), to.getType());
1796
1797 auto emitError = [&]() {
1798 auto diag = op.emitError();
1799 diag.attachNote(op.getLoc())
1800 << "while materializing value for context switching for " << op;
1801 return diag;
1802 };
1803
1804 if (from == to) {
1805 Value seqVal = materializer.materialize(
1806 get<SequenceStorage *>(op.getSequence()), op.getLoc(), emitError);
1807 if (!seqVal)
1808 return failure();
1809
1810 Value randSeqVal =
1811 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1812 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1813 return DeletionKind::Delete;
1814 }
1815
1816 // Switch to the desired context.
1817 // First, check if a context switch is registered that has the concrete
1818 // context as source and target.
1819 auto *iter = testState.contextSwitches.find({from, to});
1820
1821 // Try with 'any' context as target and the concrete context as source.
1822 if (iter == testState.contextSwitches.end())
1823 iter = testState.contextSwitches.find(
1824 {from, AnyContextAttr::get(op->getContext(), to.getType())});
1825
1826 // Try with 'any' context as source and the concrete context as target.
1827 if (iter == testState.contextSwitches.end())
1828 iter = testState.contextSwitches.find(
1829 {AnyContextAttr::get(op->getContext(), from.getType()), to});
1830
1831 // Try with 'any' context for both the source and the target.
1832 if (iter == testState.contextSwitches.end())
1833 iter = testState.contextSwitches.find(
1834 {AnyContextAttr::get(op->getContext(), from.getType()),
1835 AnyContextAttr::get(op->getContext(), to.getType())});
1836
1837 // Otherwise, fail with an error because we couldn't find a user
1838 // specification on how to switch between the requested contexts.
1839 // NOTE: we could think about supporting context switching via intermediate
1840 // context, i.e., treat it as a transitive relation.
1841 if (iter == testState.contextSwitches.end())
1842 return op->emitError("no context transition registered to switch from ")
1843 << from << " to " << to;
1844
1845 auto familyName = iter->second->familyName;
1846 SmallVector<ElaboratorValue> args{from, to,
1847 get<SequenceStorage *>(op.getSequence())};
1848 auto *seq = sharedState.internalizer.internalize<SequenceStorage>(
1849 familyName, std::move(args));
1850 auto *randSeq = sharedState.internalizer.create<RandomizedSequenceStorage>(
1851 to, seq, op.getLoc());
1852 materializer.registerIdentityValue(randSeq);
1853 Value seqVal = materializer.materialize(randSeq, op.getLoc(), emitError);
1854 if (!seqVal)
1855 return failure();
1856
1857 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1858 return DeletionKind::Delete;
1859 }
1860
1861 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1862 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1863 get<SequenceStorage *>(op.getSequence());
1864 return DeletionKind::Delete;
1865 }
1866
1867 FailureOr<DeletionKind> visitOp(MemoryBlockDeclareOp op) {
1868 auto *val = sharedState.internalizer.create<MemoryBlockStorage>(
1869 op.getBaseAddress(), op.getEndAddress(), op.getType(), op.getLoc());
1870 state[op.getResult()] = val;
1871 materializer.registerIdentityValue(val);
1872 return DeletionKind::Delete;
1873 }
1874
1875 FailureOr<DeletionKind> visitOp(MemoryAllocOp op) {
1876 size_t size = get<size_t>(op.getSize());
1877 size_t alignment = get<size_t>(op.getAlignment());
1878 auto *memBlock = get<MemoryBlockStorage *>(op.getMemoryBlock());
1879 auto *val = sharedState.internalizer.create<MemoryStorage>(
1880 memBlock, size, alignment, op.getLoc());
1881 state[op.getResult()] = val;
1882 materializer.registerIdentityValue(val);
1883 return DeletionKind::Delete;
1884 }
1885
1886 FailureOr<DeletionKind> visitOp(MemorySizeOp op) {
1887 auto *memory = get<MemoryStorage *>(op.getMemory());
1888 state[op.getResult()] = memory->size;
1889 return DeletionKind::Delete;
1890 }
1891
1892 FailureOr<DeletionKind> visitOp(TupleCreateOp op) {
1893 SmallVector<ElaboratorValue> values;
1894 values.reserve(op.getElements().size());
1895 for (auto el : op.getElements())
1896 values.push_back(state.at(el));
1897
1898 state[op.getResult()] =
1899 sharedState.internalizer.internalize<TupleStorage>(std::move(values));
1900 return DeletionKind::Delete;
1901 }
1902
1903 FailureOr<DeletionKind> visitOp(TupleExtractOp op) {
1904 auto *tuple = get<TupleStorage *>(op.getTuple());
1905 state[op.getResult()] = tuple->values[op.getIndex().getZExtValue()];
1906 return DeletionKind::Delete;
1907 }
1908
1909 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1910 bool cond = get<bool>(op.getCondition());
1911 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1912 if (toElaborate.empty())
1913 return DeletionKind::Delete;
1914
1915 // Just reuse this elaborator for the nested region because we need access
1916 // to the elaborated values outside the nested region (since it is not
1917 // isolated from above) and we want to materialize the region inline, thus
1918 // don't need a new materializer instance.
1919 SmallVector<ElaboratorValue> yieldedVals;
1920 if (failed(elaborate(toElaborate, {}, yieldedVals)))
1921 return failure();
1922
1923 // Map the results of the 'scf.if' to the yielded values.
1924 for (auto [res, out] : llvm::zip(op.getResults(), yieldedVals))
1925 state[res] = out;
1926
1927 return DeletionKind::Delete;
1928 }
1929
1930 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1931 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1932 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1933 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1934 return op->emitOpError("can only elaborate index type iterator");
1935
1936 auto lowerBound = get<size_t>(op.getLowerBound());
1937 auto step = get<size_t>(op.getStep());
1938 auto upperBound = get<size_t>(op.getUpperBound());
1939
1940 // Prepare for first iteration by assigning the nested regions block
1941 // arguments. We can just reuse this elaborator because we need access to
1942 // values elaborated in the parent region anyway and materialize everything
1943 // inline (i.e., don't need a new materializer).
1944 state[op.getInductionVar()] = lowerBound;
1945 for (auto [iterArg, initArg] :
1946 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1947 state[iterArg] = state.at(initArg);
1948
1949 // This loop performs the actual 'scf.for' loop iterations.
1950 SmallVector<ElaboratorValue> yieldedVals;
1951 for (size_t i = lowerBound; i < upperBound; i += step) {
1952 yieldedVals.clear();
1953 if (failed(elaborate(op.getBodyRegion(), {}, yieldedVals)))
1954 return failure();
1955
1956 // Prepare for the next iteration by updating the mapping of the nested
1957 // regions block arguments
1958 state[op.getInductionVar()] = i + step;
1959 for (auto [iterArg, prevIterArg] :
1960 llvm::zip(op.getRegionIterArgs(), yieldedVals))
1961 state[iterArg] = prevIterArg;
1962 }
1963
1964 // Transfer the previously yielded values to the for loop result values.
1965 for (auto [res, iterArg] :
1966 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1967 state[res] = state.at(iterArg);
1968
1969 return DeletionKind::Delete;
1970 }
1971
1972 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1973 return DeletionKind::Delete;
1974 }
1975
1976 FailureOr<DeletionKind> visitOp(arith::AddIOp op) {
1977 if (!isa<IndexType>(op.getType()))
1978 return visitOpGeneric(op);
1979
1980 size_t lhs = get<size_t>(op.getLhs());
1981 size_t rhs = get<size_t>(op.getRhs());
1982 state[op.getResult()] = lhs + rhs;
1983 return DeletionKind::Delete;
1984 }
1985
1986 FailureOr<DeletionKind> visitOp(arith::AndIOp op) {
1987 if (!op.getType().isSignlessInteger(1))
1988 return visitOpGeneric(op);
1989
1990 bool lhs = get<bool>(op.getLhs());
1991 bool rhs = get<bool>(op.getRhs());
1992 state[op.getResult()] = lhs && rhs;
1993 return DeletionKind::Delete;
1994 }
1995
1996 FailureOr<DeletionKind> visitOp(arith::XOrIOp op) {
1997 if (!op.getType().isSignlessInteger(1))
1998 return visitOpGeneric(op);
1999
2000 bool lhs = get<bool>(op.getLhs());
2001 bool rhs = get<bool>(op.getRhs());
2002 state[op.getResult()] = lhs != rhs;
2003 return DeletionKind::Delete;
2004 }
2005
2006 FailureOr<DeletionKind> visitOp(arith::OrIOp op) {
2007 if (!op.getType().isSignlessInteger(1))
2008 return visitOpGeneric(op);
2009
2010 bool lhs = get<bool>(op.getLhs());
2011 bool rhs = get<bool>(op.getRhs());
2012 state[op.getResult()] = lhs || rhs;
2013 return DeletionKind::Delete;
2014 }
2015
2016 FailureOr<DeletionKind> visitOp(arith::SelectOp op) {
2017 auto condOpaque = state.at(op.getCondition());
2018 if (isSymbolic(condOpaque))
2019 return visitOpGeneric(op);
2020
2021 bool cond = std::get<bool>(condOpaque);
2022 auto trueVal = state.at(op.getTrueValue());
2023 auto falseVal = state.at(op.getFalseValue());
2024 state[op.getResult()] = cond ? trueVal : falseVal;
2025 return DeletionKind::Delete;
2026 }
2027
2028 FailureOr<DeletionKind> visitOp(index::AddOp op) {
2029 size_t lhs = get<size_t>(op.getLhs());
2030 size_t rhs = get<size_t>(op.getRhs());
2031 state[op.getResult()] = lhs + rhs;
2032 return DeletionKind::Delete;
2033 }
2034
2035 FailureOr<DeletionKind> visitOp(index::SubOp op) {
2036 size_t lhs = get<size_t>(op.getLhs());
2037 size_t rhs = get<size_t>(op.getRhs());
2038 state[op.getResult()] = lhs - rhs;
2039 return DeletionKind::Delete;
2040 }
2041
2042 FailureOr<DeletionKind> visitOp(index::MulOp op) {
2043 size_t lhs = get<size_t>(op.getLhs());
2044 size_t rhs = get<size_t>(op.getRhs());
2045 state[op.getResult()] = lhs * rhs;
2046 return DeletionKind::Delete;
2047 }
2048
2049 FailureOr<DeletionKind> visitOp(index::DivUOp op) {
2050 size_t lhs = get<size_t>(op.getLhs());
2051 size_t rhs = get<size_t>(op.getRhs());
2052
2053 if (rhs == 0)
2054 return op->emitOpError("attempted division by zero");
2055
2056 state[op.getResult()] = lhs / rhs;
2057 return DeletionKind::Delete;
2058 }
2059
2060 FailureOr<DeletionKind> visitOp(index::CeilDivUOp op) {
2061 size_t lhs = get<size_t>(op.getLhs());
2062 size_t rhs = get<size_t>(op.getRhs());
2063
2064 if (rhs == 0)
2065 return op->emitOpError("attempted division by zero");
2066
2067 if (lhs == 0)
2068 state[op.getResult()] = (lhs + rhs - 1) / rhs;
2069 else
2070 state[op.getResult()] = 1 + ((lhs - 1) / rhs);
2071
2072 return DeletionKind::Delete;
2073 }
2074
2075 FailureOr<DeletionKind> visitOp(index::RemUOp op) {
2076 size_t lhs = get<size_t>(op.getLhs());
2077 size_t rhs = get<size_t>(op.getRhs());
2078
2079 if (rhs == 0)
2080 return op->emitOpError("attempted division by zero");
2081
2082 state[op.getResult()] = lhs % rhs;
2083 return DeletionKind::Delete;
2084 }
2085
2086 FailureOr<DeletionKind> visitOp(index::AndOp op) {
2087 size_t lhs = get<size_t>(op.getLhs());
2088 size_t rhs = get<size_t>(op.getRhs());
2089 state[op.getResult()] = lhs & rhs;
2090 return DeletionKind::Delete;
2091 }
2092
2093 FailureOr<DeletionKind> visitOp(index::OrOp op) {
2094 size_t lhs = get<size_t>(op.getLhs());
2095 size_t rhs = get<size_t>(op.getRhs());
2096 state[op.getResult()] = lhs | rhs;
2097 return DeletionKind::Delete;
2098 }
2099
2100 FailureOr<DeletionKind> visitOp(index::XOrOp op) {
2101 size_t lhs = get<size_t>(op.getLhs());
2102 size_t rhs = get<size_t>(op.getRhs());
2103 state[op.getResult()] = lhs ^ rhs;
2104 return DeletionKind::Delete;
2105 }
2106
2107 FailureOr<DeletionKind> visitOp(index::ShlOp op) {
2108 size_t lhs = get<size_t>(op.getLhs());
2109 size_t rhs = get<size_t>(op.getRhs());
2110 state[op.getResult()] = lhs << rhs;
2111 return DeletionKind::Delete;
2112 }
2113
2114 FailureOr<DeletionKind> visitOp(index::ShrUOp op) {
2115 size_t lhs = get<size_t>(op.getLhs());
2116 size_t rhs = get<size_t>(op.getRhs());
2117 state[op.getResult()] = lhs >> rhs;
2118 return DeletionKind::Delete;
2119 }
2120
2121 FailureOr<DeletionKind> visitOp(index::MaxUOp op) {
2122 size_t lhs = get<size_t>(op.getLhs());
2123 size_t rhs = get<size_t>(op.getRhs());
2124 state[op.getResult()] = std::max(lhs, rhs);
2125 return DeletionKind::Delete;
2126 }
2127
2128 FailureOr<DeletionKind> visitOp(index::MinUOp op) {
2129 size_t lhs = get<size_t>(op.getLhs());
2130 size_t rhs = get<size_t>(op.getRhs());
2131 state[op.getResult()] = std::min(lhs, rhs);
2132 return DeletionKind::Delete;
2133 }
2134
2135 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
2136 size_t lhs = get<size_t>(op.getLhs());
2137 size_t rhs = get<size_t>(op.getRhs());
2138 bool result;
2139 switch (op.getPred()) {
2140 case index::IndexCmpPredicate::EQ:
2141 result = lhs == rhs;
2142 break;
2143 case index::IndexCmpPredicate::NE:
2144 result = lhs != rhs;
2145 break;
2146 case index::IndexCmpPredicate::ULT:
2147 result = lhs < rhs;
2148 break;
2149 case index::IndexCmpPredicate::ULE:
2150 result = lhs <= rhs;
2151 break;
2152 case index::IndexCmpPredicate::UGT:
2153 result = lhs > rhs;
2154 break;
2155 case index::IndexCmpPredicate::UGE:
2156 result = lhs >= rhs;
2157 break;
2158 default:
2159 return op->emitOpError("elaboration not supported");
2160 }
2161 state[op.getResult()] = result;
2162 return DeletionKind::Delete;
2163 }
2164
2165 bool isSymbolic(ElaboratorValue val) {
2166 return std::holds_alternative<SymbolicComputationWithIdentityValue *>(
2167 val) ||
2168 std::holds_alternative<SymbolicComputationWithIdentityStorage *>(
2169 val) ||
2170 std::holds_alternative<SymbolicComputationStorage *>(val);
2171 }
2172
2173 bool isSymbolic(Operation *op) {
2174 return llvm::any_of(op->getOperands(), [&](auto operand) {
2175 auto val = state.at(operand);
2176 return isSymbolic(val);
2177 });
2178 }
2179
2180 /// If all operands are constants, try to fold the operation and register its
2181 /// result values with the folded results in the elaborator state.
2182 /// Returns 'false' if operation has to be handled symbolically.
2183 bool attemptConcreteCase(Operation *op) {
2184 if (op->getNumResults() == 0)
2185 return false;
2186
2187 SmallVector<Attribute> operands;
2188 for (auto operand : op->getOperands()) {
2189 auto evalValue = state[operand];
2190 auto attr = elabValConverter.convert(evalValue);
2191 operands.push_back(attr);
2192 }
2193
2194 SmallVector<OpFoldResult> results;
2195 if (failed(op->fold(operands, results)))
2196 return false;
2197
2198 if (results.size() != op->getNumResults())
2199 return false;
2200
2201 for (auto [res, val] : llvm::zip(results, op->getResults())) {
2202 auto attr = llvm::dyn_cast_or_null<TypedAttr>(res.dyn_cast<Attribute>());
2203 if (!attr)
2204 return false;
2205
2206 if (attr.getType() != val.getType())
2207 return false;
2208
2209 // Try to convert the attribute to an ElaboratorValue
2210 auto converted = attrConverter.convert(attr);
2211 if (succeeded(converted)) {
2212 state[val] = *converted;
2213 continue;
2214 }
2215 return false;
2216 }
2217
2218 return true;
2219 }
2220
2221 FailureOr<DeletionKind> visitOpGeneric(Operation *op) {
2222 if (op->getNumResults() == 0)
2223 return DeletionKind::Keep;
2224
2225 if (attemptConcreteCase(op))
2226 return DeletionKind::Delete;
2227
2228 if (mlir::isMemoryEffectFree(op)) {
2229 if (op->getNumResults() != 1)
2230 return op->emitOpError(
2231 "symbolic elaboration of memory-effect-free operations with "
2232 "multiple results not supported");
2233
2234 state[op->getResult(0)] =
2235 sharedState.internalizer.internalize<SymbolicComputationStorage>(
2236 state, op);
2237 return DeletionKind::Delete;
2238 }
2239
2240 // We assume that reordering operations with only allocate effects is
2241 // allowed.
2242 // FIXME: this is not how the MLIR MemoryEffects interface intends it.
2243 // We should create our own interface/trait for that use-case.
2244 // Or modify the elaboration pass to keep track of the ordering of such
2245 // instructions and materialize all operations that are not already
2246 // materialized but have to happen before the current alloc operation to be
2247 // materialized.
2248 bool onlyAlloc = mlir::hasSingleEffect<mlir::MemoryEffects::Allocate>(op);
2249 onlyAlloc |= isa<ValidateOp>(op);
2250
2251 auto *validationVal =
2252 sharedState.internalizer.create<SymbolicComputationWithIdentityStorage>(
2253 state, op);
2254 materializer.registerIdentityValue(validationVal);
2255 state[op->getResult(0)] = validationVal;
2256
2257 for (auto [i, res] : llvm::enumerate(op->getResults())) {
2258 if (i == 0)
2259 continue;
2260 auto *val =
2261 sharedState.internalizer.create<SymbolicComputationWithIdentityValue>(
2262 res.getType(), validationVal, i);
2263 state[res] = val;
2264 materializer.registerIdentityValue(val);
2265 }
2266 return onlyAlloc ? DeletionKind::Delete : DeletionKind::Keep;
2267 }
2268
2269 bool supportsSymbolicValuesNonGenerically(Operation *op) {
2270 return isa<SubstituteSequenceOp, ArrayCreateOp, ArrayInjectOp,
2271 TupleCreateOp, arith::SelectOp>(op);
2272 }
2273
2274 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
2275 if (isSymbolic(op) && !supportsSymbolicValuesNonGenerically(op))
2276 return visitOpGeneric(op);
2277
2278 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
2279 .Case<
2280 // Arith ops
2281 arith::AddIOp, arith::XOrIOp, arith::AndIOp, arith::OrIOp,
2282 arith::SelectOp,
2283 // Index ops
2284 index::AddOp, index::SubOp, index::MulOp, index::DivUOp,
2285 index::CeilDivUOp, index::RemUOp, index::AndOp, index::OrOp,
2286 index::XOrOp, index::ShlOp, index::ShrUOp, index::MaxUOp,
2287 index::MinUOp, index::CmpOp,
2288 // SCF ops
2289 scf::IfOp, scf::ForOp, scf::YieldOp>(
2290 [&](auto op) { return visitOp(op); })
2291 .Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); });
2292 }
2293
2294 // NOLINTNEXTLINE(misc-no-recursion)
2295 LogicalResult elaborate(Region &region,
2296 ArrayRef<ElaboratorValue> regionArguments,
2297 SmallVector<ElaboratorValue> &terminatorOperands) {
2298 if (region.getBlocks().size() > 1)
2299 return region.getParentOp()->emitOpError(
2300 "regions with more than one block are not supported");
2301
2302 for (auto [arg, elabArg] :
2303 llvm::zip(region.getArguments(), regionArguments))
2304 state[arg] = elabArg;
2305
2306 Block *block = &region.front();
2307 for (auto &op : *block) {
2308 auto result = dispatchOpVisitor(&op);
2309 if (failed(result))
2310 return failure();
2311
2312 if (*result == DeletionKind::Keep)
2313 if (failed(materializer.materialize(&op, state)))
2314 return failure();
2315
2316 LLVM_DEBUG({
2317 llvm::dbgs() << "Elaborated " << op << " to\n[";
2318
2319 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) {
2320 if (state.contains(res))
2321 llvm::dbgs() << state.at(res);
2322 else
2323 llvm::dbgs() << "unknown";
2324 });
2325
2326 llvm::dbgs() << "]\n\n";
2327 });
2328 }
2329
2330 if (region.front().mightHaveTerminator())
2331 for (auto val : region.front().getTerminator()->getOperands())
2332 terminatorOperands.push_back(state.at(val));
2333
2334 return success();
2335 }
2336
2337private:
2338 // State to be shared between all elaborator instances.
2339 SharedState &sharedState;
2340
2341 // State to a specific RTG test and the sequences placed within it.
2342 TestState &testState;
2343
2344 // Allows us to materialize ElaboratorValues to the IR operations necessary to
2345 // obtain an SSA value representing that elaborated value.
2346 Materializer &materializer;
2347
2348 // A map from SSA values to a pointer of an interned elaborator value.
2349 DenseMap<Value, ElaboratorValue> state;
2350
2351 // The current context we are elaborating under.
2352 ContextResourceAttrInterface currentContext;
2353
2354 // Allows us to convert attributes to ElaboratorValues.
2355 AttributeToElaboratorValueConverter attrConverter;
2356
2357 // Allows us to convert ElaboratorValues to attributes.
2358 ElaboratorValueToAttributeConverter elabValConverter;
2359};
2360} // namespace
2361
2362SequenceOp
2363Materializer::elaborateSequence(const RandomizedSequenceStorage *seq,
2364 SmallVector<ElaboratorValue> &elabArgs) {
2365 auto familyOp =
2366 sharedState.table.lookup<SequenceOp>(seq->sequence->familyName);
2367 // TODO: don't clone if this is the only remaining reference to this
2368 // sequence
2369 OpBuilder builder(familyOp);
2370 auto seqOp = builder.cloneWithoutRegions(familyOp);
2371 auto name = sharedState.names.newName(seq->sequence->familyName.getValue());
2372 seqOp.setSymName(name);
2373 seqOp.getBodyRegion().emplaceBlock();
2374 sharedState.table.insert(seqOp);
2375 assert(seqOp.getSymName() == name && "should not have been renamed");
2376
2377 LLVM_DEBUG(llvm::dbgs() << "\n=== Elaborating sequence family @"
2378 << familyOp.getSymName() << " into @"
2379 << seqOp.getSymName() << " under context "
2380 << seq->context << "\n\n");
2381
2382 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()), testState,
2383 sharedState, elabArgs);
2384 Elaborator elaborator(sharedState, testState, materializer, seq->context);
2385 SmallVector<ElaboratorValue> yieldedVals;
2386 if (failed(elaborator.elaborate(familyOp.getBodyRegion(), seq->sequence->args,
2387 yieldedVals)))
2388 return {};
2389
2390 seqOp.setSequenceType(
2391 SequenceType::get(builder.getContext(), materializer.getBlockArgTypes()));
2392 materializer.finalize();
2393
2394 return seqOp;
2395}
2396
2397//===----------------------------------------------------------------------===//
2398// Elaborator Pass
2399//===----------------------------------------------------------------------===//
2400
2401namespace {
2402struct ElaborationPass
2403 : public rtg::impl::ElaborationPassBase<ElaborationPass> {
2404 using Base::Base;
2405
2406 void runOnOperation() override;
2407 void matchTestsAgainstTargets(SymbolTable &table);
2408 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
2409};
2410} // namespace
2411
2412void ElaborationPass::runOnOperation() {
2413 auto moduleOp = getOperation();
2414 SymbolTable table(moduleOp);
2415
2416 matchTestsAgainstTargets(table);
2417
2418 if (failed(elaborateModule(moduleOp, table)))
2419 return signalPassFailure();
2420}
2421
2422void ElaborationPass::matchTestsAgainstTargets(SymbolTable &table) {
2423 auto moduleOp = getOperation();
2424
2425 for (auto test : llvm::make_early_inc_range(moduleOp.getOps<TestOp>())) {
2426 if (test.getTargetAttr())
2427 continue;
2428
2429 bool matched = false;
2430
2431 for (auto target : moduleOp.getOps<TargetOp>()) {
2432 // Check if the target type is a subtype of the test's target type
2433 // This means that for each entry in the test's target type, there must be
2434 // a corresponding entry with the same name and type in the target's type
2435 bool isSubtype = true;
2436 auto testEntries = test.getTargetType().getEntries();
2437 auto targetEntries = target.getTarget().getEntries();
2438
2439 // Check if target is a subtype of test requirements
2440 // Since entries are sorted by name, we can do this in a single pass
2441 size_t targetIdx = 0;
2442 for (auto testEntry : testEntries) {
2443 // Find the matching entry in target entries.
2444 while (targetIdx < targetEntries.size() &&
2445 targetEntries[targetIdx].name.getValue() <
2446 testEntry.name.getValue())
2447 targetIdx++;
2448
2449 // Check if we found a matching entry with the same name and type
2450 if (targetIdx >= targetEntries.size() ||
2451 targetEntries[targetIdx].name != testEntry.name ||
2452 targetEntries[targetIdx].type != testEntry.type) {
2453 isSubtype = false;
2454 break;
2455 }
2456 }
2457
2458 if (!isSubtype)
2459 continue;
2460
2461 IRRewriter rewriter(test);
2462 // Create a new test for the matched target
2463 auto newTest = cast<TestOp>(test->clone());
2464 newTest.setSymName(test.getSymName().str() + "_" +
2465 target.getSymName().str());
2466
2467 // Set the target symbol specifying that this test is only suitable for
2468 // that target.
2469 newTest.setTargetAttr(target.getSymNameAttr());
2470
2471 table.insert(newTest, rewriter.getInsertionPoint());
2472 matched = true;
2473 }
2474
2475 if (matched || deleteUnmatchedTests)
2476 test->erase();
2477 }
2478}
2479
2480static bool onlyLegalToMaterializeInTarget(Type type) {
2481 return isa<MemoryBlockType, ContextResourceTypeInterface>(type);
2482}
2483
2484LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
2485 SymbolTable &table) {
2486 SharedState state(moduleOp.getContext(), table, seed);
2487
2488 // Update the name cache
2489 state.names.add(moduleOp);
2490
2491 struct TargetElabResult {
2492 DictType targetType;
2493 SmallVector<ElaboratorValue> yields;
2494 TestState testState;
2495 };
2496
2497 // Map to store elaborated targets
2498 DenseMap<StringAttr, TargetElabResult> targetMap;
2499 for (auto targetOp : moduleOp.getOps<TargetOp>()) {
2500 LLVM_DEBUG(llvm::dbgs() << "=== Elaborating target @"
2501 << targetOp.getSymName() << "\n\n");
2502
2503 auto &result = targetMap[targetOp.getSymNameAttr()];
2504 result.targetType = targetOp.getTarget();
2505
2506 SmallVector<ElaboratorValue> blockArgs;
2507 Materializer targetMaterializer(OpBuilder::atBlockBegin(targetOp.getBody()),
2508 result.testState, state, blockArgs);
2509 Elaborator targetElaborator(state, result.testState, targetMaterializer);
2510
2511 // Elaborate the target
2512 if (failed(targetElaborator.elaborate(targetOp.getBodyRegion(), {},
2513 result.yields)))
2514 return failure();
2515
2516 targetMaterializer.finalize();
2517 }
2518
2519 // Initialize the worklist with the test ops since they cannot be placed by
2520 // other ops.
2521 for (auto testOp : moduleOp.getOps<TestOp>()) {
2522 // Skip tests without a target attribute - these couldn't be matched
2523 // against any target but can be useful to keep around for reporting
2524 // purposes.
2525 if (!testOp.getTargetAttr())
2526 continue;
2527
2528 LLVM_DEBUG(llvm::dbgs()
2529 << "\n=== Elaborating test @" << testOp.getTemplateName()
2530 << " for target @" << *testOp.getTarget() << "\n\n");
2531
2532 // Get the target for this test
2533 auto targetResult = targetMap[testOp.getTargetAttr()];
2534 TestState testState = targetResult.testState;
2535 testState.name = testOp.getSymNameAttr();
2536
2537 SmallVector<ElaboratorValue> filteredYields;
2538 unsigned i = 0;
2539 for (auto [entry, yield] :
2540 llvm::zip(targetResult.targetType.getEntries(), targetResult.yields)) {
2541 if (i >= testOp.getTargetType().getEntries().size())
2542 break;
2543
2544 if (entry.name == testOp.getTargetType().getEntries()[i].name) {
2545 filteredYields.push_back(yield);
2546 ++i;
2547 }
2548 }
2549
2550 // Now elaborate the test with the same state, passing the target yield
2551 // values as arguments
2552 SmallVector<ElaboratorValue> blockArgs;
2553 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()),
2554 testState, state, blockArgs);
2555
2556 for (auto [arg, val] :
2557 llvm::zip(testOp.getBody()->getArguments(), filteredYields))
2558 if (onlyLegalToMaterializeInTarget(arg.getType()))
2559 materializer.map(val, arg);
2560
2561 Elaborator elaborator(state, testState, materializer);
2562 SmallVector<ElaboratorValue> ignore;
2563 if (failed(elaborator.elaborate(testOp.getBodyRegion(), filteredYields,
2564 ignore)))
2565 return failure();
2566
2567 materializer.finalize();
2568 }
2569
2570 return success();
2571}
assert(baseType &&"element must be base type")
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static bool onlyLegalToMaterializeInTarget(Type type)
#define VISIT_UNSUPPORTED(STORAGETYPE)
static void print(TypedAttr val, llvm::raw_ostream &os)
static LogicalResult convert(arc::ExecuteOp op, arc::ExecuteOp::Adaptor adaptor, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
This helps visit TypeOp nodes.
Definition RTGVisitors.h:29
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:95
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static llvm::hash_code hash_value(const ModulePort &port)
Definition HWTypes.h:38
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition Utils.h:32
Definition rtg.py:1
Definition seq.py:1
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)