CIRCT  19.0.0git
FindInitialVectors.cpp
Go to the documentation of this file.
1 //===- FindInitialVectors.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements a simple SLP vectorizer for Arc, the pass starts with
10 // `arc.state` operations as seeds in every new vector, then following the
11 // dependency graph nodes computes a rank to every operation in the module
12 // and assigns a rank to each one of them. After that it groups isomorphic
13 // operations together and put them in a vector.
14 //
15 //===----------------------------------------------------------------------===//
16 
20 #include "circt/Dialect/HW/HWOps.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/IRMapping.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/Types.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include <algorithm>
34 
35 #define DEBUG_TYPE "find-initial-vectors"
36 
37 namespace circt {
38 namespace arc {
39 #define GEN_PASS_DEF_FINDINITIALVECTORS
40 #include "circt/Dialect/Arc/ArcPasses.h.inc"
41 } // namespace arc
42 } // namespace circt
43 
44 using namespace circt;
45 using namespace arc;
46 using llvm::SmallMapVector;
47 
48 namespace {
49 struct FindInitialVectorsPass
50  : public impl::FindInitialVectorsBase<FindInitialVectorsPass> {
51  void runOnOperation() override;
52 
53  struct StatisticVars {
54  size_t vecOps{0};
55  size_t savedOps{0};
56  size_t bigSeedVec{0};
57  size_t vecCreated{0};
58  };
59 
60  StatisticVars stat;
61 };
62 } // namespace
63 
64 namespace {
65 struct TopologicalOrder {
66  /// An integer rank assigned to each operation.
67  SmallMapVector<Operation *, unsigned, 32> opRanks;
68  LogicalResult compute(Block *block);
69  unsigned get(Operation *op) const {
70  const auto *it = opRanks.find(op);
71  assert(it != opRanks.end() && "op has no rank");
72  return it->second;
73  }
74 };
75 } // namespace
76 
77 /// Assign each operation in the given block a topological rank. Stateful
78 /// elements are assigned rank 0. All other operations receive the maximum rank
79 /// of their users, plus one.
80 LogicalResult TopologicalOrder::compute(Block *block) {
81  LLVM_DEBUG(llvm::dbgs() << "- Computing topological order in block " << block
82  << "\n");
83  struct WorklistItem {
84  WorklistItem(Operation *op) : userIt(op->user_begin()) {}
85  Operation::user_iterator userIt;
86  unsigned rank = 0;
87  };
88  SmallMapVector<Operation *, WorklistItem, 16> worklist;
89  for (auto &op : *block) {
90  if (opRanks.contains(&op))
91  continue;
92  worklist.insert({&op, WorklistItem(&op)});
93  while (!worklist.empty()) {
94  auto &[op, item] = worklist.back();
95  if (auto stateOp = dyn_cast<StateOp>(op)) {
96  if (stateOp.getLatency() > 0)
97  item.userIt = op->user_end();
98  } else if (auto writeOp = dyn_cast<MemoryWritePortOp>(op)) {
99  item.userIt = op->user_end();
100  }
101  if (item.userIt == op->user_end()) {
102  opRanks.insert({op, item.rank});
103  worklist.pop_back();
104  continue;
105  }
106  if (auto *rankIt = opRanks.find(*item.userIt); rankIt != opRanks.end()) {
107  item.rank = std::max(item.rank, rankIt->second + 1);
108  ++item.userIt;
109  continue;
110  }
111  if (!worklist.insert({*item.userIt, WorklistItem(*item.userIt)}).second)
112  return op->emitError("dependency cycle");
113  }
114  }
115  return success();
116 }
117 
118 namespace {
119 using Key = std::tuple<unsigned, StringRef, SmallVector<Type>,
120  SmallVector<Type>, DictionaryAttr>;
121 
122 Key computeKey(Operation *op, unsigned rank) {
123  // The key = concat(op_rank, op_name, op_operands_types, op_result_types,
124  // op_attrs)
125  return std::make_tuple(
126  rank, op->getName().getStringRef(),
127  SmallVector<Type>(op->operand_type_begin(), op->operand_type_end()),
128  SmallVector<Type>(op->result_type_begin(), op->result_type_end()),
129  op->getAttrDictionary());
130 }
131 
132 struct Vectorizer {
133  Vectorizer(Block *block) : block(block) {}
134  LogicalResult collectSeeds(Block *block) {
135  if (failed(order.compute(block)))
136  return failure();
137 
138  for (auto &[op, rank] : order.opRanks)
139  candidates[computeKey(op, rank)].push_back(op);
140 
141  return success();
142  }
143 
144  LogicalResult vectorize(FindInitialVectorsPass::StatisticVars &stat);
145  // Store Isomorphic ops together
146  SmallMapVector<Key, SmallVector<Operation *>, 16> candidates;
147  TopologicalOrder order;
148  Block *block;
149 };
150 } // namespace
151 
152 namespace llvm {
153 template <>
154 struct DenseMapInfo<Key> {
155  static inline Key getEmptyKey() {
156  return Key(0, StringRef(), SmallVector<Type>(), SmallVector<Type>(),
157  DictionaryAttr());
158  }
159 
160  static inline Key getTombstoneKey() {
161  static StringRef tombStoneKeyOpName =
162  DenseMapInfo<StringRef>::getTombstoneKey();
163  return Key(1, tombStoneKeyOpName, SmallVector<Type>(), SmallVector<Type>(),
164  DictionaryAttr());
165  }
166 
167  static unsigned getHashValue(const Key &key) {
168  return hash_value(std::get<0>(key)) ^ hash_value(std::get<1>(key)) ^
169  hash_value(std::get<2>(key)) ^ hash_value(std::get<3>(key)) ^
170  hash_value(std::get<4>(key));
171  }
172 
173  static bool isEqual(const Key &lhs, const Key &rhs) { return lhs == rhs; }
174 };
175 } // namespace llvm
176 
177 // When calling this function we assume that we have the candidate groups of
178 // isomorphic ops so we need to feed them to the `VectorizeOp`
179 LogicalResult
180 Vectorizer::vectorize(FindInitialVectorsPass::StatisticVars &stat) {
181  LLVM_DEBUG(llvm::dbgs() << "- Vectorizing the ops in block" << block << "\n");
182 
183  if (failed(collectSeeds(block)))
184  return failure();
185 
186  // Unachievable?! just in case!
187  if (candidates.empty())
188  return success();
189 
190  // Iterate over every group of isomorphic ops
191  for (const auto &[key, ops] : candidates) {
192  // If the group has only one scalar then it doesn't worth vectorizing,
193  // We skip also ops with more than one result as `arc.vectorize` supports
194  // only one result in its body region. Ignore zero-result and zero operands
195  // ops as well.
196  if (ops.size() == 1 || ops[0]->getNumResults() != 1 ||
197  ops[0]->getNumOperands() == 0)
198  continue;
199 
200  // Collect Statistics
201  stat.vecOps += ops.size();
202  stat.savedOps += ops.size() - 1;
203  stat.bigSeedVec = std::max(ops.size(), stat.bigSeedVec);
204  ++stat.vecCreated;
205 
206  // Here, we have a bunch of isomorphic ops, we need to extract the operands
207  // results and attributes of every op and store them in a vector
208  // Holds the operands
209  SmallVector<SmallVector<Value, 4>> vectorOperands;
210  vectorOperands.resize(ops[0]->getNumOperands());
211  for (auto *op : ops)
212  for (auto [into, operand] : llvm::zip(vectorOperands, op->getOperands()))
213  into.push_back(operand);
214  SmallVector<ValueRange> operandValueRanges;
215  operandValueRanges.assign(vectorOperands.begin(), vectorOperands.end());
216  // Holds the results
217  SmallVector<Type> resultTypes(ops.size(), ops[0]->getResult(0).getType());
218 
219  // Now construct the `VectorizeOp`
220  ImplicitLocOpBuilder builder(ops[0]->getLoc(), ops[0]);
221  auto vectorizeOp =
222  builder.create<VectorizeOp>(resultTypes, operandValueRanges);
223 
224  // Now we have the operands, results and attributes, now we need to get
225  // the blocks.
226 
227  // There was no blocks so we need to create one and set the insertion point
228  // at the first of this region
229  auto &vectorizeBlock = vectorizeOp.getBody().emplaceBlock();
230  builder.setInsertionPointToStart(&vectorizeBlock);
231 
232  // Add the block arguments
233  // comb.and %x, %y
234  // comb.and %u, %v
235  // at this point the operands vector will be {{x, u}, {y, v}}
236  // we need to create an th block args, so we need the type and the location
237  // the type is a vector type
238  IRMapping argMapping;
239  for (auto [vecOperand, origOpernad] :
240  llvm::zip(vectorOperands, ops[0]->getOperands())) {
241  auto arg = vectorizeBlock.addArgument(vecOperand[0].getType(),
242  origOpernad.getLoc());
243  argMapping.map(origOpernad, arg);
244  }
245 
246  auto *clonedOp = builder.clone(*ops[0], argMapping);
247  // `VectorizeReturnOp`
248  builder.create<VectorizeReturnOp>(clonedOp->getResult(0));
249 
250  // Now replace the original ops with the vectorized ops
251  for (auto [op, result] : llvm::zip(ops, vectorizeOp->getResults())) {
252  op->getResult(0).replaceAllUsesWith(result);
253  op->erase();
254  }
255  }
256  return success();
257 }
258 
259 void FindInitialVectorsPass::runOnOperation() {
260  for (auto moduleOp : getOperation().getOps<hw::HWModuleOp>()) {
261  auto result = moduleOp.walk([&](Block *block) {
262  if (!mayHaveSSADominance(*block->getParent()))
263  if (failed(Vectorizer(block).vectorize(stat)))
264  return WalkResult::interrupt();
265  return WalkResult::advance();
266  });
267  if (result.wasInterrupted())
268  return signalPassFailure();
269  }
270 
271  numOfVectorizedOps = stat.vecOps;
272  numOfSavedOps = stat.savedOps;
273  biggestSeedVector = stat.bigSeedVec;
274  numOfVectorsCreated = stat.vecCreated;
275 }
276 
277 std::unique_ptr<Pass> arc::createFindInitialVectorsPass() {
278  return std::make_unique<FindInitialVectorsPass>();
279 }
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createFindInitialVectorsPass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
inline ::llvm::hash_code hash_value(const FieldRef &fieldRef)
Get a hash code for a FieldRef.
Definition: FieldRef.h:92
static bool isEqual(const Key &lhs, const Key &rhs)
static unsigned getHashValue(const Key &key)