CIRCT  20.0.0git
HandshakeOps.cpp
Go to the documentation of this file.
1 //===- HandshakeOps.cpp - Handshake MLIR Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the declaration of the Handshake operations struct.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "circt/Support/LLVM.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/SymbolTable.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Interfaces/FunctionImplementation.h"
29 #include "mlir/Transforms/InliningUtils.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/SmallBitVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 
34 #include <set>
35 
36 using namespace circt;
37 using namespace circt::handshake;
38 
39 namespace circt {
40 namespace handshake {
41 #include "circt/Dialect/Handshake/HandshakeCanonicalization.h.inc"
42 
43 bool isControlOpImpl(Operation *op) {
44  if (auto sostInterface = dyn_cast<SOSTInterface>(op); sostInterface)
45  return sostInterface.sostIsControl();
46 
47  return false;
48 }
49 
50 } // namespace handshake
51 } // namespace circt
52 
53 static std::string defaultOperandName(unsigned int idx) {
54  return "in" + std::to_string(idx);
55 }
56 
57 static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v) {
58  if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
59  return failure();
60  return success();
61 }
62 
63 static ParseResult
64 parseSostOperation(OpAsmParser &parser,
65  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
66  OperationState &result, int &size, Type &type,
67  bool explicitSize) {
68  if (explicitSize)
69  if (parseIntInSquareBrackets(parser, size))
70  return failure();
71 
72  if (parser.parseOperandList(operands) ||
73  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
74  parser.parseType(type))
75  return failure();
76 
77  if (!explicitSize)
78  size = operands.size();
79  return success();
80 }
81 
82 /// Verifies whether an indexing value is wide enough to index into a provided
83 /// number of operands.
84 static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal,
85  uint64_t numOperands) {
86  auto indexType = indexVal.getType();
87  unsigned indexWidth;
88 
89  // Determine the bitwidth of the indexing value
90  if (auto integerType = dyn_cast<IntegerType>(indexType))
91  indexWidth = integerType.getWidth();
92  else if (indexType.isIndex())
93  indexWidth = IndexType::kInternalStorageBitWidth;
94  else
95  return op->emitError("unsupported type for indexing value: ") << indexType;
96 
97  // Check whether the bitwidth can support the provided number of operands
98  if (indexWidth < 64) {
99  uint64_t maxNumOperands = (uint64_t)1 << indexWidth;
100  if (numOperands > maxNumOperands)
101  return op->emitError("bitwidth of indexing value is ")
102  << indexWidth << ", which can index into " << maxNumOperands
103  << " operands, but found " << numOperands << " operands";
104  }
105  return success();
106 }
107 
108 static bool isControlCheckTypeAndOperand(Type dataType, Value operand) {
109  // The operation is a control operation if its operand data type is a
110  // NoneType.
111  if (isa<NoneType>(dataType))
112  return true;
113 
114  // Otherwise, the operation is a control operation if the operation's
115  // operand originates from the control network
116  auto *defOp = operand.getDefiningOp();
117  return isa_and_nonnull<ControlMergeOp>(defOp) &&
118  operand == defOp->getResult(0);
119 }
120 
121 template <typename TMemOp>
122 llvm::SmallVector<handshake::MemLoadInterface> getLoadPorts(TMemOp op) {
123  llvm::SmallVector<MemLoadInterface> ports;
124  // Memory interface refresher:
125  // Operands:
126  // all stores (stdata1, staddr1, stdata2, staddr2, ...)
127  // then all loads (ldaddr1, ldaddr2,...)
128  // Outputs: load addresses (lddata1, lddata2, ...), followed by all none
129  // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
130  unsigned stCount = op.getStCount();
131  unsigned ldCount = op.getLdCount();
132  for (unsigned i = 0, e = ldCount; i != e; ++i) {
133  MemLoadInterface ldif;
134  ldif.index = i;
135  ldif.addressIn = op.getInputs()[stCount * 2 + i];
136  ldif.dataOut = op.getResult(i);
137  ldif.doneOut = op.getResult(ldCount + stCount + i);
138  ports.push_back(ldif);
139  }
140  return ports;
141 }
142 
143 template <typename TMemOp>
144 llvm::SmallVector<handshake::MemStoreInterface> getStorePorts(TMemOp op) {
145  llvm::SmallVector<MemStoreInterface> ports;
146  // Memory interface refresher:
147  // Operands:
148  // all stores (stdata1, staddr1, stdata2, staddr2, ...)
149  // then all loads (ldaddr1, ldaddr2,...)
150  // Outputs: load data (lddata1, lddata2, ...), followed by all none
151  // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
152  unsigned ldCount = op.getLdCount();
153  for (unsigned i = 0, e = op.getStCount(); i != e; ++i) {
154  MemStoreInterface stif;
155  stif.index = i;
156  stif.dataIn = op.getInputs()[i * 2];
157  stif.addressIn = op.getInputs()[i * 2 + 1];
158  stif.doneOut = op.getResult(ldCount + i);
159  ports.push_back(stif);
160  }
161  return ports;
162 }
163 
164 unsigned ForkOp::getSize() { return getResults().size(); }
165 
166 static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result) {
167  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
168  Type type;
169  ArrayRef<Type> operandTypes(type);
170  SmallVector<Type, 1> resultTypes;
171  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
172  int size;
173  if (parseSostOperation(parser, allOperands, result, size, type, true))
174  return failure();
175 
176  resultTypes.assign(size, type);
177  result.addTypes(resultTypes);
178  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
179  result.operands))
180  return failure();
181  return success();
182 }
183 
184 ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
185  return parseForkOp(parser, result);
186 }
187 
188 void ForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
189 
190 namespace {
191 
192 struct EliminateUnusedForkResultsPattern : mlir::OpRewritePattern<ForkOp> {
194 
195  LogicalResult matchAndRewrite(ForkOp op,
196  PatternRewriter &rewriter) const override {
197  std::set<unsigned> unusedIndexes;
198 
199  for (auto res : llvm::enumerate(op.getResults()))
200  if (res.value().getUses().empty())
201  unusedIndexes.insert(res.index());
202 
203  if (unusedIndexes.empty())
204  return failure();
205 
206  // Create a new fork op, dropping the unused results.
207  rewriter.setInsertionPoint(op);
208  auto operand = op.getOperand();
209  auto newFork = rewriter.create<ForkOp>(
210  op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
211  unsigned i = 0;
212  for (auto oldRes : llvm::enumerate(op.getResults()))
213  if (unusedIndexes.count(oldRes.index()) == 0)
214  rewriter.replaceAllUsesWith(oldRes.value(), newFork.getResults()[i++]);
215  rewriter.eraseOp(op);
216  return success();
217  }
218 };
219 
220 struct EliminateForkToForkPattern : mlir::OpRewritePattern<ForkOp> {
222 
223  LogicalResult matchAndRewrite(ForkOp op,
224  PatternRewriter &rewriter) const override {
225  auto parentForkOp = op.getOperand().getDefiningOp<ForkOp>();
226  if (!parentForkOp)
227  return failure();
228 
229  /// Create the fork with as many outputs as the two source forks.
230  /// Keeping the op.operand() output may or may not be redundant (dependning
231  /// on if op is the single user of the value), but we'll let
232  /// EliminateUnusedForkResultsPattern apply in that case.
233  unsigned totalNumOuts = op.getSize() + parentForkOp.getSize();
234  /// Create a new parent fork op which produces all of the fork outputs and
235  /// replace all of the uses of the old results.
236  auto newParentForkOp = rewriter.create<ForkOp>(
237  parentForkOp.getLoc(), parentForkOp.getOperand(), totalNumOuts);
238 
239  for (auto it :
240  llvm::zip(parentForkOp->getResults(), newParentForkOp.getResults()))
241  rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
242 
243  /// Replace the results of the matches fork op with the corresponding
244  /// results of the new parent fork op.
245  rewriter.replaceOp(op,
246  newParentForkOp.getResults().take_back(op.getSize()));
247  rewriter.eraseOp(parentForkOp);
248  return success();
249  }
250 };
251 
252 } // namespace
253 
254 void handshake::ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
255  MLIRContext *context) {
256  results.insert<circt::handshake::EliminateSimpleForksPattern,
257  EliminateForkToForkPattern, EliminateUnusedForkResultsPattern>(
258  context);
259 }
260 
261 unsigned LazyForkOp::getSize() { return getResults().size(); }
262 
263 bool LazyForkOp::sostIsControl() {
264  return isControlCheckTypeAndOperand(getDataType(), getOperand());
265 }
266 
267 ParseResult LazyForkOp::parse(OpAsmParser &parser, OperationState &result) {
268  return parseForkOp(parser, result);
269 }
270 
271 void LazyForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
272 
273 ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
274  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
275  Type type;
276  ArrayRef<Type> operandTypes(type);
277  SmallVector<Type, 1> resultTypes, dataOperandsTypes;
278  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
279  int size;
280  if (parseSostOperation(parser, allOperands, result, size, type, false))
281  return failure();
282 
283  dataOperandsTypes.assign(size, type);
284  resultTypes.push_back(type);
285  result.addTypes(resultTypes);
286  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
287  result.operands))
288  return failure();
289  return success();
290 }
291 
292 void MergeOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
293 
294 void MergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
295  MLIRContext *context) {
296  results.insert<circt::handshake::EliminateSimpleMergesPattern>(context);
297 }
298 
299 /// Returns a dematerialized version of the value 'v', defined as the source of
300 /// the value before passing through a buffer or fork operation.
301 static Value getDematerialized(Value v) {
302  Operation *parentOp = v.getDefiningOp();
303  if (!parentOp)
304  return v;
305 
306  return llvm::TypeSwitch<Operation *, Value>(parentOp)
307  .Case<ForkOp>(
308  [&](ForkOp op) { return getDematerialized(op.getOperand()); })
309  .Case<BufferOp>(
310  [&](BufferOp op) { return getDematerialized(op.getOperand()); })
311  .Default([&](auto) { return v; });
312 }
313 
314 namespace {
315 
316 /// Eliminates muxes with identical data inputs. Data inputs are inspected as
317 /// their dematerialized versions. This has the side effect of any subsequently
318 /// unused buffers are DCE'd and forks are optimized to be narrower.
319 struct EliminateSimpleMuxesPattern : mlir::OpRewritePattern<MuxOp> {
321  LogicalResult matchAndRewrite(MuxOp op,
322  PatternRewriter &rewriter) const override {
323  Value firstDataOperand = getDematerialized(op.getDataOperands()[0]);
324  if (!llvm::all_of(op.getDataOperands(), [&](Value operand) {
325  return getDematerialized(operand) == firstDataOperand;
326  }))
327  return failure();
328  rewriter.replaceOp(op, firstDataOperand);
329  return success();
330  }
331 };
332 
333 struct EliminateUnaryMuxesPattern : OpRewritePattern<MuxOp> {
335  LogicalResult matchAndRewrite(MuxOp op,
336  PatternRewriter &rewriter) const override {
337  if (op.getDataOperands().size() != 1)
338  return failure();
339 
340  rewriter.replaceOp(op, op.getDataOperands()[0]);
341  return success();
342  }
343 };
344 
345 struct EliminateCBranchIntoMuxPattern : OpRewritePattern<MuxOp> {
347  LogicalResult matchAndRewrite(MuxOp op,
348  PatternRewriter &rewriter) const override {
349 
350  auto dataOperands = op.getDataOperands();
351  if (dataOperands.size() != 2)
352  return failure();
353 
354  // Both data operands must originate from the same cbranch
355  ConditionalBranchOp firstParentCBranch =
356  dataOperands[0].getDefiningOp<ConditionalBranchOp>();
357  if (!firstParentCBranch)
358  return failure();
359  auto secondParentCBranch =
360  dataOperands[1].getDefiningOp<ConditionalBranchOp>();
361  if (!secondParentCBranch || firstParentCBranch != secondParentCBranch)
362  return failure();
363 
364  rewriter.modifyOpInPlace(firstParentCBranch, [&] {
365  // Replace uses of the mux's output with cbranch's data input
366  rewriter.replaceOp(op, firstParentCBranch.getDataOperand());
367  });
368 
369  return success();
370  }
371 };
372 
373 } // namespace
374 
375 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
376  MLIRContext *context) {
377  results.insert<EliminateSimpleMuxesPattern, EliminateUnaryMuxesPattern,
378  EliminateCBranchIntoMuxPattern>(context);
379 }
380 
381 LogicalResult
382 MuxOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
383  ValueRange operands, DictionaryAttr attributes,
384  mlir::OpaqueProperties properties,
385  mlir::RegionRange regions,
386  SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
387  // MuxOp must have at least one data operand (in addition to the select
388  // operand)
389  if (operands.size() < 2)
390  return failure();
391  // Result type is type of any data operand
392  inferredReturnTypes.push_back(operands[1].getType());
393  return success();
394 }
395 
396 bool MuxOp::isControl() { return isa<NoneType>(getResult().getType()); }
397 
398 std::string handshake::MuxOp::getOperandName(unsigned int idx) {
399  return idx == 0 ? "select" : defaultOperandName(idx - 1);
400 }
401 
402 ParseResult MuxOp::parse(OpAsmParser &parser, OperationState &result) {
403  OpAsmParser::UnresolvedOperand selectOperand;
404  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
405  Type selectType, dataType;
406  SmallVector<Type, 1> dataOperandsTypes;
407  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
408  if (parser.parseOperand(selectOperand) || parser.parseLSquare() ||
409  parser.parseOperandList(allOperands) || parser.parseRSquare() ||
410  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
411  parser.parseType(selectType) || parser.parseComma() ||
412  parser.parseType(dataType))
413  return failure();
414 
415  int size = allOperands.size();
416  dataOperandsTypes.assign(size, dataType);
417  result.addTypes(dataType);
418  allOperands.insert(allOperands.begin(), selectOperand);
419  if (parser.resolveOperands(
420  allOperands,
421  llvm::concat<const Type>(ArrayRef<Type>(selectType),
422  ArrayRef<Type>(dataOperandsTypes)),
423  allOperandLoc, result.operands))
424  return failure();
425  return success();
426 }
427 
428 void MuxOp::print(OpAsmPrinter &p) {
429  Type selectType = getSelectOperand().getType();
430  auto ops = getOperands();
431  p << ' ' << ops.front();
432  p << " [";
433  p.printOperands(ops.drop_front());
434  p << "]";
435  p.printOptionalAttrDict((*this)->getAttrs());
436  p << " : " << selectType << ", " << getResult().getType();
437 }
438 
439 LogicalResult MuxOp::verify() {
440  return verifyIndexWideEnough(*this, getSelectOperand(),
441  getDataOperands().size());
442 }
443 
444 std::string handshake::ControlMergeOp::getResultName(unsigned int idx) {
445  assert(idx == 0 || idx == 1);
446  return idx == 0 ? "dataOut" : "index";
447 }
448 
449 ParseResult ControlMergeOp::parse(OpAsmParser &parser, OperationState &result) {
450  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
451  Type resultType, indexType;
452  SmallVector<Type> resultTypes, dataOperandsTypes;
453  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
454  int size;
455  if (parseSostOperation(parser, allOperands, result, size, resultType, false))
456  return failure();
457  // Parse type of index result
458  if (parser.parseComma() || parser.parseType(indexType))
459  return failure();
460 
461  dataOperandsTypes.assign(size, resultType);
462  resultTypes.push_back(resultType);
463  resultTypes.push_back(indexType);
464  result.addTypes(resultTypes);
465  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
466  result.operands))
467  return failure();
468  return success();
469 }
470 
471 void ControlMergeOp::print(OpAsmPrinter &p) {
472  sostPrint(p, false);
473  // Print type of index result
474  p << ", " << getIndex().getType();
475 }
476 
477 LogicalResult ControlMergeOp::verify() {
478  auto operands = getOperands();
479  if (operands.empty())
480  return emitOpError("operation must have at least one operand");
481  if (operands[0].getType() != getResult().getType())
482  return emitOpError("type of first result should match type of operands");
483  return verifyIndexWideEnough(*this, getIndex(), getNumOperands());
484 }
485 
486 LogicalResult FuncOp::verify() {
487  // If this function is external there is nothing to do.
488  if (isExternal())
489  return success();
490 
491  // Verify that the argument list of the function and the arg list of the
492  // entry block line up. The trait already verified that the number of
493  // arguments is the same between the signature and the block.
494  auto fnInputTypes = getArgumentTypes();
495  Block &entryBlock = front();
496 
497  for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
498  if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
499  return emitOpError("type of entry block argument #")
500  << i << '(' << entryBlock.getArgument(i).getType()
501  << ") must match the type of the corresponding argument in "
502  << "function signature(" << fnInputTypes[i] << ')';
503 
504  // Verify that we have a name for each argument and result of this function.
505  auto verifyPortNameAttr = [&](StringRef attrName,
506  unsigned numIOs) -> LogicalResult {
507  auto portNamesAttr = (*this)->getAttrOfType<ArrayAttr>(attrName);
508 
509  if (!portNamesAttr)
510  return emitOpError() << "expected attribute '" << attrName << "'.";
511 
512  auto portNames = portNamesAttr.getValue();
513  if (portNames.size() != numIOs)
514  return emitOpError() << "attribute '" << attrName << "' has "
515  << portNames.size()
516  << " entries but is expected to have " << numIOs
517  << ".";
518 
519  if (llvm::any_of(portNames,
520  [&](Attribute attr) { return !isa<StringAttr>(attr); }))
521  return emitOpError() << "expected all entries in attribute '" << attrName
522  << "' to be strings.";
523 
524  return success();
525  };
526  if (failed(verifyPortNameAttr("argNames", getNumArguments())))
527  return failure();
528  if (failed(verifyPortNameAttr("resNames", getNumResults())))
529  return failure();
530 
531  // Verify that all memrefs have a corresponding extmemory operation
532  for (auto arg : entryBlock.getArguments()) {
533  if (!isa<MemRefType>(arg.getType()))
534  continue;
535  if (arg.getUsers().empty() ||
536  !isa<ExternalMemoryOp>(*arg.getUsers().begin()))
537  return emitOpError("expected that block argument #")
538  << arg.getArgNumber() << " is used by an 'extmemory' operation";
539  }
540 
541  return success();
542 }
543 
544 /// Parses a FuncOp signature using
545 /// mlir::function_interface_impl::parseFunctionSignature while getting access
546 /// to the parsed SSA names to store as attributes.
547 static ParseResult
548 parseFuncOpArgs(OpAsmParser &parser,
549  SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
550  SmallVectorImpl<Type> &resTypes,
551  SmallVectorImpl<DictionaryAttr> &resAttrs) {
552  bool isVariadic;
553  if (mlir::function_interface_impl::parseFunctionSignature(
554  parser, /*allowVariadic=*/true, entryArgs, isVariadic, resTypes,
555  resAttrs)
556  .failed())
557  return failure();
558 
559  return success();
560 }
561 
562 /// Generates names for a handshake.func input and output arguments, based on
563 /// the number of args as well as a prefix.
564 static SmallVector<Attribute> getFuncOpNames(Builder &builder, unsigned cnt,
565  StringRef prefix) {
566  SmallVector<Attribute> resNames;
567  for (unsigned i = 0; i < cnt; ++i)
568  resNames.push_back(builder.getStringAttr(prefix + std::to_string(i)));
569  return resNames;
570 }
571 
572 void handshake::FuncOp::build(OpBuilder &builder, OperationState &state,
573  StringRef name, FunctionType type,
574  ArrayRef<NamedAttribute> attrs) {
575  state.addAttribute(SymbolTable::getSymbolAttrName(),
576  builder.getStringAttr(name));
577  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
578  TypeAttr::get(type));
579  state.attributes.append(attrs.begin(), attrs.end());
580 
581  if (const auto *argNamesAttrIt = llvm::find_if(
582  attrs, [&](auto attr) { return attr.getName() == "argNames"; });
583  argNamesAttrIt == attrs.end())
584  state.addAttribute("argNames", builder.getArrayAttr({}));
585 
586  if (llvm::find_if(attrs, [&](auto attr) {
587  return attr.getName() == "resNames";
588  }) == attrs.end())
589  state.addAttribute("resNames", builder.getArrayAttr({}));
590 
591  state.addRegion();
592 }
593 
594 /// Helper function for appending a string to an array attribute, and
595 /// rewriting the attribute back to the operation.
596 static void addStringToStringArrayAttr(Builder &builder, Operation *op,
597  StringRef attrName, StringAttr str) {
598  llvm::SmallVector<Attribute> attrs;
599  llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
600  std::back_inserter(attrs));
601  attrs.push_back(str);
602  op->setAttr(attrName, builder.getArrayAttr(attrs));
603 }
604 
605 void handshake::FuncOp::resolveArgAndResNames() {
606  Builder builder(getContext());
607 
608  /// Generate a set of fallback names. These are used in case names are
609  /// missing from the currently set arg- and res name attributes.
610  auto fallbackArgNames = getFuncOpNames(builder, getNumArguments(), "in");
611  auto fallbackResNames = getFuncOpNames(builder, getNumResults(), "out");
612  auto argNames = getArgNames().getValue();
613  auto resNames = getResNames().getValue();
614 
615  /// Use fallback names where actual names are missing.
616  auto resolveNames = [&](auto &fallbackNames, auto &actualNames,
617  StringRef attrName) {
618  for (auto fallbackName : llvm::enumerate(fallbackNames)) {
619  if (actualNames.size() <= fallbackName.index())
620  addStringToStringArrayAttr(builder, this->getOperation(), attrName,
621  cast<StringAttr>(fallbackName.value()));
622  }
623  };
624  resolveNames(fallbackArgNames, argNames, "argNames");
625  resolveNames(fallbackResNames, resNames, "resNames");
626 }
627 
628 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
629  auto &builder = parser.getBuilder();
630  StringAttr nameAttr;
631  SmallVector<OpAsmParser::Argument> args;
632  SmallVector<Type> resTypes;
633  SmallVector<DictionaryAttr> resAttributes;
634  SmallVector<Attribute> argNames;
635 
636  // Parse visibility.
637  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
638 
639  // Parse signature
640  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
641  result.attributes) ||
642  parseFuncOpArgs(parser, args, resTypes, resAttributes))
643  return failure();
644  mlir::function_interface_impl::addArgAndResultAttrs(
645  builder, result, args, resAttributes,
646  handshake::FuncOp::getArgAttrsAttrName(result.name),
647  handshake::FuncOp::getResAttrsAttrName(result.name));
648 
649  // Set function type
650  SmallVector<Type> argTypes;
651  for (auto arg : args)
652  argTypes.push_back(arg.type);
653 
654  result.addAttribute(
655  handshake::FuncOp::getFunctionTypeAttrName(result.name),
656  TypeAttr::get(builder.getFunctionType(argTypes, resTypes)));
657 
658  // Determine the names of the arguments. If no SSA values are present, use
659  // fallback names.
660  bool noSSANames =
661  llvm::any_of(args, [](auto arg) { return arg.ssaName.name.empty(); });
662  if (noSSANames) {
663  argNames = getFuncOpNames(builder, args.size(), "in");
664  } else {
665  llvm::transform(args, std::back_inserter(argNames), [&](auto arg) {
666  return builder.getStringAttr(arg.ssaName.name.drop_front());
667  });
668  }
669 
670  // Parse attributes
671  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
672  return failure();
673 
674  // If argNames and resNames wasn't provided manually, infer argNames attribute
675  // from the parsed SSA names and resNames from our naming convention.
676  if (!result.attributes.get("argNames"))
677  result.addAttribute("argNames", builder.getArrayAttr(argNames));
678  if (!result.attributes.get("resNames")) {
679  auto resNames = getFuncOpNames(builder, resTypes.size(), "out");
680  result.addAttribute("resNames", builder.getArrayAttr(resNames));
681  }
682 
683  // Parse the optional function body. The printer will not print the body if
684  // its empty, so disallow parsing of empty body in the parser.
685  auto *body = result.addRegion();
686  llvm::SMLoc loc = parser.getCurrentLocation();
687  auto parseResult = parser.parseOptionalRegion(*body, args,
688  /*enableNameShadowing=*/false);
689  if (!parseResult.has_value())
690  return success();
691 
692  if (failed(*parseResult))
693  return failure();
694  // Function body was parsed, make sure its not empty.
695  if (body->empty())
696  return parser.emitError(loc, "expected non-empty function body");
697 
698  // If a body was parsed, the arg and res names need to be resolved
699  return success();
700 }
701 
702 void FuncOp::print(OpAsmPrinter &p) {
703  mlir::function_interface_impl::printFunctionOp(
704  p, *this, /*isVariadic=*/true, getFunctionTypeAttrName(),
705  getArgAttrsAttrName(), getResAttrsAttrName());
706 }
707 
708 namespace {
709 struct EliminateSimpleControlMergesPattern
710  : mlir::OpRewritePattern<ControlMergeOp> {
712 
713  LogicalResult matchAndRewrite(ControlMergeOp op,
714  PatternRewriter &rewriter) const override;
715 };
716 } // namespace
717 
718 LogicalResult EliminateSimpleControlMergesPattern::matchAndRewrite(
719  ControlMergeOp op, PatternRewriter &rewriter) const {
720  auto dataResult = op.getResult();
721  auto choiceResult = op.getIndex();
722  auto choiceUnused = choiceResult.use_empty();
723  if (!choiceUnused && !choiceResult.hasOneUse())
724  return failure();
725 
726  Operation *choiceUser = nullptr;
727  if (choiceResult.hasOneUse()) {
728  choiceUser = choiceResult.getUses().begin().getUser();
729  if (!isa<SinkOp>(choiceUser))
730  return failure();
731  }
732 
733  auto merge = rewriter.create<MergeOp>(op.getLoc(), op.getDataOperands());
734 
735  for (auto &use : llvm::make_early_inc_range(dataResult.getUses())) {
736  auto *user = use.getOwner();
737  rewriter.modifyOpInPlace(
738  user, [&]() { user->setOperand(use.getOperandNumber(), merge); });
739  }
740 
741  if (choiceUnused) {
742  rewriter.eraseOp(op);
743  return success();
744  }
745 
746  rewriter.eraseOp(choiceUser);
747  rewriter.eraseOp(op);
748  return success();
749 }
750 
751 void ControlMergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
752  MLIRContext *context) {
753  results.insert<EliminateSimpleControlMergesPattern>(context);
754 }
755 
756 bool BranchOp::sostIsControl() {
757  return isControlCheckTypeAndOperand(getDataType(), getOperand());
758 }
759 
760 void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
761  MLIRContext *context) {
762  results.insert<circt::handshake::EliminateSimpleBranchesPattern>(context);
763 }
764 
765 ParseResult BranchOp::parse(OpAsmParser &parser, OperationState &result) {
766  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
767  Type type;
768  ArrayRef<Type> operandTypes(type);
769  SmallVector<Type, 1> dataOperandsTypes;
770  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
771  int size;
772  if (parseSostOperation(parser, allOperands, result, size, type, false))
773  return failure();
774 
775  dataOperandsTypes.assign(size, type);
776  result.addTypes({type});
777  if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
778  result.operands))
779  return failure();
780  return success();
781 }
782 
783 void BranchOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
784 
785 ParseResult ConditionalBranchOp::parse(OpAsmParser &parser,
786  OperationState &result) {
787  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
788  Type dataType;
789  SmallVector<Type> operandTypes;
790  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
791  if (parser.parseOperandList(allOperands) ||
792  parser.parseOptionalAttrDict(result.attributes) ||
793  parser.parseColonType(dataType))
794  return failure();
795 
796  if (allOperands.size() != 2)
797  return parser.emitError(parser.getCurrentLocation(),
798  "Expected exactly 2 operands");
799 
800  result.addTypes({dataType, dataType});
801  operandTypes.push_back(IntegerType::get(parser.getContext(), 1));
802  operandTypes.push_back(dataType);
803  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
804  result.operands))
805  return failure();
806 
807  return success();
808 }
809 
810 void ConditionalBranchOp::print(OpAsmPrinter &p) {
811  Type type = getDataOperand().getType();
812  p << " " << getOperands();
813  p.printOptionalAttrDict((*this)->getAttrs());
814  p << " : " << type;
815 }
816 
817 std::string handshake::ConditionalBranchOp::getOperandName(unsigned int idx) {
818  assert(idx == 0 || idx == 1);
819  return idx == 0 ? "cond" : "data";
820 }
821 
822 std::string handshake::ConditionalBranchOp::getResultName(unsigned int idx) {
823  assert(idx == 0 || idx == 1);
824  return idx == ConditionalBranchOp::falseIndex ? "outFalse" : "outTrue";
825 }
826 
827 bool ConditionalBranchOp::isControl() {
828  return isControlCheckTypeAndOperand(getDataOperand().getType(),
829  getDataOperand());
830 }
831 
832 ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
833  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
834  Type type;
835  ArrayRef<Type> operandTypes(type);
836  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
837  int size;
838  if (parseSostOperation(parser, allOperands, result, size, type, false))
839  return failure();
840 
841  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
842  result.operands))
843  return failure();
844  return success();
845 }
846 
847 void SinkOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
848 
849 std::string handshake::ConstantOp::getOperandName(unsigned int idx) {
850  assert(idx == 0);
851  return "ctrl";
852 }
853 
854 Type SourceOp::getDataType() { return getResult().getType(); }
855 unsigned SourceOp::getSize() { return 1; }
856 
857 ParseResult SourceOp::parse(OpAsmParser &parser, OperationState &result) {
858  if (parser.parseOptionalAttrDict(result.attributes))
859  return failure();
860  result.addTypes(NoneType::get(result.getContext()));
861  return success();
862 }
863 
864 void SourceOp::print(OpAsmPrinter &p) {
865  p.printOptionalAttrDict((*this)->getAttrs());
866 }
867 
868 LogicalResult ConstantOp::verify() {
869  // Verify that the type of the provided value is equal to the result type.
870  auto typedValue = dyn_cast<mlir::TypedAttr>(getValue());
871  if (!typedValue)
872  return emitOpError("constant value must be a typed attribute; value is ")
873  << getValue();
874  if (typedValue.getType() != getResult().getType())
875  return emitOpError() << "constant value type " << typedValue.getType()
876  << " differs from operation result type "
877  << getResult().getType();
878 
879  return success();
880 }
881 
882 void handshake::ConstantOp::getCanonicalizationPatterns(
883  RewritePatternSet &results, MLIRContext *context) {
884  results.insert<circt::handshake::EliminateSunkConstantsPattern>(context);
885 }
886 
887 LogicalResult BufferOp::verify() {
888  // Verify that exactly 'size' number of initial values have been provided, if
889  // an initializer list have been provided.
890  if (auto initVals = getInitValues()) {
891  if (!isSequential())
892  return emitOpError()
893  << "only bufferType buffers are allowed to have initial values.";
894 
895  auto nInits = initVals->size();
896  if (nInits != getSize())
897  return emitOpError() << "expected " << getSize()
898  << " init values but got " << nInits << ".";
899  }
900 
901  return success();
902 }
903 
904 void handshake::BufferOp::getCanonicalizationPatterns(
905  RewritePatternSet &results, MLIRContext *context) {
906  results.insert<circt::handshake::EliminateSunkBuffersPattern>(context);
907 }
908 
909 unsigned BufferOp::getSize() {
910  return (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
911 }
912 
913 ParseResult BufferOp::parse(OpAsmParser &parser, OperationState &result) {
914  SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
915  Type type;
916  ArrayRef<Type> operandTypes(type);
917  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
918  int slots;
919  if (parseIntInSquareBrackets(parser, slots))
920  return failure();
921 
922  auto bufferTypeAttr = BufferTypeEnumAttr::parse(parser, {});
923  if (!bufferTypeAttr)
924  return failure();
925 
926  result.addAttribute(
927  "slots",
928  IntegerAttr::get(IntegerType::get(result.getContext(), 32), slots));
929  result.addAttribute("bufferType", bufferTypeAttr);
930 
931  if (parser.parseOperandList(allOperands) ||
932  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
933  parser.parseType(type))
934  return failure();
935 
936  result.addTypes({type});
937  if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
938  result.operands))
939  return failure();
940  return success();
941 }
942 
943 void BufferOp::print(OpAsmPrinter &p) {
944  int size =
945  (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
946  p << " [" << size << "]";
947  p << " " << stringifyEnum(getBufferType());
948  p << " " << (*this)->getOperands();
949  p.printOptionalAttrDict((*this)->getAttrs(), {"slots", "bufferType"});
950  p << " : " << (*this).getDataType();
951 }
952 
953 static std::string getMemoryOperandName(unsigned nStores, unsigned idx) {
954  std::string name;
955  if (idx < nStores * 2) {
956  bool isData = idx % 2 == 0;
957  name = isData ? "stData" + std::to_string(idx / 2)
958  : "stAddr" + std::to_string(idx / 2);
959  } else {
960  idx -= 2 * nStores;
961  name = "ldAddr" + std::to_string(idx);
962  }
963  return name;
964 }
965 
966 std::string handshake::MemoryOp::getOperandName(unsigned int idx) {
967  return getMemoryOperandName(getStCount(), idx);
968 }
969 
970 static std::string getMemoryResultName(unsigned nLoads, unsigned nStores,
971  unsigned idx) {
972  std::string name;
973  if (idx < nLoads)
974  name = "ldData" + std::to_string(idx);
975  else if (idx < nLoads + nStores)
976  name = "stDone" + std::to_string(idx - nLoads);
977  else
978  name = "ldDone" + std::to_string(idx - nLoads - nStores);
979  return name;
980 }
981 
982 std::string handshake::MemoryOp::getResultName(unsigned int idx) {
983  return getMemoryResultName(getLdCount(), getStCount(), idx);
984 }
985 
986 LogicalResult MemoryOp::verify() {
987  auto memrefType = getMemRefType();
988 
989  if (memrefType.getNumDynamicDims() != 0)
990  return emitOpError()
991  << "memref dimensions for handshake.memory must be static.";
992  if (memrefType.getShape().size() != 1)
993  return emitOpError() << "memref must have only a single dimension.";
994 
995  unsigned opStCount = getStCount();
996  unsigned opLdCount = getLdCount();
997  int addressCount = memrefType.getShape().size();
998 
999  auto inputType = getInputs().getType();
1000  auto outputType = getOutputs().getType();
1001  Type dataType = memrefType.getElementType();
1002 
1003  unsigned numOperands = static_cast<int>(getInputs().size());
1004  unsigned numResults = static_cast<int>(getOutputs().size());
1005  if (numOperands != (1 + addressCount) * opStCount + addressCount * opLdCount)
1006  return emitOpError("number of operands ")
1007  << numOperands << " does not match number expected of "
1008  << 2 * opStCount + opLdCount << " with " << addressCount
1009  << " address inputs per port";
1010 
1011  if (numResults != opStCount + 2 * opLdCount)
1012  return emitOpError("number of results ")
1013  << numResults << " does not match number expected of "
1014  << opStCount + 2 * opLdCount << " with " << addressCount
1015  << " address inputs per port";
1016 
1017  Type addressType = opStCount > 0 ? inputType[1] : inputType[0];
1018 
1019  for (unsigned i = 0; i < opStCount; i++) {
1020  if (inputType[2 * i] != dataType)
1021  return emitOpError("data type for store port ")
1022  << i << ":" << inputType[2 * i] << " doesn't match memory type "
1023  << dataType;
1024  if (inputType[2 * i + 1] != addressType)
1025  return emitOpError("address type for store port ")
1026  << i << ":" << inputType[2 * i + 1]
1027  << " doesn't match address type " << addressType;
1028  }
1029  for (unsigned i = 0; i < opLdCount; i++) {
1030  Type ldAddressType = inputType[2 * opStCount + i];
1031  if (ldAddressType != addressType)
1032  return emitOpError("address type for load port ")
1033  << i << ":" << ldAddressType << " doesn't match address type "
1034  << addressType;
1035  }
1036  for (unsigned i = 0; i < opLdCount; i++) {
1037  if (outputType[i] != dataType)
1038  return emitOpError("data type for load port ")
1039  << i << ":" << outputType[i] << " doesn't match memory type "
1040  << dataType;
1041  }
1042  for (unsigned i = 0; i < opStCount; i++) {
1043  Type syncType = outputType[opLdCount + i];
1044  if (!isa<NoneType>(syncType))
1045  return emitOpError("data type for sync port for store port ")
1046  << i << ":" << syncType << " is not 'none'";
1047  }
1048  for (unsigned i = 0; i < opLdCount; i++) {
1049  Type syncType = outputType[opLdCount + opStCount + i];
1050  if (!isa<NoneType>(syncType))
1051  return emitOpError("data type for sync port for load port ")
1052  << i << ":" << syncType << " is not 'none'";
1053  }
1054 
1055  return success();
1056 }
1057 
1058 std::string handshake::ExternalMemoryOp::getOperandName(unsigned int idx) {
1059  if (idx == 0)
1060  return "extmem";
1061 
1062  return getMemoryOperandName(getStCount(), idx - 1);
1063 }
1064 
1065 std::string handshake::ExternalMemoryOp::getResultName(unsigned int idx) {
1066  return getMemoryResultName(getLdCount(), getStCount(), idx);
1067 }
1068 
1069 void ExternalMemoryOp::build(OpBuilder &builder, OperationState &result,
1070  Value memref, ValueRange inputs, int ldCount,
1071  int stCount, int id) {
1072  SmallVector<Value> ops;
1073  ops.push_back(memref);
1074  llvm::append_range(ops, inputs);
1075  result.addOperands(ops);
1076 
1077  auto memrefType = cast<MemRefType>(memref.getType());
1078 
1079  // Data outputs (get their type from memref)
1080  result.types.append(ldCount, memrefType.getElementType());
1081 
1082  // Control outputs
1083  result.types.append(stCount + ldCount, builder.getNoneType());
1084 
1085  // Memory ID (individual ID for each MemoryOp)
1086  Type i32Type = builder.getIntegerType(32);
1087  result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1088  result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, ldCount));
1089  result.addAttribute("stCount", builder.getIntegerAttr(i32Type, stCount));
1090 }
1091 
1092 llvm::SmallVector<handshake::MemLoadInterface>
1094  return ::getLoadPorts(*this);
1095 }
1096 
1097 llvm::SmallVector<handshake::MemStoreInterface>
1099  return ::getStorePorts(*this);
1100 }
1101 
1102 void MemoryOp::build(OpBuilder &builder, OperationState &result,
1103  ValueRange operands, int outputs, int controlOutputs,
1104  bool lsq, int id, Value memref) {
1105  result.addOperands(operands);
1106 
1107  auto memrefType = cast<MemRefType>(memref.getType());
1108 
1109  // Data outputs (get their type from memref)
1110  result.types.append(outputs, memrefType.getElementType());
1111 
1112  // Control outputs
1113  result.types.append(controlOutputs, builder.getNoneType());
1114  result.addAttribute("lsq", builder.getBoolAttr(lsq));
1115  result.addAttribute("memRefType", TypeAttr::get(memrefType));
1116 
1117  // Memory ID (individual ID for each MemoryOp)
1118  Type i32Type = builder.getIntegerType(32);
1119  result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1120 
1121  if (!lsq) {
1122  result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, outputs));
1123  result.addAttribute(
1124  "stCount", builder.getIntegerAttr(i32Type, controlOutputs - outputs));
1125  }
1126 }
1127 
1128 llvm::SmallVector<handshake::MemLoadInterface> MemoryOp::getLoadPorts() {
1129  return ::getLoadPorts(*this);
1130 }
1131 
1132 llvm::SmallVector<handshake::MemStoreInterface> MemoryOp::getStorePorts() {
1133  return ::getStorePorts(*this);
1134 }
1135 
1136 bool handshake::MemoryOp::allocateMemory(
1137  llvm::DenseMap<unsigned, unsigned> &memoryMap,
1138  std::vector<std::vector<llvm::Any>> &store,
1139  std::vector<double> &storeTimes) {
1140  if (memoryMap.count(getId()))
1141  return false;
1142 
1143  auto type = getMemRefType();
1144  std::vector<llvm::Any> in;
1145 
1146  ArrayRef<int64_t> shape = type.getShape();
1147  int allocationSize = 1;
1148  unsigned count = 0;
1149  for (int64_t dim : shape) {
1150  if (dim > 0)
1151  allocationSize *= dim;
1152  else {
1153  assert(count < in.size());
1154  allocationSize *= llvm::any_cast<APInt>(in[count++]).getSExtValue();
1155  }
1156  }
1157  unsigned ptr = store.size();
1158  store.resize(ptr + 1);
1159  storeTimes.resize(ptr + 1);
1160  store[ptr].resize(allocationSize);
1161  storeTimes[ptr] = 0.0;
1162  mlir::Type elementType = type.getElementType();
1163  int width = elementType.getIntOrFloatBitWidth();
1164  for (int i = 0; i < allocationSize; i++) {
1165  if (isa<mlir::IntegerType>(elementType)) {
1166  store[ptr][i] = APInt(width, 0);
1167  } else if (isa<mlir::FloatType>(elementType)) {
1168  store[ptr][i] = APFloat(0.0);
1169  } else {
1170  llvm_unreachable("Unknown result type!\n");
1171  }
1172  }
1173 
1174  memoryMap[getId()] = ptr;
1175  return true;
1176 }
1177 
1178 std::string handshake::LoadOp::getOperandName(unsigned int idx) {
1179  unsigned nAddresses = getAddresses().size();
1180  std::string opName;
1181  if (idx < nAddresses)
1182  opName = "addrIn" + std::to_string(idx);
1183  else if (idx == nAddresses)
1184  opName = "dataFromMem";
1185  else
1186  opName = "ctrl";
1187  return opName;
1188 }
1189 
1190 std::string handshake::LoadOp::getResultName(unsigned int idx) {
1191  std::string resName;
1192  if (idx == 0)
1193  resName = "dataOut";
1194  else
1195  resName = "addrOut" + std::to_string(idx - 1);
1196  return resName;
1197 }
1198 
1199 void handshake::LoadOp::build(OpBuilder &builder, OperationState &result,
1200  Value memref, ValueRange indices) {
1201  // Address indices
1202  // result.addOperands(memref);
1203  result.addOperands(indices);
1204 
1205  // Data type
1206  auto memrefType = cast<MemRefType>(memref.getType());
1207 
1208  // Data output (from load to successor ops)
1209  result.types.push_back(memrefType.getElementType());
1210 
1211  // Address outputs (to lsq)
1212  result.types.append(indices.size(), builder.getIndexType());
1213 }
1214 
1215 static ParseResult parseMemoryAccessOp(OpAsmParser &parser,
1216  OperationState &result) {
1217  SmallVector<OpAsmParser::UnresolvedOperand, 4> addressOperands,
1218  remainingOperands, allOperands;
1219  SmallVector<Type, 1> parsedTypes, allTypes;
1220  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1221 
1222  if (parser.parseLSquare() || parser.parseOperandList(addressOperands) ||
1223  parser.parseRSquare() || parser.parseOperandList(remainingOperands) ||
1224  parser.parseColon() || parser.parseTypeList(parsedTypes))
1225  return failure();
1226 
1227  // The last type will be the data type of the operation; the prior will be the
1228  // address types.
1229  Type dataType = parsedTypes.back();
1230  auto parsedTypesRef = ArrayRef(parsedTypes);
1231  result.addTypes(dataType);
1232  result.addTypes(parsedTypesRef.drop_back());
1233  allOperands.append(addressOperands);
1234  allOperands.append(remainingOperands);
1235  allTypes.append(parsedTypes);
1236  allTypes.push_back(NoneType::get(result.getContext()));
1237  if (parser.resolveOperands(allOperands, allTypes, allOperandLoc,
1238  result.operands))
1239  return failure();
1240  return success();
1241 }
1242 
1243 template <typename MemOp>
1244 static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op) {
1245  p << " [";
1246  p << op.getAddresses();
1247  p << "] " << op.getData() << ", " << op.getCtrl() << " : ";
1248  llvm::interleaveComma(op.getAddresses(), p,
1249  [&](Value v) { p << v.getType(); });
1250  p << ", " << op.getData().getType();
1251 }
1252 
1253 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
1254  return parseMemoryAccessOp(parser, result);
1255 }
1256 
1257 void LoadOp::print(OpAsmPrinter &p) { printMemoryAccessOp(p, *this); }
1258 
1259 std::string handshake::StoreOp::getOperandName(unsigned int idx) {
1260  unsigned nAddresses = getAddresses().size();
1261  std::string opName;
1262  if (idx < nAddresses)
1263  opName = "addrIn" + std::to_string(idx);
1264  else if (idx == nAddresses)
1265  opName = "dataIn";
1266  else
1267  opName = "ctrl";
1268  return opName;
1269 }
1270 
1271 template <typename TMemoryOp>
1272 static LogicalResult verifyMemoryAccessOp(TMemoryOp op) {
1273  if (op.getAddresses().size() == 0)
1274  return op.emitOpError() << "No addresses were specified";
1275 
1276  return success();
1277 }
1278 
1279 LogicalResult LoadOp::verify() { return verifyMemoryAccessOp(*this); }
1280 
1281 std::string handshake::StoreOp::getResultName(unsigned int idx) {
1282  std::string resName;
1283  if (idx == 0)
1284  resName = "dataToMem";
1285  else
1286  resName = "addrOut" + std::to_string(idx - 1);
1287  return resName;
1288 }
1289 
1290 void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
1291  Value valueToStore, ValueRange indices) {
1292 
1293  // Address indices
1294  result.addOperands(indices);
1295 
1296  // Data
1297  result.addOperands(valueToStore);
1298 
1299  // Data output (from store to LSQ)
1300  result.types.push_back(valueToStore.getType());
1301 
1302  // Address outputs (from store to lsq)
1303  result.types.append(indices.size(), builder.getIndexType());
1304 }
1305 
1306 LogicalResult StoreOp::verify() { return verifyMemoryAccessOp(*this); }
1307 
1308 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
1309  return parseMemoryAccessOp(parser, result);
1310 }
1311 
1312 void StoreOp::print(OpAsmPrinter &p) { return printMemoryAccessOp(p, *this); }
1313 
1314 bool JoinOp::isControl() { return true; }
1315 
1316 ParseResult JoinOp::parse(OpAsmParser &parser, OperationState &result) {
1317  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1318  SmallVector<Type> types;
1319 
1320  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1321  if (parser.parseOperandList(operands) ||
1322  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1323  parser.parseTypeList(types))
1324  return failure();
1325 
1326  if (parser.resolveOperands(operands, types, allOperandLoc, result.operands))
1327  return failure();
1328 
1329  result.addTypes(NoneType::get(result.getContext()));
1330  return success();
1331 }
1332 
1333 void JoinOp::print(OpAsmPrinter &p) {
1334  p << " " << getData();
1335  p.printOptionalAttrDict((*this)->getAttrs(), {"control"});
1336  p << " : " << getData().getTypes();
1337 }
1338 
1339 LogicalResult
1340 ESIInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1341  // Check that the module attribute was specified.
1342  auto fnAttr = this->getModuleAttr();
1343  assert(fnAttr && "requires a 'module' symbol reference attribute");
1344 
1345  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
1346  if (!fn)
1347  return emitOpError() << "'" << fnAttr.getValue()
1348  << "' does not reference a valid handshake function";
1349 
1350  // Verify that the operand and result types match the callee.
1351  auto fnType = fn.getFunctionType();
1352  if (fnType.getNumInputs() != getNumOperands() - NumFixedOperands)
1353  return emitOpError(
1354  "incorrect number of operands for the referenced handshake function");
1355 
1356  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1357  Type operandType = getOperand(i + NumFixedOperands).getType();
1358  auto channelType = dyn_cast<esi::ChannelType>(operandType);
1359  if (!channelType)
1360  return emitOpError("operand type mismatch: expected channel type, but "
1361  "provided ")
1362  << operandType << " for operand number " << i;
1363  if (channelType.getInner() != fnType.getInput(i))
1364  return emitOpError("operand type mismatch: expected operand type ")
1365  << fnType.getInput(i) << ", but provided "
1366  << getOperand(i).getType() << " for operand number " << i;
1367  }
1368 
1369  if (fnType.getNumResults() != getNumResults())
1370  return emitOpError(
1371  "incorrect number of results for the referenced handshake function");
1372 
1373  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1374  Type resultType = getResult(i).getType();
1375  auto channelType = dyn_cast<esi::ChannelType>(resultType);
1376  if (!channelType)
1377  return emitOpError("result type mismatch: expected channel type, but "
1378  "provided ")
1379  << resultType << " for result number " << i;
1380  if (channelType.getInner() != fnType.getResult(i))
1381  return emitOpError("result type mismatch: expected result type ")
1382  << fnType.getResult(i) << ", but provided "
1383  << getResult(i).getType() << " for result number " << i;
1384  }
1385 
1386  return success();
1387 }
1388 
1389 /// Based on mlir::func::CallOp::verifySymbolUses
1390 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1391  // Check that the module attribute was specified.
1392  auto fnAttr = this->getModuleAttr();
1393  assert(fnAttr && "requires a 'module' symbol reference attribute");
1394 
1395  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
1396  if (!fn)
1397  return emitOpError() << "'" << fnAttr.getValue()
1398  << "' does not reference a valid handshake function";
1399 
1400  // Verify that the operand and result types match the callee.
1401  auto fnType = fn.getFunctionType();
1402  if (fnType.getNumInputs() != getNumOperands())
1403  return emitOpError(
1404  "incorrect number of operands for the referenced handshake function");
1405 
1406  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
1407  if (getOperand(i).getType() != fnType.getInput(i))
1408  return emitOpError("operand type mismatch: expected operand type ")
1409  << fnType.getInput(i) << ", but provided "
1410  << getOperand(i).getType() << " for operand number " << i;
1411 
1412  if (fnType.getNumResults() != getNumResults())
1413  return emitOpError(
1414  "incorrect number of results for the referenced handshake function");
1415 
1416  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
1417  if (getResult(i).getType() != fnType.getResult(i))
1418  return emitOpError("result type mismatch: expected result type ")
1419  << fnType.getResult(i) << ", but provided "
1420  << getResult(i).getType() << " for result number " << i;
1421 
1422  return success();
1423 }
1424 
1425 FunctionType InstanceOp::getModuleType() {
1426  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
1427 }
1428 
1429 ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1430  OpAsmParser::UnresolvedOperand tuple;
1431  TupleType type;
1432 
1433  if (parser.parseOperand(tuple) ||
1434  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1435  parser.parseType(type))
1436  return failure();
1437 
1438  if (parser.resolveOperand(tuple, type, result.operands))
1439  return failure();
1440 
1441  result.addTypes(type.getTypes());
1442 
1443  return success();
1444 }
1445 
1446 void UnpackOp::print(OpAsmPrinter &p) {
1447  p << " " << getInput();
1448  p.printOptionalAttrDict((*this)->getAttrs());
1449  p << " : " << getInput().getType();
1450 }
1451 
1452 ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1453  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1454  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1455  TupleType type;
1456 
1457  if (parser.parseOperandList(operands) ||
1458  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1459  parser.parseType(type))
1460  return failure();
1461 
1462  if (parser.resolveOperands(operands, type.getTypes(), allOperandLoc,
1463  result.operands))
1464  return failure();
1465 
1466  result.addTypes(type);
1467 
1468  return success();
1469 }
1470 
1471 void PackOp::print(OpAsmPrinter &p) {
1472  p << " " << getInputs();
1473  p.printOptionalAttrDict((*this)->getAttrs());
1474  p << " : " << getResult().getType();
1475 }
1476 
1477 //===----------------------------------------------------------------------===//
1478 // TableGen'd op method definitions
1479 //===----------------------------------------------------------------------===//
1480 
1481 LogicalResult ReturnOp::verify() {
1482  auto *parent = (*this)->getParentOp();
1483  auto function = dyn_cast<handshake::FuncOp>(parent);
1484  if (!function)
1485  return emitOpError("must have a handshake.func parent");
1486 
1487  // The operand number and types must match the function signature.
1488  const auto &results = function.getResultTypes();
1489  if (getNumOperands() != results.size())
1490  return emitOpError("has ")
1491  << getNumOperands() << " operands, but enclosing function returns "
1492  << results.size();
1493 
1494  for (unsigned i = 0, e = results.size(); i != e; ++i)
1495  if (getOperand(i).getType() != results[i])
1496  return emitError() << "type of return operand " << i << " ("
1497  << getOperand(i).getType()
1498  << ") doesn't match function result type ("
1499  << results[i] << ")";
1500 
1501  return success();
1502 }
1503 
1504 #define GET_OP_CLASSES
1505 #include "circt/Dialect/Handshake/Handshake.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition: CHIRRTL.cpp:29
static StringRef getOperandName(Value operand)
static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal, uint64_t numOperands)
Verifies whether an indexing value is wide enough to index into a provided number of operands.
static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result)
static ParseResult parseSostOperation(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, OperationState &result, int &size, Type &type, bool explicitSize)
static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op)
static ParseResult parseFuncOpArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Type > &resTypes, SmallVectorImpl< DictionaryAttr > &resAttrs)
Parses a FuncOp signature using mlir::function_interface_impl::parseFunctionSignature while getting a...
static Value getDematerialized(Value v)
Returns a dematerialized version of the value 'v', defined as the source of the value before passing ...
static bool isControlCheckTypeAndOperand(Type dataType, Value operand)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static std::string getMemoryResultName(unsigned nLoads, unsigned nStores, unsigned idx)
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v)
static LogicalResult verifyMemoryAccessOp(TMemoryOp op)
static void addStringToStringArrayAttr(Builder &builder, Operation *op, StringRef attrName, StringAttr str)
Helper function for appending a string to an array attribute, and rewriting the attribute back to the...
static std::string defaultOperandName(unsigned int idx)
static SmallVector< Attribute > getFuncOpNames(Builder &builder, unsigned cnt, StringRef prefix)
Generates names for a handshake.func input and output arguments, based on the number of args as well ...
static std::string getMemoryOperandName(unsigned nStores, unsigned idx)
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
static ParseResult parseMemoryAccessOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition: SVOps.cpp:2467
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
bool isControlOpImpl(Operation *op)
Default implementation for checking whether an operation is a control operation.
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition: HWOps.cpp:521
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21