471 PatternRewriter &rewriter)
const override {
472 auto inputWidth = op.getOperand(0).getType().getIntOrFloatBitWidth();
475 if (!matchPattern(op.getOperand(0), comb::m_Sext(m_Any(&lhs))) ||
476 !matchPattern(op.getOperand(1), comb::m_Sext(m_Any(&rhs))))
479 size_t lhsWidth = lhs.getType().getIntOrFloatBitWidth();
480 size_t rhsWidth = rhs.getType().getIntOrFloatBitWidth();
482 size_t maxRows = std::max(lhsWidth, rhsWidth) - 1;
486 if (lhsWidth != rhsWidth || lhsWidth <= 1 || rhsWidth <= 1)
490 if (maxRows >= op.getNumResults())
494 auto lhsBaseWidth = lhsWidth - 1;
495 auto rhsBaseWidth = rhsWidth - 1;
507 comb::createZExt(rewriter, op.getLoc(), lhsBase, inputWidth);
509 comb::createZExt(rewriter, op.getLoc(), rhsBase, inputWidth);
510 auto newPP = datapath::PartialProductOp::create(
511 rewriter, op.getLoc(), ValueRange{lhsBaseZext, rhsBaseZext}, maxRows);
520 auto lhsSignReplicate = comb::ReplicateOp::create(rewriter, op.getLoc(),
521 lhsSignBit, rhsBaseWidth);
523 comb::AndOp::create(rewriter, op.getLoc(), lhsSignReplicate, rhsBase);
524 auto lhsSignCorrection =
525 comb::createOrFoldNot(op.getLoc(), lhsSignAndRhs, rewriter,
true);
528 auto alignLhsSignCorrection =
zeroPad(
529 rewriter, op.getLoc(), lhsSignCorrection, inputWidth, lhsBaseWidth);
532 auto rhsSignReplicate = comb::ReplicateOp::create(rewriter, op.getLoc(),
533 rhsSignBit, lhsBaseWidth);
535 comb::AndOp::create(rewriter, op.getLoc(), rhsSignReplicate, lhsBase);
536 auto rhsSignCorrection =
537 comb::createOrFoldNot(op.getLoc(), rhsSignAndLhs, rewriter,
true);
540 auto alignRhsSignCorrection =
zeroPad(
541 rewriter, op.getLoc(), rhsSignCorrection, inputWidth, rhsBaseWidth);
546 comb::AndOp::create(rewriter, op.getLoc(), lhsSignBit, rhsSignBit);
548 auto alignSignAndZext =
zeroPad(rewriter, op.getLoc(), signAnd, inputWidth,
549 lhsBaseWidth + rhsBaseWidth);
553 auto ones = APInt::getAllOnes(inputWidth);
554 auto lowerLhs = APInt::getOneBitSet(inputWidth, lhsBaseWidth);
555 auto lowerRhs = APInt::getOneBitSet(inputWidth, rhsBaseWidth);
556 auto msbCorrection = ones << (lhsBaseWidth + rhsBaseWidth);
557 auto correction = lowerLhs + lowerRhs + 2 * msbCorrection;
559 auto constantCorrection =
563 APInt::getZero(inputWidth));
565 SmallVector<Value> newResults(newPP.getResults().begin(),
566 newPP.getResults().end());
569 newResults.push_back(alignLhsSignCorrection);
571 newResults.push_back(alignRhsSignCorrection);
573 newResults.push_back(alignSignAndZext);
575 newResults.push_back(constantCorrection);
577 newResults.append(op.getNumResults() - newResults.size(), zero);
579 rewriter.replaceOp(op, newResults);