CIRCT 23.0.0git
Loading...
Searching...
No Matches
HWAggregateToComb.cpp
Go to the documentation of this file.
1//===- HWAggregateToComb.cpp - HW aggregate to comb -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Pass/Pass.h"
13#include "mlir/Transforms/DialectConversion.h"
14#include "llvm/ADT/APInt.h"
15
16namespace circt {
17namespace hw {
18#define GEN_PASS_DEF_HWAGGREGATETOCOMB
19#include "circt/Dialect/HW/Passes.h.inc"
20} // namespace hw
21} // namespace circt
22
23using namespace mlir;
24using namespace circt;
25
26namespace {
27
28// Lower hw.array_create and hw.array_concat to comb.concat.
29template <typename OpTy>
30struct HWArrayCreateLikeOpConversion : OpConversionPattern<OpTy> {
32 using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
33 LogicalResult
34 matchAndRewrite(OpTy op, OpAdaptor adaptor,
35 ConversionPatternRewriter &rewriter) const override {
36 rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, adaptor.getInputs());
37 return success();
38 }
39};
40
41struct HWAggregateConstantOpConversion
42 : OpConversionPattern<hw::AggregateConstantOp> {
43 using OpConversionPattern<hw::AggregateConstantOp>::OpConversionPattern;
44
45 static LogicalResult peelAttribute(Location loc, Attribute attr,
46 ConversionPatternRewriter &rewriter,
47 APInt &intVal) {
48 SmallVector<Attribute> worklist;
49 worklist.push_back(attr);
50 unsigned nextInsertion = intVal.getBitWidth();
51
52 while (!worklist.empty()) {
53 auto current = worklist.pop_back_val();
54 if (auto innerArray = dyn_cast<ArrayAttr>(current)) {
55 for (auto elem : llvm::reverse(innerArray))
56 worklist.push_back(elem);
57 continue;
58 }
59
60 if (auto intAttr = dyn_cast<IntegerAttr>(current)) {
61 auto chunk = intAttr.getValue();
62 nextInsertion -= chunk.getBitWidth();
63 intVal.insertBits(chunk, nextInsertion);
64 continue;
65 }
66
67 return failure();
68 }
69
70 return success();
71 }
72
73 LogicalResult
74 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
75 ConversionPatternRewriter &rewriter) const override {
76 // Lower to concat.
77 SmallVector<Value> results;
78 auto bitWidth = hw::getBitWidth(op.getType());
79 assert(bitWidth >= 0 && "bit width must be known for constant");
80 APInt intVal(bitWidth, 0);
81 if (failed(peelAttribute(op.getLoc(), adaptor.getFieldsAttr(), rewriter,
82 intVal)))
83 return failure();
84 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, intVal);
85 return success();
86 }
87};
88
89struct HWArrayGetOpConversion : OpConversionPattern<hw::ArrayGetOp> {
91
92 LogicalResult
93 matchAndRewrite(hw::ArrayGetOp op, OpAdaptor adaptor,
94 ConversionPatternRewriter &rewriter) const override {
95 SmallVector<Value> results;
96 auto arrayType = cast<hw::ArrayType>(op.getInput().getType());
97 auto elemType = arrayType.getElementType();
98 auto numElements = arrayType.getNumElements();
99 auto elemWidth = hw::getBitWidth(elemType);
100 if (elemWidth < 0)
101 return rewriter.notifyMatchFailure(op.getLoc(), "unknown element width");
102
103 auto lowered = adaptor.getInput();
104 auto index = adaptor.getIndex();
105 APInt constantIndex;
106 if (matchPattern(index, m_ConstantInt(&constantIndex))) {
107 int64_t maxIndex = std::numeric_limits<int32_t>::max() / elemWidth;
108 if (constantIndex.isSingleWord() &&
109 constantIndex.getZExtValue() <= static_cast<uint64_t>(maxIndex)) {
110 rewriter.replaceOpWithNewOp<comb::ExtractOp>(
111 op, lowered, constantIndex.getZExtValue() * elemWidth, elemWidth);
112 return success();
113 }
114 }
115
116 for (size_t i = 0; i < numElements; ++i)
117 results.push_back(rewriter.createOrFold<comb::ExtractOp>(
118 op.getLoc(), lowered, i * elemWidth, elemWidth));
119
120 SmallVector<Value> bits;
121 comb::extractBits(rewriter, index, bits);
122 auto result = comb::constructMuxTree(rewriter, op.getLoc(), bits, results,
123 results.back());
124
125 rewriter.replaceOp(op, result);
126 return success();
127 }
128};
129
130struct HWArrayInjectOpConversion : OpConversionPattern<hw::ArrayInjectOp> {
131 using OpConversionPattern<hw::ArrayInjectOp>::OpConversionPattern;
132
133 LogicalResult
134 matchAndRewrite(hw::ArrayInjectOp op, OpAdaptor adaptor,
135 ConversionPatternRewriter &rewriter) const override {
136 auto arrayType = cast<hw::ArrayType>(op.getInput().getType());
137 auto elemType = arrayType.getElementType();
138 auto numElements = arrayType.getNumElements();
139 auto elemWidth = hw::getBitWidth(elemType);
140 if (elemWidth < 0)
141 return rewriter.notifyMatchFailure(op.getLoc(), "unknown element width");
142
143 Location loc = op.getLoc();
144
145 // Extract all elements from the input array
146 SmallVector<Value> originalElements;
147 auto inputArray = adaptor.getInput();
148 for (size_t i = 0; i < numElements; ++i) {
149 originalElements.push_back(rewriter.createOrFold<comb::ExtractOp>(
150 loc, inputArray, i * elemWidth, elemWidth));
151 }
152
153 // Create 2D array: each row represents what the array would look like
154 // if injection happened at that specific index
155 SmallVector<Value> arrayRows;
156 arrayRows.reserve(numElements);
157 for (int injectIdx = numElements - 1; injectIdx >= 0; --injectIdx) {
158 SmallVector<Value> rowElements;
159 rowElements.reserve(numElements);
160
161 // Build the row: array[n-1], array[n-2], ..., but replace element at
162 // injectIdx with newVal
163 for (int originalIdx = numElements - 1; originalIdx >= 0; --originalIdx) {
164 if (originalIdx == injectIdx) {
165 rowElements.push_back(adaptor.getElement());
166 } else {
167 rowElements.push_back(originalElements[originalIdx]);
168 }
169 }
170
171 // Concatenate elements to form this row
172 Value row = hw::ArrayCreateOp::create(rewriter, loc, rowElements);
173 arrayRows.push_back(row);
174 }
175
176 // Create the 2D array by concatenating all rows
177 // arrayRows[0] corresponds to injection at index 0
178 // arrayRows[1] corresponds to injection at index 1, etc.
179 Value array2D = hw::ArrayCreateOp::create(rewriter, loc, arrayRows);
180
181 // Create array_get operation to select the row
182 auto arrayGetOp =
183 hw::ArrayGetOp::create(rewriter, loc, array2D, adaptor.getIndex());
184
185 rewriter.replaceOp(op, arrayGetOp);
186 return success();
187 }
188};
189
190struct HWStructCreateOpConversion : OpConversionPattern<hw::StructCreateOp> {
192
193 LogicalResult
194 matchAndRewrite(hw::StructCreateOp op, OpAdaptor adaptor,
195 ConversionPatternRewriter &rewriter) const override {
196 // Lower struct_create to comb.concat. The first field occupies the MSBs, so
197 // we concatenate fields in order (comb.concat places first operand at MSB).
198 rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, adaptor.getInput());
199 return success();
200 }
201};
202
203struct HWStructExtractOpConversion : OpConversionPattern<hw::StructExtractOp> {
205
206 LogicalResult
207 matchAndRewrite(hw::StructExtractOp op, OpAdaptor adaptor,
208 ConversionPatternRewriter &rewriter) const override {
209 auto structType = cast<hw::StructType>(op.getInput().getType());
210 auto fieldIndex = op.getFieldIndex();
211 auto elements = structType.getElements();
212
213 int64_t totalBitWidth = hw::getBitWidth(structType);
214 if (totalBitWidth < 0)
215 return rewriter.notifyMatchFailure(op.getLoc(), "unknown struct width");
216
217 // Compute the bit offset from the MSB by summing the widths of all
218 // preceding fields. The first field occupies the MSBs.
219 int64_t consumedBits = 0;
220 for (size_t i = 0; i < fieldIndex; ++i) {
221 int64_t fieldWidth = hw::getBitWidth(elements[i].type);
222 assert(fieldWidth >= 0 &&
223 "must be failed before if field width is unknown");
224 consumedBits += fieldWidth;
225 }
226
227 int64_t fieldWidth = hw::getBitWidth(elements[fieldIndex].type);
228 assert(fieldWidth >= 0 &&
229 "must be failed before if field width is unknown");
230
231 // Extract the field using comb.extract. Offset is from LSB.
232 int64_t bitOffset = totalBitWidth - consumedBits - fieldWidth;
233 rewriter.replaceOpWithNewOp<comb::ExtractOp>(op, adaptor.getInput(),
234 bitOffset, fieldWidth);
235 return success();
236 }
237};
238
239struct MuxOpConversion : OpConversionPattern<comb::MuxOp> {
241
242 LogicalResult
243 matchAndRewrite(comb::MuxOp op, OpAdaptor adaptor,
244 ConversionPatternRewriter &rewriter) const override {
245 // Re-create Mux with legalized types.
246 rewriter.replaceOpWithNewOp<comb::MuxOp>(
247 op, adaptor.getCond(), adaptor.getTrueValue(), adaptor.getFalseValue());
248 return success();
249 }
250};
251
252/// A type converter is needed to perform the in-flight materialization of
253/// aggregate types to integer types.
254class AggregateTypeConverter : public TypeConverter {
255public:
256 AggregateTypeConverter() {
257 addConversion([](Type type) -> Type { return type; });
258 addConversion([](hw::ArrayType t) -> Type {
259 return IntegerType::get(t.getContext(), hw::getBitWidth(t));
260 });
261 addConversion([](hw::StructType t) -> Type {
262 return IntegerType::get(t.getContext(), hw::getBitWidth(t));
263 });
264 addTargetMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
265 mlir::ValueRange inputs,
266 mlir::Location loc) -> mlir::Value {
267 if (inputs.size() != 1)
268 return Value();
269
270 return hw::BitcastOp::create(builder, loc, resultType, inputs[0])
271 ->getResult(0);
272 });
273
274 addSourceMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
275 mlir::ValueRange inputs,
276 mlir::Location loc) -> mlir::Value {
277 if (inputs.size() != 1)
278 return Value();
279
280 return hw::BitcastOp::create(builder, loc, resultType, inputs[0])
281 ->getResult(0);
282 });
283 }
284};
285} // namespace
286
288 RewritePatternSet &patterns, AggregateTypeConverter &typeConverter) {
289 patterns.add<
290 HWArrayGetOpConversion, HWArrayCreateLikeOpConversion<hw::ArrayCreateOp>,
291 HWArrayCreateLikeOpConversion<hw::ArrayConcatOp>,
292 HWAggregateConstantOpConversion, HWArrayInjectOpConversion,
293 HWStructCreateOpConversion, HWStructExtractOpConversion, MuxOpConversion>(
294 typeConverter, patterns.getContext());
295}
296
297namespace {
298struct HWAggregateToCombPass
299 : public hw::impl::HWAggregateToCombBase<HWAggregateToCombPass> {
300 void runOnOperation() override;
301 using HWAggregateToCombBase<HWAggregateToCombPass>::HWAggregateToCombBase;
302};
303} // namespace
304
305void HWAggregateToCombPass::runOnOperation() {
306 ConversionTarget target(getContext());
307
308 // TODO: Add ArraySliceOp and struct operatons as well.
310 hw::AggregateConstantOp, hw::ArrayInjectOp,
312 target.addDynamicallyLegalOp<comb::MuxOp>(
313 [](comb::MuxOp op) { return hw::type_isa<IntegerType>(op.getType()); });
314 target.addLegalDialect<hw::HWDialect, comb::CombDialect>();
315
316 RewritePatternSet patterns(&getContext());
317 AggregateTypeConverter typeConverter;
319
320 if (failed(mlir::applyPartialConversion(getOperation(), target,
321 std::move(patterns))))
322 return signalPassFailure();
323}
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
static void populateHWAggregateToCombOpConversionPatterns(RewritePatternSet &patterns, AggregateTypeConverter &typeConverter)
create(elements, Type result_type=None)
Definition hw.py:483
create(array_value, idx)
Definition hw.py:450
create(data_type, value)
Definition hw.py:441
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1