17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/OperationSupport.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Support/IndentedOstream.h"
23 #include "llvm/ADT/TypeSwitch.h"
29 #define GEN_PASS_DEF_HANDSHAKEDOTPRINT
30 #define GEN_PASS_DEF_HANDSHAKEOPCOUNT
31 #define GEN_PASS_DEF_HANDSHAKEADDIDS
32 #include "circt/Dialect/Handshake/HandshakePasses.h.inc"
36 using namespace circt;
37 using namespace handshake;
41 auto controlInterface = dyn_cast<handshake::ControlInterface>(op);
42 return controlInterface && controlInterface.isControl();
46 struct HandshakeDotPrintPass
47 :
public circt::handshake::impl::HandshakeDotPrintBase<
48 HandshakeDotPrintPass> {
49 void runOnOperation()
override {
50 ModuleOp m = getOperation();
55 SmallVector<std::string> sortedFuncs;
61 handshake::FuncOp topLevelOp =
62 cast<handshake::FuncOp>(m.lookupSymbol(topLevel));
66 llvm::raw_fd_ostream outfile(topLevel +
".dot", ec);
67 mlir::raw_indented_ostream os(outfile);
69 os <<
"Digraph G {\n";
71 os <<
"splines=spline;\n";
72 os <<
"compound=true; // Allow edges between clusters\n";
73 dotPrint(os,
"TOP", topLevelOp,
true);
82 std::string dotPrint(mlir::raw_indented_ostream &os, StringRef parentName,
83 handshake::FuncOp f,
bool isTop);
88 std::map<std::string, unsigned> instanceIdMap;
91 DenseMap<Operation *, std::string> opNameMap;
94 DenseMap<Value, std::string> argNameMap;
96 void setUsedByMapping(Value v, Operation *op, StringRef node);
97 void setProducedByMapping(Value v, Operation *op, StringRef node);
100 std::string getUsedByNode(Value v, Operation *consumer);
102 std::string getProducedByNode(Value v, Operation *producer);
108 DenseMap<Value, std::map<Operation *, std::string>> usedByMapping;
110 DenseMap<Value, std::map<Operation *, std::string>> producedByMapping;
113 struct HandshakeOpCountPass
114 :
public circt::handshake::impl::HandshakeOpCountBase<
115 HandshakeOpCountPass> {
116 void runOnOperation()
override {
117 ModuleOp m = getOperation();
119 for (
auto func : m.getOps<handshake::FuncOp>()) {
120 std::map<std::string, int> cnts;
121 for (Operation &op : func.getOps()) {
122 llvm::TypeSwitch<Operation *, void>(&op)
123 .Case<handshake::ConstantOp>([&](
auto) { cnts[
"Constant"]++; })
124 .Case<handshake::MuxOp>([&](
auto) { cnts[
"Mux"]++; })
125 .Case<handshake::LoadOp>([&](
auto) { cnts[
"Load"]++; })
126 .Case<handshake::StoreOp>([&](
auto) { cnts[
"Store"]++; })
127 .Case<handshake::MergeOp>([&](
auto) { cnts[
"Merge"]++; })
128 .Case<handshake::ForkOp>([&](
auto) { cnts[
"Fork"]++; })
129 .Case<handshake::BranchOp>([&](
auto) { cnts[
"Branch"]++; })
130 .Case<handshake::MemoryOp, handshake::ExternalMemoryOp>(
131 [&](
auto) { cnts[
"Memory"]++; })
132 .Case<handshake::ControlMergeOp>(
133 [&](
auto) { cnts[
"CntrlMerge"]++; })
134 .Case<handshake::SinkOp>([&](
auto) { cnts[
"Sink"]++; })
135 .Case<handshake::SourceOp>([&](
auto) { cnts[
"Source"]++; })
136 .Case<handshake::JoinOp>([&](
auto) { cnts[
"Join"]++; })
137 .Case<handshake::BufferOp>([&](
auto) { cnts[
"Buffer"]++; })
138 .Case<handshake::ConditionalBranchOp>(
139 [&](
auto) { cnts[
"Branch"]++; })
140 .Case<arith::AddIOp>([&](
auto) { cnts[
"Add"]++; })
141 .Case<arith::SubIOp>([&](
auto) { cnts[
"Sub"]++; })
142 .Case<arith::AddIOp>([&](
auto) { cnts[
"Add"]++; })
143 .Case<arith::MulIOp>([&](
auto) { cnts[
"Mul"]++; })
144 .Case<arith::CmpIOp>([&](
auto) { cnts[
"Cmp"]++; })
145 .Case<arith::IndexCastOp, arith::ShLIOp, arith::ShRSIOp,
146 arith::ShRUIOp>([&](
auto) { cnts[
"Ext/Sh"]++; })
147 .Case<handshake::ReturnOp>([&](
auto) {})
148 .Default([&](
auto op) {
149 llvm::outs() <<
"Unhandled operation: " << *op <<
"\n";
154 llvm::outs() <<
"// RESOURCES"
157 llvm::outs() << it.first <<
"\t" << it.second <<
"\n";
158 llvm::outs() <<
"// END"
169 StringRef instanceName, Operation *op,
170 DenseMap<Operation *, unsigned> &opIDs) {
175 std::string opDialectName = op->getName().getStringRef().str();
176 std::replace(opDialectName.begin(), opDialectName.end(),
'.',
'_');
177 std::string opName = (instanceName +
"." + opDialectName).str();
180 auto idAttr = op->getAttrOfType<IntegerAttr>(
"handshake_id");
182 opName +=
"_id" + std::to_string(idAttr.getValue().getZExtValue());
184 opName += std::to_string(opIDs[op]);
186 outfile <<
"\"" << opName <<
"\""
190 outfile <<
"fillcolor = ";
192 << llvm::TypeSwitch<Operation *, std::string>(op)
193 .Case<handshake::ForkOp, handshake::LazyForkOp, handshake::MuxOp,
194 handshake::JoinOp>([&](
auto) {
return "lavender"; })
195 .Case<handshake::BufferOp>([&](
auto) {
return "lightgreen"; })
196 .Case<handshake::ReturnOp>([&](
auto) {
return "gold"; })
197 .Case<handshake::SinkOp, handshake::ConstantOp>(
198 [&](
auto) {
return "gainsboro"; })
199 .Case<handshake::MemoryOp, handshake::LoadOp, handshake::StoreOp>(
200 [&](
auto) {
return "coral"; })
201 .Case<handshake::MergeOp, handshake::ControlMergeOp,
202 handshake::BranchOp, handshake::ConditionalBranchOp>(
203 [&](
auto) {
return "lightblue"; })
204 .Default([&](
auto) {
return "moccasin"; });
207 outfile <<
", shape=";
208 if (op->getDialect()->getNamespace() ==
"handshake")
214 outfile <<
", label=\"";
215 outfile << llvm::TypeSwitch<Operation *, std::string>(op)
216 .Case<handshake::ConstantOp>([&](
auto op) {
217 return std::to_string(
218 op->template getAttrOfType<mlir::IntegerAttr>(
"value")
222 .Case<handshake::ControlMergeOp>(
223 [&](
auto) {
return "cmerge"; })
224 .Case<handshake::ConditionalBranchOp>(
225 [&](
auto) {
return "cbranch"; })
226 .Case<handshake::BufferOp>([&](
auto op) {
227 std::string n =
"buffer ";
228 n += stringifyEnum(op.getBufferType());
231 .Case<arith::AddIOp>([&](
auto) {
return "+"; })
232 .Case<arith::SubIOp>([&](
auto) {
return "-"; })
233 .Case<arith::AndIOp>([&](
auto) {
return "&"; })
234 .Case<arith::OrIOp>([&](
auto) {
return "|"; })
235 .Case<arith::XOrIOp>([&](
auto) {
return "^"; })
236 .Case<arith::MulIOp>([&](
auto) {
return "*"; })
237 .Case<arith::ShRSIOp, arith::ShRUIOp>(
238 [&](
auto) {
return ">>"; })
239 .Case<arith::ShLIOp>([&](
auto) {
return "<<"; })
240 .Case<arith::CmpIOp>([&](arith::CmpIOp op) {
241 switch (op.getPredicate()) {
242 case arith::CmpIPredicate::eq:
244 case arith::CmpIPredicate::ne:
246 case arith::CmpIPredicate::uge:
247 case arith::CmpIPredicate::sge:
249 case arith::CmpIPredicate::ugt:
250 case arith::CmpIPredicate::sgt:
252 case arith::CmpIPredicate::ule:
253 case arith::CmpIPredicate::sle:
255 case arith::CmpIPredicate::ult:
256 case arith::CmpIPredicate::slt:
259 llvm_unreachable(
"unhandled cmpi predicate");
261 .Default([&](
auto op) {
262 auto opDialect = op->getDialect()->getNamespace();
263 std::string label = op->getName().getStringRef().str();
264 if (opDialect ==
"handshake")
265 label.erase(0, StringLiteral(
"handshake.").size());
273 outfile <<
" [" << std::to_string(idAttr.getValue().getZExtValue()) <<
"]";
278 outfile <<
", style=\"filled";
280 outfile <<
", dashed";
292 return llvm::TypeSwitch<Operation *, bool>(op)
293 .Case<handshake::MuxOp, handshake::ConditionalBranchOp>(
294 [&](
auto op) {
return v == op.getOperand(0); })
295 .Case<handshake::ControlMergeOp>([&](
auto) {
return true; })
296 .Default([](
auto) {
return false; });
299 static std::string
getLocalName(StringRef instanceName, StringRef suffix) {
300 return (instanceName +
"." + suffix).str();
303 static std::string
getArgName(handshake::FuncOp op,
unsigned index) {
304 return op.getArgName(index).getValue().str();
308 handshake::FuncOp op,
unsigned index) {
312 static std::string
getResName(handshake::FuncOp op,
unsigned index) {
313 return op.getResName(index).getValue().str();
317 handshake::FuncOp op,
unsigned index) {
321 void HandshakeDotPrintPass::setUsedByMapping(Value v, Operation *op,
323 usedByMapping[v][op] = node;
325 void HandshakeDotPrintPass::setProducedByMapping(Value v, Operation *op,
327 producedByMapping[v][op] = node;
330 std::string HandshakeDotPrintPass::getUsedByNode(Value v, Operation *consumer) {
332 auto it = usedByMapping.find(v);
333 if (it != usedByMapping.end()) {
334 auto it2 = it->second.find(consumer);
335 if (it2 != it->second.end())
340 auto opNameIt = opNameMap.find(consumer);
341 assert(opNameIt != opNameMap.end() &&
342 "No name registered for the operation!");
343 return opNameIt->second;
346 std::string HandshakeDotPrintPass::getProducedByNode(Value v,
347 Operation *producer) {
349 auto it = producedByMapping.find(v);
350 if (it != producedByMapping.end()) {
351 auto it2 = it->second.find(producer);
352 if (it2 != it->second.end())
357 auto opNameIt = opNameMap.find(producer);
358 assert(opNameIt != opNameMap.end() &&
359 "No name registered for the operation!");
360 return opNameIt->second;
367 Value result, Operation *to) {
372 auto results = from->getResults();
374 std::distance(results.begin(), llvm::find(results, result));
375 auto fromNamedOpInterface = dyn_cast<handshake::NamedIOInterface>(from);
376 if (fromNamedOpInterface) {
377 auto resName = fromNamedOpInterface.getResultName(resIdx);
378 os <<
" output=\"" << resName <<
"\"";
380 os <<
" output=\"out" << resIdx <<
"\"";
385 auto ops = to->getOperands();
386 unsigned opIdx = std::distance(ops.begin(), llvm::find(ops, result));
387 auto toNamedOpInterface = dyn_cast<handshake::NamedIOInterface>(to);
388 if (toNamedOpInterface) {
389 auto opName = toNamedOpInterface.getOperandName(opIdx);
390 os <<
" input=\"" << opName <<
"\"";
392 os <<
" input=\"in" << opIdx <<
"\"";
396 std::string HandshakeDotPrintPass::dotPrint(mlir::raw_indented_ostream &os,
397 StringRef parentName,
398 handshake::FuncOp f,
bool isTop) {
400 DenseMap<Block *, unsigned> blockIDs;
401 std::map<std::string, unsigned> opTypeCntrs;
402 DenseMap<Operation *, unsigned> opIDs;
403 auto name = f.getName();
404 unsigned thisId = instanceIdMap[name.str()]++;
405 std::string instanceName = parentName.str() +
"." + name.str();
408 instanceName += std::to_string(thisId);
413 std::optional<std::string> anyArg, anyBody, anyRes;
419 for (Block &block : f) {
420 blockIDs[&block] = i++;
421 for (Operation &op : block)
422 opIDs[&op] = opTypeCntrs[op.getName().getStringRef().str()]++;
426 os <<
"// Subgraph for instance of " << name <<
"\n";
427 os <<
"subgraph \"cluster_" << instanceName <<
"\" {\n";
429 os <<
"label = \"" << name <<
"\"\n";
430 os <<
"labeljust=\"l\"\n";
431 os <<
"color = \"darkgreen\"\n";
433 os <<
"node [shape=box style=filled fillcolor=\"white\"]\n";
435 Block *bodyBlock = &f.getBody().front();
438 os <<
"// Function argument nodes\n";
439 std::string argsCluster =
"cluster_" + instanceName +
"_args";
440 os <<
"subgraph \"" << argsCluster <<
"\" {\n";
444 os <<
"label=\"\"\n";
445 os <<
"peripheries=0\n";
446 for (
const auto &barg : enumerate(bodyBlock->getArguments())) {
448 auto localArgName =
getLocalName(instanceName, argName);
449 os <<
"\"" << localArgName <<
"\" [shape=diamond";
450 if (barg.index() == bodyBlock->getNumArguments() - 1)
451 os <<
", style=dashed";
452 os <<
" label=\"" << argName <<
"\"";
454 if (!anyArg.has_value())
455 anyArg = localArgName;
460 os <<
"// Function return nodes\n";
461 std::string resCluster =
"cluster_" + instanceName +
"_res";
462 os <<
"subgraph \"" << resCluster <<
"\" {\n";
466 os <<
"label=\"\"\n";
467 os <<
"peripheries=0\n";
470 auto returnOp = *f.getBody().getOps<handshake::ReturnOp>().begin();
471 for (
const auto &res : llvm::enumerate(returnOp.getOperands())) {
474 os <<
"\"" << uniqueResName <<
"\" [shape=diamond";
475 if (res.index() == bodyBlock->getNumArguments() - 1)
476 os <<
", style=dashed";
477 os <<
" label=\"" << resName <<
"\"";
482 setUsedByMapping(res.value(), returnOp, uniqueResName);
484 if (!anyRes.has_value())
485 anyRes = uniqueResName;
491 std::string opsCluster =
"cluster_" + instanceName +
"_ops";
492 os <<
"subgraph \"" << opsCluster <<
"\" {\n";
496 os <<
"label=\"\"\n";
497 os <<
"peripheries=0\n";
498 for (Operation &op : *bodyBlock) {
499 if (!isa<handshake::InstanceOp, handshake::ReturnOp>(op)) {
501 opNameMap[&op] =
dotPrintNode(os, instanceName, &op, opIDs);
504 auto instOp = dyn_cast<handshake::InstanceOp>(op);
508 instOp->getParentOfType<ModuleOp>().lookupSymbol<handshake::FuncOp>(
511 auto subInstanceName = dotPrint(os, instanceName, calledFuncOp,
false);
515 for (
const auto &arg : llvm::enumerate(instOp.getOperands())) {
522 for (
const auto &res : llvm::enumerate(instOp.getResults())) {
523 setProducedByMapping(
529 if (!opNameMap.empty())
530 anyBody = opNameMap.begin()->second;
536 os <<
"// Operation result edges\n";
537 for (Operation &op : *bodyBlock) {
538 for (
auto result : op.getResults()) {
539 for (
auto &u : result.getUses()) {
540 Operation *useOp = u.getOwner();
541 if (useOp->getBlock() == bodyBlock) {
542 os <<
"\"" << getProducedByNode(result, &op);
544 os << getUsedByNode(result, useOp) <<
"\"";
546 os <<
" [style=\"dashed\"]";
561 os <<
"// Function argument edges\n";
562 for (
const auto &barg : enumerate(bodyBlock->getArguments())) {
564 os <<
"\"" <<
getLocalName(instanceName, argName) <<
"\" [shape=diamond";
565 if (barg.index() == bodyBlock->getNumArguments() - 1)
566 os <<
", style=dashed";
568 for (
auto *useOp : barg.value().getUsers()) {
569 os <<
"\"" <<
getLocalName(instanceName, argName) <<
"\" -> \""
570 << getUsedByNode(barg.value(), useOp) <<
"\"";
572 os <<
" [style=\"dashed\"]";
581 if (anyArg.has_value() && anyBody.has_value())
582 os <<
"\"" << anyArg.value() <<
"\" -> \"" << anyBody.value()
583 <<
"\" [lhead=\"" << opsCluster <<
"\" ltail=\"" << argsCluster
584 <<
"\" style=invis]\n";
585 if (anyBody.has_value() && anyRes.has_value())
586 os <<
"\"" << anyBody.value() <<
"\" -> \"" << anyRes.value()
587 <<
"\" [lhead=\"" << resCluster <<
"\" ltail=\"" << opsCluster
588 <<
"\" style=invis]\n";
595 struct HandshakeAddIDsPass
596 :
public circt::handshake::impl::HandshakeAddIDsBase<HandshakeAddIDsPass> {
597 void runOnOperation()
override {
598 handshake::FuncOp funcOp = getOperation();
599 auto *ctx = &getContext();
600 OpBuilder builder(funcOp);
601 funcOp.walk([&](Operation *op) {
602 if (op->hasAttr(
"handshake_id"))
604 llvm::SmallVector<NamedAttribute> attrs;
605 llvm::copy(op->getAttrs(), std::back_inserter(attrs));
606 attrs.push_back(builder.getNamedAttr(
609 opCounters[op->getName().getStringRef().str()]++)));
617 std::map<std::string, unsigned> opCounters;
621 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
623 return std::make_unique<HandshakeDotPrintPass>();
626 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
628 return std::make_unique<HandshakeOpCountPass>();
632 return std::make_unique<HandshakeAddIDsPass>();
static bool isControlOp(Operation *op)
static std::string getUniqueResName(StringRef instanceName, handshake::FuncOp op, unsigned index)
static std::string getLocalName(StringRef instanceName, StringRef suffix)
static std::string dotPrintNode(mlir::raw_indented_ostream &outfile, StringRef instanceName, Operation *op, DenseMap< Operation *, unsigned > &opIDs)
Prints an operation to the dot file and returns the unique name for the operation within the graph.
static std::string getArgName(handshake::FuncOp op, unsigned index)
static bool isControlOperand(Operation *op, Value v)
Returns true if v is used as a control operand in op.
static std::string getUniqueArgName(StringRef instanceName, handshake::FuncOp op, unsigned index)
static std::string getResName(handshake::FuncOp op, unsigned index)
static void tryAddExtraEdgeInfo(mlir::raw_indented_ostream &os, Operation *from, Value result, Operation *to)
Emits additional, non-graphviz information about the connection between from- and to.
assert(baseType &&"element must be base type")
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createHandshakeOpCountPass()
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
LogicalResult resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph, std::string &topLevel, SmallVectorImpl< std::string > &sortedFuncs)
Iterates over the handshake::FuncOp's in the program to build an instance graph.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createHandshakeDotPrintPass()
std::unique_ptr< mlir::Pass > createHandshakeAddIDsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.