Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
FIRRTLOpInterfaces.cpp
Go to the documentation of this file.
1//===- FIRRTLOpInterfaces.cpp - Implement the FIRRTL op interfaces --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the FIRRTL operation interfaces.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/BuiltinAttributes.h"
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/PatternMatch.h"
18#include "llvm/ADT/SmallBitVector.h"
19#include "llvm/ADT/StringRef.h"
20
21using namespace mlir;
22using namespace llvm;
23using namespace circt::firrtl;
24
25static LogicalResult verifyNoInputProbes(FModuleLike module) {
26 // Helper to check for input-oriented refs.
27 std::function<bool(Type, bool)> hasInputRef = [&](Type type,
28 bool output) -> bool {
29 auto ftype = type_dyn_cast<FIRRTLType>(type);
30 if (!ftype || !ftype.containsReference())
31 return false;
33 .Case<RefType>([&](auto reftype) { return !output; })
34 .Case<OpenVectorType>([&](OpenVectorType ovt) {
35 return hasInputRef(ovt.getElementType(), output);
36 })
37 .Case<OpenBundleType>([&](OpenBundleType obt) {
38 for (auto field : obt.getElements())
39 if (hasInputRef(field.type, field.isFlip ^ output))
40 return true;
41 return false;
42 });
43 };
44
45 for (auto &pi : module.getPorts())
46 if (hasInputRef(pi.type, pi.isOutput()))
47 return emitError(pi.loc, "input probe not allowed");
48 return success();
49}
50
51LogicalResult circt::firrtl::verifyModuleLikeOpInterface(FModuleLike module) {
52 // Verify port types first. This is used as the basis for the number of
53 // ports required everywhere else.
54 auto portTypes = module.getPortTypesAttr();
55 if (!portTypes || llvm::any_of(portTypes.getValue(), [](Attribute attr) {
56 return !isa<TypeAttr>(attr);
57 }))
58 return module.emitOpError("requires valid port types");
59
60 auto numPorts = portTypes.size();
61
62 // Verify the port dirctions.
63 auto portDirections = module.getPortDirectionsAttr();
64 if (!portDirections)
65 return module.emitOpError("requires valid port direction");
66 // TODO: bitwidth is 1 when there are no ports, since APInt previously did not
67 // support 0 bit widths.
68 auto bitWidth = portDirections.size();
69 if (static_cast<size_t>(bitWidth) != numPorts)
70 return module.emitOpError("requires ") << numPorts << " port directions";
71
72 // Verify the port names.
73 auto portNames = module.getPortNamesAttr();
74 if (!portNames)
75 return module.emitOpError("requires valid port names");
76 if (portNames.size() != numPorts)
77 return module.emitOpError("requires ") << numPorts << " port names";
78 if (llvm::any_of(portNames.getValue(),
79 [](Attribute attr) { return !isa<StringAttr>(attr); }))
80 return module.emitOpError("port names should all be string attributes");
81
82 // Verify the port annotations.
83 auto portAnnotations = module.getPortAnnotationsAttr();
84 if (!portAnnotations)
85 return module.emitOpError("requires valid port annotations");
86 // TODO: it seems weird to allow empty port annotations.
87 if (!portAnnotations.empty() && portAnnotations.size() != numPorts)
88 return module.emitOpError("requires ") << numPorts << " port annotations";
89 // TODO: Move this into an annotation verifier.
90 for (auto annos : portAnnotations.getValue()) {
91 auto arrayAttr = dyn_cast<ArrayAttr>(annos);
92 if (!arrayAttr)
93 return module.emitOpError(
94 "requires port annotations be array attributes");
95 if (llvm::any_of(arrayAttr.getValue(),
96 [](Attribute attr) { return !isa<DictionaryAttr>(attr); }))
97 return module.emitOpError(
98 "annotations must be dictionaries or subannotations");
99 }
100
101 // Verify the port symbols.
102 auto portSymbols = module.getPortSymbolsAttr();
103 if (!portSymbols)
104 return module.emitOpError("requires valid port symbols");
105 if (!portSymbols.empty() && portSymbols.size() != numPorts)
106 return module.emitOpError("requires ") << numPorts << " port symbols";
107 if (llvm::any_of(portSymbols.getValue(), [](Attribute attr) {
108 return !attr || !isa<hw::InnerSymAttr>(attr);
109 }))
110 return module.emitOpError("port symbols should all be InnerSym attributes");
111
112 // Verify the port locations.
113 auto portLocs = module.getPortLocationsAttr();
114 if (!portLocs)
115 return module.emitOpError("requires valid port locations");
116 if (portLocs.size() != numPorts)
117 return module.emitOpError("requires ") << numPorts << " port locations";
118 if (llvm::any_of(portLocs.getValue(), [](Attribute attr) {
119 return !attr || !isa<LocationAttr>(attr);
120 }))
121 return module.emitOpError("port symbols should all be location attributes");
122
123 // Verify the port domains. This can be either:
124 // 1. An empty ArrayAttr.
125 // 2. An ArrayAttr, one entry-per-port, of:
126 // a. SymbolRefAttrs if the port is a domain type
127 // b. IntegerAttrs if the port is a non-domain type
128 auto domains = module.getDomainInfoAttr();
129 if (!domains)
130 return module.emitOpError("requires valid port domains");
131 if (!domains.empty() && domains.size() != numPorts)
132 return module.emitOpError("requires ")
133 << numPorts << " port domains, but has " << domains.size();
134 for (auto [index, domain] : llvm::enumerate(domains)) {
135 auto type = cast<TypeAttr>(portTypes[index]).getValue();
136 if (isa<DomainType>(type)) {
137 if (!isa<FlatSymbolRefAttr>(domain))
138 return module.emitOpError() << "domain information for domain port '"
139 << module.getPortName(index)
140 << "' must be a 'FlatSymbolRefAttr'";
141 continue;
142 }
143 auto arrayAttr = dyn_cast<ArrayAttr>(domain);
144 if (!arrayAttr)
145 return module.emitOpError()
146 << "domain information for non-domain port '"
147 << module.getPortName(index) << "' must be an 'ArrayAttr'";
148 if (llvm::any_of(arrayAttr,
149 [](Attribute attr) { return !isa<IntegerAttr>(attr); }))
150 return module.emitOpError() << "domain information for non-domain port '"
151 << module.getPortName(index)
152 << "' must be an 'ArrayAttr<IntegerAttr>'";
153 }
154
155 // Verify the body.
156 if (module->getNumRegions() != 1)
157 return module.emitOpError("requires one region");
158
159 if (failed(verifyNoInputProbes(module)))
160 return failure();
161
162 return success();
163}
164
165//===----------------------------------------------------------------------===//
166// Forceable
167//===----------------------------------------------------------------------===//
168
170 Type type) {
171 auto base = dyn_cast_or_null<FIRRTLBaseType>(type);
172 // TODO: Find a way to not check same things RefType::get/verify does.
173 if (!forceable || !base || base.containsConst())
174 return {};
175 return circt::firrtl::RefType::get(base.getPassiveType(), forceable);
176}
177
178LogicalResult circt::firrtl::detail::verifyForceableOp(Forceable op) {
179 bool forceable = op.isForceable();
180 auto ref = op.getDataRef();
181 if ((bool)ref != forceable)
182 return op.emitOpError("must have ref result iff marked forceable");
183 if (!forceable)
184 return success();
185 auto data = op.getDataRaw();
186 auto baseType = type_dyn_cast<FIRRTLBaseType>(data.getType());
187 if (!baseType)
188 return op.emitOpError("has data that is not a base type");
189 if (baseType.containsConst())
190 return op.emitOpError("cannot force a declaration of constant type");
191 auto expectedRefType = getForceableResultType(forceable, baseType);
192 if (ref.getType() != expectedRefType)
193 return op.emitOpError("reference result of incorrect type, found ")
194 << ref.getType() << ", expected " << expectedRefType;
195 return success();
196}
197
198namespace {
199/// Simple wrapper to allow construction from a context for local use.
200class TrivialPatternRewriter : public PatternRewriter {
201public:
202 explicit TrivialPatternRewriter(MLIRContext *context)
203 : PatternRewriter(context) {}
204};
205} // end namespace
206
207Forceable
208circt::firrtl::detail::replaceWithNewForceability(Forceable op, bool forceable,
209 PatternRewriter *rewriter) {
210 if (forceable == op.isForceable())
211 return op;
212
213 assert(op->getNumRegions() == 0);
214
215 // Create copy of this operation with/without the forceable marker + result
216 // type.
217
218 TrivialPatternRewriter localRewriter(op.getContext());
219 PatternRewriter &rw = rewriter ? *rewriter : localRewriter;
220
221 // Grab the current operation's results and attributes.
222 SmallVector<Type, 8> resultTypes(op->getResultTypes());
223 SmallVector<NamedAttribute, 16> attributes(op->getAttrs());
224
225 // Add/remove the optional ref result.
226 auto refType = firrtl::detail::getForceableResultType(true, op.getDataType());
227 if (forceable)
228 resultTypes.push_back(refType);
229 else {
230 assert(resultTypes.back() == refType &&
231 "expected forceable type as last result");
232 resultTypes.pop_back();
233 }
234
235 // Add/remove the forceable marker.
236 auto forceableMarker =
237 rw.getNamedAttr(op.getForceableAttrName(), rw.getUnitAttr());
238 if (forceable)
239 attributes.push_back(forceableMarker);
240 else {
241 llvm::erase(attributes, forceableMarker);
242 assert(attributes.size() != op->getAttrs().size());
243 }
244
245 // Create the replacement operation.
246 OperationState state(op.getLoc(), op->getName(), op->getOperands(),
247 resultTypes, attributes, op->getSuccessors());
248 rw.setInsertionPoint(op);
249 auto *replace = rw.create(state);
250
251 // Dropping forceability (!forceable) -> no uses of forceable ref handle.
252 assert(forceable || op.getDataRef().use_empty());
253
254 // Replace results.
255 for (auto result : llvm::drop_end(op->getResults(), forceable ? 0 : 1))
256 rw.replaceAllUsesWith(result, replace->getResult(result.getResultNumber()));
257 rw.eraseOp(op);
258 return cast<Forceable>(replace);
259}
260
261#include "circt/Dialect/FIRRTL/FIRRTLOpInterfaces.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyNoInputProbes(FModuleLike module)
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
LogicalResult verifyForceableOp(Forceable op)
Verify a Forceable op.
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
RefType getForceableResultType(bool forceable, Type type)
Return null or forceable reference result type.
LogicalResult verifyModuleLikeOpInterface(FModuleLike module)
Verification hook for verifying module like operations.