22 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "llhd-desequentialization"
29 #define GEN_PASS_DEF_DESEQUENTIALIZATION
30 #include "circt/Dialect/LLHD/Transforms/Passes.h.inc"
34 using namespace circt;
41 class CombInterpreter :
public comb::CombinationalVisitor<CombInterpreter> {
48 APInt compute(ArrayRef<Value> primitives, ArrayRef<APInt> truthTable,
49 uint64_t width, Value root) {
50 assert(primitives.size() == truthTable.size() &&
"must have same size");
52 for (
auto [p, t] : llvm::zip(primitives, truthTable))
57 DenseSet<Operation *> visited;
58 SmallVector<Value> worklist;
59 worklist.push_back(root);
61 while (!worklist.empty()) {
62 auto curr = worklist.back();
63 if (results.contains(curr)) {
68 auto *defOp = curr.getDefiningOp();
74 if (!visited.contains(defOp)) {
75 visited.insert(defOp);
77 bool addedToWorklist =
78 TypeSwitch<Operation *, bool>(defOp)
80 worklist.append(llvm::to_vector(op->getOperands()));
83 .Case([&](comb::ICmpOp op) {
84 if ((op.getPredicate() == circt::comb::ICmpPredicate::eq ||
85 op.getPredicate() == circt::comb::ICmpPredicate::ne) &&
86 op.getLhs().getType().isSignlessInteger(1)) {
87 worklist.append(llvm::to_vector(op->getOperands()));
93 results[op.getResult()] = op.getValue().getBoolValue()
94 ? APInt::getAllOnes(width)
103 dispatchCombinationalVisitor(defOp);
107 return results[root];
111 auto res = APInt::getAllOnes(width);
112 for (
auto operand : op->getOperands())
113 res &= results[operand];
114 results[op.getResult()] = res;
118 auto res = APInt(width, 0);
119 for (
auto operand : op->getOperands())
120 res |= results[operand];
121 results[op.getResult()] = res;
125 auto res = results[op->getOperands()[0]];
126 for (
auto operand : op->getOperands().drop_front())
127 res ^= results[operand];
128 results[op.getResult()] = res;
131 void visitComb(comb::ICmpOp op) {
132 auto res = results[op.getLhs()];
133 res ^= results[op.getRhs()];
134 if (op.getPredicate() == comb::ICmpPredicate::eq)
135 res ^= APInt::getAllOnes(width);
136 results[op.getResult()] = res;
141 void visitComb(
comb::AddOp op) { visitInvalidComb(op); }
142 void visitComb(
comb::SubOp op) { visitInvalidComb(op); }
143 void visitComb(
comb::MulOp op) { visitInvalidComb(op); }
144 void visitComb(
comb::DivUOp op) { visitInvalidComb(op); }
145 void visitComb(
comb::DivSOp op) { visitInvalidComb(op); }
146 void visitComb(
comb::ModUOp op) { visitInvalidComb(op); }
147 void visitComb(
comb::ModSOp op) { visitInvalidComb(op); }
148 void visitComb(
comb::ShlOp op) { visitInvalidComb(op); }
149 void visitComb(
comb::ShrUOp op) { visitInvalidComb(op); }
150 void visitComb(
comb::ShrSOp op) { visitInvalidComb(op); }
153 void visitComb(comb::ReplicateOp op) { visitInvalidComb(op); }
155 void visitComb(
comb::MuxOp op) { visitInvalidComb(op); }
158 DenseMap<Value, APInt> results;
171 static StringRef stringify(
const Trigger::Kind &kind) {
173 case Trigger::Kind::PosEdge:
175 case Trigger::Kind::NegEdge:
177 case Trigger::Kind::Edge:
180 llvm::llvm_unreachable_internal(
"all cases considered above");
184 SmallVector<Value> clocks;
188 SmallVector<Kind> kinds;
195 template <
typename T>
197 return os << Trigger::stringify(kind);
205 DnfAnalyzer(Value value, function_ref<
bool(Value)> sampledInPast)
207 assert(value.getType().isSignlessInteger(1) &&
208 "only 1-bit signless integers supported");
210 DenseSet<Value> alreadyAdded;
212 SmallVector<Value> worklist;
213 worklist.push_back(value);
215 while (!worklist.empty()) {
216 Value curr = worklist.pop_back_val();
217 auto *defOp = curr.getDefiningOp();
219 if (!alreadyAdded.contains(curr)) {
220 primitives.push_back(curr);
221 primitiveSampledInPast.push_back(sampledInPast(curr));
222 alreadyAdded.insert(curr);
227 TypeSwitch<Operation *>(defOp)
229 worklist.append(llvm::to_vector(op->getOperands()));
231 .Case([&](comb::ICmpOp op) {
232 if ((op.getPredicate() == circt::comb::ICmpPredicate::eq ||
233 op.getPredicate() == circt::comb::ICmpPredicate::ne) &&
234 op.getLhs().getType().isSignlessInteger(1)) {
235 worklist.append(llvm::to_vector(op->getOperands()));
237 if (!alreadyAdded.contains(curr)) {
238 primitives.push_back(curr);
239 primitiveSampledInPast.push_back(sampledInPast(curr));
240 alreadyAdded.insert(curr);
244 .Case<hw::ConstantOp>([](
auto op) { })
245 .Default([&](
auto op) {
246 if (!alreadyAdded.contains(curr)) {
247 primitives.push_back(curr);
248 primitiveSampledInPast.push_back(sampledInPast(curr));
249 alreadyAdded.insert(curr);
255 for (
auto val : primitives)
256 llvm::dbgs() <<
" - Primitive variable: " << val <<
"\n";
259 this->isClock = SmallVector<bool>(primitives.size(),
false);
260 this->dontCare = SmallVector<APInt>(primitives.size(),
261 APInt(1ULL << primitives.size(), 0));
269 computeTriggers(OpBuilder &builder, Location loc,
270 function_ref<
bool(Value, Value)> sampledFromSameSignal,
271 SmallVectorImpl<Trigger> &triggers,
unsigned maxPrimitives) {
272 if (primitives.size() > maxPrimitives) {
273 LLVM_DEBUG({ llvm::dbgs() <<
" Too many primitives, skipping...\n"; });
283 if (failed(computeClockValuePairs(sampledFromSameSignal)))
287 simplifyTruthTable();
290 llvm::dbgs() <<
" - Truth table:\n";
292 for (
auto [t, d] : llvm::zip(truthTable, dontCare))
293 llvm::dbgs() <<
" " <<
FVInt(std::move(t), std::move(d)) <<
"\n";
295 SmallVector<char> str;
299 for (
unsigned i = 0; i < result.getBitWidth() - str.size(); ++i)
302 llvm::dbgs() << str <<
"\n";
308 materializeTriggerEnables(builder, loc);
311 extractTriggerList(triggers);
314 canonicalizeTriggerList(triggers, builder, loc);
320 FVInt computeEnableKey(
unsigned tableRow) {
322 for (
unsigned k = 0; k < primitives.size(); ++k) {
323 if (dontCare[k][tableRow])
326 if (primitiveSampledInPast[k])
334 key.
setBit(k, truthTable[k][tableRow]);
340 void extractTriggerList(SmallVectorImpl<Trigger> &triggers) {
341 for (uint64_t i = 0, e = 1ULL << primitives.size(); i < e; ++i) {
345 auto key = computeEnableKey(i);
347 for (
auto clk : clockPairs) {
348 if (dontCare[
clk.second][i] && dontCare[
clk.first][i])
351 trigger.clocks.push_back(primitives[
clk.second]);
352 trigger.kinds.push_back(truthTable[
clk.second][i]
353 ? Trigger::Kind::PosEdge
354 : Trigger::Kind::NegEdge);
356 trigger.enable = enableMap[key];
358 if (!trigger.clocks.empty())
359 triggers.push_back(trigger);
363 void materializeTriggerEnables(OpBuilder &builder, Location loc) {
366 for (uint64_t i = 0, e = 1ULL << primitives.size(); i < e; ++i) {
370 auto key = computeEnableKey(i);
372 if (!enableMap.contains(key)) {
373 SmallVector<Value> conjuncts;
374 for (
unsigned k = 0; k < primitives.size(); ++k) {
378 if (primitiveSampledInPast[k])
386 if (truthTable[k][i]) {
387 conjuncts.push_back(primitives[k]);
391 builder.create<
comb::XorOp>(loc, primitives[k], trueVal));
393 if (!conjuncts.empty())
395 builder.createOrFold<
comb::AndOp>(loc, conjuncts,
false);
400 LogicalResult computeClockValuePairs(
401 function_ref<
bool(Value, Value)> sampledFromSameSignal) {
402 for (
unsigned k = 0; k < primitives.size(); ++k) {
406 for (
unsigned l = k + 1; l < primitives.size(); ++l) {
407 if (sampledFromSameSignal(primitives[k], primitives[l]) &&
408 (primitiveSampledInPast[k] != primitiveSampledInPast[l])) {
409 if (primitiveSampledInPast[k])
410 clockPairs.emplace_back(k, l);
412 clockPairs.emplace_back(l, k);
417 if (primitiveSampledInPast[k] && !isClock[k])
424 void simplifyTruthTable() {
425 uint64_t numEntries = 1 << primitives.size();
431 for (uint64_t i = 0; i < numEntries; ++i) {
435 for (uint64_t k = i + 1; k < numEntries; ++k) {
439 unsigned differenceCount = 0;
440 for (
unsigned l = 0; l < primitives.size(); ++l) {
441 if (truthTable[l][i] != truthTable[l][k])
443 if (differenceCount > 1)
447 if (differenceCount == 1) {
448 for (
unsigned l = 0; l < primitives.size(); ++l) {
449 dontCare[l].setBit(k);
450 if (truthTable[l][i] != truthTable[l][k])
451 dontCare[l].setBit(i);
458 void computeTruthTable() {
459 uint64_t numEntries = 1 << primitives.size();
460 for (
auto _ [[maybe_unused]] : primitives)
461 truthTable.push_back(APInt(numEntries, 0));
463 for (uint64_t i = 0; i < numEntries; ++i)
464 for (
unsigned k = 0; k < primitives.size(); ++k)
465 truthTable[k].setBitVal(i, APInt(64, i)[k]);
468 CombInterpreter().compute(primitives, truthTable, numEntries, root);
471 void canonicalizeTriggerList(SmallVectorImpl<Trigger> &triggers,
472 OpBuilder &builder, Location loc) {
473 for (
auto *iter1 = triggers.begin(); iter1 != triggers.end(); ++iter1) {
474 for (
auto *iter2 = iter1 + 1; iter2 != triggers.end(); ++iter2) {
475 if (iter1->clocks == iter2->clocks && iter1->kinds == iter2->kinds) {
477 builder.create<
comb::OrOp>(loc, iter1->enable, iter2->enable);
478 triggers.erase(iter2--);
488 SmallVector<Value> primitives;
489 SmallVector<bool> isClock;
490 SmallVector<bool> primitiveSampledInPast;
491 SmallVector<APInt> truthTable;
492 SmallVector<APInt> dontCare;
493 SmallVector<std::pair<unsigned, unsigned>> clockPairs;
494 DenseMap<FVInt, Value> enableMap;
498 struct DesequentializationPass
499 :
public llhd::impl::DesequentializationBase<DesequentializationPass> {
500 DesequentializationPass()
501 : llhd::impl::DesequentializationBase<DesequentializationPass>() {}
502 DesequentializationPass(
const llhd::DesequentializationOptions &options)
503 : llhd::impl::DesequentializationBase<DesequentializationPass>(options) {
504 maxPrimitives.setValue(options.maxPrimitives);
506 void runOnOperation()
override;
507 void runOnProcess(llhd::ProcessOp procOp)
const;
509 isSupportedSequentialProcess(llhd::ProcessOp procOp,
511 SmallVectorImpl<Value> &observed)
const;
515 LogicalResult DesequentializationPass::isSupportedSequentialProcess(
517 SmallVectorImpl<Value> &observed)
const {
526 llvm::dbgs() <<
" Combinational process -> no need to desequentialize\n";
531 if (numTRs > 2 || procOp.getBody().getBlocks().size() != 3) {
533 { llvm::dbgs() <<
" Complex sequential process -> not supported\n"; });
537 bool seenWait =
false;
538 WalkResult result = procOp.walk([&](llhd::WaitOp op) -> WalkResult {
539 LLVM_DEBUG({ llvm::dbgs() <<
" Analyzing Wait Operation:\n"; });
540 for (
auto obs : op.getObserved()) {
541 observed.push_back(obs);
542 LLVM_DEBUG({ llvm::dbgs() <<
" - Observes: " << obs <<
"\n"; });
544 LLVM_DEBUG({ llvm::dbgs() <<
"\n"; });
552 trAnalysis.
getBlockTR(op.getOperation()->getBlock())))
556 return WalkResult::advance();
559 if (result.wasInterrupted() || !seenWait) {
561 { llvm::dbgs() <<
" Complex sequential process -> not supported\n"; });
566 { llvm::dbgs() <<
" Sequential process, attempt lowering...\n"; });
571 void DesequentializationPass::runOnProcess(llhd::ProcessOp procOp)
const {
573 std::string line(74,
'-');
574 llvm::dbgs() <<
"\n===" << line <<
"===\n";
575 llvm::dbgs() <<
"=== Process\n";
576 llvm::dbgs() <<
"===" << line <<
"===\n";
582 SmallVector<Value> observed;
583 if (failed(isSupportedSequentialProcess(procOp, trAnalysis, observed)))
586 OpBuilder builder(procOp);
587 WalkResult result = procOp.walk([&](llhd::DrvOp op) {
588 LLVM_DEBUG({ llvm::dbgs() <<
"\n Lowering Drive Operation\n"; });
590 if (!op.getEnable()) {
591 LLVM_DEBUG({ llvm::dbgs() <<
" - No enable condition -> skip\n"; });
592 return WalkResult::advance();
595 Location loc = op.getLoc();
596 builder.setInsertionPoint(op);
597 int presentTR = trAnalysis.
getBlockTR(op.getOperation()->getBlock());
599 auto sampledInPast = [&](Value value) ->
bool {
600 if (isa<BlockArgument>(value))
603 if (!procOp->isAncestor(value.getDefiningOp()))
606 return trAnalysis.
getBlockTR(value.getDefiningOp()->getBlock()) !=
610 LLVM_DEBUG({ llvm::dbgs() <<
" - Analyzing enable condition...\n"; });
612 SmallVector<Trigger> triggers;
613 auto sampledFromSameSignal = [](Value val1, Value val2) ->
bool {
614 if (
auto prb1 = val1.getDefiningOp<llhd::PrbOp>())
615 if (
auto prb2 = val2.getDefiningOp<llhd::PrbOp>())
616 return prb1.getSignal() == prb2.getSignal();
623 if (failed(DnfAnalyzer(op.getEnable(), sampledInPast)
624 .computeTriggers(builder, loc, sampledFromSameSignal,
625 triggers, maxPrimitives))) {
627 llvm::dbgs() <<
" Unable to compute trigger list for drive condition, "
630 return WalkResult::interrupt();
634 if (triggers.empty())
635 llvm::dbgs() <<
" - no triggers found!\n";
639 for (
auto trigger : triggers) {
640 llvm::dbgs() <<
" - Trigger\n";
641 for (
auto [clk, kind] : llvm::zip(trigger.clocks, trigger.kinds))
642 llvm::dbgs() <<
" - " << kind <<
" "
643 <<
"clock: " <<
clk <<
"\n";
646 llvm::dbgs() <<
" with enable: " << trigger.enable <<
"\n";
651 if (triggers.size() > 2 || triggers.empty())
652 return WalkResult::interrupt();
655 if (triggers[0].clocks.size() != 1 || triggers[0].clocks.size() != 1)
656 return WalkResult::interrupt();
659 if (triggers[0].kinds[0] == Trigger::Kind::Edge)
660 return WalkResult::interrupt();
662 if (!llvm::any_of(observed, [&](Value val) {
663 return sampledFromSameSignal(val, triggers[0].clocks[0]) &&
664 val.getParentRegion() != procOp.getBody();
666 return WalkResult::interrupt();
668 Value clock = builder.create<seq::ToClockOp>(loc, triggers[0].clocks[0]);
669 Value reset, resetValue;
671 if (triggers[0].kinds[0] == Trigger::Kind::NegEdge)
672 clock = builder.create<seq::ClockInverterOp>(loc, clock);
674 if (triggers[0].enable)
675 clock = builder.create<seq::ClockGateOp>(loc, clock, triggers[0].enable);
677 if (triggers.size() == 2) {
679 if (triggers[1].clocks.size() != 1 || triggers[1].kinds.size() != 1)
680 return WalkResult::interrupt();
683 if (triggers[1].kinds[0] == Trigger::Kind::Edge)
684 return WalkResult::interrupt();
687 if (triggers[1].enable)
688 return WalkResult::interrupt();
690 if (!llvm::any_of(observed, [&](Value val) {
691 return sampledFromSameSignal(val, triggers[1].clocks[0]) &&
692 val.getParentRegion() != procOp.getBody();
694 return WalkResult::interrupt();
696 reset = triggers[1].clocks[0];
697 resetValue = op.getValue();
699 if (triggers[1].kinds[0] == Trigger::Kind::NegEdge) {
708 Value regOut = builder.create<
seq::CompRegOp>(loc, op.getValue(), clock,
711 op.getEnableMutable().clear();
712 op.getValueMutable().assign(regOut);
715 { llvm::dbgs() <<
" Lowered Drive Operation successfully!\n\n"; });
717 return WalkResult::advance();
720 if (result.wasInterrupted())
723 IRRewriter rewriter(builder);
724 auto &entryBlock = procOp.getBody().getBlocks().front();
727 for (Block &block : procOp.getBody().getBlocks()) {
728 block.getTerminator()->erase();
730 if (!block.isEntryBlock())
731 entryBlock.getOperations().splice(entryBlock.end(),
732 block.getOperations());
735 rewriter.inlineBlockBefore(&entryBlock, procOp);
738 LLVM_DEBUG({ llvm::dbgs() <<
"Lowered process successfully!\n"; });
741 void DesequentializationPass::runOnOperation() {
744 llvm::make_early_inc_range(moduleOp.getOps<llhd::ProcessOp>()))
745 runOnProcess(procOp);
assert(baseType &&"element must be base type")
Four-valued arbitrary precision integers.
SmallString< 16 > toString(unsigned radix=10, bool uppercase=true) const
Convert an FVInt to a string.
static FVInt getAllX(unsigned numBits)
Construct an FVInt with all bits set to X.
void setBit(unsigned index, Bit bit)
Set the value of an individual bit.
def create(data_type, value)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
raw_ostream & operator<<(raw_ostream &os, const FVInt &value)
int getBlockTR(Block *) const
unsigned getNumTemporalRegions() const
bool hasSingleExitBlock(int tr) const