18 #include "mlir/IR/AsmState.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Diagnostics.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/SymbolTable.h"
25 #include "mlir/Interfaces/FunctionImplementation.h"
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/PriorityQueue.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Casting.h"
35 using namespace circt;
44 template <
class T,
class... Ts>
45 struct IsAny : std::disjunction<std::is_same<T, Ts>...> {};
61 size_t numDirections = nIns + nOuts;
62 APInt portDirections(numDirections, 0);
63 for (
size_t i = nIns, e = numDirections; i != e; ++i)
64 portDirections.setBit(i);
75 template <
typename CtrlOp>
79 PatternRewriter &rewriter)
const override {
80 auto &ops = ctrlOp.getBodyBlock()->getOperations();
82 (ops.size() == 1) && isa<EnableOp>(ops.front()) &&
83 isa<SeqOp, ParOp, StaticSeqOp, StaticParOp>(ctrlOp->getParentOp());
87 ops.front().moveBefore(ctrlOp);
88 rewriter.eraseOp(ctrlOp);
101 template <
typename Op>
103 Operation *definingOp = op.getSrc().getDefiningOp();
104 if (definingOp ==
nullptr)
110 if (
auto dialect = definingOp->getDialect(); isa<comb::CombDialect>(dialect))
111 return op->emitOpError(
"has source that is not a port or constant. "
112 "Complex logic should be conducted in the guard.");
119 static std::string
valueName(Operation *scopeOp, Value v) {
121 llvm::raw_string_ostream os(s);
126 AsmState asmState(scopeOp, OpPrintingFlags().assumeVerified());
127 v.printAsOperand(os, asmState);
134 Operation *definingOp =
value.getDefiningOp();
135 return value.isa<BlockArgument>() ||
136 (definingOp && isa<CellInterface>(definingOp));
141 Operation *op = arg.getOwner()->getParentOp();
142 assert(isa<ComponentInterface>(op) &&
143 "Only ComponentInterface should support lookup by BlockArgument.");
144 return cast<ComponentInterface>(op).getPortInfo()[arg.getArgNumber()];
149 return isa<ControlOp, SeqOp, IfOp, RepeatOp, WhileOp, ParOp, StaticRepeatOp,
150 StaticParOp, StaticSeqOp, StaticIfOp>(op);
155 if (isa<EnableOp>(op)) {
157 auto component = op->getParentOfType<ComponentOp>();
158 auto enableOp = llvm::cast<EnableOp>(op);
159 StringRef groupName = enableOp.getGroupName();
160 auto group = component.getWiresOp().lookupSymbol<GroupInterface>(groupName);
161 return isa<StaticGroupOp>(group);
163 return isa<StaticIfOp, StaticSeqOp, StaticRepeatOp, StaticParOp>(op);
168 if (isa<SeqOp, ParOp, StaticSeqOp, StaticParOp>(op))
173 for (
auto ®ion : op->getRegions()) {
174 auto opsIt = region.getOps();
175 size_t numOperations = std::distance(opsIt.begin(), opsIt.end());
180 bool usesEnableAsCompositionOperator =
181 numOperations > 1 && llvm::any_of(region.front(), [](
auto &&bodyOp) {
182 return isa<EnableOp>(bodyOp);
184 if (usesEnableAsCompositionOperator)
185 return op->emitOpError(
186 "EnableOp is not a composition operator. It should be nested "
187 "in a control flow operation, such as \"calyx.seq\"");
191 size_t numControlFlowRegions =
193 if (numControlFlowRegions > 1)
194 return op->emitOpError(
195 "has an invalid control sequence. Multiple control flow operations "
196 "must all be nested in a single calyx.seq or calyx.par");
202 auto *opParent = op->getParentOp();
203 if (!isa<ModuleOp>(opParent))
204 return op->emitOpError()
205 <<
"has parent: " << opParent <<
", expected ModuleOp.";
210 auto opParent = op->getParentOp();
211 if (!isa<ComponentInterface>(opParent))
212 return op->emitOpError()
213 <<
"has parent: " << opParent <<
", expected ComponentInterface.";
218 auto parent = op->getParentOp();
220 if (isa<calyx::EnableOp>(op) &&
221 !isa<calyx::CalyxDialect>(parent->getDialect())) {
230 return op->emitOpError()
231 <<
"has parent: " << parent
232 <<
", which is not allowed for a control-like operation.";
234 if (op->getNumRegions() == 0)
237 auto ®ion = op->getRegion(0);
239 auto isValidBodyOp = [](Operation *operation) {
240 return isa<EnableOp, InvokeOp, SeqOp, IfOp, RepeatOp, WhileOp, ParOp,
241 StaticParOp, StaticRepeatOp, StaticSeqOp, StaticIfOp>(operation);
243 for (
auto &&bodyOp : region.front()) {
244 if (isValidBodyOp(&bodyOp))
247 return op->emitOpError()
248 <<
"has operation: " << bodyOp.getName()
249 <<
", which is not allowed in this control-like operation";
255 auto ifOp = dyn_cast<IfInterface>(op);
257 if (ifOp.elseBodyExists() && ifOp.getElseBody()->empty())
258 return ifOp->emitOpError() <<
"empty 'else' region.";
268 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
269 OpAsmParser::UnresolvedOperand guardOrSource;
270 if (parser.parseOperand(guardOrSource))
273 if (succeeded(parser.parseOptionalQuestion())) {
274 OpAsmParser::UnresolvedOperand source;
276 if (parser.parseOperand(source))
278 operandInfos.push_back(source);
281 operandInfos.push_back(guardOrSource);
286 if (parser.parseColonType(type) ||
287 parser.resolveOperands(operandInfos, type, result.operands))
294 template <
typename GroupPortType>
296 static_assert(IsAny<GroupPortType, GroupGoOp, GroupDoneOp>(),
297 "Should be a Calyx Group port.");
301 Value guard = op.getGuard(), source = op.getSrc();
304 p << source <<
" : " << source.getType();
309 template <
typename OpTy>
311 PatternRewriter &rewriter) {
312 static_assert(IsAny<OpTy, SeqOp, ParOp, StaticSeqOp, StaticParOp>(),
313 "Should be a SeqOp, ParOp, StaticSeqOp, or StaticParOp");
315 if (isa<OpTy>(controlOp->getParentOp())) {
316 Block *controlBody = controlOp.getBodyBlock();
317 for (
auto &op : make_early_inc_range(*controlBody))
318 op.moveBefore(controlOp);
320 rewriter.eraseOp(controlOp);
327 template <
typename OpTy>
328 static LogicalResult
emptyControl(OpTy controlOp, PatternRewriter &rewriter) {
329 if (controlOp.getBodyBlock()->empty()) {
330 rewriter.eraseOp(controlOp);
339 template <
typename OpTy>
341 PatternRewriter &rewriter) {
342 static_assert(IsAny<OpTy, IfOp, WhileOp>(),
343 "This is only applicable to WhileOp and IfOp.");
346 Value cond = op.getCond();
347 std::optional<StringRef> groupName = op.getGroupName();
348 auto component = op->template getParentOfType<ComponentOp>();
349 rewriter.eraseOp(op);
353 auto group = component.getWiresOp().template lookupSymbol<GroupInterface>(
355 if (SymbolTable::symbolKnownUseEmpty(group, component.getRegion()))
356 rewriter.eraseOp(group);
359 if (!cond.isa<BlockArgument>() && cond.getDefiningOp()->use_empty())
360 rewriter.eraseOp(cond.getDefiningOp());
366 template <
typename OpTy>
368 static_assert(std::is_same<OpTy, StaticIfOp>(),
369 "This is only applicable to StatifIfOp.");
372 Value cond = op.getCond();
373 rewriter.eraseOp(op);
376 if (!cond.isa<BlockArgument>() && cond.getDefiningOp()->use_empty())
377 rewriter.eraseOp(cond.getDefiningOp());
384 template <
typename ComponentTy>
386 auto componentName = comp->template getAttrOfType<StringAttr>(
387 ::mlir::SymbolTable::getSymbolAttrName())
390 p.printSymbolName(componentName);
393 auto printPortDefList = [&](
auto ports) {
395 llvm::interleaveComma(ports, p, [&](
const PortInfo &port) {
396 p <<
"%" << port.
name.getValue() <<
": " << port.
type;
399 p.printAttributeWithoutType(port.attributes);
404 printPortDefList(comp.getInputPortInfo());
406 printPortDefList(comp.getOutputPortInfo());
409 p.printRegion(*comp.getRegion(),
false,
413 SmallVector<StringRef> elidedAttrs = {
418 ComponentTy::getFunctionTypeAttrName(comp->getName()),
419 ComponentTy::getArgAttrsAttrName(comp->getName()),
420 ComponentTy::getResAttrsAttrName(comp->getName())};
421 p.printOptionalAttrDict(comp->getAttrs(), elidedAttrs);
428 SmallVectorImpl<OpAsmParser::Argument> &ports,
429 SmallVectorImpl<Type> &portTypes,
430 SmallVectorImpl<NamedAttrList> &portAttrs) {
432 OpAsmParser::Argument port;
435 if (parser.parseArgument(port) || parser.parseColon() ||
436 parser.parseType(portType))
438 port.type = portType;
439 ports.push_back(port);
440 portTypes.push_back(portType);
442 NamedAttrList portAttr;
443 portAttrs.push_back(succeeded(parser.parseOptionalAttrDict(portAttr))
449 return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
456 SmallVectorImpl<OpAsmParser::Argument> &ports,
457 SmallVectorImpl<Type> &portTypes) {
458 SmallVector<OpAsmParser::Argument> inPorts, outPorts;
459 SmallVector<Type> inPortTypes, outPortTypes;
460 SmallVector<NamedAttrList> portAttributes;
465 if (parser.parseArrow() ||
469 auto *context = parser.getBuilder().getContext();
472 SmallVector<Attribute> portNames;
473 auto getPortName = [context](
const auto &port) -> StringAttr {
474 StringRef name = port.ssaName.name;
475 if (name.startswith(
"%"))
476 name = name.drop_front();
479 llvm::transform(inPorts, std::back_inserter(portNames), getPortName);
480 llvm::transform(outPorts, std::back_inserter(portNames), getPortName);
482 result.addAttribute(
"portNames",
ArrayAttr::get(context, portNames));
487 ports.append(inPorts);
488 ports.append(outPorts);
489 portTypes.append(inPortTypes);
490 portTypes.append(outPortTypes);
492 SmallVector<Attribute> portAttrs;
493 llvm::transform(portAttributes, std::back_inserter(portAttrs),
494 [&](
auto attr) {
return attr.getDictionary(context); });
495 result.addAttribute(
"portAttributes",
ArrayAttr::get(context, portAttrs));
500 template <
typename ComponentTy>
502 OperationState &result) {
503 using namespace mlir::function_interface_impl;
505 StringAttr componentName;
506 if (parser.parseSymbolName(componentName,
507 ::mlir::SymbolTable::getSymbolAttrName(),
511 SmallVector<mlir::OpAsmParser::Argument> ports;
513 SmallVector<Type> portTypes;
519 auto type = parser.getBuilder().getFunctionType(portTypes, {});
520 result.addAttribute(ComponentTy::getFunctionTypeAttrName(result.name),
523 auto *body = result.addRegion();
524 if (parser.parseRegion(*body, ports))
528 body->push_back(
new Block());
530 if (parser.parseOptionalAttrDict(result.attributes))
537 template <
typename T>
538 static SmallVector<T>
concat(
const SmallVectorImpl<T> &a,
539 const SmallVectorImpl<T> &b) {
547 StringAttr name, ArrayRef<PortInfo> ports,
548 bool combinational) {
549 using namespace mlir::function_interface_impl;
551 result.addAttribute(::mlir::SymbolTable::getSymbolAttrName(), name);
553 std::pair<SmallVector<Type, 8>, SmallVector<Type, 8>> portIOTypes;
554 std::pair<SmallVector<Attribute, 8>, SmallVector<Attribute, 8>> portIONames;
555 std::pair<SmallVector<Attribute, 8>, SmallVector<Attribute, 8>>
557 SmallVector<Direction, 8> portDirections;
560 for (
auto &&port : ports) {
562 (isInput ? portIOTypes.first : portIOTypes.second).push_back(port.type);
563 (isInput ? portIONames.first : portIONames.second).push_back(port.name);
564 (isInput ? portIOAttributes.first : portIOAttributes.second)
565 .push_back(port.attributes);
567 auto portTypes =
concat(portIOTypes.first, portIOTypes.second);
568 auto portNames =
concat(portIONames.first, portIONames.second);
569 auto portAttributes =
concat(portIOAttributes.first, portIOAttributes.second);
572 auto functionType =
builder.getFunctionType(portTypes, {});
574 result.addAttribute(CombComponentOp::getFunctionTypeAttrName(result.name),
577 result.addAttribute(ComponentOp::getFunctionTypeAttrName(result.name),
582 result.addAttribute(
"portNames",
builder.getArrayAttr(portNames));
583 result.addAttribute(
"portDirections",
585 portIOTypes.first.size(),
586 portIOTypes.second.size()));
588 result.addAttribute(
"portAttributes",
builder.getArrayAttr(portAttributes));
591 Region *region = result.addRegion();
592 Block *body =
new Block();
593 region->push_back(body);
596 body->addArguments(portTypes, SmallVector<Location, 4>(
597 portTypes.size(),
builder.getUnknownLoc()));
600 IRRewriter::InsertionGuard guard(
builder);
601 builder.setInsertionPointToStart(body);
602 builder.create<WiresOp>(result.location);
604 builder.create<ControlOp>(result.location);
615 template <
typename Op>
617 auto *body = op.getBodyBlock();
620 auto opIt = body->getOps<Op>().begin();
627 ArrayAttr portNames = op.getPortNames();
629 for (
size_t i = 0, e = portNames.size(); i != e; ++i) {
630 auto portName = portNames[i].cast<StringAttr>();
631 if (portName.getValue() == name)
632 return op.getBodyBlock()->getArgument(i);
637 WiresOp calyx::ComponentOp::getWiresOp() {
638 return getControlOrWiresFrom<WiresOp>(*
this);
641 ControlOp calyx::ComponentOp::getControlOp() {
642 return getControlOrWiresFrom<ControlOp>(*
this);
645 Value calyx::ComponentOp::getGoPort() {
649 Value calyx::ComponentOp::getDonePort() {
653 Value calyx::ComponentOp::getClkPort() {
657 Value calyx::ComponentOp::getResetPort() {
662 auto portTypes = getArgumentTypes();
663 ArrayAttr portNamesAttr = getPortNames(), portAttrs = getPortAttributes();
664 APInt portDirectionsAttr = getPortDirections();
666 SmallVector<PortInfo> results;
667 for (
size_t i = 0, e = portNamesAttr.size(); i != e; ++i) {
668 results.push_back(
PortInfo{portNamesAttr[i].cast<StringAttr>(),
671 portAttrs[i].cast<DictionaryAttr>()});
677 template <
typename Pred>
679 SmallVector<PortInfo> ports = op.getPortInfo();
680 llvm::erase_if(ports, p);
684 SmallVector<PortInfo> ComponentOp::getInputPortInfo() {
689 SmallVector<PortInfo> ComponentOp::getOutputPortInfo() {
694 void ComponentOp::print(OpAsmPrinter &p) {
695 printComponentInterface<ComponentOp>(p, *
this);
698 ParseResult ComponentOp::parse(OpAsmParser &parser, OperationState &result) {
699 return parseComponentInterface<ComponentOp>(parser, result);
705 llvm::SmallVector<StringRef, 4> identifiers;
706 for (
PortInfo &port : op.getPortInfo()) {
707 auto portIds = port.getAllIdentifiers();
708 identifiers.append(portIds.begin(), portIds.end());
711 std::sort(identifiers.begin(), identifiers.end());
714 interfacePorts{
"clk",
"done",
"go",
"reset"};
716 std::set_intersection(interfacePorts.begin(), interfacePorts.end(),
717 identifiers.begin(), identifiers.end(),
723 SmallVector<StringRef, 4> difference;
724 std::set_difference(interfacePorts.begin(), interfacePorts.end(),
726 std::back_inserter(difference));
727 return op->emitOpError()
728 <<
"is missing the following required port attribute identifiers: "
732 LogicalResult ComponentOp::verify() {
734 auto wIt = getBodyBlock()->getOps<WiresOp>();
735 auto cIt = getBodyBlock()->getOps<ControlOp>();
736 if (std::distance(wIt.begin(), wIt.end()) +
737 std::distance(cIt.begin(), cIt.end()) !=
739 return emitOpError() <<
"requires exactly one of each: '"
740 << WiresOp::getOperationName() <<
"', '"
741 << ControlOp::getOperationName() <<
"'.";
748 bool hasNoControlConstructs =
749 getControlOp().getBodyBlock()->getOperations().empty();
750 bool hasNoAssignments =
751 getWiresOp().getBodyBlock()->getOps<AssignOp>().
empty();
752 if (hasNoControlConstructs && hasNoAssignments)
754 "The component currently does nothing. It needs to either have "
755 "continuous assignments in the Wires region or control constructs in "
756 "the Control region.");
761 void ComponentOp::build(OpBuilder &
builder, OperationState &result,
762 StringAttr name, ArrayRef<PortInfo> ports) {
766 void ComponentOp::getAsmBlockArgumentNames(
770 auto ports = getPortNames();
771 auto *block = &getRegion()->front();
772 for (
size_t i = 0, e = block->getNumArguments(); i != e; ++i)
773 setNameFn(block->getArgument(i), ports[i].cast<StringAttr>().getValue());
781 auto portTypes = getArgumentTypes();
782 ArrayAttr portNamesAttr = getPortNames(), portAttrs = getPortAttributes();
783 APInt portDirectionsAttr = getPortDirections();
785 SmallVector<PortInfo> results;
786 for (
size_t i = 0, e = portNamesAttr.size(); i != e; ++i) {
787 results.push_back(
PortInfo{portNamesAttr[i].cast<StringAttr>(),
790 portAttrs[i].cast<DictionaryAttr>()});
795 WiresOp calyx::CombComponentOp::getWiresOp() {
796 auto *body = getBodyBlock();
797 auto opIt = body->getOps<WiresOp>().begin();
802 template <
typename Pred>
804 SmallVector<PortInfo> ports = op.getPortInfo();
805 llvm::erase_if(ports, p);
809 SmallVector<PortInfo> CombComponentOp::getInputPortInfo() {
814 SmallVector<PortInfo> CombComponentOp::getOutputPortInfo() {
819 void CombComponentOp::print(OpAsmPrinter &p) {
820 printComponentInterface<CombComponentOp>(p, *
this);
823 ParseResult CombComponentOp::parse(OpAsmParser &parser,
824 OperationState &result) {
825 return parseComponentInterface<CombComponentOp>(parser, result);
828 LogicalResult CombComponentOp::verify() {
830 auto wIt = getBodyBlock()->getOps<WiresOp>();
831 if (std::distance(wIt.begin(), wIt.end()) != 1)
832 return emitOpError() <<
"requires exactly one "
833 << WiresOp::getOperationName() <<
" op.";
836 auto cIt = getBodyBlock()->getOps<ControlOp>();
837 if (std::distance(cIt.begin(), cIt.end()) != 0)
838 return emitOpError() <<
"must not have a `" << ControlOp::getOperationName()
842 bool hasNoAssignments =
843 getWiresOp().getBodyBlock()->getOps<AssignOp>().
empty();
844 if (hasNoAssignments)
846 "The component currently does nothing. It needs to either have "
847 "continuous assignments in the Wires region or control constructs in "
848 "the Control region.");
851 auto cells = getOps<CellInterface>();
852 for (
auto cell : cells) {
853 if (!cell.isCombinational())
854 return emitOpError() <<
"contains non-combinational cell "
855 << cell.instanceName();
859 auto groups = getWiresOp().getOps<GroupOp>();
861 return emitOpError() <<
"contains group " << (*groups.begin()).getSymName();
866 auto combGroups = getWiresOp().getOps<CombGroupOp>();
867 if (!combGroups.empty())
868 return emitOpError() <<
"contains comb group "
869 << (*combGroups.begin()).getSymName();
874 void CombComponentOp::build(OpBuilder &
builder, OperationState &result,
875 StringAttr name, ArrayRef<PortInfo> ports) {
879 void CombComponentOp::getAsmBlockArgumentNames(
883 auto ports = getPortNames();
884 auto *block = &getRegion()->front();
885 for (
size_t i = 0, e = block->getNumArguments(); i != e; ++i)
886 setNameFn(block->getArgument(i), ports[i].cast<StringAttr>().getValue());
895 SmallVector<InvokeOp, 4> ControlOp::getInvokeOps() {
896 SmallVector<InvokeOp, 4> ret;
897 this->walk([&](InvokeOp invokeOp) { ret.push_back(invokeOp); });
905 void SeqOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
906 MLIRContext *context) {
907 patterns.add(collapseControl<SeqOp>);
916 LogicalResult StaticSeqOp::verify() {
918 auto &ops = (*this).getBodyBlock()->getOperations();
919 if (!llvm::all_of(ops, [&](Operation &op) {
return isStaticControl(&op); })) {
920 return emitOpError(
"StaticSeqOp has non static control within it");
926 void StaticSeqOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
927 MLIRContext *context) {
928 patterns.add(collapseControl<StaticSeqOp>);
929 patterns.add(emptyControl<StaticSeqOp>);
937 LogicalResult ParOp::verify() {
942 for (EnableOp op : getBodyBlock()->getOps<EnableOp>()) {
943 StringRef groupName = op.getGroupName();
944 if (groupNames.count(groupName))
945 return emitOpError() <<
"cannot enable the same group: \"" << groupName
946 <<
"\" more than once.";
947 groupNames.insert(groupName);
953 void ParOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
954 MLIRContext *context) {
955 patterns.add(collapseControl<ParOp>);
964 LogicalResult StaticParOp::verify() {
969 for (EnableOp op : getBodyBlock()->getOps<EnableOp>()) {
970 StringRef groupName = op.getGroupName();
971 if (groupNames.count(groupName))
972 return emitOpError() <<
"cannot enable the same group: \"" << groupName
973 <<
"\" more than once.";
974 groupNames.insert(groupName);
978 auto &ops = (*this).getBodyBlock()->getOperations();
979 for (Operation &op : ops) {
981 return op.emitOpError(
"StaticParOp has non static control within it");
988 void StaticParOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
989 MLIRContext *context) {
990 patterns.add(collapseControl<StaticParOp>);
991 patterns.add(emptyControl<StaticParOp>);
998 LogicalResult WiresOp::verify() {
999 auto componentInterface = (*this)->getParentOfType<ComponentInterface>();
1000 if (llvm::isa<ComponentOp>(componentInterface)) {
1001 auto component = llvm::cast<ComponentOp>(componentInterface);
1002 auto control = component.getControlOp();
1005 for (
auto &&op : *getBodyBlock()) {
1006 if (!isa<GroupInterface>(op))
1008 auto group = cast<GroupInterface>(op);
1009 auto groupName = group.symName();
1010 if (mlir::SymbolTable::symbolKnownUseEmpty(groupName, control))
1011 return op.emitOpError()
1012 <<
"with name: " << groupName
1013 <<
" is unused in the control execution schedule";
1020 for (
auto thisAssignment : getBodyBlock()->getOps<AssignOp>()) {
1024 if (thisAssignment.getGuard())
1027 Value dest = thisAssignment.getDest();
1028 for (Operation *user : dest.getUsers()) {
1029 auto assignUser = dyn_cast<AssignOp>(user);
1030 if (!assignUser || assignUser.getDest() != dest ||
1031 assignUser == thisAssignment)
1034 return user->emitOpError() <<
"destination is already continuously "
1035 "driven. Other assignment is "
1049 Operation *definingOp =
value.getDefiningOp();
1050 if (definingOp ==
nullptr || definingOp->hasTrait<
Combinational>())
1056 if (isa<InstanceOp>(definingOp))
1060 if (isa<comb::CombDialect, hw::HWDialect>(definingOp->getDialect()))
1064 if (
auto r = dyn_cast<RegisterOp>(definingOp)) {
1065 return value == r.getOut()
1067 : group->emitOpError()
1068 <<
"with register: \"" << r.instanceName()
1069 <<
"\" is conducting a memory store. This is not "
1071 }
else if (
auto m = dyn_cast<MemoryOp>(definingOp)) {
1072 auto writePorts = {m.writeData(), m.writeEn()};
1073 return (llvm::none_of(writePorts, [&](Value p) {
return p ==
value; }))
1075 : group->emitOpError()
1076 <<
"with memory: \"" << m.instanceName()
1077 <<
"\" is conducting a memory store. This "
1078 "is not combinational.";
1081 std::string portName =
1083 return group->emitOpError() <<
"with port: " << portName
1084 <<
". This operation is not combinational.";
1089 LogicalResult CombGroupOp::verify() {
1090 for (
auto &&op : *getBodyBlock()) {
1091 auto assign = dyn_cast<AssignOp>(op);
1092 if (assign ==
nullptr)
1094 Value dst = assign.getDest(), src = assign.getSrc();
1105 GroupGoOp GroupOp::getGoOp() {
1106 auto goOps = getBodyBlock()->getOps<GroupGoOp>();
1107 size_t nOps = std::distance(goOps.begin(), goOps.end());
1108 return nOps ? *goOps.begin() : GroupGoOp();
1111 GroupDoneOp GroupOp::getDoneOp() {
1112 auto body = this->getBodyBlock();
1113 return cast<GroupDoneOp>(body->getTerminator());
1119 void CycleOp::print(OpAsmPrinter &p) {
1122 auto start = this->getStart();
1123 auto end = this->getEnd();
1124 if (
end.has_value()) {
1125 p <<
"[" << start <<
":" <<
end.value() <<
"]";
1131 ParseResult CycleOp::parse(OpAsmParser &parser, OperationState &result) {
1132 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
1134 uint32_t startLiteral;
1135 uint32_t endLiteral;
1137 auto hasEnd = succeeded(parser.parseOptionalLSquare());
1139 if (parser.parseInteger(startLiteral)) {
1140 parser.emitError(parser.getNameLoc(),
"Could not parse start cycle");
1144 auto start = parser.getBuilder().getI32IntegerAttr(startLiteral);
1145 result.addAttribute(getStartAttrName(result.name), start);
1148 if (parser.parseColon())
1151 if (
auto res = parser.parseOptionalInteger(endLiteral); res.has_value()) {
1152 auto end = parser.getBuilder().getI32IntegerAttr(endLiteral);
1153 result.addAttribute(getEndAttrName(result.name), end);
1156 if (parser.parseRSquare())
1160 result.addTypes(parser.getBuilder().getI1Type());
1165 LogicalResult CycleOp::verify() {
1166 uint32_t latency = this->getGroupLatency();
1168 if (this->getStart() >= latency) {
1169 emitOpError(
"start cycle must be less than the group latency");
1173 if (this->getEnd().has_value()) {
1174 if (this->getStart() >= this->getEnd().
value()) {
1175 emitOpError(
"start cycle must be less than end cycle");
1179 if (this->getEnd() >= latency) {
1180 emitOpError(
"end cycle must be less than the group latency");
1188 uint32_t CycleOp::getGroupLatency() {
1189 auto group = (*this)->getParentOfType<StaticGroupOp>();
1190 return group.getLatency();
1201 return llvm::any_of(port.getUses(), [&](
auto &&use) {
1202 auto assignOp = dyn_cast<AssignOp>(use.getOwner());
1203 if (assignOp == nullptr)
1206 Operation *parent = assignOp->getParentOp();
1207 if (isa<WiresOp>(parent))
1216 Value expected = isDriven ? assignOp.getDest() : assignOp.getSrc();
1217 return expected == port && group == parent;
1229 if (
auto cell = dyn_cast<CellInterface>(port.getDefiningOp());
1231 return groupOp.drivesAnyPort(cell.getInputPorts());
1236 LogicalResult GroupOp::drivesPort(Value port) {
1240 LogicalResult CombGroupOp::drivesPort(Value port) {
1244 LogicalResult StaticGroupOp::drivesPort(Value port) {
1251 return success(llvm::all_of(ports, [&](Value port) {
1256 LogicalResult GroupOp::drivesAllPorts(ValueRange ports) {
1260 LogicalResult CombGroupOp::drivesAllPorts(ValueRange ports) {
1264 LogicalResult StaticGroupOp::drivesAllPorts(ValueRange ports) {
1271 return success(llvm::any_of(ports, [&](Value port) {
1276 LogicalResult GroupOp::drivesAnyPort(ValueRange ports) {
1280 LogicalResult CombGroupOp::drivesAnyPort(ValueRange ports) {
1284 LogicalResult StaticGroupOp::drivesAnyPort(ValueRange ports) {
1291 return success(llvm::any_of(ports, [&](Value port) {
1296 LogicalResult GroupOp::readsAnyPort(ValueRange ports) {
1300 LogicalResult CombGroupOp::readsAnyPort(ValueRange ports) {
1304 LogicalResult StaticGroupOp::readsAnyPort(ValueRange ports) {
1311 GroupInterface group) {
1312 Operation *destDefiningOp = assign.getDest().getDefiningOp();
1313 if (destDefiningOp ==
nullptr)
1315 auto destCell = dyn_cast<CellInterface>(destDefiningOp);
1316 if (destCell ==
nullptr)
1319 LogicalResult verifyWrites =
1320 TypeSwitch<Operation *, LogicalResult>(destCell)
1321 .Case<RegisterOp>([&](
auto op) {
1324 return succeeded(group.drivesAnyPort({op.getWriteEn(), op.getIn()}))
1325 ? group.drivesAllPorts({op.getWriteEn(), op.getIn()})
1328 .Case<MemoryOp>([&](
auto op) {
1329 SmallVector<Value> requiredWritePorts;
1332 requiredWritePorts.push_back(op.writeEn());
1333 requiredWritePorts.push_back(op.writeData());
1334 for (Value address : op.addrPorts())
1335 requiredWritePorts.push_back(address);
1340 group.drivesAnyPort({op.writeData(), op.writeEn()}))
1341 ? group.drivesAllPorts(requiredWritePorts)
1344 .Case<AndLibOp, OrLibOp, XorLibOp, AddLibOp, SubLibOp, GtLibOp,
1345 LtLibOp, EqLibOp, NeqLibOp, GeLibOp, LeLibOp, LshLibOp,
1346 RshLibOp, SgtLibOp, SltLibOp, SeqLibOp, SneqLibOp, SgeLibOp,
1347 SleLibOp, SrshLibOp>([&](
auto op) {
1348 Value lhs = op.getLeft(), rhs = op.getRight();
1349 return succeeded(group.drivesAnyPort({lhs, rhs}))
1350 ? group.drivesAllPorts({lhs, rhs})
1353 .Default([&](
auto op) {
return success(); });
1355 if (failed(verifyWrites))
1356 return group->emitOpError()
1357 <<
"with cell: " << destCell->getName() <<
" \""
1358 << destCell.instanceName()
1359 <<
"\" is performing a write and failed to drive all necessary "
1362 Operation *srcDefiningOp = assign.getSrc().getDefiningOp();
1363 if (srcDefiningOp ==
nullptr)
1365 auto srcCell = dyn_cast<CellInterface>(srcDefiningOp);
1366 if (srcCell ==
nullptr)
1369 LogicalResult verifyReads =
1370 TypeSwitch<Operation *, LogicalResult>(srcCell)
1371 .Case<MemoryOp>([&](
auto op) {
1375 return succeeded(group.readsAnyPort({op.readData()}))
1376 ? group.drivesAllPorts(op.addrPorts())
1379 .Default([&](
auto op) {
return success(); });
1381 if (failed(verifyReads))
1382 return group->emitOpError() <<
"with cell: " << srcCell->getName() <<
" \""
1383 << srcCell.instanceName()
1384 <<
"\" is having a read performed upon it, and "
1385 "failed to drive all necessary ports.";
1391 auto group = dyn_cast<GroupInterface>(op);
1392 if (group ==
nullptr)
1395 for (
auto &&groupOp : *group.getBody()) {
1396 auto assign = dyn_cast<AssignOp>(groupOp);
1397 if (assign ==
nullptr)
1413 ArrayRef<StringRef> portNames) {
1414 auto cellInterface = dyn_cast<CellInterface>(op);
1415 assert(cellInterface &&
"must implement the Cell interface");
1417 std::string prefix = cellInterface.instanceName().str() +
".";
1418 for (
size_t i = 0, e = portNames.size(); i != e; ++i)
1419 setNameFn(op->getResult(i), prefix + portNames[i].str());
1430 bool isDestination) {
1431 Operation *definingOp =
value.getDefiningOp();
1432 bool isComponentPort =
value.isa<BlockArgument>(),
1433 isCellInterfacePort = definingOp && isa<CellInterface>(definingOp);
1434 assert((isComponentPort || isCellInterfacePort) &&
"Not a port.");
1438 : cast<CellInterface>(definingOp).portInfo(
value);
1440 bool isSource = !isDestination;
1443 (isDestination && isComponentPort) || (isSource && isCellInterfacePort)
1450 <<
"has a " << (isComponentPort ?
"component" :
"cell")
1452 << (isDestination ?
"destination" :
"source")
1453 <<
" with the incorrect direction.";
1460 bool isSource = !isDestination;
1461 Value
value = isDestination ? op.getDest() : op.getSrc();
1466 if (isDestination && !isa<GroupGoOp, GroupDoneOp>(
value.getDefiningOp()))
1467 return op->emitOpError(
1468 "has an invalid destination port. It must be drive-able.");
1475 LogicalResult AssignOp::verify() {
1476 bool isDestination =
true, isSource =
false;
1485 ParseResult AssignOp::parse(OpAsmParser &parser, OperationState &result) {
1486 OpAsmParser::UnresolvedOperand destination;
1487 if (parser.parseOperand(destination) || parser.parseEqual())
1493 OpAsmParser::UnresolvedOperand guardOrSource;
1494 if (parser.parseOperand(guardOrSource))
1499 OpAsmParser::UnresolvedOperand source;
1500 bool hasGuard = succeeded(parser.parseOptionalQuestion());
1503 if (parser.parseOperand(source))
1508 if (parser.parseColonType(type) ||
1509 parser.resolveOperand(destination, type, result.operands))
1513 Type i1Type = parser.getBuilder().getI1Type();
1516 if (parser.resolveOperand(source, type, result.operands) ||
1517 parser.resolveOperand(guardOrSource, i1Type, result.operands))
1521 if (parser.resolveOperand(guardOrSource, type, result.operands))
1528 void AssignOp::print(OpAsmPrinter &p) {
1529 p <<
" " << getDest() <<
" = ";
1531 Value bguard = getGuard(), source = getSrc();
1534 p << bguard <<
" ? ";
1538 p << source <<
" : " << source.getType();
1547 ComponentInterface InstanceOp::getReferencedComponent() {
1548 auto module = (*this)->getParentOfType<ModuleOp>();
1552 return module.lookupSymbol<ComponentInterface>(getComponentName());
1558 static LogicalResult
1560 ComponentInterface referencedComponent) {
1561 auto module = instance->getParentOfType<ModuleOp>();
1562 StringRef entryPointName =
1563 module->getAttrOfType<StringAttr>(
"calyx.entrypoint");
1564 if (instance.getComponentName() == entryPointName)
1565 return instance.emitOpError()
1566 <<
"cannot reference the entry-point component: '" << entryPointName
1570 SmallVector<PortInfo> componentPorts = referencedComponent.getPortInfo();
1571 size_t numPorts = componentPorts.size();
1573 size_t numResults = instance.getNumResults();
1574 if (numResults != numPorts)
1575 return instance.emitOpError()
1576 <<
"has a wrong number of results; expected: " << numPorts
1577 <<
" but got " << numResults;
1579 for (
size_t i = 0; i != numResults; ++i) {
1580 auto resultType = instance.getResult(i).getType();
1581 auto expectedType = componentPorts[i].type;
1582 if (resultType == expectedType)
1584 return instance.emitOpError()
1585 <<
"result type for " << componentPorts[i].name <<
" must be "
1586 << expectedType <<
", but got " << resultType;
1591 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1592 Operation *op = *
this;
1593 auto module = op->getParentOfType<ModuleOp>();
1594 Operation *referencedComponent =
1595 symbolTable.lookupNearestSymbolFrom(module, getComponentNameAttr());
1596 if (referencedComponent ==
nullptr)
1597 return emitError() <<
"referencing component: '" << getComponentName()
1598 <<
"', which does not exist.";
1600 Operation *shadowedComponentName =
1601 symbolTable.lookupNearestSymbolFrom(module, getSymNameAttr());
1602 if (shadowedComponentName !=
nullptr)
1603 return emitError() <<
"instance symbol: '" << instanceName()
1604 <<
"' is already a symbol for another component.";
1607 auto parentComponent = op->getParentOfType<ComponentOp>();
1608 if (parentComponent == referencedComponent)
1609 return emitError() <<
"recursive instantiation of its parent component: '"
1610 << getComponentName() <<
"'";
1612 assert(isa<ComponentInterface>(referencedComponent) &&
1613 "Should be a ComponentInterface.");
1615 cast<ComponentInterface>(referencedComponent));
1623 SmallVector<StringRef> InstanceOp::portNames() {
1624 SmallVector<StringRef> portNames;
1625 for (Attribute name : getReferencedComponent().getPortNames())
1626 portNames.push_back(name.cast<StringAttr>().getValue());
1630 SmallVector<Direction> InstanceOp::portDirections() {
1631 SmallVector<Direction> portDirections;
1633 portDirections.push_back(port.direction);
1634 return portDirections;
1637 SmallVector<DictionaryAttr> InstanceOp::portAttributes() {
1638 SmallVector<DictionaryAttr> portAttributes;
1640 portAttributes.push_back(port.attributes);
1641 return portAttributes;
1645 return isa<CombComponentOp>(getReferencedComponent());
1655 auto module = (*this)->getParentOfType<ModuleOp>();
1665 static LogicalResult
1668 auto module = instance->getParentOfType<ModuleOp>();
1669 StringRef entryPointName =
1670 module->getAttrOfType<StringAttr>(
"calyx.entrypoint");
1671 if (instance.getPrimitiveName() == entryPointName)
1672 return instance.emitOpError()
1673 <<
"cannot reference the entry-point component: '" << entryPointName
1677 auto primitivePorts = referencedPrimitive.getPortList();
1678 size_t numPorts = primitivePorts.size();
1680 size_t numResults = instance.getNumResults();
1681 if (numResults != numPorts)
1682 return instance.emitOpError()
1683 <<
"has a wrong number of results; expected: " << numPorts
1684 <<
" but got " << numResults;
1687 ArrayAttr modParameters = referencedPrimitive.getParameters();
1688 ArrayAttr parameters = instance.getParameters().value_or(ArrayAttr());
1689 size_t numExpected = modParameters.size();
1690 size_t numParams = parameters.size();
1691 if (numParams != numExpected)
1692 return instance.emitOpError()
1693 <<
"has the wrong number of parameters; expected: " << numExpected
1694 <<
" but got " << numParams;
1696 for (
size_t i = 0; i != numExpected; ++i) {
1697 auto param = parameters[i].cast<circt::hw::ParamDeclAttr>();
1698 auto modParam = modParameters[i].cast<circt::hw::ParamDeclAttr>();
1700 auto paramName = param.getName();
1701 if (paramName != modParam.getName())
1702 return instance.emitOpError()
1703 <<
"parameter #" << i <<
" should have name " << modParam.getName()
1704 <<
" but has name " << paramName;
1706 if (param.getType() != modParam.getType())
1707 return instance.emitOpError()
1708 <<
"parameter " << paramName <<
" should have type "
1709 << modParam.getType() <<
" but has type " << param.getType();
1713 if (!param.getValue())
1714 return instance.emitOpError(
"parameter ")
1715 << paramName <<
" must have a value";
1718 for (
size_t i = 0; i != numResults; ++i) {
1719 auto resultType = instance.getResult(i).getType();
1720 auto expectedType = primitivePorts[i].
type;
1722 instance.getLoc(), instance.getParametersAttr(), expectedType);
1723 if (failed(replacedType))
1725 if (resultType == replacedType)
1727 return instance.emitOpError()
1728 <<
"result type for " << primitivePorts[i].name <<
" must be "
1729 << expectedType <<
", but got " << resultType;
1735 PrimitiveOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1736 Operation *op = *
this;
1737 auto module = op->getParentOfType<ModuleOp>();
1738 Operation *referencedPrimitive =
1739 symbolTable.lookupNearestSymbolFrom(module, getPrimitiveNameAttr());
1740 if (referencedPrimitive ==
nullptr)
1741 return emitError() <<
"referencing primitive: '" << getPrimitiveName()
1742 <<
"', which does not exist.";
1744 Operation *shadowedPrimitiveName =
1745 symbolTable.lookupNearestSymbolFrom(module, getSymNameAttr());
1746 if (shadowedPrimitiveName !=
nullptr)
1747 return emitError() <<
"instance symbol: '" << instanceName()
1748 <<
"' is already a symbol for another primitive.";
1752 if (parentPrimitive == referencedPrimitive)
1753 return emitError() <<
"recursive instantiation of its parent primitive: '"
1754 << getPrimitiveName() <<
"'";
1756 assert(isa<hw::HWModuleExternOp>(referencedPrimitive) &&
1757 "Should be a HardwareModuleExternOp.");
1760 cast<hw::HWModuleExternOp>(referencedPrimitive));
1768 SmallVector<StringRef> PrimitiveOp::portNames() {
1769 SmallVector<StringRef> portNames;
1770 auto ports = getReferencedPrimitive().getPortList();
1771 for (
auto port : ports)
1772 portNames.push_back(port.name.getValue());
1778 switch (direction) {
1784 llvm_unreachable(
"InOut ports not supported by Calyx");
1786 llvm_unreachable(
"Impossible port type");
1789 SmallVector<Direction> PrimitiveOp::portDirections() {
1790 SmallVector<Direction> portDirections;
1791 auto ports = getReferencedPrimitive().getPortList();
1792 for (hw::PortInfo port : ports)
1794 return portDirections;
1803 DictionaryAttr dict) {
1807 llvm::SmallVector<NamedAttribute> attrs;
1808 for (NamedAttribute attr : dict) {
1809 Dialect *dialect = attr.getNameDialect();
1810 if (dialect ==
nullptr || !isa<CalyxDialect>(*dialect))
1812 StringRef name = attr.getName().strref();
1813 StringAttr newName =
builder.getStringAttr(std::get<1>(name.split(
".")));
1814 attr.setName(newName);
1815 attrs.push_back(attr);
1817 return builder.getDictionaryAttr(attrs);
1821 SmallVector<DictionaryAttr> PrimitiveOp::portAttributes() {
1822 SmallVector<DictionaryAttr> portAttributes;
1823 OpBuilder
builder(getContext());
1825 auto argAttrs = prim.getAllInputAttrs();
1826 auto resAttrs = prim.getAllOutputAttrs();
1827 for (
auto a : argAttrs)
1828 portAttributes.push_back(
1830 for (
auto a : resAttrs)
1831 portAttributes.push_back(
1833 return portAttributes;
1842 SmallVector<Attribute> ¶meters) {
1844 return parser.parseCommaSeparatedList(
1845 OpAsmParser::Delimiter::OptionalLessGreater, [&]() {
1850 if (parser.parseKeywordOrString(&name) || parser.parseColonType(type))
1854 if (succeeded(parser.parseOptionalEqual())) {
1855 if (parser.parseAttribute(value, type))
1859 auto &
builder = parser.getBuilder();
1868 ArrayAttr ¶meters) {
1869 SmallVector<Attribute> parseParameters;
1873 parameters =
ArrayAttr::get(parser.getContext(), parseParameters);
1880 ArrayAttr parameters) {
1881 if (parameters.empty())
1885 llvm::interleaveComma(parameters, p, [&](Attribute param) {
1886 auto paramAttr = param.cast<hw::ParamDeclAttr>();
1887 p << paramAttr.getName().getValue() <<
": " << paramAttr.getType();
1888 if (
auto value = paramAttr.getValue()) {
1890 p.printAttributeWithoutType(
value);
1896 //===----------------------------------------------------------------------===//
1898 //===----------------------------------------------------------------------===//
1900 LogicalResult GroupGoOp::verify() { return verifyNotComplexSource(*this); }
1903 void GroupGoOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
1904 auto parent = (*this)->getParentOfType<GroupOp>();
1905 StringRef name = parent.getSymName();
1906 std::string resultName = name.str() + ".go";
1907 setNameFn(getResult(), resultName);
1910 void GroupGoOp::print(OpAsmPrinter &p) { printGroupPort(p, *this); }
1912 ParseResult GroupGoOp::parse(OpAsmParser &parser, OperationState &result) {
1913 if (parseGroupPort(parser, result))
1916 result.addTypes(parser.getBuilder().getI1Type());
1920 //===----------------------------------------------------------------------===//
1922 //===----------------------------------------------------------------------===//
1924 LogicalResult GroupDoneOp::verify() {
1925 Operation *srcOp = getSrc().getDefiningOp();
1926 Value optionalGuard = getGuard();
1927 Operation *guardOp = optionalGuard ? optionalGuard.getDefiningOp() : nullptr;
1928 bool noGuard = (guardOp == nullptr);
1930 if (srcOp == nullptr)
1931 // This is a port of the parent component.
1934 if (isa<hw::ConstantOp>(srcOp) && (noGuard || isa<hw::ConstantOp>(guardOp)))
1935 return emitOpError() << "with constant source"
1936 << (noGuard ? "" : " and constant guard")
1937 << ". This should be a combinational group.";
1939 return verifyNotComplexSource(*this);
1942 void GroupDoneOp::print(OpAsmPrinter &p) { printGroupPort(p, *this); }
1944 ParseResult GroupDoneOp::parse(OpAsmParser &parser, OperationState &result) {
1945 return parseGroupPort(parser, result);
1948 //===----------------------------------------------------------------------===//
1950 //===----------------------------------------------------------------------===//
1953 void RegisterOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
1954 getCellAsmResultNames(setNameFn, *this, this->portNames());
1957 SmallVector<StringRef> RegisterOp::portNames() {
1958 return {"in", "write_en", "clk", "reset", "out", "done"};
1961 SmallVector<Direction> RegisterOp::portDirections() {
1962 return {Input, Input, Input, Input, Output, Output};
1965 SmallVector<DictionaryAttr> RegisterOp::portAttributes() {
1966 MLIRContext *context = getContext();
1967 IntegerAttr isSet = IntegerAttr::get(IntegerType::get(context, 1), 1);
1968 NamedAttrList writeEn, clk, reset, done;
1969 writeEn.append("go", isSet);
1970 clk.append("clk", isSet);
1971 reset.append("reset", isSet);
1972 done.append("done", isSet);
1974 DictionaryAttr::get(context), // In
1975 writeEn.getDictionary(context), // Write enable
1976 clk.getDictionary(context), // Clk
1977 reset.getDictionary(context), // Reset
1978 DictionaryAttr::get(context), // Out
1979 done.getDictionary(context) // Done
1983 bool RegisterOp::isCombinational() { return false; }
1985 //===----------------------------------------------------------------------===//
1987 //===----------------------------------------------------------------------===//
1990 void MemoryOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
1991 getCellAsmResultNames(setNameFn, *this, this->portNames());
1994 SmallVector<StringRef> MemoryOp::portNames() {
1995 SmallVector<StringRef> portNames;
1996 for (size_t i = 0, e = getAddrSizes().size(); i != e; ++i) {
1998 StringAttr::get(this->getContext(), "addr" + std::to_string(i));
1999 portNames.push_back(nameAttr.getValue());
2001 portNames.append({"write_data", "write_en", "clk", "read_data", "done"});
2005 SmallVector<Direction> MemoryOp::portDirections() {
2006 SmallVector<Direction> portDirections;
2007 for (size_t i = 0, e = getAddrSizes().size(); i != e; ++i)
2008 portDirections.push_back(Input);
2009 portDirections.append({Input, Input, Input, Output, Output});
2010 return portDirections;
2013 SmallVector<DictionaryAttr> MemoryOp::portAttributes() {
2014 SmallVector<DictionaryAttr> portAttributes;
2015 MLIRContext *context = getContext();
2016 for (size_t i = 0, e = getAddrSizes().size(); i != e; ++i)
2017 portAttributes.push_back(DictionaryAttr::get(context)); // Addresses
2019 // Use a boolean to indicate this attribute is used.
2020 IntegerAttr isSet = IntegerAttr::get(IntegerType::get(context, 1), 1);
2021 NamedAttrList writeEn, clk, reset, done;
2022 writeEn.append("go", isSet);
2023 clk.append("clk", isSet);
2024 done.append("done", isSet);
2025 portAttributes.append({DictionaryAttr::get(context), // In
2026 writeEn.getDictionary(context), // Write enable
2027 clk.getDictionary(context), // Clk
2028 DictionaryAttr::get(context), // Out
2029 done.getDictionary(context)} // Done
2031 return portAttributes;
2034 void MemoryOp::build(OpBuilder &builder, OperationState &state,
2035 StringRef instanceName, int64_t width,
2036 ArrayRef<int64_t> sizes, ArrayRef<int64_t> addrSizes) {
2037 state.addAttribute(SymbolTable::getSymbolAttrName(),
2038 builder.getStringAttr(instanceName));
2039 state.addAttribute("width", builder.getI64IntegerAttr(width));
2040 state.addAttribute("sizes", builder.getI64ArrayAttr(sizes));
2041 state.addAttribute("addrSizes", builder.getI64ArrayAttr(addrSizes));
2042 SmallVector<Type> types;
2043 for (int64_t size : addrSizes)
2044 types.push_back(builder.getIntegerType(size)); // Addresses
2045 types.push_back(builder.getIntegerType(width)); // Write data
2046 types.push_back(builder.getI1Type()); // Write enable
2047 types.push_back(builder.getI1Type()); // Clk
2048 types.push_back(builder.getIntegerType(width)); // Read data
2049 types.push_back(builder.getI1Type()); // Done
2050 state.addTypes(types);
2053 LogicalResult MemoryOp::verify() {
2054 ArrayRef<Attribute> opSizes = getSizes().getValue();
2055 ArrayRef<Attribute> opAddrSizes = getAddrSizes().getValue();
2056 size_t numDims = getSizes().size();
2057 size_t numAddrs = getAddrSizes().size();
2058 if (numDims != numAddrs)
2059 return emitOpError("mismatched number of dimensions (")
2060 << numDims << ") and address sizes (" << numAddrs << ")";
2062 size_t numExtraPorts = 5; // write data/enable, clk, and read data/done.
2063 if (getNumResults() != numAddrs + numExtraPorts)
2064 return emitOpError("incorrect number of address ports, expected ")
2067 for (size_t i = 0; i < numDims; ++i) {
2068 int64_t size = opSizes[i].cast<IntegerAttr>().getInt();
2069 int64_t addrSize = opAddrSizes[i].cast<IntegerAttr>().getInt();
2070 if (llvm::Log2_64_Ceil(size) > addrSize)
2071 return emitOpError("address size (")
2072 << addrSize << ") for dimension " << i
2073 << " can't address the entire range (
" << size << ")
";
2079 //===----------------------------------------------------------------------===//
2081 //===----------------------------------------------------------------------===//
2084 void SeqMemoryOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
2085 getCellAsmResultNames(setNameFn, *this, this->portNames());
2088 SmallVector<StringRef> SeqMemoryOp::portNames() {
2089 SmallVector<StringRef> portNames;
2090 for (size_t i = 0, e = getAddrSizes().size(); i != e; ++i) {
2092 StringAttr::get(this->getContext(), "addr
" + std::to_string(i));
2093 portNames.push_back(nameAttr.getValue());
2095 portNames.append({"write_data
", "write_en
", "write_done
", "clk
", "read_data
",
2096 "read_en
", "read_done
"});
2100 SmallVector<Direction> SeqMemoryOp::portDirections() {
2101 SmallVector<Direction> portDirections;
2102 for (size_t i = 0, e = getAddrSizes().size(); i != e; ++i)
2103 portDirections.push_back(Input);
2104 portDirections.append({Input, Input, Output, Input, Output, Input, Output});
2105 return portDirections;
2108 SmallVector<DictionaryAttr> SeqMemoryOp::portAttributes() {
2109 SmallVector<DictionaryAttr> portAttributes;
2110 MLIRContext *context = getContext();
2111 for (size_t i = 0, e = getAddrSizes().size(); i != e; ++i)
2112 portAttributes.push_back(DictionaryAttr::get(context)); // Addresses
2114 OpBuilder builder(context);
2115 // Use a boolean to indicate this attribute is used.
2116 IntegerAttr isSet = IntegerAttr::get(builder.getIndexType(), 1);
2117 IntegerAttr isTwo = IntegerAttr::get(builder.getIndexType(), 2);
2118 NamedAttrList writeEn, writeDone, clk, reset, readEn, readDone;
2119 writeEn.append("go
", isSet);
2120 writeDone.append("done
", isSet);
2121 clk.append("clk
", isSet);
2122 readEn.append("go
", isTwo);
2123 readDone.append("done
", isTwo);
2124 portAttributes.append({DictionaryAttr::get(context), // Write Data
2125 writeEn.getDictionary(context), // Write enable
2126 writeDone.getDictionary(context), // Write done
2127 clk.getDictionary(context), // Clk
2128 DictionaryAttr::get(context), // Out
2129 readEn.getDictionary(context), // Read enable
2130 readDone.getDictionary(context)} // Read done
2132 return portAttributes;
2135 void SeqMemoryOp::build(OpBuilder &builder, OperationState &state,
2136 StringRef instanceName, int64_t width,
2137 ArrayRef<int64_t> sizes, ArrayRef<int64_t> addrSizes) {
2138 state.addAttribute(SymbolTable::getSymbolAttrName(),
2139 builder.getStringAttr(instanceName));
2140 state.addAttribute("width", builder.getI64IntegerAttr(width));
2141 state.addAttribute("sizes
", builder.getI64ArrayAttr(sizes));
2142 state.addAttribute("addrSizes
", builder.getI64ArrayAttr(addrSizes));
2143 SmallVector<Type> types;
2144 for (int64_t size : addrSizes)
2145 types.push_back(builder.getIntegerType(size)); // Addresses
2146 types.push_back(builder.getIntegerType(width)); // Write data
2147 types.push_back(builder.getI1Type()); // Write enable
2148 types.push_back(builder.getI1Type()); // Write done
2149 types.push_back(builder.getI1Type()); // Clk
2150 types.push_back(builder.getIntegerType(width)); // Read data
2151 types.push_back(builder.getI1Type()); // Read enable
2152 types.push_back(builder.getI1Type()); // Read done
2153 state.addTypes(types);
2156 LogicalResult SeqMemoryOp::verify() {
2157 ArrayRef<Attribute> opSizes = getSizes().getValue();
2158 ArrayRef<Attribute> opAddrSizes = getAddrSizes().getValue();
2159 size_t numDims = getSizes().size();
2160 size_t numAddrs = getAddrSizes().size();
2161 if (numDims != numAddrs)
2162 return emitOpError("mismatched number of dimensions (
")
2163 << numDims << ") and address sizes (
" << numAddrs << ")
";
2165 size_t numExtraPorts =
2166 7; // write data/enable/done, clk, and read data/enable/done.
2167 if (getNumResults() != numAddrs + numExtraPorts)
2168 return emitOpError("incorrect number of address ports, expected
")
2171 for (size_t i = 0; i < numDims; ++i) {
2172 int64_t size = opSizes[i].cast<IntegerAttr>().getInt();
2173 int64_t addrSize = opAddrSizes[i].cast<IntegerAttr>().getInt();
2174 if (llvm::Log2_64_Ceil(size) > addrSize)
2175 return emitOpError("address size (
")
2176 << addrSize << ")
for dimension
" << i
2177 << " can
't address the entire range (" << size << ")";
2183 //===----------------------------------------------------------------------===//
2185 //===----------------------------------------------------------------------===//
2186 LogicalResult EnableOp::verify() {
2187 auto component = (*this)->getParentOfType<ComponentOp>();
2188 auto wiresOp = component.getWiresOp();
2189 StringRef name = getGroupName();
2191 auto groupOp = wiresOp.lookupSymbol<GroupInterface>(name);
2193 return emitOpError() << "with group '" << name
2194 << "', which does not exist.";
2196 if (isa<CombGroupOp>(groupOp))
2197 return emitOpError() << "with group '" << name
2198 << "', which is a combinational group.";
2203 //===----------------------------------------------------------------------===//
2205 //===----------------------------------------------------------------------===//
2207 LogicalResult IfOp::verify() {
2208 std::optional<StringRef> optGroupName = getGroupName();
2209 if (!optGroupName) {
2210 // No combinational group was provided.
2213 auto component = (*this)->getParentOfType<ComponentOp>();
2214 WiresOp wiresOp = component.getWiresOp();
2215 StringRef groupName = *optGroupName;
2216 auto groupOp = wiresOp.lookupSymbol<GroupInterface>(groupName);
2218 return emitOpError() << "with group '" << groupName
2219 << "', which does not exist.";
2221 if (isa<GroupOp>(groupOp))
2222 return emitOpError() << "with group '" << groupName
2223 << "', which is not a combinational group.";
2225 if (failed(groupOp.drivesPort(getCond())))
2226 return emitError() << "with conditional op: '"
2227 << valueName(component, getCond())
2228 << "' expected to be driven from group: '" << groupName
2229 << "' but no driver was found.";
2237 template <typename OpTy>
2238 static std::optional<EnableOp> getLastEnableOp(OpTy parent) {
2239 static_assert(IsAny<OpTy, SeqOp, StaticSeqOp>(),
2240 "Should be a StaticSeqOp or SeqOp.");
2241 auto &lastOp = parent.getBodyBlock()->back();
2242 if (auto enableOp = dyn_cast<EnableOp>(lastOp))
2244 if (auto seqOp = dyn_cast<SeqOp>(lastOp))
2245 return getLastEnableOp(seqOp);
2246 if (auto staticSeqOp = dyn_cast<StaticSeqOp>(lastOp))
2247 return getLastEnableOp(staticSeqOp);
2249 return std::nullopt;
2254 template <typename OpTy>
2255 static llvm::StringMap<EnableOp> getAllEnableOpsInImmediateBody(OpTy parent) {
2256 static_assert(IsAny<OpTy, ParOp, StaticParOp>(),
2257 "Should be a StaticParOp or ParOp.");
2259 llvm::StringMap<EnableOp> enables;
2260 Block *body = parent.getBodyBlock();
2261 for (EnableOp op : body->getOps<EnableOp>())
2262 enables.insert(std::pair(op.getGroupName(), op));
2274 template <typename IfOpTy, typename TailOpTy>
2275 static bool hasCommonTailPatternPreConditions(IfOpTy op) {
2276 static_assert(IsAny<TailOpTy, SeqOp, ParOp, StaticSeqOp, StaticParOp>(),
2277 "Should be a SeqOp, ParOp, StaticSeqOp, or StaticParOp.");
2278 static_assert(IsAny<IfOpTy, IfOp, StaticIfOp>(),
2279 "Should be a IfOp or StaticIfOp.");
2281 if (!op.thenBodyExists() || !op.elseBodyExists())
2283 if (op.getThenBody()->empty() || op.getElseBody()->empty())
2286 Block *thenBody = op.getThenBody(), *elseBody = op.getElseBody();
2287 return isa<TailOpTy>(thenBody->front()) && isa<TailOpTy>(elseBody->front());
2298 template <typename IfOpTy, typename SeqOpTy>
2299 static LogicalResult commonTailPatternWithSeq(IfOpTy ifOp,
2300 PatternRewriter &rewriter) {
2301 static_assert(IsAny<IfOpTy, IfOp, StaticIfOp>(),
2302 "Should be an IfOp or StaticIfOp.");
2303 static_assert(IsAny<SeqOpTy, SeqOp, StaticSeqOp>(),
2304 "Branches should be checking for an SeqOp or StaticSeqOp");
2305 if (!hasCommonTailPatternPreConditions<IfOpTy, SeqOpTy>(ifOp))
2307 auto thenControl = cast<SeqOpTy>(ifOp.getThenBody()->front()),
2308 elseControl = cast<SeqOpTy>(ifOp.getElseBody()->front());
2310 std::optional<EnableOp> lastThenEnableOp = getLastEnableOp(thenControl),
2311 lastElseEnableOp = getLastEnableOp(elseControl);
2313 if (!lastThenEnableOp || !lastElseEnableOp)
2315 if (lastThenEnableOp->getGroupName() != lastElseEnableOp->getGroupName())
2318 // Place the IfOp and pulled EnableOp inside a sequential region, in case
2319 // this IfOp is nested in a ParOp. This avoids unintentionally
2320 // parallelizing the pulled out EnableOps.
2321 rewriter.setInsertionPointAfter(ifOp);
2322 SeqOpTy seqOp = rewriter.create<SeqOpTy>(ifOp.getLoc());
2323 Block *body = seqOp.getBodyBlock();
2325 body->push_back(ifOp);
2326 rewriter.setInsertionPointToEnd(body);
2327 rewriter.create<EnableOp>(seqOp.getLoc(), lastThenEnableOp->getGroupName());
2329 // Erase the common EnableOp from the Then and Else regions.
2330 rewriter.eraseOp(*lastThenEnableOp);
2331 rewriter.eraseOp(*lastElseEnableOp);
2348 template <typename OpTy, typename ParOpTy>
2349 static LogicalResult commonTailPatternWithPar(OpTy controlOp,
2350 PatternRewriter &rewriter) {
2351 static_assert(IsAny<OpTy, IfOp, StaticIfOp>(),
2352 "Should be an IfOp or StaticIfOp.");
2353 static_assert(IsAny<ParOpTy, ParOp, StaticParOp>(),
2354 "Branches should be checking for an ParOp or StaticParOp");
2355 if (!hasCommonTailPatternPreConditions<OpTy, ParOpTy>(controlOp))
2357 auto thenControl = cast<ParOpTy>(controlOp.getThenBody()->front()),
2358 elseControl = cast<ParOpTy>(controlOp.getElseBody()->front());
2360 llvm::StringMap<EnableOp> a = getAllEnableOpsInImmediateBody(thenControl),
2361 b = getAllEnableOpsInImmediateBody(elseControl);
2362 // Compute the intersection between `A` and `B`.
2363 SmallVector<StringRef> groupNames;
2364 for (auto aIndex = a.begin(); aIndex != a.end(); ++aIndex) {
2365 StringRef groupName = aIndex->getKey();
2366 auto bIndex = b.find(groupName);
2367 if (bIndex == b.end())
2369 // This is also an element in B.
2370 groupNames.push_back(groupName);
2371 // Since these are being pulled out, erase them.
2372 rewriter.eraseOp(aIndex->getValue());
2373 rewriter.eraseOp(bIndex->getValue());
2376 // Place the IfOp and EnableOp(s) inside a parallel region, in case this
2377 // IfOp is nested in a SeqOp. This avoids unintentionally sequentializing
2378 // the pulled out EnableOps.
2379 rewriter.setInsertionPointAfter(controlOp);
2381 ParOpTy parOp = rewriter.create<ParOpTy>(controlOp.getLoc());
2382 Block *body = parOp.getBodyBlock();
2383 controlOp->remove();
2384 body->push_back(controlOp);
2385 // Pull out the intersection between these two sets, and erase their
2386 // counterparts in the Then and Else regions.
2387 rewriter.setInsertionPointToEnd(body);
2388 for (StringRef groupName : groupNames)
2389 rewriter.create<EnableOp>(parOp.getLoc(), groupName);
2397 struct EmptyIfBody : mlir::OpRewritePattern<IfOp> {
2398 using mlir::OpRewritePattern<IfOp>::OpRewritePattern;
2399 LogicalResult matchAndRewrite(IfOp ifOp,
2400 PatternRewriter &rewriter) const override {
2401 if (!ifOp.getThenBody()->empty())
2403 if (ifOp.elseBodyExists() && !ifOp.getElseBody()->empty())
2406 eraseControlWithGroupAndConditional(ifOp, rewriter);
2412 void IfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2413 MLIRContext *context) {
2414 patterns.add<EmptyIfBody>(context);
2415 patterns.add(commonTailPatternWithPar<IfOp, ParOp>);
2416 patterns.add(commonTailPatternWithSeq<IfOp, SeqOp>);
2419 //===----------------------------------------------------------------------===//
2421 //===----------------------------------------------------------------------===//
2422 LogicalResult StaticIfOp::verify() {
2423 if (elseBodyExists()) {
2424 auto *elseBod = getElseBody();
2425 auto &elseOps = elseBod->getOperations();
2426 // should only have one Operation, static, in the else branch
2427 for (Operation &op : elseOps) {
2428 if (!isStaticControl(&op)) {
2429 return op.emitOpError(
2430 "static if's
else branch has non
static control within it
");
2435 auto *thenBod = getThenBody();
2436 auto &thenOps = thenBod->getOperations();
2437 for (Operation &op : thenOps) {
2438 // should only have one, static, Operation in the then branch
2439 if (!isStaticControl(&op)) {
2440 return op.emitOpError(
2441 "static if's then branch has non static control within it");
2451 struct EmptyStaticIfBody : mlir::OpRewritePattern<StaticIfOp> {
2452 using mlir::OpRewritePattern<StaticIfOp>::OpRewritePattern;
2453 LogicalResult matchAndRewrite(StaticIfOp ifOp,
2454 PatternRewriter &rewriter) const override {
2455 if (!ifOp.getThenBody()->empty())
2457 if (ifOp.elseBodyExists() && !ifOp.getElseBody()->empty())
2460 eraseControlWithConditional(ifOp, rewriter);
2466 void StaticIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2467 MLIRContext *context) {
2468 patterns.add<EmptyStaticIfBody>(context);
2469 patterns.add(commonTailPatternWithPar<StaticIfOp, StaticParOp>);
2470 patterns.add(commonTailPatternWithSeq<StaticIfOp, StaticSeqOp>);
2473 //===----------------------------------------------------------------------===//
2475 //===----------------------------------------------------------------------===//
2476 LogicalResult WhileOp::verify() {
2477 auto component = (*this)->getParentOfType<ComponentOp>();
2478 auto wiresOp = component.getWiresOp();
2480 std::optional<StringRef> optGroupName = getGroupName();
2481 if (!optGroupName) {
2485 StringRef groupName = *optGroupName;
2486 auto groupOp = wiresOp.lookupSymbol<GroupInterface>(groupName);
2488 return emitOpError() << "with group '
" << groupName
2489 << "', which does not exist.";
2491 if (isa<GroupOp>(groupOp))
2492 return emitOpError() << "with group '" << groupName
2493 << "', which is not a combinational group.";
2495 if (failed(groupOp.drivesPort(getCond())))
2496 return emitError() << "conditional op: '" << valueName(component, getCond())
2497 << "' expected to be driven from group: '" << groupName
2498 << "' but no driver was found.";
2503 LogicalResult WhileOp::canonicalize(WhileOp whileOp,
2504 PatternRewriter &rewriter) {
2505 if (whileOp.getBodyBlock()->empty()) {
2506 eraseControlWithGroupAndConditional(whileOp, rewriter);
2513 //===----------------------------------------------------------------------===//
2515 //===----------------------------------------------------------------------===//
2516 LogicalResult StaticRepeatOp::verify() {
2517 for (auto &&bodyOp : (*this).getRegion().front()) {
2518 // there should only be one bodyOp for each StaticRepeatOp
2519 if (!isStaticControl(&bodyOp)) {
2520 return bodyOp.emitOpError(
2521 "static repeat has non static control within it");
2528 template <typename OpTy>
2529 static LogicalResult zeroRepeat(OpTy op, PatternRewriter &rewriter) {
2530 static_assert(IsAny<OpTy, RepeatOp, StaticRepeatOp>(),
2531 "Should be a RepeatOp or StaticPRepeatOp");
2532 if (op.getCount() == 0) {
2533 Block *controlBody = op.getBodyBlock();
2534 for (auto &op : make_early_inc_range(*controlBody))
2537 rewriter.eraseOp(op);
2544 void StaticRepeatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2545 MLIRContext *context) {
2546 patterns.add(emptyControl<StaticRepeatOp>);
2547 patterns.add(zeroRepeat<StaticRepeatOp>);
2550 //===----------------------------------------------------------------------===//
2552 //===----------------------------------------------------------------------===//
2553 void RepeatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2554 MLIRContext *context) {
2555 patterns.add(emptyControl<RepeatOp>);
2556 patterns.add(zeroRepeat<RepeatOp>);
2559 //===----------------------------------------------------------------------===//
2561 //===----------------------------------------------------------------------===//
2563 // Parse the parameter list of invoke.
2565 parseParameterList(OpAsmParser &parser, OperationState &result,
2566 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &ports,
2567 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inputs,
2568 SmallVectorImpl<Attribute> &portNames,
2569 SmallVectorImpl<Attribute> &inputNames,
2570 SmallVectorImpl<Type> &types) {
2571 OpAsmParser::UnresolvedOperand port;
2572 OpAsmParser::UnresolvedOperand input;
2574 auto parseParameter = [&]() -> ParseResult {
2575 if (parser.parseOperand(port) || parser.parseEqual() ||
2576 parser.parseOperand(input))
2578 ports.push_back(port);
2579 portNames.push_back(StringAttr::get(parser.getContext(), port.name));
2580 inputs.push_back(input);
2581 inputNames.push_back(StringAttr::get(parser.getContext(), input.name));
2584 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
2587 if (parser.parseArrow())
2589 auto parseType = [&]() -> ParseResult {
2590 if (parser.parseType(type))
2592 types.push_back(type);
2595 return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
2599 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
2600 StringAttr componentName;
2601 SmallVector<OpAsmParser::UnresolvedOperand, 4> ports;
2602 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs;
2603 SmallVector<Attribute> portNames;
2604 SmallVector<Attribute> inputNames;
2605 SmallVector<Type, 4> types;
2606 if (parser.parseSymbolName(componentName))
2608 FlatSymbolRefAttr callee = FlatSymbolRefAttr::get(componentName);
2609 SMLoc loc = parser.getCurrentLocation();
2610 result.addAttribute("callee", callee);
2611 if (parseParameterList(parser, result, ports, inputs, portNames, inputNames,
2614 if (parser.resolveOperands(ports, types, loc, result.operands))
2616 if (parser.resolveOperands(inputs, types, loc, result.operands))
2618 result.addAttribute("portNames",
2619 ArrayAttr::get(parser.getContext(), portNames));
2620 result.addAttribute("inputNames",
2621 ArrayAttr::get(parser.getContext(), inputNames));
2625 void InvokeOp::print(OpAsmPrinter &p) {
2626 p << " @" << getCallee() << "(";
2627 auto ports = getPorts();
2628 auto inputs = getInputs();
2629 llvm::interleaveComma(llvm::zip(ports, inputs), p, [&](auto arg) {
2630 p << std::get<0>(arg) << " = " << std::get<1>(arg);
2633 llvm::interleaveComma(ports, p, [&](auto port) { p << port.getType(); });
2637 // Check the direction of one of the ports in one of the connections of an
2639 static LogicalResult verifyInvokeOpValue(InvokeOp &op, Value &value,
2640 bool isDestination) {
2642 return verifyPortDirection(op, value, isDestination);
2646 // Checks if the value comes from complex logic.
2647 static LogicalResult verifyComplexLogic(InvokeOp &op, Value &value) {
2648 // Refer to the above function verifyNotComplexSource for its role.
2649 Operation *operation = value.getDefiningOp();
2650 if (operation == nullptr)
2652 if (auto *dialect = operation->getDialect(); isa<comb::CombDialect>(dialect))
2657 // Get the go port of the invoked component.
2658 Value InvokeOp::getInstGoValue() {
2659 ComponentOp componentOp = (*this)->getParentOfType<ComponentOp>();
2660 Operation *operation = componentOp.lookupSymbol(getCallee());
2661 Value ret = nullptr;
2662 llvm::TypeSwitch<Operation *>(operation)
2663 .Case<RegisterOp>([&](auto op) { ret = operation->getResult(1); })
2664 .Case<MemoryOp, DivSPipeLibOp, DivUPipeLibOp, MultPipeLibOp,
2665 RemSPipeLibOp, RemUPipeLibOp>(
2666 [&](auto op) { ret = operation->getResult(2); })
2667 .Case<InstanceOp>([&](auto op) {
2668 auto portInfo = op.getReferencedComponent().getPortInfo();
2669 for (auto [portInfo, res] :
2670 llvm::zip(portInfo, operation->getResults())) {
2671 if (portInfo.hasAttribute("go"))
2675 .Case<PrimitiveOp>([&](auto op) {
2676 auto moduleExternOp = op.getReferencedPrimitive();
2677 auto argAttrs = moduleExternOp.getAllInputAttrs();
2678 for (auto [attr, res] : llvm::zip(argAttrs, op.getResults())) {
2679 if (DictionaryAttr dictAttr = dyn_cast<DictionaryAttr>(attr)) {
2680 if (!dictAttr.empty()) {
2681 if (dictAttr.begin()->getName().getValue() == "calyx.go")
2690 // Get the done port of the invoked component.
2691 Value InvokeOp::getInstDoneValue() {
2692 ComponentOp componentOp = (*this)->getParentOfType<ComponentOp>();
2693 Operation *operation = componentOp.lookupSymbol(getCallee());
2694 Value ret = nullptr;
2695 llvm::TypeSwitch<Operation *>(operation)
2696 .Case<RegisterOp, MemoryOp, DivSPipeLibOp, DivUPipeLibOp, MultPipeLibOp,
2697 RemSPipeLibOp, RemUPipeLibOp>([&](auto op) {
2698 size_t doneIdx = operation->getResults().size() - 1;
2699 ret = operation->getResult(doneIdx);
2701 .Case<InstanceOp>([&](auto op) {
2702 InstanceOp instanceOp = cast<InstanceOp>(operation);
2703 auto portInfo = instanceOp.getReferencedComponent().getPortInfo();
2704 for (auto [portInfo, res] :
2705 llvm::zip(portInfo, operation->getResults())) {
2706 if (portInfo.hasAttribute("done"))
2710 .Case<PrimitiveOp>([&](auto op) {
2711 PrimitiveOp primOp = cast<PrimitiveOp>(operation);
2712 auto moduleExternOp = primOp.getReferencedPrimitive();
2713 auto resAttrs = moduleExternOp.getAllOutputAttrs();
2714 for (auto [attr, res] : llvm::zip(resAttrs, primOp.getResults())) {
2715 if (DictionaryAttr dictAttr = dyn_cast<DictionaryAttr>(attr)) {
2716 if (!dictAttr.empty()) {
2717 if (dictAttr.begin()->getName().getValue() == "calyx.done")
2726 // A helper function that gets the number of go or done ports in
2727 // hw.module.extern.
2729 getHwModuleExtGoOrDonePortNumber(hw::HWModuleExternOp &moduleExternOp,
2732 std::string str = isGo ? "calyx.go" : "calyx.done";
2733 for (Attribute attr : moduleExternOp.getAllInputAttrs()) {
2734 if (DictionaryAttr dictAttr = dyn_cast<DictionaryAttr>(attr)) {
2735 ret = llvm::count_if(dictAttr, [&](NamedAttribute iter) {
2736 return iter.getName().getValue() == str;
2743 LogicalResult InvokeOp::verify() {
2744 ComponentOp componentOp = (*this)->getParentOfType<ComponentOp>();
2745 StringRef callee = getCallee();
2746 Operation *operation = componentOp.lookupSymbol(callee);
2747 // The referenced symbol does not exist.
2749 return emitOpError() << "with instance '@" << callee
2750 << "', which does not exist.";
2751 // The argument list of invoke is empty.
2752 if (getInputs().empty())
2753 return emitOpError() << "'@" << callee
2754 << "' has zero input and output port connections; "
2755 "expected at least one.";
2756 size_t goPortNum = 0, donePortNum = 0;
2757 // They both have a go port and a done port, but the "go" port for
2758 // registers and memrey should be "write_en" port.
2759 llvm::TypeSwitch<Operation *>(operation)
2760 .Case<RegisterOp, DivSPipeLibOp, DivUPipeLibOp, MemoryOp, MultPipeLibOp,
2761 RemSPipeLibOp, RemUPipeLibOp>(
2762 [&](auto op) { goPortNum = 1, donePortNum = 1; })
2763 .Case<InstanceOp>([&](auto op) {
2764 auto portInfo = op.getReferencedComponent().getPortInfo();
2765 for (PortInfo info : portInfo) {
2766 if (info.hasAttribute("go"))
2768 if (info.hasAttribute("done"))
2772 .Case<PrimitiveOp>([&](auto op) {
2773 auto moduleExternOp = op.getReferencedPrimitive();
2774 // Get the number of go ports and done ports by their attrubutes.
2775 goPortNum = getHwModuleExtGoOrDonePortNumber(moduleExternOp, true);
2776 donePortNum = getHwModuleExtGoOrDonePortNumber(moduleExternOp, false);
2778 // If the number of go ports and done ports is wrong.
2779 if (goPortNum != 1 && donePortNum != 1)
2780 return emitOpError()
2781 << "'@" << callee << "'"
2782 << " is a combinational component and cannot be invoked, which must "
2783 "have single go port and single done port.";
2785 auto ports = getPorts();
2786 auto inputs = getInputs();
2787 // We have verified earlier that the instance has a go and a done port.
2788 Value goValue = getInstGoValue();
2789 Value doneValue = getInstDoneValue();
2790 for (auto [port, input, portName, inputName] :
2791 llvm::zip(ports, inputs, getPortNames(), getInputNames())) {
2792 // Check the direction of these destination ports.
2793 // 'calyx.invoke
' op '@r0
' has input '%r.out
', which is a source port. The
2794 // inputs are required to be destination ports.
2795 if (failed(verifyInvokeOpValue(*this, port, true)))
2796 return emitOpError() << "'@" << callee << "' has input '"
2797 << portName.cast<StringAttr>().getValue()
2798 << "', which is a source port. The inputs are "
2799 "required to be destination ports.";
2800 // The go port should not appear in the parameter list.
2801 if (port == goValue)
2802 return emitOpError() << "the go or write_en port of '@" << callee
2803 << "' cannot appear here.";
2804 // Check the direction of these source ports.
2805 if (failed(verifyInvokeOpValue(*this, input, false)))
2806 return emitOpError() << "'@" << callee << "' has output '"
2807 << inputName.cast<StringAttr>().getValue()
2808 << "', which is a destination port. The inputs are "
2809 "required to be source ports.";
2810 if (failed(verifyComplexLogic(*this, input)))
2811 return emitOpError() << "'@" << callee << "' has '"
2812 << inputName.cast<StringAttr>().getValue()
2813 << "', which is not a port or constant. Complex "
2814 "logic should be conducted in the guard.";
2815 if (input == doneValue)
2816 return emitOpError() << "the done port of '@" << callee
2817 << "' cannot appear here.";
2818 // Check if the connection uses the callee's port.
2819 if (port.getDefiningOp() != operation && input.getDefiningOp() != operation)
2820 return emitOpError() <<
"the connection "
2821 << portName.cast<StringAttr>().getValue() <<
" = "
2822 << inputName.cast<StringAttr>().getValue()
2823 <<
" is not defined as an input port of '@" << callee
2833 LogicalResult PadLibOp::verify() {
2834 unsigned inBits = getResult(0).getType().getIntOrFloatBitWidth();
2835 unsigned outBits = getResult(1).getType().getIntOrFloatBitWidth();
2836 if (inBits >= outBits)
2837 return emitOpError(
"expected input bits (")
2838 << inBits <<
')' <<
" to be less than output bits (" << outBits
2843 LogicalResult SliceLibOp::verify() {
2844 unsigned inBits = getResult(0).getType().getIntOrFloatBitWidth();
2845 unsigned outBits = getResult(1).getType().getIntOrFloatBitWidth();
2846 if (inBits <= outBits)
2847 return emitOpError(
"expected input bits (")
2848 << inBits <<
')' <<
" to be greater than output bits (" << outBits
2853 #define ImplBinPipeOpCellInterface(OpType, outName) \
2854 SmallVector<StringRef> OpType::portNames() { \
2855 return {"clk", "reset", "go", "left", "right", outName, "done"}; \
2858 SmallVector<Direction> OpType::portDirections() { \
2859 return {Input, Input, Input, Input, Input, Output, Output}; \
2862 void OpType::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { \
2863 getCellAsmResultNames(setNameFn, *this, this->portNames()); \
2866 SmallVector<DictionaryAttr> OpType::portAttributes() { \
2867 MLIRContext *context = getContext(); \
2868 IntegerAttr isSet = IntegerAttr::get(IntegerType::get(context, 1), 1); \
2869 NamedAttrList go, clk, reset, done; \
2870 go.append("go", isSet); \
2871 clk.append("clk", isSet); \
2872 reset.append("reset", isSet); \
2873 done.append("done", isSet); \
2875 clk.getDictionary(context),
\
2876 reset.getDictionary(context), \
2877 go.getDictionary(context), \
2878 DictionaryAttr::get(context), \
2879 DictionaryAttr::get(context), \
2880 DictionaryAttr::get(context), \
2881 done.getDictionary(context) \
2885 bool OpType::isCombinational() { return false; }
2887 #define ImplUnaryOpCellInterface(OpType) \
2888 SmallVector<StringRef> OpType::portNames() { return {"in", "out"}; } \
2889 SmallVector<Direction> OpType::portDirections() { return {Input, Output}; } \
2890 SmallVector<DictionaryAttr> OpType::portAttributes() { \
2891 return {DictionaryAttr::get(getContext()), \
2892 DictionaryAttr::get(getContext())}; \
2894 bool OpType::isCombinational() { return true; } \
2895 void OpType::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { \
2896 getCellAsmResultNames(setNameFn, *this, this->portNames()); \
2899 #define ImplBinOpCellInterface(OpType) \
2900 SmallVector<StringRef> OpType::portNames() { \
2901 return {"left", "right", "out"}; \
2903 SmallVector<Direction> OpType::portDirections() { \
2904 return {Input, Input, Output}; \
2906 void OpType::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { \
2907 getCellAsmResultNames(setNameFn, *this, this->portNames()); \
2909 bool OpType::isCombinational() { return true; } \
2910 SmallVector<DictionaryAttr> OpType::portAttributes() { \
2911 return {DictionaryAttr::get(getContext()), \
2912 DictionaryAttr::get(getContext()), \
2913 DictionaryAttr::get(getContext())}; \
2957 #include "circt/Dialect/Calyx/CalyxInterfaces.cpp.inc"
2960 #define GET_OP_CLASSES
2961 #include "circt/Dialect/Calyx/Calyx.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyPrimitiveOpType(PrimitiveOp instance, hw::HWModuleExternOp referencedPrimitive)
Verifies the port information in comparison with the referenced component of an instance.
static ParseResult parseComponentSignature(OpAsmParser &parser, OperationState &result, SmallVectorImpl< OpAsmParser::Argument > &ports, SmallVectorImpl< Type > &portTypes)
Parses the signature of a Calyx component.
static LogicalResult verifyAssignOpValue(AssignOp op, bool isDestination)
Verifies the value of a given assignment operation.
static ParseResult parseParameterList(OpAsmParser &parser, SmallVector< Attribute > ¶meters)
Parse an parameter list if present.
static SmallVector< PortInfo > getFilteredPorts(ComponentOp op, Pred p)
A helper function to return a filtered subset of a component's ports.
static Op getControlOrWiresFrom(ComponentOp op)
This is a helper function that should only be used to get the WiresOp or ControlOp of a ComponentOp,...
static LogicalResult verifyPrimitivePortDriving(AssignOp assign, GroupInterface group)
Verifies that certain ports of primitives are either driven or read together.
#define ImplBinPipeOpCellInterface(OpType, outName)
static bool portIsUsedInGroup(GroupInterface group, Value port, bool isDriven)
Determines whether the given port is used in the group.
static Value getBlockArgumentWithName(StringRef name, ComponentOp op)
Returns the Block argument with the given name from a ComponentOp.
static ParseResult parsePortDefList(OpAsmParser &parser, OperationState &result, SmallVectorImpl< OpAsmParser::Argument > &ports, SmallVectorImpl< Type > &portTypes, SmallVectorImpl< NamedAttrList > &portAttrs)
Parses the ports of a Calyx component signature, and adds the corresponding port names to attrName.
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
static LogicalResult verifyNotComplexSource(Op op)
Verify that the value is not a "complex" value.
static LogicalResult verifyInstanceOpType(InstanceOp instance, ComponentInterface referencedComponent)
Verifies the port information in comparison with the referenced component of an instance.
static LogicalResult collapseControl(OpTy controlOp, PatternRewriter &rewriter)
Direction convertHWDirectionToCalyx(hw::ModulePort::Direction direction)
static LogicalResult anyPortsReadByGroup(GroupInterface group, ValueRange ports)
Checks whether any ports are read within the group.
static bool hasControlRegion(Operation *op)
Returns whether the given operation has a control region.
static LogicalResult emptyControl(OpTy controlOp, PatternRewriter &rewriter)
static LogicalResult verifyControlBody(Operation *op)
Verifies the body of a ControlLikeOp.
static void eraseControlWithConditional(OpTy op, PatternRewriter &rewriter)
A helper function to check whether the conditional needs to be erased to maintain a valid state of a ...
static void eraseControlWithGroupAndConditional(OpTy op, PatternRewriter &rewriter)
A helper function to check whether the conditional and group (if it exists) needs to be erased to mai...
static ParseResult parseComponentInterface(OpAsmParser &parser, OperationState &result)
static void printComponentInterface(OpAsmPrinter &p, ComponentInterface comp)
static LogicalResult hasRequiredPorts(ComponentOp op)
Determines whether the given ComponentOp has all the required ports.
static LogicalResult anyPortsDrivenByGroup(GroupInterface group, ValueRange ports)
Checks whether any ports are driven within the group.
static bool isPort(Value value)
Returns whether this value is either (1) a port on a ComponentOp or (2) a port on a cell interface.
#define ImplBinOpCellInterface(OpType)
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static LogicalResult portDrivenByGroup(GroupInterface groupOp, Value port)
Checks whether port is driven from within groupOp.
static void buildComponentLike(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, bool combinational)
static void getCellAsmResultNames(OpAsmSetValueNameFn setNameFn, Operation *op, ArrayRef< StringRef > portNames)
Gives each result of the cell a meaningful name in the form: <instance-name>.
static LogicalResult allPortsDrivenByGroup(GroupInterface group, ValueRange ports)
Checks whether all ports are driven within the group.
static LogicalResult verifyPortDirection(Operation *op, Value value, bool isDestination)
Determines whether the given direction is valid with the given inputs.
static LogicalResult isCombinational(Value value, GroupInterface group)
Verifies the defining operation of a value is combinational.
static bool isStaticControl(Operation *op)
Returns whether the given operation is a static control operator.
static void printParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance. Same format as HW dialect.
static DictionaryAttr cleanCalyxPortAttrs(OpBuilder builder, DictionaryAttr dict)
Returns a new DictionaryAttr containing only the calyx dialect attrs in the input DictionaryAttr.
static void printGroupPort(OpAsmPrinter &p, GroupPortType op)
#define ImplUnaryOpCellInterface(OpType)
static ParseResult parseGroupPort(OpAsmParser &parser, OperationState &result)
static SmallVector< Block *, 8 > intersection(SmallVectorImpl< Block * > &v1, SmallVectorImpl< Block * > &v2)
Calculate intersection of two vectors, returns a new vector.
static InstancePath empty
static ParseResult parsePort(OpAsmParser &p, module_like_impl::PortParse &result)
Parse a single argument with the following syntax:
Signals that the following operation is combinational.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
IntegerAttr packAttribute(MLIRContext *context, size_t nIns, size_t nOuts)
Returns an IntegerAttr containing the packed representation of the direction counts.
LogicalResult verifyComponent(Operation *op)
A helper function to verify each operation with the Ccomponent trait.
LogicalResult verifyControlLikeOp(Operation *op)
A helper function to verify each control-like operation has a valid parent and, if applicable,...
LogicalResult verifyGroupInterface(Operation *op)
A helper function to verify each operation with the Group Interface trait.
LogicalResult verifyCell(Operation *op)
A helper function to verify each operation with the Cell trait.
Direction
The direction of a Component or Cell port.
LogicalResult verifyIf(Operation *op)
A helper function to verify each operation with the If trait.
PortInfo getPortInfo(BlockArgument arg)
Returns port information for the block argument provided.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
enum PEO uint32_t mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
This pattern collapses a calyx.seq or calyx.par operation when it contains exactly one calyx....
LogicalResult matchAndRewrite(CtrlOp ctrlOp, PatternRewriter &rewriter) const override
This holds information about the port for either a Component or Cell.
DictionaryAttr attributes