22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/IRMapping.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/Types.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
35 #define DEBUG_TYPE "find-initial-vectors"
39 #define GEN_PASS_DEF_FINDINITIALVECTORS
40 #include "circt/Dialect/Arc/ArcPasses.h.inc"
44 using namespace circt;
46 using llvm::SmallMapVector;
49 struct FindInitialVectorsPass
50 :
public impl::FindInitialVectorsBase<FindInitialVectorsPass> {
51 void runOnOperation()
override;
53 struct StatisticVars {
65 struct TopologicalOrder {
67 SmallMapVector<Operation *, unsigned, 32> opRanks;
68 LogicalResult compute(Block *block);
69 unsigned get(Operation *op)
const {
70 const auto *it = opRanks.find(op);
71 assert(it != opRanks.end() &&
"op has no rank");
80 LogicalResult TopologicalOrder::compute(Block *block) {
81 LLVM_DEBUG(llvm::dbgs() <<
"- Computing topological order in block " << block
84 WorklistItem(Operation *op) : userIt(op->user_begin()) {}
85 Operation::user_iterator userIt;
88 SmallMapVector<Operation *, WorklistItem, 16> worklist;
89 for (
auto &op : *block) {
90 if (opRanks.contains(&op))
92 worklist.insert({&op, WorklistItem(&op)});
93 while (!worklist.empty()) {
94 auto &[op, item] = worklist.back();
95 if (
auto stateOp = dyn_cast<StateOp>(op)) {
96 if (stateOp.getLatency() > 0)
97 item.userIt = op->user_end();
98 }
else if (
auto writeOp = dyn_cast<MemoryWritePortOp>(op)) {
99 item.userIt = op->user_end();
101 if (item.userIt == op->user_end()) {
102 opRanks.insert({op, item.rank});
106 if (
auto *rankIt = opRanks.find(*item.userIt); rankIt != opRanks.end()) {
107 item.rank = std::max(item.rank, rankIt->second + 1);
111 if (!worklist.insert({*item.userIt, WorklistItem(*item.userIt)}).second)
112 return op->emitError(
"dependency cycle");
119 using Key = std::tuple<unsigned, StringRef, SmallVector<Type>,
120 SmallVector<Type>, DictionaryAttr>;
122 Key computeKey(Operation *op,
unsigned rank) {
125 return std::make_tuple(
126 rank, op->getName().getStringRef(),
127 SmallVector<Type>(op->operand_type_begin(), op->operand_type_end()),
128 SmallVector<Type>(op->result_type_begin(), op->result_type_end()),
129 op->getAttrDictionary());
133 Vectorizer(Block *block) : block(block) {}
134 LogicalResult collectSeeds(Block *block) {
135 if (failed(order.compute(block)))
138 for (
auto &[op, rank] : order.opRanks)
139 candidates[computeKey(op, rank)].push_back(op);
144 LogicalResult vectorize(FindInitialVectorsPass::StatisticVars &stat);
146 SmallMapVector<Key, SmallVector<Operation *>, 16> candidates;
147 TopologicalOrder order;
156 return Key(0, StringRef(), SmallVector<Type>(), SmallVector<Type>(),
161 static StringRef tombStoneKeyOpName =
163 return Key(1, tombStoneKeyOpName, SmallVector<Type>(), SmallVector<Type>(),
173 static bool isEqual(
const Key &lhs,
const Key &rhs) {
return lhs == rhs; }
180 Vectorizer::vectorize(FindInitialVectorsPass::StatisticVars &stat) {
181 LLVM_DEBUG(llvm::dbgs() <<
"- Vectorizing the ops in block" << block <<
"\n");
183 if (failed(collectSeeds(block)))
187 if (candidates.empty())
191 for (
const auto &[key, ops] : candidates) {
196 if (ops.size() == 1 || ops[0]->getNumResults() != 1 ||
197 ops[0]->getNumOperands() == 0)
201 stat.vecOps += ops.size();
202 stat.savedOps += ops.size() - 1;
203 stat.bigSeedVec = std::max(ops.size(), stat.bigSeedVec);
209 SmallVector<SmallVector<Value, 4>> vectorOperands;
210 vectorOperands.resize(ops[0]->getNumOperands());
212 for (
auto [into, operand] : llvm::zip(vectorOperands, op->getOperands()))
213 into.push_back(operand);
214 SmallVector<ValueRange> operandValueRanges;
215 operandValueRanges.assign(vectorOperands.begin(), vectorOperands.end());
217 SmallVector<Type> resultTypes(ops.size(), ops[0]->getResult(0).getType());
220 ImplicitLocOpBuilder builder(ops[0]->getLoc(), ops[0]);
222 builder.create<VectorizeOp>(resultTypes, operandValueRanges);
229 auto &vectorizeBlock = vectorizeOp.getBody().emplaceBlock();
230 builder.setInsertionPointToStart(&vectorizeBlock);
238 IRMapping argMapping;
239 for (
auto [vecOperand, origOpernad] :
240 llvm::zip(vectorOperands, ops[0]->getOperands())) {
241 auto arg = vectorizeBlock.addArgument(vecOperand[0].getType(),
242 origOpernad.getLoc());
243 argMapping.map(origOpernad, arg);
246 auto *clonedOp = builder.clone(*ops[0], argMapping);
248 builder.create<VectorizeReturnOp>(clonedOp->getResult(0));
251 for (
auto [op, result] : llvm::zip(ops, vectorizeOp->getResults())) {
252 op->getResult(0).replaceAllUsesWith(result);
259 void FindInitialVectorsPass::runOnOperation() {
260 for (
auto moduleOp : getOperation().getOps<hw::HWModuleOp>()) {
261 auto result = moduleOp.walk([&](Block *block) {
262 if (!mayHaveSSADominance(*block->getParent()))
263 if (failed(Vectorizer(block).vectorize(stat)))
264 return WalkResult::interrupt();
265 return WalkResult::advance();
267 if (result.wasInterrupted())
268 return signalPassFailure();
271 numOfVectorizedOps = stat.vecOps;
272 numOfSavedOps = stat.savedOps;
273 biggestSeedVector = stat.bigSeedVec;
274 numOfVectorsCreated = stat.vecCreated;
278 return std::make_unique<FindInitialVectorsPass>();
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createFindInitialVectorsPass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
inline ::llvm::hash_code hash_value(const FieldRef &fieldRef)
Get a hash code for a FieldRef.
static bool isEqual(const Key &lhs, const Key &rhs)
static unsigned getHashValue(const Key &key)
static Key getTombstoneKey()