CIRCT 20.0.0git
Loading...
Searching...
No Matches
HWAggregateToComb.cpp
Go to the documentation of this file.
1//===- HWAggregateToComb.cpp - HW aggregate to comb -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Pass/Pass.h"
13#include "mlir/Transforms/DialectConversion.h"
14#include "llvm/ADT/APInt.h"
15
16namespace circt {
17namespace hw {
18#define GEN_PASS_DEF_HWAGGREGATETOCOMB
19#include "circt/Dialect/HW/Passes.h.inc"
20} // namespace hw
21} // namespace circt
22
23using namespace mlir;
24using namespace circt;
25
26namespace {
27
28// Lower hw.array_create and hw.array_concat to comb.concat.
29template <typename OpTy>
30struct HWArrayCreateLikeOpConversion : OpConversionPattern<OpTy> {
32 using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
33 LogicalResult
34 matchAndRewrite(OpTy op, OpAdaptor adaptor,
35 ConversionPatternRewriter &rewriter) const override {
36 rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, adaptor.getInputs());
37 return success();
38 }
39};
40
41struct HWAggregateConstantOpConversion
42 : OpConversionPattern<hw::AggregateConstantOp> {
43 using OpConversionPattern<hw::AggregateConstantOp>::OpConversionPattern;
44
45 static LogicalResult peelAttribute(Location loc, Attribute attr,
46 ConversionPatternRewriter &rewriter,
47 APInt &intVal) {
48 SmallVector<Attribute> worklist;
49 worklist.push_back(attr);
50 unsigned nextInsertion = intVal.getBitWidth();
51
52 while (!worklist.empty()) {
53 auto current = worklist.pop_back_val();
54 if (auto innerArray = dyn_cast<ArrayAttr>(current)) {
55 for (auto elem : llvm::reverse(innerArray))
56 worklist.push_back(elem);
57 continue;
58 }
59
60 if (auto intAttr = dyn_cast<IntegerAttr>(current)) {
61 auto chunk = intAttr.getValue();
62 nextInsertion -= chunk.getBitWidth();
63 intVal.insertBits(chunk, nextInsertion);
64 continue;
65 }
66
67 return failure();
68 }
69
70 return success();
71 }
72
73 LogicalResult
74 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
75 ConversionPatternRewriter &rewriter) const override {
76 // Lower to concat.
77 SmallVector<Value> results;
78 auto bitWidth = hw::getBitWidth(op.getType());
79 assert(bitWidth >= 0 && "bit width must be known for constant");
80 APInt intVal(bitWidth, 0);
81 if (failed(peelAttribute(op.getLoc(), adaptor.getFieldsAttr(), rewriter,
82 intVal)))
83 return failure();
84 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, intVal);
85 return success();
86 }
87};
88
89struct HWArrayGetOpConversion : OpConversionPattern<hw::ArrayGetOp> {
91
92 LogicalResult
93 matchAndRewrite(hw::ArrayGetOp op, OpAdaptor adaptor,
94 ConversionPatternRewriter &rewriter) const override {
95 SmallVector<Value> results;
96 auto arrayType = cast<hw::ArrayType>(op.getInput().getType());
97 auto elemType = arrayType.getElementType();
98 auto numElements = arrayType.getNumElements();
99 auto elemWidth = hw::getBitWidth(elemType);
100 if (elemWidth < 0)
101 return rewriter.notifyMatchFailure(op.getLoc(), "unknown element width");
102
103 auto lowered = adaptor.getInput();
104 for (size_t i = 0; i < numElements; ++i)
105 results.push_back(rewriter.createOrFold<comb::ExtractOp>(
106 op.getLoc(), lowered, i * elemWidth, elemWidth));
107
108 SmallVector<Value> bits;
109 comb::extractBits(rewriter, op.getIndex(), bits);
110 auto result = comb::constructMuxTree(rewriter, op.getLoc(), bits, results,
111 results.back());
112
113 rewriter.replaceOp(op, result);
114 return success();
115 }
116};
117
118/// A type converter is needed to perform the in-flight materialization of
119/// aggregate types to integer types.
120class AggregateTypeConverter : public TypeConverter {
121public:
122 AggregateTypeConverter() {
123 addConversion([](Type type) -> Type { return type; });
124 addConversion([](hw::ArrayType t) -> Type {
125 return IntegerType::get(t.getContext(), hw::getBitWidth(t));
126 });
127 addTargetMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
128 mlir::ValueRange inputs,
129 mlir::Location loc) -> mlir::Value {
130 if (inputs.size() != 1)
131 return Value();
132
133 return builder.create<hw::BitcastOp>(loc, resultType, inputs[0])
134 ->getResult(0);
135 });
136
137 addSourceMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
138 mlir::ValueRange inputs,
139 mlir::Location loc) -> mlir::Value {
140 if (inputs.size() != 1)
141 return Value();
142
143 return builder.create<hw::BitcastOp>(loc, resultType, inputs[0])
144 ->getResult(0);
145 });
146 }
147};
148} // namespace
149
151 RewritePatternSet &patterns, AggregateTypeConverter &typeConverter) {
152 patterns.add<HWArrayGetOpConversion,
153 HWArrayCreateLikeOpConversion<hw::ArrayCreateOp>,
154 HWArrayCreateLikeOpConversion<hw::ArrayConcatOp>,
155 HWAggregateConstantOpConversion>(typeConverter,
156 patterns.getContext());
157}
158
159namespace {
160struct HWAggregateToCombPass
161 : public hw::impl::HWAggregateToCombBase<HWAggregateToCombPass> {
162 void runOnOperation() override;
163 using HWAggregateToCombBase<HWAggregateToCombPass>::HWAggregateToCombBase;
164};
165} // namespace
166
167void HWAggregateToCombPass::runOnOperation() {
168 ConversionTarget target(getContext());
169
170 // TODO: Add ArraySliceOp and struct operatons as well.
172 hw::AggregateConstantOp>();
173
174 target.addLegalDialect<hw::HWDialect, comb::CombDialect>();
175
176 RewritePatternSet patterns(&getContext());
177 AggregateTypeConverter typeConverter;
179
180 if (failed(mlir::applyPartialConversion(getOperation(), target,
181 std::move(patterns))))
182 return signalPassFailure();
183}
184
186 return std::make_unique<HWAggregateToCombPass>();
187}
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
static void populateHWAggregateToCombOpConversionPatterns(RewritePatternSet &patterns, AggregateTypeConverter &typeConverter)
std::unique_ptr< mlir::Pass > createHWAggregateToCombPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1