Loading [MathJax]/jax/output/HTML-CSS/config.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
InlineSequencesPass.cpp
Go to the documentation of this file.
1//===- InlineSequencesPass.cpp - RTG InlineSequencesPass implementation ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass inlines sequences and computes sequence interleavings.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/IR/IRMapping.h"
18#include "llvm/Support/Debug.h"
19
20namespace circt {
21namespace rtg {
22#define GEN_PASS_DEF_INLINESEQUENCESPASS
23#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
24} // namespace rtg
25} // namespace circt
26
27using namespace mlir;
28using namespace circt;
29using namespace circt::rtg;
30
31#define DEBUG_TYPE "rtg-inline-sequences"
32
33//===----------------------------------------------------------------------===//
34// Inline Sequences Pass
35//===----------------------------------------------------------------------===//
36
37namespace {
38struct InlineSequencesPass
39 : public rtg::impl::InlineSequencesPassBase<InlineSequencesPass> {
40 using Base::Base;
41
42 void runOnOperation() override;
43};
44
45/// Enum to indicate to the visitor driver whether the operation should be
46/// deleted.
47enum class DeletionKind { Delete, Keep };
48
49/// The SequenceInliner computes sequence interleavings and inlines them.
50struct SequenceInliner
51 : public RTGOpVisitor<SequenceInliner, FailureOr<DeletionKind>> {
52 using RTGOpVisitor<SequenceInliner, FailureOr<DeletionKind>>::visitOp;
53
54 SequenceInliner(ModuleOp moduleOp) : table(moduleOp) {}
55
56 LogicalResult inlineSequences(TestOp testOp);
57 void materializeInterleavedSequence(Value value, ArrayRef<Block *> blocks,
58 uint32_t batchSize);
59
60 // Visitor methods
61
62 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
63 SmallVector<Block *> blocks;
64 for (auto [i, seq] : llvm::enumerate(op.getSequences())) {
65 auto iter = materializedSequences.find(seq);
66 if (iter == materializedSequences.end())
67 return op->emitError()
68 << "sequence operand #" << i
69 << " could not be resolved; it was likely produced by an op or "
70 "block argument not supported by this pass";
71
72 blocks.push_back(iter->getSecond().first);
73 }
74
75 LLVM_DEBUG(llvm::dbgs()
76 << " - Computing sequence interleaving: " << op << "\n");
77
78 materializeInterleavedSequence(op.getInterleavedSequence(), blocks,
79 op.getBatchSize());
80 return DeletionKind::Delete;
81 }
82
83 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
84 auto seqOp = table.lookup<SequenceOp>(op.getSequenceAttr());
85 if (!seqOp)
86 return op->emitError() << "referenced sequence not found";
87
88 LLVM_DEBUG(llvm::dbgs() << " - Registering existing sequence: "
89 << op.getSequence() << "\n");
90
91 materializedSequences[op.getResult()] =
92 std::make_pair(seqOp.getBody(), IRMapping());
93 return DeletionKind::Delete;
94 }
95
96 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
97 LLVM_DEBUG(llvm::dbgs() << " - Substitute sequence: " << op << "\n");
98
99 auto iter = materializedSequences.find(op.getSequence());
100 if (iter == materializedSequences.end())
101 return op->emitError() << "sequence operand could not be resolved; it "
102 "was likely produced by an op or block "
103 "argument not supported by this pass";
104
105 IRMapping mapping = iter->getSecond().second;
106 Block *block = iter->getSecond().first;
107 for (auto [arg, repl] :
108 llvm::zip(block->getArguments(), op.getReplacements())) {
109 LLVM_DEBUG(llvm::dbgs()
110 << " - Mapping " << arg << " to " << repl << "\n");
111 mapping.map(arg, repl);
112 }
113
114 materializedSequences[op.getResult()] = std::make_pair(block, mapping);
115 return DeletionKind::Delete;
116 }
117
118 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
119 LLVM_DEBUG(llvm::dbgs() << " - Randomize sequence: " << op << "\n");
120
121 auto iter = materializedSequences.find(op.getSequence());
122 if (iter == materializedSequences.end())
123 return op->emitError() << "sequence operand could not be resolved; it "
124 "was likely produced by an op or block "
125 "argument not supported by this pass";
126
127 // It's important to force a copy here. Without the temporary variable, we'd
128 // assign an lvalue.
129 auto value = iter->getSecond();
130 materializedSequences[op.getResult()] = value;
131 return DeletionKind::Delete;
132 }
133
134 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
135 LLVM_DEBUG(llvm::dbgs() << " - Inlining sequence: " << op << "\n");
136
137 auto iter = materializedSequences.find(op.getSequence());
138 if (iter == materializedSequences.end())
139 return op->emitError() << "sequence operand could not be resolved; it "
140 "was likely produced by an op or block "
141 "argument not supported by this pass";
142
143 OpBuilder builder(op);
144 builder.setInsertionPointAfter(op);
145 IRMapping mapping = iter->getSecond().second;
146
147 LLVM_DEBUG({
148 for (auto [k, v] : mapping.getValueMap())
149 llvm::dbgs() << " - Maps " << k << " to " << v << "\n";
150 });
151
152 for (auto &op : *iter->getSecond().first) {
153 Operation *o = builder.clone(op, mapping);
154 (void)o;
155 LLVM_DEBUG(llvm::dbgs() << " - Inlined " << *o << "\n");
156 }
157
158 ++numSequencesInlined;
159
160 return DeletionKind::Delete;
161 }
162
163 FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
164 return DeletionKind::Keep;
165 }
166
167 FailureOr<DeletionKind> visitExternalOp(Operation *op) {
168 return DeletionKind::Keep;
169 }
170
171 SymbolTable table;
172 DenseMap<Value, std::pair<Block *, IRMapping>> materializedSequences;
173 SmallVector<std::unique_ptr<Block>> blockStorage;
174 size_t numSequencesInlined = 0;
175 size_t numSequencesInterleaved = 0;
176};
177
178} // namespace
179
180void SequenceInliner::materializeInterleavedSequence(Value value,
181 ArrayRef<Block *> blocks,
182 uint32_t batchSize) {
183 auto *interleavedBlock =
184 blockStorage.emplace_back(std::make_unique<Block>()).get();
185 IRMapping mapping;
186 OpBuilder builder(value.getContext());
187 builder.setInsertionPointToStart(interleavedBlock);
188
189 SmallVector<Block::iterator> iters(blocks.size());
190 for (auto [i, block] : llvm::enumerate(blocks))
191 iters[i] = block->begin();
192
193 llvm::BitVector finishedBlocks(blocks.size());
194 for (unsigned i = 0; !finishedBlocks.all(); i = (i + 1) % blocks.size()) {
195 if (finishedBlocks[i])
196 continue;
197 for (unsigned k = 0; k < batchSize;) {
198 if (iters[i] == blocks[i]->end()) {
199 finishedBlocks.set(i);
200 break;
201 }
202 auto *op = builder.clone(*iters[i], mapping);
203 if (isa<InstructionOpInterface>(op))
204 ++k;
205 ++iters[i];
206 }
207 }
208
209 materializedSequences[value] = std::make_pair(interleavedBlock, IRMapping());
210 numSequencesInterleaved += blocks.size();
211}
212
213LogicalResult SequenceInliner::inlineSequences(TestOp testOp) {
214 LLVM_DEBUG(llvm::dbgs() << "\n=== Processing test @" << testOp.getSymName()
215 << "\n\n");
216
217 SmallVector<Operation *> toDelete;
218 for (auto &op : *testOp.getBody()) {
219 auto result = dispatchOpVisitor(&op);
220 if (failed(result))
221 return failure();
222
223 if (*result == DeletionKind::Delete)
224 toDelete.push_back(&op);
225 }
226
227 for (auto *op : llvm::reverse(toDelete))
228 op->erase();
229
230 return success();
231}
232
233void InlineSequencesPass::runOnOperation() {
234 auto moduleOp = getOperation();
235 SequenceInliner inliner(moduleOp);
236
237 // Inline all sequences and remove the operations that place the sequences.
238 for (auto testOp : moduleOp.getOps<TestOp>())
239 if (failed(inliner.inlineSequences(testOp)))
240 return signalPassFailure();
241
242 numSequencesInlined = inliner.numSequencesInlined;
243 numSequencesInterleaved = inliner.numSequencesInterleaved;
244}
This helps visit TypeOp nodes.
Definition RTGVisitors.h:29
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:90
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition rtg.py:1
Definition seq.py:1