CIRCT  19.0.0git
MSFTLowerConstructs.cpp
Go to the documentation of this file.
1 //===- MSFTLowerConstructs.cpp - MSFT constructs lowerings ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetails.h"
11 #include "circt/Dialect/HW/HWOps.h"
18 
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Pass/PassRegistry.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 using namespace mlir;
28 using namespace circt;
29 using namespace msft;
30 
31 //===----------------------------------------------------------------------===//
32 // Lower MSFT constructs
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 
37 struct LowerConstructsPass : public LowerConstructsBase<LowerConstructsPass> {
38  void runOnOperation() override;
39 
40  /// For naming purposes, get the inner Namespace for a module, building it
41  /// lazily.
42  Namespace &getNamespaceFor(Operation *mod) {
43  auto ns = moduleNamespaces.find(mod);
44  if (ns != moduleNamespaces.end())
45  return ns->getSecond();
46  Namespace &nsNew = moduleNamespaces[mod];
47  SymbolCache syms;
48  syms.addDefinitions(mod);
49  nsNew.add(syms);
50  return nsNew;
51  }
52 
53 private:
54  DenseMap<Operation *, circt::Namespace> moduleNamespaces;
55 };
56 } // anonymous namespace
57 
58 namespace {
59 /// Lower MSFT's OutputOp to HW's.
60 struct SystolicArrayOpLowering : public OpConversionPattern<SystolicArrayOp> {
61 public:
62  using OpConversionPattern::OpConversionPattern;
63 
64  LogicalResult
65  matchAndRewrite(SystolicArrayOp array, OpAdaptor adaptor,
66  ConversionPatternRewriter &rewriter) const final {
67  MLIRContext *ctxt = getContext();
68  Location loc = array.getLoc();
69  Block &peBlock = array.getPe().front();
70  rewriter.setInsertionPointAfter(array);
71 
72  // For the row broadcasts, break out the row values which must be broadcast
73  // to each PE.
74  hw::ArrayType rowInputs =
75  hw::type_cast<hw::ArrayType>(array.getRowInputs().getType());
76  IntegerType rowIdxType = rewriter.getIntegerType(
77  std::max(1u, llvm::Log2_64_Ceil(rowInputs.getNumElements())));
78  SmallVector<Value> rowValues;
79  for (size_t rowNum = 0, numRows = rowInputs.getNumElements();
80  rowNum < numRows; ++rowNum) {
81  Value rowNumVal =
82  rewriter.create<hw::ConstantOp>(loc, rowIdxType, rowNum);
83  auto rowValue =
84  rewriter.create<hw::ArrayGetOp>(loc, array.getRowInputs(), rowNumVal);
85  rowValue->setAttr("sv.namehint",
86  StringAttr::get(ctxt, "row_" + Twine(rowNum)));
87  rowValues.push_back(rowValue);
88  }
89 
90  // For the column broadcasts, break out the column values which must be
91  // broadcast to each PE.
92  hw::ArrayType colInputs =
93  hw::type_cast<hw::ArrayType>(array.getColInputs().getType());
94  IntegerType colIdxType = rewriter.getIntegerType(
95  std::max(1u, llvm::Log2_64_Ceil(colInputs.getNumElements())));
96  SmallVector<Value> colValues;
97  for (size_t colNum = 0, numCols = colInputs.getNumElements();
98  colNum < numCols; ++colNum) {
99  Value colNumVal =
100  rewriter.create<hw::ConstantOp>(loc, colIdxType, colNum);
101  auto colValue =
102  rewriter.create<hw::ArrayGetOp>(loc, array.getColInputs(), colNumVal);
103  colValue->setAttr("sv.namehint",
104  StringAttr::get(ctxt, "col_" + Twine(colNum)));
105  colValues.push_back(colValue);
106  }
107 
108  // Build the PE matrix.
109  SmallVector<Value> peOutputs;
110  for (size_t rowNum = 0, numRows = rowInputs.getNumElements();
111  rowNum < numRows; ++rowNum) {
112  Value rowValue = rowValues[rowNum];
113  SmallVector<Value> colPEOutputs;
114  for (size_t colNum = 0, numCols = colInputs.getNumElements();
115  colNum < numCols; ++colNum) {
116  Value colValue = colValues[colNum];
117  // Clone the PE block, substituting %row (arg 0) and %col (arg 1) for
118  // the corresponding row/column broadcast value.
119  // NOTE: the PE region is NOT a graph region so we don't have to deal
120  // with backedges.
121  IRMapping mapper;
122  mapper.map(peBlock.getArgument(0), rowValue);
123  mapper.map(peBlock.getArgument(1), colValue);
124  for (Operation &peOperation : peBlock)
125  // If we see the output op (which should be the block terminator), add
126  // its operand to the output matrix.
127  if (auto outputOp = dyn_cast<PEOutputOp>(peOperation)) {
128  colPEOutputs.push_back(mapper.lookup(outputOp.getOutput()));
129  } else {
130  Operation *clone = rewriter.clone(peOperation, mapper);
131 
132  StringRef nameSource = "name";
133  auto name = clone->getAttrOfType<StringAttr>(nameSource);
134  if (!name) {
135  nameSource = "sv.namehint";
136  name = clone->getAttrOfType<StringAttr>(nameSource);
137  }
138  if (name)
139  clone->setAttr(nameSource,
140  StringAttr::get(ctxt, name.getValue() + "_" +
141  Twine(rowNum) + "_" +
142  Twine(colNum)));
143  }
144  }
145  // Reverse the vector since ArrayCreateOp has the opposite ordering to C
146  // vectors.
147  std::reverse(colPEOutputs.begin(), colPEOutputs.end());
148  peOutputs.push_back(
149  rewriter.create<hw::ArrayCreateOp>(loc, colPEOutputs));
150  }
151 
152  std::reverse(peOutputs.begin(), peOutputs.end());
153  rewriter.replaceOp(array,
154  rewriter.create<hw::ArrayCreateOp>(loc, peOutputs));
155  return success();
156  }
157 };
158 } // anonymous namespace
159 
160 void LowerConstructsPass::runOnOperation() {
161  auto top = getOperation();
162  auto *ctxt = &getContext();
163 
164  ConversionTarget target(*ctxt);
165  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
166 
167  RewritePatternSet patterns(ctxt);
168  patterns.insert<SystolicArrayOpLowering>(ctxt);
169  target.addIllegalOp<SystolicArrayOp>();
170 
171  if (failed(mlir::applyPartialConversion(top, target, std::move(patterns))))
172  signalPassFailure();
173 }
174 
175 std::unique_ptr<Pass> circt::msft::createLowerConstructsPass() {
176  return std::make_unique<LowerConstructsPass>();
177 }
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:29
void add(SymbolCache &symCache)
SymbolCache initializer; initialize from every key that is convertible to a StringAttr in the SymbolC...
Definition: Namespace.h:47
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
def create(data_type, value)
Definition: hw.py:393
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
std::unique_ptr< mlir::Pass > createLowerConstructsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: msft.py:1