18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinOps.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/Location.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Support/LLVM.h"
27#include "mlir/Support/LogicalResult.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#define DEBUG_TYPE "arc-lower-arrays"
35#define GEN_PASS_DEF_LOWERARRAYS
36#include "circt/Dialect/Arc/ArcPasses.h.inc"
44using ::llvm::enumerate;
51struct LowerArraysPass :
public arc::impl::LowerArraysBase<LowerArraysPass> {
52 LowerArraysPass() =
default;
53 LowerArraysPass(
const LowerArraysPass &pass) : LowerArraysPass() {}
55 void runOnOperation()
override;
58Value asIndex(Value value, OpBuilder &builder) {
59 Location loc = builder.getUnknownLoc();
60 if (Operation *parent = value.getDefiningOp()) {
61 loc = parent->getLoc();
63 return arith::IndexCastUIOp::create(builder, loc, builder.getIndexType(),
67Value cloneArrayRef(Value value, OpBuilder &builder, Location loc) {
68 Value newAlloc = ArrayRefAllocOp::create(builder, loc, value.getType(), {});
69 return ArrayRefCopyOp::create(builder, loc, newAlloc, value);
73 using OpConversionPattern::OpConversionPattern;
76 matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter)
const override {
78 const TypeConverter &converter = *getTypeConverter();
79 TypeConverter::SignatureConversion conversion(op.getNumArguments());
81 SmallVector<Type> newArgTypes;
82 SmallVector<Type> newResultTypes;
83 assert(op.getBody().getBlocks().size() == 1);
86 Operation *ret = op.getBody().front().getTerminator();
87 for (Value result : ret->getOperands()) {
88 if (isa<ArrayType>(result.getType())) {
89 Type newType = converter.convertType(result.getType());
90 conversion.addInputs(newType);
91 newArgTypes.push_back(newType);
95 if (failed(converter.convertTypes(op.getArgumentTypes(), newArgTypes)) ||
96 failed(converter.convertTypes(op.getResultTypes(), newResultTypes))) {
100 if (failed(converter.convertSignatureArgs(op.getArgumentTypes(),
102 failed(rewriter.convertRegionTypes(&op.getBody(), converter,
107 rewriter.modifyOpInPlace(op, [&] {
108 op.setType(FunctionType::get(getContext(), newArgTypes, newResultTypes));
116 using OpConversionPattern::OpConversionPattern;
119 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter)
const override {
121 func::FuncOp func = op->getParentOfType<func::FuncOp>();
123 SmallVector<Value> newOperands;
124 for (Value operand : adaptor.getOperands()) {
125 if (isa<ArrayRefType>(operand.getType())) {
126 Value arg = func.getArgument(sretIndex++);
128 ArrayRefCopyOp::create(rewriter, op.getLoc(), arg, operand);
129 newOperands.push_back(copy);
131 newOperands.push_back(operand);
134 rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); });
140 using OpConversionPattern::OpConversionPattern;
143 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
144 ConversionPatternRewriter &rewriter)
const override {
145 SmallVector<Value> newOperands;
146 for (Type resultType : op.getResultTypes()) {
147 auto arrayType = dyn_cast<ArrayType>(resultType);
150 auto arrayRefType = ArrayRefType::get(arrayType.getElementType(),
151 arrayType.getNumElements());
153 ArrayRefAllocOp::create(rewriter, op.getLoc(), arrayRefType, {});
154 newOperands.push_back(alloc);
156 for (Value operand : adaptor.getOperands()) {
157 newOperands.push_back(operand);
160 SmallVector<Type> resultTypes;
161 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
166 auto newCall = func::CallOp::create(rewriter, op.getLoc(), op.getCallee(),
167 resultTypes, newOperands);
168 newCall->setDiscardableAttrs(op->getDiscardableAttrDictionary());
170 rewriter.replaceOp(op, newCall);
175struct ConvertAggregateConstant
177 using OpConversionPattern::OpConversionPattern;
180 matchAndRewrite(AggregateConstantOp op, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter)
const override {
183 cast<ArrayRefType>(getTypeConverter()->
convertType(op.getType()));
184 Value newOp = ArrayRefAllocOp::create(rewriter, op.getLoc(), newType,
186 rewriter.replaceOp(op, newOp);
192 using OpConversionPattern::OpConversionPattern;
196 ConversionPatternRewriter &rewriter)
const override {
197 Type resultType = getTypeConverter()->convertType(op.getType());
198 Value index = asIndex(op.getIndex(), rewriter);
199 Value newOp = ArrayRefGetOp::create(rewriter, op.getLoc(), resultType,
200 adaptor.getInput(), index);
201 rewriter.replaceOp(op, newOp);
207 using OpConversionPattern::OpConversionPattern;
210 matchAndRewrite(ArrayInjectOp op, OpAdaptor adaptor,
211 ConversionPatternRewriter &rewriter)
const override {
212 Value index = asIndex(op.getIndex(), rewriter);
213 Value dest = cloneArrayRef(adaptor.getInput(), rewriter, op.getLoc());
215 ArrayRefInjectOp::create(rewriter, op.getLoc(), dest.getType(), dest,
216 index, adaptor.getElement());
217 rewriter.replaceOp(op, newOp);
223 using OpConversionPattern::OpConversionPattern;
227 ConversionPatternRewriter &rewriter)
const override {
230 Value index = asIndex(op.getLowIndex(), rewriter);
231 Type destType = getTypeConverter()->convertType(op.getType());
232 Value newOp = ArrayRefSliceOp::create(rewriter, op.getLoc(), destType,
233 adaptor.getInput(), index);
234 rewriter.replaceOp(op, newOp);
240 using OpConversionPattern::OpConversionPattern;
244 ConversionPatternRewriter &rewriter)
const override {
245 Type destType = getTypeConverter()->convertType(op.getType());
246 Value dest = ArrayRefAllocOp::create(rewriter, op.getLoc(), destType, {});
250 int offset = cast<ArrayRefType>(destType).getNumElements();
251 for (Value operand : adaptor.getOperands()) {
252 offset -= cast<ArrayRefType>(operand.getType()).getNumElements();
254 arith::ConstantIndexOp::create(rewriter, op.getLoc(), offset);
255 Value destSlice = ArrayRefSliceOp::create(rewriter, op.getLoc(),
256 operand.getType(), dest, index);
257 ArrayRefCopyOp::create(rewriter, op.getLoc(), destSlice, operand);
260 rewriter.replaceOp(op, dest);
266 using OpConversionPattern::OpConversionPattern;
270 ConversionPatternRewriter &rewriter)
const override {
271 Type newType = getTypeConverter()->convertType(op.getType());
272 Value alloc = ArrayRefAllocOp::create(rewriter, op.getLoc(), newType, {});
273 Value create = ArrayRefCreateOp::create(rewriter, op.getLoc(), newType,
274 alloc, adaptor.getInputs());
275 rewriter.replaceOp(op, create);
281 using OpConversionPattern::OpConversionPattern;
284 matchAndRewrite(StorageGetOp op, OpAdaptor adaptor,
285 ConversionPatternRewriter &rewriter)
const override {
286 auto result = convertOpResultTypes(op, adaptor.getOperands(),
287 *getTypeConverter(), rewriter);
291 rewriter.replaceOp(op, *result);
297 using OpConversionPattern::OpConversionPattern;
301 ConversionPatternRewriter &rewriter)
const override {
302 Type newType = getTypeConverter()->convertType(op.getType());
303 Value newOp = arith::SelectOp::create(
304 rewriter, op.getLoc(), newType, adaptor.getCond(),
305 adaptor.getTrueValue(), adaptor.getFalseValue());
306 rewriter.replaceOp(op, newOp);
314 using OpRewritePattern::OpRewritePattern;
316 LogicalResult matchAndRewrite(func::ReturnOp op,
317 PatternRewriter &rewriter)
const override {
318 auto funcOp = op->getParentOfType<func::FuncOp>();
319 bool changed =
false;
324 auto args = funcOp.getArguments();
326 llvm::make_filter_range(op->getOpOperands(), [](OpOperand &operand) {
327 return isa<ArrayRefType>(operand.get().getType());
330 for (
auto [arg, result] :
llvm::zip(args, results)) {
331 Value resultValue = result.get();
332 auto copy = resultValue.getDefiningOp<ArrayRefCopyOp>();
333 if (!copy || copy.getInput() != arg)
335 ArrayRefAllocOp alloc = getUltimatelyDefiningAlloc(copy.getSource());
336 if (!alloc || alloc.getInit())
340 rewriter.modifyOpInPlace(
341 op, [&, &result = result] { result.set(copy.getSource()); });
342 rewriter.replaceAllUsesWith(alloc, arg);
343 rewriter.eraseOp(alloc);
344 rewriter.eraseOp(copy);
347 return success(changed);
350 ArrayRefAllocOp getUltimatelyDefiningAlloc(Value value)
const {
351 if (!isa<ArrayRefType>(value.getType()))
354 Operation *op = value.getDefiningOp();
358 if (
auto alloc = dyn_cast<ArrayRefAllocOp>(op))
362 TypeSwitch<Operation *, Value>(op)
363 .Case<ArrayRefCopyOp>([&](
auto copy) {
return copy.getInput(); })
364 .Case<ArrayRefInjectOp>(
365 [&](
auto inject) {
return inject.getInput(); })
366 .Case<ArrayRefFromArrayOp>(
367 [&](
auto fromArray) {
return fromArray.getInput(); })
368 .Case<func::CallOp>([&](
auto call) {
369 OpResult result = cast<OpResult>(value);
370 return call.getOperand(result.getResultNumber());
372 .Default([&](Operation *) {
return nullptr; });
382 matchAndRewrite(StateReadOp op, OpAdaptor adaptor,
383 ConversionPatternRewriter &rewriter)
const override {
384 auto result = convertOpResultTypes(op, adaptor.getOperands(),
385 *this->getTypeConverter(), rewriter);
390 Value resultValue = result.value()->getResult(0);
391 rewriter.replaceOp(op, cloneArrayRef(resultValue, rewriter, op.getLoc()));
396template <
typename Op>
401 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
402 ConversionPatternRewriter &rewriter)
const override {
403 auto result = convertOpResultTypes(op, adaptor.getOperands(),
404 *this->getTypeConverter(), rewriter);
407 rewriter.replaceOp(op, *result);
412using ConvertAllocState = ConvertTrivially<arc::AllocStateOp>;
413using ConvertStateWrite = ConvertTrivially<arc::StateWriteOp>;
414using ConvertRootInput = ConvertTrivially<arc::RootInputOp>;
415using ConvertRootOutput = ConvertTrivially<arc::RootOutputOp>;
416using ConvertUnrealizedConversionCast =
417 ConvertTrivially<UnrealizedConversionCastOp>;
421void LowerArraysPass::runOnOperation() {
422 TypeConverter converter;
423 ConversionTarget target(getContext());
424 RewritePatternSet
patterns(&getContext());
426 converter.addConversion([](Type type) {
return type; });
427 converter.addConversion([&converter](ArrayType type) -> Type {
428 Type newElem = converter.convertType(type.getElementType());
430 if (newElem != type.getElementType())
432 return ArrayRefType::get(newElem, type.getNumElements());
434 converter.addConversion([&converter](StateType type) {
435 return StateType::get(converter.convertType(type.getType()));
438 target.addLegalOp<ArrayRefFromArrayOp, ArrayRefToArrayOp>();
441 StructExplodeOp, UnionCreateOp, UnionExtractOp>();
442 target.markUnknownOpDynamicallyLegal(
443 [&](Operation *op) {
return converter.isLegal(op); });
444 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp func) {
445 FunctionType fty = func.getFunctionType();
446 return converter.isLegal(fty.getInputs()) &&
447 converter.isLegal(fty.getResults());
452 return converter.isLegal(op.getInput().getType());
456 return converter.isLegal(op->getResult(0).getType());
460 converter.addTargetMaterialization([&](OpBuilder &b, ArrayRefType type,
461 ValueRange values, Location loc,
462 Type fromType) -> Value {
463 assert(isa<ArrayType>(fromType));
464 Value alloc = ArrayRefAllocOp::create(b, loc, type, {});
465 return ArrayRefFromArrayOp::create(b, loc, type, alloc, values.front());
469 converter.addSourceMaterialization([&](OpBuilder &b, ArrayType type,
471 Location loc) -> Value {
472 assert(isa<ArrayRefType>(values.front().getType()));
473 return ArrayRefToArrayOp::create(b, loc, type, values.front());
476 patterns.add<ConvertFunc, ConvertReturn, ConvertCall,
477 ConvertAggregateConstant, ConvertArrayGet, ConvertArrayInject,
478 ConvertArraySlice, ConvertArrayConcat, ConvertArrayCreate,
479 ConvertStorageGet, ConvertMux, ConvertAllocState,
480 ConvertStateRead, ConvertStateWrite, ConvertRootInput,
481 ConvertRootOutput, ConvertUnrealizedConversionCast>(
482 converter, &getContext());
484 ConversionConfig config;
485 config.allowPatternRollback =
false;
486 if (failed(applyPartialConversion(getOperation(), target, std::move(
patterns),
488 return signalPassFailure();
492 RewritePatternSet cleanupPatterns(&getContext());
493 cleanupPatterns.add<OptimizeReturnOfAlloc>(&getContext());
495 applyPatternsGreedily(getOperation(), std::move(cleanupPatterns)))) {
496 return signalPassFailure();
assert(baseType &&"element must be base type")
static FIRRTLBaseType convertType(FIRRTLBaseType type)
Returns null type if no conversion is needed.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.