22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/IRMapping.h"
25#include "mlir/IR/MLIRContext.h"
26#include "mlir/IR/Matchers.h"
27#include "mlir/IR/PatternMatch.h"
28#include "mlir/IR/Types.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Support/raw_ostream.h"
35#define DEBUG_TYPE "find-initial-vectors"
39#define GEN_PASS_DEF_FINDINITIALVECTORS
40#include "circt/Dialect/Arc/ArcPasses.h.inc"
49struct FindInitialVectorsPass
50 :
public impl::FindInitialVectorsBase<FindInitialVectorsPass> {
51 void runOnOperation()
override;
53 struct StatisticVars {
65struct TopologicalOrder {
68 LogicalResult compute(Block *block);
69 unsigned get(Operation *op)
const {
70 const auto *it = opRanks.find(op);
71 assert(it != opRanks.end() &&
"op has no rank");
80LogicalResult TopologicalOrder::compute(Block *block) {
81 LLVM_DEBUG(llvm::dbgs() <<
"- Computing topological order in block " << block
84 WorklistItem(Operation *op) : userIt(op->user_begin()) {}
85 Operation::user_iterator userIt;
89 for (
auto &op : *block) {
90 if (opRanks.contains(&op))
92 worklist.insert({&op, WorklistItem(&op)});
93 while (!worklist.empty()) {
94 auto &[op, item] = worklist.back();
95 if (
auto stateOp = dyn_cast<StateOp>(op)) {
96 if (stateOp.getLatency() > 0)
97 item.userIt = op->user_end();
98 }
else if (
auto writeOp = dyn_cast<MemoryWritePortOp>(op)) {
99 item.userIt = op->user_end();
101 if (item.userIt == op->user_end()) {
102 opRanks.insert({op, item.rank});
106 if (
auto *rankIt = opRanks.find(*item.userIt); rankIt != opRanks.end()) {
107 item.rank = std::max(item.rank, rankIt->second + 1);
111 if (!worklist.insert({*item.userIt, WorklistItem(*item.userIt)}).second)
112 return op->emitError(
"dependency cycle");
119using Key = std::tuple<unsigned, StringRef, SmallVector<Type>,
120 SmallVector<Type>, DictionaryAttr>;
122Key computeKey(Operation *op,
unsigned rank) {
125 return std::make_tuple(
126 rank, op->getName().getStringRef(),
127 SmallVector<Type>(op->operand_type_begin(), op->operand_type_end()),
128 SmallVector<Type>(op->result_type_begin(), op->result_type_end()),
129 op->getAttrDictionary());
133 Vectorizer(Block *block) : block(block) {}
134 LogicalResult collectSeeds(Block *block) {
135 if (failed(order.compute(block)))
138 for (
auto &[op, rank] : order.opRanks)
139 candidates[computeKey(op, rank)].push_back(op);
144 LogicalResult vectorize(FindInitialVectorsPass::StatisticVars &stat);
147 TopologicalOrder order;
161 static bool isEqual(
const Key &lhs,
const Key &rhs) {
return lhs == rhs; }
168Vectorizer::vectorize(FindInitialVectorsPass::StatisticVars &stat) {
169 LLVM_DEBUG(llvm::dbgs() <<
"- Vectorizing the ops in block" << block <<
"\n");
171 if (failed(collectSeeds(block)))
175 if (candidates.empty())
179 for (
const auto &[key, ops] : candidates) {
184 if (ops.size() == 1 || ops[0]->getNumResults() != 1 ||
185 ops[0]->getNumOperands() == 0)
189 stat.vecOps += ops.size();
190 stat.savedOps += ops.size() - 1;
191 stat.bigSeedVec = std::max(ops.size(), stat.bigSeedVec);
197 SmallVector<SmallVector<Value, 4>> vectorOperands;
198 vectorOperands.resize(ops[0]->getNumOperands());
200 for (auto [into, operand] :
llvm::zip(vectorOperands, op->getOperands()))
201 into.push_back(operand);
202 SmallVector<ValueRange> operandValueRanges;
203 operandValueRanges.assign(vectorOperands.begin(), vectorOperands.end());
205 SmallVector<Type> resultTypes(ops.size(), ops[0]->getResult(0).getType());
208 ImplicitLocOpBuilder builder(ops[0]->
getLoc(), ops[0]);
210 VectorizeOp::create(builder, resultTypes, operandValueRanges);
217 auto &vectorizeBlock = vectorizeOp.getBody().emplaceBlock();
218 builder.setInsertionPointToStart(&vectorizeBlock);
226 IRMapping argMapping;
227 for (
auto [vecOperand, origOpernad] :
228 llvm::zip(vectorOperands, ops[0]->getOperands())) {
229 auto arg = vectorizeBlock.addArgument(vecOperand[0].getType(),
230 origOpernad.getLoc());
231 argMapping.map(origOpernad, arg);
234 auto *clonedOp = builder.clone(*ops[0], argMapping);
236 VectorizeReturnOp::create(builder, clonedOp->getResult(0));
239 for (
auto [op, result] :
llvm::zip(ops, vectorizeOp->getResults())) {
240 op->getResult(0).replaceAllUsesWith(result);
247void FindInitialVectorsPass::runOnOperation() {
248 for (
auto moduleOp : getOperation().getOps<
hw::HWModuleOp>()) {
249 auto result = moduleOp.walk([&](Block *block) {
250 if (!mayHaveSSADominance(*block->getParent()))
251 if (failed(Vectorizer(block).vectorize(stat)))
252 return WalkResult::interrupt();
253 return WalkResult::advance();
255 if (result.wasInterrupted())
256 return signalPassFailure();
259 numOfVectorizedOps = stat.vecOps;
260 numOfSavedOps = stat.savedOps;
261 biggestSeedVector = stat.bigSeedVec;
262 numOfVectorsCreated = stat.vecCreated;
assert(baseType &&"element must be base type")
static Location getLoc(DefSlot slot)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::hash_code hash_value(const DenseSet< T > &set)
static bool isEqual(const Key &lhs, const Key &rhs)
static unsigned getHashValue(const Key &key)