18#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
19#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
20#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
22#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
23#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
24#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
25#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
26#include "mlir/Dialect/Func/IR/FuncOps.h"
27#include "mlir/Dialect/Index/IR/IndexOps.h"
28#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
29#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
30#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
31#include "mlir/Dialect/SCF/IR/SCF.h"
32#include "mlir/IR/BuiltinDialect.h"
33#include "mlir/Pass/Pass.h"
34#include "mlir/Transforms/DialectConversion.h"
35#include "llvm/Support/Debug.h"
37#define DEBUG_TYPE "lower-arc-to-llvm"
40#define GEN_PASS_DEF_LOWERARCTOLLVM
41#include "circt/Conversion/Passes.h.inc"
54 return modelName +
"_eval";
60 using OpConversionPattern::OpConversionPattern;
62 matchAndRewrite(arc::ModelOp op, OpAdaptor adaptor,
63 ConversionPatternRewriter &rewriter)
const final {
65 IRRewriter::InsertionGuard guard(rewriter);
66 rewriter.setInsertionPointToEnd(&op.getBodyBlock());
67 rewriter.create<func::ReturnOp>(op.getLoc());
72 rewriter.getFunctionType(op.getBody().getArgumentTypes(), {});
74 rewriter.create<mlir::func::FuncOp>(op.getLoc(), funcName, funcType);
75 rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
81struct AllocStorageOpLowering
83 using OpConversionPattern::OpConversionPattern;
85 matchAndRewrite(arc::AllocStorageOp op, OpAdaptor adaptor,
86 ConversionPatternRewriter &rewriter)
const final {
87 auto type = typeConverter->convertType(op.getType());
88 if (!op.getOffset().has_value())
90 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, type, rewriter.getI8Type(),
92 LLVM::GEPArg(*op.getOffset()));
97template <
class ConcreteOp>
101 using OpAdaptor =
typename ConcreteOp::Adaptor;
104 matchAndRewrite(ConcreteOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter)
const final {
107 auto offsetAttr = op->template getAttrOfType<IntegerAttr>(
"offset");
110 Value ptr = rewriter.create<LLVM::GEPOp>(
111 op->getLoc(), adaptor.getStorage().getType(), rewriter.getI8Type(),
112 adaptor.getStorage(),
113 LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
114 rewriter.replaceOp(op, ptr);
120 using OpConversionPattern::OpConversionPattern;
122 matchAndRewrite(arc::StateReadOp op, OpAdaptor adaptor,
123 ConversionPatternRewriter &rewriter)
const final {
124 auto type = typeConverter->convertType(op.getType());
125 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, type, adaptor.getState());
131 using OpConversionPattern::OpConversionPattern;
133 matchAndRewrite(arc::StateWriteOp op, OpAdaptor adaptor,
134 ConversionPatternRewriter &rewriter)
const final {
135 if (adaptor.getCondition()) {
136 rewriter.replaceOpWithNewOp<scf::IfOp>(
137 op, adaptor.getCondition(), [&](
auto &builder,
auto loc) {
138 builder.template create<LLVM::StoreOp>(loc, adaptor.getValue(),
140 builder.template create<scf::YieldOp>(loc);
143 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
151 using OpConversionPattern::OpConversionPattern;
153 matchAndRewrite(arc::AllocMemoryOp op, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter)
const final {
155 auto offsetAttr = op->getAttrOfType<IntegerAttr>(
"offset");
158 Value ptr = rewriter.create<LLVM::GEPOp>(
159 op.getLoc(), adaptor.getStorage().getType(), rewriter.getI8Type(),
160 adaptor.getStorage(),
161 LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
163 rewriter.replaceOp(op, ptr);
169 using OpConversionPattern::OpConversionPattern;
171 matchAndRewrite(arc::StorageGetOp op, OpAdaptor adaptor,
172 ConversionPatternRewriter &rewriter)
const final {
173 Value offset = rewriter.create<LLVM::ConstantOp>(
174 op.getLoc(), rewriter.getI32Type(), op.getOffsetAttr());
175 Value ptr = rewriter.create<LLVM::GEPOp>(
176 op.getLoc(), adaptor.getStorage().getType(), rewriter.getI8Type(),
177 adaptor.getStorage(), offset);
178 rewriter.replaceOp(op, ptr);
188static MemoryAccess prepareMemoryAccess(Location loc, Value memory,
189 Value address, MemoryType type,
190 ConversionPatternRewriter &rewriter) {
191 auto zextAddrType = rewriter.getIntegerType(
192 cast<IntegerType>(address.getType()).getWidth() + 1);
193 Value
addr = rewriter.create<LLVM::ZExtOp>(loc, zextAddrType, address);
194 Value addrLimit = rewriter.create<LLVM::ConstantOp>(
195 loc, zextAddrType, rewriter.getI32IntegerAttr(type.getNumWords()));
196 Value withinBounds = rewriter.create<LLVM::ICmpOp>(
197 loc, LLVM::ICmpPredicate::ult,
addr, addrLimit);
198 Value ptr = rewriter.create<LLVM::GEPOp>(
199 loc, LLVM::LLVMPointerType::get(memory.getContext()),
200 rewriter.getIntegerType(type.getStride() * 8), memory, ValueRange{
addr});
201 return {ptr, withinBounds};
205 using OpConversionPattern::OpConversionPattern;
207 matchAndRewrite(arc::MemoryReadOp op, OpAdaptor adaptor,
208 ConversionPatternRewriter &rewriter)
const final {
209 auto type = typeConverter->convertType(op.getType());
210 auto memoryType = cast<MemoryType>(op.getMemory().getType());
212 prepareMemoryAccess(op.getLoc(), adaptor.getMemory(),
213 adaptor.getAddress(), memoryType, rewriter);
217 rewriter.replaceOpWithNewOp<scf::IfOp>(
218 op, access.withinBounds,
219 [&](
auto &builder,
auto loc) {
220 Value loadOp = builder.template create<LLVM::LoadOp>(
221 loc, memoryType.getWordType(), access.ptr);
222 builder.template create<scf::YieldOp>(loc, loadOp);
224 [&](
auto &builder,
auto loc) {
225 Value zeroValue = builder.template create<LLVM::ConstantOp>(
226 loc, type, builder.getI64IntegerAttr(0));
227 builder.template create<scf::YieldOp>(loc, zeroValue);
234 using OpConversionPattern::OpConversionPattern;
236 matchAndRewrite(arc::MemoryWriteOp op, OpAdaptor adaptor,
237 ConversionPatternRewriter &rewriter)
const final {
238 auto access = prepareMemoryAccess(
239 op.getLoc(), adaptor.getMemory(), adaptor.getAddress(),
240 cast<MemoryType>(op.getMemory().getType()), rewriter);
241 auto enable = access.withinBounds;
242 if (adaptor.getEnable())
243 enable = rewriter.create<LLVM::AndOp>(op.getLoc(), adaptor.getEnable(),
247 rewriter.replaceOpWithNewOp<scf::IfOp>(
248 op, enable, [&](
auto &builder,
auto loc) {
249 builder.template create<LLVM::StoreOp>(loc, adaptor.getData(),
251 builder.template create<scf::YieldOp>(loc);
259 using OpConversionPattern::OpConversionPattern;
261 matchAndRewrite(seq::ClockGateOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const final {
263 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, adaptor.getInput(),
264 adaptor.getEnable());
271 using OpConversionPattern::OpConversionPattern;
273 matchAndRewrite(seq::ClockInverterOp op, OpAdaptor adaptor,
274 ConversionPatternRewriter &rewriter)
const final {
275 auto constTrue = rewriter.create<LLVM::ConstantOp>(op->getLoc(),
276 rewriter.getI1Type(), 1);
277 rewriter.replaceOpWithNewOp<LLVM::XOrOp>(op, adaptor.getInput(), constTrue);
283 using OpConversionPattern::OpConversionPattern;
285 matchAndRewrite(arc::ZeroCountOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter)
const override {
288 IntegerAttr isZeroPoison = rewriter.getBoolAttr(
true);
290 if (op.getPredicate() == arc::ZeroCountPredicate::leading) {
291 rewriter.replaceOpWithNewOp<LLVM::CountLeadingZerosOp>(
292 op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
296 rewriter.replaceOpWithNewOp<LLVM::CountTrailingZerosOp>(
297 op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
303 using OpConversionPattern::OpConversionPattern;
305 matchAndRewrite(seq::ConstClockOp op, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter)
const override {
307 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
308 op, rewriter.getI1Type(),
static_cast<int64_t
>(op.getValue()));
313template <
typename OpTy>
316 using OpAdaptor =
typename OpTy::Adaptor;
318 matchAndRewrite(OpTy op, OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter)
const override {
320 rewriter.replaceOp(op, adaptor.getInput());
334 size_t numStateBytes;
335 llvm::DenseMap<StringRef, StateInfo> states;
336 mlir::FlatSymbolRefAttr initialFnSymbol;
337 mlir::FlatSymbolRefAttr finalFnSymbol;
340template <
typename OpTy>
342 ModelAwarePattern(
const TypeConverter &typeConverter, MLIRContext *context,
343 llvm::DenseMap<StringRef, ModelInfoMap> &modelInfo)
345 modelInfo(modelInfo) {}
348 Value createPtrToPortState(ConversionPatternRewriter &rewriter, Location loc,
349 Value state,
const StateInfo &port)
const {
350 MLIRContext *ctx = rewriter.getContext();
351 return rewriter.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(ctx),
352 IntegerType::get(ctx, 8), state,
353 LLVM::GEPArg(port.
offset));
356 llvm::DenseMap<StringRef, ModelInfoMap> &modelInfo;
361struct SimInstantiateOpLowering
362 :
public ModelAwarePattern<arc::SimInstantiateOp> {
363 using ModelAwarePattern::ModelAwarePattern;
366 matchAndRewrite(arc::SimInstantiateOp op, OpAdaptor adaptor,
367 ConversionPatternRewriter &rewriter)
const final {
368 auto modelIt = modelInfo.find(
369 cast<SimModelInstanceType>(op.getBody().getArgument(0).getType())
372 ModelInfoMap &model = modelIt->second;
374 ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
378 ConversionPatternRewriter::InsertionGuard guard(rewriter);
382 Type convertedIndex = typeConverter->convertType(rewriter.getIndexType());
384 FailureOr<LLVM::LLVMFuncOp> mallocFunc =
385 LLVM::lookupOrCreateMallocFn(rewriter, moduleOp, convertedIndex);
386 if (failed(mallocFunc))
389 FailureOr<LLVM::LLVMFuncOp> freeFunc =
390 LLVM::lookupOrCreateFreeFn(rewriter, moduleOp);
391 if (failed(freeFunc))
394 Location loc = op.getLoc();
395 Value numStateBytes = rewriter.create<LLVM::ConstantOp>(
396 loc, convertedIndex, model.numStateBytes);
397 Value allocated = rewriter
398 .create<LLVM::CallOp>(loc, mallocFunc.value(),
399 ValueRange{numStateBytes})
402 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI8Type(), 0);
403 rewriter.create<LLVM::MemsetOp>(loc, allocated, zero, numStateBytes,
false);
406 if (model.initialFnSymbol) {
407 auto initialFnType = LLVM::LLVMFunctionType::get(
408 LLVM::LLVMVoidType::get(op.getContext()),
409 {LLVM::LLVMPointerType::get(op.getContext())});
410 rewriter.create<LLVM::CallOp>(loc, initialFnType, model.initialFnSymbol,
411 ValueRange{allocated});
415 rewriter.inlineBlockBefore(&adaptor.getBody().getBlocks().front(), op,
419 if (model.finalFnSymbol) {
420 auto finalFnType = LLVM::LLVMFunctionType::get(
421 LLVM::LLVMVoidType::get(op.getContext()),
422 {LLVM::LLVMPointerType::get(op.getContext())});
423 rewriter.create<LLVM::CallOp>(loc, finalFnType, model.finalFnSymbol,
424 ValueRange{allocated});
427 rewriter.create<LLVM::CallOp>(loc, freeFunc.value(), ValueRange{allocated});
428 rewriter.eraseOp(op);
434struct SimSetInputOpLowering :
public ModelAwarePattern<arc::SimSetInputOp> {
435 using ModelAwarePattern::ModelAwarePattern;
438 matchAndRewrite(arc::SimSetInputOp op, OpAdaptor adaptor,
439 ConversionPatternRewriter &rewriter)
const final {
441 modelInfo.find(cast<SimModelInstanceType>(op.getInstance().getType())
444 ModelInfoMap &model = modelIt->second;
446 auto portIt = model.states.find(op.getInput());
447 if (portIt == model.states.end()) {
450 rewriter.eraseOp(op);
455 Value statePtr = createPtrToPortState(rewriter, op.getLoc(),
456 adaptor.getInstance(), port);
457 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
464struct SimGetPortOpLowering :
public ModelAwarePattern<arc::SimGetPortOp> {
465 using ModelAwarePattern::ModelAwarePattern;
468 matchAndRewrite(arc::SimGetPortOp op, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter)
const final {
471 modelInfo.find(cast<SimModelInstanceType>(op.getInstance().getType())
474 ModelInfoMap &model = modelIt->second;
476 auto type = typeConverter->convertType(op.getValue().getType());
479 auto portIt = model.states.find(op.getPort());
480 if (portIt == model.states.end()) {
483 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, type, 0);
488 Value statePtr = createPtrToPortState(rewriter, op.getLoc(),
489 adaptor.getInstance(), port);
490 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, type, statePtr);
496struct SimStepOpLowering :
public ModelAwarePattern<arc::SimStepOp> {
497 using ModelAwarePattern::ModelAwarePattern;
500 matchAndRewrite(arc::SimStepOp op, OpAdaptor adaptor,
501 ConversionPatternRewriter &rewriter)
const final {
502 StringRef modelName = cast<SimModelInstanceType>(op.getInstance().getType())
506 StringAttr evalFunc =
508 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, std::nullopt, evalFunc,
509 adaptor.getInstance());
518struct SimEmitValueOpLowering
520 using OpConversionPattern::OpConversionPattern;
523 matchAndRewrite(arc::SimEmitValueOp op, OpAdaptor adaptor,
524 ConversionPatternRewriter &rewriter)
const final {
525 auto valueType = dyn_cast<IntegerType>(adaptor.getValue().getType());
529 Location loc = op.getLoc();
531 ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
538 Value toPrint = adaptor.getValue();
539 DataLayout layout = DataLayout::closest(op);
540 llvm::TypeSize sizeOfSizeT =
541 layout.getTypeSizeInBits(rewriter.getIndexType());
542 assert(!sizeOfSizeT.isScalable() &&
543 sizeOfSizeT.getFixedValue() <= std::numeric_limits<unsigned>::max());
544 bool truncated =
false;
545 if (valueType.getWidth() > sizeOfSizeT) {
546 toPrint = rewriter.create<LLVM::TruncOp>(
547 loc, IntegerType::get(getContext(), sizeOfSizeT.getFixedValue()),
550 }
else if (valueType.getWidth() < sizeOfSizeT)
551 toPrint = rewriter.create<LLVM::ZExtOp>(
552 loc, IntegerType::get(getContext(), sizeOfSizeT.getFixedValue()),
556 auto printfFunc = LLVM::lookupOrCreateFn(
557 rewriter, moduleOp,
"printf", LLVM::LLVMPointerType::get(getContext()),
558 LLVM::LLVMVoidType::get(getContext()),
true);
559 if (failed(printfFunc))
563 SmallString<16> formatStrName{
"_arc_sim_emit_"};
564 formatStrName.append(truncated ?
"trunc_" :
"full_");
565 formatStrName.append(adaptor.getValueName());
566 LLVM::GlobalOp formatStrGlobal;
567 if (!(formatStrGlobal =
568 moduleOp.lookupSymbol<LLVM::GlobalOp>(formatStrName))) {
569 ConversionPatternRewriter::InsertionGuard insertGuard(rewriter);
571 SmallString<16> formatStr = adaptor.getValueName();
572 formatStr.append(
" = ");
574 formatStr.append(
"(truncated) ");
575 formatStr.append(
"%zx\n");
576 SmallVector<char> formatStrVec{formatStr.begin(), formatStr.end()};
577 formatStrVec.push_back(0);
579 rewriter.setInsertionPointToStart(moduleOp.getBody());
581 LLVM::LLVMArrayType::get(rewriter.getI8Type(), formatStrVec.size());
582 formatStrGlobal = rewriter.create<LLVM::GlobalOp>(
583 loc, globalType,
true, LLVM::Linkage::Internal,
584 formatStrName, rewriter.getStringAttr(formatStrVec),
588 Value formatStrGlobalPtr =
589 rewriter.create<LLVM::AddressOfOp>(loc, formatStrGlobal);
590 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
591 op, printfFunc.value(), ValueRange{formatStrGlobalPtr, toPrint});
604struct LowerArcToLLVMPass
605 :
public circt::impl::LowerArcToLLVMBase<LowerArcToLLVMPass> {
606 void runOnOperation()
override;
610void LowerArcToLLVMPass::runOnOperation() {
622 LLVMConversionTarget target(getContext());
623 target.addLegalOp<mlir::ModuleOp>();
624 target.addLegalOp<scf::YieldOp>();
627 LLVMTypeConverter converter(&getContext());
628 converter.addConversion([&](seq::ClockType type) {
629 return IntegerType::get(type.getContext(), 1);
631 converter.addConversion([&](StorageType type) {
632 return LLVM::LLVMPointerType::get(type.getContext());
634 converter.addConversion([&](MemoryType type) {
635 return LLVM::LLVMPointerType::get(type.getContext());
637 converter.addConversion([&](StateType type) {
638 return LLVM::LLVMPointerType::get(type.getContext());
640 converter.addConversion([&](SimModelInstanceType type) {
641 return LLVM::LLVMPointerType::get(type.getContext());
645 RewritePatternSet
patterns(&getContext());
648 populateSCFToControlFlowConversionPatterns(
patterns);
649 populateFuncToLLVMConversionPatterns(converter,
patterns);
650 cf::populateControlFlowToLLVMConversionPatterns(converter,
patterns);
651 arith::populateArithToLLVMConversionPatterns(converter,
patterns);
652 index::populateIndexToLLVMConversionPatterns(converter,
patterns);
653 populateAnyFunctionOpInterfaceTypeConversionPattern(
patterns, converter);
656 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
658 constAggregateGlobalsMap);
666 AllocMemoryOpLowering,
667 AllocStateLikeOpLowering<arc::AllocStateOp>,
668 AllocStateLikeOpLowering<arc::RootInputOp>,
669 AllocStateLikeOpLowering<arc::RootOutputOp>,
670 AllocStorageOpLowering,
673 MemoryReadOpLowering,
674 MemoryWriteOpLowering,
676 ReplaceOpWithInputPattern<seq::ToClockOp>,
677 ReplaceOpWithInputPattern<seq::FromClockOp>,
678 SeqConstClockLowering,
679 SimEmitValueOpLowering,
681 StateWriteOpLowering,
682 StorageGetOpLowering,
684 >(converter, &getContext());
687 SmallVector<ModelInfo> models;
693 llvm::DenseMap<StringRef, ModelInfoMap> modelMap(models.size());
695 llvm::DenseMap<StringRef, StateInfo> states(modelInfo.states.size());
696 for (
StateInfo &stateInfo : modelInfo.states)
697 states.insert({stateInfo.name, stateInfo});
700 ModelInfoMap{modelInfo.numStateBytes, std::move(states),
701 modelInfo.initialFnSym, modelInfo.finalFnSym}});
704 patterns.add<SimInstantiateOpLowering, SimSetInputOpLowering,
705 SimGetPortOpLowering, SimStepOpLowering>(
706 converter, &getContext(), modelMap);
709 if (failed(applyFullConversion(getOperation(), target, std::move(
patterns))))
714 return std::make_unique<LowerArcToLLVMPass>();
assert(baseType &&"element must be base type")
static llvm::Twine evalSymbolFromModelName(StringRef modelName)
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
mlir::LogicalResult collectModels(mlir::ModuleOp module, llvm::SmallVector< ModelInfo > &models)
Collects information about all Arc models in the provided module, and adds it to models.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap)
Get the HW to LLVM conversion patterns.
void populateCombToArithConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
void populateCombToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns)
Get the Comb to LLVM conversion patterns.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
std::unique_ptr< OperationPass< ModuleOp > > createLowerArcToLLVMPass()
Gathers information about a given Arc model.
Gathers information about a given Arc state.