17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/TypeSwitch.h"
25 #define GEN_PASS_DEF_CONVERTHWTOLLVM
26 #include "circt/Conversion/Passes.h.inc"
30 using namespace circt;
40 return TypeSwitch<Type, uint32_t>(type)
42 [&](hw::ArrayType ty) {
return ty.getNumElements() - index - 1; })
43 .Case<hw::StructType>([&](hw::StructType ty) {
44 return ty.getElements().size() - index - 1;
50 StringRef fieldName) {
51 auto fieldIter = type.getElements();
54 for (
const auto *iter = fieldIter.begin(); iter != fieldIter.end(); ++iter) {
55 if (iter->name == fieldName) {
56 return HWToLLVMEndianessConverter::convertToLLVMEndianess(type, index);
63 llvm_unreachable(
"Field name attribute of hw::StructExtractOp invalid");
74 static Value
zextByOne(Location loc, ConversionPatternRewriter &rewriter,
76 auto valueTy = value.getType();
78 valueTy.getIntOrFloatBitWidth() + 1);
79 return rewriter.create<LLVM::ZExtOp>(loc, zextTy, value);
90 struct StructExplodeOpConversion
91 :
public ConvertOpToLLVMPattern<hw::StructExplodeOp> {
92 using ConvertOpToLLVMPattern<hw::StructExplodeOp>::ConvertOpToLLVMPattern;
95 matchAndRewrite(hw::StructExplodeOp op, OpAdaptor adaptor,
96 ConversionPatternRewriter &rewriter)
const override {
98 SmallVector<Value> replacements;
101 e = cast<LLVM::LLVMStructType>(adaptor.getInput().getType())
106 replacements.push_back(rewriter.create<LLVM::ExtractValueOp>(
107 op->getLoc(), adaptor.getInput(),
108 HWToLLVMEndianessConverter::convertToLLVMEndianess(
109 op.getInput().getType(), i)));
111 rewriter.replaceOp(op, replacements);
121 struct StructExtractOpConversion
122 :
public ConvertOpToLLVMPattern<hw::StructExtractOp> {
123 using ConvertOpToLLVMPattern<hw::StructExtractOp>::ConvertOpToLLVMPattern;
127 ConversionPatternRewriter &rewriter)
const override {
129 uint32_t fieldIndex = HWToLLVMEndianessConverter::convertToLLVMEndianess(
130 op.getInput().getType(), op.getFieldIndex());
131 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, adaptor.getInput(),
142 struct ArrayGetOpConversion :
public ConvertOpToLLVMPattern<hw::ArrayGetOp> {
143 using ConvertOpToLLVMPattern<hw::ArrayGetOp>::ConvertOpToLLVMPattern;
147 ConversionPatternRewriter &rewriter)
const override {
150 if (
auto load = adaptor.getInput().getDefiningOp<LLVM::LoadOp>()) {
153 arrPtr = load.getAddr();
155 auto oneC = rewriter.create<LLVM::ConstantOp>(
157 rewriter.getI32IntegerAttr(1));
158 arrPtr = rewriter.create<LLVM::AllocaOp>(
160 adaptor.getInput().getType(), oneC,
162 rewriter.create<LLVM::StoreOp>(op->getLoc(), adaptor.getInput(), arrPtr);
165 auto arrTy = typeConverter->convertType(op.getInput().getType());
166 auto elemTy = typeConverter->convertType(op.getResult().getType());
167 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getIndex());
172 auto gep = rewriter.create<LLVM::GEPOp>(
174 arrPtr, ArrayRef<LLVM::GEPArg>{0, zextIndex});
175 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elemTy, gep);
186 struct ArraySliceOpConversion
187 :
public ConvertOpToLLVMPattern<hw::ArraySliceOp> {
188 using ConvertOpToLLVMPattern<hw::ArraySliceOp>::ConvertOpToLLVMPattern;
192 ConversionPatternRewriter &rewriter)
const override {
194 auto dstTy = typeConverter->convertType(op.getDst().getType());
196 auto oneC = rewriter.create<LLVM::ConstantOp>(
197 op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
199 auto arrPtr = rewriter.create<LLVM::AllocaOp>(
201 adaptor.getInput().getType(), oneC,
204 rewriter.create<LLVM::StoreOp>(op->getLoc(), adaptor.getInput(), arrPtr);
206 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getLowIndex());
211 auto gep = rewriter.create<LLVM::GEPOp>(
213 arrPtr, ArrayRef<LLVM::GEPArg>{0, zextIndex});
215 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dstTy, gep);
230 struct StructInjectOpConversion
231 :
public ConvertOpToLLVMPattern<hw::StructInjectOp> {
232 using ConvertOpToLLVMPattern<hw::StructInjectOp>::ConvertOpToLLVMPattern;
235 matchAndRewrite(hw::StructInjectOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter)
const override {
238 uint32_t fieldIndex = HWToLLVMEndianessConverter::convertToLLVMEndianess(
239 op.getInput().getType(), op.getFieldIndex());
241 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
242 op, adaptor.getInput(), op.getNewValue(), fieldIndex);
255 struct ArrayConcatOpConversion
256 :
public ConvertOpToLLVMPattern<hw::ArrayConcatOp> {
257 using ConvertOpToLLVMPattern<hw::ArrayConcatOp>::ConvertOpToLLVMPattern;
261 ConversionPatternRewriter &rewriter)
const override {
263 hw::ArrayType arrTy = cast<hw::ArrayType>(op.getResult().getType());
264 Type resultTy = typeConverter->convertType(arrTy);
266 Value arr = rewriter.create<LLVM::UndefOp>(op->getLoc(), resultTy);
269 size_t j = op.getInputs().size() - 1, k = 0;
271 for (
size_t i = 0, e = arrTy.getNumElements(); i < e; ++i) {
272 Value element = rewriter.
create<LLVM::ExtractValueOp>(
273 op->getLoc(), adaptor.getInputs()[j], k);
274 arr = rewriter.create<LLVM::InsertValueOp>(op->getLoc(), arr, element, i);
278 cast<hw::ArrayType>(op.getInputs()[j].getType()).getNumElements()) {
284 rewriter.replaceOp(op, arr);
299 struct BitcastOpConversion :
public ConvertOpToLLVMPattern<hw::BitcastOp> {
300 using ConvertOpToLLVMPattern<hw::BitcastOp>::ConvertOpToLLVMPattern;
304 ConversionPatternRewriter &rewriter)
const override {
306 Type resultTy = typeConverter->convertType(op.getResult().getType());
308 auto oneC = rewriter.createOrFold<LLVM::ConstantOp>(
309 op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
311 auto ptr = rewriter.create<LLVM::AllocaOp>(
313 adaptor.getInput().getType(), oneC,
316 rewriter.create<LLVM::StoreOp>(op->getLoc(), adaptor.getInput(), ptr);
318 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, resultTy, ptr);
330 struct HWConstantOpConversion :
public ConvertToLLVMPattern {
331 explicit HWConstantOpConversion(MLIRContext *ctx,
332 LLVMTypeConverter &typeConverter)
333 : ConvertToLLVMPattern(
hw::ConstantOp::getOperationName(), ctx,
337 matchAndRewrite(Operation *op, ArrayRef<Value> operand,
338 ConversionPatternRewriter &rewriter)
const override {
340 auto constOp = cast<hw::ConstantOp>(op);
342 auto intType = typeConverter->convertType(constOp.getValueAttr().getType());
344 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, intType,
345 constOp.getValueAttr());
355 struct HWDynamicArrayCreateOpConversion
356 :
public ConvertOpToLLVMPattern<hw::ArrayCreateOp> {
357 using ConvertOpToLLVMPattern<hw::ArrayCreateOp>::ConvertOpToLLVMPattern;
361 ConversionPatternRewriter &rewriter)
const override {
362 auto arrayTy = typeConverter->convertType(op->getResult(0).getType());
365 Value arr = rewriter.create<LLVM::UndefOp>(op->getLoc(), arrayTy);
366 for (
size_t i = 0, e = op.getInputs().size(); i < e; ++i) {
369 .getInputs()[HWToLLVMEndianessConverter::convertToLLVMEndianess(
370 op.getResult().getType(), i)];
371 arr = rewriter.create<LLVM::InsertValueOp>(op->getLoc(), arr, input, i);
374 rewriter.replaceOp(op, arr);
384 class AggregateConstantOpConversion
385 :
public ConvertOpToLLVMPattern<hw::AggregateConstantOp> {
386 using ConvertOpToLLVMPattern<hw::AggregateConstantOp>::ConvertOpToLLVMPattern;
388 bool containsArrayAndStructAggregatesOnly(Type type)
const;
390 bool isMultiDimArrayOfIntegers(Type type,
391 SmallVectorImpl<int64_t> &dims)
const;
393 void flatten(Type type, Attribute attr,
394 SmallVectorImpl<Attribute> &output)
const;
396 Value constructAggregate(OpBuilder &builder,
397 const TypeConverter &typeConverter, Location loc,
398 Type type, Attribute data)
const;
401 explicit AggregateConstantOpConversion(
402 LLVMTypeConverter &typeConverter,
403 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
404 &constAggregateGlobalsMap,
406 : ConvertOpToLLVMPattern(typeConverter),
407 constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals) {}
410 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
411 ConversionPatternRewriter &rewriter)
const override;
414 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
415 &constAggregateGlobalsMap;
423 struct HWStructCreateOpConversion
424 :
public ConvertOpToLLVMPattern<hw::StructCreateOp> {
425 using ConvertOpToLLVMPattern<hw::StructCreateOp>::ConvertOpToLLVMPattern;
429 ConversionPatternRewriter &rewriter)
const override {
431 auto resTy = typeConverter->convertType(op.getResult().getType());
433 Value tup = rewriter.create<LLVM::UndefOp>(op->getLoc(), resTy);
434 for (
size_t i = 0, e = cast<LLVM::LLVMStructType>(resTy).getBody().size();
437 adaptor.getInput()[HWToLLVMEndianessConverter::convertToLLVMEndianess(
438 op.getResult().getType(), i)];
439 tup = rewriter.create<LLVM::InsertValueOp>(op->getLoc(), tup, input, i);
442 rewriter.replaceOp(op, tup);
452 bool AggregateConstantOpConversion::containsArrayAndStructAggregatesOnly(
454 if (
auto intType = dyn_cast<IntegerType>(type))
457 if (
auto arrTy = dyn_cast<hw::ArrayType>(type))
458 return containsArrayAndStructAggregatesOnly(arrTy.getElementType());
460 if (
auto structTy = dyn_cast<hw::StructType>(type)) {
461 SmallVector<Type> innerTypes;
462 structTy.getInnerTypes(innerTypes);
463 return llvm::all_of(innerTypes, [&](
auto ty) {
464 return containsArrayAndStructAggregatesOnly(ty);
471 bool AggregateConstantOpConversion::isMultiDimArrayOfIntegers(
472 Type type, SmallVectorImpl<int64_t> &dims)
const {
473 if (
auto intType = dyn_cast<IntegerType>(type))
476 if (
auto arrTy = dyn_cast<hw::ArrayType>(type)) {
477 dims.push_back(arrTy.getNumElements());
478 return isMultiDimArrayOfIntegers(arrTy.getElementType(), dims);
484 void AggregateConstantOpConversion::flatten(
485 Type type, Attribute attr, SmallVectorImpl<Attribute> &output)
const {
486 if (isa<IntegerType>(type)) {
487 assert(isa<IntegerAttr>(attr));
488 output.push_back(attr);
492 auto arrAttr = cast<ArrayAttr>(attr);
493 for (
size_t i = 0, e = arrAttr.size(); i < e; ++i) {
495 arrAttr[HWToLLVMEndianessConverter::convertToLLVMEndianess(type, i)];
497 flatten(cast<hw::ArrayType>(type).getElementType(), element, output);
501 Value AggregateConstantOpConversion::constructAggregate(
502 OpBuilder &builder,
const TypeConverter &typeConverter, Location loc,
503 Type type, Attribute data)
const {
504 Type llvmType = typeConverter.convertType(type);
506 auto getElementType = [](Type type,
size_t index) {
507 if (
auto arrTy = dyn_cast<hw::ArrayType>(type)) {
508 return arrTy.getElementType();
511 assert(isa<hw::StructType>(type));
512 auto structTy = cast<hw::StructType>(type);
513 SmallVector<Type> innerTypes;
514 structTy.getInnerTypes(innerTypes);
515 return innerTypes[index];
518 return TypeSwitch<Type, Value>(type)
519 .Case<IntegerType>([&](
auto ty) {
520 return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(data));
522 .Case<hw::ArrayType, hw::StructType>([&](
auto ty) {
523 Value aggVal = builder.create<LLVM::UndefOp>(loc, llvmType);
524 auto arrayAttr = cast<ArrayAttr>(data);
525 for (
size_t i = 0, e = arrayAttr.size(); i < e; ++i) {
527 HWToLLVMEndianessConverter::convertToLLVMEndianess(type, i);
528 Attribute input = arrayAttr[currIdx];
531 Value element = constructAggregate(builder, typeConverter, loc,
533 aggVal = builder.create<LLVM::InsertValueOp>(loc, aggVal, element, i);
540 LogicalResult AggregateConstantOpConversion::matchAndRewrite(
541 hw::AggregateConstantOp op, OpAdaptor adaptor,
542 ConversionPatternRewriter &rewriter)
const {
543 Type aggregateType = op.getResult().getType();
546 if (!containsArrayAndStructAggregatesOnly(aggregateType))
549 auto llvmTy = typeConverter->convertType(op.getResult().getType());
550 auto typeAttrPair = std::make_pair(aggregateType, adaptor.getFields());
552 if (!constAggregateGlobalsMap.count(typeAttrPair) ||
553 !constAggregateGlobalsMap[typeAttrPair]) {
554 auto ipSave = rewriter.saveInsertionPoint();
556 Operation *parent = op->getParentOp();
557 while (!isa<mlir::ModuleOp>(parent->getParentOp())) {
558 parent = parent->getParentOp();
561 rewriter.setInsertionPoint(parent);
564 auto name = globals.newName(
"_aggregate_const_global");
566 SmallVector<int64_t> dims;
567 if (isMultiDimArrayOfIntegers(aggregateType, dims)) {
568 SmallVector<Attribute> ints;
569 flatten(aggregateType, adaptor.getFields(), ints);
572 dims, cast<IntegerAttr>(ints.front()).getType());
575 constAggregateGlobalsMap[typeAttrPair] = rewriter.create<LLVM::GlobalOp>(
576 op.getLoc(), llvmTy,
true, LLVM::Linkage::Internal, name, denseAttr);
578 auto global = rewriter.create<LLVM::GlobalOp>(op.getLoc(), llvmTy,
false,
579 LLVM::Linkage::Internal,
582 global.getInitializerRegion().push_back(blk);
583 rewriter.setInsertionPointToStart(blk);
586 constructAggregate(rewriter, *typeConverter, op.getLoc(),
587 aggregateType, adaptor.getFields());
588 rewriter.create<LLVM::ReturnOp>(op.getLoc(), aggregate);
589 constAggregateGlobalsMap[typeAttrPair] = global;
592 rewriter.restoreInsertionPoint(ipSave);
596 auto addr = rewriter.create<LLVM::AddressOfOp>(
597 op->getLoc(), constAggregateGlobalsMap[typeAttrPair]);
598 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmTy,
addr);
608 auto elementTy = converter.convertType(type.getElementType());
613 LLVMTypeConverter &converter) {
614 llvm::SmallVector<Type, 8> elements;
615 mlir::SmallVector<mlir::Type> types;
616 type.getInnerTypes(types);
618 for (
int i = 0, e = types.size(); i < e; ++i)
619 elements.push_back(converter.convertType(
620 types[HWToLLVMEndianessConverter::convertToLLVMEndianess(type, i)]));
622 return LLVM::LLVMStructType::getLiteral(&converter.getContext(), elements);
630 struct HWToLLVMLoweringPass
631 :
public circt::impl::ConvertHWToLLVMBase<HWToLLVMLoweringPass> {
632 void runOnOperation()
override;
637 LLVMTypeConverter &converter, RewritePatternSet &
patterns,
639 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
640 &constAggregateGlobalsMap) {
641 MLIRContext *ctx = converter.getDialect()->getContext();
644 patterns.add<HWConstantOpConversion>(ctx, converter);
645 patterns.add<HWDynamicArrayCreateOpConversion, HWStructCreateOpConversion>(
647 patterns.add<AggregateConstantOpConversion>(
648 converter, constAggregateGlobalsMap, globals);
651 patterns.add<BitcastOpConversion>(converter);
654 patterns.add<ArrayGetOpConversion, ArraySliceOpConversion,
655 ArrayConcatOpConversion, StructExplodeOpConversion,
656 StructExtractOpConversion, StructInjectOpConversion>(converter);
660 converter.addConversion(
662 converter.addConversion(
666 void HWToLLVMLoweringPass::runOnOperation() {
667 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
673 RewritePatternSet
patterns(&getContext());
674 auto converter = mlir::LLVMTypeConverter(&getContext());
677 LLVMConversionTarget target(getContext());
678 target.addIllegalDialect<hw::HWDialect>();
682 constAggregateGlobalsMap);
686 applyPartialConversion(getOperation(), target, std::move(
patterns))))
692 return std::make_unique<HWToLLVMLoweringPass>();
assert(baseType &&"element must be base type")
static Type convertStructType(hw::StructType type, LLVMTypeConverter &converter)
static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter, Value value)
Create a zext operation by one bit on the given value.
static Type convertArrayType(hw::ArrayType type, LLVMTypeConverter &converter)
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap)
Get the HW to LLVM conversion patterns.
std::unique_ptr< OperationPass< ModuleOp > > createConvertHWToLLVMPass()
Create an HW to LLVM conversion pass.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
static uint32_t convertToLLVMEndianess(Type type, uint32_t index)
Convert an index into a HW ArrayType or StructType to LLVM Endianess.
static uint32_t llvmIndexOfStructField(hw::StructType type, StringRef fieldName)
Get the index of a specific StructType field in the LLVM lowering of the StructType.