CIRCT  19.0.0git
HoistPassthrough.cpp
Go to the documentation of this file.
1 //===- HoistPassthrough.cpp - Hoist basic passthrough ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the HoistPassthrough pass. This pass identifies basic
10 // drivers of output ports that can be pulled out of modules.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetails.h"
15 
23 #include "circt/Support/Debug.h"
24 #include "mlir/IR/BuiltinAttributes.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "llvm/ADT/PostOrderIterator.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/FormatVariadic.h"
29 
30 #define DEBUG_TYPE "firrtl-hoist-passthrough"
31 
32 using namespace circt;
33 using namespace firrtl;
34 
35 using RefValue = mlir::TypedValue<RefType>;
36 
37 namespace {
38 
39 struct RefDriver;
40 struct HWDriver;
41 
42 //===----------------------------------------------------------------------===//
43 // (Rematerializable)Driver declaration.
44 //===----------------------------------------------------------------------===//
45 /// Statically known driver for a Value.
46 ///
47 /// Driver source expected to be rematerialized provided a mapping.
48 /// Generally takes form:
49 /// [source]----(static indexing?)---->DRIVE_OP---->[dest]
50 ///
51 /// However, only requirement is that the "driver" can be rematerialized
52 /// across a module/instance boundary in terms of mapping args<-->results.
53 ///
54 /// Driver can be reconstructed given a mapping in new location.
55 ///
56 /// "Update":
57 /// Map:
58 /// source -> A
59 /// dest -> B
60 ///
61 /// [source]---(indexing)--> SSA_DRIVE_OP ---> [dest]
62 /// + ([s']---> SSA_DRIVE_OP ---> [A])
63 /// =>
64 /// RAUW(B, [A]--(clone indexing))
65 /// (or RAUW(B, [s']--(clone indexing)))
66 ///
67 /// Update is safe if driver classification is ""equivalent"" for each context
68 /// on the other side. For hoisting U-Turns, this is safe in all cases,
69 /// for sinking n-turns the driver must be map-equivalent at all instantiation
70 /// sites.
71 /// Only UTurns are supported presently.
72 ///
73 /// The goal is to drop the destination port, so after replacing all users
74 /// on other side of the instantiation, drop the port driver and move
75 /// all its users to the driver (immediate) source.
76 /// This may not be safe if the driver source does not dominate all users of the
77 /// port, in which case either reject (unsafe) or insert a temporary wire to
78 /// drive instead.
79 ///
80 /// RAUW'ing may require insertion of conversion ops if types don't match.
81 //===----------------------------------------------------------------------===//
82 struct Driver {
83  //-- Data -----------------------------------------------------------------//
84 
85  /// Connect entirely and definitively driving the destination.
86  FConnectLike drivingConnect;
87  /// Source of LHS.
88  FieldRef source;
89 
90  //-- Constructors ---------------------------------------------------------//
91  Driver() = default;
92  Driver(FConnectLike connect, FieldRef source)
93  : drivingConnect(connect), source(source) {
94  assert((isa<RefDriver, HWDriver>(*this)));
95  }
96 
97  //-- Driver methods -------------------------------------------------------//
98 
99  // "Virtual" methods, either commonly defined or dispatched appropriately.
100 
101  /// Determine direct driver for the given value, empty Driver otherwise.
102  static Driver get(Value v, FieldRefCache &refs);
103 
104  /// Whether this can be rematerialized up through an instantiation.
105  bool canHoist() const { return isa<BlockArgument>(source.getValue()); }
106 
107  /// Simple mapping across instantiation by index.
108  using PortMappingFn = llvm::function_ref<Value(size_t)>;
109 
110  /// Rematerialize this driven value, using provided mapping function and
111  /// builder. New value is returned.
112  Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder);
113 
114  /// Drop uses of the destination, inserting temporary as necessary.
115  /// Erases the driving connection, invalidating this Driver.
116  void finalize(ImplicitLocOpBuilder &builder);
117 
118  //--- Helper methods -------------------------------------------------------//
119 
120  /// Return whether this driver is valid/non-null.
121  operator bool() const { return source; }
122 
123  /// Get driven destination value.
124  Value getDest() const {
125  // (const cast to workaround getDest() not being const, even if mutates the
126  // Operation* that's fine)
127  return const_cast<Driver *>(this)->drivingConnect.getDest();
128  }
129 
130  /// Whether this driver destination is a module port.
131  bool drivesModuleArg() const {
132  auto arg = dyn_cast<BlockArgument>(getDest());
133  assert(!arg || isa<firrtl::FModuleLike>(arg.getOwner()->getParentOp()));
134  return !!arg;
135  }
136 
137  /// Whether this driver destination is an instance result.
138  bool drivesInstanceResult() const {
139  return getDest().getDefiningOp<hw::HWInstanceLike>();
140  }
141 
142  /// Get destination as block argument.
143  BlockArgument getDestBlockArg() const {
144  assert(drivesModuleArg());
145  return dyn_cast<BlockArgument>(getDest());
146  }
147 
148  /// Get destination as operation result, must be instance result.
149  OpResult getDestOpResult() const {
150  assert(drivesInstanceResult());
151  return dyn_cast<OpResult>(getDest());
152  }
153 
154  /// Helper to obtain argument/result number of destination.
155  /// Must be block arg or op result.
156  static size_t getIndex(Value v) {
157  if (auto arg = dyn_cast<BlockArgument>(v))
158  return arg.getArgNumber();
159  auto result = dyn_cast<OpResult>(v);
160  assert(result);
161  return result.getResultNumber();
162  }
163 };
164 
165 /// Driver implementation for probes.
166 struct RefDriver : public Driver {
167  using Driver::Driver;
168 
169  static bool classof(const Driver *t) { return isa<RefValue>(t->getDest()); }
170 
171  static RefDriver get(Value v, FieldRefCache &refs);
172 
173  Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder);
174 };
175 static_assert(sizeof(RefDriver) == sizeof(Driver),
176  "passed by value, no slicing");
177 
178 // Driver implementation for HW signals.
179 // Split out because has more complexity re:safety + updating.
180 // And can't walk through temporaries in same way.
181 struct HWDriver : public Driver {
182  using Driver::Driver;
183 
184  static bool classof(const Driver *t) { return !isa<RefValue>(t->getDest()); }
185 
186  static HWDriver get(Value v, FieldRefCache &refs);
187 
188  Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder);
189 };
190 static_assert(sizeof(HWDriver) == sizeof(Driver),
191  "passed by value, no slicing");
192 
193 /// Print driver information.
194 template <typename T>
195 static inline T &operator<<(T &os, Driver &d) {
196  if (!d)
197  return os << "(null)";
198  return os << d.getDest() << " <-- " << d.drivingConnect << " <-- "
199  << d.source.getValue() << "@" << d.source.getFieldID();
200 }
201 
202 } // end anonymous namespace
203 
204 //===----------------------------------------------------------------------===//
205 // Driver implementation.
206 //===----------------------------------------------------------------------===//
207 
208 Driver Driver::get(Value v, FieldRefCache &refs) {
209  if (auto refDriver = RefDriver::get(v, refs))
210  return refDriver;
211  if (auto hwDriver = HWDriver::get(v, refs))
212  return hwDriver;
213  return {};
214 }
215 
216 Value Driver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder) {
217  return TypeSwitch<Driver *, Value>(this)
218  .Case<RefDriver, HWDriver>(
219  [&](auto *d) { return d->remat(mapPortFn, builder); })
220  .Default({});
221 }
222 
223 void Driver::finalize(ImplicitLocOpBuilder &builder) {
224  auto immSource = drivingConnect.getSrc();
225  auto dest = getDest();
226  assert(immSource.getType() == dest.getType() &&
227  "final connect must be strict");
228  if (dest.hasOneUse()) {
229  // Only use is the connect, just drop it.
230  drivingConnect.erase();
231  } else if (isa<BlockArgument>(immSource)) {
232  // Block argument dominates all, so drop connect and RAUW to it.
233  drivingConnect.erase();
234  dest.replaceAllUsesWith(immSource);
235  } else {
236  // Insert wire temporary.
237  // For hoisting use-case could also remat using cached indexing inside the
238  // module, but wires keep this simple.
239  auto temp = builder.create<WireOp>(immSource.getType());
240  dest.replaceAllUsesWith(temp.getDataRaw());
241  }
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // RefDriver implementation.
246 //===----------------------------------------------------------------------===//
247 
248 static RefDefineOp getRefDefine(Value result) {
249  for (auto *user : result.getUsers()) {
250  if (auto rd = dyn_cast<RefDefineOp>(user); rd && rd.getDest() == result)
251  return rd;
252  }
253  return {};
254 }
255 
256 RefDriver RefDriver::get(Value v, FieldRefCache &refs) {
257  auto refVal = dyn_cast<RefValue>(v);
258  if (!refVal)
259  return {};
260 
261  auto rd = getRefDefine(v);
262  if (!rd)
263  return {};
264 
265  auto ref = refs.getFieldRefFromValue(rd.getSrc(), true);
266  if (!ref)
267  return {};
268 
269  return RefDriver(rd, ref);
270 }
271 
272 Value RefDriver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder) {
273  auto mappedSource = mapPortFn(getIndex(source.getValue()));
274  auto newVal = getValueByFieldID(builder, mappedSource, source.getFieldID());
275  auto destType = getDest().getType();
276  if (newVal.getType() != destType)
277  newVal = builder.create<RefCastOp>(destType, newVal);
278  return newVal;
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // HWDriver implementation.
283 //===----------------------------------------------------------------------===//
284 
285 static bool hasDontTouchOrInnerSymOnResult(Operation *op) {
287  return true;
288  auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op);
289  return symOp && symOp.getTargetResultIndex() && symOp.getInnerSymAttr();
290 }
291 
292 static bool hasDontTouchOrInnerSymOnResult(Value value) {
293  if (auto *op = value.getDefiningOp())
295  auto arg = dyn_cast<BlockArgument>(value);
296  auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
297  return (module.getPortSymbolAttr(arg.getArgNumber())) ||
298  AnnotationSet::forPort(module, arg.getArgNumber()).hasDontTouch();
299 }
300 
301 HWDriver HWDriver::get(Value v, FieldRefCache &refs) {
302  auto baseValue = dyn_cast<FIRRTLBaseValue>(v);
303  if (!baseValue)
304  return {};
305 
306  // Output must be passive, for flow reasons.
307  // Reject aggregates for now, to be conservative re:aliasing writes/etc.
308  // before ExpandWhens.
309  if (!baseValue.getType().isPassive() || !baseValue.getType().isGround())
310  return {};
311 
313  if (!connect)
314  return {};
315 
316  auto ref = refs.getFieldRefFromValue(connect.getSrc());
317  if (!ref)
318  return {};
319 
320  // Reject if not all same block.
321  if (v.getParentBlock() != ref.getValue().getParentBlock() ||
322  v.getParentBlock() != connect->getBlock())
323  return {};
324 
325  // Reject if cannot reason through this.
326  // Use local "hasDontTouch" to distinguish inner symbols on results
327  // vs on the operation itself (like an instance).
329  hasDontTouchOrInnerSymOnResult(ref.getValue()))
330  return {};
331  if (auto fop = ref.getValue().getDefiningOp<Forceable>();
332  fop && fop.isForceable())
333  return {};
334 
335  // Limit to passive sources for now.
336  auto sourceType = type_dyn_cast<FIRRTLBaseType>(ref.getValue().getType());
337  if (!sourceType)
338  return {};
339  if (!sourceType.isPassive())
340  return {};
341 
342  assert(hw::FieldIdImpl::getFinalTypeByFieldID(sourceType, ref.getFieldID()) ==
343  baseValue.getType() &&
344  "unexpected type mismatch, cast or extension?");
345 
346  return HWDriver(connect, ref);
347 }
348 
349 Value HWDriver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder) {
350  auto mappedSource = mapPortFn(getIndex(source.getValue()));
351  // TODO: Cast if needed. For now only support matching.
352  // (No cast needed for current HWDriver's, getFieldRefFromValue and
353  // assert)
354  return getValueByFieldID(builder, mappedSource, source.getFieldID());
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // MustDrivenBy analysis.
359 //===----------------------------------------------------------------------===//
360 namespace {
361 /// Driver analysis, tracking values that "must be driven" by the specified
362 /// source (+fieldID), along with final complete driving connect.
363 class MustDrivenBy {
364 public:
365  MustDrivenBy() = default;
366  MustDrivenBy(FModuleOp mod) { run(mod); }
367 
368  /// Get direct driver, if computed, for the specified value.
369  Driver getDriverFor(Value v) const { return driverMap.lookup(v); }
370 
371  /// Get combined driver for the specified value.
372  /// Walks the driver "graph" from the value to its ultimate source.
373  Driver getCombinedDriverFor(Value v) const {
374  Driver driver = driverMap.lookup(v);
375  if (!driver)
376  return driver;
377 
378  // Chase and collapse.
379  Driver cur = driver;
380  size_t len = 1;
381  SmallPtrSet<Value, 8> seen;
382  while ((cur = driverMap.lookup(cur.source.getValue()))) {
383  // If re-encounter same value, bail.
384  if (!seen.insert(cur.source.getValue()).second)
385  return {};
386  driver.source = cur.source.getSubField(driver.source.getFieldID());
387  ++len;
388  }
389  (void)len;
390  LLVM_DEBUG(llvm::dbgs() << "Found driver for " << v << " (chain length = "
391  << len << "): " << driver << "\n");
392  return driver;
393  }
394 
395  /// Analyze the given module's ports and chase simple storage.
396  void run(FModuleOp mod) {
397  SmallVector<Value, 64> worklist(mod.getArguments());
398 
399  DenseSet<Value> enqueued;
400  enqueued.insert(worklist.begin(), worklist.end());
401  FieldRefCache refs;
402  while (!worklist.empty()) {
403  auto val = worklist.pop_back_val();
404  auto driver =
405  ignoreHWDrivers ? RefDriver::get(val, refs) : Driver::get(val, refs);
406  driverMap.insert({val, driver});
407  if (!driver)
408  continue;
409 
410  auto sourceVal = driver.source.getValue();
411 
412  // If already enqueued, ignore.
413  if (!enqueued.insert(sourceVal).second)
414  continue;
415 
416  // Only chase through atomic values for now.
417  // Here, atomic implies must be driven entirely.
418  // This is true for HW types, and is true for RefType's because
419  // while they can be indexed into, only RHS can have indexing.
420  if (hw::FieldIdImpl::getMaxFieldID(sourceVal.getType()) != 0)
421  continue;
422 
423  // Only through Wires, block arguments, instance results.
424  if (!isa<BlockArgument>(sourceVal) &&
425  !isa_and_nonnull<WireOp, InstanceOp>(sourceVal.getDefiningOp()))
426  continue;
427 
428  worklist.push_back(sourceVal);
429  }
430 
431  refs.verify();
432 
433  LLVM_DEBUG({
434  llvm::dbgs() << "Analyzed " << mod.getModuleName() << " and found "
435  << driverMap.size() << " drivers.\n";
436  refs.printStats(llvm::dbgs());
437  });
438  }
439 
440  /// Clear out analysis results and storage.
441  void clear() { driverMap.clear(); }
442 
443  /// Configure whether HW signals are analyzed.
444  void setIgnoreHWDrivers(bool ignore) { ignoreHWDrivers = ignore; }
445 
446 private:
447  /// Map of values to their computed direct must-drive source.
448  DenseMap<Value, Driver> driverMap;
449  bool ignoreHWDrivers = false;
450 };
451 
452 } // end anonymous namespace
453 
454 //===----------------------------------------------------------------------===//
455 // Pass Infrastructure
456 //===----------------------------------------------------------------------===//
457 
458 namespace {
459 
460 struct HoistPassthroughPass
461  : public HoistPassthroughBase<HoistPassthroughPass> {
462  using HoistPassthroughBase::HoistPassthroughBase;
463  void runOnOperation() override;
464 
465  using HoistPassthroughBase::hoistHWDrivers;
466 };
467 } // end anonymous namespace
468 
469 void HoistPassthroughPass::runOnOperation() {
470  LLVM_DEBUG(debugPassHeader(this) << "\n");
471  auto &instanceGraph = getAnalysis<InstanceGraph>();
472 
473  SmallVector<FModuleOp, 0> modules(llvm::make_filter_range(
474  llvm::map_range(
475  llvm::post_order(&instanceGraph),
476  [](auto *node) { return dyn_cast<FModuleOp>(*node->getModule()); }),
477  [](auto module) { return module; }));
478 
479  MustDrivenBy driverAnalysis;
480  driverAnalysis.setIgnoreHWDrivers(!hoistHWDrivers);
481 
482  bool anyChanged = false;
483 
484  // For each module (PO)...
485  for (auto module : modules) {
486  // TODO: Public means can't reason down into, or remove ports.
487  // Does not mean cannot clone out wires or optimize w.r.t its contents.
488  if (module.isPublic())
489  continue;
490 
491  // 1. Analyze.
492 
493  // What ports to delete.
494  // Hoisted drivers of output ports will be deleted.
495  BitVector deadPorts(module.getNumPorts());
496 
497  // Instance graph node, for walking instances of this module.
498  auto *igNode = instanceGraph.lookup(module);
499 
500  // Analyze all ports using current IR.
501  driverAnalysis.clear();
502  driverAnalysis.run(module);
503  auto notNullAndCanHoist = [](const Driver &d) -> bool {
504  return d && d.canHoist();
505  };
506  SmallVector<Driver, 16> drivers(llvm::make_filter_range(
507  llvm::map_range(module.getArguments(),
508  [&driverAnalysis](auto val) {
509  return driverAnalysis.getCombinedDriverFor(val);
510  }),
511  notNullAndCanHoist));
512 
513  // If no hoistable drivers found, nothing to do. Onwards!
514  if (drivers.empty())
515  continue;
516 
517  anyChanged = true;
518 
519  // 2. Rematerialize must-driven ports at instantiation sites.
520 
521  // Do this first, keep alive Driver state pointing to module.
522  for (auto &driver : drivers) {
523  std::optional<size_t> deadPort;
524  {
525  auto destArg = driver.getDestBlockArg();
526  auto index = destArg.getArgNumber();
527 
528  // Replace dest in all instantiations.
529  for (auto *record : igNode->uses()) {
530  auto inst = cast<InstanceOp>(record->getInstance());
531  ImplicitLocOpBuilder builder(inst.getLoc(), inst);
532  builder.setInsertionPointAfter(inst);
533 
534  auto mappedDest = inst.getResult(index);
535  mappedDest.replaceAllUsesWith(driver.remat(
536  [&inst](size_t index) { return inst.getResult(index); },
537  builder));
538  }
539  // The driven port has no external users, will soon be dead.
540  deadPort = index;
541  }
542  assert(deadPort.has_value());
543 
544  assert(!deadPorts.test(*deadPort));
545  deadPorts.set(*deadPort);
546 
547  // Update statistics.
548  TypeSwitch<Driver *, void>(&driver)
549  .Case<RefDriver>([&](auto *) { ++numRefDrivers; })
550  .Case<HWDriver>([&](auto *) { ++numHWDrivers; });
551  }
552 
553  // 3. Finalize stage. Ensure remat'd dest is unused on original side.
554 
555  ImplicitLocOpBuilder builder(module.getLoc(), module.getBody());
556  for (auto &driver : drivers) {
557  // Finalize. Invalidates the driver.
558  builder.setLoc(driver.getDest().getLoc());
559  driver.finalize(builder);
560  }
561 
562  // 4. Delete newly dead ports.
563 
564  // Drop dead ports at instantiation sites.
565  for (auto *record : llvm::make_early_inc_range(igNode->uses())) {
566  auto inst = cast<InstanceOp>(record->getInstance());
567  ImplicitLocOpBuilder builder(inst.getLoc(), inst);
568 
569  assert(inst.getNumResults() == deadPorts.size());
570  auto newInst = inst.erasePorts(builder, deadPorts);
571  instanceGraph.replaceInstance(inst, newInst);
572  inst.erase();
573  }
574 
575  // Drop dead ports from module.
576  module.erasePorts(deadPorts);
577 
578  numUTurnsHoisted += deadPorts.count();
579  }
580  markAnalysesPreserved<InstanceGraph>();
581 
582  if (!anyChanged)
583  markAllAnalysesPreserved();
584 }
585 
586 /// This is the pass constructor.
587 std::unique_ptr<mlir::Pass>
589  auto pass = std::make_unique<HoistPassthroughPass>();
590  pass->hoistHWDrivers = hoistHWDrivers;
591  return pass;
592 }
assert(baseType &&"element must be base type")
static RefDefineOp getRefDefine(Value result)
mlir::TypedValue< RefType > RefValue
static bool hasDontTouchOrInnerSymOnResult(Operation *op)
Builder builder
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
Value getValue() const
Get the Value which created this location.
Definition: FieldRef.h:37
bool hasDontTouch() const
firrtl.transforms.DontTouchAnnotation
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
Caching version of getFieldRefFromValue.
Definition: FieldRefCache.h:26
void printStats(llvm::raw_ostream &os) const
void verify() const
Verify cached fieldRefs against firrtl::getFieldRefFromValue.
Definition: FieldRefCache.h:51
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Caching version of getFieldRefFromValue.
def connect(destination, source)
Definition: support.py:37
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
T & operator<<(T &os, FIRVersion version)
Definition: FIRParser.h:116
std::unique_ptr< mlir::Pass > createHoistPassthroughPass(bool hoistHWDrivers=true)
This is the pass constructor.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
StrictConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
uint64_t getMaxFieldID(Type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition: Debug.cpp:31