19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Pass/PassRegistry.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29#define GEN_PASS_DEF_LOWERCONSTRUCTS
30#include "circt/Dialect/MSFT/MSFTPasses.h.inc"
44struct LowerConstructsPass
45 :
public circt::msft::impl::LowerConstructsBase<LowerConstructsPass> {
46 void runOnOperation()
override;
50 Namespace &getNamespaceFor(Operation *mod) {
51 auto ns = moduleNamespaces.find(mod);
52 if (ns != moduleNamespaces.end())
53 return ns->getSecond();
62 DenseMap<Operation *, circt::Namespace> moduleNamespaces;
70 using OpConversionPattern::OpConversionPattern;
73 matchAndRewrite(SystolicArrayOp array, OpAdaptor adaptor,
74 ConversionPatternRewriter &rewriter)
const final {
75 MLIRContext *
ctxt = getContext();
76 Location loc = array.getLoc();
77 Block &peBlock = array.getPe().front();
78 rewriter.setInsertionPointAfter(array);
82 hw::ArrayType rowInputs =
83 hw::type_cast<hw::ArrayType>(array.getRowInputs().getType());
84 IntegerType rowIdxType = rewriter.getIntegerType(
85 std::max(1u, llvm::Log2_64_Ceil(rowInputs.getNumElements())));
86 SmallVector<Value> rowValues;
87 for (
size_t rowNum = 0, numRows = rowInputs.getNumElements();
88 rowNum < numRows; ++rowNum) {
93 rowValue->setAttr(
"sv.namehint",
94 StringAttr::get(ctxt,
"row_" + Twine(rowNum)));
95 rowValues.push_back(rowValue);
100 hw::ArrayType colInputs =
101 hw::type_cast<hw::ArrayType>(array.getColInputs().getType());
102 IntegerType colIdxType = rewriter.getIntegerType(
103 std::max(1u, llvm::Log2_64_Ceil(colInputs.getNumElements())));
104 SmallVector<Value> colValues;
105 for (
size_t colNum = 0, numCols = colInputs.getNumElements();
106 colNum < numCols; ++colNum) {
111 colValue->setAttr(
"sv.namehint",
112 StringAttr::get(ctxt,
"col_" + Twine(colNum)));
113 colValues.push_back(colValue);
117 SmallVector<Value> peOutputs;
118 for (
size_t rowNum = 0, numRows = rowInputs.getNumElements();
119 rowNum < numRows; ++rowNum) {
120 Value rowValue = rowValues[rowNum];
121 SmallVector<Value> colPEOutputs;
122 for (
size_t colNum = 0, numCols = colInputs.getNumElements();
123 colNum < numCols; ++colNum) {
124 Value colValue = colValues[colNum];
130 mapper.map(peBlock.getArgument(0), rowValue);
131 mapper.map(peBlock.getArgument(1), colValue);
132 for (Operation &peOperation : peBlock)
135 if (auto outputOp = dyn_cast<PEOutputOp>(peOperation)) {
136 colPEOutputs.push_back(mapper.lookup(outputOp.getOutput()));
138 Operation *clone = rewriter.clone(peOperation, mapper);
140 StringRef nameSource =
"name";
141 auto name = clone->getAttrOfType<StringAttr>(nameSource);
143 nameSource =
"sv.namehint";
144 name = clone->getAttrOfType<StringAttr>(nameSource);
147 clone->setAttr(nameSource,
148 StringAttr::get(ctxt, name.getValue() +
"_" +
149 Twine(rowNum) +
"_" +
155 std::reverse(colPEOutputs.begin(), colPEOutputs.end());
160 std::reverse(peOutputs.begin(), peOutputs.end());
161 rewriter.replaceOp(array,
168void LowerConstructsPass::runOnOperation() {
169 auto top = getOperation();
170 auto *
ctxt = &getContext();
172 ConversionTarget target(*ctxt);
173 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
177 target.addIllegalOp<SystolicArrayOp>();
179 if (failed(mlir::applyPartialConversion(top, target, std::move(
patterns))))
184 return std::make_unique<LowerConstructsPass>();
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
std::unique_ptr< mlir::Pass > createLowerConstructsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.