CIRCT  18.0.0git
LowerState.cpp
Go to the documentation of this file.
1 //===- LowerState.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "circt/Dialect/HW/HWOps.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/SymbolTable.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/TopologicalSortUtils.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "arc-lower-state"
28 
29 namespace circt {
30 namespace arc {
31 #define GEN_PASS_DEF_LOWERSTATE
32 #include "circt/Dialect/Arc/ArcPasses.h.inc"
33 } // namespace arc
34 } // namespace circt
35 
36 using namespace circt;
37 using namespace arc;
38 using namespace hw;
39 using namespace mlir;
40 using llvm::SmallDenseSet;
41 
42 //===----------------------------------------------------------------------===//
43 // Data Structures
44 //===----------------------------------------------------------------------===//
45 
46 namespace {
47 
48 /// Statistics gathered throughout the execution of this pass.
49 struct Statistics {
50  Pass *parent;
51  Statistics(Pass *parent) : parent(parent) {}
52  using Statistic = Pass::Statistic;
53 
54  Statistic matOpsMoved{parent, "mat-ops-moved",
55  "Ops moved during value materialization"};
56  Statistic matOpsCloned{parent, "mat-ops-cloned",
57  "Ops cloned during value materialization"};
58  Statistic opsPruned{parent, "ops-pruned", "Ops removed as dead code"};
59 };
60 
61 /// Lowering info associated with a single primary clock.
62 struct ClockLowering {
63  /// The root clock this lowering is for.
64  Value clock;
65  /// A `ClockTreeOp` or `PassThroughOp`.
66  Operation *treeOp;
67  /// Pass statistics.
68  Statistics &stats;
69  OpBuilder builder;
70  /// A mapping from values outside the clock tree to their materialize form
71  /// inside the clock tree.
72  IRMapping materializedValues;
73  /// A cache of AND gates created for aggregating enable conditions.
74  DenseMap<std::pair<Value, Value>, Value> andCache;
75 
76  ClockLowering(Value clock, Operation *treeOp, Statistics &stats)
77  : clock(clock), treeOp(treeOp), stats(stats), builder(treeOp) {
78  assert((isa<ClockTreeOp, PassThroughOp>(treeOp)));
79  builder.setInsertionPointToStart(&treeOp->getRegion(0).front());
80  }
81 
82  Value materializeValue(Value value);
83  Value getOrCreateAnd(Value lhs, Value rhs, Location loc);
84 };
85 
86 struct GatedClockLowering {
87  /// Lowering info of the primary clock.
88  ClockLowering &clock;
89  /// An optional enable condition of the primary clock. May be null.
90  Value enable;
91 };
92 
93 /// State lowering for a single `HWModuleOp`.
94 struct ModuleLowering {
95  HWModuleOp moduleOp;
96  /// Pass statistics.
97  Statistics &stats;
98  MLIRContext *context;
99  DenseMap<Value, std::unique_ptr<ClockLowering>> clockLowerings;
100  DenseMap<Value, GatedClockLowering> gatedClockLowerings;
101  Value storageArg;
102  OpBuilder clockBuilder;
103  OpBuilder stateBuilder;
104 
105  ModuleLowering(HWModuleOp moduleOp, Statistics &stats)
106  : moduleOp(moduleOp), stats(stats), context(moduleOp.getContext()),
107  clockBuilder(moduleOp), stateBuilder(moduleOp) {}
108 
109  GatedClockLowering getOrCreateClockLowering(Value clock);
110  ClockLowering &getOrCreatePassThrough();
111  Value replaceValueWithStateRead(Value value, Value state);
112 
113  void addStorageArg();
114  LogicalResult lowerPrimaryInputs();
115  LogicalResult lowerPrimaryOutputs();
116  LogicalResult lowerStates();
117  LogicalResult lowerState(StateOp stateOp);
118  LogicalResult lowerState(MemoryOp memOp);
119  LogicalResult lowerState(MemoryWritePortOp memWriteOp);
120  LogicalResult lowerState(TapOp tapOp);
121  LogicalResult lowerExtModules(SymbolTable &symtbl);
122  LogicalResult lowerExtModule(InstanceOp instOp);
123 
124  LogicalResult cleanup();
125 };
126 } // namespace
127 
128 //===----------------------------------------------------------------------===//
129 // Clock Lowering
130 //===----------------------------------------------------------------------===//
131 
132 static bool shouldMaterialize(Operation *op) {
133  // Don't materialize arc uses with latency >0, since we handle these in a
134  // second pass once all other operations have been moved to their respective
135  // clock trees.
136  if (auto stateOp = dyn_cast<StateOp>(op); stateOp && stateOp.getLatency() > 0)
137  return false;
138 
139  if (isa<MemoryOp, AllocStateOp, AllocMemoryOp, AllocStorageOp, ClockTreeOp,
140  PassThroughOp, RootInputOp, RootOutputOp, StateWriteOp,
141  MemoryWritePortOp, igraph::InstanceOpInterface>(op))
142  return false;
143 
144  return true;
145 }
146 
147 static bool shouldMaterialize(Value value) {
148  assert(value);
149 
150  // Block arguments are just used as they are.
151  auto *op = value.getDefiningOp();
152  if (!op)
153  return false;
154 
155  return shouldMaterialize(op);
156 }
157 
158 /// Materialize a value within this clock tree. This clones or moves all
159 /// operations required to produce this value inside the clock tree.
160 Value ClockLowering::materializeValue(Value value) {
161  if (!value)
162  return {};
163  if (auto mapped = materializedValues.lookupOrNull(value))
164  return mapped;
165  if (!shouldMaterialize(value))
166  return value;
167 
168  struct WorkItem {
169  Operation *op;
170  SmallVector<Value, 2> operands;
171  WorkItem(Operation *op) : op(op) {}
172  };
173 
174  SmallPtrSet<Operation *, 8> seen;
175  SmallVector<WorkItem> worklist;
176 
177  auto addToWorklist = [&](Operation *outerOp) {
178  SmallDenseSet<Value> seenOperands;
179  auto &workItem = worklist.emplace_back(outerOp);
180  outerOp->walk([&](Operation *innerOp) {
181  for (auto operand : innerOp->getOperands()) {
182  // Skip operands that are defined within the operation itself.
183  if (!operand.getParentBlock()->getParentOp()->isProperAncestor(outerOp))
184  continue;
185 
186  // Skip operands that we have already seen.
187  if (!seenOperands.insert(operand).second)
188  continue;
189 
190  // Skip operands that we have already materialized or that should not
191  // be materialized at all.
192  if (materializedValues.contains(operand) || !shouldMaterialize(operand))
193  continue;
194 
195  workItem.operands.push_back(operand);
196  }
197  });
198  };
199 
200  seen.insert(value.getDefiningOp());
201  addToWorklist(value.getDefiningOp());
202 
203  while (!worklist.empty()) {
204  auto &workItem = worklist.back();
205  if (!workItem.operands.empty()) {
206  auto operand = workItem.operands.pop_back_val();
207  if (materializedValues.contains(operand) || !shouldMaterialize(operand))
208  continue;
209  auto *defOp = operand.getDefiningOp();
210  if (!seen.insert(defOp).second) {
211  defOp->emitError("combinational loop detected");
212  return {};
213  }
214  addToWorklist(defOp);
215  } else {
216  builder.clone(*workItem.op, materializedValues);
217  seen.erase(workItem.op);
218  worklist.pop_back();
219  }
220  }
221 
222  return materializedValues.lookup(value);
223 }
224 
225 /// Create an AND gate if none with the given operands already exists. Note that
226 /// the operands may be null, in which case the function will return the
227 /// non-null operand, or null if both operands are null.
228 Value ClockLowering::getOrCreateAnd(Value lhs, Value rhs, Location loc) {
229  if (!lhs)
230  return rhs;
231  if (!rhs)
232  return lhs;
233  auto &slot = andCache[std::make_pair(lhs, rhs)];
234  if (!slot)
235  slot = builder.create<comb::AndOp>(loc, lhs, rhs);
236  return slot;
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // Module Lowering
241 //===----------------------------------------------------------------------===//
242 
243 GatedClockLowering ModuleLowering::getOrCreateClockLowering(Value clock) {
244  // Look through clock gates.
245  if (auto ckgOp = clock.getDefiningOp<ClockGateOp>()) {
246  // Reuse the existing lowering for this clock gate if possible.
247  if (auto it = gatedClockLowerings.find(clock);
248  it != gatedClockLowerings.end())
249  return it->second;
250 
251  // Get the lowering for the parent clock gate's input clock. This will give
252  // us the clock tree to emit things into, alongside the compound enable
253  // condition of all the clock gates along the way to the primary clock. All
254  // we have to do is to add this clock gate's condition to that list.
255  auto info = getOrCreateClockLowering(ckgOp.getInput());
256  auto ckgEnable = info.clock.materializeValue(ckgOp.getEnable());
257  info.enable =
258  info.clock.getOrCreateAnd(info.enable, ckgEnable, ckgOp.getLoc());
259  gatedClockLowerings.insert({clock, info});
260  return info;
261  }
262 
263  // Create the `ClockTreeOp` that corresponds to this ungated clock.
264  auto &slot = clockLowerings[clock];
265  if (!slot) {
266  auto newClock =
267  clockBuilder.createOrFold<seq::FromClockOp>(clock.getLoc(), clock);
268 
269  // Detect a rising edge on the clock, as `(old != new) & new`.
270  auto oldClockStorage = stateBuilder.create<AllocStateOp>(
271  clock.getLoc(), StateType::get(stateBuilder.getI1Type()), storageArg);
272  auto oldClock =
273  clockBuilder.create<StateReadOp>(clock.getLoc(), oldClockStorage);
274  clockBuilder.create<StateWriteOp>(clock.getLoc(), oldClockStorage, newClock,
275  Value{});
276  Value trigger = clockBuilder.create<comb::ICmpOp>(
277  clock.getLoc(), comb::ICmpPredicate::ne, oldClock, newClock);
278  trigger =
279  clockBuilder.create<comb::AndOp>(clock.getLoc(), trigger, newClock);
280 
281  // Create the tree op.
282  auto treeOp = clockBuilder.create<ClockTreeOp>(clock.getLoc(), trigger);
283  treeOp.getBody().emplaceBlock();
284  slot = std::make_unique<ClockLowering>(clock, treeOp, stats);
285  }
286  return GatedClockLowering{*slot, Value{}};
287 }
288 
289 ClockLowering &ModuleLowering::getOrCreatePassThrough() {
290  auto &slot = clockLowerings[Value{}];
291  if (!slot) {
292  auto treeOp = clockBuilder.create<PassThroughOp>(moduleOp.getLoc());
293  treeOp.getBody().emplaceBlock();
294  slot = std::make_unique<ClockLowering>(Value{}, treeOp, stats);
295  }
296  return *slot;
297 }
298 
299 /// Replace all uses of a value with a `StateReadOp` on a state.
300 Value ModuleLowering::replaceValueWithStateRead(Value value, Value state) {
301  OpBuilder builder(state.getContext());
302  builder.setInsertionPointAfterValue(state);
303  Value readOp = builder.create<StateReadOp>(value.getLoc(), state);
304  if (isa<seq::ClockType>(value.getType()))
305  readOp = builder.createOrFold<seq::ToClockOp>(value.getLoc(), readOp);
306  value.replaceAllUsesWith(readOp);
307  return readOp;
308 }
309 
310 /// Add the global state as an argument to the module's body block.
311 void ModuleLowering::addStorageArg() {
312  assert(!storageArg);
313  storageArg = moduleOp.getBodyBlock()->addArgument(
314  StorageType::get(context, {}), moduleOp.getLoc());
315 }
316 
317 /// Lower the primary inputs of the module to dedicated ops that allocate the
318 /// inputs in the model's storage.
319 LogicalResult ModuleLowering::lowerPrimaryInputs() {
320  for (auto blockArg : moduleOp.getBodyBlock()->getArguments()) {
321  if (blockArg == storageArg)
322  continue;
323  auto name = moduleOp.getArgName(blockArg.getArgNumber());
324  auto argTy = blockArg.getType();
325  IntegerType innerTy;
326  if (argTy.isa<seq::ClockType>()) {
327  innerTy = IntegerType::get(context, 1);
328  } else if (auto intType = argTy.dyn_cast<IntegerType>()) {
329  innerTy = intType;
330  } else {
331  return mlir::emitError(blockArg.getLoc(), "input ")
332  << name << " is of non-integer type " << blockArg.getType();
333  }
334  auto state = stateBuilder.create<RootInputOp>(
335  blockArg.getLoc(), StateType::get(innerTy), name, storageArg);
336  replaceValueWithStateRead(blockArg, state);
337  }
338  return success();
339 }
340 
341 /// Lower the primary outputs of the module to dedicated ops that allocate the
342 /// outputs in the model's storage.
343 LogicalResult ModuleLowering::lowerPrimaryOutputs() {
344  auto outputOp = cast<hw::OutputOp>(moduleOp.getBodyBlock()->getTerminator());
345  if (outputOp.getNumOperands() > 0) {
346  auto outputOperands = SmallVector<Value>(outputOp.getOperands());
347  outputOp->dropAllReferences();
348  auto &passThrough = getOrCreatePassThrough();
349  for (auto [outputArg, name] :
350  llvm::zip(outputOperands, moduleOp.getOutputNames())) {
351  IntegerType innerTy;
352  if (outputArg.getType().isa<seq::ClockType>()) {
353  innerTy = IntegerType::get(context, 1);
354  } else if (auto intType = outputArg.getType().dyn_cast<IntegerType>()) {
355  innerTy = intType;
356  } else {
357  return mlir::emitError(outputOp.getLoc(), "output ")
358  << name << " is of non-integer type " << outputArg.getType();
359  }
360  auto value = passThrough.materializeValue(outputArg);
361  auto state = stateBuilder.create<RootOutputOp>(
362  outputOp.getLoc(), StateType::get(innerTy), name.cast<StringAttr>(),
363  storageArg);
364  if (isa<seq::ClockType>(value.getType()))
365  value = passThrough.builder.createOrFold<seq::FromClockOp>(
366  outputOp.getLoc(), value);
367  passThrough.builder.create<StateWriteOp>(outputOp.getLoc(), state, value,
368  Value{});
369  }
370  }
371  outputOp.erase();
372  return success();
373 }
374 
375 LogicalResult ModuleLowering::lowerStates() {
376  SmallVector<Operation *> opsToLower;
377  for (auto &op : *moduleOp.getBodyBlock()) {
378  auto stateOp = dyn_cast<StateOp>(&op);
379  if ((stateOp && stateOp.getLatency() > 0) ||
380  isa<MemoryOp, MemoryWritePortOp, TapOp>(&op))
381  opsToLower.push_back(&op);
382  }
383 
384  for (auto *op : opsToLower) {
385  LLVM_DEBUG(llvm::dbgs() << "- Lowering " << *op << "\n");
386  auto result = TypeSwitch<Operation *, LogicalResult>(op)
387  .Case<StateOp, MemoryOp, MemoryWritePortOp, TapOp>(
388  [&](auto op) { return lowerState(op); })
389  .Default(success());
390  if (failed(result))
391  return failure();
392  }
393  return success();
394 }
395 
396 LogicalResult ModuleLowering::lowerState(StateOp stateOp) {
397  // Latency zero arcs incur no state and remain in the IR unmodified.
398  if (stateOp.getLatency() == 0)
399  return success();
400 
401  // We don't support arcs beyond latency 1 yet. These should be easy to add in
402  // the future though.
403  if (stateOp.getLatency() > 1)
404  return stateOp.emitError("state with latency > 1 not supported");
405 
406  // Grab all operands from the state op and make it drop all its references.
407  // This allows `materializeValue` to move an operation if this state was the
408  // last user.
409  auto stateClock = stateOp.getClock();
410  auto stateEnable = stateOp.getEnable();
411  auto stateReset = stateOp.getReset();
412  auto stateInputs = SmallVector<Value>(stateOp.getInputs());
413  stateOp->dropAllReferences();
414 
415  // Get the clock tree and enable condition for this state's clock. If this arc
416  // carries an explicit enable condition, fold that into the enable provided by
417  // the clock gates in the arc's clock tree.
418  auto info = getOrCreateClockLowering(stateClock);
419  info.enable = info.clock.getOrCreateAnd(
420  info.enable, info.clock.materializeValue(stateEnable), stateOp.getLoc());
421 
422  // Allocate the necessary state within the model.
423  SmallVector<Value> allocatedStates;
424  for (unsigned stateIdx = 0; stateIdx < stateOp.getNumResults(); ++stateIdx) {
425  auto type = stateOp.getResult(stateIdx).getType();
426  auto intType = dyn_cast<IntegerType>(type);
427  if (!intType)
428  return stateOp.emitOpError("result ")
429  << stateIdx << " has non-integer type " << type
430  << "; only integer types are supported";
431  auto stateType = StateType::get(intType);
432  auto state = stateBuilder.create<AllocStateOp>(stateOp.getLoc(), stateType,
433  storageArg);
434  if (auto names = stateOp->getAttrOfType<ArrayAttr>("names"))
435  state->setAttr("name", names[stateIdx]);
436  allocatedStates.push_back(state);
437  }
438 
439  // Create a copy of the arc use with latency zero. This will effectively be
440  // the computation of the arc's transfer function, while the latency is
441  // implemented through read and write functions.
442  SmallVector<Value> materializedOperands;
443  materializedOperands.reserve(stateInputs.size());
444 
445  for (auto input : stateInputs)
446  materializedOperands.push_back(info.clock.materializeValue(input));
447 
448  OpBuilder nonResetBuilder = info.clock.builder;
449  if (stateReset) {
450  auto materializedReset = info.clock.materializeValue(stateReset);
451  auto ifOp = info.clock.builder.create<scf::IfOp>(stateOp.getLoc(),
452  materializedReset, true);
453 
454  for (auto [alloc, resTy] :
455  llvm::zip(allocatedStates, stateOp.getResultTypes())) {
456  if (!resTy.isa<IntegerType>())
457  stateOp->emitOpError("Non-integer result not supported yet!");
458 
459  auto thenBuilder = ifOp.getThenBodyBuilder();
460  Value constZero =
461  thenBuilder.create<hw::ConstantOp>(stateOp.getLoc(), resTy, 0);
462  thenBuilder.create<StateWriteOp>(stateOp.getLoc(), alloc, constZero,
463  Value());
464  }
465 
466  nonResetBuilder = ifOp.getElseBodyBuilder();
467  }
468 
469  auto newStateOp = nonResetBuilder.create<StateOp>(
470  stateOp.getLoc(), stateOp.getArcAttr(), stateOp.getResultTypes(), Value{},
471  Value{}, 0, materializedOperands);
472 
473  // Create the write ops that write the result of the transfer function to the
474  // allocated state storage.
475  for (auto [alloc, result] :
476  llvm::zip(allocatedStates, newStateOp.getResults()))
477  nonResetBuilder.create<StateWriteOp>(stateOp.getLoc(), alloc, result,
478  info.enable);
479 
480  // Replace all uses of the arc with reads from the allocated state.
481  for (auto [alloc, result] : llvm::zip(allocatedStates, stateOp.getResults()))
482  replaceValueWithStateRead(result, alloc);
483  stateOp.erase();
484  return success();
485 }
486 
487 LogicalResult ModuleLowering::lowerState(MemoryOp memOp) {
488  auto allocMemOp = stateBuilder.create<AllocMemoryOp>(
489  memOp.getLoc(), memOp.getType(), storageArg, memOp->getAttrs());
490  memOp.replaceAllUsesWith(allocMemOp.getResult());
491  memOp.erase();
492  return success();
493 }
494 
495 LogicalResult ModuleLowering::lowerState(MemoryWritePortOp memWriteOp) {
496  if (memWriteOp.getLatency() > 1)
497  return memWriteOp->emitOpError("latencies > 1 not supported yet");
498 
499  // Get the clock tree and enable condition for this write port's clock. If the
500  // port carries an explicit enable condition, fold that into the enable
501  // provided by the clock gates in the port's clock tree.
502  auto info = getOrCreateClockLowering(memWriteOp.getClock());
503 
504  // Grab all operands from the op and make it drop all its references. This
505  // allows `materializeValue` to move an operation if this op was the last
506  // user.
507  auto writeMemory = memWriteOp.getMemory();
508  auto writeInputs = SmallVector<Value>(memWriteOp.getInputs());
509  auto arcResultTypes = memWriteOp.getArcResultTypes();
510  memWriteOp->dropAllReferences();
511 
512  SmallVector<Value> materializedInputs;
513  for (auto input : writeInputs)
514  materializedInputs.push_back(info.clock.materializeValue(input));
515  ValueRange results =
516  info.clock.builder
517  .create<CallOp>(memWriteOp.getLoc(), arcResultTypes,
518  memWriteOp.getArc(), materializedInputs)
519  ->getResults();
520 
521  auto enable =
522  memWriteOp.getEnable() ? results[memWriteOp.getEnableIdx()] : Value();
523  info.enable =
524  info.clock.getOrCreateAnd(info.enable, enable, memWriteOp.getLoc());
525 
526  // Materialize the operands for the write op within the surrounding clock
527  // tree.
528  auto address = results[memWriteOp.getAddressIdx()];
529  auto data = results[memWriteOp.getDataIdx()];
530  if (memWriteOp.getMask()) {
531  Value mask = results[memWriteOp.getMaskIdx(static_cast<bool>(enable))];
532  Value oldData = info.clock.builder.create<arc::MemoryReadOp>(
533  mask.getLoc(), data.getType(), writeMemory, address);
534  Value allOnes = info.clock.builder.create<hw::ConstantOp>(
535  mask.getLoc(), oldData.getType(), -1);
536  Value negatedMask = info.clock.builder.create<comb::XorOp>(
537  mask.getLoc(), mask, allOnes, true);
538  Value maskedOldData = info.clock.builder.create<comb::AndOp>(
539  mask.getLoc(), negatedMask, oldData, true);
540  Value maskedNewData =
541  info.clock.builder.create<comb::AndOp>(mask.getLoc(), mask, data, true);
542  data = info.clock.builder.create<comb::OrOp>(mask.getLoc(), maskedOldData,
543  maskedNewData, true);
544  }
545  info.clock.builder.create<MemoryWriteOp>(memWriteOp.getLoc(), writeMemory,
546  address, info.enable, data);
547  memWriteOp.erase();
548  return success();
549 }
550 
551 // Add state for taps into the passthrough block.
552 LogicalResult ModuleLowering::lowerState(TapOp tapOp) {
553  auto intType = tapOp.getValue().getType().dyn_cast<IntegerType>();
554  if (!intType)
555  return mlir::emitError(tapOp.getLoc(), "tapped value ")
556  << tapOp.getNameAttr() << " is of non-integer type "
557  << tapOp.getValue().getType();
558 
559  // Grab what we need from the tap op and then make it drop all its references.
560  // This will allow `materializeValue` to move ops instead of cloning them.
561  auto tapValue = tapOp.getValue();
562  tapOp->dropAllReferences();
563 
564  auto &passThrough = getOrCreatePassThrough();
565  auto materializedValue = passThrough.materializeValue(tapValue);
566  auto state = stateBuilder.create<AllocStateOp>(
567  tapOp.getLoc(), StateType::get(intType), storageArg, true);
568  state->setAttr("name", tapOp.getNameAttr());
569  passThrough.builder.create<StateWriteOp>(tapOp.getLoc(), state,
570  materializedValue, Value{});
571  tapOp.erase();
572  return success();
573 }
574 
575 /// Lower all instances of external modules to internal inputs/outputs to be
576 /// driven from outside of the design.
577 LogicalResult ModuleLowering::lowerExtModules(SymbolTable &symtbl) {
578  auto instOps = SmallVector<InstanceOp>(moduleOp.getOps<InstanceOp>());
579  for (auto op : instOps)
580  if (isa<HWModuleExternOp>(symtbl.lookup(op.getModuleNameAttr().getAttr())))
581  if (failed(lowerExtModule(op)))
582  return failure();
583  return success();
584 }
585 
586 LogicalResult ModuleLowering::lowerExtModule(InstanceOp instOp) {
587  LLVM_DEBUG(llvm::dbgs() << "- Lowering extmodule "
588  << instOp.getInstanceNameAttr() << "\n");
589 
590  SmallString<32> baseName(instOp.getInstanceName());
591  auto baseNameLen = baseName.size();
592 
593  // Lower the inputs of the extmodule as state that is only written.
594  for (auto [operand, name] :
595  llvm::zip(instOp.getOperands(), instOp.getArgNames())) {
596  LLVM_DEBUG(llvm::dbgs()
597  << " - Input " << name << " : " << operand.getType() << "\n");
598  auto intType = operand.getType().dyn_cast<IntegerType>();
599  if (!intType)
600  return mlir::emitError(operand.getLoc(), "input ")
601  << name << " of extern module " << instOp.getModuleNameAttr()
602  << " instance " << instOp.getInstanceNameAttr()
603  << " is of non-integer type " << operand.getType();
604  baseName.resize(baseNameLen);
605  baseName += '/';
606  baseName += cast<StringAttr>(name).getValue();
607  auto &passThrough = getOrCreatePassThrough();
608  auto state = stateBuilder.create<AllocStateOp>(
609  instOp.getLoc(), StateType::get(intType), storageArg);
610  state->setAttr("name", stateBuilder.getStringAttr(baseName));
611  passThrough.builder.create<StateWriteOp>(
612  instOp.getLoc(), state, passThrough.materializeValue(operand), Value{});
613  }
614 
615  // Lower the outputs of the extmodule as state that is only read.
616  for (auto [result, name] :
617  llvm::zip(instOp.getResults(), instOp.getResultNames())) {
618  LLVM_DEBUG(llvm::dbgs()
619  << " - Output " << name << " : " << result.getType() << "\n");
620  auto intType = result.getType().dyn_cast<IntegerType>();
621  if (!intType)
622  return mlir::emitError(result.getLoc(), "output ")
623  << name << " of extern module " << instOp.getModuleNameAttr()
624  << " instance " << instOp.getInstanceNameAttr()
625  << " is of non-integer type " << result.getType();
626  baseName.resize(baseNameLen);
627  baseName += '/';
628  baseName += cast<StringAttr>(name).getValue();
629  auto state = stateBuilder.create<AllocStateOp>(
630  result.getLoc(), StateType::get(intType), storageArg);
631  state->setAttr("name", stateBuilder.getStringAttr(baseName));
632  replaceValueWithStateRead(result, state);
633  }
634 
635  instOp.erase();
636  return success();
637 }
638 
639 LogicalResult ModuleLowering::cleanup() {
640  // Clean up dead ops in the model.
641  SetVector<Operation *> erasureWorklist;
642  auto isDead = [](Operation *op) {
643  if (isOpTriviallyDead(op))
644  return true;
645  if (!op->use_empty())
646  return false;
647  if (auto stateOp = dyn_cast<StateOp>(op))
648  return stateOp.getLatency() == 0;
649  return false;
650  };
651  for (auto &op : *moduleOp.getBodyBlock())
652  if (isDead(&op))
653  erasureWorklist.insert(&op);
654  while (!erasureWorklist.empty()) {
655  auto *op = erasureWorklist.pop_back_val();
656  if (!isDead(op))
657  continue;
658  op->walk([&](Operation *innerOp) {
659  for (auto operand : innerOp->getOperands())
660  if (auto *defOp = operand.getDefiningOp())
661  if (!op->isProperAncestor(defOp))
662  erasureWorklist.insert(defOp);
663  });
664  op->erase();
665  }
666 
667  // Establish an order among all operations (to avoid an O(n²) pathological
668  // pattern with `moveBefore`) and replicate read operations into the blocks
669  // where they have uses. The established order is used to create the read
670  // operation as late in the block as possible, just before the first use.
671  DenseMap<Operation *, unsigned> opOrder;
672  SmallVector<StateReadOp, 0> readsToSink;
673  moduleOp.walk([&](Operation *op) {
674  opOrder.insert({op, opOrder.size()});
675  if (auto readOp = dyn_cast<StateReadOp>(op))
676  readsToSink.push_back(readOp);
677  });
678  for (auto readToSink : readsToSink) {
680  for (auto &use : llvm::make_early_inc_range(readToSink->getUses())) {
681  auto *user = use.getOwner();
682  auto userOrder = opOrder.lookup(user);
683  auto &localRead = readsByBlock[user->getBlock()];
684  if (!localRead.first) {
685  if (user->getBlock() == readToSink->getBlock()) {
686  localRead.first = readToSink;
687  readToSink->moveBefore(user);
688  } else {
689  localRead.first = OpBuilder(user).cloneWithoutRegions(readToSink);
690  }
691  localRead.second = userOrder;
692  } else if (userOrder < localRead.second) {
693  localRead.first->moveBefore(user);
694  localRead.second = userOrder;
695  }
696  use.set(localRead.first);
697  }
698  if (readToSink.use_empty())
699  readToSink.erase();
700  }
701  return success();
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // Pass Infrastructure
706 //===----------------------------------------------------------------------===//
707 
708 namespace {
709 struct LowerStatePass : public arc::impl::LowerStateBase<LowerStatePass> {
710  LowerStatePass() = default;
711  LowerStatePass(const LowerStatePass &pass) : LowerStatePass() {}
712 
713  void runOnOperation() override;
714  LogicalResult runOnModule(HWModuleOp moduleOp, SymbolTable &symtbl);
715 
716  Statistics stats{this};
717 };
718 } // namespace
719 
720 void LowerStatePass::runOnOperation() {
721  auto &symtbl = getAnalysis<SymbolTable>();
722  SmallVector<HWModuleExternOp> extModules;
723  for (auto &op : llvm::make_early_inc_range(getOperation().getOps())) {
724  if (auto moduleOp = dyn_cast<HWModuleOp>(&op)) {
725  if (failed(runOnModule(moduleOp, symtbl)))
726  return signalPassFailure();
727  } else if (auto extModuleOp = dyn_cast<HWModuleExternOp>(&op)) {
728  extModules.push_back(extModuleOp);
729  }
730  }
731  for (auto op : extModules)
732  op.erase();
733 
734  // Lower remaining MemoryReadPort ops to MemoryRead ops. This can occur when
735  // the fan-in of a MemoryReadPortOp contains another such operation and is
736  // materialized before the one in the fan-in as the MemoryReadPortOp is not
737  // marked as a fan-in blocking/termination operation in `shouldMaterialize`.
738  // Adding it there can lead to dominance issues which would then have to be
739  // resolved instead.
740  SetVector<DefineOp> arcsToLower;
741  OpBuilder builder(getOperation());
742  getOperation()->walk([&](MemoryReadPortOp memReadOp) {
743  if (auto defOp = memReadOp->getParentOfType<DefineOp>())
744  arcsToLower.insert(defOp);
745 
746  builder.setInsertionPoint(memReadOp);
747  Value newRead = builder.create<MemoryReadOp>(
748  memReadOp.getLoc(), memReadOp.getMemory(), memReadOp.getAddress());
749  memReadOp.replaceAllUsesWith(newRead);
750  memReadOp.erase();
751  });
752 
753  SymbolTableCollection symbolTable;
754  mlir::SymbolUserMap userMap(symbolTable, getOperation());
755  for (auto defOp : arcsToLower) {
756  auto *terminator = defOp.getBodyBlock().getTerminator();
757  builder.setInsertionPoint(terminator);
758  builder.create<func::ReturnOp>(terminator->getLoc(),
759  terminator->getOperands());
760  terminator->erase();
761  builder.setInsertionPoint(defOp);
762  auto funcOp = builder.create<func::FuncOp>(defOp.getLoc(), defOp.getName(),
763  defOp.getFunctionType());
764  funcOp->setAttr("llvm.linkage",
765  LLVM::LinkageAttr::get(builder.getContext(),
766  LLVM::linkage::Linkage::Internal));
767  funcOp.getBody().takeBody(defOp.getBody());
768 
769  for (auto *user : userMap.getUsers(defOp)) {
770  builder.setInsertionPoint(user);
771  ValueRange results = builder
772  .create<func::CallOp>(
773  user->getLoc(), funcOp,
774  cast<CallOpInterface>(user).getArgOperands())
775  ->getResults();
776  user->replaceAllUsesWith(results);
777  user->erase();
778  }
779 
780  defOp.erase();
781  }
782 }
783 
784 LogicalResult LowerStatePass::runOnModule(HWModuleOp moduleOp,
785  SymbolTable &symtbl) {
786  LLVM_DEBUG(llvm::dbgs() << "Lowering state in `" << moduleOp.getModuleName()
787  << "`\n");
788  ModuleLowering lowering(moduleOp, stats);
789 
790  // Add sentinel ops to separate state allocations from clock trees.
791  lowering.stateBuilder.setInsertionPointToStart(moduleOp.getBodyBlock());
792 
793  Operation *stateSentinel =
794  lowering.stateBuilder.create<hw::OutputOp>(moduleOp.getLoc());
795  Operation *clockSentinel =
796  lowering.stateBuilder.create<hw::OutputOp>(moduleOp.getLoc());
797 
798  lowering.stateBuilder.setInsertionPoint(stateSentinel);
799  lowering.clockBuilder.setInsertionPoint(clockSentinel);
800 
801  lowering.addStorageArg();
802  if (failed(lowering.lowerPrimaryInputs()))
803  return failure();
804  if (failed(lowering.lowerPrimaryOutputs()))
805  return failure();
806  if (failed(lowering.lowerStates()))
807  return failure();
808  if (failed(lowering.lowerExtModules(symtbl)))
809  return failure();
810 
811  // Clean up the module body which contains a lot of operations that the
812  // pessimistic value materialization has left behind because it couldn't
813  // reliably determine that the ops were no longer needed.
814  if (failed(lowering.cleanup()))
815  return failure();
816 
817  // Erase the sentinel ops.
818  stateSentinel->erase();
819  clockSentinel->erase();
820 
821  // Replace the `HWModuleOp` with a `ModelOp`.
822  moduleOp.getBodyBlock()->eraseArguments(
823  [&](auto arg) { return arg != lowering.storageArg; });
824  ImplicitLocOpBuilder builder(moduleOp.getLoc(), moduleOp);
825  auto modelOp =
826  builder.create<ModelOp>(moduleOp.getLoc(), moduleOp.getModuleNameAttr());
827  modelOp.getBody().takeBody(moduleOp.getBody());
828  moduleOp->erase();
829  sortTopologically(&modelOp.getBodyBlock());
830 
831  return success();
832 }
833 
834 std::unique_ptr<Pass> arc::createLowerStatePass() {
835  return std::make_unique<LowerStatePass>();
836 }
lowerAnnotationsNoRefTypePorts FirtoolPreserveValuesMode value
Definition: Firtool.cpp:95
assert(baseType &&"element must be base type")
static bool shouldMaterialize(Operation *op)
Definition: LowerState.cpp:132
Builder builder
def create(data_type, value)
Definition: hw.py:397
std::unique_ptr< mlir::Pass > createLowerStatePass()
Definition: LowerState.cpp:834
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
Definition: hw.py:1
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28