CIRCT  20.0.0git
ConvertToArcs.cpp
Go to the documentation of this file.
1 //===- ConvertToArcs.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "circt/Dialect/HW/HWOps.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "llvm/Support/Debug.h"
18 
19 #define DEBUG_TYPE "convert-to-arcs"
20 
21 using namespace circt;
22 using namespace arc;
23 using namespace hw;
24 using llvm::MapVector;
25 
26 static bool isArcBreakingOp(Operation *op) {
27  return op->hasTrait<OpTrait::ConstantLike>() ||
28  isa<hw::InstanceOp, seq::CompRegOp, MemoryOp, MemoryReadPortOp,
29  ClockedOpInterface, seq::InitialOp, seq::ClockGateOp,
30  sim::DPICallOp>(op) ||
31  op->getNumResults() > 1;
32 }
33 
34 static LogicalResult convertInitialValue(seq::CompRegOp reg,
35  SmallVectorImpl<Value> &values) {
36  if (!reg.getInitialValue())
37  return values.push_back({}), success();
38 
39  // Use from_immutable cast to convert the seq.immutable type to the reg's
40  // type.
41  OpBuilder builder(reg);
42  auto init = builder.create<seq::FromImmutableOp>(reg.getLoc(), reg.getType(),
43  reg.getInitialValue());
44 
45  values.push_back(init);
46  return success();
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // Conversion
51 //===----------------------------------------------------------------------===//
52 
53 namespace {
54 struct Converter {
55  LogicalResult run(ModuleOp module);
56  LogicalResult runOnModule(HWModuleOp module);
57  LogicalResult analyzeFanIn();
58  void extractArcs(HWModuleOp module);
59  LogicalResult absorbRegs(HWModuleOp module);
60 
61  /// The global namespace used to create unique definition names.
62  Namespace globalNamespace;
63 
64  /// All arc-breaking operations in the current module.
65  SmallVector<Operation *> arcBreakers;
66  SmallDenseMap<Operation *, unsigned> arcBreakerIndices;
67 
68  /// A post-order traversal of the operations in the current module.
69  SmallVector<Operation *> postOrder;
70 
71  /// The set of arc-breaking ops an operation in the current module
72  /// contributes to, represented as a bit mask.
73  MapVector<Operation *, APInt> faninMasks;
74 
75  /// The sets of operations that contribute to the same arc-breaking ops.
76  MapVector<APInt, DenseSet<Operation *>> faninMaskGroups;
77 
78  /// The arc uses generated by `extractArcs`.
79  SmallVector<mlir::CallOpInterface> arcUses;
80 
81  /// Whether registers should be made observable by assigning their arcs a
82  /// "name" attribute.
83  bool tapRegisters;
84 };
85 } // namespace
86 
87 LogicalResult Converter::run(ModuleOp module) {
88  for (auto &op : module.getOps())
89  if (auto sym =
90  op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
91  globalNamespace.newName(sym.getValue());
92  for (auto module : module.getOps<HWModuleOp>())
93  if (failed(runOnModule(module)))
94  return failure();
95  return success();
96 }
97 
98 LogicalResult Converter::runOnModule(HWModuleOp module) {
99  // Find all arc-breaking operations in this module and assign them an index.
100  arcBreakers.clear();
101  arcBreakerIndices.clear();
102  for (Operation &op : *module.getBodyBlock()) {
103  if (isa<seq::InitialOp>(&op))
104  continue;
105  if (op.getNumRegions() > 0)
106  return op.emitOpError("has regions; not supported by ConvertToArcs");
107  if (!isArcBreakingOp(&op) && !isa<hw::OutputOp>(&op))
108  continue;
109  arcBreakerIndices[&op] = arcBreakers.size();
110  arcBreakers.push_back(&op);
111  }
112  // Skip modules with only `OutputOp`.
113  if (module.getBodyBlock()->without_terminator().empty() &&
114  isa<hw::OutputOp>(module.getBodyBlock()->getTerminator()))
115  return success();
116  LLVM_DEBUG(llvm::dbgs() << "Analyzing " << module.getModuleNameAttr() << " ("
117  << arcBreakers.size() << " breakers)\n");
118 
119  // For each operation, figure out the set of breaker ops it contributes to,
120  // in the form of a bit mask. Then group operations together that contribute
121  // to the same set of breaker ops.
122  if (failed(analyzeFanIn()))
123  return failure();
124 
125  // Extract the fanin mask groups into separate combinational arcs and
126  // combine them with the registers in the design.
127  extractArcs(module);
128  if (failed(absorbRegs(module)))
129  return failure();
130 
131  return success();
132 }
133 
134 LogicalResult Converter::analyzeFanIn() {
135  SmallVector<std::tuple<Operation *, unsigned>> worklist;
136 
137  // Seed the worklist and fanin masks with the arc breaking operations.
138  faninMasks.clear();
139  for (auto *op : arcBreakers) {
140  unsigned index = arcBreakerIndices.lookup(op);
141  auto mask = APInt::getOneBitSet(arcBreakers.size(), index);
142  faninMasks[op] = mask;
143  worklist.push_back({op, 0});
144  }
145 
146  // Establish a post-order among the operations.
147  DenseSet<Operation *> seen;
148  DenseSet<Operation *> finished;
149  postOrder.clear();
150  while (!worklist.empty()) {
151  auto &[op, operandIdx] = worklist.back();
152  if (operandIdx == op->getNumOperands()) {
153  if (!isArcBreakingOp(op) && !isa<hw::OutputOp>(op))
154  postOrder.push_back(op);
155  finished.insert(op);
156  seen.erase(op);
157  worklist.pop_back();
158  continue;
159  }
160  auto operand = op->getOperand(operandIdx++); // advance to next operand
161  auto *definingOp = operand.getDefiningOp();
162  if (!definingOp || isArcBreakingOp(definingOp) ||
163  finished.contains(definingOp))
164  continue;
165  if (!seen.insert(definingOp).second) {
166  definingOp->emitError("combinational loop detected");
167  return failure();
168  }
169  worklist.push_back({definingOp, 0});
170  }
171  LLVM_DEBUG(llvm::dbgs() << "- Sorted " << postOrder.size() << " ops\n");
172 
173  // Compute fanin masks in reverse post-order, which will compute the mask
174  // for an operation's uses before it computes it for the operation itself.
175  // This allows us to compute the set of arc breakers an operation
176  // contributes to in one pass.
177  for (auto *op : llvm::reverse(postOrder)) {
178  auto mask = APInt::getZero(arcBreakers.size());
179  for (auto *user : op->getUsers()) {
180  auto it = faninMasks.find(user);
181  if (it != faninMasks.end())
182  mask |= it->second;
183  }
184 
185  auto duplicateOp = faninMasks.insert({op, mask});
186  (void)duplicateOp;
187  assert(duplicateOp.second && "duplicate op in order");
188  }
189 
190  // Group the operations by their fan-in mask.
191  faninMaskGroups.clear();
192  for (auto [op, mask] : faninMasks)
193  if (!isArcBreakingOp(op) && !isa<hw::OutputOp>(op))
194  faninMaskGroups[mask].insert(op);
195  LLVM_DEBUG(llvm::dbgs() << "- Found " << faninMaskGroups.size()
196  << " fanin mask groups\n");
197 
198  return success();
199 }
200 
201 void Converter::extractArcs(HWModuleOp module) {
202  DenseMap<Value, Value> valueMapping;
203  SmallVector<Value> inputs;
204  SmallVector<Value> outputs;
205  SmallVector<Type> inputTypes;
206  SmallVector<Type> outputTypes;
207  SmallVector<std::pair<OpOperand *, unsigned>> externalUses;
208 
209  arcUses.clear();
210  for (auto &group : faninMaskGroups) {
211  auto &opSet = group.second;
212  OpBuilder builder(module);
213 
214  auto block = std::make_unique<Block>();
215  builder.setInsertionPointToStart(block.get());
216  valueMapping.clear();
217  inputs.clear();
218  outputs.clear();
219  inputTypes.clear();
220  outputTypes.clear();
221  externalUses.clear();
222 
223  Operation *lastOp = nullptr;
224  // TODO: Remove the elements from the post order as we go.
225  for (auto *op : postOrder) {
226  if (!opSet.contains(op))
227  continue;
228  lastOp = op;
229  op->remove();
230  builder.insert(op);
231  for (auto &operand : op->getOpOperands()) {
232  if (opSet.contains(operand.get().getDefiningOp()))
233  continue;
234  auto &mapped = valueMapping[operand.get()];
235  if (!mapped) {
236  mapped = block->addArgument(operand.get().getType(),
237  operand.get().getLoc());
238  inputs.push_back(operand.get());
239  inputTypes.push_back(mapped.getType());
240  }
241  operand.set(mapped);
242  }
243  for (auto result : op->getResults()) {
244  bool anyExternal = false;
245  for (auto &use : result.getUses()) {
246  if (!opSet.contains(use.getOwner())) {
247  anyExternal = true;
248  externalUses.push_back({&use, outputs.size()});
249  }
250  }
251  if (anyExternal) {
252  outputs.push_back(result);
253  outputTypes.push_back(result.getType());
254  }
255  }
256  }
257  assert(lastOp);
258  builder.create<arc::OutputOp>(lastOp->getLoc(), outputs);
259 
260  // Create the arc definition.
261  builder.setInsertionPoint(module);
262  auto defOp = builder.create<DefineOp>(
263  lastOp->getLoc(),
264  builder.getStringAttr(
265  globalNamespace.newName(module.getModuleName() + "_arc")),
266  builder.getFunctionType(inputTypes, outputTypes));
267  defOp.getBody().push_back(block.release());
268 
269  // Create the call to the arc definition to replace the operations that
270  // we have just extracted.
271  builder.setInsertionPoint(module.getBodyBlock()->getTerminator());
272  auto arcOp = builder.create<CallOp>(lastOp->getLoc(), defOp, inputs);
273  arcUses.push_back(arcOp);
274  for (auto [use, resultIdx] : externalUses)
275  use->set(arcOp.getResult(resultIdx));
276  }
277 }
278 
279 LogicalResult Converter::absorbRegs(HWModuleOp module) {
280  // Handle the trivial cases where all of an arc's results are used by
281  // exactly one register each.
282  unsigned outIdx = 0;
283  unsigned numTrivialRegs = 0;
284  for (auto callOp : arcUses) {
285  auto stateOp = dyn_cast<StateOp>(callOp.getOperation());
286  Value clock = stateOp ? stateOp.getClock() : Value{};
287  Value reset;
288  SmallVector<Value> initialValues;
289  SmallVector<seq::CompRegOp> absorbedRegs;
290  SmallVector<Attribute> absorbedNames(callOp->getNumResults(), {});
291  if (auto names = callOp->getAttrOfType<ArrayAttr>("names"))
292  absorbedNames.assign(names.getValue().begin(), names.getValue().end());
293 
294  // Go through all every arc result and collect the single register that uses
295  // it. If a result has multiple uses or is used by something other than a
296  // register, skip the arc for now and handle it later.
297  bool isTrivial = true;
298  for (auto result : callOp->getResults()) {
299  if (!result.hasOneUse()) {
300  isTrivial = false;
301  break;
302  }
303  auto regOp = dyn_cast<seq::CompRegOp>(result.use_begin()->getOwner());
304  if (!regOp || regOp.getInput() != result ||
305  (clock && clock != regOp.getClk())) {
306  isTrivial = false;
307  break;
308  }
309 
310  clock = regOp.getClk();
311  reset = regOp.getReset();
312 
313  // Check that if the register has a reset, it is to a constant zero
314  if (reset) {
315  Value resetValue = regOp.getResetValue();
316  Operation *op = resetValue.getDefiningOp();
317  if (!op)
318  return regOp->emitOpError(
319  "is reset by an input; not supported by ConvertToArcs");
320  if (auto constant = dyn_cast<hw::ConstantOp>(op)) {
321  if (constant.getValue() != 0)
322  return regOp->emitOpError("is reset to a constant non-zero value; "
323  "not supported by ConvertToArcs");
324  } else {
325  return regOp->emitOpError("is reset to a value that is not clearly "
326  "constant; not supported by ConvertToArcs");
327  }
328  }
329 
330  if (failed(convertInitialValue(regOp, initialValues)))
331  return failure();
332 
333  absorbedRegs.push_back(regOp);
334  // If we absorb a register into the arc, the arc effectively produces that
335  // register's value. So if the register had a name, ensure that we assign
336  // that name to the arc's output.
337  absorbedNames[result.getResultNumber()] = regOp.getNameAttr();
338  }
339 
340  // If this wasn't a trivial case keep the arc around for a second iteration.
341  if (!isTrivial) {
342  arcUses[outIdx++] = callOp;
343  continue;
344  }
345  ++numTrivialRegs;
346 
347  // Set the arc's clock to the clock of the registers we've absorbed, bump
348  // the latency up by one to account for the registers, add the reset if
349  // present and update the output names. Then replace the registers.
350 
351  auto arc = dyn_cast<StateOp>(callOp.getOperation());
352  if (arc) {
353  arc.getClockMutable().assign(clock);
354  arc.setLatency(arc.getLatency() + 1);
355  } else {
356  mlir::IRRewriter rewriter(module->getContext());
357  rewriter.setInsertionPoint(callOp);
358  arc = rewriter.replaceOpWithNewOp<StateOp>(
359  callOp.getOperation(),
360  callOp.getCallableForCallee().get<SymbolRefAttr>(),
361  callOp->getResultTypes(), clock, Value{}, 1, callOp.getArgOperands());
362  }
363 
364  if (reset) {
365  if (arc.getReset())
366  return arc.emitError(
367  "StateOp tried to infer reset from CompReg, but already "
368  "had a reset.");
369  arc.getResetMutable().assign(reset);
370  }
371 
372  bool onlyDefaultInitializers =
373  llvm::all_of(initialValues, [](auto val) -> bool { return !val; });
374 
375  if (!onlyDefaultInitializers) {
376  if (!arc.getInitials().empty()) {
377  return arc.emitError(
378  "StateOp tried to infer initial values from CompReg, but already "
379  "had an initial value.");
380  }
381  // Create 0 constants for default initialization
382  for (unsigned i = 0; i < initialValues.size(); ++i) {
383  if (!initialValues[i]) {
384  OpBuilder zeroBuilder(arc);
385  initialValues[i] = zeroBuilder.createOrFold<hw::ConstantOp>(
386  arc.getLoc(),
387  zeroBuilder.getIntegerAttr(arc.getResult(i).getType(), 0));
388  }
389  }
390  arc.getInitialsMutable().assign(initialValues);
391  }
392 
393  if (tapRegisters && llvm::any_of(absorbedNames, [](auto name) {
394  return !cast<StringAttr>(name).getValue().empty();
395  }))
396  arc->setAttr("names", ArrayAttr::get(module.getContext(), absorbedNames));
397  for (auto [arcResult, reg] : llvm::zip(arc.getResults(), absorbedRegs)) {
398  auto it = arcBreakerIndices.find(reg);
399  arcBreakers[it->second] = {};
400  arcBreakerIndices.erase(it);
401  reg.replaceAllUsesWith(arcResult);
402  reg.erase();
403  }
404  }
405  if (numTrivialRegs > 0)
406  LLVM_DEBUG(llvm::dbgs() << "- Trivially converted " << numTrivialRegs
407  << " regs to arcs\n");
408  arcUses.truncate(outIdx);
409 
410  // Group the remaining registers by their clock, their reset and the operation
411  // they use as input. This will allow us to generally collapse registers
412  // derived from the same arc into one shuffling arc.
413  MapVector<std::tuple<Value, Value, Operation *>, SmallVector<seq::CompRegOp>>
414  regsByInput;
415  for (auto *op : arcBreakers)
416  if (auto regOp = dyn_cast_or_null<seq::CompRegOp>(op)) {
417  regsByInput[{regOp.getClk(), regOp.getReset(),
418  regOp.getInput().getDefiningOp()}]
419  .push_back(regOp);
420  }
421 
422  unsigned numMappedRegs = 0;
423  for (auto [clockAndResetAndOp, regOps] : regsByInput) {
424  numMappedRegs += regOps.size();
425  OpBuilder builder(module);
426  auto block = std::make_unique<Block>();
427  builder.setInsertionPointToStart(block.get());
428 
429  SmallVector<Value> inputs;
430  SmallVector<Value> outputs;
431  SmallVector<Attribute> names;
432  SmallVector<Type> types;
433  SmallVector<Value> initialValues;
435  SmallVector<unsigned> regToOutputMapping;
436  for (auto regOp : regOps) {
437  auto it = mapping.find(regOp.getInput());
438  if (it == mapping.end()) {
439  it = mapping.insert({regOp.getInput(), inputs.size()}).first;
440  inputs.push_back(regOp.getInput());
441  types.push_back(regOp.getType());
442  outputs.push_back(block->addArgument(regOp.getType(), regOp.getLoc()));
443  names.push_back(regOp->getAttrOfType<StringAttr>("name"));
444  if (failed(convertInitialValue(regOp, initialValues)))
445  return failure();
446  }
447  regToOutputMapping.push_back(it->second);
448  }
449 
450  auto loc = regOps.back().getLoc();
451  builder.create<arc::OutputOp>(loc, outputs);
452 
453  builder.setInsertionPoint(module);
454  auto defOp =
455  builder.create<DefineOp>(loc,
456  builder.getStringAttr(globalNamespace.newName(
457  module.getModuleName() + "_arc")),
458  builder.getFunctionType(types, types));
459  defOp.getBody().push_back(block.release());
460 
461  builder.setInsertionPoint(module.getBodyBlock()->getTerminator());
462 
463  bool onlyDefaultInitializers =
464  llvm::all_of(initialValues, [](auto val) -> bool { return !val; });
465 
466  if (onlyDefaultInitializers)
467  initialValues.clear();
468  else
469  for (unsigned i = 0; i < initialValues.size(); ++i) {
470  if (!initialValues[i])
471  initialValues[i] = builder.createOrFold<hw::ConstantOp>(
472  loc, builder.getIntegerAttr(types[i], 0));
473  }
474 
475  auto arcOp =
476  builder.create<StateOp>(loc, defOp, std::get<0>(clockAndResetAndOp),
477  /*enable=*/Value{}, 1, inputs, initialValues);
478  auto reset = std::get<1>(clockAndResetAndOp);
479  if (reset)
480  arcOp.getResetMutable().assign(reset);
481  if (tapRegisters && llvm::any_of(names, [](auto name) {
482  return !cast<StringAttr>(name).getValue().empty();
483  }))
484  arcOp->setAttr("names", builder.getArrayAttr(names));
485  for (auto [reg, resultIdx] : llvm::zip(regOps, regToOutputMapping)) {
486  reg.replaceAllUsesWith(arcOp.getResult(resultIdx));
487  reg.erase();
488  }
489  }
490 
491  if (numMappedRegs > 0)
492  LLVM_DEBUG(llvm::dbgs() << "- Mapped " << numMappedRegs << " regs to "
493  << regsByInput.size() << " shuffling arcs\n");
494 
495  return success();
496 }
497 
498 //===----------------------------------------------------------------------===//
499 // Pass Infrastructure
500 //===----------------------------------------------------------------------===//
501 
502 namespace circt {
503 #define GEN_PASS_DEF_CONVERTTOARCS
504 #include "circt/Conversion/Passes.h.inc"
505 } // namespace circt
506 
507 namespace {
508 struct ConvertToArcsPass : public impl::ConvertToArcsBase<ConvertToArcsPass> {
509  using ConvertToArcsBase::ConvertToArcsBase;
510 
511  void runOnOperation() override {
512  Converter converter;
513  converter.tapRegisters = tapRegisters;
514  if (failed(converter.run(getOperation())))
515  signalPassFailure();
516  }
517 };
518 } // namespace
519 
520 std::unique_ptr<OperationPass<ModuleOp>>
521 circt::createConvertToArcsPass(const ConvertToArcsOptions &options) {
522  return std::make_unique<ConvertToArcsPass>(options);
523 }
assert(baseType &&"element must be base type")
static LogicalResult convertInitialValue(seq::CompRegOp reg, SmallVectorImpl< Value > &values)
static bool isArcBreakingOp(Operation *op)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
def create(data_type, value)
Definition: hw.py:393
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
std::unique_ptr< OperationPass< ModuleOp > > createConvertToArcsPass(const ConvertToArcsOptions &options={})
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition: codegen.py:121
Definition: hw.py:1
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:21