CIRCT  18.0.0git
ConvertToArcs.cpp
Go to the documentation of this file.
1 //===- ConvertToArcs.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../PassDetail.h"
12 #include "circt/Dialect/HW/HWOps.h"
15 #include "llvm/Support/Debug.h"
16 
17 #define DEBUG_TYPE "convert-to-arcs"
18 
19 using namespace circt;
20 using namespace arc;
21 using namespace hw;
22 using llvm::MapVector;
23 
24 static bool isArcBreakingOp(Operation *op) {
25  return op->hasTrait<OpTrait::ConstantLike>() ||
26  isa<hw::InstanceOp, seq::CompRegOp, ClockGateOp, MemoryOp,
27  ClockedOpInterface>(op) ||
28  op->getNumResults() > 1;
29 }
30 
31 //===----------------------------------------------------------------------===//
32 // Conversion
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 struct Converter {
37  LogicalResult run(ModuleOp module);
38  LogicalResult runOnModule(HWModuleOp module);
39  LogicalResult analyzeFanIn();
40  void extractArcs(HWModuleOp module);
41  LogicalResult absorbRegs(HWModuleOp module);
42 
43  /// The global namespace used to create unique definition names.
44  Namespace globalNamespace;
45 
46  /// All arc-breaking operations in the current module.
47  SmallVector<Operation *> arcBreakers;
48  SmallDenseMap<Operation *, unsigned> arcBreakerIndices;
49 
50  /// A post-order traversal of the operations in the current module.
51  SmallVector<Operation *> postOrder;
52 
53  /// The set of arc-breaking ops an operation in the current module
54  /// contributes to, represented as a bit mask.
55  MapVector<Operation *, APInt> faninMasks;
56 
57  /// The sets of operations that contribute to the same arc-breaking ops.
58  MapVector<APInt, DenseSet<Operation *>> faninMaskGroups;
59 
60  /// The arc uses generated by `extractArcs`.
61  SmallVector<StateOp> arcUses;
62 };
63 } // namespace
64 
65 LogicalResult Converter::run(ModuleOp module) {
66  for (auto &op : module.getOps())
67  if (auto sym =
68  op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
69  globalNamespace.newName(sym.getValue());
70  for (auto module : module.getOps<HWModuleOp>())
71  if (failed(runOnModule(module)))
72  return failure();
73  return success();
74 }
75 
76 LogicalResult Converter::runOnModule(HWModuleOp module) {
77  // Find all arc-breaking operations in this module and assign them an index.
78  arcBreakers.clear();
79  arcBreakerIndices.clear();
80  for (Operation &op : *module.getBodyBlock()) {
81  if (op.getNumRegions() > 0)
82  return op.emitOpError("has regions; not supported by ConvertToArcs");
83  if (!isArcBreakingOp(&op) && !isa<hw::OutputOp>(&op))
84  continue;
85  arcBreakerIndices[&op] = arcBreakers.size();
86  arcBreakers.push_back(&op);
87  }
88  // Skip modules with only `OutputOp`.
89  if (module.getBodyBlock()->without_terminator().empty() &&
90  isa<hw::OutputOp>(module.getBodyBlock()->getTerminator()))
91  return success();
92  LLVM_DEBUG(llvm::dbgs() << "Analyzing " << module.getModuleNameAttr() << " ("
93  << arcBreakers.size() << " breakers)\n");
94 
95  // For each operation, figure out the set of breaker ops it contributes to,
96  // in the form of a bit mask. Then group operations together that contribute
97  // to the same set of breaker ops.
98  if (failed(analyzeFanIn()))
99  return failure();
100 
101  // Extract the fanin mask groups into separate combinational arcs and
102  // combine them with the registers in the design.
103  extractArcs(module);
104  if (failed(absorbRegs(module)))
105  return failure();
106  return success();
107 }
108 
109 LogicalResult Converter::analyzeFanIn() {
110  SmallVector<std::tuple<Operation *, unsigned>> worklist;
111 
112  // Seed the worklist and fanin masks with the arc breaking operations.
113  faninMasks.clear();
114  for (auto *op : arcBreakers) {
115  unsigned index = arcBreakerIndices.lookup(op);
116  auto mask = APInt::getOneBitSet(arcBreakers.size(), index);
117  faninMasks[op] = mask;
118  worklist.push_back({op, 0});
119  }
120 
121  // Establish a post-order among the operations.
122  DenseSet<Operation *> seen;
123  DenseSet<Operation *> finished;
124  postOrder.clear();
125  while (!worklist.empty()) {
126  auto &[op, operandIdx] = worklist.back();
127  if (operandIdx == op->getNumOperands()) {
128  if (!isArcBreakingOp(op) && !isa<hw::OutputOp>(op))
129  postOrder.push_back(op);
130  finished.insert(op);
131  seen.erase(op);
132  worklist.pop_back();
133  continue;
134  }
135  auto operand = op->getOperand(operandIdx++); // advance to next operand
136  auto *definingOp = operand.getDefiningOp();
137  if (!definingOp || isArcBreakingOp(definingOp) ||
138  finished.contains(definingOp))
139  continue;
140  if (!seen.insert(definingOp).second) {
141  definingOp->emitError("combinational loop detected");
142  return failure();
143  }
144  worklist.push_back({definingOp, 0});
145  }
146  LLVM_DEBUG(llvm::dbgs() << "- Sorted " << postOrder.size() << " ops\n");
147 
148  // Compute fanin masks in reverse post-order, which will compute the mask
149  // for an operation's uses before it computes it for the operation itself.
150  // This allows us to compute the set of arc breakers an operation
151  // contributes to in one pass.
152  for (auto *op : llvm::reverse(postOrder)) {
153  auto mask = APInt::getZero(arcBreakers.size());
154  for (auto *user : op->getUsers()) {
155  auto it = faninMasks.find(user);
156  if (it != faninMasks.end())
157  mask |= it->second;
158  }
159 
160  auto duplicateOp = faninMasks.insert({op, mask});
161  (void)duplicateOp;
162  assert(duplicateOp.second && "duplicate op in order");
163  }
164 
165  // Group the operations by their fan-in mask.
166  faninMaskGroups.clear();
167  for (auto [op, mask] : faninMasks)
168  if (!isArcBreakingOp(op) && !isa<hw::OutputOp>(op))
169  faninMaskGroups[mask].insert(op);
170  LLVM_DEBUG(llvm::dbgs() << "- Found " << faninMaskGroups.size()
171  << " fanin mask groups\n");
172 
173  return success();
174 }
175 
176 void Converter::extractArcs(HWModuleOp module) {
177  DenseMap<Value, Value> valueMapping;
178  SmallVector<Value> inputs;
179  SmallVector<Value> outputs;
180  SmallVector<Type> inputTypes;
181  SmallVector<Type> outputTypes;
182  SmallVector<std::pair<OpOperand *, unsigned>> externalUses;
183 
184  arcUses.clear();
185  for (auto &group : faninMaskGroups) {
186  auto &opSet = group.second;
187  OpBuilder builder(module);
188 
189  auto block = std::make_unique<Block>();
190  builder.setInsertionPointToStart(block.get());
191  valueMapping.clear();
192  inputs.clear();
193  outputs.clear();
194  inputTypes.clear();
195  outputTypes.clear();
196  externalUses.clear();
197 
198  Operation *lastOp = nullptr;
199  // TODO: Remove the elements from the post order as we go.
200  for (auto *op : postOrder) {
201  if (!opSet.contains(op))
202  continue;
203  lastOp = op;
204  op->remove();
205  builder.insert(op);
206  for (auto &operand : op->getOpOperands()) {
207  if (opSet.contains(operand.get().getDefiningOp()))
208  continue;
209  auto &mapped = valueMapping[operand.get()];
210  if (!mapped) {
211  mapped = block->addArgument(operand.get().getType(),
212  operand.get().getLoc());
213  inputs.push_back(operand.get());
214  inputTypes.push_back(mapped.getType());
215  }
216  operand.set(mapped);
217  }
218  for (auto result : op->getResults()) {
219  bool anyExternal = false;
220  for (auto &use : result.getUses()) {
221  if (!opSet.contains(use.getOwner())) {
222  anyExternal = true;
223  externalUses.push_back({&use, outputs.size()});
224  }
225  }
226  if (anyExternal) {
227  outputs.push_back(result);
228  outputTypes.push_back(result.getType());
229  }
230  }
231  }
232  assert(lastOp);
233  builder.create<arc::OutputOp>(lastOp->getLoc(), outputs);
234 
235  // Create the arc definition.
236  builder.setInsertionPoint(module);
237  auto defOp = builder.create<DefineOp>(
238  lastOp->getLoc(),
239  builder.getStringAttr(
240  globalNamespace.newName(module.getModuleName() + "_arc")),
241  builder.getFunctionType(inputTypes, outputTypes));
242  defOp.getBody().push_back(block.release());
243 
244  // Create the call to the arc definition to replace the operations that
245  // we have just extracted.
246  builder.setInsertionPoint(module.getBodyBlock()->getTerminator());
247  auto arcOp = builder.create<StateOp>(lastOp->getLoc(), defOp, Value{},
248  Value{}, 0, inputs);
249  arcUses.push_back(arcOp);
250  for (auto [use, resultIdx] : externalUses)
251  use->set(arcOp.getResult(resultIdx));
252  }
253 }
254 
255 LogicalResult Converter::absorbRegs(HWModuleOp module) {
256  // Handle the trivial cases where all of an arc's results are used by
257  // exactly one register each.
258  unsigned outIdx = 0;
259  unsigned numTrivialRegs = 0;
260  for (auto &arc : arcUses) {
261  Value clock = arc.getClock();
262  Value reset;
263  SmallVector<seq::CompRegOp> absorbedRegs;
264  SmallVector<Attribute> absorbedNames(arc.getNumResults(), {});
265  if (auto names = arc->getAttrOfType<ArrayAttr>("names"))
266  absorbedNames.assign(names.getValue().begin(), names.getValue().end());
267 
268  // Go through all every arc result and collect the single register that uses
269  // it. If a result has multiple uses or is used by something other than a
270  // register, skip the arc for now and handle it later.
271  bool isTrivial = true;
272  for (auto result : arc.getResults()) {
273  if (!result.hasOneUse()) {
274  isTrivial = false;
275  break;
276  }
277  auto regOp = dyn_cast<seq::CompRegOp>(result.use_begin()->getOwner());
278  if (!regOp || regOp.getInput() != result ||
279  (clock && clock != regOp.getClk())) {
280  isTrivial = false;
281  break;
282  }
283 
284  clock = regOp.getClk();
285  reset = regOp.getReset();
286 
287  // Check that if the register has a reset, it is to a constant zero
288  if (reset) {
289  Value resetValue = regOp.getResetValue();
290  Operation *op = resetValue.getDefiningOp();
291  if (!op)
292  return regOp->emitOpError(
293  "is reset by an input; not supported by ConvertToArcs");
294  if (auto constant = dyn_cast<hw::ConstantOp>(op)) {
295  if (constant.getValue() != 0)
296  return regOp->emitOpError("is reset to a constant non-zero value; "
297  "not supported by ConvertToArcs");
298  } else {
299  return regOp->emitOpError("is reset to a value that is not clearly "
300  "constant; not supported by ConvertToArcs");
301  }
302  }
303 
304  absorbedRegs.push_back(regOp);
305  // If we absorb a register into the arc, the arc effectively produces that
306  // register's value. So if the register had a name, ensure that we assign
307  // that name to the arc's output.
308  absorbedNames[result.getResultNumber()] = regOp.getNameAttr();
309  }
310 
311  // If this wasn't a trivial case keep the arc around for a second iteration.
312  if (!isTrivial) {
313  arcUses[outIdx++] = arc;
314  continue;
315  }
316  ++numTrivialRegs;
317 
318  // Set the arc's clock to the clock of the registers we've absorbed, bump
319  // the latency up by one to account for the registers, add the reset if
320  // present and update the output names. Then replace the registers.
321  arc.getClockMutable().assign(clock);
322  arc.setLatency(arc.getLatency() + 1);
323  if (reset) {
324  if (arc.getReset())
325  return arc.emitError(
326  "StateOp tried to infer reset from CompReg, but already "
327  "had a reset.");
328  arc.getResetMutable().assign(reset);
329  }
330  if (llvm::any_of(absorbedNames, [](auto name) {
331  return !name.template cast<StringAttr>().getValue().empty();
332  }))
333  arc->setAttr("names", ArrayAttr::get(module.getContext(), absorbedNames));
334  for (auto [arcResult, reg] : llvm::zip(arc.getResults(), absorbedRegs)) {
335  auto it = arcBreakerIndices.find(reg);
336  arcBreakers[it->second] = {};
337  arcBreakerIndices.erase(it);
338  reg.replaceAllUsesWith(arcResult);
339  reg.erase();
340  }
341  }
342  if (numTrivialRegs > 0)
343  LLVM_DEBUG(llvm::dbgs() << "- Trivially converted " << numTrivialRegs
344  << " regs to arcs\n");
345  arcUses.truncate(outIdx);
346 
347  // Group the remaining registers by their clock, their reset and the operation
348  // they use as input. This will allow us to generally collapse registers
349  // derived from the same arc into one shuffling arc.
350  MapVector<std::tuple<Value, Value, Operation *>, SmallVector<seq::CompRegOp>>
351  regsByInput;
352  for (auto *op : arcBreakers)
353  if (auto regOp = dyn_cast_or_null<seq::CompRegOp>(op)) {
354  regsByInput[{regOp.getClk(), regOp.getReset(),
355  regOp.getInput().getDefiningOp()}]
356  .push_back(regOp);
357  }
358 
359  unsigned numMappedRegs = 0;
360  for (auto [clockAndResetAndOp, regOps] : regsByInput) {
361  numMappedRegs += regOps.size();
362  OpBuilder builder(module);
363  auto block = std::make_unique<Block>();
364  builder.setInsertionPointToStart(block.get());
365 
366  SmallVector<Value> inputs;
367  SmallVector<Value> outputs;
368  SmallVector<Attribute> names;
369  SmallVector<Type> types;
371  SmallVector<unsigned> regToOutputMapping;
372  for (auto regOp : regOps) {
373  auto it = mapping.find(regOp.getInput());
374  if (it == mapping.end()) {
375  it = mapping.insert({regOp.getInput(), inputs.size()}).first;
376  inputs.push_back(regOp.getInput());
377  types.push_back(regOp.getType());
378  outputs.push_back(block->addArgument(regOp.getType(), regOp.getLoc()));
379  names.push_back(regOp->getAttrOfType<StringAttr>("name"));
380  }
381  regToOutputMapping.push_back(it->second);
382  }
383 
384  auto loc = regOps.back().getLoc();
385  builder.create<arc::OutputOp>(loc, outputs);
386 
387  builder.setInsertionPoint(module);
388  auto defOp =
389  builder.create<DefineOp>(loc,
390  builder.getStringAttr(globalNamespace.newName(
391  module.getModuleName() + "_arc")),
392  builder.getFunctionType(types, types));
393  defOp.getBody().push_back(block.release());
394 
395  builder.setInsertionPoint(module.getBodyBlock()->getTerminator());
396  auto arcOp =
397  builder.create<StateOp>(loc, defOp, std::get<0>(clockAndResetAndOp),
398  /*enable=*/Value{}, 1, inputs);
399  auto reset = std::get<1>(clockAndResetAndOp);
400  if (reset)
401  arcOp.getResetMutable().assign(reset);
402  if (llvm::any_of(names, [](auto name) {
403  return !name.template cast<StringAttr>().getValue().empty();
404  }))
405  arcOp->setAttr("names", builder.getArrayAttr(names));
406  for (auto [reg, resultIdx] : llvm::zip(regOps, regToOutputMapping)) {
407  reg.replaceAllUsesWith(arcOp.getResult(resultIdx));
408  reg.erase();
409  }
410  }
411 
412  if (numMappedRegs > 0)
413  LLVM_DEBUG(llvm::dbgs() << "- Mapped " << numMappedRegs << " regs to "
414  << regsByInput.size() << " shuffling arcs\n");
415 
416  return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // Pass Infrastructure
421 //===----------------------------------------------------------------------===//
422 
423 namespace {
424 struct ConvertToArcsPass : public ConvertToArcsBase<ConvertToArcsPass> {
425  void runOnOperation() override {
426  Converter converter;
427  if (failed(converter.run(getOperation())))
428  signalPassFailure();
429  }
430 };
431 } // namespace
432 
433 std::unique_ptr<OperationPass<ModuleOp>> circt::createConvertToArcsPass() {
434  return std::make_unique<ConvertToArcsPass>();
435 }
assert(baseType &&"element must be base type")
static bool isArcBreakingOp(Operation *op)
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:29
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::unique_ptr< OperationPass< ModuleOp > > createConvertToArcsPass()
Definition: hw.py:1
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:16