17#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18#include "mlir/Conversion/LLVMCommon/Pattern.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/IR/Iterators.h"
23#include "mlir/Interfaces/DataLayoutInterfaces.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include "llvm/ADT/TypeSwitch.h"
29#define GEN_PASS_DEF_CONVERTHWTOLLVM
30#include "circt/Conversion/Passes.h.inc"
44 return TypeSwitch<Type, uint32_t>(type)
46 [&](hw::ArrayType ty) {
return ty.getNumElements() - index - 1; })
47 .Case<hw::StructType>([&](hw::StructType ty) {
48 return ty.getElements().size() - index - 1;
54 StringRef fieldName) {
55 auto fieldIter = type.getElements();
58 for (
const auto *iter = fieldIter.begin(); iter != fieldIter.end(); ++iter) {
59 if (iter->name == fieldName) {
67 llvm_unreachable(
"Field name attribute of hw::StructExtractOp invalid");
78static Value
zextByOne(Location loc, ConversionPatternRewriter &rewriter,
80 auto valueTy = value.getType();
81 auto zextTy = IntegerType::get(valueTy.getContext(),
82 valueTy.getIntOrFloatBitWidth() + 1);
83 return LLVM::ZExtOp::create(rewriter, loc, zextTy, value);
92 auto oneC = LLVM::ConstantOp::create(
93 builder, loc, IntegerType::get(builder.getContext(), 32),
94 builder.getI32IntegerAttr(1));
96 Block *block = builder.getInsertionBlock();
97 assert(block &&
"expected an insertion block when spilling a value");
100 static_cast<unsigned>(DataLayout::closest(block->getParentOp())
101 .getTypePreferredAlignment(spillVal.getType()));
102 alignment = std::max(4u, alignment);
103 Value ptr = LLVM::AllocaOp::create(
104 builder, loc, LLVM::LLVMPointerType::get(builder.getContext()),
105 spillVal.getType(), oneC, alignment);
106 LLVM::StoreOp::create(builder, loc, spillVal, ptr);
111 LLVMTypeConverter &converter,
112 Operation *containerOp) {
113 OpBuilder::InsertionGuard g(builder);
114 containerOp->walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
116 if (isa_and_nonnull<hw::HWDialect>(op->getDialect()))
118 auto hasSpillingUser = [](Value arrVal) ->
bool {
119 for (
auto user : arrVal.getUsers())
120 if (isa<hw::ArrayGetOp, hw::ArraySliceOp>(user))
125 for (
auto ®ion : op->getRegions()) {
126 for (
auto &block : region.getBlocks()) {
127 builder.setInsertionPointToStart(&block);
128 for (
auto &arg : block.getArguments()) {
129 if (isa<hw::ArrayType>(arg.getType()) && hasSpillingUser(arg))
135 for (
auto result : op->getResults()) {
136 if (isa<hw::ArrayType>(result.getType()) && hasSpillingUser(result)) {
137 builder.setInsertionPointAfter(op);
145 assert(isa<LLVM::LLVMArrayType>(arrayValue.getType()) &&
146 "Key is not an LLVM array.");
147 assert(isa<LLVM::LLVMPointerType>(bufferPtr.getType()) &&
148 "Value is not a pointer.");
149 spillMap.insert({arrayValue, bufferPtr});
153 assert(isa<LLVM::LLVMArrayType>(arrayValue.getType()) ||
154 isa<hw::ArrayType>(arrayValue.getType()) &&
"Not an array value");
155 while (isa<LLVM::LLVMArrayType, hw::ArrayType>(arrayValue.getType())) {
156 if (isa<LLVM::LLVMArrayType>(arrayValue.getType())) {
157 auto mapVal =
spillMap.lookup(arrayValue);
161 if (
auto castOp = arrayValue.getDefiningOp<UnrealizedConversionCastOp>())
162 arrayValue = castOp.getOperand(0);
173 assert(isa<LLVM::LLVMArrayType>(llvmArray.getType()) &&
174 "Expected an LLVM array.");
177 LLVM::LoadOp::create(builder, loc, llvmArray.getType(), spillBuffer);
178 map(loadOp.getResult(), spillBuffer);
179 return loadOp.getResult();
187 LLVMTypeConverter &converter,
189 assert(isa<hw::ArrayType>(hwArray.getType()) &&
"Expected an HW array");
190 auto targetType = converter.convertType(hwArray.getType());
192 UnrealizedConversionCastOp::create(builder, loc, targetType, hwArray);
194 auto llvmToHWCast = UnrealizedConversionCastOp::create(
195 builder, loc, hwArray.getType(), spilled);
196 hwArray.replaceAllUsesExcept(llvmToHWCast.getResult(0), hwToLLVMCast);
197 return llvmToHWCast.getResult(0);
203template <
typename SourceOp>
204struct HWArrayOpToLLVMPattern :
public ConvertOpToLLVMPattern<SourceOp> {
206 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
207 HWArrayOpToLLVMPattern(LLVMTypeConverter &converter,
208 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt)
209 : ConvertOpToLLVMPattern<SourceOp>(converter),
210 spillCacheOpt(spillCacheOpt) {}
213 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt;
226struct StructExplodeOpConversion
227 :
public ConvertOpToLLVMPattern<hw::StructExplodeOp> {
228 using ConvertOpToLLVMPattern<hw::StructExplodeOp>::ConvertOpToLLVMPattern;
231 matchAndRewrite(hw::StructExplodeOp op, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter)
const override {
234 SmallVector<Value> replacements;
237 e = cast<LLVM::LLVMStructType>(adaptor.getInput().getType())
242 replacements.push_back(LLVM::ExtractValueOp::create(
243 rewriter, op->getLoc(), adaptor.getInput(),
245 op.getInput().getType(), i)));
247 rewriter.replaceOp(op, replacements);
257struct StructExtractOpConversion
258 :
public ConvertOpToLLVMPattern<hw::StructExtractOp> {
263 ConversionPatternRewriter &rewriter)
const override {
266 op.getInput().getType(), op.getFieldIndex());
267 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, adaptor.getInput(),
279struct ArrayInjectOpConversion
280 :
public HWArrayOpToLLVMPattern<hw::ArrayInjectOp> {
281 using HWArrayOpToLLVMPattern<hw::ArrayInjectOp>::HWArrayOpToLLVMPattern;
284 matchAndRewrite(hw::ArrayInjectOp op, OpAdaptor adaptor,
285 ConversionPatternRewriter &rewriter)
const override {
286 auto inputType = cast<hw::ArrayType>(op.getInput().getType());
287 auto oldArrTy = adaptor.getInput().getType();
288 auto newArrTy = oldArrTy;
289 const size_t arrElems = inputType.getNumElements();
292 rewriter.replaceOp(op, adaptor.getInput());
297 LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI32Type(),
298 rewriter.getI32IntegerAttr(1));
299 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getIndex());
301 if (arrElems == 1 || !llvm::isPowerOf2_64(arrElems)) {
306 LLVM::ConstantOp::create(rewriter, op->getLoc(), zextIndex.getType(),
307 rewriter.getI32IntegerAttr(arrElems));
309 LLVM::UMinOp::create(rewriter, op->getLoc(), zextIndex, maxIndex);
311 newArrTy = typeConverter->convertType(
312 hw::ArrayType::get(inputType.getElementType(), arrElems + 1));
314 auto allocaAlignment = std::max(
315 4u,
static_cast<unsigned>(DataLayout::closest(op.getOperation())
316 .getTypePreferredAlignment(newArrTy)));
317 Value arrPtr = LLVM::AllocaOp::create(
318 rewriter, op->getLoc(),
319 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, oneC,
322 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getInput(), arrPtr);
324 auto gep = LLVM::GEPOp::create(
325 rewriter, op->getLoc(),
326 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, arrPtr,
327 ArrayRef<LLVM::GEPArg>{0, zextIndex});
329 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getElement(), gep);
331 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, oldArrTy, arrPtr);
333 spillCacheOpt->map(loadOp, arrPtr);
343struct ArrayGetOpConversion :
public HWArrayOpToLLVMPattern<hw::ArrayGetOp> {
344 using HWArrayOpToLLVMPattern<
hw::ArrayGetOp>::HWArrayOpToLLVMPattern;
348 ConversionPatternRewriter &rewriter)
const override {
352 arrPtr = spillCacheOpt->lookup(adaptor.getInput());
356 auto arrTy = typeConverter->convertType(op.getInput().getType());
357 auto elemTy = typeConverter->convertType(op.getResult().getType());
358 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getIndex());
363 auto gep = LLVM::GEPOp::create(
364 rewriter, op->getLoc(),
365 LLVM::LLVMPointerType::get(rewriter.getContext()), arrTy, arrPtr,
366 ArrayRef<LLVM::GEPArg>{0, zextIndex});
367 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elemTy, gep);
378struct ArraySliceOpConversion
379 :
public HWArrayOpToLLVMPattern<hw::ArraySliceOp> {
384 ConversionPatternRewriter &rewriter)
const override {
386 auto dstTy = typeConverter->convertType(op.getDst().getType());
390 arrPtr = spillCacheOpt->lookup(adaptor.getInput());
394 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getLowIndex());
399 auto gep = LLVM::GEPOp::create(
400 rewriter, op->getLoc(),
401 LLVM::LLVMPointerType::get(rewriter.getContext()), dstTy, arrPtr,
402 ArrayRef<LLVM::GEPArg>{0, zextIndex});
404 auto loadOp = rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dstTy, gep);
407 spillCacheOpt->map(loadOp, gep);
422struct StructInjectOpConversion
423 :
public ConvertOpToLLVMPattern<hw::StructInjectOp> {
424 using ConvertOpToLLVMPattern<hw::StructInjectOp>::ConvertOpToLLVMPattern;
427 matchAndRewrite(hw::StructInjectOp op, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter)
const override {
431 op.getInput().getType(), op.getFieldIndex());
433 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
434 op, adaptor.getInput(), adaptor.getNewValue(), fieldIndex);
462static Value allocateUnionBuffer(ConversionPatternRewriter &rewriter,
463 Location loc, Type bufferType,
465 auto *
context = rewriter.getContext();
467 std::max<uint64_t>(1, DataLayout().getTypePreferredAlignment(accessType));
468 Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
469 rewriter.getI32IntegerAttr(1));
470 return LLVM::AllocaOp::create(rewriter, loc,
471 LLVM::LLVMPointerType::get(
context), bufferType,
477struct UnionCreateOpConversion
478 :
public ConvertOpToLLVMPattern<hw::UnionCreateOp> {
479 using ConvertOpToLLVMPattern<hw::UnionCreateOp>::ConvertOpToLLVMPattern;
482 matchAndRewrite(hw::UnionCreateOp op, OpAdaptor adaptor,
483 ConversionPatternRewriter &rewriter)
const override {
484 auto loc = op.getLoc();
485 auto bufferType = typeConverter->convertType(op.getType());
488 Value input = adaptor.getInput();
489 Value ptr = allocateUnionBuffer(rewriter, loc, bufferType, input.getType());
490 LLVM::StoreOp::create(rewriter, loc, input, ptr);
491 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, bufferType, ptr);
498struct UnionExtractOpConversion
499 :
public ConvertOpToLLVMPattern<hw::UnionExtractOp> {
500 using ConvertOpToLLVMPattern<hw::UnionExtractOp>::ConvertOpToLLVMPattern;
503 matchAndRewrite(hw::UnionExtractOp op, OpAdaptor adaptor,
504 ConversionPatternRewriter &rewriter)
const override {
505 auto loc = op.getLoc();
506 auto memberType = typeConverter->convertType(op.getType());
509 Value input = adaptor.getInput();
510 Value ptr = allocateUnionBuffer(rewriter, loc, input.getType(), memberType);
511 LLVM::StoreOp::create(rewriter, loc, input, ptr);
512 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, memberType, ptr);
524struct ArrayConcatOpConversion
525 :
public HWArrayOpToLLVMPattern<hw::ArrayConcatOp> {
530 ConversionPatternRewriter &rewriter)
const override {
532 hw::ArrayType arrTy = cast<hw::ArrayType>(op.getResult().getType());
533 Type resultTy = typeConverter->convertType(arrTy);
534 auto loc = op.getLoc();
536 Value arr = LLVM::UndefOp::create(rewriter, loc, resultTy);
539 size_t j = op.getInputs().size() - 1, k = 0;
541 for (
size_t i = 0, e = arrTy.getNumElements(); i < e; ++i) {
542 Value element = LLVM::ExtractValueOp::create(rewriter, loc,
543 adaptor.getInputs()[j], k);
544 arr = LLVM::InsertValueOp::create(rewriter, loc, arr, element, i);
548 cast<hw::ArrayType>(op.getInputs()[j].getType()).getNumElements()) {
554 rewriter.replaceOp(op, arr);
558 rewriter.setInsertionPointAfter(arr.getDefiningOp());
560 spillCacheOpt->map(arr, ptr);
572struct HWConstantOpConversion :
public ConvertToLLVMPattern {
573 explicit HWConstantOpConversion(MLIRContext *ctx,
574 LLVMTypeConverter &typeConverter)
575 : ConvertToLLVMPattern(
hw::ConstantOp::getOperationName(), ctx,
579 matchAndRewrite(Operation *op, ArrayRef<Value> operand,
580 ConversionPatternRewriter &rewriter)
const override {
582 auto constOp = cast<hw::ConstantOp>(op);
584 auto intType = typeConverter->convertType(constOp.getValueAttr().getType());
586 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, intType,
587 constOp.getValueAttr());
597struct HWDynamicArrayCreateOpConversion
598 :
public ConvertOpToLLVMPattern<hw::ArrayCreateOp> {
603 ConversionPatternRewriter &rewriter)
const override {
604 auto arrayTy = typeConverter->convertType(op->getResult(0).getType());
607 Value arr = LLVM::UndefOp::create(rewriter, op->getLoc(), arrayTy);
608 for (
size_t i = 0, e = op.getInputs().size(); i < e; ++i) {
612 op.getResult().getType(), i)];
613 arr = LLVM::InsertValueOp::create(rewriter, op->getLoc(), arr, input, i);
616 rewriter.replaceOp(op, arr);
626class AggregateConstantOpConversion
627 :
public HWArrayOpToLLVMPattern<hw::AggregateConstantOp> {
628 using HWArrayOpToLLVMPattern<hw::AggregateConstantOp>::HWArrayOpToLLVMPattern;
630 bool containsArrayAndStructAggregatesOnly(Type type)
const;
632 bool isMultiDimArrayOfIntegers(Type type,
633 SmallVectorImpl<int64_t> &dims)
const;
635 void flatten(Type type, Attribute attr,
636 SmallVectorImpl<Attribute> &output)
const;
638 Value constructAggregate(OpBuilder &builder,
639 const TypeConverter &typeConverter, Location loc,
640 Type type, Attribute data)
const;
643 explicit AggregateConstantOpConversion(
644 LLVMTypeConverter &typeConverter,
645 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
646 &constAggregateGlobalsMap,
647 Namespace &globals, std::optional<HWToLLVMArraySpillCache> &spillCacheOpt)
648 : HWArrayOpToLLVMPattern(typeConverter, spillCacheOpt),
649 constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals) {}
652 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
653 ConversionPatternRewriter &rewriter)
const override;
656 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
657 &constAggregateGlobalsMap;
665struct HWStructCreateOpConversion
666 :
public ConvertOpToLLVMPattern<hw::StructCreateOp> {
671 ConversionPatternRewriter &rewriter)
const override {
673 auto resTy = typeConverter->convertType(op.getResult().getType());
675 Value tup = LLVM::UndefOp::create(rewriter, op->getLoc(), resTy);
676 for (
size_t i = 0, e = cast<LLVM::LLVMStructType>(resTy).getBody().size();
680 op.getResult().getType(), i)];
681 tup = LLVM::InsertValueOp::create(rewriter, op->getLoc(), tup, input, i);
684 rewriter.replaceOp(op, tup);
694bool AggregateConstantOpConversion::containsArrayAndStructAggregatesOnly(
696 if (
auto intType = dyn_cast<IntegerType>(type))
699 if (
auto arrTy = dyn_cast<hw::ArrayType>(type))
700 return containsArrayAndStructAggregatesOnly(arrTy.getElementType());
702 if (
auto structTy = dyn_cast<hw::StructType>(type)) {
703 SmallVector<Type> innerTypes;
704 structTy.getInnerTypes(innerTypes);
705 return llvm::all_of(innerTypes, [&](
auto ty) {
706 return containsArrayAndStructAggregatesOnly(ty);
713bool AggregateConstantOpConversion::isMultiDimArrayOfIntegers(
714 Type type, SmallVectorImpl<int64_t> &dims)
const {
715 if (
auto intType = dyn_cast<IntegerType>(type))
718 if (
auto arrTy = dyn_cast<hw::ArrayType>(type)) {
719 dims.push_back(arrTy.getNumElements());
720 return isMultiDimArrayOfIntegers(arrTy.getElementType(), dims);
726void AggregateConstantOpConversion::flatten(
727 Type type, Attribute attr, SmallVectorImpl<Attribute> &output)
const {
728 if (isa<IntegerType>(type)) {
729 assert(isa<IntegerAttr>(attr));
730 output.push_back(attr);
734 auto arrAttr = cast<ArrayAttr>(attr);
735 for (
size_t i = 0, e = arrAttr.size(); i < e; ++i) {
739 flatten(cast<hw::ArrayType>(type).getElementType(), element, output);
743Value AggregateConstantOpConversion::constructAggregate(
744 OpBuilder &builder,
const TypeConverter &typeConverter, Location loc,
745 Type type, Attribute data)
const {
746 Type llvmType = typeConverter.convertType(type);
748 auto getElementType = [](Type type,
size_t index) {
749 if (
auto arrTy = dyn_cast<hw::ArrayType>(type)) {
750 return arrTy.getElementType();
753 assert(isa<hw::StructType>(type));
754 auto structTy = cast<hw::StructType>(type);
755 SmallVector<Type> innerTypes;
756 structTy.getInnerTypes(innerTypes);
757 return innerTypes[index];
760 return TypeSwitch<Type, Value>(type)
761 .Case<IntegerType>([&](
auto ty) {
762 return LLVM::ConstantOp::create(builder, loc, cast<TypedAttr>(data));
764 .Case<hw::ArrayType, hw::StructType>([&](
auto ty) {
765 Value aggVal = LLVM::UndefOp::create(builder, loc, llvmType);
766 auto arrayAttr = cast<ArrayAttr>(data);
767 for (
size_t i = 0, e = arrayAttr.size(); i < e; ++i) {
770 Attribute input = arrayAttr[currIdx];
773 Value element = constructAggregate(builder, typeConverter, loc,
776 LLVM::InsertValueOp::create(builder, loc, aggVal, element, i);
783LogicalResult AggregateConstantOpConversion::matchAndRewrite(
784 hw::AggregateConstantOp op, OpAdaptor adaptor,
785 ConversionPatternRewriter &rewriter)
const {
786 Type aggregateType = op.getResult().getType();
789 if (!containsArrayAndStructAggregatesOnly(aggregateType))
792 auto llvmTy = typeConverter->convertType(op.getResult().getType());
793 auto typeAttrPair = std::make_pair(aggregateType, adaptor.getFields());
795 if (!constAggregateGlobalsMap.count(typeAttrPair) ||
796 !constAggregateGlobalsMap[typeAttrPair]) {
797 auto ipSave = rewriter.saveInsertionPoint();
799 Operation *parent = op->getParentOp();
800 while (!isa<mlir::ModuleOp>(parent->getParentOp())) {
801 parent = parent->getParentOp();
804 rewriter.setInsertionPoint(parent);
807 auto name = globals.newName(
"_aggregate_const_global");
809 SmallVector<int64_t> dims;
810 if (isMultiDimArrayOfIntegers(aggregateType, dims)) {
811 SmallVector<Attribute> ints;
812 flatten(aggregateType, adaptor.getFields(), ints);
814 auto shapedType = RankedTensorType::get(
815 dims, cast<IntegerAttr>(ints.front()).getType());
816 auto denseAttr = DenseElementsAttr::get(shapedType, ints);
818 constAggregateGlobalsMap[typeAttrPair] =
819 LLVM::GlobalOp::create(rewriter, op.getLoc(), llvmTy,
true,
820 LLVM::Linkage::Internal, name, denseAttr);
823 LLVM::GlobalOp::create(rewriter, op.getLoc(), llvmTy,
false,
824 LLVM::Linkage::Internal, name, Attribute());
826 global.getInitializerRegion().push_back(blk);
827 rewriter.setInsertionPointToStart(blk);
830 constructAggregate(rewriter, *typeConverter, op.getLoc(),
831 aggregateType, adaptor.getFields());
832 LLVM::ReturnOp::create(rewriter, op.getLoc(), aggregate);
833 constAggregateGlobalsMap[typeAttrPair] = global;
836 rewriter.restoreInsertionPoint(ipSave);
840 auto addr = LLVM::AddressOfOp::create(rewriter, op->getLoc(),
841 constAggregateGlobalsMap[typeAttrPair]);
842 auto newOp = rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmTy,
addr);
844 if (spillCacheOpt && llvm::isa<hw::ArrayType>(aggregateType))
845 spillCacheOpt->map(newOp.getResult(),
addr);
855 auto elementTy = converter.convertType(type.getElementType());
856 return LLVM::LLVMArrayType::get(elementTy, type.getNumElements());
860 LLVMTypeConverter &converter) {
861 llvm::SmallVector<Type, 8> elements;
862 mlir::SmallVector<mlir::Type> types;
863 type.getInnerTypes(types);
865 for (
int i = 0, e = types.size(); i < e; ++i)
866 elements.push_back(converter.convertType(
869 return LLVM::LLVMStructType::getLiteral(&converter.getContext(), elements);
876 uint64_t maxBytes = 0;
877 for (
auto field : type.getElements()) {
878 auto llvmFieldTy = converter.convertType(field.type);
882 std::max(maxBytes, layout.getTypeSize(llvmFieldTy).getFixedValue());
884 return LLVM::LLVMArrayType::get(IntegerType::get(&converter.getContext(), 8),
893struct HWToLLVMLoweringPass
894 :
public circt::impl::ConvertHWToLLVMBase<HWToLLVMLoweringPass> {
896 using circt::impl::ConvertHWToLLVMBase<
897 HWToLLVMLoweringPass>::ConvertHWToLLVMBase;
899 void runOnOperation()
override;
904 LLVMTypeConverter &converter, RewritePatternSet &
patterns,
906 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
907 &constAggregateGlobalsMap,
908 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt) {
909 MLIRContext *ctx = converter.getDialect()->getContext();
912 patterns.add<HWConstantOpConversion>(ctx, converter);
913 patterns.add<HWDynamicArrayCreateOpConversion, HWStructCreateOpConversion>(
915 patterns.add<AggregateConstantOpConversion>(
916 converter, constAggregateGlobalsMap, globals, spillCacheOpt);
919 patterns.add<StructExplodeOpConversion, StructExtractOpConversion,
920 StructInjectOpConversion>(converter);
923 patterns.add<UnionCreateOpConversion, UnionExtractOpConversion>(converter);
925 patterns.add<ArrayGetOpConversion, ArrayInjectOpConversion,
926 ArraySliceOpConversion, ArrayConcatOpConversion>(converter,
931 converter.addConversion(
933 converter.addConversion(
935 converter.addConversion(
939void HWToLLVMLoweringPass::runOnOperation() {
940 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
941 std::optional<HWToLLVMArraySpillCache> spillCacheOpt = {};
947 RewritePatternSet
patterns(&getContext());
948 auto converter = mlir::LLVMTypeConverter(&getContext());
951 if (spillArraysEarly) {
953 OpBuilder spillBuilder(getOperation());
954 spillCacheOpt->spillNonHWOps(spillBuilder, converter, getOperation());
957 LLVMConversionTarget target(getContext());
958 target.addIllegalDialect<hw::HWDialect>();
960 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
966 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
967 return converter.isSignatureLegal(op.getFunctionType()) &&
968 converter.isLegal(&op.getBody());
970 target.addDynamicallyLegalOp<func::ReturnOp, func::CallOp>(
971 [&](Operation *op) {
return converter.isLegal(op); });
975 constAggregateGlobalsMap, spillCacheOpt);
976 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
978 populateReturnOpTypeConversionPattern(
patterns, converter);
979 populateCallOpTypeConversionPattern(
patterns, converter);
982 ConversionConfig config;
983 config.allowPatternRollback =
false;
984 if (failed(applyPartialConversion(getOperation(), target, std::move(
patterns),
986 return signalPassFailure();
993 SmallVector<UnrealizedConversionCastOp> castOps;
994 getOperation()->walk(
995 [&](UnrealizedConversionCastOp op) { castOps.push_back(op); });
996 reconcileUnrealizedCasts(castOps,
nullptr);
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Type convertStructType(hw::StructType type, LLVMTypeConverter &converter)
static Type convertUnionType(hw::UnionType type, LLVMTypeConverter &converter)
Convert a union to a flat byte buffer large enough to hold the LLVM representation of its widest memb...
static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter, Value value)
Create a zext operation by one bit on the given value.
static Type convertArrayType(hw::ArrayType type, LLVMTypeConverter &converter)
static Value spillValueOnStack(OpBuilder &builder, Location loc, Value spillVal)
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap, std::optional< HWToLLVMArraySpillCache > &spillCacheOpt)
Get the HW to LLVM conversion patterns.
Helper class mapping array values (HW or LLVM Dialect) to pointers to buffers containing the array va...
Value spillHWArrayValue(OpBuilder &builder, Location loc, mlir::LLVMTypeConverter &converter, Value hwArray)
Value lookup(Value arrayValue)
Retrieve a pointer to a buffer containing the given array value (HW or LLVM Dialect).
void spillNonHWOps(mlir::OpBuilder &builder, mlir::LLVMTypeConverter &converter, Operation *containerOp)
Spill HW array values produced by 'foreign' dialects on the stack.
void map(mlir::Value arrayValue, mlir::Value bufferPtr)
Map an LLVM array value to an LLVM pointer.
Value spillLLVMArrayValue(OpBuilder &builder, Location loc, Value llvmArray)
llvm::DenseMap< Value, Value > spillMap
static uint32_t convertToLLVMEndianess(Type type, uint32_t index)
Convert an index into a HW ArrayType or StructType to LLVM Endianess.
static uint32_t llvmIndexOfStructField(hw::StructType type, StringRef fieldName)
Get the index of a specific StructType field in the LLVM lowering of the StructType.