CIRCT  20.0.0git
InferStateProperties.cpp
Go to the documentation of this file.
1 //===- InferStateProperties.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "circt/Dialect/HW/HWOps.h"
13 #include "circt/Support/LLVM.h"
14 #include "mlir/Pass/Pass.h"
15 
16 #define DEBUG_TYPE "arc-infer-state-properties"
17 
18 namespace circt {
19 namespace arc {
20 #define GEN_PASS_DEF_INFERSTATEPROPERTIES
21 #include "circt/Dialect/Arc/ArcPasses.h.inc"
22 } // namespace arc
23 } // namespace circt
24 
25 using namespace circt;
26 using namespace arc;
27 
28 //===----------------------------------------------------------------------===//
29 // Helpers
30 //===----------------------------------------------------------------------===//
31 
32 static bool isConstZero(Value value) {
33  if (auto constOp = value.getDefiningOp<hw::ConstantOp>())
34  return constOp.getValue().isZero();
35 
36  return false;
37 }
38 
39 static bool isConstTrue(Value value) {
40  if (auto constOp = value.getDefiningOp<hw::ConstantOp>()) {
41  return constOp.getValue().getBitWidth() == 1 &&
42  constOp.getValue().isAllOnes();
43  }
44  return false;
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Reset and Enable property storages
49 //===----------------------------------------------------------------------===//
50 
51 namespace {
52 /// Contains all the information needed to pass a detected reset to the rewriter
53 /// function.
54 struct ResetInfo {
55  ResetInfo() = default;
56  ResetInfo(std::function<Value(OpBuilder &)> &&constructInput,
57  BlockArgument condition, bool isZeroReset)
58  : constructInput(constructInput), condition(condition),
59  isZeroReset(isZeroReset) {}
60 
61  ResetInfo(Value input, BlockArgument condition, bool isZeroReset)
62  : ResetInfo([=](OpBuilder &) { return input; }, condition, isZeroReset) {}
63 
64  std::function<Value(OpBuilder &)> constructInput;
65  BlockArgument condition;
66  bool isZeroReset;
67 
68  operator bool() { return constructInput && condition; }
69 };
70 
71 /// Contains all the information needed to pass a detected enable to the
72 /// rewriter function.
73 struct EnableInfo {
74  EnableInfo() = default;
75  EnableInfo(std::function<Value(OpBuilder &)> &&constructInput,
76  BlockArgument condition, BlockArgument selfArg, bool isDisable)
77  : constructInput(constructInput), condition(condition), selfArg(selfArg),
78  isDisable(isDisable) {}
79 
80  EnableInfo(Value input, BlockArgument condition, BlockArgument selfArg,
81  bool isDisable)
82  : EnableInfo([=](OpBuilder &) { return input; }, condition, selfArg,
83  isDisable) {}
84 
85  std::function<Value(OpBuilder &)> constructInput;
86  BlockArgument condition;
87  BlockArgument selfArg;
88  bool isDisable;
89 
90  operator bool() { return constructInput && condition && selfArg; }
91 };
92 } // namespace
93 
94 //===----------------------------------------------------------------------===//
95 // Rewriter functions
96 //===----------------------------------------------------------------------===//
97 
98 /// Take an arc and a detected reset per output value and apply it to the arc if
99 /// applicable (but does not change the state ops referring to the arc).
100 static LogicalResult applyResetTransformation(arc::DefineOp arcOp,
101  ArrayRef<ResetInfo> resetInfos) {
102  auto outputOp = cast<arc::OutputOp>(arcOp.getBodyBlock().getTerminator());
103 
104  assert(outputOp.getOutputs().size() == resetInfos.size() &&
105  "required to pass the same amount of resets as outputs of the arc");
106 
107  for (auto info : resetInfos) {
108  if (!info)
109  return failure();
110 
111  // We can only pull out the reset to the whole arc when all the output
112  // values have the same reset applied to them.
113  // TODO: split the arcs such that there is one for each reset kind, however,
114  // that requires a cost-model to not blow up binary-size too much
115  if (!resetInfos.empty() &&
116  (info.condition != resetInfos.back().condition ||
117  info.isZeroReset != resetInfos.back().isZeroReset))
118  return failure();
119 
120  // TODO: arc.state operation only supports resets to zero at the moment.
121  if (!info.isZeroReset)
122  return failure();
123  }
124 
125  if (resetInfos.empty())
126  return failure();
127 
128  OpBuilder builder(outputOp);
129 
130  for (size_t i = 0, e = outputOp.getOutputs().size(); i < e; ++i) {
131  auto *defOp = outputOp.getOperands()[i].getDefiningOp();
132  outputOp.getOperands()[i].replaceUsesWithIf(
133  resetInfos[i].constructInput(builder),
134  [](OpOperand &op) { return isa<arc::OutputOp>(op.getOwner()); });
135 
136  if (defOp && defOp->getResult(0).use_empty())
137  defOp->erase();
138  }
139 
140  return success();
141 }
142 
143 /// Transform the given state operation to match the changes done to the arc in
144 /// 'applyResetTransformation' without any additional checks.
145 static void setResetOperandOfStateOp(arc::StateOp stateOp,
146  unsigned resetConditionIndex) {
147  Value resetCond = stateOp.getInputs()[resetConditionIndex];
148  ImplicitLocOpBuilder builder(stateOp.getLoc(), stateOp);
149 
150  if (stateOp.getEnable())
151  resetCond = builder.create<comb::AndOp>(stateOp.getEnable(), resetCond);
152 
153  if (stateOp.getReset())
154  resetCond = builder.create<comb::OrOp>(stateOp.getReset(), resetCond);
155 
156  stateOp.getResetMutable().assign(resetCond);
157 }
158 
159 /// Take an arc and a detected enable per output value and apply it to the given
160 /// state if applicable (no changes required to the arc::DefineOp operation for
161 /// enables).
162 static LogicalResult
163 applyEnableTransformation(arc::DefineOp arcOp, arc::StateOp stateOp,
164  ArrayRef<EnableInfo> enableInfos) {
165  auto outputOp = cast<arc::OutputOp>(arcOp.getBodyBlock().getTerminator());
166 
167  assert(outputOp.getOutputs().size() == enableInfos.size() &&
168  "required to pass the same amount of enables as outputs of the arc");
169 
170  for (auto info : enableInfos) {
171  if (!info)
172  return failure();
173 
174  // We can only pull out the enable to the whole arc when all the output
175  // values have the same enable applied to them.
176  // TODO: split the arcs such that there is one for each enable kind,
177  // however, this requires a cost-model to not blow up binary-size too much.
178  if (!enableInfos.empty() &&
179  (info.condition != enableInfos.back().condition ||
180  info.isDisable != enableInfos.back().isDisable))
181  return failure();
182  }
183 
184  if (enableInfos.empty())
185  return failure();
186 
187  if (!enableInfos[0].condition.hasOneUse())
188  return failure();
189 
190  ImplicitLocOpBuilder builder(stateOp.getLoc(), stateOp);
191  SmallVector<Value> inputs(stateOp.getInputs());
192 
193  Value enableCond =
194  stateOp.getInputs()[enableInfos[0].condition.getArgNumber()];
195  Value one = builder.create<hw::ConstantOp>(builder.getI1Type(), -1);
196  if (enableInfos[0].isDisable) {
197  inputs[enableInfos[0].condition.getArgNumber()] =
198  builder.create<hw::ConstantOp>(builder.getI1Type(), 0);
199  enableCond = builder.create<comb::XorOp>(enableCond, one);
200  } else {
201  inputs[enableInfos[0].condition.getArgNumber()] = one;
202  }
203 
204  if (stateOp.getEnable())
205  enableCond = builder.create<comb::AndOp>(stateOp.getEnable(), enableCond);
206 
207  stateOp.getEnableMutable().assign(enableCond);
208 
209  for (size_t i = 0, e = outputOp.getOutputs().size(); i < e; ++i) {
210  if (enableInfos[i].selfArg.hasOneUse())
211  inputs[enableInfos[i].selfArg.getArgNumber()] =
212  builder.create<hw::ConstantOp>(stateOp.getLoc(),
213  enableInfos[i].selfArg.getType(), 0);
214  }
215 
216  stateOp.getInputsMutable().assign(inputs);
217  return success();
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // Pattern detectors
222 //===----------------------------------------------------------------------===//
223 
224 //===----------------------------------------------------------------------===//
225 // Reset Patterns
226 
227 /// A reset represented with a single mux operation.
228 /// out = mux(resetCondition, 0, arcArgument)
229 /// ==>
230 /// return arcArgument directly and add resetCondition to the StateOp
231 static ResetInfo getIfMuxBasedReset(OpOperand &output) {
232  assert(isa<arc::OutputOp>(output.getOwner()) &&
233  "value has to be returned by the arc");
234 
235  if (auto mux = output.get().getDefiningOp<comb::MuxOp>()) {
236  if (!isConstZero(mux.getTrueValue()))
237  return {};
238 
239  if (!mux.getResult().hasOneUse())
240  return {};
241 
242  if (auto condArg = dyn_cast<BlockArgument>(mux.getCond()))
243  return ResetInfo(mux.getFalseValue(), condArg, true);
244  }
245 
246  return {};
247 }
248 
249 /// A reset represented by an AND and XOR operation for i1 values only.
250 /// out = and(X); X being a list containing all of
251 /// {xor(resetCond, true), arcArgument}
252 /// ==>
253 /// out = and(X\xor(resetCond, true)) + add resetCond to StateOp
254 static ResetInfo getIfAndBasedReset(OpOperand &output) {
255  assert(isa<arc::OutputOp>(output.getOwner()) &&
256  "value has to be returned by the arc");
257 
258  if (auto andOp = output.get().getDefiningOp<comb::AndOp>()) {
259  if (!andOp.getResult().getType().isInteger(1))
260  return {};
261 
262  if (!andOp.getResult().hasOneUse())
263  return {};
264 
265  for (auto &operand : andOp->getOpOperands()) {
266  if (auto xorOp = operand.get().getDefiningOp<comb::XorOp>();
267  xorOp && xorOp->getNumOperands() == 2 &&
268  xorOp.getResult().hasOneUse()) {
269  if (auto condArg = dyn_cast<BlockArgument>(xorOp.getInputs()[0])) {
270  if (xorOp.getInputs().size() != 2 ||
271  !isConstTrue(xorOp.getInputs()[1]))
272  continue;
273 
274  const unsigned condOutputNumber = operand.getOperandNumber();
275  auto inputConstructor = [=](OpBuilder &builder) -> Value {
276  if (andOp->getNumOperands() > 2) {
277  builder.setInsertionPoint(andOp);
278  auto copy = cast<comb::AndOp>(builder.clone(*andOp));
279  copy.getInputsMutable().erase(condOutputNumber);
280  return copy->getResult(0);
281  }
282 
283  return andOp->getOperand(!condOutputNumber);
284  };
285 
286  return ResetInfo(inputConstructor, condArg, true);
287  }
288  }
289  }
290  }
291 
292  return {};
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // Enable Patterns
297 
298 /// Just a helper function for the following two patterns.
299 static EnableInfo checkOperandsForEnable(arc::StateOp stateOp, Value selfArg,
300  Value cond, unsigned outputNr,
301  bool isDisable) {
302  if (auto trueArg = dyn_cast<BlockArgument>(selfArg)) {
303  if (stateOp.getInputs()[trueArg.getArgNumber()] !=
304  stateOp.getResult(outputNr))
305  return {};
306 
307  if (auto condArg = dyn_cast<BlockArgument>(cond))
308  return EnableInfo(selfArg, condArg, trueArg, isDisable);
309  }
310 
311  return {};
312 }
313 
314 /// An enable represented by a single mux operation.
315 /// out = mux(enableCond, x, arcArgument) where x is the 'out' of the last cycle
316 /// ==>
317 /// out = arcArgument + set enableCond as enable operand to the StateOp
318 static EnableInfo getIfMuxBasedEnable(OpOperand &output, StateOp stateOp) {
319  assert(isa<arc::OutputOp>(output.getOwner()) &&
320  "value has to be returned by the arc");
321 
322  if (auto mux = output.get().getDefiningOp<comb::MuxOp>()) {
323  if (!mux.getResult().hasOneUse())
324  return {};
325 
326  return checkOperandsForEnable(stateOp, mux.getFalseValue(), mux.getCond(),
327  output.getOperandNumber(), false);
328  }
329 
330  return {};
331 }
332 
333 /// A negated enable represented by a single mux operation.
334 /// out = mux(enableCond, arcArgument, x) where x is the 'out' of the last cycle
335 /// ==>
336 /// out = arcArgument + set xor(enableCond, true) as enable operand to the
337 /// StateOp
338 static EnableInfo getIfMuxBasedDisable(OpOperand &output, StateOp stateOp) {
339  assert(isa<arc::OutputOp>(output.getOwner()) &&
340  "value has to be returned by the arc");
341 
342  if (auto mux = output.get().getDefiningOp<comb::MuxOp>()) {
343  if (!mux.getResult().hasOneUse())
344  return {};
345 
346  return checkOperandsForEnable(stateOp, mux.getTrueValue(), mux.getCond(),
347  output.getOperandNumber(), true);
348  }
349 
350  return {};
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // Combine all the patterns
355 //===----------------------------------------------------------------------===//
356 
357 /// Combine all the reset patterns to one.
358 ResetInfo computeResetInfoFromPattern(OpOperand &output) {
359  auto resetInfo = getIfMuxBasedReset(output);
360 
361  if (!resetInfo)
362  resetInfo = getIfAndBasedReset(output);
363 
364  return resetInfo;
365 }
366 
367 /// Combine all the enable patterns to one.
368 EnableInfo computeEnableInfoFromPattern(OpOperand &output, StateOp stateOp) {
369  auto enableInfo = getIfMuxBasedEnable(output, stateOp);
370 
371  if (!enableInfo)
372  enableInfo = getIfMuxBasedDisable(output, stateOp);
373 
374  return enableInfo;
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // DetectResets pass
379 //===----------------------------------------------------------------------===//
380 
381 namespace {
382 struct InferStatePropertiesPass
383  : public impl::InferStatePropertiesBase<InferStatePropertiesPass> {
384  using InferStatePropertiesBase::InferStatePropertiesBase;
385 
386  void runOnOperation() override;
387  void runOnStateOp(arc::StateOp stateOp, arc::DefineOp arc,
388  DenseMap<arc::DefineOp, unsigned> &resetConditionMap);
389 };
390 } // namespace
391 
392 void InferStatePropertiesPass::runOnOperation() {
393  SymbolTableCollection symbolTable;
394 
395  DenseMap<arc::DefineOp, unsigned> resetConditionMap;
396  getOperation()->walk([&](arc::StateOp stateOp) {
397  auto arc =
398  cast<arc::DefineOp>(cast<mlir::CallOpInterface>(stateOp.getOperation())
399  .resolveCallableInTable(&symbolTable));
400  runOnStateOp(stateOp, arc, resetConditionMap);
401  });
402 }
403 
404 void InferStatePropertiesPass::runOnStateOp(
405  arc::StateOp stateOp, arc::DefineOp arc,
406  DenseMap<arc::DefineOp, unsigned> &resetConditionMap) {
407 
408  auto outputOp = cast<arc::OutputOp>(arc.getBodyBlock().getTerminator());
409  static constexpr unsigned visitedNoChange = -1;
410 
411  if (detectResets) {
412  // Check for reset patterns, we only have to do this once per arc::DefineOp
413  // and store the result for later arc::StateOps referring to the same arc.
414  if (!resetConditionMap.count(arc)) {
415  SmallVector<ResetInfo> resetInfos;
416  int numResets = 0;
417  for (auto &output : outputOp->getOpOperands()) {
418  auto resetInfo = computeResetInfoFromPattern(output);
419  resetInfos.push_back(resetInfo);
420  if (resetInfo)
421  ++numResets;
422  }
423 
424  // Rewrite the arc::DefineOp if valid
425  auto result = applyResetTransformation(arc, resetInfos);
426  if ((succeeded(result) && resetInfos[0]))
427  resetConditionMap[arc] = resetInfos[0].condition.getArgNumber();
428  else
429  resetConditionMap[arc] = visitedNoChange;
430 
431  if (failed(result))
432  missedResets += numResets;
433  }
434 
435  // Apply resets to the state operation.
436  if (resetConditionMap.count(arc) &&
437  resetConditionMap[arc] != visitedNoChange) {
438  setResetOperandOfStateOp(stateOp, resetConditionMap[arc]);
439  ++addedResets;
440  }
441  }
442 
443  if (detectEnables) {
444  // Check for enable patterns.
445  SmallVector<EnableInfo> enableInfos;
446  int numEnables = 0;
447  for (OpOperand &output : outputOp->getOpOperands()) {
448  auto enableInfo = computeEnableInfoFromPattern(output, stateOp);
449  enableInfos.push_back(enableInfo);
450  if (enableInfo)
451  ++numEnables;
452  }
453 
454  // Apply enable patterns.
455  if (!failed(applyEnableTransformation(arc, stateOp, enableInfos)))
456  ++addedEnables;
457  else
458  missedEnables += numEnables;
459  }
460 }
assert(baseType &&"element must be base type")
static bool isConstTrue(Value value)
static EnableInfo checkOperandsForEnable(arc::StateOp stateOp, Value selfArg, Value cond, unsigned outputNr, bool isDisable)
Just a helper function for the following two patterns.
static EnableInfo getIfMuxBasedDisable(OpOperand &output, StateOp stateOp)
A negated enable represented by a single mux operation.
static ResetInfo getIfAndBasedReset(OpOperand &output)
A reset represented by an AND and XOR operation for i1 values only.
static LogicalResult applyEnableTransformation(arc::DefineOp arcOp, arc::StateOp stateOp, ArrayRef< EnableInfo > enableInfos)
Take an arc and a detected enable per output value and apply it to the given state if applicable (no ...
static bool isConstZero(Value value)
static LogicalResult applyResetTransformation(arc::DefineOp arcOp, ArrayRef< ResetInfo > resetInfos)
Take an arc and a detected reset per output value and apply it to the arc if applicable (but does not...
EnableInfo computeEnableInfoFromPattern(OpOperand &output, StateOp stateOp)
Combine all the enable patterns to one.
static void setResetOperandOfStateOp(arc::StateOp stateOp, unsigned resetConditionIndex)
Transform the given state operation to match the changes done to the arc in 'applyResetTransformation...
ResetInfo computeResetInfoFromPattern(OpOperand &output)
Combine all the reset patterns to one.
static EnableInfo getIfMuxBasedEnable(OpOperand &output, StateOp stateOp)
An enable represented by a single mux operation.
static ResetInfo getIfMuxBasedReset(OpOperand &output)
A reset represented with a single mux operation.
def create(data_type, value)
Definition: hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21