19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Pass/PassRegistry.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 using namespace circt;
37 struct LowerConstructsPass :
public LowerConstructsBase<LowerConstructsPass> {
38 void runOnOperation()
override;
42 Namespace &getNamespaceFor(Operation *mod) {
43 auto ns = moduleNamespaces.find(mod);
44 if (ns != moduleNamespaces.end())
45 return ns->getSecond();
54 DenseMap<Operation *, circt::Namespace> moduleNamespaces;
62 using OpConversionPattern::OpConversionPattern;
65 matchAndRewrite(SystolicArrayOp array, OpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter)
const final {
67 MLIRContext *ctxt = getContext();
68 Location loc = array.getLoc();
69 Block &peBlock = array.getPe().front();
70 rewriter.setInsertionPointAfter(array);
74 hw::ArrayType rowInputs =
75 hw::type_cast<hw::ArrayType>(array.getRowInputs().getType());
76 IntegerType rowIdxType = rewriter.getIntegerType(
77 std::max(1u, llvm::Log2_64_Ceil(rowInputs.getNumElements())));
78 SmallVector<Value> rowValues;
79 for (
size_t rowNum = 0, numRows = rowInputs.getNumElements();
80 rowNum < numRows; ++rowNum) {
85 rowValue->setAttr(
"sv.namehint",
87 rowValues.push_back(rowValue);
92 hw::ArrayType colInputs =
93 hw::type_cast<hw::ArrayType>(array.getColInputs().getType());
94 IntegerType colIdxType = rewriter.getIntegerType(
95 std::max(1u, llvm::Log2_64_Ceil(colInputs.getNumElements())));
96 SmallVector<Value> colValues;
97 for (
size_t colNum = 0, numCols = colInputs.getNumElements();
98 colNum < numCols; ++colNum) {
103 colValue->setAttr(
"sv.namehint",
105 colValues.push_back(colValue);
109 SmallVector<Value> peOutputs;
110 for (
size_t rowNum = 0, numRows = rowInputs.getNumElements();
111 rowNum < numRows; ++rowNum) {
112 Value rowValue = rowValues[rowNum];
113 SmallVector<Value> colPEOutputs;
114 for (
size_t colNum = 0, numCols = colInputs.getNumElements();
115 colNum < numCols; ++colNum) {
116 Value colValue = colValues[colNum];
122 mapper.map(peBlock.getArgument(0), rowValue);
123 mapper.map(peBlock.getArgument(1), colValue);
124 for (Operation &peOperation : peBlock)
127 if (
auto outputOp = dyn_cast<PEOutputOp>(peOperation)) {
128 colPEOutputs.push_back(mapper.lookup(outputOp.getOutput()));
130 Operation *clone = rewriter.clone(peOperation, mapper);
132 StringRef nameSource =
"name";
133 auto name = clone->getAttrOfType<StringAttr>(nameSource);
135 nameSource =
"sv.namehint";
136 name = clone->getAttrOfType<StringAttr>(nameSource);
139 clone->setAttr(nameSource,
141 Twine(rowNum) +
"_" +
147 std::reverse(colPEOutputs.begin(), colPEOutputs.end());
152 std::reverse(peOutputs.begin(), peOutputs.end());
153 rewriter.replaceOp(array,
164 ChannelOpLowering(MLIRContext *ctxt, LowerConstructsPass &pass)
168 matchAndRewrite(ChannelOp chan, OpAdaptor adaptor,
169 ConversionPatternRewriter &rewriter)
const final {
170 Location loc = chan.getLoc();
171 Operation *mod = chan->getParentOfType<hw::HWModuleLike>();
172 assert(mod &&
"ChannelOp must be contained by module");
173 Namespace &ns = pass.getNamespaceFor(mod);
174 Value
clk = chan.getClk();
175 Value v = chan.getInput();
176 for (uint64_t stageNum = 0, e = chan.getDefaultStages(); stageNum < e;
179 ns.
newName(chan.getSymName()));
180 rewriter.replaceOp(chan, {v});
185 LowerConstructsPass &pass;
189 void LowerConstructsPass::runOnOperation() {
190 auto top = getOperation();
191 auto *ctxt = &getContext();
193 ConversionTarget target(*ctxt);
194 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
197 patterns.insert<SystolicArrayOpLowering>(ctxt);
198 target.addIllegalOp<SystolicArrayOp>();
199 patterns.insert<ChannelOpLowering>(ctxt, *
this);
200 target.addIllegalOp<ChannelOp>();
202 if (failed(mlir::applyPartialConversion(top, target, std::move(
patterns))))
207 return std::make_unique<LowerConstructsPass>();
assert(baseType &&"element must be base type")
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(SymbolCache &symCache)
SymbolCache initializer; initialize from every key that is convertible to a StringAttr in the SymbolC...
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
def create(data_type, value)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createLowerConstructsPass()
This file defines an intermediate representation for circuits acting as an abstraction for constraint...