15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "llvm/ADT/SmallBitVector.h"
19 #include "llvm/ADT/StringRef.h"
28 auto portTypes = module.getPortTypesAttr();
29 if (!portTypes || llvm::any_of(portTypes.getValue(), [](Attribute attr) {
30 return !isa<TypeAttr>(attr);
32 return module.emitOpError(
"requires valid port types");
34 auto numPorts = portTypes.size();
37 auto portDirections = module.getPortDirectionsAttr();
39 return module.emitOpError(
"requires valid port direction");
42 auto bitWidth = portDirections.getValue().getBitWidth();
43 if (bitWidth != numPorts)
44 return module.emitOpError(
"requires ") << numPorts <<
" port directions";
47 auto portNames = module.getPortNamesAttr();
49 return module.emitOpError(
"requires valid port names");
50 if (portNames.size() != numPorts)
51 return module.emitOpError(
"requires ") << numPorts <<
" port names";
52 if (llvm::any_of(portNames.getValue(),
53 [](Attribute attr) { return !isa<StringAttr>(attr); }))
54 return module.emitOpError(
"port names should all be string attributes");
57 auto portAnnotations = module.getPortAnnotationsAttr();
59 return module.emitOpError(
"requires valid port annotations");
61 if (!portAnnotations.empty() && portAnnotations.size() != numPorts)
62 return module.emitOpError(
"requires ") << numPorts <<
" port annotations";
64 for (
auto annos : portAnnotations.getValue()) {
65 auto arrayAttr = dyn_cast<ArrayAttr>(annos);
67 return module.emitOpError(
68 "requires port annotations be array attributes");
69 if (llvm::any_of(arrayAttr.getValue(),
70 [](Attribute attr) { return !isa<DictionaryAttr>(attr); }))
71 return module.emitOpError(
72 "annotations must be dictionaries or subannotations");
76 auto portSymbols = module.getPortSymbolsAttr();
78 return module.emitOpError(
"requires valid port symbols");
79 if (!portSymbols.empty() && portSymbols.size() != numPorts)
80 return module.emitOpError(
"requires ") << numPorts <<
" port symbols";
81 if (llvm::any_of(portSymbols.getValue(), [](Attribute attr) {
82 return !attr || !isa<hw::InnerSymAttr>(attr);
84 return module.emitOpError(
"port symbols should all be InnerSym attributes");
87 auto portLocs = module.getPortLocationsAttr();
89 return module.emitOpError(
"requires valid port locations");
90 if (portLocs.size() != numPorts)
91 return module.emitOpError(
"requires ") << numPorts <<
" port locations";
92 if (llvm::any_of(portLocs.getValue(), [](Attribute attr) {
93 return !attr || !isa<LocationAttr>(attr);
95 return module.emitOpError(
"port symbols should all be location attributes");
98 if (module->getNumRegions() != 1)
99 return module.emitOpError(
"requires one region");
110 auto base = dyn_cast_or_null<FIRRTLBaseType>(type);
112 if (!forceable || !base || base.containsConst())
118 bool forceable = op.isForceable();
119 auto ref = op.getDataRef();
120 if ((
bool)ref != forceable)
121 return op.emitOpError(
"must have ref result iff marked forceable");
124 auto data = op.getDataRaw();
125 auto baseType = type_dyn_cast<FIRRTLBaseType>(
data.getType());
127 return op.emitOpError(
"has data that is not a base type");
128 if (baseType.containsConst())
129 return op.emitOpError(
"cannot force a declaration of constant type");
131 if (ref.getType() != expectedRefType)
132 return op.emitOpError(
"reference result of incorrect type, found ")
133 << ref.getType() <<
", expected " << expectedRefType;
139 class TrivialPatternRewriter :
public PatternRewriter {
141 explicit TrivialPatternRewriter(MLIRContext *context)
142 : PatternRewriter(context) {}
148 PatternRewriter *rewriter) {
149 if (forceable == op.isForceable())
152 assert(op->getNumRegions() == 0);
157 TrivialPatternRewriter localRewriter(op.getContext());
158 PatternRewriter &rw = rewriter ? *rewriter : localRewriter;
161 SmallVector<Type, 8> resultTypes(op->getResultTypes());
162 SmallVector<NamedAttribute, 16> attributes(op->getAttrs());
167 resultTypes.push_back(refType);
169 assert(resultTypes.back() == refType &&
170 "expected forceable type as last result");
171 resultTypes.pop_back();
175 auto forceableMarker =
176 rw.getNamedAttr(op.getForceableAttrName(), rw.getUnitAttr());
178 attributes.push_back(forceableMarker);
180 llvm::erase(attributes, forceableMarker);
181 assert(attributes.size() != op->getAttrs().size());
185 OperationState state(op.getLoc(), op->getName(), op->getOperands(),
186 resultTypes, attributes, op->getSuccessors());
187 rw.setInsertionPoint(op);
188 auto *replace = rw.create(state);
191 assert(forceable || op.getDataRef().use_empty());
194 for (
auto result : llvm::drop_end(op->getResults(), forceable ? 0 : 1))
195 rw.replaceAllUsesWith(result, replace->getResult(result.getResultNumber()));
197 return cast<Forceable>(replace);
200 #include "circt/Dialect/FIRRTL/FIRRTLOpInterfaces.cpp.inc"
assert(baseType &&"element must be base type")
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
LogicalResult verifyForceableOp(Forceable op)
Verify a Forceable op.
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
RefType getForceableResultType(bool forceable, Type type)
Return null or forceable reference result type.
LogicalResult verifyModuleLikeOpInterface(FModuleLike module)
Verification hook for verifying module like operations.