20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/STLExtras.h"
23 using namespace circt;
24 using namespace firrtl;
28 static void mergeBlock(Block &destination, Block::iterator insertPoint,
30 destination.getOperations().splice(insertPoint, source.getOperations());
38 template <
typename KeyT,
typename ValueT>
40 using ScopeT =
typename llvm::MapVector<KeyT, ValueT>;
41 using StackT =
typename llvm::SmallVector<ScopeT, 3>;
45 typename ScopeT::iterator scopeIt)
46 : stackIt(stackIt), scopeIt(scopeIt) {}
54 std::pair<KeyT, ValueT> &
operator*()
const {
return *scopeIt; }
57 if (scopeIt == stackIt->end())
58 scopeIt = (++stackIt)->begin();
76 return Iterator(mapStack.begin(), mapStack.first().begin());
83 for (
auto i = mapStack.size(); i > 0; --i) {
84 auto &map = mapStack[i - 1];
85 auto it = map.find(key);
87 return Iterator(mapStack.begin() + i - 1, it);
97 assert(mapStack.size() > 1 &&
"Cannot pop the last scope");
98 return mapStack.pop_back_val();
102 ValueT &
operator[](
const KeyT &key) {
return mapStack.back()[key]; }
120 template <
typename ConcreteT>
121 class LastConnectResolver :
public FIRRTLVisitor<ConcreteT> {
129 LastConnectResolver(
ScopedDriverMap &driverMap) : driverMap(driverMap) {}
140 bool recordConnect(
FieldRef dest, Operation *connection) {
142 auto itAndInserted = driverMap.
getLastScope().insert({dest, connection});
143 if (isStaticSingleConnect(connection)) {
145 assert(itAndInserted.second || !itAndInserted.first->second);
146 if (!itAndInserted.second)
147 itAndInserted.first->second = connection;
150 assert(isLastConnect(connection));
151 if (!std::get<1>(itAndInserted)) {
152 auto iterator = std::get<0>(itAndInserted);
153 auto changed =
false;
156 if (
auto *oldConnect = iterator->second) {
160 iterator->second = connection;
168 static Value getDestinationValue(Operation *op) {
169 return cast<FConnectLike>(op).getDest();
174 static Value getConnectedValue(Operation *op) {
175 return cast<FConnectLike>(op).getSrc();
179 static bool isStaticSingleConnect(Operation *op) {
180 return cast<FConnectLike>(op).hasStaticSingleConnectBehavior();
187 static bool isLastConnect(Operation *op) {
188 return cast<FConnectLike>(op).hasLastConnectBehavior();
193 void declareSinks(Value
value,
Flow flow,
bool local =
false) {
194 auto type =
value.getType();
198 std::function<void(Type,
Flow,
bool)> declare = [&](Type type,
Flow flow,
201 if (
auto classType = type_dyn_cast<ClassType>(type)) {
205 for (
auto &element : classType.getElements()) {
208 declare(element.type, flow,
false);
210 declare(element.type,
swapFlow(flow),
false);
216 driverMap[{
value,
id}] =
nullptr;
219 id += classType.getMaxFieldID();
225 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
226 for (
auto &element : bundleType.getElements()) {
229 declare(element.type,
swapFlow(flow),
false);
231 declare(element.type, flow,
false);
237 if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
238 for (
unsigned i = 0; i < vectorType.getNumElements(); ++i) {
240 declare(vectorType.getElementType(), flow,
false);
246 if (
auto analogType = type_dyn_cast<AnalogType>(type))
252 driverMap[{
value,
id}] =
nullptr;
255 declare(type, flow, local);
260 ConnectOp flattenConditionalConnections(OpBuilder &b, Location loc,
261 Value dest, Value cond,
262 Operation *whenTrueConn,
263 Operation *whenFalseConn) {
264 assert(isLastConnect(whenTrueConn) && isLastConnect(whenFalseConn));
266 b.getFusedLoc({loc, whenTrueConn->getLoc(), whenFalseConn->getLoc()});
267 auto whenTrue = getConnectedValue(whenTrueConn);
269 isa_and_nonnull<InvalidValueOp>(whenTrue.getDefiningOp());
270 auto whenFalse = getConnectedValue(whenFalseConn);
271 auto falseIsInvalid =
272 isa_and_nonnull<InvalidValueOp>(whenFalse.getDefiningOp());
279 Value newValue = whenTrue;
280 if (trueIsInvalid == falseIsInvalid)
281 newValue = b.createOrFold<MuxPrimOp>(fusedLoc, cond, whenTrue, whenFalse);
282 else if (trueIsInvalid)
283 newValue = whenFalse;
284 return b.create<ConnectOp>(loc, dest, newValue);
287 void visitDecl(WireOp op) { declareSinks(op.getResult(),
Flow::Duplex); }
292 llvm::function_ref<
void(Value)> fn) {
294 .template Case<BundleType>([&](BundleType bundle) {
295 for (
auto i : llvm::seq(0u, (
unsigned)bundle.getNumElements())) {
297 builder.create<SubfieldOp>(value.getLoc(), value, i);
298 foreachSubelement(builder, subfield, fn);
301 .
template Case<FVectorType>([&](FVectorType vector) {
302 for (
auto i : llvm::seq((
size_t)0, vector.getNumElements())) {
304 builder.create<SubindexOp>(value.getLoc(), value, i);
305 foreachSubelement(builder, subindex, fn);
308 .Default([&](
auto) { fn(
value); });
311 void visitDecl(RegOp op) {
314 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
315 auto fn = [&](Value
value) {
319 foreachSubelement(
builder, op.getResult(), fn);
322 void visitDecl(RegResetOp op) {
325 auto builder = OpBuilder(op->getBlock(), ++Block::iterator(op));
326 auto fn = [&](Value
value) {
330 foreachSubelement(
builder, op.getResult(), fn);
333 void visitDecl(InstanceOp op) {
336 for (
const auto &result : llvm::enumerate(op.getResults()))
343 void visitDecl(ObjectOp op) {
347 void visitDecl(MemOp op) {
349 for (
auto result : op.getResults())
350 if (!isa<RefType>(result.getType()))
354 void visitStmt(ConnectOp op) {
358 void visitStmt(StrictConnectOp op) {
362 void visitStmt(RefDefineOp op) {
366 void visitStmt(PropAssignOp op) {
370 void processWhenOp(WhenOp whenOp, Value outerCondition);
389 Value thenCondition) {
392 for (
auto &destAndConnect : thenScope) {
393 auto dest = std::get<0>(destAndConnect);
394 auto thenConnect = std::get<1>(destAndConnect);
396 auto outerIt = driverMap.
find(dest);
397 if (outerIt == driverMap.
end()) {
400 driverMap[dest] = thenConnect;
404 auto elseIt = elseScope.
find(dest);
405 if (elseIt != elseScope.end()) {
410 auto &elseConnect = std::get<1>(*elseIt);
411 OpBuilder connectBuilder(elseConnect);
412 auto newConnect = flattenConditionalConnections(
413 connectBuilder, loc, getDestinationValue(thenConnect),
414 thenCondition, thenConnect, elseConnect);
417 thenConnect->erase();
418 elseConnect->erase();
419 recordConnect(dest, newConnect);
424 auto &outerConnect = std::get<1>(*outerIt);
426 if (isLastConnect(thenConnect)) {
429 thenConnect->erase();
431 assert(isStaticSingleConnect(thenConnect));
432 driverMap[dest] = thenConnect;
439 OpBuilder connectBuilder(thenConnect);
440 auto newConnect = flattenConditionalConnections(
441 connectBuilder, loc, getDestinationValue(thenConnect), thenCondition,
442 thenConnect, outerConnect);
445 thenConnect->erase();
446 recordConnect(dest, newConnect);
450 for (
auto &destAndConnect : elseScope) {
451 auto dest = std::get<0>(destAndConnect);
452 auto elseConnect = std::get<1>(destAndConnect);
456 if (thenScope.contains(dest))
459 auto outerIt = driverMap.
find(dest);
460 if (outerIt == driverMap.
end()) {
463 driverMap[dest] = elseConnect;
467 auto &outerConnect = std::get<1>(*outerIt);
469 if (isLastConnect(elseConnect)) {
472 elseConnect->erase();
474 assert(isStaticSingleConnect(elseConnect));
475 driverMap[dest] = elseConnect;
482 OpBuilder connectBuilder(elseConnect);
483 auto newConnect = flattenConditionalConnections(
484 connectBuilder, loc, getDestinationValue(outerConnect), thenCondition,
485 outerConnect, elseConnect);
488 elseConnect->erase();
489 recordConnect(dest, newConnect);
503 class WhenOpVisitor :
public LastConnectResolver<WhenOpVisitor> {
507 : LastConnectResolver<WhenOpVisitor>(driverMap), condition(condition) {}
509 using LastConnectResolver<WhenOpVisitor>::visitExpr;
510 using LastConnectResolver<WhenOpVisitor>::visitDecl;
511 using LastConnectResolver<WhenOpVisitor>::visitStmt;
514 void process(Block &block);
517 void visitStmt(AssertOp op);
518 void visitStmt(AssumeOp op);
519 void visitStmt(CoverOp op);
520 void visitStmt(ModuleOp op);
521 void visitStmt(PrintFOp op);
522 void visitStmt(StopOp op);
523 void visitStmt(WhenOp op);
524 void visitStmt(RefForceOp op);
525 void visitStmt(RefForceInitialOp op);
526 void visitStmt(RefReleaseOp op);
527 void visitStmt(RefReleaseInitialOp op);
532 Value andWithCondition(Operation *op, Value
value) {
534 return OpBuilder(op).createOrFold<AndPrimOp>(
535 condition.getLoc(), condition.getType(), condition,
value);
544 void WhenOpVisitor::process(Block &block) {
545 for (
auto &op : llvm::make_early_inc_range(block)) {
546 dispatchVisitor(&op);
550 void WhenOpVisitor::visitStmt(PrintFOp op) {
551 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
554 void WhenOpVisitor::visitStmt(StopOp op) {
555 op.getCondMutable().assign(andWithCondition(op, op.getCond()));
558 void WhenOpVisitor::visitStmt(AssertOp op) {
559 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
562 void WhenOpVisitor::visitStmt(AssumeOp op) {
563 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
566 void WhenOpVisitor::visitStmt(CoverOp op) {
567 op.getEnableMutable().assign(andWithCondition(op, op.getEnable()));
570 void WhenOpVisitor::visitStmt(WhenOp whenOp) {
571 processWhenOp(whenOp, condition);
574 void WhenOpVisitor::visitStmt(RefForceOp op) {
575 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
578 void WhenOpVisitor::visitStmt(RefForceInitialOp op) {
579 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
582 void WhenOpVisitor::visitStmt(RefReleaseOp op) {
583 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
586 void WhenOpVisitor::visitStmt(RefReleaseInitialOp op) {
587 op.getPredicateMutable().assign(andWithCondition(op, op.getPredicate()));
595 template <
typename ConcreteT>
596 void LastConnectResolver<ConcreteT>::processWhenOp(WhenOp whenOp,
597 Value outerCondition) {
599 auto loc = whenOp.getLoc();
600 Block *parentBlock = whenOp->getBlock();
601 auto condition = whenOp.getCondition();
602 auto ui1Type = condition.getType();
610 Value thenCondition = whenOp.getCondition();
613 b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition, thenCondition);
615 auto &thenBlock = whenOp.getThenBlock();
616 driverMap.pushScope();
617 WhenOpVisitor(driverMap, thenCondition).process(thenBlock);
618 mergeBlock(*parentBlock, Block::iterator(whenOp), thenBlock);
619 auto thenScope = driverMap.popScope();
623 if (whenOp.hasElseRegion()) {
626 b.createOrFold<NotPrimOp>(loc, condition.getType(), condition);
629 elseCondition = b.createOrFold<AndPrimOp>(loc, ui1Type, outerCondition,
631 auto &elseBlock = whenOp.getElseBlock();
632 driverMap.pushScope();
633 WhenOpVisitor(driverMap, elseCondition).process(elseBlock);
634 mergeBlock(*parentBlock, Block::iterator(whenOp), elseBlock);
635 elseScope = driverMap.popScope();
638 mergeScopes(loc, thenScope, elseScope, condition);
650 class ModuleVisitor :
public LastConnectResolver<ModuleVisitor> {
652 ModuleVisitor() : LastConnectResolver<ModuleVisitor>(driverMap) {}
654 using LastConnectResolver<ModuleVisitor>::visitExpr;
655 using LastConnectResolver<ModuleVisitor>::visitDecl;
656 using LastConnectResolver<ModuleVisitor>::visitStmt;
657 void visitStmt(WhenOp whenOp);
658 void visitStmt(ConnectOp connectOp);
659 void visitStmt(StrictConnectOp connectOp);
660 void visitStmt(GroupOp groupOp);
662 bool run(FModuleLike op);
663 LogicalResult checkInitialization();
670 bool anythingChanged =
false;
677 bool ModuleVisitor::run(FModuleLike op) {
679 if (!isa<FModuleOp, ClassOp>(op))
680 return anythingChanged;
682 for (
auto ®ion : op->getRegions()) {
683 for (
auto &block : region.getBlocks()) {
685 for (
const auto &[index,
value] : llvm::enumerate(block.getArguments())) {
686 auto direction = op.getPortDirection(index);
688 declareSinks(
value, flow);
692 for (
auto &op : llvm::make_early_inc_range(block))
693 dispatchVisitor(&op);
697 return anythingChanged;
700 void ModuleVisitor::visitStmt(ConnectOp op) {
704 void ModuleVisitor::visitStmt(StrictConnectOp op) {
708 void ModuleVisitor::visitStmt(WhenOp whenOp) {
710 anythingChanged =
true;
711 processWhenOp(whenOp, {});
714 void ModuleVisitor::visitStmt(GroupOp groupOp) {
715 for (
auto &op : llvm::make_early_inc_range(*groupOp.getBody())) {
716 dispatchVisitor(&op);
722 LogicalResult ModuleVisitor::checkInitialization() {
724 for (
auto destAndConnect : driverMap.getLastScope()) {
726 auto *
connect = std::get<1>(destAndConnect);
731 FieldRef dest = std::get<0>(destAndConnect);
732 auto loc = dest.
getValue().getLoc();
734 if (
auto mod = dyn_cast<FModuleLike>(definingOp))
735 mlir::emitError(loc) <<
"port \"" <<
getFieldName(dest).first
736 <<
"\" not fully initialized in \""
737 << mod.getModuleName() <<
"\"";
741 <<
"\" not fully initialized in \""
742 << definingOp->getParentOfType<FModuleLike>().getModuleName() <<
"\"";
755 class ExpandWhensPass :
public ExpandWhensBase<ExpandWhensPass> {
756 void runOnOperation()
override;
760 void ExpandWhensPass::runOnOperation() {
761 ModuleVisitor visitor;
762 if (!visitor.run(getOperation()))
763 markAllAnalysesPreserved();
764 if (failed(visitor.checkInitialization()))
769 return std::make_unique<ExpandWhensPass>();
assert(baseType &&"element must be base type")
ScopedDriverMap::ScopeT DriverMap
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
This class represents a reference to a specific field or element of an aggregate value.
Value getValue() const
Get the Value which created this location.
Operation * getDefiningOp() const
Get the operation which defines this field.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
def connect(destination, source)
Flow swapFlow(Flow flow)
Get a flow's reverse.
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
std::unique_ptr< mlir::Pass > createExpandWhensPass()
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::pair< KeyT, ValueT > & operator*() const
bool operator!=(const Iterator &rhs) const
bool operator==(const Iterator &rhs) const
Iterator(typename StackT::iterator stackIt, typename ScopeT::iterator scopeIt)
This is a stack of hashtables, if lookup fails in the top-most hashtable, it will attempt to lookup i...
typename llvm::MapVector< KeyT, ValueT > ScopeT
iterator find(const KeyT &key)
ValueT & operator[](const KeyT &key)
typename llvm::SmallVector< ScopeT, 3 > StackT