12#include "mlir/Pass/Pass.h"
13#include "mlir/Transforms/DialectConversion.h"
14#include "llvm/ADT/APInt.h"
18#define GEN_PASS_DEF_HWAGGREGATETOCOMB
19#include "circt/Dialect/HW/Passes.h.inc"
29template <
typename OpTy>
34 matchAndRewrite(OpTy op, OpAdaptor adaptor,
35 ConversionPatternRewriter &rewriter)
const override {
36 rewriter.replaceOpWithNewOp<
comb::ConcatOp>(op, adaptor.getInputs());
41struct HWAggregateConstantOpConversion
45 static LogicalResult peelAttribute(Location loc, Attribute attr,
46 ConversionPatternRewriter &rewriter,
48 SmallVector<Attribute> worklist;
49 worklist.push_back(attr);
50 unsigned nextInsertion = intVal.getBitWidth();
52 while (!worklist.empty()) {
53 auto current = worklist.pop_back_val();
54 if (
auto innerArray = dyn_cast<ArrayAttr>(current)) {
55 for (
auto elem :
llvm::reverse(innerArray))
56 worklist.push_back(elem);
60 if (
auto intAttr = dyn_cast<IntegerAttr>(current)) {
61 auto chunk = intAttr.getValue();
62 nextInsertion -= chunk.getBitWidth();
63 intVal.insertBits(chunk, nextInsertion);
74 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
75 ConversionPatternRewriter &rewriter)
const override {
77 SmallVector<Value> results;
78 auto bitWidth = hw::getBitWidth(op.getType());
79 assert(bitWidth >= 0 &&
"bit width must be known for constant");
80 APInt intVal(bitWidth, 0);
81 if (failed(peelAttribute(op.getLoc(), adaptor.getFieldsAttr(), rewriter,
94 ConversionPatternRewriter &rewriter)
const override {
95 SmallVector<Value> results;
96 auto arrayType = cast<hw::ArrayType>(op.getInput().getType());
97 auto elemType = arrayType.getElementType();
99 auto elemWidth = hw::getBitWidth(elemType);
101 return rewriter.notifyMatchFailure(op.getLoc(),
"unknown element width");
103 auto lowered = adaptor.getInput();
104 auto index = adaptor.getIndex();
106 if (matchPattern(index, m_ConstantInt(&constantIndex))) {
107 int64_t maxIndex = std::numeric_limits<int32_t>::max() / elemWidth;
108 if (constantIndex.isSingleWord() &&
109 constantIndex.getZExtValue() <=
static_cast<uint64_t
>(maxIndex)) {
111 op, lowered, constantIndex.getZExtValue() * elemWidth, elemWidth);
118 op.getLoc(), lowered, i * elemWidth, elemWidth));
120 SmallVector<Value> bits;
121 comb::extractBits(rewriter, index, bits);
122 auto result = comb::constructMuxTree(rewriter, op.getLoc(), bits, results,
125 rewriter.replaceOp(op, result);
134 matchAndRewrite(hw::ArrayInjectOp op, OpAdaptor adaptor,
135 ConversionPatternRewriter &rewriter)
const override {
136 auto arrayType = cast<hw::ArrayType>(op.getInput().getType());
137 auto elemType = arrayType.getElementType();
139 auto elemWidth = hw::getBitWidth(elemType);
141 return rewriter.notifyMatchFailure(op.getLoc(),
"unknown element width");
143 Location loc = op.getLoc();
146 SmallVector<Value> originalElements;
147 auto inputArray = adaptor.getInput();
150 loc, inputArray, i * elemWidth, elemWidth));
155 SmallVector<Value> arrayRows;
157 for (
int injectIdx =
numElements - 1; injectIdx >= 0; --injectIdx) {
158 SmallVector<Value> rowElements;
163 for (
int originalIdx =
numElements - 1; originalIdx >= 0; --originalIdx) {
164 if (originalIdx == injectIdx) {
165 rowElements.push_back(adaptor.getElement());
167 rowElements.push_back(originalElements[originalIdx]);
173 arrayRows.push_back(row);
185 rewriter.replaceOp(op, arrayGetOp);
195 ConversionPatternRewriter &rewriter)
const override {
198 rewriter.replaceOpWithNewOp<
comb::ConcatOp>(op, adaptor.getInput());
208 ConversionPatternRewriter &rewriter)
const override {
209 auto structType = cast<hw::StructType>(op.getInput().getType());
210 auto fieldIndex = op.getFieldIndex();
211 auto elements = structType.getElements();
213 int64_t totalBitWidth = hw::getBitWidth(structType);
214 if (totalBitWidth < 0)
215 return rewriter.notifyMatchFailure(op.getLoc(),
"unknown struct width");
219 int64_t consumedBits = 0;
220 for (
size_t i = 0; i < fieldIndex; ++i) {
221 int64_t fieldWidth = hw::getBitWidth(elements[i].type);
223 "must be failed before if field width is unknown");
224 consumedBits += fieldWidth;
227 int64_t fieldWidth = hw::getBitWidth(elements[fieldIndex].type);
229 "must be failed before if field width is unknown");
232 int64_t bitOffset = totalBitWidth - consumedBits - fieldWidth;
234 bitOffset, fieldWidth);
244 ConversionPatternRewriter &rewriter)
const override {
247 op, adaptor.getCond(), adaptor.getTrueValue(), adaptor.getFalseValue());
254class AggregateTypeConverter :
public TypeConverter {
256 AggregateTypeConverter() {
257 addConversion([](Type type) -> Type {
return type; });
258 addConversion([](hw::ArrayType t) -> Type {
259 return IntegerType::get(t.getContext(), hw::getBitWidth(t));
261 addConversion([](hw::StructType t) -> Type {
262 return IntegerType::get(t.getContext(), hw::getBitWidth(t));
264 addTargetMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
265 mlir::ValueRange inputs,
266 mlir::Location loc) -> mlir::Value {
267 if (inputs.size() != 1)
274 addSourceMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
275 mlir::ValueRange inputs,
276 mlir::Location loc) -> mlir::Value {
277 if (inputs.size() != 1)
288 RewritePatternSet &
patterns, AggregateTypeConverter &typeConverter) {
290 HWArrayGetOpConversion, HWArrayCreateLikeOpConversion<hw::ArrayCreateOp>,
291 HWArrayCreateLikeOpConversion<hw::ArrayConcatOp>,
292 HWAggregateConstantOpConversion, HWArrayInjectOpConversion,
293 HWStructCreateOpConversion, HWStructExtractOpConversion, MuxOpConversion>(
294 typeConverter,
patterns.getContext());
298struct HWAggregateToCombPass
299 :
public hw::impl::HWAggregateToCombBase<HWAggregateToCombPass> {
300 void runOnOperation()
override;
301 using HWAggregateToCombBase<HWAggregateToCombPass>::HWAggregateToCombBase;
305void HWAggregateToCombPass::runOnOperation() {
306 ConversionTarget target(getContext());
310 hw::AggregateConstantOp, hw::ArrayInjectOp,
313 [](
comb::MuxOp op) {
return hw::type_isa<IntegerType>(op.getType()); });
314 target.addLegalDialect<hw::HWDialect, comb::CombDialect>();
316 RewritePatternSet
patterns(&getContext());
317 AggregateTypeConverter typeConverter;
320 if (failed(mlir::applyPartialConversion(getOperation(), target,
322 return signalPassFailure();
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
static void populateHWAggregateToCombOpConversionPatterns(RewritePatternSet &patterns, AggregateTypeConverter &typeConverter)
create(elements, Type result_type=None)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.