CIRCT  19.0.0git
HWLegalizeModules.cpp
Go to the documentation of this file.
1 //===- HWLegalizeModulesPass.cpp - Lower unsupported IR features away -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass lowers away features in the SV/Comb/HW dialects that are
10 // unsupported by some tools (e.g. multidimensional arrays) as specified by
11 // LoweringOptions. This pass is run relatively late in the pipeline in
12 // preparation for emission. Any passes run after this (e.g. PrettifyVerilog)
13 // must be aware they cannot introduce new invalid constructs.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "circt/Dialect/HW/HWOps.h"
22 #include "mlir/IR/Builders.h"
23 
24 using namespace circt;
25 
26 //===----------------------------------------------------------------------===//
27 // HWLegalizeModulesPass
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 struct HWLegalizeModulesPass
32  : public sv::HWLegalizeModulesBase<HWLegalizeModulesPass> {
33  void runOnOperation() override;
34 
35 private:
36  void processPostOrder(Block &block);
37  bool tryLoweringPackedArrayOp(Operation &op);
38  Value lowerLookupToCasez(Operation &op, Value input, Value index,
39  mlir::Type elementType,
40  SmallVector<Value> caseValues);
41  bool processUsers(Operation &op, Value value, ArrayRef<Value> mapping);
42  std::optional<std::pair<uint64_t, unsigned>>
43  tryExtractIndexAndBitWidth(Value value);
44 
45  /// This is the current hw.module being processed.
46  hw::HWModuleOp thisHWModule;
47 
48  bool anythingChanged;
49 
50  /// This tells us what language features we're allowed to use in generated
51  /// Verilog.
52  LoweringOptions options;
53 
54  /// This pass will be run on multiple hw.modules, this keeps track of the
55  /// contents of LoweringOptions so we don't have to reparse the
56  /// LoweringOptions for every hw.module.
57  StringAttr lastParsedOptions;
58 };
59 } // end anonymous namespace
60 
61 bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
62  return TypeSwitch<Operation *, bool>(&op)
63  .Case<hw::AggregateConstantOp>([&](hw::AggregateConstantOp constOp) {
64  // Replace individual element uses (if any) with input fields.
65  SmallVector<Value> inputs;
66  OpBuilder builder(constOp);
67  for (auto field : llvm::reverse(constOp.getFields())) {
68  if (auto intAttr = dyn_cast<IntegerAttr>(field))
69  inputs.push_back(
70  builder.create<hw::ConstantOp>(constOp.getLoc(), intAttr));
71  else
72  inputs.push_back(builder.create<hw::AggregateConstantOp>(
73  constOp.getLoc(), constOp.getType(), field.cast<ArrayAttr>()));
74  }
75  if (!processUsers(op, constOp.getResult(), inputs))
76  return false;
77 
78  // Remove original op.
79  return true;
80  })
81  .Case<hw::ArrayConcatOp>([&](hw::ArrayConcatOp concatOp) {
82  // Redirect individual element uses (if any) to the input arguments.
83  SmallVector<std::pair<Value, uint64_t>> arrays;
84  for (auto array : llvm::reverse(concatOp.getInputs())) {
85  auto ty = hw::type_cast<hw::ArrayType>(array.getType());
86  arrays.emplace_back(array, ty.getNumElements());
87  }
88  for (auto *user :
89  llvm::make_early_inc_range(concatOp.getResult().getUsers())) {
90  if (TypeSwitch<Operation *, bool>(user)
91  .Case<hw::ArrayGetOp>([&](hw::ArrayGetOp getOp) {
92  if (auto indexAndBitWidth =
93  tryExtractIndexAndBitWidth(getOp.getIndex())) {
94  auto [indexValue, bitWidth] = *indexAndBitWidth;
95  // FIXME: More efficient search
96  for (const auto &[array, size] : arrays) {
97  if (indexValue >= size) {
98  indexValue -= size;
99  continue;
100  }
101  OpBuilder builder(getOp);
102  getOp.getInputMutable().set(array);
103  getOp.getIndexMutable().set(
104  builder.createOrFold<hw::ConstantOp>(
105  getOp.getLoc(), APInt(bitWidth, indexValue)));
106  return true;
107  }
108  }
109 
110  return false;
111  })
112  .Default([](auto op) { return false; }))
113  continue;
114 
115  op.emitError("unsupported packed array expression");
116  signalPassFailure();
117  }
118 
119  // Remove the original op.
120  return true;
121  })
122  .Case<hw::ArrayCreateOp>([&](hw::ArrayCreateOp createOp) {
123  // Replace individual element uses (if any) with input arguments.
124  SmallVector<Value> inputs(llvm::reverse(createOp.getInputs()));
125  if (!processUsers(op, createOp.getResult(), inputs))
126  return false;
127 
128  // Remove original op.
129  return true;
130  })
131  .Case<hw::ArrayGetOp>([&](hw::ArrayGetOp getOp) {
132  // Skip index ops with constant index.
133  auto index = getOp.getIndex();
134  if (auto *definingOp = index.getDefiningOp())
135  if (isa<hw::ConstantOp>(definingOp))
136  return false;
137 
138  // Generate case value element lookups.
139  auto ty = hw::type_cast<hw::ArrayType>(getOp.getInput().getType());
140  OpBuilder builder(getOp);
141  SmallVector<Value> caseValues;
142  for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
143  auto loc = op.getLoc();
144  auto index = builder.createOrFold<hw::ConstantOp>(
145  loc, APInt(llvm::Log2_64_Ceil(e), i));
146  auto element =
147  builder.create<hw::ArrayGetOp>(loc, getOp.getInput(), index);
148  caseValues.push_back(element);
149  }
150 
151  // Transform array index op into casez statement.
152  auto theWire = lowerLookupToCasez(op, getOp.getInput(), index,
153  ty.getElementType(), caseValues);
154 
155  // Emit the read from the wire, replace uses and clean up.
156  builder.setInsertionPoint(getOp);
157  auto readWire =
158  builder.create<sv::ReadInOutOp>(getOp.getLoc(), theWire);
159  getOp.getResult().replaceAllUsesWith(readWire);
160  return true;
161  })
162  .Case<sv::ArrayIndexInOutOp>([&](sv::ArrayIndexInOutOp indexOp) {
163  // Skip index ops with constant index.
164  auto index = indexOp.getIndex();
165  if (auto *definingOp = index.getDefiningOp())
166  if (isa<hw::ConstantOp>(definingOp))
167  return false;
168 
169  // Skip index ops with unpacked arrays.
170  auto inout = indexOp.getInput().getType();
171  if (hw::type_isa<hw::UnpackedArrayType>(inout.getElementType()))
172  return false;
173 
174  // Generate case value element lookups.
175  auto ty = hw::type_cast<hw::ArrayType>(inout.getElementType());
176  OpBuilder builder(&op);
177  SmallVector<Value> caseValues;
178  for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
179  auto loc = op.getLoc();
180  auto index = builder.createOrFold<hw::ConstantOp>(
181  loc, APInt(llvm::Log2_64_Ceil(e), i));
182  auto element = builder.create<sv::ArrayIndexInOutOp>(
183  loc, indexOp.getInput(), index);
184  auto readElement = builder.create<sv::ReadInOutOp>(loc, element);
185  caseValues.push_back(readElement);
186  }
187 
188  // Transform array index op into casez statement.
189  auto theWire = lowerLookupToCasez(op, indexOp.getInput(), index,
190  ty.getElementType(), caseValues);
191 
192  // Replace uses and clean up.
193  indexOp.getResult().replaceAllUsesWith(theWire);
194  return true;
195  })
196  .Case<sv::PAssignOp>([&](sv::PAssignOp assignOp) {
197  // Transform array assignment into individual assignments for each array
198  // element.
199  auto inout = assignOp.getDest().getType();
200  auto ty = hw::type_dyn_cast<hw::ArrayType>(inout.getElementType());
201  if (!ty)
202  return false;
203 
204  OpBuilder builder(assignOp);
205  for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
206  auto loc = op.getLoc();
207  auto index = builder.createOrFold<hw::ConstantOp>(
208  loc, APInt(llvm::Log2_64_Ceil(e), i));
209  auto dstElement = builder.create<sv::ArrayIndexInOutOp>(
210  loc, assignOp.getDest(), index);
211  auto srcElement =
212  builder.create<hw::ArrayGetOp>(loc, assignOp.getSrc(), index);
213  builder.create<sv::PAssignOp>(loc, dstElement, srcElement);
214  }
215 
216  // Remove original assignment.
217  return true;
218  })
219  .Case<sv::RegOp>([&](sv::RegOp regOp) {
220  // Transform array reg into individual regs for each array element.
221  auto ty = hw::type_dyn_cast<hw::ArrayType>(regOp.getElementType());
222  if (!ty)
223  return false;
224 
225  OpBuilder builder(regOp);
226  auto name = StringAttr::get(regOp.getContext(), "name");
227  SmallVector<Value> elements;
228  for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
229  auto loc = op.getLoc();
230  auto element = builder.create<sv::RegOp>(loc, ty.getElementType());
231  if (auto nameAttr = regOp->getAttrOfType<StringAttr>(name)) {
232  element.setNameAttr(
233  StringAttr::get(regOp.getContext(), nameAttr.getValue()));
234  }
235  elements.push_back(element);
236  }
237 
238  // Fix users to refer to individual element regs.
239  if (!processUsers(op, regOp.getResult(), elements))
240  return false;
241 
242  // Remove original reg.
243  return true;
244  })
245  .Default([&](auto op) { return false; });
246 }
247 
248 Value HWLegalizeModulesPass::lowerLookupToCasez(Operation &op, Value input,
249  Value index,
250  mlir::Type elementType,
251  SmallVector<Value> caseValues) {
252  // Create the wire for the result of the casez in the
253  // hw.module.
254  OpBuilder builder(&op);
255  auto theWire = builder.create<sv::RegOp>(op.getLoc(), elementType,
256  builder.getStringAttr("casez_tmp"));
257  builder.setInsertionPoint(&op);
258 
259  auto loc = input.getDefiningOp()->getLoc();
260  // A casez is a procedural operation, so if we're in a
261  // non-procedural region we need to inject an always_comb
262  // block.
263  if (!op.getParentOp()->hasTrait<sv::ProceduralRegion>()) {
264  auto alwaysComb = builder.create<sv::AlwaysCombOp>(loc);
265  builder.setInsertionPointToEnd(alwaysComb.getBodyBlock());
266  }
267 
268  // If we are missing elements in the array (it is non-power of
269  // two), then add a default 'X' value.
270  if (1ULL << index.getType().getIntOrFloatBitWidth() != caseValues.size()) {
271  caseValues.push_back(builder.create<sv::ConstantXOp>(
272  op.getLoc(), op.getResult(0).getType()));
273  }
274 
275  APInt caseValue(index.getType().getIntOrFloatBitWidth(), 0);
276  auto *context = builder.getContext();
277 
278  // Create the casez itself.
279  builder.create<sv::CaseOp>(
280  loc, CaseStmtType::CaseZStmt, index, caseValues.size(),
281  [&](size_t caseIdx) -> std::unique_ptr<sv::CasePattern> {
282  // Use a default pattern for the last value, even if we
283  // are complete. This avoids tools thinking they need to
284  // insert a latch due to potentially incomplete case
285  // coverage.
286  bool isDefault = caseIdx == caseValues.size() - 1;
287  Value theValue = caseValues[caseIdx];
288  std::unique_ptr<sv::CasePattern> thePattern;
289 
290  if (isDefault)
291  thePattern = std::make_unique<sv::CaseDefaultPattern>(context);
292  else
293  thePattern = std::make_unique<sv::CaseBitPattern>(caseValue, context);
294  ++caseValue;
295  builder.create<sv::BPAssignOp>(loc, theWire, theValue);
296  return thePattern;
297  });
298 
299  return theWire;
300 }
301 
302 bool HWLegalizeModulesPass::processUsers(Operation &op, Value value,
303  ArrayRef<Value> mapping) {
304  for (auto *user : llvm::make_early_inc_range(value.getUsers())) {
305  if (TypeSwitch<Operation *, bool>(user)
306  .Case<hw::ArrayGetOp>([&](hw::ArrayGetOp getOp) {
307  if (auto indexAndBitWidth =
308  tryExtractIndexAndBitWidth(getOp.getIndex())) {
309  getOp.replaceAllUsesWith(mapping[indexAndBitWidth->first]);
310  return true;
311  }
312 
313  return false;
314  })
315  .Case<sv::ArrayIndexInOutOp>([&](sv::ArrayIndexInOutOp indexOp) {
316  if (auto indexAndBitWidth =
317  tryExtractIndexAndBitWidth(indexOp.getIndex())) {
318  indexOp.replaceAllUsesWith(mapping[indexAndBitWidth->first]);
319  return true;
320  }
321 
322  return false;
323  })
324  .Default([](auto op) { return false; })) {
325  user->erase();
326  continue;
327  }
328 
329  user->emitError("unsupported packed array expression");
330  signalPassFailure();
331  return false;
332  }
333 
334  return true;
335 }
336 
337 std::optional<std::pair<uint64_t, unsigned>>
338 HWLegalizeModulesPass::tryExtractIndexAndBitWidth(Value value) {
339  if (auto constantOp = dyn_cast<hw::ConstantOp>(value.getDefiningOp())) {
340  auto index = constantOp.getValue();
341  return std::make_optional(
342  std::make_pair(index.getZExtValue(), index.getBitWidth()));
343  }
344  return std::nullopt;
345 }
346 
347 void HWLegalizeModulesPass::processPostOrder(Block &body) {
348  if (body.empty())
349  return;
350 
351  // Walk the block bottom-up, processing the region tree inside out.
352  Block::iterator it = std::prev(body.end());
353  while (it != body.end()) {
354  auto &op = *it;
355 
356  // Advance the iterator, using the end iterator as a sentinel that we're at
357  // the top of the block.
358  if (it == body.begin())
359  it = body.end();
360  else
361  --it;
362 
363  if (op.getNumRegions()) {
364  for (auto &region : op.getRegions())
365  for (auto &regionBlock : region.getBlocks())
366  processPostOrder(regionBlock);
367  }
368 
369  if (options.disallowPackedArrays) {
370  // Try supported packed array op lowering.
371  if (tryLoweringPackedArrayOp(op)) {
372  it = --Block::iterator(op);
373  op.erase();
374  anythingChanged = true;
375  continue;
376  }
377 
378  // Otherwise, if the IR produces a packed array and we aren't allowing
379  // multi-dimensional arrays, reject the IR as invalid.
380  for (auto value : op.getResults()) {
381  if (value.getType().isa<hw::ArrayType>()) {
382  op.emitError("unsupported packed array expression");
383  signalPassFailure();
384  }
385  }
386  }
387  }
388 }
389 
390 void HWLegalizeModulesPass::runOnOperation() {
391  thisHWModule = getOperation();
392 
393  // Parse the lowering options if necessary.
394  auto optionsAttr = LoweringOptions::getAttributeFrom(
395  cast<ModuleOp>(thisHWModule->getParentOp()));
396  if (optionsAttr != lastParsedOptions) {
397  if (optionsAttr)
398  options = LoweringOptions(optionsAttr.getValue(), [&](Twine error) {
399  thisHWModule.emitError(error);
400  });
401  else
402  options = LoweringOptions();
403  lastParsedOptions = optionsAttr;
404  }
405 
406  // Keeps track if anything changed during this pass, used to determine if
407  // the analyses were preserved.
408  anythingChanged = false;
409 
410  // Walk the operations in post-order, transforming any that are interesting.
411  processPostOrder(*thisHWModule.getBodyBlock());
412 
413  // If we did not change anything in the IR mark all analysis as preserved.
414  if (!anythingChanged)
415  markAllAnalysesPreserved();
416 }
417 
418 std::unique_ptr<Pass> circt::sv::createHWLegalizeModulesPass() {
419  return std::make_unique<HWLegalizeModulesPass>();
420 }
MlirType elementType
Definition: CHIRRTL.cpp:29
llvm::SmallVector< StringAttr > inputs
Builder builder
Definition: sv.py:68
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
std::unique_ptr< mlir::Pass > createHWLegalizeModulesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Options which control the emission from CIRCT to Verilog.