Loading [MathJax]/extensions/MathMenu.js
CIRCT 20.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
HandshakeOps.cpp
Go to the documentation of this file.
1//===- HandshakeOps.cpp - Handshake MLIR Operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the declaration of the Handshake operations struct.
10//
11//===----------------------------------------------------------------------===//
12
15#include "circt/Support/LLVM.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/IntegerSet.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/OpDefinition.h"
24#include "mlir/IR/OpImplementation.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/SymbolTable.h"
27#include "mlir/IR/Value.h"
28#include "mlir/Interfaces/FunctionImplementation.h"
29#include "mlir/Transforms/InliningUtils.h"
30#include "llvm/ADT/SetVector.h"
31#include "llvm/ADT/SmallBitVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34#include <set>
35
36using namespace circt;
37using namespace circt::handshake;
38
39namespace circt {
40namespace handshake {
41#include "circt/Dialect/Handshake/HandshakeCanonicalization.h.inc"
42
43bool isControlOpImpl(Operation *op) {
44 if (auto sostInterface = dyn_cast<SOSTInterface>(op); sostInterface)
45 return sostInterface.sostIsControl();
46
47 return false;
48}
49
50} // namespace handshake
51} // namespace circt
52
53static std::string defaultOperandName(unsigned int idx) {
54 return "in" + std::to_string(idx);
55}
56
57static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v) {
58 if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
59 return failure();
60 return success();
61}
62
63static ParseResult
64parseSostOperation(OpAsmParser &parser,
65 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
66 OperationState &result, int &size, Type &type,
67 bool explicitSize) {
68 if (explicitSize)
69 if (parseIntInSquareBrackets(parser, size))
70 return failure();
71
72 if (parser.parseOperandList(operands) ||
73 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
74 parser.parseType(type))
75 return failure();
76
77 if (!explicitSize)
78 size = operands.size();
79 return success();
80}
81
82/// Verifies whether an indexing value is wide enough to index into a provided
83/// number of operands.
84static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal,
85 uint64_t numOperands) {
86 auto indexType = indexVal.getType();
87 unsigned indexWidth;
88
89 // Determine the bitwidth of the indexing value
90 if (auto integerType = dyn_cast<IntegerType>(indexType))
91 indexWidth = integerType.getWidth();
92 else if (indexType.isIndex())
93 indexWidth = IndexType::kInternalStorageBitWidth;
94 else
95 return op->emitError("unsupported type for indexing value: ") << indexType;
96
97 // Check whether the bitwidth can support the provided number of operands
98 if (indexWidth < 64) {
99 uint64_t maxNumOperands = (uint64_t)1 << indexWidth;
100 if (numOperands > maxNumOperands)
101 return op->emitError("bitwidth of indexing value is ")
102 << indexWidth << ", which can index into " << maxNumOperands
103 << " operands, but found " << numOperands << " operands";
104 }
105 return success();
106}
107
108static bool isControlCheckTypeAndOperand(Type dataType, Value operand) {
109 // The operation is a control operation if its operand data type is a
110 // NoneType.
111 if (isa<NoneType>(dataType))
112 return true;
113
114 // Otherwise, the operation is a control operation if the operation's
115 // operand originates from the control network
116 auto *defOp = operand.getDefiningOp();
117 return isa_and_nonnull<ControlMergeOp>(defOp) &&
118 operand == defOp->getResult(0);
119}
120
121template <typename TMemOp>
122llvm::SmallVector<handshake::MemLoadInterface> getLoadPorts(TMemOp op) {
123 llvm::SmallVector<MemLoadInterface> ports;
124 // Memory interface refresher:
125 // Operands:
126 // all stores (stdata1, staddr1, stdata2, staddr2, ...)
127 // then all loads (ldaddr1, ldaddr2,...)
128 // Outputs: load addresses (lddata1, lddata2, ...), followed by all none
129 // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
130 unsigned stCount = op.getStCount();
131 unsigned ldCount = op.getLdCount();
132 for (unsigned i = 0, e = ldCount; i != e; ++i) {
133 MemLoadInterface ldif;
134 ldif.index = i;
135 ldif.addressIn = op.getInputs()[stCount * 2 + i];
136 ldif.dataOut = op.getResult(i);
137 ldif.doneOut = op.getResult(ldCount + stCount + i);
138 ports.push_back(ldif);
139 }
140 return ports;
141}
142
143template <typename TMemOp>
144llvm::SmallVector<handshake::MemStoreInterface> getStorePorts(TMemOp op) {
145 llvm::SmallVector<MemStoreInterface> ports;
146 // Memory interface refresher:
147 // Operands:
148 // all stores (stdata1, staddr1, stdata2, staddr2, ...)
149 // then all loads (ldaddr1, ldaddr2,...)
150 // Outputs: load data (lddata1, lddata2, ...), followed by all none
151 // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
152 unsigned ldCount = op.getLdCount();
153 for (unsigned i = 0, e = op.getStCount(); i != e; ++i) {
155 stif.index = i;
156 stif.dataIn = op.getInputs()[i * 2];
157 stif.addressIn = op.getInputs()[i * 2 + 1];
158 stif.doneOut = op.getResult(ldCount + i);
159 ports.push_back(stif);
160 }
161 return ports;
162}
163
164unsigned ForkOp::getSize() { return getResults().size(); }
165
166static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result) {
167 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
168 Type type;
169 ArrayRef<Type> operandTypes(type);
170 SmallVector<Type, 1> resultTypes;
171 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
172 int size;
173 if (parseSostOperation(parser, allOperands, result, size, type, true))
174 return failure();
175
176 resultTypes.assign(size, type);
177 result.addTypes(resultTypes);
178 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
179 result.operands))
180 return failure();
181 return success();
182}
183
184ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
185 return parseForkOp(parser, result);
186}
187
188void ForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
189
190namespace {
191
192struct EliminateUnusedForkResultsPattern : mlir::OpRewritePattern<ForkOp> {
194
195 LogicalResult matchAndRewrite(ForkOp op,
196 PatternRewriter &rewriter) const override {
197 std::set<unsigned> unusedIndexes;
198
199 for (auto res : llvm::enumerate(op.getResults()))
200 if (res.value().getUses().empty())
201 unusedIndexes.insert(res.index());
202
203 if (unusedIndexes.empty())
204 return failure();
205
206 // Create a new fork op, dropping the unused results.
207 rewriter.setInsertionPoint(op);
208 auto operand = op.getOperand();
209 auto newFork = rewriter.create<ForkOp>(
210 op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
211 unsigned i = 0;
212 for (auto oldRes : llvm::enumerate(op.getResults()))
213 if (unusedIndexes.count(oldRes.index()) == 0)
214 rewriter.replaceAllUsesWith(oldRes.value(), newFork.getResults()[i++]);
215 rewriter.eraseOp(op);
216 return success();
217 }
218};
219
220struct EliminateForkToForkPattern : mlir::OpRewritePattern<ForkOp> {
222
223 LogicalResult matchAndRewrite(ForkOp op,
224 PatternRewriter &rewriter) const override {
225 auto parentForkOp = op.getOperand().getDefiningOp<ForkOp>();
226 if (!parentForkOp)
227 return failure();
228
229 /// Create the fork with as many outputs as the two source forks.
230 /// Keeping the op.operand() output may or may not be redundant (dependning
231 /// on if op is the single user of the value), but we'll let
232 /// EliminateUnusedForkResultsPattern apply in that case.
233 unsigned totalNumOuts = op.getSize() + parentForkOp.getSize();
234 /// Create a new parent fork op which produces all of the fork outputs and
235 /// replace all of the uses of the old results.
236 auto newParentForkOp = rewriter.create<ForkOp>(
237 parentForkOp.getLoc(), parentForkOp.getOperand(), totalNumOuts);
238
239 for (auto it :
240 llvm::zip(parentForkOp->getResults(), newParentForkOp.getResults()))
241 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
242
243 /// Replace the results of the matches fork op with the corresponding
244 /// results of the new parent fork op.
245 rewriter.replaceOp(op,
246 newParentForkOp.getResults().take_back(op.getSize()));
247 rewriter.eraseOp(parentForkOp);
248 return success();
249 }
250};
251
252} // namespace
253
254void handshake::ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
255 MLIRContext *context) {
256 results.insert<circt::handshake::EliminateSimpleForksPattern,
257 EliminateForkToForkPattern, EliminateUnusedForkResultsPattern>(
258 context);
259}
260
261unsigned LazyForkOp::getSize() { return getResults().size(); }
262
263bool LazyForkOp::sostIsControl() {
264 return isControlCheckTypeAndOperand(getDataType(), getOperand());
265}
266
267ParseResult LazyForkOp::parse(OpAsmParser &parser, OperationState &result) {
268 return parseForkOp(parser, result);
269}
270
271void LazyForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
272
273ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
274 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
275 Type type;
276 ArrayRef<Type> operandTypes(type);
277 SmallVector<Type, 1> resultTypes, dataOperandsTypes;
278 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
279 int size;
280 if (parseSostOperation(parser, allOperands, result, size, type, false))
281 return failure();
282
283 dataOperandsTypes.assign(size, type);
284 resultTypes.push_back(type);
285 result.addTypes(resultTypes);
286 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
287 result.operands))
288 return failure();
289 return success();
290}
291
292void MergeOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
293
294void MergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
295 MLIRContext *context) {
296 results.insert<circt::handshake::EliminateSimpleMergesPattern>(context);
297}
298
299/// Returns a dematerialized version of the value 'v', defined as the source of
300/// the value before passing through a buffer or fork operation.
301static Value getDematerialized(Value v) {
302 Operation *parentOp = v.getDefiningOp();
303 if (!parentOp)
304 return v;
305
306 return llvm::TypeSwitch<Operation *, Value>(parentOp)
307 .Case<ForkOp>(
308 [&](ForkOp op) { return getDematerialized(op.getOperand()); })
309 .Case<BufferOp>(
310 [&](BufferOp op) { return getDematerialized(op.getOperand()); })
311 .Default([&](auto) { return v; });
312}
313
314namespace {
315
316/// Eliminates muxes with identical data inputs. Data inputs are inspected as
317/// their dematerialized versions. This has the side effect of any subsequently
318/// unused buffers are DCE'd and forks are optimized to be narrower.
319struct EliminateSimpleMuxesPattern : mlir::OpRewritePattern<MuxOp> {
321 LogicalResult matchAndRewrite(MuxOp op,
322 PatternRewriter &rewriter) const override {
323 Value firstDataOperand = getDematerialized(op.getDataOperands()[0]);
324 if (!llvm::all_of(op.getDataOperands(), [&](Value operand) {
325 return getDematerialized(operand) == firstDataOperand;
326 }))
327 return failure();
328 rewriter.replaceOp(op, firstDataOperand);
329 return success();
330 }
331};
332
333struct EliminateUnaryMuxesPattern : OpRewritePattern<MuxOp> {
335 LogicalResult matchAndRewrite(MuxOp op,
336 PatternRewriter &rewriter) const override {
337 if (op.getDataOperands().size() != 1)
338 return failure();
339
340 rewriter.replaceOp(op, op.getDataOperands()[0]);
341 return success();
342 }
343};
344
345struct EliminateCBranchIntoMuxPattern : OpRewritePattern<MuxOp> {
347 LogicalResult matchAndRewrite(MuxOp op,
348 PatternRewriter &rewriter) const override {
349
350 auto dataOperands = op.getDataOperands();
351 if (dataOperands.size() != 2)
352 return failure();
353
354 // Both data operands must originate from the same cbranch
355 ConditionalBranchOp firstParentCBranch =
356 dataOperands[0].getDefiningOp<ConditionalBranchOp>();
357 if (!firstParentCBranch)
358 return failure();
359 auto secondParentCBranch =
360 dataOperands[1].getDefiningOp<ConditionalBranchOp>();
361 if (!secondParentCBranch || firstParentCBranch != secondParentCBranch)
362 return failure();
363
364 rewriter.modifyOpInPlace(firstParentCBranch, [&] {
365 // Replace uses of the mux's output with cbranch's data input
366 rewriter.replaceOp(op, firstParentCBranch.getDataOperand());
367 });
368
369 return success();
370 }
371};
372
373} // namespace
374
375void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
376 MLIRContext *context) {
377 results.insert<EliminateSimpleMuxesPattern, EliminateUnaryMuxesPattern,
378 EliminateCBranchIntoMuxPattern>(context);
379}
380
381LogicalResult
382MuxOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
383 ValueRange operands, DictionaryAttr attributes,
384 mlir::OpaqueProperties properties,
385 mlir::RegionRange regions,
386 SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
387 // MuxOp must have at least one data operand (in addition to the select
388 // operand)
389 if (operands.size() < 2)
390 return failure();
391 // Result type is type of any data operand
392 inferredReturnTypes.push_back(operands[1].getType());
393 return success();
394}
395
396bool MuxOp::isControl() { return isa<NoneType>(getResult().getType()); }
397
398std::string handshake::MuxOp::getOperandName(unsigned int idx) {
399 return idx == 0 ? "select" : defaultOperandName(idx - 1);
400}
401
402ParseResult MuxOp::parse(OpAsmParser &parser, OperationState &result) {
403 OpAsmParser::UnresolvedOperand selectOperand;
404 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
405 Type selectType, dataType;
406 SmallVector<Type, 1> dataOperandsTypes;
407 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
408 if (parser.parseOperand(selectOperand) || parser.parseLSquare() ||
409 parser.parseOperandList(allOperands) || parser.parseRSquare() ||
410 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
411 parser.parseType(selectType) || parser.parseComma() ||
412 parser.parseType(dataType))
413 return failure();
414
415 int size = allOperands.size();
416 dataOperandsTypes.assign(size, dataType);
417 result.addTypes(dataType);
418 allOperands.insert(allOperands.begin(), selectOperand);
419 if (parser.resolveOperands(
420 allOperands,
421 llvm::concat<const Type>(ArrayRef<Type>(selectType),
422 ArrayRef<Type>(dataOperandsTypes)),
423 allOperandLoc, result.operands))
424 return failure();
425 return success();
426}
427
428void MuxOp::print(OpAsmPrinter &p) {
429 Type selectType = getSelectOperand().getType();
430 auto ops = getOperands();
431 p << ' ' << ops.front();
432 p << " [";
433 p.printOperands(ops.drop_front());
434 p << "]";
435 p.printOptionalAttrDict((*this)->getAttrs());
436 p << " : " << selectType << ", " << getResult().getType();
437}
438
439LogicalResult MuxOp::verify() {
440 return verifyIndexWideEnough(*this, getSelectOperand(),
441 getDataOperands().size());
442}
443
444std::string handshake::ControlMergeOp::getResultName(unsigned int idx) {
445 assert(idx == 0 || idx == 1);
446 return idx == 0 ? "dataOut" : "index";
447}
448
449ParseResult ControlMergeOp::parse(OpAsmParser &parser, OperationState &result) {
450 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
451 Type resultType, indexType;
452 SmallVector<Type> resultTypes, dataOperandsTypes;
453 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
454 int size;
455 if (parseSostOperation(parser, allOperands, result, size, resultType, false))
456 return failure();
457 // Parse type of index result
458 if (parser.parseComma() || parser.parseType(indexType))
459 return failure();
460
461 dataOperandsTypes.assign(size, resultType);
462 resultTypes.push_back(resultType);
463 resultTypes.push_back(indexType);
464 result.addTypes(resultTypes);
465 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
466 result.operands))
467 return failure();
468 return success();
469}
470
471void ControlMergeOp::print(OpAsmPrinter &p) {
472 sostPrint(p, false);
473 // Print type of index result
474 p << ", " << getIndex().getType();
475}
476
477LogicalResult ControlMergeOp::verify() {
478 auto operands = getOperands();
479 if (operands.empty())
480 return emitOpError("operation must have at least one operand");
481 if (operands[0].getType() != getResult().getType())
482 return emitOpError("type of first result should match type of operands");
483 return verifyIndexWideEnough(*this, getIndex(), getNumOperands());
484}
485
486LogicalResult FuncOp::verify() {
487 // If this function is external there is nothing to do.
488 if (isExternal())
489 return success();
490
491 // Verify that the argument list of the function and the arg list of the
492 // entry block line up. The trait already verified that the number of
493 // arguments is the same between the signature and the block.
494 auto fnInputTypes = getArgumentTypes();
495 Block &entryBlock = front();
496
497 for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
498 if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
499 return emitOpError("type of entry block argument #")
500 << i << '(' << entryBlock.getArgument(i).getType()
501 << ") must match the type of the corresponding argument in "
502 << "function signature(" << fnInputTypes[i] << ')';
503
504 // Verify that we have a name for each argument and result of this function.
505 auto verifyPortNameAttr = [&](StringRef attrName,
506 unsigned numIOs) -> LogicalResult {
507 auto portNamesAttr = (*this)->getAttrOfType<ArrayAttr>(attrName);
508
509 if (!portNamesAttr)
510 return emitOpError() << "expected attribute '" << attrName << "'.";
511
512 auto portNames = portNamesAttr.getValue();
513 if (portNames.size() != numIOs)
514 return emitOpError() << "attribute '" << attrName << "' has "
515 << portNames.size()
516 << " entries but is expected to have " << numIOs
517 << ".";
518
519 if (llvm::any_of(portNames,
520 [&](Attribute attr) { return !isa<StringAttr>(attr); }))
521 return emitOpError() << "expected all entries in attribute '" << attrName
522 << "' to be strings.";
523
524 return success();
525 };
526 if (failed(verifyPortNameAttr("argNames", getNumArguments())))
527 return failure();
528 if (failed(verifyPortNameAttr("resNames", getNumResults())))
529 return failure();
530
531 // Verify that all memrefs have a corresponding extmemory operation
532 for (auto arg : entryBlock.getArguments()) {
533 if (!isa<MemRefType>(arg.getType()))
534 continue;
535 if (arg.getUsers().empty() ||
536 !isa<ExternalMemoryOp>(*arg.getUsers().begin()))
537 return emitOpError("expected that block argument #")
538 << arg.getArgNumber() << " is used by an 'extmemory' operation";
539 }
540
541 return success();
542}
543
544/// Parses a FuncOp signature using
545/// mlir::function_interface_impl::parseFunctionSignature while getting access
546/// to the parsed SSA names to store as attributes.
547static ParseResult
548parseFuncOpArgs(OpAsmParser &parser,
549 SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
550 SmallVectorImpl<Type> &resTypes,
551 SmallVectorImpl<DictionaryAttr> &resAttrs) {
552 bool isVariadic;
553 if (mlir::function_interface_impl::parseFunctionSignature(
554 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resTypes,
555 resAttrs)
556 .failed())
557 return failure();
558
559 return success();
560}
561
562/// Generates names for a handshake.func input and output arguments, based on
563/// the number of args as well as a prefix.
564static SmallVector<Attribute> getFuncOpNames(Builder &builder, unsigned cnt,
565 StringRef prefix) {
566 SmallVector<Attribute> resNames;
567 for (unsigned i = 0; i < cnt; ++i)
568 resNames.push_back(builder.getStringAttr(prefix + std::to_string(i)));
569 return resNames;
570}
571
572void handshake::FuncOp::build(OpBuilder &builder, OperationState &state,
573 StringRef name, FunctionType type,
574 ArrayRef<NamedAttribute> attrs) {
575 state.addAttribute(SymbolTable::getSymbolAttrName(),
576 builder.getStringAttr(name));
577 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
578 TypeAttr::get(type));
579 state.attributes.append(attrs.begin(), attrs.end());
580
581 if (const auto *argNamesAttrIt = llvm::find_if(
582 attrs, [&](auto attr) { return attr.getName() == "argNames"; });
583 argNamesAttrIt == attrs.end())
584 state.addAttribute("argNames", builder.getArrayAttr({}));
585
586 if (llvm::find_if(attrs, [&](auto attr) {
587 return attr.getName() == "resNames";
588 }) == attrs.end())
589 state.addAttribute("resNames", builder.getArrayAttr({}));
590
591 state.addRegion();
592}
593
594/// Helper function for appending a string to an array attribute, and
595/// rewriting the attribute back to the operation.
596static void addStringToStringArrayAttr(Builder &builder, Operation *op,
597 StringRef attrName, StringAttr str) {
598 llvm::SmallVector<Attribute> attrs;
599 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
600 std::back_inserter(attrs));
601 attrs.push_back(str);
602 op->setAttr(attrName, builder.getArrayAttr(attrs));
603}
604
605void handshake::FuncOp::resolveArgAndResNames() {
606 Builder builder(getContext());
607
608 /// Generate a set of fallback names. These are used in case names are
609 /// missing from the currently set arg- and res name attributes.
610 auto fallbackArgNames = getFuncOpNames(builder, getNumArguments(), "in");
611 auto fallbackResNames = getFuncOpNames(builder, getNumResults(), "out");
612 auto argNames = getArgNames().getValue();
613 auto resNames = getResNames().getValue();
614
615 /// Use fallback names where actual names are missing.
616 auto resolveNames = [&](auto &fallbackNames, auto &actualNames,
617 StringRef attrName) {
618 for (auto fallbackName : llvm::enumerate(fallbackNames)) {
619 if (actualNames.size() <= fallbackName.index())
620 addStringToStringArrayAttr(builder, this->getOperation(), attrName,
621 cast<StringAttr>(fallbackName.value()));
622 }
623 };
624 resolveNames(fallbackArgNames, argNames, "argNames");
625 resolveNames(fallbackResNames, resNames, "resNames");
626}
627
628ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
629 auto &builder = parser.getBuilder();
630 StringAttr nameAttr;
631 SmallVector<OpAsmParser::Argument> args;
632 SmallVector<Type> resTypes;
633 SmallVector<DictionaryAttr> resAttributes;
634 SmallVector<Attribute> argNames;
635
636 // Parse visibility.
637 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
638
639 // Parse signature
640 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
641 result.attributes) ||
642 parseFuncOpArgs(parser, args, resTypes, resAttributes))
643 return failure();
644 mlir::function_interface_impl::addArgAndResultAttrs(
645 builder, result, args, resAttributes,
646 handshake::FuncOp::getArgAttrsAttrName(result.name),
647 handshake::FuncOp::getResAttrsAttrName(result.name));
648
649 // Set function type
650 SmallVector<Type> argTypes;
651 for (auto arg : args)
652 argTypes.push_back(arg.type);
653
654 result.addAttribute(
655 handshake::FuncOp::getFunctionTypeAttrName(result.name),
656 TypeAttr::get(builder.getFunctionType(argTypes, resTypes)));
657
658 // Determine the names of the arguments. If no SSA values are present, use
659 // fallback names.
660 bool noSSANames =
661 llvm::any_of(args, [](auto arg) { return arg.ssaName.name.empty(); });
662 if (noSSANames) {
663 argNames = getFuncOpNames(builder, args.size(), "in");
664 } else {
665 llvm::transform(args, std::back_inserter(argNames), [&](auto arg) {
666 return builder.getStringAttr(arg.ssaName.name.drop_front());
667 });
668 }
669
670 // Parse attributes
671 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
672 return failure();
673
674 // If argNames and resNames wasn't provided manually, infer argNames attribute
675 // from the parsed SSA names and resNames from our naming convention.
676 if (!result.attributes.get("argNames"))
677 result.addAttribute("argNames", builder.getArrayAttr(argNames));
678 if (!result.attributes.get("resNames")) {
679 auto resNames = getFuncOpNames(builder, resTypes.size(), "out");
680 result.addAttribute("resNames", builder.getArrayAttr(resNames));
681 }
682
683 // Parse the optional function body. The printer will not print the body if
684 // its empty, so disallow parsing of empty body in the parser.
685 auto *body = result.addRegion();
686 llvm::SMLoc loc = parser.getCurrentLocation();
687 auto parseResult = parser.parseOptionalRegion(*body, args,
688 /*enableNameShadowing=*/false);
689 if (!parseResult.has_value())
690 return success();
691
692 if (failed(*parseResult))
693 return failure();
694 // Function body was parsed, make sure its not empty.
695 if (body->empty())
696 return parser.emitError(loc, "expected non-empty function body");
697
698 // If a body was parsed, the arg and res names need to be resolved
699 return success();
700}
701
702void FuncOp::print(OpAsmPrinter &p) {
703 mlir::function_interface_impl::printFunctionOp(
704 p, *this, /*isVariadic=*/true, getFunctionTypeAttrName(),
705 getArgAttrsAttrName(), getResAttrsAttrName());
706}
707
708namespace {
709struct EliminateSimpleControlMergesPattern
710 : mlir::OpRewritePattern<ControlMergeOp> {
711 using mlir::OpRewritePattern<ControlMergeOp>::OpRewritePattern;
712
713 LogicalResult matchAndRewrite(ControlMergeOp op,
714 PatternRewriter &rewriter) const override;
715};
716} // namespace
717
718LogicalResult EliminateSimpleControlMergesPattern::matchAndRewrite(
719 ControlMergeOp op, PatternRewriter &rewriter) const {
720 auto dataResult = op.getResult();
721 auto choiceResult = op.getIndex();
722 auto choiceUnused = choiceResult.use_empty();
723 if (!choiceUnused && !choiceResult.hasOneUse())
724 return failure();
725
726 Operation *choiceUser = nullptr;
727 if (choiceResult.hasOneUse()) {
728 choiceUser = choiceResult.getUses().begin().getUser();
729 if (!isa<SinkOp>(choiceUser))
730 return failure();
731 }
732
733 auto merge = rewriter.create<MergeOp>(op.getLoc(), op.getDataOperands());
734
735 for (auto &use : llvm::make_early_inc_range(dataResult.getUses())) {
736 auto *user = use.getOwner();
737 rewriter.modifyOpInPlace(
738 user, [&]() { user->setOperand(use.getOperandNumber(), merge); });
739 }
740
741 if (choiceUnused) {
742 rewriter.eraseOp(op);
743 return success();
744 }
745
746 rewriter.eraseOp(choiceUser);
747 rewriter.eraseOp(op);
748 return success();
749}
750
751void ControlMergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
752 MLIRContext *context) {
753 results.insert<EliminateSimpleControlMergesPattern>(context);
754}
755
756bool BranchOp::sostIsControl() {
757 return isControlCheckTypeAndOperand(getDataType(), getOperand());
758}
759
760void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
761 MLIRContext *context) {
762 results.insert<circt::handshake::EliminateSimpleBranchesPattern>(context);
763}
764
765ParseResult BranchOp::parse(OpAsmParser &parser, OperationState &result) {
766 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
767 Type type;
768 ArrayRef<Type> operandTypes(type);
769 SmallVector<Type, 1> dataOperandsTypes;
770 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
771 int size;
772 if (parseSostOperation(parser, allOperands, result, size, type, false))
773 return failure();
774
775 dataOperandsTypes.assign(size, type);
776 result.addTypes({type});
777 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
778 result.operands))
779 return failure();
780 return success();
781}
782
783void BranchOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
784
785ParseResult ConditionalBranchOp::parse(OpAsmParser &parser,
786 OperationState &result) {
787 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
788 Type dataType;
789 SmallVector<Type> operandTypes;
790 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
791 if (parser.parseOperandList(allOperands) ||
792 parser.parseOptionalAttrDict(result.attributes) ||
793 parser.parseColonType(dataType))
794 return failure();
795
796 if (allOperands.size() != 2)
797 return parser.emitError(parser.getCurrentLocation(),
798 "Expected exactly 2 operands");
799
800 result.addTypes({dataType, dataType});
801 operandTypes.push_back(IntegerType::get(parser.getContext(), 1));
802 operandTypes.push_back(dataType);
803 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
804 result.operands))
805 return failure();
806
807 return success();
808}
809
810void ConditionalBranchOp::print(OpAsmPrinter &p) {
811 Type type = getDataOperand().getType();
812 p << " " << getOperands();
813 p.printOptionalAttrDict((*this)->getAttrs());
814 p << " : " << type;
815}
816
817std::string handshake::ConditionalBranchOp::getOperandName(unsigned int idx) {
818 assert(idx == 0 || idx == 1);
819 return idx == 0 ? "cond" : "data";
820}
821
822std::string handshake::ConditionalBranchOp::getResultName(unsigned int idx) {
823 assert(idx == 0 || idx == 1);
824 return idx == ConditionalBranchOp::falseIndex ? "outFalse" : "outTrue";
825}
826
827bool ConditionalBranchOp::isControl() {
828 return isControlCheckTypeAndOperand(getDataOperand().getType(),
829 getDataOperand());
830}
831
832ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
833 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
834 Type type;
835 ArrayRef<Type> operandTypes(type);
836 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
837 int size;
838 if (parseSostOperation(parser, allOperands, result, size, type, false))
839 return failure();
840
841 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
842 result.operands))
843 return failure();
844 return success();
845}
846
847void SinkOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
848
849std::string handshake::ConstantOp::getOperandName(unsigned int idx) {
850 assert(idx == 0);
851 return "ctrl";
852}
853
854Type SourceOp::getDataType() { return getResult().getType(); }
855unsigned SourceOp::getSize() { return 1; }
856
857ParseResult SourceOp::parse(OpAsmParser &parser, OperationState &result) {
858 if (parser.parseOptionalAttrDict(result.attributes))
859 return failure();
860 result.addTypes(NoneType::get(result.getContext()));
861 return success();
862}
863
864void SourceOp::print(OpAsmPrinter &p) {
865 p.printOptionalAttrDict((*this)->getAttrs());
866}
867
868LogicalResult ConstantOp::verify() {
869 // Verify that the type of the provided value is equal to the result type.
870 auto typedValue = dyn_cast<mlir::TypedAttr>(getValue());
871 if (!typedValue)
872 return emitOpError("constant value must be a typed attribute; value is ")
873 << getValue();
874 if (typedValue.getType() != getResult().getType())
875 return emitOpError() << "constant value type " << typedValue.getType()
876 << " differs from operation result type "
877 << getResult().getType();
878
879 return success();
880}
881
882void handshake::ConstantOp::getCanonicalizationPatterns(
883 RewritePatternSet &results, MLIRContext *context) {
884 results.insert<circt::handshake::EliminateSunkConstantsPattern>(context);
885}
886
887LogicalResult BufferOp::verify() {
888 // Verify that exactly 'size' number of initial values have been provided, if
889 // an initializer list have been provided.
890 if (auto initVals = getInitValues()) {
891 if (!isSequential())
892 return emitOpError()
893 << "only bufferType buffers are allowed to have initial values.";
894
895 auto nInits = initVals->size();
896 if (nInits != getSize())
897 return emitOpError() << "expected " << getSize()
898 << " init values but got " << nInits << ".";
899 }
900
901 return success();
902}
903
904void handshake::BufferOp::getCanonicalizationPatterns(
905 RewritePatternSet &results, MLIRContext *context) {
906 results.insert<circt::handshake::EliminateSunkBuffersPattern>(context);
907}
908
909unsigned BufferOp::getSize() {
910 return (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
911}
912
913ParseResult BufferOp::parse(OpAsmParser &parser, OperationState &result) {
914 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
915 Type type;
916 ArrayRef<Type> operandTypes(type);
917 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
918 int slots;
919 if (parseIntInSquareBrackets(parser, slots))
920 return failure();
921
922 auto bufferTypeAttr = BufferTypeEnumAttr::parse(parser, {});
923 if (!bufferTypeAttr)
924 return failure();
925
926 result.addAttribute(
927 "slots",
928 IntegerAttr::get(IntegerType::get(result.getContext(), 32), slots));
929 result.addAttribute("bufferType", bufferTypeAttr);
930
931 if (parser.parseOperandList(allOperands) ||
932 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
933 parser.parseType(type))
934 return failure();
935
936 result.addTypes({type});
937 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
938 result.operands))
939 return failure();
940 return success();
941}
942
943void BufferOp::print(OpAsmPrinter &p) {
944 int size =
945 (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
946 p << " [" << size << "]";
947 p << " " << stringifyEnum(getBufferType());
948 p << " " << (*this)->getOperands();
949 p.printOptionalAttrDict((*this)->getAttrs(), {"slots", "bufferType"});
950 p << " : " << (*this).getDataType();
951}
952
953static std::string getMemoryOperandName(unsigned nStores, unsigned idx) {
954 std::string name;
955 if (idx < nStores * 2) {
956 bool isData = idx % 2 == 0;
957 name = isData ? "stData" + std::to_string(idx / 2)
958 : "stAddr" + std::to_string(idx / 2);
959 } else {
960 idx -= 2 * nStores;
961 name = "ldAddr" + std::to_string(idx);
962 }
963 return name;
964}
965
966std::string handshake::MemoryOp::getOperandName(unsigned int idx) {
967 return getMemoryOperandName(getStCount(), idx);
968}
969
970static std::string getMemoryResultName(unsigned nLoads, unsigned nStores,
971 unsigned idx) {
972 std::string name;
973 if (idx < nLoads)
974 name = "ldData" + std::to_string(idx);
975 else if (idx < nLoads + nStores)
976 name = "stDone" + std::to_string(idx - nLoads);
977 else
978 name = "ldDone" + std::to_string(idx - nLoads - nStores);
979 return name;
980}
981
982std::string handshake::MemoryOp::getResultName(unsigned int idx) {
983 return getMemoryResultName(getLdCount(), getStCount(), idx);
984}
985
986LogicalResult MemoryOp::verify() {
987 auto memrefType = getMemRefType();
988
989 if (memrefType.getNumDynamicDims() != 0)
990 return emitOpError()
991 << "memref dimensions for handshake.memory must be static.";
992 if (memrefType.getShape().size() != 1)
993 return emitOpError() << "memref must have only a single dimension.";
994
995 unsigned opStCount = getStCount();
996 unsigned opLdCount = getLdCount();
997 int addressCount = memrefType.getShape().size();
998
999 auto inputType = getInputs().getType();
1000 auto outputType = getOutputs().getType();
1001 Type dataType = memrefType.getElementType();
1002
1003 unsigned numOperands = static_cast<int>(getInputs().size());
1004 unsigned numResults = static_cast<int>(getOutputs().size());
1005 if (numOperands != (1 + addressCount) * opStCount + addressCount * opLdCount)
1006 return emitOpError("number of operands ")
1007 << numOperands << " does not match number expected of "
1008 << 2 * opStCount + opLdCount << " with " << addressCount
1009 << " address inputs per port";
1010
1011 if (numResults != opStCount + 2 * opLdCount)
1012 return emitOpError("number of results ")
1013 << numResults << " does not match number expected of "
1014 << opStCount + 2 * opLdCount << " with " << addressCount
1015 << " address inputs per port";
1016
1017 Type addressType = opStCount > 0 ? inputType[1] : inputType[0];
1018
1019 for (unsigned i = 0; i < opStCount; i++) {
1020 if (inputType[2 * i] != dataType)
1021 return emitOpError("data type for store port ")
1022 << i << ":" << inputType[2 * i] << " doesn't match memory type "
1023 << dataType;
1024 if (inputType[2 * i + 1] != addressType)
1025 return emitOpError("address type for store port ")
1026 << i << ":" << inputType[2 * i + 1]
1027 << " doesn't match address type " << addressType;
1028 }
1029 for (unsigned i = 0; i < opLdCount; i++) {
1030 Type ldAddressType = inputType[2 * opStCount + i];
1031 if (ldAddressType != addressType)
1032 return emitOpError("address type for load port ")
1033 << i << ":" << ldAddressType << " doesn't match address type "
1034 << addressType;
1035 }
1036 for (unsigned i = 0; i < opLdCount; i++) {
1037 if (outputType[i] != dataType)
1038 return emitOpError("data type for load port ")
1039 << i << ":" << outputType[i] << " doesn't match memory type "
1040 << dataType;
1041 }
1042 for (unsigned i = 0; i < opStCount; i++) {
1043 Type syncType = outputType[opLdCount + i];
1044 if (!isa<NoneType>(syncType))
1045 return emitOpError("data type for sync port for store port ")
1046 << i << ":" << syncType << " is not 'none'";
1047 }
1048 for (unsigned i = 0; i < opLdCount; i++) {
1049 Type syncType = outputType[opLdCount + opStCount + i];
1050 if (!isa<NoneType>(syncType))
1051 return emitOpError("data type for sync port for load port ")
1052 << i << ":" << syncType << " is not 'none'";
1053 }
1054
1055 return success();
1056}
1057
1058std::string handshake::ExternalMemoryOp::getOperandName(unsigned int idx) {
1059 if (idx == 0)
1060 return "extmem";
1061
1062 return getMemoryOperandName(getStCount(), idx - 1);
1063}
1064
1065std::string handshake::ExternalMemoryOp::getResultName(unsigned int idx) {
1066 return getMemoryResultName(getLdCount(), getStCount(), idx);
1067}
1068
1069void ExternalMemoryOp::build(OpBuilder &builder, OperationState &result,
1070 Value memref, ValueRange inputs, int ldCount,
1071 int stCount, int id) {
1072 SmallVector<Value> ops;
1073 ops.push_back(memref);
1074 llvm::append_range(ops, inputs);
1075 result.addOperands(ops);
1076
1077 auto memrefType = cast<MemRefType>(memref.getType());
1078
1079 // Data outputs (get their type from memref)
1080 result.types.append(ldCount, memrefType.getElementType());
1081
1082 // Control outputs
1083 result.types.append(stCount + ldCount, builder.getNoneType());
1084
1085 // Memory ID (individual ID for each MemoryOp)
1086 Type i32Type = builder.getIntegerType(32);
1087 result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1088 result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, ldCount));
1089 result.addAttribute("stCount", builder.getIntegerAttr(i32Type, stCount));
1090}
1091
1092llvm::SmallVector<handshake::MemLoadInterface>
1093ExternalMemoryOp::getLoadPorts() {
1094 return ::getLoadPorts(*this);
1095}
1096
1097llvm::SmallVector<handshake::MemStoreInterface>
1098ExternalMemoryOp::getStorePorts() {
1099 return ::getStorePorts(*this);
1100}
1101
1102void MemoryOp::build(OpBuilder &builder, OperationState &result,
1103 ValueRange operands, int outputs, int controlOutputs,
1104 bool lsq, int id, Value memref) {
1105 result.addOperands(operands);
1106
1107 auto memrefType = cast<MemRefType>(memref.getType());
1108
1109 // Data outputs (get their type from memref)
1110 result.types.append(outputs, memrefType.getElementType());
1111
1112 // Control outputs
1113 result.types.append(controlOutputs, builder.getNoneType());
1114 result.addAttribute("lsq", builder.getBoolAttr(lsq));
1115 result.addAttribute("memRefType", TypeAttr::get(memrefType));
1116
1117 // Memory ID (individual ID for each MemoryOp)
1118 Type i32Type = builder.getIntegerType(32);
1119 result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1120
1121 if (!lsq) {
1122 result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, outputs));
1123 result.addAttribute(
1124 "stCount", builder.getIntegerAttr(i32Type, controlOutputs - outputs));
1125 }
1126}
1127
1128llvm::SmallVector<handshake::MemLoadInterface> MemoryOp::getLoadPorts() {
1129 return ::getLoadPorts(*this);
1130}
1131
1132llvm::SmallVector<handshake::MemStoreInterface> MemoryOp::getStorePorts() {
1133 return ::getStorePorts(*this);
1134}
1135
1136bool handshake::MemoryOp::allocateMemory(
1137 llvm::DenseMap<unsigned, unsigned> &memoryMap,
1138 std::vector<std::vector<llvm::Any>> &store,
1139 std::vector<double> &storeTimes) {
1140 if (memoryMap.count(getId()))
1141 return false;
1142
1143 auto type = getMemRefType();
1144 std::vector<llvm::Any> in;
1145
1146 ArrayRef<int64_t> shape = type.getShape();
1147 int allocationSize = 1;
1148 unsigned count = 0;
1149 for (int64_t dim : shape) {
1150 if (dim > 0)
1151 allocationSize *= dim;
1152 else {
1153 assert(count < in.size());
1154 allocationSize *= llvm::any_cast<APInt>(in[count++]).getSExtValue();
1155 }
1156 }
1157 unsigned ptr = store.size();
1158 store.resize(ptr + 1);
1159 storeTimes.resize(ptr + 1);
1160 store[ptr].resize(allocationSize);
1161 storeTimes[ptr] = 0.0;
1162 mlir::Type elementType = type.getElementType();
1163 int width = elementType.getIntOrFloatBitWidth();
1164 for (int i = 0; i < allocationSize; i++) {
1165 if (isa<mlir::IntegerType>(elementType)) {
1166 store[ptr][i] = APInt(width, 0);
1167 } else if (isa<mlir::FloatType>(elementType)) {
1168 store[ptr][i] = APFloat(0.0);
1169 } else {
1170 llvm_unreachable("Unknown result type!\n");
1171 }
1172 }
1173
1174 memoryMap[getId()] = ptr;
1175 return true;
1176}
1177
1178std::string handshake::LoadOp::getOperandName(unsigned int idx) {
1179 unsigned nAddresses = getAddresses().size();
1180 std::string opName;
1181 if (idx < nAddresses)
1182 opName = "addrIn" + std::to_string(idx);
1183 else if (idx == nAddresses)
1184 opName = "dataFromMem";
1185 else
1186 opName = "ctrl";
1187 return opName;
1188}
1189
1190std::string handshake::LoadOp::getResultName(unsigned int idx) {
1191 std::string resName;
1192 if (idx == 0)
1193 resName = "dataOut";
1194 else
1195 resName = "addrOut" + std::to_string(idx - 1);
1196 return resName;
1197}
1198
1199void handshake::LoadOp::build(OpBuilder &builder, OperationState &result,
1200 Value memref, ValueRange indices) {
1201 // Address indices
1202 // result.addOperands(memref);
1203 result.addOperands(indices);
1204
1205 // Data type
1206 auto memrefType = cast<MemRefType>(memref.getType());
1207
1208 // Data output (from load to successor ops)
1209 result.types.push_back(memrefType.getElementType());
1210
1211 // Address outputs (to lsq)
1212 result.types.append(indices.size(), builder.getIndexType());
1213}
1214
1215static ParseResult parseMemoryAccessOp(OpAsmParser &parser,
1216 OperationState &result) {
1217 SmallVector<OpAsmParser::UnresolvedOperand, 4> addressOperands,
1218 remainingOperands, allOperands;
1219 SmallVector<Type, 1> parsedTypes, allTypes;
1220 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1221
1222 if (parser.parseLSquare() || parser.parseOperandList(addressOperands) ||
1223 parser.parseRSquare() || parser.parseOperandList(remainingOperands) ||
1224 parser.parseColon() || parser.parseTypeList(parsedTypes))
1225 return failure();
1226
1227 // The last type will be the data type of the operation; the prior will be the
1228 // address types.
1229 Type dataType = parsedTypes.back();
1230 auto parsedTypesRef = ArrayRef(parsedTypes);
1231 result.addTypes(dataType);
1232 result.addTypes(parsedTypesRef.drop_back());
1233 allOperands.append(addressOperands);
1234 allOperands.append(remainingOperands);
1235 allTypes.append(parsedTypes);
1236 allTypes.push_back(NoneType::get(result.getContext()));
1237 if (parser.resolveOperands(allOperands, allTypes, allOperandLoc,
1238 result.operands))
1239 return failure();
1240 return success();
1241}
1242
1243template <typename MemOp>
1244static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op) {
1245 p << " [";
1246 p << op.getAddresses();
1247 p << "] " << op.getData() << ", " << op.getCtrl() << " : ";
1248 llvm::interleaveComma(op.getAddresses(), p,
1249 [&](Value v) { p << v.getType(); });
1250 p << ", " << op.getData().getType();
1251}
1252
1253ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
1254 return parseMemoryAccessOp(parser, result);
1255}
1256
1257void LoadOp::print(OpAsmPrinter &p) { printMemoryAccessOp(p, *this); }
1258
1259std::string handshake::StoreOp::getOperandName(unsigned int idx) {
1260 unsigned nAddresses = getAddresses().size();
1261 std::string opName;
1262 if (idx < nAddresses)
1263 opName = "addrIn" + std::to_string(idx);
1264 else if (idx == nAddresses)
1265 opName = "dataIn";
1266 else
1267 opName = "ctrl";
1268 return opName;
1269}
1270
1271template <typename TMemoryOp>
1272static LogicalResult verifyMemoryAccessOp(TMemoryOp op) {
1273 if (op.getAddresses().size() == 0)
1274 return op.emitOpError() << "No addresses were specified";
1275
1276 return success();
1277}
1278
1279LogicalResult LoadOp::verify() { return verifyMemoryAccessOp(*this); }
1280
1281std::string handshake::StoreOp::getResultName(unsigned int idx) {
1282 std::string resName;
1283 if (idx == 0)
1284 resName = "dataToMem";
1285 else
1286 resName = "addrOut" + std::to_string(idx - 1);
1287 return resName;
1288}
1289
1290void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
1291 Value valueToStore, ValueRange indices) {
1292
1293 // Address indices
1294 result.addOperands(indices);
1295
1296 // Data
1297 result.addOperands(valueToStore);
1298
1299 // Data output (from store to LSQ)
1300 result.types.push_back(valueToStore.getType());
1301
1302 // Address outputs (from store to lsq)
1303 result.types.append(indices.size(), builder.getIndexType());
1304}
1305
1306LogicalResult StoreOp::verify() { return verifyMemoryAccessOp(*this); }
1307
1308ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
1309 return parseMemoryAccessOp(parser, result);
1310}
1311
1312void StoreOp::print(OpAsmPrinter &p) { return printMemoryAccessOp(p, *this); }
1313
1314bool JoinOp::isControl() { return true; }
1315
1316ParseResult JoinOp::parse(OpAsmParser &parser, OperationState &result) {
1317 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1318 SmallVector<Type> types;
1319
1320 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1321 if (parser.parseOperandList(operands) ||
1322 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1323 parser.parseTypeList(types))
1324 return failure();
1325
1326 if (parser.resolveOperands(operands, types, allOperandLoc, result.operands))
1327 return failure();
1328
1329 result.addTypes(NoneType::get(result.getContext()));
1330 return success();
1331}
1332
1333void JoinOp::print(OpAsmPrinter &p) {
1334 p << " " << getData();
1335 p.printOptionalAttrDict((*this)->getAttrs(), {"control"});
1336 p << " : " << getData().getTypes();
1337}
1338
1339LogicalResult
1340ESIInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1341 // Check that the module attribute was specified.
1342 auto fnAttr = this->getModuleAttr();
1343 assert(fnAttr && "requires a 'module' symbol reference attribute");
1344
1345 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
1346 if (!fn)
1347 return emitOpError() << "'" << fnAttr.getValue()
1348 << "' does not reference a valid handshake function";
1349
1350 // Verify that the operand and result types match the callee.
1351 auto fnType = fn.getFunctionType();
1352 if (fnType.getNumInputs() != getNumOperands() - NumFixedOperands)
1353 return emitOpError(
1354 "incorrect number of operands for the referenced handshake function");
1355
1356 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1357 Type operandType = getOperand(i + NumFixedOperands).getType();
1358 auto channelType = dyn_cast<esi::ChannelType>(operandType);
1359 if (!channelType)
1360 return emitOpError("operand type mismatch: expected channel type, but "
1361 "provided ")
1362 << operandType << " for operand number " << i;
1363 if (channelType.getInner() != fnType.getInput(i))
1364 return emitOpError("operand type mismatch: expected operand type ")
1365 << fnType.getInput(i) << ", but provided "
1366 << getOperand(i).getType() << " for operand number " << i;
1367 }
1368
1369 if (fnType.getNumResults() != getNumResults())
1370 return emitOpError(
1371 "incorrect number of results for the referenced handshake function");
1372
1373 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1374 Type resultType = getResult(i).getType();
1375 auto channelType = dyn_cast<esi::ChannelType>(resultType);
1376 if (!channelType)
1377 return emitOpError("result type mismatch: expected channel type, but "
1378 "provided ")
1379 << resultType << " for result number " << i;
1380 if (channelType.getInner() != fnType.getResult(i))
1381 return emitOpError("result type mismatch: expected result type ")
1382 << fnType.getResult(i) << ", but provided "
1383 << getResult(i).getType() << " for result number " << i;
1384 }
1385
1386 return success();
1387}
1388
1389/// Based on mlir::func::CallOp::verifySymbolUses
1390LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1391 // Check that the module attribute was specified.
1392 auto fnAttr = this->getModuleAttr();
1393 assert(fnAttr && "requires a 'module' symbol reference attribute");
1394
1395 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
1396 if (!fn)
1397 return emitOpError() << "'" << fnAttr.getValue()
1398 << "' does not reference a valid handshake function";
1399
1400 // Verify that the operand and result types match the callee.
1401 auto fnType = fn.getFunctionType();
1402 if (fnType.getNumInputs() != getNumOperands())
1403 return emitOpError(
1404 "incorrect number of operands for the referenced handshake function");
1405
1406 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
1407 if (getOperand(i).getType() != fnType.getInput(i))
1408 return emitOpError("operand type mismatch: expected operand type ")
1409 << fnType.getInput(i) << ", but provided "
1410 << getOperand(i).getType() << " for operand number " << i;
1411
1412 if (fnType.getNumResults() != getNumResults())
1413 return emitOpError(
1414 "incorrect number of results for the referenced handshake function");
1415
1416 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
1417 if (getResult(i).getType() != fnType.getResult(i))
1418 return emitOpError("result type mismatch: expected result type ")
1419 << fnType.getResult(i) << ", but provided "
1420 << getResult(i).getType() << " for result number " << i;
1421
1422 return success();
1423}
1424
1425FunctionType InstanceOp::getModuleType() {
1426 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
1427}
1428
1429ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1430 OpAsmParser::UnresolvedOperand tuple;
1431 TupleType type;
1432
1433 if (parser.parseOperand(tuple) ||
1434 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1435 parser.parseType(type))
1436 return failure();
1437
1438 if (parser.resolveOperand(tuple, type, result.operands))
1439 return failure();
1440
1441 result.addTypes(type.getTypes());
1442
1443 return success();
1444}
1445
1446void UnpackOp::print(OpAsmPrinter &p) {
1447 p << " " << getInput();
1448 p.printOptionalAttrDict((*this)->getAttrs());
1449 p << " : " << getInput().getType();
1450}
1451
1452ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1453 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1454 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1455 TupleType type;
1456
1457 if (parser.parseOperandList(operands) ||
1458 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1459 parser.parseType(type))
1460 return failure();
1461
1462 if (parser.resolveOperands(operands, type.getTypes(), allOperandLoc,
1463 result.operands))
1464 return failure();
1465
1466 result.addTypes(type);
1467
1468 return success();
1469}
1470
1471void PackOp::print(OpAsmPrinter &p) {
1472 p << " " << getInputs();
1473 p.printOptionalAttrDict((*this)->getAttrs());
1474 p << " : " << getResult().getType();
1475}
1476
1477//===----------------------------------------------------------------------===//
1478// TableGen'd op method definitions
1479//===----------------------------------------------------------------------===//
1480
1481LogicalResult ReturnOp::verify() {
1482 auto *parent = (*this)->getParentOp();
1483 auto function = dyn_cast<handshake::FuncOp>(parent);
1484 if (!function)
1485 return emitOpError("must have a handshake.func parent");
1486
1487 // The operand number and types must match the function signature.
1488 const auto &results = function.getResultTypes();
1489 if (getNumOperands() != results.size())
1490 return emitOpError("has ")
1491 << getNumOperands() << " operands, but enclosing function returns "
1492 << results.size();
1493
1494 for (unsigned i = 0, e = results.size(); i != e; ++i)
1495 if (getOperand(i).getType() != results[i])
1496 return emitError() << "type of return operand " << i << " ("
1497 << getOperand(i).getType()
1498 << ") doesn't match function result type ("
1499 << results[i] << ")";
1500
1501 return success();
1502}
1503
1504#define GET_OP_CLASSES
1505#include "circt/Dialect/Handshake/Handshake.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal, uint64_t numOperands)
Verifies whether an indexing value is wide enough to index into a provided number of operands.
static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result)
static ParseResult parseSostOperation(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, OperationState &result, int &size, Type &type, bool explicitSize)
static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op)
static ParseResult parseFuncOpArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Type > &resTypes, SmallVectorImpl< DictionaryAttr > &resAttrs)
Parses a FuncOp signature using mlir::function_interface_impl::parseFunctionSignature while getting a...
static Value getDematerialized(Value v)
Returns a dematerialized version of the value 'v', defined as the source of the value before passing ...
static bool isControlCheckTypeAndOperand(Type dataType, Value operand)
static std::string getMemoryResultName(unsigned nLoads, unsigned nStores, unsigned idx)
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v)
static LogicalResult verifyMemoryAccessOp(TMemoryOp op)
static SmallVector< Attribute > getFuncOpNames(Builder &builder, unsigned cnt, StringRef prefix)
Generates names for a handshake.func input and output arguments, based on the number of args as well ...
static void addStringToStringArrayAttr(Builder &builder, Operation *op, StringRef attrName, StringAttr str)
Helper function for appending a string to an array attribute, and rewriting the attribute back to the...
static std::string defaultOperandName(unsigned int idx)
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
static std::string getMemoryOperandName(unsigned nStores, unsigned idx)
static ParseResult parseMemoryAccessOp(OpAsmParser &parser, OperationState &result)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
bool isControlOpImpl(Operation *op)
Default implementation for checking whether an operation is a control operation.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.