CIRCT  20.0.0git
SFCCompat.cpp
Go to the documentation of this file.
1 //===- SFCCompat.cpp - SFC Compatible Pass ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This pass makes a number of updates to the circuit that are required to match
9 // the behavior of the Scala FIRRTL Compiler (SFC). This pass removes invalid
10 // values from the circuit. This is a combination of the Scala FIRRTL
11 // Compiler's RemoveRests pass and RemoveValidIf. This is done to remove two
12 // "interpretations" of invalid. Namely: (1) registers that are initialized to
13 // an invalid value (module scoped and looking through wires and connects only)
14 // are converted to an unitialized register and (2) invalid values are converted
15 // to zero (after rule 1 is applied). Additionally, this pass checks and
16 // disallows async reset registers that are not driven with a constant when
17 // looking through wires, connects, and nodes.
18 //
19 //===----------------------------------------------------------------------===//
20 
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/Pass/Pass.h"
27 #include "llvm/Support/Debug.h"
28 
29 #define DEBUG_TYPE "firrtl-remove-resets"
30 
31 namespace circt {
32 namespace firrtl {
33 #define GEN_PASS_DEF_SFCCOMPAT
34 #include "circt/Dialect/FIRRTL/Passes.h.inc"
35 } // namespace firrtl
36 } // namespace circt
37 
38 using namespace circt;
39 using namespace firrtl;
40 
42  : public circt::firrtl::impl::SFCCompatBase<SFCCompatPass> {
43  void runOnOperation() override;
44 };
45 
47  LLVM_DEBUG(
48  llvm::dbgs() << "==----- Running SFCCompat "
49  "---------------------------------------------------===\n"
50  << "Module: '" << getOperation().getName() << "'\n";);
51 
52  bool madeModifications = false;
53  SmallVector<InvalidValueOp> invalidOps;
54 
55  auto fullResetAttr = StringAttr::get(&getContext(), fullResetAnnoClass);
56  auto isFullResetAnno = [fullResetAttr](Annotation anno) {
57  auto annoClassAttr = anno.getClassAttr();
58  return annoClassAttr == fullResetAttr;
59  };
60  bool fullResetExists = AnnotationSet::removePortAnnotations(
61  getOperation(),
62  [&](unsigned argNum, Annotation anno) { return isFullResetAnno(anno); });
63  getOperation()->walk([isFullResetAnno, &fullResetExists](Operation *op) {
64  fullResetExists |= AnnotationSet::removeAnnotations(op, isFullResetAnno);
65  });
66  madeModifications |= fullResetExists;
67 
68  auto result = getOperation()->walk([&](Operation *op) {
69  // Populate invalidOps for later handling.
70  if (auto inv = dyn_cast<InvalidValueOp>(op)) {
71  invalidOps.push_back(inv);
72  return WalkResult::advance();
73  }
74  auto reg = dyn_cast<RegResetOp>(op);
75  if (!reg)
76  return WalkResult::advance();
77 
78  // If the `RegResetOp` has an invalidated initialization and we
79  // are not running FART, then replace it with a `RegOp`.
80  if (!fullResetExists && walkDrivers(reg.getResetValue(), true, true, false,
81  [](FieldRef dst, FieldRef src) {
82  return src.isa<InvalidValueOp>();
83  })) {
84  ImplicitLocOpBuilder builder(reg.getLoc(), reg);
85  RegOp newReg = builder.create<RegOp>(
86  reg.getResult().getType(), reg.getClockVal(), reg.getNameAttr(),
87  reg.getNameKindAttr(), reg.getAnnotationsAttr(),
88  reg.getInnerSymAttr(), reg.getForceableAttr());
89  reg.replaceAllUsesWith(newReg);
90  reg.erase();
91  madeModifications = true;
92  return WalkResult::advance();
93  }
94 
95  // If the `RegResetOp` has an asynchronous reset and the reset value is not
96  // a module-scoped constant when looking through wires and nodes, then
97  // generate an error. This implements the SFC's CheckResets pass.
98  if (!isa<AsyncResetType>(reg.getResetSignal().getType()))
99  return WalkResult::advance();
100  if (walkDrivers(
101  reg.getResetValue(), true, true, true,
102  [&](FieldRef dst, FieldRef src) {
103  if (src.isa<ConstantOp, InvalidValueOp, SpecialConstantOp,
104  AggregateConstantOp>())
105  return true;
106  auto diag = emitError(reg.getLoc());
107  auto [fieldName, rootKnown] = getFieldName(dst);
108  diag << "register " << reg.getNameAttr()
109  << " has an async reset, but its reset value";
110  if (rootKnown)
111  diag << " \"" << fieldName << "\"";
112  diag << " is not driven with a constant value through wires, "
113  "nodes, or connects";
114  std::tie(fieldName, rootKnown) = getFieldName(src);
115  diag.attachNote(src.getLoc())
116  << "reset driver is "
117  << (rootKnown ? ("\"" + fieldName + "\"") : "here");
118  return false;
119  }))
120  return WalkResult::advance();
121  return WalkResult::interrupt();
122  });
123 
124  if (result.wasInterrupted())
125  return signalPassFailure();
126 
127  // Convert all invalid values to zero.
128  for (auto inv : invalidOps) {
129  // Delete invalids which have no uses.
130  if (inv->getUses().empty()) {
131  inv->erase();
132  madeModifications = true;
133  continue;
134  }
135  ImplicitLocOpBuilder builder(inv.getLoc(), inv);
136  Value replacement =
138  .Case<ClockType, AsyncResetType, ResetType>(
139  [&](auto type) -> Value {
140  return builder.create<SpecialConstantOp>(
141  type, builder.getBoolAttr(false));
142  })
143  .Case<IntType>([&](IntType type) -> Value {
144  return builder.create<ConstantOp>(type, getIntZerosAttr(type));
145  })
146  .Case<BundleType, FVectorType>([&](auto type) -> Value {
147  auto width = circt::firrtl::getBitWidth(type);
148  assert(width && "width must be inferred");
149  auto zero = builder.create<ConstantOp>(APSInt(*width));
150  return builder.create<BitCastOp>(type, zero);
151  })
152  .Default([&](auto) {
153  llvm_unreachable("all types are supported");
154  return Value();
155  });
156  inv.replaceAllUsesWith(replacement);
157  inv.erase();
158  madeModifications = true;
159  }
160 
161  if (!madeModifications)
162  return markAllAnalysesPreserved();
163 }
164 
165 std::unique_ptr<mlir::Pass> circt::firrtl::createSFCCompatPass() {
166  return std::make_unique<SFCCompatPass>();
167 }
assert(baseType &&"element must be base type")
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
This is the common base class between SIntType and UIntType.
Definition: FIRRTLTypes.h:296
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
constexpr const char * fullResetAnnoClass
Annotation that marks a reset (port or wire) and domain.
std::unique_ptr< mlir::Pass > createSFCCompatPass()
Definition: SFCCompat.cpp:165
bool walkDrivers(FIRRTLBaseValue value, bool lookThroughWires, bool lookThroughNodes, bool lookThroughCasts, WalkDriverCallback callback)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:21
void runOnOperation() override
Definition: SFCCompat.cpp:46