22 #include "mlir/IR/BuiltinAttributes.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "llvm/ADT/PostOrderIterator.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
28 #define DEBUG_TYPE "firrtl-hoist-passthrough"
30 using namespace circt;
31 using namespace firrtl;
84 FConnectLike drivingConnect;
91 : drivingConnect(
connect), source(source) {
92 assert((isa<RefDriver, HWDriver>(*
this)));
103 bool canHoist()
const {
return isa<BlockArgument>(source.
getValue()); }
106 using PortMappingFn = llvm::function_ref<Value(
size_t)>;
110 Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder);
114 void finalize(ImplicitLocOpBuilder &
builder);
119 operator bool()
const {
return source; }
122 Value getDest()
const {
125 return const_cast<Driver *
>(
this)->drivingConnect.getDest();
129 bool drivesModuleArg()
const {
130 auto arg = dyn_cast<BlockArgument>(getDest());
131 assert(!arg || isa<firrtl::FModuleLike>(arg.getOwner()->getParentOp()));
136 bool drivesInstanceResult()
const {
137 return getDest().getDefiningOp<hw::HWInstanceLike>();
141 BlockArgument getDestBlockArg()
const {
142 assert(drivesModuleArg());
143 return dyn_cast<BlockArgument>(getDest());
147 OpResult getDestOpResult()
const {
148 assert(drivesInstanceResult());
149 return dyn_cast<OpResult>(getDest());
154 static size_t getIndex(Value v) {
155 if (
auto arg = dyn_cast<BlockArgument>(v))
156 return arg.getArgNumber();
157 auto result = dyn_cast<OpResult>(v);
159 return result.getResultNumber();
164 struct RefDriver :
public Driver {
165 using Driver::Driver;
167 static bool classof(
const Driver *t) {
return isa<RefValue>(t->getDest()); }
171 Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder);
173 static_assert(
sizeof(RefDriver) ==
sizeof(Driver),
174 "passed by value, no slicing");
179 struct HWDriver :
public Driver {
180 using Driver::Driver;
182 static bool classof(
const Driver *t) {
return !isa<RefValue>(t->getDest()); }
186 Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder);
188 static_assert(
sizeof(HWDriver) ==
sizeof(Driver),
189 "passed by value, no slicing");
192 template <
typename T>
193 static inline T &
operator<<(T &os, Driver &d) {
195 return os <<
"(null)";
196 return os << d.getDest() <<
" <-- " << d.drivingConnect <<
" <-- "
197 << d.source.getValue() <<
"@" << d.source.getFieldID();
214 Value Driver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder) {
215 return TypeSwitch<Driver *, Value>(
this)
216 .Case<RefDriver, HWDriver>(
217 [&](
auto *d) {
return d->remat(mapPortFn,
builder); })
221 void Driver::finalize(ImplicitLocOpBuilder &
builder) {
222 auto immSource = drivingConnect.getSrc();
223 auto dest = getDest();
224 assert(immSource.getType() == dest.getType() &&
225 "final connect must be strict");
226 if (dest.hasOneUse()) {
228 drivingConnect.erase();
229 }
else if (isa<BlockArgument>(immSource)) {
231 drivingConnect.erase();
232 dest.replaceAllUsesWith(immSource);
237 auto temp =
builder.create<WireOp>(immSource.getType());
238 dest.replaceAllUsesWith(temp.getDataRaw());
247 for (
auto *user : result.getUsers()) {
248 if (
auto rd = dyn_cast<RefDefineOp>(user); rd && rd.getDest() == result)
255 auto refVal = dyn_cast<RefValue>(v);
267 return RefDriver(rd, ref);
270 Value RefDriver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder) {
271 auto mappedSource = mapPortFn(getIndex(source.getValue()));
273 auto destType = getDest().getType();
274 if (newVal.getType() != destType)
275 newVal =
builder.create<RefCastOp>(destType, newVal);
286 auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op);
287 return symOp && symOp.getTargetResultIndex() && symOp.getInnerSymAttr();
291 if (
auto *op =
value.getDefiningOp())
293 auto arg = dyn_cast<BlockArgument>(
value);
294 auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
295 return (module.getPortSymbolAttr(arg.getArgNumber())) ||
300 auto baseValue = dyn_cast<FIRRTLBaseValue>(v);
307 if (!baseValue.getType().isPassive() || !baseValue.getType().isGround())
319 if (v.getParentBlock() != ref.getValue().getParentBlock() ||
320 v.getParentBlock() !=
connect->getBlock())
329 if (
auto fop = ref.getValue().getDefiningOp<Forceable>();
330 fop && fop.isForceable())
334 auto sourceType = type_dyn_cast<FIRRTLBaseType>(ref.getValue().getType());
337 if (!sourceType.isPassive())
341 baseValue.getType() &&
342 "unexpected type mismatch, cast or extension?");
347 Value HWDriver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder) {
348 auto mappedSource = mapPortFn(getIndex(source.getValue()));
363 MustDrivenBy() =
default;
364 MustDrivenBy(FModuleOp mod) { run(mod); }
367 Driver getDriverFor(Value v)
const {
return driverMap.lookup(v); }
371 Driver getCombinedDriverFor(Value v)
const {
372 Driver driver = driverMap.lookup(v);
379 SmallPtrSet<Value, 8> seen;
380 while ((cur = driverMap.lookup(cur.source.getValue()))) {
382 if (!seen.insert(cur.source.getValue()).second)
384 driver.source = cur.source.getSubField(driver.source.getFieldID());
388 LLVM_DEBUG(
llvm::dbgs() <<
"Found driver for " << v <<
" (chain length = "
389 << len <<
"): " << driver <<
"\n");
394 void run(FModuleOp mod) {
395 SmallVector<Value, 64> worklist(mod.getArguments());
397 DenseSet<Value> enqueued;
398 enqueued.insert(worklist.begin(), worklist.end());
400 while (!worklist.empty()) {
401 auto val = worklist.pop_back_val();
404 driverMap.insert({val, driver});
408 auto sourceVal = driver.source.getValue();
411 if (!enqueued.insert(sourceVal).second)
422 if (!isa<BlockArgument>(sourceVal) &&
423 !isa_and_nonnull<WireOp, InstanceOp>(sourceVal.getDefiningOp()))
426 worklist.push_back(sourceVal);
432 llvm::dbgs() <<
"Analyzed " << mod.getModuleName() <<
" and found "
433 << driverMap.size() <<
" drivers.\n";
439 void clear() { driverMap.clear(); }
442 void setIgnoreHWDrivers(
bool ignore) { ignoreHWDrivers = ignore; }
446 DenseMap<Value, Driver> driverMap;
447 bool ignoreHWDrivers =
false;
458 struct HoistPassthroughPass
459 :
public HoistPassthroughBase<HoistPassthroughPass> {
460 using HoistPassthroughBase::HoistPassthroughBase;
461 void runOnOperation()
override;
463 using HoistPassthroughBase::hoistHWDrivers;
467 void HoistPassthroughPass::runOnOperation() {
468 LLVM_DEBUG(
llvm::dbgs() <<
"===- Running HoistPassthrough Pass "
469 "------------------------------------------===\n");
470 auto &instanceGraph = getAnalysis<InstanceGraph>();
472 SmallVector<FModuleOp, 0> modules(llvm::make_filter_range(
474 llvm::post_order(&instanceGraph),
475 [](
auto *node) {
return dyn_cast<FModuleOp>(*node->getModule()); }),
476 [](
auto module) {
return module; }));
478 MustDrivenBy driverAnalysis;
479 driverAnalysis.setIgnoreHWDrivers(!hoistHWDrivers);
481 bool anyChanged =
false;
484 for (
auto module : modules) {
487 if (module.isPublic())
494 BitVector deadPorts(module.getNumPorts());
497 auto *igNode = instanceGraph.lookup(module);
500 driverAnalysis.clear();
501 driverAnalysis.run(module);
502 auto notNullAndCanHoist = [](
const Driver &d) ->
bool {
503 return d && d.canHoist();
505 SmallVector<Driver, 16> drivers(llvm::make_filter_range(
506 llvm::map_range(module.getArguments(),
507 [&driverAnalysis](
auto val) {
508 return driverAnalysis.getCombinedDriverFor(val);
510 notNullAndCanHoist));
521 for (
auto &driver : drivers) {
522 std::optional<size_t> deadPort;
524 auto destArg = driver.getDestBlockArg();
525 auto index = destArg.getArgNumber();
528 for (
auto *record : igNode->uses()) {
529 auto inst = cast<InstanceOp>(record->getInstance());
530 ImplicitLocOpBuilder
builder(inst.getLoc(), inst);
531 builder.setInsertionPointAfter(inst);
533 auto mappedDest = inst.getResult(index);
534 mappedDest.replaceAllUsesWith(driver.remat(
535 [&inst](
size_t index) { return inst.getResult(index); },
541 assert(deadPort.has_value());
543 assert(!deadPorts.test(*deadPort));
544 deadPorts.set(*deadPort);
547 TypeSwitch<Driver *, void>(&driver)
548 .Case<RefDriver>([&](
auto *) { ++numRefDrivers; })
549 .Case<HWDriver>([&](
auto *) { ++numHWDrivers; });
554 ImplicitLocOpBuilder
builder(module.getLoc(), module.getBody());
555 for (
auto &driver : drivers) {
557 builder.setLoc(driver.getDest().getLoc());
564 for (
auto *record : llvm::make_early_inc_range(igNode->uses())) {
565 auto inst = cast<InstanceOp>(record->getInstance());
566 ImplicitLocOpBuilder
builder(inst.getLoc(), inst);
568 assert(inst.getNumResults() == deadPorts.size());
569 auto newInst = inst.erasePorts(
builder, deadPorts);
570 instanceGraph.replaceInstance(inst, newInst);
575 module.erasePorts(deadPorts);
577 numUTurnsHoisted += deadPorts.count();
579 markAnalysesPreserved<InstanceGraph>();
582 markAllAnalysesPreserved();
586 std::unique_ptr<mlir::Pass>
588 auto pass = std::make_unique<HoistPassthroughPass>();
589 pass->hoistHWDrivers = hoistHWDrivers;
assert(baseType &&"element must be base type")
static RefDefineOp getRefDefine(Value result)
mlir::TypedValue< RefType > RefValue
static bool hasDontTouchOrInnerSymOnResult(Operation *op)
This class represents a reference to a specific field or element of an aggregate value.
Value getValue() const
Get the Value which created this location.
bool hasDontTouch() const
firrtl.transforms.DontTouchAnnotation
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
Caching version of getFieldRefFromValue.
void printStats(llvm::raw_ostream &os) const
void verify() const
Verify cached fieldRefs against firrtl::getFieldRefFromValue.
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Caching version of getFieldRefFromValue.
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
T & operator<<(T &os, FIRVersion version)
std::unique_ptr< mlir::Pass > createHoistPassthroughPass(bool hoistHWDrivers=true)
This is the pass constructor.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
StrictConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
uint64_t getMaxFieldID(Type)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & dbgs()