15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "llvm/ADT/SmallBitVector.h"
19 #include "llvm/ADT/StringRef.h"
27 std::function<bool(Type,
bool)> hasInputRef = [&](Type type,
28 bool output) ->
bool {
29 auto ftype = type_dyn_cast<FIRRTLType>(type);
30 if (!ftype || !ftype.containsReference())
33 .
Case<RefType>([&](
auto reftype) {
return !output; })
34 .Case<OpenVectorType>([&](OpenVectorType ovt) {
35 return hasInputRef(ovt.getElementType(), output);
37 .Case<OpenBundleType>([&](OpenBundleType obt) {
38 for (
auto field : obt.getElements())
39 if (hasInputRef(field.type, field.isFlip ^ output))
45 if (module.isPublic()) {
46 for (
auto &pi : module.getPorts()) {
47 if (hasInputRef(pi.type, pi.isOutput()))
48 return emitError(pi.loc,
"input probe not allowed on public module");
57 auto portTypes = module.getPortTypesAttr();
58 if (!portTypes || llvm::any_of(portTypes.getValue(), [](Attribute attr) {
59 return !isa<TypeAttr>(attr);
61 return module.emitOpError(
"requires valid port types");
63 auto numPorts = portTypes.size();
66 auto portDirections = module.getPortDirectionsAttr();
68 return module.emitOpError(
"requires valid port direction");
71 auto bitWidth = portDirections.size();
72 if (
static_cast<size_t>(bitWidth) != numPorts)
73 return module.emitOpError(
"requires ") << numPorts <<
" port directions";
76 auto portNames = module.getPortNamesAttr();
78 return module.emitOpError(
"requires valid port names");
79 if (portNames.size() != numPorts)
80 return module.emitOpError(
"requires ") << numPorts <<
" port names";
81 if (llvm::any_of(portNames.getValue(),
82 [](Attribute attr) { return !isa<StringAttr>(attr); }))
83 return module.emitOpError(
"port names should all be string attributes");
86 auto portAnnotations = module.getPortAnnotationsAttr();
88 return module.emitOpError(
"requires valid port annotations");
90 if (!portAnnotations.empty() && portAnnotations.size() != numPorts)
91 return module.emitOpError(
"requires ") << numPorts <<
" port annotations";
93 for (
auto annos : portAnnotations.getValue()) {
94 auto arrayAttr = dyn_cast<ArrayAttr>(annos);
96 return module.emitOpError(
97 "requires port annotations be array attributes");
98 if (llvm::any_of(arrayAttr.getValue(),
99 [](Attribute attr) { return !isa<DictionaryAttr>(attr); }))
100 return module.emitOpError(
101 "annotations must be dictionaries or subannotations");
105 auto portSymbols = module.getPortSymbolsAttr();
107 return module.emitOpError(
"requires valid port symbols");
108 if (!portSymbols.empty() && portSymbols.size() != numPorts)
109 return module.emitOpError(
"requires ") << numPorts <<
" port symbols";
110 if (llvm::any_of(portSymbols.getValue(), [](Attribute attr) {
111 return !attr || !isa<hw::InnerSymAttr>(attr);
113 return module.emitOpError(
"port symbols should all be InnerSym attributes");
116 auto portLocs = module.getPortLocationsAttr();
118 return module.emitOpError(
"requires valid port locations");
119 if (portLocs.size() != numPorts)
120 return module.emitOpError(
"requires ") << numPorts <<
" port locations";
121 if (llvm::any_of(portLocs.getValue(), [](Attribute attr) {
122 return !attr || !isa<LocationAttr>(attr);
124 return module.emitOpError(
"port symbols should all be location attributes");
127 if (module->getNumRegions() != 1)
128 return module.emitOpError(
"requires one region");
142 auto base = dyn_cast_or_null<FIRRTLBaseType>(type);
144 if (!forceable || !base || base.containsConst())
150 bool forceable = op.isForceable();
151 auto ref = op.getDataRef();
152 if ((
bool)ref != forceable)
153 return op.emitOpError(
"must have ref result iff marked forceable");
156 auto data = op.getDataRaw();
157 auto baseType = type_dyn_cast<FIRRTLBaseType>(
data.getType());
159 return op.emitOpError(
"has data that is not a base type");
160 if (baseType.containsConst())
161 return op.emitOpError(
"cannot force a declaration of constant type");
163 if (ref.getType() != expectedRefType)
164 return op.emitOpError(
"reference result of incorrect type, found ")
165 << ref.getType() <<
", expected " << expectedRefType;
171 class TrivialPatternRewriter :
public PatternRewriter {
173 explicit TrivialPatternRewriter(MLIRContext *context)
174 : PatternRewriter(context) {}
180 PatternRewriter *rewriter) {
181 if (forceable == op.isForceable())
184 assert(op->getNumRegions() == 0);
189 TrivialPatternRewriter localRewriter(op.getContext());
190 PatternRewriter &rw = rewriter ? *rewriter : localRewriter;
193 SmallVector<Type, 8> resultTypes(op->getResultTypes());
194 SmallVector<NamedAttribute, 16> attributes(op->getAttrs());
199 resultTypes.push_back(refType);
201 assert(resultTypes.back() == refType &&
202 "expected forceable type as last result");
203 resultTypes.pop_back();
207 auto forceableMarker =
208 rw.getNamedAttr(op.getForceableAttrName(), rw.getUnitAttr());
210 attributes.push_back(forceableMarker);
212 llvm::erase(attributes, forceableMarker);
213 assert(attributes.size() != op->getAttrs().size());
217 OperationState state(op.getLoc(), op->getName(), op->getOperands(),
218 resultTypes, attributes, op->getSuccessors());
219 rw.setInsertionPoint(op);
220 auto *replace = rw.create(state);
223 assert(forceable || op.getDataRef().use_empty());
226 for (
auto result : llvm::drop_end(op->getResults(), forceable ? 0 : 1))
227 rw.replaceAllUsesWith(result, replace->getResult(result.getResultNumber()));
229 return cast<Forceable>(replace);
232 #include "circt/Dialect/FIRRTL/FIRRTLOpInterfaces.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyNoInputProbes(FModuleLike module)
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
LogicalResult verifyForceableOp(Forceable op)
Verify a Forceable op.
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
RefType getForceableResultType(bool forceable, Type type)
Return null or forceable reference result type.
LogicalResult verifyModuleLikeOpInterface(FModuleLike module)
Verification hook for verifying module like operations.