11#include "mlir/IR/Threading.h"
12#include "mlir/Transforms/DialectConversion.h"
13#include "llvm/ADT/DenseSet.h"
14#include "llvm/Support/Debug.h"
21#define DEBUG_TYPE "lower-seq-firreg"
24 [](
const Operation *op) ->
bool {
25 return (isa<comb::MuxOp, ArrayGetOp, ArrayCreateOp>(op));
30 return llvm::any_of(regOp.getResult().getUsers(), [&](Operation *user) {
31 if (!OpUserInfo::opAllowsReachability(user))
33 buildReachabilityFrom(user);
34 return reachableMuxes[user].contains(muxOp);
46 if (
visited.contains(startNode))
51 llvm::SmallVector<OpUserInfo, 16> stk;
53 stk.emplace_back(startNode);
55 while (!stk.empty()) {
56 auto &info = stk.back();
57 Operation *currentNode = info.op;
60 if (info.getAndSetUnvisited())
63 if (info.userIter != info.userEnd) {
64 Operation *child = *info.userIter;
67 stk.emplace_back(child);
73 for (
auto *childOp : llvm::make_filter_range(
82 iter->getSecond().end());
90 const std::function<
void()> &trueSide,
91 const std::function<
void()> &falseSide) {
92 auto op =
ifCache.lookup({builder.getBlock(), cond});
97 builder.create<sv::IfOp>(cond.getLoc(), cond, trueSide, falseSide);
98 ifCache.insert({{builder.getBlock(), cond}, newIfOp});
100 OpBuilder::InsertionGuard guard(builder);
101 builder.setInsertionPointToEnd(op.getThenBlock());
103 builder.setInsertionPointToEnd(op.getElseBlock());
110 bool disableRegRandomization,
111 bool emitSeparateAlwaysBlocks)
112 : typeConverter(typeConverter), module(module),
113 disableRegRandomization(disableRegRandomization),
114 emitSeparateAlwaysBlocks(emitSeparateAlwaysBlocks) {
121 auto regs =
module.getOps<seq::FirRegOp>();
126 SmallVector<RegLowerInfo> randomInit, presetInit;
127 llvm::MapVector<Value, SmallVector<RegLowerInfo>> asyncResets;
128 for (
auto reg : llvm::make_early_inc_range(regs)) {
131 presetInit.push_back(svReg);
133 randomInit.push_back(svReg);
135 if (svReg.asyncResetSignal)
136 asyncResets[svReg.asyncResetSignal].emplace_back(svReg);
144 for (
auto reg : randomInit)
145 if (
reg.randStart >= 0)
146 maxBit = std::max(maxBit, (uint64_t)
reg.randStart +
reg.width);
148 for (
auto &
reg : randomInit) {
149 if (
reg.randStart == -1) {
150 reg.randStart = maxBit;
167 if (randomInit.empty() && presetInit.empty() && asyncResets.empty())
172 auto loc =
module.getLoc();
173 MLIRContext *context =
module.getContext();
174 auto randInitRef = sv::MacroIdentAttr::get(context,
"RANDOMIZE_REG_INIT");
177 ImplicitLocOpBuilder::atBlockTerminator(loc, module.getBodyBlock());
179 builder.create<
sv::IfDefOp>(
"ENABLE_INITIAL_REG_", [&] {
180 builder.create<sv::OrderedOutputOp>([&] {
181 builder.create<
sv::IfDefOp>(
"FIRRTL_BEFORE_INITIAL", [&] {
182 builder.create<sv::VerbatimOp>(
"`FIRRTL_BEFORE_INITIAL");
185 builder.create<sv::InitialOp>([&] {
186 if (!randomInit.empty()) {
187 builder.create<sv::IfDefProceduralOp>(
"INIT_RANDOM_PROLOG_", [&] {
188 builder.create<sv::VerbatimOp>(
"`INIT_RANDOM_PROLOG_");
190 builder.create<sv::IfDefProceduralOp>(randInitRef, [&] {
192 SmallVector<Value> randValues;
193 auto numRandomCalls = (maxBit + 31) / 32;
194 auto logic = builder.create<sv::LogicOp>(
196 hw::UnpackedArrayType::get(builder.getIntegerType(32),
201 auto inducionVariableWidth = llvm::Log2_64_Ceil(numRandomCalls + 1);
202 auto arrayIndexWith = llvm::Log2_64_Ceil(numRandomCalls);
206 loc, APInt(inducionVariableWidth, numRandomCalls));
209 auto forLoop = builder.create<sv::ForOp>(
210 loc, lb, ub, step,
"i", [&](BlockArgument iter) {
211 auto rhs = builder.create<sv::MacroRefExprSEOp>(
212 loc, builder.getIntegerType(32),
"RANDOM");
213 Value iterValue = iter;
214 if (!iter.getType().isInteger(arrayIndexWith))
216 loc, iterValue, 0, arrayIndexWith);
217 auto lhs = builder.
create<sv::ArrayIndexInOutOp>(loc, logic,
219 builder.create<sv::BPAssignOp>(loc, lhs, rhs);
221 builder.setInsertionPointAfter(forLoop);
222 for (uint64_t x = 0; x < numRandomCalls; ++x) {
223 auto lhs = builder.create<sv::ArrayIndexInOutOp>(
226 randValues.push_back(lhs.getResult());
230 for (
auto &svReg : randomInit)
235 if (!presetInit.empty()) {
236 for (
auto &svReg : presetInit) {
237 auto loc = svReg.reg.getLoc();
238 auto elemTy = svReg.reg.getType().getElementType();
242 if (cst.getType() == elemTy)
247 builder.
create<sv::BPAssignOp>(loc, svReg.reg, rhs);
251 if (!asyncResets.empty()) {
255 for (
auto &reset : asyncResets) {
259 builder.create<sv::IfOp>(reset.first, [&] {
260 for (
auto &
reg : reset.second)
261 builder.create<sv::BPAssignOp>(
reg.reg.getLoc(),
reg.reg,
262 reg.asyncResetValue);
268 builder.create<
sv::IfDefOp>(
"FIRRTL_AFTER_INITIAL", [&] {
269 builder.create<sv::VerbatimOp>(
"`FIRRTL_AFTER_INITIAL");
274 module->removeAttr("firrtl.random_init_width");
290 return c1.getType() == c2.getType() &&
291 c1.getValue() == c2.getValue() &&
301 if (!andOp || !andOp.getTwoState()) {
302 llvm::SetVector<Value> ret;
307 return llvm::SetVector<Value>(andOp.getOperands().begin(),
308 andOp.getOperands().end());
312 auto constantIndex = value.template getDefiningOp<hw::ConstantOp>();
314 return constantIndex.getValue();
323std::optional<std::tuple<Value, Value, Value>>
327 SmallVector<Value> muxConditions;
330 SmallVector<Value> reverseOpValues(llvm::reverse(nextRegValue.getOperands()));
331 if (!llvm::all_of(llvm::enumerate(reverseOpValues), [&](
auto idxAndValue) {
333 auto [i, value] = idxAndValue;
334 auto mux = value.template getDefiningOp<comb::MuxOp>();
336 if (!mux || !mux.getTwoState())
339 if (trueVal && trueVal != mux.getTrueValue())
342 trueVal = mux.getTrueValue();
343 muxConditions.push_back(mux.getCond());
347 mux.getFalseValue().template getDefiningOp<hw::ArrayGetOp>();
356 llvm::SetVector<Value> commonConditions =
358 for (
auto condition : ArrayRef(muxConditions).drop_front()) {
360 commonConditions.remove_if([&](
auto v) {
return !cond.contains(v); });
363 for (
auto [idx, condition] : llvm::enumerate(muxConditions)) {
367 extractedConditions.remove_if(
368 [&](
auto v) {
return commonConditions.contains(v); });
369 if (extractedConditions.size() != 1)
373 (*extractedConditions.begin()).getDefiningOp<comb::ICmpOp>();
374 if (!indexCompare || !indexCompare.getTwoState() ||
375 indexCompare.getPredicate() != comb::ICmpPredicate::eq)
378 if (indexValue && indexValue != indexCompare.getLhs())
381 indexValue = indexCompare.getLhs();
386 OpBuilder::InsertionGuard guard(builder);
387 builder.setInsertionPointAfterValue(
reg);
388 Value commonConditionValue;
389 if (commonConditions.empty())
392 commonConditionValue = builder.createOrFold<
comb::AndOp>(
393 reg.getLoc(), builder.getI1Type(), commonConditions.takeVector(),
true);
394 return std::make_tuple(commonConditionValue, indexValue, trueVal);
406 auto firReg = term.getDefiningOp<seq::FirRegOp>();
408 SmallVector<std::tuple<Block *, Value, Value, Value>> worklist;
409 auto addToWorklist = [&](Value
reg, Value term, Value next) {
410 worklist.push_back({builder.getBlock(),
reg, term, next});
413 auto getArrayIndex = [&](Value
reg, Value idx) {
415 OpBuilder::InsertionGuard guard(builder);
416 builder.setInsertionPointAfterValue(
reg);
417 return builder.create<sv::ArrayIndexInOutOp>(
reg.getLoc(),
reg, idx);
420 SmallVector<Value, 8> opsToDelete;
421 addToWorklist(
reg, term, next);
422 while (!worklist.empty()) {
423 OpBuilder::InsertionGuard guard(builder);
425 Value
reg, term, next;
426 std::tie(block,
reg, term, next) = worklist.pop_back_val();
427 builder.setInsertionPointToEnd(block);
434 if (mux && mux.getTwoState() &&
437 builder, mux.getCond(),
438 [&]() { addToWorklist(reg, term, mux.getTrueValue()); },
439 [&]() { addToWorklist(reg, term, mux.getFalseValue()); });
446 if (
auto matchResultOpt =
448 Value cond, index, trueValue;
449 std::tie(cond, index, trueValue) = *matchResultOpt;
453 Value nextReg = getArrayIndex(
reg, index);
459 opsToDelete.push_back(termElement);
460 addToWorklist(nextReg, termElement, trueValue);
469 for (
auto [idx, value] : llvm::enumerate(array.getOperands())) {
470 idx = array.getOperands().size() - idx - 1;
474 APInt(std::max(1u, llvm::Log2_64_Ceil(array.getOperands().size())),
479 index = getArrayIndex(
reg, idxVal);
486 opsToDelete.push_back(termElement);
487 addToWorklist(index, termElement, value);
492 builder.create<sv::PAssignOp>(term.getLoc(),
reg, next);
495 while (!opsToDelete.empty()) {
496 auto value = opsToDelete.pop_back_val();
497 assert(value.use_empty());
498 value.getDefiningOp()->erase();
503 Location loc =
reg.getLoc();
506 ImplicitLocOpBuilder builder(
reg.getLoc(),
reg);
507 RegLowerInfo svReg{
nullptr,
reg.getPresetAttr(),
nullptr,
nullptr, -1, 0};
508 svReg.reg = builder.create<
sv::RegOp>(loc, regTy,
reg.getNameAttr());
511 if (
auto attr =
reg->getAttrOfType<IntegerAttr>(
"firrtl.random_init_start"))
512 svReg.randStart = attr.getUInt();
515 reg->removeAttr(
"firrtl.random_init_start");
518 svReg.reg->setDialectAttrs(
reg->getDialectAttrs());
520 if (
auto innerSymAttr =
reg.getInnerSymAttr())
521 svReg.reg.setInnerSymAttr(innerSymAttr);
525 if (
reg.hasReset()) {
527 module.getBodyBlock(), sv::EventControl::AtPosEdge,
reg.getClk(),
531 if (reg.getIsAsync() && areEquivalentValues(reg, reg.getNext()))
532 b.create<sv::PAssignOp>(reg.getLoc(), svReg.reg, reg);
534 createTree(b, svReg.reg, reg, reg.getNext());
536 reg.getIsAsync() ? sv::ResetType::AsyncReset :
sv::ResetType::SyncReset,
537 sv::EventControl::AtPosEdge,
reg.getReset(),
538 [&](OpBuilder &builder) {
539 builder.create<sv::PAssignOp>(loc, svReg.reg,
reg.getResetValue());
541 if (
reg.getIsAsync()) {
542 svReg.asyncResetSignal =
reg.getReset();
543 svReg.asyncResetValue =
reg.getResetValue();
547 module.getBodyBlock(), sv::EventControl::AtPosEdge,
reg.getClk(),
548 [&](OpBuilder &b) { createTree(b, svReg.reg, reg, reg.getNext()); });
551 reg.replaceAllUsesWith(regVal.getResult());
562 OpBuilder &builder, Value
reg,
565 auto type = cast<sv::InOutType>(
reg.getType()).getElementType();
566 if (
auto intTy = hw::type_dyn_cast<IntegerType>(type)) {
568 pos -= intTy.getWidth();
569 auto elem = builder.createOrFold<
comb::ExtractOp>(loc, randomSource, pos,
571 builder.
create<sv::BPAssignOp>(loc,
reg, elem);
572 }
else if (
auto array = hw::type_dyn_cast<hw::ArrayType>(type)) {
573 for (
unsigned i = 0, e = array.getNumElements(); i < e; ++i) {
576 loc, builder, builder.create<sv::ArrayIndexInOutOp>(loc,
reg, index),
579 }
else if (
auto structType = hw::type_dyn_cast<hw::StructType>(type)) {
580 for (
auto e : structType.getElements())
583 builder.create<sv::StructFieldInOutOp>(loc,
reg, e.name),
586 assert(
false &&
"unsupported type");
592 ArrayRef<Value> rands) {
593 auto loc =
reg.reg.getLoc();
594 SmallVector<Value> nibbles;
598 uint64_t width =
reg.width;
599 uint64_t offset =
reg.randStart;
601 auto index = offset / 32;
602 auto start = offset % 32;
603 auto nwidth = std::min(32 - start, width);
607 nibbles.push_back(elem);
612 unsigned pos =
reg.width;
618 Block *block, sv::EventControl clockEdge, Value clock,
619 const std::function<
void(OpBuilder &)> &body, sv::ResetType resetStyle,
620 sv::EventControl resetEdge, Value reset,
621 const std::function<
void(OpBuilder &)> &resetBody) {
622 auto loc = clock.getLoc();
623 auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, block);
625 resetStyle, resetEdge, reset};
627 sv::AlwaysOp alwaysOp;
635 assert(resetStyle != sv::ResetType::NoReset);
648 auto createIfOp = [&]() {
651 insideIfOp = builder.create<sv::IfOp>(
652 reset, []() {}, []() {});
654 if (resetStyle == sv::ResetType::AsyncReset) {
655 sv::EventControl events[] = {clockEdge, resetEdge};
656 Value clocks[] = {clock, reset};
658 alwaysOp = builder.create<sv::AlwaysOp>(events, clocks, [&]() {
659 if (resetEdge == sv::EventControl::AtNegEdge)
660 llvm_unreachable(
"negative edge for reset is not expected");
664 alwaysOp = builder.create<sv::AlwaysOp>(clockEdge, clock, createIfOp);
668 alwaysOp = builder.create<sv::AlwaysOp>(clockEdge, clock);
669 insideIfOp =
nullptr;
674 assert(insideIfOp &&
"reset body must be initialized before");
676 ImplicitLocOpBuilder::atBlockEnd(loc, insideIfOp.getThenBlock());
677 resetBody(resetBuilder);
680 ImplicitLocOpBuilder::atBlockEnd(loc, insideIfOp.getElseBlock());
684 ImplicitLocOpBuilder::atBlockEnd(loc, alwaysOp.getBodyBlock());
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static bool areEquivalentValues(Value term, Value next)
static std::optional< APInt > getConstantValue(Value value)
static llvm::SetVector< Value > extractConditions(Value value)
std::unique_ptr< ReachableMuxes > reachableMuxes
void initialize(OpBuilder &builder, RegLowerInfo reg, ArrayRef< Value > rands)
llvm::SmallDenseMap< std::pair< Value, unsigned >, Value > arrayIndexCache
FirRegLowering(TypeConverter &typeConverter, hw::HWModuleOp module, bool disableRegRandomization=false, bool emitSeparateAlwaysBlocks=false)
llvm::SmallDenseMap< IfKeyType, sv::IfOp > ifCache
void addToIfBlock(OpBuilder &builder, Value cond, const std::function< void()> &trueSide, const std::function< void()> &falseSide)
std::optional< std::tuple< Value, Value, Value > > tryRestoringSubaccess(OpBuilder &builder, Value reg, Value term, hw::ArrayCreateOp nextRegValue)
void createTree(OpBuilder &builder, Value reg, Value term, Value next)
unsigned numSubaccessRestored
hw::ConstantOp getOrCreateConstant(Location loc, const APInt &value)
void addToAlwaysBlock(Block *block, sv::EventControl clockEdge, Value clock, const std::function< void(OpBuilder &)> &body, sv::ResetType resetStyle={}, sv::EventControl resetEdge={}, Value reset={}, const std::function< void(OpBuilder &)> &resetBody={})
std::tuple< Block *, sv::EventControl, Value, sv::ResetType, sv::EventControl, Value > AlwaysKeyType
void initializeRegisterElements(Location loc, OpBuilder &builder, Value reg, Value rand, unsigned &pos)
TypeConverter & typeConverter
hw::HWModuleOp bool disableRegRandomization
bool emitSeparateAlwaysBlocks
llvm::SmallDenseMap< AlwaysKeyType, std::pair< sv::AlwaysOp, sv::IfOp > > alwaysBlocks
void buildReachabilityFrom(Operation *startNode)
llvm::SmallPtrSet< Operation *, 16 > visited
HWModuleOp llvm::DenseMap< Operation *, llvm::SmallDenseSet< Operation * > > reachableMuxes
bool isMuxReachableFrom(seq::FirRegOp regOp, comb::MuxOp muxOp)
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
static std::function< bool(const Operation *op)> opAllowsReachability