11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Pass/Pass.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 #include "llvm/Support/Debug.h"
17 #define DEBUG_TYPE "arc-lower-clocks-to-funcs"
21 #define GEN_PASS_DEF_LOWERCLOCKSTOFUNCS
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
27 using namespace circt;
30 using mlir::OpTrait::ConstantLike;
37 struct LowerClocksToFuncsPass
38 :
public arc::impl::LowerClocksToFuncsBase<LowerClocksToFuncsPass> {
39 LowerClocksToFuncsPass() =
default;
40 LowerClocksToFuncsPass(
const LowerClocksToFuncsPass &pass)
41 : LowerClocksToFuncsPass() {}
43 void runOnOperation()
override;
44 LogicalResult lowerModel(ModelOp modelOp);
45 LogicalResult lowerClock(Operation *clockOp, Value modelStorageArg,
46 OpBuilder &funcBuilder);
47 LogicalResult isolateClock(Operation *clockOp, Value modelStorageArg,
48 Value clockStorageArg);
50 SymbolTable *symbolTable;
52 Statistic numOpsCopied{
this,
"ops-copied",
"Ops copied into clock trees"};
53 Statistic numOpsMoved{
this,
"ops-moved",
"Ops moved into clock trees"};
57 void LowerClocksToFuncsPass::runOnOperation() {
58 symbolTable = &getAnalysis<SymbolTable>();
59 for (
auto op : getOperation().getOps<ModelOp>())
60 if (failed(lowerModel(op)))
61 return signalPassFailure();
64 LogicalResult LowerClocksToFuncsPass::lowerModel(ModelOp modelOp) {
65 LLVM_DEBUG(llvm::dbgs() <<
"Lowering clocks in `" << modelOp.getName()
69 SmallVector<InitialOp, 1> initialOps;
70 SmallVector<FinalOp, 1> finalOps;
71 SmallVector<Operation *> clocks;
72 modelOp.walk([&](Operation *op) {
73 TypeSwitch<Operation *, void>(op)
74 .Case<InitialOp>([&](
auto initOp) {
75 initialOps.push_back(initOp);
76 clocks.push_back(initOp);
78 .Case<FinalOp>([&](
auto op) {
79 finalOps.push_back(op);
85 if (initialOps.size() > 1) {
86 auto diag = modelOp.emitOpError()
87 <<
"containing multiple InitialOps is currently unsupported.";
88 for (
auto initOp : initialOps)
89 diag.attachNote(initOp.getLoc()) <<
"Conflicting InitialOp:";
91 if (finalOps.size() > 1) {
92 auto diag = modelOp.emitOpError()
93 <<
"containing multiple FinalOps is currently unsupported.";
94 for (
auto op : finalOps)
95 diag.attachNote(op.getLoc()) <<
"Conflicting FinalOp:";
97 if (initialOps.size() > 1 || finalOps.size() > 1)
101 OpBuilder funcBuilder(modelOp);
102 for (
auto *op : clocks)
103 if (failed(lowerClock(op, modelOp.getBody().getArgument(0), funcBuilder)))
109 LogicalResult LowerClocksToFuncsPass::lowerClock(Operation *clockOp,
110 Value modelStorageArg,
111 OpBuilder &funcBuilder) {
112 LLVM_DEBUG(llvm::dbgs() <<
"- Lowering clock " << clockOp->getName() <<
"\n");
113 assert((isa<InitialOp, FinalOp>(clockOp)));
118 Region &clockRegion = clockOp->getRegion(0);
119 Value clockStorageArg = clockRegion.addArgument(modelStorageArg.getType(),
120 modelStorageArg.getLoc());
123 if (failed(isolateClock(clockOp, modelStorageArg, clockStorageArg)))
127 auto builder = OpBuilder::atBlockEnd(&clockRegion.front());
128 builder.create<func::ReturnOp>(clockOp->getLoc());
131 SmallString<32> funcName;
132 auto modelOp = clockOp->getParentOfType<ModelOp>();
133 funcName.append(modelOp.getName());
135 if (isa<InitialOp>(clockOp))
136 funcName.append(
"_initial");
137 else if (isa<FinalOp>(clockOp))
138 funcName.append(
"_final");
140 auto funcOp = funcBuilder.create<func::FuncOp>(
141 clockOp->getLoc(), funcName,
142 builder.getFunctionType({modelStorageArg.getType()}, {}));
143 symbolTable->insert(funcOp);
144 LLVM_DEBUG(llvm::dbgs() <<
" - Created function `" << funcOp.getSymName()
148 builder.setInsertionPoint(clockOp);
149 TypeSwitch<Operation *, void>(clockOp)
150 .Case<InitialOp>([&](
auto) {
151 if (modelOp.getInitialFn().has_value())
152 modelOp.emitWarning() <<
"Existing model initializer '"
153 << modelOp.getInitialFnAttr().getValue()
154 <<
"' will be overridden.";
155 modelOp.setInitialFnAttr(
158 .Case<FinalOp>([&](
auto) {
159 if (modelOp.getFinalFn().has_value())
160 modelOp.emitWarning()
161 <<
"Existing model finalizer '"
162 << modelOp.getFinalFnAttr().getValue() <<
"' will be overridden.";
167 funcOp.getBody().takeBody(clockRegion);
176 LogicalResult LowerClocksToFuncsPass::isolateClock(Operation *clockOp,
177 Value modelStorageArg,
178 Value clockStorageArg) {
179 auto *clockRegion = &clockOp->getRegion(0);
180 auto builder = OpBuilder::atBlockBegin(&clockRegion->front());
181 DenseMap<Value, Value> copiedValues;
182 auto result = clockRegion->walk([&](Operation *op) {
183 for (
auto &operand : op->getOpOperands()) {
185 if (operand.get() == modelStorageArg) {
186 operand.set(clockStorageArg);
189 if (isa<BlockArgument>(operand.get())) {
190 auto d = op->emitError(
191 "operation in clock tree uses external block argument");
192 d.attachNote() <<
"clock trees can only use external constant values";
193 d.attachNote() <<
"see operand #" << operand.getOperandNumber();
194 d.attachNote(clockOp->getLoc()) <<
"clock tree:";
195 return WalkResult::interrupt();
199 auto *definingOp = operand.get().getDefiningOp();
200 assert(definingOp &&
"block arguments ruled out above");
201 Region *definingRegion = definingOp->getParentRegion();
202 if (clockRegion->isAncestor(definingRegion))
207 if (
auto copiedValue = copiedValues.lookup(operand.get())) {
208 operand.set(copiedValue);
213 if (!definingOp->hasTrait<ConstantLike>()) {
214 auto d = op->emitError(
"operation in clock tree uses external value");
215 d.attachNote() <<
"clock trees can only use external constant values";
216 d.attachNote(definingOp->getLoc()) <<
"external value defined here:";
217 d.attachNote(clockOp->getLoc()) <<
"clock tree:";
218 return WalkResult::interrupt();
223 bool canMove = llvm::all_of(definingOp->getUsers(), [&](Operation *user) {
224 return clockRegion->isAncestor(user->getParentRegion());
228 definingOp->remove();
229 clonedOp = definingOp;
232 clonedOp = definingOp->cloneWithoutRegions();
235 builder.insert(clonedOp);
237 for (
auto [outerResult, innerResult] :
238 llvm::zip(definingOp->getResults(), clonedOp->getResults())) {
239 copiedValues.insert({outerResult, innerResult});
240 if (operand.get() == outerResult)
241 operand.set(innerResult);
245 return WalkResult::advance();
247 return success(!result.wasInterrupted());
251 return std::make_unique<LowerClocksToFuncsPass>();
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createLowerClocksToFuncsPass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.