CIRCT 20.0.0git
Loading...
Searching...
No Matches
LoopScheduleOps.cpp
Go to the documentation of this file.
1//===- LoopScheduleOps.cpp - LoopSchedule CIRCT Operations ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the LoopSchedule ops.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/Interfaces/FunctionImplementation.h"
19
20using namespace mlir;
21using namespace circt;
22using namespace circt::loopschedule;
23
24//===----------------------------------------------------------------------===//
25// LoopSchedulePipelineWhileOp
26//===----------------------------------------------------------------------===//
27
28ParseResult LoopSchedulePipelineOp::parse(OpAsmParser &parser,
29 OperationState &result) {
30 // Parse initiation interval.
31 IntegerAttr ii;
32 if (parser.parseKeyword("II") || parser.parseEqual() ||
33 parser.parseAttribute(ii))
34 return failure();
35 result.addAttribute("II", ii);
36
37 // Parse optional trip count.
38 if (succeeded(parser.parseOptionalKeyword("trip_count"))) {
39 IntegerAttr tripCount;
40 if (parser.parseEqual() || parser.parseAttribute(tripCount))
41 return failure();
42 result.addAttribute("tripCount", tripCount);
43 }
44
45 // Parse iter_args assignment list.
46 SmallVector<OpAsmParser::Argument> regionArgs;
47 SmallVector<OpAsmParser::UnresolvedOperand> operands;
48 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
49 if (parser.parseAssignmentList(regionArgs, operands))
50 return failure();
51 }
52
53 // Parse function type from iter_args to results.
54 FunctionType type;
55 if (parser.parseColon() || parser.parseType(type))
56 return failure();
57
58 // Function result type is the pipeline result type.
59 result.addTypes(type.getResults());
60
61 // Resolve iter_args operands.
62 for (auto [regionArg, operand, type] :
63 llvm::zip(regionArgs, operands, type.getInputs())) {
64 regionArg.type = type;
65 if (parser.resolveOperand(operand, type, result.operands))
66 return failure();
67 }
68
69 // Parse condition region.
70 Region *condition = result.addRegion();
71 if (parser.parseRegion(*condition, regionArgs))
72 return failure();
73
74 // Parse stages region.
75 if (parser.parseKeyword("do"))
76 return failure();
77 Region *stages = result.addRegion();
78 if (parser.parseRegion(*stages, regionArgs))
79 return failure();
80
81 return success();
82}
83
84void LoopSchedulePipelineOp::print(OpAsmPrinter &p) {
85 // Print the initiation interval.
86 p << " II = " << ' ' << getII();
87
88 // Print the optional tripCount.
89 if (getTripCount())
90 p << " trip_count = " << ' ' << *getTripCount();
91
92 // Print iter_args assignment list.
93 p << " iter_args(";
94 llvm::interleaveComma(
95 llvm::zip(getStages().getArguments(), getIterArgs()), p,
96 [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); });
97 p << ") : ";
98
99 // Print function type from iter_args to results.
100 auto type = FunctionType::get(getContext(), getStages().getArgumentTypes(),
101 getResultTypes());
102 p.printType(type);
103
104 // Print condition region.
105 p << ' ';
106 p.printRegion(getCondition(), /*printEntryBlockArgs=*/false);
107 p << " do";
108
109 // Print stages region.
110 p << ' ';
111 p.printRegion(getStages(), /*printEntryBlockArgs=*/false);
112}
113
114LogicalResult LoopSchedulePipelineOp::verify() {
115 // Verify the condition block is "combinational" based on an allowlist of
116 // Arithmetic ops.
117 Block &conditionBlock = getCondition().front();
118 Operation *nonCombinational;
119 WalkResult conditionWalk = conditionBlock.walk([&](Operation *op) {
120 if (isa<LoopScheduleDialect>(op->getDialect()))
121 return WalkResult::advance();
122
123 if (!isa<arith::AddIOp, arith::AndIOp, arith::BitcastOp, arith::CmpIOp,
124 arith::ConstantOp, arith::IndexCastOp, arith::MulIOp, arith::OrIOp,
125 arith::SelectOp, arith::ShLIOp, arith::ExtSIOp, arith::CeilDivSIOp,
126 arith::DivSIOp, arith::FloorDivSIOp, arith::RemSIOp,
127 arith::ShRSIOp, arith::SubIOp, arith::TruncIOp, arith::DivUIOp,
128 arith::RemUIOp, arith::ShRUIOp, arith::XOrIOp, arith::ExtUIOp>(
129 op)) {
130 nonCombinational = op;
131 return WalkResult::interrupt();
132 }
133
134 return WalkResult::advance();
135 });
136
137 if (conditionWalk.wasInterrupted())
138 return emitOpError("condition must have a combinational body, found ")
139 << *nonCombinational;
140
141 // Verify the condition block terminates with a value of type i1.
142 TypeRange conditionResults =
143 conditionBlock.getTerminator()->getOperandTypes();
144 if (conditionResults.size() != 1)
145 return emitOpError("condition must terminate with a single result, found ")
146 << conditionResults;
147
148 if (conditionResults.front() != IntegerType::get(getContext(), 1))
149 return emitOpError("condition must terminate with an i1 result, found ")
150 << conditionResults.front();
151
152 // Verify the stages block contains at least one stage and a terminator.
153 Block &stagesBlock = getStages().front();
154 if (stagesBlock.getOperations().size() < 2)
155 return emitOpError("stages must contain at least one stage");
156
157 int64_t lastStartTime = -1;
158 for (Operation &inner : stagesBlock) {
159 // Verify the stages block contains only `loopschedule.pipeline.stage` and
160 // `loopschedule.terminator` ops.
161 if (!isa<LoopSchedulePipelineStageOp, LoopScheduleTerminatorOp>(inner))
162 return emitOpError(
163 "stages may only contain 'loopschedule.pipeline.stage' or "
164 "'loopschedule.terminator' ops, found ")
165 << inner;
166
167 // Verify the stage start times are monotonically increasing.
168 if (auto stage = dyn_cast<LoopSchedulePipelineStageOp>(inner)) {
169 if (lastStartTime == -1) {
170 lastStartTime = stage.getStart();
171 continue;
172 }
173
174 if (lastStartTime >= stage.getStart())
175 return stage.emitOpError("'start' must be after previous 'start' (")
176 << lastStartTime << ')';
177
178 lastStartTime = stage.getStart();
179 }
180 }
181
182 return success();
183}
184
185void LoopSchedulePipelineOp::build(OpBuilder &builder, OperationState &state,
186 TypeRange resultTypes, IntegerAttr ii,
187 std::optional<IntegerAttr> tripCount,
188 ValueRange iterArgs) {
189 OpBuilder::InsertionGuard g(builder);
190
191 state.addTypes(resultTypes);
192 state.addAttribute("II", ii);
193 if (tripCount)
194 state.addAttribute("tripCount", *tripCount);
195 state.addOperands(iterArgs);
196
197 Region *condRegion = state.addRegion();
198 Block &condBlock = condRegion->emplaceBlock();
199
200 SmallVector<Location, 4> argLocs;
201 for (auto arg : iterArgs)
202 argLocs.push_back(arg.getLoc());
203 condBlock.addArguments(iterArgs.getTypes(), argLocs);
204 builder.setInsertionPointToEnd(&condBlock);
205 builder.create<LoopScheduleRegisterOp>(builder.getUnknownLoc(), ValueRange());
206
207 Region *stagesRegion = state.addRegion();
208 Block &stagesBlock = stagesRegion->emplaceBlock();
209 stagesBlock.addArguments(iterArgs.getTypes(), argLocs);
210 builder.setInsertionPointToEnd(&stagesBlock);
211 builder.create<LoopScheduleTerminatorOp>(builder.getUnknownLoc(),
212 ValueRange(), ValueRange());
213}
214
215//===----------------------------------------------------------------------===//
216// PipelineWhileStageOp
217//===----------------------------------------------------------------------===//
218
219LogicalResult LoopSchedulePipelineStageOp::verify() {
220 if (getStart() < 0)
221 return emitOpError("'start' must be non-negative");
222
223 return success();
224}
225
226void LoopSchedulePipelineStageOp::build(OpBuilder &builder,
227 OperationState &state,
228 TypeRange resultTypes,
229 IntegerAttr start) {
230 OpBuilder::InsertionGuard g(builder);
231
232 state.addTypes(resultTypes);
233 state.addAttribute("start", start);
234
235 Region *region = state.addRegion();
236 Block &block = region->emplaceBlock();
237 builder.setInsertionPointToEnd(&block);
238 builder.create<LoopScheduleRegisterOp>(builder.getUnknownLoc(), ValueRange());
239}
240
241unsigned LoopSchedulePipelineStageOp::getStageNumber() {
242 unsigned number = 0;
243 auto *op = getOperation();
244 auto parent = op->getParentOfType<LoopSchedulePipelineOp>();
245 Operation *stage = &parent.getStagesBlock().front();
246 while (stage != op && stage->getNextNode()) {
247 ++number;
248 stage = stage->getNextNode();
249 }
250 return number;
251}
252
253//===----------------------------------------------------------------------===//
254// PipelineRegisterOp
255//===----------------------------------------------------------------------===//
256
257LogicalResult LoopScheduleRegisterOp::verify() {
258 LoopSchedulePipelineStageOp stage =
259 (*this)->getParentOfType<LoopSchedulePipelineStageOp>();
260
261 // If this doesn't terminate a stage, it is terminating the condition.
262 if (stage == nullptr)
263 return success();
264
265 // Verify stage terminates with the same types as the result types.
266 TypeRange registerTypes = getOperandTypes();
267 TypeRange resultTypes = stage.getResultTypes();
268 if (registerTypes != resultTypes)
269 return emitOpError("operand types (")
270 << registerTypes << ") must match result types (" << resultTypes
271 << ")";
272
273 return success();
274}
275
276//===----------------------------------------------------------------------===//
277// PipelineTerminatorOp
278//===----------------------------------------------------------------------===//
279
280LogicalResult LoopScheduleTerminatorOp::verify() {
281 LoopSchedulePipelineOp pipeline =
282 (*this)->getParentOfType<LoopSchedulePipelineOp>();
283
284 // Verify pipeline terminates with the same `iter_args` types as the pipeline.
285 auto iterArgs = getIterArgs();
286 TypeRange terminatorArgTypes = iterArgs.getTypes();
287 TypeRange pipelineArgTypes = pipeline.getIterArgs().getTypes();
288 if (terminatorArgTypes != pipelineArgTypes)
289 return emitOpError("'iter_args' types (")
290 << terminatorArgTypes << ") must match pipeline 'iter_args' types ("
291 << pipelineArgTypes << ")";
292
293 // Verify `iter_args` are defined by a pipeline stage.
294 for (auto iterArg : iterArgs)
295 if (iterArg.getDefiningOp<LoopSchedulePipelineStageOp>() == nullptr)
296 return emitOpError(
297 "'iter_args' must be defined by a 'loopschedule.pipeline.stage'");
298
299 // Verify pipeline terminates with the same result types as the pipeline.
300 auto opResults = getResults();
301 TypeRange terminatorResultTypes = opResults.getTypes();
302 TypeRange pipelineResultTypes = pipeline.getResultTypes();
303 if (terminatorResultTypes != pipelineResultTypes)
304 return emitOpError("'results' types (")
305 << terminatorResultTypes << ") must match pipeline result types ("
306 << pipelineResultTypes << ")";
307
308 // Verify `results` are defined by a pipeline stage.
309 for (auto result : opResults)
310 if (result.getDefiningOp<LoopSchedulePipelineStageOp>() == nullptr)
311 return emitOpError(
312 "'results' must be defined by a 'loopschedule.pipeline.stage'");
313
314 return success();
315}
316
317#define GET_OP_CLASSES
318#include "circt/Dialect/LoopSchedule/LoopSchedule.cpp.inc"
319
320void LoopScheduleDialect::initialize() {
321 addOperations<
322#define GET_OP_LIST
323#include "circt/Dialect/LoopSchedule/LoopSchedule.cpp.inc"
324 >();
325}
326
327#include "circt/Dialect/LoopSchedule/LoopScheduleDialect.cpp.inc"
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.