CIRCT  19.0.0git
HandshakeOps.cpp
Go to the documentation of this file.
1 //===- HandshakeOps.cpp - Handshake MLIR Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the declaration of the Handshake operations struct.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "circt/Support/LLVM.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/Interfaces/FunctionImplementation.h"
28 #include "mlir/Transforms/InliningUtils.h"
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <set>
34 
35 using namespace circt;
36 using namespace circt::handshake;
37 
38 namespace circt {
39 namespace handshake {
40 #include "circt/Dialect/Handshake/HandshakeCanonicalization.h.inc"
41 
42 bool isControlOpImpl(Operation *op) {
43  if (auto sostInterface = dyn_cast<SOSTInterface>(op); sostInterface)
44  return sostInterface.sostIsControl();
45 
46  return false;
47 }
48 
49 } // namespace handshake
50 } // namespace circt
51 
52 static std::string defaultOperandName(unsigned int idx) {
53  return "in" + std::to_string(idx);
54 }
55 
56 static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v) {
57  if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
58  return failure();
59  return success();
60 }
61 
62 static ParseResult
63 parseSostOperation(OpAsmParser &parser,
64  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
65  OperationState &result, int &size, Type &type,
66  bool explicitSize) {
67  if (explicitSize)
68  if (parseIntInSquareBrackets(parser, size))
69  return failure();
70 
71  if (parser.parseOperandList(operands) ||
72  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
73  parser.parseType(type))
74  return failure();
75 
76  if (!explicitSize)
77  size = operands.size();
78  return success();
79 }
80 
81 /// Verifies whether an indexing value is wide enough to index into a provided
82 /// number of operands.
83 static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal,
84  uint64_t numOperands) {
85  auto indexType = indexVal.getType();
86  unsigned indexWidth;
87 
88  // Determine the bitwidth of the indexing value
89  if (auto integerType = dyn_cast<IntegerType>(indexType))
90  indexWidth = integerType.getWidth();
91  else if (indexType.isIndex())
92  indexWidth = IndexType::kInternalStorageBitWidth;
93  else
94  return op->emitError("unsupported type for indexing value: ") << indexType;
95 
96  // Check whether the bitwidth can support the provided number of operands
97  if (indexWidth < 64) {
98  uint64_t maxNumOperands = (uint64_t)1 << indexWidth;
99  if (numOperands > maxNumOperands)
100  return op->emitError("bitwidth of indexing value is ")
101  << indexWidth << ", which can index into " << maxNumOperands
102  << " operands, but found " << numOperands << " operands";
103  }
104  return success();
105 }
106 
107 static bool isControlCheckTypeAndOperand(Type dataType, Value operand) {
108  // The operation is a control operation if its operand data type is a
109  // NoneType.
110  if (isa<NoneType>(dataType))
111  return true;
112 
113  // Otherwise, the operation is a control operation if the operation's
114  // operand originates from the control network
115  auto *defOp = operand.getDefiningOp();
116  return isa_and_nonnull<ControlMergeOp>(defOp) &&
117  operand == defOp->getResult(0);
118 }
119 
120 template <typename TMemOp>
121 llvm::SmallVector<handshake::MemLoadInterface> getLoadPorts(TMemOp op) {
122  llvm::SmallVector<MemLoadInterface> ports;
123  // Memory interface refresher:
124  // Operands:
125  // all stores (stdata1, staddr1, stdata2, staddr2, ...)
126  // then all loads (ldaddr1, ldaddr2,...)
127  // Outputs: load addresses (lddata1, lddata2, ...), followed by all none
128  // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
129  unsigned stCount = op.getStCount();
130  unsigned ldCount = op.getLdCount();
131  for (unsigned i = 0, e = ldCount; i != e; ++i) {
132  MemLoadInterface ldif;
133  ldif.index = i;
134  ldif.addressIn = op.getInputs()[stCount * 2 + i];
135  ldif.dataOut = op.getResult(i);
136  ldif.doneOut = op.getResult(ldCount + stCount + i);
137  ports.push_back(ldif);
138  }
139  return ports;
140 }
141 
142 template <typename TMemOp>
143 llvm::SmallVector<handshake::MemStoreInterface> getStorePorts(TMemOp op) {
144  llvm::SmallVector<MemStoreInterface> ports;
145  // Memory interface refresher:
146  // Operands:
147  // all stores (stdata1, staddr1, stdata2, staddr2, ...)
148  // then all loads (ldaddr1, ldaddr2,...)
149  // Outputs: load data (lddata1, lddata2, ...), followed by all none
150  // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
151  unsigned ldCount = op.getLdCount();
152  for (unsigned i = 0, e = op.getStCount(); i != e; ++i) {
153  MemStoreInterface stif;
154  stif.index = i;
155  stif.dataIn = op.getInputs()[i * 2];
156  stif.addressIn = op.getInputs()[i * 2 + 1];
157  stif.doneOut = op.getResult(ldCount + i);
158  ports.push_back(stif);
159  }
160  return ports;
161 }
162 
163 unsigned ForkOp::getSize() { return getResults().size(); }
164 
165 static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result) {
166  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
167  Type type;
168  ArrayRef<Type> operandTypes(type);
169  SmallVector<Type, 1> resultTypes;
170  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
171  int size;
172  if (parseSostOperation(parser, allOperands, result, size, type, true))
173  return failure();
174 
175  resultTypes.assign(size, type);
176  result.addTypes(resultTypes);
177  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
178  result.operands))
179  return failure();
180  return success();
181 }
182 
183 ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
184  return parseForkOp(parser, result);
185 }
186 
187 void ForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
188 
189 namespace {
190 
191 struct EliminateUnusedForkResultsPattern : mlir::OpRewritePattern<ForkOp> {
193 
194  LogicalResult matchAndRewrite(ForkOp op,
195  PatternRewriter &rewriter) const override {
196  std::set<unsigned> unusedIndexes;
197 
198  for (auto res : llvm::enumerate(op.getResults()))
199  if (res.value().getUses().empty())
200  unusedIndexes.insert(res.index());
201 
202  if (unusedIndexes.empty())
203  return failure();
204 
205  // Create a new fork op, dropping the unused results.
206  rewriter.setInsertionPoint(op);
207  auto operand = op.getOperand();
208  auto newFork = rewriter.create<ForkOp>(
209  op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
210  unsigned i = 0;
211  for (auto oldRes : llvm::enumerate(op.getResults()))
212  if (unusedIndexes.count(oldRes.index()) == 0)
213  rewriter.replaceAllUsesWith(oldRes.value(), newFork.getResults()[i++]);
214  rewriter.eraseOp(op);
215  return success();
216  }
217 };
218 
219 struct EliminateForkToForkPattern : mlir::OpRewritePattern<ForkOp> {
221 
222  LogicalResult matchAndRewrite(ForkOp op,
223  PatternRewriter &rewriter) const override {
224  auto parentForkOp = op.getOperand().getDefiningOp<ForkOp>();
225  if (!parentForkOp)
226  return failure();
227 
228  /// Create the fork with as many outputs as the two source forks.
229  /// Keeping the op.operand() output may or may not be redundant (dependning
230  /// on if op is the single user of the value), but we'll let
231  /// EliminateUnusedForkResultsPattern apply in that case.
232  unsigned totalNumOuts = op.getSize() + parentForkOp.getSize();
233  /// Create a new parent fork op which produces all of the fork outputs and
234  /// replace all of the uses of the old results.
235  auto newParentForkOp = rewriter.create<ForkOp>(
236  parentForkOp.getLoc(), parentForkOp.getOperand(), totalNumOuts);
237 
238  for (auto it :
239  llvm::zip(parentForkOp->getResults(), newParentForkOp.getResults()))
240  rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
241 
242  /// Replace the results of the matches fork op with the corresponding
243  /// results of the new parent fork op.
244  rewriter.replaceOp(op,
245  newParentForkOp.getResults().take_back(op.getSize()));
246  rewriter.eraseOp(parentForkOp);
247  return success();
248  }
249 };
250 
251 } // namespace
252 
253 void handshake::ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
254  MLIRContext *context) {
255  results.insert<circt::handshake::EliminateSimpleForksPattern,
256  EliminateForkToForkPattern, EliminateUnusedForkResultsPattern>(
257  context);
258 }
259 
260 unsigned LazyForkOp::getSize() { return getResults().size(); }
261 
262 bool LazyForkOp::sostIsControl() {
263  return isControlCheckTypeAndOperand(getDataType(), getOperand());
264 }
265 
266 ParseResult LazyForkOp::parse(OpAsmParser &parser, OperationState &result) {
267  return parseForkOp(parser, result);
268 }
269 
270 void LazyForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
271 
272 ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
273  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
274  Type type;
275  ArrayRef<Type> operandTypes(type);
276  SmallVector<Type, 1> resultTypes, dataOperandsTypes;
277  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
278  int size;
279  if (parseSostOperation(parser, allOperands, result, size, type, false))
280  return failure();
281 
282  dataOperandsTypes.assign(size, type);
283  resultTypes.push_back(type);
284  result.addTypes(resultTypes);
285  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
286  result.operands))
287  return failure();
288  return success();
289 }
290 
291 void MergeOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
292 
293 void MergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
294  MLIRContext *context) {
295  results.insert<circt::handshake::EliminateSimpleMergesPattern>(context);
296 }
297 
298 /// Returns a dematerialized version of the value 'v', defined as the source of
299 /// the value before passing through a buffer or fork operation.
300 static Value getDematerialized(Value v) {
301  Operation *parentOp = v.getDefiningOp();
302  if (!parentOp)
303  return v;
304 
305  return llvm::TypeSwitch<Operation *, Value>(parentOp)
306  .Case<ForkOp>(
307  [&](ForkOp op) { return getDematerialized(op.getOperand()); })
308  .Case<BufferOp>(
309  [&](BufferOp op) { return getDematerialized(op.getOperand()); })
310  .Default([&](auto) { return v; });
311 }
312 
313 namespace {
314 
315 /// Eliminates muxes with identical data inputs. Data inputs are inspected as
316 /// their dematerialized versions. This has the side effect of any subsequently
317 /// unused buffers are DCE'd and forks are optimized to be narrower.
318 struct EliminateSimpleMuxesPattern : mlir::OpRewritePattern<MuxOp> {
320  LogicalResult matchAndRewrite(MuxOp op,
321  PatternRewriter &rewriter) const override {
322  Value firstDataOperand = getDematerialized(op.getDataOperands()[0]);
323  if (!llvm::all_of(op.getDataOperands(), [&](Value operand) {
324  return getDematerialized(operand) == firstDataOperand;
325  }))
326  return failure();
327  rewriter.replaceOp(op, firstDataOperand);
328  return success();
329  }
330 };
331 
332 struct EliminateUnaryMuxesPattern : OpRewritePattern<MuxOp> {
334  LogicalResult matchAndRewrite(MuxOp op,
335  PatternRewriter &rewriter) const override {
336  if (op.getDataOperands().size() != 1)
337  return failure();
338 
339  rewriter.replaceOp(op, op.getDataOperands()[0]);
340  return success();
341  }
342 };
343 
344 struct EliminateCBranchIntoMuxPattern : OpRewritePattern<MuxOp> {
346  LogicalResult matchAndRewrite(MuxOp op,
347  PatternRewriter &rewriter) const override {
348 
349  auto dataOperands = op.getDataOperands();
350  if (dataOperands.size() != 2)
351  return failure();
352 
353  // Both data operands must originate from the same cbranch
354  ConditionalBranchOp firstParentCBranch =
355  dataOperands[0].getDefiningOp<ConditionalBranchOp>();
356  if (!firstParentCBranch)
357  return failure();
358  auto secondParentCBranch =
359  dataOperands[1].getDefiningOp<ConditionalBranchOp>();
360  if (!secondParentCBranch || firstParentCBranch != secondParentCBranch)
361  return failure();
362 
363  rewriter.modifyOpInPlace(firstParentCBranch, [&] {
364  // Replace uses of the mux's output with cbranch's data input
365  rewriter.replaceOp(op, firstParentCBranch.getDataOperand());
366  });
367 
368  return success();
369  }
370 };
371 
372 } // namespace
373 
374 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
375  MLIRContext *context) {
376  results.insert<EliminateSimpleMuxesPattern, EliminateUnaryMuxesPattern,
377  EliminateCBranchIntoMuxPattern>(context);
378 }
379 
380 LogicalResult
381 MuxOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
382  ValueRange operands, DictionaryAttr attributes,
383  mlir::OpaqueProperties properties,
384  mlir::RegionRange regions,
385  SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
386  // MuxOp must have at least one data operand (in addition to the select
387  // operand)
388  if (operands.size() < 2)
389  return failure();
390  // Result type is type of any data operand
391  inferredReturnTypes.push_back(operands[1].getType());
392  return success();
393 }
394 
395 bool MuxOp::isControl() { return isa<NoneType>(getResult().getType()); }
396 
397 std::string handshake::MuxOp::getOperandName(unsigned int idx) {
398  return idx == 0 ? "select" : defaultOperandName(idx - 1);
399 }
400 
401 ParseResult MuxOp::parse(OpAsmParser &parser, OperationState &result) {
402  OpAsmParser::UnresolvedOperand selectOperand;
403  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
404  Type selectType, dataType;
405  SmallVector<Type, 1> dataOperandsTypes;
406  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
407  if (parser.parseOperand(selectOperand) || parser.parseLSquare() ||
408  parser.parseOperandList(allOperands) || parser.parseRSquare() ||
409  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
410  parser.parseType(selectType) || parser.parseComma() ||
411  parser.parseType(dataType))
412  return failure();
413 
414  int size = allOperands.size();
415  dataOperandsTypes.assign(size, dataType);
416  result.addTypes(dataType);
417  allOperands.insert(allOperands.begin(), selectOperand);
418  if (parser.resolveOperands(
419  allOperands,
420  llvm::concat<const Type>(ArrayRef<Type>(selectType),
421  ArrayRef<Type>(dataOperandsTypes)),
422  allOperandLoc, result.operands))
423  return failure();
424  return success();
425 }
426 
427 void MuxOp::print(OpAsmPrinter &p) {
428  Type selectType = getSelectOperand().getType();
429  auto ops = getOperands();
430  p << ' ' << ops.front();
431  p << " [";
432  p.printOperands(ops.drop_front());
433  p << "]";
434  p.printOptionalAttrDict((*this)->getAttrs());
435  p << " : " << selectType << ", " << getResult().getType();
436 }
437 
438 LogicalResult MuxOp::verify() {
439  return verifyIndexWideEnough(*this, getSelectOperand(),
440  getDataOperands().size());
441 }
442 
443 std::string handshake::ControlMergeOp::getResultName(unsigned int idx) {
444  assert(idx == 0 || idx == 1);
445  return idx == 0 ? "dataOut" : "index";
446 }
447 
448 ParseResult ControlMergeOp::parse(OpAsmParser &parser, OperationState &result) {
449  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
450  Type resultType, indexType;
451  SmallVector<Type> resultTypes, dataOperandsTypes;
452  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
453  int size;
454  if (parseSostOperation(parser, allOperands, result, size, resultType, false))
455  return failure();
456  // Parse type of index result
457  if (parser.parseComma() || parser.parseType(indexType))
458  return failure();
459 
460  dataOperandsTypes.assign(size, resultType);
461  resultTypes.push_back(resultType);
462  resultTypes.push_back(indexType);
463  result.addTypes(resultTypes);
464  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
465  result.operands))
466  return failure();
467  return success();
468 }
469 
470 void ControlMergeOp::print(OpAsmPrinter &p) {
471  sostPrint(p, false);
472  // Print type of index result
473  p << ", " << getIndex().getType();
474 }
475 
476 LogicalResult ControlMergeOp::verify() {
477  auto operands = getOperands();
478  if (operands.empty())
479  return emitOpError("operation must have at least one operand");
480  if (operands[0].getType() != getResult().getType())
481  return emitOpError("type of first result should match type of operands");
482  return verifyIndexWideEnough(*this, getIndex(), getNumOperands());
483 }
484 
485 LogicalResult FuncOp::verify() {
486  // If this function is external there is nothing to do.
487  if (isExternal())
488  return success();
489 
490  // Verify that the argument list of the function and the arg list of the
491  // entry block line up. The trait already verified that the number of
492  // arguments is the same between the signature and the block.
493  auto fnInputTypes = getArgumentTypes();
494  Block &entryBlock = front();
495 
496  for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
497  if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
498  return emitOpError("type of entry block argument #")
499  << i << '(' << entryBlock.getArgument(i).getType()
500  << ") must match the type of the corresponding argument in "
501  << "function signature(" << fnInputTypes[i] << ')';
502 
503  // Verify that we have a name for each argument and result of this function.
504  auto verifyPortNameAttr = [&](StringRef attrName,
505  unsigned numIOs) -> LogicalResult {
506  auto portNamesAttr = (*this)->getAttrOfType<ArrayAttr>(attrName);
507 
508  if (!portNamesAttr)
509  return emitOpError() << "expected attribute '" << attrName << "'.";
510 
511  auto portNames = portNamesAttr.getValue();
512  if (portNames.size() != numIOs)
513  return emitOpError() << "attribute '" << attrName << "' has "
514  << portNames.size()
515  << " entries but is expected to have " << numIOs
516  << ".";
517 
518  if (llvm::any_of(portNames,
519  [&](Attribute attr) { return !isa<StringAttr>(attr); }))
520  return emitOpError() << "expected all entries in attribute '" << attrName
521  << "' to be strings.";
522 
523  return success();
524  };
525  if (failed(verifyPortNameAttr("argNames", getNumArguments())))
526  return failure();
527  if (failed(verifyPortNameAttr("resNames", getNumResults())))
528  return failure();
529 
530  // Verify that all memrefs have a corresponding extmemory operation
531  for (auto arg : entryBlock.getArguments()) {
532  if (!isa<MemRefType>(arg.getType()))
533  continue;
534  if (arg.getUsers().empty() ||
535  !isa<ExternalMemoryOp>(*arg.getUsers().begin()))
536  return emitOpError("expected that block argument #")
537  << arg.getArgNumber() << " is used by an 'extmemory' operation";
538  }
539 
540  return success();
541 }
542 
543 /// Parses a FuncOp signature using
544 /// mlir::function_interface_impl::parseFunctionSignature while getting access
545 /// to the parsed SSA names to store as attributes.
546 static ParseResult
547 parseFuncOpArgs(OpAsmParser &parser,
548  SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
549  SmallVectorImpl<Type> &resTypes,
550  SmallVectorImpl<DictionaryAttr> &resAttrs) {
551  bool isVariadic;
552  if (mlir::function_interface_impl::parseFunctionSignature(
553  parser, /*allowVariadic=*/true, entryArgs, isVariadic, resTypes,
554  resAttrs)
555  .failed())
556  return failure();
557 
558  return success();
559 }
560 
561 /// Generates names for a handshake.func input and output arguments, based on
562 /// the number of args as well as a prefix.
563 static SmallVector<Attribute> getFuncOpNames(Builder &builder, unsigned cnt,
564  StringRef prefix) {
565  SmallVector<Attribute> resNames;
566  for (unsigned i = 0; i < cnt; ++i)
567  resNames.push_back(builder.getStringAttr(prefix + std::to_string(i)));
568  return resNames;
569 }
570 
571 void handshake::FuncOp::build(OpBuilder &builder, OperationState &state,
572  StringRef name, FunctionType type,
573  ArrayRef<NamedAttribute> attrs) {
574  state.addAttribute(SymbolTable::getSymbolAttrName(),
575  builder.getStringAttr(name));
576  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
577  TypeAttr::get(type));
578  state.attributes.append(attrs.begin(), attrs.end());
579 
580  if (const auto *argNamesAttrIt = llvm::find_if(
581  attrs, [&](auto attr) { return attr.getName() == "argNames"; });
582  argNamesAttrIt == attrs.end())
583  state.addAttribute("argNames", builder.getArrayAttr({}));
584 
585  if (llvm::find_if(attrs, [&](auto attr) {
586  return attr.getName() == "resNames";
587  }) == attrs.end())
588  state.addAttribute("resNames", builder.getArrayAttr({}));
589 
590  state.addRegion();
591 }
592 
593 /// Helper function for appending a string to an array attribute, and
594 /// rewriting the attribute back to the operation.
595 static void addStringToStringArrayAttr(Builder &builder, Operation *op,
596  StringRef attrName, StringAttr str) {
597  llvm::SmallVector<Attribute> attrs;
598  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
599  std::back_inserter(attrs));
600  attrs.push_back(str);
601  op->setAttr(attrName, builder.getArrayAttr(attrs));
602 }
603 
604 void handshake::FuncOp::resolveArgAndResNames() {
605  Builder builder(getContext());
606 
607  /// Generate a set of fallback names. These are used in case names are
608  /// missing from the currently set arg- and res name attributes.
609  auto fallbackArgNames = getFuncOpNames(builder, getNumArguments(), "in");
610  auto fallbackResNames = getFuncOpNames(builder, getNumResults(), "out");
611  auto argNames = getArgNames().getValue();
612  auto resNames = getResNames().getValue();
613 
614  /// Use fallback names where actual names are missing.
615  auto resolveNames = [&](auto &fallbackNames, auto &actualNames,
616  StringRef attrName) {
617  for (auto fallbackName : llvm::enumerate(fallbackNames)) {
618  if (actualNames.size() <= fallbackName.index())
619  addStringToStringArrayAttr(builder, this->getOperation(), attrName,
620  cast<StringAttr>(fallbackName.value()));
621  }
622  };
623  resolveNames(fallbackArgNames, argNames, "argNames");
624  resolveNames(fallbackResNames, resNames, "resNames");
625 }
626 
627 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
628  auto &builder = parser.getBuilder();
629  StringAttr nameAttr;
630  SmallVector<OpAsmParser::Argument> args;
631  SmallVector<Type> resTypes;
632  SmallVector<DictionaryAttr> resAttributes;
633  SmallVector<Attribute> argNames;
634 
635  // Parse visibility.
636  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
637 
638  // Parse signature
639  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
640  result.attributes) ||
641  parseFuncOpArgs(parser, args, resTypes, resAttributes))
642  return failure();
643  mlir::function_interface_impl::addArgAndResultAttrs(
644  builder, result, args, resAttributes,
645  handshake::FuncOp::getArgAttrsAttrName(result.name),
646  handshake::FuncOp::getResAttrsAttrName(result.name));
647 
648  // Set function type
649  SmallVector<Type> argTypes;
650  for (auto arg : args)
651  argTypes.push_back(arg.type);
652 
653  result.addAttribute(
654  handshake::FuncOp::getFunctionTypeAttrName(result.name),
655  TypeAttr::get(builder.getFunctionType(argTypes, resTypes)));
656 
657  // Determine the names of the arguments. If no SSA values are present, use
658  // fallback names.
659  bool noSSANames =
660  llvm::any_of(args, [](auto arg) { return arg.ssaName.name.empty(); });
661  if (noSSANames) {
662  argNames = getFuncOpNames(builder, args.size(), "in");
663  } else {
664  llvm::transform(args, std::back_inserter(argNames), [&](auto arg) {
665  return builder.getStringAttr(arg.ssaName.name.drop_front());
666  });
667  }
668 
669  // Parse attributes
670  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
671  return failure();
672 
673  // If argNames and resNames wasn't provided manually, infer argNames attribute
674  // from the parsed SSA names and resNames from our naming convention.
675  if (!result.attributes.get("argNames"))
676  result.addAttribute("argNames", builder.getArrayAttr(argNames));
677  if (!result.attributes.get("resNames")) {
678  auto resNames = getFuncOpNames(builder, resTypes.size(), "out");
679  result.addAttribute("resNames", builder.getArrayAttr(resNames));
680  }
681 
682  // Parse the optional function body. The printer will not print the body if
683  // its empty, so disallow parsing of empty body in the parser.
684  auto *body = result.addRegion();
685  llvm::SMLoc loc = parser.getCurrentLocation();
686  auto parseResult = parser.parseOptionalRegion(*body, args,
687  /*enableNameShadowing=*/false);
688  if (!parseResult.has_value())
689  return success();
690 
691  if (failed(*parseResult))
692  return failure();
693  // Function body was parsed, make sure its not empty.
694  if (body->empty())
695  return parser.emitError(loc, "expected non-empty function body");
696 
697  // If a body was parsed, the arg and res names need to be resolved
698  return success();
699 }
700 
701 void FuncOp::print(OpAsmPrinter &p) {
702  mlir::function_interface_impl::printFunctionOp(
703  p, *this, /*isVariadic=*/true, getFunctionTypeAttrName(),
704  getArgAttrsAttrName(), getResAttrsAttrName());
705 }
706 
707 namespace {
708 struct EliminateSimpleControlMergesPattern
709  : mlir::OpRewritePattern<ControlMergeOp> {
711 
712  LogicalResult matchAndRewrite(ControlMergeOp op,
713  PatternRewriter &rewriter) const override;
714 };
715 } // namespace
716 
717 LogicalResult EliminateSimpleControlMergesPattern::matchAndRewrite(
718  ControlMergeOp op, PatternRewriter &rewriter) const {
719  auto dataResult = op.getResult();
720  auto choiceResult = op.getIndex();
721  auto choiceUnused = choiceResult.use_empty();
722  if (!choiceUnused && !choiceResult.hasOneUse())
723  return failure();
724 
725  Operation *choiceUser = nullptr;
726  if (choiceResult.hasOneUse()) {
727  choiceUser = choiceResult.getUses().begin().getUser();
728  if (!isa<SinkOp>(choiceUser))
729  return failure();
730  }
731 
732  auto merge = rewriter.create<MergeOp>(op.getLoc(), op.getDataOperands());
733 
734  for (auto &use : llvm::make_early_inc_range(dataResult.getUses())) {
735  auto *user = use.getOwner();
736  rewriter.modifyOpInPlace(
737  user, [&]() { user->setOperand(use.getOperandNumber(), merge); });
738  }
739 
740  if (choiceUnused) {
741  rewriter.eraseOp(op);
742  return success();
743  }
744 
745  rewriter.eraseOp(choiceUser);
746  rewriter.eraseOp(op);
747  return success();
748 }
749 
750 void ControlMergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
751  MLIRContext *context) {
752  results.insert<EliminateSimpleControlMergesPattern>(context);
753 }
754 
755 bool BranchOp::sostIsControl() {
756  return isControlCheckTypeAndOperand(getDataType(), getOperand());
757 }
758 
759 void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
760  MLIRContext *context) {
761  results.insert<circt::handshake::EliminateSimpleBranchesPattern>(context);
762 }
763 
764 ParseResult BranchOp::parse(OpAsmParser &parser, OperationState &result) {
765  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
766  Type type;
767  ArrayRef<Type> operandTypes(type);
768  SmallVector<Type, 1> dataOperandsTypes;
769  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
770  int size;
771  if (parseSostOperation(parser, allOperands, result, size, type, false))
772  return failure();
773 
774  dataOperandsTypes.assign(size, type);
775  result.addTypes({type});
776  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
777  result.operands))
778  return failure();
779  return success();
780 }
781 
782 void BranchOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
783 
784 ParseResult ConditionalBranchOp::parse(OpAsmParser &parser,
785  OperationState &result) {
786  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
787  Type dataType;
788  SmallVector<Type> operandTypes;
789  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
790  if (parser.parseOperandList(allOperands) ||
791  parser.parseOptionalAttrDict(result.attributes) ||
792  parser.parseColonType(dataType))
793  return failure();
794 
795  if (allOperands.size() != 2)
796  return parser.emitError(parser.getCurrentLocation(),
797  "Expected exactly 2 operands");
798 
799  result.addTypes({dataType, dataType});
800  operandTypes.push_back(IntegerType::get(parser.getContext(), 1));
801  operandTypes.push_back(dataType);
802  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
803  result.operands))
804  return failure();
805 
806  return success();
807 }
808 
809 void ConditionalBranchOp::print(OpAsmPrinter &p) {
810  Type type = getDataOperand().getType();
811  p << " " << getOperands();
812  p.printOptionalAttrDict((*this)->getAttrs());
813  p << " : " << type;
814 }
815 
816 std::string handshake::ConditionalBranchOp::getOperandName(unsigned int idx) {
817  assert(idx == 0 || idx == 1);
818  return idx == 0 ? "cond" : "data";
819 }
820 
821 std::string handshake::ConditionalBranchOp::getResultName(unsigned int idx) {
822  assert(idx == 0 || idx == 1);
823  return idx == ConditionalBranchOp::falseIndex ? "outFalse" : "outTrue";
824 }
825 
826 bool ConditionalBranchOp::isControl() {
827  return isControlCheckTypeAndOperand(getDataOperand().getType(),
828  getDataOperand());
829 }
830 
831 ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
832  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
833  Type type;
834  ArrayRef<Type> operandTypes(type);
835  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
836  int size;
837  if (parseSostOperation(parser, allOperands, result, size, type, false))
838  return failure();
839 
840  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
841  result.operands))
842  return failure();
843  return success();
844 }
845 
846 void SinkOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
847 
848 std::string handshake::ConstantOp::getOperandName(unsigned int idx) {
849  assert(idx == 0);
850  return "ctrl";
851 }
852 
853 Type SourceOp::getDataType() { return getResult().getType(); }
854 unsigned SourceOp::getSize() { return 1; }
855 
856 ParseResult SourceOp::parse(OpAsmParser &parser, OperationState &result) {
857  if (parser.parseOptionalAttrDict(result.attributes))
858  return failure();
859  result.addTypes(NoneType::get(result.getContext()));
860  return success();
861 }
862 
863 void SourceOp::print(OpAsmPrinter &p) {
864  p.printOptionalAttrDict((*this)->getAttrs());
865 }
866 
867 LogicalResult ConstantOp::verify() {
868  // Verify that the type of the provided value is equal to the result type.
869  auto typedValue = dyn_cast<mlir::TypedAttr>(getValue());
870  if (!typedValue)
871  return emitOpError("constant value must be a typed attribute; value is ")
872  << getValue();
873  if (typedValue.getType() != getResult().getType())
874  return emitOpError() << "constant value type " << typedValue.getType()
875  << " differs from operation result type "
876  << getResult().getType();
877 
878  return success();
879 }
880 
881 void handshake::ConstantOp::getCanonicalizationPatterns(
882  RewritePatternSet &results, MLIRContext *context) {
883  results.insert<circt::handshake::EliminateSunkConstantsPattern>(context);
884 }
885 
886 LogicalResult BufferOp::verify() {
887  // Verify that exactly 'size' number of initial values have been provided, if
888  // an initializer list have been provided.
889  if (auto initVals = getInitValues()) {
890  if (!isSequential())
891  return emitOpError()
892  << "only bufferType buffers are allowed to have initial values.";
893 
894  auto nInits = initVals->size();
895  if (nInits != getSize())
896  return emitOpError() << "expected " << getSize()
897  << " init values but got " << nInits << ".";
898  }
899 
900  return success();
901 }
902 
903 void handshake::BufferOp::getCanonicalizationPatterns(
904  RewritePatternSet &results, MLIRContext *context) {
905  results.insert<circt::handshake::EliminateSunkBuffersPattern>(context);
906 }
907 
908 unsigned BufferOp::getSize() {
909  return (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
910 }
911 
912 ParseResult BufferOp::parse(OpAsmParser &parser, OperationState &result) {
913  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
914  Type type;
915  ArrayRef<Type> operandTypes(type);
916  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
917  int slots;
918  if (parseIntInSquareBrackets(parser, slots))
919  return failure();
920 
921  auto bufferTypeAttr = BufferTypeEnumAttr::parse(parser, {});
922  if (!bufferTypeAttr)
923  return failure();
924 
925  result.addAttribute(
926  "slots",
927  IntegerAttr::get(IntegerType::get(result.getContext(), 32), slots));
928  result.addAttribute("bufferType", bufferTypeAttr);
929 
930  if (parser.parseOperandList(allOperands) ||
931  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
932  parser.parseType(type))
933  return failure();
934 
935  result.addTypes({type});
936  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
937  result.operands))
938  return failure();
939  return success();
940 }
941 
942 void BufferOp::print(OpAsmPrinter &p) {
943  int size =
944  (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
945  p << " [" << size << "]";
946  p << " " << stringifyEnum(getBufferType());
947  p << " " << (*this)->getOperands();
948  p.printOptionalAttrDict((*this)->getAttrs(), {"slots", "bufferType"});
949  p << " : " << (*this).getDataType();
950 }
951 
952 static std::string getMemoryOperandName(unsigned nStores, unsigned idx) {
953  std::string name;
954  if (idx < nStores * 2) {
955  bool isData = idx % 2 == 0;
956  name = isData ? "stData" + std::to_string(idx / 2)
957  : "stAddr" + std::to_string(idx / 2);
958  } else {
959  idx -= 2 * nStores;
960  name = "ldAddr" + std::to_string(idx);
961  }
962  return name;
963 }
964 
965 std::string handshake::MemoryOp::getOperandName(unsigned int idx) {
966  return getMemoryOperandName(getStCount(), idx);
967 }
968 
969 static std::string getMemoryResultName(unsigned nLoads, unsigned nStores,
970  unsigned idx) {
971  std::string name;
972  if (idx < nLoads)
973  name = "ldData" + std::to_string(idx);
974  else if (idx < nLoads + nStores)
975  name = "stDone" + std::to_string(idx - nLoads);
976  else
977  name = "ldDone" + std::to_string(idx - nLoads - nStores);
978  return name;
979 }
980 
981 std::string handshake::MemoryOp::getResultName(unsigned int idx) {
982  return getMemoryResultName(getLdCount(), getStCount(), idx);
983 }
984 
985 LogicalResult MemoryOp::verify() {
986  auto memrefType = getMemRefType();
987 
988  if (memrefType.getNumDynamicDims() != 0)
989  return emitOpError()
990  << "memref dimensions for handshake.memory must be static.";
991  if (memrefType.getShape().size() != 1)
992  return emitOpError() << "memref must have only a single dimension.";
993 
994  unsigned opStCount = getStCount();
995  unsigned opLdCount = getLdCount();
996  int addressCount = memrefType.getShape().size();
997 
998  auto inputType = getInputs().getType();
999  auto outputType = getOutputs().getType();
1000  Type dataType = memrefType.getElementType();
1001 
1002  unsigned numOperands = static_cast<int>(getInputs().size());
1003  unsigned numResults = static_cast<int>(getOutputs().size());
1004  if (numOperands != (1 + addressCount) * opStCount + addressCount * opLdCount)
1005  return emitOpError("number of operands ")
1006  << numOperands << " does not match number expected of "
1007  << 2 * opStCount + opLdCount << " with " << addressCount
1008  << " address inputs per port";
1009 
1010  if (numResults != opStCount + 2 * opLdCount)
1011  return emitOpError("number of results ")
1012  << numResults << " does not match number expected of "
1013  << opStCount + 2 * opLdCount << " with " << addressCount
1014  << " address inputs per port";
1015 
1016  Type addressType = opStCount > 0 ? inputType[1] : inputType[0];
1017 
1018  for (unsigned i = 0; i < opStCount; i++) {
1019  if (inputType[2 * i] != dataType)
1020  return emitOpError("data type for store port ")
1021  << i << ":" << inputType[2 * i] << " doesn't match memory type "
1022  << dataType;
1023  if (inputType[2 * i + 1] != addressType)
1024  return emitOpError("address type for store port ")
1025  << i << ":" << inputType[2 * i + 1]
1026  << " doesn't match address type " << addressType;
1027  }
1028  for (unsigned i = 0; i < opLdCount; i++) {
1029  Type ldAddressType = inputType[2 * opStCount + i];
1030  if (ldAddressType != addressType)
1031  return emitOpError("address type for load port ")
1032  << i << ":" << ldAddressType << " doesn't match address type "
1033  << addressType;
1034  }
1035  for (unsigned i = 0; i < opLdCount; i++) {
1036  if (outputType[i] != dataType)
1037  return emitOpError("data type for load port ")
1038  << i << ":" << outputType[i] << " doesn't match memory type "
1039  << dataType;
1040  }
1041  for (unsigned i = 0; i < opStCount; i++) {
1042  Type syncType = outputType[opLdCount + i];
1043  if (!isa<NoneType>(syncType))
1044  return emitOpError("data type for sync port for store port ")
1045  << i << ":" << syncType << " is not 'none'";
1046  }
1047  for (unsigned i = 0; i < opLdCount; i++) {
1048  Type syncType = outputType[opLdCount + opStCount + i];
1049  if (!isa<NoneType>(syncType))
1050  return emitOpError("data type for sync port for load port ")
1051  << i << ":" << syncType << " is not 'none'";
1052  }
1053 
1054  return success();
1055 }
1056 
1057 std::string handshake::ExternalMemoryOp::getOperandName(unsigned int idx) {
1058  if (idx == 0)
1059  return "extmem";
1060 
1061  return getMemoryOperandName(getStCount(), idx - 1);
1062 }
1063 
1064 std::string handshake::ExternalMemoryOp::getResultName(unsigned int idx) {
1065  return getMemoryResultName(getLdCount(), getStCount(), idx);
1066 }
1067 
1068 void ExternalMemoryOp::build(OpBuilder &builder, OperationState &result,
1069  Value memref, ValueRange inputs, int ldCount,
1070  int stCount, int id) {
1071  SmallVector<Value> ops;
1072  ops.push_back(memref);
1073  llvm::append_range(ops, inputs);
1074  result.addOperands(ops);
1075 
1076  auto memrefType = cast<MemRefType>(memref.getType());
1077 
1078  // Data outputs (get their type from memref)
1079  result.types.append(ldCount, memrefType.getElementType());
1080 
1081  // Control outputs
1082  result.types.append(stCount + ldCount, builder.getNoneType());
1083 
1084  // Memory ID (individual ID for each MemoryOp)
1085  Type i32Type = builder.getIntegerType(32);
1086  result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1087  result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, ldCount));
1088  result.addAttribute("stCount", builder.getIntegerAttr(i32Type, stCount));
1089 }
1090 
1091 llvm::SmallVector<handshake::MemLoadInterface>
1093  return ::getLoadPorts(*this);
1094 }
1095 
1096 llvm::SmallVector<handshake::MemStoreInterface>
1098  return ::getStorePorts(*this);
1099 }
1100 
1101 void MemoryOp::build(OpBuilder &builder, OperationState &result,
1102  ValueRange operands, int outputs, int controlOutputs,
1103  bool lsq, int id, Value memref) {
1104  result.addOperands(operands);
1105 
1106  auto memrefType = cast<MemRefType>(memref.getType());
1107 
1108  // Data outputs (get their type from memref)
1109  result.types.append(outputs, memrefType.getElementType());
1110 
1111  // Control outputs
1112  result.types.append(controlOutputs, builder.getNoneType());
1113  result.addAttribute("lsq", builder.getBoolAttr(lsq));
1114  result.addAttribute("memRefType", TypeAttr::get(memrefType));
1115 
1116  // Memory ID (individual ID for each MemoryOp)
1117  Type i32Type = builder.getIntegerType(32);
1118  result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1119 
1120  if (!lsq) {
1121  result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, outputs));
1122  result.addAttribute(
1123  "stCount", builder.getIntegerAttr(i32Type, controlOutputs - outputs));
1124  }
1125 }
1126 
1127 llvm::SmallVector<handshake::MemLoadInterface> MemoryOp::getLoadPorts() {
1128  return ::getLoadPorts(*this);
1129 }
1130 
1131 llvm::SmallVector<handshake::MemStoreInterface> MemoryOp::getStorePorts() {
1132  return ::getStorePorts(*this);
1133 }
1134 
1135 bool handshake::MemoryOp::allocateMemory(
1136  llvm::DenseMap<unsigned, unsigned> &memoryMap,
1137  std::vector<std::vector<llvm::Any>> &store,
1138  std::vector<double> &storeTimes) {
1139  if (memoryMap.count(getId()))
1140  return false;
1141 
1142  auto type = getMemRefType();
1143  std::vector<llvm::Any> in;
1144 
1145  ArrayRef<int64_t> shape = type.getShape();
1146  int allocationSize = 1;
1147  unsigned count = 0;
1148  for (int64_t dim : shape) {
1149  if (dim > 0)
1150  allocationSize *= dim;
1151  else {
1152  assert(count < in.size());
1153  allocationSize *= llvm::any_cast<APInt>(in[count++]).getSExtValue();
1154  }
1155  }
1156  unsigned ptr = store.size();
1157  store.resize(ptr + 1);
1158  storeTimes.resize(ptr + 1);
1159  store[ptr].resize(allocationSize);
1160  storeTimes[ptr] = 0.0;
1161  mlir::Type elementType = type.getElementType();
1162  int width = elementType.getIntOrFloatBitWidth();
1163  for (int i = 0; i < allocationSize; i++) {
1164  if (isa<mlir::IntegerType>(elementType)) {
1165  store[ptr][i] = APInt(width, 0);
1166  } else if (isa<mlir::FloatType>(elementType)) {
1167  store[ptr][i] = APFloat(0.0);
1168  } else {
1169  llvm_unreachable("Unknown result type!\n");
1170  }
1171  }
1172 
1173  memoryMap[getId()] = ptr;
1174  return true;
1175 }
1176 
1177 std::string handshake::LoadOp::getOperandName(unsigned int idx) {
1178  unsigned nAddresses = getAddresses().size();
1179  std::string opName;
1180  if (idx < nAddresses)
1181  opName = "addrIn" + std::to_string(idx);
1182  else if (idx == nAddresses)
1183  opName = "dataFromMem";
1184  else
1185  opName = "ctrl";
1186  return opName;
1187 }
1188 
1189 std::string handshake::LoadOp::getResultName(unsigned int idx) {
1190  std::string resName;
1191  if (idx == 0)
1192  resName = "dataOut";
1193  else
1194  resName = "addrOut" + std::to_string(idx - 1);
1195  return resName;
1196 }
1197 
1198 void handshake::LoadOp::build(OpBuilder &builder, OperationState &result,
1199  Value memref, ValueRange indices) {
1200  // Address indices
1201  // result.addOperands(memref);
1202  result.addOperands(indices);
1203 
1204  // Data type
1205  auto memrefType = cast<MemRefType>(memref.getType());
1206 
1207  // Data output (from load to successor ops)
1208  result.types.push_back(memrefType.getElementType());
1209 
1210  // Address outputs (to lsq)
1211  result.types.append(indices.size(), builder.getIndexType());
1212 }
1213 
1214 static ParseResult parseMemoryAccessOp(OpAsmParser &parser,
1215  OperationState &result) {
1216  SmallVector<OpAsmParser::UnresolvedOperand, 4> addressOperands,
1217  remainingOperands, allOperands;
1218  SmallVector<Type, 1> parsedTypes, allTypes;
1219  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1220 
1221  if (parser.parseLSquare() || parser.parseOperandList(addressOperands) ||
1222  parser.parseRSquare() || parser.parseOperandList(remainingOperands) ||
1223  parser.parseColon() || parser.parseTypeList(parsedTypes))
1224  return failure();
1225 
1226  // The last type will be the data type of the operation; the prior will be the
1227  // address types.
1228  Type dataType = parsedTypes.back();
1229  auto parsedTypesRef = ArrayRef(parsedTypes);
1230  result.addTypes(dataType);
1231  result.addTypes(parsedTypesRef.drop_back());
1232  allOperands.append(addressOperands);
1233  allOperands.append(remainingOperands);
1234  allTypes.append(parsedTypes);
1235  allTypes.push_back(NoneType::get(result.getContext()));
1236  if (parser.resolveOperands(allOperands, allTypes, allOperandLoc,
1237  result.operands))
1238  return failure();
1239  return success();
1240 }
1241 
1242 template <typename MemOp>
1243 static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op) {
1244  p << " [";
1245  p << op.getAddresses();
1246  p << "] " << op.getData() << ", " << op.getCtrl() << " : ";
1247  llvm::interleaveComma(op.getAddresses(), p,
1248  [&](Value v) { p << v.getType(); });
1249  p << ", " << op.getData().getType();
1250 }
1251 
1252 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
1253  return parseMemoryAccessOp(parser, result);
1254 }
1255 
1256 void LoadOp::print(OpAsmPrinter &p) { printMemoryAccessOp(p, *this); }
1257 
1258 std::string handshake::StoreOp::getOperandName(unsigned int idx) {
1259  unsigned nAddresses = getAddresses().size();
1260  std::string opName;
1261  if (idx < nAddresses)
1262  opName = "addrIn" + std::to_string(idx);
1263  else if (idx == nAddresses)
1264  opName = "dataIn";
1265  else
1266  opName = "ctrl";
1267  return opName;
1268 }
1269 
1270 template <typename TMemoryOp>
1271 static LogicalResult verifyMemoryAccessOp(TMemoryOp op) {
1272  if (op.getAddresses().size() == 0)
1273  return op.emitOpError() << "No addresses were specified";
1274 
1275  return success();
1276 }
1277 
1278 LogicalResult LoadOp::verify() { return verifyMemoryAccessOp(*this); }
1279 
1280 std::string handshake::StoreOp::getResultName(unsigned int idx) {
1281  std::string resName;
1282  if (idx == 0)
1283  resName = "dataToMem";
1284  else
1285  resName = "addrOut" + std::to_string(idx - 1);
1286  return resName;
1287 }
1288 
1289 void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
1290  Value valueToStore, ValueRange indices) {
1291 
1292  // Address indices
1293  result.addOperands(indices);
1294 
1295  // Data
1296  result.addOperands(valueToStore);
1297 
1298  // Data output (from store to LSQ)
1299  result.types.push_back(valueToStore.getType());
1300 
1301  // Address outputs (from store to lsq)
1302  result.types.append(indices.size(), builder.getIndexType());
1303 }
1304 
1305 LogicalResult StoreOp::verify() { return verifyMemoryAccessOp(*this); }
1306 
1307 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
1308  return parseMemoryAccessOp(parser, result);
1309 }
1310 
1311 void StoreOp::print(OpAsmPrinter &p) { return printMemoryAccessOp(p, *this); }
1312 
1313 bool JoinOp::isControl() { return true; }
1314 
1315 ParseResult JoinOp::parse(OpAsmParser &parser, OperationState &result) {
1316  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1317  SmallVector<Type> types;
1318 
1319  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1320  if (parser.parseOperandList(operands) ||
1321  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1322  parser.parseTypeList(types))
1323  return failure();
1324 
1325  if (parser.resolveOperands(operands, types, allOperandLoc, result.operands))
1326  return failure();
1327 
1328  result.addTypes(NoneType::get(result.getContext()));
1329  return success();
1330 }
1331 
1332 void JoinOp::print(OpAsmPrinter &p) {
1333  p << " " << getData();
1334  p.printOptionalAttrDict((*this)->getAttrs(), {"control"});
1335  p << " : " << getData().getTypes();
1336 }
1337 
1338 /// Based on mlir::func::CallOp::verifySymbolUses
1339 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1340  // Check that the module attribute was specified.
1341  auto fnAttr = this->getModuleAttr();
1342  assert(fnAttr && "requires a 'module' symbol reference attribute");
1343 
1344  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
1345  if (!fn)
1346  return emitOpError() << "'" << fnAttr.getValue()
1347  << "' does not reference a valid handshake function";
1348 
1349  // Verify that the operand and result types match the callee.
1350  auto fnType = fn.getFunctionType();
1351  if (fnType.getNumInputs() != getNumOperands())
1352  return emitOpError(
1353  "incorrect number of operands for the referenced handshake function");
1354 
1355  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
1356  if (getOperand(i).getType() != fnType.getInput(i))
1357  return emitOpError("operand type mismatch: expected operand type ")
1358  << fnType.getInput(i) << ", but provided "
1359  << getOperand(i).getType() << " for operand number " << i;
1360 
1361  if (fnType.getNumResults() != getNumResults())
1362  return emitOpError(
1363  "incorrect number of results for the referenced handshake function");
1364 
1365  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
1366  if (getResult(i).getType() != fnType.getResult(i))
1367  return emitOpError("result type mismatch: expected result type ")
1368  << fnType.getResult(i) << ", but provided "
1369  << getResult(i).getType() << " for result number " << i;
1370 
1371  return success();
1372 }
1373 
1374 FunctionType InstanceOp::getModuleType() {
1375  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
1376 }
1377 
1378 ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1379  OpAsmParser::UnresolvedOperand tuple;
1380  TupleType type;
1381 
1382  if (parser.parseOperand(tuple) ||
1383  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1384  parser.parseType(type))
1385  return failure();
1386 
1387  if (parser.resolveOperand(tuple, type, result.operands))
1388  return failure();
1389 
1390  result.addTypes(type.getTypes());
1391 
1392  return success();
1393 }
1394 
1395 void UnpackOp::print(OpAsmPrinter &p) {
1396  p << " " << getInput();
1397  p.printOptionalAttrDict((*this)->getAttrs());
1398  p << " : " << getInput().getType();
1399 }
1400 
1401 ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1402  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1403  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1404  TupleType type;
1405 
1406  if (parser.parseOperandList(operands) ||
1407  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1408  parser.parseType(type))
1409  return failure();
1410 
1411  if (parser.resolveOperands(operands, type.getTypes(), allOperandLoc,
1412  result.operands))
1413  return failure();
1414 
1415  result.addTypes(type);
1416 
1417  return success();
1418 }
1419 
1420 void PackOp::print(OpAsmPrinter &p) {
1421  p << " " << getInputs();
1422  p.printOptionalAttrDict((*this)->getAttrs());
1423  p << " : " << getResult().getType();
1424 }
1425 
1426 //===----------------------------------------------------------------------===//
1427 // TableGen'd op method definitions
1428 //===----------------------------------------------------------------------===//
1429 
1430 LogicalResult ReturnOp::verify() {
1431  auto *parent = (*this)->getParentOp();
1432  auto function = dyn_cast<handshake::FuncOp>(parent);
1433  if (!function)
1434  return emitOpError("must have a handshake.func parent");
1435 
1436  // The operand number and types must match the function signature.
1437  const auto &results = function.getResultTypes();
1438  if (getNumOperands() != results.size())
1439  return emitOpError("has ")
1440  << getNumOperands() << " operands, but enclosing function returns "
1441  << results.size();
1442 
1443  for (unsigned i = 0, e = results.size(); i != e; ++i)
1444  if (getOperand(i).getType() != results[i])
1445  return emitError() << "type of return operand " << i << " ("
1446  << getOperand(i).getType()
1447  << ") doesn't match function result type ("
1448  << results[i] << ")";
1449 
1450  return success();
1451 }
1452 
1453 #define GET_OP_CLASSES
1454 #include "circt/Dialect/Handshake/Handshake.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition: CHIRRTL.cpp:29
static StringRef getOperandName(Value operand)
int32_t width
Definition: FIRRTL.cpp:36
static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal, uint64_t numOperands)
Verifies whether an indexing value is wide enough to index into a provided number of operands.
static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result)
static ParseResult parseSostOperation(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, OperationState &result, int &size, Type &type, bool explicitSize)
static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op)
static ParseResult parseFuncOpArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Type > &resTypes, SmallVectorImpl< DictionaryAttr > &resAttrs)
Parses a FuncOp signature using mlir::function_interface_impl::parseFunctionSignature while getting a...
static Value getDematerialized(Value v)
Returns a dematerialized version of the value 'v', defined as the source of the value before passing ...
static bool isControlCheckTypeAndOperand(Type dataType, Value operand)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static std::string getMemoryResultName(unsigned nLoads, unsigned nStores, unsigned idx)
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v)
static LogicalResult verifyMemoryAccessOp(TMemoryOp op)
static void addStringToStringArrayAttr(Builder &builder, Operation *op, StringRef attrName, StringAttr str)
Helper function for appending a string to an array attribute, and rewriting the attribute back to the...
static std::string defaultOperandName(unsigned int idx)
static SmallVector< Attribute > getFuncOpNames(Builder &builder, unsigned cnt, StringRef prefix)
Generates names for a handshake.func input and output arguments, based on the number of args as well ...
static std::string getMemoryOperandName(unsigned nStores, unsigned idx)
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
static ParseResult parseMemoryAccessOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition: SVOps.cpp:2443
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
bool isControlOpImpl(Operation *op)
Default implementation for checking whether an operation is a control operation.
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition: HWOps.cpp:515
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21