CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerArcToLLVM.cpp
Go to the documentation of this file.
1//===- LowerArcToLLVM.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
27#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
28#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
29#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
30#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
31#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
32#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
33#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
34#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
35#include "mlir/Dialect/Func/IR/FuncOps.h"
36#include "mlir/Dialect/Index/IR/IndexOps.h"
37#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
38#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
39#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
40#include "mlir/Dialect/SCF/IR/SCF.h"
41#include "mlir/IR/Builders.h"
42#include "mlir/IR/BuiltinDialect.h"
43#include "mlir/Pass/Pass.h"
44#include "mlir/Transforms/DialectConversion.h"
45#include "llvm/Support/Debug.h"
46#include "llvm/Support/FormatVariadic.h"
47
48#include <cstddef>
49
50#define DEBUG_TYPE "lower-arc-to-llvm"
51
52namespace circt {
53#define GEN_PASS_DEF_LOWERARCTOLLVM
54#include "circt/Conversion/Passes.h.inc"
55} // namespace circt
56
57using namespace mlir;
58using namespace circt;
59using namespace arc;
60using namespace hw;
61using namespace runtime;
62
63//===----------------------------------------------------------------------===//
64// Lowering Patterns
65//===----------------------------------------------------------------------===//
66
67static llvm::Twine evalSymbolFromModelName(StringRef modelName) {
68 return modelName + "_eval";
69}
70
71namespace {
72
73struct ModelOpLowering : public OpConversionPattern<arc::ModelOp> {
74 using OpConversionPattern::OpConversionPattern;
75 LogicalResult
76 matchAndRewrite(arc::ModelOp op, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter) const final {
78 {
79 IRRewriter::InsertionGuard guard(rewriter);
80 rewriter.setInsertionPointToEnd(&op.getBodyBlock());
81 func::ReturnOp::create(rewriter, op.getLoc());
82 }
83 auto funcName =
84 rewriter.getStringAttr(evalSymbolFromModelName(op.getName()));
85 auto funcType =
86 rewriter.getFunctionType(op.getBody().getArgumentTypes(), {});
87 auto func =
88 mlir::func::FuncOp::create(rewriter, op.getLoc(), funcName, funcType);
89 rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
90 rewriter.eraseOp(op);
91 return success();
92 }
93};
94
95struct AllocStorageOpLowering
96 : public OpConversionPattern<arc::AllocStorageOp> {
97 using OpConversionPattern::OpConversionPattern;
98 LogicalResult
99 matchAndRewrite(arc::AllocStorageOp op, OpAdaptor adaptor,
100 ConversionPatternRewriter &rewriter) const final {
101 auto type = typeConverter->convertType(op.getType());
102 if (!op.getOffset().has_value())
103 return failure();
104 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, type, rewriter.getI8Type(),
105 adaptor.getInput(),
106 LLVM::GEPArg(*op.getOffset()));
107 return success();
108 }
109};
110
111template <class ConcreteOp>
112struct AllocStateLikeOpLowering : public OpConversionPattern<ConcreteOp> {
114 using OpConversionPattern<ConcreteOp>::typeConverter;
115 using OpAdaptor = typename ConcreteOp::Adaptor;
116
117 LogicalResult
118 matchAndRewrite(ConcreteOp op, OpAdaptor adaptor,
119 ConversionPatternRewriter &rewriter) const final {
120 // Get a pointer to the correct offset in the storage.
121 auto offsetAttr = op->template getAttrOfType<IntegerAttr>("offset");
122 if (!offsetAttr)
123 return failure();
124 Value ptr = LLVM::GEPOp::create(
125 rewriter, op->getLoc(), adaptor.getStorage().getType(),
126 rewriter.getI8Type(), adaptor.getStorage(),
127 LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
128 rewriter.replaceOp(op, ptr);
129 return success();
130 }
131};
132
133struct StateReadOpLowering : public OpConversionPattern<arc::StateReadOp> {
134 using OpConversionPattern::OpConversionPattern;
135 LogicalResult
136 matchAndRewrite(arc::StateReadOp op, OpAdaptor adaptor,
137 ConversionPatternRewriter &rewriter) const final {
138 // Loading an ArrayRef is a no-op as ArrayRefs are accessed by reference.
139 if (isa<ArrayRefType>(op.getType())) {
140 rewriter.replaceOp(op, adaptor.getState());
141 return success();
142 }
143
144 auto type = typeConverter->convertType(op.getType());
145 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, type, adaptor.getState());
146 return success();
147 }
148};
149
150struct StateWriteOpLowering : public OpConversionPattern<arc::StateWriteOp> {
151 using OpConversionPattern::OpConversionPattern;
152 LogicalResult
153 matchAndRewrite(arc::StateWriteOp op, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter) const final {
155 if (!isa<ArrayRefType>(op.getValue().getType())) {
156 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
157 adaptor.getState());
158 return success();
159 }
160
161 int numBytes = op.getState().getType().getByteWidth();
162 Value size = LLVM::ConstantOp::create(rewriter, op.getLoc(),
163 rewriter.getI64Type(), numBytes);
164 rewriter.replaceOpWithNewOp<LLVM::MemcpyOp>(
165 op, adaptor.getState(), adaptor.getValue(), size, /*volatile=*/false);
166 return success();
167 }
168};
169
170//===----------------------------------------------------------------------===//
171// Time Operations Lowering
172//===----------------------------------------------------------------------===//
173
174struct CurrentTimeOpLowering : public OpConversionPattern<arc::CurrentTimeOp> {
175 using OpConversionPattern::OpConversionPattern;
176 LogicalResult
177 matchAndRewrite(arc::CurrentTimeOp op, OpAdaptor adaptor,
178 ConversionPatternRewriter &rewriter) const final {
179 // Time is stored at offset 0 in storage (no offset needed).
180 Value ptr = adaptor.getStorage();
181 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, rewriter.getI64Type(), ptr);
182 return success();
183 }
184};
185
186// Lower `llhd.constant_time` to an `i64` LLVM constant holding the time in
187// femtoseconds. Time attributes with non-zero delta or epsilon, units smaller
188// than `fs`, or values that overflow `i64` femtoseconds are rejected.
189struct ConstantTimeOpLowering
190 : public OpConversionPattern<llhd::ConstantTimeOp> {
191 using OpConversionPattern::OpConversionPattern;
192 LogicalResult
193 matchAndRewrite(llhd::ConstantTimeOp op, OpAdaptor adaptor,
194 ConversionPatternRewriter &rewriter) const final {
195 auto attr = op.getValue();
196 if (attr.getDelta() != 0 || attr.getEpsilon() != 0)
197 return rewriter.notifyMatchFailure(
198 op, "non-zero delta or epsilon time components are not supported");
199 uint64_t value = attr.getTime();
200 StringRef unit = attr.getTimeUnit();
201 uint64_t scale;
202 if (unit == "fs")
203 scale = 1;
204 else if (unit == "ps")
205 scale = 1'000ULL;
206 else if (unit == "ns")
207 scale = 1'000'000ULL;
208 else if (unit == "us")
209 scale = 1'000'000'000ULL;
210 else if (unit == "ms")
211 scale = 1'000'000'000'000ULL;
212 else if (unit == "s")
213 scale = 1'000'000'000'000'000ULL;
214 else
215 return rewriter.notifyMatchFailure(
216 op, "time units smaller than `fs` are not supported");
217 if (value > std::numeric_limits<uint64_t>::max() / scale)
218 return rewriter.notifyMatchFailure(
219 op, "time value does not fit into `i64` femtoseconds");
220 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, rewriter.getI64Type(),
221 value * scale);
222 return success();
223 }
224};
225
226// `llhd.int_to_time` is a no-op
227struct IntToTimeOpLowering : public OpConversionPattern<llhd::IntToTimeOp> {
228 using OpConversionPattern::OpConversionPattern;
229 LogicalResult
230 matchAndRewrite(llhd::IntToTimeOp op, OpAdaptor adaptor,
231 ConversionPatternRewriter &rewriter) const final {
232 rewriter.replaceOp(op, adaptor.getInput());
233 return success();
234 }
235};
236
237// `llhd.time_to_int` is a no-op
238struct TimeToIntOpLowering : public OpConversionPattern<llhd::TimeToIntOp> {
239 using OpConversionPattern::OpConversionPattern;
240 LogicalResult
241 matchAndRewrite(llhd::TimeToIntOp op, OpAdaptor adaptor,
242 ConversionPatternRewriter &rewriter) const final {
243 rewriter.replaceOp(op, adaptor.getInput());
244 return success();
245 }
246};
247
248//===----------------------------------------------------------------------===//
249// Memory and Storage Lowering
250//===----------------------------------------------------------------------===//
251
252struct AllocMemoryOpLowering : public OpConversionPattern<arc::AllocMemoryOp> {
253 using OpConversionPattern::OpConversionPattern;
254 LogicalResult
255 matchAndRewrite(arc::AllocMemoryOp op, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter) const final {
257 auto offsetAttr = op->getAttrOfType<IntegerAttr>("offset");
258 if (!offsetAttr)
259 return failure();
260 Value ptr = LLVM::GEPOp::create(
261 rewriter, op.getLoc(), adaptor.getStorage().getType(),
262 rewriter.getI8Type(), adaptor.getStorage(),
263 LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
264
265 rewriter.replaceOp(op, ptr);
266 return success();
267 }
268};
269
270struct StorageGetOpLowering : public OpConversionPattern<arc::StorageGetOp> {
271 using OpConversionPattern::OpConversionPattern;
272 LogicalResult
273 matchAndRewrite(arc::StorageGetOp op, OpAdaptor adaptor,
274 ConversionPatternRewriter &rewriter) const final {
275 Value offset = LLVM::ConstantOp::create(
276 rewriter, op.getLoc(), rewriter.getI32Type(), op.getOffsetAttr());
277 Value ptr = LLVM::GEPOp::create(
278 rewriter, op.getLoc(), adaptor.getStorage().getType(),
279 rewriter.getI8Type(), adaptor.getStorage(), offset);
280 rewriter.replaceOp(op, ptr);
281 return success();
282 }
283};
284
285struct MemoryAccess {
286 Value ptr;
287 Value withinBounds;
288};
289
290static MemoryAccess prepareMemoryAccess(Location loc, Value memory,
291 Value address, MemoryType type,
292 ConversionPatternRewriter &rewriter) {
293 auto zextAddrType = rewriter.getIntegerType(
294 cast<IntegerType>(address.getType()).getWidth() + 1);
295 Value addr = LLVM::ZExtOp::create(rewriter, loc, zextAddrType, address);
296 Value addrLimit =
297 LLVM::ConstantOp::create(rewriter, loc, zextAddrType,
298 rewriter.getI32IntegerAttr(type.getNumWords()));
299 Value withinBounds = LLVM::ICmpOp::create(
300 rewriter, loc, LLVM::ICmpPredicate::ult, addr, addrLimit);
301 Value ptr = LLVM::GEPOp::create(
302 rewriter, loc, LLVM::LLVMPointerType::get(memory.getContext()),
303 rewriter.getIntegerType(type.getStride() * 8), memory, ValueRange{addr});
304 return {ptr, withinBounds};
305}
306
307struct MemoryReadOpLowering : public OpConversionPattern<arc::MemoryReadOp> {
308 using OpConversionPattern::OpConversionPattern;
309 LogicalResult
310 matchAndRewrite(arc::MemoryReadOp op, OpAdaptor adaptor,
311 ConversionPatternRewriter &rewriter) const final {
312 auto type = typeConverter->convertType(op.getType());
313 auto memoryType = cast<MemoryType>(op.getMemory().getType());
314 auto access =
315 prepareMemoryAccess(op.getLoc(), adaptor.getMemory(),
316 adaptor.getAddress(), memoryType, rewriter);
317
318 // Only attempt to read the memory if the address is within bounds,
319 // otherwise produce a zero value.
320 rewriter.replaceOpWithNewOp<scf::IfOp>(
321 op, access.withinBounds,
322 [&](auto &builder, auto loc) {
323 Value loadOp = LLVM::LoadOp::create(
324 builder, loc, memoryType.getWordType(), access.ptr);
325 scf::YieldOp::create(builder, loc, loadOp);
326 },
327 [&](auto &builder, auto loc) {
328 Value zeroValue = LLVM::ConstantOp::create(
329 builder, loc, type, builder.getI64IntegerAttr(0));
330 scf::YieldOp::create(builder, loc, zeroValue);
331 });
332 return success();
333 }
334};
335
336struct MemoryWriteOpLowering : public OpConversionPattern<arc::MemoryWriteOp> {
337 using OpConversionPattern::OpConversionPattern;
338 LogicalResult
339 matchAndRewrite(arc::MemoryWriteOp op, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const final {
341 auto access = prepareMemoryAccess(
342 op.getLoc(), adaptor.getMemory(), adaptor.getAddress(),
343 cast<MemoryType>(op.getMemory().getType()), rewriter);
344 auto enable = access.withinBounds;
345
346 // Only attempt to write the memory if the address is within bounds.
347 rewriter.replaceOpWithNewOp<scf::IfOp>(
348 op, enable, [&](auto &builder, auto loc) {
349 LLVM::StoreOp::create(builder, loc, adaptor.getData(), access.ptr);
350 scf::YieldOp::create(builder, loc);
351 });
352 return success();
353 }
354};
355
356/// A dummy lowering for clock gates to an AND gate.
357struct ClockGateOpLowering : public OpConversionPattern<seq::ClockGateOp> {
358 using OpConversionPattern::OpConversionPattern;
359 LogicalResult
360 matchAndRewrite(seq::ClockGateOp op, OpAdaptor adaptor,
361 ConversionPatternRewriter &rewriter) const final {
362 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, adaptor.getInput(),
363 adaptor.getEnable());
364 return success();
365 }
366};
367
368/// Lower 'seq.clock_inv x' to 'llvm.xor x true'
369struct ClockInvOpLowering : public OpConversionPattern<seq::ClockInverterOp> {
370 using OpConversionPattern::OpConversionPattern;
371 LogicalResult
372 matchAndRewrite(seq::ClockInverterOp op, OpAdaptor adaptor,
373 ConversionPatternRewriter &rewriter) const final {
374 auto constTrue = LLVM::ConstantOp::create(rewriter, op->getLoc(),
375 rewriter.getI1Type(), 1);
376 rewriter.replaceOpWithNewOp<LLVM::XOrOp>(op, adaptor.getInput(), constTrue);
377 return success();
378 }
379};
380
381struct ZeroCountOpLowering : public OpConversionPattern<arc::ZeroCountOp> {
382 using OpConversionPattern::OpConversionPattern;
383 LogicalResult
384 matchAndRewrite(arc::ZeroCountOp op, OpAdaptor adaptor,
385 ConversionPatternRewriter &rewriter) const override {
386 // Use poison when input is zero.
387 IntegerAttr isZeroPoison = rewriter.getBoolAttr(true);
388
389 if (op.getPredicate() == arc::ZeroCountPredicate::leading) {
390 rewriter.replaceOpWithNewOp<LLVM::CountLeadingZerosOp>(
391 op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
392 return success();
393 }
394
395 rewriter.replaceOpWithNewOp<LLVM::CountTrailingZerosOp>(
396 op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
397 return success();
398 }
399};
400
401struct SeqConstClockLowering : public OpConversionPattern<seq::ConstClockOp> {
402 using OpConversionPattern::OpConversionPattern;
403 LogicalResult
404 matchAndRewrite(seq::ConstClockOp op, OpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter) const override {
406 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
407 op, rewriter.getI1Type(), static_cast<int64_t>(op.getValue()));
408 return success();
409 }
410};
411
412template <typename OpTy>
413struct ReplaceOpWithInputPattern : public OpConversionPattern<OpTy> {
415 using OpAdaptor = typename OpTy::Adaptor;
416 LogicalResult
417 matchAndRewrite(OpTy op, OpAdaptor adaptor,
418 ConversionPatternRewriter &rewriter) const override {
419 rewriter.replaceOp(op, adaptor.getInput());
420 return success();
421 }
422};
423
424} // namespace
425
426//===----------------------------------------------------------------------===//
427// Simulation Orchestration Lowering Patterns
428//===----------------------------------------------------------------------===//
429
430namespace {
431
432struct ModelInfoMap {
433 size_t numStateBytes;
434 llvm::DenseMap<StringRef, StateInfo> states;
435 mlir::FlatSymbolRefAttr initialFnSymbol;
436 mlir::FlatSymbolRefAttr finalFnSymbol;
437};
438
439template <typename OpTy>
440struct ModelAwarePattern : public OpConversionPattern<OpTy> {
441 ModelAwarePattern(const TypeConverter &typeConverter, MLIRContext *context,
442 llvm::DenseMap<StringRef, ModelInfoMap> &modelInfo)
443 : OpConversionPattern<OpTy>(typeConverter, context),
444 modelInfo(modelInfo) {}
445
446protected:
447 Value createPtrToPortState(ConversionPatternRewriter &rewriter, Location loc,
448 Value state, const StateInfo &port) const {
449 MLIRContext *ctx = rewriter.getContext();
450 return LLVM::GEPOp::create(rewriter, loc, LLVM::LLVMPointerType::get(ctx),
451 IntegerType::get(ctx, 8), state,
452 LLVM::GEPArg(port.offset));
453 }
454
455 llvm::DenseMap<StringRef, ModelInfoMap> &modelInfo;
456};
457
458/// Lowers SimInstantiateOp to a malloc and memset call. This pattern will
459/// mutate the global module.
460struct SimInstantiateOpLowering
461 : public ModelAwarePattern<arc::SimInstantiateOp> {
462 using ModelAwarePattern::ModelAwarePattern;
463
464 LogicalResult
465 matchAndRewrite(arc::SimInstantiateOp op, OpAdaptor adaptor,
466 ConversionPatternRewriter &rewriter) const final {
467 auto modelIt = modelInfo.find(
468 cast<SimModelInstanceType>(op.getBody().getArgument(0).getType())
469 .getModel()
470 .getValue());
471 ModelInfoMap &model = modelIt->second;
472
473 bool useRuntime = op.getRuntimeModel().has_value();
474
475 ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
476 if (!moduleOp)
477 return failure();
478
479 ConversionPatternRewriter::InsertionGuard guard(rewriter);
480
481 // FIXME: like the rest of MLIR, this assumes sizeof(intptr_t) ==
482 // sizeof(size_t) on the target architecture.
483 Type convertedIndex = typeConverter->convertType(rewriter.getIndexType());
484 Location loc = op.getLoc();
485 Value allocated;
486
487 if (useRuntime) {
488 // The instance is using the runtime library
489 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
490
491 Value runtimeArgs;
492 // If present, materialize the runtime argument string on the stack
493 if (op.getRuntimeArgs().has_value()) {
494 SmallVector<int8_t> argStringVec(op.getRuntimeArgsAttr().begin(),
495 op.getRuntimeArgsAttr().end());
496 argStringVec.push_back('\0');
497 auto strAttr = mlir::DenseElementsAttr::get(
498 mlir::RankedTensorType::get({(int64_t)argStringVec.size()},
499 rewriter.getI8Type()),
500 llvm::ArrayRef(argStringVec));
501
502 auto arrayCst = LLVM::ConstantOp::create(
503 rewriter, loc,
504 LLVM::LLVMArrayType::get(rewriter.getI8Type(), argStringVec.size()),
505 strAttr);
506 auto cst1 = LLVM::ConstantOp::create(rewriter, loc,
507 rewriter.getI32IntegerAttr(1));
508 runtimeArgs = LLVM::AllocaOp::create(rewriter, loc, ptrTy,
509 arrayCst.getType(), cst1);
510 LLVM::LifetimeStartOp::create(rewriter, loc, runtimeArgs);
511 LLVM::StoreOp::create(rewriter, loc, arrayCst, runtimeArgs);
512 } else {
513 runtimeArgs = LLVM::ZeroOp::create(rewriter, loc, ptrTy).getResult();
514 }
515 // Call the state allocation function
516 auto rtModelPtr = LLVM::AddressOfOp::create(rewriter, loc, ptrTy,
517 op.getRuntimeModelAttr())
518 .getResult();
519 allocated =
520 LLVM::CallOp::create(rewriter, loc, {ptrTy},
521 runtime::APICallbacks::symNameAllocInstance,
522 {rtModelPtr, runtimeArgs})
523 .getResult();
524
525 if (op.getRuntimeArgs().has_value())
526 LLVM::LifetimeEndOp::create(rewriter, loc, runtimeArgs);
527
528 } else {
529 // The instance is not using the runtime library
530 FailureOr<LLVM::LLVMFuncOp> mallocFunc =
531 LLVM::lookupOrCreateMallocFn(rewriter, moduleOp, convertedIndex);
532 if (failed(mallocFunc))
533 return mallocFunc;
534
535 Value numStateBytes = LLVM::ConstantOp::create(
536 rewriter, loc, convertedIndex, model.numStateBytes);
537 allocated = LLVM::CallOp::create(rewriter, loc, mallocFunc.value(),
538 ValueRange{numStateBytes})
539 .getResult();
540 Value zero =
541 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI8Type(), 0);
542 LLVM::MemsetOp::create(rewriter, loc, allocated, zero, numStateBytes,
543 false);
544 }
545
546 // Call the model's 'initial' function if present.
547 if (model.initialFnSymbol) {
548 auto initialFnType = LLVM::LLVMFunctionType::get(
549 LLVM::LLVMVoidType::get(op.getContext()),
550 {LLVM::LLVMPointerType::get(op.getContext())});
551 LLVM::CallOp::create(rewriter, loc, initialFnType, model.initialFnSymbol,
552 ValueRange{allocated});
553 }
554
555 // Call the runtime's 'onInitialized' function if present.
556 if (useRuntime)
557 LLVM::CallOp::create(rewriter, loc, TypeRange{},
558 runtime::APICallbacks::symNameOnInitialized,
559 {allocated});
560
561 // Execute the body.
562 rewriter.inlineBlockBefore(&adaptor.getBody().getBlocks().front(), op,
563 {allocated});
564
565 // Call the model's 'final' function if present.
566 if (model.finalFnSymbol) {
567 auto finalFnType = LLVM::LLVMFunctionType::get(
568 LLVM::LLVMVoidType::get(op.getContext()),
569 {LLVM::LLVMPointerType::get(op.getContext())});
570 LLVM::CallOp::create(rewriter, loc, finalFnType, model.finalFnSymbol,
571 ValueRange{allocated});
572 }
573
574 if (useRuntime) {
575 LLVM::CallOp::create(rewriter, loc, TypeRange{},
576 runtime::APICallbacks::symNameDeleteInstance,
577 {allocated});
578 } else {
579 FailureOr<LLVM::LLVMFuncOp> freeFunc =
580 LLVM::lookupOrCreateFreeFn(rewriter, moduleOp);
581 if (failed(freeFunc))
582 return freeFunc;
583
584 LLVM::CallOp::create(rewriter, loc, freeFunc.value(),
585 ValueRange{allocated});
586 }
587
588 rewriter.eraseOp(op);
589 return success();
590 }
591};
592
593struct SimSetInputOpLowering : public ModelAwarePattern<arc::SimSetInputOp> {
594 using ModelAwarePattern::ModelAwarePattern;
595
596 LogicalResult
597 matchAndRewrite(arc::SimSetInputOp op, OpAdaptor adaptor,
598 ConversionPatternRewriter &rewriter) const final {
599 auto modelIt =
600 modelInfo.find(cast<SimModelInstanceType>(op.getInstance().getType())
601 .getModel()
602 .getValue());
603 ModelInfoMap &model = modelIt->second;
604
605 auto portIt = model.states.find(op.getInput());
606 if (portIt == model.states.end()) {
607 // If the port is not found in the state, it means the model does not
608 // actually use it. Thus this operation is a no-op.
609 rewriter.eraseOp(op);
610 return success();
611 }
612
613 StateInfo &port = portIt->second;
614 Value statePtr = createPtrToPortState(rewriter, op.getLoc(),
615 adaptor.getInstance(), port);
616 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
617 statePtr);
618
619 return success();
620 }
621};
622
623struct SimGetPortOpLowering : public ModelAwarePattern<arc::SimGetPortOp> {
624 using ModelAwarePattern::ModelAwarePattern;
625
626 LogicalResult
627 matchAndRewrite(arc::SimGetPortOp op, OpAdaptor adaptor,
628 ConversionPatternRewriter &rewriter) const final {
629 auto modelIt =
630 modelInfo.find(cast<SimModelInstanceType>(op.getInstance().getType())
631 .getModel()
632 .getValue());
633 ModelInfoMap &model = modelIt->second;
634
635 auto type = typeConverter->convertType(op.getValue().getType());
636 if (!type)
637 return failure();
638 auto portIt = model.states.find(op.getPort());
639 if (portIt == model.states.end()) {
640 // If the port is not found in the state, it means the model does not
641 // actually set it. Thus this operation returns 0.
642 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, type, 0);
643 return success();
644 }
645
646 StateInfo &port = portIt->second;
647 Value statePtr = createPtrToPortState(rewriter, op.getLoc(),
648 adaptor.getInstance(), port);
649 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, type, statePtr);
650
651 return success();
652 }
653};
654
655struct SimStepOpLowering : public ModelAwarePattern<arc::SimStepOp> {
656 using ModelAwarePattern::ModelAwarePattern;
657
658 LogicalResult
659 matchAndRewrite(arc::SimStepOp op, OpAdaptor adaptor,
660 ConversionPatternRewriter &rewriter) const final {
661 StringRef modelName = cast<SimModelInstanceType>(op.getInstance().getType())
662 .getModel()
663 .getValue();
664
665 if (adaptor.getTimePostIncrement()) {
666 // Increment time after step
667 OpBuilder::InsertionGuard g(rewriter);
668 rewriter.setInsertionPointAfter(op);
669 auto oldTime =
670 arc::SimGetTimeOp::create(rewriter, op.getLoc(), op.getInstance());
671 auto newTime = LLVM::AddOp::create(rewriter, op.getLoc(), oldTime,
672 adaptor.getTimePostIncrement());
673 arc::SimSetTimeOp::create(rewriter, op.getLoc(), op.getInstance(),
674 newTime);
675 }
676
677 StringAttr evalFunc =
678 rewriter.getStringAttr(evalSymbolFromModelName(modelName));
679 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, mlir::TypeRange(), evalFunc,
680 adaptor.getInstance());
681
682 return success();
683 }
684};
685
686// Loads the simulation time (i64 femtoseconds) from byte offset 0 in the
687// model instance's state storage.
688struct SimGetTimeOpLowering : public OpConversionPattern<arc::SimGetTimeOp> {
689 using OpConversionPattern::OpConversionPattern;
690
691 LogicalResult
692 matchAndRewrite(arc::SimGetTimeOp op, OpAdaptor adaptor,
693 ConversionPatternRewriter &rewriter) const final {
694 // Time is stored at offset 0 in the instance storage.
695 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, rewriter.getI64Type(),
696 adaptor.getInstance());
697 return success();
698 }
699};
700
701// Stores the simulation time (i64 femtoseconds) to byte offset 0 in the
702// model instance's state storage.
703struct SimSetTimeOpLowering : public OpConversionPattern<arc::SimSetTimeOp> {
704 using OpConversionPattern::OpConversionPattern;
705
706 LogicalResult
707 matchAndRewrite(arc::SimSetTimeOp op, OpAdaptor adaptor,
708 ConversionPatternRewriter &rewriter) const final {
709 // Time is stored at offset 0 in the instance storage.
710 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getTime(),
711 adaptor.getInstance());
712 return success();
713 }
714};
715
716// Loads the next wakeup time (i64 femtoseconds) from `kNextWakeupOffset` of
717// the model instance's state storage.
718struct SimGetNextWakeupOpLowering
719 : public OpConversionPattern<arc::SimGetNextWakeupOp> {
720 using OpConversionPattern::OpConversionPattern;
721
722 LogicalResult
723 matchAndRewrite(arc::SimGetNextWakeupOp op, OpAdaptor adaptor,
724 ConversionPatternRewriter &rewriter) const final {
725 auto loc = op.getLoc();
726 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
727 Value slotPtr = LLVM::GEPOp::create(
728 rewriter, loc, ptrType, rewriter.getI8Type(), adaptor.getInstance(),
729 ArrayRef<LLVM::GEPArg>{arc::kNextWakeupOffset});
730 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, rewriter.getI64Type(),
731 slotPtr);
732 return success();
733 }
734};
735
736// Global string constants in the module.
737class StringCache {
738public:
739 Value getOrCreate(OpBuilder &b, StringRef formatStr) {
740 auto it = cache.find(formatStr);
741 if (it != cache.end()) {
742 return LLVM::AddressOfOp::create(b, b.getUnknownLoc(), it->second);
743 }
744
745 Location loc = b.getUnknownLoc();
746 LLVM::GlobalOp global;
747 {
748 OpBuilder::InsertionGuard guard(b);
749 ModuleOp m =
750 b.getInsertionBlock()->getParent()->getParentOfType<ModuleOp>();
751 b.setInsertionPointToStart(m.getBody());
752
753 SmallVector<char> strVec(formatStr.begin(), formatStr.end());
754 strVec.push_back(0);
755
756 auto name = llvm::formatv("_arc_str_{0}", cache.size()).str();
757 auto globalType = LLVM::LLVMArrayType::get(b.getI8Type(), strVec.size());
758 global = LLVM::GlobalOp::create(b, loc, globalType, /*isConstant=*/true,
759 LLVM::Linkage::Internal,
760 /*name=*/name, b.getStringAttr(strVec),
761 /*alignment=*/0);
762 }
763
764 cache[formatStr] = global;
765 return LLVM::AddressOfOp::create(b, loc, global);
766 }
767
768private:
769 llvm::StringMap<LLVM::GlobalOp> cache;
770};
771
772FailureOr<LLVM::CallOp> emitPrintfCall(OpBuilder &builder, Location loc,
773 StringCache &cache, StringRef formatStr,
774 ValueRange args) {
775 ModuleOp moduleOp =
776 builder.getInsertionBlock()->getParent()->getParentOfType<ModuleOp>();
777 // Lookup or create printf function symbol.
778 MLIRContext *ctx = builder.getContext();
779 auto printfFunc = LLVM::lookupOrCreateFn(builder, moduleOp, "printf",
780 LLVM::LLVMPointerType::get(ctx),
781 LLVM::LLVMVoidType::get(ctx), true);
782 if (failed(printfFunc))
783 return printfFunc;
784
785 Value formatStrPtr = cache.getOrCreate(builder, formatStr);
786 SmallVector<Value> argsVec(1, formatStrPtr);
787 argsVec.append(args.begin(), args.end());
788 return LLVM::CallOp::create(builder, loc, printfFunc.value(), argsVec);
789}
790
791/// Lowers SimEmitValueOp to a printf call. The integer will be printed in its
792/// entirety if it is of size up to size_t, and explicitly truncated otherwise.
793/// This pattern will mutate the global module.
794struct SimEmitValueOpLowering
795 : public OpConversionPattern<arc::SimEmitValueOp> {
796 SimEmitValueOpLowering(const TypeConverter &typeConverter,
797 MLIRContext *context, StringCache &formatStringCache)
798 : OpConversionPattern(typeConverter, context),
799 formatStringCache(formatStringCache) {}
800
801 LogicalResult
802 matchAndRewrite(arc::SimEmitValueOp op, OpAdaptor adaptor,
803 ConversionPatternRewriter &rewriter) const final {
804 auto valueType = dyn_cast<IntegerType>(adaptor.getValue().getType());
805 if (!valueType)
806 return failure();
807
808 Location loc = op.getLoc();
809
810 ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
811 if (!moduleOp)
812 return failure();
813
814 SmallVector<Value> printfVariadicArgs;
815 SmallString<16> printfFormatStr;
816 int remainingBits = valueType.getWidth();
817 Value value = adaptor.getValue();
818
819 // Assumes the target platform uses 64bit for long long ints (%llx
820 // formatter).
821 constexpr llvm::StringRef intFormatter = "llx";
822 auto intType = IntegerType::get(getContext(), 64);
823 Value shiftValue = LLVM::ConstantOp::create(
824 rewriter, loc, rewriter.getIntegerAttr(valueType, intType.getWidth()));
825
826 if (valueType.getWidth() < intType.getWidth()) {
827 int width = llvm::divideCeil(valueType.getWidth(), 4);
828 printfFormatStr = llvm::formatv("%0{0}{1}", width, intFormatter);
829 printfVariadicArgs.push_back(
830 LLVM::ZExtOp::create(rewriter, loc, intType, value));
831 } else {
832 // Process the value in 64 bit chunks, starting from the least significant
833 // bits. Since we append chunks in low-to-high order, we reverse the
834 // vector to print them in the correct high-to-low order.
835 int otherChunkWidth = intType.getWidth() / 4;
836 int firstChunkWidth =
837 llvm::divideCeil(valueType.getWidth() % intType.getWidth(), 4);
838 if (firstChunkWidth == 0) { // print the full 64-bit hex or a subset.
839 firstChunkWidth = otherChunkWidth;
840 }
841
842 std::string firstChunkFormat =
843 llvm::formatv("%0{0}{1}", firstChunkWidth, intFormatter);
844 std::string otherChunkFormat =
845 llvm::formatv("%0{0}{1}", otherChunkWidth, intFormatter);
846
847 for (int i = 0; remainingBits > 0; ++i) {
848 // Append 64-bit chunks to the printf arguments, in low-to-high
849 // order. The integer is printed in hex format with zero padding.
850 printfVariadicArgs.push_back(
851 LLVM::TruncOp::create(rewriter, loc, intType, value));
852
853 // Zero-padded format specifier for fixed width, e.g. %01llx for 4 bits.
854 printfFormatStr.append(i == 0 ? firstChunkFormat : otherChunkFormat);
855
856 value =
857 LLVM::LShrOp::create(rewriter, loc, value, shiftValue).getResult();
858 remainingBits -= intType.getWidth();
859 }
860 }
861
862 std::reverse(printfVariadicArgs.begin(), printfVariadicArgs.end());
863
864 SmallString<16> formatStr = adaptor.getValueName();
865 formatStr.append(" = ");
866 formatStr.append(printfFormatStr);
867 formatStr.append("\n");
868
869 auto callOp = emitPrintfCall(rewriter, op->getLoc(), formatStringCache,
870 formatStr, printfVariadicArgs);
871 if (failed(callOp))
872 return failure();
873 rewriter.replaceOp(op, *callOp);
874
875 return success();
876 }
877
878 StringCache &formatStringCache;
879};
880
881//===----------------------------------------------------------------------===//
882// `sim` dialect lowerings
883//===----------------------------------------------------------------------===//
884
885// Helper struct to hold the format string and arguments for arcRuntimeFormat.
886struct FormatInfo {
887 SmallVector<FmtDescriptor> descriptors;
888 SmallVector<Value> args;
889};
890
891// Copies the given integer value into an alloca, returning a pointer to it.
892//
893// The alloca is rounded up to a 64-bit boundary and is written as little-endian
894// words of size 64-bits, to be compatible with the constructor of APInt.
895static Value reg2mem(ConversionPatternRewriter &rewriter, Location loc,
896 Value value) {
897 // Round up the type size to a 64-bit boundary.
898 int64_t origBitwidth = cast<IntegerType>(value.getType()).getWidth();
899 int64_t bitwidth = llvm::divideCeil(origBitwidth, 64) * 64;
900 int64_t numWords = bitwidth / 64;
901
902 // Create an alloca for the rounded up type.
903 LLVM::ConstantOp alloca_size =
904 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), numWords);
905 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
906 auto allocaOp = LLVM::AllocaOp::create(rewriter, loc, ptrType,
907 rewriter.getI64Type(), alloca_size);
908 LLVM::LifetimeStartOp::create(rewriter, loc, allocaOp);
909
910 // Copy `value` into the alloca, 64-bits at a time from the least significant
911 // bits first.
912 for (int64_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
913 Value cst = LLVM::ConstantOp::create(
914 rewriter, loc, rewriter.getIntegerType(origBitwidth), wordIdx * 64);
915 Value v = LLVM::LShrOp::create(rewriter, loc, value, cst);
916 if (origBitwidth > 64) {
917 v = LLVM::TruncOp::create(rewriter, loc, rewriter.getI64Type(), v);
918 } else if (origBitwidth < 64) {
919 v = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), v);
920 }
921 Value gep = LLVM::GEPOp::create(rewriter, loc, ptrType,
922 rewriter.getI64Type(), allocaOp, {wordIdx});
923 LLVM::StoreOp::create(rewriter, loc, v, gep);
924 }
925
926 return allocaOp;
927}
928
929// Statically folds a value of type sim::FormatStringType to a FormatInfo.
930static FailureOr<FormatInfo>
931foldFormatString(ConversionPatternRewriter &rewriter, Value fstringValue,
932 StringCache &cache) {
933 Operation *op = fstringValue.getDefiningOp();
934 return llvm::TypeSwitch<Operation *, FailureOr<FormatInfo>>(op)
935 .Case<sim::FormatCharOp>(
936 [&](sim::FormatCharOp op) -> FailureOr<FormatInfo> {
937 FmtDescriptor d = FmtDescriptor::createChar();
938 return FormatInfo{{d}, {op.getValue()}};
939 })
940 .Case<sim::FormatDecOp>([&](sim::FormatDecOp op)
941 -> FailureOr<FormatInfo> {
942 FmtDescriptor d = FmtDescriptor::createInt(
943 op.getValue().getType().getWidth(), 10, op.getIsLeftAligned(),
944 op.getSpecifierWidth().value_or(-1), op.getPaddingChar(), false,
945 op.getIsSigned());
946 return FormatInfo{{d}, {reg2mem(rewriter, op.getLoc(), op.getValue())}};
947 })
948 .Case<sim::FormatHexOp>([&](sim::FormatHexOp op)
949 -> FailureOr<FormatInfo> {
950 FmtDescriptor d = FmtDescriptor::createInt(
951 op.getValue().getType().getWidth(), 16, op.getIsLeftAligned(),
952 op.getSpecifierWidth().value_or(-1), op.getPaddingChar(),
953 op.getIsHexUppercase(), false);
954 return FormatInfo{{d}, {reg2mem(rewriter, op.getLoc(), op.getValue())}};
955 })
956 .Case<sim::FormatOctOp>([&](sim::FormatOctOp op)
957 -> FailureOr<FormatInfo> {
958 FmtDescriptor d = FmtDescriptor::createInt(
959 op.getValue().getType().getWidth(), 8, op.getIsLeftAligned(),
960 op.getSpecifierWidth().value_or(-1), op.getPaddingChar(), false,
961 false);
962 return FormatInfo{{d}, {reg2mem(rewriter, op.getLoc(), op.getValue())}};
963 })
964 .Case<sim::FormatLiteralOp>(
965 [&](sim::FormatLiteralOp op) -> FailureOr<FormatInfo> {
966 if (op.getLiteral().size() < 8 &&
967 op.getLiteral().find('\0') == StringRef::npos) {
968 // We can use the small string optimization.
969 FmtDescriptor d =
970 FmtDescriptor::createSmallLiteral(op.getLiteral());
971 return FormatInfo{{d}, {}};
972 }
973 FmtDescriptor d =
974 FmtDescriptor::createLiteral(op.getLiteral().size());
975 Value value = cache.getOrCreate(rewriter, op.getLiteral());
976 return FormatInfo{{d}, {value}};
977 })
978 .Case<sim::FormatStringConcatOp>(
979 [&](sim::FormatStringConcatOp op) -> FailureOr<FormatInfo> {
980 auto fmt = foldFormatString(rewriter, op.getInputs()[0], cache);
981 if (failed(fmt))
982 return failure();
983 for (auto input : op.getInputs().drop_front()) {
984 auto next = foldFormatString(rewriter, input, cache);
985 if (failed(next))
986 return failure();
987 fmt->descriptors.append(next->descriptors);
988 fmt->args.append(next->args);
989 }
990 return fmt;
991 })
992 .Default(
993 [](Operation *op) -> FailureOr<FormatInfo> { return failure(); });
994}
995
996FailureOr<LLVM::CallOp> emitFmtCall(OpBuilder &builder, Location loc,
997 StringCache &stringCache,
998 ArrayRef<FmtDescriptor> descriptors,
999 ValueRange args) {
1000 ModuleOp moduleOp =
1001 builder.getInsertionBlock()->getParent()->getParentOfType<ModuleOp>();
1002 // Lookup or create the arcRuntimeFormat function symbol.
1003 MLIRContext *ctx = builder.getContext();
1004 auto func = LLVM::lookupOrCreateFn(
1005 builder, moduleOp, runtime::APICallbacks::symNameFormat,
1006 LLVM::LLVMPointerType::get(ctx), LLVM::LLVMVoidType::get(ctx), true);
1007 if (failed(func))
1008 return func;
1009
1010 StringRef rawDescriptors(reinterpret_cast<const char *>(descriptors.data()),
1011 descriptors.size() * sizeof(FmtDescriptor));
1012 Value fmtPtr = stringCache.getOrCreate(builder, rawDescriptors);
1013
1014 SmallVector<Value> argsVec(1, fmtPtr);
1015 argsVec.append(args.begin(), args.end());
1016 auto result = LLVM::CallOp::create(builder, loc, func.value(), argsVec);
1017
1018 for (Value arg : args) {
1019 Operation *definingOp = arg.getDefiningOp();
1020 if (auto alloca = dyn_cast_if_present<LLVM::AllocaOp>(definingOp)) {
1021 LLVM::LifetimeEndOp::create(builder, loc, arg);
1022 }
1023 }
1024
1025 return result;
1026}
1027
1028struct SimPrintFormattedProcOpLowering
1029 : public OpConversionPattern<sim::PrintFormattedProcOp> {
1030 SimPrintFormattedProcOpLowering(const TypeConverter &typeConverter,
1031 MLIRContext *context,
1032 StringCache &stringCache)
1033 : OpConversionPattern<sim::PrintFormattedProcOp>(typeConverter, context),
1034 stringCache(stringCache) {}
1035
1036 LogicalResult
1037 matchAndRewrite(sim::PrintFormattedProcOp op, OpAdaptor adaptor,
1038 ConversionPatternRewriter &rewriter) const override {
1039 auto formatInfo = foldFormatString(rewriter, op.getInput(), stringCache);
1040 if (failed(formatInfo))
1041 return rewriter.notifyMatchFailure(op, "unsupported format string");
1042
1043 // Add the end descriptor.
1044 formatInfo->descriptors.push_back(FmtDescriptor());
1045
1046 auto result = emitFmtCall(rewriter, op.getLoc(), stringCache,
1047 formatInfo->descriptors, formatInfo->args);
1048 if (failed(result))
1049 return failure();
1050 rewriter.replaceOp(op, result.value());
1051
1052 return success();
1053 }
1054
1055 StringCache &stringCache;
1056};
1057
1058struct TerminateOpLowering : public OpConversionPattern<arc::TerminateOp> {
1059 using OpConversionPattern::OpConversionPattern;
1060
1061 LogicalResult
1062 matchAndRewrite(arc::TerminateOp op, OpAdaptor adaptor,
1063 ConversionPatternRewriter &rewriter) const override {
1064 auto loc = op.getLoc();
1065
1066 auto i8Type = rewriter.getI8Type();
1067 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1068
1069 Value flagPtr = LLVM::GEPOp::create(
1070 rewriter, loc, ptrType, i8Type, adaptor.getStorage(),
1071 ArrayRef<LLVM::GEPArg>{arc::kTerminateFlagOffset});
1072
1073 uint8_t statusCode = op.getSuccess() ? 1 : 2;
1074 Value codeVal = LLVM::ConstantOp::create(
1075 rewriter, loc, i8Type, rewriter.getI8IntegerAttr(statusCode));
1076
1077 LLVM::StoreOp::create(rewriter, loc, codeVal, flagPtr);
1078
1079 rewriter.eraseOp(op);
1080 return success();
1081 }
1082};
1083
1084// Loads the next wakeup time (i64 femtoseconds) from the model's storage at
1085// `kNextWakeupOffset`.
1086struct GetNextWakeupOpLowering
1087 : public OpConversionPattern<arc::GetNextWakeupOp> {
1088 using OpConversionPattern::OpConversionPattern;
1089
1090 LogicalResult
1091 matchAndRewrite(arc::GetNextWakeupOp op, OpAdaptor adaptor,
1092 ConversionPatternRewriter &rewriter) const override {
1093 auto loc = op.getLoc();
1094 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1095 Value slotPtr = LLVM::GEPOp::create(
1096 rewriter, loc, ptrType, rewriter.getI8Type(), adaptor.getStorage(),
1097 ArrayRef<LLVM::GEPArg>{arc::kNextWakeupOffset});
1098 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, rewriter.getI64Type(),
1099 slotPtr);
1100 return success();
1101 }
1102};
1103
1104// Stores the next wakeup time (i64 femtoseconds) to the model's storage at
1105// `kNextWakeupOffset`.
1106struct SetNextWakeupOpLowering
1107 : public OpConversionPattern<arc::SetNextWakeupOp> {
1108 using OpConversionPattern::OpConversionPattern;
1109
1110 LogicalResult
1111 matchAndRewrite(arc::SetNextWakeupOp op, OpAdaptor adaptor,
1112 ConversionPatternRewriter &rewriter) const override {
1113 auto loc = op.getLoc();
1114 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1115 Value slotPtr = LLVM::GEPOp::create(
1116 rewriter, loc, ptrType, rewriter.getI8Type(), adaptor.getStorage(),
1117 ArrayRef<LLVM::GEPArg>{arc::kNextWakeupOffset});
1118 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getTime(), slotPtr);
1119 return success();
1120 }
1121};
1122
1123} // namespace
1124
1125static LogicalResult convert(arc::ExecuteOp op, arc::ExecuteOp::Adaptor adaptor,
1126 ConversionPatternRewriter &rewriter,
1127 const TypeConverter &converter) {
1128 // Convert the argument types in the body blocks.
1129 if (failed(rewriter.convertRegionTypes(&op.getBody(), converter)))
1130 return failure();
1131
1132 // Split the block at the current insertion point such that we can branch into
1133 // the `arc.execute` body region, and have `arc.output` branch back to the
1134 // point after the `arc.execute`.
1135 auto *blockBefore = rewriter.getInsertionBlock();
1136 auto *blockAfter =
1137 rewriter.splitBlock(blockBefore, rewriter.getInsertionPoint());
1138
1139 // Branch to the entry block.
1140 rewriter.setInsertionPointToEnd(blockBefore);
1141 mlir::cf::BranchOp::create(rewriter, op.getLoc(), &op.getBody().front(),
1142 adaptor.getInputs());
1143
1144 // Make all `arc.output` terminators branch to the block after the
1145 // `arc.execute` op.
1146 for (auto &block : op.getBody()) {
1147 auto outputOp = dyn_cast<arc::OutputOp>(block.getTerminator());
1148 if (!outputOp)
1149 continue;
1150 rewriter.setInsertionPointToEnd(&block);
1151 rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(outputOp, blockAfter,
1152 outputOp.getOperands());
1153 }
1154
1155 // Inline the body region between the before and after blocks.
1156 rewriter.inlineRegionBefore(op.getBody(), blockAfter);
1157
1158 // Add arguments to the block after the `arc.execute`, replace the op's
1159 // results with the arguments, then perform block signature conversion.
1160 SmallVector<Value> args;
1161 args.reserve(op.getNumResults());
1162 for (auto result : op.getResults())
1163 args.push_back(blockAfter->addArgument(result.getType(), result.getLoc()));
1164 rewriter.replaceOp(op, args);
1165 auto conversion = converter.convertBlockSignature(blockAfter);
1166 if (!conversion)
1167 return failure();
1168 rewriter.applySignatureConversion(blockAfter, *conversion, &converter);
1169 return success();
1170}
1171
1172//===----------------------------------------------------------------------===//
1173// Runtime Implementation
1174//===----------------------------------------------------------------------===//
1175
1176template <typename T, typename = std::enable_if_t<std::is_integral<T>::value>>
1177static LLVM::GlobalOp
1178buildGlobalConstantIntArray(OpBuilder &builder, Location loc, Twine symName,
1179 SmallVectorImpl<T> &data,
1180 unsigned alignment = alignof(T)) {
1181 auto intType = builder.getIntegerType(8 * sizeof(T));
1182 Attribute denseAttr = mlir::DenseElementsAttr::get(
1183 mlir::RankedTensorType::get({(int64_t)data.size()}, intType),
1184 llvm::ArrayRef(data));
1185 auto globalOp = LLVM::GlobalOp::create(
1186 builder, loc, LLVM::LLVMArrayType::get(intType, data.size()),
1187 /*isConstant=*/true, LLVM::Linkage::Internal,
1188 builder.getStringAttr(symName), denseAttr);
1189 globalOp.setAlignmentAttr(builder.getI64IntegerAttr(alignment));
1190 return globalOp;
1191}
1192
1193// Construct a raw constant byte array from a vector of struct values
1194template <typename T>
1195static LLVM::GlobalOp
1196buildGlobalConstantRuntimeStructArray(OpBuilder &builder, Location loc,
1197 Twine symName,
1198 SmallVectorImpl<T> &array) {
1199 assert(!array.empty());
1200 static_assert(std::is_standard_layout<T>(),
1201 "Runtime struct must have standard layout");
1202 int64_t numBytes = sizeof(T) * array.size();
1203 Attribute denseAttr = mlir::DenseElementsAttr::get(
1204 mlir::RankedTensorType::get({numBytes}, builder.getI8Type()),
1205 llvm::ArrayRef(reinterpret_cast<uint8_t *>(array.data()), numBytes));
1206 auto globalOp = LLVM::GlobalOp::create(
1207 builder, loc, LLVM::LLVMArrayType::get(builder.getI8Type(), numBytes),
1208 /*isConstant=*/true, LLVM::Linkage::Internal,
1209 builder.getStringAttr(symName), denseAttr, alignof(T));
1210 return globalOp;
1211}
1212
1214 : public OpConversionPattern<arc::RuntimeModelOp> {
1215 using OpConversionPattern::OpConversionPattern;
1216
1217 static constexpr uint64_t runtimeApiVersion = ARC_RUNTIME_API_VERSION;
1218
1219 // Build the constant ArcModelTraceInfo struct and its members
1220 LLVM::GlobalOp
1221 buildTraceInfoStruct(arc::RuntimeModelOp &op,
1222 ConversionPatternRewriter &rewriter) const {
1223 if (!op.getTraceTaps().has_value() || op.getTraceTaps()->empty())
1224 return {};
1225 // Construct the array of tap names/aliases
1226 SmallVector<char> namesArray;
1227 SmallVector<ArcTraceTap> tapArray;
1228 tapArray.reserve(op.getTraceTaps()->size());
1229 for (auto attr : op.getTraceTapsAttr()) {
1230 auto tap = cast<TraceTapAttr>(attr);
1231 assert(!tap.getNames().empty() &&
1232 "Expected trace tap to have at least one name");
1233 for (auto alias : tap.getNames()) {
1234 auto aliasStr = cast<StringAttr>(alias);
1235 namesArray.append(aliasStr.begin(), aliasStr.end());
1236 namesArray.push_back('\0');
1237 }
1238 ArcTraceTap tapStruct;
1239 tapStruct.stateOffset = tap.getStateOffset();
1240 tapStruct.nameOffset = namesArray.size() - 1;
1241 tapStruct.typeBits = tap.getSigType().getValue().getIntOrFloatBitWidth();
1242 tapStruct.reserved = 0;
1243 tapArray.emplace_back(tapStruct);
1244 }
1245 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1246 auto namesGlobal = buildGlobalConstantIntArray(
1247 rewriter, op.getLoc(), "_arc_tap_names_" + op.getName(), namesArray);
1248 auto traceTapsArrayGlobal = buildGlobalConstantRuntimeStructArray(
1249 rewriter, op.getLoc(), "_arc_trace_taps_" + op.getName(), tapArray);
1250
1251 //
1252 // struct ArcModelTraceInfo {
1253 // uint64_t numTraceTaps;
1254 // struct ArcTraceTap *traceTaps;
1255 // const char *traceTapNames;
1256 // uint64_t traceBufferCapacity;
1257 // };
1258 //
1259 auto traceInfoStructType = LLVM::LLVMStructType::getLiteral(
1260 getContext(),
1261 {rewriter.getI64Type(), ptrTy, ptrTy, rewriter.getI64Type()});
1262 static_assert(sizeof(ArcModelTraceInfo) == 32 &&
1263 "Unexpected size of ArcModelTraceInfo struct");
1264
1265 auto globalSymName =
1266 rewriter.getStringAttr("_arc_trace_info_" + op.getName());
1267 auto traceInfoGlobalOp = LLVM::GlobalOp::create(
1268 rewriter, op.getLoc(), traceInfoStructType,
1269 /*isConstant=*/false, LLVM::Linkage::Internal, globalSymName,
1270 Attribute{}, alignof(ArcModelTraceInfo));
1271 OpBuilder::InsertionGuard g(rewriter);
1272
1273 // Struct Initializer
1274 Region &initRegion = traceInfoGlobalOp.getInitializerRegion();
1275 Block *initBlock = rewriter.createBlock(&initRegion);
1276 rewriter.setInsertionPointToStart(initBlock);
1277
1278 auto numTraceTapsCst = LLVM::ConstantOp::create(
1279 rewriter, op.getLoc(), rewriter.getI64IntegerAttr(tapArray.size()));
1280 auto traceTapArrayAddr =
1281 LLVM::AddressOfOp::create(rewriter, op.getLoc(), traceTapsArrayGlobal);
1282 auto tapNameArrayAddr =
1283 LLVM::AddressOfOp::create(rewriter, op.getLoc(), namesGlobal);
1284 auto bufferCapacityCst = LLVM::ConstantOp::create(
1285 rewriter, op.getLoc(),
1286 rewriter.getI64IntegerAttr(runtime::defaultTraceBufferCapacity));
1287
1288 Value initStruct =
1289 LLVM::PoisonOp::create(rewriter, op.getLoc(), traceInfoStructType);
1290
1291 // Field: uint64_t numTraceTaps
1292 initStruct =
1293 LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1294 numTraceTapsCst, ArrayRef<int64_t>{0});
1295 static_assert(offsetof(ArcModelTraceInfo, numTraceTaps) == 0,
1296 "Unexpected offset of field numTraceTaps");
1297 // Field: struct ArcTraceTap *traceTaps
1298 initStruct =
1299 LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1300 traceTapArrayAddr, ArrayRef<int64_t>{1});
1301 static_assert(offsetof(ArcModelTraceInfo, traceTaps) == 8,
1302 "Unexpected offset of field traceTaps");
1303 // Field: const char *traceTapNames
1304 initStruct =
1305 LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1306 tapNameArrayAddr, ArrayRef<int64_t>{2});
1307 static_assert(offsetof(ArcModelTraceInfo, traceTapNames) == 16,
1308 "Unexpected offset of field traceTapNames");
1309 // Field: uint64_t traceBufferCapacity
1310 initStruct =
1311 LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1312 bufferCapacityCst, ArrayRef<int64_t>{3});
1313 static_assert(offsetof(ArcModelTraceInfo, traceBufferCapacity) == 24,
1314 "Unexpected offset of field traceBufferCapacity");
1315 LLVM::ReturnOp::create(rewriter, op.getLoc(), initStruct);
1316
1317 return traceInfoGlobalOp;
1318 }
1319
1320 // Create a global LLVM struct containing the RuntimeModel metadata
1321 LogicalResult
1322 matchAndRewrite(arc::RuntimeModelOp op, OpAdaptor adaptor,
1323 ConversionPatternRewriter &rewriter) const final {
1324
1325 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1326 auto modelInfoStructType = LLVM::LLVMStructType::getLiteral(
1327 getContext(),
1328 {rewriter.getI64Type(), rewriter.getI64Type(), ptrTy, ptrTy});
1329 static_assert(sizeof(ArcRuntimeModelInfo) == 32 &&
1330 "Unexpected size of ArcRuntimeModelInfo struct");
1331
1332 rewriter.setInsertionPoint(op);
1333 auto traceInfoGlobal = buildTraceInfoStruct(op, rewriter);
1334
1335 // Construct the Model Name String GlobalOp
1336 SmallVector<char, 16> modNameArray(op.getName().begin(),
1337 op.getName().end());
1338 modNameArray.push_back('\0');
1339 auto nameGlobalType =
1340 LLVM::LLVMArrayType::get(rewriter.getI8Type(), modNameArray.size());
1341 auto globalSymName =
1342 rewriter.getStringAttr("_arc_mod_name_" + op.getName());
1343 auto nameGlobal = LLVM::GlobalOp::create(
1344 rewriter, op.getLoc(), nameGlobalType, /*isConstant=*/true,
1345 LLVM::Linkage::Internal,
1346 /*name=*/globalSymName, rewriter.getStringAttr(modNameArray),
1347 /*alignment=*/0);
1348
1349 // Construct the Model Info Struct GlobalOp
1350 // Note: The struct is supposed to be constant at runtime, but contains the
1351 // relocatable address of another symbol, so it should not be placed in the
1352 // "rodata" section.
1353 auto modInfoGlobalOp =
1354 LLVM::GlobalOp::create(rewriter, op.getLoc(), modelInfoStructType,
1355 /*isConstant=*/false, LLVM::Linkage::External,
1356 op.getSymName(), Attribute{});
1357
1358 // Struct Initializer
1359 Region &initRegion = modInfoGlobalOp.getInitializerRegion();
1360 Block *initBlock = rewriter.createBlock(&initRegion);
1361 rewriter.setInsertionPointToStart(initBlock);
1362 auto apiVersionCst = LLVM::ConstantOp::create(
1363 rewriter, op.getLoc(), rewriter.getI64IntegerAttr(runtimeApiVersion));
1364 auto numStateBytesCst = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1365 op.getNumStateBytesAttr());
1366 auto nameAddr =
1367 LLVM::AddressOfOp::create(rewriter, op.getLoc(), nameGlobal);
1368 Value traceInfoPtr;
1369 if (traceInfoGlobal)
1370 traceInfoPtr =
1371 LLVM::AddressOfOp::create(rewriter, op.getLoc(), traceInfoGlobal);
1372 else
1373 traceInfoPtr = LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy);
1374
1375 Value initStruct =
1376 LLVM::PoisonOp::create(rewriter, op.getLoc(), modelInfoStructType);
1377
1378 // Field: uint64_t apiVersion
1379 initStruct = LLVM::InsertValueOp::create(
1380 rewriter, op.getLoc(), initStruct, apiVersionCst, ArrayRef<int64_t>{0});
1381 static_assert(offsetof(ArcRuntimeModelInfo, apiVersion) == 0,
1382 "Unexpected offset of field apiVersion");
1383 // Field: uint64_t numStateBytes
1384 initStruct =
1385 LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1386 numStateBytesCst, ArrayRef<int64_t>{1});
1387 static_assert(offsetof(ArcRuntimeModelInfo, numStateBytes) == 8,
1388 "Unexpected offset of field numStateBytes");
1389 // Field: const char *modelName
1390 initStruct = LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
1391 nameAddr, ArrayRef<int64_t>{2});
1392 static_assert(offsetof(ArcRuntimeModelInfo, modelName) == 16,
1393 "Unexpected offset of field modelName");
1394 // Field: struct ArcModelTraceInfo *traceInfo
1395 initStruct = LLVM::InsertValueOp::create(
1396 rewriter, op.getLoc(), initStruct, traceInfoPtr, ArrayRef<int64_t>{3});
1397 static_assert(offsetof(ArcRuntimeModelInfo, traceInfo) == 24,
1398 "Unexpected offset of field traceInfo");
1399
1400 LLVM::ReturnOp::create(rewriter, op.getLoc(), initStruct);
1401
1402 rewriter.replaceOp(op, modInfoGlobalOp);
1403 return success();
1404 }
1405};
1406
1407//===----------------------------------------------------------------------===//
1408// ArrayRef patterns
1409//===----------------------------------------------------------------------===//
1410
1411size_t computeByteWidth(ArrayRefType type) {
1412 auto bitWidth = computeLLVMBitWidth(type);
1413 assert(bitWidth.has_value());
1414 return llvm::divideCeil(*bitWidth, 8);
1415}
1416
1417// Computes the padded bytewidth (stride) of each element.
1418size_t computeElementByteWidth(ArrayRefType arrayRefType) {
1419 auto arrayBitWidth = computeLLVMBitWidth(arrayRefType);
1420 assert(arrayBitWidth.has_value());
1421 assert(arrayRefType.getNumElements() > 0 &&
1422 "Cannot compute stride for zero sized array");
1423 size_t elementBitWidth = *arrayBitWidth / arrayRefType.getNumElements();
1424 return llvm::divideCeil(elementBitWidth, 8);
1425}
1426
1427struct ArrayRefAllocOpLowering : public OpConversionPattern<ArrayRefAllocOp> {
1428 using OpConversionPattern::OpConversionPattern;
1429
1430 LogicalResult
1431 matchAndRewrite(ArrayRefAllocOp op, OpAdaptor adaptor,
1432 ConversionPatternRewriter &rewriter) const override {
1433 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1434 auto i8Ty = rewriter.getI8Type();
1435 ArrayRefType arrayRefType = op.getType();
1436 size_t byteWidth = computeByteWidth(arrayRefType);
1437 auto size = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1438 rewriter.getI64Type(), byteWidth);
1439
1440 size_t alignment = computeAllocaAlignment(arrayRefType, op);
1441 auto alloc = LLVM::AllocaOp::create(rewriter, op.getLoc(), ptrTy, i8Ty,
1442 size, alignment);
1443
1444 if (op.getInitAttr()) {
1445 ArrayAttr initAttr = op.getInitAttr();
1446 if (isZero(initAttr)) {
1447 auto i8Ty = rewriter.getI8Type();
1448 auto zero = LLVM::ConstantOp::create(rewriter, op.getLoc(), i8Ty, 0);
1449 LLVM::MemsetOp::create(rewriter, op.getLoc(), alloc, zero, size,
1450 /*isVolatile=*/false);
1451 } else {
1452 initializeArray(rewriter, op.getLoc(), alloc, initAttr, arrayRefType);
1453 }
1454 }
1455
1456 rewriter.replaceOp(op, alloc);
1457 return success();
1458 }
1459
1460 // Computes the required alignment for an AllocaOp of the given type.
1461 // c.f. HWToLLVM.cpp.
1462 size_t computeAllocaAlignment(ArrayRefType type, Operation *op) const {
1463 if (alignmentCache.count(type)) {
1464 return alignmentCache[type];
1465 }
1466 auto dl = DataLayout::closest(op);
1467 auto hwType =
1468 hw::ArrayType::get(type.getElementType(), type.getNumElements());
1469 auto llvmType = getTypeConverter()->convertType(hwType);
1470 auto alignment =
1471 static_cast<unsigned>(dl.getTypePreferredAlignment(llvmType));
1472 alignment = std::max(4u, alignment);
1473 alignmentCache[type] = alignment;
1474 return alignment;
1475 }
1476
1477 bool isZero(ArrayAttr arrayAttr) const {
1478 return llvm::all_of(arrayAttr.getAsValueRange<IntegerAttr>(),
1479 [](APInt i) { return i.isZero(); });
1480 }
1481
1482 void initializeArray(ConversionPatternRewriter &rewriter, Location loc,
1483 Value alloc, ArrayAttr initAttr,
1484 ArrayRefType arrayRefType) const {
1485 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1486 Type ptrTy = LLVM::LLVMPointerType::get(getContext());
1487 Type i8Ty = rewriter.getI8Type();
1488 for (unsigned i = 0; i < arrayRefType.getNumElements(); ++i) {
1489 unsigned elemIndex = arrayRefType.getNumElements() - i - 1;
1490 Value elemOffset = LLVM::ConstantOp::create(
1491 rewriter, loc, rewriter.getI64Type(), elemIndex * elemByteWidth);
1492 auto elemAddr =
1493 LLVM::GEPOp::create(rewriter, loc, ptrTy, i8Ty, alloc, elemOffset);
1494 auto elem = LLVM::ConstantOp::create(
1495 rewriter, loc, arrayRefType.getElementType(), initAttr[i]);
1496 LLVM::StoreOp::create(rewriter, loc, elem, elemAddr);
1497 }
1498 }
1499
1500private:
1501 mutable DenseMap<ArrayRefType, size_t> alignmentCache;
1502};
1503
1504struct ArrayRefCreateOpLowering : public OpConversionPattern<ArrayRefCreateOp> {
1505 using OpConversionPattern::OpConversionPattern;
1506
1507 LogicalResult
1508 matchAndRewrite(ArrayRefCreateOp op, OpAdaptor adaptor,
1509 ConversionPatternRewriter &rewriter) const override {
1510 ArrayRefType arrayRefType = cast<ArrayRefType>(op.getType());
1511 Value alloc = adaptor.getInput();
1512 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1513 auto i8Ty = rewriter.getI8Type();
1514 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1515 auto elements = adaptor.getElements();
1516 for (unsigned i = 0; i < elements.size(); ++i) {
1517 // Note: hardcoded for little endian targets.
1518 unsigned elemIndex = arrayRefType.getNumElements() - i - 1;
1519 Value elemOffset =
1520 LLVM::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI64Type(),
1521 elemIndex * elemByteWidth);
1522 auto elemAddr = LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, i8Ty,
1523 alloc, elemOffset);
1524 LLVM::StoreOp::create(rewriter, op.getLoc(), elements[i], elemAddr);
1525 }
1526 rewriter.replaceOp(op, alloc);
1527 return success();
1528 }
1529};
1530
1531struct ArrayRefGetOpLowering : public OpConversionPattern<ArrayRefGetOp> {
1532 using OpConversionPattern::OpConversionPattern;
1533
1534 LogicalResult
1535 matchAndRewrite(ArrayRefGetOp op, OpAdaptor adaptor,
1536 ConversionPatternRewriter &rewriter) const override {
1537 auto loc = op.getLoc();
1538 ArrayRefType arrayRefType = cast<ArrayRefType>(op.getInput().getType());
1539 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1540 auto i8Ty = rewriter.getI8Type();
1541 auto i64Ty = rewriter.getI64Type();
1542 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1543 assert(!isa<ArrayRefType>(arrayRefType.getElementType()));
1544
1545 Value stride =
1546 LLVM::ConstantOp::create(rewriter, loc, i64Ty, elemByteWidth);
1547 Value byteOffset =
1548 LLVM::MulOp::create(rewriter, loc, adaptor.getIndex(), stride);
1549 // Defend against out-of-bounds accesses. What we return is undefined in the
1550 // case of OOB.
1551 size_t lastElementByteOffset =
1552 elemByteWidth * (arrayRefType.getNumElements() - 1);
1553 Value lastElementByteOffsetVal =
1554 LLVM::ConstantOp::create(rewriter, loc, i64Ty, lastElementByteOffset);
1555 Value clampedOffset = LLVM::UMinOp::create(rewriter, loc, i64Ty, byteOffset,
1556 lastElementByteOffsetVal);
1557 auto elemAddr = LLVM::GEPOp::create(rewriter, loc, ptrTy, i8Ty,
1558 adaptor.getInput(), clampedOffset);
1559 Value loaded = LLVM::LoadOp::create(
1560 rewriter, loc, typeConverter->convertType(op.getValue().getType()),
1561 elemAddr);
1562 rewriter.replaceOp(op, loaded);
1563 return success();
1564 }
1565};
1566
1567struct ArrayRefInjectOpLowering : public OpConversionPattern<ArrayRefInjectOp> {
1568 using OpConversionPattern::OpConversionPattern;
1569
1570 LogicalResult
1571 matchAndRewrite(ArrayRefInjectOp op, OpAdaptor adaptor,
1572 ConversionPatternRewriter &rewriter) const override {
1573 auto loc = op.getLoc();
1574 ArrayRefType arrayRefType = cast<ArrayRefType>(op.getInput().getType());
1575 assert(!isa<ArrayRefType>(arrayRefType.getElementType()));
1576 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1577 auto i8Ty = rewriter.getI8Type();
1578 auto i64Ty = rewriter.getI64Type();
1579 size_t byteWidth = computeByteWidth(arrayRefType);
1580 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1581
1582 Value stride =
1583 LLVM::ConstantOp::create(rewriter, loc, i64Ty, elemByteWidth);
1584 Value byteOffset =
1585 LLVM::MulOp::create(rewriter, loc, adaptor.getIndex(), stride);
1586 Value totalSize = LLVM::ConstantOp::create(rewriter, loc, i64Ty, byteWidth);
1587 // Defend against out-of-bounds accesses. We must avoid corrupting the
1588 // array.
1589 Value isInbounds = LLVM::ICmpOp::create(
1590 rewriter, loc, LLVM::ICmpPredicate::ult, byteOffset, totalSize);
1591 scf::IfOp::create(rewriter, loc, isInbounds, [&](OpBuilder &b, Location) {
1592 auto elemAddr = LLVM::GEPOp::create(b, loc, ptrTy, i8Ty,
1593 adaptor.getInput(), byteOffset);
1594 LLVM::StoreOp::create(b, loc, adaptor.getElement(), elemAddr);
1595 scf::YieldOp::create(b, loc);
1596 });
1597
1598 // Inject is pure; returns the same pointer (input buffer is modified
1599 // in-place and the pointer is forwarded as the result).
1600 rewriter.replaceOp(op, adaptor.getInput());
1601 return success();
1602 }
1603};
1604
1605struct ArrayRefSliceOpLowering : public OpConversionPattern<ArrayRefSliceOp> {
1606 using OpConversionPattern::OpConversionPattern;
1607
1608 LogicalResult
1609 matchAndRewrite(ArrayRefSliceOp op, OpAdaptor adaptor,
1610 ConversionPatternRewriter &rewriter) const override {
1611 auto loc = op.getLoc();
1612 // The result type is the sub-array type; use its element size.
1613 ArrayRefType inputType = cast<ArrayRefType>(op.getInput().getType());
1614 ArrayRefType resultType = cast<ArrayRefType>(op.getOutput().getType());
1615 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
1616 auto i8Ty = rewriter.getI8Type();
1617 auto i64Ty = rewriter.getI64Type();
1618 size_t elemByteWidth = computeElementByteWidth(resultType);
1619
1620 // Ensure the slice doesn't go out of bounds.
1621 size_t maxLowIndex =
1622 inputType.getNumElements() - resultType.getNumElements();
1623 Value maxLowIndexVal =
1624 LLVM::ConstantOp::create(rewriter, loc, i64Ty, maxLowIndex);
1625 Value clampedLowIndex = LLVM::UMinOp::create(
1626 rewriter, loc, i64Ty, adaptor.getLowIndex(), maxLowIndexVal);
1627
1628 // Byte offset = lowIndex * elemByteWidth.
1629 Value stride =
1630 LLVM::ConstantOp::create(rewriter, loc, i64Ty, elemByteWidth);
1631 Value byteOffset =
1632 LLVM::MulOp::create(rewriter, loc, clampedLowIndex, stride);
1633 auto sliceAddr = LLVM::GEPOp::create(rewriter, loc, ptrTy, i8Ty,
1634 adaptor.getInput(), byteOffset);
1635 rewriter.replaceOp(op, sliceAddr);
1636 return success();
1637 }
1638};
1639
1640struct ArrayRefCopyOpLowering : public OpConversionPattern<ArrayRefCopyOp> {
1641 using OpConversionPattern::OpConversionPattern;
1642
1643 LogicalResult
1644 matchAndRewrite(ArrayRefCopyOp op, OpAdaptor adaptor,
1645 ConversionPatternRewriter &rewriter) const override {
1646 auto loc = op.getLoc();
1647 ArrayRefType arrayRefType = cast<ArrayRefType>(op.getInput().getType());
1648 auto i64Ty = rewriter.getI64Type();
1649 size_t byteWidth = computeByteWidth(arrayRefType);
1650 Value size = LLVM::ConstantOp::create(rewriter, loc, i64Ty, byteWidth);
1651 // Use a memmove rather than a memcpy just in case the arrays alias.
1652 LLVM::MemmoveOp::create(rewriter, loc, adaptor.getInput(),
1653 adaptor.getSource(), size,
1654 /*isVolatile=*/false);
1655 rewriter.replaceOp(op, adaptor.getInput());
1656 return success();
1657 }
1658};
1659
1660static Value loadArrayRefAsArray(ImplicitLocOpBuilder &builder, Value arrayRef,
1661 ArrayRefType arrayRefType,
1662 LLVM::LLVMArrayType llvmType) {
1663 auto i8Ty = builder.getI8Type();
1664 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
1665 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1666 Value v = LLVM::PoisonOp::create(builder, llvmType);
1667 int32_t size = arrayRefType.getNumElements();
1668 for (int32_t i = 0; i < size; i++) {
1669 int32_t byteOffset = i * elemByteWidth;
1670 Value gep = LLVM::GEPOp::create(builder, ptrTy, i8Ty, arrayRef,
1671 LLVM::GEPArg{byteOffset});
1672 Value load = LLVM::LoadOp::create(builder, llvmType.getElementType(), gep);
1673 v = LLVM::InsertValueOp::create(builder, v, load, i);
1674 }
1675 return v;
1676}
1677
1678static void storeArrayAsArrayRef(ImplicitLocOpBuilder &builder, Value array,
1679 Value arrayRef, ArrayRefType arrayRefType) {
1680 auto i8Ty = builder.getI8Type();
1681 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
1682 size_t elemByteWidth = computeElementByteWidth(arrayRefType);
1683 int32_t size = arrayRefType.getNumElements();
1684 for (int32_t i = 0; i < size; i++) {
1685 int32_t byteOffset = i * elemByteWidth;
1686 Value gep = LLVM::GEPOp::create(builder, ptrTy, i8Ty, arrayRef,
1687 LLVM::GEPArg{byteOffset});
1688 Value val = LLVM::ExtractValueOp::create(builder, array, i);
1689 LLVM::StoreOp::create(builder, val, gep);
1690 }
1691}
1692
1694 : public OpConversionPattern<UnrealizedConversionCastOp> {
1695 using OpConversionPattern::OpConversionPattern;
1696
1697 LogicalResult
1698 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
1699 ConversionPatternRewriter &rewriter) const override {
1700 if (!isa<ArrayRefType>(op.getOperand(0).getType()) ||
1701 !isa<LLVM::LLVMArrayType>(op.getResult(0).getType())) {
1702 return failure();
1703 }
1704
1705 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1706 Value loaded = loadArrayRefAsArray(
1707 b, adaptor.getInputs().front(),
1708 cast<ArrayRefType>(op.getOperand(0).getType()),
1709 cast<LLVM::LLVMArrayType>(op.getResult(0).getType()));
1710 rewriter.replaceOp(op, loaded);
1711 return success();
1712 }
1713};
1714
1716 : public OpConversionPattern<ArrayRefToArrayOp> {
1717 using OpConversionPattern::OpConversionPattern;
1718
1719 LogicalResult
1720 matchAndRewrite(ArrayRefToArrayOp op, OpAdaptor adaptor,
1721 ConversionPatternRewriter &rewriter) const override {
1722 Type resultType = getTypeConverter()->convertType(op.getResult().getType());
1723 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1724 Value loaded = loadArrayRefAsArray(
1725 b, adaptor.getInput(), cast<ArrayRefType>(op.getInput().getType()),
1726 cast<LLVM::LLVMArrayType>(resultType));
1727 rewriter.replaceOp(op, loaded);
1728 return success();
1729 }
1730};
1731
1733 : public OpConversionPattern<ArrayRefFromArrayOp> {
1734 using OpConversionPattern::OpConversionPattern;
1735
1736 LogicalResult
1737 matchAndRewrite(ArrayRefFromArrayOp op, OpAdaptor adaptor,
1738 ConversionPatternRewriter &rewriter) const override {
1739 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1740 storeArrayAsArrayRef(b, adaptor.getArray(), adaptor.getInput(),
1741 cast<ArrayRefType>(op.getInput().getType()));
1742 rewriter.replaceOp(op, adaptor.getInput());
1743 return success();
1744 }
1745};
1746
1747//===----------------------------------------------------------------------===//
1748// Pass Implementation
1749//===----------------------------------------------------------------------===//
1750
1751namespace {
1752struct LowerArcToLLVMPass
1753 : public circt::impl::LowerArcToLLVMBase<LowerArcToLLVMPass> {
1754 void runOnOperation() override;
1755};
1756} // namespace
1757
1758void LowerArcToLLVMPass::runOnOperation() {
1759 // Add `dereferenceable(<N>)` attributes to all function arguments that take
1760 // ArrayRefTypes.
1761 for (func::FuncOp func : getOperation().getOps<func::FuncOp>()) {
1762 for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
1763 if (auto arrayRefType =
1764 dyn_cast<ArrayRefType>(func.getArgumentTypes()[i])) {
1765 size_t byteWidth = computeByteWidth(arrayRefType);
1766 Builder builder(&getContext());
1767 func.setArgAttr(i, LLVM::LLVMDialect::getDereferenceableAttrName(),
1768 builder.getI64IntegerAttr(byteWidth));
1769 }
1770 }
1771 }
1772
1773 // Collect the symbols in the root op such that the HW-to-LLVM lowering can
1774 // create LLVM globals with non-colliding names.
1775 Namespace globals;
1776 SymbolCache cache;
1777 cache.addDefinitions(getOperation());
1778 globals.add(cache);
1779
1780 // Setup the conversion target. Explicitly mark `scf.yield` legal since it
1781 // does not have a conversion itself, which would cause it to fail
1782 // legalization and for the conversion to abort. (It relies on its parent op's
1783 // conversion to remove it.)
1784 LLVMConversionTarget target(getContext());
1785 target.addLegalOp<mlir::ModuleOp>();
1786 target.addLegalOp<scf::YieldOp>(); // quirk of SCF dialect conversion
1787
1788 // Mark sim::Format*Op as legal. These are not converted to LLVM, but the
1789 // lowering of sim::PrintFormattedOp walks them to build up its format string.
1790 // They are all marked Pure so are removed after the conversion.
1791 target.addLegalOp<sim::FormatLiteralOp, sim::FormatDecOp, sim::FormatHexOp,
1792 sim::FormatBinOp, sim::FormatOctOp, sim::FormatCharOp,
1793 sim::FormatStringConcatOp>();
1794
1795 // Setup the arc dialect type conversion.
1796 LLVMTypeConverter converter(&getContext());
1797 converter.addConversion([&](seq::ClockType type) {
1798 return IntegerType::get(type.getContext(), 1);
1799 });
1800 converter.addConversion([&](StorageType type) {
1801 return LLVM::LLVMPointerType::get(type.getContext());
1802 });
1803 converter.addConversion([&](MemoryType type) {
1804 return LLVM::LLVMPointerType::get(type.getContext());
1805 });
1806 converter.addConversion([&](StateType type) {
1807 return LLVM::LLVMPointerType::get(type.getContext());
1808 });
1809 converter.addConversion([&](SimModelInstanceType type) {
1810 return LLVM::LLVMPointerType::get(type.getContext());
1811 });
1812 converter.addConversion([&](sim::FormatStringType type) {
1813 return LLVM::LLVMPointerType::get(type.getContext());
1814 });
1815 converter.addConversion([&](llhd::TimeType type) {
1816 // LLHD time is represented as i64 femtoseconds.
1817 return IntegerType::get(type.getContext(), 64);
1818 });
1819 converter.addConversion([&](ArrayRefType type) {
1820 return LLVM::LLVMPointerType::get(type.getContext());
1821 });
1822
1823 // Convert an UnrealizedConversionCastOp from !arc.arrayref<T> to
1824 // !llvm.array<T>. These are inserted by the InsertRuntime pass.
1825 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>([&](Operation *op) {
1826 Type src = op->getOperand(0).getType();
1827 Type dst = op->getResult(0).getType();
1828 bool needsConvert = isa<ArrayRefType>(src) && isa<LLVM::LLVMArrayType>(dst);
1829 return !needsConvert;
1830 });
1831
1832 // Setup the conversion patterns.
1833 ConversionPatternSet patterns(&getContext(), converter);
1834
1835 // MLIR patterns.
1836 populateSCFToControlFlowConversionPatterns(patterns);
1837 populateFuncToLLVMConversionPatterns(converter, patterns);
1838 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
1839 arith::populateArithToLLVMConversionPatterns(converter, patterns);
1840 index::populateIndexToLLVMConversionPatterns(converter, patterns);
1841 populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter);
1842
1843 // CIRCT patterns.
1844 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
1846 std::optional<HWToLLVMArraySpillCache> spillCacheOpt =
1848 {
1849 OpBuilder spillBuilder(getOperation());
1850 spillCacheOpt->spillNonHWOps(spillBuilder, converter, getOperation());
1851 }
1852 populateHWToLLVMConversionPatterns(converter, patterns, globals,
1853 constAggregateGlobalsMap, spillCacheOpt);
1854
1857
1858 // Arc patterns.
1859 // clang-format off
1860 patterns.add<
1861 AllocMemoryOpLowering,
1862 AllocStateLikeOpLowering<arc::AllocStateOp>,
1863 AllocStateLikeOpLowering<arc::RootInputOp>,
1864 AllocStateLikeOpLowering<arc::RootOutputOp>,
1865 AllocStorageOpLowering,
1866 ClockGateOpLowering,
1867 ClockInvOpLowering,
1868 ConstantTimeOpLowering,
1869 CurrentTimeOpLowering,
1870 GetNextWakeupOpLowering,
1871 IntToTimeOpLowering,
1872 MemoryReadOpLowering,
1873 MemoryWriteOpLowering,
1874 ModelOpLowering,
1875 ReplaceOpWithInputPattern<seq::ToClockOp>,
1876 ReplaceOpWithInputPattern<seq::FromClockOp>,
1878 SeqConstClockLowering,
1879 SetNextWakeupOpLowering,
1880 SimGetNextWakeupOpLowering,
1881 SimGetTimeOpLowering,
1882 SimSetTimeOpLowering,
1883 StateReadOpLowering,
1884 StateWriteOpLowering,
1885 StorageGetOpLowering,
1886 TerminateOpLowering,
1887 TimeToIntOpLowering,
1888 ZeroCountOpLowering,
1898 >(converter, &getContext());
1899 // clang-format on
1900 patterns.add<ExecuteOp>(convert);
1901
1902 StringCache stringCache;
1903 patterns.add<SimEmitValueOpLowering, SimPrintFormattedProcOpLowering>(
1904 converter, &getContext(), stringCache);
1905
1906 auto &modelInfo = getAnalysis<ModelInfoAnalysis>();
1907 llvm::DenseMap<StringRef, ModelInfoMap> modelMap(modelInfo.infoMap.size());
1908 for (auto &[_, modelInfo] : modelInfo.infoMap) {
1909 llvm::DenseMap<StringRef, StateInfo> states(modelInfo.states.size());
1910 for (StateInfo &stateInfo : modelInfo.states)
1911 states.insert({stateInfo.name, stateInfo});
1912 modelMap.insert(
1913 {modelInfo.name,
1914 ModelInfoMap{modelInfo.numStateBytes, std::move(states),
1915 modelInfo.initialFnSym, modelInfo.finalFnSym}});
1916 }
1917
1918 patterns.add<SimInstantiateOpLowering, SimSetInputOpLowering,
1919 SimGetPortOpLowering, SimStepOpLowering>(
1920 converter, &getContext(), modelMap);
1921
1922 // Apply the conversion.
1923 ConversionConfig config;
1924 config.allowPatternRollback = false;
1925 if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
1926 config)))
1927 signalPassFailure();
1928}
1929
1930std::unique_ptr<OperationPass<ModuleOp>> circt::createLowerArcToLLVMPass() {
1931 return std::make_unique<LowerArcToLLVMPass>();
1932}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static LLVM::GlobalOp buildGlobalConstantIntArray(OpBuilder &builder, Location loc, Twine symName, SmallVectorImpl< T > &data, unsigned alignment=alignof(T))
static LLVM::GlobalOp buildGlobalConstantRuntimeStructArray(OpBuilder &builder, Location loc, Twine symName, SmallVectorImpl< T > &array)
static Value loadArrayRefAsArray(ImplicitLocOpBuilder &builder, Value arrayRef, ArrayRefType arrayRefType, LLVM::LLVMArrayType llvmType)
size_t computeByteWidth(ArrayRefType type)
static llvm::Twine evalSymbolFromModelName(StringRef modelName)
size_t computeElementByteWidth(ArrayRefType arrayRefType)
static void storeArrayAsArrayRef(ImplicitLocOpBuilder &builder, Value array, Value arrayRef, ArrayRefType arrayRefType)
static LogicalResult convert(arc::ExecuteOp op, arc::ExecuteOp::Adaptor adaptor, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
Extension of RewritePatternSet that allows adding matchAndRewrite functions with op adaptors and Conv...
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
#define ARC_RUNTIME_API_VERSION
Version of the combined public and internal API.
Definition Common.h:27
Definition arc.py:1
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateCombToArithConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
void populateCombToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns)
Get the Comb to LLVM conversion patterns.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap, std::optional< HWToLLVMArraySpillCache > &spillCacheOpt)
Get the HW to LLVM conversion patterns.
std::unique_ptr< OperationPass< ModuleOp > > createLowerArcToLLVMPass()
Definition hw.py:1
Definition sim.py:1
Static information for a compiled hardware model, generated by the MLIR lowering.
Definition Common.h:70
uint32_t typeBits
Bit width of the traced signal.
Definition TraceTaps.h:28
uint64_t stateOffset
Byte offset of the traced value within the model state.
Definition TraceTaps.h:23
uint64_t nameOffset
Byte offset to the null terminator of this signal's last alias in the names array.
Definition TraceTaps.h:26
uint32_t reserved
Padding and reserved for future use.
Definition TraceTaps.h:30
void initializeArray(ConversionPatternRewriter &rewriter, Location loc, Value alloc, ArrayAttr initAttr, ArrayRefType arrayRefType) const
size_t computeAllocaAlignment(ArrayRefType type, Operation *op) const
LogicalResult matchAndRewrite(ArrayRefAllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
bool isZero(ArrayAttr arrayAttr) const
DenseMap< ArrayRefType, size_t > alignmentCache
LogicalResult matchAndRewrite(ArrayRefCopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefCreateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefFromArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefGetOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefInjectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ArrayRefToArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LLVM::GlobalOp buildTraceInfoStruct(arc::RuntimeModelOp &op, ConversionPatternRewriter &rewriter) const
static constexpr uint64_t runtimeApiVersion
LogicalResult matchAndRewrite(arc::RuntimeModelOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final
Helper class mapping array values (HW or LLVM Dialect) to pointers to buffers containing the array va...
Definition HWToLLVM.h:47