15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Pass/Pass.h"
22#define DEBUG_TYPE "arc-insert-runtime"
26#define GEN_PASS_DEF_INSERTRUNTIME
27#include "circt/Dialect/Arc/ArcPasses.h.inc"
38struct RuntimeFunction {
39 LLVM::LLVMFuncOp llvmFuncOp = {};
44 void setModelStateArgAttrs(OpBuilder &builder,
unsigned argIndex,
46 llvmFuncOp.setArgAttr(0, LLVM::LLVMDialect::getNoCaptureAttrName(),
47 builder.getUnitAttr());
48 llvmFuncOp.setArgAttr(0, LLVM::LLVMDialect::getNoFreeAttrName(),
49 builder.getUnitAttr());
50 llvmFuncOp.setArgAttr(0, LLVM::LLVMDialect::getNoAliasAttrName(),
51 builder.getUnitAttr());
53 llvmFuncOp.setArgAttr(0, LLVM::LLVMDialect::getReadonlyAttrName(),
54 builder.getUnitAttr());
58struct AllocInstanceFunction :
public RuntimeFunction {
59 explicit AllocInstanceFunction(ImplicitLocOpBuilder &builder) {
65 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
66 llvmFuncOp = LLVM::LLVMFuncOp::create(
67 builder, runtime::APICallbacks::symNameAllocInstance,
68 LLVM::LLVMFunctionType::get(ptrTy, {ptrTy, ptrTy}));
69 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNoAliasAttrName(),
70 builder.getUnitAttr());
71 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNoUndefAttrName(),
72 builder.getUnitAttr());
73 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNonNullAttrName(),
74 builder.getUnitAttr());
75 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getAlignAttrName(),
76 builder.getI64IntegerAttr(16));
80struct DeleteInstanceFunction :
public RuntimeFunction {
81 explicit DeleteInstanceFunction(ImplicitLocOpBuilder &builder) {
85 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
86 auto voidTy = LLVM::LLVMVoidType::get(builder.getContext());
87 llvmFuncOp = LLVM::LLVMFuncOp::create(
88 builder, runtime::APICallbacks::symNameDeleteInstance,
89 LLVM::LLVMFunctionType::get(voidTy, {ptrTy}));
93struct OnEvalFunction :
public RuntimeFunction {
94 explicit OnEvalFunction(ImplicitLocOpBuilder &builder) {
98 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
99 auto voidTy = LLVM::LLVMVoidType::get(builder.getContext());
101 LLVM::LLVMFuncOp::create(builder, runtime::APICallbacks::symNameOnEval,
102 LLVM::LLVMFunctionType::get(voidTy, {ptrTy}));
103 setModelStateArgAttrs(builder, 0,
true);
107struct OnInitializedFunction :
public RuntimeFunction {
108 explicit OnInitializedFunction(ImplicitLocOpBuilder &builder) {
112 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
113 auto voidTy = LLVM::LLVMVoidType::get(builder.getContext());
114 llvmFuncOp = LLVM::LLVMFuncOp::create(
115 builder, runtime::APICallbacks::symNameOnInitialized,
116 LLVM::LLVMFunctionType::get(voidTy, {ptrTy}));
117 setModelStateArgAttrs(builder, 0,
true);
121struct SwapTraceBufferFunction :
public RuntimeFunction {
122 explicit SwapTraceBufferFunction(ImplicitLocOpBuilder &builder) {
126 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
127 llvmFuncOp = LLVM::LLVMFuncOp::create(
128 builder, runtime::APICallbacks::symNameSwapTraceBuffer,
129 LLVM::LLVMFunctionType::get(ptrTy, {ptrTy}));
130 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNoAliasAttrName(),
131 builder.getUnitAttr());
132 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNoUndefAttrName(),
133 builder.getUnitAttr());
134 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNonNullAttrName(),
135 builder.getUnitAttr());
136 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getAlignAttrName(),
137 builder.getI64IntegerAttr(8));
138 setModelStateArgAttrs(builder, 0,
false);
144struct RuntimeModelContext;
146struct GlobalRuntimeContext {
147 GlobalRuntimeContext() =
delete;
151 explicit GlobalRuntimeContext(ModuleOp moduleOp)
152 : mlirModuleOp(moduleOp), globalBuilder(createBuilder(moduleOp)),
153 allocInstanceFn(globalBuilder), deleteInstanceFn(globalBuilder),
154 onEvalFn(globalBuilder), onInitializedFn(globalBuilder),
155 swapTraceBufferFn(globalBuilder) {}
158 void deleteUnusedFunctions() {
159 for (
auto *fn : apiFunctions)
161 fn->llvmFuncOp->erase();
165 static Type getTraceExtendedType(Type stateType) {
166 auto numBits = stateType.getIntOrFloatBitWidth();
167 auto numQWords = std::max((numBits + 63) / 64, 1U);
168 return IntegerType::get(stateType.getContext(), numQWords * 64);
172 void addModel(ModelOp &modelOp,
const ModelInfo &modelInfo);
174 LogicalResult buildRuntimeModelOps();
176 LogicalResult collectInstances();
179 LogicalResult buildTraceInstrumentation();
182 LLVM::LLVMFuncOp getTraceInstrumentFn(Type ty)
const {
183 assert(ty.getIntOrFloatBitWidth() % 64 == 0);
184 auto fn = traceInstrumentationFns.find(ty);
185 assert(fn != traceInstrumentationFns.end());
190 ModuleOp mlirModuleOp;
192 ImplicitLocOpBuilder globalBuilder;
195 AllocInstanceFunction allocInstanceFn;
196 DeleteInstanceFunction deleteInstanceFn;
197 OnEvalFunction onEvalFn;
198 OnInitializedFunction onInitializedFn;
199 SwapTraceBufferFunction swapTraceBufferFn;
200 const std::array<RuntimeFunction *, 5> apiFunctions = {
201 &allocInstanceFn, &deleteInstanceFn, &onEvalFn, &onInitializedFn,
208 static ImplicitLocOpBuilder createBuilder(ModuleOp &moduleOp) {
209 auto builder = ImplicitLocOpBuilder(moduleOp.getLoc(), moduleOp);
210 builder.setInsertionPointToStart(moduleOp.getBody());
213 void buildTraceInstrumentationFn(Type ty);
218struct RuntimeModelContext {
219 RuntimeModelContext() =
delete;
221 RuntimeModelContext(GlobalRuntimeContext &globalContext, ModelOp &modelOp,
222 const ModelInfo &modelInfo)
223 : globalContext(globalContext), modelOp(modelOp), modelInfo(modelInfo) {}
226 void addInstance(SimInstantiateOp &instantiateOp) {
227 assert(!instantiateOp.getRuntimeModelAttr());
229 instantiateOp.setRuntimeModelAttr(
230 FlatSymbolRefAttr::get(runtimeModelOp.getSymNameAttr()));
231 instances.push_back(instantiateOp);
234 void addTappedStateWrite(StateWriteOp &writeOp) {
235 assert(writeOp.getTraceTapModel().has_value() &&
236 writeOp.getTraceTapIndex().has_value());
237 assert(modelOp.getSymNameAttr() ==
238 writeOp.getTraceTapModelAttr().getAttr());
239 tappedWrites.push_back(writeOp);
242 bool hasTraceTaps() {
return runtimeModelOp.getTraceTaps().has_value(); }
246 LogicalResult insertTraceInstrumentation();
248 LogicalResult lower();
251 GlobalRuntimeContext &globalContext;
255 const ModelInfo &modelInfo;
257 SmallVector<SimInstantiateOp> instances;
259 RuntimeModelOp runtimeModelOp;
261 SmallVector<StateWriteOp> tappedWrites;
264 LogicalResult lowerInstance(SimInstantiateOp &instance);
266struct InsertRuntimePass
267 :
public arc::impl::InsertRuntimeBase<InsertRuntimePass> {
268 using InsertRuntimeBase::InsertRuntimeBase;
270 void runOnOperation()
override;
274 SmallString<32> buildArgString(
unsigned instIdx, StringAttr existingArgs) {
277 str.append(existingArgs);
279 if (!traceFileName.empty()) {
286 str += traceFileName;
289 std::filesystem::path(
static_cast<std::string
>(traceFileName))
292 str += traceFileName.substr(0, traceFileName.size() - extension.size());
294 str += std::to_string(instIdx);
299 if (!extraArgs.empty()) {
302 str.append(extraArgs);
310void GlobalRuntimeContext::addModel(ModelOp &modelOp,
311 const ModelInfo &modelInfo) {
313 std::make_unique<RuntimeModelContext>(*
this, modelOp, modelInfo);
314 models[modelOp.getNameAttr()] = std::move(newModel);
319LogicalResult GlobalRuntimeContext::collectInstances() {
320 bool hasFailed =
false;
321 mlirModuleOp.getBody()->walk([&](Operation *op) -> WalkResult {
322 if (
auto instOp = dyn_cast<SimInstantiateOp>(op)) {
324 if (instOp.getRuntimeModel())
325 return WalkResult::skip();
326 auto instanceModelSym = llvm::cast<SimModelInstanceType>(
327 instOp.getBody().getArgument(0).getType())
330 auto modelContext = models.find(instanceModelSym);
331 if (modelContext == models.end()) {
333 instOp->emitOpError(
" does not refer to a known Arc model.");
335 modelContext->second->addInstance(instOp);
337 return WalkResult::skip();
339 if (
auto instOp = dyn_cast<ModelOp>(op))
340 return WalkResult::skip();
341 return WalkResult::advance();
343 return success(!hasFailed);
346LogicalResult GlobalRuntimeContext::buildTraceInstrumentation() {
348 models, [](
auto &modelIt) {
return modelIt.second->hasTraceTaps(); }))
351 swapTraceBufferFn.used =
true;
352 SetVector<Type> tappedTypes;
354 mlirModuleOp.getBody()->walk([&](StateWriteOp writeOp) {
355 if (!writeOp.getTraceTapModel().has_value())
357 auto modelCtxt = models.find(writeOp.getTraceTapModelAttr().getAttr());
358 assert(modelCtxt != models.end() &&
"Unknown referenced model");
359 modelCtxt->second->addTappedStateWrite(writeOp);
360 if (isa<IntegerType>(writeOp.getValue().getType()))
361 buildTraceInstrumentationFn(writeOp.getValue().getType());
363 writeOp->emitWarning(
"Tracing of non-integer type is not supported");
398void GlobalRuntimeContext::buildTraceInstrumentationFn(Type ty) {
399 assert(isa<IntegerType>(ty));
401 auto traceTy = getTraceExtendedType(ty);
402 if (traceInstrumentationFns.contains(traceTy))
406 auto typeQWords = traceTy.getIntOrFloatBitWidth() / 64;
407 assert(traceTy.getIntOrFloatBitWidth() % 64 == 0);
408 auto *ctx = ty.getContext();
409 auto i64Ty = IntegerType::get(ctx, 64);
410 auto i32Ty = IntegerType::get(ctx, 32);
411 auto llvmPtrTy = LLVM::LLVMPointerType::get(ctx);
412 auto llvmFnTy = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx),
413 {llvmPtrTy, i64Ty, traceTy});
414 auto symName = StringAttr::get(
415 ctx,
"_arc_trace_instrument_i" + Twine(traceTy.getIntOrFloatBitWidth()));
416 auto funcOp = LLVM::LLVMFuncOp::create(globalBuilder, symName, llvmFnTy,
417 LLVM::Linkage::Private);
418 funcOp.setNoInline(
true);
419 traceInstrumentationFns.insert({traceTy, funcOp});
422 OpBuilder::InsertionGuard g(globalBuilder);
423 auto *capcaityCheckBlock = funcOp.addEntryBlock(globalBuilder);
424 auto *swapBufferBlock = &funcOp.getRegion().emplaceBlock();
425 auto *bufferStoreBlock = &funcOp.getRegion().emplaceBlock();
427 bufferStoreBlock->addArgument(llvmPtrTy, globalBuilder.getLoc());
429 bufferStoreBlock->addArgument(i32Ty, globalBuilder.getLoc());
432 globalBuilder.setInsertionPointToStart(capcaityCheckBlock);
433 auto modelStatePtr = capcaityCheckBlock->getArgument(0);
434 auto bufferPtrPtr = LLVM::GEPOp::create(
435 globalBuilder, llvmPtrTy, globalBuilder.getI8Type(), modelStatePtr,
436 {LLVM::GEPArg(static_cast<int>(offsetof(ArcState, traceBuffer)) -
437 static_cast<int>(sizeof(ArcState)))});
438 auto bufferSizePtr = LLVM::GEPOp::create(
439 globalBuilder, llvmPtrTy, globalBuilder.getI8Type(), modelStatePtr,
440 {LLVM::GEPArg(static_cast<int>(offsetof(ArcState, traceBufferSize)) -
441 static_cast<int>(sizeof(ArcState)))});
443 auto requiredSize = typeQWords + 1;
444 auto reqSizeCst = LLVM::ConstantOp::create(
445 globalBuilder, globalBuilder.getI32IntegerAttr(requiredSize));
448 LLVM::LoadOp::create(globalBuilder, llvmPtrTy, bufferPtrPtr);
451 LLVM::LoadOp::create(globalBuilder, i32Ty, bufferSizePtr);
452 auto capacityConstant = LLVM::ConstantOp::create(
454 globalBuilder.getI32IntegerAttr(runtime::defaultTraceBufferCapacity));
457 LLVM::AddOp::create(globalBuilder, bufferSizeVal, reqSizeCst);
460 LLVM::GEPOp::create(globalBuilder, llvmPtrTy, i64Ty, bufferPtrVal,
461 {LLVM::GEPArg(bufferSizeVal)});
463 auto needsSwap = LLVM::ICmpOp::create(globalBuilder, LLVM::ICmpPredicate::ugt,
464 newSizeVal, capacityConstant);
465 LLVM::CondBrOp::create(
466 globalBuilder, needsSwap, swapBufferBlock, {}, bufferStoreBlock,
467 {storePtr, newSizeVal},
469 std::pair<int32_t, int32_t>(0, std::numeric_limits<int32_t>::max()));
472 globalBuilder.setInsertionPointToStart(swapBufferBlock);
474 auto swapCall = LLVM::CallOp::create(
475 globalBuilder, swapTraceBufferFn.llvmFuncOp, {modelStatePtr});
477 LLVM::StoreOp::create(globalBuilder, swapCall.getResult(), bufferPtrPtr);
478 LLVM::BrOp::create(globalBuilder, {swapCall.getResult(), reqSizeCst},
482 globalBuilder.setInsertionPointToStart(bufferStoreBlock);
484 LLVM::StoreOp::create(globalBuilder, capcaityCheckBlock->getArgument(1),
485 bufferStoreBlock->getArgument(0));
488 for (
unsigned qWord = 0; qWord < typeQWords; ++qWord) {
490 auto dataStorePtr = LLVM::GEPOp::create(globalBuilder, llvmPtrTy, i64Ty,
491 bufferStoreBlock->getArgument(0),
492 {LLVM::GEPArg(qWord + 1)});
493 Value storeVal = capcaityCheckBlock->getArgument(2);
495 auto shiftCst = LLVM::ConstantOp::create(
497 globalBuilder.getIntegerAttr(storeVal.getType(), qWord * 64));
498 storeVal = LLVM::LShrOp::create(globalBuilder, storeVal, shiftCst);
500 if (storeVal.getType() != i64Ty)
501 storeVal = LLVM::TruncOp::create(globalBuilder, i64Ty, storeVal);
502 LLVM::StoreOp::create(globalBuilder, storeVal, dataStorePtr);
505 LLVM::StoreOp::create(globalBuilder, bufferStoreBlock->getArgument(1),
507 LLVM::ReturnOp::create(globalBuilder, Value{});
511LogicalResult GlobalRuntimeContext::buildRuntimeModelOps() {
512 auto savedLoc = globalBuilder.getLoc();
513 for (
auto &[_, model] : models) {
514 globalBuilder.setLoc(model->modelOp.getLoc());
515 auto symName = globalBuilder.getStringAttr(Twine(
"arcRuntimeModel_") +
516 model->modelInfo.name);
517 model->runtimeModelOp = RuntimeModelOp::create(
518 globalBuilder, symName,
519 globalBuilder.getStringAttr(model->modelInfo.name),
520 static_cast<uint64_t
>(model->modelInfo.numStateBytes),
521 model->modelOp.getTraceTapsAttr());
522 model->modelOp.setTraceTapsAttr({});
524 globalBuilder.setLoc(savedLoc);
529LogicalResult RuntimeModelContext::lower() {
530 bool hasFailed =
false;
531 for (
auto &instance : instances)
532 if (failed(lowerInstance(instance)))
534 if (failed(insertTraceInstrumentation()))
536 return success(!hasFailed);
540LogicalResult RuntimeModelContext::insertTraceInstrumentation() {
541 if (!hasTraceTaps() || tappedWrites.empty())
543 bool hasFailed =
false;
544 ImplicitLocOpBuilder builder(runtimeModelOp.getLoc(),
545 runtimeModelOp.getContext());
546 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
547 for (
auto writeOp : tappedWrites) {
548 builder.setInsertionPoint(writeOp);
549 builder.setLoc(writeOp.getLoc());
551 auto tapId = *writeOp.getTraceTapIndex();
552 assert(tapId < runtimeModelOp.getTraceTapsAttr().size());
553 auto tapAttr = cast<TraceTapAttr>(runtimeModelOp.getTraceTapsAttr()[tapId]);
554 auto traceTy = GlobalRuntimeContext::getTraceExtendedType(
555 writeOp.getValue().getType());
556 auto instrumentFn = globalContext.getTraceInstrumentFn(traceTy);
558 writeOp.setTraceTapIndex(std::nullopt);
559 writeOp.setTraceTapModel(std::nullopt);
561 auto oldRead = StateReadOp::create(builder, writeOp.getState());
562 auto hasChanged = LLVM::ICmpOp::create(builder, LLVM::ICmpPredicate::ne,
563 writeOp.getValue(), oldRead);
565 builder, hasChanged, [&](OpBuilder scfBuilder, Location loc) {
567 scfBuilder.clone(*writeOp.getOperation());
569 auto statePtrCast = UnrealizedConversionCastOp::create(
570 scfBuilder, loc, ptrTy, writeOp.getState());
571 auto baseStatePtr = LLVM::GEPOp::create(
572 scfBuilder, loc, ptrTy, scfBuilder.getI8Type(),
573 statePtrCast.getResult(0),
575 static_cast<int32_t>(tapAttr.getStateOffset()))});
576 auto tapIdxCst = LLVM::ConstantOp::create(
577 scfBuilder, loc, scfBuilder.getI64IntegerAttr(tapId));
578 Value storeVal = writeOp.getValue();
579 if (traceTy != storeVal.getType())
580 storeVal = LLVM::ZExtOp::create(scfBuilder, loc, traceTy, storeVal)
582 LLVM::CallOp::create(scfBuilder, loc, instrumentFn,
583 {baseStatePtr, tapIdxCst, storeVal});
584 scf::YieldOp::create(builder, loc);
588 tappedWrites.clear();
589 return success(!hasFailed);
592LogicalResult RuntimeModelContext::lowerInstance(SimInstantiateOp &instance) {
594 globalContext.allocInstanceFn.used =
true;
595 globalContext.onInitializedFn.used =
true;
596 globalContext.deleteInstanceFn.used =
true;
599 OpBuilder instBodyBuilder(instance);
600 instBodyBuilder.setInsertionPointToStart(
601 &instance.getBody().getBlocks().front());
603 UnrealizedConversionCastOp::create(
604 instBodyBuilder, instance.getLoc(),
605 LLVM::LLVMPointerType::get(instBodyBuilder.getContext()),
606 instance.getBody().getArgument(0))
609 instance.getBody().getBlocks().front().walk([&](SimStepOp stepOp) {
610 instBodyBuilder.setInsertionPoint(stepOp);
611 globalContext.onEvalFn.used =
true;
612 LLVM::CallOp::create(instBodyBuilder, stepOp.getLoc(),
613 globalContext.onEvalFn.llvmFuncOp, {runtimeInst});
619void InsertRuntimePass::runOnOperation() {
622 auto &modelInfo = getAnalysis<ModelInfoAnalysis>();
623 auto globalContext = std::make_unique<GlobalRuntimeContext>(getOperation());
624 for (
auto &[mOp, mInfo] : modelInfo.infoMap)
625 globalContext->addModel(mOp, mInfo);
626 if (failed(globalContext->buildRuntimeModelOps()) ||
627 failed(globalContext->buildTraceInstrumentation()) ||
628 failed(globalContext->collectInstances())) {
634 for (
auto &[_, model] : globalContext->models) {
636 if (!extraArgs.empty() || !traceFileName.empty()) {
637 for (
auto [idx, instance] :
llvm::enumerate(model->instances)) {
638 auto newArgs = buildArgString(idx, instance.getRuntimeArgsAttr());
639 auto newArgAttr = StringAttr::get(&getContext(), newArgs);
640 instance.setRuntimeArgsAttr(newArgAttr);
644 if (failed(model->lower()))
648 globalContext->deleteUnusedFunctions();
649 markAnalysesPreserved<ModelInfoAnalysis>();
assert(baseType &&"element must be base type")
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.