19#include "mlir/Pass/Pass.h"
20#include "llvm/ADT/MapVector.h"
21#include "llvm/ADT/STLExtras.h"
25#define GEN_PASS_DEF_EXPANDWHENS
26#include "circt/Dialect/FIRRTL/Passes.h.inc"
31using namespace firrtl;
35static void mergeBlock(Block &destination, Block::iterator insertPoint,
37 destination.getOperations().splice(insertPoint, source.getOperations());
35static void mergeBlock(Block &destination, Block::iterator insertPoint, {
…}
45template <
typename KeyT,
typename ValueT>
47 using ScopeT =
typename llvm::MapVector<KeyT, ValueT>;
48 using StackT =
typename llvm::SmallVector<ScopeT, 3>;
52 typename ScopeT::iterator
scopeIt)
90 for (
auto i =
mapStack.size(); i > 0; --i) {
92 auto it = map.find(key);
127template <
typename ConcreteT>
136 LastConnectResolver(
ScopedDriverMap &driverMap) : driverMap(driverMap) {}
147 bool recordConnect(
FieldRef dest, Operation *connection) {
149 auto itAndInserted = driverMap.getLastScope().insert({dest, connection});
150 if (isStaticSingleConnect(connection)) {
152 assert(itAndInserted.second || !itAndInserted.first->second);
153 if (!itAndInserted.second)
154 itAndInserted.first->second = connection;
157 assert(isLastConnect(connection));
158 if (!std::get<1>(itAndInserted)) {
159 auto iterator = std::get<0>(itAndInserted);
160 auto changed =
false;
163 if (
auto *oldConnect = iterator->second) {
167 iterator->second = connection;
175 static Value getDestinationValue(Operation *op) {
176 return cast<FConnectLike>(op).getDest();
181 static Value getConnectedValue(Operation *op) {
182 return cast<FConnectLike>(op).getSrc();
186 static bool isStaticSingleConnect(Operation *op) {
187 return cast<FConnectLike>(op).hasStaticSingleConnectBehavior();
194 static bool isLastConnect(Operation *op) {
195 return cast<FConnectLike>(op).hasLastConnectBehavior();
200 void declareSinks(Value value,
Flow flow,
bool local =
false) {
201 auto type = value.getType();
205 std::function<void(Type,
Flow,
bool)> declare = [&](Type type,
Flow flow,
208 if (
auto classType = type_dyn_cast<ClassType>(type)) {
212 for (
auto &element : classType.getElements()) {
214 if (element.direction == Direction::Out)
215 declare(element.type, flow,
false);
217 declare(element.type,
swapFlow(flow),
false);
222 if (flow != Flow::Source)
223 driverMap[{value,
id}] =
nullptr;
226 id += classType.getMaxFieldID();
232 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
233 for (
auto &element : bundleType.getElements()) {
236 declare(element.type,
swapFlow(flow),
false);
238 declare(element.type, flow,
false);
244 if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
245 for (
unsigned i = 0; i < vectorType.getNumElements(); ++i) {
247 declare(vectorType.getElementType(), flow,
false);
253 if (
auto analogType = type_dyn_cast<AnalogType>(type))
258 if (flow != Flow::Source)
259 driverMap[{value,
id}] =
nullptr;
262 declare(type, flow, local);
267 ConnectOp flattenConditionalConnections(OpBuilder &b, Location loc,
268 Value dest, Value cond,
269 Operation *whenTrueConn,
270 Operation *whenFalseConn) {
271 assert(isLastConnect(whenTrueConn) && isLastConnect(whenFalseConn));
273 b.getFusedLoc({loc, whenTrueConn->getLoc(), whenFalseConn->getLoc()});
274 auto whenTrue = getConnectedValue(whenTrueConn);
276 isa_and_nonnull<InvalidValueOp>(whenTrue.getDefiningOp());
277 auto whenFalse = getConnectedValue(whenFalseConn);
278 auto falseIsInvalid =
279 isa_and_nonnull<InvalidValueOp>(whenFalse.getDefiningOp());
286 Value newValue = whenTrue;
287 if (trueIsInvalid == falseIsInvalid)
288 newValue = b.createOrFold<MuxPrimOp>(fusedLoc, cond, whenTrue, whenFalse);
289 else if (trueIsInvalid)
290 newValue = whenFalse;
291 return b.create<ConnectOp>(loc, dest, newValue);
294 void visitDecl(WireOp op) { declareSinks(op.getResult(), Flow::Duplex); }
298 void foreachSubelement(OpBuilder &builder, Value value,
299 llvm::function_ref<
void(Value)> fn) {
301 .template Case<BundleType>([&](BundleType bundle) {
302 for (
auto i : llvm::seq(0u, (
unsigned)bundle.getNumElements())) {
304 builder.create<SubfieldOp>(value.getLoc(), value, i);
305 foreachSubelement(builder, subfield, fn);
308 .
template Case<FVectorType>([&](FVectorType vector) {
309 for (
auto i : llvm::seq((
size_t)0, vector.getNumElements())) {
311 builder.create<SubindexOp>(value.getLoc(), value, i);
312 foreachSubelement(builder, subindex, fn);
315 .Default([&](
auto) { fn(value); });
318 void visitDecl(RegOp op) {
321 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
322 auto fn = [&](Value value) {
323 auto connect = builder.create<ConnectOp>(value.getLoc(), value, value);
326 foreachSubelement(builder, op.getResult(), fn);
329 void visitDecl(RegResetOp op) {
332 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
333 auto fn = [&](Value value) {
334 auto connect = builder.create<ConnectOp>(value.getLoc(), value, value);
337 foreachSubelement(builder, op.getResult(), fn);
340 void visitDecl(InstanceOp op) {
343 for (
const auto &result : llvm::enumerate(op.getResults()))
347 declareSinks(result.value(), Flow::Sink);
350 void visitDecl(ObjectOp op) {
351 declareSinks(op, Flow::Source,
true);
354 void visitDecl(MemOp op) {
356 for (
auto result : op.getResults())
357 if (!isa<RefType>(result.getType()))
358 declareSinks(result, Flow::Sink);
361 void visitStmt(ConnectOp op) {
365 void visitStmt(MatchingConnectOp op) {
369 void visitStmt(RefDefineOp op) {
373 void visitStmt(PropAssignOp op) {
377 void processWhenOp(WhenOp whenOp, Value outerCondition);
396 Value thenCondition) {
399 for (
auto &destAndConnect : thenScope) {
400 auto dest = std::get<0>(destAndConnect);
401 auto thenConnect = std::get<1>(destAndConnect);
403 auto outerIt = driverMap.find(dest);
404 if (outerIt == driverMap.end()) {
407 driverMap[dest] = thenConnect;
411 auto elseIt = elseScope.find(dest);
412 if (elseIt != elseScope.end()) {
417 auto &elseConnect = std::get<1>(*elseIt);
418 OpBuilder connectBuilder(elseConnect);
419 auto newConnect = flattenConditionalConnections(
420 connectBuilder, loc, getDestinationValue(thenConnect),
421 thenCondition, thenConnect, elseConnect);
424 thenConnect->erase();
425 elseConnect->erase();
426 recordConnect(dest, newConnect);
431 auto &outerConnect = std::get<1>(*outerIt);
433 if (isLastConnect(thenConnect)) {
436 thenConnect->erase();
438 assert(isStaticSingleConnect(thenConnect));
439 driverMap[dest] = thenConnect;
446 OpBuilder connectBuilder(thenConnect);
447 auto newConnect = flattenConditionalConnections(
448 connectBuilder, loc, getDestinationValue(thenConnect), thenCondition,
449 thenConnect, outerConnect);
452 thenConnect->erase();
453 recordConnect(dest, newConnect);
457 for (
auto &destAndConnect : elseScope) {
458 auto dest = std::get<0>(destAndConnect);
459 auto elseConnect = std::get<1>(destAndConnect);
463 if (thenScope.contains(dest))
466 auto outerIt = driverMap.find(dest);
467 if (outerIt == driverMap.end()) {
470 driverMap[dest] = elseConnect;
474 auto &outerConnect = std::get<1>(*outerIt);
476 if (isLastConnect(elseConnect)) {
479 elseConnect->erase();
481 assert(isStaticSingleConnect(elseConnect));
482 driverMap[dest] = elseConnect;
489 OpBuilder connectBuilder(elseConnect);
490 auto newConnect = flattenConditionalConnections(
491 connectBuilder, loc, getDestinationValue(outerConnect), thenCondition,
492 outerConnect, elseConnect);
495 elseConnect->erase();
496 recordConnect(dest, newConnect);
510class WhenOpVisitor :
public LastConnectResolver<WhenOpVisitor> {
514 : LastConnectResolver<WhenOpVisitor>(driverMap), condition(condition) {}
516 using LastConnectResolver<WhenOpVisitor>::visitExpr;
517 using LastConnectResolver<WhenOpVisitor>::visitDecl;
518 using LastConnectResolver<WhenOpVisitor>::visitStmt;
519 using LastConnectResolver<WhenOpVisitor>::visitStmtExpr;
522 void process(Block &block);
525 void visitStmt(VerifAssertIntrinsicOp op);
526 void visitStmt(VerifAssumeIntrinsicOp op);
527 void visitStmt(VerifCoverIntrinsicOp op);
528 void visitStmt(AssertOp op);
529 void visitStmt(AssumeOp op);
530 void visitStmt(UnclockedAssumeIntrinsicOp op);
531 void visitStmt(CoverOp op);
532 void visitStmt(ModuleOp op);
533 void visitStmt(PrintFOp op);
534 void visitStmt(FPrintFOp op);
535 void visitStmt(FFlushOp op);
536 void visitStmt(StopOp op);
537 void visitStmt(WhenOp op);
538 void visitStmt(LayerBlockOp op);
539 void visitStmt(RefForceOp op);
540 void visitStmt(RefForceInitialOp op);
541 void visitStmt(RefReleaseOp op);
542 void visitStmt(RefReleaseInitialOp op);
543 void visitStmtExpr(DPICallIntrinsicOp op);
548 Value andWithCondition(Operation *op, Value value) {
550 return OpBuilder(op).createOrFold<AndPrimOp>(
551 condition.getLoc(), condition.getType(), condition, value);
556 Value ltlAndWithCondition(Operation *op, Value property) {
558 while (
auto nodeOp = property.getDefiningOp<NodeOp>())
559 property = nodeOp.getInput();
562 if (
auto clockOp = property.getDefiningOp<LTLClockIntrinsicOp>()) {
563 auto input = ltlAndWithCondition(op, clockOp.getInput());
564 auto &newClockOp = createdLTLClockOps[{clockOp, input}];
566 newClockOp = OpBuilder(op).cloneWithoutRegions(clockOp);
567 newClockOp.getInputMutable().assign(input);
573 auto &newOp = createdLTLAndOps[{condition,
property}];
575 newOp = OpBuilder(op).createOrFold<LTLAndIntrinsicOp>(
576 condition.getLoc(),
property.getType(), condition, property);
583 Value ltlImplicationWithCondition(Operation *op, Value property) {
585 while (
auto nodeOp = property.getDefiningOp<NodeOp>())
586 property = nodeOp.getInput();
589 if (
auto clockOp = property.getDefiningOp<LTLClockIntrinsicOp>()) {
590 auto input = ltlImplicationWithCondition(op, clockOp.getInput());
591 auto &newClockOp = createdLTLClockOps[{clockOp, input}];
593 newClockOp = OpBuilder(op).cloneWithoutRegions(clockOp);
594 newClockOp.getInputMutable().assign(input);
600 if (
auto implOp = property.getDefiningOp<LTLImplicationIntrinsicOp>()) {
601 auto lhs = ltlAndWithCondition(op, implOp.getLhs());
602 auto &newImplOp = createdLTLImplicationOps[{lhs, implOp.getRhs()}];
604 auto clonedOp = OpBuilder(op).cloneWithoutRegions(implOp);
605 clonedOp.getLhsMutable().assign(lhs);
606 newImplOp = clonedOp;
612 auto &newImplOp = createdLTLImplicationOps[{condition,
property}];
614 newImplOp = OpBuilder(op).createOrFold<LTLImplicationIntrinsicOp>(
615 condition.getLoc(),
property.getType(), condition, property);
635void WhenOpVisitor::process(Block &block) {
636 for (
auto &op :
llvm::make_early_inc_range(block)) {
637 dispatchVisitor(&op);
641void WhenOpVisitor::visitStmt(PrintFOp op) {
642 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
645void WhenOpVisitor::visitStmt(FPrintFOp op) {
646 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
649void WhenOpVisitor::visitStmt(FFlushOp op) {
650 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
653void WhenOpVisitor::visitStmt(StopOp op) {
654 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
657void WhenOpVisitor::visitStmt(VerifAssertIntrinsicOp op) {
658 op.getPropertyMutable().assign(
659 ltlImplicationWithCondition(op, op.getProperty()));
662void WhenOpVisitor::visitStmt(VerifAssumeIntrinsicOp op) {
663 op.getPropertyMutable().assign(
664 ltlImplicationWithCondition(op, op.getProperty()));
667void WhenOpVisitor::visitStmt(VerifCoverIntrinsicOp op) {
668 op.getPropertyMutable().assign(ltlAndWithCondition(op, op.getProperty()));
671void WhenOpVisitor::visitStmt(AssertOp op) {
672 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
675void WhenOpVisitor::visitStmt(AssumeOp op) {
676 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
679void WhenOpVisitor::visitStmt(UnclockedAssumeIntrinsicOp op) {
680 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
683void WhenOpVisitor::visitStmt(CoverOp op) {
684 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
687void WhenOpVisitor::visitStmt(WhenOp whenOp) {
688 processWhenOp(whenOp, condition);
692void WhenOpVisitor::visitStmt(LayerBlockOp layerBlockOp) {
693 process(*layerBlockOp.getBody());
696void WhenOpVisitor::visitStmt(RefForceOp op) {
697 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
700void WhenOpVisitor::visitStmt(RefForceInitialOp op) {
701 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
704void WhenOpVisitor::visitStmt(RefReleaseOp op) {
705 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
708void WhenOpVisitor::visitStmt(RefReleaseInitialOp op) {
709 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
712void WhenOpVisitor::visitStmtExpr(DPICallIntrinsicOp op) {
714 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
716 op.getEnableMutable().assign(condition);
724template <
typename ConcreteT>
725void LastConnectResolver<ConcreteT>::processWhenOp(WhenOp whenOp,
726 Value outerCondition) {
728 auto loc = whenOp.getLoc();
729 Block *parentBlock = whenOp->getBlock();
730 auto condition = whenOp.getCondition();
731 auto ui1Type = condition.getType();
739 Value thenCondition = whenOp.getCondition();
742 b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition, thenCondition);
744 auto &thenBlock = whenOp.getThenBlock();
745 driverMap.pushScope();
746 WhenOpVisitor(driverMap, thenCondition).process(thenBlock);
747 mergeBlock(*parentBlock, Block::iterator(whenOp), thenBlock);
748 auto thenScope = driverMap.popScope();
752 if (whenOp.hasElseRegion()) {
755 b.createOrFold<NotPrimOp>(loc, condition.getType(), condition);
758 elseCondition = b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition,
760 auto &elseBlock = whenOp.getElseBlock();
761 driverMap.pushScope();
762 WhenOpVisitor(driverMap, elseCondition).process(elseBlock);
763 mergeBlock(*parentBlock, Block::iterator(whenOp), elseBlock);
764 elseScope = driverMap.popScope();
767 mergeScopes(loc, thenScope, elseScope, condition);
779class ModuleVisitor :
public LastConnectResolver<ModuleVisitor> {
781 ModuleVisitor() : LastConnectResolver<ModuleVisitor>(driverMap) {}
783 using LastConnectResolver<ModuleVisitor>::visitExpr;
784 using LastConnectResolver<ModuleVisitor>::visitDecl;
785 using LastConnectResolver<ModuleVisitor>::visitStmt;
786 void visitStmt(WhenOp whenOp);
787 void visitStmt(ConnectOp connectOp);
788 void visitStmt(MatchingConnectOp connectOp);
789 void visitStmt(LayerBlockOp layerBlockOp);
791 bool run(FModuleLike op);
792 LogicalResult checkInitialization();
799 bool anythingChanged =
false;
806bool ModuleVisitor::run(FModuleLike op) {
808 if (!isa<FModuleOp, ClassOp>(op))
809 return anythingChanged;
811 for (
auto ®ion : op->getRegions()) {
812 for (
auto &block : region.getBlocks()) {
814 for (
const auto &[index, value] :
llvm::enumerate(block.getArguments())) {
815 auto direction = op.getPortDirection(index);
816 auto flow = direction == Direction::In ? Flow::Source : Flow::Sink;
817 declareSinks(value, flow);
821 for (
auto &op :
llvm::make_early_inc_range(block))
822 dispatchVisitor(&op);
826 return anythingChanged;
829void ModuleVisitor::visitStmt(ConnectOp op) {
833void ModuleVisitor::visitStmt(MatchingConnectOp op) {
837void ModuleVisitor::visitStmt(WhenOp whenOp) {
839 anythingChanged =
true;
840 processWhenOp(whenOp, {});
843void ModuleVisitor::visitStmt(LayerBlockOp layerBlockOp) {
844 for (
auto &op :
llvm::make_early_inc_range(*layerBlockOp.getBody())) {
845 dispatchVisitor(&op);
851LogicalResult ModuleVisitor::checkInitialization() {
853 for (
auto destAndConnect : driverMap.getLastScope()) {
855 auto *
connect = std::get<1>(destAndConnect);
860 FieldRef dest = std::get<0>(destAndConnect);
861 auto loc = dest.
getValue().getLoc();
863 if (
auto mod = dyn_cast<FModuleLike>(definingOp))
864 mlir::emitError(loc) <<
"port \"" <<
getFieldName(dest).first
865 <<
"\" not fully initialized in \""
866 << mod.getModuleName() <<
"\"";
870 <<
"\" not fully initialized in \""
871 << definingOp->getParentOfType<FModuleLike>().getModuleName() <<
"\"";
885 :
public circt::firrtl::impl::ExpandWhensBase<ExpandWhensPass> {
886 void runOnOperation()
override;
890void ExpandWhensPass::runOnOperation() {
891 ModuleVisitor visitor;
892 if (!visitor.run(getOperation()))
893 markAllAnalysesPreserved();
894 if (failed(visitor.checkInitialization()))
899 return std::make_unique<ExpandWhensPass>();
assert(baseType &&"element must be base type")
ScopedDriverMap::ScopeT DriverMap
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
HashTableStack< FieldRef, Operation * > ScopedDriverMap
This is a determistic mapping of a FieldRef to the last operation which set a value to it.
This class represents a reference to a specific field or element of an aggregate value.
Value getValue() const
Get the Value which created this location.
Operation * getDefiningOp() const
Get the operation which defines this field.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
connect(destination, source)
Flow swapFlow(Flow flow)
Get a flow's reverse.
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
std::unique_ptr< mlir::Pass > createExpandWhensPass()
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
bool operator!=(const Iterator &rhs) const
bool operator==(const Iterator &rhs) const
std::pair< KeyT, ValueT > & operator*() const
Iterator(typename StackT::iterator stackIt, typename ScopeT::iterator scopeIt)
This is a stack of hashtables, if lookup fails in the top-most hashtable, it will attempt to lookup i...
typename llvm::MapVector< KeyT, ValueT > ScopeT
iterator find(const KeyT &key)
typename llvm::SmallVector< ScopeT, 3 > StackT
ValueT & operator[](const KeyT &key)