CIRCT 23.0.0git
Loading...
Searching...
No Matches
InsertRuntime.cpp
Go to the documentation of this file.
1//===- InsertRuntime.cpp --------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Pass/Pass.h"
19
20#include <filesystem>
21
22#define DEBUG_TYPE "arc-insert-runtime"
23
24namespace circt {
25namespace arc {
26#define GEN_PASS_DEF_INSERTRUNTIME
27#include "circt/Dialect/Arc/ArcPasses.h.inc"
28} // namespace arc
29} // namespace circt
30
31using namespace mlir;
32using namespace circt;
33using namespace arc;
34
35namespace {
36
37// API Helpers
38struct RuntimeFunction {
39 LLVM::LLVMFuncOp llvmFuncOp = {};
40 bool used = false;
41
42protected:
43 // Add attributes for passing the model state pointer to the runtime library
44 void setModelStateArgAttrs(OpBuilder &builder, unsigned argIndex,
45 bool isMutable) {
46 llvmFuncOp.setArgAttr(0, LLVM::LLVMDialect::getNoCaptureAttrName(),
47 builder.getUnitAttr());
48 llvmFuncOp.setArgAttr(0, LLVM::LLVMDialect::getNoFreeAttrName(),
49 builder.getUnitAttr());
50 llvmFuncOp.setArgAttr(0, LLVM::LLVMDialect::getNoAliasAttrName(),
51 builder.getUnitAttr());
52 if (!isMutable)
53 llvmFuncOp.setArgAttr(0, LLVM::LLVMDialect::getReadonlyAttrName(),
54 builder.getUnitAttr());
55 }
56};
57
58struct AllocInstanceFunction : public RuntimeFunction {
59 explicit AllocInstanceFunction(ImplicitLocOpBuilder &builder) {
60 /*
61 uint8_t *
62 arcRuntimeIR_allocInstance(const ArcRuntimeModelInfo *model, const char
63 *args);
64 */
65 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
66 llvmFuncOp = LLVM::LLVMFuncOp::create(
67 builder, runtime::APICallbacks::symNameAllocInstance,
68 LLVM::LLVMFunctionType::get(ptrTy, {ptrTy, ptrTy}));
69 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNoAliasAttrName(),
70 builder.getUnitAttr());
71 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNoUndefAttrName(),
72 builder.getUnitAttr());
73 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNonNullAttrName(),
74 builder.getUnitAttr());
75 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getAlignAttrName(),
76 builder.getI64IntegerAttr(16));
77 }
78};
79
80struct DeleteInstanceFunction : public RuntimeFunction {
81 explicit DeleteInstanceFunction(ImplicitLocOpBuilder &builder) {
82 /*
83 void arcRuntimeIR_deleteInstance(uint8_t *modelState);
84 */
85 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
86 auto voidTy = LLVM::LLVMVoidType::get(builder.getContext());
87 llvmFuncOp = LLVM::LLVMFuncOp::create(
88 builder, runtime::APICallbacks::symNameDeleteInstance,
89 LLVM::LLVMFunctionType::get(voidTy, {ptrTy}));
90 }
91};
92
93struct OnEvalFunction : public RuntimeFunction {
94 explicit OnEvalFunction(ImplicitLocOpBuilder &builder) {
95 /*
96 void arcRuntimeIR_onEval(uint8_t *modelState);
97 */
98 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
99 auto voidTy = LLVM::LLVMVoidType::get(builder.getContext());
100 llvmFuncOp =
101 LLVM::LLVMFuncOp::create(builder, runtime::APICallbacks::symNameOnEval,
102 LLVM::LLVMFunctionType::get(voidTy, {ptrTy}));
103 setModelStateArgAttrs(builder, 0, true);
104 }
105};
106
107struct OnInitializedFunction : public RuntimeFunction {
108 explicit OnInitializedFunction(ImplicitLocOpBuilder &builder) {
109 /*
110 void arcRuntimeIR_onInitialized(uint8_t *modelState);
111 */
112 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
113 auto voidTy = LLVM::LLVMVoidType::get(builder.getContext());
114 llvmFuncOp = LLVM::LLVMFuncOp::create(
115 builder, runtime::APICallbacks::symNameOnInitialized,
116 LLVM::LLVMFunctionType::get(voidTy, {ptrTy}));
117 setModelStateArgAttrs(builder, 0, true);
118 }
119};
120
121struct SwapTraceBufferFunction : public RuntimeFunction {
122 explicit SwapTraceBufferFunction(ImplicitLocOpBuilder &builder) {
123 /*
124 uint64_t *arcRuntimeIR_swapTraceBuffer(const uint8_t *modelState);
125 */
126 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
127 llvmFuncOp = LLVM::LLVMFuncOp::create(
128 builder, runtime::APICallbacks::symNameSwapTraceBuffer,
129 LLVM::LLVMFunctionType::get(ptrTy, {ptrTy}));
130 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNoAliasAttrName(),
131 builder.getUnitAttr());
132 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNoUndefAttrName(),
133 builder.getUnitAttr());
134 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getNonNullAttrName(),
135 builder.getUnitAttr());
136 llvmFuncOp.setResultAttr(0, LLVM::LLVMDialect::getAlignAttrName(),
137 builder.getI64IntegerAttr(8));
138 setModelStateArgAttrs(builder, 0, false);
139 }
140};
141
142// Lowering Helpers
143
144struct RuntimeModelContext; // Forward declaration
145
146struct GlobalRuntimeContext {
147 GlobalRuntimeContext() = delete;
148
149 /// Constructs a global context and adds the available runtime API function
150 /// declarations to the MLIR module
151 explicit GlobalRuntimeContext(ModuleOp moduleOp)
152 : mlirModuleOp(moduleOp), globalBuilder(createBuilder(moduleOp)),
153 allocInstanceFn(globalBuilder), deleteInstanceFn(globalBuilder),
154 onEvalFn(globalBuilder), onInitializedFn(globalBuilder),
155 swapTraceBufferFn(globalBuilder) {}
156
157 /// Delete all API functions that are never called
158 void deleteUnusedFunctions() {
159 for (auto *fn : apiFunctions)
160 if (!fn->used)
161 fn->llvmFuncOp->erase();
162 }
163
164 /// Map a type to its corresponding data type in the trace buffer
165 static Type getTraceExtendedType(Type stateType) {
166 auto numBits = stateType.getIntOrFloatBitWidth();
167 auto numQWords = std::max((numBits + 63) / 64, 1U);
168 return IntegerType::get(stateType.getContext(), numQWords * 64);
169 }
170
171 /// Add an Arc model to the global runtime context
172 void addModel(ModelOp &modelOp, const ModelInfo &modelInfo);
173 /// Build a RuntimeModelOp for each registered model
174 LogicalResult buildRuntimeModelOps();
175 /// Find and assign instances of the registered models within the root module
176 LogicalResult collectInstances();
177 /// Collect tapped StateWriteOps, assign them to their model, and build the
178 /// trace instrumentation functions for the required types
179 LogicalResult buildTraceInstrumentation();
180
181 /// Lookup the trace instrumentation function for the given (extended) type
182 LLVM::LLVMFuncOp getTraceInstrumentFn(Type ty) const {
183 assert(ty.getIntOrFloatBitWidth() % 64 == 0);
184 auto fn = traceInstrumentationFns.find(ty);
185 assert(fn != traceInstrumentationFns.end());
186 return fn->second;
187 }
188
189 /// The root module
190 ModuleOp mlirModuleOp;
191 /// Builder for global operations
192 ImplicitLocOpBuilder globalBuilder;
193
194 // API Functions
195 AllocInstanceFunction allocInstanceFn;
196 DeleteInstanceFunction deleteInstanceFn;
197 OnEvalFunction onEvalFn;
198 OnInitializedFunction onInitializedFn;
199 SwapTraceBufferFunction swapTraceBufferFn;
200 const std::array<RuntimeFunction *, 5> apiFunctions = {
201 &allocInstanceFn, &deleteInstanceFn, &onEvalFn, &onInitializedFn,
202 &swapTraceBufferFn};
203
204 // Maps model symbol name to model context
206
207private:
208 static ImplicitLocOpBuilder createBuilder(ModuleOp &moduleOp) {
209 auto builder = ImplicitLocOpBuilder(moduleOp.getLoc(), moduleOp);
210 builder.setInsertionPointToStart(moduleOp.getBody());
211 return builder;
212 }
213 void buildTraceInstrumentationFn(Type ty);
214
215 SmallDenseMap<Type, LLVM::LLVMFuncOp> traceInstrumentationFns;
216};
217
218struct RuntimeModelContext {
219 RuntimeModelContext() = delete;
220 /// Construct the local context for an Arc model within the global context
221 RuntimeModelContext(GlobalRuntimeContext &globalContext, ModelOp &modelOp,
222 const ModelInfo &modelInfo)
223 : globalContext(globalContext), modelOp(modelOp), modelInfo(modelInfo) {}
224
225 /// Register an MLIR defined instance of our model
226 void addInstance(SimInstantiateOp &instantiateOp) {
227 assert(!instantiateOp.getRuntimeModelAttr());
228 assert(!!runtimeModelOp);
229 instantiateOp.setRuntimeModelAttr(
230 FlatSymbolRefAttr::get(runtimeModelOp.getSymNameAttr()));
231 instances.push_back(instantiateOp);
232 }
233
234 void addTappedStateWrite(StateWriteOp &writeOp) {
235 assert(writeOp.getTraceTapModel().has_value() &&
236 writeOp.getTraceTapIndex().has_value());
237 assert(modelOp.getSymNameAttr() ==
238 writeOp.getTraceTapModelAttr().getAttr());
239 tappedWrites.push_back(writeOp);
240 }
241
242 bool hasTraceTaps() { return runtimeModelOp.getTraceTaps().has_value(); }
243
244 /// Insert calls to the trace instrumentation functions for tapped state
245 /// writes
246 LogicalResult insertTraceInstrumentation();
247 /// Insert runtime calls to the model and its instances
248 LogicalResult lower();
249
250 /// The global runtime context
251 GlobalRuntimeContext &globalContext;
252 /// This context's model
253 ModelOp modelOp;
254 /// Model metadata
255 const ModelInfo &modelInfo;
256 /// List of registered instances
257 SmallVector<SimInstantiateOp> instances;
258 /// The model's corresponding RuntimeModelOp
259 RuntimeModelOp runtimeModelOp;
260 // StateWrite ops referring to one of this model's trace taps
261 SmallVector<StateWriteOp> tappedWrites;
262
263private:
264 LogicalResult lowerInstance(SimInstantiateOp &instance);
265};
266struct InsertRuntimePass
267 : public arc::impl::InsertRuntimeBase<InsertRuntimePass> {
268 using InsertRuntimeBase::InsertRuntimeBase;
269
270 void runOnOperation() override;
271
272private:
273 // Construct the runtime argument string for an instance
274 SmallString<32> buildArgString(unsigned instIdx, StringAttr existingArgs) {
275 SmallString<32> str;
276 if (existingArgs)
277 str.append(existingArgs);
278 // If requested, append the trace file name
279 if (!traceFileName.empty()) {
280 if (!str.empty())
281 str += ';';
282 str += "traceFile=";
283 // Create a unique per-instance file name by adding a suffix before the
284 // the file extension
285 if (instIdx == 0) {
286 str += traceFileName;
287 } else {
288 auto extension =
289 std::filesystem::path(static_cast<std::string>(traceFileName))
290 .extension()
291 .string();
292 str += traceFileName.substr(0, traceFileName.size() - extension.size());
293 str += '_';
294 str += std::to_string(instIdx);
295 str += extension;
296 }
297 }
298 // Append extra arguments from pass option
299 if (!extraArgs.empty()) {
300 if (!str.empty())
301 str += ';';
302 str.append(extraArgs);
303 }
304 return str;
305 }
306};
307
308} // namespace
309
310void GlobalRuntimeContext::addModel(ModelOp &modelOp,
311 const ModelInfo &modelInfo) {
312 auto newModel =
313 std::make_unique<RuntimeModelContext>(*this, modelOp, modelInfo);
314 models[modelOp.getNameAttr()] = std::move(newModel);
315}
316
317// Find all instances in the MLIR Module and assign them to their
318// respective Arc Model
319LogicalResult GlobalRuntimeContext::collectInstances() {
320 bool hasFailed = false;
321 mlirModuleOp.getBody()->walk([&](Operation *op) -> WalkResult {
322 if (auto instOp = dyn_cast<SimInstantiateOp>(op)) {
323 // Don't touch instances which somehow already carry a runtime model
324 if (instOp.getRuntimeModel())
325 return WalkResult::skip();
326 auto instanceModelSym = llvm::cast<SimModelInstanceType>(
327 instOp.getBody().getArgument(0).getType())
328 .getModel()
329 .getAttr();
330 auto modelContext = models.find(instanceModelSym);
331 if (modelContext == models.end()) {
332 hasFailed = true;
333 instOp->emitOpError(" does not refer to a known Arc model.");
334 } else {
335 modelContext->second->addInstance(instOp);
336 }
337 return WalkResult::skip();
338 }
339 if (auto instOp = dyn_cast<ModelOp>(op))
340 return WalkResult::skip();
341 return WalkResult::advance();
342 });
343 return success(!hasFailed);
344}
345
346LogicalResult GlobalRuntimeContext::buildTraceInstrumentation() {
347 if (llvm::none_of(
348 models, [](auto &modelIt) { return modelIt.second->hasTraceTaps(); }))
349 return success();
350
351 swapTraceBufferFn.used = true;
352 SetVector<Type> tappedTypes;
353
354 mlirModuleOp.getBody()->walk([&](StateWriteOp writeOp) {
355 if (!writeOp.getTraceTapModel().has_value())
356 return;
357 auto modelCtxt = models.find(writeOp.getTraceTapModelAttr().getAttr());
358 assert(modelCtxt != models.end() && "Unknown referenced model");
359 modelCtxt->second->addTappedStateWrite(writeOp);
360 if (isa<IntegerType>(writeOp.getValue().getType()))
361 buildTraceInstrumentationFn(writeOp.getValue().getType());
362 else
363 writeOp->emitWarning("Tracing of non-integer type is not supported");
364 });
365
366 return success();
367}
368
369// Build a trace instrumentation function recording the change of a state
370// value to the trace buffer. Calls the runtime library if the current buffer
371// is running out of space.
372// Pseudocode of the constructed function:
373//
374//
375// void _arc_trace_instrument_i{BW}(uint8_t *modelState, uint64_t traceTapId,
376// uint{BW}_t newValue) {
377// // BB: "capcaityCheckBlock"
378// const uint32_t reqSize = {BW} / 64 + 1;
379// ArcState *runtimeState = (ArcState*)(modelState - sizeof(ArcState));
380// uint64_t *oldBuffer = runtimeState->traceBuffer;
381// const uint32_t oldSize = runtimeState->traceBufferSize;
382// uint32_t newSize = oldSize + reqSize;
383// uint64_t *storePtr = &oldBuffer[oldSize];
384// if (newSize >= runtime::defaultTraceBufferCapacity) [[unlikely]] {
385// // BB: "swapBufferBlock"
386// storePtr = arcRuntimeIR_swapTraceBuffer(modelState);
387// runtimeState->traceBuffer = storePtr;
388// newSize = reqSize;
389// }
390// // BB: "bufferStoreBlock"
391// storePtr[0] = traceTapId;
392// for (unsigned qword = 0; qword < {BW} / 64; ++qword) // Unrolled
393// storePtr[qword + 1] = (uint64_t)(newValue >> (64 * qword));
394// runtimeState->traceBufferSize = newSize;
395// }
396//
397
398void GlobalRuntimeContext::buildTraceInstrumentationFn(Type ty) {
399 assert(isa<IntegerType>(ty));
400 // Check if we've already built the function
401 auto traceTy = getTraceExtendedType(ty);
402 if (traceInstrumentationFns.contains(traceTy))
403 return;
404
405 // Build the function signature
406 auto typeQWords = traceTy.getIntOrFloatBitWidth() / 64;
407 assert(traceTy.getIntOrFloatBitWidth() % 64 == 0);
408 auto *ctx = ty.getContext();
409 auto i64Ty = IntegerType::get(ctx, 64);
410 auto i32Ty = IntegerType::get(ctx, 32);
411 auto llvmPtrTy = LLVM::LLVMPointerType::get(ctx);
412 auto llvmFnTy = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx),
413 {llvmPtrTy, i64Ty, traceTy});
414 auto symName = StringAttr::get(
415 ctx, "_arc_trace_instrument_i" + Twine(traceTy.getIntOrFloatBitWidth()));
416 auto funcOp = LLVM::LLVMFuncOp::create(globalBuilder, symName, llvmFnTy,
417 LLVM::Linkage::Private);
418 funcOp.setNoInline(true);
419 traceInstrumentationFns.insert({traceTy, funcOp});
420
421 // Build the body of the function
422 OpBuilder::InsertionGuard g(globalBuilder);
423 auto *capcaityCheckBlock = funcOp.addEntryBlock(globalBuilder);
424 auto *swapBufferBlock = &funcOp.getRegion().emplaceBlock();
425 auto *bufferStoreBlock = &funcOp.getRegion().emplaceBlock();
426 // storePtr
427 bufferStoreBlock->addArgument(llvmPtrTy, globalBuilder.getLoc());
428 // newSize
429 bufferStoreBlock->addArgument(i32Ty, globalBuilder.getLoc());
430
431 // --- capcaityCheckBlock ---
432 globalBuilder.setInsertionPointToStart(capcaityCheckBlock);
433 auto modelStatePtr = capcaityCheckBlock->getArgument(0);
434 auto bufferPtrPtr = LLVM::GEPOp::create(
435 globalBuilder, llvmPtrTy, globalBuilder.getI8Type(), modelStatePtr,
436 {LLVM::GEPArg(static_cast<int>(offsetof(ArcState, traceBuffer)) -
437 static_cast<int>(sizeof(ArcState)))});
438 auto bufferSizePtr = LLVM::GEPOp::create(
439 globalBuilder, llvmPtrTy, globalBuilder.getI8Type(), modelStatePtr,
440 {LLVM::GEPArg(static_cast<int>(offsetof(ArcState, traceBufferSize)) -
441 static_cast<int>(sizeof(ArcState)))});
442 // > const uint32_t reqSize = {BW} / 64 + 1;
443 auto requiredSize = typeQWords + 1;
444 auto reqSizeCst = LLVM::ConstantOp::create(
445 globalBuilder, globalBuilder.getI32IntegerAttr(requiredSize));
446 // > uint64_t *oldBuffer = runtimeState->traceBuffer;
447 auto bufferPtrVal =
448 LLVM::LoadOp::create(globalBuilder, llvmPtrTy, bufferPtrPtr);
449 // > const uint32_t oldSize = runtimeState->traceBufferSize;
450 auto bufferSizeVal =
451 LLVM::LoadOp::create(globalBuilder, i32Ty, bufferSizePtr);
452 auto capacityConstant = LLVM::ConstantOp::create(
453 globalBuilder,
454 globalBuilder.getI32IntegerAttr(runtime::defaultTraceBufferCapacity));
455 // > uint32_t newSize = oldSize + reqSize;
456 auto newSizeVal =
457 LLVM::AddOp::create(globalBuilder, bufferSizeVal, reqSizeCst);
458 // > uint64_t *storePtr = &oldBuffer[oldSize];
459 auto storePtr =
460 LLVM::GEPOp::create(globalBuilder, llvmPtrTy, i64Ty, bufferPtrVal,
461 {LLVM::GEPArg(bufferSizeVal)});
462 // > if (newSize >= runtime::defaultTraceBufferCapacity) [[unlikely]]
463 auto needsSwap = LLVM::ICmpOp::create(globalBuilder, LLVM::ICmpPredicate::ugt,
464 newSizeVal, capacityConstant);
465 LLVM::CondBrOp::create(
466 globalBuilder, needsSwap, swapBufferBlock, {}, bufferStoreBlock,
467 {storePtr, newSizeVal},
468 /*weights*/
469 std::pair<int32_t, int32_t>(0, std::numeric_limits<int32_t>::max()));
470
471 // --- swapBufferBlock ---
472 globalBuilder.setInsertionPointToStart(swapBufferBlock);
473 // > storePtr = arcRuntimeIR_swapTraceBuffer(modelState);
474 auto swapCall = LLVM::CallOp::create(
475 globalBuilder, swapTraceBufferFn.llvmFuncOp, {modelStatePtr});
476 // > runtimeState->traceBuffer = storePtr;
477 LLVM::StoreOp::create(globalBuilder, swapCall.getResult(), bufferPtrPtr);
478 LLVM::BrOp::create(globalBuilder, {swapCall.getResult(), reqSizeCst},
479 bufferStoreBlock);
480
481 // --- bufferStoreBlock ---
482 globalBuilder.setInsertionPointToStart(bufferStoreBlock);
483 // > storePtr[0] = traceTapId;
484 LLVM::StoreOp::create(globalBuilder, capcaityCheckBlock->getArgument(1),
485 bufferStoreBlock->getArgument(0));
486
487 // > for (unsigned qword = 0; qword < {BW} / 64; ++qword) // Unrolled
488 for (unsigned qWord = 0; qWord < typeQWords; ++qWord) {
489 // > storePtr[qword + 1] = (uint64_t)(newValue >> (64 * qword));
490 auto dataStorePtr = LLVM::GEPOp::create(globalBuilder, llvmPtrTy, i64Ty,
491 bufferStoreBlock->getArgument(0),
492 {LLVM::GEPArg(qWord + 1)});
493 Value storeVal = capcaityCheckBlock->getArgument(2);
494 if (qWord > 0) {
495 auto shiftCst = LLVM::ConstantOp::create(
496 globalBuilder,
497 globalBuilder.getIntegerAttr(storeVal.getType(), qWord * 64));
498 storeVal = LLVM::LShrOp::create(globalBuilder, storeVal, shiftCst);
499 }
500 if (storeVal.getType() != i64Ty)
501 storeVal = LLVM::TruncOp::create(globalBuilder, i64Ty, storeVal);
502 LLVM::StoreOp::create(globalBuilder, storeVal, dataStorePtr);
503 }
504 // > runtimeState->traceBufferSize = newSize;
505 LLVM::StoreOp::create(globalBuilder, bufferStoreBlock->getArgument(1),
506 bufferSizePtr);
507 LLVM::ReturnOp::create(globalBuilder, Value{});
508}
509
510// Build the global RuntimeModelOp for each model
511LogicalResult GlobalRuntimeContext::buildRuntimeModelOps() {
512 auto savedLoc = globalBuilder.getLoc();
513 for (auto &[_, model] : models) {
514 globalBuilder.setLoc(model->modelOp.getLoc());
515 auto symName = globalBuilder.getStringAttr(Twine("arcRuntimeModel_") +
516 model->modelInfo.name);
517 model->runtimeModelOp = RuntimeModelOp::create(
518 globalBuilder, symName,
519 globalBuilder.getStringAttr(model->modelInfo.name),
520 static_cast<uint64_t>(model->modelInfo.numStateBytes),
521 model->modelOp.getTraceTapsAttr());
522 model->modelOp.setTraceTapsAttr({});
523 }
524 globalBuilder.setLoc(savedLoc);
525 return success();
526}
527
528// Lower the model and all of its instances
529LogicalResult RuntimeModelContext::lower() {
530 bool hasFailed = false;
531 for (auto &instance : instances)
532 if (failed(lowerInstance(instance)))
533 hasFailed = true;
534 if (failed(insertTraceInstrumentation()))
535 hasFailed = true;
536 return success(!hasFailed);
537}
538
539// Insert call to the trace instrumentation function to each tapped write
540LogicalResult RuntimeModelContext::insertTraceInstrumentation() {
541 if (!hasTraceTaps() || tappedWrites.empty())
542 return success();
543 bool hasFailed = false;
544 ImplicitLocOpBuilder builder(runtimeModelOp.getLoc(),
545 runtimeModelOp.getContext());
546 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
547 for (auto writeOp : tappedWrites) {
548 builder.setInsertionPoint(writeOp);
549 builder.setLoc(writeOp.getLoc());
550 // Lookup the instrumentation function for the state's type
551 auto tapId = *writeOp.getTraceTapIndex();
552 assert(tapId < runtimeModelOp.getTraceTapsAttr().size());
553 auto tapAttr = cast<TraceTapAttr>(runtimeModelOp.getTraceTapsAttr()[tapId]);
554 auto traceTy = GlobalRuntimeContext::getTraceExtendedType(
555 writeOp.getValue().getType());
556 auto instrumentFn = globalContext.getTraceInstrumentFn(traceTy);
557 // Strip the tap annotation
558 writeOp.setTraceTapIndex(std::nullopt);
559 writeOp.setTraceTapModel(std::nullopt);
560 // Test if the new value differs from the old value
561 auto oldRead = StateReadOp::create(builder, writeOp.getState());
562 auto hasChanged = LLVM::ICmpOp::create(builder, LLVM::ICmpPredicate::ne,
563 writeOp.getValue(), oldRead);
564 scf::IfOp::create(
565 builder, hasChanged, [&](OpBuilder scfBuilder, Location loc) {
566 // Pull the state write itself under the condition
567 scfBuilder.clone(*writeOp.getOperation());
568 // Invoke the instrumentation function
569 auto statePtrCast = UnrealizedConversionCastOp::create(
570 scfBuilder, loc, ptrTy, writeOp.getState());
571 auto baseStatePtr = LLVM::GEPOp::create(
572 scfBuilder, loc, ptrTy, scfBuilder.getI8Type(),
573 statePtrCast.getResult(0),
574 {LLVM::GEPArg(-1 *
575 static_cast<int32_t>(tapAttr.getStateOffset()))});
576 auto tapIdxCst = LLVM::ConstantOp::create(
577 scfBuilder, loc, scfBuilder.getI64IntegerAttr(tapId));
578 Value storeVal = writeOp.getValue();
579 if (traceTy != storeVal.getType())
580 storeVal = LLVM::ZExtOp::create(scfBuilder, loc, traceTy, storeVal)
581 .getResult();
582 LLVM::CallOp::create(scfBuilder, loc, instrumentFn,
583 {baseStatePtr, tapIdxCst, storeVal});
584 scf::YieldOp::create(builder, loc);
585 });
586 writeOp.erase();
587 }
588 tappedWrites.clear();
589 return success(!hasFailed);
590}
591
592LogicalResult RuntimeModelContext::lowerInstance(SimInstantiateOp &instance) {
593 // For now, these get invoked by the lowering of SimInstantiateOp
594 globalContext.allocInstanceFn.used = true;
595 globalContext.onInitializedFn.used = true;
596 globalContext.deleteInstanceFn.used = true;
597
598 // Insert onEval call for every step call
599 OpBuilder instBodyBuilder(instance);
600 instBodyBuilder.setInsertionPointToStart(
601 &instance.getBody().getBlocks().front());
602 auto runtimeInst =
603 UnrealizedConversionCastOp::create(
604 instBodyBuilder, instance.getLoc(),
605 LLVM::LLVMPointerType::get(instBodyBuilder.getContext()),
606 instance.getBody().getArgument(0))
607 .getResult(0);
608
609 instance.getBody().getBlocks().front().walk([&](SimStepOp stepOp) {
610 instBodyBuilder.setInsertionPoint(stepOp);
611 globalContext.onEvalFn.used = true;
612 LLVM::CallOp::create(instBodyBuilder, stepOp.getLoc(),
613 globalContext.onEvalFn.llvmFuncOp, {runtimeInst});
614 });
615
616 return success();
617}
618
619void InsertRuntimePass::runOnOperation() {
620 // Construct the global context and collect information on all
621 // models and instances
622 auto &modelInfo = getAnalysis<ModelInfoAnalysis>();
623 auto globalContext = std::make_unique<GlobalRuntimeContext>(getOperation());
624 for (auto &[mOp, mInfo] : modelInfo.infoMap)
625 globalContext->addModel(mOp, mInfo);
626 if (failed(globalContext->buildRuntimeModelOps()) ||
627 failed(globalContext->buildTraceInstrumentation()) ||
628 failed(globalContext->collectInstances())) {
629 signalPassFailure();
630 return;
631 }
632
633 // Lower all models
634 for (auto &[_, model] : globalContext->models) {
635 // If provided, append extra instance arguments
636 if (!extraArgs.empty() || !traceFileName.empty()) {
637 for (auto [idx, instance] : llvm::enumerate(model->instances)) {
638 auto newArgs = buildArgString(idx, instance.getRuntimeArgsAttr());
639 auto newArgAttr = StringAttr::get(&getContext(), newArgs);
640 instance.setRuntimeArgsAttr(newArgAttr);
641 }
642 }
643
644 if (failed(model->lower()))
645 signalPassFailure();
646 }
647
648 globalContext->deleteUnusedFunctions();
649 markAnalysesPreserved<ModelInfoAnalysis>();
650}
assert(baseType &&"element must be base type")
Definition arc.py:1
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.