19#include "mlir/Pass/Pass.h" 
   20#include "llvm/ADT/MapVector.h" 
   21#include "llvm/ADT/STLExtras.h" 
   25#define GEN_PASS_DEF_EXPANDWHENS 
   26#include "circt/Dialect/FIRRTL/Passes.h.inc" 
   31using namespace firrtl;
 
   35static void mergeBlock(Block &destination, Block::iterator insertPoint,
 
   37  destination.getOperations().splice(insertPoint, source.getOperations());
 
 
   45template <
typename KeyT, 
typename ValueT>
 
   47  using ScopeT = 
typename llvm::MapVector<KeyT, ValueT>;
 
   48  using StackT = 
typename llvm::SmallVector<ScopeT, 3>;
 
   52             typename ScopeT::iterator 
scopeIt)
 
 
 
   90    for (
auto i = 
mapStack.size(); i > 0; --i) {
 
   92      auto it = map.find(key);
 
 
 
  127template <
typename ConcreteT>
 
  136  LastConnectResolver(
ScopedDriverMap &driverMap) : driverMap(driverMap) {}
 
  147  bool recordConnect(
FieldRef dest, Operation *connection) {
 
  149    auto itAndInserted = driverMap.getLastScope().insert({dest, connection});
 
  150    if (isStaticSingleConnect(connection)) {
 
  152      assert(itAndInserted.second || !itAndInserted.first->second);
 
  153      if (!itAndInserted.second)
 
  154        itAndInserted.first->second = connection;
 
  157    assert(isLastConnect(connection));
 
  158    if (!std::get<1>(itAndInserted)) {
 
  159      auto iterator = std::get<0>(itAndInserted);
 
  160      auto changed = 
false;
 
  163      if (
auto *oldConnect = iterator->second) {
 
  167      iterator->second = connection;
 
  175  static Value getDestinationValue(Operation *op) {
 
  176    return cast<FConnectLike>(op).getDest();
 
  181  static Value getConnectedValue(Operation *op) {
 
  182    return cast<FConnectLike>(op).getSrc();
 
  186  static bool isStaticSingleConnect(Operation *op) {
 
  187    return cast<FConnectLike>(op).hasStaticSingleConnectBehavior();
 
  194  static bool isLastConnect(Operation *op) {
 
  195    return cast<FConnectLike>(op).hasLastConnectBehavior();
 
  200  void declareSinks(Value value, 
Flow flow, 
bool local = 
false) {
 
  201    auto type = value.getType();
 
  205    std::function<void(Type, 
Flow, 
bool)> declare = [&](Type type, 
Flow flow,
 
  208      if (
auto classType = type_dyn_cast<ClassType>(type)) {
 
  212          for (
auto &element : classType.getElements()) {
 
  214            if (element.direction == Direction::Out)
 
  215              declare(element.type, flow, 
false);
 
  217              declare(element.type, 
swapFlow(flow), 
false);
 
  222          if (flow != Flow::Source)
 
  223            driverMap[{value, 
id}] = 
nullptr;
 
  226          id += classType.getMaxFieldID();
 
  232      if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
 
  233        for (
auto &element : bundleType.getElements()) {
 
  236            declare(element.type, 
swapFlow(flow), 
false);
 
  238            declare(element.type, flow, 
false);
 
  244      if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
 
  245        for (
unsigned i = 0; i < vectorType.getNumElements(); ++i) {
 
  247          declare(vectorType.getElementType(), flow, 
false);
 
  253      if (
auto analogType = type_dyn_cast<AnalogType>(type))
 
  258      if (flow != Flow::Source)
 
  259        driverMap[{value, 
id}] = 
nullptr;
 
  262    declare(type, flow, local);
 
  267  ConnectOp flattenConditionalConnections(OpBuilder &b, Location loc,
 
  268                                          Value dest, Value cond,
 
  269                                          Operation *whenTrueConn,
 
  270                                          Operation *whenFalseConn) {
 
  271    assert(isLastConnect(whenTrueConn) && isLastConnect(whenFalseConn));
 
  273        b.getFusedLoc({loc, whenTrueConn->getLoc(), whenFalseConn->getLoc()});
 
  274    auto whenTrue = getConnectedValue(whenTrueConn);
 
  276        isa_and_nonnull<InvalidValueOp>(whenTrue.getDefiningOp());
 
  277    auto whenFalse = getConnectedValue(whenFalseConn);
 
  278    auto falseIsInvalid =
 
  279        isa_and_nonnull<InvalidValueOp>(whenFalse.getDefiningOp());
 
  286    Value newValue = whenTrue;
 
  287    if (trueIsInvalid == falseIsInvalid)
 
  288      newValue = b.createOrFold<MuxPrimOp>(fusedLoc, cond, whenTrue, whenFalse);
 
  289    else if (trueIsInvalid)
 
  290      newValue = whenFalse;
 
  291    return ConnectOp::create(b, loc, dest, newValue);
 
  294  void visitDecl(WireOp op) { declareSinks(op.getResult(), Flow::Duplex); }
 
  298  void foreachSubelement(OpBuilder &builder, Value value,
 
  299                         llvm::function_ref<
void(Value)> fn) {
 
  301        .template Case<BundleType>([&](BundleType bundle) {
 
  302          for (
auto i : llvm::seq(0u, (
unsigned)bundle.getNumElements())) {
 
  304                SubfieldOp::create(builder, value.getLoc(), value, i);
 
  305            foreachSubelement(builder, subfield, fn);
 
  308        .
template Case<FVectorType>([&](FVectorType vector) {
 
  309          for (
auto i : llvm::seq((
size_t)0, vector.getNumElements())) {
 
  311                SubindexOp::create(builder, value.getLoc(), value, i);
 
  312            foreachSubelement(builder, subindex, fn);
 
  315        .Default([&](
auto) { fn(value); });
 
  318  void visitDecl(RegOp op) {
 
  321    auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
 
  322    auto fn = [&](Value value) {
 
  323      auto connect = ConnectOp::create(builder, value.getLoc(), value, value);
 
  326    foreachSubelement(builder, op.getResult(), fn);
 
  329  void visitDecl(RegResetOp op) {
 
  332    auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
 
  333    auto fn = [&](Value value) {
 
  334      auto connect = ConnectOp::create(builder, value.getLoc(), value, value);
 
  337    foreachSubelement(builder, op.getResult(), fn);
 
  340  void visitDecl(InstanceOp op) {
 
  343    for (
const auto &result : llvm::enumerate(op.getResults()))
 
  347        declareSinks(result.value(), Flow::Sink);
 
  350  void visitDecl(ObjectOp op) {
 
  351    declareSinks(op, Flow::Source, 
true);
 
  354  void visitDecl(MemOp op) {
 
  356    for (
auto result : op.getResults())
 
  357      if (!isa<RefType>(result.getType()))
 
  358        declareSinks(result, Flow::Sink);
 
  361  void visitStmt(ConnectOp op) {
 
  365  void visitStmt(MatchingConnectOp op) {
 
  369  void visitStmt(RefDefineOp op) {
 
  373  void visitStmt(PropAssignOp op) {
 
  377  void visitStmt(DomainDefineOp op) {
 
  381  void processWhenOp(WhenOp whenOp, Value outerCondition);
 
  400                   Value thenCondition) {
 
  403    for (
auto &destAndConnect : thenScope) {
 
  404      auto dest = std::get<0>(destAndConnect);
 
  405      auto thenConnect = std::get<1>(destAndConnect);
 
  407      auto outerIt = driverMap.find(dest);
 
  408      if (outerIt == driverMap.end()) {
 
  411        driverMap[dest] = thenConnect;
 
  415      auto elseIt = elseScope.find(dest);
 
  416      if (elseIt != elseScope.end()) {
 
  421        auto &elseConnect = std::get<1>(*elseIt);
 
  422        OpBuilder connectBuilder(elseConnect);
 
  423        auto newConnect = flattenConditionalConnections(
 
  424            connectBuilder, loc, getDestinationValue(thenConnect),
 
  425            thenCondition, thenConnect, elseConnect);
 
  428        thenConnect->erase();
 
  429        elseConnect->erase();
 
  430        recordConnect(dest, newConnect);
 
  435      auto &outerConnect = std::get<1>(*outerIt);
 
  437        if (isLastConnect(thenConnect)) {
 
  440          thenConnect->erase();
 
  442          assert(isStaticSingleConnect(thenConnect));
 
  443          driverMap[dest] = thenConnect;
 
  450      OpBuilder connectBuilder(thenConnect);
 
  451      auto newConnect = flattenConditionalConnections(
 
  452          connectBuilder, loc, getDestinationValue(thenConnect), thenCondition,
 
  453          thenConnect, outerConnect);
 
  456      thenConnect->erase();
 
  457      recordConnect(dest, newConnect);
 
  461    for (
auto &destAndConnect : elseScope) {
 
  462      auto dest = std::get<0>(destAndConnect);
 
  463      auto elseConnect = std::get<1>(destAndConnect);
 
  467      if (thenScope.contains(dest))
 
  470      auto outerIt = driverMap.find(dest);
 
  471      if (outerIt == driverMap.end()) {
 
  474        driverMap[dest] = elseConnect;
 
  478      auto &outerConnect = std::get<1>(*outerIt);
 
  480        if (isLastConnect(elseConnect)) {
 
  483          elseConnect->erase();
 
  485          assert(isStaticSingleConnect(elseConnect));
 
  486          driverMap[dest] = elseConnect;
 
  493      OpBuilder connectBuilder(elseConnect);
 
  494      auto newConnect = flattenConditionalConnections(
 
  495          connectBuilder, loc, getDestinationValue(outerConnect), thenCondition,
 
  496          outerConnect, elseConnect);
 
  499      elseConnect->erase();
 
  500      recordConnect(dest, newConnect);
 
  514class WhenOpVisitor : 
public LastConnectResolver<WhenOpVisitor> {
 
  518      : LastConnectResolver<WhenOpVisitor>(driverMap), condition(condition) {}
 
  520  using LastConnectResolver<WhenOpVisitor>::visitExpr;
 
  521  using LastConnectResolver<WhenOpVisitor>::visitDecl;
 
  522  using LastConnectResolver<WhenOpVisitor>::visitStmt;
 
  523  using LastConnectResolver<WhenOpVisitor>::visitStmtExpr;
 
  526  void process(Block &block);
 
  529  void visitStmt(VerifAssertIntrinsicOp op);
 
  530  void visitStmt(VerifAssumeIntrinsicOp op);
 
  531  void visitStmt(VerifCoverIntrinsicOp op);
 
  532  void visitStmt(AssertOp op);
 
  533  void visitStmt(AssumeOp op);
 
  534  void visitStmt(UnclockedAssumeIntrinsicOp op);
 
  535  void visitStmt(CoverOp op);
 
  536  void visitStmt(ModuleOp op);
 
  537  void visitStmt(PrintFOp op);
 
  538  void visitStmt(FPrintFOp op);
 
  539  void visitStmt(FFlushOp op);
 
  540  void visitStmt(StopOp op);
 
  541  void visitStmt(WhenOp op);
 
  542  void visitStmt(LayerBlockOp op);
 
  543  void visitStmt(RefForceOp op);
 
  544  void visitStmt(RefForceInitialOp op);
 
  545  void visitStmt(RefReleaseOp op);
 
  546  void visitStmt(RefReleaseInitialOp op);
 
  547  void visitStmtExpr(DPICallIntrinsicOp op);
 
  552  Value andWithCondition(Operation *op, Value value) {
 
  554    return OpBuilder(op).createOrFold<AndPrimOp>(
 
  555        condition.getLoc(), condition.getType(), condition, value);
 
  560  Value ltlAndWithCondition(Operation *op, Value property) {
 
  562    while (
auto nodeOp = property.getDefiningOp<NodeOp>())
 
  563      property = nodeOp.getInput();
 
  566    if (
auto clockOp = property.getDefiningOp<LTLClockIntrinsicOp>()) {
 
  567      auto input = ltlAndWithCondition(op, clockOp.getInput());
 
  568      auto &newClockOp = createdLTLClockOps[{clockOp, input}];
 
  570        newClockOp = OpBuilder(op).cloneWithoutRegions(clockOp);
 
  571        newClockOp.getInputMutable().assign(input);
 
  577    auto &newOp = createdLTLAndOps[{condition, 
property}];
 
  579      newOp = OpBuilder(op).createOrFold<LTLAndIntrinsicOp>(
 
  580          condition.getLoc(), 
property.getType(), condition, property);
 
  587  Value ltlImplicationWithCondition(Operation *op, Value property) {
 
  589    while (
auto nodeOp = property.getDefiningOp<NodeOp>())
 
  590      property = nodeOp.getInput();
 
  593    if (
auto clockOp = property.getDefiningOp<LTLClockIntrinsicOp>()) {
 
  594      auto input = ltlImplicationWithCondition(op, clockOp.getInput());
 
  595      auto &newClockOp = createdLTLClockOps[{clockOp, input}];
 
  597        newClockOp = OpBuilder(op).cloneWithoutRegions(clockOp);
 
  598        newClockOp.getInputMutable().assign(input);
 
  604    if (
auto implOp = property.getDefiningOp<LTLImplicationIntrinsicOp>()) {
 
  605      auto lhs = ltlAndWithCondition(op, implOp.getLhs());
 
  606      auto &newImplOp = createdLTLImplicationOps[{lhs, implOp.getRhs()}];
 
  608        auto clonedOp = OpBuilder(op).cloneWithoutRegions(implOp);
 
  609        clonedOp.getLhsMutable().assign(lhs);
 
  610        newImplOp = clonedOp;
 
  616    auto &newImplOp = createdLTLImplicationOps[{condition, 
property}];
 
  618      newImplOp = OpBuilder(op).createOrFold<LTLImplicationIntrinsicOp>(
 
  619          condition.getLoc(), 
property.getType(), condition, property);
 
  639void WhenOpVisitor::process(Block &block) {
 
  640  for (
auto &op : 
llvm::make_early_inc_range(block)) {
 
  641    dispatchVisitor(&op);
 
  645void WhenOpVisitor::visitStmt(PrintFOp op) {
 
  646  op.getCondMutable().assign(andWithCondition(op, op.getCond()));
 
  649void WhenOpVisitor::visitStmt(FPrintFOp op) {
 
  650  op.getCondMutable().assign(andWithCondition(op, op.getCond()));
 
  653void WhenOpVisitor::visitStmt(FFlushOp op) {
 
  654  op.getCondMutable().assign(andWithCondition(op, op.getCond()));
 
  657void WhenOpVisitor::visitStmt(StopOp op) {
 
  658  op.getCondMutable().assign(andWithCondition(op, op.getCond()));
 
  661void WhenOpVisitor::visitStmt(VerifAssertIntrinsicOp op) {
 
  662  op.getPropertyMutable().assign(
 
  663      ltlImplicationWithCondition(op, op.getProperty()));
 
  666void WhenOpVisitor::visitStmt(VerifAssumeIntrinsicOp op) {
 
  667  op.getPropertyMutable().assign(
 
  668      ltlImplicationWithCondition(op, op.getProperty()));
 
  671void WhenOpVisitor::visitStmt(VerifCoverIntrinsicOp op) {
 
  672  op.getPropertyMutable().assign(ltlAndWithCondition(op, op.getProperty()));
 
  675void WhenOpVisitor::visitStmt(AssertOp op) {
 
  676  op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
 
  679void WhenOpVisitor::visitStmt(AssumeOp op) {
 
  680  op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
 
  683void WhenOpVisitor::visitStmt(UnclockedAssumeIntrinsicOp op) {
 
  684  op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
 
  687void WhenOpVisitor::visitStmt(CoverOp op) {
 
  688  op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
 
  691void WhenOpVisitor::visitStmt(WhenOp whenOp) {
 
  692  processWhenOp(whenOp, condition);
 
  696void WhenOpVisitor::visitStmt(LayerBlockOp layerBlockOp) {
 
  697  process(*layerBlockOp.getBody());
 
  700void WhenOpVisitor::visitStmt(RefForceOp op) {
 
  701  op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
 
  704void WhenOpVisitor::visitStmt(RefForceInitialOp op) {
 
  705  op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
 
  708void WhenOpVisitor::visitStmt(RefReleaseOp op) {
 
  709  op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
 
  712void WhenOpVisitor::visitStmt(RefReleaseInitialOp op) {
 
  713  op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
 
  716void WhenOpVisitor::visitStmtExpr(DPICallIntrinsicOp op) {
 
  718    op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
 
  720    op.getEnableMutable().assign(condition);
 
  728template <
typename ConcreteT>
 
  729void LastConnectResolver<ConcreteT>::processWhenOp(WhenOp whenOp,
 
  730                                                   Value outerCondition) {
 
  732  auto loc = whenOp.getLoc();
 
  733  Block *parentBlock = whenOp->getBlock();
 
  734  auto condition = whenOp.getCondition();
 
  735  auto ui1Type = condition.getType();
 
  743  Value thenCondition = whenOp.getCondition();
 
  746        b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition, thenCondition);
 
  748  auto &thenBlock = whenOp.getThenBlock();
 
  749  driverMap.pushScope();
 
  750  WhenOpVisitor(driverMap, thenCondition).process(thenBlock);
 
  751  mergeBlock(*parentBlock, Block::iterator(whenOp), thenBlock);
 
  752  auto thenScope = driverMap.popScope();
 
  756  if (whenOp.hasElseRegion()) {
 
  759        b.createOrFold<NotPrimOp>(loc, condition.getType(), condition);
 
  762      elseCondition = b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition,
 
  764    auto &elseBlock = whenOp.getElseBlock();
 
  765    driverMap.pushScope();
 
  766    WhenOpVisitor(driverMap, elseCondition).process(elseBlock);
 
  767    mergeBlock(*parentBlock, Block::iterator(whenOp), elseBlock);
 
  768    elseScope = driverMap.popScope();
 
  771  mergeScopes(loc, thenScope, elseScope, condition);
 
  783class ModuleVisitor : 
public LastConnectResolver<ModuleVisitor> {
 
  785  ModuleVisitor() : LastConnectResolver<ModuleVisitor>(driverMap) {}
 
  787  using LastConnectResolver<ModuleVisitor>::visitExpr;
 
  788  using LastConnectResolver<ModuleVisitor>::visitDecl;
 
  789  using LastConnectResolver<ModuleVisitor>::visitStmt;
 
  790  void visitStmt(WhenOp whenOp);
 
  791  void visitStmt(ConnectOp connectOp);
 
  792  void visitStmt(MatchingConnectOp connectOp);
 
  793  void visitStmt(LayerBlockOp layerBlockOp);
 
  795  bool run(FModuleLike op);
 
  796  LogicalResult checkInitialization();
 
  803  bool anythingChanged = 
false;
 
  810bool ModuleVisitor::run(FModuleLike op) {
 
  812  if (!isa<FModuleOp, ClassOp>(op))
 
  813    return anythingChanged;
 
  815  for (
auto ®ion : op->getRegions()) {
 
  816    for (
auto &block : region.getBlocks()) {
 
  818      for (
const auto &[index, value] : 
llvm::enumerate(block.getArguments())) {
 
  819        auto direction = op.getPortDirection(index);
 
  820        auto flow = direction == Direction::In ? Flow::Source : Flow::Sink;
 
  821        declareSinks(value, flow);
 
  825      for (
auto &op : 
llvm::make_early_inc_range(block))
 
  826        dispatchVisitor(&op);
 
  830  return anythingChanged;
 
  833void ModuleVisitor::visitStmt(ConnectOp op) {
 
  837void ModuleVisitor::visitStmt(MatchingConnectOp op) {
 
  841void ModuleVisitor::visitStmt(WhenOp whenOp) {
 
  843  anythingChanged = 
true;
 
  844  processWhenOp(whenOp, {});
 
  847void ModuleVisitor::visitStmt(LayerBlockOp layerBlockOp) {
 
  848  for (
auto &op : 
llvm::make_early_inc_range(*layerBlockOp.getBody())) {
 
  849    dispatchVisitor(&op);
 
  855LogicalResult ModuleVisitor::checkInitialization() {
 
  857  for (
auto destAndConnect : driverMap.getLastScope()) {
 
  859    auto *
connect = std::get<1>(destAndConnect);
 
  864    FieldRef dest = std::get<0>(destAndConnect);
 
  865    auto loc = dest.
getValue().getLoc();
 
  867    if (
auto mod = dyn_cast<FModuleLike>(definingOp))
 
  868      mlir::emitError(loc) << 
"port \"" << 
getFieldName(dest).first
 
  869                           << 
"\" not fully initialized in \"" 
  870                           << mod.getModuleName() << 
"\"";
 
  874          << 
"\" not fully initialized in \"" 
  875          << definingOp->getParentOfType<FModuleLike>().getModuleName() << 
"\"";
 
  889    : 
public circt::firrtl::impl::ExpandWhensBase<ExpandWhensPass> {
 
  890  void runOnOperation() 
override;
 
  894void ExpandWhensPass::runOnOperation() {
 
  895  ModuleVisitor visitor;
 
  896  if (!visitor.run(getOperation()))
 
  897    markAllAnalysesPreserved();
 
  898  if (failed(visitor.checkInitialization()))
 
assert(baseType &&"element must be base type")
ScopedDriverMap::ScopeT DriverMap
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
HashTableStack< FieldRef, Operation * > ScopedDriverMap
This is a determistic mapping of a FieldRef to the last operation which set a value to it.
This class represents a reference to a specific field or element of an aggregate value.
Value getValue() const
Get the Value which created this location.
Operation * getDefiningOp() const
Get the operation which defines this field.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
connect(destination, source)
Flow swapFlow(Flow flow)
Get a flow's reverse.
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
bool operator!=(const Iterator &rhs) const
bool operator==(const Iterator &rhs) const
std::pair< KeyT, ValueT > & operator*() const
Iterator(typename StackT::iterator stackIt, typename ScopeT::iterator scopeIt)
This is a stack of hashtables, if lookup fails in the top-most hashtable, it will attempt to lookup i...
typename llvm::MapVector< KeyT, ValueT > ScopeT
iterator find(const KeyT &key)
typename llvm::SmallVector< ScopeT, 3 > StackT
ValueT & operator[](const KeyT &key)