CIRCT  20.0.0git
IMConstProp.cpp
Go to the documentation of this file.
1 //===- IMConstProp.cpp - Intermodule ConstProp and DCE ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements SCCP:
10 // https://www.cs.wustl.edu/~cytron/531Pages/f11/Resources/Papers/cprop.pdf
11 //
12 //===----------------------------------------------------------------------===//
13 
21 #include "circt/Support/APInt.h"
22 #include "mlir/IR/Iterators.h"
23 #include "mlir/IR/Threading.h"
24 #include "mlir/Pass/Pass.h"
25 #include "llvm/ADT/APSInt.h"
26 #include "llvm/ADT/TinyPtrVector.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 
30 namespace circt {
31 namespace firrtl {
32 #define GEN_PASS_DEF_IMCONSTPROP
33 #include "circt/Dialect/FIRRTL/Passes.h.inc"
34 } // namespace firrtl
35 } // namespace circt
36 
37 using namespace circt;
38 using namespace firrtl;
39 
40 #define DEBUG_TYPE "IMCP"
41 
42 /// Return true if this is a wire or register.
43 static bool isWireOrReg(Operation *op) {
44  return isa<WireOp, RegResetOp, RegOp>(op);
45 }
46 
47 /// Return true if this is an aggregate indexer.
48 static bool isAggregate(Operation *op) {
49  return isa<SubindexOp, SubaccessOp, SubfieldOp, OpenSubfieldOp,
50  OpenSubindexOp, RefSubOp>(op);
51 }
52 
53 // Return true if this forwards input to output.
54 // Implies has appropriate visit method that propagates changes.
55 static bool isNodeLike(Operation *op) {
56  return isa<NodeOp, RefResolveOp, RefSendOp>(op);
57 }
58 
59 /// Return true if this is a wire or register we're allowed to delete.
60 static bool isDeletableWireOrRegOrNode(Operation *op) {
61  if (!isWireOrReg(op) && !isa<NodeOp>(op))
62  return false;
63 
64  // Always allow deleting wires of probe-type.
65  if (type_isa<RefType>(op->getResult(0).getType()))
66  return true;
67 
68  // Otherwise, don't delete if has anything keeping it around or unknown.
69  return AnnotationSet(op).canBeDeleted() && !hasDontTouch(op) &&
70  hasDroppableName(op) && !cast<Forceable>(op).isForceable();
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // Pass Infrastructure
75 //===----------------------------------------------------------------------===//
76 
77 namespace {
78 /// This class represents a single lattice value. A lattive value corresponds to
79 /// the various different states that a value in the SCCP dataflow analysis can
80 /// take. See 'Kind' below for more details on the different states a value can
81 /// take.
82 class LatticeValue {
83  enum Kind {
84  /// A value with a yet-to-be-determined value. This state may be changed to
85  /// anything, it hasn't been processed by IMConstProp.
86  Unknown,
87 
88  /// A value that is known to be a constant. This state may be changed to
89  /// overdefined.
90  Constant,
91 
92  /// A value that cannot statically be determined to be a constant. This
93  /// state cannot be changed.
94  Overdefined
95  };
96 
97 public:
98  /// Initialize a lattice value with "Unknown".
99  /*implicit*/ LatticeValue() : valueAndTag(nullptr, Kind::Unknown) {}
100  /// Initialize a lattice value with a constant.
101  /*implicit*/ LatticeValue(IntegerAttr attr)
102  : valueAndTag(attr, Kind::Constant) {}
103  /*implicit*/ LatticeValue(StringAttr attr)
104  : valueAndTag(attr, Kind::Constant) {}
105 
106  static LatticeValue getOverdefined() {
107  LatticeValue result;
108  result.markOverdefined();
109  return result;
110  }
111 
112  bool isUnknown() const { return valueAndTag.getInt() == Kind::Unknown; }
113  bool isConstant() const { return valueAndTag.getInt() == Kind::Constant; }
114  bool isOverdefined() const {
115  return valueAndTag.getInt() == Kind::Overdefined;
116  }
117 
118  /// Mark the lattice value as overdefined.
119  void markOverdefined() {
120  valueAndTag.setPointerAndInt(nullptr, Kind::Overdefined);
121  }
122 
123  /// Mark the lattice value as constant.
124  void markConstant(IntegerAttr value) {
125  valueAndTag.setPointerAndInt(value, Kind::Constant);
126  }
127 
128  /// If this lattice is constant or invalid value, return the attribute.
129  /// Returns nullptr otherwise.
130  Attribute getValue() const { return valueAndTag.getPointer(); }
131 
132  /// If this is in the constant state, return the attribute.
133  Attribute getConstant() const {
134  assert(isConstant());
135  return getValue();
136  }
137 
138  /// Merge in the value of the 'rhs' lattice into this one. Returns true if the
139  /// lattice value changed.
140  bool mergeIn(LatticeValue rhs) {
141  // If we are already overdefined, or rhs is unknown, there is nothing to do.
142  if (isOverdefined() || rhs.isUnknown())
143  return false;
144 
145  // If we are unknown, just take the value of rhs.
146  if (isUnknown()) {
147  valueAndTag = rhs.valueAndTag;
148  return true;
149  }
150 
151  // Otherwise, if this value doesn't match rhs go straight to overdefined.
152  // This happens when we merge "3" and "4" from two different instance sites
153  // for example.
154  if (valueAndTag != rhs.valueAndTag) {
155  markOverdefined();
156  return true;
157  }
158  return false;
159  }
160 
161  bool operator==(const LatticeValue &other) const {
162  return valueAndTag == other.valueAndTag;
163  }
164  bool operator!=(const LatticeValue &other) const {
165  return valueAndTag != other.valueAndTag;
166  }
167 
168 private:
169  /// The attribute value if this is a constant and the tag for the element
170  /// kind. The attribute is an IntegerAttr (or BoolAttr) or StringAttr.
171  llvm::PointerIntPair<Attribute, 2, Kind> valueAndTag;
172 };
173 } // end anonymous namespace
174 
175 LLVM_ATTRIBUTE_USED
176 static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
177  const LatticeValue &lattice) {
178  if (lattice.isUnknown())
179  return os << "<Unknown>";
180  if (lattice.isOverdefined())
181  return os << "<Overdefined>";
182  return os << "<" << lattice.getConstant() << ">";
183 }
184 
185 namespace {
186 struct IMConstPropPass
187  : public circt::firrtl::impl::IMConstPropBase<IMConstPropPass> {
188 
189  void runOnOperation() override;
190  void rewriteModuleBody(FModuleOp module);
191 
192  /// Returns true if the given block is executable.
193  bool isBlockExecutable(Block *block) const {
194  return executableBlocks.count(block);
195  }
196 
197  bool isOverdefined(FieldRef value) const {
198  auto it = latticeValues.find(value);
199  return it != latticeValues.end() && it->second.isOverdefined();
200  }
201 
202  // Mark the given value as overdefined. If the value is an aggregate,
203  // we mark all ground elements as overdefined.
204  void markOverdefined(Value value) {
205  FieldRef fieldRef = getOrCacheFieldRefFromValue(value);
206  auto firrtlType = type_dyn_cast<FIRRTLType>(value.getType());
207  if (!firrtlType || type_isa<PropertyType>(firrtlType)) {
208  markOverdefined(fieldRef);
209  return;
210  }
211 
212  walkGroundTypes(firrtlType, [&](uint64_t fieldID, auto, auto) {
213  markOverdefined(fieldRef.getSubField(fieldID));
214  });
215  }
216 
217  /// Mark the given value as overdefined. This means that we cannot refine a
218  /// specific constant for this value.
219  void markOverdefined(FieldRef value) {
220  auto &entry = latticeValues[value];
221  if (!entry.isOverdefined()) {
222  LLVM_DEBUG({
223  logger.getOStream()
224  << "Setting overdefined : (" << getFieldName(value).first << ")\n";
225  });
226  entry.markOverdefined();
227  changedLatticeValueWorklist.push_back(value);
228  }
229  }
230 
231  /// Merge information from the 'from' lattice value into value. If it
232  /// changes, then users of the value are added to the worklist for
233  /// revisitation.
234  void mergeLatticeValue(FieldRef value, LatticeValue &valueEntry,
235  LatticeValue source) {
236  if (valueEntry.mergeIn(source)) {
237  LLVM_DEBUG({
238  logger.getOStream()
239  << "Changed to " << valueEntry << " : (" << value << ")\n";
240  });
241  changedLatticeValueWorklist.push_back(value);
242  }
243  }
244 
245  void mergeLatticeValue(FieldRef value, LatticeValue source) {
246  // Don't even do a map lookup if from has no info in it.
247  if (source.isUnknown())
248  return;
249  mergeLatticeValue(value, latticeValues[value], source);
250  }
251 
252  void mergeLatticeValue(FieldRef result, FieldRef from) {
253  // If 'from' hasn't been computed yet, then it is unknown, don't do
254  // anything.
255  auto it = latticeValues.find(from);
256  if (it == latticeValues.end())
257  return;
258  mergeLatticeValue(result, it->second);
259  }
260 
261  void mergeLatticeValue(Value result, Value from) {
262  FieldRef fieldRefFrom = getOrCacheFieldRefFromValue(from);
263  FieldRef fieldRefResult = getOrCacheFieldRefFromValue(result);
264  if (!type_isa<FIRRTLType>(result.getType()))
265  return mergeLatticeValue(fieldRefResult, fieldRefFrom);
266  // Special-handle PropertyType's, walkGroundType's doesn't support.
267  if (type_isa<PropertyType>(result.getType()))
268  return mergeLatticeValue(fieldRefResult, fieldRefFrom);
269  walkGroundTypes(type_cast<FIRRTLType>(result.getType()),
270  [&](uint64_t fieldID, auto, auto) {
271  mergeLatticeValue(fieldRefResult.getSubField(fieldID),
272  fieldRefFrom.getSubField(fieldID));
273  });
274  }
275 
276  /// setLatticeValue - This is used when a new LatticeValue is computed for
277  /// the result of the specified value that replaces any previous knowledge,
278  /// e.g. because a fold() function on an op returned a new thing. This should
279  /// not be used on operations that have multiple contributors to it, e.g.
280  /// wires or ports.
281  void setLatticeValue(FieldRef value, LatticeValue source) {
282  // Don't even do a map lookup if from has no info in it.
283  if (source.isUnknown())
284  return;
285 
286  // If we've changed this value then revisit all the users.
287  auto &valueEntry = latticeValues[value];
288  if (valueEntry != source) {
289  changedLatticeValueWorklist.push_back(value);
290  valueEntry = source;
291  }
292  }
293 
294  // This function returns a field ref of the given value. This function caches
295  // the result to avoid extra IR traversal if the value is an aggregate
296  // element.
297  FieldRef getOrCacheFieldRefFromValue(Value value) {
298  if (!value.getDefiningOp() || !isAggregate(value.getDefiningOp()))
299  return FieldRef(value, 0);
300  auto &fieldRef = valueToFieldRef[value];
301  if (fieldRef)
302  return fieldRef;
303  return fieldRef = getFieldRefFromValue(value);
304  }
305 
306  /// Return the lattice value for the specified SSA value, extended to the
307  /// width of the specified destType. If allowTruncation is true, then this
308  /// allows truncating the lattice value to the specified type.
309  LatticeValue getExtendedLatticeValue(FieldRef value, FIRRTLType destType,
310  bool allowTruncation = false);
311 
312  /// Mark the given block as executable.
313  void markBlockExecutable(Block *block);
314  void markWireOp(WireOp wireOrReg);
315  void markMemOp(MemOp mem);
316 
317  void markInvalidValueOp(InvalidValueOp invalid);
318  void markAggregateConstantOp(AggregateConstantOp constant);
319  void markInstanceOp(InstanceOp instance);
320  void markObjectOp(ObjectOp object);
321  template <typename OpTy>
322  void markConstantValueOp(OpTy op);
323 
324  void visitConnectLike(FConnectLike connect, FieldRef changedFieldRef);
325  void visitRefSend(RefSendOp send, FieldRef changedFieldRef);
326  void visitRefResolve(RefResolveOp resolve, FieldRef changedFieldRef);
327  void mergeOnlyChangedLatticeValue(Value dest, Value src,
328  FieldRef changedFieldRef);
329  void visitNode(NodeOp node, FieldRef changedFieldRef);
330  void visitOperation(Operation *op, FieldRef changedFieldRef);
331 
332 private:
333  /// This is the current instance graph for the Circuit.
334  InstanceGraph *instanceGraph = nullptr;
335 
336  /// This keeps track of the current state of each tracked value.
337  DenseMap<FieldRef, LatticeValue> latticeValues;
338 
339  /// The set of blocks that are known to execute, or are intrinsically live.
340  SmallPtrSet<Block *, 16> executableBlocks;
341 
342  /// A worklist of values whose LatticeValue recently changed, indicating the
343  /// users need to be reprocessed.
344  SmallVector<FieldRef, 64> changedLatticeValueWorklist;
345 
346  // A map to give operations to be reprocessed.
347  DenseMap<FieldRef, llvm::TinyPtrVector<Operation *>> fieldRefToUsers;
348 
349  // A map to cache results of getFieldRefFromValue since it's costly traverse
350  // the IR.
351  llvm::DenseMap<Value, FieldRef> valueToFieldRef;
352 
353  /// This keeps track of users the instance results that correspond to output
354  /// ports.
355  DenseMap<BlockArgument, llvm::TinyPtrVector<Value>>
356  resultPortToInstanceResultMapping;
357 
358 #ifndef NDEBUG
359  /// A logger used to emit information during the application process.
360  llvm::ScopedPrinter logger{llvm::dbgs()};
361 #endif
362 };
363 } // end anonymous namespace
364 
365 // TODO: handle annotations: [[OptimizableExtModuleAnnotation]]
366 void IMConstPropPass::runOnOperation() {
367  auto circuit = getOperation();
368  LLVM_DEBUG(
369  { logger.startLine() << "IMConstProp : " << circuit.getName() << "\n"; });
370 
371  instanceGraph = &getAnalysis<InstanceGraph>();
372 
373  // Mark the input ports of public modules as being overdefined.
374  for (auto module : circuit.getBodyBlock()->getOps<FModuleOp>()) {
375  if (module.isPublic()) {
376  markBlockExecutable(module.getBodyBlock());
377  for (auto port : module.getBodyBlock()->getArguments())
378  markOverdefined(port);
379  }
380  }
381 
382  // If a value changed lattice state then reprocess any of its users.
383  while (!changedLatticeValueWorklist.empty()) {
384  FieldRef changedFieldRef = changedLatticeValueWorklist.pop_back_val();
385  for (Operation *user : fieldRefToUsers[changedFieldRef]) {
386  if (isBlockExecutable(user->getBlock()))
387  visitOperation(user, changedFieldRef);
388  }
389  }
390 
391  // Rewrite any constants in the modules.
392  mlir::parallelForEach(circuit.getContext(),
393  circuit.getBodyBlock()->getOps<FModuleOp>(),
394  [&](auto op) { rewriteModuleBody(op); });
395 
396  // Clean up our state for next time.
397  instanceGraph = nullptr;
398  latticeValues.clear();
399  executableBlocks.clear();
400  assert(changedLatticeValueWorklist.empty());
401  fieldRefToUsers.clear();
402  valueToFieldRef.clear();
403  resultPortToInstanceResultMapping.clear();
404 }
405 
406 /// Return the lattice value for the specified SSA value, extended to the width
407 /// of the specified destType. If allowTruncation is true, then this allows
408 /// truncating the lattice value to the specified type.
409 LatticeValue IMConstPropPass::getExtendedLatticeValue(FieldRef value,
410  FIRRTLType destType,
411  bool allowTruncation) {
412  // If 'value' hasn't been computed yet, then it is unknown.
413  auto it = latticeValues.find(value);
414  if (it == latticeValues.end())
415  return LatticeValue();
416 
417  auto result = it->second;
418  // Unknown/overdefined stay whatever they are.
419  if (result.isUnknown() || result.isOverdefined())
420  return result;
421 
422  // No extOrTrunc for property types. Return what we have.
423  if (isa<PropertyType>(destType))
424  return result;
425 
426  auto constant = result.getConstant();
427 
428  // If not property, only support integers.
429  auto intAttr = dyn_cast<IntegerAttr>(constant);
430  assert(intAttr && "unsupported lattice attribute kind");
431  if (!intAttr)
432  return result;
433 
434  // No extOrTrunc necessary for bools.
435  if (auto boolAttr = dyn_cast<BoolAttr>(intAttr))
436  return result;
437 
438  // Non-base (or non-ref) types are overdefined.
439  auto baseType = getBaseType(destType);
440  if (!baseType)
441  return LatticeValue::getOverdefined();
442 
443  // If destType is wider than the source constant type, extend it.
444  auto resultConstant = intAttr.getAPSInt();
445  auto destWidth = baseType.getBitWidthOrSentinel();
446  if (destWidth == -1) // We don't support unknown width FIRRTL.
447  return LatticeValue::getOverdefined();
448  if (resultConstant.getBitWidth() == (unsigned)destWidth)
449  return result; // Already the right width, we're done.
450 
451  // Otherwise, extend the constant using the signedness of the source.
452  resultConstant = extOrTruncZeroWidth(resultConstant, destWidth);
453  return LatticeValue(IntegerAttr::get(destType.getContext(), resultConstant));
454 }
455 
456 // NOLINTBEGIN(misc-no-recursion)
457 /// Mark a block executable if it isn't already. This does an initial scan of
458 /// the block, processing nullary operations like wires, instances, and
459 /// constants that only get processed once.
460 void IMConstPropPass::markBlockExecutable(Block *block) {
461  if (!executableBlocks.insert(block).second)
462  return; // Already executable.
463 
464  // Mark block arguments, which are module ports, with don't touch as
465  // overdefined.
466  for (auto ba : block->getArguments())
467  if (hasDontTouch(ba))
468  markOverdefined(ba);
469 
470  for (auto &op : *block) {
471  // Handle each of the special operations in the firrtl dialect.
472  TypeSwitch<Operation *>(&op)
473  .Case<RegOp, RegResetOp>(
474  [&](auto reg) { markOverdefined(op.getResult(0)); })
475  .Case<WireOp>([&](auto wire) { markWireOp(wire); })
476  .Case<ConstantOp, SpecialConstantOp, StringConstantOp,
477  FIntegerConstantOp, BoolConstantOp>(
478  [&](auto constOp) { markConstantValueOp(constOp); })
479  .Case<AggregateConstantOp>(
480  [&](auto aggConstOp) { markAggregateConstantOp(aggConstOp); })
481  .Case<InvalidValueOp>(
482  [&](auto invalid) { markInvalidValueOp(invalid); })
483  .Case<InstanceOp>([&](auto instance) { markInstanceOp(instance); })
484  .Case<ObjectOp>([&](auto obj) { markObjectOp(obj); })
485  .Case<MemOp>([&](auto mem) { markMemOp(mem); })
486  .Case<LayerBlockOp>(
487  [&](auto layer) { markBlockExecutable(layer.getBody(0)); })
488  .Default([&](auto _) {
489  if (isa<mlir::UnrealizedConversionCastOp, VerbatimExprOp,
490  VerbatimWireOp, SubaccessOp>(op) ||
491  op.getNumOperands() == 0) {
492  // Mark operations whose results cannot be tracked as overdefined.
493  // Mark unhandled operations with no operand as well since otherwise
494  // they will remain unknown states until the end.
495  for (auto result : op.getResults())
496  markOverdefined(result);
497  } else if (
498  // Operations that are handled when propagating values, or chasing
499  // indexing.
500  !isAggregate(&op) && !isNodeLike(&op) && op.getNumResults() > 0) {
501  // If an unknown operation has an aggregate operand, mark results as
502  // overdefined since we cannot track the dataflow. Similarly if the
503  // operations create aggregate values, we mark them overdefined.
504 
505  // TODO: We should handle aggregate operations such as
506  // vector_create, bundle_create or vector operations.
507 
508  bool hasAggregateOperand =
509  llvm::any_of(op.getOperandTypes(), [](Type type) {
510  return type_isa<FVectorType, BundleType>(type);
511  });
512 
513  for (auto result : op.getResults())
514  if (hasAggregateOperand ||
515  type_isa<FVectorType, BundleType>(result.getType()))
516  markOverdefined(result);
517  }
518  });
519 
520  // This tracks a dependency from field refs to operations which need
521  // to be added to worklist when lattice values change.
522  if (!isAggregate(&op)) {
523  for (auto operand : op.getOperands()) {
524  auto fieldRef = getOrCacheFieldRefFromValue(operand);
525  auto firrtlType = type_dyn_cast<FIRRTLType>(operand.getType());
526  if (!firrtlType)
527  continue;
528  // Special-handle PropertyType's, walkGroundTypes doesn't support.
529  if (type_isa<PropertyType>(firrtlType)) {
530  fieldRefToUsers[fieldRef].push_back(&op);
531  continue;
532  }
533  walkGroundTypes(firrtlType, [&](uint64_t fieldID, auto type, auto) {
534  fieldRefToUsers[fieldRef.getSubField(fieldID)].push_back(&op);
535  });
536  }
537  }
538  }
539 }
540 // NOLINTEND(misc-no-recursion)
541 
542 void IMConstPropPass::markWireOp(WireOp wire) {
543  auto type = type_dyn_cast<FIRRTLType>(wire.getResult().getType());
544  if (!type || hasDontTouch(wire.getResult()) || wire.isForceable()) {
545  for (auto result : wire.getResults())
546  markOverdefined(result);
547  return;
548  }
549 
550  // Otherwise, this starts out as unknown and is upgraded by connects.
551 }
552 
553 void IMConstPropPass::markMemOp(MemOp mem) {
554  for (auto result : mem.getResults())
555  markOverdefined(result);
556 }
557 
558 template <typename OpTy>
559 void IMConstPropPass::markConstantValueOp(OpTy op) {
560  mergeLatticeValue(getOrCacheFieldRefFromValue(op),
561  LatticeValue(op.getValueAttr()));
562 }
563 
564 void IMConstPropPass::markAggregateConstantOp(AggregateConstantOp constant) {
565  walkGroundTypes(constant.getType(), [&](uint64_t fieldID, auto, auto) {
566  mergeLatticeValue(FieldRef(constant, fieldID),
567  LatticeValue(cast<IntegerAttr>(
568  constant.getAttributeFromFieldID(fieldID))));
569  });
570 }
571 
572 void IMConstPropPass::markInvalidValueOp(InvalidValueOp invalid) {
573  markOverdefined(invalid.getResult());
574 }
575 
576 /// Instances have no operands, so they are visited exactly once when their
577 /// enclosing block is marked live. This sets up the def-use edges for ports.
578 void IMConstPropPass::markInstanceOp(InstanceOp instance) {
579  // Get the module being reference or a null pointer if this is an extmodule.
580  Operation *op = instance.getReferencedModule(*instanceGraph);
581 
582  // If this is an extmodule, just remember that any results and inouts are
583  // overdefined.
584  if (!isa<FModuleOp>(op)) {
585  auto module = dyn_cast<FModuleLike>(op);
586  for (size_t resultNo = 0, e = instance.getNumResults(); resultNo != e;
587  ++resultNo) {
588  auto portVal = instance.getResult(resultNo);
589  // If this is an input to the extmodule, we can ignore it.
590  if (module.getPortDirection(resultNo) == Direction::In)
591  continue;
592 
593  // Otherwise this is a result from it or an inout, mark it as overdefined.
594  markOverdefined(portVal);
595  }
596  return;
597  }
598 
599  // Otherwise this is a defined module.
600  auto fModule = cast<FModuleOp>(op);
601  markBlockExecutable(fModule.getBodyBlock());
602 
603  // Ok, it is a normal internal module reference. Populate
604  // resultPortToInstanceResultMapping, and forward any already-computed values.
605  for (size_t resultNo = 0, e = instance.getNumResults(); resultNo != e;
606  ++resultNo) {
607  auto instancePortVal = instance.getResult(resultNo);
608  // If this is an input to the instance, it will
609  // get handled when any connects to it are processed.
610  if (fModule.getPortDirection(resultNo) == Direction::In)
611  continue;
612 
613  // Otherwise we have a result from the instance. We need to forward results
614  // from the body to this instance result's SSA value, so remember it.
615  BlockArgument modulePortVal = fModule.getArgument(resultNo);
616 
617  resultPortToInstanceResultMapping[modulePortVal].push_back(instancePortVal);
618 
619  // If there is already a value known for modulePortVal make sure to forward
620  // it here.
621  mergeLatticeValue(instancePortVal, modulePortVal);
622  }
623 }
624 
625 void IMConstPropPass::markObjectOp(ObjectOp obj) {
626  // Mark overdefined for now, not supported.
627  markOverdefined(obj);
628 }
629 
630 static std::optional<uint64_t>
631 getFieldIDOffset(FieldRef changedFieldRef, Type connectionType,
632  FieldRef connectedValueFieldRef) {
633  assert(!type_isa<RefType>(connectionType));
634  if (changedFieldRef.getValue() != connectedValueFieldRef.getValue())
635  return {};
636  if (changedFieldRef.getFieldID() >= connectedValueFieldRef.getFieldID() &&
637  changedFieldRef.getFieldID() <=
638  hw::FieldIdImpl::getMaxFieldID(connectionType) +
639  connectedValueFieldRef.getFieldID())
640  return changedFieldRef.getFieldID() - connectedValueFieldRef.getFieldID();
641  return {};
642 }
643 
644 void IMConstPropPass::mergeOnlyChangedLatticeValue(Value dest, Value src,
645  FieldRef changedFieldRef) {
646 
647  // Operate on inner type for refs.
648  auto destType = dest.getType();
649  if (auto refType = type_dyn_cast<RefType>(destType))
650  destType = refType.getType();
651 
652  if (!isa<FIRRTLType>(destType)) {
653  // If the dest is not FIRRTL type, conservatively mark
654  // all of them overdefined.
655  markOverdefined(src);
656  return markOverdefined(dest);
657  }
658 
659  auto fieldRefSrc = getOrCacheFieldRefFromValue(src);
660  auto fieldRefDest = getOrCacheFieldRefFromValue(dest);
661 
662  // If a changed field ref is included the source value, find an offset in the
663  // connection.
664  if (auto srcOffset = getFieldIDOffset(changedFieldRef, destType, fieldRefSrc))
665  mergeLatticeValue(fieldRefDest.getSubField(*srcOffset),
666  fieldRefSrc.getSubField(*srcOffset));
667 
668  // If a changed field ref is included the dest value, find an offset in the
669  // connection.
670  if (auto destOffset =
671  getFieldIDOffset(changedFieldRef, destType, fieldRefDest))
672  mergeLatticeValue(fieldRefDest.getSubField(*destOffset),
673  fieldRefSrc.getSubField(*destOffset));
674 }
675 
676 void IMConstPropPass::visitConnectLike(FConnectLike connect,
677  FieldRef changedFieldRef) {
678  // Operate on inner type for refs.
679  auto destType = connect.getDest().getType();
680  if (auto refType = type_dyn_cast<RefType>(destType))
681  destType = refType.getType();
682 
683  // Mark foreign types as overdefined.
684  if (!isa<FIRRTLType>(destType)) {
685  markOverdefined(connect.getSrc());
686  return markOverdefined(connect.getDest());
687  }
688 
689  auto fieldRefSrc = getOrCacheFieldRefFromValue(connect.getSrc());
690  auto fieldRefDest = getOrCacheFieldRefFromValue(connect.getDest());
691  if (auto subaccess = fieldRefDest.getValue().getDefiningOp<SubaccessOp>()) {
692  // If the destination is subaccess, we give up to precisely track
693  // lattice values and mark entire aggregate as overdefined. This code
694  // should be dead unless we stop lowering of subaccess in LowerTypes.
695  Value parent = subaccess.getInput();
696  while (parent.getDefiningOp() &&
697  parent.getDefiningOp()->getNumOperands() > 0)
698  parent = parent.getDefiningOp()->getOperand(0);
699  return markOverdefined(parent);
700  }
701 
702  auto propagateElementLattice = [&](uint64_t fieldID, FIRRTLType destType) {
703  auto fieldRefDestConnected = fieldRefDest.getSubField(fieldID);
704  assert(!firrtl::type_isa<FIRRTLBaseType>(destType) ||
705  firrtl::type_cast<FIRRTLBaseType>(destType).isGround());
706 
707  // Handle implicit extensions.
708  auto srcValue =
709  getExtendedLatticeValue(fieldRefSrc.getSubField(fieldID), destType);
710  if (srcValue.isUnknown())
711  return;
712 
713  // Driving result ports propagates the value to each instance using the
714  // module.
715  if (auto blockArg = dyn_cast<BlockArgument>(fieldRefDest.getValue())) {
716  for (auto userOfResultPort : resultPortToInstanceResultMapping[blockArg])
717  mergeLatticeValue(
718  FieldRef(userOfResultPort, fieldRefDestConnected.getFieldID()),
719  srcValue);
720  // Output ports are wire-like and may have users.
721  return mergeLatticeValue(fieldRefDestConnected, srcValue);
722  }
723 
724  auto dest = cast<mlir::OpResult>(fieldRefDest.getValue());
725 
726  // For wires and registers, we drive the value of the wire itself, which
727  // automatically propagates to users.
728  if (isWireOrReg(dest.getOwner()))
729  return mergeLatticeValue(fieldRefDestConnected, srcValue);
730 
731  // Driving an instance argument port drives the corresponding argument
732  // of the referenced module.
733  if (auto instance = dest.getDefiningOp<InstanceOp>()) {
734  // Update the dest, when its an instance op.
735  mergeLatticeValue(fieldRefDestConnected, srcValue);
736  auto mod = instance.getReferencedModule<FModuleOp>(*instanceGraph);
737  if (!mod)
738  return;
739 
740  BlockArgument modulePortVal = mod.getArgument(dest.getResultNumber());
741 
742  return mergeLatticeValue(
743  FieldRef(modulePortVal, fieldRefDestConnected.getFieldID()),
744  srcValue);
745  }
746 
747  // Driving a memory result is ignored because these are always treated
748  // as overdefined.
749  if (dest.getDefiningOp<MemOp>())
750  return;
751 
752  // For now, don't support const prop into object fields.
753  if (isa_and_nonnull<ObjectSubfieldOp>(dest.getDefiningOp()))
754  return;
755 
756  connect.emitError("connectlike operation unhandled by IMConstProp")
757  .attachNote(connect.getDest().getLoc())
758  << "connect destination is here";
759  };
760 
761  if (auto srcOffset = getFieldIDOffset(changedFieldRef, destType, fieldRefSrc))
762  propagateElementLattice(
763  *srcOffset,
764  firrtl::type_cast<FIRRTLType>(
765  hw::FieldIdImpl::getFinalTypeByFieldID(destType, *srcOffset)));
766 
767  if (auto relativeDest =
768  getFieldIDOffset(changedFieldRef, destType, fieldRefDest))
769  propagateElementLattice(
770  *relativeDest,
771  firrtl::type_cast<FIRRTLType>(
772  hw::FieldIdImpl::getFinalTypeByFieldID(destType, *relativeDest)));
773 }
774 
775 void IMConstPropPass::visitRefSend(RefSendOp send, FieldRef changedFieldRef) {
776  // Send connects the base value (source) to the result (dest).
777  return mergeOnlyChangedLatticeValue(send.getResult(), send.getBase(),
778  changedFieldRef);
779 }
780 
781 void IMConstPropPass::visitRefResolve(RefResolveOp resolve,
782  FieldRef changedFieldRef) {
783  // Resolve connects the ref value (source) to result (dest).
784  // If writes are ever supported, this will need to work differently!
785  return mergeOnlyChangedLatticeValue(resolve.getResult(), resolve.getRef(),
786  changedFieldRef);
787 }
788 
789 void IMConstPropPass::visitNode(NodeOp node, FieldRef changedFieldRef) {
790  if (hasDontTouch(node.getResult()) || node.isForceable()) {
791  for (auto result : node.getResults())
792  markOverdefined(result);
793  return;
794  }
795 
796  return mergeOnlyChangedLatticeValue(node.getResult(), node.getInput(),
797  changedFieldRef);
798 }
799 
800 /// This method is invoked when an operand of the specified op changes its
801 /// lattice value state and when the block containing the operation is first
802 /// noticed as being alive.
803 ///
804 /// This should update the lattice value state for any result values.
805 ///
806 void IMConstPropPass::visitOperation(Operation *op, FieldRef changedField) {
807  // If this is a operation with special handling, handle it specially.
808  if (auto connectLikeOp = dyn_cast<FConnectLike>(op))
809  return visitConnectLike(connectLikeOp, changedField);
810  if (auto sendOp = dyn_cast<RefSendOp>(op))
811  return visitRefSend(sendOp, changedField);
812  if (auto resolveOp = dyn_cast<RefResolveOp>(op))
813  return visitRefResolve(resolveOp, changedField);
814  if (auto nodeOp = dyn_cast<NodeOp>(op))
815  return visitNode(nodeOp, changedField);
816 
817  // The clock operand of regop changing doesn't change its result value. All
818  // other registers are over-defined. Aggregate operations also doesn't change
819  // its result value.
820  if (isa<RegOp, RegResetOp>(op) || isAggregate(op))
821  return;
822  // TODO: Handle 'when' operations.
823 
824  // If all of the results of this operation are already overdefined (or if
825  // there are no results) then bail out early: we've converged.
826  auto isOverdefinedFn = [&](Value value) {
827  return isOverdefined(getOrCacheFieldRefFromValue(value));
828  };
829  if (llvm::all_of(op->getResults(), isOverdefinedFn))
830  return;
831 
832  // To prevent regressions, mark values as overdefined when they are defined
833  // by operations with a large number of operands.
834  if (op->getNumOperands() > 128) {
835  for (auto value : op->getResults())
836  markOverdefined(value);
837  return;
838  }
839 
840  // Collect all of the constant operands feeding into this operation. If any
841  // are not ready to be resolved, bail out and wait for them to resolve.
842  SmallVector<Attribute, 8> operandConstants;
843  operandConstants.reserve(op->getNumOperands());
844  bool hasUnknown = false;
845  for (Value operand : op->getOperands()) {
846 
847  auto &operandLattice = latticeValues[getOrCacheFieldRefFromValue(operand)];
848 
849  // If the operand is an unknown value, then we generally don't want to
850  // process it - we want to wait until the value is resolved to by the SCCP
851  // algorithm.
852  if (operandLattice.isUnknown())
853  hasUnknown = true;
854 
855  // Otherwise, it must be constant, invalid, or overdefined. Translate them
856  // into attributes that the fold hook can look at.
857  if (operandLattice.isConstant())
858  operandConstants.push_back(operandLattice.getValue());
859  else
860  operandConstants.push_back({});
861  }
862 
863  // Simulate the result of folding this operation to a constant. If folding
864  // fails mark the results as overdefined.
865  SmallVector<OpFoldResult, 8> foldResults;
866  foldResults.reserve(op->getNumResults());
867  if (failed(op->fold(operandConstants, foldResults))) {
868  LLVM_DEBUG({
869  logger.startLine() << "Folding Failed operation : '" << op->getName()
870  << "\n";
871  op->dump();
872  });
873  // If we had unknown arguments, hold off on overdefining
874  if (!hasUnknown)
875  for (auto value : op->getResults())
876  markOverdefined(value);
877  return;
878  }
879 
880  LLVM_DEBUG({
881  logger.getOStream() << "\n";
882  logger.startLine() << "Folding operation : '" << op->getName() << "\n";
883  op->dump();
884  logger.getOStream() << "( ";
885  for (auto cst : operandConstants)
886  if (!cst)
887  logger.getOStream() << "{} ";
888  else
889  logger.getOStream() << cst << " ";
890  logger.unindent();
891  logger.getOStream() << ") -> { ";
892  logger.indent();
893  for (auto &r : foldResults) {
894  logger.getOStream() << r << " ";
895  }
896  logger.unindent();
897  logger.getOStream() << "}\n";
898  });
899 
900  // If the folding was in-place, keep going. This is surprising, but since
901  // only folder that will do in-place updates is the commutative folder, we
902  // aren't going to stop. We don't update the results, since they didn't
903  // change, the op just got shuffled around.
904  if (foldResults.empty())
905  return visitOperation(op, changedField);
906 
907  // Merge the fold results into the lattice for this operation.
908  assert(foldResults.size() == op->getNumResults() && "invalid result size");
909  for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
910  // Merge in the result of the fold, either a constant or a value.
911  LatticeValue resultLattice;
912  OpFoldResult foldResult = foldResults[i];
913  if (Attribute foldAttr = dyn_cast<Attribute>(foldResult)) {
914  if (auto intAttr = dyn_cast<IntegerAttr>(foldAttr))
915  resultLattice = LatticeValue(intAttr);
916  else if (auto strAttr = dyn_cast<StringAttr>(foldAttr))
917  resultLattice = LatticeValue(strAttr);
918  else // Treat unsupported constants as overdefined.
919  resultLattice = LatticeValue::getOverdefined();
920  } else { // Folding to an operand results in its value.
921  resultLattice =
922  latticeValues[getOrCacheFieldRefFromValue(foldResult.get<Value>())];
923  }
924 
925  mergeLatticeValue(getOrCacheFieldRefFromValue(op->getResult(i)),
926  resultLattice);
927  }
928 }
929 
930 void IMConstPropPass::rewriteModuleBody(FModuleOp module) {
931  auto *body = module.getBodyBlock();
932  // If a module is unreachable, just ignore it.
933  if (!executableBlocks.count(body))
934  return;
935 
936  auto builder = OpBuilder::atBlockBegin(body);
937 
938  // Separate the constants we insert from the instructions we are folding and
939  // processing. Leave these as-is until we're done.
940  auto cursor = builder.create<firrtl::ConstantOp>(module.getLoc(), APSInt(1));
941  builder.setInsertionPoint(cursor);
942 
943  // Unique constants per <Const,Type> pair, inserted at entry
944  DenseMap<std::pair<Attribute, Type>, Operation *> constPool;
945 
946  std::function<Value(Attribute, Type, Location)> getConst =
947  [&](Attribute constantValue, Type type, Location loc) -> Value {
948  auto constIt = constPool.find({constantValue, type});
949  if (constIt != constPool.end()) {
950  auto *cst = constIt->second;
951  // Add location to the constant
952  cst->setLoc(builder.getFusedLoc({cst->getLoc(), loc}));
953  return cst->getResult(0);
954  }
955  OpBuilder::InsertionGuard x(builder);
956  builder.setInsertionPoint(cursor);
957 
958  // Materialize reftype "constants" by materializing the constant
959  // and probing it.
960  Operation *cst;
961  if (auto refType = type_dyn_cast<RefType>(type)) {
962  assert(!type_cast<RefType>(type).getForceable() &&
963  "Attempting to materialize rwprobe of constant, shouldn't happen");
964  auto inner = getConst(constantValue, refType.getType(), loc);
965  assert(inner);
966  cst = builder.create<RefSendOp>(loc, inner);
967  } else
968  cst = module->getDialect()->materializeConstant(builder, constantValue,
969  type, loc);
970  assert(cst && "all FIRRTL constants can be materialized");
971  constPool.insert({{constantValue, type}, cst});
972  return cst->getResult(0);
973  };
974 
975  // If the lattice value for the specified value is a constant update it and
976  // return true. Otherwise return false.
977  auto replaceValueIfPossible = [&](Value value) -> bool {
978  // Lambda to replace all uses of this value a replacement, unless this is
979  // the destination of a connect. We leave connects alone to avoid upsetting
980  // flow, i.e., to avoid trying to connect to a constant.
981  auto replaceIfNotConnect = [&value](Value replacement) {
982  value.replaceUsesWithIf(replacement, [](OpOperand &operand) {
983  return !isa<FConnectLike>(operand.getOwner()) ||
984  operand.getOperandNumber() != 0;
985  });
986  };
987 
988  // TODO: Replace entire aggregate.
989  auto it = latticeValues.find(getFieldRefFromValue(value));
990  if (it == latticeValues.end() || it->second.isOverdefined() ||
991  it->second.isUnknown())
992  return false;
993 
994  // Cannot materialize constants for certain types.
995  // TODO: Let materializeConstant tell us what it supports instead of this.
996  // Presently it asserts on unsupported combinations, so check this here.
997  if (!type_isa<FIRRTLBaseType, RefType, FIntegerType, StringType, BoolType>(
998  value.getType()))
999  return false;
1000 
1001  auto cstValue =
1002  getConst(it->second.getValue(), value.getType(), value.getLoc());
1003 
1004  replaceIfNotConnect(cstValue);
1005  return true;
1006  };
1007 
1008  // Constant propagate any ports that are always constant.
1009  for (auto &port : body->getArguments())
1010  replaceValueIfPossible(port);
1011 
1012  // Walk the IR bottom-up when folding. We often fold entire chains of
1013  // operations into constants, which make the intermediate nodes dead. Going
1014  // bottom up eliminates the users of the intermediate ops, allowing us to
1015  // aggressively delete them.
1016  //
1017  // TODO: Handle WhenOps correctly.
1018  bool aboveCursor = false;
1019  module.walk<mlir::WalkOrder::PreOrder, mlir::ReverseIterator>(
1020  [&](Operation *op) {
1021  auto dropIfDead = [&](Operation *op, const Twine &debugPrefix) {
1022  if (op->use_empty() &&
1023  (wouldOpBeTriviallyDead(op) || isDeletableWireOrRegOrNode(op))) {
1024  LLVM_DEBUG(
1025  { logger.getOStream() << debugPrefix << " : " << op << "\n"; });
1026  ++numErasedOp;
1027  op->erase();
1028  return true;
1029  }
1030  return false;
1031  };
1032 
1033  if (aboveCursor) {
1034  // Drop dead constants we materialized.
1035  dropIfDead(op, "Trivially dead materialized constant");
1036  return WalkResult::advance();
1037  }
1038  // Stop once hit the generated constants.
1039  if (op == cursor) {
1040  cursor.erase();
1041  aboveCursor = true;
1042  return WalkResult::advance();
1043  }
1044 
1045  // Connects to values that we found to be constant can be dropped.
1046  if (auto connect = dyn_cast<FConnectLike>(op)) {
1047  if (auto *destOp = connect.getDest().getDefiningOp()) {
1048  auto fieldRef = getOrCacheFieldRefFromValue(connect.getDest());
1049  // Don't remove a field-level connection even if the src value is
1050  // constant. If other elements of the aggregate value are not
1051  // constant, the aggregate value cannot be replaced. We can forward
1052  // the constant to its users, so IMDCE (or SV/HW canonicalizer)
1053  // should remove the aggregate if entire aggregate is dead.
1054  auto type = type_dyn_cast<FIRRTLType>(connect.getDest().getType());
1055  if (!type)
1056  return WalkResult::advance();
1057  auto baseType = type_dyn_cast<FIRRTLBaseType>(type);
1058  if (baseType && !baseType.isGround())
1059  return WalkResult::advance();
1060  if (isDeletableWireOrRegOrNode(destOp) &&
1061  !isOverdefined(fieldRef)) {
1062  connect.erase();
1063  ++numErasedOp;
1064  }
1065  }
1066  return WalkResult::advance();
1067  }
1068 
1069  // We only fold single-result ops and instances in practice, because
1070  // they are the expressions.
1071  if (op->getNumResults() != 1 && !isa<InstanceOp>(op))
1072  return WalkResult::advance();
1073 
1074  // If this operation is already dead, then go ahead and remove it.
1075  if (dropIfDead(op, "Trivially dead"))
1076  return WalkResult::advance();
1077 
1078  // Don't "fold" constants (into equivalent), also because they
1079  // may have name hints we'd like to preserve.
1080  if (op->hasTrait<mlir::OpTrait::ConstantLike>())
1081  return WalkResult::advance();
1082 
1083  // If the op had any constants folded, replace them.
1084  builder.setInsertionPoint(op);
1085  bool foldedAny = false;
1086  for (auto result : op->getResults())
1087  foldedAny |= replaceValueIfPossible(result);
1088 
1089  if (foldedAny)
1090  ++numFoldedOp;
1091 
1092  // If the operation folded to a constant then we can probably nuke it.
1093  if (foldedAny && dropIfDead(op, "Made dead"))
1094  return WalkResult::advance();
1095 
1096  return WalkResult::advance();
1097  });
1098 }
1099 
1100 std::unique_ptr<mlir::Pass> circt::firrtl::createIMConstPropPass() {
1101  return std::make_unique<IMConstPropPass>();
1102 }
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static bool isNodeLike(Operation *op)
Definition: IMConstProp.cpp:55
static std::optional< uint64_t > getFieldIDOffset(FieldRef changedFieldRef, Type connectionType, FieldRef connectedValueFieldRef)
static bool isWireOrReg(Operation *op)
Return true if this is a wire or register.
Definition: IMConstProp.cpp:43
static bool isAggregate(Operation *op)
Return true if this is an aggregate indexer.
Definition: IMConstProp.cpp:48
static bool isDeletableWireOrRegOrNode(Operation *op)
Return true if this is a wire or register we're allowed to delete.
Definition: IMConstProp.cpp:60
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
FieldRef getSubField(unsigned subFieldID) const
Get a reference to a subfield.
Definition: FieldRef.h:62
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Definition: FieldRef.h:59
Value getValue() const
Get the Value which created this location.
Definition: FieldRef.h:37
Location getLoc() const
Get the location associated with the value of this field ref.
Definition: FieldRef.h:67
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool canBeDeleted() const
Check if every annotation can be deleted.
This graph tracks modules and where they are instantiated.
def connect(destination, source)
Definition: support.py:39
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
Definition: FIRRTLUtils.h:222
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
void walkGroundTypes(FIRRTLType firrtlType, llvm::function_ref< void(uint64_t, FIRRTLBaseType, bool)> fn)
Walk leaf ground types in the firrtlType and apply the function fn.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
Definition: FIRRTLOps.cpp:4604
std::unique_ptr< mlir::Pass > createIMConstPropPass()
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
Definition: FIRRTLOps.cpp:316
T & operator<<(T &os, FIRVersion version)
Definition: FIRParser.h:122
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
uint64_t getMaxFieldID(Type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:22
bool operator==(uint64_t a, const FVInt &b)
Definition: FVInt.h:640
bool operator!=(uint64_t a, const FVInt &b)
Definition: FVInt.h:641
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:21