CIRCT 20.0.0git
Loading...
Searching...
No Matches
ElaborationPass.cpp
Go to the documentation of this file.
1//===- ElaborationPass.cpp - RTG ElaborationPass implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass elaborates the random parts of the RTG dialect.
10// It performs randomization top-down, i.e., random constructs in a sequence
11// that is invoked multiple times can yield different randomization results
12// for each invokation.
13//
14//===----------------------------------------------------------------------===//
15
20#include "mlir/Dialect/Index/IR/IndexDialect.h"
21#include "mlir/Dialect/Index/IR/IndexOps.h"
22#include "mlir/IR/IRMapping.h"
23#include "mlir/IR/PatternMatch.h"
24#include "llvm/Support/Debug.h"
25#include <queue>
26#include <random>
27
28namespace circt {
29namespace rtg {
30#define GEN_PASS_DEF_ELABORATIONPASS
31#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
32} // namespace rtg
33} // namespace circt
34
35using namespace mlir;
36using namespace circt;
37using namespace circt::rtg;
38using llvm::MapVector;
39
40#define DEBUG_TYPE "rtg-elaboration"
41
42//===----------------------------------------------------------------------===//
43// Uniform Distribution Helper
44//
45// Simplified version of
46// https://github.com/llvm/llvm-project/blob/main/libcxx/include/__random/uniform_int_distribution.h
47// We use our custom version here to get the same results when compiled with
48// different compiler versions and standard libraries.
49//===----------------------------------------------------------------------===//
50
51static uint32_t computeMask(size_t w) {
52 size_t n = w / 32 + (w % 32 != 0);
53 size_t w0 = w / n;
54 return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0;
55}
56
57/// Get a number uniformly at random in the in specified range.
58static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) {
59 const uint32_t diff = b - a + 1;
60 if (diff == 1)
61 return a;
62
63 const uint32_t digits = std::numeric_limits<uint32_t>::digits;
64 if (diff == 0)
65 return rng();
66
67 uint32_t width = digits - llvm::countl_zero(diff) - 1;
68 if ((diff & (std::numeric_limits<uint32_t>::max() >> (digits - width))) != 0)
69 ++width;
70
71 uint32_t mask = computeMask(diff);
72 uint32_t u;
73 do {
74 u = rng() & mask;
75 } while (u >= diff);
76
77 return u + a;
78}
79
80//===----------------------------------------------------------------------===//
81// Elaborator Values
82//===----------------------------------------------------------------------===//
83
84namespace {
85
86/// The abstract base class for elaborated values.
87struct ElaboratorValue {
88public:
89 enum class ValueKind { Attribute, Set, Bag, Sequence, Index, Bool };
90
91 ElaboratorValue(ValueKind kind) : kind(kind) {}
92 virtual ~ElaboratorValue() {}
93
94 virtual llvm::hash_code getHashValue() const = 0;
95 virtual bool isEqual(const ElaboratorValue &other) const = 0;
96
97#ifndef NDEBUG
98 virtual void print(llvm::raw_ostream &os) const = 0;
99#endif
100
101 ValueKind getKind() const { return kind; }
102
103private:
104 const ValueKind kind;
105};
106
107/// Holds any typed attribute. Wrapping around an MLIR `Attribute` allows us to
108/// use this elaborator value class for any values that have a corresponding
109/// MLIR attribute rather than one per kind of attribute. We only support typed
110/// attributes because for materialization we need to provide the type to the
111/// dialect's materializer.
112class AttributeValue : public ElaboratorValue {
113public:
114 AttributeValue(TypedAttr attr)
115 : ElaboratorValue(ValueKind::Attribute), attr(attr) {
116 assert(attr && "null attributes not allowed");
117 }
118
119 // Implement LLVMs RTTI
120 static bool classof(const ElaboratorValue *val) {
121 return val->getKind() == ValueKind::Attribute;
122 }
123
124 llvm::hash_code getHashValue() const override {
125 return llvm::hash_combine(attr);
126 }
127
128 bool isEqual(const ElaboratorValue &other) const override {
129 auto *attrValue = dyn_cast<AttributeValue>(&other);
130 if (!attrValue)
131 return false;
132
133 return attr == attrValue->attr;
134 }
135
136#ifndef NDEBUG
137 void print(llvm::raw_ostream &os) const override {
138 os << "<attr " << attr << " at " << this << ">";
139 }
140#endif
141
142 TypedAttr getAttr() const { return attr; }
143
144private:
145 const TypedAttr attr;
146};
147
148/// Holds an evaluated value of a `IndexType`'d value.
149class IndexValue : public ElaboratorValue {
150public:
151 IndexValue(size_t index) : ElaboratorValue(ValueKind::Index), index(index) {}
152
153 // Implement LLVMs RTTI
154 static bool classof(const ElaboratorValue *val) {
155 return val->getKind() == ValueKind::Index;
156 }
157
158 llvm::hash_code getHashValue() const override {
159 return llvm::hash_value(index);
160 }
161
162 bool isEqual(const ElaboratorValue &other) const override {
163 auto *indexValue = dyn_cast<IndexValue>(&other);
164 if (!indexValue)
165 return false;
166
167 return index == indexValue->index;
168 }
169
170#ifndef NDEBUG
171 void print(llvm::raw_ostream &os) const override {
172 os << "<index " << index << " at " << this << ">";
173 }
174#endif
175
176 size_t getIndex() const { return index; }
177
178private:
179 const size_t index;
180};
181
182/// Holds an evaluated value of an `i1` type'd value.
183class BoolValue : public ElaboratorValue {
184public:
185 BoolValue(bool value) : ElaboratorValue(ValueKind::Bool), value(value) {}
186
187 // Implement LLVMs RTTI
188 static bool classof(const ElaboratorValue *val) {
189 return val->getKind() == ValueKind::Bool;
190 }
191
192 llvm::hash_code getHashValue() const override {
193 return llvm::hash_value(value);
194 }
195
196 bool isEqual(const ElaboratorValue &other) const override {
197 auto *val = dyn_cast<BoolValue>(&other);
198 if (!val)
199 return false;
200
201 return value == val->value;
202 }
203
204#ifndef NDEBUG
205 void print(llvm::raw_ostream &os) const override {
206 os << "<bool " << (value ? "true" : "false") << " at " << this << ">";
207 }
208#endif
209
210 bool getBool() const { return value; }
211
212private:
213 const bool value;
214};
215
216/// Holds an evaluated value of a `SetType`'d value.
217class SetValue : public ElaboratorValue {
218public:
219 SetValue(SetVector<ElaboratorValue *> &&set, Type type)
220 : ElaboratorValue(ValueKind::Set), set(std::move(set)), type(type),
221 cachedHash(llvm::hash_combine(
222 llvm::hash_combine_range(set.begin(), set.end()), type)) {}
223
224 // Implement LLVMs RTTI
225 static bool classof(const ElaboratorValue *val) {
226 return val->getKind() == ValueKind::Set;
227 }
228
229 llvm::hash_code getHashValue() const override { return cachedHash; }
230
231 bool isEqual(const ElaboratorValue &other) const override {
232 auto *otherSet = dyn_cast<SetValue>(&other);
233 if (!otherSet)
234 return false;
235
236 if (cachedHash != otherSet->cachedHash)
237 return false;
238
239 // Make sure empty sets of different types are not considered equal
240 return set == otherSet->set && type == otherSet->type;
241 }
242
243#ifndef NDEBUG
244 void print(llvm::raw_ostream &os) const override {
245 os << "<set {";
246 llvm::interleaveComma(set, os, [&](ElaboratorValue *el) { el->print(os); });
247 os << "} at " << this << ">";
248 }
249#endif
250
251 const SetVector<ElaboratorValue *> &getSet() const { return set; }
252
253 Type getType() const { return type; }
254
255private:
256 // We currently use a sorted vector to represent sets. Note that it is sorted
257 // by the pointer value and thus non-deterministic.
258 // We probably want to do some profiling in the future to see if a DenseSet or
259 // other representation is better suited.
260 const SetVector<ElaboratorValue *> set;
261
262 // Store the set type such that we can materialize this evaluated value
263 // also in the case where the set is empty.
264 const Type type;
265
266 // Compute the hash only once at constructor time.
267 const llvm::hash_code cachedHash;
268};
269
270/// Holds an evaluated value of a `BagType`'d value.
271class BagValue : public ElaboratorValue {
272public:
273 BagValue(MapVector<ElaboratorValue *, uint64_t> &&bag, Type type)
274 : ElaboratorValue(ValueKind::Bag), bag(std::move(bag)), type(type),
275 cachedHash(llvm::hash_combine(
276 llvm::hash_combine_range(bag.begin(), bag.end()), type)) {}
277
278 // Implement LLVMs RTTI
279 static bool classof(const ElaboratorValue *val) {
280 return val->getKind() == ValueKind::Bag;
281 }
282
283 llvm::hash_code getHashValue() const override { return cachedHash; }
284
285 bool isEqual(const ElaboratorValue &other) const override {
286 auto *otherBag = dyn_cast<BagValue>(&other);
287 if (!otherBag)
288 return false;
289
290 if (cachedHash != otherBag->cachedHash)
291 return false;
292
293 return llvm::equal(bag, otherBag->bag) && type == otherBag->type;
294 }
295
296#ifndef NDEBUG
297 void print(llvm::raw_ostream &os) const override {
298 os << "<bag {";
299 llvm::interleaveComma(bag, os,
300 [&](std::pair<ElaboratorValue *, uint64_t> el) {
301 el.first->print(os);
302 os << " -> " << el.second;
303 });
304 os << "} at " << this << ">";
305 }
306#endif
307
308 const MapVector<ElaboratorValue *, uint64_t> &getBag() const { return bag; }
309
310 Type getType() const { return type; }
311
312private:
313 // Stores the elaborated values of the bag.
314 const MapVector<ElaboratorValue *, uint64_t> bag;
315
316 // Store the type of the bag such that we can materialize this evaluated value
317 // also in the case where the bag is empty.
318 const Type type;
319
320 // Compute the hash only once at constructor time.
321 const llvm::hash_code cachedHash;
322};
323
324/// Holds an evaluated value of a `SequenceType`'d value.
325class SequenceValue : public ElaboratorValue {
326public:
327 SequenceValue(StringRef name, StringAttr familyName,
328 SmallVector<ElaboratorValue *> &&args)
329 : ElaboratorValue(ValueKind::Sequence), name(name),
330 familyName(familyName), args(std::move(args)),
331 cachedHash(llvm::hash_combine(
332 llvm::hash_combine_range(this->args.begin(), this->args.end()),
333 name, familyName)) {}
334
335 // Implement LLVMs RTTI
336 static bool classof(const ElaboratorValue *val) {
337 return val->getKind() == ValueKind::Sequence;
338 }
339
340 llvm::hash_code getHashValue() const override { return cachedHash; }
341
342 bool isEqual(const ElaboratorValue &other) const override {
343 auto *otherSeq = dyn_cast<SequenceValue>(&other);
344 if (!otherSeq)
345 return false;
346
347 if (cachedHash != otherSeq->cachedHash)
348 return false;
349
350 return name == otherSeq->name && familyName == otherSeq->familyName &&
351 args == otherSeq->args;
352 }
353
354#ifndef NDEBUG
355 void print(llvm::raw_ostream &os) const override {
356 os << "<sequence @" << name << " derived from @" << familyName.getValue()
357 << "(";
358 llvm::interleaveComma(args, os,
359 [&](ElaboratorValue *val) { val->print(os); });
360 os << ") at " << this << ">";
361 }
362#endif
363
364 StringRef getName() const { return name; }
365 StringAttr getFamilyName() const { return familyName; }
366 ArrayRef<ElaboratorValue *> getArgs() const { return args; }
367
368private:
369 const StringRef name;
370 const StringAttr familyName;
371 const SmallVector<ElaboratorValue *> args;
372
373 // Compute the hash only once at constructor time.
374 const llvm::hash_code cachedHash;
375};
376} // namespace
377
378#ifndef NDEBUG
379static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
380 const ElaboratorValue &value) {
381 value.print(os);
382 return os;
383}
384#endif
385
386//===----------------------------------------------------------------------===//
387// Hash Map Helpers
388//===----------------------------------------------------------------------===//
389
390// NOLINTNEXTLINE(readability-identifier-naming)
391static llvm::hash_code hash_value(const ElaboratorValue &val) {
392 return val.getHashValue();
393}
394
395namespace {
396struct InternMapInfo : public DenseMapInfo<ElaboratorValue *> {
397 static unsigned getHashValue(const ElaboratorValue *value) {
398 assert(value != getTombstoneKey() && value != getEmptyKey());
399 return hash_value(*value);
400 }
401
402 static bool isEqual(const ElaboratorValue *lhs, const ElaboratorValue *rhs) {
403 if (lhs == rhs)
404 return true;
405
406 auto *tk = getTombstoneKey();
407 auto *ek = getEmptyKey();
408 if (lhs == tk || rhs == tk || lhs == ek || rhs == ek)
409 return false;
410
411 return lhs->isEqual(*rhs);
412 }
413};
414} // namespace
415
416//===----------------------------------------------------------------------===//
417// Main Elaborator Implementation
418//===----------------------------------------------------------------------===//
419
420namespace {
421
422/// Construct an SSA value from a given elaborated value.
423class Materializer {
424public:
425 Value materialize(ElaboratorValue *val, Location loc,
426 std::queue<SequenceValue *> &elabRequests,
427 function_ref<InFlightDiagnostic()> emitError) {
428 assert(block && "must call reset before calling this function");
429
430 auto iter = materializedValues.find(val);
431 if (iter != materializedValues.end())
432 return iter->second;
433
434 LLVM_DEBUG(llvm::dbgs() << "Materializing " << *val << "\n\n");
435
436 OpBuilder builder(block, insertionPoint);
437 return TypeSwitch<ElaboratorValue *, Value>(val)
438 .Case<AttributeValue, IndexValue, BoolValue, SetValue, BagValue,
439 SequenceValue>([&](auto val) {
440 return visit(val, builder, loc, elabRequests, emitError);
441 })
442 .Default([](auto val) {
443 assert(false && "all cases must be covered above");
444 return Value();
445 });
446 }
447
448 Materializer &reset(Block *block) {
449 materializedValues.clear();
450 integerValues.clear();
451 this->block = block;
452 insertionPoint = block->begin();
453 return *this;
454 }
455
456private:
457 Value visit(AttributeValue *val, OpBuilder &builder, Location loc,
458 std::queue<SequenceValue *> &elabRequests,
459 function_ref<InFlightDiagnostic()> emitError) {
460 auto attr = val->getAttr();
461
462 // For index attributes (and arithmetic operations on them) we use the
463 // index dialect.
464 if (auto intAttr = dyn_cast<IntegerAttr>(attr);
465 intAttr && isa<IndexType>(attr.getType())) {
466 Value res = builder.create<index::ConstantOp>(loc, intAttr);
467 materializedValues[val] = res;
468 return res;
469 }
470
471 // For any other attribute, we just call the materializer of the dialect
472 // defining that attribute.
473 auto *op = attr.getDialect().materializeConstant(builder, attr,
474 attr.getType(), loc);
475 if (!op) {
476 emitError() << "materializer of dialect '"
477 << attr.getDialect().getNamespace()
478 << "' unable to materialize value for attribute '" << attr
479 << "'";
480 return Value();
481 }
482
483 Value res = op->getResult(0);
484 materializedValues[val] = res;
485 return res;
486 }
487
488 Value visit(IndexValue *val, OpBuilder &builder, Location loc,
489 std::queue<SequenceValue *> &elabRequests,
490 function_ref<InFlightDiagnostic()> emitError) {
491 Value res = builder.create<index::ConstantOp>(loc, val->getIndex());
492 materializedValues[val] = res;
493 return res;
494 }
495
496 Value visit(BoolValue *val, OpBuilder &builder, Location loc,
497 std::queue<SequenceValue *> &elabRequests,
498 function_ref<InFlightDiagnostic()> emitError) {
499 Value res = builder.create<index::BoolConstantOp>(loc, val->getBool());
500 materializedValues[val] = res;
501 return res;
502 }
503
504 Value visit(SetValue *val, OpBuilder &builder, Location loc,
505 std::queue<SequenceValue *> &elabRequests,
506 function_ref<InFlightDiagnostic()> emitError) {
507 SmallVector<Value> elements;
508 elements.reserve(val->getSet().size());
509 for (auto *el : val->getSet()) {
510 auto materialized = materialize(el, loc, elabRequests, emitError);
511 if (!materialized)
512 return Value();
513
514 elements.push_back(materialized);
515 }
516
517 auto res = builder.create<SetCreateOp>(loc, val->getType(), elements);
518 materializedValues[val] = res;
519 return res;
520 }
521
522 Value visit(BagValue *val, OpBuilder &builder, Location loc,
523 std::queue<SequenceValue *> &elabRequests,
524 function_ref<InFlightDiagnostic()> emitError) {
525 SmallVector<Value> values, weights;
526 values.reserve(val->getBag().size());
527 weights.reserve(val->getBag().size());
528 for (auto [val, weight] : val->getBag()) {
529 auto materializedVal = materialize(val, loc, elabRequests, emitError);
530 if (!materializedVal)
531 return Value();
532
533 auto iter = integerValues.find(weight);
534 Value materializedWeight;
535 if (iter != integerValues.end()) {
536 materializedWeight = iter->second;
537 } else {
538 materializedWeight = builder.create<index::ConstantOp>(
539 loc, builder.getIndexAttr(weight));
540 integerValues[weight] = materializedWeight;
541 }
542
543 values.push_back(materializedVal);
544 weights.push_back(materializedWeight);
545 }
546
547 auto res =
548 builder.create<BagCreateOp>(loc, val->getType(), values, weights);
549 materializedValues[val] = res;
550 return res;
551 }
552
553 Value visit(SequenceValue *val, OpBuilder &builder, Location loc,
554 std::queue<SequenceValue *> &elabRequests,
555 function_ref<InFlightDiagnostic()> emitError) {
556 elabRequests.push(val);
557 return builder.create<SequenceClosureOp>(loc, val->getName(), ValueRange());
558 }
559
560private:
561 /// Cache values we have already materialized to reuse them later. We start
562 /// with an insertion point at the start of the block and cache the (updated)
563 /// insertion point such that future materializations can also reuse previous
564 /// materializations without running into dominance issues (or requiring
565 /// additional checks to avoid them).
566 DenseMap<ElaboratorValue *, Value> materializedValues;
567 DenseMap<uint64_t, Value> integerValues;
568
569 /// Cache the builders to continue insertions at their current insertion point
570 /// for the reason stated above.
571 Block *block;
572 Block::iterator insertionPoint;
573};
574
575/// Used to signal to the elaboration driver whether the operation should be
576/// removed.
577enum class DeletionKind { Keep, Delete };
578
579/// Interprets the IR to perform and lower the represented randomizations.
580class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
581public:
583 using RTGBase::visitOp;
584 using RTGBase::visitRegisterOp;
585
586 Elaborator(SymbolTable &table, std::mt19937 &rng) : rng(rng), table(table) {}
587
588 /// Helper to perform internalization and keep track of interpreted value for
589 /// the given SSA value.
590 template <typename ValueTy, typename... Args>
591 void internalizeResult(Value val, Args &&...args) {
592 // TODO: this isn't the most efficient way to internalize
593 auto ptr = std::make_unique<ValueTy>(std::forward<Args>(args)...);
594 auto *e = ptr.get();
595 auto [iter, _] = interned.insert({e, std::move(ptr)});
596 state[val] = iter->second.get();
597 }
598
599 /// Print a nice error message for operations we don't support yet.
600 FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
601 return op->emitOpError("elaboration not supported");
602 }
603
604 FailureOr<DeletionKind> visitExternalOp(Operation *op) {
605 // TODO: we only have this to be able to write tests for this pass without
606 // having to add support for more operations for now, so it should be
607 // removed once it is not necessary anymore for writing tests
608 if (op->use_empty())
609 return DeletionKind::Keep;
610
611 return visitUnhandledOp(op);
612 }
613
614 FailureOr<DeletionKind> visitOp(SequenceClosureOp op) {
615 SmallVector<ElaboratorValue *> args;
616 for (auto arg : op.getArgs())
617 args.push_back(state.at(arg));
618
619 auto familyName = op.getSequenceAttr();
620 auto name = names.newName(familyName.getValue());
621 internalizeResult<SequenceValue>(op.getResult(), name, familyName,
622 std::move(args));
623 return DeletionKind::Delete;
624 }
625
626 FailureOr<DeletionKind> visitOp(InvokeSequenceOp op) {
627 return DeletionKind::Keep;
628 }
629
630 FailureOr<DeletionKind> visitOp(SetCreateOp op) {
631 SetVector<ElaboratorValue *> set;
632 for (auto val : op.getElements())
633 set.insert(state.at(val));
634
635 internalizeResult<SetValue>(op.getSet(), std::move(set),
636 op.getSet().getType());
637 return DeletionKind::Delete;
638 }
639
640 FailureOr<DeletionKind> visitOp(SetSelectRandomOp op) {
641 auto *set = cast<SetValue>(state.at(op.getSet()));
642
643 size_t selected;
644 if (auto intAttr =
645 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
646 std::mt19937 customRng(intAttr.getInt());
647 selected = getUniformlyInRange(customRng, 0, set->getSet().size() - 1);
648 } else {
649 selected = getUniformlyInRange(rng, 0, set->getSet().size() - 1);
650 }
651
652 state[op.getResult()] = set->getSet()[selected];
653 return DeletionKind::Delete;
654 }
655
656 FailureOr<DeletionKind> visitOp(SetDifferenceOp op) {
657 auto original = cast<SetValue>(state.at(op.getOriginal()))->getSet();
658 auto diff = cast<SetValue>(state.at(op.getDiff()))->getSet();
659
660 SetVector<ElaboratorValue *> result(original);
661 result.set_subtract(diff);
662
663 internalizeResult<SetValue>(op.getResult(), std::move(result),
664 op.getResult().getType());
665 return DeletionKind::Delete;
666 }
667
668 FailureOr<DeletionKind> visitOp(SetUnionOp op) {
669 SetVector<ElaboratorValue *> result;
670 for (auto set : op.getSets())
671 result.set_union(cast<SetValue>(state.at(set))->getSet());
672
673 internalizeResult<SetValue>(op.getResult(), std::move(result),
674 op.getType());
675 return DeletionKind::Delete;
676 }
677
678 FailureOr<DeletionKind> visitOp(SetSizeOp op) {
679 auto size = cast<SetValue>(state.at(op.getSet()))->getSet().size();
680 auto sizeAttr = IntegerAttr::get(IndexType::get(op->getContext()), size);
681 internalizeResult<AttributeValue>(op.getResult(), sizeAttr);
682 return DeletionKind::Delete;
683 }
684
685 FailureOr<DeletionKind> visitOp(BagCreateOp op) {
686 MapVector<ElaboratorValue *, uint64_t> bag;
687 for (auto [val, multiple] :
688 llvm::zip(op.getElements(), op.getMultiples())) {
689 auto *interpValue = state.at(val);
690 // If the multiple is not stored as an AttributeValue, the elaboration
691 // must have already failed earlier (since we don't have
692 // unevaluated/opaque values).
693 auto *interpMultiple = cast<IndexValue>(state.at(multiple));
694 bag[interpValue] += interpMultiple->getIndex();
695 }
696
697 internalizeResult<BagValue>(op.getBag(), std::move(bag), op.getType());
698 return DeletionKind::Delete;
699 }
700
701 FailureOr<DeletionKind> visitOp(BagSelectRandomOp op) {
702 auto *bag = cast<BagValue>(state.at(op.getBag()));
703
704 SmallVector<std::pair<ElaboratorValue *, uint32_t>> prefixSum;
705 prefixSum.reserve(bag->getBag().size());
706 uint32_t accumulator = 0;
707 for (auto [val, weight] : bag->getBag()) {
708 accumulator += weight;
709 prefixSum.push_back({val, accumulator});
710 }
711
712 auto customRng = rng;
713 if (auto intAttr =
714 op->getAttrOfType<IntegerAttr>("rtg.elaboration_custom_seed")) {
715 customRng = std::mt19937(intAttr.getInt());
716 }
717
718 auto idx = getUniformlyInRange(customRng, 0, accumulator - 1);
719 auto *iter = llvm::upper_bound(
720 prefixSum, idx,
721 [](uint32_t a, const std::pair<ElaboratorValue *, uint32_t> &b) {
722 return a < b.second;
723 });
724 state[op.getResult()] = iter->first;
725 return DeletionKind::Delete;
726 }
727
728 FailureOr<DeletionKind> visitOp(BagDifferenceOp op) {
729 auto *original = cast<BagValue>(state.at(op.getOriginal()));
730 auto *diff = cast<BagValue>(state.at(op.getDiff()));
731
732 MapVector<ElaboratorValue *, uint64_t> result;
733 for (const auto &el : original->getBag()) {
734 if (!diff->getBag().contains(el.first)) {
735 result.insert(el);
736 continue;
737 }
738
739 if (op.getInf())
740 continue;
741
742 auto toDiff = diff->getBag().lookup(el.first);
743 if (el.second <= toDiff)
744 continue;
745
746 result.insert({el.first, el.second - toDiff});
747 }
748
749 internalizeResult<BagValue>(op.getResult(), std::move(result),
750 op.getType());
751 return DeletionKind::Delete;
752 }
753
754 FailureOr<DeletionKind> visitOp(BagUnionOp op) {
755 MapVector<ElaboratorValue *, uint64_t> result;
756 for (auto bag : op.getBags()) {
757 auto *val = cast<BagValue>(state.at(bag));
758 for (auto [el, multiple] : val->getBag())
759 result[el] += multiple;
760 }
761
762 internalizeResult<BagValue>(op.getResult(), std::move(result),
763 op.getType());
764 return DeletionKind::Delete;
765 }
766
767 FailureOr<DeletionKind> visitOp(BagUniqueSizeOp op) {
768 auto size = cast<BagValue>(state.at(op.getBag()))->getBag().size();
769 auto sizeAttr = IntegerAttr::get(IndexType::get(op->getContext()), size);
770 internalizeResult<AttributeValue>(op.getResult(), sizeAttr);
771 return DeletionKind::Delete;
772 }
773
774 FailureOr<DeletionKind> visitOp(index::AddOp op) {
775 size_t lhs = cast<IndexValue>(state.at(op.getLhs()))->getIndex();
776 size_t rhs = cast<IndexValue>(state.at(op.getRhs()))->getIndex();
777 internalizeResult<IndexValue>(op.getResult(), lhs + rhs);
778 return DeletionKind::Delete;
779 }
780
781 FailureOr<DeletionKind> visitOp(index::CmpOp op) {
782 size_t lhs = cast<IndexValue>(state.at(op.getLhs()))->getIndex();
783 size_t rhs = cast<IndexValue>(state.at(op.getRhs()))->getIndex();
784 bool result;
785 switch (op.getPred()) {
786 case index::IndexCmpPredicate::EQ:
787 result = lhs == rhs;
788 break;
789 case index::IndexCmpPredicate::NE:
790 result = lhs != rhs;
791 break;
792 case index::IndexCmpPredicate::ULT:
793 result = lhs < rhs;
794 break;
795 case index::IndexCmpPredicate::ULE:
796 result = lhs <= rhs;
797 break;
798 case index::IndexCmpPredicate::UGT:
799 result = lhs > rhs;
800 break;
801 case index::IndexCmpPredicate::UGE:
802 result = lhs >= rhs;
803 break;
804 default:
805 return op->emitOpError("elaboration not supported");
806 }
807 internalizeResult<BoolValue>(op.getResult(), result);
808 return DeletionKind::Delete;
809 }
810
811 FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
812 if (op->hasTrait<OpTrait::ConstantLike>()) {
813 SmallVector<OpFoldResult, 1> result;
814 auto foldResult = op->fold(result);
815 (void)foldResult; // Make sure there is a user when assertions are off.
816 assert(succeeded(foldResult) &&
817 "constant folder of a constant-like must always succeed");
818 auto attr = dyn_cast<TypedAttr>(result[0].dyn_cast<Attribute>());
819 if (!attr)
820 return op->emitError(
821 "only typed attributes supported for constant-like operations");
822
823 auto intAttr = dyn_cast<IntegerAttr>(attr);
824 if (intAttr && isa<IndexType>(attr.getType()))
825 internalizeResult<IndexValue>(op->getResult(0), intAttr.getInt());
826 else if (intAttr && intAttr.getType().isSignlessInteger(1))
827 internalizeResult<BoolValue>(op->getResult(0), intAttr.getInt());
828 else
829 internalizeResult<AttributeValue>(op->getResult(0), attr);
830
831 return DeletionKind::Delete;
832 }
833
834 return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
835 .Case<index::AddOp, index::CmpOp>([&](auto op) { return visitOp(op); })
836 .Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); });
837 }
838
839 LogicalResult elaborate(SequenceOp family, SequenceOp dest,
840 ArrayRef<ElaboratorValue *> args) {
841 LLVM_DEBUG(llvm::dbgs() << "\n=== Elaborating " << family.getOperationName()
842 << " @" << family.getSymName() << " into @"
843 << dest.getSymName() << "\n\n");
844
845 // Reduce max memory consumption and make sure the values cannot be accessed
846 // anymore because we deleted the ops above. Clearing should lead to better
847 // performance than having them as a local here and pass via function
848 // argument.
849 state.clear();
850 materializer.reset(dest.getBody());
851 IRMapping mapping;
852
853 for (auto [arg, elabArg] :
854 llvm::zip(family.getBody()->getArguments(), args))
855 state[arg] = elabArg;
856
857 for (auto &op : *family.getBody()) {
858 if (op.getNumRegions() != 0)
859 return op.emitOpError("nested regions not supported");
860
861 auto result = dispatchOpVisitor(&op);
862 if (failed(result))
863 return failure();
864
865 if (*result == DeletionKind::Keep) {
866 for (auto &operand : op.getOpOperands()) {
867 if (mapping.contains(operand.get()))
868 continue;
869
870 auto emitError = [&]() {
871 auto diag = op.emitError();
872 diag.attachNote(op.getLoc())
873 << "while materializing value for operand#"
874 << operand.getOperandNumber();
875 return diag;
876 };
877 Value val = materializer.materialize(
878 state.at(operand.get()), op.getLoc(), worklist, emitError);
879 if (!val)
880 return failure();
881
882 mapping.map(operand.get(), val);
883 }
884
885 OpBuilder builder = OpBuilder::atBlockEnd(dest.getBody());
886 builder.clone(op, mapping);
887 }
888
889 LLVM_DEBUG({
890 llvm::dbgs() << "Elaborating " << op << " to\n[";
891
892 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) {
893 if (state.contains(res))
894 llvm::dbgs() << *state.at(res);
895 else
896 llvm::dbgs() << "unknown";
897 });
898
899 llvm::dbgs() << "]\n\n";
900 });
901 }
902
903 return success();
904 }
905
906 template <typename OpTy>
907 LogicalResult elaborateInPlace(OpTy op) {
908 LLVM_DEBUG(llvm::dbgs()
909 << "\n=== Elaborating (in place) " << op.getOperationName()
910 << " @" << op.getSymName() << "\n\n");
911
912 // Reduce max memory consumption and make sure the values cannot be accessed
913 // anymore because we deleted the ops above. Clearing should lead to better
914 // performance than having them as a local here and pass via function
915 // argument.
916 state.clear();
917 materializer.reset(op.getBody());
918
919 SmallVector<Operation *> toDelete;
920 for (auto &op : *op.getBody()) {
921 if (op.getNumRegions() != 0)
922 return op.emitOpError("nested regions not supported");
923
924 auto result = dispatchOpVisitor(&op);
925 if (failed(result))
926 return failure();
927
928 if (*result == DeletionKind::Keep) {
929 for (auto &operand : op.getOpOperands()) {
930 auto emitError = [&]() {
931 auto diag = op.emitError();
932 diag.attachNote(op.getLoc())
933 << "while materializing value for operand#"
934 << operand.getOperandNumber();
935 return diag;
936 };
937 Value val = materializer.materialize(
938 state.at(operand.get()), op.getLoc(), worklist, emitError);
939 if (!val)
940 return failure();
941 operand.set(val);
942 }
943 } else { // DeletionKind::Delete
944 toDelete.push_back(&op);
945 }
946
947 LLVM_DEBUG({
948 llvm::dbgs() << "Elaborating " << op << " to\n[";
949
950 llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) {
951 if (state.contains(res))
952 llvm::dbgs() << *state.at(res);
953 else
954 llvm::dbgs() << "unknown";
955 });
956
957 llvm::dbgs() << "]\n\n";
958 });
959 }
960
961 for (auto *op : llvm::reverse(toDelete))
962 op->erase();
963
964 return success();
965 }
966
967 LogicalResult inlineSequences(TestOp testOp) {
968 OpBuilder builder(testOp);
969 for (auto iter = testOp.getBody()->begin();
970 iter != testOp.getBody()->end();) {
971 auto invokeOp = dyn_cast<InvokeSequenceOp>(&*iter);
972 if (!invokeOp) {
973 ++iter;
974 continue;
975 }
976
977 auto seqClosureOp =
978 invokeOp.getSequence().getDefiningOp<SequenceClosureOp>();
979 if (!seqClosureOp)
980 return invokeOp->emitError(
981 "sequence operand not directly defined by sequence_closure op");
982
983 auto seqOp = table.lookup<SequenceOp>(seqClosureOp.getSequenceAttr());
984
985 builder.setInsertionPointAfter(invokeOp);
986 IRMapping mapping;
987 for (auto &op : *seqOp.getBody())
988 builder.clone(op, mapping);
989
990 (iter++)->erase();
991
992 if (seqClosureOp->use_empty())
993 seqClosureOp->erase();
994 }
995
996 return success();
997 }
998
999 LogicalResult elaborateModule(ModuleOp moduleOp) {
1000 // Update the name cache
1001 names.clear();
1002 names.add(moduleOp);
1003
1004 // Initialize the worklist with the test ops since they cannot be placed by
1005 // other ops.
1006 for (auto testOp : moduleOp.getOps<TestOp>())
1007 if (failed(elaborateInPlace(testOp)))
1008 return failure();
1009
1010 // Do top-down BFS traversal such that elaborating a sequence further down
1011 // does not fix the outcome for multiple placements.
1012 while (!worklist.empty()) {
1013 auto *curr = worklist.front();
1014 worklist.pop();
1015
1016 if (table.lookup<SequenceOp>(curr->getName()))
1017 continue;
1018
1019 auto familyOp = table.lookup<SequenceOp>(curr->getFamilyName());
1020 // TODO: use 'elaborateInPlace' and don't clone if this is the only
1021 // remaining reference to this sequence
1022 OpBuilder builder(familyOp);
1023 auto seqOp = builder.cloneWithoutRegions(familyOp);
1024 seqOp.getBodyRegion().emplaceBlock();
1025 seqOp.setSymName(curr->getName());
1026 table.insert(seqOp);
1027 assert(seqOp.getSymName() == curr->getName() &&
1028 "should not have been renamed");
1029
1030 if (failed(elaborate(familyOp, seqOp, curr->getArgs())))
1031 return failure();
1032 }
1033
1034 // Inline all sequences and remove the operations that place the sequences.
1035 for (auto testOp : moduleOp.getOps<TestOp>())
1036 if (failed(inlineSequences(testOp)))
1037 return failure();
1038
1039 // Remove all sequences since they are not accessible from the outside and
1040 // are not needed anymore since we fully inlined them.
1041 for (auto seqOp : llvm::make_early_inc_range(moduleOp.getOps<SequenceOp>()))
1042 seqOp->erase();
1043
1044 return success();
1045 }
1046
1047private:
1048 std::mt19937 rng;
1049 SymbolTable &table;
1050 Namespace names;
1051
1052 /// The worklist used to keep track of the test and sequence operations to
1053 /// make sure they are processed top-down (BFS traversal).
1054 std::queue<SequenceValue *> worklist;
1055
1056 // A map used to intern elaborator values. We do this such that we can
1057 // compare pointers when, e.g., computing set differences, uniquing the
1058 // elements in a set, etc. Otherwise, we'd need to do a deep value comparison
1059 // in those situations.
1060 // Use a pointer as the key with custom MapInfo because of object slicing when
1061 // inserting an object of a derived class of ElaboratorValue.
1062 // The custom MapInfo makes sure that we do a value comparison instead of
1063 // comparing the pointers.
1064 DenseMap<ElaboratorValue *, std::unique_ptr<ElaboratorValue>, InternMapInfo>
1065 interned;
1066
1067 // A map from SSA values to a pointer of an interned elaborator value.
1068 DenseMap<Value, ElaboratorValue *> state;
1069
1070 // Allows us to materialize ElaboratorValues to the IR operations necessary to
1071 // obtain an SSA value representing that elaborated value.
1072 Materializer materializer;
1073};
1074} // namespace
1075
1076//===----------------------------------------------------------------------===//
1077// Elaborator Pass
1078//===----------------------------------------------------------------------===//
1079
1080namespace {
1081struct ElaborationPass
1082 : public rtg::impl::ElaborationPassBase<ElaborationPass> {
1083 using Base::Base;
1084
1085 void runOnOperation() override;
1086 void cloneTargetsIntoTests(SymbolTable &table);
1087};
1088} // namespace
1089
1090void ElaborationPass::runOnOperation() {
1091 auto moduleOp = getOperation();
1092 SymbolTable table(moduleOp);
1093
1094 cloneTargetsIntoTests(table);
1095
1096 std::mt19937 rng(seed);
1097 Elaborator elaborator(table, rng);
1098 if (failed(elaborator.elaborateModule(moduleOp)))
1099 return signalPassFailure();
1100}
1101
1102void ElaborationPass::cloneTargetsIntoTests(SymbolTable &table) {
1103 auto moduleOp = getOperation();
1104 for (auto target : llvm::make_early_inc_range(moduleOp.getOps<TargetOp>())) {
1105 for (auto test : moduleOp.getOps<TestOp>()) {
1106 // If the test requires nothing from a target, we can always run it.
1107 if (test.getTarget().getEntries().empty())
1108 continue;
1109
1110 // If the target requirements do not match, skip this test
1111 // TODO: allow target refinements, just not coarsening
1112 if (target.getTarget() != test.getTarget())
1113 continue;
1114
1115 IRRewriter rewriter(test);
1116 // Create a new test for the matched target
1117 auto newTest = cast<TestOp>(test->clone());
1118 newTest.setSymName(test.getSymName().str() + "_" +
1119 target.getSymName().str());
1120 table.insert(newTest, rewriter.getInsertionPoint());
1121
1122 // Copy the target body into the newly created test
1123 IRMapping mapping;
1124 rewriter.setInsertionPointToStart(newTest.getBody());
1125 for (auto &op : target.getBody()->without_terminator())
1126 rewriter.clone(op, mapping);
1127
1128 for (auto [returnVal, result] :
1129 llvm::zip(target.getBody()->getTerminator()->getOperands(),
1130 newTest.getBody()->getArguments()))
1131 result.replaceAllUsesWith(mapping.lookup(returnVal));
1132
1133 newTest.getBody()->eraseArguments(0,
1134 newTest.getBody()->getNumArguments());
1135 newTest.setTarget(DictType::get(&getContext(), {}));
1136 }
1137
1138 target->erase();
1139 }
1140
1141 // Erase all remaining non-matched tests.
1142 for (auto test : llvm::make_early_inc_range(moduleOp.getOps<TestOp>()))
1143 if (!test.getTarget().getEntries().empty())
1144 test->erase();
1145}
assert(baseType &&"element must be base type")
static uint32_t computeMask(size_t w)
static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b)
Get a number uniformly at random in the in specified range.
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
This helps visit TypeOp nodes.
Definition RTGVisitors.h:29
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:78
ResultType dispatchOpVisitor(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:31
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
static llvm::hash_code hash_value(const ModulePort &port)
Definition HWTypes.h:38
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition Utils.h:32
Definition rtg.py:1