345 PatternRewriter &rewriter)
const override {
346 auto inputWidth = op.getOperand(0).getType().getIntOrFloatBitWidth();
347 auto lhs =
isSext(op.getOperand(0));
348 auto rhs =
isSext(op.getOperand(1));
349 if (failed(lhs) || failed(rhs))
352 size_t lhsWidth = (*lhs).getType().getIntOrFloatBitWidth();
353 size_t rhsWidth = (*rhs).getType().getIntOrFloatBitWidth();
355 size_t maxRows = std::max(lhsWidth, rhsWidth) - 1;
359 if (lhsWidth != rhsWidth || lhsWidth <= 1 || rhsWidth <= 1)
363 if (maxRows >= op.getNumResults())
367 auto lhsBaseWidth = lhsWidth - 1;
368 auto rhsBaseWidth = rhsWidth - 1;
380 comb::createZExt(rewriter, op.getLoc(), lhsBase, inputWidth);
382 comb::createZExt(rewriter, op.getLoc(), rhsBase, inputWidth);
383 auto newPP = datapath::PartialProductOp::create(
384 rewriter, op.getLoc(), ValueRange{lhsBaseZext, rhsBaseZext}, maxRows);
393 auto lhsSignReplicate = comb::ReplicateOp::create(rewriter, op.getLoc(),
394 lhsSignBit, rhsBaseWidth);
396 comb::AndOp::create(rewriter, op.getLoc(), lhsSignReplicate, rhsBase);
397 auto lhsSignCorrection =
398 comb::createOrFoldNot(op.getLoc(), lhsSignAndRhs, rewriter,
true);
401 auto alignLhsSignCorrection =
zeroPad(
402 rewriter, op.getLoc(), lhsSignCorrection, inputWidth, lhsBaseWidth);
405 auto rhsSignReplicate = comb::ReplicateOp::create(rewriter, op.getLoc(),
406 rhsSignBit, lhsBaseWidth);
408 comb::AndOp::create(rewriter, op.getLoc(), rhsSignReplicate, lhsBase);
409 auto rhsSignCorrection =
410 comb::createOrFoldNot(op.getLoc(), rhsSignAndLhs, rewriter,
true);
413 auto alignRhsSignCorrection =
zeroPad(
414 rewriter, op.getLoc(), rhsSignCorrection, inputWidth, rhsBaseWidth);
419 comb::AndOp::create(rewriter, op.getLoc(), lhsSignBit, rhsSignBit);
421 auto alignSignAndZext =
zeroPad(rewriter, op.getLoc(), signAnd, inputWidth,
422 lhsBaseWidth + rhsBaseWidth);
426 auto ones = APInt::getAllOnes(inputWidth);
427 auto lowerLhs = APInt::getOneBitSet(inputWidth, lhsBaseWidth);
428 auto lowerRhs = APInt::getOneBitSet(inputWidth, rhsBaseWidth);
429 auto msbCorrection = ones << (lhsBaseWidth + rhsBaseWidth);
430 auto correction = lowerLhs + lowerRhs + 2 * msbCorrection;
432 auto constantCorrection =
436 APInt::getZero(inputWidth));
438 SmallVector<Value> newResults(newPP.getResults().begin(),
439 newPP.getResults().end());
442 newResults.push_back(alignLhsSignCorrection);
444 newResults.push_back(alignRhsSignCorrection);
446 newResults.push_back(alignSignAndZext);
448 newResults.push_back(constantCorrection);
450 newResults.append(op.getNumResults() - newResults.size(), zero);
452 rewriter.replaceOp(op, newResults);