15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "llvm/ADT/SmallBitVector.h"
19 #include "llvm/ADT/StringRef.h"
27 std::function<bool(Type,
bool)> hasInputRef = [&](Type type,
28 bool output) ->
bool {
29 auto ftype = type_dyn_cast<FIRRTLType>(type);
30 if (!ftype || !ftype.containsReference())
33 .
Case<RefType>([&](
auto reftype) {
return !output; })
34 .Case<OpenVectorType>([&](OpenVectorType ovt) {
35 return hasInputRef(ovt.getElementType(), output);
37 .Case<OpenBundleType>([&](OpenBundleType obt) {
38 for (
auto field : obt.getElements())
39 if (hasInputRef(field.type, field.isFlip ^ output))
45 for (
auto &pi : module.getPorts())
46 if (hasInputRef(pi.type, pi.isOutput()))
47 return emitError(pi.loc,
"input probe not allowed");
54 auto portTypes = module.getPortTypesAttr();
55 if (!portTypes || llvm::any_of(portTypes.getValue(), [](Attribute attr) {
56 return !isa<TypeAttr>(attr);
58 return module.emitOpError(
"requires valid port types");
60 auto numPorts = portTypes.size();
63 auto portDirections = module.getPortDirectionsAttr();
65 return module.emitOpError(
"requires valid port direction");
68 auto bitWidth = portDirections.size();
69 if (
static_cast<size_t>(bitWidth) != numPorts)
70 return module.emitOpError(
"requires ") << numPorts <<
" port directions";
73 auto portNames = module.getPortNamesAttr();
75 return module.emitOpError(
"requires valid port names");
76 if (portNames.size() != numPorts)
77 return module.emitOpError(
"requires ") << numPorts <<
" port names";
78 if (llvm::any_of(portNames.getValue(),
79 [](Attribute attr) { return !isa<StringAttr>(attr); }))
80 return module.emitOpError(
"port names should all be string attributes");
83 auto portAnnotations = module.getPortAnnotationsAttr();
85 return module.emitOpError(
"requires valid port annotations");
87 if (!portAnnotations.empty() && portAnnotations.size() != numPorts)
88 return module.emitOpError(
"requires ") << numPorts <<
" port annotations";
90 for (
auto annos : portAnnotations.getValue()) {
91 auto arrayAttr = dyn_cast<ArrayAttr>(annos);
93 return module.emitOpError(
94 "requires port annotations be array attributes");
95 if (llvm::any_of(arrayAttr.getValue(),
96 [](Attribute attr) { return !isa<DictionaryAttr>(attr); }))
97 return module.emitOpError(
98 "annotations must be dictionaries or subannotations");
102 auto portSymbols = module.getPortSymbolsAttr();
104 return module.emitOpError(
"requires valid port symbols");
105 if (!portSymbols.empty() && portSymbols.size() != numPorts)
106 return module.emitOpError(
"requires ") << numPorts <<
" port symbols";
107 if (llvm::any_of(portSymbols.getValue(), [](Attribute attr) {
108 return !attr || !isa<hw::InnerSymAttr>(attr);
110 return module.emitOpError(
"port symbols should all be InnerSym attributes");
113 auto portLocs = module.getPortLocationsAttr();
115 return module.emitOpError(
"requires valid port locations");
116 if (portLocs.size() != numPorts)
117 return module.emitOpError(
"requires ") << numPorts <<
" port locations";
118 if (llvm::any_of(portLocs.getValue(), [](Attribute attr) {
119 return !attr || !isa<LocationAttr>(attr);
121 return module.emitOpError(
"port symbols should all be location attributes");
124 if (module->getNumRegions() != 1)
125 return module.emitOpError(
"requires one region");
139 auto base = dyn_cast_or_null<FIRRTLBaseType>(type);
141 if (!forceable || !base || base.containsConst())
147 bool forceable = op.isForceable();
148 auto ref = op.getDataRef();
149 if ((
bool)ref != forceable)
150 return op.emitOpError(
"must have ref result iff marked forceable");
153 auto data = op.getDataRaw();
154 auto baseType = type_dyn_cast<FIRRTLBaseType>(
data.getType());
156 return op.emitOpError(
"has data that is not a base type");
157 if (baseType.containsConst())
158 return op.emitOpError(
"cannot force a declaration of constant type");
160 if (ref.getType() != expectedRefType)
161 return op.emitOpError(
"reference result of incorrect type, found ")
162 << ref.getType() <<
", expected " << expectedRefType;
168 class TrivialPatternRewriter :
public PatternRewriter {
170 explicit TrivialPatternRewriter(MLIRContext *context)
171 : PatternRewriter(context) {}
177 PatternRewriter *rewriter) {
178 if (forceable == op.isForceable())
181 assert(op->getNumRegions() == 0);
186 TrivialPatternRewriter localRewriter(op.getContext());
187 PatternRewriter &rw = rewriter ? *rewriter : localRewriter;
190 SmallVector<Type, 8> resultTypes(op->getResultTypes());
191 SmallVector<NamedAttribute, 16> attributes(op->getAttrs());
196 resultTypes.push_back(refType);
198 assert(resultTypes.back() == refType &&
199 "expected forceable type as last result");
200 resultTypes.pop_back();
204 auto forceableMarker =
205 rw.getNamedAttr(op.getForceableAttrName(), rw.getUnitAttr());
207 attributes.push_back(forceableMarker);
209 llvm::erase(attributes, forceableMarker);
210 assert(attributes.size() != op->getAttrs().size());
214 OperationState state(op.getLoc(), op->getName(), op->getOperands(),
215 resultTypes, attributes, op->getSuccessors());
216 rw.setInsertionPoint(op);
217 auto *replace = rw.create(state);
220 assert(forceable || op.getDataRef().use_empty());
223 for (
auto result : llvm::drop_end(op->getResults(), forceable ? 0 : 1))
224 rw.replaceAllUsesWith(result, replace->getResult(result.getResultNumber()));
226 return cast<Forceable>(replace);
229 #include "circt/Dialect/FIRRTL/FIRRTLOpInterfaces.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyNoInputProbes(FModuleLike module)
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
LogicalResult verifyForceableOp(Forceable op)
Verify a Forceable op.
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
RefType getForceableResultType(bool forceable, Type type)
Return null or forceable reference result type.
LogicalResult verifyModuleLikeOpInterface(FModuleLike module)
Verification hook for verifying module like operations.