12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/IR/ImplicitLocOpBuilder.h"
14 #include "mlir/Pass/Pass.h"
15 #include "llvm/ADT/SetVector.h"
16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "arc-split-loops"
22 #define GEN_PASS_DEF_SPLITLOOPS
23 #include "circt/Dialect/Arc/ArcPasses.h.inc"
27 using namespace circt;
30 using mlir::CallOpInterface;
32 using llvm::SmallSetVector;
40 struct ImportedValue {
55 Split(MLIRContext *context,
unsigned index,
const APInt &color)
56 : index(index), color(color), block(std::make_unique<
Block>()),
58 builder.setInsertionPointToStart(block.get());
62 void importInput(BlockArgument arg) {
63 importedValues.push_back({
true, arg.getArgNumber(), 0});
64 mapping.map(arg, block->addArgument(arg.getType(), arg.getLoc()));
68 void importFromOtherSplit(Value value, Split &otherSplit) {
69 auto resultIdx = otherSplit.exportValue(value);
70 importedValues.push_back({
false, resultIdx, otherSplit.index});
71 mapping.map(value, block->addArgument(value.getType(), value.getLoc()));
76 unsigned exportValue(Value value) {
77 value = mapping.lookup(value);
78 auto result = exportedValueIndices.insert({value, exportedValues.size()});
80 exportedValues.push_back(value);
81 return result.first->second;
87 std::unique_ptr<Block> block;
92 SmallVector<ImportedValue> importedValues;
94 SmallVector<Value> exportedValues;
100 Splitter(MLIRContext *context, Location loc) : context(context), loc(loc) {}
101 void run(Block &block, DenseMap<Operation *, APInt> &coloring);
102 Split &getSplit(
const APInt &color);
104 MLIRContext *context;
108 SmallVector<Split *> splits;
112 SmallVector<ImportedValue> outputs;
117 void Splitter::run(Block &block, DenseMap<Operation *, APInt> &coloring) {
118 for (
auto &op : block.without_terminator()) {
119 auto color = coloring.lookup(&op);
120 auto &split = getSplit(color);
123 SmallSetVector<Value, 4> operands;
124 op.walk([&](Operation *op) {
125 for (
auto operand : op->getOperands())
126 if (operand.getParentBlock() == &block)
127 operands.insert(operand);
133 for (
auto operand : operands) {
134 if (split.mapping.contains(operand))
136 if (
auto blockArg = dyn_cast<BlockArgument>(operand)) {
137 split.importInput(blockArg);
140 auto *operandOp = operand.getDefiningOp();
141 auto operandColor = coloring.lookup(operandOp);
142 assert(operandOp && color != operandColor);
143 auto &operandSplit = getSplit(operandColor);
144 split.importFromOtherSplit(operand, operandSplit);
148 split.builder.clone(op, split.mapping);
152 for (
auto operand : block.getTerminator()->getOperands()) {
153 if (
auto blockArg = dyn_cast<BlockArgument>(operand)) {
154 outputs.push_back({
true, blockArg.getArgNumber(), 0});
157 auto &operandSplit = getSplit(coloring.lookup(operand.getDefiningOp()));
158 auto resultIdx = operandSplit.exportValue(operand);
159 outputs.push_back({
false, resultIdx, operandSplit.index});
163 for (
auto &split : splits)
164 split->builder.create<arc::OutputOp>(loc, split->exportedValues);
168 Split &Splitter::getSplit(
const APInt &color) {
169 auto &split = splitsByColor[color];
171 auto index = splits.size();
172 LLVM_DEBUG(llvm::dbgs()
173 <<
"- Creating split " << index <<
" for " << color <<
"\n");
174 split = std::make_unique<Split>(context, index, color);
175 splits.push_back(split.get());
185 struct SplitLoopsPass :
public arc::impl::SplitLoopsBase<SplitLoopsPass> {
186 void runOnOperation()
override;
187 void splitArc(
Namespace &arcNamespace, DefineOp defOp,
188 ArrayRef<CallOpInterface> arcUses);
189 void replaceArcUse(CallOpInterface arcUse, ArrayRef<DefineOp> splitDefs,
190 ArrayRef<Split *> splits, ArrayRef<ImportedValue> outputs);
191 LogicalResult ensureNoLoops();
193 DenseSet<mlir::CallOpInterface> allArcUses;
197 void SplitLoopsPass::runOnOperation() {
198 auto module = getOperation();
203 DenseMap<StringAttr, DefineOp> arcDefs;
204 for (
auto arcDef : module.getOps<DefineOp>()) {
205 arcNamespace.
newName(arcDef.getSymName());
206 arcDefs[arcDef.getSymNameAttr()] = arcDef;
210 SetVector<DefineOp> arcsToSplit;
211 DenseMap<DefineOp, SmallVector<CallOpInterface>> arcUses;
212 SetVector<CallOpInterface> allArcUses;
214 auto result = module.walk([&](CallOpInterface callOp) -> WalkResult {
215 auto refSym = dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
219 return WalkResult::advance();
220 StringAttr leafRef = refSym.getLeafReference();
221 if (!arcDefs.contains(leafRef))
222 return WalkResult::advance();
224 auto defOp = arcDefs.lookup(leafRef);
225 arcUses[defOp].push_back(callOp);
226 allArcUses.insert(callOp);
228 auto clockedOp = dyn_cast<ClockedOpInterface>(callOp.getOperation());
229 if ((!clockedOp || clockedOp.getLatency() == 0) &&
230 callOp->getNumResults() > 1)
231 arcsToSplit.insert(defOp);
233 return WalkResult::advance();
236 if (result.wasInterrupted())
237 return signalPassFailure();
243 for (
auto defOp : arcsToSplit)
244 splitArc(arcNamespace, defOp, arcUses[defOp]);
247 if (failed(ensureNoLoops()))
248 return signalPassFailure();
252 void SplitLoopsPass::splitArc(
Namespace &arcNamespace, DefineOp defOp,
253 ArrayRef<CallOpInterface> arcUses) {
254 LLVM_DEBUG(llvm::dbgs() <<
"Splitting arc " << defOp.getSymNameAttr()
259 auto numResults = defOp.getNumResults();
260 DenseMap<Value, APInt> valueColoring;
261 DenseMap<Operation *, APInt> opColoring;
263 for (
auto &operand : defOp.getBodyBlock().getTerminator()->getOpOperands())
264 valueColoring.insert(
266 APInt::getOneBitSet(numResults, operand.getOperandNumber())});
268 for (
auto &op : llvm::reverse(defOp.getBodyBlock().without_terminator())) {
269 auto coloring = APInt::getZero(numResults);
270 for (
auto result : op.getResults())
271 if (
auto it = valueColoring.find(result); it != valueColoring.end())
272 coloring |= it->second;
273 opColoring.insert({&op, coloring});
274 op.walk([&](Operation *op) {
275 for (
auto &operand : op->getOpOperands())
276 valueColoring.try_emplace(operand.get(), numResults, 0).first->second |=
282 Splitter splitter(&getContext(), defOp.getLoc());
283 splitter.run(defOp.getBodyBlock(), opColoring);
286 ImplicitLocOpBuilder builder(defOp.getLoc(), defOp);
287 SmallVector<DefineOp> splitArcs;
288 splitArcs.reserve(splitter.splits.size());
289 for (
auto &split : splitter.splits) {
290 auto splitName = defOp.getSymName();
291 if (splitter.splits.size() > 1)
292 splitName = arcNamespace.
newName(defOp.getSymName() +
"_split_" +
293 Twine(split->index));
294 auto splitArc = builder.create<DefineOp>(
295 splitName, builder.getFunctionType(
296 split->block->getArgumentTypes(),
297 split->block->getTerminator()->getOperandTypes()));
298 splitArc.getBody().push_back(split->block.release());
299 splitArcs.push_back(splitArc);
304 for (
auto arcUse : arcUses)
305 replaceArcUse(arcUse, splitArcs, splitter.splits, splitter.outputs);
311 void SplitLoopsPass::replaceArcUse(CallOpInterface arcUse,
312 ArrayRef<DefineOp> splitDefs,
313 ArrayRef<Split *> splits,
314 ArrayRef<ImportedValue> outputs) {
315 ImplicitLocOpBuilder builder(arcUse.getLoc(), arcUse);
316 SmallVector<CallOp> newUses(splits.size());
320 auto getMappedValue = [&](ImportedValue value) {
322 return arcUse.getArgOperands()[value.index];
323 return newUses[value.split].getResult(value.index);
329 DenseMap<unsigned, unsigned> splitIdxMap;
330 for (
auto [i, split] : llvm::enumerate(splits))
331 splitIdxMap[split->index] = i;
333 DenseSet<unsigned> splitsDone;
334 SmallVector<std::pair<const DefineOp, const Split *>> worklist;
336 auto getMappedValuesOrSchedule = [&](ArrayRef<ImportedValue> importedValues,
337 SmallVector<Value> &operands) {
338 for (
auto importedValue : importedValues) {
339 if (!importedValue.isInput && !splitsDone.contains(importedValue.split)) {
340 unsigned idx = splitIdxMap[importedValue.split];
341 worklist.push_back({splitDefs[idx], splits[idx]});
345 operands.push_back(getMappedValue(importedValue));
352 for (
auto [splitDef, split] : llvm::reverse(llvm::zip(splitDefs, splits)))
353 worklist.push_back({splitDef, split});
356 while (!worklist.empty()) {
357 auto [splitDef, split] = worklist.back();
359 if (splitsDone.contains(split->index)) {
364 SmallVector<Value> operands;
365 if (!getMappedValuesOrSchedule(split->importedValues, operands))
368 auto newUse = builder.create<CallOp>(splitDef, operands);
369 allArcUses.insert(newUse);
370 newUses[split->index] = newUse;
372 splitsDone.insert(split->index);
377 for (
auto [result, importedValue] : llvm::zip(arcUse->getResults(), outputs))
378 result.replaceAllUsesWith(getMappedValue(importedValue));
379 allArcUses.erase(arcUse);
384 LogicalResult SplitLoopsPass::ensureNoLoops() {
385 SmallVector<std::pair<Operation *, unsigned>, 0> worklist;
386 DenseSet<Operation *> finished;
387 DenseSet<Operation *> seen;
388 for (
auto op : allArcUses) {
389 if (finished.contains(op))
392 worklist.push_back({op, 0});
393 while (!worklist.empty()) {
394 auto [op, idx] = worklist.back();
395 ++worklist.back().second;
396 if (idx == op->getNumOperands()) {
402 auto operand = op->getOperand(idx);
403 auto *def = operand.getDefiningOp();
404 if (!def || finished.contains(def))
406 if (
auto clockedOp = dyn_cast<ClockedOpInterface>(def);
407 clockedOp && clockedOp.getLatency() > 0)
409 if (!seen.insert(def).second) {
410 auto d = def->emitError(
411 "loop splitting did not eliminate all loops; loop detected");
412 for (
auto [op, idx] : llvm::reverse(worklist)) {
413 d.attachNote(op->getLoc())
414 <<
"through operand " << (idx - 1) <<
" here:";
420 worklist.push_back({def, 0});
427 return std::make_unique<SplitLoopsPass>();
assert(baseType &&"element must be base type")
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
std::unique_ptr< mlir::Pass > createSplitLoopsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)