15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/Interfaces/FunctionImplementation.h"
28 #include "mlir/Transforms/InliningUtils.h"
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
35 using namespace circt;
40 #include "circt/Dialect/Handshake/HandshakeCanonicalization.h.inc"
43 if (
auto sostInterface = dyn_cast<SOSTInterface>(op); sostInterface)
44 return sostInterface.sostIsControl();
53 return "in" + std::to_string(idx);
57 if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
64 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
65 OperationState &result,
int &size, Type &type,
71 if (parser.parseOperandList(operands) ||
72 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
73 parser.parseType(type))
77 size = operands.size();
84 uint64_t numOperands) {
85 auto indexType = indexVal.getType();
89 if (
auto integerType = dyn_cast<IntegerType>(indexType))
90 indexWidth = integerType.getWidth();
91 else if (indexType.isIndex())
92 indexWidth = IndexType::kInternalStorageBitWidth;
94 return op->emitError(
"unsupported type for indexing value: ") << indexType;
97 if (indexWidth < 64) {
98 uint64_t maxNumOperands = (uint64_t)1 << indexWidth;
99 if (numOperands > maxNumOperands)
100 return op->emitError(
"bitwidth of indexing value is ")
101 << indexWidth <<
", which can index into " << maxNumOperands
102 <<
" operands, but found " << numOperands <<
" operands";
110 if (isa<NoneType>(dataType))
115 auto *defOp = operand.getDefiningOp();
116 return isa_and_nonnull<ControlMergeOp>(defOp) &&
117 operand == defOp->getResult(0);
120 template <
typename TMemOp>
121 llvm::SmallVector<handshake::MemLoadInterface>
getLoadPorts(TMemOp op) {
122 llvm::SmallVector<MemLoadInterface> ports;
129 unsigned stCount = op.getStCount();
130 unsigned ldCount = op.getLdCount();
131 for (
unsigned i = 0, e = ldCount; i != e; ++i) {
134 ldif.
addressIn = op.getInputs()[stCount * 2 + i];
135 ldif.
dataOut = op.getResult(i);
136 ldif.
doneOut = op.getResult(ldCount + stCount + i);
137 ports.push_back(ldif);
142 template <
typename TMemOp>
144 llvm::SmallVector<MemStoreInterface> ports;
151 unsigned ldCount = op.getLdCount();
152 for (
unsigned i = 0, e = op.getStCount(); i != e; ++i) {
155 stif.
dataIn = op.getInputs()[i * 2];
156 stif.
addressIn = op.getInputs()[i * 2 + 1];
157 stif.
doneOut = op.getResult(ldCount + i);
158 ports.push_back(stif);
163 unsigned ForkOp::getSize() {
return getResults().size(); }
165 static ParseResult
parseForkOp(OpAsmParser &parser, OperationState &result) {
166 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
168 ArrayRef<Type> operandTypes(type);
169 SmallVector<Type, 1> resultTypes;
170 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
175 resultTypes.assign(size, type);
176 result.addTypes(resultTypes);
177 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
183 ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
187 void ForkOp::print(OpAsmPrinter &p) { sostPrint(p,
true); }
194 LogicalResult matchAndRewrite(ForkOp op,
195 PatternRewriter &rewriter)
const override {
196 std::set<unsigned> unusedIndexes;
198 for (
auto res : llvm::enumerate(op.getResults()))
199 if (res.value().getUses().empty())
200 unusedIndexes.insert(res.index());
202 if (unusedIndexes.empty())
206 rewriter.setInsertionPoint(op);
207 auto operand = op.getOperand();
208 auto newFork = rewriter.create<ForkOp>(
209 op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
211 for (
auto oldRes : llvm::enumerate(op.getResults()))
212 if (unusedIndexes.count(oldRes.index()) == 0)
213 rewriter.replaceAllUsesWith(oldRes.value(), newFork.getResults()[i++]);
214 rewriter.eraseOp(op);
222 LogicalResult matchAndRewrite(ForkOp op,
223 PatternRewriter &rewriter)
const override {
224 auto parentForkOp = op.getOperand().getDefiningOp<ForkOp>();
232 unsigned totalNumOuts = op.getSize() + parentForkOp.getSize();
235 auto newParentForkOp = rewriter.create<ForkOp>(
236 parentForkOp.getLoc(), parentForkOp.getOperand(), totalNumOuts);
239 llvm::zip(parentForkOp->getResults(), newParentForkOp.getResults()))
240 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
244 rewriter.replaceOp(op,
245 newParentForkOp.getResults().take_back(op.getSize()));
246 rewriter.eraseOp(parentForkOp);
253 void handshake::ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
254 MLIRContext *context) {
255 results.insert<circt::handshake::EliminateSimpleForksPattern,
256 EliminateForkToForkPattern, EliminateUnusedForkResultsPattern>(
260 unsigned LazyForkOp::getSize() {
return getResults().size(); }
262 bool LazyForkOp::sostIsControl() {
266 ParseResult LazyForkOp::parse(OpAsmParser &parser, OperationState &result) {
270 void LazyForkOp::print(OpAsmPrinter &p) { sostPrint(p,
true); }
272 ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
273 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
275 ArrayRef<Type> operandTypes(type);
276 SmallVector<Type, 1> resultTypes, dataOperandsTypes;
277 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
282 dataOperandsTypes.assign(size, type);
283 resultTypes.push_back(type);
284 result.addTypes(resultTypes);
285 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
291 void MergeOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
293 void MergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
294 MLIRContext *context) {
295 results.insert<circt::handshake::EliminateSimpleMergesPattern>(context);
301 Operation *parentOp = v.getDefiningOp();
305 return llvm::TypeSwitch<Operation *, Value>(parentOp)
310 .Default([&](
auto) {
return v; });
320 LogicalResult matchAndRewrite(MuxOp op,
321 PatternRewriter &rewriter)
const override {
323 if (!llvm::all_of(op.getDataOperands(), [&](Value operand) {
324 return getDematerialized(operand) == firstDataOperand;
327 rewriter.replaceOp(op, firstDataOperand);
334 LogicalResult matchAndRewrite(MuxOp op,
335 PatternRewriter &rewriter)
const override {
336 if (op.getDataOperands().size() != 1)
339 rewriter.replaceOp(op, op.getDataOperands()[0]);
346 LogicalResult matchAndRewrite(MuxOp op,
347 PatternRewriter &rewriter)
const override {
349 auto dataOperands = op.getDataOperands();
350 if (dataOperands.size() != 2)
354 ConditionalBranchOp firstParentCBranch =
355 dataOperands[0].getDefiningOp<ConditionalBranchOp>();
356 if (!firstParentCBranch)
358 auto secondParentCBranch =
359 dataOperands[1].getDefiningOp<ConditionalBranchOp>();
360 if (!secondParentCBranch || firstParentCBranch != secondParentCBranch)
363 rewriter.modifyOpInPlace(firstParentCBranch, [&] {
365 rewriter.replaceOp(op, firstParentCBranch.getDataOperand());
374 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
375 MLIRContext *context) {
376 results.insert<EliminateSimpleMuxesPattern, EliminateUnaryMuxesPattern,
377 EliminateCBranchIntoMuxPattern>(context);
382 ValueRange operands, DictionaryAttr attributes,
383 mlir::OpaqueProperties properties,
384 mlir::RegionRange regions,
385 SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
388 if (operands.size() < 2)
391 inferredReturnTypes.push_back(operands[1].getType());
395 bool MuxOp::isControl() {
return isa<NoneType>(getResult().getType()); }
401 ParseResult MuxOp::parse(OpAsmParser &parser, OperationState &result) {
402 OpAsmParser::UnresolvedOperand selectOperand;
403 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
404 Type selectType, dataType;
405 SmallVector<Type, 1> dataOperandsTypes;
406 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
407 if (parser.parseOperand(selectOperand) || parser.parseLSquare() ||
408 parser.parseOperandList(allOperands) || parser.parseRSquare() ||
409 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
410 parser.parseType(selectType) || parser.parseComma() ||
411 parser.parseType(dataType))
414 int size = allOperands.size();
415 dataOperandsTypes.assign(size, dataType);
416 result.addTypes(dataType);
417 allOperands.insert(allOperands.begin(), selectOperand);
418 if (parser.resolveOperands(
420 llvm::concat<const Type>(ArrayRef<Type>(selectType),
421 ArrayRef<Type>(dataOperandsTypes)),
422 allOperandLoc, result.operands))
427 void MuxOp::print(OpAsmPrinter &p) {
428 Type selectType = getSelectOperand().getType();
429 auto ops = getOperands();
430 p <<
' ' << ops.front();
432 p.printOperands(ops.drop_front());
434 p.printOptionalAttrDict((*this)->getAttrs());
435 p <<
" : " << selectType <<
", " << getResult().getType();
440 getDataOperands().size());
443 std::string handshake::ControlMergeOp::getResultName(
unsigned int idx) {
444 assert(idx == 0 || idx == 1);
445 return idx == 0 ?
"dataOut" :
"index";
448 ParseResult ControlMergeOp::parse(OpAsmParser &parser, OperationState &result) {
449 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
450 Type resultType, indexType;
451 SmallVector<Type> resultTypes, dataOperandsTypes;
452 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
457 if (parser.parseComma() || parser.parseType(indexType))
460 dataOperandsTypes.assign(size, resultType);
461 resultTypes.push_back(resultType);
462 resultTypes.push_back(indexType);
463 result.addTypes(resultTypes);
464 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
470 void ControlMergeOp::print(OpAsmPrinter &p) {
473 p <<
", " << getIndex().getType();
477 auto operands = getOperands();
478 if (operands.empty())
479 return emitOpError(
"operation must have at least one operand");
480 if (operands[0].getType() != getResult().getType())
481 return emitOpError(
"type of first result should match type of operands");
493 auto fnInputTypes = getArgumentTypes();
494 Block &entryBlock = front();
496 for (
unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
497 if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
498 return emitOpError(
"type of entry block argument #")
499 << i <<
'(' << entryBlock.getArgument(i).getType()
500 <<
") must match the type of the corresponding argument in "
501 <<
"function signature(" << fnInputTypes[i] <<
')';
504 auto verifyPortNameAttr = [&](StringRef attrName,
505 unsigned numIOs) -> LogicalResult {
506 auto portNamesAttr = (*this)->getAttrOfType<ArrayAttr>(attrName);
509 return emitOpError() <<
"expected attribute '" << attrName <<
"'.";
511 auto portNames = portNamesAttr.getValue();
512 if (portNames.size() != numIOs)
513 return emitOpError() <<
"attribute '" << attrName <<
"' has "
515 <<
" entries but is expected to have " << numIOs
518 if (llvm::any_of(portNames,
519 [&](Attribute attr) {
return !isa<StringAttr>(attr); }))
520 return emitOpError() <<
"expected all entries in attribute '" << attrName
521 <<
"' to be strings.";
525 if (failed(verifyPortNameAttr(
"argNames", getNumArguments())))
527 if (failed(verifyPortNameAttr(
"resNames", getNumResults())))
531 for (
auto arg : entryBlock.getArguments()) {
532 if (!isa<MemRefType>(arg.getType()))
534 if (arg.getUsers().empty() ||
535 !isa<ExternalMemoryOp>(*arg.getUsers().begin()))
536 return emitOpError(
"expected that block argument #")
537 << arg.getArgNumber() <<
" is used by an 'extmemory' operation";
548 SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
549 SmallVectorImpl<Type> &resTypes,
550 SmallVectorImpl<DictionaryAttr> &resAttrs) {
552 if (mlir::function_interface_impl::parseFunctionSignature(
553 parser,
true, entryArgs, isVariadic, resTypes,
565 SmallVector<Attribute> resNames;
566 for (
unsigned i = 0; i < cnt; ++i)
567 resNames.push_back(builder.getStringAttr(prefix + std::to_string(i)));
571 void handshake::FuncOp::build(OpBuilder &builder, OperationState &state,
572 StringRef name, FunctionType type,
573 ArrayRef<NamedAttribute> attrs) {
574 state.addAttribute(SymbolTable::getSymbolAttrName(),
575 builder.getStringAttr(name));
576 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
578 state.attributes.append(attrs.begin(), attrs.end());
580 if (
const auto *argNamesAttrIt = llvm::find_if(
581 attrs, [&](
auto attr) {
return attr.getName() ==
"argNames"; });
582 argNamesAttrIt == attrs.end())
583 state.addAttribute(
"argNames", builder.getArrayAttr({}));
585 if (llvm::find_if(attrs, [&](
auto attr) {
586 return attr.getName() ==
"resNames";
588 state.addAttribute(
"resNames", builder.getArrayAttr({}));
596 StringRef attrName, StringAttr str) {
597 llvm::SmallVector<Attribute> attrs;
598 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
599 std::back_inserter(attrs));
600 attrs.push_back(str);
601 op->setAttr(attrName, builder.getArrayAttr(attrs));
604 void handshake::FuncOp::resolveArgAndResNames() {
605 Builder builder(getContext());
609 auto fallbackArgNames =
getFuncOpNames(builder, getNumArguments(),
"in");
610 auto fallbackResNames =
getFuncOpNames(builder, getNumResults(),
"out");
611 auto argNames = getArgNames().getValue();
612 auto resNames = getResNames().getValue();
615 auto resolveNames = [&](
auto &fallbackNames,
auto &actualNames,
616 StringRef attrName) {
617 for (
auto fallbackName : llvm::enumerate(fallbackNames)) {
618 if (actualNames.size() <= fallbackName.index())
620 cast<StringAttr>(fallbackName.value()));
623 resolveNames(fallbackArgNames, argNames,
"argNames");
624 resolveNames(fallbackResNames, resNames,
"resNames");
627 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
628 auto &builder = parser.getBuilder();
630 SmallVector<OpAsmParser::Argument> args;
631 SmallVector<Type> resTypes;
632 SmallVector<DictionaryAttr> resAttributes;
633 SmallVector<Attribute> argNames;
636 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
639 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
640 result.attributes) ||
643 mlir::function_interface_impl::addArgAndResultAttrs(
644 builder, result, args, resAttributes,
645 handshake::FuncOp::getArgAttrsAttrName(result.name),
646 handshake::FuncOp::getResAttrsAttrName(result.name));
649 SmallVector<Type> argTypes;
650 for (
auto arg : args)
651 argTypes.push_back(arg.type);
654 handshake::FuncOp::getFunctionTypeAttrName(result.name),
660 llvm::any_of(args, [](
auto arg) {
return arg.ssaName.name.empty(); });
664 llvm::transform(args, std::back_inserter(argNames), [&](
auto arg) {
665 return builder.getStringAttr(arg.ssaName.name.drop_front());
670 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
675 if (!result.attributes.get(
"argNames"))
676 result.addAttribute(
"argNames", builder.getArrayAttr(argNames));
677 if (!result.attributes.get(
"resNames")) {
679 result.addAttribute(
"resNames", builder.getArrayAttr(resNames));
684 auto *body = result.addRegion();
685 llvm::SMLoc loc = parser.getCurrentLocation();
686 auto parseResult = parser.parseOptionalRegion(*body, args,
688 if (!parseResult.has_value())
691 if (failed(*parseResult))
695 return parser.emitError(loc,
"expected non-empty function body");
701 void FuncOp::print(OpAsmPrinter &p) {
702 mlir::function_interface_impl::printFunctionOp(
703 p, *
this,
true, getFunctionTypeAttrName(),
704 getArgAttrsAttrName(), getResAttrsAttrName());
708 struct EliminateSimpleControlMergesPattern
712 LogicalResult matchAndRewrite(ControlMergeOp op,
713 PatternRewriter &rewriter)
const override;
717 LogicalResult EliminateSimpleControlMergesPattern::matchAndRewrite(
718 ControlMergeOp op, PatternRewriter &rewriter)
const {
719 auto dataResult = op.getResult();
720 auto choiceResult = op.getIndex();
721 auto choiceUnused = choiceResult.use_empty();
722 if (!choiceUnused && !choiceResult.hasOneUse())
725 Operation *choiceUser =
nullptr;
726 if (choiceResult.hasOneUse()) {
727 choiceUser = choiceResult.getUses().begin().getUser();
728 if (!isa<SinkOp>(choiceUser))
732 auto merge = rewriter.create<MergeOp>(op.getLoc(), op.getDataOperands());
734 for (
auto &use : llvm::make_early_inc_range(dataResult.getUses())) {
735 auto *user = use.getOwner();
736 rewriter.modifyOpInPlace(
737 user, [&]() { user->setOperand(use.getOperandNumber(), merge); });
741 rewriter.eraseOp(op);
745 rewriter.eraseOp(choiceUser);
746 rewriter.eraseOp(op);
750 void ControlMergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
751 MLIRContext *context) {
752 results.insert<EliminateSimpleControlMergesPattern>(context);
755 bool BranchOp::sostIsControl() {
759 void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
760 MLIRContext *context) {
761 results.insert<circt::handshake::EliminateSimpleBranchesPattern>(context);
764 ParseResult BranchOp::parse(OpAsmParser &parser, OperationState &result) {
765 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
767 ArrayRef<Type> operandTypes(type);
768 SmallVector<Type, 1> dataOperandsTypes;
769 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
774 dataOperandsTypes.assign(size, type);
775 result.addTypes({type});
776 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
782 void BranchOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
784 ParseResult ConditionalBranchOp::parse(OpAsmParser &parser,
785 OperationState &result) {
786 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
788 SmallVector<Type> operandTypes;
789 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
790 if (parser.parseOperandList(allOperands) ||
791 parser.parseOptionalAttrDict(result.attributes) ||
792 parser.parseColonType(dataType))
795 if (allOperands.size() != 2)
796 return parser.emitError(parser.getCurrentLocation(),
797 "Expected exactly 2 operands");
799 result.addTypes({dataType, dataType});
801 operandTypes.push_back(dataType);
802 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
809 void ConditionalBranchOp::print(OpAsmPrinter &p) {
810 Type type = getDataOperand().getType();
811 p <<
" " << getOperands();
812 p.printOptionalAttrDict((*this)->getAttrs());
817 assert(idx == 0 || idx == 1);
818 return idx == 0 ?
"cond" :
"data";
821 std::string handshake::ConditionalBranchOp::getResultName(
unsigned int idx) {
822 assert(idx == 0 || idx == 1);
823 return idx == ConditionalBranchOp::falseIndex ?
"outFalse" :
"outTrue";
826 bool ConditionalBranchOp::isControl() {
831 ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
832 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
834 ArrayRef<Type> operandTypes(type);
835 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
840 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
846 void SinkOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
853 Type SourceOp::getDataType() {
return getResult().getType(); }
854 unsigned SourceOp::getSize() {
return 1; }
856 ParseResult SourceOp::parse(OpAsmParser &parser, OperationState &result) {
857 if (parser.parseOptionalAttrDict(result.attributes))
863 void SourceOp::print(OpAsmPrinter &p) {
864 p.printOptionalAttrDict((*this)->getAttrs());
869 auto typedValue = dyn_cast<mlir::TypedAttr>(getValue());
871 return emitOpError(
"constant value must be a typed attribute; value is ")
873 if (typedValue.getType() != getResult().getType())
874 return emitOpError() <<
"constant value type " << typedValue.getType()
875 <<
" differs from operation result type "
876 << getResult().getType();
881 void handshake::ConstantOp::getCanonicalizationPatterns(
882 RewritePatternSet &results, MLIRContext *context) {
883 results.insert<circt::handshake::EliminateSunkConstantsPattern>(context);
889 if (
auto initVals = getInitValues()) {
892 <<
"only bufferType buffers are allowed to have initial values.";
894 auto nInits = initVals->size();
895 if (nInits != getSize())
896 return emitOpError() <<
"expected " << getSize()
897 <<
" init values but got " << nInits <<
".";
903 void handshake::BufferOp::getCanonicalizationPatterns(
904 RewritePatternSet &results, MLIRContext *context) {
905 results.insert<circt::handshake::EliminateSunkBuffersPattern>(context);
908 unsigned BufferOp::getSize() {
909 return (*this)->getAttrOfType<IntegerAttr>(
"slots").getValue().getZExtValue();
912 ParseResult BufferOp::parse(OpAsmParser &parser, OperationState &result) {
913 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
915 ArrayRef<Type> operandTypes(type);
916 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
921 auto bufferTypeAttr = BufferTypeEnumAttr::parse(parser, {});
928 result.addAttribute(
"bufferType", bufferTypeAttr);
930 if (parser.parseOperandList(allOperands) ||
931 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
932 parser.parseType(type))
935 result.addTypes({type});
936 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
942 void BufferOp::print(OpAsmPrinter &p) {
944 (*this)->getAttrOfType<IntegerAttr>(
"slots").getValue().getZExtValue();
945 p <<
" [" << size <<
"]";
946 p <<
" " << stringifyEnum(getBufferType());
947 p <<
" " << (*this)->getOperands();
948 p.printOptionalAttrDict((*this)->getAttrs(), {
"slots",
"bufferType"});
949 p <<
" : " << (*this).getDataType();
954 if (idx < nStores * 2) {
955 bool isData = idx % 2 == 0;
956 name = isData ?
"stData" + std::to_string(idx / 2)
957 :
"stAddr" + std::to_string(idx / 2);
960 name =
"ldAddr" + std::to_string(idx);
973 name =
"ldData" + std::to_string(idx);
974 else if (idx < nLoads + nStores)
975 name =
"stDone" + std::to_string(idx - nLoads);
977 name =
"ldDone" + std::to_string(idx - nLoads - nStores);
981 std::string handshake::MemoryOp::getResultName(
unsigned int idx) {
986 auto memrefType = getMemRefType();
988 if (memrefType.getNumDynamicDims() != 0)
990 <<
"memref dimensions for handshake.memory must be static.";
991 if (memrefType.getShape().size() != 1)
992 return emitOpError() <<
"memref must have only a single dimension.";
994 unsigned opStCount = getStCount();
995 unsigned opLdCount = getLdCount();
996 int addressCount = memrefType.getShape().size();
998 auto inputType = getInputs().getType();
999 auto outputType = getOutputs().getType();
1000 Type dataType = memrefType.getElementType();
1002 unsigned numOperands =
static_cast<int>(getInputs().size());
1003 unsigned numResults =
static_cast<int>(getOutputs().size());
1004 if (numOperands != (1 + addressCount) * opStCount + addressCount * opLdCount)
1005 return emitOpError(
"number of operands ")
1006 << numOperands <<
" does not match number expected of "
1007 << 2 * opStCount + opLdCount <<
" with " << addressCount
1008 <<
" address inputs per port";
1010 if (numResults != opStCount + 2 * opLdCount)
1011 return emitOpError(
"number of results ")
1012 << numResults <<
" does not match number expected of "
1013 << opStCount + 2 * opLdCount <<
" with " << addressCount
1014 <<
" address inputs per port";
1016 Type addressType = opStCount > 0 ? inputType[1] : inputType[0];
1018 for (
unsigned i = 0; i < opStCount; i++) {
1019 if (inputType[2 * i] != dataType)
1020 return emitOpError(
"data type for store port ")
1021 << i <<
":" << inputType[2 * i] <<
" doesn't match memory type "
1023 if (inputType[2 * i + 1] != addressType)
1024 return emitOpError(
"address type for store port ")
1025 << i <<
":" << inputType[2 * i + 1]
1026 <<
" doesn't match address type " << addressType;
1028 for (
unsigned i = 0; i < opLdCount; i++) {
1029 Type ldAddressType = inputType[2 * opStCount + i];
1030 if (ldAddressType != addressType)
1031 return emitOpError(
"address type for load port ")
1032 << i <<
":" << ldAddressType <<
" doesn't match address type "
1035 for (
unsigned i = 0; i < opLdCount; i++) {
1036 if (outputType[i] != dataType)
1037 return emitOpError(
"data type for load port ")
1038 << i <<
":" << outputType[i] <<
" doesn't match memory type "
1041 for (
unsigned i = 0; i < opStCount; i++) {
1042 Type syncType = outputType[opLdCount + i];
1043 if (!isa<NoneType>(syncType))
1044 return emitOpError(
"data type for sync port for store port ")
1045 << i <<
":" << syncType <<
" is not 'none'";
1047 for (
unsigned i = 0; i < opLdCount; i++) {
1048 Type syncType = outputType[opLdCount + opStCount + i];
1049 if (!isa<NoneType>(syncType))
1050 return emitOpError(
"data type for sync port for load port ")
1051 << i <<
":" << syncType <<
" is not 'none'";
1064 std::string handshake::ExternalMemoryOp::getResultName(
unsigned int idx) {
1068 void ExternalMemoryOp::build(OpBuilder &builder, OperationState &result,
1069 Value memref, ValueRange inputs,
int ldCount,
1070 int stCount,
int id) {
1071 SmallVector<Value> ops;
1072 ops.push_back(memref);
1073 llvm::append_range(ops, inputs);
1074 result.addOperands(ops);
1076 auto memrefType = cast<MemRefType>(memref.getType());
1079 result.types.append(ldCount, memrefType.getElementType());
1082 result.types.append(stCount + ldCount, builder.getNoneType());
1085 Type i32Type = builder.getIntegerType(32);
1086 result.addAttribute(
"id", builder.getIntegerAttr(i32Type,
id));
1087 result.addAttribute(
"ldCount", builder.getIntegerAttr(i32Type, ldCount));
1088 result.addAttribute(
"stCount", builder.getIntegerAttr(i32Type, stCount));
1091 llvm::SmallVector<handshake::MemLoadInterface>
1096 llvm::SmallVector<handshake::MemStoreInterface>
1101 void MemoryOp::build(OpBuilder &builder, OperationState &result,
1102 ValueRange operands,
int outputs,
int controlOutputs,
1103 bool lsq,
int id, Value memref) {
1104 result.addOperands(operands);
1106 auto memrefType = cast<MemRefType>(memref.getType());
1109 result.types.append(outputs, memrefType.getElementType());
1112 result.types.append(controlOutputs, builder.getNoneType());
1113 result.addAttribute(
"lsq", builder.getBoolAttr(lsq));
1114 result.addAttribute(
"memRefType",
TypeAttr::get(memrefType));
1117 Type i32Type = builder.getIntegerType(32);
1118 result.addAttribute(
"id", builder.getIntegerAttr(i32Type,
id));
1121 result.addAttribute(
"ldCount", builder.getIntegerAttr(i32Type, outputs));
1122 result.addAttribute(
1123 "stCount", builder.getIntegerAttr(i32Type, controlOutputs - outputs));
1135 bool handshake::MemoryOp::allocateMemory(
1136 llvm::DenseMap<unsigned, unsigned> &memoryMap,
1137 std::vector<std::vector<llvm::Any>> &store,
1138 std::vector<double> &storeTimes) {
1139 if (memoryMap.count(getId()))
1142 auto type = getMemRefType();
1143 std::vector<llvm::Any> in;
1145 ArrayRef<int64_t> shape = type.getShape();
1146 int allocationSize = 1;
1148 for (int64_t dim : shape) {
1150 allocationSize *= dim;
1152 assert(count < in.size());
1153 allocationSize *= llvm::any_cast<APInt>(in[count++]).getSExtValue();
1156 unsigned ptr = store.size();
1157 store.resize(ptr + 1);
1158 storeTimes.resize(ptr + 1);
1159 store[ptr].resize(allocationSize);
1160 storeTimes[ptr] = 0.0;
1163 for (
int i = 0; i < allocationSize; i++) {
1165 store[ptr][i] = APInt(
width, 0);
1167 store[ptr][i] = APFloat(0.0);
1169 llvm_unreachable(
"Unknown result type!\n");
1173 memoryMap[getId()] = ptr;
1178 unsigned nAddresses = getAddresses().size();
1180 if (idx < nAddresses)
1181 opName =
"addrIn" + std::to_string(idx);
1182 else if (idx == nAddresses)
1183 opName =
"dataFromMem";
1189 std::string handshake::LoadOp::getResultName(
unsigned int idx) {
1190 std::string resName;
1192 resName =
"dataOut";
1194 resName =
"addrOut" + std::to_string(idx - 1);
1198 void handshake::LoadOp::build(OpBuilder &builder, OperationState &result,
1199 Value memref, ValueRange indices) {
1202 result.addOperands(indices);
1205 auto memrefType = cast<MemRefType>(memref.getType());
1208 result.types.push_back(memrefType.getElementType());
1211 result.types.append(indices.size(), builder.getIndexType());
1215 OperationState &result) {
1216 SmallVector<OpAsmParser::UnresolvedOperand, 4> addressOperands,
1217 remainingOperands, allOperands;
1218 SmallVector<Type, 1> parsedTypes, allTypes;
1219 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1221 if (parser.parseLSquare() || parser.parseOperandList(addressOperands) ||
1222 parser.parseRSquare() || parser.parseOperandList(remainingOperands) ||
1223 parser.parseColon() || parser.parseTypeList(parsedTypes))
1228 Type dataType = parsedTypes.back();
1229 auto parsedTypesRef = ArrayRef(parsedTypes);
1230 result.addTypes(dataType);
1231 result.addTypes(parsedTypesRef.drop_back());
1232 allOperands.append(addressOperands);
1233 allOperands.append(remainingOperands);
1234 allTypes.append(parsedTypes);
1236 if (parser.resolveOperands(allOperands, allTypes, allOperandLoc,
1242 template <
typename MemOp>
1245 p << op.getAddresses();
1246 p <<
"] " << op.getData() <<
", " << op.getCtrl() <<
" : ";
1247 llvm::interleaveComma(op.getAddresses(), p,
1248 [&](Value v) { p << v.getType(); });
1249 p <<
", " << op.getData().getType();
1252 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
1259 unsigned nAddresses = getAddresses().size();
1261 if (idx < nAddresses)
1262 opName =
"addrIn" + std::to_string(idx);
1263 else if (idx == nAddresses)
1270 template <
typename TMemoryOp>
1272 if (op.getAddresses().size() == 0)
1273 return op.emitOpError() <<
"No addresses were specified";
1280 std::string handshake::StoreOp::getResultName(
unsigned int idx) {
1281 std::string resName;
1283 resName =
"dataToMem";
1285 resName =
"addrOut" + std::to_string(idx - 1);
1289 void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
1290 Value valueToStore, ValueRange indices) {
1293 result.addOperands(indices);
1296 result.addOperands(valueToStore);
1299 result.types.push_back(valueToStore.getType());
1302 result.types.append(indices.size(), builder.getIndexType());
1307 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
1313 bool JoinOp::isControl() {
return true; }
1315 ParseResult JoinOp::parse(OpAsmParser &parser, OperationState &result) {
1316 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1317 SmallVector<Type> types;
1319 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1320 if (parser.parseOperandList(operands) ||
1321 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1322 parser.parseTypeList(types))
1325 if (parser.resolveOperands(operands, types, allOperandLoc, result.operands))
1332 void JoinOp::print(OpAsmPrinter &p) {
1333 p <<
" " << getData();
1334 p.printOptionalAttrDict((*this)->getAttrs(), {
"control"});
1335 p <<
" : " << getData().getTypes();
1339 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1341 auto fnAttr = this->getModuleAttr();
1342 assert(fnAttr &&
"requires a 'module' symbol reference attribute");
1344 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*
this, fnAttr);
1346 return emitOpError() <<
"'" << fnAttr.getValue()
1347 <<
"' does not reference a valid handshake function";
1350 auto fnType = fn.getFunctionType();
1351 if (fnType.getNumInputs() != getNumOperands())
1353 "incorrect number of operands for the referenced handshake function");
1355 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
1356 if (getOperand(i).getType() != fnType.getInput(i))
1357 return emitOpError(
"operand type mismatch: expected operand type ")
1358 << fnType.getInput(i) <<
", but provided "
1359 << getOperand(i).getType() <<
" for operand number " << i;
1361 if (fnType.getNumResults() != getNumResults())
1363 "incorrect number of results for the referenced handshake function");
1365 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
1366 if (getResult(i).getType() != fnType.getResult(i))
1367 return emitOpError(
"result type mismatch: expected result type ")
1368 << fnType.getResult(i) <<
", but provided "
1369 << getResult(i).getType() <<
" for result number " << i;
1378 ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1379 OpAsmParser::UnresolvedOperand tuple;
1382 if (parser.parseOperand(tuple) ||
1383 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1384 parser.parseType(type))
1387 if (parser.resolveOperand(tuple, type, result.operands))
1390 result.addTypes(type.getTypes());
1395 void UnpackOp::print(OpAsmPrinter &p) {
1396 p <<
" " << getInput();
1397 p.printOptionalAttrDict((*this)->getAttrs());
1398 p <<
" : " << getInput().getType();
1401 ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1402 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1403 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1406 if (parser.parseOperandList(operands) ||
1407 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1408 parser.parseType(type))
1411 if (parser.resolveOperands(operands, type.getTypes(), allOperandLoc,
1415 result.addTypes(type);
1420 void PackOp::print(OpAsmPrinter &p) {
1421 p <<
" " << getInputs();
1422 p.printOptionalAttrDict((*this)->getAttrs());
1423 p <<
" : " << getResult().getType();
1431 auto *parent = (*this)->getParentOp();
1432 auto function = dyn_cast<handshake::FuncOp>(parent);
1434 return emitOpError(
"must have a handshake.func parent");
1437 const auto &results =
function.getResultTypes();
1438 if (getNumOperands() != results.size())
1439 return emitOpError(
"has ")
1440 << getNumOperands() <<
" operands, but enclosing function returns "
1443 for (
unsigned i = 0, e = results.size(); i != e; ++i)
1444 if (getOperand(i).getType() != results[i])
1445 return emitError() <<
"type of return operand " << i <<
" ("
1446 << getOperand(i).getType()
1447 <<
") doesn't match function result type ("
1448 << results[i] <<
")";
1453 #define GET_OP_CLASSES
1454 #include "circt/Dialect/Handshake/Handshake.cpp.inc"
assert(baseType &&"element must be base type")
static StringRef getOperandName(Value operand)
static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal, uint64_t numOperands)
Verifies whether an indexing value is wide enough to index into a provided number of operands.
static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result)
static ParseResult parseSostOperation(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, OperationState &result, int &size, Type &type, bool explicitSize)
static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op)
static ParseResult parseFuncOpArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Type > &resTypes, SmallVectorImpl< DictionaryAttr > &resAttrs)
Parses a FuncOp signature using mlir::function_interface_impl::parseFunctionSignature while getting a...
static Value getDematerialized(Value v)
Returns a dematerialized version of the value 'v', defined as the source of the value before passing ...
static bool isControlCheckTypeAndOperand(Type dataType, Value operand)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static std::string getMemoryResultName(unsigned nLoads, unsigned nStores, unsigned idx)
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v)
static LogicalResult verifyMemoryAccessOp(TMemoryOp op)
static void addStringToStringArrayAttr(Builder &builder, Operation *op, StringRef attrName, StringAttr str)
Helper function for appending a string to an array attribute, and rewriting the attribute back to the...
static std::string defaultOperandName(unsigned int idx)
static SmallVector< Attribute > getFuncOpNames(Builder &builder, unsigned cnt, StringRef prefix)
Generates names for a handshake.func input and output arguments, based on the number of args as well ...
static std::string getMemoryOperandName(unsigned nStores, unsigned idx)
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
static ParseResult parseMemoryAccessOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
bool isControlOpImpl(Operation *op)
Default implementation for checking whether an operation is a control operation.
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.