CIRCT  18.0.0git
IbisCallPrep.cpp
Go to the documentation of this file.
1 //===- IbisCallPrep.cpp - Implementation of call prep lowering ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetails.h"
10 
15 
19 
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace circt;
23 using namespace ibis;
24 
25 /// Build indexes to make lookups faster. Create the new argument types as well.
27  CallPrepPrecomputed(ModuleOp mod);
28 
29  // Lookup a class from its symbol.
30  DenseMap<StringAttr, ClassOp> classSymbols;
31 
32  // Mapping of method to argument type.
33  DenseMap<SymbolRefAttr, std::pair<hw::StructType, Location>> argTypes;
34 
35  // Lookup the class to which a particular instance (in a particular class) is
36  // referring.
37  DenseMap<std::pair<ClassOp, StringAttr>, ClassOp> instanceMap;
38 
39  // Lookup an entry in instanceMap. If not found, return null.
40  ClassOp lookupNext(ClassOp scope, StringAttr instSym) const {
41  auto entry = instanceMap.find(std::make_pair(scope, instSym));
42  if (entry == instanceMap.end())
43  return {};
44  return entry->second;
45  }
46 
47  // Given an instance path, get the class::func symbolref for it.
48  SymbolRefAttr resolveInstancePath(Operation *scope, SymbolRefAttr path) const;
49 
50  // Utility function to create a symbolref to a method.
51  static SymbolRefAttr getSymbol(MethodOp method) {
52  ClassOp cls = method.getParentOp();
53  return SymbolRefAttr::get(
54  cls.getSymNameAttr(),
55  {FlatSymbolRefAttr::get(method.getContext(), *method.getInnerName())});
56  }
57 };
58 
60  auto *ctxt = mod.getContext();
61 
62  // Populate the class-symbol lookup table.
63  for (auto cls : mod.getOps<ClassOp>())
64  classSymbols[cls.getSymNameAttr()] = cls;
65 
66  for (auto cls : mod.getOps<ClassOp>()) {
67  // Compute new argument types for each method.
68  for (auto method : cls.getOps<MethodOp>()) {
69 
70  // Create the struct type.
71  SmallVector<hw::StructType::FieldInfo> argFields;
72  for (auto [argName, argType] :
73  llvm::zip(method.getArgNamesAttr().getAsRange<StringAttr>(),
74  cast<MethodLikeOpInterface>(method.getOperation())
75  .getArgumentTypes()))
76  argFields.push_back({argName, argType});
77  auto argStruct = hw::StructType::get(ctxt, argFields);
78 
79  // Later we're gonna want the block locations, so compute a fused location
80  // and store it.
81  Location argLoc = UnknownLoc::get(ctxt);
82  if (method->getNumRegions() > 0) {
83  SmallVector<Location> argLocs;
84  Block *body = &method.getBody().front();
85  for (auto arg : body->getArguments())
86  argLocs.push_back(arg.getLoc());
87  argLoc = FusedLoc::get(ctxt, argLocs);
88  }
89 
90  // Add both to the lookup table.
91  argTypes.insert(
92  std::make_pair(getSymbol(method), std::make_pair(argStruct, argLoc)));
93  }
94 
95  // Populate the instances table.
96  for (auto inst : cls.getOps<InstanceOp>()) {
97  auto clsEntry = classSymbols.find(inst.getTargetNameAttr().getAttr());
98  assert(clsEntry != classSymbols.end() &&
99  "class being instantiated doesn't exist");
100  instanceMap[std::make_pair(cls, inst.getInnerSym().getSymName())] =
101  clsEntry->second;
102  }
103  }
104 }
105 
106 SymbolRefAttr
108  SymbolRefAttr path) const {
109  auto cls = scope->getParentOfType<ClassOp>();
110  assert(cls && "scope outside of ibis class");
111 
112  // SymbolRefAttr is rather silly. The start of the path is root reference...
113  cls = lookupNext(cls, path.getRootReference());
114  if (!cls)
115  return {};
116 
117  // ... then the rest are the nested references. The last one is the function
118  // name rather than an instance.
119  for (auto instSym : path.getNestedReferences().drop_back()) {
120  cls = lookupNext(cls, instSym.getAttr());
121  if (!cls)
122  return {};
123  }
124 
125  // The last one is the function symbol.
126  return SymbolRefAttr::get(cls.getSymNameAttr(),
127  {FlatSymbolRefAttr::get(path.getLeafReference())});
128 }
129 
130 namespace {
131 /// For each CallOp, the corresponding method signature will have changed. Pack
132 /// all the operands into a struct.
133 struct MergeCallArgs : public OpConversionPattern<CallOp> {
134  MergeCallArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
135  : OpConversionPattern(ctxt), info(info) {}
136 
137  void rewrite(CallOp, OpAdaptor adaptor,
138  ConversionPatternRewriter &rewriter) const final;
139  LogicalResult match(CallOp) const override { return success(); }
140 
141 private:
142  const CallPrepPrecomputed &info;
143 };
144 } // anonymous namespace
145 
146 void MergeCallArgs::rewrite(CallOp call, OpAdaptor adaptor,
147  ConversionPatternRewriter &rewriter) const {
148  auto loc = call.getLoc();
149  rewriter.setInsertionPoint(call);
150  auto method = call->getParentOfType<ibis::MethodLikeOpInterface>();
151 
152  // Use the 'info' accelerator structures to find the argument type.
153  SymbolRefAttr calleeSym =
154  info.resolveInstancePath(method, adaptor.getCalleeAttr());
155  auto argStructEntry = info.argTypes.find(calleeSym);
156  assert(argStructEntry != info.argTypes.end() && "Method symref not found!");
157  auto [argStruct, argLoc] = argStructEntry->second;
158 
159  // Pack all of the operands into it.
160  auto newArg = rewriter.create<hw::StructCreateOp>(loc, argStruct,
161  adaptor.getOperands());
162  newArg->setAttr("sv.namehint",
163  rewriter.getStringAttr(
164  call.getCalleeAttr().getLeafReference().getValue() +
165  "_args_called_from_" +
166  method.getMethodName().getValue()));
167 
168  // Update the call to use just the new struct.
169  rewriter.updateRootInPlace(call, [&]() {
170  call.getOperandsMutable().clear();
171  call.getOperandsMutable().append(newArg.getResult());
172  });
173 }
174 
175 namespace {
176 /// Change the method signatures to only have one argument: a struct capturing
177 /// all of the original arguments.
178 struct MergeMethodArgs : public OpConversionPattern<MethodOp> {
179  MergeMethodArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
180  : OpConversionPattern(ctxt), info(info) {}
181 
182  void rewrite(MethodOp, OpAdaptor adaptor,
183  ConversionPatternRewriter &rewriter) const final;
184  LogicalResult match(MethodOp) const override { return success(); }
185 
186 private:
187  const CallPrepPrecomputed &info;
188 };
189 } // anonymous namespace
190 
191 void MergeMethodArgs::rewrite(MethodOp func, OpAdaptor adaptor,
192  ConversionPatternRewriter &rewriter) const {
193  auto loc = func.getLoc();
194  auto *ctxt = getContext();
195 
196  // Find the pre-computed arg struct for this method.
197  auto argStructEntry =
198  info.argTypes.find(CallPrepPrecomputed::getSymbol(func));
199  assert(argStructEntry != info.argTypes.end() && "Cannot find symref!");
200  auto [argStruct, argLoc] = argStructEntry->second;
201 
202  // Create a new method with the new signature.
203  FunctionType funcType = func.getFunctionType();
204  FunctionType newFuncType =
205  FunctionType::get(ctxt, {argStruct}, funcType.getResults());
206  auto newArgNames = ArrayAttr::get(ctxt, {StringAttr::get(ctxt, "arg")});
207  auto newMethod =
208  rewriter.create<MethodOp>(loc, func.getInnerSym(), newFuncType,
209  newArgNames, ArrayAttr(), ArrayAttr());
210 
211  if (func->getNumRegions() > 0) {
212  // Create a body block with a struct explode to the arg struct into the
213  // original arguments.
214  Block *b =
215  rewriter.createBlock(&newMethod.getRegion(), {}, {argStruct}, {argLoc});
216  rewriter.setInsertionPointToStart(b);
217  auto replacementArgs =
218  rewriter.create<hw::StructExplodeOp>(loc, b->getArgument(0));
219 
220  // Merge the original method body, rewiring the args.
221  Block *funcBody = &func.getBody().front();
222  rewriter.mergeBlocks(funcBody, b, replacementArgs.getResults());
223  }
224 
225  rewriter.eraseOp(func);
226 }
227 
228 namespace {
229 /// Run all the physical lowerings.
230 struct CallPrepPass : public IbisCallPrepBase<CallPrepPass> {
231  void runOnOperation() override;
232 
233 private:
234  // Merge the arguments into one struct.
235  LogicalResult merge(const CallPrepPrecomputed &);
236 };
237 } // anonymous namespace
238 
239 void CallPrepPass::runOnOperation() {
240  CallPrepPrecomputed info(getOperation());
241 
242  if (failed(merge(info))) {
243  signalPassFailure();
244  return;
245  }
246 }
247 
248 LogicalResult CallPrepPass::merge(const CallPrepPrecomputed &info) {
249  // Set up a conversion and give it a set of laws.
250  ConversionTarget target(getContext());
251  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
252  target.addDynamicallyLegalOp<CallOp>([](CallOp call) {
253  auto argValues = call.getArgOperands();
254  return argValues.size() == 1 &&
255  hw::type_isa<hw::StructType>(argValues.front().getType());
256  });
257  target.addDynamicallyLegalOp<MethodOp>([](MethodOp func) {
258  ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
259  return argTypes.size() == 1 &&
260  hw::type_isa<hw::StructType>(argTypes.front());
261  });
262 
263  // Add patterns to merge the args on both the call and method sides.
264  RewritePatternSet patterns(&getContext());
265  patterns.insert<MergeCallArgs>(&getContext(), info);
266  patterns.insert<MergeMethodArgs>(&getContext(), info);
267 
268  return applyPartialConversion(getOperation(), target, std::move(patterns));
269 }
270 
271 std::unique_ptr<Pass> circt::ibis::createCallPrepPass() {
272  return std::make_unique<CallPrepPass>();
273 }
assert(baseType &&"element must be base type")
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
std::unique_ptr< mlir::Pass > createCallPrepPass()
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
Build indexes to make lookups faster. Create the new argument types as well.
SymbolRefAttr resolveInstancePath(Operation *scope, SymbolRefAttr path) const
DenseMap< StringAttr, ClassOp > classSymbols
ClassOp lookupNext(ClassOp scope, StringAttr instSym) const
DenseMap< std::pair< ClassOp, StringAttr >, ClassOp > instanceMap
DenseMap< SymbolRefAttr, std::pair< hw::StructType, Location > > argTypes
CallPrepPrecomputed(ModuleOp mod)
static SymbolRefAttr getSymbol(MethodOp method)