CIRCT  20.0.0git
IMConstProp.cpp
Go to the documentation of this file.
1 //===- IMConstProp.cpp - Intermodule ConstProp and DCE ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements SCCP:
10 // https://www.cs.wustl.edu/~cytron/531Pages/f11/Resources/Papers/cprop.pdf
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "circt/Support/APInt.h"
21 #include "mlir/IR/Iterators.h"
22 #include "mlir/IR/Threading.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/APSInt.h"
25 #include "llvm/ADT/TinyPtrVector.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/ScopedPrinter.h"
28 
29 namespace circt {
30 namespace firrtl {
31 #define GEN_PASS_DEF_IMCONSTPROP
32 #include "circt/Dialect/FIRRTL/Passes.h.inc"
33 } // namespace firrtl
34 } // namespace circt
35 
36 using namespace circt;
37 using namespace firrtl;
38 
39 #define DEBUG_TYPE "IMCP"
40 
41 /// Return true if this is a wire or register.
42 static bool isWireOrReg(Operation *op) {
43  return isa<WireOp, RegResetOp, RegOp>(op);
44 }
45 
46 /// Return true if this is an aggregate indexer.
47 static bool isAggregate(Operation *op) {
48  return isa<SubindexOp, SubaccessOp, SubfieldOp, OpenSubfieldOp,
49  OpenSubindexOp, RefSubOp>(op);
50 }
51 
52 // Return true if this forwards input to output.
53 // Implies has appropriate visit method that propagates changes.
54 static bool isNodeLike(Operation *op) {
55  return isa<NodeOp, RefResolveOp, RefSendOp>(op);
56 }
57 
58 /// Return true if this is a wire or register we're allowed to delete.
59 static bool isDeletableWireOrRegOrNode(Operation *op) {
60  if (!isWireOrReg(op) && !isa<NodeOp>(op))
61  return false;
62 
63  // Always allow deleting wires of probe-type.
64  if (type_isa<RefType>(op->getResult(0).getType()))
65  return true;
66 
67  // Otherwise, don't delete if has anything keeping it around or unknown.
68  return AnnotationSet(op).empty() && !hasDontTouch(op) &&
69  hasDroppableName(op) && !cast<Forceable>(op).isForceable();
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // Pass Infrastructure
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 /// This class represents a single lattice value. A lattive value corresponds to
78 /// the various different states that a value in the SCCP dataflow analysis can
79 /// take. See 'Kind' below for more details on the different states a value can
80 /// take.
81 class LatticeValue {
82  enum Kind {
83  /// A value with a yet-to-be-determined value. This state may be changed to
84  /// anything, it hasn't been processed by IMConstProp.
85  Unknown,
86 
87  /// A value that is known to be a constant. This state may be changed to
88  /// overdefined.
89  Constant,
90 
91  /// A value that cannot statically be determined to be a constant. This
92  /// state cannot be changed.
93  Overdefined
94  };
95 
96 public:
97  /// Initialize a lattice value with "Unknown".
98  /*implicit*/ LatticeValue() : valueAndTag(nullptr, Kind::Unknown) {}
99  /// Initialize a lattice value with a constant.
100  /*implicit*/ LatticeValue(IntegerAttr attr)
101  : valueAndTag(attr, Kind::Constant) {}
102  /*implicit*/ LatticeValue(StringAttr attr)
103  : valueAndTag(attr, Kind::Constant) {}
104 
105  static LatticeValue getOverdefined() {
106  LatticeValue result;
107  result.markOverdefined();
108  return result;
109  }
110 
111  bool isUnknown() const { return valueAndTag.getInt() == Kind::Unknown; }
112  bool isConstant() const { return valueAndTag.getInt() == Kind::Constant; }
113  bool isOverdefined() const {
114  return valueAndTag.getInt() == Kind::Overdefined;
115  }
116 
117  /// Mark the lattice value as overdefined.
118  void markOverdefined() {
119  valueAndTag.setPointerAndInt(nullptr, Kind::Overdefined);
120  }
121 
122  /// Mark the lattice value as constant.
123  void markConstant(IntegerAttr value) {
124  valueAndTag.setPointerAndInt(value, Kind::Constant);
125  }
126 
127  /// If this lattice is constant or invalid value, return the attribute.
128  /// Returns nullptr otherwise.
129  Attribute getValue() const { return valueAndTag.getPointer(); }
130 
131  /// If this is in the constant state, return the attribute.
132  Attribute getConstant() const {
133  assert(isConstant());
134  return getValue();
135  }
136 
137  /// Merge in the value of the 'rhs' lattice into this one. Returns true if the
138  /// lattice value changed.
139  bool mergeIn(LatticeValue rhs) {
140  // If we are already overdefined, or rhs is unknown, there is nothing to do.
141  if (isOverdefined() || rhs.isUnknown())
142  return false;
143 
144  // If we are unknown, just take the value of rhs.
145  if (isUnknown()) {
146  valueAndTag = rhs.valueAndTag;
147  return true;
148  }
149 
150  // Otherwise, if this value doesn't match rhs go straight to overdefined.
151  // This happens when we merge "3" and "4" from two different instance sites
152  // for example.
153  if (valueAndTag != rhs.valueAndTag) {
154  markOverdefined();
155  return true;
156  }
157  return false;
158  }
159 
160  bool operator==(const LatticeValue &other) const {
161  return valueAndTag == other.valueAndTag;
162  }
163  bool operator!=(const LatticeValue &other) const {
164  return valueAndTag != other.valueAndTag;
165  }
166 
167 private:
168  /// The attribute value if this is a constant and the tag for the element
169  /// kind. The attribute is an IntegerAttr (or BoolAttr) or StringAttr.
170  llvm::PointerIntPair<Attribute, 2, Kind> valueAndTag;
171 };
172 } // end anonymous namespace
173 
174 LLVM_ATTRIBUTE_USED
175 static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
176  const LatticeValue &lattice) {
177  if (lattice.isUnknown())
178  return os << "<Unknown>";
179  if (lattice.isOverdefined())
180  return os << "<Overdefined>";
181  return os << "<" << lattice.getConstant() << ">";
182 }
183 
184 namespace {
185 struct IMConstPropPass
186  : public circt::firrtl::impl::IMConstPropBase<IMConstPropPass> {
187 
188  void runOnOperation() override;
189  void rewriteModuleBody(FModuleOp module);
190 
191  /// Returns true if the given block is executable.
192  bool isBlockExecutable(Block *block) const {
193  return executableBlocks.count(block);
194  }
195 
196  bool isOverdefined(FieldRef value) const {
197  auto it = latticeValues.find(value);
198  return it != latticeValues.end() && it->second.isOverdefined();
199  }
200 
201  // Mark the given value as overdefined. If the value is an aggregate,
202  // we mark all ground elements as overdefined.
203  void markOverdefined(Value value) {
204  FieldRef fieldRef = getOrCacheFieldRefFromValue(value);
205  auto firrtlType = type_dyn_cast<FIRRTLType>(value.getType());
206  if (!firrtlType || type_isa<PropertyType>(firrtlType)) {
207  markOverdefined(fieldRef);
208  return;
209  }
210 
211  walkGroundTypes(firrtlType, [&](uint64_t fieldID, auto, auto) {
212  markOverdefined(fieldRef.getSubField(fieldID));
213  });
214  }
215 
216  /// Mark the given value as overdefined. This means that we cannot refine a
217  /// specific constant for this value.
218  void markOverdefined(FieldRef value) {
219  auto &entry = latticeValues[value];
220  if (!entry.isOverdefined()) {
221  LLVM_DEBUG({
222  logger.getOStream()
223  << "Setting overdefined : (" << getFieldName(value).first << ")\n";
224  });
225  entry.markOverdefined();
226  changedLatticeValueWorklist.push_back(value);
227  }
228  }
229 
230  /// Merge information from the 'from' lattice value into value. If it
231  /// changes, then users of the value are added to the worklist for
232  /// revisitation.
233  void mergeLatticeValue(FieldRef value, LatticeValue &valueEntry,
234  LatticeValue source) {
235  if (valueEntry.mergeIn(source)) {
236  LLVM_DEBUG({
237  logger.getOStream()
238  << "Changed to " << valueEntry << " : (" << value << ")\n";
239  });
240  changedLatticeValueWorklist.push_back(value);
241  }
242  }
243 
244  void mergeLatticeValue(FieldRef value, LatticeValue source) {
245  // Don't even do a map lookup if from has no info in it.
246  if (source.isUnknown())
247  return;
248  mergeLatticeValue(value, latticeValues[value], source);
249  }
250 
251  void mergeLatticeValue(FieldRef result, FieldRef from) {
252  // If 'from' hasn't been computed yet, then it is unknown, don't do
253  // anything.
254  auto it = latticeValues.find(from);
255  if (it == latticeValues.end())
256  return;
257  mergeLatticeValue(result, it->second);
258  }
259 
260  void mergeLatticeValue(Value result, Value from) {
261  FieldRef fieldRefFrom = getOrCacheFieldRefFromValue(from);
262  FieldRef fieldRefResult = getOrCacheFieldRefFromValue(result);
263  if (!type_isa<FIRRTLType>(result.getType()))
264  return mergeLatticeValue(fieldRefResult, fieldRefFrom);
265  // Special-handle PropertyType's, walkGroundType's doesn't support.
266  if (type_isa<PropertyType>(result.getType()))
267  return mergeLatticeValue(fieldRefResult, fieldRefFrom);
268  walkGroundTypes(type_cast<FIRRTLType>(result.getType()),
269  [&](uint64_t fieldID, auto, auto) {
270  mergeLatticeValue(fieldRefResult.getSubField(fieldID),
271  fieldRefFrom.getSubField(fieldID));
272  });
273  }
274 
275  /// setLatticeValue - This is used when a new LatticeValue is computed for
276  /// the result of the specified value that replaces any previous knowledge,
277  /// e.g. because a fold() function on an op returned a new thing. This should
278  /// not be used on operations that have multiple contributors to it, e.g.
279  /// wires or ports.
280  void setLatticeValue(FieldRef value, LatticeValue source) {
281  // Don't even do a map lookup if from has no info in it.
282  if (source.isUnknown())
283  return;
284 
285  // If we've changed this value then revisit all the users.
286  auto &valueEntry = latticeValues[value];
287  if (valueEntry != source) {
288  changedLatticeValueWorklist.push_back(value);
289  valueEntry = source;
290  }
291  }
292 
293  // This function returns a field ref of the given value. This function caches
294  // the result to avoid extra IR traversal if the value is an aggregate
295  // element.
296  FieldRef getOrCacheFieldRefFromValue(Value value) {
297  if (!value.getDefiningOp() || !isAggregate(value.getDefiningOp()))
298  return FieldRef(value, 0);
299  auto &fieldRef = valueToFieldRef[value];
300  if (fieldRef)
301  return fieldRef;
302  return fieldRef = getFieldRefFromValue(value);
303  }
304 
305  /// Return the lattice value for the specified SSA value, extended to the
306  /// width of the specified destType. If allowTruncation is true, then this
307  /// allows truncating the lattice value to the specified type.
308  LatticeValue getExtendedLatticeValue(FieldRef value, FIRRTLType destType,
309  bool allowTruncation = false);
310 
311  /// Mark the given block as executable.
312  void markBlockExecutable(Block *block);
313  void markWireOp(WireOp wireOrReg);
314  void markMemOp(MemOp mem);
315 
316  void markInvalidValueOp(InvalidValueOp invalid);
317  void markAggregateConstantOp(AggregateConstantOp constant);
318  void markInstanceOp(InstanceOp instance);
319  void markObjectOp(ObjectOp object);
320  template <typename OpTy>
321  void markConstantValueOp(OpTy op);
322 
323  void visitConnectLike(FConnectLike connect, FieldRef changedFieldRef);
324  void visitRefSend(RefSendOp send, FieldRef changedFieldRef);
325  void visitRefResolve(RefResolveOp resolve, FieldRef changedFieldRef);
326  void mergeOnlyChangedLatticeValue(Value dest, Value src,
327  FieldRef changedFieldRef);
328  void visitNode(NodeOp node, FieldRef changedFieldRef);
329  void visitOperation(Operation *op, FieldRef changedFieldRef);
330 
331 private:
332  /// This is the current instance graph for the Circuit.
333  InstanceGraph *instanceGraph = nullptr;
334 
335  /// This keeps track of the current state of each tracked value.
336  DenseMap<FieldRef, LatticeValue> latticeValues;
337 
338  /// The set of blocks that are known to execute, or are intrinsically live.
339  SmallPtrSet<Block *, 16> executableBlocks;
340 
341  /// A worklist of values whose LatticeValue recently changed, indicating the
342  /// users need to be reprocessed.
343  SmallVector<FieldRef, 64> changedLatticeValueWorklist;
344 
345  // A map to give operations to be reprocessed.
346  DenseMap<FieldRef, llvm::TinyPtrVector<Operation *>> fieldRefToUsers;
347 
348  // A map to cache results of getFieldRefFromValue since it's costly traverse
349  // the IR.
350  llvm::DenseMap<Value, FieldRef> valueToFieldRef;
351 
352  /// This keeps track of users the instance results that correspond to output
353  /// ports.
354  DenseMap<BlockArgument, llvm::TinyPtrVector<Value>>
355  resultPortToInstanceResultMapping;
356 
357 #ifndef NDEBUG
358  /// A logger used to emit information during the application process.
359  llvm::ScopedPrinter logger{llvm::dbgs()};
360 #endif
361 };
362 } // end anonymous namespace
363 
364 // TODO: handle annotations: [[OptimizableExtModuleAnnotation]]
365 void IMConstPropPass::runOnOperation() {
366  auto circuit = getOperation();
367  LLVM_DEBUG(
368  { logger.startLine() << "IMConstProp : " << circuit.getName() << "\n"; });
369 
370  instanceGraph = &getAnalysis<InstanceGraph>();
371 
372  // Mark the input ports of public modules as being overdefined.
373  for (auto module : circuit.getBodyBlock()->getOps<FModuleOp>()) {
374  if (module.isPublic()) {
375  markBlockExecutable(module.getBodyBlock());
376  for (auto port : module.getBodyBlock()->getArguments())
377  markOverdefined(port);
378  }
379  }
380 
381  // If a value changed lattice state then reprocess any of its users.
382  while (!changedLatticeValueWorklist.empty()) {
383  FieldRef changedFieldRef = changedLatticeValueWorklist.pop_back_val();
384  for (Operation *user : fieldRefToUsers[changedFieldRef]) {
385  if (isBlockExecutable(user->getBlock()))
386  visitOperation(user, changedFieldRef);
387  }
388  }
389 
390  // Rewrite any constants in the modules.
391  mlir::parallelForEach(circuit.getContext(),
392  circuit.getBodyBlock()->getOps<FModuleOp>(),
393  [&](auto op) { rewriteModuleBody(op); });
394 
395  // Clean up our state for next time.
396  instanceGraph = nullptr;
397  latticeValues.clear();
398  executableBlocks.clear();
399  assert(changedLatticeValueWorklist.empty());
400  fieldRefToUsers.clear();
401  valueToFieldRef.clear();
402  resultPortToInstanceResultMapping.clear();
403 }
404 
405 /// Return the lattice value for the specified SSA value, extended to the width
406 /// of the specified destType. If allowTruncation is true, then this allows
407 /// truncating the lattice value to the specified type.
408 LatticeValue IMConstPropPass::getExtendedLatticeValue(FieldRef value,
409  FIRRTLType destType,
410  bool allowTruncation) {
411  // If 'value' hasn't been computed yet, then it is unknown.
412  auto it = latticeValues.find(value);
413  if (it == latticeValues.end())
414  return LatticeValue();
415 
416  auto result = it->second;
417  // Unknown/overdefined stay whatever they are.
418  if (result.isUnknown() || result.isOverdefined())
419  return result;
420 
421  // No extOrTrunc for property types. Return what we have.
422  if (isa<PropertyType>(destType))
423  return result;
424 
425  auto constant = result.getConstant();
426 
427  // If not property, only support integers.
428  auto intAttr = dyn_cast<IntegerAttr>(constant);
429  assert(intAttr && "unsupported lattice attribute kind");
430  if (!intAttr)
431  return result;
432 
433  // No extOrTrunc necessary for bools.
434  if (auto boolAttr = dyn_cast<BoolAttr>(intAttr))
435  return result;
436 
437  // Non-base (or non-ref) types are overdefined.
438  auto baseType = getBaseType(destType);
439  if (!baseType)
440  return LatticeValue::getOverdefined();
441 
442  // If destType is wider than the source constant type, extend it.
443  auto resultConstant = intAttr.getAPSInt();
444  auto destWidth = baseType.getBitWidthOrSentinel();
445  if (destWidth == -1) // We don't support unknown width FIRRTL.
446  return LatticeValue::getOverdefined();
447  if (resultConstant.getBitWidth() == (unsigned)destWidth)
448  return result; // Already the right width, we're done.
449 
450  // Otherwise, extend the constant using the signedness of the source.
451  resultConstant = extOrTruncZeroWidth(resultConstant, destWidth);
452  return LatticeValue(IntegerAttr::get(destType.getContext(), resultConstant));
453 }
454 
455 // NOLINTBEGIN(misc-no-recursion)
456 /// Mark a block executable if it isn't already. This does an initial scan of
457 /// the block, processing nullary operations like wires, instances, and
458 /// constants that only get processed once.
459 void IMConstPropPass::markBlockExecutable(Block *block) {
460  if (!executableBlocks.insert(block).second)
461  return; // Already executable.
462 
463  // Mark block arguments, which are module ports, with don't touch as
464  // overdefined.
465  for (auto ba : block->getArguments())
466  if (hasDontTouch(ba))
467  markOverdefined(ba);
468 
469  for (auto &op : *block) {
470  // Handle each of the special operations in the firrtl dialect.
471  TypeSwitch<Operation *>(&op)
472  .Case<RegOp, RegResetOp>(
473  [&](auto reg) { markOverdefined(op.getResult(0)); })
474  .Case<WireOp>([&](auto wire) { markWireOp(wire); })
475  .Case<ConstantOp, SpecialConstantOp, StringConstantOp,
476  FIntegerConstantOp, BoolConstantOp>(
477  [&](auto constOp) { markConstantValueOp(constOp); })
478  .Case<AggregateConstantOp>(
479  [&](auto aggConstOp) { markAggregateConstantOp(aggConstOp); })
480  .Case<InvalidValueOp>(
481  [&](auto invalid) { markInvalidValueOp(invalid); })
482  .Case<InstanceOp>([&](auto instance) { markInstanceOp(instance); })
483  .Case<ObjectOp>([&](auto obj) { markObjectOp(obj); })
484  .Case<MemOp>([&](auto mem) { markMemOp(mem); })
485  .Case<LayerBlockOp>(
486  [&](auto layer) { markBlockExecutable(layer.getBody(0)); })
487  .Default([&](auto _) {
488  if (isa<mlir::UnrealizedConversionCastOp, VerbatimExprOp,
489  VerbatimWireOp, SubaccessOp>(op) ||
490  op.getNumOperands() == 0) {
491  // Mark operations whose results cannot be tracked as overdefined.
492  // Mark unhandled operations with no operand as well since otherwise
493  // they will remain unknown states until the end.
494  for (auto result : op.getResults())
495  markOverdefined(result);
496  } else if (
497  // Operations that are handled when propagating values, or chasing
498  // indexing.
499  !isAggregate(&op) && !isNodeLike(&op) && op.getNumResults() > 0) {
500  // If an unknown operation has an aggregate operand, mark results as
501  // overdefined since we cannot track the dataflow. Similarly if the
502  // operations create aggregate values, we mark them overdefined.
503 
504  // TODO: We should handle aggregate operations such as
505  // vector_create, bundle_create or vector operations.
506 
507  bool hasAggregateOperand =
508  llvm::any_of(op.getOperandTypes(), [](Type type) {
509  return type_isa<FVectorType, BundleType>(type);
510  });
511 
512  for (auto result : op.getResults())
513  if (hasAggregateOperand ||
514  type_isa<FVectorType, BundleType>(result.getType()))
515  markOverdefined(result);
516  }
517  });
518 
519  // This tracks a dependency from field refs to operations which need
520  // to be added to worklist when lattice values change.
521  if (!isAggregate(&op)) {
522  for (auto operand : op.getOperands()) {
523  auto fieldRef = getOrCacheFieldRefFromValue(operand);
524  auto firrtlType = type_dyn_cast<FIRRTLType>(operand.getType());
525  if (!firrtlType)
526  continue;
527  // Special-handle PropertyType's, walkGroundTypes doesn't support.
528  if (type_isa<PropertyType>(firrtlType)) {
529  fieldRefToUsers[fieldRef].push_back(&op);
530  continue;
531  }
532  walkGroundTypes(firrtlType, [&](uint64_t fieldID, auto type, auto) {
533  fieldRefToUsers[fieldRef.getSubField(fieldID)].push_back(&op);
534  });
535  }
536  }
537  }
538 }
539 // NOLINTEND(misc-no-recursion)
540 
541 void IMConstPropPass::markWireOp(WireOp wire) {
542  auto type = type_dyn_cast<FIRRTLType>(wire.getResult().getType());
543  if (!type || hasDontTouch(wire.getResult()) || wire.isForceable()) {
544  for (auto result : wire.getResults())
545  markOverdefined(result);
546  return;
547  }
548 
549  // Otherwise, this starts out as unknown and is upgraded by connects.
550 }
551 
552 void IMConstPropPass::markMemOp(MemOp mem) {
553  for (auto result : mem.getResults())
554  markOverdefined(result);
555 }
556 
557 template <typename OpTy>
558 void IMConstPropPass::markConstantValueOp(OpTy op) {
559  mergeLatticeValue(getOrCacheFieldRefFromValue(op),
560  LatticeValue(op.getValueAttr()));
561 }
562 
563 void IMConstPropPass::markAggregateConstantOp(AggregateConstantOp constant) {
564  walkGroundTypes(constant.getType(), [&](uint64_t fieldID, auto, auto) {
565  mergeLatticeValue(FieldRef(constant, fieldID),
566  LatticeValue(cast<IntegerAttr>(
567  constant.getAttributeFromFieldID(fieldID))));
568  });
569 }
570 
571 void IMConstPropPass::markInvalidValueOp(InvalidValueOp invalid) {
572  markOverdefined(invalid.getResult());
573 }
574 
575 /// Instances have no operands, so they are visited exactly once when their
576 /// enclosing block is marked live. This sets up the def-use edges for ports.
577 void IMConstPropPass::markInstanceOp(InstanceOp instance) {
578  // Get the module being reference or a null pointer if this is an extmodule.
579  Operation *op = instance.getReferencedModule(*instanceGraph);
580 
581  // If this is an extmodule, just remember that any results and inouts are
582  // overdefined.
583  if (!isa<FModuleOp>(op)) {
584  auto module = dyn_cast<FModuleLike>(op);
585  for (size_t resultNo = 0, e = instance.getNumResults(); resultNo != e;
586  ++resultNo) {
587  auto portVal = instance.getResult(resultNo);
588  // If this is an input to the extmodule, we can ignore it.
589  if (module.getPortDirection(resultNo) == Direction::In)
590  continue;
591 
592  // Otherwise this is a result from it or an inout, mark it as overdefined.
593  markOverdefined(portVal);
594  }
595  return;
596  }
597 
598  // Otherwise this is a defined module.
599  auto fModule = cast<FModuleOp>(op);
600  markBlockExecutable(fModule.getBodyBlock());
601 
602  // Ok, it is a normal internal module reference. Populate
603  // resultPortToInstanceResultMapping, and forward any already-computed values.
604  for (size_t resultNo = 0, e = instance.getNumResults(); resultNo != e;
605  ++resultNo) {
606  auto instancePortVal = instance.getResult(resultNo);
607  // If this is an input to the instance, it will
608  // get handled when any connects to it are processed.
609  if (fModule.getPortDirection(resultNo) == Direction::In)
610  continue;
611 
612  // Otherwise we have a result from the instance. We need to forward results
613  // from the body to this instance result's SSA value, so remember it.
614  BlockArgument modulePortVal = fModule.getArgument(resultNo);
615 
616  resultPortToInstanceResultMapping[modulePortVal].push_back(instancePortVal);
617 
618  // If there is already a value known for modulePortVal make sure to forward
619  // it here.
620  mergeLatticeValue(instancePortVal, modulePortVal);
621  }
622 }
623 
624 void IMConstPropPass::markObjectOp(ObjectOp obj) {
625  // Mark overdefined for now, not supported.
626  markOverdefined(obj);
627 }
628 
629 static std::optional<uint64_t>
630 getFieldIDOffset(FieldRef changedFieldRef, Type connectionType,
631  FieldRef connectedValueFieldRef) {
632  assert(!type_isa<RefType>(connectionType));
633  if (changedFieldRef.getValue() != connectedValueFieldRef.getValue())
634  return {};
635  if (changedFieldRef.getFieldID() >= connectedValueFieldRef.getFieldID() &&
636  changedFieldRef.getFieldID() <=
637  hw::FieldIdImpl::getMaxFieldID(connectionType) +
638  connectedValueFieldRef.getFieldID())
639  return changedFieldRef.getFieldID() - connectedValueFieldRef.getFieldID();
640  return {};
641 }
642 
643 void IMConstPropPass::mergeOnlyChangedLatticeValue(Value dest, Value src,
644  FieldRef changedFieldRef) {
645 
646  // Operate on inner type for refs.
647  auto destType = dest.getType();
648  if (auto refType = type_dyn_cast<RefType>(destType))
649  destType = refType.getType();
650 
651  if (!isa<FIRRTLType>(destType)) {
652  // If the dest is not FIRRTL type, conservatively mark
653  // all of them overdefined.
654  markOverdefined(src);
655  return markOverdefined(dest);
656  }
657 
658  auto fieldRefSrc = getOrCacheFieldRefFromValue(src);
659  auto fieldRefDest = getOrCacheFieldRefFromValue(dest);
660 
661  // If a changed field ref is included the source value, find an offset in the
662  // connection.
663  if (auto srcOffset = getFieldIDOffset(changedFieldRef, destType, fieldRefSrc))
664  mergeLatticeValue(fieldRefDest.getSubField(*srcOffset),
665  fieldRefSrc.getSubField(*srcOffset));
666 
667  // If a changed field ref is included the dest value, find an offset in the
668  // connection.
669  if (auto destOffset =
670  getFieldIDOffset(changedFieldRef, destType, fieldRefDest))
671  mergeLatticeValue(fieldRefDest.getSubField(*destOffset),
672  fieldRefSrc.getSubField(*destOffset));
673 }
674 
675 void IMConstPropPass::visitConnectLike(FConnectLike connect,
676  FieldRef changedFieldRef) {
677  // Operate on inner type for refs.
678  auto destType = connect.getDest().getType();
679  if (auto refType = type_dyn_cast<RefType>(destType))
680  destType = refType.getType();
681 
682  // Mark foreign types as overdefined.
683  if (!isa<FIRRTLType>(destType)) {
684  markOverdefined(connect.getSrc());
685  return markOverdefined(connect.getDest());
686  }
687 
688  auto fieldRefSrc = getOrCacheFieldRefFromValue(connect.getSrc());
689  auto fieldRefDest = getOrCacheFieldRefFromValue(connect.getDest());
690  if (auto subaccess = fieldRefDest.getValue().getDefiningOp<SubaccessOp>()) {
691  // If the destination is subaccess, we give up to precisely track
692  // lattice values and mark entire aggregate as overdefined. This code
693  // should be dead unless we stop lowering of subaccess in LowerTypes.
694  Value parent = subaccess.getInput();
695  while (parent.getDefiningOp() &&
696  parent.getDefiningOp()->getNumOperands() > 0)
697  parent = parent.getDefiningOp()->getOperand(0);
698  return markOverdefined(parent);
699  }
700 
701  auto propagateElementLattice = [&](uint64_t fieldID, FIRRTLType destType) {
702  auto fieldRefDestConnected = fieldRefDest.getSubField(fieldID);
703  assert(!firrtl::type_isa<FIRRTLBaseType>(destType) ||
704  firrtl::type_cast<FIRRTLBaseType>(destType).isGround());
705 
706  // Handle implicit extensions.
707  auto srcValue =
708  getExtendedLatticeValue(fieldRefSrc.getSubField(fieldID), destType);
709  if (srcValue.isUnknown())
710  return;
711 
712  // Driving result ports propagates the value to each instance using the
713  // module.
714  if (auto blockArg = dyn_cast<BlockArgument>(fieldRefDest.getValue())) {
715  for (auto userOfResultPort : resultPortToInstanceResultMapping[blockArg])
716  mergeLatticeValue(
717  FieldRef(userOfResultPort, fieldRefDestConnected.getFieldID()),
718  srcValue);
719  // Output ports are wire-like and may have users.
720  return mergeLatticeValue(fieldRefDestConnected, srcValue);
721  }
722 
723  auto dest = cast<mlir::OpResult>(fieldRefDest.getValue());
724 
725  // For wires and registers, we drive the value of the wire itself, which
726  // automatically propagates to users.
727  if (isWireOrReg(dest.getOwner()))
728  return mergeLatticeValue(fieldRefDestConnected, srcValue);
729 
730  // Driving an instance argument port drives the corresponding argument
731  // of the referenced module.
732  if (auto instance = dest.getDefiningOp<InstanceOp>()) {
733  // Update the dest, when its an instance op.
734  mergeLatticeValue(fieldRefDestConnected, srcValue);
735  auto mod = instance.getReferencedModule<FModuleOp>(*instanceGraph);
736  if (!mod)
737  return;
738 
739  BlockArgument modulePortVal = mod.getArgument(dest.getResultNumber());
740 
741  return mergeLatticeValue(
742  FieldRef(modulePortVal, fieldRefDestConnected.getFieldID()),
743  srcValue);
744  }
745 
746  // Driving a memory result is ignored because these are always treated
747  // as overdefined.
748  if (dest.getDefiningOp<MemOp>())
749  return;
750 
751  // For now, don't support const prop into object fields.
752  if (isa_and_nonnull<ObjectSubfieldOp>(dest.getDefiningOp()))
753  return;
754 
755  connect.emitError("connectlike operation unhandled by IMConstProp")
756  .attachNote(connect.getDest().getLoc())
757  << "connect destination is here";
758  };
759 
760  if (auto srcOffset = getFieldIDOffset(changedFieldRef, destType, fieldRefSrc))
761  propagateElementLattice(
762  *srcOffset,
763  firrtl::type_cast<FIRRTLType>(
764  hw::FieldIdImpl::getFinalTypeByFieldID(destType, *srcOffset)));
765 
766  if (auto relativeDest =
767  getFieldIDOffset(changedFieldRef, destType, fieldRefDest))
768  propagateElementLattice(
769  *relativeDest,
770  firrtl::type_cast<FIRRTLType>(
771  hw::FieldIdImpl::getFinalTypeByFieldID(destType, *relativeDest)));
772 }
773 
774 void IMConstPropPass::visitRefSend(RefSendOp send, FieldRef changedFieldRef) {
775  // Send connects the base value (source) to the result (dest).
776  return mergeOnlyChangedLatticeValue(send.getResult(), send.getBase(),
777  changedFieldRef);
778 }
779 
780 void IMConstPropPass::visitRefResolve(RefResolveOp resolve,
781  FieldRef changedFieldRef) {
782  // Resolve connects the ref value (source) to result (dest).
783  // If writes are ever supported, this will need to work differently!
784  return mergeOnlyChangedLatticeValue(resolve.getResult(), resolve.getRef(),
785  changedFieldRef);
786 }
787 
788 void IMConstPropPass::visitNode(NodeOp node, FieldRef changedFieldRef) {
789  if (hasDontTouch(node.getResult()) || node.isForceable()) {
790  for (auto result : node.getResults())
791  markOverdefined(result);
792  return;
793  }
794 
795  return mergeOnlyChangedLatticeValue(node.getResult(), node.getInput(),
796  changedFieldRef);
797 }
798 
799 /// This method is invoked when an operand of the specified op changes its
800 /// lattice value state and when the block containing the operation is first
801 /// noticed as being alive.
802 ///
803 /// This should update the lattice value state for any result values.
804 ///
805 void IMConstPropPass::visitOperation(Operation *op, FieldRef changedField) {
806  // If this is a operation with special handling, handle it specially.
807  if (auto connectLikeOp = dyn_cast<FConnectLike>(op))
808  return visitConnectLike(connectLikeOp, changedField);
809  if (auto sendOp = dyn_cast<RefSendOp>(op))
810  return visitRefSend(sendOp, changedField);
811  if (auto resolveOp = dyn_cast<RefResolveOp>(op))
812  return visitRefResolve(resolveOp, changedField);
813  if (auto nodeOp = dyn_cast<NodeOp>(op))
814  return visitNode(nodeOp, changedField);
815 
816  // The clock operand of regop changing doesn't change its result value. All
817  // other registers are over-defined. Aggregate operations also doesn't change
818  // its result value.
819  if (isa<RegOp, RegResetOp>(op) || isAggregate(op))
820  return;
821  // TODO: Handle 'when' operations.
822 
823  // If all of the results of this operation are already overdefined (or if
824  // there are no results) then bail out early: we've converged.
825  auto isOverdefinedFn = [&](Value value) {
826  return isOverdefined(getOrCacheFieldRefFromValue(value));
827  };
828  if (llvm::all_of(op->getResults(), isOverdefinedFn))
829  return;
830 
831  // To prevent regressions, mark values as overdefined when they are defined
832  // by operations with a large number of operands.
833  if (op->getNumOperands() > 128) {
834  for (auto value : op->getResults())
835  markOverdefined(value);
836  return;
837  }
838 
839  // Collect all of the constant operands feeding into this operation. If any
840  // are not ready to be resolved, bail out and wait for them to resolve.
841  SmallVector<Attribute, 8> operandConstants;
842  operandConstants.reserve(op->getNumOperands());
843  bool hasUnknown = false;
844  for (Value operand : op->getOperands()) {
845 
846  auto &operandLattice = latticeValues[getOrCacheFieldRefFromValue(operand)];
847 
848  // If the operand is an unknown value, then we generally don't want to
849  // process it - we want to wait until the value is resolved to by the SCCP
850  // algorithm.
851  if (operandLattice.isUnknown())
852  hasUnknown = true;
853 
854  // Otherwise, it must be constant, invalid, or overdefined. Translate them
855  // into attributes that the fold hook can look at.
856  if (operandLattice.isConstant())
857  operandConstants.push_back(operandLattice.getValue());
858  else
859  operandConstants.push_back({});
860  }
861 
862  // Simulate the result of folding this operation to a constant. If folding
863  // fails mark the results as overdefined.
864  SmallVector<OpFoldResult, 8> foldResults;
865  foldResults.reserve(op->getNumResults());
866  if (failed(op->fold(operandConstants, foldResults))) {
867  LLVM_DEBUG({
868  logger.startLine() << "Folding Failed operation : '" << op->getName()
869  << "\n";
870  op->dump();
871  });
872  // If we had unknown arguments, hold off on overdefining
873  if (!hasUnknown)
874  for (auto value : op->getResults())
875  markOverdefined(value);
876  return;
877  }
878 
879  LLVM_DEBUG({
880  logger.getOStream() << "\n";
881  logger.startLine() << "Folding operation : '" << op->getName() << "\n";
882  op->dump();
883  logger.getOStream() << "( ";
884  for (auto cst : operandConstants)
885  if (!cst)
886  logger.getOStream() << "{} ";
887  else
888  logger.getOStream() << cst << " ";
889  logger.unindent();
890  logger.getOStream() << ") -> { ";
891  logger.indent();
892  for (auto &r : foldResults) {
893  logger.getOStream() << r << " ";
894  }
895  logger.unindent();
896  logger.getOStream() << "}\n";
897  });
898 
899  // If the folding was in-place, keep going. This is surprising, but since
900  // only folder that will do in-place updates is the commutative folder, we
901  // aren't going to stop. We don't update the results, since they didn't
902  // change, the op just got shuffled around.
903  if (foldResults.empty())
904  return visitOperation(op, changedField);
905 
906  // Merge the fold results into the lattice for this operation.
907  assert(foldResults.size() == op->getNumResults() && "invalid result size");
908  for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
909  // Merge in the result of the fold, either a constant or a value.
910  LatticeValue resultLattice;
911  OpFoldResult foldResult = foldResults[i];
912  if (Attribute foldAttr = dyn_cast<Attribute>(foldResult)) {
913  if (auto intAttr = dyn_cast<IntegerAttr>(foldAttr))
914  resultLattice = LatticeValue(intAttr);
915  else if (auto strAttr = dyn_cast<StringAttr>(foldAttr))
916  resultLattice = LatticeValue(strAttr);
917  else // Treat unsupported constants as overdefined.
918  resultLattice = LatticeValue::getOverdefined();
919  } else { // Folding to an operand results in its value.
920  resultLattice =
921  latticeValues[getOrCacheFieldRefFromValue(foldResult.get<Value>())];
922  }
923 
924  mergeLatticeValue(getOrCacheFieldRefFromValue(op->getResult(i)),
925  resultLattice);
926  }
927 }
928 
929 void IMConstPropPass::rewriteModuleBody(FModuleOp module) {
930  auto *body = module.getBodyBlock();
931  // If a module is unreachable, just ignore it.
932  if (!executableBlocks.count(body))
933  return;
934 
935  auto builder = OpBuilder::atBlockBegin(body);
936 
937  // Separate the constants we insert from the instructions we are folding and
938  // processing. Leave these as-is until we're done.
939  auto cursor = builder.create<firrtl::ConstantOp>(module.getLoc(), APSInt(1));
940  builder.setInsertionPoint(cursor);
941 
942  // Unique constants per <Const,Type> pair, inserted at entry
943  DenseMap<std::pair<Attribute, Type>, Operation *> constPool;
944 
945  std::function<Value(Attribute, Type, Location)> getConst =
946  [&](Attribute constantValue, Type type, Location loc) -> Value {
947  auto constIt = constPool.find({constantValue, type});
948  if (constIt != constPool.end()) {
949  auto *cst = constIt->second;
950  // Add location to the constant
951  cst->setLoc(builder.getFusedLoc({cst->getLoc(), loc}));
952  return cst->getResult(0);
953  }
954  OpBuilder::InsertionGuard x(builder);
955  builder.setInsertionPoint(cursor);
956 
957  // Materialize reftype "constants" by materializing the constant
958  // and probing it.
959  Operation *cst;
960  if (auto refType = type_dyn_cast<RefType>(type)) {
961  assert(!type_cast<RefType>(type).getForceable() &&
962  "Attempting to materialize rwprobe of constant, shouldn't happen");
963  auto inner = getConst(constantValue, refType.getType(), loc);
964  assert(inner);
965  cst = builder.create<RefSendOp>(loc, inner);
966  } else
967  cst = module->getDialect()->materializeConstant(builder, constantValue,
968  type, loc);
969  assert(cst && "all FIRRTL constants can be materialized");
970  constPool.insert({{constantValue, type}, cst});
971  return cst->getResult(0);
972  };
973 
974  // If the lattice value for the specified value is a constant update it and
975  // return true. Otherwise return false.
976  auto replaceValueIfPossible = [&](Value value) -> bool {
977  // Lambda to replace all uses of this value a replacement, unless this is
978  // the destination of a connect. We leave connects alone to avoid upsetting
979  // flow, i.e., to avoid trying to connect to a constant.
980  auto replaceIfNotConnect = [&value](Value replacement) {
981  value.replaceUsesWithIf(replacement, [](OpOperand &operand) {
982  return !isa<FConnectLike>(operand.getOwner()) ||
983  operand.getOperandNumber() != 0;
984  });
985  };
986 
987  // TODO: Replace entire aggregate.
988  auto it = latticeValues.find(getFieldRefFromValue(value));
989  if (it == latticeValues.end() || it->second.isOverdefined() ||
990  it->second.isUnknown())
991  return false;
992 
993  // Cannot materialize constants for certain types.
994  // TODO: Let materializeConstant tell us what it supports instead of this.
995  // Presently it asserts on unsupported combinations, so check this here.
996  if (!type_isa<FIRRTLBaseType, RefType, FIntegerType, StringType, BoolType>(
997  value.getType()))
998  return false;
999 
1000  auto cstValue =
1001  getConst(it->second.getValue(), value.getType(), value.getLoc());
1002 
1003  replaceIfNotConnect(cstValue);
1004  return true;
1005  };
1006 
1007  // Constant propagate any ports that are always constant.
1008  for (auto &port : body->getArguments())
1009  replaceValueIfPossible(port);
1010 
1011  // Walk the IR bottom-up when folding. We often fold entire chains of
1012  // operations into constants, which make the intermediate nodes dead. Going
1013  // bottom up eliminates the users of the intermediate ops, allowing us to
1014  // aggressively delete them.
1015  //
1016  // TODO: Handle WhenOps correctly.
1017  bool aboveCursor = false;
1018  module.walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
1019  [&](Operation *op) {
1020  auto dropIfDead = [&](Operation *op, const Twine &debugPrefix) {
1021  if (op->use_empty() &&
1022  (wouldOpBeTriviallyDead(op) || isDeletableWireOrRegOrNode(op))) {
1023  LLVM_DEBUG(
1024  { logger.getOStream() << debugPrefix << " : " << op << "\n"; });
1025  ++numErasedOp;
1026  op->erase();
1027  return true;
1028  }
1029  return false;
1030  };
1031 
1032  if (aboveCursor) {
1033  // Drop dead constants we materialized.
1034  dropIfDead(op, "Trivially dead materialized constant");
1035  return WalkResult::advance();
1036  }
1037  // Stop once hit the generated constants.
1038  if (op == cursor) {
1039  cursor.erase();
1040  aboveCursor = true;
1041  return WalkResult::advance();
1042  }
1043 
1044  // Connects to values that we found to be constant can be dropped.
1045  if (auto connect = dyn_cast<FConnectLike>(op)) {
1046  if (auto *destOp = connect.getDest().getDefiningOp()) {
1047  auto fieldRef = getOrCacheFieldRefFromValue(connect.getDest());
1048  // Don't remove a field-level connection even if the src value is
1049  // constant. If other elements of the aggregate value are not
1050  // constant, the aggregate value cannot be replaced. We can forward
1051  // the constant to its users, so IMDCE (or SV/HW canonicalizer)
1052  // should remove the aggregate if entire aggregate is dead.
1053  auto type = type_dyn_cast<FIRRTLType>(connect.getDest().getType());
1054  if (!type)
1055  return WalkResult::advance();
1056  auto baseType = type_dyn_cast<FIRRTLBaseType>(type);
1057  if (baseType && !baseType.isGround())
1058  return WalkResult::advance();
1059  if (isDeletableWireOrRegOrNode(destOp) &&
1060  !isOverdefined(fieldRef)) {
1061  connect.erase();
1062  ++numErasedOp;
1063  }
1064  }
1065  return WalkResult::advance();
1066  }
1067 
1068  // We only fold single-result ops and instances in practice, because
1069  // they are the expressions.
1070  if (op->getNumResults() != 1 && !isa<InstanceOp>(op))
1071  return WalkResult::advance();
1072 
1073  // If this operation is already dead, then go ahead and remove it.
1074  if (dropIfDead(op, "Trivially dead"))
1075  return WalkResult::advance();
1076 
1077  // Don't "fold" constants (into equivalent), also because they
1078  // may have name hints we'd like to preserve.
1079  if (op->hasTrait<mlir::OpTrait::ConstantLike>())
1080  return WalkResult::advance();
1081 
1082  // If the op had any constants folded, replace them.
1083  builder.setInsertionPoint(op);
1084  bool foldedAny = false;
1085  for (auto result : op->getResults())
1086  foldedAny |= replaceValueIfPossible(result);
1087 
1088  if (foldedAny)
1089  ++numFoldedOp;
1090 
1091  // If the operation folded to a constant then we can probably nuke it.
1092  if (foldedAny && dropIfDead(op, "Made dead"))
1093  return WalkResult::advance();
1094 
1095  return WalkResult::advance();
1096  });
1097 }
1098 
1099 std::unique_ptr<mlir::Pass> circt::firrtl::createIMConstPropPass() {
1100  return std::make_unique<IMConstPropPass>();
1101 }
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static bool isNodeLike(Operation *op)
Definition: IMConstProp.cpp:54
static std::optional< uint64_t > getFieldIDOffset(FieldRef changedFieldRef, Type connectionType, FieldRef connectedValueFieldRef)
static bool isWireOrReg(Operation *op)
Return true if this is a wire or register.
Definition: IMConstProp.cpp:42
static bool isAggregate(Operation *op)
Return true if this is an aggregate indexer.
Definition: IMConstProp.cpp:47
static bool isDeletableWireOrRegOrNode(Operation *op)
Return true if this is a wire or register we're allowed to delete.
Definition: IMConstProp.cpp:59
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
FieldRef getSubField(unsigned subFieldID) const
Get a reference to a subfield.
Definition: FieldRef.h:62
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Definition: FieldRef.h:59
Value getValue() const
Get the Value which created this location.
Definition: FieldRef.h:37
Location getLoc() const
Get the location associated with the value of this field ref.
Definition: FieldRef.h:67
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This graph tracks modules and where they are instantiated.
def connect(destination, source)
Definition: support.py:39
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
Definition: FIRRTLUtils.h:222
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
void walkGroundTypes(FIRRTLType firrtlType, llvm::function_ref< void(uint64_t, FIRRTLBaseType, bool)> fn)
Walk leaf ground types in the firrtlType and apply the function fn.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
Definition: FIRRTLOps.cpp:4588
std::unique_ptr< mlir::Pass > createIMConstPropPass()
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
Definition: FIRRTLOps.cpp:317
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
uint64_t getMaxFieldID(Type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:22
bool operator==(uint64_t a, const FVInt &b)
Definition: FVInt.h:650
bool operator!=(uint64_t a, const FVInt &b)
Definition: FVInt.h:651
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:21