19#include "mlir/Pass/Pass.h"
20#include "llvm/ADT/MapVector.h"
21#include "llvm/ADT/STLExtras.h"
25#define GEN_PASS_DEF_EXPANDWHENS
26#include "circt/Dialect/FIRRTL/Passes.h.inc"
31using namespace firrtl;
35static void mergeBlock(Block &destination, Block::iterator insertPoint,
37 destination.getOperations().splice(insertPoint, source.getOperations());
45template <
typename KeyT,
typename ValueT>
47 using ScopeT =
typename llvm::MapVector<KeyT, ValueT>;
48 using StackT =
typename llvm::SmallVector<ScopeT, 3>;
52 typename ScopeT::iterator
scopeIt)
90 for (
auto i =
mapStack.size(); i > 0; --i) {
92 auto it = map.find(key);
127template <
typename ConcreteT>
136 LastConnectResolver(
ScopedDriverMap &driverMap) : driverMap(driverMap) {}
147 bool recordConnect(
FieldRef dest, Operation *connection) {
149 auto itAndInserted = driverMap.getLastScope().insert({dest, connection});
150 if (isStaticSingleConnect(connection)) {
152 assert(itAndInserted.second || !itAndInserted.first->second);
153 if (!itAndInserted.second)
154 itAndInserted.first->second = connection;
157 assert(isLastConnect(connection));
158 if (!std::get<1>(itAndInserted)) {
159 auto iterator = std::get<0>(itAndInserted);
160 auto changed =
false;
163 if (
auto *oldConnect = iterator->second) {
167 iterator->second = connection;
175 static Value getDestinationValue(Operation *op) {
176 return cast<FConnectLike>(op).getDest();
181 static Value getConnectedValue(Operation *op) {
182 return cast<FConnectLike>(op).getSrc();
186 static bool isStaticSingleConnect(Operation *op) {
187 return cast<FConnectLike>(op).hasStaticSingleConnectBehavior();
194 static bool isLastConnect(Operation *op) {
195 return cast<FConnectLike>(op).hasLastConnectBehavior();
200 void declareSinks(Value value,
Flow flow,
bool local =
false) {
201 auto type = value.getType();
205 std::function<void(Type,
Flow,
bool)> declare = [&](Type type,
Flow flow,
208 if (
auto classType = type_dyn_cast<ClassType>(type)) {
212 for (
auto &element : classType.getElements()) {
214 if (element.direction == Direction::Out)
215 declare(element.type, flow,
false);
217 declare(element.type,
swapFlow(flow),
false);
222 if (flow != Flow::Source)
223 driverMap[{value,
id}] =
nullptr;
226 id += classType.getMaxFieldID();
232 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
233 for (
auto &element : bundleType.getElements()) {
236 declare(element.type,
swapFlow(flow),
false);
238 declare(element.type, flow,
false);
244 if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
245 for (
unsigned i = 0; i < vectorType.getNumElements(); ++i) {
247 declare(vectorType.getElementType(), flow,
false);
253 if (
auto analogType = type_dyn_cast<AnalogType>(type))
258 if (flow != Flow::Source)
259 driverMap[{value,
id}] =
nullptr;
262 declare(type, flow, local);
267 ConnectOp flattenConditionalConnections(OpBuilder &b, Location loc,
268 Value dest, Value cond,
269 Operation *whenTrueConn,
270 Operation *whenFalseConn) {
271 assert(isLastConnect(whenTrueConn) && isLastConnect(whenFalseConn));
273 b.getFusedLoc({loc, whenTrueConn->getLoc(), whenFalseConn->getLoc()});
274 auto whenTrue = getConnectedValue(whenTrueConn);
276 isa_and_nonnull<InvalidValueOp>(whenTrue.getDefiningOp());
277 auto whenFalse = getConnectedValue(whenFalseConn);
278 auto falseIsInvalid =
279 isa_and_nonnull<InvalidValueOp>(whenFalse.getDefiningOp());
286 Value newValue = whenTrue;
287 if (trueIsInvalid == falseIsInvalid)
288 newValue = b.createOrFold<MuxPrimOp>(fusedLoc, cond, whenTrue, whenFalse);
289 else if (trueIsInvalid)
290 newValue = whenFalse;
291 return ConnectOp::create(b, loc, dest, newValue);
294 void visitDecl(WireOp op) { declareSinks(op.getResult(), Flow::Duplex); }
298 void foreachSubelement(OpBuilder &builder, Value value,
299 llvm::function_ref<
void(Value)> fn) {
301 .template Case<BundleType>([&](BundleType bundle) {
302 for (
auto i : llvm::seq(0u, (
unsigned)bundle.getNumElements())) {
304 SubfieldOp::create(builder, value.getLoc(), value, i);
305 foreachSubelement(builder, subfield, fn);
308 .
template Case<FVectorType>([&](FVectorType vector) {
309 for (
auto i : llvm::seq((
size_t)0, vector.getNumElements())) {
311 SubindexOp::create(builder, value.getLoc(), value, i);
312 foreachSubelement(builder, subindex, fn);
315 .Default([&](
auto) { fn(value); });
318 void visitDecl(RegOp op) {
321 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
322 auto fn = [&](Value value) {
323 auto connect = ConnectOp::create(builder, value.getLoc(), value, value);
326 foreachSubelement(builder, op.getResult(), fn);
329 void visitDecl(RegResetOp op) {
332 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
333 auto fn = [&](Value value) {
334 auto connect = ConnectOp::create(builder, value.getLoc(), value, value);
337 foreachSubelement(builder, op.getResult(), fn);
340 void visitDecl(InstanceOp op) {
343 for (
const auto &result : llvm::enumerate(op.getResults()))
347 declareSinks(result.value(), Flow::Sink);
350 void visitDecl(ObjectOp op) {
351 declareSinks(op, Flow::Source,
true);
354 void visitDecl(MemOp op) {
356 for (
auto result : op.getResults())
357 if (!isa<RefType>(result.getType()))
358 declareSinks(result, Flow::Sink);
361 void visitStmt(ConnectOp op) {
365 void visitStmt(MatchingConnectOp op) {
369 void visitStmt(RefDefineOp op) {
373 void visitStmt(PropAssignOp op) {
377 void visitStmt(DomainDefineOp op) {
381 void processWhenOp(WhenOp whenOp, Value outerCondition);
400 Value thenCondition) {
403 for (
auto &destAndConnect : thenScope) {
404 auto dest = std::get<0>(destAndConnect);
405 auto thenConnect = std::get<1>(destAndConnect);
407 auto outerIt = driverMap.find(dest);
408 if (outerIt == driverMap.end()) {
411 driverMap[dest] = thenConnect;
415 auto elseIt = elseScope.find(dest);
416 if (elseIt != elseScope.end()) {
421 auto &elseConnect = std::get<1>(*elseIt);
422 OpBuilder connectBuilder(elseConnect);
423 auto newConnect = flattenConditionalConnections(
424 connectBuilder, loc, getDestinationValue(thenConnect),
425 thenCondition, thenConnect, elseConnect);
428 thenConnect->erase();
429 elseConnect->erase();
430 recordConnect(dest, newConnect);
435 auto &outerConnect = std::get<1>(*outerIt);
437 if (isLastConnect(thenConnect)) {
440 thenConnect->erase();
442 assert(isStaticSingleConnect(thenConnect));
443 driverMap[dest] = thenConnect;
450 OpBuilder connectBuilder(thenConnect);
451 auto newConnect = flattenConditionalConnections(
452 connectBuilder, loc, getDestinationValue(thenConnect), thenCondition,
453 thenConnect, outerConnect);
456 thenConnect->erase();
457 recordConnect(dest, newConnect);
461 for (
auto &destAndConnect : elseScope) {
462 auto dest = std::get<0>(destAndConnect);
463 auto elseConnect = std::get<1>(destAndConnect);
467 if (thenScope.contains(dest))
470 auto outerIt = driverMap.find(dest);
471 if (outerIt == driverMap.end()) {
474 driverMap[dest] = elseConnect;
478 auto &outerConnect = std::get<1>(*outerIt);
480 if (isLastConnect(elseConnect)) {
483 elseConnect->erase();
485 assert(isStaticSingleConnect(elseConnect));
486 driverMap[dest] = elseConnect;
493 OpBuilder connectBuilder(elseConnect);
494 auto newConnect = flattenConditionalConnections(
495 connectBuilder, loc, getDestinationValue(outerConnect), thenCondition,
496 outerConnect, elseConnect);
499 elseConnect->erase();
500 recordConnect(dest, newConnect);
514class WhenOpVisitor :
public LastConnectResolver<WhenOpVisitor> {
518 : LastConnectResolver<WhenOpVisitor>(driverMap), condition(condition) {}
520 using LastConnectResolver<WhenOpVisitor>::visitExpr;
521 using LastConnectResolver<WhenOpVisitor>::visitDecl;
522 using LastConnectResolver<WhenOpVisitor>::visitStmt;
523 using LastConnectResolver<WhenOpVisitor>::visitStmtExpr;
526 void process(Block &block);
529 void visitStmt(VerifAssertIntrinsicOp op);
530 void visitStmt(VerifAssumeIntrinsicOp op);
531 void visitStmt(VerifCoverIntrinsicOp op);
532 void visitStmt(AssertOp op);
533 void visitStmt(AssumeOp op);
534 void visitStmt(UnclockedAssumeIntrinsicOp op);
535 void visitStmt(CoverOp op);
536 void visitStmt(ModuleOp op);
537 void visitStmt(PrintFOp op);
538 void visitStmt(FPrintFOp op);
539 void visitStmt(FFlushOp op);
540 void visitStmt(StopOp op);
541 void visitStmt(WhenOp op);
542 void visitStmt(LayerBlockOp op);
543 void visitStmt(RefForceOp op);
544 void visitStmt(RefForceInitialOp op);
545 void visitStmt(RefReleaseOp op);
546 void visitStmt(RefReleaseInitialOp op);
547 void visitStmtExpr(DPICallIntrinsicOp op);
552 Value andWithCondition(Operation *op, Value value) {
554 return OpBuilder(op).createOrFold<AndPrimOp>(
555 condition.getLoc(), condition.getType(), condition, value);
560 Value ltlAndWithCondition(Operation *op, Value property) {
562 while (
auto nodeOp = property.getDefiningOp<NodeOp>())
563 property = nodeOp.getInput();
566 if (
auto clockOp = property.getDefiningOp<LTLClockIntrinsicOp>()) {
567 auto input = ltlAndWithCondition(op, clockOp.getInput());
568 auto &newClockOp = createdLTLClockOps[{clockOp, input}];
570 newClockOp = OpBuilder(op).cloneWithoutRegions(clockOp);
571 newClockOp.getInputMutable().assign(input);
577 auto &newOp = createdLTLAndOps[{condition,
property}];
579 newOp = OpBuilder(op).createOrFold<LTLAndIntrinsicOp>(
580 condition.getLoc(),
property.getType(), condition, property);
587 Value ltlImplicationWithCondition(Operation *op, Value property) {
589 while (
auto nodeOp = property.getDefiningOp<NodeOp>())
590 property = nodeOp.getInput();
593 if (
auto clockOp = property.getDefiningOp<LTLClockIntrinsicOp>()) {
594 auto input = ltlImplicationWithCondition(op, clockOp.getInput());
595 auto &newClockOp = createdLTLClockOps[{clockOp, input}];
597 newClockOp = OpBuilder(op).cloneWithoutRegions(clockOp);
598 newClockOp.getInputMutable().assign(input);
604 if (
auto implOp = property.getDefiningOp<LTLImplicationIntrinsicOp>()) {
605 auto lhs = ltlAndWithCondition(op, implOp.getLhs());
606 auto &newImplOp = createdLTLImplicationOps[{lhs, implOp.getRhs()}];
608 auto clonedOp = OpBuilder(op).cloneWithoutRegions(implOp);
609 clonedOp.getLhsMutable().assign(lhs);
610 newImplOp = clonedOp;
616 auto &newImplOp = createdLTLImplicationOps[{condition,
property}];
618 newImplOp = OpBuilder(op).createOrFold<LTLImplicationIntrinsicOp>(
619 condition.getLoc(),
property.getType(), condition, property);
639void WhenOpVisitor::process(Block &block) {
640 for (
auto &op :
llvm::make_early_inc_range(block)) {
641 dispatchVisitor(&op);
645void WhenOpVisitor::visitStmt(PrintFOp op) {
646 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
649void WhenOpVisitor::visitStmt(FPrintFOp op) {
650 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
653void WhenOpVisitor::visitStmt(FFlushOp op) {
654 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
657void WhenOpVisitor::visitStmt(StopOp op) {
658 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
661void WhenOpVisitor::visitStmt(VerifAssertIntrinsicOp op) {
662 op.getPropertyMutable().assign(
663 ltlImplicationWithCondition(op, op.getProperty()));
666void WhenOpVisitor::visitStmt(VerifAssumeIntrinsicOp op) {
667 op.getPropertyMutable().assign(
668 ltlImplicationWithCondition(op, op.getProperty()));
671void WhenOpVisitor::visitStmt(VerifCoverIntrinsicOp op) {
672 op.getPropertyMutable().assign(ltlAndWithCondition(op, op.getProperty()));
675void WhenOpVisitor::visitStmt(AssertOp op) {
676 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
679void WhenOpVisitor::visitStmt(AssumeOp op) {
680 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
683void WhenOpVisitor::visitStmt(UnclockedAssumeIntrinsicOp op) {
684 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
687void WhenOpVisitor::visitStmt(CoverOp op) {
688 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
691void WhenOpVisitor::visitStmt(WhenOp whenOp) {
692 processWhenOp(whenOp, condition);
696void WhenOpVisitor::visitStmt(LayerBlockOp layerBlockOp) {
697 process(*layerBlockOp.getBody());
700void WhenOpVisitor::visitStmt(RefForceOp op) {
701 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
704void WhenOpVisitor::visitStmt(RefForceInitialOp op) {
705 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
708void WhenOpVisitor::visitStmt(RefReleaseOp op) {
709 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
712void WhenOpVisitor::visitStmt(RefReleaseInitialOp op) {
713 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
716void WhenOpVisitor::visitStmtExpr(DPICallIntrinsicOp op) {
718 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
720 op.getEnableMutable().assign(condition);
728template <
typename ConcreteT>
729void LastConnectResolver<ConcreteT>::processWhenOp(WhenOp whenOp,
730 Value outerCondition) {
732 auto loc = whenOp.getLoc();
733 Block *parentBlock = whenOp->getBlock();
734 auto condition = whenOp.getCondition();
735 auto ui1Type = condition.getType();
743 Value thenCondition = whenOp.getCondition();
746 b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition, thenCondition);
748 auto &thenBlock = whenOp.getThenBlock();
749 driverMap.pushScope();
750 WhenOpVisitor(driverMap, thenCondition).process(thenBlock);
751 mergeBlock(*parentBlock, Block::iterator(whenOp), thenBlock);
752 auto thenScope = driverMap.popScope();
756 if (whenOp.hasElseRegion()) {
759 b.createOrFold<NotPrimOp>(loc, condition.getType(), condition);
762 elseCondition = b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition,
764 auto &elseBlock = whenOp.getElseBlock();
765 driverMap.pushScope();
766 WhenOpVisitor(driverMap, elseCondition).process(elseBlock);
767 mergeBlock(*parentBlock, Block::iterator(whenOp), elseBlock);
768 elseScope = driverMap.popScope();
771 mergeScopes(loc, thenScope, elseScope, condition);
783class ModuleVisitor :
public LastConnectResolver<ModuleVisitor> {
785 ModuleVisitor() : LastConnectResolver<ModuleVisitor>(driverMap) {}
787 using LastConnectResolver<ModuleVisitor>::visitExpr;
788 using LastConnectResolver<ModuleVisitor>::visitDecl;
789 using LastConnectResolver<ModuleVisitor>::visitStmt;
790 void visitStmt(WhenOp whenOp);
791 void visitStmt(ConnectOp connectOp);
792 void visitStmt(MatchingConnectOp connectOp);
793 void visitStmt(LayerBlockOp layerBlockOp);
795 bool run(FModuleLike op);
796 LogicalResult checkInitialization();
803 bool anythingChanged =
false;
810bool ModuleVisitor::run(FModuleLike op) {
812 if (!isa<FModuleOp, ClassOp>(op))
813 return anythingChanged;
815 for (
auto ®ion : op->getRegions()) {
816 for (
auto &block : region.getBlocks()) {
818 for (
const auto &[index, value] :
llvm::enumerate(block.getArguments())) {
819 auto direction = op.getPortDirection(index);
820 auto flow = direction == Direction::In ? Flow::Source : Flow::Sink;
821 declareSinks(value, flow);
825 for (
auto &op :
llvm::make_early_inc_range(block))
826 dispatchVisitor(&op);
830 return anythingChanged;
833void ModuleVisitor::visitStmt(ConnectOp op) {
837void ModuleVisitor::visitStmt(MatchingConnectOp op) {
841void ModuleVisitor::visitStmt(WhenOp whenOp) {
843 anythingChanged =
true;
844 processWhenOp(whenOp, {});
847void ModuleVisitor::visitStmt(LayerBlockOp layerBlockOp) {
848 for (
auto &op :
llvm::make_early_inc_range(*layerBlockOp.getBody())) {
849 dispatchVisitor(&op);
855LogicalResult ModuleVisitor::checkInitialization() {
857 for (
auto destAndConnect : driverMap.getLastScope()) {
859 auto *
connect = std::get<1>(destAndConnect);
864 FieldRef dest = std::get<0>(destAndConnect);
865 auto loc = dest.
getValue().getLoc();
867 if (
auto mod = dyn_cast<FModuleLike>(definingOp))
868 mlir::emitError(loc) <<
"port \"" <<
getFieldName(dest).first
869 <<
"\" not fully initialized in \""
870 << mod.getModuleName() <<
"\"";
874 <<
"\" not fully initialized in \""
875 << definingOp->getParentOfType<FModuleLike>().getModuleName() <<
"\"";
889 :
public circt::firrtl::impl::ExpandWhensBase<ExpandWhensPass> {
890 void runOnOperation()
override;
894void ExpandWhensPass::runOnOperation() {
895 ModuleVisitor visitor;
896 if (!visitor.run(getOperation()))
897 markAllAnalysesPreserved();
898 if (failed(visitor.checkInitialization()))
assert(baseType &&"element must be base type")
ScopedDriverMap::ScopeT DriverMap
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
HashTableStack< FieldRef, Operation * > ScopedDriverMap
This is a determistic mapping of a FieldRef to the last operation which set a value to it.
This class represents a reference to a specific field or element of an aggregate value.
Value getValue() const
Get the Value which created this location.
Operation * getDefiningOp() const
Get the operation which defines this field.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
connect(destination, source)
Flow swapFlow(Flow flow)
Get a flow's reverse.
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
bool operator!=(const Iterator &rhs) const
bool operator==(const Iterator &rhs) const
std::pair< KeyT, ValueT > & operator*() const
Iterator(typename StackT::iterator stackIt, typename ScopeT::iterator scopeIt)
This is a stack of hashtables, if lookup fails in the top-most hashtable, it will attempt to lookup i...
typename llvm::MapVector< KeyT, ValueT > ScopeT
iterator find(const KeyT &key)
typename llvm::SmallVector< ScopeT, 3 > StackT
ValueT & operator[](const KeyT &key)