CIRCT  19.0.0git
IbisCallPrep.cpp
Go to the documentation of this file.
1 //===- IbisCallPrep.cpp - Implementation of call prep lowering ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetails.h"
10 
15 
19 
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace circt;
23 using namespace ibis;
24 
25 /// Build indexes to make lookups faster. Create the new argument types as well.
27  CallPrepPrecomputed(DesignOp);
28 
29  // Lookup a class from its symbol.
30  DenseMap<hw::InnerRefAttr, ClassOp> classSymbols;
31 
32  // Mapping of method to argument type.
33  DenseMap<hw::InnerRefAttr, std::pair<hw::StructType, Location>> argTypes;
34 
35  // Lookup the class to which a particular instance (in a particular class) is
36  // referring.
37  DenseMap<std::pair<ClassOp, StringAttr>, ClassOp> instanceMap;
38 
39  // Lookup an entry in instanceMap. If not found, return null.
40  ClassOp lookupNext(ClassOp scope, StringAttr instSym) const {
41  auto entry = instanceMap.find(std::make_pair(scope, instSym));
42  if (entry == instanceMap.end())
43  return {};
44  return entry->second;
45  }
46 
47  // Utility function to create a hw::InnerRefAttr to a method.
48  static hw::InnerRefAttr getSymbol(MethodOp method) {
49  auto design = method->getParentOfType<DesignOp>();
50  return hw::InnerRefAttr::get(design.getSymNameAttr(),
51  method.getInnerSym().getSymName());
52  }
53 };
54 
56  auto *ctxt = design.getContext();
57  StringAttr modName = design.getNameAttr();
58 
59  // Populate the class-symbol lookup table.
60  for (auto cls : design.getOps<ClassOp>())
61  classSymbols[hw::InnerRefAttr::get(
62  modName, cls.getInnerSymAttr().getSymName())] = cls;
63 
64  for (auto cls : design.getOps<ClassOp>()) {
65  // Compute new argument types for each method.
66  for (auto method : cls.getOps<MethodOp>()) {
67 
68  // Create the struct type.
69  SmallVector<hw::StructType::FieldInfo> argFields;
70  for (auto [argName, argType] :
71  llvm::zip(method.getArgNamesAttr().getAsRange<StringAttr>(),
72  cast<MethodLikeOpInterface>(method.getOperation())
73  .getArgumentTypes()))
74  argFields.push_back({argName, argType});
75  auto argStruct = hw::StructType::get(ctxt, argFields);
76 
77  // Later we're gonna want the block locations, so compute a fused location
78  // and store it.
79  Location argLoc = UnknownLoc::get(ctxt);
80  if (method->getNumRegions() > 0) {
81  SmallVector<Location> argLocs;
82  Block *body = &method.getBody().front();
83  for (auto arg : body->getArguments())
84  argLocs.push_back(arg.getLoc());
85  argLoc = FusedLoc::get(ctxt, argLocs);
86  }
87 
88  // Add both to the lookup table.
89  argTypes.insert(
90  std::make_pair(getSymbol(method), std::make_pair(argStruct, argLoc)));
91  }
92 
93  // Populate the instances table.
94  for (auto inst : cls.getOps<InstanceOp>()) {
95  auto clsEntry = classSymbols.find(inst.getTargetNameAttr());
96  assert(clsEntry != classSymbols.end() &&
97  "class being instantiated doesn't exist");
98  instanceMap[std::make_pair(cls, inst.getInnerSym().getSymName())] =
99  clsEntry->second;
100  }
101  }
102 }
103 
104 namespace {
105 /// For each CallOp, the corresponding method signature will have changed. Pack
106 /// all the operands into a struct.
107 struct MergeCallArgs : public OpConversionPattern<CallOp> {
108  MergeCallArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
109  : OpConversionPattern(ctxt), info(info) {}
110 
111  void rewrite(CallOp, OpAdaptor adaptor,
112  ConversionPatternRewriter &rewriter) const final;
113  LogicalResult match(CallOp) const override { return success(); }
114 
115 private:
116  const CallPrepPrecomputed &info;
117 };
118 } // anonymous namespace
119 
120 void MergeCallArgs::rewrite(CallOp call, OpAdaptor adaptor,
121  ConversionPatternRewriter &rewriter) const {
122  auto loc = call.getLoc();
123  rewriter.setInsertionPoint(call);
124  auto method = call->getParentOfType<ibis::MethodLikeOpInterface>();
125 
126  // Use the 'info' accelerator structures to find the argument type.
127  auto argStructEntry = info.argTypes.find(call.getCalleeAttr());
128  assert(argStructEntry != info.argTypes.end() && "Method symref not found!");
129  auto [argStruct, argLoc] = argStructEntry->second;
130 
131  // Pack all of the operands into it.
132  auto newArg = rewriter.create<hw::StructCreateOp>(loc, argStruct,
133  adaptor.getOperands());
134  newArg->setAttr("sv.namehint",
135  rewriter.getStringAttr(call.getCallee().getName().getValue() +
136  "_args_called_from_" +
137  method.getMethodName().getValue()));
138 
139  // Update the call to use just the new struct.
140  rewriter.modifyOpInPlace(call, [&]() {
141  call.getOperandsMutable().clear();
142  call.getOperandsMutable().append(newArg.getResult());
143  });
144 }
145 
146 namespace {
147 /// Change the method signatures to only have one argument: a struct capturing
148 /// all of the original arguments.
149 struct MergeMethodArgs : public OpConversionPattern<MethodOp> {
150  MergeMethodArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
151  : OpConversionPattern(ctxt), info(info) {}
152 
153  void rewrite(MethodOp, OpAdaptor adaptor,
154  ConversionPatternRewriter &rewriter) const final;
155  LogicalResult match(MethodOp) const override { return success(); }
156 
157 private:
158  const CallPrepPrecomputed &info;
159 };
160 } // anonymous namespace
161 
162 void MergeMethodArgs::rewrite(MethodOp func, OpAdaptor adaptor,
163  ConversionPatternRewriter &rewriter) const {
164  auto loc = func.getLoc();
165  auto *ctxt = getContext();
166 
167  // Find the pre-computed arg struct for this method.
168  auto argStructEntry =
169  info.argTypes.find(CallPrepPrecomputed::getSymbol(func));
170  assert(argStructEntry != info.argTypes.end() && "Cannot find symref!");
171  auto [argStruct, argLoc] = argStructEntry->second;
172 
173  // Create a new method with the new signature.
174  FunctionType funcType = func.getFunctionType();
175  FunctionType newFuncType =
176  FunctionType::get(ctxt, {argStruct}, funcType.getResults());
177  auto newArgNames = ArrayAttr::get(ctxt, {StringAttr::get(ctxt, "arg")});
178  auto newMethod =
179  rewriter.create<MethodOp>(loc, func.getInnerSym(), newFuncType,
180  newArgNames, ArrayAttr(), ArrayAttr());
181 
182  if (func->getNumRegions() > 0) {
183  // Create a body block with a struct explode to the arg struct into the
184  // original arguments.
185  Block *b =
186  rewriter.createBlock(&newMethod.getRegion(), {}, {argStruct}, {argLoc});
187  rewriter.setInsertionPointToStart(b);
188  auto replacementArgs =
189  rewriter.create<hw::StructExplodeOp>(loc, b->getArgument(0));
190 
191  // Merge the original method body, rewiring the args.
192  Block *funcBody = &func.getBody().front();
193  rewriter.mergeBlocks(funcBody, b, replacementArgs.getResults());
194  }
195 
196  rewriter.eraseOp(func);
197 }
198 
199 namespace {
200 /// Run all the physical lowerings.
201 struct CallPrepPass : public IbisCallPrepBase<CallPrepPass> {
202  void runOnOperation() override;
203 
204 private:
205  // Merge the arguments into one struct.
206  LogicalResult merge(const CallPrepPrecomputed &);
207 };
208 } // anonymous namespace
209 
210 void CallPrepPass::runOnOperation() {
211  CallPrepPrecomputed info(getOperation());
212 
213  if (failed(merge(info))) {
214  signalPassFailure();
215  return;
216  }
217 }
218 
219 LogicalResult CallPrepPass::merge(const CallPrepPrecomputed &info) {
220  // Set up a conversion and give it a set of laws.
221  ConversionTarget target(getContext());
222  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
223  target.addDynamicallyLegalOp<CallOp>([](CallOp call) {
224  auto argValues = call.getArgOperands();
225  return argValues.size() == 1 &&
226  hw::type_isa<hw::StructType>(argValues.front().getType());
227  });
228  target.addDynamicallyLegalOp<MethodOp>([](MethodOp func) {
229  ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
230  return argTypes.size() == 1 &&
231  hw::type_isa<hw::StructType>(argTypes.front());
232  });
233 
234  // Add patterns to merge the args on both the call and method sides.
235  RewritePatternSet patterns(&getContext());
236  patterns.insert<MergeCallArgs>(&getContext(), info);
237  patterns.insert<MergeMethodArgs>(&getContext(), info);
238 
239  return applyPartialConversion(getOperation(), target, std::move(patterns));
240 }
241 
242 std::unique_ptr<Pass> circt::ibis::createCallPrepPass() {
243  return std::make_unique<CallPrepPass>();
244 }
assert(baseType &&"element must be base type")
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
std::unique_ptr< mlir::Pass > createCallPrepPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Build indexes to make lookups faster. Create the new argument types as well.
ClassOp lookupNext(ClassOp scope, StringAttr instSym) const
DenseMap< std::pair< ClassOp, StringAttr >, ClassOp > instanceMap
DenseMap< hw::InnerRefAttr, std::pair< hw::StructType, Location > > argTypes
CallPrepPrecomputed(DesignOp)
DenseMap< hw::InnerRefAttr, ClassOp > classSymbols
static hw::InnerRefAttr getSymbol(MethodOp method)