CIRCT  19.0.0git
PipelineOps.cpp
Go to the documentation of this file.
1 //===- PipelineOps.h - Pipeline MLIR Operations -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the Pipeline ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Interfaces/FunctionImplementation.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 
22 using namespace mlir;
23 using namespace circt;
24 using namespace circt::pipeline;
25 using namespace circt::parsing_util;
26 
27 #include "circt/Dialect/Pipeline/PipelineDialect.cpp.inc"
28 
29 #define DEBUG_TYPE "pipeline-ops"
30 
31 llvm::SmallVector<Value>
33  llvm::SetVector<Value> values;
34  region.walk([&](Operation *op) {
35  for (auto operand : op->getOperands()) {
36  if (region.isAncestor(operand.getParentRegion()))
37  continue;
38  values.insert(operand);
39  }
40  });
41  return values.takeVector();
42 }
43 
44 Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
45  Block *block) {
46  // Optional debug check - ensure that 'block' eventually leads to the
47  // pipeline.
48  LLVM_DEBUG({
49  Operation *directParent = block->getParentOp();
50  if (directParent != pipeline) {
51  auto indirectParent =
52  directParent->getParentOfType<ScheduledPipelineOp>();
53  assert(indirectParent == pipeline && "block is not in the pipeline");
54  }
55  });
56 
57  while (block && block->getParent() != &pipeline.getRegion()) {
58  // Go one level up.
59  block = block->getParent()->getParentOp()->getBlock();
60  }
61 
62  // This is a block within the pipeline region, so it must be a stage.
63  return block;
64 }
65 
66 Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
67  Operation *op) {
68  return getParentStageInPipeline(pipeline, op->getBlock());
69 }
70 
71 Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
72  Value v) {
73  if (v.isa<BlockArgument>())
74  return getParentStageInPipeline(pipeline,
75  v.cast<BlockArgument>().getOwner());
76  return getParentStageInPipeline(pipeline, v.getDefiningOp());
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // Fancy pipeline-like op printer/parser functions.
81 //===----------------------------------------------------------------------===//
82 
83 // Parses a list of operands on the format:
84 // (name : type, ...)
85 static ParseResult parseOutputList(OpAsmParser &parser,
86  llvm::SmallVector<Type> &inputTypes,
87  mlir::ArrayAttr &outputNames) {
88 
89  llvm::SmallVector<Attribute> names;
90  if (parser.parseCommaSeparatedList(
91  OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
92  StringRef name;
93  Type type;
94  if (parser.parseKeyword(&name) || parser.parseColonType(type))
95  return failure();
96 
97  inputTypes.push_back(type);
98  names.push_back(StringAttr::get(parser.getContext(), name));
99  return success();
100  }))
101  return failure();
102 
103  outputNames = ArrayAttr::get(parser.getContext(), names);
104  return success();
105 }
106 
107 static void printOutputList(OpAsmPrinter &p, TypeRange types, ArrayAttr names) {
108  p << "(";
109  llvm::interleaveComma(llvm::zip(types, names), p, [&](auto it) {
110  auto [type, name] = it;
111  p.printKeywordOrString(name.template cast<StringAttr>().str());
112  p << " : " << type;
113  });
114  p << ")";
115 }
116 
117 static ParseResult parseKeywordAndOperand(OpAsmParser &p, StringRef keyword,
118  OpAsmParser::UnresolvedOperand &op) {
119  if (p.parseKeyword(keyword) || p.parseLParen() || p.parseOperand(op) ||
120  p.parseRParen())
121  return failure();
122  return success();
123 }
124 
125 // Assembly format is roughly:
126 // ( $name )? initializer-list stall (%stall = $stall)?
127 // clock (%clock) reset (%reset) go(%go) entryEnable(%en) {
128 // --- elided inner block ---
129 static ParseResult parsePipelineOp(mlir::OpAsmParser &parser,
130  mlir::OperationState &result) {
131  std::string name;
132  if (succeeded(parser.parseOptionalString(&name)))
133  result.addAttribute("name", parser.getBuilder().getStringAttr(name));
134 
135  llvm::SmallVector<OpAsmParser::UnresolvedOperand> inputOperands;
136  llvm::SmallVector<OpAsmParser::Argument> inputArguments;
137  llvm::SmallVector<Type> inputTypes;
138  ArrayAttr inputNames;
139  if (parseInitializerList(parser, inputArguments, inputOperands, inputTypes,
140  inputNames))
141  return failure();
142  result.addAttribute("inputNames", inputNames);
143 
144  Type i1 = parser.getBuilder().getI1Type();
145  // Parse optional 'stall (%stallArg)'
146  OpAsmParser::Argument stallArg;
147  OpAsmParser::UnresolvedOperand stallOperand;
148  bool withStall = false;
149  if (succeeded(parser.parseOptionalKeyword("stall"))) {
150  if (parser.parseLParen() || parser.parseOperand(stallOperand) ||
151  parser.parseRParen())
152  return failure();
153  withStall = true;
154  }
155 
156  // Parse clock, reset, and go.
157  OpAsmParser::UnresolvedOperand clockOperand, resetOperand, goOperand;
158  if (parseKeywordAndOperand(parser, "clock", clockOperand) ||
159  parseKeywordAndOperand(parser, "reset", resetOperand) ||
160  parseKeywordAndOperand(parser, "go", goOperand))
161  return failure();
162 
163  // Parse entry stage enable block argument.
164  OpAsmParser::Argument entryEnable;
165  entryEnable.type = i1;
166  if (parser.parseKeyword("entryEn") || parser.parseLParen() ||
167  parser.parseArgument(entryEnable) || parser.parseRParen())
168  return failure();
169 
170  // Optional attribute dict
171  if (parser.parseOptionalAttrDict(result.attributes))
172  return failure();
173 
174  // Parse the output assignment list
175  if (parser.parseArrow())
176  return failure();
177 
178  llvm::SmallVector<Type> outputTypes;
179  ArrayAttr outputNames;
180  if (parseOutputList(parser, outputTypes, outputNames))
181  return failure();
182  result.addTypes(outputTypes);
183  result.addAttribute("outputNames", outputNames);
184 
185  // And the implicit 'done' output.
186  result.addTypes({i1});
187 
188  // All operands have been parsed - resolve.
189  if (parser.resolveOperands(inputOperands, inputTypes, parser.getNameLoc(),
190  result.operands))
191  return failure();
192 
193  if (withStall) {
194  if (parser.resolveOperand(stallOperand, i1, result.operands))
195  return failure();
196  }
197 
198  Type clkType = seq::ClockType::get(parser.getContext());
199  if (parser.resolveOperand(clockOperand, clkType, result.operands) ||
200  parser.resolveOperand(resetOperand, i1, result.operands) ||
201  parser.resolveOperand(goOperand, i1, result.operands))
202  return failure();
203 
204  // Assemble the body region block arguments - this is where the magic happens
205  // and why we're doing a custom printer/parser - if the user had to magically
206  // know the order of these block arguments, we're asking for issues.
207  SmallVector<OpAsmParser::Argument> regionArgs;
208 
209  // First we add the input arguments.
210  llvm::append_range(regionArgs, inputArguments);
211  // Then the internal entry stage enable block argument.
212  regionArgs.push_back(entryEnable);
213 
214  // Parse the body region.
215  Region *body = result.addRegion();
216  if (parser.parseRegion(*body, regionArgs))
217  return failure();
218 
219  result.addAttribute(
220  "operandSegmentSizes",
221  parser.getBuilder().getDenseI32ArrayAttr(
222  {static_cast<int32_t>(inputTypes.size()),
223  static_cast<int32_t>(withStall ? 1 : 0),
224  /*clock*/ static_cast<int32_t>(1),
225  /*reset*/ static_cast<int32_t>(1), /*go*/ static_cast<int32_t>(1)}));
226 
227  return success();
228 }
229 
230 static void printKeywordOperand(OpAsmPrinter &p, StringRef keyword,
231  Value value) {
232  p << keyword << "(";
233  p.printOperand(value);
234  p << ")";
235 }
236 
237 template <typename TPipelineOp>
238 static void printPipelineOp(OpAsmPrinter &p, TPipelineOp op) {
239  if (auto name = op.getNameAttr()) {
240  p << " \"" << name.getValue() << "\"";
241  }
242 
243  // Print the input list.
244  printInitializerList(p, op.getInputs(), op.getInnerInputs());
245  p << " ";
246 
247  // Print the optional stall.
248  if (op.hasStall()) {
249  p << "stall(";
250  p.printOperand(op.getStall());
251  p << ") ";
252  }
253 
254  // Print the clock, reset, and go.
255  printKeywordOperand(p, "clock", op.getClock());
256  p << " ";
257  printKeywordOperand(p, "reset", op.getReset());
258  p << " ";
259  printKeywordOperand(p, "go", op.getGo());
260  p << " ";
261 
262  // Print the entry enable block argument.
263  p << "entryEn(";
264  p.printRegionArgument(
265  cast<BlockArgument>(op.getStageEnableSignal(static_cast<size_t>(0))), {},
266  /*omitType*/ true);
267  p << ") ";
268 
269  // Print the optional attribute dict.
270  p.printOptionalAttrDict(op->getAttrs(),
271  /*elidedAttrs=*/{"name", "operandSegmentSizes",
272  "outputNames", "inputNames"});
273  p << " -> ";
274 
275  // Print the output list.
276  printOutputList(p, op.getDataOutputs().getTypes(), op.getOutputNames());
277 
278  p << " ";
279 
280  // Print the inner region, eliding the entry block arguments - we've already
281  // defined these in our initializer lists.
282  p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false,
283  /*printBlockTerminators=*/true);
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // UnscheduledPipelineOp
288 //===----------------------------------------------------------------------===//
289 
290 static void buildPipelineLikeOp(OpBuilder &odsBuilder, OperationState &odsState,
291  TypeRange dataOutputs, ValueRange inputs,
292  ArrayAttr inputNames, ArrayAttr outputNames,
293  Value clock, Value reset, Value go, Value stall,
294  StringAttr name, ArrayAttr stallability) {
295  odsState.addOperands(inputs);
296  if (stall)
297  odsState.addOperands(stall);
298  odsState.addOperands(clock);
299  odsState.addOperands(reset);
300  odsState.addOperands(go);
301  if (name)
302  odsState.addAttribute("name", name);
303 
304  odsState.addAttribute(
305  "operandSegmentSizes",
306  odsBuilder.getDenseI32ArrayAttr(
307  {static_cast<int32_t>(inputs.size()),
308  static_cast<int32_t>(stall ? 1 : 0), static_cast<int32_t>(1),
309  static_cast<int32_t>(1), static_cast<int32_t>(1)}));
310 
311  odsState.addAttribute("inputNames", inputNames);
312  odsState.addAttribute("outputNames", outputNames);
313 
314  auto *region = odsState.addRegion();
315  odsState.addTypes(dataOutputs);
316 
317  // Add the implicit done output signal.
318  Type i1 = odsBuilder.getIntegerType(1);
319  odsState.addTypes({i1});
320 
321  // Add the entry stage - arguments order:
322  // 1. Inputs
323  // 2. Stall (opt)
324  // 3. Clock
325  // 4. Reset
326  // 5. Go
327  auto &entryBlock = region->emplaceBlock();
328  llvm::SmallVector<Location> entryArgLocs(inputs.size(), odsState.location);
329  entryBlock.addArguments(
330  inputs.getTypes(),
331  llvm::SmallVector<Location>(inputs.size(), odsState.location));
332  if (stall)
333  entryBlock.addArgument(i1, odsState.location);
334  entryBlock.addArgument(i1, odsState.location);
335  entryBlock.addArgument(i1, odsState.location);
336 
337  // entry stage valid signal.
338  entryBlock.addArgument(i1, odsState.location);
339 
340  if (stallability)
341  odsState.addAttribute("stallability", stallability);
342 }
343 
344 template <typename TPipelineOp>
345 static void getPipelineAsmResultNames(TPipelineOp op,
346  OpAsmSetValueNameFn setNameFn) {
347  for (auto [res, name] :
348  llvm::zip(op.getDataOutputs(),
349  op.getOutputNames().template getAsValueRange<StringAttr>()))
350  setNameFn(res, name);
351  setNameFn(op.getDone(), "done");
352 }
353 
354 template <typename TPipelineOp>
355 static void
356 getPipelineAsmBlockArgumentNames(TPipelineOp op, mlir::Region &region,
357  mlir::OpAsmSetValueNameFn setNameFn) {
358  for (auto [i, block] : llvm::enumerate(op.getRegion())) {
359  if (Block *predBlock = block.getSinglePredecessor()) {
360  // Predecessor stageOp might have register and passthrough names
361  // specified, which we can use to name the block arguments.
362  auto predStageOp = cast<StageOp>(predBlock->getTerminator());
363  size_t nRegs = predStageOp.getRegisters().size();
364  auto nPassthrough = predStageOp.getPassthroughs().size();
365 
366  auto regNames = predStageOp.getRegisterNames();
367  auto passthroughNames = predStageOp.getPassthroughNames();
368 
369  // Register naming...
370  for (size_t regI = 0; regI < nRegs; ++regI) {
371  auto arg = block.getArguments()[regI];
372 
373  if (regNames) {
374  auto nameAttr = (*regNames)[regI].dyn_cast<StringAttr>();
375  if (nameAttr && !nameAttr.strref().empty()) {
376  setNameFn(arg, nameAttr);
377  continue;
378  }
379  }
380  setNameFn(arg, llvm::formatv("s{0}_reg{1}", i, regI).str());
381  }
382 
383  // Passthrough naming...
384  for (size_t passthroughI = 0; passthroughI < nPassthrough;
385  ++passthroughI) {
386  auto arg = block.getArguments()[nRegs + passthroughI];
387 
388  if (passthroughNames) {
389  auto nameAttr =
390  (*passthroughNames)[passthroughI].dyn_cast<StringAttr>();
391  if (nameAttr && !nameAttr.strref().empty()) {
392  setNameFn(arg, nameAttr);
393  continue;
394  }
395  }
396  setNameFn(arg, llvm::formatv("s{0}_pass{1}", i, passthroughI).str());
397  }
398  } else {
399  // This is the entry stage - name the arguments according to the input
400  // names.
401  for (auto [inputArg, inputName] :
402  llvm::zip(op.getInnerInputs(),
403  op.getInputNames().template getAsValueRange<StringAttr>()))
404  setNameFn(inputArg, inputName);
405  }
406 
407  // Last argument in any stage is the stage enable signal.
408  setNameFn(block.getArguments().back(),
409  llvm::formatv("s{0}_enable", i).str());
410  }
411 }
412 
413 void UnscheduledPipelineOp::print(OpAsmPrinter &p) {
414  printPipelineOp(p, *this);
415 }
416 
417 ParseResult UnscheduledPipelineOp::parse(OpAsmParser &parser,
418  OperationState &result) {
419  return parsePipelineOp(parser, result);
420 }
421 
423  getPipelineAsmResultNames(*this, setNameFn);
424 }
425 
426 void UnscheduledPipelineOp::getAsmBlockArgumentNames(
427  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
428  getPipelineAsmBlockArgumentNames(*this, region, setNameFn);
429 }
430 
431 void UnscheduledPipelineOp::build(OpBuilder &odsBuilder,
432  OperationState &odsState,
433  TypeRange dataOutputs, ValueRange inputs,
434  ArrayAttr inputNames, ArrayAttr outputNames,
435  Value clock, Value reset, Value go,
436  Value stall, StringAttr name,
437  ArrayAttr stallability) {
438  buildPipelineLikeOp(odsBuilder, odsState, dataOutputs, inputs, inputNames,
439  outputNames, clock, reset, go, stall, name, stallability);
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // ScheduledPipelineOp
444 //===----------------------------------------------------------------------===//
445 
446 void ScheduledPipelineOp::print(OpAsmPrinter &p) { printPipelineOp(p, *this); }
447 
448 ParseResult ScheduledPipelineOp::parse(OpAsmParser &parser,
449  OperationState &result) {
450  return parsePipelineOp(parser, result);
451 }
452 
453 void ScheduledPipelineOp::build(OpBuilder &odsBuilder, OperationState &odsState,
454  TypeRange dataOutputs, ValueRange inputs,
455  ArrayAttr inputNames, ArrayAttr outputNames,
456  Value clock, Value reset, Value go, Value stall,
457  StringAttr name, ArrayAttr stallability) {
458  buildPipelineLikeOp(odsBuilder, odsState, dataOutputs, inputs, inputNames,
459  outputNames, clock, reset, go, stall, name, stallability);
460 }
461 
462 Block *ScheduledPipelineOp::addStage() {
463  OpBuilder builder(getContext());
464  Block *stage = builder.createBlock(&getRegion());
465 
466  // Add the stage valid signal.
467  stage->addArgument(builder.getIntegerType(1), getLoc());
468  return stage;
469 }
470 
471 void ScheduledPipelineOp::getAsmBlockArgumentNames(
472  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
473  getPipelineAsmBlockArgumentNames(*this, region, setNameFn);
474 }
475 
477  getPipelineAsmResultNames(*this, setNameFn);
478 }
479 
480 // Implementation of getOrderedStages which also produces an error if
481 // there are any cfg cycles in the pipeline.
483 getOrderedStagesFailable(ScheduledPipelineOp op) {
484  llvm::DenseSet<Block *> visited;
485  Block *currentStage = op.getEntryStage();
486  llvm::SmallVector<Block *> orderedStages;
487  do {
488  if (!visited.insert(currentStage).second)
489  return op.emitOpError("pipeline contains a cycle.");
490 
491  orderedStages.push_back(currentStage);
492  if (auto stageOp = dyn_cast<StageOp>(currentStage->getTerminator()))
493  currentStage = stageOp.getNextStage();
494  else
495  currentStage = nullptr;
496  } while (currentStage);
497 
498  return {orderedStages};
499 }
500 
501 llvm::SmallVector<Block *> ScheduledPipelineOp::getOrderedStages() {
502  // Should always be safe, seeing as the pipeline itself has already been
503  // verified.
504  return *getOrderedStagesFailable(*this);
505 }
506 
507 llvm::DenseMap<Block *, unsigned> ScheduledPipelineOp::getStageMap() {
508  llvm::DenseMap<Block *, unsigned> stageMap;
509  auto orderedStages = getOrderedStages();
510  for (auto [index, stage] : llvm::enumerate(orderedStages))
511  stageMap[stage] = index;
512 
513  return stageMap;
514 }
515 
516 Block *ScheduledPipelineOp::getLastStage() { return getOrderedStages().back(); }
517 
518 bool ScheduledPipelineOp::isMaterialized() {
519  // We determine materialization as if any pipeline stage has an explicit
520  // input (apart from the stage valid signal).
521  return llvm::any_of(getStages(), [this](Block &block) {
522  // The entry stage doesn't count since it'll always have arguments.
523  if (&block == getEntryStage())
524  return false;
525  return block.getNumArguments() > 1;
526  });
527 }
528 
529 LogicalResult ScheduledPipelineOp::verify() {
530  // Verify that all block are terminated properly.
531  auto &stages = getStages();
532  for (Block &stage : stages) {
533  if (stage.empty() || !isa<ReturnOp, StageOp>(stage.back()))
534  return emitOpError("all blocks must be terminated with a "
535  "`pipeline.stage` or `pipeline.return` op.");
536  }
537 
538  if (failed(getOrderedStagesFailable(*this)))
539  return failure();
540 
541  // Verify that every stage has a stage valid block argument.
542  for (auto [i, block] : llvm::enumerate(stages)) {
543  bool err = true;
544  if (block.getNumArguments() != 0) {
545  auto lastArgType =
546  block.getArguments().back().getType().dyn_cast<IntegerType>();
547  err = !lastArgType || lastArgType.getWidth() != 1;
548  }
549  if (err)
550  return emitOpError("block " + std::to_string(i) +
551  " must have an i1 argument as the last block argument "
552  "(stage valid signal).");
553  }
554 
555  // Cache external inputs in a set for fast lookup (also includes clock, reset,
556  // and stall).
557  llvm::DenseSet<Value> extLikeInputs;
558  for (auto extInput : getExtInputs())
559  extLikeInputs.insert(extInput);
560 
561  extLikeInputs.insert(getClock());
562  extLikeInputs.insert(getReset());
563  if (hasStall())
564  extLikeInputs.insert(getStall());
565 
566  // Phase invariant - if any block has arguments apart from the stage valid
567  // argument, we are in register materialized mode. Check that all values
568  // used within a stage are defined within the stage.
569  bool materialized = isMaterialized();
570  if (materialized) {
571  for (auto &stage : stages) {
572  for (auto &op : stage) {
573  for (auto [index, operand] : llvm::enumerate(op.getOperands())) {
574  bool err = false;
575  if (extLikeInputs.contains(operand)) {
576  // This is an external input; legal to reference everywhere.
577  continue;
578  }
579 
580  if (auto *definingOp = operand.getDefiningOp()) {
581  // Constants are allowed to be used across stages.
582  if (definingOp->hasTrait<OpTrait::ConstantLike>())
583  continue;
584  err = definingOp->getBlock() != &stage;
585  } else {
586  // This is a block argument;
587  err = !llvm::is_contained(stage.getArguments(), operand);
588  }
589 
590  if (err)
591  return op.emitOpError(
592  "Pipeline is in register materialized mode - operand ")
593  << index
594  << " is defined in a different stage, which is illegal.";
595  }
596  }
597  }
598  }
599 
600  if (auto stallability = getStallability()) {
601  // Only allow specifying stallability if there is a stall signal.
602  if (!hasStall())
603  return emitOpError("cannot specify stallability without a stall signal.");
604 
605  // Ensure that the # of stages is equal to the length of the stallability
606  // array - the exit stage is never stallable.
607  size_t nRegisterStages = stages.size() - 1;
608  if (stallability->size() != nRegisterStages)
609  return emitOpError("stallability array must be the same length as the "
610  "number of stages. Pipeline has ")
611  << nRegisterStages << " stages but array had "
612  << stallability->size() << " elements.";
613  }
614 
615  return success();
616 }
617 
618 StageKind ScheduledPipelineOp::getStageKind(size_t stageIndex) {
619  assert(stageIndex < getNumStages() && "invalid stage index");
620 
621  if (!hasStall())
622  return StageKind::Continuous;
623 
624  // There is a stall signal - also check whether stage-level stallability is
625  // specified.
626  std::optional<ArrayAttr> stallability = getStallability();
627  if (!stallability) {
628  // All stages are stallable.
629  return StageKind::Stallable;
630  }
631 
632  if (stageIndex < stallability->size()) {
633  bool stageIsStallable =
634  (*stallability)[stageIndex].cast<BoolAttr>().getValue();
635  if (!stageIsStallable) {
636  // This is a non-stallable stage.
637  return StageKind::NonStallable;
638  }
639  }
640 
641  // Walk backwards from this stage to see if any non-stallable stage exists.
642  // If so, this is a runoff stage.
643  // TODO: This should be a pre-computed property.
644  if (stageIndex == 0)
645  return StageKind::Stallable;
646 
647  for (size_t i = stageIndex - 1; i > 0; --i) {
648  if (getStageKind(i) == StageKind::NonStallable)
649  return StageKind::Runoff;
650  }
651  return StageKind::Stallable;
652 }
653 
654 //===----------------------------------------------------------------------===//
655 // ReturnOp
656 //===----------------------------------------------------------------------===//
657 
658 LogicalResult ReturnOp::verify() {
659  Operation *parent = getOperation()->getParentOp();
660  size_t nInputs = getInputs().size();
661  auto expectedResults = TypeRange(parent->getResultTypes()).drop_back();
662  size_t expectedNResults = expectedResults.size();
663  if (nInputs != expectedNResults)
664  return emitOpError("expected ")
665  << expectedNResults << " return values, got " << nInputs << ".";
666 
667  for (auto [inType, reqType] :
668  llvm::zip(getInputs().getTypes(), expectedResults)) {
669  if (inType != reqType)
670  return emitOpError("expected return value of type ")
671  << reqType << ", got " << inType << ".";
672  }
673 
674  return success();
675 }
676 
677 //===----------------------------------------------------------------------===//
678 // StageOp
679 //===----------------------------------------------------------------------===//
680 
681 // Parses the form:
682 // ($name `=`)? $register : type($register)
683 
684 static ParseResult
685 parseOptNamedTypedAssignment(OpAsmParser &parser,
686  OpAsmParser::UnresolvedOperand &v, Type &t,
687  StringAttr &name) {
688  // Parse optional name.
689  std::string nameref;
690  if (succeeded(parser.parseOptionalString(&nameref))) {
691  if (nameref.empty())
692  return parser.emitError(parser.getCurrentLocation(),
693  "name cannot be empty");
694 
695  if (failed(parser.parseEqual()))
696  return parser.emitError(parser.getCurrentLocation(),
697  "expected '=' after name");
698  name = parser.getBuilder().getStringAttr(nameref);
699  } else {
700  name = parser.getBuilder().getStringAttr("");
701  }
702 
703  // Parse mandatory value and type.
704  if (failed(parser.parseOperand(v)) || failed(parser.parseColonType(t)))
705  return failure();
706 
707  return success();
708 }
709 
710 // Parses the form:
711 // parseOptNamedTypedAssignment (`gated by` `[` $clockGates `]`)?
712 static ParseResult parseSingleStageRegister(
713  OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t,
714  llvm::SmallVector<OpAsmParser::UnresolvedOperand> &clockGates,
715  StringAttr &name) {
716  if (failed(parseOptNamedTypedAssignment(parser, v, t, name)))
717  return failure();
718 
719  // Parse optional gated-by clause.
720  if (failed(parser.parseOptionalKeyword("gated")))
721  return success();
722 
723  if (failed(parser.parseKeyword("by")) ||
724  failed(
725  parser.parseOperandList(clockGates, OpAsmParser::Delimiter::Square)))
726  return failure();
727 
728  return success();
729 }
730 
731 // Parses the form:
732 // regs( ($name `=`)? $register : type($register) (`gated by` `[` $clockGates
733 // `]`)?, ...)
735  OpAsmParser &parser,
736  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &registers,
737  llvm::SmallVector<mlir::Type, 1> &registerTypes,
738  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &clockGates,
739  ArrayAttr &clockGatesPerRegister, ArrayAttr &registerNames) {
740 
741  if (failed(parser.parseOptionalKeyword("regs"))) {
742  clockGatesPerRegister = parser.getBuilder().getI64ArrayAttr({});
743  return success(); // no registers to parse.
744  }
745 
746  llvm::SmallVector<int64_t> clockGatesPerRegisterList;
747  llvm::SmallVector<Attribute> registerNamesList;
748  bool withNames = false;
749  if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&]() {
750  OpAsmParser::UnresolvedOperand v;
751  Type t;
752  llvm::SmallVector<OpAsmParser::UnresolvedOperand> cgs;
753  StringAttr name;
754  if (parseSingleStageRegister(parser, v, t, cgs, name))
755  return failure();
756  registers.push_back(v);
757  registerTypes.push_back(t);
758  registerNamesList.push_back(name);
759  withNames |= static_cast<bool>(name);
760  llvm::append_range(clockGates, cgs);
761  clockGatesPerRegisterList.push_back(cgs.size());
762  return success();
763  })))
764  return failure();
765 
766  clockGatesPerRegister =
767  parser.getBuilder().getI64ArrayAttr(clockGatesPerRegisterList);
768  if (withNames)
769  registerNames = parser.getBuilder().getArrayAttr(registerNamesList);
770 
771  return success();
772 }
773 
774 void printStageRegisters(OpAsmPrinter &p, Operation *op, ValueRange registers,
775  TypeRange registerTypes, ValueRange clockGates,
776  ArrayAttr clockGatesPerRegister, ArrayAttr names) {
777  if (registers.empty())
778  return;
779 
780  p << "regs(";
781  size_t clockGateStartIdx = 0;
782  llvm::interleaveComma(
783  llvm::enumerate(
784  llvm::zip(registers, registerTypes, clockGatesPerRegister)),
785  p, [&](auto it) {
786  size_t idx = it.index();
787  auto &[reg, type, nClockGatesAttr] = it.value();
788  if (names) {
789  if (auto nameAttr = names[idx].dyn_cast<StringAttr>();
790  nameAttr && !nameAttr.strref().empty())
791  p << nameAttr << " = ";
792  }
793 
794  p << reg << " : " << type;
795  int64_t nClockGates =
796  nClockGatesAttr.template cast<IntegerAttr>().getInt();
797  if (nClockGates == 0)
798  return;
799  p << " gated by [";
800  llvm::interleaveComma(clockGates.slice(clockGateStartIdx, nClockGates),
801  p);
802  p << "]";
803  clockGateStartIdx += nClockGates;
804  });
805  p << ")";
806 }
807 
808 void printPassthroughs(OpAsmPrinter &p, Operation *op, ValueRange passthroughs,
809  TypeRange passthroughTypes, ArrayAttr names) {
810 
811  if (passthroughs.empty())
812  return;
813 
814  p << "pass(";
815  llvm::interleaveComma(
816  llvm::enumerate(llvm::zip(passthroughs, passthroughTypes)), p,
817  [&](auto it) {
818  size_t idx = it.index();
819  auto &[reg, type] = it.value();
820  if (names) {
821  if (auto nameAttr = names[idx].dyn_cast<StringAttr>();
822  nameAttr && !nameAttr.strref().empty())
823  p << nameAttr << " = ";
824  }
825  p << reg << " : " << type;
826  });
827  p << ")";
828 }
829 
830 // Parses the form:
831 // (`pass` `(` ($name `=`)? $register : type($register), ... `)` )?
832 ParseResult parsePassthroughs(
833  OpAsmParser &parser,
834  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &passthroughs,
835  llvm::SmallVector<mlir::Type, 1> &passthroughTypes,
836  ArrayAttr &passthroughNames) {
837  if (failed(parser.parseOptionalKeyword("pass")))
838  return success(); // no passthroughs to parse.
839 
840  llvm::SmallVector<Attribute> passthroughsNameList;
841  bool withNames = false;
842  if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&]() {
843  OpAsmParser::UnresolvedOperand v;
844  Type t;
845  StringAttr name;
846  if (parseOptNamedTypedAssignment(parser, v, t, name))
847  return failure();
848  passthroughs.push_back(v);
849  passthroughTypes.push_back(t);
850  passthroughsNameList.push_back(name);
851  withNames |= static_cast<bool>(name);
852  return success();
853  })))
854  return failure();
855 
856  if (withNames)
857  passthroughNames = parser.getBuilder().getArrayAttr(passthroughsNameList);
858 
859  return success();
860 }
861 
862 void StageOp::build(OpBuilder &odsBuilder, OperationState &odsState,
863  Block *dest, ValueRange registers,
864  ValueRange passthroughs) {
865  odsState.addSuccessors(dest);
866  odsState.addOperands(registers);
867  odsState.addOperands(passthroughs);
868  odsState.addAttribute("operandSegmentSizes",
869  odsBuilder.getDenseI32ArrayAttr(
870  {static_cast<int32_t>(registers.size()),
871  static_cast<int32_t>(passthroughs.size()),
872  /*clock gates*/ static_cast<int32_t>(0)}));
873  llvm::SmallVector<int64_t> clockGatesPerRegister(registers.size(), 0);
874  odsState.addAttribute("clockGatesPerRegister",
875  odsBuilder.getI64ArrayAttr(clockGatesPerRegister));
876 }
877 
878 void StageOp::build(OpBuilder &odsBuilder, OperationState &odsState,
879  Block *dest, ValueRange registers, ValueRange passthroughs,
880  llvm::ArrayRef<llvm::SmallVector<Value>> clockGateList,
881  mlir::ArrayAttr registerNames,
882  mlir::ArrayAttr passthroughNames) {
883  build(odsBuilder, odsState, dest, registers, passthroughs);
884 
885  llvm::SmallVector<Value> clockGates;
886  llvm::SmallVector<int64_t> clockGatesPerRegister(registers.size(), 0);
887  for (auto gates : clockGateList) {
888  llvm::append_range(clockGates, gates);
889  clockGatesPerRegister.push_back(gates.size());
890  }
891  odsState.attributes.set("clockGatesPerRegister",
892  odsBuilder.getI64ArrayAttr(clockGatesPerRegister));
893  odsState.addOperands(clockGates);
894 
895  if (registerNames)
896  odsState.addAttribute("registerNames", registerNames);
897 
898  if (passthroughNames)
899  odsState.addAttribute("passthroughNames", passthroughNames);
900 }
901 
902 ValueRange StageOp::getClockGatesForReg(unsigned regIdx) {
903  assert(regIdx < getRegisters().size() && "register index out of bounds.");
904 
905  // TODO: This could be optimized quite a bit if we didn't store clock
906  // gates per register as an array of sizes... look into using properties
907  // and maybe attaching a more complex datastructure to reduce compute
908  // here.
909 
910  unsigned clockGateStartIdx = 0;
911  for (auto [index, nClockGatesAttr] :
912  llvm::enumerate(getClockGatesPerRegister().getAsRange<IntegerAttr>())) {
913  int64_t nClockGates = nClockGatesAttr.getInt();
914  if (index == regIdx) {
915  // This is the register we are looking for.
916  return getClockGates().slice(clockGateStartIdx, nClockGates);
917  }
918  // Increment the start index by the number of clock gates for this
919  // register.
920  clockGateStartIdx += nClockGates;
921  }
922 
923  llvm_unreachable("register index out of bounds.");
924 }
925 
926 LogicalResult StageOp::verify() {
927  // Verify that the target block has the correct arguments as this stage
928  // op.
929  llvm::SmallVector<Type> expectedTargetArgTypes;
930  llvm::append_range(expectedTargetArgTypes, getRegisters().getTypes());
931  llvm::append_range(expectedTargetArgTypes, getPassthroughs().getTypes());
932  Block *targetStage = getNextStage();
933  // Expected types is everything but the stage valid signal.
934  TypeRange targetStageArgTypes =
935  TypeRange(targetStage->getArgumentTypes()).drop_back();
936 
937  if (targetStageArgTypes.size() != expectedTargetArgTypes.size())
938  return emitOpError("expected ") << expectedTargetArgTypes.size()
939  << " arguments in the target stage, got "
940  << targetStageArgTypes.size() << ".";
941 
942  for (auto [index, it] : llvm::enumerate(
943  llvm::zip(expectedTargetArgTypes, targetStageArgTypes))) {
944  auto [arg, barg] = it;
945  if (arg != barg)
946  return emitOpError("expected target stage argument ")
947  << index << " to have type " << arg << ", got " << barg << ".";
948  }
949 
950  // Verify that the clock gate index list is equally sized to the # of
951  // registers.
952  if (getClockGatesPerRegister().size() != getRegisters().size())
953  return emitOpError("expected clockGatesPerRegister to be equally sized to "
954  "the number of registers.");
955 
956  // Verify that, if provided, the list of register names is equally sized
957  // to the number of registers.
958  if (auto regNames = getRegisterNames()) {
959  if (regNames->size() != getRegisters().size())
960  return emitOpError("expected registerNames to be equally sized to "
961  "the number of registers.");
962  }
963 
964  // Verify that, if provided, the list of passthrough names is equally sized
965  // to the number of passthroughs.
966  if (auto passthroughNames = getPassthroughNames()) {
967  if (passthroughNames->size() != getPassthroughs().size())
968  return emitOpError("expected passthroughNames to be equally sized to "
969  "the number of passthroughs.");
970  }
971 
972  return success();
973 }
974 
975 //===----------------------------------------------------------------------===//
976 // LatencyOp
977 //===----------------------------------------------------------------------===//
978 
979 LogicalResult LatencyOp::verify() {
980  ScheduledPipelineOp scheduledPipelineParent =
981  dyn_cast<ScheduledPipelineOp>(getOperation()->getParentOp());
982 
983  if (!scheduledPipelineParent) {
984  // Nothing to verify, got to assume that anything goes in an unscheduled
985  // pipeline.
986  return success();
987  }
988 
989  // Verify that the resulting values aren't referenced before they are
990  // accessible.
991  size_t latency = getLatency();
992  Block *definingStage = getOperation()->getBlock();
993 
994  llvm::DenseMap<Block *, unsigned> stageMap =
995  scheduledPipelineParent.getStageMap();
996 
997  auto stageDistance = [&](Block *from, Block *to) {
998  assert(stageMap.count(from) && "stage 'from' not contained in pipeline");
999  assert(stageMap.count(to) && "stage 'to' not contained in pipeline");
1000  int64_t fromStage = stageMap[from];
1001  int64_t toStage = stageMap[to];
1002  return toStage - fromStage;
1003  };
1004 
1005  for (auto [i, res] : llvm::enumerate(getResults())) {
1006  for (auto &use : res.getUses()) {
1007  auto *user = use.getOwner();
1008 
1009  // The user may reside within a block which is not a stage (e.g.
1010  // inside a pipeline.latency op). Determine the stage which this use
1011  // resides within.
1012  Block *userStage =
1013  getParentStageInPipeline(scheduledPipelineParent, user);
1014  unsigned useDistance = stageDistance(definingStage, userStage);
1015 
1016  // Is this a stage op and is the value passed through? if so, this is
1017  // a legal use.
1018  StageOp stageOp = dyn_cast<StageOp>(user);
1019  if (userStage == definingStage && stageOp) {
1020  if (llvm::is_contained(stageOp.getPassthroughs(), res))
1021  continue;
1022  }
1023 
1024  // The use is not a passthrough. Check that the distance between
1025  // the defining stage and the user stage is at least the latency of
1026  // the result.
1027  if (useDistance < latency) {
1028  auto diag = emitOpError("result ")
1029  << i << " is used before it is available.";
1030  diag.attachNote(user->getLoc())
1031  << "use was operand " << use.getOperandNumber()
1032  << ". The result is available " << latency - useDistance
1033  << " stages later than this use.";
1034  return diag;
1035  }
1036  }
1037  }
1038  return success();
1039 }
1040 
1041 //===----------------------------------------------------------------------===//
1042 // LatencyReturnOp
1043 //===----------------------------------------------------------------------===//
1044 
1045 LogicalResult LatencyReturnOp::verify() {
1046  LatencyOp parent = cast<LatencyOp>(getOperation()->getParentOp());
1047  size_t nInputs = getInputs().size();
1048  size_t nResults = parent->getNumResults();
1049  if (nInputs != nResults)
1050  return emitOpError("expected ")
1051  << nResults << " return values, got " << nInputs << ".";
1052 
1053  for (auto [inType, reqType] :
1054  llvm::zip(getInputs().getTypes(), parent->getResultTypes())) {
1055  if (inType != reqType)
1056  return emitOpError("expected return value of type ")
1057  << reqType << ", got " << inType << ".";
1058  }
1059 
1060  return success();
1061 }
1062 
1063 #define GET_OP_CLASSES
1064 #include "circt/Dialect/Pipeline/Pipeline.cpp.inc"
1065 
1066 void PipelineDialect::initialize() {
1067  addOperations<
1068 #define GET_OP_LIST
1069 #include "circt/Dialect/Pipeline/Pipeline.cpp.inc"
1070  >();
1071 }
assert(baseType &&"element must be base type")
static InstancePath empty
llvm::SmallVector< StringAttr > inputs
Builder builder
static ParseResult parseSingleStageRegister(OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t, llvm::SmallVector< OpAsmParser::UnresolvedOperand > &clockGates, StringAttr &name)
static ParseResult parseKeywordAndOperand(OpAsmParser &p, StringRef keyword, OpAsmParser::UnresolvedOperand &op)
static void printOutputList(OpAsmPrinter &p, TypeRange types, ArrayAttr names)
ParseResult parsePassthroughs(OpAsmParser &parser, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &passthroughs, llvm::SmallVector< mlir::Type, 1 > &passthroughTypes, ArrayAttr &passthroughNames)
static void printPipelineOp(OpAsmPrinter &p, TPipelineOp op)
void printPassthroughs(OpAsmPrinter &p, Operation *op, ValueRange passthroughs, TypeRange passthroughTypes, ArrayAttr names)
static ParseResult parseOptNamedTypedAssignment(OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t, StringAttr &name)
static ParseResult parsePipelineOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
static FailureOr< llvm::SmallVector< Block * > > getOrderedStagesFailable(ScheduledPipelineOp op)
void printStageRegisters(OpAsmPrinter &p, Operation *op, ValueRange registers, TypeRange registerTypes, ValueRange clockGates, ArrayAttr clockGatesPerRegister, ArrayAttr names)
static void getPipelineAsmBlockArgumentNames(TPipelineOp op, mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn)
static ParseResult parseOutputList(OpAsmParser &parser, llvm::SmallVector< Type > &inputTypes, mlir::ArrayAttr &outputNames)
Definition: PipelineOps.cpp:85
static void printKeywordOperand(OpAsmPrinter &p, StringRef keyword, Value value)
ParseResult parseStageRegisters(OpAsmParser &parser, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &registers, llvm::SmallVector< mlir::Type, 1 > &registerTypes, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &clockGates, ArrayAttr &clockGatesPerRegister, ArrayAttr &registerNames)
static void buildPipelineLikeOp(OpBuilder &odsBuilder, OperationState &odsState, TypeRange dataOutputs, ValueRange inputs, ArrayAttr inputNames, ArrayAttr outputNames, Value clock, Value reset, Value go, Value stall, StringAttr name, ArrayAttr stallability)
static void getPipelineAsmResultNames(TPipelineOp op, OpAsmSetValueNameFn setNameFn)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
ParseResult parseInitializerList(mlir::OpAsmParser &parser, llvm::SmallVector< mlir::OpAsmParser::Argument > &inputArguments, llvm::SmallVector< mlir::OpAsmParser::UnresolvedOperand > &inputOperands, llvm::SmallVector< Type > &inputTypes, ArrayAttr &inputNames)
Parses an initializer.
void printInitializerList(OpAsmPrinter &p, ValueRange ins, ArrayRef< BlockArgument > args)
llvm::SmallVector< Value > getValuesDefinedOutsideRegion(Region &region)
Definition: PipelineOps.cpp:32
Block * getParentStageInPipeline(ScheduledPipelineOp pipeline, Operation *op)
Definition: PipelineOps.cpp:66
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition: LLVM.h:186
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:20