14 #include "../PassDetail.h"
18 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
19 #include "mlir/Conversion/LLVMCommon/Pattern.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/TypeSwitch.h"
26 using namespace circt;
36 return TypeSwitch<Type, uint32_t>(type)
38 [&](hw::ArrayType ty) {
return ty.getNumElements() - index - 1; })
39 .Case<hw::StructType>([&](hw::StructType ty) {
40 return ty.getElements().size() - index - 1;
46 StringRef fieldName) {
47 auto fieldIter = type.getElements();
50 for (
const auto *iter = fieldIter.begin(); iter != fieldIter.end(); ++iter) {
51 if (iter->name == fieldName) {
52 return HWToLLVMEndianessConverter::convertToLLVMEndianess(type, index);
59 llvm_unreachable(
"Field name attribute of hw::StructExtractOp invalid");
70 static Value
zextByOne(Location loc, ConversionPatternRewriter &rewriter,
72 auto valueTy = value.getType();
74 valueTy.getIntOrFloatBitWidth() + 1);
75 return rewriter.create<LLVM::ZExtOp>(loc, zextTy, value);
86 struct StructExplodeOpConversion
87 :
public ConvertOpToLLVMPattern<hw::StructExplodeOp> {
88 using ConvertOpToLLVMPattern<hw::StructExplodeOp>::ConvertOpToLLVMPattern;
91 matchAndRewrite(hw::StructExplodeOp op, OpAdaptor adaptor,
92 ConversionPatternRewriter &rewriter)
const override {
94 SmallVector<Value> replacements;
96 for (
size_t i = 0, e = adaptor.getInput()
98 .cast<LLVM::LLVMStructType>()
103 replacements.push_back(rewriter.create<LLVM::ExtractValueOp>(
104 op->getLoc(), adaptor.getInput(),
105 HWToLLVMEndianessConverter::convertToLLVMEndianess(
106 op.getInput().getType(), i)));
108 rewriter.replaceOp(op, replacements);
118 struct StructExtractOpConversion
119 :
public ConvertOpToLLVMPattern<hw::StructExtractOp> {
120 using ConvertOpToLLVMPattern<hw::StructExtractOp>::ConvertOpToLLVMPattern;
123 matchAndRewrite(hw::StructExtractOp op, OpAdaptor adaptor,
124 ConversionPatternRewriter &rewriter)
const override {
126 uint32_t fieldIndex = HWToLLVMEndianessConverter::llvmIndexOfStructField(
127 op.getInput().getType().cast<hw::StructType>(), op.getField());
128 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, adaptor.getInput(),
139 struct ArrayGetOpConversion :
public ConvertOpToLLVMPattern<hw::ArrayGetOp> {
140 using ConvertOpToLLVMPattern<hw::ArrayGetOp>::ConvertOpToLLVMPattern;
143 matchAndRewrite(hw::ArrayGetOp op, OpAdaptor adaptor,
144 ConversionPatternRewriter &rewriter)
const override {
147 if (
auto load = adaptor.getInput().getDefiningOp<LLVM::LoadOp>()) {
150 arrPtr = load.getAddr();
152 auto oneC = rewriter.create<LLVM::ConstantOp>(
154 rewriter.getI32IntegerAttr(1));
155 arrPtr = rewriter.create<LLVM::AllocaOp>(
159 rewriter.create<LLVM::StoreOp>(op->getLoc(), adaptor.getInput(), arrPtr);
162 auto elemTy = typeConverter->convertType(op.getResult().getType());
164 auto zeroC = rewriter.create<LLVM::ConstantOp>(
166 rewriter.getI32IntegerAttr(0));
167 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getIndex());
168 auto gep = rewriter.create<LLVM::GEPOp>(
170 ArrayRef<Value>({zeroC, zextIndex}));
171 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elemTy, gep);
182 struct ArraySliceOpConversion
183 :
public ConvertOpToLLVMPattern<hw::ArraySliceOp> {
184 using ConvertOpToLLVMPattern<hw::ArraySliceOp>::ConvertOpToLLVMPattern;
187 matchAndRewrite(hw::ArraySliceOp op, OpAdaptor adaptor,
188 ConversionPatternRewriter &rewriter)
const override {
190 auto dstTy = typeConverter->convertType(op.getDst().getType());
191 auto elemTy = typeConverter->convertType(
192 op.getDst().getType().cast<hw::ArrayType>().getElementType());
194 auto zeroC = rewriter.create<LLVM::ConstantOp>(
195 op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
196 auto oneC = rewriter.create<LLVM::ConstantOp>(
197 op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
199 auto arrPtr = rewriter.create<LLVM::AllocaOp>(
204 rewriter.create<LLVM::StoreOp>(op->getLoc(), adaptor.getInput(), arrPtr);
206 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getLowIndex());
208 auto gep = rewriter.create<LLVM::GEPOp>(
210 ArrayRef<Value>({zeroC, zextIndex}));
212 auto cast = rewriter.create<LLVM::BitcastOp>(
215 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dstTy, cast);
230 struct StructInjectOpConversion
231 :
public ConvertOpToLLVMPattern<hw::StructInjectOp> {
232 using ConvertOpToLLVMPattern<hw::StructInjectOp>::ConvertOpToLLVMPattern;
235 matchAndRewrite(hw::StructInjectOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter)
const override {
238 uint32_t fieldIndex = HWToLLVMEndianessConverter::llvmIndexOfStructField(
239 op.getInput().getType().cast<hw::StructType>(),
240 op.getFieldAttr().getValue());
242 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
243 op, adaptor.getInput(), op.getNewValue(), fieldIndex);
256 struct ArrayConcatOpConversion
257 :
public ConvertOpToLLVMPattern<hw::ArrayConcatOp> {
258 using ConvertOpToLLVMPattern<hw::ArrayConcatOp>::ConvertOpToLLVMPattern;
261 matchAndRewrite(hw::ArrayConcatOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const override {
264 hw::ArrayType arrTy = op.getResult().getType().cast<hw::ArrayType>();
265 Type resultTy = typeConverter->convertType(arrTy);
267 Value arr = rewriter.create<LLVM::UndefOp>(op->getLoc(), resultTy);
270 size_t j = op.getInputs().size() - 1, k = 0;
272 for (
size_t i = 0, e = arrTy.getNumElements(); i < e; ++i) {
273 Value element = rewriter.create<LLVM::ExtractValueOp>(
274 op->getLoc(), adaptor.getInputs()[j], k);
275 arr = rewriter.create<LLVM::InsertValueOp>(op->getLoc(), arr, element, i);
279 op.getInputs()[j].getType().cast<hw::ArrayType>().getNumElements()) {
285 rewriter.replaceOp(op, arr);
300 struct BitcastOpConversion :
public ConvertOpToLLVMPattern<hw::BitcastOp> {
301 using ConvertOpToLLVMPattern<hw::BitcastOp>::ConvertOpToLLVMPattern;
304 matchAndRewrite(hw::BitcastOp op, OpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter)
const override {
307 Type resultTy = typeConverter->convertType(op.getResult().getType());
309 auto oneC = rewriter.createOrFold<LLVM::ConstantOp>(
310 op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
312 auto ptr = rewriter.create<LLVM::AllocaOp>(
317 rewriter.create<LLVM::StoreOp>(op->getLoc(), adaptor.getInput(), ptr);
319 auto cast = rewriter.create<LLVM::BitcastOp>(
322 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, resultTy, cast);
334 struct HWConstantOpConversion :
public ConvertToLLVMPattern {
335 explicit HWConstantOpConversion(MLIRContext *ctx,
336 LLVMTypeConverter &typeConverter)
337 : ConvertToLLVMPattern(
hw::ConstantOp::getOperationName(), ctx,
341 matchAndRewrite(Operation *op, ArrayRef<Value> operand,
342 ConversionPatternRewriter &rewriter)
const override {
344 auto constOp = cast<hw::ConstantOp>(op);
346 auto intType = typeConverter->convertType(constOp.getValueAttr().getType());
348 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, intType,
349 constOp.getValueAttr());
359 struct HWDynamicArrayCreateOpConversion
360 :
public ConvertOpToLLVMPattern<hw::ArrayCreateOp> {
361 using ConvertOpToLLVMPattern<hw::ArrayCreateOp>::ConvertOpToLLVMPattern;
364 matchAndRewrite(hw::ArrayCreateOp op, OpAdaptor adaptor,
365 ConversionPatternRewriter &rewriter)
const override {
366 auto arrayTy = typeConverter->convertType(op->getResult(0).getType());
369 Value arr = rewriter.create<LLVM::UndefOp>(op->getLoc(), arrayTy);
370 for (
size_t i = 0, e = op.getInputs().size(); i < e; ++i) {
373 .getInputs()[HWToLLVMEndianessConverter::convertToLLVMEndianess(
374 op.getResult().getType(), i)];
375 arr = rewriter.create<LLVM::InsertValueOp>(op->getLoc(), arr, input, i);
378 rewriter.replaceOp(op, arr);
388 class AggregateConstantOpConversion
389 :
public ConvertOpToLLVMPattern<hw::AggregateConstantOp> {
390 using ConvertOpToLLVMPattern<hw::AggregateConstantOp>::ConvertOpToLLVMPattern;
392 bool containsArrayAndStructAggregatesOnly(Type type)
const;
394 bool isMultiDimArrayOfIntegers(Type type,
395 SmallVectorImpl<int64_t> &dims)
const;
397 void flatten(Type type, Attribute attr,
398 SmallVectorImpl<Attribute> &output)
const;
400 Value constructAggregate(OpBuilder &
builder,
401 const TypeConverter &typeConverter, Location loc,
402 Type type, Attribute data)
const;
405 explicit AggregateConstantOpConversion(
406 LLVMTypeConverter &typeConverter,
407 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
408 &constAggregateGlobalsMap,
410 : ConvertOpToLLVMPattern(typeConverter),
411 constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals) {}
414 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
415 ConversionPatternRewriter &rewriter)
const override;
418 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
419 &constAggregateGlobalsMap;
427 struct HWStructCreateOpConversion
428 :
public ConvertOpToLLVMPattern<hw::StructCreateOp> {
429 using ConvertOpToLLVMPattern<hw::StructCreateOp>::ConvertOpToLLVMPattern;
432 matchAndRewrite(hw::StructCreateOp op, OpAdaptor adaptor,
433 ConversionPatternRewriter &rewriter)
const override {
435 auto resTy = typeConverter->convertType(op.getResult().getType());
437 Value tup = rewriter.create<LLVM::UndefOp>(op->getLoc(), resTy);
438 for (
size_t i = 0, e = resTy.cast<LLVM::LLVMStructType>().getBody().size();
441 adaptor.getInput()[HWToLLVMEndianessConverter::convertToLLVMEndianess(
442 op.getResult().getType(), i)];
443 tup = rewriter.create<LLVM::InsertValueOp>(op->getLoc(), tup, input, i);
446 rewriter.replaceOp(op, tup);
456 bool AggregateConstantOpConversion::containsArrayAndStructAggregatesOnly(
458 if (
auto intType = type.dyn_cast<IntegerType>())
461 if (
auto arrTy = type.dyn_cast<hw::ArrayType>())
462 return containsArrayAndStructAggregatesOnly(arrTy.getElementType());
464 if (
auto structTy = type.dyn_cast<hw::StructType>()) {
465 SmallVector<Type> innerTypes;
466 structTy.getInnerTypes(innerTypes);
467 return llvm::all_of(innerTypes, [&](
auto ty) {
468 return containsArrayAndStructAggregatesOnly(ty);
475 bool AggregateConstantOpConversion::isMultiDimArrayOfIntegers(
476 Type type, SmallVectorImpl<int64_t> &dims)
const {
477 if (
auto intType = type.dyn_cast<IntegerType>())
480 if (
auto arrTy = type.dyn_cast<hw::ArrayType>()) {
481 dims.push_back(arrTy.getNumElements());
482 return isMultiDimArrayOfIntegers(arrTy.getElementType(), dims);
488 void AggregateConstantOpConversion::flatten(
489 Type type, Attribute attr, SmallVectorImpl<Attribute> &output)
const {
490 if (type.isa<IntegerType>()) {
491 assert(attr.isa<IntegerAttr>());
492 output.push_back(attr);
496 auto arrAttr = attr.cast<ArrayAttr>();
497 for (
size_t i = 0, e = arrAttr.size(); i < e; ++i) {
499 arrAttr[HWToLLVMEndianessConverter::convertToLLVMEndianess(type, i)];
501 flatten(type.cast<hw::ArrayType>().getElementType(), element, output);
505 Value AggregateConstantOpConversion::constructAggregate(
506 OpBuilder &
builder,
const TypeConverter &typeConverter, Location loc,
507 Type type, Attribute data)
const {
508 Type llvmType = typeConverter.convertType(type);
510 auto getElementType = [](Type type,
size_t index) {
511 if (
auto arrTy = type.dyn_cast<hw::ArrayType>()) {
512 return arrTy.getElementType();
515 assert(type.isa<hw::StructType>());
516 auto structTy = type.cast<hw::StructType>();
517 SmallVector<Type> innerTypes;
518 structTy.getInnerTypes(innerTypes);
519 return innerTypes[index];
522 return TypeSwitch<Type, Value>(type)
523 .Case<IntegerType>([&](
auto ty) {
524 return builder.create<LLVM::ConstantOp>(loc,
data.cast<TypedAttr>());
526 .Case<hw::ArrayType, hw::StructType>([&](
auto ty) {
527 Value aggVal =
builder.create<LLVM::UndefOp>(loc, llvmType);
528 auto arrayAttr =
data.cast<ArrayAttr>();
529 for (
size_t i = 0, e = arrayAttr.size(); i < e; ++i) {
531 HWToLLVMEndianessConverter::convertToLLVMEndianess(type, i);
532 Attribute input = arrayAttr[currIdx];
535 Value element = constructAggregate(
builder, typeConverter, loc,
537 aggVal =
builder.create<LLVM::InsertValueOp>(loc, aggVal, element, i);
544 LogicalResult AggregateConstantOpConversion::matchAndRewrite(
545 hw::AggregateConstantOp op, OpAdaptor adaptor,
546 ConversionPatternRewriter &rewriter)
const {
547 Type aggregateType = op.getResult().getType();
550 if (!containsArrayAndStructAggregatesOnly(aggregateType))
553 auto llvmTy = typeConverter->convertType(op.getResult().getType());
554 auto typeAttrPair = std::make_pair(aggregateType, adaptor.getFields());
556 if (!constAggregateGlobalsMap.count(typeAttrPair) ||
557 !constAggregateGlobalsMap[typeAttrPair]) {
558 auto ipSave = rewriter.saveInsertionPoint();
560 Operation *parent = op->getParentOp();
561 while (!isa<mlir::ModuleOp>(parent->getParentOp())) {
562 parent = parent->getParentOp();
565 rewriter.setInsertionPoint(parent);
568 auto name = globals.newName(
"_aggregate_const_global");
570 SmallVector<int64_t> dims;
571 if (isMultiDimArrayOfIntegers(aggregateType, dims)) {
572 SmallVector<Attribute> ints;
573 flatten(aggregateType, adaptor.getFields(), ints);
576 dims, ints.front().cast<IntegerAttr>().getType());
579 constAggregateGlobalsMap[typeAttrPair] = rewriter.create<LLVM::GlobalOp>(
580 op.getLoc(), llvmTy,
true, LLVM::Linkage::Internal, name, denseAttr);
582 auto global = rewriter.create<LLVM::GlobalOp>(op.getLoc(), llvmTy,
false,
583 LLVM::Linkage::Internal,
586 global.getInitializerRegion().push_back(blk);
587 rewriter.setInsertionPointToStart(blk);
590 constructAggregate(rewriter, *typeConverter, op.getLoc(),
591 aggregateType, adaptor.getFields());
592 rewriter.create<LLVM::ReturnOp>(op.getLoc(), aggregate);
593 constAggregateGlobalsMap[typeAttrPair] = global;
596 rewriter.restoreInsertionPoint(ipSave);
600 auto addr = rewriter.create<LLVM::AddressOfOp>(
601 op->getLoc(), constAggregateGlobalsMap[typeAttrPair]);
602 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmTy,
addr);
612 auto elementTy = converter.convertType(type.getElementType());
617 LLVMTypeConverter &converter) {
618 llvm::SmallVector<Type, 8> elements;
619 mlir::SmallVector<mlir::Type> types;
620 type.getInnerTypes(types);
622 for (
int i = 0, e = types.size(); i < e; ++i)
623 elements.push_back(converter.convertType(
624 types[HWToLLVMEndianessConverter::convertToLLVMEndianess(type, i)]));
626 return LLVM::LLVMStructType::getLiteral(&converter.getContext(), elements);
634 struct HWToLLVMLoweringPass :
public ConvertHWToLLVMBase<HWToLLVMLoweringPass> {
635 void runOnOperation()
override;
640 LLVMTypeConverter &converter, RewritePatternSet &
patterns,
642 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
643 &constAggregateGlobalsMap) {
644 MLIRContext *ctx = converter.getDialect()->getContext();
647 patterns.add<HWConstantOpConversion>(ctx, converter);
648 patterns.add<HWDynamicArrayCreateOpConversion, HWStructCreateOpConversion>(
650 patterns.add<AggregateConstantOpConversion>(
651 converter, constAggregateGlobalsMap, globals);
654 patterns.add<BitcastOpConversion>(converter);
657 patterns.add<ArrayGetOpConversion, ArraySliceOpConversion,
658 ArrayConcatOpConversion, StructExplodeOpConversion,
659 StructExtractOpConversion, StructInjectOpConversion>(converter);
663 converter.addConversion(
665 converter.addConversion(
669 void HWToLLVMLoweringPass::runOnOperation() {
670 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
676 RewritePatternSet
patterns(&getContext());
677 auto converter = mlir::LLVMTypeConverter(&getContext());
680 LLVMConversionTarget target(getContext());
681 target.addLegalOp<UnrealizedConversionCastOp>();
682 target.addLegalOp<ModuleOp>();
683 target.addLegalDialect<LLVM::LLVMDialect>();
684 target.addIllegalDialect<hw::HWDialect>();
688 constAggregateGlobalsMap);
692 applyPartialConversion(getOperation(), target, std::move(
patterns))))
698 return std::make_unique<HWToLLVMLoweringPass>();
assert(baseType &&"element must be base type")
static Type convertStructType(hw::StructType type, LLVMTypeConverter &converter)
static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter, Value value)
Create a zext operation by one bit on the given value.
static Type convertArrayType(hw::ArrayType type, LLVMTypeConverter &converter)
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(SymbolCache &symCache)
SymbolCache initializer; initialize from every key that is convertible to a StringAttr in the SymbolC...
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap)
Get the HW to LLVM conversion patterns.
std::unique_ptr< OperationPass< ModuleOp > > createConvertHWToLLVMPass()
Create an HW to LLVM conversion pass.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
static uint32_t convertToLLVMEndianess(Type type, uint32_t index)
Convert an index into a HW ArrayType or StructType to LLVM Endianess.
static uint32_t llvmIndexOfStructField(hw::StructType type, StringRef fieldName)
Get the index of a specific StructType field in the LLVM lowering of the StructType.