CIRCT 21.0.0git
Loading...
Searching...
No Matches
KanagawaCallPrep.cpp
Go to the documentation of this file.
1//===- KanagawaCallPrep.cpp - Implementation of call prep lowering --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/Pass/Pass.h"
12
17
22
23#include "mlir/Transforms/DialectConversion.h"
24
25namespace circt {
26namespace kanagawa {
27#define GEN_PASS_DEF_KANAGAWACALLPREP
28#include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
29} // namespace kanagawa
30} // namespace circt
31
32using namespace circt;
33using namespace kanagawa;
34
35/// Build indexes to make lookups faster. Create the new argument types as well.
37 CallPrepPrecomputed(DesignOp);
38
39 // Lookup a class from its symbol.
40 DenseMap<hw::InnerRefAttr, ClassOp> classSymbols;
41
42 // Mapping of method to argument type.
43 DenseMap<hw::InnerRefAttr, std::pair<hw::StructType, Location>> argTypes;
44
45 // Lookup the class to which a particular instance (in a particular class) is
46 // referring.
47 DenseMap<std::pair<ClassOp, StringAttr>, ClassOp> instanceMap;
48
49 // Lookup an entry in instanceMap. If not found, return null.
50 ClassOp lookupNext(ClassOp scope, StringAttr instSym) const {
51 auto entry = instanceMap.find(std::make_pair(scope, instSym));
52 if (entry == instanceMap.end())
53 return {};
54 return entry->second;
55 }
56
57 // Utility function to create a hw::InnerRefAttr to a method.
58 static hw::InnerRefAttr getSymbol(MethodOp method) {
59 auto design = method->getParentOfType<DesignOp>();
60 return hw::InnerRefAttr::get(design.getSymNameAttr(),
61 method.getInnerSym().getSymName());
62 }
63};
64
66 auto *ctxt = design.getContext();
67 StringAttr modName = design.getNameAttr();
68
69 // Populate the class-symbol lookup table.
70 for (auto cls : design.getOps<ClassOp>())
71 classSymbols[hw::InnerRefAttr::get(
72 modName, cls.getInnerSymAttr().getSymName())] = cls;
73
74 for (auto cls : design.getOps<ClassOp>()) {
75 // Compute new argument types for each method.
76 for (auto method : cls.getOps<MethodOp>()) {
77
78 // Create the struct type.
79 SmallVector<hw::StructType::FieldInfo> argFields;
80 for (auto [argName, argType] :
81 llvm::zip(method.getArgNamesAttr().getAsRange<StringAttr>(),
82 cast<MethodLikeOpInterface>(method.getOperation())
83 .getArgumentTypes()))
84 argFields.push_back({argName, argType});
85 auto argStruct = hw::StructType::get(ctxt, argFields);
86
87 // Later we're gonna want the block locations, so compute a fused location
88 // and store it.
89 Location argLoc = UnknownLoc::get(ctxt);
90 if (method->getNumRegions() > 0) {
91 SmallVector<Location> argLocs;
92 Block *body = &method.getBody().front();
93 for (auto arg : body->getArguments())
94 argLocs.push_back(arg.getLoc());
95 argLoc = FusedLoc::get(ctxt, argLocs);
96 }
97
98 // Add both to the lookup table.
99 argTypes.insert(
100 std::make_pair(getSymbol(method), std::make_pair(argStruct, argLoc)));
101 }
102
103 // Populate the instances table.
104 for (auto inst : cls.getOps<InstanceOp>()) {
105 auto clsEntry = classSymbols.find(inst.getTargetNameAttr());
106 assert(clsEntry != classSymbols.end() &&
107 "class being instantiated doesn't exist");
108 instanceMap[std::make_pair(cls, inst.getInnerSym().getSymName())] =
109 clsEntry->second;
110 }
111 }
112}
113
114namespace {
115/// For each CallOp, the corresponding method signature will have changed. Pack
116/// all the operands into a struct.
117struct MergeCallArgs : public OpConversionPattern<CallOp> {
118 MergeCallArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
120
121 LogicalResult
122 matchAndRewrite(CallOp, OpAdaptor adaptor,
123 ConversionPatternRewriter &rewriter) const override final;
124
125private:
127};
128} // anonymous namespace
129
130LogicalResult
131MergeCallArgs::matchAndRewrite(CallOp call, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter) const {
133 auto loc = call.getLoc();
134 rewriter.setInsertionPoint(call);
135 auto method = call->getParentOfType<kanagawa::MethodLikeOpInterface>();
136
137 // Use the 'info' accelerator structures to find the argument type.
138 auto argStructEntry = info.argTypes.find(call.getCalleeAttr());
139 assert(argStructEntry != info.argTypes.end() && "Method symref not found!");
140 auto [argStruct, argLoc] = argStructEntry->second;
141
142 // Pack all of the operands into it.
143 auto newArg = rewriter.create<hw::StructCreateOp>(loc, argStruct,
144 adaptor.getOperands());
145 newArg->setAttr("sv.namehint",
146 rewriter.getStringAttr(call.getCallee().getName().getValue() +
147 "_args_called_from_" +
148 method.getMethodName().getValue()));
149
150 // Update the call to use just the new struct.
151 rewriter.modifyOpInPlace(call, [&]() {
152 call.getOperandsMutable().clear();
153 call.getOperandsMutable().append(newArg.getResult());
154 });
155
156 return success();
157}
158
159namespace {
160/// Change the method signatures to only have one argument: a struct capturing
161/// all of the original arguments.
162struct MergeMethodArgs : public OpConversionPattern<MethodOp> {
163 MergeMethodArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
165
166 LogicalResult
167 matchAndRewrite(MethodOp, OpAdaptor adaptor,
168 ConversionPatternRewriter &rewriter) const final override;
169
170private:
172};
173} // anonymous namespace
174
175LogicalResult
176MergeMethodArgs::matchAndRewrite(MethodOp func, OpAdaptor adaptor,
177 ConversionPatternRewriter &rewriter) const {
178 auto loc = func.getLoc();
179 auto *ctxt = getContext();
180
181 // Find the pre-computed arg struct for this method.
182 auto argStructEntry =
183 info.argTypes.find(CallPrepPrecomputed::getSymbol(func));
184 assert(argStructEntry != info.argTypes.end() && "Cannot find symref!");
185 auto [argStruct, argLoc] = argStructEntry->second;
186
187 // Create a new method with the new signature.
188 FunctionType funcType = func.getFunctionType();
189 FunctionType newFuncType =
190 FunctionType::get(ctxt, {argStruct}, funcType.getResults());
191 auto newArgNames = ArrayAttr::get(ctxt, {StringAttr::get(ctxt, "arg")});
192 auto newMethod =
193 rewriter.create<MethodOp>(loc, func.getInnerSym(), newFuncType,
194 newArgNames, ArrayAttr(), ArrayAttr());
195
196 if (func->getNumRegions() > 0) {
197 // Create a body block with a struct explode to the arg struct into the
198 // original arguments.
199 Block *b =
200 rewriter.createBlock(&newMethod.getRegion(), {}, {argStruct}, {argLoc});
201 rewriter.setInsertionPointToStart(b);
202 auto replacementArgs =
203 rewriter.create<hw::StructExplodeOp>(loc, b->getArgument(0));
204
205 // Merge the original method body, rewiring the args.
206 Block *funcBody = &func.getBody().front();
207 rewriter.mergeBlocks(funcBody, b, replacementArgs.getResults());
208 }
209
210 rewriter.eraseOp(func);
211
212 return success();
213}
214
215namespace {
216/// Run all the physical lowerings.
217struct CallPrepPass
218 : public circt::kanagawa::impl::KanagawaCallPrepBase<CallPrepPass> {
219 void runOnOperation() override;
220
221private:
222 // Merge the arguments into one struct.
223 LogicalResult merge(const CallPrepPrecomputed &);
224};
225} // anonymous namespace
226
227void CallPrepPass::runOnOperation() {
228 CallPrepPrecomputed info(getOperation());
229
230 if (failed(merge(info))) {
231 signalPassFailure();
232 return;
233 }
234}
235
236LogicalResult CallPrepPass::merge(const CallPrepPrecomputed &info) {
237 // Set up a conversion and give it a set of laws.
238 ConversionTarget target(getContext());
239 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
240 target.addDynamicallyLegalOp<CallOp>([](CallOp call) {
241 auto argValues = call.getArgOperands();
242 return argValues.size() == 1 &&
243 hw::type_isa<hw::StructType>(argValues.front().getType());
244 });
245 target.addDynamicallyLegalOp<MethodOp>([](MethodOp func) {
246 ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
247 return argTypes.size() == 1 &&
248 hw::type_isa<hw::StructType>(argTypes.front());
249 });
250
251 // Add patterns to merge the args on both the call and method sides.
252 RewritePatternSet patterns(&getContext());
253 patterns.insert<MergeCallArgs>(&getContext(), info);
254 patterns.insert<MergeMethodArgs>(&getContext(), info);
255
256 return applyPartialConversion(getOperation(), target, std::move(patterns));
257}
258
259std::unique_ptr<Pass> circt::kanagawa::createCallPrepPass() {
260 return std::make_unique<CallPrepPass>();
261}
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createCallPrepPass()
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Build indexes to make lookups faster. Create the new argument types as well.
ClassOp lookupNext(ClassOp scope, StringAttr instSym) const
DenseMap< std::pair< ClassOp, StringAttr >, ClassOp > instanceMap
DenseMap< hw::InnerRefAttr, std::pair< hw::StructType, Location > > argTypes
DenseMap< hw::InnerRefAttr, ClassOp > classSymbols
static hw::InnerRefAttr getSymbol(MethodOp method)