CIRCT  19.0.0git
ConvertToArcs.cpp
Go to the documentation of this file.
1 //===- ConvertToArcs.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../PassDetail.h"
12 #include "circt/Dialect/HW/HWOps.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "convert-to-arcs"
19 
20 using namespace circt;
21 using namespace arc;
22 using namespace hw;
23 using llvm::MapVector;
24 
25 static bool isArcBreakingOp(Operation *op) {
26  return op->hasTrait<OpTrait::ConstantLike>() ||
27  isa<hw::InstanceOp, seq::CompRegOp, MemoryOp, ClockedOpInterface,
28  seq::ClockGateOp>(op) ||
29  op->getNumResults() > 1;
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // Conversion
34 //===----------------------------------------------------------------------===//
35 
36 namespace {
37 struct Converter {
38  LogicalResult run(ModuleOp module);
39  LogicalResult runOnModule(HWModuleOp module);
40  LogicalResult analyzeFanIn();
41  void extractArcs(HWModuleOp module);
42  LogicalResult absorbRegs(HWModuleOp module);
43 
44  /// The global namespace used to create unique definition names.
45  Namespace globalNamespace;
46 
47  /// All arc-breaking operations in the current module.
48  SmallVector<Operation *> arcBreakers;
49  SmallDenseMap<Operation *, unsigned> arcBreakerIndices;
50 
51  /// A post-order traversal of the operations in the current module.
52  SmallVector<Operation *> postOrder;
53 
54  /// The set of arc-breaking ops an operation in the current module
55  /// contributes to, represented as a bit mask.
56  MapVector<Operation *, APInt> faninMasks;
57 
58  /// The sets of operations that contribute to the same arc-breaking ops.
59  MapVector<APInt, DenseSet<Operation *>> faninMaskGroups;
60 
61  /// The arc uses generated by `extractArcs`.
62  SmallVector<mlir::CallOpInterface> arcUses;
63 
64  /// Whether registers should be made observable by assigning their arcs a
65  /// "name" attribute.
66  bool tapRegisters;
67 };
68 } // namespace
69 
70 LogicalResult Converter::run(ModuleOp module) {
71  for (auto &op : module.getOps())
72  if (auto sym =
73  op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
74  globalNamespace.newName(sym.getValue());
75  for (auto module : module.getOps<HWModuleOp>())
76  if (failed(runOnModule(module)))
77  return failure();
78  return success();
79 }
80 
81 LogicalResult Converter::runOnModule(HWModuleOp module) {
82  // Find all arc-breaking operations in this module and assign them an index.
83  arcBreakers.clear();
84  arcBreakerIndices.clear();
85  for (Operation &op : *module.getBodyBlock()) {
86  if (op.getNumRegions() > 0)
87  return op.emitOpError("has regions; not supported by ConvertToArcs");
88  if (!isArcBreakingOp(&op) && !isa<hw::OutputOp>(&op))
89  continue;
90  arcBreakerIndices[&op] = arcBreakers.size();
91  arcBreakers.push_back(&op);
92  }
93  // Skip modules with only `OutputOp`.
94  if (module.getBodyBlock()->without_terminator().empty() &&
95  isa<hw::OutputOp>(module.getBodyBlock()->getTerminator()))
96  return success();
97  LLVM_DEBUG(llvm::dbgs() << "Analyzing " << module.getModuleNameAttr() << " ("
98  << arcBreakers.size() << " breakers)\n");
99 
100  // For each operation, figure out the set of breaker ops it contributes to,
101  // in the form of a bit mask. Then group operations together that contribute
102  // to the same set of breaker ops.
103  if (failed(analyzeFanIn()))
104  return failure();
105 
106  // Extract the fanin mask groups into separate combinational arcs and
107  // combine them with the registers in the design.
108  extractArcs(module);
109  if (failed(absorbRegs(module)))
110  return failure();
111  return success();
112 }
113 
114 LogicalResult Converter::analyzeFanIn() {
115  SmallVector<std::tuple<Operation *, unsigned>> worklist;
116 
117  // Seed the worklist and fanin masks with the arc breaking operations.
118  faninMasks.clear();
119  for (auto *op : arcBreakers) {
120  unsigned index = arcBreakerIndices.lookup(op);
121  auto mask = APInt::getOneBitSet(arcBreakers.size(), index);
122  faninMasks[op] = mask;
123  worklist.push_back({op, 0});
124  }
125 
126  // Establish a post-order among the operations.
127  DenseSet<Operation *> seen;
128  DenseSet<Operation *> finished;
129  postOrder.clear();
130  while (!worklist.empty()) {
131  auto &[op, operandIdx] = worklist.back();
132  if (operandIdx == op->getNumOperands()) {
133  if (!isArcBreakingOp(op) && !isa<hw::OutputOp>(op))
134  postOrder.push_back(op);
135  finished.insert(op);
136  seen.erase(op);
137  worklist.pop_back();
138  continue;
139  }
140  auto operand = op->getOperand(operandIdx++); // advance to next operand
141  auto *definingOp = operand.getDefiningOp();
142  if (!definingOp || isArcBreakingOp(definingOp) ||
143  finished.contains(definingOp))
144  continue;
145  if (!seen.insert(definingOp).second) {
146  definingOp->emitError("combinational loop detected");
147  return failure();
148  }
149  worklist.push_back({definingOp, 0});
150  }
151  LLVM_DEBUG(llvm::dbgs() << "- Sorted " << postOrder.size() << " ops\n");
152 
153  // Compute fanin masks in reverse post-order, which will compute the mask
154  // for an operation's uses before it computes it for the operation itself.
155  // This allows us to compute the set of arc breakers an operation
156  // contributes to in one pass.
157  for (auto *op : llvm::reverse(postOrder)) {
158  auto mask = APInt::getZero(arcBreakers.size());
159  for (auto *user : op->getUsers()) {
160  auto it = faninMasks.find(user);
161  if (it != faninMasks.end())
162  mask |= it->second;
163  }
164 
165  auto duplicateOp = faninMasks.insert({op, mask});
166  (void)duplicateOp;
167  assert(duplicateOp.second && "duplicate op in order");
168  }
169 
170  // Group the operations by their fan-in mask.
171  faninMaskGroups.clear();
172  for (auto [op, mask] : faninMasks)
173  if (!isArcBreakingOp(op) && !isa<hw::OutputOp>(op))
174  faninMaskGroups[mask].insert(op);
175  LLVM_DEBUG(llvm::dbgs() << "- Found " << faninMaskGroups.size()
176  << " fanin mask groups\n");
177 
178  return success();
179 }
180 
181 void Converter::extractArcs(HWModuleOp module) {
182  DenseMap<Value, Value> valueMapping;
183  SmallVector<Value> inputs;
184  SmallVector<Value> outputs;
185  SmallVector<Type> inputTypes;
186  SmallVector<Type> outputTypes;
187  SmallVector<std::pair<OpOperand *, unsigned>> externalUses;
188 
189  arcUses.clear();
190  for (auto &group : faninMaskGroups) {
191  auto &opSet = group.second;
192  OpBuilder builder(module);
193 
194  auto block = std::make_unique<Block>();
195  builder.setInsertionPointToStart(block.get());
196  valueMapping.clear();
197  inputs.clear();
198  outputs.clear();
199  inputTypes.clear();
200  outputTypes.clear();
201  externalUses.clear();
202 
203  Operation *lastOp = nullptr;
204  // TODO: Remove the elements from the post order as we go.
205  for (auto *op : postOrder) {
206  if (!opSet.contains(op))
207  continue;
208  lastOp = op;
209  op->remove();
210  builder.insert(op);
211  for (auto &operand : op->getOpOperands()) {
212  if (opSet.contains(operand.get().getDefiningOp()))
213  continue;
214  auto &mapped = valueMapping[operand.get()];
215  if (!mapped) {
216  mapped = block->addArgument(operand.get().getType(),
217  operand.get().getLoc());
218  inputs.push_back(operand.get());
219  inputTypes.push_back(mapped.getType());
220  }
221  operand.set(mapped);
222  }
223  for (auto result : op->getResults()) {
224  bool anyExternal = false;
225  for (auto &use : result.getUses()) {
226  if (!opSet.contains(use.getOwner())) {
227  anyExternal = true;
228  externalUses.push_back({&use, outputs.size()});
229  }
230  }
231  if (anyExternal) {
232  outputs.push_back(result);
233  outputTypes.push_back(result.getType());
234  }
235  }
236  }
237  assert(lastOp);
238  builder.create<arc::OutputOp>(lastOp->getLoc(), outputs);
239 
240  // Create the arc definition.
241  builder.setInsertionPoint(module);
242  auto defOp = builder.create<DefineOp>(
243  lastOp->getLoc(),
244  builder.getStringAttr(
245  globalNamespace.newName(module.getModuleName() + "_arc")),
246  builder.getFunctionType(inputTypes, outputTypes));
247  defOp.getBody().push_back(block.release());
248 
249  // Create the call to the arc definition to replace the operations that
250  // we have just extracted.
251  builder.setInsertionPoint(module.getBodyBlock()->getTerminator());
252  auto arcOp = builder.create<CallOp>(lastOp->getLoc(), defOp, inputs);
253  arcUses.push_back(arcOp);
254  for (auto [use, resultIdx] : externalUses)
255  use->set(arcOp.getResult(resultIdx));
256  }
257 }
258 
259 LogicalResult Converter::absorbRegs(HWModuleOp module) {
260  // Handle the trivial cases where all of an arc's results are used by
261  // exactly one register each.
262  unsigned outIdx = 0;
263  unsigned numTrivialRegs = 0;
264  for (auto callOp : arcUses) {
265  auto stateOp = dyn_cast<StateOp>(callOp.getOperation());
266  Value clock = stateOp ? stateOp.getClock() : Value{};
267  Value reset;
268  SmallVector<seq::CompRegOp> absorbedRegs;
269  SmallVector<Attribute> absorbedNames(callOp->getNumResults(), {});
270  if (auto names = callOp->getAttrOfType<ArrayAttr>("names"))
271  absorbedNames.assign(names.getValue().begin(), names.getValue().end());
272 
273  // Go through all every arc result and collect the single register that uses
274  // it. If a result has multiple uses or is used by something other than a
275  // register, skip the arc for now and handle it later.
276  bool isTrivial = true;
277  for (auto result : callOp->getResults()) {
278  if (!result.hasOneUse()) {
279  isTrivial = false;
280  break;
281  }
282  auto regOp = dyn_cast<seq::CompRegOp>(result.use_begin()->getOwner());
283  if (!regOp || regOp.getInput() != result ||
284  (clock && clock != regOp.getClk())) {
285  isTrivial = false;
286  break;
287  }
288 
289  clock = regOp.getClk();
290  reset = regOp.getReset();
291 
292  // Check that if the register has a reset, it is to a constant zero
293  if (reset) {
294  Value resetValue = regOp.getResetValue();
295  Operation *op = resetValue.getDefiningOp();
296  if (!op)
297  return regOp->emitOpError(
298  "is reset by an input; not supported by ConvertToArcs");
299  if (auto constant = dyn_cast<hw::ConstantOp>(op)) {
300  if (constant.getValue() != 0)
301  return regOp->emitOpError("is reset to a constant non-zero value; "
302  "not supported by ConvertToArcs");
303  } else {
304  return regOp->emitOpError("is reset to a value that is not clearly "
305  "constant; not supported by ConvertToArcs");
306  }
307  }
308 
309  absorbedRegs.push_back(regOp);
310  // If we absorb a register into the arc, the arc effectively produces that
311  // register's value. So if the register had a name, ensure that we assign
312  // that name to the arc's output.
313  absorbedNames[result.getResultNumber()] = regOp.getNameAttr();
314  }
315 
316  // If this wasn't a trivial case keep the arc around for a second iteration.
317  if (!isTrivial) {
318  arcUses[outIdx++] = callOp;
319  continue;
320  }
321  ++numTrivialRegs;
322 
323  // Set the arc's clock to the clock of the registers we've absorbed, bump
324  // the latency up by one to account for the registers, add the reset if
325  // present and update the output names. Then replace the registers.
326 
327  auto arc = dyn_cast<StateOp>(callOp.getOperation());
328  if (arc) {
329  arc.getClockMutable().assign(clock);
330  arc.setLatency(arc.getLatency() + 1);
331  } else {
332  mlir::IRRewriter rewriter(module->getContext());
333  rewriter.setInsertionPoint(callOp);
334  arc = rewriter.replaceOpWithNewOp<StateOp>(
335  callOp.getOperation(),
336  callOp.getCallableForCallee().get<SymbolRefAttr>(),
337  callOp->getResultTypes(), clock, Value{}, 1, callOp.getArgOperands());
338  }
339 
340  if (reset) {
341  if (arc.getReset())
342  return arc.emitError(
343  "StateOp tried to infer reset from CompReg, but already "
344  "had a reset.");
345  arc.getResetMutable().assign(reset);
346  }
347  if (tapRegisters && llvm::any_of(absorbedNames, [](auto name) {
348  return !name.template cast<StringAttr>().getValue().empty();
349  }))
350  arc->setAttr("names", ArrayAttr::get(module.getContext(), absorbedNames));
351  for (auto [arcResult, reg] : llvm::zip(arc.getResults(), absorbedRegs)) {
352  auto it = arcBreakerIndices.find(reg);
353  arcBreakers[it->second] = {};
354  arcBreakerIndices.erase(it);
355  reg.replaceAllUsesWith(arcResult);
356  reg.erase();
357  }
358  }
359  if (numTrivialRegs > 0)
360  LLVM_DEBUG(llvm::dbgs() << "- Trivially converted " << numTrivialRegs
361  << " regs to arcs\n");
362  arcUses.truncate(outIdx);
363 
364  // Group the remaining registers by their clock, their reset and the operation
365  // they use as input. This will allow us to generally collapse registers
366  // derived from the same arc into one shuffling arc.
367  MapVector<std::tuple<Value, Value, Operation *>, SmallVector<seq::CompRegOp>>
368  regsByInput;
369  for (auto *op : arcBreakers)
370  if (auto regOp = dyn_cast_or_null<seq::CompRegOp>(op)) {
371  regsByInput[{regOp.getClk(), regOp.getReset(),
372  regOp.getInput().getDefiningOp()}]
373  .push_back(regOp);
374  }
375 
376  unsigned numMappedRegs = 0;
377  for (auto [clockAndResetAndOp, regOps] : regsByInput) {
378  numMappedRegs += regOps.size();
379  OpBuilder builder(module);
380  auto block = std::make_unique<Block>();
381  builder.setInsertionPointToStart(block.get());
382 
383  SmallVector<Value> inputs;
384  SmallVector<Value> outputs;
385  SmallVector<Attribute> names;
386  SmallVector<Type> types;
388  SmallVector<unsigned> regToOutputMapping;
389  for (auto regOp : regOps) {
390  auto it = mapping.find(regOp.getInput());
391  if (it == mapping.end()) {
392  it = mapping.insert({regOp.getInput(), inputs.size()}).first;
393  inputs.push_back(regOp.getInput());
394  types.push_back(regOp.getType());
395  outputs.push_back(block->addArgument(regOp.getType(), regOp.getLoc()));
396  names.push_back(regOp->getAttrOfType<StringAttr>("name"));
397  }
398  regToOutputMapping.push_back(it->second);
399  }
400 
401  auto loc = regOps.back().getLoc();
402  builder.create<arc::OutputOp>(loc, outputs);
403 
404  builder.setInsertionPoint(module);
405  auto defOp =
406  builder.create<DefineOp>(loc,
407  builder.getStringAttr(globalNamespace.newName(
408  module.getModuleName() + "_arc")),
409  builder.getFunctionType(types, types));
410  defOp.getBody().push_back(block.release());
411 
412  builder.setInsertionPoint(module.getBodyBlock()->getTerminator());
413  auto arcOp =
414  builder.create<StateOp>(loc, defOp, std::get<0>(clockAndResetAndOp),
415  /*enable=*/Value{}, 1, inputs);
416  auto reset = std::get<1>(clockAndResetAndOp);
417  if (reset)
418  arcOp.getResetMutable().assign(reset);
419  if (tapRegisters && llvm::any_of(names, [](auto name) {
420  return !name.template cast<StringAttr>().getValue().empty();
421  }))
422  arcOp->setAttr("names", builder.getArrayAttr(names));
423  for (auto [reg, resultIdx] : llvm::zip(regOps, regToOutputMapping)) {
424  reg.replaceAllUsesWith(arcOp.getResult(resultIdx));
425  reg.erase();
426  }
427  }
428 
429  if (numMappedRegs > 0)
430  LLVM_DEBUG(llvm::dbgs() << "- Mapped " << numMappedRegs << " regs to "
431  << regsByInput.size() << " shuffling arcs\n");
432 
433  return success();
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // Pass Infrastructure
438 //===----------------------------------------------------------------------===//
439 
440 namespace circt {
441 #define GEN_PASS_DEF_CONVERTTOARCS
442 #include "circt/Conversion/Passes.h.inc"
443 } // namespace circt
444 
445 namespace {
446 struct ConvertToArcsPass : public impl::ConvertToArcsBase<ConvertToArcsPass> {
447  using ConvertToArcsBase::ConvertToArcsBase;
448 
449  void runOnOperation() override {
450  Converter converter;
451  converter.tapRegisters = tapRegisters;
452  if (failed(converter.run(getOperation())))
453  signalPassFailure();
454  }
455 };
456 } // namespace
457 
458 std::unique_ptr<OperationPass<ModuleOp>>
459 circt::createConvertToArcsPass(const ConvertToArcsOptions &options) {
460  return std::make_unique<ConvertToArcsPass>(options);
461 }
assert(baseType &&"element must be base type")
static bool isArcBreakingOp(Operation *op)
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:29
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
std::unique_ptr< OperationPass< ModuleOp > > createConvertToArcsPass(const ConvertToArcsOptions &options={})
Definition: hw.py:1
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:20