18 #include "mlir/IR/Builders.h"
19 #include "mlir/Pass/Pass.h"
20 #include "llvm/ADT/TypeSwitch.h"
23 #define GEN_PASS_DEF_PIPELINETOHW
24 #include "circt/Conversion/Passes.h.inc"
28 using namespace circt;
29 using namespace pipeline;
33 class PipelineLowering {
35 PipelineLowering(
size_t pipelineID, ScheduledPipelineOp pipeline,
36 OpBuilder &builder,
bool clockGateRegs,
37 bool enablePowerOnValues)
38 : pipelineID(pipelineID), pipeline(pipeline), builder(builder),
39 clockGateRegs(clockGateRegs), enablePowerOnValues(enablePowerOnValues) {
40 parentClk = pipeline.getClock();
41 parentRst = pipeline.getReset();
44 virtual ~PipelineLowering() =
default;
46 virtual LogicalResult
run() = 0;
62 llvm::SmallVector<Value> regs;
63 llvm::SmallVector<Value> passthroughs;
72 virtual FailureOr<StageReturns>
73 lowerStage(
Block *stage, StageArgs args,
size_t stageIndex,
74 llvm::ArrayRef<Attribute> inputNames = {}) = 0;
76 StageReturns emitStageBody(
Block *stage, StageArgs args,
77 llvm::ArrayRef<Attribute> registerNames,
78 size_t stageIndex = -1) {
79 assert(args.enable &&
"enable not set");
80 auto *terminator = stage->getTerminator();
83 for (
auto &op : llvm::make_early_inc_range(*stage)) {
84 if (&op == terminator)
87 if (
auto latencyOp = dyn_cast<LatencyOp>(op)) {
92 Block *latencyOpBody = latencyOp.getBodyBlock();
94 llvm::make_early_inc_range(latencyOpBody->without_terminator()))
95 innerOp.moveBefore(builder.getInsertionBlock(),
96 builder.getInsertionPoint());
97 latencyOp.replaceAllUsesWith(
98 latencyOpBody->getTerminator()->getOperands());
101 op.moveBefore(builder.getInsertionBlock(), builder.getInsertionPoint());
105 auto loc = terminator->getLoc();
107 auto getOrSetNotStalled = [&]() {
116 StageKind stageKind = pipeline.getStageKind(stageIndex);
118 StringAttr validSignalName =
119 builder.getStringAttr(getStagePrefix(stageIndex).strref() +
"_valid");
121 case StageKind::Continuous:
123 case StageKind::NonStallable:
124 stageValid = args.enable;
126 case StageKind::Stallable:
128 builder.create<
comb::AndOp>(loc, args.enable, getOrSetNotStalled());
129 stageValid.getDefiningOp()->setAttr(
"sv.namehint", validSignalName);
131 case StageKind::Runoff:
132 assert(args.lnsEn &&
"Expected an LNS signal if this was a runoff stage");
135 builder.create<
comb::OrOp>(loc, args.lnsEn, getOrSetNotStalled()));
136 stageValid.getDefiningOp()->setAttr(
"sv.namehint", validSignalName);
141 auto stageOp = dyn_cast<StageOp>(terminator);
143 assert(isa<ReturnOp>(terminator) &&
"expected ReturnOp");
147 rets.passthroughs = terminator->getOperands();
148 rets.valid = stageValid;
152 assert(registerNames.size() == stageOp.getRegisters().size() &&
153 "register names and registers must be the same size");
155 bool isStallablePipeline = stageKind != StageKind::Continuous;
156 Value notStalledClockGate;
157 if (this->clockGateRegs) {
159 notStalledClockGate = builder.create<seq::ClockGateOp>(
160 loc, args.clock, stageValid, Value(),
164 for (
auto it : llvm::enumerate(stageOp.getRegisters())) {
165 auto regIdx = it.index();
166 auto regIn = it.value();
168 StringAttr regName = cast<StringAttr>(registerNames[regIdx]);
170 if (this->clockGateRegs) {
172 Value currClockGate = notStalledClockGate;
173 for (
auto hierClockGateEnable : stageOp.getClockGatesForReg(regIdx)) {
175 currClockGate = builder.create<seq::ClockGateOp>(
176 loc, currClockGate, hierClockGateEnable,
180 dataReg = builder.create<
seq::CompRegOp>(stageOp->getLoc(), regIn,
181 currClockGate, regName);
186 if (isStallablePipeline) {
188 stageOp->getLoc(), regIn, args.clock, stageValid, regName);
191 args.clock, regName);
194 rets.regs.push_back(dataReg);
197 rets.valid = stageValid;
198 if (stageKind == StageKind::NonStallable)
199 rets.lnsEn = args.enable;
201 rets.passthroughs = stageOp.getPassthroughs();
208 struct StageEgressNames {
209 llvm::SmallVector<Attribute> regNames;
210 llvm::SmallVector<Attribute> outNames;
211 llvm::SmallVector<Attribute> inNames;
217 void getStageEgressNames(
size_t stageIndex, Operation *stageTerminator,
218 bool withPipelinePrefix,
219 StageEgressNames &egressNames) {
220 StringAttr pipelineName;
221 if (withPipelinePrefix)
222 pipelineName = getPipelineBaseName();
224 if (
auto stageOp = dyn_cast<StageOp>(stageTerminator)) {
226 std::string assignedRegName, assignedOutName, assignedInName;
227 for (
size_t regi = 0; regi < stageOp.getRegisters().size(); ++regi) {
228 if (
auto regName = stageOp.getRegisterName(regi)) {
229 assignedRegName = regName.str();
230 assignedOutName = assignedRegName +
"_out";
231 assignedInName = assignedRegName +
"_in";
234 (
"stage" + Twine(stageIndex) +
"_reg" + Twine(regi)).str();
235 assignedOutName = (
"out" + Twine(regi)).str();
236 assignedInName = (
"in" + Twine(regi)).str();
239 if (pipelineName && !pipelineName.getValue().empty()) {
240 assignedRegName = pipelineName.str() +
"_" + assignedRegName;
241 assignedOutName = pipelineName.str() +
"_" + assignedOutName;
242 assignedInName = pipelineName.str() +
"_" + assignedInName;
245 egressNames.regNames.push_back(builder.getStringAttr(assignedRegName));
246 egressNames.outNames.push_back(builder.getStringAttr(assignedOutName));
247 egressNames.inNames.push_back(builder.getStringAttr(assignedInName));
251 for (
size_t passi = 0; passi < stageOp.getPassthroughs().size();
253 if (
auto passName = stageOp.getPassthroughName(passi)) {
254 assignedOutName = (passName.strref() +
"_out").str();
255 assignedInName = (passName.strref() +
"_in").str();
257 assignedOutName = (
"pass" + Twine(passi)).str();
258 assignedInName = (
"pass" + Twine(passi)).str();
261 if (pipelineName && !pipelineName.getValue().empty()) {
262 assignedOutName = pipelineName.str() +
"_" + assignedOutName;
263 assignedInName = pipelineName.str() +
"_" + assignedInName;
266 egressNames.outNames.push_back(builder.getStringAttr(assignedOutName));
267 egressNames.inNames.push_back(builder.getStringAttr(assignedInName));
272 llvm::copy(pipeline.getOutputNames().getAsRange<StringAttr>(),
273 std::back_inserter(egressNames.outNames));
278 virtual StringAttr getStagePrefix(
size_t stageIdx) = 0;
283 StringAttr getPipelineBaseName() {
284 if (
auto nameAttr = pipeline.getNameAttr())
286 return StringAttr::get(pipeline.getContext(),
"p" + Twine(pipelineID));
296 ScheduledPipelineOp pipeline;
307 bool enablePowerOnValues;
311 StringAttr pipelineName;
314 class PipelineInlineLowering :
public PipelineLowering {
316 using PipelineLowering::PipelineLowering;
318 StringAttr getStagePrefix(
size_t stageIdx)
override {
319 if (pipelineName && !pipelineName.getValue().empty())
320 return builder.getStringAttr(pipelineName.strref() +
"_stage" +
322 return builder.getStringAttr(
"stage" + Twine(stageIdx));
325 LogicalResult
run()
override {
326 pipelineName = getPipelineBaseName();
329 for (
auto [outer, inner] :
330 llvm::zip(pipeline.getInputs(), pipeline.getInnerInputs()))
331 inner.replaceAllUsesWith(outer);
335 builder.setInsertionPoint(pipeline);
337 args.data = pipeline.getInnerInputs();
338 args.enable = pipeline.getGo();
339 args.clock = pipeline.getClock();
340 args.reset = pipeline.getReset();
341 args.stall = pipeline.getStall();
342 if (failed(lowerStage(pipeline.getEntryStage(), args, 0)))
350 FailureOr<StageReturns>
351 lowerStage(Block *stage, StageArgs args,
size_t stageIndex,
352 llvm::ArrayRef<Attribute> = {})
override {
353 OpBuilder::InsertionGuard guard(builder);
354 Operation *terminator = stage->getTerminator();
355 Location loc = terminator->getLoc();
357 if (stage != pipeline.getEntryStage()) {
359 for (
auto [vInput, vArg] :
360 llvm::zip(pipeline.getStageDataArgs(stage), args.data))
361 vInput.replaceAllUsesWith(vArg);
370 StageKind stageKind = pipeline.getStageKind(stageIndex);
372 if (stageIndex == 0) {
373 stageEnabled = args.enable;
375 auto stageRegPrefix = getStagePrefix(stageIndex);
376 auto enableRegName = (stageRegPrefix.strref() +
"_enable").str();
381 case StageKind::Continuous:
383 case StageKind::NonStallable:
385 loc, args.enable, args.clock, args.reset, enableRegResetVal,
388 case StageKind::Stallable:
390 loc, args.enable, args.clock,
392 enableRegResetVal, enableRegName);
394 case StageKind::Runoff:
396 "Expected an LNS signal if this was a runoff stage");
398 loc, args.enable, args.clock,
402 args.reset, enableRegResetVal, enableRegName);
406 if (enablePowerOnValues) {
407 llvm::TypeSwitch<Operation *, void>(stageEnabled.getDefiningOp())
409 op.getInitialValueMutable().assign(
417 args.enable = stageEnabled;
418 pipeline.getStageEnableSignal(stage).replaceAllUsesWith(stageEnabled);
421 auto nextStage = dyn_cast<StageOp>(terminator);
422 StageEgressNames egressNames;
424 getStageEgressNames(stageIndex, nextStage,
428 builder.setInsertionPoint(pipeline);
429 StageReturns stageRets =
430 emitStageBody(stage, args, egressNames.regNames, stageIndex);
434 SmallVector<Value> nextStageArgs;
435 llvm::append_range(nextStageArgs, stageRets.regs);
436 llvm::append_range(nextStageArgs, stageRets.passthroughs);
437 args.enable = stageRets.valid;
438 if (stageRets.lnsEn) {
441 args.lnsEn = stageRets.lnsEn;
443 args.data = nextStageArgs;
444 return lowerStage(nextStage.getNextStage(), args, stageIndex + 1);
448 auto returnOp = cast<pipeline::ReturnOp>(stage->getTerminator());
449 llvm::SmallVector<Value> pipelineReturns;
450 llvm::append_range(pipelineReturns, returnOp.getInputs());
452 pipelineReturns.push_back(stageRets.valid);
453 pipeline.replaceAllUsesWith(pipelineReturns);
464 struct PipelineToHWPass
465 :
public circt::impl::PipelineToHWBase<PipelineToHWPass> {
466 using PipelineToHWBase::PipelineToHWBase;
467 void runOnOperation()
override;
476 void PipelineToHWPass::runOnOperation() {
477 for (
auto hwMod : getOperation().getOps<hw::HWModuleOp>())
478 runOnHWModule(hwMod);
482 OpBuilder builder(&getContext());
487 size_t pipelinesSeen = 0;
489 llvm::make_early_inc_range(mod.getOps<ScheduledPipelineOp>())) {
490 if (failed(PipelineInlineLowering(pipelinesSeen, pipeline, builder,
491 clockGateRegs, enablePowerOnValues)
502 std::unique_ptr<mlir::Pass>
504 return std::make_unique<PipelineToHWPass>(options);
assert(baseType &&"element must be base type")
def create(data_type, value)
def create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
def create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
mlir::TypedValue< seq::ImmutableType > createConstantInitialValue(OpBuilder builder, Location loc, mlir::IntegerAttr attr)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createPipelineToHWPass(const PipelineToHWOptions &options={})
Create an SCF to Calyx conversion pass.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)