17 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
18 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
19 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
22 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
23 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/Index/IR/IndexOps.h"
26 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
27 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
28 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
29 #include "mlir/Dialect/SCF/IR/SCF.h"
30 #include "mlir/IR/BuiltinDialect.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/DialectConversion.h"
33 #include "llvm/Support/Debug.h"
35 #define DEBUG_TYPE "lower-arc-to-llvm"
38 #define GEN_PASS_DEF_LOWERARCTOLLVM
39 #include "circt/Conversion/Passes.h.inc"
43 using namespace circt;
52 return modelName +
"_eval";
58 using OpConversionPattern::OpConversionPattern;
60 matchAndRewrite(arc::ModelOp op, OpAdaptor adaptor,
61 ConversionPatternRewriter &rewriter)
const final {
63 IRRewriter::InsertionGuard guard(rewriter);
64 rewriter.setInsertionPointToEnd(&op.getBodyBlock());
65 rewriter.create<func::ReturnOp>(op.getLoc());
70 rewriter.getFunctionType(op.getBody().getArgumentTypes(), {});
72 rewriter.create<mlir::func::FuncOp>(op.getLoc(), funcName, funcType);
73 rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
79 struct AllocStorageOpLowering
81 using OpConversionPattern::OpConversionPattern;
83 matchAndRewrite(arc::AllocStorageOp op, OpAdaptor adaptor,
84 ConversionPatternRewriter &rewriter)
const final {
85 auto type = typeConverter->convertType(op.getType());
86 if (!op.getOffset().has_value())
88 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, type, rewriter.getI8Type(),
90 LLVM::GEPArg(*op.getOffset()));
95 template <
class ConcreteOp>
99 using OpAdaptor =
typename ConcreteOp::Adaptor;
102 matchAndRewrite(ConcreteOp op, OpAdaptor adaptor,
103 ConversionPatternRewriter &rewriter)
const final {
105 auto offsetAttr = op->template getAttrOfType<IntegerAttr>(
"offset");
108 Value ptr = rewriter.create<LLVM::GEPOp>(
109 op->getLoc(), adaptor.getStorage().getType(), rewriter.getI8Type(),
110 adaptor.getStorage(),
111 LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
112 rewriter.replaceOp(op, ptr);
118 using OpConversionPattern::OpConversionPattern;
120 matchAndRewrite(arc::StateReadOp op, OpAdaptor adaptor,
121 ConversionPatternRewriter &rewriter)
const final {
122 auto type = typeConverter->convertType(op.getType());
123 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, type, adaptor.getState());
129 using OpConversionPattern::OpConversionPattern;
131 matchAndRewrite(arc::StateWriteOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter)
const final {
133 if (adaptor.getCondition()) {
134 rewriter.replaceOpWithNewOp<scf::IfOp>(
135 op, adaptor.getCondition(), [&](
auto &builder,
auto loc) {
136 builder.template create<LLVM::StoreOp>(loc, adaptor.getValue(),
138 builder.template create<scf::YieldOp>(loc);
141 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
149 using OpConversionPattern::OpConversionPattern;
151 matchAndRewrite(arc::AllocMemoryOp op, OpAdaptor adaptor,
152 ConversionPatternRewriter &rewriter)
const final {
153 auto offsetAttr = op->getAttrOfType<IntegerAttr>(
"offset");
156 Value ptr = rewriter.create<LLVM::GEPOp>(
157 op.getLoc(), adaptor.getStorage().getType(), rewriter.getI8Type(),
158 adaptor.getStorage(),
159 LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
161 rewriter.replaceOp(op, ptr);
167 using OpConversionPattern::OpConversionPattern;
169 matchAndRewrite(arc::StorageGetOp op, OpAdaptor adaptor,
170 ConversionPatternRewriter &rewriter)
const final {
171 Value offset = rewriter.create<LLVM::ConstantOp>(
172 op.getLoc(), rewriter.getI32Type(), op.getOffsetAttr());
173 Value ptr = rewriter.create<LLVM::GEPOp>(
174 op.getLoc(), adaptor.getStorage().getType(), rewriter.getI8Type(),
175 adaptor.getStorage(), offset);
176 rewriter.replaceOp(op, ptr);
181 struct MemoryAccess {
186 static MemoryAccess prepareMemoryAccess(Location loc, Value memory,
187 Value address, MemoryType type,
188 ConversionPatternRewriter &rewriter) {
189 auto zextAddrType = rewriter.getIntegerType(
190 cast<IntegerType>(address.getType()).getWidth() + 1);
191 Value
addr = rewriter.create<LLVM::ZExtOp>(loc, zextAddrType, address);
192 Value addrLimit = rewriter.create<LLVM::ConstantOp>(
193 loc, zextAddrType, rewriter.getI32IntegerAttr(type.getNumWords()));
194 Value withinBounds = rewriter.create<LLVM::ICmpOp>(
195 loc, LLVM::ICmpPredicate::ult,
addr, addrLimit);
196 Value ptr = rewriter.create<LLVM::GEPOp>(
198 rewriter.getIntegerType(type.getStride() * 8), memory, ValueRange{
addr});
199 return {ptr, withinBounds};
203 using OpConversionPattern::OpConversionPattern;
205 matchAndRewrite(arc::MemoryReadOp op, OpAdaptor adaptor,
206 ConversionPatternRewriter &rewriter)
const final {
207 auto type = typeConverter->convertType(op.getType());
208 auto memoryType = cast<MemoryType>(op.getMemory().getType());
210 prepareMemoryAccess(op.getLoc(), adaptor.getMemory(),
211 adaptor.getAddress(), memoryType, rewriter);
215 rewriter.replaceOpWithNewOp<scf::IfOp>(
216 op, access.withinBounds,
217 [&](
auto &builder,
auto loc) {
218 Value loadOp = builder.template create<LLVM::LoadOp>(
219 loc, memoryType.getWordType(), access.ptr);
220 builder.template create<scf::YieldOp>(loc, loadOp);
222 [&](
auto &builder,
auto loc) {
223 Value zeroValue = builder.template create<LLVM::ConstantOp>(
224 loc, type, builder.getI64IntegerAttr(0));
225 builder.template create<scf::YieldOp>(loc, zeroValue);
232 using OpConversionPattern::OpConversionPattern;
234 matchAndRewrite(arc::MemoryWriteOp op, OpAdaptor adaptor,
235 ConversionPatternRewriter &rewriter)
const final {
236 auto access = prepareMemoryAccess(
237 op.getLoc(), adaptor.getMemory(), adaptor.getAddress(),
238 cast<MemoryType>(op.getMemory().getType()), rewriter);
239 auto enable = access.withinBounds;
240 if (adaptor.getEnable())
241 enable = rewriter.create<LLVM::AndOp>(op.getLoc(), adaptor.getEnable(),
245 rewriter.replaceOpWithNewOp<scf::IfOp>(
246 op, enable, [&](
auto &builder,
auto loc) {
247 builder.template create<LLVM::StoreOp>(loc, adaptor.getData(),
249 builder.template create<scf::YieldOp>(loc);
257 using OpConversionPattern::OpConversionPattern;
259 matchAndRewrite(seq::ClockGateOp op, OpAdaptor adaptor,
260 ConversionPatternRewriter &rewriter)
const final {
261 rewriter.replaceOpWithNewOp<
comb::AndOp>(op, adaptor.getInput(),
262 adaptor.getEnable(),
true);
268 using OpConversionPattern::OpConversionPattern;
270 matchAndRewrite(arc::ZeroCountOp op, OpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter)
const override {
273 IntegerAttr isZeroPoison = rewriter.getBoolAttr(
true);
275 if (op.getPredicate() == arc::ZeroCountPredicate::leading) {
276 rewriter.replaceOpWithNewOp<LLVM::CountLeadingZerosOp>(
277 op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
281 rewriter.replaceOpWithNewOp<LLVM::CountTrailingZerosOp>(
282 op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
288 using OpConversionPattern::OpConversionPattern;
290 matchAndRewrite(seq::ConstClockOp op, OpAdaptor adaptor,
291 ConversionPatternRewriter &rewriter)
const override {
292 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
293 op, rewriter.getI1Type(),
static_cast<int64_t
>(op.getValue()));
298 template <
typename OpTy>
301 using OpAdaptor =
typename OpTy::Adaptor;
303 matchAndRewrite(OpTy op, OpAdaptor adaptor,
304 ConversionPatternRewriter &rewriter)
const override {
305 rewriter.replaceOp(op, adaptor.getInput());
318 struct ModelInfoMap {
319 size_t numStateBytes;
320 llvm::DenseMap<StringRef, StateInfo> states;
323 template <
typename OpTy>
325 ModelAwarePattern(
const TypeConverter &typeConverter, MLIRContext *context,
326 llvm::DenseMap<StringRef, ModelInfoMap> &modelInfo)
328 modelInfo(modelInfo) {}
331 Value createPtrToPortState(ConversionPatternRewriter &rewriter, Location loc,
332 Value state,
const StateInfo &port)
const {
333 MLIRContext *ctx = rewriter.getContext();
336 LLVM::GEPArg(port.
offset));
339 llvm::DenseMap<StringRef, ModelInfoMap> &modelInfo;
344 struct SimInstantiateOpLowering
345 :
public ModelAwarePattern<arc::SimInstantiateOp> {
346 using ModelAwarePattern::ModelAwarePattern;
349 matchAndRewrite(arc::SimInstantiateOp op, OpAdaptor adaptor,
350 ConversionPatternRewriter &rewriter)
const final {
351 auto modelIt = modelInfo.find(
352 cast<SimModelInstanceType>(op.getBody().getArgument(0).getType())
355 ModelInfoMap &model = modelIt->second;
357 ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
361 ConversionPatternRewriter::InsertionGuard guard(rewriter);
365 Type convertedIndex = typeConverter->convertType(rewriter.getIndexType());
367 LLVM::LLVMFuncOp mallocFunc =
368 LLVM::lookupOrCreateMallocFn(moduleOp, convertedIndex);
369 LLVM::LLVMFuncOp freeFunc = LLVM::lookupOrCreateFreeFn(moduleOp);
371 Location loc = op.getLoc();
372 Value numStateBytes = rewriter.create<LLVM::ConstantOp>(
373 loc, convertedIndex, model.numStateBytes);
376 .create<LLVM::CallOp>(loc, mallocFunc, ValueRange{numStateBytes})
379 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI8Type(), 0);
380 rewriter.create<LLVM::MemsetOp>(loc, allocated, zero, numStateBytes,
false);
381 rewriter.inlineBlockBefore(&adaptor.getBody().getBlocks().front(), op,
383 rewriter.create<LLVM::CallOp>(loc, freeFunc, ValueRange{allocated});
384 rewriter.eraseOp(op);
390 struct SimSetInputOpLowering :
public ModelAwarePattern<arc::SimSetInputOp> {
391 using ModelAwarePattern::ModelAwarePattern;
394 matchAndRewrite(arc::SimSetInputOp op, OpAdaptor adaptor,
395 ConversionPatternRewriter &rewriter)
const final {
397 modelInfo.find(cast<SimModelInstanceType>(op.getInstance().getType())
400 ModelInfoMap &model = modelIt->second;
402 auto portIt = model.states.find(op.getInput());
403 if (portIt == model.states.end()) {
406 rewriter.eraseOp(op);
411 Value statePtr = createPtrToPortState(rewriter, op.getLoc(),
412 adaptor.getInstance(), port);
413 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
420 struct SimGetPortOpLowering :
public ModelAwarePattern<arc::SimGetPortOp> {
421 using ModelAwarePattern::ModelAwarePattern;
424 matchAndRewrite(arc::SimGetPortOp op, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter)
const final {
427 modelInfo.find(cast<SimModelInstanceType>(op.getInstance().getType())
430 ModelInfoMap &model = modelIt->second;
432 auto portIt = model.states.find(op.getPort());
433 if (portIt == model.states.end()) {
436 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
437 op, typeConverter->convertType(op.getValue().getType()), 0);
442 Value statePtr = createPtrToPortState(rewriter, op.getLoc(),
443 adaptor.getInstance(), port);
444 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, op.getValue().getType(),
451 struct SimStepOpLowering :
public ModelAwarePattern<arc::SimStepOp> {
452 using ModelAwarePattern::ModelAwarePattern;
455 matchAndRewrite(arc::SimStepOp op, OpAdaptor adaptor,
456 ConversionPatternRewriter &rewriter)
const final {
457 StringRef modelName = cast<SimModelInstanceType>(op.getInstance().getType())
461 StringAttr evalFunc =
463 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, std::nullopt, evalFunc,
464 adaptor.getInstance());
473 struct SimEmitValueOpLowering
475 using OpConversionPattern::OpConversionPattern;
478 matchAndRewrite(arc::SimEmitValueOp op, OpAdaptor adaptor,
479 ConversionPatternRewriter &rewriter)
const final {
480 auto valueType = dyn_cast<IntegerType>(adaptor.getValue().getType());
484 Location loc = op.getLoc();
486 ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
493 Value toPrint = adaptor.getValue();
494 DataLayout layout = DataLayout::closest(op);
495 llvm::TypeSize sizeOfSizeT =
496 layout.getTypeSizeInBits(rewriter.getIndexType());
497 assert(!sizeOfSizeT.isScalable() &&
498 sizeOfSizeT.getFixedValue() <= std::numeric_limits<unsigned>::max());
499 bool truncated =
false;
500 if (valueType.getWidth() > sizeOfSizeT) {
501 toPrint = rewriter.create<LLVM::TruncOp>(
505 }
else if (valueType.getWidth() < sizeOfSizeT)
506 toPrint = rewriter.create<LLVM::ZExtOp>(
511 auto printfFunc = LLVM::lookupOrCreateFn(
516 SmallString<16> formatStrName{
"_arc_sim_emit_"};
517 formatStrName.append(truncated ?
"trunc_" :
"full_");
518 formatStrName.append(adaptor.getValueName());
519 LLVM::GlobalOp formatStrGlobal;
520 if (!(formatStrGlobal =
521 moduleOp.lookupSymbol<LLVM::GlobalOp>(formatStrName))) {
522 ConversionPatternRewriter::InsertionGuard insertGuard(rewriter);
524 SmallString<16> formatStr = adaptor.getValueName();
525 formatStr.append(
" = ");
527 formatStr.append(
"(truncated) ");
528 formatStr.append(
"%zx\n");
529 SmallVector<char> formatStrVec{formatStr.begin(), formatStr.end()};
530 formatStrVec.push_back(0);
532 rewriter.setInsertionPointToStart(moduleOp.getBody());
535 formatStrGlobal = rewriter.create<LLVM::GlobalOp>(
536 loc, globalType,
true, LLVM::Linkage::Internal,
537 formatStrName, rewriter.getStringAttr(formatStrVec),
541 Value formatStrGlobalPtr =
542 rewriter.create<LLVM::AddressOfOp>(loc, formatStrGlobal);
543 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
544 op, printfFunc, ValueRange{formatStrGlobalPtr, toPrint});
557 struct LowerArcToLLVMPass
558 :
public circt::impl::LowerArcToLLVMBase<LowerArcToLLVMPass> {
559 void runOnOperation()
override;
563 void LowerArcToLLVMPass::runOnOperation() {
575 LLVMConversionTarget target(getContext());
576 target.addLegalOp<mlir::ModuleOp>();
577 target.addLegalOp<scf::YieldOp>();
580 LLVMTypeConverter converter(&getContext());
581 converter.addConversion([&](seq::ClockType type) {
584 converter.addConversion([&](StorageType type) {
587 converter.addConversion([&](MemoryType type) {
590 converter.addConversion([&](StateType type) {
593 converter.addConversion([&](SimModelInstanceType type) {
598 RewritePatternSet
patterns(&getContext());
601 populateSCFToControlFlowConversionPatterns(
patterns);
602 populateFuncToLLVMConversionPatterns(converter,
patterns);
603 cf::populateControlFlowToLLVMConversionPatterns(converter,
patterns);
604 arith::populateArithToLLVMConversionPatterns(converter,
patterns);
605 populateAnyFunctionOpInterfaceTypeConversionPattern(
patterns, converter);
608 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
610 constAggregateGlobalsMap);
617 AllocMemoryOpLowering,
618 AllocStateLikeOpLowering<arc::AllocStateOp>,
619 AllocStateLikeOpLowering<arc::RootInputOp>,
620 AllocStateLikeOpLowering<arc::RootOutputOp>,
621 AllocStorageOpLowering,
623 MemoryReadOpLowering,
624 MemoryWriteOpLowering,
626 ReplaceOpWithInputPattern<seq::ToClockOp>,
627 ReplaceOpWithInputPattern<seq::FromClockOp>,
628 SeqConstClockLowering,
629 SimEmitValueOpLowering,
631 StateWriteOpLowering,
632 StorageGetOpLowering,
634 >(converter, &getContext());
637 SmallVector<ModelInfo> models;
643 llvm::DenseMap<StringRef, ModelInfoMap> modelMap(models.size());
645 llvm::DenseMap<StringRef, StateInfo> states(modelInfo.states.size());
646 for (
StateInfo &stateInfo : modelInfo.states)
647 states.insert({stateInfo.
name, stateInfo});
648 modelMap.insert({modelInfo.name,
649 ModelInfoMap{modelInfo.numStateBytes, std::move(states)}});
652 patterns.add<SimInstantiateOpLowering, SimSetInputOpLowering,
653 SimGetPortOpLowering, SimStepOpLowering>(
654 converter, &getContext(), modelMap);
657 if (failed(applyFullConversion(getOperation(), target, std::move(
patterns))))
662 return std::make_unique<LowerArcToLLVMPass>();
assert(baseType &&"element must be base type")
static llvm::Twine evalSymbolFromModelName(StringRef modelName)
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
mlir::LogicalResult collectModels(mlir::ModuleOp module, llvm::SmallVector< ModelInfo > &models)
Collects information about all Arc models in the provided module, and adds it to models.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap)
Get the HW to LLVM conversion patterns.
void populateCombToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns)
Get the Comb to LLVM conversion patterns.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
std::unique_ptr< OperationPass< ModuleOp > > createLowerArcToLLVMPass()
Gathers information about a given Arc model.
Gathers information about a given Arc state.