14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/Pass/Pass.h"
19#define DEBUG_TYPE "arc-insert-runtime"
23#define GEN_PASS_DEF_INSERTRUNTIME
24#include "circt/Dialect/Arc/ArcPasses.h.inc"
35struct RuntimeFunction {
36 LLVM::LLVMFuncOp llvmFuncOp = {};
40struct AllocInstanceFunction :
public RuntimeFunction {
41 explicit AllocInstanceFunction(ImplicitLocOpBuilder &builder) {
47 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
48 llvmFuncOp = LLVM::LLVMFuncOp::create(
49 builder, runtime::APICallbacks::symNameAllocInstance,
50 LLVM::LLVMFunctionType::get(ptrTy, {ptrTy, ptrTy}));
54struct DeleteInstanceFunction :
public RuntimeFunction {
55 explicit DeleteInstanceFunction(ImplicitLocOpBuilder &builder) {
59 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
60 auto voidTy = LLVM::LLVMVoidType::get(builder.getContext());
61 llvmFuncOp = LLVM::LLVMFuncOp::create(
62 builder, runtime::APICallbacks::symNameDeleteInstance,
63 LLVM::LLVMFunctionType::get(voidTy, {ptrTy}));
67struct OnEvalFunction :
public RuntimeFunction {
68 explicit OnEvalFunction(ImplicitLocOpBuilder &builder) {
72 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
73 auto voidTy = LLVM::LLVMVoidType::get(builder.getContext());
75 LLVM::LLVMFuncOp::create(builder, runtime::APICallbacks::symNameOnEval,
76 LLVM::LLVMFunctionType::get(voidTy, {ptrTy}));
82struct RuntimeModelContext;
84struct GlobalRuntimeContext {
85 GlobalRuntimeContext() =
delete;
89 explicit GlobalRuntimeContext(ModuleOp moduleOp)
90 : mlirModuleOp(moduleOp), globalBuilder(createBuilder(moduleOp)),
91 allocInstanceFn(globalBuilder), deleteInstanceFn(globalBuilder),
92 onEvalFn(globalBuilder) {}
95 void deleteUnusedFunctions() {
96 for (
auto *fn : apiFunctions)
98 fn->llvmFuncOp->erase();
102 void addModel(ModelOp &modelOp,
const ModelInfo &modelInfo);
104 LogicalResult buildRuntimeModelOps();
106 LogicalResult collectInstances();
109 ModuleOp mlirModuleOp;
111 ImplicitLocOpBuilder globalBuilder;
114 AllocInstanceFunction allocInstanceFn;
115 DeleteInstanceFunction deleteInstanceFn;
116 OnEvalFunction onEvalFn;
117 const std::array<RuntimeFunction *, 3> apiFunctions = {
118 &allocInstanceFn, &deleteInstanceFn, &onEvalFn};
124 static ImplicitLocOpBuilder createBuilder(ModuleOp &moduleOp) {
125 auto builder = ImplicitLocOpBuilder(moduleOp.getLoc(), moduleOp);
126 builder.setInsertionPointToStart(moduleOp.getBody());
131struct RuntimeModelContext {
132 RuntimeModelContext() =
delete;
134 RuntimeModelContext(GlobalRuntimeContext &globalContext, ModelOp &modelOp,
135 const ModelInfo &modelInfo)
136 : globalContext(globalContext), modelOp(modelOp), modelInfo(modelInfo) {}
139 void addInstance(SimInstantiateOp &instantiateOp) {
140 assert(!instantiateOp.getRuntimeModelAttr());
142 instantiateOp.setRuntimeModelAttr(
143 FlatSymbolRefAttr::get(runtimeModelOp.getSymNameAttr()));
144 instances.push_back(instantiateOp);
148 LogicalResult lower();
151 GlobalRuntimeContext &globalContext;
155 const ModelInfo &modelInfo;
157 SmallVector<SimInstantiateOp> instances;
159 RuntimeModelOp runtimeModelOp;
162 LogicalResult lowerInstance(SimInstantiateOp &instance);
164struct InsertRuntimePass
165 :
public arc::impl::InsertRuntimeBase<InsertRuntimePass> {
166 using InsertRuntimeBase::InsertRuntimeBase;
168 void runOnOperation()
override;
173void GlobalRuntimeContext::addModel(ModelOp &modelOp,
174 const ModelInfo &modelInfo) {
176 std::make_unique<RuntimeModelContext>(*
this, modelOp, modelInfo);
177 models[modelOp.getNameAttr()] = std::move(newModel);
182LogicalResult GlobalRuntimeContext::collectInstances() {
183 bool hasFailed =
false;
184 mlirModuleOp.getBody()->walk([&](Operation *op) -> WalkResult {
185 if (
auto instOp = dyn_cast<SimInstantiateOp>(op)) {
187 if (instOp.getRuntimeModel())
188 return WalkResult::skip();
189 auto instanceModelSym = llvm::cast<SimModelInstanceType>(
190 instOp.getBody().getArgument(0).getType())
193 auto modelContext = models.find(instanceModelSym);
194 if (modelContext == models.end()) {
196 instOp->emitOpError(
" does not refer to a known Arc model.");
198 modelContext->second->addInstance(instOp);
200 return WalkResult::skip();
202 if (
auto instOp = dyn_cast<ModelOp>(op))
203 return WalkResult::skip();
204 return WalkResult::advance();
206 return success(!hasFailed);
210LogicalResult GlobalRuntimeContext::buildRuntimeModelOps() {
211 auto savedLoc = globalBuilder.getLoc();
212 for (
auto &[_, model] : models) {
213 globalBuilder.setLoc(model->modelOp.getLoc());
214 auto symName = globalBuilder.getStringAttr(Twine(
"arcRuntimeModel_") +
215 model->modelInfo.name);
216 model->runtimeModelOp = RuntimeModelOp::create(
217 globalBuilder, symName,
218 globalBuilder.getStringAttr(model->modelInfo.name),
219 static_cast<uint64_t
>(model->modelInfo.numStateBytes));
221 globalBuilder.setLoc(savedLoc);
226LogicalResult RuntimeModelContext::lower() {
227 bool hasFailed =
false;
228 for (
auto &instance : instances)
229 if (failed(lowerInstance(instance)))
231 return success(!hasFailed);
234LogicalResult RuntimeModelContext::lowerInstance(SimInstantiateOp &instance) {
236 globalContext.allocInstanceFn.used =
true;
237 globalContext.deleteInstanceFn.used =
true;
240 OpBuilder instBodyBuilder(instance);
241 instBodyBuilder.setInsertionPointToStart(
242 &instance.getBody().getBlocks().front());
244 UnrealizedConversionCastOp::create(
245 instBodyBuilder, instance.getLoc(),
246 LLVM::LLVMPointerType::get(instBodyBuilder.getContext()),
247 instance.getBody().getArgument(0))
250 instance.getBody().getBlocks().front().walk([&](SimStepOp stepOp) {
251 instBodyBuilder.setInsertionPoint(stepOp);
252 globalContext.onEvalFn.used =
true;
253 LLVM::CallOp::create(instBodyBuilder, stepOp.getLoc(),
254 globalContext.onEvalFn.llvmFuncOp, {runtimeInst});
260void InsertRuntimePass::runOnOperation() {
263 auto &modelInfo = getAnalysis<ModelInfoAnalysis>();
264 auto globalContext = std::make_unique<GlobalRuntimeContext>(getOperation());
265 for (
auto &[mOp, mInfo] : modelInfo.infoMap)
266 globalContext->addModel(mOp, mInfo);
267 if (failed(globalContext->buildRuntimeModelOps()) ||
268 failed(globalContext->collectInstances())) {
274 for (
auto &[_, model] : globalContext->models) {
276 if (!extraArgs.empty()) {
277 for (
auto &instance : model->instances) {
279 if (!instance.getRuntimeArgsAttr() ||
280 instance.getRuntimeArgsAttr().getValue().empty())
281 newArgs = StringAttr::get(&getContext(), Twine(extraArgs));
283 newArgs = StringAttr::get(&getContext(),
284 Twine(instance.getRuntimeArgsAttr()) +
285 Twine(
";") + Twine(extraArgs));
286 instance.setRuntimeArgsAttr(newArgs);
290 if (failed(model->lower()))
294 globalContext->deleteUnusedFunctions();
295 markAnalysesPreserved<ModelInfoAnalysis>();
assert(baseType &&"element must be base type")
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.