CIRCT 22.0.0git
Loading...
Searching...
No Matches
FIRRTLOpInterfaces.cpp
Go to the documentation of this file.
1//===- FIRRTLOpInterfaces.cpp - Implement the FIRRTL op interfaces --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the FIRRTL operation interfaces.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/BuiltinAttributes.h"
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/PatternMatch.h"
18#include "llvm/ADT/SmallBitVector.h"
19#include "llvm/ADT/StringRef.h"
20
21using namespace mlir;
22using namespace llvm;
23using namespace circt::firrtl;
24
25static LogicalResult verifyNoInputProbes(FModuleLike module) {
26 // Helper to check for input-oriented refs.
27 std::function<bool(Type, bool)> hasInputRef = [&](Type type,
28 bool output) -> bool {
29 auto ftype = type_dyn_cast<FIRRTLType>(type);
30 if (!ftype || !ftype.containsReference())
31 return false;
33 .Case<RefType>([&](auto reftype) { return !output; })
34 .Case<OpenVectorType>([&](OpenVectorType ovt) {
35 return hasInputRef(ovt.getElementType(), output);
36 })
37 .Case<OpenBundleType>([&](OpenBundleType obt) {
38 for (auto field : obt.getElements())
39 if (hasInputRef(field.type, field.isFlip ^ output))
40 return true;
41 return false;
42 });
43 };
44
45 for (auto &pi : module.getPorts())
46 if (hasInputRef(pi.type, pi.isOutput()))
47 return emitError(pi.loc, "input probe not allowed");
48 return success();
49}
50
51LogicalResult circt::firrtl::verifyModuleLikeOpInterface(FModuleLike module) {
52 // Verify port types first. This is used as the basis for the number of
53 // ports required everywhere else.
54 auto portTypes = module.getPortTypesAttr();
55 if (!portTypes || llvm::any_of(portTypes.getValue(), [](Attribute attr) {
56 return !isa<TypeAttr>(attr);
57 }))
58 return module.emitOpError("requires valid port types");
59
60 auto numPorts = portTypes.size();
61
62 // Verify the port dirctions.
63 auto portDirections = module.getPortDirectionsAttr();
64 if (!portDirections)
65 return module.emitOpError("requires valid port direction");
66 // TODO: bitwidth is 1 when there are no ports, since APInt previously did not
67 // support 0 bit widths.
68 auto bitWidth = portDirections.size();
69 if (static_cast<size_t>(bitWidth) != numPorts)
70 return module.emitOpError("requires ") << numPorts << " port directions";
71
72 // Verify the port names.
73 auto portNames = module.getPortNamesAttr();
74 if (!portNames)
75 return module.emitOpError("requires valid port names");
76 if (portNames.size() != numPorts)
77 return module.emitOpError("requires ") << numPorts << " port names";
78 if (llvm::any_of(portNames.getValue(),
79 [](Attribute attr) { return !isa<StringAttr>(attr); }))
80 return module.emitOpError("port names should all be string attributes");
81
82 // Verify the port annotations.
83 auto portAnnotations = module.getPortAnnotationsAttr();
84 if (!portAnnotations)
85 return module.emitOpError("requires valid port annotations");
86 // TODO: it seems weird to allow empty port annotations.
87 if (!portAnnotations.empty() && portAnnotations.size() != numPorts)
88 return module.emitOpError("requires ") << numPorts << " port annotations";
89 // TODO: Move this into an annotation verifier.
90 for (auto annos : portAnnotations.getValue()) {
91 auto arrayAttr = dyn_cast<ArrayAttr>(annos);
92 if (!arrayAttr)
93 return module.emitOpError(
94 "requires port annotations be array attributes");
95 if (llvm::any_of(arrayAttr.getValue(),
96 [](Attribute attr) { return !isa<DictionaryAttr>(attr); }))
97 return module.emitOpError(
98 "annotations must be dictionaries or subannotations");
99 }
100
101 // Verify the port symbols.
102 auto portSymbols = module.getPortSymbolsAttr();
103 if (!portSymbols)
104 return module.emitOpError("requires valid port symbols");
105 if (!portSymbols.empty() && portSymbols.size() != numPorts)
106 return module.emitOpError("requires ") << numPorts << " port symbols";
107 if (llvm::any_of(portSymbols.getValue(), [](Attribute attr) {
108 return !attr || !isa<hw::InnerSymAttr>(attr);
109 }))
110 return module.emitOpError("port symbols should all be InnerSym attributes");
111
112 // Verify the port locations.
113 auto portLocs = module.getPortLocationsAttr();
114 if (!portLocs)
115 return module.emitOpError("requires valid port locations");
116 if (portLocs.size() != numPorts)
117 return module.emitOpError("requires ") << numPorts << " port locations";
118 if (llvm::any_of(portLocs.getValue(), [](Attribute attr) {
119 return !attr || !isa<LocationAttr>(attr);
120 }))
121 return module.emitOpError("port symbols should all be location attributes");
122
123 // Verify the port domains. This can be either:
124 // 1. An empty ArrayAttr.
125 // 2. An ArrayAttr, one entry-per-port, of:
126 // a. SymbolRefAttrs if the port is a domain type
127 // b. IntegerAttrs if the port is a non-domain type
128 //
129 // Note: error handling here intentionally does _not_ use port info locations.
130 // This is because if any of these fail, this is almost always a
131 // CIRCT-internal bug. These code paths are essoentially inaccessible from
132 // FIRRTL text.
133 auto domains = module.getDomainInfoAttr();
134 // Domain info cannot be null.
135 if (!domains)
136 return module.emitOpError("requires valid port domains");
137 // If non-empty, the one-entry-per-port.
138 if (!domains.empty() && domains.size() != numPorts)
139 return module.emitOpError("requires ")
140 << numPorts << " port domains, but has " << domains.size();
141
142 for (auto [index, port] : llvm::enumerate(module.getPorts())) {
143 auto type = cast<TypeAttr>(portTypes[index]).getValue();
144 // If this is a domain type port, then it _must_ refer to a domain. This
145 // cannot be empty.
146 if (isa<DomainType>(type)) {
147 if (domains.empty())
148 return module.emitOpError() << "has domain port '" << port.getName()
149 << "', but no domain information";
150 if (!isa<FlatSymbolRefAttr>(domains[index]))
151 return module.emitOpError() << "domain information for domain port '"
152 << module.getPortName(index)
153 << "' must be a 'FlatSymbolRefAttr'";
154 continue;
155 }
156
157 // Non-domain type ports can have no domain information.
158 if (domains.empty())
159 continue;
160
161 // If they do have domain information, then the domain info must be an array
162 // of integers and each integer must point to a domain type port.
163 auto domain = domains[index];
164 auto arrayAttr = dyn_cast<ArrayAttr>(domain);
165 if (!arrayAttr)
166 return module.emitOpError()
167 << "domain information for non-domain port '"
168 << module.getPortName(index) << "' must be an 'ArrayAttr'";
169 for (auto attr : arrayAttr) {
170 auto association = dyn_cast<IntegerAttr>(attr);
171 if (!association)
172 return module.emitOpError()
173 << "domain information for non-domain port '"
174 << module.getPortName(index)
175 << "' must be an 'ArrayAttr<IntegerAttr>'";
176 auto associationIdx = association.getValue().getZExtValue();
177 if (associationIdx >= numPorts)
178 return module.emitOpError()
179 << "has domain association " << associationIdx << " for port '"
180 << module.getPortName(index) << "', but the module only has "
181 << numPorts << " ports";
182 if (!type_isa<DomainType>(module.getPortType(associationIdx)))
183 return module.emitOpError()
184 << "has port '" << module.getPortName(index)
185 << "' which has a domain association with non-domain port '"
186 << module.getPortName(associationIdx) << "'";
187 }
188 }
189
190 // Verify the body.
191 if (module->getNumRegions() != 1)
192 return module.emitOpError("requires one region");
193
194 if (failed(verifyNoInputProbes(module)))
195 return failure();
196
197 return success();
198}
199
200//===----------------------------------------------------------------------===//
201// Forceable
202//===----------------------------------------------------------------------===//
203
205 Type type) {
206 auto base = dyn_cast_or_null<FIRRTLBaseType>(type);
207 // TODO: Find a way to not check same things RefType::get/verify does.
208 if (!forceable || !base || base.containsConst())
209 return {};
210 return circt::firrtl::RefType::get(base.getPassiveType(), forceable);
211}
212
213LogicalResult circt::firrtl::detail::verifyForceableOp(Forceable op) {
214 bool forceable = op.isForceable();
215 auto ref = op.getDataRef();
216 if ((bool)ref != forceable)
217 return op.emitOpError("must have ref result iff marked forceable");
218 if (!forceable)
219 return success();
220 auto data = op.getDataRaw();
221 auto baseType = type_dyn_cast<FIRRTLBaseType>(data.getType());
222 if (!baseType)
223 return op.emitOpError("has data that is not a base type");
224 if (baseType.containsConst())
225 return op.emitOpError("cannot force a declaration of constant type");
226 auto expectedRefType = getForceableResultType(forceable, baseType);
227 if (ref.getType() != expectedRefType)
228 return op.emitOpError("reference result of incorrect type, found ")
229 << ref.getType() << ", expected " << expectedRefType;
230 return success();
231}
232
233namespace {
234/// Simple wrapper to allow construction from a context for local use.
235class TrivialPatternRewriter : public PatternRewriter {
236public:
237 explicit TrivialPatternRewriter(MLIRContext *context)
238 : PatternRewriter(context) {}
239};
240} // end namespace
241
242Forceable
243circt::firrtl::detail::replaceWithNewForceability(Forceable op, bool forceable,
244 PatternRewriter *rewriter) {
245 if (forceable == op.isForceable())
246 return op;
247
248 assert(op->getNumRegions() == 0);
249
250 // Create copy of this operation with/without the forceable marker + result
251 // type.
252
253 TrivialPatternRewriter localRewriter(op.getContext());
254 PatternRewriter &rw = rewriter ? *rewriter : localRewriter;
255
256 // Grab the current operation's results and attributes.
257 SmallVector<Type, 8> resultTypes(op->getResultTypes());
258 SmallVector<NamedAttribute, 16> attributes(op->getAttrs());
259
260 // Add/remove the optional ref result.
261 auto refType = firrtl::detail::getForceableResultType(true, op.getDataType());
262 if (forceable)
263 resultTypes.push_back(refType);
264 else {
265 assert(resultTypes.back() == refType &&
266 "expected forceable type as last result");
267 resultTypes.pop_back();
268 }
269
270 // Add/remove the forceable marker.
271 auto forceableMarker =
272 rw.getNamedAttr(op.getForceableAttrName(), rw.getUnitAttr());
273 if (forceable)
274 attributes.push_back(forceableMarker);
275 else {
276 llvm::erase(attributes, forceableMarker);
277 assert(attributes.size() != op->getAttrs().size());
278 }
279
280 // Create the replacement operation.
281 OperationState state(op.getLoc(), op->getName(), op->getOperands(),
282 resultTypes, attributes, op->getSuccessors());
283 rw.setInsertionPoint(op);
284 auto *replace = rw.create(state);
285
286 // Dropping forceability (!forceable) -> no uses of forceable ref handle.
287 assert(forceable || op.getDataRef().use_empty());
288
289 // Replace results.
290 for (auto result : llvm::drop_end(op->getResults(), forceable ? 0 : 1))
291 rw.replaceAllUsesWith(result, replace->getResult(result.getResultNumber()));
292 rw.eraseOp(op);
293 return cast<Forceable>(replace);
294}
295
296#include "circt/Dialect/FIRRTL/FIRRTLOpInterfaces.cpp.inc"
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static LogicalResult verifyNoInputProbes(FModuleLike module)
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
LogicalResult verifyForceableOp(Forceable op)
Verify a Forceable op.
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
RefType getForceableResultType(bool forceable, Type type)
Return null or forceable reference result type.
LogicalResult verifyModuleLikeOpInterface(FModuleLike module)
Verification hook for verifying module like operations.