17#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18#include "mlir/Conversion/LLVMCommon/Pattern.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "llvm/ADT/TypeSwitch.h"
25#define GEN_PASS_DEF_CONVERTHWTOLLVM
26#include "circt/Conversion/Passes.h.inc"
40 return TypeSwitch<Type, uint32_t>(type)
42 [&](hw::ArrayType ty) {
return ty.getNumElements() - index - 1; })
43 .Case<hw::StructType>([&](hw::StructType ty) {
44 return ty.getElements().size() - index - 1;
50 StringRef fieldName) {
51 auto fieldIter = type.getElements();
54 for (
const auto *iter = fieldIter.begin(); iter != fieldIter.end(); ++iter) {
55 if (iter->name == fieldName) {
63 llvm_unreachable(
"Field name attribute of hw::StructExtractOp invalid");
74static Value
zextByOne(Location loc, ConversionPatternRewriter &rewriter,
76 auto valueTy = value.getType();
77 auto zextTy = IntegerType::get(valueTy.getContext(),
78 valueTy.getIntOrFloatBitWidth() + 1);
79 return LLVM::ZExtOp::create(rewriter, loc, zextTy, value);
90struct StructExplodeOpConversion
91 :
public ConvertOpToLLVMPattern<hw::StructExplodeOp> {
92 using ConvertOpToLLVMPattern<hw::StructExplodeOp>::ConvertOpToLLVMPattern;
95 matchAndRewrite(hw::StructExplodeOp op, OpAdaptor adaptor,
96 ConversionPatternRewriter &rewriter)
const override {
98 SmallVector<Value> replacements;
101 e = cast<LLVM::LLVMStructType>(adaptor.getInput().getType())
106 replacements.push_back(LLVM::ExtractValueOp::create(
107 rewriter, op->getLoc(), adaptor.getInput(),
109 op.getInput().getType(), i)));
111 rewriter.replaceOp(op, replacements);
121struct StructExtractOpConversion
122 :
public ConvertOpToLLVMPattern<hw::StructExtractOp> {
127 ConversionPatternRewriter &rewriter)
const override {
130 op.getInput().getType(), op.getFieldIndex());
131 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, adaptor.getInput(),
143struct ArrayInjectOpConversion
144 :
public ConvertOpToLLVMPattern<hw::ArrayInjectOp> {
145 using ConvertOpToLLVMPattern<hw::ArrayInjectOp>::ConvertOpToLLVMPattern;
148 matchAndRewrite(hw::ArrayInjectOp op, OpAdaptor adaptor,
149 ConversionPatternRewriter &rewriter)
const override {
150 auto inputType = cast<hw::ArrayType>(op.getInput().getType());
151 auto oldArrTy = adaptor.getInput().getType();
152 auto newArrTy = oldArrTy;
153 const size_t arrElems = inputType.getNumElements();
156 rewriter.replaceOp(op, adaptor.getInput());
161 LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI32Type(),
162 rewriter.getI32IntegerAttr(1));
163 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getIndex());
166 if (arrElems == 1 || !llvm::isPowerOf2_64(arrElems)) {
171 LLVM::ConstantOp::create(rewriter, op->getLoc(), zextIndex.getType(),
172 rewriter.getI32IntegerAttr(arrElems));
174 LLVM::UMinOp::create(rewriter, op->getLoc(), zextIndex, maxIndex);
176 newArrTy = typeConverter->convertType(
177 hw::ArrayType::get(inputType.getElementType(), arrElems + 1));
178 arrPtr = LLVM::AllocaOp::create(
179 rewriter, op->getLoc(),
180 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, oneC,
183 arrPtr = LLVM::AllocaOp::create(
184 rewriter, op->getLoc(),
185 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, oneC,
189 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getInput(), arrPtr);
191 auto gep = LLVM::GEPOp::create(
192 rewriter, op->getLoc(),
193 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, arrPtr,
194 ArrayRef<LLVM::GEPArg>{0, zextIndex});
196 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getElement(), gep);
197 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, oldArrTy, arrPtr);
207struct ArrayGetOpConversion :
public ConvertOpToLLVMPattern<hw::ArrayGetOp> {
208 using ConvertOpToLLVMPattern<
hw::ArrayGetOp>::ConvertOpToLLVMPattern;
212 ConversionPatternRewriter &rewriter)
const override {
215 if (
auto load = adaptor.getInput().getDefiningOp<LLVM::LoadOp>()) {
218 arrPtr = load.getAddr();
220 auto oneC = LLVM::ConstantOp::create(
221 rewriter, op->getLoc(), IntegerType::get(rewriter.getContext(), 32),
222 rewriter.getI32IntegerAttr(1));
223 arrPtr = LLVM::AllocaOp::create(
224 rewriter, op->getLoc(),
225 LLVM::LLVMPointerType::get(rewriter.getContext()),
226 adaptor.getInput().getType(), oneC,
228 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getInput(), arrPtr);
231 auto arrTy = typeConverter->convertType(op.getInput().getType());
232 auto elemTy = typeConverter->convertType(op.getResult().getType());
233 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getIndex());
238 auto gep = LLVM::GEPOp::create(
239 rewriter, op->getLoc(),
240 LLVM::LLVMPointerType::get(rewriter.getContext()), arrTy, arrPtr,
241 ArrayRef<LLVM::GEPArg>{0, zextIndex});
242 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elemTy, gep);
253struct ArraySliceOpConversion
254 :
public ConvertOpToLLVMPattern<hw::ArraySliceOp> {
259 ConversionPatternRewriter &rewriter)
const override {
261 auto dstTy = typeConverter->convertType(op.getDst().getType());
264 LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI32Type(),
265 rewriter.getI32IntegerAttr(1));
267 auto arrPtr = LLVM::AllocaOp::create(
268 rewriter, op->getLoc(),
269 LLVM::LLVMPointerType::get(rewriter.getContext()),
270 adaptor.getInput().getType(), oneC,
273 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getInput(), arrPtr);
275 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getLowIndex());
280 auto gep = LLVM::GEPOp::create(
281 rewriter, op->getLoc(),
282 LLVM::LLVMPointerType::get(rewriter.getContext()), dstTy, arrPtr,
283 ArrayRef<LLVM::GEPArg>{0, zextIndex});
285 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dstTy, gep);
300struct StructInjectOpConversion
301 :
public ConvertOpToLLVMPattern<hw::StructInjectOp> {
302 using ConvertOpToLLVMPattern<hw::StructInjectOp>::ConvertOpToLLVMPattern;
305 matchAndRewrite(hw::StructInjectOp op, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter)
const override {
309 op.getInput().getType(), op.getFieldIndex());
311 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
312 op, adaptor.getInput(), adaptor.getNewValue(), fieldIndex);
325struct ArrayConcatOpConversion
326 :
public ConvertOpToLLVMPattern<hw::ArrayConcatOp> {
331 ConversionPatternRewriter &rewriter)
const override {
333 hw::ArrayType arrTy = cast<hw::ArrayType>(op.getResult().getType());
334 Type resultTy = typeConverter->convertType(arrTy);
336 Value arr = LLVM::UndefOp::create(rewriter, op->getLoc(), resultTy);
339 size_t j = op.getInputs().size() - 1, k = 0;
341 for (
size_t i = 0, e = arrTy.getNumElements(); i < e; ++i) {
342 Value element = LLVM::ExtractValueOp::create(rewriter, op->getLoc(),
343 adaptor.getInputs()[j], k);
345 LLVM::InsertValueOp::create(rewriter, op->getLoc(), arr, element, i);
349 cast<hw::ArrayType>(op.getInputs()[j].getType()).getNumElements()) {
355 rewriter.replaceOp(op, arr);
370struct BitcastOpConversion :
public ConvertOpToLLVMPattern<hw::BitcastOp> {
371 using ConvertOpToLLVMPattern<
hw::BitcastOp>::ConvertOpToLLVMPattern;
375 ConversionPatternRewriter &rewriter)
const override {
377 Type resultTy = typeConverter->convertType(op.getResult().getType());
379 auto oneC = rewriter.createOrFold<LLVM::ConstantOp>(
380 op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
382 auto ptr = LLVM::AllocaOp::create(
383 rewriter, op->getLoc(),
384 LLVM::LLVMPointerType::get(rewriter.getContext()),
385 adaptor.getInput().getType(), oneC,
388 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getInput(), ptr);
390 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, resultTy, ptr);
402struct HWConstantOpConversion :
public ConvertToLLVMPattern {
403 explicit HWConstantOpConversion(MLIRContext *ctx,
404 LLVMTypeConverter &typeConverter)
405 : ConvertToLLVMPattern(
hw::ConstantOp::getOperationName(), ctx,
409 matchAndRewrite(Operation *op, ArrayRef<Value> operand,
410 ConversionPatternRewriter &rewriter)
const override {
412 auto constOp = cast<hw::ConstantOp>(op);
414 auto intType = typeConverter->convertType(constOp.getValueAttr().getType());
416 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, intType,
417 constOp.getValueAttr());
427struct HWDynamicArrayCreateOpConversion
428 :
public ConvertOpToLLVMPattern<hw::ArrayCreateOp> {
433 ConversionPatternRewriter &rewriter)
const override {
434 auto arrayTy = typeConverter->convertType(op->getResult(0).getType());
437 Value arr = LLVM::UndefOp::create(rewriter, op->getLoc(), arrayTy);
438 for (
size_t i = 0, e = op.getInputs().size(); i < e; ++i) {
442 op.getResult().getType(), i)];
443 arr = LLVM::InsertValueOp::create(rewriter, op->getLoc(), arr, input, i);
446 rewriter.replaceOp(op, arr);
456class AggregateConstantOpConversion
457 :
public ConvertOpToLLVMPattern<hw::AggregateConstantOp> {
458 using ConvertOpToLLVMPattern<hw::AggregateConstantOp>::ConvertOpToLLVMPattern;
460 bool containsArrayAndStructAggregatesOnly(Type type)
const;
462 bool isMultiDimArrayOfIntegers(Type type,
463 SmallVectorImpl<int64_t> &dims)
const;
465 void flatten(Type type, Attribute attr,
466 SmallVectorImpl<Attribute> &output)
const;
468 Value constructAggregate(OpBuilder &builder,
469 const TypeConverter &typeConverter, Location loc,
470 Type type, Attribute data)
const;
473 explicit AggregateConstantOpConversion(
474 LLVMTypeConverter &typeConverter,
475 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
476 &constAggregateGlobalsMap,
478 : ConvertOpToLLVMPattern(typeConverter),
479 constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals) {}
482 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
483 ConversionPatternRewriter &rewriter)
const override;
486 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
487 &constAggregateGlobalsMap;
495struct HWStructCreateOpConversion
496 :
public ConvertOpToLLVMPattern<hw::StructCreateOp> {
501 ConversionPatternRewriter &rewriter)
const override {
503 auto resTy = typeConverter->convertType(op.getResult().getType());
505 Value tup = LLVM::UndefOp::create(rewriter, op->getLoc(), resTy);
506 for (
size_t i = 0, e = cast<LLVM::LLVMStructType>(resTy).getBody().size();
510 op.getResult().getType(), i)];
511 tup = LLVM::InsertValueOp::create(rewriter, op->getLoc(), tup, input, i);
514 rewriter.replaceOp(op, tup);
524bool AggregateConstantOpConversion::containsArrayAndStructAggregatesOnly(
526 if (
auto intType = dyn_cast<IntegerType>(type))
529 if (
auto arrTy = dyn_cast<hw::ArrayType>(type))
530 return containsArrayAndStructAggregatesOnly(arrTy.getElementType());
532 if (
auto structTy = dyn_cast<hw::StructType>(type)) {
533 SmallVector<Type> innerTypes;
534 structTy.getInnerTypes(innerTypes);
535 return llvm::all_of(innerTypes, [&](
auto ty) {
536 return containsArrayAndStructAggregatesOnly(ty);
543bool AggregateConstantOpConversion::isMultiDimArrayOfIntegers(
544 Type type, SmallVectorImpl<int64_t> &dims)
const {
545 if (
auto intType = dyn_cast<IntegerType>(type))
548 if (
auto arrTy = dyn_cast<hw::ArrayType>(type)) {
549 dims.push_back(arrTy.getNumElements());
550 return isMultiDimArrayOfIntegers(arrTy.getElementType(), dims);
556void AggregateConstantOpConversion::flatten(
557 Type type, Attribute attr, SmallVectorImpl<Attribute> &output)
const {
558 if (isa<IntegerType>(type)) {
559 assert(isa<IntegerAttr>(attr));
560 output.push_back(attr);
564 auto arrAttr = cast<ArrayAttr>(attr);
565 for (
size_t i = 0, e = arrAttr.size(); i < e; ++i) {
569 flatten(cast<hw::ArrayType>(type).getElementType(), element, output);
573Value AggregateConstantOpConversion::constructAggregate(
574 OpBuilder &builder,
const TypeConverter &typeConverter, Location loc,
575 Type type, Attribute data)
const {
576 Type llvmType = typeConverter.convertType(type);
578 auto getElementType = [](Type type,
size_t index) {
579 if (
auto arrTy = dyn_cast<hw::ArrayType>(type)) {
580 return arrTy.getElementType();
583 assert(isa<hw::StructType>(type));
584 auto structTy = cast<hw::StructType>(type);
585 SmallVector<Type> innerTypes;
586 structTy.getInnerTypes(innerTypes);
587 return innerTypes[index];
590 return TypeSwitch<Type, Value>(type)
591 .Case<IntegerType>([&](
auto ty) {
592 return LLVM::ConstantOp::create(builder, loc, cast<TypedAttr>(data));
594 .Case<hw::ArrayType, hw::StructType>([&](
auto ty) {
595 Value aggVal = LLVM::UndefOp::create(builder, loc, llvmType);
596 auto arrayAttr = cast<ArrayAttr>(data);
597 for (
size_t i = 0, e = arrayAttr.size(); i < e; ++i) {
600 Attribute input = arrayAttr[currIdx];
603 Value element = constructAggregate(builder, typeConverter, loc,
606 LLVM::InsertValueOp::create(builder, loc, aggVal, element, i);
613LogicalResult AggregateConstantOpConversion::matchAndRewrite(
614 hw::AggregateConstantOp op, OpAdaptor adaptor,
615 ConversionPatternRewriter &rewriter)
const {
616 Type aggregateType = op.getResult().getType();
619 if (!containsArrayAndStructAggregatesOnly(aggregateType))
622 auto llvmTy = typeConverter->convertType(op.getResult().getType());
623 auto typeAttrPair = std::make_pair(aggregateType, adaptor.getFields());
625 if (!constAggregateGlobalsMap.count(typeAttrPair) ||
626 !constAggregateGlobalsMap[typeAttrPair]) {
627 auto ipSave = rewriter.saveInsertionPoint();
629 Operation *parent = op->getParentOp();
630 while (!isa<mlir::ModuleOp>(parent->getParentOp())) {
631 parent = parent->getParentOp();
634 rewriter.setInsertionPoint(parent);
637 auto name = globals.newName(
"_aggregate_const_global");
639 SmallVector<int64_t> dims;
640 if (isMultiDimArrayOfIntegers(aggregateType, dims)) {
641 SmallVector<Attribute> ints;
642 flatten(aggregateType, adaptor.getFields(), ints);
644 auto shapedType = RankedTensorType::get(
645 dims, cast<IntegerAttr>(ints.front()).getType());
646 auto denseAttr = DenseElementsAttr::get(shapedType, ints);
648 constAggregateGlobalsMap[typeAttrPair] =
649 LLVM::GlobalOp::create(rewriter, op.getLoc(), llvmTy,
true,
650 LLVM::Linkage::Internal, name, denseAttr);
653 LLVM::GlobalOp::create(rewriter, op.getLoc(), llvmTy,
false,
654 LLVM::Linkage::Internal, name, Attribute());
656 global.getInitializerRegion().push_back(blk);
657 rewriter.setInsertionPointToStart(blk);
660 constructAggregate(rewriter, *typeConverter, op.getLoc(),
661 aggregateType, adaptor.getFields());
662 LLVM::ReturnOp::create(rewriter, op.getLoc(), aggregate);
663 constAggregateGlobalsMap[typeAttrPair] = global;
666 rewriter.restoreInsertionPoint(ipSave);
670 auto addr = LLVM::AddressOfOp::create(rewriter, op->getLoc(),
671 constAggregateGlobalsMap[typeAttrPair]);
672 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmTy,
addr);
682 auto elementTy = converter.convertType(type.getElementType());
683 return LLVM::LLVMArrayType::get(elementTy, type.getNumElements());
687 LLVMTypeConverter &converter) {
688 llvm::SmallVector<Type, 8> elements;
689 mlir::SmallVector<mlir::Type> types;
690 type.getInnerTypes(types);
692 for (
int i = 0, e = types.size(); i < e; ++i)
693 elements.push_back(converter.convertType(
696 return LLVM::LLVMStructType::getLiteral(&converter.getContext(), elements);
704struct HWToLLVMLoweringPass
705 :
public circt::impl::ConvertHWToLLVMBase<HWToLLVMLoweringPass> {
706 void runOnOperation()
override;
711 LLVMTypeConverter &converter, RewritePatternSet &
patterns,
713 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
714 &constAggregateGlobalsMap) {
715 MLIRContext *ctx = converter.getDialect()->getContext();
718 patterns.add<HWConstantOpConversion>(ctx, converter);
719 patterns.add<HWDynamicArrayCreateOpConversion, HWStructCreateOpConversion>(
721 patterns.add<AggregateConstantOpConversion>(
722 converter, constAggregateGlobalsMap, globals);
725 patterns.add<BitcastOpConversion>(converter);
728 patterns.add<ArrayInjectOpConversion, ArrayGetOpConversion,
729 ArraySliceOpConversion, ArrayConcatOpConversion,
730 StructExplodeOpConversion, StructExtractOpConversion,
731 StructInjectOpConversion>(converter);
735 converter.addConversion(
737 converter.addConversion(
741void HWToLLVMLoweringPass::runOnOperation() {
742 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
748 RewritePatternSet
patterns(&getContext());
749 auto converter = mlir::LLVMTypeConverter(&getContext());
752 LLVMConversionTarget target(getContext());
753 target.addIllegalDialect<hw::HWDialect>();
757 constAggregateGlobalsMap);
760 ConversionConfig config;
761 config.allowPatternRollback =
false;
762 if (failed(applyPartialConversion(getOperation(), target, std::move(
patterns),
769 return std::make_unique<HWToLLVMLoweringPass>();
assert(baseType &&"element must be base type")
static Type convertStructType(hw::StructType type, LLVMTypeConverter &converter)
static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter, Value value)
Create a zext operation by one bit on the given value.
static Type convertArrayType(hw::ArrayType type, LLVMTypeConverter &converter)
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap)
Get the HW to LLVM conversion patterns.
std::unique_ptr< OperationPass< ModuleOp > > createConvertHWToLLVMPass()
Create an HW to LLVM conversion pass.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
static uint32_t convertToLLVMEndianess(Type type, uint32_t index)
Convert an index into a HW ArrayType or StructType to LLVM Endianess.
static uint32_t llvmIndexOfStructField(hw::StructType type, StringRef fieldName)
Get the index of a specific StructType field in the LLVM lowering of the StructType.