CIRCT 23.0.0git
Loading...
Searching...
No Matches
InlineSequencesPass.cpp
Go to the documentation of this file.
1//===- InlineSequencesPass.cpp - RTG InlineSequencesPass implementation ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass inlines sequences and computes sequence interleavings.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/IR/IRMapping.h"
18#include "llvm/Support/Debug.h"
19
20namespace circt {
21namespace rtg {
22#define GEN_PASS_DEF_INLINESEQUENCESPASS
23#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc"
24} // namespace rtg
25} // namespace circt
26
27using namespace mlir;
28using namespace circt;
29using namespace circt::rtg;
30
31#define DEBUG_TYPE "rtg-inline-sequences"
32
33//===----------------------------------------------------------------------===//
34// Inline Sequences Pass
35//===----------------------------------------------------------------------===//
36
37namespace {
38struct InlineSequencesPass
39 : public rtg::impl::InlineSequencesPassBase<InlineSequencesPass> {
40 using Base::Base;
41
42 void runOnOperation() override;
43};
44
45/// Enum to indicate to the visitor driver whether the operation should be
46/// deleted.
47enum class DeletionKind { Delete, Keep };
48
49/// The SequenceInliner computes sequence interleavings and inlines them.
50struct SequenceInliner
51 : public RTGOpVisitor<SequenceInliner, FailureOr<DeletionKind>> {
52 using RTGOpVisitor<SequenceInliner, FailureOr<DeletionKind>>::visitOp;
53
54 SequenceInliner(ModuleOp moduleOp, bool failOnRemaining)
55 : table(moduleOp), failOnRemaining(failOnRemaining) {}
56
57 LogicalResult inlineSequences(Block &block);
58 void materializeInterleavedSequence(Value value, ArrayRef<Block *> blocks,
59 uint32_t batchSize);
60
61 FailureOr<std::pair<Block *, IRMapping>>
62 getMaterializedSequence(Value seq, Location loc) {
63 auto iter = materializedSequences.find(seq);
64 if (iter == materializedSequences.end()) {
65 StringLiteral msg = "sequence operand could not be resolved; it "
66 "was likely produced by an op or block "
67 "argument not supported by this pass";
68 if (failOnRemaining)
69 return mlir::emitError(loc, msg);
70
71 LLVM_DEBUG(llvm::dbgs() << msg << "\n");
72 return failure();
73 }
74
75 return iter->getSecond();
76 }
77
78 // Visitor methods
79
80 FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
81 SmallVector<Block *> blocks;
82 for (auto [i, seq] : llvm::enumerate(op.getSequences())) {
83 auto res = getMaterializedSequence(seq, op.getLoc());
84 if (failed(res))
85 return failure();
86
87 blocks.push_back(res->first);
88 }
89
90 LLVM_DEBUG(llvm::dbgs()
91 << " - Computing sequence interleaving: " << op << "\n");
92
93 materializeInterleavedSequence(op.getInterleavedSequence(), blocks,
94 op.getBatchSize());
95 return DeletionKind::Delete;
96 }
97
98 FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
99 auto seqOp = table.lookup<SequenceOp>(op.getSequenceAttr().getAttr());
100 if (!seqOp)
101 return op->emitError() << "referenced sequence not found";
102
103 LLVM_DEBUG(llvm::dbgs() << " - Registering existing sequence: "
104 << op.getSequence() << "\n");
105
106 materializedSequences[op.getResult()] =
107 std::make_pair(seqOp.getBody(), IRMapping());
108 return DeletionKind::Delete;
109 }
110
111 FailureOr<DeletionKind> visitOp(SubstituteSequenceOp op) {
112 LLVM_DEBUG(llvm::dbgs() << " - Substitute sequence: " << op << "\n");
113
114 auto res = getMaterializedSequence(op.getSequence(), op.getLoc());
115 if (failed(res))
116 return failure();
117
118 IRMapping mapping = res->second;
119 Block *block = res->first;
120 for (auto [arg, repl] :
121 llvm::zip(block->getArguments(), op.getReplacements())) {
122 LLVM_DEBUG(llvm::dbgs()
123 << " - Mapping " << arg << " to " << repl << "\n");
124 mapping.map(arg, repl);
125 }
126
127 materializedSequences[op.getResult()] = std::make_pair(block, mapping);
128 return DeletionKind::Delete;
129 }
130
131 FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
132 LLVM_DEBUG(llvm::dbgs() << " - Randomize sequence: " << op << "\n");
133
134 auto res = getMaterializedSequence(op.getSequence(), op.getLoc());
135 if (failed(res))
136 return failure();
137
138 // It's important to force a copy here. Without the temporary variable, we'd
139 // assign an lvalue.
140 materializedSequences[op.getResult()] = *res;
141 return DeletionKind::Delete;
142 }
143
144 FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
145 LLVM_DEBUG(llvm::dbgs() << " - Inlining sequence: " << op << "\n");
146
147 auto res = getMaterializedSequence(op.getSequence(), op.getLoc());
148 if (failed(res))
149 return failure();
150
151 OpBuilder builder(op);
152 builder.setInsertionPointAfter(op);
153 IRMapping mapping = res->second;
154
155 LLVM_DEBUG({
156 for (auto [k, v] : mapping.getValueMap())
157 llvm::dbgs() << " - Maps " << k << " to " << v << "\n";
158 });
159
160 for (auto &opToInline : *res->first) {
161 Operation *o = builder.clone(opToInline, mapping);
162 o->setLoc(op.getLoc());
163 (void)o;
164 LLVM_DEBUG(llvm::dbgs() << " - Inlined " << *o << "\n");
165 }
166
167 ++numSequencesInlined;
168
169 return DeletionKind::Delete;
170 }
171
172 FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
173 return DeletionKind::Keep;
174 }
175
176 FailureOr<DeletionKind> visitExternalOp(Operation *op) {
177 return DeletionKind::Keep;
178 }
179
180 SymbolTable table;
181 DenseMap<Value, std::pair<Block *, IRMapping>> materializedSequences;
182 SmallVector<std::unique_ptr<Block>> blockStorage;
183 size_t numSequencesInlined = 0;
184 size_t numSequencesInterleaved = 0;
185 bool failOnRemaining;
186};
187
188} // namespace
189
190void SequenceInliner::materializeInterleavedSequence(Value value,
191 ArrayRef<Block *> blocks,
192 uint32_t batchSize) {
193 auto *interleavedBlock =
194 blockStorage.emplace_back(std::make_unique<Block>()).get();
195 IRMapping mapping;
196 OpBuilder builder(value.getContext());
197 builder.setInsertionPointToStart(interleavedBlock);
198
199 SmallVector<Block::iterator> iters(blocks.size());
200 for (auto [i, block] : llvm::enumerate(blocks))
201 iters[i] = block->begin();
202
203 llvm::BitVector finishedBlocks(blocks.size());
204 for (unsigned i = 0; !finishedBlocks.all(); i = (i + 1) % blocks.size()) {
205 if (finishedBlocks[i])
206 continue;
207 for (unsigned k = 0; k < batchSize;) {
208 if (iters[i] == blocks[i]->end()) {
209 finishedBlocks.set(i);
210 break;
211 }
212 auto *op = builder.clone(*iters[i], mapping);
213 if (isa<InstructionOpInterface>(op))
214 ++k;
215 ++iters[i];
216 }
217 }
218
219 materializedSequences[value] = std::make_pair(interleavedBlock, IRMapping());
220 numSequencesInterleaved += blocks.size();
221}
222
223// NOLINTNEXTLINE(misc-no-recursion)
224LogicalResult SequenceInliner::inlineSequences(Block &block) {
225 // Make sure we inline sequences in nested regions. Walk doesn't work here
226 // because it uses 'early_inc_range' which means we'd skip sequence
227 // embeddings that we added with the previous inlining.
228 SmallVector<Operation *> toDelete;
229 for (auto &op : block) {
230 for (auto &region : op.getRegions())
231 for (auto &block : region)
232 if (failOnRemaining && failed(inlineSequences(block)))
233 return failure();
234
235 auto result = dispatchOpVisitor(&op);
236 if (failed(result))
237 return failure();
238
239 if (*result == DeletionKind::Delete)
240 toDelete.push_back(&op);
241 }
242
243 for (auto *op : llvm::reverse(toDelete))
244 op->erase();
245
246 return success();
247}
248
249void InlineSequencesPass::runOnOperation() {
250 auto moduleOp = getOperation();
251 SequenceInliner inliner(moduleOp, failOnRemaining);
252
253 // Fast-path: no sequences are defined.
254 if (moduleOp.getOps<SequenceOp>().empty())
255 return;
256
257 // Inline all sequences and remove the operations that place the sequences.
258 for (auto testOp : moduleOp.getOps<TestOp>()) {
259 auto res = inliner.inlineSequences(*testOp.getBody());
260 if (failOnRemaining && failed(res))
261 return signalPassFailure();
262 }
263
264 numSequencesInlined = inliner.numSequencesInlined;
265 numSequencesInterleaved = inliner.numSequencesInterleaved;
266}
This helps visit TypeOp nodes.
Definition RTGVisitors.h:29
ResultType visitExternalOp(Operation *op, ExtraArgs... args)
Definition RTGVisitors.h:95
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
This callback is invoked on any operations that are not handled by the concrete visitor.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition rtg.py:1
Definition seq.py:1