15 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16 #include "mlir/Conversion/LLVMCommon/Pattern.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/BuiltinDialect.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "mlir/IR/OperationSupport.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/MathExtras.h"
30 using namespace circt;
33 return memref.getShape().size() == 1;
45 static Value
flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
46 ValueRange indices, MemRefType memrefType) {
47 assert(memrefType.hasStaticShape() &&
"expected statically shaped memref");
48 Location loc = op->getLoc();
50 if (indices.empty()) {
52 return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
56 Value finalIdx = indices.front();
57 for (
auto memIdx : llvm::enumerate(indices.drop_front())) {
58 Value partialIdx = memIdx.value();
59 int64_t indexMulFactor = 1;
62 for (
unsigned i = 0; i <= memIdx.index(); ++i) {
63 int64_t dimSize = memrefType.getShape()[i];
64 indexMulFactor *= dimSize;
68 if (llvm::isPowerOf2_64(indexMulFactor)) {
71 .create<arith::ConstantOp>(
72 loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
75 rewriter.create<arith::ShLIOp>(loc, partialIdx, constant).getResult();
77 auto constant = rewriter
78 .create<arith::ConstantOp>(
79 loc, rewriter.getIndexAttr(indexMulFactor))
82 rewriter.create<arith::MulIOp>(loc, partialIdx, constant).getResult();
86 auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
87 finalIdx = sumOp.getResult();
93 return llvm::any_of(values, [](Value v) {
94 auto memref = v.getType().dyn_cast<MemRefType>();
104 using OpConversionPattern::OpConversionPattern;
107 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter)
const override {
109 MemRefType type = op.getMemRefType();
111 op.getIndices().size() == 1)
114 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
115 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.getMemref(),
117 SmallVector<Value>{finalIdx});
123 using OpConversionPattern::OpConversionPattern;
126 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
127 ConversionPatternRewriter &rewriter)
const override {
128 MemRefType type = op.getMemRefType();
130 op.getIndices().size() == 1)
133 flattenIndices(rewriter, op, adaptor.getIndices(), op.getMemRefType());
134 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, adaptor.getValue(),
136 SmallVector<Value>{finalIdx});
142 using OpConversionPattern::OpConversionPattern;
145 matchAndRewrite(memref::AllocOp op, OpAdaptor ,
146 ConversionPatternRewriter &rewriter)
const override {
147 MemRefType type = op.getType();
151 SmallVector<int64_t>{type.getNumElements()}, type.getElementType());
152 rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
159 template <
typename TOp>
162 using OpAdaptor =
typename TOp::Adaptor;
164 matchAndRewrite(TOp op, OpAdaptor adaptor,
165 ConversionPatternRewriter &rewriter)
const override {
166 rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
167 adaptor.getOperands(), op->getAttrs());
174 struct CondBranchOpConversion
176 using OpConversionPattern::OpConversionPattern;
179 matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
180 ConversionPatternRewriter &rewriter)
const override {
181 rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
182 op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
183 adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
192 CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
193 bool rewriteFunctions =
false)
195 rewriteFunctions(rewriteFunctions) {}
198 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
199 ConversionPatternRewriter &rewriter)
const override {
200 llvm::SmallVector<Type> convResTypes;
201 if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
203 auto newCallOp = rewriter.replaceOpWithNewOp<func::CallOp>(
204 op, adaptor.getCallee(), convResTypes, adaptor.getOperands());
206 if (!rewriteFunctions)
212 rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
213 auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
215 op.getContext(), newCallOp.getOperandTypes(), convResTypes);
216 func::FuncOp newFuncOp;
218 newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
219 calledFunction, op.getCallee(), funcType);
222 rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
223 newFuncOp.setVisibility(SymbolTable::Visibility::Private);
229 bool rewriteFunctions;
232 template <
typename... TOp>
233 void addGenericLegalityConstraint(ConversionTarget &target) {
234 (target.addDynamicallyLegalOp<TOp>([](TOp op) {
241 static void populateFlattenMemRefsLegality(ConversionTarget &target) {
242 target.addLegalDialect<arith::ArithDialect>();
243 target.addDynamicallyLegalOp<memref::AllocOp>(
245 target.addDynamicallyLegalOp<memref::StoreOp>(
246 [](memref::StoreOp op) {
return op.getIndices().size() == 1; });
247 target.addDynamicallyLegalOp<memref::LoadOp>(
248 [](memref::LoadOp op) {
return op.getIndices().size() == 1; });
250 addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
251 func::CallOp, func::ReturnOp, memref::DeallocOp,
252 memref::CopyOp>(target);
254 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
255 auto argsConverted = llvm::none_of(op.getBlocks(), [](
auto &block) {
256 return hasMultiDimMemRef(block.getArguments());
259 auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
260 if (auto memref = type.dyn_cast<MemRefType>())
261 return isUniDimensional(memref);
265 return argsConverted && resultsConverted;
272 static Value materializeSubViewFlattening(OpBuilder &
builder, MemRefType type,
273 ValueRange
inputs, Location loc) {
274 assert(type.hasStaticShape() &&
275 "Can only subview flatten memref's with static shape (for now...).");
276 MemRefType sourceType =
inputs[0].getType().cast<MemRefType>();
277 int64_t memSize = sourceType.getNumElements();
278 unsigned dims = sourceType.getShape().size();
281 SmallVector<OpFoldResult> sizes(dims,
builder.getIndexAttr(0));
282 SmallVector<OpFoldResult> offsets(dims,
builder.getIndexAttr(1));
283 offsets[offsets.size() - 1] =
builder.getIndexAttr(memSize);
284 SmallVector<OpFoldResult> strides(dims,
builder.getIndexAttr(1));
287 MemRefType outType =
MemRefType::get({memSize}, type.getElementType());
288 return builder.create<memref::SubViewOp>(loc, outType,
inputs[0], sizes,
292 static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
294 typeConverter.addConversion([](Type type) {
return type; });
296 typeConverter.addConversion([](MemRefType memref) {
299 return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
300 memref.getElementType());
304 struct FlattenMemRefPass :
public FlattenMemRefBase<FlattenMemRefPass> {
306 void runOnOperation()
override {
308 auto *ctx = &getContext();
309 TypeConverter typeConverter;
310 populateTypeConversionPatterns(typeConverter);
313 SetVector<StringRef> rewrittenCallees;
315 OperandConversionPattern<func::ReturnOp>,
316 OperandConversionPattern<memref::DeallocOp>,
317 CondBranchOpConversion,
318 OperandConversionPattern<memref::DeallocOp>,
319 OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
321 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
324 ConversionTarget target(*ctx);
325 populateFlattenMemRefsLegality(target);
327 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
335 struct FlattenMemRefCallsPass
336 :
public FlattenMemRefCallsBase<FlattenMemRefCallsPass> {
338 void runOnOperation()
override {
339 auto *ctx = &getContext();
340 TypeConverter typeConverter;
341 populateTypeConversionPatterns(typeConverter);
350 patterns.add<CallOpConversion>(typeConverter, ctx,
353 ConversionTarget target(*ctx);
354 target.addLegalDialect<memref::MemRefDialect, mlir::BuiltinDialect>();
355 addGenericLegalityConstraint<func::CallOp>(target);
356 addGenericLegalityConstraint<func::FuncOp>(target);
360 typeConverter.addTargetMaterialization(materializeSubViewFlattening);
362 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
374 return std::make_unique<FlattenMemRefPass>();
378 return std::make_unique<FlattenMemRefCallsPass>();
assert(baseType &&"element must be base type")
static bool hasMultiDimMemRef(ValueRange values)
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, ValueRange indices, MemRefType memrefType)
OneToOneConvertToLLVMPattern< llhd::LoadOp, LLVM::LoadOp > LoadOpConversion
llvm::SmallVector< StringAttr > inputs
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::unique_ptr< mlir::Pass > createFlattenMemRefPass()
bool isUniDimensional(mlir::MemRefType memref)
std::unique_ptr< mlir::Pass > createFlattenMemRefCallsPass()
A struct for maintaining function declarations which needs to be rewritten, if they contain memref ar...