15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/Interfaces/FunctionImplementation.h"
28 #include "mlir/Transforms/InliningUtils.h"
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
35 using namespace circt;
40 #include "circt/Dialect/Handshake/HandshakeCanonicalization.h.inc"
43 if (
auto sostInterface = dyn_cast<SOSTInterface>(op); sostInterface)
44 return sostInterface.sostIsControl();
53 return "in" + std::to_string(idx);
57 if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
64 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
65 OperationState &result,
int &size, Type &type,
71 if (parser.parseOperandList(operands) ||
72 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
73 parser.parseType(type))
77 size = operands.size();
84 uint64_t numOperands) {
85 auto indexType = indexVal.getType();
89 if (
auto integerType = indexType.dyn_cast<IntegerType>())
90 indexWidth = integerType.getWidth();
91 else if (indexType.isIndex())
92 indexWidth = IndexType::kInternalStorageBitWidth;
94 return op->emitError(
"unsupported type for indexing value: ") << indexType;
97 if (indexWidth < 64) {
98 uint64_t maxNumOperands = (uint64_t)1 << indexWidth;
99 if (numOperands > maxNumOperands)
100 return op->emitError(
"bitwidth of indexing value is ")
101 << indexWidth <<
", which can index into " << maxNumOperands
102 <<
" operands, but found " << numOperands <<
" operands";
110 if (dataType.isa<NoneType>())
115 auto *defOp = operand.getDefiningOp();
116 return isa_and_nonnull<ControlMergeOp>(defOp) &&
117 operand == defOp->getResult(0);
120 template <
typename TMemOp>
121 llvm::SmallVector<handshake::MemLoadInterface>
getLoadPorts(TMemOp op) {
122 llvm::SmallVector<MemLoadInterface> ports;
129 unsigned stCount = op.getStCount();
130 unsigned ldCount = op.getLdCount();
131 for (
unsigned i = 0, e = ldCount; i != e; ++i) {
134 ldif.
addressIn = op.getInputs()[stCount * 2 + i];
135 ldif.
dataOut = op.getResult(i);
136 ldif.
doneOut = op.getResult(ldCount + stCount + i);
137 ports.push_back(ldif);
142 template <
typename TMemOp>
144 llvm::SmallVector<MemStoreInterface> ports;
151 unsigned ldCount = op.getLdCount();
152 for (
unsigned i = 0, e = op.getStCount(); i != e; ++i) {
155 stif.
dataIn = op.getInputs()[i * 2];
156 stif.
addressIn = op.getInputs()[i * 2 + 1];
157 stif.
doneOut = op.getResult(ldCount + i);
158 ports.push_back(stif);
163 unsigned ForkOp::getSize() {
return getResults().size(); }
165 static ParseResult
parseForkOp(OpAsmParser &parser, OperationState &result) {
166 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
168 ArrayRef<Type> operandTypes(type);
169 SmallVector<Type, 1> resultTypes;
170 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
175 resultTypes.assign(size, type);
176 result.addTypes(resultTypes);
177 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
183 ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
187 void ForkOp::print(OpAsmPrinter &p) { sostPrint(p,
true); }
194 LogicalResult matchAndRewrite(ForkOp op,
195 PatternRewriter &rewriter)
const override {
196 std::set<unsigned> unusedIndexes;
198 for (
auto res : llvm::enumerate(op.getResults()))
199 if (res.value().getUses().empty())
200 unusedIndexes.insert(res.index());
202 if (unusedIndexes.empty())
206 rewriter.setInsertionPoint(op);
207 auto operand = op.getOperand();
208 auto newFork = rewriter.create<ForkOp>(
209 op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
210 rewriter.updateRootInPlace(op, [&] {
212 for (
auto oldRes : llvm::enumerate(op.getResults()))
213 if (unusedIndexes.count(oldRes.index()) == 0)
214 oldRes.value().replaceAllUsesWith(newFork.getResults()[i++]);
216 rewriter.eraseOp(op);
224 LogicalResult matchAndRewrite(ForkOp op,
225 PatternRewriter &rewriter)
const override {
226 auto parentForkOp = op.getOperand().getDefiningOp<ForkOp>();
234 unsigned totalNumOuts = op.getSize() + parentForkOp.getSize();
235 rewriter.updateRootInPlace(parentForkOp, [&] {
238 auto newParentForkOp = rewriter.create<ForkOp>(
239 parentForkOp.getLoc(), parentForkOp.getOperand(), totalNumOuts);
242 llvm::zip(parentForkOp->getResults(), newParentForkOp.getResults()))
243 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
247 rewriter.replaceOp(op,
248 newParentForkOp.getResults().take_back(op.getSize()));
250 rewriter.eraseOp(parentForkOp);
257 void handshake::ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
258 MLIRContext *context) {
259 results.insert<circt::handshake::EliminateSimpleForksPattern,
260 EliminateForkToForkPattern, EliminateUnusedForkResultsPattern>(
264 unsigned LazyForkOp::getSize() {
return getResults().size(); }
266 bool LazyForkOp::sostIsControl() {
270 ParseResult LazyForkOp::parse(OpAsmParser &parser, OperationState &result) {
274 void LazyForkOp::print(OpAsmPrinter &p) { sostPrint(p,
true); }
276 ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
277 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
279 ArrayRef<Type> operandTypes(type);
280 SmallVector<Type, 1> resultTypes, dataOperandsTypes;
281 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
286 dataOperandsTypes.assign(size, type);
287 resultTypes.push_back(type);
288 result.addTypes(resultTypes);
289 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
295 void MergeOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
297 void MergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
298 MLIRContext *context) {
299 results.insert<circt::handshake::EliminateSimpleMergesPattern>(context);
305 Operation *parentOp = v.getDefiningOp();
309 return llvm::TypeSwitch<Operation *, Value>(parentOp)
314 .Default([&](
auto) {
return v; });
324 LogicalResult matchAndRewrite(MuxOp op,
325 PatternRewriter &rewriter)
const override {
327 if (!llvm::all_of(op.getDataOperands(), [&](Value operand) {
328 return getDematerialized(operand) == firstDataOperand;
331 rewriter.replaceOp(op, firstDataOperand);
338 LogicalResult matchAndRewrite(MuxOp op,
339 PatternRewriter &rewriter)
const override {
340 if (op.getDataOperands().size() != 1)
343 rewriter.replaceOp(op, op.getDataOperands()[0]);
350 LogicalResult matchAndRewrite(MuxOp op,
351 PatternRewriter &rewriter)
const override {
353 auto dataOperands = op.getDataOperands();
354 if (dataOperands.size() != 2)
358 ConditionalBranchOp firstParentCBranch =
359 dataOperands[0].getDefiningOp<ConditionalBranchOp>();
360 if (!firstParentCBranch)
362 auto secondParentCBranch =
363 dataOperands[1].getDefiningOp<ConditionalBranchOp>();
364 if (!secondParentCBranch || firstParentCBranch != secondParentCBranch)
367 rewriter.updateRootInPlace(firstParentCBranch, [&] {
369 rewriter.replaceOp(op, firstParentCBranch.getDataOperand());
378 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
379 MLIRContext *context) {
380 results.insert<EliminateSimpleMuxesPattern, EliminateUnaryMuxesPattern,
381 EliminateCBranchIntoMuxPattern>(context);
386 ValueRange operands, DictionaryAttr attributes,
387 mlir::OpaqueProperties properties,
388 mlir::RegionRange regions,
389 SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
392 if (operands.size() < 2)
395 inferredReturnTypes.push_back(operands[1].getType());
399 bool MuxOp::isControl() {
return getResult().getType().isa<NoneType>(); }
405 ParseResult MuxOp::parse(OpAsmParser &parser, OperationState &result) {
406 OpAsmParser::UnresolvedOperand selectOperand;
407 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
408 Type selectType, dataType;
409 SmallVector<Type, 1> dataOperandsTypes;
410 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
411 if (parser.parseOperand(selectOperand) || parser.parseLSquare() ||
412 parser.parseOperandList(allOperands) || parser.parseRSquare() ||
413 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
414 parser.parseType(selectType) || parser.parseComma() ||
415 parser.parseType(dataType))
418 int size = allOperands.size();
419 dataOperandsTypes.assign(size, dataType);
420 result.addTypes(dataType);
421 allOperands.insert(allOperands.begin(), selectOperand);
422 if (parser.resolveOperands(
424 llvm::concat<const Type>(ArrayRef<Type>(selectType),
425 ArrayRef<Type>(dataOperandsTypes)),
426 allOperandLoc, result.operands))
431 void MuxOp::print(OpAsmPrinter &p) {
432 Type selectType = getSelectOperand().getType();
433 auto ops = getOperands();
434 p <<
' ' << ops.front();
436 p.printOperands(ops.drop_front());
438 p.printOptionalAttrDict((*this)->getAttrs());
439 p <<
" : " << selectType <<
", " << getResult().getType();
442 LogicalResult MuxOp::verify() {
444 getDataOperands().size());
447 std::string handshake::ControlMergeOp::getResultName(
unsigned int idx) {
448 assert(idx == 0 || idx == 1);
449 return idx == 0 ?
"dataOut" :
"index";
452 ParseResult ControlMergeOp::parse(OpAsmParser &parser, OperationState &result) {
453 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
454 Type resultType, indexType;
455 SmallVector<Type> resultTypes, dataOperandsTypes;
456 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
461 if (parser.parseComma() || parser.parseType(indexType))
464 dataOperandsTypes.assign(size, resultType);
465 resultTypes.push_back(resultType);
466 resultTypes.push_back(indexType);
467 result.addTypes(resultTypes);
468 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
474 void ControlMergeOp::print(OpAsmPrinter &p) {
477 p <<
", " << getIndex().getType();
480 LogicalResult ControlMergeOp::verify() {
481 auto operands = getOperands();
482 if (operands.empty())
483 return emitOpError(
"operation must have at least one operand");
484 if (operands[0].getType() != getResult().getType())
485 return emitOpError(
"type of first result should match type of operands");
489 LogicalResult FuncOp::verify() {
497 auto fnInputTypes = getArgumentTypes();
498 Block &entryBlock = front();
500 for (
unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
501 if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
502 return emitOpError(
"type of entry block argument #")
503 << i <<
'(' << entryBlock.getArgument(i).getType()
504 <<
") must match the type of the corresponding argument in "
505 <<
"function signature(" << fnInputTypes[i] <<
')';
508 auto verifyPortNameAttr = [&](StringRef attrName,
509 unsigned numIOs) -> LogicalResult {
510 auto portNamesAttr = (*this)->getAttrOfType<ArrayAttr>(attrName);
513 return emitOpError() <<
"expected attribute '" << attrName <<
"'.";
515 auto portNames = portNamesAttr.getValue();
516 if (portNames.size() != numIOs)
517 return emitOpError() <<
"attribute '" << attrName <<
"' has "
519 <<
" entries but is expected to have " << numIOs
522 if (llvm::any_of(portNames,
523 [&](Attribute attr) {
return !attr.isa<StringAttr>(); }))
524 return emitOpError() <<
"expected all entries in attribute '" << attrName
525 <<
"' to be strings.";
529 if (failed(verifyPortNameAttr(
"argNames", getNumArguments())))
531 if (failed(verifyPortNameAttr(
"resNames", getNumResults())))
535 for (
auto arg : entryBlock.getArguments()) {
536 if (!arg.getType().isa<MemRefType>())
538 if (arg.getUsers().empty() ||
539 !isa<ExternalMemoryOp>(*arg.getUsers().begin()))
540 return emitOpError(
"expected that block argument #")
541 << arg.getArgNumber() <<
" is used by an 'extmemory' operation";
552 SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
553 SmallVectorImpl<Type> &resTypes,
554 SmallVectorImpl<DictionaryAttr> &resAttrs) {
556 if (mlir::function_interface_impl::parseFunctionSignature(
557 parser,
true, entryArgs, isVariadic, resTypes,
569 SmallVector<Attribute> resNames;
570 for (
unsigned i = 0; i < cnt; ++i)
571 resNames.push_back(
builder.getStringAttr(prefix + std::to_string(i)));
575 void handshake::FuncOp::build(OpBuilder &
builder, OperationState &state,
576 StringRef name, FunctionType type,
577 ArrayRef<NamedAttribute> attrs) {
578 state.addAttribute(SymbolTable::getSymbolAttrName(),
580 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
582 state.attributes.append(attrs.begin(), attrs.end());
584 if (
const auto *argNamesAttrIt = llvm::find_if(
585 attrs, [&](
auto attr) {
return attr.getName() ==
"argNames"; });
586 argNamesAttrIt == attrs.end())
587 state.addAttribute(
"argNames",
builder.getArrayAttr({}));
589 if (llvm::find_if(attrs, [&](
auto attr) {
590 return attr.getName() ==
"resNames";
592 state.addAttribute(
"resNames",
builder.getArrayAttr({}));
600 StringRef attrName, StringAttr str) {
601 llvm::SmallVector<Attribute> attrs;
602 llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
603 std::back_inserter(attrs));
604 attrs.push_back(str);
605 op->setAttr(attrName,
builder.getArrayAttr(attrs));
608 void handshake::FuncOp::resolveArgAndResNames() {
615 auto argNames = getArgNames().getValue();
616 auto resNames = getResNames().getValue();
619 auto resolveNames = [&](
auto &fallbackNames,
auto &actualNames,
620 StringRef attrName) {
621 for (
auto fallbackName : llvm::enumerate(fallbackNames)) {
622 if (actualNames.size() <= fallbackName.index())
624 builder, this->getOperation(), attrName,
625 fallbackName.value().template cast<StringAttr>());
628 resolveNames(fallbackArgNames, argNames,
"argNames");
629 resolveNames(fallbackResNames, resNames,
"resNames");
632 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
633 auto &
builder = parser.getBuilder();
635 SmallVector<OpAsmParser::Argument> args;
636 SmallVector<Type> resTypes;
637 SmallVector<DictionaryAttr> resAttributes;
638 SmallVector<Attribute> argNames;
641 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
644 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
645 result.attributes) ||
648 mlir::function_interface_impl::addArgAndResultAttrs(
649 builder, result, args, resAttributes,
650 handshake::FuncOp::getArgAttrsAttrName(result.name),
651 handshake::FuncOp::getResAttrsAttrName(result.name));
654 SmallVector<Type> argTypes;
655 for (
auto arg : args)
656 argTypes.push_back(arg.type);
659 handshake::FuncOp::getFunctionTypeAttrName(result.name),
665 llvm::any_of(args, [](
auto arg) {
return arg.ssaName.name.empty(); });
669 llvm::transform(args, std::back_inserter(argNames), [&](
auto arg) {
670 return builder.getStringAttr(arg.ssaName.name.drop_front());
675 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
680 if (!result.attributes.get(
"argNames"))
681 result.addAttribute(
"argNames",
builder.getArrayAttr(argNames));
682 if (!result.attributes.get(
"resNames")) {
684 result.addAttribute(
"resNames",
builder.getArrayAttr(resNames));
689 auto *body = result.addRegion();
690 llvm::SMLoc loc = parser.getCurrentLocation();
691 auto parseResult = parser.parseOptionalRegion(*body, args,
693 if (!parseResult.has_value())
696 if (failed(*parseResult))
700 return parser.emitError(loc,
"expected non-empty function body");
706 void FuncOp::print(OpAsmPrinter &p) {
707 mlir::function_interface_impl::printFunctionOp(
708 p, *
this,
true, getFunctionTypeAttrName(),
709 getArgAttrsAttrName(), getResAttrsAttrName());
713 struct EliminateSimpleControlMergesPattern
717 LogicalResult matchAndRewrite(ControlMergeOp op,
718 PatternRewriter &rewriter)
const override;
722 LogicalResult EliminateSimpleControlMergesPattern::matchAndRewrite(
723 ControlMergeOp op, PatternRewriter &rewriter)
const {
724 auto dataResult = op.getResult();
725 auto choiceResult = op.getIndex();
726 auto choiceUnused = choiceResult.use_empty();
727 if (!choiceUnused && !choiceResult.hasOneUse())
730 Operation *choiceUser =
nullptr;
731 if (choiceResult.hasOneUse()) {
732 choiceUser = choiceResult.getUses().begin().getUser();
733 if (!isa<SinkOp>(choiceUser))
737 auto merge = rewriter.create<MergeOp>(op.getLoc(), op.getDataOperands());
739 for (
auto &use : llvm::make_early_inc_range(dataResult.getUses())) {
740 auto *user = use.getOwner();
741 rewriter.updateRootInPlace(
742 user, [&]() { user->setOperand(use.getOperandNumber(), merge); });
746 rewriter.eraseOp(op);
750 rewriter.eraseOp(choiceUser);
751 rewriter.eraseOp(op);
755 void ControlMergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
756 MLIRContext *context) {
757 results.insert<EliminateSimpleControlMergesPattern>(context);
760 bool BranchOp::sostIsControl() {
764 void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
765 MLIRContext *context) {
766 results.insert<circt::handshake::EliminateSimpleBranchesPattern>(context);
769 ParseResult BranchOp::parse(OpAsmParser &parser, OperationState &result) {
770 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
772 ArrayRef<Type> operandTypes(type);
773 SmallVector<Type, 1> dataOperandsTypes;
774 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
779 dataOperandsTypes.assign(size, type);
780 result.addTypes({type});
781 if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
787 void BranchOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
789 ParseResult ConditionalBranchOp::parse(OpAsmParser &parser,
790 OperationState &result) {
791 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
793 SmallVector<Type> operandTypes;
794 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
795 if (parser.parseOperandList(allOperands) ||
796 parser.parseOptionalAttrDict(result.attributes) ||
797 parser.parseColonType(dataType))
800 if (allOperands.size() != 2)
801 return parser.emitError(parser.getCurrentLocation(),
802 "Expected exactly 2 operands");
804 result.addTypes({dataType, dataType});
806 operandTypes.push_back(dataType);
807 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
814 void ConditionalBranchOp::print(OpAsmPrinter &p) {
815 Type type = getDataOperand().getType();
816 p <<
" " << getOperands();
817 p.printOptionalAttrDict((*this)->getAttrs());
822 assert(idx == 0 || idx == 1);
823 return idx == 0 ?
"cond" :
"data";
826 std::string handshake::ConditionalBranchOp::getResultName(
unsigned int idx) {
827 assert(idx == 0 || idx == 1);
828 return idx == ConditionalBranchOp::falseIndex ?
"outFalse" :
"outTrue";
831 bool ConditionalBranchOp::isControl() {
836 ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
837 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
839 ArrayRef<Type> operandTypes(type);
840 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
845 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
851 void SinkOp::print(OpAsmPrinter &p) { sostPrint(p,
false); }
858 Type SourceOp::getDataType() {
return getResult().getType(); }
859 unsigned SourceOp::getSize() {
return 1; }
861 ParseResult SourceOp::parse(OpAsmParser &parser, OperationState &result) {
862 if (parser.parseOptionalAttrDict(result.attributes))
868 void SourceOp::print(OpAsmPrinter &p) {
869 p.printOptionalAttrDict((*this)->getAttrs());
872 LogicalResult ConstantOp::verify() {
874 auto typedValue = getValue().dyn_cast<mlir::TypedAttr>();
876 return emitOpError(
"constant value must be a typed attribute; value is ")
878 if (typedValue.getType() != getResult().getType())
879 return emitOpError() <<
"constant value type " << typedValue.getType()
880 <<
" differs from operation result type "
881 << getResult().getType();
886 void handshake::ConstantOp::getCanonicalizationPatterns(
887 RewritePatternSet &results, MLIRContext *context) {
888 results.insert<circt::handshake::EliminateSunkConstantsPattern>(context);
891 LogicalResult BufferOp::verify() {
894 if (
auto initVals = getInitValues()) {
897 <<
"only bufferType buffers are allowed to have initial values.";
899 auto nInits = initVals->size();
900 if (nInits != getSize())
901 return emitOpError() <<
"expected " << getSize()
902 <<
" init values but got " << nInits <<
".";
908 void handshake::BufferOp::getCanonicalizationPatterns(
909 RewritePatternSet &results, MLIRContext *context) {
910 results.insert<circt::handshake::EliminateSunkBuffersPattern>(context);
913 unsigned BufferOp::getSize() {
914 return (*this)->getAttrOfType<IntegerAttr>(
"slots").getValue().getZExtValue();
917 ParseResult BufferOp::parse(OpAsmParser &parser, OperationState &result) {
918 SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
920 ArrayRef<Type> operandTypes(type);
921 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
926 auto bufferTypeAttr = BufferTypeEnumAttr::parse(parser, {});
933 result.addAttribute(
"bufferType", bufferTypeAttr);
935 if (parser.parseOperandList(allOperands) ||
936 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
937 parser.parseType(type))
940 result.addTypes({type});
941 if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
947 void BufferOp::print(OpAsmPrinter &p) {
949 (*this)->getAttrOfType<IntegerAttr>(
"slots").getValue().getZExtValue();
950 p <<
" [" << size <<
"]";
951 p <<
" " << stringifyEnum(getBufferType());
952 p <<
" " << (*this)->getOperands();
953 p.printOptionalAttrDict((*this)->getAttrs(), {
"slots",
"bufferType"});
954 p <<
" : " << (*this).getDataType();
959 if (idx < nStores * 2) {
960 bool isData = idx % 2 == 0;
961 name = isData ?
"stData" + std::to_string(idx / 2)
962 :
"stAddr" + std::to_string(idx / 2);
965 name =
"ldAddr" + std::to_string(idx);
978 name =
"ldData" + std::to_string(idx);
979 else if (idx < nLoads + nStores)
980 name =
"stDone" + std::to_string(idx - nLoads);
982 name =
"ldDone" + std::to_string(idx - nLoads - nStores);
986 std::string handshake::MemoryOp::getResultName(
unsigned int idx) {
990 LogicalResult MemoryOp::verify() {
991 auto memrefType = getMemRefType();
993 if (memrefType.getNumDynamicDims() != 0)
995 <<
"memref dimensions for handshake.memory must be static.";
996 if (memrefType.getShape().size() != 1)
997 return emitOpError() <<
"memref must have only a single dimension.";
999 unsigned opStCount = getStCount();
1000 unsigned opLdCount = getLdCount();
1001 int addressCount = memrefType.getShape().size();
1003 auto inputType = getInputs().getType();
1004 auto outputType = getOutputs().getType();
1005 Type dataType = memrefType.getElementType();
1007 unsigned numOperands =
static_cast<int>(getInputs().size());
1008 unsigned numResults =
static_cast<int>(getOutputs().size());
1009 if (numOperands != (1 + addressCount) * opStCount + addressCount * opLdCount)
1010 return emitOpError(
"number of operands ")
1011 << numOperands <<
" does not match number expected of "
1012 << 2 * opStCount + opLdCount <<
" with " << addressCount
1013 <<
" address inputs per port";
1015 if (numResults != opStCount + 2 * opLdCount)
1016 return emitOpError(
"number of results ")
1017 << numResults <<
" does not match number expected of "
1018 << opStCount + 2 * opLdCount <<
" with " << addressCount
1019 <<
" address inputs per port";
1021 Type addressType = opStCount > 0 ? inputType[1] : inputType[0];
1023 for (
unsigned i = 0; i < opStCount; i++) {
1024 if (inputType[2 * i] != dataType)
1025 return emitOpError(
"data type for store port ")
1026 << i <<
":" << inputType[2 * i] <<
" doesn't match memory type "
1028 if (inputType[2 * i + 1] != addressType)
1029 return emitOpError(
"address type for store port ")
1030 << i <<
":" << inputType[2 * i + 1]
1031 <<
" doesn't match address type " << addressType;
1033 for (
unsigned i = 0; i < opLdCount; i++) {
1034 Type ldAddressType = inputType[2 * opStCount + i];
1035 if (ldAddressType != addressType)
1036 return emitOpError(
"address type for load port ")
1037 << i <<
":" << ldAddressType <<
" doesn't match address type "
1040 for (
unsigned i = 0; i < opLdCount; i++) {
1041 if (outputType[i] != dataType)
1042 return emitOpError(
"data type for load port ")
1043 << i <<
":" << outputType[i] <<
" doesn't match memory type "
1046 for (
unsigned i = 0; i < opStCount; i++) {
1047 Type syncType = outputType[opLdCount + i];
1048 if (!syncType.isa<NoneType>())
1049 return emitOpError(
"data type for sync port for store port ")
1050 << i <<
":" << syncType <<
" is not 'none'";
1052 for (
unsigned i = 0; i < opLdCount; i++) {
1053 Type syncType = outputType[opLdCount + opStCount + i];
1054 if (!syncType.isa<NoneType>())
1055 return emitOpError(
"data type for sync port for load port ")
1056 << i <<
":" << syncType <<
" is not 'none'";
1069 std::string handshake::ExternalMemoryOp::getResultName(
unsigned int idx) {
1073 void ExternalMemoryOp::build(OpBuilder &
builder, OperationState &result,
1074 Value memref, ValueRange
inputs,
int ldCount,
1075 int stCount,
int id) {
1076 SmallVector<Value> ops;
1077 ops.push_back(memref);
1078 llvm::append_range(ops,
inputs);
1079 result.addOperands(ops);
1081 auto memrefType = memref.getType().cast<MemRefType>();
1084 result.types.append(ldCount, memrefType.getElementType());
1087 result.types.append(stCount + ldCount,
builder.getNoneType());
1090 Type i32Type =
builder.getIntegerType(32);
1091 result.addAttribute(
"id",
builder.getIntegerAttr(i32Type,
id));
1092 result.addAttribute(
"ldCount",
builder.getIntegerAttr(i32Type, ldCount));
1093 result.addAttribute(
"stCount",
builder.getIntegerAttr(i32Type, stCount));
1096 llvm::SmallVector<handshake::MemLoadInterface>
1101 llvm::SmallVector<handshake::MemStoreInterface>
1106 void MemoryOp::build(OpBuilder &
builder, OperationState &result,
1107 ValueRange operands,
int outputs,
int controlOutputs,
1108 bool lsq,
int id, Value memref) {
1109 result.addOperands(operands);
1111 auto memrefType = memref.getType().cast<MemRefType>();
1114 result.types.append(
outputs, memrefType.getElementType());
1117 result.types.append(controlOutputs,
builder.getNoneType());
1118 result.addAttribute(
"lsq",
builder.getBoolAttr(lsq));
1119 result.addAttribute(
"memRefType",
TypeAttr::get(memrefType));
1122 Type i32Type =
builder.getIntegerType(32);
1123 result.addAttribute(
"id",
builder.getIntegerAttr(i32Type,
id));
1126 result.addAttribute(
"ldCount",
builder.getIntegerAttr(i32Type,
outputs));
1127 result.addAttribute(
1128 "stCount",
builder.getIntegerAttr(i32Type, controlOutputs -
outputs));
1140 bool handshake::MemoryOp::allocateMemory(
1141 llvm::DenseMap<unsigned, unsigned> &memoryMap,
1142 std::vector<std::vector<llvm::Any>> &store,
1143 std::vector<double> &storeTimes) {
1144 if (memoryMap.count(getId()))
1147 auto type = getMemRefType();
1148 std::vector<llvm::Any> in;
1150 ArrayRef<int64_t> shape = type.getShape();
1151 int allocationSize = 1;
1153 for (int64_t dim : shape) {
1155 allocationSize *= dim;
1157 assert(count < in.size());
1158 allocationSize *= llvm::any_cast<APInt>(in[count++]).getSExtValue();
1161 unsigned ptr = store.size();
1162 store.resize(ptr + 1);
1163 storeTimes.resize(ptr + 1);
1164 store[ptr].resize(allocationSize);
1165 storeTimes[ptr] = 0.0;
1168 for (
int i = 0; i < allocationSize; i++) {
1170 store[ptr][i] = APInt(
width, 0);
1172 store[ptr][i] = APFloat(0.0);
1174 llvm_unreachable(
"Unknown result type!\n");
1178 memoryMap[getId()] = ptr;
1183 unsigned nAddresses = getAddresses().size();
1185 if (idx < nAddresses)
1186 opName =
"addrIn" + std::to_string(idx);
1187 else if (idx == nAddresses)
1188 opName =
"dataFromMem";
1194 std::string handshake::LoadOp::getResultName(
unsigned int idx) {
1195 std::string resName;
1197 resName =
"dataOut";
1199 resName =
"addrOut" + std::to_string(idx - 1);
1203 void handshake::LoadOp::build(OpBuilder &
builder, OperationState &result,
1204 Value memref, ValueRange indices) {
1207 result.addOperands(indices);
1210 auto memrefType = memref.getType().cast<MemRefType>();
1213 result.types.push_back(memrefType.getElementType());
1216 result.types.append(indices.size(),
builder.getIndexType());
1220 OperationState &result) {
1221 SmallVector<OpAsmParser::UnresolvedOperand, 4> addressOperands,
1222 remainingOperands, allOperands;
1223 SmallVector<Type, 1> parsedTypes, allTypes;
1224 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1226 if (parser.parseLSquare() || parser.parseOperandList(addressOperands) ||
1227 parser.parseRSquare() || parser.parseOperandList(remainingOperands) ||
1228 parser.parseColon() || parser.parseTypeList(parsedTypes))
1233 Type dataType = parsedTypes.back();
1234 auto parsedTypesRef = ArrayRef(parsedTypes);
1235 result.addTypes(dataType);
1236 result.addTypes(parsedTypesRef.drop_back());
1237 allOperands.append(addressOperands);
1238 allOperands.append(remainingOperands);
1239 allTypes.append(parsedTypes);
1241 if (parser.resolveOperands(allOperands, allTypes, allOperandLoc,
1247 template <
typename MemOp>
1250 p << op.getAddresses();
1251 p <<
"] " << op.getData() <<
", " << op.getCtrl() <<
" : ";
1252 llvm::interleaveComma(op.getAddresses(), p,
1253 [&](Value v) { p << v.getType(); });
1254 p <<
", " << op.getData().getType();
1257 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
1264 unsigned nAddresses = getAddresses().size();
1266 if (idx < nAddresses)
1267 opName =
"addrIn" + std::to_string(idx);
1268 else if (idx == nAddresses)
1275 template <
typename TMemoryOp>
1277 if (op.getAddresses().size() == 0)
1278 return op.emitOpError() <<
"No addresses were specified";
1285 std::string handshake::StoreOp::getResultName(
unsigned int idx) {
1286 std::string resName;
1288 resName =
"dataToMem";
1290 resName =
"addrOut" + std::to_string(idx - 1);
1294 void handshake::StoreOp::build(OpBuilder &
builder, OperationState &result,
1295 Value valueToStore, ValueRange indices) {
1298 result.addOperands(indices);
1301 result.addOperands(valueToStore);
1304 result.types.push_back(valueToStore.getType());
1307 result.types.append(indices.size(),
builder.getIndexType());
1312 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
1318 bool JoinOp::isControl() {
return true; }
1320 ParseResult JoinOp::parse(OpAsmParser &parser, OperationState &result) {
1321 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1322 SmallVector<Type> types;
1324 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1325 if (parser.parseOperandList(operands) ||
1326 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1327 parser.parseTypeList(types))
1330 if (parser.resolveOperands(operands, types, allOperandLoc, result.operands))
1337 void JoinOp::print(OpAsmPrinter &p) {
1338 p <<
" " << getData();
1339 p.printOptionalAttrDict((*this)->getAttrs(), {
"control"});
1340 p <<
" : " << getData().getTypes();
1344 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1346 auto fnAttr = this->getModuleAttr();
1347 assert(fnAttr &&
"requires a 'module' symbol reference attribute");
1349 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*
this, fnAttr);
1351 return emitOpError() <<
"'" << fnAttr.getValue()
1352 <<
"' does not reference a valid handshake function";
1355 auto fnType = fn.getFunctionType();
1356 if (fnType.getNumInputs() != getNumOperands())
1358 "incorrect number of operands for the referenced handshake function");
1360 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
1361 if (getOperand(i).getType() != fnType.getInput(i))
1362 return emitOpError(
"operand type mismatch: expected operand type ")
1363 << fnType.getInput(i) <<
", but provided "
1364 << getOperand(i).getType() <<
" for operand number " << i;
1366 if (fnType.getNumResults() != getNumResults())
1368 "incorrect number of results for the referenced handshake function");
1370 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
1371 if (getResult(i).getType() != fnType.getResult(i))
1372 return emitOpError(
"result type mismatch: expected result type ")
1373 << fnType.getResult(i) <<
", but provided "
1374 << getResult(i).getType() <<
" for result number " << i;
1383 ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1384 OpAsmParser::UnresolvedOperand tuple;
1387 if (parser.parseOperand(tuple) ||
1388 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1389 parser.parseType(type))
1392 if (parser.resolveOperand(tuple, type, result.operands))
1395 result.addTypes(type.getTypes());
1400 void UnpackOp::print(OpAsmPrinter &p) {
1401 p <<
" " << getInput();
1402 p.printOptionalAttrDict((*this)->getAttrs());
1403 p <<
" : " << getInput().getType();
1406 ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1407 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1408 llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
1411 if (parser.parseOperandList(operands) ||
1412 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1413 parser.parseType(type))
1416 if (parser.resolveOperands(operands, type.getTypes(), allOperandLoc,
1420 result.addTypes(type);
1425 void PackOp::print(OpAsmPrinter &p) {
1426 p <<
" " << getInputs();
1427 p.printOptionalAttrDict((*this)->getAttrs());
1428 p <<
" : " << getResult().getType();
1435 LogicalResult ReturnOp::verify() {
1436 auto *parent = (*this)->getParentOp();
1437 auto function = dyn_cast<handshake::FuncOp>(parent);
1439 return emitOpError(
"must have a handshake.func parent");
1442 const auto &results =
function.getResultTypes();
1443 if (getNumOperands() != results.size())
1444 return emitOpError(
"has ")
1445 << getNumOperands() <<
" operands, but enclosing function returns "
1448 for (
unsigned i = 0, e = results.size(); i != e; ++i)
1449 if (getOperand(i).getType() != results[i])
1450 return emitError() <<
"type of return operand " << i <<
" ("
1451 << getOperand(i).getType()
1452 <<
") doesn't match function result type ("
1453 << results[i] <<
")";
1458 #define GET_OP_CLASSES
1459 #include "circt/Dialect/Handshake/Handshake.cpp.inc"
assert(baseType &&"element must be base type")
static StringRef getOperandName(Value operand)
static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal, uint64_t numOperands)
Verifies whether an indexing value is wide enough to index into a provided number of operands.
static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result)
static ParseResult parseSostOperation(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, OperationState &result, int &size, Type &type, bool explicitSize)
static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op)
static ParseResult parseFuncOpArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Type > &resTypes, SmallVectorImpl< DictionaryAttr > &resAttrs)
Parses a FuncOp signature using mlir::function_interface_impl::parseFunctionSignature while getting a...
static Value getDematerialized(Value v)
Returns a dematerialized version of the value 'v', defined as the source of the value before passing ...
static bool isControlCheckTypeAndOperand(Type dataType, Value operand)
llvm::SmallVector< handshake::MemLoadInterface > getLoadPorts(TMemOp op)
static std::string getMemoryResultName(unsigned nLoads, unsigned nStores, unsigned idx)
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v)
static LogicalResult verifyMemoryAccessOp(TMemoryOp op)
static void addStringToStringArrayAttr(Builder &builder, Operation *op, StringRef attrName, StringAttr str)
Helper function for appending a string to an array attribute, and rewriting the attribute back to the...
static std::string defaultOperandName(unsigned int idx)
static SmallVector< Attribute > getFuncOpNames(Builder &builder, unsigned cnt, StringRef prefix)
Generates names for a handshake.func input and output arguments, based on the number of args as well ...
static std::string getMemoryOperandName(unsigned nStores, unsigned idx)
llvm::SmallVector< handshake::MemStoreInterface > getStorePorts(TMemOp op)
static ParseResult parseMemoryAccessOp(OpAsmParser &parser, OperationState &result)
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
bool isControlOpImpl(Operation *op)
Default implementation for checking whether an operation is a control operation.
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...