16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/SymbolTable.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Interfaces/FunctionImplementation.h"
29 #include "mlir/Transforms/InliningUtils.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/SmallBitVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
36 using namespace circt;
41 #include "circt/Dialect/Handshake/HandshakeCanonicalization.h.inc"
44 if (
auto sostInterface = dyn_cast<SOSTInterface>(op); sostInterface)
45 return sostInterface.sostIsControl();
54 return "in" + std::to_string(idx);
58 if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
65 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
66 OperationState &result,
int &size, Type &type,
72 if (parser.parseOperandList(operands) ||
73 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
74 parser.parseType(type))
78 size = operands.size();
85 uint64_t numOperands) {
86 auto indexType = indexVal.getType();
90 if (
auto integerType = dyn_cast<IntegerType>(indexType))
91 indexWidth = integerType.getWidth();
92 else if (indexType.isIndex())
93 indexWidth = IndexType::kInternalStorageBitWidth;
95 return op->emitError(
"unsupported type for indexing value: ") << indexType;
98 if (indexWidth < 64) {
99 uint64_t maxNumOperands = (uint64_t)1 << indexWidth;
100 if (numOperands > maxNumOperands)
101 return op->emitError(
"bitwidth of indexing value is ")
102 << indexWidth <<
", which can index into " << maxNumOperands
103 <<
" operands, but found " << numOperands <<
" operands";
111 if (isa<NoneType>(dataType))
116 auto *defOp = operand.getDefiningOp();
117 return isa_and_nonnull<ControlMergeOp>(defOp) &&
118 operand == defOp->getResult(0);
121 template <
typename TMemOp>
122 llvm::SmallVector<handshake::MemLoadInterface>
getLoadPorts(TMemOp op) {
123 llvm::SmallVector<MemLoadInterface> ports;
130 unsigned stCount = op.getStCount();
131 unsigned ldCount = op.getLdCount();
132 for (
unsigned i = 0, e = ldCount; i != e; ++i) {
135 ldif.
addressIn = op.getInputs()[stCount * 2 + i];
136 ldif.
dataOut = op.getResult(i);
137 ldif.
doneOut = op.getResult(ldCount + stCount + i);
138 ports.push_back(ldif);
143 template <
typename TMemOp>
145 llvm::SmallVector<MemStoreInterface> ports;
152 unsigned ldCount = op.getLdCount();
153 for (
unsigned i = 0, e = op.getStCount(); i != e; ++i) {
156 stif.
dataIn = op.getInputs()[i * 2];
157 stif.
addressIn = op.getInputs()[i * 2 + 1];
158 stif.
doneOut = op.getResult(ldCount + i);
159 ports.push_back(stif);
164 unsigned ForkOp::getSize() {
return getResults().size(); }
166 static ParseResult
parseForkOp(OpAsmParser &parser, OperationState &result) {
167 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
169 ArrayRef<Type> operandTypes(type);
170 SmallVector<Type, 1> resultTypes;
171 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
176 resultTypes.assign(size, type);
177 result.addTypes(resultTypes);
178 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
184 ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
188 void ForkOp::print(OpAsmPrinter &p) { sostPrint(p,
true); }
195 LogicalResult matchAndRewrite(ForkOp op,
196 PatternRewriter &rewriter)
const override {
197 std::set<unsigned> unusedIndexes;
199 for (
auto res : llvm::enumerate(op.getResults()))
200 if (res.value().getUses().empty())
201 unusedIndexes.insert(res.index());
203 if (unusedIndexes.empty())
207 rewriter.setInsertionPoint(op);
208 auto operand = op.getOperand();
209 auto newFork = rewriter.create<ForkOp>(
210 op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
212 for (
auto oldRes : llvm::enumerate(op.getResults()))
213 if (unusedIndexes.count(oldRes.index()) == 0)
214 rewriter.replaceAllUsesWith(oldRes.value(), newFork.getResults()[i++]);
215 rewriter.eraseOp(op);
223 LogicalResult matchAndRewrite(ForkOp op,
224 PatternRewriter &rewriter)
const override {
225 auto parentForkOp = op.getOperand().getDefiningOp<ForkOp>();
233 unsigned totalNumOuts = op.getSize() + parentForkOp.getSize();
236 auto newParentForkOp = rewriter.create<ForkOp>(
237 parentForkOp.getLoc(), parentForkOp.getOperand(), totalNumOuts);
240 llvm::zip(parentForkOp->getResults(), newParentForkOp.getResults()))
241 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
245 rewriter.replaceOp(op,
246 newParentForkOp.getResults().take_back(op.getSize()));
247 rewriter.eraseOp(parentForkOp);
254 void handshake::ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
255 MLIRContext *context) {
256 results.insert<circt::handshake::EliminateSimpleForksPattern,
257 EliminateForkToForkPattern, EliminateUnusedForkResultsPattern>(
261 unsigned LazyForkOp::getSize() {
return getResults().size(); }
263 bool LazyForkOp::sostIsControl() {
267 ParseResult LazyForkOp::parse(OpAsmParser &parser, OperationState &result) {
271 void LazyForkOp::print(OpAsmPrinter &p) { sostPrint(p,
true); }
273 ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
274 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
276 ArrayRef<Type> operandTypes(type);
277 SmallVector<Type, 1> resultTypes, dataOperandsTypes;
278 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
283 dataOperandsTypes.assign(size, type);
284 resultTypes.push_back(type);
285 result.addTypes(resultTypes);
286 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
292 void MergeOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
294 void MergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
295 MLIRContext *context) {
296 results.insert<circt::handshake::EliminateSimpleMergesPattern>(context);
302 Operation *parentOp = v.getDefiningOp();
306 return llvm::TypeSwitch<Operation *, Value>(parentOp)
311 .Default([&](
auto) {
return v; });
321 LogicalResult matchAndRewrite(MuxOp op,
322 PatternRewriter &rewriter)
const override {
324 if (!llvm::all_of(op.getDataOperands(), [&](Value operand) {
325 return getDematerialized(operand) == firstDataOperand;
328 rewriter.replaceOp(op, firstDataOperand);
335 LogicalResult matchAndRewrite(MuxOp op,
336 PatternRewriter &rewriter)
const override {
337 if (op.getDataOperands().size() != 1)
340 rewriter.replaceOp(op, op.getDataOperands()[0]);
347 LogicalResult matchAndRewrite(MuxOp op,
348 PatternRewriter &rewriter)
const override {
350 auto dataOperands = op.getDataOperands();
351 if (dataOperands.size() != 2)
355 ConditionalBranchOp firstParentCBranch =
356 dataOperands[0].getDefiningOp<ConditionalBranchOp>();
357 if (!firstParentCBranch)
359 auto secondParentCBranch =
360 dataOperands[1].getDefiningOp<ConditionalBranchOp>();
361 if (!secondParentCBranch || firstParentCBranch != secondParentCBranch)
364 rewriter.modifyOpInPlace(firstParentCBranch, [&] {
366 rewriter.replaceOp(op, firstParentCBranch.getDataOperand());
375 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
376 MLIRContext *context) {
377 results.insert<EliminateSimpleMuxesPattern, EliminateUnaryMuxesPattern,
378 EliminateCBranchIntoMuxPattern>(context);
382 MuxOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
383 ValueRange operands, DictionaryAttr attributes,
384 mlir::OpaqueProperties properties,
385 mlir::RegionRange regions,
386 SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
389 if (operands.size() < 2)
392 inferredReturnTypes.push_back(operands[1].getType());
396 bool MuxOp::isControl() {
return isa<NoneType>(getResult().getType()); }
402 ParseResult MuxOp::parse(OpAsmParser &parser, OperationState &result) {
403 OpAsmParser::UnresolvedOperand selectOperand;
404 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
405 Type selectType, dataType;
406 SmallVector<Type, 1> dataOperandsTypes;
407 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
408 if (parser.parseOperand(selectOperand) || parser.parseLSquare() ||
409 parser.parseOperandList(allOperands) || parser.parseRSquare() ||
410 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
411 parser.parseType(selectType) || parser.parseComma() ||
412 parser.parseType(dataType))
415 int size = allOperands.size();
416 dataOperandsTypes.assign(size, dataType);
417 result.addTypes(dataType);
418 allOperands.insert(allOperands.begin(), selectOperand);
419 if (parser.resolveOperands(
421 llvm::concat<const Type>(ArrayRef<Type>(selectType),
422 ArrayRef<Type>(dataOperandsTypes)),
423 allOperandLoc, result.operands))
428 void MuxOp::print(OpAsmPrinter &p) {
429 Type selectType = getSelectOperand().getType();
430 auto ops = getOperands();
431 p <<
' ' << ops.front();
433 p.printOperands(ops.drop_front());
435 p.printOptionalAttrDict((*this)->getAttrs());
436 p <<
" : " << selectType <<
", " << getResult().getType();
441 getDataOperands().size());
444 std::string handshake::ControlMergeOp::getResultName(
unsigned int idx) {
445 assert(idx == 0 || idx == 1);
446 return idx == 0 ?
"dataOut" :
"index";
449 ParseResult ControlMergeOp::parse(OpAsmParser &parser, OperationState &result) {
450 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
451 Type resultType, indexType;
452 SmallVector<Type> resultTypes, dataOperandsTypes;
453 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
458 if (parser.parseComma() || parser.parseType(indexType))
461 dataOperandsTypes.assign(size, resultType);
462 resultTypes.push_back(resultType);
463 resultTypes.push_back(indexType);
464 result.addTypes(resultTypes);
465 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
471 void ControlMergeOp::print(OpAsmPrinter &p) {
474 p <<
", " << getIndex().getType();
478 auto operands = getOperands();
479 if (operands.empty())
480 return emitOpError(
"operation must have at least one operand");
481 if (operands[0].getType() != getResult().getType())
482 return emitOpError(
"type of first result should match type of operands");
494 auto fnInputTypes = getArgumentTypes();
495 Block &entryBlock = front();
497 for (
unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
498 if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
499 return emitOpError(
"type of entry block argument #")
500 << i <<
'(' << entryBlock.getArgument(i).getType()
501 <<
") must match the type of the corresponding argument in "
502 <<
"function signature(" << fnInputTypes[i] <<
')';
505 auto verifyPortNameAttr = [&](StringRef attrName,
506 unsigned numIOs) -> LogicalResult {
507 auto portNamesAttr = (*this)->getAttrOfType<ArrayAttr>(attrName);
510 return emitOpError() <<
"expected attribute '" << attrName <<
"'.";
512 auto portNames = portNamesAttr.getValue();
513 if (portNames.size() != numIOs)
514 return emitOpError() <<
"attribute '" << attrName <<
"' has "
516 <<
" entries but is expected to have " << numIOs
519 if (llvm::any_of(portNames,
520 [&](Attribute attr) {
return !isa<StringAttr>(attr); }))
521 return emitOpError() <<
"expected all entries in attribute '" << attrName
522 <<
"' to be strings.";
526 if (failed(verifyPortNameAttr(
"argNames", getNumArguments())))
528 if (failed(verifyPortNameAttr(
"resNames", getNumResults())))
532 for (
auto arg : entryBlock.getArguments()) {
533 if (!isa<MemRefType>(arg.getType()))
535 if (arg.getUsers().empty() ||
536 !isa<ExternalMemoryOp>(*arg.getUsers().begin()))
537 return emitOpError(
"expected that block argument #")
538 << arg.getArgNumber() <<
" is used by an 'extmemory' operation";
549 SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
550 SmallVectorImpl<Type> &resTypes,
551 SmallVectorImpl<DictionaryAttr> &resAttrs) {
553 if (mlir::function_interface_impl::parseFunctionSignature(
554 parser,
true, entryArgs, isVariadic, resTypes,
566 SmallVector<Attribute> resNames;
567 for (
unsigned i = 0; i < cnt; ++i)
568 resNames.push_back(builder.getStringAttr(prefix + std::to_string(i)));
572 void handshake::FuncOp::build(OpBuilder &builder, OperationState &state,
573 StringRef name, FunctionType type,
574 ArrayRef<NamedAttribute> attrs) {
575 state.addAttribute(SymbolTable::getSymbolAttrName(),
576 builder.getStringAttr(name));
577 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
579 state.attributes.append(attrs.begin(), attrs.end());
581 if (
const auto *argNamesAttrIt = llvm::find_if(
582 attrs, [&](
auto attr) {
return attr.getName() ==
"argNames"; });
583 argNamesAttrIt == attrs.end())
584 state.addAttribute(
"argNames", builder.getArrayAttr({}));
586 if (llvm::find_if(attrs, [&](
auto attr) {
587 return attr.getName() ==
"resNames";
589 state.addAttribute(
"resNames", builder.getArrayAttr({}));
597 StringRef attrName, StringAttr str) {
598 llvm::SmallVector<Attribute> attrs;
599 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
600 std::back_inserter(attrs));
601 attrs.push_back(str);
602 op->setAttr(attrName, builder.getArrayAttr(attrs));
605 void handshake::FuncOp::resolveArgAndResNames() {
606 Builder builder(getContext());
610 auto fallbackArgNames =
getFuncOpNames(builder, getNumArguments(),
"in");
611 auto fallbackResNames =
getFuncOpNames(builder, getNumResults(),
"out");
612 auto argNames = getArgNames().getValue();
613 auto resNames = getResNames().getValue();
616 auto resolveNames = [&](
auto &fallbackNames,
auto &actualNames,
617 StringRef attrName) {
618 for (
auto fallbackName : llvm::enumerate(fallbackNames)) {
619 if (actualNames.size() <= fallbackName.index())
621 cast<StringAttr>(fallbackName.value()));
624 resolveNames(fallbackArgNames, argNames,
"argNames");
625 resolveNames(fallbackResNames, resNames,
"resNames");
628 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
629 auto &builder = parser.getBuilder();
631 SmallVector<OpAsmParser::Argument> args;
632 SmallVector<Type> resTypes;
633 SmallVector<DictionaryAttr> resAttributes;
634 SmallVector<Attribute> argNames;
637 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
640 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
641 result.attributes) ||
644 mlir::function_interface_impl::addArgAndResultAttrs(
645 builder, result, args, resAttributes,
646 handshake::FuncOp::getArgAttrsAttrName(result.name),
647 handshake::FuncOp::getResAttrsAttrName(result.name));
650 SmallVector<Type> argTypes;
651 for (
auto arg : args)
652 argTypes.push_back(arg.type);
655 handshake::FuncOp::getFunctionTypeAttrName(result.name),
661 llvm::any_of(args, [](
auto arg) {
return arg.ssaName.name.empty(); });
665 llvm::transform(args, std::back_inserter(argNames), [&](
auto arg) {
666 return builder.getStringAttr(arg.ssaName.name.drop_front());
671 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
676 if (!result.attributes.get(
"argNames"))
677 result.addAttribute(
"argNames", builder.getArrayAttr(argNames));
678 if (!result.attributes.get(
"resNames")) {
680 result.addAttribute(
"resNames", builder.getArrayAttr(resNames));
685 auto *body = result.addRegion();
686 llvm::SMLoc loc = parser.getCurrentLocation();
687 auto parseResult = parser.parseOptionalRegion(*body, args,
689 if (!parseResult.has_value())
692 if (failed(*parseResult))
696 return parser.emitError(loc,
"expected non-empty function body");
702 void FuncOp::print(OpAsmPrinter &p) {
703 mlir::function_interface_impl::printFunctionOp(
704 p, *
this,
true, getFunctionTypeAttrName(),
705 getArgAttrsAttrName(), getResAttrsAttrName());
709 struct EliminateSimpleControlMergesPattern
713 LogicalResult matchAndRewrite(ControlMergeOp op,
714 PatternRewriter &rewriter)
const override;
718 LogicalResult EliminateSimpleControlMergesPattern::matchAndRewrite(
719 ControlMergeOp op, PatternRewriter &rewriter)
const {
720 auto dataResult = op.getResult();
721 auto choiceResult = op.getIndex();
722 auto choiceUnused = choiceResult.use_empty();
723 if (!choiceUnused && !choiceResult.hasOneUse())
726 Operation *choiceUser =
nullptr;
727 if (choiceResult.hasOneUse()) {
728 choiceUser = choiceResult.getUses().begin().getUser();
729 if (!isa<SinkOp>(choiceUser))
733 auto merge = rewriter.create<MergeOp>(op.getLoc(), op.getDataOperands());
735 for (
auto &use : llvm::make_early_inc_range(dataResult.getUses())) {
736 auto *user = use.getOwner();
737 rewriter.modifyOpInPlace(
738 user, [&]() { user->setOperand(use.getOperandNumber(), merge); });
742 rewriter.eraseOp(op);
746 rewriter.eraseOp(choiceUser);
747 rewriter.eraseOp(op);
751 void ControlMergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
752 MLIRContext *context) {
753 results.insert<EliminateSimpleControlMergesPattern>(context);
756 bool BranchOp::sostIsControl() {
760 void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
761 MLIRContext *context) {
762 results.insert<circt::handshake::EliminateSimpleBranchesPattern>(context);
765 ParseResult BranchOp::parse(OpAsmParser &parser, OperationState &result) {
766 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
768 ArrayRef<Type> operandTypes(type);
769 SmallVector<Type, 1> dataOperandsTypes;
770 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
775 dataOperandsTypes.assign(size, type);
776 result.addTypes({type});
777 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
783 void BranchOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
785 ParseResult ConditionalBranchOp::parse(OpAsmParser &parser,
786 OperationState &result) {
787 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
789 SmallVector<Type> operandTypes;
790 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
791 if (parser.parseOperandList(allOperands) ||
792 parser.parseOptionalAttrDict(result.attributes) ||
793 parser.parseColonType(dataType))
796 if (allOperands.size() != 2)
797 return parser.emitError(parser.getCurrentLocation(),
798 "Expected exactly 2 operands");
800 result.addTypes({dataType, dataType});
802 operandTypes.push_back(dataType);
803 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
810 void ConditionalBranchOp::print(OpAsmPrinter &p) {
811 Type type = getDataOperand().getType();
812 p <<
" " << getOperands();
813 p.printOptionalAttrDict((*this)->getAttrs());
818 assert(idx == 0 || idx == 1);
819 return idx == 0 ?
"cond" :
"data";
822 std::string handshake::ConditionalBranchOp::getResultName(
unsigned int idx) {
823 assert(idx == 0 || idx == 1);
824 return idx == ConditionalBranchOp::falseIndex ?
"outFalse" :
"outTrue";
827 bool ConditionalBranchOp::isControl() {
832 ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
833 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
835 ArrayRef<Type> operandTypes(type);
836 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
841 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
847 void SinkOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
854 Type SourceOp::getDataType() {
return getResult().getType(); }
855 unsigned SourceOp::getSize() {
return 1; }
857 ParseResult SourceOp::parse(OpAsmParser &parser, OperationState &result) {
858 if (parser.parseOptionalAttrDict(result.attributes))
864 void SourceOp::print(OpAsmPrinter &p) {
865 p.printOptionalAttrDict((*this)->getAttrs());
870 auto typedValue = dyn_cast<mlir::TypedAttr>(getValue());
872 return emitOpError(
"constant value must be a typed attribute; value is ")
874 if (typedValue.getType() != getResult().getType())
875 return emitOpError() <<
"constant value type " << typedValue.getType()
876 <<
" differs from operation result type "
877 << getResult().getType();
882 void handshake::ConstantOp::getCanonicalizationPatterns(
883 RewritePatternSet &results, MLIRContext *context) {
884 results.insert<circt::handshake::EliminateSunkConstantsPattern>(context);
890 if (
auto initVals = getInitValues()) {
893 <<
"only bufferType buffers are allowed to have initial values.";
895 auto nInits = initVals->size();
896 if (nInits != getSize())
897 return emitOpError() <<
"expected " << getSize()
898 <<
" init values but got " << nInits <<
".";
904 void handshake::BufferOp::getCanonicalizationPatterns(
905 RewritePatternSet &results, MLIRContext *context) {
906 results.insert<circt::handshake::EliminateSunkBuffersPattern>(context);
909 unsigned BufferOp::getSize() {
910 return (*this)->getAttrOfType<IntegerAttr>(
"slots").getValue().getZExtValue();
913 ParseResult BufferOp::parse(OpAsmParser &parser, OperationState &result) {
914 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
916 ArrayRef<Type> operandTypes(type);
917 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
922 auto bufferTypeAttr = BufferTypeEnumAttr::parse(parser, {});
929 result.addAttribute(
"bufferType", bufferTypeAttr);
931 if (parser.parseOperandList(allOperands) ||
932 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
933 parser.parseType(type))
936 result.addTypes({type});
937 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
943 void BufferOp::print(OpAsmPrinter &p) {
945 (*this)->getAttrOfType<IntegerAttr>(
"slots").getValue().getZExtValue();
946 p <<
" [" << size <<
"]";
947 p <<
" " << stringifyEnum(getBufferType());
948 p <<
" " << (*this)->getOperands();
949 p.printOptionalAttrDict((*this)->getAttrs(), {
"slots",
"bufferType"});
950 p <<
" : " << (*this).getDataType();
955 if (idx < nStores * 2) {
956 bool isData = idx % 2 == 0;
957 name = isData ?
"stData" + std::to_string(idx / 2)
958 :
"stAddr" + std::to_string(idx / 2);
961 name =
"ldAddr" + std::to_string(idx);
974 name =
"ldData" + std::to_string(idx);
975 else if (idx < nLoads + nStores)
976 name =
"stDone" + std::to_string(idx - nLoads);
978 name =
"ldDone" + std::to_string(idx - nLoads - nStores);
982 std::string handshake::MemoryOp::getResultName(
unsigned int idx) {
987 auto memrefType = getMemRefType();
989 if (memrefType.getNumDynamicDims() != 0)
991 <<
"memref dimensions for handshake.memory must be static.";
992 if (memrefType.getShape().size() != 1)
993 return emitOpError() <<
"memref must have only a single dimension.";
995 unsigned opStCount = getStCount();
996 unsigned opLdCount = getLdCount();
997 int addressCount = memrefType.getShape().size();
999 auto inputType = getInputs().getType();
1000 auto outputType = getOutputs().getType();
1001 Type dataType = memrefType.getElementType();
1003 unsigned numOperands =
static_cast<int>(getInputs().size());
1004 unsigned numResults =
static_cast<int>(getOutputs().size());
1005 if (numOperands != (1 + addressCount) * opStCount + addressCount * opLdCount)
1006 return emitOpError(
"number of operands ")
1007 << numOperands <<
" does not match number expected of "
1008 << 2 * opStCount + opLdCount <<
" with " << addressCount
1009 <<
" address inputs per port";
1011 if (numResults != opStCount + 2 * opLdCount)
1012 return emitOpError(
"number of results ")
1013 << numResults <<
" does not match number expected of "
1014 << opStCount + 2 * opLdCount <<
" with " << addressCount
1015 <<
" address inputs per port";
1017 Type addressType = opStCount > 0 ? inputType[1] : inputType[0];
1019 for (
unsigned i = 0; i < opStCount; i++) {
1020 if (inputType[2 * i] != dataType)
1021 return emitOpError(
"data type for store port ")
1022 << i <<
":" << inputType[2 * i] <<
" doesn't match memory type "
1024 if (inputType[2 * i + 1] != addressType)
1025 return emitOpError(
"address type for store port ")
1026 << i <<
":" << inputType[2 * i + 1]
1027 <<
" doesn't match address type " << addressType;
1029 for (
unsigned i = 0; i < opLdCount; i++) {
1030 Type ldAddressType = inputType[2 * opStCount + i];
1031 if (ldAddressType != addressType)
1032 return emitOpError(
"address type for load port ")
1033 << i <<
":" << ldAddressType <<
" doesn't match address type "
1036 for (
unsigned i = 0; i < opLdCount; i++) {
1037 if (outputType[i] != dataType)
1038 return emitOpError(
"data type for load port ")
1039 << i <<
":" << outputType[i] <<
" doesn't match memory type "
1042 for (
unsigned i = 0; i < opStCount; i++) {
1043 Type syncType = outputType[opLdCount + i];
1044 if (!isa<NoneType>(syncType))
1045 return emitOpError(
"data type for sync port for store port ")
1046 << i <<
":" << syncType <<
" is not 'none'";
1048 for (
unsigned i = 0; i < opLdCount; i++) {
1049 Type syncType = outputType[opLdCount + opStCount + i];
1050 if (!isa<NoneType>(syncType))
1051 return emitOpError(
"data type for sync port for load port ")
1052 << i <<
":" << syncType <<
" is not 'none'";
1065 std::string handshake::ExternalMemoryOp::getResultName(
unsigned int idx) {
1069 void ExternalMemoryOp::build(OpBuilder &builder, OperationState &result,
1070 Value memref, ValueRange inputs,
int ldCount,
1071 int stCount,
int id) {
1072 SmallVector<Value> ops;
1073 ops.push_back(memref);
1074 llvm::append_range(ops, inputs);
1075 result.addOperands(ops);
1077 auto memrefType = cast<MemRefType>(memref.getType());
1080 result.types.append(ldCount, memrefType.getElementType());
1083 result.types.append(stCount + ldCount, builder.getNoneType());
1086 Type i32Type = builder.getIntegerType(32);
1087 result.addAttribute(
"id", builder.getIntegerAttr(i32Type,
id));
1088 result.addAttribute(
"ldCount", builder.getIntegerAttr(i32Type, ldCount));
1089 result.addAttribute(
"stCount", builder.getIntegerAttr(i32Type, stCount));
1092 llvm::SmallVector<handshake::MemLoadInterface>
1097 llvm::SmallVector<handshake::MemStoreInterface>
1102 void MemoryOp::build(OpBuilder &builder, OperationState &result,
1103 ValueRange operands,
int outputs,
int controlOutputs,
1104 bool lsq,
int id, Value memref) {
1105 result.addOperands(operands);
1107 auto memrefType = cast<MemRefType>(memref.getType());
1110 result.types.append(outputs, memrefType.getElementType());
1113 result.types.append(controlOutputs, builder.getNoneType());
1114 result.addAttribute(
"lsq", builder.getBoolAttr(lsq));
1115 result.addAttribute(
"memRefType",
TypeAttr::get(memrefType));
1118 Type i32Type = builder.getIntegerType(32);
1119 result.addAttribute(
"id", builder.getIntegerAttr(i32Type,
id));
1122 result.addAttribute(
"ldCount", builder.getIntegerAttr(i32Type, outputs));
1123 result.addAttribute(
1124 "stCount", builder.getIntegerAttr(i32Type, controlOutputs - outputs));
1136 bool handshake::MemoryOp::allocateMemory(
1137 llvm::DenseMap<unsigned, unsigned> &memoryMap,
1138 std::vector<std::vector<llvm::Any>> &store,
1139 std::vector<double> &storeTimes) {
1140 if (memoryMap.count(getId()))
1143 auto type = getMemRefType();
1144 std::vector<llvm::Any> in;
1146 ArrayRef<int64_t> shape = type.getShape();
1147 int allocationSize = 1;
1149 for (int64_t dim : shape) {
1151 allocationSize *= dim;
1153 assert(count < in.size());
1154 allocationSize *= llvm::any_cast<APInt>(in[count++]).getSExtValue();
1157 unsigned ptr = store.size();
1158 store.resize(ptr + 1);
1159 storeTimes.resize(ptr + 1);
1160 store[ptr].resize(allocationSize);
1161 storeTimes[ptr] = 0.0;
1164 for (
int i = 0; i < allocationSize; i++) {
1166 store[ptr][i] = APInt(
width, 0);
1168 store[ptr][i] = APFloat(0.0);
1170 llvm_unreachable(
"Unknown result type!\n");
1174 memoryMap[getId()] = ptr;
1179 unsigned nAddresses = getAddresses().size();
1181 if (idx < nAddresses)
1182 opName =
"addrIn" + std::to_string(idx);
1183 else if (idx == nAddresses)
1184 opName =
"dataFromMem";
1190 std::string handshake::LoadOp::getResultName(
unsigned int idx) {
1191 std::string resName;
1193 resName =
"dataOut";
1195 resName =
"addrOut" + std::to_string(idx - 1);
1199 void handshake::LoadOp::build(OpBuilder &builder, OperationState &result,
1200 Value memref, ValueRange indices) {
1203 result.addOperands(indices);
1206 auto memrefType = cast<MemRefType>(memref.getType());
1209 result.types.push_back(memrefType.getElementType());
1212 result.types.append(indices.size(), builder.getIndexType());
1216 OperationState &result) {
1217 SmallVector<OpAsmParser::UnresolvedOperand, 4> addressOperands,
1218 remainingOperands, allOperands;
1219 SmallVector<Type, 1> parsedTypes, allTypes;
1220 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1222 if (parser.parseLSquare() || parser.parseOperandList(addressOperands) ||
1223 parser.parseRSquare() || parser.parseOperandList(remainingOperands) ||
1224 parser.parseColon() || parser.parseTypeList(parsedTypes))
1229 Type dataType = parsedTypes.back();
1230 auto parsedTypesRef = ArrayRef(parsedTypes);
1231 result.addTypes(dataType);
1232 result.addTypes(parsedTypesRef.drop_back());
1233 allOperands.append(addressOperands);
1234 allOperands.append(remainingOperands);
1235 allTypes.append(parsedTypes);
1237 if (parser.resolveOperands(allOperands, allTypes, allOperandLoc,
1243 template <
typename MemOp>
1246 p << op.getAddresses();
1247 p <<
"] " << op.getData() <<
", " << op.getCtrl() <<
" : ";
1248 llvm::interleaveComma(op.getAddresses(), p,
1249 [&](Value v) { p << v.getType(); });
1250 p <<
", " << op.getData().getType();
1253 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
1260 unsigned nAddresses = getAddresses().size();
1262 if (idx < nAddresses)
1263 opName =
"addrIn" + std::to_string(idx);
1264 else if (idx == nAddresses)
1271 template <
typename TMemoryOp>
1273 if (op.getAddresses().size() == 0)
1274 return op.emitOpError() <<
"No addresses were specified";
1281 std::string handshake::StoreOp::getResultName(
unsigned int idx) {
1282 std::string resName;
1284 resName =
"dataToMem";
1286 resName =
"addrOut" + std::to_string(idx - 1);
1290 void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
1291 Value valueToStore, ValueRange indices) {
1294 result.addOperands(indices);
1297 result.addOperands(valueToStore);
1300 result.types.push_back(valueToStore.getType());
1303 result.types.append(indices.size(), builder.getIndexType());
1308 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
1314 bool JoinOp::isControl() {
return true; }
1316 ParseResult JoinOp::parse(OpAsmParser &parser, OperationState &result) {
1317 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1318 SmallVector<Type> types;
1320 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1321 if (parser.parseOperandList(operands) ||
1322 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1323 parser.parseTypeList(types))
1326 if (parser.resolveOperands(operands, types, allOperandLoc, result.operands))
1333 void JoinOp::print(OpAsmPrinter &p) {
1334 p <<
" " << getData();
1335 p.printOptionalAttrDict((*this)->getAttrs(), {
"control"});
1336 p <<
" : " << getData().getTypes();
1340 ESIInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1342 auto fnAttr = this->getModuleAttr();
1343 assert(fnAttr &&
"requires a 'module' symbol reference attribute");
1345 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*
this, fnAttr);
1347 return emitOpError() <<
"'" << fnAttr.getValue()
1348 <<
"' does not reference a valid handshake function";
1351 auto fnType = fn.getFunctionType();
1352 if (fnType.getNumInputs() != getNumOperands() - NumFixedOperands)
1354 "incorrect number of operands for the referenced handshake function");
1356 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1357 Type operandType = getOperand(i + NumFixedOperands).getType();
1358 auto channelType = dyn_cast<esi::ChannelType>(operandType);
1360 return emitOpError(
"operand type mismatch: expected channel type, but "
1362 << operandType <<
" for operand number " << i;
1363 if (channelType.getInner() != fnType.getInput(i))
1364 return emitOpError(
"operand type mismatch: expected operand type ")
1365 << fnType.getInput(i) <<
", but provided "
1366 << getOperand(i).getType() <<
" for operand number " << i;
1369 if (fnType.getNumResults() != getNumResults())
1371 "incorrect number of results for the referenced handshake function");
1373 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1374 Type resultType = getResult(i).getType();
1375 auto channelType = dyn_cast<esi::ChannelType>(resultType);
1377 return emitOpError(
"result type mismatch: expected channel type, but "
1379 << resultType <<
" for result number " << i;
1380 if (channelType.getInner() != fnType.getResult(i))
1381 return emitOpError(
"result type mismatch: expected result type ")
1382 << fnType.getResult(i) <<
", but provided "
1383 << getResult(i).getType() <<
" for result number " << i;
1390 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1392 auto fnAttr = this->getModuleAttr();
1393 assert(fnAttr &&
"requires a 'module' symbol reference attribute");
1395 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*
this, fnAttr);
1397 return emitOpError() <<
"'" << fnAttr.getValue()
1398 <<
"' does not reference a valid handshake function";
1401 auto fnType = fn.getFunctionType();
1402 if (fnType.getNumInputs() != getNumOperands())
1404 "incorrect number of operands for the referenced handshake function");
1406 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
1407 if (getOperand(i).getType() != fnType.getInput(i))
1408 return emitOpError(
"operand type mismatch: expected operand type ")
1409 << fnType.getInput(i) <<
", but provided "
1410 << getOperand(i).getType() <<
" for operand number " << i;
1412 if (fnType.getNumResults() != getNumResults())
1414 "incorrect number of results for the referenced handshake function");
1416 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
1417 if (getResult(i).getType() != fnType.getResult(i))
1418 return emitOpError(
"result type mismatch: expected result type ")
1419 << fnType.getResult(i) <<
", but provided "
1420 << getResult(i).getType() <<
" for result number " << i;
1429 ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1430 OpAsmParser::UnresolvedOperand tuple;
1433 if (parser.parseOperand(tuple) ||
1434 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1435 parser.parseType(type))
1438 if (parser.resolveOperand(tuple, type, result.operands))
1441 result.addTypes(type.getTypes());
1446 void UnpackOp::print(OpAsmPrinter &p) {
1447 p <<
" " << getInput();
1448 p.printOptionalAttrDict((*this)->getAttrs());
1449 p <<
" : " << getInput().getType();
1452 ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1453 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1454 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1457 if (parser.parseOperandList(operands) ||
1458 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1459 parser.parseType(type))
1462 if (parser.resolveOperands(operands, type.getTypes(), allOperandLoc,
1466 result.addTypes(type);
1471 void PackOp::print(OpAsmPrinter &p) {
1472 p <<
" " << getInputs();
1473 p.printOptionalAttrDict((*this)->getAttrs());
1474 p <<
" : " << getResult().getType();
1482 auto *parent = (*this)->getParentOp();
1483 auto function = dyn_cast<handshake::FuncOp>(parent);
1485 return emitOpError(
"must have a handshake.func parent");
1488 const auto &results =
function.getResultTypes();
1489 if (getNumOperands() != results.size())
1490 return emitOpError(
"has ")
1491 << getNumOperands() <<
" operands, but enclosing function returns "
1494 for (
unsigned i = 0, e = results.size(); i != e; ++i)
1495 if (getOperand(i).getType() != results[i])
1496 return emitError() <<
"type of return operand " << i <<
" ("
1497 << getOperand(i).getType()
1498 <<
") doesn't match function result type ("
1499 << results[i] <<
")";
1504 #define GET_OP_CLASSES
1505 #include "circt/Dialect/Handshake/Handshake.cpp.inc"
assert(baseType &&"element must be base type")
static StringRef getOperandName(Value operand)
static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal, uint64_t numOperands)
Verifies whether an indexing value is wide enough to index into a provided number of operands.
static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result)
static ParseResult parseSostOperation(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, OperationState &result, int &size, Type &type, bool explicitSize)
static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op)
static ParseResult parseFuncOpArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Type > &resTypes, SmallVectorImpl< DictionaryAttr > &resAttrs)
Parses a FuncOp signature using mlir::function_interface_impl::parseFunctionSignature while getting a...
static Value getDematerialized(Value v)
Returns a dematerialized version of the value 'v', defined as the source of the value before passing ...
static bool isControlCheckTypeAndOperand(Type dataType, Value operand)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static std::string getMemoryResultName(unsigned nLoads, unsigned nStores, unsigned idx)
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v)
static LogicalResult verifyMemoryAccessOp(TMemoryOp op)
static void addStringToStringArrayAttr(Builder &builder, Operation *op, StringRef attrName, StringAttr str)
Helper function for appending a string to an array attribute, and rewriting the attribute back to the...
static std::string defaultOperandName(unsigned int idx)
static SmallVector< Attribute > getFuncOpNames(Builder &builder, unsigned cnt, StringRef prefix)
Generates names for a handshake.func input and output arguments, based on the number of args as well ...
static std::string getMemoryOperandName(unsigned nStores, unsigned idx)
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
static ParseResult parseMemoryAccessOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
bool isControlOpImpl(Operation *op)
Default implementation for checking whether an operation is a control operation.
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.