CIRCT  19.0.0git
HandshakeOps.cpp
Go to the documentation of this file.
1 //===- HandshakeOps.cpp - Handshake MLIR Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the declaration of the Handshake operations struct.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "circt/Support/LLVM.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/Interfaces/FunctionImplementation.h"
28 #include "mlir/Transforms/InliningUtils.h"
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <set>
34 
35 using namespace circt;
36 using namespace circt::handshake;
37 
38 namespace circt {
39 namespace handshake {
40 #include "circt/Dialect/Handshake/HandshakeCanonicalization.h.inc"
41 
42 bool isControlOpImpl(Operation *op) {
43  if (auto sostInterface = dyn_cast<SOSTInterface>(op); sostInterface)
44  return sostInterface.sostIsControl();
45 
46  return false;
47 }
48 
49 } // namespace handshake
50 } // namespace circt
51 
52 static std::string defaultOperandName(unsigned int idx) {
53  return "in" + std::to_string(idx);
54 }
55 
56 static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v) {
57  if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
58  return failure();
59  return success();
60 }
61 
62 static ParseResult
63 parseSostOperation(OpAsmParser &parser,
64  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
65  OperationState &result, int &size, Type &type,
66  bool explicitSize) {
67  if (explicitSize)
68  if (parseIntInSquareBrackets(parser, size))
69  return failure();
70 
71  if (parser.parseOperandList(operands) ||
72  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
73  parser.parseType(type))
74  return failure();
75 
76  if (!explicitSize)
77  size = operands.size();
78  return success();
79 }
80 
81 /// Verifies whether an indexing value is wide enough to index into a provided
82 /// number of operands.
83 static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal,
84  uint64_t numOperands) {
85  auto indexType = indexVal.getType();
86  unsigned indexWidth;
87 
88  // Determine the bitwidth of the indexing value
89  if (auto integerType = indexType.dyn_cast<IntegerType>())
90  indexWidth = integerType.getWidth();
91  else if (indexType.isIndex())
92  indexWidth = IndexType::kInternalStorageBitWidth;
93  else
94  return op->emitError("unsupported type for indexing value: ") << indexType;
95 
96  // Check whether the bitwidth can support the provided number of operands
97  if (indexWidth < 64) {
98  uint64_t maxNumOperands = (uint64_t)1 << indexWidth;
99  if (numOperands > maxNumOperands)
100  return op->emitError("bitwidth of indexing value is ")
101  << indexWidth << ", which can index into " << maxNumOperands
102  << " operands, but found " << numOperands << " operands";
103  }
104  return success();
105 }
106 
107 static bool isControlCheckTypeAndOperand(Type dataType, Value operand) {
108  // The operation is a control operation if its operand data type is a
109  // NoneType.
110  if (dataType.isa<NoneType>())
111  return true;
112 
113  // Otherwise, the operation is a control operation if the operation's
114  // operand originates from the control network
115  auto *defOp = operand.getDefiningOp();
116  return isa_and_nonnull<ControlMergeOp>(defOp) &&
117  operand == defOp->getResult(0);
118 }
119 
120 template <typename TMemOp>
121 llvm::SmallVector<handshake::MemLoadInterface> getLoadPorts(TMemOp op) {
122  llvm::SmallVector<MemLoadInterface> ports;
123  // Memory interface refresher:
124  // Operands:
125  // all stores (stdata1, staddr1, stdata2, staddr2, ...)
126  // then all loads (ldaddr1, ldaddr2,...)
127  // Outputs: load addresses (lddata1, lddata2, ...), followed by all none
128  // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
129  unsigned stCount = op.getStCount();
130  unsigned ldCount = op.getLdCount();
131  for (unsigned i = 0, e = ldCount; i != e; ++i) {
132  MemLoadInterface ldif;
133  ldif.index = i;
134  ldif.addressIn = op.getInputs()[stCount * 2 + i];
135  ldif.dataOut = op.getResult(i);
136  ldif.doneOut = op.getResult(ldCount + stCount + i);
137  ports.push_back(ldif);
138  }
139  return ports;
140 }
141 
142 template <typename TMemOp>
143 llvm::SmallVector<handshake::MemStoreInterface> getStorePorts(TMemOp op) {
144  llvm::SmallVector<MemStoreInterface> ports;
145  // Memory interface refresher:
146  // Operands:
147  // all stores (stdata1, staddr1, stdata2, staddr2, ...)
148  // then all loads (ldaddr1, ldaddr2,...)
149  // Outputs: load data (lddata1, lddata2, ...), followed by all none
150  // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
151  unsigned ldCount = op.getLdCount();
152  for (unsigned i = 0, e = op.getStCount(); i != e; ++i) {
153  MemStoreInterface stif;
154  stif.index = i;
155  stif.dataIn = op.getInputs()[i * 2];
156  stif.addressIn = op.getInputs()[i * 2 + 1];
157  stif.doneOut = op.getResult(ldCount + i);
158  ports.push_back(stif);
159  }
160  return ports;
161 }
162 
163 unsigned ForkOp::getSize() { return getResults().size(); }
164 
165 static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result) {
166  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
167  Type type;
168  ArrayRef<Type> operandTypes(type);
169  SmallVector<Type, 1> resultTypes;
170  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
171  int size;
172  if (parseSostOperation(parser, allOperands, result, size, type, true))
173  return failure();
174 
175  resultTypes.assign(size, type);
176  result.addTypes(resultTypes);
177  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
178  result.operands))
179  return failure();
180  return success();
181 }
182 
183 ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
184  return parseForkOp(parser, result);
185 }
186 
187 void ForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
188 
189 namespace {
190 
191 struct EliminateUnusedForkResultsPattern : mlir::OpRewritePattern<ForkOp> {
193 
194  LogicalResult matchAndRewrite(ForkOp op,
195  PatternRewriter &rewriter) const override {
196  std::set<unsigned> unusedIndexes;
197 
198  for (auto res : llvm::enumerate(op.getResults()))
199  if (res.value().getUses().empty())
200  unusedIndexes.insert(res.index());
201 
202  if (unusedIndexes.empty())
203  return failure();
204 
205  // Create a new fork op, dropping the unused results.
206  rewriter.setInsertionPoint(op);
207  auto operand = op.getOperand();
208  auto newFork = rewriter.create<ForkOp>(
209  op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
210  rewriter.modifyOpInPlace(op, [&] {
211  unsigned i = 0;
212  for (auto oldRes : llvm::enumerate(op.getResults()))
213  if (unusedIndexes.count(oldRes.index()) == 0)
214  oldRes.value().replaceAllUsesWith(newFork.getResults()[i++]);
215  });
216  rewriter.eraseOp(op);
217  return success();
218  }
219 };
220 
221 struct EliminateForkToForkPattern : mlir::OpRewritePattern<ForkOp> {
223 
224  LogicalResult matchAndRewrite(ForkOp op,
225  PatternRewriter &rewriter) const override {
226  auto parentForkOp = op.getOperand().getDefiningOp<ForkOp>();
227  if (!parentForkOp)
228  return failure();
229 
230  /// Create the fork with as many outputs as the two source forks.
231  /// Keeping the op.operand() output may or may not be redundant (dependning
232  /// on if op is the single user of the value), but we'll let
233  /// EliminateUnusedForkResultsPattern apply in that case.
234  unsigned totalNumOuts = op.getSize() + parentForkOp.getSize();
235  rewriter.modifyOpInPlace(parentForkOp, [&] {
236  /// Create a new parent fork op which produces all of the fork outputs and
237  /// replace all of the uses of the old results.
238  auto newParentForkOp = rewriter.create<ForkOp>(
239  parentForkOp.getLoc(), parentForkOp.getOperand(), totalNumOuts);
240 
241  for (auto it :
242  llvm::zip(parentForkOp->getResults(), newParentForkOp.getResults()))
243  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
244 
245  /// Replace the results of the matches fork op with the corresponding
246  /// results of the new parent fork op.
247  rewriter.replaceOp(op,
248  newParentForkOp.getResults().take_back(op.getSize()));
249  });
250  rewriter.eraseOp(parentForkOp);
251  return success();
252  }
253 };
254 
255 } // namespace
256 
257 void handshake::ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
258  MLIRContext *context) {
259  results.insert<circt::handshake::EliminateSimpleForksPattern,
260  EliminateForkToForkPattern, EliminateUnusedForkResultsPattern>(
261  context);
262 }
263 
264 unsigned LazyForkOp::getSize() { return getResults().size(); }
265 
266 bool LazyForkOp::sostIsControl() {
267  return isControlCheckTypeAndOperand(getDataType(), getOperand());
268 }
269 
270 ParseResult LazyForkOp::parse(OpAsmParser &parser, OperationState &result) {
271  return parseForkOp(parser, result);
272 }
273 
274 void LazyForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
275 
276 ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
277  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
278  Type type;
279  ArrayRef<Type> operandTypes(type);
280  SmallVector<Type, 1> resultTypes, dataOperandsTypes;
281  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
282  int size;
283  if (parseSostOperation(parser, allOperands, result, size, type, false))
284  return failure();
285 
286  dataOperandsTypes.assign(size, type);
287  resultTypes.push_back(type);
288  result.addTypes(resultTypes);
289  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
290  result.operands))
291  return failure();
292  return success();
293 }
294 
295 void MergeOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
296 
297 void MergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
298  MLIRContext *context) {
299  results.insert<circt::handshake::EliminateSimpleMergesPattern>(context);
300 }
301 
302 /// Returns a dematerialized version of the value 'v', defined as the source of
303 /// the value before passing through a buffer or fork operation.
304 static Value getDematerialized(Value v) {
305  Operation *parentOp = v.getDefiningOp();
306  if (!parentOp)
307  return v;
308 
309  return llvm::TypeSwitch<Operation *, Value>(parentOp)
310  .Case<ForkOp>(
311  [&](ForkOp op) { return getDematerialized(op.getOperand()); })
312  .Case<BufferOp>(
313  [&](BufferOp op) { return getDematerialized(op.getOperand()); })
314  .Default([&](auto) { return v; });
315 }
316 
317 namespace {
318 
319 /// Eliminates muxes with identical data inputs. Data inputs are inspected as
320 /// their dematerialized versions. This has the side effect of any subsequently
321 /// unused buffers are DCE'd and forks are optimized to be narrower.
322 struct EliminateSimpleMuxesPattern : mlir::OpRewritePattern<MuxOp> {
324  LogicalResult matchAndRewrite(MuxOp op,
325  PatternRewriter &rewriter) const override {
326  Value firstDataOperand = getDematerialized(op.getDataOperands()[0]);
327  if (!llvm::all_of(op.getDataOperands(), [&](Value operand) {
328  return getDematerialized(operand) == firstDataOperand;
329  }))
330  return failure();
331  rewriter.replaceOp(op, firstDataOperand);
332  return success();
333  }
334 };
335 
336 struct EliminateUnaryMuxesPattern : OpRewritePattern<MuxOp> {
338  LogicalResult matchAndRewrite(MuxOp op,
339  PatternRewriter &rewriter) const override {
340  if (op.getDataOperands().size() != 1)
341  return failure();
342 
343  rewriter.replaceOp(op, op.getDataOperands()[0]);
344  return success();
345  }
346 };
347 
348 struct EliminateCBranchIntoMuxPattern : OpRewritePattern<MuxOp> {
350  LogicalResult matchAndRewrite(MuxOp op,
351  PatternRewriter &rewriter) const override {
352 
353  auto dataOperands = op.getDataOperands();
354  if (dataOperands.size() != 2)
355  return failure();
356 
357  // Both data operands must originate from the same cbranch
358  ConditionalBranchOp firstParentCBranch =
359  dataOperands[0].getDefiningOp<ConditionalBranchOp>();
360  if (!firstParentCBranch)
361  return failure();
362  auto secondParentCBranch =
363  dataOperands[1].getDefiningOp<ConditionalBranchOp>();
364  if (!secondParentCBranch || firstParentCBranch != secondParentCBranch)
365  return failure();
366 
367  rewriter.modifyOpInPlace(firstParentCBranch, [&] {
368  // Replace uses of the mux's output with cbranch's data input
369  rewriter.replaceOp(op, firstParentCBranch.getDataOperand());
370  });
371 
372  return success();
373  }
374 };
375 
376 } // namespace
377 
378 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
379  MLIRContext *context) {
380  results.insert<EliminateSimpleMuxesPattern, EliminateUnaryMuxesPattern,
381  EliminateCBranchIntoMuxPattern>(context);
382 }
383 
384 LogicalResult
385 MuxOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
386  ValueRange operands, DictionaryAttr attributes,
387  mlir::OpaqueProperties properties,
388  mlir::RegionRange regions,
389  SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
390  // MuxOp must have at least one data operand (in addition to the select
391  // operand)
392  if (operands.size() < 2)
393  return failure();
394  // Result type is type of any data operand
395  inferredReturnTypes.push_back(operands[1].getType());
396  return success();
397 }
398 
399 bool MuxOp::isControl() { return getResult().getType().isa<NoneType>(); }
400 
401 std::string handshake::MuxOp::getOperandName(unsigned int idx) {
402  return idx == 0 ? "select" : defaultOperandName(idx - 1);
403 }
404 
405 ParseResult MuxOp::parse(OpAsmParser &parser, OperationState &result) {
406  OpAsmParser::UnresolvedOperand selectOperand;
407  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
408  Type selectType, dataType;
409  SmallVector<Type, 1> dataOperandsTypes;
410  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
411  if (parser.parseOperand(selectOperand) || parser.parseLSquare() ||
412  parser.parseOperandList(allOperands) || parser.parseRSquare() ||
413  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
414  parser.parseType(selectType) || parser.parseComma() ||
415  parser.parseType(dataType))
416  return failure();
417 
418  int size = allOperands.size();
419  dataOperandsTypes.assign(size, dataType);
420  result.addTypes(dataType);
421  allOperands.insert(allOperands.begin(), selectOperand);
422  if (parser.resolveOperands(
423  allOperands,
424  llvm::concat<const Type>(ArrayRef<Type>(selectType),
425  ArrayRef<Type>(dataOperandsTypes)),
426  allOperandLoc, result.operands))
427  return failure();
428  return success();
429 }
430 
431 void MuxOp::print(OpAsmPrinter &p) {
432  Type selectType = getSelectOperand().getType();
433  auto ops = getOperands();
434  p << ' ' << ops.front();
435  p << " [";
436  p.printOperands(ops.drop_front());
437  p << "]";
438  p.printOptionalAttrDict((*this)->getAttrs());
439  p << " : " << selectType << ", " << getResult().getType();
440 }
441 
442 LogicalResult MuxOp::verify() {
443  return verifyIndexWideEnough(*this, getSelectOperand(),
444  getDataOperands().size());
445 }
446 
447 std::string handshake::ControlMergeOp::getResultName(unsigned int idx) {
448  assert(idx == 0 || idx == 1);
449  return idx == 0 ? "dataOut" : "index";
450 }
451 
452 ParseResult ControlMergeOp::parse(OpAsmParser &parser, OperationState &result) {
453  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
454  Type resultType, indexType;
455  SmallVector<Type> resultTypes, dataOperandsTypes;
456  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
457  int size;
458  if (parseSostOperation(parser, allOperands, result, size, resultType, false))
459  return failure();
460  // Parse type of index result
461  if (parser.parseComma() || parser.parseType(indexType))
462  return failure();
463 
464  dataOperandsTypes.assign(size, resultType);
465  resultTypes.push_back(resultType);
466  resultTypes.push_back(indexType);
467  result.addTypes(resultTypes);
468  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
469  result.operands))
470  return failure();
471  return success();
472 }
473 
474 void ControlMergeOp::print(OpAsmPrinter &p) {
475  sostPrint(p, false);
476  // Print type of index result
477  p << ", " << getIndex().getType();
478 }
479 
480 LogicalResult ControlMergeOp::verify() {
481  auto operands = getOperands();
482  if (operands.empty())
483  return emitOpError("operation must have at least one operand");
484  if (operands[0].getType() != getResult().getType())
485  return emitOpError("type of first result should match type of operands");
486  return verifyIndexWideEnough(*this, getIndex(), getNumOperands());
487 }
488 
489 LogicalResult FuncOp::verify() {
490  // If this function is external there is nothing to do.
491  if (isExternal())
492  return success();
493 
494  // Verify that the argument list of the function and the arg list of the
495  // entry block line up. The trait already verified that the number of
496  // arguments is the same between the signature and the block.
497  auto fnInputTypes = getArgumentTypes();
498  Block &entryBlock = front();
499 
500  for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
501  if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
502  return emitOpError("type of entry block argument #")
503  << i << '(' << entryBlock.getArgument(i).getType()
504  << ") must match the type of the corresponding argument in "
505  << "function signature(" << fnInputTypes[i] << ')';
506 
507  // Verify that we have a name for each argument and result of this function.
508  auto verifyPortNameAttr = [&](StringRef attrName,
509  unsigned numIOs) -> LogicalResult {
510  auto portNamesAttr = (*this)->getAttrOfType<ArrayAttr>(attrName);
511 
512  if (!portNamesAttr)
513  return emitOpError() << "expected attribute '" << attrName << "'.";
514 
515  auto portNames = portNamesAttr.getValue();
516  if (portNames.size() != numIOs)
517  return emitOpError() << "attribute '" << attrName << "' has "
518  << portNames.size()
519  << " entries but is expected to have " << numIOs
520  << ".";
521 
522  if (llvm::any_of(portNames,
523  [&](Attribute attr) { return !attr.isa<StringAttr>(); }))
524  return emitOpError() << "expected all entries in attribute '" << attrName
525  << "' to be strings.";
526 
527  return success();
528  };
529  if (failed(verifyPortNameAttr("argNames", getNumArguments())))
530  return failure();
531  if (failed(verifyPortNameAttr("resNames", getNumResults())))
532  return failure();
533 
534  // Verify that all memrefs have a corresponding extmemory operation
535  for (auto arg : entryBlock.getArguments()) {
536  if (!arg.getType().isa<MemRefType>())
537  continue;
538  if (arg.getUsers().empty() ||
539  !isa<ExternalMemoryOp>(*arg.getUsers().begin()))
540  return emitOpError("expected that block argument #")
541  << arg.getArgNumber() << " is used by an 'extmemory' operation";
542  }
543 
544  return success();
545 }
546 
547 /// Parses a FuncOp signature using
548 /// mlir::function_interface_impl::parseFunctionSignature while getting access
549 /// to the parsed SSA names to store as attributes.
550 static ParseResult
551 parseFuncOpArgs(OpAsmParser &parser,
552  SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
553  SmallVectorImpl<Type> &resTypes,
554  SmallVectorImpl<DictionaryAttr> &resAttrs) {
555  bool isVariadic;
556  if (mlir::function_interface_impl::parseFunctionSignature(
557  parser, /*allowVariadic=*/true, entryArgs, isVariadic, resTypes,
558  resAttrs)
559  .failed())
560  return failure();
561 
562  return success();
563 }
564 
565 /// Generates names for a handshake.func input and output arguments, based on
566 /// the number of args as well as a prefix.
567 static SmallVector<Attribute> getFuncOpNames(Builder &builder, unsigned cnt,
568  StringRef prefix) {
569  SmallVector<Attribute> resNames;
570  for (unsigned i = 0; i < cnt; ++i)
571  resNames.push_back(builder.getStringAttr(prefix + std::to_string(i)));
572  return resNames;
573 }
574 
575 void handshake::FuncOp::build(OpBuilder &builder, OperationState &state,
576  StringRef name, FunctionType type,
577  ArrayRef<NamedAttribute> attrs) {
578  state.addAttribute(SymbolTable::getSymbolAttrName(),
579  builder.getStringAttr(name));
580  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
581  TypeAttr::get(type));
582  state.attributes.append(attrs.begin(), attrs.end());
583 
584  if (const auto *argNamesAttrIt = llvm::find_if(
585  attrs, [&](auto attr) { return attr.getName() == "argNames"; });
586  argNamesAttrIt == attrs.end())
587  state.addAttribute("argNames", builder.getArrayAttr({}));
588 
589  if (llvm::find_if(attrs, [&](auto attr) {
590  return attr.getName() == "resNames";
591  }) == attrs.end())
592  state.addAttribute("resNames", builder.getArrayAttr({}));
593 
594  state.addRegion();
595 }
596 
597 /// Helper function for appending a string to an array attribute, and
598 /// rewriting the attribute back to the operation.
599 static void addStringToStringArrayAttr(Builder &builder, Operation *op,
600  StringRef attrName, StringAttr str) {
601  llvm::SmallVector<Attribute> attrs;
602  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
603  std::back_inserter(attrs));
604  attrs.push_back(str);
605  op->setAttr(attrName, builder.getArrayAttr(attrs));
606 }
607 
608 void handshake::FuncOp::resolveArgAndResNames() {
609  Builder builder(getContext());
610 
611  /// Generate a set of fallback names. These are used in case names are
612  /// missing from the currently set arg- and res name attributes.
613  auto fallbackArgNames = getFuncOpNames(builder, getNumArguments(), "in");
614  auto fallbackResNames = getFuncOpNames(builder, getNumResults(), "out");
615  auto argNames = getArgNames().getValue();
616  auto resNames = getResNames().getValue();
617 
618  /// Use fallback names where actual names are missing.
619  auto resolveNames = [&](auto &fallbackNames, auto &actualNames,
620  StringRef attrName) {
621  for (auto fallbackName : llvm::enumerate(fallbackNames)) {
622  if (actualNames.size() <= fallbackName.index())
624  builder, this->getOperation(), attrName,
625  fallbackName.value().template cast<StringAttr>());
626  }
627  };
628  resolveNames(fallbackArgNames, argNames, "argNames");
629  resolveNames(fallbackResNames, resNames, "resNames");
630 }
631 
632 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
633  auto &builder = parser.getBuilder();
634  StringAttr nameAttr;
635  SmallVector<OpAsmParser::Argument> args;
636  SmallVector<Type> resTypes;
637  SmallVector<DictionaryAttr> resAttributes;
638  SmallVector<Attribute> argNames;
639 
640  // Parse visibility.
641  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
642 
643  // Parse signature
644  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
645  result.attributes) ||
646  parseFuncOpArgs(parser, args, resTypes, resAttributes))
647  return failure();
648  mlir::function_interface_impl::addArgAndResultAttrs(
649  builder, result, args, resAttributes,
650  handshake::FuncOp::getArgAttrsAttrName(result.name),
651  handshake::FuncOp::getResAttrsAttrName(result.name));
652 
653  // Set function type
654  SmallVector<Type> argTypes;
655  for (auto arg : args)
656  argTypes.push_back(arg.type);
657 
658  result.addAttribute(
659  handshake::FuncOp::getFunctionTypeAttrName(result.name),
660  TypeAttr::get(builder.getFunctionType(argTypes, resTypes)));
661 
662  // Determine the names of the arguments. If no SSA values are present, use
663  // fallback names.
664  bool noSSANames =
665  llvm::any_of(args, [](auto arg) { return arg.ssaName.name.empty(); });
666  if (noSSANames) {
667  argNames = getFuncOpNames(builder, args.size(), "in");
668  } else {
669  llvm::transform(args, std::back_inserter(argNames), [&](auto arg) {
670  return builder.getStringAttr(arg.ssaName.name.drop_front());
671  });
672  }
673 
674  // Parse attributes
675  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
676  return failure();
677 
678  // If argNames and resNames wasn't provided manually, infer argNames attribute
679  // from the parsed SSA names and resNames from our naming convention.
680  if (!result.attributes.get("argNames"))
681  result.addAttribute("argNames", builder.getArrayAttr(argNames));
682  if (!result.attributes.get("resNames")) {
683  auto resNames = getFuncOpNames(builder, resTypes.size(), "out");
684  result.addAttribute("resNames", builder.getArrayAttr(resNames));
685  }
686 
687  // Parse the optional function body. The printer will not print the body if
688  // its empty, so disallow parsing of empty body in the parser.
689  auto *body = result.addRegion();
690  llvm::SMLoc loc = parser.getCurrentLocation();
691  auto parseResult = parser.parseOptionalRegion(*body, args,
692  /*enableNameShadowing=*/false);
693  if (!parseResult.has_value())
694  return success();
695 
696  if (failed(*parseResult))
697  return failure();
698  // Function body was parsed, make sure its not empty.
699  if (body->empty())
700  return parser.emitError(loc, "expected non-empty function body");
701 
702  // If a body was parsed, the arg and res names need to be resolved
703  return success();
704 }
705 
706 void FuncOp::print(OpAsmPrinter &p) {
707  mlir::function_interface_impl::printFunctionOp(
708  p, *this, /*isVariadic=*/true, getFunctionTypeAttrName(),
709  getArgAttrsAttrName(), getResAttrsAttrName());
710 }
711 
712 namespace {
713 struct EliminateSimpleControlMergesPattern
714  : mlir::OpRewritePattern<ControlMergeOp> {
716 
717  LogicalResult matchAndRewrite(ControlMergeOp op,
718  PatternRewriter &rewriter) const override;
719 };
720 } // namespace
721 
722 LogicalResult EliminateSimpleControlMergesPattern::matchAndRewrite(
723  ControlMergeOp op, PatternRewriter &rewriter) const {
724  auto dataResult = op.getResult();
725  auto choiceResult = op.getIndex();
726  auto choiceUnused = choiceResult.use_empty();
727  if (!choiceUnused && !choiceResult.hasOneUse())
728  return failure();
729 
730  Operation *choiceUser = nullptr;
731  if (choiceResult.hasOneUse()) {
732  choiceUser = choiceResult.getUses().begin().getUser();
733  if (!isa<SinkOp>(choiceUser))
734  return failure();
735  }
736 
737  auto merge = rewriter.create<MergeOp>(op.getLoc(), op.getDataOperands());
738 
739  for (auto &use : llvm::make_early_inc_range(dataResult.getUses())) {
740  auto *user = use.getOwner();
741  rewriter.modifyOpInPlace(
742  user, [&]() { user->setOperand(use.getOperandNumber(), merge); });
743  }
744 
745  if (choiceUnused) {
746  rewriter.eraseOp(op);
747  return success();
748  }
749 
750  rewriter.eraseOp(choiceUser);
751  rewriter.eraseOp(op);
752  return success();
753 }
754 
755 void ControlMergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
756  MLIRContext *context) {
757  results.insert<EliminateSimpleControlMergesPattern>(context);
758 }
759 
760 bool BranchOp::sostIsControl() {
761  return isControlCheckTypeAndOperand(getDataType(), getOperand());
762 }
763 
764 void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
765  MLIRContext *context) {
766  results.insert<circt::handshake::EliminateSimpleBranchesPattern>(context);
767 }
768 
769 ParseResult BranchOp::parse(OpAsmParser &parser, OperationState &result) {
770  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
771  Type type;
772  ArrayRef<Type> operandTypes(type);
773  SmallVector<Type, 1> dataOperandsTypes;
774  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
775  int size;
776  if (parseSostOperation(parser, allOperands, result, size, type, false))
777  return failure();
778 
779  dataOperandsTypes.assign(size, type);
780  result.addTypes({type});
781  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
782  result.operands))
783  return failure();
784  return success();
785 }
786 
787 void BranchOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
788 
789 ParseResult ConditionalBranchOp::parse(OpAsmParser &parser,
790  OperationState &result) {
791  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
792  Type dataType;
793  SmallVector<Type> operandTypes;
794  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
795  if (parser.parseOperandList(allOperands) ||
796  parser.parseOptionalAttrDict(result.attributes) ||
797  parser.parseColonType(dataType))
798  return failure();
799 
800  if (allOperands.size() != 2)
801  return parser.emitError(parser.getCurrentLocation(),
802  "Expected exactly 2 operands");
803 
804  result.addTypes({dataType, dataType});
805  operandTypes.push_back(IntegerType::get(parser.getContext(), 1));
806  operandTypes.push_back(dataType);
807  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
808  result.operands))
809  return failure();
810 
811  return success();
812 }
813 
814 void ConditionalBranchOp::print(OpAsmPrinter &p) {
815  Type type = getDataOperand().getType();
816  p << " " << getOperands();
817  p.printOptionalAttrDict((*this)->getAttrs());
818  p << " : " << type;
819 }
820 
821 std::string handshake::ConditionalBranchOp::getOperandName(unsigned int idx) {
822  assert(idx == 0 || idx == 1);
823  return idx == 0 ? "cond" : "data";
824 }
825 
826 std::string handshake::ConditionalBranchOp::getResultName(unsigned int idx) {
827  assert(idx == 0 || idx == 1);
828  return idx == ConditionalBranchOp::falseIndex ? "outFalse" : "outTrue";
829 }
830 
831 bool ConditionalBranchOp::isControl() {
832  return isControlCheckTypeAndOperand(getDataOperand().getType(),
833  getDataOperand());
834 }
835 
836 ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
837  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
838  Type type;
839  ArrayRef<Type> operandTypes(type);
840  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
841  int size;
842  if (parseSostOperation(parser, allOperands, result, size, type, false))
843  return failure();
844 
845  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
846  result.operands))
847  return failure();
848  return success();
849 }
850 
851 void SinkOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
852 
853 std::string handshake::ConstantOp::getOperandName(unsigned int idx) {
854  assert(idx == 0);
855  return "ctrl";
856 }
857 
858 Type SourceOp::getDataType() { return getResult().getType(); }
859 unsigned SourceOp::getSize() { return 1; }
860 
861 ParseResult SourceOp::parse(OpAsmParser &parser, OperationState &result) {
862  if (parser.parseOptionalAttrDict(result.attributes))
863  return failure();
864  result.addTypes(NoneType::get(result.getContext()));
865  return success();
866 }
867 
868 void SourceOp::print(OpAsmPrinter &p) {
869  p.printOptionalAttrDict((*this)->getAttrs());
870 }
871 
872 LogicalResult ConstantOp::verify() {
873  // Verify that the type of the provided value is equal to the result type.
874  auto typedValue = getValue().dyn_cast<mlir::TypedAttr>();
875  if (!typedValue)
876  return emitOpError("constant value must be a typed attribute; value is ")
877  << getValue();
878  if (typedValue.getType() != getResult().getType())
879  return emitOpError() << "constant value type " << typedValue.getType()
880  << " differs from operation result type "
881  << getResult().getType();
882 
883  return success();
884 }
885 
886 void handshake::ConstantOp::getCanonicalizationPatterns(
887  RewritePatternSet &results, MLIRContext *context) {
888  results.insert<circt::handshake::EliminateSunkConstantsPattern>(context);
889 }
890 
891 LogicalResult BufferOp::verify() {
892  // Verify that exactly 'size' number of initial values have been provided, if
893  // an initializer list have been provided.
894  if (auto initVals = getInitValues()) {
895  if (!isSequential())
896  return emitOpError()
897  << "only bufferType buffers are allowed to have initial values.";
898 
899  auto nInits = initVals->size();
900  if (nInits != getSize())
901  return emitOpError() << "expected " << getSize()
902  << " init values but got " << nInits << ".";
903  }
904 
905  return success();
906 }
907 
908 void handshake::BufferOp::getCanonicalizationPatterns(
909  RewritePatternSet &results, MLIRContext *context) {
910  results.insert<circt::handshake::EliminateSunkBuffersPattern>(context);
911 }
912 
913 unsigned BufferOp::getSize() {
914  return (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
915 }
916 
917 ParseResult BufferOp::parse(OpAsmParser &parser, OperationState &result) {
918  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
919  Type type;
920  ArrayRef<Type> operandTypes(type);
921  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
922  int slots;
923  if (parseIntInSquareBrackets(parser, slots))
924  return failure();
925 
926  auto bufferTypeAttr = BufferTypeEnumAttr::parse(parser, {});
927  if (!bufferTypeAttr)
928  return failure();
929 
930  result.addAttribute(
931  "slots",
932  IntegerAttr::get(IntegerType::get(result.getContext(), 32), slots));
933  result.addAttribute("bufferType", bufferTypeAttr);
934 
935  if (parser.parseOperandList(allOperands) ||
936  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
937  parser.parseType(type))
938  return failure();
939 
940  result.addTypes({type});
941  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
942  result.operands))
943  return failure();
944  return success();
945 }
946 
947 void BufferOp::print(OpAsmPrinter &p) {
948  int size =
949  (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
950  p << " [" << size << "]";
951  p << " " << stringifyEnum(getBufferType());
952  p << " " << (*this)->getOperands();
953  p.printOptionalAttrDict((*this)->getAttrs(), {"slots", "bufferType"});
954  p << " : " << (*this).getDataType();
955 }
956 
957 static std::string getMemoryOperandName(unsigned nStores, unsigned idx) {
958  std::string name;
959  if (idx < nStores * 2) {
960  bool isData = idx % 2 == 0;
961  name = isData ? "stData" + std::to_string(idx / 2)
962  : "stAddr" + std::to_string(idx / 2);
963  } else {
964  idx -= 2 * nStores;
965  name = "ldAddr" + std::to_string(idx);
966  }
967  return name;
968 }
969 
970 std::string handshake::MemoryOp::getOperandName(unsigned int idx) {
971  return getMemoryOperandName(getStCount(), idx);
972 }
973 
974 static std::string getMemoryResultName(unsigned nLoads, unsigned nStores,
975  unsigned idx) {
976  std::string name;
977  if (idx < nLoads)
978  name = "ldData" + std::to_string(idx);
979  else if (idx < nLoads + nStores)
980  name = "stDone" + std::to_string(idx - nLoads);
981  else
982  name = "ldDone" + std::to_string(idx - nLoads - nStores);
983  return name;
984 }
985 
986 std::string handshake::MemoryOp::getResultName(unsigned int idx) {
987  return getMemoryResultName(getLdCount(), getStCount(), idx);
988 }
989 
990 LogicalResult MemoryOp::verify() {
991  auto memrefType = getMemRefType();
992 
993  if (memrefType.getNumDynamicDims() != 0)
994  return emitOpError()
995  << "memref dimensions for handshake.memory must be static.";
996  if (memrefType.getShape().size() != 1)
997  return emitOpError() << "memref must have only a single dimension.";
998 
999  unsigned opStCount = getStCount();
1000  unsigned opLdCount = getLdCount();
1001  int addressCount = memrefType.getShape().size();
1002 
1003  auto inputType = getInputs().getType();
1004  auto outputType = getOutputs().getType();
1005  Type dataType = memrefType.getElementType();
1006 
1007  unsigned numOperands = static_cast<int>(getInputs().size());
1008  unsigned numResults = static_cast<int>(getOutputs().size());
1009  if (numOperands != (1 + addressCount) * opStCount + addressCount * opLdCount)
1010  return emitOpError("number of operands ")
1011  << numOperands << " does not match number expected of "
1012  << 2 * opStCount + opLdCount << " with " << addressCount
1013  << " address inputs per port";
1014 
1015  if (numResults != opStCount + 2 * opLdCount)
1016  return emitOpError("number of results ")
1017  << numResults << " does not match number expected of "
1018  << opStCount + 2 * opLdCount << " with " << addressCount
1019  << " address inputs per port";
1020 
1021  Type addressType = opStCount > 0 ? inputType[1] : inputType[0];
1022 
1023  for (unsigned i = 0; i < opStCount; i++) {
1024  if (inputType[2 * i] != dataType)
1025  return emitOpError("data type for store port ")
1026  << i << ":" << inputType[2 * i] << " doesn't match memory type "
1027  << dataType;
1028  if (inputType[2 * i + 1] != addressType)
1029  return emitOpError("address type for store port ")
1030  << i << ":" << inputType[2 * i + 1]
1031  << " doesn't match address type " << addressType;
1032  }
1033  for (unsigned i = 0; i < opLdCount; i++) {
1034  Type ldAddressType = inputType[2 * opStCount + i];
1035  if (ldAddressType != addressType)
1036  return emitOpError("address type for load port ")
1037  << i << ":" << ldAddressType << " doesn't match address type "
1038  << addressType;
1039  }
1040  for (unsigned i = 0; i < opLdCount; i++) {
1041  if (outputType[i] != dataType)
1042  return emitOpError("data type for load port ")
1043  << i << ":" << outputType[i] << " doesn't match memory type "
1044  << dataType;
1045  }
1046  for (unsigned i = 0; i < opStCount; i++) {
1047  Type syncType = outputType[opLdCount + i];
1048  if (!syncType.isa<NoneType>())
1049  return emitOpError("data type for sync port for store port ")
1050  << i << ":" << syncType << " is not 'none'";
1051  }
1052  for (unsigned i = 0; i < opLdCount; i++) {
1053  Type syncType = outputType[opLdCount + opStCount + i];
1054  if (!syncType.isa<NoneType>())
1055  return emitOpError("data type for sync port for load port ")
1056  << i << ":" << syncType << " is not 'none'";
1057  }
1058 
1059  return success();
1060 }
1061 
1062 std::string handshake::ExternalMemoryOp::getOperandName(unsigned int idx) {
1063  if (idx == 0)
1064  return "extmem";
1065 
1066  return getMemoryOperandName(getStCount(), idx - 1);
1067 }
1068 
1069 std::string handshake::ExternalMemoryOp::getResultName(unsigned int idx) {
1070  return getMemoryResultName(getLdCount(), getStCount(), idx);
1071 }
1072 
1073 void ExternalMemoryOp::build(OpBuilder &builder, OperationState &result,
1074  Value memref, ValueRange inputs, int ldCount,
1075  int stCount, int id) {
1076  SmallVector<Value> ops;
1077  ops.push_back(memref);
1078  llvm::append_range(ops, inputs);
1079  result.addOperands(ops);
1080 
1081  auto memrefType = memref.getType().cast<MemRefType>();
1082 
1083  // Data outputs (get their type from memref)
1084  result.types.append(ldCount, memrefType.getElementType());
1085 
1086  // Control outputs
1087  result.types.append(stCount + ldCount, builder.getNoneType());
1088 
1089  // Memory ID (individual ID for each MemoryOp)
1090  Type i32Type = builder.getIntegerType(32);
1091  result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1092  result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, ldCount));
1093  result.addAttribute("stCount", builder.getIntegerAttr(i32Type, stCount));
1094 }
1095 
1096 llvm::SmallVector<handshake::MemLoadInterface>
1098  return ::getLoadPorts(*this);
1099 }
1100 
1101 llvm::SmallVector<handshake::MemStoreInterface>
1103  return ::getStorePorts(*this);
1104 }
1105 
1106 void MemoryOp::build(OpBuilder &builder, OperationState &result,
1107  ValueRange operands, int outputs, int controlOutputs,
1108  bool lsq, int id, Value memref) {
1109  result.addOperands(operands);
1110 
1111  auto memrefType = memref.getType().cast<MemRefType>();
1112 
1113  // Data outputs (get their type from memref)
1114  result.types.append(outputs, memrefType.getElementType());
1115 
1116  // Control outputs
1117  result.types.append(controlOutputs, builder.getNoneType());
1118  result.addAttribute("lsq", builder.getBoolAttr(lsq));
1119  result.addAttribute("memRefType", TypeAttr::get(memrefType));
1120 
1121  // Memory ID (individual ID for each MemoryOp)
1122  Type i32Type = builder.getIntegerType(32);
1123  result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1124 
1125  if (!lsq) {
1126  result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, outputs));
1127  result.addAttribute(
1128  "stCount", builder.getIntegerAttr(i32Type, controlOutputs - outputs));
1129  }
1130 }
1131 
1132 llvm::SmallVector<handshake::MemLoadInterface> MemoryOp::getLoadPorts() {
1133  return ::getLoadPorts(*this);
1134 }
1135 
1136 llvm::SmallVector<handshake::MemStoreInterface> MemoryOp::getStorePorts() {
1137  return ::getStorePorts(*this);
1138 }
1139 
1140 bool handshake::MemoryOp::allocateMemory(
1141  llvm::DenseMap<unsigned, unsigned> &memoryMap,
1142  std::vector<std::vector<llvm::Any>> &store,
1143  std::vector<double> &storeTimes) {
1144  if (memoryMap.count(getId()))
1145  return false;
1146 
1147  auto type = getMemRefType();
1148  std::vector<llvm::Any> in;
1149 
1150  ArrayRef<int64_t> shape = type.getShape();
1151  int allocationSize = 1;
1152  unsigned count = 0;
1153  for (int64_t dim : shape) {
1154  if (dim > 0)
1155  allocationSize *= dim;
1156  else {
1157  assert(count < in.size());
1158  allocationSize *= llvm::any_cast<APInt>(in[count++]).getSExtValue();
1159  }
1160  }
1161  unsigned ptr = store.size();
1162  store.resize(ptr + 1);
1163  storeTimes.resize(ptr + 1);
1164  store[ptr].resize(allocationSize);
1165  storeTimes[ptr] = 0.0;
1166  mlir::Type elementType = type.getElementType();
1167  int width = elementType.getIntOrFloatBitWidth();
1168  for (int i = 0; i < allocationSize; i++) {
1169  if (elementType.isa<mlir::IntegerType>()) {
1170  store[ptr][i] = APInt(width, 0);
1171  } else if (elementType.isa<mlir::FloatType>()) {
1172  store[ptr][i] = APFloat(0.0);
1173  } else {
1174  llvm_unreachable("Unknown result type!\n");
1175  }
1176  }
1177 
1178  memoryMap[getId()] = ptr;
1179  return true;
1180 }
1181 
1182 std::string handshake::LoadOp::getOperandName(unsigned int idx) {
1183  unsigned nAddresses = getAddresses().size();
1184  std::string opName;
1185  if (idx < nAddresses)
1186  opName = "addrIn" + std::to_string(idx);
1187  else if (idx == nAddresses)
1188  opName = "dataFromMem";
1189  else
1190  opName = "ctrl";
1191  return opName;
1192 }
1193 
1194 std::string handshake::LoadOp::getResultName(unsigned int idx) {
1195  std::string resName;
1196  if (idx == 0)
1197  resName = "dataOut";
1198  else
1199  resName = "addrOut" + std::to_string(idx - 1);
1200  return resName;
1201 }
1202 
1203 void handshake::LoadOp::build(OpBuilder &builder, OperationState &result,
1204  Value memref, ValueRange indices) {
1205  // Address indices
1206  // result.addOperands(memref);
1207  result.addOperands(indices);
1208 
1209  // Data type
1210  auto memrefType = memref.getType().cast<MemRefType>();
1211 
1212  // Data output (from load to successor ops)
1213  result.types.push_back(memrefType.getElementType());
1214 
1215  // Address outputs (to lsq)
1216  result.types.append(indices.size(), builder.getIndexType());
1217 }
1218 
1219 static ParseResult parseMemoryAccessOp(OpAsmParser &parser,
1220  OperationState &result) {
1221  SmallVector<OpAsmParser::UnresolvedOperand, 4> addressOperands,
1222  remainingOperands, allOperands;
1223  SmallVector<Type, 1> parsedTypes, allTypes;
1224  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1225 
1226  if (parser.parseLSquare() || parser.parseOperandList(addressOperands) ||
1227  parser.parseRSquare() || parser.parseOperandList(remainingOperands) ||
1228  parser.parseColon() || parser.parseTypeList(parsedTypes))
1229  return failure();
1230 
1231  // The last type will be the data type of the operation; the prior will be the
1232  // address types.
1233  Type dataType = parsedTypes.back();
1234  auto parsedTypesRef = ArrayRef(parsedTypes);
1235  result.addTypes(dataType);
1236  result.addTypes(parsedTypesRef.drop_back());
1237  allOperands.append(addressOperands);
1238  allOperands.append(remainingOperands);
1239  allTypes.append(parsedTypes);
1240  allTypes.push_back(NoneType::get(result.getContext()));
1241  if (parser.resolveOperands(allOperands, allTypes, allOperandLoc,
1242  result.operands))
1243  return failure();
1244  return success();
1245 }
1246 
1247 template <typename MemOp>
1248 static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op) {
1249  p << " [";
1250  p << op.getAddresses();
1251  p << "] " << op.getData() << ", " << op.getCtrl() << " : ";
1252  llvm::interleaveComma(op.getAddresses(), p,
1253  [&](Value v) { p << v.getType(); });
1254  p << ", " << op.getData().getType();
1255 }
1256 
1257 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
1258  return parseMemoryAccessOp(parser, result);
1259 }
1260 
1261 void LoadOp::print(OpAsmPrinter &p) { printMemoryAccessOp(p, *this); }
1262 
1263 std::string handshake::StoreOp::getOperandName(unsigned int idx) {
1264  unsigned nAddresses = getAddresses().size();
1265  std::string opName;
1266  if (idx < nAddresses)
1267  opName = "addrIn" + std::to_string(idx);
1268  else if (idx == nAddresses)
1269  opName = "dataIn";
1270  else
1271  opName = "ctrl";
1272  return opName;
1273 }
1274 
1275 template <typename TMemoryOp>
1276 static LogicalResult verifyMemoryAccessOp(TMemoryOp op) {
1277  if (op.getAddresses().size() == 0)
1278  return op.emitOpError() << "No addresses were specified";
1279 
1280  return success();
1281 }
1282 
1283 LogicalResult LoadOp::verify() { return verifyMemoryAccessOp(*this); }
1284 
1285 std::string handshake::StoreOp::getResultName(unsigned int idx) {
1286  std::string resName;
1287  if (idx == 0)
1288  resName = "dataToMem";
1289  else
1290  resName = "addrOut" + std::to_string(idx - 1);
1291  return resName;
1292 }
1293 
1294 void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
1295  Value valueToStore, ValueRange indices) {
1296 
1297  // Address indices
1298  result.addOperands(indices);
1299 
1300  // Data
1301  result.addOperands(valueToStore);
1302 
1303  // Data output (from store to LSQ)
1304  result.types.push_back(valueToStore.getType());
1305 
1306  // Address outputs (from store to lsq)
1307  result.types.append(indices.size(), builder.getIndexType());
1308 }
1309 
1310 LogicalResult StoreOp::verify() { return verifyMemoryAccessOp(*this); }
1311 
1312 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
1313  return parseMemoryAccessOp(parser, result);
1314 }
1315 
1316 void StoreOp::print(OpAsmPrinter &p) { return printMemoryAccessOp(p, *this); }
1317 
1318 bool JoinOp::isControl() { return true; }
1319 
1320 ParseResult JoinOp::parse(OpAsmParser &parser, OperationState &result) {
1321  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1322  SmallVector<Type> types;
1323 
1324  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1325  if (parser.parseOperandList(operands) ||
1326  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1327  parser.parseTypeList(types))
1328  return failure();
1329 
1330  if (parser.resolveOperands(operands, types, allOperandLoc, result.operands))
1331  return failure();
1332 
1333  result.addTypes(NoneType::get(result.getContext()));
1334  return success();
1335 }
1336 
1337 void JoinOp::print(OpAsmPrinter &p) {
1338  p << " " << getData();
1339  p.printOptionalAttrDict((*this)->getAttrs(), {"control"});
1340  p << " : " << getData().getTypes();
1341 }
1342 
1343 /// Based on mlir::func::CallOp::verifySymbolUses
1344 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1345  // Check that the module attribute was specified.
1346  auto fnAttr = this->getModuleAttr();
1347  assert(fnAttr && "requires a 'module' symbol reference attribute");
1348 
1349  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
1350  if (!fn)
1351  return emitOpError() << "'" << fnAttr.getValue()
1352  << "' does not reference a valid handshake function";
1353 
1354  // Verify that the operand and result types match the callee.
1355  auto fnType = fn.getFunctionType();
1356  if (fnType.getNumInputs() != getNumOperands())
1357  return emitOpError(
1358  "incorrect number of operands for the referenced handshake function");
1359 
1360  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
1361  if (getOperand(i).getType() != fnType.getInput(i))
1362  return emitOpError("operand type mismatch: expected operand type ")
1363  << fnType.getInput(i) << ", but provided "
1364  << getOperand(i).getType() << " for operand number " << i;
1365 
1366  if (fnType.getNumResults() != getNumResults())
1367  return emitOpError(
1368  "incorrect number of results for the referenced handshake function");
1369 
1370  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
1371  if (getResult(i).getType() != fnType.getResult(i))
1372  return emitOpError("result type mismatch: expected result type ")
1373  << fnType.getResult(i) << ", but provided "
1374  << getResult(i).getType() << " for result number " << i;
1375 
1376  return success();
1377 }
1378 
1379 FunctionType InstanceOp::getModuleType() {
1380  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
1381 }
1382 
1383 ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1384  OpAsmParser::UnresolvedOperand tuple;
1385  TupleType type;
1386 
1387  if (parser.parseOperand(tuple) ||
1388  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1389  parser.parseType(type))
1390  return failure();
1391 
1392  if (parser.resolveOperand(tuple, type, result.operands))
1393  return failure();
1394 
1395  result.addTypes(type.getTypes());
1396 
1397  return success();
1398 }
1399 
1400 void UnpackOp::print(OpAsmPrinter &p) {
1401  p << " " << getInput();
1402  p.printOptionalAttrDict((*this)->getAttrs());
1403  p << " : " << getInput().getType();
1404 }
1405 
1406 ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1407  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1408  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1409  TupleType type;
1410 
1411  if (parser.parseOperandList(operands) ||
1412  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1413  parser.parseType(type))
1414  return failure();
1415 
1416  if (parser.resolveOperands(operands, type.getTypes(), allOperandLoc,
1417  result.operands))
1418  return failure();
1419 
1420  result.addTypes(type);
1421 
1422  return success();
1423 }
1424 
1425 void PackOp::print(OpAsmPrinter &p) {
1426  p << " " << getInputs();
1427  p.printOptionalAttrDict((*this)->getAttrs());
1428  p << " : " << getResult().getType();
1429 }
1430 
1431 //===----------------------------------------------------------------------===//
1432 // TableGen'd op method definitions
1433 //===----------------------------------------------------------------------===//
1434 
1435 LogicalResult ReturnOp::verify() {
1436  auto *parent = (*this)->getParentOp();
1437  auto function = dyn_cast<handshake::FuncOp>(parent);
1438  if (!function)
1439  return emitOpError("must have a handshake.func parent");
1440 
1441  // The operand number and types must match the function signature.
1442  const auto &results = function.getResultTypes();
1443  if (getNumOperands() != results.size())
1444  return emitOpError("has ")
1445  << getNumOperands() << " operands, but enclosing function returns "
1446  << results.size();
1447 
1448  for (unsigned i = 0, e = results.size(); i != e; ++i)
1449  if (getOperand(i).getType() != results[i])
1450  return emitError() << "type of return operand " << i << " ("
1451  << getOperand(i).getType()
1452  << ") doesn't match function result type ("
1453  << results[i] << ")";
1454 
1455  return success();
1456 }
1457 
1458 #define GET_OP_CLASSES
1459 #include "circt/Dialect/Handshake/Handshake.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition: CHIRRTL.cpp:29
static StringRef getOperandName(Value operand)
int32_t width
Definition: FIRRTL.cpp:36
static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal, uint64_t numOperands)
Verifies whether an indexing value is wide enough to index into a provided number of operands.
static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result)
static ParseResult parseSostOperation(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, OperationState &result, int &size, Type &type, bool explicitSize)
static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op)
static ParseResult parseFuncOpArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Type > &resTypes, SmallVectorImpl< DictionaryAttr > &resAttrs)
Parses a FuncOp signature using mlir::function_interface_impl::parseFunctionSignature while getting a...
static Value getDematerialized(Value v)
Returns a dematerialized version of the value 'v', defined as the source of the value before passing ...
static bool isControlCheckTypeAndOperand(Type dataType, Value operand)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static std::string getMemoryResultName(unsigned nLoads, unsigned nStores, unsigned idx)
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v)
static LogicalResult verifyMemoryAccessOp(TMemoryOp op)
static void addStringToStringArrayAttr(Builder &builder, Operation *op, StringRef attrName, StringAttr str)
Helper function for appending a string to an array attribute, and rewriting the attribute back to the...
static std::string defaultOperandName(unsigned int idx)
static SmallVector< Attribute > getFuncOpNames(Builder &builder, unsigned cnt, StringRef prefix)
Generates names for a handshake.func input and output arguments, based on the number of args as well ...
static std::string getMemoryOperandName(unsigned nStores, unsigned idx)
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
static ParseResult parseMemoryAccessOp(OpAsmParser &parser, OperationState &result)
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
bool isControlOpImpl(Operation *op)
Default implementation for checking whether an operation is a control operation.
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition: HWOps.cpp:515
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21