22#include "mlir/Pass/Pass.h"
23#include "llvm/ADT/MapVector.h"
24#include "llvm/ADT/STLExtras.h"
28#define GEN_PASS_DEF_EXPANDWHENS
29#include "circt/Dialect/FIRRTL/Passes.h.inc"
34using namespace firrtl;
38static void mergeBlock(Block &destination, Block::iterator insertPoint,
40 destination.getOperations().splice(insertPoint, source.getOperations());
48template <
typename KeyT,
typename ValueT>
50 using ScopeT =
typename llvm::MapVector<KeyT, ValueT>;
51 using StackT =
typename llvm::SmallVector<ScopeT, 3>;
55 typename ScopeT::iterator
scopeIt)
93 for (
auto i =
mapStack.size(); i > 0; --i) {
95 auto it = map.find(key);
130template <
typename ConcreteT>
139 LastConnectResolver(
ScopedDriverMap &driverMap) : driverMap(driverMap) {}
150 bool recordConnect(
FieldRef dest, Operation *connection) {
152 auto itAndInserted = driverMap.getLastScope().insert({dest, connection});
153 if (isStaticSingleConnect(connection)) {
155 assert(itAndInserted.second || !itAndInserted.first->second);
156 if (!itAndInserted.second)
157 itAndInserted.first->second = connection;
160 assert(isLastConnect(connection));
161 if (!std::get<1>(itAndInserted)) {
162 auto iterator = std::get<0>(itAndInserted);
163 auto changed =
false;
166 if (
auto *oldConnect = iterator->second) {
170 iterator->second = connection;
178 static Value getDestinationValue(Operation *op) {
179 return cast<FConnectLike>(op).getDest();
184 static Value getConnectedValue(Operation *op) {
185 return cast<FConnectLike>(op).getSrc();
189 static bool isStaticSingleConnect(Operation *op) {
190 return cast<FConnectLike>(op).hasStaticSingleConnectBehavior();
197 static bool isLastConnect(Operation *op) {
198 return cast<FConnectLike>(op).hasLastConnectBehavior();
203 void declareSinks(Value value,
Flow flow,
bool local =
false) {
204 auto type = value.getType();
208 std::function<void(Type,
Flow,
bool)> declare = [&](Type type,
Flow flow,
212 if (type_isa<DomainType>(type))
216 if (
auto classType = type_dyn_cast<ClassType>(type)) {
220 for (
auto &element : classType.getElements()) {
222 if (element.direction == Direction::Out)
223 declare(element.type, flow,
false);
225 declare(element.type,
swapFlow(flow),
false);
230 if (flow != Flow::Source)
231 driverMap[{value,
id}] =
nullptr;
234 id += classType.getMaxFieldID();
240 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
241 for (
auto &element : bundleType.getElements()) {
244 declare(element.type,
swapFlow(flow),
false);
246 declare(element.type, flow,
false);
252 if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
253 for (
unsigned i = 0; i < vectorType.getNumElements(); ++i) {
255 declare(vectorType.getElementType(), flow,
false);
261 if (
auto analogType = type_dyn_cast<AnalogType>(type))
266 if (flow != Flow::Source)
267 driverMap[{value,
id}] =
nullptr;
270 declare(type, flow, local);
275 ConnectOp flattenConditionalConnections(OpBuilder &b, Location loc,
276 Value dest, Value cond,
277 Operation *whenTrueConn,
278 Operation *whenFalseConn) {
279 assert(isLastConnect(whenTrueConn) && isLastConnect(whenFalseConn));
281 b.getFusedLoc({loc, whenTrueConn->getLoc(), whenFalseConn->getLoc()});
282 auto whenTrue = getConnectedValue(whenTrueConn);
284 isa_and_nonnull<InvalidValueOp>(whenTrue.getDefiningOp());
285 auto whenFalse = getConnectedValue(whenFalseConn);
286 auto falseIsInvalid =
287 isa_and_nonnull<InvalidValueOp>(whenFalse.getDefiningOp());
294 Value newValue = whenTrue;
295 if (trueIsInvalid == falseIsInvalid)
296 newValue = b.createOrFold<MuxPrimOp>(fusedLoc, cond, whenTrue, whenFalse);
297 else if (trueIsInvalid)
298 newValue = whenFalse;
299 return ConnectOp::create(b, loc, dest, newValue);
302 void visitDecl(WireOp op) { declareSinks(op.getResult(), Flow::Duplex); }
306 void foreachSubelement(OpBuilder &builder, Value value,
307 llvm::function_ref<
void(Value)> fn) {
309 .template Case<BundleType>([&](BundleType bundle) {
310 for (
auto i : llvm::seq(0u, (
unsigned)bundle.getNumElements())) {
312 SubfieldOp::create(builder, value.getLoc(), value, i);
313 foreachSubelement(builder, subfield, fn);
316 .
template Case<FVectorType>([&](FVectorType vector) {
317 for (
auto i : llvm::seq((
size_t)0, vector.getNumElements())) {
319 SubindexOp::create(builder, value.getLoc(), value, i);
320 foreachSubelement(builder, subindex, fn);
323 .Default([&](
auto) { fn(value); });
326 void visitDecl(RegOp op) {
329 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
330 auto fn = [&](Value value) {
331 auto connect = ConnectOp::create(builder, value.getLoc(), value, value);
334 foreachSubelement(builder, op.getResult(), fn);
337 void visitDecl(RegResetOp op) {
340 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
341 auto fn = [&](Value value) {
342 auto connect = ConnectOp::create(builder, value.getLoc(), value, value);
345 foreachSubelement(builder, op.getResult(), fn);
348 template <
typename OpTy>
349 void visitInstanceDecl(OpTy op) {
352 for (
const auto &result : llvm::enumerate(op.getResults()))
356 declareSinks(result.value(), Flow::Sink);
359 void visitDecl(InstanceOp op) { visitInstanceDecl(op); }
360 void visitDecl(InstanceChoiceOp op) { visitInstanceDecl(op); }
362 void visitDecl(ObjectOp op) {
363 declareSinks(op, Flow::Source,
true);
366 void visitDecl(MemOp op) {
368 for (
auto result : op.getResults())
369 if (!isa<RefType>(result.getType()))
370 declareSinks(result, Flow::Sink);
373 void visitStmt(ConnectOp op) {
377 void visitStmt(MatchingConnectOp op) {
381 void visitStmt(RefDefineOp op) {
385 void visitStmt(PropAssignOp op) {
389 void visitStmt(DomainDefineOp op) {
393 void processWhenOp(WhenOp whenOp, Value outerCondition);
412 Value thenCondition) {
415 for (
auto &destAndConnect : thenScope) {
416 auto dest = std::get<0>(destAndConnect);
417 auto thenConnect = std::get<1>(destAndConnect);
419 auto outerIt = driverMap.find(dest);
420 if (outerIt == driverMap.end()) {
423 driverMap[dest] = thenConnect;
427 auto elseIt = elseScope.find(dest);
428 if (elseIt != elseScope.end()) {
433 auto &elseConnect = std::get<1>(*elseIt);
434 OpBuilder connectBuilder(elseConnect);
435 auto newConnect = flattenConditionalConnections(
436 connectBuilder, loc, getDestinationValue(thenConnect),
437 thenCondition, thenConnect, elseConnect);
440 thenConnect->erase();
441 elseConnect->erase();
442 recordConnect(dest, newConnect);
447 auto &outerConnect = std::get<1>(*outerIt);
449 if (isLastConnect(thenConnect)) {
452 thenConnect->erase();
454 assert(isStaticSingleConnect(thenConnect));
455 driverMap[dest] = thenConnect;
462 OpBuilder connectBuilder(thenConnect);
463 auto newConnect = flattenConditionalConnections(
464 connectBuilder, loc, getDestinationValue(thenConnect), thenCondition,
465 thenConnect, outerConnect);
468 thenConnect->erase();
469 recordConnect(dest, newConnect);
473 for (
auto &destAndConnect : elseScope) {
474 auto dest = std::get<0>(destAndConnect);
475 auto elseConnect = std::get<1>(destAndConnect);
479 if (thenScope.contains(dest))
482 auto outerIt = driverMap.find(dest);
483 if (outerIt == driverMap.end()) {
486 driverMap[dest] = elseConnect;
490 auto &outerConnect = std::get<1>(*outerIt);
492 if (isLastConnect(elseConnect)) {
495 elseConnect->erase();
497 assert(isStaticSingleConnect(elseConnect));
498 driverMap[dest] = elseConnect;
505 OpBuilder connectBuilder(elseConnect);
506 auto newConnect = flattenConditionalConnections(
507 connectBuilder, loc, getDestinationValue(outerConnect), thenCondition,
508 outerConnect, elseConnect);
511 elseConnect->erase();
512 recordConnect(dest, newConnect);
526class WhenOpVisitor :
public LastConnectResolver<WhenOpVisitor> {
530 : LastConnectResolver<WhenOpVisitor>(driverMap), condition(condition) {}
532 using LastConnectResolver<WhenOpVisitor>::visitExpr;
533 using LastConnectResolver<WhenOpVisitor>::visitDecl;
534 using LastConnectResolver<WhenOpVisitor>::visitStmt;
535 using LastConnectResolver<WhenOpVisitor>::visitStmtExpr;
538 void process(Block &block);
541 void visitStmt(VerifAssertIntrinsicOp op);
542 void visitStmt(VerifAssumeIntrinsicOp op);
543 void visitStmt(VerifCoverIntrinsicOp op);
544 void visitStmt(AssertOp op);
545 void visitStmt(AssumeOp op);
546 void visitStmt(UnclockedAssumeIntrinsicOp op);
547 void visitStmt(CoverOp op);
548 void visitStmt(ModuleOp op);
549 void visitStmt(PrintFOp op);
550 void visitStmt(FPrintFOp op);
551 void visitStmt(FFlushOp op);
552 void visitStmt(StopOp op);
553 void visitStmt(WhenOp op);
554 void visitStmt(LayerBlockOp op);
555 void visitStmt(RefForceOp op);
556 void visitStmt(RefForceInitialOp op);
557 void visitStmt(RefReleaseOp op);
558 void visitStmt(RefReleaseInitialOp op);
559 void visitStmtExpr(DPICallIntrinsicOp op);
564 Value andWithCondition(Operation *op, Value value) {
566 return OpBuilder(op).createOrFold<AndPrimOp>(
567 condition.getLoc(), condition.getType(), condition, value);
572 Value ltlAndWithCondition(Operation *op, Value property) {
574 while (
auto nodeOp = property.getDefiningOp<NodeOp>())
575 property = nodeOp.getInput();
578 if (
auto clockOp = property.getDefiningOp<LTLClockIntrinsicOp>()) {
579 auto input = ltlAndWithCondition(op, clockOp.getInput());
580 auto &newClockOp = createdLTLClockOps[{clockOp, input}];
582 newClockOp = OpBuilder(op).cloneWithoutRegions(clockOp);
583 newClockOp.getInputMutable().assign(input);
589 auto &newOp = createdLTLAndOps[{condition,
property}];
591 newOp = OpBuilder(op).createOrFold<LTLAndIntrinsicOp>(
592 condition.getLoc(),
property.getType(), condition, property);
599 Value ltlImplicationWithCondition(Operation *op, Value property) {
601 while (
auto nodeOp = property.getDefiningOp<NodeOp>())
602 property = nodeOp.getInput();
605 if (
auto clockOp = property.getDefiningOp<LTLClockIntrinsicOp>()) {
606 auto input = ltlImplicationWithCondition(op, clockOp.getInput());
607 auto &newClockOp = createdLTLClockOps[{clockOp, input}];
609 newClockOp = OpBuilder(op).cloneWithoutRegions(clockOp);
610 newClockOp.getInputMutable().assign(input);
616 if (
auto implOp = property.getDefiningOp<LTLImplicationIntrinsicOp>()) {
617 auto lhs = ltlAndWithCondition(op, implOp.getLhs());
618 auto &newImplOp = createdLTLImplicationOps[{lhs, implOp.getRhs()}];
620 auto clonedOp = OpBuilder(op).cloneWithoutRegions(implOp);
621 clonedOp.getLhsMutable().assign(lhs);
622 newImplOp = clonedOp;
628 auto &newImplOp = createdLTLImplicationOps[{condition,
property}];
630 newImplOp = OpBuilder(op).createOrFold<LTLImplicationIntrinsicOp>(
631 condition.getLoc(),
property.getType(), condition, property);
651void WhenOpVisitor::process(Block &block) {
652 for (
auto &op :
llvm::make_early_inc_range(block)) {
653 dispatchVisitor(&op);
657void WhenOpVisitor::visitStmt(PrintFOp op) {
658 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
661void WhenOpVisitor::visitStmt(FPrintFOp op) {
662 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
665void WhenOpVisitor::visitStmt(FFlushOp op) {
666 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
669void WhenOpVisitor::visitStmt(StopOp op) {
670 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
673void WhenOpVisitor::visitStmt(VerifAssertIntrinsicOp op) {
674 op.getPropertyMutable().assign(
675 ltlImplicationWithCondition(op, op.getProperty()));
678void WhenOpVisitor::visitStmt(VerifAssumeIntrinsicOp op) {
679 op.getPropertyMutable().assign(
680 ltlImplicationWithCondition(op, op.getProperty()));
683void WhenOpVisitor::visitStmt(VerifCoverIntrinsicOp op) {
684 op.getPropertyMutable().assign(ltlAndWithCondition(op, op.getProperty()));
687void WhenOpVisitor::visitStmt(AssertOp op) {
688 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
691void WhenOpVisitor::visitStmt(AssumeOp op) {
692 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
695void WhenOpVisitor::visitStmt(UnclockedAssumeIntrinsicOp op) {
696 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
699void WhenOpVisitor::visitStmt(CoverOp op) {
700 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
703void WhenOpVisitor::visitStmt(WhenOp whenOp) {
704 processWhenOp(whenOp, condition);
708void WhenOpVisitor::visitStmt(LayerBlockOp layerBlockOp) {
709 process(*layerBlockOp.getBody());
712void WhenOpVisitor::visitStmt(RefForceOp op) {
713 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
716void WhenOpVisitor::visitStmt(RefForceInitialOp op) {
717 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
720void WhenOpVisitor::visitStmt(RefReleaseOp op) {
721 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
724void WhenOpVisitor::visitStmt(RefReleaseInitialOp op) {
725 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
728void WhenOpVisitor::visitStmtExpr(DPICallIntrinsicOp op) {
730 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
732 op.getEnableMutable().assign(condition);
740template <
typename ConcreteT>
741void LastConnectResolver<ConcreteT>::processWhenOp(WhenOp whenOp,
742 Value outerCondition) {
744 auto loc = whenOp.getLoc();
745 Block *parentBlock = whenOp->getBlock();
746 auto condition = whenOp.getCondition();
747 auto ui1Type = condition.getType();
755 Value thenCondition = whenOp.getCondition();
758 b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition, thenCondition);
760 auto &thenBlock = whenOp.getThenBlock();
761 driverMap.pushScope();
762 WhenOpVisitor(driverMap, thenCondition).process(thenBlock);
763 mergeBlock(*parentBlock, Block::iterator(whenOp), thenBlock);
764 auto thenScope = driverMap.popScope();
768 if (whenOp.hasElseRegion()) {
771 b.createOrFold<NotPrimOp>(loc, condition.getType(), condition);
774 elseCondition =
b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition,
776 auto &elseBlock = whenOp.getElseBlock();
777 driverMap.pushScope();
778 WhenOpVisitor(driverMap, elseCondition).process(elseBlock);
779 mergeBlock(*parentBlock, Block::iterator(whenOp), elseBlock);
780 elseScope = driverMap.popScope();
783 mergeScopes(loc, thenScope, elseScope, condition);
795class ModuleVisitor :
public LastConnectResolver<ModuleVisitor> {
797 ModuleVisitor() : LastConnectResolver<ModuleVisitor>(driverMap) {}
799 using LastConnectResolver<ModuleVisitor>::visitExpr;
800 using LastConnectResolver<ModuleVisitor>::visitDecl;
801 using LastConnectResolver<ModuleVisitor>::visitStmt;
802 void visitStmt(WhenOp whenOp);
803 void visitStmt(ConnectOp connectOp);
804 void visitStmt(MatchingConnectOp connectOp);
805 void visitStmt(LayerBlockOp layerBlockOp);
807 bool run(FModuleLike op);
808 LogicalResult checkInitialization();
815 bool anythingChanged =
false;
822bool ModuleVisitor::run(FModuleLike op) {
824 if (!isa<FModuleOp, ClassOp>(op))
825 return anythingChanged;
827 for (
auto ®ion : op->getRegions()) {
828 for (
auto &block : region.getBlocks()) {
830 for (
const auto &[index, value] :
llvm::enumerate(block.getArguments())) {
831 auto direction = op.getPortDirection(index);
832 auto flow = direction == Direction::In ? Flow::Source : Flow::Sink;
833 declareSinks(value, flow);
837 for (
auto &op :
llvm::make_early_inc_range(block))
838 dispatchVisitor(&op);
842 return anythingChanged;
845void ModuleVisitor::visitStmt(ConnectOp op) {
849void ModuleVisitor::visitStmt(MatchingConnectOp op) {
853void ModuleVisitor::visitStmt(WhenOp whenOp) {
855 anythingChanged =
true;
856 processWhenOp(whenOp, {});
859void ModuleVisitor::visitStmt(LayerBlockOp layerBlockOp) {
860 for (
auto &op :
llvm::make_early_inc_range(*layerBlockOp.getBody())) {
861 dispatchVisitor(&op);
867LogicalResult ModuleVisitor::checkInitialization() {
869 for (
auto destAndConnect : driverMap.getLastScope()) {
871 auto *
connect = std::get<1>(destAndConnect);
876 FieldRef dest = std::get<0>(destAndConnect);
877 auto loc = dest.
getValue().getLoc();
879 if (
auto mod = dyn_cast<FModuleLike>(definingOp))
880 mlir::emitError(loc) <<
"port \"" <<
getFieldName(dest).first
881 <<
"\" not fully initialized in \""
882 << mod.getModuleName() <<
"\"";
886 <<
"\" not fully initialized in \""
887 << definingOp->getParentOfType<FModuleLike>().getModuleName() <<
"\"";
901 :
public circt::firrtl::impl::ExpandWhensBase<ExpandWhensPass> {
902 void runOnOperation()
override;
906void ExpandWhensPass::runOnOperation() {
907 ModuleVisitor visitor;
908 if (!visitor.run(getOperation()))
909 markAllAnalysesPreserved();
910 if (failed(visitor.checkInitialization()))
assert(baseType &&"element must be base type")
ScopedDriverMap::ScopeT DriverMap
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
HashTableStack< FieldRef, Operation * > ScopedDriverMap
This is a determistic mapping of a FieldRef to the last operation which set a value to it.
This class represents a reference to a specific field or element of an aggregate value.
Value getValue() const
Get the Value which created this location.
Operation * getDefiningOp() const
Get the operation which defines this field.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
connect(destination, source)
Flow swapFlow(Flow flow)
Get a flow's reverse.
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
bool operator!=(const Iterator &rhs) const
bool operator==(const Iterator &rhs) const
std::pair< KeyT, ValueT > & operator*() const
Iterator(typename StackT::iterator stackIt, typename ScopeT::iterator scopeIt)
This is a stack of hashtables, if lookup fails in the top-most hashtable, it will attempt to lookup i...
typename llvm::MapVector< KeyT, ValueT > ScopeT
iterator find(const KeyT &key)
typename llvm::SmallVector< ScopeT, 3 > StackT
ValueT & operator[](const KeyT &key)