18#include "mlir/IR/Builders.h"
19#include "mlir/Pass/Pass.h"
20#include "llvm/ADT/TypeSwitch.h"
23#define GEN_PASS_DEF_PIPELINETOHW
24#include "circt/Conversion/Passes.h.inc"
29using namespace pipeline;
33class PipelineLowering {
35 PipelineLowering(
size_t pipelineID, ScheduledPipelineOp pipeline,
36 OpBuilder &builder,
bool clockGateRegs,
37 bool enablePowerOnValues)
38 : pipelineID(pipelineID), pipeline(pipeline), builder(builder),
39 clockGateRegs(clockGateRegs), enablePowerOnValues(enablePowerOnValues) {
40 parentClk = pipeline.getClock();
41 parentRst = pipeline.getReset();
44 virtual ~PipelineLowering() =
default;
46 virtual LogicalResult
run() = 0;
62 llvm::SmallVector<Value> regs;
63 llvm::SmallVector<Value> passthroughs;
72 virtual FailureOr<StageReturns>
73 lowerStage(
Block *stage, StageArgs args,
size_t stageIndex,
74 llvm::ArrayRef<Attribute> inputNames = {}) = 0;
76 StageReturns emitStageBody(
Block *stage, StageArgs args,
77 llvm::ArrayRef<Attribute> registerNames,
78 size_t stageIndex = -1) {
79 assert(args.enable &&
"enable not set");
80 auto *terminator = stage->getTerminator();
83 for (
auto &op : llvm::make_early_inc_range(*stage)) {
84 if (&op == terminator)
87 if (
auto latencyOp = dyn_cast<LatencyOp>(op)) {
92 Block *latencyOpBody = latencyOp.getBodyBlock();
94 llvm::make_early_inc_range(latencyOpBody->without_terminator()))
95 innerOp.moveBefore(builder.getInsertionBlock(),
96 builder.getInsertionPoint());
97 latencyOp.replaceAllUsesWith(
98 latencyOpBody->getTerminator()->getOperands());
101 op.moveBefore(builder.getInsertionBlock(), builder.getInsertionPoint());
105 auto loc = terminator->getLoc();
107 auto getOrSetNotStalled = [&]() {
109 notStalled = comb::createOrFoldNot(loc, args.stall, builder);
116 StageKind stageKind = pipeline.getStageKind(stageIndex);
118 StringAttr validSignalName =
119 builder.getStringAttr(getStagePrefix(stageIndex).strref() +
"_valid");
121 case StageKind::Continuous:
123 case StageKind::NonStallable:
124 stageValid = args.enable;
126 case StageKind::Stallable:
128 comb::AndOp::create(builder, loc, args.enable, getOrSetNotStalled());
129 stageValid.getDefiningOp()->setAttr(
"sv.namehint", validSignalName);
131 case StageKind::Runoff:
132 assert(args.lnsEn &&
"Expected an LNS signal if this was a runoff stage");
133 stageValid = comb::AndOp::create(
134 builder, loc, args.enable,
135 comb::OrOp::create(builder, loc, args.lnsEn, getOrSetNotStalled()));
136 stageValid.getDefiningOp()->setAttr(
"sv.namehint", validSignalName);
141 auto stageOp = dyn_cast<StageOp>(terminator);
143 assert(isa<ReturnOp>(terminator) &&
"expected ReturnOp");
147 rets.passthroughs = terminator->getOperands();
148 rets.valid = stageValid;
152 assert(registerNames.size() == stageOp.getRegisters().size() &&
153 "register names and registers must be the same size");
155 bool isStallablePipeline = stageKind != StageKind::Continuous;
156 Value notStalledClockGate;
157 if (this->clockGateRegs) {
159 notStalledClockGate = seq::ClockGateOp::create(
160 builder, loc, args.clock, stageValid, Value(),
164 for (
auto it : llvm::enumerate(stageOp.getRegisters())) {
165 auto regIdx = it.index();
166 auto regIn = it.value();
168 StringAttr regName = cast<StringAttr>(registerNames[regIdx]);
170 if (this->clockGateRegs) {
172 Value currClockGate = notStalledClockGate;
173 for (
auto hierClockGateEnable : stageOp.getClockGatesForReg(regIdx)) {
175 currClockGate = seq::ClockGateOp::create(
176 builder, loc, currClockGate, hierClockGateEnable,
181 currClockGate, regName);
186 if (isStallablePipeline) {
188 builder, stageOp->getLoc(), regIn, args.clock, stageValid,
192 args.clock, regName);
195 rets.regs.push_back(dataReg);
198 rets.valid = stageValid;
199 if (stageKind == StageKind::NonStallable)
200 rets.lnsEn = args.enable;
202 rets.passthroughs = stageOp.getPassthroughs();
209 struct StageEgressNames {
210 llvm::SmallVector<Attribute> regNames;
211 llvm::SmallVector<Attribute> outNames;
212 llvm::SmallVector<Attribute> inNames;
218 void getStageEgressNames(
size_t stageIndex, Operation *stageTerminator,
219 bool withPipelinePrefix,
220 StageEgressNames &egressNames) {
221 StringAttr pipelineName;
222 if (withPipelinePrefix)
223 pipelineName = getPipelineBaseName();
225 if (
auto stageOp = dyn_cast<StageOp>(stageTerminator)) {
227 std::string assignedRegName, assignedOutName, assignedInName;
228 for (
size_t regi = 0; regi < stageOp.getRegisters().size(); ++regi) {
229 if (
auto regName = stageOp.getRegisterName(regi)) {
230 assignedRegName = regName.str();
231 assignedOutName = assignedRegName +
"_out";
232 assignedInName = assignedRegName +
"_in";
235 (
"stage" + Twine(stageIndex) +
"_reg" + Twine(regi)).str();
236 assignedOutName = (
"out" + Twine(regi)).str();
237 assignedInName = (
"in" + Twine(regi)).str();
240 if (pipelineName && !pipelineName.getValue().empty()) {
241 assignedRegName = pipelineName.str() +
"_" + assignedRegName;
242 assignedOutName = pipelineName.str() +
"_" + assignedOutName;
243 assignedInName = pipelineName.str() +
"_" + assignedInName;
246 egressNames.regNames.push_back(builder.getStringAttr(assignedRegName));
247 egressNames.outNames.push_back(builder.getStringAttr(assignedOutName));
248 egressNames.inNames.push_back(builder.getStringAttr(assignedInName));
252 for (
size_t passi = 0; passi < stageOp.getPassthroughs().size();
254 if (
auto passName = stageOp.getPassthroughName(passi)) {
255 assignedOutName = (passName.strref() +
"_out").str();
256 assignedInName = (passName.strref() +
"_in").str();
258 assignedOutName = (
"pass" + Twine(passi)).str();
259 assignedInName = (
"pass" + Twine(passi)).str();
262 if (pipelineName && !pipelineName.getValue().empty()) {
263 assignedOutName = pipelineName.str() +
"_" + assignedOutName;
264 assignedInName = pipelineName.str() +
"_" + assignedInName;
267 egressNames.outNames.push_back(builder.getStringAttr(assignedOutName));
268 egressNames.inNames.push_back(builder.getStringAttr(assignedInName));
273 llvm::copy(pipeline.getOutputNames().getAsRange<StringAttr>(),
274 std::back_inserter(egressNames.outNames));
279 virtual StringAttr getStagePrefix(
size_t stageIdx) = 0;
284 StringAttr getPipelineBaseName() {
285 if (
auto nameAttr = pipeline.getNameAttr())
287 return StringAttr::get(pipeline.getContext(),
"p" + Twine(pipelineID));
297 ScheduledPipelineOp pipeline;
308 bool enablePowerOnValues;
312 StringAttr pipelineName;
315class PipelineInlineLowering :
public PipelineLowering {
317 using PipelineLowering::PipelineLowering;
319 StringAttr getStagePrefix(
size_t stageIdx)
override {
320 if (pipelineName && !pipelineName.getValue().empty())
321 return builder.getStringAttr(pipelineName.strref() +
"_stage" +
323 return builder.getStringAttr(
"stage" + Twine(stageIdx));
326 LogicalResult
run()
override {
327 pipelineName = getPipelineBaseName();
330 for (
auto [outer, inner] :
331 llvm::zip(pipeline.getInputs(), pipeline.getInnerInputs()))
332 inner.replaceAllUsesWith(outer);
336 builder.setInsertionPoint(pipeline);
338 args.data = pipeline.getInnerInputs();
339 args.enable = pipeline.getGo();
340 args.clock = pipeline.getClock();
341 args.reset = pipeline.getReset();
342 args.stall = pipeline.getStall();
343 if (failed(lowerStage(pipeline.getEntryStage(), args, 0)))
351 FailureOr<StageReturns>
352 lowerStage(Block *stage, StageArgs args,
size_t stageIndex,
353 llvm::ArrayRef<Attribute> = {})
override {
354 OpBuilder::InsertionGuard guard(builder);
355 Operation *terminator = stage->getTerminator();
356 Location loc = terminator->getLoc();
358 if (stage != pipeline.getEntryStage()) {
360 for (
auto [vInput, vArg] :
361 llvm::zip(pipeline.getStageDataArgs(stage), args.
data))
362 vInput.replaceAllUsesWith(vArg);
374 StageKind stageKind = pipeline.getStageKind(stageIndex);
376 if (stageIndex == 0) {
377 stageEnabled = args.enable;
379 auto stageRegPrefix = getStagePrefix(stageIndex);
380 auto enableRegName = (stageRegPrefix.strref() +
"_enable").str();
382 Value enableRegResetVal;
389 case StageKind::Continuous:
391 case StageKind::NonStallable:
393 args.clock, args.reset,
394 enableRegResetVal, enableRegName);
396 case StageKind::Stallable:
398 builder, loc, args.enable, args.clock,
399 comb::createOrFoldNot(loc, args.stall, builder), args.reset,
400 enableRegResetVal, enableRegName);
402 case StageKind::Runoff:
404 "Expected an LNS signal if this was a runoff stage");
406 builder, loc, args.enable, args.clock,
407 comb::OrOp::create(builder, loc, args.lnsEn,
408 comb::createOrFoldNot(loc, args.stall, builder)),
409 args.reset, enableRegResetVal, enableRegName);
413 if (enablePowerOnValues) {
414 llvm::TypeSwitch<Operation *, void>(stageEnabled.getDefiningOp())
416 op.getInitialValueMutable().assign(
419 builder.getIntegerAttr(builder.getI1Type(),
420 APInt(1, 0,
false))));
426 args.enable = stageEnabled;
427 pipeline.getStageEnableSignal(stage).replaceAllUsesWith(stageEnabled);
430 auto nextStage = dyn_cast<StageOp>(terminator);
431 StageEgressNames egressNames;
433 getStageEgressNames(stageIndex, nextStage,
437 builder.setInsertionPoint(pipeline);
438 StageReturns stageRets =
439 emitStageBody(stage, args, egressNames.regNames, stageIndex);
443 SmallVector<Value> nextStageArgs;
444 llvm::append_range(nextStageArgs, stageRets.regs);
445 llvm::append_range(nextStageArgs, stageRets.passthroughs);
446 args.enable = stageRets.valid;
447 if (stageRets.lnsEn) {
450 args.lnsEn = stageRets.lnsEn;
452 args.data = nextStageArgs;
453 return lowerStage(nextStage.getNextStage(), args, stageIndex + 1);
457 auto returnOp = cast<pipeline::ReturnOp>(stage->getTerminator());
458 llvm::SmallVector<Value> pipelineReturns;
459 llvm::append_range(pipelineReturns, returnOp.getInputs());
461 pipelineReturns.push_back(stageRets.valid);
462 pipeline.replaceAllUsesWith(pipelineReturns);
473struct PipelineToHWPass
474 :
public circt::impl::PipelineToHWBase<PipelineToHWPass> {
475 using PipelineToHWBase::PipelineToHWBase;
476 void runOnOperation()
override;
485void PipelineToHWPass::runOnOperation() {
486 for (
auto hwMod : getOperation().getOps<
hw::HWModuleOp>())
487 runOnHWModule(hwMod);
491 OpBuilder builder(&getContext());
496 size_t pipelinesSeen = 0;
498 llvm::make_early_inc_range(mod.getOps<ScheduledPipelineOp>())) {
499 if (failed(PipelineInlineLowering(pipelinesSeen, pipeline, builder,
500 clockGateRegs, enablePowerOnValues)
511std::unique_ptr<mlir::Pass>
513 return std::make_unique<PipelineToHWPass>(options);
assert(baseType &&"element must be base type")
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
mlir::TypedValue< seq::ImmutableType > createConstantInitialValue(OpBuilder builder, Location loc, mlir::IntegerAttr attr)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createPipelineToHWPass(const PipelineToHWOptions &options={})
Create an SCF to Calyx conversion pass.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)