CIRCT 23.0.0git
Loading...
Searching...
No Matches
HWConvertBitcasts.cpp
Go to the documentation of this file.
1//===- ConvertBitcasts.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/AttrTypeSubElements.h"
13#include "mlir/IR/Iterators.h"
14#include "mlir/Pass/Pass.h"
15
16#define DEBUG_TYPE "hw-convert-bitcasts"
17
18namespace circt {
19namespace hw {
20#define GEN_PASS_DEF_HWCONVERTBITCASTS
21#include "circt/Dialect/HW/Passes.h.inc"
22} // namespace hw
23} // namespace circt
24
25using namespace mlir;
26using namespace circt;
27using namespace hw;
28
29namespace {
30struct HWConvertBitcastsPass
31 : circt::hw::impl::HWConvertBitcastsBase<HWConvertBitcastsPass> {
32 using circt::hw::impl::HWConvertBitcastsBase<
33 HWConvertBitcastsPass>::HWConvertBitcastsBase;
34
35 void runOnOperation() override;
36 static bool isTypeSupported(Type ty);
37 LogicalResult convertBitcastOp(OpBuilder builder, BitcastOp bitcastOp);
38};
39} // namespace
40
41// Array conversion: Lower bits correspond to lower array index
42// Struct conversion: Higher bits correspond to lower field index
43
44// NOLINTNEXTLINE(misc-no-recursion)
45bool HWConvertBitcastsPass::isTypeSupported(Type ty) {
46 if (isa<IntegerType>(ty))
47 return true;
48 if (auto arrayTy = hw::type_dyn_cast<hw::ArrayType>(ty))
49 return isTypeSupported(arrayTy.getElementType());
50 if (auto structTy = hw::type_dyn_cast<hw::StructType>(ty))
51 return llvm::all_of(structTy.getElements(),
52 [](StructType::FieldInfo field) {
53 return isTypeSupported(field.type);
54 });
55 // TODO: Add support for: union, packed array, enum
56 return false;
57}
58
59// Collect integer fields of aggregates for concatenation
60// NOLINTNEXTLINE(misc-no-recursion)
61static void collectIntegersRecursively(OpBuilder builder, Location loc,
62 Value inputVal,
63 SmallVectorImpl<Value> &accumulator) {
64 // End of recursion: Integer value
65 if (isa<IntegerType>(inputVal.getType())) {
66 accumulator.push_back(inputVal);
67 return;
68 }
69
70 auto numBits = getBitWidth(inputVal.getType());
71 assert(numBits >= 0 && "Bitwidth of input must be known");
72
73 // Array Type
74 if (auto arrayTy = dyn_cast<ArrayType>(inputVal.getType())) {
75 unsigned numElements = arrayTy.getNumElements();
76 // Avoid creating zero-width indices to dodge lowering issues
77 auto indexType =
78 builder.getIntegerType(std::max(1u, llvm::Log2_64_Ceil(numElements)));
79 for (unsigned i = 0; i < numElements; ++i) {
80 // Process from high to low array index
81 auto indexCst = ConstantOp::create(
82 builder, loc, builder.getIntegerAttr(indexType, numElements - i - 1));
83 auto getOp = ArrayGetOp::create(builder, loc, inputVal, indexCst);
84 collectIntegersRecursively(builder, loc, getOp.getResult(), accumulator);
85 }
86 return;
87 }
88
89 // Struct Type
90 if (auto structTy = dyn_cast<StructType>(inputVal.getType())) {
91 // Process from low to high field index
92 auto explodeOp = StructExplodeOp::create(builder, loc, inputVal);
93 for (auto elt : explodeOp.getResults())
94 collectIntegersRecursively(builder, loc, elt, accumulator);
95 return;
96 }
97
98 assert(false && "Unsupported type");
99}
100
101// Convert integer to aggregate
102// NOLINTNEXTLINE(misc-no-recursion)
103static Value constructAggregateRecursively(OpBuilder builder, Location loc,
104 Value rawInteger, Type targetType) {
105 auto numBits = getBitWidth(targetType);
106 assert(numBits >= 0 && "Bitwidth of target must be known");
107 assert(numBits == rawInteger.getType().getIntOrFloatBitWidth());
108
109 // End of recursion: Integer value
110 if (isa<IntegerType>(targetType))
111 return rawInteger;
112
113 SmallVector<Value> elements;
114
115 // Array Type
116 if (auto arrayTy = type_dyn_cast<ArrayType>(targetType)) {
117 auto numElements = arrayTy.getNumElements();
118 auto sliceWidth = getBitWidth(arrayTy.getElementType());
119 assert(sliceWidth >= 0);
120 auto sliceTy = builder.getIntegerType(sliceWidth);
121 elements.reserve(numElements);
122 for (unsigned i = 0; i < numElements; ++i) {
123 // Process bits from MSB to LSB
124 auto offset = sliceWidth * (numElements - i - 1);
125 Value slice;
126 if (sliceWidth == 0)
127 slice = ConstantOp::create(builder, loc,
128 builder.getIntegerAttr(sliceTy, 0));
129 else
130 slice =
131 comb::ExtractOp::create(builder, loc, sliceTy, rawInteger, offset);
132 auto elt = constructAggregateRecursively(builder, loc, slice,
133 arrayTy.getElementType());
134 elements.push_back(elt);
135 }
136 // Array create reverses order:
137 // Higher bits are added first and become high indices.
138 return ArrayCreateOp::create(builder, loc, targetType, elements);
139 }
140
141 // Struct Type
142 if (auto structTy = type_dyn_cast<StructType>(targetType)) {
143 auto numElements = structTy.getElements().size();
144 unsigned consumedBits = 0;
145 for (unsigned i = 0; i < numElements; ++i) {
146
147 auto eltBits = getBitWidth(structTy.getElements()[i].type);
148 assert(eltBits >= 0);
149 Value slice;
150 // Process bits from MSB to LSB
151 if (eltBits == 0)
152 slice = ConstantOp::create(
153 builder, loc, builder.getIntegerAttr(builder.getIntegerType(0), 0));
154 else
156 builder, loc, builder.getIntegerType(eltBits), rawInteger,
157 numBits - consumedBits - eltBits);
158 auto elt = constructAggregateRecursively(builder, loc, slice,
159 structTy.getElements()[i].type);
160 elements.push_back(elt);
161 consumedBits += eltBits;
162 }
163 assert(consumedBits == numBits);
164 // Struct create does not reverse order:
165 // Higher bits are added first and become low indices.
166 return StructCreateOp::create(builder, loc, targetType, elements);
167 }
168
169 assert(false && "Unsupported type");
170 return {};
171}
172
173LogicalResult HWConvertBitcastsPass::convertBitcastOp(OpBuilder builder,
174 BitcastOp bitcastOp) {
175 bool inputSupported = isTypeSupported(bitcastOp.getInput().getType());
176 bool outputSupported = isTypeSupported(bitcastOp.getType());
177 if (!allowPartialConversion) {
178 if (!inputSupported)
179 bitcastOp.emitOpError("has unsupported input type");
180 if (!outputSupported)
181 bitcastOp.emitOpError("has unsupported output type");
182 }
183 if (!(inputSupported && outputSupported))
184 return failure();
185
186 builder.setInsertionPoint(bitcastOp);
187
188 // Convert input value to a packed integer
189 SmallVector<Value> integers;
190 collectIntegersRecursively(builder, bitcastOp.getLoc(), bitcastOp.getInput(),
191 integers);
192 Value concat;
193 if (integers.size() == 1)
194 concat = integers.front();
195 else
196 concat = comb::ConcatOp::create(builder, bitcastOp.getLoc(), integers)
197 .getResult();
198
199 // Convert packed integer to the target type
200 auto result = constructAggregateRecursively(builder, bitcastOp.getLoc(),
201 concat, bitcastOp.getType());
202
203 // Replace operation
204 bitcastOp.getResult().replaceAllUsesWith(result);
205 bitcastOp.erase();
206 return success();
207}
208
209void HWConvertBitcastsPass::runOnOperation() {
210 OpBuilder builder(getOperation());
211 bool anyFailed = false;
212 bool anyChanged = false;
213 getOperation().getBody()->walk<WalkOrder::PostOrder, ReverseIterator>(
214 [&](BitcastOp bitcastOp) {
215 anyChanged = true;
216 if (failed(convertBitcastOp(builder, bitcastOp)))
217 anyFailed = true;
218 });
219
220 if (!anyChanged) {
221 markAllAnalysesPreserved();
222 return;
223 }
224
225 if (anyFailed && !allowPartialConversion)
226 signalPassFailure();
227}
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
static void collectIntegersRecursively(OpBuilder builder, Location loc, Value inputVal, SmallVectorImpl< Value > &accumulator)
static Value constructAggregateRecursively(OpBuilder builder, Location loc, Value rawInteger, Type targetType)
create(low_bit, result_type, input=None)
Definition comb.py:187
create(elements, Type result_type=None)
Definition hw.py:483
create(array_value, idx)
Definition hw.py:450
create(data_type, value)
Definition hw.py:433
create(elements, Type result_type=None)
Definition hw.py:544
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1