Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ElaborationPass.cpp
Go to the documentation of this file.
1//===- ElaborationPass.cpp - RTG ElaborationPass implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass elaborates the random parts of the RTG dialect.
10// It performs randomization top-down, i.e., random constructs in a sequence
11// that is invoked multiple times can yield different randomization results
12// for each invokation.
13//
14//===----------------------------------------------------------------------===//
15
21#include "mlir/Dialect/Index/IR/IndexDialect.h"
22#include "mlir/Dialect/Index/IR/IndexOps.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/IR/IRMapping.h"
25#include "mlir/IR/PatternMatch.h"
26#include "llvm/ADT/DenseMapInfoVariant.h"
27#include "llvm/Support/Debug.h"
28#include <queue>
29#include <random>
30
31namespace circt {
32namespace rtg {
33#define GEN_PASS_DEF_ELABORATIONPASS
34#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
35} // namespace rtg
36} // namespace circt
37
38using namespace mlir;
39using namespace circt;
40using namespace circt::rtg;
41using llvm::MapVector;
42
43#define DEBUG_TYPE "rtg-elaboration"
44
45//===----------------------------------------------------------------------===//
46// Uniform Distribution Helper
47//
48// Simplified version of
49// https://github.com/llvm/llvm-project/blob/main/libcxx/include/__random/uniform_int_distribution.h
50// We use our custom version here to get the same results when compiled with
51// different compiler versions and standard libraries.
52//===----------------------------------------------------------------------===//
53
54static uint32_t computeMask(size_t w) {
55 size_t n = w / 32 + (w % 32 != 0);
56 size_t w0 = w / n;
57 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
58}
59
60/// Get a number uniformly at random in the in specified range.
61static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) {
62 const uint32_t diff = b - a + 1;
63 if (diff == 1)
64 return a;
65
66 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
67 if (diff == 0)
68 return rng();
69
70 uint32_t width = digits - llvm::countl_zero(diff) - 1;
71 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
72 ++width;
73
74 uint32_t mask = computeMask(diff);
75 uint32_t u;
76 do {
77 u = rng() & mask;
78 } while (u >= diff);
79
80 return u + a;
81}
82
83//===----------------------------------------------------------------------===//
84// Elaborator Value
85//===----------------------------------------------------------------------===//
86
87namespace {
88struct ArrayStorage;
89struct BagStorage;
90struct SequenceStorage;
91struct RandomizedSequenceStorage;
92struct InterleavedSequenceStorage;
93struct SetStorage;
94struct VirtualRegisterStorage;
95struct UniqueLabelStorage;
96
97/// Simple wrapper around a 'StringAttr' such that we know to materialize it as
98/// a label declaration instead of calling the builtin dialect constant
99/// materializer.
100struct LabelValue {
101 LabelValue(StringAttr name) : name(name) {}
102
103 bool operator==(const LabelValue &other) const { return name == other.name; }
104
105 /// The label name.
106 StringAttr name;
107};
108
109/// The abstract base class for elaborated values.
110using ElaboratorValue =
111 std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
112 RandomizedSequenceStorage *, InterleavedSequenceStorage *,
113 SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
114 LabelValue, ArrayStorage *>;
115
116// NOLINTNEXTLINE(readability-identifier-naming)
117llvm::hash_code hash_value(const LabelValue &val) {
118 return llvm::hash_value(val.name);
119}
120
121// NOLINTNEXTLINE(readability-identifier-naming)
122llvm::hash_code hash_value(const ElaboratorValue &val) {
123 return std::visit(
124 [&val](const auto &alternative) {
125 // Include index in hash to make sure same value as different
126 // alternatives don't collide.
127 return llvm::hash_combine(val.index(), alternative);
128 },
129 val);
130}
131
132} // namespace
133
134namespace llvm {
135
136template <>
137struct DenseMapInfo<bool> {
138 static inline unsigned getEmptyKey() { return false; }
139 static inline unsigned getTombstoneKey() { return true; }
140 static unsigned getHashValue(const bool &val) { return val * 37U; }
141
142 static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; }
143};
144template <>
145struct DenseMapInfo<LabelValue> {
146 static inline LabelValue getEmptyKey() {
148 }
149 static inline LabelValue getTombstoneKey() {
151 }
152 static unsigned getHashValue(const LabelValue &val) {
153 return hash_value(val);
154 }
155
156 static bool isEqual(const LabelValue &lhs, const LabelValue &rhs) {
157 return lhs == rhs;
158 }
159};
160
161} // namespace llvm
162
163//===----------------------------------------------------------------------===//
164// Elaborator Value Storages and Internalization
165//===----------------------------------------------------------------------===//
166
167namespace {
168
169/// Lightweight object to be used as the key for internalization sets. It caches
170/// the hashcode of the internalized object and a pointer to it. This allows a
171/// delayed allocation and construction of the actual object and thus only has
172/// to happen if the object is not already in the set.
173template <typename StorageTy>
174struct HashedStorage {
175 HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr)
176 : hashcode(hashcode), storage(storage) {}
177
178 unsigned hashcode;
179 StorageTy *storage;
180};
181
182/// A DenseMapInfo implementation to support 'insert_as' for the internalization
183/// sets. When comparing two 'HashedStorage's we can just compare the already
184/// internalized storage pointers, otherwise we have to call the costly
185/// 'isEqual' method.
186template <typename StorageTy>
187struct StorageKeyInfo {
188 static inline HashedStorage<StorageTy> getEmptyKey() {
189 return HashedStorage<StorageTy>(0,
190 DenseMapInfo<StorageTy *>::getEmptyKey());
191 }
192 static inline HashedStorage<StorageTy> getTombstoneKey() {
193 return HashedStorage<StorageTy>(
194 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
195 }
196
197 static inline unsigned getHashValue(const HashedStorage<StorageTy> &key) {
198 return key.hashcode;
199 }
200 static inline unsigned getHashValue(const StorageTy &key) {
201 return key.hashcode;
202 }
203
204 static inline bool isEqual(const HashedStorage<StorageTy> &lhs,
205 const HashedStorage<StorageTy> &rhs) {
206 return lhs.storage == rhs.storage;
207 }
208 static inline bool isEqual(const StorageTy &lhs,
209 const HashedStorage<StorageTy> &rhs) {
210 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
211 return false;
212
213 return lhs.isEqual(rhs.storage);
214 }
215};
216
217/// Storage object for an '!rtg.set<T>'.
218struct SetStorage {
219 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
220 : hashcode(llvm::hash_combine(
221 type, llvm::hash_combine_range(set.begin(), set.end()))),
222 set(std::move(set)), type(type) {}
223
224 bool isEqual(const SetStorage *other) const {
225 return hashcode == other->hashcode && set == other->set &&
226 type == other->type;
227 }
228
229 // The cached hashcode to avoid repeated computations.
230 const unsigned hashcode;
231
232 // Stores the elaborated values contained in the set.
233 const SetVector<ElaboratorValue> set;
234
235 // Store the set type such that we can materialize this evaluated value
236 // also in the case where the set is empty.
237 const Type type;
238};
239
240/// Storage object for an '!rtg.bag<T>'.
241struct BagStorage {
242 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
243 : hashcode(llvm::hash_combine(
244 type, llvm::hash_combine_range(bag.begin(), bag.end()))),
245 bag(std::move(bag)), type(type) {}
246
247 bool isEqual(const BagStorage *other) const {
248 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
249 type == other->type;
250 }
251
252 // The cached hashcode to avoid repeated computations.
253 const unsigned hashcode;
254
255 // Stores the elaborated values contained in the bag with their number of
256 // occurences.
257 const MapVector<ElaboratorValue, uint64_t> bag;
258
259 // Store the bag type such that we can materialize this evaluated value
260 // also in the case where the bag is empty.
261 const Type type;
262};
263
264/// Storage object for an '!rtg.sequence'.
265struct SequenceStorage {
266 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
267 : hashcode(llvm::hash_combine(
268 familyName, llvm::hash_combine_range(args.begin(), args.end()))),
269 familyName(familyName), args(std::move(args)) {}
270
271 bool isEqual(const SequenceStorage *other) const {
272 return hashcode == other->hashcode && familyName == other->familyName &&
273 args == other->args;
274 }
275
276 // The cached hashcode to avoid repeated computations.
277 const unsigned hashcode;
278
279 // The name of the sequence family this sequence is derived from.
280 const StringAttr familyName;
281
282 // The elaborator values used during substitution of the sequence family.
283 const SmallVector<ElaboratorValue> args;
284};
285
286/// Storage object for an '!rtg.randomized_sequence'.
287struct RandomizedSequenceStorage {
288 RandomizedSequenceStorage(StringRef name,
289 ContextResourceAttrInterface context,
290 StringAttr test, SequenceStorage *sequence)
291 : hashcode(llvm::hash_combine(name, context, test, sequence)), name(name),
292 context(context), test(test), sequence(sequence) {}
293
294 bool isEqual(const RandomizedSequenceStorage *other) const {
295 return hashcode == other->hashcode && name == other->name &&
296 context == other->context && test == other->test &&
297 sequence == other->sequence;
298 }
299
300 // The cached hashcode to avoid repeated computations.
301 const unsigned hashcode;
302
303 // The name of this fully substituted and elaborated sequence.
304 const StringRef name;
305
306 // The context under which this sequence is placed.
307 const ContextResourceAttrInterface context;
308
309 // The test in which this sequence is placed.
310 const StringAttr test;
311
312 const SequenceStorage *sequence;
313};
314
315/// Storage object for interleaved '!rtg.randomized_sequence'es.
316struct InterleavedSequenceStorage {
317 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
318 uint32_t batchSize)
319 : sequences(std::move(sequences)), batchSize(batchSize),
320 hashcode(llvm::hash_combine(
321 llvm::hash_combine_range(sequences.begin(), sequences.end()),
322 batchSize)) {}
323
324 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
325 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
326 hashcode(llvm::hash_combine(
327 llvm::hash_combine_range(sequences.begin(), sequences.end()),
328 batchSize)) {}
329
330 bool isEqual(const InterleavedSequenceStorage *other) const {
331 return hashcode == other->hashcode && sequences == other->sequences &&
332 batchSize == other->batchSize;
333 }
334
335 const SmallVector<ElaboratorValue> sequences;
336
337 const uint32_t batchSize;
338
339 // The cached hashcode to avoid repeated computations.
340 const unsigned hashcode;
341};
342
343/// Represents a unique virtual register.
344struct VirtualRegisterStorage {
345 VirtualRegisterStorage(ArrayAttr allowedRegs) : allowedRegs(allowedRegs) {}
346
347 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
348 // VirtualRegisters are never internalized.
349
350 // The list of fixed registers allowed to be selected for this virtual
351 // register.
352 const ArrayAttr allowedRegs;
353};
354
355struct UniqueLabelStorage {
356 UniqueLabelStorage(StringAttr name) : name(name) {}
357
358 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
359 // VirtualRegisters are never internalized.
360
361 /// The label name. For unique labels, this is just the prefix.
362 const StringAttr name;
363};
364
365/// Storage object for '!rtg.array`-typed values.
366struct ArrayStorage {
367 ArrayStorage(Type type, SmallVector<ElaboratorValue> &&array)
368 : hashcode(llvm::hash_combine(
369 type, llvm::hash_combine_range(array.begin(), array.end()))),
370 type(type), array(array) {}
371
372 bool isEqual(const ArrayStorage *other) const {
373 return hashcode == other->hashcode && type == other->type &&
374 array == other->array;
375 }
376
377 // The cached hashcode to avoid repeated computations.
378 const unsigned hashcode;
379
380 /// The type of the array. This is necessary because an array of size 0
381 /// cannot be reconstructed without knowing the original element type.
382 const Type type;
383
384 /// The label name. For unique labels, this is just the prefix.
385 const SmallVector<ElaboratorValue> array;
386};
387
388/// An 'Internalizer' object internalizes storages and takes ownership of them.
389/// When the initializer object is destroyed, all owned storages are also
390/// deallocated and thus must not be accessed anymore.
391class Internalizer {
392public:
393 /// Internalize a storage of type `StorageTy` constructed with arguments
394 /// `args`. The pointers returned by this method can be used to compare
395 /// objects when, e.g., computing set differences, uniquing the elements in a
396 /// set, etc. Otherwise, we'd need to do a deep value comparison in those
397 /// situations.
398 template <typename StorageTy, typename... Args>
399 StorageTy *internalize(Args &&...args) {
400 StorageTy storage(std::forward<Args>(args)...);
401
402 auto existing = getInternSet<StorageTy>().insert_as(
403 HashedStorage<StorageTy>(storage.hashcode), storage);
404 StorageTy *&storagePtr = existing.first->storage;
405 if (existing.second)
406 storagePtr =
407 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
408
409 return storagePtr;
410 }
411
412 template <typename StorageTy, typename... Args>
413 StorageTy *create(Args &&...args) {
414 return new (allocator.Allocate<StorageTy>())
415 StorageTy(std::forward<Args>(args)...);
416 }
417
418private:
419 template <typename StorageTy>
420 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
421 getInternSet() {
422 if constexpr (std::is_same_v<StorageTy, ArrayStorage>)
423 return internedArrays;
424 else if constexpr (std::is_same_v<StorageTy, SetStorage>)
425 return internedSets;
426 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
427 return internedBags;
428 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
429 return internedSequences;
430 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
431 return internedRandomizedSequences;
432 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
433 return internedInterleavedSequences;
434 else
435 static_assert(!sizeof(StorageTy),
436 "no intern set available for this storage type.");
437 }
438
439 // This allocator allocates on the heap. It automatically deallocates all
440 // objects it allocated once the allocator itself is destroyed.
441 llvm::BumpPtrAllocator allocator;
442
443 // The sets holding the internalized objects. We use one set per storage type
444 // such that we can have a simpler equality checking function (no need to
445 // compare some sort of TypeIDs).
446 DenseSet<HashedStorage<ArrayStorage>, StorageKeyInfo<ArrayStorage>>
447 internedArrays;
448 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
449 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
450 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
451 internedSequences;
452 DenseSet<HashedStorage<RandomizedSequenceStorage>,
453 StorageKeyInfo<RandomizedSequenceStorage>>
454 internedRandomizedSequences;
455 DenseSet<HashedStorage<InterleavedSequenceStorage>,
456 StorageKeyInfo<InterleavedSequenceStorage>>
457 internedInterleavedSequences;
458};
459
460} // namespace
461
462#ifndef NDEBUG
463
464static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
465 const ElaboratorValue &value);
466
467static void print(TypedAttr val, llvm::raw_ostream &os) {
468 os << "<attr " << val << ">";
469}
470
471static void print(BagStorage *val, llvm::raw_ostream &os) {
472 os << "<bag {";
473 llvm::interleaveComma(val->bag, os,
474 [&](const std::pair<ElaboratorValue, uint64_t> &el) {
475 os << el.first << " -> " << el.second;
476 });
477 os << "} at " << val << ">";
478}
479
480static void print(bool val, llvm::raw_ostream &os) {
481 os << "<bool " << (val ? "true" : "false") << ">";
482}
483
484static void print(size_t val, llvm::raw_ostream &os) {
485 os << "<index " << val << ">";
486}
487
488static void print(SequenceStorage *val, llvm::raw_ostream &os) {
489 os << "<sequence @" << val->familyName.getValue() << "(";
490 llvm::interleaveComma(val->args, os,
491 [&](const ElaboratorValue &val) { os << val; });
492 os << ") at " << val << ">";
493}
494
495static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
496 os << "<randomized-sequence @" << val->name << " derived from @"
497 << val->sequence->familyName.getValue() << " under context "
498 << val->context << " in test " << val->test << "(";
499 llvm::interleaveComma(val->sequence->args, os,
500 [&](const ElaboratorValue &val) { os << val; });
501 os << ") at " << val << ">";
502}
503
504static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
505 os << "<interleaved-sequence [";
506 llvm::interleaveComma(val->sequences, os,
507 [&](const ElaboratorValue &val) { os << val; });
508 os << "] batch-size " << val->batchSize << " at " << val << ">";
509}
510
511static void print(ArrayStorage *val, llvm::raw_ostream &os) {
512 os << "<array [";
513 llvm::interleaveComma(val->array, os,
514 [&](const ElaboratorValue &val) { os << val; });
515 os << "] at " << val << ">";
516}
517
518static void print(SetStorage *val, llvm::raw_ostream &os) {
519 os << "<set {";
520 llvm::interleaveComma(val->set, os,
521 [&](const ElaboratorValue &val) { os << val; });
522 os << "} at " << val << ">";
523}
524
525static void print(const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
526 os << "<virtual-register " << val << " " << val->allowedRegs << ">";
527}
528
529static void print(const UniqueLabelStorage *val, llvm::raw_ostream &os) {
530 os << "<unique-label " << val << " " << val->name << ">";
531}
532
533static void print(const LabelValue &val, llvm::raw_ostream &os) {
534 os << "<label " << val.name << ">";
535}
536
537static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
538 const ElaboratorValue &value) {
539 std::visit([&](auto val) { print(val, os); }, value);
540
541 return os;
542}
543
544#endif
545
546//===----------------------------------------------------------------------===//
547// Elaborator Value Materialization
548//===----------------------------------------------------------------------===//
549
550namespace {
551
552/// Construct an SSA value from a given elaborated value.
553class Materializer {
554public:
555 Materializer(OpBuilder builder) : builder(builder) {}
556
557 /// Materialize IR representing the provided `ElaboratorValue` and return the
558 /// `Value` or a null value on failure.
559 Value materialize(ElaboratorValue val, Location loc,
560 std::queue<RandomizedSequenceStorage *> &elabRequests,
561 function_ref<InFlightDiagnostic()> emitError) {
562 auto iter = materializedValues.find(val);
563 if (iter != materializedValues.end())
564 return iter->second;
565
566 LLVM_DEBUG(llvm::dbgs() << "Materializing " << val << "\n\n");
567
568 return std::visit(
569 [&](auto val) { return visit(val, loc, elabRequests, emitError); },
570 val);
571 }
572
573 /// If `op` is not in the same region as the materializer insertion point, a
574 /// clone is created at the materializer's insertion point by also
575 /// materializing the `ElaboratorValue`s for each operand just before it.
576 /// Otherwise, all operations after the materializer's insertion point are
577 /// deleted until `op` is reached. An error is returned if the operation is
578 /// before the insertion point.
579 LogicalResult
580 materialize(Operation *op, DenseMap<Value, ElaboratorValue> &state,
581 std::queue<RandomizedSequenceStorage *> &elabRequests) {
582 if (op->getNumRegions() > 0)
583 return op->emitOpError("ops with nested regions must be elaborated away");
584
585 // We don't support opaque values. If there is an SSA value that has a
586 // use-site it needs an equivalent ElaborationValue representation.
587 // NOTE: We could support cases where there is initially a use-site but that
588 // op is guaranteed to be deleted during elaboration. Or the use-sites are
589 // replaced with freshly materialized values from the ElaborationValue. But
590 // then, why can't we delete the value defining op?
591 for (auto res : op->getResults())
592 if (!res.use_empty())
593 return op->emitOpError(
594 "ops with results that have uses are not supported");
595
596 if (op->getParentRegion() == builder.getBlock()->getParent()) {
597 // We are doing in-place materialization, so mark all ops deleted until we
598 // reach the one to be materialized and modify it in-place.
599 deleteOpsUntil([&](auto iter) { return &*iter == op; });
600
601 if (builder.getInsertionPoint() == builder.getBlock()->end())
602 return op->emitError("operation did not occur after the current "
603 "materializer insertion point");
604
605 LLVM_DEBUG(llvm::dbgs() << "Modifying in-place: " << *op << "\n\n");
606 } else {
607 LLVM_DEBUG(llvm::dbgs() << "Materializing a clone of " << *op << "\n\n");
608 op = builder.clone(*op);
609 builder.setInsertionPoint(op);
610 }
611
612 for (auto &operand : op->getOpOperands()) {
613 auto emitError = [&]() {
614 auto diag = op->emitError();
615 diag.attachNote(op->getLoc())
616 << "while materializing value for operand#"
617 << operand.getOperandNumber();
618 return diag;
619 };
620
621 Value val = materialize(state.at(operand.get()), op->getLoc(),
622 elabRequests, emitError);
623 if (!val)
624 return failure();
625
626 operand.set(val);
627 }
628
629 builder.setInsertionPointAfter(op);
630 return success();
631 }
632
633 /// Should be called once the `Region` is successfully materialized. No calls
634 /// to `materialize` should happen after this anymore.
635 void finalize() {
636 deleteOpsUntil([](auto iter) { return false; });
637
638 for (auto *op : llvm::reverse(toDelete))
639 op->erase();
640 }
641
642 template <typename OpTy, typename... Args>
643 OpTy create(Location location, Args &&...args) {
644 return builder.create<OpTy>(location, std::forward<Args>(args)...);
645 }
646
647private:
648 void deleteOpsUntil(function_ref<bool(Block::iterator)> stop) {
649 auto ip = builder.getInsertionPoint();
650 while (ip != builder.getBlock()->end() && !stop(ip)) {
651 LLVM_DEBUG(llvm::dbgs() << "Marking to be deleted: " << *ip << "\n\n");
652 toDelete.push_back(&*ip);
653
654 builder.setInsertionPointAfter(&*ip);
655 ip = builder.getInsertionPoint();
656 }
657 }
658
659 Value visit(TypedAttr val, Location loc,
660 std::queue<RandomizedSequenceStorage *> &elabRequests,
661 function_ref<InFlightDiagnostic()> emitError) {
662 // For index attributes (and arithmetic operations on them) we use the
663 // index dialect.
664 if (auto intAttr = dyn_cast<IntegerAttr>(val);
665 intAttr && isa<IndexType>(val.getType())) {
666 Value res = builder.create<index::ConstantOp>(loc, intAttr);
667 materializedValues[val] = res;
668 return res;
669 }
670
671 // For any other attribute, we just call the materializer of the dialect
672 // defining that attribute.
673 auto *op =
674 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
675 if (!op) {
676 emitError() << "materializer of dialect '"
677 << val.getDialect().getNamespace()
678 << "' unable to materialize value for attribute '" << val
679 << "'";
680 return Value();
681 }
682
683 Value res = op->getResult(0);
684 materializedValues[val] = res;
685 return res;
686 }
687
688 Value visit(size_t val, Location loc,
689 std::queue<RandomizedSequenceStorage *> &elabRequests,
690 function_ref<InFlightDiagnostic()> emitError) {
691 Value res = builder.create<index::ConstantOp>(loc, val);
692 materializedValues[val] = res;
693 return res;
694 }
695
696 Value visit(bool val, Location loc,
697 std::queue<RandomizedSequenceStorage *> &elabRequests,
698 function_ref<InFlightDiagnostic()> emitError) {
699 Value res = builder.create<index::BoolConstantOp>(loc, val);
700 materializedValues[val] = res;
701 return res;
702 }
703
704 Value visit(ArrayStorage *val, Location loc,
705 std::queue<RandomizedSequenceStorage *> &elabRequests,
706 function_ref<InFlightDiagnostic()> emitError) {
707 SmallVector<Value> elements;
708 elements.reserve(val->array.size());
709 for (auto el : val->array) {
710 auto materialized = materialize(el, loc, elabRequests, emitError);
711 if (!materialized)
712 return Value();
713
714 elements.push_back(materialized);
715 }
716
717 auto res = builder.create<ArrayCreateOp>(loc, val->type, elements);
718 materializedValues[val] = res;
719 return res;
720 }
721
722 Value visit(SetStorage *val, Location loc,
723 std::queue<RandomizedSequenceStorage *> &elabRequests,
724 function_ref<InFlightDiagnostic()> emitError) {
725 SmallVector<Value> elements;
726 elements.reserve(val->set.size());
727 for (auto el : val->set) {
728 auto materialized = materialize(el, loc, elabRequests, emitError);
729 if (!materialized)
730 return Value();
731
732 elements.push_back(materialized);
733 }
734
735 auto res = builder.create<SetCreateOp>(loc, val->type, elements);
736 materializedValues[val] = res;
737 return res;
738 }
739
740 Value visit(BagStorage *val, Location loc,
741 std::queue<RandomizedSequenceStorage *> &elabRequests,
742 function_ref<InFlightDiagnostic()> emitError) {
743 SmallVector<Value> values, weights;
744 values.reserve(val->bag.size());
745 weights.reserve(val->bag.size());
746 for (auto [val, weight] : val->bag) {
747 auto materializedVal = materialize(val, loc, elabRequests, emitError);
748 auto materializedWeight =
749 materialize(weight, loc, elabRequests, emitError);
750 if (!materializedVal || !materializedWeight)
751 return Value();
752
753 values.push_back(materializedVal);
754 weights.push_back(materializedWeight);
755 }
756
757 auto res = builder.create<BagCreateOp>(loc, val->type, values, weights);
758 materializedValues[val] = res;
759 return res;
760 }
761
762 Value visit(SequenceStorage *val, Location loc,
763 std::queue<RandomizedSequenceStorage *> &elabRequests,
764 function_ref<InFlightDiagnostic()> emitError) {
765 emitError() << "materializing a non-randomized sequence not supported yet";
766 return Value();
767 }
768
769 Value visit(RandomizedSequenceStorage *val, Location loc,
770 std::queue<RandomizedSequenceStorage *> &elabRequests,
771 function_ref<InFlightDiagnostic()> emitError) {
772 elabRequests.push(val);
773 Value seq = builder.create<GetSequenceOp>(
774 loc, SequenceType::get(builder.getContext(), {}), val->name);
775 Value res = builder.create<RandomizeSequenceOp>(loc, seq);
776 materializedValues[val] = res;
777 return res;
778 }
779
780 Value visit(InterleavedSequenceStorage *val, Location loc,
781 std::queue<RandomizedSequenceStorage *> &elabRequests,
782 function_ref<InFlightDiagnostic()> emitError) {
783 SmallVector<Value> sequences;
784 for (auto seqVal : val->sequences)
785 sequences.push_back(materialize(seqVal, loc, elabRequests, emitError));
786
787 if (sequences.size() == 1)
788 return sequences[0];
789
790 Value res =
791 builder.create<InterleaveSequencesOp>(loc, sequences, val->batchSize);
792 materializedValues[val] = res;
793 return res;
794 }
795
796 Value visit(VirtualRegisterStorage *val, Location loc,
797 std::queue<RandomizedSequenceStorage *> &elabRequests,
798 function_ref<InFlightDiagnostic()> emitError) {
799 Value res = builder.create<VirtualRegisterOp>(loc, val->allowedRegs);
800 materializedValues[val] = res;
801 return res;
802 }
803
804 Value visit(UniqueLabelStorage *val, Location loc,
805 std::queue<RandomizedSequenceStorage *> &elabRequests,
806 function_ref<InFlightDiagnostic()> emitError) {
807 Value res = builder.create<LabelUniqueDeclOp>(loc, val->name, ValueRange());
808 materializedValues[val] = res;
809 return res;
810 }
811
812 Value visit(const LabelValue &val, Location loc,
813 std::queue<RandomizedSequenceStorage *> &elabRequests,
814 function_ref<InFlightDiagnostic()> emitError) {
815 Value res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
816 materializedValues[val] = res;
817 return res;
818 }
819
820private:
821 /// Cache values we have already materialized to reuse them later. We start
822 /// with an insertion point at the start of the block and cache the (updated)
823 /// insertion point such that future materializations can also reuse previous
824 /// materializations without running into dominance issues (or requiring
825 /// additional checks to avoid them).
826 DenseMap<ElaboratorValue, Value> materializedValues;
827
828 /// Cache the builder to continue insertions at their current insertion point
829 /// for the reason stated above.
830 OpBuilder builder;
831
832 SmallVector<Operation *> toDelete;
833};
834
835//===----------------------------------------------------------------------===//
836// Elaboration Visitor
837//===----------------------------------------------------------------------===//
838
839/// Used to signal to the elaboration driver whether the operation should be
840/// removed.
841enum class DeletionKind { Keep, Delete };
842
843/// Elaborator state that should be shared by all elaborator instances.
844struct ElaboratorSharedState {
845 ElaboratorSharedState(SymbolTable &table, unsigned seed)
846 : table(table), rng(seed) {}
847
848 SymbolTable &table;
849 std::mt19937 rng;
850 Namespace names;
851 Internalizer internalizer;
852
853 /// The worklist used to keep track of the test and sequence operations to
854 /// make sure they are processed top-down (BFS traversal).
855 std::queue<RandomizedSequenceStorage *> worklist;
856};
857
858/// A collection of state per RTG test.
859struct TestState {
860 /// The name of the test.
861 StringAttr name;
862
863 /// The context switches registered for this test.
864 MapVector<
865 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
866 SequenceStorage *>
867 contextSwitches;
868};
869
870/// Interprets the IR to perform and lower the represented randomizations.
871class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
872public:
874 using RTGBase::visitOp;
875
876 Elaborator(ElaboratorSharedState &sharedState, TestState &testState,
877 Materializer &materializer,
878 ContextResourceAttrInterface currentContext = {})
879 : sharedState(sharedState), testState(testState),
880 materializer(materializer), currentContext(currentContext) {}
881
882 template <typename ValueTy>
883 inline ValueTy get(Value val) const {
884 return std::get<ValueTy>(state.at(val));
885 }
886
887 FailureOr<DeletionKind> visitConstantLike(Operation *op) {
888 assert(op->hasTrait<OpTrait::ConstantLike>() &&
889 "op is expected to be constant-like");
890
891 SmallVector<OpFoldResult, 1> result;
892 auto foldResult = op->fold(result);
893 (void)foldResult; // Make sure there is a user when assertions are off.
894 assert(succeeded(foldResult) &&
895 "constant folder of a constant-like must always succeed");
896 auto attr = dyn_cast<TypedAttr>(result[0].dyn_cast<Attribute>());
897 if (!attr)
898 return op->emitError(
899 "only typed attributes supported for constant-like operations");
900
901 auto intAttr = dyn_cast<IntegerAttr>(attr);
902 if (intAttr && isa<IndexType>(attr.getType()))
903 state[op->getResult(0)] = size_t(intAttr.getInt());
904 else if (intAttr && intAttr.getType().isSignlessInteger(1))
905 state[op->getResult(0)] = bool(intAttr.getInt());
906 else
907 state[op->getResult(0)] = attr;
908
909 return DeletionKind::Delete;
910 }
911
912 /// Print a nice error message for operations we don't support yet.
913 FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
914 return op->emitOpError("elaboration not supported");
915 }
916
917 FailureOr<DeletionKind> visitExternalOp(Operation *op) {
918 if (op->hasTrait<OpTrait::ConstantLike>())
919 return visitConstantLike(op);
920
921 // TODO: we only have this to be able to write tests for this pass without
922 // having to add support for more operations for now, so it should be
923 // removed once it is not necessary anymore for writing tests
924 if (op->use_empty())
925 return DeletionKind::Keep;
926
927 return visitUnhandledOp(op);
928 }
929
930 FailureOr<DeletionKind> visitOp(ConstantOp op) {
931 return visitConstantLike(op);
932 }
933
934 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
935 SmallVector<ElaboratorValue> replacements;
936 state[op.getResult()] =
937 sharedState.internalizer.internalize<SequenceStorage>(
938 op.getSequenceAttr(), std::move(replacements));
939 return DeletionKind::Delete;
940 }
941
942 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
943 auto *seq = get<SequenceStorage *>(op.getSequence());
944
945 SmallVector<ElaboratorValue> replacements(seq->args);
946 for (auto replacement : op.getReplacements())
947 replacements.push_back(state.at(replacement));
948
949 state[op.getResult()] =
950 sharedState.internalizer.internalize<SequenceStorage>(
951 seq->familyName, std::move(replacements));
952
953 return DeletionKind::Delete;
954 }
955
956 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
957 auto *seq = get<SequenceStorage *>(op.getSequence());
958
959 auto name = sharedState.names.newName(seq->familyName.getValue());
960 auto *randomizedSeq =
961 sharedState.internalizer.internalize<RandomizedSequenceStorage>(
962 name, currentContext, testState.name, seq);
963 state[op.getResult()] =
964 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
965 randomizedSeq);
966 return DeletionKind::Delete;
967 }
968
969 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
970 SmallVector<ElaboratorValue> sequences;
971 for (auto seq : op.getSequences())
972 sequences.push_back(get<InterleavedSequenceStorage *>(seq));
973
974 state[op.getResult()] =
975 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
976 std::move(sequences), op.getBatchSize());
977 return DeletionKind::Delete;
978 }
979
980 // NOLINTNEXTLINE(misc-no-recursion)
981 LogicalResult isValidContext(ElaboratorValue value, Operation *op) const {
982 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
983 auto *seq = std::get<RandomizedSequenceStorage *>(value);
984 if (seq->context != currentContext) {
985 auto err = op->emitError("attempting to place sequence ")
986 << seq->name << " derived from "
987 << seq->sequence->familyName.getValue() << " under context "
988 << currentContext
989 << ", but it was previously randomized for context ";
990 if (seq->context)
991 err << seq->context;
992 else
993 err << "'default'";
994 return err;
995 }
996 return success();
997 }
998
999 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
1000 for (auto val : interVal->sequences)
1001 if (failed(isValidContext(val, op)))
1002 return failure();
1003 return success();
1004 }
1005
1006 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
1007 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
1008 if (failed(isValidContext(seqVal, op)))
1009 return failure();
1010
1011 return DeletionKind::Keep;
1012 }
1013
1014 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
1015 SetVector<ElaboratorValue> set;
1016 for (auto val : op.getElements())
1017 set.insert(state.at(val));
1018
1019 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
1020 std::move(set), op.getSet().getType());
1021 return DeletionKind::Delete;
1022 }
1023
1024 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
1025 auto set = get<SetStorage *>(op.getSet())->set;
1026
1027 if (set.empty())
1028 return op->emitError("cannot select from an empty set");
1029
1030 size_t selected;
1031 if (auto intAttr =
1032 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1033 std::mt19937 customRng(intAttr.getInt());
1034 selected = getUniformlyInRange(customRng, 0, set.size() - 1);
1035 } else {
1036 selected = getUniformlyInRange(sharedState.rng, 0, set.size() - 1);
1037 }
1038
1039 state[op.getResult()] = set[selected];
1040 return DeletionKind::Delete;
1041 }
1042
1043 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
1044 auto original = get<SetStorage *>(op.getOriginal())->set;
1045 auto diff = get<SetStorage *>(op.getDiff())->set;
1046
1047 SetVector<ElaboratorValue> result(original);
1048 result.set_subtract(diff);
1049
1050 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1051 std::move(result), op.getResult().getType());
1052 return DeletionKind::Delete;
1053 }
1054
1055 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
1056 SetVector<ElaboratorValue> result;
1057 for (auto set : op.getSets())
1058 result.set_union(get<SetStorage *>(set)->set);
1059
1060 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1061 std::move(result), op.getType());
1062 return DeletionKind::Delete;
1063 }
1064
1065 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1066 auto size = get<SetStorage *>(op.getSet())->set.size();
1067 state[op.getResult()] = size;
1068 return DeletionKind::Delete;
1069 }
1070
1071 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1072 MapVector<ElaboratorValue, uint64_t> bag;
1073 for (auto [val, multiple] :
1074 llvm::zip(op.getElements(), op.getMultiples())) {
1075 // If the multiple is not stored as an AttributeValue, the elaboration
1076 // must have already failed earlier (since we don't have
1077 // unevaluated/opaque values).
1078 bag[state.at(val)] += get<size_t>(multiple);
1079 }
1080
1081 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1082 std::move(bag), op.getType());
1083 return DeletionKind::Delete;
1084 }
1085
1086 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1087 auto bag = get<BagStorage *>(op.getBag())->bag;
1088
1089 if (bag.empty())
1090 return op->emitError("cannot select from an empty bag");
1091
1092 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1093 prefixSum.reserve(bag.size());
1094 uint32_t accumulator = 0;
1095 for (auto [val, weight] : bag) {
1096 accumulator += weight;
1097 prefixSum.push_back({val, accumulator});
1098 }
1099
1100 auto customRng = sharedState.rng;
1101 if (auto intAttr =
1102 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1103 customRng = std::mt19937(intAttr.getInt());
1104 }
1105
1106 auto idx = getUniformlyInRange(customRng, 0, accumulator - 1);
1107 auto *iter = llvm::upper_bound(
1108 prefixSum, idx,
1109 [](uint32_t a, const std::pair<ElaboratorValue, uint32_t> &b) {
1110 return a < b.second;
1111 });
1112
1113 state[op.getResult()] = iter->first;
1114 return DeletionKind::Delete;
1115 }
1116
1117 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1118 auto original = get<BagStorage *>(op.getOriginal())->bag;
1119 auto diff = get<BagStorage *>(op.getDiff())->bag;
1120
1121 MapVector<ElaboratorValue, uint64_t> result;
1122 for (const auto &el : original) {
1123 if (!diff.contains(el.first)) {
1124 result.insert(el);
1125 continue;
1126 }
1127
1128 if (op.getInf())
1129 continue;
1130
1131 auto toDiff = diff.lookup(el.first);
1132 if (el.second <= toDiff)
1133 continue;
1134
1135 result.insert({el.first, el.second - toDiff});
1136 }
1137
1138 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1139 std::move(result), op.getType());
1140 return DeletionKind::Delete;
1141 }
1142
1143 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1144 MapVector<ElaboratorValue, uint64_t> result;
1145 for (auto bag : op.getBags()) {
1146 auto val = get<BagStorage *>(bag)->bag;
1147 for (auto [el, multiple] : val)
1148 result[el] += multiple;
1149 }
1150
1151 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1152 std::move(result), op.getType());
1153 return DeletionKind::Delete;
1154 }
1155
1156 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1157 auto size = get<BagStorage *>(op.getBag())->bag.size();
1158 state[op.getResult()] = size;
1159 return DeletionKind::Delete;
1160 }
1161
1162 FailureOr<DeletionKind> visitOp(FixedRegisterOp op) {
1163 return visitConstantLike(op);
1164 }
1165
1166 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1167 state[op.getResult()] =
1168 sharedState.internalizer.create<VirtualRegisterStorage>(
1169 op.getAllowedRegsAttr());
1170 return DeletionKind::Delete;
1171 }
1172
1173 StringAttr substituteFormatString(StringAttr formatString,
1174 ValueRange substitutes) const {
1175 if (substitutes.empty() || formatString.empty())
1176 return formatString;
1177
1178 auto original = formatString.getValue().str();
1179 for (auto [i, subst] : llvm::enumerate(substitutes)) {
1180 size_t startPos = 0;
1181 std::string from = "{{" + std::to_string(i) + "}}";
1182 while ((startPos = original.find(from, startPos)) != std::string::npos) {
1183 auto substString = std::to_string(get<size_t>(subst));
1184 original.replace(startPos, from.length(), substString);
1185 }
1186 }
1187
1188 return StringAttr::get(formatString.getContext(), original);
1189 }
1190
1191 FailureOr<DeletionKind> visitOp(ArrayCreateOp op) {
1192 SmallVector<ElaboratorValue> array;
1193 array.reserve(op.getElements().size());
1194 for (auto val : op.getElements())
1195 array.emplace_back(state.at(val));
1196
1197 state[op.getResult()] = sharedState.internalizer.internalize<ArrayStorage>(
1198 op.getResult().getType(), std::move(array));
1199 return DeletionKind::Delete;
1200 }
1201
1202 FailureOr<DeletionKind> visitOp(ArrayExtractOp op) {
1203 auto array = get<ArrayStorage *>(op.getArray())->array;
1204 size_t idx = get<size_t>(op.getIndex());
1205
1206 if (array.size() <= idx)
1207 return op->emitError("invalid to access index ")
1208 << idx << " of an array with " << array.size() << " elements";
1209
1210 state[op.getResult()] = array[idx];
1211 return DeletionKind::Delete;
1212 }
1213
1214 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
1215 auto substituted =
1216 substituteFormatString(op.getFormatStringAttr(), op.getArgs());
1217 state[op.getLabel()] = LabelValue(substituted);
1218 return DeletionKind::Delete;
1219 }
1220
1221 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1222 state[op.getLabel()] = sharedState.internalizer.create<UniqueLabelStorage>(
1223 substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
1224 return DeletionKind::Delete;
1225 }
1226
1227 FailureOr<DeletionKind> visitOp(LabelOp op) { return DeletionKind::Keep; }
1228
1229 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1230 size_t lower = get<size_t>(op.getLowerBound());
1231 size_t upper = get<size_t>(op.getUpperBound()) - 1;
1232 if (lower > upper)
1233 return op->emitError("cannot select a number from an empty range");
1234
1235 if (auto intAttr =
1236 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1237 std::mt19937 customRng(intAttr.getInt());
1238 state[op.getResult()] =
1239 size_t(getUniformlyInRange(customRng, lower, upper));
1240 } else {
1241 state[op.getResult()] =
1242 size_t(getUniformlyInRange(sharedState.rng, lower, upper));
1243 }
1244
1245 return DeletionKind::Delete;
1246 }
1247
1248 FailureOr<DeletionKind> visitOp(IntToImmediateOp op) {
1249 size_t input = get<size_t>(op.getInput());
1250 auto width = op.getType().getWidth();
1251 auto emitError = [&]() { return op->emitError(); };
1252 if (input > APInt::getAllOnes(width).getZExtValue())
1253 return emitError() << "cannot represent " << input << " with " << width
1254 << " bits";
1255
1256 state[op.getResult()] =
1257 ImmediateAttr::get(op.getContext(), APInt(width, input));
1258 return DeletionKind::Delete;
1259 }
1260
1261 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1262 ContextResourceAttrInterface from = currentContext,
1263 to = cast<ContextResourceAttrInterface>(
1264 get<TypedAttr>(op.getContext()));
1265 if (!currentContext)
1266 from = DefaultContextAttr::get(op->getContext(), to.getType());
1267
1268 auto emitError = [&]() {
1269 auto diag = op.emitError();
1270 diag.attachNote(op.getLoc())
1271 << "while materializing value for context switching for " << op;
1272 return diag;
1273 };
1274
1275 if (from == to) {
1276 Value seqVal = materializer.materialize(
1277 get<SequenceStorage *>(op.getSequence()), op.getLoc(),
1278 sharedState.worklist, emitError);
1279 Value randSeqVal =
1280 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1281 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1282 return DeletionKind::Delete;
1283 }
1284
1285 // Switch to the desired context.
1286 auto *iter = testState.contextSwitches.find({from, to});
1287 // NOTE: we could think about supporting context switching via intermediate
1288 // context, i.e., treat it as a transitive relation.
1289 if (iter == testState.contextSwitches.end())
1290 return op->emitError("no context transition registered to switch from ")
1291 << from << " to " << to;
1292
1293 auto familyName = iter->second->familyName;
1294 SmallVector<ElaboratorValue> args{from, to,
1295 get<SequenceStorage *>(op.getSequence())};
1296 auto *seq = sharedState.internalizer.internalize<SequenceStorage>(
1297 familyName, std::move(args));
1298 auto *randSeq =
1299 sharedState.internalizer.internalize<RandomizedSequenceStorage>(
1300 sharedState.names.newName(familyName.getValue()), to,
1301 testState.name, seq);
1302 Value seqVal = materializer.materialize(randSeq, op.getLoc(),
1303 sharedState.worklist, emitError);
1304 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1305
1306 return DeletionKind::Delete;
1307 }
1308
1309 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1310 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1311 get<SequenceStorage *>(op.getSequence());
1312 return DeletionKind::Delete;
1313 }
1314
1315 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1316 bool cond = get<bool>(op.getCondition());
1317 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1318 if (toElaborate.empty())
1319 return DeletionKind::Delete;
1320
1321 // Just reuse this elaborator for the nested region because we need access
1322 // to the elaborated values outside the nested region (since it is not
1323 // isolated from above) and we want to materialize the region inline, thus
1324 // don't need a new materializer instance.
1325 if (failed(elaborate(toElaborate)))
1326 return failure();
1327
1328 // Map the results of the 'scf.if' to the yielded values.
1329 for (auto [res, out] :
1330 llvm::zip(op.getResults(),
1331 toElaborate.front().getTerminator()->getOperands()))
1332 state[res] = state.at(out);
1333
1334 return DeletionKind::Delete;
1335 }
1336
1337 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1338 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1339 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1340 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1341 return op->emitOpError("can only elaborate index type iterator");
1342
1343 auto lowerBound = get<size_t>(op.getLowerBound());
1344 auto step = get<size_t>(op.getStep());
1345 auto upperBound = get<size_t>(op.getUpperBound());
1346
1347 // Prepare for first iteration by assigning the nested regions block
1348 // arguments. We can just reuse this elaborator because we need access to
1349 // values elaborated in the parent region anyway and materialize everything
1350 // inline (i.e., don't need a new materializer).
1351 state[op.getInductionVar()] = lowerBound;
1352 for (auto [iterArg, initArg] :
1353 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1354 state[iterArg] = state.at(initArg);
1355
1356 // This loop performs the actual 'scf.for' loop iterations.
1357 for (size_t i = lowerBound; i < upperBound; i += step) {
1358 if (failed(elaborate(op.getBodyRegion())))
1359 return failure();
1360
1361 // Prepare for the next iteration by updating the mapping of the nested
1362 // regions block arguments
1363 state[op.getInductionVar()] = i + step;
1364 for (auto [iterArg, prevIterArg] :
1365 llvm::zip(op.getRegionIterArgs(),
1366 op.getBody()->getTerminator()->getOperands()))
1367 state[iterArg] = state.at(prevIterArg);
1368 }
1369
1370 // Transfer the previously yielded values to the for loop result values.
1371 for (auto [res, iterArg] :
1372 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1373 state[res] = state.at(iterArg);
1374
1375 return DeletionKind::Delete;
1376 }
1377
1378 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1379 return DeletionKind::Delete;
1380 }
1381
1382 FailureOr<DeletionKind> visitOp(index::AddOp op) {
1383 size_t lhs = get<size_t>(op.getLhs());
1384 size_t rhs = get<size_t>(op.getRhs());
1385 state[op.getResult()] = lhs + rhs;
1386 return DeletionKind::Delete;
1387 }
1388
1389 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
1390 size_t lhs = get<size_t>(op.getLhs());
1391 size_t rhs = get<size_t>(op.getRhs());
1392 bool result;
1393 switch (op.getPred()) {
1394 case index::IndexCmpPredicate::EQ:
1395 result = lhs == rhs;
1396 break;
1397 case index::IndexCmpPredicate::NE:
1398 result = lhs != rhs;
1399 break;
1400 case index::IndexCmpPredicate::ULT:
1401 result = lhs < rhs;
1402 break;
1403 case index::IndexCmpPredicate::ULE:
1404 result = lhs <= rhs;
1405 break;
1406 case index::IndexCmpPredicate::UGT:
1407 result = lhs > rhs;
1408 break;
1409 case index::IndexCmpPredicate::UGE:
1410 result = lhs >= rhs;
1411 break;
1412 default:
1413 return op->emitOpError("elaboration not supported");
1414 }
1415 state[op.getResult()] = result;
1416 return DeletionKind::Delete;
1417 }
1418
1419 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
1420 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
1421 .Case<
1422 // Index ops
1423 index::AddOp, index::CmpOp,
1424 // SCF ops
1425 scf::IfOp, scf::ForOp, scf::YieldOp>(
1426 [&](auto op) { return visitOp(op); })
1427 .Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); });
1428 }
1429
1430 // NOLINTNEXTLINE(misc-no-recursion)
1431 LogicalResult elaborate(Region &region,
1432 ArrayRef<ElaboratorValue> regionArguments = {}) {
1433 if (region.getBlocks().size() > 1)
1434 return region.getParentOp()->emitOpError(
1435 "regions with more than one block are not supported");
1436
1437 for (auto [arg, elabArg] :
1438 llvm::zip(region.getArguments(), regionArguments))
1439 state[arg] = elabArg;
1440
1441 Block *block = &region.front();
1442 for (auto &op : *block) {
1443 auto result = dispatchOpVisitor(&op);
1444 if (failed(result))
1445 return failure();
1446
1447 if (*result == DeletionKind::Keep)
1448 if (failed(materializer.materialize(&op, state, sharedState.worklist)))
1449 return failure();
1450
1451 LLVM_DEBUG({
1452 llvm::dbgs() << "Elaborated " << op << " to\n[";
1453
1454 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) {
1455 if (state.contains(res))
1456 llvm::dbgs() << state.at(res);
1457 else
1458 llvm::dbgs() << "unknown";
1459 });
1460
1461 llvm::dbgs() << "]\n\n";
1462 });
1463 }
1464
1465 return success();
1466 }
1467
1468private:
1469 // State to be shared between all elaborator instances.
1470 ElaboratorSharedState &sharedState;
1471
1472 // State to a specific RTG test and the sequences placed within it.
1473 TestState &testState;
1474
1475 // Allows us to materialize ElaboratorValues to the IR operations necessary to
1476 // obtain an SSA value representing that elaborated value.
1477 Materializer &materializer;
1478
1479 // A map from SSA values to a pointer of an interned elaborator value.
1480 DenseMap<Value, ElaboratorValue> state;
1481
1482 // The current context we are elaborating under.
1483 ContextResourceAttrInterface currentContext;
1484};
1485} // namespace
1486
1487//===----------------------------------------------------------------------===//
1488// Elaborator Pass
1489//===----------------------------------------------------------------------===//
1490
1491namespace {
1492struct ElaborationPass
1493 : public rtg::impl::ElaborationPassBase<ElaborationPass> {
1494 using Base::Base;
1495
1496 void runOnOperation() override;
1497 void cloneTargetsIntoTests(SymbolTable &table);
1498 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
1499};
1500} // namespace
1501
1502void ElaborationPass::runOnOperation() {
1503 auto moduleOp = getOperation();
1504 SymbolTable table(moduleOp);
1505
1506 cloneTargetsIntoTests(table);
1507
1508 if (failed(elaborateModule(moduleOp, table)))
1509 return signalPassFailure();
1510}
1511
1512void ElaborationPass::cloneTargetsIntoTests(SymbolTable &table) {
1513 auto moduleOp = getOperation();
1514 for (auto target : llvm::make_early_inc_range(moduleOp.getOps<TargetOp>())) {
1515 for (auto test : moduleOp.getOps<TestOp>()) {
1516 // If the test requires nothing from a target, we can always run it.
1517 if (test.getTarget().getEntries().empty())
1518 continue;
1519
1520 // If the target requirements do not match, skip this test
1521 // TODO: allow target refinements, just not coarsening
1522 if (target.getTarget() != test.getTarget())
1523 continue;
1524
1525 IRRewriter rewriter(test);
1526 // Create a new test for the matched target
1527 auto newTest = cast<TestOp>(test->clone());
1528 newTest.setSymName(test.getSymName().str() + "_" +
1529 target.getSymName().str());
1530 table.insert(newTest, rewriter.getInsertionPoint());
1531
1532 // Copy the target body into the newly created test
1533 IRMapping mapping;
1534 rewriter.setInsertionPointToStart(newTest.getBody());
1535 for (auto &op : target.getBody()->without_terminator())
1536 rewriter.clone(op, mapping);
1537
1538 for (auto [returnVal, result] :
1539 llvm::zip(target.getBody()->getTerminator()->getOperands(),
1540 newTest.getBody()->getArguments()))
1541 result.replaceAllUsesWith(mapping.lookup(returnVal));
1542
1543 newTest.getBody()->eraseArguments(0,
1544 newTest.getBody()->getNumArguments());
1545 newTest.setTarget(DictType::get(&getContext(), {}));
1546 }
1547
1548 target->erase();
1549 }
1550
1551 // Erase all remaining non-matched tests.
1552 for (auto test : llvm::make_early_inc_range(moduleOp.getOps<TestOp>()))
1553 if (!test.getTarget().getEntries().empty())
1554 test->erase();
1555}
1556
1557LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
1558 SymbolTable &table) {
1559 ElaboratorSharedState state(table, seed);
1560
1561 // Update the name cache
1562 state.names.add(moduleOp);
1563
1564 // Initialize the worklist with the test ops since they cannot be placed by
1565 // other ops.
1566 DenseMap<StringAttr, TestState> testStates;
1567 for (auto testOp : moduleOp.getOps<TestOp>()) {
1568 LLVM_DEBUG(llvm::dbgs()
1569 << "\n=== Elaborating test @" << testOp.getSymName() << "\n\n");
1570 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()));
1571 testStates[testOp.getSymNameAttr()].name = testOp.getSymNameAttr();
1572 Elaborator elaborator(state, testStates[testOp.getSymNameAttr()],
1573 materializer);
1574 if (failed(elaborator.elaborate(testOp.getBodyRegion())))
1575 return failure();
1576
1577 materializer.finalize();
1578 }
1579
1580 // Do top-down BFS traversal such that elaborating a sequence further down
1581 // does not fix the outcome for multiple placements.
1582 while (!state.worklist.empty()) {
1583 auto *curr = state.worklist.front();
1584 state.worklist.pop();
1585
1586 if (table.lookup<SequenceOp>(curr->name))
1587 continue;
1588
1589 auto familyOp = table.lookup<SequenceOp>(curr->sequence->familyName);
1590 // TODO: don't clone if this is the only remaining reference to this
1591 // sequence
1592 OpBuilder builder(familyOp);
1593 auto seqOp = builder.cloneWithoutRegions(familyOp);
1594 seqOp.getBodyRegion().emplaceBlock();
1595 seqOp.setSymName(curr->name);
1596 seqOp.setSequenceType(
1597 SequenceType::get(builder.getContext(), ArrayRef<Type>{}));
1598 table.insert(seqOp);
1599 assert(seqOp.getSymName() == curr->name && "should not have been renamed");
1600
1601 LLVM_DEBUG(llvm::dbgs()
1602 << "\n=== Elaborating sequence family @" << familyOp.getSymName()
1603 << " into @" << seqOp.getSymName() << " under context "
1604 << curr->context << "\n\n");
1605
1606 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()));
1607 Elaborator elaborator(state, testStates[curr->test], materializer,
1608 curr->context);
1609 if (failed(elaborator.elaborate(familyOp.getBodyRegion(),
1610 curr->sequence->args)))
1611 return failure();
1612
1613 materializer.finalize();
1614 }
1615
1616 return success();
1617}
assert(baseType &&"element must be base type")
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static void print(TypedAttr val, llvm::raw_ostream &os)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
This helps visit TypeOp nodes.
Definition RTGVisitors.h:29
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:82
ResultType dispatchOpVisitor(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:31
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
static llvm::hash_code hash_value(const ModulePort &port)
Definition HWTypes.h:38
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition Utils.h:32
Definition rtg.py:1
Definition seq.py:1
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)