CIRCT  20.0.0git
FIRRTLOpInterfaces.cpp
Go to the documentation of this file.
1 //===- FIRRTLOpInterfaces.cpp - Implement the FIRRTL op interfaces --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the FIRRTL operation interfaces.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "llvm/ADT/SmallBitVector.h"
19 #include "llvm/ADT/StringRef.h"
20 
21 using namespace mlir;
22 using namespace llvm;
23 using namespace circt::firrtl;
24 
25 static LogicalResult verifyNoInputProbes(FModuleLike module) {
26  // Helper to check for input-oriented refs.
27  std::function<bool(Type, bool)> hasInputRef = [&](Type type,
28  bool output) -> bool {
29  auto ftype = type_dyn_cast<FIRRTLType>(type);
30  if (!ftype || !ftype.containsReference())
31  return false;
33  .Case<RefType>([&](auto reftype) { return !output; })
34  .Case<OpenVectorType>([&](OpenVectorType ovt) {
35  return hasInputRef(ovt.getElementType(), output);
36  })
37  .Case<OpenBundleType>([&](OpenBundleType obt) {
38  for (auto field : obt.getElements())
39  if (hasInputRef(field.type, field.isFlip ^ output))
40  return true;
41  return false;
42  });
43  };
44 
45  for (auto &pi : module.getPorts())
46  if (hasInputRef(pi.type, pi.isOutput()))
47  return emitError(pi.loc, "input probe not allowed");
48  return success();
49 }
50 
51 LogicalResult circt::firrtl::verifyModuleLikeOpInterface(FModuleLike module) {
52  // Verify port types first. This is used as the basis for the number of
53  // ports required everywhere else.
54  auto portTypes = module.getPortTypesAttr();
55  if (!portTypes || llvm::any_of(portTypes.getValue(), [](Attribute attr) {
56  return !isa<TypeAttr>(attr);
57  }))
58  return module.emitOpError("requires valid port types");
59 
60  auto numPorts = portTypes.size();
61 
62  // Verify the port dirctions.
63  auto portDirections = module.getPortDirectionsAttr();
64  if (!portDirections)
65  return module.emitOpError("requires valid port direction");
66  // TODO: bitwidth is 1 when there are no ports, since APInt previously did not
67  // support 0 bit widths.
68  auto bitWidth = portDirections.size();
69  if (static_cast<size_t>(bitWidth) != numPorts)
70  return module.emitOpError("requires ") << numPorts << " port directions";
71 
72  // Verify the port names.
73  auto portNames = module.getPortNamesAttr();
74  if (!portNames)
75  return module.emitOpError("requires valid port names");
76  if (portNames.size() != numPorts)
77  return module.emitOpError("requires ") << numPorts << " port names";
78  if (llvm::any_of(portNames.getValue(),
79  [](Attribute attr) { return !isa<StringAttr>(attr); }))
80  return module.emitOpError("port names should all be string attributes");
81 
82  // Verify the port annotations.
83  auto portAnnotations = module.getPortAnnotationsAttr();
84  if (!portAnnotations)
85  return module.emitOpError("requires valid port annotations");
86  // TODO: it seems weird to allow empty port annotations.
87  if (!portAnnotations.empty() && portAnnotations.size() != numPorts)
88  return module.emitOpError("requires ") << numPorts << " port annotations";
89  // TODO: Move this into an annotation verifier.
90  for (auto annos : portAnnotations.getValue()) {
91  auto arrayAttr = dyn_cast<ArrayAttr>(annos);
92  if (!arrayAttr)
93  return module.emitOpError(
94  "requires port annotations be array attributes");
95  if (llvm::any_of(arrayAttr.getValue(),
96  [](Attribute attr) { return !isa<DictionaryAttr>(attr); }))
97  return module.emitOpError(
98  "annotations must be dictionaries or subannotations");
99  }
100 
101  // Verify the port symbols.
102  auto portSymbols = module.getPortSymbolsAttr();
103  if (!portSymbols)
104  return module.emitOpError("requires valid port symbols");
105  if (!portSymbols.empty() && portSymbols.size() != numPorts)
106  return module.emitOpError("requires ") << numPorts << " port symbols";
107  if (llvm::any_of(portSymbols.getValue(), [](Attribute attr) {
108  return !attr || !isa<hw::InnerSymAttr>(attr);
109  }))
110  return module.emitOpError("port symbols should all be InnerSym attributes");
111 
112  // Verify the port locations.
113  auto portLocs = module.getPortLocationsAttr();
114  if (!portLocs)
115  return module.emitOpError("requires valid port locations");
116  if (portLocs.size() != numPorts)
117  return module.emitOpError("requires ") << numPorts << " port locations";
118  if (llvm::any_of(portLocs.getValue(), [](Attribute attr) {
119  return !attr || !isa<LocationAttr>(attr);
120  }))
121  return module.emitOpError("port symbols should all be location attributes");
122 
123  // Verify the body.
124  if (module->getNumRegions() != 1)
125  return module.emitOpError("requires one region");
126 
127  if (failed(verifyNoInputProbes(module)))
128  return failure();
129 
130  return success();
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // Forceable
135 //===----------------------------------------------------------------------===//
136 
138  Type type) {
139  auto base = dyn_cast_or_null<FIRRTLBaseType>(type);
140  // TODO: Find a way to not check same things RefType::get/verify does.
141  if (!forceable || !base || base.containsConst())
142  return {};
143  return circt::firrtl::RefType::get(base.getPassiveType(), forceable);
144 }
145 
146 LogicalResult circt::firrtl::detail::verifyForceableOp(Forceable op) {
147  bool forceable = op.isForceable();
148  auto ref = op.getDataRef();
149  if ((bool)ref != forceable)
150  return op.emitOpError("must have ref result iff marked forceable");
151  if (!forceable)
152  return success();
153  auto data = op.getDataRaw();
154  auto baseType = type_dyn_cast<FIRRTLBaseType>(data.getType());
155  if (!baseType)
156  return op.emitOpError("has data that is not a base type");
157  if (baseType.containsConst())
158  return op.emitOpError("cannot force a declaration of constant type");
159  auto expectedRefType = getForceableResultType(forceable, baseType);
160  if (ref.getType() != expectedRefType)
161  return op.emitOpError("reference result of incorrect type, found ")
162  << ref.getType() << ", expected " << expectedRefType;
163  return success();
164 }
165 
166 namespace {
167 /// Simple wrapper to allow construction from a context for local use.
168 class TrivialPatternRewriter : public PatternRewriter {
169 public:
170  explicit TrivialPatternRewriter(MLIRContext *context)
171  : PatternRewriter(context) {}
172 };
173 } // end namespace
174 
175 Forceable
176 circt::firrtl::detail::replaceWithNewForceability(Forceable op, bool forceable,
177  PatternRewriter *rewriter) {
178  if (forceable == op.isForceable())
179  return op;
180 
181  assert(op->getNumRegions() == 0);
182 
183  // Create copy of this operation with/without the forceable marker + result
184  // type.
185 
186  TrivialPatternRewriter localRewriter(op.getContext());
187  PatternRewriter &rw = rewriter ? *rewriter : localRewriter;
188 
189  // Grab the current operation's results and attributes.
190  SmallVector<Type, 8> resultTypes(op->getResultTypes());
191  SmallVector<NamedAttribute, 16> attributes(op->getAttrs());
192 
193  // Add/remove the optional ref result.
194  auto refType = firrtl::detail::getForceableResultType(true, op.getDataType());
195  if (forceable)
196  resultTypes.push_back(refType);
197  else {
198  assert(resultTypes.back() == refType &&
199  "expected forceable type as last result");
200  resultTypes.pop_back();
201  }
202 
203  // Add/remove the forceable marker.
204  auto forceableMarker =
205  rw.getNamedAttr(op.getForceableAttrName(), rw.getUnitAttr());
206  if (forceable)
207  attributes.push_back(forceableMarker);
208  else {
209  llvm::erase(attributes, forceableMarker);
210  assert(attributes.size() != op->getAttrs().size());
211  }
212 
213  // Create the replacement operation.
214  OperationState state(op.getLoc(), op->getName(), op->getOperands(),
215  resultTypes, attributes, op->getSuccessors());
216  rw.setInsertionPoint(op);
217  auto *replace = rw.create(state);
218 
219  // Dropping forceability (!forceable) -> no uses of forceable ref handle.
220  assert(forceable || op.getDataRef().use_empty());
221 
222  // Replace results.
223  for (auto result : llvm::drop_end(op->getResults(), forceable ? 0 : 1))
224  rw.replaceAllUsesWith(result, replace->getResult(result.getResultNumber()));
225  rw.eraseOp(op);
226  return cast<Forceable>(replace);
227 }
228 
229 #include "circt/Dialect/FIRRTL/FIRRTLOpInterfaces.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyNoInputProbes(FModuleLike module)
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
LogicalResult verifyForceableOp(Forceable op)
Verify a Forceable op.
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
RefType getForceableResultType(bool forceable, Type type)
Return null or forceable reference result type.
LogicalResult verifyModuleLikeOpInterface(FModuleLike module)
Verification hook for verifying module like operations.