11#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/DialectConversion.h"
27#define GEN_PASS_DEF_KANAGAWACALLPREP
28#include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
33using namespace kanagawa;
43 DenseMap<hw::InnerRefAttr, std::pair<hw::StructType, Location>>
argTypes;
47 DenseMap<std::pair<ClassOp, StringAttr>, ClassOp>
instanceMap;
50 ClassOp
lookupNext(ClassOp scope, StringAttr instSym)
const {
51 auto entry =
instanceMap.find(std::make_pair(scope, instSym));
58 static hw::InnerRefAttr
getSymbol(MethodOp method) {
59 auto design = method->getParentOfType<DesignOp>();
60 return hw::InnerRefAttr::get(design.getSymNameAttr(),
61 method.getInnerSym().getSymName());
66 auto *ctxt = design.getContext();
67 StringAttr modName = design.getNameAttr();
70 for (
auto cls : design.getOps<ClassOp>())
72 modName, cls.getInnerSymAttr().getSymName())] = cls;
74 for (
auto cls : design.getOps<ClassOp>()) {
76 for (
auto method : cls.getOps<MethodOp>()) {
79 SmallVector<hw::StructType::FieldInfo> argFields;
80 for (
auto [argName, argType] :
81 llvm::zip(method.getArgNamesAttr().getAsRange<StringAttr>(),
82 cast<MethodLikeOpInterface>(method.getOperation())
84 argFields.push_back({argName, argType});
85 auto argStruct = hw::StructType::get(ctxt, argFields);
89 Location argLoc = UnknownLoc::get(ctxt);
90 if (method->getNumRegions() > 0) {
91 SmallVector<Location> argLocs;
92 Block *body = &method.getBody().front();
93 for (
auto arg : body->getArguments())
94 argLocs.push_back(arg.getLoc());
95 argLoc = FusedLoc::get(ctxt, argLocs);
100 std::make_pair(
getSymbol(method), std::make_pair(argStruct, argLoc)));
104 for (
auto inst : cls.getOps<InstanceOp>()) {
105 auto clsEntry =
classSymbols.find(inst.getTargetNameAttr());
107 "class being instantiated doesn't exist");
108 instanceMap[std::make_pair(cls, inst.getInnerSym().getSymName())] =
122 matchAndRewrite(CallOp, OpAdaptor adaptor,
123 ConversionPatternRewriter &rewriter)
const override final;
131MergeCallArgs::matchAndRewrite(CallOp call, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter)
const {
133 auto loc = call.getLoc();
134 rewriter.setInsertionPoint(call);
135 auto method = call->getParentOfType<kanagawa::MethodLikeOpInterface>();
138 auto argStructEntry =
info.argTypes.find(call.getCalleeAttr());
139 assert(argStructEntry !=
info.argTypes.end() &&
"Method symref not found!");
140 auto [argStruct, argLoc] = argStructEntry->second;
144 adaptor.getOperands());
145 newArg->setAttr(
"sv.namehint",
146 rewriter.getStringAttr(call.getCallee().getName().getValue() +
147 "_args_called_from_" +
148 method.getMethodName().getValue()));
151 rewriter.modifyOpInPlace(call, [&]() {
152 call.getOperandsMutable().clear();
153 call.getOperandsMutable().append(newArg.getResult());
167 matchAndRewrite(MethodOp, OpAdaptor adaptor,
168 ConversionPatternRewriter &rewriter)
const final override;
176MergeMethodArgs::matchAndRewrite(MethodOp func, OpAdaptor adaptor,
177 ConversionPatternRewriter &rewriter)
const {
178 auto loc = func.getLoc();
179 auto *
ctxt = getContext();
182 auto argStructEntry =
184 assert(argStructEntry !=
info.argTypes.end() &&
"Cannot find symref!");
185 auto [argStruct, argLoc] = argStructEntry->second;
188 FunctionType funcType = func.getFunctionType();
189 FunctionType newFuncType =
190 FunctionType::get(ctxt, {argStruct}, funcType.getResults());
191 auto newArgNames = ArrayAttr::get(ctxt, {StringAttr::get(ctxt,
"arg")});
193 rewriter.create<MethodOp>(loc, func.getInnerSym(), newFuncType,
194 newArgNames, ArrayAttr(), ArrayAttr());
196 if (func->getNumRegions() > 0) {
200 rewriter.createBlock(&newMethod.getRegion(), {}, {argStruct}, {argLoc});
201 rewriter.setInsertionPointToStart(b);
202 auto replacementArgs =
203 rewriter.create<hw::StructExplodeOp>(loc, b->getArgument(0));
206 Block *funcBody = &func.getBody().front();
207 rewriter.mergeBlocks(funcBody, b, replacementArgs.getResults());
210 rewriter.eraseOp(func);
218 :
public circt::kanagawa::impl::KanagawaCallPrepBase<CallPrepPass> {
219 void runOnOperation()
override;
227void CallPrepPass::runOnOperation() {
230 if (failed(merge(info))) {
238 ConversionTarget target(getContext());
239 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
240 target.addDynamicallyLegalOp<CallOp>([](CallOp call) {
241 auto argValues = call.getArgOperands();
242 return argValues.size() == 1 &&
243 hw::type_isa<hw::StructType>(argValues.front().getType());
245 target.addDynamicallyLegalOp<MethodOp>([](MethodOp func) {
246 ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
247 return argTypes.size() == 1 &&
248 hw::type_isa<hw::StructType>(argTypes.front());
252 RewritePatternSet
patterns(&getContext());
256 return applyPartialConversion(getOperation(), target, std::move(
patterns));
260 return std::make_unique<CallPrepPass>();
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createCallPrepPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Build indexes to make lookups faster. Create the new argument types as well.
ClassOp lookupNext(ClassOp scope, StringAttr instSym) const
DenseMap< std::pair< ClassOp, StringAttr >, ClassOp > instanceMap
DenseMap< hw::InnerRefAttr, std::pair< hw::StructType, Location > > argTypes
CallPrepPrecomputed(DesignOp)
DenseMap< hw::InnerRefAttr, ClassOp > classSymbols
static hw::InnerRefAttr getSymbol(MethodOp method)