CIRCT  19.0.0git
IMConstProp.cpp
Go to the documentation of this file.
1 //===- IMConstProp.cpp - Intermodule ConstProp and DCE ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements SCCP:
10 // https://www.cs.wustl.edu/~cytron/531Pages/f11/Resources/Papers/cprop.pdf
11 //
12 //===----------------------------------------------------------------------===//
13 
21 #include "circt/Support/APInt.h"
22 #include "mlir/IR/Threading.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/APSInt.h"
25 #include "llvm/ADT/TinyPtrVector.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/ScopedPrinter.h"
28 
29 namespace circt {
30 namespace firrtl {
31 #define GEN_PASS_DEF_IMCONSTPROP
32 #include "circt/Dialect/FIRRTL/Passes.h.inc"
33 } // namespace firrtl
34 } // namespace circt
35 
36 using namespace circt;
37 using namespace firrtl;
38 
39 #define DEBUG_TYPE "IMCP"
40 
41 /// Return true if this is a wire or register.
42 static bool isWireOrReg(Operation *op) {
43  return isa<WireOp, RegResetOp, RegOp>(op);
44 }
45 
46 /// Return true if this is an aggregate indexer.
47 static bool isAggregate(Operation *op) {
48  return isa<SubindexOp, SubaccessOp, SubfieldOp, OpenSubfieldOp,
49  OpenSubindexOp, RefSubOp>(op);
50 }
51 
52 // Return true if this forwards input to output.
53 // Implies has appropriate visit method that propagates changes.
54 static bool isNodeLike(Operation *op) {
55  return isa<NodeOp, RefResolveOp, RefSendOp>(op);
56 }
57 
58 /// Return true if this is a wire or register we're allowed to delete.
59 static bool isDeletableWireOrRegOrNode(Operation *op) {
60  if (!isWireOrReg(op) && !isa<NodeOp>(op))
61  return false;
62 
63  // Always allow deleting wires of probe-type.
64  if (type_isa<RefType>(op->getResult(0).getType()))
65  return true;
66 
67  // Otherwise, don't delete if has anything keeping it around or unknown.
68  return AnnotationSet(op).canBeDeleted() && !hasDontTouch(op) &&
69  hasDroppableName(op) && !cast<Forceable>(op).isForceable();
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // Pass Infrastructure
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 /// This class represents a single lattice value. A lattive value corresponds to
78 /// the various different states that a value in the SCCP dataflow analysis can
79 /// take. See 'Kind' below for more details on the different states a value can
80 /// take.
81 class LatticeValue {
82  enum Kind {
83  /// A value with a yet-to-be-determined value. This state may be changed to
84  /// anything, it hasn't been processed by IMConstProp.
85  Unknown,
86 
87  /// A value that is known to be a constant. This state may be changed to
88  /// overdefined.
89  Constant,
90 
91  /// A value that cannot statically be determined to be a constant. This
92  /// state cannot be changed.
93  Overdefined
94  };
95 
96 public:
97  /// Initialize a lattice value with "Unknown".
98  /*implicit*/ LatticeValue() : valueAndTag(nullptr, Kind::Unknown) {}
99  /// Initialize a lattice value with a constant.
100  /*implicit*/ LatticeValue(IntegerAttr attr)
101  : valueAndTag(attr, Kind::Constant) {}
102  /*implicit*/ LatticeValue(StringAttr attr)
103  : valueAndTag(attr, Kind::Constant) {}
104 
105  static LatticeValue getOverdefined() {
106  LatticeValue result;
107  result.markOverdefined();
108  return result;
109  }
110 
111  bool isUnknown() const { return valueAndTag.getInt() == Kind::Unknown; }
112  bool isConstant() const { return valueAndTag.getInt() == Kind::Constant; }
113  bool isOverdefined() const {
114  return valueAndTag.getInt() == Kind::Overdefined;
115  }
116 
117  /// Mark the lattice value as overdefined.
118  void markOverdefined() {
119  valueAndTag.setPointerAndInt(nullptr, Kind::Overdefined);
120  }
121 
122  /// Mark the lattice value as constant.
123  void markConstant(IntegerAttr value) {
124  valueAndTag.setPointerAndInt(value, Kind::Constant);
125  }
126 
127  /// If this lattice is constant or invalid value, return the attribute.
128  /// Returns nullptr otherwise.
129  Attribute getValue() const { return valueAndTag.getPointer(); }
130 
131  /// If this is in the constant state, return the attribute.
132  Attribute getConstant() const {
133  assert(isConstant());
134  return getValue();
135  }
136 
137  /// Merge in the value of the 'rhs' lattice into this one. Returns true if the
138  /// lattice value changed.
139  bool mergeIn(LatticeValue rhs) {
140  // If we are already overdefined, or rhs is unknown, there is nothing to do.
141  if (isOverdefined() || rhs.isUnknown())
142  return false;
143 
144  // If we are unknown, just take the value of rhs.
145  if (isUnknown()) {
146  valueAndTag = rhs.valueAndTag;
147  return true;
148  }
149 
150  // Otherwise, if this value doesn't match rhs go straight to overdefined.
151  // This happens when we merge "3" and "4" from two different instance sites
152  // for example.
153  if (valueAndTag != rhs.valueAndTag) {
154  markOverdefined();
155  return true;
156  }
157  return false;
158  }
159 
160  bool operator==(const LatticeValue &other) const {
161  return valueAndTag == other.valueAndTag;
162  }
163  bool operator!=(const LatticeValue &other) const {
164  return valueAndTag != other.valueAndTag;
165  }
166 
167 private:
168  /// The attribute value if this is a constant and the tag for the element
169  /// kind. The attribute is an IntegerAttr (or BoolAttr) or StringAttr.
170  llvm::PointerIntPair<Attribute, 2, Kind> valueAndTag;
171 };
172 } // end anonymous namespace
173 
174 LLVM_ATTRIBUTE_USED
175 static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
176  const LatticeValue &lattice) {
177  if (lattice.isUnknown())
178  return os << "<Unknown>";
179  if (lattice.isOverdefined())
180  return os << "<Overdefined>";
181  return os << "<" << lattice.getConstant() << ">";
182 }
183 
184 namespace {
185 struct IMConstPropPass
186  : public circt::firrtl::impl::IMConstPropBase<IMConstPropPass> {
187 
188  void runOnOperation() override;
189  void rewriteModuleBody(FModuleOp module);
190 
191  /// Returns true if the given block is executable.
192  bool isBlockExecutable(Block *block) const {
193  return executableBlocks.count(block);
194  }
195 
196  bool isOverdefined(FieldRef value) const {
197  auto it = latticeValues.find(value);
198  return it != latticeValues.end() && it->second.isOverdefined();
199  }
200 
201  // Mark the given value as overdefined. If the value is an aggregate,
202  // we mark all ground elements as overdefined.
203  void markOverdefined(Value value) {
204  FieldRef fieldRef = getOrCacheFieldRefFromValue(value);
205  auto firrtlType = type_dyn_cast<FIRRTLType>(value.getType());
206  if (!firrtlType || type_isa<PropertyType>(firrtlType)) {
207  markOverdefined(fieldRef);
208  return;
209  }
210 
211  walkGroundTypes(firrtlType, [&](uint64_t fieldID, auto, auto) {
212  markOverdefined(fieldRef.getSubField(fieldID));
213  });
214  }
215 
216  /// Mark the given value as overdefined. This means that we cannot refine a
217  /// specific constant for this value.
218  void markOverdefined(FieldRef value) {
219  auto &entry = latticeValues[value];
220  if (!entry.isOverdefined()) {
221  LLVM_DEBUG({
222  logger.getOStream()
223  << "Setting overdefined : (" << getFieldName(value).first << ")\n";
224  });
225  entry.markOverdefined();
226  changedLatticeValueWorklist.push_back(value);
227  }
228  }
229 
230  /// Merge information from the 'from' lattice value into value. If it
231  /// changes, then users of the value are added to the worklist for
232  /// revisitation.
233  void mergeLatticeValue(FieldRef value, LatticeValue &valueEntry,
234  LatticeValue source) {
235  if (valueEntry.mergeIn(source)) {
236  LLVM_DEBUG({
237  logger.getOStream()
238  << "Changed to " << valueEntry << " : (" << value << ")\n";
239  });
240  changedLatticeValueWorklist.push_back(value);
241  }
242  }
243 
244  void mergeLatticeValue(FieldRef value, LatticeValue source) {
245  // Don't even do a map lookup if from has no info in it.
246  if (source.isUnknown())
247  return;
248  mergeLatticeValue(value, latticeValues[value], source);
249  }
250 
251  void mergeLatticeValue(FieldRef result, FieldRef from) {
252  // If 'from' hasn't been computed yet, then it is unknown, don't do
253  // anything.
254  auto it = latticeValues.find(from);
255  if (it == latticeValues.end())
256  return;
257  mergeLatticeValue(result, it->second);
258  }
259 
260  void mergeLatticeValue(Value result, Value from) {
261  FieldRef fieldRefFrom = getOrCacheFieldRefFromValue(from);
262  FieldRef fieldRefResult = getOrCacheFieldRefFromValue(result);
263  if (!type_isa<FIRRTLType>(result.getType()))
264  return mergeLatticeValue(fieldRefResult, fieldRefFrom);
265  // Special-handle PropertyType's, walkGroundType's doesn't support.
266  if (type_isa<PropertyType>(result.getType()))
267  return mergeLatticeValue(fieldRefResult, fieldRefFrom);
268  walkGroundTypes(type_cast<FIRRTLType>(result.getType()),
269  [&](uint64_t fieldID, auto, auto) {
270  mergeLatticeValue(fieldRefResult.getSubField(fieldID),
271  fieldRefFrom.getSubField(fieldID));
272  });
273  }
274 
275  /// setLatticeValue - This is used when a new LatticeValue is computed for
276  /// the result of the specified value that replaces any previous knowledge,
277  /// e.g. because a fold() function on an op returned a new thing. This should
278  /// not be used on operations that have multiple contributors to it, e.g.
279  /// wires or ports.
280  void setLatticeValue(FieldRef value, LatticeValue source) {
281  // Don't even do a map lookup if from has no info in it.
282  if (source.isUnknown())
283  return;
284 
285  // If we've changed this value then revisit all the users.
286  auto &valueEntry = latticeValues[value];
287  if (valueEntry != source) {
288  changedLatticeValueWorklist.push_back(value);
289  valueEntry = source;
290  }
291  }
292 
293  // This function returns a field ref of the given value. This function caches
294  // the result to avoid extra IR traversal if the value is an aggregate
295  // element.
296  FieldRef getOrCacheFieldRefFromValue(Value value) {
297  if (!value.getDefiningOp() || !isAggregate(value.getDefiningOp()))
298  return FieldRef(value, 0);
299  auto &fieldRef = valueToFieldRef[value];
300  if (fieldRef)
301  return fieldRef;
302  return fieldRef = getFieldRefFromValue(value);
303  }
304 
305  /// Return the lattice value for the specified SSA value, extended to the
306  /// width of the specified destType. If allowTruncation is true, then this
307  /// allows truncating the lattice value to the specified type.
308  LatticeValue getExtendedLatticeValue(FieldRef value, FIRRTLType destType,
309  bool allowTruncation = false);
310 
311  /// Mark the given block as executable.
312  void markBlockExecutable(Block *block);
313  void markWireOp(WireOp wireOrReg);
314  void markMemOp(MemOp mem);
315 
316  void markInvalidValueOp(InvalidValueOp invalid);
317  void markAggregateConstantOp(AggregateConstantOp constant);
318  void markInstanceOp(InstanceOp instance);
319  void markObjectOp(ObjectOp object);
320  template <typename OpTy>
321  void markConstantValueOp(OpTy op);
322 
323  void visitConnectLike(FConnectLike connect, FieldRef changedFieldRef);
324  void visitRefSend(RefSendOp send, FieldRef changedFieldRef);
325  void visitRefResolve(RefResolveOp resolve, FieldRef changedFieldRef);
326  void mergeOnlyChangedLatticeValue(Value dest, Value src,
327  FieldRef changedFieldRef);
328  void visitNode(NodeOp node, FieldRef changedFieldRef);
329  void visitOperation(Operation *op, FieldRef changedFieldRef);
330 
331 private:
332  /// This is the current instance graph for the Circuit.
333  InstanceGraph *instanceGraph = nullptr;
334 
335  /// This keeps track of the current state of each tracked value.
336  DenseMap<FieldRef, LatticeValue> latticeValues;
337 
338  /// The set of blocks that are known to execute, or are intrinsically live.
339  SmallPtrSet<Block *, 16> executableBlocks;
340 
341  /// A worklist of values whose LatticeValue recently changed, indicating the
342  /// users need to be reprocessed.
343  SmallVector<FieldRef, 64> changedLatticeValueWorklist;
344 
345  // A map to give operations to be reprocessed.
346  DenseMap<FieldRef, llvm::TinyPtrVector<Operation *>> fieldRefToUsers;
347 
348  // A map to cache results of getFieldRefFromValue since it's costly traverse
349  // the IR.
350  llvm::DenseMap<Value, FieldRef> valueToFieldRef;
351 
352  /// This keeps track of users the instance results that correspond to output
353  /// ports.
354  DenseMap<BlockArgument, llvm::TinyPtrVector<Value>>
355  resultPortToInstanceResultMapping;
356 
357 #ifndef NDEBUG
358  /// A logger used to emit information during the application process.
359  llvm::ScopedPrinter logger{llvm::dbgs()};
360 #endif
361 };
362 } // end anonymous namespace
363 
364 // TODO: handle annotations: [[OptimizableExtModuleAnnotation]]
365 void IMConstPropPass::runOnOperation() {
366  auto circuit = getOperation();
367  LLVM_DEBUG(
368  { logger.startLine() << "IMConstProp : " << circuit.getName() << "\n"; });
369 
370  instanceGraph = &getAnalysis<InstanceGraph>();
371 
372  // Mark the input ports of public modules as being overdefined.
373  for (auto module : circuit.getBodyBlock()->getOps<FModuleOp>()) {
374  if (module.isPublic()) {
375  markBlockExecutable(module.getBodyBlock());
376  for (auto port : module.getBodyBlock()->getArguments())
377  markOverdefined(port);
378  }
379  }
380 
381  // If a value changed lattice state then reprocess any of its users.
382  while (!changedLatticeValueWorklist.empty()) {
383  FieldRef changedFieldRef = changedLatticeValueWorklist.pop_back_val();
384  for (Operation *user : fieldRefToUsers[changedFieldRef]) {
385  if (isBlockExecutable(user->getBlock()))
386  visitOperation(user, changedFieldRef);
387  }
388  }
389 
390  // Rewrite any constants in the modules.
391  mlir::parallelForEach(circuit.getContext(),
392  circuit.getBodyBlock()->getOps<FModuleOp>(),
393  [&](auto op) { rewriteModuleBody(op); });
394 
395  // Clean up our state for next time.
396  instanceGraph = nullptr;
397  latticeValues.clear();
398  executableBlocks.clear();
399  assert(changedLatticeValueWorklist.empty());
400  fieldRefToUsers.clear();
401  valueToFieldRef.clear();
402  resultPortToInstanceResultMapping.clear();
403 }
404 
405 /// Return the lattice value for the specified SSA value, extended to the width
406 /// of the specified destType. If allowTruncation is true, then this allows
407 /// truncating the lattice value to the specified type.
408 LatticeValue IMConstPropPass::getExtendedLatticeValue(FieldRef value,
409  FIRRTLType destType,
410  bool allowTruncation) {
411  // If 'value' hasn't been computed yet, then it is unknown.
412  auto it = latticeValues.find(value);
413  if (it == latticeValues.end())
414  return LatticeValue();
415 
416  auto result = it->second;
417  // Unknown/overdefined stay whatever they are.
418  if (result.isUnknown() || result.isOverdefined())
419  return result;
420 
421  // No extOrTrunc for property types. Return what we have.
422  if (isa<PropertyType>(destType))
423  return result;
424 
425  auto constant = result.getConstant();
426 
427  // If not property, only support integers.
428  auto intAttr = dyn_cast<IntegerAttr>(constant);
429  assert(intAttr && "unsupported lattice attribute kind");
430  if (!intAttr)
431  return result;
432 
433  // No extOrTrunc necessary for bools.
434  if (auto boolAttr = dyn_cast<BoolAttr>(intAttr))
435  return result;
436 
437  // Non-base (or non-ref) types are overdefined.
438  auto baseType = getBaseType(destType);
439  if (!baseType)
440  return LatticeValue::getOverdefined();
441 
442  // If destType is wider than the source constant type, extend it.
443  auto resultConstant = intAttr.getAPSInt();
444  auto destWidth = baseType.getBitWidthOrSentinel();
445  if (destWidth == -1) // We don't support unknown width FIRRTL.
446  return LatticeValue::getOverdefined();
447  if (resultConstant.getBitWidth() == (unsigned)destWidth)
448  return result; // Already the right width, we're done.
449 
450  // Otherwise, extend the constant using the signedness of the source.
451  resultConstant = extOrTruncZeroWidth(resultConstant, destWidth);
452  return LatticeValue(IntegerAttr::get(destType.getContext(), resultConstant));
453 }
454 
455 // NOLINTBEGIN(misc-no-recursion)
456 /// Mark a block executable if it isn't already. This does an initial scan of
457 /// the block, processing nullary operations like wires, instances, and
458 /// constants that only get processed once.
459 void IMConstPropPass::markBlockExecutable(Block *block) {
460  if (!executableBlocks.insert(block).second)
461  return; // Already executable.
462 
463  // Mark block arguments, which are module ports, with don't touch as
464  // overdefined.
465  for (auto ba : block->getArguments())
466  if (hasDontTouch(ba))
467  markOverdefined(ba);
468 
469  for (auto &op : *block) {
470  // Handle each of the special operations in the firrtl dialect.
471  TypeSwitch<Operation *>(&op)
472  .Case<RegOp, RegResetOp>(
473  [&](auto reg) { markOverdefined(op.getResult(0)); })
474  .Case<WireOp>([&](auto wire) { markWireOp(wire); })
475  .Case<ConstantOp, SpecialConstantOp, StringConstantOp,
476  FIntegerConstantOp, BoolConstantOp>(
477  [&](auto constOp) { markConstantValueOp(constOp); })
478  .Case<AggregateConstantOp>(
479  [&](auto aggConstOp) { markAggregateConstantOp(aggConstOp); })
480  .Case<InvalidValueOp>(
481  [&](auto invalid) { markInvalidValueOp(invalid); })
482  .Case<InstanceOp>([&](auto instance) { markInstanceOp(instance); })
483  .Case<ObjectOp>([&](auto obj) { markObjectOp(obj); })
484  .Case<MemOp>([&](auto mem) { markMemOp(mem); })
485  .Default([&](auto _) {
486  if (isa<mlir::UnrealizedConversionCastOp, VerbatimExprOp,
487  VerbatimWireOp, SubaccessOp>(op) ||
488  op.getNumOperands() == 0) {
489  // Mark operations whose results cannot be tracked as overdefined.
490  // Mark unhandled operations with no operand as well since otherwise
491  // they will remain unknown states until the end.
492  for (auto result : op.getResults())
493  markOverdefined(result);
494  } else if (
495  // Operations that are handled when propagating values, or chasing
496  // indexing.
497  !isAggregate(&op) && !isNodeLike(&op) && op.getNumResults() > 0) {
498  // If an unknown operation has an aggregate operand, mark results as
499  // overdefined since we cannot track the dataflow. Similarly if the
500  // operations create aggregate values, we mark them overdefined.
501 
502  // TODO: We should handle aggregate operations such as
503  // vector_create, bundle_create or vector operations.
504 
505  bool hasAggregateOperand =
506  llvm::any_of(op.getOperandTypes(), [](Type type) {
507  return type_isa<FVectorType, BundleType>(type);
508  });
509 
510  for (auto result : op.getResults())
511  if (hasAggregateOperand ||
512  type_isa<FVectorType, BundleType>(result.getType()))
513  markOverdefined(result);
514  }
515  });
516 
517  // This tracks a dependency from field refs to operations which need
518  // to be added to worklist when lattice values change.
519  if (!isAggregate(&op)) {
520  for (auto operand : op.getOperands()) {
521  auto fieldRef = getOrCacheFieldRefFromValue(operand);
522  auto firrtlType = type_dyn_cast<FIRRTLType>(operand.getType());
523  if (!firrtlType)
524  continue;
525  // Special-handle PropertyType's, walkGroundTypes doesn't support.
526  if (type_isa<PropertyType>(firrtlType)) {
527  fieldRefToUsers[fieldRef].push_back(&op);
528  continue;
529  }
530  walkGroundTypes(firrtlType, [&](uint64_t fieldID, auto type, auto) {
531  fieldRefToUsers[fieldRef.getSubField(fieldID)].push_back(&op);
532  });
533  }
534  }
535  }
536 }
537 // NOLINTEND(misc-no-recursion)
538 
539 void IMConstPropPass::markWireOp(WireOp wire) {
540  auto type = type_dyn_cast<FIRRTLType>(wire.getResult().getType());
541  if (!type || hasDontTouch(wire.getResult()) || wire.isForceable()) {
542  for (auto result : wire.getResults())
543  markOverdefined(result);
544  return;
545  }
546 
547  // Otherwise, this starts out as unknown and is upgraded by connects.
548 }
549 
550 void IMConstPropPass::markMemOp(MemOp mem) {
551  for (auto result : mem.getResults())
552  markOverdefined(result);
553 }
554 
555 template <typename OpTy>
556 void IMConstPropPass::markConstantValueOp(OpTy op) {
557  mergeLatticeValue(getOrCacheFieldRefFromValue(op),
558  LatticeValue(op.getValueAttr()));
559 }
560 
561 void IMConstPropPass::markAggregateConstantOp(AggregateConstantOp constant) {
562  walkGroundTypes(constant.getType(), [&](uint64_t fieldID, auto, auto) {
563  mergeLatticeValue(FieldRef(constant, fieldID),
564  LatticeValue(cast<IntegerAttr>(
565  constant.getAttributeFromFieldID(fieldID))));
566  });
567 }
568 
569 void IMConstPropPass::markInvalidValueOp(InvalidValueOp invalid) {
570  markOverdefined(invalid.getResult());
571 }
572 
573 /// Instances have no operands, so they are visited exactly once when their
574 /// enclosing block is marked live. This sets up the def-use edges for ports.
575 void IMConstPropPass::markInstanceOp(InstanceOp instance) {
576  // Get the module being reference or a null pointer if this is an extmodule.
577  Operation *op = instance.getReferencedModule(*instanceGraph);
578 
579  // If this is an extmodule, just remember that any results and inouts are
580  // overdefined.
581  if (!isa<FModuleOp>(op)) {
582  auto module = dyn_cast<FModuleLike>(op);
583  for (size_t resultNo = 0, e = instance.getNumResults(); resultNo != e;
584  ++resultNo) {
585  auto portVal = instance.getResult(resultNo);
586  // If this is an input to the extmodule, we can ignore it.
587  if (module.getPortDirection(resultNo) == Direction::In)
588  continue;
589 
590  // Otherwise this is a result from it or an inout, mark it as overdefined.
591  markOverdefined(portVal);
592  }
593  return;
594  }
595 
596  // Otherwise this is a defined module.
597  auto fModule = cast<FModuleOp>(op);
598  markBlockExecutable(fModule.getBodyBlock());
599 
600  // Ok, it is a normal internal module reference. Populate
601  // resultPortToInstanceResultMapping, and forward any already-computed values.
602  for (size_t resultNo = 0, e = instance.getNumResults(); resultNo != e;
603  ++resultNo) {
604  auto instancePortVal = instance.getResult(resultNo);
605  // If this is an input to the instance, it will
606  // get handled when any connects to it are processed.
607  if (fModule.getPortDirection(resultNo) == Direction::In)
608  continue;
609 
610  // Otherwise we have a result from the instance. We need to forward results
611  // from the body to this instance result's SSA value, so remember it.
612  BlockArgument modulePortVal = fModule.getArgument(resultNo);
613 
614  resultPortToInstanceResultMapping[modulePortVal].push_back(instancePortVal);
615 
616  // If there is already a value known for modulePortVal make sure to forward
617  // it here.
618  mergeLatticeValue(instancePortVal, modulePortVal);
619  }
620 }
621 
622 void IMConstPropPass::markObjectOp(ObjectOp obj) {
623  // Mark overdefined for now, not supported.
624  markOverdefined(obj);
625 }
626 
627 static std::optional<uint64_t>
628 getFieldIDOffset(FieldRef changedFieldRef, Type connectionType,
629  FieldRef connectedValueFieldRef) {
630  assert(!type_isa<RefType>(connectionType));
631  if (changedFieldRef.getValue() != connectedValueFieldRef.getValue())
632  return {};
633  if (changedFieldRef.getFieldID() >= connectedValueFieldRef.getFieldID() &&
634  changedFieldRef.getFieldID() <=
635  hw::FieldIdImpl::getMaxFieldID(connectionType) +
636  connectedValueFieldRef.getFieldID())
637  return changedFieldRef.getFieldID() - connectedValueFieldRef.getFieldID();
638  return {};
639 }
640 
641 void IMConstPropPass::mergeOnlyChangedLatticeValue(Value dest, Value src,
642  FieldRef changedFieldRef) {
643 
644  // Operate on inner type for refs.
645  auto destType = dest.getType();
646  if (auto refType = type_dyn_cast<RefType>(destType))
647  destType = refType.getType();
648 
649  if (!isa<FIRRTLType>(destType)) {
650  // If the dest is not FIRRTL type, conservatively mark
651  // all of them overdefined.
652  markOverdefined(src);
653  return markOverdefined(dest);
654  }
655 
656  auto fieldRefSrc = getOrCacheFieldRefFromValue(src);
657  auto fieldRefDest = getOrCacheFieldRefFromValue(dest);
658 
659  // If a changed field ref is included the source value, find an offset in the
660  // connection.
661  if (auto srcOffset = getFieldIDOffset(changedFieldRef, destType, fieldRefSrc))
662  mergeLatticeValue(fieldRefDest.getSubField(*srcOffset),
663  fieldRefSrc.getSubField(*srcOffset));
664 
665  // If a changed field ref is included the dest value, find an offset in the
666  // connection.
667  if (auto destOffset =
668  getFieldIDOffset(changedFieldRef, destType, fieldRefDest))
669  mergeLatticeValue(fieldRefDest.getSubField(*destOffset),
670  fieldRefSrc.getSubField(*destOffset));
671 }
672 
673 void IMConstPropPass::visitConnectLike(FConnectLike connect,
674  FieldRef changedFieldRef) {
675  // Operate on inner type for refs.
676  auto destType = connect.getDest().getType();
677  if (auto refType = type_dyn_cast<RefType>(destType))
678  destType = refType.getType();
679 
680  // Mark foreign types as overdefined.
681  if (!isa<FIRRTLType>(destType)) {
682  markOverdefined(connect.getSrc());
683  return markOverdefined(connect.getDest());
684  }
685 
686  auto fieldRefSrc = getOrCacheFieldRefFromValue(connect.getSrc());
687  auto fieldRefDest = getOrCacheFieldRefFromValue(connect.getDest());
688  if (auto subaccess = fieldRefDest.getValue().getDefiningOp<SubaccessOp>()) {
689  // If the destination is subaccess, we give up to precisely track
690  // lattice values and mark entire aggregate as overdefined. This code
691  // should be dead unless we stop lowering of subaccess in LowerTypes.
692  Value parent = subaccess.getInput();
693  while (parent.getDefiningOp() &&
694  parent.getDefiningOp()->getNumOperands() > 0)
695  parent = parent.getDefiningOp()->getOperand(0);
696  return markOverdefined(parent);
697  }
698 
699  auto propagateElementLattice = [&](uint64_t fieldID, FIRRTLType destType) {
700  auto fieldRefDestConnected = fieldRefDest.getSubField(fieldID);
701  assert(!firrtl::type_isa<FIRRTLBaseType>(destType) ||
702  firrtl::type_cast<FIRRTLBaseType>(destType).isGround());
703 
704  // Handle implicit extensions.
705  auto srcValue =
706  getExtendedLatticeValue(fieldRefSrc.getSubField(fieldID), destType);
707  if (srcValue.isUnknown())
708  return;
709 
710  // Driving result ports propagates the value to each instance using the
711  // module.
712  if (auto blockArg = dyn_cast<BlockArgument>(fieldRefDest.getValue())) {
713  for (auto userOfResultPort : resultPortToInstanceResultMapping[blockArg])
714  mergeLatticeValue(
715  FieldRef(userOfResultPort, fieldRefDestConnected.getFieldID()),
716  srcValue);
717  // Output ports are wire-like and may have users.
718  return mergeLatticeValue(fieldRefDestConnected, srcValue);
719  }
720 
721  auto dest = cast<mlir::OpResult>(fieldRefDest.getValue());
722 
723  // For wires and registers, we drive the value of the wire itself, which
724  // automatically propagates to users.
725  if (isWireOrReg(dest.getOwner()))
726  return mergeLatticeValue(fieldRefDestConnected, srcValue);
727 
728  // Driving an instance argument port drives the corresponding argument
729  // of the referenced module.
730  if (auto instance = dest.getDefiningOp<InstanceOp>()) {
731  // Update the dest, when its an instance op.
732  mergeLatticeValue(fieldRefDestConnected, srcValue);
733  auto mod = instance.getReferencedModule<FModuleOp>(*instanceGraph);
734  if (!mod)
735  return;
736 
737  BlockArgument modulePortVal = mod.getArgument(dest.getResultNumber());
738 
739  return mergeLatticeValue(
740  FieldRef(modulePortVal, fieldRefDestConnected.getFieldID()),
741  srcValue);
742  }
743 
744  // Driving a memory result is ignored because these are always treated
745  // as overdefined.
746  if (dest.getDefiningOp<MemOp>())
747  return;
748 
749  // For now, don't support const prop into object fields.
750  if (isa_and_nonnull<ObjectSubfieldOp>(dest.getDefiningOp()))
751  return;
752 
753  connect.emitError("connectlike operation unhandled by IMConstProp")
754  .attachNote(connect.getDest().getLoc())
755  << "connect destination is here";
756  };
757 
758  if (auto srcOffset = getFieldIDOffset(changedFieldRef, destType, fieldRefSrc))
759  propagateElementLattice(
760  *srcOffset,
761  firrtl::type_cast<FIRRTLType>(
762  hw::FieldIdImpl::getFinalTypeByFieldID(destType, *srcOffset)));
763 
764  if (auto relativeDest =
765  getFieldIDOffset(changedFieldRef, destType, fieldRefDest))
766  propagateElementLattice(
767  *relativeDest,
768  firrtl::type_cast<FIRRTLType>(
769  hw::FieldIdImpl::getFinalTypeByFieldID(destType, *relativeDest)));
770 }
771 
772 void IMConstPropPass::visitRefSend(RefSendOp send, FieldRef changedFieldRef) {
773  // Send connects the base value (source) to the result (dest).
774  return mergeOnlyChangedLatticeValue(send.getResult(), send.getBase(),
775  changedFieldRef);
776 }
777 
778 void IMConstPropPass::visitRefResolve(RefResolveOp resolve,
779  FieldRef changedFieldRef) {
780  // Resolve connects the ref value (source) to result (dest).
781  // If writes are ever supported, this will need to work differently!
782  return mergeOnlyChangedLatticeValue(resolve.getResult(), resolve.getRef(),
783  changedFieldRef);
784 }
785 
786 void IMConstPropPass::visitNode(NodeOp node, FieldRef changedFieldRef) {
787  if (hasDontTouch(node.getResult()) || node.isForceable()) {
788  for (auto result : node.getResults())
789  markOverdefined(result);
790  return;
791  }
792 
793  return mergeOnlyChangedLatticeValue(node.getResult(), node.getInput(),
794  changedFieldRef);
795 }
796 
797 /// This method is invoked when an operand of the specified op changes its
798 /// lattice value state and when the block containing the operation is first
799 /// noticed as being alive.
800 ///
801 /// This should update the lattice value state for any result values.
802 ///
803 void IMConstPropPass::visitOperation(Operation *op, FieldRef changedField) {
804  // If this is a operation with special handling, handle it specially.
805  if (auto connectLikeOp = dyn_cast<FConnectLike>(op))
806  return visitConnectLike(connectLikeOp, changedField);
807  if (auto sendOp = dyn_cast<RefSendOp>(op))
808  return visitRefSend(sendOp, changedField);
809  if (auto resolveOp = dyn_cast<RefResolveOp>(op))
810  return visitRefResolve(resolveOp, changedField);
811  if (auto nodeOp = dyn_cast<NodeOp>(op))
812  return visitNode(nodeOp, changedField);
813 
814  // The clock operand of regop changing doesn't change its result value. All
815  // other registers are over-defined. Aggregate operations also doesn't change
816  // its result value.
817  if (isa<RegOp, RegResetOp>(op) || isAggregate(op))
818  return;
819  // TODO: Handle 'when' operations.
820 
821  // If all of the results of this operation are already overdefined (or if
822  // there are no results) then bail out early: we've converged.
823  auto isOverdefinedFn = [&](Value value) {
824  return isOverdefined(getOrCacheFieldRefFromValue(value));
825  };
826  if (llvm::all_of(op->getResults(), isOverdefinedFn))
827  return;
828 
829  // To prevent regressions, mark values as overdefined when they are defined
830  // by operations with a large number of operands.
831  if (op->getNumOperands() > 128) {
832  for (auto value : op->getResults())
833  markOverdefined(value);
834  return;
835  }
836 
837  // Collect all of the constant operands feeding into this operation. If any
838  // are not ready to be resolved, bail out and wait for them to resolve.
839  SmallVector<Attribute, 8> operandConstants;
840  operandConstants.reserve(op->getNumOperands());
841  bool hasUnknown = false;
842  for (Value operand : op->getOperands()) {
843 
844  auto &operandLattice = latticeValues[getOrCacheFieldRefFromValue(operand)];
845 
846  // If the operand is an unknown value, then we generally don't want to
847  // process it - we want to wait until the value is resolved to by the SCCP
848  // algorithm.
849  if (operandLattice.isUnknown())
850  hasUnknown = true;
851 
852  // Otherwise, it must be constant, invalid, or overdefined. Translate them
853  // into attributes that the fold hook can look at.
854  if (operandLattice.isConstant())
855  operandConstants.push_back(operandLattice.getValue());
856  else
857  operandConstants.push_back({});
858  }
859 
860  // Simulate the result of folding this operation to a constant. If folding
861  // fails mark the results as overdefined.
862  SmallVector<OpFoldResult, 8> foldResults;
863  foldResults.reserve(op->getNumResults());
864  if (failed(op->fold(operandConstants, foldResults))) {
865  LLVM_DEBUG({
866  logger.startLine() << "Folding Failed operation : '" << op->getName()
867  << "\n";
868  op->dump();
869  });
870  // If we had unknown arguments, hold off on overdefining
871  if (!hasUnknown)
872  for (auto value : op->getResults())
873  markOverdefined(value);
874  return;
875  }
876 
877  LLVM_DEBUG({
878  logger.getOStream() << "\n";
879  logger.startLine() << "Folding operation : '" << op->getName() << "\n";
880  op->dump();
881  logger.getOStream() << "( ";
882  for (auto cst : operandConstants)
883  if (!cst)
884  logger.getOStream() << "{} ";
885  else
886  logger.getOStream() << cst << " ";
887  logger.unindent();
888  logger.getOStream() << ") -> { ";
889  logger.indent();
890  for (auto &r : foldResults) {
891  logger.getOStream() << r << " ";
892  }
893  logger.unindent();
894  logger.getOStream() << "}\n";
895  });
896 
897  // If the folding was in-place, keep going. This is surprising, but since
898  // only folder that will do in-place updates is the commutative folder, we
899  // aren't going to stop. We don't update the results, since they didn't
900  // change, the op just got shuffled around.
901  if (foldResults.empty())
902  return visitOperation(op, changedField);
903 
904  // Merge the fold results into the lattice for this operation.
905  assert(foldResults.size() == op->getNumResults() && "invalid result size");
906  for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
907  // Merge in the result of the fold, either a constant or a value.
908  LatticeValue resultLattice;
909  OpFoldResult foldResult = foldResults[i];
910  if (Attribute foldAttr = dyn_cast<Attribute>(foldResult)) {
911  if (auto intAttr = dyn_cast<IntegerAttr>(foldAttr))
912  resultLattice = LatticeValue(intAttr);
913  else if (auto strAttr = dyn_cast<StringAttr>(foldAttr))
914  resultLattice = LatticeValue(strAttr);
915  else // Treat unsupported constants as overdefined.
916  resultLattice = LatticeValue::getOverdefined();
917  } else { // Folding to an operand results in its value.
918  resultLattice =
919  latticeValues[getOrCacheFieldRefFromValue(foldResult.get<Value>())];
920  }
921 
922  mergeLatticeValue(getOrCacheFieldRefFromValue(op->getResult(i)),
923  resultLattice);
924  }
925 }
926 
927 void IMConstPropPass::rewriteModuleBody(FModuleOp module) {
928  auto *body = module.getBodyBlock();
929  // If a module is unreachable, just ignore it.
930  if (!executableBlocks.count(body))
931  return;
932 
933  auto builder = OpBuilder::atBlockBegin(body);
934 
935  // Separate the constants we insert from the instructions we are folding and
936  // processing. Leave these as-is until we're done.
937  auto cursor = builder.create<firrtl::ConstantOp>(module.getLoc(), APSInt(1));
938  builder.setInsertionPoint(cursor);
939 
940  // Unique constants per <Const,Type> pair, inserted at entry
941  DenseMap<std::pair<Attribute, Type>, Operation *> constPool;
942 
943  std::function<Value(Attribute, Type, Location)> getConst =
944  [&](Attribute constantValue, Type type, Location loc) -> Value {
945  auto constIt = constPool.find({constantValue, type});
946  if (constIt != constPool.end()) {
947  auto *cst = constIt->second;
948  // Add location to the constant
949  cst->setLoc(builder.getFusedLoc({cst->getLoc(), loc}));
950  return cst->getResult(0);
951  }
952  OpBuilder::InsertionGuard x(builder);
953  builder.setInsertionPoint(cursor);
954 
955  // Materialize reftype "constants" by materializing the constant
956  // and probing it.
957  Operation *cst;
958  if (auto refType = type_dyn_cast<RefType>(type)) {
959  assert(!type_cast<RefType>(type).getForceable() &&
960  "Attempting to materialize rwprobe of constant, shouldn't happen");
961  auto inner = getConst(constantValue, refType.getType(), loc);
962  assert(inner);
963  cst = builder.create<RefSendOp>(loc, inner);
964  } else
965  cst = module->getDialect()->materializeConstant(builder, constantValue,
966  type, loc);
967  assert(cst && "all FIRRTL constants can be materialized");
968  constPool.insert({{constantValue, type}, cst});
969  return cst->getResult(0);
970  };
971 
972  // If the lattice value for the specified value is a constant update it and
973  // return true. Otherwise return false.
974  auto replaceValueIfPossible = [&](Value value) -> bool {
975  // Lambda to replace all uses of this value a replacement, unless this is
976  // the destination of a connect. We leave connects alone to avoid upsetting
977  // flow, i.e., to avoid trying to connect to a constant.
978  auto replaceIfNotConnect = [&value](Value replacement) {
979  value.replaceUsesWithIf(replacement, [](OpOperand &operand) {
980  return !isa<FConnectLike>(operand.getOwner()) ||
981  operand.getOperandNumber() != 0;
982  });
983  };
984 
985  // TODO: Replace entire aggregate.
986  auto it = latticeValues.find(getFieldRefFromValue(value));
987  if (it == latticeValues.end() || it->second.isOverdefined() ||
988  it->second.isUnknown())
989  return false;
990 
991  // Cannot materialize constants for certain types.
992  // TODO: Let materializeConstant tell us what it supports instead of this.
993  // Presently it asserts on unsupported combinations, so check this here.
994  if (!type_isa<FIRRTLBaseType, RefType, FIntegerType, StringType, BoolType>(
995  value.getType()))
996  return false;
997 
998  auto cstValue =
999  getConst(it->second.getValue(), value.getType(), value.getLoc());
1000 
1001  replaceIfNotConnect(cstValue);
1002  return true;
1003  };
1004 
1005  // Constant propagate any ports that are always constant.
1006  for (auto &port : body->getArguments())
1007  replaceValueIfPossible(port);
1008 
1009  // TODO: Walk 'when's preorder with `walk`.
1010 
1011  // Walk the IR bottom-up when folding. We often fold entire chains of
1012  // operations into constants, which make the intermediate nodes dead. Going
1013  // bottom up eliminates the users of the intermediate ops, allowing us to
1014  // aggressively delete them.
1015  bool aboveCursor = false;
1016  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*body))) {
1017  auto dropIfDead = [&](Operation &op, const Twine &debugPrefix) {
1018  if (op.use_empty() &&
1019  (wouldOpBeTriviallyDead(&op) || isDeletableWireOrRegOrNode(&op))) {
1020  LLVM_DEBUG(
1021  { logger.getOStream() << debugPrefix << " : " << op << "\n"; });
1022  ++numErasedOp;
1023  op.erase();
1024  return true;
1025  }
1026  return false;
1027  };
1028 
1029  if (aboveCursor) {
1030  // Drop dead constants we materialized.
1031  dropIfDead(op, "Trivially dead materialized constant");
1032  continue;
1033  }
1034  // Stop once hit the generated constants.
1035  if (&op == cursor) {
1036  cursor.erase();
1037  aboveCursor = true;
1038  continue;
1039  }
1040 
1041  // Connects to values that we found to be constant can be dropped.
1042  if (auto connect = dyn_cast<FConnectLike>(op)) {
1043  if (auto *destOp = connect.getDest().getDefiningOp()) {
1044  auto fieldRef = getOrCacheFieldRefFromValue(connect.getDest());
1045  // Don't remove a field-level connection even if the src value is
1046  // constant. If other elements of the aggregate value are not constant,
1047  // the aggregate value cannot be replaced. We can forward the constant
1048  // to its users, so IMDCE (or SV/HW canonicalizer) should remove the
1049  // aggregate if entire aggregate is dead.
1050  auto type = type_dyn_cast<FIRRTLType>(connect.getDest().getType());
1051  if (!type)
1052  continue;
1053  auto baseType = type_dyn_cast<FIRRTLBaseType>(type);
1054  if (baseType && !baseType.isGround())
1055  continue;
1056  if (isDeletableWireOrRegOrNode(destOp) && !isOverdefined(fieldRef)) {
1057  connect.erase();
1058  ++numErasedOp;
1059  }
1060  }
1061  continue;
1062  }
1063 
1064  // We only fold single-result ops and instances in practice, because they
1065  // are the expressions.
1066  if (op.getNumResults() != 1 && !isa<InstanceOp>(op))
1067  continue;
1068 
1069  // If this operation is already dead, then go ahead and remove it.
1070  if (dropIfDead(op, "Trivially dead"))
1071  continue;
1072 
1073  // Don't "fold" constants (into equivalent), also because they
1074  // may have name hints we'd like to preserve.
1075  if (op.hasTrait<mlir::OpTrait::ConstantLike>())
1076  continue;
1077 
1078  // If the op had any constants folded, replace them.
1079  builder.setInsertionPoint(&op);
1080  bool foldedAny = false;
1081  for (auto result : op.getResults())
1082  foldedAny |= replaceValueIfPossible(result);
1083 
1084  if (foldedAny)
1085  ++numFoldedOp;
1086 
1087  // If the operation folded to a constant then we can probably nuke it.
1088  if (foldedAny && dropIfDead(op, "Made dead"))
1089  continue;
1090  }
1091 }
1092 
1093 std::unique_ptr<mlir::Pass> circt::firrtl::createIMConstPropPass() {
1094  return std::make_unique<IMConstPropPass>();
1095 }
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static bool isNodeLike(Operation *op)
Definition: IMConstProp.cpp:54
static std::optional< uint64_t > getFieldIDOffset(FieldRef changedFieldRef, Type connectionType, FieldRef connectedValueFieldRef)
static bool isWireOrReg(Operation *op)
Return true if this is a wire or register.
Definition: IMConstProp.cpp:42
static bool isAggregate(Operation *op)
Return true if this is an aggregate indexer.
Definition: IMConstProp.cpp:47
static bool isDeletableWireOrRegOrNode(Operation *op)
Return true if this is a wire or register we're allowed to delete.
Definition: IMConstProp.cpp:59
bool operator!=(const ResetDomain &a, const ResetDomain &b)
Definition: InferResets.cpp:81
bool operator==(const ResetDomain &a, const ResetDomain &b)
Definition: InferResets.cpp:78
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
FieldRef getSubField(unsigned subFieldID) const
Get a reference to a subfield.
Definition: FieldRef.h:62
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Definition: FieldRef.h:59
Value getValue() const
Get the Value which created this location.
Definition: FieldRef.h:37
Location getLoc() const
Get the location associated with the value of this field ref.
Definition: FieldRef.h:67
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool canBeDeleted() const
Check if every annotation can be deleted.
This graph tracks modules and where they are instantiated.
def connect(destination, source)
Definition: support.py:37
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
Definition: FIRRTLUtils.h:222
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
void walkGroundTypes(FIRRTLType firrtlType, llvm::function_ref< void(uint64_t, FIRRTLBaseType, bool)> fn)
Walk leaf ground types in the firrtlType and apply the function fn.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
Definition: FIRRTLOps.cpp:4571
std::unique_ptr< mlir::Pass > createIMConstPropPass()
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
Definition: FIRRTLOps.cpp:314
T & operator<<(T &os, FIRVersion version)
Definition: FIRParser.h:119
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
uint64_t getMaxFieldID(Type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:22
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:20