CIRCT 23.0.0git
Loading...
Searching...
No Matches
HWToLLVM.cpp
Go to the documentation of this file.
1//===- HWToLLVM.cpp - HW to LLVM Conversion Pass --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main HW to LLVM Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
15#include "circt/Support/LLVM.h"
17#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18#include "mlir/Conversion/LLVMCommon/Pattern.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/IR/Iterators.h"
23#include "mlir/Interfaces/DataLayoutInterfaces.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include "llvm/ADT/TypeSwitch.h"
27
28namespace circt {
29#define GEN_PASS_DEF_CONVERTHWTOLLVM
30#include "circt/Conversion/Passes.h.inc"
31} // namespace circt
32
33using namespace mlir;
34using namespace circt;
35
36//===----------------------------------------------------------------------===//
37// Endianess Converter
38//===----------------------------------------------------------------------===//
39
40uint32_t
42 uint32_t index) {
43 // This is hardcoded for little endian machines for now.
44 return TypeSwitch<Type, uint32_t>(type)
45 .Case<hw::ArrayType>(
46 [&](hw::ArrayType ty) { return ty.getNumElements() - index - 1; })
47 .Case<hw::StructType>([&](hw::StructType ty) {
48 return ty.getElements().size() - index - 1;
49 });
50}
51
52uint32_t
54 StringRef fieldName) {
55 auto fieldIter = type.getElements();
56 size_t index = 0;
57
58 for (const auto *iter = fieldIter.begin(); iter != fieldIter.end(); ++iter) {
59 if (iter->name == fieldName) {
61 }
62 ++index;
63 }
64
65 // Verifier of StructExtractOp has to ensure that the field name is indeed
66 // present.
67 llvm_unreachable("Field name attribute of hw::StructExtractOp invalid");
68 return 0;
69}
70
71//===----------------------------------------------------------------------===//
72// Helpers
73//===----------------------------------------------------------------------===//
74
75/// Create a zext operation by one bit on the given value. This is useful when
76/// passing unsigned indexes to a GEP instruction, which treats indexes as
77/// signed values, to avoid unexpected "sign overflows".
78static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter,
79 Value value) {
80 auto valueTy = value.getType();
81 auto zextTy = IntegerType::get(valueTy.getContext(),
82 valueTy.getIntOrFloatBitWidth() + 1);
83 return LLVM::ZExtOp::create(rewriter, loc, zextTy, value);
84}
85
86//===----------------------------------------------------------------------===//
87// HWToLLVMArraySpillCache
88//===----------------------------------------------------------------------===//
89
90static Value spillValueOnStack(OpBuilder &builder, Location loc,
91 Value spillVal) {
92 auto oneC = LLVM::ConstantOp::create(
93 builder, loc, IntegerType::get(builder.getContext(), 32),
94 builder.getI32IntegerAttr(1));
95
96 Block *block = builder.getInsertionBlock();
97 assert(block && "expected an insertion block when spilling a value");
98
99 auto alignment =
100 static_cast<unsigned>(DataLayout::closest(block->getParentOp())
101 .getTypePreferredAlignment(spillVal.getType()));
102 alignment = std::max(4u, alignment);
103 Value ptr = LLVM::AllocaOp::create(
104 builder, loc, LLVM::LLVMPointerType::get(builder.getContext()),
105 spillVal.getType(), oneC, alignment);
106 LLVM::StoreOp::create(builder, loc, spillVal, ptr);
107 return ptr;
108}
109
111 LLVMTypeConverter &converter,
112 Operation *containerOp) {
113 OpBuilder::InsertionGuard g(builder);
114 containerOp->walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
115 [&](Operation *op) {
116 if (isa_and_nonnull<hw::HWDialect>(op->getDialect()))
117 return;
118 auto hasSpillingUser = [](Value arrVal) -> bool {
119 for (auto user : arrVal.getUsers())
120 if (isa<hw::ArrayGetOp, hw::ArraySliceOp>(user))
121 return true;
122 return false;
123 };
124 // Spill Block arguments
125 for (auto &region : op->getRegions()) {
126 for (auto &block : region.getBlocks()) {
127 builder.setInsertionPointToStart(&block);
128 for (auto &arg : block.getArguments()) {
129 if (isa<hw::ArrayType>(arg.getType()) && hasSpillingUser(arg))
130 spillHWArrayValue(builder, arg.getLoc(), converter, arg);
131 }
132 }
133 }
134 // Spill Op Results
135 for (auto result : op->getResults()) {
136 if (isa<hw::ArrayType>(result.getType()) && hasSpillingUser(result)) {
137 builder.setInsertionPointAfter(op);
138 spillHWArrayValue(builder, op->getLoc(), converter, result);
139 }
140 }
141 });
142}
143
144void HWToLLVMArraySpillCache::map(Value arrayValue, Value bufferPtr) {
145 assert(isa<LLVM::LLVMArrayType>(arrayValue.getType()) &&
146 "Key is not an LLVM array.");
147 assert(isa<LLVM::LLVMPointerType>(bufferPtr.getType()) &&
148 "Value is not a pointer.");
149 spillMap.insert({arrayValue, bufferPtr});
150}
151
152Value HWToLLVMArraySpillCache::lookup(Value arrayValue) {
153 assert(isa<LLVM::LLVMArrayType>(arrayValue.getType()) ||
154 isa<hw::ArrayType>(arrayValue.getType()) && "Not an array value");
155 while (isa<LLVM::LLVMArrayType, hw::ArrayType>(arrayValue.getType())) {
156 if (isa<LLVM::LLVMArrayType>(arrayValue.getType())) {
157 auto mapVal = spillMap.lookup(arrayValue);
158 if (mapVal)
159 return mapVal;
160 }
161 if (auto castOp = arrayValue.getDefiningOp<UnrealizedConversionCastOp>())
162 arrayValue = castOp.getOperand(0);
163 else
164 break;
165 }
166 return {};
167}
168
169// Materialize a LLVM Array value in a stack allocated buffer.
171 Location loc,
172 Value llvmArray) {
173 assert(isa<LLVM::LLVMArrayType>(llvmArray.getType()) &&
174 "Expected an LLVM array.");
175 auto spillBuffer = spillValueOnStack(builder, loc, llvmArray);
176 auto loadOp =
177 LLVM::LoadOp::create(builder, loc, llvmArray.getType(), spillBuffer);
178 map(loadOp.getResult(), spillBuffer);
179 return loadOp.getResult();
180}
181
182// Materialize a HW Array value in a stack allocated buffer. Replaces
183// all current uses of the SSA value with a new SSA representing the same
184// array value.
186 Location loc,
187 LLVMTypeConverter &converter,
188 Value hwArray) {
189 assert(isa<hw::ArrayType>(hwArray.getType()) && "Expected an HW array");
190 auto targetType = converter.convertType(hwArray.getType());
191 auto hwToLLVMCast =
192 UnrealizedConversionCastOp::create(builder, loc, targetType, hwArray);
193 auto spilled = spillLLVMArrayValue(builder, loc, hwToLLVMCast.getResult(0));
194 auto llvmToHWCast = UnrealizedConversionCastOp::create(
195 builder, loc, hwArray.getType(), spilled);
196 hwArray.replaceAllUsesExcept(llvmToHWCast.getResult(0), hwToLLVMCast);
197 return llvmToHWCast.getResult(0);
198}
199
200namespace {
201// Helper for patterns using or creating buffers containing
202// HW array values.
203template <typename SourceOp>
204struct HWArrayOpToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
205
206 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
207 HWArrayOpToLLVMPattern(LLVMTypeConverter &converter,
208 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt)
209 : ConvertOpToLLVMPattern<SourceOp>(converter),
210 spillCacheOpt(spillCacheOpt) {}
211
212protected:
213 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt;
214};
215
216} // namespace
217
218//===----------------------------------------------------------------------===//
219// Extraction operation conversions
220//===----------------------------------------------------------------------===//
221
222namespace {
223/// Convert a StructExplodeOp to the LLVM dialect.
224/// Pattern: struct_explode(input) =>
225/// struct_extract(input, structElements_index(index)) ...
226struct StructExplodeOpConversion
227 : public ConvertOpToLLVMPattern<hw::StructExplodeOp> {
228 using ConvertOpToLLVMPattern<hw::StructExplodeOp>::ConvertOpToLLVMPattern;
229
230 LogicalResult
231 matchAndRewrite(hw::StructExplodeOp op, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter) const override {
233
234 SmallVector<Value> replacements;
235
236 for (size_t i = 0,
237 e = cast<LLVM::LLVMStructType>(adaptor.getInput().getType())
238 .getBody()
239 .size();
240 i < e; ++i)
241
242 replacements.push_back(LLVM::ExtractValueOp::create(
243 rewriter, op->getLoc(), adaptor.getInput(),
245 op.getInput().getType(), i)));
246
247 rewriter.replaceOp(op, replacements);
248 return success();
249 }
250};
251} // namespace
252
253namespace {
254/// Convert a StructExtractOp to LLVM dialect.
255/// Pattern: struct_extract(input, fieldname) =>
256/// extractvalue(input, fieldname_to_index(fieldname))
257struct StructExtractOpConversion
258 : public ConvertOpToLLVMPattern<hw::StructExtractOp> {
259 using ConvertOpToLLVMPattern<hw::StructExtractOp>::ConvertOpToLLVMPattern;
260
261 LogicalResult
262 matchAndRewrite(hw::StructExtractOp op, OpAdaptor adaptor,
263 ConversionPatternRewriter &rewriter) const override {
264
266 op.getInput().getType(), op.getFieldIndex());
267 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(op, adaptor.getInput(),
268 fieldIndex);
269 return success();
270 }
271};
272} // namespace
273
274namespace {
275/// Convert an ArrayInjectOp to the LLVM dialect.
276/// Pattern: array_inject(input, element, index) =>
277/// store(gep(store(input, alloca), zext(index)), element)
278/// load(alloca)
279struct ArrayInjectOpConversion
280 : public HWArrayOpToLLVMPattern<hw::ArrayInjectOp> {
281 using HWArrayOpToLLVMPattern<hw::ArrayInjectOp>::HWArrayOpToLLVMPattern;
282
283 LogicalResult
284 matchAndRewrite(hw::ArrayInjectOp op, OpAdaptor adaptor,
285 ConversionPatternRewriter &rewriter) const override {
286 auto inputType = cast<hw::ArrayType>(op.getInput().getType());
287 auto oldArrTy = adaptor.getInput().getType();
288 auto newArrTy = oldArrTy;
289 const size_t arrElems = inputType.getNumElements();
290
291 if (arrElems == 0) {
292 rewriter.replaceOp(op, adaptor.getInput());
293 return success();
294 }
295
296 auto oneC =
297 LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI32Type(),
298 rewriter.getI32IntegerAttr(1));
299 auto zextIndex = zextByOne(op->getLoc(), rewriter, op.getIndex());
300
301 if (arrElems == 1 || !llvm::isPowerOf2_64(arrElems)) {
302 // Clamp index to prevent OOB access. We add an extra element to the
303 // array so that OOB access modifies this element, leaving the original
304 // array intact.
305 auto maxIndex =
306 LLVM::ConstantOp::create(rewriter, op->getLoc(), zextIndex.getType(),
307 rewriter.getI32IntegerAttr(arrElems));
308 zextIndex =
309 LLVM::UMinOp::create(rewriter, op->getLoc(), zextIndex, maxIndex);
310
311 newArrTy = typeConverter->convertType(
312 hw::ArrayType::get(inputType.getElementType(), arrElems + 1));
313 }
314 auto allocaAlignment = std::max(
315 4u, static_cast<unsigned>(DataLayout::closest(op.getOperation())
316 .getTypePreferredAlignment(newArrTy)));
317 Value arrPtr = LLVM::AllocaOp::create(
318 rewriter, op->getLoc(),
319 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, oneC,
320 allocaAlignment);
321
322 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getInput(), arrPtr);
323
324 auto gep = LLVM::GEPOp::create(
325 rewriter, op->getLoc(),
326 LLVM::LLVMPointerType::get(rewriter.getContext()), newArrTy, arrPtr,
327 ArrayRef<LLVM::GEPArg>{0, zextIndex});
328
329 LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getElement(), gep);
330 auto loadOp =
331 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, oldArrTy, arrPtr);
332 if (spillCacheOpt)
333 spillCacheOpt->map(loadOp, arrPtr);
334 return success();
335 }
336};
337} // namespace
338
339namespace {
340/// Convert an ArrayGetOp to the LLVM dialect.
341/// Pattern: array_get(input, index) =>
342/// load(gep(store(input, alloca), zext(index)))
343struct ArrayGetOpConversion : public HWArrayOpToLLVMPattern<hw::ArrayGetOp> {
344 using HWArrayOpToLLVMPattern<hw::ArrayGetOp>::HWArrayOpToLLVMPattern;
345
346 LogicalResult
347 matchAndRewrite(hw::ArrayGetOp op, OpAdaptor adaptor,
348 ConversionPatternRewriter &rewriter) const override {
349
350 Value arrPtr;
351 if (spillCacheOpt)
352 arrPtr = spillCacheOpt->lookup(adaptor.getInput());
353 if (!arrPtr)
354 arrPtr = spillValueOnStack(rewriter, op.getLoc(), adaptor.getInput());
355
356 auto arrTy = typeConverter->convertType(op.getInput().getType());
357 auto elemTy = typeConverter->convertType(op.getResult().getType());
358 auto zextIndex = zextByOne(op->getLoc(), rewriter, op.getIndex());
359
360 // During the ongoing migration to opaque types, use the constructor that
361 // accepts an element type when the array pointer type is opaque, and
362 // otherwise use the typed pointer constructor.
363 auto gep = LLVM::GEPOp::create(
364 rewriter, op->getLoc(),
365 LLVM::LLVMPointerType::get(rewriter.getContext()), arrTy, arrPtr,
366 ArrayRef<LLVM::GEPArg>{0, zextIndex});
367 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elemTy, gep);
368
369 return success();
370 }
371};
372} // namespace
373
374namespace {
375/// Convert an ArraySliceOp to the LLVM dialect.
376/// Pattern: array_slice(input, lowIndex) =>
377/// load(bitcast(gep(store(input, alloca), zext(lowIndex))))
378struct ArraySliceOpConversion
379 : public HWArrayOpToLLVMPattern<hw::ArraySliceOp> {
380 using HWArrayOpToLLVMPattern<hw::ArraySliceOp>::HWArrayOpToLLVMPattern;
381
382 LogicalResult
383 matchAndRewrite(hw::ArraySliceOp op, OpAdaptor adaptor,
384 ConversionPatternRewriter &rewriter) const override {
385
386 auto dstTy = typeConverter->convertType(op.getDst().getType());
387
388 Value arrPtr;
389 if (spillCacheOpt)
390 arrPtr = spillCacheOpt->lookup(adaptor.getInput());
391 if (!arrPtr)
392 arrPtr = spillValueOnStack(rewriter, op.getLoc(), adaptor.getInput());
393
394 auto zextIndex = zextByOne(op->getLoc(), rewriter, op.getLowIndex());
395
396 // During the ongoing migration to opaque types, use the constructor that
397 // accepts an element type when the array pointer type is opaque, and
398 // otherwise use the typed pointer constructor.
399 auto gep = LLVM::GEPOp::create(
400 rewriter, op->getLoc(),
401 LLVM::LLVMPointerType::get(rewriter.getContext()), dstTy, arrPtr,
402 ArrayRef<LLVM::GEPArg>{0, zextIndex});
403
404 auto loadOp = rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dstTy, gep);
405
406 if (spillCacheOpt)
407 spillCacheOpt->map(loadOp, gep);
408
409 return success();
410 }
411};
412} // namespace
413
414//===----------------------------------------------------------------------===//
415// Insertion operations conversion
416//===----------------------------------------------------------------------===//
417
418namespace {
419/// Convert a StructInjectOp to LLVM dialect.
420/// Pattern: struct_inject(input, index, value) =>
421/// insertvalue(input, value, index)
422struct StructInjectOpConversion
423 : public ConvertOpToLLVMPattern<hw::StructInjectOp> {
424 using ConvertOpToLLVMPattern<hw::StructInjectOp>::ConvertOpToLLVMPattern;
425
426 LogicalResult
427 matchAndRewrite(hw::StructInjectOp op, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter) const override {
429
431 op.getInput().getType(), op.getFieldIndex());
432
433 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
434 op, adaptor.getInput(), adaptor.getNewValue(), fieldIndex);
435
436 return success();
437 }
438};
439} // namespace
440
441//===----------------------------------------------------------------------===//
442// Union operations conversion
443//===----------------------------------------------------------------------===//
444//
445// A `!hw.union` is lowered to a flat `!llvm.array<N x i8>` byte buffer that is
446// large enough to hold the LLVM representation of its widest member (see
447// `convertUnionType`). Because the members keep their natural LLVM layout
448// inside that buffer -- including any alignment padding -- the conversion works
449// uniformly for integer and aggregate members alike.
450//
451// Both directions go through a stack slot. Creating a union allocates a buffer,
452// stores the member value into it, and reads the whole buffer back. Extracting
453// a member stores the buffer and reads back only the member's bytes. Bytes not
454// covered by the active member are left undefined, which matches the union
455// semantics. Members are placed at the start of the buffer; the field offset is
456// not honored, as nothing downstream of this lowering interprets the buffer's
457// bit layout.
458
459namespace {
460/// Allocate a stack slot of the given type, suitably aligned to hold a value of
461/// `accessType`, and return the pointer to it.
462static Value allocateUnionBuffer(ConversionPatternRewriter &rewriter,
463 Location loc, Type bufferType,
464 Type accessType) {
465 auto *context = rewriter.getContext();
466 auto align =
467 std::max<uint64_t>(1, DataLayout().getTypePreferredAlignment(accessType));
468 Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
469 rewriter.getI32IntegerAttr(1));
470 return LLVM::AllocaOp::create(rewriter, loc,
471 LLVM::LLVMPointerType::get(context), bufferType,
472 one, align);
473}
474
475/// Convert a UnionCreateOp to the LLVM dialect by storing the member value into
476/// a fresh union buffer and reading the buffer back.
477struct UnionCreateOpConversion
478 : public ConvertOpToLLVMPattern<hw::UnionCreateOp> {
479 using ConvertOpToLLVMPattern<hw::UnionCreateOp>::ConvertOpToLLVMPattern;
480
481 LogicalResult
482 matchAndRewrite(hw::UnionCreateOp op, OpAdaptor adaptor,
483 ConversionPatternRewriter &rewriter) const override {
484 auto loc = op.getLoc();
485 auto bufferType = typeConverter->convertType(op.getType());
486 if (!bufferType)
487 return failure();
488 Value input = adaptor.getInput();
489 Value ptr = allocateUnionBuffer(rewriter, loc, bufferType, input.getType());
490 LLVM::StoreOp::create(rewriter, loc, input, ptr);
491 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, bufferType, ptr);
492 return success();
493 }
494};
495
496/// Convert a UnionExtractOp to the LLVM dialect by storing the union buffer and
497/// reading back the requested member's bytes.
498struct UnionExtractOpConversion
499 : public ConvertOpToLLVMPattern<hw::UnionExtractOp> {
500 using ConvertOpToLLVMPattern<hw::UnionExtractOp>::ConvertOpToLLVMPattern;
501
502 LogicalResult
503 matchAndRewrite(hw::UnionExtractOp op, OpAdaptor adaptor,
504 ConversionPatternRewriter &rewriter) const override {
505 auto loc = op.getLoc();
506 auto memberType = typeConverter->convertType(op.getType());
507 if (!memberType)
508 return failure();
509 Value input = adaptor.getInput();
510 Value ptr = allocateUnionBuffer(rewriter, loc, input.getType(), memberType);
511 LLVM::StoreOp::create(rewriter, loc, input, ptr);
512 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, memberType, ptr);
513 return success();
514 }
515};
516} // namespace
517
518//===----------------------------------------------------------------------===//
519// Concat operations conversion
520//===----------------------------------------------------------------------===//
521
522namespace {
523/// Lower an ArrayConcatOp operation to the LLVM dialect.
524struct ArrayConcatOpConversion
525 : public HWArrayOpToLLVMPattern<hw::ArrayConcatOp> {
526 using HWArrayOpToLLVMPattern<hw::ArrayConcatOp>::HWArrayOpToLLVMPattern;
527
528 LogicalResult
529 matchAndRewrite(hw::ArrayConcatOp op, OpAdaptor adaptor,
530 ConversionPatternRewriter &rewriter) const override {
531
532 hw::ArrayType arrTy = cast<hw::ArrayType>(op.getResult().getType());
533 Type resultTy = typeConverter->convertType(arrTy);
534 auto loc = op.getLoc();
535
536 Value arr = LLVM::UndefOp::create(rewriter, loc, resultTy);
537
538 // Attention: j is hardcoded for little endian machines.
539 size_t j = op.getInputs().size() - 1, k = 0;
540
541 for (size_t i = 0, e = arrTy.getNumElements(); i < e; ++i) {
542 Value element = LLVM::ExtractValueOp::create(rewriter, loc,
543 adaptor.getInputs()[j], k);
544 arr = LLVM::InsertValueOp::create(rewriter, loc, arr, element, i);
545
546 ++k;
547 if (k >=
548 cast<hw::ArrayType>(op.getInputs()[j].getType()).getNumElements()) {
549 k = 0;
550 --j;
551 }
552 }
553
554 rewriter.replaceOp(op, arr);
555
556 // If we've got a cache, spill the array right away.
557 if (spillCacheOpt) {
558 rewriter.setInsertionPointAfter(arr.getDefiningOp());
559 auto ptr = spillValueOnStack(rewriter, loc, arr);
560 spillCacheOpt->map(arr, ptr);
561 }
562 return success();
563 }
564};
565} // namespace
566
567//===----------------------------------------------------------------------===//
568// Value creation conversions
569//===----------------------------------------------------------------------===//
570
571namespace {
572struct HWConstantOpConversion : public ConvertToLLVMPattern {
573 explicit HWConstantOpConversion(MLIRContext *ctx,
574 LLVMTypeConverter &typeConverter)
575 : ConvertToLLVMPattern(hw::ConstantOp::getOperationName(), ctx,
576 typeConverter) {}
577
578 LogicalResult
579 matchAndRewrite(Operation *op, ArrayRef<Value> operand,
580 ConversionPatternRewriter &rewriter) const override {
581 // Get the ConstOp.
582 auto constOp = cast<hw::ConstantOp>(op);
583 // Get the converted llvm type.
584 auto intType = typeConverter->convertType(constOp.getValueAttr().getType());
585 // Replace the operation with an llvm constant op.
586 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, intType,
587 constOp.getValueAttr());
588
589 return success();
590 }
591};
592} // namespace
593
594namespace {
595/// Convert an ArrayCreateOp with dynamic elements to the LLVM dialect. An
596/// equivalent and initialized llvm dialect array type is generated.
597struct HWDynamicArrayCreateOpConversion
598 : public ConvertOpToLLVMPattern<hw::ArrayCreateOp> {
599 using ConvertOpToLLVMPattern<hw::ArrayCreateOp>::ConvertOpToLLVMPattern;
600
601 LogicalResult
602 matchAndRewrite(hw::ArrayCreateOp op, OpAdaptor adaptor,
603 ConversionPatternRewriter &rewriter) const override {
604 auto arrayTy = typeConverter->convertType(op->getResult(0).getType());
605 assert(arrayTy);
606
607 Value arr = LLVM::UndefOp::create(rewriter, op->getLoc(), arrayTy);
608 for (size_t i = 0, e = op.getInputs().size(); i < e; ++i) {
609 Value input =
610 adaptor
612 op.getResult().getType(), i)];
613 arr = LLVM::InsertValueOp::create(rewriter, op->getLoc(), arr, input, i);
614 }
615
616 rewriter.replaceOp(op, arr);
617 return success();
618 }
619};
620} // namespace
621
622namespace {
623
624/// Convert an ArrayCreateOp with constant elements to the LLVM dialect. An
625/// equivalent and initialized llvm dialect array type is generated.
626class AggregateConstantOpConversion
627 : public HWArrayOpToLLVMPattern<hw::AggregateConstantOp> {
628 using HWArrayOpToLLVMPattern<hw::AggregateConstantOp>::HWArrayOpToLLVMPattern;
629
630 bool containsArrayAndStructAggregatesOnly(Type type) const;
631
632 bool isMultiDimArrayOfIntegers(Type type,
633 SmallVectorImpl<int64_t> &dims) const;
634
635 void flatten(Type type, Attribute attr,
636 SmallVectorImpl<Attribute> &output) const;
637
638 Value constructAggregate(OpBuilder &builder,
639 const TypeConverter &typeConverter, Location loc,
640 Type type, Attribute data) const;
641
642public:
643 explicit AggregateConstantOpConversion(
644 LLVMTypeConverter &typeConverter,
645 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
646 &constAggregateGlobalsMap,
647 Namespace &globals, std::optional<HWToLLVMArraySpillCache> &spillCacheOpt)
648 : HWArrayOpToLLVMPattern(typeConverter, spillCacheOpt),
649 constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals) {}
650
651 LogicalResult
652 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
653 ConversionPatternRewriter &rewriter) const override;
654
655private:
656 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
657 &constAggregateGlobalsMap;
658 Namespace &globals;
659};
660} // namespace
661
662namespace {
663/// Convert a StructCreateOp operation to the LLVM dialect. An equivalent and
664/// initialized llvm dialect struct type is generated.
665struct HWStructCreateOpConversion
666 : public ConvertOpToLLVMPattern<hw::StructCreateOp> {
667 using ConvertOpToLLVMPattern<hw::StructCreateOp>::ConvertOpToLLVMPattern;
668
669 LogicalResult
670 matchAndRewrite(hw::StructCreateOp op, OpAdaptor adaptor,
671 ConversionPatternRewriter &rewriter) const override {
672
673 auto resTy = typeConverter->convertType(op.getResult().getType());
674
675 Value tup = LLVM::UndefOp::create(rewriter, op->getLoc(), resTy);
676 for (size_t i = 0, e = cast<LLVM::LLVMStructType>(resTy).getBody().size();
677 i < e; ++i) {
678 Value input =
680 op.getResult().getType(), i)];
681 tup = LLVM::InsertValueOp::create(rewriter, op->getLoc(), tup, input, i);
682 }
683
684 rewriter.replaceOp(op, tup);
685 return success();
686 }
687};
688} // namespace
689
690//===----------------------------------------------------------------------===//
691// Pattern implementations
692//===----------------------------------------------------------------------===//
693
694bool AggregateConstantOpConversion::containsArrayAndStructAggregatesOnly(
695 Type type) const {
696 if (auto intType = dyn_cast<IntegerType>(type))
697 return true;
698
699 if (auto arrTy = dyn_cast<hw::ArrayType>(type))
700 return containsArrayAndStructAggregatesOnly(arrTy.getElementType());
701
702 if (auto structTy = dyn_cast<hw::StructType>(type)) {
703 SmallVector<Type> innerTypes;
704 structTy.getInnerTypes(innerTypes);
705 return llvm::all_of(innerTypes, [&](auto ty) {
706 return containsArrayAndStructAggregatesOnly(ty);
707 });
708 }
709
710 return false;
711}
712
713bool AggregateConstantOpConversion::isMultiDimArrayOfIntegers(
714 Type type, SmallVectorImpl<int64_t> &dims) const {
715 if (auto intType = dyn_cast<IntegerType>(type))
716 return true;
717
718 if (auto arrTy = dyn_cast<hw::ArrayType>(type)) {
719 dims.push_back(arrTy.getNumElements());
720 return isMultiDimArrayOfIntegers(arrTy.getElementType(), dims);
721 }
722
723 return false;
724}
725
726void AggregateConstantOpConversion::flatten(
727 Type type, Attribute attr, SmallVectorImpl<Attribute> &output) const {
728 if (isa<IntegerType>(type)) {
729 assert(isa<IntegerAttr>(attr));
730 output.push_back(attr);
731 return;
732 }
733
734 auto arrAttr = cast<ArrayAttr>(attr);
735 for (size_t i = 0, e = arrAttr.size(); i < e; ++i) {
736 auto element =
738
739 flatten(cast<hw::ArrayType>(type).getElementType(), element, output);
740 }
741}
742
743Value AggregateConstantOpConversion::constructAggregate(
744 OpBuilder &builder, const TypeConverter &typeConverter, Location loc,
745 Type type, Attribute data) const {
746 Type llvmType = typeConverter.convertType(type);
747
748 auto getElementType = [](Type type, size_t index) {
749 if (auto arrTy = dyn_cast<hw::ArrayType>(type)) {
750 return arrTy.getElementType();
751 }
752
753 assert(isa<hw::StructType>(type));
754 auto structTy = cast<hw::StructType>(type);
755 SmallVector<Type> innerTypes;
756 structTy.getInnerTypes(innerTypes);
757 return innerTypes[index];
758 };
759
760 return TypeSwitch<Type, Value>(type)
761 .Case<IntegerType>([&](auto ty) {
762 return LLVM::ConstantOp::create(builder, loc, cast<TypedAttr>(data));
763 })
764 .Case<hw::ArrayType, hw::StructType>([&](auto ty) {
765 Value aggVal = LLVM::UndefOp::create(builder, loc, llvmType);
766 auto arrayAttr = cast<ArrayAttr>(data);
767 for (size_t i = 0, e = arrayAttr.size(); i < e; ++i) {
768 size_t currIdx =
770 Attribute input = arrayAttr[currIdx];
771 Type elementType = getElementType(ty, currIdx);
772
773 Value element = constructAggregate(builder, typeConverter, loc,
774 elementType, input);
775 aggVal =
776 LLVM::InsertValueOp::create(builder, loc, aggVal, element, i);
777 }
778
779 return aggVal;
780 });
781}
782
783LogicalResult AggregateConstantOpConversion::matchAndRewrite(
784 hw::AggregateConstantOp op, OpAdaptor adaptor,
785 ConversionPatternRewriter &rewriter) const {
786 Type aggregateType = op.getResult().getType();
787
788 // TODO: Only arrays and structs supported at the moment.
789 if (!containsArrayAndStructAggregatesOnly(aggregateType))
790 return failure();
791
792 auto llvmTy = typeConverter->convertType(op.getResult().getType());
793 auto typeAttrPair = std::make_pair(aggregateType, adaptor.getFields());
794
795 if (!constAggregateGlobalsMap.count(typeAttrPair) ||
796 !constAggregateGlobalsMap[typeAttrPair]) {
797 auto ipSave = rewriter.saveInsertionPoint();
798
799 Operation *parent = op->getParentOp();
800 while (!isa<mlir::ModuleOp>(parent->getParentOp())) {
801 parent = parent->getParentOp();
802 }
803
804 rewriter.setInsertionPoint(parent);
805
806 // Create a global region for this static array.
807 auto name = globals.newName("_aggregate_const_global");
808
809 SmallVector<int64_t> dims;
810 if (isMultiDimArrayOfIntegers(aggregateType, dims)) {
811 SmallVector<Attribute> ints;
812 flatten(aggregateType, adaptor.getFields(), ints);
813 assert(!ints.empty());
814 auto shapedType = RankedTensorType::get(
815 dims, cast<IntegerAttr>(ints.front()).getType());
816 auto denseAttr = DenseElementsAttr::get(shapedType, ints);
817
818 constAggregateGlobalsMap[typeAttrPair] =
819 LLVM::GlobalOp::create(rewriter, op.getLoc(), llvmTy, true,
820 LLVM::Linkage::Internal, name, denseAttr);
821 } else {
822 auto global =
823 LLVM::GlobalOp::create(rewriter, op.getLoc(), llvmTy, false,
824 LLVM::Linkage::Internal, name, Attribute());
825 Block *blk = new Block();
826 global.getInitializerRegion().push_back(blk);
827 rewriter.setInsertionPointToStart(blk);
828
829 Value aggregate =
830 constructAggregate(rewriter, *typeConverter, op.getLoc(),
831 aggregateType, adaptor.getFields());
832 LLVM::ReturnOp::create(rewriter, op.getLoc(), aggregate);
833 constAggregateGlobalsMap[typeAttrPair] = global;
834 }
835
836 rewriter.restoreInsertionPoint(ipSave);
837 }
838
839 // Get the global array address and load it to return an array value.
840 auto addr = LLVM::AddressOfOp::create(rewriter, op->getLoc(),
841 constAggregateGlobalsMap[typeAttrPair]);
842 auto newOp = rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmTy, addr);
843
844 if (spillCacheOpt && llvm::isa<hw::ArrayType>(aggregateType))
845 spillCacheOpt->map(newOp.getResult(), addr);
846
847 return success();
848}
849
850//===----------------------------------------------------------------------===//
851// Type conversions
852//===----------------------------------------------------------------------===//
853
854static Type convertArrayType(hw::ArrayType type, LLVMTypeConverter &converter) {
855 auto elementTy = converter.convertType(type.getElementType());
856 return LLVM::LLVMArrayType::get(elementTy, type.getNumElements());
857}
858
859static Type convertStructType(hw::StructType type,
860 LLVMTypeConverter &converter) {
861 llvm::SmallVector<Type, 8> elements;
862 mlir::SmallVector<mlir::Type> types;
863 type.getInnerTypes(types);
864
865 for (int i = 0, e = types.size(); i < e; ++i)
866 elements.push_back(converter.convertType(
868
869 return LLVM::LLVMStructType::getLiteral(&converter.getContext(), elements);
870}
871
872/// Convert a union to a flat byte buffer large enough to hold the LLVM
873/// representation of its widest member, including any alignment padding.
874static Type convertUnionType(hw::UnionType type, LLVMTypeConverter &converter) {
875 DataLayout layout;
876 uint64_t maxBytes = 0;
877 for (auto field : type.getElements()) {
878 auto llvmFieldTy = converter.convertType(field.type);
879 if (!llvmFieldTy)
880 return {};
881 maxBytes =
882 std::max(maxBytes, layout.getTypeSize(llvmFieldTy).getFixedValue());
883 }
884 return LLVM::LLVMArrayType::get(IntegerType::get(&converter.getContext(), 8),
885 maxBytes);
886}
887
888//===----------------------------------------------------------------------===//
889// Pass initialization
890//===----------------------------------------------------------------------===//
891
892namespace {
893struct HWToLLVMLoweringPass
894 : public circt::impl::ConvertHWToLLVMBase<HWToLLVMLoweringPass> {
895
896 using circt::impl::ConvertHWToLLVMBase<
897 HWToLLVMLoweringPass>::ConvertHWToLLVMBase;
898
899 void runOnOperation() override;
900};
901} // namespace
902
904 LLVMTypeConverter &converter, RewritePatternSet &patterns,
905 Namespace &globals,
906 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
907 &constAggregateGlobalsMap,
908 std::optional<HWToLLVMArraySpillCache> &spillCacheOpt) {
909 MLIRContext *ctx = converter.getDialect()->getContext();
910
911 // Value creation conversion patterns.
912 patterns.add<HWConstantOpConversion>(ctx, converter);
913 patterns.add<HWDynamicArrayCreateOpConversion, HWStructCreateOpConversion>(
914 converter);
915 patterns.add<AggregateConstantOpConversion>(
916 converter, constAggregateGlobalsMap, globals, spillCacheOpt);
917
918 // Extraction operation conversion patterns.
919 patterns.add<StructExplodeOpConversion, StructExtractOpConversion,
920 StructInjectOpConversion>(converter);
921
922 // Union operation conversion patterns.
923 patterns.add<UnionCreateOpConversion, UnionExtractOpConversion>(converter);
924
925 patterns.add<ArrayGetOpConversion, ArrayInjectOpConversion,
926 ArraySliceOpConversion, ArrayConcatOpConversion>(converter,
927 spillCacheOpt);
928}
929
930void circt::populateHWToLLVMTypeConversions(LLVMTypeConverter &converter) {
931 converter.addConversion(
932 [&](hw::ArrayType arr) { return convertArrayType(arr, converter); });
933 converter.addConversion(
934 [&](hw::StructType tup) { return convertStructType(tup, converter); });
935 converter.addConversion(
936 [&](hw::UnionType uni) { return convertUnionType(uni, converter); });
937}
938
939void HWToLLVMLoweringPass::runOnOperation() {
940 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
941 std::optional<HWToLLVMArraySpillCache> spillCacheOpt = {};
942 Namespace globals;
943 SymbolCache cache;
944 cache.addDefinitions(getOperation());
945 globals.add(cache);
946
947 RewritePatternSet patterns(&getContext());
948 auto converter = mlir::LLVMTypeConverter(&getContext());
950
951 if (spillArraysEarly) {
952 spillCacheOpt = HWToLLVMArraySpillCache();
953 OpBuilder spillBuilder(getOperation());
954 spillCacheOpt->spillNonHWOps(spillBuilder, converter, getOperation());
955 }
956
957 LLVMConversionTarget target(getContext());
958 target.addIllegalDialect<hw::HWDialect>();
959 // Don't touch non-HW operations
960 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
961
962 // Rewrite the aggregate HW types carried by function signatures and the
963 // `return`/`call` ops that wire values through them, so that no leftover
964 // values of HW type remain once the HW ops have been lowered. The op
965 // legality is keyed on the type converter.
966 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
967 return converter.isSignatureLegal(op.getFunctionType()) &&
968 converter.isLegal(&op.getBody());
969 });
970 target.addDynamicallyLegalOp<func::ReturnOp, func::CallOp>(
971 [&](Operation *op) { return converter.isLegal(op); });
972
973 // Setup the conversion.
975 constAggregateGlobalsMap, spillCacheOpt);
976 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
977 converter);
978 populateReturnOpTypeConversionPattern(patterns, converter);
979 populateCallOpTypeConversionPattern(patterns, converter);
980
981 // Apply the partial conversion.
982 ConversionConfig config;
983 config.allowPatternRollback = false;
984 if (failed(applyPartialConversion(getOperation(), target, std::move(patterns),
985 config)))
986 return signalPassFailure();
987
988 // Reconcile the temporary `unrealized_conversion_cast` ops the conversion
989 // left behind -- in particular the `hw -> llvm -> hw` round-trips that the
990 // framework materializes around spilled values whose signature was
991 // converted. This keeps the pass self-contained instead of relying on a
992 // downstream reconcile pass; genuine boundary casts are left in place.
993 SmallVector<UnrealizedConversionCastOp> castOps;
994 getOperation()->walk(
995 [&](UnrealizedConversionCastOp op) { castOps.push_back(op); });
996 reconcileUnrealizedCasts(castOps, /*remainingCastOps=*/nullptr);
997}
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static std::unique_ptr< Context > context
static Type convertStructType(hw::StructType type, LLVMTypeConverter &converter)
Definition HWToLLVM.cpp:859
static Type convertUnionType(hw::UnionType type, LLVMTypeConverter &converter)
Convert a union to a flat byte buffer large enough to hold the LLVM representation of its widest memb...
Definition HWToLLVM.cpp:874
static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter, Value value)
Create a zext operation by one bit on the given value.
Definition HWToLLVM.cpp:78
static Type convertArrayType(hw::ArrayType type, LLVMTypeConverter &converter)
Definition HWToLLVM.cpp:854
static Value spillValueOnStack(OpBuilder &builder, Location loc, Value spillVal)
Definition HWToLLVM.cpp:90
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap, std::optional< HWToLLVMArraySpillCache > &spillCacheOpt)
Get the HW to LLVM conversion patterns.
Definition hw.py:1
Helper class mapping array values (HW or LLVM Dialect) to pointers to buffers containing the array va...
Definition HWToLLVM.h:47
Value spillHWArrayValue(OpBuilder &builder, Location loc, mlir::LLVMTypeConverter &converter, Value hwArray)
Definition HWToLLVM.cpp:185
Value lookup(Value arrayValue)
Retrieve a pointer to a buffer containing the given array value (HW or LLVM Dialect).
Definition HWToLLVM.cpp:152
void spillNonHWOps(mlir::OpBuilder &builder, mlir::LLVMTypeConverter &converter, Operation *containerOp)
Spill HW array values produced by 'foreign' dialects on the stack.
Definition HWToLLVM.cpp:110
void map(mlir::Value arrayValue, mlir::Value bufferPtr)
Map an LLVM array value to an LLVM pointer.
Definition HWToLLVM.cpp:144
Value spillLLVMArrayValue(OpBuilder &builder, Location loc, Value llvmArray)
Definition HWToLLVM.cpp:170
llvm::DenseMap< Value, Value > spillMap
Definition HWToLLVM.h:70
static uint32_t convertToLLVMEndianess(Type type, uint32_t index)
Convert an index into a HW ArrayType or StructType to LLVM Endianess.
Definition HWToLLVM.cpp:41
static uint32_t llvmIndexOfStructField(hw::StructType type, StringRef fieldName)
Get the index of a specific StructType field in the LLVM lowering of the StructType.
Definition HWToLLVM.cpp:53