CIRCT  19.0.0git
PipelineOps.cpp
Go to the documentation of this file.
1 //===- PipelineOps.h - Pipeline MLIR Operations -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the Pipeline ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Interfaces/FunctionImplementation.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 
22 using namespace mlir;
23 using namespace circt;
24 using namespace circt::pipeline;
25 using namespace circt::parsing_util;
26 
27 #include "circt/Dialect/Pipeline/PipelineDialect.cpp.inc"
28 
29 #define DEBUG_TYPE "pipeline-ops"
30 
31 llvm::SmallVector<Value>
33  llvm::SetVector<Value> values;
34  region.walk([&](Operation *op) {
35  for (auto operand : op->getOperands()) {
36  if (region.isAncestor(operand.getParentRegion()))
37  continue;
38  values.insert(operand);
39  }
40  });
41  return values.takeVector();
42 }
43 
44 Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
45  Block *block) {
46  // Optional debug check - ensure that 'block' eventually leads to the
47  // pipeline.
48  LLVM_DEBUG({
49  Operation *directParent = block->getParentOp();
50  if (directParent != pipeline) {
51  auto indirectParent =
52  directParent->getParentOfType<ScheduledPipelineOp>();
53  assert(indirectParent == pipeline && "block is not in the pipeline");
54  }
55  });
56 
57  while (block && block->getParent() != &pipeline.getRegion()) {
58  // Go one level up.
59  block = block->getParent()->getParentOp()->getBlock();
60  }
61 
62  // This is a block within the pipeline region, so it must be a stage.
63  return block;
64 }
65 
66 Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
67  Operation *op) {
68  return getParentStageInPipeline(pipeline, op->getBlock());
69 }
70 
71 Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
72  Value v) {
73  if (isa<BlockArgument>(v))
74  return getParentStageInPipeline(pipeline,
75  cast<BlockArgument>(v).getOwner());
76  return getParentStageInPipeline(pipeline, v.getDefiningOp());
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // Fancy pipeline-like op printer/parser functions.
81 //===----------------------------------------------------------------------===//
82 
83 // Parses a list of operands on the format:
84 // (name : type, ...)
85 static ParseResult parseOutputList(OpAsmParser &parser,
86  llvm::SmallVector<Type> &inputTypes,
87  mlir::ArrayAttr &outputNames) {
88 
89  llvm::SmallVector<Attribute> names;
90  if (parser.parseCommaSeparatedList(
91  OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
92  StringRef name;
93  Type type;
94  if (parser.parseKeyword(&name) || parser.parseColonType(type))
95  return failure();
96 
97  inputTypes.push_back(type);
98  names.push_back(StringAttr::get(parser.getContext(), name));
99  return success();
100  }))
101  return failure();
102 
103  outputNames = ArrayAttr::get(parser.getContext(), names);
104  return success();
105 }
106 
107 static void printOutputList(OpAsmPrinter &p, TypeRange types, ArrayAttr names) {
108  p << "(";
109  llvm::interleaveComma(llvm::zip(types, names), p, [&](auto it) {
110  auto [type, name] = it;
111  p.printKeywordOrString(cast<StringAttr>(name).str());
112  p << " : " << type;
113  });
114  p << ")";
115 }
116 
117 static ParseResult parseKeywordAndOperand(OpAsmParser &p, StringRef keyword,
118  OpAsmParser::UnresolvedOperand &op) {
119  if (p.parseKeyword(keyword) || p.parseLParen() || p.parseOperand(op) ||
120  p.parseRParen())
121  return failure();
122  return success();
123 }
124 
125 // Assembly format is roughly:
126 // ( $name )? initializer-list stall (%stall = $stall)?
127 // clock (%clock) reset (%reset) go(%go) entryEnable(%en) {
128 // --- elided inner block ---
129 static ParseResult parsePipelineOp(mlir::OpAsmParser &parser,
130  mlir::OperationState &result) {
131  std::string name;
132  if (succeeded(parser.parseOptionalString(&name)))
133  result.addAttribute("name", parser.getBuilder().getStringAttr(name));
134 
135  llvm::SmallVector<OpAsmParser::UnresolvedOperand> inputOperands;
136  llvm::SmallVector<OpAsmParser::Argument> inputArguments;
137  llvm::SmallVector<Type> inputTypes;
138  ArrayAttr inputNames;
139  if (parseInitializerList(parser, inputArguments, inputOperands, inputTypes,
140  inputNames))
141  return failure();
142  result.addAttribute("inputNames", inputNames);
143 
144  Type i1 = parser.getBuilder().getI1Type();
145  // Parse optional 'stall (%stallArg)'
146  OpAsmParser::Argument stallArg;
147  OpAsmParser::UnresolvedOperand stallOperand;
148  bool withStall = false;
149  if (succeeded(parser.parseOptionalKeyword("stall"))) {
150  if (parser.parseLParen() || parser.parseOperand(stallOperand) ||
151  parser.parseRParen())
152  return failure();
153  withStall = true;
154  }
155 
156  // Parse clock, reset, and go.
157  OpAsmParser::UnresolvedOperand clockOperand, resetOperand, goOperand;
158  if (parseKeywordAndOperand(parser, "clock", clockOperand) ||
159  parseKeywordAndOperand(parser, "reset", resetOperand) ||
160  parseKeywordAndOperand(parser, "go", goOperand))
161  return failure();
162 
163  // Parse entry stage enable block argument.
164  OpAsmParser::Argument entryEnable;
165  entryEnable.type = i1;
166  if (parser.parseKeyword("entryEn") || parser.parseLParen() ||
167  parser.parseArgument(entryEnable) || parser.parseRParen())
168  return failure();
169 
170  // Optional attribute dict
171  if (parser.parseOptionalAttrDict(result.attributes))
172  return failure();
173 
174  // Parse the output assignment list
175  if (parser.parseArrow())
176  return failure();
177 
178  llvm::SmallVector<Type> outputTypes;
179  ArrayAttr outputNames;
180  if (parseOutputList(parser, outputTypes, outputNames))
181  return failure();
182  result.addTypes(outputTypes);
183  result.addAttribute("outputNames", outputNames);
184 
185  // And the implicit 'done' output.
186  result.addTypes({i1});
187 
188  // All operands have been parsed - resolve.
189  if (parser.resolveOperands(inputOperands, inputTypes, parser.getNameLoc(),
190  result.operands))
191  return failure();
192 
193  if (withStall) {
194  if (parser.resolveOperand(stallOperand, i1, result.operands))
195  return failure();
196  }
197 
198  Type clkType = seq::ClockType::get(parser.getContext());
199  if (parser.resolveOperand(clockOperand, clkType, result.operands) ||
200  parser.resolveOperand(resetOperand, i1, result.operands) ||
201  parser.resolveOperand(goOperand, i1, result.operands))
202  return failure();
203 
204  // Assemble the body region block arguments - this is where the magic happens
205  // and why we're doing a custom printer/parser - if the user had to magically
206  // know the order of these block arguments, we're asking for issues.
207  SmallVector<OpAsmParser::Argument> regionArgs;
208 
209  // First we add the input arguments.
210  llvm::append_range(regionArgs, inputArguments);
211  // Then the internal entry stage enable block argument.
212  regionArgs.push_back(entryEnable);
213 
214  // Parse the body region.
215  Region *body = result.addRegion();
216  if (parser.parseRegion(*body, regionArgs))
217  return failure();
218 
219  result.addAttribute(
220  "operandSegmentSizes",
221  parser.getBuilder().getDenseI32ArrayAttr(
222  {static_cast<int32_t>(inputTypes.size()),
223  static_cast<int32_t>(withStall ? 1 : 0),
224  /*clock*/ static_cast<int32_t>(1),
225  /*reset*/ static_cast<int32_t>(1), /*go*/ static_cast<int32_t>(1)}));
226 
227  return success();
228 }
229 
230 static void printKeywordOperand(OpAsmPrinter &p, StringRef keyword,
231  Value value) {
232  p << keyword << "(";
233  p.printOperand(value);
234  p << ")";
235 }
236 
237 template <typename TPipelineOp>
238 static void printPipelineOp(OpAsmPrinter &p, TPipelineOp op) {
239  if (auto name = op.getNameAttr()) {
240  p << " \"" << name.getValue() << "\"";
241  }
242 
243  // Print the input list.
244  printInitializerList(p, op.getInputs(), op.getInnerInputs());
245  p << " ";
246 
247  // Print the optional stall.
248  if (op.hasStall()) {
249  p << "stall(";
250  p.printOperand(op.getStall());
251  p << ") ";
252  }
253 
254  // Print the clock, reset, and go.
255  printKeywordOperand(p, "clock", op.getClock());
256  p << " ";
257  printKeywordOperand(p, "reset", op.getReset());
258  p << " ";
259  printKeywordOperand(p, "go", op.getGo());
260  p << " ";
261 
262  // Print the entry enable block argument.
263  p << "entryEn(";
264  p.printRegionArgument(
265  cast<BlockArgument>(op.getStageEnableSignal(static_cast<size_t>(0))), {},
266  /*omitType*/ true);
267  p << ") ";
268 
269  // Print the optional attribute dict.
270  p.printOptionalAttrDict(op->getAttrs(),
271  /*elidedAttrs=*/{"name", "operandSegmentSizes",
272  "outputNames", "inputNames"});
273  p << " -> ";
274 
275  // Print the output list.
276  printOutputList(p, op.getDataOutputs().getTypes(), op.getOutputNames());
277 
278  p << " ";
279 
280  // Print the inner region, eliding the entry block arguments - we've already
281  // defined these in our initializer lists.
282  p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false,
283  /*printBlockTerminators=*/true);
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // UnscheduledPipelineOp
288 //===----------------------------------------------------------------------===//
289 
290 static void buildPipelineLikeOp(OpBuilder &odsBuilder, OperationState &odsState,
291  TypeRange dataOutputs, ValueRange inputs,
292  ArrayAttr inputNames, ArrayAttr outputNames,
293  Value clock, Value reset, Value go, Value stall,
294  StringAttr name, ArrayAttr stallability) {
295  odsState.addOperands(inputs);
296  if (stall)
297  odsState.addOperands(stall);
298  odsState.addOperands(clock);
299  odsState.addOperands(reset);
300  odsState.addOperands(go);
301  if (name)
302  odsState.addAttribute("name", name);
303 
304  odsState.addAttribute(
305  "operandSegmentSizes",
306  odsBuilder.getDenseI32ArrayAttr(
307  {static_cast<int32_t>(inputs.size()),
308  static_cast<int32_t>(stall ? 1 : 0), static_cast<int32_t>(1),
309  static_cast<int32_t>(1), static_cast<int32_t>(1)}));
310 
311  odsState.addAttribute("inputNames", inputNames);
312  odsState.addAttribute("outputNames", outputNames);
313 
314  auto *region = odsState.addRegion();
315  odsState.addTypes(dataOutputs);
316 
317  // Add the implicit done output signal.
318  Type i1 = odsBuilder.getIntegerType(1);
319  odsState.addTypes({i1});
320 
321  // Add the entry stage - arguments order:
322  // 1. Inputs
323  // 2. Stall (opt)
324  // 3. Clock
325  // 4. Reset
326  // 5. Go
327  auto &entryBlock = region->emplaceBlock();
328  llvm::SmallVector<Location> entryArgLocs(inputs.size(), odsState.location);
329  entryBlock.addArguments(
330  inputs.getTypes(),
331  llvm::SmallVector<Location>(inputs.size(), odsState.location));
332  if (stall)
333  entryBlock.addArgument(i1, odsState.location);
334  entryBlock.addArgument(i1, odsState.location);
335  entryBlock.addArgument(i1, odsState.location);
336 
337  // entry stage valid signal.
338  entryBlock.addArgument(i1, odsState.location);
339 
340  if (stallability)
341  odsState.addAttribute("stallability", stallability);
342 }
343 
344 template <typename TPipelineOp>
345 static void getPipelineAsmResultNames(TPipelineOp op,
346  OpAsmSetValueNameFn setNameFn) {
347  for (auto [res, name] :
348  llvm::zip(op.getDataOutputs(),
349  op.getOutputNames().template getAsValueRange<StringAttr>()))
350  setNameFn(res, name);
351  setNameFn(op.getDone(), "done");
352 }
353 
354 template <typename TPipelineOp>
355 static void
356 getPipelineAsmBlockArgumentNames(TPipelineOp op, mlir::Region &region,
357  mlir::OpAsmSetValueNameFn setNameFn) {
358  for (auto [i, block] : llvm::enumerate(op.getRegion())) {
359  if (Block *predBlock = block.getSinglePredecessor()) {
360  // Predecessor stageOp might have register and passthrough names
361  // specified, which we can use to name the block arguments.
362  auto predStageOp = cast<StageOp>(predBlock->getTerminator());
363  size_t nRegs = predStageOp.getRegisters().size();
364  auto nPassthrough = predStageOp.getPassthroughs().size();
365 
366  auto regNames = predStageOp.getRegisterNames();
367  auto passthroughNames = predStageOp.getPassthroughNames();
368 
369  // Register naming...
370  for (size_t regI = 0; regI < nRegs; ++regI) {
371  auto arg = block.getArguments()[regI];
372 
373  if (regNames) {
374  auto nameAttr = dyn_cast<StringAttr>((*regNames)[regI]);
375  if (nameAttr && !nameAttr.strref().empty()) {
376  setNameFn(arg, nameAttr);
377  continue;
378  }
379  }
380  setNameFn(arg, llvm::formatv("s{0}_reg{1}", i, regI).str());
381  }
382 
383  // Passthrough naming...
384  for (size_t passthroughI = 0; passthroughI < nPassthrough;
385  ++passthroughI) {
386  auto arg = block.getArguments()[nRegs + passthroughI];
387 
388  if (passthroughNames) {
389  auto nameAttr =
390  dyn_cast<StringAttr>((*passthroughNames)[passthroughI]);
391  if (nameAttr && !nameAttr.strref().empty()) {
392  setNameFn(arg, nameAttr);
393  continue;
394  }
395  }
396  setNameFn(arg, llvm::formatv("s{0}_pass{1}", i, passthroughI).str());
397  }
398  } else {
399  // This is the entry stage - name the arguments according to the input
400  // names.
401  for (auto [inputArg, inputName] :
402  llvm::zip(op.getInnerInputs(),
403  op.getInputNames().template getAsValueRange<StringAttr>()))
404  setNameFn(inputArg, inputName);
405  }
406 
407  // Last argument in any stage is the stage enable signal.
408  setNameFn(block.getArguments().back(),
409  llvm::formatv("s{0}_enable", i).str());
410  }
411 }
412 
413 void UnscheduledPipelineOp::print(OpAsmPrinter &p) {
414  printPipelineOp(p, *this);
415 }
416 
417 ParseResult UnscheduledPipelineOp::parse(OpAsmParser &parser,
418  OperationState &result) {
419  return parsePipelineOp(parser, result);
420 }
421 
423  getPipelineAsmResultNames(*this, setNameFn);
424 }
425 
426 void UnscheduledPipelineOp::getAsmBlockArgumentNames(
427  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
428  getPipelineAsmBlockArgumentNames(*this, region, setNameFn);
429 }
430 
431 void UnscheduledPipelineOp::build(OpBuilder &odsBuilder,
432  OperationState &odsState,
433  TypeRange dataOutputs, ValueRange inputs,
434  ArrayAttr inputNames, ArrayAttr outputNames,
435  Value clock, Value reset, Value go,
436  Value stall, StringAttr name,
437  ArrayAttr stallability) {
438  buildPipelineLikeOp(odsBuilder, odsState, dataOutputs, inputs, inputNames,
439  outputNames, clock, reset, go, stall, name, stallability);
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // ScheduledPipelineOp
444 //===----------------------------------------------------------------------===//
445 
446 void ScheduledPipelineOp::print(OpAsmPrinter &p) { printPipelineOp(p, *this); }
447 
448 ParseResult ScheduledPipelineOp::parse(OpAsmParser &parser,
449  OperationState &result) {
450  return parsePipelineOp(parser, result);
451 }
452 
453 void ScheduledPipelineOp::build(OpBuilder &odsBuilder, OperationState &odsState,
454  TypeRange dataOutputs, ValueRange inputs,
455  ArrayAttr inputNames, ArrayAttr outputNames,
456  Value clock, Value reset, Value go, Value stall,
457  StringAttr name, ArrayAttr stallability) {
458  buildPipelineLikeOp(odsBuilder, odsState, dataOutputs, inputs, inputNames,
459  outputNames, clock, reset, go, stall, name, stallability);
460 }
461 
462 Block *ScheduledPipelineOp::addStage() {
463  OpBuilder builder(getContext());
464  Block *stage = builder.createBlock(&getRegion());
465 
466  // Add the stage valid signal.
467  stage->addArgument(builder.getIntegerType(1), getLoc());
468  return stage;
469 }
470 
471 void ScheduledPipelineOp::getAsmBlockArgumentNames(
472  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
473  getPipelineAsmBlockArgumentNames(*this, region, setNameFn);
474 }
475 
477  getPipelineAsmResultNames(*this, setNameFn);
478 }
479 
480 // Implementation of getOrderedStages which also produces an error if
481 // there are any cfg cycles in the pipeline.
483 getOrderedStagesFailable(ScheduledPipelineOp op) {
484  llvm::DenseSet<Block *> visited;
485  Block *currentStage = op.getEntryStage();
486  llvm::SmallVector<Block *> orderedStages;
487  do {
488  if (!visited.insert(currentStage).second)
489  return op.emitOpError("pipeline contains a cycle.");
490 
491  orderedStages.push_back(currentStage);
492  if (auto stageOp = dyn_cast<StageOp>(currentStage->getTerminator()))
493  currentStage = stageOp.getNextStage();
494  else
495  currentStage = nullptr;
496  } while (currentStage);
497 
498  return {orderedStages};
499 }
500 
501 llvm::SmallVector<Block *> ScheduledPipelineOp::getOrderedStages() {
502  // Should always be safe, seeing as the pipeline itself has already been
503  // verified.
504  return *getOrderedStagesFailable(*this);
505 }
506 
507 llvm::DenseMap<Block *, unsigned> ScheduledPipelineOp::getStageMap() {
508  llvm::DenseMap<Block *, unsigned> stageMap;
509  auto orderedStages = getOrderedStages();
510  for (auto [index, stage] : llvm::enumerate(orderedStages))
511  stageMap[stage] = index;
512 
513  return stageMap;
514 }
515 
516 Block *ScheduledPipelineOp::getLastStage() { return getOrderedStages().back(); }
517 
518 bool ScheduledPipelineOp::isMaterialized() {
519  // We determine materialization as if any pipeline stage has an explicit
520  // input (apart from the stage valid signal).
521  return llvm::any_of(getStages(), [this](Block &block) {
522  // The entry stage doesn't count since it'll always have arguments.
523  if (&block == getEntryStage())
524  return false;
525  return block.getNumArguments() > 1;
526  });
527 }
528 
529 LogicalResult ScheduledPipelineOp::verify() {
530  // Verify that all block are terminated properly.
531  auto &stages = getStages();
532  for (Block &stage : stages) {
533  if (stage.empty() || !isa<ReturnOp, StageOp>(stage.back()))
534  return emitOpError("all blocks must be terminated with a "
535  "`pipeline.stage` or `pipeline.return` op.");
536  }
537 
538  if (failed(getOrderedStagesFailable(*this)))
539  return failure();
540 
541  // Verify that every stage has a stage valid block argument.
542  for (auto [i, block] : llvm::enumerate(stages)) {
543  bool err = true;
544  if (block.getNumArguments() != 0) {
545  auto lastArgType =
546  dyn_cast<IntegerType>(block.getArguments().back().getType());
547  err = !lastArgType || lastArgType.getWidth() != 1;
548  }
549  if (err)
550  return emitOpError("block " + std::to_string(i) +
551  " must have an i1 argument as the last block argument "
552  "(stage valid signal).");
553  }
554 
555  // Cache external inputs in a set for fast lookup (also includes clock, reset,
556  // and stall).
557  llvm::DenseSet<Value> extLikeInputs;
558  for (auto extInput : getExtInputs())
559  extLikeInputs.insert(extInput);
560 
561  extLikeInputs.insert(getClock());
562  extLikeInputs.insert(getReset());
563  if (hasStall())
564  extLikeInputs.insert(getStall());
565 
566  // Phase invariant - if any block has arguments apart from the stage valid
567  // argument, we are in register materialized mode. Check that all values
568  // used within a stage are defined within the stage.
569  bool materialized = isMaterialized();
570  if (materialized) {
571  for (auto &stage : stages) {
572  for (auto &op : stage) {
573  for (auto [index, operand] : llvm::enumerate(op.getOperands())) {
574  bool err = false;
575  if (extLikeInputs.contains(operand)) {
576  // This is an external input; legal to reference everywhere.
577  continue;
578  }
579 
580  if (auto *definingOp = operand.getDefiningOp()) {
581  // Constants are allowed to be used across stages.
582  if (definingOp->hasTrait<OpTrait::ConstantLike>())
583  continue;
584  err = definingOp->getBlock() != &stage;
585  } else {
586  // This is a block argument;
587  err = !llvm::is_contained(stage.getArguments(), operand);
588  }
589 
590  if (err)
591  return op.emitOpError(
592  "Pipeline is in register materialized mode - operand ")
593  << index
594  << " is defined in a different stage, which is illegal.";
595  }
596  }
597  }
598  }
599 
600  if (auto stallability = getStallability()) {
601  // Only allow specifying stallability if there is a stall signal.
602  if (!hasStall())
603  return emitOpError("cannot specify stallability without a stall signal.");
604 
605  // Ensure that the # of stages is equal to the length of the stallability
606  // array - the exit stage is never stallable.
607  size_t nRegisterStages = stages.size() - 1;
608  if (stallability->size() != nRegisterStages)
609  return emitOpError("stallability array must be the same length as the "
610  "number of stages. Pipeline has ")
611  << nRegisterStages << " stages but array had "
612  << stallability->size() << " elements.";
613  }
614 
615  return success();
616 }
617 
618 StageKind ScheduledPipelineOp::getStageKind(size_t stageIndex) {
619  assert(stageIndex < getNumStages() && "invalid stage index");
620 
621  if (!hasStall())
622  return StageKind::Continuous;
623 
624  // There is a stall signal - also check whether stage-level stallability is
625  // specified.
626  std::optional<ArrayAttr> stallability = getStallability();
627  if (!stallability) {
628  // All stages are stallable.
629  return StageKind::Stallable;
630  }
631 
632  if (stageIndex < stallability->size()) {
633  bool stageIsStallable =
634  cast<BoolAttr>((*stallability)[stageIndex]).getValue();
635  if (!stageIsStallable) {
636  // This is a non-stallable stage.
637  return StageKind::NonStallable;
638  }
639  }
640 
641  // Walk backwards from this stage to see if any non-stallable stage exists.
642  // If so, this is a runoff stage.
643  // TODO: This should be a pre-computed property.
644  if (stageIndex == 0)
645  return StageKind::Stallable;
646 
647  for (size_t i = stageIndex - 1; i > 0; --i) {
648  if (getStageKind(i) == StageKind::NonStallable)
649  return StageKind::Runoff;
650  }
651  return StageKind::Stallable;
652 }
653 
654 //===----------------------------------------------------------------------===//
655 // ReturnOp
656 //===----------------------------------------------------------------------===//
657 
658 LogicalResult ReturnOp::verify() {
659  Operation *parent = getOperation()->getParentOp();
660  size_t nInputs = getInputs().size();
661  auto expectedResults = TypeRange(parent->getResultTypes()).drop_back();
662  size_t expectedNResults = expectedResults.size();
663  if (nInputs != expectedNResults)
664  return emitOpError("expected ")
665  << expectedNResults << " return values, got " << nInputs << ".";
666 
667  for (auto [inType, reqType] :
668  llvm::zip(getInputs().getTypes(), expectedResults)) {
669  if (inType != reqType)
670  return emitOpError("expected return value of type ")
671  << reqType << ", got " << inType << ".";
672  }
673 
674  return success();
675 }
676 
677 //===----------------------------------------------------------------------===//
678 // StageOp
679 //===----------------------------------------------------------------------===//
680 
681 // Parses the form:
682 // ($name `=`)? $register : type($register)
683 
684 static ParseResult
685 parseOptNamedTypedAssignment(OpAsmParser &parser,
686  OpAsmParser::UnresolvedOperand &v, Type &t,
687  StringAttr &name) {
688  // Parse optional name.
689  std::string nameref;
690  if (succeeded(parser.parseOptionalString(&nameref))) {
691  if (nameref.empty())
692  return parser.emitError(parser.getCurrentLocation(),
693  "name cannot be empty");
694 
695  if (failed(parser.parseEqual()))
696  return parser.emitError(parser.getCurrentLocation(),
697  "expected '=' after name");
698  name = parser.getBuilder().getStringAttr(nameref);
699  } else {
700  name = parser.getBuilder().getStringAttr("");
701  }
702 
703  // Parse mandatory value and type.
704  if (failed(parser.parseOperand(v)) || failed(parser.parseColonType(t)))
705  return failure();
706 
707  return success();
708 }
709 
710 // Parses the form:
711 // parseOptNamedTypedAssignment (`gated by` `[` $clockGates `]`)?
712 static ParseResult parseSingleStageRegister(
713  OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t,
714  llvm::SmallVector<OpAsmParser::UnresolvedOperand> &clockGates,
715  StringAttr &name) {
716  if (failed(parseOptNamedTypedAssignment(parser, v, t, name)))
717  return failure();
718 
719  // Parse optional gated-by clause.
720  if (failed(parser.parseOptionalKeyword("gated")))
721  return success();
722 
723  if (failed(parser.parseKeyword("by")) ||
724  failed(
725  parser.parseOperandList(clockGates, OpAsmParser::Delimiter::Square)))
726  return failure();
727 
728  return success();
729 }
730 
731 // Parses the form:
732 // regs( ($name `=`)? $register : type($register) (`gated by` `[` $clockGates
733 // `]`)?, ...)
735  OpAsmParser &parser,
736  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &registers,
737  llvm::SmallVector<mlir::Type, 1> &registerTypes,
738  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &clockGates,
739  ArrayAttr &clockGatesPerRegister, ArrayAttr &registerNames) {
740 
741  if (failed(parser.parseOptionalKeyword("regs"))) {
742  clockGatesPerRegister = parser.getBuilder().getI64ArrayAttr({});
743  return success(); // no registers to parse.
744  }
745 
746  llvm::SmallVector<int64_t> clockGatesPerRegisterList;
747  llvm::SmallVector<Attribute> registerNamesList;
748  bool withNames = false;
749  if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&]() {
750  OpAsmParser::UnresolvedOperand v;
751  Type t;
752  llvm::SmallVector<OpAsmParser::UnresolvedOperand> cgs;
753  StringAttr name;
754  if (parseSingleStageRegister(parser, v, t, cgs, name))
755  return failure();
756  registers.push_back(v);
757  registerTypes.push_back(t);
758  registerNamesList.push_back(name);
759  withNames |= static_cast<bool>(name);
760  llvm::append_range(clockGates, cgs);
761  clockGatesPerRegisterList.push_back(cgs.size());
762  return success();
763  })))
764  return failure();
765 
766  clockGatesPerRegister =
767  parser.getBuilder().getI64ArrayAttr(clockGatesPerRegisterList);
768  if (withNames)
769  registerNames = parser.getBuilder().getArrayAttr(registerNamesList);
770 
771  return success();
772 }
773 
774 void printStageRegisters(OpAsmPrinter &p, Operation *op, ValueRange registers,
775  TypeRange registerTypes, ValueRange clockGates,
776  ArrayAttr clockGatesPerRegister, ArrayAttr names) {
777  if (registers.empty())
778  return;
779 
780  p << "regs(";
781  size_t clockGateStartIdx = 0;
782  llvm::interleaveComma(
783  llvm::enumerate(
784  llvm::zip(registers, registerTypes, clockGatesPerRegister)),
785  p, [&](auto it) {
786  size_t idx = it.index();
787  auto &[reg, type, nClockGatesAttr] = it.value();
788  if (names) {
789  if (auto nameAttr = dyn_cast<StringAttr>(names[idx]);
790  nameAttr && !nameAttr.strref().empty())
791  p << nameAttr << " = ";
792  }
793 
794  p << reg << " : " << type;
795  int64_t nClockGates = cast<IntegerAttr>(nClockGatesAttr).getInt();
796  if (nClockGates == 0)
797  return;
798  p << " gated by [";
799  llvm::interleaveComma(clockGates.slice(clockGateStartIdx, nClockGates),
800  p);
801  p << "]";
802  clockGateStartIdx += nClockGates;
803  });
804  p << ")";
805 }
806 
807 void printPassthroughs(OpAsmPrinter &p, Operation *op, ValueRange passthroughs,
808  TypeRange passthroughTypes, ArrayAttr names) {
809 
810  if (passthroughs.empty())
811  return;
812 
813  p << "pass(";
814  llvm::interleaveComma(
815  llvm::enumerate(llvm::zip(passthroughs, passthroughTypes)), p,
816  [&](auto it) {
817  size_t idx = it.index();
818  auto &[reg, type] = it.value();
819  if (names) {
820  if (auto nameAttr = dyn_cast<StringAttr>(names[idx]);
821  nameAttr && !nameAttr.strref().empty())
822  p << nameAttr << " = ";
823  }
824  p << reg << " : " << type;
825  });
826  p << ")";
827 }
828 
829 // Parses the form:
830 // (`pass` `(` ($name `=`)? $register : type($register), ... `)` )?
831 ParseResult parsePassthroughs(
832  OpAsmParser &parser,
833  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &passthroughs,
834  llvm::SmallVector<mlir::Type, 1> &passthroughTypes,
835  ArrayAttr &passthroughNames) {
836  if (failed(parser.parseOptionalKeyword("pass")))
837  return success(); // no passthroughs to parse.
838 
839  llvm::SmallVector<Attribute> passthroughsNameList;
840  bool withNames = false;
841  if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&]() {
842  OpAsmParser::UnresolvedOperand v;
843  Type t;
844  StringAttr name;
845  if (parseOptNamedTypedAssignment(parser, v, t, name))
846  return failure();
847  passthroughs.push_back(v);
848  passthroughTypes.push_back(t);
849  passthroughsNameList.push_back(name);
850  withNames |= static_cast<bool>(name);
851  return success();
852  })))
853  return failure();
854 
855  if (withNames)
856  passthroughNames = parser.getBuilder().getArrayAttr(passthroughsNameList);
857 
858  return success();
859 }
860 
861 void StageOp::build(OpBuilder &odsBuilder, OperationState &odsState,
862  Block *dest, ValueRange registers,
863  ValueRange passthroughs) {
864  odsState.addSuccessors(dest);
865  odsState.addOperands(registers);
866  odsState.addOperands(passthroughs);
867  odsState.addAttribute("operandSegmentSizes",
868  odsBuilder.getDenseI32ArrayAttr(
869  {static_cast<int32_t>(registers.size()),
870  static_cast<int32_t>(passthroughs.size()),
871  /*clock gates*/ static_cast<int32_t>(0)}));
872  llvm::SmallVector<int64_t> clockGatesPerRegister(registers.size(), 0);
873  odsState.addAttribute("clockGatesPerRegister",
874  odsBuilder.getI64ArrayAttr(clockGatesPerRegister));
875 }
876 
877 void StageOp::build(OpBuilder &odsBuilder, OperationState &odsState,
878  Block *dest, ValueRange registers, ValueRange passthroughs,
879  llvm::ArrayRef<llvm::SmallVector<Value>> clockGateList,
880  mlir::ArrayAttr registerNames,
881  mlir::ArrayAttr passthroughNames) {
882  build(odsBuilder, odsState, dest, registers, passthroughs);
883 
884  llvm::SmallVector<Value> clockGates;
885  llvm::SmallVector<int64_t> clockGatesPerRegister(registers.size(), 0);
886  for (auto gates : clockGateList) {
887  llvm::append_range(clockGates, gates);
888  clockGatesPerRegister.push_back(gates.size());
889  }
890  odsState.attributes.set("clockGatesPerRegister",
891  odsBuilder.getI64ArrayAttr(clockGatesPerRegister));
892  odsState.addOperands(clockGates);
893 
894  if (registerNames)
895  odsState.addAttribute("registerNames", registerNames);
896 
897  if (passthroughNames)
898  odsState.addAttribute("passthroughNames", passthroughNames);
899 }
900 
901 ValueRange StageOp::getClockGatesForReg(unsigned regIdx) {
902  assert(regIdx < getRegisters().size() && "register index out of bounds.");
903 
904  // TODO: This could be optimized quite a bit if we didn't store clock
905  // gates per register as an array of sizes... look into using properties
906  // and maybe attaching a more complex datastructure to reduce compute
907  // here.
908 
909  unsigned clockGateStartIdx = 0;
910  for (auto [index, nClockGatesAttr] :
911  llvm::enumerate(getClockGatesPerRegister().getAsRange<IntegerAttr>())) {
912  int64_t nClockGates = nClockGatesAttr.getInt();
913  if (index == regIdx) {
914  // This is the register we are looking for.
915  return getClockGates().slice(clockGateStartIdx, nClockGates);
916  }
917  // Increment the start index by the number of clock gates for this
918  // register.
919  clockGateStartIdx += nClockGates;
920  }
921 
922  llvm_unreachable("register index out of bounds.");
923 }
924 
925 LogicalResult StageOp::verify() {
926  // Verify that the target block has the correct arguments as this stage
927  // op.
928  llvm::SmallVector<Type> expectedTargetArgTypes;
929  llvm::append_range(expectedTargetArgTypes, getRegisters().getTypes());
930  llvm::append_range(expectedTargetArgTypes, getPassthroughs().getTypes());
931  Block *targetStage = getNextStage();
932  // Expected types is everything but the stage valid signal.
933  TypeRange targetStageArgTypes =
934  TypeRange(targetStage->getArgumentTypes()).drop_back();
935 
936  if (targetStageArgTypes.size() != expectedTargetArgTypes.size())
937  return emitOpError("expected ") << expectedTargetArgTypes.size()
938  << " arguments in the target stage, got "
939  << targetStageArgTypes.size() << ".";
940 
941  for (auto [index, it] : llvm::enumerate(
942  llvm::zip(expectedTargetArgTypes, targetStageArgTypes))) {
943  auto [arg, barg] = it;
944  if (arg != barg)
945  return emitOpError("expected target stage argument ")
946  << index << " to have type " << arg << ", got " << barg << ".";
947  }
948 
949  // Verify that the clock gate index list is equally sized to the # of
950  // registers.
951  if (getClockGatesPerRegister().size() != getRegisters().size())
952  return emitOpError("expected clockGatesPerRegister to be equally sized to "
953  "the number of registers.");
954 
955  // Verify that, if provided, the list of register names is equally sized
956  // to the number of registers.
957  if (auto regNames = getRegisterNames()) {
958  if (regNames->size() != getRegisters().size())
959  return emitOpError("expected registerNames to be equally sized to "
960  "the number of registers.");
961  }
962 
963  // Verify that, if provided, the list of passthrough names is equally sized
964  // to the number of passthroughs.
965  if (auto passthroughNames = getPassthroughNames()) {
966  if (passthroughNames->size() != getPassthroughs().size())
967  return emitOpError("expected passthroughNames to be equally sized to "
968  "the number of passthroughs.");
969  }
970 
971  return success();
972 }
973 
974 //===----------------------------------------------------------------------===//
975 // LatencyOp
976 //===----------------------------------------------------------------------===//
977 
978 LogicalResult LatencyOp::verify() {
979  ScheduledPipelineOp scheduledPipelineParent =
980  dyn_cast<ScheduledPipelineOp>(getOperation()->getParentOp());
981 
982  if (!scheduledPipelineParent) {
983  // Nothing to verify, got to assume that anything goes in an unscheduled
984  // pipeline.
985  return success();
986  }
987 
988  // Verify that the resulting values aren't referenced before they are
989  // accessible.
990  size_t latency = getLatency();
991  Block *definingStage = getOperation()->getBlock();
992 
993  llvm::DenseMap<Block *, unsigned> stageMap =
994  scheduledPipelineParent.getStageMap();
995 
996  auto stageDistance = [&](Block *from, Block *to) {
997  assert(stageMap.count(from) && "stage 'from' not contained in pipeline");
998  assert(stageMap.count(to) && "stage 'to' not contained in pipeline");
999  int64_t fromStage = stageMap[from];
1000  int64_t toStage = stageMap[to];
1001  return toStage - fromStage;
1002  };
1003 
1004  for (auto [i, res] : llvm::enumerate(getResults())) {
1005  for (auto &use : res.getUses()) {
1006  auto *user = use.getOwner();
1007 
1008  // The user may reside within a block which is not a stage (e.g.
1009  // inside a pipeline.latency op). Determine the stage which this use
1010  // resides within.
1011  Block *userStage =
1012  getParentStageInPipeline(scheduledPipelineParent, user);
1013  unsigned useDistance = stageDistance(definingStage, userStage);
1014 
1015  // Is this a stage op and is the value passed through? if so, this is
1016  // a legal use.
1017  StageOp stageOp = dyn_cast<StageOp>(user);
1018  if (userStage == definingStage && stageOp) {
1019  if (llvm::is_contained(stageOp.getPassthroughs(), res))
1020  continue;
1021  }
1022 
1023  // The use is not a passthrough. Check that the distance between
1024  // the defining stage and the user stage is at least the latency of
1025  // the result.
1026  if (useDistance < latency) {
1027  auto diag = emitOpError("result ")
1028  << i << " is used before it is available.";
1029  diag.attachNote(user->getLoc())
1030  << "use was operand " << use.getOperandNumber()
1031  << ". The result is available " << latency - useDistance
1032  << " stages later than this use.";
1033  return diag;
1034  }
1035  }
1036  }
1037  return success();
1038 }
1039 
1040 //===----------------------------------------------------------------------===//
1041 // LatencyReturnOp
1042 //===----------------------------------------------------------------------===//
1043 
1044 LogicalResult LatencyReturnOp::verify() {
1045  LatencyOp parent = cast<LatencyOp>(getOperation()->getParentOp());
1046  size_t nInputs = getInputs().size();
1047  size_t nResults = parent->getNumResults();
1048  if (nInputs != nResults)
1049  return emitOpError("expected ")
1050  << nResults << " return values, got " << nInputs << ".";
1051 
1052  for (auto [inType, reqType] :
1053  llvm::zip(getInputs().getTypes(), parent->getResultTypes())) {
1054  if (inType != reqType)
1055  return emitOpError("expected return value of type ")
1056  << reqType << ", got " << inType << ".";
1057  }
1058 
1059  return success();
1060 }
1061 
1062 #define GET_OP_CLASSES
1063 #include "circt/Dialect/Pipeline/Pipeline.cpp.inc"
1064 
1065 void PipelineDialect::initialize() {
1066  addOperations<
1067 #define GET_OP_LIST
1068 #include "circt/Dialect/Pipeline/Pipeline.cpp.inc"
1069  >();
1070 }
assert(baseType &&"element must be base type")
static InstancePath empty
llvm::SmallVector< StringAttr > inputs
Builder builder
static ParseResult parseSingleStageRegister(OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t, llvm::SmallVector< OpAsmParser::UnresolvedOperand > &clockGates, StringAttr &name)
static ParseResult parseKeywordAndOperand(OpAsmParser &p, StringRef keyword, OpAsmParser::UnresolvedOperand &op)
static void printOutputList(OpAsmPrinter &p, TypeRange types, ArrayAttr names)
ParseResult parsePassthroughs(OpAsmParser &parser, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &passthroughs, llvm::SmallVector< mlir::Type, 1 > &passthroughTypes, ArrayAttr &passthroughNames)
static void printPipelineOp(OpAsmPrinter &p, TPipelineOp op)
void printPassthroughs(OpAsmPrinter &p, Operation *op, ValueRange passthroughs, TypeRange passthroughTypes, ArrayAttr names)
static ParseResult parseOptNamedTypedAssignment(OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t, StringAttr &name)
static ParseResult parsePipelineOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
static FailureOr< llvm::SmallVector< Block * > > getOrderedStagesFailable(ScheduledPipelineOp op)
void printStageRegisters(OpAsmPrinter &p, Operation *op, ValueRange registers, TypeRange registerTypes, ValueRange clockGates, ArrayAttr clockGatesPerRegister, ArrayAttr names)
static void getPipelineAsmBlockArgumentNames(TPipelineOp op, mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn)
static ParseResult parseOutputList(OpAsmParser &parser, llvm::SmallVector< Type > &inputTypes, mlir::ArrayAttr &outputNames)
Definition: PipelineOps.cpp:85
static void printKeywordOperand(OpAsmPrinter &p, StringRef keyword, Value value)
ParseResult parseStageRegisters(OpAsmParser &parser, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &registers, llvm::SmallVector< mlir::Type, 1 > &registerTypes, llvm::SmallVector< OpAsmParser::UnresolvedOperand, 4 > &clockGates, ArrayAttr &clockGatesPerRegister, ArrayAttr &registerNames)
static void buildPipelineLikeOp(OpBuilder &odsBuilder, OperationState &odsState, TypeRange dataOutputs, ValueRange inputs, ArrayAttr inputNames, ArrayAttr outputNames, Value clock, Value reset, Value go, Value stall, StringAttr name, ArrayAttr stallability)
static void getPipelineAsmResultNames(TPipelineOp op, OpAsmSetValueNameFn setNameFn)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
ParseResult parseInitializerList(mlir::OpAsmParser &parser, llvm::SmallVector< mlir::OpAsmParser::Argument > &inputArguments, llvm::SmallVector< mlir::OpAsmParser::UnresolvedOperand > &inputOperands, llvm::SmallVector< Type > &inputTypes, ArrayAttr &inputNames)
Parses an initializer.
void printInitializerList(OpAsmPrinter &p, ValueRange ins, ArrayRef< BlockArgument > args)
llvm::SmallVector< Value > getValuesDefinedOutsideRegion(Region &region)
Definition: PipelineOps.cpp:32
Block * getParentStageInPipeline(ScheduledPipelineOp pipeline, Operation *op)
Definition: PipelineOps.cpp:66
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition: LLVM.h:186
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:20