CIRCT  19.0.0git
LowerState.cpp
Go to the documentation of this file.
1 //===- LowerState.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "circt/Dialect/HW/HWOps.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/SymbolTable.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/TopologicalSortUtils.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "arc-lower-state"
28 
29 namespace circt {
30 namespace arc {
31 #define GEN_PASS_DEF_LOWERSTATE
32 #include "circt/Dialect/Arc/ArcPasses.h.inc"
33 } // namespace arc
34 } // namespace circt
35 
36 using namespace circt;
37 using namespace arc;
38 using namespace hw;
39 using namespace mlir;
40 using llvm::SmallDenseSet;
41 
42 //===----------------------------------------------------------------------===//
43 // Data Structures
44 //===----------------------------------------------------------------------===//
45 
46 namespace {
47 
48 /// Statistics gathered throughout the execution of this pass.
49 struct Statistics {
50  Pass *parent;
51  Statistics(Pass *parent) : parent(parent) {}
52  using Statistic = Pass::Statistic;
53 
54  Statistic matOpsMoved{parent, "mat-ops-moved",
55  "Ops moved during value materialization"};
56  Statistic matOpsCloned{parent, "mat-ops-cloned",
57  "Ops cloned during value materialization"};
58  Statistic opsPruned{parent, "ops-pruned", "Ops removed as dead code"};
59 };
60 
61 /// Lowering info associated with a single primary clock.
62 struct ClockLowering {
63  /// The root clock this lowering is for.
64  Value clock;
65  /// A `ClockTreeOp` or `PassThroughOp`.
66  Operation *treeOp;
67  /// Pass statistics.
68  Statistics &stats;
69  OpBuilder builder;
70  /// A mapping from values outside the clock tree to their materialize form
71  /// inside the clock tree.
72  IRMapping materializedValues;
73  /// A cache of AND gates created for aggregating enable conditions.
74  DenseMap<std::pair<Value, Value>, Value> andCache;
75  /// A cache of OR gates created for aggregating enable conditions.
76  DenseMap<std::pair<Value, Value>, Value> orCache;
77 
78  ClockLowering(Value clock, Operation *treeOp, Statistics &stats)
79  : clock(clock), treeOp(treeOp), stats(stats), builder(treeOp) {
80  assert((isa<ClockTreeOp, PassThroughOp>(treeOp)));
81  builder.setInsertionPointToStart(&treeOp->getRegion(0).front());
82  }
83 
84  Value materializeValue(Value value);
85  Value getOrCreateAnd(Value lhs, Value rhs, Location loc);
86  Value getOrCreateOr(Value lhs, Value rhs, Location loc);
87 };
88 
89 struct GatedClockLowering {
90  /// Lowering info of the primary clock.
91  ClockLowering &clock;
92  /// An optional enable condition of the primary clock. May be null.
93  Value enable;
94 };
95 
96 /// State lowering for a single `HWModuleOp`.
97 struct ModuleLowering {
98  HWModuleOp moduleOp;
99  /// Pass statistics.
100  Statistics &stats;
101  MLIRContext *context;
102  DenseMap<Value, std::unique_ptr<ClockLowering>> clockLowerings;
103  DenseMap<Value, GatedClockLowering> gatedClockLowerings;
104  Value storageArg;
105  OpBuilder clockBuilder;
106  OpBuilder stateBuilder;
107 
108  ModuleLowering(HWModuleOp moduleOp, Statistics &stats)
109  : moduleOp(moduleOp), stats(stats), context(moduleOp.getContext()),
110  clockBuilder(moduleOp), stateBuilder(moduleOp) {}
111 
112  GatedClockLowering getOrCreateClockLowering(Value clock);
113  ClockLowering &getOrCreatePassThrough();
114  Value replaceValueWithStateRead(Value value, Value state);
115 
116  void addStorageArg();
117  LogicalResult lowerPrimaryInputs();
118  LogicalResult lowerPrimaryOutputs();
119  LogicalResult lowerStates();
120  LogicalResult lowerState(StateOp stateOp);
121  LogicalResult lowerState(MemoryOp memOp);
122  LogicalResult lowerState(MemoryWritePortOp memWriteOp);
123  LogicalResult lowerState(TapOp tapOp);
124  LogicalResult lowerExtModules(SymbolTable &symtbl);
125  LogicalResult lowerExtModule(InstanceOp instOp);
126 
127  LogicalResult cleanup();
128 };
129 } // namespace
130 
131 //===----------------------------------------------------------------------===//
132 // Clock Lowering
133 //===----------------------------------------------------------------------===//
134 
135 static bool shouldMaterialize(Operation *op) {
136  // Don't materialize arc uses with latency >0, since we handle these in a
137  // second pass once all other operations have been moved to their respective
138  // clock trees.
139  return !isa<MemoryOp, AllocStateOp, AllocMemoryOp, AllocStorageOp,
140  ClockTreeOp, PassThroughOp, RootInputOp, RootOutputOp,
141  StateWriteOp, MemoryWritePortOp, igraph::InstanceOpInterface,
142  StateOp>(op);
143 }
144 
145 static bool shouldMaterialize(Value value) {
146  assert(value);
147 
148  // Block arguments are just used as they are.
149  auto *op = value.getDefiningOp();
150  if (!op)
151  return false;
152 
153  return shouldMaterialize(op);
154 }
155 
156 /// Materialize a value within this clock tree. This clones or moves all
157 /// operations required to produce this value inside the clock tree.
158 Value ClockLowering::materializeValue(Value value) {
159  if (!value)
160  return {};
161  if (auto mapped = materializedValues.lookupOrNull(value))
162  return mapped;
163  if (!shouldMaterialize(value))
164  return value;
165 
166  struct WorkItem {
167  Operation *op;
168  SmallVector<Value, 2> operands;
169  WorkItem(Operation *op) : op(op) {}
170  };
171 
172  SmallPtrSet<Operation *, 8> seen;
173  SmallVector<WorkItem> worklist;
174 
175  auto addToWorklist = [&](Operation *outerOp) {
176  SmallDenseSet<Value> seenOperands;
177  auto &workItem = worklist.emplace_back(outerOp);
178  outerOp->walk([&](Operation *innerOp) {
179  for (auto operand : innerOp->getOperands()) {
180  // Skip operands that are defined within the operation itself.
181  if (!operand.getParentBlock()->getParentOp()->isProperAncestor(outerOp))
182  continue;
183 
184  // Skip operands that we have already seen.
185  if (!seenOperands.insert(operand).second)
186  continue;
187 
188  // Skip operands that we have already materialized or that should not
189  // be materialized at all.
190  if (materializedValues.contains(operand) || !shouldMaterialize(operand))
191  continue;
192 
193  workItem.operands.push_back(operand);
194  }
195  });
196  };
197 
198  seen.insert(value.getDefiningOp());
199  addToWorklist(value.getDefiningOp());
200 
201  while (!worklist.empty()) {
202  auto &workItem = worklist.back();
203  if (!workItem.operands.empty()) {
204  auto operand = workItem.operands.pop_back_val();
205  if (materializedValues.contains(operand) || !shouldMaterialize(operand))
206  continue;
207  auto *defOp = operand.getDefiningOp();
208  if (!seen.insert(defOp).second) {
209  defOp->emitError("combinational loop detected");
210  return {};
211  }
212  addToWorklist(defOp);
213  } else {
214  builder.clone(*workItem.op, materializedValues);
215  seen.erase(workItem.op);
216  worklist.pop_back();
217  }
218  }
219 
220  return materializedValues.lookup(value);
221 }
222 
223 /// Create an AND gate if none with the given operands already exists. Note that
224 /// the operands may be null, in which case the function will return the
225 /// non-null operand, or null if both operands are null.
226 Value ClockLowering::getOrCreateAnd(Value lhs, Value rhs, Location loc) {
227  if (!lhs)
228  return rhs;
229  if (!rhs)
230  return lhs;
231  auto &slot = andCache[std::make_pair(lhs, rhs)];
232  if (!slot)
233  slot = builder.create<comb::AndOp>(loc, lhs, rhs);
234  return slot;
235 }
236 
237 /// Create an OR gate if none with the given operands already exists. Note that
238 /// the operands may be null, in which case the function will return the
239 /// non-null operand, or null if both operands are null.
240 Value ClockLowering::getOrCreateOr(Value lhs, Value rhs, Location loc) {
241  if (!lhs)
242  return rhs;
243  if (!rhs)
244  return lhs;
245  auto &slot = orCache[std::make_pair(lhs, rhs)];
246  if (!slot)
247  slot = builder.create<comb::OrOp>(loc, lhs, rhs);
248  return slot;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // Module Lowering
253 //===----------------------------------------------------------------------===//
254 
255 GatedClockLowering ModuleLowering::getOrCreateClockLowering(Value clock) {
256  // Look through clock gates.
257  if (auto ckgOp = clock.getDefiningOp<seq::ClockGateOp>()) {
258  // Reuse the existing lowering for this clock gate if possible.
259  if (auto it = gatedClockLowerings.find(clock);
260  it != gatedClockLowerings.end())
261  return it->second;
262 
263  // Get the lowering for the parent clock gate's input clock. This will give
264  // us the clock tree to emit things into, alongside the compound enable
265  // condition of all the clock gates along the way to the primary clock. All
266  // we have to do is to add this clock gate's condition to that list.
267  auto info = getOrCreateClockLowering(ckgOp.getInput());
268  auto ckgEnable = info.clock.materializeValue(ckgOp.getEnable());
269  auto ckgTestEnable = info.clock.materializeValue(ckgOp.getTestEnable());
270  info.enable = info.clock.getOrCreateAnd(
271  info.enable,
272  info.clock.getOrCreateOr(ckgEnable, ckgTestEnable, ckgOp.getLoc()),
273  ckgOp.getLoc());
274  gatedClockLowerings.insert({clock, info});
275  return info;
276  }
277 
278  // Create the `ClockTreeOp` that corresponds to this ungated clock.
279  auto &slot = clockLowerings[clock];
280  if (!slot) {
281  auto newClock =
282  clockBuilder.createOrFold<seq::FromClockOp>(clock.getLoc(), clock);
283 
284  // Detect a rising edge on the clock, as `(old != new) & new`.
285  auto oldClockStorage = stateBuilder.create<AllocStateOp>(
286  clock.getLoc(), StateType::get(stateBuilder.getI1Type()), storageArg);
287  auto oldClock =
288  clockBuilder.create<StateReadOp>(clock.getLoc(), oldClockStorage);
289  clockBuilder.create<StateWriteOp>(clock.getLoc(), oldClockStorage, newClock,
290  Value{});
291  Value trigger = clockBuilder.create<comb::ICmpOp>(
292  clock.getLoc(), comb::ICmpPredicate::ne, oldClock, newClock);
293  trigger =
294  clockBuilder.create<comb::AndOp>(clock.getLoc(), trigger, newClock);
295 
296  // Create the tree op.
297  auto treeOp = clockBuilder.create<ClockTreeOp>(clock.getLoc(), trigger);
298  treeOp.getBody().emplaceBlock();
299  slot = std::make_unique<ClockLowering>(clock, treeOp, stats);
300  }
301  return GatedClockLowering{*slot, Value{}};
302 }
303 
304 ClockLowering &ModuleLowering::getOrCreatePassThrough() {
305  auto &slot = clockLowerings[Value{}];
306  if (!slot) {
307  auto treeOp = clockBuilder.create<PassThroughOp>(moduleOp.getLoc());
308  treeOp.getBody().emplaceBlock();
309  slot = std::make_unique<ClockLowering>(Value{}, treeOp, stats);
310  }
311  return *slot;
312 }
313 
314 /// Replace all uses of a value with a `StateReadOp` on a state.
315 Value ModuleLowering::replaceValueWithStateRead(Value value, Value state) {
316  OpBuilder builder(state.getContext());
317  builder.setInsertionPointAfterValue(state);
318  Value readOp = builder.create<StateReadOp>(value.getLoc(), state);
319  if (isa<seq::ClockType>(value.getType()))
320  readOp = builder.createOrFold<seq::ToClockOp>(value.getLoc(), readOp);
321  value.replaceAllUsesWith(readOp);
322  return readOp;
323 }
324 
325 /// Add the global state as an argument to the module's body block.
326 void ModuleLowering::addStorageArg() {
327  assert(!storageArg);
328  storageArg = moduleOp.getBodyBlock()->addArgument(
329  StorageType::get(context, {}), moduleOp.getLoc());
330 }
331 
332 /// Lower the primary inputs of the module to dedicated ops that allocate the
333 /// inputs in the model's storage.
334 LogicalResult ModuleLowering::lowerPrimaryInputs() {
335  for (auto blockArg : moduleOp.getBodyBlock()->getArguments()) {
336  if (blockArg == storageArg)
337  continue;
338  auto name = moduleOp.getArgName(blockArg.getArgNumber());
339  auto argTy = blockArg.getType();
340  IntegerType innerTy;
341  if (isa<seq::ClockType>(argTy)) {
342  innerTy = IntegerType::get(context, 1);
343  } else if (auto intType = dyn_cast<IntegerType>(argTy)) {
344  innerTy = intType;
345  } else {
346  return mlir::emitError(blockArg.getLoc(), "input ")
347  << name << " is of non-integer type " << blockArg.getType();
348  }
349  auto state = stateBuilder.create<RootInputOp>(
350  blockArg.getLoc(), StateType::get(innerTy), name, storageArg);
351  replaceValueWithStateRead(blockArg, state);
352  }
353  return success();
354 }
355 
356 /// Lower the primary outputs of the module to dedicated ops that allocate the
357 /// outputs in the model's storage.
358 LogicalResult ModuleLowering::lowerPrimaryOutputs() {
359  auto outputOp = cast<hw::OutputOp>(moduleOp.getBodyBlock()->getTerminator());
360  if (outputOp.getNumOperands() > 0) {
361  auto outputOperands = SmallVector<Value>(outputOp.getOperands());
362  outputOp->dropAllReferences();
363  auto &passThrough = getOrCreatePassThrough();
364  for (auto [outputArg, name] :
365  llvm::zip(outputOperands, moduleOp.getOutputNames())) {
366  IntegerType innerTy;
367  if (isa<seq::ClockType>(outputArg.getType())) {
368  innerTy = IntegerType::get(context, 1);
369  } else if (auto intType = dyn_cast<IntegerType>(outputArg.getType())) {
370  innerTy = intType;
371  } else {
372  return mlir::emitError(outputOp.getLoc(), "output ")
373  << name << " is of non-integer type " << outputArg.getType();
374  }
375  auto value = passThrough.materializeValue(outputArg);
376  auto state = stateBuilder.create<RootOutputOp>(
377  outputOp.getLoc(), StateType::get(innerTy), cast<StringAttr>(name),
378  storageArg);
379  if (isa<seq::ClockType>(value.getType()))
380  value = passThrough.builder.createOrFold<seq::FromClockOp>(
381  outputOp.getLoc(), value);
382  passThrough.builder.create<StateWriteOp>(outputOp.getLoc(), state, value,
383  Value{});
384  }
385  }
386  outputOp.erase();
387  return success();
388 }
389 
390 LogicalResult ModuleLowering::lowerStates() {
391  SmallVector<Operation *> opsToLower;
392  for (auto &op : *moduleOp.getBodyBlock())
393  if (isa<StateOp, MemoryOp, MemoryWritePortOp, TapOp>(&op))
394  opsToLower.push_back(&op);
395 
396  for (auto *op : opsToLower) {
397  LLVM_DEBUG(llvm::dbgs() << "- Lowering " << *op << "\n");
398  auto result = TypeSwitch<Operation *, LogicalResult>(op)
399  .Case<StateOp, MemoryOp, MemoryWritePortOp, TapOp>(
400  [&](auto op) { return lowerState(op); })
401  .Default(success());
402  if (failed(result))
403  return failure();
404  }
405  return success();
406 }
407 
408 LogicalResult ModuleLowering::lowerState(StateOp stateOp) {
409  // We don't support arcs beyond latency 1 yet. These should be easy to add in
410  // the future though.
411  if (stateOp.getLatency() > 1)
412  return stateOp.emitError("state with latency > 1 not supported");
413 
414  // Grab all operands from the state op and make it drop all its references.
415  // This allows `materializeValue` to move an operation if this state was the
416  // last user.
417  auto stateClock = stateOp.getClock();
418  auto stateEnable = stateOp.getEnable();
419  auto stateReset = stateOp.getReset();
420  auto stateInputs = SmallVector<Value>(stateOp.getInputs());
421 
422  // Get the clock tree and enable condition for this state's clock. If this arc
423  // carries an explicit enable condition, fold that into the enable provided by
424  // the clock gates in the arc's clock tree.
425  auto info = getOrCreateClockLowering(stateClock);
426  info.enable = info.clock.getOrCreateAnd(
427  info.enable, info.clock.materializeValue(stateEnable), stateOp.getLoc());
428 
429  // Allocate the necessary state within the model.
430  SmallVector<Value> allocatedStates;
431  for (unsigned stateIdx = 0; stateIdx < stateOp.getNumResults(); ++stateIdx) {
432  auto type = stateOp.getResult(stateIdx).getType();
433  auto intType = dyn_cast<IntegerType>(type);
434  if (!intType)
435  return stateOp.emitOpError("result ")
436  << stateIdx << " has non-integer type " << type
437  << "; only integer types are supported";
438  auto stateType = StateType::get(intType);
439  auto state = stateBuilder.create<AllocStateOp>(stateOp.getLoc(), stateType,
440  storageArg);
441  if (auto names = stateOp->getAttrOfType<ArrayAttr>("names"))
442  state->setAttr("name", names[stateIdx]);
443  allocatedStates.push_back(state);
444  }
445 
446  // Create a copy of the arc use with latency zero. This will effectively be
447  // the computation of the arc's transfer function, while the latency is
448  // implemented through read and write functions.
449  SmallVector<Value> materializedOperands;
450  materializedOperands.reserve(stateInputs.size());
451 
452  for (auto input : stateInputs)
453  materializedOperands.push_back(info.clock.materializeValue(input));
454 
455  OpBuilder nonResetBuilder = info.clock.builder;
456  if (stateReset) {
457  auto materializedReset = info.clock.materializeValue(stateReset);
458  auto ifOp = info.clock.builder.create<scf::IfOp>(stateOp.getLoc(),
459  materializedReset, true);
460 
461  for (auto [alloc, resTy] :
462  llvm::zip(allocatedStates, stateOp.getResultTypes())) {
463  if (!isa<IntegerType>(resTy))
464  stateOp->emitOpError("Non-integer result not supported yet!");
465 
466  auto thenBuilder = ifOp.getThenBodyBuilder();
467  Value constZero =
468  thenBuilder.create<hw::ConstantOp>(stateOp.getLoc(), resTy, 0);
469  thenBuilder.create<StateWriteOp>(stateOp.getLoc(), alloc, constZero,
470  Value());
471  }
472 
473  nonResetBuilder = ifOp.getElseBodyBuilder();
474  }
475 
476  stateOp->dropAllReferences();
477 
478  auto newStateOp = nonResetBuilder.create<CallOp>(
479  stateOp.getLoc(), stateOp.getResultTypes(), stateOp.getArcAttr(),
480  materializedOperands);
481 
482  // Create the write ops that write the result of the transfer function to the
483  // allocated state storage.
484  for (auto [alloc, result] :
485  llvm::zip(allocatedStates, newStateOp.getResults()))
486  nonResetBuilder.create<StateWriteOp>(stateOp.getLoc(), alloc, result,
487  info.enable);
488 
489  // Replace all uses of the arc with reads from the allocated state.
490  for (auto [alloc, result] : llvm::zip(allocatedStates, stateOp.getResults()))
491  replaceValueWithStateRead(result, alloc);
492  stateOp.erase();
493  return success();
494 }
495 
496 LogicalResult ModuleLowering::lowerState(MemoryOp memOp) {
497  auto allocMemOp = stateBuilder.create<AllocMemoryOp>(
498  memOp.getLoc(), memOp.getType(), storageArg, memOp->getAttrs());
499  memOp.replaceAllUsesWith(allocMemOp.getResult());
500  memOp.erase();
501  return success();
502 }
503 
504 LogicalResult ModuleLowering::lowerState(MemoryWritePortOp memWriteOp) {
505  if (memWriteOp.getLatency() > 1)
506  return memWriteOp->emitOpError("latencies > 1 not supported yet");
507 
508  // Get the clock tree and enable condition for this write port's clock. If the
509  // port carries an explicit enable condition, fold that into the enable
510  // provided by the clock gates in the port's clock tree.
511  auto info = getOrCreateClockLowering(memWriteOp.getClock());
512 
513  // Grab all operands from the op and make it drop all its references. This
514  // allows `materializeValue` to move an operation if this op was the last
515  // user.
516  auto writeMemory = memWriteOp.getMemory();
517  auto writeInputs = SmallVector<Value>(memWriteOp.getInputs());
518  auto arcResultTypes = memWriteOp.getArcResultTypes();
519  memWriteOp->dropAllReferences();
520 
521  SmallVector<Value> materializedInputs;
522  for (auto input : writeInputs)
523  materializedInputs.push_back(info.clock.materializeValue(input));
524  ValueRange results =
525  info.clock.builder
526  .create<CallOp>(memWriteOp.getLoc(), arcResultTypes,
527  memWriteOp.getArc(), materializedInputs)
528  ->getResults();
529 
530  auto enable =
531  memWriteOp.getEnable() ? results[memWriteOp.getEnableIdx()] : Value();
532  info.enable =
533  info.clock.getOrCreateAnd(info.enable, enable, memWriteOp.getLoc());
534 
535  // Materialize the operands for the write op within the surrounding clock
536  // tree.
537  auto address = results[memWriteOp.getAddressIdx()];
538  auto data = results[memWriteOp.getDataIdx()];
539  if (memWriteOp.getMask()) {
540  Value mask = results[memWriteOp.getMaskIdx(static_cast<bool>(enable))];
541  Value oldData = info.clock.builder.create<arc::MemoryReadOp>(
542  mask.getLoc(), data.getType(), writeMemory, address);
543  Value allOnes = info.clock.builder.create<hw::ConstantOp>(
544  mask.getLoc(), oldData.getType(), -1);
545  Value negatedMask = info.clock.builder.create<comb::XorOp>(
546  mask.getLoc(), mask, allOnes, true);
547  Value maskedOldData = info.clock.builder.create<comb::AndOp>(
548  mask.getLoc(), negatedMask, oldData, true);
549  Value maskedNewData =
550  info.clock.builder.create<comb::AndOp>(mask.getLoc(), mask, data, true);
551  data = info.clock.builder.create<comb::OrOp>(mask.getLoc(), maskedOldData,
552  maskedNewData, true);
553  }
554  info.clock.builder.create<MemoryWriteOp>(memWriteOp.getLoc(), writeMemory,
555  address, info.enable, data);
556  memWriteOp.erase();
557  return success();
558 }
559 
560 // Add state for taps into the passthrough block.
561 LogicalResult ModuleLowering::lowerState(TapOp tapOp) {
562  auto intType = dyn_cast<IntegerType>(tapOp.getValue().getType());
563  if (!intType)
564  return mlir::emitError(tapOp.getLoc(), "tapped value ")
565  << tapOp.getNameAttr() << " is of non-integer type "
566  << tapOp.getValue().getType();
567 
568  // Grab what we need from the tap op and then make it drop all its references.
569  // This will allow `materializeValue` to move ops instead of cloning them.
570  auto tapValue = tapOp.getValue();
571  tapOp->dropAllReferences();
572 
573  auto &passThrough = getOrCreatePassThrough();
574  auto materializedValue = passThrough.materializeValue(tapValue);
575  auto state = stateBuilder.create<AllocStateOp>(
576  tapOp.getLoc(), StateType::get(intType), storageArg, true);
577  state->setAttr("name", tapOp.getNameAttr());
578  passThrough.builder.create<StateWriteOp>(tapOp.getLoc(), state,
579  materializedValue, Value{});
580  tapOp.erase();
581  return success();
582 }
583 
584 /// Lower all instances of external modules to internal inputs/outputs to be
585 /// driven from outside of the design.
586 LogicalResult ModuleLowering::lowerExtModules(SymbolTable &symtbl) {
587  auto instOps = SmallVector<InstanceOp>(moduleOp.getOps<InstanceOp>());
588  for (auto op : instOps)
589  if (isa<HWModuleExternOp>(symtbl.lookup(op.getModuleNameAttr().getAttr())))
590  if (failed(lowerExtModule(op)))
591  return failure();
592  return success();
593 }
594 
595 LogicalResult ModuleLowering::lowerExtModule(InstanceOp instOp) {
596  LLVM_DEBUG(llvm::dbgs() << "- Lowering extmodule "
597  << instOp.getInstanceNameAttr() << "\n");
598 
599  SmallString<32> baseName(instOp.getInstanceName());
600  auto baseNameLen = baseName.size();
601 
602  // Lower the inputs of the extmodule as state that is only written.
603  for (auto [operand, name] :
604  llvm::zip(instOp.getOperands(), instOp.getArgNames())) {
605  LLVM_DEBUG(llvm::dbgs()
606  << " - Input " << name << " : " << operand.getType() << "\n");
607  auto intType = dyn_cast<IntegerType>(operand.getType());
608  if (!intType)
609  return mlir::emitError(operand.getLoc(), "input ")
610  << name << " of extern module " << instOp.getModuleNameAttr()
611  << " instance " << instOp.getInstanceNameAttr()
612  << " is of non-integer type " << operand.getType();
613  baseName.resize(baseNameLen);
614  baseName += '/';
615  baseName += cast<StringAttr>(name).getValue();
616  auto &passThrough = getOrCreatePassThrough();
617  auto state = stateBuilder.create<AllocStateOp>(
618  instOp.getLoc(), StateType::get(intType), storageArg);
619  state->setAttr("name", stateBuilder.getStringAttr(baseName));
620  passThrough.builder.create<StateWriteOp>(
621  instOp.getLoc(), state, passThrough.materializeValue(operand), Value{});
622  }
623 
624  // Lower the outputs of the extmodule as state that is only read.
625  for (auto [result, name] :
626  llvm::zip(instOp.getResults(), instOp.getResultNames())) {
627  LLVM_DEBUG(llvm::dbgs()
628  << " - Output " << name << " : " << result.getType() << "\n");
629  auto intType = dyn_cast<IntegerType>(result.getType());
630  if (!intType)
631  return mlir::emitError(result.getLoc(), "output ")
632  << name << " of extern module " << instOp.getModuleNameAttr()
633  << " instance " << instOp.getInstanceNameAttr()
634  << " is of non-integer type " << result.getType();
635  baseName.resize(baseNameLen);
636  baseName += '/';
637  baseName += cast<StringAttr>(name).getValue();
638  auto state = stateBuilder.create<AllocStateOp>(
639  result.getLoc(), StateType::get(intType), storageArg);
640  state->setAttr("name", stateBuilder.getStringAttr(baseName));
641  replaceValueWithStateRead(result, state);
642  }
643 
644  instOp.erase();
645  return success();
646 }
647 
648 LogicalResult ModuleLowering::cleanup() {
649  // Clean up dead ops in the model.
650  SetVector<Operation *> erasureWorklist;
651  auto isDead = [](Operation *op) {
652  if (isOpTriviallyDead(op))
653  return true;
654  if (!op->use_empty())
655  return false;
656  return false;
657  };
658  for (auto &op : *moduleOp.getBodyBlock())
659  if (isDead(&op))
660  erasureWorklist.insert(&op);
661  while (!erasureWorklist.empty()) {
662  auto *op = erasureWorklist.pop_back_val();
663  if (!isDead(op))
664  continue;
665  op->walk([&](Operation *innerOp) {
666  for (auto operand : innerOp->getOperands())
667  if (auto *defOp = operand.getDefiningOp())
668  if (!op->isProperAncestor(defOp))
669  erasureWorklist.insert(defOp);
670  });
671  op->erase();
672  }
673 
674  // Establish an order among all operations (to avoid an O(n²) pathological
675  // pattern with `moveBefore`) and replicate read operations into the blocks
676  // where they have uses. The established order is used to create the read
677  // operation as late in the block as possible, just before the first use.
678  DenseMap<Operation *, unsigned> opOrder;
679  SmallVector<StateReadOp, 0> readsToSink;
680  moduleOp.walk([&](Operation *op) {
681  opOrder.insert({op, opOrder.size()});
682  if (auto readOp = dyn_cast<StateReadOp>(op))
683  readsToSink.push_back(readOp);
684  });
685  for (auto readToSink : readsToSink) {
687  for (auto &use : llvm::make_early_inc_range(readToSink->getUses())) {
688  auto *user = use.getOwner();
689  auto userOrder = opOrder.lookup(user);
690  auto &localRead = readsByBlock[user->getBlock()];
691  if (!localRead.first) {
692  if (user->getBlock() == readToSink->getBlock()) {
693  localRead.first = readToSink;
694  readToSink->moveBefore(user);
695  } else {
696  localRead.first = OpBuilder(user).cloneWithoutRegions(readToSink);
697  }
698  localRead.second = userOrder;
699  } else if (userOrder < localRead.second) {
700  localRead.first->moveBefore(user);
701  localRead.second = userOrder;
702  }
703  use.set(localRead.first);
704  }
705  if (readToSink.use_empty())
706  readToSink.erase();
707  }
708  return success();
709 }
710 
711 //===----------------------------------------------------------------------===//
712 // Pass Infrastructure
713 //===----------------------------------------------------------------------===//
714 
715 namespace {
716 struct LowerStatePass : public arc::impl::LowerStateBase<LowerStatePass> {
717  LowerStatePass() = default;
718  LowerStatePass(const LowerStatePass &pass) : LowerStatePass() {}
719 
720  void runOnOperation() override;
721  LogicalResult runOnModule(HWModuleOp moduleOp, SymbolTable &symtbl);
722 
723  Statistics stats{this};
724 };
725 } // namespace
726 
727 void LowerStatePass::runOnOperation() {
728  auto &symtbl = getAnalysis<SymbolTable>();
729  SmallVector<HWModuleExternOp> extModules;
730  for (auto &op : llvm::make_early_inc_range(getOperation().getOps())) {
731  if (auto moduleOp = dyn_cast<HWModuleOp>(&op)) {
732  if (failed(runOnModule(moduleOp, symtbl)))
733  return signalPassFailure();
734  } else if (auto extModuleOp = dyn_cast<HWModuleExternOp>(&op)) {
735  extModules.push_back(extModuleOp);
736  }
737  }
738  for (auto op : extModules)
739  op.erase();
740 
741  // Lower remaining MemoryReadPort ops to MemoryRead ops. This can occur when
742  // the fan-in of a MemoryReadPortOp contains another such operation and is
743  // materialized before the one in the fan-in as the MemoryReadPortOp is not
744  // marked as a fan-in blocking/termination operation in `shouldMaterialize`.
745  // Adding it there can lead to dominance issues which would then have to be
746  // resolved instead.
747  SetVector<DefineOp> arcsToLower;
748  OpBuilder builder(getOperation());
749  getOperation()->walk([&](MemoryReadPortOp memReadOp) {
750  if (auto defOp = memReadOp->getParentOfType<DefineOp>())
751  arcsToLower.insert(defOp);
752 
753  builder.setInsertionPoint(memReadOp);
754  Value newRead = builder.create<MemoryReadOp>(
755  memReadOp.getLoc(), memReadOp.getMemory(), memReadOp.getAddress());
756  memReadOp.replaceAllUsesWith(newRead);
757  memReadOp.erase();
758  });
759 
760  SymbolTableCollection symbolTable;
761  mlir::SymbolUserMap userMap(symbolTable, getOperation());
762  for (auto defOp : arcsToLower) {
763  auto *terminator = defOp.getBodyBlock().getTerminator();
764  builder.setInsertionPoint(terminator);
765  builder.create<func::ReturnOp>(terminator->getLoc(),
766  terminator->getOperands());
767  terminator->erase();
768  builder.setInsertionPoint(defOp);
769  auto funcOp = builder.create<func::FuncOp>(defOp.getLoc(), defOp.getName(),
770  defOp.getFunctionType());
771  funcOp->setAttr("llvm.linkage",
772  LLVM::LinkageAttr::get(builder.getContext(),
773  LLVM::linkage::Linkage::Internal));
774  funcOp.getBody().takeBody(defOp.getBody());
775 
776  for (auto *user : userMap.getUsers(defOp)) {
777  builder.setInsertionPoint(user);
778  ValueRange results = builder
779  .create<func::CallOp>(
780  user->getLoc(), funcOp,
781  cast<CallOpInterface>(user).getArgOperands())
782  ->getResults();
783  user->replaceAllUsesWith(results);
784  user->erase();
785  }
786 
787  defOp.erase();
788  }
789 }
790 
791 LogicalResult LowerStatePass::runOnModule(HWModuleOp moduleOp,
792  SymbolTable &symtbl) {
793  LLVM_DEBUG(llvm::dbgs() << "Lowering state in `" << moduleOp.getModuleName()
794  << "`\n");
795  ModuleLowering lowering(moduleOp, stats);
796 
797  // Add sentinel ops to separate state allocations from clock trees.
798  lowering.stateBuilder.setInsertionPointToStart(moduleOp.getBodyBlock());
799 
800  Operation *stateSentinel =
801  lowering.stateBuilder.create<hw::OutputOp>(moduleOp.getLoc());
802  Operation *clockSentinel =
803  lowering.stateBuilder.create<hw::OutputOp>(moduleOp.getLoc());
804 
805  lowering.stateBuilder.setInsertionPoint(stateSentinel);
806  lowering.clockBuilder.setInsertionPoint(clockSentinel);
807 
808  lowering.addStorageArg();
809  if (failed(lowering.lowerPrimaryInputs()))
810  return failure();
811  if (failed(lowering.lowerPrimaryOutputs()))
812  return failure();
813  if (failed(lowering.lowerStates()))
814  return failure();
815  if (failed(lowering.lowerExtModules(symtbl)))
816  return failure();
817 
818  // Clean up the module body which contains a lot of operations that the
819  // pessimistic value materialization has left behind because it couldn't
820  // reliably determine that the ops were no longer needed.
821  if (failed(lowering.cleanup()))
822  return failure();
823 
824  // Erase the sentinel ops.
825  stateSentinel->erase();
826  clockSentinel->erase();
827 
828  // Replace the `HWModuleOp` with a `ModelOp`.
829  moduleOp.getBodyBlock()->eraseArguments(
830  [&](auto arg) { return arg != lowering.storageArg; });
831  ImplicitLocOpBuilder builder(moduleOp.getLoc(), moduleOp);
832  auto modelOp =
833  builder.create<ModelOp>(moduleOp.getLoc(), moduleOp.getModuleNameAttr(),
834  TypeAttr::get(moduleOp.getModuleType()));
835  modelOp.getBody().takeBody(moduleOp.getBody());
836  moduleOp->erase();
837  sortTopologically(&modelOp.getBodyBlock());
838 
839  return success();
840 }
841 
842 std::unique_ptr<Pass> arc::createLowerStatePass() {
843  return std::make_unique<LowerStatePass>();
844 }
assert(baseType &&"element must be base type")
static bool shouldMaterialize(Operation *op)
Definition: LowerState.cpp:135
Builder builder
def create(data_type, value)
Definition: hw.py:393
std::unique_ptr< mlir::Pass > createLowerStatePass()
Definition: LowerState.cpp:842
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: hw.py:1