CIRCT 22.0.0git
Loading...
Searching...
No Matches
HandshakeOps.cpp
Go to the documentation of this file.
1//===- HandshakeOps.cpp - Handshake MLIR Operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the declaration of the Handshake operations struct.
10//
11//===----------------------------------------------------------------------===//
12
15#include "circt/Support/LLVM.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/IntegerSet.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/OpDefinition.h"
24#include "mlir/IR/OpImplementation.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/SymbolTable.h"
27#include "mlir/IR/Value.h"
28#include "mlir/Interfaces/FunctionImplementation.h"
29#include "mlir/Transforms/InliningUtils.h"
30#include "llvm/ADT/SetVector.h"
31#include "llvm/ADT/SmallBitVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34#include <set>
35
36using namespace circt;
37using namespace circt::handshake;
38
39namespace circt {
40namespace handshake {
41#include "circt/Dialect/Handshake/HandshakeCanonicalization.h.inc"
42
43bool isControlOpImpl(Operation *op) {
44 if (auto sostInterface = dyn_cast<SOSTInterface>(op); sostInterface)
45 return sostInterface.sostIsControl();
46
47 return false;
48}
49
50} // namespace handshake
51} // namespace circt
52
53static std::string defaultOperandName(unsigned int idx) {
54 return "in" + std::to_string(idx);
55}
56
57static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v) {
58 if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
59 return failure();
60 return success();
61}
62
63static ParseResult
64parseSostOperation(OpAsmParser &parser,
65 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
66 OperationState &result, int &size, Type &type,
67 bool explicitSize) {
68 if (explicitSize)
69 if (parseIntInSquareBrackets(parser, size))
70 return failure();
71
72 if (parser.parseOperandList(operands) ||
73 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
74 parser.parseType(type))
75 return failure();
76
77 if (!explicitSize)
78 size = operands.size();
79 return success();
80}
81
82/// Verifies whether an indexing value is wide enough to index into a provided
83/// number of operands.
84static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal,
85 uint64_t numOperands) {
86 auto indexType = indexVal.getType();
87 unsigned indexWidth;
88
89 // Determine the bitwidth of the indexing value
90 if (auto integerType = dyn_cast<IntegerType>(indexType))
91 indexWidth = integerType.getWidth();
92 else if (indexType.isIndex())
93 indexWidth = IndexType::kInternalStorageBitWidth;
94 else
95 return op->emitError("unsupported type for indexing value: ") << indexType;
96
97 // Check whether the bitwidth can support the provided number of operands
98 if (indexWidth < 64) {
99 uint64_t maxNumOperands = (uint64_t)1 << indexWidth;
100 if (numOperands > maxNumOperands)
101 return op->emitError("bitwidth of indexing value is ")
102 << indexWidth << ", which can index into " << maxNumOperands
103 << " operands, but found " << numOperands << " operands";
104 }
105 return success();
106}
107
108static bool isControlCheckTypeAndOperand(Type dataType, Value operand) {
109 // The operation is a control operation if its operand data type is a
110 // NoneType.
111 if (isa<NoneType>(dataType))
112 return true;
113
114 // Otherwise, the operation is a control operation if the operation's
115 // operand originates from the control network
116 auto *defOp = operand.getDefiningOp();
117 return isa_and_nonnull<ControlMergeOp>(defOp) &&
118 operand == defOp->getResult(0);
119}
120
121template <typename TMemOp>
122llvm::SmallVector<handshake::MemLoadInterface> getLoadPorts(TMemOp op) {
123 llvm::SmallVector<MemLoadInterface> ports;
124 // Memory interface refresher:
125 // Operands:
126 // all stores (stdata1, staddr1, stdata2, staddr2, ...)
127 // then all loads (ldaddr1, ldaddr2,...)
128 // Outputs: load addresses (lddata1, lddata2, ...), followed by all none
129 // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
130 unsigned stCount = op.getStCount();
131 unsigned ldCount = op.getLdCount();
132 for (unsigned i = 0, e = ldCount; i != e; ++i) {
133 MemLoadInterface ldif;
134 ldif.index = i;
135 ldif.addressIn = op.getInputs()[stCount * 2 + i];
136 ldif.dataOut = op.getResult(i);
137 ldif.doneOut = op.getResult(ldCount + stCount + i);
138 ports.push_back(ldif);
139 }
140 return ports;
141}
142
143template <typename TMemOp>
144llvm::SmallVector<handshake::MemStoreInterface> getStorePorts(TMemOp op) {
145 llvm::SmallVector<MemStoreInterface> ports;
146 // Memory interface refresher:
147 // Operands:
148 // all stores (stdata1, staddr1, stdata2, staddr2, ...)
149 // then all loads (ldaddr1, ldaddr2,...)
150 // Outputs: load data (lddata1, lddata2, ...), followed by all none
151 // outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
152 unsigned ldCount = op.getLdCount();
153 for (unsigned i = 0, e = op.getStCount(); i != e; ++i) {
155 stif.index = i;
156 stif.dataIn = op.getInputs()[i * 2];
157 stif.addressIn = op.getInputs()[i * 2 + 1];
158 stif.doneOut = op.getResult(ldCount + i);
159 ports.push_back(stif);
160 }
161 return ports;
162}
163
164unsigned ForkOp::getSize() { return getResults().size(); }
165
166static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result) {
167 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
168 Type type;
169 ArrayRef<Type> operandTypes(type);
170 SmallVector<Type, 1> resultTypes;
171 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
172 int size;
173 if (parseSostOperation(parser, allOperands, result, size, type, true))
174 return failure();
175
176 resultTypes.assign(size, type);
177 result.addTypes(resultTypes);
178 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
179 result.operands))
180 return failure();
181 return success();
182}
183
184ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
185 return parseForkOp(parser, result);
186}
187
188void ForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
189
190namespace {
191
192struct EliminateUnusedForkResultsPattern : mlir::OpRewritePattern<ForkOp> {
194
195 LogicalResult matchAndRewrite(ForkOp op,
196 PatternRewriter &rewriter) const override {
197 std::set<unsigned> unusedIndexes;
198
199 for (auto res : llvm::enumerate(op.getResults()))
200 if (res.value().getUses().empty())
201 unusedIndexes.insert(res.index());
202
203 if (unusedIndexes.empty())
204 return failure();
205
206 // Create a new fork op, dropping the unused results.
207 rewriter.setInsertionPoint(op);
208 auto operand = op.getOperand();
209 auto newFork = ForkOp::create(rewriter, op.getLoc(), operand,
210 op.getNumResults() - unusedIndexes.size());
211 unsigned i = 0;
212 for (auto oldRes : llvm::enumerate(op.getResults()))
213 if (unusedIndexes.count(oldRes.index()) == 0)
214 rewriter.replaceAllUsesWith(oldRes.value(), newFork.getResults()[i++]);
215 rewriter.eraseOp(op);
216 return success();
217 }
218};
219
220struct EliminateForkToForkPattern : mlir::OpRewritePattern<ForkOp> {
222
223 LogicalResult matchAndRewrite(ForkOp op,
224 PatternRewriter &rewriter) const override {
225 auto parentForkOp = op.getOperand().getDefiningOp<ForkOp>();
226 if (!parentForkOp)
227 return failure();
228
229 /// Create the fork with as many outputs as the two source forks.
230 /// Keeping the op.operand() output may or may not be redundant (dependning
231 /// on if op is the single user of the value), but we'll let
232 /// EliminateUnusedForkResultsPattern apply in that case.
233 unsigned totalNumOuts = op.getSize() + parentForkOp.getSize();
234 /// Create a new parent fork op which produces all of the fork outputs and
235 /// replace all of the uses of the old results.
236 auto newParentForkOp =
237 ForkOp::create(rewriter, parentForkOp.getLoc(),
238 parentForkOp.getOperand(), totalNumOuts);
239
240 for (auto it :
241 llvm::zip(parentForkOp->getResults(), newParentForkOp.getResults()))
242 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
243
244 /// Replace the results of the matches fork op with the corresponding
245 /// results of the new parent fork op.
246 rewriter.replaceOp(op,
247 newParentForkOp.getResults().take_back(op.getSize()));
248 rewriter.eraseOp(parentForkOp);
249 return success();
250 }
251};
252
253} // namespace
254
255void handshake::ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
256 MLIRContext *context) {
257 results.insert<circt::handshake::EliminateSimpleForksPattern,
258 EliminateForkToForkPattern, EliminateUnusedForkResultsPattern>(
259 context);
260}
261
262unsigned LazyForkOp::getSize() { return getResults().size(); }
263
264bool LazyForkOp::sostIsControl() {
265 return isControlCheckTypeAndOperand(getDataType(), getOperand());
266}
267
268ParseResult LazyForkOp::parse(OpAsmParser &parser, OperationState &result) {
269 return parseForkOp(parser, result);
270}
271
272void LazyForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
273
274ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
275 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
276 Type type;
277 ArrayRef<Type> operandTypes(type);
278 SmallVector<Type, 1> resultTypes, dataOperandsTypes;
279 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
280 int size;
281 if (parseSostOperation(parser, allOperands, result, size, type, false))
282 return failure();
283
284 dataOperandsTypes.assign(size, type);
285 resultTypes.push_back(type);
286 result.addTypes(resultTypes);
287 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
288 result.operands))
289 return failure();
290 return success();
291}
292
293void MergeOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
294
295void MergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
296 MLIRContext *context) {
297 results.insert<circt::handshake::EliminateSimpleMergesPattern>(context);
298}
299
300/// Returns a dematerialized version of the value 'v', defined as the source of
301/// the value before passing through a buffer or fork operation.
302static Value getDematerialized(Value v) {
303 Operation *parentOp = v.getDefiningOp();
304 if (!parentOp)
305 return v;
306
307 return llvm::TypeSwitch<Operation *, Value>(parentOp)
308 .Case<ForkOp>(
309 [&](ForkOp op) { return getDematerialized(op.getOperand()); })
310 .Case<BufferOp>(
311 [&](BufferOp op) { return getDematerialized(op.getOperand()); })
312 .Default([&](auto) { return v; });
313}
314
315namespace {
316
317/// Eliminates muxes with identical data inputs. Data inputs are inspected as
318/// their dematerialized versions. This has the side effect of any subsequently
319/// unused buffers are DCE'd and forks are optimized to be narrower.
320struct EliminateSimpleMuxesPattern : mlir::OpRewritePattern<MuxOp> {
322 LogicalResult matchAndRewrite(MuxOp op,
323 PatternRewriter &rewriter) const override {
324 Value firstDataOperand = getDematerialized(op.getDataOperands()[0]);
325 if (!llvm::all_of(op.getDataOperands(), [&](Value operand) {
326 return getDematerialized(operand) == firstDataOperand;
327 }))
328 return failure();
329 rewriter.replaceOp(op, firstDataOperand);
330 return success();
331 }
332};
333
334struct EliminateUnaryMuxesPattern : OpRewritePattern<MuxOp> {
336 LogicalResult matchAndRewrite(MuxOp op,
337 PatternRewriter &rewriter) const override {
338 if (op.getDataOperands().size() != 1)
339 return failure();
340
341 rewriter.replaceOp(op, op.getDataOperands()[0]);
342 return success();
343 }
344};
345
346struct EliminateCBranchIntoMuxPattern : OpRewritePattern<MuxOp> {
348 LogicalResult matchAndRewrite(MuxOp op,
349 PatternRewriter &rewriter) const override {
350
351 auto dataOperands = op.getDataOperands();
352 if (dataOperands.size() != 2)
353 return failure();
354
355 // Both data operands must originate from the same cbranch
356 ConditionalBranchOp firstParentCBranch =
357 dataOperands[0].getDefiningOp<ConditionalBranchOp>();
358 if (!firstParentCBranch)
359 return failure();
360 auto secondParentCBranch =
361 dataOperands[1].getDefiningOp<ConditionalBranchOp>();
362 if (!secondParentCBranch || firstParentCBranch != secondParentCBranch)
363 return failure();
364
365 rewriter.modifyOpInPlace(firstParentCBranch, [&] {
366 // Replace uses of the mux's output with cbranch's data input
367 rewriter.replaceOp(op, firstParentCBranch.getDataOperand());
368 });
369
370 return success();
371 }
372};
373
374} // namespace
375
376void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
377 MLIRContext *context) {
378 results.insert<EliminateSimpleMuxesPattern, EliminateUnaryMuxesPattern,
379 EliminateCBranchIntoMuxPattern>(context);
380}
381
382LogicalResult
383MuxOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
384 ValueRange operands, DictionaryAttr attributes,
385 mlir::OpaqueProperties properties,
386 mlir::RegionRange regions,
387 SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
388 // MuxOp must have at least one data operand (in addition to the select
389 // operand)
390 if (operands.size() < 2)
391 return failure();
392 // Result type is type of any data operand
393 inferredReturnTypes.push_back(operands[1].getType());
394 return success();
395}
396
397bool MuxOp::isControl() { return isa<NoneType>(getResult().getType()); }
398
399std::string handshake::MuxOp::getOperandName(unsigned int idx) {
400 return idx == 0 ? "select" : defaultOperandName(idx - 1);
401}
402
403ParseResult MuxOp::parse(OpAsmParser &parser, OperationState &result) {
404 OpAsmParser::UnresolvedOperand selectOperand;
405 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
406 Type selectType, dataType;
407 SmallVector<Type, 1> dataOperandsTypes;
408 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
409 if (parser.parseOperand(selectOperand) || parser.parseLSquare() ||
410 parser.parseOperandList(allOperands) || parser.parseRSquare() ||
411 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
412 parser.parseType(selectType) || parser.parseComma() ||
413 parser.parseType(dataType))
414 return failure();
415
416 int size = allOperands.size();
417 dataOperandsTypes.assign(size, dataType);
418 result.addTypes(dataType);
419 allOperands.insert(allOperands.begin(), selectOperand);
420 if (parser.resolveOperands(
421 allOperands,
422 llvm::concat<const Type>(ArrayRef<Type>(selectType),
423 ArrayRef<Type>(dataOperandsTypes)),
424 allOperandLoc, result.operands))
425 return failure();
426 return success();
427}
428
429void MuxOp::print(OpAsmPrinter &p) {
430 Type selectType = getSelectOperand().getType();
431 auto ops = getOperands();
432 p << ' ' << ops.front();
433 p << " [";
434 p.printOperands(ops.drop_front());
435 p << "]";
436 p.printOptionalAttrDict((*this)->getAttrs());
437 p << " : " << selectType << ", " << getResult().getType();
438}
439
440LogicalResult MuxOp::verify() {
441 return verifyIndexWideEnough(*this, getSelectOperand(),
442 getDataOperands().size());
443}
444
445std::string handshake::ControlMergeOp::getResultName(unsigned int idx) {
446 assert(idx == 0 || idx == 1);
447 return idx == 0 ? "dataOut" : "index";
448}
449
450ParseResult ControlMergeOp::parse(OpAsmParser &parser, OperationState &result) {
451 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
452 Type resultType, indexType;
453 SmallVector<Type> resultTypes, dataOperandsTypes;
454 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
455 int size;
456 if (parseSostOperation(parser, allOperands, result, size, resultType, false))
457 return failure();
458 // Parse type of index result
459 if (parser.parseComma() || parser.parseType(indexType))
460 return failure();
461
462 dataOperandsTypes.assign(size, resultType);
463 resultTypes.push_back(resultType);
464 resultTypes.push_back(indexType);
465 result.addTypes(resultTypes);
466 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
467 result.operands))
468 return failure();
469 return success();
470}
471
472void ControlMergeOp::print(OpAsmPrinter &p) {
473 sostPrint(p, false);
474 // Print type of index result
475 p << ", " << getIndex().getType();
476}
477
478LogicalResult ControlMergeOp::verify() {
479 auto operands = getOperands();
480 if (operands.empty())
481 return emitOpError("operation must have at least one operand");
482 if (operands[0].getType() != getResult().getType())
483 return emitOpError("type of first result should match type of operands");
484 return verifyIndexWideEnough(*this, getIndex(), getNumOperands());
485}
486
487LogicalResult FuncOp::verify() {
488 // If this function is external there is nothing to do.
489 if (isExternal())
490 return success();
491
492 // Verify that the argument list of the function and the arg list of the
493 // entry block line up. The trait already verified that the number of
494 // arguments is the same between the signature and the block.
495 auto fnInputTypes = getArgumentTypes();
496 Block &entryBlock = front();
497
498 for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
499 if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
500 return emitOpError("type of entry block argument #")
501 << i << '(' << entryBlock.getArgument(i).getType()
502 << ") must match the type of the corresponding argument in "
503 << "function signature(" << fnInputTypes[i] << ')';
504
505 // Verify that we have a name for each argument and result of this function.
506 auto verifyPortNameAttr = [&](StringRef attrName,
507 unsigned numIOs) -> LogicalResult {
508 auto portNamesAttr = (*this)->getAttrOfType<ArrayAttr>(attrName);
509
510 if (!portNamesAttr)
511 return emitOpError() << "expected attribute '" << attrName << "'.";
512
513 auto portNames = portNamesAttr.getValue();
514 if (portNames.size() != numIOs)
515 return emitOpError() << "attribute '" << attrName << "' has "
516 << portNames.size()
517 << " entries but is expected to have " << numIOs
518 << ".";
519
520 if (llvm::any_of(portNames,
521 [&](Attribute attr) { return !isa<StringAttr>(attr); }))
522 return emitOpError() << "expected all entries in attribute '" << attrName
523 << "' to be strings.";
524
525 return success();
526 };
527 if (failed(verifyPortNameAttr("argNames", getNumArguments())))
528 return failure();
529 if (failed(verifyPortNameAttr("resNames", getNumResults())))
530 return failure();
531
532 // Verify that all memrefs have a corresponding extmemory operation
533 for (auto arg : entryBlock.getArguments()) {
534 if (!isa<MemRefType>(arg.getType()))
535 continue;
536 if (arg.getUsers().empty() ||
537 !isa<ExternalMemoryOp>(*arg.getUsers().begin()))
538 return emitOpError("expected that block argument #")
539 << arg.getArgNumber() << " is used by an 'extmemory' operation";
540 }
541
542 return success();
543}
544
545/// Parses a FuncOp signature using
546/// mlir::function_interface_impl::parseFunctionSignature while getting access
547/// to the parsed SSA names to store as attributes.
548static ParseResult
549parseFuncOpArgs(OpAsmParser &parser,
550 SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
551 SmallVectorImpl<Type> &resTypes,
552 SmallVectorImpl<DictionaryAttr> &resAttrs) {
553 bool isVariadic;
554 if (mlir::function_interface_impl::parseFunctionSignatureWithArguments(
555 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resTypes,
556 resAttrs)
557 .failed())
558 return failure();
559
560 return success();
561}
562
563/// Generates names for a handshake.func input and output arguments, based on
564/// the number of args as well as a prefix.
565static SmallVector<Attribute> getFuncOpNames(Builder &builder, unsigned cnt,
566 StringRef prefix) {
567 SmallVector<Attribute> resNames;
568 for (unsigned i = 0; i < cnt; ++i)
569 resNames.push_back(builder.getStringAttr(prefix + std::to_string(i)));
570 return resNames;
571}
572
573void handshake::FuncOp::build(OpBuilder &builder, OperationState &state,
574 StringRef name, FunctionType type,
575 ArrayRef<NamedAttribute> attrs) {
576 state.addAttribute(SymbolTable::getSymbolAttrName(),
577 builder.getStringAttr(name));
578 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
579 TypeAttr::get(type));
580 state.attributes.append(attrs.begin(), attrs.end());
581
582 if (const auto *argNamesAttrIt = llvm::find_if(
583 attrs, [&](auto attr) { return attr.getName() == "argNames"; });
584 argNamesAttrIt == attrs.end())
585 state.addAttribute("argNames", builder.getArrayAttr({}));
586
587 if (llvm::find_if(attrs, [&](auto attr) {
588 return attr.getName() == "resNames";
589 }) == attrs.end())
590 state.addAttribute("resNames", builder.getArrayAttr({}));
591
592 state.addRegion();
593}
594
595/// Helper function for appending a string to an array attribute, and
596/// rewriting the attribute back to the operation.
597static void addStringToStringArrayAttr(Builder &builder, Operation *op,
598 StringRef attrName, StringAttr str) {
599 llvm::SmallVector<Attribute> attrs;
600 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
601 std::back_inserter(attrs));
602 attrs.push_back(str);
603 op->setAttr(attrName, builder.getArrayAttr(attrs));
604}
605
606void handshake::FuncOp::resolveArgAndResNames() {
607 Builder builder(getContext());
608
609 /// Generate a set of fallback names. These are used in case names are
610 /// missing from the currently set arg- and res name attributes.
611 auto fallbackArgNames = getFuncOpNames(builder, getNumArguments(), "in");
612 auto fallbackResNames = getFuncOpNames(builder, getNumResults(), "out");
613 auto argNames = getArgNames().getValue();
614 auto resNames = getResNames().getValue();
615
616 /// Use fallback names where actual names are missing.
617 auto resolveNames = [&](auto &fallbackNames, auto &actualNames,
618 StringRef attrName) {
619 for (auto fallbackName : llvm::enumerate(fallbackNames)) {
620 if (actualNames.size() <= fallbackName.index())
621 addStringToStringArrayAttr(builder, this->getOperation(), attrName,
622 cast<StringAttr>(fallbackName.value()));
623 }
624 };
625 resolveNames(fallbackArgNames, argNames, "argNames");
626 resolveNames(fallbackResNames, resNames, "resNames");
627}
628
629ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
630 auto &builder = parser.getBuilder();
631 StringAttr nameAttr;
632 SmallVector<OpAsmParser::Argument> args;
633 SmallVector<Type> resTypes;
634 SmallVector<DictionaryAttr> resAttributes;
635 SmallVector<Attribute> argNames;
636
637 // Parse visibility.
638 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
639
640 // Parse signature
641 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
642 result.attributes) ||
643 parseFuncOpArgs(parser, args, resTypes, resAttributes))
644 return failure();
645 mlir::call_interface_impl::addArgAndResultAttrs(
646 builder, result, args, resAttributes,
647 handshake::FuncOp::getArgAttrsAttrName(result.name),
648 handshake::FuncOp::getResAttrsAttrName(result.name));
649
650 // Set function type
651 SmallVector<Type> argTypes;
652 for (auto arg : args)
653 argTypes.push_back(arg.type);
654
655 result.addAttribute(
656 handshake::FuncOp::getFunctionTypeAttrName(result.name),
657 TypeAttr::get(builder.getFunctionType(argTypes, resTypes)));
658
659 // Determine the names of the arguments. If no SSA values are present, use
660 // fallback names.
661 bool noSSANames =
662 llvm::any_of(args, [](auto arg) { return arg.ssaName.name.empty(); });
663 if (noSSANames) {
664 argNames = getFuncOpNames(builder, args.size(), "in");
665 } else {
666 llvm::transform(args, std::back_inserter(argNames), [&](auto arg) {
667 return builder.getStringAttr(arg.ssaName.name.drop_front());
668 });
669 }
670
671 // Parse attributes
672 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
673 return failure();
674
675 // If argNames and resNames wasn't provided manually, infer argNames attribute
676 // from the parsed SSA names and resNames from our naming convention.
677 if (!result.attributes.get("argNames"))
678 result.addAttribute("argNames", builder.getArrayAttr(argNames));
679 if (!result.attributes.get("resNames")) {
680 auto resNames = getFuncOpNames(builder, resTypes.size(), "out");
681 result.addAttribute("resNames", builder.getArrayAttr(resNames));
682 }
683
684 // Parse the optional function body. The printer will not print the body if
685 // its empty, so disallow parsing of empty body in the parser.
686 auto *body = result.addRegion();
687 llvm::SMLoc loc = parser.getCurrentLocation();
688 auto parseResult = parser.parseOptionalRegion(*body, args,
689 /*enableNameShadowing=*/false);
690 if (!parseResult.has_value())
691 return success();
692
693 if (failed(*parseResult))
694 return failure();
695 // Function body was parsed, make sure its not empty.
696 if (body->empty())
697 return parser.emitError(loc, "expected non-empty function body");
698
699 // If a body was parsed, the arg and res names need to be resolved
700 return success();
701}
702
703void FuncOp::print(OpAsmPrinter &p) {
704 mlir::function_interface_impl::printFunctionOp(
705 p, *this, /*isVariadic=*/true, getFunctionTypeAttrName(),
706 getArgAttrsAttrName(), getResAttrsAttrName());
707}
708
709namespace {
710struct EliminateSimpleControlMergesPattern
711 : mlir::OpRewritePattern<ControlMergeOp> {
712 using mlir::OpRewritePattern<ControlMergeOp>::OpRewritePattern;
713
714 LogicalResult matchAndRewrite(ControlMergeOp op,
715 PatternRewriter &rewriter) const override;
716};
717} // namespace
718
719LogicalResult EliminateSimpleControlMergesPattern::matchAndRewrite(
720 ControlMergeOp op, PatternRewriter &rewriter) const {
721 auto dataResult = op.getResult();
722 auto choiceResult = op.getIndex();
723 auto choiceUnused = choiceResult.use_empty();
724 if (!choiceUnused && !choiceResult.hasOneUse())
725 return failure();
726
727 Operation *choiceUser = nullptr;
728 if (choiceResult.hasOneUse()) {
729 choiceUser = choiceResult.getUses().begin().getUser();
730 if (!isa<SinkOp>(choiceUser))
731 return failure();
732 }
733
734 auto merge = MergeOp::create(rewriter, op.getLoc(), op.getDataOperands());
735
736 for (auto &use : llvm::make_early_inc_range(dataResult.getUses())) {
737 auto *user = use.getOwner();
738 rewriter.modifyOpInPlace(
739 user, [&]() { user->setOperand(use.getOperandNumber(), merge); });
740 }
741
742 if (choiceUnused) {
743 rewriter.eraseOp(op);
744 return success();
745 }
746
747 rewriter.eraseOp(choiceUser);
748 rewriter.eraseOp(op);
749 return success();
750}
751
752void ControlMergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
753 MLIRContext *context) {
754 results.insert<EliminateSimpleControlMergesPattern>(context);
755}
756
757bool BranchOp::sostIsControl() {
758 return isControlCheckTypeAndOperand(getDataType(), getOperand());
759}
760
761void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
762 MLIRContext *context) {
763 results.insert<circt::handshake::EliminateSimpleBranchesPattern>(context);
764}
765
766ParseResult BranchOp::parse(OpAsmParser &parser, OperationState &result) {
767 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
768 Type type;
769 ArrayRef<Type> operandTypes(type);
770 SmallVector<Type, 1> dataOperandsTypes;
771 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
772 int size;
773 if (parseSostOperation(parser, allOperands, result, size, type, false))
774 return failure();
775
776 dataOperandsTypes.assign(size, type);
777 result.addTypes({type});
778 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
779 result.operands))
780 return failure();
781 return success();
782}
783
784void BranchOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
785
786ParseResult ConditionalBranchOp::parse(OpAsmParser &parser,
787 OperationState &result) {
788 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
789 Type dataType;
790 SmallVector<Type> operandTypes;
791 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
792 if (parser.parseOperandList(allOperands) ||
793 parser.parseOptionalAttrDict(result.attributes) ||
794 parser.parseColonType(dataType))
795 return failure();
796
797 if (allOperands.size() != 2)
798 return parser.emitError(parser.getCurrentLocation(),
799 "Expected exactly 2 operands");
800
801 result.addTypes({dataType, dataType});
802 operandTypes.push_back(IntegerType::get(parser.getContext(), 1));
803 operandTypes.push_back(dataType);
804 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
805 result.operands))
806 return failure();
807
808 return success();
809}
810
811void ConditionalBranchOp::print(OpAsmPrinter &p) {
812 Type type = getDataOperand().getType();
813 p << " " << getOperands();
814 p.printOptionalAttrDict((*this)->getAttrs());
815 p << " : " << type;
816}
817
818std::string handshake::ConditionalBranchOp::getOperandName(unsigned int idx) {
819 assert(idx == 0 || idx == 1);
820 return idx == 0 ? "cond" : "data";
821}
822
823std::string handshake::ConditionalBranchOp::getResultName(unsigned int idx) {
824 assert(idx == 0 || idx == 1);
825 return idx == ConditionalBranchOp::falseIndex ? "outFalse" : "outTrue";
826}
827
828bool ConditionalBranchOp::isControl() {
829 return isControlCheckTypeAndOperand(getDataOperand().getType(),
830 getDataOperand());
831}
832
833ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
834 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
835 Type type;
836 ArrayRef<Type> operandTypes(type);
837 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
838 int size;
839 if (parseSostOperation(parser, allOperands, result, size, type, false))
840 return failure();
841
842 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
843 result.operands))
844 return failure();
845 return success();
846}
847
848void SinkOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
849
850std::string handshake::ConstantOp::getOperandName(unsigned int idx) {
851 assert(idx == 0);
852 return "ctrl";
853}
854
855Type SourceOp::getDataType() { return getResult().getType(); }
856unsigned SourceOp::getSize() { return 1; }
857
858ParseResult SourceOp::parse(OpAsmParser &parser, OperationState &result) {
859 if (parser.parseOptionalAttrDict(result.attributes))
860 return failure();
861 result.addTypes(NoneType::get(result.getContext()));
862 return success();
863}
864
865void SourceOp::print(OpAsmPrinter &p) {
866 p.printOptionalAttrDict((*this)->getAttrs());
867}
868
869LogicalResult ConstantOp::verify() {
870 // Verify that the type of the provided value is equal to the result type.
871 auto typedValue = dyn_cast<mlir::TypedAttr>(getValue());
872 if (!typedValue)
873 return emitOpError("constant value must be a typed attribute; value is ")
874 << getValue();
875 if (typedValue.getType() != getResult().getType())
876 return emitOpError() << "constant value type " << typedValue.getType()
877 << " differs from operation result type "
878 << getResult().getType();
879
880 return success();
881}
882
883void handshake::ConstantOp::getCanonicalizationPatterns(
884 RewritePatternSet &results, MLIRContext *context) {
885 results.insert<circt::handshake::EliminateSunkConstantsPattern>(context);
886}
887
888LogicalResult BufferOp::verify() {
889 // Verify that exactly 'size' number of initial values have been provided, if
890 // an initializer list have been provided.
891 if (auto initVals = getInitValues()) {
892 if (!isSequential())
893 return emitOpError()
894 << "only bufferType buffers are allowed to have initial values.";
895
896 auto nInits = initVals->size();
897 if (nInits != getSize())
898 return emitOpError() << "expected " << getSize()
899 << " init values but got " << nInits << ".";
900 }
901
902 return success();
903}
904
905void handshake::BufferOp::getCanonicalizationPatterns(
906 RewritePatternSet &results, MLIRContext *context) {
907 results.insert<circt::handshake::EliminateSunkBuffersPattern>(context);
908}
909
910unsigned BufferOp::getSize() {
911 return (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
912}
913
914ParseResult BufferOp::parse(OpAsmParser &parser, OperationState &result) {
915 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
916 Type type;
917 ArrayRef<Type> operandTypes(type);
918 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
919 int slots;
920 if (parseIntInSquareBrackets(parser, slots))
921 return failure();
922
923 auto bufferTypeAttr = BufferTypeEnumAttr::parse(parser, {});
924 if (!bufferTypeAttr)
925 return failure();
926
927 result.addAttribute(
928 "slots",
929 IntegerAttr::get(IntegerType::get(result.getContext(), 32), slots));
930 result.addAttribute("bufferType", bufferTypeAttr);
931
932 if (parser.parseOperandList(allOperands) ||
933 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
934 parser.parseType(type))
935 return failure();
936
937 result.addTypes({type});
938 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
939 result.operands))
940 return failure();
941 return success();
942}
943
944void BufferOp::print(OpAsmPrinter &p) {
945 int size =
946 (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
947 p << " [" << size << "]";
948 p << " " << stringifyEnum(getBufferType());
949 p << " " << (*this)->getOperands();
950 p.printOptionalAttrDict((*this)->getAttrs(), {"slots", "bufferType"});
951 p << " : " << (*this).getDataType();
952}
953
954static std::string getMemoryOperandName(unsigned nStores, unsigned idx) {
955 std::string name;
956 if (idx < nStores * 2) {
957 bool isData = idx % 2 == 0;
958 name = isData ? "stData" + std::to_string(idx / 2)
959 : "stAddr" + std::to_string(idx / 2);
960 } else {
961 idx -= 2 * nStores;
962 name = "ldAddr" + std::to_string(idx);
963 }
964 return name;
965}
966
967std::string handshake::MemoryOp::getOperandName(unsigned int idx) {
968 return getMemoryOperandName(getStCount(), idx);
969}
970
971static std::string getMemoryResultName(unsigned nLoads, unsigned nStores,
972 unsigned idx) {
973 std::string name;
974 if (idx < nLoads)
975 name = "ldData" + std::to_string(idx);
976 else if (idx < nLoads + nStores)
977 name = "stDone" + std::to_string(idx - nLoads);
978 else
979 name = "ldDone" + std::to_string(idx - nLoads - nStores);
980 return name;
981}
982
983std::string handshake::MemoryOp::getResultName(unsigned int idx) {
984 return getMemoryResultName(getLdCount(), getStCount(), idx);
985}
986
987LogicalResult MemoryOp::verify() {
988 auto memrefType = getMemRefType();
989
990 if (memrefType.getNumDynamicDims() != 0)
991 return emitOpError()
992 << "memref dimensions for handshake.memory must be static.";
993 if (memrefType.getShape().size() != 1)
994 return emitOpError() << "memref must have only a single dimension.";
995
996 unsigned opStCount = getStCount();
997 unsigned opLdCount = getLdCount();
998 int addressCount = memrefType.getShape().size();
999
1000 auto inputType = getInputs().getType();
1001 auto outputType = getOutputs().getType();
1002 Type dataType = memrefType.getElementType();
1003
1004 unsigned numOperands = static_cast<int>(getInputs().size());
1005 unsigned numResults = static_cast<int>(getOutputs().size());
1006 if (numOperands != (1 + addressCount) * opStCount + addressCount * opLdCount)
1007 return emitOpError("number of operands ")
1008 << numOperands << " does not match number expected of "
1009 << 2 * opStCount + opLdCount << " with " << addressCount
1010 << " address inputs per port";
1011
1012 if (numResults != opStCount + 2 * opLdCount)
1013 return emitOpError("number of results ")
1014 << numResults << " does not match number expected of "
1015 << opStCount + 2 * opLdCount << " with " << addressCount
1016 << " address inputs per port";
1017
1018 Type addressType = opStCount > 0 ? inputType[1] : inputType[0];
1019
1020 for (unsigned i = 0; i < opStCount; i++) {
1021 if (inputType[2 * i] != dataType)
1022 return emitOpError("data type for store port ")
1023 << i << ":" << inputType[2 * i] << " doesn't match memory type "
1024 << dataType;
1025 if (inputType[2 * i + 1] != addressType)
1026 return emitOpError("address type for store port ")
1027 << i << ":" << inputType[2 * i + 1]
1028 << " doesn't match address type " << addressType;
1029 }
1030 for (unsigned i = 0; i < opLdCount; i++) {
1031 Type ldAddressType = inputType[2 * opStCount + i];
1032 if (ldAddressType != addressType)
1033 return emitOpError("address type for load port ")
1034 << i << ":" << ldAddressType << " doesn't match address type "
1035 << addressType;
1036 }
1037 for (unsigned i = 0; i < opLdCount; i++) {
1038 if (outputType[i] != dataType)
1039 return emitOpError("data type for load port ")
1040 << i << ":" << outputType[i] << " doesn't match memory type "
1041 << dataType;
1042 }
1043 for (unsigned i = 0; i < opStCount; i++) {
1044 Type syncType = outputType[opLdCount + i];
1045 if (!isa<NoneType>(syncType))
1046 return emitOpError("data type for sync port for store port ")
1047 << i << ":" << syncType << " is not 'none'";
1048 }
1049 for (unsigned i = 0; i < opLdCount; i++) {
1050 Type syncType = outputType[opLdCount + opStCount + i];
1051 if (!isa<NoneType>(syncType))
1052 return emitOpError("data type for sync port for load port ")
1053 << i << ":" << syncType << " is not 'none'";
1054 }
1055
1056 return success();
1057}
1058
1059std::string handshake::ExternalMemoryOp::getOperandName(unsigned int idx) {
1060 if (idx == 0)
1061 return "extmem";
1062
1063 return getMemoryOperandName(getStCount(), idx - 1);
1064}
1065
1066std::string handshake::ExternalMemoryOp::getResultName(unsigned int idx) {
1067 return getMemoryResultName(getLdCount(), getStCount(), idx);
1068}
1069
1070void ExternalMemoryOp::build(OpBuilder &builder, OperationState &result,
1071 Value memref, ValueRange inputs, int ldCount,
1072 int stCount, int id) {
1073 SmallVector<Value> ops;
1074 ops.push_back(memref);
1075 llvm::append_range(ops, inputs);
1076 result.addOperands(ops);
1077
1078 auto memrefType = cast<MemRefType>(memref.getType());
1079
1080 // Data outputs (get their type from memref)
1081 result.types.append(ldCount, memrefType.getElementType());
1082
1083 // Control outputs
1084 result.types.append(stCount + ldCount, builder.getNoneType());
1085
1086 // Memory ID (individual ID for each MemoryOp)
1087 Type i32Type = builder.getIntegerType(32);
1088 result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1089 result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, ldCount));
1090 result.addAttribute("stCount", builder.getIntegerAttr(i32Type, stCount));
1091}
1092
1093llvm::SmallVector<handshake::MemLoadInterface>
1094ExternalMemoryOp::getLoadPorts() {
1095 return ::getLoadPorts(*this);
1096}
1097
1098llvm::SmallVector<handshake::MemStoreInterface>
1099ExternalMemoryOp::getStorePorts() {
1100 return ::getStorePorts(*this);
1101}
1102
1103void MemoryOp::build(OpBuilder &builder, OperationState &result,
1104 ValueRange operands, int outputs, int controlOutputs,
1105 bool lsq, int id, Value memref) {
1106 result.addOperands(operands);
1107
1108 auto memrefType = cast<MemRefType>(memref.getType());
1109
1110 // Data outputs (get their type from memref)
1111 result.types.append(outputs, memrefType.getElementType());
1112
1113 // Control outputs
1114 result.types.append(controlOutputs, builder.getNoneType());
1115 result.addAttribute("lsq", builder.getBoolAttr(lsq));
1116 result.addAttribute("memRefType", TypeAttr::get(memrefType));
1117
1118 // Memory ID (individual ID for each MemoryOp)
1119 Type i32Type = builder.getIntegerType(32);
1120 result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
1121
1122 if (!lsq) {
1123 result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, outputs));
1124 result.addAttribute(
1125 "stCount", builder.getIntegerAttr(i32Type, controlOutputs - outputs));
1126 }
1127}
1128
1129llvm::SmallVector<handshake::MemLoadInterface> MemoryOp::getLoadPorts() {
1130 return ::getLoadPorts(*this);
1131}
1132
1133llvm::SmallVector<handshake::MemStoreInterface> MemoryOp::getStorePorts() {
1134 return ::getStorePorts(*this);
1135}
1136
1137bool handshake::MemoryOp::allocateMemory(
1138 llvm::DenseMap<unsigned, unsigned> &memoryMap,
1139 std::vector<std::vector<llvm::Any>> &store,
1140 std::vector<double> &storeTimes) {
1141 if (memoryMap.count(getId()))
1142 return false;
1143
1144 auto type = getMemRefType();
1145 std::vector<llvm::Any> in;
1146
1147 ArrayRef<int64_t> shape = type.getShape();
1148 int allocationSize = 1;
1149 unsigned count = 0;
1150 for (int64_t dim : shape) {
1151 if (dim > 0)
1152 allocationSize *= dim;
1153 else {
1154 assert(count < in.size());
1155 allocationSize *= llvm::any_cast<APInt>(in[count++]).getSExtValue();
1156 }
1157 }
1158 unsigned ptr = store.size();
1159 store.resize(ptr + 1);
1160 storeTimes.resize(ptr + 1);
1161 store[ptr].resize(allocationSize);
1162 storeTimes[ptr] = 0.0;
1163 mlir::Type elementType = type.getElementType();
1164 int width = elementType.getIntOrFloatBitWidth();
1165 for (int i = 0; i < allocationSize; i++) {
1166 if (isa<mlir::IntegerType>(elementType)) {
1167 store[ptr][i] = APInt(width, 0);
1168 } else if (isa<mlir::FloatType>(elementType)) {
1169 store[ptr][i] = APFloat(0.0);
1170 } else {
1171 llvm_unreachable("Unknown result type!\n");
1172 }
1173 }
1174
1175 memoryMap[getId()] = ptr;
1176 return true;
1177}
1178
1179std::string handshake::LoadOp::getOperandName(unsigned int idx) {
1180 unsigned nAddresses = getAddresses().size();
1181 std::string opName;
1182 if (idx < nAddresses)
1183 opName = "addrIn" + std::to_string(idx);
1184 else if (idx == nAddresses)
1185 opName = "dataFromMem";
1186 else
1187 opName = "ctrl";
1188 return opName;
1189}
1190
1191std::string handshake::LoadOp::getResultName(unsigned int idx) {
1192 std::string resName;
1193 if (idx == 0)
1194 resName = "dataOut";
1195 else
1196 resName = "addrOut" + std::to_string(idx - 1);
1197 return resName;
1198}
1199
1200void handshake::LoadOp::build(OpBuilder &builder, OperationState &result,
1201 Value memref, ValueRange indices) {
1202 // Address indices
1203 // result.addOperands(memref);
1204 result.addOperands(indices);
1205
1206 // Data type
1207 auto memrefType = cast<MemRefType>(memref.getType());
1208
1209 // Data output (from load to successor ops)
1210 result.types.push_back(memrefType.getElementType());
1211
1212 // Address outputs (to lsq)
1213 result.types.append(indices.size(), builder.getIndexType());
1214}
1215
1216static ParseResult parseMemoryAccessOp(OpAsmParser &parser,
1217 OperationState &result) {
1218 SmallVector<OpAsmParser::UnresolvedOperand, 4> addressOperands,
1219 remainingOperands, allOperands;
1220 SmallVector<Type, 1> parsedTypes, allTypes;
1221 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1222
1223 if (parser.parseLSquare() || parser.parseOperandList(addressOperands) ||
1224 parser.parseRSquare() || parser.parseOperandList(remainingOperands) ||
1225 parser.parseColon() || parser.parseTypeList(parsedTypes))
1226 return failure();
1227
1228 // The last type will be the data type of the operation; the prior will be the
1229 // address types.
1230 Type dataType = parsedTypes.back();
1231 auto parsedTypesRef = ArrayRef(parsedTypes);
1232 result.addTypes(dataType);
1233 result.addTypes(parsedTypesRef.drop_back());
1234 allOperands.append(addressOperands);
1235 allOperands.append(remainingOperands);
1236 allTypes.append(parsedTypes);
1237 allTypes.push_back(NoneType::get(result.getContext()));
1238 if (parser.resolveOperands(allOperands, allTypes, allOperandLoc,
1239 result.operands))
1240 return failure();
1241 return success();
1242}
1243
1244template <typename MemOp>
1245static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op) {
1246 p << " [";
1247 p << op.getAddresses();
1248 p << "] " << op.getData() << ", " << op.getCtrl() << " : ";
1249 llvm::interleaveComma(op.getAddresses(), p,
1250 [&](Value v) { p << v.getType(); });
1251 p << ", " << op.getData().getType();
1252}
1253
1254ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
1255 return parseMemoryAccessOp(parser, result);
1256}
1257
1258void LoadOp::print(OpAsmPrinter &p) { printMemoryAccessOp(p, *this); }
1259
1260std::string handshake::StoreOp::getOperandName(unsigned int idx) {
1261 unsigned nAddresses = getAddresses().size();
1262 std::string opName;
1263 if (idx < nAddresses)
1264 opName = "addrIn" + std::to_string(idx);
1265 else if (idx == nAddresses)
1266 opName = "dataIn";
1267 else
1268 opName = "ctrl";
1269 return opName;
1270}
1271
1272template <typename TMemoryOp>
1273static LogicalResult verifyMemoryAccessOp(TMemoryOp op) {
1274 if (op.getAddresses().size() == 0)
1275 return op.emitOpError() << "No addresses were specified";
1276
1277 return success();
1278}
1279
1280LogicalResult LoadOp::verify() { return verifyMemoryAccessOp(*this); }
1281
1282std::string handshake::StoreOp::getResultName(unsigned int idx) {
1283 std::string resName;
1284 if (idx == 0)
1285 resName = "dataToMem";
1286 else
1287 resName = "addrOut" + std::to_string(idx - 1);
1288 return resName;
1289}
1290
1291void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
1292 Value valueToStore, ValueRange indices) {
1293
1294 // Address indices
1295 result.addOperands(indices);
1296
1297 // Data
1298 result.addOperands(valueToStore);
1299
1300 // Data output (from store to LSQ)
1301 result.types.push_back(valueToStore.getType());
1302
1303 // Address outputs (from store to lsq)
1304 result.types.append(indices.size(), builder.getIndexType());
1305}
1306
1307LogicalResult StoreOp::verify() { return verifyMemoryAccessOp(*this); }
1308
1309ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
1310 return parseMemoryAccessOp(parser, result);
1311}
1312
1313void StoreOp::print(OpAsmPrinter &p) { return printMemoryAccessOp(p, *this); }
1314
1315bool JoinOp::isControl() { return true; }
1316
1317ParseResult JoinOp::parse(OpAsmParser &parser, OperationState &result) {
1318 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1319 SmallVector<Type> types;
1320
1321 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1322 if (parser.parseOperandList(operands) ||
1323 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1324 parser.parseTypeList(types))
1325 return failure();
1326
1327 if (parser.resolveOperands(operands, types, allOperandLoc, result.operands))
1328 return failure();
1329
1330 result.addTypes(NoneType::get(result.getContext()));
1331 return success();
1332}
1333
1334void JoinOp::print(OpAsmPrinter &p) {
1335 p << " " << getData();
1336 p.printOptionalAttrDict((*this)->getAttrs(), {"control"});
1337 p << " : " << getData().getTypes();
1338}
1339
1340LogicalResult
1341ESIInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1342 // Check that the module attribute was specified.
1343 auto fnAttr = this->getModuleAttr();
1344 assert(fnAttr && "requires a 'module' symbol reference attribute");
1345
1346 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
1347 if (!fn)
1348 return emitOpError() << "'" << fnAttr.getValue()
1349 << "' does not reference a valid handshake function";
1350
1351 // Verify that the operand and result types match the callee.
1352 auto fnType = fn.getFunctionType();
1353 if (fnType.getNumInputs() != getNumOperands() - NumFixedOperands)
1354 return emitOpError(
1355 "incorrect number of operands for the referenced handshake function");
1356
1357 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1358 Type operandType = getOperand(i + NumFixedOperands).getType();
1359 auto channelType = dyn_cast<esi::ChannelType>(operandType);
1360 if (!channelType)
1361 return emitOpError("operand type mismatch: expected channel type, but "
1362 "provided ")
1363 << operandType << " for operand number " << i;
1364 if (channelType.getInner() != fnType.getInput(i))
1365 return emitOpError("operand type mismatch: expected operand type ")
1366 << fnType.getInput(i) << ", but provided "
1367 << getOperand(i).getType() << " for operand number " << i;
1368 }
1369
1370 if (fnType.getNumResults() != getNumResults())
1371 return emitOpError(
1372 "incorrect number of results for the referenced handshake function");
1373
1374 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1375 Type resultType = getResult(i).getType();
1376 auto channelType = dyn_cast<esi::ChannelType>(resultType);
1377 if (!channelType)
1378 return emitOpError("result type mismatch: expected channel type, but "
1379 "provided ")
1380 << resultType << " for result number " << i;
1381 if (channelType.getInner() != fnType.getResult(i))
1382 return emitOpError("result type mismatch: expected result type ")
1383 << fnType.getResult(i) << ", but provided "
1384 << getResult(i).getType() << " for result number " << i;
1385 }
1386
1387 return success();
1388}
1389
1390/// Based on mlir::func::CallOp::verifySymbolUses
1391LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1392 // Check that the module attribute was specified.
1393 auto fnAttr = this->getModuleAttr();
1394 assert(fnAttr && "requires a 'module' symbol reference attribute");
1395
1396 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
1397 if (!fn)
1398 return emitOpError() << "'" << fnAttr.getValue()
1399 << "' does not reference a valid handshake function";
1400
1401 // Verify that the operand and result types match the callee.
1402 auto fnType = fn.getFunctionType();
1403 if (fnType.getNumInputs() != getNumOperands())
1404 return emitOpError(
1405 "incorrect number of operands for the referenced handshake function");
1406
1407 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
1408 if (getOperand(i).getType() != fnType.getInput(i))
1409 return emitOpError("operand type mismatch: expected operand type ")
1410 << fnType.getInput(i) << ", but provided "
1411 << getOperand(i).getType() << " for operand number " << i;
1412
1413 if (fnType.getNumResults() != getNumResults())
1414 return emitOpError(
1415 "incorrect number of results for the referenced handshake function");
1416
1417 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
1418 if (getResult(i).getType() != fnType.getResult(i))
1419 return emitOpError("result type mismatch: expected result type ")
1420 << fnType.getResult(i) << ", but provided "
1421 << getResult(i).getType() << " for result number " << i;
1422
1423 return success();
1424}
1425
1426FunctionType InstanceOp::getModuleType() {
1427 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
1428}
1429
1430ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1431 OpAsmParser::UnresolvedOperand tuple;
1432 TupleType type;
1433
1434 if (parser.parseOperand(tuple) ||
1435 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1436 parser.parseType(type))
1437 return failure();
1438
1439 if (parser.resolveOperand(tuple, type, result.operands))
1440 return failure();
1441
1442 result.addTypes(type.getTypes());
1443
1444 return success();
1445}
1446
1447void UnpackOp::print(OpAsmPrinter &p) {
1448 p << " " << getInput();
1449 p.printOptionalAttrDict((*this)->getAttrs());
1450 p << " : " << getInput().getType();
1451}
1452
1453ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1454 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1455 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1456 TupleType type;
1457
1458 if (parser.parseOperandList(operands) ||
1459 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1460 parser.parseType(type))
1461 return failure();
1462
1463 if (parser.resolveOperands(operands, type.getTypes(), allOperandLoc,
1464 result.operands))
1465 return failure();
1466
1467 result.addTypes(type);
1468
1469 return success();
1470}
1471
1472void PackOp::print(OpAsmPrinter &p) {
1473 p << " " << getInputs();
1474 p.printOptionalAttrDict((*this)->getAttrs());
1475 p << " : " << getResult().getType();
1476}
1477
1478//===----------------------------------------------------------------------===//
1479// TableGen'd op method definitions
1480//===----------------------------------------------------------------------===//
1481
1482LogicalResult ReturnOp::verify() {
1483 auto *parent = (*this)->getParentOp();
1484 auto function = dyn_cast<handshake::FuncOp>(parent);
1485 if (!function)
1486 return emitOpError("must have a handshake.func parent");
1487
1488 // The operand number and types must match the function signature.
1489 const auto &results = function.getResultTypes();
1490 if (getNumOperands() != results.size())
1491 return emitOpError("has ")
1492 << getNumOperands() << " operands, but enclosing function returns "
1493 << results.size();
1494
1495 for (unsigned i = 0, e = results.size(); i != e; ++i)
1496 if (getOperand(i).getType() != results[i])
1497 return emitError() << "type of return operand " << i << " ("
1498 << getOperand(i).getType()
1499 << ") doesn't match function result type ("
1500 << results[i] << ")";
1501
1502 return success();
1503}
1504
1505#define GET_OP_CLASSES
1506#include "circt/Dialect/Handshake/Handshake.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal, uint64_t numOperands)
Verifies whether an indexing value is wide enough to index into a provided number of operands.
static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result)
static ParseResult parseSostOperation(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, OperationState &result, int &size, Type &type, bool explicitSize)
static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op)
static ParseResult parseFuncOpArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Type > &resTypes, SmallVectorImpl< DictionaryAttr > &resAttrs)
Parses a FuncOp signature using mlir::function_interface_impl::parseFunctionSignature while getting a...
static Value getDematerialized(Value v)
Returns a dematerialized version of the value 'v', defined as the source of the value before passing ...
static bool isControlCheckTypeAndOperand(Type dataType, Value operand)
static std::string getMemoryResultName(unsigned nLoads, unsigned nStores, unsigned idx)
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v)
static LogicalResult verifyMemoryAccessOp(TMemoryOp op)
static SmallVector< Attribute > getFuncOpNames(Builder &builder, unsigned cnt, StringRef prefix)
Generates names for a handshake.func input and output arguments, based on the number of args as well ...
static void addStringToStringArrayAttr(Builder &builder, Operation *op, StringRef attrName, StringAttr str)
Helper function for appending a string to an array attribute, and rewriting the attribute back to the...
static std::string defaultOperandName(unsigned int idx)
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
static std::string getMemoryOperandName(unsigned nStores, unsigned idx)
static ParseResult parseMemoryAccessOp(OpAsmParser &parser, OperationState &result)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
bool isControlOpImpl(Operation *op)
Default implementation for checking whether an operation is a control operation.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.