CIRCT  19.0.0git
MSFTLowerConstructs.cpp
Go to the documentation of this file.
1 //===- MSFTLowerConstructs.cpp - MSFT constructs lowerings ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetails.h"
11 #include "circt/Dialect/HW/HWOps.h"
18 
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Pass/PassRegistry.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 namespace circt {
28 namespace msft {
29 #define GEN_PASS_DEF_LOWERCONSTRUCTS
30 #include "circt/Dialect/MSFT/MSFTPasses.h.inc"
31 } // namespace msft
32 } // namespace circt
33 
34 using namespace mlir;
35 using namespace circt;
36 using namespace msft;
37 
38 //===----------------------------------------------------------------------===//
39 // Lower MSFT constructs
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 
44 struct LowerConstructsPass
45  : public circt::msft::impl::LowerConstructsBase<LowerConstructsPass> {
46  void runOnOperation() override;
47 
48  /// For naming purposes, get the inner Namespace for a module, building it
49  /// lazily.
50  Namespace &getNamespaceFor(Operation *mod) {
51  auto ns = moduleNamespaces.find(mod);
52  if (ns != moduleNamespaces.end())
53  return ns->getSecond();
54  Namespace &nsNew = moduleNamespaces[mod];
55  SymbolCache syms;
56  syms.addDefinitions(mod);
57  nsNew.add(syms);
58  return nsNew;
59  }
60 
61 private:
62  DenseMap<Operation *, circt::Namespace> moduleNamespaces;
63 };
64 } // anonymous namespace
65 
66 namespace {
67 /// Lower MSFT's OutputOp to HW's.
68 struct SystolicArrayOpLowering : public OpConversionPattern<SystolicArrayOp> {
69 public:
70  using OpConversionPattern::OpConversionPattern;
71 
72  LogicalResult
73  matchAndRewrite(SystolicArrayOp array, OpAdaptor adaptor,
74  ConversionPatternRewriter &rewriter) const final {
75  MLIRContext *ctxt = getContext();
76  Location loc = array.getLoc();
77  Block &peBlock = array.getPe().front();
78  rewriter.setInsertionPointAfter(array);
79 
80  // For the row broadcasts, break out the row values which must be broadcast
81  // to each PE.
82  hw::ArrayType rowInputs =
83  hw::type_cast<hw::ArrayType>(array.getRowInputs().getType());
84  IntegerType rowIdxType = rewriter.getIntegerType(
85  std::max(1u, llvm::Log2_64_Ceil(rowInputs.getNumElements())));
86  SmallVector<Value> rowValues;
87  for (size_t rowNum = 0, numRows = rowInputs.getNumElements();
88  rowNum < numRows; ++rowNum) {
89  Value rowNumVal =
90  rewriter.create<hw::ConstantOp>(loc, rowIdxType, rowNum);
91  auto rowValue =
92  rewriter.create<hw::ArrayGetOp>(loc, array.getRowInputs(), rowNumVal);
93  rowValue->setAttr("sv.namehint",
94  StringAttr::get(ctxt, "row_" + Twine(rowNum)));
95  rowValues.push_back(rowValue);
96  }
97 
98  // For the column broadcasts, break out the column values which must be
99  // broadcast to each PE.
100  hw::ArrayType colInputs =
101  hw::type_cast<hw::ArrayType>(array.getColInputs().getType());
102  IntegerType colIdxType = rewriter.getIntegerType(
103  std::max(1u, llvm::Log2_64_Ceil(colInputs.getNumElements())));
104  SmallVector<Value> colValues;
105  for (size_t colNum = 0, numCols = colInputs.getNumElements();
106  colNum < numCols; ++colNum) {
107  Value colNumVal =
108  rewriter.create<hw::ConstantOp>(loc, colIdxType, colNum);
109  auto colValue =
110  rewriter.create<hw::ArrayGetOp>(loc, array.getColInputs(), colNumVal);
111  colValue->setAttr("sv.namehint",
112  StringAttr::get(ctxt, "col_" + Twine(colNum)));
113  colValues.push_back(colValue);
114  }
115 
116  // Build the PE matrix.
117  SmallVector<Value> peOutputs;
118  for (size_t rowNum = 0, numRows = rowInputs.getNumElements();
119  rowNum < numRows; ++rowNum) {
120  Value rowValue = rowValues[rowNum];
121  SmallVector<Value> colPEOutputs;
122  for (size_t colNum = 0, numCols = colInputs.getNumElements();
123  colNum < numCols; ++colNum) {
124  Value colValue = colValues[colNum];
125  // Clone the PE block, substituting %row (arg 0) and %col (arg 1) for
126  // the corresponding row/column broadcast value.
127  // NOTE: the PE region is NOT a graph region so we don't have to deal
128  // with backedges.
129  IRMapping mapper;
130  mapper.map(peBlock.getArgument(0), rowValue);
131  mapper.map(peBlock.getArgument(1), colValue);
132  for (Operation &peOperation : peBlock)
133  // If we see the output op (which should be the block terminator), add
134  // its operand to the output matrix.
135  if (auto outputOp = dyn_cast<PEOutputOp>(peOperation)) {
136  colPEOutputs.push_back(mapper.lookup(outputOp.getOutput()));
137  } else {
138  Operation *clone = rewriter.clone(peOperation, mapper);
139 
140  StringRef nameSource = "name";
141  auto name = clone->getAttrOfType<StringAttr>(nameSource);
142  if (!name) {
143  nameSource = "sv.namehint";
144  name = clone->getAttrOfType<StringAttr>(nameSource);
145  }
146  if (name)
147  clone->setAttr(nameSource,
148  StringAttr::get(ctxt, name.getValue() + "_" +
149  Twine(rowNum) + "_" +
150  Twine(colNum)));
151  }
152  }
153  // Reverse the vector since ArrayCreateOp has the opposite ordering to C
154  // vectors.
155  std::reverse(colPEOutputs.begin(), colPEOutputs.end());
156  peOutputs.push_back(
157  rewriter.create<hw::ArrayCreateOp>(loc, colPEOutputs));
158  }
159 
160  std::reverse(peOutputs.begin(), peOutputs.end());
161  rewriter.replaceOp(array,
162  rewriter.create<hw::ArrayCreateOp>(loc, peOutputs));
163  return success();
164  }
165 };
166 } // anonymous namespace
167 
168 void LowerConstructsPass::runOnOperation() {
169  auto top = getOperation();
170  auto *ctxt = &getContext();
171 
172  ConversionTarget target(*ctxt);
173  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
174 
175  RewritePatternSet patterns(ctxt);
176  patterns.insert<SystolicArrayOpLowering>(ctxt);
177  target.addIllegalOp<SystolicArrayOp>();
178 
179  if (failed(mlir::applyPartialConversion(top, target, std::move(patterns))))
180  signalPassFailure();
181 }
182 
183 std::unique_ptr<Pass> circt::msft::createLowerConstructsPass() {
184  return std::make_unique<LowerConstructsPass>();
185 }
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
void add(mlir::ModuleOp module)
Definition: Namespace.h:46
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
def create(data_type, value)
Definition: hw.py:393
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
std::unique_ptr< mlir::Pass > createLowerConstructsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: msft.py:1