CIRCT 21.0.0git
Loading...
Searching...
No Matches
PipelineOps.cpp
Go to the documentation of this file.
1//===- PipelineOps.h - Pipeline MLIR Operations -----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the Pipeline ops.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/Interfaces/FunctionImplementation.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/FormatVariadic.h"
21
22using namespace mlir;
23using namespace circt;
24using namespace circt::pipeline;
25using namespace circt::parsing_util;
26
27#include "circt/Dialect/Pipeline/PipelineDialect.cpp.inc"
28
29#define DEBUG_TYPE "pipeline-ops"
30
31llvm::SmallVector<Value>
33 llvm::SetVector<Value> values;
34 region.walk([&](Operation *op) {
35 for (auto operand : op->getOperands()) {
36 if (region.isAncestor(operand.getParentRegion()))
37 continue;
38 values.insert(operand);
39 }
40 });
41 return values.takeVector();
42}
43
44Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
45 Block *block) {
46 // Optional debug check - ensure that 'block' eventually leads to the
47 // pipeline.
48 LLVM_DEBUG({
49 Operation *directParent = block->getParentOp();
50 if (directParent != pipeline) {
51 auto indirectParent =
52 directParent->getParentOfType<ScheduledPipelineOp>();
53 assert(indirectParent == pipeline && "block is not in the pipeline");
54 }
55 });
56
57 while (block && block->getParent() != &pipeline.getRegion()) {
58 // Go one level up.
59 block = block->getParent()->getParentOp()->getBlock();
60 }
61
62 // This is a block within the pipeline region, so it must be a stage.
63 return block;
64}
65
66Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
67 Operation *op) {
68 return getParentStageInPipeline(pipeline, op->getBlock());
69}
70
71Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
72 Value v) {
73 if (isa<BlockArgument>(v))
74 return getParentStageInPipeline(pipeline,
75 cast<BlockArgument>(v).getOwner());
76 return getParentStageInPipeline(pipeline, v.getDefiningOp());
77}
78
79//===----------------------------------------------------------------------===//
80// Fancy pipeline-like op printer/parser functions.
81//===----------------------------------------------------------------------===//
82
83// Parses a list of operands on the format:
84// (name : type, ...)
85static ParseResult parseOutputList(OpAsmParser &parser,
86 llvm::SmallVector<Type> &inputTypes,
87 mlir::ArrayAttr &outputNames) {
88
89 llvm::SmallVector<Attribute> names;
90 if (parser.parseCommaSeparatedList(
91 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
92 StringRef name;
93 Type type;
94 if (parser.parseKeyword(&name) || parser.parseColonType(type))
95 return failure();
96
97 inputTypes.push_back(type);
98 names.push_back(StringAttr::get(parser.getContext(), name));
99 return success();
100 }))
101 return failure();
102
103 outputNames = ArrayAttr::get(parser.getContext(), names);
104 return success();
105}
106
107static void printOutputList(OpAsmPrinter &p, TypeRange types, ArrayAttr names) {
108 p << "(";
109 llvm::interleaveComma(llvm::zip(types, names), p, [&](auto it) {
110 auto [type, name] = it;
111 p.printKeywordOrString(cast<StringAttr>(name).str());
112 p << " : " << type;
113 });
114 p << ")";
115}
116
117static ParseResult parseKeywordAndOperand(OpAsmParser &p, StringRef keyword,
118 OpAsmParser::UnresolvedOperand &op) {
119 if (p.parseKeyword(keyword) || p.parseLParen() || p.parseOperand(op) ||
120 p.parseRParen())
121 return failure();
122 return success();
123}
124
125// Assembly format is roughly:
126// ( $name )? initializer-list stall (%stall = $stall)?
127// clock (%clock) reset (%reset) go(%go) entryEnable(%en) {
128// --- elided inner block ---
129static ParseResult parsePipelineOp(mlir::OpAsmParser &parser,
130 mlir::OperationState &result) {
131 std::string name;
132 if (succeeded(parser.parseOptionalString(&name)))
133 result.addAttribute("name", parser.getBuilder().getStringAttr(name));
134
135 llvm::SmallVector<OpAsmParser::UnresolvedOperand> inputOperands;
136 llvm::SmallVector<OpAsmParser::Argument> inputArguments;
137 llvm::SmallVector<Type> inputTypes;
138 ArrayAttr inputNames;
139 if (parseInitializerList(parser, inputArguments, inputOperands, inputTypes,
140 inputNames))
141 return failure();
142 result.addAttribute("inputNames", inputNames);
143
144 Type i1 = parser.getBuilder().getI1Type();
145
146 OpAsmParser::UnresolvedOperand stallOperand, clockOperand, resetOperand,
147 goOperand;
148
149 // Parse optional 'stall (%stallArg)'
150 bool withStall = false;
151 if (succeeded(parser.parseOptionalKeyword("stall"))) {
152 if (parser.parseLParen() || parser.parseOperand(stallOperand) ||
153 parser.parseRParen())
154 return failure();
155 withStall = true;
156 }
157
158 // Parse clock, reset, and go.
159 if (parseKeywordAndOperand(parser, "clock", clockOperand))
160 return failure();
161
162 // Parse optional 'reset (%resetArg)'
163 bool withReset = false;
164 if (succeeded(parser.parseOptionalKeyword("reset"))) {
165 if (parser.parseLParen() || parser.parseOperand(resetOperand) ||
166 parser.parseRParen())
167 return failure();
168 withReset = true;
169 }
170
171 if (parseKeywordAndOperand(parser, "go", goOperand))
172 return failure();
173
174 // Parse entry stage enable block argument.
175 OpAsmParser::Argument entryEnable;
176 entryEnable.type = i1;
177 if (parser.parseKeyword("entryEn") || parser.parseLParen() ||
178 parser.parseArgument(entryEnable) || parser.parseRParen())
179 return failure();
180
181 // Optional attribute dict
182 if (parser.parseOptionalAttrDict(result.attributes))
183 return failure();
184
185 // Parse the output assignment list
186 if (parser.parseArrow())
187 return failure();
188
189 llvm::SmallVector<Type> outputTypes;
190 ArrayAttr outputNames;
191 if (parseOutputList(parser, outputTypes, outputNames))
192 return failure();
193 result.addTypes(outputTypes);
194 result.addAttribute("outputNames", outputNames);
195
196 // And the implicit 'done' output.
197 result.addTypes({i1});
198
199 // All operands have been parsed - resolve.
200 if (parser.resolveOperands(inputOperands, inputTypes, parser.getNameLoc(),
201 result.operands))
202 return failure();
203
204 if (withStall) {
205 if (parser.resolveOperand(stallOperand, i1, result.operands))
206 return failure();
207 }
208
209 Type clkType = seq::ClockType::get(parser.getContext());
210 if (parser.resolveOperand(clockOperand, clkType, result.operands))
211 return failure();
212
213 if (withReset && parser.resolveOperand(resetOperand, i1, result.operands))
214 return failure();
215
216 if (parser.resolveOperand(goOperand, i1, result.operands))
217 return failure();
218
219 // Assemble the body region block arguments - this is where the magic happens
220 // and why we're doing a custom printer/parser - if the user had to magically
221 // know the order of these block arguments, we're asking for issues.
222 SmallVector<OpAsmParser::Argument> regionArgs;
223
224 // First we add the input arguments.
225 llvm::append_range(regionArgs, inputArguments);
226 // Then the internal entry stage enable block argument.
227 regionArgs.push_back(entryEnable);
228
229 // Parse the body region.
230 Region *body = result.addRegion();
231 if (parser.parseRegion(*body, regionArgs))
232 return failure();
233
234 result.addAttribute("operandSegmentSizes",
235 parser.getBuilder().getDenseI32ArrayAttr(
236 {static_cast<int32_t>(inputTypes.size()),
237 static_cast<int32_t>(withStall ? 1 : 0),
238 /*clock*/ static_cast<int32_t>(1),
239 /*reset*/ static_cast<int32_t>(withReset ? 1 : 0),
240 /*go*/ static_cast<int32_t>(1)}));
241
242 return success();
243}
244
245static void printKeywordOperand(OpAsmPrinter &p, StringRef keyword,
246 Value value) {
247 p << keyword << "(";
248 p.printOperand(value);
249 p << ")";
250}
251
252template <typename TPipelineOp>
253static void printPipelineOp(OpAsmPrinter &p, TPipelineOp op) {
254 if (auto name = op.getNameAttr()) {
255 p << " \"" << name.getValue() << "\"";
256 }
257
258 // Print the input list.
259 printInitializerList(p, op.getInputs(), op.getInnerInputs());
260 p << " ";
261
262 // Print the optional stall.
263 if (op.hasStall()) {
264 printKeywordOperand(p, "stall", op.getStall());
265 p << " ";
266 }
267
268 // Print the clock, reset, and go.
269 printKeywordOperand(p, "clock", op.getClock());
270 p << " ";
271 if (op.hasReset()) {
272 printKeywordOperand(p, "reset", op.getReset());
273 p << " ";
274 }
275 printKeywordOperand(p, "go", op.getGo());
276 p << " ";
277
278 // Print the entry enable block argument.
279 p << "entryEn(";
280 p.printRegionArgument(
281 cast<BlockArgument>(op.getStageEnableSignal(static_cast<size_t>(0))), {},
282 /*omitType*/ true);
283 p << ") ";
284
285 // Print the optional attribute dict.
286 p.printOptionalAttrDict(op->getAttrs(),
287 /*elidedAttrs=*/{"name", "operandSegmentSizes",
288 "outputNames", "inputNames"});
289 p << " -> ";
290
291 // Print the output list.
292 printOutputList(p, op.getDataOutputs().getTypes(), op.getOutputNames());
293
294 p << " ";
295
296 // Print the inner region, eliding the entry block arguments - we've already
297 // defined these in our initializer lists.
298 p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false,
299 /*printBlockTerminators=*/true);
300}
301
302//===----------------------------------------------------------------------===//
303// UnscheduledPipelineOp
304//===----------------------------------------------------------------------===//
305
306static void buildPipelineLikeOp(OpBuilder &odsBuilder, OperationState &odsState,
307 TypeRange dataOutputs, ValueRange inputs,
308 ArrayAttr inputNames, ArrayAttr outputNames,
309 Value clock, Value go, Value reset, Value stall,
310 StringAttr name, ArrayAttr stallability) {
311 odsState.addOperands(inputs);
312 if (stall)
313 odsState.addOperands(stall);
314 odsState.addOperands(clock);
315 odsState.addOperands(reset);
316 odsState.addOperands(go);
317 if (name)
318 odsState.addAttribute("name", name);
319
320 odsState.addAttribute(
321 "operandSegmentSizes",
322 odsBuilder.getDenseI32ArrayAttr(
323 {static_cast<int32_t>(inputs.size()),
324 static_cast<int32_t>(stall ? 1 : 0), static_cast<int32_t>(1),
325 static_cast<int32_t>(1), static_cast<int32_t>(1)}));
326
327 odsState.addAttribute("inputNames", inputNames);
328 odsState.addAttribute("outputNames", outputNames);
329
330 auto *region = odsState.addRegion();
331 odsState.addTypes(dataOutputs);
332
333 // Add the implicit done output signal.
334 Type i1 = odsBuilder.getIntegerType(1);
335 odsState.addTypes({i1});
336
337 // Add the entry stage - arguments order:
338 // 1. Inputs
339 // 2. Stall (opt)
340 // 3. Clock
341 // 4. Reset
342 // 5. Go
343 auto &entryBlock = region->emplaceBlock();
344 llvm::SmallVector<Location> entryArgLocs(inputs.size(), odsState.location);
345 entryBlock.addArguments(
346 inputs.getTypes(),
347 llvm::SmallVector<Location>(inputs.size(), odsState.location));
348 if (stall)
349 entryBlock.addArgument(i1, odsState.location);
350 entryBlock.addArgument(i1, odsState.location);
351 entryBlock.addArgument(i1, odsState.location);
352
353 // entry stage valid signal.
354 entryBlock.addArgument(i1, odsState.location);
355
356 if (stallability)
357 odsState.addAttribute("stallability", stallability);
358}
359
360template <typename TPipelineOp>
361static void getPipelineAsmResultNames(TPipelineOp op,
362 OpAsmSetValueNameFn setNameFn) {
363 for (auto [res, name] :
364 llvm::zip(op.getDataOutputs(),
365 op.getOutputNames().template getAsValueRange<StringAttr>()))
366 setNameFn(res, name);
367 setNameFn(op.getDone(), "done");
368}
369
370template <typename TPipelineOp>
371static void
372getPipelineAsmBlockArgumentNames(TPipelineOp op, mlir::Region &region,
373 mlir::OpAsmSetValueNameFn setNameFn) {
374 for (auto [i, block] : llvm::enumerate(op.getRegion())) {
375 if (Block *predBlock = block.getSinglePredecessor()) {
376 // Predecessor stageOp might have register and passthrough names
377 // specified, which we can use to name the block arguments.
378 auto predStageOp = cast<StageOp>(predBlock->getTerminator());
379 size_t nRegs = predStageOp.getRegisters().size();
380 auto nPassthrough = predStageOp.getPassthroughs().size();
381
382 auto regNames = predStageOp.getRegisterNames();
383 auto passthroughNames = predStageOp.getPassthroughNames();
384
385 // Register naming...
386 for (size_t regI = 0; regI < nRegs; ++regI) {
387 auto arg = block.getArguments()[regI];
388
389 if (regNames) {
390 auto nameAttr = dyn_cast<StringAttr>((*regNames)[regI]);
391 if (nameAttr && !nameAttr.strref().empty()) {
392 setNameFn(arg, nameAttr);
393 continue;
394 }
395 }
396 setNameFn(arg, llvm::formatv("s{0}_reg{1}", i, regI).str());
397 }
398
399 // Passthrough naming...
400 for (size_t passthroughI = 0; passthroughI < nPassthrough;
401 ++passthroughI) {
402 auto arg = block.getArguments()[nRegs + passthroughI];
403
404 if (passthroughNames) {
405 auto nameAttr =
406 dyn_cast<StringAttr>((*passthroughNames)[passthroughI]);
407 if (nameAttr && !nameAttr.strref().empty()) {
408 setNameFn(arg, nameAttr);
409 continue;
410 }
411 }
412 setNameFn(arg, llvm::formatv("s{0}_pass{1}", i, passthroughI).str());
413 }
414 } else {
415 // This is the entry stage - name the arguments according to the input
416 // names.
417 for (auto [inputArg, inputName] :
418 llvm::zip(op.getInnerInputs(),
419 op.getInputNames().template getAsValueRange<StringAttr>()))
420 setNameFn(inputArg, inputName);
421 }
422
423 // Last argument in any stage is the stage enable signal.
424 setNameFn(block.getArguments().back(),
425 llvm::formatv("s{0}_enable", i).str());
426 }
427}
428
429void UnscheduledPipelineOp::print(OpAsmPrinter &p) {
430 printPipelineOp(p, *this);
431}
432
433ParseResult UnscheduledPipelineOp::parse(OpAsmParser &parser,
434 OperationState &result) {
435 return parsePipelineOp(parser, result);
436}
437
438void UnscheduledPipelineOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
439 getPipelineAsmResultNames(*this, setNameFn);
440}
441
442void UnscheduledPipelineOp::getAsmBlockArgumentNames(
443 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
444 getPipelineAsmBlockArgumentNames(*this, region, setNameFn);
445}
446
447void UnscheduledPipelineOp::build(OpBuilder &odsBuilder,
448 OperationState &odsState,
449 TypeRange dataOutputs, ValueRange inputs,
450 ArrayAttr inputNames, ArrayAttr outputNames,
451 Value clock, Value go, Value reset,
452 Value stall, StringAttr name,
453 ArrayAttr stallability) {
454 buildPipelineLikeOp(odsBuilder, odsState, dataOutputs, inputs, inputNames,
455 outputNames, clock, go, reset, stall, name, stallability);
456}
457
458//===----------------------------------------------------------------------===//
459// ScheduledPipelineOp
460//===----------------------------------------------------------------------===//
461
462void ScheduledPipelineOp::print(OpAsmPrinter &p) { printPipelineOp(p, *this); }
463
464ParseResult ScheduledPipelineOp::parse(OpAsmParser &parser,
465 OperationState &result) {
466 return parsePipelineOp(parser, result);
467}
468
469void ScheduledPipelineOp::build(OpBuilder &odsBuilder, OperationState &odsState,
470 TypeRange dataOutputs, ValueRange inputs,
471 ArrayAttr inputNames, ArrayAttr outputNames,
472 Value clock, Value go, Value reset, Value stall,
473 StringAttr name, ArrayAttr stallability) {
474 buildPipelineLikeOp(odsBuilder, odsState, dataOutputs, inputs, inputNames,
475 outputNames, clock, go, reset, stall, name, stallability);
476}
477
478Block *ScheduledPipelineOp::addStage() {
479 OpBuilder builder(getContext());
480 Block *stage = builder.createBlock(&getRegion());
481
482 // Add the stage valid signal.
483 stage->addArgument(builder.getIntegerType(1), getLoc());
484 return stage;
485}
486
487void ScheduledPipelineOp::getAsmBlockArgumentNames(
488 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
489 getPipelineAsmBlockArgumentNames(*this, region, setNameFn);
490}
491
492void ScheduledPipelineOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
493 getPipelineAsmResultNames(*this, setNameFn);
494}
495
496// Implementation of getOrderedStages which also produces an error if
497// there are any cfg cycles in the pipeline.
498static FailureOr<llvm::SmallVector<Block *>>
499getOrderedStagesFailable(ScheduledPipelineOp op) {
500 llvm::DenseSet<Block *> visited;
501 Block *currentStage = op.getEntryStage();
502 llvm::SmallVector<Block *> orderedStages;
503 do {
504 if (!visited.insert(currentStage).second)
505 return op.emitOpError("pipeline contains a cycle.");
506
507 orderedStages.push_back(currentStage);
508 if (auto stageOp = dyn_cast<StageOp>(currentStage->getTerminator()))
509 currentStage = stageOp.getNextStage();
510 else
511 currentStage = nullptr;
512 } while (currentStage);
513
514 return {orderedStages};
515}
516
517llvm::SmallVector<Block *> ScheduledPipelineOp::getOrderedStages() {
518 // Should always be safe, seeing as the pipeline itself has already been
519 // verified.
520 return *getOrderedStagesFailable(*this);
521}
522
523llvm::DenseMap<Block *, unsigned> ScheduledPipelineOp::getStageMap() {
524 llvm::DenseMap<Block *, unsigned> stageMap;
525 auto orderedStages = getOrderedStages();
526 for (auto [index, stage] : llvm::enumerate(orderedStages))
527 stageMap[stage] = index;
528
529 return stageMap;
530}
531
532Block *ScheduledPipelineOp::getLastStage() { return getOrderedStages().back(); }
533
534bool ScheduledPipelineOp::isMaterialized() {
535 // We determine materialization as if any pipeline stage has an explicit
536 // input (apart from the stage enable signal).
537 return llvm::any_of(getStages(), [this](Block &block) {
538 // The entry stage doesn't count since it'll always have arguments.
539 if (&block == getEntryStage())
540 return false;
541 return block.getNumArguments() > 1;
542 });
543}
544
545// Returns true if 'current' is nested somewhere within the 'parent' block,
546// or current == parent.
547// `stopAt` is provided as a termination condition for the recursive lookup.
548// Once stopAt is encountered, `isNestedBlock` will return false.
549static bool isNestedBlock(Block *stopAt, Block *parent, Block *current) {
550 while (current) {
551 if (current == stopAt)
552 return false;
553 if (current == parent)
554 return true;
555 current = current->getParentOp()->getBlock();
556 }
557 return false;
558}
559
560// Check whether the value referenced by `use` is defined within the provided
561// `stage`. It is assumed that the OpOperand `use` (i.e. the operation that owns
562// `use`) is defined within `stage`.
563// `stopAt` is provided as a termination condition for the recursive lookup.
564// Once stopAt is encountered, `isNestedBlock` will return false.
565static bool useDefinedInStage(Block *stopAt, Block *stage, OpOperand &use) {
566 Block *useBlock = use.getOwner()->getBlock();
567 Block *definingBlock = use.get().getParentBlock();
568
569 assert(isNestedBlock(stopAt, stage, useBlock) &&
570 "use` must originate from within `stage`");
571
572 // Common-case checks...
573 if (useBlock == definingBlock || stage == definingBlock)
574 return true;
575
576 // Else, recurse upwards from the defining block to see if we can find the
577 // stage.
578 Block *currBlock = definingBlock;
579 return isNestedBlock(stopAt, stage, currBlock);
580}
581
582LogicalResult ScheduledPipelineOp::verify() {
583 // Verify that all block are terminated properly.
584 auto &stages = getStages();
585 for (Block &stage : stages) {
586 if (stage.empty() || !isa<ReturnOp, StageOp>(stage.back()))
587 return emitOpError("all blocks must be terminated with a "
588 "`pipeline.stage` or `pipeline.return` op.");
589 }
590
591 if (failed(getOrderedStagesFailable(*this)))
592 return failure();
593
594 // Verify that every stage has a stage valid block argument.
595 for (auto [i, block] : llvm::enumerate(stages)) {
596 bool err = true;
597 if (block.getNumArguments() != 0) {
598 auto lastArgType =
599 dyn_cast<IntegerType>(block.getArguments().back().getType());
600 err = !lastArgType || lastArgType.getWidth() != 1;
601 }
602 if (err)
603 return emitOpError("block " + std::to_string(i) +
604 " must have an i1 argument as the last block argument "
605 "(stage valid signal).");
606 }
607
608 // Cache external inputs in a set for fast lookup (also includes clock, reset,
609 // and stall).
610 llvm::DenseSet<Value> extLikeInputs;
611 for (auto extInput : getExtInputs())
612 extLikeInputs.insert(extInput);
613
614 extLikeInputs.insert(getClock());
615 extLikeInputs.insert(getReset());
616 if (hasStall())
617 extLikeInputs.insert(getStall());
618
619 // Phase invariant - Check that all values used within a stage are valid
620 // based on the materialization mode. This is a walk, since this condition
621 // should also apply to nested operations.
622 bool materialized = isMaterialized();
623 Block *parentBlock = getOperation()->getBlock();
624 for (auto &stage : stages) {
625 auto walkRes = stage.walk([&](Operation *op) {
626 // Skip pipeline.src operations in non-materialized mode
627 if (isa<SourceOp>(op)) {
628 if (materialized) {
629 op->emitOpError(
630 "Pipeline is in register materialized mode - pipeline.src "
631 "operations are not allowed");
632 return WalkResult::interrupt();
633 }
634
635 // In non-materialized mode, pipeline.src operations are required, and
636 // is what is implicitly allowing cross-stage referenced by not
637 // reaching the below verification code.
638 return WalkResult::advance();
639 }
640
641 for (auto [index, operand] : llvm::enumerate(op->getOpOperands())) {
642 // External inputs (including clock, reset, stall) are allowed
643 // everywhere
644 if (extLikeInputs.contains(operand.get()))
645 continue;
646
647 // Constant-like inputs are allowed everywhere
648 if (auto *definingOp = operand.get().getDefiningOp()) {
649 // Constants are allowed to be used across stages.
650 if (definingOp->hasTrait<OpTrait::ConstantLike>())
651 continue;
652 }
653
654 // Values must always be defined in the same stage.
655 // Materialization mode defines the actual mitigation method.
656 if (!useDefinedInStage(parentBlock, &stage, operand)) {
657 auto err = op->emitOpError("operand ")
658 << index << " is defined in a different stage. ";
659 if (materialized) {
660 err << "Value should have been passed through block arguments";
661 } else {
662 err << "Value should have been passed through a `pipeline.src` "
663 "op";
664 }
665 return WalkResult::interrupt();
666 }
667 }
668
669 return WalkResult::advance();
670 });
671
672 if (walkRes.wasInterrupted())
673 return failure();
674 }
675
676 if (auto stallability = getStallability()) {
677 // Only allow specifying stallability if there is a stall signal.
678 if (!hasStall())
679 return emitOpError("cannot specify stallability without a stall signal.");
680
681 // Ensure that the # of stages is equal to the length of the stallability
682 // array - the exit stage is never stallable.
683 size_t nRegisterStages = stages.size() - 1;
684 if (stallability->size() != nRegisterStages)
685 return emitOpError("stallability array must be the same length as the "
686 "number of stages. Pipeline has ")
687 << nRegisterStages << " stages but array had "
688 << stallability->size() << " elements.";
689 }
690
691 return success();
692}
693
694StageKind ScheduledPipelineOp::getStageKind(size_t stageIndex) {
695 assert(stageIndex < getNumStages() && "invalid stage index");
696
697 if (!hasStall())
698 return StageKind::Continuous;
699
700 // There is a stall signal - also check whether stage-level stallability is
701 // specified.
702 std::optional<ArrayAttr> stallability = getStallability();
703 if (!stallability) {
704 // All stages are stallable.
705 return StageKind::Stallable;
706 }
707
708 if (stageIndex < stallability->size()) {
709 bool stageIsStallable =
710 cast<BoolAttr>((*stallability)[stageIndex]).getValue();
711 if (!stageIsStallable) {
712 // This is a non-stallable stage.
713 return StageKind::NonStallable;
714 }
715 }
716
717 // Walk backwards from this stage to see if any non-stallable stage exists.
718 // If so, this is a runoff stage.
719 // TODO: This should be a pre-computed property.
720 if (stageIndex == 0)
721 return StageKind::Stallable;
722
723 for (size_t i = stageIndex - 1; i > 0; --i) {
724 if (getStageKind(i) == StageKind::NonStallable)
725 return StageKind::Runoff;
726 }
727 return StageKind::Stallable;
728}
729
730//===----------------------------------------------------------------------===//
731// ReturnOp
732//===----------------------------------------------------------------------===//
733
734LogicalResult ReturnOp::verify() {
735 Operation *parent = getOperation()->getParentOp();
736 size_t nInputs = getInputs().size();
737 auto expectedResults = TypeRange(parent->getResultTypes()).drop_back();
738 size_t expectedNResults = expectedResults.size();
739 if (nInputs != expectedNResults)
740 return emitOpError("expected ")
741 << expectedNResults << " return values, got " << nInputs << ".";
742
743 for (auto [inType, reqType] :
744 llvm::zip(getInputs().getTypes(), expectedResults)) {
745 if (inType != reqType)
746 return emitOpError("expected return value of type ")
747 << reqType << ", got " << inType << ".";
748 }
749
750 return success();
751}
752
753//===----------------------------------------------------------------------===//
754// StageOp
755//===----------------------------------------------------------------------===//
756
757// Parses the form:
758// ($name `=`)? $register : type($register)
759
760static ParseResult
762 OpAsmParser::UnresolvedOperand &v, Type &t,
763 StringAttr &name) {
764 // Parse optional name.
765 std::string nameref;
766 if (succeeded(parser.parseOptionalString(&nameref))) {
767 if (nameref.empty())
768 return parser.emitError(parser.getCurrentLocation(),
769 "name cannot be empty");
770
771 if (failed(parser.parseEqual()))
772 return parser.emitError(parser.getCurrentLocation(),
773 "expected '=' after name");
774 name = parser.getBuilder().getStringAttr(nameref);
775 } else {
776 name = parser.getBuilder().getStringAttr("");
777 }
778
779 // Parse mandatory value and type.
780 if (failed(parser.parseOperand(v)) || failed(parser.parseColonType(t)))
781 return failure();
782
783 return success();
784}
785
786// Parses the form:
787// parseOptNamedTypedAssignment (`gated by` `[` $clockGates `]`)?
788static ParseResult parseSingleStageRegister(
789 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t,
790 llvm::SmallVector<OpAsmParser::UnresolvedOperand> &clockGates,
791 StringAttr &name) {
792 if (failed(parseOptNamedTypedAssignment(parser, v, t, name)))
793 return failure();
794
795 // Parse optional gated-by clause.
796 if (failed(parser.parseOptionalKeyword("gated")))
797 return success();
798
799 if (failed(parser.parseKeyword("by")) ||
800 failed(
801 parser.parseOperandList(clockGates, OpAsmParser::Delimiter::Square)))
802 return failure();
803
804 return success();
805}
806
807// Parses the form:
808// regs( ($name `=`)? $register : type($register) (`gated by` `[` $clockGates
809// `]`)?, ...)
811 OpAsmParser &parser,
812 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &registers,
813 llvm::SmallVector<mlir::Type, 1> &registerTypes,
814 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &clockGates,
815 ArrayAttr &clockGatesPerRegister, ArrayAttr &registerNames) {
816
817 if (failed(parser.parseOptionalKeyword("regs"))) {
818 clockGatesPerRegister = parser.getBuilder().getI64ArrayAttr({});
819 return success(); // no registers to parse.
820 }
821
822 llvm::SmallVector<int64_t> clockGatesPerRegisterList;
823 llvm::SmallVector<Attribute> registerNamesList;
824 bool withNames = false;
825 if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&]() {
826 OpAsmParser::UnresolvedOperand v;
827 Type t;
828 llvm::SmallVector<OpAsmParser::UnresolvedOperand> cgs;
829 StringAttr name;
830 if (parseSingleStageRegister(parser, v, t, cgs, name))
831 return failure();
832 registers.push_back(v);
833 registerTypes.push_back(t);
834 registerNamesList.push_back(name);
835 withNames |= static_cast<bool>(name);
836 llvm::append_range(clockGates, cgs);
837 clockGatesPerRegisterList.push_back(cgs.size());
838 return success();
839 })))
840 return failure();
841
842 clockGatesPerRegister =
843 parser.getBuilder().getI64ArrayAttr(clockGatesPerRegisterList);
844 if (withNames)
845 registerNames = parser.getBuilder().getArrayAttr(registerNamesList);
846
847 return success();
848}
849
850void printStageRegisters(OpAsmPrinter &p, Operation *op, ValueRange registers,
851 TypeRange registerTypes, ValueRange clockGates,
852 ArrayAttr clockGatesPerRegister, ArrayAttr names) {
853 if (registers.empty())
854 return;
855
856 p << "regs(";
857 size_t clockGateStartIdx = 0;
858 llvm::interleaveComma(
859 llvm::enumerate(
860 llvm::zip(registers, registerTypes, clockGatesPerRegister)),
861 p, [&](auto it) {
862 size_t idx = it.index();
863 auto &[reg, type, nClockGatesAttr] = it.value();
864 if (names) {
865 if (auto nameAttr = dyn_cast<StringAttr>(names[idx]);
866 nameAttr && !nameAttr.strref().empty())
867 p << nameAttr << " = ";
868 }
869
870 p << reg << " : " << type;
871 int64_t nClockGates = cast<IntegerAttr>(nClockGatesAttr).getInt();
872 if (nClockGates == 0)
873 return;
874 p << " gated by [";
875 llvm::interleaveComma(clockGates.slice(clockGateStartIdx, nClockGates),
876 p);
877 p << "]";
878 clockGateStartIdx += nClockGates;
879 });
880 p << ")";
881}
882
883void printPassthroughs(OpAsmPrinter &p, Operation *op, ValueRange passthroughs,
884 TypeRange passthroughTypes, ArrayAttr names) {
885
886 if (passthroughs.empty())
887 return;
888
889 p << "pass(";
890 llvm::interleaveComma(
891 llvm::enumerate(llvm::zip(passthroughs, passthroughTypes)), p,
892 [&](auto it) {
893 size_t idx = it.index();
894 auto &[reg, type] = it.value();
895 if (names) {
896 if (auto nameAttr = dyn_cast<StringAttr>(names[idx]);
897 nameAttr && !nameAttr.strref().empty())
898 p << nameAttr << " = ";
899 }
900 p << reg << " : " << type;
901 });
902 p << ")";
903}
904
905// Parses the form:
906// (`pass` `(` ($name `=`)? $register : type($register), ... `)` )?
908 OpAsmParser &parser,
909 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &passthroughs,
910 llvm::SmallVector<mlir::Type, 1> &passthroughTypes,
911 ArrayAttr &passthroughNames) {
912 if (failed(parser.parseOptionalKeyword("pass")))
913 return success(); // no passthroughs to parse.
914
915 llvm::SmallVector<Attribute> passthroughsNameList;
916 bool withNames = false;
917 if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&]() {
918 OpAsmParser::UnresolvedOperand v;
919 Type t;
920 StringAttr name;
921 if (parseOptNamedTypedAssignment(parser, v, t, name))
922 return failure();
923 passthroughs.push_back(v);
924 passthroughTypes.push_back(t);
925 passthroughsNameList.push_back(name);
926 withNames |= static_cast<bool>(name);
927 return success();
928 })))
929 return failure();
930
931 if (withNames)
932 passthroughNames = parser.getBuilder().getArrayAttr(passthroughsNameList);
933
934 return success();
935}
936
937void StageOp::build(OpBuilder &odsBuilder, OperationState &odsState,
938 Block *dest, ValueRange registers,
939 ValueRange passthroughs) {
940 odsState.addSuccessors(dest);
941 odsState.addOperands(registers);
942 odsState.addOperands(passthroughs);
943 odsState.addAttribute("operandSegmentSizes",
944 odsBuilder.getDenseI32ArrayAttr(
945 {static_cast<int32_t>(registers.size()),
946 static_cast<int32_t>(passthroughs.size()),
947 /*clock gates*/ static_cast<int32_t>(0)}));
948 llvm::SmallVector<int64_t> clockGatesPerRegister(registers.size(), 0);
949 odsState.addAttribute("clockGatesPerRegister",
950 odsBuilder.getI64ArrayAttr(clockGatesPerRegister));
951}
952
953void StageOp::build(OpBuilder &odsBuilder, OperationState &odsState,
954 Block *dest, ValueRange registers, ValueRange passthroughs,
955 llvm::ArrayRef<llvm::SmallVector<Value>> clockGateList,
956 mlir::ArrayAttr registerNames,
957 mlir::ArrayAttr passthroughNames) {
958 build(odsBuilder, odsState, dest, registers, passthroughs);
959
960 llvm::SmallVector<Value> clockGates;
961 llvm::SmallVector<int64_t> clockGatesPerRegister(registers.size(), 0);
962 for (auto gates : clockGateList) {
963 llvm::append_range(clockGates, gates);
964 clockGatesPerRegister.push_back(gates.size());
965 }
966 odsState.attributes.set("clockGatesPerRegister",
967 odsBuilder.getI64ArrayAttr(clockGatesPerRegister));
968 odsState.addOperands(clockGates);
969
970 if (registerNames)
971 odsState.addAttribute("registerNames", registerNames);
972
973 if (passthroughNames)
974 odsState.addAttribute("passthroughNames", passthroughNames);
975}
976
977ValueRange StageOp::getClockGatesForReg(unsigned regIdx) {
978 assert(regIdx < getRegisters().size() && "register index out of bounds.");
979
980 // TODO: This could be optimized quite a bit if we didn't store clock
981 // gates per register as an array of sizes... look into using properties
982 // and maybe attaching a more complex datastructure to reduce compute
983 // here.
984
985 unsigned clockGateStartIdx = 0;
986 for (auto [index, nClockGatesAttr] :
987 llvm::enumerate(getClockGatesPerRegister().getAsRange<IntegerAttr>())) {
988 int64_t nClockGates = nClockGatesAttr.getInt();
989 if (index == regIdx) {
990 // This is the register we are looking for.
991 return getClockGates().slice(clockGateStartIdx, nClockGates);
992 }
993 // Increment the start index by the number of clock gates for this
994 // register.
995 clockGateStartIdx += nClockGates;
996 }
997
998 llvm_unreachable("register index out of bounds.");
999}
1000
1001LogicalResult StageOp::verify() {
1002 // Verify that the target block has the correct arguments as this stage
1003 // op.
1004 llvm::SmallVector<Type> expectedTargetArgTypes;
1005 llvm::append_range(expectedTargetArgTypes, getRegisters().getTypes());
1006 llvm::append_range(expectedTargetArgTypes, getPassthroughs().getTypes());
1007 Block *targetStage = getNextStage();
1008 // Expected types is everything but the stage valid signal.
1009 TypeRange targetStageArgTypes =
1010 TypeRange(targetStage->getArgumentTypes()).drop_back();
1011
1012 if (targetStageArgTypes.size() != expectedTargetArgTypes.size())
1013 return emitOpError("expected ") << expectedTargetArgTypes.size()
1014 << " arguments in the target stage, got "
1015 << targetStageArgTypes.size() << ".";
1016
1017 for (auto [index, it] : llvm::enumerate(
1018 llvm::zip(expectedTargetArgTypes, targetStageArgTypes))) {
1019 auto [arg, barg] = it;
1020 if (arg != barg)
1021 return emitOpError("expected target stage argument ")
1022 << index << " to have type " << arg << ", got " << barg << ".";
1023 }
1024
1025 // Verify that the clock gate index list is equally sized to the # of
1026 // registers.
1027 if (getClockGatesPerRegister().size() != getRegisters().size())
1028 return emitOpError("expected clockGatesPerRegister to be equally sized to "
1029 "the number of registers.");
1030
1031 // Verify that, if provided, the list of register names is equally sized
1032 // to the number of registers.
1033 if (auto regNames = getRegisterNames()) {
1034 if (regNames->size() != getRegisters().size())
1035 return emitOpError("expected registerNames to be equally sized to "
1036 "the number of registers.");
1037 }
1038
1039 // Verify that, if provided, the list of passthrough names is equally sized
1040 // to the number of passthroughs.
1041 if (auto passthroughNames = getPassthroughNames()) {
1042 if (passthroughNames->size() != getPassthroughs().size())
1043 return emitOpError("expected passthroughNames to be equally sized to "
1044 "the number of passthroughs.");
1045 }
1046
1047 return success();
1048}
1049
1050//===----------------------------------------------------------------------===//
1051// LatencyOp
1052//===----------------------------------------------------------------------===//
1053
1054LogicalResult LatencyOp::verify() {
1055 ScheduledPipelineOp scheduledPipelineParent =
1056 dyn_cast<ScheduledPipelineOp>(getOperation()->getParentOp());
1057
1058 if (!scheduledPipelineParent) {
1059 // Nothing to verify, got to assume that anything goes in an unscheduled
1060 // pipeline.
1061 return success();
1062 }
1063
1064 // Verify that there's at least one result type. Latency ops don't make
1065 // sense if they're not delaying anything, and we're not yet prepared to
1066 // support side-effectful bodies.
1067 if (getNumResults() == 0)
1068 return emitOpError("expected at least one result type.");
1069
1070 // Verify that the resulting values aren't referenced before they are
1071 // accessible.
1072 size_t latency = getLatency();
1073 Block *definingStage = getOperation()->getBlock();
1074
1075 llvm::DenseMap<Block *, unsigned> stageMap =
1076 scheduledPipelineParent.getStageMap();
1077
1078 auto stageDistance = [&](Block *from, Block *to) {
1079 assert(stageMap.count(from) && "stage 'from' not contained in pipeline");
1080 assert(stageMap.count(to) && "stage 'to' not contained in pipeline");
1081 int64_t fromStage = stageMap[from];
1082 int64_t toStage = stageMap[to];
1083 return toStage - fromStage;
1084 };
1085
1086 for (auto [i, res] : llvm::enumerate(getResults())) {
1087 for (auto &use : res.getUses()) {
1088 auto *user = use.getOwner();
1089
1090 // The user may reside within a block which is not a stage (e.g.
1091 // inside a pipeline.latency op). Determine the stage which this use
1092 // resides within.
1093 Block *userStage =
1094 getParentStageInPipeline(scheduledPipelineParent, user);
1095 unsigned useDistance = stageDistance(definingStage, userStage);
1096
1097 // Is this a stage op and is the value passed through? if so, this is
1098 // a legal use.
1099 StageOp stageOp = dyn_cast<StageOp>(user);
1100 if (userStage == definingStage && stageOp) {
1101 if (llvm::is_contained(stageOp.getPassthroughs(), res))
1102 continue;
1103 }
1104
1105 // The use is not a passthrough. Check that the distance between
1106 // the defining stage and the user stage is at least the latency of
1107 // the result.
1108 if (useDistance < latency) {
1109 auto diag = emitOpError("result ")
1110 << i << " is used before it is available.";
1111 diag.attachNote(user->getLoc())
1112 << "use was operand " << use.getOperandNumber()
1113 << ". The result is available " << latency - useDistance
1114 << " stages later than this use.";
1115 return diag;
1116 }
1117 }
1118 }
1119 return success();
1120}
1121
1122//===----------------------------------------------------------------------===//
1123// LatencyReturnOp
1124//===----------------------------------------------------------------------===//
1125
1126LogicalResult LatencyReturnOp::verify() {
1127 LatencyOp parent = cast<LatencyOp>(getOperation()->getParentOp());
1128 size_t nInputs = getInputs().size();
1129 size_t nResults = parent->getNumResults();
1130 if (nInputs != nResults)
1131 return emitOpError("expected ")
1132 << nResults << " return values, got " << nInputs << ".";
1133
1134 for (auto [inType, reqType] :
1135 llvm::zip(getInputs().getTypes(), parent->getResultTypes())) {
1136 if (inType != reqType)
1137 return emitOpError("expected return value of type ")
1138 << reqType << ", got " << inType << ".";
1139 }
1140
1141 return success();
1142}
1143
1144#define GET_OP_CLASSES
1145#include "circt/Dialect/Pipeline/Pipeline.cpp.inc"
1146
1147void PipelineDialect::initialize() {
1148 addOperations<
1149#define GET_OP_LIST
1150#include "circt/Dialect/Pipeline/Pipeline.cpp.inc"
1151 >();
1152}
assert(baseType &&"element must be base type")
static InstancePath empty
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:217
static void buildPipelineLikeOp(OpBuilder &odsBuilder, OperationState &odsState, TypeRange dataOutputs, ValueRange inputs, ArrayAttr inputNames, ArrayAttr outputNames, Value clock, Value go, Value reset, Value stall, StringAttr name, ArrayAttr stallability)
static ParseResult parseSingleStageRegister(OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t, llvm::SmallVector< OpAsmParser::UnresolvedOperand > &clockGates, StringAttr &name)
static FailureOr< llvm::SmallVector< Block * > > getOrderedStagesFailable(ScheduledPipelineOp op)
static ParseResult parseKeywordAndOperand(OpAsmParser &p, StringRef keyword, OpAsmParser::UnresolvedOperand &op)
static void printOutputList(OpAsmPrinter &p, TypeRange types, ArrayAttr names)
ParseResult parsePassthroughs(OpAsmParser &parser, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &passthroughs, llvm::SmallVector< mlir::Type, 1 > &passthroughTypes, ArrayAttr &passthroughNames)
static void printPipelineOp(OpAsmPrinter &p, TPipelineOp op)
void printPassthroughs(OpAsmPrinter &p, Operation *op, ValueRange passthroughs, TypeRange passthroughTypes, ArrayAttr names)
static ParseResult parseOptNamedTypedAssignment(OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t, StringAttr &name)
static ParseResult parsePipelineOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
static bool useDefinedInStage(Block *stopAt, Block *stage, OpOperand &use)
void printStageRegisters(OpAsmPrinter &p, Operation *op, ValueRange registers, TypeRange registerTypes, ValueRange clockGates, ArrayAttr clockGatesPerRegister, ArrayAttr names)
static void getPipelineAsmBlockArgumentNames(TPipelineOp op, mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn)
static ParseResult parseOutputList(OpAsmParser &parser, llvm::SmallVector< Type > &inputTypes, mlir::ArrayAttr &outputNames)
static bool isNestedBlock(Block *stopAt, Block *parent, Block *current)
static void printKeywordOperand(OpAsmPrinter &p, StringRef keyword, Value value)
ParseResult parseStageRegisters(OpAsmParser &parser, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &registers, llvm::SmallVector< mlir::Type, 1 > &registerTypes, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &clockGates, ArrayAttr &clockGatesPerRegister, ArrayAttr &registerNames)
static void getPipelineAsmResultNames(TPipelineOp op, OpAsmSetValueNameFn setNameFn)
ParseResult parseInitializerList(mlir::OpAsmParser &parser, llvm::SmallVector< mlir::OpAsmParser::Argument > &inputArguments, llvm::SmallVector< mlir::OpAsmParser::UnresolvedOperand > &inputOperands, llvm::SmallVector< Type > &inputTypes, ArrayAttr &inputNames)
Parses an initializer.
void printInitializerList(OpAsmPrinter &p, ValueRange ins, ArrayRef< BlockArgument > args)
llvm::SmallVector< Value > getValuesDefinedOutsideRegion(Region &region)
Block * getParentStageInPipeline(ScheduledPipelineOp pipeline, Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183