CIRCT 21.0.0git
Loading...
Searching...
No Matches
ElaborationPass.cpp
Go to the documentation of this file.
1//===- ElaborationPass.cpp - RTG ElaborationPass implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass elaborates the random parts of the RTG dialect.
10// It performs randomization top-down, i.e., random constructs in a sequence
11// that is invoked multiple times can yield different randomization results
12// for each invokation.
13//
14//===----------------------------------------------------------------------===//
15
21#include "mlir/Dialect/Index/IR/IndexDialect.h"
22#include "mlir/Dialect/Index/IR/IndexOps.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/IR/IRMapping.h"
25#include "mlir/IR/PatternMatch.h"
26#include "llvm/ADT/DenseMapInfoVariant.h"
27#include "llvm/Support/Debug.h"
28#include <queue>
29#include <random>
30
31namespace circt {
32namespace rtg {
33#define GEN_PASS_DEF_ELABORATIONPASS
34#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
35} // namespace rtg
36} // namespace circt
37
38using namespace mlir;
39using namespace circt;
40using namespace circt::rtg;
41using llvm::MapVector;
42
43#define DEBUG_TYPE "rtg-elaboration"
44
45//===----------------------------------------------------------------------===//
46// Uniform Distribution Helper
47//
48// Simplified version of
49// https://github.com/llvm/llvm-project/blob/main/libcxx/include/__random/uniform_int_distribution.h
50// We use our custom version here to get the same results when compiled with
51// different compiler versions and standard libraries.
52//===----------------------------------------------------------------------===//
53
54static uint32_t computeMask(size_t w) {
55 size_t n = w / 32 + (w % 32 != 0);
56 size_t w0 = w / n;
57 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
58}
59
60/// Get a number uniformly at random in the in specified range.
61static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) {
62 const uint32_t diff = b - a + 1;
63 if (diff == 1)
64 return a;
65
66 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
67 if (diff == 0)
68 return rng();
69
70 uint32_t width = digits - llvm::countl_zero(diff) - 1;
71 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
72 ++width;
73
74 uint32_t mask = computeMask(diff);
75 uint32_t u;
76 do {
77 u = rng() & mask;
78 } while (u >= diff);
79
80 return u + a;
81}
82
83//===----------------------------------------------------------------------===//
84// Elaborator Value
85//===----------------------------------------------------------------------===//
86
87namespace {
88struct BagStorage;
89struct SequenceStorage;
90struct RandomizedSequenceStorage;
91struct InterleavedSequenceStorage;
92struct SetStorage;
93struct VirtualRegisterStorage;
94struct UniqueLabelStorage;
95
96/// Simple wrapper around a 'StringAttr' such that we know to materialize it as
97/// a label declaration instead of calling the builtin dialect constant
98/// materializer.
99struct LabelValue {
100 LabelValue(StringAttr name) : name(name) {}
101
102 bool operator==(const LabelValue &other) const { return name == other.name; }
103
104 /// The label name.
105 StringAttr name;
106};
107
108/// The abstract base class for elaborated values.
109using ElaboratorValue =
110 std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
111 RandomizedSequenceStorage *, InterleavedSequenceStorage *,
112 SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
113 LabelValue>;
114
115// NOLINTNEXTLINE(readability-identifier-naming)
116llvm::hash_code hash_value(const LabelValue &val) {
117 return llvm::hash_value(val.name);
118}
119
120// NOLINTNEXTLINE(readability-identifier-naming)
121llvm::hash_code hash_value(const ElaboratorValue &val) {
122 return std::visit(
123 [&val](const auto &alternative) {
124 // Include index in hash to make sure same value as different
125 // alternatives don't collide.
126 return llvm::hash_combine(val.index(), alternative);
127 },
128 val);
129}
130
131} // namespace
132
133namespace llvm {
134
135template <>
136struct DenseMapInfo<bool> {
137 static inline unsigned getEmptyKey() { return false; }
138 static inline unsigned getTombstoneKey() { return true; }
139 static unsigned getHashValue(const bool &val) { return val * 37U; }
140
141 static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; }
142};
143template <>
144struct DenseMapInfo<LabelValue> {
145 static inline LabelValue getEmptyKey() {
147 }
148 static inline LabelValue getTombstoneKey() {
150 }
151 static unsigned getHashValue(const LabelValue &val) {
152 return hash_value(val);
153 }
154
155 static bool isEqual(const LabelValue &lhs, const LabelValue &rhs) {
156 return lhs == rhs;
157 }
158};
159
160} // namespace llvm
161
162//===----------------------------------------------------------------------===//
163// Elaborator Value Storages and Internalization
164//===----------------------------------------------------------------------===//
165
166namespace {
167
168/// Lightweight object to be used as the key for internalization sets. It caches
169/// the hashcode of the internalized object and a pointer to it. This allows a
170/// delayed allocation and construction of the actual object and thus only has
171/// to happen if the object is not already in the set.
172template <typename StorageTy>
173struct HashedStorage {
174 HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr)
175 : hashcode(hashcode), storage(storage) {}
176
177 unsigned hashcode;
178 StorageTy *storage;
179};
180
181/// A DenseMapInfo implementation to support 'insert_as' for the internalization
182/// sets. When comparing two 'HashedStorage's we can just compare the already
183/// internalized storage pointers, otherwise we have to call the costly
184/// 'isEqual' method.
185template <typename StorageTy>
186struct StorageKeyInfo {
187 static inline HashedStorage<StorageTy> getEmptyKey() {
188 return HashedStorage<StorageTy>(0,
189 DenseMapInfo<StorageTy *>::getEmptyKey());
190 }
191 static inline HashedStorage<StorageTy> getTombstoneKey() {
192 return HashedStorage<StorageTy>(
193 0, DenseMapInfo<StorageTy *>::getTombstoneKey());
194 }
195
196 static inline unsigned getHashValue(const HashedStorage<StorageTy> &key) {
197 return key.hashcode;
198 }
199 static inline unsigned getHashValue(const StorageTy &key) {
200 return key.hashcode;
201 }
202
203 static inline bool isEqual(const HashedStorage<StorageTy> &lhs,
204 const HashedStorage<StorageTy> &rhs) {
205 return lhs.storage == rhs.storage;
206 }
207 static inline bool isEqual(const StorageTy &lhs,
208 const HashedStorage<StorageTy> &rhs) {
209 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
210 return false;
211
212 return lhs.isEqual(rhs.storage);
213 }
214};
215
216/// Storage object for an '!rtg.set<T>'.
217struct SetStorage {
218 SetStorage(SetVector<ElaboratorValue> &&set, Type type)
219 : hashcode(llvm::hash_combine(
220 type, llvm::hash_combine_range(set.begin(), set.end()))),
221 set(std::move(set)), type(type) {}
222
223 bool isEqual(const SetStorage *other) const {
224 return hashcode == other->hashcode && set == other->set &&
225 type == other->type;
226 }
227
228 // The cached hashcode to avoid repeated computations.
229 const unsigned hashcode;
230
231 // Stores the elaborated values contained in the set.
232 const SetVector<ElaboratorValue> set;
233
234 // Store the set type such that we can materialize this evaluated value
235 // also in the case where the set is empty.
236 const Type type;
237};
238
239/// Storage object for an '!rtg.bag<T>'.
240struct BagStorage {
241 BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
242 : hashcode(llvm::hash_combine(
243 type, llvm::hash_combine_range(bag.begin(), bag.end()))),
244 bag(std::move(bag)), type(type) {}
245
246 bool isEqual(const BagStorage *other) const {
247 return hashcode == other->hashcode && llvm::equal(bag, other->bag) &&
248 type == other->type;
249 }
250
251 // The cached hashcode to avoid repeated computations.
252 const unsigned hashcode;
253
254 // Stores the elaborated values contained in the bag with their number of
255 // occurences.
256 const MapVector<ElaboratorValue, uint64_t> bag;
257
258 // Store the bag type such that we can materialize this evaluated value
259 // also in the case where the bag is empty.
260 const Type type;
261};
262
263/// Storage object for an '!rtg.sequence'.
264struct SequenceStorage {
265 SequenceStorage(StringAttr familyName, SmallVector<ElaboratorValue> &&args)
266 : hashcode(llvm::hash_combine(
267 familyName, llvm::hash_combine_range(args.begin(), args.end()))),
268 familyName(familyName), args(std::move(args)) {}
269
270 bool isEqual(const SequenceStorage *other) const {
271 return hashcode == other->hashcode && familyName == other->familyName &&
272 args == other->args;
273 }
274
275 // The cached hashcode to avoid repeated computations.
276 const unsigned hashcode;
277
278 // The name of the sequence family this sequence is derived from.
279 const StringAttr familyName;
280
281 // The elaborator values used during substitution of the sequence family.
282 const SmallVector<ElaboratorValue> args;
283};
284
285/// Storage object for an '!rtg.randomized_sequence'.
286struct RandomizedSequenceStorage {
287 RandomizedSequenceStorage(StringRef name,
288 ContextResourceAttrInterface context,
289 StringAttr test, SequenceStorage *sequence)
290 : hashcode(llvm::hash_combine(name, context, test, sequence)), name(name),
291 context(context), test(test), sequence(sequence) {}
292
293 bool isEqual(const RandomizedSequenceStorage *other) const {
294 return hashcode == other->hashcode && name == other->name &&
295 context == other->context && test == other->test &&
296 sequence == other->sequence;
297 }
298
299 // The cached hashcode to avoid repeated computations.
300 const unsigned hashcode;
301
302 // The name of this fully substituted and elaborated sequence.
303 const StringRef name;
304
305 // The context under which this sequence is placed.
306 const ContextResourceAttrInterface context;
307
308 // The test in which this sequence is placed.
309 const StringAttr test;
310
311 const SequenceStorage *sequence;
312};
313
314/// Storage object for interleaved '!rtg.randomized_sequence'es.
315struct InterleavedSequenceStorage {
316 InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
317 uint32_t batchSize)
318 : sequences(std::move(sequences)), batchSize(batchSize),
319 hashcode(llvm::hash_combine(
320 llvm::hash_combine_range(sequences.begin(), sequences.end()),
321 batchSize)) {}
322
323 explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
324 : sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
325 hashcode(llvm::hash_combine(
326 llvm::hash_combine_range(sequences.begin(), sequences.end()),
327 batchSize)) {}
328
329 bool isEqual(const InterleavedSequenceStorage *other) const {
330 return hashcode == other->hashcode && sequences == other->sequences &&
331 batchSize == other->batchSize;
332 }
333
334 const SmallVector<ElaboratorValue> sequences;
335
336 const uint32_t batchSize;
337
338 // The cached hashcode to avoid repeated computations.
339 const unsigned hashcode;
340};
341
342/// Represents a unique virtual register.
343struct VirtualRegisterStorage {
344 VirtualRegisterStorage(ArrayAttr allowedRegs) : allowedRegs(allowedRegs) {}
345
346 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
347 // VirtualRegisters are never internalized.
348
349 // The list of fixed registers allowed to be selected for this virtual
350 // register.
351 const ArrayAttr allowedRegs;
352};
353
354struct UniqueLabelStorage {
355 UniqueLabelStorage(StringAttr name) : name(name) {}
356
357 // NOTE: we don't need an 'isEqual' function and 'hashcode' here because
358 // VirtualRegisters are never internalized.
359
360 /// The label name. For unique labels, this is just the prefix.
361 const StringAttr name;
362};
363
364/// An 'Internalizer' object internalizes storages and takes ownership of them.
365/// When the initializer object is destroyed, all owned storages are also
366/// deallocated and thus must not be accessed anymore.
367class Internalizer {
368public:
369 /// Internalize a storage of type `StorageTy` constructed with arguments
370 /// `args`. The pointers returned by this method can be used to compare
371 /// objects when, e.g., computing set differences, uniquing the elements in a
372 /// set, etc. Otherwise, we'd need to do a deep value comparison in those
373 /// situations.
374 template <typename StorageTy, typename... Args>
375 StorageTy *internalize(Args &&...args) {
376 StorageTy storage(std::forward<Args>(args)...);
377
378 auto existing = getInternSet<StorageTy>().insert_as(
379 HashedStorage<StorageTy>(storage.hashcode), storage);
380 StorageTy *&storagePtr = existing.first->storage;
381 if (existing.second)
382 storagePtr =
383 new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
384
385 return storagePtr;
386 }
387
388 template <typename StorageTy, typename... Args>
389 StorageTy *create(Args &&...args) {
390 return new (allocator.Allocate<StorageTy>())
391 StorageTy(std::forward<Args>(args)...);
392 }
393
394private:
395 template <typename StorageTy>
396 DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
397 getInternSet() {
398 if constexpr (std::is_same_v<StorageTy, SetStorage>)
399 return internedSets;
400 else if constexpr (std::is_same_v<StorageTy, BagStorage>)
401 return internedBags;
402 else if constexpr (std::is_same_v<StorageTy, SequenceStorage>)
403 return internedSequences;
404 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
405 return internedRandomizedSequences;
406 else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
407 return internedInterleavedSequences;
408 else
409 static_assert(!sizeof(StorageTy),
410 "no intern set available for this storage type.");
411 }
412
413 // This allocator allocates on the heap. It automatically deallocates all
414 // objects it allocated once the allocator itself is destroyed.
415 llvm::BumpPtrAllocator allocator;
416
417 // The sets holding the internalized objects. We use one set per storage type
418 // such that we can have a simpler equality checking function (no need to
419 // compare some sort of TypeIDs).
420 DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
421 DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
422 DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>>
423 internedSequences;
424 DenseSet<HashedStorage<RandomizedSequenceStorage>,
425 StorageKeyInfo<RandomizedSequenceStorage>>
426 internedRandomizedSequences;
427 DenseSet<HashedStorage<InterleavedSequenceStorage>,
428 StorageKeyInfo<InterleavedSequenceStorage>>
429 internedInterleavedSequences;
430};
431
432} // namespace
433
434#ifndef NDEBUG
435
436static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
437 const ElaboratorValue &value);
438
439static void print(TypedAttr val, llvm::raw_ostream &os) {
440 os << "<attr " << val << ">";
441}
442
443static void print(BagStorage *val, llvm::raw_ostream &os) {
444 os << "<bag {";
445 llvm::interleaveComma(val->bag, os,
446 [&](const std::pair<ElaboratorValue, uint64_t> &el) {
447 os << el.first << " -> " << el.second;
448 });
449 os << "} at " << val << ">";
450}
451
452static void print(bool val, llvm::raw_ostream &os) {
453 os << "<bool " << (val ? "true" : "false") << ">";
454}
455
456static void print(size_t val, llvm::raw_ostream &os) {
457 os << "<index " << val << ">";
458}
459
460static void print(SequenceStorage *val, llvm::raw_ostream &os) {
461 os << "<sequence @" << val->familyName.getValue() << "(";
462 llvm::interleaveComma(val->args, os,
463 [&](const ElaboratorValue &val) { os << val; });
464 os << ") at " << val << ">";
465}
466
467static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
468 os << "<randomized-sequence @" << val->name << " derived from @"
469 << val->sequence->familyName.getValue() << " under context "
470 << val->context << " in test " << val->test << "(";
471 llvm::interleaveComma(val->sequence->args, os,
472 [&](const ElaboratorValue &val) { os << val; });
473 os << ") at " << val << ">";
474}
475
476static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
477 os << "<interleaved-sequence [";
478 llvm::interleaveComma(val->sequences, os,
479 [&](const ElaboratorValue &val) { os << val; });
480 os << "] batch-size " << val->batchSize << " at " << val << ">";
481}
482
483static void print(SetStorage *val, llvm::raw_ostream &os) {
484 os << "<set {";
485 llvm::interleaveComma(val->set, os,
486 [&](const ElaboratorValue &val) { os << val; });
487 os << "} at " << val << ">";
488}
489
490static void print(const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
491 os << "<virtual-register " << val << " " << val->allowedRegs << ">";
492}
493
494static void print(const UniqueLabelStorage *val, llvm::raw_ostream &os) {
495 os << "<unique-label " << val << " " << val->name << ">";
496}
497
498static void print(const LabelValue &val, llvm::raw_ostream &os) {
499 os << "<label " << val.name << ">";
500}
501
502static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
503 const ElaboratorValue &value) {
504 std::visit([&](auto val) { print(val, os); }, value);
505
506 return os;
507}
508
509#endif
510
511//===----------------------------------------------------------------------===//
512// Elaborator Value Materialization
513//===----------------------------------------------------------------------===//
514
515namespace {
516
517/// Construct an SSA value from a given elaborated value.
518class Materializer {
519public:
520 Materializer(OpBuilder builder) : builder(builder) {}
521
522 /// Materialize IR representing the provided `ElaboratorValue` and return the
523 /// `Value` or a null value on failure.
524 Value materialize(ElaboratorValue val, Location loc,
525 std::queue<RandomizedSequenceStorage *> &elabRequests,
526 function_ref<InFlightDiagnostic()> emitError) {
527 auto iter = materializedValues.find(val);
528 if (iter != materializedValues.end())
529 return iter->second;
530
531 LLVM_DEBUG(llvm::dbgs() << "Materializing " << val << "\n\n");
532
533 return std::visit(
534 [&](auto val) { return visit(val, loc, elabRequests, emitError); },
535 val);
536 }
537
538 /// If `op` is not in the same region as the materializer insertion point, a
539 /// clone is created at the materializer's insertion point by also
540 /// materializing the `ElaboratorValue`s for each operand just before it.
541 /// Otherwise, all operations after the materializer's insertion point are
542 /// deleted until `op` is reached. An error is returned if the operation is
543 /// before the insertion point.
544 LogicalResult
545 materialize(Operation *op, DenseMap<Value, ElaboratorValue> &state,
546 std::queue<RandomizedSequenceStorage *> &elabRequests) {
547 if (op->getNumRegions() > 0)
548 return op->emitOpError("ops with nested regions must be elaborated away");
549
550 // We don't support opaque values. If there is an SSA value that has a
551 // use-site it needs an equivalent ElaborationValue representation.
552 // NOTE: We could support cases where there is initially a use-site but that
553 // op is guaranteed to be deleted during elaboration. Or the use-sites are
554 // replaced with freshly materialized values from the ElaborationValue. But
555 // then, why can't we delete the value defining op?
556 for (auto res : op->getResults())
557 if (!res.use_empty())
558 return op->emitOpError(
559 "ops with results that have uses are not supported");
560
561 if (op->getParentRegion() == builder.getBlock()->getParent()) {
562 // We are doing in-place materialization, so mark all ops deleted until we
563 // reach the one to be materialized and modify it in-place.
564 deleteOpsUntil([&](auto iter) { return &*iter == op; });
565
566 if (builder.getInsertionPoint() == builder.getBlock()->end())
567 return op->emitError("operation did not occur after the current "
568 "materializer insertion point");
569
570 LLVM_DEBUG(llvm::dbgs() << "Modifying in-place: " << *op << "\n\n");
571 } else {
572 LLVM_DEBUG(llvm::dbgs() << "Materializing a clone of " << *op << "\n\n");
573 op = builder.clone(*op);
574 builder.setInsertionPoint(op);
575 }
576
577 for (auto &operand : op->getOpOperands()) {
578 auto emitError = [&]() {
579 auto diag = op->emitError();
580 diag.attachNote(op->getLoc())
581 << "while materializing value for operand#"
582 << operand.getOperandNumber();
583 return diag;
584 };
585
586 Value val = materialize(state.at(operand.get()), op->getLoc(),
587 elabRequests, emitError);
588 if (!val)
589 return failure();
590
591 operand.set(val);
592 }
593
594 builder.setInsertionPointAfter(op);
595 return success();
596 }
597
598 /// Should be called once the `Region` is successfully materialized. No calls
599 /// to `materialize` should happen after this anymore.
600 void finalize() {
601 deleteOpsUntil([](auto iter) { return false; });
602
603 for (auto *op : llvm::reverse(toDelete))
604 op->erase();
605 }
606
607 template <typename OpTy, typename... Args>
608 OpTy create(Location location, Args &&...args) {
609 return builder.create<OpTy>(location, std::forward<Args>(args)...);
610 }
611
612private:
613 void deleteOpsUntil(function_ref<bool(Block::iterator)> stop) {
614 auto ip = builder.getInsertionPoint();
615 while (ip != builder.getBlock()->end() && !stop(ip)) {
616 LLVM_DEBUG(llvm::dbgs() << "Marking to be deleted: " << *ip << "\n\n");
617 toDelete.push_back(&*ip);
618
619 builder.setInsertionPointAfter(&*ip);
620 ip = builder.getInsertionPoint();
621 }
622 }
623
624 Value visit(TypedAttr val, Location loc,
625 std::queue<RandomizedSequenceStorage *> &elabRequests,
626 function_ref<InFlightDiagnostic()> emitError) {
627 // For index attributes (and arithmetic operations on them) we use the
628 // index dialect.
629 if (auto intAttr = dyn_cast<IntegerAttr>(val);
630 intAttr && isa<IndexType>(val.getType())) {
631 Value res = builder.create<index::ConstantOp>(loc, intAttr);
632 materializedValues[val] = res;
633 return res;
634 }
635
636 // For any other attribute, we just call the materializer of the dialect
637 // defining that attribute.
638 auto *op =
639 val.getDialect().materializeConstant(builder, val, val.getType(), loc);
640 if (!op) {
641 emitError() << "materializer of dialect '"
642 << val.getDialect().getNamespace()
643 << "' unable to materialize value for attribute '" << val
644 << "'";
645 return Value();
646 }
647
648 Value res = op->getResult(0);
649 materializedValues[val] = res;
650 return res;
651 }
652
653 Value visit(size_t val, Location loc,
654 std::queue<RandomizedSequenceStorage *> &elabRequests,
655 function_ref<InFlightDiagnostic()> emitError) {
656 Value res = builder.create<index::ConstantOp>(loc, val);
657 materializedValues[val] = res;
658 return res;
659 }
660
661 Value visit(bool val, Location loc,
662 std::queue<RandomizedSequenceStorage *> &elabRequests,
663 function_ref<InFlightDiagnostic()> emitError) {
664 Value res = builder.create<index::BoolConstantOp>(loc, val);
665 materializedValues[val] = res;
666 return res;
667 }
668
669 Value visit(SetStorage *val, Location loc,
670 std::queue<RandomizedSequenceStorage *> &elabRequests,
671 function_ref<InFlightDiagnostic()> emitError) {
672 SmallVector<Value> elements;
673 elements.reserve(val->set.size());
674 for (auto el : val->set) {
675 auto materialized = materialize(el, loc, elabRequests, emitError);
676 if (!materialized)
677 return Value();
678
679 elements.push_back(materialized);
680 }
681
682 auto res = builder.create<SetCreateOp>(loc, val->type, elements);
683 materializedValues[val] = res;
684 return res;
685 }
686
687 Value visit(BagStorage *val, Location loc,
688 std::queue<RandomizedSequenceStorage *> &elabRequests,
689 function_ref<InFlightDiagnostic()> emitError) {
690 SmallVector<Value> values, weights;
691 values.reserve(val->bag.size());
692 weights.reserve(val->bag.size());
693 for (auto [val, weight] : val->bag) {
694 auto materializedVal = materialize(val, loc, elabRequests, emitError);
695 auto materializedWeight =
696 materialize(weight, loc, elabRequests, emitError);
697 if (!materializedVal || !materializedWeight)
698 return Value();
699
700 values.push_back(materializedVal);
701 weights.push_back(materializedWeight);
702 }
703
704 auto res = builder.create<BagCreateOp>(loc, val->type, values, weights);
705 materializedValues[val] = res;
706 return res;
707 }
708
709 Value visit(SequenceStorage *val, Location loc,
710 std::queue<RandomizedSequenceStorage *> &elabRequests,
711 function_ref<InFlightDiagnostic()> emitError) {
712 emitError() << "materializing a non-randomized sequence not supported yet";
713 return Value();
714 }
715
716 Value visit(RandomizedSequenceStorage *val, Location loc,
717 std::queue<RandomizedSequenceStorage *> &elabRequests,
718 function_ref<InFlightDiagnostic()> emitError) {
719 elabRequests.push(val);
720 Value seq = builder.create<GetSequenceOp>(
721 loc, SequenceType::get(builder.getContext(), {}), val->name);
722 Value res = builder.create<RandomizeSequenceOp>(loc, seq);
723 materializedValues[val] = res;
724 return res;
725 }
726
727 Value visit(InterleavedSequenceStorage *val, Location loc,
728 std::queue<RandomizedSequenceStorage *> &elabRequests,
729 function_ref<InFlightDiagnostic()> emitError) {
730 SmallVector<Value> sequences;
731 for (auto seqVal : val->sequences)
732 sequences.push_back(materialize(seqVal, loc, elabRequests, emitError));
733
734 if (sequences.size() == 1)
735 return sequences[0];
736
737 Value res =
738 builder.create<InterleaveSequencesOp>(loc, sequences, val->batchSize);
739 materializedValues[val] = res;
740 return res;
741 }
742
743 Value visit(VirtualRegisterStorage *val, Location loc,
744 std::queue<RandomizedSequenceStorage *> &elabRequests,
745 function_ref<InFlightDiagnostic()> emitError) {
746 Value res = builder.create<VirtualRegisterOp>(loc, val->allowedRegs);
747 materializedValues[val] = res;
748 return res;
749 }
750
751 Value visit(UniqueLabelStorage *val, Location loc,
752 std::queue<RandomizedSequenceStorage *> &elabRequests,
753 function_ref<InFlightDiagnostic()> emitError) {
754 Value res = builder.create<LabelUniqueDeclOp>(loc, val->name, ValueRange());
755 materializedValues[val] = res;
756 return res;
757 }
758
759 Value visit(const LabelValue &val, Location loc,
760 std::queue<RandomizedSequenceStorage *> &elabRequests,
761 function_ref<InFlightDiagnostic()> emitError) {
762 Value res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
763 materializedValues[val] = res;
764 return res;
765 }
766
767private:
768 /// Cache values we have already materialized to reuse them later. We start
769 /// with an insertion point at the start of the block and cache the (updated)
770 /// insertion point such that future materializations can also reuse previous
771 /// materializations without running into dominance issues (or requiring
772 /// additional checks to avoid them).
773 DenseMap<ElaboratorValue, Value> materializedValues;
774
775 /// Cache the builder to continue insertions at their current insertion point
776 /// for the reason stated above.
777 OpBuilder builder;
778
779 SmallVector<Operation *> toDelete;
780};
781
782//===----------------------------------------------------------------------===//
783// Elaboration Visitor
784//===----------------------------------------------------------------------===//
785
786/// Used to signal to the elaboration driver whether the operation should be
787/// removed.
788enum class DeletionKind { Keep, Delete };
789
790/// Elaborator state that should be shared by all elaborator instances.
791struct ElaboratorSharedState {
792 ElaboratorSharedState(SymbolTable &table, unsigned seed)
793 : table(table), rng(seed) {}
794
795 SymbolTable &table;
796 std::mt19937 rng;
797 Namespace names;
798 Internalizer internalizer;
799
800 /// The worklist used to keep track of the test and sequence operations to
801 /// make sure they are processed top-down (BFS traversal).
802 std::queue<RandomizedSequenceStorage *> worklist;
803};
804
805/// A collection of state per RTG test.
806struct TestState {
807 /// The name of the test.
808 StringAttr name;
809
810 /// The context switches registered for this test.
811 MapVector<
812 std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
813 SequenceStorage *>
814 contextSwitches;
815};
816
817/// Interprets the IR to perform and lower the represented randomizations.
818class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
819public:
821 using RTGBase::visitOp;
822
823 Elaborator(ElaboratorSharedState &sharedState, TestState &testState,
824 Materializer &materializer,
825 ContextResourceAttrInterface currentContext = {})
826 : sharedState(sharedState), testState(testState),
827 materializer(materializer), currentContext(currentContext) {}
828
829 template <typename ValueTy>
830 inline ValueTy get(Value val) const {
831 return std::get<ValueTy>(state.at(val));
832 }
833
834 FailureOr<DeletionKind> visitConstantLike(Operation *op) {
835 assert(op->hasTrait<OpTrait::ConstantLike>() &&
836 "op is expected to be constant-like");
837
838 SmallVector<OpFoldResult, 1> result;
839 auto foldResult = op->fold(result);
840 (void)foldResult; // Make sure there is a user when assertions are off.
841 assert(succeeded(foldResult) &&
842 "constant folder of a constant-like must always succeed");
843 auto attr = dyn_cast<TypedAttr>(result[0].dyn_cast<Attribute>());
844 if (!attr)
845 return op->emitError(
846 "only typed attributes supported for constant-like operations");
847
848 auto intAttr = dyn_cast<IntegerAttr>(attr);
849 if (intAttr && isa<IndexType>(attr.getType()))
850 state[op->getResult(0)] = size_t(intAttr.getInt());
851 else if (intAttr && intAttr.getType().isSignlessInteger(1))
852 state[op->getResult(0)] = bool(intAttr.getInt());
853 else
854 state[op->getResult(0)] = attr;
855
856 return DeletionKind::Delete;
857 }
858
859 /// Print a nice error message for operations we don't support yet.
860 FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
861 return op->emitOpError("elaboration not supported");
862 }
863
864 FailureOr<DeletionKind> visitExternalOp(Operation *op) {
865 if (op->hasTrait<OpTrait::ConstantLike>())
866 return visitConstantLike(op);
867
868 // TODO: we only have this to be able to write tests for this pass without
869 // having to add support for more operations for now, so it should be
870 // removed once it is not necessary anymore for writing tests
871 if (op->use_empty())
872 return DeletionKind::Keep;
873
874 return visitUnhandledOp(op);
875 }
876
877 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
878 SmallVector<ElaboratorValue> replacements;
879 state[op.getResult()] =
880 sharedState.internalizer.internalize<SequenceStorage>(
881 op.getSequenceAttr(), std::move(replacements));
882 return DeletionKind::Delete;
883 }
884
885 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
886 auto *seq = get<SequenceStorage *>(op.getSequence());
887
888 SmallVector<ElaboratorValue> replacements(seq->args);
889 for (auto replacement : op.getReplacements())
890 replacements.push_back(state.at(replacement));
891
892 state[op.getResult()] =
893 sharedState.internalizer.internalize<SequenceStorage>(
894 seq->familyName, std::move(replacements));
895
896 return DeletionKind::Delete;
897 }
898
899 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
900 auto *seq = get<SequenceStorage *>(op.getSequence());
901
902 auto name = sharedState.names.newName(seq->familyName.getValue());
903 auto *randomizedSeq =
904 sharedState.internalizer.internalize<RandomizedSequenceStorage>(
905 name, currentContext, testState.name, seq);
906 state[op.getResult()] =
907 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
908 randomizedSeq);
909 return DeletionKind::Delete;
910 }
911
912 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
913 SmallVector<ElaboratorValue> sequences;
914 for (auto seq : op.getSequences())
915 sequences.push_back(get<InterleavedSequenceStorage *>(seq));
916
917 state[op.getResult()] =
918 sharedState.internalizer.internalize<InterleavedSequenceStorage>(
919 std::move(sequences), op.getBatchSize());
920 return DeletionKind::Delete;
921 }
922
923 // NOLINTNEXTLINE(misc-no-recursion)
924 LogicalResult isValidContext(ElaboratorValue value, Operation *op) const {
925 if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
926 auto *seq = std::get<RandomizedSequenceStorage *>(value);
927 if (seq->context != currentContext) {
928 auto err = op->emitError("attempting to place sequence ")
929 << seq->name << " derived from "
930 << seq->sequence->familyName.getValue() << " under context "
931 << currentContext
932 << ", but it was previously randomized for context ";
933 if (seq->context)
934 err << seq->context;
935 else
936 err << "'default'";
937 return err;
938 }
939 return success();
940 }
941
942 auto *interVal = std::get<InterleavedSequenceStorage *>(value);
943 for (auto val : interVal->sequences)
944 if (failed(isValidContext(val, op)))
945 return failure();
946 return success();
947 }
948
949 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
950 auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
951 if (failed(isValidContext(seqVal, op)))
952 return failure();
953
954 return DeletionKind::Keep;
955 }
956
957 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
958 SetVector<ElaboratorValue> set;
959 for (auto val : op.getElements())
960 set.insert(state.at(val));
961
962 state[op.getSet()] = sharedState.internalizer.internalize<SetStorage>(
963 std::move(set), op.getSet().getType());
964 return DeletionKind::Delete;
965 }
966
967 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
968 auto set = get<SetStorage *>(op.getSet())->set;
969
970 if (set.empty())
971 return op->emitError("cannot select from an empty set");
972
973 size_t selected;
974 if (auto intAttr =
975 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
976 std::mt19937 customRng(intAttr.getInt());
977 selected = getUniformlyInRange(customRng, 0, set.size() - 1);
978 } else {
979 selected = getUniformlyInRange(sharedState.rng, 0, set.size() - 1);
980 }
981
982 state[op.getResult()] = set[selected];
983 return DeletionKind::Delete;
984 }
985
986 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
987 auto original = get<SetStorage *>(op.getOriginal())->set;
988 auto diff = get<SetStorage *>(op.getDiff())->set;
989
990 SetVector<ElaboratorValue> result(original);
991 result.set_subtract(diff);
992
993 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
994 std::move(result), op.getResult().getType());
995 return DeletionKind::Delete;
996 }
997
998 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
999 SetVector<ElaboratorValue> result;
1000 for (auto set : op.getSets())
1001 result.set_union(get<SetStorage *>(set)->set);
1002
1003 state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
1004 std::move(result), op.getType());
1005 return DeletionKind::Delete;
1006 }
1007
1008 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
1009 auto size = get<SetStorage *>(op.getSet())->set.size();
1010 state[op.getResult()] = size;
1011 return DeletionKind::Delete;
1012 }
1013
1014 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
1015 MapVector<ElaboratorValue, uint64_t> bag;
1016 for (auto [val, multiple] :
1017 llvm::zip(op.getElements(), op.getMultiples())) {
1018 // If the multiple is not stored as an AttributeValue, the elaboration
1019 // must have already failed earlier (since we don't have
1020 // unevaluated/opaque values).
1021 bag[state.at(val)] += get<size_t>(multiple);
1022 }
1023
1024 state[op.getBag()] = sharedState.internalizer.internalize<BagStorage>(
1025 std::move(bag), op.getType());
1026 return DeletionKind::Delete;
1027 }
1028
1029 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
1030 auto bag = get<BagStorage *>(op.getBag())->bag;
1031
1032 if (bag.empty())
1033 return op->emitError("cannot select from an empty bag");
1034
1035 SmallVector<std::pair<ElaboratorValue, uint32_t>> prefixSum;
1036 prefixSum.reserve(bag.size());
1037 uint32_t accumulator = 0;
1038 for (auto [val, weight] : bag) {
1039 accumulator += weight;
1040 prefixSum.push_back({val, accumulator});
1041 }
1042
1043 auto customRng = sharedState.rng;
1044 if (auto intAttr =
1045 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1046 customRng = std::mt19937(intAttr.getInt());
1047 }
1048
1049 auto idx = getUniformlyInRange(customRng, 0, accumulator - 1);
1050 auto *iter = llvm::upper_bound(
1051 prefixSum, idx,
1052 [](uint32_t a, const std::pair<ElaboratorValue, uint32_t> &b) {
1053 return a < b.second;
1054 });
1055
1056 state[op.getResult()] = iter->first;
1057 return DeletionKind::Delete;
1058 }
1059
1060 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
1061 auto original = get<BagStorage *>(op.getOriginal())->bag;
1062 auto diff = get<BagStorage *>(op.getDiff())->bag;
1063
1064 MapVector<ElaboratorValue, uint64_t> result;
1065 for (const auto &el : original) {
1066 if (!diff.contains(el.first)) {
1067 result.insert(el);
1068 continue;
1069 }
1070
1071 if (op.getInf())
1072 continue;
1073
1074 auto toDiff = diff.lookup(el.first);
1075 if (el.second <= toDiff)
1076 continue;
1077
1078 result.insert({el.first, el.second - toDiff});
1079 }
1080
1081 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1082 std::move(result), op.getType());
1083 return DeletionKind::Delete;
1084 }
1085
1086 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
1087 MapVector<ElaboratorValue, uint64_t> result;
1088 for (auto bag : op.getBags()) {
1089 auto val = get<BagStorage *>(bag)->bag;
1090 for (auto [el, multiple] : val)
1091 result[el] += multiple;
1092 }
1093
1094 state[op.getResult()] = sharedState.internalizer.internalize<BagStorage>(
1095 std::move(result), op.getType());
1096 return DeletionKind::Delete;
1097 }
1098
1099 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
1100 auto size = get<BagStorage *>(op.getBag())->bag.size();
1101 state[op.getResult()] = size;
1102 return DeletionKind::Delete;
1103 }
1104
1105 FailureOr<DeletionKind> visitOp(FixedRegisterOp op) {
1106 return visitConstantLike(op);
1107 }
1108
1109 FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1110 state[op.getResult()] =
1111 sharedState.internalizer.create<VirtualRegisterStorage>(
1112 op.getAllowedRegsAttr());
1113 return DeletionKind::Delete;
1114 }
1115
1116 StringAttr substituteFormatString(StringAttr formatString,
1117 ValueRange substitutes) const {
1118 if (substitutes.empty() || formatString.empty())
1119 return formatString;
1120
1121 auto original = formatString.getValue().str();
1122 for (auto [i, subst] : llvm::enumerate(substitutes)) {
1123 size_t startPos = 0;
1124 std::string from = "{{" + std::to_string(i) + "}}";
1125 while ((startPos = original.find(from, startPos)) != std::string::npos) {
1126 auto substString = std::to_string(get<size_t>(subst));
1127 original.replace(startPos, from.length(), substString);
1128 }
1129 }
1130
1131 return StringAttr::get(formatString.getContext(), original);
1132 }
1133
1134 FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
1135 auto substituted =
1136 substituteFormatString(op.getFormatStringAttr(), op.getArgs());
1137 state[op.getLabel()] = LabelValue(substituted);
1138 return DeletionKind::Delete;
1139 }
1140
1141 FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1142 state[op.getLabel()] = sharedState.internalizer.create<UniqueLabelStorage>(
1143 substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
1144 return DeletionKind::Delete;
1145 }
1146
1147 FailureOr<DeletionKind> visitOp(LabelOp op) { return DeletionKind::Keep; }
1148
1149 FailureOr<DeletionKind> visitOp(RandomNumberInRangeOp op) {
1150 size_t lower = get<size_t>(op.getLowerBound());
1151 size_t upper = get<size_t>(op.getUpperBound()) - 1;
1152 if (lower > upper)
1153 return op->emitError("cannot select a number from an empty range");
1154
1155 if (auto intAttr =
1156 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
1157 std::mt19937 customRng(intAttr.getInt());
1158 state[op.getResult()] =
1159 size_t(getUniformlyInRange(customRng, lower, upper));
1160 } else {
1161 state[op.getResult()] =
1162 size_t(getUniformlyInRange(sharedState.rng, lower, upper));
1163 }
1164
1165 return DeletionKind::Delete;
1166 }
1167
1168 FailureOr<DeletionKind> visitOp(OnContextOp op) {
1169 ContextResourceAttrInterface from = currentContext,
1170 to = cast<ContextResourceAttrInterface>(
1171 get<TypedAttr>(op.getContext()));
1172 if (!currentContext)
1173 from = DefaultContextAttr::get(op->getContext(), to.getType());
1174
1175 auto emitError = [&]() {
1176 auto diag = op.emitError();
1177 diag.attachNote(op.getLoc())
1178 << "while materializing value for context switching for " << op;
1179 return diag;
1180 };
1181
1182 if (from == to) {
1183 Value seqVal = materializer.materialize(
1184 get<SequenceStorage *>(op.getSequence()), op.getLoc(),
1185 sharedState.worklist, emitError);
1186 Value randSeqVal =
1187 materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1188 materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1189 return DeletionKind::Delete;
1190 }
1191
1192 // Switch to the desired context.
1193 auto *iter = testState.contextSwitches.find({from, to});
1194 // NOTE: we could think about supporting context switching via intermediate
1195 // context, i.e., treat it as a transitive relation.
1196 if (iter == testState.contextSwitches.end())
1197 return op->emitError("no context transition registered to switch from ")
1198 << from << " to " << to;
1199
1200 auto familyName = iter->second->familyName;
1201 SmallVector<ElaboratorValue> args{from, to,
1202 get<SequenceStorage *>(op.getSequence())};
1203 auto *seq = sharedState.internalizer.internalize<SequenceStorage>(
1204 familyName, std::move(args));
1205 auto *randSeq =
1206 sharedState.internalizer.internalize<RandomizedSequenceStorage>(
1207 sharedState.names.newName(familyName.getValue()), to,
1208 testState.name, seq);
1209 Value seqVal = materializer.materialize(randSeq, op.getLoc(),
1210 sharedState.worklist, emitError);
1211 materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1212
1213 return DeletionKind::Delete;
1214 }
1215
1216 FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1217 testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1218 get<SequenceStorage *>(op.getSequence());
1219 return DeletionKind::Delete;
1220 }
1221
1222 FailureOr<DeletionKind> visitOp(scf::IfOp op) {
1223 bool cond = get<bool>(op.getCondition());
1224 auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
1225 if (toElaborate.empty())
1226 return DeletionKind::Delete;
1227
1228 // Just reuse this elaborator for the nested region because we need access
1229 // to the elaborated values outside the nested region (since it is not
1230 // isolated from above) and we want to materialize the region inline, thus
1231 // don't need a new materializer instance.
1232 if (failed(elaborate(toElaborate)))
1233 return failure();
1234
1235 // Map the results of the 'scf.if' to the yielded values.
1236 for (auto [res, out] :
1237 llvm::zip(op.getResults(),
1238 toElaborate.front().getTerminator()->getOperands()))
1239 state[res] = state.at(out);
1240
1241 return DeletionKind::Delete;
1242 }
1243
1244 FailureOr<DeletionKind> visitOp(scf::ForOp op) {
1245 if (!(std::holds_alternative<size_t>(state.at(op.getLowerBound())) &&
1246 std::holds_alternative<size_t>(state.at(op.getStep())) &&
1247 std::holds_alternative<size_t>(state.at(op.getUpperBound()))))
1248 return op->emitOpError("can only elaborate index type iterator");
1249
1250 auto lowerBound = get<size_t>(op.getLowerBound());
1251 auto step = get<size_t>(op.getStep());
1252 auto upperBound = get<size_t>(op.getUpperBound());
1253
1254 // Prepare for first iteration by assigning the nested regions block
1255 // arguments. We can just reuse this elaborator because we need access to
1256 // values elaborated in the parent region anyway and materialize everything
1257 // inline (i.e., don't need a new materializer).
1258 state[op.getInductionVar()] = lowerBound;
1259 for (auto [iterArg, initArg] :
1260 llvm::zip(op.getRegionIterArgs(), op.getInitArgs()))
1261 state[iterArg] = state.at(initArg);
1262
1263 // This loop performs the actual 'scf.for' loop iterations.
1264 for (size_t i = lowerBound; i < upperBound; i += step) {
1265 if (failed(elaborate(op.getBodyRegion())))
1266 return failure();
1267
1268 // Prepare for the next iteration by updating the mapping of the nested
1269 // regions block arguments
1270 state[op.getInductionVar()] = i + step;
1271 for (auto [iterArg, prevIterArg] :
1272 llvm::zip(op.getRegionIterArgs(),
1273 op.getBody()->getTerminator()->getOperands()))
1274 state[iterArg] = state.at(prevIterArg);
1275 }
1276
1277 // Transfer the previously yielded values to the for loop result values.
1278 for (auto [res, iterArg] :
1279 llvm::zip(op->getResults(), op.getRegionIterArgs()))
1280 state[res] = state.at(iterArg);
1281
1282 return DeletionKind::Delete;
1283 }
1284
1285 FailureOr<DeletionKind> visitOp(scf::YieldOp op) {
1286 return DeletionKind::Delete;
1287 }
1288
1289 FailureOr<DeletionKind> visitOp(index::AddOp op) {
1290 size_t lhs = get<size_t>(op.getLhs());
1291 size_t rhs = get<size_t>(op.getRhs());
1292 state[op.getResult()] = lhs + rhs;
1293 return DeletionKind::Delete;
1294 }
1295
1296 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
1297 size_t lhs = get<size_t>(op.getLhs());
1298 size_t rhs = get<size_t>(op.getRhs());
1299 bool result;
1300 switch (op.getPred()) {
1301 case index::IndexCmpPredicate::EQ:
1302 result = lhs == rhs;
1303 break;
1304 case index::IndexCmpPredicate::NE:
1305 result = lhs != rhs;
1306 break;
1307 case index::IndexCmpPredicate::ULT:
1308 result = lhs < rhs;
1309 break;
1310 case index::IndexCmpPredicate::ULE:
1311 result = lhs <= rhs;
1312 break;
1313 case index::IndexCmpPredicate::UGT:
1314 result = lhs > rhs;
1315 break;
1316 case index::IndexCmpPredicate::UGE:
1317 result = lhs >= rhs;
1318 break;
1319 default:
1320 return op->emitOpError("elaboration not supported");
1321 }
1322 state[op.getResult()] = result;
1323 return DeletionKind::Delete;
1324 }
1325
1326 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
1327 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
1328 .Case<
1329 // Index ops
1330 index::AddOp, index::CmpOp,
1331 // SCF ops
1332 scf::IfOp, scf::ForOp, scf::YieldOp>(
1333 [&](auto op) { return visitOp(op); })
1334 .Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); });
1335 }
1336
1337 // NOLINTNEXTLINE(misc-no-recursion)
1338 LogicalResult elaborate(Region &region,
1339 ArrayRef<ElaboratorValue> regionArguments = {}) {
1340 if (region.getBlocks().size() > 1)
1341 return region.getParentOp()->emitOpError(
1342 "regions with more than one block are not supported");
1343
1344 for (auto [arg, elabArg] :
1345 llvm::zip(region.getArguments(), regionArguments))
1346 state[arg] = elabArg;
1347
1348 Block *block = &region.front();
1349 for (auto &op : *block) {
1350 auto result = dispatchOpVisitor(&op);
1351 if (failed(result))
1352 return failure();
1353
1354 if (*result == DeletionKind::Keep)
1355 if (failed(materializer.materialize(&op, state, sharedState.worklist)))
1356 return failure();
1357
1358 LLVM_DEBUG({
1359 llvm::dbgs() << "Elaborated " << op << " to\n[";
1360
1361 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) {
1362 if (state.contains(res))
1363 llvm::dbgs() << state.at(res);
1364 else
1365 llvm::dbgs() << "unknown";
1366 });
1367
1368 llvm::dbgs() << "]\n\n";
1369 });
1370 }
1371
1372 return success();
1373 }
1374
1375private:
1376 // State to be shared between all elaborator instances.
1377 ElaboratorSharedState &sharedState;
1378
1379 // State to a specific RTG test and the sequences placed within it.
1380 TestState &testState;
1381
1382 // Allows us to materialize ElaboratorValues to the IR operations necessary to
1383 // obtain an SSA value representing that elaborated value.
1384 Materializer &materializer;
1385
1386 // A map from SSA values to a pointer of an interned elaborator value.
1387 DenseMap<Value, ElaboratorValue> state;
1388
1389 // The current context we are elaborating under.
1390 ContextResourceAttrInterface currentContext;
1391};
1392} // namespace
1393
1394//===----------------------------------------------------------------------===//
1395// Elaborator Pass
1396//===----------------------------------------------------------------------===//
1397
1398namespace {
1399struct ElaborationPass
1400 : public rtg::impl::ElaborationPassBase<ElaborationPass> {
1401 using Base::Base;
1402
1403 void runOnOperation() override;
1404 void cloneTargetsIntoTests(SymbolTable &table);
1405 LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
1406};
1407} // namespace
1408
1409void ElaborationPass::runOnOperation() {
1410 auto moduleOp = getOperation();
1411 SymbolTable table(moduleOp);
1412
1413 cloneTargetsIntoTests(table);
1414
1415 if (failed(elaborateModule(moduleOp, table)))
1416 return signalPassFailure();
1417}
1418
1419void ElaborationPass::cloneTargetsIntoTests(SymbolTable &table) {
1420 auto moduleOp = getOperation();
1421 for (auto target : llvm::make_early_inc_range(moduleOp.getOps<TargetOp>())) {
1422 for (auto test : moduleOp.getOps<TestOp>()) {
1423 // If the test requires nothing from a target, we can always run it.
1424 if (test.getTarget().getEntries().empty())
1425 continue;
1426
1427 // If the target requirements do not match, skip this test
1428 // TODO: allow target refinements, just not coarsening
1429 if (target.getTarget() != test.getTarget())
1430 continue;
1431
1432 IRRewriter rewriter(test);
1433 // Create a new test for the matched target
1434 auto newTest = cast<TestOp>(test->clone());
1435 newTest.setSymName(test.getSymName().str() + "_" +
1436 target.getSymName().str());
1437 table.insert(newTest, rewriter.getInsertionPoint());
1438
1439 // Copy the target body into the newly created test
1440 IRMapping mapping;
1441 rewriter.setInsertionPointToStart(newTest.getBody());
1442 for (auto &op : target.getBody()->without_terminator())
1443 rewriter.clone(op, mapping);
1444
1445 for (auto [returnVal, result] :
1446 llvm::zip(target.getBody()->getTerminator()->getOperands(),
1447 newTest.getBody()->getArguments()))
1448 result.replaceAllUsesWith(mapping.lookup(returnVal));
1449
1450 newTest.getBody()->eraseArguments(0,
1451 newTest.getBody()->getNumArguments());
1452 newTest.setTarget(DictType::get(&getContext(), {}));
1453 }
1454
1455 target->erase();
1456 }
1457
1458 // Erase all remaining non-matched tests.
1459 for (auto test : llvm::make_early_inc_range(moduleOp.getOps<TestOp>()))
1460 if (!test.getTarget().getEntries().empty())
1461 test->erase();
1462}
1463
1464LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
1465 SymbolTable &table) {
1466 ElaboratorSharedState state(table, seed);
1467
1468 // Update the name cache
1469 state.names.add(moduleOp);
1470
1471 // Initialize the worklist with the test ops since they cannot be placed by
1472 // other ops.
1473 DenseMap<StringAttr, TestState> testStates;
1474 for (auto testOp : moduleOp.getOps<TestOp>()) {
1475 LLVM_DEBUG(llvm::dbgs()
1476 << "\n=== Elaborating test @" << testOp.getSymName() << "\n\n");
1477 Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()));
1478 testStates[testOp.getSymNameAttr()].name = testOp.getSymNameAttr();
1479 Elaborator elaborator(state, testStates[testOp.getSymNameAttr()],
1480 materializer);
1481 if (failed(elaborator.elaborate(testOp.getBodyRegion())))
1482 return failure();
1483
1484 materializer.finalize();
1485 }
1486
1487 // Do top-down BFS traversal such that elaborating a sequence further down
1488 // does not fix the outcome for multiple placements.
1489 while (!state.worklist.empty()) {
1490 auto *curr = state.worklist.front();
1491 state.worklist.pop();
1492
1493 if (table.lookup<SequenceOp>(curr->name))
1494 continue;
1495
1496 auto familyOp = table.lookup<SequenceOp>(curr->sequence->familyName);
1497 // TODO: don't clone if this is the only remaining reference to this
1498 // sequence
1499 OpBuilder builder(familyOp);
1500 auto seqOp = builder.cloneWithoutRegions(familyOp);
1501 seqOp.getBodyRegion().emplaceBlock();
1502 seqOp.setSymName(curr->name);
1503 seqOp.setSequenceType(
1504 SequenceType::get(builder.getContext(), ArrayRef<Type>{}));
1505 table.insert(seqOp);
1506 assert(seqOp.getSymName() == curr->name && "should not have been renamed");
1507
1508 LLVM_DEBUG(llvm::dbgs()
1509 << "\n=== Elaborating sequence family @" << familyOp.getSymName()
1510 << " into @" << seqOp.getSymName() << " under context "
1511 << curr->context << "\n\n");
1512
1513 Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()));
1514 Elaborator elaborator(state, testStates[curr->test], materializer,
1515 curr->context);
1516 if (failed(elaborator.elaborate(familyOp.getBodyRegion(),
1517 curr->sequence->args)))
1518 return failure();
1519
1520 materializer.finalize();
1521 }
1522
1523 return success();
1524}
assert(baseType &&"element must be base type")
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static void print(TypedAttr val, llvm::raw_ostream &os)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
This helps visit TypeOp nodes.
Definition RTGVisitors.h:29
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:76
ResultType dispatchOpVisitor(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:31
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
static llvm::hash_code hash_value(const ModulePort &port)
Definition HWTypes.h:38
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition Utils.h:32
Definition rtg.py:1
Definition seq.py:1
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs)
static unsigned getHashValue(const LabelValue &val)
static bool isEqual(const bool &lhs, const bool &rhs)
static unsigned getTombstoneKey()
static unsigned getHashValue(const bool &val)