19 #include "mlir/Pass/Pass.h"
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/STLExtras.h"
25 #define GEN_PASS_DEF_EXPANDWHENS
26 #include "circt/Dialect/FIRRTL/Passes.h.inc"
30 using namespace circt;
31 using namespace firrtl;
35 static void mergeBlock(Block &destination, Block::iterator insertPoint,
37 destination.getOperations().splice(insertPoint, source.getOperations());
45 template <
typename KeyT,
typename ValueT>
47 using ScopeT =
typename llvm::MapVector<KeyT, ValueT>;
48 using StackT =
typename llvm::SmallVector<ScopeT, 3>;
52 typename ScopeT::iterator scopeIt)
53 : stackIt(stackIt), scopeIt(scopeIt) {}
61 std::pair<KeyT, ValueT> &
operator*()
const {
return *scopeIt; }
64 if (scopeIt == stackIt->end())
65 scopeIt = (++stackIt)->begin();
83 return Iterator(mapStack.begin(), mapStack.first().begin());
90 for (
auto i = mapStack.size(); i > 0; --i) {
91 auto &map = mapStack[i - 1];
92 auto it = map.find(key);
94 return Iterator(mapStack.begin() + i - 1, it);
104 assert(mapStack.size() > 1 &&
"Cannot pop the last scope");
105 return mapStack.pop_back_val();
109 ValueT &
operator[](
const KeyT &key) {
return mapStack.back()[key]; }
127 template <
typename ConcreteT>
128 class LastConnectResolver :
public FIRRTLVisitor<ConcreteT> {
136 LastConnectResolver(
ScopedDriverMap &driverMap) : driverMap(driverMap) {}
147 bool recordConnect(
FieldRef dest, Operation *connection) {
149 auto itAndInserted = driverMap.
getLastScope().insert({dest, connection});
150 if (isStaticSingleConnect(connection)) {
152 assert(itAndInserted.second || !itAndInserted.first->second);
153 if (!itAndInserted.second)
154 itAndInserted.first->second = connection;
157 assert(isLastConnect(connection));
158 if (!std::get<1>(itAndInserted)) {
159 auto iterator = std::get<0>(itAndInserted);
160 auto changed =
false;
163 if (
auto *oldConnect = iterator->second) {
167 iterator->second = connection;
175 static Value getDestinationValue(Operation *op) {
176 return cast<FConnectLike>(op).getDest();
181 static Value getConnectedValue(Operation *op) {
182 return cast<FConnectLike>(op).getSrc();
186 static bool isStaticSingleConnect(Operation *op) {
187 return cast<FConnectLike>(op).hasStaticSingleConnectBehavior();
194 static bool isLastConnect(Operation *op) {
195 return cast<FConnectLike>(op).hasLastConnectBehavior();
200 void declareSinks(Value value,
Flow flow,
bool local =
false) {
201 auto type = value.getType();
205 std::function<void(Type,
Flow,
bool)> declare = [&](Type type,
Flow flow,
208 if (
auto classType = type_dyn_cast<ClassType>(type)) {
212 for (
auto &element : classType.getElements()) {
215 declare(element.type, flow,
false);
217 declare(element.type,
swapFlow(flow),
false);
223 driverMap[{value,
id}] =
nullptr;
226 id += classType.getMaxFieldID();
232 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
233 for (
auto &element : bundleType.getElements()) {
236 declare(element.type,
swapFlow(flow),
false);
238 declare(element.type, flow,
false);
244 if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
245 for (
unsigned i = 0; i < vectorType.getNumElements(); ++i) {
247 declare(vectorType.getElementType(), flow,
false);
253 if (
auto analogType = type_dyn_cast<AnalogType>(type))
259 driverMap[{value,
id}] =
nullptr;
262 declare(type, flow, local);
267 ConnectOp flattenConditionalConnections(OpBuilder &b, Location loc,
268 Value dest, Value cond,
269 Operation *whenTrueConn,
270 Operation *whenFalseConn) {
271 assert(isLastConnect(whenTrueConn) && isLastConnect(whenFalseConn));
273 b.getFusedLoc({loc, whenTrueConn->getLoc(), whenFalseConn->getLoc()});
274 auto whenTrue = getConnectedValue(whenTrueConn);
276 isa_and_nonnull<InvalidValueOp>(whenTrue.getDefiningOp());
277 auto whenFalse = getConnectedValue(whenFalseConn);
278 auto falseIsInvalid =
279 isa_and_nonnull<InvalidValueOp>(whenFalse.getDefiningOp());
286 Value newValue = whenTrue;
287 if (trueIsInvalid == falseIsInvalid)
288 newValue = b.createOrFold<MuxPrimOp>(fusedLoc, cond, whenTrue, whenFalse);
289 else if (trueIsInvalid)
290 newValue = whenFalse;
291 return b.create<ConnectOp>(loc, dest, newValue);
294 void visitDecl(WireOp op) { declareSinks(op.getResult(),
Flow::Duplex); }
298 void foreachSubelement(OpBuilder &builder, Value value,
299 llvm::function_ref<
void(Value)> fn) {
301 .template Case<BundleType>([&](BundleType bundle) {
302 for (
auto i : llvm::seq(0u, (
unsigned)bundle.getNumElements())) {
304 builder.create<SubfieldOp>(value.getLoc(), value, i);
305 foreachSubelement(builder, subfield, fn);
308 .
template Case<FVectorType>([&](FVectorType vector) {
309 for (
auto i : llvm::seq((
size_t)0, vector.getNumElements())) {
311 builder.create<SubindexOp>(value.getLoc(), value, i);
312 foreachSubelement(builder, subindex, fn);
315 .Default([&](
auto) { fn(value); });
318 void visitDecl(RegOp op) {
321 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
322 auto fn = [&](Value value) {
323 auto connect = builder.create<ConnectOp>(value.getLoc(), value, value);
326 foreachSubelement(builder, op.getResult(), fn);
329 void visitDecl(RegResetOp op) {
332 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
333 auto fn = [&](Value value) {
334 auto connect = builder.create<ConnectOp>(value.getLoc(), value, value);
337 foreachSubelement(builder, op.getResult(), fn);
340 void visitDecl(InstanceOp op) {
343 for (
const auto &result : llvm::enumerate(op.getResults()))
350 void visitDecl(ObjectOp op) {
354 void visitDecl(MemOp op) {
356 for (
auto result : op.getResults())
357 if (!isa<RefType>(result.getType()))
361 void visitStmt(ConnectOp op) {
365 void visitStmt(MatchingConnectOp op) {
369 void visitStmt(RefDefineOp op) {
373 void visitStmt(PropAssignOp op) {
377 void processWhenOp(WhenOp whenOp, Value outerCondition);
396 Value thenCondition) {
399 for (
auto &destAndConnect : thenScope) {
400 auto dest = std::get<0>(destAndConnect);
401 auto thenConnect = std::get<1>(destAndConnect);
403 auto outerIt = driverMap.
find(dest);
404 if (outerIt == driverMap.
end()) {
407 driverMap[dest] = thenConnect;
411 auto elseIt = elseScope.
find(dest);
412 if (elseIt != elseScope.end()) {
417 auto &elseConnect = std::get<1>(*elseIt);
418 OpBuilder connectBuilder(elseConnect);
419 auto newConnect = flattenConditionalConnections(
420 connectBuilder, loc, getDestinationValue(thenConnect),
421 thenCondition, thenConnect, elseConnect);
424 thenConnect->erase();
425 elseConnect->erase();
426 recordConnect(dest, newConnect);
431 auto &outerConnect = std::get<1>(*outerIt);
433 if (isLastConnect(thenConnect)) {
436 thenConnect->erase();
438 assert(isStaticSingleConnect(thenConnect));
439 driverMap[dest] = thenConnect;
446 OpBuilder connectBuilder(thenConnect);
447 auto newConnect = flattenConditionalConnections(
448 connectBuilder, loc, getDestinationValue(thenConnect), thenCondition,
449 thenConnect, outerConnect);
452 thenConnect->erase();
453 recordConnect(dest, newConnect);
457 for (
auto &destAndConnect : elseScope) {
458 auto dest = std::get<0>(destAndConnect);
459 auto elseConnect = std::get<1>(destAndConnect);
463 if (thenScope.contains(dest))
466 auto outerIt = driverMap.
find(dest);
467 if (outerIt == driverMap.
end()) {
470 driverMap[dest] = elseConnect;
474 auto &outerConnect = std::get<1>(*outerIt);
476 if (isLastConnect(elseConnect)) {
479 elseConnect->erase();
481 assert(isStaticSingleConnect(elseConnect));
482 driverMap[dest] = elseConnect;
489 OpBuilder connectBuilder(elseConnect);
490 auto newConnect = flattenConditionalConnections(
491 connectBuilder, loc, getDestinationValue(outerConnect), thenCondition,
492 outerConnect, elseConnect);
495 elseConnect->erase();
496 recordConnect(dest, newConnect);
510 class WhenOpVisitor :
public LastConnectResolver<WhenOpVisitor> {
514 : LastConnectResolver<WhenOpVisitor>(driverMap), condition(condition) {}
516 using LastConnectResolver<WhenOpVisitor>::visitExpr;
517 using LastConnectResolver<WhenOpVisitor>::visitDecl;
518 using LastConnectResolver<WhenOpVisitor>::visitStmt;
519 using LastConnectResolver<WhenOpVisitor>::visitStmtExpr;
522 void process(Block &block);
525 void visitStmt(VerifAssertIntrinsicOp op);
526 void visitStmt(VerifAssumeIntrinsicOp op);
527 void visitStmt(VerifCoverIntrinsicOp op);
528 void visitStmt(AssertOp op);
529 void visitStmt(AssumeOp op);
530 void visitStmt(UnclockedAssumeIntrinsicOp op);
531 void visitStmt(CoverOp op);
532 void visitStmt(ModuleOp op);
533 void visitStmt(PrintFOp op);
534 void visitStmt(StopOp op);
535 void visitStmt(WhenOp op);
536 void visitStmt(LayerBlockOp op);
537 void visitStmt(RefForceOp op);
538 void visitStmt(RefForceInitialOp op);
539 void visitStmt(RefReleaseOp op);
540 void visitStmt(RefReleaseInitialOp op);
541 void visitStmtExpr(DPICallIntrinsicOp op);
546 Value andWithCondition(Operation *op, Value value) {
548 return OpBuilder(op).createOrFold<AndPrimOp>(
549 condition.getLoc(), condition.getType(), condition, value);
554 Value ltlAndWithCondition(Operation *op, Value property) {
556 while (
auto nodeOp = property.getDefiningOp<NodeOp>())
557 property = nodeOp.getInput();
560 if (
auto clockOp = property.getDefiningOp<LTLClockIntrinsicOp>()) {
561 auto input = ltlAndWithCondition(op, clockOp.getInput());
562 auto &newClockOp = createdLTLClockOps[{clockOp, input}];
564 newClockOp = OpBuilder(op).cloneWithoutRegions(clockOp);
565 newClockOp.getInputMutable().assign(input);
571 auto &newOp = createdLTLAndOps[{condition,
property}];
573 newOp = OpBuilder(op).createOrFold<LTLAndIntrinsicOp>(
574 condition.getLoc(),
property.getType(), condition, property);
581 Value ltlImplicationWithCondition(Operation *op, Value property) {
583 while (
auto nodeOp = property.getDefiningOp<NodeOp>())
584 property = nodeOp.getInput();
587 if (
auto clockOp = property.getDefiningOp<LTLClockIntrinsicOp>()) {
588 auto input = ltlImplicationWithCondition(op, clockOp.getInput());
589 auto &newClockOp = createdLTLClockOps[{clockOp, input}];
591 newClockOp = OpBuilder(op).cloneWithoutRegions(clockOp);
592 newClockOp.getInputMutable().assign(input);
598 if (
auto implOp = property.getDefiningOp<LTLImplicationIntrinsicOp>()) {
599 auto lhs = ltlAndWithCondition(op, implOp.getLhs());
600 auto &newImplOp = createdLTLImplicationOps[{lhs, implOp.getRhs()}];
602 auto clonedOp = OpBuilder(op).cloneWithoutRegions(implOp);
603 clonedOp.getLhsMutable().assign(lhs);
604 newImplOp = clonedOp;
610 auto &newImplOp = createdLTLImplicationOps[{condition,
property}];
612 newImplOp = OpBuilder(op).createOrFold<LTLImplicationIntrinsicOp>(
613 condition.getLoc(),
property.getType(), condition, property);
633 void WhenOpVisitor::process(Block &block) {
634 for (
auto &op : llvm::make_early_inc_range(block)) {
635 dispatchVisitor(&op);
639 void WhenOpVisitor::visitStmt(PrintFOp op) {
640 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
643 void WhenOpVisitor::visitStmt(StopOp op) {
644 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
647 void WhenOpVisitor::visitStmt(VerifAssertIntrinsicOp op) {
648 op.getPropertyMutable().assign(
649 ltlImplicationWithCondition(op, op.getProperty()));
652 void WhenOpVisitor::visitStmt(VerifAssumeIntrinsicOp op) {
653 op.getPropertyMutable().assign(
654 ltlImplicationWithCondition(op, op.getProperty()));
657 void WhenOpVisitor::visitStmt(VerifCoverIntrinsicOp op) {
658 op.getPropertyMutable().assign(ltlAndWithCondition(op, op.getProperty()));
661 void WhenOpVisitor::visitStmt(AssertOp op) {
662 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
665 void WhenOpVisitor::visitStmt(AssumeOp op) {
666 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
669 void WhenOpVisitor::visitStmt(UnclockedAssumeIntrinsicOp op) {
670 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
673 void WhenOpVisitor::visitStmt(CoverOp op) {
674 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
677 void WhenOpVisitor::visitStmt(WhenOp whenOp) {
678 processWhenOp(whenOp, condition);
682 void WhenOpVisitor::visitStmt(LayerBlockOp layerBlockOp) {
683 process(*layerBlockOp.getBody());
686 void WhenOpVisitor::visitStmt(RefForceOp op) {
687 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
690 void WhenOpVisitor::visitStmt(RefForceInitialOp op) {
691 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
694 void WhenOpVisitor::visitStmt(RefReleaseOp op) {
695 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
698 void WhenOpVisitor::visitStmt(RefReleaseInitialOp op) {
699 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
702 void WhenOpVisitor::visitStmtExpr(DPICallIntrinsicOp op) {
704 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
706 op.getEnableMutable().assign(condition);
714 template <
typename ConcreteT>
715 void LastConnectResolver<ConcreteT>::processWhenOp(WhenOp whenOp,
716 Value outerCondition) {
718 auto loc = whenOp.getLoc();
719 Block *parentBlock = whenOp->getBlock();
720 auto condition = whenOp.getCondition();
721 auto ui1Type = condition.getType();
729 Value thenCondition = whenOp.getCondition();
732 b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition, thenCondition);
734 auto &thenBlock = whenOp.getThenBlock();
735 driverMap.pushScope();
736 WhenOpVisitor(driverMap, thenCondition).process(thenBlock);
737 mergeBlock(*parentBlock, Block::iterator(whenOp), thenBlock);
738 auto thenScope = driverMap.popScope();
742 if (whenOp.hasElseRegion()) {
745 b.createOrFold<NotPrimOp>(loc, condition.getType(), condition);
748 elseCondition = b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition,
750 auto &elseBlock = whenOp.getElseBlock();
751 driverMap.pushScope();
752 WhenOpVisitor(driverMap, elseCondition).process(elseBlock);
753 mergeBlock(*parentBlock, Block::iterator(whenOp), elseBlock);
754 elseScope = driverMap.popScope();
757 mergeScopes(loc, thenScope, elseScope, condition);
769 class ModuleVisitor :
public LastConnectResolver<ModuleVisitor> {
771 ModuleVisitor() : LastConnectResolver<ModuleVisitor>(driverMap) {}
773 using LastConnectResolver<ModuleVisitor>::visitExpr;
774 using LastConnectResolver<ModuleVisitor>::visitDecl;
775 using LastConnectResolver<ModuleVisitor>::visitStmt;
776 void visitStmt(WhenOp whenOp);
777 void visitStmt(ConnectOp connectOp);
778 void visitStmt(MatchingConnectOp connectOp);
779 void visitStmt(LayerBlockOp layerBlockOp);
781 bool run(FModuleLike op);
782 LogicalResult checkInitialization();
789 bool anythingChanged =
false;
798 if (!isa<FModuleOp, ClassOp>(op))
799 return anythingChanged;
801 for (
auto ®ion : op->getRegions()) {
802 for (
auto &block : region.getBlocks()) {
804 for (
const auto &[index, value] : llvm::enumerate(block.getArguments())) {
805 auto direction = op.getPortDirection(index);
807 declareSinks(value, flow);
811 for (
auto &op : llvm::make_early_inc_range(block))
812 dispatchVisitor(&op);
816 return anythingChanged;
819 void ModuleVisitor::visitStmt(ConnectOp op) {
823 void ModuleVisitor::visitStmt(MatchingConnectOp op) {
827 void ModuleVisitor::visitStmt(WhenOp whenOp) {
829 anythingChanged =
true;
830 processWhenOp(whenOp, {});
833 void ModuleVisitor::visitStmt(LayerBlockOp layerBlockOp) {
834 for (
auto &op : llvm::make_early_inc_range(*layerBlockOp.getBody())) {
835 dispatchVisitor(&op);
841 LogicalResult ModuleVisitor::checkInitialization() {
843 for (
auto destAndConnect : driverMap.getLastScope()) {
845 auto *
connect = std::get<1>(destAndConnect);
850 FieldRef dest = std::get<0>(destAndConnect);
851 auto loc = dest.
getValue().getLoc();
853 if (
auto mod = dyn_cast<FModuleLike>(definingOp))
854 mlir::emitError(loc) <<
"port \"" <<
getFieldName(dest).first
855 <<
"\" not fully initialized in \""
856 << mod.getModuleName() <<
"\"";
860 <<
"\" not fully initialized in \""
861 << definingOp->getParentOfType<FModuleLike>().getModuleName() <<
"\"";
874 class ExpandWhensPass
875 :
public circt::firrtl::impl::ExpandWhensBase<ExpandWhensPass> {
876 void runOnOperation()
override;
880 void ExpandWhensPass::runOnOperation() {
881 ModuleVisitor visitor;
882 if (!visitor.run(getOperation()))
883 markAllAnalysesPreserved();
884 if (failed(visitor.checkInitialization()))
889 return std::make_unique<ExpandWhensPass>();
assert(baseType &&"element must be base type")
ScopedDriverMap::ScopeT DriverMap
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
This class represents a reference to a specific field or element of an aggregate value.
Value getValue() const
Get the Value which created this location.
Operation * getDefiningOp() const
Get the operation which defines this field.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
def connect(destination, source)
Flow swapFlow(Flow flow)
Get a flow's reverse.
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
std::unique_ptr< mlir::Pass > createExpandWhensPass()
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
std::pair< KeyT, ValueT > & operator*() const
bool operator!=(const Iterator &rhs) const
bool operator==(const Iterator &rhs) const
Iterator(typename StackT::iterator stackIt, typename ScopeT::iterator scopeIt)
This is a stack of hashtables, if lookup fails in the top-most hashtable, it will attempt to lookup i...
typename llvm::MapVector< KeyT, ValueT > ScopeT
iterator find(const KeyT &key)
ValueT & operator[](const KeyT &key)
typename llvm::SmallVector< ScopeT, 3 > StackT