CIRCT  20.0.0git
LoopScheduleOps.cpp
Go to the documentation of this file.
1 //===- LoopScheduleOps.cpp - LoopSchedule CIRCT Operations ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the LoopSchedule ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Interfaces/FunctionImplementation.h"
19 
20 using namespace mlir;
21 using namespace circt;
22 using namespace circt::loopschedule;
23 
24 //===----------------------------------------------------------------------===//
25 // LoopSchedulePipelineWhileOp
26 //===----------------------------------------------------------------------===//
27 
28 ParseResult LoopSchedulePipelineOp::parse(OpAsmParser &parser,
29  OperationState &result) {
30  // Parse initiation interval.
31  IntegerAttr ii;
32  if (parser.parseKeyword("II") || parser.parseEqual() ||
33  parser.parseAttribute(ii))
34  return failure();
35  result.addAttribute("II", ii);
36 
37  // Parse optional trip count.
38  if (succeeded(parser.parseOptionalKeyword("trip_count"))) {
39  IntegerAttr tripCount;
40  if (parser.parseEqual() || parser.parseAttribute(tripCount))
41  return failure();
42  result.addAttribute("tripCount", tripCount);
43  }
44 
45  // Parse iter_args assignment list.
46  SmallVector<OpAsmParser::Argument> regionArgs;
47  SmallVector<OpAsmParser::UnresolvedOperand> operands;
48  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
49  if (parser.parseAssignmentList(regionArgs, operands))
50  return failure();
51  }
52 
53  // Parse function type from iter_args to results.
54  FunctionType type;
55  if (parser.parseColon() || parser.parseType(type))
56  return failure();
57 
58  // Function result type is the pipeline result type.
59  result.addTypes(type.getResults());
60 
61  // Resolve iter_args operands.
62  for (auto [regionArg, operand, type] :
63  llvm::zip(regionArgs, operands, type.getInputs())) {
64  regionArg.type = type;
65  if (parser.resolveOperand(operand, type, result.operands))
66  return failure();
67  }
68 
69  // Parse condition region.
70  Region *condition = result.addRegion();
71  if (parser.parseRegion(*condition, regionArgs))
72  return failure();
73 
74  // Parse stages region.
75  if (parser.parseKeyword("do"))
76  return failure();
77  Region *stages = result.addRegion();
78  if (parser.parseRegion(*stages, regionArgs))
79  return failure();
80 
81  return success();
82 }
83 
84 void LoopSchedulePipelineOp::print(OpAsmPrinter &p) {
85  // Print the initiation interval.
86  p << " II = " << ' ' << getII();
87 
88  // Print the optional tripCount.
89  if (getTripCount())
90  p << " trip_count = " << ' ' << *getTripCount();
91 
92  // Print iter_args assignment list.
93  p << " iter_args(";
94  llvm::interleaveComma(
95  llvm::zip(getStages().getArguments(), getIterArgs()), p,
96  [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); });
97  p << ") : ";
98 
99  // Print function type from iter_args to results.
100  auto type = FunctionType::get(getContext(), getStages().getArgumentTypes(),
101  getResultTypes());
102  p.printType(type);
103 
104  // Print condition region.
105  p << ' ';
106  p.printRegion(getCondition(), /*printEntryBlockArgs=*/false);
107  p << " do";
108 
109  // Print stages region.
110  p << ' ';
111  p.printRegion(getStages(), /*printEntryBlockArgs=*/false);
112 }
113 
114 LogicalResult LoopSchedulePipelineOp::verify() {
115  // Verify the condition block is "combinational" based on an allowlist of
116  // Arithmetic ops.
117  Block &conditionBlock = getCondition().front();
118  Operation *nonCombinational;
119  WalkResult conditionWalk = conditionBlock.walk([&](Operation *op) {
120  if (isa<LoopScheduleDialect>(op->getDialect()))
121  return WalkResult::advance();
122 
123  if (!isa<arith::AddIOp, arith::AndIOp, arith::BitcastOp, arith::CmpIOp,
124  arith::ConstantOp, arith::IndexCastOp, arith::MulIOp, arith::OrIOp,
125  arith::SelectOp, arith::ShLIOp, arith::ExtSIOp, arith::CeilDivSIOp,
126  arith::DivSIOp, arith::FloorDivSIOp, arith::RemSIOp,
127  arith::ShRSIOp, arith::SubIOp, arith::TruncIOp, arith::DivUIOp,
128  arith::RemUIOp, arith::ShRUIOp, arith::XOrIOp, arith::ExtUIOp>(
129  op)) {
130  nonCombinational = op;
131  return WalkResult::interrupt();
132  }
133 
134  return WalkResult::advance();
135  });
136 
137  if (conditionWalk.wasInterrupted())
138  return emitOpError("condition must have a combinational body, found ")
139  << *nonCombinational;
140 
141  // Verify the condition block terminates with a value of type i1.
142  TypeRange conditionResults =
143  conditionBlock.getTerminator()->getOperandTypes();
144  if (conditionResults.size() != 1)
145  return emitOpError("condition must terminate with a single result, found ")
146  << conditionResults;
147 
148  if (conditionResults.front() != IntegerType::get(getContext(), 1))
149  return emitOpError("condition must terminate with an i1 result, found ")
150  << conditionResults.front();
151 
152  // Verify the stages block contains at least one stage and a terminator.
153  Block &stagesBlock = getStages().front();
154  if (stagesBlock.getOperations().size() < 2)
155  return emitOpError("stages must contain at least one stage");
156 
157  int64_t lastStartTime = -1;
158  for (Operation &inner : stagesBlock) {
159  // Verify the stages block contains only `loopschedule.pipeline.stage` and
160  // `loopschedule.terminator` ops.
161  if (!isa<LoopSchedulePipelineStageOp, LoopScheduleTerminatorOp>(inner))
162  return emitOpError(
163  "stages may only contain 'loopschedule.pipeline.stage' or "
164  "'loopschedule.terminator' ops, found ")
165  << inner;
166 
167  // Verify the stage start times are monotonically increasing.
168  if (auto stage = dyn_cast<LoopSchedulePipelineStageOp>(inner)) {
169  if (lastStartTime == -1) {
170  lastStartTime = stage.getStart();
171  continue;
172  }
173 
174  if (lastStartTime >= stage.getStart())
175  return stage.emitOpError("'start' must be after previous 'start' (")
176  << lastStartTime << ')';
177 
178  lastStartTime = stage.getStart();
179  }
180  }
181 
182  return success();
183 }
184 
185 void LoopSchedulePipelineOp::build(OpBuilder &builder, OperationState &state,
186  TypeRange resultTypes, IntegerAttr ii,
187  std::optional<IntegerAttr> tripCount,
188  ValueRange iterArgs) {
189  OpBuilder::InsertionGuard g(builder);
190 
191  state.addTypes(resultTypes);
192  state.addAttribute("II", ii);
193  if (tripCount)
194  state.addAttribute("tripCount", *tripCount);
195  state.addOperands(iterArgs);
196 
197  Region *condRegion = state.addRegion();
198  Block &condBlock = condRegion->emplaceBlock();
199 
200  SmallVector<Location, 4> argLocs;
201  for (auto arg : iterArgs)
202  argLocs.push_back(arg.getLoc());
203  condBlock.addArguments(iterArgs.getTypes(), argLocs);
204  builder.setInsertionPointToEnd(&condBlock);
205  builder.create<LoopScheduleRegisterOp>(builder.getUnknownLoc(), ValueRange());
206 
207  Region *stagesRegion = state.addRegion();
208  Block &stagesBlock = stagesRegion->emplaceBlock();
209  stagesBlock.addArguments(iterArgs.getTypes(), argLocs);
210  builder.setInsertionPointToEnd(&stagesBlock);
211  builder.create<LoopScheduleTerminatorOp>(builder.getUnknownLoc(),
212  ValueRange(), ValueRange());
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // PipelineWhileStageOp
217 //===----------------------------------------------------------------------===//
218 
219 LogicalResult LoopSchedulePipelineStageOp::verify() {
220  if (getStart() < 0)
221  return emitOpError("'start' must be non-negative");
222 
223  return success();
224 }
225 
226 void LoopSchedulePipelineStageOp::build(OpBuilder &builder,
227  OperationState &state,
228  TypeRange resultTypes,
229  IntegerAttr start) {
230  OpBuilder::InsertionGuard g(builder);
231 
232  state.addTypes(resultTypes);
233  state.addAttribute("start", start);
234 
235  Region *region = state.addRegion();
236  Block &block = region->emplaceBlock();
237  builder.setInsertionPointToEnd(&block);
238  builder.create<LoopScheduleRegisterOp>(builder.getUnknownLoc(), ValueRange());
239 }
240 
241 unsigned LoopSchedulePipelineStageOp::getStageNumber() {
242  unsigned number = 0;
243  auto *op = getOperation();
244  auto parent = op->getParentOfType<LoopSchedulePipelineOp>();
245  Operation *stage = &parent.getStagesBlock().front();
246  while (stage != op && stage->getNextNode()) {
247  ++number;
248  stage = stage->getNextNode();
249  }
250  return number;
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // PipelineRegisterOp
255 //===----------------------------------------------------------------------===//
256 
257 LogicalResult LoopScheduleRegisterOp::verify() {
258  LoopSchedulePipelineStageOp stage =
259  (*this)->getParentOfType<LoopSchedulePipelineStageOp>();
260 
261  // If this doesn't terminate a stage, it is terminating the condition.
262  if (stage == nullptr)
263  return success();
264 
265  // Verify stage terminates with the same types as the result types.
266  TypeRange registerTypes = getOperandTypes();
267  TypeRange resultTypes = stage.getResultTypes();
268  if (registerTypes != resultTypes)
269  return emitOpError("operand types (")
270  << registerTypes << ") must match result types (" << resultTypes
271  << ")";
272 
273  return success();
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // PipelineTerminatorOp
278 //===----------------------------------------------------------------------===//
279 
280 LogicalResult LoopScheduleTerminatorOp::verify() {
281  LoopSchedulePipelineOp pipeline =
282  (*this)->getParentOfType<LoopSchedulePipelineOp>();
283 
284  // Verify pipeline terminates with the same `iter_args` types as the pipeline.
285  auto iterArgs = getIterArgs();
286  TypeRange terminatorArgTypes = iterArgs.getTypes();
287  TypeRange pipelineArgTypes = pipeline.getIterArgs().getTypes();
288  if (terminatorArgTypes != pipelineArgTypes)
289  return emitOpError("'iter_args' types (")
290  << terminatorArgTypes << ") must match pipeline 'iter_args' types ("
291  << pipelineArgTypes << ")";
292 
293  // Verify `iter_args` are defined by a pipeline stage.
294  for (auto iterArg : iterArgs)
295  if (iterArg.getDefiningOp<LoopSchedulePipelineStageOp>() == nullptr)
296  return emitOpError(
297  "'iter_args' must be defined by a 'loopschedule.pipeline.stage'");
298 
299  // Verify pipeline terminates with the same result types as the pipeline.
300  auto opResults = getResults();
301  TypeRange terminatorResultTypes = opResults.getTypes();
302  TypeRange pipelineResultTypes = pipeline.getResultTypes();
303  if (terminatorResultTypes != pipelineResultTypes)
304  return emitOpError("'results' types (")
305  << terminatorResultTypes << ") must match pipeline result types ("
306  << pipelineResultTypes << ")";
307 
308  // Verify `results` are defined by a pipeline stage.
309  for (auto result : opResults)
310  if (result.getDefiningOp<LoopSchedulePipelineStageOp>() == nullptr)
311  return emitOpError(
312  "'results' must be defined by a 'loopschedule.pipeline.stage'");
313 
314  return success();
315 }
316 
317 #define GET_OP_CLASSES
318 #include "circt/Dialect/LoopSchedule/LoopSchedule.cpp.inc"
319 
320 void LoopScheduleDialect::initialize() {
321  addOperations<
322 #define GET_OP_LIST
323 #include "circt/Dialect/LoopSchedule/LoopSchedule.cpp.inc"
324  >();
325 }
326 
327 #include "circt/Dialect/LoopSchedule/LoopScheduleDialect.cpp.inc"
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition: SVOps.cpp:2443
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21