12#include "mlir/IR/AttrTypeSubElements.h"
13#include "mlir/IR/Iterators.h"
14#include "mlir/Pass/Pass.h"
16#define DEBUG_TYPE "hw-convert-bitcasts"
20#define GEN_PASS_DEF_HWCONVERTBITCASTS
21#include "circt/Dialect/HW/Passes.h.inc"
30struct HWConvertBitcastsPass
31 : circt::hw::impl::HWConvertBitcastsBase<HWConvertBitcastsPass> {
32 using circt::hw::impl::HWConvertBitcastsBase<
33 HWConvertBitcastsPass>::HWConvertBitcastsBase;
35 void runOnOperation()
override;
36 static bool isTypeSupported(Type ty);
37 LogicalResult convertBitcastOp(OpBuilder builder,
BitcastOp bitcastOp);
45bool HWConvertBitcastsPass::isTypeSupported(Type ty) {
46 if (isa<IntegerType>(ty))
48 if (
auto arrayTy = hw::type_dyn_cast<hw::ArrayType>(ty))
49 return isTypeSupported(arrayTy.getElementType());
50 if (
auto structTy = hw::type_dyn_cast<hw::StructType>(ty))
51 return llvm::all_of(structTy.getElements(),
52 [](StructType::FieldInfo field) {
53 return isTypeSupported(field.type);
63 SmallVectorImpl<Value> &accumulator) {
65 if (isa<IntegerType>(inputVal.getType())) {
66 accumulator.push_back(inputVal);
71 assert(numBits >= 0 &&
"Bitwidth of input must be known");
74 if (
auto arrayTy = dyn_cast<ArrayType>(inputVal.getType())) {
78 builder.getIntegerType(std::max(1u, llvm::Log2_64_Ceil(
numElements)));
82 builder, loc, builder.getIntegerAttr(indexType,
numElements - i - 1));
90 if (
auto structTy = dyn_cast<StructType>(inputVal.getType())) {
92 auto explodeOp = StructExplodeOp::create(builder, loc, inputVal);
93 for (
auto elt : explodeOp.getResults())
98 assert(
false &&
"Unsupported type");
104 Value rawInteger, Type targetType) {
106 assert(numBits >= 0 &&
"Bitwidth of target must be known");
107 assert(numBits == rawInteger.getType().getIntOrFloatBitWidth());
110 if (isa<IntegerType>(targetType))
113 SmallVector<Value> elements;
116 if (
auto arrayTy = type_dyn_cast<ArrayType>(targetType)) {
118 auto sliceWidth =
getBitWidth(arrayTy.getElementType());
120 auto sliceTy = builder.getIntegerType(sliceWidth);
128 builder.getIntegerAttr(sliceTy, 0));
133 arrayTy.getElementType());
134 elements.push_back(elt);
142 if (
auto structTy = type_dyn_cast<StructType>(targetType)) {
144 unsigned consumedBits = 0;
147 auto eltBits =
getBitWidth(structTy.getElements()[i].type);
153 builder, loc, builder.getIntegerAttr(builder.getIntegerType(0), 0));
156 builder, loc, builder.getIntegerType(eltBits), rawInteger,
157 numBits - consumedBits - eltBits);
159 structTy.getElements()[i].type);
160 elements.push_back(elt);
161 consumedBits += eltBits;
163 assert(consumedBits == numBits);
169 assert(
false &&
"Unsupported type");
173LogicalResult HWConvertBitcastsPass::convertBitcastOp(OpBuilder builder,
175 bool inputSupported = isTypeSupported(bitcastOp.getInput().getType());
176 bool outputSupported = isTypeSupported(bitcastOp.getType());
177 if (!allowPartialConversion) {
179 bitcastOp.emitOpError(
"has unsupported input type");
180 if (!outputSupported)
181 bitcastOp.emitOpError(
"has unsupported output type");
183 if (!(inputSupported && outputSupported))
186 builder.setInsertionPoint(bitcastOp);
189 SmallVector<Value> integers;
193 if (integers.size() == 1)
194 concat = integers.front();
196 concat = comb::ConcatOp::create(builder, bitcastOp.getLoc(), integers)
201 concat, bitcastOp.getType());
204 bitcastOp.getResult().replaceAllUsesWith(result);
209void HWConvertBitcastsPass::runOnOperation() {
210 OpBuilder builder(getOperation());
211 bool anyFailed =
false;
212 bool anyChanged =
false;
213 getOperation().getBody()->walk<WalkOrder::PostOrder, ReverseIterator>(
216 if (failed(convertBitcastOp(builder, bitcastOp)))
221 markAllAnalysesPreserved();
225 if (anyFailed && !allowPartialConversion)
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
static void collectIntegersRecursively(OpBuilder builder, Location loc, Value inputVal, SmallVectorImpl< Value > &accumulator)
static Value constructAggregateRecursively(OpBuilder builder, Location loc, Value rawInteger, Type targetType)
create(elements, Type result_type=None)
create(elements, Type result_type=None)
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.