CIRCT 20.0.0git
Loading...
Searching...
No Matches
PipelineOps.cpp
Go to the documentation of this file.
1//===- PipelineOps.h - Pipeline MLIR Operations -----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the Pipeline ops.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/Interfaces/FunctionImplementation.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/FormatVariadic.h"
21
22using namespace mlir;
23using namespace circt;
24using namespace circt::pipeline;
25using namespace circt::parsing_util;
26
27#include "circt/Dialect/Pipeline/PipelineDialect.cpp.inc"
28
29#define DEBUG_TYPE "pipeline-ops"
30
31llvm::SmallVector<Value>
33 llvm::SetVector<Value> values;
34 region.walk([&](Operation *op) {
35 for (auto operand : op->getOperands()) {
36 if (region.isAncestor(operand.getParentRegion()))
37 continue;
38 values.insert(operand);
39 }
40 });
41 return values.takeVector();
42}
43
44Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
45 Block *block) {
46 // Optional debug check - ensure that 'block' eventually leads to the
47 // pipeline.
48 LLVM_DEBUG({
49 Operation *directParent = block->getParentOp();
50 if (directParent != pipeline) {
51 auto indirectParent =
52 directParent->getParentOfType<ScheduledPipelineOp>();
53 assert(indirectParent == pipeline && "block is not in the pipeline");
54 }
55 });
56
57 while (block && block->getParent() != &pipeline.getRegion()) {
58 // Go one level up.
59 block = block->getParent()->getParentOp()->getBlock();
60 }
61
62 // This is a block within the pipeline region, so it must be a stage.
63 return block;
64}
65
66Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
67 Operation *op) {
68 return getParentStageInPipeline(pipeline, op->getBlock());
69}
70
71Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
72 Value v) {
73 if (isa<BlockArgument>(v))
74 return getParentStageInPipeline(pipeline,
75 cast<BlockArgument>(v).getOwner());
76 return getParentStageInPipeline(pipeline, v.getDefiningOp());
77}
78
79//===----------------------------------------------------------------------===//
80// Fancy pipeline-like op printer/parser functions.
81//===----------------------------------------------------------------------===//
82
83// Parses a list of operands on the format:
84// (name : type, ...)
85static ParseResult parseOutputList(OpAsmParser &parser,
86 llvm::SmallVector<Type> &inputTypes,
87 mlir::ArrayAttr &outputNames) {
88
89 llvm::SmallVector<Attribute> names;
90 if (parser.parseCommaSeparatedList(
91 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
92 StringRef name;
93 Type type;
94 if (parser.parseKeyword(&name) || parser.parseColonType(type))
95 return failure();
96
97 inputTypes.push_back(type);
98 names.push_back(StringAttr::get(parser.getContext(), name));
99 return success();
100 }))
101 return failure();
102
103 outputNames = ArrayAttr::get(parser.getContext(), names);
104 return success();
105}
106
107static void printOutputList(OpAsmPrinter &p, TypeRange types, ArrayAttr names) {
108 p << "(";
109 llvm::interleaveComma(llvm::zip(types, names), p, [&](auto it) {
110 auto [type, name] = it;
111 p.printKeywordOrString(cast<StringAttr>(name).str());
112 p << " : " << type;
113 });
114 p << ")";
115}
116
117static ParseResult parseKeywordAndOperand(OpAsmParser &p, StringRef keyword,
118 OpAsmParser::UnresolvedOperand &op) {
119 if (p.parseKeyword(keyword) || p.parseLParen() || p.parseOperand(op) ||
120 p.parseRParen())
121 return failure();
122 return success();
123}
124
125// Assembly format is roughly:
126// ( $name )? initializer-list stall (%stall = $stall)?
127// clock (%clock) reset (%reset) go(%go) entryEnable(%en) {
128// --- elided inner block ---
129static ParseResult parsePipelineOp(mlir::OpAsmParser &parser,
130 mlir::OperationState &result) {
131 std::string name;
132 if (succeeded(parser.parseOptionalString(&name)))
133 result.addAttribute("name", parser.getBuilder().getStringAttr(name));
134
135 llvm::SmallVector<OpAsmParser::UnresolvedOperand> inputOperands;
136 llvm::SmallVector<OpAsmParser::Argument> inputArguments;
137 llvm::SmallVector<Type> inputTypes;
138 ArrayAttr inputNames;
139 if (parseInitializerList(parser, inputArguments, inputOperands, inputTypes,
140 inputNames))
141 return failure();
142 result.addAttribute("inputNames", inputNames);
143
144 Type i1 = parser.getBuilder().getI1Type();
145
146 OpAsmParser::UnresolvedOperand stallOperand, clockOperand, resetOperand,
147 goOperand;
148
149 // Parse optional 'stall (%stallArg)'
150 bool withStall = false;
151 if (succeeded(parser.parseOptionalKeyword("stall"))) {
152 if (parser.parseLParen() || parser.parseOperand(stallOperand) ||
153 parser.parseRParen())
154 return failure();
155 withStall = true;
156 }
157
158 // Parse clock, reset, and go.
159 if (parseKeywordAndOperand(parser, "clock", clockOperand))
160 return failure();
161
162 // Parse optional 'reset (%resetArg)'
163 bool withReset = false;
164 if (succeeded(parser.parseOptionalKeyword("reset"))) {
165 if (parser.parseLParen() || parser.parseOperand(resetOperand) ||
166 parser.parseRParen())
167 return failure();
168 withReset = true;
169 }
170
171 if (parseKeywordAndOperand(parser, "go", goOperand))
172 return failure();
173
174 // Parse entry stage enable block argument.
175 OpAsmParser::Argument entryEnable;
176 entryEnable.type = i1;
177 if (parser.parseKeyword("entryEn") || parser.parseLParen() ||
178 parser.parseArgument(entryEnable) || parser.parseRParen())
179 return failure();
180
181 // Optional attribute dict
182 if (parser.parseOptionalAttrDict(result.attributes))
183 return failure();
184
185 // Parse the output assignment list
186 if (parser.parseArrow())
187 return failure();
188
189 llvm::SmallVector<Type> outputTypes;
190 ArrayAttr outputNames;
191 if (parseOutputList(parser, outputTypes, outputNames))
192 return failure();
193 result.addTypes(outputTypes);
194 result.addAttribute("outputNames", outputNames);
195
196 // And the implicit 'done' output.
197 result.addTypes({i1});
198
199 // All operands have been parsed - resolve.
200 if (parser.resolveOperands(inputOperands, inputTypes, parser.getNameLoc(),
201 result.operands))
202 return failure();
203
204 if (withStall) {
205 if (parser.resolveOperand(stallOperand, i1, result.operands))
206 return failure();
207 }
208
209 Type clkType = seq::ClockType::get(parser.getContext());
210 if (parser.resolveOperand(clockOperand, clkType, result.operands))
211 return failure();
212
213 if (withReset && parser.resolveOperand(resetOperand, i1, result.operands))
214 return failure();
215
216 if (parser.resolveOperand(goOperand, i1, result.operands))
217 return failure();
218
219 // Assemble the body region block arguments - this is where the magic happens
220 // and why we're doing a custom printer/parser - if the user had to magically
221 // know the order of these block arguments, we're asking for issues.
222 SmallVector<OpAsmParser::Argument> regionArgs;
223
224 // First we add the input arguments.
225 llvm::append_range(regionArgs, inputArguments);
226 // Then the internal entry stage enable block argument.
227 regionArgs.push_back(entryEnable);
228
229 // Parse the body region.
230 Region *body = result.addRegion();
231 if (parser.parseRegion(*body, regionArgs))
232 return failure();
233
234 result.addAttribute("operandSegmentSizes",
235 parser.getBuilder().getDenseI32ArrayAttr(
236 {static_cast<int32_t>(inputTypes.size()),
237 static_cast<int32_t>(withStall ? 1 : 0),
238 /*clock*/ static_cast<int32_t>(1),
239 /*reset*/ static_cast<int32_t>(withReset ? 1 : 0),
240 /*go*/ static_cast<int32_t>(1)}));
241
242 return success();
243}
244
245static void printKeywordOperand(OpAsmPrinter &p, StringRef keyword,
246 Value value) {
247 p << keyword << "(";
248 p.printOperand(value);
249 p << ")";
250}
251
252template <typename TPipelineOp>
253static void printPipelineOp(OpAsmPrinter &p, TPipelineOp op) {
254 if (auto name = op.getNameAttr()) {
255 p << " \"" << name.getValue() << "\"";
256 }
257
258 // Print the input list.
259 printInitializerList(p, op.getInputs(), op.getInnerInputs());
260 p << " ";
261
262 // Print the optional stall.
263 if (op.hasStall()) {
264 printKeywordOperand(p, "stall", op.getStall());
265 p << " ";
266 }
267
268 // Print the clock, reset, and go.
269 printKeywordOperand(p, "clock", op.getClock());
270 p << " ";
271 if (op.hasReset()) {
272 printKeywordOperand(p, "reset", op.getReset());
273 p << " ";
274 }
275 printKeywordOperand(p, "go", op.getGo());
276 p << " ";
277
278 // Print the entry enable block argument.
279 p << "entryEn(";
280 p.printRegionArgument(
281 cast<BlockArgument>(op.getStageEnableSignal(static_cast<size_t>(0))), {},
282 /*omitType*/ true);
283 p << ") ";
284
285 // Print the optional attribute dict.
286 p.printOptionalAttrDict(op->getAttrs(),
287 /*elidedAttrs=*/{"name", "operandSegmentSizes",
288 "outputNames", "inputNames"});
289 p << " -> ";
290
291 // Print the output list.
292 printOutputList(p, op.getDataOutputs().getTypes(), op.getOutputNames());
293
294 p << " ";
295
296 // Print the inner region, eliding the entry block arguments - we've already
297 // defined these in our initializer lists.
298 p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false,
299 /*printBlockTerminators=*/true);
300}
301
302//===----------------------------------------------------------------------===//
303// UnscheduledPipelineOp
304//===----------------------------------------------------------------------===//
305
306static void buildPipelineLikeOp(OpBuilder &odsBuilder, OperationState &odsState,
307 TypeRange dataOutputs, ValueRange inputs,
308 ArrayAttr inputNames, ArrayAttr outputNames,
309 Value clock, Value go, Value reset, Value stall,
310 StringAttr name, ArrayAttr stallability) {
311 odsState.addOperands(inputs);
312 if (stall)
313 odsState.addOperands(stall);
314 odsState.addOperands(clock);
315 odsState.addOperands(reset);
316 odsState.addOperands(go);
317 if (name)
318 odsState.addAttribute("name", name);
319
320 odsState.addAttribute(
321 "operandSegmentSizes",
322 odsBuilder.getDenseI32ArrayAttr(
323 {static_cast<int32_t>(inputs.size()),
324 static_cast<int32_t>(stall ? 1 : 0), static_cast<int32_t>(1),
325 static_cast<int32_t>(1), static_cast<int32_t>(1)}));
326
327 odsState.addAttribute("inputNames", inputNames);
328 odsState.addAttribute("outputNames", outputNames);
329
330 auto *region = odsState.addRegion();
331 odsState.addTypes(dataOutputs);
332
333 // Add the implicit done output signal.
334 Type i1 = odsBuilder.getIntegerType(1);
335 odsState.addTypes({i1});
336
337 // Add the entry stage - arguments order:
338 // 1. Inputs
339 // 2. Stall (opt)
340 // 3. Clock
341 // 4. Reset
342 // 5. Go
343 auto &entryBlock = region->emplaceBlock();
344 llvm::SmallVector<Location> entryArgLocs(inputs.size(), odsState.location);
345 entryBlock.addArguments(
346 inputs.getTypes(),
347 llvm::SmallVector<Location>(inputs.size(), odsState.location));
348 if (stall)
349 entryBlock.addArgument(i1, odsState.location);
350 entryBlock.addArgument(i1, odsState.location);
351 entryBlock.addArgument(i1, odsState.location);
352
353 // entry stage valid signal.
354 entryBlock.addArgument(i1, odsState.location);
355
356 if (stallability)
357 odsState.addAttribute("stallability", stallability);
358}
359
360template <typename TPipelineOp>
361static void getPipelineAsmResultNames(TPipelineOp op,
362 OpAsmSetValueNameFn setNameFn) {
363 for (auto [res, name] :
364 llvm::zip(op.getDataOutputs(),
365 op.getOutputNames().template getAsValueRange<StringAttr>()))
366 setNameFn(res, name);
367 setNameFn(op.getDone(), "done");
368}
369
370template <typename TPipelineOp>
371static void
372getPipelineAsmBlockArgumentNames(TPipelineOp op, mlir::Region &region,
373 mlir::OpAsmSetValueNameFn setNameFn) {
374 for (auto [i, block] : llvm::enumerate(op.getRegion())) {
375 if (Block *predBlock = block.getSinglePredecessor()) {
376 // Predecessor stageOp might have register and passthrough names
377 // specified, which we can use to name the block arguments.
378 auto predStageOp = cast<StageOp>(predBlock->getTerminator());
379 size_t nRegs = predStageOp.getRegisters().size();
380 auto nPassthrough = predStageOp.getPassthroughs().size();
381
382 auto regNames = predStageOp.getRegisterNames();
383 auto passthroughNames = predStageOp.getPassthroughNames();
384
385 // Register naming...
386 for (size_t regI = 0; regI < nRegs; ++regI) {
387 auto arg = block.getArguments()[regI];
388
389 if (regNames) {
390 auto nameAttr = dyn_cast<StringAttr>((*regNames)[regI]);
391 if (nameAttr && !nameAttr.strref().empty()) {
392 setNameFn(arg, nameAttr);
393 continue;
394 }
395 }
396 setNameFn(arg, llvm::formatv("s{0}_reg{1}", i, regI).str());
397 }
398
399 // Passthrough naming...
400 for (size_t passthroughI = 0; passthroughI < nPassthrough;
401 ++passthroughI) {
402 auto arg = block.getArguments()[nRegs + passthroughI];
403
404 if (passthroughNames) {
405 auto nameAttr =
406 dyn_cast<StringAttr>((*passthroughNames)[passthroughI]);
407 if (nameAttr && !nameAttr.strref().empty()) {
408 setNameFn(arg, nameAttr);
409 continue;
410 }
411 }
412 setNameFn(arg, llvm::formatv("s{0}_pass{1}", i, passthroughI).str());
413 }
414 } else {
415 // This is the entry stage - name the arguments according to the input
416 // names.
417 for (auto [inputArg, inputName] :
418 llvm::zip(op.getInnerInputs(),
419 op.getInputNames().template getAsValueRange<StringAttr>()))
420 setNameFn(inputArg, inputName);
421 }
422
423 // Last argument in any stage is the stage enable signal.
424 setNameFn(block.getArguments().back(),
425 llvm::formatv("s{0}_enable", i).str());
426 }
427}
428
429void UnscheduledPipelineOp::print(OpAsmPrinter &p) {
430 printPipelineOp(p, *this);
431}
432
433ParseResult UnscheduledPipelineOp::parse(OpAsmParser &parser,
434 OperationState &result) {
435 return parsePipelineOp(parser, result);
436}
437
438void UnscheduledPipelineOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
439 getPipelineAsmResultNames(*this, setNameFn);
440}
441
442void UnscheduledPipelineOp::getAsmBlockArgumentNames(
443 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
444 getPipelineAsmBlockArgumentNames(*this, region, setNameFn);
445}
446
447void UnscheduledPipelineOp::build(OpBuilder &odsBuilder,
448 OperationState &odsState,
449 TypeRange dataOutputs, ValueRange inputs,
450 ArrayAttr inputNames, ArrayAttr outputNames,
451 Value clock, Value go, Value reset,
452 Value stall, StringAttr name,
453 ArrayAttr stallability) {
454 buildPipelineLikeOp(odsBuilder, odsState, dataOutputs, inputs, inputNames,
455 outputNames, clock, go, reset, stall, name, stallability);
456}
457
458//===----------------------------------------------------------------------===//
459// ScheduledPipelineOp
460//===----------------------------------------------------------------------===//
461
462void ScheduledPipelineOp::print(OpAsmPrinter &p) { printPipelineOp(p, *this); }
463
464ParseResult ScheduledPipelineOp::parse(OpAsmParser &parser,
465 OperationState &result) {
466 return parsePipelineOp(parser, result);
467}
468
469void ScheduledPipelineOp::build(OpBuilder &odsBuilder, OperationState &odsState,
470 TypeRange dataOutputs, ValueRange inputs,
471 ArrayAttr inputNames, ArrayAttr outputNames,
472 Value clock, Value go, Value reset, Value stall,
473 StringAttr name, ArrayAttr stallability) {
474 buildPipelineLikeOp(odsBuilder, odsState, dataOutputs, inputs, inputNames,
475 outputNames, clock, go, reset, stall, name, stallability);
476}
477
478Block *ScheduledPipelineOp::addStage() {
479 OpBuilder builder(getContext());
480 Block *stage = builder.createBlock(&getRegion());
481
482 // Add the stage valid signal.
483 stage->addArgument(builder.getIntegerType(1), getLoc());
484 return stage;
485}
486
487void ScheduledPipelineOp::getAsmBlockArgumentNames(
488 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
489 getPipelineAsmBlockArgumentNames(*this, region, setNameFn);
490}
491
492void ScheduledPipelineOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
493 getPipelineAsmResultNames(*this, setNameFn);
494}
495
496// Implementation of getOrderedStages which also produces an error if
497// there are any cfg cycles in the pipeline.
498static FailureOr<llvm::SmallVector<Block *>>
499getOrderedStagesFailable(ScheduledPipelineOp op) {
500 llvm::DenseSet<Block *> visited;
501 Block *currentStage = op.getEntryStage();
502 llvm::SmallVector<Block *> orderedStages;
503 do {
504 if (!visited.insert(currentStage).second)
505 return op.emitOpError("pipeline contains a cycle.");
506
507 orderedStages.push_back(currentStage);
508 if (auto stageOp = dyn_cast<StageOp>(currentStage->getTerminator()))
509 currentStage = stageOp.getNextStage();
510 else
511 currentStage = nullptr;
512 } while (currentStage);
513
514 return {orderedStages};
515}
516
517llvm::SmallVector<Block *> ScheduledPipelineOp::getOrderedStages() {
518 // Should always be safe, seeing as the pipeline itself has already been
519 // verified.
520 return *getOrderedStagesFailable(*this);
521}
522
523llvm::DenseMap<Block *, unsigned> ScheduledPipelineOp::getStageMap() {
524 llvm::DenseMap<Block *, unsigned> stageMap;
525 auto orderedStages = getOrderedStages();
526 for (auto [index, stage] : llvm::enumerate(orderedStages))
527 stageMap[stage] = index;
528
529 return stageMap;
530}
531
532Block *ScheduledPipelineOp::getLastStage() { return getOrderedStages().back(); }
533
534bool ScheduledPipelineOp::isMaterialized() {
535 // We determine materialization as if any pipeline stage has an explicit
536 // input (apart from the stage valid signal).
537 return llvm::any_of(getStages(), [this](Block &block) {
538 // The entry stage doesn't count since it'll always have arguments.
539 if (&block == getEntryStage())
540 return false;
541 return block.getNumArguments() > 1;
542 });
543}
544
545LogicalResult ScheduledPipelineOp::verify() {
546 // Verify that all block are terminated properly.
547 auto &stages = getStages();
548 for (Block &stage : stages) {
549 if (stage.empty() || !isa<ReturnOp, StageOp>(stage.back()))
550 return emitOpError("all blocks must be terminated with a "
551 "`pipeline.stage` or `pipeline.return` op.");
552 }
553
554 if (failed(getOrderedStagesFailable(*this)))
555 return failure();
556
557 // Verify that every stage has a stage valid block argument.
558 for (auto [i, block] : llvm::enumerate(stages)) {
559 bool err = true;
560 if (block.getNumArguments() != 0) {
561 auto lastArgType =
562 dyn_cast<IntegerType>(block.getArguments().back().getType());
563 err = !lastArgType || lastArgType.getWidth() != 1;
564 }
565 if (err)
566 return emitOpError("block " + std::to_string(i) +
567 " must have an i1 argument as the last block argument "
568 "(stage valid signal).");
569 }
570
571 // Cache external inputs in a set for fast lookup (also includes clock, reset,
572 // and stall).
573 llvm::DenseSet<Value> extLikeInputs;
574 for (auto extInput : getExtInputs())
575 extLikeInputs.insert(extInput);
576
577 extLikeInputs.insert(getClock());
578 extLikeInputs.insert(getReset());
579 if (hasStall())
580 extLikeInputs.insert(getStall());
581
582 // Phase invariant - if any block has arguments apart from the stage valid
583 // argument, we are in register materialized mode. Check that all values
584 // used within a stage are defined within the stage.
585 bool materialized = isMaterialized();
586 if (materialized) {
587 for (auto &stage : stages) {
588 for (auto &op : stage) {
589 for (auto [index, operand] : llvm::enumerate(op.getOperands())) {
590 bool err = false;
591 if (extLikeInputs.contains(operand)) {
592 // This is an external input; legal to reference everywhere.
593 continue;
594 }
595
596 if (auto *definingOp = operand.getDefiningOp()) {
597 // Constants are allowed to be used across stages.
598 if (definingOp->hasTrait<OpTrait::ConstantLike>())
599 continue;
600 err = definingOp->getBlock() != &stage;
601 } else {
602 // This is a block argument;
603 err = !llvm::is_contained(stage.getArguments(), operand);
604 }
605
606 if (err)
607 return op.emitOpError(
608 "Pipeline is in register materialized mode - operand ")
609 << index
610 << " is defined in a different stage, which is illegal.";
611 }
612 }
613 }
614 }
615
616 if (auto stallability = getStallability()) {
617 // Only allow specifying stallability if there is a stall signal.
618 if (!hasStall())
619 return emitOpError("cannot specify stallability without a stall signal.");
620
621 // Ensure that the # of stages is equal to the length of the stallability
622 // array - the exit stage is never stallable.
623 size_t nRegisterStages = stages.size() - 1;
624 if (stallability->size() != nRegisterStages)
625 return emitOpError("stallability array must be the same length as the "
626 "number of stages. Pipeline has ")
627 << nRegisterStages << " stages but array had "
628 << stallability->size() << " elements.";
629 }
630
631 return success();
632}
633
634StageKind ScheduledPipelineOp::getStageKind(size_t stageIndex) {
635 assert(stageIndex < getNumStages() && "invalid stage index");
636
637 if (!hasStall())
638 return StageKind::Continuous;
639
640 // There is a stall signal - also check whether stage-level stallability is
641 // specified.
642 std::optional<ArrayAttr> stallability = getStallability();
643 if (!stallability) {
644 // All stages are stallable.
645 return StageKind::Stallable;
646 }
647
648 if (stageIndex < stallability->size()) {
649 bool stageIsStallable =
650 cast<BoolAttr>((*stallability)[stageIndex]).getValue();
651 if (!stageIsStallable) {
652 // This is a non-stallable stage.
653 return StageKind::NonStallable;
654 }
655 }
656
657 // Walk backwards from this stage to see if any non-stallable stage exists.
658 // If so, this is a runoff stage.
659 // TODO: This should be a pre-computed property.
660 if (stageIndex == 0)
661 return StageKind::Stallable;
662
663 for (size_t i = stageIndex - 1; i > 0; --i) {
664 if (getStageKind(i) == StageKind::NonStallable)
665 return StageKind::Runoff;
666 }
667 return StageKind::Stallable;
668}
669
670//===----------------------------------------------------------------------===//
671// ReturnOp
672//===----------------------------------------------------------------------===//
673
674LogicalResult ReturnOp::verify() {
675 Operation *parent = getOperation()->getParentOp();
676 size_t nInputs = getInputs().size();
677 auto expectedResults = TypeRange(parent->getResultTypes()).drop_back();
678 size_t expectedNResults = expectedResults.size();
679 if (nInputs != expectedNResults)
680 return emitOpError("expected ")
681 << expectedNResults << " return values, got " << nInputs << ".";
682
683 for (auto [inType, reqType] :
684 llvm::zip(getInputs().getTypes(), expectedResults)) {
685 if (inType != reqType)
686 return emitOpError("expected return value of type ")
687 << reqType << ", got " << inType << ".";
688 }
689
690 return success();
691}
692
693//===----------------------------------------------------------------------===//
694// StageOp
695//===----------------------------------------------------------------------===//
696
697// Parses the form:
698// ($name `=`)? $register : type($register)
699
700static ParseResult
702 OpAsmParser::UnresolvedOperand &v, Type &t,
703 StringAttr &name) {
704 // Parse optional name.
705 std::string nameref;
706 if (succeeded(parser.parseOptionalString(&nameref))) {
707 if (nameref.empty())
708 return parser.emitError(parser.getCurrentLocation(),
709 "name cannot be empty");
710
711 if (failed(parser.parseEqual()))
712 return parser.emitError(parser.getCurrentLocation(),
713 "expected '=' after name");
714 name = parser.getBuilder().getStringAttr(nameref);
715 } else {
716 name = parser.getBuilder().getStringAttr("");
717 }
718
719 // Parse mandatory value and type.
720 if (failed(parser.parseOperand(v)) || failed(parser.parseColonType(t)))
721 return failure();
722
723 return success();
724}
725
726// Parses the form:
727// parseOptNamedTypedAssignment (`gated by` `[` $clockGates `]`)?
728static ParseResult parseSingleStageRegister(
729 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t,
730 llvm::SmallVector<OpAsmParser::UnresolvedOperand> &clockGates,
731 StringAttr &name) {
732 if (failed(parseOptNamedTypedAssignment(parser, v, t, name)))
733 return failure();
734
735 // Parse optional gated-by clause.
736 if (failed(parser.parseOptionalKeyword("gated")))
737 return success();
738
739 if (failed(parser.parseKeyword("by")) ||
740 failed(
741 parser.parseOperandList(clockGates, OpAsmParser::Delimiter::Square)))
742 return failure();
743
744 return success();
745}
746
747// Parses the form:
748// regs( ($name `=`)? $register : type($register) (`gated by` `[` $clockGates
749// `]`)?, ...)
751 OpAsmParser &parser,
752 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &registers,
753 llvm::SmallVector<mlir::Type, 1> &registerTypes,
754 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &clockGates,
755 ArrayAttr &clockGatesPerRegister, ArrayAttr &registerNames) {
756
757 if (failed(parser.parseOptionalKeyword("regs"))) {
758 clockGatesPerRegister = parser.getBuilder().getI64ArrayAttr({});
759 return success(); // no registers to parse.
760 }
761
762 llvm::SmallVector<int64_t> clockGatesPerRegisterList;
763 llvm::SmallVector<Attribute> registerNamesList;
764 bool withNames = false;
765 if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&]() {
766 OpAsmParser::UnresolvedOperand v;
767 Type t;
768 llvm::SmallVector<OpAsmParser::UnresolvedOperand> cgs;
769 StringAttr name;
770 if (parseSingleStageRegister(parser, v, t, cgs, name))
771 return failure();
772 registers.push_back(v);
773 registerTypes.push_back(t);
774 registerNamesList.push_back(name);
775 withNames |= static_cast<bool>(name);
776 llvm::append_range(clockGates, cgs);
777 clockGatesPerRegisterList.push_back(cgs.size());
778 return success();
779 })))
780 return failure();
781
782 clockGatesPerRegister =
783 parser.getBuilder().getI64ArrayAttr(clockGatesPerRegisterList);
784 if (withNames)
785 registerNames = parser.getBuilder().getArrayAttr(registerNamesList);
786
787 return success();
788}
789
790void printStageRegisters(OpAsmPrinter &p, Operation *op, ValueRange registers,
791 TypeRange registerTypes, ValueRange clockGates,
792 ArrayAttr clockGatesPerRegister, ArrayAttr names) {
793 if (registers.empty())
794 return;
795
796 p << "regs(";
797 size_t clockGateStartIdx = 0;
798 llvm::interleaveComma(
799 llvm::enumerate(
800 llvm::zip(registers, registerTypes, clockGatesPerRegister)),
801 p, [&](auto it) {
802 size_t idx = it.index();
803 auto &[reg, type, nClockGatesAttr] = it.value();
804 if (names) {
805 if (auto nameAttr = dyn_cast<StringAttr>(names[idx]);
806 nameAttr && !nameAttr.strref().empty())
807 p << nameAttr << " = ";
808 }
809
810 p << reg << " : " << type;
811 int64_t nClockGates = cast<IntegerAttr>(nClockGatesAttr).getInt();
812 if (nClockGates == 0)
813 return;
814 p << " gated by [";
815 llvm::interleaveComma(clockGates.slice(clockGateStartIdx, nClockGates),
816 p);
817 p << "]";
818 clockGateStartIdx += nClockGates;
819 });
820 p << ")";
821}
822
823void printPassthroughs(OpAsmPrinter &p, Operation *op, ValueRange passthroughs,
824 TypeRange passthroughTypes, ArrayAttr names) {
825
826 if (passthroughs.empty())
827 return;
828
829 p << "pass(";
830 llvm::interleaveComma(
831 llvm::enumerate(llvm::zip(passthroughs, passthroughTypes)), p,
832 [&](auto it) {
833 size_t idx = it.index();
834 auto &[reg, type] = it.value();
835 if (names) {
836 if (auto nameAttr = dyn_cast<StringAttr>(names[idx]);
837 nameAttr && !nameAttr.strref().empty())
838 p << nameAttr << " = ";
839 }
840 p << reg << " : " << type;
841 });
842 p << ")";
843}
844
845// Parses the form:
846// (`pass` `(` ($name `=`)? $register : type($register), ... `)` )?
848 OpAsmParser &parser,
849 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &passthroughs,
850 llvm::SmallVector<mlir::Type, 1> &passthroughTypes,
851 ArrayAttr &passthroughNames) {
852 if (failed(parser.parseOptionalKeyword("pass")))
853 return success(); // no passthroughs to parse.
854
855 llvm::SmallVector<Attribute> passthroughsNameList;
856 bool withNames = false;
857 if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&]() {
858 OpAsmParser::UnresolvedOperand v;
859 Type t;
860 StringAttr name;
861 if (parseOptNamedTypedAssignment(parser, v, t, name))
862 return failure();
863 passthroughs.push_back(v);
864 passthroughTypes.push_back(t);
865 passthroughsNameList.push_back(name);
866 withNames |= static_cast<bool>(name);
867 return success();
868 })))
869 return failure();
870
871 if (withNames)
872 passthroughNames = parser.getBuilder().getArrayAttr(passthroughsNameList);
873
874 return success();
875}
876
877void StageOp::build(OpBuilder &odsBuilder, OperationState &odsState,
878 Block *dest, ValueRange registers,
879 ValueRange passthroughs) {
880 odsState.addSuccessors(dest);
881 odsState.addOperands(registers);
882 odsState.addOperands(passthroughs);
883 odsState.addAttribute("operandSegmentSizes",
884 odsBuilder.getDenseI32ArrayAttr(
885 {static_cast<int32_t>(registers.size()),
886 static_cast<int32_t>(passthroughs.size()),
887 /*clock gates*/ static_cast<int32_t>(0)}));
888 llvm::SmallVector<int64_t> clockGatesPerRegister(registers.size(), 0);
889 odsState.addAttribute("clockGatesPerRegister",
890 odsBuilder.getI64ArrayAttr(clockGatesPerRegister));
891}
892
893void StageOp::build(OpBuilder &odsBuilder, OperationState &odsState,
894 Block *dest, ValueRange registers, ValueRange passthroughs,
895 llvm::ArrayRef<llvm::SmallVector<Value>> clockGateList,
896 mlir::ArrayAttr registerNames,
897 mlir::ArrayAttr passthroughNames) {
898 build(odsBuilder, odsState, dest, registers, passthroughs);
899
900 llvm::SmallVector<Value> clockGates;
901 llvm::SmallVector<int64_t> clockGatesPerRegister(registers.size(), 0);
902 for (auto gates : clockGateList) {
903 llvm::append_range(clockGates, gates);
904 clockGatesPerRegister.push_back(gates.size());
905 }
906 odsState.attributes.set("clockGatesPerRegister",
907 odsBuilder.getI64ArrayAttr(clockGatesPerRegister));
908 odsState.addOperands(clockGates);
909
910 if (registerNames)
911 odsState.addAttribute("registerNames", registerNames);
912
913 if (passthroughNames)
914 odsState.addAttribute("passthroughNames", passthroughNames);
915}
916
917ValueRange StageOp::getClockGatesForReg(unsigned regIdx) {
918 assert(regIdx < getRegisters().size() && "register index out of bounds.");
919
920 // TODO: This could be optimized quite a bit if we didn't store clock
921 // gates per register as an array of sizes... look into using properties
922 // and maybe attaching a more complex datastructure to reduce compute
923 // here.
924
925 unsigned clockGateStartIdx = 0;
926 for (auto [index, nClockGatesAttr] :
927 llvm::enumerate(getClockGatesPerRegister().getAsRange<IntegerAttr>())) {
928 int64_t nClockGates = nClockGatesAttr.getInt();
929 if (index == regIdx) {
930 // This is the register we are looking for.
931 return getClockGates().slice(clockGateStartIdx, nClockGates);
932 }
933 // Increment the start index by the number of clock gates for this
934 // register.
935 clockGateStartIdx += nClockGates;
936 }
937
938 llvm_unreachable("register index out of bounds.");
939}
940
941LogicalResult StageOp::verify() {
942 // Verify that the target block has the correct arguments as this stage
943 // op.
944 llvm::SmallVector<Type> expectedTargetArgTypes;
945 llvm::append_range(expectedTargetArgTypes, getRegisters().getTypes());
946 llvm::append_range(expectedTargetArgTypes, getPassthroughs().getTypes());
947 Block *targetStage = getNextStage();
948 // Expected types is everything but the stage valid signal.
949 TypeRange targetStageArgTypes =
950 TypeRange(targetStage->getArgumentTypes()).drop_back();
951
952 if (targetStageArgTypes.size() != expectedTargetArgTypes.size())
953 return emitOpError("expected ") << expectedTargetArgTypes.size()
954 << " arguments in the target stage, got "
955 << targetStageArgTypes.size() << ".";
956
957 for (auto [index, it] : llvm::enumerate(
958 llvm::zip(expectedTargetArgTypes, targetStageArgTypes))) {
959 auto [arg, barg] = it;
960 if (arg != barg)
961 return emitOpError("expected target stage argument ")
962 << index << " to have type " << arg << ", got " << barg << ".";
963 }
964
965 // Verify that the clock gate index list is equally sized to the # of
966 // registers.
967 if (getClockGatesPerRegister().size() != getRegisters().size())
968 return emitOpError("expected clockGatesPerRegister to be equally sized to "
969 "the number of registers.");
970
971 // Verify that, if provided, the list of register names is equally sized
972 // to the number of registers.
973 if (auto regNames = getRegisterNames()) {
974 if (regNames->size() != getRegisters().size())
975 return emitOpError("expected registerNames to be equally sized to "
976 "the number of registers.");
977 }
978
979 // Verify that, if provided, the list of passthrough names is equally sized
980 // to the number of passthroughs.
981 if (auto passthroughNames = getPassthroughNames()) {
982 if (passthroughNames->size() != getPassthroughs().size())
983 return emitOpError("expected passthroughNames to be equally sized to "
984 "the number of passthroughs.");
985 }
986
987 return success();
988}
989
990//===----------------------------------------------------------------------===//
991// LatencyOp
992//===----------------------------------------------------------------------===//
993
994LogicalResult LatencyOp::verify() {
995 ScheduledPipelineOp scheduledPipelineParent =
996 dyn_cast<ScheduledPipelineOp>(getOperation()->getParentOp());
997
998 if (!scheduledPipelineParent) {
999 // Nothing to verify, got to assume that anything goes in an unscheduled
1000 // pipeline.
1001 return success();
1002 }
1003
1004 // Verify that there's at least one result type. Latency ops don't make sense
1005 // if they're not delaying anything, and we're not yet prepared to support
1006 // side-effectful bodies.
1007 if (getNumResults() == 0)
1008 return emitOpError("expected at least one result type.");
1009
1010 // Verify that the resulting values aren't referenced before they are
1011 // accessible.
1012 size_t latency = getLatency();
1013 Block *definingStage = getOperation()->getBlock();
1014
1015 llvm::DenseMap<Block *, unsigned> stageMap =
1016 scheduledPipelineParent.getStageMap();
1017
1018 auto stageDistance = [&](Block *from, Block *to) {
1019 assert(stageMap.count(from) && "stage 'from' not contained in pipeline");
1020 assert(stageMap.count(to) && "stage 'to' not contained in pipeline");
1021 int64_t fromStage = stageMap[from];
1022 int64_t toStage = stageMap[to];
1023 return toStage - fromStage;
1024 };
1025
1026 for (auto [i, res] : llvm::enumerate(getResults())) {
1027 for (auto &use : res.getUses()) {
1028 auto *user = use.getOwner();
1029
1030 // The user may reside within a block which is not a stage (e.g.
1031 // inside a pipeline.latency op). Determine the stage which this use
1032 // resides within.
1033 Block *userStage =
1034 getParentStageInPipeline(scheduledPipelineParent, user);
1035 unsigned useDistance = stageDistance(definingStage, userStage);
1036
1037 // Is this a stage op and is the value passed through? if so, this is
1038 // a legal use.
1039 StageOp stageOp = dyn_cast<StageOp>(user);
1040 if (userStage == definingStage && stageOp) {
1041 if (llvm::is_contained(stageOp.getPassthroughs(), res))
1042 continue;
1043 }
1044
1045 // The use is not a passthrough. Check that the distance between
1046 // the defining stage and the user stage is at least the latency of
1047 // the result.
1048 if (useDistance < latency) {
1049 auto diag = emitOpError("result ")
1050 << i << " is used before it is available.";
1051 diag.attachNote(user->getLoc())
1052 << "use was operand " << use.getOperandNumber()
1053 << ". The result is available " << latency - useDistance
1054 << " stages later than this use.";
1055 return diag;
1056 }
1057 }
1058 }
1059 return success();
1060}
1061
1062//===----------------------------------------------------------------------===//
1063// LatencyReturnOp
1064//===----------------------------------------------------------------------===//
1065
1066LogicalResult LatencyReturnOp::verify() {
1067 LatencyOp parent = cast<LatencyOp>(getOperation()->getParentOp());
1068 size_t nInputs = getInputs().size();
1069 size_t nResults = parent->getNumResults();
1070 if (nInputs != nResults)
1071 return emitOpError("expected ")
1072 << nResults << " return values, got " << nInputs << ".";
1073
1074 for (auto [inType, reqType] :
1075 llvm::zip(getInputs().getTypes(), parent->getResultTypes())) {
1076 if (inType != reqType)
1077 return emitOpError("expected return value of type ")
1078 << reqType << ", got " << inType << ".";
1079 }
1080
1081 return success();
1082}
1083
1084#define GET_OP_CLASSES
1085#include "circt/Dialect/Pipeline/Pipeline.cpp.inc"
1086
1087void PipelineDialect::initialize() {
1088 addOperations<
1089#define GET_OP_LIST
1090#include "circt/Dialect/Pipeline/Pipeline.cpp.inc"
1091 >();
1092}
assert(baseType &&"element must be base type")
static InstancePath empty
static void buildPipelineLikeOp(OpBuilder &odsBuilder, OperationState &odsState, TypeRange dataOutputs, ValueRange inputs, ArrayAttr inputNames, ArrayAttr outputNames, Value clock, Value go, Value reset, Value stall, StringAttr name, ArrayAttr stallability)
static ParseResult parseSingleStageRegister(OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t, llvm::SmallVector< OpAsmParser::UnresolvedOperand > &clockGates, StringAttr &name)
static FailureOr< llvm::SmallVector< Block * > > getOrderedStagesFailable(ScheduledPipelineOp op)
static ParseResult parseKeywordAndOperand(OpAsmParser &p, StringRef keyword, OpAsmParser::UnresolvedOperand &op)
static void printOutputList(OpAsmPrinter &p, TypeRange types, ArrayAttr names)
ParseResult parsePassthroughs(OpAsmParser &parser, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &passthroughs, llvm::SmallVector< mlir::Type, 1 > &passthroughTypes, ArrayAttr &passthroughNames)
static void printPipelineOp(OpAsmPrinter &p, TPipelineOp op)
void printPassthroughs(OpAsmPrinter &p, Operation *op, ValueRange passthroughs, TypeRange passthroughTypes, ArrayAttr names)
static ParseResult parseOptNamedTypedAssignment(OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t, StringAttr &name)
static ParseResult parsePipelineOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
void printStageRegisters(OpAsmPrinter &p, Operation *op, ValueRange registers, TypeRange registerTypes, ValueRange clockGates, ArrayAttr clockGatesPerRegister, ArrayAttr names)
static void getPipelineAsmBlockArgumentNames(TPipelineOp op, mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn)
static ParseResult parseOutputList(OpAsmParser &parser, llvm::SmallVector< Type > &inputTypes, mlir::ArrayAttr &outputNames)
static void printKeywordOperand(OpAsmPrinter &p, StringRef keyword, Value value)
ParseResult parseStageRegisters(OpAsmParser &parser, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &registers, llvm::SmallVector< mlir::Type, 1 > &registerTypes, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &clockGates, ArrayAttr &clockGatesPerRegister, ArrayAttr &registerNames)
static void getPipelineAsmResultNames(TPipelineOp op, OpAsmSetValueNameFn setNameFn)
ParseResult parseInitializerList(mlir::OpAsmParser &parser, llvm::SmallVector< mlir::OpAsmParser::Argument > &inputArguments, llvm::SmallVector< mlir::OpAsmParser::UnresolvedOperand > &inputOperands, llvm::SmallVector< Type > &inputTypes, ArrayAttr &inputNames)
Parses an initializer.
void printInitializerList(OpAsmPrinter &p, ValueRange ins, ArrayRef< BlockArgument > args)
llvm::SmallVector< Value > getValuesDefinedOutsideRegion(Region &region)
Block * getParentStageInPipeline(ScheduledPipelineOp pipeline, Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:182