CIRCT 23.0.0git
Loading...
Searching...
No Matches
ElaborationPass.cpp
Go to the documentation of this file.
1//===- ElaborationPass.cpp - RTG ElaborationPass implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass elaborates the random parts of the RTG dialect.
10// It performs randomization top-down, i.e., random constructs in a sequence
11// that is invoked multiple times can yield different randomization results
12// for each invokation.
13//
14//===----------------------------------------------------------------------===//
15
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Index/IR/IndexDialect.h"
23#include "mlir/Dialect/Index/IR/IndexOps.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/DenseMapInfoVariant.h"
28#include "llvm/Support/Debug.h"
29#include <random>
30
31namespace circt {
32namespace rtg {
33#define GEN_PASS_DEF_ELABORATIONPASS
34#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
35} // namespace rtg
36} // namespace circt
37
38using namespace mlir;
39using namespace circt;
40using namespace circt::rtg;
41using llvm::MapVector;
42
43#define DEBUG_TYPE "rtg-elaboration"
44
45namespace {
46/// Helper class to generate random numbers supporting scopes of randomization.
47struct RngScope {
48 RngScope() = delete;
49 RngScope(const RngScope &) = delete;
50 RngScope &operator=(const RngScope &) = delete;
51
52 RngScope(RngScope &&) = default;
53 RngScope &operator=(RngScope &&) = default;
54
55 explicit RngScope(uint32_t seed) : rng(seed) {}
56
57 //===--------------------------------------------------------------------===//
58 // Uniform Distribution Helper
59 //
60 // Simplified version of
61 // https://github.com/llvm/llvm-project/blob/main/libcxx/include/__random/uniform_int_distribution.h
62 // We use our custom version here to get the same results when compiled with
63 // different compiler versions and standard libraries.
64 //===--------------------------------------------------------------------===//
65
66 /// Get a number uniformly at random in the in specified range.
67 uint32_t getUniformlyInRange(uint32_t a, uint32_t b) {
68 const uint32_t diff = b - a + 1;
69 if (diff == 1)
70 return a;
71
72 if (diff == 0)
73 return rng();
74
75 uint32_t mask =
76 std::numeric_limits<uint32_t>::max() >> (32 - llvm::Log2_32_Ceil(diff));
77 uint32_t u;
78 do {
79 u = rng() & mask;
80 } while (u >= diff);
81
82 return u + a;
83 }
84
85 RngScope getNested() { return RngScope(rng()); }
86
87private:
88 std::mt19937 rng;
89};
90} // namespace
91
92//===----------------------------------------------------------------------===//
93// Elaborator Value
94//===----------------------------------------------------------------------===//
95
96namespace {
97struct ArrayStorage;
98struct BagStorage;
99struct SequenceStorage;
100struct RandomizedSequenceStorage;
101struct InterleavedSequenceStorage;
102struct SetStorage;
103struct VirtualRegisterStorage;
104struct UniqueLabelStorage;
105struct TupleStorage;
106struct MemoryStorage;
107struct MemoryBlockStorage;
108struct SymbolicComputationWithIdentityStorage;
109struct SymbolicComputationWithIdentityValue;
110struct SymbolicComputationStorage;
111
112/// The abstract base class for elaborated values.
113using ElaboratorValue =
114 std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
115 RandomizedSequenceStorage *, InterleavedSequenceStorage *,
116 SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
117 ArrayStorage *, TupleStorage *, MemoryStorage *,
118 MemoryBlockStorage *, SymbolicComputationWithIdentityStorage *,
119 SymbolicComputationWithIdentityValue *,
120 SymbolicComputationStorage *>;
121
122// NOLINTNEXTLINE(readability-identifier-naming)
123llvm::hash_code hash_value(const ElaboratorValue &val) {
124 return std::visit(
125 [&val](const auto &alternative) {
126 // Include index in hash to make sure same value as different
127 // alternatives don't collide.
128 return llvm::hash_combine(val.index(), alternative);
129 },
130 val);
131}
132
133} // namespace
134
135namespace llvm {
136
137template <>
138struct DenseMapInfo<bool> {
139 static inline unsigned getEmptyKey() { return false; }
140 static inline unsigned getTombstoneKey() { return true; }
141 static unsigned getHashValue(const bool &val) { return val * 37U; }
142
143 static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; }
144};
145
146} // namespace llvm
147
148//===----------------------------------------------------------------------===//
149// Elaborator Value Storages and Internalization
150//===----------------------------------------------------------------------===//
151
152namespace {
153
154/// Lightweight object to be used as the key for internalization sets. It caches
155/// the hashcode of the internalized object and a pointer to it. This allows a
156/// delayed allocation and construction of the actual object and thus only has
157/// to happen if the object is not already in the set.
158template <typename StorageTy>
159struct HashedStorage {
160 HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr)
161 : hashcode(hashcode), storage(storage) {}
162
163 unsigned hashcode;
164 StorageTy *storage;
165};
166
167/// A DenseMapInfo implementation to support 'insert_as' for the internalization
168/// sets. When comparing two 'HashedStorage's we can just compare the already
169/// internalized storage pointers, otherwise we have to call the costly
170/// 'isEqual' method.
171template <typename StorageTy>
172struct StorageKeyInfo {
173 static inline HashedStorage<StorageTy> getEmptyKey() {
174 return HashedStorage<StorageTy>(0,
175 DenseMapInfo<StorageTy *>::getEmptyKey());
176 }
177 static inline HashedStorage<StorageTy> getTombstoneKey() {
178 return HashedStorage<StorageTy>(
179 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
180 }
181
182 static inline unsigned getHashValue(const HashedStorage<StorageTy> &key) {
183 return key.hashcode;
184 }
185 static inline unsigned getHashValue(const StorageTy &key) {
186 return key.hashcode;
187 }
188
189 static inline bool isEqual(const HashedStorage<StorageTy> &lhs,
190 const HashedStorage<StorageTy> &rhs) {
191 return lhs.storage == rhs.storage;
192 }
193 static inline bool isEqual(const StorageTy &lhs,
194 const HashedStorage<StorageTy> &rhs) {
195 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
196 return false;
197
198 return lhs.isEqual(rhs.storage);
199 }
200};
201
202// Values with structural equivalence intended to be internalized.
203//===----------------------------------------------------------------------===//
204
205/// Base class for all storage objects that can be materialized as attributes.
206/// This provides a cache for the attribute that was materialized from this
207/// storage.
208struct CachableStorage {
209 TypedAttr attrCache;
210};
211
212/// Storage object for an '!rtg.set<T>'.
213struct SetStorage : CachableStorage {
214 static unsigned computeHash(const SetVector<ElaboratorValue> &set,
215 Type type) {
216 llvm::hash_code setHash = 0;
217 for (auto el : set) {
218 // Just XOR all hashes because it's a commutative operation and
219 // `llvm::hash_combine_range` is not commutative.
220 // We don't want the order in which elements were added to influence the
221 // hash and thus the equivalence of sets.
222 setHash = setHash ^ llvm::hash_combine(el);
223 }
224 return llvm::hash_combine(type, setHash);
225 }
226
227 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
228 : hashcode(computeHash(set, type)), set(std::move(set)), type(type) {}
229
230 bool isEqual(const SetStorage *other) const {
231 // Note: we are not using the `==` operator of `SetVector` because it
232 // takes the order in which elements were added into account (since it's a
233 // vector after all). We just use it as a convenient way to keep track of a
234 // deterministic order for re-materialization.
235 bool allContained = true;
236 for (auto el : set)
237 allContained &= other->set.contains(el);
238
239 return hashcode == other->hashcode && set.size() == other->set.size() &&
240 allContained && type == other->type;
241 }
242
243 // The cached hashcode to avoid repeated computations.
244 const unsigned hashcode;
245
246 // Stores the elaborated values contained in the set.
247 const SetVector<ElaboratorValue> set;
248
249 // Store the set type such that we can materialize this evaluated value
250 // also in the case where the set is empty.
251 const Type type;
252};
253
254/// Storage object for an '!rtg.bag<T>'.
255struct BagStorage {
256 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
257 : hashcode(llvm::hash_combine(
258 type, llvm::hash_combine_range(bag.begin(), bag.end()))),
259 bag(std::move(bag)), type(type) {}
260
261 bool isEqual(const BagStorage *other) const {
262 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
263 type == other->type;
264 }
265
266 // The cached hashcode to avoid repeated computations.
267 const unsigned hashcode;
268
269 // Stores the elaborated values contained in the bag with their number of
270 // occurences.
271 const MapVector<ElaboratorValue, uint64_t> bag;
272
273 // Store the bag type such that we can materialize this evaluated value
274 // also in the case where the bag is empty.
275 const Type type;
276};
277
278/// Storage object for an '!rtg.sequence'.
279struct SequenceStorage {
280 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
281 : hashcode(llvm::hash_combine(
282 familyName, llvm::hash_combine_range(args.begin(), args.end()))),
283 familyName(familyName), args(std::move(args)) {}
284
285 bool isEqual(const SequenceStorage *other) const {
286 return hashcode == other->hashcode && familyName == other->familyName &&
287 args == other->args;
288 }
289
290 // The cached hashcode to avoid repeated computations.
291 const unsigned hashcode;
292
293 // The name of the sequence family this sequence is derived from.
294 const StringAttr familyName;
295
296 // The elaborator values used during substitution of the sequence family.
297 const SmallVector<ElaboratorValue> args;
298};
299
300/// Storage object for interleaved '!rtg.randomized_sequence'es.
301struct InterleavedSequenceStorage {
302 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
303 uint32_t batchSize)
304 : sequences(std::move(sequences)), batchSize(batchSize),
305 hashcode(llvm::hash_combine(
306 llvm::hash_combine_range(sequences.begin(), sequences.end()),
307 batchSize)) {}
308
309 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
310 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
311 hashcode(llvm::hash_combine(
312 llvm::hash_combine_range(sequences.begin(), sequences.end()),
313 batchSize)) {}
314
315 bool isEqual(const InterleavedSequenceStorage *other) const {
316 return hashcode == other->hashcode && sequences == other->sequences &&
317 batchSize == other->batchSize;
318 }
319
320 const SmallVector<ElaboratorValue> sequences;
321
322 const uint32_t batchSize;
323
324 // The cached hashcode to avoid repeated computations.
325 const unsigned hashcode;
326};
327
328/// Storage object for '!rtg.array`-typed values.
329struct ArrayStorage {
330 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
331 : hashcode(llvm::hash_combine(
332 type, llvm::hash_combine_range(array.begin(), array.end()))),
333 type(type), array(array) {}
334
335 bool isEqual(const ArrayStorage *other) const {
336 return hashcode == other->hashcode && type == other->type &&
337 array == other->array;
338 }
339
340 // The cached hashcode to avoid repeated computations.
341 const unsigned hashcode;
342
343 /// The type of the array. This is necessary because an array of size 0
344 /// cannot be reconstructed without knowing the original element type.
345 const Type type;
346
347 /// The label name. For unique labels, this is just the prefix.
348 const SmallVector<ElaboratorValue> array;
349};
350
351/// Storage object for 'tuple`-typed values.
352struct TupleStorage : CachableStorage {
353 TupleStorage(SmallVector<ElaboratorValue> &&values)
354 : hashcode(llvm::hash_combine_range(values.begin(), values.end())),
355 values(std::move(values)) {}
356
357 bool isEqual(const TupleStorage *other) const {
358 return hashcode == other->hashcode && values == other->values;
359 }
360
361 // The cached hashcode to avoid repeated computations.
362 const unsigned hashcode;
363
364 const SmallVector<ElaboratorValue> values;
365};
366
367struct SymbolicComputationStorage {
368 SymbolicComputationStorage(const DenseMap<Value, ElaboratorValue> &state,
369 Operation *op)
370 : name(op->getName()), resultTypes(op->getResultTypes()),
371 operands(llvm::map_range(op->getOperands(),
372 [&](Value v) { return state.lookup(v); })),
373 attributes(op->getAttrDictionary()),
374 properties(op->getPropertiesAsAttribute()),
375 hashcode(llvm::hash_combine(name, llvm::hash_combine_range(resultTypes),
376 llvm::hash_combine_range(operands),
377 attributes, op->hashProperties())) {}
378
379 bool isEqual(const SymbolicComputationStorage *other) const {
380 return hashcode == other->hashcode && name == other->name &&
381 resultTypes == other->resultTypes && operands == other->operands &&
382 attributes == other->attributes && properties == other->properties;
383 }
384
385 const OperationName name;
386 const SmallVector<Type> resultTypes;
387 const SmallVector<ElaboratorValue> operands;
388 const DictionaryAttr attributes;
389 const Attribute properties;
390 const unsigned hashcode;
391};
392
393// Values with identity not intended to be internalized.
394//===----------------------------------------------------------------------===//
395
396/// Base class for storages that represent values with identity, i.e., two
397/// values are not considered equivalent if they are structurally the same, but
398/// each definition of such a value is unique. E.g., unique labels or virtual
399/// registers. These cannot be materialized anew in each nested sequence, but
400/// must be passed as arguments.
401struct IdentityValue {
402
403 IdentityValue(Type type, Location loc) : type(type), loc(loc) {}
404
405#ifndef NDEBUG
406
407 /// In debug mode, track whether this value was already materialized to
408 /// assert if it's illegally materialized multiple times.
409 ///
410 /// Instead of deleting operations defining these values and materializing
411 /// them again, we could retain the operations. However, we still need
412 /// specific storages to represent these values in some cases, e.g., to get
413 /// the size of a memory allocation. Also, elaboration of nested control-flow
414 /// regions (e.g. `scf.for`) relies on materialization of such values lazily
415 /// instead of cloning the operations eagerly.
416 bool alreadyMaterialized = false;
417
418#endif
419
420 const Type type;
421 const Location loc;
422};
423
424/// Represents a unique virtual register.
425struct VirtualRegisterStorage : IdentityValue {
426 VirtualRegisterStorage(VirtualRegisterConfigAttr allowedRegs, Type type,
427 Location loc)
428 : IdentityValue(type, loc), allowedRegs(allowedRegs) {}
429
430 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
431 // VirtualRegisters are never internalized.
432
433 // The list of fixed registers allowed to be selected for this virtual
434 // register.
435 const VirtualRegisterConfigAttr allowedRegs;
436};
437
438struct UniqueLabelStorage : IdentityValue {
439 UniqueLabelStorage(const ElaboratorValue &name, Location loc)
440 : IdentityValue(LabelType::get(loc->getContext()), loc), name(name) {}
441
442 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
443 // VirtualRegisters are never internalized.
444
445 /// The label name. For unique labels, this is just the prefix.
446 const ElaboratorValue name;
447};
448
449/// Storage object for '!rtg.isa.memoryblock`-typed values.
450struct MemoryBlockStorage : IdentityValue {
451 MemoryBlockStorage(const APInt &baseAddress, const APInt &endAddress,
452 Type type, Location loc)
453 : IdentityValue(type, loc), baseAddress(baseAddress),
454 endAddress(endAddress) {}
455
456 // The base address of the memory. The width of the APInt also represents the
457 // address width of the memory. This is an APInt to support memories of
458 // >64-bit machines.
459 const APInt baseAddress;
460
461 // The last address of the memory.
462 const APInt endAddress;
463};
464
465/// Storage object for '!rtg.isa.memory`-typed values.
466struct MemoryStorage : IdentityValue {
467 MemoryStorage(MemoryBlockStorage *memoryBlock, size_t size, size_t alignment,
468 Location loc)
469 : IdentityValue(MemoryType::get(memoryBlock->type.getContext(),
470 memoryBlock->baseAddress.getBitWidth()),
471 loc),
472 memoryBlock(memoryBlock), size(size), alignment(alignment) {}
473
474 MemoryBlockStorage *memoryBlock;
475 const size_t size;
476 const size_t alignment;
477};
478
479/// Storage object for an '!rtg.randomized_sequence'.
480struct RandomizedSequenceStorage : IdentityValue {
481 RandomizedSequenceStorage(ContextResourceAttrInterface context,
482 SequenceStorage *sequence, Location loc)
483 : IdentityValue(
484 RandomizedSequenceType::get(sequence->familyName.getContext()),
485 loc),
486 context(context), sequence(sequence) {}
487
488 // The context under which this sequence is placed.
489 const ContextResourceAttrInterface context;
490
491 const SequenceStorage *sequence;
492};
493
494/// Operation must have at least 1 result.
495struct SymbolicComputationWithIdentityStorage : IdentityValue {
496 SymbolicComputationWithIdentityStorage(
497 const DenseMap<Value, ElaboratorValue> &state, Operation *op)
498 : IdentityValue(op->getResult(0).getType(), op->getLoc()),
499 name(op->getName()), resultTypes(op->getResultTypes()),
500 operands(llvm::map_range(op->getOperands(),
501 [&](Value v) { return state.lookup(v); })),
502 attributes(op->getAttrDictionary()),
503 properties(op->getPropertiesAsAttribute()) {}
504
505 const OperationName name;
506 const SmallVector<Type> resultTypes;
507 const SmallVector<ElaboratorValue> operands;
508 const DictionaryAttr attributes;
509 const Attribute properties;
510};
511
512struct SymbolicComputationWithIdentityValue : IdentityValue {
513 SymbolicComputationWithIdentityValue(
514 Type type, const SymbolicComputationWithIdentityStorage *storage,
515 unsigned idx)
516 : IdentityValue(type, storage->loc), storage(storage), idx(idx) {
517 assert(
518 idx != 0 &&
519 "Use SymbolicComputationWithIdentityStorage for result with index 0.");
520 }
521
522 const SymbolicComputationWithIdentityStorage *storage;
523 const unsigned idx;
524};
525
526/// An 'Internalizer' object internalizes storages and takes ownership of them.
527/// When the initializer object is destroyed, all owned storages are also
528/// deallocated and thus must not be accessed anymore.
529class Internalizer {
530public:
531 /// Internalize a storage of type `StorageTy` constructed with arguments
532 /// `args`. The pointers returned by this method can be used to compare
533 /// objects when, e.g., computing set differences, uniquing the elements in a
534 /// set, etc. Otherwise, we'd need to do a deep value comparison in those
535 /// situations.
536 template <typename StorageTy, typename... Args>
537 StorageTy *internalize(Args &&...args) {
538 static_assert(!std::is_base_of_v<IdentityValue, StorageTy> &&
539 "values with identity must not be internalized");
540
541 StorageTy storage(std::forward<Args>(args)...);
542
543 auto existing = getInternSet<StorageTy>().insert_as(
544 HashedStorage<StorageTy>(storage.hashcode), storage);
545 StorageTy *&storagePtr = existing.first->storage;
546 if (existing.second)
547 storagePtr =
548 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
549
550 return storagePtr;
551 }
552
553 template <typename StorageTy, typename... Args>
554 StorageTy *create(Args &&...args) {
555 static_assert(std::is_base_of_v<IdentityValue, StorageTy> &&
556 "values with structural equivalence must be internalized");
557
558 return new (allocator.Allocate<StorageTy>())
559 StorageTy(std::forward<Args>(args)...);
560 }
561
562private:
563 template <typename StorageTy>
564 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
565 getInternSet() {
566 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
567 return internedArrays;
568 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
569 return internedSets;
570 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
571 return internedBags;
572 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
573 return internedSequences;
574 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
575 return internedRandomizedSequences;
576 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
577 return internedInterleavedSequences;
578 else if constexpr (std::is_same_v<StorageTy, TupleStorage>)
579 return internedTuples;
580 else if constexpr (std::is_same_v<StorageTy, SymbolicComputationStorage>)
581 return internedSymbolicComputationWithIdentityValues;
582 else
583 static_assert(!sizeof(StorageTy),
584 "no intern set available for this storage type.");
585 }
586
587 // This allocator allocates on the heap. It automatically deallocates all
588 // objects it allocated once the allocator itself is destroyed.
589 llvm::BumpPtrAllocator allocator;
590
591 // The sets holding the internalized objects. We use one set per storage type
592 // such that we can have a simpler equality checking function (no need to
593 // compare some sort of TypeIDs).
594 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
595 internedArrays;
596 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
597 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
598 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
599 internedSequences;
600 DenseSet<HashedStorage<RandomizedSequenceStorage>,
601 StorageKeyInfo<RandomizedSequenceStorage>>
602 internedRandomizedSequences;
603 DenseSet<HashedStorage<InterleavedSequenceStorage>,
604 StorageKeyInfo<InterleavedSequenceStorage>>
605 internedInterleavedSequences;
606 DenseSet<HashedStorage<TupleStorage>, StorageKeyInfo<TupleStorage>>
607 internedTuples;
608 DenseSet<HashedStorage<SymbolicComputationStorage>,
609 StorageKeyInfo<SymbolicComputationStorage>>
610 internedSymbolicComputationWithIdentityValues;
611};
612
613} // namespace
614
615#ifndef NDEBUG
616
617static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
618 const ElaboratorValue &value);
619
620static void print(TypedAttr val, llvm::raw_ostream &os) {
621 os << "<attr " << val << ">";
622}
623
624static void print(BagStorage *val, llvm::raw_ostream &os) {
625 os << "<bag {";
626 llvm::interleaveComma(val->bag, os,
627 [&](const std::pair<ElaboratorValue, uint64_t> &el) {
628 os << el.first << " -> " << el.second;
629 });
630 os << "} at " << val << ">";
631}
632
633static void print(bool val, llvm::raw_ostream &os) {
634 os << "<bool " << (val ? "true" : "false") << ">";
635}
636
637static void print(size_t val, llvm::raw_ostream &os) {
638 os << "<index " << val << ">";
639}
640
641static void print(SequenceStorage *val, llvm::raw_ostream &os) {
642 os << "<sequence @" << val->familyName.getValue() << "(";
643 llvm::interleaveComma(val->args, os,
644 [&](const ElaboratorValue &val) { os << val; });
645 os << ") at " << val << ">";
646}
647
648static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
649 os << "<randomized-sequence derived from @"
650 << val->sequence->familyName.getValue() << " under context "
651 << val->context << "(";
652 llvm::interleaveComma(val->sequence->args, os,
653 [&](const ElaboratorValue &val) { os << val; });
654 os << ") at " << val << ">";
655}
656
657static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
658 os << "<interleaved-sequence [";
659 llvm::interleaveComma(val->sequences, os,
660 [&](const ElaboratorValue &val) { os << val; });
661 os << "] batch-size " << val->batchSize << " at " << val << ">";
662}
663
664static void print(ArrayStorage *val, llvm::raw_ostream &os) {
665 os << "<array [";
666 llvm::interleaveComma(val->array, os,
667 [&](const ElaboratorValue &val) { os << val; });
668 os << "] at " << val << ">";
669}
670
671static void print(SetStorage *val, llvm::raw_ostream &os) {
672 os << "<set {";
673 llvm::interleaveComma(val->set, os,
674 [&](const ElaboratorValue &val) { os << val; });
675 os << "} at " << val << ">";
676}
677
678static void print(const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
679 os << "<virtual-register " << val << " " << val->allowedRegs << ">";
680}
681
682static void print(const UniqueLabelStorage *val, llvm::raw_ostream &os) {
683 os << "<unique-label " << val << " " << val->name << ">";
684}
685
686static void print(const TupleStorage *val, llvm::raw_ostream &os) {
687 os << "<tuple (";
688 llvm::interleaveComma(val->values, os,
689 [&](const ElaboratorValue &val) { os << val; });
690 os << ")>";
691}
692
693static void print(const MemoryStorage *val, llvm::raw_ostream &os) {
694 os << "<memory {" << ElaboratorValue(val->memoryBlock)
695 << ", size=" << val->size << ", alignment=" << val->alignment << "}>";
696}
697
698static void print(const MemoryBlockStorage *val, llvm::raw_ostream &os) {
699 os << "<memory-block {"
700 << ", address-width=" << val->baseAddress.getBitWidth()
701 << ", base-address=" << val->baseAddress
702 << ", end-address=" << val->endAddress << "}>";
703}
704
705static void print(const SymbolicComputationWithIdentityValue *val,
706 llvm::raw_ostream &os) {
707 os << "<symbolic-computation-with-identity-value (" << val->storage << ") at "
708 << val->idx << ">";
709}
710
711static void print(const SymbolicComputationWithIdentityStorage *val,
712 llvm::raw_ostream &os) {
713 os << "<symbolic-computation-with-identity " << val->name << "(";
714 llvm::interleaveComma(val->operands, os,
715 [&](const ElaboratorValue &val) { os << val; });
716 os << ") -> " << val->resultTypes << " with attributes " << val->attributes
717 << " and properties " << val->properties;
718 os << ">";
719}
720
721static void print(const SymbolicComputationStorage *val,
722 llvm::raw_ostream &os) {
723 os << "<symbolic-computation " << val->name << "(";
724 llvm::interleaveComma(val->operands, os,
725 [&](const ElaboratorValue &val) { os << val; });
726 os << ") -> " << val->resultTypes << " with attributes " << val->attributes
727 << " and properties " << val->properties;
728 os << ">";
729}
730
731static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
732 const ElaboratorValue &value) {
733 std::visit([&](auto val) { print(val, os); }, value);
734
735 return os;
736}
737
738#endif
739
740//===----------------------------------------------------------------------===//
741// Attribute <-> ElaboratorValue Converters
742//===----------------------------------------------------------------------===//
743
744namespace {
745
746/// Convert an attribute to an ElaboratorValue. This handles nested attributes
747/// like SetAttr and TupleAttr recursively.
748class AttributeToElaboratorValueConverter {
749public:
750 AttributeToElaboratorValueConverter(Internalizer &internalizer)
751 : internalizer(internalizer) {}
752
753 /// Convert an attribute to an ElaboratorValue.
754 FailureOr<ElaboratorValue> convert(Attribute attr) {
755 return llvm::TypeSwitch<Attribute, FailureOr<ElaboratorValue>>(attr)
756 .Case<IntegerAttr, SetAttr, TupleAttr>(
757 [&](auto attr) { return convert(attr); })
758 .Case<TypedAttr>([&](auto typedAttr) -> FailureOr<ElaboratorValue> {
759 return ElaboratorValue(typedAttr);
760 })
761 .Default(
762 [&](Attribute) -> FailureOr<ElaboratorValue> { return failure(); });
763 }
764
765private:
766 FailureOr<ElaboratorValue> convert(IntegerAttr attr) {
767 if (attr.getType().isSignlessInteger(1))
768 return ElaboratorValue(bool(attr.getInt()));
769 if (isa<IndexType>(attr.getType()))
770 return ElaboratorValue(size_t(attr.getInt()));
771 return ElaboratorValue(attr);
772 }
773
774 FailureOr<ElaboratorValue> convert(SetAttr setAttr) {
775 SetVector<ElaboratorValue> set;
776 for (auto element : *setAttr.getElements()) {
777 auto converted = convert(element);
778 if (failed(converted))
779 return failure();
780 set.insert(*converted);
781 }
782 auto *storage =
783 internalizer.internalize<SetStorage>(std::move(set), setAttr.getType());
784 // Cache the original attribute for efficient materialization
785 storage->attrCache = setAttr;
786 return ElaboratorValue(storage);
787 }
788
789 FailureOr<ElaboratorValue> convert(TupleAttr tupleAttr) {
790 SmallVector<ElaboratorValue> values;
791 for (auto element : tupleAttr.getElements()) {
792 auto converted = convert(element);
793 if (failed(converted))
794 return failure();
795 values.push_back(*converted);
796 }
797 auto *storage = internalizer.internalize<TupleStorage>(std::move(values));
798 // Cache the original attribute for efficient materialization
799 storage->attrCache = tupleAttr;
800 return ElaboratorValue(storage);
801 }
802
803 Internalizer &internalizer;
804};
805
806/// Convert an ElaboratorValue to an attribute when possible. Not all
807/// ElaboratorValues can be converted to attributes (e.g., values with
808/// identity).
809class ElaboratorValueToAttributeConverter {
810public:
811 ElaboratorValueToAttributeConverter(MLIRContext *context)
812 : context(context) {}
813
814 /// Convert an ElaboratorValue to an attribute. Returns a null attribute if
815 /// the conversion is not possible.
816 TypedAttr convert(const ElaboratorValue &value) {
817 return std::visit(
818 [&](auto val) -> TypedAttr {
819 if constexpr (std::is_base_of_v<CachableStorage,
820 std::remove_pointer_t<
821 std::decay_t<decltype(value)>>>) {
822 if (val->attrCache)
823 return val->attrCache;
824 }
825 return visit(val);
826 },
827 value);
828 }
829
830private:
831 TypedAttr visit(TypedAttr val) { return val; }
832
833 TypedAttr visit(bool val) {
834 return IntegerAttr::get(IntegerType::get(context, 1), val);
835 }
836
837 TypedAttr visit(size_t val) {
838 return IntegerAttr::get(IndexType::get(context), val);
839 }
840
841 TypedAttr visit(SetStorage *val) {
842 DenseSet<TypedAttr> elements;
843 for (auto element : val->set) {
844 auto converted = convert(element);
845 if (!converted)
846 return {};
847 auto typedAttr = dyn_cast<TypedAttr>(converted);
848 if (!typedAttr)
849 return {};
850 elements.insert(typedAttr);
851 }
852 return SetAttr::get(cast<SetType>(val->type), &elements);
853 }
854
855 TypedAttr visit(TupleStorage *val) {
856 SmallVector<TypedAttr> elements;
857 for (auto element : val->values) {
858 auto converted = convert(element);
859 if (!converted)
860 return {};
861 auto typedAttr = dyn_cast<TypedAttr>(converted);
862 if (!typedAttr)
863 return {};
864 elements.push_back(typedAttr);
865 }
866 return TupleAttr::get(context, elements);
867 }
868
869 // Default implementation for storage types that cannot be converted to
870 // attributes. Using a macro to avoid repetition. std::visit does not support
871 // any kind of "default" clauses, unfortunately.
872#define VISIT_UNSUPPORTED(STORAGETYPE) \
873 /* NOLINTNEXTLINE(bugprone-macro-parentheses)*/ \
874 TypedAttr visit(STORAGETYPE *val) { return {}; }
875
876 VISIT_UNSUPPORTED(ArrayStorage)
877 VISIT_UNSUPPORTED(BagStorage)
878 VISIT_UNSUPPORTED(SequenceStorage)
879 VISIT_UNSUPPORTED(RandomizedSequenceStorage)
880 VISIT_UNSUPPORTED(InterleavedSequenceStorage)
881 VISIT_UNSUPPORTED(VirtualRegisterStorage)
882 VISIT_UNSUPPORTED(UniqueLabelStorage)
883 VISIT_UNSUPPORTED(MemoryStorage)
884 VISIT_UNSUPPORTED(MemoryBlockStorage)
885 VISIT_UNSUPPORTED(SymbolicComputationWithIdentityStorage)
886 VISIT_UNSUPPORTED(SymbolicComputationWithIdentityValue)
887 VISIT_UNSUPPORTED(SymbolicComputationStorage)
888
889#undef VISIT_UNSUPPORTED
890
891 MLIRContext *context;
892};
893
894} // namespace
895
896//===----------------------------------------------------------------------===//
897// Elaborator Value Materialization
898//===----------------------------------------------------------------------===//
899
900namespace {
901
902/// State that should be shared by all elaborator and materializer instances.
903struct SharedState {
904 SharedState(MLIRContext *ctxt, SymbolTable &table)
905 : ctxt(ctxt), table(table) {}
906
907 MLIRContext *ctxt;
908 SymbolTable &table;
909 Namespace names;
910 Internalizer internalizer;
911};
912
913/// A collection of state per RTG test.
914struct TestState {
915 explicit TestState(unsigned seed) : rng(RngScope(seed)) {}
916
917 /// The name of the test.
918 StringAttr name;
919
920 /// The context switches registered for this test.
921 MapVector<
922 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
923 SequenceStorage *>
924 contextSwitches;
925
926 /// The root RNG scope for this test.
927 RngScope rng;
928};
929
930/// Construct an SSA value from a given elaborated value.
931class Materializer {
932public:
933 Materializer(OpBuilder builder, TestState &testState,
934 SharedState &sharedState,
935 SmallVector<ElaboratorValue> &blockArgs)
936 : builder(builder), testState(testState), sharedState(sharedState),
937 blockArgs(blockArgs), attrConverter(builder.getContext()) {}
938
939 /// Materialize IR representing the provided `ElaboratorValue` and return the
940 /// `Value` or a null value on failure.
941 Value materialize(ElaboratorValue val, Location loc,
942 function_ref<InFlightDiagnostic()> emitError) {
943 auto iter = materializedValues.find(val);
944 if (iter != materializedValues.end())
945 return iter->second;
946
947 LLVM_DEBUG(llvm::dbgs() << "Materializing " << val);
948
949 if (auto res = tryMaterializeAsConstant(val, loc))
950 return res;
951
952 // In debug mode, track whether values with identity were already
953 // materialized before and assert in such a situation.
954 Value res = std::visit(
955 [&](auto value) {
956 if constexpr (std::is_base_of_v<IdentityValue,
957 std::remove_pointer_t<
958 std::decay_t<decltype(value)>>>) {
959 if (identityValueRoot.contains(value)) {
960#ifndef NDEBUG
961 bool &materialized =
962 static_cast<IdentityValue *>(value)->alreadyMaterialized;
963 assert(!materialized && "must not already be materialized");
964 materialized = true;
965#endif
966
967 return visit(value, loc, emitError);
968 }
969
970 Value arg = builder.getBlock()->addArgument(value->type, loc);
971 blockArgs.push_back(val);
972 blockArgTypes.push_back(arg.getType());
973 materializedValues[val] = arg;
974 return arg;
975 }
976
977 return visit(value, loc, emitError);
978 },
979 val);
980
981 LLVM_DEBUG(llvm::dbgs() << " to\n" << res << "\n\n");
982
983 return res;
984 }
985
986 bool isInPlace(Operation *op) const {
987 return builder.getBlock()->getParent() == op->getParentRegion();
988 }
989
990 /// If `op` is not in the same region as the materializer insertion point, a
991 /// clone is created at the materializer's insertion point by also
992 /// materializing the `ElaboratorValue`s for each operand just before it.
993 /// Otherwise, all operations after the materializer's insertion point are
994 /// deleted until `op` is reached. An error is returned if the operation is
995 /// before the insertion point.
996 LogicalResult materialize(Operation *op,
997 DenseMap<Value, ElaboratorValue> &state) {
998 if (op->getNumRegions() > 0)
999 return op->emitOpError("ops with nested regions must be elaborated away");
1000
1001 // We don't support opaque values. If there is an SSA value that has a
1002 // use-site it needs an equivalent ElaborationValue representation.
1003 // NOTE: We could support cases where there is initially a use-site but that
1004 // op is guaranteed to be deleted during elaboration. Or the use-sites are
1005 // replaced with freshly materialized values from the ElaborationValue. But
1006 // then, why can't we delete the value defining op?
1007 for (auto res : op->getResults())
1008 if (!res.use_empty())
1009 return op->emitOpError(
1010 "ops with results that have uses are not supported");
1011
1012 if (isInPlace(op)) {
1013 // We are doing in-place materialization, so mark all ops deleted until we
1014 // reach the one to be materialized and modify it in-place.
1015 deleteOpsUntil([&](auto iter) { return &*iter == op; });
1016
1017 if (builder.getInsertionPoint() == builder.getBlock()->end())
1018 return op->emitError("operation did not occur after the current "
1019 "materializer insertion point");
1020
1021 LLVM_DEBUG(llvm::dbgs() << "Modifying in-place: " << *op << "\n\n");
1022 } else {
1023 LLVM_DEBUG(llvm::dbgs() << "Materializing a clone of " << *op << "\n\n");
1024 op = builder.clone(*op);
1025 builder.setInsertionPoint(op);
1026 }
1027
1028 for (auto &operand : op->getOpOperands()) {
1029 auto emitError = [&]() {
1030 auto diag = op->emitError();
1031 diag.attachNote(op->getLoc())
1032 << "while materializing value for operand#"
1033 << operand.getOperandNumber();
1034 return diag;
1035 };
1036
1037 auto elabVal = state.at(operand.get());
1038 Value val = materialize(elabVal, op->getLoc(), emitError);
1039 if (!val)
1040 return failure();
1041
1042 state[val] = elabVal;
1043 operand.set(val);
1044 }
1045
1046 builder.setInsertionPointAfter(op);
1047 return success();
1048 }
1049
1050 /// Should be called once the `Region` is successfully materialized. No calls
1051 /// to `materialize` should happen after this anymore.
1052 void finalize() {
1053 deleteOpsUntil([](auto iter) { return false; });
1054
1055 for (auto *op : llvm::reverse(toDelete))
1056 op->erase();
1057 }
1058
1059 /// Tell this materializer that it is responsible for materializing the given
1060 /// identity value at the earliest position it is needed, and should't
1061 /// request the value via block argument.
1062 void registerIdentityValue(IdentityValue *val) {
1063 identityValueRoot.insert(val);
1064 }
1065
1066 ArrayRef<Type> getBlockArgTypes() const { return blockArgTypes; }
1067
1068 void map(ElaboratorValue eval, Value val) { materializedValues[eval] = val; }
1069
1070 template <typename OpTy, typename... Args>
1071 OpTy create(Location location, Args &&...args) {
1072 return OpTy::create(builder, location, std::forward<Args>(args)...);
1073 }
1074
1075private:
1076 Value tryMaterializeAsConstant(ElaboratorValue val, Location loc) {
1077 if (auto attr = attrConverter.convert(val)) {
1078 Value res = ConstantOp::create(builder, loc, attr);
1079 materializedValues[val] = res;
1080 return res;
1081 }
1082
1083 return Value();
1084 }
1085
1086 SequenceOp elaborateSequence(const RandomizedSequenceStorage *seq,
1087 SmallVector<ElaboratorValue> &elabArgs);
1088
1089 void deleteOpsUntil(function_ref<bool(Block::iterator)> stop) {
1090 auto ip = builder.getInsertionPoint();
1091 while (ip != builder.getBlock()->end() && !stop(ip)) {
1092 LLVM_DEBUG(llvm::dbgs() << "Marking to be deleted: " << *ip << "\n\n");
1093 toDelete.push_back(&*ip);
1094
1095 builder.setInsertionPointAfter(&*ip);
1096 ip = builder.getInsertionPoint();
1097 }
1098 }
1099
1100 Value visit(TypedAttr val, Location loc,
1101 function_ref<InFlightDiagnostic()> emitError) {
1102 return {};
1103 }
1104
1105 Value visit(size_t val, Location loc,
1106 function_ref<InFlightDiagnostic()> emitError) {
1107 return {};
1108 }
1109
1110 Value visit(bool val, Location loc,
1111 function_ref<InFlightDiagnostic()> emitError) {
1112 return {};
1113 }
1114
1115 Value visit(ArrayStorage *val, Location loc,
1116 function_ref<InFlightDiagnostic()> emitError) {
1117 SmallVector<Value> elements;
1118 elements.reserve(val->array.size());
1119 for (auto el : val->array) {
1120 auto materialized = materialize(el, loc, emitError);
1121 if (!materialized)
1122 return Value();
1123
1124 elements.push_back(materialized);
1125 }
1126
1127 Value res = ArrayCreateOp::create(builder, loc, val->type, elements);
1128 materializedValues[val] = res;
1129 return res;
1130 }
1131
1132 Value visit(SetStorage *val, Location loc,
1133 function_ref<InFlightDiagnostic()> emitError) {
1134 SmallVector<Value> elements;
1135 elements.reserve(val->set.size());
1136 for (auto el : val->set) {
1137 auto materialized = materialize(el, loc, emitError);
1138 if (!materialized)
1139 return Value();
1140
1141 elements.push_back(materialized);
1142 }
1143
1144 auto res = SetCreateOp::create(builder, loc, val->type, elements);
1145 materializedValues[val] = res;
1146 return res;
1147 }
1148
1149 Value visit(BagStorage *val, Location loc,
1150 function_ref<InFlightDiagnostic()> emitError) {
1151 SmallVector<Value> values, weights;
1152 values.reserve(val->bag.size());
1153 weights.reserve(val->bag.size());
1154 for (auto [val, weight] : val->bag) {
1155 auto materializedVal = materialize(val, loc, emitError);
1156 auto materializedWeight = materialize(weight, loc, emitError);
1157 if (!materializedVal || !materializedWeight)
1158 return Value();
1159
1160 values.push_back(materializedVal);
1161 weights.push_back(materializedWeight);
1162 }
1163
1164 auto res = BagCreateOp::create(builder, loc, val->type, values, weights);
1165 materializedValues[val] = res;
1166 return res;
1167 }
1168
1169 Value visit(MemoryBlockStorage *val, Location loc,
1170 function_ref<InFlightDiagnostic()> emitError) {
1171 auto intType = builder.getIntegerType(val->baseAddress.getBitWidth());
1172 Value res = MemoryBlockDeclareOp::create(
1173 builder, val->loc, val->type,
1174 IntegerAttr::get(intType, val->baseAddress),
1175 IntegerAttr::get(intType, val->endAddress));
1176 materializedValues[val] = res;
1177 return res;
1178 }
1179
1180 Value visit(MemoryStorage *val, Location loc,
1181 function_ref<InFlightDiagnostic()> emitError) {
1182 auto memBlock = materialize(val->memoryBlock, val->loc, emitError);
1183 auto memSize = materialize(val->size, val->loc, emitError);
1184 auto memAlign = materialize(val->alignment, val->loc, emitError);
1185 if (!(memBlock && memSize && memAlign))
1186 return {};
1187
1188 Value res =
1189 MemoryAllocOp::create(builder, val->loc, memBlock, memSize, memAlign);
1190 materializedValues[val] = res;
1191 return res;
1192 }
1193
1194 Value visit(SequenceStorage *val, Location loc,
1195 function_ref<InFlightDiagnostic()> emitError) {
1196 emitError() << "materializing a non-randomized sequence not supported yet";
1197 return Value();
1198 }
1199
1200 Value visit(RandomizedSequenceStorage *val, Location loc,
1201 function_ref<InFlightDiagnostic()> emitError) {
1202 // To know which values we have to pass by argument (and not just pass all
1203 // that migth be used eagerly), we have to elaborate the sequence family if
1204 // not already done so.
1205 // We need to get back the sequence to reference, and the list of elaborated
1206 // values to pass as arguments.
1207 SmallVector<ElaboratorValue> elabArgs;
1208 // NOTE: we wouldn't need to elaborate the sequence if it doesn't contain
1209 // randomness to be elaborated.
1210 SequenceOp seqOp = elaborateSequence(val, elabArgs);
1211 if (!seqOp)
1212 return {};
1213
1214 // Materialize all the values we need to pass as arguments and collect their
1215 // types.
1216 SmallVector<Value> args;
1217 SmallVector<Type> argTypes;
1218 for (auto arg : elabArgs) {
1219 Value materialized = materialize(arg, val->loc, emitError);
1220 if (!materialized)
1221 return {};
1222
1223 args.push_back(materialized);
1224 argTypes.push_back(materialized.getType());
1225 }
1226
1227 Value res = GetSequenceOp::create(
1228 builder, val->loc, SequenceType::get(builder.getContext(), argTypes),
1229 seqOp.getSymName());
1230
1231 // Only materialize a substitute_sequence op when we have arguments to
1232 // substitute since this op does not support 0 arguments.
1233 if (!args.empty())
1234 res = SubstituteSequenceOp::create(builder, val->loc, res, args);
1235
1236 res = RandomizeSequenceOp::create(builder, val->loc, res);
1237
1238 materializedValues[val] = res;
1239 return res;
1240 }
1241
1242 Value visit(InterleavedSequenceStorage *val, Location loc,
1243 function_ref<InFlightDiagnostic()> emitError) {
1244 SmallVector<Value> sequences;
1245 for (auto seqVal : val->sequences) {
1246 Value materialized = materialize(seqVal, loc, emitError);
1247 if (!materialized)
1248 return {};
1249
1250 sequences.push_back(materialized);
1251 }
1252
1253 if (sequences.size() == 1)
1254 return sequences[0];
1255
1256 Value res =
1257 InterleaveSequencesOp::create(builder, loc, sequences, val->batchSize);
1258 materializedValues[val] = res;
1259 return res;
1260 }
1261
1262 Value visit(VirtualRegisterStorage *val, Location loc,
1263 function_ref<InFlightDiagnostic()> emitError) {
1264 Value res = VirtualRegisterOp::create(builder, val->loc, val->allowedRegs);
1265 materializedValues[val] = res;
1266 return res;
1267 }
1268
1269 Value visit(UniqueLabelStorage *val, Location loc,
1270 function_ref<InFlightDiagnostic()> emitError) {
1271 auto materialized = materialize(val->name, val->loc, emitError);
1272 if (!materialized)
1273 return {};
1274 Value res = LabelUniqueDeclOp::create(builder, val->loc, materialized);
1275 materializedValues[val] = res;
1276 return res;
1277 }
1278
1279 Value visit(TupleStorage *val, Location loc,
1280 function_ref<InFlightDiagnostic()> emitError) {
1281 SmallVector<Value> materialized;
1282 materialized.reserve(val->values.size());
1283 for (auto v : val->values)
1284 materialized.push_back(materialize(v, loc, emitError));
1285 Value res = TupleCreateOp::create(builder, loc, materialized);
1286 materializedValues[val] = res;
1287 return res;
1288 }
1289
1290 Value visit(SymbolicComputationWithIdentityValue *val, Location loc,
1291 function_ref<InFlightDiagnostic()> emitError) {
1292 auto *noConstStorage =
1293 const_cast<SymbolicComputationWithIdentityStorage *>(val->storage);
1294 auto res0 = materialize(noConstStorage, loc, emitError);
1295 if (!res0)
1296 return {};
1297
1298 auto *op = res0.getDefiningOp();
1299 auto res = op->getResults()[val->idx];
1300 materializedValues[val] = res;
1301 return res;
1302 }
1303
1304 Value visit(SymbolicComputationWithIdentityStorage *val, Location loc,
1305 function_ref<InFlightDiagnostic()> emitError) {
1306 SmallVector<Value> operands;
1307 for (auto operand : val->operands) {
1308 auto materialized = materialize(operand, val->loc, emitError);
1309 if (!materialized)
1310 return {};
1311
1312 operands.push_back(materialized);
1313 }
1314
1315 OperationState state(val->loc, val->name);
1316 state.addTypes(val->resultTypes);
1317 state.attributes = val->attributes;
1318 state.propertiesAttr = val->properties;
1319 state.addOperands(operands);
1320 auto *op = builder.create(state);
1321
1322 materializedValues[val] = op->getResult(0);
1323 return op->getResult(0);
1324 }
1325
1326 Value visit(SymbolicComputationStorage *val, Location loc,
1327 function_ref<InFlightDiagnostic()> emitError) {
1328 SmallVector<Value> operands;
1329 for (auto operand : val->operands) {
1330 auto materialized = materialize(operand, loc, emitError);
1331 if (!materialized)
1332 return {};
1333
1334 operands.push_back(materialized);
1335 }
1336
1337 OperationState state(loc, val->name);
1338 state.addTypes(val->resultTypes);
1339 state.attributes = val->attributes;
1340 state.propertiesAttr = val->properties;
1341 state.addOperands(operands);
1342 auto *op = builder.create(state);
1343
1344 for (auto res : op->getResults())
1345 materializedValues[val] = res;
1346
1347 return op->getResult(0);
1348 }
1349
1350private:
1351 /// Cache values we have already materialized to reuse them later. We start
1352 /// with an insertion point at the start of the block and cache the (updated)
1353 /// insertion point such that future materializations can also reuse previous
1354 /// materializations without running into dominance issues (or requiring
1355 /// additional checks to avoid them).
1356 DenseMap<ElaboratorValue, Value> materializedValues;
1357
1358 /// Cache the builder to continue insertions at their current insertion point
1359 /// for the reason stated above.
1360 OpBuilder builder;
1361
1362 SmallVector<Operation *> toDelete;
1363
1364 TestState &testState;
1365 SharedState &sharedState;
1366
1367 /// Keep track of the block arguments we had to add to this materializer's
1368 /// block for identity values and also remember which elaborator values are
1369 /// expected to be passed as arguments from outside.
1370 SmallVector<ElaboratorValue> &blockArgs;
1371 SmallVector<Type> blockArgTypes;
1372
1373 /// Identity values in this set are materialized by this materializer,
1374 /// otherwise they are added as block arguments and the block that wants to
1375 /// embed this sequence is expected to provide a value for it.
1376 DenseSet<IdentityValue *> identityValueRoot;
1377
1378 /// Helper to convert ElaboratorValues to attributes.
1379 ElaboratorValueToAttributeConverter attrConverter;
1380};
1381
1382//===----------------------------------------------------------------------===//
1383// Elaboration Visitor
1384//===----------------------------------------------------------------------===//
1385
1386/// Used to signal to the elaboration driver whether the operation should be
1387/// removed.
1388enum class DeletionKind { Keep, Delete };
1389
1390/// Interprets the IR to perform and lower the represented randomizations.
1391class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
1392public:
1394 using RTGBase::visitOp;
1395
1396 Elaborator(SharedState &sharedState, TestState &testState,
1397 Materializer &materializer,
1398 ContextResourceAttrInterface currentContext = {})
1399 : sharedState(sharedState), testState(testState),
1400 materializer(materializer), currentContext(currentContext),
1401 attrConverter(sharedState.internalizer),
1402 elabValConverter(sharedState.ctxt) {}
1403
1404 template <typename ValueTy>
1405 inline ValueTy get(Value val) const {
1406 return std::get<ValueTy>(state.at(val));
1407 }
1408
1409 /// Print a nice error message for operations we don't support yet.
1410 FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
1411 return visitOpGeneric(op);
1412 }
1413
1414 FailureOr<DeletionKind> visitExternalOp(Operation *op) {
1415 return visitOpGeneric(op);
1416 }
1417
1418 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
1419 SmallVector<ElaboratorValue> replacements;
1420 state[op.getResult()] =
1421 sharedState.internalizer.internalize<SequenceStorage>(
1422 op.getSequenceAttr().getAttr(), std::move(replacements));
1423 return DeletionKind::Delete;
1424 }
1425
1426 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
1427 if (isSymbolic(state.at(op.getSequence())))
1428 return visitOpGeneric(op);
1429
1430 auto *seq = get<SequenceStorage *>(op.getSequence());
1431
1432 SmallVector<ElaboratorValue> replacements(seq->args);
1433 for (auto replacement : op.getReplacements())
1434 replacements.push_back(state.at(replacement));
1435
1436 state[op.getResult()] =
1437 sharedState.internalizer.internalize<SequenceStorage>(
1438 seq->familyName, std::move(replacements));
1439
1440 return DeletionKind::Delete;
1441 }
1442
1443 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
1444 auto *seq = get<SequenceStorage *>(op.getSequence());
1445 auto *randomizedSeq =
1446 sharedState.internalizer.create<RandomizedSequenceStorage>(
1447 currentContext, seq, op.getLoc());
1448 materializer.registerIdentityValue(randomizedSeq);
1449 state[op.getResult()] =
1450 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1451 randomizedSeq);
1452 return DeletionKind::Delete;
1453 }
1454
1455 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
1456 SmallVector<ElaboratorValue> sequences;
1457 for (auto seq : op.getSequences())
1458 sequences.push_back(state.at(seq));
1459
1460 state[op.getResult()] =
1461 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1462 std::move(sequences), op.getBatchSize());
1463 return DeletionKind::Delete;
1464 }
1465
1466 // NOLINTNEXTLINE(misc-no-recursion)
1467 LogicalResult isValidContext(ElaboratorValue value, Operation *op) const {
1468 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
1469 auto *seq = std::get<RandomizedSequenceStorage *>(value);
1470 if (seq->context != currentContext) {
1471 auto err = op->emitError("attempting to place sequence derived from ")
1472 << seq->sequence->familyName.getValue() << " under context "
1473 << currentContext
1474 << ", but it was previously randomized for context ";
1475 if (seq->context)
1476 err << seq->context;
1477 else
1478 err << "'default'";
1479 return err;
1480 }
1481 return success();
1482 }
1483
1484 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1485 for (auto val : interVal->sequences)
1486 if (failed(isValidContext(val, op)))
1487 return failure();
1488 return success();
1489 }
1490
1491 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1492 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1493 if (failed(isValidContext(seqVal, op)))
1494 return failure();
1495
1496 return DeletionKind::Keep;
1497 }
1498
1499 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1500 SetVector<ElaboratorValue> set;
1501 for (auto val : op.getElements())
1502 set.insert(state.at(val));
1503
1504 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1505 std::move(set), op.getSet().getType());
1506 return DeletionKind::Delete;
1507 }
1508
1509 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1510 auto set = get<SetStorage *>(op.getSet())->set;
1511
1512 if (set.empty())
1513 return op->emitError("cannot select from an empty set");
1514
1515 size_t selected = testState.rng.getUniformlyInRange(0, set.size() - 1);
1516 state[op.getResult()] = set[selected];
1517 return DeletionKind::Delete;
1518 }
1519
1520 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1521 auto original = get<SetStorage *>(op.getOriginal())->set;
1522 auto diff = get<SetStorage *>(op.getDiff())->set;
1523
1524 SetVector<ElaboratorValue> result(original);
1525 result.set_subtract(diff);
1526
1527 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1528 std::move(result), op.getResult().getType());
1529 return DeletionKind::Delete;
1530 }
1531
1532 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1533 SetVector<ElaboratorValue> result;
1534 for (auto set : op.getSets())
1535 result.set_union(get<SetStorage *>(set)->set);
1536
1537 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1538 std::move(result), op.getType());
1539 return DeletionKind::Delete;
1540 }
1541
1542 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1543 auto size = get<SetStorage *>(op.getSet())->set.size();
1544 state[op.getResult()] = size;
1545 return DeletionKind::Delete;
1546 }
1547
1548 // {a0,a1} x {b0,b1} x {c0,c1} -> {(a0), (a1)} -> {(a0,b0), (a0,b1), (a1,b0),
1549 // (a1,b1)} -> {(a0,b0,c0), (a0,b0,c1), (a0,b1,c0), (a0,b1,c1), (a1,b0,c0),
1550 // (a1,b0,c1), (a1,b1,c0), (a1,b1,c1)}
1551 FailureOr<DeletionKind> visitOp(SetCartesianProductOp op) {
1552 SetVector<ElaboratorValue> result;
1553 SmallVector<SmallVector<ElaboratorValue>> tuples;
1554 tuples.push_back({});
1555
1556 for (auto input : op.getInputs()) {
1557 auto &set = get<SetStorage *>(input)->set;
1558 if (set.empty()) {
1559 SetVector<ElaboratorValue> empty;
1560 state[op.getResult()] =
1561 sharedState.internalizer.internalize<SetStorage>(std::move(empty),
1562 op.getType());
1563 return DeletionKind::Delete;
1564 }
1565
1566 for (unsigned i = 0, e = tuples.size(); i < e; ++i) {
1567 for (auto setEl : set.getArrayRef().drop_back()) {
1568 tuples.push_back(tuples[i]);
1569 tuples.back().push_back(setEl);
1570 }
1571 tuples[i].push_back(set.back());
1572 }
1573 }
1574
1575 for (auto &tup : tuples)
1576 result.insert(
1577 sharedState.internalizer.internalize<TupleStorage>(std::move(tup)));
1578
1579 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1580 std::move(result), op.getType());
1581 return DeletionKind::Delete;
1582 }
1583
1584 FailureOr<DeletionKind> visitOp(SetConvertToBagOp op) {
1585 auto set = get<SetStorage *>(op.getInput())->set;
1586 MapVector<ElaboratorValue, uint64_t> bag;
1587 for (auto val : set)
1588 bag.insert({val, 1});
1589 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1590 std::move(bag), op.getType());
1591 return DeletionKind::Delete;
1592 }
1593
1594 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1595 MapVector<ElaboratorValue, uint64_t> bag;
1596 for (auto [val, multiple] :
1597 llvm::zip(op.getElements(), op.getMultiples())) {
1598 // If the multiple is not stored as an AttributeValue, the elaboration
1599 // must have already failed earlier (since we don't have
1600 // unevaluated/opaque values).
1601 bag[state.at(val)] += get<size_t>(multiple);
1602 }
1603
1604 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1605 std::move(bag), op.getType());
1606 return DeletionKind::Delete;
1607 }
1608
1609 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1610 auto bag = get<BagStorage *>(op.getBag())->bag;
1611
1612 if (bag.empty())
1613 return op->emitError("cannot select from an empty bag");
1614
1615 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1616 prefixSum.reserve(bag.size());
1617 uint32_t accumulator = 0;
1618 for (auto [val, weight] : bag) {
1619 accumulator += weight;
1620 prefixSum.push_back({val, accumulator});
1621 }
1622
1623 auto idx = testState.rng.getUniformlyInRange(0, accumulator - 1);
1624 auto *iter = llvm::upper_bound(
1625 prefixSum, idx,
1626 [](uint32_t a, const std::pair<ElaboratorValue, uint32_t> &b) {
1627 return a < b.second;
1628 });
1629
1630 state[op.getResult()] = iter->first;
1631 return DeletionKind::Delete;
1632 }
1633
1634 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1635 auto original = get<BagStorage *>(op.getOriginal())->bag;
1636 auto diff = get<BagStorage *>(op.getDiff())->bag;
1637
1638 MapVector<ElaboratorValue, uint64_t> result;
1639 for (const auto &el : original) {
1640 if (!diff.contains(el.first)) {
1641 result.insert(el);
1642 continue;
1643 }
1644
1645 if (op.getInf())
1646 continue;
1647
1648 auto toDiff = diff.lookup(el.first);
1649 if (el.second <= toDiff)
1650 continue;
1651
1652 result.insert({el.first, el.second - toDiff});
1653 }
1654
1655 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1656 std::move(result), op.getType());
1657 return DeletionKind::Delete;
1658 }
1659
1660 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1661 MapVector<ElaboratorValue, uint64_t> result;
1662 for (auto bag : op.getBags()) {
1663 auto val = get<BagStorage *>(bag)->bag;
1664 for (auto [el, multiple] : val)
1665 result[el] += multiple;
1666 }
1667
1668 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1669 std::move(result), op.getType());
1670 return DeletionKind::Delete;
1671 }
1672
1673 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1674 auto size = get<BagStorage *>(op.getBag())->bag.size();
1675 state[op.getResult()] = size;
1676 return DeletionKind::Delete;
1677 }
1678
1679 FailureOr<DeletionKind> visitOp(BagConvertToSetOp op) {
1680 auto bag = get<BagStorage *>(op.getInput())->bag;
1681 SetVector<ElaboratorValue> set;
1682 for (auto [k, v] : bag)
1683 set.insert(k);
1684 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1685 std::move(set), op.getType());
1686 return DeletionKind::Delete;
1687 }
1688
1689 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1690 auto *val = sharedState.internalizer.create<VirtualRegisterStorage>(
1691 op.getAllowedRegsAttr(), op.getType(), op.getLoc());
1692 state[op.getResult()] = val;
1693 materializer.registerIdentityValue(val);
1694 return DeletionKind::Delete;
1695 }
1696
1697 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1698 SmallVector<ElaboratorValue> array;
1699 array.reserve(op.getElements().size());
1700 for (auto val : op.getElements())
1701 array.emplace_back(state.at(val));
1702
1703 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1704 op.getResult().getType(), std::move(array));
1705 return DeletionKind::Delete;
1706 }
1707
1708 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1709 auto array = get<ArrayStorage *>(op.getArray())->array;
1710 size_t idx = get<size_t>(op.getIndex());
1711
1712 if (array.size() <= idx)
1713 return op->emitError("invalid to access index ")
1714 << idx << " of an array with " << array.size() << " elements";
1715
1716 state[op.getResult()] = array[idx];
1717 return DeletionKind::Delete;
1718 }
1719
1720 FailureOr<DeletionKind> visitOp(ArrayInjectOp op) {
1721 auto arrayOpaque = state.at(op.getArray());
1722 auto idxOpaque = state.at(op.getIndex());
1723 if (isSymbolic(arrayOpaque) || isSymbolic(idxOpaque))
1724 return visitOpGeneric(op);
1725
1726 auto array = std::get<ArrayStorage *>(arrayOpaque)->array;
1727 size_t idx = std::get<size_t>(idxOpaque);
1728
1729 if (array.size() <= idx)
1730 return op->emitError("invalid to access index ")
1731 << idx << " of an array with " << array.size() << " elements";
1732
1733 array[idx] = state.at(op.getValue());
1734 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1735 op.getResult().getType(), std::move(array));
1736 return DeletionKind::Delete;
1737 }
1738
1739 FailureOr<DeletionKind> visitOp(ArrayAppendOp op) {
1740 auto array = std::get<ArrayStorage *>(state.at(op.getArray()))->array;
1741 array.push_back(state.at(op.getElement()));
1742 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1743 op.getResult().getType(), std::move(array));
1744 return DeletionKind::Delete;
1745 }
1746
1747 FailureOr<DeletionKind> visitOp(ArraySizeOp op) {
1748 auto array = get<ArrayStorage *>(op.getArray())->array;
1749 state[op.getResult()] = array.size();
1750 return DeletionKind::Delete;
1751 }
1752
1753 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1754 auto *val = sharedState.internalizer.create<UniqueLabelStorage>(
1755 state.at(op.getNamePrefix()), op.getLoc());
1756 state[op.getLabel()] = val;
1757 materializer.registerIdentityValue(val);
1758 return DeletionKind::Delete;
1759 }
1760
1761 FailureOr<DeletionKind> visitOp(RandomScopeOp op) {
1762 auto getNestedRng = [&]() -> RngScope {
1763 if (op.getSeed())
1764 return RngScope(op.getSeed()->getZExtValue());
1765 return testState.rng.getNested();
1766 };
1767
1768 RngScope nestedRng = getNestedRng();
1769 std::swap(testState.rng, nestedRng);
1770
1771 // Elaborate the body region. The region has no arguments and yields values
1772 // that become the results of this operation.
1773 // We reuse the same elaborator for the nested region because we need access
1774 // to the elaborated values outside the nested region (since it is not
1775 // isolated from above) and we want to delete the random_scope operation
1776 // entirely, thus don't need a new materializer instance.
1777 SmallVector<ElaboratorValue> yieldedVals;
1778 if (failed(elaborate(op.getBodyRegion(), {}, /*keepTerminator=*/false,
1779 yieldedVals)))
1780 return failure();
1781
1782 // Restore the previous RNG scope.
1783 std::swap(testState.rng, nestedRng);
1784
1785 // Map the results of the random_scope to the yielded values.
1786 for (auto [res, out] : llvm::zip(op.getResults(), yieldedVals))
1787 state[res] = out;
1788
1789 return DeletionKind::Delete;
1790 }
1791
1792 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1793 size_t lower = get<size_t>(op.getLowerBound());
1794 size_t upper = get<size_t>(op.getUpperBound());
1795 if (lower > upper)
1796 return op->emitError("cannot select a number from an empty range");
1797
1798 state[op.getResult()] =
1799 size_t(testState.rng.getUniformlyInRange(lower, upper));
1800 return DeletionKind::Delete;
1801 }
1802
1803 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1804 size_t input = get<size_t>(op.getInput());
1805 auto width = op.getType().getWidth();
1806 auto emitError = [&]() { return op->emitError(); };
1807 if (input > APInt::getAllOnes(width).getZExtValue())
1808 return emitError() << "cannot represent " << input << " with " << width
1809 << " bits";
1810
1811 state[op.getResult()] =
1812 ImmediateAttr::get(op.getContext(), APInt(width, input));
1813 return DeletionKind::Delete;
1814 }
1815
1816 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1817 ContextResourceAttrInterface from = currentContext,
1818 to = cast<ContextResourceAttrInterface>(
1819 get<TypedAttr>(op.getContext()));
1820 if (!currentContext)
1821 from = DefaultContextAttr::get(op->getContext(), to.getType());
1822
1823 auto emitError = [&]() {
1824 auto diag = op.emitError();
1825 diag.attachNote(op.getLoc())
1826 << "while materializing value for context switching for " << op;
1827 return diag;
1828 };
1829
1830 if (from == to) {
1831 Value seqVal = materializer.materialize(
1832 get<SequenceStorage *>(op.getSequence()), op.getLoc(), emitError);
1833 if (!seqVal)
1834 return failure();
1835
1836 Value randSeqVal =
1837 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1838 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1839 return DeletionKind::Delete;
1840 }
1841
1842 // Switch to the desired context.
1843 // First, check if a context switch is registered that has the concrete
1844 // context as source and target.
1845 auto *iter = testState.contextSwitches.find({from, to});
1846
1847 // Try with 'any' context as target and the concrete context as source.
1848 if (iter == testState.contextSwitches.end())
1849 iter = testState.contextSwitches.find(
1850 {from, AnyContextAttr::get(op->getContext(), to.getType())});
1851
1852 // Try with 'any' context as source and the concrete context as target.
1853 if (iter == testState.contextSwitches.end())
1854 iter = testState.contextSwitches.find(
1855 {AnyContextAttr::get(op->getContext(), from.getType()), to});
1856
1857 // Try with 'any' context for both the source and the target.
1858 if (iter == testState.contextSwitches.end())
1859 iter = testState.contextSwitches.find(
1860 {AnyContextAttr::get(op->getContext(), from.getType()),
1861 AnyContextAttr::get(op->getContext(), to.getType())});
1862
1863 // Otherwise, fail with an error because we couldn't find a user
1864 // specification on how to switch between the requested contexts.
1865 // NOTE: we could think about supporting context switching via intermediate
1866 // context, i.e., treat it as a transitive relation.
1867 if (iter == testState.contextSwitches.end())
1868 return op->emitError("no context transition registered to switch from ")
1869 << from << " to " << to;
1870
1871 auto familyName = iter->second->familyName;
1872 SmallVector<ElaboratorValue> args{from, to,
1873 get<SequenceStorage *>(op.getSequence())};
1874 auto *seq = sharedState.internalizer.internalize<SequenceStorage>(
1875 familyName, std::move(args));
1876 auto *randSeq = sharedState.internalizer.create<RandomizedSequenceStorage>(
1877 to, seq, op.getLoc());
1878 materializer.registerIdentityValue(randSeq);
1879 Value seqVal = materializer.materialize(randSeq, op.getLoc(), emitError);
1880 if (!seqVal)
1881 return failure();
1882
1883 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1884 return DeletionKind::Delete;
1885 }
1886
1887 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1888 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1889 get<SequenceStorage *>(op.getSequence());
1890 return DeletionKind::Delete;
1891 }
1892
1893 FailureOr<DeletionKind> visitOp(MemoryBlockDeclareOp op) {
1894 auto *val = sharedState.internalizer.create<MemoryBlockStorage>(
1895 op.getBaseAddress(), op.getEndAddress(), op.getType(), op.getLoc());
1896 state[op.getResult()] = val;
1897 materializer.registerIdentityValue(val);
1898 return DeletionKind::Delete;
1899 }
1900
1901 FailureOr<DeletionKind> visitOp(MemoryAllocOp op) {
1902 size_t size = get<size_t>(op.getSize());
1903 size_t alignment = get<size_t>(op.getAlignment());
1904 auto *memBlock = get<MemoryBlockStorage *>(op.getMemoryBlock());
1905 auto *val = sharedState.internalizer.create<MemoryStorage>(
1906 memBlock, size, alignment, op.getLoc());
1907 state[op.getResult()] = val;
1908 materializer.registerIdentityValue(val);
1909 return DeletionKind::Delete;
1910 }
1911
1912 FailureOr<DeletionKind> visitOp(MemorySizeOp op) {
1913 auto *memory = get<MemoryStorage *>(op.getMemory());
1914 state[op.getResult()] = memory->size;
1915 return DeletionKind::Delete;
1916 }
1917
1918 FailureOr<DeletionKind> visitOp(TupleCreateOp op) {
1919 SmallVector<ElaboratorValue> values;
1920 values.reserve(op.getElements().size());
1921 for (auto el : op.getElements())
1922 values.push_back(state.at(el));
1923
1924 state[op.getResult()] =
1925 sharedState.internalizer.internalize<TupleStorage>(std::move(values));
1926 return DeletionKind::Delete;
1927 }
1928
1929 FailureOr<DeletionKind> visitOp(TupleExtractOp op) {
1930 auto *tuple = get<TupleStorage *>(op.getTuple());
1931 state[op.getResult()] = tuple->values[op.getIndex().getZExtValue()];
1932 return DeletionKind::Delete;
1933 }
1934
1935 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1936 bool cond = get<bool>(op.getCondition());
1937 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1938 if (toElaborate.empty())
1939 return DeletionKind::Delete;
1940
1941 // Just reuse this elaborator for the nested region because we need access
1942 // to the elaborated values outside the nested region (since it is not
1943 // isolated from above) and we want to materialize the region inline, thus
1944 // don't need a new materializer instance.
1945 SmallVector<ElaboratorValue> yieldedVals;
1946 if (failed(
1947 elaborate(toElaborate, {}, /*keepTerminator=*/false, yieldedVals)))
1948 return failure();
1949
1950 // Map the results of the 'scf.if' to the yielded values.
1951 for (auto [res, out] : llvm::zip(op.getResults(), yieldedVals))
1952 state[res] = out;
1953
1954 return DeletionKind::Delete;
1955 }
1956
1957 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1958 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1959 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1960 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1961 return op->emitOpError("can only elaborate index type iterator");
1962
1963 auto lowerBound = get<size_t>(op.getLowerBound());
1964 auto step = get<size_t>(op.getStep());
1965 auto upperBound = get<size_t>(op.getUpperBound());
1966
1967 // Prepare for first iteration by assigning the nested regions block
1968 // arguments. We can just reuse this elaborator because we need access to
1969 // values elaborated in the parent region anyway and materialize everything
1970 // inline (i.e., don't need a new materializer).
1971 state[op.getInductionVar()] = lowerBound;
1972 for (auto [iterArg, initArg] :
1973 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1974 state[iterArg] = state.at(initArg);
1975
1976 // This loop performs the actual 'scf.for' loop iterations.
1977 SmallVector<ElaboratorValue> yieldedVals;
1978 for (size_t i = lowerBound; i < upperBound; i += step) {
1979 yieldedVals.clear();
1980 if (failed(elaborate(op.getBodyRegion(), {}, /*keepTerminator=*/false,
1981 yieldedVals)))
1982 return failure();
1983
1984 // Prepare for the next iteration by updating the mapping of the nested
1985 // regions block arguments
1986 state[op.getInductionVar()] = i + step;
1987 for (auto [iterArg, prevIterArg] :
1988 llvm::zip(op.getRegionIterArgs(), yieldedVals))
1989 state[iterArg] = prevIterArg;
1990 }
1991
1992 // Transfer the previously yielded values to the for loop result values.
1993 for (auto [res, iterArg] :
1994 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1995 state[res] = state.at(iterArg);
1996
1997 return DeletionKind::Delete;
1998 }
1999
2000 FailureOr<DeletionKind> visitOp(arith::AddIOp op) {
2001 if (!isa<IndexType>(op.getType()))
2002 return visitOpGeneric(op);
2003
2004 size_t lhs = get<size_t>(op.getLhs());
2005 size_t rhs = get<size_t>(op.getRhs());
2006 state[op.getResult()] = lhs + rhs;
2007 return DeletionKind::Delete;
2008 }
2009
2010 FailureOr<DeletionKind> visitOp(arith::AndIOp op) {
2011 if (!op.getType().isSignlessInteger(1))
2012 return visitOpGeneric(op);
2013
2014 bool lhs = get<bool>(op.getLhs());
2015 bool rhs = get<bool>(op.getRhs());
2016 state[op.getResult()] = lhs && rhs;
2017 return DeletionKind::Delete;
2018 }
2019
2020 FailureOr<DeletionKind> visitOp(arith::XOrIOp op) {
2021 if (!op.getType().isSignlessInteger(1))
2022 return visitOpGeneric(op);
2023
2024 bool lhs = get<bool>(op.getLhs());
2025 bool rhs = get<bool>(op.getRhs());
2026 state[op.getResult()] = lhs != rhs;
2027 return DeletionKind::Delete;
2028 }
2029
2030 FailureOr<DeletionKind> visitOp(arith::OrIOp op) {
2031 if (!op.getType().isSignlessInteger(1))
2032 return visitOpGeneric(op);
2033
2034 bool lhs = get<bool>(op.getLhs());
2035 bool rhs = get<bool>(op.getRhs());
2036 state[op.getResult()] = lhs || rhs;
2037 return DeletionKind::Delete;
2038 }
2039
2040 FailureOr<DeletionKind> visitOp(arith::SelectOp op) {
2041 auto condOpaque = state.at(op.getCondition());
2042 if (isSymbolic(condOpaque))
2043 return visitOpGeneric(op);
2044
2045 bool cond = std::get<bool>(condOpaque);
2046 auto trueVal = state.at(op.getTrueValue());
2047 auto falseVal = state.at(op.getFalseValue());
2048 state[op.getResult()] = cond ? trueVal : falseVal;
2049 return DeletionKind::Delete;
2050 }
2051
2052 FailureOr<DeletionKind> visitOp(index::AddOp op) {
2053 size_t lhs = get<size_t>(op.getLhs());
2054 size_t rhs = get<size_t>(op.getRhs());
2055 state[op.getResult()] = lhs + rhs;
2056 return DeletionKind::Delete;
2057 }
2058
2059 FailureOr<DeletionKind> visitOp(index::SubOp op) {
2060 size_t lhs = get<size_t>(op.getLhs());
2061 size_t rhs = get<size_t>(op.getRhs());
2062 state[op.getResult()] = lhs - rhs;
2063 return DeletionKind::Delete;
2064 }
2065
2066 FailureOr<DeletionKind> visitOp(index::MulOp op) {
2067 size_t lhs = get<size_t>(op.getLhs());
2068 size_t rhs = get<size_t>(op.getRhs());
2069 state[op.getResult()] = lhs * rhs;
2070 return DeletionKind::Delete;
2071 }
2072
2073 FailureOr<DeletionKind> visitOp(index::DivUOp op) {
2074 size_t lhs = get<size_t>(op.getLhs());
2075 size_t rhs = get<size_t>(op.getRhs());
2076
2077 if (rhs == 0)
2078 return op->emitOpError("attempted division by zero");
2079
2080 state[op.getResult()] = lhs / rhs;
2081 return DeletionKind::Delete;
2082 }
2083
2084 FailureOr<DeletionKind> visitOp(index::CeilDivUOp op) {
2085 size_t lhs = get<size_t>(op.getLhs());
2086 size_t rhs = get<size_t>(op.getRhs());
2087
2088 if (rhs == 0)
2089 return op->emitOpError("attempted division by zero");
2090
2091 if (lhs == 0)
2092 state[op.getResult()] = (lhs + rhs - 1) / rhs;
2093 else
2094 state[op.getResult()] = 1 + ((lhs - 1) / rhs);
2095
2096 return DeletionKind::Delete;
2097 }
2098
2099 FailureOr<DeletionKind> visitOp(index::RemUOp op) {
2100 size_t lhs = get<size_t>(op.getLhs());
2101 size_t rhs = get<size_t>(op.getRhs());
2102
2103 if (rhs == 0)
2104 return op->emitOpError("attempted division by zero");
2105
2106 state[op.getResult()] = lhs % rhs;
2107 return DeletionKind::Delete;
2108 }
2109
2110 FailureOr<DeletionKind> visitOp(index::AndOp op) {
2111 size_t lhs = get<size_t>(op.getLhs());
2112 size_t rhs = get<size_t>(op.getRhs());
2113 state[op.getResult()] = lhs & rhs;
2114 return DeletionKind::Delete;
2115 }
2116
2117 FailureOr<DeletionKind> visitOp(index::OrOp op) {
2118 size_t lhs = get<size_t>(op.getLhs());
2119 size_t rhs = get<size_t>(op.getRhs());
2120 state[op.getResult()] = lhs | rhs;
2121 return DeletionKind::Delete;
2122 }
2123
2124 FailureOr<DeletionKind> visitOp(index::XOrOp op) {
2125 size_t lhs = get<size_t>(op.getLhs());
2126 size_t rhs = get<size_t>(op.getRhs());
2127 state[op.getResult()] = lhs ^ rhs;
2128 return DeletionKind::Delete;
2129 }
2130
2131 FailureOr<DeletionKind> visitOp(index::ShlOp op) {
2132 size_t lhs = get<size_t>(op.getLhs());
2133 size_t rhs = get<size_t>(op.getRhs());
2134 state[op.getResult()] = lhs << rhs;
2135 return DeletionKind::Delete;
2136 }
2137
2138 FailureOr<DeletionKind> visitOp(index::ShrUOp op) {
2139 size_t lhs = get<size_t>(op.getLhs());
2140 size_t rhs = get<size_t>(op.getRhs());
2141 state[op.getResult()] = lhs >> rhs;
2142 return DeletionKind::Delete;
2143 }
2144
2145 FailureOr<DeletionKind> visitOp(index::MaxUOp op) {
2146 size_t lhs = get<size_t>(op.getLhs());
2147 size_t rhs = get<size_t>(op.getRhs());
2148 state[op.getResult()] = std::max(lhs, rhs);
2149 return DeletionKind::Delete;
2150 }
2151
2152 FailureOr<DeletionKind> visitOp(index::MinUOp op) {
2153 size_t lhs = get<size_t>(op.getLhs());
2154 size_t rhs = get<size_t>(op.getRhs());
2155 state[op.getResult()] = std::min(lhs, rhs);
2156 return DeletionKind::Delete;
2157 }
2158
2159 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
2160 size_t lhs = get<size_t>(op.getLhs());
2161 size_t rhs = get<size_t>(op.getRhs());
2162 bool result;
2163 switch (op.getPred()) {
2164 case index::IndexCmpPredicate::EQ:
2165 result = lhs == rhs;
2166 break;
2167 case index::IndexCmpPredicate::NE:
2168 result = lhs != rhs;
2169 break;
2170 case index::IndexCmpPredicate::ULT:
2171 result = lhs < rhs;
2172 break;
2173 case index::IndexCmpPredicate::ULE:
2174 result = lhs <= rhs;
2175 break;
2176 case index::IndexCmpPredicate::UGT:
2177 result = lhs > rhs;
2178 break;
2179 case index::IndexCmpPredicate::UGE:
2180 result = lhs >= rhs;
2181 break;
2182 default:
2183 return op->emitOpError("elaboration not supported");
2184 }
2185 state[op.getResult()] = result;
2186 return DeletionKind::Delete;
2187 }
2188
2189 bool isSymbolic(ElaboratorValue val) {
2190 return std::holds_alternative<SymbolicComputationWithIdentityValue *>(
2191 val) ||
2192 std::holds_alternative<SymbolicComputationWithIdentityStorage *>(
2193 val) ||
2194 std::holds_alternative<SymbolicComputationStorage *>(val);
2195 }
2196
2197 bool isSymbolic(Operation *op) {
2198 return llvm::any_of(op->getOperands(), [&](auto operand) {
2199 auto val = state.at(operand);
2200 return isSymbolic(val);
2201 });
2202 }
2203
2204 /// If all operands are constants, try to fold the operation and register its
2205 /// result values with the folded results in the elaborator state.
2206 /// Returns 'false' if operation has to be handled symbolically.
2207 bool attemptConcreteCase(Operation *op) {
2208 if (op->getNumResults() == 0)
2209 return false;
2210
2211 SmallVector<Attribute> operands;
2212 for (auto operand : op->getOperands()) {
2213 auto evalValue = state[operand];
2214 auto attr = elabValConverter.convert(evalValue);
2215 operands.push_back(attr);
2216 }
2217
2218 SmallVector<OpFoldResult> results;
2219 if (failed(op->fold(operands, results)))
2220 return false;
2221
2222 if (results.size() != op->getNumResults())
2223 return false;
2224
2225 for (auto [res, val] : llvm::zip(results, op->getResults())) {
2226 auto attr = llvm::dyn_cast_or_null<TypedAttr>(res.dyn_cast<Attribute>());
2227 if (!attr)
2228 return false;
2229
2230 if (attr.getType() != val.getType())
2231 return false;
2232
2233 // Try to convert the attribute to an ElaboratorValue
2234 auto converted = attrConverter.convert(attr);
2235 if (succeeded(converted)) {
2236 state[val] = *converted;
2237 continue;
2238 }
2239 return false;
2240 }
2241
2242 return true;
2243 }
2244
2245 FailureOr<DeletionKind> visitOpGeneric(Operation *op) {
2246 if (op->getNumResults() == 0)
2247 return DeletionKind::Keep;
2248
2249 if (attemptConcreteCase(op))
2250 return DeletionKind::Delete;
2251
2252 if (mlir::isMemoryEffectFree(op)) {
2253 if (op->getNumResults() != 1)
2254 return op->emitOpError(
2255 "symbolic elaboration of memory-effect-free operations with "
2256 "multiple results not supported");
2257
2258 state[op->getResult(0)] =
2259 sharedState.internalizer.internalize<SymbolicComputationStorage>(
2260 state, op);
2261 return DeletionKind::Delete;
2262 }
2263
2264 // We assume that reordering operations with only allocate effects is
2265 // allowed.
2266 // FIXME: this is not how the MLIR MemoryEffects interface intends it.
2267 // We should create our own interface/trait for that use-case.
2268 // Or modify the elaboration pass to keep track of the ordering of such
2269 // instructions and materialize all operations that are not already
2270 // materialized but have to happen before the current alloc operation to be
2271 // materialized.
2272 bool onlyAlloc = mlir::hasSingleEffect<mlir::MemoryEffects::Allocate>(op);
2273 onlyAlloc |= isa<ValidateOp>(op);
2274
2275 auto *validationVal =
2276 sharedState.internalizer.create<SymbolicComputationWithIdentityStorage>(
2277 state, op);
2278 materializer.registerIdentityValue(validationVal);
2279 state[op->getResult(0)] = validationVal;
2280
2281 for (auto [i, res] : llvm::enumerate(op->getResults())) {
2282 if (i == 0)
2283 continue;
2284 auto *val =
2285 sharedState.internalizer.create<SymbolicComputationWithIdentityValue>(
2286 res.getType(), validationVal, i);
2287 state[res] = val;
2288 materializer.registerIdentityValue(val);
2289 }
2290 return onlyAlloc ? DeletionKind::Delete : DeletionKind::Keep;
2291 }
2292
2293 bool supportsSymbolicValuesNonGenerically(Operation *op) {
2294 return isa<SubstituteSequenceOp, ArrayCreateOp, ArrayInjectOp,
2295 TupleCreateOp, arith::SelectOp>(op);
2296 }
2297
2298 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
2299 if (isSymbolic(op) && !supportsSymbolicValuesNonGenerically(op))
2300 return visitOpGeneric(op);
2301
2302 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
2303 .Case<
2304 // Arith ops
2305 arith::AddIOp, arith::XOrIOp, arith::AndIOp, arith::OrIOp,
2306 arith::SelectOp,
2307 // Index ops
2308 index::AddOp, index::SubOp, index::MulOp, index::DivUOp,
2309 index::CeilDivUOp, index::RemUOp, index::AndOp, index::OrOp,
2310 index::XOrOp, index::ShlOp, index::ShrUOp, index::MaxUOp,
2311 index::MinUOp, index::CmpOp,
2312 // SCF ops
2313 scf::IfOp, scf::ForOp>([&](auto op) { return visitOp(op); })
2314 .Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); });
2315 }
2316
2317 // NOLINTNEXTLINE(misc-no-recursion)
2318 LogicalResult elaborate(Region &region,
2319 ArrayRef<ElaboratorValue> regionArguments,
2320 bool keepTerminator,
2321 SmallVector<ElaboratorValue> &terminatorOperands) {
2322 if (region.getBlocks().size() > 1)
2323 return region.getParentOp()->emitOpError(
2324 "regions with more than one block are not supported");
2325
2326 for (auto [arg, elabArg] :
2327 llvm::zip(region.getArguments(), regionArguments))
2328 state[arg] = elabArg;
2329
2330 Block *block = &region.front();
2331 auto iter = keepTerminator ? *block : block->without_terminator();
2332 for (auto &op : iter) {
2333 auto result = dispatchOpVisitor(&op);
2334 if (failed(result))
2335 return failure();
2336
2337 if (*result == DeletionKind::Keep)
2338 if (failed(materializer.materialize(&op, state)))
2339 return failure();
2340
2341 LLVM_DEBUG({
2342 llvm::dbgs() << "Elaborated " << op << " to\n[";
2343
2344 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) {
2345 if (state.contains(res))
2346 llvm::dbgs() << state.at(res);
2347 else
2348 llvm::dbgs() << "unknown";
2349 });
2350
2351 llvm::dbgs() << "]\n\n";
2352 });
2353 }
2354
2355 if (!block->empty() && block->back().hasTrait<OpTrait::IsTerminator>()) {
2356 auto *terminator = block->getTerminator();
2357 for (auto val : terminator->getOperands())
2358 terminatorOperands.push_back(state.at(val));
2359
2360 if (!keepTerminator && materializer.isInPlace(terminator))
2361 terminator->erase();
2362 }
2363
2364 return success();
2365 }
2366
2367private:
2368 // State to be shared between all elaborator instances.
2369 SharedState &sharedState;
2370
2371 // State to a specific RTG test and the sequences placed within it.
2372 TestState &testState;
2373
2374 // Allows us to materialize ElaboratorValues to the IR operations necessary to
2375 // obtain an SSA value representing that elaborated value.
2376 Materializer &materializer;
2377
2378 // A map from SSA values to a pointer of an interned elaborator value.
2379 DenseMap<Value, ElaboratorValue> state;
2380
2381 // The current context we are elaborating under.
2382 ContextResourceAttrInterface currentContext;
2383
2384 // Allows us to convert attributes to ElaboratorValues.
2385 AttributeToElaboratorValueConverter attrConverter;
2386
2387 // Allows us to convert ElaboratorValues to attributes.
2388 ElaboratorValueToAttributeConverter elabValConverter;
2389};
2390} // namespace
2391
2392SequenceOp
2393Materializer::elaborateSequence(const RandomizedSequenceStorage *seq,
2394 SmallVector<ElaboratorValue> &elabArgs) {
2395 auto familyOp =
2396 sharedState.table.lookup<SequenceOp>(seq->sequence->familyName);
2397 // TODO: don't clone if this is the only remaining reference to this
2398 // sequence
2399 OpBuilder builder(familyOp);
2400 auto seqOp = builder.cloneWithoutRegions(familyOp);
2401 auto name = sharedState.names.newName(seq->sequence->familyName.getValue());
2402 seqOp.setSymName(name);
2403 seqOp.getBodyRegion().emplaceBlock();
2404 sharedState.table.insert(seqOp);
2405 assert(seqOp.getSymName() == name && "should not have been renamed");
2406
2407 LLVM_DEBUG(llvm::dbgs() << "\n=== Elaborating sequence family @"
2408 << familyOp.getSymName() << " into @"
2409 << seqOp.getSymName() << " under context "
2410 << seq->context << "\n\n");
2411
2412 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()), testState,
2413 sharedState, elabArgs);
2414 Elaborator elaborator(sharedState, testState, materializer, seq->context);
2415 SmallVector<ElaboratorValue> yieldedVals;
2416 if (failed(elaborator.elaborate(familyOp.getBodyRegion(), seq->sequence->args,
2417 /*keepTerminator=*/false, yieldedVals)))
2418 return {};
2419
2420 seqOp.setSequenceType(
2421 SequenceType::get(builder.getContext(), materializer.getBlockArgTypes()));
2422 materializer.finalize();
2423
2424 return seqOp;
2425}
2426
2427//===----------------------------------------------------------------------===//
2428// Elaborator Pass
2429//===----------------------------------------------------------------------===//
2430
2431namespace {
2432struct ElaborationPass
2433 : public rtg::impl::ElaborationPassBase<ElaborationPass> {
2434 using Base::Base;
2435
2436 void runOnOperation() override;
2437 void matchTestsAgainstTargets(SymbolTable &table);
2438 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
2439};
2440} // namespace
2441
2442void ElaborationPass::runOnOperation() {
2443 auto moduleOp = getOperation();
2444 SymbolTable table(moduleOp);
2445
2446 matchTestsAgainstTargets(table);
2447
2448 if (failed(elaborateModule(moduleOp, table)))
2449 return signalPassFailure();
2450}
2451
2452void ElaborationPass::matchTestsAgainstTargets(SymbolTable &table) {
2453 auto moduleOp = getOperation();
2454
2455 for (auto test : llvm::make_early_inc_range(moduleOp.getOps<TestOp>())) {
2456 if (test.getTargetAttr())
2457 continue;
2458
2459 bool matched = false;
2460
2461 for (auto target : moduleOp.getOps<TargetOp>()) {
2462 // Check if the target type is a subtype of the test's target type
2463 // This means that for each entry in the test's target type, there must be
2464 // a corresponding entry with the same name and type in the target's type
2465 bool isSubtype = true;
2466 auto testEntries = test.getTargetType().getEntries();
2467 auto targetEntries = target.getTarget().getEntries();
2468
2469 // Check if target is a subtype of test requirements
2470 // Since entries are sorted by name, we can do this in a single pass
2471 size_t targetIdx = 0;
2472 for (auto testEntry : testEntries) {
2473 // Find the matching entry in target entries.
2474 while (targetIdx < targetEntries.size() &&
2475 targetEntries[targetIdx].name.getValue() <
2476 testEntry.name.getValue())
2477 targetIdx++;
2478
2479 // Check if we found a matching entry with the same name and type
2480 if (targetIdx >= targetEntries.size() ||
2481 targetEntries[targetIdx].name != testEntry.name ||
2482 targetEntries[targetIdx].type != testEntry.type) {
2483 isSubtype = false;
2484 break;
2485 }
2486 }
2487
2488 if (!isSubtype)
2489 continue;
2490
2491 IRRewriter rewriter(test);
2492 // Create a new test for the matched target
2493 auto newTest = cast<TestOp>(test->clone());
2494 newTest.setSymName(test.getSymName().str() + "_" +
2495 target.getSymName().str());
2496
2497 // Set the target symbol specifying that this test is only suitable for
2498 // that target.
2499 newTest.setTargetAttr(target.getSymNameAttr());
2500
2501 table.insert(newTest, rewriter.getInsertionPoint());
2502 matched = true;
2503 }
2504
2505 if (matched || deleteUnmatchedTests)
2506 test->erase();
2507 }
2508}
2509
2510static bool onlyLegalToMaterializeInTarget(Type type) {
2511 return isa<MemoryBlockType, ContextResourceTypeInterface>(type);
2512}
2513
2514LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
2515 SymbolTable &table) {
2516 SharedState state(moduleOp.getContext(), table);
2517
2518 // Update the name cache
2519 state.names.add(moduleOp);
2520
2521 struct TargetElabResult {
2522 TargetElabResult(DictType targetType, uint32_t seed)
2523 : targetType(targetType), testState(seed) {}
2524
2525 DictType targetType;
2526 SmallVector<ElaboratorValue> yields;
2527 TestState testState;
2528 };
2529
2530 // Map to store elaborated targets
2531 DenseMap<StringAttr, TargetElabResult> targetMap;
2532 for (auto targetOp : moduleOp.getOps<TargetOp>()) {
2533 LLVM_DEBUG(llvm::dbgs() << "=== Elaborating target @"
2534 << targetOp.getSymName() << "\n\n");
2535
2536 auto [it, inserted] = targetMap.try_emplace(targetOp.getSymNameAttr(),
2537 targetOp.getTarget(), seed);
2538 auto &result = it->second;
2539
2540 SmallVector<ElaboratorValue> blockArgs;
2541 Materializer targetMaterializer(OpBuilder::atBlockBegin(targetOp.getBody()),
2542 result.testState, state, blockArgs);
2543 Elaborator targetElaborator(state, result.testState, targetMaterializer);
2544
2545 // Elaborate the target
2546 if (failed(targetElaborator.elaborate(targetOp.getBodyRegion(), {},
2547 /*keepTerminator=*/true,
2548 result.yields)))
2549 return failure();
2550
2551 targetMaterializer.finalize();
2552 }
2553
2554 // Initialize the worklist with the test ops since they cannot be placed by
2555 // other ops.
2556 for (auto testOp : moduleOp.getOps<TestOp>()) {
2557 // Skip tests without a target attribute - these couldn't be matched
2558 // against any target but can be useful to keep around for reporting
2559 // purposes.
2560 if (!testOp.getTargetAttr())
2561 continue;
2562
2563 LLVM_DEBUG(llvm::dbgs()
2564 << "\n=== Elaborating test @" << testOp.getTemplateName()
2565 << " for target @" << *testOp.getTarget() << "\n\n");
2566
2567 // Get the target for this test
2568 auto &targetResult = targetMap.at(testOp.getTargetAttr());
2569 TestState testState(seed);
2570 testState.contextSwitches = targetResult.testState.contextSwitches;
2571 testState.name = testOp.getSymNameAttr();
2572
2573 SmallVector<ElaboratorValue> filteredYields;
2574 unsigned i = 0;
2575 for (auto [entry, yield] :
2576 llvm::zip(targetResult.targetType.getEntries(), targetResult.yields)) {
2577 if (i >= testOp.getTargetType().getEntries().size())
2578 break;
2579
2580 if (entry.name == testOp.getTargetType().getEntries()[i].name) {
2581 filteredYields.push_back(yield);
2582 ++i;
2583 }
2584 }
2585
2586 // Now elaborate the test with the same state, passing the target yield
2587 // values as arguments
2588 SmallVector<ElaboratorValue> blockArgs;
2589 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()),
2590 testState, state, blockArgs);
2591
2592 for (auto [arg, val] :
2593 llvm::zip(testOp.getBody()->getArguments(), filteredYields))
2594 if (onlyLegalToMaterializeInTarget(arg.getType()))
2595 materializer.map(val, arg);
2596
2597 Elaborator elaborator(state, testState, materializer);
2598 SmallVector<ElaboratorValue> ignore;
2599 if (failed(elaborator.elaborate(testOp.getBodyRegion(), filteredYields,
2600 /*keepTerminator=*/false, ignore)))
2601 return failure();
2602
2603 materializer.finalize();
2604 }
2605
2606 return success();
2607}
assert(baseType &&"element must be base type")
static bool onlyLegalToMaterializeInTarget(Type type)
#define VISIT_UNSUPPORTED(STORAGETYPE)
static void print(TypedAttr val, llvm::raw_ostream &os)
static LogicalResult convert(arc::ExecuteOp op, arc::ExecuteOp::Adaptor adaptor, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
This helps visit TypeOp nodes.
Definition RTGVisitors.h:29
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:96
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static llvm::hash_code hash_value(const ModulePort &port)
Definition HWTypes.h:38
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition Utils.h:32
Definition rtg.py:1
Definition seq.py:1
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)