CIRCT 22.0.0git
Loading...
Searching...
No Matches
HWToLLVM.cpp
Go to the documentation of this file.
1//===- HWToLLVM.cpp - HW to LLVM Conversion Pass --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main HW to LLVM Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
15#include "circt/Support/LLVM.h"
17#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18#include "mlir/Conversion/LLVMCommon/Pattern.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/IR/Iterators.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/ADT/TypeSwitch.h"
24
25namespace circt {
26#define GEN_PASS_DEF_CONVERTHWTOLLVM
27#include "circt/Conversion/Passes.h.inc"
28} // namespace circt
29
30using namespace mlir;
31using namespace circt;
32
33//===----------------------------------------------------------------------===//
34// Endianess Converter
35//===----------------------------------------------------------------------===//
36
37uint32_t
39 uint32_t index) {
40 // This is hardcoded for little endian machines for now.
41 return TypeSwitch<Type, uint32_t>(type)
42 .Case<hw::ArrayType>(
43 [&](hw::ArrayType ty) { return ty.getNumElements() - index - 1; })
44 .Case<hw::StructType>([&](hw::StructType ty) {
45 return ty.getElements().size() - index - 1;
46 });
47}
48
49uint32_t
51 StringRef fieldName) {
52 auto fieldIter = type.getElements();
53 size_t index = 0;
54
55 for (const auto *iter = fieldIter.begin(); iter != fieldIter.end(); ++iter) {
56 if (iter->name == fieldName) {
58 }
59 ++index;
60 }
61
62 // Verifier of StructExtractOp has to ensure that the field name is indeed
63 // present.
64 llvm_unreachable("Field name attribute of hw::StructExtractOp invalid");
65 return 0;
66}
67
68//===----------------------------------------------------------------------===//
69// Helpers
70//===----------------------------------------------------------------------===//
71
72/// Create a zext operation by one bit on the given value. This is useful when
73/// passing unsigned indexes to a GEP instruction, which treats indexes as
74/// signed values, to avoid unexpected "sign overflows".
75static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter,
76 Value value) {
77 auto valueTy = value.getType();
78 auto zextTy = IntegerType::get(valueTy.getContext(),
79 valueTy.getIntOrFloatBitWidth() + 1);
80 return LLVM::ZExtOp::create(rewriter, loc, zextTy, value);
81}
82
83//===----------------------------------------------------------------------===//
84// HWToLLVMArraySpillCache
85//===----------------------------------------------------------------------===//
86
87static Value spillValueOnStack(OpBuilder &builder, Location loc,
88 Value spillVal) {
89 auto oneC = LLVM::ConstantOp::create(
90 builder, loc, IntegerType::get(builder.getContext(), 32),
91 builder.getI32IntegerAttr(1));
92 Value ptr = LLVM::AllocaOp::create(
93 builder, loc, LLVM::LLVMPointerType::get(builder.getContext()),
94 spillVal.getType(), oneC,
95 /*alignment=*/4);
96 LLVM::StoreOp::create(builder, loc, spillVal, ptr);
97 return ptr;
98}
99
101 LLVMTypeConverter &converter,
102 Operation *containerOp) {
103 OpBuilder::InsertionGuard g(builder);
104 containerOp->walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
105 [&](Operation *op) {
106 if (isa_and_nonnull<hw::HWDialect>(op->getDialect()))
107 return;
108 auto hasSpillingUser = [](Value arrVal) -> bool {
109 for (auto user : arrVal.getUsers())
110 if (isa<hw::ArrayGetOp, hw::ArraySliceOp>(user))
111 return true;
112 return false;
113 };
114 // Spill Block arguments
115 for (auto &region : op->getRegions()) {
116 for (auto &block : region.getBlocks()) {
117 builder.setInsertionPointToStart(&block);
118 for (auto &arg : block.getArguments()) {
119 if (isa<hw::ArrayType>(arg.getType()) && hasSpillingUser(arg))
120 spillHWArrayValue(builder, arg.getLoc(), converter, arg);
121 }
122 }
123 }
124 // Spill Op Results
125 for (auto result : op->getResults()) {
126 if (isa<hw::ArrayType>(result.getType()) && hasSpillingUser(result)) {
127 builder.setInsertionPointAfter(op);
128 spillHWArrayValue(builder, op->getLoc(), converter, result);
129 }
130 }
131 });
132}
133
134void HWToLLVMArraySpillCache::map(Value arrayValue, Value bufferPtr) {
135 assert(isa<LLVM::LLVMArrayType>(arrayValue.getType()) &&
136 "Key is not an LLVM array.");
137 assert(isa<LLVM::LLVMPointerType>(bufferPtr.getType()) &&
138 "Value is not a pointer.");
139 spillMap.insert({arrayValue, bufferPtr});
140}
141
142Value HWToLLVMArraySpillCache::lookup(Value arrayValue) {
143 assert(isa<LLVM::LLVMArrayType>(arrayValue.getType()) ||
144 isa<hw::ArrayType>(arrayValue.getType()) && "Not an array value");
145 while (isa<LLVM::LLVMArrayType, hw::ArrayType>(arrayValue.getType())) {
146 if (isa<LLVM::LLVMArrayType>(arrayValue.getType())) {
147 auto mapVal = spillMap.lookup(arrayValue);
148 if (mapVal)
149 return mapVal;
150 }
151 if (auto castOp = arrayValue.getDefiningOp<UnrealizedConversionCastOp>())
152 arrayValue = castOp.getOperand(0);
153 else
154 break;
155 }
156 return {};
157}
158
159// Materialize a LLVM Array value in a stack allocated buffer.
161 Location loc,
162 Value llvmArray) {
163 assert(isa<LLVM::LLVMArrayType>(llvmArray.getType()) &&
164 "Expected an LLVM array.");
165 auto spillBuffer = spillValueOnStack(builder, loc, llvmArray);
166 auto loadOp =
167 LLVM::LoadOp::create(builder, loc, llvmArray.getType(), spillBuffer);
168 map(loadOp.getResult(), spillBuffer);
169 return loadOp.getResult();
170}
171
172// Materialize a HW Array value in a stack allocated buffer. Replaces
173// all current uses of the SSA value with a new SSA representing the same
174// array value.
176 Location loc,
177 LLVMTypeConverter &converter,
178 Value hwArray) {
179 assert(isa<hw::ArrayType>(hwArray.getType()) && "Expected an HW array");
180 auto targetType = converter.convertType(hwArray.getType());
181 auto hwToLLVMCast =
182 UnrealizedConversionCastOp::create(builder, loc, targetType, hwArray);
183 auto spilled = spillLLVMArrayValue(builder, loc, hwToLLVMCast.getResult(0));
184 auto llvmToHWCast = UnrealizedConversionCastOp::create(
185 builder, loc, hwArray.getType(), spilled);
186 hwArray.replaceAllUsesExcept(llvmToHWCast.getResult(0), hwToLLVMCast);
187 return llvmToHWCast.getResult(0);
188}
189
190namespace {
191// Helper for patterns using or creating buffers containing
192// HW array values.
193template <typename SourceOp>
194struct HWArrayOpToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
195
196 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
197 HWArrayOpToLLVMPattern(LLVMTypeConverter &converter,
198 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt)
199 : ConvertOpToLLVMPattern<SourceOp>(converter),
200 spillCacheOpt(spillCacheOpt) {}
201
202protected:
203 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt;
204};
205
206} // namespace
207
208//===----------------------------------------------------------------------===//
209// Extraction operation conversions
210//===----------------------------------------------------------------------===//
211
212namespace {
213/// Convert a StructExplodeOp to the LLVM dialect.
214/// Pattern: struct_explode(input) =>
215/// struct_extract(input, structElements_index(index)) ...
216struct StructExplodeOpConversion
217 : public ConvertOpToLLVMPattern<hw::StructExplodeOp> {
218 using ConvertOpToLLVMPattern<hw::StructExplodeOp>::ConvertOpToLLVMPattern;
219
220 LogicalResult
221 matchAndRewrite(hw::StructExplodeOp op, OpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter) const override {
223
224 SmallVector<Value> replacements;
225
226 for (size_t i = 0,
227 e = cast<LLVM::LLVMStructType>(adaptor.getInput().getType())
228 .getBody()
229 .size();
230 i < e; ++i)
231
232 replacements.push_back(LLVM::ExtractValueOp::create(
233 rewriter, op->getLoc(), adaptor.getInput(),
235 op.getInput().getType(), i)));
236
237 rewriter.replaceOp(op, replacements);
238 return success();
239 }
240};
241} // namespace
242
243namespace {
244/// Convert a StructExtractOp to LLVM dialect.
245/// Pattern: struct_extract(input, fieldname) =>
246/// extractvalue(input, fieldname_to_index(fieldname))
247struct StructExtractOpConversion
248 : public ConvertOpToLLVMPattern<hw::StructExtractOp> {
249 using ConvertOpToLLVMPattern<hw::StructExtractOp>::ConvertOpToLLVMPattern;
250
251 LogicalResult
252 matchAndRewrite(hw::StructExtractOp op, OpAdaptor adaptor,
253 ConversionPatternRewriter &rewriter) const override {
254
256 op.getInput().getType(), op.getFieldIndex());
257 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, adaptor.getInput(),
258 fieldIndex);
259 return success();
260 }
261};
262} // namespace
263
264namespace {
265/// Convert an ArrayInjectOp to the LLVM dialect.
266/// Pattern: array_inject(input, element, index) =>
267/// store(gep(store(input, alloca), zext(index)), element)
268/// load(alloca)
269struct ArrayInjectOpConversion
270 : public HWArrayOpToLLVMPattern<hw::ArrayInjectOp> {
271 using HWArrayOpToLLVMPattern<hw::ArrayInjectOp>::HWArrayOpToLLVMPattern;
272
273 LogicalResult
274 matchAndRewrite(hw::ArrayInjectOp op, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter) const override {
276 auto inputType = cast<hw::ArrayType>(op.getInput().getType());
277 auto oldArrTy = adaptor.getInput().getType();
278 auto newArrTy = oldArrTy;
279 const size_t arrElems = inputType.getNumElements();
280
281 if (arrElems == 0) {
282 rewriter.replaceOp(op, adaptor.getInput());
283 return success();
284 }
285
286 auto oneC =
287 LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI32Type(),
288 rewriter.getI32IntegerAttr(1));
289 auto zextIndex = zextByOne(op->getLoc(), rewriter, op.getIndex());
290
291 Value arrPtr;
292 if (arrElems == 1 || !llvm::isPowerOf2_64(arrElems)) {
293 // Clamp index to prevent OOB access. We add an extra element to the
294 // array so that OOB access modifies this element, leaving the original
295 // array intact.
296 auto maxIndex =
297 LLVM::ConstantOp::create(rewriter, op->getLoc(), zextIndex.getType(),
298 rewriter.getI32IntegerAttr(arrElems));
299 zextIndex =
300 LLVM::UMinOp::create(rewriter, op->getLoc(), zextIndex, maxIndex);
301
302 newArrTy = typeConverter->convertType(
303 hw::ArrayType::get(inputType.getElementType(), arrElems + 1));
304 arrPtr = LLVM::AllocaOp::create(
305 rewriter, op->getLoc(),
306 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, oneC,
307 /*alignment=*/4);
308 } else {
309 arrPtr = LLVM::AllocaOp::create(
310 rewriter, op->getLoc(),
311 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, oneC,
312 /*alignment=*/4);
313 }
314
315 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getInput(), arrPtr);
316
317 auto gep = LLVM::GEPOp::create(
318 rewriter, op->getLoc(),
319 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, arrPtr,
320 ArrayRef<LLVM::GEPArg>{0, zextIndex});
321
322 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getElement(), gep);
323 auto loadOp =
324 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, oldArrTy, arrPtr);
325 if (spillCacheOpt)
326 spillCacheOpt->map(loadOp, arrPtr);
327 return success();
328 }
329};
330} // namespace
331
332namespace {
333/// Convert an ArrayGetOp to the LLVM dialect.
334/// Pattern: array_get(input, index) =>
335/// load(gep(store(input, alloca), zext(index)))
336struct ArrayGetOpConversion : public HWArrayOpToLLVMPattern<hw::ArrayGetOp> {
337 using HWArrayOpToLLVMPattern<hw::ArrayGetOp>::HWArrayOpToLLVMPattern;
338
339 LogicalResult
340 matchAndRewrite(hw::ArrayGetOp op, OpAdaptor adaptor,
341 ConversionPatternRewriter &rewriter) const override {
342
343 Value arrPtr;
344 if (spillCacheOpt)
345 arrPtr = spillCacheOpt->lookup(adaptor.getInput());
346 if (!arrPtr)
347 arrPtr = spillValueOnStack(rewriter, op.getLoc(), adaptor.getInput());
348
349 auto arrTy = typeConverter->convertType(op.getInput().getType());
350 auto elemTy = typeConverter->convertType(op.getResult().getType());
351 auto zextIndex = zextByOne(op->getLoc(), rewriter, op.getIndex());
352
353 // During the ongoing migration to opaque types, use the constructor that
354 // accepts an element type when the array pointer type is opaque, and
355 // otherwise use the typed pointer constructor.
356 auto gep = LLVM::GEPOp::create(
357 rewriter, op->getLoc(),
358 LLVM::LLVMPointerType::get(rewriter.getContext()), arrTy, arrPtr,
359 ArrayRef<LLVM::GEPArg>{0, zextIndex});
360 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elemTy, gep);
361
362 return success();
363 }
364};
365} // namespace
366
367namespace {
368/// Convert an ArraySliceOp to the LLVM dialect.
369/// Pattern: array_slice(input, lowIndex) =>
370/// load(bitcast(gep(store(input, alloca), zext(lowIndex))))
371struct ArraySliceOpConversion
372 : public HWArrayOpToLLVMPattern<hw::ArraySliceOp> {
373 using HWArrayOpToLLVMPattern<hw::ArraySliceOp>::HWArrayOpToLLVMPattern;
374
375 LogicalResult
376 matchAndRewrite(hw::ArraySliceOp op, OpAdaptor adaptor,
377 ConversionPatternRewriter &rewriter) const override {
378
379 auto dstTy = typeConverter->convertType(op.getDst().getType());
380
381 Value arrPtr;
382 if (spillCacheOpt)
383 arrPtr = spillCacheOpt->lookup(adaptor.getInput());
384 if (!arrPtr)
385 arrPtr = spillValueOnStack(rewriter, op.getLoc(), adaptor.getInput());
386
387 auto zextIndex = zextByOne(op->getLoc(), rewriter, op.getLowIndex());
388
389 // During the ongoing migration to opaque types, use the constructor that
390 // accepts an element type when the array pointer type is opaque, and
391 // otherwise use the typed pointer constructor.
392 auto gep = LLVM::GEPOp::create(
393 rewriter, op->getLoc(),
394 LLVM::LLVMPointerType::get(rewriter.getContext()), dstTy, arrPtr,
395 ArrayRef<LLVM::GEPArg>{0, zextIndex});
396
397 auto loadOp = rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dstTy, gep);
398
399 if (spillCacheOpt)
400 spillCacheOpt->map(loadOp, gep);
401
402 return success();
403 }
404};
405} // namespace
406
407//===----------------------------------------------------------------------===//
408// Insertion operations conversion
409//===----------------------------------------------------------------------===//
410
411namespace {
412/// Convert a StructInjectOp to LLVM dialect.
413/// Pattern: struct_inject(input, index, value) =>
414/// insertvalue(input, value, index)
415struct StructInjectOpConversion
416 : public ConvertOpToLLVMPattern<hw::StructInjectOp> {
417 using ConvertOpToLLVMPattern<hw::StructInjectOp>::ConvertOpToLLVMPattern;
418
419 LogicalResult
420 matchAndRewrite(hw::StructInjectOp op, OpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter) const override {
422
424 op.getInput().getType(), op.getFieldIndex());
425
426 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
427 op, adaptor.getInput(), adaptor.getNewValue(), fieldIndex);
428
429 return success();
430 }
431};
432} // namespace
433
434//===----------------------------------------------------------------------===//
435// Concat operations conversion
436//===----------------------------------------------------------------------===//
437
438namespace {
439/// Lower an ArrayConcatOp operation to the LLVM dialect.
440struct ArrayConcatOpConversion
441 : public HWArrayOpToLLVMPattern<hw::ArrayConcatOp> {
442 using HWArrayOpToLLVMPattern<hw::ArrayConcatOp>::HWArrayOpToLLVMPattern;
443
444 LogicalResult
445 matchAndRewrite(hw::ArrayConcatOp op, OpAdaptor adaptor,
446 ConversionPatternRewriter &rewriter) const override {
447
448 hw::ArrayType arrTy = cast<hw::ArrayType>(op.getResult().getType());
449 Type resultTy = typeConverter->convertType(arrTy);
450 auto loc = op.getLoc();
451
452 Value arr = LLVM::UndefOp::create(rewriter, loc, resultTy);
453
454 // Attention: j is hardcoded for little endian machines.
455 size_t j = op.getInputs().size() - 1, k = 0;
456
457 for (size_t i = 0, e = arrTy.getNumElements(); i < e; ++i) {
458 Value element = LLVM::ExtractValueOp::create(rewriter, loc,
459 adaptor.getInputs()[j], k);
460 arr = LLVM::InsertValueOp::create(rewriter, loc, arr, element, i);
461
462 ++k;
463 if (k >=
464 cast<hw::ArrayType>(op.getInputs()[j].getType()).getNumElements()) {
465 k = 0;
466 --j;
467 }
468 }
469
470 rewriter.replaceOp(op, arr);
471
472 // If we've got a cache, spill the array right away.
473 if (spillCacheOpt) {
474 rewriter.setInsertionPointAfter(arr.getDefiningOp());
475 auto ptr = spillValueOnStack(rewriter, loc, arr);
476 spillCacheOpt->map(arr, ptr);
477 }
478 return success();
479 }
480};
481} // namespace
482
483//===----------------------------------------------------------------------===//
484// Bitwise conversions
485//===----------------------------------------------------------------------===//
486
487namespace {
488/// Lower an ArrayConcatOp operation to the LLVM dialect.
489/// Pattern: hw.bitcast(input) ==> load(bitcast_ptr(store(input, alloca)))
490/// This is necessary because we cannot bitcast aggregate types directly in
491/// LLVMIR.
492struct BitcastOpConversion : public ConvertOpToLLVMPattern<hw::BitcastOp> {
493 using ConvertOpToLLVMPattern<hw::BitcastOp>::ConvertOpToLLVMPattern;
494
495 LogicalResult
496 matchAndRewrite(hw::BitcastOp op, OpAdaptor adaptor,
497 ConversionPatternRewriter &rewriter) const override {
498
499 Type resultTy = typeConverter->convertType(op.getResult().getType());
500 auto ptr = spillValueOnStack(rewriter, op.getLoc(), adaptor.getInput());
501 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, resultTy, ptr);
502
503 return success();
504 }
505};
506} // namespace
507
508//===----------------------------------------------------------------------===//
509// Value creation conversions
510//===----------------------------------------------------------------------===//
511
512namespace {
513struct HWConstantOpConversion : public ConvertToLLVMPattern {
514 explicit HWConstantOpConversion(MLIRContext *ctx,
515 LLVMTypeConverter &typeConverter)
516 : ConvertToLLVMPattern(hw::ConstantOp::getOperationName(), ctx,
517 typeConverter) {}
518
519 LogicalResult
520 matchAndRewrite(Operation *op, ArrayRef<Value> operand,
521 ConversionPatternRewriter &rewriter) const override {
522 // Get the ConstOp.
523 auto constOp = cast<hw::ConstantOp>(op);
524 // Get the converted llvm type.
525 auto intType = typeConverter->convertType(constOp.getValueAttr().getType());
526 // Replace the operation with an llvm constant op.
527 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, intType,
528 constOp.getValueAttr());
529
530 return success();
531 }
532};
533} // namespace
534
535namespace {
536/// Convert an ArrayCreateOp with dynamic elements to the LLVM dialect. An
537/// equivalent and initialized llvm dialect array type is generated.
538struct HWDynamicArrayCreateOpConversion
539 : public ConvertOpToLLVMPattern<hw::ArrayCreateOp> {
540 using ConvertOpToLLVMPattern<hw::ArrayCreateOp>::ConvertOpToLLVMPattern;
541
542 LogicalResult
543 matchAndRewrite(hw::ArrayCreateOp op, OpAdaptor adaptor,
544 ConversionPatternRewriter &rewriter) const override {
545 auto arrayTy = typeConverter->convertType(op->getResult(0).getType());
546 assert(arrayTy);
547
548 Value arr = LLVM::UndefOp::create(rewriter, op->getLoc(), arrayTy);
549 for (size_t i = 0, e = op.getInputs().size(); i < e; ++i) {
550 Value input =
551 adaptor
553 op.getResult().getType(), i)];
554 arr = LLVM::InsertValueOp::create(rewriter, op->getLoc(), arr, input, i);
555 }
556
557 rewriter.replaceOp(op, arr);
558 return success();
559 }
560};
561} // namespace
562
563namespace {
564
565/// Convert an ArrayCreateOp with constant elements to the LLVM dialect. An
566/// equivalent and initialized llvm dialect array type is generated.
567class AggregateConstantOpConversion
568 : public HWArrayOpToLLVMPattern<hw::AggregateConstantOp> {
569 using HWArrayOpToLLVMPattern<hw::AggregateConstantOp>::HWArrayOpToLLVMPattern;
570
571 bool containsArrayAndStructAggregatesOnly(Type type) const;
572
573 bool isMultiDimArrayOfIntegers(Type type,
574 SmallVectorImpl<int64_t> &dims) const;
575
576 void flatten(Type type, Attribute attr,
577 SmallVectorImpl<Attribute> &output) const;
578
579 Value constructAggregate(OpBuilder &builder,
580 const TypeConverter &typeConverter, Location loc,
581 Type type, Attribute data) const;
582
583public:
584 explicit AggregateConstantOpConversion(
585 LLVMTypeConverter &typeConverter,
586 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
587 &constAggregateGlobalsMap,
588 Namespace &globals, std::optional<HWToLLVMArraySpillCache> &spillCacheOpt)
589 : HWArrayOpToLLVMPattern(typeConverter, spillCacheOpt),
590 constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals) {}
591
592 LogicalResult
593 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
594 ConversionPatternRewriter &rewriter) const override;
595
596private:
597 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
598 &constAggregateGlobalsMap;
599 Namespace &globals;
600};
601} // namespace
602
603namespace {
604/// Convert a StructCreateOp operation to the LLVM dialect. An equivalent and
605/// initialized llvm dialect struct type is generated.
606struct HWStructCreateOpConversion
607 : public ConvertOpToLLVMPattern<hw::StructCreateOp> {
608 using ConvertOpToLLVMPattern<hw::StructCreateOp>::ConvertOpToLLVMPattern;
609
610 LogicalResult
611 matchAndRewrite(hw::StructCreateOp op, OpAdaptor adaptor,
612 ConversionPatternRewriter &rewriter) const override {
613
614 auto resTy = typeConverter->convertType(op.getResult().getType());
615
616 Value tup = LLVM::UndefOp::create(rewriter, op->getLoc(), resTy);
617 for (size_t i = 0, e = cast<LLVM::LLVMStructType>(resTy).getBody().size();
618 i < e; ++i) {
619 Value input =
621 op.getResult().getType(), i)];
622 tup = LLVM::InsertValueOp::create(rewriter, op->getLoc(), tup, input, i);
623 }
624
625 rewriter.replaceOp(op, tup);
626 return success();
627 }
628};
629} // namespace
630
631//===----------------------------------------------------------------------===//
632// Pattern implementations
633//===----------------------------------------------------------------------===//
634
635bool AggregateConstantOpConversion::containsArrayAndStructAggregatesOnly(
636 Type type) const {
637 if (auto intType = dyn_cast<IntegerType>(type))
638 return true;
639
640 if (auto arrTy = dyn_cast<hw::ArrayType>(type))
641 return containsArrayAndStructAggregatesOnly(arrTy.getElementType());
642
643 if (auto structTy = dyn_cast<hw::StructType>(type)) {
644 SmallVector<Type> innerTypes;
645 structTy.getInnerTypes(innerTypes);
646 return llvm::all_of(innerTypes, [&](auto ty) {
647 return containsArrayAndStructAggregatesOnly(ty);
648 });
649 }
650
651 return false;
652}
653
654bool AggregateConstantOpConversion::isMultiDimArrayOfIntegers(
655 Type type, SmallVectorImpl<int64_t> &dims) const {
656 if (auto intType = dyn_cast<IntegerType>(type))
657 return true;
658
659 if (auto arrTy = dyn_cast<hw::ArrayType>(type)) {
660 dims.push_back(arrTy.getNumElements());
661 return isMultiDimArrayOfIntegers(arrTy.getElementType(), dims);
662 }
663
664 return false;
665}
666
667void AggregateConstantOpConversion::flatten(
668 Type type, Attribute attr, SmallVectorImpl<Attribute> &output) const {
669 if (isa<IntegerType>(type)) {
670 assert(isa<IntegerAttr>(attr));
671 output.push_back(attr);
672 return;
673 }
674
675 auto arrAttr = cast<ArrayAttr>(attr);
676 for (size_t i = 0, e = arrAttr.size(); i < e; ++i) {
677 auto element =
679
680 flatten(cast<hw::ArrayType>(type).getElementType(), element, output);
681 }
682}
683
684Value AggregateConstantOpConversion::constructAggregate(
685 OpBuilder &builder, const TypeConverter &typeConverter, Location loc,
686 Type type, Attribute data) const {
687 Type llvmType = typeConverter.convertType(type);
688
689 auto getElementType = [](Type type, size_t index) {
690 if (auto arrTy = dyn_cast<hw::ArrayType>(type)) {
691 return arrTy.getElementType();
692 }
693
694 assert(isa<hw::StructType>(type));
695 auto structTy = cast<hw::StructType>(type);
696 SmallVector<Type> innerTypes;
697 structTy.getInnerTypes(innerTypes);
698 return innerTypes[index];
699 };
700
701 return TypeSwitch<Type, Value>(type)
702 .Case<IntegerType>([&](auto ty) {
703 return LLVM::ConstantOp::create(builder, loc, cast<TypedAttr>(data));
704 })
705 .Case<hw::ArrayType, hw::StructType>([&](auto ty) {
706 Value aggVal = LLVM::UndefOp::create(builder, loc, llvmType);
707 auto arrayAttr = cast<ArrayAttr>(data);
708 for (size_t i = 0, e = arrayAttr.size(); i < e; ++i) {
709 size_t currIdx =
711 Attribute input = arrayAttr[currIdx];
712 Type elementType = getElementType(ty, currIdx);
713
714 Value element = constructAggregate(builder, typeConverter, loc,
715 elementType, input);
716 aggVal =
717 LLVM::InsertValueOp::create(builder, loc, aggVal, element, i);
718 }
719
720 return aggVal;
721 });
722}
723
724LogicalResult AggregateConstantOpConversion::matchAndRewrite(
725 hw::AggregateConstantOp op, OpAdaptor adaptor,
726 ConversionPatternRewriter &rewriter) const {
727 Type aggregateType = op.getResult().getType();
728
729 // TODO: Only arrays and structs supported at the moment.
730 if (!containsArrayAndStructAggregatesOnly(aggregateType))
731 return failure();
732
733 auto llvmTy = typeConverter->convertType(op.getResult().getType());
734 auto typeAttrPair = std::make_pair(aggregateType, adaptor.getFields());
735
736 if (!constAggregateGlobalsMap.count(typeAttrPair) ||
737 !constAggregateGlobalsMap[typeAttrPair]) {
738 auto ipSave = rewriter.saveInsertionPoint();
739
740 Operation *parent = op->getParentOp();
741 while (!isa<mlir::ModuleOp>(parent->getParentOp())) {
742 parent = parent->getParentOp();
743 }
744
745 rewriter.setInsertionPoint(parent);
746
747 // Create a global region for this static array.
748 auto name = globals.newName("_aggregate_const_global");
749
750 SmallVector<int64_t> dims;
751 if (isMultiDimArrayOfIntegers(aggregateType, dims)) {
752 SmallVector<Attribute> ints;
753 flatten(aggregateType, adaptor.getFields(), ints);
754 assert(!ints.empty());
755 auto shapedType = RankedTensorType::get(
756 dims, cast<IntegerAttr>(ints.front()).getType());
757 auto denseAttr = DenseElementsAttr::get(shapedType, ints);
758
759 constAggregateGlobalsMap[typeAttrPair] =
760 LLVM::GlobalOp::create(rewriter, op.getLoc(), llvmTy, true,
761 LLVM::Linkage::Internal, name, denseAttr);
762 } else {
763 auto global =
764 LLVM::GlobalOp::create(rewriter, op.getLoc(), llvmTy, false,
765 LLVM::Linkage::Internal, name, Attribute());
766 Block *blk = new Block();
767 global.getInitializerRegion().push_back(blk);
768 rewriter.setInsertionPointToStart(blk);
769
770 Value aggregate =
771 constructAggregate(rewriter, *typeConverter, op.getLoc(),
772 aggregateType, adaptor.getFields());
773 LLVM::ReturnOp::create(rewriter, op.getLoc(), aggregate);
774 constAggregateGlobalsMap[typeAttrPair] = global;
775 }
776
777 rewriter.restoreInsertionPoint(ipSave);
778 }
779
780 // Get the global array address and load it to return an array value.
781 auto addr = LLVM::AddressOfOp::create(rewriter, op->getLoc(),
782 constAggregateGlobalsMap[typeAttrPair]);
783 auto newOp = rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmTy, addr);
784
785 if (spillCacheOpt && llvm::isa<hw::ArrayType>(aggregateType))
786 spillCacheOpt->map(newOp.getResult(), addr);
787
788 return success();
789}
790
791//===----------------------------------------------------------------------===//
792// Type conversions
793//===----------------------------------------------------------------------===//
794
795static Type convertArrayType(hw::ArrayType type, LLVMTypeConverter &converter) {
796 auto elementTy = converter.convertType(type.getElementType());
797 return LLVM::LLVMArrayType::get(elementTy, type.getNumElements());
798}
799
800static Type convertStructType(hw::StructType type,
801 LLVMTypeConverter &converter) {
802 llvm::SmallVector<Type, 8> elements;
803 mlir::SmallVector<mlir::Type> types;
804 type.getInnerTypes(types);
805
806 for (int i = 0, e = types.size(); i < e; ++i)
807 elements.push_back(converter.convertType(
809
810 return LLVM::LLVMStructType::getLiteral(&converter.getContext(), elements);
811}
812
813//===----------------------------------------------------------------------===//
814// Pass initialization
815//===----------------------------------------------------------------------===//
816
817namespace {
818struct HWToLLVMLoweringPass
819 : public circt::impl::ConvertHWToLLVMBase<HWToLLVMLoweringPass> {
820
821 using circt::impl::ConvertHWToLLVMBase<
822 HWToLLVMLoweringPass>::ConvertHWToLLVMBase;
823
824 void runOnOperation() override;
825};
826} // namespace
827
829 LLVMTypeConverter &converter, RewritePatternSet &patterns,
830 Namespace &globals,
831 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
832 &constAggregateGlobalsMap,
833 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt) {
834 MLIRContext *ctx = converter.getDialect()->getContext();
835
836 // Value creation conversion patterns.
837 patterns.add<HWConstantOpConversion>(ctx, converter);
838 patterns.add<HWDynamicArrayCreateOpConversion, HWStructCreateOpConversion>(
839 converter);
840 patterns.add<AggregateConstantOpConversion>(
841 converter, constAggregateGlobalsMap, globals, spillCacheOpt);
842
843 // Bitwise conversion patterns.
844 patterns.add<BitcastOpConversion>(converter);
845
846 // Extraction operation conversion patterns.
847 patterns.add<StructExplodeOpConversion, StructExtractOpConversion,
848 StructInjectOpConversion>(converter);
849
850 patterns.add<ArrayGetOpConversion, ArrayInjectOpConversion,
851 ArraySliceOpConversion, ArrayConcatOpConversion>(converter,
852 spillCacheOpt);
853}
854
855void circt::populateHWToLLVMTypeConversions(LLVMTypeConverter &converter) {
856 converter.addConversion(
857 [&](hw::ArrayType arr) { return convertArrayType(arr, converter); });
858 converter.addConversion(
859 [&](hw::StructType tup) { return convertStructType(tup, converter); });
860}
861
862void HWToLLVMLoweringPass::runOnOperation() {
863 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
864 std::optional<HWToLLVMArraySpillCache> spillCacheOpt = {};
865 Namespace globals;
866 SymbolCache cache;
867 cache.addDefinitions(getOperation());
868 globals.add(cache);
869
870 RewritePatternSet patterns(&getContext());
871 auto converter = mlir::LLVMTypeConverter(&getContext());
873
874 if (spillArraysEarly) {
875 spillCacheOpt = HWToLLVMArraySpillCache();
876 OpBuilder spillBuilder(getOperation());
877 spillCacheOpt->spillNonHWOps(spillBuilder, converter, getOperation());
878 }
879
880 LLVMConversionTarget target(getContext());
881 target.addIllegalDialect<hw::HWDialect>();
882
883 // Setup the conversion.
885 constAggregateGlobalsMap, spillCacheOpt);
886
887 // Apply the partial conversion.
888 ConversionConfig config;
889 config.allowPatternRollback = false;
890 if (failed(applyPartialConversion(getOperation(), target, std::move(patterns),
891 config)))
892 signalPassFailure();
893}
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static Type convertStructType(hw::StructType type, LLVMTypeConverter &converter)
Definition HWToLLVM.cpp:800
static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter, Value value)
Create a zext operation by one bit on the given value.
Definition HWToLLVM.cpp:75
static Type convertArrayType(hw::ArrayType type, LLVMTypeConverter &converter)
Definition HWToLLVM.cpp:795
static Value spillValueOnStack(OpBuilder &builder, Location loc, Value spillVal)
Definition HWToLLVM.cpp:87
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap, std::optional< HWToLLVMArraySpillCache > &spillCacheOpt)
Get the HW to LLVM conversion patterns.
Definition hw.py:1
Helper class mapping array values (HW or LLVM Dialect) to pointers to buffers containing the array va...
Definition HWToLLVM.h:47
Value spillHWArrayValue(OpBuilder &builder, Location loc, mlir::LLVMTypeConverter &converter, Value hwArray)
Definition HWToLLVM.cpp:175
Value lookup(Value arrayValue)
Retrieve a pointer to a buffer containing the given array value (HW or LLVM Dialect).
Definition HWToLLVM.cpp:142
void spillNonHWOps(mlir::OpBuilder &builder, mlir::LLVMTypeConverter &converter, Operation *containerOp)
Spill HW array values produced by 'foreign' dialects on the stack.
Definition HWToLLVM.cpp:100
void map(mlir::Value arrayValue, mlir::Value bufferPtr)
Map an LLVM array value to an LLVM pointer.
Definition HWToLLVM.cpp:134
Value spillLLVMArrayValue(OpBuilder &builder, Location loc, Value llvmArray)
Definition HWToLLVM.cpp:160
llvm::DenseMap< Value, Value > spillMap
Definition HWToLLVM.h:70
static uint32_t convertToLLVMEndianess(Type type, uint32_t index)
Convert an index into a HW ArrayType or StructType to LLVM Endianess.
Definition HWToLLVM.cpp:38
static uint32_t llvmIndexOfStructField(hw::StructType type, StringRef fieldName)
Get the index of a specific StructType field in the LLVM lowering of the StructType.
Definition HWToLLVM.cpp:50