CIRCT  20.0.0git
KanagawaCallPrep.cpp
Go to the documentation of this file.
1 //===- KanagawaCallPrep.cpp - Implementation of call prep lowering --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Pass/Pass.h"
12 
17 
19 #include "circt/Dialect/HW/HWOps.h"
22 
23 #include "mlir/Transforms/DialectConversion.h"
24 
25 namespace circt {
26 namespace kanagawa {
27 #define GEN_PASS_DEF_KANAGAWACALLPREP
28 #include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
29 } // namespace kanagawa
30 } // namespace circt
31 
32 using namespace circt;
33 using namespace kanagawa;
34 
35 /// Build indexes to make lookups faster. Create the new argument types as well.
37  CallPrepPrecomputed(DesignOp);
38 
39  // Lookup a class from its symbol.
40  DenseMap<hw::InnerRefAttr, ClassOp> classSymbols;
41 
42  // Mapping of method to argument type.
43  DenseMap<hw::InnerRefAttr, std::pair<hw::StructType, Location>> argTypes;
44 
45  // Lookup the class to which a particular instance (in a particular class) is
46  // referring.
47  DenseMap<std::pair<ClassOp, StringAttr>, ClassOp> instanceMap;
48 
49  // Lookup an entry in instanceMap. If not found, return null.
50  ClassOp lookupNext(ClassOp scope, StringAttr instSym) const {
51  auto entry = instanceMap.find(std::make_pair(scope, instSym));
52  if (entry == instanceMap.end())
53  return {};
54  return entry->second;
55  }
56 
57  // Utility function to create a hw::InnerRefAttr to a method.
58  static hw::InnerRefAttr getSymbol(MethodOp method) {
59  auto design = method->getParentOfType<DesignOp>();
60  return hw::InnerRefAttr::get(design.getSymNameAttr(),
61  method.getInnerSym().getSymName());
62  }
63 };
64 
66  auto *ctxt = design.getContext();
67  StringAttr modName = design.getNameAttr();
68 
69  // Populate the class-symbol lookup table.
70  for (auto cls : design.getOps<ClassOp>())
71  classSymbols[hw::InnerRefAttr::get(
72  modName, cls.getInnerSymAttr().getSymName())] = cls;
73 
74  for (auto cls : design.getOps<ClassOp>()) {
75  // Compute new argument types for each method.
76  for (auto method : cls.getOps<MethodOp>()) {
77 
78  // Create the struct type.
79  SmallVector<hw::StructType::FieldInfo> argFields;
80  for (auto [argName, argType] :
81  llvm::zip(method.getArgNamesAttr().getAsRange<StringAttr>(),
82  cast<MethodLikeOpInterface>(method.getOperation())
83  .getArgumentTypes()))
84  argFields.push_back({argName, argType});
85  auto argStruct = hw::StructType::get(ctxt, argFields);
86 
87  // Later we're gonna want the block locations, so compute a fused location
88  // and store it.
89  Location argLoc = UnknownLoc::get(ctxt);
90  if (method->getNumRegions() > 0) {
91  SmallVector<Location> argLocs;
92  Block *body = &method.getBody().front();
93  for (auto arg : body->getArguments())
94  argLocs.push_back(arg.getLoc());
95  argLoc = FusedLoc::get(ctxt, argLocs);
96  }
97 
98  // Add both to the lookup table.
99  argTypes.insert(
100  std::make_pair(getSymbol(method), std::make_pair(argStruct, argLoc)));
101  }
102 
103  // Populate the instances table.
104  for (auto inst : cls.getOps<InstanceOp>()) {
105  auto clsEntry = classSymbols.find(inst.getTargetNameAttr());
106  assert(clsEntry != classSymbols.end() &&
107  "class being instantiated doesn't exist");
108  instanceMap[std::make_pair(cls, inst.getInnerSym().getSymName())] =
109  clsEntry->second;
110  }
111  }
112 }
113 
114 namespace {
115 /// For each CallOp, the corresponding method signature will have changed. Pack
116 /// all the operands into a struct.
117 struct MergeCallArgs : public OpConversionPattern<CallOp> {
118  MergeCallArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
119  : OpConversionPattern(ctxt), info(info) {}
120 
121  void rewrite(CallOp, OpAdaptor adaptor,
122  ConversionPatternRewriter &rewriter) const final;
123  LogicalResult match(CallOp) const override { return success(); }
124 
125 private:
126  const CallPrepPrecomputed &info;
127 };
128 } // anonymous namespace
129 
130 void MergeCallArgs::rewrite(CallOp call, OpAdaptor adaptor,
131  ConversionPatternRewriter &rewriter) const {
132  auto loc = call.getLoc();
133  rewriter.setInsertionPoint(call);
134  auto method = call->getParentOfType<kanagawa::MethodLikeOpInterface>();
135 
136  // Use the 'info' accelerator structures to find the argument type.
137  auto argStructEntry = info.argTypes.find(call.getCalleeAttr());
138  assert(argStructEntry != info.argTypes.end() && "Method symref not found!");
139  auto [argStruct, argLoc] = argStructEntry->second;
140 
141  // Pack all of the operands into it.
142  auto newArg = rewriter.create<hw::StructCreateOp>(loc, argStruct,
143  adaptor.getOperands());
144  newArg->setAttr("sv.namehint",
145  rewriter.getStringAttr(call.getCallee().getName().getValue() +
146  "_args_called_from_" +
147  method.getMethodName().getValue()));
148 
149  // Update the call to use just the new struct.
150  rewriter.modifyOpInPlace(call, [&]() {
151  call.getOperandsMutable().clear();
152  call.getOperandsMutable().append(newArg.getResult());
153  });
154 }
155 
156 namespace {
157 /// Change the method signatures to only have one argument: a struct capturing
158 /// all of the original arguments.
159 struct MergeMethodArgs : public OpConversionPattern<MethodOp> {
160  MergeMethodArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
161  : OpConversionPattern(ctxt), info(info) {}
162 
163  void rewrite(MethodOp, OpAdaptor adaptor,
164  ConversionPatternRewriter &rewriter) const final;
165  LogicalResult match(MethodOp) const override { return success(); }
166 
167 private:
168  const CallPrepPrecomputed &info;
169 };
170 } // anonymous namespace
171 
172 void MergeMethodArgs::rewrite(MethodOp func, OpAdaptor adaptor,
173  ConversionPatternRewriter &rewriter) const {
174  auto loc = func.getLoc();
175  auto *ctxt = getContext();
176 
177  // Find the pre-computed arg struct for this method.
178  auto argStructEntry =
179  info.argTypes.find(CallPrepPrecomputed::getSymbol(func));
180  assert(argStructEntry != info.argTypes.end() && "Cannot find symref!");
181  auto [argStruct, argLoc] = argStructEntry->second;
182 
183  // Create a new method with the new signature.
184  FunctionType funcType = func.getFunctionType();
185  FunctionType newFuncType =
186  FunctionType::get(ctxt, {argStruct}, funcType.getResults());
187  auto newArgNames = ArrayAttr::get(ctxt, {StringAttr::get(ctxt, "arg")});
188  auto newMethod =
189  rewriter.create<MethodOp>(loc, func.getInnerSym(), newFuncType,
190  newArgNames, ArrayAttr(), ArrayAttr());
191 
192  if (func->getNumRegions() > 0) {
193  // Create a body block with a struct explode to the arg struct into the
194  // original arguments.
195  Block *b =
196  rewriter.createBlock(&newMethod.getRegion(), {}, {argStruct}, {argLoc});
197  rewriter.setInsertionPointToStart(b);
198  auto replacementArgs =
199  rewriter.create<hw::StructExplodeOp>(loc, b->getArgument(0));
200 
201  // Merge the original method body, rewiring the args.
202  Block *funcBody = &func.getBody().front();
203  rewriter.mergeBlocks(funcBody, b, replacementArgs.getResults());
204  }
205 
206  rewriter.eraseOp(func);
207 }
208 
209 namespace {
210 /// Run all the physical lowerings.
211 struct CallPrepPass
212  : public circt::kanagawa::impl::KanagawaCallPrepBase<CallPrepPass> {
213  void runOnOperation() override;
214 
215 private:
216  // Merge the arguments into one struct.
217  LogicalResult merge(const CallPrepPrecomputed &);
218 };
219 } // anonymous namespace
220 
221 void CallPrepPass::runOnOperation() {
222  CallPrepPrecomputed info(getOperation());
223 
224  if (failed(merge(info))) {
225  signalPassFailure();
226  return;
227  }
228 }
229 
230 LogicalResult CallPrepPass::merge(const CallPrepPrecomputed &info) {
231  // Set up a conversion and give it a set of laws.
232  ConversionTarget target(getContext());
233  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
234  target.addDynamicallyLegalOp<CallOp>([](CallOp call) {
235  auto argValues = call.getArgOperands();
236  return argValues.size() == 1 &&
237  hw::type_isa<hw::StructType>(argValues.front().getType());
238  });
239  target.addDynamicallyLegalOp<MethodOp>([](MethodOp func) {
240  ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
241  return argTypes.size() == 1 &&
242  hw::type_isa<hw::StructType>(argTypes.front());
243  });
244 
245  // Add patterns to merge the args on both the call and method sides.
246  RewritePatternSet patterns(&getContext());
247  patterns.insert<MergeCallArgs>(&getContext(), info);
248  patterns.insert<MergeMethodArgs>(&getContext(), info);
249 
250  return applyPartialConversion(getOperation(), target, std::move(patterns));
251 }
252 
253 std::unique_ptr<Pass> circt::kanagawa::createCallPrepPass() {
254  return std::make_unique<CallPrepPass>();
255 }
assert(baseType &&"element must be base type")
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
std::unique_ptr< mlir::Pass > createCallPrepPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Build indexes to make lookups faster. Create the new argument types as well.
ClassOp lookupNext(ClassOp scope, StringAttr instSym) const
DenseMap< std::pair< ClassOp, StringAttr >, ClassOp > instanceMap
DenseMap< hw::InnerRefAttr, std::pair< hw::StructType, Location > > argTypes
DenseMap< hw::InnerRefAttr, ClassOp > classSymbols
static hw::InnerRefAttr getSymbol(MethodOp method)