17 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
18 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
19 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
20 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
21 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
22 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
23 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
24 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/Index/IR/IndexOps.h"
27 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
28 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
29 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
30 #include "mlir/Dialect/SCF/IR/SCF.h"
31 #include "mlir/IR/BuiltinDialect.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Transforms/DialectConversion.h"
34 #include "llvm/Support/Debug.h"
36 #define DEBUG_TYPE "lower-arc-to-llvm"
39 #define GEN_PASS_DEF_LOWERARCTOLLVM
40 #include "circt/Conversion/Passes.h.inc"
44 using namespace circt;
53 return modelName +
"_eval";
59 using OpConversionPattern::OpConversionPattern;
61 matchAndRewrite(arc::ModelOp op, OpAdaptor adaptor,
62 ConversionPatternRewriter &rewriter)
const final {
64 IRRewriter::InsertionGuard guard(rewriter);
65 rewriter.setInsertionPointToEnd(&op.getBodyBlock());
66 rewriter.create<func::ReturnOp>(op.getLoc());
71 rewriter.getFunctionType(op.getBody().getArgumentTypes(), {});
73 rewriter.create<mlir::func::FuncOp>(op.getLoc(), funcName, funcType);
74 rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
80 struct AllocStorageOpLowering
82 using OpConversionPattern::OpConversionPattern;
84 matchAndRewrite(arc::AllocStorageOp op, OpAdaptor adaptor,
85 ConversionPatternRewriter &rewriter)
const final {
86 auto type = typeConverter->convertType(op.getType());
87 if (!op.getOffset().has_value())
89 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, type, rewriter.getI8Type(),
91 LLVM::GEPArg(*op.getOffset()));
96 template <
class ConcreteOp>
100 using OpAdaptor =
typename ConcreteOp::Adaptor;
103 matchAndRewrite(ConcreteOp op, OpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter)
const final {
106 auto offsetAttr = op->template getAttrOfType<IntegerAttr>(
"offset");
109 Value ptr = rewriter.create<LLVM::GEPOp>(
110 op->getLoc(), adaptor.getStorage().getType(), rewriter.getI8Type(),
111 adaptor.getStorage(),
112 LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
113 rewriter.replaceOp(op, ptr);
119 using OpConversionPattern::OpConversionPattern;
121 matchAndRewrite(arc::StateReadOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter)
const final {
123 auto type = typeConverter->convertType(op.getType());
124 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, type, adaptor.getState());
130 using OpConversionPattern::OpConversionPattern;
132 matchAndRewrite(arc::StateWriteOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter)
const final {
134 if (adaptor.getCondition()) {
135 rewriter.replaceOpWithNewOp<scf::IfOp>(
136 op, adaptor.getCondition(), [&](
auto &builder,
auto loc) {
137 builder.template create<LLVM::StoreOp>(loc, adaptor.getValue(),
139 builder.template create<scf::YieldOp>(loc);
142 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
150 using OpConversionPattern::OpConversionPattern;
152 matchAndRewrite(arc::AllocMemoryOp op, OpAdaptor adaptor,
153 ConversionPatternRewriter &rewriter)
const final {
154 auto offsetAttr = op->getAttrOfType<IntegerAttr>(
"offset");
157 Value ptr = rewriter.create<LLVM::GEPOp>(
158 op.getLoc(), adaptor.getStorage().getType(), rewriter.getI8Type(),
159 adaptor.getStorage(),
160 LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
162 rewriter.replaceOp(op, ptr);
168 using OpConversionPattern::OpConversionPattern;
170 matchAndRewrite(arc::StorageGetOp op, OpAdaptor adaptor,
171 ConversionPatternRewriter &rewriter)
const final {
172 Value offset = rewriter.create<LLVM::ConstantOp>(
173 op.getLoc(), rewriter.getI32Type(), op.getOffsetAttr());
174 Value ptr = rewriter.create<LLVM::GEPOp>(
175 op.getLoc(), adaptor.getStorage().getType(), rewriter.getI8Type(),
176 adaptor.getStorage(), offset);
177 rewriter.replaceOp(op, ptr);
182 struct MemoryAccess {
187 static MemoryAccess prepareMemoryAccess(Location loc, Value memory,
188 Value address, MemoryType type,
189 ConversionPatternRewriter &rewriter) {
190 auto zextAddrType = rewriter.getIntegerType(
191 cast<IntegerType>(address.getType()).getWidth() + 1);
192 Value
addr = rewriter.create<LLVM::ZExtOp>(loc, zextAddrType, address);
193 Value addrLimit = rewriter.create<LLVM::ConstantOp>(
194 loc, zextAddrType, rewriter.getI32IntegerAttr(type.getNumWords()));
195 Value withinBounds = rewriter.create<LLVM::ICmpOp>(
196 loc, LLVM::ICmpPredicate::ult,
addr, addrLimit);
197 Value ptr = rewriter.create<LLVM::GEPOp>(
199 rewriter.getIntegerType(type.getStride() * 8), memory, ValueRange{
addr});
200 return {ptr, withinBounds};
204 using OpConversionPattern::OpConversionPattern;
206 matchAndRewrite(arc::MemoryReadOp op, OpAdaptor adaptor,
207 ConversionPatternRewriter &rewriter)
const final {
208 auto type = typeConverter->convertType(op.getType());
209 auto memoryType = cast<MemoryType>(op.getMemory().getType());
211 prepareMemoryAccess(op.getLoc(), adaptor.getMemory(),
212 adaptor.getAddress(), memoryType, rewriter);
216 rewriter.replaceOpWithNewOp<scf::IfOp>(
217 op, access.withinBounds,
218 [&](
auto &builder,
auto loc) {
219 Value loadOp = builder.template create<LLVM::LoadOp>(
220 loc, memoryType.getWordType(), access.ptr);
221 builder.template create<scf::YieldOp>(loc, loadOp);
223 [&](
auto &builder,
auto loc) {
224 Value zeroValue = builder.template create<LLVM::ConstantOp>(
225 loc, type, builder.getI64IntegerAttr(0));
226 builder.template create<scf::YieldOp>(loc, zeroValue);
233 using OpConversionPattern::OpConversionPattern;
235 matchAndRewrite(arc::MemoryWriteOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter)
const final {
237 auto access = prepareMemoryAccess(
238 op.getLoc(), adaptor.getMemory(), adaptor.getAddress(),
239 cast<MemoryType>(op.getMemory().getType()), rewriter);
240 auto enable = access.withinBounds;
241 if (adaptor.getEnable())
242 enable = rewriter.create<LLVM::AndOp>(op.getLoc(), adaptor.getEnable(),
246 rewriter.replaceOpWithNewOp<scf::IfOp>(
247 op, enable, [&](
auto &builder,
auto loc) {
248 builder.template create<LLVM::StoreOp>(loc, adaptor.getData(),
250 builder.template create<scf::YieldOp>(loc);
258 using OpConversionPattern::OpConversionPattern;
260 matchAndRewrite(seq::ClockGateOp op, OpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter)
const final {
262 rewriter.replaceOpWithNewOp<
comb::AndOp>(op, adaptor.getInput(),
263 adaptor.getEnable(),
true);
269 using OpConversionPattern::OpConversionPattern;
271 matchAndRewrite(arc::ZeroCountOp op, OpAdaptor adaptor,
272 ConversionPatternRewriter &rewriter)
const override {
274 IntegerAttr isZeroPoison = rewriter.getBoolAttr(
true);
276 if (op.getPredicate() == arc::ZeroCountPredicate::leading) {
277 rewriter.replaceOpWithNewOp<LLVM::CountLeadingZerosOp>(
278 op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
282 rewriter.replaceOpWithNewOp<LLVM::CountTrailingZerosOp>(
283 op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
289 using OpConversionPattern::OpConversionPattern;
291 matchAndRewrite(seq::ConstClockOp op, OpAdaptor adaptor,
292 ConversionPatternRewriter &rewriter)
const override {
293 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
294 op, rewriter.getI1Type(),
static_cast<int64_t
>(op.getValue()));
299 template <
typename OpTy>
302 using OpAdaptor =
typename OpTy::Adaptor;
304 matchAndRewrite(OpTy op, OpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter)
const override {
306 rewriter.replaceOp(op, adaptor.getInput());
319 struct ModelInfoMap {
320 size_t numStateBytes;
321 llvm::DenseMap<StringRef, StateInfo> states;
322 mlir::FlatSymbolRefAttr initialFnSymbol;
323 mlir::FlatSymbolRefAttr finalFnSymbol;
326 template <
typename OpTy>
328 ModelAwarePattern(
const TypeConverter &typeConverter, MLIRContext *context,
329 llvm::DenseMap<StringRef, ModelInfoMap> &modelInfo)
331 modelInfo(modelInfo) {}
334 Value createPtrToPortState(ConversionPatternRewriter &rewriter, Location loc,
335 Value state,
const StateInfo &port)
const {
336 MLIRContext *ctx = rewriter.getContext();
339 LLVM::GEPArg(port.
offset));
342 llvm::DenseMap<StringRef, ModelInfoMap> &modelInfo;
347 struct SimInstantiateOpLowering
348 :
public ModelAwarePattern<arc::SimInstantiateOp> {
349 using ModelAwarePattern::ModelAwarePattern;
352 matchAndRewrite(arc::SimInstantiateOp op, OpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter)
const final {
354 auto modelIt = modelInfo.find(
355 cast<SimModelInstanceType>(op.getBody().getArgument(0).getType())
358 ModelInfoMap &model = modelIt->second;
360 ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
364 ConversionPatternRewriter::InsertionGuard guard(rewriter);
368 Type convertedIndex = typeConverter->convertType(rewriter.getIndexType());
370 LLVM::LLVMFuncOp mallocFunc =
371 LLVM::lookupOrCreateMallocFn(moduleOp, convertedIndex);
372 LLVM::LLVMFuncOp freeFunc = LLVM::lookupOrCreateFreeFn(moduleOp);
374 Location loc = op.getLoc();
375 Value numStateBytes = rewriter.create<LLVM::ConstantOp>(
376 loc, convertedIndex, model.numStateBytes);
379 .create<LLVM::CallOp>(loc, mallocFunc, ValueRange{numStateBytes})
382 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI8Type(), 0);
383 rewriter.create<LLVM::MemsetOp>(loc, allocated, zero, numStateBytes,
false);
386 if (model.initialFnSymbol) {
389 {LLVM::LLVMPointerType::get(op.getContext())});
390 rewriter.create<LLVM::CallOp>(loc, initialFnType, model.initialFnSymbol,
391 ValueRange{allocated});
395 rewriter.inlineBlockBefore(&adaptor.getBody().getBlocks().front(), op,
399 if (model.finalFnSymbol) {
402 {LLVM::LLVMPointerType::get(op.getContext())});
403 rewriter.create<LLVM::CallOp>(loc, finalFnType, model.finalFnSymbol,
404 ValueRange{allocated});
407 rewriter.create<LLVM::CallOp>(loc, freeFunc, ValueRange{allocated});
408 rewriter.eraseOp(op);
414 struct SimSetInputOpLowering :
public ModelAwarePattern<arc::SimSetInputOp> {
415 using ModelAwarePattern::ModelAwarePattern;
418 matchAndRewrite(arc::SimSetInputOp op, OpAdaptor adaptor,
419 ConversionPatternRewriter &rewriter)
const final {
421 modelInfo.find(cast<SimModelInstanceType>(op.getInstance().getType())
424 ModelInfoMap &model = modelIt->second;
426 auto portIt = model.states.find(op.getInput());
427 if (portIt == model.states.end()) {
430 rewriter.eraseOp(op);
435 Value statePtr = createPtrToPortState(rewriter, op.getLoc(),
436 adaptor.getInstance(), port);
437 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
444 struct SimGetPortOpLowering :
public ModelAwarePattern<arc::SimGetPortOp> {
445 using ModelAwarePattern::ModelAwarePattern;
448 matchAndRewrite(arc::SimGetPortOp op, OpAdaptor adaptor,
449 ConversionPatternRewriter &rewriter)
const final {
451 modelInfo.find(cast<SimModelInstanceType>(op.getInstance().getType())
454 ModelInfoMap &model = modelIt->second;
456 auto type = typeConverter->convertType(op.getValue().getType());
459 auto portIt = model.states.find(op.getPort());
460 if (portIt == model.states.end()) {
463 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, type, 0);
468 Value statePtr = createPtrToPortState(rewriter, op.getLoc(),
469 adaptor.getInstance(), port);
470 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, type, statePtr);
476 struct SimStepOpLowering :
public ModelAwarePattern<arc::SimStepOp> {
477 using ModelAwarePattern::ModelAwarePattern;
480 matchAndRewrite(arc::SimStepOp op, OpAdaptor adaptor,
481 ConversionPatternRewriter &rewriter)
const final {
482 StringRef modelName = cast<SimModelInstanceType>(op.getInstance().getType())
486 StringAttr evalFunc =
488 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, std::nullopt, evalFunc,
489 adaptor.getInstance());
498 struct SimEmitValueOpLowering
500 using OpConversionPattern::OpConversionPattern;
503 matchAndRewrite(arc::SimEmitValueOp op, OpAdaptor adaptor,
504 ConversionPatternRewriter &rewriter)
const final {
505 auto valueType = dyn_cast<IntegerType>(adaptor.getValue().getType());
509 Location loc = op.getLoc();
511 ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
518 Value toPrint = adaptor.getValue();
519 DataLayout layout = DataLayout::closest(op);
520 llvm::TypeSize sizeOfSizeT =
521 layout.getTypeSizeInBits(rewriter.getIndexType());
522 assert(!sizeOfSizeT.isScalable() &&
523 sizeOfSizeT.getFixedValue() <= std::numeric_limits<unsigned>::max());
524 bool truncated =
false;
525 if (valueType.getWidth() > sizeOfSizeT) {
526 toPrint = rewriter.create<LLVM::TruncOp>(
530 }
else if (valueType.getWidth() < sizeOfSizeT)
531 toPrint = rewriter.create<LLVM::ZExtOp>(
536 auto printfFunc = LLVM::lookupOrCreateFn(
541 SmallString<16> formatStrName{
"_arc_sim_emit_"};
542 formatStrName.append(truncated ?
"trunc_" :
"full_");
543 formatStrName.append(adaptor.getValueName());
544 LLVM::GlobalOp formatStrGlobal;
545 if (!(formatStrGlobal =
546 moduleOp.lookupSymbol<LLVM::GlobalOp>(formatStrName))) {
547 ConversionPatternRewriter::InsertionGuard insertGuard(rewriter);
549 SmallString<16> formatStr = adaptor.getValueName();
550 formatStr.append(
" = ");
552 formatStr.append(
"(truncated) ");
553 formatStr.append(
"%zx\n");
554 SmallVector<char> formatStrVec{formatStr.begin(), formatStr.end()};
555 formatStrVec.push_back(0);
557 rewriter.setInsertionPointToStart(moduleOp.getBody());
560 formatStrGlobal = rewriter.create<LLVM::GlobalOp>(
561 loc, globalType,
true, LLVM::Linkage::Internal,
562 formatStrName, rewriter.getStringAttr(formatStrVec),
566 Value formatStrGlobalPtr =
567 rewriter.create<LLVM::AddressOfOp>(loc, formatStrGlobal);
568 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
569 op, printfFunc, ValueRange{formatStrGlobalPtr, toPrint});
582 struct LowerArcToLLVMPass
583 :
public circt::impl::LowerArcToLLVMBase<LowerArcToLLVMPass> {
584 void runOnOperation()
override;
588 void LowerArcToLLVMPass::runOnOperation() {
600 LLVMConversionTarget target(getContext());
601 target.addLegalOp<mlir::ModuleOp>();
602 target.addLegalOp<scf::YieldOp>();
605 LLVMTypeConverter converter(&getContext());
606 converter.addConversion([&](seq::ClockType type) {
609 converter.addConversion([&](StorageType type) {
612 converter.addConversion([&](MemoryType type) {
615 converter.addConversion([&](StateType type) {
618 converter.addConversion([&](SimModelInstanceType type) {
623 RewritePatternSet
patterns(&getContext());
626 populateSCFToControlFlowConversionPatterns(
patterns);
627 populateFuncToLLVMConversionPatterns(converter,
patterns);
628 cf::populateControlFlowToLLVMConversionPatterns(converter,
patterns);
629 arith::populateArithToLLVMConversionPatterns(converter,
patterns);
630 index::populateIndexToLLVMConversionPatterns(converter,
patterns);
631 populateAnyFunctionOpInterfaceTypeConversionPattern(
patterns, converter);
634 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
636 constAggregateGlobalsMap);
643 AllocMemoryOpLowering,
644 AllocStateLikeOpLowering<arc::AllocStateOp>,
645 AllocStateLikeOpLowering<arc::RootInputOp>,
646 AllocStateLikeOpLowering<arc::RootOutputOp>,
647 AllocStorageOpLowering,
649 MemoryReadOpLowering,
650 MemoryWriteOpLowering,
652 ReplaceOpWithInputPattern<seq::ToClockOp>,
653 ReplaceOpWithInputPattern<seq::FromClockOp>,
654 SeqConstClockLowering,
655 SimEmitValueOpLowering,
657 StateWriteOpLowering,
658 StorageGetOpLowering,
660 >(converter, &getContext());
663 SmallVector<ModelInfo> models;
669 llvm::DenseMap<StringRef, ModelInfoMap> modelMap(models.size());
671 llvm::DenseMap<StringRef, StateInfo> states(modelInfo.states.size());
672 for (
StateInfo &stateInfo : modelInfo.states)
673 states.insert({stateInfo.
name, stateInfo});
676 ModelInfoMap{modelInfo.numStateBytes, std::move(states),
677 modelInfo.initialFnSym, modelInfo.finalFnSym}});
680 patterns.add<SimInstantiateOpLowering, SimSetInputOpLowering,
681 SimGetPortOpLowering, SimStepOpLowering>(
682 converter, &getContext(), modelMap);
685 if (failed(applyFullConversion(getOperation(), target, std::move(
patterns))))
690 return std::make_unique<LowerArcToLLVMPass>();
assert(baseType &&"element must be base type")
static llvm::Twine evalSymbolFromModelName(StringRef modelName)
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
mlir::LogicalResult collectModels(mlir::ModuleOp module, llvm::SmallVector< ModelInfo > &models)
Collects information about all Arc models in the provided module, and adds it to models.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap)
Get the HW to LLVM conversion patterns.
void populateCombToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns)
Get the Comb to LLVM conversion patterns.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
std::unique_ptr< OperationPass< ModuleOp > > createLowerArcToLLVMPass()
Gathers information about a given Arc model.
Gathers information about a given Arc state.