11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Pass/Pass.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 #include "llvm/Support/Debug.h"
17 #define DEBUG_TYPE "arc-lower-clocks-to-funcs"
21 #define GEN_PASS_DEF_LOWERCLOCKSTOFUNCS
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
27 using namespace circt;
30 using mlir::OpTrait::ConstantLike;
37 struct LowerClocksToFuncsPass
38 :
public arc::impl::LowerClocksToFuncsBase<LowerClocksToFuncsPass> {
39 LowerClocksToFuncsPass() =
default;
40 LowerClocksToFuncsPass(
const LowerClocksToFuncsPass &pass)
41 : LowerClocksToFuncsPass() {}
43 void runOnOperation()
override;
44 LogicalResult lowerModel(ModelOp modelOp);
45 LogicalResult lowerClock(Operation *clockOp, Value modelStorageArg,
46 OpBuilder &funcBuilder);
47 LogicalResult isolateClock(Operation *clockOp, Value modelStorageArg,
48 Value clockStorageArg);
50 SymbolTable *symbolTable;
52 Statistic numOpsCopied{
this,
"ops-copied",
"Ops copied into clock trees"};
53 Statistic numOpsMoved{
this,
"ops-moved",
"Ops moved into clock trees"};
56 bool hasPassthroughOp;
60 void LowerClocksToFuncsPass::runOnOperation() {
61 symbolTable = &getAnalysis<SymbolTable>();
62 for (
auto op : getOperation().getOps<ModelOp>())
63 if (failed(lowerModel(op)))
64 return signalPassFailure();
67 LogicalResult LowerClocksToFuncsPass::lowerModel(ModelOp modelOp) {
68 LLVM_DEBUG(llvm::dbgs() <<
"Lowering clocks in `" << modelOp.getName()
72 SmallVector<InitialOp, 1> initialOps;
73 SmallVector<PassThroughOp, 1> passthroughOps;
74 SmallVector<Operation *> clocks;
75 modelOp.walk([&](Operation *op) {
76 TypeSwitch<Operation *, void>(op)
77 .Case<ClockTreeOp>([&](
auto) { clocks.push_back(op); })
78 .Case<InitialOp>([&](
auto initOp) {
79 initialOps.push_back(initOp);
80 clocks.push_back(initOp);
82 .Case<PassThroughOp>([&](
auto ptOp) {
83 passthroughOps.push_back(ptOp);
84 clocks.push_back(ptOp);
87 hasPassthroughOp = !passthroughOps.empty();
90 if (passthroughOps.size() > 1) {
91 auto diag = modelOp.emitOpError()
92 <<
"containing multiple PassThroughOps cannot be lowered.";
93 for (
auto ptOp : passthroughOps)
94 diag.attachNote(ptOp.getLoc()) <<
"Conflicting PassThroughOp:";
96 if (initialOps.size() > 1) {
97 auto diag = modelOp.emitOpError()
98 <<
"containing multiple InitialOps is currently unsupported.";
99 for (
auto initOp : initialOps)
100 diag.attachNote(initOp.getLoc()) <<
"Conflicting InitialOp:";
102 if (passthroughOps.size() > 1 || initialOps.size() > 1)
106 OpBuilder funcBuilder(modelOp);
107 for (
auto *op : clocks)
108 if (failed(lowerClock(op, modelOp.getBody().getArgument(0), funcBuilder)))
114 LogicalResult LowerClocksToFuncsPass::lowerClock(Operation *clockOp,
115 Value modelStorageArg,
116 OpBuilder &funcBuilder) {
117 LLVM_DEBUG(llvm::dbgs() <<
"- Lowering clock " << clockOp->getName() <<
"\n");
118 assert((isa<ClockTreeOp, PassThroughOp, InitialOp>(clockOp)));
123 Region &clockRegion = clockOp->getRegion(0);
124 Value clockStorageArg = clockRegion.addArgument(modelStorageArg.getType(),
125 modelStorageArg.getLoc());
128 if (failed(isolateClock(clockOp, modelStorageArg, clockStorageArg)))
132 auto builder = OpBuilder::atBlockEnd(&clockRegion.front());
133 builder.create<func::ReturnOp>(clockOp->getLoc());
136 SmallString<32> funcName;
137 auto modelOp = clockOp->getParentOfType<ModelOp>();
138 funcName.append(modelOp.getName());
140 if (isa<PassThroughOp>(clockOp))
141 funcName.append(
"_passthrough");
142 else if (isa<InitialOp>(clockOp))
143 funcName.append(
"_initial");
145 funcName.append(
"_clock");
147 auto funcOp = funcBuilder.create<func::FuncOp>(
148 clockOp->getLoc(), funcName,
149 builder.getFunctionType({modelStorageArg.getType()}, {}));
150 symbolTable->insert(funcOp);
151 LLVM_DEBUG(llvm::dbgs() <<
" - Created function `" << funcOp.getSymName()
155 builder.setInsertionPoint(clockOp);
156 TypeSwitch<Operation *, void>(clockOp)
157 .Case<ClockTreeOp>([&](
auto treeOp) {
158 auto ifOp = builder.create<scf::IfOp>(clockOp->getLoc(),
159 treeOp.getClock(),
false);
160 auto builder = ifOp.getThenBodyBuilder();
161 builder.template create<func::CallOp>(clockOp->getLoc(), funcOp,
162 ValueRange{modelStorageArg});
164 .Case<PassThroughOp>([&](
auto) {
165 builder.template create<func::CallOp>(clockOp->getLoc(), funcOp,
166 ValueRange{modelStorageArg});
168 .Case<InitialOp>([&](
auto) {
169 if (modelOp.getInitialFn().has_value())
170 modelOp.emitWarning() <<
"Existing model initializer '"
171 << modelOp.getInitialFnAttr().getValue()
172 <<
"' will be overridden.";
173 modelOp.setInitialFnAttr(
178 funcOp.getBody().takeBody(clockRegion);
180 if (isa<InitialOp>(clockOp) && hasPassthroughOp) {
182 builder.setInsertionPoint(funcOp.getBlocks().front().getTerminator());
184 funcName.append(modelOp.getName());
185 funcName.append(
"_passthrough");
186 builder.create<func::CallOp>(clockOp->getLoc(), funcName, TypeRange{},
187 ValueRange{funcOp.getBody().getArgument(0)});
197 LogicalResult LowerClocksToFuncsPass::isolateClock(Operation *clockOp,
198 Value modelStorageArg,
199 Value clockStorageArg) {
200 auto *clockRegion = &clockOp->getRegion(0);
201 auto builder = OpBuilder::atBlockBegin(&clockRegion->front());
202 DenseMap<Value, Value> copiedValues;
203 auto result = clockRegion->walk([&](Operation *op) {
204 for (
auto &operand : op->getOpOperands()) {
206 if (operand.get() == modelStorageArg) {
207 operand.set(clockStorageArg);
210 if (isa<BlockArgument>(operand.get())) {
211 auto d = op->emitError(
212 "operation in clock tree uses external block argument");
213 d.attachNote() <<
"clock trees can only use external constant values";
214 d.attachNote() <<
"see operand #" << operand.getOperandNumber();
215 d.attachNote(clockOp->getLoc()) <<
"clock tree:";
216 return WalkResult::interrupt();
220 auto *definingOp = operand.get().getDefiningOp();
221 assert(definingOp &&
"block arguments ruled out above");
222 Region *definingRegion = definingOp->getParentRegion();
223 if (clockRegion->isAncestor(definingRegion))
228 if (
auto copiedValue = copiedValues.lookup(operand.get())) {
229 operand.set(copiedValue);
234 if (!definingOp->hasTrait<ConstantLike>()) {
235 auto d = op->emitError(
"operation in clock tree uses external value");
236 d.attachNote() <<
"clock trees can only use external constant values";
237 d.attachNote(definingOp->getLoc()) <<
"external value defined here:";
238 d.attachNote(clockOp->getLoc()) <<
"clock tree:";
239 return WalkResult::interrupt();
244 bool canMove = llvm::all_of(definingOp->getUsers(), [&](Operation *user) {
245 return clockRegion->isAncestor(user->getParentRegion());
249 definingOp->remove();
250 clonedOp = definingOp;
253 clonedOp = definingOp->cloneWithoutRegions();
256 builder.insert(clonedOp);
258 for (
auto [outerResult, innerResult] :
259 llvm::zip(definingOp->getResults(), clonedOp->getResults())) {
260 copiedValues.insert({outerResult, innerResult});
261 if (operand.get() == outerResult)
262 operand.set(innerResult);
266 return WalkResult::advance();
268 return success(!result.wasInterrupted());
272 return std::make_unique<LowerClocksToFuncsPass>();
assert(baseType &&"element must be base type")
std::unique_ptr< mlir::Pass > createLowerClocksToFuncsPass()
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.