12#include "mlir/IR/IRMapping.h"
13#include "mlir/IR/ImplicitLocOpBuilder.h"
14#include "mlir/Pass/Pass.h"
15#include "llvm/ADT/SetVector.h"
16#include "llvm/Support/Debug.h"
18#define DEBUG_TYPE "arc-split-loops"
22#define GEN_PASS_DEF_SPLITLOOPS
23#include "circt/Dialect/Arc/ArcPasses.h.inc"
30using mlir::CallOpInterface;
32using llvm::SmallSetVector;
55 Split(MLIRContext *context,
unsigned index,
const APInt &color)
56 : index(index), color(color), block(std::make_unique<
Block>()),
58 builder.setInsertionPointToStart(block.get());
62 void importInput(BlockArgument arg) {
63 importedValues.push_back({
true, arg.getArgNumber(), 0});
64 mapping.map(arg, block->addArgument(arg.getType(), arg.getLoc()));
68 void importFromOtherSplit(Value value, Split &otherSplit) {
69 auto resultIdx = otherSplit.exportValue(value);
70 importedValues.push_back({
false, resultIdx, otherSplit.index});
71 mapping.map(value, block->addArgument(value.getType(), value.getLoc()));
76 unsigned exportValue(Value value) {
77 value = mapping.lookup(value);
78 auto result = exportedValueIndices.insert({value, exportedValues.size()});
80 exportedValues.push_back(value);
81 return result.first->second;
87 std::unique_ptr<Block> block;
92 SmallVector<ImportedValue> importedValues;
94 SmallVector<Value> exportedValues;
100 Splitter(MLIRContext *context, Location loc) : context(context), loc(loc) {}
101 void run(Block &block, DenseMap<Operation *, APInt> &coloring);
102 Split &getSplit(
const APInt &color);
104 MLIRContext *context;
108 SmallVector<Split *> splits;
112 SmallVector<ImportedValue> outputs;
117void Splitter::run(Block &block, DenseMap<Operation *, APInt> &coloring) {
118 for (
auto &op : block.without_terminator()) {
119 auto color = coloring.lookup(&op);
120 auto &split = getSplit(color);
123 SmallSetVector<Value, 4> operands;
124 op.walk([&](Operation *op) {
125 for (
auto operand : op->getOperands())
126 if (operand.getParentBlock() == &block)
127 operands.insert(operand);
133 for (
auto operand : operands) {
134 if (split.mapping.contains(operand))
136 if (
auto blockArg = dyn_cast<BlockArgument>(operand)) {
137 split.importInput(blockArg);
140 auto *operandOp = operand.getDefiningOp();
141 auto operandColor = coloring.lookup(operandOp);
142 assert(operandOp && color != operandColor);
143 auto &operandSplit = getSplit(operandColor);
144 split.importFromOtherSplit(operand, operandSplit);
148 split.builder.clone(op, split.mapping);
152 for (
auto operand : block.getTerminator()->getOperands()) {
153 if (
auto blockArg = dyn_cast<BlockArgument>(operand)) {
154 outputs.push_back({
true, blockArg.getArgNumber(), 0});
157 auto &operandSplit = getSplit(coloring.lookup(operand.getDefiningOp()));
158 auto resultIdx = operandSplit.exportValue(operand);
159 outputs.push_back({
false, resultIdx, operandSplit.index});
163 for (
auto &split : splits)
164 arc::OutputOp::create(split->builder, loc, split->exportedValues);
168Split &Splitter::getSplit(
const APInt &color) {
169 auto &split = splitsByColor[color];
171 auto index = splits.size();
172 LLVM_DEBUG(llvm::dbgs()
173 <<
"- Creating split " << index <<
" for " << color <<
"\n");
174 split = std::make_unique<Split>(context, index, color);
175 splits.push_back(split.get());
185struct SplitLoopsPass :
public arc::impl::SplitLoopsBase<SplitLoopsPass> {
186 void runOnOperation()
override;
187 void splitArc(
Namespace &arcNamespace, DefineOp defOp,
188 ArrayRef<CallOpInterface> arcUses);
189 void replaceArcUse(CallOpInterface arcUse, ArrayRef<DefineOp> splitDefs,
190 ArrayRef<Split *> splits, ArrayRef<ImportedValue> outputs);
191 LogicalResult ensureNoLoops();
193 DenseSet<mlir::CallOpInterface> allArcUses;
197void SplitLoopsPass::runOnOperation() {
198 auto module = getOperation();
203 DenseMap<StringAttr, DefineOp> arcDefs;
204 for (
auto arcDef : module.getOps<DefineOp>()) {
205 arcNamespace.
newName(arcDef.getSymName());
206 arcDefs[arcDef.getSymNameAttr()] = arcDef;
210 SetVector<DefineOp> arcsToSplit;
211 DenseMap<DefineOp, SmallVector<CallOpInterface>> arcUses;
212 SetVector<CallOpInterface> allArcUses;
214 auto result =
module.walk([&](CallOpInterface callOp) -> WalkResult {
215 auto refSym = dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
219 return WalkResult::advance();
220 StringAttr leafRef = refSym.getLeafReference();
221 if (!arcDefs.contains(leafRef))
222 return WalkResult::advance();
224 auto defOp = arcDefs.lookup(leafRef);
225 arcUses[defOp].push_back(callOp);
226 allArcUses.insert(callOp);
228 auto clockedOp = dyn_cast<ClockedOpInterface>(callOp.getOperation());
229 if ((!clockedOp || clockedOp.getLatency() == 0) &&
230 callOp->getNumResults() > 1)
231 arcsToSplit.insert(defOp);
233 return WalkResult::advance();
236 if (result.wasInterrupted())
237 return signalPassFailure();
243 for (
auto defOp : arcsToSplit)
244 splitArc(arcNamespace, defOp, arcUses[defOp]);
247 if (failed(ensureNoLoops()))
248 return signalPassFailure();
252void SplitLoopsPass::splitArc(
Namespace &arcNamespace, DefineOp defOp,
253 ArrayRef<CallOpInterface> arcUses) {
254 LLVM_DEBUG(llvm::dbgs() <<
"Splitting arc " << defOp.getSymNameAttr()
259 auto numResults = defOp.getNumResults();
260 DenseMap<Value, APInt> valueColoring;
261 DenseMap<Operation *, APInt> opColoring;
263 for (
auto &operand : defOp.
getBodyBlock().getTerminator()->getOpOperands())
264 valueColoring.insert(
266 APInt::getOneBitSet(numResults, operand.getOperandNumber())});
269 auto coloring = APInt::getZero(numResults);
270 for (
auto result : op.getResults())
271 if (auto it = valueColoring.find(result); it != valueColoring.end())
272 coloring |= it->second;
273 opColoring.insert({&op, coloring});
274 op.walk([&](Operation *op) {
275 for (
auto &operand : op->getOpOperands())
276 valueColoring.try_emplace(operand.
get(), numResults, 0).first->second |=
282 Splitter splitter(&getContext(), defOp.getLoc());
283 splitter.run(defOp.getBodyBlock(), opColoring);
286 ImplicitLocOpBuilder builder(defOp.getLoc(), defOp);
287 SmallVector<DefineOp> splitArcs;
288 splitArcs.reserve(splitter.splits.size());
289 for (
auto &split : splitter.splits) {
290 auto splitName = defOp.getSymName();
291 if (splitter.splits.size() > 1)
292 splitName = arcNamespace.
newName(defOp.getSymName() +
"_split_" +
293 Twine(split->index));
295 DefineOp::create(builder, splitName,
296 builder.getFunctionType(
297 split->block->getArgumentTypes(),
298 split->block->getTerminator()->getOperandTypes()));
299 splitArc.getBody().push_back(split->block.release());
300 splitArcs.push_back(splitArc);
305 for (
auto arcUse : arcUses)
306 replaceArcUse(arcUse, splitArcs, splitter.splits, splitter.outputs);
312void SplitLoopsPass::replaceArcUse(CallOpInterface arcUse,
313 ArrayRef<DefineOp> splitDefs,
314 ArrayRef<Split *> splits,
315 ArrayRef<ImportedValue> outputs) {
316 ImplicitLocOpBuilder builder(arcUse.getLoc(), arcUse);
317 SmallVector<CallOp> newUses(splits.size());
321 auto getMappedValue = [&](ImportedValue value) {
323 return arcUse.getArgOperands()[value.index];
324 return newUses[value.split].getResult(value.index);
330 DenseMap<unsigned, unsigned> splitIdxMap;
331 for (
auto [i, split] :
llvm::enumerate(splits))
332 splitIdxMap[split->index] = i;
334 DenseSet<unsigned> splitsDone;
335 SmallVector<std::pair<const DefineOp, const Split *>> worklist;
337 auto getMappedValuesOrSchedule = [&](ArrayRef<ImportedValue> importedValues,
338 SmallVector<Value> &operands) {
339 for (
auto importedValue : importedValues) {
340 if (!importedValue.isInput && !splitsDone.contains(importedValue.split)) {
341 unsigned idx = splitIdxMap[importedValue.split];
342 worklist.push_back({splitDefs[idx], splits[idx]});
346 operands.push_back(getMappedValue(importedValue));
353 for (
auto [splitDef, split] :
llvm::reverse(
llvm::zip(splitDefs, splits)))
354 worklist.push_back({splitDef, split});
357 while (!worklist.empty()) {
358 auto [splitDef, split] = worklist.back();
360 if (splitsDone.contains(split->index)) {
365 SmallVector<Value> operands;
366 if (!getMappedValuesOrSchedule(split->importedValues, operands))
369 auto newUse = CallOp::create(builder, splitDef, operands);
370 allArcUses.insert(newUse);
371 newUses[split->index] = newUse;
373 splitsDone.insert(split->index);
378 for (
auto [result, importedValue] :
llvm::zip(arcUse->getResults(), outputs))
379 result.replaceAllUsesWith(getMappedValue(importedValue));
380 allArcUses.erase(arcUse);
385LogicalResult SplitLoopsPass::ensureNoLoops() {
386 SmallVector<std::pair<Operation *, unsigned>, 0> worklist;
387 DenseSet<Operation *> finished;
388 DenseSet<Operation *> seen;
389 for (
auto op : allArcUses) {
390 if (finished.contains(op))
393 worklist.push_back({op, 0});
394 while (!worklist.empty()) {
395 auto [op, idx] = worklist.back();
396 ++worklist.back().second;
397 if (idx == op->getNumOperands()) {
403 auto operand = op->getOperand(idx);
404 auto *def = operand.getDefiningOp();
405 if (!def || finished.contains(def))
407 if (
auto clockedOp = dyn_cast<ClockedOpInterface>(def);
408 clockedOp && clockedOp.getLatency() > 0)
410 if (!seen.insert(def).second) {
411 auto d = def->emitError(
412 "loop splitting did not eliminate all loops; loop detected");
413 for (
auto [op, idx] :
llvm::reverse(worklist)) {
414 d.attachNote(op->getLoc())
415 <<
"through operand " << (idx - 1) <<
" here:";
421 worklist.push_back({def, 0});
427std::unique_ptr<Pass> arc::createSplitLoopsPass() {
428 return std::make_unique<SplitLoopsPass>();
assert(baseType &&"element must be base type")
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)