CIRCT  19.0.0git
HandshakeOps.cpp
Go to the documentation of this file.
1 //===- HandshakeOps.cpp - Handshake MLIR Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the declaration of the Handshake operations struct.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "circt/Support/LLVM.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/Interfaces/FunctionImplementation.h"
28 #include "mlir/Transforms/InliningUtils.h"
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <set>
34 
35 using namespace circt;
36 using namespace circt::handshake;
37 
38 namespace circt {
39 namespace handshake {
40 #include "circt/Dialect/Handshake/HandshakeCanonicalization.h.inc"
41 
42 bool isControlOpImpl(Operation *op) {
43  if (auto sostInterface = dyn_cast<SOSTInterface>(op); sostInterface)
44  return sostInterface.sostIsControl();
45 
46  return false;
47 }
48 
49 } // namespace handshake
50 } // namespace circt
51 
52 static std::string defaultOperandName(unsigned int idx) {
53  return "in" + std::to_string(idx);
54 }
55 
56 static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v) {
57  if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
58  return failure();
59  return success();
60 }
61 
62 static ParseResult
63 parseSostOperation(OpAsmParser &parser,
64  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
65  OperationState &result, int &size, Type &type,
66  bool explicitSize) {
67  if (explicitSize)
68  if (parseIntInSquareBrackets(parser, size))
69  return failure();
70 
71  if (parser.parseOperandList(operands) ||
72  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
73  parser.parseType(type))
74  return failure();
75 
76  if (!explicitSize)
77  size = operands.size();
78  return success();
79 }
80 
81 /// Verifies whether an indexing value is wide enough to index into a provided
82 /// number of operands.
83 static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal,
84  uint64_t numOperands) {
85  auto indexType = indexVal.getType();
86  unsigned indexWidth;
87 
88  // Determine the bitwidth of the indexing value
89  if (auto integerType = dyn_cast<IntegerType>(indexType))
90  indexWidth = integerType.getWidth();
91  else if (indexType.isIndex())
92  indexWidth = IndexType::kInternalStorageBitWidth;
93  else
94  return op->emitError("unsupported type for indexing value: ") << indexType;
95 
96  // Check whether the bitwidth can support the provided number of operands
97  if (indexWidth < 64) {
98  uint64_t maxNumOperands = (uint64_t)1 << indexWidth;
99  if (numOperands > maxNumOperands)
100  return op->emitError("bitwidth of indexing value is ")
101  << indexWidth << ", which can index into " << maxNumOperands
102  << " operands, but found " << numOperands << " operands";
103  }
104  return success();
105 }
106 
107 static bool isControlCheckTypeAndOperand(Type dataType, Value operand) {
108  // The operation is a control operation if its operand data type is a
109  // NoneType.
110  if (isa<NoneType>(dataType))
111  return true;
112 
113  // Otherwise, the operation is a control operation if the operation's
114  // operand originates from the control network
115  auto *defOp = operand.getDefiningOp();
116  return isa_and_nonnull<ControlMergeOp>(defOp) &&
117  operand == defOp->getResult(0);
118 }
119 
120 template <typename TMemOp>
121 llvm::SmallVector<handshake::MemLoadInterface> getLoadPorts(TMemOp op) {
122  llvm::SmallVector<MemLoadInterface> ports;
123  // Memory interface refresher:
124  // Operands:
125  // all stores (stdata1, staddr1, stdata2, staddr2, ...)
126  // then all loads (ldaddr1, ldaddr2,...)
127  // Outputs: load addresses (lddata1, lddata2, ...), followed by all none
128  // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
129  unsigned stCount = op.getStCount();
130  unsigned ldCount = op.getLdCount();
131  for (unsigned i = 0, e = ldCount; i != e; ++i) {
132  MemLoadInterface ldif;
133  ldif.index = i;
134  ldif.addressIn = op.getInputs()[stCount * 2 + i];
135  ldif.dataOut = op.getResult(i);
136  ldif.doneOut = op.getResult(ldCount + stCount + i);
137  ports.push_back(ldif);
138  }
139  return ports;
140 }
141 
142 template <typename TMemOp>
143 llvm::SmallVector<handshake::MemStoreInterface> getStorePorts(TMemOp op) {
144  llvm::SmallVector<MemStoreInterface> ports;
145  // Memory interface refresher:
146  // Operands:
147  // all stores (stdata1, staddr1, stdata2, staddr2, ...)
148  // then all loads (ldaddr1, ldaddr2,...)
149  // Outputs: load data (lddata1, lddata2, ...), followed by all none
150  // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
151  unsigned ldCount = op.getLdCount();
152  for (unsigned i = 0, e = op.getStCount(); i != e; ++i) {
153  MemStoreInterface stif;
154  stif.index = i;
155  stif.dataIn = op.getInputs()[i * 2];
156  stif.addressIn = op.getInputs()[i * 2 + 1];
157  stif.doneOut = op.getResult(ldCount + i);
158  ports.push_back(stif);
159  }
160  return ports;
161 }
162 
163 unsigned ForkOp::getSize() { return getResults().size(); }
164 
165 static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result) {
166  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
167  Type type;
168  ArrayRef<Type> operandTypes(type);
169  SmallVector<Type, 1> resultTypes;
170  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
171  int size;
172  if (parseSostOperation(parser, allOperands, result, size, type, true))
173  return failure();
174 
175  resultTypes.assign(size, type);
176  result.addTypes(resultTypes);
177  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
178  result.operands))
179  return failure();
180  return success();
181 }
182 
183 ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
184  return parseForkOp(parser, result);
185 }
186 
187 void ForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
188 
189 namespace {
190 
191 struct EliminateUnusedForkResultsPattern : mlir::OpRewritePattern<ForkOp> {
193 
194  LogicalResult matchAndRewrite(ForkOp op,
195  PatternRewriter &rewriter) const override {
196  std::set<unsigned> unusedIndexes;
197 
198  for (auto res : llvm::enumerate(op.getResults()))
199  if (res.value().getUses().empty())
200  unusedIndexes.insert(res.index());
201 
202  if (unusedIndexes.empty())
203  return failure();
204 
205  // Create a new fork op, dropping the unused results.
206  rewriter.setInsertionPoint(op);
207  auto operand = op.getOperand();
208  auto newFork = rewriter.create<ForkOp>(
209  op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
210  rewriter.modifyOpInPlace(op, [&] {
211  unsigned i = 0;
212  for (auto oldRes : llvm::enumerate(op.getResults()))
213  if (unusedIndexes.count(oldRes.index()) == 0)
214  oldRes.value().replaceAllUsesWith(newFork.getResults()[i++]);
215  });
216  rewriter.eraseOp(op);
217  return success();
218  }
219 };
220 
221 struct EliminateForkToForkPattern : mlir::OpRewritePattern<ForkOp> {
223 
224  LogicalResult matchAndRewrite(ForkOp op,
225  PatternRewriter &rewriter) const override {
226  auto parentForkOp = op.getOperand().getDefiningOp<ForkOp>();
227  if (!parentForkOp)
228  return failure();
229 
230  /// Create the fork with as many outputs as the two source forks.
231  /// Keeping the op.operand() output may or may not be redundant (dependning
232  /// on if op is the single user of the value), but we'll let
233  /// EliminateUnusedForkResultsPattern apply in that case.
234  unsigned totalNumOuts = op.getSize() + parentForkOp.getSize();
235  rewriter.modifyOpInPlace(parentForkOp, [&] {
236  /// Create a new parent fork op which produces all of the fork outputs and
237  /// replace all of the uses of the old results.
238  auto newParentForkOp = rewriter.create<ForkOp>(
239  parentForkOp.getLoc(), parentForkOp.getOperand(), totalNumOuts);
240 
241  for (auto it :
242  llvm::zip(parentForkOp->getResults(), newParentForkOp.getResults()))
243  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
244 
245  /// Replace the results of the matches fork op with the corresponding
246  /// results of the new parent fork op.
247  rewriter.replaceOp(op,
248  newParentForkOp.getResults().take_back(op.getSize()));
249  });
250  rewriter.eraseOp(parentForkOp);
251  return success();
252  }
253 };
254 
255 } // namespace
256 
257 void handshake::ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
258  MLIRContext *context) {
259  results.insert<circt::handshake::EliminateSimpleForksPattern,
260  EliminateForkToForkPattern, EliminateUnusedForkResultsPattern>(
261  context);
262 }
263 
264 unsigned LazyForkOp::getSize() { return getResults().size(); }
265 
266 bool LazyForkOp::sostIsControl() {
267  return isControlCheckTypeAndOperand(getDataType(), getOperand());
268 }
269 
270 ParseResult LazyForkOp::parse(OpAsmParser &parser, OperationState &result) {
271  return parseForkOp(parser, result);
272 }
273 
274 void LazyForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
275 
276 ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
277  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
278  Type type;
279  ArrayRef<Type> operandTypes(type);
280  SmallVector<Type, 1> resultTypes, dataOperandsTypes;
281  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
282  int size;
283  if (parseSostOperation(parser, allOperands, result, size, type, false))
284  return failure();
285 
286  dataOperandsTypes.assign(size, type);
287  resultTypes.push_back(type);
288  result.addTypes(resultTypes);
289  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
290  result.operands))
291  return failure();
292  return success();
293 }
294 
295 void MergeOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
296 
297 void MergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
298  MLIRContext *context) {
299  results.insert<circt::handshake::EliminateSimpleMergesPattern>(context);
300 }
301 
302 /// Returns a dematerialized version of the value 'v', defined as the source of
303 /// the value before passing through a buffer or fork operation.
304 static Value getDematerialized(Value v) {
305  Operation *parentOp = v.getDefiningOp();
306  if (!parentOp)
307  return v;
308 
309  return llvm::TypeSwitch<Operation *, Value>(parentOp)
310  .Case<ForkOp>(
311  [&](ForkOp op) { return getDematerialized(op.getOperand()); })
312  .Case<BufferOp>(
313  [&](BufferOp op) { return getDematerialized(op.getOperand()); })
314  .Default([&](auto) { return v; });
315 }
316 
317 namespace {
318 
319 /// Eliminates muxes with identical data inputs. Data inputs are inspected as
320 /// their dematerialized versions. This has the side effect of any subsequently
321 /// unused buffers are DCE'd and forks are optimized to be narrower.
322 struct EliminateSimpleMuxesPattern : mlir::OpRewritePattern<MuxOp> {
324  LogicalResult matchAndRewrite(MuxOp op,
325  PatternRewriter &rewriter) const override {
326  Value firstDataOperand = getDematerialized(op.getDataOperands()[0]);
327  if (!llvm::all_of(op.getDataOperands(), [&](Value operand) {
328  return getDematerialized(operand) == firstDataOperand;
329  }))
330  return failure();
331  rewriter.replaceOp(op, firstDataOperand);
332  return success();
333  }
334 };
335 
336 struct EliminateUnaryMuxesPattern : OpRewritePattern<MuxOp> {
338  LogicalResult matchAndRewrite(MuxOp op,
339  PatternRewriter &rewriter) const override {
340  if (op.getDataOperands().size() != 1)
341  return failure();
342 
343  rewriter.replaceOp(op, op.getDataOperands()[0]);
344  return success();
345  }
346 };
347 
348 struct EliminateCBranchIntoMuxPattern : OpRewritePattern<MuxOp> {
350  LogicalResult matchAndRewrite(MuxOp op,
351  PatternRewriter &rewriter) const override {
352 
353  auto dataOperands = op.getDataOperands();
354  if (dataOperands.size() != 2)
355  return failure();
356 
357  // Both data operands must originate from the same cbranch
358  ConditionalBranchOp firstParentCBranch =
359  dataOperands[0].getDefiningOp<ConditionalBranchOp>();
360  if (!firstParentCBranch)
361  return failure();
362  auto secondParentCBranch =
363  dataOperands[1].getDefiningOp<ConditionalBranchOp>();
364  if (!secondParentCBranch || firstParentCBranch != secondParentCBranch)
365  return failure();
366 
367  rewriter.modifyOpInPlace(firstParentCBranch, [&] {
368  // Replace uses of the mux's output with cbranch's data input
369  rewriter.replaceOp(op, firstParentCBranch.getDataOperand());
370  });
371 
372  return success();
373  }
374 };
375 
376 } // namespace
377 
378 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
379  MLIRContext *context) {
380  results.insert<EliminateSimpleMuxesPattern, EliminateUnaryMuxesPattern,
381  EliminateCBranchIntoMuxPattern>(context);
382 }
383 
384 LogicalResult
385 MuxOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
386  ValueRange operands, DictionaryAttr attributes,
387  mlir::OpaqueProperties properties,
388  mlir::RegionRange regions,
389  SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
390  // MuxOp must have at least one data operand (in addition to the select
391  // operand)
392  if (operands.size() < 2)
393  return failure();
394  // Result type is type of any data operand
395  inferredReturnTypes.push_back(operands[1].getType());
396  return success();
397 }
398 
399 bool MuxOp::isControl() { return isa<NoneType>(getResult().getType()); }
400 
401 std::string handshake::MuxOp::getOperandName(unsigned int idx) {
402  return idx == 0 ? "select" : defaultOperandName(idx - 1);
403 }
404 
405 ParseResult MuxOp::parse(OpAsmParser &parser, OperationState &result) {
406  OpAsmParser::UnresolvedOperand selectOperand;
407  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
408  Type selectType, dataType;
409  SmallVector<Type, 1> dataOperandsTypes;
410  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
411  if (parser.parseOperand(selectOperand) || parser.parseLSquare() ||
412  parser.parseOperandList(allOperands) || parser.parseRSquare() ||
413  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
414  parser.parseType(selectType) || parser.parseComma() ||
415  parser.parseType(dataType))
416  return failure();
417 
418  int size = allOperands.size();
419  dataOperandsTypes.assign(size, dataType);
420  result.addTypes(dataType);
421  allOperands.insert(allOperands.begin(), selectOperand);
422  if (parser.resolveOperands(
423  allOperands,
424  llvm::concat<const Type>(ArrayRef<Type>(selectType),
425  ArrayRef<Type>(dataOperandsTypes)),
426  allOperandLoc, result.operands))
427  return failure();
428  return success();
429 }
430 
431 void MuxOp::print(OpAsmPrinter &p) {
432  Type selectType = getSelectOperand().getType();
433  auto ops = getOperands();
434  p << ' ' << ops.front();
435  p << " [";
436  p.printOperands(ops.drop_front());
437  p << "]";
438  p.printOptionalAttrDict((*this)->getAttrs());
439  p << " : " << selectType << ", " << getResult().getType();
440 }
441 
442 LogicalResult MuxOp::verify() {
443  return verifyIndexWideEnough(*this, getSelectOperand(),
444  getDataOperands().size());
445 }
446 
447 std::string handshake::ControlMergeOp::getResultName(unsigned int idx) {
448  assert(idx == 0 || idx == 1);
449  return idx == 0 ? "dataOut" : "index";
450 }
451 
452 ParseResult ControlMergeOp::parse(OpAsmParser &parser, OperationState &result) {
453  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
454  Type resultType, indexType;
455  SmallVector<Type> resultTypes, dataOperandsTypes;
456  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
457  int size;
458  if (parseSostOperation(parser, allOperands, result, size, resultType, false))
459  return failure();
460  // Parse type of index result
461  if (parser.parseComma() || parser.parseType(indexType))
462  return failure();
463 
464  dataOperandsTypes.assign(size, resultType);
465  resultTypes.push_back(resultType);
466  resultTypes.push_back(indexType);
467  result.addTypes(resultTypes);
468  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
469  result.operands))
470  return failure();
471  return success();
472 }
473 
474 void ControlMergeOp::print(OpAsmPrinter &p) {
475  sostPrint(p, false);
476  // Print type of index result
477  p << ", " << getIndex().getType();
478 }
479 
480 LogicalResult ControlMergeOp::verify() {
481  auto operands = getOperands();
482  if (operands.empty())
483  return emitOpError("operation must have at least one operand");
484  if (operands[0].getType() != getResult().getType())
485  return emitOpError("type of first result should match type of operands");
486  return verifyIndexWideEnough(*this, getIndex(), getNumOperands());
487 }
488 
489 LogicalResult FuncOp::verify() {
490  // If this function is external there is nothing to do.
491  if (isExternal())
492  return success();
493 
494  // Verify that the argument list of the function and the arg list of the
495  // entry block line up. The trait already verified that the number of
496  // arguments is the same between the signature and the block.
497  auto fnInputTypes = getArgumentTypes();
498  Block &entryBlock = front();
499 
500  for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
501  if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
502  return emitOpError("type of entry block argument #")
503  << i << '(' << entryBlock.getArgument(i).getType()
504  << ") must match the type of the corresponding argument in "
505  << "function signature(" << fnInputTypes[i] << ')';
506 
507  // Verify that we have a name for each argument and result of this function.
508  auto verifyPortNameAttr = [&](StringRef attrName,
509  unsigned numIOs) -> LogicalResult {
510  auto portNamesAttr = (*this)->getAttrOfType<ArrayAttr>(attrName);
511 
512  if (!portNamesAttr)
513  return emitOpError() << "expected attribute '" << attrName << "'.";
514 
515  auto portNames = portNamesAttr.getValue();
516  if (portNames.size() != numIOs)
517  return emitOpError() << "attribute '" << attrName << "' has "
518  << portNames.size()
519  << " entries but is expected to have " << numIOs
520  << ".";
521 
522  if (llvm::any_of(portNames,
523  [&](Attribute attr) { return !isa<StringAttr>(attr); }))
524  return emitOpError() << "expected all entries in attribute '" << attrName
525  << "' to be strings.";
526 
527  return success();
528  };
529  if (failed(verifyPortNameAttr("argNames", getNumArguments())))
530  return failure();
531  if (failed(verifyPortNameAttr("resNames", getNumResults())))
532  return failure();
533 
534  // Verify that all memrefs have a corresponding extmemory operation
535  for (auto arg : entryBlock.getArguments()) {
536  if (!isa<MemRefType>(arg.getType()))
537  continue;
538  if (arg.getUsers().empty() ||
539  !isa<ExternalMemoryOp>(*arg.getUsers().begin()))
540  return emitOpError("expected that block argument #")
541  << arg.getArgNumber() << " is used by an 'extmemory' operation";
542  }
543 
544  return success();
545 }
546 
547 /// Parses a FuncOp signature using
548 /// mlir::function_interface_impl::parseFunctionSignature while getting access
549 /// to the parsed SSA names to store as attributes.
550 static ParseResult
551 parseFuncOpArgs(OpAsmParser &parser,
552  SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
553  SmallVectorImpl<Type> &resTypes,
554  SmallVectorImpl<DictionaryAttr> &resAttrs) {
555  bool isVariadic;
556  if (mlir::function_interface_impl::parseFunctionSignature(
557  parser, /*allowVariadic=*/true, entryArgs, isVariadic, resTypes,
558  resAttrs)
559  .failed())
560  return failure();
561 
562  return success();
563 }
564 
565 /// Generates names for a handshake.func input and output arguments, based on
566 /// the number of args as well as a prefix.
567 static SmallVector<Attribute> getFuncOpNames(Builder &builder, unsigned cnt,
568  StringRef prefix) {
569  SmallVector<Attribute> resNames;
570  for (unsigned i = 0; i < cnt; ++i)
571  resNames.push_back(builder.getStringAttr(prefix + std::to_string(i)));
572  return resNames;
573 }
574 
575 void handshake::FuncOp::build(OpBuilder &builder, OperationState &state,
576  StringRef name, FunctionType type,
577  ArrayRef<NamedAttribute> attrs) {
578  state.addAttribute(SymbolTable::getSymbolAttrName(),
579  builder.getStringAttr(name));
580  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
581  TypeAttr::get(type));
582  state.attributes.append(attrs.begin(), attrs.end());
583 
584  if (const auto *argNamesAttrIt = llvm::find_if(
585  attrs, [&](auto attr) { return attr.getName() == "argNames"; });
586  argNamesAttrIt == attrs.end())
587  state.addAttribute("argNames", builder.getArrayAttr({}));
588 
589  if (llvm::find_if(attrs, [&](auto attr) {
590  return attr.getName() == "resNames";
591  }) == attrs.end())
592  state.addAttribute("resNames", builder.getArrayAttr({}));
593 
594  state.addRegion();
595 }
596 
597 /// Helper function for appending a string to an array attribute, and
598 /// rewriting the attribute back to the operation.
599 static void addStringToStringArrayAttr(Builder &builder, Operation *op,
600  StringRef attrName, StringAttr str) {
601  llvm::SmallVector<Attribute> attrs;
602  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
603  std::back_inserter(attrs));
604  attrs.push_back(str);
605  op->setAttr(attrName, builder.getArrayAttr(attrs));
606 }
607 
608 void handshake::FuncOp::resolveArgAndResNames() {
609  Builder builder(getContext());
610 
611  /// Generate a set of fallback names. These are used in case names are
612  /// missing from the currently set arg- and res name attributes.
613  auto fallbackArgNames = getFuncOpNames(builder, getNumArguments(), "in");
614  auto fallbackResNames = getFuncOpNames(builder, getNumResults(), "out");
615  auto argNames = getArgNames().getValue();
616  auto resNames = getResNames().getValue();
617 
618  /// Use fallback names where actual names are missing.
619  auto resolveNames = [&](auto &fallbackNames, auto &actualNames,
620  StringRef attrName) {
621  for (auto fallbackName : llvm::enumerate(fallbackNames)) {
622  if (actualNames.size() <= fallbackName.index())
623  addStringToStringArrayAttr(builder, this->getOperation(), attrName,
624  cast<StringAttr>(fallbackName.value()));
625  }
626  };
627  resolveNames(fallbackArgNames, argNames, "argNames");
628  resolveNames(fallbackResNames, resNames, "resNames");
629 }
630 
631 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
632  auto &builder = parser.getBuilder();
633  StringAttr nameAttr;
634  SmallVector<OpAsmParser::Argument> args;
635  SmallVector<Type> resTypes;
636  SmallVector<DictionaryAttr> resAttributes;
637  SmallVector<Attribute> argNames;
638 
639  // Parse visibility.
640  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
641 
642  // Parse signature
643  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
644  result.attributes) ||
645  parseFuncOpArgs(parser, args, resTypes, resAttributes))
646  return failure();
647  mlir::function_interface_impl::addArgAndResultAttrs(
648  builder, result, args, resAttributes,
649  handshake::FuncOp::getArgAttrsAttrName(result.name),
650  handshake::FuncOp::getResAttrsAttrName(result.name));
651 
652  // Set function type
653  SmallVector<Type> argTypes;
654  for (auto arg : args)
655  argTypes.push_back(arg.type);
656 
657  result.addAttribute(
658  handshake::FuncOp::getFunctionTypeAttrName(result.name),
659  TypeAttr::get(builder.getFunctionType(argTypes, resTypes)));
660 
661  // Determine the names of the arguments. If no SSA values are present, use
662  // fallback names.
663  bool noSSANames =
664  llvm::any_of(args, [](auto arg) { return arg.ssaName.name.empty(); });
665  if (noSSANames) {
666  argNames = getFuncOpNames(builder, args.size(), "in");
667  } else {
668  llvm::transform(args, std::back_inserter(argNames), [&](auto arg) {
669  return builder.getStringAttr(arg.ssaName.name.drop_front());
670  });
671  }
672 
673  // Parse attributes
674  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
675  return failure();
676 
677  // If argNames and resNames wasn't provided manually, infer argNames attribute
678  // from the parsed SSA names and resNames from our naming convention.
679  if (!result.attributes.get("argNames"))
680  result.addAttribute("argNames", builder.getArrayAttr(argNames));
681  if (!result.attributes.get("resNames")) {
682  auto resNames = getFuncOpNames(builder, resTypes.size(), "out");
683  result.addAttribute("resNames", builder.getArrayAttr(resNames));
684  }
685 
686  // Parse the optional function body. The printer will not print the body if
687  // its empty, so disallow parsing of empty body in the parser.
688  auto *body = result.addRegion();
689  llvm::SMLoc loc = parser.getCurrentLocation();
690  auto parseResult = parser.parseOptionalRegion(*body, args,
691  /*enableNameShadowing=*/false);
692  if (!parseResult.has_value())
693  return success();
694 
695  if (failed(*parseResult))
696  return failure();
697  // Function body was parsed, make sure its not empty.
698  if (body->empty())
699  return parser.emitError(loc, "expected non-empty function body");
700 
701  // If a body was parsed, the arg and res names need to be resolved
702  return success();
703 }
704 
705 void FuncOp::print(OpAsmPrinter &p) {
706  mlir::function_interface_impl::printFunctionOp(
707  p, *this, /*isVariadic=*/true, getFunctionTypeAttrName(),
708  getArgAttrsAttrName(), getResAttrsAttrName());
709 }
710 
711 namespace {
712 struct EliminateSimpleControlMergesPattern
713  : mlir::OpRewritePattern<ControlMergeOp> {
715 
716  LogicalResult matchAndRewrite(ControlMergeOp op,
717  PatternRewriter &rewriter) const override;
718 };
719 } // namespace
720 
721 LogicalResult EliminateSimpleControlMergesPattern::matchAndRewrite(
722  ControlMergeOp op, PatternRewriter &rewriter) const {
723  auto dataResult = op.getResult();
724  auto choiceResult = op.getIndex();
725  auto choiceUnused = choiceResult.use_empty();
726  if (!choiceUnused && !choiceResult.hasOneUse())
727  return failure();
728 
729  Operation *choiceUser = nullptr;
730  if (choiceResult.hasOneUse()) {
731  choiceUser = choiceResult.getUses().begin().getUser();
732  if (!isa<SinkOp>(choiceUser))
733  return failure();
734  }
735 
736  auto merge = rewriter.create<MergeOp>(op.getLoc(), op.getDataOperands());
737 
738  for (auto &use : llvm::make_early_inc_range(dataResult.getUses())) {
739  auto *user = use.getOwner();
740  rewriter.modifyOpInPlace(
741  user, [&]() { user->setOperand(use.getOperandNumber(), merge); });
742  }
743 
744  if (choiceUnused) {
745  rewriter.eraseOp(op);
746  return success();
747  }
748 
749  rewriter.eraseOp(choiceUser);
750  rewriter.eraseOp(op);
751  return success();
752 }
753 
754 void ControlMergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
755  MLIRContext *context) {
756  results.insert<EliminateSimpleControlMergesPattern>(context);
757 }
758 
759 bool BranchOp::sostIsControl() {
760  return isControlCheckTypeAndOperand(getDataType(), getOperand());
761 }
762 
763 void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
764  MLIRContext *context) {
765  results.insert<circt::handshake::EliminateSimpleBranchesPattern>(context);
766 }
767 
768 ParseResult BranchOp::parse(OpAsmParser &parser, OperationState &result) {
769  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
770  Type type;
771  ArrayRef<Type> operandTypes(type);
772  SmallVector<Type, 1> dataOperandsTypes;
773  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
774  int size;
775  if (parseSostOperation(parser, allOperands, result, size, type, false))
776  return failure();
777 
778  dataOperandsTypes.assign(size, type);
779  result.addTypes({type});
780  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
781  result.operands))
782  return failure();
783  return success();
784 }
785 
786 void BranchOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
787 
788 ParseResult ConditionalBranchOp::parse(OpAsmParser &parser,
789  OperationState &result) {
790  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
791  Type dataType;
792  SmallVector<Type> operandTypes;
793  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
794  if (parser.parseOperandList(allOperands) ||
795  parser.parseOptionalAttrDict(result.attributes) ||
796  parser.parseColonType(dataType))
797  return failure();
798 
799  if (allOperands.size() != 2)
800  return parser.emitError(parser.getCurrentLocation(),
801  "Expected exactly 2 operands");
802 
803  result.addTypes({dataType, dataType});
804  operandTypes.push_back(IntegerType::get(parser.getContext(), 1));
805  operandTypes.push_back(dataType);
806  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
807  result.operands))
808  return failure();
809 
810  return success();
811 }
812 
813 void ConditionalBranchOp::print(OpAsmPrinter &p) {
814  Type type = getDataOperand().getType();
815  p << " " << getOperands();
816  p.printOptionalAttrDict((*this)->getAttrs());
817  p << " : " << type;
818 }
819 
820 std::string handshake::ConditionalBranchOp::getOperandName(unsigned int idx) {
821  assert(idx == 0 || idx == 1);
822  return idx == 0 ? "cond" : "data";
823 }
824 
825 std::string handshake::ConditionalBranchOp::getResultName(unsigned int idx) {
826  assert(idx == 0 || idx == 1);
827  return idx == ConditionalBranchOp::falseIndex ? "outFalse" : "outTrue";
828 }
829 
830 bool ConditionalBranchOp::isControl() {
831  return isControlCheckTypeAndOperand(getDataOperand().getType(),
832  getDataOperand());
833 }
834 
835 ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
836  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
837  Type type;
838  ArrayRef<Type> operandTypes(type);
839  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
840  int size;
841  if (parseSostOperation(parser, allOperands, result, size, type, false))
842  return failure();
843 
844  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
845  result.operands))
846  return failure();
847  return success();
848 }
849 
850 void SinkOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
851 
852 std::string handshake::ConstantOp::getOperandName(unsigned int idx) {
853  assert(idx == 0);
854  return "ctrl";
855 }
856 
857 Type SourceOp::getDataType() { return getResult().getType(); }
858 unsigned SourceOp::getSize() { return 1; }
859 
860 ParseResult SourceOp::parse(OpAsmParser &parser, OperationState &result) {
861  if (parser.parseOptionalAttrDict(result.attributes))
862  return failure();
863  result.addTypes(NoneType::get(result.getContext()));
864  return success();
865 }
866 
867 void SourceOp::print(OpAsmPrinter &p) {
868  p.printOptionalAttrDict((*this)->getAttrs());
869 }
870 
871 LogicalResult ConstantOp::verify() {
872  // Verify that the type of the provided value is equal to the result type.
873  auto typedValue = dyn_cast<mlir::TypedAttr>(getValue());
874  if (!typedValue)
875  return emitOpError("constant value must be a typed attribute; value is ")
876  << getValue();
877  if (typedValue.getType() != getResult().getType())
878  return emitOpError() << "constant value type " << typedValue.getType()
879  << " differs from operation result type "
880  << getResult().getType();
881 
882  return success();
883 }
884 
885 void handshake::ConstantOp::getCanonicalizationPatterns(
886  RewritePatternSet &results, MLIRContext *context) {
887  results.insert<circt::handshake::EliminateSunkConstantsPattern>(context);
888 }
889 
890 LogicalResult BufferOp::verify() {
891  // Verify that exactly 'size' number of initial values have been provided, if
892  // an initializer list have been provided.
893  if (auto initVals = getInitValues()) {
894  if (!isSequential())
895  return emitOpError()
896  << "only bufferType buffers are allowed to have initial values.";
897 
898  auto nInits = initVals->size();
899  if (nInits != getSize())
900  return emitOpError() << "expected " << getSize()
901  << " init values but got " << nInits << ".";
902  }
903 
904  return success();
905 }
906 
907 void handshake::BufferOp::getCanonicalizationPatterns(
908  RewritePatternSet &results, MLIRContext *context) {
909  results.insert<circt::handshake::EliminateSunkBuffersPattern>(context);
910 }
911 
912 unsigned BufferOp::getSize() {
913  return (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
914 }
915 
916 ParseResult BufferOp::parse(OpAsmParser &parser, OperationState &result) {
917  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
918  Type type;
919  ArrayRef<Type> operandTypes(type);
920  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
921  int slots;
922  if (parseIntInSquareBrackets(parser, slots))
923  return failure();
924 
925  auto bufferTypeAttr = BufferTypeEnumAttr::parse(parser, {});
926  if (!bufferTypeAttr)
927  return failure();
928 
929  result.addAttribute(
930  "slots",
931  IntegerAttr::get(IntegerType::get(result.getContext(), 32), slots));
932  result.addAttribute("bufferType", bufferTypeAttr);
933 
934  if (parser.parseOperandList(allOperands) ||
935  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
936  parser.parseType(type))
937  return failure();
938 
939  result.addTypes({type});
940  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
941  result.operands))
942  return failure();
943  return success();
944 }
945 
946 void BufferOp::print(OpAsmPrinter &p) {
947  int size =
948  (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
949  p << " [" << size << "]";
950  p << " " << stringifyEnum(getBufferType());
951  p << " " << (*this)->getOperands();
952  p.printOptionalAttrDict((*this)->getAttrs(), {"slots", "bufferType"});
953  p << " : " << (*this).getDataType();
954 }
955 
956 static std::string getMemoryOperandName(unsigned nStores, unsigned idx) {
957  std::string name;
958  if (idx < nStores * 2) {
959  bool isData = idx % 2 == 0;
960  name = isData ? "stData" + std::to_string(idx / 2)
961  : "stAddr" + std::to_string(idx / 2);
962  } else {
963  idx -= 2 * nStores;
964  name = "ldAddr" + std::to_string(idx);
965  }
966  return name;
967 }
968 
969 std::string handshake::MemoryOp::getOperandName(unsigned int idx) {
970  return getMemoryOperandName(getStCount(), idx);
971 }
972 
973 static std::string getMemoryResultName(unsigned nLoads, unsigned nStores,
974  unsigned idx) {
975  std::string name;
976  if (idx < nLoads)
977  name = "ldData" + std::to_string(idx);
978  else if (idx < nLoads + nStores)
979  name = "stDone" + std::to_string(idx - nLoads);
980  else
981  name = "ldDone" + std::to_string(idx - nLoads - nStores);
982  return name;
983 }
984 
985 std::string handshake::MemoryOp::getResultName(unsigned int idx) {
986  return getMemoryResultName(getLdCount(), getStCount(), idx);
987 }
988 
989 LogicalResult MemoryOp::verify() {
990  auto memrefType = getMemRefType();
991 
992  if (memrefType.getNumDynamicDims() != 0)
993  return emitOpError()
994  << "memref dimensions for handshake.memory must be static.";
995  if (memrefType.getShape().size() != 1)
996  return emitOpError() << "memref must have only a single dimension.";
997 
998  unsigned opStCount = getStCount();
999  unsigned opLdCount = getLdCount();
1000  int addressCount = memrefType.getShape().size();
1001 
1002  auto inputType = getInputs().getType();
1003  auto outputType = getOutputs().getType();
1004  Type dataType = memrefType.getElementType();
1005 
1006  unsigned numOperands = static_cast<int>(getInputs().size());
1007  unsigned numResults = static_cast<int>(getOutputs().size());
1008  if (numOperands != (1 + addressCount) * opStCount + addressCount * opLdCount)
1009  return emitOpError("number of operands ")
1010  << numOperands << " does not match number expected of "
1011  << 2 * opStCount + opLdCount << " with " << addressCount
1012  << " address inputs per port";
1013 
1014  if (numResults != opStCount + 2 * opLdCount)
1015  return emitOpError("number of results ")
1016  << numResults << " does not match number expected of "
1017  << opStCount + 2 * opLdCount << " with " << addressCount
1018  << " address inputs per port";
1019 
1020  Type addressType = opStCount > 0 ? inputType[1] : inputType[0];
1021 
1022  for (unsigned i = 0; i < opStCount; i++) {
1023  if (inputType[2 * i] != dataType)
1024  return emitOpError("data type for store port ")
1025  << i << ":" << inputType[2 * i] << " doesn't match memory type "
1026  << dataType;
1027  if (inputType[2 * i + 1] != addressType)
1028  return emitOpError("address type for store port ")
1029  << i << ":" << inputType[2 * i + 1]
1030  << " doesn't match address type " << addressType;
1031  }
1032  for (unsigned i = 0; i < opLdCount; i++) {
1033  Type ldAddressType = inputType[2 * opStCount + i];
1034  if (ldAddressType != addressType)
1035  return emitOpError("address type for load port ")
1036  << i << ":" << ldAddressType << " doesn't match address type "
1037  << addressType;
1038  }
1039  for (unsigned i = 0; i < opLdCount; i++) {
1040  if (outputType[i] != dataType)
1041  return emitOpError("data type for load port ")
1042  << i << ":" << outputType[i] << " doesn't match memory type "
1043  << dataType;
1044  }
1045  for (unsigned i = 0; i < opStCount; i++) {
1046  Type syncType = outputType[opLdCount + i];
1047  if (!isa<NoneType>(syncType))
1048  return emitOpError("data type for sync port for store port ")
1049  << i << ":" << syncType << " is not 'none'";
1050  }
1051  for (unsigned i = 0; i < opLdCount; i++) {
1052  Type syncType = outputType[opLdCount + opStCount + i];
1053  if (!isa<NoneType>(syncType))
1054  return emitOpError("data type for sync port for load port ")
1055  << i << ":" << syncType << " is not 'none'";
1056  }
1057 
1058  return success();
1059 }
1060 
1061 std::string handshake::ExternalMemoryOp::getOperandName(unsigned int idx) {
1062  if (idx == 0)
1063  return "extmem";
1064 
1065  return getMemoryOperandName(getStCount(), idx - 1);
1066 }
1067 
1068 std::string handshake::ExternalMemoryOp::getResultName(unsigned int idx) {
1069  return getMemoryResultName(getLdCount(), getStCount(), idx);
1070 }
1071 
1072 void ExternalMemoryOp::build(OpBuilder &builder, OperationState &result,
1073  Value memref, ValueRange inputs, int ldCount,
1074  int stCount, int id) {
1075  SmallVector<Value> ops;
1076  ops.push_back(memref);
1077  llvm::append_range(ops, inputs);
1078  result.addOperands(ops);
1079 
1080  auto memrefType = cast<MemRefType>(memref.getType());
1081 
1082  // Data outputs (get their type from memref)
1083  result.types.append(ldCount, memrefType.getElementType());
1084 
1085  // Control outputs
1086  result.types.append(stCount + ldCount, builder.getNoneType());
1087 
1088  // Memory ID (individual ID for each MemoryOp)
1089  Type i32Type = builder.getIntegerType(32);
1090  result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1091  result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, ldCount));
1092  result.addAttribute("stCount", builder.getIntegerAttr(i32Type, stCount));
1093 }
1094 
1095 llvm::SmallVector<handshake::MemLoadInterface>
1097  return ::getLoadPorts(*this);
1098 }
1099 
1100 llvm::SmallVector<handshake::MemStoreInterface>
1102  return ::getStorePorts(*this);
1103 }
1104 
1105 void MemoryOp::build(OpBuilder &builder, OperationState &result,
1106  ValueRange operands, int outputs, int controlOutputs,
1107  bool lsq, int id, Value memref) {
1108  result.addOperands(operands);
1109 
1110  auto memrefType = cast<MemRefType>(memref.getType());
1111 
1112  // Data outputs (get their type from memref)
1113  result.types.append(outputs, memrefType.getElementType());
1114 
1115  // Control outputs
1116  result.types.append(controlOutputs, builder.getNoneType());
1117  result.addAttribute("lsq", builder.getBoolAttr(lsq));
1118  result.addAttribute("memRefType", TypeAttr::get(memrefType));
1119 
1120  // Memory ID (individual ID for each MemoryOp)
1121  Type i32Type = builder.getIntegerType(32);
1122  result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1123 
1124  if (!lsq) {
1125  result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, outputs));
1126  result.addAttribute(
1127  "stCount", builder.getIntegerAttr(i32Type, controlOutputs - outputs));
1128  }
1129 }
1130 
1131 llvm::SmallVector<handshake::MemLoadInterface> MemoryOp::getLoadPorts() {
1132  return ::getLoadPorts(*this);
1133 }
1134 
1135 llvm::SmallVector<handshake::MemStoreInterface> MemoryOp::getStorePorts() {
1136  return ::getStorePorts(*this);
1137 }
1138 
1139 bool handshake::MemoryOp::allocateMemory(
1140  llvm::DenseMap<unsigned, unsigned> &memoryMap,
1141  std::vector<std::vector<llvm::Any>> &store,
1142  std::vector<double> &storeTimes) {
1143  if (memoryMap.count(getId()))
1144  return false;
1145 
1146  auto type = getMemRefType();
1147  std::vector<llvm::Any> in;
1148 
1149  ArrayRef<int64_t> shape = type.getShape();
1150  int allocationSize = 1;
1151  unsigned count = 0;
1152  for (int64_t dim : shape) {
1153  if (dim > 0)
1154  allocationSize *= dim;
1155  else {
1156  assert(count < in.size());
1157  allocationSize *= llvm::any_cast<APInt>(in[count++]).getSExtValue();
1158  }
1159  }
1160  unsigned ptr = store.size();
1161  store.resize(ptr + 1);
1162  storeTimes.resize(ptr + 1);
1163  store[ptr].resize(allocationSize);
1164  storeTimes[ptr] = 0.0;
1165  mlir::Type elementType = type.getElementType();
1166  int width = elementType.getIntOrFloatBitWidth();
1167  for (int i = 0; i < allocationSize; i++) {
1168  if (isa<mlir::IntegerType>(elementType)) {
1169  store[ptr][i] = APInt(width, 0);
1170  } else if (isa<mlir::FloatType>(elementType)) {
1171  store[ptr][i] = APFloat(0.0);
1172  } else {
1173  llvm_unreachable("Unknown result type!\n");
1174  }
1175  }
1176 
1177  memoryMap[getId()] = ptr;
1178  return true;
1179 }
1180 
1181 std::string handshake::LoadOp::getOperandName(unsigned int idx) {
1182  unsigned nAddresses = getAddresses().size();
1183  std::string opName;
1184  if (idx < nAddresses)
1185  opName = "addrIn" + std::to_string(idx);
1186  else if (idx == nAddresses)
1187  opName = "dataFromMem";
1188  else
1189  opName = "ctrl";
1190  return opName;
1191 }
1192 
1193 std::string handshake::LoadOp::getResultName(unsigned int idx) {
1194  std::string resName;
1195  if (idx == 0)
1196  resName = "dataOut";
1197  else
1198  resName = "addrOut" + std::to_string(idx - 1);
1199  return resName;
1200 }
1201 
1202 void handshake::LoadOp::build(OpBuilder &builder, OperationState &result,
1203  Value memref, ValueRange indices) {
1204  // Address indices
1205  // result.addOperands(memref);
1206  result.addOperands(indices);
1207 
1208  // Data type
1209  auto memrefType = cast<MemRefType>(memref.getType());
1210 
1211  // Data output (from load to successor ops)
1212  result.types.push_back(memrefType.getElementType());
1213 
1214  // Address outputs (to lsq)
1215  result.types.append(indices.size(), builder.getIndexType());
1216 }
1217 
1218 static ParseResult parseMemoryAccessOp(OpAsmParser &parser,
1219  OperationState &result) {
1220  SmallVector<OpAsmParser::UnresolvedOperand, 4> addressOperands,
1221  remainingOperands, allOperands;
1222  SmallVector<Type, 1> parsedTypes, allTypes;
1223  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1224 
1225  if (parser.parseLSquare() || parser.parseOperandList(addressOperands) ||
1226  parser.parseRSquare() || parser.parseOperandList(remainingOperands) ||
1227  parser.parseColon() || parser.parseTypeList(parsedTypes))
1228  return failure();
1229 
1230  // The last type will be the data type of the operation; the prior will be the
1231  // address types.
1232  Type dataType = parsedTypes.back();
1233  auto parsedTypesRef = ArrayRef(parsedTypes);
1234  result.addTypes(dataType);
1235  result.addTypes(parsedTypesRef.drop_back());
1236  allOperands.append(addressOperands);
1237  allOperands.append(remainingOperands);
1238  allTypes.append(parsedTypes);
1239  allTypes.push_back(NoneType::get(result.getContext()));
1240  if (parser.resolveOperands(allOperands, allTypes, allOperandLoc,
1241  result.operands))
1242  return failure();
1243  return success();
1244 }
1245 
1246 template <typename MemOp>
1247 static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op) {
1248  p << " [";
1249  p << op.getAddresses();
1250  p << "] " << op.getData() << ", " << op.getCtrl() << " : ";
1251  llvm::interleaveComma(op.getAddresses(), p,
1252  [&](Value v) { p << v.getType(); });
1253  p << ", " << op.getData().getType();
1254 }
1255 
1256 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
1257  return parseMemoryAccessOp(parser, result);
1258 }
1259 
1260 void LoadOp::print(OpAsmPrinter &p) { printMemoryAccessOp(p, *this); }
1261 
1262 std::string handshake::StoreOp::getOperandName(unsigned int idx) {
1263  unsigned nAddresses = getAddresses().size();
1264  std::string opName;
1265  if (idx < nAddresses)
1266  opName = "addrIn" + std::to_string(idx);
1267  else if (idx == nAddresses)
1268  opName = "dataIn";
1269  else
1270  opName = "ctrl";
1271  return opName;
1272 }
1273 
1274 template <typename TMemoryOp>
1275 static LogicalResult verifyMemoryAccessOp(TMemoryOp op) {
1276  if (op.getAddresses().size() == 0)
1277  return op.emitOpError() << "No addresses were specified";
1278 
1279  return success();
1280 }
1281 
1282 LogicalResult LoadOp::verify() { return verifyMemoryAccessOp(*this); }
1283 
1284 std::string handshake::StoreOp::getResultName(unsigned int idx) {
1285  std::string resName;
1286  if (idx == 0)
1287  resName = "dataToMem";
1288  else
1289  resName = "addrOut" + std::to_string(idx - 1);
1290  return resName;
1291 }
1292 
1293 void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
1294  Value valueToStore, ValueRange indices) {
1295 
1296  // Address indices
1297  result.addOperands(indices);
1298 
1299  // Data
1300  result.addOperands(valueToStore);
1301 
1302  // Data output (from store to LSQ)
1303  result.types.push_back(valueToStore.getType());
1304 
1305  // Address outputs (from store to lsq)
1306  result.types.append(indices.size(), builder.getIndexType());
1307 }
1308 
1309 LogicalResult StoreOp::verify() { return verifyMemoryAccessOp(*this); }
1310 
1311 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
1312  return parseMemoryAccessOp(parser, result);
1313 }
1314 
1315 void StoreOp::print(OpAsmPrinter &p) { return printMemoryAccessOp(p, *this); }
1316 
1317 bool JoinOp::isControl() { return true; }
1318 
1319 ParseResult JoinOp::parse(OpAsmParser &parser, OperationState &result) {
1320  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1321  SmallVector<Type> types;
1322 
1323  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1324  if (parser.parseOperandList(operands) ||
1325  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1326  parser.parseTypeList(types))
1327  return failure();
1328 
1329  if (parser.resolveOperands(operands, types, allOperandLoc, result.operands))
1330  return failure();
1331 
1332  result.addTypes(NoneType::get(result.getContext()));
1333  return success();
1334 }
1335 
1336 void JoinOp::print(OpAsmPrinter &p) {
1337  p << " " << getData();
1338  p.printOptionalAttrDict((*this)->getAttrs(), {"control"});
1339  p << " : " << getData().getTypes();
1340 }
1341 
1342 /// Based on mlir::func::CallOp::verifySymbolUses
1343 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1344  // Check that the module attribute was specified.
1345  auto fnAttr = this->getModuleAttr();
1346  assert(fnAttr && "requires a 'module' symbol reference attribute");
1347 
1348  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
1349  if (!fn)
1350  return emitOpError() << "'" << fnAttr.getValue()
1351  << "' does not reference a valid handshake function";
1352 
1353  // Verify that the operand and result types match the callee.
1354  auto fnType = fn.getFunctionType();
1355  if (fnType.getNumInputs() != getNumOperands())
1356  return emitOpError(
1357  "incorrect number of operands for the referenced handshake function");
1358 
1359  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
1360  if (getOperand(i).getType() != fnType.getInput(i))
1361  return emitOpError("operand type mismatch: expected operand type ")
1362  << fnType.getInput(i) << ", but provided "
1363  << getOperand(i).getType() << " for operand number " << i;
1364 
1365  if (fnType.getNumResults() != getNumResults())
1366  return emitOpError(
1367  "incorrect number of results for the referenced handshake function");
1368 
1369  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
1370  if (getResult(i).getType() != fnType.getResult(i))
1371  return emitOpError("result type mismatch: expected result type ")
1372  << fnType.getResult(i) << ", but provided "
1373  << getResult(i).getType() << " for result number " << i;
1374 
1375  return success();
1376 }
1377 
1378 FunctionType InstanceOp::getModuleType() {
1379  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
1380 }
1381 
1382 ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1383  OpAsmParser::UnresolvedOperand tuple;
1384  TupleType type;
1385 
1386  if (parser.parseOperand(tuple) ||
1387  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1388  parser.parseType(type))
1389  return failure();
1390 
1391  if (parser.resolveOperand(tuple, type, result.operands))
1392  return failure();
1393 
1394  result.addTypes(type.getTypes());
1395 
1396  return success();
1397 }
1398 
1399 void UnpackOp::print(OpAsmPrinter &p) {
1400  p << " " << getInput();
1401  p.printOptionalAttrDict((*this)->getAttrs());
1402  p << " : " << getInput().getType();
1403 }
1404 
1405 ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1406  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1407  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1408  TupleType type;
1409 
1410  if (parser.parseOperandList(operands) ||
1411  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1412  parser.parseType(type))
1413  return failure();
1414 
1415  if (parser.resolveOperands(operands, type.getTypes(), allOperandLoc,
1416  result.operands))
1417  return failure();
1418 
1419  result.addTypes(type);
1420 
1421  return success();
1422 }
1423 
1424 void PackOp::print(OpAsmPrinter &p) {
1425  p << " " << getInputs();
1426  p.printOptionalAttrDict((*this)->getAttrs());
1427  p << " : " << getResult().getType();
1428 }
1429 
1430 //===----------------------------------------------------------------------===//
1431 // TableGen'd op method definitions
1432 //===----------------------------------------------------------------------===//
1433 
1434 LogicalResult ReturnOp::verify() {
1435  auto *parent = (*this)->getParentOp();
1436  auto function = dyn_cast<handshake::FuncOp>(parent);
1437  if (!function)
1438  return emitOpError("must have a handshake.func parent");
1439 
1440  // The operand number and types must match the function signature.
1441  const auto &results = function.getResultTypes();
1442  if (getNumOperands() != results.size())
1443  return emitOpError("has ")
1444  << getNumOperands() << " operands, but enclosing function returns "
1445  << results.size();
1446 
1447  for (unsigned i = 0, e = results.size(); i != e; ++i)
1448  if (getOperand(i).getType() != results[i])
1449  return emitError() << "type of return operand " << i << " ("
1450  << getOperand(i).getType()
1451  << ") doesn't match function result type ("
1452  << results[i] << ")";
1453 
1454  return success();
1455 }
1456 
1457 #define GET_OP_CLASSES
1458 #include "circt/Dialect/Handshake/Handshake.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition: CHIRRTL.cpp:29
static StringRef getOperandName(Value operand)
int32_t width
Definition: FIRRTL.cpp:36
static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal, uint64_t numOperands)
Verifies whether an indexing value is wide enough to index into a provided number of operands.
static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result)
static ParseResult parseSostOperation(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, OperationState &result, int &size, Type &type, bool explicitSize)
static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op)
static ParseResult parseFuncOpArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Type > &resTypes, SmallVectorImpl< DictionaryAttr > &resAttrs)
Parses a FuncOp signature using mlir::function_interface_impl::parseFunctionSignature while getting a...
static Value getDematerialized(Value v)
Returns a dematerialized version of the value 'v', defined as the source of the value before passing ...
static bool isControlCheckTypeAndOperand(Type dataType, Value operand)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static std::string getMemoryResultName(unsigned nLoads, unsigned nStores, unsigned idx)
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v)
static LogicalResult verifyMemoryAccessOp(TMemoryOp op)
static void addStringToStringArrayAttr(Builder &builder, Operation *op, StringRef attrName, StringAttr str)
Helper function for appending a string to an array attribute, and rewriting the attribute back to the...
static std::string defaultOperandName(unsigned int idx)
static SmallVector< Attribute > getFuncOpNames(Builder &builder, unsigned cnt, StringRef prefix)
Generates names for a handshake.func input and output arguments, based on the number of args as well ...
static std::string getMemoryOperandName(unsigned nStores, unsigned idx)
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
static ParseResult parseMemoryAccessOp(OpAsmParser &parser, OperationState &result)
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
bool isControlOpImpl(Operation *op)
Default implementation for checking whether an operation is a control operation.
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition: HWOps.cpp:515
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21