CIRCT 20.0.0git
Loading...
Searching...
No Matches
MSFTLowerConstructs.cpp
Go to the documentation of this file.
1//===- MSFTLowerConstructs.cpp - MSFT constructs lowerings ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetails.h"
18
19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Pass/PassRegistry.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27namespace circt {
28namespace msft {
29#define GEN_PASS_DEF_LOWERCONSTRUCTS
30#include "circt/Dialect/MSFT/MSFTPasses.h.inc"
31} // namespace msft
32} // namespace circt
33
34using namespace mlir;
35using namespace circt;
36using namespace msft;
37
38//===----------------------------------------------------------------------===//
39// Lower MSFT constructs
40//===----------------------------------------------------------------------===//
41
42namespace {
43
44struct LowerConstructsPass
45 : public circt::msft::impl::LowerConstructsBase<LowerConstructsPass> {
46 void runOnOperation() override;
47
48 /// For naming purposes, get the inner Namespace for a module, building it
49 /// lazily.
50 Namespace &getNamespaceFor(Operation *mod) {
51 auto ns = moduleNamespaces.find(mod);
52 if (ns != moduleNamespaces.end())
53 return ns->getSecond();
54 Namespace &nsNew = moduleNamespaces[mod];
55 SymbolCache syms;
56 syms.addDefinitions(mod);
57 nsNew.add(syms);
58 return nsNew;
59 }
60
61private:
62 DenseMap<Operation *, circt::Namespace> moduleNamespaces;
63};
64} // anonymous namespace
65
66namespace {
67/// Lower MSFT's OutputOp to HW's.
68struct SystolicArrayOpLowering : public OpConversionPattern<SystolicArrayOp> {
69public:
70 using OpConversionPattern::OpConversionPattern;
71
72 LogicalResult
73 matchAndRewrite(SystolicArrayOp array, OpAdaptor adaptor,
74 ConversionPatternRewriter &rewriter) const final {
75 MLIRContext *ctxt = getContext();
76 Location loc = array.getLoc();
77 Block &peBlock = array.getPe().front();
78 rewriter.setInsertionPointAfter(array);
79
80 // For the row broadcasts, break out the row values which must be broadcast
81 // to each PE.
82 hw::ArrayType rowInputs =
83 hw::type_cast<hw::ArrayType>(array.getRowInputs().getType());
84 IntegerType rowIdxType = rewriter.getIntegerType(
85 std::max(1u, llvm::Log2_64_Ceil(rowInputs.getNumElements())));
86 SmallVector<Value> rowValues;
87 for (size_t rowNum = 0, numRows = rowInputs.getNumElements();
88 rowNum < numRows; ++rowNum) {
89 Value rowNumVal =
90 rewriter.create<hw::ConstantOp>(loc, rowIdxType, rowNum);
91 auto rowValue =
92 rewriter.create<hw::ArrayGetOp>(loc, array.getRowInputs(), rowNumVal);
93 rowValue->setAttr("sv.namehint",
94 StringAttr::get(ctxt, "row_" + Twine(rowNum)));
95 rowValues.push_back(rowValue);
96 }
97
98 // For the column broadcasts, break out the column values which must be
99 // broadcast to each PE.
100 hw::ArrayType colInputs =
101 hw::type_cast<hw::ArrayType>(array.getColInputs().getType());
102 IntegerType colIdxType = rewriter.getIntegerType(
103 std::max(1u, llvm::Log2_64_Ceil(colInputs.getNumElements())));
104 SmallVector<Value> colValues;
105 for (size_t colNum = 0, numCols = colInputs.getNumElements();
106 colNum < numCols; ++colNum) {
107 Value colNumVal =
108 rewriter.create<hw::ConstantOp>(loc, colIdxType, colNum);
109 auto colValue =
110 rewriter.create<hw::ArrayGetOp>(loc, array.getColInputs(), colNumVal);
111 colValue->setAttr("sv.namehint",
112 StringAttr::get(ctxt, "col_" + Twine(colNum)));
113 colValues.push_back(colValue);
114 }
115
116 // Build the PE matrix.
117 SmallVector<Value> peOutputs;
118 for (size_t rowNum = 0, numRows = rowInputs.getNumElements();
119 rowNum < numRows; ++rowNum) {
120 Value rowValue = rowValues[rowNum];
121 SmallVector<Value> colPEOutputs;
122 for (size_t colNum = 0, numCols = colInputs.getNumElements();
123 colNum < numCols; ++colNum) {
124 Value colValue = colValues[colNum];
125 // Clone the PE block, substituting %row (arg 0) and %col (arg 1) for
126 // the corresponding row/column broadcast value.
127 // NOTE: the PE region is NOT a graph region so we don't have to deal
128 // with backedges.
129 IRMapping mapper;
130 mapper.map(peBlock.getArgument(0), rowValue);
131 mapper.map(peBlock.getArgument(1), colValue);
132 for (Operation &peOperation : peBlock)
133 // If we see the output op (which should be the block terminator), add
134 // its operand to the output matrix.
135 if (auto outputOp = dyn_cast<PEOutputOp>(peOperation)) {
136 colPEOutputs.push_back(mapper.lookup(outputOp.getOutput()));
137 } else {
138 Operation *clone = rewriter.clone(peOperation, mapper);
139
140 StringRef nameSource = "name";
141 auto name = clone->getAttrOfType<StringAttr>(nameSource);
142 if (!name) {
143 nameSource = "sv.namehint";
144 name = clone->getAttrOfType<StringAttr>(nameSource);
145 }
146 if (name)
147 clone->setAttr(nameSource,
148 StringAttr::get(ctxt, name.getValue() + "_" +
149 Twine(rowNum) + "_" +
150 Twine(colNum)));
151 }
152 }
153 // Reverse the vector since ArrayCreateOp has the opposite ordering to C
154 // vectors.
155 std::reverse(colPEOutputs.begin(), colPEOutputs.end());
156 peOutputs.push_back(
157 rewriter.create<hw::ArrayCreateOp>(loc, colPEOutputs));
158 }
159
160 std::reverse(peOutputs.begin(), peOutputs.end());
161 rewriter.replaceOp(array,
162 rewriter.create<hw::ArrayCreateOp>(loc, peOutputs));
163 return success();
164 }
165};
166} // anonymous namespace
167
168void LowerConstructsPass::runOnOperation() {
169 auto top = getOperation();
170 auto *ctxt = &getContext();
171
172 ConversionTarget target(*ctxt);
173 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
174
175 RewritePatternSet patterns(ctxt);
176 patterns.insert<SystolicArrayOpLowering>(ctxt);
177 target.addIllegalOp<SystolicArrayOp>();
178
179 if (failed(mlir::applyPartialConversion(top, target, std::move(patterns))))
180 signalPassFailure();
181}
182
184 return std::make_unique<LowerConstructsPass>();
185}
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
create(data_type, value)
Definition hw.py:433
std::unique_ptr< mlir::Pass > createLowerConstructsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition msft.py:1