CIRCT  19.0.0git
MSFTOps.cpp
Go to the documentation of this file.
1 //===- MSFTOps.cpp - Implement MSFT dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MSFT dialect operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "circt/Dialect/HW/HWOps.h"
19 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/Interfaces/FunctionImplementation.h"
23 #include "mlir/Interfaces/FunctionInterfaces.h"
24 #include "llvm/ADT/BitVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 using namespace circt;
29 using namespace msft;
30 
31 //===----------------------------------------------------------------------===//
32 // Custom directive parsers/printers
33 //===----------------------------------------------------------------------===//
34 
35 static ParseResult parsePhysLoc(OpAsmParser &p, PhysLocationAttr &attr) {
36  llvm::SMLoc loc = p.getCurrentLocation();
37  StringRef devTypeStr;
38  uint64_t x, y, num;
39 
40  if (p.parseKeyword(&devTypeStr) || p.parseKeyword("x") || p.parseColon() ||
41  p.parseInteger(x) || p.parseKeyword("y") || p.parseColon() ||
42  p.parseInteger(y) || p.parseKeyword("n") || p.parseColon() ||
43  p.parseInteger(num))
44  return failure();
45 
46  std::optional<PrimitiveType> devType = symbolizePrimitiveType(devTypeStr);
47  if (!devType) {
48  p.emitError(loc, "Unknown device type '" + devTypeStr + "'");
49  return failure();
50  }
51  PrimitiveTypeAttr devTypeAttr =
52  PrimitiveTypeAttr::get(p.getContext(), *devType);
53  attr = PhysLocationAttr::get(p.getContext(), devTypeAttr, x, y, num);
54  return success();
55 }
56 
57 static void printPhysLoc(OpAsmPrinter &p, Operation *, PhysLocationAttr loc) {
58  p << stringifyPrimitiveType(loc.getPrimitiveType().getValue())
59  << " x: " << loc.getX() << " y: " << loc.getY() << " n: " << loc.getNum();
60 }
61 
62 static ParseResult parseListOptionalRegLocList(OpAsmParser &p,
63  LocationVectorAttr &locs) {
64  SmallVector<PhysLocationAttr, 32> locArr;
65  TypeAttr type;
66  if (p.parseAttribute(type) || p.parseLSquare() ||
67  p.parseCommaSeparatedList(
68  [&]() { return parseOptionalRegLoc(locArr, p); }) ||
69  p.parseRSquare())
70  return failure();
71 
72  if (failed(LocationVectorAttr::verify(
73  [&p]() { return p.emitError(p.getNameLoc()); }, type, locArr)))
74  return failure();
75  locs = LocationVectorAttr::get(p.getContext(), type, locArr);
76  return success();
77 }
78 
79 static void printListOptionalRegLocList(OpAsmPrinter &p, Operation *,
80  LocationVectorAttr locs) {
81  p << locs.getType() << " [";
82  llvm::interleaveComma(locs.getLocs(), p, [&p](PhysLocationAttr loc) {
83  printOptionalRegLoc(loc, p);
84  });
85  p << "]";
86 }
87 
88 static ParseResult parseImplicitInnerRef(OpAsmParser &p,
89  hw::InnerRefAttr &innerRef) {
90  SymbolRefAttr sym;
91  if (p.parseAttribute(sym))
92  return failure();
93  auto loc = p.getCurrentLocation();
94  if (sym.getNestedReferences().size() != 1)
95  return p.emitError(loc, "expected <module sym>::<inner name>");
96  innerRef = hw::InnerRefAttr::get(
97  sym.getRootReference(),
98  sym.getNestedReferences().front().getRootReference());
99  return success();
100 }
101 void printImplicitInnerRef(OpAsmPrinter &p, Operation *,
102  hw::InnerRefAttr innerRef) {
103  MLIRContext *ctxt = innerRef.getContext();
104  StringRef innerRefNameStr, moduleStr;
105  if (innerRef.getName())
106  innerRefNameStr = innerRef.getName().getValue();
107  if (innerRef.getModule())
108  moduleStr = innerRef.getModule().getValue();
109  p << SymbolRefAttr::get(ctxt, moduleStr,
110  {FlatSymbolRefAttr::get(ctxt, innerRefNameStr)});
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // DynamicInstanceOp
115 //===----------------------------------------------------------------------===//
116 
117 ArrayAttr DynamicInstanceOp::getPath() {
118  SmallVector<Attribute, 16> path;
119  DynamicInstanceOp next = *this;
120  do {
121  path.push_back(next.getInstanceRefAttr());
122  next = next->getParentOfType<DynamicInstanceOp>();
123  } while (next);
124  std::reverse(path.begin(), path.end());
125  return ArrayAttr::get(getContext(), path);
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // OutputOp
130 //===----------------------------------------------------------------------===//
131 
132 void OutputOp::build(OpBuilder &odsBuilder, OperationState &odsState) {}
133 
134 //===----------------------------------------------------------------------===//
135 // MSFT high level design constructs
136 //===----------------------------------------------------------------------===//
137 
138 //===----------------------------------------------------------------------===//
139 // SystolicArrayOp
140 //===----------------------------------------------------------------------===//
141 
142 ParseResult SystolicArrayOp::parse(OpAsmParser &parser,
143  OperationState &result) {
144  uint64_t numRows, numColumns;
145  Type rowType, columnType;
146  OpAsmParser::UnresolvedOperand rowInputs, columnInputs;
147  llvm::SMLoc loc = parser.getCurrentLocation();
148  if (parser.parseLSquare() || parser.parseOperand(rowInputs) ||
149  parser.parseColon() || parser.parseInteger(numRows) ||
150  parser.parseKeyword("x") || parser.parseType(rowType) ||
151  parser.parseRSquare() || parser.parseLSquare() ||
152  parser.parseOperand(columnInputs) || parser.parseColon() ||
153  parser.parseInteger(numColumns) || parser.parseKeyword("x") ||
154  parser.parseType(columnType) || parser.parseRSquare())
155  return failure();
156 
157  hw::ArrayType rowInputType = hw::ArrayType::get(rowType, numRows);
158  hw::ArrayType columnInputType = hw::ArrayType::get(columnType, numColumns);
159  SmallVector<Value> operands;
160  if (parser.resolveOperands({rowInputs, columnInputs},
161  {rowInputType, columnInputType}, loc, operands))
162  return failure();
163  result.addOperands(operands);
164 
165  Type peOutputType;
166  SmallVector<OpAsmParser::Argument> peArgs;
167  if (parser.parseKeyword("pe")) {
168  return failure();
169  }
170  llvm::SMLoc peLoc = parser.getCurrentLocation();
171  if (parser.parseArgumentList(peArgs, AsmParser::Delimiter::Paren)) {
172  return failure();
173  }
174  if (peArgs.size() != 2) {
175  return parser.emitError(peLoc, "expected two operands");
176  }
177 
178  peArgs[0].type = rowType;
179  peArgs[1].type = columnType;
180 
181  if (parser.parseArrow() || parser.parseLParen() ||
182  parser.parseType(peOutputType) || parser.parseRParen())
183  return failure();
184 
185  result.addTypes({hw::ArrayType::get(
186  hw::ArrayType::get(peOutputType, numColumns), numRows)});
187 
188  Region *pe = result.addRegion();
189 
190  peLoc = parser.getCurrentLocation();
191 
192  if (parser.parseRegion(*pe, peArgs))
193  return failure();
194 
195  if (pe->getBlocks().size() != 1)
196  return parser.emitError(peLoc, "expected one block for the PE");
197  Operation *peTerm = pe->getBlocks().front().getTerminator();
198  if (peTerm->getOperands().size() != 1)
199  return peTerm->emitOpError("expected one return value");
200  if (peTerm->getOperand(0).getType() != peOutputType)
201  return peTerm->emitOpError("expected return type as given in parent: ")
202  << peOutputType;
203 
204  return success();
205 }
206 
207 void SystolicArrayOp::print(OpAsmPrinter &p) {
208  hw::ArrayType rowInputType = cast<hw::ArrayType>(getRowInputs().getType());
209  hw::ArrayType columnInputType = cast<hw::ArrayType>(getColInputs().getType());
210  p << " [";
211  p.printOperand(getRowInputs());
212  p << " : " << rowInputType.getNumElements() << " x ";
213  p.printType(rowInputType.getElementType());
214  p << "] [";
215  p.printOperand(getColInputs());
216  p << " : " << columnInputType.getNumElements() << " x ";
217  p.printType(columnInputType.getElementType());
218 
219  p << "] pe (";
220  p.printOperand(getPe().getArgument(0));
221  p << ", ";
222  p.printOperand(getPe().getArgument(1));
223  p << ") -> (";
224  p.printType(
225  cast<hw::ArrayType>(
226  cast<hw::ArrayType>(getPeOutputs().getType()).getElementType())
227  .getElementType());
228  p << ") ";
229  p.printRegion(getPe(), false);
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // LinearOp
234 //===----------------------------------------------------------------------===//
235 
236 LogicalResult LinearOp::verify() {
237 
238  for (auto &op : *getBodyBlock()) {
239  if (!isa<hw::HWDialect, comb::CombDialect, msft::MSFTDialect>(
240  op.getDialect()))
241  return emitOpError() << "expected only hw, comb, and msft dialect ops "
242  "inside the datapath.";
243  }
244 
245  return success();
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // PDMulticycleOp
250 //===----------------------------------------------------------------------===//
251 
252 Operation *PDMulticycleOp::getTopModule(hw::HWSymbolCache &cache) {
253  // Both symbols should reference the same top-level module in their respective
254  // HierPath ops.
255  Operation *srcTop = getHierPathTopModule(getLoc(), cache, getSourceAttr());
256  Operation *dstTop = getHierPathTopModule(getLoc(), cache, getDestAttr());
257  if (srcTop != dstTop) {
258  emitOpError("source and destination paths must refer to the same top-level "
259  "module.");
260  return nullptr;
261  }
262  return srcTop;
263 }
264 
265 #define GET_OP_CLASSES
266 #include "circt/Dialect/MSFT/MSFT.cpp.inc"
static void printListOptionalRegLocList(OpAsmPrinter &p, Operation *, LocationVectorAttr locs)
Definition: MSFTOps.cpp:79
static ParseResult parseListOptionalRegLocList(OpAsmParser &p, LocationVectorAttr &locs)
Definition: MSFTOps.cpp:62
static void printPhysLoc(OpAsmPrinter &p, Operation *, PhysLocationAttr loc)
Definition: MSFTOps.cpp:57
static ParseResult parseImplicitInnerRef(OpAsmParser &p, hw::InnerRefAttr &innerRef)
Definition: MSFTOps.cpp:88
void printImplicitInnerRef(OpAsmPrinter &p, Operation *, hw::InnerRefAttr innerRef)
Definition: MSFTOps.cpp:101
static ParseResult parsePhysLoc(OpAsmParser &p, PhysLocationAttr &attr)
Definition: MSFTOps.cpp:35
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
Operation * getHierPathTopModule(Location loc, circt::hw::HWSymbolCache &symCache, FlatSymbolRefAttr pathSym)
Returns the top-level module which the given HierPathOp that defines pathSym, refers to.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: msft.py:1