CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerArcToLLVM.cpp
Go to the documentation of this file.
1//===- LowerArcToLLVM.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
27#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
28#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
29#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
30#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
31#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
32#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
33#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
34#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
35#include "mlir/Dialect/Func/IR/FuncOps.h"
36#include "mlir/Dialect/Index/IR/IndexOps.h"
37#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
38#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
39#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
40#include "mlir/Dialect/SCF/IR/SCF.h"
41#include "mlir/IR/Builders.h"
42#include "mlir/IR/BuiltinDialect.h"
43#include "mlir/Pass/Pass.h"
44#include "mlir/Transforms/DialectConversion.h"
45#include "llvm/Support/Debug.h"
46#include "llvm/Support/FormatVariadic.h"
47
48#include <cstddef>
49
50#define DEBUG_TYPE "lower-arc-to-llvm"
51
52namespace circt {
53#define GEN_PASS_DEF_LOWERARCTOLLVM
54#include "circt/Conversion/Passes.h.inc"
55} // namespace circt
56
57using namespace mlir;
58using namespace circt;
59using namespace arc;
60using namespace hw;
61using namespace runtime;
62
63//===----------------------------------------------------------------------===//
64// Lowering Patterns
65//===----------------------------------------------------------------------===//
66
67static llvm::Twine evalSymbolFromModelName(StringRef modelName) {
68 return modelName + "_eval";
69}
70
71namespace {
72
73struct ModelOpLowering : public OpConversionPattern<arc::ModelOp> {
74 using OpConversionPattern::OpConversionPattern;
75 LogicalResult
76 matchAndRewrite(arc::ModelOp op, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter) const final {
78 {
79 IRRewriter::InsertionGuard guard(rewriter);
80 rewriter.setInsertionPointToEnd(&op.getBodyBlock());
81 func::ReturnOp::create(rewriter, op.getLoc());
82 }
83 auto funcName =
84 rewriter.getStringAttr(evalSymbolFromModelName(op.getName()));
85 auto funcType =
86 rewriter.getFunctionType(op.getBody().getArgumentTypes(), {});
87 auto func =
88 mlir::func::FuncOp::create(rewriter, op.getLoc(), funcName, funcType);
89 rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
90 rewriter.eraseOp(op);
91 return success();
92 }
93};
94
95struct AllocStorageOpLowering
96 : public OpConversionPattern<arc::AllocStorageOp> {
97 using OpConversionPattern::OpConversionPattern;
98 LogicalResult
99 matchAndRewrite(arc::AllocStorageOp op, OpAdaptor adaptor,
100 ConversionPatternRewriter &rewriter) const final {
101 auto type = typeConverter->convertType(op.getType());
102 if (!op.getOffset().has_value())
103 return failure();
104 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, type, rewriter.getI8Type(),
105 adaptor.getInput(),
106 LLVM::GEPArg(*op.getOffset()));
107 return success();
108 }
109};
110
111template <class ConcreteOp>
112struct AllocStateLikeOpLowering : public OpConversionPattern<ConcreteOp> {
114 using OpConversionPattern<ConcreteOp>::typeConverter;
115 using OpAdaptor = typename ConcreteOp::Adaptor;
116
117 LogicalResult
118 matchAndRewrite(ConcreteOp op, OpAdaptor adaptor,
119 ConversionPatternRewriter &rewriter) const final {
120 // Get a pointer to the correct offset in the storage.
121 auto offsetAttr = op->template getAttrOfType<IntegerAttr>("offset");
122 if (!offsetAttr)
123 return failure();
124 Value ptr = LLVM::GEPOp::create(
125 rewriter, op->getLoc(), adaptor.getStorage().getType(),
126 rewriter.getI8Type(), adaptor.getStorage(),
127 LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
128 rewriter.replaceOp(op, ptr);
129 return success();
130 }
131};
132
133struct StateReadOpLowering : public OpConversionPattern<arc::StateReadOp> {
134 using OpConversionPattern::OpConversionPattern;
135 LogicalResult
136 matchAndRewrite(arc::StateReadOp op, OpAdaptor adaptor,
137 ConversionPatternRewriter &rewriter) const final {
138 // Loading an ArrayRef is a no-op as ArrayRefs are accessed by reference.
139 if (isa<ArrayRefType>(op.getType())) {
140 rewriter.replaceOp(op, adaptor.getState());
141 return success();
142 }
143
144 auto type = typeConverter->convertType(op.getType());
145 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, type, adaptor.getState());
146 return success();
147 }
148};
149
150struct StateWriteOpLowering : public OpConversionPattern<arc::StateWriteOp> {
151 using OpConversionPattern::OpConversionPattern;
152 LogicalResult
153 matchAndRewrite(arc::StateWriteOp op, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter) const final {
155 if (!isa<ArrayRefType>(op.getValue().getType())) {
156 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
157 adaptor.getState());
158 return success();
159 }
160
161 int numBytes = op.getState().getType().getByteWidth();
162 Value size = LLVM::ConstantOp::create(rewriter, op.getLoc(),
163 rewriter.getI64Type(), numBytes);
164 rewriter.replaceOpWithNewOp<LLVM::MemcpyOp>(
165 op, adaptor.getState(), adaptor.getValue(), size, /*volatile=*/false);
166 return success();
167 }
168};
169
170//===----------------------------------------------------------------------===//
171// Time Operations Lowering
172//===----------------------------------------------------------------------===//
173
174struct CurrentTimeOpLowering : public OpConversionPattern<arc::CurrentTimeOp> {
175 using OpConversionPattern::OpConversionPattern;
176 LogicalResult
177 matchAndRewrite(arc::CurrentTimeOp op, OpAdaptor adaptor,
178 ConversionPatternRewriter &rewriter) const final {
179 // Time is stored at offset 0 in storage (no offset needed).
180 Value ptr = adaptor.getStorage();
181 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, rewriter.getI64Type(), ptr);
182 return success();
183 }
184};
185
186// Lower `llhd.constant_time` to an `i64` LLVM constant holding the time in
187// femtoseconds. Time attributes with non-zero delta or epsilon, units smaller
188// than `fs`, or values that overflow `i64` femtoseconds are rejected.
189struct ConstantTimeOpLowering
190 : public OpConversionPattern<llhd::ConstantTimeOp> {
191 using OpConversionPattern::OpConversionPattern;
192 LogicalResult
193 matchAndRewrite(llhd::ConstantTimeOp op, OpAdaptor adaptor,
194 ConversionPatternRewriter &rewriter) const final {
195 auto attr = op.getValue();
196 if (attr.getDelta() != 0 || attr.getEpsilon() != 0)
197 return rewriter.notifyMatchFailure(
198 op, "non-zero delta or epsilon time components are not supported");
199 uint64_t value = attr.getTime();
200 StringRef unit = attr.getTimeUnit();
201 uint64_t scale;
202 if (unit == "fs")
203 scale = 1;
204 else if (unit == "ps")
205 scale = 1'000ULL;
206 else if (unit == "ns")
207 scale = 1'000'000ULL;
208 else if (unit == "us")
209 scale = 1'000'000'000ULL;
210 else if (unit == "ms")
211 scale = 1'000'000'000'000ULL;
212 else if (unit == "s")
213 scale = 1'000'000'000'000'000ULL;
214 else
215 return rewriter.notifyMatchFailure(
216 op, "time units smaller than `fs` are not supported");
217 if (value > std::numeric_limits<uint64_t>::max() / scale)
218 return rewriter.notifyMatchFailure(
219 op, "time value does not fit into `i64` femtoseconds");
220 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, rewriter.getI64Type(),
221 value * scale);
222 return success();
223 }
224};
225
226// `llhd.int_to_time` is a no-op
227struct IntToTimeOpLowering : public OpConversionPattern<llhd::IntToTimeOp> {
228 using OpConversionPattern::OpConversionPattern;
229 LogicalResult
230 matchAndRewrite(llhd::IntToTimeOp op, OpAdaptor adaptor,
231 ConversionPatternRewriter &rewriter) const final {
232 rewriter.replaceOp(op, adaptor.getInput());
233 return success();
234 }
235};
236
237// `llhd.time_to_int` is a no-op
238struct TimeToIntOpLowering : public OpConversionPattern<llhd::TimeToIntOp> {
239 using OpConversionPattern::OpConversionPattern;
240 LogicalResult
241 matchAndRewrite(llhd::TimeToIntOp op, OpAdaptor adaptor,
242 ConversionPatternRewriter &rewriter) const final {
243 rewriter.replaceOp(op, adaptor.getInput());
244 return success();
245 }
246};
247
248//===----------------------------------------------------------------------===//
249// Memory and Storage Lowering
250//===----------------------------------------------------------------------===//
251
252struct AllocMemoryOpLowering : public OpConversionPattern<arc::AllocMemoryOp> {
253 using OpConversionPattern::OpConversionPattern;
254 LogicalResult
255 matchAndRewrite(arc::AllocMemoryOp op, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter) const final {
257 auto offsetAttr = op->getAttrOfType<IntegerAttr>("offset");
258 if (!offsetAttr)
259 return failure();
260 Value ptr = LLVM::GEPOp::create(
261 rewriter, op.getLoc(), adaptor.getStorage().getType(),
262 rewriter.getI8Type(), adaptor.getStorage(),
263 LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
264
265 rewriter.replaceOp(op, ptr);
266 return success();
267 }
268};
269
270struct StorageGetOpLowering : public OpConversionPattern<arc::StorageGetOp> {
271 using OpConversionPattern::OpConversionPattern;
272 LogicalResult
273 matchAndRewrite(arc::StorageGetOp op, OpAdaptor adaptor,
274 ConversionPatternRewriter &rewriter) const final {
275 Value offset = LLVM::ConstantOp::create(
276 rewriter, op.getLoc(), rewriter.getI32Type(), op.getOffsetAttr());
277 Value ptr = LLVM::GEPOp::create(
278 rewriter, op.getLoc(), adaptor.getStorage().getType(),
279 rewriter.getI8Type(), adaptor.getStorage(), offset);
280 rewriter.replaceOp(op, ptr);
281 return success();
282 }
283};
284
285struct MemoryAccess {
286 Value ptr;
287 Value withinBounds;
288};
289
290static MemoryAccess prepareMemoryAccess(Location loc, Value memory,
291 Value address, MemoryType type,
292 ConversionPatternRewriter &rewriter) {
293 auto zextAddrType = rewriter.getIntegerType(
294 cast<IntegerType>(address.getType()).getWidth() + 1);
295 Value addr = LLVM::ZExtOp::create(rewriter, loc, zextAddrType, address);
296 Value addrLimit =
297 LLVM::ConstantOp::create(rewriter, loc, zextAddrType,
298 rewriter.getI32IntegerAttr(type.getNumWords()));
299 Value withinBounds = LLVM::ICmpOp::create(
300 rewriter, loc, LLVM::ICmpPredicate::ult, addr, addrLimit);
301 Value ptr = LLVM::GEPOp::create(
302 rewriter, loc, LLVM::LLVMPointerType::get(memory.getContext()),
303 rewriter.getIntegerType(type.getStride() * 8), memory, ValueRange{addr});
304 return {ptr, withinBounds};
305}
306
307struct MemoryReadOpLowering : public OpConversionPattern<arc::MemoryReadOp> {
308 using OpConversionPattern::OpConversionPattern;
309 LogicalResult
310 matchAndRewrite(arc::MemoryReadOp op, OpAdaptor adaptor,
311 ConversionPatternRewriter &rewriter) const final {
312 auto type = typeConverter->convertType(op.getType());
313 auto memoryType = cast<MemoryType>(op.getMemory().getType());
314 auto access =
315 prepareMemoryAccess(op.getLoc(), adaptor.getMemory(),
316 adaptor.getAddress(), memoryType, rewriter);
317
318 // Only attempt to read the memory if the address is within bounds,
319 // otherwise produce a zero value.
320 rewriter.replaceOpWithNewOp<scf::IfOp>(
321 op, access.withinBounds,
322 [&](auto &builder, auto loc) {
323 Value loadOp = LLVM::LoadOp::create(
324 builder, loc, memoryType.getWordType(), access.ptr);
325 scf::YieldOp::create(builder, loc, loadOp);
326 },
327 [&](auto &builder, auto loc) {
328 Value zeroValue = LLVM::ConstantOp::create(
329 builder, loc, type, builder.getI64IntegerAttr(0));
330 scf::YieldOp::create(builder, loc, zeroValue);
331 });
332 return success();
333 }
334};
335
336struct MemoryWriteOpLowering : public OpConversionPattern<arc::MemoryWriteOp> {
337 using OpConversionPattern::OpConversionPattern;
338 LogicalResult
339 matchAndRewrite(arc::MemoryWriteOp op, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const final {
341 auto access = prepareMemoryAccess(
342 op.getLoc(), adaptor.getMemory(), adaptor.getAddress(),
343 cast<MemoryType>(op.getMemory().getType()), rewriter);
344 auto enable = access.withinBounds;
345
346 // Only attempt to write the memory if the address is within bounds.
347 rewriter.replaceOpWithNewOp<scf::IfOp>(
348 op, enable, [&](auto &builder, auto loc) {
349 LLVM::StoreOp::create(builder, loc, adaptor.getData(), access.ptr);
350 scf::YieldOp::create(builder, loc);
351 });
352 return success();
353 }
354};
355
356/// A dummy lowering for clock gates to an AND gate.
357struct ClockGateOpLowering : public OpConversionPattern<seq::ClockGateOp> {
358 using OpConversionPattern::OpConversionPattern;
359 LogicalResult
360 matchAndRewrite(seq::ClockGateOp op, OpAdaptor adaptor,
361 ConversionPatternRewriter &rewriter) const final {
362 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, adaptor.getInput(),
363 adaptor.getEnable());
364 return success();
365 }
366};
367
368/// Lower 'seq.clock_inv x' to 'llvm.xor x true'
369struct ClockInvOpLowering : public OpConversionPattern<seq::ClockInverterOp> {
370 using OpConversionPattern::OpConversionPattern;
371 LogicalResult
372 matchAndRewrite(seq::ClockInverterOp op, OpAdaptor adaptor,
373 ConversionPatternRewriter &rewriter) const final {
374 auto constTrue = LLVM::ConstantOp::create(rewriter, op->getLoc(),
375 rewriter.getI1Type(), 1);
376 rewriter.replaceOpWithNewOp<LLVM::XOrOp>(op, adaptor.getInput(), constTrue);
377 return success();
378 }
379};
380
381struct ZeroCountOpLowering : public OpConversionPattern<arc::ZeroCountOp> {
382 using OpConversionPattern::OpConversionPattern;
383 LogicalResult
384 matchAndRewrite(arc::ZeroCountOp op, OpAdaptor adaptor,
385 ConversionPatternRewriter &rewriter) const override {
386 // Use poison when input is zero.
387 IntegerAttr isZeroPoison = rewriter.getBoolAttr(true);
388
389 if (op.getPredicate() == arc::ZeroCountPredicate::leading) {
390 rewriter.replaceOpWithNewOp<LLVM::CountLeadingZerosOp>(
391 op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
392 return success();
393 }
394
395 rewriter.replaceOpWithNewOp<LLVM::CountTrailingZerosOp>(
396 op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
397 return success();
398 }
399};
400
401struct SeqConstClockLowering : public OpConversionPattern<seq::ConstClockOp> {
402 using OpConversionPattern::OpConversionPattern;
403 LogicalResult
404 matchAndRewrite(seq::ConstClockOp op, OpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter) const override {
406 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
407 op, rewriter.getI1Type(), static_cast<int64_t>(op.getValue()));
408 return success();
409 }
410};
411
412template <typename OpTy>
413struct ReplaceOpWithInputPattern : public OpConversionPattern<OpTy> {
415 using OpAdaptor = typename OpTy::Adaptor;
416 LogicalResult
417 matchAndRewrite(OpTy op, OpAdaptor adaptor,
418 ConversionPatternRewriter &rewriter) const override {
419 rewriter.replaceOp(op, adaptor.getInput());
420 return success();
421 }
422};
423
424} // namespace
425
426//===----------------------------------------------------------------------===//
427// Simulation Orchestration Lowering Patterns
428//===----------------------------------------------------------------------===//
429
430namespace {
431
432struct ModelInfoMap {
433 size_t numStateBytes;
434 llvm::DenseMap<StringRef, StateInfo> states;
435 mlir::FlatSymbolRefAttr initialFnSymbol;
436 mlir::FlatSymbolRefAttr finalFnSymbol;
437};
438
439template <typename OpTy>
440struct ModelAwarePattern : public OpConversionPattern<OpTy> {
441 ModelAwarePattern(const TypeConverter &typeConverter, MLIRContext *context,
442 llvm::DenseMap<StringRef, ModelInfoMap> &modelInfo)
443 : OpConversionPattern<OpTy>(typeConverter, context),
444 modelInfo(modelInfo) {}
445
446protected:
447 Value createPtrToPortState(ConversionPatternRewriter &rewriter, Location loc,
448 Value state, const StateInfo &port) const {
449 MLIRContext *ctx = rewriter.getContext();
450 return LLVM::GEPOp::create(rewriter, loc, LLVM::LLVMPointerType::get(ctx),
451 IntegerType::get(ctx, 8), state,
452 LLVM::GEPArg(port.offset));
453 }
454
455 llvm::DenseMap<StringRef, ModelInfoMap> &modelInfo;
456};
457
458/// Lowers SimInstantiateOp to a malloc and memset call. This pattern will
459/// mutate the global module.
460struct SimInstantiateOpLowering
461 : public ModelAwarePattern<arc::SimInstantiateOp> {
462 using ModelAwarePattern::ModelAwarePattern;
463
464 LogicalResult
465 matchAndRewrite(arc::SimInstantiateOp op, OpAdaptor adaptor,
466 ConversionPatternRewriter &rewriter) const final {
467 auto modelIt = modelInfo.find(
468 cast<SimModelInstanceType>(op.getBody().getArgument(0).getType())
469 .getModel()
470 .getValue());
471 ModelInfoMap &model = modelIt->second;
472
473 bool useRuntime = op.getRuntimeModel().has_value();
474
475 ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
476 if (!moduleOp)
477 return failure();
478
479 ConversionPatternRewriter::InsertionGuard guard(rewriter);
480
481 // FIXME: like the rest of MLIR, this assumes sizeof(intptr_t) ==
482 // sizeof(size_t) on the target architecture.
483 Type convertedIndex = typeConverter->convertType(rewriter.getIndexType());
484 Location loc = op.getLoc();
485 Value allocated;
486
487 if (useRuntime) {
488 // The instance is using the runtime library
489 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
490
491 Value runtimeArgs;
492 // If present, materialize the runtime argument string on the stack
493 if (op.getRuntimeArgs().has_value()) {
494 SmallVector<int8_t> argStringVec(op.getRuntimeArgsAttr().begin(),
495 op.getRuntimeArgsAttr().end());
496 argStringVec.push_back('\0');
497 auto strAttr = mlir::DenseElementsAttr::get(
498 mlir::RankedTensorType::get({(int64_t)argStringVec.size()},
499 rewriter.getI8Type()),
500 llvm::ArrayRef(argStringVec));
501
502 auto arrayCst = LLVM::ConstantOp::create(
503 rewriter, loc,
504 LLVM::LLVMArrayType::get(rewriter.getI8Type(), argStringVec.size()),
505 strAttr);
506 auto cst1 = LLVM::ConstantOp::create(rewriter, loc,
507 rewriter.getI32IntegerAttr(1));
508 runtimeArgs = LLVM::AllocaOp::create(rewriter, loc, ptrTy,
509 arrayCst.getType(), cst1);
510 LLVM::LifetimeStartOp::create(rewriter, loc, runtimeArgs);
511 LLVM::StoreOp::create(rewriter, loc, arrayCst, runtimeArgs);
512 } else {
513 runtimeArgs = LLVM::ZeroOp::create(rewriter, loc, ptrTy).getResult();
514 }
515 // Call the state allocation function
516 auto rtModelPtr = LLVM::AddressOfOp::create(rewriter, loc, ptrTy,
517 op.getRuntimeModelAttr())
518 .getResult();
519 allocated =
520 LLVM::CallOp::create(rewriter, loc, {ptrTy},
521 runtime::APICallbacks::symNameAllocInstance,
522 {rtModelPtr, runtimeArgs})
523 .getResult();
524
525 if (op.getRuntimeArgs().has_value())
526 LLVM::LifetimeEndOp::create(rewriter, loc, runtimeArgs);
527
528 } else {
529 // The instance is not using the runtime library
530 FailureOr<LLVM::LLVMFuncOp> mallocFunc =
531 LLVM::lookupOrCreateMallocFn(rewriter, moduleOp, convertedIndex);
532 if (failed(mallocFunc))
533 return mallocFunc;
534
535 Value numStateBytes = LLVM::ConstantOp::create(
536 rewriter, loc, convertedIndex, model.numStateBytes);
537 allocated = LLVM::CallOp::create(rewriter, loc, mallocFunc.value(),
538 ValueRange{numStateBytes})
539 .getResult();
540 Value zero =
541 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI8Type(), 0);
542 LLVM::MemsetOp::create(rewriter, loc, allocated, zero, numStateBytes,
543 false);
544 }
545
546 // Call the model's 'initial' function if present.
547 if (model.initialFnSymbol) {
548 auto initialFnType = LLVM::LLVMFunctionType::get(
549 LLVM::LLVMVoidType::get(op.getContext()),
550 {LLVM::LLVMPointerType::get(op.getContext())});
551 LLVM::CallOp::create(rewriter, loc, initialFnType, model.initialFnSymbol,
552 ValueRange{allocated});
553 }
554
555 // Call the runtime's 'onInitialized' function if present.
556 if (useRuntime)
557 LLVM::CallOp::create(rewriter, loc, TypeRange{},
558 runtime::APICallbacks::symNameOnInitialized,
559 {allocated});
560
561 // Execute the body.
562 rewriter.inlineBlockBefore(&adaptor.getBody().getBlocks().front(), op,
563 {allocated});
564
565 // Call the model's 'final' function if present.
566 if (model.finalFnSymbol) {
567 auto finalFnType = LLVM::LLVMFunctionType::get(
568 LLVM::LLVMVoidType::get(op.getContext()),
569 {LLVM::LLVMPointerType::get(op.getContext())});
570 LLVM::CallOp::create(rewriter, loc, finalFnType, model.finalFnSymbol,
571 ValueRange{allocated});
572 }
573
574 if (useRuntime) {
575 LLVM::CallOp::create(rewriter, loc, TypeRange{},
576 runtime::APICallbacks::symNameDeleteInstance,
577 {allocated});
578 } else {
579 FailureOr<LLVM::LLVMFuncOp> freeFunc =
580 LLVM::lookupOrCreateFreeFn(rewriter, moduleOp);
581 if (failed(freeFunc))
582 return freeFunc;
583
584 LLVM::CallOp::create(rewriter, loc, freeFunc.value(),
585 ValueRange{allocated});
586 }
587
588 rewriter.eraseOp(op);
589 return success();
590 }
591};
592
593struct SimSetInputOpLowering : public ModelAwarePattern<arc::SimSetInputOp> {
594 using ModelAwarePattern::ModelAwarePattern;
595
596 LogicalResult
597 matchAndRewrite(arc::SimSetInputOp op, OpAdaptor adaptor,
598 ConversionPatternRewriter &rewriter) const final {
599 auto modelIt =
600 modelInfo.find(cast<SimModelInstanceType>(op.getInstance().getType())
601 .getModel()
602 .getValue());
603 ModelInfoMap &model = modelIt->second;
604
605 auto portIt = model.states.find(op.getInput());
606 if (portIt == model.states.end()) {
607 // If the port is not found in the state, it means the model does not
608 // actually use it. Thus this operation is a no-op.
609 rewriter.eraseOp(op);
610 return success();
611 }
612
613 StateInfo &port = portIt->second;
614 Value statePtr = createPtrToPortState(rewriter, op.getLoc(),
615 adaptor.getInstance(), port);
616 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
617 statePtr);
618
619 return success();
620 }
621};
622
623struct SimGetPortOpLowering : public ModelAwarePattern<arc::SimGetPortOp> {
624 using ModelAwarePattern::ModelAwarePattern;
625
626 LogicalResult
627 matchAndRewrite(arc::SimGetPortOp op, OpAdaptor adaptor,
628 ConversionPatternRewriter &rewriter) const final {
629 auto modelIt =
630 modelInfo.find(cast<SimModelInstanceType>(op.getInstance().getType())
631 .getModel()
632 .getValue());
633 ModelInfoMap &model = modelIt->second;
634
635 auto type = typeConverter->convertType(op.getValue().getType());
636 if (!type)
637 return failure();
638 auto portIt = model.states.find(op.getPort());
639 if (portIt == model.states.end()) {
640 // If the port is not found in the state, it means the model does not
641 // actually set it. Thus this operation returns 0.
642 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, type, 0);
643 return success();
644 }
645
646 StateInfo &port = portIt->second;
647 Value statePtr = createPtrToPortState(rewriter, op.getLoc(),
648 adaptor.getInstance(), port);
649 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, type, statePtr);
650
651 return success();
652 }
653};
654
655struct SimStepOpLowering : public ModelAwarePattern<arc::SimStepOp> {
656 using ModelAwarePattern::ModelAwarePattern;
657
658 LogicalResult
659 matchAndRewrite(arc::SimStepOp op, OpAdaptor adaptor,
660 ConversionPatternRewriter &rewriter) const final {
661 StringRef modelName = cast<SimModelInstanceType>(op.getInstance().getType())
662 .getModel()
663 .getValue();
664
665 if (adaptor.getTimePostIncrement()) {
666 // Increment time after step
667 OpBuilder::InsertionGuard g(rewriter);
668 rewriter.setInsertionPointAfter(op);
669 auto oldTime =
670 arc::SimGetTimeOp::create(rewriter, op.getLoc(), op.getInstance());
671 auto newTime = LLVM::AddOp::create(rewriter, op.getLoc(), oldTime,
672 adaptor.getTimePostIncrement());
673 arc::SimSetTimeOp::create(rewriter, op.getLoc(), op.getInstance(),
674 newTime);
675 }
676
677 StringAttr evalFunc =
678 rewriter.getStringAttr(evalSymbolFromModelName(modelName));
679 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, mlir::TypeRange(), evalFunc,
680 adaptor.getInstance());
681
682 return success();
683 }
684};
685
686// Loads the simulation time (i64 femtoseconds) from byte offset 0 in the
687// model instance's state storage.
688struct SimGetTimeOpLowering : public OpConversionPattern<arc::SimGetTimeOp> {
689 using OpConversionPattern::OpConversionPattern;
690
691 LogicalResult
692 matchAndRewrite(arc::SimGetTimeOp op, OpAdaptor adaptor,
693 ConversionPatternRewriter &rewriter) const final {
694 // Time is stored at offset 0 in the instance storage.
695 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, rewriter.getI64Type(),
696 adaptor.getInstance());
697 return success();
698 }
699};
700
701// Stores the simulation time (i64 femtoseconds) to byte offset 0 in the
702// model instance's state storage.
703struct SimSetTimeOpLowering : public OpConversionPattern<arc::SimSetTimeOp> {
704 using OpConversionPattern::OpConversionPattern;
705
706 LogicalResult
707 matchAndRewrite(arc::SimSetTimeOp op, OpAdaptor adaptor,
708 ConversionPatternRewriter &rewriter) const final {
709 // Time is stored at offset 0 in the instance storage.
710 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getTime(),
711 adaptor.getInstance());
712 return success();
713 }
714};
715
716// Loads the next wakeup time (i64 femtoseconds) from `kNextWakeupOffset` of
717// the model instance's state storage.
718struct SimGetNextWakeupOpLowering
719 : public OpConversionPattern<arc::SimGetNextWakeupOp> {
720 using OpConversionPattern::OpConversionPattern;
721
722 LogicalResult
723 matchAndRewrite(arc::SimGetNextWakeupOp op, OpAdaptor adaptor,
724 ConversionPatternRewriter &rewriter) const final {
725 auto loc = op.getLoc();
726 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
727 Value slotPtr = LLVM::GEPOp::create(
728 rewriter, loc, ptrType, rewriter.getI8Type(), adaptor.getInstance(),
729 ArrayRef<LLVM::GEPArg>{arc::kNextWakeupOffset});
730 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, rewriter.getI64Type(),
731 slotPtr);
732 return success();
733 }
734};
735
736// Global string constants in the module.
737class StringCache {
738public:
739 Value getOrCreate(OpBuilder &b, StringRef formatStr) {
740 auto it = cache.find(formatStr);
741 if (it != cache.end()) {
742 return LLVM::AddressOfOp::create(b, b.getUnknownLoc(), it->second);
743 }
744
745 Location loc = b.getUnknownLoc();
746 LLVM::GlobalOp global;
747 {
748 OpBuilder::InsertionGuard guard(b);
749 ModuleOp m =
750 b.getInsertionBlock()->getParent()->getParentOfType<ModuleOp>();
751 b.setInsertionPointToStart(m.getBody());
752
753 SmallVector<char> strVec(formatStr.begin(), formatStr.end());
754 strVec.push_back(0);
755
756 auto name = llvm::formatv("_arc_str_{0}", cache.size()).str();
757 auto globalType = LLVM::LLVMArrayType::get(b.getI8Type(), strVec.size());
758 global = LLVM::GlobalOp::create(b, loc, globalType, /*isConstant=*/true,
759 LLVM::Linkage::Internal,
760 /*name=*/name, b.getStringAttr(strVec),
761 /*alignment=*/0);
762 }
763
764 cache[formatStr] = global;
765 return LLVM::AddressOfOp::create(b, loc, global);
766 }
767
768private:
769 llvm::StringMap<LLVM::GlobalOp> cache;
770};
771
772FailureOr<LLVM::CallOp> emitPrintfCall(OpBuilder &builder, Location loc,
773 StringCache &cache, StringRef formatStr,
774 ValueRange args) {
775 ModuleOp moduleOp =
776 builder.getInsertionBlock()->getParent()->getParentOfType<ModuleOp>();
777 // Lookup or create printf function symbol.
778 MLIRContext *ctx = builder.getContext();
779 auto printfFunc = LLVM::lookupOrCreateFn(builder, moduleOp, "printf",
780 LLVM::LLVMPointerType::get(ctx),
781 LLVM::LLVMVoidType::get(ctx), true);
782 if (failed(printfFunc))
783 return printfFunc;
784
785 Value formatStrPtr = cache.getOrCreate(builder, formatStr);
786 SmallVector<Value> argsVec(1, formatStrPtr);
787 argsVec.append(args.begin(), args.end());
788 return LLVM::CallOp::create(builder, loc, printfFunc.value(), argsVec);
789}
790
791/// Lowers SimEmitValueOp to a printf call. The integer will be printed in its
792/// entirety if it is of size up to size_t, and explicitly truncated otherwise.
793/// This pattern will mutate the global module.
794struct SimEmitValueOpLowering
795 : public OpConversionPattern<arc::SimEmitValueOp> {
796 SimEmitValueOpLowering(const TypeConverter &typeConverter,
797 MLIRContext *context, StringCache &formatStringCache)
798 : OpConversionPattern(typeConverter, context),
799 formatStringCache(formatStringCache) {}
800
801 LogicalResult
802 matchAndRewrite(arc::SimEmitValueOp op, OpAdaptor adaptor,
803 ConversionPatternRewriter &rewriter) const final {
804 auto valueType = dyn_cast<IntegerType>(adaptor.getValue().getType());
805 if (!valueType)
806 return failure();
807
808 Location loc = op.getLoc();
809
810 ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
811 if (!moduleOp)
812 return failure();
813
814 SmallVector<Value> printfVariadicArgs;
815 SmallString<16> printfFormatStr;
816 int remainingBits = valueType.getWidth();
817 Value value = adaptor.getValue();
818
819 // Assumes the target platform uses 64bit for long long ints (%llx
820 // formatter).
821 constexpr llvm::StringRef intFormatter = "llx";
822 auto intType = IntegerType::get(getContext(), 64);
823 Value shiftValue = LLVM::ConstantOp::create(
824 rewriter, loc, rewriter.getIntegerAttr(valueType, intType.getWidth()));
825
826 if (valueType.getWidth() < intType.getWidth()) {
827 int width = llvm::divideCeil(valueType.getWidth(), 4);
828 printfFormatStr = llvm::formatv("%0{0}{1}", width, intFormatter);
829 printfVariadicArgs.push_back(
830 LLVM::ZExtOp::create(rewriter, loc, intType, value));
831 } else {
832 // Process the value in 64 bit chunks, starting from the least significant
833 // bits. Since we append chunks in low-to-high order, we reverse the
834 // vector to print them in the correct high-to-low order.
835 int otherChunkWidth = intType.getWidth() / 4;
836 int firstChunkWidth =
837 llvm::divideCeil(valueType.getWidth() % intType.getWidth(), 4);
838 if (firstChunkWidth == 0) { // print the full 64-bit hex or a subset.
839 firstChunkWidth = otherChunkWidth;
840 }
841
842 std::string firstChunkFormat =
843 llvm::formatv("%0{0}{1}", firstChunkWidth, intFormatter);
844 std::string otherChunkFormat =
845 llvm::formatv("%0{0}{1}", otherChunkWidth, intFormatter);
846
847 for (int i = 0; remainingBits > 0; ++i) {
848 // Append 64-bit chunks to the printf arguments, in low-to-high
849 // order. The integer is printed in hex format with zero padding.
850 printfVariadicArgs.push_back(
851 LLVM::TruncOp::create(rewriter, loc, intType, value));
852
853 // Zero-padded format specifier for fixed width, e.g. %01llx for 4 bits.
854 printfFormatStr.append(i == 0 ? firstChunkFormat : otherChunkFormat);
855
856 value =
857 LLVM::LShrOp::create(rewriter, loc, value, shiftValue).getResult();
858 remainingBits -= intType.getWidth();
859 }
860 }
861
862 std::reverse(printfVariadicArgs.begin(), printfVariadicArgs.end());
863
864 SmallString<16> formatStr = adaptor.getValueName();
865 formatStr.append(" = ");
866 formatStr.append(printfFormatStr);
867 formatStr.append("\n");
868
869 auto callOp = emitPrintfCall(rewriter, op->getLoc(), formatStringCache,
870 formatStr, printfVariadicArgs);
871 if (failed(callOp))
872 return failure();
873 rewriter.replaceOp(op, *callOp);
874
875 return success();
876 }
877
878 StringCache &formatStringCache;
879};
880
881//===----------------------------------------------------------------------===//
882// `sim` dialect lowerings
883//===----------------------------------------------------------------------===//
884
885// Helper struct to hold the format string and arguments for arcRuntimeFormat.
886struct FormatInfo {
887 SmallVector<FmtDescriptor> descriptors;
888 SmallVector<Value> args;
889};
890
891// Copies the given integer value into an alloca, returning a pointer to it.
892//
893// The alloca is rounded up to a 64-bit boundary and is written as little-endian
894// words of size 64-bits, to be compatible with the constructor of APInt.
895static Value reg2mem(ConversionPatternRewriter &rewriter, Location loc,
896 Value value) {
897 // Round up the type size to a 64-bit boundary.
898 int64_t origBitwidth = cast<IntegerType>(value.getType()).getWidth();
899 int64_t bitwidth = llvm::divideCeil(origBitwidth, 64) * 64;
900 int64_t numWords = bitwidth / 64;
901
902 // Create an alloca for the rounded up type.
903 LLVM::ConstantOp alloca_size =
904 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), numWords);
905 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
906 auto allocaOp = LLVM::AllocaOp::create(rewriter, loc, ptrType,
907 rewriter.getI64Type(), alloca_size);
908 LLVM::LifetimeStartOp::create(rewriter, loc, allocaOp);
909
910 // Copy `value` into the alloca, 64-bits at a time from the least significant
911 // bits first.
912 for (int64_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
913 Value cst = LLVM::ConstantOp::create(
914 rewriter, loc, rewriter.getIntegerType(origBitwidth), wordIdx * 64);
915 Value v = LLVM::LShrOp::create(rewriter, loc, value, cst);
916 if (origBitwidth > 64) {
917 v = LLVM::TruncOp::create(rewriter, loc, rewriter.getI64Type(), v);
918 } else if (origBitwidth < 64) {
919 v = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), v);
920 }
921 Value gep = LLVM::GEPOp::create(rewriter, loc, ptrType,
922 rewriter.getI64Type(), allocaOp, {wordIdx});
923 LLVM::StoreOp::create(rewriter, loc, v, gep);
924 }
925
926 return allocaOp;
927}
928
929// Statically folds a value of type sim::FormatStringType to a FormatInfo.
930static FailureOr<FormatInfo>
931foldFormatString(ConversionPatternRewriter &rewriter, Value fstringValue,
932 StringCache &cache) {
933 Operation *op = fstringValue.getDefiningOp();
934 return llvm::TypeSwitch<Operation *, FailureOr<FormatInfo>>(op)
935 .Case<sim::FormatCharOp>(
936 [&](sim::FormatCharOp op) -> FailureOr<FormatInfo> {
937 FmtDescriptor d = FmtDescriptor::createChar();
938 return FormatInfo{{d}, {op.getValue()}};
939 })
940 .Case<sim::FormatDecOp>([&](sim::FormatDecOp op)
941 -> FailureOr<FormatInfo> {
942 FmtDescriptor d = FmtDescriptor::createInt(
943 op.getValue().getType().getWidth(), 10, op.getIsLeftAligned(),
944 op.getSpecifierWidth().value_or(-1), false, op.getIsSigned());
945 return FormatInfo{{d}, {reg2mem(rewriter, op.getLoc(), op.getValue())}};
946 })
947 .Case<sim::FormatHexOp>([&](sim::FormatHexOp op)
948 -> FailureOr<FormatInfo> {
949 FmtDescriptor d = FmtDescriptor::createInt(
950 op.getValue().getType().getWidth(), 16, op.getIsLeftAligned(),
951 op.getSpecifierWidth().value_or(-1), op.getIsHexUppercase(), false);
952 return FormatInfo{{d}, {reg2mem(rewriter, op.getLoc(), op.getValue())}};
953 })
954 .Case<sim::FormatOctOp>([&](sim::FormatOctOp op)
955 -> FailureOr<FormatInfo> {
956 FmtDescriptor d = FmtDescriptor::createInt(
957 op.getValue().getType().getWidth(), 8, op.getIsLeftAligned(),
958 op.getSpecifierWidth().value_or(-1), false, false);
959 return FormatInfo{{d}, {reg2mem(rewriter, op.getLoc(), op.getValue())}};
960 })
961 .Case<sim::FormatLiteralOp>(
962 [&](sim::FormatLiteralOp op) -> FailureOr<FormatInfo> {
963 if (op.getLiteral().size() < 8 &&
964 op.getLiteral().find('\0') == StringRef::npos) {
965 // We can use the small string optimization.
966 FmtDescriptor d =
967 FmtDescriptor::createSmallLiteral(op.getLiteral());
968 return FormatInfo{{d}, {}};
969 }
970 FmtDescriptor d =
971 FmtDescriptor::createLiteral(op.getLiteral().size());
972 Value value = cache.getOrCreate(rewriter, op.getLiteral());
973 return FormatInfo{{d}, {value}};
974 })
975 .Case<sim::FormatStringConcatOp>(
976 [&](sim::FormatStringConcatOp op) -> FailureOr<FormatInfo> {
977 auto fmt = foldFormatString(rewriter, op.getInputs()[0], cache);
978 if (failed(fmt))
979 return failure();
980 for (auto input : op.getInputs().drop_front()) {
981 auto next = foldFormatString(rewriter, input, cache);
982 if (failed(next))
983 return failure();
984 fmt->descriptors.append(next->descriptors);
985 fmt->args.append(next->args);
986 }
987 return fmt;
988 })
989 .Default(
990 [](Operation *op) -> FailureOr<FormatInfo> { return failure(); });
991}
992
993FailureOr<LLVM::CallOp> emitFmtCall(OpBuilder &builder, Location loc,
994 StringCache &stringCache,
995 ArrayRef<FmtDescriptor> descriptors,
996 ValueRange args) {
997 ModuleOp moduleOp =
998 builder.getInsertionBlock()->getParent()->getParentOfType<ModuleOp>();
999 // Lookup or create the arcRuntimeFormat function symbol.
1000 MLIRContext *ctx = builder.getContext();
1001 auto func = LLVM::lookupOrCreateFn(
1002 builder, moduleOp, runtime::APICallbacks::symNameFormat,
1003 LLVM::LLVMPointerType::get(ctx), LLVM::LLVMVoidType::get(ctx), true);
1004 if (failed(func))
1005 return func;
1006
1007 StringRef rawDescriptors(reinterpret_cast<const char *>(descriptors.data()),
1008 descriptors.size() * sizeof(FmtDescriptor));
1009 Value fmtPtr = stringCache.getOrCreate(builder, rawDescriptors);
1010
1011 SmallVector<Value> argsVec(1, fmtPtr);
1012 argsVec.append(args.begin(), args.end());
1013 auto result = LLVM::CallOp::create(builder, loc, func.value(), argsVec);
1014
1015 for (Value arg : args) {
1016 Operation *definingOp = arg.getDefiningOp();
1017 if (auto alloca = dyn_cast_if_present<LLVM::AllocaOp>(definingOp)) {
1018 LLVM::LifetimeEndOp::create(builder, loc, arg);
1019 }
1020 }
1021
1022 return result;
1023}
1024
1025struct SimPrintFormattedProcOpLowering
1026 : public OpConversionPattern<sim::PrintFormattedProcOp> {
1027 SimPrintFormattedProcOpLowering(const TypeConverter &typeConverter,
1028 MLIRContext *context,
1029 StringCache &stringCache)
1030 : OpConversionPattern<sim::PrintFormattedProcOp>(typeConverter, context),
1031 stringCache(stringCache) {}
1032
1033 LogicalResult
1034 matchAndRewrite(sim::PrintFormattedProcOp op, OpAdaptor adaptor,
1035 ConversionPatternRewriter &rewriter) const override {
1036 auto formatInfo = foldFormatString(rewriter, op.getInput(), stringCache);
1037 if (failed(formatInfo))
1038 return rewriter.notifyMatchFailure(op, "unsupported format string");
1039
1040 // Add the end descriptor.
1041 formatInfo->descriptors.push_back(FmtDescriptor());
1042
1043 auto result = emitFmtCall(rewriter, op.getLoc(), stringCache,
1044 formatInfo->descriptors, formatInfo->args);
1045 if (failed(result))
1046 return failure();
1047 rewriter.replaceOp(op, result.value());
1048
1049 return success();
1050 }
1051
1052 StringCache &stringCache;
1053};
1054
1055struct TerminateOpLowering : public OpConversionPattern<arc::TerminateOp> {
1056 using OpConversionPattern::OpConversionPattern;
1057
1058 LogicalResult
1059 matchAndRewrite(arc::TerminateOp op, OpAdaptor adaptor,
1060 ConversionPatternRewriter &rewriter) const override {
1061 auto loc = op.getLoc();
1062
1063 auto i8Type = rewriter.getI8Type();
1064 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1065
1066 Value flagPtr = LLVM::GEPOp::create(
1067 rewriter, loc, ptrType, i8Type, adaptor.getStorage(),
1068 ArrayRef<LLVM::GEPArg>{arc::kTerminateFlagOffset});
1069
1070 uint8_t statusCode = op.getSuccess() ? 1 : 2;
1071 Value codeVal = LLVM::ConstantOp::create(
1072 rewriter, loc, i8Type, rewriter.getI8IntegerAttr(statusCode));
1073
1074 LLVM::StoreOp::create(rewriter, loc, codeVal, flagPtr);
1075
1076 rewriter.eraseOp(op);
1077 return success();
1078 }
1079};
1080
1081// Loads the next wakeup time (i64 femtoseconds) from the model's storage at
1082// `kNextWakeupOffset`.
1083struct GetNextWakeupOpLowering
1084 : public OpConversionPattern<arc::GetNextWakeupOp> {
1085 using OpConversionPattern::OpConversionPattern;
1086
1087 LogicalResult
1088 matchAndRewrite(arc::GetNextWakeupOp op, OpAdaptor adaptor,
1089 ConversionPatternRewriter &rewriter) const override {
1090 auto loc = op.getLoc();
1091 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1092 Value slotPtr = LLVM::GEPOp::create(
1093 rewriter, loc, ptrType, rewriter.getI8Type(), adaptor.getStorage(),
1094 ArrayRef<LLVM::GEPArg>{arc::kNextWakeupOffset});
1095 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, rewriter.getI64Type(),
1096 slotPtr);
1097 return success();
1098 }
1099};
1100
1101// Stores the next wakeup time (i64 femtoseconds) to the model's storage at
1102// `kNextWakeupOffset`.
1103struct SetNextWakeupOpLowering
1104 : public OpConversionPattern<arc::SetNextWakeupOp> {
1105 using OpConversionPattern::OpConversionPattern;
1106
1107 LogicalResult
1108 matchAndRewrite(arc::SetNextWakeupOp op, OpAdaptor adaptor,
1109 ConversionPatternRewriter &rewriter) const override {
1110 auto loc = op.getLoc();
1111 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1112 Value slotPtr = LLVM::GEPOp::create(
1113 rewriter, loc, ptrType, rewriter.getI8Type(), adaptor.getStorage(),
1114 ArrayRef<LLVM::GEPArg>{arc::kNextWakeupOffset});
1115 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getTime(), slotPtr);
1116 return success();
1117 }
1118};
1119
1120} // namespace
1121
1122static LogicalResult convert(arc::ExecuteOp op, arc::ExecuteOp::Adaptor adaptor,
1123 ConversionPatternRewriter &rewriter,
1124 const TypeConverter &converter) {
1125 // Convert the argument types in the body blocks.
1126 if (failed(rewriter.convertRegionTypes(&op.getBody(), converter)))
1127 return failure();
1128
1129 // Split the block at the current insertion point such that we can branch into
1130 // the `arc.execute` body region, and have `arc.output` branch back to the
1131 // point after the `arc.execute`.
1132 auto *blockBefore = rewriter.getInsertionBlock();
1133 auto *blockAfter =
1134 rewriter.splitBlock(blockBefore, rewriter.getInsertionPoint());
1135
1136 // Branch to the entry block.
1137 rewriter.setInsertionPointToEnd(blockBefore);
1138 mlir::cf::BranchOp::create(rewriter, op.getLoc(), &op.getBody().front(),
1139 adaptor.getInputs());
1140
1141 // Make all `arc.output` terminators branch to the block after the
1142 // `arc.execute` op.
1143 for (auto &block : op.getBody()) {
1144 auto outputOp = dyn_cast<arc::OutputOp>(block.getTerminator());
1145 if (!outputOp)
1146 continue;
1147 rewriter.setInsertionPointToEnd(&block);
1148 rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(outputOp, blockAfter,
1149 outputOp.getOperands());
1150 }
1151
1152 // Inline the body region between the before and after blocks.
1153 rewriter.inlineRegionBefore(op.getBody(), blockAfter);
1154
1155 // Add arguments to the block after the `arc.execute`, replace the op's
1156 // results with the arguments, then perform block signature conversion.
1157 SmallVector<Value> args;
1158 args.reserve(op.getNumResults());
1159 for (auto result : op.getResults())
1160 args.push_back(blockAfter->addArgument(result.getType(), result.getLoc()));
1161 rewriter.replaceOp(op, args);
1162 auto conversion = converter.convertBlockSignature(blockAfter);
1163 if (!conversion)
1164 return failure();
1165 rewriter.applySignatureConversion(blockAfter, *conversion, &converter);
1166 return success();
1167}
1168
1169//===----------------------------------------------------------------------===//
1170// Runtime Implementation
1171//===----------------------------------------------------------------------===//
1172
1173template <typename T, typename = std::enable_if_t<std::is_integral<T>::value>>
1174static LLVM::GlobalOp
1175buildGlobalConstantIntArray(OpBuilder &builder, Location loc, Twine symName,
1176 SmallVectorImpl<T> &data,
1177 unsigned alignment = alignof(T)) {
1178 auto intType = builder.getIntegerType(8 * sizeof(T));
1179 Attribute denseAttr = mlir::DenseElementsAttr::get(
1180 mlir::RankedTensorType::get({(int64_t)data.size()}, intType),
1181 llvm::ArrayRef(data));
1182 auto globalOp = LLVM::GlobalOp::create(
1183 builder, loc, LLVM::LLVMArrayType::get(intType, data.size()),
1184 /*isConstant=*/true, LLVM::Linkage::Internal,
1185 builder.getStringAttr(symName), denseAttr);
1186 globalOp.setAlignmentAttr(builder.getI64IntegerAttr(alignment));
1187 return globalOp;
1188}
1189
1190// Construct a raw constant byte array from a vector of struct values
1191template <typename T>
1192static LLVM::GlobalOp
1193buildGlobalConstantRuntimeStructArray(OpBuilder &builder, Location loc,
1194 Twine symName,
1195 SmallVectorImpl<T> &array) {
1196 assert(!array.empty());
1197 static_assert(std::is_standard_layout<T>(),
1198 "Runtime struct must have standard layout");
1199 int64_t numBytes = sizeof(T) * array.size();
1200 Attribute denseAttr = mlir::DenseElementsAttr::get(
1201 mlir::RankedTensorType::get({numBytes}, builder.getI8Type()),
1202 llvm::ArrayRef(reinterpret_cast<uint8_t *>(array.data()), numBytes));
1203 auto globalOp = LLVM::GlobalOp::create(
1204 builder, loc, LLVM::LLVMArrayType::get(builder.getI8Type(), numBytes),
1205 /*isConstant=*/true, LLVM::Linkage::Internal,
1206 builder.getStringAttr(symName), denseAttr, alignof(T));
1207 return globalOp;
1208}
1209
1211 : public OpConversionPattern<arc::RuntimeModelOp> {
1212 using OpConversionPattern::OpConversionPattern;
1213
1214 static constexpr uint64_t runtimeApiVersion = ARC_RUNTIME_API_VERSION;
1215
1216 // Build the constant ArcModelTraceInfo struct and its members
1217 LLVM::GlobalOp
1218 buildTraceInfoStruct(arc::RuntimeModelOp &op,
1219 ConversionPatternRewriter &rewriter) const {
1220 if (!op.getTraceTaps().has_value() || op.getTraceTaps()->empty())
1221 return {};
1222 // Construct the array of tap names/aliases
1223 SmallVector<char> namesArray;
1224 SmallVector<ArcTraceTap> tapArray;
1225 tapArray.reserve(op.getTraceTaps()->size());
1226 for (auto attr : op.getTraceTapsAttr()) {
1227 auto tap = cast<TraceTapAttr>(attr);
1228 assert(!tap.getNames().empty() &&
1229 "Expected trace tap to have at least one name");
1230 for (auto alias : tap.getNames()) {
1231 auto aliasStr = cast<StringAttr>(alias);
1232 namesArray.append(aliasStr.begin(), aliasStr.end());
1233 namesArray.push_back('\0');
1234 }
1235 ArcTraceTap tapStruct;
1236 tapStruct.stateOffset = tap.getStateOffset();
1237 tapStruct.nameOffset = namesArray.size() - 1;
1238 tapStruct.typeBits = tap.getSigType().getValue().getIntOrFloatBitWidth();
1239 tapStruct.reserved = 0;
1240 tapArray.emplace_back(tapStruct);
1241 }
1242 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1243 auto namesGlobal = buildGlobalConstantIntArray(
1244 rewriter, op.getLoc(), "_arc_tap_names_" + op.getName(), namesArray);
1245 auto traceTapsArrayGlobal = buildGlobalConstantRuntimeStructArray(
1246 rewriter, op.getLoc(), "_arc_trace_taps_" + op.getName(), tapArray);
1247
1248 //
1249 // struct ArcModelTraceInfo {
1250 // uint64_t numTraceTaps;
1251 // struct ArcTraceTap *traceTaps;
1252 // const char *traceTapNames;
1253 // uint64_t traceBufferCapacity;
1254 // };
1255 //
1256 auto traceInfoStructType = LLVM::LLVMStructType::getLiteral(
1257 getContext(),
1258 {rewriter.getI64Type(), ptrTy, ptrTy, rewriter.getI64Type()});
1259 static_assert(sizeof(ArcModelTraceInfo) == 32 &&
1260 "Unexpected size of ArcModelTraceInfo struct");
1261
1262 auto globalSymName =
1263 rewriter.getStringAttr("_arc_trace_info_" + op.getName());
1264 auto traceInfoGlobalOp = LLVM::GlobalOp::create(
1265 rewriter, op.getLoc(), traceInfoStructType,
1266 /*isConstant=*/false, LLVM::Linkage::Internal, globalSymName,
1267 Attribute{}, alignof(ArcModelTraceInfo));
1268 OpBuilder::InsertionGuard g(rewriter);
1269
1270 // Struct Initializer
1271 Region &initRegion = traceInfoGlobalOp.getInitializerRegion();
1272 Block *initBlock = rewriter.createBlock(&initRegion);
1273 rewriter.setInsertionPointToStart(initBlock);
1274
1275 auto numTraceTapsCst = LLVM::ConstantOp::create(
1276 rewriter, op.getLoc(), rewriter.getI64IntegerAttr(tapArray.size()));
1277 auto traceTapArrayAddr =
1278 LLVM::AddressOfOp::create(rewriter, op.getLoc(), traceTapsArrayGlobal);
1279 auto tapNameArrayAddr =
1280 LLVM::AddressOfOp::create(rewriter, op.getLoc(), namesGlobal);
1281 auto bufferCapacityCst = LLVM::ConstantOp::create(
1282 rewriter, op.getLoc(),
1283 rewriter.getI64IntegerAttr(runtime::defaultTraceBufferCapacity));
1284
1285 Value initStruct =
1286 LLVM::PoisonOp::create(rewriter, op.getLoc(), traceInfoStructType);
1287
1288 // Field: uint64_t numTraceTaps
1289 initStruct =
1290 LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1291 numTraceTapsCst, ArrayRef<int64_t>{0});
1292 static_assert(offsetof(ArcModelTraceInfo, numTraceTaps) == 0,
1293 "Unexpected offset of field numTraceTaps");
1294 // Field: struct ArcTraceTap *traceTaps
1295 initStruct =
1296 LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1297 traceTapArrayAddr, ArrayRef<int64_t>{1});
1298 static_assert(offsetof(ArcModelTraceInfo, traceTaps) == 8,
1299 "Unexpected offset of field traceTaps");
1300 // Field: const char *traceTapNames
1301 initStruct =
1302 LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1303 tapNameArrayAddr, ArrayRef<int64_t>{2});
1304 static_assert(offsetof(ArcModelTraceInfo, traceTapNames) == 16,
1305 "Unexpected offset of field traceTapNames");
1306 // Field: uint64_t traceBufferCapacity
1307 initStruct =
1308 LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1309 bufferCapacityCst, ArrayRef<int64_t>{3});
1310 static_assert(offsetof(ArcModelTraceInfo, traceBufferCapacity) == 24,
1311 "Unexpected offset of field traceBufferCapacity");
1312 LLVM::ReturnOp::create(rewriter, op.getLoc(), initStruct);
1313
1314 return traceInfoGlobalOp;
1315 }
1316
1317 // Create a global LLVM struct containing the RuntimeModel metadata
1318 LogicalResult
1319 matchAndRewrite(arc::RuntimeModelOp op, OpAdaptor adaptor,
1320 ConversionPatternRewriter &rewriter) const final {
1321
1322 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1323 auto modelInfoStructType = LLVM::LLVMStructType::getLiteral(
1324 getContext(),
1325 {rewriter.getI64Type(), rewriter.getI64Type(), ptrTy, ptrTy});
1326 static_assert(sizeof(ArcRuntimeModelInfo) == 32 &&
1327 "Unexpected size of ArcRuntimeModelInfo struct");
1328
1329 rewriter.setInsertionPoint(op);
1330 auto traceInfoGlobal = buildTraceInfoStruct(op, rewriter);
1331
1332 // Construct the Model Name String GlobalOp
1333 SmallVector<char, 16> modNameArray(op.getName().begin(),
1334 op.getName().end());
1335 modNameArray.push_back('\0');
1336 auto nameGlobalType =
1337 LLVM::LLVMArrayType::get(rewriter.getI8Type(), modNameArray.size());
1338 auto globalSymName =
1339 rewriter.getStringAttr("_arc_mod_name_" + op.getName());
1340 auto nameGlobal = LLVM::GlobalOp::create(
1341 rewriter, op.getLoc(), nameGlobalType, /*isConstant=*/true,
1342 LLVM::Linkage::Internal,
1343 /*name=*/globalSymName, rewriter.getStringAttr(modNameArray),
1344 /*alignment=*/0);
1345
1346 // Construct the Model Info Struct GlobalOp
1347 // Note: The struct is supposed to be constant at runtime, but contains the
1348 // relocatable address of another symbol, so it should not be placed in the
1349 // "rodata" section.
1350 auto modInfoGlobalOp =
1351 LLVM::GlobalOp::create(rewriter, op.getLoc(), modelInfoStructType,
1352 /*isConstant=*/false, LLVM::Linkage::External,
1353 op.getSymName(), Attribute{});
1354
1355 // Struct Initializer
1356 Region &initRegion = modInfoGlobalOp.getInitializerRegion();
1357 Block *initBlock = rewriter.createBlock(&initRegion);
1358 rewriter.setInsertionPointToStart(initBlock);
1359 auto apiVersionCst = LLVM::ConstantOp::create(
1360 rewriter, op.getLoc(), rewriter.getI64IntegerAttr(runtimeApiVersion));
1361 auto numStateBytesCst = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1362 op.getNumStateBytesAttr());
1363 auto nameAddr =
1364 LLVM::AddressOfOp::create(rewriter, op.getLoc(), nameGlobal);
1365 Value traceInfoPtr;
1366 if (traceInfoGlobal)
1367 traceInfoPtr =
1368 LLVM::AddressOfOp::create(rewriter, op.getLoc(), traceInfoGlobal);
1369 else
1370 traceInfoPtr = LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy);
1371
1372 Value initStruct =
1373 LLVM::PoisonOp::create(rewriter, op.getLoc(), modelInfoStructType);
1374
1375 // Field: uint64_t apiVersion
1376 initStruct = LLVM::InsertValueOp::create(
1377 rewriter, op.getLoc(), initStruct, apiVersionCst, ArrayRef<int64_t>{0});
1378 static_assert(offsetof(ArcRuntimeModelInfo, apiVersion) == 0,
1379 "Unexpected offset of field apiVersion");
1380 // Field: uint64_t numStateBytes
1381 initStruct =
1382 LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1383 numStateBytesCst, ArrayRef<int64_t>{1});
1384 static_assert(offsetof(ArcRuntimeModelInfo, numStateBytes) == 8,
1385 "Unexpected offset of field numStateBytes");
1386 // Field: const char *modelName
1387 initStruct = LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1388 nameAddr, ArrayRef<int64_t>{2});
1389 static_assert(offsetof(ArcRuntimeModelInfo, modelName) == 16,
1390 "Unexpected offset of field modelName");
1391 // Field: struct ArcModelTraceInfo *traceInfo
1392 initStruct = LLVM::InsertValueOp::create(
1393 rewriter, op.getLoc(), initStruct, traceInfoPtr, ArrayRef<int64_t>{3});
1394 static_assert(offsetof(ArcRuntimeModelInfo, traceInfo) == 24,
1395 "Unexpected offset of field traceInfo");
1396
1397 LLVM::ReturnOp::create(rewriter, op.getLoc(), initStruct);
1398
1399 rewriter.replaceOp(op, modInfoGlobalOp);
1400 return success();
1401 }
1402};
1403
1404//===----------------------------------------------------------------------===//
1405// ArrayRef patterns
1406//===----------------------------------------------------------------------===//
1407
1408size_t computeByteWidth(ArrayRefType type) {
1409 auto bitWidth = computeLLVMBitWidth(type);
1410 assert(bitWidth.has_value());
1411 return llvm::divideCeil(*bitWidth, 8);
1412}
1413
1414// Computes the padded bytewidth (stride) of each element.
1415size_t computeElementByteWidth(ArrayRefType arrayRefType) {
1416 auto arrayBitWidth = computeLLVMBitWidth(arrayRefType);
1417 assert(arrayBitWidth.has_value());
1418 assert(arrayRefType.getNumElements() > 0 &&
1419 "Cannot compute stride for zero sized array");
1420 size_t elementBitWidth = *arrayBitWidth / arrayRefType.getNumElements();
1421 return llvm::divideCeil(elementBitWidth, 8);
1422}
1423
1424struct ArrayRefAllocOpLowering : public OpConversionPattern<ArrayRefAllocOp> {
1425 using OpConversionPattern::OpConversionPattern;
1426
1427 LogicalResult
1428 matchAndRewrite(ArrayRefAllocOp op, OpAdaptor adaptor,
1429 ConversionPatternRewriter &rewriter) const override {
1430 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1431 auto i8Ty = rewriter.getI8Type();
1432 ArrayRefType arrayRefType = op.getType();
1433 size_t byteWidth = computeByteWidth(arrayRefType);
1434 auto size = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1435 rewriter.getI64Type(), byteWidth);
1436
1437 size_t alignment = computeAllocaAlignment(arrayRefType, op);
1438 auto alloc = LLVM::AllocaOp::create(rewriter, op.getLoc(), ptrTy, i8Ty,
1439 size, alignment);
1440
1441 if (op.getInitAttr()) {
1442 ArrayAttr initAttr = op.getInitAttr();
1443 if (isZero(initAttr)) {
1444 auto i8Ty = rewriter.getI8Type();
1445 auto zero = LLVM::ConstantOp::create(rewriter, op.getLoc(), i8Ty, 0);
1446 LLVM::MemsetOp::create(rewriter, op.getLoc(), alloc, zero, size,
1447 /*isVolatile=*/false);
1448 } else {
1449 initializeArray(rewriter, op.getLoc(), alloc, initAttr, arrayRefType);
1450 }
1451 }
1452
1453 rewriter.replaceOp(op, alloc);
1454 return success();
1455 }
1456
1457 // Computes the required alignment for an AllocaOp of the given type.
1458 // c.f. HWToLLVM.cpp.
1459 size_t computeAllocaAlignment(ArrayRefType type, Operation *op) const {
1460 if (alignmentCache.count(type)) {
1461 return alignmentCache[type];
1462 }
1463 auto dl = DataLayout::closest(op);
1464 auto hwType =
1465 hw::ArrayType::get(type.getElementType(), type.getNumElements());
1466 auto llvmType = getTypeConverter()->convertType(hwType);
1467 auto alignment =
1468 static_cast<unsigned>(dl.getTypePreferredAlignment(llvmType));
1469 alignment = std::max(4u, alignment);
1470 alignmentCache[type] = alignment;
1471 return alignment;
1472 }
1473
1474 bool isZero(ArrayAttr arrayAttr) const {
1475 return llvm::all_of(arrayAttr.getAsValueRange<IntegerAttr>(),
1476 [](APInt i) { return i.isZero(); });
1477 }
1478
1479 void initializeArray(ConversionPatternRewriter &rewriter, Location loc,
1480 Value alloc, ArrayAttr initAttr,
1481 ArrayRefType arrayRefType) const {
1482 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1483 Type ptrTy = LLVM::LLVMPointerType::get(getContext());
1484 Type i8Ty = rewriter.getI8Type();
1485 for (unsigned i = 0; i < arrayRefType.getNumElements(); ++i) {
1486 unsigned elemIndex = arrayRefType.getNumElements() - i - 1;
1487 Value elemOffset = LLVM::ConstantOp::create(
1488 rewriter, loc, rewriter.getI64Type(), elemIndex * elemByteWidth);
1489 auto elemAddr =
1490 LLVM::GEPOp::create(rewriter, loc, ptrTy, i8Ty, alloc, elemOffset);
1491 auto elem = LLVM::ConstantOp::create(
1492 rewriter, loc, arrayRefType.getElementType(), initAttr[i]);
1493 LLVM::StoreOp::create(rewriter, loc, elem, elemAddr);
1494 }
1495 }
1496
1497private:
1498 mutable DenseMap<ArrayRefType, size_t> alignmentCache;
1499};
1500
1501struct ArrayRefCreateOpLowering : public OpConversionPattern<ArrayRefCreateOp> {
1502 using OpConversionPattern::OpConversionPattern;
1503
1504 LogicalResult
1505 matchAndRewrite(ArrayRefCreateOp op, OpAdaptor adaptor,
1506 ConversionPatternRewriter &rewriter) const override {
1507 ArrayRefType arrayRefType = cast<ArrayRefType>(op.getType());
1508 Value alloc = adaptor.getInput();
1509 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1510 auto i8Ty = rewriter.getI8Type();
1511 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1512 auto elements = adaptor.getElements();
1513 for (unsigned i = 0; i < elements.size(); ++i) {
1514 // Note: hardcoded for little endian targets.
1515 unsigned elemIndex = arrayRefType.getNumElements() - i - 1;
1516 Value elemOffset =
1517 LLVM::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI64Type(),
1518 elemIndex * elemByteWidth);
1519 auto elemAddr = LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, i8Ty,
1520 alloc, elemOffset);
1521 LLVM::StoreOp::create(rewriter, op.getLoc(), elements[i], elemAddr);
1522 }
1523 rewriter.replaceOp(op, alloc);
1524 return success();
1525 }
1526};
1527
1528struct ArrayRefGetOpLowering : public OpConversionPattern<ArrayRefGetOp> {
1529 using OpConversionPattern::OpConversionPattern;
1530
1531 LogicalResult
1532 matchAndRewrite(ArrayRefGetOp op, OpAdaptor adaptor,
1533 ConversionPatternRewriter &rewriter) const override {
1534 auto loc = op.getLoc();
1535 ArrayRefType arrayRefType = cast<ArrayRefType>(op.getInput().getType());
1536 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1537 auto i8Ty = rewriter.getI8Type();
1538 auto i64Ty = rewriter.getI64Type();
1539 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1540 assert(!isa<ArrayRefType>(arrayRefType.getElementType()));
1541
1542 Value stride =
1543 LLVM::ConstantOp::create(rewriter, loc, i64Ty, elemByteWidth);
1544 Value byteOffset =
1545 LLVM::MulOp::create(rewriter, loc, adaptor.getIndex(), stride);
1546 // Defend against out-of-bounds accesses. What we return is undefined in the
1547 // case of OOB.
1548 size_t lastElementByteOffset =
1549 elemByteWidth * (arrayRefType.getNumElements() - 1);
1550 Value lastElementByteOffsetVal =
1551 LLVM::ConstantOp::create(rewriter, loc, i64Ty, lastElementByteOffset);
1552 Value clampedOffset = LLVM::UMinOp::create(rewriter, loc, i64Ty, byteOffset,
1553 lastElementByteOffsetVal);
1554 auto elemAddr = LLVM::GEPOp::create(rewriter, loc, ptrTy, i8Ty,
1555 adaptor.getInput(), clampedOffset);
1556 Value loaded = LLVM::LoadOp::create(
1557 rewriter, loc, typeConverter->convertType(op.getValue().getType()),
1558 elemAddr);
1559 rewriter.replaceOp(op, loaded);
1560 return success();
1561 }
1562};
1563
1564struct ArrayRefInjectOpLowering : public OpConversionPattern<ArrayRefInjectOp> {
1565 using OpConversionPattern::OpConversionPattern;
1566
1567 LogicalResult
1568 matchAndRewrite(ArrayRefInjectOp op, OpAdaptor adaptor,
1569 ConversionPatternRewriter &rewriter) const override {
1570 auto loc = op.getLoc();
1571 ArrayRefType arrayRefType = cast<ArrayRefType>(op.getInput().getType());
1572 assert(!isa<ArrayRefType>(arrayRefType.getElementType()));
1573 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1574 auto i8Ty = rewriter.getI8Type();
1575 auto i64Ty = rewriter.getI64Type();
1576 size_t byteWidth = computeByteWidth(arrayRefType);
1577 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1578
1579 Value stride =
1580 LLVM::ConstantOp::create(rewriter, loc, i64Ty, elemByteWidth);
1581 Value byteOffset =
1582 LLVM::MulOp::create(rewriter, loc, adaptor.getIndex(), stride);
1583 Value totalSize = LLVM::ConstantOp::create(rewriter, loc, i64Ty, byteWidth);
1584 // Defend against out-of-bounds accesses. We must avoid corrupting the
1585 // array.
1586 Value isInbounds = LLVM::ICmpOp::create(
1587 rewriter, loc, LLVM::ICmpPredicate::ult, byteOffset, totalSize);
1588 scf::IfOp::create(rewriter, loc, isInbounds, [&](OpBuilder &b, Location) {
1589 auto elemAddr = LLVM::GEPOp::create(b, loc, ptrTy, i8Ty,
1590 adaptor.getInput(), byteOffset);
1591 LLVM::StoreOp::create(b, loc, adaptor.getElement(), elemAddr);
1592 scf::YieldOp::create(b, loc);
1593 });
1594
1595 // Inject is pure; returns the same pointer (input buffer is modified
1596 // in-place and the pointer is forwarded as the result).
1597 rewriter.replaceOp(op, adaptor.getInput());
1598 return success();
1599 }
1600};
1601
1602struct ArrayRefSliceOpLowering : public OpConversionPattern<ArrayRefSliceOp> {
1603 using OpConversionPattern::OpConversionPattern;
1604
1605 LogicalResult
1606 matchAndRewrite(ArrayRefSliceOp op, OpAdaptor adaptor,
1607 ConversionPatternRewriter &rewriter) const override {
1608 auto loc = op.getLoc();
1609 // The result type is the sub-array type; use its element size.
1610 ArrayRefType inputType = cast<ArrayRefType>(op.getInput().getType());
1611 ArrayRefType resultType = cast<ArrayRefType>(op.getOutput().getType());
1612 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1613 auto i8Ty = rewriter.getI8Type();
1614 auto i64Ty = rewriter.getI64Type();
1615 size_t elemByteWidth = computeElementByteWidth(resultType);
1616
1617 // Ensure the slice doesn't go out of bounds.
1618 size_t maxLowIndex =
1619 inputType.getNumElements() - resultType.getNumElements();
1620 Value maxLowIndexVal =
1621 LLVM::ConstantOp::create(rewriter, loc, i64Ty, maxLowIndex);
1622 Value clampedLowIndex = LLVM::UMinOp::create(
1623 rewriter, loc, i64Ty, adaptor.getLowIndex(), maxLowIndexVal);
1624
1625 // Byte offset = lowIndex * elemByteWidth.
1626 Value stride =
1627 LLVM::ConstantOp::create(rewriter, loc, i64Ty, elemByteWidth);
1628 Value byteOffset =
1629 LLVM::MulOp::create(rewriter, loc, clampedLowIndex, stride);
1630 auto sliceAddr = LLVM::GEPOp::create(rewriter, loc, ptrTy, i8Ty,
1631 adaptor.getInput(), byteOffset);
1632 rewriter.replaceOp(op, sliceAddr);
1633 return success();
1634 }
1635};
1636
1637struct ArrayRefCopyOpLowering : public OpConversionPattern<ArrayRefCopyOp> {
1638 using OpConversionPattern::OpConversionPattern;
1639
1640 LogicalResult
1641 matchAndRewrite(ArrayRefCopyOp op, OpAdaptor adaptor,
1642 ConversionPatternRewriter &rewriter) const override {
1643 auto loc = op.getLoc();
1644 ArrayRefType arrayRefType = cast<ArrayRefType>(op.getInput().getType());
1645 auto i64Ty = rewriter.getI64Type();
1646 size_t byteWidth = computeByteWidth(arrayRefType);
1647 Value size = LLVM::ConstantOp::create(rewriter, loc, i64Ty, byteWidth);
1648 // Use a memmove rather than a memcpy just in case the arrays alias.
1649 LLVM::MemmoveOp::create(rewriter, loc, adaptor.getInput(),
1650 adaptor.getSource(), size,
1651 /*isVolatile=*/false);
1652 rewriter.replaceOp(op, adaptor.getInput());
1653 return success();
1654 }
1655};
1656
1657static Value loadArrayRefAsArray(ImplicitLocOpBuilder &builder, Value arrayRef,
1658 ArrayRefType arrayRefType,
1659 LLVM::LLVMArrayType llvmType) {
1660 auto i8Ty = builder.getI8Type();
1661 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
1662 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1663 Value v = LLVM::PoisonOp::create(builder, llvmType);
1664 int32_t size = arrayRefType.getNumElements();
1665 for (int32_t i = 0; i < size; i++) {
1666 int32_t byteOffset = i * elemByteWidth;
1667 Value gep = LLVM::GEPOp::create(builder, ptrTy, i8Ty, arrayRef,
1668 LLVM::GEPArg{byteOffset});
1669 Value load = LLVM::LoadOp::create(builder, llvmType.getElementType(), gep);
1670 v = LLVM::InsertValueOp::create(builder, v, load, i);
1671 }
1672 return v;
1673}
1674
1675static void storeArrayAsArrayRef(ImplicitLocOpBuilder &builder, Value array,
1676 Value arrayRef, ArrayRefType arrayRefType) {
1677 auto i8Ty = builder.getI8Type();
1678 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
1679 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1680 int32_t size = arrayRefType.getNumElements();
1681 for (int32_t i = 0; i < size; i++) {
1682 int32_t byteOffset = i * elemByteWidth;
1683 Value gep = LLVM::GEPOp::create(builder, ptrTy, i8Ty, arrayRef,
1684 LLVM::GEPArg{byteOffset});
1685 Value val = LLVM::ExtractValueOp::create(builder, array, i);
1686 LLVM::StoreOp::create(builder, val, gep);
1687 }
1688}
1689
1691 : public OpConversionPattern<UnrealizedConversionCastOp> {
1692 using OpConversionPattern::OpConversionPattern;
1693
1694 LogicalResult
1695 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
1696 ConversionPatternRewriter &rewriter) const override {
1697 if (!isa<ArrayRefType>(op.getOperand(0).getType()) ||
1698 !isa<LLVM::LLVMArrayType>(op.getResult(0).getType())) {
1699 return failure();
1700 }
1701
1702 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1703 Value loaded = loadArrayRefAsArray(
1704 b, adaptor.getInputs().front(),
1705 cast<ArrayRefType>(op.getOperand(0).getType()),
1706 cast<LLVM::LLVMArrayType>(op.getResult(0).getType()));
1707 rewriter.replaceOp(op, loaded);
1708 return success();
1709 }
1710};
1711
1713 : public OpConversionPattern<ArrayRefToArrayOp> {
1714 using OpConversionPattern::OpConversionPattern;
1715
1716 LogicalResult
1717 matchAndRewrite(ArrayRefToArrayOp op, OpAdaptor adaptor,
1718 ConversionPatternRewriter &rewriter) const override {
1719 Type resultType = getTypeConverter()->convertType(op.getResult().getType());
1720 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1721 Value loaded = loadArrayRefAsArray(
1722 b, adaptor.getInput(), cast<ArrayRefType>(op.getInput().getType()),
1723 cast<LLVM::LLVMArrayType>(resultType));
1724 rewriter.replaceOp(op, loaded);
1725 return success();
1726 }
1727};
1728
1730 : public OpConversionPattern<ArrayRefFromArrayOp> {
1731 using OpConversionPattern::OpConversionPattern;
1732
1733 LogicalResult
1734 matchAndRewrite(ArrayRefFromArrayOp op, OpAdaptor adaptor,
1735 ConversionPatternRewriter &rewriter) const override {
1736 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1737 storeArrayAsArrayRef(b, adaptor.getArray(), adaptor.getInput(),
1738 cast<ArrayRefType>(op.getInput().getType()));
1739 rewriter.replaceOp(op, adaptor.getInput());
1740 return success();
1741 }
1742};
1743
1744//===----------------------------------------------------------------------===//
1745// Pass Implementation
1746//===----------------------------------------------------------------------===//
1747
1748namespace {
1749struct LowerArcToLLVMPass
1750 : public circt::impl::LowerArcToLLVMBase<LowerArcToLLVMPass> {
1751 void runOnOperation() override;
1752};
1753} // namespace
1754
1755void LowerArcToLLVMPass::runOnOperation() {
1756 // Add `dereferenceable(<N>)` attributes to all function arguments that take
1757 // ArrayRefTypes.
1758 for (func::FuncOp func : getOperation().getOps<func::FuncOp>()) {
1759 for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
1760 if (auto arrayRefType =
1761 dyn_cast<ArrayRefType>(func.getArgumentTypes()[i])) {
1762 size_t byteWidth = computeByteWidth(arrayRefType);
1763 Builder builder(&getContext());
1764 func.setArgAttr(i, LLVM::LLVMDialect::getDereferenceableAttrName(),
1765 builder.getI64IntegerAttr(byteWidth));
1766 }
1767 }
1768 }
1769
1770 // Collect the symbols in the root op such that the HW-to-LLVM lowering can
1771 // create LLVM globals with non-colliding names.
1772 Namespace globals;
1773 SymbolCache cache;
1774 cache.addDefinitions(getOperation());
1775 globals.add(cache);
1776
1777 // Setup the conversion target. Explicitly mark `scf.yield` legal since it
1778 // does not have a conversion itself, which would cause it to fail
1779 // legalization and for the conversion to abort. (It relies on its parent op's
1780 // conversion to remove it.)
1781 LLVMConversionTarget target(getContext());
1782 target.addLegalOp<mlir::ModuleOp>();
1783 target.addLegalOp<scf::YieldOp>(); // quirk of SCF dialect conversion
1784
1785 // Mark sim::Format*Op as legal. These are not converted to LLVM, but the
1786 // lowering of sim::PrintFormattedOp walks them to build up its format string.
1787 // They are all marked Pure so are removed after the conversion.
1788 target.addLegalOp<sim::FormatLiteralOp, sim::FormatDecOp, sim::FormatHexOp,
1789 sim::FormatBinOp, sim::FormatOctOp, sim::FormatCharOp,
1790 sim::FormatStringConcatOp>();
1791
1792 // Setup the arc dialect type conversion.
1793 LLVMTypeConverter converter(&getContext());
1794 converter.addConversion([&](seq::ClockType type) {
1795 return IntegerType::get(type.getContext(), 1);
1796 });
1797 converter.addConversion([&](StorageType type) {
1798 return LLVM::LLVMPointerType::get(type.getContext());
1799 });
1800 converter.addConversion([&](MemoryType type) {
1801 return LLVM::LLVMPointerType::get(type.getContext());
1802 });
1803 converter.addConversion([&](StateType type) {
1804 return LLVM::LLVMPointerType::get(type.getContext());
1805 });
1806 converter.addConversion([&](SimModelInstanceType type) {
1807 return LLVM::LLVMPointerType::get(type.getContext());
1808 });
1809 converter.addConversion([&](sim::FormatStringType type) {
1810 return LLVM::LLVMPointerType::get(type.getContext());
1811 });
1812 converter.addConversion([&](llhd::TimeType type) {
1813 // LLHD time is represented as i64 femtoseconds.
1814 return IntegerType::get(type.getContext(), 64);
1815 });
1816 converter.addConversion([&](ArrayRefType type) {
1817 return LLVM::LLVMPointerType::get(type.getContext());
1818 });
1819
1820 // Convert an UnrealizedConversionCastOp from !arc.arrayref<T> to
1821 // !llvm.array<T>. These are inserted by the InsertRuntime pass.
1822 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>([&](Operation *op) {
1823 Type src = op->getOperand(0).getType();
1824 Type dst = op->getResult(0).getType();
1825 bool needsConvert = isa<ArrayRefType>(src) && isa<LLVM::LLVMArrayType>(dst);
1826 return !needsConvert;
1827 });
1828
1829 // Setup the conversion patterns.
1830 ConversionPatternSet patterns(&getContext(), converter);
1831
1832 // MLIR patterns.
1833 populateSCFToControlFlowConversionPatterns(patterns);
1834 populateFuncToLLVMConversionPatterns(converter, patterns);
1835 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
1836 arith::populateArithToLLVMConversionPatterns(converter, patterns);
1837 index::populateIndexToLLVMConversionPatterns(converter, patterns);
1838 populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter);
1839
1840 // CIRCT patterns.
1841 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
1843 std::optional<HWToLLVMArraySpillCache> spillCacheOpt =
1845 {
1846 OpBuilder spillBuilder(getOperation());
1847 spillCacheOpt->spillNonHWOps(spillBuilder, converter, getOperation());
1848 }
1849 populateHWToLLVMConversionPatterns(converter, patterns, globals,
1850 constAggregateGlobalsMap, spillCacheOpt);
1851
1854
1855 // Arc patterns.
1856 // clang-format off
1857 patterns.add<
1858 AllocMemoryOpLowering,
1859 AllocStateLikeOpLowering<arc::AllocStateOp>,
1860 AllocStateLikeOpLowering<arc::RootInputOp>,
1861 AllocStateLikeOpLowering<arc::RootOutputOp>,
1862 AllocStorageOpLowering,
1863 ClockGateOpLowering,
1864 ClockInvOpLowering,
1865 ConstantTimeOpLowering,
1866 CurrentTimeOpLowering,
1867 GetNextWakeupOpLowering,
1868 IntToTimeOpLowering,
1869 MemoryReadOpLowering,
1870 MemoryWriteOpLowering,
1871 ModelOpLowering,
1872 ReplaceOpWithInputPattern<seq::ToClockOp>,
1873 ReplaceOpWithInputPattern<seq::FromClockOp>,
1875 SeqConstClockLowering,
1876 SetNextWakeupOpLowering,
1877 SimGetNextWakeupOpLowering,
1878 SimGetTimeOpLowering,
1879 SimSetTimeOpLowering,
1880 StateReadOpLowering,
1881 StateWriteOpLowering,
1882 StorageGetOpLowering,
1883 TerminateOpLowering,
1884 TimeToIntOpLowering,
1885 ZeroCountOpLowering,
1895 >(converter, &getContext());
1896 // clang-format on
1897 patterns.add<ExecuteOp>(convert);
1898
1899 StringCache stringCache;
1900 patterns.add<SimEmitValueOpLowering, SimPrintFormattedProcOpLowering>(
1901 converter, &getContext(), stringCache);
1902
1903 auto &modelInfo = getAnalysis<ModelInfoAnalysis>();
1904 llvm::DenseMap<StringRef, ModelInfoMap> modelMap(modelInfo.infoMap.size());
1905 for (auto &[_, modelInfo] : modelInfo.infoMap) {
1906 llvm::DenseMap<StringRef, StateInfo> states(modelInfo.states.size());
1907 for (StateInfo &stateInfo : modelInfo.states)
1908 states.insert({stateInfo.name, stateInfo});
1909 modelMap.insert(
1910 {modelInfo.name,
1911 ModelInfoMap{modelInfo.numStateBytes, std::move(states),
1912 modelInfo.initialFnSym, modelInfo.finalFnSym}});
1913 }
1914
1915 patterns.add<SimInstantiateOpLowering, SimSetInputOpLowering,
1916 SimGetPortOpLowering, SimStepOpLowering>(
1917 converter, &getContext(), modelMap);
1918
1919 // Apply the conversion.
1920 ConversionConfig config;
1921 config.allowPatternRollback = false;
1922 if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
1923 config)))
1924 signalPassFailure();
1925}
1926
1927std::unique_ptr<OperationPass<ModuleOp>> circt::createLowerArcToLLVMPass() {
1928 return std::make_unique<LowerArcToLLVMPass>();
1929}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static LLVM::GlobalOp buildGlobalConstantIntArray(OpBuilder &builder, Location loc, Twine symName, SmallVectorImpl< T > &data, unsigned alignment=alignof(T))
static LLVM::GlobalOp buildGlobalConstantRuntimeStructArray(OpBuilder &builder, Location loc, Twine symName, SmallVectorImpl< T > &array)
static Value loadArrayRefAsArray(ImplicitLocOpBuilder &builder, Value arrayRef, ArrayRefType arrayRefType, LLVM::LLVMArrayType llvmType)
size_t computeByteWidth(ArrayRefType type)
static llvm::Twine evalSymbolFromModelName(StringRef modelName)
size_t computeElementByteWidth(ArrayRefType arrayRefType)
static void storeArrayAsArrayRef(ImplicitLocOpBuilder &builder, Value array, Value arrayRef, ArrayRefType arrayRefType)
static LogicalResult convert(arc::ExecuteOp op, arc::ExecuteOp::Adaptor adaptor, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
Extension of RewritePatternSet that allows adding matchAndRewrite functions with op adaptors and Conv...
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
#define ARC_RUNTIME_API_VERSION
Version of the combined public and internal API.
Definition Common.h:27
Definition arc.py:1
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateCombToArithConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
void populateCombToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns)
Get the Comb to LLVM conversion patterns.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap, std::optional< HWToLLVMArraySpillCache > &spillCacheOpt)
Get the HW to LLVM conversion patterns.
std::unique_ptr< OperationPass< ModuleOp > > createLowerArcToLLVMPass()
Definition hw.py:1
Definition sim.py:1
Static information for a compiled hardware model, generated by the MLIR lowering.
Definition Common.h:70
uint32_t typeBits
Bit width of the traced signal.
Definition TraceTaps.h:28
uint64_t stateOffset
Byte offset of the traced value within the model state.
Definition TraceTaps.h:23
uint64_t nameOffset
Byte offset to the null terminator of this signal's last alias in the names array.
Definition TraceTaps.h:26
uint32_t reserved
Padding and reserved for future use.
Definition TraceTaps.h:30
void initializeArray(ConversionPatternRewriter &rewriter, Location loc, Value alloc, ArrayAttr initAttr, ArrayRefType arrayRefType) const
size_t computeAllocaAlignment(ArrayRefType type, Operation *op) const
LogicalResult matchAndRewrite(ArrayRefAllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
bool isZero(ArrayAttr arrayAttr) const
DenseMap< ArrayRefType, size_t > alignmentCache
LogicalResult matchAndRewrite(ArrayRefCopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefCreateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefFromArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefGetOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefInjectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefToArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LLVM::GlobalOp buildTraceInfoStruct(arc::RuntimeModelOp &op, ConversionPatternRewriter &rewriter) const
static constexpr uint64_t runtimeApiVersion
LogicalResult matchAndRewrite(arc::RuntimeModelOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final
Helper class mapping array values (HW or LLVM Dialect) to pointers to buffers containing the array va...
Definition HWToLLVM.h:47