CIRCT  19.0.0git
LowerDPI.cpp
Go to the documentation of this file.
1 //===- LowerDPI.cpp - Lower to DPI to Sim dialects ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the LowerDPI pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/Threading.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Support/LogicalResult.h"
24 #include "llvm/ADT/MapVector.h"
25 
26 namespace circt {
27 namespace firrtl {
28 #define GEN_PASS_DEF_LOWERDPI
29 #include "circt/Dialect/FIRRTL/Passes.h.inc"
30 } // namespace firrtl
31 } // namespace circt
32 
33 using namespace mlir;
34 using namespace llvm;
35 using namespace circt;
36 using namespace circt::firrtl;
37 
38 namespace {
39 struct LowerDPIPass : public circt::firrtl::impl::LowerDPIBase<LowerDPIPass> {
40  void runOnOperation() override;
41 };
42 
43 // A helper struct to lower DPI intrinsics in the circuit.
44 struct LowerDPI {
45  LowerDPI(CircuitOp circuitOp) : circuitOp(circuitOp), nameSpace(circuitOp) {}
46  // Tte main logic.
47  LogicalResult run();
48  bool changed() const { return !funcNameToCallSites.empty(); }
49 
50 private:
51  // Walk all modules and peel `funcNameToCallSites`.
52  void collectIntrinsics();
53 
54  // Lower intrinsics recorded in `funcNameToCallSites`.
55  LogicalResult lower();
56 
57  sim::DPIFuncOp getOrCreateDPIFuncDecl(DPICallIntrinsicOp op);
58  LogicalResult lowerDPIIntrinsic(DPICallIntrinsicOp op);
59 
60  MapVector<StringAttr, SmallVector<DPICallIntrinsicOp>> funcNameToCallSites;
61 
62  // A map stores DPI func op for its function name and type.
63  llvm::DenseMap<std::pair<StringAttr, Type>, sim::DPIFuncOp>
64  functionSignatureToDPIFuncOp;
65 
66  firrtl::CircuitOp circuitOp;
67  CircuitNamespace nameSpace;
68 };
69 } // namespace
70 
71 void LowerDPI::collectIntrinsics() {
72  // A helper struct to collect DPI calls in the circuit.
73  struct DpiCallCollections {
74  FModuleOp module;
75  SmallVector<DPICallIntrinsicOp> dpiOps;
76  };
77 
78  SmallVector<DpiCallCollections, 0> collections;
79  collections.reserve(64);
80 
81  for (auto module : circuitOp.getOps<FModuleOp>())
82  collections.push_back(DpiCallCollections{module, {}});
83 
84  parallelForEach(circuitOp.getContext(), collections, [](auto &result) {
85  result.module.walk(
86  [&](DPICallIntrinsicOp dpi) { result.dpiOps.push_back(dpi); });
87  });
88 
89  for (auto &collection : collections)
90  for (auto dpi : collection.dpiOps)
91  funcNameToCallSites[dpi.getFunctionNameAttr()].push_back(dpi);
92 }
93 
94 LogicalResult LowerDPI::lower() {
95  for (auto [name, calls] : funcNameToCallSites) {
96  auto firstDPICallop = calls.front();
97  // Construct DPI func op.
98  auto firstDPIDecl = getOrCreateDPIFuncDecl(firstDPICallop);
99 
100  auto inputTypes = firstDPICallop.getInputs().getTypes();
101  auto outputTypes = firstDPICallop.getResultTypes();
102 
103  ImplicitLocOpBuilder builder(firstDPICallop.getLoc(),
104  circuitOp.getOperation());
105  auto lowerCall = [&](DPICallIntrinsicOp dpiOp) {
106  auto getLowered = [&](Value value) -> Value {
107  // Insert an unrealized conversion to cast FIRRTL type to HW type.
108  if (!value)
109  return value;
110  auto type = lowerType(value.getType());
111  return builder.create<mlir::UnrealizedConversionCastOp>(type, value)
112  ->getResult(0);
113  };
114  builder.setInsertionPoint(dpiOp);
115  auto clock = getLowered(dpiOp.getClock());
116  auto enable = getLowered(dpiOp.getEnable());
117  SmallVector<Value, 4> inputs;
118  inputs.reserve(dpiOp.getInputs().size());
119  for (auto input : dpiOp.getInputs())
120  inputs.push_back(getLowered(input));
121 
122  SmallVector<Type> outputTypes;
123  if (dpiOp.getResult())
124  outputTypes.push_back(lowerType(dpiOp.getResult().getType()));
125 
126  auto call = builder.create<sim::DPICallOp>(
127  outputTypes, firstDPIDecl.getSymNameAttr(), clock, enable, inputs);
128  if (!call.getResults().empty()) {
129  // Insert unrealized conversion cast HW type to FIRRTL type.
130  auto result = builder
131  .create<mlir::UnrealizedConversionCastOp>(
132  dpiOp.getResult().getType(), call.getResult(0))
133  ->getResult(0);
134  dpiOp.getResult().replaceAllUsesWith(result);
135  }
136  return success();
137  };
138 
139  if (failed(lowerCall(firstDPICallop)))
140  return failure();
141 
142  for (auto dpiOp : llvm::ArrayRef(calls).drop_front()) {
143  // Check that all DPI declaration match.
144  // TODO: This should be implemented as a verifier once function is added
145  // to FIRRTL.
146  if (dpiOp.getInputs().getTypes() != inputTypes) {
147  auto diag = firstDPICallop.emitOpError()
148  << "DPI function " << firstDPICallop.getFunctionNameAttr()
149  << " input types don't match ";
150  diag.attachNote(dpiOp.getLoc()) << " mismatched caller is here";
151  return failure();
152  }
153 
154  if (dpiOp.getResultTypes() != outputTypes) {
155  auto diag = firstDPICallop.emitOpError()
156  << "DPI function " << firstDPICallop.getFunctionNameAttr()
157  << " output types don't match";
158  diag.attachNote(dpiOp.getLoc()) << " mismatched caller is here";
159  return failure();
160  }
161 
162  if (failed(lowerCall(dpiOp)))
163  return failure();
164  }
165 
166  for (auto callOp : calls)
167  callOp.erase();
168  }
169 
170  return success();
171 }
172 
173 sim::DPIFuncOp LowerDPI::getOrCreateDPIFuncDecl(DPICallIntrinsicOp op) {
174  ImplicitLocOpBuilder builder(op.getLoc(), circuitOp.getOperation());
175  builder.setInsertionPointToStart(circuitOp.getBodyBlock());
176  auto inputTypes = op.getInputs().getTypes();
177  auto outputTypes = op.getResultTypes();
178  ArrayAttr inputNames = op.getInputNamesAttr();
179  StringAttr outputName = op.getOutputNameAttr();
180  assert(outputTypes.size() <= 1);
181 
182  SmallVector<hw::ModulePort> ports;
183  ports.reserve(inputTypes.size() + outputTypes.size());
184 
185  // Add input arguments.
186  for (auto [idx, inType] : llvm::enumerate(inputTypes)) {
187  hw::ModulePort port;
189  port.name = inputNames ? cast<StringAttr>(inputNames[idx])
190  : builder.getStringAttr(Twine("in_") + Twine(idx));
191  port.type = lowerType(inType);
192  ports.push_back(port);
193  }
194 
195  // Add output arguments.
196  for (auto [idx, outType] : llvm::enumerate(outputTypes)) {
197  hw::ModulePort port;
199  port.name = outputName ? outputName
200  : builder.getStringAttr(Twine("out_") + Twine(idx));
201  port.type = lowerType(outType);
202  ports.push_back(port);
203  }
204 
205  auto modType = hw::ModuleType::get(builder.getContext(), ports);
206  auto it =
207  functionSignatureToDPIFuncOp.find({op.getFunctionNameAttr(), modType});
208  if (it != functionSignatureToDPIFuncOp.end())
209  return it->second;
210 
211  auto funcSymbol = nameSpace.newName(op.getFunctionNameAttr().getValue());
212  auto funcOp = builder.create<sim::DPIFuncOp>(
213  funcSymbol, modType, ArrayAttr(), ArrayAttr(), op.getFunctionNameAttr());
214  // External function must have a private linkage.
215  funcOp.setPrivate();
216  functionSignatureToDPIFuncOp[{op.getFunctionNameAttr(), modType}] = funcOp;
217  return funcOp;
218 }
219 
220 LogicalResult LowerDPI::run() {
221  collectIntrinsics();
222  return lower();
223 }
224 
225 void LowerDPIPass::runOnOperation() {
226  auto circuitOp = getOperation();
227  LowerDPI lowerDPI(circuitOp);
228  if (failed(lowerDPI.run()))
229  return signalPassFailure();
230  if (!lowerDPI.changed())
231  return markAllAnalysesPreserved();
232 }
233 
234 std::unique_ptr<mlir::Pass> circt::firrtl::createLowerDPIPass() {
235  return std::make_unique<LowerDPIPass>();
236 }
assert(baseType &&"element must be base type")
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
llvm::SmallVector< StringAttr > inputs
Builder builder
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
std::unique_ptr< mlir::Pass > createLowerDPIPass()
Definition: LowerDPI.cpp:234
Type lowerType(Type type, std::optional< Location > loc={}, llvm::function_ref< hw::TypeAliasType(Type, BaseTypeAliasType, Location)> getTypeDeclFn={})
Given a type, return the corresponding lowered type for the HW dialect.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
The namespace of a CircuitOp, generally inhabited by modules.
Definition: Namespace.h:24