CIRCT 20.0.0git
Loading...
Searching...
No Matches
KanagawaCallPrep.cpp
Go to the documentation of this file.
1//===- KanagawaCallPrep.cpp - Implementation of call prep lowering --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/Pass/Pass.h"
12
17
22
23#include "mlir/Transforms/DialectConversion.h"
24
25namespace circt {
26namespace kanagawa {
27#define GEN_PASS_DEF_KANAGAWACALLPREP
28#include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
29} // namespace kanagawa
30} // namespace circt
31
32using namespace circt;
33using namespace kanagawa;
34
35/// Build indexes to make lookups faster. Create the new argument types as well.
37 CallPrepPrecomputed(DesignOp);
38
39 // Lookup a class from its symbol.
40 DenseMap<hw::InnerRefAttr, ClassOp> classSymbols;
41
42 // Mapping of method to argument type.
43 DenseMap<hw::InnerRefAttr, std::pair<hw::StructType, Location>> argTypes;
44
45 // Lookup the class to which a particular instance (in a particular class) is
46 // referring.
47 DenseMap<std::pair<ClassOp, StringAttr>, ClassOp> instanceMap;
48
49 // Lookup an entry in instanceMap. If not found, return null.
50 ClassOp lookupNext(ClassOp scope, StringAttr instSym) const {
51 auto entry = instanceMap.find(std::make_pair(scope, instSym));
52 if (entry == instanceMap.end())
53 return {};
54 return entry->second;
55 }
56
57 // Utility function to create a hw::InnerRefAttr to a method.
58 static hw::InnerRefAttr getSymbol(MethodOp method) {
59 auto design = method->getParentOfType<DesignOp>();
60 return hw::InnerRefAttr::get(design.getSymNameAttr(),
61 method.getInnerSym().getSymName());
62 }
63};
64
66 auto *ctxt = design.getContext();
67 StringAttr modName = design.getNameAttr();
68
69 // Populate the class-symbol lookup table.
70 for (auto cls : design.getOps<ClassOp>())
71 classSymbols[hw::InnerRefAttr::get(
72 modName, cls.getInnerSymAttr().getSymName())] = cls;
73
74 for (auto cls : design.getOps<ClassOp>()) {
75 // Compute new argument types for each method.
76 for (auto method : cls.getOps<MethodOp>()) {
77
78 // Create the struct type.
79 SmallVector<hw::StructType::FieldInfo> argFields;
80 for (auto [argName, argType] :
81 llvm::zip(method.getArgNamesAttr().getAsRange<StringAttr>(),
82 cast<MethodLikeOpInterface>(method.getOperation())
83 .getArgumentTypes()))
84 argFields.push_back({argName, argType});
85 auto argStruct = hw::StructType::get(ctxt, argFields);
86
87 // Later we're gonna want the block locations, so compute a fused location
88 // and store it.
89 Location argLoc = UnknownLoc::get(ctxt);
90 if (method->getNumRegions() > 0) {
91 SmallVector<Location> argLocs;
92 Block *body = &method.getBody().front();
93 for (auto arg : body->getArguments())
94 argLocs.push_back(arg.getLoc());
95 argLoc = FusedLoc::get(ctxt, argLocs);
96 }
97
98 // Add both to the lookup table.
99 argTypes.insert(
100 std::make_pair(getSymbol(method), std::make_pair(argStruct, argLoc)));
101 }
102
103 // Populate the instances table.
104 for (auto inst : cls.getOps<InstanceOp>()) {
105 auto clsEntry = classSymbols.find(inst.getTargetNameAttr());
106 assert(clsEntry != classSymbols.end() &&
107 "class being instantiated doesn't exist");
108 instanceMap[std::make_pair(cls, inst.getInnerSym().getSymName())] =
109 clsEntry->second;
110 }
111 }
112}
113
114namespace {
115/// For each CallOp, the corresponding method signature will have changed. Pack
116/// all the operands into a struct.
117struct MergeCallArgs : public OpConversionPattern<CallOp> {
118 MergeCallArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
119 : OpConversionPattern(ctxt), info(info) {}
120
121 void rewrite(CallOp, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter) const final;
123 LogicalResult match(CallOp) const override { return success(); }
124
125private:
126 const CallPrepPrecomputed &info;
127};
128} // anonymous namespace
129
130void MergeCallArgs::rewrite(CallOp call, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter) const {
132 auto loc = call.getLoc();
133 rewriter.setInsertionPoint(call);
134 auto method = call->getParentOfType<kanagawa::MethodLikeOpInterface>();
135
136 // Use the 'info' accelerator structures to find the argument type.
137 auto argStructEntry = info.argTypes.find(call.getCalleeAttr());
138 assert(argStructEntry != info.argTypes.end() && "Method symref not found!");
139 auto [argStruct, argLoc] = argStructEntry->second;
140
141 // Pack all of the operands into it.
142 auto newArg = rewriter.create<hw::StructCreateOp>(loc, argStruct,
143 adaptor.getOperands());
144 newArg->setAttr("sv.namehint",
145 rewriter.getStringAttr(call.getCallee().getName().getValue() +
146 "_args_called_from_" +
147 method.getMethodName().getValue()));
148
149 // Update the call to use just the new struct.
150 rewriter.modifyOpInPlace(call, [&]() {
151 call.getOperandsMutable().clear();
152 call.getOperandsMutable().append(newArg.getResult());
153 });
154}
155
156namespace {
157/// Change the method signatures to only have one argument: a struct capturing
158/// all of the original arguments.
159struct MergeMethodArgs : public OpConversionPattern<MethodOp> {
160 MergeMethodArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
161 : OpConversionPattern(ctxt), info(info) {}
162
163 void rewrite(MethodOp, OpAdaptor adaptor,
164 ConversionPatternRewriter &rewriter) const final;
165 LogicalResult match(MethodOp) const override { return success(); }
166
167private:
168 const CallPrepPrecomputed &info;
169};
170} // anonymous namespace
171
172void MergeMethodArgs::rewrite(MethodOp func, OpAdaptor adaptor,
173 ConversionPatternRewriter &rewriter) const {
174 auto loc = func.getLoc();
175 auto *ctxt = getContext();
176
177 // Find the pre-computed arg struct for this method.
178 auto argStructEntry =
179 info.argTypes.find(CallPrepPrecomputed::getSymbol(func));
180 assert(argStructEntry != info.argTypes.end() && "Cannot find symref!");
181 auto [argStruct, argLoc] = argStructEntry->second;
182
183 // Create a new method with the new signature.
184 FunctionType funcType = func.getFunctionType();
185 FunctionType newFuncType =
186 FunctionType::get(ctxt, {argStruct}, funcType.getResults());
187 auto newArgNames = ArrayAttr::get(ctxt, {StringAttr::get(ctxt, "arg")});
188 auto newMethod =
189 rewriter.create<MethodOp>(loc, func.getInnerSym(), newFuncType,
190 newArgNames, ArrayAttr(), ArrayAttr());
191
192 if (func->getNumRegions() > 0) {
193 // Create a body block with a struct explode to the arg struct into the
194 // original arguments.
195 Block *b =
196 rewriter.createBlock(&newMethod.getRegion(), {}, {argStruct}, {argLoc});
197 rewriter.setInsertionPointToStart(b);
198 auto replacementArgs =
199 rewriter.create<hw::StructExplodeOp>(loc, b->getArgument(0));
200
201 // Merge the original method body, rewiring the args.
202 Block *funcBody = &func.getBody().front();
203 rewriter.mergeBlocks(funcBody, b, replacementArgs.getResults());
204 }
205
206 rewriter.eraseOp(func);
207}
208
209namespace {
210/// Run all the physical lowerings.
211struct CallPrepPass
212 : public circt::kanagawa::impl::KanagawaCallPrepBase<CallPrepPass> {
213 void runOnOperation() override;
214
215private:
216 // Merge the arguments into one struct.
217 LogicalResult merge(const CallPrepPrecomputed &);
218};
219} // anonymous namespace
220
221void CallPrepPass::runOnOperation() {
222 CallPrepPrecomputed info(getOperation());
223
224 if (failed(merge(info))) {
225 signalPassFailure();
226 return;
227 }
228}
229
230LogicalResult CallPrepPass::merge(const CallPrepPrecomputed &info) {
231 // Set up a conversion and give it a set of laws.
232 ConversionTarget target(getContext());
233 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
234 target.addDynamicallyLegalOp<CallOp>([](CallOp call) {
235 auto argValues = call.getArgOperands();
236 return argValues.size() == 1 &&
237 hw::type_isa<hw::StructType>(argValues.front().getType());
238 });
239 target.addDynamicallyLegalOp<MethodOp>([](MethodOp func) {
240 ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
241 return argTypes.size() == 1 &&
242 hw::type_isa<hw::StructType>(argTypes.front());
243 });
244
245 // Add patterns to merge the args on both the call and method sides.
246 RewritePatternSet patterns(&getContext());
247 patterns.insert<MergeCallArgs>(&getContext(), info);
248 patterns.insert<MergeMethodArgs>(&getContext(), info);
249
250 return applyPartialConversion(getOperation(), target, std::move(patterns));
251}
252
253std::unique_ptr<Pass> circt::kanagawa::createCallPrepPass() {
254 return std::make_unique<CallPrepPass>();
255}
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createCallPrepPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Build indexes to make lookups faster. Create the new argument types as well.
ClassOp lookupNext(ClassOp scope, StringAttr instSym) const
DenseMap< std::pair< ClassOp, StringAttr >, ClassOp > instanceMap
DenseMap< hw::InnerRefAttr, std::pair< hw::StructType, Location > > argTypes
DenseMap< hw::InnerRefAttr, ClassOp > classSymbols
static hw::InnerRefAttr getSymbol(MethodOp method)