14 #include "../PassDetail.h"
18 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
19 #include "mlir/Conversion/LLVMCommon/Pattern.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/TypeSwitch.h"
26 using namespace circt;
36 return TypeSwitch<Type, uint32_t>(type)
38 [&](hw::ArrayType ty) {
return ty.getNumElements() - index - 1; })
39 .Case<hw::StructType>([&](hw::StructType ty) {
40 return ty.getElements().size() - index - 1;
46 StringRef fieldName) {
47 auto fieldIter = type.getElements();
50 for (
const auto *iter = fieldIter.begin(); iter != fieldIter.end(); ++iter) {
51 if (iter->name == fieldName) {
52 return HWToLLVMEndianessConverter::convertToLLVMEndianess(type, index);
59 llvm_unreachable(
"Field name attribute of hw::StructExtractOp invalid");
70 static Value
zextByOne(Location loc, ConversionPatternRewriter &rewriter,
72 auto valueTy = value.getType();
74 valueTy.getIntOrFloatBitWidth() + 1);
75 return rewriter.create<LLVM::ZExtOp>(loc, zextTy, value);
86 struct StructExplodeOpConversion
87 :
public ConvertOpToLLVMPattern<hw::StructExplodeOp> {
88 using ConvertOpToLLVMPattern<hw::StructExplodeOp>::ConvertOpToLLVMPattern;
91 matchAndRewrite(hw::StructExplodeOp op, OpAdaptor adaptor,
92 ConversionPatternRewriter &rewriter)
const override {
94 SmallVector<Value> replacements;
96 for (
size_t i = 0, e = adaptor.getInput()
98 .cast<LLVM::LLVMStructType>()
103 replacements.push_back(rewriter.create<LLVM::ExtractValueOp>(
104 op->getLoc(), adaptor.getInput(),
105 HWToLLVMEndianessConverter::convertToLLVMEndianess(
106 op.getInput().getType(), i)));
108 rewriter.replaceOp(op, replacements);
118 struct StructExtractOpConversion
119 :
public ConvertOpToLLVMPattern<hw::StructExtractOp> {
120 using ConvertOpToLLVMPattern<hw::StructExtractOp>::ConvertOpToLLVMPattern;
124 ConversionPatternRewriter &rewriter)
const override {
126 uint32_t fieldIndex = HWToLLVMEndianessConverter::convertToLLVMEndianess(
127 op.getInput().getType(), op.getFieldIndex());
128 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, adaptor.getInput(),
139 struct ArrayGetOpConversion :
public ConvertOpToLLVMPattern<hw::ArrayGetOp> {
140 using ConvertOpToLLVMPattern<hw::ArrayGetOp>::ConvertOpToLLVMPattern;
144 ConversionPatternRewriter &rewriter)
const override {
147 if (
auto load = adaptor.getInput().getDefiningOp<LLVM::LoadOp>()) {
150 arrPtr = load.getAddr();
152 auto oneC = rewriter.create<LLVM::ConstantOp>(
154 rewriter.getI32IntegerAttr(1));
155 arrPtr = rewriter.create<LLVM::AllocaOp>(
157 adaptor.getInput().getType(), oneC,
159 rewriter.create<LLVM::StoreOp>(op->getLoc(), adaptor.getInput(), arrPtr);
162 auto arrTy = typeConverter->convertType(op.getInput().getType());
163 auto elemTy = typeConverter->convertType(op.getResult().getType());
164 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getIndex());
169 auto gep = rewriter.create<LLVM::GEPOp>(
171 arrPtr, ArrayRef<LLVM::GEPArg>{0, zextIndex});
172 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elemTy, gep);
183 struct ArraySliceOpConversion
184 :
public ConvertOpToLLVMPattern<hw::ArraySliceOp> {
185 using ConvertOpToLLVMPattern<hw::ArraySliceOp>::ConvertOpToLLVMPattern;
189 ConversionPatternRewriter &rewriter)
const override {
191 auto dstTy = typeConverter->convertType(op.getDst().getType());
193 auto oneC = rewriter.create<LLVM::ConstantOp>(
194 op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
196 auto arrPtr = rewriter.create<LLVM::AllocaOp>(
198 adaptor.getInput().getType(), oneC,
201 rewriter.create<LLVM::StoreOp>(op->getLoc(), adaptor.getInput(), arrPtr);
203 auto zextIndex =
zextByOne(op->getLoc(), rewriter, op.getLowIndex());
208 auto gep = rewriter.create<LLVM::GEPOp>(
210 arrPtr, ArrayRef<LLVM::GEPArg>{0, zextIndex});
212 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dstTy, gep);
227 struct StructInjectOpConversion
228 :
public ConvertOpToLLVMPattern<hw::StructInjectOp> {
229 using ConvertOpToLLVMPattern<hw::StructInjectOp>::ConvertOpToLLVMPattern;
232 matchAndRewrite(hw::StructInjectOp op, OpAdaptor adaptor,
233 ConversionPatternRewriter &rewriter)
const override {
235 uint32_t fieldIndex = HWToLLVMEndianessConverter::convertToLLVMEndianess(
236 op.getInput().getType(), op.getFieldIndex());
238 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
239 op, adaptor.getInput(), op.getNewValue(), fieldIndex);
252 struct ArrayConcatOpConversion
253 :
public ConvertOpToLLVMPattern<hw::ArrayConcatOp> {
254 using ConvertOpToLLVMPattern<hw::ArrayConcatOp>::ConvertOpToLLVMPattern;
258 ConversionPatternRewriter &rewriter)
const override {
260 hw::ArrayType arrTy = op.getResult().getType().cast<hw::ArrayType>();
261 Type resultTy = typeConverter->convertType(arrTy);
263 Value arr = rewriter.
create<LLVM::UndefOp>(op->getLoc(), resultTy);
266 size_t j = op.getInputs().size() - 1, k = 0;
268 for (
size_t i = 0, e = arrTy.getNumElements(); i < e; ++i) {
269 Value element = rewriter.
create<LLVM::ExtractValueOp>(
270 op->getLoc(), adaptor.getInputs()[j], k);
271 arr = rewriter.create<LLVM::InsertValueOp>(op->getLoc(), arr, element, i);
275 op.getInputs()[j].getType().cast<hw::ArrayType>().getNumElements()) {
281 rewriter.replaceOp(op, arr);
296 struct BitcastOpConversion :
public ConvertOpToLLVMPattern<hw::BitcastOp> {
297 using ConvertOpToLLVMPattern<hw::BitcastOp>::ConvertOpToLLVMPattern;
301 ConversionPatternRewriter &rewriter)
const override {
303 Type resultTy = typeConverter->convertType(op.getResult().getType());
305 auto oneC = rewriter.createOrFold<LLVM::ConstantOp>(
306 op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
308 auto ptr = rewriter.create<LLVM::AllocaOp>(
310 adaptor.getInput().getType(), oneC,
313 rewriter.create<LLVM::StoreOp>(op->getLoc(), adaptor.getInput(), ptr);
315 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, resultTy, ptr);
327 struct HWConstantOpConversion :
public ConvertToLLVMPattern {
328 explicit HWConstantOpConversion(MLIRContext *ctx,
329 LLVMTypeConverter &typeConverter)
330 : ConvertToLLVMPattern(
hw::ConstantOp::getOperationName(), ctx,
334 matchAndRewrite(Operation *op, ArrayRef<Value> operand,
335 ConversionPatternRewriter &rewriter)
const override {
337 auto constOp = cast<hw::ConstantOp>(op);
339 auto intType = typeConverter->convertType(constOp.getValueAttr().getType());
341 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, intType,
342 constOp.getValueAttr());
352 struct HWDynamicArrayCreateOpConversion
353 :
public ConvertOpToLLVMPattern<hw::ArrayCreateOp> {
354 using ConvertOpToLLVMPattern<hw::ArrayCreateOp>::ConvertOpToLLVMPattern;
358 ConversionPatternRewriter &rewriter)
const override {
359 auto arrayTy = typeConverter->convertType(op->getResult(0).getType());
362 Value arr = rewriter.create<LLVM::UndefOp>(op->getLoc(), arrayTy);
363 for (
size_t i = 0, e = op.getInputs().size(); i < e; ++i) {
366 .getInputs()[HWToLLVMEndianessConverter::convertToLLVMEndianess(
367 op.getResult().getType(), i)];
368 arr = rewriter.create<LLVM::InsertValueOp>(op->getLoc(), arr, input, i);
371 rewriter.replaceOp(op, arr);
381 class AggregateConstantOpConversion
382 :
public ConvertOpToLLVMPattern<hw::AggregateConstantOp> {
383 using ConvertOpToLLVMPattern<hw::AggregateConstantOp>::ConvertOpToLLVMPattern;
385 bool containsArrayAndStructAggregatesOnly(Type type)
const;
387 bool isMultiDimArrayOfIntegers(Type type,
388 SmallVectorImpl<int64_t> &dims)
const;
390 void flatten(Type type, Attribute attr,
391 SmallVectorImpl<Attribute> &output)
const;
393 Value constructAggregate(OpBuilder &
builder,
394 const TypeConverter &typeConverter, Location loc,
395 Type type, Attribute data)
const;
398 explicit AggregateConstantOpConversion(
399 LLVMTypeConverter &typeConverter,
400 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
401 &constAggregateGlobalsMap,
403 : ConvertOpToLLVMPattern(typeConverter),
404 constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals) {}
407 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
408 ConversionPatternRewriter &rewriter)
const override;
411 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
412 &constAggregateGlobalsMap;
420 struct HWStructCreateOpConversion
421 :
public ConvertOpToLLVMPattern<hw::StructCreateOp> {
422 using ConvertOpToLLVMPattern<hw::StructCreateOp>::ConvertOpToLLVMPattern;
426 ConversionPatternRewriter &rewriter)
const override {
428 auto resTy = typeConverter->convertType(op.getResult().getType());
430 Value tup = rewriter.create<LLVM::UndefOp>(op->getLoc(), resTy);
431 for (
size_t i = 0, e = resTy.cast<LLVM::LLVMStructType>().getBody().size();
434 adaptor.getInput()[HWToLLVMEndianessConverter::convertToLLVMEndianess(
435 op.getResult().getType(), i)];
436 tup = rewriter.create<LLVM::InsertValueOp>(op->getLoc(), tup, input, i);
439 rewriter.replaceOp(op, tup);
449 bool AggregateConstantOpConversion::containsArrayAndStructAggregatesOnly(
451 if (
auto intType = type.dyn_cast<IntegerType>())
454 if (
auto arrTy = type.dyn_cast<hw::ArrayType>())
455 return containsArrayAndStructAggregatesOnly(arrTy.getElementType());
457 if (
auto structTy = type.dyn_cast<hw::StructType>()) {
458 SmallVector<Type> innerTypes;
459 structTy.getInnerTypes(innerTypes);
460 return llvm::all_of(innerTypes, [&](
auto ty) {
461 return containsArrayAndStructAggregatesOnly(ty);
468 bool AggregateConstantOpConversion::isMultiDimArrayOfIntegers(
469 Type type, SmallVectorImpl<int64_t> &dims)
const {
470 if (
auto intType = type.dyn_cast<IntegerType>())
473 if (
auto arrTy = type.dyn_cast<hw::ArrayType>()) {
474 dims.push_back(arrTy.getNumElements());
475 return isMultiDimArrayOfIntegers(arrTy.getElementType(), dims);
481 void AggregateConstantOpConversion::flatten(
482 Type type, Attribute attr, SmallVectorImpl<Attribute> &output)
const {
483 if (type.isa<IntegerType>()) {
484 assert(attr.isa<IntegerAttr>());
485 output.push_back(attr);
489 auto arrAttr = attr.cast<ArrayAttr>();
490 for (
size_t i = 0, e = arrAttr.size(); i < e; ++i) {
492 arrAttr[HWToLLVMEndianessConverter::convertToLLVMEndianess(type, i)];
494 flatten(type.cast<hw::ArrayType>().getElementType(), element, output);
498 Value AggregateConstantOpConversion::constructAggregate(
499 OpBuilder &
builder,
const TypeConverter &typeConverter, Location loc,
500 Type type, Attribute data)
const {
501 Type llvmType = typeConverter.convertType(type);
503 auto getElementType = [](Type type,
size_t index) {
504 if (
auto arrTy = type.dyn_cast<hw::ArrayType>()) {
505 return arrTy.getElementType();
508 assert(type.isa<hw::StructType>());
509 auto structTy = type.cast<hw::StructType>();
510 SmallVector<Type> innerTypes;
511 structTy.getInnerTypes(innerTypes);
512 return innerTypes[index];
515 return TypeSwitch<Type, Value>(type)
516 .Case<IntegerType>([&](
auto ty) {
517 return builder.create<LLVM::ConstantOp>(loc,
data.cast<TypedAttr>());
519 .Case<hw::ArrayType, hw::StructType>([&](
auto ty) {
520 Value aggVal =
builder.create<LLVM::UndefOp>(loc, llvmType);
521 auto arrayAttr =
data.cast<ArrayAttr>();
522 for (
size_t i = 0, e = arrayAttr.size(); i < e; ++i) {
524 HWToLLVMEndianessConverter::convertToLLVMEndianess(type, i);
525 Attribute input = arrayAttr[currIdx];
528 Value element = constructAggregate(
builder, typeConverter, loc,
530 aggVal =
builder.create<LLVM::InsertValueOp>(loc, aggVal, element, i);
537 LogicalResult AggregateConstantOpConversion::matchAndRewrite(
538 hw::AggregateConstantOp op, OpAdaptor adaptor,
539 ConversionPatternRewriter &rewriter)
const {
540 Type aggregateType = op.getResult().getType();
543 if (!containsArrayAndStructAggregatesOnly(aggregateType))
546 auto llvmTy = typeConverter->convertType(op.getResult().getType());
547 auto typeAttrPair = std::make_pair(aggregateType, adaptor.getFields());
549 if (!constAggregateGlobalsMap.count(typeAttrPair) ||
550 !constAggregateGlobalsMap[typeAttrPair]) {
551 auto ipSave = rewriter.saveInsertionPoint();
553 Operation *parent = op->getParentOp();
554 while (!isa<mlir::ModuleOp>(parent->getParentOp())) {
555 parent = parent->getParentOp();
558 rewriter.setInsertionPoint(parent);
561 auto name = globals.newName(
"_aggregate_const_global");
563 SmallVector<int64_t> dims;
564 if (isMultiDimArrayOfIntegers(aggregateType, dims)) {
565 SmallVector<Attribute> ints;
566 flatten(aggregateType, adaptor.getFields(), ints);
569 dims, ints.front().cast<IntegerAttr>().getType());
572 constAggregateGlobalsMap[typeAttrPair] = rewriter.create<LLVM::GlobalOp>(
573 op.getLoc(), llvmTy,
true, LLVM::Linkage::Internal, name, denseAttr);
575 auto global = rewriter.create<LLVM::GlobalOp>(op.getLoc(), llvmTy,
false,
576 LLVM::Linkage::Internal,
579 global.getInitializerRegion().push_back(blk);
580 rewriter.setInsertionPointToStart(blk);
583 constructAggregate(rewriter, *typeConverter, op.getLoc(),
584 aggregateType, adaptor.getFields());
585 rewriter.create<LLVM::ReturnOp>(op.getLoc(), aggregate);
586 constAggregateGlobalsMap[typeAttrPair] = global;
589 rewriter.restoreInsertionPoint(ipSave);
593 auto addr = rewriter.create<LLVM::AddressOfOp>(
594 op->getLoc(), constAggregateGlobalsMap[typeAttrPair]);
595 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmTy,
addr);
605 auto elementTy = converter.convertType(type.getElementType());
610 LLVMTypeConverter &converter) {
611 llvm::SmallVector<Type, 8> elements;
612 mlir::SmallVector<mlir::Type> types;
613 type.getInnerTypes(types);
615 for (
int i = 0, e = types.size(); i < e; ++i)
616 elements.push_back(converter.convertType(
617 types[HWToLLVMEndianessConverter::convertToLLVMEndianess(type, i)]));
619 return LLVM::LLVMStructType::getLiteral(&converter.getContext(), elements);
627 struct HWToLLVMLoweringPass :
public ConvertHWToLLVMBase<HWToLLVMLoweringPass> {
628 void runOnOperation()
override;
633 LLVMTypeConverter &converter, RewritePatternSet &
patterns,
635 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
636 &constAggregateGlobalsMap) {
637 MLIRContext *ctx = converter.getDialect()->getContext();
640 patterns.add<HWConstantOpConversion>(ctx, converter);
641 patterns.add<HWDynamicArrayCreateOpConversion, HWStructCreateOpConversion>(
643 patterns.add<AggregateConstantOpConversion>(
644 converter, constAggregateGlobalsMap, globals);
647 patterns.add<BitcastOpConversion>(converter);
650 patterns.add<ArrayGetOpConversion, ArraySliceOpConversion,
651 ArrayConcatOpConversion, StructExplodeOpConversion,
652 StructExtractOpConversion, StructInjectOpConversion>(converter);
656 converter.addConversion(
658 converter.addConversion(
662 void HWToLLVMLoweringPass::runOnOperation() {
663 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
669 RewritePatternSet
patterns(&getContext());
670 auto converter = mlir::LLVMTypeConverter(&getContext());
673 LLVMConversionTarget target(getContext());
674 target.addLegalOp<UnrealizedConversionCastOp>();
675 target.addLegalOp<ModuleOp>();
676 target.addLegalDialect<LLVM::LLVMDialect>();
677 target.addIllegalDialect<hw::HWDialect>();
681 constAggregateGlobalsMap);
685 applyPartialConversion(getOperation(), target, std::move(
patterns))))
691 return std::make_unique<HWToLLVMLoweringPass>();
assert(baseType &&"element must be base type")
static Type convertStructType(hw::StructType type, LLVMTypeConverter &converter)
static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter, Value value)
Create a zext operation by one bit on the given value.
static Type convertArrayType(hw::ArrayType type, LLVMTypeConverter &converter)
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(SymbolCache &symCache)
SymbolCache initializer; initialize from every key that is convertible to a StringAttr in the SymbolC...
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap)
Get the HW to LLVM conversion patterns.
std::unique_ptr< OperationPass< ModuleOp > > createConvertHWToLLVMPass()
Create an HW to LLVM conversion pass.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
static uint32_t convertToLLVMEndianess(Type type, uint32_t index)
Convert an index into a HW ArrayType or StructType to LLVM Endianess.
static uint32_t llvmIndexOfStructField(hw::StructType type, StringRef fieldName)
Get the index of a specific StructType field in the LLVM lowering of the StructType.