CIRCT  18.0.0git
HoistPassthrough.cpp
Go to the documentation of this file.
1 //===- HoistPassthrough.cpp - Hoist basic passthrough ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the HoistPassthrough pass. This pass identifies basic
10 // drivers of output ports that can be pulled out of modules.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetails.h"
22 #include "mlir/IR/BuiltinAttributes.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "llvm/ADT/PostOrderIterator.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define DEBUG_TYPE "firrtl-hoist-passthrough"
29 
30 using namespace circt;
31 using namespace firrtl;
32 
33 using RefValue = mlir::TypedValue<RefType>;
34 
35 namespace {
36 
37 struct RefDriver;
38 struct HWDriver;
39 
40 //===----------------------------------------------------------------------===//
41 // (Rematerializable)Driver declaration.
42 //===----------------------------------------------------------------------===//
43 /// Statically known driver for a Value.
44 ///
45 /// Driver source expected to be rematerialized provided a mapping.
46 /// Generally takes form:
47 /// [source]----(static indexing?)---->DRIVE_OP---->[dest]
48 ///
49 /// However, only requirement is that the "driver" can be rematerialized
50 /// across a module/instance boundary in terms of mapping args<-->results.
51 ///
52 /// Driver can be reconstructed given a mapping in new location.
53 ///
54 /// "Update":
55 /// Map:
56 /// source -> A
57 /// dest -> B
58 ///
59 /// [source]---(indexing)--> SSA_DRIVE_OP ---> [dest]
60 /// + ([s']---> SSA_DRIVE_OP ---> [A])
61 /// =>
62 /// RAUW(B, [A]--(clone indexing))
63 /// (or RAUW(B, [s']--(clone indexing)))
64 ///
65 /// Update is safe if driver classification is ""equivalent"" for each context
66 /// on the other side. For hoisting U-Turns, this is safe in all cases,
67 /// for sinking n-turns the driver must be map-equivalent at all instantiation
68 /// sites.
69 /// Only UTurns are supported presently.
70 ///
71 /// The goal is to drop the destination port, so after replacing all users
72 /// on other side of the instantiation, drop the port driver and move
73 /// all its users to the driver (immediate) source.
74 /// This may not be safe if the driver source does not dominate all users of the
75 /// port, in which case either reject (unsafe) or insert a temporary wire to
76 /// drive instead.
77 ///
78 /// RAUW'ing may require insertion of conversion ops if types don't match.
79 //===----------------------------------------------------------------------===//
80 struct Driver {
81  //-- Data -----------------------------------------------------------------//
82 
83  /// Connect entirely and definitively driving the destination.
84  FConnectLike drivingConnect;
85  /// Source of LHS.
86  FieldRef source;
87 
88  //-- Constructors ---------------------------------------------------------//
89  Driver() = default;
90  Driver(FConnectLike connect, FieldRef source)
91  : drivingConnect(connect), source(source) {
92  assert((isa<RefDriver, HWDriver>(*this)));
93  }
94 
95  //-- Driver methods -------------------------------------------------------//
96 
97  // "Virtual" methods, either commonly defined or dispatched appropriately.
98 
99  /// Determine direct driver for the given value, empty Driver otherwise.
100  static Driver get(Value v, FieldRefCache &refs);
101 
102  /// Whether this can be rematerialized up through an instantiation.
103  bool canHoist() const { return isa<BlockArgument>(source.getValue()); }
104 
105  /// Simple mapping across instantiation by index.
106  using PortMappingFn = llvm::function_ref<Value(size_t)>;
107 
108  /// Rematerialize this driven value, using provided mapping function and
109  /// builder. New value is returned.
110  Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder);
111 
112  /// Drop uses of the destination, inserting temporary as necessary.
113  /// Erases the driving connection, invalidating this Driver.
114  void finalize(ImplicitLocOpBuilder &builder);
115 
116  //--- Helper methods -------------------------------------------------------//
117 
118  /// Return whether this driver is valid/non-null.
119  operator bool() const { return source; }
120 
121  /// Get driven destination value.
122  Value getDest() const {
123  // (const cast to workaround getDest() not being const, even if mutates the
124  // Operation* that's fine)
125  return const_cast<Driver *>(this)->drivingConnect.getDest();
126  }
127 
128  /// Whether this driver destination is a module port.
129  bool drivesModuleArg() const {
130  auto arg = dyn_cast<BlockArgument>(getDest());
131  assert(!arg || isa<firrtl::FModuleLike>(arg.getOwner()->getParentOp()));
132  return !!arg;
133  }
134 
135  /// Whether this driver destination is an instance result.
136  bool drivesInstanceResult() const {
137  return getDest().getDefiningOp<hw::HWInstanceLike>();
138  }
139 
140  /// Get destination as block argument.
141  BlockArgument getDestBlockArg() const {
142  assert(drivesModuleArg());
143  return dyn_cast<BlockArgument>(getDest());
144  }
145 
146  /// Get destination as operation result, must be instance result.
147  OpResult getDestOpResult() const {
148  assert(drivesInstanceResult());
149  return dyn_cast<OpResult>(getDest());
150  }
151 
152  /// Helper to obtain argument/result number of destination.
153  /// Must be block arg or op result.
154  static size_t getIndex(Value v) {
155  if (auto arg = dyn_cast<BlockArgument>(v))
156  return arg.getArgNumber();
157  auto result = dyn_cast<OpResult>(v);
158  assert(result);
159  return result.getResultNumber();
160  }
161 };
162 
163 /// Driver implementation for probes.
164 struct RefDriver : public Driver {
165  using Driver::Driver;
166 
167  static bool classof(const Driver *t) { return isa<RefValue>(t->getDest()); }
168 
169  static RefDriver get(Value v, FieldRefCache &refs);
170 
171  Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder);
172 };
173 static_assert(sizeof(RefDriver) == sizeof(Driver),
174  "passed by value, no slicing");
175 
176 // Driver implementation for HW signals.
177 // Split out because has more complexity re:safety + updating.
178 // And can't walk through temporaries in same way.
179 struct HWDriver : public Driver {
180  using Driver::Driver;
181 
182  static bool classof(const Driver *t) { return !isa<RefValue>(t->getDest()); }
183 
184  static HWDriver get(Value v, FieldRefCache &refs);
185 
186  Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder);
187 };
188 static_assert(sizeof(HWDriver) == sizeof(Driver),
189  "passed by value, no slicing");
190 
191 /// Print driver information.
192 template <typename T>
193 static inline T &operator<<(T &os, Driver &d) {
194  if (!d)
195  return os << "(null)";
196  return os << d.getDest() << " <-- " << d.drivingConnect << " <-- "
197  << d.source.getValue() << "@" << d.source.getFieldID();
198 }
199 
200 } // end anonymous namespace
201 
202 //===----------------------------------------------------------------------===//
203 // Driver implementation.
204 //===----------------------------------------------------------------------===//
205 
206 Driver Driver::get(Value v, FieldRefCache &refs) {
207  if (auto refDriver = RefDriver::get(v, refs))
208  return refDriver;
209  if (auto hwDriver = HWDriver::get(v, refs))
210  return hwDriver;
211  return {};
212 }
213 
214 Value Driver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder) {
215  return TypeSwitch<Driver *, Value>(this)
216  .Case<RefDriver, HWDriver>(
217  [&](auto *d) { return d->remat(mapPortFn, builder); })
218  .Default({});
219 }
220 
221 void Driver::finalize(ImplicitLocOpBuilder &builder) {
222  auto immSource = drivingConnect.getSrc();
223  auto dest = getDest();
224  assert(immSource.getType() == dest.getType() &&
225  "final connect must be strict");
226  if (dest.hasOneUse()) {
227  // Only use is the connect, just drop it.
228  drivingConnect.erase();
229  } else if (isa<BlockArgument>(immSource)) {
230  // Block argument dominates all, so drop connect and RAUW to it.
231  drivingConnect.erase();
232  dest.replaceAllUsesWith(immSource);
233  } else {
234  // Insert wire temporary.
235  // For hoisting use-case could also remat using cached indexing inside the
236  // module, but wires keep this simple.
237  auto temp = builder.create<WireOp>(immSource.getType());
238  dest.replaceAllUsesWith(temp.getDataRaw());
239  }
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // RefDriver implementation.
244 //===----------------------------------------------------------------------===//
245 
246 static RefDefineOp getRefDefine(Value result) {
247  for (auto *user : result.getUsers()) {
248  if (auto rd = dyn_cast<RefDefineOp>(user); rd && rd.getDest() == result)
249  return rd;
250  }
251  return {};
252 }
253 
254 RefDriver RefDriver::get(Value v, FieldRefCache &refs) {
255  auto refVal = dyn_cast<RefValue>(v);
256  if (!refVal)
257  return {};
258 
259  auto rd = getRefDefine(v);
260  if (!rd)
261  return {};
262 
263  auto ref = refs.getFieldRefFromValue(rd.getSrc(), true);
264  if (!ref)
265  return {};
266 
267  return RefDriver(rd, ref);
268 }
269 
270 Value RefDriver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder) {
271  auto mappedSource = mapPortFn(getIndex(source.getValue()));
272  auto newVal = getValueByFieldID(builder, mappedSource, source.getFieldID());
273  auto destType = getDest().getType();
274  if (newVal.getType() != destType)
275  newVal = builder.create<RefCastOp>(destType, newVal);
276  return newVal;
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // HWDriver implementation.
281 //===----------------------------------------------------------------------===//
282 
283 static bool hasDontTouchOrInnerSymOnResult(Operation *op) {
285  return true;
286  auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op);
287  return symOp && symOp.getTargetResultIndex() && symOp.getInnerSymAttr();
288 }
289 
291  if (auto *op = value.getDefiningOp())
293  auto arg = dyn_cast<BlockArgument>(value);
294  auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
295  return (module.getPortSymbolAttr(arg.getArgNumber())) ||
296  AnnotationSet::forPort(module, arg.getArgNumber()).hasDontTouch();
297 }
298 
299 HWDriver HWDriver::get(Value v, FieldRefCache &refs) {
300  auto baseValue = dyn_cast<FIRRTLBaseValue>(v);
301  if (!baseValue)
302  return {};
303 
304  // Output must be passive, for flow reasons.
305  // Reject aggregates for now, to be conservative re:aliasing writes/etc.
306  // before ExpandWhens.
307  if (!baseValue.getType().isPassive() || !baseValue.getType().isGround())
308  return {};
309 
311  if (!connect)
312  return {};
313 
314  auto ref = refs.getFieldRefFromValue(connect.getSrc());
315  if (!ref)
316  return {};
317 
318  // Reject if not all same block.
319  if (v.getParentBlock() != ref.getValue().getParentBlock() ||
320  v.getParentBlock() != connect->getBlock())
321  return {};
322 
323  // Reject if cannot reason through this.
324  // Use local "hasDontTouch" to distinguish inner symbols on results
325  // vs on the operation itself (like an instance).
327  hasDontTouchOrInnerSymOnResult(ref.getValue()))
328  return {};
329  if (auto fop = ref.getValue().getDefiningOp<Forceable>();
330  fop && fop.isForceable())
331  return {};
332 
333  // Limit to passive sources for now.
334  auto sourceType = type_dyn_cast<FIRRTLBaseType>(ref.getValue().getType());
335  if (!sourceType)
336  return {};
337  if (!sourceType.isPassive())
338  return {};
339 
340  assert(hw::FieldIdImpl::getFinalTypeByFieldID(sourceType, ref.getFieldID()) ==
341  baseValue.getType() &&
342  "unexpected type mismatch, cast or extension?");
343 
344  return HWDriver(connect, ref);
345 }
346 
347 Value HWDriver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &builder) {
348  auto mappedSource = mapPortFn(getIndex(source.getValue()));
349  // TODO: Cast if needed. For now only support matching.
350  // (No cast needed for current HWDriver's, getFieldRefFromValue and
351  // assert)
352  return getValueByFieldID(builder, mappedSource, source.getFieldID());
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // MustDrivenBy analysis.
357 //===----------------------------------------------------------------------===//
358 namespace {
359 /// Driver analysis, tracking values that "must be driven" by the specified
360 /// source (+fieldID), along with final complete driving connect.
361 class MustDrivenBy {
362 public:
363  MustDrivenBy() = default;
364  MustDrivenBy(FModuleOp mod) { run(mod); }
365 
366  /// Get direct driver, if computed, for the specified value.
367  Driver getDriverFor(Value v) const { return driverMap.lookup(v); }
368 
369  /// Get combined driver for the specified value.
370  /// Walks the driver "graph" from the value to its ultimate source.
371  Driver getCombinedDriverFor(Value v) const {
372  Driver driver = driverMap.lookup(v);
373  if (!driver)
374  return driver;
375 
376  // Chase and collapse.
377  Driver cur = driver;
378  size_t len = 1;
379  SmallPtrSet<Value, 8> seen;
380  while ((cur = driverMap.lookup(cur.source.getValue()))) {
381  // If re-encounter same value, bail.
382  if (!seen.insert(cur.source.getValue()).second)
383  return {};
384  driver.source = cur.source.getSubField(driver.source.getFieldID());
385  ++len;
386  }
387  (void)len;
388  LLVM_DEBUG(llvm::dbgs() << "Found driver for " << v << " (chain length = "
389  << len << "): " << driver << "\n");
390  return driver;
391  }
392 
393  /// Analyze the given module's ports and chase simple storage.
394  void run(FModuleOp mod) {
395  SmallVector<Value, 64> worklist(mod.getArguments());
396 
397  DenseSet<Value> enqueued;
398  enqueued.insert(worklist.begin(), worklist.end());
399  FieldRefCache refs;
400  while (!worklist.empty()) {
401  auto val = worklist.pop_back_val();
402  auto driver =
403  ignoreHWDrivers ? RefDriver::get(val, refs) : Driver::get(val, refs);
404  driverMap.insert({val, driver});
405  if (!driver)
406  continue;
407 
408  auto sourceVal = driver.source.getValue();
409 
410  // If already enqueued, ignore.
411  if (!enqueued.insert(sourceVal).second)
412  continue;
413 
414  // Only chase through atomic values for now.
415  // Here, atomic implies must be driven entirely.
416  // This is true for HW types, and is true for RefType's because
417  // while they can be indexed into, only RHS can have indexing.
418  if (hw::FieldIdImpl::getMaxFieldID(sourceVal.getType()) != 0)
419  continue;
420 
421  // Only through Wires, block arguments, instance results.
422  if (!isa<BlockArgument>(sourceVal) &&
423  !isa_and_nonnull<WireOp, InstanceOp>(sourceVal.getDefiningOp()))
424  continue;
425 
426  worklist.push_back(sourceVal);
427  }
428 
429  refs.verify();
430 
431  LLVM_DEBUG({
432  llvm::dbgs() << "Analyzed " << mod.getModuleName() << " and found "
433  << driverMap.size() << " drivers.\n";
434  refs.printStats(llvm::dbgs());
435  });
436  }
437 
438  /// Clear out analysis results and storage.
439  void clear() { driverMap.clear(); }
440 
441  /// Configure whether HW signals are analyzed.
442  void setIgnoreHWDrivers(bool ignore) { ignoreHWDrivers = ignore; }
443 
444 private:
445  /// Map of values to their computed direct must-drive source.
446  DenseMap<Value, Driver> driverMap;
447  bool ignoreHWDrivers = false;
448 };
449 
450 } // end anonymous namespace
451 
452 //===----------------------------------------------------------------------===//
453 // Pass Infrastructure
454 //===----------------------------------------------------------------------===//
455 
456 namespace {
457 
458 struct HoistPassthroughPass
459  : public HoistPassthroughBase<HoistPassthroughPass> {
460  using HoistPassthroughBase::HoistPassthroughBase;
461  void runOnOperation() override;
462 
463  using HoistPassthroughBase::hoistHWDrivers;
464 };
465 } // end anonymous namespace
466 
467 void HoistPassthroughPass::runOnOperation() {
468  LLVM_DEBUG(llvm::dbgs() << "===- Running HoistPassthrough Pass "
469  "------------------------------------------===\n");
470  auto &instanceGraph = getAnalysis<InstanceGraph>();
471 
472  SmallVector<FModuleOp, 0> modules(llvm::make_filter_range(
473  llvm::map_range(
474  llvm::post_order(&instanceGraph),
475  [](auto *node) { return dyn_cast<FModuleOp>(*node->getModule()); }),
476  [](auto module) { return module; }));
477 
478  MustDrivenBy driverAnalysis;
479  driverAnalysis.setIgnoreHWDrivers(!hoistHWDrivers);
480 
481  bool anyChanged = false;
482 
483  // For each module (PO)...
484  for (auto module : modules) {
485  // TODO: Public means can't reason down into, or remove ports.
486  // Does not mean cannot clone out wires or optimize w.r.t its contents.
487  if (module.isPublic())
488  continue;
489 
490  // 1. Analyze.
491 
492  // What ports to delete.
493  // Hoisted drivers of output ports will be deleted.
494  BitVector deadPorts(module.getNumPorts());
495 
496  // Instance graph node, for walking instances of this module.
497  auto *igNode = instanceGraph.lookup(module);
498 
499  // Analyze all ports using current IR.
500  driverAnalysis.clear();
501  driverAnalysis.run(module);
502  auto notNullAndCanHoist = [](const Driver &d) -> bool {
503  return d && d.canHoist();
504  };
505  SmallVector<Driver, 16> drivers(llvm::make_filter_range(
506  llvm::map_range(module.getArguments(),
507  [&driverAnalysis](auto val) {
508  return driverAnalysis.getCombinedDriverFor(val);
509  }),
510  notNullAndCanHoist));
511 
512  // If no hoistable drivers found, nothing to do. Onwards!
513  if (drivers.empty())
514  continue;
515 
516  anyChanged = true;
517 
518  // 2. Rematerialize must-driven ports at instantiation sites.
519 
520  // Do this first, keep alive Driver state pointing to module.
521  for (auto &driver : drivers) {
522  std::optional<size_t> deadPort;
523  {
524  auto destArg = driver.getDestBlockArg();
525  auto index = destArg.getArgNumber();
526 
527  // Replace dest in all instantiations.
528  for (auto *record : igNode->uses()) {
529  auto inst = cast<InstanceOp>(record->getInstance());
530  ImplicitLocOpBuilder builder(inst.getLoc(), inst);
531  builder.setInsertionPointAfter(inst);
532 
533  auto mappedDest = inst.getResult(index);
534  mappedDest.replaceAllUsesWith(driver.remat(
535  [&inst](size_t index) { return inst.getResult(index); },
536  builder));
537  }
538  // The driven port has no external users, will soon be dead.
539  deadPort = index;
540  }
541  assert(deadPort.has_value());
542 
543  assert(!deadPorts.test(*deadPort));
544  deadPorts.set(*deadPort);
545 
546  // Update statistics.
547  TypeSwitch<Driver *, void>(&driver)
548  .Case<RefDriver>([&](auto *) { ++numRefDrivers; })
549  .Case<HWDriver>([&](auto *) { ++numHWDrivers; });
550  }
551 
552  // 3. Finalize stage. Ensure remat'd dest is unused on original side.
553 
554  ImplicitLocOpBuilder builder(module.getLoc(), module.getBody());
555  for (auto &driver : drivers) {
556  // Finalize. Invalidates the driver.
557  builder.setLoc(driver.getDest().getLoc());
558  driver.finalize(builder);
559  }
560 
561  // 4. Delete newly dead ports.
562 
563  // Drop dead ports at instantiation sites.
564  for (auto *record : llvm::make_early_inc_range(igNode->uses())) {
565  auto inst = cast<InstanceOp>(record->getInstance());
566  ImplicitLocOpBuilder builder(inst.getLoc(), inst);
567 
568  assert(inst.getNumResults() == deadPorts.size());
569  auto newInst = inst.erasePorts(builder, deadPorts);
570  instanceGraph.replaceInstance(inst, newInst);
571  inst.erase();
572  }
573 
574  // Drop dead ports from module.
575  module.erasePorts(deadPorts);
576 
577  numUTurnsHoisted += deadPorts.count();
578  }
579  markAnalysesPreserved<InstanceGraph>();
580 
581  if (!anyChanged)
582  markAllAnalysesPreserved();
583 }
584 
585 /// This is the pass constructor.
586 std::unique_ptr<mlir::Pass>
588  auto pass = std::make_unique<HoistPassthroughPass>();
589  pass->hoistHWDrivers = hoistHWDrivers;
590  return pass;
591 }
lowerAnnotationsNoRefTypePorts FirtoolPreserveValuesMode value
Definition: Firtool.cpp:95
assert(baseType &&"element must be base type")
static RefDefineOp getRefDefine(Value result)
mlir::TypedValue< RefType > RefValue
static bool hasDontTouchOrInnerSymOnResult(Operation *op)
Builder builder
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
Value getValue() const
Get the Value which created this location.
Definition: FieldRef.h:37
bool hasDontTouch() const
firrtl.transforms.DontTouchAnnotation
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
Caching version of getFieldRefFromValue.
Definition: FieldRefCache.h:26
void printStats(llvm::raw_ostream &os) const
void verify() const
Verify cached fieldRefs against firrtl::getFieldRefFromValue.
Definition: FieldRefCache.h:51
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Caching version of getFieldRefFromValue.
def connect(destination, source)
Definition: support.py:37
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
T & operator<<(T &os, FIRVersion version)
Definition: FIRParser.h:115
std::unique_ptr< mlir::Pass > createHoistPassthroughPass(bool hoistHWDrivers=true)
This is the pass constructor.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
StrictConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
uint64_t getMaxFieldID(Type)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28