CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerArrays.cpp
Go to the documentation of this file.
1//===- LowerArrays.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <utility>
11
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinOps.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/Location.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Support/LLVM.h"
27#include "mlir/Support/LogicalResult.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30
31#define DEBUG_TYPE "arc-lower-arrays"
32
33namespace circt {
34namespace arc {
35#define GEN_PASS_DEF_LOWERARRAYS
36#include "circt/Dialect/Arc/ArcPasses.h.inc"
37} // namespace arc
38} // namespace circt
39
40using namespace mlir;
41using namespace circt;
42using namespace arc;
43using namespace hw;
44using ::llvm::enumerate;
45
46//===----------------------------------------------------------------------===//
47// Pass Implementation
48//===----------------------------------------------------------------------===//
49
50namespace {
51struct LowerArraysPass : public arc::impl::LowerArraysBase<LowerArraysPass> {
52 LowerArraysPass() = default;
53 LowerArraysPass(const LowerArraysPass &pass) : LowerArraysPass() {}
54
55 void runOnOperation() override;
56};
57
58Value asIndex(Value value, OpBuilder &builder) {
59 Location loc = builder.getUnknownLoc();
60 if (Operation *parent = value.getDefiningOp()) {
61 loc = parent->getLoc();
62 }
63 return arith::IndexCastUIOp::create(builder, loc, builder.getIndexType(),
64 value);
65}
66
67Value cloneArrayRef(Value value, OpBuilder &builder, Location loc) {
68 Value newAlloc = ArrayRefAllocOp::create(builder, loc, value.getType(), {});
69 return ArrayRefCopyOp::create(builder, loc, newAlloc, value);
70}
71
72struct ConvertFunc : public OpConversionPattern<func::FuncOp> {
73 using OpConversionPattern::OpConversionPattern;
74
75 LogicalResult
76 matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter) const override {
78 const TypeConverter &converter = *getTypeConverter();
79 TypeConverter::SignatureConversion conversion(op.getNumArguments());
80
81 SmallVector<Type> newArgTypes;
82 SmallVector<Type> newResultTypes;
83 assert(op.getBody().getBlocks().size() == 1);
84
85 // Any array-typed results become parameters.
86 Operation *ret = op.getBody().front().getTerminator();
87 for (Value result : ret->getOperands()) {
88 if (isa<ArrayType>(result.getType())) {
89 Type newType = converter.convertType(result.getType());
90 conversion.addInputs(newType);
91 newArgTypes.push_back(newType);
92 }
93 }
94
95 if (failed(converter.convertTypes(op.getArgumentTypes(), newArgTypes)) ||
96 failed(converter.convertTypes(op.getResultTypes(), newResultTypes))) {
97 return failure();
98 }
99
100 if (failed(converter.convertSignatureArgs(op.getArgumentTypes(),
101 conversion)) ||
102 failed(rewriter.convertRegionTypes(&op.getBody(), converter,
103 &conversion))) {
104 return failure();
105 }
106
107 rewriter.modifyOpInPlace(op, [&] {
108 op.setType(FunctionType::get(getContext(), newArgTypes, newResultTypes));
109 });
110
111 return success();
112 }
113};
114
115struct ConvertReturn : public OpConversionPattern<func::ReturnOp> {
116 using OpConversionPattern::OpConversionPattern;
117
118 LogicalResult
119 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter) const override {
121 func::FuncOp func = op->getParentOfType<func::FuncOp>();
122 int sretIndex = 0;
123 SmallVector<Value> newOperands;
124 for (Value operand : adaptor.getOperands()) {
125 if (isa<ArrayRefType>(operand.getType())) {
126 Value arg = func.getArgument(sretIndex++);
127 Value copy =
128 ArrayRefCopyOp::create(rewriter, op.getLoc(), arg, operand);
129 newOperands.push_back(copy);
130 } else {
131 newOperands.push_back(operand);
132 }
133 }
134 rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); });
135 return success();
136 }
137};
138
139struct ConvertCall : public OpConversionPattern<func::CallOp> {
140 using OpConversionPattern::OpConversionPattern;
141
142 LogicalResult
143 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
144 ConversionPatternRewriter &rewriter) const override {
145 SmallVector<Value> newOperands;
146 for (Type resultType : op.getResultTypes()) {
147 auto arrayType = dyn_cast<ArrayType>(resultType);
148 if (!arrayType)
149 continue;
150 auto arrayRefType = ArrayRefType::get(arrayType.getElementType(),
151 arrayType.getNumElements());
152 Value alloc =
153 ArrayRefAllocOp::create(rewriter, op.getLoc(), arrayRefType, {});
154 newOperands.push_back(alloc);
155 }
156 for (Value operand : adaptor.getOperands()) {
157 newOperands.push_back(operand);
158 }
159
160 SmallVector<Type> resultTypes;
161 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
162 resultTypes))) {
163 return failure();
164 }
165
166 auto newCall = func::CallOp::create(rewriter, op.getLoc(), op.getCallee(),
167 resultTypes, newOperands);
168 newCall->setDiscardableAttrs(op->getDiscardableAttrDictionary());
169
170 rewriter.replaceOp(op, newCall);
171 return success();
172 }
173};
174
175struct ConvertAggregateConstant
176 : public OpConversionPattern<AggregateConstantOp> {
177 using OpConversionPattern::OpConversionPattern;
178
179 LogicalResult
180 matchAndRewrite(AggregateConstantOp op, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter) const override {
182 auto newType =
183 cast<ArrayRefType>(getTypeConverter()->convertType(op.getType()));
184 Value newOp = ArrayRefAllocOp::create(rewriter, op.getLoc(), newType,
185 op.getFieldsAttr());
186 rewriter.replaceOp(op, newOp);
187 return success();
188 }
189};
190
191struct ConvertArrayGet : public OpConversionPattern<ArrayGetOp> {
192 using OpConversionPattern::OpConversionPattern;
193
194 LogicalResult
195 matchAndRewrite(hw::ArrayGetOp op, OpAdaptor adaptor,
196 ConversionPatternRewriter &rewriter) const override {
197 Type resultType = getTypeConverter()->convertType(op.getType());
198 Value index = asIndex(op.getIndex(), rewriter);
199 Value newOp = ArrayRefGetOp::create(rewriter, op.getLoc(), resultType,
200 adaptor.getInput(), index);
201 rewriter.replaceOp(op, newOp);
202 return success();
203 }
204};
205
206struct ConvertArrayInject : public OpConversionPattern<ArrayInjectOp> {
207 using OpConversionPattern::OpConversionPattern;
208
209 LogicalResult
210 matchAndRewrite(ArrayInjectOp op, OpAdaptor adaptor,
211 ConversionPatternRewriter &rewriter) const override {
212 Value index = asIndex(op.getIndex(), rewriter);
213 Value dest = cloneArrayRef(adaptor.getInput(), rewriter, op.getLoc());
214 Value newOp =
215 ArrayRefInjectOp::create(rewriter, op.getLoc(), dest.getType(), dest,
216 index, adaptor.getElement());
217 rewriter.replaceOp(op, newOp);
218 return success();
219 }
220};
221
222struct ConvertArraySlice : public OpConversionPattern<ArraySliceOp> {
223 using OpConversionPattern::OpConversionPattern;
224
225 LogicalResult
226 matchAndRewrite(ArraySliceOp op, OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const override {
228 // Because we're converting from value semantics we can assume that the
229 // input buffer is immutable, so we don't need to copy it.
230 Value index = asIndex(op.getLowIndex(), rewriter);
231 Type destType = getTypeConverter()->convertType(op.getType());
232 Value newOp = ArrayRefSliceOp::create(rewriter, op.getLoc(), destType,
233 adaptor.getInput(), index);
234 rewriter.replaceOp(op, newOp);
235 return success();
236 }
237};
238
239struct ConvertArrayConcat : public OpConversionPattern<ArrayConcatOp> {
240 using OpConversionPattern::OpConversionPattern;
241
242 LogicalResult
243 matchAndRewrite(ArrayConcatOp op, OpAdaptor adaptor,
244 ConversionPatternRewriter &rewriter) const override {
245 Type destType = getTypeConverter()->convertType(op.getType());
246 Value dest = ArrayRefAllocOp::create(rewriter, op.getLoc(), destType, {});
247
248 // ArrayConcatOp's operands are ordered from most significant to least
249 // significant.
250 int offset = cast<ArrayRefType>(destType).getNumElements();
251 for (Value operand : adaptor.getOperands()) {
252 offset -= cast<ArrayRefType>(operand.getType()).getNumElements();
253 Value index =
254 arith::ConstantIndexOp::create(rewriter, op.getLoc(), offset);
255 Value destSlice = ArrayRefSliceOp::create(rewriter, op.getLoc(),
256 operand.getType(), dest, index);
257 ArrayRefCopyOp::create(rewriter, op.getLoc(), destSlice, operand);
258 }
259 assert(offset == 0);
260 rewriter.replaceOp(op, dest);
261 return success();
262 }
263};
264
265struct ConvertArrayCreate : public OpConversionPattern<ArrayCreateOp> {
266 using OpConversionPattern::OpConversionPattern;
267
268 LogicalResult
269 matchAndRewrite(ArrayCreateOp op, OpAdaptor adaptor,
270 ConversionPatternRewriter &rewriter) const override {
271 Type newType = getTypeConverter()->convertType(op.getType());
272 Value alloc = ArrayRefAllocOp::create(rewriter, op.getLoc(), newType, {});
273 Value create = ArrayRefCreateOp::create(rewriter, op.getLoc(), newType,
274 alloc, adaptor.getInputs());
275 rewriter.replaceOp(op, create);
276 return success();
277 }
278};
279
280struct ConvertStorageGet : public OpConversionPattern<StorageGetOp> {
281 using OpConversionPattern::OpConversionPattern;
282
283 LogicalResult
284 matchAndRewrite(StorageGetOp op, OpAdaptor adaptor,
285 ConversionPatternRewriter &rewriter) const override {
286 auto result = convertOpResultTypes(op, adaptor.getOperands(),
287 *getTypeConverter(), rewriter);
288 if (failed(result))
289 return failure();
290
291 rewriter.replaceOp(op, *result);
292 return success();
293 }
294};
295
296struct ConvertMux : public OpConversionPattern<comb::MuxOp> {
297 using OpConversionPattern::OpConversionPattern;
298
299 LogicalResult
300 matchAndRewrite(comb::MuxOp op, OpAdaptor adaptor,
301 ConversionPatternRewriter &rewriter) const override {
302 Type newType = getTypeConverter()->convertType(op.getType());
303 Value newOp = arith::SelectOp::create(
304 rewriter, op.getLoc(), newType, adaptor.getCond(),
305 adaptor.getTrueValue(), adaptor.getFalseValue());
306 rewriter.replaceOp(op, newOp);
307 return success();
308 }
309};
310
311// Identifies a return of an ArrayRef that is defined by an ArrayRefAllocOp,
312// and replaces the alloc with the sret buffer. Also removes the copy.
313struct OptimizeReturnOfAlloc : public OpRewritePattern<func::ReturnOp> {
314 using OpRewritePattern::OpRewritePattern;
315
316 LogicalResult matchAndRewrite(func::ReturnOp op,
317 PatternRewriter &rewriter) const override {
318 auto funcOp = op->getParentOfType<func::FuncOp>();
319 bool changed = false;
320
321 // Iterate over all pairs of (result, sret-buffer). The sret buffers are
322 // always the initial function arguments, and every !arc.arrayref<T>
323 // result has one sret buffer associated with it.
324 auto args = funcOp.getArguments();
325 auto results =
326 llvm::make_filter_range(op->getOpOperands(), [](OpOperand &operand) {
327 return isa<ArrayRefType>(operand.get().getType());
328 });
329
330 for (auto [arg, result] : llvm::zip(args, results)) {
331 Value resultValue = result.get();
332 auto copy = resultValue.getDefiningOp<ArrayRefCopyOp>();
333 if (!copy || copy.getInput() != arg)
334 continue;
335 ArrayRefAllocOp alloc = getUltimatelyDefiningAlloc(copy.getSource());
336 if (!alloc || alloc.getInit())
337 continue;
338 // Note that `result` is a structured binding, which cannot be implicitly
339 // captured until C++20.
340 rewriter.modifyOpInPlace(
341 op, [&, &result = result] { result.set(copy.getSource()); });
342 rewriter.replaceAllUsesWith(alloc, arg);
343 rewriter.eraseOp(alloc);
344 rewriter.eraseOp(copy);
345 changed = true;
346 }
347 return success(changed);
348 }
349
350 ArrayRefAllocOp getUltimatelyDefiningAlloc(Value value) const {
351 if (!isa<ArrayRefType>(value.getType()))
352 return nullptr;
353 while (value) {
354 Operation *op = value.getDefiningOp();
355 if (!op)
356 return nullptr;
357
358 if (auto alloc = dyn_cast<ArrayRefAllocOp>(op))
359 return alloc;
360
361 value =
362 TypeSwitch<Operation *, Value>(op)
363 .Case<ArrayRefCopyOp>([&](auto copy) { return copy.getInput(); })
364 .Case<ArrayRefInjectOp>(
365 [&](auto inject) { return inject.getInput(); })
366 .Case<ArrayRefFromArrayOp>(
367 [&](auto fromArray) { return fromArray.getInput(); })
368 .Case<func::CallOp>([&](auto call) {
369 OpResult result = cast<OpResult>(value);
370 return call.getOperand(result.getResultNumber());
371 })
372 .Default([&](Operation *) { return nullptr; });
373 }
374 return nullptr;
375 }
376};
377
378struct ConvertStateRead : public OpConversionPattern<StateReadOp> {
380
381 LogicalResult
382 matchAndRewrite(StateReadOp op, OpAdaptor adaptor,
383 ConversionPatternRewriter &rewriter) const override {
384 auto result = convertOpResultTypes(op, adaptor.getOperands(),
385 *this->getTypeConverter(), rewriter);
386 if (failed(result))
387 return failure();
388
389 // For correctness we need to copy the read state into a local array.
390 Value resultValue = result.value()->getResult(0);
391 rewriter.replaceOp(op, cloneArrayRef(resultValue, rewriter, op.getLoc()));
392 return success();
393 }
394};
395
396template <typename Op>
397struct ConvertTrivially : public OpConversionPattern<Op> {
399
400 LogicalResult
401 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
402 ConversionPatternRewriter &rewriter) const override {
403 auto result = convertOpResultTypes(op, adaptor.getOperands(),
404 *this->getTypeConverter(), rewriter);
405 if (failed(result))
406 return failure();
407 rewriter.replaceOp(op, *result);
408 return success();
409 }
410};
411
412using ConvertAllocState = ConvertTrivially<arc::AllocStateOp>;
413using ConvertStateWrite = ConvertTrivially<arc::StateWriteOp>;
414using ConvertRootInput = ConvertTrivially<arc::RootInputOp>;
415using ConvertRootOutput = ConvertTrivially<arc::RootOutputOp>;
416using ConvertUnrealizedConversionCast =
417 ConvertTrivially<UnrealizedConversionCastOp>;
418
419} // namespace
420
421void LowerArraysPass::runOnOperation() {
422 TypeConverter converter;
423 ConversionTarget target(getContext());
424 RewritePatternSet patterns(&getContext());
425
426 converter.addConversion([](Type type) { return type; });
427 converter.addConversion([&converter](ArrayType type) -> Type {
428 Type newElem = converter.convertType(type.getElementType());
429 // Don't convert nested array types.
430 if (newElem != type.getElementType())
431 return type;
432 return ArrayRefType::get(newElem, type.getNumElements());
433 });
434 converter.addConversion([&converter](StateType type) {
435 return StateType::get(converter.convertType(type.getType()));
436 });
437
438 target.addLegalOp<ArrayRefFromArrayOp, ArrayRefToArrayOp>();
439 // Arrays within structs or unions always use !hw.array.
440 target.addLegalOp<StructCreateOp, StructInjectOp, StructExtractOp,
441 StructExplodeOp, UnionCreateOp, UnionExtractOp>();
442 target.markUnknownOpDynamicallyLegal(
443 [&](Operation *op) { return converter.isLegal(op); });
444 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp func) {
445 FunctionType fty = func.getFunctionType();
446 return converter.isLegal(fty.getInputs()) &&
447 converter.isLegal(fty.getResults());
448 });
449 // An ArrayGetOp may legally return an array if the input type was a nested
450 // array. Similarly for ArrayCreateOp and ArrayInjectOp.
451 target.addDynamicallyLegalOp<ArrayGetOp>([&](ArrayGetOp op) {
452 return converter.isLegal(op.getInput().getType());
453 });
454 target.addDynamicallyLegalOp<ArrayCreateOp, ArrayInjectOp>(
455 [&](Operation *op) {
456 return converter.isLegal(op->getResult(0).getType());
457 });
458
459 // Produces an ArrayRefType from an ArrayType.
460 converter.addTargetMaterialization([&](OpBuilder &b, ArrayRefType type,
461 ValueRange values, Location loc,
462 Type fromType) -> Value {
463 assert(isa<ArrayType>(fromType));
464 Value alloc = ArrayRefAllocOp::create(b, loc, type, {});
465 return ArrayRefFromArrayOp::create(b, loc, type, alloc, values.front());
466 });
467
468 // Produces an ArrayType from an ArrayRefType.
469 converter.addSourceMaterialization([&](OpBuilder &b, ArrayType type,
470 ValueRange values,
471 Location loc) -> Value {
472 assert(isa<ArrayRefType>(values.front().getType()));
473 return ArrayRefToArrayOp::create(b, loc, type, values.front());
474 });
475
476 patterns.add<ConvertFunc, ConvertReturn, ConvertCall,
477 ConvertAggregateConstant, ConvertArrayGet, ConvertArrayInject,
478 ConvertArraySlice, ConvertArrayConcat, ConvertArrayCreate,
479 ConvertStorageGet, ConvertMux, ConvertAllocState,
480 ConvertStateRead, ConvertStateWrite, ConvertRootInput,
481 ConvertRootOutput, ConvertUnrealizedConversionCast>(
482 converter, &getContext());
483
484 ConversionConfig config;
485 config.allowPatternRollback = false;
486 if (failed(applyPartialConversion(getOperation(), target, std::move(patterns),
487 config))) {
488 return signalPassFailure();
489 }
490
491 // Apply some cleanup patterns to optimize away ArrayRefAllocOps.
492 RewritePatternSet cleanupPatterns(&getContext());
493 cleanupPatterns.add<OptimizeReturnOfAlloc>(&getContext());
494 if (failed(
495 applyPatternsGreedily(getOperation(), std::move(cleanupPatterns)))) {
496 return signalPassFailure();
497 }
498}
assert(baseType &&"element must be base type")
static FIRRTLBaseType convertType(FIRRTLBaseType type)
Returns null type if no conversion is needed.
Definition DropConst.cpp:32
Definition arc.py:1
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1