CIRCT  19.0.0git
DropConst.cpp
Go to the documentation of this file.
1 //===- DropConst.cpp - Check and remove const types -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the DropConst pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
19 #include "mlir/IR/Threading.h"
20 
21 using namespace circt;
22 using namespace firrtl;
23 
24 /// Returns null type if no conversion is needed.
26  auto nonConstType = type.getAllConstDroppedType();
27  return nonConstType != type ? nonConstType : FIRRTLBaseType{};
28 }
29 
30 /// Returns null type if no conversion is needed.
31 static Type convertType(Type type) {
32  if (auto base = type_dyn_cast<FIRRTLBaseType>(type)) {
33  return convertType(base);
34  }
35 
36  if (auto refType = type_dyn_cast<RefType>(type)) {
37  if (auto converted = convertType(refType.getType()))
38  return RefType::get(converted, refType.getForceable(),
39  refType.getLayer());
40  }
41 
42  return {};
43 }
44 
45 namespace {
46 class DropConstPass : public DropConstBase<DropConstPass> {
47  void runOnOperation() override {
48  mlir::parallelForEach(
49  &getContext(), getOperation().getOps<firrtl::FModuleLike>(),
50  [](auto module) {
51  // Convert the module body if present
52  module->walk([](Operation *op) {
53  if (auto constCastOp = dyn_cast<ConstCastOp>(op)) {
54  // Remove any `ConstCastOp`, replacing results with inputs
55  constCastOp.getResult().replaceAllUsesWith(
56  constCastOp.getInput());
57  constCastOp->erase();
58  return;
59  }
60 
61  // Convert any block arguments
62  for (auto &region : op->getRegions())
63  for (auto &block : region.getBlocks())
64  for (auto argument : block.getArguments())
65  if (auto convertedType = convertType(argument.getType()))
66  argument.setType(convertedType);
67 
68  for (auto result : op->getResults())
69  if (auto convertedType = convertType(result.getType()))
70  result.setType(convertedType);
71  });
72 
73  // Update the module signature with non-'const' ports
74  SmallVector<Attribute> portTypes;
75  portTypes.reserve(module.getNumPorts());
76  bool convertedAny = false;
77  llvm::transform(module.getPortTypes(), std::back_inserter(portTypes),
78  [&](Attribute type) -> Attribute {
79  if (auto convertedType = convertType(
80  cast<TypeAttr>(type).getValue())) {
81  convertedAny = true;
82  return TypeAttr::get(convertedType);
83  }
84  return type;
85  });
86  if (convertedAny)
87  module->setAttr(FModuleLike::getPortTypesAttrName(),
88  ArrayAttr::get(module.getContext(), portTypes));
89  });
90 
91  markAnalysesPreserved<InstanceGraph>();
92  }
93 };
94 } // namespace
95 
96 std::unique_ptr<mlir::Pass> circt::firrtl::createDropConstPass() {
97  return std::make_unique<DropConstPass>();
98 }
static FIRRTLBaseType convertType(FIRRTLBaseType type)
Returns null type if no conversion is needed.
Definition: DropConst.cpp:25
FIRRTLBaseType getAllConstDroppedType()
Return this type with a 'const' modifiers dropped.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
std::unique_ptr< mlir::Pass > createDropConstPass()
Definition: DropConst.cpp:96
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21