19#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
20#include "mlir/Conversion/LLVMCommon/Pattern.h"
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/Math/IR/Math.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/IR/AsmState.h"
28#include "mlir/IR/Matchers.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Support/LogicalResult.h"
31#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/LogicalResult.h"
34#include "llvm/Support/raw_os_ostream.h"
35#include "llvm/Support/raw_ostream.h"
45#define GEN_PASS_DEF_SCFTOCALYX
46#include "circt/Conversion/Passes.h.inc"
51using namespace mlir::arith;
52using namespace mlir::cf;
55class ComponentLoweringStateInterface;
83 std::optional<int64_t>
getBound()
override {
return std::nullopt; }
145 Operation *operation = op.getOperation();
146 auto [it, succeeded] =
condReg.insert(std::make_pair(operation, regOp));
148 "A condition register was already set for this scf::IfOp!");
152 auto it =
condReg.find(op.getOperation());
159 Operation *operation = op.getOperation();
161 "A then group was already set for this scf::IfOp!\n");
166 auto it =
thenGroup.find(op.getOperation());
168 "No then group was set for this scf::IfOp!\n");
173 Operation *operation = op.getOperation();
175 "An else group was already set for this scf::IfOp!\n");
180 auto it =
elseGroup.find(op.getOperation());
182 "No else group was set for this scf::IfOp!\n");
188 "A register was already registered for the given yield result.\n");
189 assert(idx < op->getNumOperands());
199 auto it = regs.find(idx);
200 assert(it != regs.end() &&
"resultReg not found");
206 DenseMap<Operation *, calyx::RegisterOp>
condReg;
209 DenseMap<Operation *, DenseMap<unsigned, calyx::RegisterOp>>
resultRegs;
219 OpBuilder &builder,
ScfWhileOp op, calyx::ComponentOp componentOp,
220 Twine uniqueSuffix, MutableArrayRef<OpOperand> ops) {
227 const DenseMap<unsigned, calyx::RegisterOp> &
238 SmallVector<calyx::GroupOp> groups) {
250 OpBuilder &builder,
ScfForOp op, calyx::ComponentOp componentOp,
251 Twine uniqueSuffix, MutableArrayRef<OpOperand> ops) {
280 auto cellOp = dyn_cast<calyx::CellInterface>(op);
281 assert(cellOp && !cellOp.isCombinational());
282 auto [it, succeeded] =
resultRegs.insert(std::make_pair(op, reg));
284 "A register was already set for this sequential operation!");
290 "No register was set for this sequential operation!");
324 DenseMap<mlir::func::FuncOp, calyx::ComponentOp> &map,
326 mlir::Pass::Option<std::string> &writeJsonOpt)
329 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
333 PatternRewriter &rewriter)
const override {
336 bool opBuiltSuccessfully =
true;
337 funcOp.walk([&](Operation *_op) {
338 opBuiltSuccessfully &=
339 TypeSwitch<mlir::Operation *, bool>(_op)
340 .template Case<arith::ConstantOp, ReturnOp, BranchOpInterface,
342 scf::YieldOp, scf::WhileOp, scf::ForOp, scf::IfOp,
343 scf::ParallelOp, scf::ReduceOp,
344 scf::ExecuteRegionOp,
346 memref::AllocOp, memref::AllocaOp, memref::LoadOp,
347 memref::StoreOp, memref::GetGlobalOp,
349 AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
350 AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
351 MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
353 AddFOp, SubFOp, MulFOp, CmpFOp, FPToSIOp, SIToFPOp,
354 DivFOp, math::SqrtOp, math::AbsFOp,
356 SelectOp, IndexCastOp, BitcastOp, CallOp>(
357 [&](
auto op) {
return buildOp(rewriter, op).succeeded(); })
358 .
template Case<FuncOp, scf::ConditionOp>([&](
auto) {
362 .Default([&](
auto op) {
363 op->emitError() <<
"Unhandled operation during BuildOpGroups()";
367 return opBuiltSuccessfully ? WalkResult::advance()
368 : WalkResult::interrupt();
372 auto &extMemData = getState<ComponentLoweringState>().getExtMemData();
373 if (extMemData.getAsObject()->empty())
376 if (
auto fileLoc = dyn_cast<mlir::FileLineColLoc>(funcOp->getLoc())) {
377 std::string filename = fileLoc.getFilename().str();
378 std::filesystem::path path(filename);
379 std::string jsonFileName =
writeJson.getValue() +
".json";
380 auto outFileName = path.parent_path().append(jsonFileName);
381 std::ofstream outFile(outFileName);
383 if (!outFile.is_open()) {
384 llvm::errs() <<
"Unable to open file: " << outFileName.string()
388 llvm::raw_os_ostream llvmOut(outFile);
389 llvm::json::OStream jsonOS(llvmOut, 2);
390 jsonOS.value(extMemData);
396 return success(opBuiltSuccessfully);
402 LogicalResult
buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp)
const;
403 LogicalResult
buildOp(PatternRewriter &rewriter,
404 BranchOpInterface brOp)
const;
405 LogicalResult
buildOp(PatternRewriter &rewriter,
406 arith::ConstantOp constOp)
const;
407 LogicalResult
buildOp(PatternRewriter &rewriter, SelectOp op)
const;
408 LogicalResult
buildOp(PatternRewriter &rewriter, AddIOp op)
const;
409 LogicalResult
buildOp(PatternRewriter &rewriter, SubIOp op)
const;
410 LogicalResult
buildOp(PatternRewriter &rewriter, MulIOp op)
const;
411 LogicalResult
buildOp(PatternRewriter &rewriter, DivUIOp op)
const;
412 LogicalResult
buildOp(PatternRewriter &rewriter, DivSIOp op)
const;
413 LogicalResult
buildOp(PatternRewriter &rewriter, RemUIOp op)
const;
414 LogicalResult
buildOp(PatternRewriter &rewriter, RemSIOp op)
const;
415 LogicalResult
buildOp(PatternRewriter &rewriter, AddFOp op)
const;
416 LogicalResult
buildOp(PatternRewriter &rewriter, SubFOp op)
const;
417 LogicalResult
buildOp(PatternRewriter &rewriter, MulFOp op)
const;
418 LogicalResult
buildOp(PatternRewriter &rewriter, CmpFOp op)
const;
419 LogicalResult
buildOp(PatternRewriter &rewriter, FPToSIOp op)
const;
420 LogicalResult
buildOp(PatternRewriter &rewriter, SIToFPOp op)
const;
421 LogicalResult
buildOp(PatternRewriter &rewriter, DivFOp op)
const;
422 LogicalResult
buildOp(PatternRewriter &rewriter, math::SqrtOp op)
const;
423 LogicalResult
buildOp(PatternRewriter &rewriter, math::AbsFOp op)
const;
424 LogicalResult
buildOp(PatternRewriter &rewriter, ShRUIOp op)
const;
425 LogicalResult
buildOp(PatternRewriter &rewriter, ShRSIOp op)
const;
426 LogicalResult
buildOp(PatternRewriter &rewriter, ShLIOp op)
const;
427 LogicalResult
buildOp(PatternRewriter &rewriter, AndIOp op)
const;
428 LogicalResult
buildOp(PatternRewriter &rewriter, OrIOp op)
const;
429 LogicalResult
buildOp(PatternRewriter &rewriter, XOrIOp op)
const;
430 LogicalResult
buildOp(PatternRewriter &rewriter, CmpIOp op)
const;
431 LogicalResult
buildOp(PatternRewriter &rewriter, TruncIOp op)
const;
432 LogicalResult
buildOp(PatternRewriter &rewriter, ExtUIOp op)
const;
433 LogicalResult
buildOp(PatternRewriter &rewriter, ExtSIOp op)
const;
434 LogicalResult
buildOp(PatternRewriter &rewriter, ReturnOp op)
const;
435 LogicalResult
buildOp(PatternRewriter &rewriter, IndexCastOp op)
const;
436 LogicalResult
buildOp(PatternRewriter &rewriter, BitcastOp op)
const;
437 LogicalResult
buildOp(PatternRewriter &rewriter, memref::AllocOp op)
const;
438 LogicalResult
buildOp(PatternRewriter &rewriter, memref::AllocaOp op)
const;
439 LogicalResult
buildOp(PatternRewriter &rewriter,
440 memref::GetGlobalOp op)
const;
441 LogicalResult
buildOp(PatternRewriter &rewriter, memref::LoadOp op)
const;
442 LogicalResult
buildOp(PatternRewriter &rewriter, memref::StoreOp op)
const;
443 LogicalResult
buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp)
const;
444 LogicalResult
buildOp(PatternRewriter &rewriter, scf::ForOp forOp)
const;
445 LogicalResult
buildOp(PatternRewriter &rewriter, scf::IfOp ifOp)
const;
446 LogicalResult
buildOp(PatternRewriter &rewriter,
447 scf::ReduceOp reduceOp)
const;
448 LogicalResult
buildOp(PatternRewriter &rewriter,
449 scf::ParallelOp parallelOp)
const;
450 LogicalResult
buildOp(PatternRewriter &rewriter,
451 scf::ExecuteRegionOp executeRegionOp)
const;
452 LogicalResult
buildOp(PatternRewriter &rewriter, CallOp callOp)
const;
457 template <
typename TCalyxLibOp>
458 void setupCmpIOp(PatternRewriter &rewriter, CmpIOp cmpIOp, Operation *group,
459 calyx::RegisterOp &condReg, calyx::RegisterOp &resReg,
460 TCalyxLibOp calyxOp)
const {
464 StringRef opName = cmpIOp.getOperationName().split(
".").second;
465 Type width = cmpIOp.getResult().getType();
467 condReg = createRegister(
469 width.getIntOrFloatBitWidth(),
470 getState<ComponentLoweringState>().getUniqueName(opName));
472 for (
auto *user : cmpIOp->getUsers()) {
473 if (
auto ifOp = dyn_cast<scf::IfOp>(user))
474 getState<ComponentLoweringState>().setCondReg(ifOp, condReg);
478 lhsIsSeqOp != rhsIsSeqOp &&
479 "unexpected sequential operation on both sides; please open an issue");
483 cast<calyx::RegisterOp>(lhsIsSeqOp ? cmpIOp.getLhs().getDefiningOp()
484 : cmpIOp.getRhs().getDefiningOp());
486 auto groupOp = cast<calyx::GroupOp>(group);
487 getState<ComponentLoweringState>().addBlockScheduleable(cmpIOp->getBlock(),
490 rewriter.setInsertionPointToEnd(groupOp.getBodyBlock());
491 auto loc = cmpIOp.getLoc();
493 (isa<calyx::EqLibOp, calyx::NeqLibOp, calyx::SleLibOp, calyx::SltLibOp,
494 calyx::LeLibOp, calyx::LtLibOp, calyx::GeLibOp, calyx::GtLibOp,
495 calyx::SgeLibOp, calyx::SgtLibOp>(calyxOp.getOperation())) &&
496 "Must be a Calyx comparison library operation.");
497 int64_t outputIndex = 2;
498 rewriter.create<calyx::AssignOp>(loc, condReg.getIn(),
499 calyxOp.getResult(outputIndex));
500 rewriter.create<calyx::AssignOp>(
501 loc, condReg.getWriteEn(),
502 createConstant(loc, rewriter,
503 getState<ComponentLoweringState>().getComponentOp(), 1,
505 rewriter.create<calyx::GroupDoneOp>(loc, condReg.getDone());
507 getState<ComponentLoweringState>().addSeqGuardCmpLibOp(cmpIOp);
458 void setupCmpIOp(PatternRewriter &rewriter, CmpIOp cmpIOp, Operation *group, {
…}
510 template <
typename CmpILibOp>
512 bool isIfOpGuard = std::any_of(op->getUsers().begin(), op->getUsers().end(),
513 [](
auto op) { return isa<scf::IfOp>(op); });
518 return buildLibraryOp<calyx::GroupOp, CmpILibOp>(rewriter, op);
519 return buildLibraryOp<calyx::CombGroupOp, CmpILibOp>(rewriter, op);
524 template <
typename TGroupOp,
typename TCalyxLibOp,
typename TSrcOp>
526 TypeRange srcTypes, TypeRange dstTypes)
const {
527 SmallVector<Type> types;
528 for (Type srcType : srcTypes)
530 for (Type dstType : dstTypes)
534 getState<ComponentLoweringState>().getNewLibraryOpInstance<TCalyxLibOp>(
535 rewriter, op.getLoc(), types);
537 auto directions = calyxOp.portDirections();
538 SmallVector<Value, 4> opInputPorts;
539 SmallVector<Value, 4> opOutputPorts;
540 for (
auto dir : enumerate(directions)) {
542 opInputPorts.push_back(calyxOp.getResult(dir.index()));
544 opOutputPorts.push_back(calyxOp.getResult(dir.index()));
547 opInputPorts.size() == op->getNumOperands() &&
548 opOutputPorts.size() == op->getNumResults() &&
549 "Expected an equal number of in/out ports in the Calyx library op with "
550 "respect to the number of operands/results of the source operation.");
553 auto group = createGroupForOp<TGroupOp>(rewriter, op);
555 bool isSeqCondCheck = isa<calyx::GroupOp>(group);
556 calyx::RegisterOp condReg =
nullptr, resReg =
nullptr;
557 if (isa<CmpIOp>(op) && isSeqCondCheck) {
558 auto cmpIOp = cast<CmpIOp>(op);
559 setupCmpIOp(rewriter, cmpIOp, group, condReg, resReg, calyxOp);
562 rewriter.setInsertionPointToEnd(group.getBodyBlock());
564 for (
auto dstOp : enumerate(opInputPorts)) {
567 : op->getOperand(dstOp.index());
568 rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(), srcOp);
572 for (
auto res : enumerate(opOutputPorts)) {
573 getState<ComponentLoweringState>().registerEvaluatingGroup(res.value(),
575 auto dstOp = isSeqCondCheck ? condReg.getOut() : res.value();
576 op->getResult(res.index()).replaceAllUsesWith(dstOp);
584 template <
typename TGroupOp,
typename TCalyxLibOp,
typename TSrcOp>
586 return buildLibraryOp<TGroupOp, TCalyxLibOp, TSrcOp>(
587 rewriter, op, op.getOperandTypes(), op->getResultTypes());
591 template <
typename TGroupOp>
593 Block *block = op->getBlock();
594 auto groupName = getState<ComponentLoweringState>().getUniqueName(
596 return calyx::createGroup<TGroupOp>(
597 rewriter, getState<ComponentLoweringState>().getComponentOp(),
598 op->getLoc(), groupName);
603 template <
typename TOpType,
typename TSrcOp>
605 TOpType opPipe, Value out)
const {
606 StringRef opName = TSrcOp::getOperationName().split(
".").second;
607 Location loc = op.getLoc();
608 Type width = op.getResult().getType();
609 auto reg = createRegister(
610 op.getLoc(), rewriter,
getComponent(), width.getIntOrFloatBitWidth(),
611 getState<ComponentLoweringState>().getUniqueName(opName));
614 auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
615 OpBuilder builder(group->getRegion(0));
616 getState<ComponentLoweringState>().addBlockScheduleable(op->getBlock(),
619 rewriter.setInsertionPointToEnd(group.getBodyBlock());
620 if constexpr (std::is_same_v<TSrcOp, math::SqrtOp>)
623 rewriter.create<calyx::AssignOp>(loc, opPipe.getLeft(), op.getOperand());
625 rewriter.create<calyx::AssignOp>(loc, opPipe.getLeft(), op.getLhs());
626 rewriter.create<calyx::AssignOp>(loc, opPipe.getRight(), op.getRhs());
629 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), out);
631 rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), opPipe.getDone());
636 rewriter.create<calyx::AssignOp>(
637 loc, opPipe.getGo(), c1,
640 rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());
644 op.getResult().replaceAllUsesWith(reg.getOut());
646 if (isa<calyx::AddFOpIEEE754>(opPipe)) {
647 auto opFOp = cast<calyx::AddFOpIEEE754>(opPipe);
649 if (isa<arith::AddFOp>(op)) {
650 subOp = createConstant(loc, rewriter,
getComponent(), 1,
653 subOp = createConstant(loc, rewriter,
getComponent(), 1,
656 rewriter.create<calyx::AssignOp>(loc, opFOp.getSubOp(), subOp);
657 }
else if (
auto opFOp =
658 dyn_cast<calyx::DivSqrtOpIEEE754>(opPipe.getOperation())) {
659 bool isSqrt = !isa<arith::DivFOp>(op);
661 createConstant(loc, rewriter,
getComponent(), 1, isSqrt);
662 rewriter.create<calyx::AssignOp>(loc, opFOp.getSqrtOp(), sqrtOp);
666 getState<ComponentLoweringState>().registerEvaluatingGroup(out, group);
667 getState<ComponentLoweringState>().registerEvaluatingGroup(opPipe.getLeft(),
669 getState<ComponentLoweringState>().registerEvaluatingGroup(
670 opPipe.getRight(), group);
672 getState<ComponentLoweringState>().setSeqResReg(out.getDefiningOp(), reg);
677 template <
typename TCalyxLibOp,
typename TSrcOp>
679 unsigned inputWidth,
unsigned outputWidth,
680 StringRef signedPort)
const {
681 Location loc = op.getLoc();
682 IntegerType one = rewriter.getI1Type(),
683 inWidth = rewriter.getIntegerType(inputWidth),
684 outWidth = rewriter.getIntegerType(outputWidth);
686 getState<ComponentLoweringState>().getNewLibraryOpInstance<TCalyxLibOp>(
687 rewriter, loc, {one, one, one, inWidth, one, outWidth, one});
689 StringRef opName = op.getOperationName().split(
".").second;
691 auto reg = createRegister(
692 loc, rewriter,
getComponent(), outWidth.getIntOrFloatBitWidth(),
693 getState<ComponentLoweringState>().getUniqueName(opName));
695 auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
696 OpBuilder builder(group->getRegion(0));
697 getState<ComponentLoweringState>().addBlockScheduleable(op->getBlock(),
700 rewriter.setInsertionPointToEnd(group.getBodyBlock());
701 rewriter.create<calyx::AssignOp>(loc, calyxOp.getIn(), op.getIn());
702 if (isa<calyx::FpToIntOpIEEE754>(calyxOp)) {
703 rewriter.create<calyx::AssignOp>(
704 loc, cast<calyx::FpToIntOpIEEE754>(calyxOp).getSignedOut(), c1);
705 }
else if (isa<calyx::IntToFpOpIEEE754>(calyxOp)) {
706 rewriter.create<calyx::AssignOp>(
707 loc, cast<calyx::IntToFpOpIEEE754>(calyxOp).getSignedIn(), c1);
709 op.getResult().replaceAllUsesWith(reg.getOut());
711 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), calyxOp.getOut());
712 rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), c1);
714 rewriter.create<calyx::AssignOp>(
715 loc, calyxOp.getGo(), c1,
717 rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());
725 calyx::GroupInterface group,
727 Operation::operand_range addressValues)
const {
728 IRRewriter::InsertionGuard guard(rewriter);
729 rewriter.setInsertionPointToEnd(group.getBody());
730 auto addrPorts = memoryInterface.
addrPorts();
731 if (addressValues.empty()) {
733 addrPorts.size() == 1 &&
734 "We expected a 1 dimensional memory of size 1 because there were no "
735 "address assignment values");
737 rewriter.create<calyx::AssignOp>(
741 assert(addrPorts.size() == addressValues.size() &&
742 "Mismatch between number of address ports of the provided memory "
743 "and address assignment values");
744 for (
auto address : enumerate(addressValues))
745 rewriter.create<calyx::AssignOp>(loc, addrPorts[address.index()],
751 Value signal,
bool invert,
752 StringRef nameSuffix,
753 calyx::CompareFOpIEEE754 calyxCmpFOp,
754 calyx::GroupOp group)
const {
755 Location loc = calyxCmpFOp.getLoc();
756 IntegerType one = rewriter.getI1Type();
758 OpBuilder builder(group->getRegion(0));
759 auto reg = createRegister(
760 loc, rewriter, component, 1,
761 getState<ComponentLoweringState>().getUniqueName(nameSuffix));
762 rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(),
763 calyxCmpFOp.getDone());
765 auto notLibOp = getState<ComponentLoweringState>()
766 .getNewLibraryOpInstance<calyx::NotLibOp>(
767 rewriter, loc, {one, one});
768 rewriter.create<calyx::AssignOp>(loc, notLibOp.getIn(), signal);
769 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), notLibOp.getOut());
770 getState<ComponentLoweringState>().registerEvaluatingGroup(
771 notLibOp.getOut(), group);
773 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), signal);
779 memref::LoadOp loadOp)
const {
780 Value memref = loadOp.getMemref();
781 auto memoryInterface =
782 getState<ComponentLoweringState>().getMemoryInterface(memref);
783 auto group = createGroupForOp<calyx::GroupOp>(rewriter, loadOp);
785 loadOp.getIndices());
787 rewriter.setInsertionPointToEnd(group.getBodyBlock());
792 createConstant(loadOp.getLoc(), rewriter,
getComponent(), 1, 1);
793 if (memoryInterface.readEnOpt().has_value()) {
796 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), memoryInterface.readEn(),
798 regWriteEn = memoryInterface.done();
805 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(),
806 memoryInterface.done());
816 res = loadOp.getResult();
818 }
else if (memoryInterface.contentEnOpt().has_value()) {
823 rewriter.create<calyx::AssignOp>(loadOp.getLoc(),
824 memoryInterface.contentEn(), oneI1);
825 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), memoryInterface.writeEn(),
827 regWriteEn = memoryInterface.done();
834 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(),
835 memoryInterface.done());
845 res = loadOp.getResult();
857 auto reg = createRegister(
859 loadOp.getMemRefType().getElementTypeBitWidth(),
860 getState<ComponentLoweringState>().getUniqueName(
"load"));
861 rewriter.setInsertionPointToEnd(group.getBodyBlock());
862 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), reg.getIn(),
863 memoryInterface.readData());
864 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), reg.getWriteEn(),
866 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(), reg.getDone());
867 loadOp.getResult().replaceAllUsesWith(reg.getOut());
871 getState<ComponentLoweringState>().registerEvaluatingGroup(res, group);
872 getState<ComponentLoweringState>().addBlockScheduleable(loadOp->getBlock(),
878 memref::StoreOp storeOp)
const {
879 auto memoryInterface = getState<ComponentLoweringState>().getMemoryInterface(
880 storeOp.getMemref());
881 auto group = createGroupForOp<calyx::GroupOp>(rewriter, storeOp);
885 getState<ComponentLoweringState>().addBlockScheduleable(storeOp->getBlock(),
888 storeOp.getIndices());
889 rewriter.setInsertionPointToEnd(group.getBodyBlock());
890 rewriter.create<calyx::AssignOp>(
891 storeOp.getLoc(), memoryInterface.writeData(), storeOp.getValueToStore());
892 rewriter.create<calyx::AssignOp>(
893 storeOp.getLoc(), memoryInterface.writeEn(),
894 createConstant(storeOp.getLoc(), rewriter,
getComponent(), 1, 1));
895 if (memoryInterface.contentEnOpt().has_value()) {
897 rewriter.create<calyx::AssignOp>(
898 storeOp.getLoc(), memoryInterface.contentEn(),
899 createConstant(storeOp.getLoc(), rewriter,
getComponent(), 1, 1));
901 rewriter.create<calyx::GroupDoneOp>(storeOp.getLoc(), memoryInterface.done());
908 Location loc = mul.getLoc();
909 Type width = mul.getResult().getType(), one = rewriter.getI1Type();
911 getState<ComponentLoweringState>()
912 .getNewLibraryOpInstance<calyx::MultPipeLibOp>(
913 rewriter, loc, {one, one, one, width, width, width, one});
914 return buildLibraryBinaryPipeOp<calyx::MultPipeLibOp>(
915 rewriter, mul, mulPipe,
921 Location loc = div.getLoc();
922 Type width = div.getResult().getType(), one = rewriter.getI1Type();
924 getState<ComponentLoweringState>()
925 .getNewLibraryOpInstance<calyx::DivUPipeLibOp>(
926 rewriter, loc, {one, one, one, width, width, width, one});
927 return buildLibraryBinaryPipeOp<calyx::DivUPipeLibOp>(
928 rewriter, div, divPipe,
934 Location loc = div.getLoc();
935 Type width = div.getResult().getType(), one = rewriter.getI1Type();
937 getState<ComponentLoweringState>()
938 .getNewLibraryOpInstance<calyx::DivSPipeLibOp>(
939 rewriter, loc, {one, one, one, width, width, width, one});
940 return buildLibraryBinaryPipeOp<calyx::DivSPipeLibOp>(
941 rewriter, div, divPipe,
947 Location loc = rem.getLoc();
948 Type width = rem.getResult().getType(), one = rewriter.getI1Type();
950 getState<ComponentLoweringState>()
951 .getNewLibraryOpInstance<calyx::RemUPipeLibOp>(
952 rewriter, loc, {one, one, one, width, width, width, one});
953 return buildLibraryBinaryPipeOp<calyx::RemUPipeLibOp>(
954 rewriter, rem, remPipe,
960 Location loc = rem.getLoc();
961 Type width = rem.getResult().getType(), one = rewriter.getI1Type();
963 getState<ComponentLoweringState>()
964 .getNewLibraryOpInstance<calyx::RemSPipeLibOp>(
965 rewriter, loc, {one, one, one, width, width, width, one});
966 return buildLibraryBinaryPipeOp<calyx::RemSPipeLibOp>(
967 rewriter, rem, remPipe,
973 Location loc = addf.getLoc();
974 IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
975 five = rewriter.getIntegerType(5),
976 width = rewriter.getIntegerType(
977 addf.getType().getIntOrFloatBitWidth());
979 getState<ComponentLoweringState>()
980 .getNewLibraryOpInstance<calyx::AddFOpIEEE754>(
982 {one, one, one, one, one, width, width, three, width, five, one});
983 return buildLibraryBinaryPipeOp<calyx::AddFOpIEEE754>(rewriter, addf, addFOp,
989 Location loc = subf.getLoc();
990 IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
991 five = rewriter.getIntegerType(5),
992 width = rewriter.getIntegerType(
993 subf.getType().getIntOrFloatBitWidth());
995 getState<ComponentLoweringState>()
996 .getNewLibraryOpInstance<calyx::AddFOpIEEE754>(
998 {one, one, one, one, one, width, width, three, width, five, one});
999 return buildLibraryBinaryPipeOp<calyx::AddFOpIEEE754>(rewriter, subf, subFOp,
1004 MulFOp mulf)
const {
1005 Location loc = mulf.getLoc();
1006 IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
1007 five = rewriter.getIntegerType(5),
1008 width = rewriter.getIntegerType(
1009 mulf.getType().getIntOrFloatBitWidth());
1011 getState<ComponentLoweringState>()
1012 .getNewLibraryOpInstance<calyx::MulFOpIEEE754>(
1014 {one, one, one, one, width, width, three, width, five, one});
1015 return buildLibraryBinaryPipeOp<calyx::MulFOpIEEE754>(rewriter, mulf, mulFOp,
1020 CmpFOp cmpf)
const {
1021 Location loc = cmpf.getLoc();
1022 IntegerType one = rewriter.getI1Type(), five = rewriter.getIntegerType(5),
1023 width = rewriter.getIntegerType(
1024 cmpf.getLhs().getType().getIntOrFloatBitWidth());
1025 auto calyxCmpFOp = getState<ComponentLoweringState>()
1026 .getNewLibraryOpInstance<calyx::CompareFOpIEEE754>(
1028 {one, one, one, width, width, one, one, one, one,
1035 using CombLogic = PredicateInfo::CombLogic;
1036 using Port = PredicateInfo::InputPorts::Port;
1038 if (info.logic == CombLogic::None) {
1039 if (cmpf.getPredicate() == CmpFPredicate::AlwaysTrue) {
1040 cmpf.getResult().replaceAllUsesWith(c1);
1044 if (cmpf.getPredicate() == CmpFPredicate::AlwaysFalse) {
1045 cmpf.getResult().replaceAllUsesWith(c0);
1051 StringRef opName = cmpf.getOperationName().split(
".").second;
1054 getState<ComponentLoweringState>().getUniqueName(opName));
1057 auto group = createGroupForOp<calyx::GroupOp>(rewriter, cmpf);
1058 OpBuilder builder(group->getRegion(0));
1059 getState<ComponentLoweringState>().addBlockScheduleable(cmpf->getBlock(),
1062 rewriter.setInsertionPointToEnd(group.getBodyBlock());
1063 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getLeft(), cmpf.getLhs());
1064 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getRight(), cmpf.getRhs());
1066 bool signalingFlag =
false;
1067 switch (cmpf.getPredicate()) {
1068 case CmpFPredicate::UGT:
1069 case CmpFPredicate::UGE:
1070 case CmpFPredicate::ULT:
1071 case CmpFPredicate::ULE:
1072 case CmpFPredicate::OGT:
1073 case CmpFPredicate::OGE:
1074 case CmpFPredicate::OLT:
1075 case CmpFPredicate::OLE:
1076 signalingFlag =
true;
1078 case CmpFPredicate::UEQ:
1079 case CmpFPredicate::UNE:
1080 case CmpFPredicate::OEQ:
1081 case CmpFPredicate::ONE:
1082 case CmpFPredicate::UNO:
1083 case CmpFPredicate::ORD:
1084 case CmpFPredicate::AlwaysTrue:
1085 case CmpFPredicate::AlwaysFalse:
1086 signalingFlag =
false;
1092 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getSignaling(),
1093 signalingFlag ? c1 : c0);
1096 SmallVector<calyx::RegisterOp> inputRegs;
1097 for (
const auto &input : info.inputPorts) {
1099 switch (input.port) {
1101 signal = calyxCmpFOp.getEq();
1105 signal = calyxCmpFOp.getGt();
1109 signal = calyxCmpFOp.getLt();
1112 case Port::Unordered: {
1113 signal = calyxCmpFOp.getUnordered();
1117 std::string nameSuffix =
1118 (input.port == PredicateInfo::InputPorts::Port::Unordered)
1122 nameSuffix, calyxCmpFOp, group);
1123 inputRegs.push_back(signalReg);
1127 Value outputValue, doneValue;
1128 switch (info.logic) {
1129 case CombLogic::None: {
1131 outputValue = inputRegs[0].getOut();
1132 doneValue = inputRegs[0].getDone();
1135 case CombLogic::And: {
1136 auto outputLibOp = getState<ComponentLoweringState>()
1137 .getNewLibraryOpInstance<calyx::AndLibOp>(
1138 rewriter, loc, {one, one, one});
1139 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
1140 inputRegs[0].getOut());
1141 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
1142 inputRegs[1].getOut());
1144 outputValue = outputLibOp.getOut();
1147 case CombLogic::Or: {
1148 auto outputLibOp = getState<ComponentLoweringState>()
1149 .getNewLibraryOpInstance<calyx::OrLibOp>(
1150 rewriter, loc, {one, one, one});
1151 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
1152 inputRegs[0].getOut());
1153 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
1154 inputRegs[1].getOut());
1156 outputValue = outputLibOp.getOut();
1161 if (info.logic != CombLogic::None) {
1162 auto doneLibOp = getState<ComponentLoweringState>()
1163 .getNewLibraryOpInstance<calyx::AndLibOp>(
1164 rewriter, loc, {one, one, one});
1165 rewriter.create<calyx::AssignOp>(loc, doneLibOp.getLeft(),
1166 inputRegs[0].getDone());
1167 rewriter.create<calyx::AssignOp>(loc, doneLibOp.getRight(),
1168 inputRegs[1].getDone());
1169 doneValue = doneLibOp.getOut();
1173 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), outputValue);
1174 rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), doneValue);
1177 rewriter.create<calyx::AssignOp>(
1178 loc, calyxCmpFOp.getGo(), c1,
1180 rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());
1182 cmpf.getResult().replaceAllUsesWith(reg.getOut());
1185 getState<ComponentLoweringState>().registerEvaluatingGroup(outputValue,
1187 getState<ComponentLoweringState>().registerEvaluatingGroup(doneValue, group);
1188 getState<ComponentLoweringState>().registerEvaluatingGroup(
1189 calyxCmpFOp.getLeft(), group);
1190 getState<ComponentLoweringState>().registerEvaluatingGroup(
1191 calyxCmpFOp.getRight(), group);
1197 FPToSIOp fptosi)
const {
1198 return buildFpIntTypeCastOp<calyx::FpToIntOpIEEE754>(
1199 rewriter, fptosi, fptosi.getIn().getType().getIntOrFloatBitWidth(),
1200 fptosi.getOut().getType().getIntOrFloatBitWidth(),
"signedOut");
1204 SIToFPOp sitofp)
const {
1205 return buildFpIntTypeCastOp<calyx::IntToFpOpIEEE754>(
1206 rewriter, sitofp, sitofp.getIn().getType().getIntOrFloatBitWidth(),
1207 sitofp.getOut().getType().getIntOrFloatBitWidth(),
"signedIn");
1211 DivFOp divf)
const {
1212 Location loc = divf.getLoc();
1213 IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
1214 five = rewriter.getIntegerType(5),
1215 width = rewriter.getIntegerType(
1216 divf.getType().getIntOrFloatBitWidth());
1217 auto divFOp = getState<ComponentLoweringState>()
1218 .getNewLibraryOpInstance<calyx::DivSqrtOpIEEE754>(
1222 width, three, width,
1224 return buildLibraryBinaryPipeOp<calyx::DivSqrtOpIEEE754>(
1225 rewriter, divf, divFOp, divFOp.getOut());
1229 math::SqrtOp sqrt)
const {
1230 Location loc = sqrt.getLoc();
1231 IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
1232 five = rewriter.getIntegerType(5),
1233 width = rewriter.getIntegerType(
1234 sqrt.getType().getIntOrFloatBitWidth());
1235 auto sqrtOp = getState<ComponentLoweringState>()
1236 .getNewLibraryOpInstance<calyx::DivSqrtOpIEEE754>(
1240 width, three, width,
1242 return buildLibraryBinaryPipeOp<calyx::DivSqrtOpIEEE754>(
1243 rewriter, sqrt, sqrtOp, sqrtOp.getOut());
1247 math::AbsFOp absFOp)
const {
1248 Location loc = absFOp.getLoc();
1249 auto input = absFOp.getOperand();
1251 unsigned bitwidth = input.getType().getIntOrFloatBitWidth();
1252 Type intTy = rewriter.getIntegerType(bitwidth);
1254 uint64_t signBit = 1ULL << (bitwidth - 1);
1255 uint64_t absMask = ~signBit & ((1ULL << bitwidth) - 1);
1257 Value maskOp = rewriter.create<arith::ConstantIntOp>(loc, absMask, intTy);
1259 auto combGroup = createGroupForOp<calyx::CombGroupOp>(rewriter, absFOp);
1260 rewriter.setInsertionPointToStart(combGroup.getBodyBlock());
1262 auto andLibOp = getState<ComponentLoweringState>()
1263 .getNewLibraryOpInstance<calyx::AndLibOp>(
1264 rewriter, loc, {intTy, intTy, intTy});
1265 rewriter.create<calyx::AssignOp>(loc, andLibOp.getLeft(), maskOp);
1266 rewriter.create<calyx::AssignOp>(loc, andLibOp.getRight(), input);
1268 getState<ComponentLoweringState>().registerEvaluatingGroup(andLibOp.getOut(),
1270 rewriter.replaceAllUsesWith(absFOp, andLibOp.getOut());
1275template <
typename TAllocOp>
1277 PatternRewriter &rewriter, TAllocOp allocOp) {
1278 rewriter.setInsertionPointToStart(
1280 MemRefType memtype = allocOp.getType();
1281 SmallVector<int64_t> addrSizes;
1282 SmallVector<int64_t> sizes;
1283 for (int64_t dim : memtype.getShape()) {
1284 sizes.push_back(dim);
1289 if (sizes.empty() && addrSizes.empty()) {
1291 addrSizes.push_back(1);
1293 auto memoryOp = rewriter.create<calyx::SeqMemoryOp>(
1295 memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
1299 memoryOp->setAttr(
"external",
1300 IntegerAttr::get(rewriter.getI1Type(), llvm::APInt(1, 1)));
1304 unsigned elmTyBitWidth = memtype.getElementTypeBitWidth();
1305 assert(elmTyBitWidth <= 64 &&
"element bitwidth should not exceed 64");
1306 bool isFloat = !memtype.getElementType().isInteger();
1308 auto shape = allocOp.getType().getShape();
1310 std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
1317 if (!(shape.size() <= 1 || totalSize <= 1)) {
1318 allocOp.emitError(
"input memory dimension must be empty or one.");
1322 std::vector<uint64_t> flattenedVals(totalSize, 0);
1323 if (isa<memref::GetGlobalOp>(allocOp)) {
1324 auto getGlobalOp = cast<memref::GetGlobalOp>(allocOp);
1325 auto *symbolTableOp =
1326 getGlobalOp->template getParentWithTrait<mlir::OpTrait::SymbolTable>();
1327 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
1328 SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
1330 auto cstAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(
1331 globalOp.getConstantInitValue());
1333 for (
auto attr : cstAttr.template getValues<Attribute>()) {
1334 assert((isa<mlir::FloatAttr, mlir::IntegerAttr>(attr)) &&
1335 "memory attributes must be float or int");
1336 if (
auto fltAttr = dyn_cast<mlir::FloatAttr>(attr)) {
1337 flattenedVals[sizeCount++] =
1338 bit_cast<uint64_t>(fltAttr.getValueAsDouble());
1340 auto intAttr = dyn_cast<mlir::IntegerAttr>(attr);
1341 APInt value = intAttr.getValue();
1342 flattenedVals[sizeCount++] = *value.getRawData();
1346 rewriter.eraseOp(globalOp);
1349 llvm::json::Array result;
1350 result.reserve(std::max(
static_cast<int>(shape.size()), 1));
1352 Type elemType = memtype.getElementType();
1354 !elemType.isSignlessInteger() && !elemType.isUnsignedInteger();
1355 for (uint64_t bitValue : flattenedVals) {
1356 llvm::json::Value value = 0;
1360 value = bit_cast<double>(bitValue);
1362 APInt apInt(elmTyBitWidth, bitValue, isSigned,
1367 value =
static_cast<int64_t
>(apInt.getSExtValue());
1369 value = apInt.getZExtValue();
1371 result.push_back(std::move(value));
1374 componentState.
setDataField(memoryOp.getName(), result);
1375 std::string numType =
1376 memtype.getElementType().isInteger() ?
"bitnum" :
"ieee754_float";
1377 componentState.
setFormat(memoryOp.getName(), numType, isSigned,
1384 memref::AllocOp allocOp)
const {
1385 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
1389 memref::AllocaOp allocOp)
const {
1390 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
1394 memref::GetGlobalOp getGlobalOp)
const {
1395 return buildAllocOp(getState<ComponentLoweringState>(), rewriter,
1400 scf::YieldOp yieldOp)
const {
1401 if (yieldOp.getOperands().empty()) {
1402 if (
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
1406 auto inductionReg = getState<ComponentLoweringState>().getForLoopIterReg(
1409 Type regWidth = inductionReg.getOut().getType();
1411 SmallVector<Type> types(3, regWidth);
1412 auto addOp = getState<ComponentLoweringState>()
1413 .getNewLibraryOpInstance<calyx::AddLibOp>(
1414 rewriter, forOp.getLoc(), types);
1416 auto directions = addOp.portDirections();
1418 SmallVector<Value, 2> opInputPorts;
1420 for (
auto dir : enumerate(directions)) {
1421 switch (dir.value()) {
1423 opInputPorts.push_back(addOp.getResult(dir.index()));
1427 opOutputPort = addOp.getResult(dir.index());
1435 getState<ComponentLoweringState>().getComponentOp();
1436 SmallVector<StringRef, 4> groupIdentifier = {
1437 "incr", getState<ComponentLoweringState>().getUniqueName(forOp),
1438 "induction",
"var"};
1439 auto groupOp = calyx::createGroup<calyx::GroupOp>(
1441 llvm::join(groupIdentifier,
"_"));
1442 rewriter.setInsertionPointToEnd(groupOp.getBodyBlock());
1445 Value leftOp = opInputPorts.front();
1446 rewriter.create<calyx::AssignOp>(forOp.getLoc(), leftOp,
1447 inductionReg.getOut());
1449 Value rightOp = opInputPorts.back();
1450 rewriter.create<calyx::AssignOp>(
1451 forOp.getLoc(), rightOp,
1452 createConstant(forOp->getLoc(), rewriter,
componentOp,
1453 regWidth.getIntOrFloatBitWidth(),
1454 forOp.getConstantStep().value().getSExtValue()));
1456 buildAssignmentsForRegisterWrite(rewriter, groupOp,
componentOp,
1457 inductionReg, opOutputPort);
1459 getState<ComponentLoweringState>().setForLoopLatchGroup(forOpInterface,
1461 getState<ComponentLoweringState>().registerEvaluatingGroup(opOutputPort,
1465 if (
auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp()))
1468 if (
auto executeRegionOp =
1469 dyn_cast<scf::ExecuteRegionOp>(yieldOp->getParentOp()))
1472 return yieldOp.getOperation()->emitError()
1473 <<
"Unsupported empty yieldOp outside ForOp or IfOp.";
1476 if (dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
1477 return yieldOp.getOperation()->emitError()
1478 <<
"Currently do not support non-empty yield operations inside for "
1479 "loops. Run --scf-for-to-while before running --scf-to-calyx.";
1482 if (
auto whileOp = dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1486 getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
1487 rewriter, whileOpInterface,
1488 getState<ComponentLoweringState>().getComponentOp(),
1489 getState<ComponentLoweringState>().getUniqueName(whileOp) +
1491 yieldOp->getOpOperands());
1492 getState<ComponentLoweringState>().setWhileLoopLatchGroup(whileOpInterface,
1497 if (
auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
1498 auto resultRegs = getState<ComponentLoweringState>().getResultRegs(ifOp);
1500 if (yieldOp->getParentRegion() == &ifOp.getThenRegion()) {
1501 auto thenGroup = getState<ComponentLoweringState>().getThenGroup(ifOp);
1502 for (
auto op : enumerate(yieldOp.getOperands())) {
1504 getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
1505 buildAssignmentsForRegisterWrite(
1506 rewriter, thenGroup,
1507 getState<ComponentLoweringState>().getComponentOp(), resultReg,
1509 getState<ComponentLoweringState>().registerEvaluatingGroup(
1510 ifOp.getResult(op.index()), thenGroup);
1514 if (!ifOp.getElseRegion().empty() &&
1515 (yieldOp->getParentRegion() == &ifOp.getElseRegion())) {
1516 auto elseGroup = getState<ComponentLoweringState>().getElseGroup(ifOp);
1517 for (
auto op : enumerate(yieldOp.getOperands())) {
1519 getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
1520 buildAssignmentsForRegisterWrite(
1521 rewriter, elseGroup,
1522 getState<ComponentLoweringState>().getComponentOp(), resultReg,
1524 getState<ComponentLoweringState>().registerEvaluatingGroup(
1525 ifOp.getResult(op.index()), elseGroup);
1533 BranchOpInterface brOp)
const {
1538 Block *srcBlock = brOp->getBlock();
1539 for (
auto succBlock : enumerate(brOp->getSuccessors())) {
1540 auto succOperands = brOp.getSuccessorOperands(succBlock.index());
1541 if (succOperands.empty())
1546 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter,
getComponent(),
1547 brOp.getLoc(), groupName);
1549 auto dstBlockArgRegs =
1550 getState<ComponentLoweringState>().getBlockArgRegs(succBlock.value());
1552 for (
auto arg : enumerate(succOperands.getForwardedOperands())) {
1553 auto reg = dstBlockArgRegs[arg.index()];
1556 getState<ComponentLoweringState>().getComponentOp(), reg,
1561 getState<ComponentLoweringState>().addBlockArgGroup(
1562 srcBlock, succBlock.value(), groupOp);
1570 ReturnOp retOp)
const {
1571 if (retOp.getNumOperands() == 0)
1574 std::string groupName =
1575 getState<ComponentLoweringState>().getUniqueName(
"ret_assign");
1576 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter,
getComponent(),
1577 retOp.getLoc(), groupName);
1578 for (
auto op : enumerate(retOp.getOperands())) {
1579 auto reg = getState<ComponentLoweringState>().getReturnReg(op.index());
1581 rewriter, groupOp, getState<ComponentLoweringState>().getComponentOp(),
1585 getState<ComponentLoweringState>().addBlockScheduleable(retOp->getBlock(),
1591 arith::ConstantOp constOp)
const {
1592 if (isa<IntegerType>(constOp.getType())) {
1601 std::string name = getState<ComponentLoweringState>().getUniqueName(
"cst");
1602 auto floatAttr = cast<FloatAttr>(constOp.getValueAttr());
1604 rewriter.getIntegerType(floatAttr.getType().getIntOrFloatBitWidth());
1605 auto calyxConstOp = rewriter.create<calyx::ConstantOp>(
1606 constOp.getLoc(), name, floatAttr, intType);
1609 rewriter.replaceAllUsesWith(constOp, calyxConstOp.getOut());
1617 return buildLibraryOp<calyx::CombGroupOp, calyx::AddLibOp>(rewriter, op);
1621 return buildLibraryOp<calyx::CombGroupOp, calyx::SubLibOp>(rewriter, op);
1625 return buildLibraryOp<calyx::CombGroupOp, calyx::RshLibOp>(rewriter, op);
1629 return buildLibraryOp<calyx::CombGroupOp, calyx::SrshLibOp>(rewriter, op);
1633 return buildLibraryOp<calyx::CombGroupOp, calyx::LshLibOp>(rewriter, op);
1637 return buildLibraryOp<calyx::CombGroupOp, calyx::AndLibOp>(rewriter, op);
1641 return buildLibraryOp<calyx::CombGroupOp, calyx::OrLibOp>(rewriter, op);
1645 return buildLibraryOp<calyx::CombGroupOp, calyx::XorLibOp>(rewriter, op);
1648 SelectOp op)
const {
1649 return buildLibraryOp<calyx::CombGroupOp, calyx::MuxLibOp>(rewriter, op);
1654 switch (op.getPredicate()) {
1655 case CmpIPredicate::eq:
1656 return buildCmpIOpHelper<calyx::EqLibOp>(rewriter, op);
1657 case CmpIPredicate::ne:
1658 return buildCmpIOpHelper<calyx::NeqLibOp>(rewriter, op);
1659 case CmpIPredicate::uge:
1660 return buildCmpIOpHelper<calyx::GeLibOp>(rewriter, op);
1661 case CmpIPredicate::ult:
1662 return buildCmpIOpHelper<calyx::LtLibOp>(rewriter, op);
1663 case CmpIPredicate::ugt:
1664 return buildCmpIOpHelper<calyx::GtLibOp>(rewriter, op);
1665 case CmpIPredicate::ule:
1666 return buildCmpIOpHelper<calyx::LeLibOp>(rewriter, op);
1667 case CmpIPredicate::sge:
1668 return buildCmpIOpHelper<calyx::SgeLibOp>(rewriter, op);
1669 case CmpIPredicate::slt:
1670 return buildCmpIOpHelper<calyx::SltLibOp>(rewriter, op);
1671 case CmpIPredicate::sgt:
1672 return buildCmpIOpHelper<calyx::SgtLibOp>(rewriter, op);
1673 case CmpIPredicate::sle:
1674 return buildCmpIOpHelper<calyx::SleLibOp>(rewriter, op);
1676 llvm_unreachable(
"unsupported comparison predicate");
1680 TruncIOp op)
const {
1681 return buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
1682 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1686 return buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
1687 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1692 return buildLibraryOp<calyx::CombGroupOp, calyx::ExtSILibOp>(
1693 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1697 IndexCastOp op)
const {
1700 unsigned targetBits = targetType.getIntOrFloatBitWidth();
1701 unsigned sourceBits = sourceType.getIntOrFloatBitWidth();
1702 LogicalResult res = success();
1704 if (targetBits == sourceBits) {
1707 op.getResult().replaceAllUsesWith(op.getOperand());
1710 if (sourceBits > targetBits)
1711 res = buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
1712 rewriter, op, {sourceType}, {targetType});
1714 res = buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
1715 rewriter, op, {sourceType}, {targetType});
1717 rewriter.eraseOp(op);
1724 BitcastOp op)
const {
1725 rewriter.replaceAllUsesWith(op.getOut(), op.getIn());
1730 scf::WhileOp whileOp)
const {
1734 getState<ComponentLoweringState>().addBlockScheduleable(
1740 scf::ForOp forOp)
const {
1746 std::optional<uint64_t> bound = scfForOp.
getBound();
1747 if (!bound.has_value()) {
1749 <<
"Loop bound not statically known. Should "
1750 "transform into while loop using `--scf-for-to-while` before "
1751 "running --lower-scf-to-calyx.";
1753 getState<ComponentLoweringState>().addBlockScheduleable(
1762 scf::IfOp ifOp)
const {
1763 getState<ComponentLoweringState>().addBlockScheduleable(
1769 scf::ReduceOp reduceOp)
const {
1777 scf::ParallelOp parOp)
const {
1780 "AffineParallelUnroll must be run in order to lower scf.parallel");
1783 getState<ComponentLoweringState>().addBlockScheduleable(
1790 scf::ExecuteRegionOp executeRegionOp)
const {
1798 CallOp callOp)
const {
1800 calyx::InstanceOp instanceOp =
1801 getState<ComponentLoweringState>().getInstance(instanceName);
1802 SmallVector<Value, 4> outputPorts;
1803 auto portInfos = instanceOp.getReferencedComponent().getPortInfo();
1804 for (
auto [idx, portInfo] : enumerate(portInfos)) {
1806 outputPorts.push_back(instanceOp.getResult(idx));
1810 for (
auto [idx, result] : llvm::enumerate(callOp.getResults()))
1811 rewriter.replaceAllUsesWith(result, outputPorts[idx]);
1815 getState<ComponentLoweringState>().addBlockScheduleable(
1829 using OpRewritePattern::OpRewritePattern;
1832 PatternRewriter &rewriter)
const override {
1833 if (
auto parOp = dyn_cast_or_null<scf::ParallelOp>(execOp->getParentOp())) {
1834 if (
auto boolAttr = dyn_cast_or_null<mlir::BoolAttr>(
1842 TypeRange yieldTypes = execOp.getResultTypes();
1846 rewriter.setInsertionPointAfter(execOp);
1847 auto *sinkBlock = rewriter.splitBlock(
1849 execOp.getOperation()->getIterator()->getNextNode()->getIterator());
1850 sinkBlock->addArguments(
1852 SmallVector<Location, 4>(yieldTypes.size(), rewriter.getUnknownLoc()));
1853 for (
auto res : enumerate(execOp.getResults()))
1854 res.value().replaceAllUsesWith(sinkBlock->getArgument(res.index()));
1858 make_early_inc_range(execOp.getRegion().getOps<scf::YieldOp>())) {
1859 rewriter.setInsertionPointAfter(yieldOp);
1860 rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, sinkBlock,
1861 yieldOp.getOperands());
1865 auto *preBlock = execOp->getBlock();
1866 auto *execOpEntryBlock = &execOp.getRegion().front();
1867 auto *postBlock = execOp->getBlock()->splitBlock(execOp);
1868 rewriter.inlineRegionBefore(execOp.getRegion(), postBlock);
1869 rewriter.mergeBlocks(postBlock, preBlock);
1870 rewriter.eraseOp(execOp);
1873 rewriter.mergeBlocks(execOpEntryBlock, preBlock);
1881 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1885 PatternRewriter &rewriter)
const override {
1888 DenseMap<Value, unsigned> funcOpArgRewrites;
1892 DenseMap<unsigned, unsigned> funcOpResultMapping;
1900 DenseMap<Value, std::pair<unsigned, unsigned>> extMemoryCompPortIndices;
1904 SmallVector<calyx::PortInfo> inPorts, outPorts;
1905 FunctionType funcType = funcOp.getFunctionType();
1906 for (
auto arg : enumerate(funcOp.getArguments())) {
1907 if (!isa<MemRefType>(arg.value().getType())) {
1910 if (
auto portNameAttr = funcOp.getArgAttrOfType<StringAttr>(
1912 inName = portNameAttr.str();
1914 inName =
"in" + std::to_string(arg.index());
1915 funcOpArgRewrites[arg.value()] = inPorts.size();
1917 rewriter.getStringAttr(inName),
1920 DictionaryAttr::get(rewriter.getContext(), {})});
1923 for (
auto res : enumerate(funcType.getResults())) {
1924 std::string resName;
1925 if (
auto portNameAttr = funcOp.getResultAttrOfType<StringAttr>(
1927 resName = portNameAttr.str();
1929 resName =
"out" + std::to_string(res.index());
1930 funcOpResultMapping[res.index()] = outPorts.size();
1933 rewriter.getStringAttr(resName),
1935 DictionaryAttr::get(rewriter.getContext(), {})});
1940 auto ports = inPorts;
1941 llvm::append_range(ports, outPorts);
1945 auto compOp = rewriter.create<calyx::ComponentOp>(
1946 funcOp.getLoc(), rewriter.getStringAttr(funcOp.getSymName()), ports);
1948 std::string funcName =
"func_" + funcOp.getSymName().str();
1949 rewriter.modifyOpInPlace(funcOp, [&]() { funcOp.setSymName(funcName); });
1954 compOp->setAttr(
"toplevel", rewriter.getUnitAttr());
1961 unsigned extMemCounter = 0;
1962 for (
auto arg : enumerate(funcOp.getArguments())) {
1963 if (isa<MemRefType>(arg.value().getType())) {
1964 std::string memName =
1965 llvm::join_items(
"_",
"arg_mem", std::to_string(extMemCounter++));
1967 rewriter.setInsertionPointToStart(compOp.getBodyBlock());
1968 MemRefType memtype = cast<MemRefType>(arg.value().getType());
1969 SmallVector<int64_t> addrSizes;
1970 SmallVector<int64_t> sizes;
1971 for (int64_t dim : memtype.getShape()) {
1972 sizes.push_back(dim);
1975 if (sizes.empty() && addrSizes.empty()) {
1977 addrSizes.push_back(1);
1979 auto memOp = rewriter.create<calyx::SeqMemoryOp>(
1980 funcOp.getLoc(), memName,
1981 memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
1984 compState->registerMemoryInterface(arg.value(),
1990 for (
auto &mapping : funcOpArgRewrites)
1991 mapping.getFirst().replaceAllUsesWith(
1992 compOp.getArgument(mapping.getSecond()));
2003 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2007 PatternRewriter &rewriter)
const override {
2008 LogicalResult res = success();
2009 funcOp.walk([&](Operation *op) {
2011 if (!isa<scf::WhileOp>(op))
2012 return WalkResult::advance();
2014 auto scfWhileOp = cast<scf::WhileOp>(op);
2017 getState<ComponentLoweringState>().setUniqueName(whileOp.
getOperation(),
2027 enumerate(scfWhileOp.getBefore().front().getArguments())) {
2028 auto condOp = scfWhileOp.getConditionOp().getArgs()[barg.index()];
2029 if (barg.value() != condOp) {
2033 <<
"do-while loops not supported; expected iter-args to "
2034 "remain untransformed in the 'before' region of the "
2036 return WalkResult::interrupt();
2045 for (
auto arg : enumerate(whileOp.
getBodyArgs())) {
2046 std::string name = getState<ComponentLoweringState>()
2049 "_arg" + std::to_string(arg.index());
2051 createRegister(arg.value().getLoc(), rewriter,
getComponent(),
2052 arg.value().getType().getIntOrFloatBitWidth(), name);
2053 getState<ComponentLoweringState>().addWhileLoopIterReg(whileOp, reg,
2055 arg.value().replaceAllUsesWith(reg.getOut());
2059 ->getArgument(arg.index())
2060 .replaceAllUsesWith(reg.getOut());
2064 SmallVector<calyx::GroupOp> initGroups;
2065 auto numOperands = whileOp.
getOperation()->getNumOperands();
2066 for (
size_t i = 0; i < numOperands; ++i) {
2068 getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
2070 getState<ComponentLoweringState>().getComponentOp(),
2071 getState<ComponentLoweringState>().getUniqueName(
2073 "_init_" + std::to_string(i),
2075 initGroups.push_back(initGroupOp);
2078 getState<ComponentLoweringState>().setWhileLoopInitGroups(whileOp,
2081 return WalkResult::advance();
2091 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2095 PatternRewriter &rewriter)
const override {
2096 LogicalResult res = success();
2097 funcOp.walk([&](Operation *op) {
2099 if (!isa<scf::ForOp>(op))
2100 return WalkResult::advance();
2102 auto scfForOp = cast<scf::ForOp>(op);
2105 getState<ComponentLoweringState>().setUniqueName(forOp.
getOperation(),
2110 auto inductionVar = forOp.
getOperation().getInductionVar();
2111 SmallVector<std::string, 3> inductionVarIdentifiers = {
2112 getState<ComponentLoweringState>()
2115 "induction",
"var"};
2116 std::string name = llvm::join(inductionVarIdentifiers,
"_");
2118 createRegister(inductionVar.getLoc(), rewriter,
getComponent(),
2119 inductionVar.getType().getIntOrFloatBitWidth(), name);
2120 getState<ComponentLoweringState>().addForLoopIterReg(forOp, reg, 0);
2121 inductionVar.replaceAllUsesWith(reg.getOut());
2125 getState<ComponentLoweringState>().getComponentOp();
2126 SmallVector<calyx::GroupOp> initGroups;
2127 SmallVector<std::string, 4> groupIdentifiers = {
2129 getState<ComponentLoweringState>()
2132 "induction",
"var"};
2133 std::string groupName = llvm::join(groupIdentifiers,
"_");
2134 auto groupOp = calyx::createGroup<calyx::GroupOp>(
2136 buildAssignmentsForRegisterWrite(rewriter, groupOp,
componentOp, reg,
2138 initGroups.push_back(groupOp);
2139 getState<ComponentLoweringState>().setForLoopInitGroups(forOp,
2142 return WalkResult::advance();
2149 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2153 PatternRewriter &rewriter)
const override {
2154 LogicalResult res = success();
2155 funcOp.walk([&](Operation *op) {
2156 if (!isa<scf::IfOp>(op))
2157 return WalkResult::advance();
2159 auto scfIfOp = cast<scf::IfOp>(op);
2164 if (scfIfOp.getResults().empty())
2165 return WalkResult::advance();
2168 getState<ComponentLoweringState>().getComponentOp();
2170 std::string thenGroupName =
2171 getState<ComponentLoweringState>().getUniqueName(
"then_br");
2172 auto thenGroupOp = calyx::createGroup<calyx::GroupOp>(
2173 rewriter,
componentOp, scfIfOp.getLoc(), thenGroupName);
2174 getState<ComponentLoweringState>().setThenGroup(scfIfOp, thenGroupOp);
2176 if (!scfIfOp.getElseRegion().empty()) {
2177 std::string elseGroupName =
2178 getState<ComponentLoweringState>().getUniqueName(
"else_br");
2179 auto elseGroupOp = calyx::createGroup<calyx::GroupOp>(
2180 rewriter,
componentOp, scfIfOp.getLoc(), elseGroupName);
2181 getState<ComponentLoweringState>().setElseGroup(scfIfOp, elseGroupOp);
2184 for (
auto ifOpRes : scfIfOp.getResults()) {
2185 auto reg = createRegister(
2187 ifOpRes.getType().getIntOrFloatBitWidth(),
2188 getState<ComponentLoweringState>().getUniqueName(
"if_res"));
2189 getState<ComponentLoweringState>().setResultRegs(
2190 scfIfOp, reg, ifOpRes.getResultNumber());
2193 return WalkResult::advance();
2205 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2209 PatternRewriter &rewriter)
const override {
2210 auto *entryBlock = &funcOp.getBlocks().front();
2211 rewriter.setInsertionPointToStart(
2213 auto topLevelSeqOp = rewriter.create<calyx::SeqOp>(funcOp.getLoc());
2214 DenseSet<Block *> path;
2216 nullptr, entryBlock);
2223 const DenseSet<Block *> &path,
2224 mlir::Block *parentCtrlBlock,
2225 mlir::Block *block)
const {
2226 auto compBlockScheduleables =
2227 getState<ComponentLoweringState>().getBlockScheduleables(block);
2228 auto loc = block->front().getLoc();
2230 if (compBlockScheduleables.size() > 1 &&
2231 !isa<scf::ParallelOp>(block->getParentOp())) {
2232 auto seqOp = rewriter.create<calyx::SeqOp>(loc);
2233 parentCtrlBlock = seqOp.getBodyBlock();
2236 for (
auto &group : compBlockScheduleables) {
2237 rewriter.setInsertionPointToEnd(parentCtrlBlock);
2238 if (
auto groupPtr = std::get_if<calyx::GroupOp>(&group); groupPtr) {
2239 rewriter.create<calyx::EnableOp>(groupPtr->getLoc(),
2240 groupPtr->getSymName());
2241 }
else if (
auto whileSchedPtr = std::get_if<WhileScheduleable>(&group);
2243 auto &whileOp = whileSchedPtr->whileOp;
2247 getState<ComponentLoweringState>().getWhileLoopInitGroups(whileOp),
2249 rewriter.setInsertionPointToEnd(whileCtrlOp.getBodyBlock());
2251 rewriter.create<calyx::SeqOp>(whileOp.getOperation()->getLoc());
2252 auto *whileBodyOpBlock = whileBodyOp.getBodyBlock();
2256 if (LogicalResult result =
2258 whileOp.getBodyBlock());
2263 rewriter.setInsertionPointToEnd(whileBodyOpBlock);
2264 calyx::GroupOp whileLatchGroup =
2265 getState<ComponentLoweringState>().getWhileLoopLatchGroup(whileOp);
2266 rewriter.create<calyx::EnableOp>(whileLatchGroup.getLoc(),
2267 whileLatchGroup.getName());
2268 }
else if (
auto *parSchedPtr = std::get_if<ParScheduleable>(&group)) {
2269 auto parOp = parSchedPtr->parOp;
2270 auto calyxParOp = rewriter.create<calyx::ParOp>(parOp.getLoc());
2272 WalkResult walkResult =
2273 parOp.walk([&](scf::ExecuteRegionOp execRegion) {
2274 rewriter.setInsertionPointToEnd(calyxParOp.getBodyBlock());
2275 auto seqOp = rewriter.create<calyx::SeqOp>(execRegion.getLoc());
2276 rewriter.setInsertionPointToEnd(seqOp.getBodyBlock());
2278 for (
auto &execBlock : execRegion.getRegion().getBlocks()) {
2280 rewriter, path, seqOp.getBodyBlock(), &execBlock);
2282 return WalkResult::interrupt();
2285 return WalkResult::advance();
2288 if (walkResult.wasInterrupted())
2290 }
else if (
auto *forSchedPtr = std::get_if<ForScheduleable>(&group);
2292 auto forOp = forSchedPtr->forOp;
2296 getState<ComponentLoweringState>().getForLoopInitGroups(forOp),
2297 forSchedPtr->bound, rewriter);
2298 rewriter.setInsertionPointToEnd(forCtrlOp.getBodyBlock());
2300 rewriter.create<calyx::SeqOp>(forOp.getOperation()->getLoc());
2301 auto *forBodyOpBlock = forBodyOp.getBodyBlock();
2304 if (LogicalResult res =
buildCFGControl(path, rewriter, forBodyOpBlock,
2305 block, forOp.getBodyBlock());
2310 rewriter.setInsertionPointToEnd(forBodyOpBlock);
2311 calyx::GroupOp forLatchGroup =
2312 getState<ComponentLoweringState>().getForLoopLatchGroup(forOp);
2313 rewriter.create<calyx::EnableOp>(forLatchGroup.getLoc(),
2314 forLatchGroup.getName());
2315 }
else if (
auto *ifSchedPtr = std::get_if<IfScheduleable>(&group);
2317 auto ifOp = ifSchedPtr->ifOp;
2319 Location loc = ifOp->getLoc();
2321 auto cond = ifOp.getCondition();
2323 FlatSymbolRefAttr symbolAttr =
nullptr;
2324 auto condReg = getState<ComponentLoweringState>().getCondReg(ifOp);
2326 auto condGroup = getState<ComponentLoweringState>()
2327 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
2329 symbolAttr = FlatSymbolRefAttr::get(
2330 StringAttr::get(getContext(), condGroup.getSymName()));
2333 bool initElse = !ifOp.getElseRegion().empty();
2334 auto ifCtrlOp = rewriter.create<calyx::IfOp>(
2335 loc, cond, symbolAttr, initElse);
2337 rewriter.setInsertionPointToEnd(ifCtrlOp.getBodyBlock());
2340 rewriter.create<calyx::SeqOp>(ifOp.getThenRegion().getLoc());
2341 auto *thenSeqOpBlock = thenSeqOp.getBodyBlock();
2343 auto *thenBlock = &ifOp.getThenRegion().front();
2351 if (!ifOp.getResults().empty()) {
2352 rewriter.setInsertionPointToEnd(thenSeqOpBlock);
2353 calyx::GroupOp thenGroup =
2354 getState<ComponentLoweringState>().getThenGroup(ifOp);
2355 rewriter.create<calyx::EnableOp>(thenGroup.getLoc(),
2356 thenGroup.getName());
2359 if (!ifOp.getElseRegion().empty()) {
2360 rewriter.setInsertionPointToEnd(ifCtrlOp.getElseBody());
2363 rewriter.create<calyx::SeqOp>(ifOp.getElseRegion().getLoc());
2364 auto *elseSeqOpBlock = elseSeqOp.getBodyBlock();
2366 auto *elseBlock = &ifOp.getElseRegion().front();
2372 if (!ifOp.getResults().empty()) {
2373 rewriter.setInsertionPointToEnd(elseSeqOpBlock);
2374 calyx::GroupOp elseGroup =
2375 getState<ComponentLoweringState>().getElseGroup(ifOp);
2376 rewriter.create<calyx::EnableOp>(elseGroup.getLoc(),
2377 elseGroup.getName());
2380 }
else if (
auto *callSchedPtr = std::get_if<CallScheduleable>(&group)) {
2381 auto instanceOp = callSchedPtr->instanceOp;
2382 OpBuilder::InsertionGuard g(rewriter);
2383 auto callBody = rewriter.create<calyx::SeqOp>(instanceOp.getLoc());
2384 rewriter.setInsertionPointToStart(callBody.getBodyBlock());
2386 auto callee = callSchedPtr->callOp.getCallee();
2387 auto *calleeOp = SymbolTable::lookupNearestSymbolFrom(
2388 callSchedPtr->callOp.getOperation()->getParentOp(),
2389 StringAttr::get(rewriter.getContext(),
"func_" + callee.str()));
2390 FuncOp calleeFunc = dyn_cast_or_null<FuncOp>(calleeOp);
2392 auto instanceOpComp =
2393 llvm::cast<calyx::ComponentOp>(instanceOp.getReferencedComponent());
2394 auto *instanceOpLoweringState =
2397 SmallVector<Value, 4> instancePorts;
2398 SmallVector<Value, 4> inputPorts;
2399 SmallVector<Attribute, 4> refCells;
2400 for (
auto operandEnum : enumerate(callSchedPtr->callOp.getOperands())) {
2401 auto operand = operandEnum.value();
2402 auto index = operandEnum.index();
2403 if (!isa<MemRefType>(operand.getType())) {
2404 inputPorts.push_back(operand);
2408 auto memOpName = getState<ComponentLoweringState>()
2409 .getMemoryInterface(operand)
2411 auto memOpNameAttr =
2412 SymbolRefAttr::get(rewriter.getContext(), memOpName);
2413 Value argI = calleeFunc.getArgument(index);
2414 if (isa<MemRefType>(argI.getType())) {
2415 NamedAttrList namedAttrList;
2416 namedAttrList.append(
2417 rewriter.getStringAttr(
2418 instanceOpLoweringState->getMemoryInterface(argI)
2422 DictionaryAttr::get(rewriter.getContext(), namedAttrList));
2425 llvm::copy(instanceOp.getResults().take_front(inputPorts.size()),
2426 std::back_inserter(instancePorts));
2428 ArrayAttr refCellsAttr =
2429 ArrayAttr::get(rewriter.getContext(), refCells);
2431 rewriter.create<calyx::InvokeOp>(
2432 instanceOp.getLoc(), instanceOp.getSymName(), instancePorts,
2433 inputPorts, refCellsAttr, ArrayAttr::get(rewriter.getContext(), {}),
2434 ArrayAttr::get(rewriter.getContext(), {}));
2436 llvm_unreachable(
"Unknown scheduleable");
2447 const DenseSet<Block *> &path, Location loc,
2448 Block *from, Block *to,
2449 Block *parentCtrlBlock)
const {
2452 rewriter.setInsertionPointToEnd(parentCtrlBlock);
2453 auto preSeqOp = rewriter.create<calyx::SeqOp>(loc);
2454 rewriter.setInsertionPointToEnd(preSeqOp.getBodyBlock());
2456 getState<ComponentLoweringState>().getBlockArgGroups(from, to))
2457 rewriter.create<calyx::EnableOp>(barg.getLoc(), barg.getSymName());
2463 PatternRewriter &rewriter,
2464 mlir::Block *parentCtrlBlock,
2465 mlir::Block *preBlock,
2466 mlir::Block *block)
const {
2467 if (path.count(block) != 0)
2468 return preBlock->getTerminator()->emitError()
2469 <<
"CFG backedge detected. Loops must be raised to 'scf.while' or "
2470 "'scf.for' operations.";
2472 rewriter.setInsertionPointToEnd(parentCtrlBlock);
2473 LogicalResult bbSchedResult =
2475 if (bbSchedResult.failed())
2476 return bbSchedResult;
2479 auto successors = block->getSuccessors();
2480 auto nSuccessors = successors.size();
2481 if (nSuccessors > 0) {
2482 auto brOp = dyn_cast<BranchOpInterface>(block->getTerminator());
2484 if (nSuccessors > 1) {
2488 assert(nSuccessors == 2 &&
2489 "only conditional branches supported for now...");
2491 auto cond = brOp->getOperand(0);
2492 auto condGroup = getState<ComponentLoweringState>()
2493 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
2494 auto symbolAttr = FlatSymbolRefAttr::get(
2495 StringAttr::get(getContext(), condGroup.getSymName()));
2497 auto ifOp = rewriter.create<calyx::IfOp>(
2498 brOp->getLoc(), cond, symbolAttr,
true);
2499 rewriter.setInsertionPointToStart(ifOp.getThenBody());
2500 auto thenSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
2501 rewriter.setInsertionPointToStart(ifOp.getElseBody());
2502 auto elseSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
2504 bool trueBrSchedSuccess =
2505 schedulePath(rewriter, path, brOp.getLoc(), block, successors[0],
2506 thenSeqOp.getBodyBlock())
2508 bool falseBrSchedSuccess =
true;
2509 if (trueBrSchedSuccess) {
2510 falseBrSchedSuccess =
2511 schedulePath(rewriter, path, brOp.getLoc(), block, successors[1],
2512 elseSeqOp.getBodyBlock())
2516 return success(trueBrSchedSuccess && falseBrSchedSuccess);
2519 return schedulePath(rewriter, path, brOp.getLoc(), block,
2520 successors.front(), parentCtrlBlock);
2530 const SmallVector<calyx::GroupOp> &initGroups)
const {
2531 PatternRewriter::InsertionGuard g(rewriter);
2532 auto parOp = rewriter.create<calyx::ParOp>(loc);
2533 rewriter.setInsertionPointToStart(parOp.getBodyBlock());
2534 for (calyx::GroupOp group : initGroups)
2535 rewriter.create<calyx::EnableOp>(group.getLoc(), group.getName());
2539 SmallVector<calyx::GroupOp> initGroups,
2540 PatternRewriter &rewriter)
const {
2541 Location loc = whileOp.
getLoc();
2548 auto condGroup = getState<ComponentLoweringState>()
2549 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
2550 auto symbolAttr = FlatSymbolRefAttr::get(
2551 StringAttr::get(getContext(), condGroup.getSymName()));
2552 return rewriter.create<calyx::WhileOp>(loc, cond, symbolAttr);
2556 SmallVector<calyx::GroupOp>
const &initGroups,
2558 PatternRewriter &rewriter)
const {
2559 Location loc = forOp.
getLoc();
2565 return rewriter.create<calyx::RepeatOp>(loc, bound);
2572 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2575 PatternRewriter &)
const override {
2576 funcOp.walk([&](scf::IfOp op) {
2577 for (
auto res : getState<ComponentLoweringState>().getResultRegs(op))
2578 op.getOperation()->getResults()[res.first].replaceAllUsesWith(
2579 res.second.getOut());
2582 funcOp.walk([&](scf::WhileOp op) {
2591 getState<ComponentLoweringState>().getWhileLoopIterRegs(whileOp))
2592 whileOp.
getOperation()->getResults()[res.first].replaceAllUsesWith(
2593 res.second.getOut());
2596 funcOp.walk([&](memref::LoadOp loadOp) {
2602 loadOp.getResult().replaceAllUsesWith(
2603 getState<ComponentLoweringState>()
2604 .getMemoryInterface(loadOp.getMemref())
2615 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2618 PatternRewriter &rewriter)
const override {
2619 rewriter.eraseOp(funcOp);
2625 PatternRewriter &rewriter)
const override {
56namespace scftocalyx {
…}
2639class SCFToCalyxPass :
public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
2641 SCFToCalyxPass(std::string topLevelFunction)
2643 this->topLevelFunctionOpt = topLevelFunction;
2645 void runOnOperation()
override;
2647 LogicalResult setTopLevelFunction(mlir::ModuleOp moduleOp,
2648 std::string &topLevelFunction) {
2649 if (!topLevelFunctionOpt.empty()) {
2650 if (SymbolTable::lookupSymbolIn(moduleOp, topLevelFunctionOpt) ==
2652 moduleOp.emitError() <<
"Top level function '" << topLevelFunctionOpt
2653 <<
"' not found in module.";
2656 topLevelFunction = topLevelFunctionOpt;
2660 auto funcOps = moduleOp.getOps<FuncOp>();
2661 if (std::distance(funcOps.begin(), funcOps.end()) == 1)
2662 topLevelFunction = (*funcOps.begin()).getSymName().str();
2664 moduleOp.emitError()
2665 <<
"Module contains multiple functions, but no top level "
2666 "function was set. Please see --top-level-function";
2671 return createOptNewTopLevelFn(moduleOp, topLevelFunction);
2674 struct LoweringPattern {
2675 enum class Strategy { Once, Greedy };
2684 LogicalResult labelEntryPoint(StringRef topLevelFunction) {
2688 using OpRewritePattern::OpRewritePattern;
2689 LogicalResult matchAndRewrite(mlir::ModuleOp,
2690 PatternRewriter &)
const override {
2695 ConversionTarget target(getContext());
2696 target.addLegalDialect<calyx::CalyxDialect>();
2697 target.addLegalDialect<scf::SCFDialect>();
2698 target.addIllegalDialect<hw::HWDialect>();
2699 target.addIllegalDialect<comb::CombDialect>();
2702 target.addIllegalDialect<FuncDialect>();
2703 target.addIllegalDialect<ArithDialect>();
2705 AddIOp, SelectOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp, AndIOp,
2706 XOrIOp, OrIOp, ExtUIOp, TruncIOp, CondBranchOp, BranchOp, MulIOp,
2707 DivUIOp, DivSIOp, RemUIOp, RemSIOp, ReturnOp, arith::ConstantOp,
2708 IndexCastOp, BitcastOp, FuncOp, ExtSIOp, CallOp, AddFOp, SubFOp, MulFOp,
2709 CmpFOp, FPToSIOp, SIToFPOp, DivFOp, math::SqrtOp>();
2711 RewritePatternSet legalizePatterns(&getContext());
2712 legalizePatterns.add<DummyPattern>(&getContext());
2713 DenseSet<Operation *> legalizedOps;
2714 if (applyPartialConversion(getOperation(), target,
2715 std::move(legalizePatterns))
2726 template <
typename TPattern,
typename... PatternArgs>
2727 void addOncePattern(SmallVectorImpl<LoweringPattern> &
patterns,
2728 PatternArgs &&...args) {
2729 RewritePatternSet ps(&getContext());
2732 LoweringPattern{std::move(ps), LoweringPattern::Strategy::Once});
2735 template <
typename TPattern,
typename... PatternArgs>
2736 void addGreedyPattern(SmallVectorImpl<LoweringPattern> &
patterns,
2737 PatternArgs &&...args) {
2738 RewritePatternSet ps(&getContext());
2739 ps.add<TPattern>(&getContext(), args...);
2741 LoweringPattern{std::move(ps), LoweringPattern::Strategy::Greedy});
2744 LogicalResult runPartialPattern(RewritePatternSet &
pattern,
bool runOnce) {
2746 "Should only apply 1 partial lowering pattern at once");
2752 GreedyRewriteConfig config;
2753 config.setRegionSimplificationLevel(
2754 mlir::GreedySimplifyRegionLevel::Disabled);
2756 config.setMaxIterations(1);
2761 (void)applyPatternsGreedily(getOperation(), std::move(
pattern), config);
2770 FuncOp createNewTopLevelFn(ModuleOp moduleOp, std::string &baseName) {
2771 std::string newName =
"main";
2773 if (
auto *existingMainOp = SymbolTable::lookupSymbolIn(moduleOp, newName)) {
2774 auto existingMainFunc = dyn_cast<FuncOp>(existingMainOp);
2775 if (existingMainFunc ==
nullptr) {
2776 moduleOp.emitError() <<
"Symbol 'main' exists but is not a function";
2779 unsigned counter = 0;
2780 std::string newOldName = baseName;
2781 while (SymbolTable::lookupSymbolIn(moduleOp, newOldName))
2782 newOldName = llvm::join_items(
"_", baseName, std::to_string(++counter));
2783 existingMainFunc.setName(newOldName);
2784 if (baseName ==
"main")
2785 baseName = newOldName;
2789 OpBuilder builder(moduleOp.getContext());
2790 builder.setInsertionPointToStart(moduleOp.getBody());
2792 FunctionType funcType = builder.getFunctionType({}, {});
2795 builder.create<FuncOp>(moduleOp.getLoc(), newName, funcType))
2805 void insertCallFromNewTopLevel(OpBuilder &builder, FuncOp caller,
2807 if (caller.getBody().empty()) {
2808 caller.addEntryBlock();
2811 Block *callerEntryBlock = &caller.getBody().front();
2812 builder.setInsertionPointToStart(callerEntryBlock);
2816 SmallVector<Type, 4> nonMemRefCalleeArgTypes;
2817 for (
auto arg : callee.getArguments()) {
2818 if (!isa<MemRefType>(arg.getType())) {
2819 nonMemRefCalleeArgTypes.push_back(arg.getType());
2823 for (Type type : nonMemRefCalleeArgTypes) {
2824 callerEntryBlock->addArgument(type, caller.getLoc());
2827 FunctionType callerFnType = caller.getFunctionType();
2828 SmallVector<Type, 4> updatedCallerArgTypes(
2829 caller.getFunctionType().getInputs());
2830 updatedCallerArgTypes.append(nonMemRefCalleeArgTypes.begin(),
2831 nonMemRefCalleeArgTypes.end());
2832 caller.setType(FunctionType::get(caller.getContext(), updatedCallerArgTypes,
2833 callerFnType.getResults()));
2835 Block *calleeFnBody = &callee.getBody().front();
2836 unsigned originalCalleeArgNum = callee.getArguments().size();
2838 SmallVector<Value, 4> extraMemRefArgs;
2839 SmallVector<Type, 4> extraMemRefArgTypes;
2840 SmallVector<Value, 4> extraMemRefOperands;
2841 SmallVector<Operation *, 4> opsToModify;
2842 for (
auto &op : callee.getBody().getOps()) {
2843 if (isa<memref::AllocaOp, memref::AllocOp, memref::GetGlobalOp>(op))
2844 opsToModify.push_back(&op);
2849 builder.setInsertionPointToEnd(callerEntryBlock);
2850 for (
auto *op : opsToModify) {
2853 TypeSwitch<Operation *>(op)
2854 .Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
2855 newOpRes = builder.create<memref::AllocaOp>(callee.getLoc(),
2856 allocaOp.getType());
2858 .Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
2859 newOpRes = builder.create<memref::AllocOp>(callee.getLoc(),
2862 .Case<memref::GetGlobalOp>([&](memref::GetGlobalOp getGlobalOp) {
2863 newOpRes = builder.create<memref::GetGlobalOp>(
2864 caller.getLoc(), getGlobalOp.getType(), getGlobalOp.getName());
2866 .Default([&](Operation *defaultOp) {
2867 llvm::report_fatal_error(
"Unsupported operation in TypeSwitch");
2869 extraMemRefOperands.push_back(newOpRes);
2871 calleeFnBody->addArgument(newOpRes.getType(), callee.getLoc());
2872 BlockArgument newBodyArg = calleeFnBody->getArguments().back();
2873 op->getResult(0).replaceAllUsesWith(newBodyArg);
2875 extraMemRefArgs.push_back(newBodyArg);
2876 extraMemRefArgTypes.push_back(newBodyArg.getType());
2879 SmallVector<Type, 4> updatedCalleeArgTypes(
2880 callee.getFunctionType().getInputs());
2881 updatedCalleeArgTypes.append(extraMemRefArgTypes.begin(),
2882 extraMemRefArgTypes.end());
2883 callee.setType(FunctionType::get(callee.getContext(), updatedCalleeArgTypes,
2884 callee.getFunctionType().getResults()));
2886 unsigned otherArgsCount = 0;
2887 SmallVector<Value, 4> calleeArgFnOperands;
2888 builder.setInsertionPointToStart(callerEntryBlock);
2889 for (
auto arg : callee.getArguments().take_front(originalCalleeArgNum)) {
2890 if (isa<MemRefType>(arg.getType())) {
2891 auto memrefType = cast<MemRefType>(arg.getType());
2893 builder.create<memref::AllocOp>(callee.getLoc(), memrefType);
2894 calleeArgFnOperands.push_back(allocOp);
2896 auto callerArg = callerEntryBlock->getArgument(otherArgsCount++);
2897 calleeArgFnOperands.push_back(callerArg);
2901 SmallVector<Value, 4> fnOperands;
2902 fnOperands.append(calleeArgFnOperands.begin(), calleeArgFnOperands.end());
2903 fnOperands.append(extraMemRefOperands.begin(), extraMemRefOperands.end());
2905 SymbolRefAttr::get(builder.getContext(), callee.getSymName());
2906 auto resultTypes = callee.getResultTypes();
2908 builder.setInsertionPointToEnd(callerEntryBlock);
2909 builder.create<CallOp>(caller.getLoc(), calleeName, resultTypes,
2911 builder.create<ReturnOp>(caller.getLoc());
2917 LogicalResult createOptNewTopLevelFn(ModuleOp moduleOp,
2918 std::string &topLevelFunction) {
2919 auto hasMemrefArguments = [](FuncOp func) {
2921 func.getArguments().begin(), func.getArguments().end(),
2922 [](BlockArgument arg) { return isa<MemRefType>(arg.getType()); });
2928 auto funcOps = moduleOp.getOps<FuncOp>();
2929 bool hasMemrefArgsInTopLevel =
2930 std::any_of(funcOps.begin(), funcOps.end(), [&](
auto funcOp) {
2931 return funcOp.getName() == topLevelFunction &&
2932 hasMemrefArguments(funcOp);
2935 if (hasMemrefArgsInTopLevel) {
2936 auto newTopLevelFunc = createNewTopLevelFn(moduleOp, topLevelFunction);
2937 if (!newTopLevelFunc)
2940 OpBuilder builder(moduleOp.getContext());
2941 Operation *oldTopLevelFuncOp =
2942 SymbolTable::lookupSymbolIn(moduleOp, topLevelFunction);
2943 if (
auto oldTopLevelFunc = dyn_cast<FuncOp>(oldTopLevelFuncOp))
2944 insertCallFromNewTopLevel(builder, newTopLevelFunc, oldTopLevelFunc);
2946 moduleOp.emitOpError(
"Original top-level function not found!");
2949 topLevelFunction =
"main";
2956void SCFToCalyxPass::runOnOperation() {
2961 std::string topLevelFunction;
2962 if (failed(setTopLevelFunction(getOperation(), topLevelFunction))) {
2963 signalPassFailure();
2968 if (failed(labelEntryPoint(topLevelFunction))) {
2969 signalPassFailure();
2972 loweringState = std::make_shared<calyx::CalyxLoweringState>(getOperation(),
2983 DenseMap<FuncOp, calyx::ComponentOp> funcMap;
2984 SmallVector<LoweringPattern, 8> loweringPatterns;
2988 addOncePattern<FuncOpConversion>(loweringPatterns, patternState, funcMap,
2992 addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);
2995 addOncePattern<calyx::ConvertIndexTypes>(loweringPatterns, patternState,
2999 addOncePattern<calyx::BuildBasicBlockRegs>(loweringPatterns, patternState,
3002 addOncePattern<calyx::BuildCallInstance>(loweringPatterns, patternState,
3006 addOncePattern<calyx::BuildReturnRegs>(loweringPatterns, patternState,
3012 addOncePattern<BuildWhileGroups>(loweringPatterns, patternState, funcMap,
3018 addOncePattern<BuildForGroups>(loweringPatterns, patternState, funcMap,
3021 addOncePattern<BuildIfGroups>(loweringPatterns, patternState, funcMap,
3031 addOncePattern<BuildOpGroups>(loweringPatterns, patternState, funcMap,
3037 addOncePattern<BuildControl>(loweringPatterns, patternState, funcMap,
3042 addOncePattern<calyx::InlineCombGroups>(loweringPatterns, patternState,
3047 addOncePattern<LateSSAReplacement>(loweringPatterns, patternState, funcMap,
3053 addGreedyPattern<calyx::EliminateUnusedCombGroups>(loweringPatterns);
3057 addOncePattern<calyx::RewriteMemoryAccesses>(loweringPatterns, patternState,
3062 addOncePattern<CleanupFuncOps>(loweringPatterns, patternState, funcMap,
3066 for (
auto &pat : loweringPatterns) {
3069 pat.strategy == LoweringPattern::Strategy::Once);
3072 signalPassFailure();
3079 RewritePatternSet cleanupPatterns(&getContext());
3083 applyPatternsGreedily(getOperation(), std::move(cleanupPatterns)))) {
3084 signalPassFailure();
3088 if (ciderSourceLocationMetadata) {
3091 SmallVector<Attribute, 16> sourceLocations;
3092 getOperation()->walk([&](calyx::ComponentOp component) {
3096 MLIRContext *context = getOperation()->getContext();
3097 getOperation()->setAttr(
"calyx.metadata",
3098 ArrayAttr::get(context, sourceLocations));
3107std::unique_ptr<OperationPass<ModuleOp>>
3109 return std::make_unique<SCFToCalyxPass>(topLevelFunction);
assert(baseType &&"element must be base type")
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
std::shared_ptr< calyx::CalyxLoweringState > loweringState
LogicalResult partialPatternRes
An interface for conversion passes that lower Calyx programs.
std::string irName(ValueOrBlock &v)
Returns a meaningful name for a value within the program scope.
std::string blockName(Block *b)
Returns a meaningful name for a block within the program scope (removes the ^ prefix from block names...
StringRef getTopLevelFunction() const
Returns the name of the top-level function in the source program.
T * getState(calyx::ComponentOp op)
Returns the component lowering state associated with op.
void setFuncOpResultMapping(const DenseMap< unsigned, unsigned > &mapping)
Assign a mapping between the source funcOp result indices and the corresponding output port indices o...
std::string getUniqueName(StringRef prefix)
Returns a unique name within compOp with the provided prefix.
calyx::ComponentOp component
The component which this lowering state is associated to.
void registerMemoryInterface(Value memref, const calyx::MemoryInterface &memoryInterface)
Registers a memory interface as being associated with a memory identified by 'memref'.
calyx::ComponentOp getComponentOp()
Returns the calyx::ComponentOp associated with this lowering state.
void setDataField(StringRef name, llvm::json::Array data)
ComponentLoweringStateInterface(calyx::ComponentOp component)
void setFormat(StringRef name, std::string numType, bool isSigned, unsigned width)
FuncOpPartialLoweringPatterns are patterns which intend to match on FuncOps and then perform their ow...
calyx::ComponentOp getComponent() const
Returns the component operation associated with the currently executing partial lowering.
DenseMap< mlir::func::FuncOp, calyx::ComponentOp > & functionMapping
CalyxLoweringState & loweringState() const
Return the calyx lowering state for this pattern.
FuncOpPartialLoweringPattern(MLIRContext *context, LogicalResult &resRef, PatternApplicationState &patternState, DenseMap< mlir::func::FuncOp, calyx::ComponentOp > &map, calyx::CalyxLoweringState &state)
calyx::GroupOp getLoopLatchGroup(ScfWhileOp op)
Retrieve the loop latch group registered for op.
void setLoopLatchGroup(ScfWhileOp op, calyx::GroupOp group)
Registers grp to be the loop latch group of op.
calyx::RegisterOp getLoopIterReg(ScfForOp op, unsigned idx)
Return a mapping of block argument indices to block argument.
void addLoopIterReg(ScfWhileOp op, calyx::RegisterOp reg, unsigned idx)
Register reg as being the idx'th iter_args register for 'op'.
void setLoopInitGroups(ScfWhileOp op, SmallVector< calyx::GroupOp > groups)
Registers groups to be the loop init groups of op.
SmallVector< calyx::GroupOp > getLoopInitGroups(ScfWhileOp op)
Retrieve the loop init groups registered for op.
calyx::GroupOp buildLoopIterArgAssignments(OpBuilder &builder, ScfWhileOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
Creates a new group that assigns the 'ops' values to the iter arg registers of the loop operation.
const DenseMap< unsigned, calyx::RegisterOp > & getLoopIterRegs(ScfWhileOp op)
Return a mapping of block argument indices to block argument.
PatternApplicationState & patternState
scf::ForOp getOperation()
Location getLoc() override
RepeatOpInterface(scf::ForOp op)
Holds common utilities used for scheduling when lowering to Calyx.
scf::WhileOp getOperation()
WhileOpInterface(scf::WhileOp op)
Location getLoc() override
Builds a control schedule by traversing the CFG of the function and associating this with the previou...
calyx::RepeatOp buildForCtrlOp(ScfForOp forOp, SmallVector< calyx::GroupOp > const &initGroups, uint64_t bound, PatternRewriter &rewriter) const
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult schedulePath(PatternRewriter &rewriter, const DenseSet< Block * > &path, Location loc, Block *from, Block *to, Block *parentCtrlBlock) const
Schedules a block by inserting a branch argument assignment block (if any) before recursing into the ...
calyx::WhileOp buildWhileCtrlOp(ScfWhileOp whileOp, SmallVector< calyx::GroupOp > initGroups, PatternRewriter &rewriter) const
LogicalResult scheduleBasicBlock(PatternRewriter &rewriter, const DenseSet< Block * > &path, mlir::Block *parentCtrlBlock, mlir::Block *block) const
Sequentially schedules the groups that registered themselves with 'block'.
LogicalResult buildCFGControl(DenseSet< Block * > path, PatternRewriter &rewriter, mlir::Block *parentCtrlBlock, mlir::Block *preBlock, mlir::Block *block) const
void insertParInitGroups(PatternRewriter &rewriter, Location loc, const SmallVector< calyx::GroupOp > &initGroups) const
In BuildForGroups, a register is created for the iteration argument of the for op.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Iterate through the operations of a source function and instantiate components or primitives based on...
BuildOpGroups(MLIRContext *context, LogicalResult &resRef, calyx::PatternApplicationState &patternState, DenseMap< mlir::func::FuncOp, calyx::ComponentOp > &map, calyx::CalyxLoweringState &state, mlir::Pass::Option< std::string > &writeJsonOpt)
LogicalResult buildCmpIOpHelper(PatternRewriter &rewriter, CmpIOp op) const
void setupCmpIOp(PatternRewriter &rewriter, CmpIOp cmpIOp, Operation *group, calyx::RegisterOp &condReg, calyx::RegisterOp &resReg, TCalyxLibOp calyxOp) const
LogicalResult buildFpIntTypeCastOp(PatternRewriter &rewriter, TSrcOp op, unsigned inputWidth, unsigned outputWidth, StringRef signedPort) const
TGroupOp createGroupForOp(PatternRewriter &rewriter, Operation *op) const
Creates a group named by the basic block which the input op resides in.
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op) const
buildLibraryOp which provides in- and output types based on the operands and results of the op argume...
LogicalResult buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp) const
Op builder specializations.
calyx::RegisterOp createSignalRegister(PatternRewriter &rewriter, Value signal, bool invert, StringRef nameSuffix, calyx::CompareFOpIEEE754 calyxCmpFOp, calyx::GroupOp group) const
void assignAddressPorts(PatternRewriter &rewriter, Location loc, calyx::GroupInterface group, calyx::MemoryInterface memoryInterface, Operation::operand_range addressValues) const
Creates assignments within the provided group to the address ports of the memoryOp based on the provi...
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op, TypeRange srcTypes, TypeRange dstTypes) const
buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the source operation TSrcOp.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult buildLibraryBinaryPipeOp(PatternRewriter &rewriter, TSrcOp op, TOpType opPipe, Value out) const
buildLibraryBinaryPipeOp will build a TCalyxLibBinaryPipeOp, to deal with MulIOp, DivUIOp and RemUIOp...
mlir::Pass::Option< std::string > & writeJson
In BuildWhileGroups, a register is created for each iteration argumenet of the while op.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Erases FuncOp operations.
LogicalResult matchAndRewrite(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Handles the current state of lowering of a Calyx component.
ComponentLoweringState(calyx::ComponentOp component)
void setForLoopInitGroups(ScfForOp op, SmallVector< calyx::GroupOp > groups)
calyx::GroupOp buildForLoopIterArgAssignments(OpBuilder &builder, ScfForOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
void setForLoopLatchGroup(ScfForOp op, calyx::GroupOp group)
SmallVector< calyx::GroupOp > getForLoopInitGroups(ScfForOp op)
void addForLoopIterReg(ScfForOp op, calyx::RegisterOp reg, unsigned idx)
calyx::GroupOp getForLoopLatchGroup(ScfForOp op)
calyx::RegisterOp getForLoopIterReg(ScfForOp op, unsigned idx)
const DenseMap< unsigned, calyx::RegisterOp > & getForLoopIterRegs(ScfForOp op)
DenseMap< Operation *, calyx::GroupOp > elseGroup
DenseMap< Operation *, calyx::GroupOp > thenGroup
void setCondReg(scf::IfOp op, calyx::RegisterOp regOp)
const DenseMap< unsigned, calyx::RegisterOp > & getResultRegs(scf::IfOp op)
void setElseGroup(scf::IfOp op, calyx::GroupOp group)
void setResultRegs(scf::IfOp op, calyx::RegisterOp reg, unsigned idx)
void setThenGroup(scf::IfOp op, calyx::GroupOp group)
DenseMap< Operation *, DenseMap< unsigned, calyx::RegisterOp > > resultRegs
calyx::RegisterOp getResultRegs(scf::IfOp op, unsigned idx)
calyx::RegisterOp getCondReg(scf::IfOp op)
calyx::GroupOp getThenGroup(scf::IfOp op)
DenseMap< Operation *, calyx::RegisterOp > condReg
calyx::GroupOp getElseGroup(scf::IfOp op)
Inlines Calyx ExecuteRegionOp operations within their parent blocks.
LogicalResult matchAndRewrite(scf::ExecuteRegionOp execOp, PatternRewriter &rewriter) const override
LateSSAReplacement contains various functions for replacing SSA values that were not replaced during ...
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &) const override
std::optional< int64_t > getBound() override
Block::BlockArgListType getBodyArgs() override
Block * getBodyBlock() override
Block * getBodyBlock() override
ScfWhileOp(scf::WhileOp op)
Block::BlockArgListType getBodyArgs() override
Value getConditionValue() override
std::optional< int64_t > getBound() override
Block * getConditionBlock() override
Stores the state information for condition checks involving sequential computation.
void setSeqResReg(Operation *op, calyx::RegisterOp reg)
calyx::RegisterOp getSeqResReg(Operation *op)
DenseMap< Operation *, calyx::RegisterOp > resultRegs
calyx::GroupOp buildWhileLoopIterArgAssignments(OpBuilder &builder, ScfWhileOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
void setWhileLoopInitGroups(ScfWhileOp op, SmallVector< calyx::GroupOp > groups)
SmallVector< calyx::GroupOp > getWhileLoopInitGroups(ScfWhileOp op)
void addWhileLoopIterReg(ScfWhileOp op, calyx::RegisterOp reg, unsigned idx)
void setWhileLoopLatchGroup(ScfWhileOp op, calyx::GroupOp group)
const DenseMap< unsigned, calyx::RegisterOp > & getWhileLoopIterRegs(ScfWhileOp op)
calyx::GroupOp getWhileLoopLatchGroup(ScfWhileOp op)
bool parentIsSeqCell(Value value)
void addMandatoryComponentPorts(PatternRewriter &rewriter, SmallVectorImpl< calyx::PortInfo > &ports)
void buildAssignmentsForRegisterWrite(OpBuilder &builder, calyx::GroupOp groupOp, calyx::ComponentOp componentOp, calyx::RegisterOp ®, Value inputValue)
Creates register assignment operations within the provided groupOp.
DenseMap< const mlir::RewritePattern *, SmallPtrSet< Operation *, 16 > > PatternApplicationState
Extra state that is passed to all PartialLoweringPatterns so they can record when they have run on an...
PredicateInfo getPredicateInfo(mlir::arith::CmpFPredicate pred)
Type normalizeType(OpBuilder &builder, Type type)
LogicalResult applyModuleOpConversion(mlir::ModuleOp, StringRef topLevelFunction)
Helper to update the top-level ModuleOp to set the entrypoing function.
WalkResult getCiderSourceLocationMetadata(calyx::ComponentOp component, SmallVectorImpl< Attribute > &sourceLocations)
bool matchConstantOp(Operation *op, APInt &value)
unsigned handleZeroWidth(int64_t dim)
hw::ConstantOp createConstant(Location loc, OpBuilder &builder, ComponentOp component, size_t width, size_t value)
A helper function to create constants in the HW dialect.
bool noStoresToMemory(Value memoryReference)
bool singleLoadFromMemory(Value memoryReference)
Type toBitVector(T type)
Performs a bit cast from a non-signless integer type value, such as a floating point value,...
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
static constexpr std::string_view sPortNameAttr
static constexpr std::string_view unrolledParallelAttr
static LogicalResult buildAllocOp(ComponentLoweringState &componentState, PatternRewriter &rewriter, TAllocOp allocOp)
std::variant< calyx::GroupOp, WhileScheduleable, ForScheduleable, IfScheduleable, CallScheduleable, ParScheduleable > Scheduleable
A variant of types representing scheduleable operations.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< OperationPass< ModuleOp > > createSCFToCalyxPass(std::string topLevelFunction="")
Create an SCF to Calyx conversion pass.
When building groups which contain accesses to multiple sequential components, a group_done op is cre...
GroupDoneOp's are terminator operations and should therefore be the last operator in a group.
This holds information about the port for either a Component or Cell.
Predicate information for the floating point comparisons.
calyx::InstanceOp instanceOp
Instance for invoking.
ScfForOp forOp
For operation to schedule.
Creates a new Calyx component for each FuncOp in the program.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
scf::ParallelOp parOp
Parallel operation to schedule.
ScfWhileOp whileOp
While operation to schedule.