CIRCT  20.0.0git
HWLegalizeModules.cpp
Go to the documentation of this file.
1 //===- HWLegalizeModulesPass.cpp - Lower unsupported IR features away -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass lowers away features in the SV/Comb/HW dialects that are
10 // unsupported by some tools (e.g. multidimensional arrays) as specified by
11 // LoweringOptions. This pass is run relatively late in the pipeline in
12 // preparation for emission. Any passes run after this (e.g. PrettifyVerilog)
13 // must be aware they cannot introduce new invalid constructs.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "circt/Dialect/HW/HWOps.h"
19 #include "circt/Dialect/SV/SVOps.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/Pass/Pass.h"
24 
25 namespace circt {
26 namespace sv {
27 #define GEN_PASS_DEF_HWLEGALIZEMODULES
28 #include "circt/Dialect/SV/SVPasses.h.inc"
29 } // namespace sv
30 } // namespace circt
31 
32 using namespace circt;
33 //===----------------------------------------------------------------------===//
34 // HWLegalizeModulesPass
35 //===----------------------------------------------------------------------===//
36 
37 namespace {
38 struct HWLegalizeModulesPass
39  : public circt::sv::impl::HWLegalizeModulesBase<HWLegalizeModulesPass> {
40  void runOnOperation() override;
41 
42 private:
43  void processPostOrder(Block &block);
44  bool tryLoweringPackedArrayOp(Operation &op);
45  Value lowerLookupToCasez(Operation &op, Value input, Value index,
46  mlir::Type elementType,
47  SmallVector<Value> caseValues);
48  bool processUsers(Operation &op, Value value, ArrayRef<Value> mapping);
49  std::optional<std::pair<uint64_t, unsigned>>
50  tryExtractIndexAndBitWidth(Value value);
51 
52  /// This is the current hw.module being processed.
53  hw::HWModuleOp thisHWModule;
54 
55  bool anythingChanged;
56 
57  /// This tells us what language features we're allowed to use in generated
58  /// Verilog.
59  LoweringOptions options;
60 
61  /// This pass will be run on multiple hw.modules, this keeps track of the
62  /// contents of LoweringOptions so we don't have to reparse the
63  /// LoweringOptions for every hw.module.
64  StringAttr lastParsedOptions;
65 };
66 } // end anonymous namespace
67 
68 bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
69  return TypeSwitch<Operation *, bool>(&op)
70  .Case<hw::AggregateConstantOp>([&](hw::AggregateConstantOp constOp) {
71  // Replace individual element uses (if any) with input fields.
72  SmallVector<Value> inputs;
73  OpBuilder builder(constOp);
74  for (auto field : llvm::reverse(constOp.getFields())) {
75  if (auto intAttr = dyn_cast<IntegerAttr>(field))
76  inputs.push_back(
77  builder.create<hw::ConstantOp>(constOp.getLoc(), intAttr));
78  else
79  inputs.push_back(builder.create<hw::AggregateConstantOp>(
80  constOp.getLoc(), constOp.getType(), cast<ArrayAttr>(field)));
81  }
82  if (!processUsers(op, constOp.getResult(), inputs))
83  return false;
84 
85  // Remove original op.
86  return true;
87  })
88  .Case<hw::ArrayConcatOp>([&](hw::ArrayConcatOp concatOp) {
89  // Redirect individual element uses (if any) to the input arguments.
90  SmallVector<std::pair<Value, uint64_t>> arrays;
91  for (auto array : llvm::reverse(concatOp.getInputs())) {
92  auto ty = hw::type_cast<hw::ArrayType>(array.getType());
93  arrays.emplace_back(array, ty.getNumElements());
94  }
95  for (auto *user :
96  llvm::make_early_inc_range(concatOp.getResult().getUsers())) {
97  if (TypeSwitch<Operation *, bool>(user)
98  .Case<hw::ArrayGetOp>([&](hw::ArrayGetOp getOp) {
99  if (auto indexAndBitWidth =
100  tryExtractIndexAndBitWidth(getOp.getIndex())) {
101  auto [indexValue, bitWidth] = *indexAndBitWidth;
102  // FIXME: More efficient search
103  for (const auto &[array, size] : arrays) {
104  if (indexValue >= size) {
105  indexValue -= size;
106  continue;
107  }
108  OpBuilder builder(getOp);
109  getOp.getInputMutable().set(array);
110  getOp.getIndexMutable().set(
111  builder.createOrFold<hw::ConstantOp>(
112  getOp.getLoc(), APInt(bitWidth, indexValue)));
113  return true;
114  }
115  }
116 
117  return false;
118  })
119  .Default([](auto op) { return false; }))
120  continue;
121 
122  op.emitError("unsupported packed array expression");
123  signalPassFailure();
124  }
125 
126  // Remove the original op.
127  return true;
128  })
129  .Case<hw::ArrayCreateOp>([&](hw::ArrayCreateOp createOp) {
130  // Replace individual element uses (if any) with input arguments.
131  SmallVector<Value> inputs(llvm::reverse(createOp.getInputs()));
132  if (!processUsers(op, createOp.getResult(), inputs))
133  return false;
134 
135  // Remove original op.
136  return true;
137  })
138  .Case<hw::ArrayGetOp>([&](hw::ArrayGetOp getOp) {
139  // Skip index ops with constant index.
140  auto index = getOp.getIndex();
141  if (auto *definingOp = index.getDefiningOp())
142  if (isa<hw::ConstantOp>(definingOp))
143  return false;
144 
145  // Generate case value element lookups.
146  auto ty = hw::type_cast<hw::ArrayType>(getOp.getInput().getType());
147  OpBuilder builder(getOp);
148  SmallVector<Value> caseValues;
149  for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
150  auto loc = op.getLoc();
151  auto index = builder.createOrFold<hw::ConstantOp>(
152  loc, APInt(llvm::Log2_64_Ceil(e), i));
153  auto element =
154  builder.create<hw::ArrayGetOp>(loc, getOp.getInput(), index);
155  caseValues.push_back(element);
156  }
157 
158  // Transform array index op into casez statement.
159  auto theWire = lowerLookupToCasez(op, getOp.getInput(), index,
160  ty.getElementType(), caseValues);
161 
162  // Emit the read from the wire, replace uses and clean up.
163  builder.setInsertionPoint(getOp);
164  auto readWire =
165  builder.create<sv::ReadInOutOp>(getOp.getLoc(), theWire);
166  getOp.getResult().replaceAllUsesWith(readWire);
167  return true;
168  })
169  .Case<sv::ArrayIndexInOutOp>([&](sv::ArrayIndexInOutOp indexOp) {
170  // Skip index ops with constant index.
171  auto index = indexOp.getIndex();
172  if (auto *definingOp = index.getDefiningOp())
173  if (isa<hw::ConstantOp>(definingOp))
174  return false;
175 
176  // Skip index ops with unpacked arrays.
177  auto inout = indexOp.getInput().getType();
178  if (hw::type_isa<hw::UnpackedArrayType>(inout.getElementType()))
179  return false;
180 
181  // Generate case value element lookups.
182  auto ty = hw::type_cast<hw::ArrayType>(inout.getElementType());
183  OpBuilder builder(&op);
184  SmallVector<Value> caseValues;
185  for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
186  auto loc = op.getLoc();
187  auto index = builder.createOrFold<hw::ConstantOp>(
188  loc, APInt(llvm::Log2_64_Ceil(e), i));
189  auto element = builder.create<sv::ArrayIndexInOutOp>(
190  loc, indexOp.getInput(), index);
191  auto readElement = builder.create<sv::ReadInOutOp>(loc, element);
192  caseValues.push_back(readElement);
193  }
194 
195  // Transform array index op into casez statement.
196  auto theWire = lowerLookupToCasez(op, indexOp.getInput(), index,
197  ty.getElementType(), caseValues);
198 
199  // Replace uses and clean up.
200  indexOp.getResult().replaceAllUsesWith(theWire);
201  return true;
202  })
203  .Case<sv::PAssignOp>([&](sv::PAssignOp assignOp) {
204  // Transform array assignment into individual assignments for each array
205  // element.
206  auto inout = assignOp.getDest().getType();
207  auto ty = hw::type_dyn_cast<hw::ArrayType>(inout.getElementType());
208  if (!ty)
209  return false;
210 
211  OpBuilder builder(assignOp);
212  for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
213  auto loc = op.getLoc();
214  auto index = builder.createOrFold<hw::ConstantOp>(
215  loc, APInt(llvm::Log2_64_Ceil(e), i));
216  auto dstElement = builder.create<sv::ArrayIndexInOutOp>(
217  loc, assignOp.getDest(), index);
218  auto srcElement =
219  builder.create<hw::ArrayGetOp>(loc, assignOp.getSrc(), index);
220  builder.create<sv::PAssignOp>(loc, dstElement, srcElement);
221  }
222 
223  // Remove original assignment.
224  return true;
225  })
226  .Case<sv::RegOp>([&](sv::RegOp regOp) {
227  // Transform array reg into individual regs for each array element.
228  auto ty = hw::type_dyn_cast<hw::ArrayType>(regOp.getElementType());
229  if (!ty)
230  return false;
231 
232  OpBuilder builder(regOp);
233  auto name = StringAttr::get(regOp.getContext(), "name");
234  SmallVector<Value> elements;
235  for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
236  auto loc = op.getLoc();
237  auto element = builder.create<sv::RegOp>(loc, ty.getElementType());
238  if (auto nameAttr = regOp->getAttrOfType<StringAttr>(name)) {
239  element.setNameAttr(
240  StringAttr::get(regOp.getContext(), nameAttr.getValue()));
241  }
242  elements.push_back(element);
243  }
244 
245  // Fix users to refer to individual element regs.
246  if (!processUsers(op, regOp.getResult(), elements))
247  return false;
248 
249  // Remove original reg.
250  return true;
251  })
252  .Default([&](auto op) { return false; });
253 }
254 
255 Value HWLegalizeModulesPass::lowerLookupToCasez(Operation &op, Value input,
256  Value index,
257  mlir::Type elementType,
258  SmallVector<Value> caseValues) {
259  // Create the wire for the result of the casez in the
260  // hw.module.
261  OpBuilder builder(&op);
262  auto theWire = builder.create<sv::RegOp>(op.getLoc(), elementType,
263  builder.getStringAttr("casez_tmp"));
264  builder.setInsertionPoint(&op);
265 
266  auto loc = input.getLoc();
267  // A casez is a procedural operation, so if we're in a
268  // non-procedural region we need to inject an always_comb
269  // block.
270  if (!op.getParentOp()->hasTrait<sv::ProceduralRegion>()) {
271  auto alwaysComb = builder.create<sv::AlwaysCombOp>(loc);
272  builder.setInsertionPointToEnd(alwaysComb.getBodyBlock());
273  }
274 
275  // If we are missing elements in the array (it is non-power of
276  // two), then add a default 'X' value.
277  if (1ULL << index.getType().getIntOrFloatBitWidth() != caseValues.size()) {
278  caseValues.push_back(builder.create<sv::ConstantXOp>(
279  op.getLoc(), op.getResult(0).getType()));
280  }
281 
282  APInt caseValue(index.getType().getIntOrFloatBitWidth(), 0);
283  auto *context = builder.getContext();
284 
285  // Create the casez itself.
286  builder.create<sv::CaseOp>(
287  loc, CaseStmtType::CaseZStmt, index, caseValues.size(),
288  [&](size_t caseIdx) -> std::unique_ptr<sv::CasePattern> {
289  // Use a default pattern for the last value, even if we
290  // are complete. This avoids tools thinking they need to
291  // insert a latch due to potentially incomplete case
292  // coverage.
293  bool isDefault = caseIdx == caseValues.size() - 1;
294  Value theValue = caseValues[caseIdx];
295  std::unique_ptr<sv::CasePattern> thePattern;
296 
297  if (isDefault)
298  thePattern = std::make_unique<sv::CaseDefaultPattern>(context);
299  else
300  thePattern = std::make_unique<sv::CaseBitPattern>(caseValue, context);
301  ++caseValue;
302  builder.create<sv::BPAssignOp>(loc, theWire, theValue);
303  return thePattern;
304  });
305 
306  return theWire;
307 }
308 
309 bool HWLegalizeModulesPass::processUsers(Operation &op, Value value,
310  ArrayRef<Value> mapping) {
311  for (auto *user : llvm::make_early_inc_range(value.getUsers())) {
312  if (TypeSwitch<Operation *, bool>(user)
313  .Case<hw::ArrayGetOp>([&](hw::ArrayGetOp getOp) {
314  if (auto indexAndBitWidth =
315  tryExtractIndexAndBitWidth(getOp.getIndex())) {
316  getOp.replaceAllUsesWith(mapping[indexAndBitWidth->first]);
317  return true;
318  }
319 
320  return false;
321  })
322  .Case<sv::ArrayIndexInOutOp>([&](sv::ArrayIndexInOutOp indexOp) {
323  if (auto indexAndBitWidth =
324  tryExtractIndexAndBitWidth(indexOp.getIndex())) {
325  indexOp.replaceAllUsesWith(mapping[indexAndBitWidth->first]);
326  return true;
327  }
328 
329  return false;
330  })
331  .Default([](auto op) { return false; })) {
332  user->erase();
333  continue;
334  }
335 
336  user->emitError("unsupported packed array expression");
337  signalPassFailure();
338  return false;
339  }
340 
341  return true;
342 }
343 
344 std::optional<std::pair<uint64_t, unsigned>>
345 HWLegalizeModulesPass::tryExtractIndexAndBitWidth(Value value) {
346  if (auto constantOp = dyn_cast<hw::ConstantOp>(value.getDefiningOp())) {
347  auto index = constantOp.getValue();
348  return std::make_optional(
349  std::make_pair(index.getZExtValue(), index.getBitWidth()));
350  }
351  return std::nullopt;
352 }
353 
354 void HWLegalizeModulesPass::processPostOrder(Block &body) {
355  if (body.empty())
356  return;
357 
358  // Walk the block bottom-up, processing the region tree inside out.
359  Block::iterator it = std::prev(body.end());
360  while (it != body.end()) {
361  auto &op = *it;
362 
363  // Advance the iterator, using the end iterator as a sentinel that we're at
364  // the top of the block.
365  if (it == body.begin())
366  it = body.end();
367  else
368  --it;
369 
370  if (op.getNumRegions()) {
371  for (auto &region : op.getRegions())
372  for (auto &regionBlock : region.getBlocks())
373  processPostOrder(regionBlock);
374  }
375 
376  if (options.disallowPackedArrays) {
377  // Try supported packed array op lowering.
378  if (tryLoweringPackedArrayOp(op)) {
379  it = --Block::iterator(op);
380  op.erase();
381  anythingChanged = true;
382  continue;
383  }
384 
385  // Otherwise, if the IR produces a packed array and we aren't allowing
386  // multi-dimensional arrays, reject the IR as invalid.
387  for (auto value : op.getResults()) {
388  if (isa<hw::ArrayType>(value.getType())) {
389  op.emitError("unsupported packed array expression");
390  signalPassFailure();
391  }
392  }
393  }
394  }
395 }
396 
397 void HWLegalizeModulesPass::runOnOperation() {
398  thisHWModule = getOperation();
399 
400  // Parse the lowering options if necessary.
401  auto optionsAttr = LoweringOptions::getAttributeFrom(
402  cast<ModuleOp>(thisHWModule->getParentOp()));
403  if (optionsAttr != lastParsedOptions) {
404  if (optionsAttr)
405  options = LoweringOptions(optionsAttr.getValue(), [&](Twine error) {
406  thisHWModule.emitError(error);
407  });
408  else
409  options = LoweringOptions();
410  lastParsedOptions = optionsAttr;
411  }
412 
413  // Keeps track if anything changed during this pass, used to determine if
414  // the analyses were preserved.
415  anythingChanged = false;
416 
417  // Walk the operations in post-order, transforming any that are interesting.
418  processPostOrder(*thisHWModule.getBodyBlock());
419 
420  // If we did not change anything in the IR mark all analysis as preserved.
421  if (!anythingChanged)
422  markAllAnalysesPreserved();
423 }
424 
425 std::unique_ptr<Pass> circt::sv::createHWLegalizeModulesPass() {
426  return std::make_unique<HWLegalizeModulesPass>();
427 }
MlirType elementType
Definition: CHIRRTL.cpp:29
def create(array_value, idx)
Definition: hw.py:450
def create(data_type, value)
Definition: hw.py:433
Definition: sv.py:68
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
std::unique_ptr< mlir::Pass > createHWLegalizeModulesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: sv.py:1
Options which control the emission from CIRCT to Verilog.