11 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/DialectConversion.h"
27 #define GEN_PASS_DEF_KANAGAWACALLPREP
28 #include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
32 using namespace circt;
33 using namespace kanagawa;
43 DenseMap<hw::InnerRefAttr, std::pair<hw::StructType, Location>>
argTypes;
47 DenseMap<std::pair<ClassOp, StringAttr>, ClassOp>
instanceMap;
50 ClassOp
lookupNext(ClassOp scope, StringAttr instSym)
const {
51 auto entry = instanceMap.find(std::make_pair(scope, instSym));
52 if (entry == instanceMap.end())
58 static hw::InnerRefAttr
getSymbol(MethodOp method) {
59 auto design = method->getParentOfType<DesignOp>();
61 method.getInnerSym().getSymName());
66 auto *
ctxt = design.getContext();
67 StringAttr modName = design.getNameAttr();
70 for (
auto cls : design.getOps<ClassOp>())
72 modName, cls.getInnerSymAttr().getSymName())] = cls;
74 for (
auto cls : design.getOps<ClassOp>()) {
76 for (
auto method : cls.getOps<MethodOp>()) {
79 SmallVector<hw::StructType::FieldInfo> argFields;
80 for (
auto [argName, argType] :
81 llvm::zip(method.getArgNamesAttr().getAsRange<StringAttr>(),
82 cast<MethodLikeOpInterface>(method.getOperation())
84 argFields.push_back({argName, argType});
90 if (method->getNumRegions() > 0) {
91 SmallVector<Location> argLocs;
92 Block *body = &method.getBody().front();
93 for (
auto arg : body->getArguments())
94 argLocs.push_back(arg.getLoc());
100 std::make_pair(getSymbol(method), std::make_pair(argStruct, argLoc)));
104 for (
auto inst : cls.getOps<InstanceOp>()) {
105 auto clsEntry = classSymbols.find(inst.getTargetNameAttr());
106 assert(clsEntry != classSymbols.end() &&
107 "class being instantiated doesn't exist");
108 instanceMap[std::make_pair(cls, inst.getInnerSym().getSymName())] =
121 void rewrite(CallOp, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter)
const final;
123 LogicalResult match(CallOp)
const override {
return success(); }
130 void MergeCallArgs::rewrite(CallOp call, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter)
const {
132 auto loc = call.getLoc();
133 rewriter.setInsertionPoint(call);
134 auto method = call->getParentOfType<kanagawa::MethodLikeOpInterface>();
137 auto argStructEntry = info.
argTypes.find(call.getCalleeAttr());
138 assert(argStructEntry != info.argTypes.end() &&
"Method symref not found!");
139 auto [argStruct, argLoc] = argStructEntry->second;
143 adaptor.getOperands());
144 newArg->setAttr(
"sv.namehint",
145 rewriter.getStringAttr(call.getCallee().getName().getValue() +
146 "_args_called_from_" +
147 method.getMethodName().getValue()));
150 rewriter.modifyOpInPlace(call, [&]() {
151 call.getOperandsMutable().clear();
152 call.getOperandsMutable().append(newArg.getResult());
163 void rewrite(MethodOp, OpAdaptor adaptor,
164 ConversionPatternRewriter &rewriter)
const final;
165 LogicalResult match(MethodOp)
const override {
return success(); }
172 void MergeMethodArgs::rewrite(MethodOp func, OpAdaptor adaptor,
173 ConversionPatternRewriter &rewriter)
const {
174 auto loc = func.getLoc();
175 auto *
ctxt = getContext();
178 auto argStructEntry =
180 assert(argStructEntry != info.argTypes.end() &&
"Cannot find symref!");
181 auto [argStruct, argLoc] = argStructEntry->second;
184 FunctionType funcType = func.getFunctionType();
185 FunctionType newFuncType =
189 rewriter.create<MethodOp>(loc, func.getInnerSym(), newFuncType,
190 newArgNames, ArrayAttr(), ArrayAttr());
192 if (func->getNumRegions() > 0) {
196 rewriter.createBlock(&newMethod.getRegion(), {}, {argStruct}, {argLoc});
197 rewriter.setInsertionPointToStart(b);
198 auto replacementArgs =
199 rewriter.create<hw::StructExplodeOp>(loc, b->getArgument(0));
202 Block *funcBody = &func.getBody().front();
203 rewriter.mergeBlocks(funcBody, b, replacementArgs.getResults());
206 rewriter.eraseOp(func);
212 :
public circt::kanagawa::impl::KanagawaCallPrepBase<CallPrepPass> {
213 void runOnOperation()
override;
221 void CallPrepPass::runOnOperation() {
224 if (failed(merge(info))) {
232 ConversionTarget target(getContext());
233 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
234 target.addDynamicallyLegalOp<CallOp>([](CallOp call) {
235 auto argValues = call.getArgOperands();
236 return argValues.size() == 1 &&
237 hw::type_isa<hw::StructType>(argValues.front().getType());
239 target.addDynamicallyLegalOp<MethodOp>([](MethodOp func) {
240 ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
241 return argTypes.size() == 1 &&
242 hw::type_isa<hw::StructType>(argTypes.front());
246 RewritePatternSet
patterns(&getContext());
247 patterns.insert<MergeCallArgs>(&getContext(), info);
248 patterns.insert<MergeMethodArgs>(&getContext(), info);
250 return applyPartialConversion(getOperation(), target, std::move(
patterns));
254 return std::make_unique<CallPrepPass>();
assert(baseType &&"element must be base type")
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createCallPrepPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Build indexes to make lookups faster. Create the new argument types as well.
ClassOp lookupNext(ClassOp scope, StringAttr instSym) const
DenseMap< std::pair< ClassOp, StringAttr >, ClassOp > instanceMap
DenseMap< hw::InnerRefAttr, std::pair< hw::StructType, Location > > argTypes
CallPrepPrecomputed(DesignOp)
DenseMap< hw::InnerRefAttr, ClassOp > classSymbols
static hw::InnerRefAttr getSymbol(MethodOp method)