15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Interfaces/FunctionImplementation.h"
21 using namespace circt;
22 using namespace circt::loopschedule;
28 ParseResult LoopSchedulePipelineOp::parse(OpAsmParser &parser,
29 OperationState &result) {
32 if (parser.parseKeyword(
"II") || parser.parseEqual() ||
33 parser.parseAttribute(ii))
35 result.addAttribute(
"II", ii);
38 if (succeeded(parser.parseOptionalKeyword(
"trip_count"))) {
39 IntegerAttr tripCount;
40 if (parser.parseEqual() || parser.parseAttribute(tripCount))
42 result.addAttribute(
"tripCount", tripCount);
46 SmallVector<OpAsmParser::Argument> regionArgs;
47 SmallVector<OpAsmParser::UnresolvedOperand> operands;
48 if (succeeded(parser.parseOptionalKeyword(
"iter_args"))) {
49 if (parser.parseAssignmentList(regionArgs, operands))
55 if (parser.parseColon() || parser.parseType(type))
59 result.addTypes(type.getResults());
62 for (
auto [regionArg, operand, type] :
63 llvm::zip(regionArgs, operands, type.getInputs())) {
64 regionArg.type = type;
65 if (parser.resolveOperand(operand, type, result.operands))
70 Region *condition = result.addRegion();
71 if (parser.parseRegion(*condition, regionArgs))
75 if (parser.parseKeyword(
"do"))
77 Region *stages = result.addRegion();
78 if (parser.parseRegion(*stages, regionArgs))
84 void LoopSchedulePipelineOp::print(OpAsmPrinter &p) {
86 p <<
" II = " <<
' ' << getII();
90 p <<
" trip_count = " <<
' ' << *getTripCount();
94 llvm::interleaveComma(
95 llvm::zip(getStages().getArguments(), getIterArgs()), p,
96 [&](
auto it) { p << std::get<0>(it) <<
" = " << std::get<1>(it); });
106 p.printRegion(getCondition(),
false);
111 p.printRegion(getStages(),
false);
117 Block &conditionBlock = getCondition().front();
118 Operation *nonCombinational;
119 WalkResult conditionWalk = conditionBlock.walk([&](Operation *op) {
120 if (isa<LoopScheduleDialect>(op->getDialect()))
121 return WalkResult::advance();
123 if (!isa<arith::AddIOp, arith::AndIOp, arith::BitcastOp, arith::CmpIOp,
124 arith::ConstantOp, arith::IndexCastOp, arith::MulIOp, arith::OrIOp,
125 arith::SelectOp, arith::ShLIOp, arith::ExtSIOp, arith::CeilDivSIOp,
126 arith::DivSIOp, arith::FloorDivSIOp, arith::RemSIOp,
127 arith::ShRSIOp, arith::SubIOp, arith::TruncIOp, arith::DivUIOp,
128 arith::RemUIOp, arith::ShRUIOp, arith::XOrIOp, arith::ExtUIOp>(
130 nonCombinational = op;
131 return WalkResult::interrupt();
134 return WalkResult::advance();
137 if (conditionWalk.wasInterrupted())
138 return emitOpError(
"condition must have a combinational body, found ")
139 << *nonCombinational;
142 TypeRange conditionResults =
143 conditionBlock.getTerminator()->getOperandTypes();
144 if (conditionResults.size() != 1)
145 return emitOpError(
"condition must terminate with a single result, found ")
149 return emitOpError(
"condition must terminate with an i1 result, found ")
150 << conditionResults.front();
153 Block &stagesBlock = getStages().front();
154 if (stagesBlock.getOperations().size() < 2)
155 return emitOpError(
"stages must contain at least one stage");
157 int64_t lastStartTime = -1;
158 for (Operation &inner : stagesBlock) {
161 if (!isa<LoopSchedulePipelineStageOp, LoopScheduleTerminatorOp>(inner))
163 "stages may only contain 'loopschedule.pipeline.stage' or "
164 "'loopschedule.terminator' ops, found ")
168 if (
auto stage = dyn_cast<LoopSchedulePipelineStageOp>(inner)) {
169 if (lastStartTime == -1) {
170 lastStartTime = stage.getStart();
174 if (lastStartTime >= stage.getStart())
175 return stage.emitOpError(
"'start' must be after previous 'start' (")
176 << lastStartTime <<
')';
178 lastStartTime = stage.getStart();
185 void LoopSchedulePipelineOp::build(OpBuilder &builder, OperationState &state,
186 TypeRange resultTypes, IntegerAttr ii,
187 std::optional<IntegerAttr> tripCount,
188 ValueRange iterArgs) {
189 OpBuilder::InsertionGuard g(builder);
191 state.addTypes(resultTypes);
192 state.addAttribute(
"II", ii);
194 state.addAttribute(
"tripCount", *tripCount);
195 state.addOperands(iterArgs);
197 Region *condRegion = state.addRegion();
198 Block &condBlock = condRegion->emplaceBlock();
200 SmallVector<Location, 4> argLocs;
201 for (
auto arg : iterArgs)
202 argLocs.push_back(arg.getLoc());
203 condBlock.addArguments(iterArgs.getTypes(), argLocs);
204 builder.setInsertionPointToEnd(&condBlock);
205 builder.create<LoopScheduleRegisterOp>(builder.getUnknownLoc(), ValueRange());
207 Region *stagesRegion = state.addRegion();
208 Block &stagesBlock = stagesRegion->emplaceBlock();
209 stagesBlock.addArguments(iterArgs.getTypes(), argLocs);
210 builder.setInsertionPointToEnd(&stagesBlock);
211 builder.create<LoopScheduleTerminatorOp>(builder.getUnknownLoc(),
212 ValueRange(), ValueRange());
221 return emitOpError(
"'start' must be non-negative");
226 void LoopSchedulePipelineStageOp::build(OpBuilder &builder,
227 OperationState &state,
228 TypeRange resultTypes,
230 OpBuilder::InsertionGuard g(builder);
232 state.addTypes(resultTypes);
233 state.addAttribute(
"start", start);
235 Region *region = state.addRegion();
236 Block &block = region->emplaceBlock();
237 builder.setInsertionPointToEnd(&block);
238 builder.create<LoopScheduleRegisterOp>(builder.getUnknownLoc(), ValueRange());
241 unsigned LoopSchedulePipelineStageOp::getStageNumber() {
243 auto *op = getOperation();
244 auto parent = op->getParentOfType<LoopSchedulePipelineOp>();
245 Operation *stage = &parent.getStagesBlock().front();
246 while (stage != op && stage->getNextNode()) {
248 stage = stage->getNextNode();
258 LoopSchedulePipelineStageOp stage =
259 (*this)->getParentOfType<LoopSchedulePipelineStageOp>();
262 if (stage ==
nullptr)
266 TypeRange registerTypes = getOperandTypes();
267 TypeRange resultTypes = stage.getResultTypes();
268 if (registerTypes != resultTypes)
269 return emitOpError(
"operand types (")
270 << registerTypes <<
") must match result types (" << resultTypes
281 LoopSchedulePipelineOp pipeline =
282 (*this)->getParentOfType<LoopSchedulePipelineOp>();
285 auto iterArgs = getIterArgs();
286 TypeRange terminatorArgTypes = iterArgs.getTypes();
287 TypeRange pipelineArgTypes = pipeline.getIterArgs().getTypes();
288 if (terminatorArgTypes != pipelineArgTypes)
289 return emitOpError(
"'iter_args' types (")
290 << terminatorArgTypes <<
") must match pipeline 'iter_args' types ("
291 << pipelineArgTypes <<
")";
294 for (
auto iterArg : iterArgs)
295 if (iterArg.getDefiningOp<LoopSchedulePipelineStageOp>() ==
nullptr)
297 "'iter_args' must be defined by a 'loopschedule.pipeline.stage'");
300 auto opResults = getResults();
301 TypeRange terminatorResultTypes = opResults.getTypes();
302 TypeRange pipelineResultTypes = pipeline.getResultTypes();
303 if (terminatorResultTypes != pipelineResultTypes)
304 return emitOpError(
"'results' types (")
305 << terminatorResultTypes <<
") must match pipeline result types ("
306 << pipelineResultTypes <<
")";
309 for (
auto result : opResults)
310 if (result.getDefiningOp<LoopSchedulePipelineStageOp>() ==
nullptr)
312 "'results' must be defined by a 'loopschedule.pipeline.stage'");
317 #define GET_OP_CLASSES
318 #include "circt/Dialect/LoopSchedule/LoopSchedule.cpp.inc"
320 void LoopScheduleDialect::initialize() {
323 #include "circt/Dialect/LoopSchedule/LoopSchedule.cpp.inc"
327 #include "circt/Dialect/LoopSchedule/LoopScheduleDialect.cpp.inc"
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.