24 #include "mlir/IR/BuiltinAttributes.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "llvm/ADT/PostOrderIterator.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/FormatVariadic.h"
30 #define DEBUG_TYPE "firrtl-hoist-passthrough"
32 using namespace circt;
33 using namespace firrtl;
86 FConnectLike drivingConnect;
93 : drivingConnect(
connect), source(source) {
94 assert((isa<RefDriver, HWDriver>(*
this)));
105 bool canHoist()
const {
return isa<BlockArgument>(source.
getValue()); }
108 using PortMappingFn = llvm::function_ref<Value(
size_t)>;
112 Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder);
116 void finalize(ImplicitLocOpBuilder &
builder);
121 operator bool()
const {
return source; }
124 Value getDest()
const {
127 return const_cast<Driver *
>(
this)->drivingConnect.getDest();
131 bool drivesModuleArg()
const {
132 auto arg = dyn_cast<BlockArgument>(getDest());
133 assert(!arg || isa<firrtl::FModuleLike>(arg.getOwner()->getParentOp()));
138 bool drivesInstanceResult()
const {
139 return getDest().getDefiningOp<hw::HWInstanceLike>();
143 BlockArgument getDestBlockArg()
const {
144 assert(drivesModuleArg());
145 return dyn_cast<BlockArgument>(getDest());
149 OpResult getDestOpResult()
const {
150 assert(drivesInstanceResult());
151 return dyn_cast<OpResult>(getDest());
156 static size_t getIndex(Value v) {
157 if (
auto arg = dyn_cast<BlockArgument>(v))
158 return arg.getArgNumber();
159 auto result = dyn_cast<OpResult>(v);
161 return result.getResultNumber();
166 struct RefDriver :
public Driver {
167 using Driver::Driver;
169 static bool classof(
const Driver *t) {
return isa<RefValue>(t->getDest()); }
173 Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder);
175 static_assert(
sizeof(RefDriver) ==
sizeof(Driver),
176 "passed by value, no slicing");
181 struct HWDriver :
public Driver {
182 using Driver::Driver;
184 static bool classof(
const Driver *t) {
return !isa<RefValue>(t->getDest()); }
188 Value remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder);
190 static_assert(
sizeof(HWDriver) ==
sizeof(Driver),
191 "passed by value, no slicing");
194 template <
typename T>
195 static inline T &
operator<<(T &os, Driver &d) {
197 return os <<
"(null)";
198 return os << d.getDest() <<
" <-- " << d.drivingConnect <<
" <-- "
199 << d.source.getValue() <<
"@" << d.source.getFieldID();
216 Value Driver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder) {
217 return TypeSwitch<Driver *, Value>(
this)
218 .Case<RefDriver, HWDriver>(
219 [&](
auto *d) {
return d->remat(mapPortFn,
builder); })
223 void Driver::finalize(ImplicitLocOpBuilder &
builder) {
224 auto immSource = drivingConnect.getSrc();
225 auto dest = getDest();
226 assert(immSource.getType() == dest.getType() &&
227 "final connect must be strict");
228 if (dest.hasOneUse()) {
230 drivingConnect.erase();
231 }
else if (isa<BlockArgument>(immSource)) {
233 drivingConnect.erase();
234 dest.replaceAllUsesWith(immSource);
239 auto temp =
builder.create<WireOp>(immSource.getType());
240 dest.replaceAllUsesWith(temp.getDataRaw());
249 for (
auto *user : result.getUsers()) {
250 if (
auto rd = dyn_cast<RefDefineOp>(user); rd && rd.getDest() == result)
257 auto refVal = dyn_cast<RefValue>(v);
269 return RefDriver(rd, ref);
272 Value RefDriver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder) {
273 auto mappedSource = mapPortFn(getIndex(source.getValue()));
275 auto destType = getDest().getType();
276 if (newVal.getType() != destType)
277 newVal =
builder.create<RefCastOp>(destType, newVal);
288 auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op);
289 return symOp && symOp.getTargetResultIndex() && symOp.getInnerSymAttr();
293 if (
auto *op = value.getDefiningOp())
295 auto arg = dyn_cast<BlockArgument>(value);
296 auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
297 return (module.getPortSymbolAttr(arg.getArgNumber())) ||
302 auto baseValue = dyn_cast<FIRRTLBaseValue>(v);
309 if (!baseValue.getType().isPassive() || !baseValue.getType().isGround())
321 if (v.getParentBlock() != ref.getValue().getParentBlock() ||
322 v.getParentBlock() !=
connect->getBlock())
331 if (
auto fop = ref.getValue().getDefiningOp<Forceable>();
332 fop && fop.isForceable())
336 auto sourceType = type_dyn_cast<FIRRTLBaseType>(ref.getValue().getType());
339 if (!sourceType.isPassive())
343 baseValue.getType() &&
344 "unexpected type mismatch, cast or extension?");
349 Value HWDriver::remat(PortMappingFn mapPortFn, ImplicitLocOpBuilder &
builder) {
350 auto mappedSource = mapPortFn(getIndex(source.getValue()));
365 MustDrivenBy() =
default;
366 MustDrivenBy(FModuleOp mod) { run(mod); }
369 Driver getDriverFor(Value v)
const {
return driverMap.lookup(v); }
373 Driver getCombinedDriverFor(Value v)
const {
374 Driver driver = driverMap.lookup(v);
381 SmallPtrSet<Value, 8> seen;
382 while ((cur = driverMap.lookup(cur.source.getValue()))) {
384 if (!seen.insert(cur.source.getValue()).second)
386 driver.source = cur.source.getSubField(driver.source.getFieldID());
390 LLVM_DEBUG(llvm::dbgs() <<
"Found driver for " << v <<
" (chain length = "
391 << len <<
"): " << driver <<
"\n");
396 void run(FModuleOp mod) {
397 SmallVector<Value, 64> worklist(mod.getArguments());
399 DenseSet<Value> enqueued;
400 enqueued.insert(worklist.begin(), worklist.end());
402 while (!worklist.empty()) {
403 auto val = worklist.pop_back_val();
406 driverMap.insert({val, driver});
410 auto sourceVal = driver.source.getValue();
413 if (!enqueued.insert(sourceVal).second)
424 if (!isa<BlockArgument>(sourceVal) &&
425 !isa_and_nonnull<WireOp, InstanceOp>(sourceVal.getDefiningOp()))
428 worklist.push_back(sourceVal);
434 llvm::dbgs() <<
"Analyzed " << mod.getModuleName() <<
" and found "
435 << driverMap.size() <<
" drivers.\n";
441 void clear() { driverMap.clear(); }
444 void setIgnoreHWDrivers(
bool ignore) { ignoreHWDrivers = ignore; }
448 DenseMap<Value, Driver> driverMap;
449 bool ignoreHWDrivers =
false;
460 struct HoistPassthroughPass
461 :
public HoistPassthroughBase<HoistPassthroughPass> {
462 using HoistPassthroughBase::HoistPassthroughBase;
463 void runOnOperation()
override;
465 using HoistPassthroughBase::hoistHWDrivers;
469 void HoistPassthroughPass::runOnOperation() {
471 auto &instanceGraph = getAnalysis<InstanceGraph>();
473 SmallVector<FModuleOp, 0> modules(llvm::make_filter_range(
475 llvm::post_order(&instanceGraph),
476 [](
auto *node) {
return dyn_cast<FModuleOp>(*node->getModule()); }),
477 [](
auto module) {
return module; }));
479 MustDrivenBy driverAnalysis;
480 driverAnalysis.setIgnoreHWDrivers(!hoistHWDrivers);
482 bool anyChanged =
false;
485 for (
auto module : modules) {
488 if (module.isPublic())
495 BitVector deadPorts(module.getNumPorts());
498 auto *igNode = instanceGraph.lookup(module);
501 driverAnalysis.clear();
502 driverAnalysis.run(module);
503 auto notNullAndCanHoist = [](
const Driver &d) ->
bool {
504 return d && d.canHoist();
506 SmallVector<Driver, 16> drivers(llvm::make_filter_range(
507 llvm::map_range(module.getArguments(),
508 [&driverAnalysis](
auto val) {
509 return driverAnalysis.getCombinedDriverFor(val);
511 notNullAndCanHoist));
522 for (
auto &driver : drivers) {
523 std::optional<size_t> deadPort;
525 auto destArg = driver.getDestBlockArg();
526 auto index = destArg.getArgNumber();
529 for (
auto *record : igNode->uses()) {
530 auto inst = cast<InstanceOp>(record->getInstance());
531 ImplicitLocOpBuilder
builder(inst.getLoc(), inst);
532 builder.setInsertionPointAfter(inst);
534 auto mappedDest = inst.getResult(index);
535 mappedDest.replaceAllUsesWith(driver.remat(
536 [&inst](
size_t index) { return inst.getResult(index); },
542 assert(deadPort.has_value());
544 assert(!deadPorts.test(*deadPort));
545 deadPorts.set(*deadPort);
548 TypeSwitch<Driver *, void>(&driver)
549 .Case<RefDriver>([&](
auto *) { ++numRefDrivers; })
550 .Case<HWDriver>([&](
auto *) { ++numHWDrivers; });
555 ImplicitLocOpBuilder
builder(module.getLoc(), module.getBody());
556 for (
auto &driver : drivers) {
558 builder.setLoc(driver.getDest().getLoc());
565 for (
auto *record : llvm::make_early_inc_range(igNode->uses())) {
566 auto inst = cast<InstanceOp>(record->getInstance());
567 ImplicitLocOpBuilder
builder(inst.getLoc(), inst);
569 assert(inst.getNumResults() == deadPorts.size());
570 auto newInst = inst.erasePorts(
builder, deadPorts);
571 instanceGraph.replaceInstance(inst, newInst);
576 module.erasePorts(deadPorts);
578 numUTurnsHoisted += deadPorts.count();
580 markAnalysesPreserved<InstanceGraph>();
583 markAllAnalysesPreserved();
587 std::unique_ptr<mlir::Pass>
589 auto pass = std::make_unique<HoistPassthroughPass>();
590 pass->hoistHWDrivers = hoistHWDrivers;
assert(baseType &&"element must be base type")
static RefDefineOp getRefDefine(Value result)
mlir::TypedValue< RefType > RefValue
static bool hasDontTouchOrInnerSymOnResult(Operation *op)
This class represents a reference to a specific field or element of an aggregate value.
Value getValue() const
Get the Value which created this location.
bool hasDontTouch() const
firrtl.transforms.DontTouchAnnotation
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
Caching version of getFieldRefFromValue.
void printStats(llvm::raw_ostream &os) const
void verify() const
Verify cached fieldRefs against firrtl::getFieldRefFromValue.
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Caching version of getFieldRefFromValue.
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
T & operator<<(T &os, FIRVersion version)
std::unique_ptr< mlir::Pass > createHoistPassthroughPass(bool hoistHWDrivers=true)
This is the pass constructor.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
StrictConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
uint64_t getMaxFieldID(Type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.