19 #include "mlir/Pass/Pass.h"
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/STLExtras.h"
25 #define GEN_PASS_DEF_EXPANDWHENS
26 #include "circt/Dialect/FIRRTL/Passes.h.inc"
30 using namespace circt;
31 using namespace firrtl;
35 static void mergeBlock(Block &destination, Block::iterator insertPoint,
37 destination.getOperations().splice(insertPoint, source.getOperations());
45 template <
typename KeyT,
typename ValueT>
47 using ScopeT =
typename llvm::MapVector<KeyT, ValueT>;
48 using StackT =
typename llvm::SmallVector<ScopeT, 3>;
52 typename ScopeT::iterator scopeIt)
53 : stackIt(stackIt), scopeIt(scopeIt) {}
61 std::pair<KeyT, ValueT> &
operator*()
const {
return *scopeIt; }
64 if (scopeIt == stackIt->end())
65 scopeIt = (++stackIt)->begin();
83 return Iterator(mapStack.begin(), mapStack.first().begin());
90 for (
auto i = mapStack.size(); i > 0; --i) {
91 auto &map = mapStack[i - 1];
92 auto it = map.find(key);
94 return Iterator(mapStack.begin() + i - 1, it);
104 assert(mapStack.size() > 1 &&
"Cannot pop the last scope");
105 return mapStack.pop_back_val();
109 ValueT &
operator[](
const KeyT &key) {
return mapStack.back()[key]; }
127 template <
typename ConcreteT>
128 class LastConnectResolver :
public FIRRTLVisitor<ConcreteT> {
136 LastConnectResolver(
ScopedDriverMap &driverMap) : driverMap(driverMap) {}
147 bool recordConnect(
FieldRef dest, Operation *connection) {
149 auto itAndInserted = driverMap.
getLastScope().insert({dest, connection});
150 if (isStaticSingleConnect(connection)) {
152 assert(itAndInserted.second || !itAndInserted.first->second);
153 if (!itAndInserted.second)
154 itAndInserted.first->second = connection;
157 assert(isLastConnect(connection));
158 if (!std::get<1>(itAndInserted)) {
159 auto iterator = std::get<0>(itAndInserted);
160 auto changed =
false;
163 if (
auto *oldConnect = iterator->second) {
167 iterator->second = connection;
175 static Value getDestinationValue(Operation *op) {
176 return cast<FConnectLike>(op).getDest();
181 static Value getConnectedValue(Operation *op) {
182 return cast<FConnectLike>(op).getSrc();
186 static bool isStaticSingleConnect(Operation *op) {
187 return cast<FConnectLike>(op).hasStaticSingleConnectBehavior();
194 static bool isLastConnect(Operation *op) {
195 return cast<FConnectLike>(op).hasLastConnectBehavior();
200 void declareSinks(Value value,
Flow flow,
bool local =
false) {
201 auto type = value.getType();
205 std::function<void(Type,
Flow,
bool)> declare = [&](Type type,
Flow flow,
208 if (
auto classType = type_dyn_cast<ClassType>(type)) {
212 for (
auto &element : classType.getElements()) {
215 declare(element.type, flow,
false);
217 declare(element.type,
swapFlow(flow),
false);
223 driverMap[{value,
id}] =
nullptr;
226 id += classType.getMaxFieldID();
232 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
233 for (
auto &element : bundleType.getElements()) {
236 declare(element.type,
swapFlow(flow),
false);
238 declare(element.type, flow,
false);
244 if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
245 for (
unsigned i = 0; i < vectorType.getNumElements(); ++i) {
247 declare(vectorType.getElementType(), flow,
false);
253 if (
auto analogType = type_dyn_cast<AnalogType>(type))
259 driverMap[{value,
id}] =
nullptr;
262 declare(type, flow, local);
267 ConnectOp flattenConditionalConnections(OpBuilder &b, Location loc,
268 Value dest, Value cond,
269 Operation *whenTrueConn,
270 Operation *whenFalseConn) {
271 assert(isLastConnect(whenTrueConn) && isLastConnect(whenFalseConn));
273 b.getFusedLoc({loc, whenTrueConn->getLoc(), whenFalseConn->getLoc()});
274 auto whenTrue = getConnectedValue(whenTrueConn);
276 isa_and_nonnull<InvalidValueOp>(whenTrue.getDefiningOp());
277 auto whenFalse = getConnectedValue(whenFalseConn);
278 auto falseIsInvalid =
279 isa_and_nonnull<InvalidValueOp>(whenFalse.getDefiningOp());
286 Value newValue = whenTrue;
287 if (trueIsInvalid == falseIsInvalid)
288 newValue = b.createOrFold<MuxPrimOp>(fusedLoc, cond, whenTrue, whenFalse);
289 else if (trueIsInvalid)
290 newValue = whenFalse;
291 return b.create<ConnectOp>(loc, dest, newValue);
294 void visitDecl(WireOp op) { declareSinks(op.getResult(),
Flow::Duplex); }
298 void foreachSubelement(OpBuilder &builder, Value value,
299 llvm::function_ref<
void(Value)> fn) {
301 .template Case<BundleType>([&](BundleType bundle) {
302 for (
auto i : llvm::seq(0u, (
unsigned)bundle.getNumElements())) {
304 builder.create<SubfieldOp>(value.getLoc(), value, i);
305 foreachSubelement(builder, subfield, fn);
308 .
template Case<FVectorType>([&](FVectorType vector) {
309 for (
auto i : llvm::seq((
size_t)0, vector.getNumElements())) {
311 builder.create<SubindexOp>(value.getLoc(), value, i);
312 foreachSubelement(builder, subindex, fn);
315 .Default([&](
auto) { fn(value); });
318 void visitDecl(RegOp op) {
321 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
322 auto fn = [&](Value value) {
323 auto connect = builder.create<ConnectOp>(value.getLoc(), value, value);
326 foreachSubelement(builder, op.getResult(), fn);
329 void visitDecl(RegResetOp op) {
332 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
333 auto fn = [&](Value value) {
334 auto connect = builder.create<ConnectOp>(value.getLoc(), value, value);
337 foreachSubelement(builder, op.getResult(), fn);
340 void visitDecl(InstanceOp op) {
343 for (
const auto &result : llvm::enumerate(op.getResults()))
350 void visitDecl(ObjectOp op) {
354 void visitDecl(MemOp op) {
356 for (
auto result : op.getResults())
357 if (!isa<RefType>(result.getType()))
361 void visitStmt(ConnectOp op) {
365 void visitStmt(MatchingConnectOp op) {
369 void visitStmt(RefDefineOp op) {
373 void visitStmt(PropAssignOp op) {
377 void processWhenOp(WhenOp whenOp, Value outerCondition);
396 Value thenCondition) {
399 for (
auto &destAndConnect : thenScope) {
400 auto dest = std::get<0>(destAndConnect);
401 auto thenConnect = std::get<1>(destAndConnect);
403 auto outerIt = driverMap.
find(dest);
404 if (outerIt == driverMap.
end()) {
407 driverMap[dest] = thenConnect;
411 auto elseIt = elseScope.
find(dest);
412 if (elseIt != elseScope.end()) {
417 auto &elseConnect = std::get<1>(*elseIt);
418 OpBuilder connectBuilder(elseConnect);
419 auto newConnect = flattenConditionalConnections(
420 connectBuilder, loc, getDestinationValue(thenConnect),
421 thenCondition, thenConnect, elseConnect);
424 thenConnect->erase();
425 elseConnect->erase();
426 recordConnect(dest, newConnect);
431 auto &outerConnect = std::get<1>(*outerIt);
433 if (isLastConnect(thenConnect)) {
436 thenConnect->erase();
438 assert(isStaticSingleConnect(thenConnect));
439 driverMap[dest] = thenConnect;
446 OpBuilder connectBuilder(thenConnect);
447 auto newConnect = flattenConditionalConnections(
448 connectBuilder, loc, getDestinationValue(thenConnect), thenCondition,
449 thenConnect, outerConnect);
452 thenConnect->erase();
453 recordConnect(dest, newConnect);
457 for (
auto &destAndConnect : elseScope) {
458 auto dest = std::get<0>(destAndConnect);
459 auto elseConnect = std::get<1>(destAndConnect);
463 if (thenScope.contains(dest))
466 auto outerIt = driverMap.
find(dest);
467 if (outerIt == driverMap.
end()) {
470 driverMap[dest] = elseConnect;
474 auto &outerConnect = std::get<1>(*outerIt);
476 if (isLastConnect(elseConnect)) {
479 elseConnect->erase();
481 assert(isStaticSingleConnect(elseConnect));
482 driverMap[dest] = elseConnect;
489 OpBuilder connectBuilder(elseConnect);
490 auto newConnect = flattenConditionalConnections(
491 connectBuilder, loc, getDestinationValue(outerConnect), thenCondition,
492 outerConnect, elseConnect);
495 elseConnect->erase();
496 recordConnect(dest, newConnect);
510 class WhenOpVisitor :
public LastConnectResolver<WhenOpVisitor> {
514 : LastConnectResolver<WhenOpVisitor>(driverMap), condition(condition) {}
516 using LastConnectResolver<WhenOpVisitor>::visitExpr;
517 using LastConnectResolver<WhenOpVisitor>::visitDecl;
518 using LastConnectResolver<WhenOpVisitor>::visitStmt;
519 using LastConnectResolver<WhenOpVisitor>::visitStmtExpr;
522 void process(Block &block);
525 void visitStmt(VerifAssertIntrinsicOp op);
526 void visitStmt(VerifAssumeIntrinsicOp op);
527 void visitStmt(VerifCoverIntrinsicOp op);
528 void visitStmt(AssertOp op);
529 void visitStmt(AssumeOp op);
530 void visitStmt(UnclockedAssumeIntrinsicOp op);
531 void visitStmt(CoverOp op);
532 void visitStmt(ModuleOp op);
533 void visitStmt(PrintFOp op);
534 void visitStmt(StopOp op);
535 void visitStmt(WhenOp op);
536 void visitStmt(LayerBlockOp op);
537 void visitStmt(RefForceOp op);
538 void visitStmt(RefForceInitialOp op);
539 void visitStmt(RefReleaseOp op);
540 void visitStmt(RefReleaseInitialOp op);
541 void visitStmtExpr(DPICallIntrinsicOp op);
546 Value andWithCondition(Operation *op, Value value) {
548 return OpBuilder(op).createOrFold<AndPrimOp>(
549 condition.getLoc(), condition.getType(), condition, value);
554 Value ltlAndWithCondition(Operation *op, Value value) {
556 return OpBuilder(op).createOrFold<LTLAndIntrinsicOp>(
557 condition.getLoc(), condition.getType(), condition, value);
563 Value ltlImplicationWithCondition(Operation *op, Value value) {
565 return OpBuilder(op).createOrFold<LTLImplicationIntrinsicOp>(
566 condition.getLoc(), condition.getType(), condition, value);
575 void WhenOpVisitor::process(Block &block) {
576 for (
auto &op : llvm::make_early_inc_range(block)) {
577 dispatchVisitor(&op);
581 void WhenOpVisitor::visitStmt(PrintFOp op) {
582 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
585 void WhenOpVisitor::visitStmt(StopOp op) {
586 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
589 void WhenOpVisitor::visitStmt(VerifAssertIntrinsicOp op) {
590 op.getPropertyMutable().assign(
591 ltlImplicationWithCondition(op, op.getProperty()));
594 void WhenOpVisitor::visitStmt(VerifAssumeIntrinsicOp op) {
595 op.getPropertyMutable().assign(
596 ltlImplicationWithCondition(op, op.getProperty()));
599 void WhenOpVisitor::visitStmt(VerifCoverIntrinsicOp op) {
600 op.getPropertyMutable().assign(ltlAndWithCondition(op, op.getProperty()));
603 void WhenOpVisitor::visitStmt(AssertOp op) {
604 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
607 void WhenOpVisitor::visitStmt(AssumeOp op) {
608 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
611 void WhenOpVisitor::visitStmt(UnclockedAssumeIntrinsicOp op) {
612 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
615 void WhenOpVisitor::visitStmt(CoverOp op) {
616 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
619 void WhenOpVisitor::visitStmt(WhenOp whenOp) {
620 processWhenOp(whenOp, condition);
624 void WhenOpVisitor::visitStmt(LayerBlockOp layerBlockOp) {
625 process(*layerBlockOp.getBody());
628 void WhenOpVisitor::visitStmt(RefForceOp op) {
629 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
632 void WhenOpVisitor::visitStmt(RefForceInitialOp op) {
633 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
636 void WhenOpVisitor::visitStmt(RefReleaseOp op) {
637 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
640 void WhenOpVisitor::visitStmt(RefReleaseInitialOp op) {
641 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
644 void WhenOpVisitor::visitStmtExpr(DPICallIntrinsicOp op) {
646 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
648 op.getEnableMutable().assign(condition);
656 template <
typename ConcreteT>
657 void LastConnectResolver<ConcreteT>::processWhenOp(WhenOp whenOp,
658 Value outerCondition) {
660 auto loc = whenOp.getLoc();
661 Block *parentBlock = whenOp->getBlock();
662 auto condition = whenOp.getCondition();
663 auto ui1Type = condition.getType();
671 Value thenCondition = whenOp.getCondition();
674 b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition, thenCondition);
676 auto &thenBlock = whenOp.getThenBlock();
677 driverMap.pushScope();
678 WhenOpVisitor(driverMap, thenCondition).process(thenBlock);
679 mergeBlock(*parentBlock, Block::iterator(whenOp), thenBlock);
680 auto thenScope = driverMap.popScope();
684 if (whenOp.hasElseRegion()) {
687 b.createOrFold<NotPrimOp>(loc, condition.getType(), condition);
690 elseCondition = b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition,
692 auto &elseBlock = whenOp.getElseBlock();
693 driverMap.pushScope();
694 WhenOpVisitor(driverMap, elseCondition).process(elseBlock);
695 mergeBlock(*parentBlock, Block::iterator(whenOp), elseBlock);
696 elseScope = driverMap.popScope();
699 mergeScopes(loc, thenScope, elseScope, condition);
711 class ModuleVisitor :
public LastConnectResolver<ModuleVisitor> {
713 ModuleVisitor() : LastConnectResolver<ModuleVisitor>(driverMap) {}
715 using LastConnectResolver<ModuleVisitor>::visitExpr;
716 using LastConnectResolver<ModuleVisitor>::visitDecl;
717 using LastConnectResolver<ModuleVisitor>::visitStmt;
718 void visitStmt(WhenOp whenOp);
719 void visitStmt(ConnectOp connectOp);
720 void visitStmt(MatchingConnectOp connectOp);
721 void visitStmt(LayerBlockOp layerBlockOp);
723 bool run(FModuleLike op);
724 LogicalResult checkInitialization();
731 bool anythingChanged =
false;
740 if (!isa<FModuleOp, ClassOp>(op))
741 return anythingChanged;
743 for (
auto ®ion : op->getRegions()) {
744 for (
auto &block : region.getBlocks()) {
746 for (
const auto &[index, value] : llvm::enumerate(block.getArguments())) {
747 auto direction = op.getPortDirection(index);
749 declareSinks(value, flow);
753 for (
auto &op : llvm::make_early_inc_range(block))
754 dispatchVisitor(&op);
758 return anythingChanged;
761 void ModuleVisitor::visitStmt(ConnectOp op) {
765 void ModuleVisitor::visitStmt(MatchingConnectOp op) {
769 void ModuleVisitor::visitStmt(WhenOp whenOp) {
771 anythingChanged =
true;
772 processWhenOp(whenOp, {});
775 void ModuleVisitor::visitStmt(LayerBlockOp layerBlockOp) {
776 for (
auto &op : llvm::make_early_inc_range(*layerBlockOp.getBody())) {
777 dispatchVisitor(&op);
783 LogicalResult ModuleVisitor::checkInitialization() {
785 for (
auto destAndConnect : driverMap.getLastScope()) {
787 auto *
connect = std::get<1>(destAndConnect);
792 FieldRef dest = std::get<0>(destAndConnect);
793 auto loc = dest.
getValue().getLoc();
795 if (
auto mod = dyn_cast<FModuleLike>(definingOp))
796 mlir::emitError(loc) <<
"port \"" <<
getFieldName(dest).first
797 <<
"\" not fully initialized in \""
798 << mod.getModuleName() <<
"\"";
802 <<
"\" not fully initialized in \""
803 << definingOp->getParentOfType<FModuleLike>().getModuleName() <<
"\"";
816 class ExpandWhensPass
817 :
public circt::firrtl::impl::ExpandWhensBase<ExpandWhensPass> {
818 void runOnOperation()
override;
822 void ExpandWhensPass::runOnOperation() {
823 ModuleVisitor visitor;
824 if (!visitor.run(getOperation()))
825 markAllAnalysesPreserved();
826 if (failed(visitor.checkInitialization()))
831 return std::make_unique<ExpandWhensPass>();
assert(baseType &&"element must be base type")
ScopedDriverMap::ScopeT DriverMap
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
This class represents a reference to a specific field or element of an aggregate value.
Value getValue() const
Get the Value which created this location.
Operation * getDefiningOp() const
Get the operation which defines this field.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
def connect(destination, source)
Flow swapFlow(Flow flow)
Get a flow's reverse.
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
std::unique_ptr< mlir::Pass > createExpandWhensPass()
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
std::pair< KeyT, ValueT > & operator*() const
bool operator!=(const Iterator &rhs) const
bool operator==(const Iterator &rhs) const
Iterator(typename StackT::iterator stackIt, typename ScopeT::iterator scopeIt)
This is a stack of hashtables, if lookup fails in the top-most hashtable, it will attempt to lookup i...
typename llvm::MapVector< KeyT, ValueT > ScopeT
iterator find(const KeyT &key)
ValueT & operator[](const KeyT &key)
typename llvm::SmallVector< ScopeT, 3 > StackT