16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/IntegerSet.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/OpDefinition.h"
24#include "mlir/IR/OpImplementation.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/SymbolTable.h"
27#include "mlir/IR/Value.h"
28#include "mlir/Interfaces/FunctionImplementation.h"
29#include "mlir/Transforms/InliningUtils.h"
30#include "llvm/ADT/SetVector.h"
31#include "llvm/ADT/SmallBitVector.h"
32#include "llvm/ADT/TypeSwitch.h"
41#include "circt/Dialect/Handshake/HandshakeCanonicalization.h.inc"
44 if (
auto sostInterface = dyn_cast<SOSTInterface>(op); sostInterface)
45 return sostInterface.sostIsControl();
54 return "in" + std::to_string(idx);
58 if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
65 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
66 OperationState &result,
int &size, Type &type,
72 if (parser.parseOperandList(operands) ||
73 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
74 parser.parseType(type))
78 size = operands.size();
85 uint64_t numOperands) {
86 auto indexType = indexVal.getType();
90 if (
auto integerType = dyn_cast<IntegerType>(indexType))
91 indexWidth = integerType.getWidth();
92 else if (indexType.isIndex())
93 indexWidth = IndexType::kInternalStorageBitWidth;
95 return op->emitError(
"unsupported type for indexing value: ") << indexType;
98 if (indexWidth < 64) {
99 uint64_t maxNumOperands = (uint64_t)1 << indexWidth;
100 if (numOperands > maxNumOperands)
101 return op->emitError(
"bitwidth of indexing value is ")
102 << indexWidth <<
", which can index into " << maxNumOperands
103 <<
" operands, but found " << numOperands <<
" operands";
111 if (isa<NoneType>(dataType))
116 auto *defOp = operand.getDefiningOp();
117 return isa_and_nonnull<ControlMergeOp>(defOp) &&
118 operand == defOp->getResult(0);
121template <
typename TMemOp>
122llvm::SmallVector<handshake::MemLoadInterface>
getLoadPorts(TMemOp op) {
123 llvm::SmallVector<MemLoadInterface> ports;
130 unsigned stCount = op.getStCount();
131 unsigned ldCount = op.getLdCount();
132 for (
unsigned i = 0, e = ldCount; i != e; ++i) {
135 ldif.
addressIn = op.getInputs()[stCount * 2 + i];
136 ldif.
dataOut = op.getResult(i);
137 ldif.
doneOut = op.getResult(ldCount + stCount + i);
138 ports.push_back(ldif);
143template <
typename TMemOp>
145 llvm::SmallVector<MemStoreInterface> ports;
152 unsigned ldCount = op.getLdCount();
153 for (
unsigned i = 0, e = op.getStCount(); i != e; ++i) {
156 stif.
dataIn = op.getInputs()[i * 2];
157 stif.
addressIn = op.getInputs()[i * 2 + 1];
158 stif.
doneOut = op.getResult(ldCount + i);
159 ports.push_back(stif);
164unsigned ForkOp::getSize() {
return getResults().size(); }
166static ParseResult
parseForkOp(OpAsmParser &parser, OperationState &result) {
167 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
169 ArrayRef<Type> operandTypes(type);
170 SmallVector<Type, 1> resultTypes;
171 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
176 resultTypes.assign(size, type);
177 result.addTypes(resultTypes);
178 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
184ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
188void ForkOp::print(OpAsmPrinter &p) { sostPrint(p,
true); }
195 LogicalResult matchAndRewrite(ForkOp op,
196 PatternRewriter &rewriter)
const override {
197 std::set<unsigned> unusedIndexes;
199 for (
auto res :
llvm::enumerate(op.getResults()))
200 if (res.value().getUses().
empty())
201 unusedIndexes.insert(res.index());
203 if (unusedIndexes.empty())
207 rewriter.setInsertionPoint(op);
208 auto operand = op.getOperand();
209 auto newFork = ForkOp::create(rewriter, op.getLoc(), operand,
210 op.getNumResults() - unusedIndexes.size());
212 for (
auto oldRes :
llvm::enumerate(op.getResults()))
213 if (unusedIndexes.count(oldRes.index()) == 0)
214 rewriter.replaceAllUsesWith(oldRes.value(), newFork.getResults()[i++]);
215 rewriter.eraseOp(op);
223 LogicalResult matchAndRewrite(ForkOp op,
224 PatternRewriter &rewriter)
const override {
225 auto parentForkOp = op.getOperand().getDefiningOp<ForkOp>();
233 unsigned totalNumOuts = op.getSize() + parentForkOp.getSize();
236 auto newParentForkOp =
237 ForkOp::create(rewriter, parentForkOp.getLoc(),
238 parentForkOp.getOperand(), totalNumOuts);
241 llvm::zip(parentForkOp->getResults(), newParentForkOp.getResults()))
242 rewriter.replaceAllUsesWith(std::
get<0>(it), std::
get<1>(it));
246 rewriter.replaceOp(op,
247 newParentForkOp.getResults().take_back(op.getSize()));
248 rewriter.eraseOp(parentForkOp);
255void handshake::ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
256 MLIRContext *context) {
257 results.insert<circt::handshake::EliminateSimpleForksPattern,
258 EliminateForkToForkPattern, EliminateUnusedForkResultsPattern>(
262unsigned LazyForkOp::getSize() {
return getResults().size(); }
264bool LazyForkOp::sostIsControl() {
268ParseResult LazyForkOp::parse(OpAsmParser &parser, OperationState &result) {
272void LazyForkOp::print(OpAsmPrinter &p) { sostPrint(p,
true); }
274ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
275 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
277 ArrayRef<Type> operandTypes(type);
278 SmallVector<Type, 1> resultTypes, dataOperandsTypes;
279 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
284 dataOperandsTypes.assign(size, type);
285 resultTypes.push_back(type);
286 result.addTypes(resultTypes);
287 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
293void MergeOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
295void MergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
296 MLIRContext *context) {
297 results.insert<circt::handshake::EliminateSimpleMergesPattern>(context);
303 Operation *parentOp = v.getDefiningOp();
307 return llvm::TypeSwitch<Operation *, Value>(parentOp)
312 .Default([&](
auto) {
return v; });
322 LogicalResult matchAndRewrite(MuxOp op,
323 PatternRewriter &rewriter)
const override {
325 if (!llvm::all_of(op.getDataOperands(), [&](Value operand) {
326 return getDematerialized(operand) == firstDataOperand;
329 rewriter.replaceOp(op, firstDataOperand);
336 LogicalResult matchAndRewrite(MuxOp op,
337 PatternRewriter &rewriter)
const override {
338 if (op.getDataOperands().size() != 1)
341 rewriter.replaceOp(op, op.getDataOperands()[0]);
348 LogicalResult matchAndRewrite(MuxOp op,
349 PatternRewriter &rewriter)
const override {
351 auto dataOperands = op.getDataOperands();
352 if (dataOperands.size() != 2)
356 ConditionalBranchOp firstParentCBranch =
357 dataOperands[0].getDefiningOp<ConditionalBranchOp>();
358 if (!firstParentCBranch)
360 auto secondParentCBranch =
361 dataOperands[1].getDefiningOp<ConditionalBranchOp>();
362 if (!secondParentCBranch || firstParentCBranch != secondParentCBranch)
365 rewriter.modifyOpInPlace(firstParentCBranch, [&] {
367 rewriter.replaceOp(op, firstParentCBranch.getDataOperand());
376void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
377 MLIRContext *context) {
378 results.insert<EliminateSimpleMuxesPattern, EliminateUnaryMuxesPattern,
379 EliminateCBranchIntoMuxPattern>(context);
383MuxOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
384 ValueRange operands, DictionaryAttr attributes,
385 mlir::OpaqueProperties properties,
386 mlir::RegionRange regions,
387 SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
390 if (operands.size() < 2)
393 inferredReturnTypes.push_back(operands[1].getType());
397bool MuxOp::isControl() {
return isa<NoneType>(getResult().getType()); }
399std::string handshake::MuxOp::getOperandName(
unsigned int idx) {
403ParseResult MuxOp::parse(OpAsmParser &parser, OperationState &result) {
404 OpAsmParser::UnresolvedOperand selectOperand;
405 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
406 Type selectType, dataType;
407 SmallVector<Type, 1> dataOperandsTypes;
408 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
409 if (parser.parseOperand(selectOperand) || parser.parseLSquare() ||
410 parser.parseOperandList(allOperands) || parser.parseRSquare() ||
411 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
412 parser.parseType(selectType) || parser.parseComma() ||
413 parser.parseType(dataType))
416 int size = allOperands.size();
417 dataOperandsTypes.assign(size, dataType);
418 result.addTypes(dataType);
419 allOperands.insert(allOperands.begin(), selectOperand);
420 if (parser.resolveOperands(
422 llvm::concat<const Type>(ArrayRef<Type>(selectType),
423 ArrayRef<Type>(dataOperandsTypes)),
424 allOperandLoc, result.operands))
429void MuxOp::print(OpAsmPrinter &p) {
430 Type selectType = getSelectOperand().getType();
431 auto ops = getOperands();
432 p <<
' ' << ops.front();
434 p.printOperands(ops.drop_front());
436 p.printOptionalAttrDict((*this)->getAttrs());
437 p <<
" : " << selectType <<
", " << getResult().getType();
440LogicalResult MuxOp::verify() {
442 getDataOperands().size());
445std::string handshake::ControlMergeOp::getResultName(
unsigned int idx) {
446 assert(idx == 0 || idx == 1);
447 return idx == 0 ?
"dataOut" :
"index";
450ParseResult ControlMergeOp::parse(OpAsmParser &parser, OperationState &result) {
451 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
452 Type resultType, indexType;
453 SmallVector<Type> resultTypes, dataOperandsTypes;
454 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
459 if (parser.parseComma() || parser.parseType(indexType))
462 dataOperandsTypes.assign(size, resultType);
463 resultTypes.push_back(resultType);
464 resultTypes.push_back(indexType);
465 result.addTypes(resultTypes);
466 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
472void ControlMergeOp::print(OpAsmPrinter &p) {
475 p <<
", " << getIndex().getType();
478LogicalResult ControlMergeOp::verify() {
479 auto operands = getOperands();
480 if (operands.empty())
481 return emitOpError(
"operation must have at least one operand");
482 if (operands[0].getType() != getResult().getType())
483 return emitOpError(
"type of first result should match type of operands");
487LogicalResult FuncOp::verify() {
495 auto fnInputTypes = getArgumentTypes();
496 Block &entryBlock = front();
498 for (
unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
499 if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
500 return emitOpError(
"type of entry block argument #")
501 << i <<
'(' << entryBlock.getArgument(i).getType()
502 <<
") must match the type of the corresponding argument in "
503 <<
"function signature(" << fnInputTypes[i] <<
')';
506 auto verifyPortNameAttr = [&](StringRef attrName,
507 unsigned numIOs) -> LogicalResult {
508 auto portNamesAttr = (*this)->getAttrOfType<ArrayAttr>(attrName);
511 return emitOpError() <<
"expected attribute '" << attrName <<
"'.";
513 auto portNames = portNamesAttr.getValue();
514 if (portNames.size() != numIOs)
515 return emitOpError() <<
"attribute '" << attrName <<
"' has "
517 <<
" entries but is expected to have " << numIOs
520 if (llvm::any_of(portNames,
521 [&](Attribute attr) {
return !isa<StringAttr>(attr); }))
522 return emitOpError() <<
"expected all entries in attribute '" << attrName
523 <<
"' to be strings.";
527 if (failed(verifyPortNameAttr(
"argNames", getNumArguments())))
529 if (failed(verifyPortNameAttr(
"resNames", getNumResults())))
533 for (
auto arg : entryBlock.getArguments()) {
534 if (!isa<MemRefType>(arg.getType()))
536 if (arg.getUsers().empty() ||
537 !isa<ExternalMemoryOp>(*arg.getUsers().begin()))
538 return emitOpError(
"expected that block argument #")
539 << arg.getArgNumber() <<
" is used by an 'extmemory' operation";
550 SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
551 SmallVectorImpl<Type> &resTypes,
552 SmallVectorImpl<DictionaryAttr> &resAttrs) {
554 if (mlir::function_interface_impl::parseFunctionSignatureWithArguments(
555 parser,
true, entryArgs, isVariadic, resTypes,
567 SmallVector<Attribute> resNames;
568 for (
unsigned i = 0; i < cnt; ++i)
569 resNames.push_back(builder.getStringAttr(prefix + std::to_string(i)));
573void handshake::FuncOp::build(OpBuilder &builder, OperationState &state,
574 StringRef name, FunctionType type,
575 ArrayRef<NamedAttribute> attrs) {
576 state.addAttribute(SymbolTable::getSymbolAttrName(),
577 builder.getStringAttr(name));
578 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
579 TypeAttr::get(type));
580 state.attributes.append(attrs.begin(), attrs.end());
582 if (
const auto *argNamesAttrIt = llvm::find_if(
583 attrs, [&](
auto attr) {
return attr.getName() ==
"argNames"; });
584 argNamesAttrIt == attrs.end())
585 state.addAttribute(
"argNames", builder.getArrayAttr({}));
587 if (llvm::find_if(attrs, [&](
auto attr) {
588 return attr.getName() ==
"resNames";
590 state.addAttribute(
"resNames", builder.getArrayAttr({}));
598 StringRef attrName, StringAttr str) {
599 llvm::SmallVector<Attribute> attrs;
600 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
601 std::back_inserter(attrs));
602 attrs.push_back(str);
603 op->setAttr(attrName, builder.getArrayAttr(attrs));
606void handshake::FuncOp::resolveArgAndResNames() {
607 Builder builder(getContext());
611 auto fallbackArgNames =
getFuncOpNames(builder, getNumArguments(),
"in");
612 auto fallbackResNames =
getFuncOpNames(builder, getNumResults(),
"out");
613 auto argNames = getArgNames().getValue();
614 auto resNames = getResNames().getValue();
617 auto resolveNames = [&](
auto &fallbackNames,
auto &actualNames,
618 StringRef attrName) {
619 for (
auto fallbackName :
llvm::enumerate(fallbackNames)) {
620 if (actualNames.size() <= fallbackName.index())
622 cast<StringAttr>(fallbackName.value()));
625 resolveNames(fallbackArgNames, argNames,
"argNames");
626 resolveNames(fallbackResNames, resNames,
"resNames");
629ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
630 auto &builder = parser.getBuilder();
632 SmallVector<OpAsmParser::Argument> args;
633 SmallVector<Type> resTypes;
634 SmallVector<DictionaryAttr> resAttributes;
635 SmallVector<Attribute> argNames;
638 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
641 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
642 result.attributes) ||
645 mlir::call_interface_impl::addArgAndResultAttrs(
646 builder, result, args, resAttributes,
647 handshake::FuncOp::getArgAttrsAttrName(result.name),
648 handshake::FuncOp::getResAttrsAttrName(result.name));
651 SmallVector<Type> argTypes;
652 for (
auto arg : args)
653 argTypes.push_back(arg.type);
656 handshake::FuncOp::getFunctionTypeAttrName(result.name),
657 TypeAttr::get(builder.getFunctionType(argTypes, resTypes)));
662 llvm::any_of(args, [](
auto arg) {
return arg.ssaName.name.empty(); });
666 llvm::transform(args, std::back_inserter(argNames), [&](
auto arg) {
667 return builder.getStringAttr(arg.ssaName.name.drop_front());
672 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
677 if (!result.attributes.get(
"argNames"))
678 result.addAttribute(
"argNames", builder.getArrayAttr(argNames));
679 if (!result.attributes.get(
"resNames")) {
681 result.addAttribute(
"resNames", builder.getArrayAttr(resNames));
686 auto *body = result.addRegion();
687 llvm::SMLoc loc = parser.getCurrentLocation();
688 auto parseResult = parser.parseOptionalRegion(*body, args,
690 if (!parseResult.has_value())
693 if (failed(*parseResult))
697 return parser.emitError(loc,
"expected non-empty function body");
703void FuncOp::print(OpAsmPrinter &p) {
704 mlir::function_interface_impl::printFunctionOp(
705 p, *
this,
true, getFunctionTypeAttrName(),
706 getArgAttrsAttrName(), getResAttrsAttrName());
710struct EliminateSimpleControlMergesPattern
714 LogicalResult matchAndRewrite(ControlMergeOp op,
715 PatternRewriter &rewriter)
const override;
719LogicalResult EliminateSimpleControlMergesPattern::matchAndRewrite(
720 ControlMergeOp op, PatternRewriter &rewriter)
const {
721 auto dataResult = op.getResult();
722 auto choiceResult = op.getIndex();
723 auto choiceUnused = choiceResult.use_empty();
724 if (!choiceUnused && !choiceResult.hasOneUse())
727 Operation *choiceUser =
nullptr;
728 if (choiceResult.hasOneUse()) {
729 choiceUser = choiceResult.getUses().begin().getUser();
730 if (!isa<SinkOp>(choiceUser))
734 auto merge = MergeOp::create(rewriter, op.getLoc(), op.getDataOperands());
736 for (
auto &use :
llvm::make_early_inc_range(dataResult.getUses())) {
737 auto *user = use.getOwner();
738 rewriter.modifyOpInPlace(
739 user, [&]() { user->setOperand(use.getOperandNumber(), merge); });
743 rewriter.eraseOp(op);
747 rewriter.eraseOp(choiceUser);
748 rewriter.eraseOp(op);
752void ControlMergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
753 MLIRContext *context) {
754 results.insert<EliminateSimpleControlMergesPattern>(context);
757bool BranchOp::sostIsControl() {
761void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
762 MLIRContext *context) {
763 results.insert<circt::handshake::EliminateSimpleBranchesPattern>(context);
766ParseResult BranchOp::parse(OpAsmParser &parser, OperationState &result) {
767 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
769 ArrayRef<Type> operandTypes(type);
770 SmallVector<Type, 1> dataOperandsTypes;
771 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
776 dataOperandsTypes.assign(size, type);
777 result.addTypes({type});
778 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
784void BranchOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
786ParseResult ConditionalBranchOp::parse(OpAsmParser &parser,
787 OperationState &result) {
788 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
790 SmallVector<Type> operandTypes;
791 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
792 if (parser.parseOperandList(allOperands) ||
793 parser.parseOptionalAttrDict(result.attributes) ||
794 parser.parseColonType(dataType))
797 if (allOperands.size() != 2)
798 return parser.emitError(parser.getCurrentLocation(),
799 "Expected exactly 2 operands");
801 result.addTypes({dataType, dataType});
802 operandTypes.push_back(IntegerType::get(parser.getContext(), 1));
803 operandTypes.push_back(dataType);
804 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
811void ConditionalBranchOp::print(OpAsmPrinter &p) {
812 Type type = getDataOperand().getType();
813 p <<
" " << getOperands();
814 p.printOptionalAttrDict((*this)->getAttrs());
818std::string handshake::ConditionalBranchOp::getOperandName(
unsigned int idx) {
819 assert(idx == 0 || idx == 1);
820 return idx == 0 ?
"cond" :
"data";
823std::string handshake::ConditionalBranchOp::getResultName(
unsigned int idx) {
824 assert(idx == 0 || idx == 1);
825 return idx == ConditionalBranchOp::falseIndex ?
"outFalse" :
"outTrue";
828bool ConditionalBranchOp::isControl() {
833ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
834 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
836 ArrayRef<Type> operandTypes(type);
837 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
842 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
848void SinkOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
850std::string handshake::ConstantOp::getOperandName(
unsigned int idx) {
855Type SourceOp::getDataType() {
return getResult().getType(); }
856unsigned SourceOp::getSize() {
return 1; }
858ParseResult SourceOp::parse(OpAsmParser &parser, OperationState &result) {
859 if (parser.parseOptionalAttrDict(result.attributes))
861 result.addTypes(NoneType::get(result.getContext()));
865void SourceOp::print(OpAsmPrinter &p) {
866 p.printOptionalAttrDict((*this)->getAttrs());
869LogicalResult ConstantOp::verify() {
871 auto typedValue = dyn_cast<mlir::TypedAttr>(getValue());
873 return emitOpError(
"constant value must be a typed attribute; value is ")
875 if (typedValue.getType() != getResult().getType())
876 return emitOpError() <<
"constant value type " << typedValue.getType()
877 <<
" differs from operation result type "
878 << getResult().getType();
883void handshake::ConstantOp::getCanonicalizationPatterns(
884 RewritePatternSet &results, MLIRContext *context) {
885 results.insert<circt::handshake::EliminateSunkConstantsPattern>(context);
888LogicalResult BufferOp::verify() {
891 if (
auto initVals = getInitValues()) {
894 <<
"only bufferType buffers are allowed to have initial values.";
896 auto nInits = initVals->size();
897 if (nInits != getSize())
898 return emitOpError() <<
"expected " << getSize()
899 <<
" init values but got " << nInits <<
".";
905void handshake::BufferOp::getCanonicalizationPatterns(
906 RewritePatternSet &results, MLIRContext *context) {
907 results.insert<circt::handshake::EliminateSunkBuffersPattern>(context);
910unsigned BufferOp::getSize() {
911 return (*this)->getAttrOfType<IntegerAttr>(
"slots").getValue().getZExtValue();
914ParseResult BufferOp::parse(OpAsmParser &parser, OperationState &result) {
915 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
917 ArrayRef<Type> operandTypes(type);
918 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
923 auto bufferTypeAttr = BufferTypeEnumAttr::parse(parser, {});
929 IntegerAttr::get(IntegerType::get(result.getContext(), 32), slots));
930 result.addAttribute(
"bufferType", bufferTypeAttr);
932 if (parser.parseOperandList(allOperands) ||
933 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
934 parser.parseType(type))
937 result.addTypes({type});
938 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
944void BufferOp::print(OpAsmPrinter &p) {
946 (*this)->getAttrOfType<IntegerAttr>(
"slots").getValue().getZExtValue();
947 p <<
" [" << size <<
"]";
948 p <<
" " << stringifyEnum(getBufferType());
949 p <<
" " << (*this)->getOperands();
950 p.printOptionalAttrDict((*this)->getAttrs(), {
"slots",
"bufferType"});
951 p <<
" : " << (*this).getDataType();
956 if (idx < nStores * 2) {
957 bool isData = idx % 2 == 0;
958 name = isData ?
"stData" + std::to_string(idx / 2)
959 :
"stAddr" + std::to_string(idx / 2);
962 name =
"ldAddr" + std::to_string(idx);
967std::string handshake::MemoryOp::getOperandName(
unsigned int idx) {
975 name =
"ldData" + std::to_string(idx);
976 else if (idx < nLoads + nStores)
977 name =
"stDone" + std::to_string(idx - nLoads);
979 name =
"ldDone" + std::to_string(idx - nLoads - nStores);
983std::string handshake::MemoryOp::getResultName(
unsigned int idx) {
987LogicalResult MemoryOp::verify() {
988 auto memrefType = getMemRefType();
990 if (memrefType.getNumDynamicDims() != 0)
992 <<
"memref dimensions for handshake.memory must be static.";
993 if (memrefType.getShape().size() != 1)
994 return emitOpError() <<
"memref must have only a single dimension.";
996 unsigned opStCount = getStCount();
997 unsigned opLdCount = getLdCount();
998 int addressCount = memrefType.getShape().size();
1000 auto inputType = getInputs().getType();
1001 auto outputType = getOutputs().getType();
1002 Type dataType = memrefType.getElementType();
1004 unsigned numOperands =
static_cast<int>(getInputs().size());
1005 unsigned numResults =
static_cast<int>(getOutputs().size());
1006 if (numOperands != (1 + addressCount) * opStCount + addressCount * opLdCount)
1007 return emitOpError(
"number of operands ")
1008 << numOperands <<
" does not match number expected of "
1009 << 2 * opStCount + opLdCount <<
" with " << addressCount
1010 <<
" address inputs per port";
1012 if (numResults != opStCount + 2 * opLdCount)
1013 return emitOpError(
"number of results ")
1014 << numResults <<
" does not match number expected of "
1015 << opStCount + 2 * opLdCount <<
" with " << addressCount
1016 <<
" address inputs per port";
1018 Type addressType = opStCount > 0 ? inputType[1] : inputType[0];
1020 for (
unsigned i = 0; i < opStCount; i++) {
1021 if (inputType[2 * i] != dataType)
1022 return emitOpError(
"data type for store port ")
1023 << i <<
":" << inputType[2 * i] <<
" doesn't match memory type "
1025 if (inputType[2 * i + 1] != addressType)
1026 return emitOpError(
"address type for store port ")
1027 << i <<
":" << inputType[2 * i + 1]
1028 <<
" doesn't match address type " << addressType;
1030 for (
unsigned i = 0; i < opLdCount; i++) {
1031 Type ldAddressType = inputType[2 * opStCount + i];
1032 if (ldAddressType != addressType)
1033 return emitOpError(
"address type for load port ")
1034 << i <<
":" << ldAddressType <<
" doesn't match address type "
1037 for (
unsigned i = 0; i < opLdCount; i++) {
1038 if (outputType[i] != dataType)
1039 return emitOpError(
"data type for load port ")
1040 << i <<
":" << outputType[i] <<
" doesn't match memory type "
1043 for (
unsigned i = 0; i < opStCount; i++) {
1044 Type syncType = outputType[opLdCount + i];
1045 if (!isa<NoneType>(syncType))
1046 return emitOpError(
"data type for sync port for store port ")
1047 << i <<
":" << syncType <<
" is not 'none'";
1049 for (
unsigned i = 0; i < opLdCount; i++) {
1050 Type syncType = outputType[opLdCount + opStCount + i];
1051 if (!isa<NoneType>(syncType))
1052 return emitOpError(
"data type for sync port for load port ")
1053 << i <<
":" << syncType <<
" is not 'none'";
1059std::string handshake::ExternalMemoryOp::getOperandName(
unsigned int idx) {
1066std::string handshake::ExternalMemoryOp::getResultName(
unsigned int idx) {
1070void ExternalMemoryOp::build(OpBuilder &builder, OperationState &result,
1071 Value memref, ValueRange inputs,
int ldCount,
1072 int stCount,
int id) {
1073 SmallVector<Value> ops;
1074 ops.push_back(memref);
1075 llvm::append_range(ops, inputs);
1076 result.addOperands(ops);
1078 auto memrefType = cast<MemRefType>(memref.getType());
1081 result.types.append(ldCount, memrefType.getElementType());
1084 result.types.append(stCount + ldCount, builder.getNoneType());
1087 Type i32Type = builder.getIntegerType(32);
1088 result.addAttribute(
"id", builder.getIntegerAttr(i32Type,
id));
1089 result.addAttribute(
"ldCount", builder.getIntegerAttr(i32Type, ldCount));
1090 result.addAttribute(
"stCount", builder.getIntegerAttr(i32Type, stCount));
1093llvm::SmallVector<handshake::MemLoadInterface>
1094ExternalMemoryOp::getLoadPorts() {
1095 return ::getLoadPorts(*
this);
1098llvm::SmallVector<handshake::MemStoreInterface>
1099ExternalMemoryOp::getStorePorts() {
1100 return ::getStorePorts(*
this);
1103void MemoryOp::build(OpBuilder &builder, OperationState &result,
1104 ValueRange operands,
int outputs,
int controlOutputs,
1105 bool lsq,
int id, Value memref) {
1106 result.addOperands(operands);
1108 auto memrefType = cast<MemRefType>(memref.getType());
1111 result.types.append(outputs, memrefType.getElementType());
1114 result.types.append(controlOutputs, builder.getNoneType());
1115 result.addAttribute(
"lsq", builder.getBoolAttr(lsq));
1116 result.addAttribute(
"memRefType", TypeAttr::get(memrefType));
1119 Type i32Type = builder.getIntegerType(32);
1120 result.addAttribute(
"id", builder.getIntegerAttr(i32Type,
id));
1123 result.addAttribute(
"ldCount", builder.getIntegerAttr(i32Type, outputs));
1124 result.addAttribute(
1125 "stCount", builder.getIntegerAttr(i32Type, controlOutputs - outputs));
1129llvm::SmallVector<handshake::MemLoadInterface> MemoryOp::getLoadPorts() {
1130 return ::getLoadPorts(*
this);
1133llvm::SmallVector<handshake::MemStoreInterface> MemoryOp::getStorePorts() {
1134 return ::getStorePorts(*
this);
1137bool handshake::MemoryOp::allocateMemory(
1138 llvm::DenseMap<unsigned, unsigned> &memoryMap,
1139 std::vector<std::vector<llvm::Any>> &store,
1140 std::vector<double> &storeTimes) {
1141 if (memoryMap.count(getId()))
1144 auto type = getMemRefType();
1145 std::vector<llvm::Any> in;
1147 ArrayRef<int64_t> shape = type.getShape();
1148 int allocationSize = 1;
1150 for (int64_t dim : shape) {
1152 allocationSize *= dim;
1154 assert(count < in.size());
1155 allocationSize *= llvm::any_cast<APInt>(in[count++]).getSExtValue();
1158 unsigned ptr = store.size();
1159 store.resize(ptr + 1);
1160 storeTimes.resize(ptr + 1);
1161 store[ptr].resize(allocationSize);
1162 storeTimes[ptr] = 0.0;
1165 for (
int i = 0; i < allocationSize; i++) {
1167 store[ptr][i] = APInt(width, 0);
1169 store[ptr][i] = APFloat(0.0);
1171 llvm_unreachable(
"Unknown result type!\n");
1175 memoryMap[getId()] = ptr;
1179std::string handshake::LoadOp::getOperandName(
unsigned int idx) {
1180 unsigned nAddresses = getAddresses().size();
1182 if (idx < nAddresses)
1183 opName =
"addrIn" + std::to_string(idx);
1184 else if (idx == nAddresses)
1185 opName =
"dataFromMem";
1191std::string handshake::LoadOp::getResultName(
unsigned int idx) {
1192 std::string resName;
1194 resName =
"dataOut";
1196 resName =
"addrOut" + std::to_string(idx - 1);
1200void handshake::LoadOp::build(OpBuilder &builder, OperationState &result,
1201 Value memref, ValueRange indices) {
1204 result.addOperands(indices);
1207 auto memrefType = cast<MemRefType>(memref.getType());
1210 result.types.push_back(memrefType.getElementType());
1213 result.types.append(indices.size(), builder.getIndexType());
1217 OperationState &result) {
1218 SmallVector<OpAsmParser::UnresolvedOperand, 4> addressOperands,
1219 remainingOperands, allOperands;
1220 SmallVector<Type, 1> parsedTypes, allTypes;
1221 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1223 if (parser.parseLSquare() || parser.parseOperandList(addressOperands) ||
1224 parser.parseRSquare() || parser.parseOperandList(remainingOperands) ||
1225 parser.parseColon() || parser.parseTypeList(parsedTypes))
1230 Type dataType = parsedTypes.back();
1231 auto parsedTypesRef = ArrayRef(parsedTypes);
1232 result.addTypes(dataType);
1233 result.addTypes(parsedTypesRef.drop_back());
1234 allOperands.append(addressOperands);
1235 allOperands.append(remainingOperands);
1236 allTypes.append(parsedTypes);
1237 allTypes.push_back(NoneType::get(result.getContext()));
1238 if (parser.resolveOperands(allOperands, allTypes, allOperandLoc,
1244template <
typename MemOp>
1247 p << op.getAddresses();
1248 p <<
"] " << op.getData() <<
", " << op.getCtrl() <<
" : ";
1249 llvm::interleaveComma(op.getAddresses(), p,
1250 [&](Value v) { p << v.getType(); });
1251 p <<
", " << op.getData().getType();
1254ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
1260std::string handshake::StoreOp::getOperandName(
unsigned int idx) {
1261 unsigned nAddresses = getAddresses().size();
1263 if (idx < nAddresses)
1264 opName =
"addrIn" + std::to_string(idx);
1265 else if (idx == nAddresses)
1272template <
typename TMemoryOp>
1274 if (op.getAddresses().size() == 0)
1275 return op.emitOpError() <<
"No addresses were specified";
1282std::string handshake::StoreOp::getResultName(
unsigned int idx) {
1283 std::string resName;
1285 resName =
"dataToMem";
1287 resName =
"addrOut" + std::to_string(idx - 1);
1291void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
1292 Value valueToStore, ValueRange indices) {
1295 result.addOperands(indices);
1298 result.addOperands(valueToStore);
1301 result.types.push_back(valueToStore.getType());
1304 result.types.append(indices.size(), builder.getIndexType());
1309ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
1315bool JoinOp::isControl() {
return true; }
1317ParseResult JoinOp::parse(OpAsmParser &parser, OperationState &result) {
1318 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1319 SmallVector<Type> types;
1321 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1322 if (parser.parseOperandList(operands) ||
1323 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1324 parser.parseTypeList(types))
1327 if (parser.resolveOperands(operands, types, allOperandLoc, result.operands))
1330 result.addTypes(NoneType::get(result.getContext()));
1334void JoinOp::print(OpAsmPrinter &p) {
1335 p <<
" " << getData();
1336 p.printOptionalAttrDict((*this)->getAttrs(), {
"control"});
1337 p <<
" : " << getData().getTypes();
1341ESIInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1343 auto fnAttr = this->getModuleAttr();
1344 assert(fnAttr &&
"requires a 'module' symbol reference attribute");
1346 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*
this, fnAttr);
1348 return emitOpError() <<
"'" << fnAttr.getValue()
1349 <<
"' does not reference a valid handshake function";
1352 auto fnType = fn.getFunctionType();
1353 if (fnType.getNumInputs() != getNumOperands() - NumFixedOperands)
1355 "incorrect number of operands for the referenced handshake function");
1357 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1358 Type operandType = getOperand(i + NumFixedOperands).getType();
1359 auto channelType = dyn_cast<esi::ChannelType>(operandType);
1361 return emitOpError(
"operand type mismatch: expected channel type, but "
1363 << operandType <<
" for operand number " << i;
1364 if (channelType.getInner() != fnType.getInput(i))
1365 return emitOpError(
"operand type mismatch: expected operand type ")
1366 << fnType.getInput(i) <<
", but provided "
1367 << getOperand(i).getType() <<
" for operand number " << i;
1370 if (fnType.getNumResults() != getNumResults())
1372 "incorrect number of results for the referenced handshake function");
1374 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1375 Type resultType = getResult(i).getType();
1376 auto channelType = dyn_cast<esi::ChannelType>(resultType);
1378 return emitOpError(
"result type mismatch: expected channel type, but "
1380 << resultType <<
" for result number " << i;
1381 if (channelType.getInner() != fnType.getResult(i))
1382 return emitOpError(
"result type mismatch: expected result type ")
1383 << fnType.getResult(i) <<
", but provided "
1384 << getResult(i).getType() <<
" for result number " << i;
1391LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1393 auto fnAttr = this->getModuleAttr();
1394 assert(fnAttr &&
"requires a 'module' symbol reference attribute");
1396 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*
this, fnAttr);
1398 return emitOpError() <<
"'" << fnAttr.getValue()
1399 <<
"' does not reference a valid handshake function";
1402 auto fnType = fn.getFunctionType();
1403 if (fnType.getNumInputs() != getNumOperands())
1405 "incorrect number of operands for the referenced handshake function");
1407 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
1408 if (getOperand(i).getType() != fnType.getInput(i))
1409 return emitOpError(
"operand type mismatch: expected operand type ")
1410 << fnType.getInput(i) <<
", but provided "
1411 << getOperand(i).getType() <<
" for operand number " << i;
1413 if (fnType.getNumResults() != getNumResults())
1415 "incorrect number of results for the referenced handshake function");
1417 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
1418 if (getResult(i).getType() != fnType.getResult(i))
1419 return emitOpError(
"result type mismatch: expected result type ")
1420 << fnType.getResult(i) <<
", but provided "
1421 << getResult(i).getType() <<
" for result number " << i;
1426FunctionType InstanceOp::getModuleType() {
1427 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
1430ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1431 OpAsmParser::UnresolvedOperand tuple;
1434 if (parser.parseOperand(tuple) ||
1435 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1436 parser.parseType(type))
1439 if (parser.resolveOperand(tuple, type, result.operands))
1442 result.addTypes(type.getTypes());
1447void UnpackOp::print(OpAsmPrinter &p) {
1448 p <<
" " << getInput();
1449 p.printOptionalAttrDict((*this)->getAttrs());
1450 p <<
" : " << getInput().getType();
1453ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1454 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1455 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1458 if (parser.parseOperandList(operands) ||
1459 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1460 parser.parseType(type))
1463 if (parser.resolveOperands(operands, type.getTypes(), allOperandLoc,
1467 result.addTypes(type);
1472void PackOp::print(OpAsmPrinter &p) {
1473 p <<
" " << getInputs();
1474 p.printOptionalAttrDict((*this)->getAttrs());
1475 p <<
" : " << getResult().getType();
1482LogicalResult ReturnOp::verify() {
1483 auto *parent = (*this)->getParentOp();
1484 auto function = dyn_cast<handshake::FuncOp>(parent);
1486 return emitOpError(
"must have a handshake.func parent");
1489 const auto &results = function.getResultTypes();
1490 if (getNumOperands() != results.size())
1491 return emitOpError(
"has ")
1492 << getNumOperands() <<
" operands, but enclosing function returns "
1495 for (
unsigned i = 0, e = results.size(); i != e; ++i)
1496 if (getOperand(i).getType() != results[i])
1497 return emitError() <<
"type of return operand " << i <<
" ("
1498 << getOperand(i).getType()
1499 <<
") doesn't match function result type ("
1500 << results[i] <<
")";
1505#define GET_OP_CLASSES
1506#include "circt/Dialect/Handshake/Handshake.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal, uint64_t numOperands)
Verifies whether an indexing value is wide enough to index into a provided number of operands.
static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result)
static ParseResult parseSostOperation(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, OperationState &result, int &size, Type &type, bool explicitSize)
static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op)
static ParseResult parseFuncOpArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Type > &resTypes, SmallVectorImpl< DictionaryAttr > &resAttrs)
Parses a FuncOp signature using mlir::function_interface_impl::parseFunctionSignature while getting a...
static Value getDematerialized(Value v)
Returns a dematerialized version of the value 'v', defined as the source of the value before passing ...
static bool isControlCheckTypeAndOperand(Type dataType, Value operand)
static std::string getMemoryResultName(unsigned nLoads, unsigned nStores, unsigned idx)
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v)
static LogicalResult verifyMemoryAccessOp(TMemoryOp op)
static SmallVector< Attribute > getFuncOpNames(Builder &builder, unsigned cnt, StringRef prefix)
Generates names for a handshake.func input and output arguments, based on the number of args as well ...
static void addStringToStringArrayAttr(Builder &builder, Operation *op, StringRef attrName, StringAttr str)
Helper function for appending a string to an array attribute, and rewriting the attribute back to the...
static std::string defaultOperandName(unsigned int idx)
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
static std::string getMemoryOperandName(unsigned nStores, unsigned idx)
static ParseResult parseMemoryAccessOp(OpAsmParser &parser, OperationState &result)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
bool isControlOpImpl(Operation *op)
Default implementation for checking whether an operation is a control operation.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.