CIRCT 23.0.0git
Loading...
Searching...
No Matches
HWToLLVM.cpp
Go to the documentation of this file.
1//===- HWToLLVM.cpp - HW to LLVM Conversion Pass --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main HW to LLVM Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
15#include "circt/Support/LLVM.h"
17#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18#include "mlir/Conversion/LLVMCommon/Pattern.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/IR/Iterators.h"
21#include "mlir/Interfaces/DataLayoutInterfaces.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26namespace circt {
27#define GEN_PASS_DEF_CONVERTHWTOLLVM
28#include "circt/Conversion/Passes.h.inc"
29} // namespace circt
30
31using namespace mlir;
32using namespace circt;
33
34//===----------------------------------------------------------------------===//
35// Endianess Converter
36//===----------------------------------------------------------------------===//
37
38uint32_t
40 uint32_t index) {
41 // This is hardcoded for little endian machines for now.
42 return TypeSwitch<Type, uint32_t>(type)
43 .Case<hw::ArrayType>(
44 [&](hw::ArrayType ty) { return ty.getNumElements() - index - 1; })
45 .Case<hw::StructType>([&](hw::StructType ty) {
46 return ty.getElements().size() - index - 1;
47 });
48}
49
50uint32_t
52 StringRef fieldName) {
53 auto fieldIter = type.getElements();
54 size_t index = 0;
55
56 for (const auto *iter = fieldIter.begin(); iter != fieldIter.end(); ++iter) {
57 if (iter->name == fieldName) {
59 }
60 ++index;
61 }
62
63 // Verifier of StructExtractOp has to ensure that the field name is indeed
64 // present.
65 llvm_unreachable("Field name attribute of hw::StructExtractOp invalid");
66 return 0;
67}
68
69//===----------------------------------------------------------------------===//
70// Helpers
71//===----------------------------------------------------------------------===//
72
73/// Create a zext operation by one bit on the given value. This is useful when
74/// passing unsigned indexes to a GEP instruction, which treats indexes as
75/// signed values, to avoid unexpected "sign overflows".
76static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter,
77 Value value) {
78 auto valueTy = value.getType();
79 auto zextTy = IntegerType::get(valueTy.getContext(),
80 valueTy.getIntOrFloatBitWidth() + 1);
81 return LLVM::ZExtOp::create(rewriter, loc, zextTy, value);
82}
83
84//===----------------------------------------------------------------------===//
85// HWToLLVMArraySpillCache
86//===----------------------------------------------------------------------===//
87
88static Value spillValueOnStack(OpBuilder &builder, Location loc,
89 Value spillVal) {
90 auto oneC = LLVM::ConstantOp::create(
91 builder, loc, IntegerType::get(builder.getContext(), 32),
92 builder.getI32IntegerAttr(1));
93
94 Block *block = builder.getInsertionBlock();
95 assert(block && "expected an insertion block when spilling a value");
96
97 auto alignment =
98 static_cast<unsigned>(DataLayout::closest(block->getParentOp())
99 .getTypePreferredAlignment(spillVal.getType()));
100 alignment = std::max(4u, alignment);
101 Value ptr = LLVM::AllocaOp::create(
102 builder, loc, LLVM::LLVMPointerType::get(builder.getContext()),
103 spillVal.getType(), oneC, alignment);
104 LLVM::StoreOp::create(builder, loc, spillVal, ptr);
105 return ptr;
106}
107
109 LLVMTypeConverter &converter,
110 Operation *containerOp) {
111 OpBuilder::InsertionGuard g(builder);
112 containerOp->walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
113 [&](Operation *op) {
114 if (isa_and_nonnull<hw::HWDialect>(op->getDialect()))
115 return;
116 auto hasSpillingUser = [](Value arrVal) -> bool {
117 for (auto user : arrVal.getUsers())
118 if (isa<hw::ArrayGetOp, hw::ArraySliceOp>(user))
119 return true;
120 return false;
121 };
122 // Spill Block arguments
123 for (auto &region : op->getRegions()) {
124 for (auto &block : region.getBlocks()) {
125 builder.setInsertionPointToStart(&block);
126 for (auto &arg : block.getArguments()) {
127 if (isa<hw::ArrayType>(arg.getType()) && hasSpillingUser(arg))
128 spillHWArrayValue(builder, arg.getLoc(), converter, arg);
129 }
130 }
131 }
132 // Spill Op Results
133 for (auto result : op->getResults()) {
134 if (isa<hw::ArrayType>(result.getType()) && hasSpillingUser(result)) {
135 builder.setInsertionPointAfter(op);
136 spillHWArrayValue(builder, op->getLoc(), converter, result);
137 }
138 }
139 });
140}
141
142void HWToLLVMArraySpillCache::map(Value arrayValue, Value bufferPtr) {
143 assert(isa<LLVM::LLVMArrayType>(arrayValue.getType()) &&
144 "Key is not an LLVM array.");
145 assert(isa<LLVM::LLVMPointerType>(bufferPtr.getType()) &&
146 "Value is not a pointer.");
147 spillMap.insert({arrayValue, bufferPtr});
148}
149
150Value HWToLLVMArraySpillCache::lookup(Value arrayValue) {
151 assert(isa<LLVM::LLVMArrayType>(arrayValue.getType()) ||
152 isa<hw::ArrayType>(arrayValue.getType()) && "Not an array value");
153 while (isa<LLVM::LLVMArrayType, hw::ArrayType>(arrayValue.getType())) {
154 if (isa<LLVM::LLVMArrayType>(arrayValue.getType())) {
155 auto mapVal = spillMap.lookup(arrayValue);
156 if (mapVal)
157 return mapVal;
158 }
159 if (auto castOp = arrayValue.getDefiningOp<UnrealizedConversionCastOp>())
160 arrayValue = castOp.getOperand(0);
161 else
162 break;
163 }
164 return {};
165}
166
167// Materialize a LLVM Array value in a stack allocated buffer.
169 Location loc,
170 Value llvmArray) {
171 assert(isa<LLVM::LLVMArrayType>(llvmArray.getType()) &&
172 "Expected an LLVM array.");
173 auto spillBuffer = spillValueOnStack(builder, loc, llvmArray);
174 auto loadOp =
175 LLVM::LoadOp::create(builder, loc, llvmArray.getType(), spillBuffer);
176 map(loadOp.getResult(), spillBuffer);
177 return loadOp.getResult();
178}
179
180// Materialize a HW Array value in a stack allocated buffer. Replaces
181// all current uses of the SSA value with a new SSA representing the same
182// array value.
184 Location loc,
185 LLVMTypeConverter &converter,
186 Value hwArray) {
187 assert(isa<hw::ArrayType>(hwArray.getType()) && "Expected an HW array");
188 auto targetType = converter.convertType(hwArray.getType());
189 auto hwToLLVMCast =
190 UnrealizedConversionCastOp::create(builder, loc, targetType, hwArray);
191 auto spilled = spillLLVMArrayValue(builder, loc, hwToLLVMCast.getResult(0));
192 auto llvmToHWCast = UnrealizedConversionCastOp::create(
193 builder, loc, hwArray.getType(), spilled);
194 hwArray.replaceAllUsesExcept(llvmToHWCast.getResult(0), hwToLLVMCast);
195 return llvmToHWCast.getResult(0);
196}
197
198namespace {
199// Helper for patterns using or creating buffers containing
200// HW array values.
201template <typename SourceOp>
202struct HWArrayOpToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
203
204 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
205 HWArrayOpToLLVMPattern(LLVMTypeConverter &converter,
206 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt)
207 : ConvertOpToLLVMPattern<SourceOp>(converter),
208 spillCacheOpt(spillCacheOpt) {}
209
210protected:
211 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt;
212};
213
214} // namespace
215
216//===----------------------------------------------------------------------===//
217// Extraction operation conversions
218//===----------------------------------------------------------------------===//
219
220namespace {
221/// Convert a StructExplodeOp to the LLVM dialect.
222/// Pattern: struct_explode(input) =>
223/// struct_extract(input, structElements_index(index)) ...
224struct StructExplodeOpConversion
225 : public ConvertOpToLLVMPattern<hw::StructExplodeOp> {
226 using ConvertOpToLLVMPattern<hw::StructExplodeOp>::ConvertOpToLLVMPattern;
227
228 LogicalResult
229 matchAndRewrite(hw::StructExplodeOp op, OpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter) const override {
231
232 SmallVector<Value> replacements;
233
234 for (size_t i = 0,
235 e = cast<LLVM::LLVMStructType>(adaptor.getInput().getType())
236 .getBody()
237 .size();
238 i < e; ++i)
239
240 replacements.push_back(LLVM::ExtractValueOp::create(
241 rewriter, op->getLoc(), adaptor.getInput(),
243 op.getInput().getType(), i)));
244
245 rewriter.replaceOp(op, replacements);
246 return success();
247 }
248};
249} // namespace
250
251namespace {
252/// Convert a StructExtractOp to LLVM dialect.
253/// Pattern: struct_extract(input, fieldname) =>
254/// extractvalue(input, fieldname_to_index(fieldname))
255struct StructExtractOpConversion
256 : public ConvertOpToLLVMPattern<hw::StructExtractOp> {
257 using ConvertOpToLLVMPattern<hw::StructExtractOp>::ConvertOpToLLVMPattern;
258
259 LogicalResult
260 matchAndRewrite(hw::StructExtractOp op, OpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter) const override {
262
264 op.getInput().getType(), op.getFieldIndex());
265 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, adaptor.getInput(),
266 fieldIndex);
267 return success();
268 }
269};
270} // namespace
271
272namespace {
273/// Convert an ArrayInjectOp to the LLVM dialect.
274/// Pattern: array_inject(input, element, index) =>
275/// store(gep(store(input, alloca), zext(index)), element)
276/// load(alloca)
277struct ArrayInjectOpConversion
278 : public HWArrayOpToLLVMPattern<hw::ArrayInjectOp> {
279 using HWArrayOpToLLVMPattern<hw::ArrayInjectOp>::HWArrayOpToLLVMPattern;
280
281 LogicalResult
282 matchAndRewrite(hw::ArrayInjectOp op, OpAdaptor adaptor,
283 ConversionPatternRewriter &rewriter) const override {
284 auto inputType = cast<hw::ArrayType>(op.getInput().getType());
285 auto oldArrTy = adaptor.getInput().getType();
286 auto newArrTy = oldArrTy;
287 const size_t arrElems = inputType.getNumElements();
288
289 if (arrElems == 0) {
290 rewriter.replaceOp(op, adaptor.getInput());
291 return success();
292 }
293
294 auto oneC =
295 LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI32Type(),
296 rewriter.getI32IntegerAttr(1));
297 auto zextIndex = zextByOne(op->getLoc(), rewriter, op.getIndex());
298
299 if (arrElems == 1 || !llvm::isPowerOf2_64(arrElems)) {
300 // Clamp index to prevent OOB access. We add an extra element to the
301 // array so that OOB access modifies this element, leaving the original
302 // array intact.
303 auto maxIndex =
304 LLVM::ConstantOp::create(rewriter, op->getLoc(), zextIndex.getType(),
305 rewriter.getI32IntegerAttr(arrElems));
306 zextIndex =
307 LLVM::UMinOp::create(rewriter, op->getLoc(), zextIndex, maxIndex);
308
309 newArrTy = typeConverter->convertType(
310 hw::ArrayType::get(inputType.getElementType(), arrElems + 1));
311 }
312 auto allocaAlignment = std::max(
313 4u, static_cast<unsigned>(DataLayout::closest(op.getOperation())
314 .getTypePreferredAlignment(newArrTy)));
315 Value arrPtr = LLVM::AllocaOp::create(
316 rewriter, op->getLoc(),
317 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, oneC,
318 allocaAlignment);
319
320 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getInput(), arrPtr);
321
322 auto gep = LLVM::GEPOp::create(
323 rewriter, op->getLoc(),
324 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, arrPtr,
325 ArrayRef<LLVM::GEPArg>{0, zextIndex});
326
327 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getElement(), gep);
328 auto loadOp =
329 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, oldArrTy, arrPtr);
330 if (spillCacheOpt)
331 spillCacheOpt->map(loadOp, arrPtr);
332 return success();
333 }
334};
335} // namespace
336
337namespace {
338/// Convert an ArrayGetOp to the LLVM dialect.
339/// Pattern: array_get(input, index) =>
340/// load(gep(store(input, alloca), zext(index)))
341struct ArrayGetOpConversion : public HWArrayOpToLLVMPattern<hw::ArrayGetOp> {
342 using HWArrayOpToLLVMPattern<hw::ArrayGetOp>::HWArrayOpToLLVMPattern;
343
344 LogicalResult
345 matchAndRewrite(hw::ArrayGetOp op, OpAdaptor adaptor,
346 ConversionPatternRewriter &rewriter) const override {
347
348 Value arrPtr;
349 if (spillCacheOpt)
350 arrPtr = spillCacheOpt->lookup(adaptor.getInput());
351 if (!arrPtr)
352 arrPtr = spillValueOnStack(rewriter, op.getLoc(), adaptor.getInput());
353
354 auto arrTy = typeConverter->convertType(op.getInput().getType());
355 auto elemTy = typeConverter->convertType(op.getResult().getType());
356 auto zextIndex = zextByOne(op->getLoc(), rewriter, op.getIndex());
357
358 // During the ongoing migration to opaque types, use the constructor that
359 // accepts an element type when the array pointer type is opaque, and
360 // otherwise use the typed pointer constructor.
361 auto gep = LLVM::GEPOp::create(
362 rewriter, op->getLoc(),
363 LLVM::LLVMPointerType::get(rewriter.getContext()), arrTy, arrPtr,
364 ArrayRef<LLVM::GEPArg>{0, zextIndex});
365 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elemTy, gep);
366
367 return success();
368 }
369};
370} // namespace
371
372namespace {
373/// Convert an ArraySliceOp to the LLVM dialect.
374/// Pattern: array_slice(input, lowIndex) =>
375/// load(bitcast(gep(store(input, alloca), zext(lowIndex))))
376struct ArraySliceOpConversion
377 : public HWArrayOpToLLVMPattern<hw::ArraySliceOp> {
378 using HWArrayOpToLLVMPattern<hw::ArraySliceOp>::HWArrayOpToLLVMPattern;
379
380 LogicalResult
381 matchAndRewrite(hw::ArraySliceOp op, OpAdaptor adaptor,
382 ConversionPatternRewriter &rewriter) const override {
383
384 auto dstTy = typeConverter->convertType(op.getDst().getType());
385
386 Value arrPtr;
387 if (spillCacheOpt)
388 arrPtr = spillCacheOpt->lookup(adaptor.getInput());
389 if (!arrPtr)
390 arrPtr = spillValueOnStack(rewriter, op.getLoc(), adaptor.getInput());
391
392 auto zextIndex = zextByOne(op->getLoc(), rewriter, op.getLowIndex());
393
394 // During the ongoing migration to opaque types, use the constructor that
395 // accepts an element type when the array pointer type is opaque, and
396 // otherwise use the typed pointer constructor.
397 auto gep = LLVM::GEPOp::create(
398 rewriter, op->getLoc(),
399 LLVM::LLVMPointerType::get(rewriter.getContext()), dstTy, arrPtr,
400 ArrayRef<LLVM::GEPArg>{0, zextIndex});
401
402 auto loadOp = rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dstTy, gep);
403
404 if (spillCacheOpt)
405 spillCacheOpt->map(loadOp, gep);
406
407 return success();
408 }
409};
410} // namespace
411
412//===----------------------------------------------------------------------===//
413// Insertion operations conversion
414//===----------------------------------------------------------------------===//
415
416namespace {
417/// Convert a StructInjectOp to LLVM dialect.
418/// Pattern: struct_inject(input, index, value) =>
419/// insertvalue(input, value, index)
420struct StructInjectOpConversion
421 : public ConvertOpToLLVMPattern<hw::StructInjectOp> {
422 using ConvertOpToLLVMPattern<hw::StructInjectOp>::ConvertOpToLLVMPattern;
423
424 LogicalResult
425 matchAndRewrite(hw::StructInjectOp op, OpAdaptor adaptor,
426 ConversionPatternRewriter &rewriter) const override {
427
429 op.getInput().getType(), op.getFieldIndex());
430
431 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
432 op, adaptor.getInput(), adaptor.getNewValue(), fieldIndex);
433
434 return success();
435 }
436};
437} // namespace
438
439//===----------------------------------------------------------------------===//
440// Concat operations conversion
441//===----------------------------------------------------------------------===//
442
443namespace {
444/// Lower an ArrayConcatOp operation to the LLVM dialect.
445struct ArrayConcatOpConversion
446 : public HWArrayOpToLLVMPattern<hw::ArrayConcatOp> {
447 using HWArrayOpToLLVMPattern<hw::ArrayConcatOp>::HWArrayOpToLLVMPattern;
448
449 LogicalResult
450 matchAndRewrite(hw::ArrayConcatOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter) const override {
452
453 hw::ArrayType arrTy = cast<hw::ArrayType>(op.getResult().getType());
454 Type resultTy = typeConverter->convertType(arrTy);
455 auto loc = op.getLoc();
456
457 Value arr = LLVM::UndefOp::create(rewriter, loc, resultTy);
458
459 // Attention: j is hardcoded for little endian machines.
460 size_t j = op.getInputs().size() - 1, k = 0;
461
462 for (size_t i = 0, e = arrTy.getNumElements(); i < e; ++i) {
463 Value element = LLVM::ExtractValueOp::create(rewriter, loc,
464 adaptor.getInputs()[j], k);
465 arr = LLVM::InsertValueOp::create(rewriter, loc, arr, element, i);
466
467 ++k;
468 if (k >=
469 cast<hw::ArrayType>(op.getInputs()[j].getType()).getNumElements()) {
470 k = 0;
471 --j;
472 }
473 }
474
475 rewriter.replaceOp(op, arr);
476
477 // If we've got a cache, spill the array right away.
478 if (spillCacheOpt) {
479 rewriter.setInsertionPointAfter(arr.getDefiningOp());
480 auto ptr = spillValueOnStack(rewriter, loc, arr);
481 spillCacheOpt->map(arr, ptr);
482 }
483 return success();
484 }
485};
486} // namespace
487
488//===----------------------------------------------------------------------===//
489// Value creation conversions
490//===----------------------------------------------------------------------===//
491
492namespace {
493struct HWConstantOpConversion : public ConvertToLLVMPattern {
494 explicit HWConstantOpConversion(MLIRContext *ctx,
495 LLVMTypeConverter &typeConverter)
496 : ConvertToLLVMPattern(hw::ConstantOp::getOperationName(), ctx,
497 typeConverter) {}
498
499 LogicalResult
500 matchAndRewrite(Operation *op, ArrayRef<Value> operand,
501 ConversionPatternRewriter &rewriter) const override {
502 // Get the ConstOp.
503 auto constOp = cast<hw::ConstantOp>(op);
504 // Get the converted llvm type.
505 auto intType = typeConverter->convertType(constOp.getValueAttr().getType());
506 // Replace the operation with an llvm constant op.
507 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, intType,
508 constOp.getValueAttr());
509
510 return success();
511 }
512};
513} // namespace
514
515namespace {
516/// Convert an ArrayCreateOp with dynamic elements to the LLVM dialect. An
517/// equivalent and initialized llvm dialect array type is generated.
518struct HWDynamicArrayCreateOpConversion
519 : public ConvertOpToLLVMPattern<hw::ArrayCreateOp> {
520 using ConvertOpToLLVMPattern<hw::ArrayCreateOp>::ConvertOpToLLVMPattern;
521
522 LogicalResult
523 matchAndRewrite(hw::ArrayCreateOp op, OpAdaptor adaptor,
524 ConversionPatternRewriter &rewriter) const override {
525 auto arrayTy = typeConverter->convertType(op->getResult(0).getType());
526 assert(arrayTy);
527
528 Value arr = LLVM::UndefOp::create(rewriter, op->getLoc(), arrayTy);
529 for (size_t i = 0, e = op.getInputs().size(); i < e; ++i) {
530 Value input =
531 adaptor
533 op.getResult().getType(), i)];
534 arr = LLVM::InsertValueOp::create(rewriter, op->getLoc(), arr, input, i);
535 }
536
537 rewriter.replaceOp(op, arr);
538 return success();
539 }
540};
541} // namespace
542
543namespace {
544
545/// Convert an ArrayCreateOp with constant elements to the LLVM dialect. An
546/// equivalent and initialized llvm dialect array type is generated.
547class AggregateConstantOpConversion
548 : public HWArrayOpToLLVMPattern<hw::AggregateConstantOp> {
549 using HWArrayOpToLLVMPattern<hw::AggregateConstantOp>::HWArrayOpToLLVMPattern;
550
551 bool containsArrayAndStructAggregatesOnly(Type type) const;
552
553 bool isMultiDimArrayOfIntegers(Type type,
554 SmallVectorImpl<int64_t> &dims) const;
555
556 void flatten(Type type, Attribute attr,
557 SmallVectorImpl<Attribute> &output) const;
558
559 Value constructAggregate(OpBuilder &builder,
560 const TypeConverter &typeConverter, Location loc,
561 Type type, Attribute data) const;
562
563public:
564 explicit AggregateConstantOpConversion(
565 LLVMTypeConverter &typeConverter,
566 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
567 &constAggregateGlobalsMap,
568 Namespace &globals, std::optional<HWToLLVMArraySpillCache> &spillCacheOpt)
569 : HWArrayOpToLLVMPattern(typeConverter, spillCacheOpt),
570 constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals) {}
571
572 LogicalResult
573 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
574 ConversionPatternRewriter &rewriter) const override;
575
576private:
577 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
578 &constAggregateGlobalsMap;
579 Namespace &globals;
580};
581} // namespace
582
583namespace {
584/// Convert a StructCreateOp operation to the LLVM dialect. An equivalent and
585/// initialized llvm dialect struct type is generated.
586struct HWStructCreateOpConversion
587 : public ConvertOpToLLVMPattern<hw::StructCreateOp> {
588 using ConvertOpToLLVMPattern<hw::StructCreateOp>::ConvertOpToLLVMPattern;
589
590 LogicalResult
591 matchAndRewrite(hw::StructCreateOp op, OpAdaptor adaptor,
592 ConversionPatternRewriter &rewriter) const override {
593
594 auto resTy = typeConverter->convertType(op.getResult().getType());
595
596 Value tup = LLVM::UndefOp::create(rewriter, op->getLoc(), resTy);
597 for (size_t i = 0, e = cast<LLVM::LLVMStructType>(resTy).getBody().size();
598 i < e; ++i) {
599 Value input =
601 op.getResult().getType(), i)];
602 tup = LLVM::InsertValueOp::create(rewriter, op->getLoc(), tup, input, i);
603 }
604
605 rewriter.replaceOp(op, tup);
606 return success();
607 }
608};
609} // namespace
610
611//===----------------------------------------------------------------------===//
612// Pattern implementations
613//===----------------------------------------------------------------------===//
614
615bool AggregateConstantOpConversion::containsArrayAndStructAggregatesOnly(
616 Type type) const {
617 if (auto intType = dyn_cast<IntegerType>(type))
618 return true;
619
620 if (auto arrTy = dyn_cast<hw::ArrayType>(type))
621 return containsArrayAndStructAggregatesOnly(arrTy.getElementType());
622
623 if (auto structTy = dyn_cast<hw::StructType>(type)) {
624 SmallVector<Type> innerTypes;
625 structTy.getInnerTypes(innerTypes);
626 return llvm::all_of(innerTypes, [&](auto ty) {
627 return containsArrayAndStructAggregatesOnly(ty);
628 });
629 }
630
631 return false;
632}
633
634bool AggregateConstantOpConversion::isMultiDimArrayOfIntegers(
635 Type type, SmallVectorImpl<int64_t> &dims) const {
636 if (auto intType = dyn_cast<IntegerType>(type))
637 return true;
638
639 if (auto arrTy = dyn_cast<hw::ArrayType>(type)) {
640 dims.push_back(arrTy.getNumElements());
641 return isMultiDimArrayOfIntegers(arrTy.getElementType(), dims);
642 }
643
644 return false;
645}
646
647void AggregateConstantOpConversion::flatten(
648 Type type, Attribute attr, SmallVectorImpl<Attribute> &output) const {
649 if (isa<IntegerType>(type)) {
650 assert(isa<IntegerAttr>(attr));
651 output.push_back(attr);
652 return;
653 }
654
655 auto arrAttr = cast<ArrayAttr>(attr);
656 for (size_t i = 0, e = arrAttr.size(); i < e; ++i) {
657 auto element =
659
660 flatten(cast<hw::ArrayType>(type).getElementType(), element, output);
661 }
662}
663
664Value AggregateConstantOpConversion::constructAggregate(
665 OpBuilder &builder, const TypeConverter &typeConverter, Location loc,
666 Type type, Attribute data) const {
667 Type llvmType = typeConverter.convertType(type);
668
669 auto getElementType = [](Type type, size_t index) {
670 if (auto arrTy = dyn_cast<hw::ArrayType>(type)) {
671 return arrTy.getElementType();
672 }
673
674 assert(isa<hw::StructType>(type));
675 auto structTy = cast<hw::StructType>(type);
676 SmallVector<Type> innerTypes;
677 structTy.getInnerTypes(innerTypes);
678 return innerTypes[index];
679 };
680
681 return TypeSwitch<Type, Value>(type)
682 .Case<IntegerType>([&](auto ty) {
683 return LLVM::ConstantOp::create(builder, loc, cast<TypedAttr>(data));
684 })
685 .Case<hw::ArrayType, hw::StructType>([&](auto ty) {
686 Value aggVal = LLVM::UndefOp::create(builder, loc, llvmType);
687 auto arrayAttr = cast<ArrayAttr>(data);
688 for (size_t i = 0, e = arrayAttr.size(); i < e; ++i) {
689 size_t currIdx =
691 Attribute input = arrayAttr[currIdx];
692 Type elementType = getElementType(ty, currIdx);
693
694 Value element = constructAggregate(builder, typeConverter, loc,
695 elementType, input);
696 aggVal =
697 LLVM::InsertValueOp::create(builder, loc, aggVal, element, i);
698 }
699
700 return aggVal;
701 });
702}
703
704LogicalResult AggregateConstantOpConversion::matchAndRewrite(
705 hw::AggregateConstantOp op, OpAdaptor adaptor,
706 ConversionPatternRewriter &rewriter) const {
707 Type aggregateType = op.getResult().getType();
708
709 // TODO: Only arrays and structs supported at the moment.
710 if (!containsArrayAndStructAggregatesOnly(aggregateType))
711 return failure();
712
713 auto llvmTy = typeConverter->convertType(op.getResult().getType());
714 auto typeAttrPair = std::make_pair(aggregateType, adaptor.getFields());
715
716 if (!constAggregateGlobalsMap.count(typeAttrPair) ||
717 !constAggregateGlobalsMap[typeAttrPair]) {
718 auto ipSave = rewriter.saveInsertionPoint();
719
720 Operation *parent = op->getParentOp();
721 while (!isa<mlir::ModuleOp>(parent->getParentOp())) {
722 parent = parent->getParentOp();
723 }
724
725 rewriter.setInsertionPoint(parent);
726
727 // Create a global region for this static array.
728 auto name = globals.newName("_aggregate_const_global");
729
730 SmallVector<int64_t> dims;
731 if (isMultiDimArrayOfIntegers(aggregateType, dims)) {
732 SmallVector<Attribute> ints;
733 flatten(aggregateType, adaptor.getFields(), ints);
734 assert(!ints.empty());
735 auto shapedType = RankedTensorType::get(
736 dims, cast<IntegerAttr>(ints.front()).getType());
737 auto denseAttr = DenseElementsAttr::get(shapedType, ints);
738
739 constAggregateGlobalsMap[typeAttrPair] =
740 LLVM::GlobalOp::create(rewriter, op.getLoc(), llvmTy, true,
741 LLVM::Linkage::Internal, name, denseAttr);
742 } else {
743 auto global =
744 LLVM::GlobalOp::create(rewriter, op.getLoc(), llvmTy, false,
745 LLVM::Linkage::Internal, name, Attribute());
746 Block *blk = new Block();
747 global.getInitializerRegion().push_back(blk);
748 rewriter.setInsertionPointToStart(blk);
749
750 Value aggregate =
751 constructAggregate(rewriter, *typeConverter, op.getLoc(),
752 aggregateType, adaptor.getFields());
753 LLVM::ReturnOp::create(rewriter, op.getLoc(), aggregate);
754 constAggregateGlobalsMap[typeAttrPair] = global;
755 }
756
757 rewriter.restoreInsertionPoint(ipSave);
758 }
759
760 // Get the global array address and load it to return an array value.
761 auto addr = LLVM::AddressOfOp::create(rewriter, op->getLoc(),
762 constAggregateGlobalsMap[typeAttrPair]);
763 auto newOp = rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmTy, addr);
764
765 if (spillCacheOpt && llvm::isa<hw::ArrayType>(aggregateType))
766 spillCacheOpt->map(newOp.getResult(), addr);
767
768 return success();
769}
770
771//===----------------------------------------------------------------------===//
772// Type conversions
773//===----------------------------------------------------------------------===//
774
775static Type convertArrayType(hw::ArrayType type, LLVMTypeConverter &converter) {
776 auto elementTy = converter.convertType(type.getElementType());
777 return LLVM::LLVMArrayType::get(elementTy, type.getNumElements());
778}
779
780static Type convertStructType(hw::StructType type,
781 LLVMTypeConverter &converter) {
782 llvm::SmallVector<Type, 8> elements;
783 mlir::SmallVector<mlir::Type> types;
784 type.getInnerTypes(types);
785
786 for (int i = 0, e = types.size(); i < e; ++i)
787 elements.push_back(converter.convertType(
789
790 return LLVM::LLVMStructType::getLiteral(&converter.getContext(), elements);
791}
792
793//===----------------------------------------------------------------------===//
794// Pass initialization
795//===----------------------------------------------------------------------===//
796
797namespace {
798struct HWToLLVMLoweringPass
799 : public circt::impl::ConvertHWToLLVMBase<HWToLLVMLoweringPass> {
800
801 using circt::impl::ConvertHWToLLVMBase<
802 HWToLLVMLoweringPass>::ConvertHWToLLVMBase;
803
804 void runOnOperation() override;
805};
806} // namespace
807
809 LLVMTypeConverter &converter, RewritePatternSet &patterns,
810 Namespace &globals,
811 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
812 &constAggregateGlobalsMap,
813 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt) {
814 MLIRContext *ctx = converter.getDialect()->getContext();
815
816 // Value creation conversion patterns.
817 patterns.add<HWConstantOpConversion>(ctx, converter);
818 patterns.add<HWDynamicArrayCreateOpConversion, HWStructCreateOpConversion>(
819 converter);
820 patterns.add<AggregateConstantOpConversion>(
821 converter, constAggregateGlobalsMap, globals, spillCacheOpt);
822
823 // Extraction operation conversion patterns.
824 patterns.add<StructExplodeOpConversion, StructExtractOpConversion,
825 StructInjectOpConversion>(converter);
826
827 patterns.add<ArrayGetOpConversion, ArrayInjectOpConversion,
828 ArraySliceOpConversion, ArrayConcatOpConversion>(converter,
829 spillCacheOpt);
830}
831
832void circt::populateHWToLLVMTypeConversions(LLVMTypeConverter &converter) {
833 converter.addConversion(
834 [&](hw::ArrayType arr) { return convertArrayType(arr, converter); });
835 converter.addConversion(
836 [&](hw::StructType tup) { return convertStructType(tup, converter); });
837}
838
839void HWToLLVMLoweringPass::runOnOperation() {
840 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
841 std::optional<HWToLLVMArraySpillCache> spillCacheOpt = {};
842 Namespace globals;
843 SymbolCache cache;
844 cache.addDefinitions(getOperation());
845 globals.add(cache);
846
847 RewritePatternSet patterns(&getContext());
848 auto converter = mlir::LLVMTypeConverter(&getContext());
850
851 if (spillArraysEarly) {
852 spillCacheOpt = HWToLLVMArraySpillCache();
853 OpBuilder spillBuilder(getOperation());
854 spillCacheOpt->spillNonHWOps(spillBuilder, converter, getOperation());
855 }
856
857 LLVMConversionTarget target(getContext());
858 target.addIllegalDialect<hw::HWDialect>();
859 // Don't touch non-HW operations
860 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
861
862 // Setup the conversion.
864 constAggregateGlobalsMap, spillCacheOpt);
865
866 // Apply the partial conversion.
867 ConversionConfig config;
868 config.allowPatternRollback = false;
869 if (failed(applyPartialConversion(getOperation(), target, std::move(patterns),
870 config)))
871 signalPassFailure();
872}
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static Type convertStructType(hw::StructType type, LLVMTypeConverter &converter)
Definition HWToLLVM.cpp:780
static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter, Value value)
Create a zext operation by one bit on the given value.
Definition HWToLLVM.cpp:76
static Type convertArrayType(hw::ArrayType type, LLVMTypeConverter &converter)
Definition HWToLLVM.cpp:775
static Value spillValueOnStack(OpBuilder &builder, Location loc, Value spillVal)
Definition HWToLLVM.cpp:88
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap, std::optional< HWToLLVMArraySpillCache > &spillCacheOpt)
Get the HW to LLVM conversion patterns.
Definition hw.py:1
Helper class mapping array values (HW or LLVM Dialect) to pointers to buffers containing the array va...
Definition HWToLLVM.h:47
Value spillHWArrayValue(OpBuilder &builder, Location loc, mlir::LLVMTypeConverter &converter, Value hwArray)
Definition HWToLLVM.cpp:183
Value lookup(Value arrayValue)
Retrieve a pointer to a buffer containing the given array value (HW or LLVM Dialect).
Definition HWToLLVM.cpp:150
void spillNonHWOps(mlir::OpBuilder &builder, mlir::LLVMTypeConverter &converter, Operation *containerOp)
Spill HW array values produced by 'foreign' dialects on the stack.
Definition HWToLLVM.cpp:108
void map(mlir::Value arrayValue, mlir::Value bufferPtr)
Map an LLVM array value to an LLVM pointer.
Definition HWToLLVM.cpp:142
Value spillLLVMArrayValue(OpBuilder &builder, Location loc, Value llvmArray)
Definition HWToLLVM.cpp:168
llvm::DenseMap< Value, Value > spillMap
Definition HWToLLVM.h:70
static uint32_t convertToLLVMEndianess(Type type, uint32_t index)
Convert an index into a HW ArrayType or StructType to LLVM Endianess.
Definition HWToLLVM.cpp:39
static uint32_t llvmIndexOfStructField(hw::StructType type, StringRef fieldName)
Get the index of a specific StructType field in the LLVM lowering of the StructType.
Definition HWToLLVM.cpp:51