47 ConversionPatternRewriter &rewriter)
const override {
48 if (affineParallelOp.getIVs().size() != 1)
49 return rewriter.notifyMatchFailure(affineParallelOp,
50 "currently only support single IV");
52 auto loc = affineParallelOp.getLoc();
53 auto upperBoundTuple = mlir::affine::expandAffineMap(
54 rewriter, loc, affineParallelOp.getUpperBoundsMap(),
55 affineParallelOp.getUpperBoundsOperands());
57 return rewriter.notifyMatchFailure(affineParallelOp,
58 "does not have upper bounds");
59 Value upperBound = (*upperBoundTuple)[0];
61 auto lowerBoundTuple = mlir::affine::expandAffineMap(
62 rewriter, loc, affineParallelOp.getLowerBoundsMap(),
63 affineParallelOp.getLowerBoundsOperands());
65 return rewriter.notifyMatchFailure(affineParallelOp,
66 "does not have lower bounds");
67 Value lowerBound = (*lowerBoundTuple)[0];
69 auto step = affineParallelOp.getSteps()[0];
72 affineParallelOp->getAttrOfType<IntegerAttr>(
"unparallelize.factor");
74 return rewriter.notifyMatchFailure(affineParallelOp,
75 "Missing 'unparallelize.factor'");
77 int64_t factor = factorAttr.getInt();
79 SmallVector<scf::IndexSwitchOp> simplifiableIndexSwitchOps =
82 auto outerLoop = rewriter.create<affine::AffineForOp>(
83 loc, lowerBound, rewriter.getDimIdentityMap(), upperBound,
84 rewriter.getDimIdentityMap(), step * factor);
86 rewriter.setInsertionPointToStart(outerLoop.getBody());
87 AffineMap lbMap = AffineMap::get(
89 rewriter.getAffineConstantExpr(0), rewriter.getContext());
90 AffineMap ubMap = AffineMap::get(
91 0, 0, rewriter.getAffineConstantExpr(factor), rewriter.getContext());
92 auto innerParallel = rewriter.create<affine::AffineParallelOp>(
94 SmallVector<arith::AtomicRMWKind>(),
95 lbMap, SmallVector<Value>(),
96 ubMap, SmallVector<Value>(),
97 SmallVector<int64_t>({step}));
99 if (!innerParallel.getBody()->empty()) {
100 Operation &lastOp = innerParallel.getBody()->back();
101 if (isa<affine::AffineYieldOp>(lastOp))
104 rewriter.setInsertionPointToStart(innerParallel.getBody());
108 auto addMap = AffineMap::get(
109 2, 0, rewriter.getAffineDimExpr(0) + rewriter.getAffineDimExpr(1),
110 rewriter.getContext());
112 auto newIndex = rewriter.create<affine::AffineApplyOp>(
114 ValueRange{outerLoop.getInductionVar(), innerParallel.getIVs()[0]});
116 Block *srcBlock = affineParallelOp.getBody();
117 Block *destBlock = innerParallel.getBody();
120 destBlock->getOperations().splice(
122 srcBlock->getOperations(),
124 std::prev(srcBlock->end())
128 destBlock->walk([&](Operation *op) {
129 for (OpOperand &operand : op->getOpOperands()) {
130 if (operand.get() == affineParallelOp.getIVs()[0])
131 operand.set(newIndex);
135 rewriter.setInsertionPointToEnd(destBlock);
136 rewriter.create<affine::AffineYieldOp>(loc);
138 for (
auto indexSwitchOp : simplifiableIndexSwitchOps) {
139 indexSwitchOp.setOperand(innerParallel.getIVs().front());
151 int64_t factor)
const {
152 SmallVector<scf::IndexSwitchOp> result;
153 affineParallelOp->walk([&](scf::IndexSwitchOp indexSwitchOp) {
154 auto switchArg = indexSwitchOp.getArg();
156 dyn_cast_or_null<affine::AffineApplyOp>(switchArg.getDefiningOp());
157 if (!affineApplyOp || affineApplyOp->getNumOperands() != 1 ||
158 affineApplyOp->getNumResults() != 1)
159 return WalkResult::advance();
161 auto affineMap = affineApplyOp.getAffineMap();
162 auto binExpr = dyn_cast<AffineBinaryOpExpr>(affineMap.getResult(0));
163 if (!binExpr || binExpr.getKind() != AffineExprKind::Mod)
164 return WalkResult::advance();
166 if (affineApplyOp.getOperand(0) != affineParallelOp.getIVs().front())
167 return WalkResult::advance();
169 auto rhs = binExpr.getRHS();
170 auto constRhs = dyn_cast<AffineConstantExpr>(rhs);
171 if (!constRhs || factor != constRhs.getValue())
172 return WalkResult::advance();
174 result.push_back(indexSwitchOp);
175 return WalkResult::advance();