CIRCT  19.0.0git
IMConstProp.cpp
Go to the documentation of this file.
1 //===- IMConstProp.cpp - Intermodule ConstProp and DCE ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements SCCP:
10 // https://www.cs.wustl.edu/~cytron/531Pages/f11/Resources/Papers/cprop.pdf
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetails.h"
21 #include "circt/Support/APInt.h"
22 #include "mlir/IR/Threading.h"
23 #include "llvm/ADT/APSInt.h"
24 #include "llvm/ADT/TinyPtrVector.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/ScopedPrinter.h"
27 
28 using namespace circt;
29 using namespace firrtl;
30 
31 #define DEBUG_TYPE "IMCP"
32 
33 /// Return true if this is a wire or register.
34 static bool isWireOrReg(Operation *op) {
35  return isa<WireOp, RegResetOp, RegOp>(op);
36 }
37 
38 /// Return true if this is an aggregate indexer.
39 static bool isAggregate(Operation *op) {
40  return isa<SubindexOp, SubaccessOp, SubfieldOp, OpenSubfieldOp,
41  OpenSubindexOp, RefSubOp>(op);
42 }
43 
44 // Return true if this forwards input to output.
45 // Implies has appropriate visit method that propagates changes.
46 static bool isNodeLike(Operation *op) {
47  return isa<NodeOp, RefResolveOp, RefSendOp>(op);
48 }
49 
50 /// Return true if this is a wire or register we're allowed to delete.
51 static bool isDeletableWireOrRegOrNode(Operation *op) {
52  if (!isWireOrReg(op) && !isa<NodeOp>(op))
53  return false;
54 
55  // Always allow deleting wires of probe-type.
56  if (type_isa<RefType>(op->getResult(0).getType()))
57  return true;
58 
59  // Otherwise, don't delete if has anything keeping it around or unknown.
60  return AnnotationSet(op).canBeDeleted() && !hasDontTouch(op) &&
61  hasDroppableName(op) && !cast<Forceable>(op).isForceable();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // Pass Infrastructure
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// This class represents a single lattice value. A lattive value corresponds to
70 /// the various different states that a value in the SCCP dataflow analysis can
71 /// take. See 'Kind' below for more details on the different states a value can
72 /// take.
73 class LatticeValue {
74  enum Kind {
75  /// A value with a yet-to-be-determined value. This state may be changed to
76  /// anything, it hasn't been processed by IMConstProp.
77  Unknown,
78 
79  /// A value that is known to be a constant. This state may be changed to
80  /// overdefined.
81  Constant,
82 
83  /// A value that cannot statically be determined to be a constant. This
84  /// state cannot be changed.
85  Overdefined
86  };
87 
88 public:
89  /// Initialize a lattice value with "Unknown".
90  /*implicit*/ LatticeValue() : valueAndTag(nullptr, Kind::Unknown) {}
91  /// Initialize a lattice value with a constant.
92  /*implicit*/ LatticeValue(IntegerAttr attr)
93  : valueAndTag(attr, Kind::Constant) {}
94  /*implicit*/ LatticeValue(StringAttr attr)
95  : valueAndTag(attr, Kind::Constant) {}
96 
97  static LatticeValue getOverdefined() {
98  LatticeValue result;
99  result.markOverdefined();
100  return result;
101  }
102 
103  bool isUnknown() const { return valueAndTag.getInt() == Kind::Unknown; }
104  bool isConstant() const { return valueAndTag.getInt() == Kind::Constant; }
105  bool isOverdefined() const {
106  return valueAndTag.getInt() == Kind::Overdefined;
107  }
108 
109  /// Mark the lattice value as overdefined.
110  void markOverdefined() {
111  valueAndTag.setPointerAndInt(nullptr, Kind::Overdefined);
112  }
113 
114  /// Mark the lattice value as constant.
115  void markConstant(IntegerAttr value) {
116  valueAndTag.setPointerAndInt(value, Kind::Constant);
117  }
118 
119  /// If this lattice is constant or invalid value, return the attribute.
120  /// Returns nullptr otherwise.
121  Attribute getValue() const { return valueAndTag.getPointer(); }
122 
123  /// If this is in the constant state, return the attribute.
124  Attribute getConstant() const {
125  assert(isConstant());
126  return getValue();
127  }
128 
129  /// Merge in the value of the 'rhs' lattice into this one. Returns true if the
130  /// lattice value changed.
131  bool mergeIn(LatticeValue rhs) {
132  // If we are already overdefined, or rhs is unknown, there is nothing to do.
133  if (isOverdefined() || rhs.isUnknown())
134  return false;
135 
136  // If we are unknown, just take the value of rhs.
137  if (isUnknown()) {
138  valueAndTag = rhs.valueAndTag;
139  return true;
140  }
141 
142  // Otherwise, if this value doesn't match rhs go straight to overdefined.
143  // This happens when we merge "3" and "4" from two different instance sites
144  // for example.
145  if (valueAndTag != rhs.valueAndTag) {
146  markOverdefined();
147  return true;
148  }
149  return false;
150  }
151 
152  bool operator==(const LatticeValue &other) const {
153  return valueAndTag == other.valueAndTag;
154  }
155  bool operator!=(const LatticeValue &other) const {
156  return valueAndTag != other.valueAndTag;
157  }
158 
159 private:
160  /// The attribute value if this is a constant and the tag for the element
161  /// kind. The attribute is an IntegerAttr (or BoolAttr) or StringAttr.
162  llvm::PointerIntPair<Attribute, 2, Kind> valueAndTag;
163 };
164 } // end anonymous namespace
165 
166 LLVM_ATTRIBUTE_USED
167 static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
168  const LatticeValue &lattice) {
169  if (lattice.isUnknown())
170  return os << "<Unknown>";
171  if (lattice.isOverdefined())
172  return os << "<Overdefined>";
173  return os << "<" << lattice.getConstant() << ">";
174 }
175 
176 namespace {
177 struct IMConstPropPass : public IMConstPropBase<IMConstPropPass> {
178 
179  void runOnOperation() override;
180  void rewriteModuleBody(FModuleOp module);
181 
182  /// Returns true if the given block is executable.
183  bool isBlockExecutable(Block *block) const {
184  return executableBlocks.count(block);
185  }
186 
187  bool isOverdefined(FieldRef value) const {
188  auto it = latticeValues.find(value);
189  return it != latticeValues.end() && it->second.isOverdefined();
190  }
191 
192  // Mark the given value as overdefined. If the value is an aggregate,
193  // we mark all ground elements as overdefined.
194  void markOverdefined(Value value) {
195  FieldRef fieldRef = getOrCacheFieldRefFromValue(value);
196  auto firrtlType = type_dyn_cast<FIRRTLType>(value.getType());
197  if (!firrtlType || type_isa<PropertyType>(firrtlType)) {
198  markOverdefined(fieldRef);
199  return;
200  }
201 
202  walkGroundTypes(firrtlType, [&](uint64_t fieldID, auto, auto) {
203  markOverdefined(fieldRef.getSubField(fieldID));
204  });
205  }
206 
207  /// Mark the given value as overdefined. This means that we cannot refine a
208  /// specific constant for this value.
209  void markOverdefined(FieldRef value) {
210  auto &entry = latticeValues[value];
211  if (!entry.isOverdefined()) {
212  LLVM_DEBUG({
213  logger.getOStream()
214  << "Setting overdefined : (" << getFieldName(value).first << ")\n";
215  });
216  entry.markOverdefined();
217  changedLatticeValueWorklist.push_back(value);
218  }
219  }
220 
221  /// Merge information from the 'from' lattice value into value. If it
222  /// changes, then users of the value are added to the worklist for
223  /// revisitation.
224  void mergeLatticeValue(FieldRef value, LatticeValue &valueEntry,
225  LatticeValue source) {
226  if (valueEntry.mergeIn(source)) {
227  LLVM_DEBUG({
228  logger.getOStream()
229  << "Changed to " << valueEntry << " : (" << value << ")\n";
230  });
231  changedLatticeValueWorklist.push_back(value);
232  }
233  }
234 
235  void mergeLatticeValue(FieldRef value, LatticeValue source) {
236  // Don't even do a map lookup if from has no info in it.
237  if (source.isUnknown())
238  return;
239  mergeLatticeValue(value, latticeValues[value], source);
240  }
241 
242  void mergeLatticeValue(FieldRef result, FieldRef from) {
243  // If 'from' hasn't been computed yet, then it is unknown, don't do
244  // anything.
245  auto it = latticeValues.find(from);
246  if (it == latticeValues.end())
247  return;
248  mergeLatticeValue(result, it->second);
249  }
250 
251  void mergeLatticeValue(Value result, Value from) {
252  FieldRef fieldRefFrom = getOrCacheFieldRefFromValue(from);
253  FieldRef fieldRefResult = getOrCacheFieldRefFromValue(result);
254  if (!type_isa<FIRRTLType>(result.getType()))
255  return mergeLatticeValue(fieldRefResult, fieldRefFrom);
256  // Special-handle PropertyType's, walkGroundType's doesn't support.
257  if (type_isa<PropertyType>(result.getType()))
258  return mergeLatticeValue(fieldRefResult, fieldRefFrom);
259  walkGroundTypes(type_cast<FIRRTLType>(result.getType()),
260  [&](uint64_t fieldID, auto, auto) {
261  mergeLatticeValue(fieldRefResult.getSubField(fieldID),
262  fieldRefFrom.getSubField(fieldID));
263  });
264  }
265 
266  /// setLatticeValue - This is used when a new LatticeValue is computed for
267  /// the result of the specified value that replaces any previous knowledge,
268  /// e.g. because a fold() function on an op returned a new thing. This should
269  /// not be used on operations that have multiple contributors to it, e.g.
270  /// wires or ports.
271  void setLatticeValue(FieldRef value, LatticeValue source) {
272  // Don't even do a map lookup if from has no info in it.
273  if (source.isUnknown())
274  return;
275 
276  // If we've changed this value then revisit all the users.
277  auto &valueEntry = latticeValues[value];
278  if (valueEntry != source) {
279  changedLatticeValueWorklist.push_back(value);
280  valueEntry = source;
281  }
282  }
283 
284  // This function returns a field ref of the given value. This function caches
285  // the result to avoid extra IR traversal if the value is an aggregate
286  // element.
287  FieldRef getOrCacheFieldRefFromValue(Value value) {
288  if (!value.getDefiningOp() || !isAggregate(value.getDefiningOp()))
289  return FieldRef(value, 0);
290  auto &fieldRef = valueToFieldRef[value];
291  if (fieldRef)
292  return fieldRef;
293  return fieldRef = getFieldRefFromValue(value);
294  }
295 
296  /// Return the lattice value for the specified SSA value, extended to the
297  /// width of the specified destType. If allowTruncation is true, then this
298  /// allows truncating the lattice value to the specified type.
299  LatticeValue getExtendedLatticeValue(FieldRef value, FIRRTLType destType,
300  bool allowTruncation = false);
301 
302  /// Mark the given block as executable.
303  void markBlockExecutable(Block *block);
304  void markWireOp(WireOp wireOrReg);
305  void markMemOp(MemOp mem);
306 
307  void markInvalidValueOp(InvalidValueOp invalid);
308  void markAggregateConstantOp(AggregateConstantOp constant);
309  void markInstanceOp(InstanceOp instance);
310  void markObjectOp(ObjectOp object);
311  template <typename OpTy>
312  void markConstantValueOp(OpTy op);
313 
314  void visitConnectLike(FConnectLike connect, FieldRef changedFieldRef);
315  void visitRefSend(RefSendOp send, FieldRef changedFieldRef);
316  void visitRefResolve(RefResolveOp resolve, FieldRef changedFieldRef);
317  void mergeOnlyChangedLatticeValue(Value dest, Value src,
318  FieldRef changedFieldRef);
319  void visitNode(NodeOp node, FieldRef changedFieldRef);
320  void visitOperation(Operation *op, FieldRef changedFieldRef);
321 
322 private:
323  /// This is the current instance graph for the Circuit.
324  InstanceGraph *instanceGraph = nullptr;
325 
326  /// This keeps track of the current state of each tracked value.
327  DenseMap<FieldRef, LatticeValue> latticeValues;
328 
329  /// The set of blocks that are known to execute, or are intrinsically live.
330  SmallPtrSet<Block *, 16> executableBlocks;
331 
332  /// A worklist of values whose LatticeValue recently changed, indicating the
333  /// users need to be reprocessed.
334  SmallVector<FieldRef, 64> changedLatticeValueWorklist;
335 
336  // A map to give operations to be reprocessed.
337  DenseMap<FieldRef, llvm::TinyPtrVector<Operation *>> fieldRefToUsers;
338 
339  // A map to cache results of getFieldRefFromValue since it's costly traverse
340  // the IR.
341  llvm::DenseMap<Value, FieldRef> valueToFieldRef;
342 
343  /// This keeps track of users the instance results that correspond to output
344  /// ports.
345  DenseMap<BlockArgument, llvm::TinyPtrVector<Value>>
346  resultPortToInstanceResultMapping;
347 
348 #ifndef NDEBUG
349  /// A logger used to emit information during the application process.
350  llvm::ScopedPrinter logger{llvm::dbgs()};
351 #endif
352 };
353 } // end anonymous namespace
354 
355 // TODO: handle annotations: [[OptimizableExtModuleAnnotation]]
356 void IMConstPropPass::runOnOperation() {
357  auto circuit = getOperation();
358  LLVM_DEBUG(
359  { logger.startLine() << "IMConstProp : " << circuit.getName() << "\n"; });
360 
361  instanceGraph = &getAnalysis<InstanceGraph>();
362 
363  // Mark the input ports of public modules as being overdefined.
364  for (auto module : circuit.getBodyBlock()->getOps<FModuleOp>()) {
365  if (module.isPublic()) {
366  markBlockExecutable(module.getBodyBlock());
367  for (auto port : module.getBodyBlock()->getArguments())
368  markOverdefined(port);
369  }
370  }
371 
372  // If a value changed lattice state then reprocess any of its users.
373  while (!changedLatticeValueWorklist.empty()) {
374  FieldRef changedFieldRef = changedLatticeValueWorklist.pop_back_val();
375  for (Operation *user : fieldRefToUsers[changedFieldRef]) {
376  if (isBlockExecutable(user->getBlock()))
377  visitOperation(user, changedFieldRef);
378  }
379  }
380 
381  // Rewrite any constants in the modules.
382  mlir::parallelForEach(circuit.getContext(),
383  circuit.getBodyBlock()->getOps<FModuleOp>(),
384  [&](auto op) { rewriteModuleBody(op); });
385 
386  // Clean up our state for next time.
387  instanceGraph = nullptr;
388  latticeValues.clear();
389  executableBlocks.clear();
390  assert(changedLatticeValueWorklist.empty());
391  fieldRefToUsers.clear();
392  valueToFieldRef.clear();
393  resultPortToInstanceResultMapping.clear();
394 }
395 
396 /// Return the lattice value for the specified SSA value, extended to the width
397 /// of the specified destType. If allowTruncation is true, then this allows
398 /// truncating the lattice value to the specified type.
399 LatticeValue IMConstPropPass::getExtendedLatticeValue(FieldRef value,
400  FIRRTLType destType,
401  bool allowTruncation) {
402  // If 'value' hasn't been computed yet, then it is unknown.
403  auto it = latticeValues.find(value);
404  if (it == latticeValues.end())
405  return LatticeValue();
406 
407  auto result = it->second;
408  // Unknown/overdefined stay whatever they are.
409  if (result.isUnknown() || result.isOverdefined())
410  return result;
411 
412  // No extOrTrunc for property types. Return what we have.
413  if (isa<PropertyType>(destType))
414  return result;
415 
416  auto constant = result.getConstant();
417 
418  // If not property, only support integers.
419  auto intAttr = dyn_cast<IntegerAttr>(constant);
420  assert(intAttr && "unsupported lattice attribute kind");
421  if (!intAttr)
422  return result;
423 
424  // No extOrTrunc necessary for bools.
425  if (auto boolAttr = dyn_cast<BoolAttr>(intAttr))
426  return result;
427 
428  // Non-base (or non-ref) types are overdefined.
429  auto baseType = getBaseType(destType);
430  if (!baseType)
431  return LatticeValue::getOverdefined();
432 
433  // If destType is wider than the source constant type, extend it.
434  auto resultConstant = intAttr.getAPSInt();
435  auto destWidth = baseType.getBitWidthOrSentinel();
436  if (destWidth == -1) // We don't support unknown width FIRRTL.
437  return LatticeValue::getOverdefined();
438  if (resultConstant.getBitWidth() == (unsigned)destWidth)
439  return result; // Already the right width, we're done.
440 
441  // Otherwise, extend the constant using the signedness of the source.
442  resultConstant = extOrTruncZeroWidth(resultConstant, destWidth);
443  return LatticeValue(IntegerAttr::get(destType.getContext(), resultConstant));
444 }
445 
446 // NOLINTBEGIN(misc-no-recursion)
447 /// Mark a block executable if it isn't already. This does an initial scan of
448 /// the block, processing nullary operations like wires, instances, and
449 /// constants that only get processed once.
450 void IMConstPropPass::markBlockExecutable(Block *block) {
451  if (!executableBlocks.insert(block).second)
452  return; // Already executable.
453 
454  // Mark block arguments, which are module ports, with don't touch as
455  // overdefined.
456  for (auto ba : block->getArguments())
457  if (hasDontTouch(ba))
458  markOverdefined(ba);
459 
460  for (auto &op : *block) {
461  // Handle each of the special operations in the firrtl dialect.
462  TypeSwitch<Operation *>(&op)
463  .Case<RegOp, RegResetOp>(
464  [&](auto reg) { markOverdefined(op.getResult(0)); })
465  .Case<WireOp>([&](auto wire) { markWireOp(wire); })
466  .Case<ConstantOp, SpecialConstantOp, StringConstantOp,
467  FIntegerConstantOp, BoolConstantOp>(
468  [&](auto constOp) { markConstantValueOp(constOp); })
469  .Case<AggregateConstantOp>(
470  [&](auto aggConstOp) { markAggregateConstantOp(aggConstOp); })
471  .Case<InvalidValueOp>(
472  [&](auto invalid) { markInvalidValueOp(invalid); })
473  .Case<InstanceOp>([&](auto instance) { markInstanceOp(instance); })
474  .Case<ObjectOp>([&](auto obj) { markObjectOp(obj); })
475  .Case<MemOp>([&](auto mem) { markMemOp(mem); })
476  .Default([&](auto _) {
477  if (isa<mlir::UnrealizedConversionCastOp, VerbatimExprOp,
478  VerbatimWireOp, SubaccessOp>(op) ||
479  op.getNumOperands() == 0) {
480  // Mark operations whose results cannot be tracked as overdefined.
481  // Mark unhandled operations with no operand as well since otherwise
482  // they will remain unknown states until the end.
483  for (auto result : op.getResults())
484  markOverdefined(result);
485  } else if (
486  // Operations that are handled when propagating values, or chasing
487  // indexing.
488  !isAggregate(&op) && !isNodeLike(&op) && op.getNumResults() > 0) {
489  // If an unknown operation has an aggregate operand, mark results as
490  // overdefined since we cannot track the dataflow. Similarly if the
491  // operations create aggregate values, we mark them overdefined.
492 
493  // TODO: We should handle aggregate operations such as
494  // vector_create, bundle_create or vector operations.
495 
496  bool hasAggregateOperand =
497  llvm::any_of(op.getOperandTypes(), [](Type type) {
498  return type_isa<FVectorType, BundleType>(type);
499  });
500 
501  for (auto result : op.getResults())
502  if (hasAggregateOperand ||
503  type_isa<FVectorType, BundleType>(result.getType()))
504  markOverdefined(result);
505  }
506  });
507 
508  // This tracks a dependency from field refs to operations which need
509  // to be added to worklist when lattice values change.
510  if (!isAggregate(&op)) {
511  for (auto operand : op.getOperands()) {
512  auto fieldRef = getOrCacheFieldRefFromValue(operand);
513  auto firrtlType = type_dyn_cast<FIRRTLType>(operand.getType());
514  if (!firrtlType)
515  continue;
516  // Special-handle PropertyType's, walkGroundTypes doesn't support.
517  if (type_isa<PropertyType>(firrtlType)) {
518  fieldRefToUsers[fieldRef].push_back(&op);
519  continue;
520  }
521  walkGroundTypes(firrtlType, [&](uint64_t fieldID, auto type, auto) {
522  fieldRefToUsers[fieldRef.getSubField(fieldID)].push_back(&op);
523  });
524  }
525  }
526  }
527 }
528 // NOLINTEND(misc-no-recursion)
529 
530 void IMConstPropPass::markWireOp(WireOp wire) {
531  auto type = type_dyn_cast<FIRRTLType>(wire.getResult().getType());
532  if (!type || hasDontTouch(wire.getResult()) || wire.isForceable()) {
533  for (auto result : wire.getResults())
534  markOverdefined(result);
535  return;
536  }
537 
538  // Otherwise, this starts out as unknown and is upgraded by connects.
539 }
540 
541 void IMConstPropPass::markMemOp(MemOp mem) {
542  for (auto result : mem.getResults())
543  markOverdefined(result);
544 }
545 
546 template <typename OpTy>
547 void IMConstPropPass::markConstantValueOp(OpTy op) {
548  mergeLatticeValue(getOrCacheFieldRefFromValue(op),
549  LatticeValue(op.getValueAttr()));
550 }
551 
552 void IMConstPropPass::markAggregateConstantOp(AggregateConstantOp constant) {
553  walkGroundTypes(constant.getType(), [&](uint64_t fieldID, auto, auto) {
554  mergeLatticeValue(FieldRef(constant, fieldID),
555  LatticeValue(cast<IntegerAttr>(
556  constant.getAttributeFromFieldID(fieldID))));
557  });
558 }
559 
560 void IMConstPropPass::markInvalidValueOp(InvalidValueOp invalid) {
561  markOverdefined(invalid.getResult());
562 }
563 
564 /// Instances have no operands, so they are visited exactly once when their
565 /// enclosing block is marked live. This sets up the def-use edges for ports.
566 void IMConstPropPass::markInstanceOp(InstanceOp instance) {
567  // Get the module being reference or a null pointer if this is an extmodule.
568  Operation *op = instance.getReferencedModule(*instanceGraph);
569 
570  // If this is an extmodule, just remember that any results and inouts are
571  // overdefined.
572  if (!isa<FModuleOp>(op)) {
573  auto module = dyn_cast<FModuleLike>(op);
574  for (size_t resultNo = 0, e = instance.getNumResults(); resultNo != e;
575  ++resultNo) {
576  auto portVal = instance.getResult(resultNo);
577  // If this is an input to the extmodule, we can ignore it.
578  if (module.getPortDirection(resultNo) == Direction::In)
579  continue;
580 
581  // Otherwise this is a result from it or an inout, mark it as overdefined.
582  markOverdefined(portVal);
583  }
584  return;
585  }
586 
587  // Otherwise this is a defined module.
588  auto fModule = cast<FModuleOp>(op);
589  markBlockExecutable(fModule.getBodyBlock());
590 
591  // Ok, it is a normal internal module reference. Populate
592  // resultPortToInstanceResultMapping, and forward any already-computed values.
593  for (size_t resultNo = 0, e = instance.getNumResults(); resultNo != e;
594  ++resultNo) {
595  auto instancePortVal = instance.getResult(resultNo);
596  // If this is an input to the instance, it will
597  // get handled when any connects to it are processed.
598  if (fModule.getPortDirection(resultNo) == Direction::In)
599  continue;
600 
601  // Otherwise we have a result from the instance. We need to forward results
602  // from the body to this instance result's SSA value, so remember it.
603  BlockArgument modulePortVal = fModule.getArgument(resultNo);
604 
605  resultPortToInstanceResultMapping[modulePortVal].push_back(instancePortVal);
606 
607  // If there is already a value known for modulePortVal make sure to forward
608  // it here.
609  mergeLatticeValue(instancePortVal, modulePortVal);
610  }
611 }
612 
613 void IMConstPropPass::markObjectOp(ObjectOp obj) {
614  // Mark overdefined for now, not supported.
615  markOverdefined(obj);
616 }
617 
618 static std::optional<uint64_t>
619 getFieldIDOffset(FieldRef changedFieldRef, Type connectionType,
620  FieldRef connectedValueFieldRef) {
621  assert(!type_isa<RefType>(connectionType));
622  if (changedFieldRef.getValue() != connectedValueFieldRef.getValue())
623  return {};
624  if (changedFieldRef.getFieldID() >= connectedValueFieldRef.getFieldID() &&
625  changedFieldRef.getFieldID() <=
626  hw::FieldIdImpl::getMaxFieldID(connectionType) +
627  connectedValueFieldRef.getFieldID())
628  return changedFieldRef.getFieldID() - connectedValueFieldRef.getFieldID();
629  return {};
630 }
631 
632 void IMConstPropPass::mergeOnlyChangedLatticeValue(Value dest, Value src,
633  FieldRef changedFieldRef) {
634 
635  // Operate on inner type for refs.
636  auto destType = dest.getType();
637  if (auto refType = type_dyn_cast<RefType>(destType))
638  destType = refType.getType();
639 
640  if (!isa<FIRRTLType>(destType)) {
641  // If the dest is not FIRRTL type, conservatively mark
642  // all of them overdefined.
643  markOverdefined(src);
644  return markOverdefined(dest);
645  }
646 
647  auto fieldRefSrc = getOrCacheFieldRefFromValue(src);
648  auto fieldRefDest = getOrCacheFieldRefFromValue(dest);
649 
650  // If a changed field ref is included the source value, find an offset in the
651  // connection.
652  if (auto srcOffset = getFieldIDOffset(changedFieldRef, destType, fieldRefSrc))
653  mergeLatticeValue(fieldRefDest.getSubField(*srcOffset),
654  fieldRefSrc.getSubField(*srcOffset));
655 
656  // If a changed field ref is included the dest value, find an offset in the
657  // connection.
658  if (auto destOffset =
659  getFieldIDOffset(changedFieldRef, destType, fieldRefDest))
660  mergeLatticeValue(fieldRefDest.getSubField(*destOffset),
661  fieldRefSrc.getSubField(*destOffset));
662 }
663 
664 void IMConstPropPass::visitConnectLike(FConnectLike connect,
665  FieldRef changedFieldRef) {
666  // Operate on inner type for refs.
667  auto destType = connect.getDest().getType();
668  if (auto refType = type_dyn_cast<RefType>(destType))
669  destType = refType.getType();
670 
671  // Mark foreign types as overdefined.
672  if (!isa<FIRRTLType>(destType)) {
673  markOverdefined(connect.getSrc());
674  return markOverdefined(connect.getDest());
675  }
676 
677  auto fieldRefSrc = getOrCacheFieldRefFromValue(connect.getSrc());
678  auto fieldRefDest = getOrCacheFieldRefFromValue(connect.getDest());
679  if (auto subaccess = fieldRefDest.getValue().getDefiningOp<SubaccessOp>()) {
680  // If the destination is subaccess, we give up to precisely track
681  // lattice values and mark entire aggregate as overdefined. This code
682  // should be dead unless we stop lowering of subaccess in LowerTypes.
683  Value parent = subaccess.getInput();
684  while (parent.getDefiningOp() &&
685  parent.getDefiningOp()->getNumOperands() > 0)
686  parent = parent.getDefiningOp()->getOperand(0);
687  return markOverdefined(parent);
688  }
689 
690  auto propagateElementLattice = [&](uint64_t fieldID, FIRRTLType destType) {
691  auto fieldRefDestConnected = fieldRefDest.getSubField(fieldID);
692  assert(!firrtl::type_isa<FIRRTLBaseType>(destType) ||
693  firrtl::type_cast<FIRRTLBaseType>(destType).isGround());
694 
695  // Handle implicit extensions.
696  auto srcValue =
697  getExtendedLatticeValue(fieldRefSrc.getSubField(fieldID), destType);
698  if (srcValue.isUnknown())
699  return;
700 
701  // Driving result ports propagates the value to each instance using the
702  // module.
703  if (auto blockArg = dyn_cast<BlockArgument>(fieldRefDest.getValue())) {
704  for (auto userOfResultPort : resultPortToInstanceResultMapping[blockArg])
705  mergeLatticeValue(
706  FieldRef(userOfResultPort, fieldRefDestConnected.getFieldID()),
707  srcValue);
708  // Output ports are wire-like and may have users.
709  return mergeLatticeValue(fieldRefDestConnected, srcValue);
710  }
711 
712  auto dest = cast<mlir::OpResult>(fieldRefDest.getValue());
713 
714  // For wires and registers, we drive the value of the wire itself, which
715  // automatically propagates to users.
716  if (isWireOrReg(dest.getOwner()))
717  return mergeLatticeValue(fieldRefDestConnected, srcValue);
718 
719  // Driving an instance argument port drives the corresponding argument
720  // of the referenced module.
721  if (auto instance = dest.getDefiningOp<InstanceOp>()) {
722  // Update the dest, when its an instance op.
723  mergeLatticeValue(fieldRefDestConnected, srcValue);
724  auto mod = instance.getReferencedModule<FModuleOp>(*instanceGraph);
725  if (!mod)
726  return;
727 
728  BlockArgument modulePortVal = mod.getArgument(dest.getResultNumber());
729 
730  return mergeLatticeValue(
731  FieldRef(modulePortVal, fieldRefDestConnected.getFieldID()),
732  srcValue);
733  }
734 
735  // Driving a memory result is ignored because these are always treated
736  // as overdefined.
737  if (dest.getDefiningOp<MemOp>())
738  return;
739 
740  // For now, don't support const prop into object fields.
741  if (isa_and_nonnull<ObjectSubfieldOp>(dest.getDefiningOp()))
742  return;
743 
744  connect.emitError("connectlike operation unhandled by IMConstProp")
745  .attachNote(connect.getDest().getLoc())
746  << "connect destination is here";
747  };
748 
749  if (auto srcOffset = getFieldIDOffset(changedFieldRef, destType, fieldRefSrc))
750  propagateElementLattice(
751  *srcOffset,
752  firrtl::type_cast<FIRRTLType>(
753  hw::FieldIdImpl::getFinalTypeByFieldID(destType, *srcOffset)));
754 
755  if (auto relativeDest =
756  getFieldIDOffset(changedFieldRef, destType, fieldRefDest))
757  propagateElementLattice(
758  *relativeDest,
759  firrtl::type_cast<FIRRTLType>(
760  hw::FieldIdImpl::getFinalTypeByFieldID(destType, *relativeDest)));
761 }
762 
763 void IMConstPropPass::visitRefSend(RefSendOp send, FieldRef changedFieldRef) {
764  // Send connects the base value (source) to the result (dest).
765  return mergeOnlyChangedLatticeValue(send.getResult(), send.getBase(),
766  changedFieldRef);
767 }
768 
769 void IMConstPropPass::visitRefResolve(RefResolveOp resolve,
770  FieldRef changedFieldRef) {
771  // Resolve connects the ref value (source) to result (dest).
772  // If writes are ever supported, this will need to work differently!
773  return mergeOnlyChangedLatticeValue(resolve.getResult(), resolve.getRef(),
774  changedFieldRef);
775 }
776 
777 void IMConstPropPass::visitNode(NodeOp node, FieldRef changedFieldRef) {
778  if (hasDontTouch(node.getResult()) || node.isForceable()) {
779  for (auto result : node.getResults())
780  markOverdefined(result);
781  return;
782  }
783 
784  return mergeOnlyChangedLatticeValue(node.getResult(), node.getInput(),
785  changedFieldRef);
786 }
787 
788 /// This method is invoked when an operand of the specified op changes its
789 /// lattice value state and when the block containing the operation is first
790 /// noticed as being alive.
791 ///
792 /// This should update the lattice value state for any result values.
793 ///
794 void IMConstPropPass::visitOperation(Operation *op, FieldRef changedField) {
795  // If this is a operation with special handling, handle it specially.
796  if (auto connectLikeOp = dyn_cast<FConnectLike>(op))
797  return visitConnectLike(connectLikeOp, changedField);
798  if (auto sendOp = dyn_cast<RefSendOp>(op))
799  return visitRefSend(sendOp, changedField);
800  if (auto resolveOp = dyn_cast<RefResolveOp>(op))
801  return visitRefResolve(resolveOp, changedField);
802  if (auto nodeOp = dyn_cast<NodeOp>(op))
803  return visitNode(nodeOp, changedField);
804 
805  // The clock operand of regop changing doesn't change its result value. All
806  // other registers are over-defined. Aggregate operations also doesn't change
807  // its result value.
808  if (isa<RegOp, RegResetOp>(op) || isAggregate(op))
809  return;
810  // TODO: Handle 'when' operations.
811 
812  // If all of the results of this operation are already overdefined (or if
813  // there are no results) then bail out early: we've converged.
814  auto isOverdefinedFn = [&](Value value) {
815  return isOverdefined(getOrCacheFieldRefFromValue(value));
816  };
817  if (llvm::all_of(op->getResults(), isOverdefinedFn))
818  return;
819 
820  // To prevent regressions, mark values as overdefined when they are defined
821  // by operations with a large number of operands.
822  if (op->getNumOperands() > 128) {
823  for (auto value : op->getResults())
824  markOverdefined(value);
825  return;
826  }
827 
828  // Collect all of the constant operands feeding into this operation. If any
829  // are not ready to be resolved, bail out and wait for them to resolve.
830  SmallVector<Attribute, 8> operandConstants;
831  operandConstants.reserve(op->getNumOperands());
832  bool hasUnknown = false;
833  for (Value operand : op->getOperands()) {
834 
835  auto &operandLattice = latticeValues[getOrCacheFieldRefFromValue(operand)];
836 
837  // If the operand is an unknown value, then we generally don't want to
838  // process it - we want to wait until the value is resolved to by the SCCP
839  // algorithm.
840  if (operandLattice.isUnknown())
841  hasUnknown = true;
842 
843  // Otherwise, it must be constant, invalid, or overdefined. Translate them
844  // into attributes that the fold hook can look at.
845  if (operandLattice.isConstant())
846  operandConstants.push_back(operandLattice.getValue());
847  else
848  operandConstants.push_back({});
849  }
850 
851  // Simulate the result of folding this operation to a constant. If folding
852  // fails mark the results as overdefined.
853  SmallVector<OpFoldResult, 8> foldResults;
854  foldResults.reserve(op->getNumResults());
855  if (failed(op->fold(operandConstants, foldResults))) {
856  LLVM_DEBUG({
857  logger.startLine() << "Folding Failed operation : '" << op->getName()
858  << "\n";
859  op->dump();
860  });
861  // If we had unknown arguments, hold off on overdefining
862  if (!hasUnknown)
863  for (auto value : op->getResults())
864  markOverdefined(value);
865  return;
866  }
867 
868  LLVM_DEBUG({
869  logger.getOStream() << "\n";
870  logger.startLine() << "Folding operation : '" << op->getName() << "\n";
871  op->dump();
872  logger.getOStream() << "( ";
873  for (auto cst : operandConstants)
874  if (!cst)
875  logger.getOStream() << "{} ";
876  else
877  logger.getOStream() << cst << " ";
878  logger.unindent();
879  logger.getOStream() << ") -> { ";
880  logger.indent();
881  for (auto &r : foldResults) {
882  logger.getOStream() << r << " ";
883  }
884  logger.unindent();
885  logger.getOStream() << "}\n";
886  });
887 
888  // If the folding was in-place, keep going. This is surprising, but since
889  // only folder that will do in-place updates is the commutative folder, we
890  // aren't going to stop. We don't update the results, since they didn't
891  // change, the op just got shuffled around.
892  if (foldResults.empty())
893  return visitOperation(op, changedField);
894 
895  // Merge the fold results into the lattice for this operation.
896  assert(foldResults.size() == op->getNumResults() && "invalid result size");
897  for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
898  // Merge in the result of the fold, either a constant or a value.
899  LatticeValue resultLattice;
900  OpFoldResult foldResult = foldResults[i];
901  if (Attribute foldAttr = dyn_cast<Attribute>(foldResult)) {
902  if (auto intAttr = dyn_cast<IntegerAttr>(foldAttr))
903  resultLattice = LatticeValue(intAttr);
904  else if (auto strAttr = dyn_cast<StringAttr>(foldAttr))
905  resultLattice = LatticeValue(strAttr);
906  else // Treat unsupported constants as overdefined.
907  resultLattice = LatticeValue::getOverdefined();
908  } else { // Folding to an operand results in its value.
909  resultLattice =
910  latticeValues[getOrCacheFieldRefFromValue(foldResult.get<Value>())];
911  }
912 
913  mergeLatticeValue(getOrCacheFieldRefFromValue(op->getResult(i)),
914  resultLattice);
915  }
916 }
917 
918 void IMConstPropPass::rewriteModuleBody(FModuleOp module) {
919  auto *body = module.getBodyBlock();
920  // If a module is unreachable, just ignore it.
921  if (!executableBlocks.count(body))
922  return;
923 
924  auto builder = OpBuilder::atBlockBegin(body);
925 
926  // Separate the constants we insert from the instructions we are folding and
927  // processing. Leave these as-is until we're done.
928  auto cursor = builder.create<firrtl::ConstantOp>(module.getLoc(), APSInt(1));
929  builder.setInsertionPoint(cursor);
930 
931  // Unique constants per <Const,Type> pair, inserted at entry
932  DenseMap<std::pair<Attribute, Type>, Operation *> constPool;
933 
934  std::function<Value(Attribute, Type, Location)> getConst =
935  [&](Attribute constantValue, Type type, Location loc) -> Value {
936  auto constIt = constPool.find({constantValue, type});
937  if (constIt != constPool.end()) {
938  auto *cst = constIt->second;
939  // Add location to the constant
940  cst->setLoc(builder.getFusedLoc({cst->getLoc(), loc}));
941  return cst->getResult(0);
942  }
943  OpBuilder::InsertionGuard x(builder);
944  builder.setInsertionPoint(cursor);
945 
946  // Materialize reftype "constants" by materializing the constant
947  // and probing it.
948  Operation *cst;
949  if (auto refType = type_dyn_cast<RefType>(type)) {
950  assert(!type_cast<RefType>(type).getForceable() &&
951  "Attempting to materialize rwprobe of constant, shouldn't happen");
952  auto inner = getConst(constantValue, refType.getType(), loc);
953  assert(inner);
954  cst = builder.create<RefSendOp>(loc, inner);
955  } else
956  cst = module->getDialect()->materializeConstant(builder, constantValue,
957  type, loc);
958  assert(cst && "all FIRRTL constants can be materialized");
959  constPool.insert({{constantValue, type}, cst});
960  return cst->getResult(0);
961  };
962 
963  // If the lattice value for the specified value is a constant update it and
964  // return true. Otherwise return false.
965  auto replaceValueIfPossible = [&](Value value) -> bool {
966  // Lambda to replace all uses of this value a replacement, unless this is
967  // the destination of a connect. We leave connects alone to avoid upsetting
968  // flow, i.e., to avoid trying to connect to a constant.
969  auto replaceIfNotConnect = [&value](Value replacement) {
970  value.replaceUsesWithIf(replacement, [](OpOperand &operand) {
971  return !isa<FConnectLike>(operand.getOwner()) ||
972  operand.getOperandNumber() != 0;
973  });
974  };
975 
976  // TODO: Replace entire aggregate.
977  auto it = latticeValues.find(getFieldRefFromValue(value));
978  if (it == latticeValues.end() || it->second.isOverdefined() ||
979  it->second.isUnknown())
980  return false;
981 
982  // Cannot materialize constants for certain types.
983  // TODO: Let materializeConstant tell us what it supports instead of this.
984  // Presently it asserts on unsupported combinations, so check this here.
985  if (!type_isa<FIRRTLBaseType, RefType, FIntegerType, StringType, BoolType>(
986  value.getType()))
987  return false;
988 
989  auto cstValue =
990  getConst(it->second.getValue(), value.getType(), value.getLoc());
991 
992  replaceIfNotConnect(cstValue);
993  return true;
994  };
995 
996  // Constant propagate any ports that are always constant.
997  for (auto &port : body->getArguments())
998  replaceValueIfPossible(port);
999 
1000  // TODO: Walk 'when's preorder with `walk`.
1001 
1002  // Walk the IR bottom-up when folding. We often fold entire chains of
1003  // operations into constants, which make the intermediate nodes dead. Going
1004  // bottom up eliminates the users of the intermediate ops, allowing us to
1005  // aggressively delete them.
1006  bool aboveCursor = false;
1007  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*body))) {
1008  auto dropIfDead = [&](Operation &op, const Twine &debugPrefix) {
1009  if (op.use_empty() &&
1010  (wouldOpBeTriviallyDead(&op) || isDeletableWireOrRegOrNode(&op))) {
1011  LLVM_DEBUG(
1012  { logger.getOStream() << debugPrefix << " : " << op << "\n"; });
1013  ++numErasedOp;
1014  op.erase();
1015  return true;
1016  }
1017  return false;
1018  };
1019 
1020  if (aboveCursor) {
1021  // Drop dead constants we materialized.
1022  dropIfDead(op, "Trivially dead materialized constant");
1023  continue;
1024  }
1025  // Stop once hit the generated constants.
1026  if (&op == cursor) {
1027  cursor.erase();
1028  aboveCursor = true;
1029  continue;
1030  }
1031 
1032  // Connects to values that we found to be constant can be dropped.
1033  if (auto connect = dyn_cast<FConnectLike>(op)) {
1034  if (auto *destOp = connect.getDest().getDefiningOp()) {
1035  auto fieldRef = getOrCacheFieldRefFromValue(connect.getDest());
1036  // Don't remove a field-level connection even if the src value is
1037  // constant. If other elements of the aggregate value are not constant,
1038  // the aggregate value cannot be replaced. We can forward the constant
1039  // to its users, so IMDCE (or SV/HW canonicalizer) should remove the
1040  // aggregate if entire aggregate is dead.
1041  auto type = type_dyn_cast<FIRRTLType>(connect.getDest().getType());
1042  if (!type)
1043  continue;
1044  auto baseType = type_dyn_cast<FIRRTLBaseType>(type);
1045  if (baseType && !baseType.isGround())
1046  continue;
1047  if (isDeletableWireOrRegOrNode(destOp) && !isOverdefined(fieldRef)) {
1048  connect.erase();
1049  ++numErasedOp;
1050  }
1051  }
1052  continue;
1053  }
1054 
1055  // We only fold single-result ops and instances in practice, because they
1056  // are the expressions.
1057  if (op.getNumResults() != 1 && !isa<InstanceOp>(op))
1058  continue;
1059 
1060  // If this operation is already dead, then go ahead and remove it.
1061  if (dropIfDead(op, "Trivially dead"))
1062  continue;
1063 
1064  // Don't "fold" constants (into equivalent), also because they
1065  // may have name hints we'd like to preserve.
1066  if (op.hasTrait<mlir::OpTrait::ConstantLike>())
1067  continue;
1068 
1069  // If the op had any constants folded, replace them.
1070  builder.setInsertionPoint(&op);
1071  bool foldedAny = false;
1072  for (auto result : op.getResults())
1073  foldedAny |= replaceValueIfPossible(result);
1074 
1075  if (foldedAny)
1076  ++numFoldedOp;
1077 
1078  // If the operation folded to a constant then we can probably nuke it.
1079  if (foldedAny && dropIfDead(op, "Made dead"))
1080  continue;
1081  }
1082 }
1083 
1084 std::unique_ptr<mlir::Pass> circt::firrtl::createIMConstPropPass() {
1085  return std::make_unique<IMConstPropPass>();
1086 }
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static bool isNodeLike(Operation *op)
Definition: IMConstProp.cpp:46
static std::optional< uint64_t > getFieldIDOffset(FieldRef changedFieldRef, Type connectionType, FieldRef connectedValueFieldRef)
static bool isWireOrReg(Operation *op)
Return true if this is a wire or register.
Definition: IMConstProp.cpp:34
static bool isAggregate(Operation *op)
Return true if this is an aggregate indexer.
Definition: IMConstProp.cpp:39
static bool isDeletableWireOrRegOrNode(Operation *op)
Return true if this is a wire or register we're allowed to delete.
Definition: IMConstProp.cpp:51
bool operator!=(const ResetDomain &a, const ResetDomain &b)
Definition: InferResets.cpp:74
bool operator==(const ResetDomain &a, const ResetDomain &b)
Definition: InferResets.cpp:71
Builder builder
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
FieldRef getSubField(unsigned subFieldID) const
Get a reference to a subfield.
Definition: FieldRef.h:62
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Definition: FieldRef.h:59
Value getValue() const
Get the Value which created this location.
Definition: FieldRef.h:37
Location getLoc() const
Get the location associated with the value of this field ref.
Definition: FieldRef.h:67
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool canBeDeleted() const
Check if every annotation can be deleted.
This graph tracks modules and where they are instantiated.
def connect(destination, source)
Definition: support.py:37
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
Definition: FIRRTLUtils.h:217
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
void walkGroundTypes(FIRRTLType firrtlType, llvm::function_ref< void(uint64_t, FIRRTLBaseType, bool)> fn)
Walk leaf ground types in the firrtlType and apply the function fn.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
Definition: FIRRTLOps.cpp:4494
std::unique_ptr< mlir::Pass > createIMConstPropPass()
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
Definition: FIRRTLOps.cpp:287
T & operator<<(T &os, FIRVersion version)
Definition: FIRParser.h:116
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
uint64_t getMaxFieldID(Type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:22
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:20