12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/IR/ImplicitLocOpBuilder.h"
14 #include "mlir/Pass/Pass.h"
15 #include "llvm/ADT/SetVector.h"
16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "arc-split-loops"
22 #define GEN_PASS_DEF_SPLITLOOPS
23 #include "circt/Dialect/Arc/ArcPasses.h.inc"
27 using namespace circt;
31 using llvm::SmallSetVector;
39 struct ImportedValue {
54 Split(MLIRContext *context,
unsigned index,
const APInt &color)
55 : index(index), color(color), block(std::make_unique<
Block>()),
57 builder.setInsertionPointToStart(block.get());
61 void importInput(BlockArgument arg) {
62 importedValues.push_back({
true, arg.getArgNumber(), 0});
63 mapping.map(arg, block->addArgument(arg.getType(), arg.getLoc()));
67 void importFromOtherSplit(Value
value, Split &otherSplit) {
68 auto resultIdx = otherSplit.exportValue(
value);
69 importedValues.push_back({
false, resultIdx, otherSplit.index});
70 mapping.map(
value, block->addArgument(
value.getType(),
value.getLoc()));
75 unsigned exportValue(Value
value) {
77 auto result = exportedValueIndices.insert({
value, exportedValues.size()});
79 exportedValues.push_back(
value);
80 return result.first->second;
86 std::unique_ptr<Block> block;
91 SmallVector<ImportedValue> importedValues;
93 SmallVector<Value> exportedValues;
99 Splitter(MLIRContext *context, Location loc) : context(context), loc(loc) {}
100 void run(Block &block, DenseMap<Operation *, APInt> &coloring);
101 Split &getSplit(
const APInt &color);
103 MLIRContext *context;
107 SmallVector<Split *> splits;
111 SmallVector<ImportedValue>
outputs;
116 void Splitter::run(Block &block, DenseMap<Operation *, APInt> &coloring) {
117 for (
auto &op : block.without_terminator()) {
118 auto color = coloring.lookup(&op);
119 auto &split = getSplit(color);
122 SmallSetVector<Value, 4> operands;
123 op.walk([&](Operation *op) {
124 for (
auto operand : op->getOperands())
125 if (operand.getParentBlock() == &block)
126 operands.insert(operand);
132 for (
auto operand : operands) {
133 if (split.mapping.contains(operand))
135 if (
auto blockArg = dyn_cast<BlockArgument>(operand)) {
136 split.importInput(blockArg);
139 auto *operandOp = operand.getDefiningOp();
140 auto operandColor = coloring.lookup(operandOp);
141 assert(operandOp && color != operandColor);
142 auto &operandSplit = getSplit(operandColor);
143 split.importFromOtherSplit(operand, operandSplit);
147 split.builder.clone(op, split.mapping);
151 for (
auto operand : block.getTerminator()->getOperands()) {
152 if (
auto blockArg = dyn_cast<BlockArgument>(operand)) {
153 outputs.push_back({
true, blockArg.getArgNumber(), 0});
156 auto &operandSplit = getSplit(coloring.lookup(operand.getDefiningOp()));
157 auto resultIdx = operandSplit.exportValue(operand);
158 outputs.push_back({
false, resultIdx, operandSplit.index});
162 for (
auto &split : splits)
163 split->builder.create<arc::OutputOp>(loc, split->exportedValues);
167 Split &Splitter::getSplit(
const APInt &color) {
168 auto &split = splitsByColor[color];
170 auto index = splits.size();
172 <<
"- Creating split " << index <<
" for " << color <<
"\n");
173 split = std::make_unique<Split>(context, index, color);
174 splits.push_back(split.get());
184 struct SplitLoopsPass :
public arc::impl::SplitLoopsBase<SplitLoopsPass> {
185 void runOnOperation()
override;
186 void splitArc(
Namespace &arcNamespace, DefineOp defOp,
187 ArrayRef<StateOp> arcUses);
188 void replaceArcUse(StateOp arcUse, ArrayRef<DefineOp> splitDefs,
189 ArrayRef<Split *> splits, ArrayRef<ImportedValue>
outputs);
190 LogicalResult ensureNoLoops();
192 DenseSet<StateOp> allArcUses;
196 void SplitLoopsPass::runOnOperation() {
197 auto module = getOperation();
202 DenseMap<StringAttr, DefineOp> arcDefs;
203 for (
auto arcDef : module.getOps<DefineOp>()) {
204 arcNamespace.
newName(arcDef.getSymName());
205 arcDefs[arcDef.getSymNameAttr()] = arcDef;
209 SetVector<DefineOp> arcsToSplit;
210 DenseMap<DefineOp, SmallVector<StateOp>> arcUses;
211 SetVector<StateOp> allArcUses;
213 module.walk([&](StateOp stateOp) {
214 auto sym = stateOp.getArcAttr().
getAttr();
215 auto defOp = arcDefs.lookup(sym);
216 arcUses[defOp].push_back(stateOp);
217 allArcUses.insert(stateOp);
218 if (stateOp.getLatency() == 0 && stateOp.getNumResults() > 1)
219 arcsToSplit.insert(defOp);
226 for (
auto defOp : arcsToSplit)
227 splitArc(arcNamespace, defOp, arcUses[defOp]);
230 if (failed(ensureNoLoops()))
231 return signalPassFailure();
235 void SplitLoopsPass::splitArc(
Namespace &arcNamespace, DefineOp defOp,
236 ArrayRef<StateOp> arcUses) {
237 LLVM_DEBUG(
llvm::dbgs() <<
"Splitting arc " << defOp.getSymNameAttr()
242 auto numResults = defOp.getNumResults();
243 DenseMap<Value, APInt> valueColoring;
244 DenseMap<Operation *, APInt> opColoring;
246 for (
auto &operand : defOp.getBodyBlock().getTerminator()->getOpOperands())
247 valueColoring.insert(
249 APInt::getOneBitSet(numResults, operand.getOperandNumber())});
251 for (
auto &op : llvm::reverse(defOp.getBodyBlock().without_terminator())) {
252 auto coloring = APInt::getZero(numResults);
253 for (
auto result : op.getResults())
254 if (
auto it = valueColoring.find(result); it != valueColoring.end())
255 coloring |= it->second;
256 opColoring.insert({&op, coloring});
257 op.walk([&](Operation *op) {
258 for (
auto &operand : op->getOpOperands())
259 valueColoring.try_emplace(operand.get(), numResults, 0).first->second |=
265 Splitter splitter(&getContext(), defOp.getLoc());
266 splitter.run(defOp.getBodyBlock(), opColoring);
269 ImplicitLocOpBuilder
builder(defOp.getLoc(), defOp);
270 SmallVector<DefineOp> splitArcs;
271 splitArcs.reserve(splitter.splits.size());
272 for (
auto &split : splitter.splits) {
273 auto splitName = defOp.getSymName();
274 if (splitter.splits.size() > 1)
275 splitName = arcNamespace.
newName(defOp.getSymName() +
"_split_" +
276 Twine(split->index));
277 auto splitArc =
builder.create<DefineOp>(
278 splitName,
builder.getFunctionType(
279 split->block->getArgumentTypes(),
280 split->block->getTerminator()->getOperandTypes()));
281 splitArc.getBody().push_back(split->block.release());
282 splitArcs.push_back(splitArc);
286 for (
auto arcUse : arcUses)
287 replaceArcUse(arcUse, splitArcs, splitter.splits, splitter.outputs);
292 void SplitLoopsPass::replaceArcUse(StateOp arcUse, ArrayRef<DefineOp> splitDefs,
293 ArrayRef<Split *> splits,
294 ArrayRef<ImportedValue>
outputs) {
295 ImplicitLocOpBuilder
builder(arcUse.getLoc(), arcUse);
296 SmallVector<StateOp> newUses(splits.size());
300 auto getMappedValue = [&](ImportedValue
value) {
302 return arcUse.getInputs()[
value.index];
303 return newUses[
value.split].getResult(
value.index);
309 DenseMap<unsigned, unsigned> splitIdxMap;
310 for (
auto [i, split] : llvm::enumerate(splits))
311 splitIdxMap[split->index] = i;
313 DenseSet<unsigned> splitsDone;
314 SmallVector<std::pair<const DefineOp, const Split *>> worklist;
316 auto getMappedValuesOrSchedule = [&](ArrayRef<ImportedValue> importedValues,
317 SmallVector<Value> &operands) {
318 for (
auto importedValue : importedValues) {
319 if (!importedValue.isInput && !splitsDone.contains(importedValue.split)) {
320 unsigned idx = splitIdxMap[importedValue.split];
321 worklist.push_back({splitDefs[idx], splits[idx]});
325 operands.push_back(getMappedValue(importedValue));
332 for (
auto [splitDef, split] : llvm::reverse(llvm::zip(splitDefs, splits)))
333 worklist.push_back({splitDef, split});
336 while (!worklist.empty()) {
337 auto [splitDef, split] = worklist.back();
339 if (splitsDone.contains(split->index)) {
344 SmallVector<Value> operands;
345 if (!getMappedValuesOrSchedule(split->importedValues, operands))
349 builder.create<StateOp>(splitDef, Value{}, Value{}, 0, operands);
350 allArcUses.insert(newUse);
351 newUses[split->index] = newUse;
353 splitsDone.insert(split->index);
358 for (
auto [result, importedValue] : llvm::zip(arcUse.getResults(),
outputs))
359 result.replaceAllUsesWith(getMappedValue(importedValue));
360 allArcUses.erase(arcUse);
365 LogicalResult SplitLoopsPass::ensureNoLoops() {
366 SmallVector<std::pair<Operation *, unsigned>, 0> worklist;
367 DenseSet<Operation *> finished;
368 DenseSet<Operation *> seen;
369 for (
auto op : allArcUses) {
370 if (finished.contains(op))
373 worklist.push_back({op, 0});
374 while (!worklist.empty()) {
375 auto [op, idx] = worklist.back();
376 ++worklist.back().second;
377 if (idx == op->getNumOperands()) {
383 auto operand = op->getOperand(idx);
384 auto *def = operand.getDefiningOp();
385 if (!def || finished.contains(def))
387 if (
auto stateOp = dyn_cast<StateOp>(def);
388 stateOp && stateOp.getLatency() > 0)
390 if (!seen.insert(def).second) {
391 auto d = def->emitError(
392 "loop splitting did not eliminate all loops; loop detected");
393 for (
auto [op, idx] : llvm::reverse(worklist)) {
394 d.attachNote(op->getLoc())
395 <<
"through operand " << (idx - 1) <<
" here:";
401 worklist.push_back({def, 0});
408 return std::make_unique<SplitLoopsPass>();
assert(baseType &&"element must be base type")
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
llvm::SmallVector< StringAttr > outputs
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
std::unique_ptr< mlir::Pass > createSplitLoopsPass()
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & dbgs()