CIRCT 23.0.0git
Loading...
Searching...
No Matches
ElaborationPass.cpp
Go to the documentation of this file.
1//===- ElaborationPass.cpp - RTG ElaborationPass implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass elaborates the random parts of the RTG dialect.
10// It performs randomization top-down, i.e., random constructs in a sequence
11// that is invoked multiple times can yield different randomization results
12// for each invokation.
13//
14//===----------------------------------------------------------------------===//
15
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Index/IR/IndexDialect.h"
23#include "mlir/Dialect/Index/IR/IndexOps.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/DenseMapInfoVariant.h"
28#include "llvm/Support/Debug.h"
29#include <random>
30
31namespace circt {
32namespace rtg {
33#define GEN_PASS_DEF_ELABORATIONPASS
34#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
35} // namespace rtg
36} // namespace circt
37
38using namespace mlir;
39using namespace circt;
40using namespace circt::rtg;
41using llvm::MapVector;
42
43#define DEBUG_TYPE "rtg-elaboration"
44
45namespace {
46/// Helper class to generate random numbers supporting scopes of randomization.
47struct RngScope {
48 RngScope() = delete;
49 RngScope(const RngScope &) = delete;
50 RngScope &operator=(const RngScope &) = delete;
51
52 RngScope(RngScope &&) = default;
53 RngScope &operator=(RngScope &&) = default;
54
55 explicit RngScope(uint32_t seed) : rng(seed) {}
56
57 //===--------------------------------------------------------------------===//
58 // Uniform Distribution Helper
59 //
60 // Simplified version of
61 // https://github.com/llvm/llvm-project/blob/main/libcxx/include/__random/uniform_int_distribution.h
62 // We use our custom version here to get the same results when compiled with
63 // different compiler versions and standard libraries.
64 //===--------------------------------------------------------------------===//
65
66 /// Get a number uniformly at random in the in specified range.
67 uint32_t getUniformlyInRange(uint32_t a, uint32_t b) {
68 const uint32_t diff = b - a + 1;
69 if (diff == 1)
70 return a;
71
72 if (diff == 0)
73 return rng();
74
75 uint32_t mask =
76 std::numeric_limits<uint32_t>::max() >> (32 - llvm::Log2_32_Ceil(diff));
77 uint32_t u;
78 do {
79 u = rng() & mask;
80 } while (u >= diff);
81
82 return u + a;
83 }
84
85 RngScope getNested() { return RngScope(rng()); }
86
87private:
88 std::mt19937 rng;
89};
90} // namespace
91
92//===----------------------------------------------------------------------===//
93// Elaborator Value
94//===----------------------------------------------------------------------===//
95
96namespace {
97struct ArrayStorage;
98struct BagStorage;
99struct SequenceStorage;
100struct RandomizedSequenceStorage;
101struct InterleavedSequenceStorage;
102struct SetStorage;
103struct VirtualRegisterStorage;
104struct UniqueLabelStorage;
105struct TupleStorage;
106struct MemoryStorage;
107struct MemoryBlockStorage;
108struct SymbolicComputationWithIdentityStorage;
109struct SymbolicComputationWithIdentityValue;
110struct SymbolicComputationStorage;
111
112/// The abstract base class for elaborated values.
113using ElaboratorValue =
114 std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
115 RandomizedSequenceStorage *, InterleavedSequenceStorage *,
116 SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
117 ArrayStorage *, TupleStorage *, MemoryStorage *,
118 MemoryBlockStorage *, SymbolicComputationWithIdentityStorage *,
119 SymbolicComputationWithIdentityValue *,
120 SymbolicComputationStorage *>;
121
122// NOLINTNEXTLINE(readability-identifier-naming)
123llvm::hash_code hash_value(const ElaboratorValue &val) {
124 return std::visit(
125 [&val](const auto &alternative) {
126 // Include index in hash to make sure same value as different
127 // alternatives don't collide.
128 return llvm::hash_combine(val.index(), alternative);
129 },
130 val);
131}
132
133} // namespace
134
135namespace llvm {
136
137template <>
138struct DenseMapInfo<bool> {
139 static unsigned getHashValue(const bool &val) { return val * 37U; }
140
141 static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; }
142};
143
144} // namespace llvm
145
146//===----------------------------------------------------------------------===//
147// Elaborator Value Storages and Internalization
148//===----------------------------------------------------------------------===//
149
150namespace {
151
152/// Lightweight object to be used as the key for internalization sets. It caches
153/// the hashcode of the internalized object and a pointer to it. This allows a
154/// delayed allocation and construction of the actual object and thus only has
155/// to happen if the object is not already in the set.
156template <typename StorageTy>
157struct HashedStorage {
158 HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr)
159 : hashcode(hashcode), storage(storage) {}
160
161 unsigned hashcode;
162 StorageTy *storage;
163};
164
165/// A DenseMapInfo implementation to support 'insert_as' for the internalization
166/// sets. When comparing two 'HashedStorage's we can just compare the already
167/// internalized storage pointers, otherwise we have to call the costly
168/// 'isEqual' method.
169template <typename StorageTy>
170struct StorageKeyInfo {
171 static inline unsigned getHashValue(const HashedStorage<StorageTy> &key) {
172 return key.hashcode;
173 }
174 static inline unsigned getHashValue(const StorageTy &key) {
175 return key.hashcode;
176 }
177
178 static inline bool isEqual(const HashedStorage<StorageTy> &lhs,
179 const HashedStorage<StorageTy> &rhs) {
180 return lhs.storage == rhs.storage;
181 }
182 static inline bool isEqual(const StorageTy &lhs,
183 const HashedStorage<StorageTy> &rhs) {
184 if (!rhs.storage)
185 return false;
186
187 return lhs.isEqual(rhs.storage);
188 }
189};
190
191// Values with structural equivalence intended to be internalized.
192//===----------------------------------------------------------------------===//
193
194/// Base class for all storage objects that can be materialized as attributes.
195/// This provides a cache for the attribute that was materialized from this
196/// storage.
197struct CachableStorage {
198 TypedAttr attrCache;
199};
200
201/// Storage object for an '!rtg.set<T>'.
202struct SetStorage : CachableStorage {
203 static unsigned computeHash(const SetVector<ElaboratorValue> &set,
204 Type type) {
205 llvm::hash_code setHash = 0;
206 for (auto el : set) {
207 // Just XOR all hashes because it's a commutative operation and
208 // `llvm::hash_combine_range` is not commutative.
209 // We don't want the order in which elements were added to influence the
210 // hash and thus the equivalence of sets.
211 setHash = setHash ^ llvm::hash_combine(el);
212 }
213 return llvm::hash_combine(type, setHash);
214 }
215
216 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
217 : hashcode(computeHash(set, type)), set(std::move(set)), type(type) {}
218
219 bool isEqual(const SetStorage *other) const {
220 // Note: we are not using the `==` operator of `SetVector` because it
221 // takes the order in which elements were added into account (since it's a
222 // vector after all). We just use it as a convenient way to keep track of a
223 // deterministic order for re-materialization.
224 bool allContained = true;
225 for (auto el : set)
226 allContained &= other->set.contains(el);
227
228 return hashcode == other->hashcode && set.size() == other->set.size() &&
229 allContained && type == other->type;
230 }
231
232 // The cached hashcode to avoid repeated computations.
233 const unsigned hashcode;
234
235 // Stores the elaborated values contained in the set.
236 const SetVector<ElaboratorValue> set;
237
238 // Store the set type such that we can materialize this evaluated value
239 // also in the case where the set is empty.
240 const Type type;
241};
242
243/// Storage object for an '!rtg.bag<T>'.
244struct BagStorage {
245 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
246 : hashcode(llvm::hash_combine(
247 type, llvm::hash_combine_range(bag.begin(), bag.end()))),
248 bag(std::move(bag)), type(type) {}
249
250 bool isEqual(const BagStorage *other) const {
251 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
252 type == other->type;
253 }
254
255 // The cached hashcode to avoid repeated computations.
256 const unsigned hashcode;
257
258 // Stores the elaborated values contained in the bag with their number of
259 // occurences.
261
262 // Store the bag type such that we can materialize this evaluated value
263 // also in the case where the bag is empty.
264 const Type type;
265};
266
267/// Storage object for an '!rtg.sequence'.
268struct SequenceStorage {
269 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
270 : hashcode(llvm::hash_combine(
271 familyName, llvm::hash_combine_range(args.begin(), args.end()))),
272 familyName(familyName), args(std::move(args)) {}
273
274 bool isEqual(const SequenceStorage *other) const {
275 return hashcode == other->hashcode && familyName == other->familyName &&
276 args == other->args;
277 }
278
279 // The cached hashcode to avoid repeated computations.
280 const unsigned hashcode;
281
282 // The name of the sequence family this sequence is derived from.
283 const StringAttr familyName;
284
285 // The elaborator values used during substitution of the sequence family.
286 const SmallVector<ElaboratorValue> args;
287};
288
289/// Storage object for interleaved '!rtg.randomized_sequence'es.
290struct InterleavedSequenceStorage {
291 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
292 uint32_t batchSize)
293 : sequences(std::move(sequences)), batchSize(batchSize),
294 hashcode(llvm::hash_combine(
295 llvm::hash_combine_range(sequences.begin(), sequences.end()),
296 batchSize)) {}
297
298 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
299 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
300 hashcode(llvm::hash_combine(
301 llvm::hash_combine_range(sequences.begin(), sequences.end()),
302 batchSize)) {}
303
304 bool isEqual(const InterleavedSequenceStorage *other) const {
305 return hashcode == other->hashcode && sequences == other->sequences &&
306 batchSize == other->batchSize;
307 }
308
309 const SmallVector<ElaboratorValue> sequences;
310
311 const uint32_t batchSize;
312
313 // The cached hashcode to avoid repeated computations.
314 const unsigned hashcode;
315};
316
317/// Storage object for '!rtg.array`-typed values.
318struct ArrayStorage {
319 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
320 : hashcode(llvm::hash_combine(
321 type, llvm::hash_combine_range(array.begin(), array.end()))),
322 type(type), array(array) {}
323
324 bool isEqual(const ArrayStorage *other) const {
325 return hashcode == other->hashcode && type == other->type &&
326 array == other->array;
327 }
328
329 // The cached hashcode to avoid repeated computations.
330 const unsigned hashcode;
331
332 /// The type of the array. This is necessary because an array of size 0
333 /// cannot be reconstructed without knowing the original element type.
334 const Type type;
335
336 /// The label name. For unique labels, this is just the prefix.
337 const SmallVector<ElaboratorValue> array;
338};
339
340/// Storage object for 'tuple`-typed values.
341struct TupleStorage : CachableStorage {
342 TupleStorage(SmallVector<ElaboratorValue> &&values)
343 : hashcode(llvm::hash_combine_range(values.begin(), values.end())),
344 values(std::move(values)) {}
345
346 bool isEqual(const TupleStorage *other) const {
347 return hashcode == other->hashcode && values == other->values;
348 }
349
350 // The cached hashcode to avoid repeated computations.
351 const unsigned hashcode;
352
353 const SmallVector<ElaboratorValue> values;
354};
355
356struct SymbolicComputationStorage {
357 SymbolicComputationStorage(const DenseMap<Value, ElaboratorValue> &state,
358 Operation *op)
359 : name(op->getName()), resultTypes(op->getResultTypes()),
360 operands(llvm::map_range(op->getOperands(),
361 [&](Value v) { return state.lookup(v); })),
362 attributes(op->getAttrDictionary()),
363 properties(op->getPropertiesAsAttribute()),
364 hashcode(llvm::hash_combine(name, llvm::hash_combine_range(resultTypes),
365 llvm::hash_combine_range(operands),
366 attributes, op->hashProperties())) {}
367
368 bool isEqual(const SymbolicComputationStorage *other) const {
369 return hashcode == other->hashcode && name == other->name &&
370 resultTypes == other->resultTypes && operands == other->operands &&
371 attributes == other->attributes && properties == other->properties;
372 }
373
374 const OperationName name;
375 const SmallVector<Type> resultTypes;
376 const SmallVector<ElaboratorValue> operands;
377 const DictionaryAttr attributes;
378 const Attribute properties;
379 const unsigned hashcode;
380};
381
382// Values with identity not intended to be internalized.
383//===----------------------------------------------------------------------===//
384
385/// Base class for storages that represent values with identity, i.e., two
386/// values are not considered equivalent if they are structurally the same, but
387/// each definition of such a value is unique. E.g., unique labels or virtual
388/// registers. These cannot be materialized anew in each nested sequence, but
389/// must be passed as arguments.
390struct IdentityValue {
391
392 IdentityValue(Type type, Location loc) : type(type), loc(loc) {}
393
394#ifndef NDEBUG
395
396 /// In debug mode, track whether this value was already materialized to
397 /// assert if it's illegally materialized multiple times.
398 ///
399 /// Instead of deleting operations defining these values and materializing
400 /// them again, we could retain the operations. However, we still need
401 /// specific storages to represent these values in some cases, e.g., to get
402 /// the size of a memory allocation. Also, elaboration of nested control-flow
403 /// regions (e.g. `scf.for`) relies on materialization of such values lazily
404 /// instead of cloning the operations eagerly.
405 bool alreadyMaterialized = false;
406
407#endif
408
409 const Type type;
410 const Location loc;
411};
412
413/// Represents a unique virtual register.
414struct VirtualRegisterStorage : IdentityValue {
415 VirtualRegisterStorage(VirtualRegisterConfigAttr allowedRegs, Type type,
416 Location loc)
417 : IdentityValue(type, loc), allowedRegs(allowedRegs) {}
418
419 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
420 // VirtualRegisters are never internalized.
421
422 // The list of fixed registers allowed to be selected for this virtual
423 // register.
424 const VirtualRegisterConfigAttr allowedRegs;
425};
426
427struct UniqueLabelStorage : IdentityValue {
428 UniqueLabelStorage(const ElaboratorValue &name, Location loc)
429 : IdentityValue(LabelType::get(loc->getContext()), loc), name(name) {}
430
431 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
432 // VirtualRegisters are never internalized.
433
434 /// The label name. For unique labels, this is just the prefix.
435 const ElaboratorValue name;
436};
437
438/// Storage object for '!rtg.isa.memoryblock`-typed values.
439struct MemoryBlockStorage : IdentityValue {
440 MemoryBlockStorage(const APInt &baseAddress, const APInt &endAddress,
441 Type type, Location loc)
442 : IdentityValue(type, loc), baseAddress(baseAddress),
443 endAddress(endAddress) {}
444
445 // The base address of the memory. The width of the APInt also represents the
446 // address width of the memory. This is an APInt to support memories of
447 // >64-bit machines.
448 const APInt baseAddress;
449
450 // The last address of the memory.
451 const APInt endAddress;
452};
453
454/// Storage object for '!rtg.isa.memory`-typed values.
455struct MemoryStorage : IdentityValue {
456 MemoryStorage(MemoryBlockStorage *memoryBlock, size_t size, size_t alignment,
457 Location loc)
458 : IdentityValue(MemoryType::get(memoryBlock->type.getContext(),
459 memoryBlock->baseAddress.getBitWidth()),
460 loc),
461 memoryBlock(memoryBlock), size(size), alignment(alignment) {}
462
463 MemoryBlockStorage *memoryBlock;
464 const size_t size;
465 const size_t alignment;
466};
467
468/// Storage object for an '!rtg.randomized_sequence'.
469struct RandomizedSequenceStorage : IdentityValue {
470 RandomizedSequenceStorage(ContextResourceAttrInterface context,
471 SequenceStorage *sequence, Location loc)
472 : IdentityValue(
473 RandomizedSequenceType::get(sequence->familyName.getContext()),
474 loc),
475 context(context), sequence(sequence) {}
476
477 // The context under which this sequence is placed.
478 const ContextResourceAttrInterface context;
479
480 const SequenceStorage *sequence;
481};
482
483/// Operation must have at least 1 result.
484struct SymbolicComputationWithIdentityStorage : IdentityValue {
485 SymbolicComputationWithIdentityStorage(
486 const DenseMap<Value, ElaboratorValue> &state, Operation *op)
487 : IdentityValue(op->getResult(0).getType(), op->getLoc()),
488 name(op->getName()), resultTypes(op->getResultTypes()),
489 operands(llvm::map_range(op->getOperands(),
490 [&](Value v) { return state.lookup(v); })),
491 attributes(op->getAttrDictionary()),
492 properties(op->getPropertiesAsAttribute()) {}
493
494 const OperationName name;
495 const SmallVector<Type> resultTypes;
496 const SmallVector<ElaboratorValue> operands;
497 const DictionaryAttr attributes;
498 const Attribute properties;
499};
500
501struct SymbolicComputationWithIdentityValue : IdentityValue {
502 SymbolicComputationWithIdentityValue(
503 Type type, const SymbolicComputationWithIdentityStorage *storage,
504 unsigned idx)
505 : IdentityValue(type, storage->loc), storage(storage), idx(idx) {
506 assert(
507 idx != 0 &&
508 "Use SymbolicComputationWithIdentityStorage for result with index 0.");
509 }
510
511 const SymbolicComputationWithIdentityStorage *storage;
512 const unsigned idx;
513};
514
515/// An 'Internalizer' object internalizes storages and takes ownership of them.
516/// When the initializer object is destroyed, all owned storages are also
517/// deallocated and thus must not be accessed anymore.
518class Internalizer {
519public:
520 /// Internalize a storage of type `StorageTy` constructed with arguments
521 /// `args`. The pointers returned by this method can be used to compare
522 /// objects when, e.g., computing set differences, uniquing the elements in a
523 /// set, etc. Otherwise, we'd need to do a deep value comparison in those
524 /// situations.
525 template <typename StorageTy, typename... Args>
526 StorageTy *internalize(Args &&...args) {
527 static_assert(!std::is_base_of_v<IdentityValue, StorageTy> &&
528 "values with identity must not be internalized");
529
530 StorageTy storage(std::forward<Args>(args)...);
531
532 auto existing = getInternSet<StorageTy>().insert_as(
533 HashedStorage<StorageTy>(storage.hashcode), storage);
534 StorageTy *&storagePtr = existing.first->storage;
535 if (existing.second)
536 storagePtr =
537 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
538
539 return storagePtr;
540 }
541
542 template <typename StorageTy, typename... Args>
543 StorageTy *create(Args &&...args) {
544 static_assert(std::is_base_of_v<IdentityValue, StorageTy> &&
545 "values with structural equivalence must be internalized");
546
547 return new (allocator.Allocate<StorageTy>())
548 StorageTy(std::forward<Args>(args)...);
549 }
550
551private:
552 template <typename StorageTy>
553 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
554 getInternSet() {
555 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
556 return internedArrays;
557 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
558 return internedSets;
559 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
560 return internedBags;
561 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
562 return internedSequences;
563 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
564 return internedRandomizedSequences;
565 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
566 return internedInterleavedSequences;
567 else if constexpr (std::is_same_v<StorageTy, TupleStorage>)
568 return internedTuples;
569 else if constexpr (std::is_same_v<StorageTy, SymbolicComputationStorage>)
570 return internedSymbolicComputationWithIdentityValues;
571 else
572 static_assert(!sizeof(StorageTy),
573 "no intern set available for this storage type.");
574 }
575
576 // This allocator allocates on the heap. It automatically deallocates all
577 // objects it allocated once the allocator itself is destroyed.
578 llvm::BumpPtrAllocator allocator;
579
580 // The sets holding the internalized objects. We use one set per storage type
581 // such that we can have a simpler equality checking function (no need to
582 // compare some sort of TypeIDs).
583 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
584 internedArrays;
585 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
586 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
587 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
588 internedSequences;
589 DenseSet<HashedStorage<RandomizedSequenceStorage>,
590 StorageKeyInfo<RandomizedSequenceStorage>>
591 internedRandomizedSequences;
592 DenseSet<HashedStorage<InterleavedSequenceStorage>,
593 StorageKeyInfo<InterleavedSequenceStorage>>
594 internedInterleavedSequences;
595 DenseSet<HashedStorage<TupleStorage>, StorageKeyInfo<TupleStorage>>
596 internedTuples;
597 DenseSet<HashedStorage<SymbolicComputationStorage>,
598 StorageKeyInfo<SymbolicComputationStorage>>
599 internedSymbolicComputationWithIdentityValues;
600};
601
602} // namespace
603
604#ifndef NDEBUG
605
606static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
607 const ElaboratorValue &value);
608
609static void print(TypedAttr val, llvm::raw_ostream &os) {
610 os << "<attr " << val << ">";
611}
612
613static void print(BagStorage *val, llvm::raw_ostream &os) {
614 os << "<bag {";
615 llvm::interleaveComma(val->bag, os,
616 [&](const std::pair<ElaboratorValue, uint64_t> &el) {
617 os << el.first << " -> " << el.second;
618 });
619 os << "} at " << val << ">";
620}
621
622static void print(bool val, llvm::raw_ostream &os) {
623 os << "<bool " << (val ? "true" : "false") << ">";
624}
625
626static void print(size_t val, llvm::raw_ostream &os) {
627 os << "<index " << val << ">";
628}
629
630static void print(SequenceStorage *val, llvm::raw_ostream &os) {
631 os << "<sequence @" << val->familyName.getValue() << "(";
632 llvm::interleaveComma(val->args, os,
633 [&](const ElaboratorValue &val) { os << val; });
634 os << ") at " << val << ">";
635}
636
637static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
638 os << "<randomized-sequence derived from @"
639 << val->sequence->familyName.getValue() << " under context "
640 << val->context << "(";
641 llvm::interleaveComma(val->sequence->args, os,
642 [&](const ElaboratorValue &val) { os << val; });
643 os << ") at " << val << ">";
644}
645
646static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
647 os << "<interleaved-sequence [";
648 llvm::interleaveComma(val->sequences, os,
649 [&](const ElaboratorValue &val) { os << val; });
650 os << "] batch-size " << val->batchSize << " at " << val << ">";
651}
652
653static void print(ArrayStorage *val, llvm::raw_ostream &os) {
654 os << "<array [";
655 llvm::interleaveComma(val->array, os,
656 [&](const ElaboratorValue &val) { os << val; });
657 os << "] at " << val << ">";
658}
659
660static void print(SetStorage *val, llvm::raw_ostream &os) {
661 os << "<set {";
662 llvm::interleaveComma(val->set, os,
663 [&](const ElaboratorValue &val) { os << val; });
664 os << "} at " << val << ">";
665}
666
667static void print(const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
668 os << "<virtual-register " << val << " " << val->allowedRegs << ">";
669}
670
671static void print(const UniqueLabelStorage *val, llvm::raw_ostream &os) {
672 os << "<unique-label " << val << " " << val->name << ">";
673}
674
675static void print(const TupleStorage *val, llvm::raw_ostream &os) {
676 os << "<tuple (";
677 llvm::interleaveComma(val->values, os,
678 [&](const ElaboratorValue &val) { os << val; });
679 os << ")>";
680}
681
682static void print(const MemoryStorage *val, llvm::raw_ostream &os) {
683 os << "<memory {" << ElaboratorValue(val->memoryBlock)
684 << ", size=" << val->size << ", alignment=" << val->alignment << "}>";
685}
686
687static void print(const MemoryBlockStorage *val, llvm::raw_ostream &os) {
688 os << "<memory-block {"
689 << ", address-width=" << val->baseAddress.getBitWidth()
690 << ", base-address=" << val->baseAddress
691 << ", end-address=" << val->endAddress << "}>";
692}
693
694static void print(const SymbolicComputationWithIdentityValue *val,
695 llvm::raw_ostream &os) {
696 os << "<symbolic-computation-with-identity-value (" << val->storage << ") at "
697 << val->idx << ">";
698}
699
700static void print(const SymbolicComputationWithIdentityStorage *val,
701 llvm::raw_ostream &os) {
702 os << "<symbolic-computation-with-identity " << val->name << "(";
703 llvm::interleaveComma(val->operands, os,
704 [&](const ElaboratorValue &val) { os << val; });
705 os << ") -> " << val->resultTypes << " with attributes " << val->attributes
706 << " and properties " << val->properties;
707 os << ">";
708}
709
710static void print(const SymbolicComputationStorage *val,
711 llvm::raw_ostream &os) {
712 os << "<symbolic-computation " << val->name << "(";
713 llvm::interleaveComma(val->operands, os,
714 [&](const ElaboratorValue &val) { os << val; });
715 os << ") -> " << val->resultTypes << " with attributes " << val->attributes
716 << " and properties " << val->properties;
717 os << ">";
718}
719
720static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
721 const ElaboratorValue &value) {
722 std::visit([&](auto val) { print(val, os); }, value);
723
724 return os;
725}
726
727#endif
728
729//===----------------------------------------------------------------------===//
730// Attribute <-> ElaboratorValue Converters
731//===----------------------------------------------------------------------===//
732
733namespace {
734
735/// Convert an attribute to an ElaboratorValue. This handles nested attributes
736/// like SetAttr and TupleAttr recursively.
737class AttributeToElaboratorValueConverter {
738public:
739 AttributeToElaboratorValueConverter(Internalizer &internalizer)
740 : internalizer(internalizer) {}
741
742 /// Convert an attribute to an ElaboratorValue.
743 FailureOr<ElaboratorValue> convert(Attribute attr) {
744 return llvm::TypeSwitch<Attribute, FailureOr<ElaboratorValue>>(attr)
745 .Case<IntegerAttr, SetAttr, TupleAttr>(
746 [&](auto attr) { return convert(attr); })
747 .Case<TypedAttr>([&](auto typedAttr) -> FailureOr<ElaboratorValue> {
748 return ElaboratorValue(typedAttr);
749 })
750 .Default(
751 [&](Attribute) -> FailureOr<ElaboratorValue> { return failure(); });
752 }
753
754private:
755 FailureOr<ElaboratorValue> convert(IntegerAttr attr) {
756 if (attr.getType().isSignlessInteger(1))
757 return ElaboratorValue(bool(attr.getInt()));
758 if (isa<IndexType>(attr.getType()))
759 return ElaboratorValue(size_t(attr.getInt()));
760 return ElaboratorValue(attr);
761 }
762
763 FailureOr<ElaboratorValue> convert(SetAttr setAttr) {
764 SetVector<ElaboratorValue> set;
765 for (auto element : *setAttr.getElements()) {
766 auto converted = convert(element);
767 if (failed(converted))
768 return failure();
769 set.insert(*converted);
770 }
771 auto *storage =
772 internalizer.internalize<SetStorage>(std::move(set), setAttr.getType());
773 // Cache the original attribute for efficient materialization
774 storage->attrCache = setAttr;
775 return ElaboratorValue(storage);
776 }
777
778 FailureOr<ElaboratorValue> convert(TupleAttr tupleAttr) {
779 SmallVector<ElaboratorValue> values;
780 for (auto element : tupleAttr.getElements()) {
781 auto converted = convert(element);
782 if (failed(converted))
783 return failure();
784 values.push_back(*converted);
785 }
786 auto *storage = internalizer.internalize<TupleStorage>(std::move(values));
787 // Cache the original attribute for efficient materialization
788 storage->attrCache = tupleAttr;
789 return ElaboratorValue(storage);
790 }
791
792 Internalizer &internalizer;
793};
794
795/// Convert an ElaboratorValue to an attribute when possible. Not all
796/// ElaboratorValues can be converted to attributes (e.g., values with
797/// identity).
798class ElaboratorValueToAttributeConverter {
799public:
800 ElaboratorValueToAttributeConverter(MLIRContext *context)
801 : context(context) {}
802
803 /// Convert an ElaboratorValue to an attribute. Returns a null attribute if
804 /// the conversion is not possible.
805 TypedAttr convert(const ElaboratorValue &value) {
806 return std::visit(
807 [&](auto val) -> TypedAttr {
808 if constexpr (std::is_base_of_v<CachableStorage,
809 std::remove_pointer_t<
810 std::decay_t<decltype(value)>>>) {
811 if (val->attrCache)
812 return val->attrCache;
813 }
814 return visit(val);
815 },
816 value);
817 }
818
819private:
820 TypedAttr visit(TypedAttr val) { return val; }
821
822 TypedAttr visit(bool val) {
823 return IntegerAttr::get(IntegerType::get(context, 1), val);
824 }
825
826 TypedAttr visit(size_t val) {
827 return IntegerAttr::get(IndexType::get(context), val);
828 }
829
830 TypedAttr visit(SetStorage *val) {
831 DenseSet<TypedAttr> elements;
832 for (auto element : val->set) {
833 auto converted = convert(element);
834 if (!converted)
835 return {};
836 auto typedAttr = dyn_cast<TypedAttr>(converted);
837 if (!typedAttr)
838 return {};
839 elements.insert(typedAttr);
840 }
841 return SetAttr::get(cast<SetType>(val->type), &elements);
842 }
843
844 TypedAttr visit(TupleStorage *val) {
845 SmallVector<TypedAttr> elements;
846 for (auto element : val->values) {
847 auto converted = convert(element);
848 if (!converted)
849 return {};
850 auto typedAttr = dyn_cast<TypedAttr>(converted);
851 if (!typedAttr)
852 return {};
853 elements.push_back(typedAttr);
854 }
855 return TupleAttr::get(context, elements);
856 }
857
858 // Default implementation for storage types that cannot be converted to
859 // attributes. Using a macro to avoid repetition. std::visit does not support
860 // any kind of "default" clauses, unfortunately.
861#define VISIT_UNSUPPORTED(STORAGETYPE) \
862 /* NOLINTNEXTLINE(bugprone-macro-parentheses)*/ \
863 TypedAttr visit(STORAGETYPE *val) { return {}; }
864
865 VISIT_UNSUPPORTED(ArrayStorage)
866 VISIT_UNSUPPORTED(BagStorage)
867 VISIT_UNSUPPORTED(SequenceStorage)
868 VISIT_UNSUPPORTED(RandomizedSequenceStorage)
869 VISIT_UNSUPPORTED(InterleavedSequenceStorage)
870 VISIT_UNSUPPORTED(VirtualRegisterStorage)
871 VISIT_UNSUPPORTED(UniqueLabelStorage)
872 VISIT_UNSUPPORTED(MemoryStorage)
873 VISIT_UNSUPPORTED(MemoryBlockStorage)
874 VISIT_UNSUPPORTED(SymbolicComputationWithIdentityStorage)
875 VISIT_UNSUPPORTED(SymbolicComputationWithIdentityValue)
876 VISIT_UNSUPPORTED(SymbolicComputationStorage)
877
878#undef VISIT_UNSUPPORTED
879
880 MLIRContext *context;
881};
882
883} // namespace
884
885//===----------------------------------------------------------------------===//
886// Elaborator Value Materialization
887//===----------------------------------------------------------------------===//
888
889namespace {
890
891/// State that should be shared by all elaborator and materializer instances.
892struct SharedState {
893 SharedState(MLIRContext *ctxt, SymbolTable &table)
894 : ctxt(ctxt), table(table) {}
895
896 MLIRContext *ctxt;
897 SymbolTable &table;
898 Namespace names;
899 Internalizer internalizer;
900};
901
902/// A collection of state per RTG test.
903struct TestState {
904 explicit TestState(unsigned seed) : rng(RngScope(seed)) {}
905
906 /// The name of the test.
907 StringAttr name;
908
909 /// The context switches registered for this test.
910 MapVector<
911 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
912 SequenceStorage *>
913 contextSwitches;
914
915 /// The root RNG scope for this test.
916 RngScope rng;
917};
918
919/// Construct an SSA value from a given elaborated value.
920class Materializer {
921public:
922 Materializer(OpBuilder builder, TestState &testState,
923 SharedState &sharedState,
924 SmallVector<ElaboratorValue> &blockArgs)
925 : builder(builder), testState(testState), sharedState(sharedState),
926 blockArgs(blockArgs), attrConverter(builder.getContext()) {}
927
928 /// Materialize IR representing the provided `ElaboratorValue` and return the
929 /// `Value` or a null value on failure.
930 Value materialize(ElaboratorValue val, Location loc,
931 function_ref<InFlightDiagnostic()> emitError) {
932 auto iter = materializedValues.find(val);
933 if (iter != materializedValues.end())
934 return iter->second;
935
936 LLVM_DEBUG(llvm::dbgs() << "Materializing " << val);
937
938 if (auto res = tryMaterializeAsConstant(val, loc))
939 return res;
940
941 // In debug mode, track whether values with identity were already
942 // materialized before and assert in such a situation.
943 Value res = std::visit(
944 [&](auto value) {
945 if constexpr (std::is_base_of_v<IdentityValue,
946 std::remove_pointer_t<
947 std::decay_t<decltype(value)>>>) {
948 if (identityValueRoot.contains(value)) {
949#ifndef NDEBUG
950 bool &materialized =
951 static_cast<IdentityValue *>(value)->alreadyMaterialized;
952 assert(!materialized && "must not already be materialized");
953 materialized = true;
954#endif
955
956 return visit(value, loc, emitError);
957 }
958
959 Value arg = builder.getBlock()->addArgument(value->type, loc);
960 blockArgs.push_back(val);
961 blockArgTypes.push_back(arg.getType());
962 materializedValues[val] = arg;
963 return arg;
964 }
965
966 return visit(value, loc, emitError);
967 },
968 val);
969
970 LLVM_DEBUG(llvm::dbgs() << " to\n" << res << "\n\n");
971
972 return res;
973 }
974
975 bool isInPlace(Operation *op) const {
976 return builder.getBlock()->getParent() == op->getParentRegion();
977 }
978
979 /// If `op` is not in the same region as the materializer insertion point, a
980 /// clone is created at the materializer's insertion point by also
981 /// materializing the `ElaboratorValue`s for each operand just before it.
982 /// Otherwise, all operations after the materializer's insertion point are
983 /// deleted until `op` is reached. An error is returned if the operation is
984 /// before the insertion point.
985 LogicalResult materialize(Operation *op,
986 DenseMap<Value, ElaboratorValue> &state) {
987 if (op->getNumRegions() > 0)
988 return op->emitOpError("ops with nested regions must be elaborated away");
989
990 // We don't support opaque values. If there is an SSA value that has a
991 // use-site it needs an equivalent ElaborationValue representation.
992 // NOTE: We could support cases where there is initially a use-site but that
993 // op is guaranteed to be deleted during elaboration. Or the use-sites are
994 // replaced with freshly materialized values from the ElaborationValue. But
995 // then, why can't we delete the value defining op?
996 for (auto res : op->getResults())
997 if (!res.use_empty())
998 return op->emitOpError(
999 "ops with results that have uses are not supported");
1000
1001 if (isInPlace(op)) {
1002 // We are doing in-place materialization, so mark all ops deleted until we
1003 // reach the one to be materialized and modify it in-place.
1004 deleteOpsUntil([&](auto iter) { return &*iter == op; });
1005
1006 if (builder.getInsertionPoint() == builder.getBlock()->end())
1007 return op->emitError("operation did not occur after the current "
1008 "materializer insertion point");
1009
1010 LLVM_DEBUG(llvm::dbgs() << "Modifying in-place: " << *op << "\n\n");
1011 } else {
1012 LLVM_DEBUG(llvm::dbgs() << "Materializing a clone of " << *op << "\n\n");
1013 op = builder.clone(*op);
1014 builder.setInsertionPoint(op);
1015 }
1016
1017 for (auto &operand : op->getOpOperands()) {
1018 auto emitError = [&]() {
1019 auto diag = op->emitError();
1020 diag.attachNote(op->getLoc())
1021 << "while materializing value for operand#"
1022 << operand.getOperandNumber();
1023 return diag;
1024 };
1025
1026 auto elabVal = state.at(operand.get());
1027 Value val = materialize(elabVal, op->getLoc(), emitError);
1028 if (!val)
1029 return failure();
1030
1031 state[val] = elabVal;
1032 operand.set(val);
1033 }
1034
1035 builder.setInsertionPointAfter(op);
1036 return success();
1037 }
1038
1039 /// Should be called once the `Region` is successfully materialized. No calls
1040 /// to `materialize` should happen after this anymore.
1041 void finalize() {
1042 deleteOpsUntil([](auto iter) { return false; });
1043
1044 for (auto *op : llvm::reverse(toDelete))
1045 op->erase();
1046 }
1047
1048 /// Tell this materializer that it is responsible for materializing the given
1049 /// identity value at the earliest position it is needed, and should't
1050 /// request the value via block argument.
1051 void registerIdentityValue(IdentityValue *val) {
1052 identityValueRoot.insert(val);
1053 }
1054
1055 ArrayRef<Type> getBlockArgTypes() const { return blockArgTypes; }
1056
1057 void map(ElaboratorValue eval, Value val) { materializedValues[eval] = val; }
1058
1059 template <typename OpTy, typename... Args>
1060 OpTy create(Location location, Args &&...args) {
1061 return OpTy::create(builder, location, std::forward<Args>(args)...);
1062 }
1063
1064private:
1065 Value tryMaterializeAsConstant(ElaboratorValue val, Location loc) {
1066 if (auto attr = attrConverter.convert(val)) {
1067 Value res = ConstantOp::create(builder, loc, attr);
1068 materializedValues[val] = res;
1069 return res;
1070 }
1071
1072 return Value();
1073 }
1074
1075 SequenceOp elaborateSequence(const RandomizedSequenceStorage *seq,
1076 SmallVector<ElaboratorValue> &elabArgs);
1077
1078 void deleteOpsUntil(function_ref<bool(Block::iterator)> stop) {
1079 auto ip = builder.getInsertionPoint();
1080 while (ip != builder.getBlock()->end() && !stop(ip)) {
1081 LLVM_DEBUG(llvm::dbgs() << "Marking to be deleted: " << *ip << "\n\n");
1082 toDelete.push_back(&*ip);
1083
1084 builder.setInsertionPointAfter(&*ip);
1085 ip = builder.getInsertionPoint();
1086 }
1087 }
1088
1089 Value visit(TypedAttr val, Location loc,
1090 function_ref<InFlightDiagnostic()> emitError) {
1091 return {};
1092 }
1093
1094 Value visit(size_t val, Location loc,
1095 function_ref<InFlightDiagnostic()> emitError) {
1096 return {};
1097 }
1098
1099 Value visit(bool val, Location loc,
1100 function_ref<InFlightDiagnostic()> emitError) {
1101 return {};
1102 }
1103
1104 Value visit(ArrayStorage *val, Location loc,
1105 function_ref<InFlightDiagnostic()> emitError) {
1106 SmallVector<Value> elements;
1107 elements.reserve(val->array.size());
1108 for (auto el : val->array) {
1109 auto materialized = materialize(el, loc, emitError);
1110 if (!materialized)
1111 return Value();
1112
1113 elements.push_back(materialized);
1114 }
1115
1116 Value res = ArrayCreateOp::create(builder, loc, val->type, elements);
1117 materializedValues[val] = res;
1118 return res;
1119 }
1120
1121 Value visit(SetStorage *val, Location loc,
1122 function_ref<InFlightDiagnostic()> emitError) {
1123 SmallVector<Value> elements;
1124 elements.reserve(val->set.size());
1125 for (auto el : val->set) {
1126 auto materialized = materialize(el, loc, emitError);
1127 if (!materialized)
1128 return Value();
1129
1130 elements.push_back(materialized);
1131 }
1132
1133 auto res = SetCreateOp::create(builder, loc, val->type, elements);
1134 materializedValues[val] = res;
1135 return res;
1136 }
1137
1138 Value visit(BagStorage *val, Location loc,
1139 function_ref<InFlightDiagnostic()> emitError) {
1140 SmallVector<Value> values, weights;
1141 values.reserve(val->bag.size());
1142 weights.reserve(val->bag.size());
1143 for (auto [val, weight] : val->bag) {
1144 auto materializedVal = materialize(val, loc, emitError);
1145 auto materializedWeight = materialize(weight, loc, emitError);
1146 if (!materializedVal || !materializedWeight)
1147 return Value();
1148
1149 values.push_back(materializedVal);
1150 weights.push_back(materializedWeight);
1151 }
1152
1153 auto res = BagCreateOp::create(builder, loc, val->type, values, weights);
1154 materializedValues[val] = res;
1155 return res;
1156 }
1157
1158 Value visit(MemoryBlockStorage *val, Location loc,
1159 function_ref<InFlightDiagnostic()> emitError) {
1160 auto intType = builder.getIntegerType(val->baseAddress.getBitWidth());
1161 Value res = MemoryBlockDeclareOp::create(
1162 builder, val->loc, val->type,
1163 IntegerAttr::get(intType, val->baseAddress),
1164 IntegerAttr::get(intType, val->endAddress));
1165 materializedValues[val] = res;
1166 return res;
1167 }
1168
1169 Value visit(MemoryStorage *val, Location loc,
1170 function_ref<InFlightDiagnostic()> emitError) {
1171 auto memBlock = materialize(val->memoryBlock, val->loc, emitError);
1172 auto memSize = materialize(val->size, val->loc, emitError);
1173 auto memAlign = materialize(val->alignment, val->loc, emitError);
1174 if (!(memBlock && memSize && memAlign))
1175 return {};
1176
1177 Value res =
1178 MemoryAllocOp::create(builder, val->loc, memBlock, memSize, memAlign);
1179 materializedValues[val] = res;
1180 return res;
1181 }
1182
1183 Value visit(SequenceStorage *val, Location loc,
1184 function_ref<InFlightDiagnostic()> emitError) {
1185 emitError() << "materializing a non-randomized sequence not supported yet";
1186 return Value();
1187 }
1188
1189 Value visit(RandomizedSequenceStorage *val, Location loc,
1190 function_ref<InFlightDiagnostic()> emitError) {
1191 // To know which values we have to pass by argument (and not just pass all
1192 // that migth be used eagerly), we have to elaborate the sequence family if
1193 // not already done so.
1194 // We need to get back the sequence to reference, and the list of elaborated
1195 // values to pass as arguments.
1196 SmallVector<ElaboratorValue> elabArgs;
1197 // NOTE: we wouldn't need to elaborate the sequence if it doesn't contain
1198 // randomness to be elaborated.
1199 SequenceOp seqOp = elaborateSequence(val, elabArgs);
1200 if (!seqOp)
1201 return {};
1202
1203 // Materialize all the values we need to pass as arguments and collect their
1204 // types.
1205 SmallVector<Value> args;
1206 SmallVector<Type> argTypes;
1207 for (auto arg : elabArgs) {
1208 Value materialized = materialize(arg, val->loc, emitError);
1209 if (!materialized)
1210 return {};
1211
1212 args.push_back(materialized);
1213 argTypes.push_back(materialized.getType());
1214 }
1215
1216 Value res = GetSequenceOp::create(
1217 builder, val->loc, SequenceType::get(builder.getContext(), argTypes),
1218 seqOp.getSymName());
1219
1220 // Only materialize a substitute_sequence op when we have arguments to
1221 // substitute since this op does not support 0 arguments.
1222 if (!args.empty())
1223 res = SubstituteSequenceOp::create(builder, val->loc, res, args);
1224
1225 res = RandomizeSequenceOp::create(builder, val->loc, res);
1226
1227 materializedValues[val] = res;
1228 return res;
1229 }
1230
1231 Value visit(InterleavedSequenceStorage *val, Location loc,
1232 function_ref<InFlightDiagnostic()> emitError) {
1233 SmallVector<Value> sequences;
1234 for (auto seqVal : val->sequences) {
1235 Value materialized = materialize(seqVal, loc, emitError);
1236 if (!materialized)
1237 return {};
1238
1239 sequences.push_back(materialized);
1240 }
1241
1242 if (sequences.size() == 1)
1243 return sequences[0];
1244
1245 Value res =
1246 InterleaveSequencesOp::create(builder, loc, sequences, val->batchSize);
1247 materializedValues[val] = res;
1248 return res;
1249 }
1250
1251 Value visit(VirtualRegisterStorage *val, Location loc,
1252 function_ref<InFlightDiagnostic()> emitError) {
1253 Value res = VirtualRegisterOp::create(builder, val->loc, val->allowedRegs);
1254 materializedValues[val] = res;
1255 return res;
1256 }
1257
1258 Value visit(UniqueLabelStorage *val, Location loc,
1259 function_ref<InFlightDiagnostic()> emitError) {
1260 auto materialized = materialize(val->name, val->loc, emitError);
1261 if (!materialized)
1262 return {};
1263 Value res = LabelUniqueDeclOp::create(builder, val->loc, materialized);
1264 materializedValues[val] = res;
1265 return res;
1266 }
1267
1268 Value visit(TupleStorage *val, Location loc,
1269 function_ref<InFlightDiagnostic()> emitError) {
1270 SmallVector<Value> materialized;
1271 materialized.reserve(val->values.size());
1272 for (auto v : val->values)
1273 materialized.push_back(materialize(v, loc, emitError));
1274 Value res = TupleCreateOp::create(builder, loc, materialized);
1275 materializedValues[val] = res;
1276 return res;
1277 }
1278
1279 Value visit(SymbolicComputationWithIdentityValue *val, Location loc,
1280 function_ref<InFlightDiagnostic()> emitError) {
1281 auto *noConstStorage =
1282 const_cast<SymbolicComputationWithIdentityStorage *>(val->storage);
1283 auto res0 = materialize(noConstStorage, loc, emitError);
1284 if (!res0)
1285 return {};
1286
1287 auto *op = res0.getDefiningOp();
1288 auto res = op->getResults()[val->idx];
1289 materializedValues[val] = res;
1290 return res;
1291 }
1292
1293 Value visit(SymbolicComputationWithIdentityStorage *val, Location loc,
1294 function_ref<InFlightDiagnostic()> emitError) {
1295 SmallVector<Value> operands;
1296 for (auto operand : val->operands) {
1297 auto materialized = materialize(operand, val->loc, emitError);
1298 if (!materialized)
1299 return {};
1300
1301 operands.push_back(materialized);
1302 }
1303
1304 OperationState state(val->loc, val->name);
1305 state.addTypes(val->resultTypes);
1306 state.attributes = val->attributes;
1307 state.propertiesAttr = val->properties;
1308 state.addOperands(operands);
1309 auto *op = builder.create(state);
1310
1311 materializedValues[val] = op->getResult(0);
1312 return op->getResult(0);
1313 }
1314
1315 Value visit(SymbolicComputationStorage *val, Location loc,
1316 function_ref<InFlightDiagnostic()> emitError) {
1317 SmallVector<Value> operands;
1318 for (auto operand : val->operands) {
1319 auto materialized = materialize(operand, loc, emitError);
1320 if (!materialized)
1321 return {};
1322
1323 operands.push_back(materialized);
1324 }
1325
1326 OperationState state(loc, val->name);
1327 state.addTypes(val->resultTypes);
1328 state.attributes = val->attributes;
1329 state.propertiesAttr = val->properties;
1330 state.addOperands(operands);
1331 auto *op = builder.create(state);
1332
1333 for (auto res : op->getResults())
1334 materializedValues[val] = res;
1335
1336 return op->getResult(0);
1337 }
1338
1339private:
1340 /// Cache values we have already materialized to reuse them later. We start
1341 /// with an insertion point at the start of the block and cache the (updated)
1342 /// insertion point such that future materializations can also reuse previous
1343 /// materializations without running into dominance issues (or requiring
1344 /// additional checks to avoid them).
1345 DenseMap<ElaboratorValue, Value> materializedValues;
1346
1347 /// Cache the builder to continue insertions at their current insertion point
1348 /// for the reason stated above.
1349 OpBuilder builder;
1350
1351 SmallVector<Operation *> toDelete;
1352
1353 TestState &testState;
1354 SharedState &sharedState;
1355
1356 /// Keep track of the block arguments we had to add to this materializer's
1357 /// block for identity values and also remember which elaborator values are
1358 /// expected to be passed as arguments from outside.
1359 SmallVector<ElaboratorValue> &blockArgs;
1360 SmallVector<Type> blockArgTypes;
1361
1362 /// Identity values in this set are materialized by this materializer,
1363 /// otherwise they are added as block arguments and the block that wants to
1364 /// embed this sequence is expected to provide a value for it.
1365 DenseSet<IdentityValue *> identityValueRoot;
1366
1367 /// Helper to convert ElaboratorValues to attributes.
1368 ElaboratorValueToAttributeConverter attrConverter;
1369};
1370
1371//===----------------------------------------------------------------------===//
1372// Elaboration Visitor
1373//===----------------------------------------------------------------------===//
1374
1375/// Used to signal to the elaboration driver whether the operation should be
1376/// removed.
1377enum class DeletionKind { Keep, Delete };
1378
1379/// Interprets the IR to perform and lower the represented randomizations.
1380class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
1381public:
1383 using RTGBase::visitOp;
1384
1385 Elaborator(SharedState &sharedState, TestState &testState,
1386 Materializer &materializer,
1387 ContextResourceAttrInterface currentContext = {})
1388 : sharedState(sharedState), testState(testState),
1389 materializer(materializer), currentContext(currentContext),
1390 attrConverter(sharedState.internalizer),
1391 elabValConverter(sharedState.ctxt) {}
1392
1393 template <typename ValueTy>
1394 inline ValueTy get(Value val) const {
1395 return std::get<ValueTy>(state.at(val));
1396 }
1397
1398 /// Print a nice error message for operations we don't support yet.
1399 FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
1400 return visitOpGeneric(op);
1401 }
1402
1403 FailureOr<DeletionKind> visitExternalOp(Operation *op) {
1404 return visitOpGeneric(op);
1405 }
1406
1407 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
1408 SmallVector<ElaboratorValue> replacements;
1409 state[op.getResult()] =
1410 sharedState.internalizer.internalize<SequenceStorage>(
1411 op.getSequenceAttr().getAttr(), std::move(replacements));
1412 return DeletionKind::Delete;
1413 }
1414
1415 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
1416 if (isSymbolic(state.at(op.getSequence())))
1417 return visitOpGeneric(op);
1418
1419 auto *seq = get<SequenceStorage *>(op.getSequence());
1420
1421 SmallVector<ElaboratorValue> replacements(seq->args);
1422 for (auto replacement : op.getReplacements())
1423 replacements.push_back(state.at(replacement));
1424
1425 state[op.getResult()] =
1426 sharedState.internalizer.internalize<SequenceStorage>(
1427 seq->familyName, std::move(replacements));
1428
1429 return DeletionKind::Delete;
1430 }
1431
1432 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
1433 auto *seq = get<SequenceStorage *>(op.getSequence());
1434 auto *randomizedSeq =
1435 sharedState.internalizer.create<RandomizedSequenceStorage>(
1436 currentContext, seq, op.getLoc());
1437 materializer.registerIdentityValue(randomizedSeq);
1438 state[op.getResult()] =
1439 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1440 randomizedSeq);
1441 return DeletionKind::Delete;
1442 }
1443
1444 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
1445 SmallVector<ElaboratorValue> sequences;
1446 for (auto seq : op.getSequences())
1447 sequences.push_back(state.at(seq));
1448
1449 state[op.getResult()] =
1450 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
1451 std::move(sequences), op.getBatchSize());
1452 return DeletionKind::Delete;
1453 }
1454
1455 // NOLINTNEXTLINE(misc-no-recursion)
1456 LogicalResult isValidContext(ElaboratorValue value, Operation *op) const {
1457 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
1458 auto *seq = std::get<RandomizedSequenceStorage *>(value);
1459 if (seq->context != currentContext) {
1460 auto err = op->emitError("attempting to place sequence derived from ")
1461 << seq->sequence->familyName.getValue() << " under context "
1462 << currentContext
1463 << ", but it was previously randomized for context ";
1464 if (seq->context)
1465 err << seq->context;
1466 else
1467 err << "'default'";
1468 return err;
1469 }
1470 return success();
1471 }
1472
1473 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1474 for (auto val : interVal->sequences)
1475 if (failed(isValidContext(val, op)))
1476 return failure();
1477 return success();
1478 }
1479
1480 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1481 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1482 if (failed(isValidContext(seqVal, op)))
1483 return failure();
1484
1485 return DeletionKind::Keep;
1486 }
1487
1488 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1489 SetVector<ElaboratorValue> set;
1490 for (auto val : op.getElements())
1491 set.insert(state.at(val));
1492
1493 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1494 std::move(set), op.getSet().getType());
1495 return DeletionKind::Delete;
1496 }
1497
1498 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1499 auto set = get<SetStorage *>(op.getSet())->set;
1500
1501 if (set.empty())
1502 return op->emitError("cannot select from an empty set");
1503
1504 size_t selected = testState.rng.getUniformlyInRange(0, set.size() - 1);
1505 state[op.getResult()] = set[selected];
1506 return DeletionKind::Delete;
1507 }
1508
1509 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1510 auto original = get<SetStorage *>(op.getOriginal())->set;
1511 auto diff = get<SetStorage *>(op.getDiff())->set;
1512
1513 SetVector<ElaboratorValue> result(original);
1514 result.set_subtract(diff);
1515
1516 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1517 std::move(result), op.getResult().getType());
1518 return DeletionKind::Delete;
1519 }
1520
1521 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1522 SetVector<ElaboratorValue> result;
1523 for (auto set : op.getSets())
1524 result.set_union(get<SetStorage *>(set)->set);
1525
1526 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1527 std::move(result), op.getType());
1528 return DeletionKind::Delete;
1529 }
1530
1531 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1532 auto size = get<SetStorage *>(op.getSet())->set.size();
1533 state[op.getResult()] = size;
1534 return DeletionKind::Delete;
1535 }
1536
1537 // {a0,a1} x {b0,b1} x {c0,c1} -> {(a0), (a1)} -> {(a0,b0), (a0,b1), (a1,b0),
1538 // (a1,b1)} -> {(a0,b0,c0), (a0,b0,c1), (a0,b1,c0), (a0,b1,c1), (a1,b0,c0),
1539 // (a1,b0,c1), (a1,b1,c0), (a1,b1,c1)}
1540 FailureOr<DeletionKind> visitOp(SetCartesianProductOp op) {
1541 SetVector<ElaboratorValue> result;
1542 SmallVector<SmallVector<ElaboratorValue>> tuples;
1543 tuples.push_back({});
1544
1545 for (auto input : op.getInputs()) {
1546 auto &set = get<SetStorage *>(input)->set;
1547 if (set.empty()) {
1548 SetVector<ElaboratorValue> empty;
1549 state[op.getResult()] =
1550 sharedState.internalizer.internalize<SetStorage>(std::move(empty),
1551 op.getType());
1552 return DeletionKind::Delete;
1553 }
1554
1555 for (unsigned i = 0, e = tuples.size(); i < e; ++i) {
1556 for (auto setEl : set.getArrayRef().drop_back()) {
1557 tuples.push_back(tuples[i]);
1558 tuples.back().push_back(setEl);
1559 }
1560 tuples[i].push_back(set.back());
1561 }
1562 }
1563
1564 for (auto &tup : tuples)
1565 result.insert(
1566 sharedState.internalizer.internalize<TupleStorage>(std::move(tup)));
1567
1568 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1569 std::move(result), op.getType());
1570 return DeletionKind::Delete;
1571 }
1572
1573 FailureOr<DeletionKind> visitOp(SetConvertToBagOp op) {
1574 auto set = get<SetStorage *>(op.getInput())->set;
1576 for (auto val : set)
1577 bag.insert({val, 1});
1578 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1579 std::move(bag), op.getType());
1580 return DeletionKind::Delete;
1581 }
1582
1583 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1585 for (auto [val, multiple] :
1586 llvm::zip(op.getElements(), op.getMultiples())) {
1587 // If the multiple is not stored as an AttributeValue, the elaboration
1588 // must have already failed earlier (since we don't have
1589 // unevaluated/opaque values).
1590 bag[state.at(val)] += get<size_t>(multiple);
1591 }
1592
1593 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1594 std::move(bag), op.getType());
1595 return DeletionKind::Delete;
1596 }
1597
1598 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1599 auto bag = get<BagStorage *>(op.getBag())->bag;
1600
1601 if (bag.empty())
1602 return op->emitError("cannot select from an empty bag");
1603
1604 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1605 prefixSum.reserve(bag.size());
1606 uint32_t accumulator = 0;
1607 for (auto [val, weight] : bag) {
1608 accumulator += weight;
1609 prefixSum.push_back({val, accumulator});
1610 }
1611
1612 auto idx = testState.rng.getUniformlyInRange(0, accumulator - 1);
1613 auto *iter = llvm::upper_bound(
1614 prefixSum, idx,
1615 [](uint32_t a, const std::pair<ElaboratorValue, uint32_t> &b) {
1616 return a < b.second;
1617 });
1618
1619 state[op.getResult()] = iter->first;
1620 return DeletionKind::Delete;
1621 }
1622
1623 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1624 auto original = get<BagStorage *>(op.getOriginal())->bag;
1625 auto diff = get<BagStorage *>(op.getDiff())->bag;
1626
1628 for (const auto &el : original) {
1629 if (!diff.contains(el.first)) {
1630 result.insert(el);
1631 continue;
1632 }
1633
1634 if (op.getInf())
1635 continue;
1636
1637 auto toDiff = diff.lookup(el.first);
1638 if (el.second <= toDiff)
1639 continue;
1640
1641 result.insert({el.first, el.second - toDiff});
1642 }
1643
1644 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1645 std::move(result), op.getType());
1646 return DeletionKind::Delete;
1647 }
1648
1649 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1651 for (auto bag : op.getBags()) {
1652 auto val = get<BagStorage *>(bag)->bag;
1653 for (auto [el, multiple] : val)
1654 result[el] += multiple;
1655 }
1656
1657 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1658 std::move(result), op.getType());
1659 return DeletionKind::Delete;
1660 }
1661
1662 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1663 auto size = get<BagStorage *>(op.getBag())->bag.size();
1664 state[op.getResult()] = size;
1665 return DeletionKind::Delete;
1666 }
1667
1668 FailureOr<DeletionKind> visitOp(BagConvertToSetOp op) {
1669 auto bag = get<BagStorage *>(op.getInput())->bag;
1670 SetVector<ElaboratorValue> set;
1671 for (auto [k, v] : bag)
1672 set.insert(k);
1673 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1674 std::move(set), op.getType());
1675 return DeletionKind::Delete;
1676 }
1677
1678 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1679 auto *val = sharedState.internalizer.create<VirtualRegisterStorage>(
1680 op.getAllowedRegsAttr(), op.getType(), op.getLoc());
1681 state[op.getResult()] = val;
1682 materializer.registerIdentityValue(val);
1683 return DeletionKind::Delete;
1684 }
1685
1686 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1687 SmallVector<ElaboratorValue> array;
1688 array.reserve(op.getElements().size());
1689 for (auto val : op.getElements())
1690 array.emplace_back(state.at(val));
1691
1692 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1693 op.getResult().getType(), std::move(array));
1694 return DeletionKind::Delete;
1695 }
1696
1697 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1698 auto array = get<ArrayStorage *>(op.getArray())->array;
1699 size_t idx = get<size_t>(op.getIndex());
1700
1701 if (array.size() <= idx)
1702 return op->emitError("invalid to access index ")
1703 << idx << " of an array with " << array.size() << " elements";
1704
1705 state[op.getResult()] = array[idx];
1706 return DeletionKind::Delete;
1707 }
1708
1709 FailureOr<DeletionKind> visitOp(ArrayInjectOp op) {
1710 auto arrayOpaque = state.at(op.getArray());
1711 auto idxOpaque = state.at(op.getIndex());
1712 if (isSymbolic(arrayOpaque) || isSymbolic(idxOpaque))
1713 return visitOpGeneric(op);
1714
1715 auto array = std::get<ArrayStorage *>(arrayOpaque)->array;
1716 size_t idx = std::get<size_t>(idxOpaque);
1717
1718 if (array.size() <= idx)
1719 return op->emitError("invalid to access index ")
1720 << idx << " of an array with " << array.size() << " elements";
1721
1722 array[idx] = state.at(op.getValue());
1723 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1724 op.getResult().getType(), std::move(array));
1725 return DeletionKind::Delete;
1726 }
1727
1728 FailureOr<DeletionKind> visitOp(ArrayAppendOp op) {
1729 auto array = std::get<ArrayStorage *>(state.at(op.getArray()))->array;
1730 array.push_back(state.at(op.getElement()));
1731 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1732 op.getResult().getType(), std::move(array));
1733 return DeletionKind::Delete;
1734 }
1735
1736 FailureOr<DeletionKind> visitOp(ArraySizeOp op) {
1737 auto array = get<ArrayStorage *>(op.getArray())->array;
1738 state[op.getResult()] = array.size();
1739 return DeletionKind::Delete;
1740 }
1741
1742 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1743 auto *val = sharedState.internalizer.create<UniqueLabelStorage>(
1744 state.at(op.getNamePrefix()), op.getLoc());
1745 state[op.getLabel()] = val;
1746 materializer.registerIdentityValue(val);
1747 return DeletionKind::Delete;
1748 }
1749
1750 FailureOr<DeletionKind> visitOp(RandomScopeOp op) {
1751 auto getNestedRng = [&]() -> RngScope {
1752 if (op.getSeed())
1753 return RngScope(op.getSeed()->getZExtValue());
1754 return testState.rng.getNested();
1755 };
1756
1757 RngScope nestedRng = getNestedRng();
1758 std::swap(testState.rng, nestedRng);
1759
1760 // Elaborate the body region. The region has no arguments and yields values
1761 // that become the results of this operation.
1762 // We reuse the same elaborator for the nested region because we need access
1763 // to the elaborated values outside the nested region (since it is not
1764 // isolated from above) and we want to delete the random_scope operation
1765 // entirely, thus don't need a new materializer instance.
1766 SmallVector<ElaboratorValue> yieldedVals;
1767 if (failed(elaborate(op.getBodyRegion(), {}, /*keepTerminator=*/false,
1768 yieldedVals)))
1769 return failure();
1770
1771 // Restore the previous RNG scope.
1772 std::swap(testState.rng, nestedRng);
1773
1774 // Map the results of the random_scope to the yielded values.
1775 for (auto [res, out] : llvm::zip(op.getResults(), yieldedVals))
1776 state[res] = out;
1777
1778 return DeletionKind::Delete;
1779 }
1780
1781 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1782 size_t lower = get<size_t>(op.getLowerBound());
1783 size_t upper = get<size_t>(op.getUpperBound());
1784 if (lower > upper)
1785 return op->emitError("cannot select a number from an empty range");
1786
1787 state[op.getResult()] =
1788 size_t(testState.rng.getUniformlyInRange(lower, upper));
1789 return DeletionKind::Delete;
1790 }
1791
1792 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1793 size_t input = get<size_t>(op.getInput());
1794 auto width = op.getType().getWidth();
1795 auto emitError = [&]() { return op->emitError(); };
1796 if (input > APInt::getAllOnes(width).getZExtValue())
1797 return emitError() << "cannot represent " << input << " with " << width
1798 << " bits";
1799
1800 state[op.getResult()] =
1801 ImmediateAttr::get(op.getContext(), APInt(width, input));
1802 return DeletionKind::Delete;
1803 }
1804
1805 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1806 ContextResourceAttrInterface from = currentContext,
1807 to = cast<ContextResourceAttrInterface>(
1808 get<TypedAttr>(op.getContext()));
1809 if (!currentContext)
1810 from = DefaultContextAttr::get(op->getContext(), to.getType());
1811
1812 auto emitError = [&]() {
1813 auto diag = op.emitError();
1814 diag.attachNote(op.getLoc())
1815 << "while materializing value for context switching for " << op;
1816 return diag;
1817 };
1818
1819 if (from == to) {
1820 Value seqVal = materializer.materialize(
1821 get<SequenceStorage *>(op.getSequence()), op.getLoc(), emitError);
1822 if (!seqVal)
1823 return failure();
1824
1825 Value randSeqVal =
1826 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1827 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1828 return DeletionKind::Delete;
1829 }
1830
1831 // Switch to the desired context.
1832 // First, check if a context switch is registered that has the concrete
1833 // context as source and target.
1834 auto *iter = testState.contextSwitches.find({from, to});
1835
1836 // Try with 'any' context as target and the concrete context as source.
1837 if (iter == testState.contextSwitches.end())
1838 iter = testState.contextSwitches.find(
1839 {from, AnyContextAttr::get(op->getContext(), to.getType())});
1840
1841 // Try with 'any' context as source and the concrete context as target.
1842 if (iter == testState.contextSwitches.end())
1843 iter = testState.contextSwitches.find(
1844 {AnyContextAttr::get(op->getContext(), from.getType()), to});
1845
1846 // Try with 'any' context for both the source and the target.
1847 if (iter == testState.contextSwitches.end())
1848 iter = testState.contextSwitches.find(
1849 {AnyContextAttr::get(op->getContext(), from.getType()),
1850 AnyContextAttr::get(op->getContext(), to.getType())});
1851
1852 // Otherwise, fail with an error because we couldn't find a user
1853 // specification on how to switch between the requested contexts.
1854 // NOTE: we could think about supporting context switching via intermediate
1855 // context, i.e., treat it as a transitive relation.
1856 if (iter == testState.contextSwitches.end())
1857 return op->emitError("no context transition registered to switch from ")
1858 << from << " to " << to;
1859
1860 auto familyName = iter->second->familyName;
1861 SmallVector<ElaboratorValue> args{from, to,
1862 get<SequenceStorage *>(op.getSequence())};
1863 auto *seq = sharedState.internalizer.internalize<SequenceStorage>(
1864 familyName, std::move(args));
1865 auto *randSeq = sharedState.internalizer.create<RandomizedSequenceStorage>(
1866 to, seq, op.getLoc());
1867 materializer.registerIdentityValue(randSeq);
1868 Value seqVal = materializer.materialize(randSeq, op.getLoc(), emitError);
1869 if (!seqVal)
1870 return failure();
1871
1872 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1873 return DeletionKind::Delete;
1874 }
1875
1876 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1877 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1878 get<SequenceStorage *>(op.getSequence());
1879 return DeletionKind::Delete;
1880 }
1881
1882 FailureOr<DeletionKind> visitOp(MemoryBlockDeclareOp op) {
1883 auto *val = sharedState.internalizer.create<MemoryBlockStorage>(
1884 op.getBaseAddress(), op.getEndAddress(), op.getType(), op.getLoc());
1885 state[op.getResult()] = val;
1886 materializer.registerIdentityValue(val);
1887 return DeletionKind::Delete;
1888 }
1889
1890 FailureOr<DeletionKind> visitOp(MemoryAllocOp op) {
1891 size_t size = get<size_t>(op.getSize());
1892 size_t alignment = get<size_t>(op.getAlignment());
1893 auto *memBlock = get<MemoryBlockStorage *>(op.getMemoryBlock());
1894 auto *val = sharedState.internalizer.create<MemoryStorage>(
1895 memBlock, size, alignment, op.getLoc());
1896 state[op.getResult()] = val;
1897 materializer.registerIdentityValue(val);
1898 return DeletionKind::Delete;
1899 }
1900
1901 FailureOr<DeletionKind> visitOp(MemorySizeOp op) {
1902 auto *memory = get<MemoryStorage *>(op.getMemory());
1903 state[op.getResult()] = memory->size;
1904 return DeletionKind::Delete;
1905 }
1906
1907 FailureOr<DeletionKind> visitOp(TupleCreateOp op) {
1908 SmallVector<ElaboratorValue> values;
1909 values.reserve(op.getElements().size());
1910 for (auto el : op.getElements())
1911 values.push_back(state.at(el));
1912
1913 state[op.getResult()] =
1914 sharedState.internalizer.internalize<TupleStorage>(std::move(values));
1915 return DeletionKind::Delete;
1916 }
1917
1918 FailureOr<DeletionKind> visitOp(TupleExtractOp op) {
1919 auto *tuple = get<TupleStorage *>(op.getTuple());
1920 state[op.getResult()] = tuple->values[op.getIndex().getZExtValue()];
1921 return DeletionKind::Delete;
1922 }
1923
1924 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1925 bool cond = get<bool>(op.getCondition());
1926 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1927 if (toElaborate.empty())
1928 return DeletionKind::Delete;
1929
1930 // Just reuse this elaborator for the nested region because we need access
1931 // to the elaborated values outside the nested region (since it is not
1932 // isolated from above) and we want to materialize the region inline, thus
1933 // don't need a new materializer instance.
1934 SmallVector<ElaboratorValue> yieldedVals;
1935 if (failed(
1936 elaborate(toElaborate, {}, /*keepTerminator=*/false, yieldedVals)))
1937 return failure();
1938
1939 // Map the results of the 'scf.if' to the yielded values.
1940 for (auto [res, out] : llvm::zip(op.getResults(), yieldedVals))
1941 state[res] = out;
1942
1943 return DeletionKind::Delete;
1944 }
1945
1946 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1947 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1948 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1949 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1950 return op->emitOpError("can only elaborate index type iterator");
1951
1952 auto lowerBound = get<size_t>(op.getLowerBound());
1953 auto step = get<size_t>(op.getStep());
1954 auto upperBound = get<size_t>(op.getUpperBound());
1955
1956 // Prepare for first iteration by assigning the nested regions block
1957 // arguments. We can just reuse this elaborator because we need access to
1958 // values elaborated in the parent region anyway and materialize everything
1959 // inline (i.e., don't need a new materializer).
1960 state[op.getInductionVar()] = lowerBound;
1961 for (auto [iterArg, initArg] :
1962 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1963 state[iterArg] = state.at(initArg);
1964
1965 // This loop performs the actual 'scf.for' loop iterations.
1966 SmallVector<ElaboratorValue> yieldedVals;
1967 for (size_t i = lowerBound; i < upperBound; i += step) {
1968 yieldedVals.clear();
1969 if (failed(elaborate(op.getBodyRegion(), {}, /*keepTerminator=*/false,
1970 yieldedVals)))
1971 return failure();
1972
1973 // Prepare for the next iteration by updating the mapping of the nested
1974 // regions block arguments
1975 state[op.getInductionVar()] = i + step;
1976 for (auto [iterArg, prevIterArg] :
1977 llvm::zip(op.getRegionIterArgs(), yieldedVals))
1978 state[iterArg] = prevIterArg;
1979 }
1980
1981 // Transfer the previously yielded values to the for loop result values.
1982 for (auto [res, iterArg] :
1983 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1984 state[res] = state.at(iterArg);
1985
1986 return DeletionKind::Delete;
1987 }
1988
1989 FailureOr<DeletionKind> visitOp(arith::AddIOp op) {
1990 if (!isa<IndexType>(op.getType()))
1991 return visitOpGeneric(op);
1992
1993 size_t lhs = get<size_t>(op.getLhs());
1994 size_t rhs = get<size_t>(op.getRhs());
1995 state[op.getResult()] = lhs + rhs;
1996 return DeletionKind::Delete;
1997 }
1998
1999 FailureOr<DeletionKind> visitOp(arith::AndIOp op) {
2000 if (!op.getType().isSignlessInteger(1))
2001 return visitOpGeneric(op);
2002
2003 bool lhs = get<bool>(op.getLhs());
2004 bool rhs = get<bool>(op.getRhs());
2005 state[op.getResult()] = lhs && rhs;
2006 return DeletionKind::Delete;
2007 }
2008
2009 FailureOr<DeletionKind> visitOp(arith::XOrIOp op) {
2010 if (!op.getType().isSignlessInteger(1))
2011 return visitOpGeneric(op);
2012
2013 bool lhs = get<bool>(op.getLhs());
2014 bool rhs = get<bool>(op.getRhs());
2015 state[op.getResult()] = lhs != rhs;
2016 return DeletionKind::Delete;
2017 }
2018
2019 FailureOr<DeletionKind> visitOp(arith::OrIOp op) {
2020 if (!op.getType().isSignlessInteger(1))
2021 return visitOpGeneric(op);
2022
2023 bool lhs = get<bool>(op.getLhs());
2024 bool rhs = get<bool>(op.getRhs());
2025 state[op.getResult()] = lhs || rhs;
2026 return DeletionKind::Delete;
2027 }
2028
2029 FailureOr<DeletionKind> visitOp(arith::SelectOp op) {
2030 auto condOpaque = state.at(op.getCondition());
2031 if (isSymbolic(condOpaque))
2032 return visitOpGeneric(op);
2033
2034 bool cond = std::get<bool>(condOpaque);
2035 auto trueVal = state.at(op.getTrueValue());
2036 auto falseVal = state.at(op.getFalseValue());
2037 state[op.getResult()] = cond ? trueVal : falseVal;
2038 return DeletionKind::Delete;
2039 }
2040
2041 FailureOr<DeletionKind> visitOp(index::AddOp op) {
2042 size_t lhs = get<size_t>(op.getLhs());
2043 size_t rhs = get<size_t>(op.getRhs());
2044 state[op.getResult()] = lhs + rhs;
2045 return DeletionKind::Delete;
2046 }
2047
2048 FailureOr<DeletionKind> visitOp(index::SubOp op) {
2049 size_t lhs = get<size_t>(op.getLhs());
2050 size_t rhs = get<size_t>(op.getRhs());
2051 state[op.getResult()] = lhs - rhs;
2052 return DeletionKind::Delete;
2053 }
2054
2055 FailureOr<DeletionKind> visitOp(index::MulOp op) {
2056 size_t lhs = get<size_t>(op.getLhs());
2057 size_t rhs = get<size_t>(op.getRhs());
2058 state[op.getResult()] = lhs * rhs;
2059 return DeletionKind::Delete;
2060 }
2061
2062 FailureOr<DeletionKind> visitOp(index::DivUOp op) {
2063 size_t lhs = get<size_t>(op.getLhs());
2064 size_t rhs = get<size_t>(op.getRhs());
2065
2066 if (rhs == 0)
2067 return op->emitOpError("attempted division by zero");
2068
2069 state[op.getResult()] = lhs / rhs;
2070 return DeletionKind::Delete;
2071 }
2072
2073 FailureOr<DeletionKind> visitOp(index::CeilDivUOp op) {
2074 size_t lhs = get<size_t>(op.getLhs());
2075 size_t rhs = get<size_t>(op.getRhs());
2076
2077 if (rhs == 0)
2078 return op->emitOpError("attempted division by zero");
2079
2080 if (lhs == 0)
2081 state[op.getResult()] = (lhs + rhs - 1) / rhs;
2082 else
2083 state[op.getResult()] = 1 + ((lhs - 1) / rhs);
2084
2085 return DeletionKind::Delete;
2086 }
2087
2088 FailureOr<DeletionKind> visitOp(index::RemUOp op) {
2089 size_t lhs = get<size_t>(op.getLhs());
2090 size_t rhs = get<size_t>(op.getRhs());
2091
2092 if (rhs == 0)
2093 return op->emitOpError("attempted division by zero");
2094
2095 state[op.getResult()] = lhs % rhs;
2096 return DeletionKind::Delete;
2097 }
2098
2099 FailureOr<DeletionKind> visitOp(index::AndOp op) {
2100 size_t lhs = get<size_t>(op.getLhs());
2101 size_t rhs = get<size_t>(op.getRhs());
2102 state[op.getResult()] = lhs & rhs;
2103 return DeletionKind::Delete;
2104 }
2105
2106 FailureOr<DeletionKind> visitOp(index::OrOp op) {
2107 size_t lhs = get<size_t>(op.getLhs());
2108 size_t rhs = get<size_t>(op.getRhs());
2109 state[op.getResult()] = lhs | rhs;
2110 return DeletionKind::Delete;
2111 }
2112
2113 FailureOr<DeletionKind> visitOp(index::XOrOp op) {
2114 size_t lhs = get<size_t>(op.getLhs());
2115 size_t rhs = get<size_t>(op.getRhs());
2116 state[op.getResult()] = lhs ^ rhs;
2117 return DeletionKind::Delete;
2118 }
2119
2120 FailureOr<DeletionKind> visitOp(index::ShlOp op) {
2121 size_t lhs = get<size_t>(op.getLhs());
2122 size_t rhs = get<size_t>(op.getRhs());
2123 state[op.getResult()] = lhs << rhs;
2124 return DeletionKind::Delete;
2125 }
2126
2127 FailureOr<DeletionKind> visitOp(index::ShrUOp op) {
2128 size_t lhs = get<size_t>(op.getLhs());
2129 size_t rhs = get<size_t>(op.getRhs());
2130 state[op.getResult()] = lhs >> rhs;
2131 return DeletionKind::Delete;
2132 }
2133
2134 FailureOr<DeletionKind> visitOp(index::MaxUOp op) {
2135 size_t lhs = get<size_t>(op.getLhs());
2136 size_t rhs = get<size_t>(op.getRhs());
2137 state[op.getResult()] = std::max(lhs, rhs);
2138 return DeletionKind::Delete;
2139 }
2140
2141 FailureOr<DeletionKind> visitOp(index::MinUOp op) {
2142 size_t lhs = get<size_t>(op.getLhs());
2143 size_t rhs = get<size_t>(op.getRhs());
2144 state[op.getResult()] = std::min(lhs, rhs);
2145 return DeletionKind::Delete;
2146 }
2147
2148 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
2149 size_t lhs = get<size_t>(op.getLhs());
2150 size_t rhs = get<size_t>(op.getRhs());
2151 bool result;
2152 switch (op.getPred()) {
2153 case index::IndexCmpPredicate::EQ:
2154 result = lhs == rhs;
2155 break;
2156 case index::IndexCmpPredicate::NE:
2157 result = lhs != rhs;
2158 break;
2159 case index::IndexCmpPredicate::ULT:
2160 result = lhs < rhs;
2161 break;
2162 case index::IndexCmpPredicate::ULE:
2163 result = lhs <= rhs;
2164 break;
2165 case index::IndexCmpPredicate::UGT:
2166 result = lhs > rhs;
2167 break;
2168 case index::IndexCmpPredicate::UGE:
2169 result = lhs >= rhs;
2170 break;
2171 default:
2172 return op->emitOpError("elaboration not supported");
2173 }
2174 state[op.getResult()] = result;
2175 return DeletionKind::Delete;
2176 }
2177
2178 bool isSymbolic(ElaboratorValue val) {
2179 return std::holds_alternative<SymbolicComputationWithIdentityValue *>(
2180 val) ||
2181 std::holds_alternative<SymbolicComputationWithIdentityStorage *>(
2182 val) ||
2183 std::holds_alternative<SymbolicComputationStorage *>(val);
2184 }
2185
2186 bool isSymbolic(Operation *op) {
2187 return llvm::any_of(op->getOperands(), [&](auto operand) {
2188 auto val = state.at(operand);
2189 return isSymbolic(val);
2190 });
2191 }
2192
2193 /// If all operands are constants, try to fold the operation and register its
2194 /// result values with the folded results in the elaborator state.
2195 /// Returns 'false' if operation has to be handled symbolically.
2196 bool attemptConcreteCase(Operation *op) {
2197 if (op->getNumResults() == 0)
2198 return false;
2199
2200 SmallVector<Attribute> operands;
2201 for (auto operand : op->getOperands()) {
2202 auto evalValue = state[operand];
2203 auto attr = elabValConverter.convert(evalValue);
2204 operands.push_back(attr);
2205 }
2206
2207 SmallVector<OpFoldResult> results;
2208 if (failed(op->fold(operands, results)))
2209 return false;
2210
2211 if (results.size() != op->getNumResults())
2212 return false;
2213
2214 for (auto [res, val] : llvm::zip(results, op->getResults())) {
2215 auto attr = llvm::dyn_cast_or_null<TypedAttr>(res.dyn_cast<Attribute>());
2216 if (!attr)
2217 return false;
2218
2219 if (attr.getType() != val.getType())
2220 return false;
2221
2222 // Try to convert the attribute to an ElaboratorValue
2223 auto converted = attrConverter.convert(attr);
2224 if (succeeded(converted)) {
2225 state[val] = *converted;
2226 continue;
2227 }
2228 return false;
2229 }
2230
2231 return true;
2232 }
2233
2234 FailureOr<DeletionKind> visitOpGeneric(Operation *op) {
2235 if (op->getNumResults() == 0)
2236 return DeletionKind::Keep;
2237
2238 if (attemptConcreteCase(op))
2239 return DeletionKind::Delete;
2240
2241 if (mlir::isMemoryEffectFree(op)) {
2242 if (op->getNumResults() != 1)
2243 return op->emitOpError(
2244 "symbolic elaboration of memory-effect-free operations with "
2245 "multiple results not supported");
2246
2247 state[op->getResult(0)] =
2248 sharedState.internalizer.internalize<SymbolicComputationStorage>(
2249 state, op);
2250 return DeletionKind::Delete;
2251 }
2252
2253 // We assume that reordering operations with only allocate effects is
2254 // allowed.
2255 // FIXME: this is not how the MLIR MemoryEffects interface intends it.
2256 // We should create our own interface/trait for that use-case.
2257 // Or modify the elaboration pass to keep track of the ordering of such
2258 // instructions and materialize all operations that are not already
2259 // materialized but have to happen before the current alloc operation to be
2260 // materialized.
2261 bool onlyAlloc = mlir::hasSingleEffect<mlir::MemoryEffects::Allocate>(op);
2262 onlyAlloc |= isa<ValidateOp>(op);
2263
2264 auto *validationVal =
2265 sharedState.internalizer.create<SymbolicComputationWithIdentityStorage>(
2266 state, op);
2267 materializer.registerIdentityValue(validationVal);
2268 state[op->getResult(0)] = validationVal;
2269
2270 for (auto [i, res] : llvm::enumerate(op->getResults())) {
2271 if (i == 0)
2272 continue;
2273 auto *val =
2274 sharedState.internalizer.create<SymbolicComputationWithIdentityValue>(
2275 res.getType(), validationVal, i);
2276 state[res] = val;
2277 materializer.registerIdentityValue(val);
2278 }
2279 return onlyAlloc ? DeletionKind::Delete : DeletionKind::Keep;
2280 }
2281
2282 bool supportsSymbolicValuesNonGenerically(Operation *op) {
2283 return isa<SubstituteSequenceOp, ArrayCreateOp, ArrayInjectOp,
2284 TupleCreateOp, arith::SelectOp>(op);
2285 }
2286
2287 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
2288 if (isSymbolic(op) && !supportsSymbolicValuesNonGenerically(op))
2289 return visitOpGeneric(op);
2290
2291 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
2292 .Case<
2293 // Arith ops
2294 arith::AddIOp, arith::XOrIOp, arith::AndIOp, arith::OrIOp,
2295 arith::SelectOp,
2296 // Index ops
2297 index::AddOp, index::SubOp, index::MulOp, index::DivUOp,
2298 index::CeilDivUOp, index::RemUOp, index::AndOp, index::OrOp,
2299 index::XOrOp, index::ShlOp, index::ShrUOp, index::MaxUOp,
2300 index::MinUOp, index::CmpOp,
2301 // SCF ops
2302 scf::IfOp, scf::ForOp>([&](auto op) { return visitOp(op); })
2303 .Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); });
2304 }
2305
2306 // NOLINTNEXTLINE(misc-no-recursion)
2307 LogicalResult elaborate(Region &region,
2308 ArrayRef<ElaboratorValue> regionArguments,
2309 bool keepTerminator,
2310 SmallVector<ElaboratorValue> &terminatorOperands) {
2311 if (region.getBlocks().size() > 1)
2312 return region.getParentOp()->emitOpError(
2313 "regions with more than one block are not supported");
2314
2315 for (auto [arg, elabArg] :
2316 llvm::zip(region.getArguments(), regionArguments))
2317 state[arg] = elabArg;
2318
2319 Block *block = &region.front();
2320 auto iter = keepTerminator ? *block : block->without_terminator();
2321 for (auto &op : iter) {
2322 auto result = dispatchOpVisitor(&op);
2323 if (failed(result))
2324 return failure();
2325
2326 if (*result == DeletionKind::Keep)
2327 if (failed(materializer.materialize(&op, state)))
2328 return failure();
2329
2330 LLVM_DEBUG({
2331 llvm::dbgs() << "Elaborated " << op << " to\n[";
2332
2333 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) {
2334 if (state.contains(res))
2335 llvm::dbgs() << state.at(res);
2336 else
2337 llvm::dbgs() << "unknown";
2338 });
2339
2340 llvm::dbgs() << "]\n\n";
2341 });
2342 }
2343
2344 if (!block->empty() && block->back().hasTrait<OpTrait::IsTerminator>()) {
2345 auto *terminator = block->getTerminator();
2346 for (auto val : terminator->getOperands())
2347 terminatorOperands.push_back(state.at(val));
2348
2349 if (!keepTerminator && materializer.isInPlace(terminator))
2350 terminator->erase();
2351 }
2352
2353 return success();
2354 }
2355
2356private:
2357 // State to be shared between all elaborator instances.
2358 SharedState &sharedState;
2359
2360 // State to a specific RTG test and the sequences placed within it.
2361 TestState &testState;
2362
2363 // Allows us to materialize ElaboratorValues to the IR operations necessary to
2364 // obtain an SSA value representing that elaborated value.
2365 Materializer &materializer;
2366
2367 // A map from SSA values to a pointer of an interned elaborator value.
2368 DenseMap<Value, ElaboratorValue> state;
2369
2370 // The current context we are elaborating under.
2371 ContextResourceAttrInterface currentContext;
2372
2373 // Allows us to convert attributes to ElaboratorValues.
2374 AttributeToElaboratorValueConverter attrConverter;
2375
2376 // Allows us to convert ElaboratorValues to attributes.
2377 ElaboratorValueToAttributeConverter elabValConverter;
2378};
2379} // namespace
2380
2381SequenceOp
2382Materializer::elaborateSequence(const RandomizedSequenceStorage *seq,
2383 SmallVector<ElaboratorValue> &elabArgs) {
2384 auto familyOp =
2385 sharedState.table.lookup<SequenceOp>(seq->sequence->familyName);
2386 // TODO: don't clone if this is the only remaining reference to this
2387 // sequence
2388 OpBuilder builder(familyOp);
2389 auto seqOp = builder.cloneWithoutRegions(familyOp);
2390 auto name = sharedState.names.newName(seq->sequence->familyName.getValue());
2391 seqOp.setSymName(name);
2392 seqOp.getBodyRegion().emplaceBlock();
2393 sharedState.table.insert(seqOp);
2394 assert(seqOp.getSymName() == name && "should not have been renamed");
2395
2396 LLVM_DEBUG(llvm::dbgs() << "\n=== Elaborating sequence family @"
2397 << familyOp.getSymName() << " into @"
2398 << seqOp.getSymName() << " under context "
2399 << seq->context << "\n\n");
2400
2401 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()), testState,
2402 sharedState, elabArgs);
2403 Elaborator elaborator(sharedState, testState, materializer, seq->context);
2404 SmallVector<ElaboratorValue> yieldedVals;
2405 if (failed(elaborator.elaborate(familyOp.getBodyRegion(), seq->sequence->args,
2406 /*keepTerminator=*/false, yieldedVals)))
2407 return {};
2408
2409 seqOp.setSequenceType(
2410 SequenceType::get(builder.getContext(), materializer.getBlockArgTypes()));
2411 materializer.finalize();
2412
2413 return seqOp;
2414}
2415
2416//===----------------------------------------------------------------------===//
2417// Elaborator Pass
2418//===----------------------------------------------------------------------===//
2419
2420namespace {
2421struct ElaborationPass
2422 : public rtg::impl::ElaborationPassBase<ElaborationPass> {
2423 using Base::Base;
2424
2425 void runOnOperation() override;
2426 void matchTestsAgainstTargets(SymbolTable &table);
2427 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
2428};
2429} // namespace
2430
2431void ElaborationPass::runOnOperation() {
2432 auto moduleOp = getOperation();
2433 SymbolTable table(moduleOp);
2434
2435 matchTestsAgainstTargets(table);
2436
2437 if (failed(elaborateModule(moduleOp, table)))
2438 return signalPassFailure();
2439}
2440
2441void ElaborationPass::matchTestsAgainstTargets(SymbolTable &table) {
2442 auto moduleOp = getOperation();
2443
2444 for (auto test : llvm::make_early_inc_range(moduleOp.getOps<TestOp>())) {
2445 if (test.getTargetAttr())
2446 continue;
2447
2448 bool matched = false;
2449
2450 for (auto target : moduleOp.getOps<TargetOp>()) {
2451 // Check if the target type is a subtype of the test's target type
2452 // This means that for each entry in the test's target type, there must be
2453 // a corresponding entry with the same name and type in the target's type
2454 bool isSubtype = true;
2455 auto testEntries = test.getTargetType().getEntries();
2456 auto targetEntries = target.getTarget().getEntries();
2457
2458 // Check if target is a subtype of test requirements
2459 // Since entries are sorted by name, we can do this in a single pass
2460 size_t targetIdx = 0;
2461 for (auto testEntry : testEntries) {
2462 // Find the matching entry in target entries.
2463 while (targetIdx < targetEntries.size() &&
2464 targetEntries[targetIdx].name.getValue() <
2465 testEntry.name.getValue())
2466 targetIdx++;
2467
2468 // Check if we found a matching entry with the same name and type
2469 if (targetIdx >= targetEntries.size() ||
2470 targetEntries[targetIdx].name != testEntry.name ||
2471 targetEntries[targetIdx].type != testEntry.type) {
2472 isSubtype = false;
2473 break;
2474 }
2475 }
2476
2477 if (!isSubtype)
2478 continue;
2479
2480 IRRewriter rewriter(test);
2481 // Create a new test for the matched target
2482 auto newTest = cast<TestOp>(test->clone());
2483 newTest.setSymName(test.getSymName().str() + "_" +
2484 target.getSymName().str());
2485
2486 // Set the target symbol specifying that this test is only suitable for
2487 // that target.
2488 newTest.setTargetAttr(target.getSymNameAttr());
2489
2490 table.insert(newTest, rewriter.getInsertionPoint());
2491 matched = true;
2492 }
2493
2494 if (matched || deleteUnmatchedTests)
2495 test->erase();
2496 }
2497}
2498
2499static bool onlyLegalToMaterializeInTarget(Type type) {
2500 return isa<MemoryBlockType, ContextResourceTypeInterface>(type);
2501}
2502
2503LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
2504 SymbolTable &table) {
2505 SharedState state(moduleOp.getContext(), table);
2506
2507 // Update the name cache
2508 state.names.add(moduleOp);
2509
2510 struct TargetElabResult {
2511 TargetElabResult(DictType targetType, uint32_t seed)
2512 : targetType(targetType), testState(seed) {}
2513
2514 DictType targetType;
2515 SmallVector<ElaboratorValue> yields;
2516 TestState testState;
2517 };
2518
2519 // Map to store elaborated targets
2520 DenseMap<StringAttr, TargetElabResult> targetMap;
2521 for (auto targetOp : moduleOp.getOps<TargetOp>()) {
2522 LLVM_DEBUG(llvm::dbgs() << "=== Elaborating target @"
2523 << targetOp.getSymName() << "\n\n");
2524
2525 auto [it, inserted] = targetMap.try_emplace(targetOp.getSymNameAttr(),
2526 targetOp.getTarget(), seed);
2527 auto &result = it->second;
2528
2529 SmallVector<ElaboratorValue> blockArgs;
2530 Materializer targetMaterializer(OpBuilder::atBlockBegin(targetOp.getBody()),
2531 result.testState, state, blockArgs);
2532 Elaborator targetElaborator(state, result.testState, targetMaterializer);
2533
2534 // Elaborate the target
2535 if (failed(targetElaborator.elaborate(targetOp.getBodyRegion(), {},
2536 /*keepTerminator=*/true,
2537 result.yields)))
2538 return failure();
2539
2540 targetMaterializer.finalize();
2541 }
2542
2543 // Initialize the worklist with the test ops since they cannot be placed by
2544 // other ops.
2545 for (auto testOp : moduleOp.getOps<TestOp>()) {
2546 // Skip tests without a target attribute - these couldn't be matched
2547 // against any target but can be useful to keep around for reporting
2548 // purposes.
2549 if (!testOp.getTargetAttr())
2550 continue;
2551
2552 LLVM_DEBUG(llvm::dbgs()
2553 << "\n=== Elaborating test @" << testOp.getTemplateName()
2554 << " for target @" << *testOp.getTarget() << "\n\n");
2555
2556 // Get the target for this test
2557 auto &targetResult = targetMap.at(testOp.getTargetAttr());
2558 TestState testState(seed);
2559 testState.contextSwitches = targetResult.testState.contextSwitches;
2560 testState.name = testOp.getSymNameAttr();
2561
2562 SmallVector<ElaboratorValue> filteredYields;
2563 unsigned i = 0;
2564 for (auto [entry, yield] :
2565 llvm::zip(targetResult.targetType.getEntries(), targetResult.yields)) {
2566 if (i >= testOp.getTargetType().getEntries().size())
2567 break;
2568
2569 if (entry.name == testOp.getTargetType().getEntries()[i].name) {
2570 filteredYields.push_back(yield);
2571 ++i;
2572 }
2573 }
2574
2575 // Now elaborate the test with the same state, passing the target yield
2576 // values as arguments
2577 SmallVector<ElaboratorValue> blockArgs;
2578 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()),
2579 testState, state, blockArgs);
2580
2581 for (auto [arg, val] :
2582 llvm::zip(testOp.getBody()->getArguments(), filteredYields))
2583 if (onlyLegalToMaterializeInTarget(arg.getType()))
2584 materializer.map(val, arg);
2585
2586 Elaborator elaborator(state, testState, materializer);
2587 SmallVector<ElaboratorValue> ignore;
2588 if (failed(elaborator.elaborate(testOp.getBodyRegion(), filteredYields,
2589 /*keepTerminator=*/false, ignore)))
2590 return failure();
2591
2592 materializer.finalize();
2593 }
2594
2595 return success();
2596}
assert(baseType &&"element must be base type")
static bool onlyLegalToMaterializeInTarget(Type type)
#define VISIT_UNSUPPORTED(STORAGETYPE)
static void print(TypedAttr val, llvm::raw_ostream &os)
static LogicalResult convert(arc::ExecuteOp op, arc::ExecuteOp::Adaptor adaptor, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:212
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
This helps visit TypeOp nodes.
Definition RTGVisitors.h:29
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:96
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:56
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static llvm::hash_code hash_value(const ModulePort &port)
Definition HWTypes.h:39
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
@ Delete
Erase the matched ops.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition Utils.h:36
Definition rtg.py:1
Definition seq.py:1
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getHashValue(const bool &val)