19#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
20#include "mlir/Conversion/LLVMCommon/Pattern.h"
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/MemRef/IR/MemRef.h"
25#include "mlir/Dialect/SCF/IR/SCF.h"
26#include "mlir/IR/AsmState.h"
27#include "mlir/IR/Matchers.h"
28#include "mlir/Pass/Pass.h"
29#include "mlir/Support/LogicalResult.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/LogicalResult.h"
33#include "llvm/Support/raw_os_ostream.h"
34#include "llvm/Support/raw_ostream.h"
44#define GEN_PASS_DEF_SCFTOCALYX
45#include "circt/Conversion/Passes.h.inc"
50using namespace mlir::arith;
51using namespace mlir::cf;
54class ComponentLoweringStateInterface;
80 std::optional<int64_t>
getBound()
override {
return std::nullopt; }
142 Operation *operation = op.getOperation();
144 "A then group was already set for this scf::IfOp!\n");
149 auto it =
thenGroup.find(op.getOperation());
151 "No then group was set for this scf::IfOp!\n");
156 Operation *operation = op.getOperation();
158 "An else group was already set for this scf::IfOp!\n");
163 auto it =
elseGroup.find(op.getOperation());
165 "No else group was set for this scf::IfOp!\n");
171 "A register was already registered for the given yield result.\n");
172 assert(idx < op->getNumOperands());
182 auto it = regs.find(idx);
183 assert(it != regs.end() &&
"resultReg not found");
190 DenseMap<Operation *, DenseMap<unsigned, calyx::RegisterOp>>
resultRegs;
200 OpBuilder &builder,
ScfWhileOp op, calyx::ComponentOp componentOp,
201 Twine uniqueSuffix, MutableArrayRef<OpOperand> ops) {
208 const DenseMap<unsigned, calyx::RegisterOp> &
219 SmallVector<calyx::GroupOp> groups) {
231 OpBuilder &builder,
ScfForOp op, calyx::ComponentOp componentOp,
232 Twine uniqueSuffix, MutableArrayRef<OpOperand> ops) {
279 DenseMap<mlir::func::FuncOp, calyx::ComponentOp> &map,
281 mlir::Pass::Option<std::string> &writeJsonOpt)
284 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
288 PatternRewriter &rewriter)
const override {
291 bool opBuiltSuccessfully =
true;
292 funcOp.walk([&](Operation *_op) {
293 opBuiltSuccessfully &=
294 TypeSwitch<mlir::Operation *, bool>(_op)
295 .template Case<arith::ConstantOp, ReturnOp, BranchOpInterface,
297 scf::YieldOp, scf::WhileOp, scf::ForOp, scf::IfOp,
298 scf::ParallelOp, scf::ReduceOp,
299 scf::ExecuteRegionOp,
301 memref::AllocOp, memref::AllocaOp, memref::LoadOp,
302 memref::StoreOp, memref::GetGlobalOp,
304 AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
305 AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
306 MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
308 AddFOp, MulFOp, CmpFOp,
310 SelectOp, IndexCastOp, CallOp>(
311 [&](
auto op) {
return buildOp(rewriter, op).succeeded(); })
312 .
template Case<FuncOp, scf::ConditionOp>([&](
auto) {
316 .Default([&](
auto op) {
317 op->emitError() <<
"Unhandled operation during BuildOpGroups()";
321 return opBuiltSuccessfully ? WalkResult::advance()
322 : WalkResult::interrupt();
326 if (
auto fileLoc = dyn_cast<mlir::FileLineColLoc>(funcOp->getLoc())) {
327 std::string filename = fileLoc.getFilename().str();
328 std::filesystem::path path(filename);
329 std::string jsonFileName =
writeJson.getValue() +
".json";
330 auto outFileName = path.parent_path().append(jsonFileName);
331 std::ofstream outFile(outFileName);
333 if (!outFile.is_open()) {
334 llvm::errs() <<
"Unable to open file: " << outFileName.string()
338 llvm::raw_os_ostream llvmOut(outFile);
339 llvm::json::OStream jsonOS(llvmOut, 2);
340 jsonOS.value(getState<ComponentLoweringState>().getExtMemData());
346 return success(opBuiltSuccessfully);
352 LogicalResult
buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp)
const;
353 LogicalResult
buildOp(PatternRewriter &rewriter,
354 BranchOpInterface brOp)
const;
355 LogicalResult
buildOp(PatternRewriter &rewriter,
356 arith::ConstantOp constOp)
const;
357 LogicalResult
buildOp(PatternRewriter &rewriter, SelectOp op)
const;
358 LogicalResult
buildOp(PatternRewriter &rewriter, AddIOp op)
const;
359 LogicalResult
buildOp(PatternRewriter &rewriter, SubIOp op)
const;
360 LogicalResult
buildOp(PatternRewriter &rewriter, MulIOp op)
const;
361 LogicalResult
buildOp(PatternRewriter &rewriter, DivUIOp op)
const;
362 LogicalResult
buildOp(PatternRewriter &rewriter, DivSIOp op)
const;
363 LogicalResult
buildOp(PatternRewriter &rewriter, RemUIOp op)
const;
364 LogicalResult
buildOp(PatternRewriter &rewriter, RemSIOp op)
const;
365 LogicalResult
buildOp(PatternRewriter &rewriter, AddFOp op)
const;
366 LogicalResult
buildOp(PatternRewriter &rewriter, MulFOp op)
const;
367 LogicalResult
buildOp(PatternRewriter &rewriter, CmpFOp op)
const;
368 LogicalResult
buildOp(PatternRewriter &rewriter, ShRUIOp op)
const;
369 LogicalResult
buildOp(PatternRewriter &rewriter, ShRSIOp op)
const;
370 LogicalResult
buildOp(PatternRewriter &rewriter, ShLIOp op)
const;
371 LogicalResult
buildOp(PatternRewriter &rewriter, AndIOp op)
const;
372 LogicalResult
buildOp(PatternRewriter &rewriter, OrIOp op)
const;
373 LogicalResult
buildOp(PatternRewriter &rewriter, XOrIOp op)
const;
374 LogicalResult
buildOp(PatternRewriter &rewriter, CmpIOp op)
const;
375 LogicalResult
buildOp(PatternRewriter &rewriter, TruncIOp op)
const;
376 LogicalResult
buildOp(PatternRewriter &rewriter, ExtUIOp op)
const;
377 LogicalResult
buildOp(PatternRewriter &rewriter, ExtSIOp op)
const;
378 LogicalResult
buildOp(PatternRewriter &rewriter, ReturnOp op)
const;
379 LogicalResult
buildOp(PatternRewriter &rewriter, IndexCastOp op)
const;
380 LogicalResult
buildOp(PatternRewriter &rewriter, memref::AllocOp op)
const;
381 LogicalResult
buildOp(PatternRewriter &rewriter, memref::AllocaOp op)
const;
382 LogicalResult
buildOp(PatternRewriter &rewriter,
383 memref::GetGlobalOp op)
const;
384 LogicalResult
buildOp(PatternRewriter &rewriter, memref::LoadOp op)
const;
385 LogicalResult
buildOp(PatternRewriter &rewriter, memref::StoreOp op)
const;
386 LogicalResult
buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp)
const;
387 LogicalResult
buildOp(PatternRewriter &rewriter, scf::ForOp forOp)
const;
388 LogicalResult
buildOp(PatternRewriter &rewriter, scf::IfOp ifOp)
const;
389 LogicalResult
buildOp(PatternRewriter &rewriter,
390 scf::ReduceOp reduceOp)
const;
391 LogicalResult
buildOp(PatternRewriter &rewriter,
392 scf::ParallelOp parallelOp)
const;
393 LogicalResult
buildOp(PatternRewriter &rewriter,
394 scf::ExecuteRegionOp executeRegionOp)
const;
395 LogicalResult
buildOp(PatternRewriter &rewriter, CallOp callOp)
const;
399 template <
typename TGroupOp,
typename TCalyxLibOp,
typename TSrcOp>
401 TypeRange srcTypes, TypeRange dstTypes)
const {
402 SmallVector<Type> types;
403 for (Type srcType : srcTypes)
405 for (Type dstType : dstTypes)
409 getState<ComponentLoweringState>().getNewLibraryOpInstance<TCalyxLibOp>(
410 rewriter, op.getLoc(), types);
412 auto directions = calyxOp.portDirections();
413 SmallVector<Value, 4> opInputPorts;
414 SmallVector<Value, 4> opOutputPorts;
415 for (
auto dir : enumerate(directions)) {
417 opInputPorts.push_back(calyxOp.getResult(dir.index()));
419 opOutputPorts.push_back(calyxOp.getResult(dir.index()));
422 opInputPorts.size() == op->getNumOperands() &&
423 opOutputPorts.size() == op->getNumResults() &&
424 "Expected an equal number of in/out ports in the Calyx library op with "
425 "respect to the number of operands/results of the source operation.");
428 auto group = createGroupForOp<TGroupOp>(rewriter, op);
429 rewriter.setInsertionPointToEnd(group.getBodyBlock());
430 for (
auto dstOp : enumerate(opInputPorts))
431 rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
432 op->getOperand(dstOp.index()));
435 for (
auto res : enumerate(opOutputPorts)) {
436 getState<ComponentLoweringState>().registerEvaluatingGroup(res.value(),
438 op->getResult(res.index()).replaceAllUsesWith(res.value());
445 template <
typename TGroupOp,
typename TCalyxLibOp,
typename TSrcOp>
447 return buildLibraryOp<TGroupOp, TCalyxLibOp, TSrcOp>(
448 rewriter, op, op.getOperandTypes(), op->getResultTypes());
452 template <
typename TGroupOp>
454 Block *block = op->getBlock();
455 auto groupName = getState<ComponentLoweringState>().getUniqueName(
457 return calyx::createGroup<TGroupOp>(
458 rewriter, getState<ComponentLoweringState>().getComponentOp(),
459 op->getLoc(), groupName);
464 template <
typename TOpType,
typename TSrcOp>
466 TOpType opPipe, Value out)
const {
467 StringRef opName = TSrcOp::getOperationName().split(
".").second;
468 Location loc = op.getLoc();
469 Type width = op.getResult().getType();
470 auto reg = createRegister(
471 op.getLoc(), rewriter,
getComponent(), width.getIntOrFloatBitWidth(),
472 getState<ComponentLoweringState>().getUniqueName(opName));
474 auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
475 OpBuilder builder(group->getRegion(0));
476 getState<ComponentLoweringState>().addBlockScheduleable(op->getBlock(),
479 rewriter.setInsertionPointToEnd(group.getBodyBlock());
480 rewriter.create<calyx::AssignOp>(loc, opPipe.getLeft(), op.getLhs());
481 rewriter.create<calyx::AssignOp>(loc, opPipe.getRight(), op.getRhs());
483 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), out);
485 rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), opPipe.getDone());
490 rewriter.create<calyx::AssignOp>(
491 loc, opPipe.getGo(), c1,
494 rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());
498 op.getResult().replaceAllUsesWith(reg.getOut());
500 if (isa<calyx::AddFOpIEEE754>(opPipe)) {
501 auto opFOp = cast<calyx::AddFOpIEEE754>(opPipe);
503 if (isa<arith::AddFOp>(op)) {
504 subOp = createConstant(loc, rewriter,
getComponent(), 1,
507 subOp = createConstant(loc, rewriter,
getComponent(), 1,
510 rewriter.create<calyx::AssignOp>(loc, opFOp.getSubOp(), subOp);
514 getState<ComponentLoweringState>().registerEvaluatingGroup(out, group);
515 getState<ComponentLoweringState>().registerEvaluatingGroup(opPipe.getLeft(),
517 getState<ComponentLoweringState>().registerEvaluatingGroup(
518 opPipe.getRight(), group);
526 calyx::GroupInterface group,
528 Operation::operand_range addressValues)
const {
529 IRRewriter::InsertionGuard guard(rewriter);
530 rewriter.setInsertionPointToEnd(group.getBody());
531 auto addrPorts = memoryInterface.
addrPorts();
532 if (addressValues.empty()) {
534 addrPorts.size() == 1 &&
535 "We expected a 1 dimensional memory of size 1 because there were no "
536 "address assignment values");
538 rewriter.create<calyx::AssignOp>(
542 assert(addrPorts.size() == addressValues.size() &&
543 "Mismatch between number of address ports of the provided memory "
544 "and address assignment values");
545 for (
auto address : enumerate(addressValues))
546 rewriter.create<calyx::AssignOp>(loc, addrPorts[address.index()],
552 Value signal,
bool invert,
553 StringRef nameSuffix,
554 calyx::CompareFOpIEEE754 calyxCmpFOp,
555 calyx::GroupOp group)
const {
556 Location loc = calyxCmpFOp.getLoc();
557 IntegerType one = rewriter.getI1Type();
559 OpBuilder builder(group->getRegion(0));
560 auto reg = createRegister(
561 loc, rewriter, component, 1,
562 getState<ComponentLoweringState>().getUniqueName(nameSuffix));
563 rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(),
564 calyxCmpFOp.getDone());
566 auto notLibOp = getState<ComponentLoweringState>()
567 .getNewLibraryOpInstance<calyx::NotLibOp>(
568 rewriter, loc, {one, one});
569 rewriter.create<calyx::AssignOp>(loc, notLibOp.getIn(), signal);
570 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), notLibOp.getOut());
571 getState<ComponentLoweringState>().registerEvaluatingGroup(
572 notLibOp.getOut(), group);
574 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), signal);
580 memref::LoadOp loadOp)
const {
581 Value memref = loadOp.getMemref();
582 auto memoryInterface =
583 getState<ComponentLoweringState>().getMemoryInterface(memref);
584 auto group = createGroupForOp<calyx::GroupOp>(rewriter, loadOp);
586 loadOp.getIndices());
588 rewriter.setInsertionPointToEnd(group.getBodyBlock());
593 createConstant(loadOp.getLoc(), rewriter,
getComponent(), 1, 1);
594 if (memoryInterface.readEnOpt().has_value()) {
597 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), memoryInterface.readEn(),
599 regWriteEn = memoryInterface.done();
606 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(),
607 memoryInterface.done());
617 res = loadOp.getResult();
619 }
else if (memoryInterface.contentEnOpt().has_value()) {
624 rewriter.create<calyx::AssignOp>(loadOp.getLoc(),
625 memoryInterface.contentEn(), oneI1);
626 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), memoryInterface.writeEn(),
628 regWriteEn = memoryInterface.done();
635 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(),
636 memoryInterface.done());
646 res = loadOp.getResult();
658 auto reg = createRegister(
660 loadOp.getMemRefType().getElementTypeBitWidth(),
661 getState<ComponentLoweringState>().getUniqueName(
"load"));
662 rewriter.setInsertionPointToEnd(group.getBodyBlock());
663 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), reg.getIn(),
664 memoryInterface.readData());
665 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), reg.getWriteEn(),
667 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(), reg.getDone());
668 loadOp.getResult().replaceAllUsesWith(reg.getOut());
672 getState<ComponentLoweringState>().registerEvaluatingGroup(res, group);
673 getState<ComponentLoweringState>().addBlockScheduleable(loadOp->getBlock(),
679 memref::StoreOp storeOp)
const {
680 auto memoryInterface = getState<ComponentLoweringState>().getMemoryInterface(
681 storeOp.getMemref());
682 auto group = createGroupForOp<calyx::GroupOp>(rewriter, storeOp);
686 getState<ComponentLoweringState>().addBlockScheduleable(storeOp->getBlock(),
689 storeOp.getIndices());
690 rewriter.setInsertionPointToEnd(group.getBodyBlock());
691 rewriter.create<calyx::AssignOp>(
692 storeOp.getLoc(), memoryInterface.writeData(), storeOp.getValueToStore());
693 rewriter.create<calyx::AssignOp>(
694 storeOp.getLoc(), memoryInterface.writeEn(),
695 createConstant(storeOp.getLoc(), rewriter,
getComponent(), 1, 1));
696 if (memoryInterface.contentEnOpt().has_value()) {
698 rewriter.create<calyx::AssignOp>(
699 storeOp.getLoc(), memoryInterface.contentEn(),
700 createConstant(storeOp.getLoc(), rewriter,
getComponent(), 1, 1));
702 rewriter.create<calyx::GroupDoneOp>(storeOp.getLoc(), memoryInterface.done());
709 Location loc = mul.getLoc();
710 Type width = mul.getResult().getType(), one = rewriter.getI1Type();
712 getState<ComponentLoweringState>()
713 .getNewLibraryOpInstance<calyx::MultPipeLibOp>(
714 rewriter, loc, {one, one, one, width, width, width, one});
715 return buildLibraryBinaryPipeOp<calyx::MultPipeLibOp>(
716 rewriter, mul, mulPipe,
722 Location loc = div.getLoc();
723 Type width = div.getResult().getType(), one = rewriter.getI1Type();
725 getState<ComponentLoweringState>()
726 .getNewLibraryOpInstance<calyx::DivUPipeLibOp>(
727 rewriter, loc, {one, one, one, width, width, width, one});
728 return buildLibraryBinaryPipeOp<calyx::DivUPipeLibOp>(
729 rewriter, div, divPipe,
735 Location loc = div.getLoc();
736 Type width = div.getResult().getType(), one = rewriter.getI1Type();
738 getState<ComponentLoweringState>()
739 .getNewLibraryOpInstance<calyx::DivSPipeLibOp>(
740 rewriter, loc, {one, one, one, width, width, width, one});
741 return buildLibraryBinaryPipeOp<calyx::DivSPipeLibOp>(
742 rewriter, div, divPipe,
748 Location loc = rem.getLoc();
749 Type width = rem.getResult().getType(), one = rewriter.getI1Type();
751 getState<ComponentLoweringState>()
752 .getNewLibraryOpInstance<calyx::RemUPipeLibOp>(
753 rewriter, loc, {one, one, one, width, width, width, one});
754 return buildLibraryBinaryPipeOp<calyx::RemUPipeLibOp>(
755 rewriter, rem, remPipe,
761 Location loc = rem.getLoc();
762 Type width = rem.getResult().getType(), one = rewriter.getI1Type();
764 getState<ComponentLoweringState>()
765 .getNewLibraryOpInstance<calyx::RemSPipeLibOp>(
766 rewriter, loc, {one, one, one, width, width, width, one});
767 return buildLibraryBinaryPipeOp<calyx::RemSPipeLibOp>(
768 rewriter, rem, remPipe,
774 Location loc = addf.getLoc();
775 IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
776 five = rewriter.getIntegerType(5),
777 width = rewriter.getIntegerType(
778 addf.getType().getIntOrFloatBitWidth());
780 getState<ComponentLoweringState>()
781 .getNewLibraryOpInstance<calyx::AddFOpIEEE754>(
783 {one, one, one, one, one, width, width, three, width, five, one});
784 return buildLibraryBinaryPipeOp<calyx::AddFOpIEEE754>(rewriter, addf, addFOp,
790 Location loc = mulf.getLoc();
791 IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
792 five = rewriter.getIntegerType(5),
793 width = rewriter.getIntegerType(
794 mulf.getType().getIntOrFloatBitWidth());
796 getState<ComponentLoweringState>()
797 .getNewLibraryOpInstance<calyx::MulFOpIEEE754>(
799 {one, one, one, one, width, width, three, width, five, one});
800 return buildLibraryBinaryPipeOp<calyx::MulFOpIEEE754>(rewriter, mulf, mulFOp,
806 Location loc = cmpf.getLoc();
807 IntegerType one = rewriter.getI1Type(), five = rewriter.getIntegerType(5),
808 width = rewriter.getIntegerType(
809 cmpf.getLhs().getType().getIntOrFloatBitWidth());
810 auto calyxCmpFOp = getState<ComponentLoweringState>()
811 .getNewLibraryOpInstance<calyx::CompareFOpIEEE754>(
813 {one, one, one, width, width, one, one, one, one,
820 using CombLogic = PredicateInfo::CombLogic;
821 using Port = PredicateInfo::InputPorts::Port;
823 if (info.logic == CombLogic::None) {
824 if (cmpf.getPredicate() == CmpFPredicate::AlwaysTrue) {
825 cmpf.getResult().replaceAllUsesWith(c1);
829 if (cmpf.getPredicate() == CmpFPredicate::AlwaysFalse) {
830 cmpf.getResult().replaceAllUsesWith(c0);
836 StringRef opName = cmpf.getOperationName().split(
".").second;
839 getState<ComponentLoweringState>().getUniqueName(opName));
842 auto group = createGroupForOp<calyx::GroupOp>(rewriter, cmpf);
843 OpBuilder builder(group->getRegion(0));
844 getState<ComponentLoweringState>().addBlockScheduleable(cmpf->getBlock(),
847 rewriter.setInsertionPointToEnd(group.getBodyBlock());
848 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getLeft(), cmpf.getLhs());
849 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getRight(), cmpf.getRhs());
851 bool signalingFlag =
false;
852 switch (cmpf.getPredicate()) {
853 case CmpFPredicate::UGT:
854 case CmpFPredicate::UGE:
855 case CmpFPredicate::ULT:
856 case CmpFPredicate::ULE:
857 case CmpFPredicate::OGT:
858 case CmpFPredicate::OGE:
859 case CmpFPredicate::OLT:
860 case CmpFPredicate::OLE:
861 signalingFlag =
true;
863 case CmpFPredicate::UEQ:
864 case CmpFPredicate::UNE:
865 case CmpFPredicate::OEQ:
866 case CmpFPredicate::ONE:
867 case CmpFPredicate::UNO:
868 case CmpFPredicate::ORD:
869 case CmpFPredicate::AlwaysTrue:
870 case CmpFPredicate::AlwaysFalse:
871 signalingFlag =
false;
877 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getSignaling(),
878 signalingFlag ? c1 : c0);
881 SmallVector<calyx::RegisterOp> inputRegs;
882 for (
const auto &input : info.inputPorts) {
884 switch (input.port) {
886 signal = calyxCmpFOp.getEq();
890 signal = calyxCmpFOp.getGt();
894 signal = calyxCmpFOp.getLt();
897 case Port::Unordered: {
898 signal = calyxCmpFOp.getUnordered();
902 std::string nameSuffix =
903 (input.port == PredicateInfo::InputPorts::Port::Unordered)
907 nameSuffix, calyxCmpFOp, group);
908 inputRegs.push_back(signalReg);
912 Value outputValue, doneValue;
913 switch (info.logic) {
914 case CombLogic::None: {
916 outputValue = inputRegs[0].getOut();
917 doneValue = inputRegs[0].getOut();
920 case CombLogic::And: {
921 auto outputLibOp = getState<ComponentLoweringState>()
922 .getNewLibraryOpInstance<calyx::AndLibOp>(
923 rewriter, loc, {one, one, one});
924 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
925 inputRegs[0].getOut());
926 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
927 inputRegs[1].getOut());
929 outputValue = outputLibOp.getOut();
932 case CombLogic::Or: {
933 auto outputLibOp = getState<ComponentLoweringState>()
934 .getNewLibraryOpInstance<calyx::OrLibOp>(
935 rewriter, loc, {one, one, one});
936 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
937 inputRegs[0].getOut());
938 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
939 inputRegs[1].getOut());
941 outputValue = outputLibOp.getOut();
946 if (info.logic != CombLogic::None) {
947 auto doneLibOp = getState<ComponentLoweringState>()
948 .getNewLibraryOpInstance<calyx::AndLibOp>(
949 rewriter, loc, {one, one, one});
950 rewriter.create<calyx::AssignOp>(loc, doneLibOp.getLeft(),
951 inputRegs[0].getDone());
952 rewriter.create<calyx::AssignOp>(loc, doneLibOp.getRight(),
953 inputRegs[1].getDone());
954 doneValue = doneLibOp.getOut();
958 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), outputValue);
959 rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), doneValue);
962 rewriter.create<calyx::AssignOp>(
963 loc, calyxCmpFOp.getGo(), c1,
965 rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());
967 cmpf.getResult().replaceAllUsesWith(reg.getOut());
970 getState<ComponentLoweringState>().registerEvaluatingGroup(outputValue,
972 getState<ComponentLoweringState>().registerEvaluatingGroup(doneValue, group);
973 getState<ComponentLoweringState>().registerEvaluatingGroup(
974 calyxCmpFOp.getLeft(), group);
975 getState<ComponentLoweringState>().registerEvaluatingGroup(
976 calyxCmpFOp.getRight(), group);
981template <
typename TAllocOp>
983 PatternRewriter &rewriter, TAllocOp allocOp) {
984 rewriter.setInsertionPointToStart(
986 MemRefType memtype = allocOp.getType();
987 SmallVector<int64_t> addrSizes;
988 SmallVector<int64_t> sizes;
989 for (int64_t dim : memtype.getShape()) {
990 sizes.push_back(dim);
995 if (sizes.empty() && addrSizes.empty()) {
997 addrSizes.push_back(1);
999 auto memoryOp = rewriter.create<calyx::SeqMemoryOp>(
1001 memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
1005 memoryOp->setAttr(
"external",
1006 IntegerAttr::get(rewriter.getI1Type(), llvm::APInt(1, 1)));
1010 unsigned elmTyBitWidth = memtype.getElementTypeBitWidth();
1011 assert(elmTyBitWidth <= 64 &&
"element bitwidth should not exceed 64");
1012 bool isFloat = !memtype.getElementType().isInteger();
1014 auto shape = allocOp.getType().getShape();
1016 std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
1023 if (!(shape.size() <= 1 || totalSize <= 1)) {
1024 allocOp.emitError(
"input memory dimension must be empty or one.");
1028 std::vector<uint64_t> flattenedVals(totalSize, 0);
1029 if (isa<memref::GetGlobalOp>(allocOp)) {
1030 auto getGlobalOp = cast<memref::GetGlobalOp>(allocOp);
1031 auto *symbolTableOp =
1032 getGlobalOp->template getParentWithTrait<mlir::OpTrait::SymbolTable>();
1033 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
1034 SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
1036 auto cstAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(
1037 globalOp.getConstantInitValue());
1039 for (
auto attr : cstAttr.template getValues<Attribute>()) {
1040 assert((isa<mlir::FloatAttr, mlir::IntegerAttr>(attr)) &&
1041 "memory attributes must be float or int");
1042 if (
auto fltAttr = dyn_cast<mlir::FloatAttr>(attr)) {
1043 flattenedVals[sizeCount++] =
1044 bit_cast<uint64_t>(fltAttr.getValueAsDouble());
1046 auto intAttr = dyn_cast<mlir::IntegerAttr>(attr);
1047 APInt value = intAttr.getValue();
1048 flattenedVals[sizeCount++] = *value.getRawData();
1052 rewriter.eraseOp(globalOp);
1055 llvm::json::Array result;
1056 result.reserve(std::max(
static_cast<int>(shape.size()), 1));
1058 Type elemType = memtype.getElementType();
1060 !elemType.isSignlessInteger() && !elemType.isUnsignedInteger();
1061 for (uint64_t bitValue : flattenedVals) {
1062 llvm::json::Value value = 0;
1066 value = bit_cast<double>(bitValue);
1068 APInt apInt(elmTyBitWidth, bitValue, isSigned,
1073 value =
static_cast<int64_t
>(apInt.getSExtValue());
1075 value = apInt.getZExtValue();
1077 result.push_back(std::move(value));
1080 componentState.
setDataField(memoryOp.getName(), result);
1081 std::string numType =
1082 memtype.getElementType().isInteger() ?
"bitnum" :
"ieee754_float";
1083 componentState.
setFormat(memoryOp.getName(), numType, isSigned,
1090 memref::AllocOp allocOp)
const {
1091 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
1095 memref::AllocaOp allocOp)
const {
1096 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
1100 memref::GetGlobalOp getGlobalOp)
const {
1101 return buildAllocOp(getState<ComponentLoweringState>(), rewriter,
1106 scf::YieldOp yieldOp)
const {
1107 if (yieldOp.getOperands().empty()) {
1108 if (
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
1112 auto inductionReg = getState<ComponentLoweringState>().getForLoopIterReg(
1115 Type regWidth = inductionReg.getOut().getType();
1117 SmallVector<Type> types(3, regWidth);
1118 auto addOp = getState<ComponentLoweringState>()
1119 .getNewLibraryOpInstance<calyx::AddLibOp>(
1120 rewriter, forOp.getLoc(), types);
1122 auto directions = addOp.portDirections();
1124 SmallVector<Value, 2> opInputPorts;
1126 for (
auto dir : enumerate(directions)) {
1127 switch (dir.value()) {
1129 opInputPorts.push_back(addOp.getResult(dir.index()));
1133 opOutputPort = addOp.getResult(dir.index());
1141 getState<ComponentLoweringState>().getComponentOp();
1142 SmallVector<StringRef, 4> groupIdentifier = {
1143 "incr", getState<ComponentLoweringState>().getUniqueName(forOp),
1144 "induction",
"var"};
1145 auto groupOp = calyx::createGroup<calyx::GroupOp>(
1147 llvm::join(groupIdentifier,
"_"));
1148 rewriter.setInsertionPointToEnd(groupOp.getBodyBlock());
1151 Value leftOp = opInputPorts.front();
1152 rewriter.create<calyx::AssignOp>(forOp.getLoc(), leftOp,
1153 inductionReg.getOut());
1155 Value rightOp = opInputPorts.back();
1156 rewriter.create<calyx::AssignOp>(
1157 forOp.getLoc(), rightOp,
1158 createConstant(forOp->getLoc(), rewriter,
componentOp,
1159 regWidth.getIntOrFloatBitWidth(),
1160 forOp.getConstantStep().value().getSExtValue()));
1162 buildAssignmentsForRegisterWrite(rewriter, groupOp,
componentOp,
1163 inductionReg, opOutputPort);
1165 getState<ComponentLoweringState>().setForLoopLatchGroup(forOpInterface,
1167 getState<ComponentLoweringState>().registerEvaluatingGroup(opOutputPort,
1171 if (
auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp()))
1174 return yieldOp.getOperation()->emitError()
1175 <<
"Unsupported empty yieldOp outside ForOp or IfOp.";
1178 if (dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
1179 return yieldOp.getOperation()->emitError()
1180 <<
"Currently do not support non-empty yield operations inside for "
1181 "loops. Run --scf-for-to-while before running --scf-to-calyx.";
1184 if (
auto whileOp = dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1188 getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
1189 rewriter, whileOpInterface,
1190 getState<ComponentLoweringState>().getComponentOp(),
1191 getState<ComponentLoweringState>().getUniqueName(whileOp) +
1193 yieldOp->getOpOperands());
1194 getState<ComponentLoweringState>().setWhileLoopLatchGroup(whileOpInterface,
1199 if (
auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
1200 auto resultRegs = getState<ComponentLoweringState>().getResultRegs(ifOp);
1202 if (yieldOp->getParentRegion() == &ifOp.getThenRegion()) {
1203 auto thenGroup = getState<ComponentLoweringState>().getThenGroup(ifOp);
1204 for (
auto op : enumerate(yieldOp.getOperands())) {
1206 getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
1207 buildAssignmentsForRegisterWrite(
1208 rewriter, thenGroup,
1209 getState<ComponentLoweringState>().getComponentOp(), resultReg,
1211 getState<ComponentLoweringState>().registerEvaluatingGroup(
1212 ifOp.getResult(op.index()), thenGroup);
1216 if (!ifOp.getElseRegion().empty() &&
1217 (yieldOp->getParentRegion() == &ifOp.getElseRegion())) {
1218 auto elseGroup = getState<ComponentLoweringState>().getElseGroup(ifOp);
1219 for (
auto op : enumerate(yieldOp.getOperands())) {
1221 getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
1222 buildAssignmentsForRegisterWrite(
1223 rewriter, elseGroup,
1224 getState<ComponentLoweringState>().getComponentOp(), resultReg,
1226 getState<ComponentLoweringState>().registerEvaluatingGroup(
1227 ifOp.getResult(op.index()), elseGroup);
1235 BranchOpInterface brOp)
const {
1240 Block *srcBlock = brOp->getBlock();
1241 for (
auto succBlock : enumerate(brOp->getSuccessors())) {
1242 auto succOperands = brOp.getSuccessorOperands(succBlock.index());
1243 if (succOperands.empty())
1248 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter,
getComponent(),
1249 brOp.getLoc(), groupName);
1251 auto dstBlockArgRegs =
1252 getState<ComponentLoweringState>().getBlockArgRegs(succBlock.value());
1254 for (
auto arg : enumerate(succOperands.getForwardedOperands())) {
1255 auto reg = dstBlockArgRegs[arg.index()];
1258 getState<ComponentLoweringState>().getComponentOp(), reg,
1263 getState<ComponentLoweringState>().addBlockArgGroup(
1264 srcBlock, succBlock.value(), groupOp);
1272 ReturnOp retOp)
const {
1273 if (retOp.getNumOperands() == 0)
1276 std::string groupName =
1277 getState<ComponentLoweringState>().getUniqueName(
"ret_assign");
1278 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter,
getComponent(),
1279 retOp.getLoc(), groupName);
1280 for (
auto op : enumerate(retOp.getOperands())) {
1281 auto reg = getState<ComponentLoweringState>().getReturnReg(op.index());
1283 rewriter, groupOp, getState<ComponentLoweringState>().getComponentOp(),
1287 getState<ComponentLoweringState>().addBlockScheduleable(retOp->getBlock(),
1293 arith::ConstantOp constOp)
const {
1294 if (isa<IntegerType>(constOp.getType())) {
1303 std::string name = getState<ComponentLoweringState>().getUniqueName(
"cst");
1304 auto floatAttr = cast<FloatAttr>(constOp.getValueAttr());
1306 rewriter.getIntegerType(floatAttr.getType().getIntOrFloatBitWidth());
1307 auto calyxConstOp = rewriter.create<calyx::ConstantOp>(
1308 constOp.getLoc(), name, floatAttr, intType);
1311 rewriter.replaceAllUsesWith(constOp, calyxConstOp.getOut());
1319 return buildLibraryOp<calyx::CombGroupOp, calyx::AddLibOp>(rewriter, op);
1323 return buildLibraryOp<calyx::CombGroupOp, calyx::SubLibOp>(rewriter, op);
1327 return buildLibraryOp<calyx::CombGroupOp, calyx::RshLibOp>(rewriter, op);
1331 return buildLibraryOp<calyx::CombGroupOp, calyx::SrshLibOp>(rewriter, op);
1335 return buildLibraryOp<calyx::CombGroupOp, calyx::LshLibOp>(rewriter, op);
1339 return buildLibraryOp<calyx::CombGroupOp, calyx::AndLibOp>(rewriter, op);
1343 return buildLibraryOp<calyx::CombGroupOp, calyx::OrLibOp>(rewriter, op);
1347 return buildLibraryOp<calyx::CombGroupOp, calyx::XorLibOp>(rewriter, op);
1350 SelectOp op)
const {
1351 return buildLibraryOp<calyx::CombGroupOp, calyx::MuxLibOp>(rewriter, op);
1356 switch (op.getPredicate()) {
1357 case CmpIPredicate::eq:
1358 return buildLibraryOp<calyx::CombGroupOp, calyx::EqLibOp>(rewriter, op);
1359 case CmpIPredicate::ne:
1360 return buildLibraryOp<calyx::CombGroupOp, calyx::NeqLibOp>(rewriter, op);
1361 case CmpIPredicate::uge:
1362 return buildLibraryOp<calyx::CombGroupOp, calyx::GeLibOp>(rewriter, op);
1363 case CmpIPredicate::ult:
1364 return buildLibraryOp<calyx::CombGroupOp, calyx::LtLibOp>(rewriter, op);
1365 case CmpIPredicate::ugt:
1366 return buildLibraryOp<calyx::CombGroupOp, calyx::GtLibOp>(rewriter, op);
1367 case CmpIPredicate::ule:
1368 return buildLibraryOp<calyx::CombGroupOp, calyx::LeLibOp>(rewriter, op);
1369 case CmpIPredicate::sge:
1370 return buildLibraryOp<calyx::CombGroupOp, calyx::SgeLibOp>(rewriter, op);
1371 case CmpIPredicate::slt:
1372 return buildLibraryOp<calyx::CombGroupOp, calyx::SltLibOp>(rewriter, op);
1373 case CmpIPredicate::sgt:
1374 return buildLibraryOp<calyx::CombGroupOp, calyx::SgtLibOp>(rewriter, op);
1375 case CmpIPredicate::sle:
1376 return buildLibraryOp<calyx::CombGroupOp, calyx::SleLibOp>(rewriter, op);
1378 llvm_unreachable(
"unsupported comparison predicate");
1381 TruncIOp op)
const {
1382 return buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
1383 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1387 return buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
1388 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1393 return buildLibraryOp<calyx::CombGroupOp, calyx::ExtSILibOp>(
1394 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1398 IndexCastOp op)
const {
1401 unsigned targetBits = targetType.getIntOrFloatBitWidth();
1402 unsigned sourceBits = sourceType.getIntOrFloatBitWidth();
1403 LogicalResult res = success();
1405 if (targetBits == sourceBits) {
1408 op.getResult().replaceAllUsesWith(op.getOperand());
1411 if (sourceBits > targetBits)
1412 res = buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
1413 rewriter, op, {sourceType}, {targetType});
1415 res = buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
1416 rewriter, op, {sourceType}, {targetType});
1418 rewriter.eraseOp(op);
1423 scf::WhileOp whileOp)
const {
1427 getState<ComponentLoweringState>().addBlockScheduleable(
1433 scf::ForOp forOp)
const {
1439 std::optional<uint64_t> bound = scfForOp.
getBound();
1440 if (!bound.has_value()) {
1442 <<
"Loop bound not statically known. Should "
1443 "transform into while loop using `--scf-for-to-while` before "
1444 "running --lower-scf-to-calyx.";
1446 getState<ComponentLoweringState>().addBlockScheduleable(
1455 scf::IfOp ifOp)
const {
1456 getState<ComponentLoweringState>().addBlockScheduleable(
1462 scf::ReduceOp reduceOp)
const {
1470 scf::ParallelOp parOp)
const {
1471 getState<ComponentLoweringState>().addBlockScheduleable(
1478 scf::ExecuteRegionOp executeRegionOp)
const {
1486 CallOp callOp)
const {
1488 calyx::InstanceOp instanceOp =
1489 getState<ComponentLoweringState>().getInstance(instanceName);
1490 SmallVector<Value, 4> outputPorts;
1491 auto portInfos = instanceOp.getReferencedComponent().getPortInfo();
1492 for (
auto [idx, portInfo] : enumerate(portInfos)) {
1494 outputPorts.push_back(instanceOp.getResult(idx));
1498 for (
auto [idx, result] : llvm::enumerate(callOp.getResults()))
1499 rewriter.replaceAllUsesWith(result, outputPorts[idx]);
1503 getState<ComponentLoweringState>().addBlockScheduleable(
1517 using OpRewritePattern::OpRewritePattern;
1520 PatternRewriter &rewriter)
const override {
1522 TypeRange yieldTypes = execOp.getResultTypes();
1526 rewriter.setInsertionPointAfter(execOp);
1527 auto *sinkBlock = rewriter.splitBlock(
1529 execOp.getOperation()->getIterator()->getNextNode()->getIterator());
1530 sinkBlock->addArguments(
1532 SmallVector<Location, 4>(yieldTypes.size(), rewriter.getUnknownLoc()));
1533 for (
auto res : enumerate(execOp.getResults()))
1534 res.value().replaceAllUsesWith(sinkBlock->getArgument(res.index()));
1538 make_early_inc_range(execOp.getRegion().getOps<scf::YieldOp>())) {
1539 rewriter.setInsertionPointAfter(yieldOp);
1540 rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, sinkBlock,
1541 yieldOp.getOperands());
1545 auto *preBlock = execOp->getBlock();
1546 auto *execOpEntryBlock = &execOp.getRegion().front();
1547 auto *postBlock = execOp->getBlock()->splitBlock(execOp);
1548 rewriter.inlineRegionBefore(execOp.getRegion(), postBlock);
1549 rewriter.mergeBlocks(postBlock, preBlock);
1550 rewriter.eraseOp(execOp);
1553 rewriter.mergeBlocks(execOpEntryBlock, preBlock);
1561 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1565 PatternRewriter &rewriter)
const override {
1568 DenseMap<Value, unsigned> funcOpArgRewrites;
1572 DenseMap<unsigned, unsigned> funcOpResultMapping;
1580 DenseMap<Value, std::pair<unsigned, unsigned>> extMemoryCompPortIndices;
1584 SmallVector<calyx::PortInfo> inPorts, outPorts;
1585 FunctionType funcType = funcOp.getFunctionType();
1586 for (
auto arg : enumerate(funcOp.getArguments())) {
1587 if (!isa<MemRefType>(arg.value().getType())) {
1590 if (
auto portNameAttr = funcOp.getArgAttrOfType<StringAttr>(
1592 inName = portNameAttr.str();
1594 inName =
"in" + std::to_string(arg.index());
1595 funcOpArgRewrites[arg.value()] = inPorts.size();
1597 rewriter.getStringAttr(inName),
1600 DictionaryAttr::get(rewriter.getContext(), {})});
1603 for (
auto res : enumerate(funcType.getResults())) {
1604 std::string resName;
1605 if (
auto portNameAttr = funcOp.getResultAttrOfType<StringAttr>(
1607 resName = portNameAttr.str();
1609 resName =
"out" + std::to_string(res.index());
1610 funcOpResultMapping[res.index()] = outPorts.size();
1613 rewriter.getStringAttr(resName),
1615 DictionaryAttr::get(rewriter.getContext(), {})});
1620 auto ports = inPorts;
1621 llvm::append_range(ports, outPorts);
1625 auto compOp = rewriter.create<calyx::ComponentOp>(
1626 funcOp.getLoc(), rewriter.getStringAttr(funcOp.getSymName()), ports);
1628 std::string funcName =
"func_" + funcOp.getSymName().str();
1629 rewriter.modifyOpInPlace(funcOp, [&]() { funcOp.setSymName(funcName); });
1634 compOp->setAttr(
"toplevel", rewriter.getUnitAttr());
1641 unsigned extMemCounter = 0;
1642 for (
auto arg : enumerate(funcOp.getArguments())) {
1643 if (isa<MemRefType>(arg.value().getType())) {
1644 std::string memName =
1645 llvm::join_items(
"_",
"arg_mem", std::to_string(extMemCounter++));
1647 rewriter.setInsertionPointToStart(compOp.getBodyBlock());
1648 MemRefType memtype = cast<MemRefType>(arg.value().getType());
1649 SmallVector<int64_t> addrSizes;
1650 SmallVector<int64_t> sizes;
1651 for (int64_t dim : memtype.getShape()) {
1652 sizes.push_back(dim);
1655 if (sizes.empty() && addrSizes.empty()) {
1657 addrSizes.push_back(1);
1659 auto memOp = rewriter.create<calyx::SeqMemoryOp>(
1660 funcOp.getLoc(), memName,
1661 memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
1664 compState->registerMemoryInterface(arg.value(),
1670 for (
auto &mapping : funcOpArgRewrites)
1671 mapping.getFirst().replaceAllUsesWith(
1672 compOp.getArgument(mapping.getSecond()));
1683 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1687 PatternRewriter &rewriter)
const override {
1688 LogicalResult res = success();
1689 funcOp.walk([&](Operation *op) {
1691 if (!isa<scf::WhileOp>(op))
1692 return WalkResult::advance();
1694 auto scfWhileOp = cast<scf::WhileOp>(op);
1697 getState<ComponentLoweringState>().setUniqueName(whileOp.
getOperation(),
1707 enumerate(scfWhileOp.getBefore().front().getArguments())) {
1708 auto condOp = scfWhileOp.getConditionOp().getArgs()[barg.index()];
1709 if (barg.value() != condOp) {
1713 <<
"do-while loops not supported; expected iter-args to "
1714 "remain untransformed in the 'before' region of the "
1716 return WalkResult::interrupt();
1725 for (
auto arg : enumerate(whileOp.
getBodyArgs())) {
1726 std::string name = getState<ComponentLoweringState>()
1729 "_arg" + std::to_string(arg.index());
1731 createRegister(arg.value().getLoc(), rewriter,
getComponent(),
1732 arg.value().getType().getIntOrFloatBitWidth(), name);
1733 getState<ComponentLoweringState>().addWhileLoopIterReg(whileOp, reg,
1735 arg.value().replaceAllUsesWith(reg.getOut());
1739 ->getArgument(arg.index())
1740 .replaceAllUsesWith(reg.getOut());
1744 SmallVector<calyx::GroupOp> initGroups;
1745 auto numOperands = whileOp.
getOperation()->getNumOperands();
1746 for (
size_t i = 0; i < numOperands; ++i) {
1748 getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
1750 getState<ComponentLoweringState>().getComponentOp(),
1751 getState<ComponentLoweringState>().getUniqueName(
1753 "_init_" + std::to_string(i),
1755 initGroups.push_back(initGroupOp);
1758 getState<ComponentLoweringState>().setWhileLoopInitGroups(whileOp,
1761 return WalkResult::advance();
1771 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1775 PatternRewriter &rewriter)
const override {
1776 LogicalResult res = success();
1777 funcOp.walk([&](Operation *op) {
1779 if (!isa<scf::ForOp>(op))
1780 return WalkResult::advance();
1782 auto scfForOp = cast<scf::ForOp>(op);
1785 getState<ComponentLoweringState>().setUniqueName(forOp.
getOperation(),
1790 auto inductionVar = forOp.
getOperation().getInductionVar();
1791 SmallVector<std::string, 3> inductionVarIdentifiers = {
1792 getState<ComponentLoweringState>()
1795 "induction",
"var"};
1796 std::string name = llvm::join(inductionVarIdentifiers,
"_");
1798 createRegister(inductionVar.getLoc(), rewriter,
getComponent(),
1799 inductionVar.getType().getIntOrFloatBitWidth(), name);
1800 getState<ComponentLoweringState>().addForLoopIterReg(forOp, reg, 0);
1801 inductionVar.replaceAllUsesWith(reg.getOut());
1805 getState<ComponentLoweringState>().getComponentOp();
1806 SmallVector<calyx::GroupOp> initGroups;
1807 SmallVector<std::string, 4> groupIdentifiers = {
1809 getState<ComponentLoweringState>()
1812 "induction",
"var"};
1813 std::string groupName = llvm::join(groupIdentifiers,
"_");
1814 auto groupOp = calyx::createGroup<calyx::GroupOp>(
1816 buildAssignmentsForRegisterWrite(rewriter, groupOp,
componentOp, reg,
1818 initGroups.push_back(groupOp);
1819 getState<ComponentLoweringState>().setForLoopInitGroups(forOp,
1822 return WalkResult::advance();
1829 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1833 PatternRewriter &rewriter)
const override {
1834 LogicalResult res = success();
1835 funcOp.walk([&](Operation *op) {
1836 if (!isa<scf::IfOp>(op))
1837 return WalkResult::advance();
1839 auto scfIfOp = cast<scf::IfOp>(op);
1844 if (scfIfOp.getResults().empty())
1845 return WalkResult::advance();
1848 getState<ComponentLoweringState>().getComponentOp();
1850 std::string thenGroupName =
1851 getState<ComponentLoweringState>().getUniqueName(
"then_br");
1852 auto thenGroupOp = calyx::createGroup<calyx::GroupOp>(
1853 rewriter,
componentOp, scfIfOp.getLoc(), thenGroupName);
1854 getState<ComponentLoweringState>().setThenGroup(scfIfOp, thenGroupOp);
1856 if (!scfIfOp.getElseRegion().empty()) {
1857 std::string elseGroupName =
1858 getState<ComponentLoweringState>().getUniqueName(
"else_br");
1859 auto elseGroupOp = calyx::createGroup<calyx::GroupOp>(
1860 rewriter,
componentOp, scfIfOp.getLoc(), elseGroupName);
1861 getState<ComponentLoweringState>().setElseGroup(scfIfOp, elseGroupOp);
1864 for (
auto ifOpRes : scfIfOp.getResults()) {
1865 auto reg = createRegister(
1867 ifOpRes.getType().getIntOrFloatBitWidth(),
1868 getState<ComponentLoweringState>().getUniqueName(
"if_res"));
1869 getState<ComponentLoweringState>().setResultRegs(
1870 scfIfOp, reg, ifOpRes.getResultNumber());
1873 return WalkResult::advance();
1880 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1884 PatternRewriter &rewriter)
const override {
1885 WalkResult walkResult = funcOp.walk([&](scf::ParallelOp scfParOp) {
1886 if (!scfParOp.getResults().empty()) {
1888 "Reduce operations in scf.parallel is not supported yet");
1889 return WalkResult::interrupt();
1893 return WalkResult::interrupt();
1895 return WalkResult::advance();
1898 return walkResult.wasInterrupted() ? failure() : success();
1905 scf::ParallelOp scfParOp)
const {
1906 assert(scfParOp.getLoopSteps() &&
"Parallel loop must have steps");
1907 auto *body = scfParOp.getBody();
1908 auto parOpIVs = scfParOp.getInductionVars();
1909 auto steps = scfParOp.getStep();
1910 auto lowerBounds = scfParOp.getLowerBound();
1911 auto upperBounds = scfParOp.getUpperBound();
1912 rewriter.setInsertionPointAfter(scfParOp);
1913 scf::ParallelOp newParOp = scfParOp.cloneWithoutRegions();
1914 auto loc = newParOp.getLoc();
1915 rewriter.insert(newParOp);
1916 OpBuilder insideBuilder(newParOp);
1917 auto ®ion = newParOp.getRegion();
1918 auto *newParBodyBlock = ®ion.emplaceBlock();
1921 SmallVector<int64_t> lbVals, ubVals, stepVals;
1922 for (
auto lb : lowerBounds) {
1923 auto lbOp = lb.getDefiningOp<arith::ConstantIndexOp>();
1925 "Lower bound must be a statically computable constant index");
1926 lbVals.push_back(lbOp.value());
1928 for (
auto ub : upperBounds) {
1929 auto ubOp = ub.getDefiningOp<arith::ConstantIndexOp>();
1931 "Upper bound must be a statically computable constant index");
1932 ubVals.push_back(ubOp.value());
1934 for (
auto step : steps) {
1935 auto stepOp = step.getDefiningOp<arith::ConstantIndexOp>();
1936 assert(stepOp &&
"Step must be a statically computable constant index");
1937 stepVals.push_back(stepOp.value());
1941 SmallVector<int64_t> indices = lbVals;
1944 insideBuilder.setInsertionPointToEnd(newParBodyBlock);
1948 insideBuilder.create<scf::ExecuteRegionOp>(loc, TypeRange{});
1949 auto &execRegion = execRegionOp.getRegion();
1950 Block *execBlock = &execRegion.emplaceBlock();
1951 OpBuilder regionBuilder(execRegionOp);
1955 IRMapping operandMap;
1957 regionBuilder.setInsertionPointToEnd(execBlock);
1959 for (
unsigned i = 0; i < indices.size(); ++i) {
1961 regionBuilder.create<arith::ConstantIndexOp>(loc, indices[i]);
1962 operandMap.map(parOpIVs[i], ivConstant);
1965 for (
auto it = body->begin(); it != std::prev(body->end()); ++it)
1966 regionBuilder.clone(*it, operandMap);
1969 regionBuilder.create<scf::ReduceOp>(loc);
1972 for (
int dim = indices.size() - 1; dim >= 0; --dim) {
1973 indices[dim] += stepVals[dim];
1974 if (indices[dim] < ubVals[dim])
1976 indices[dim] = lbVals[dim];
1985 rewriter.setInsertionPointToEnd(newParOp.getBody());
1986 rewriter.create<scf::ReduceOp>(newParOp.getLoc());
1988 rewriter.replaceOp(scfParOp, newParOp);
1999 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2003 PatternRewriter &rewriter)
const override {
2004 auto *entryBlock = &funcOp.getBlocks().front();
2005 rewriter.setInsertionPointToStart(
2007 auto topLevelSeqOp = rewriter.create<calyx::SeqOp>(funcOp.getLoc());
2008 DenseSet<Block *> path;
2010 nullptr, entryBlock);
2017 const DenseSet<Block *> &path,
2018 mlir::Block *parentCtrlBlock,
2019 mlir::Block *block)
const {
2020 auto compBlockScheduleables =
2021 getState<ComponentLoweringState>().getBlockScheduleables(block);
2022 auto loc = block->front().getLoc();
2024 if (compBlockScheduleables.size() > 1 &&
2025 !isa<scf::ParallelOp>(block->getParentOp())) {
2026 auto seqOp = rewriter.create<calyx::SeqOp>(loc);
2027 parentCtrlBlock = seqOp.getBodyBlock();
2030 for (
auto &group : compBlockScheduleables) {
2031 rewriter.setInsertionPointToEnd(parentCtrlBlock);
2032 if (
auto groupPtr = std::get_if<calyx::GroupOp>(&group); groupPtr) {
2033 rewriter.create<calyx::EnableOp>(groupPtr->getLoc(),
2034 groupPtr->getSymName());
2035 }
else if (
auto whileSchedPtr = std::get_if<WhileScheduleable>(&group);
2037 auto &whileOp = whileSchedPtr->whileOp;
2041 getState<ComponentLoweringState>().getWhileLoopInitGroups(whileOp),
2043 rewriter.setInsertionPointToEnd(whileCtrlOp.getBodyBlock());
2045 rewriter.create<calyx::SeqOp>(whileOp.getOperation()->getLoc());
2046 auto *whileBodyOpBlock = whileBodyOp.getBodyBlock();
2050 if (LogicalResult result =
2052 whileOp.getBodyBlock());
2057 rewriter.setInsertionPointToEnd(whileBodyOpBlock);
2058 calyx::GroupOp whileLatchGroup =
2059 getState<ComponentLoweringState>().getWhileLoopLatchGroup(whileOp);
2060 rewriter.create<calyx::EnableOp>(whileLatchGroup.getLoc(),
2061 whileLatchGroup.getName());
2062 }
else if (
auto *parSchedPtr = std::get_if<ParScheduleable>(&group)) {
2063 auto parOp = parSchedPtr->parOp;
2064 auto calyxParOp = rewriter.create<calyx::ParOp>(parOp.getLoc());
2066 WalkResult walkResult =
2067 parOp.walk([&](scf::ExecuteRegionOp execRegion) {
2068 rewriter.setInsertionPointToEnd(calyxParOp.getBodyBlock());
2069 auto seqOp = rewriter.create<calyx::SeqOp>(execRegion.getLoc());
2070 rewriter.setInsertionPointToEnd(seqOp.getBodyBlock());
2072 for (
auto &execBlock : execRegion.getRegion().getBlocks()) {
2074 rewriter, path, seqOp.getBodyBlock(), &execBlock);
2076 return WalkResult::interrupt();
2079 return WalkResult::advance();
2082 if (walkResult.wasInterrupted())
2085 }
else if (
auto *forSchedPtr = std::get_if<ForScheduleable>(&group);
2087 auto forOp = forSchedPtr->forOp;
2091 getState<ComponentLoweringState>().getForLoopInitGroups(forOp),
2092 forSchedPtr->bound, rewriter);
2093 rewriter.setInsertionPointToEnd(forCtrlOp.getBodyBlock());
2095 rewriter.create<calyx::SeqOp>(forOp.getOperation()->getLoc());
2096 auto *forBodyOpBlock = forBodyOp.getBodyBlock();
2099 if (LogicalResult res =
buildCFGControl(path, rewriter, forBodyOpBlock,
2100 block, forOp.getBodyBlock());
2105 rewriter.setInsertionPointToEnd(forBodyOpBlock);
2106 calyx::GroupOp forLatchGroup =
2107 getState<ComponentLoweringState>().getForLoopLatchGroup(forOp);
2108 rewriter.create<calyx::EnableOp>(forLatchGroup.getLoc(),
2109 forLatchGroup.getName());
2110 }
else if (
auto *ifSchedPtr = std::get_if<IfScheduleable>(&group);
2112 auto ifOp = ifSchedPtr->ifOp;
2114 Location loc = ifOp->getLoc();
2116 auto cond = ifOp.getCondition();
2117 auto condGroup = getState<ComponentLoweringState>()
2118 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
2120 auto symbolAttr = FlatSymbolRefAttr::get(
2121 StringAttr::get(getContext(), condGroup.getSymName()));
2123 bool initElse = !ifOp.getElseRegion().empty();
2124 auto ifCtrlOp = rewriter.create<calyx::IfOp>(
2125 loc, cond, symbolAttr, initElse);
2127 rewriter.setInsertionPointToEnd(ifCtrlOp.getBodyBlock());
2130 rewriter.create<calyx::SeqOp>(ifOp.getThenRegion().getLoc());
2131 auto *thenSeqOpBlock = thenSeqOp.getBodyBlock();
2133 auto *thenBlock = &ifOp.getThenRegion().front();
2141 if (!ifOp.getResults().empty()) {
2142 rewriter.setInsertionPointToEnd(thenSeqOpBlock);
2143 calyx::GroupOp thenGroup =
2144 getState<ComponentLoweringState>().getThenGroup(ifOp);
2145 rewriter.create<calyx::EnableOp>(thenGroup.getLoc(),
2146 thenGroup.getName());
2149 if (!ifOp.getElseRegion().empty()) {
2150 rewriter.setInsertionPointToEnd(ifCtrlOp.getElseBody());
2153 rewriter.create<calyx::SeqOp>(ifOp.getElseRegion().getLoc());
2154 auto *elseSeqOpBlock = elseSeqOp.getBodyBlock();
2156 auto *elseBlock = &ifOp.getElseRegion().front();
2162 if (!ifOp.getResults().empty()) {
2163 rewriter.setInsertionPointToEnd(elseSeqOpBlock);
2164 calyx::GroupOp elseGroup =
2165 getState<ComponentLoweringState>().getElseGroup(ifOp);
2166 rewriter.create<calyx::EnableOp>(elseGroup.getLoc(),
2167 elseGroup.getName());
2170 }
else if (
auto *callSchedPtr = std::get_if<CallScheduleable>(&group)) {
2171 auto instanceOp = callSchedPtr->instanceOp;
2172 OpBuilder::InsertionGuard g(rewriter);
2173 auto callBody = rewriter.create<calyx::SeqOp>(instanceOp.getLoc());
2174 rewriter.setInsertionPointToStart(callBody.getBodyBlock());
2175 std::string initGroupName =
"init_" + instanceOp.getSymName().str();
2176 rewriter.create<calyx::EnableOp>(instanceOp.getLoc(), initGroupName);
2178 auto callee = callSchedPtr->callOp.getCallee();
2179 auto *calleeOp = SymbolTable::lookupNearestSymbolFrom(
2180 callSchedPtr->callOp.getOperation()->getParentOp(),
2181 StringAttr::get(rewriter.getContext(),
"func_" + callee.str()));
2182 FuncOp calleeFunc = dyn_cast_or_null<FuncOp>(calleeOp);
2184 auto instanceOpComp =
2185 llvm::cast<calyx::ComponentOp>(instanceOp.getReferencedComponent());
2186 auto *instanceOpLoweringState =
2189 SmallVector<Value, 4> instancePorts;
2190 SmallVector<Value, 4> inputPorts;
2191 SmallVector<Attribute, 4> refCells;
2192 for (
auto operandEnum : enumerate(callSchedPtr->callOp.getOperands())) {
2193 auto operand = operandEnum.value();
2194 auto index = operandEnum.index();
2195 if (!isa<MemRefType>(operand.getType())) {
2196 inputPorts.push_back(operand);
2200 auto memOpName = getState<ComponentLoweringState>()
2201 .getMemoryInterface(operand)
2203 auto memOpNameAttr =
2204 SymbolRefAttr::get(rewriter.getContext(), memOpName);
2205 Value argI = calleeFunc.getArgument(index);
2206 if (isa<MemRefType>(argI.getType())) {
2207 NamedAttrList namedAttrList;
2208 namedAttrList.append(
2209 rewriter.getStringAttr(
2210 instanceOpLoweringState->getMemoryInterface(argI)
2214 DictionaryAttr::get(rewriter.getContext(), namedAttrList));
2217 llvm::copy(instanceOp.getResults().take_front(inputPorts.size()),
2218 std::back_inserter(instancePorts));
2220 ArrayAttr refCellsAttr =
2221 ArrayAttr::get(rewriter.getContext(), refCells);
2223 rewriter.create<calyx::InvokeOp>(
2224 instanceOp.getLoc(), instanceOp.getSymName(), instancePorts,
2225 inputPorts, refCellsAttr, ArrayAttr::get(rewriter.getContext(), {}),
2226 ArrayAttr::get(rewriter.getContext(), {}));
2228 llvm_unreachable(
"Unknown scheduleable");
2239 const DenseSet<Block *> &path, Location loc,
2240 Block *from, Block *to,
2241 Block *parentCtrlBlock)
const {
2244 rewriter.setInsertionPointToEnd(parentCtrlBlock);
2245 auto preSeqOp = rewriter.create<calyx::SeqOp>(loc);
2246 rewriter.setInsertionPointToEnd(preSeqOp.getBodyBlock());
2248 getState<ComponentLoweringState>().getBlockArgGroups(from, to))
2249 rewriter.create<calyx::EnableOp>(barg.getLoc(), barg.getSymName());
2255 PatternRewriter &rewriter,
2256 mlir::Block *parentCtrlBlock,
2257 mlir::Block *preBlock,
2258 mlir::Block *block)
const {
2259 if (path.count(block) != 0)
2260 return preBlock->getTerminator()->emitError()
2261 <<
"CFG backedge detected. Loops must be raised to 'scf.while' or "
2262 "'scf.for' operations.";
2264 rewriter.setInsertionPointToEnd(parentCtrlBlock);
2265 LogicalResult bbSchedResult =
2267 if (bbSchedResult.failed())
2268 return bbSchedResult;
2271 auto successors = block->getSuccessors();
2272 auto nSuccessors = successors.size();
2273 if (nSuccessors > 0) {
2274 auto brOp = dyn_cast<BranchOpInterface>(block->getTerminator());
2276 if (nSuccessors > 1) {
2280 assert(nSuccessors == 2 &&
2281 "only conditional branches supported for now...");
2283 auto cond = brOp->getOperand(0);
2284 auto condGroup = getState<ComponentLoweringState>()
2285 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
2286 auto symbolAttr = FlatSymbolRefAttr::get(
2287 StringAttr::get(getContext(), condGroup.getSymName()));
2289 auto ifOp = rewriter.create<calyx::IfOp>(
2290 brOp->getLoc(), cond, symbolAttr,
true);
2291 rewriter.setInsertionPointToStart(ifOp.getThenBody());
2292 auto thenSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
2293 rewriter.setInsertionPointToStart(ifOp.getElseBody());
2294 auto elseSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
2296 bool trueBrSchedSuccess =
2297 schedulePath(rewriter, path, brOp.getLoc(), block, successors[0],
2298 thenSeqOp.getBodyBlock())
2300 bool falseBrSchedSuccess =
true;
2301 if (trueBrSchedSuccess) {
2302 falseBrSchedSuccess =
2303 schedulePath(rewriter, path, brOp.getLoc(), block, successors[1],
2304 elseSeqOp.getBodyBlock())
2308 return success(trueBrSchedSuccess && falseBrSchedSuccess);
2311 return schedulePath(rewriter, path, brOp.getLoc(), block,
2312 successors.front(), parentCtrlBlock);
2322 const SmallVector<calyx::GroupOp> &initGroups)
const {
2323 PatternRewriter::InsertionGuard g(rewriter);
2324 auto parOp = rewriter.create<calyx::ParOp>(loc);
2325 rewriter.setInsertionPointToStart(parOp.getBodyBlock());
2326 for (calyx::GroupOp group : initGroups)
2327 rewriter.create<calyx::EnableOp>(group.getLoc(), group.getName());
2331 SmallVector<calyx::GroupOp> initGroups,
2332 PatternRewriter &rewriter)
const {
2333 Location loc = whileOp.
getLoc();
2340 auto condGroup = getState<ComponentLoweringState>()
2341 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
2342 auto symbolAttr = FlatSymbolRefAttr::get(
2343 StringAttr::get(getContext(), condGroup.getSymName()));
2344 return rewriter.create<calyx::WhileOp>(loc, cond, symbolAttr);
2348 SmallVector<calyx::GroupOp>
const &initGroups,
2350 PatternRewriter &rewriter)
const {
2351 Location loc = forOp.
getLoc();
2357 return rewriter.create<calyx::RepeatOp>(loc, bound);
2364 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2367 PatternRewriter &)
const override {
2368 funcOp.walk([&](scf::IfOp op) {
2369 for (
auto res : getState<ComponentLoweringState>().getResultRegs(op))
2370 op.getOperation()->getResults()[res.first].replaceAllUsesWith(
2371 res.second.getOut());
2374 funcOp.walk([&](scf::WhileOp op) {
2383 getState<ComponentLoweringState>().getWhileLoopIterRegs(whileOp))
2384 whileOp.
getOperation()->getResults()[res.first].replaceAllUsesWith(
2385 res.second.getOut());
2388 funcOp.walk([&](memref::LoadOp loadOp) {
2394 loadOp.getResult().replaceAllUsesWith(
2395 getState<ComponentLoweringState>()
2396 .getMemoryInterface(loadOp.getMemref())
2407 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2410 PatternRewriter &rewriter)
const override {
2411 rewriter.eraseOp(funcOp);
2417 PatternRewriter &rewriter)
const override {
2431class SCFToCalyxPass :
public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
2435 void runOnOperation()
override;
2437 LogicalResult setTopLevelFunction(mlir::ModuleOp moduleOp,
2438 std::string &topLevelFunction) {
2439 if (!topLevelFunctionOpt.empty()) {
2440 if (SymbolTable::lookupSymbolIn(moduleOp, topLevelFunctionOpt) ==
2442 moduleOp.emitError() <<
"Top level function '" << topLevelFunctionOpt
2443 <<
"' not found in module.";
2446 topLevelFunction = topLevelFunctionOpt;
2450 auto funcOps = moduleOp.getOps<FuncOp>();
2451 if (std::distance(funcOps.begin(), funcOps.end()) == 1)
2452 topLevelFunction = (*funcOps.begin()).getSymName().str();
2454 moduleOp.emitError()
2455 <<
"Module contains multiple functions, but no top level "
2456 "function was set. Please see --top-level-function";
2461 return createOptNewTopLevelFn(moduleOp, topLevelFunction);
2464 struct LoweringPattern {
2465 enum class Strategy { Once, Greedy };
2474 LogicalResult labelEntryPoint(StringRef topLevelFunction) {
2478 using OpRewritePattern::OpRewritePattern;
2479 LogicalResult matchAndRewrite(mlir::ModuleOp,
2480 PatternRewriter &)
const override {
2485 ConversionTarget target(getContext());
2486 target.addLegalDialect<calyx::CalyxDialect>();
2487 target.addLegalDialect<scf::SCFDialect>();
2488 target.addIllegalDialect<hw::HWDialect>();
2489 target.addIllegalDialect<comb::CombDialect>();
2492 target.addIllegalDialect<FuncDialect>();
2493 target.addIllegalDialect<ArithDialect>();
2494 target.addLegalOp<AddIOp, SelectOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp,
2495 ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
2496 CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
2497 RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
2498 ExtSIOp, CallOp, AddFOp, MulFOp, CmpFOp>();
2500 RewritePatternSet legalizePatterns(&getContext());
2501 legalizePatterns.add<DummyPattern>(&getContext());
2502 DenseSet<Operation *> legalizedOps;
2503 if (applyPartialConversion(getOperation(), target,
2504 std::move(legalizePatterns))
2515 template <
typename TPattern,
typename... PatternArgs>
2516 void addOncePattern(SmallVectorImpl<LoweringPattern> &
patterns,
2517 PatternArgs &&...args) {
2518 RewritePatternSet ps(&getContext());
2521 LoweringPattern{std::move(ps), LoweringPattern::Strategy::Once});
2524 template <
typename TPattern,
typename... PatternArgs>
2525 void addGreedyPattern(SmallVectorImpl<LoweringPattern> &
patterns,
2526 PatternArgs &&...args) {
2527 RewritePatternSet ps(&getContext());
2528 ps.add<TPattern>(&getContext(), args...);
2530 LoweringPattern{std::move(ps), LoweringPattern::Strategy::Greedy});
2533 LogicalResult runPartialPattern(RewritePatternSet &
pattern,
bool runOnce) {
2535 "Should only apply 1 partial lowering pattern at once");
2541 GreedyRewriteConfig config;
2542 config.enableRegionSimplification =
2543 mlir::GreedySimplifyRegionLevel::Disabled;
2545 config.maxIterations = 1;
2550 (void)applyPatternsGreedily(getOperation(), std::move(
pattern), config);
2559 FuncOp createNewTopLevelFn(ModuleOp moduleOp, std::string &baseName) {
2560 std::string newName =
"main";
2562 if (
auto *existingMainOp = SymbolTable::lookupSymbolIn(moduleOp, newName)) {
2563 auto existingMainFunc = dyn_cast<FuncOp>(existingMainOp);
2564 if (existingMainFunc ==
nullptr) {
2565 moduleOp.emitError() <<
"Symbol 'main' exists but is not a function";
2568 unsigned counter = 0;
2569 std::string newOldName = baseName;
2570 while (SymbolTable::lookupSymbolIn(moduleOp, newOldName))
2571 newOldName = llvm::join_items(
"_", baseName, std::to_string(++counter));
2572 existingMainFunc.setName(newOldName);
2573 if (baseName ==
"main")
2574 baseName = newOldName;
2578 OpBuilder builder(moduleOp.getContext());
2579 builder.setInsertionPointToStart(moduleOp.getBody());
2581 FunctionType funcType = builder.getFunctionType({}, {});
2584 builder.create<FuncOp>(moduleOp.getLoc(), newName, funcType))
2594 void insertCallFromNewTopLevel(OpBuilder &builder, FuncOp caller,
2596 if (caller.getBody().empty()) {
2597 caller.addEntryBlock();
2600 Block *callerEntryBlock = &caller.getBody().front();
2601 builder.setInsertionPointToStart(callerEntryBlock);
2605 SmallVector<Type, 4> nonMemRefCalleeArgTypes;
2606 for (
auto arg : callee.getArguments()) {
2607 if (!isa<MemRefType>(arg.getType())) {
2608 nonMemRefCalleeArgTypes.push_back(arg.getType());
2612 for (Type type : nonMemRefCalleeArgTypes) {
2613 callerEntryBlock->addArgument(type, caller.getLoc());
2616 FunctionType callerFnType = caller.getFunctionType();
2617 SmallVector<Type, 4> updatedCallerArgTypes(
2618 caller.getFunctionType().getInputs());
2619 updatedCallerArgTypes.append(nonMemRefCalleeArgTypes.begin(),
2620 nonMemRefCalleeArgTypes.end());
2621 caller.setType(FunctionType::get(caller.getContext(), updatedCallerArgTypes,
2622 callerFnType.getResults()));
2624 Block *calleeFnBody = &callee.getBody().front();
2625 unsigned originalCalleeArgNum = callee.getArguments().size();
2627 SmallVector<Value, 4> extraMemRefArgs;
2628 SmallVector<Type, 4> extraMemRefArgTypes;
2629 SmallVector<Value, 4> extraMemRefOperands;
2630 SmallVector<Operation *, 4> opsToModify;
2631 for (
auto &op : callee.getBody().getOps()) {
2632 if (isa<memref::AllocaOp, memref::AllocOp, memref::GetGlobalOp>(op))
2633 opsToModify.push_back(&op);
2638 builder.setInsertionPointToEnd(callerEntryBlock);
2639 for (
auto *op : opsToModify) {
2642 TypeSwitch<Operation *>(op)
2643 .Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
2644 newOpRes = builder.create<memref::AllocaOp>(callee.getLoc(),
2645 allocaOp.getType());
2647 .Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
2648 newOpRes = builder.create<memref::AllocOp>(callee.getLoc(),
2651 .Case<memref::GetGlobalOp>([&](memref::GetGlobalOp getGlobalOp) {
2652 newOpRes = builder.create<memref::GetGlobalOp>(
2653 caller.getLoc(), getGlobalOp.getType(), getGlobalOp.getName());
2655 .Default([&](Operation *defaultOp) {
2656 llvm::report_fatal_error(
"Unsupported operation in TypeSwitch");
2658 extraMemRefOperands.push_back(newOpRes);
2660 calleeFnBody->addArgument(newOpRes.getType(), callee.getLoc());
2661 BlockArgument newBodyArg = calleeFnBody->getArguments().back();
2662 op->getResult(0).replaceAllUsesWith(newBodyArg);
2664 extraMemRefArgs.push_back(newBodyArg);
2665 extraMemRefArgTypes.push_back(newBodyArg.getType());
2668 SmallVector<Type, 4> updatedCalleeArgTypes(
2669 callee.getFunctionType().getInputs());
2670 updatedCalleeArgTypes.append(extraMemRefArgTypes.begin(),
2671 extraMemRefArgTypes.end());
2672 callee.setType(FunctionType::get(callee.getContext(), updatedCalleeArgTypes,
2673 callee.getFunctionType().getResults()));
2675 unsigned otherArgsCount = 0;
2676 SmallVector<Value, 4> calleeArgFnOperands;
2677 builder.setInsertionPointToStart(callerEntryBlock);
2678 for (
auto arg : callee.getArguments().take_front(originalCalleeArgNum)) {
2679 if (isa<MemRefType>(arg.getType())) {
2680 auto memrefType = cast<MemRefType>(arg.getType());
2682 builder.create<memref::AllocOp>(callee.getLoc(), memrefType);
2683 calleeArgFnOperands.push_back(allocOp);
2685 auto callerArg = callerEntryBlock->getArgument(otherArgsCount++);
2686 calleeArgFnOperands.push_back(callerArg);
2690 SmallVector<Value, 4> fnOperands;
2691 fnOperands.append(calleeArgFnOperands.begin(), calleeArgFnOperands.end());
2692 fnOperands.append(extraMemRefOperands.begin(), extraMemRefOperands.end());
2694 SymbolRefAttr::get(builder.getContext(), callee.getSymName());
2695 auto resultTypes = callee.getResultTypes();
2697 builder.setInsertionPointToEnd(callerEntryBlock);
2698 builder.create<CallOp>(caller.getLoc(), calleeName, resultTypes,
2700 builder.create<ReturnOp>(caller.getLoc());
2706 LogicalResult createOptNewTopLevelFn(ModuleOp moduleOp,
2707 std::string &topLevelFunction) {
2708 auto hasMemrefArguments = [](FuncOp func) {
2710 func.getArguments().begin(), func.getArguments().end(),
2711 [](BlockArgument arg) { return isa<MemRefType>(arg.getType()); });
2717 auto funcOps = moduleOp.getOps<FuncOp>();
2718 bool hasMemrefArgsInTopLevel =
2719 std::any_of(funcOps.begin(), funcOps.end(), [&](
auto funcOp) {
2720 return funcOp.getName() == topLevelFunction &&
2721 hasMemrefArguments(funcOp);
2724 if (hasMemrefArgsInTopLevel) {
2725 auto newTopLevelFunc = createNewTopLevelFn(moduleOp, topLevelFunction);
2726 if (!newTopLevelFunc)
2729 OpBuilder builder(moduleOp.getContext());
2730 Operation *oldTopLevelFuncOp =
2731 SymbolTable::lookupSymbolIn(moduleOp, topLevelFunction);
2732 if (
auto oldTopLevelFunc = dyn_cast<FuncOp>(oldTopLevelFuncOp))
2733 insertCallFromNewTopLevel(builder, newTopLevelFunc, oldTopLevelFunc);
2735 moduleOp.emitOpError(
"Original top-level function not found!");
2738 topLevelFunction =
"main";
2745void SCFToCalyxPass::runOnOperation() {
2750 std::string topLevelFunction;
2751 if (failed(setTopLevelFunction(getOperation(), topLevelFunction))) {
2752 signalPassFailure();
2757 if (failed(labelEntryPoint(topLevelFunction))) {
2758 signalPassFailure();
2761 loweringState = std::make_shared<calyx::CalyxLoweringState>(getOperation(),
2772 DenseMap<FuncOp, calyx::ComponentOp> funcMap;
2773 SmallVector<LoweringPattern, 8> loweringPatterns;
2777 addOncePattern<FuncOpConversion>(loweringPatterns, patternState, funcMap,
2781 addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);
2783 addOncePattern<BuildParGroups>(loweringPatterns, patternState, funcMap,
2787 addOncePattern<calyx::ConvertIndexTypes>(loweringPatterns, patternState,
2791 addOncePattern<calyx::BuildBasicBlockRegs>(loweringPatterns, patternState,
2794 addOncePattern<calyx::BuildCallInstance>(loweringPatterns, patternState,
2798 addOncePattern<calyx::BuildReturnRegs>(loweringPatterns, patternState,
2804 addOncePattern<BuildWhileGroups>(loweringPatterns, patternState, funcMap,
2810 addOncePattern<BuildForGroups>(loweringPatterns, patternState, funcMap,
2813 addOncePattern<BuildIfGroups>(loweringPatterns, patternState, funcMap,
2823 addOncePattern<BuildOpGroups>(loweringPatterns, patternState, funcMap,
2829 addOncePattern<BuildControl>(loweringPatterns, patternState, funcMap,
2834 addOncePattern<calyx::InlineCombGroups>(loweringPatterns, patternState,
2839 addOncePattern<LateSSAReplacement>(loweringPatterns, patternState, funcMap,
2845 addGreedyPattern<calyx::EliminateUnusedCombGroups>(loweringPatterns);
2849 addOncePattern<calyx::RewriteMemoryAccesses>(loweringPatterns, patternState,
2854 addOncePattern<CleanupFuncOps>(loweringPatterns, patternState, funcMap,
2858 for (
auto &pat : loweringPatterns) {
2861 pat.strategy == LoweringPattern::Strategy::Once);
2864 signalPassFailure();
2871 RewritePatternSet cleanupPatterns(&getContext());
2875 applyPatternsGreedily(getOperation(), std::move(cleanupPatterns)))) {
2876 signalPassFailure();
2880 if (ciderSourceLocationMetadata) {
2883 SmallVector<Attribute, 16> sourceLocations;
2884 getOperation()->walk([&](calyx::ComponentOp component) {
2888 MLIRContext *context = getOperation()->getContext();
2889 getOperation()->setAttr(
"calyx.metadata",
2890 ArrayAttr::get(context, sourceLocations));
2900 return std::make_unique<SCFToCalyxPass>();
assert(baseType &&"element must be base type")
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
std::shared_ptr< calyx::CalyxLoweringState > loweringState
LogicalResult partialPatternRes
An interface for conversion passes that lower Calyx programs.
std::string irName(ValueOrBlock &v)
Returns a meaningful name for a value within the program scope.
std::string blockName(Block *b)
Returns a meaningful name for a block within the program scope (removes the ^ prefix from block names...
StringRef getTopLevelFunction() const
Returns the name of the top-level function in the source program.
T * getState(calyx::ComponentOp op)
Returns the component lowering state associated with op.
void setFuncOpResultMapping(const DenseMap< unsigned, unsigned > &mapping)
Assign a mapping between the source funcOp result indices and the corresponding output port indices o...
std::string getUniqueName(StringRef prefix)
Returns a unique name within compOp with the provided prefix.
calyx::ComponentOp component
The component which this lowering state is associated to.
void registerMemoryInterface(Value memref, const calyx::MemoryInterface &memoryInterface)
Registers a memory interface as being associated with a memory identified by 'memref'.
calyx::ComponentOp getComponentOp()
Returns the calyx::ComponentOp associated with this lowering state.
void setDataField(StringRef name, llvm::json::Array data)
ComponentLoweringStateInterface(calyx::ComponentOp component)
void setFormat(StringRef name, std::string numType, bool isSigned, unsigned width)
FuncOpPartialLoweringPatterns are patterns which intend to match on FuncOps and then perform their ow...
calyx::ComponentOp getComponent() const
Returns the component operation associated with the currently executing partial lowering.
DenseMap< mlir::func::FuncOp, calyx::ComponentOp > & functionMapping
CalyxLoweringState & loweringState() const
Return the calyx lowering state for this pattern.
FuncOpPartialLoweringPattern(MLIRContext *context, LogicalResult &resRef, PatternApplicationState &patternState, DenseMap< mlir::func::FuncOp, calyx::ComponentOp > &map, calyx::CalyxLoweringState &state)
calyx::GroupOp getLoopLatchGroup(ScfWhileOp op)
Retrieve the loop latch group registered for op.
void setLoopLatchGroup(ScfWhileOp op, calyx::GroupOp group)
Registers grp to be the loop latch group of op.
calyx::RegisterOp getLoopIterReg(ScfForOp op, unsigned idx)
Return a mapping of block argument indices to block argument.
void addLoopIterReg(ScfWhileOp op, calyx::RegisterOp reg, unsigned idx)
Register reg as being the idx'th iter_args register for 'op'.
void setLoopInitGroups(ScfWhileOp op, SmallVector< calyx::GroupOp > groups)
Registers groups to be the loop init groups of op.
SmallVector< calyx::GroupOp > getLoopInitGroups(ScfWhileOp op)
Retrieve the loop init groups registered for op.
calyx::GroupOp buildLoopIterArgAssignments(OpBuilder &builder, ScfWhileOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
Creates a new group that assigns the 'ops' values to the iter arg registers of the loop operation.
const DenseMap< unsigned, calyx::RegisterOp > & getLoopIterRegs(ScfWhileOp op)
Return a mapping of block argument indices to block argument.
PatternApplicationState & patternState
scf::ForOp getOperation()
Location getLoc() override
RepeatOpInterface(scf::ForOp op)
Holds common utilities used for scheduling when lowering to Calyx.
scf::WhileOp getOperation()
WhileOpInterface(scf::WhileOp op)
Location getLoc() override
Builds a control schedule by traversing the CFG of the function and associating this with the previou...
calyx::RepeatOp buildForCtrlOp(ScfForOp forOp, SmallVector< calyx::GroupOp > const &initGroups, uint64_t bound, PatternRewriter &rewriter) const
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult schedulePath(PatternRewriter &rewriter, const DenseSet< Block * > &path, Location loc, Block *from, Block *to, Block *parentCtrlBlock) const
Schedules a block by inserting a branch argument assignment block (if any) before recursing into the ...
calyx::WhileOp buildWhileCtrlOp(ScfWhileOp whileOp, SmallVector< calyx::GroupOp > initGroups, PatternRewriter &rewriter) const
LogicalResult scheduleBasicBlock(PatternRewriter &rewriter, const DenseSet< Block * > &path, mlir::Block *parentCtrlBlock, mlir::Block *block) const
Sequentially schedules the groups that registered themselves with 'block'.
LogicalResult buildCFGControl(DenseSet< Block * > path, PatternRewriter &rewriter, mlir::Block *parentCtrlBlock, mlir::Block *preBlock, mlir::Block *block) const
void insertParInitGroups(PatternRewriter &rewriter, Location loc, const SmallVector< calyx::GroupOp > &initGroups) const
In BuildForGroups, a register is created for the iteration argument of the for op.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Iterate through the operations of a source function and instantiate components or primitives based on...
BuildOpGroups(MLIRContext *context, LogicalResult &resRef, calyx::PatternApplicationState &patternState, DenseMap< mlir::func::FuncOp, calyx::ComponentOp > &map, calyx::CalyxLoweringState &state, mlir::Pass::Option< std::string > &writeJsonOpt)
TGroupOp createGroupForOp(PatternRewriter &rewriter, Operation *op) const
Creates a group named by the basic block which the input op resides in.
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op) const
buildLibraryOp which provides in- and output types based on the operands and results of the op argume...
LogicalResult buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp) const
Op builder specializations.
calyx::RegisterOp createSignalRegister(PatternRewriter &rewriter, Value signal, bool invert, StringRef nameSuffix, calyx::CompareFOpIEEE754 calyxCmpFOp, calyx::GroupOp group) const
void assignAddressPorts(PatternRewriter &rewriter, Location loc, calyx::GroupInterface group, calyx::MemoryInterface memoryInterface, Operation::operand_range addressValues) const
Creates assignments within the provided group to the address ports of the memoryOp based on the provi...
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op, TypeRange srcTypes, TypeRange dstTypes) const
buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the source operation TSrcOp.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult buildLibraryBinaryPipeOp(PatternRewriter &rewriter, TSrcOp op, TOpType opPipe, Value out) const
buildLibraryBinaryPipeOp will build a TCalyxLibBinaryPipeOp, to deal with MulIOp, DivUIOp and RemUIOp...
mlir::Pass::Option< std::string > & writeJson
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partialEval(PatternRewriter &rewriter, scf::ParallelOp scfParOp) const
In BuildWhileGroups, a register is created for each iteration argumenet of the while op.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Erases FuncOp operations.
LogicalResult matchAndRewrite(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Handles the current state of lowering of a Calyx component.
ComponentLoweringState(calyx::ComponentOp component)
void setForLoopInitGroups(ScfForOp op, SmallVector< calyx::GroupOp > groups)
calyx::GroupOp buildForLoopIterArgAssignments(OpBuilder &builder, ScfForOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
void setForLoopLatchGroup(ScfForOp op, calyx::GroupOp group)
SmallVector< calyx::GroupOp > getForLoopInitGroups(ScfForOp op)
void addForLoopIterReg(ScfForOp op, calyx::RegisterOp reg, unsigned idx)
calyx::GroupOp getForLoopLatchGroup(ScfForOp op)
calyx::RegisterOp getForLoopIterReg(ScfForOp op, unsigned idx)
const DenseMap< unsigned, calyx::RegisterOp > & getForLoopIterRegs(ScfForOp op)
DenseMap< Operation *, calyx::GroupOp > elseGroup
DenseMap< Operation *, calyx::GroupOp > thenGroup
const DenseMap< unsigned, calyx::RegisterOp > & getResultRegs(scf::IfOp op)
void setElseGroup(scf::IfOp op, calyx::GroupOp group)
void setResultRegs(scf::IfOp op, calyx::RegisterOp reg, unsigned idx)
void setThenGroup(scf::IfOp op, calyx::GroupOp group)
DenseMap< Operation *, DenseMap< unsigned, calyx::RegisterOp > > resultRegs
calyx::RegisterOp getResultRegs(scf::IfOp op, unsigned idx)
calyx::GroupOp getThenGroup(scf::IfOp op)
calyx::GroupOp getElseGroup(scf::IfOp op)
Inlines Calyx ExecuteRegionOp operations within their parent blocks.
LogicalResult matchAndRewrite(scf::ExecuteRegionOp execOp, PatternRewriter &rewriter) const override
LateSSAReplacement contains various functions for replacing SSA values that were not replaced during ...
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &) const override
std::optional< int64_t > getBound() override
Block::BlockArgListType getBodyArgs() override
Block * getBodyBlock() override
Block * getBodyBlock() override
ScfWhileOp(scf::WhileOp op)
Block::BlockArgListType getBodyArgs() override
Value getConditionValue() override
std::optional< int64_t > getBound() override
Block * getConditionBlock() override
calyx::GroupOp buildWhileLoopIterArgAssignments(OpBuilder &builder, ScfWhileOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
void setWhileLoopInitGroups(ScfWhileOp op, SmallVector< calyx::GroupOp > groups)
SmallVector< calyx::GroupOp > getWhileLoopInitGroups(ScfWhileOp op)
void addWhileLoopIterReg(ScfWhileOp op, calyx::RegisterOp reg, unsigned idx)
void setWhileLoopLatchGroup(ScfWhileOp op, calyx::GroupOp group)
const DenseMap< unsigned, calyx::RegisterOp > & getWhileLoopIterRegs(ScfWhileOp op)
calyx::GroupOp getWhileLoopLatchGroup(ScfWhileOp op)
void addMandatoryComponentPorts(PatternRewriter &rewriter, SmallVectorImpl< calyx::PortInfo > &ports)
void buildAssignmentsForRegisterWrite(OpBuilder &builder, calyx::GroupOp groupOp, calyx::ComponentOp componentOp, calyx::RegisterOp ®, Value inputValue)
Creates register assignment operations within the provided groupOp.
DenseMap< const mlir::RewritePattern *, SmallPtrSet< Operation *, 16 > > PatternApplicationState
Extra state that is passed to all PartialLoweringPatterns so they can record when they have run on an...
PredicateInfo getPredicateInfo(mlir::arith::CmpFPredicate pred)
Type normalizeType(OpBuilder &builder, Type type)
LogicalResult applyModuleOpConversion(mlir::ModuleOp, StringRef topLevelFunction)
Helper to update the top-level ModuleOp to set the entrypoing function.
WalkResult getCiderSourceLocationMetadata(calyx::ComponentOp component, SmallVectorImpl< Attribute > &sourceLocations)
bool matchConstantOp(Operation *op, APInt &value)
unsigned handleZeroWidth(int64_t dim)
hw::ConstantOp createConstant(Location loc, OpBuilder &builder, ComponentOp component, size_t width, size_t value)
A helper function to create constants in the HW dialect.
bool noStoresToMemory(Value memoryReference)
bool singleLoadFromMemory(Value memoryReference)
Type toBitVector(T type)
Performs a bit cast from a non-signless integer type value, such as a floating point value,...
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
static constexpr std::string_view sPortNameAttr
static LogicalResult buildAllocOp(ComponentLoweringState &componentState, PatternRewriter &rewriter, TAllocOp allocOp)
std::variant< calyx::GroupOp, WhileScheduleable, ForScheduleable, IfScheduleable, CallScheduleable, ParScheduleable > Scheduleable
A variant of types representing scheduleable operations.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< OperationPass< ModuleOp > > createSCFToCalyxPass()
Create an SCF to Calyx conversion pass.
When building groups which contain accesses to multiple sequential components, a group_done op is cre...
GroupDoneOp's are terminator operations and should therefore be the last operator in a group.
This holds information about the port for either a Component or Cell.
Predicate information for the floating point comparisons.
calyx::InstanceOp instanceOp
Instance for invoking.
ScfForOp forOp
For operation to schedule.
Creates a new Calyx component for each FuncOp in the program.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
scf::ParallelOp parOp
Parallel operation to schedule.
ScfWhileOp whileOp
While operation to schedule.