CIRCT 23.0.0git
Loading...
Searching...
No Matches
RemoveI0Types.cpp
Go to the documentation of this file.
1//===- RemoveI0Types.cpp --------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "circt/Support/LLVM.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/Value.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Support/LogicalResult.h"
22#include "mlir/Transforms/DialectConversion.h"
23
24namespace circt {
25namespace arc {
26#define GEN_PASS_DEF_REMOVEI0TYPES
27#include "circt/Dialect/Arc/ArcPasses.h.inc"
28} // namespace arc
29} // namespace circt
30
31using namespace mlir;
32using namespace circt;
33using namespace arc;
34
35namespace {
36struct RemoveI0TypesPass
37 : public arc::impl::RemoveI0TypesBase<RemoveI0TypesPass> {
38 using RemoveI0TypesBase::RemoveI0TypesBase;
39 void runOnOperation() override;
40};
41
42bool isI0(Type type) {
43 auto intType = dyn_cast<IntegerType>(type);
44 return intType && intType.getWidth() == 0;
45}
46
47// Flattens a list of list of values into a list of values.
48SmallVector<Value> flatten(ArrayRef<ValueRange> ranges) {
49 SmallVector<Value> flat;
50 for (auto range : ranges) {
51 flat.insert(flat.end(), range.begin(), range.end());
52 }
53 return flat;
54}
55
56// Generic pattern for ops that are legalizable by flattening their remaining
57// operands after type conversion.
58template <typename T>
59struct LegalizeGeneric : public OpConversionPattern<T> {
61 using OneToNOpAdaptor = typename OpConversionPattern<T>::OneToNOpAdaptor;
62
63 LogicalResult
64 matchAndRewrite(T op, OneToNOpAdaptor adaptor,
65 ConversionPatternRewriter &rewriter) const override {
66 const TypeConverter &converter = *this->getTypeConverter();
67 auto result = convertOpResultTypes(op, flatten(adaptor.getOperands()),
68 converter, rewriter);
69 if (failed(result))
70 return failure();
71
72 // Map from old results to new results, assuming the results size may have
73 // changed.
74 Operation *newOp = *result;
75 auto newOpResultIt = newOp->result_begin();
76 SmallVector<Value> results;
77 for (auto oldType : op->getResultTypes()) {
78 if (!converter.convertType(oldType)) {
79 results.push_back(nullptr);
80 } else {
81 results.push_back(*newOpResultIt++);
82 }
83 }
84 assert(newOpResultIt == newOp->result_end() && "Didn't map all results!");
85 rewriter.replaceOp(op, results);
86 return success();
87 }
88};
89
90// As above, but if any converted operand is empty (i.e. the operand was i0
91// initially) or the number of results after conversion changes, then the op is
92// erased. This is used for all ops where we don't explicitly know the legality
93// of removing operands from its operand list.
94struct ConvertGeneric : public ConversionPattern {
95 ConvertGeneric(TypeConverter &converter, MLIRContext *context)
96 : ConversionPattern(converter, MatchAnyOpTypeTag{},
97 /*benefit=*/0, context) {}
98
99 LogicalResult
100 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
101 ConversionPatternRewriter &rewriter) const override {
102 const TypeConverter &converter = *getTypeConverter();
103
104 // If any operand wasn't converted (empty range), the op must be dead.
105 for (ValueRange range : operands) {
106 if (range.empty()) {
107 rewriter.eraseOp(op);
108 return success();
109 }
110 }
111
112 // If any result wasn't converted, the op must be dead.
113 SmallVector<Type> resultTypes;
114 if (failed(converter.convertTypes(op->getResultTypes(), resultTypes)))
115 return failure();
116 if (resultTypes.size() != op->getNumResults()) {
117 rewriter.eraseOp(op);
118 return success();
119 }
120
121 auto result =
122 convertOpResultTypes(op, flatten(operands), converter, rewriter);
123 if (failed(result))
124 return failure();
125 rewriter.replaceOp(op, *result);
126 return success();
127 }
128};
129
130// An array_get with i0 index just returns the array input, which will have been
131// scalarized.
132struct ConvertArrayGet : public OpConversionPattern<hw::ArrayGetOp> {
134 LogicalResult
135 matchAndRewrite(hw::ArrayGetOp op, OneToNOpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter) const override {
137 if (adaptor.getIndex().empty()) {
138 assert(adaptor.getInput().size() == 1);
139 rewriter.replaceOp(op, adaptor.getInput().front());
140 return success();
141 }
142 // Handled by ConvertGeneric.
143 return failure();
144 }
145};
146
147// Replaces array_create of a single element with the element.
148struct ConvertArrayCreate : public OpConversionPattern<hw::ArrayCreateOp> {
150 LogicalResult
151 matchAndRewrite(hw::ArrayCreateOp op, OneToNOpAdaptor adaptor,
152 ConversionPatternRewriter &rewriter) const override {
153 if (adaptor.getInputs().size() == 1) {
154 rewriter.replaceOp(op, adaptor.getInputs().front());
155 return success();
156 }
157 // Handled by ConvertGeneric.
158 return failure();
159 }
160};
161
162// Converts array_inject with i0 index to just return the element.
163struct ConvertArrayInject : public OpConversionPattern<hw::ArrayInjectOp> {
164 using OpConversionPattern<hw::ArrayInjectOp>::OpConversionPattern;
165 LogicalResult
166 matchAndRewrite(hw::ArrayInjectOp op, OneToNOpAdaptor adaptor,
167 ConversionPatternRewriter &rewriter) const override {
168 if (adaptor.getIndex().empty()) {
169 rewriter.replaceOp(op, adaptor.getElement());
170 return success();
171 }
172 // Handled by ConvertGeneric.
173 return failure();
174 }
175};
176
177// Converts an aggregate_constant by recursively rewriting its attribute.
178struct ConvertAggregateConstant
179 : public OpConversionPattern<hw::AggregateConstantOp> {
180 using OpConversionPattern<hw::AggregateConstantOp>::OpConversionPattern;
181 LogicalResult
182 matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 Type resultType = getTypeConverter()->convertType(op.getResult().getType());
185
186 // Recursively rewrite the attribute.
187 Attribute newFields =
188 rewriteArrayAttr(op.getFields(), op.getResult().getType());
189
190 if (!isa<ArrayAttr>(newFields)) {
191 // Scalar result becomes hw.constant.
192 IntegerAttr attr = cast<IntegerAttr>(newFields);
193 auto result =
194 hw::ConstantOp::create(rewriter, op.getLoc(), resultType, attr);
195 rewriter.replaceOp(op, result);
196 return success();
197 }
198
199 // Composite result becomes a new aggregate_constant op.
200 auto newOp = hw::AggregateConstantOp::create(
201 rewriter, op.getLoc(), resultType, cast<ArrayAttr>(newFields));
202 rewriter.replaceOp(op, newOp);
203 return success();
204 }
205
206 // NOLINTNEXTLINE(misc-no-recursion): Bounded recursion.
207 Attribute rewriteArrayAttr(ArrayAttr array, Type type) const {
208 if (getTypeConverter()->convertType(type) == type)
209 return array;
210 if (auto arrayType = dyn_cast<hw::ArrayType>(type);
211 arrayType && arrayType.getNumElements() == 1) {
212 return *array.begin();
213 }
214
215 // Collect the immediate subtypes. FieldIDTypeInterface is supported by
216 // ArrayType, UnpackedArrayType, StructType, UnionType.
217 auto fieldIdInterface = cast<hw::FieldIDTypeInterface>(type);
218 SmallVector<Attribute> attrs;
219 for (auto [index, attr] : llvm::enumerate(array)) {
220 uint64_t fieldId = fieldIdInterface.getFieldID(index);
221 Type subType = fieldIdInterface.getSubTypeByFieldID(fieldId).first;
222 if (auto subArrayAttr = dyn_cast<ArrayAttr>(attr))
223 attrs.push_back(rewriteArrayAttr(subArrayAttr, subType));
224 else
225 attrs.push_back(attr);
226 }
227 return ArrayAttr::get(array.getContext(), attrs);
228 }
229};
230
231} // namespace
232
233void RemoveI0TypesPass::runOnOperation() {
234 TypeConverter converter;
235 ConversionTarget target(getContext());
236 RewritePatternSet patterns(&getContext());
237
238 // The conversions for types are 1:N, where N may be 1 or 0.
239 converter.addConversion([](Type type, SmallVectorImpl<Type> &types) {
240 if (!isI0(type))
241 types.push_back(type);
242 return success();
243 });
244
245 // Composite types - recursively apply type conversion to inner types.
246 converter.addConversion([&converter](hw::ArrayType type,
247 SmallVectorImpl<Type> &types) {
248 // If the array has only one element, we replace the array with the element.
249 if (type.getNumElements() == 1) {
250 if (Type converted = converter.convertType(type.getElementType()))
251 types.push_back(converted);
252 return success();
253 }
254 // Recursively apply type conversion to inner types.
255 types.push_back(hw::ArrayType::get(
256 converter.convertType(type.getElementType()), type.getNumElements()));
257 return success();
258 });
259 converter.addConversion([&converter](hw::StructType type) -> Type {
260 SmallVector<hw::StructType::FieldInfo> newMembers;
261 // Convert and filter out any i0 fields.
262 for (auto &field : type.getElements()) {
263 SmallVector<Type> convertedTypes;
264 if (failed(converter.convertType(field.type, convertedTypes)))
265 return Type();
266 if (!convertedTypes.empty()) {
267 assert(convertedTypes.size() == 1);
268 newMembers.push_back({field.name, convertedTypes[0]});
269 }
270 }
271 return hw::StructType::get(type.getContext(), newMembers);
272 });
273 converter.addConversion([&converter](hw::UnionType type) -> Type {
274 SmallVector<hw::UnionType::FieldInfo> newMembers;
275 // Convert and filter out any i0 fields.
276 for (auto &field : type.getElements()) {
277 SmallVector<Type> convertedTypes;
278 if (failed(converter.convertType(field.type, convertedTypes)))
279 return Type();
280 if (!convertedTypes.empty()) {
281 assert(convertedTypes.size() == 1);
282 newMembers.push_back({field.name, convertedTypes[0], field.offset});
283 }
284 }
285 return hw::UnionType::get(type.getContext(), newMembers);
286 });
287 converter.addConversion(
288 [&converter](hw::TypeAliasType type, SmallVectorImpl<Type> &types) {
289 return converter.convertType(type.getCanonicalType(), types);
290 });
291 converter.addConversion(
292 [&converter](arc::StateType type, SmallVectorImpl<Type> &types) {
293 if (failed(converter.convertType(type.getType(), types)))
294 return failure();
295 assert(types.size() == 1);
296 types[0] = arc::StateType::get(types[0]);
297 return success();
298 });
299
300 target.markUnknownOpDynamicallyLegal(
301 [&](Operation *op) { return converter.isLegal(op); });
302 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp func) {
303 FunctionType fty = func.getFunctionType();
304 return converter.isLegal(fty.getInputs()) &&
305 converter.isLegal(fty.getResults());
306 });
307
308 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
309 converter);
310
311 patterns.add<LegalizeGeneric<func::ReturnOp>, LegalizeGeneric<func::CallOp>,
312 LegalizeGeneric<hw::StructCreateOp>, ConvertArrayGet,
313 ConvertArrayCreate, ConvertArrayInject, ConvertGeneric,
314 ConvertAggregateConstant>(converter, &getContext());
315 ConversionConfig config;
316 config.allowPatternRollback = false;
317 if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
318 config)))
319 return signalPassFailure();
320}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static FIRRTLBaseType convertType(FIRRTLBaseType type)
Returns null type if no conversion is needed.
Definition DropConst.cpp:32
create(data_type, value)
Definition hw.py:433
Definition arc.py:1
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.