15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/Interfaces/FunctionImplementation.h"
22using namespace circt::loopschedule;
28ParseResult LoopSchedulePipelineOp::parse(OpAsmParser &parser,
29 OperationState &result) {
32 if (parser.parseKeyword(
"II") || parser.parseEqual() ||
33 parser.parseAttribute(ii))
35 result.addAttribute(
"II", ii);
38 if (succeeded(parser.parseOptionalKeyword(
"trip_count"))) {
39 IntegerAttr tripCount;
40 if (parser.parseEqual() || parser.parseAttribute(tripCount))
42 result.addAttribute(
"tripCount", tripCount);
46 SmallVector<OpAsmParser::Argument> regionArgs;
47 SmallVector<OpAsmParser::UnresolvedOperand> operands;
48 if (succeeded(parser.parseOptionalKeyword(
"iter_args"))) {
49 if (parser.parseAssignmentList(regionArgs, operands))
55 if (parser.parseColon() || parser.parseType(type))
59 result.addTypes(type.getResults());
62 for (
auto [regionArg, operand, type] :
63 llvm::zip(regionArgs, operands, type.getInputs())) {
64 regionArg.type = type;
65 if (parser.resolveOperand(operand, type, result.operands))
70 Region *condition = result.addRegion();
71 if (parser.parseRegion(*condition, regionArgs))
75 if (parser.parseKeyword(
"do"))
77 Region *stages = result.addRegion();
78 if (parser.parseRegion(*stages, regionArgs))
84void LoopSchedulePipelineOp::print(OpAsmPrinter &p) {
86 p <<
" II = " <<
' ' << getII();
90 p <<
" trip_count = " <<
' ' << *getTripCount();
94 llvm::interleaveComma(
95 llvm::zip(getStages().getArguments(), getIterArgs()), p,
96 [&](
auto it) { p << std::get<0>(it) <<
" = " << std::get<1>(it); });
100 auto type = FunctionType::get(getContext(), getStages().getArgumentTypes(),
106 p.printRegion(getCondition(),
false);
111 p.printRegion(getStages(),
false);
114LogicalResult LoopSchedulePipelineOp::verify() {
117 Block &conditionBlock = getCondition().front();
118 Operation *nonCombinational;
119 WalkResult conditionWalk = conditionBlock.walk([&](Operation *op) {
120 if (isa_and_nonnull<LoopScheduleDialect>(op->getDialect()))
121 return WalkResult::advance();
123 if (!isa<arith::AddIOp, arith::AndIOp, arith::BitcastOp, arith::CmpIOp,
124 arith::ConstantOp, arith::IndexCastOp, arith::MulIOp, arith::OrIOp,
125 arith::SelectOp, arith::ShLIOp, arith::ExtSIOp, arith::CeilDivSIOp,
126 arith::DivSIOp, arith::FloorDivSIOp, arith::RemSIOp,
127 arith::ShRSIOp, arith::SubIOp, arith::TruncIOp, arith::DivUIOp,
128 arith::RemUIOp, arith::ShRUIOp, arith::XOrIOp, arith::ExtUIOp>(
130 nonCombinational = op;
131 return WalkResult::interrupt();
134 return WalkResult::advance();
137 if (conditionWalk.wasInterrupted())
138 return emitOpError(
"condition must have a combinational body, found ")
139 << *nonCombinational;
142 TypeRange conditionResults =
143 conditionBlock.getTerminator()->getOperandTypes();
144 if (conditionResults.size() != 1)
145 return emitOpError(
"condition must terminate with a single result, found ")
148 if (conditionResults.front() != IntegerType::get(getContext(), 1))
149 return emitOpError(
"condition must terminate with an i1 result, found ")
150 << conditionResults.front();
153 Block &stagesBlock = getStages().front();
154 if (stagesBlock.getOperations().size() < 2)
155 return emitOpError(
"stages must contain at least one stage");
157 int64_t lastStartTime = -1;
158 for (Operation &inner : stagesBlock) {
161 if (!isa<LoopSchedulePipelineStageOp, LoopScheduleTerminatorOp>(inner))
163 "stages may only contain 'loopschedule.pipeline.stage' or "
164 "'loopschedule.terminator' ops, found ")
168 if (
auto stage = dyn_cast<LoopSchedulePipelineStageOp>(inner)) {
169 if (lastStartTime == -1) {
170 lastStartTime = stage.getStart();
174 if (lastStartTime >= stage.getStart())
175 return stage.emitOpError(
"'start' must be after previous 'start' (")
176 << lastStartTime <<
')';
178 lastStartTime = stage.getStart();
185void LoopSchedulePipelineOp::build(OpBuilder &builder, OperationState &state,
186 TypeRange resultTypes, IntegerAttr ii,
187 std::optional<IntegerAttr> tripCount,
188 ValueRange iterArgs) {
189 OpBuilder::InsertionGuard g(builder);
191 state.addTypes(resultTypes);
192 state.addAttribute(
"II", ii);
194 state.addAttribute(
"tripCount", *tripCount);
195 state.addOperands(iterArgs);
197 Region *condRegion = state.addRegion();
198 Block &condBlock = condRegion->emplaceBlock();
200 SmallVector<Location, 4> argLocs;
201 for (
auto arg : iterArgs)
202 argLocs.push_back(arg.
getLoc());
203 condBlock.addArguments(iterArgs.getTypes(), argLocs);
204 builder.setInsertionPointToEnd(&condBlock);
205 LoopScheduleRegisterOp::create(builder, builder.getUnknownLoc(),
208 Region *stagesRegion = state.addRegion();
209 Block &stagesBlock = stagesRegion->emplaceBlock();
210 stagesBlock.addArguments(iterArgs.getTypes(), argLocs);
211 builder.setInsertionPointToEnd(&stagesBlock);
212 LoopScheduleTerminatorOp::create(builder, builder.getUnknownLoc(),
213 ValueRange(), ValueRange());
220LogicalResult LoopSchedulePipelineStageOp::verify() {
222 return emitOpError(
"'start' must be non-negative");
227void LoopSchedulePipelineStageOp::build(OpBuilder &builder,
228 OperationState &state,
229 TypeRange resultTypes,
231 OpBuilder::InsertionGuard g(builder);
233 state.addTypes(resultTypes);
234 state.addAttribute(
"start", start);
236 Region *region = state.addRegion();
237 Block &block = region->emplaceBlock();
238 builder.setInsertionPointToEnd(&block);
239 LoopScheduleRegisterOp::create(builder, builder.getUnknownLoc(),
243unsigned LoopSchedulePipelineStageOp::getStageNumber() {
245 auto *op = getOperation();
247 Operation *stage = &parent.getStagesBlock().front();
248 while (stage != op && stage->getNextNode()) {
250 stage = stage->getNextNode();
259LogicalResult LoopScheduleRegisterOp::verify() {
260 LoopSchedulePipelineStageOp stage =
261 (*this)->getParentOfType<LoopSchedulePipelineStageOp>();
264 if (stage ==
nullptr)
268 TypeRange registerTypes = getOperandTypes();
269 TypeRange resultTypes = stage.getResultTypes();
270 if (registerTypes != resultTypes)
271 return emitOpError(
"operand types (")
272 << registerTypes <<
") must match result types (" << resultTypes
282LogicalResult LoopScheduleTerminatorOp::verify() {
287 auto iterArgs = getIterArgs();
288 TypeRange terminatorArgTypes = iterArgs.getTypes();
289 TypeRange pipelineArgTypes = pipeline.getIterArgs().getTypes();
290 if (terminatorArgTypes != pipelineArgTypes)
291 return emitOpError(
"'iter_args' types (")
292 << terminatorArgTypes <<
") must match pipeline 'iter_args' types ("
293 << pipelineArgTypes <<
")";
296 for (
auto iterArg : iterArgs)
297 if (iterArg.getDefiningOp<LoopSchedulePipelineStageOp>() == nullptr)
299 "'iter_args' must be defined by a 'loopschedule.pipeline.stage'");
302 auto opResults = getResults();
303 TypeRange terminatorResultTypes = opResults.getTypes();
304 TypeRange pipelineResultTypes = pipeline.getResultTypes();
305 if (terminatorResultTypes != pipelineResultTypes)
306 return emitOpError(
"'results' types (")
307 << terminatorResultTypes <<
") must match pipeline result types ("
308 << pipelineResultTypes <<
")";
311 for (
auto result : opResults)
312 if (result.getDefiningOp<LoopSchedulePipelineStageOp>() == nullptr)
314 "'results' must be defined by a 'loopschedule.pipeline.stage'");
319#define GET_OP_CLASSES
320#include "circt/Dialect/LoopSchedule/LoopSchedule.cpp.inc"
322void LoopScheduleDialect::initialize() {
325#include "circt/Dialect/LoopSchedule/LoopSchedule.cpp.inc"
329#include "circt/Dialect/LoopSchedule/LoopScheduleDialect.cpp.inc"
static Location getLoc(DefSlot slot)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.