19 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
20 #include "mlir/Conversion/LLVMCommon/Pattern.h"
21 #include "mlir/Dialect/Arith/IR/Arith.h"
22 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/IR/AsmState.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Support/LogicalResult.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/LogicalResult.h"
37 #define GEN_PASS_DEF_SCFTOCALYX
38 #include "circt/Conversion/Passes.h.inc"
43 using namespace mlir::arith;
44 using namespace mlir::cf;
47 class ComponentLoweringStateInterface;
48 namespace scftocalyx {
57 : calyx::WhileOpInterface<scf::WhileOp>(op) {}
60 return getOperation().getAfterArguments();
63 Block *
getBodyBlock()
override {
return &getOperation().getAfter().front(); }
66 return &getOperation().getBefore().front();
70 return getOperation().getConditionOp().getOperand(0);
73 std::optional<int64_t>
getBound()
override {
return std::nullopt; }
78 explicit ScfForOp(scf::ForOp op) : calyx::RepeatOpInterface<scf::ForOp>(op) {}
81 return getOperation().getRegion().getArguments();
85 return &getOperation().getRegion().getBlocks().front();
89 return constantTripCount(getOperation().getLowerBound(),
90 getOperation().getUpperBound(),
91 getOperation().getStep());
135 Operation *operation = op.getOperation();
136 assert(thenGroup.count(operation) == 0 &&
137 "A then group was already set for this scf::IfOp!\n");
138 thenGroup[operation] = group;
142 auto it = thenGroup.find(op.getOperation());
143 assert(it != thenGroup.end() &&
144 "No then group was set for this scf::IfOp!\n");
149 Operation *operation = op.getOperation();
150 assert(elseGroup.count(operation) == 0 &&
151 "An else group was already set for this scf::IfOp!\n");
152 elseGroup[operation] = group;
156 auto it = elseGroup.find(op.getOperation());
157 assert(it != elseGroup.end() &&
158 "No else group was set for this scf::IfOp!\n");
163 assert(resultRegs[op.getOperation()].count(idx) == 0 &&
164 "A register was already registered for the given yield result.\n");
165 assert(idx < op->getNumOperands());
166 resultRegs[op.getOperation()][idx] =
reg;
170 return resultRegs[op.getOperation()];
174 auto regs = getResultRegs(op);
175 auto it = regs.find(idx);
176 assert(it != regs.end() &&
"resultReg not found");
183 DenseMap<Operation *, DenseMap<unsigned, calyx::RegisterOp>>
resultRegs;
190 return getLoopInitGroups(std::move(op));
193 OpBuilder &builder,
ScfWhileOp op, calyx::ComponentOp componentOp,
194 Twine uniqueSuffix, MutableArrayRef<OpOperand> ops) {
195 return buildLoopIterArgAssignments(builder, std::move(op), componentOp,
199 return addLoopIterReg(std::move(op),
reg, idx);
201 const DenseMap<unsigned, calyx::RegisterOp> &
203 return getLoopIterRegs(std::move(op));
206 return setLoopLatchGroup(std::move(op), group);
209 return getLoopLatchGroup(std::move(op));
212 SmallVector<calyx::GroupOp> groups) {
213 return setLoopInitGroups(std::move(op), std::move(groups));
221 return getLoopInitGroups(std::move(op));
224 OpBuilder &builder,
ScfForOp op, calyx::ComponentOp componentOp,
225 Twine uniqueSuffix, MutableArrayRef<OpOperand> ops) {
226 return buildLoopIterArgAssignments(builder, std::move(op), componentOp,
230 return addLoopIterReg(std::move(op),
reg, idx);
233 return getLoopIterRegs(std::move(op));
236 return getLoopIterReg(std::move(op), idx);
239 return setLoopLatchGroup(std::move(op), group);
242 return getLoopLatchGroup(std::move(op));
245 return setLoopInitGroups(std::move(op), std::move(groups));
259 : calyx::ComponentLoweringStateInterface(component) {}
269 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
273 PatternRewriter &rewriter)
const override {
276 bool opBuiltSuccessfully =
true;
277 funcOp.walk([&](Operation *_op) {
278 opBuiltSuccessfully &=
279 TypeSwitch<mlir::Operation *, bool>(_op)
280 .template Case<arith::ConstantOp, ReturnOp, BranchOpInterface,
282 scf::YieldOp, scf::WhileOp, scf::ForOp, scf::IfOp,
283 scf::ParallelOp, scf::ReduceOp,
285 memref::AllocOp, memref::AllocaOp, memref::LoadOp,
288 AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
289 AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
290 MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
292 AddFOp, MulFOp, CmpFOp,
294 SelectOp, IndexCastOp, CallOp>(
295 [&](
auto op) {
return buildOp(rewriter, op).succeeded(); })
296 .
template Case<FuncOp, scf::ConditionOp>([&](
auto) {
300 .Default([&](
auto op) {
301 op->emitError() <<
"Unhandled operation during BuildOpGroups()";
305 return opBuiltSuccessfully ? WalkResult::advance()
306 : WalkResult::interrupt();
309 return success(opBuiltSuccessfully);
314 LogicalResult buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp)
const;
315 LogicalResult buildOp(PatternRewriter &rewriter,
316 BranchOpInterface brOp)
const;
317 LogicalResult buildOp(PatternRewriter &rewriter,
318 arith::ConstantOp constOp)
const;
319 LogicalResult buildOp(PatternRewriter &rewriter, SelectOp op)
const;
320 LogicalResult buildOp(PatternRewriter &rewriter, AddIOp op)
const;
321 LogicalResult buildOp(PatternRewriter &rewriter, SubIOp op)
const;
322 LogicalResult buildOp(PatternRewriter &rewriter, MulIOp op)
const;
323 LogicalResult buildOp(PatternRewriter &rewriter, DivUIOp op)
const;
324 LogicalResult buildOp(PatternRewriter &rewriter, DivSIOp op)
const;
325 LogicalResult buildOp(PatternRewriter &rewriter, RemUIOp op)
const;
326 LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op)
const;
327 LogicalResult buildOp(PatternRewriter &rewriter, AddFOp op)
const;
328 LogicalResult buildOp(PatternRewriter &rewriter, MulFOp op)
const;
329 LogicalResult buildOp(PatternRewriter &rewriter, CmpFOp op)
const;
330 LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op)
const;
331 LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op)
const;
332 LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op)
const;
333 LogicalResult buildOp(PatternRewriter &rewriter, AndIOp op)
const;
334 LogicalResult buildOp(PatternRewriter &rewriter, OrIOp op)
const;
335 LogicalResult buildOp(PatternRewriter &rewriter, XOrIOp op)
const;
336 LogicalResult buildOp(PatternRewriter &rewriter, CmpIOp op)
const;
337 LogicalResult buildOp(PatternRewriter &rewriter, TruncIOp op)
const;
338 LogicalResult buildOp(PatternRewriter &rewriter, ExtUIOp op)
const;
339 LogicalResult buildOp(PatternRewriter &rewriter, ExtSIOp op)
const;
340 LogicalResult buildOp(PatternRewriter &rewriter, ReturnOp op)
const;
341 LogicalResult buildOp(PatternRewriter &rewriter, IndexCastOp op)
const;
342 LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocOp op)
const;
343 LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocaOp op)
const;
344 LogicalResult buildOp(PatternRewriter &rewriter, memref::LoadOp op)
const;
345 LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op)
const;
346 LogicalResult buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp)
const;
347 LogicalResult buildOp(PatternRewriter &rewriter, scf::ForOp forOp)
const;
348 LogicalResult buildOp(PatternRewriter &rewriter, scf::IfOp ifOp)
const;
349 LogicalResult buildOp(PatternRewriter &rewriter,
350 scf::ReduceOp reduceOp)
const;
351 LogicalResult buildOp(PatternRewriter &rewriter,
352 scf::ParallelOp parallelOp)
const;
353 LogicalResult buildOp(PatternRewriter &rewriter, CallOp callOp)
const;
357 template <
typename TGroupOp,
typename TCalyxLibOp,
typename TSrcOp>
359 TypeRange srcTypes, TypeRange dstTypes)
const {
360 SmallVector<Type> types;
361 llvm::append_range(types, srcTypes);
362 llvm::append_range(types, dstTypes);
365 getState<ComponentLoweringState>().getNewLibraryOpInstance<TCalyxLibOp>(
366 rewriter, op.getLoc(), types);
368 auto directions = calyxOp.portDirections();
369 SmallVector<Value, 4> opInputPorts;
370 SmallVector<Value, 4> opOutputPorts;
371 for (
auto dir : enumerate(directions)) {
373 opInputPorts.push_back(calyxOp.getResult(dir.index()));
375 opOutputPorts.push_back(calyxOp.getResult(dir.index()));
378 opInputPorts.size() == op->getNumOperands() &&
379 opOutputPorts.size() == op->getNumResults() &&
380 "Expected an equal number of in/out ports in the Calyx library op with "
381 "respect to the number of operands/results of the source operation.");
384 auto group = createGroupForOp<TGroupOp>(rewriter, op);
385 rewriter.setInsertionPointToEnd(group.getBodyBlock());
386 for (
auto dstOp : enumerate(opInputPorts))
387 rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
388 op->getOperand(dstOp.index()));
391 for (
auto res : enumerate(opOutputPorts)) {
392 getState<ComponentLoweringState>().registerEvaluatingGroup(res.value(),
394 op->getResult(res.index()).replaceAllUsesWith(res.value());
401 template <
typename TGroupOp,
typename TCalyxLibOp,
typename TSrcOp>
403 return buildLibraryOp<TGroupOp, TCalyxLibOp, TSrcOp>(
404 rewriter, op, op.getOperandTypes(), op->getResultTypes());
408 template <
typename TGroupOp>
410 Block *block = op->getBlock();
411 auto groupName = getState<ComponentLoweringState>().getUniqueName(
413 return calyx::createGroup<TGroupOp>(
414 rewriter, getState<ComponentLoweringState>().getComponentOp(),
415 op->getLoc(), groupName);
420 template <
typename TOpType,
typename TSrcOp>
422 TOpType opPipe, Value out)
const {
423 StringRef opName = TSrcOp::getOperationName().split(
".").second;
424 Location loc = op.getLoc();
425 Type
width = op.getResult().getType();
427 op.getLoc(), rewriter, getComponent(),
width.getIntOrFloatBitWidth(),
428 getState<ComponentLoweringState>().getUniqueName(opName));
430 auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
431 OpBuilder builder(group->getRegion(0));
432 getState<ComponentLoweringState>().addBlockScheduleable(op->getBlock(),
435 rewriter.setInsertionPointToEnd(group.getBodyBlock());
436 rewriter.create<calyx::AssignOp>(loc, opPipe.getLeft(), op.getLhs());
437 rewriter.create<calyx::AssignOp>(loc, opPipe.getRight(), op.getRhs());
439 rewriter.create<calyx::AssignOp>(loc,
reg.getIn(), out);
441 rewriter.create<calyx::AssignOp>(loc,
reg.getWriteEn(), opPipe.getDone());
446 rewriter.create<calyx::AssignOp>(
447 loc, opPipe.getGo(), c1,
450 rewriter.create<calyx::GroupDoneOp>(loc,
reg.getDone());
454 op.getResult().replaceAllUsesWith(
reg.getOut());
456 if (isa<calyx::AddFOpIEEE754>(opPipe)) {
457 auto opFOp = cast<calyx::AddFOpIEEE754>(opPipe);
459 if (isa<arith::AddFOp>(op)) {
466 rewriter.create<calyx::AssignOp>(loc, opFOp.getSubOp(), subOp);
470 getState<ComponentLoweringState>().registerEvaluatingGroup(out, group);
471 getState<ComponentLoweringState>().registerEvaluatingGroup(opPipe.getLeft(),
473 getState<ComponentLoweringState>().registerEvaluatingGroup(
474 opPipe.getRight(), group);
482 calyx::GroupInterface group,
484 Operation::operand_range addressValues)
const {
485 IRRewriter::InsertionGuard guard(rewriter);
486 rewriter.setInsertionPointToEnd(group.getBody());
487 auto addrPorts = memoryInterface.
addrPorts();
488 if (addressValues.empty()) {
490 addrPorts.size() == 1 &&
491 "We expected a 1 dimensional memory of size 1 because there were no "
492 "address assignment values");
494 rewriter.create<calyx::AssignOp>(
498 assert(addrPorts.size() == addressValues.size() &&
499 "Mismatch between number of address ports of the provided memory "
500 "and address assignment values");
501 for (
auto address : enumerate(addressValues))
502 rewriter.create<calyx::AssignOp>(loc, addrPorts[address.index()],
508 Value signal,
bool invert,
509 StringRef nameSuffix,
510 calyx::CompareFOpIEEE754 calyxCmpFOp,
511 calyx::GroupOp group)
const {
512 Location loc = calyxCmpFOp.getLoc();
513 IntegerType one = rewriter.getI1Type();
514 auto component = getComponent();
515 OpBuilder builder(group->getRegion(0));
517 loc, rewriter, component, 1,
518 getState<ComponentLoweringState>().getUniqueName(nameSuffix));
519 rewriter.create<calyx::AssignOp>(loc,
reg.getWriteEn(),
520 calyxCmpFOp.getDone());
522 auto notLibOp = getState<ComponentLoweringState>()
523 .getNewLibraryOpInstance<calyx::NotLibOp>(
524 rewriter, loc, {one, one});
525 rewriter.create<calyx::AssignOp>(loc, notLibOp.getIn(), signal);
526 rewriter.create<calyx::AssignOp>(loc,
reg.getIn(), notLibOp.getOut());
527 getState<ComponentLoweringState>().registerEvaluatingGroup(
528 notLibOp.getOut(), group);
530 rewriter.create<calyx::AssignOp>(loc,
reg.getIn(), signal);
535 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
536 memref::LoadOp loadOp)
const {
537 Value memref = loadOp.getMemref();
538 auto memoryInterface =
539 getState<ComponentLoweringState>().getMemoryInterface(memref);
540 auto group = createGroupForOp<calyx::GroupOp>(rewriter, loadOp);
541 assignAddressPorts(rewriter, loadOp.getLoc(), group, memoryInterface,
542 loadOp.getIndices());
544 rewriter.setInsertionPointToEnd(group.getBodyBlock());
550 if (memoryInterface.readEnOpt().has_value()) {
553 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), memoryInterface.readEn(),
555 regWriteEn = memoryInterface.done();
562 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(),
563 memoryInterface.done());
573 res = loadOp.getResult();
575 }
else if (memoryInterface.contentEnOpt().has_value()) {
580 rewriter.create<calyx::AssignOp>(loadOp.getLoc(),
581 memoryInterface.contentEn(), oneI1);
582 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), memoryInterface.writeEn(),
584 regWriteEn = memoryInterface.done();
591 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(),
592 memoryInterface.done());
602 res = loadOp.getResult();
615 loadOp.getLoc(), rewriter, getComponent(),
616 loadOp.getMemRefType().getElementTypeBitWidth(),
617 getState<ComponentLoweringState>().getUniqueName(
"load"));
618 rewriter.setInsertionPointToEnd(group.getBodyBlock());
619 rewriter.create<calyx::AssignOp>(loadOp.getLoc(),
reg.getIn(),
620 memoryInterface.readData());
621 rewriter.create<calyx::AssignOp>(loadOp.getLoc(),
reg.getWriteEn(),
623 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(),
reg.getDone());
624 loadOp.getResult().replaceAllUsesWith(
reg.getOut());
628 getState<ComponentLoweringState>().registerEvaluatingGroup(res, group);
629 getState<ComponentLoweringState>().addBlockScheduleable(loadOp->getBlock(),
634 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
635 memref::StoreOp storeOp)
const {
636 auto memoryInterface = getState<ComponentLoweringState>().getMemoryInterface(
637 storeOp.getMemref());
638 auto group = createGroupForOp<calyx::GroupOp>(rewriter, storeOp);
642 getState<ComponentLoweringState>().addBlockScheduleable(storeOp->getBlock(),
644 assignAddressPorts(rewriter, storeOp.getLoc(), group, memoryInterface,
645 storeOp.getIndices());
646 rewriter.setInsertionPointToEnd(group.getBodyBlock());
647 rewriter.create<calyx::AssignOp>(
648 storeOp.getLoc(), memoryInterface.writeData(), storeOp.getValueToStore());
649 rewriter.create<calyx::AssignOp>(
650 storeOp.getLoc(), memoryInterface.writeEn(),
651 createConstant(storeOp.getLoc(), rewriter, getComponent(), 1, 1));
652 if (memoryInterface.contentEnOpt().has_value()) {
654 rewriter.create<calyx::AssignOp>(
655 storeOp.getLoc(), memoryInterface.contentEn(),
656 createConstant(storeOp.getLoc(), rewriter, getComponent(), 1, 1));
658 rewriter.create<calyx::GroupDoneOp>(storeOp.getLoc(), memoryInterface.done());
663 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
665 Location loc = mul.getLoc();
666 Type
width = mul.getResult().getType(), one = rewriter.getI1Type();
668 getState<ComponentLoweringState>()
669 .getNewLibraryOpInstance<calyx::MultPipeLibOp>(
671 return buildLibraryBinaryPipeOp<calyx::MultPipeLibOp>(
672 rewriter, mul, mulPipe,
676 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
678 Location loc = div.getLoc();
679 Type
width = div.getResult().getType(), one = rewriter.getI1Type();
681 getState<ComponentLoweringState>()
682 .getNewLibraryOpInstance<calyx::DivUPipeLibOp>(
684 return buildLibraryBinaryPipeOp<calyx::DivUPipeLibOp>(
685 rewriter, div, divPipe,
689 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
691 Location loc = div.getLoc();
692 Type
width = div.getResult().getType(), one = rewriter.getI1Type();
694 getState<ComponentLoweringState>()
695 .getNewLibraryOpInstance<calyx::DivSPipeLibOp>(
697 return buildLibraryBinaryPipeOp<calyx::DivSPipeLibOp>(
698 rewriter, div, divPipe,
702 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
704 Location loc = rem.getLoc();
705 Type
width = rem.getResult().getType(), one = rewriter.getI1Type();
707 getState<ComponentLoweringState>()
708 .getNewLibraryOpInstance<calyx::RemUPipeLibOp>(
710 return buildLibraryBinaryPipeOp<calyx::RemUPipeLibOp>(
711 rewriter, rem, remPipe,
715 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
717 Location loc = rem.getLoc();
718 Type
width = rem.getResult().getType(), one = rewriter.getI1Type();
720 getState<ComponentLoweringState>()
721 .getNewLibraryOpInstance<calyx::RemSPipeLibOp>(
723 return buildLibraryBinaryPipeOp<calyx::RemSPipeLibOp>(
724 rewriter, rem, remPipe,
728 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
730 Location loc = addf.getLoc();
731 IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
732 five = rewriter.getIntegerType(5),
733 width = rewriter.getIntegerType(
734 addf.getType().getIntOrFloatBitWidth());
736 getState<ComponentLoweringState>()
737 .getNewLibraryOpInstance<calyx::AddFOpIEEE754>(
739 {one, one, one, one, one,
width,
width, three,
width, five, one});
740 return buildLibraryBinaryPipeOp<calyx::AddFOpIEEE754>(rewriter, addf, addFOp,
744 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
746 Location loc = mulf.getLoc();
747 IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
748 five = rewriter.getIntegerType(5),
749 width = rewriter.getIntegerType(
750 mulf.getType().getIntOrFloatBitWidth());
752 getState<ComponentLoweringState>()
753 .getNewLibraryOpInstance<calyx::MulFOpIEEE754>(
756 return buildLibraryBinaryPipeOp<calyx::MulFOpIEEE754>(rewriter, mulf, mulFOp,
760 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
762 Location loc = cmpf.getLoc();
763 IntegerType one = rewriter.getI1Type(), five = rewriter.getIntegerType(5),
764 width = rewriter.getIntegerType(
765 cmpf.getLhs().getType().getIntOrFloatBitWidth());
766 auto calyxCmpFOp = getState<ComponentLoweringState>()
767 .getNewLibraryOpInstance<calyx::CompareFOpIEEE754>(
769 {one, one, one,
width,
width, one, one, one, one,
773 rewriter.setInsertionPointToStart(getComponent().
getBodyBlock());
776 using CombLogic = PredicateInfo::CombLogic;
777 using Port = PredicateInfo::InputPorts::Port;
779 if (info.logic == CombLogic::None) {
780 if (cmpf.getPredicate() == CmpFPredicate::AlwaysTrue) {
781 cmpf.getResult().replaceAllUsesWith(c1);
785 if (cmpf.getPredicate() == CmpFPredicate::AlwaysFalse) {
786 cmpf.getResult().replaceAllUsesWith(c0);
792 StringRef opName = cmpf.getOperationName().split(
".").second;
795 getState<ComponentLoweringState>().getUniqueName(opName));
798 auto group = createGroupForOp<calyx::GroupOp>(rewriter, cmpf);
799 OpBuilder builder(group->getRegion(0));
800 getState<ComponentLoweringState>().addBlockScheduleable(cmpf->getBlock(),
803 rewriter.setInsertionPointToEnd(group.getBodyBlock());
804 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getLeft(), cmpf.getLhs());
805 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getRight(), cmpf.getRhs());
807 bool signalingFlag =
false;
808 switch (cmpf.getPredicate()) {
809 case CmpFPredicate::UGT:
810 case CmpFPredicate::UGE:
811 case CmpFPredicate::ULT:
812 case CmpFPredicate::ULE:
813 case CmpFPredicate::OGT:
814 case CmpFPredicate::OGE:
815 case CmpFPredicate::OLT:
816 case CmpFPredicate::OLE:
817 signalingFlag =
true;
819 case CmpFPredicate::UEQ:
820 case CmpFPredicate::UNE:
821 case CmpFPredicate::OEQ:
822 case CmpFPredicate::ONE:
823 case CmpFPredicate::UNO:
824 case CmpFPredicate::ORD:
825 case CmpFPredicate::AlwaysTrue:
826 case CmpFPredicate::AlwaysFalse:
827 signalingFlag =
false;
833 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getSignaling(),
834 signalingFlag ? c1 : c0);
837 SmallVector<calyx::RegisterOp> inputRegs;
838 for (
const auto &input : info.inputPorts) {
840 switch (input.port) {
842 signal = calyxCmpFOp.getEq();
846 signal = calyxCmpFOp.getGt();
850 signal = calyxCmpFOp.getLt();
853 case Port::Unordered: {
854 signal = calyxCmpFOp.getUnordered();
858 std::string nameSuffix =
859 (input.port == PredicateInfo::InputPorts::Port::Unordered)
862 auto signalReg = createSignalRegister(rewriter, signal, input.invert,
863 nameSuffix, calyxCmpFOp, group);
864 inputRegs.push_back(signalReg);
868 Value outputValue, doneValue;
869 switch (info.logic) {
870 case CombLogic::None: {
872 outputValue = inputRegs[0].getOut();
873 doneValue = inputRegs[0].getOut();
876 case CombLogic::And: {
877 auto outputLibOp = getState<ComponentLoweringState>()
878 .getNewLibraryOpInstance<calyx::AndLibOp>(
879 rewriter, loc, {one, one, one});
880 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
881 inputRegs[0].getOut());
882 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
883 inputRegs[1].getOut());
885 outputValue = outputLibOp.getOut();
888 case CombLogic::Or: {
889 auto outputLibOp = getState<ComponentLoweringState>()
890 .getNewLibraryOpInstance<calyx::OrLibOp>(
891 rewriter, loc, {one, one, one});
892 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
893 inputRegs[0].getOut());
894 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
895 inputRegs[1].getOut());
897 outputValue = outputLibOp.getOut();
902 if (info.logic != CombLogic::None) {
903 auto doneLibOp = getState<ComponentLoweringState>()
904 .getNewLibraryOpInstance<calyx::AndLibOp>(
905 rewriter, loc, {one, one, one});
906 rewriter.create<calyx::AssignOp>(loc, doneLibOp.getLeft(),
907 inputRegs[0].getDone());
908 rewriter.create<calyx::AssignOp>(loc, doneLibOp.getRight(),
909 inputRegs[1].getDone());
910 doneValue = doneLibOp.getOut();
914 rewriter.create<calyx::AssignOp>(loc,
reg.getIn(), outputValue);
915 rewriter.create<calyx::AssignOp>(loc,
reg.getWriteEn(), doneValue);
918 rewriter.create<calyx::AssignOp>(
919 loc, calyxCmpFOp.getGo(), c1,
921 rewriter.create<calyx::GroupDoneOp>(loc,
reg.getDone());
923 cmpf.getResult().replaceAllUsesWith(
reg.getOut());
926 getState<ComponentLoweringState>().registerEvaluatingGroup(outputValue,
928 getState<ComponentLoweringState>().registerEvaluatingGroup(doneValue, group);
929 getState<ComponentLoweringState>().registerEvaluatingGroup(
930 calyxCmpFOp.getLeft(), group);
931 getState<ComponentLoweringState>().registerEvaluatingGroup(
932 calyxCmpFOp.getRight(), group);
937 template <
typename TAllocOp>
939 PatternRewriter &rewriter, TAllocOp allocOp) {
940 rewriter.setInsertionPointToStart(
942 MemRefType memtype = allocOp.getType();
943 SmallVector<int64_t> addrSizes;
944 SmallVector<int64_t> sizes;
945 for (int64_t dim : memtype.getShape()) {
946 sizes.push_back(dim);
951 if (sizes.empty() && addrSizes.empty()) {
953 addrSizes.push_back(1);
955 auto memoryOp = rewriter.create<calyx::SeqMemoryOp>(
957 memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
961 memoryOp->setAttr(
"external",
968 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
969 memref::AllocOp allocOp)
const {
970 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
973 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
974 memref::AllocaOp allocOp)
const {
975 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
978 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
979 scf::YieldOp yieldOp)
const {
980 if (yieldOp.getOperands().empty()) {
982 auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
983 assert(forOp &&
"Empty yieldOps should only be located within ForOps");
988 getState<ComponentLoweringState>().getForLoopIterReg(forOpInterface, 0);
990 Type regWidth = inductionReg.getOut().getType();
992 SmallVector<Type> types(3, regWidth);
993 auto addOp = getState<ComponentLoweringState>()
994 .getNewLibraryOpInstance<calyx::AddLibOp>(
995 rewriter, forOp.getLoc(), types);
997 auto directions = addOp.portDirections();
999 SmallVector<Value, 2> opInputPorts;
1001 for (
auto dir : enumerate(directions)) {
1002 switch (dir.value()) {
1004 opInputPorts.push_back(addOp.getResult(dir.index()));
1008 opOutputPort = addOp.getResult(dir.index());
1015 calyx::ComponentOp componentOp =
1016 getState<ComponentLoweringState>().getComponentOp();
1017 SmallVector<StringRef, 4> groupIdentifier = {
1018 "incr", getState<ComponentLoweringState>().getUniqueName(forOp),
1019 "induction",
"var"};
1020 auto groupOp = calyx::createGroup<calyx::GroupOp>(
1021 rewriter, componentOp, forOp.getLoc(),
1022 llvm::join(groupIdentifier,
"_"));
1023 rewriter.setInsertionPointToEnd(groupOp.getBodyBlock());
1026 Value leftOp = opInputPorts.front();
1027 rewriter.create<calyx::AssignOp>(forOp.getLoc(), leftOp,
1028 inductionReg.getOut());
1030 Value rightOp = opInputPorts.back();
1031 rewriter.create<calyx::AssignOp>(
1032 forOp.getLoc(), rightOp,
1034 regWidth.getIntOrFloatBitWidth(),
1035 forOp.getConstantStep().value().getSExtValue()));
1038 inductionReg, opOutputPort);
1040 getState<ComponentLoweringState>().setForLoopLatchGroup(forOpInterface,
1042 getState<ComponentLoweringState>().registerEvaluatingGroup(opOutputPort,
1047 if (dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
1048 return yieldOp.getOperation()->emitError()
1049 <<
"Currently do not support non-empty yield operations inside for "
1050 "loops. Run --scf-for-to-while before running --scf-to-calyx.";
1053 if (
auto whileOp = dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1057 getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
1058 rewriter, whileOpInterface,
1059 getState<ComponentLoweringState>().getComponentOp(),
1060 getState<ComponentLoweringState>().getUniqueName(whileOp) +
1062 yieldOp->getOpOperands());
1063 getState<ComponentLoweringState>().setWhileLoopLatchGroup(whileOpInterface,
1068 if (
auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
1069 auto resultRegs = getState<ComponentLoweringState>().getResultRegs(ifOp);
1071 if (yieldOp->getParentRegion() == &ifOp.getThenRegion()) {
1072 auto thenGroup = getState<ComponentLoweringState>().getThenGroup(ifOp);
1073 for (
auto op : enumerate(yieldOp.getOperands())) {
1075 getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
1077 rewriter, thenGroup,
1078 getState<ComponentLoweringState>().getComponentOp(), resultReg,
1080 getState<ComponentLoweringState>().registerEvaluatingGroup(
1081 ifOp.getResult(op.index()), thenGroup);
1085 if (!ifOp.getElseRegion().empty() &&
1086 (yieldOp->getParentRegion() == &ifOp.getElseRegion())) {
1087 auto elseGroup = getState<ComponentLoweringState>().getElseGroup(ifOp);
1088 for (
auto op : enumerate(yieldOp.getOperands())) {
1090 getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
1092 rewriter, elseGroup,
1093 getState<ComponentLoweringState>().getComponentOp(), resultReg,
1095 getState<ComponentLoweringState>().registerEvaluatingGroup(
1096 ifOp.getResult(op.index()), elseGroup);
1103 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1104 BranchOpInterface brOp)
const {
1109 Block *srcBlock = brOp->getBlock();
1110 for (
auto succBlock : enumerate(brOp->getSuccessors())) {
1111 auto succOperands = brOp.getSuccessorOperands(succBlock.index());
1112 if (succOperands.empty())
1115 std::string groupName =
loweringState().blockName(srcBlock) +
"_to_" +
1117 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter, getComponent(),
1118 brOp.getLoc(), groupName);
1120 auto dstBlockArgRegs =
1121 getState<ComponentLoweringState>().getBlockArgRegs(succBlock.value());
1123 for (
auto arg : enumerate(succOperands.getForwardedOperands())) {
1124 auto reg = dstBlockArgRegs[arg.index()];
1127 getState<ComponentLoweringState>().getComponentOp(),
reg,
1132 getState<ComponentLoweringState>().addBlockArgGroup(
1133 srcBlock, succBlock.value(), groupOp);
1140 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1141 ReturnOp retOp)
const {
1142 if (retOp.getNumOperands() == 0)
1145 std::string groupName =
1146 getState<ComponentLoweringState>().getUniqueName(
"ret_assign");
1147 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter, getComponent(),
1148 retOp.getLoc(), groupName);
1149 for (
auto op : enumerate(retOp.getOperands())) {
1150 auto reg = getState<ComponentLoweringState>().getReturnReg(op.index());
1152 rewriter, groupOp, getState<ComponentLoweringState>().getComponentOp(),
1156 getState<ComponentLoweringState>().addBlockScheduleable(retOp->getBlock(),
1161 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1162 arith::ConstantOp constOp)
const {
1163 if (isa<IntegerType>(constOp.getType())) {
1169 hwConstOp->moveAfter(getComponent().getBodyBlock(),
1170 getComponent().getBodyBlock()->begin());
1172 std::string name = getState<ComponentLoweringState>().getUniqueName(
"cst");
1173 auto floatAttr = cast<FloatAttr>(constOp.getValueAttr());
1175 rewriter.getIntegerType(floatAttr.getType().getIntOrFloatBitWidth());
1176 auto calyxConstOp = rewriter.create<calyx::ConstantOp>(
1177 constOp.getLoc(), name, floatAttr, intType);
1178 calyxConstOp->moveAfter(getComponent().getBodyBlock(),
1179 getComponent().getBodyBlock()->begin());
1180 rewriter.replaceAllUsesWith(constOp, calyxConstOp.getOut());
1186 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1188 return buildLibraryOp<calyx::CombGroupOp, calyx::AddLibOp>(rewriter, op);
1190 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1192 return buildLibraryOp<calyx::CombGroupOp, calyx::SubLibOp>(rewriter, op);
1194 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1196 return buildLibraryOp<calyx::CombGroupOp, calyx::RshLibOp>(rewriter, op);
1198 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1200 return buildLibraryOp<calyx::CombGroupOp, calyx::SrshLibOp>(rewriter, op);
1202 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1204 return buildLibraryOp<calyx::CombGroupOp, calyx::LshLibOp>(rewriter, op);
1206 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1208 return buildLibraryOp<calyx::CombGroupOp, calyx::AndLibOp>(rewriter, op);
1210 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1212 return buildLibraryOp<calyx::CombGroupOp, calyx::OrLibOp>(rewriter, op);
1214 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1216 return buildLibraryOp<calyx::CombGroupOp, calyx::XorLibOp>(rewriter, op);
1218 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1219 SelectOp op)
const {
1220 return buildLibraryOp<calyx::CombGroupOp, calyx::MuxLibOp>(rewriter, op);
1223 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1225 switch (op.getPredicate()) {
1226 case CmpIPredicate::eq:
1227 return buildLibraryOp<calyx::CombGroupOp, calyx::EqLibOp>(rewriter, op);
1228 case CmpIPredicate::ne:
1229 return buildLibraryOp<calyx::CombGroupOp, calyx::NeqLibOp>(rewriter, op);
1230 case CmpIPredicate::uge:
1231 return buildLibraryOp<calyx::CombGroupOp, calyx::GeLibOp>(rewriter, op);
1232 case CmpIPredicate::ult:
1233 return buildLibraryOp<calyx::CombGroupOp, calyx::LtLibOp>(rewriter, op);
1234 case CmpIPredicate::ugt:
1235 return buildLibraryOp<calyx::CombGroupOp, calyx::GtLibOp>(rewriter, op);
1236 case CmpIPredicate::ule:
1237 return buildLibraryOp<calyx::CombGroupOp, calyx::LeLibOp>(rewriter, op);
1238 case CmpIPredicate::sge:
1239 return buildLibraryOp<calyx::CombGroupOp, calyx::SgeLibOp>(rewriter, op);
1240 case CmpIPredicate::slt:
1241 return buildLibraryOp<calyx::CombGroupOp, calyx::SltLibOp>(rewriter, op);
1242 case CmpIPredicate::sgt:
1243 return buildLibraryOp<calyx::CombGroupOp, calyx::SgtLibOp>(rewriter, op);
1244 case CmpIPredicate::sle:
1245 return buildLibraryOp<calyx::CombGroupOp, calyx::SleLibOp>(rewriter, op);
1247 llvm_unreachable(
"unsupported comparison predicate");
1249 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1250 TruncIOp op)
const {
1251 return buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
1252 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1254 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1256 return buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
1257 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1260 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1262 return buildLibraryOp<calyx::CombGroupOp, calyx::ExtSILibOp>(
1263 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1266 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1267 IndexCastOp op)
const {
1270 unsigned targetBits = targetType.getIntOrFloatBitWidth();
1271 unsigned sourceBits = sourceType.getIntOrFloatBitWidth();
1272 LogicalResult res = success();
1274 if (targetBits == sourceBits) {
1277 op.getResult().replaceAllUsesWith(op.getOperand());
1280 if (sourceBits > targetBits)
1281 res = buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
1282 rewriter, op, {sourceType}, {targetType});
1284 res = buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
1285 rewriter, op, {sourceType}, {targetType});
1287 rewriter.eraseOp(op);
1291 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1292 scf::WhileOp whileOp)
const {
1296 getState<ComponentLoweringState>().addBlockScheduleable(
1301 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1302 scf::ForOp forOp)
const {
1308 std::optional<uint64_t> bound = scfForOp.
getBound();
1309 if (!bound.has_value()) {
1311 <<
"Loop bound not statically known. Should "
1312 "transform into while loop using `--scf-for-to-while` before "
1313 "running --lower-scf-to-calyx.";
1315 getState<ComponentLoweringState>().addBlockScheduleable(
1323 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1324 scf::IfOp ifOp)
const {
1325 getState<ComponentLoweringState>().addBlockScheduleable(
1330 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1331 scf::ReduceOp reduceOp)
const {
1338 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1339 scf::ParallelOp parOp)
const {
1340 getState<ComponentLoweringState>().addBlockScheduleable(
1345 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1346 CallOp callOp)
const {
1348 calyx::InstanceOp instanceOp =
1349 getState<ComponentLoweringState>().getInstance(instanceName);
1350 SmallVector<Value, 4> outputPorts;
1351 auto portInfos = instanceOp.getReferencedComponent().getPortInfo();
1352 for (
auto [idx, portInfo] : enumerate(portInfos)) {
1354 outputPorts.push_back(instanceOp.getResult(idx));
1358 for (
auto [idx, result] : llvm::enumerate(callOp.getResults()))
1359 rewriter.replaceAllUsesWith(result, outputPorts[idx]);
1363 getState<ComponentLoweringState>().addBlockScheduleable(
1377 using OpRewritePattern::OpRewritePattern;
1380 PatternRewriter &rewriter)
const override {
1382 TypeRange yieldTypes = execOp.getResultTypes();
1386 rewriter.setInsertionPointAfter(execOp);
1387 auto *sinkBlock = rewriter.splitBlock(
1389 execOp.getOperation()->getIterator()->getNextNode()->getIterator());
1390 sinkBlock->addArguments(
1392 SmallVector<Location, 4>(yieldTypes.size(), rewriter.getUnknownLoc()));
1393 for (
auto res : enumerate(execOp.getResults()))
1394 res.value().replaceAllUsesWith(sinkBlock->getArgument(res.index()));
1398 make_early_inc_range(execOp.getRegion().getOps<scf::YieldOp>())) {
1399 rewriter.setInsertionPointAfter(yieldOp);
1400 rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, sinkBlock,
1401 yieldOp.getOperands());
1405 auto *preBlock = execOp->getBlock();
1406 auto *execOpEntryBlock = &execOp.getRegion().front();
1407 auto *postBlock = execOp->getBlock()->splitBlock(execOp);
1408 rewriter.inlineRegionBefore(execOp.getRegion(), postBlock);
1409 rewriter.mergeBlocks(postBlock, preBlock);
1410 rewriter.eraseOp(execOp);
1413 rewriter.mergeBlocks(execOpEntryBlock, preBlock);
1421 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1425 PatternRewriter &rewriter)
const override {
1428 DenseMap<Value, unsigned> funcOpArgRewrites;
1432 DenseMap<unsigned, unsigned> funcOpResultMapping;
1440 DenseMap<Value, std::pair<unsigned, unsigned>> extMemoryCompPortIndices;
1444 SmallVector<calyx::PortInfo> inPorts, outPorts;
1445 FunctionType funcType = funcOp.getFunctionType();
1446 for (
auto arg : enumerate(funcOp.getArguments())) {
1447 if (!isa<MemRefType>(arg.value().getType())) {
1450 if (
auto portNameAttr = funcOp.getArgAttrOfType<StringAttr>(
1452 inName = portNameAttr.str();
1454 inName =
"in" + std::to_string(arg.index());
1455 funcOpArgRewrites[arg.value()] = inPorts.size();
1457 rewriter.getStringAttr(inName),
1463 for (
auto res : enumerate(funcType.getResults())) {
1464 std::string resName;
1465 if (
auto portNameAttr = funcOp.getResultAttrOfType<StringAttr>(
1467 resName = portNameAttr.str();
1469 resName =
"out" + std::to_string(res.index());
1470 funcOpResultMapping[res.index()] = outPorts.size();
1473 rewriter.getStringAttr(resName),
1480 auto ports = inPorts;
1481 llvm::append_range(ports, outPorts);
1485 auto compOp = rewriter.create<calyx::ComponentOp>(
1486 funcOp.getLoc(), rewriter.getStringAttr(funcOp.getSymName()), ports);
1488 std::string funcName =
"func_" + funcOp.getSymName().str();
1489 rewriter.modifyOpInPlace(funcOp, [&]() { funcOp.setSymName(funcName); });
1493 if (compOp.getName() ==
loweringState().getTopLevelFunction())
1494 compOp->setAttr(
"toplevel", rewriter.getUnitAttr());
1497 functionMapping[funcOp] = compOp;
1501 unsigned extMemCounter = 0;
1502 for (
auto arg : enumerate(funcOp.getArguments())) {
1503 if (isa<MemRefType>(arg.value().getType())) {
1504 std::string memName =
1505 llvm::join_items(
"_",
"arg_mem", std::to_string(extMemCounter++));
1507 rewriter.setInsertionPointToStart(compOp.getBodyBlock());
1508 MemRefType memtype = cast<MemRefType>(arg.value().getType());
1509 SmallVector<int64_t> addrSizes;
1510 SmallVector<int64_t> sizes;
1511 for (int64_t dim : memtype.getShape()) {
1512 sizes.push_back(dim);
1515 if (sizes.empty() && addrSizes.empty()) {
1517 addrSizes.push_back(1);
1519 auto memOp = rewriter.create<calyx::SeqMemoryOp>(
1520 funcOp.getLoc(), memName,
1521 memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
1524 compState->registerMemoryInterface(arg.value(),
1530 for (
auto &mapping : funcOpArgRewrites)
1531 mapping.getFirst().replaceAllUsesWith(
1532 compOp.getArgument(mapping.getSecond()));
1543 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1547 PatternRewriter &rewriter)
const override {
1548 LogicalResult res = success();
1549 funcOp.walk([&](Operation *op) {
1551 if (!isa<scf::WhileOp>(op))
1552 return WalkResult::advance();
1554 auto scfWhileOp = cast<scf::WhileOp>(op);
1557 getState<ComponentLoweringState>().setUniqueName(whileOp.
getOperation(),
1567 enumerate(scfWhileOp.getBefore().front().getArguments())) {
1568 auto condOp = scfWhileOp.getConditionOp().getArgs()[barg.index()];
1569 if (barg.value() != condOp) {
1573 <<
"do-while loops not supported; expected iter-args to "
1574 "remain untransformed in the 'before' region of the "
1576 return WalkResult::interrupt();
1585 for (
auto arg : enumerate(whileOp.
getBodyArgs())) {
1586 std::string name = getState<ComponentLoweringState>()
1589 "_arg" + std::to_string(arg.index());
1592 arg.value().getType().getIntOrFloatBitWidth(), name);
1593 getState<ComponentLoweringState>().addWhileLoopIterReg(whileOp,
reg,
1595 arg.value().replaceAllUsesWith(
reg.getOut());
1599 ->getArgument(arg.index())
1600 .replaceAllUsesWith(
reg.getOut());
1604 SmallVector<calyx::GroupOp> initGroups;
1605 auto numOperands = whileOp.
getOperation()->getNumOperands();
1606 for (
size_t i = 0; i < numOperands; ++i) {
1608 getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
1610 getState<ComponentLoweringState>().getComponentOp(),
1611 getState<ComponentLoweringState>().getUniqueName(
1613 "_init_" + std::to_string(i),
1615 initGroups.push_back(initGroupOp);
1618 getState<ComponentLoweringState>().setWhileLoopInitGroups(whileOp,
1621 return WalkResult::advance();
1631 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1635 PatternRewriter &rewriter)
const override {
1636 LogicalResult res = success();
1637 funcOp.walk([&](Operation *op) {
1639 if (!isa<scf::ForOp>(op))
1640 return WalkResult::advance();
1642 auto scfForOp = cast<scf::ForOp>(op);
1645 getState<ComponentLoweringState>().setUniqueName(forOp.
getOperation(),
1650 auto inductionVar = forOp.
getOperation().getInductionVar();
1651 SmallVector<std::string, 3> inductionVarIdentifiers = {
1652 getState<ComponentLoweringState>()
1655 "induction",
"var"};
1656 std::string name = llvm::join(inductionVarIdentifiers,
"_");
1659 inductionVar.getType().getIntOrFloatBitWidth(), name);
1660 getState<ComponentLoweringState>().addForLoopIterReg(forOp,
reg, 0);
1661 inductionVar.replaceAllUsesWith(
reg.getOut());
1664 calyx::ComponentOp componentOp =
1665 getState<ComponentLoweringState>().getComponentOp();
1666 SmallVector<calyx::GroupOp> initGroups;
1667 SmallVector<std::string, 4> groupIdentifiers = {
1669 getState<ComponentLoweringState>()
1672 "induction",
"var"};
1673 std::string groupName = llvm::join(groupIdentifiers,
"_");
1674 auto groupOp = calyx::createGroup<calyx::GroupOp>(
1675 rewriter, componentOp, forOp.
getLoc(), groupName);
1678 initGroups.push_back(groupOp);
1679 getState<ComponentLoweringState>().setForLoopInitGroups(forOp,
1682 return WalkResult::advance();
1689 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1693 PatternRewriter &rewriter)
const override {
1694 LogicalResult res = success();
1695 funcOp.walk([&](Operation *op) {
1696 if (!isa<scf::IfOp>(op))
1697 return WalkResult::advance();
1699 auto scfIfOp = cast<scf::IfOp>(op);
1701 calyx::ComponentOp componentOp =
1702 getState<ComponentLoweringState>().getComponentOp();
1704 std::string thenGroupName =
1705 getState<ComponentLoweringState>().getUniqueName(
"then_br");
1706 auto thenGroupOp = calyx::createGroup<calyx::GroupOp>(
1707 rewriter, componentOp, scfIfOp.getLoc(), thenGroupName);
1708 getState<ComponentLoweringState>().setThenGroup(scfIfOp, thenGroupOp);
1710 if (!scfIfOp.getElseRegion().empty()) {
1711 std::string elseGroupName =
1712 getState<ComponentLoweringState>().getUniqueName(
"else_br");
1713 auto elseGroupOp = calyx::createGroup<calyx::GroupOp>(
1714 rewriter, componentOp, scfIfOp.getLoc(), elseGroupName);
1715 getState<ComponentLoweringState>().setElseGroup(scfIfOp, elseGroupOp);
1718 for (
auto ifOpRes : scfIfOp.getResults()) {
1720 scfIfOp.getLoc(), rewriter, getComponent(),
1721 ifOpRes.getType().getIntOrFloatBitWidth(),
1722 getState<ComponentLoweringState>().getUniqueName(
"if_res"));
1723 getState<ComponentLoweringState>().setResultRegs(
1724 scfIfOp,
reg, ifOpRes.getResultNumber());
1727 return WalkResult::advance();
1734 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1738 PatternRewriter &rewriter)
const override {
1739 WalkResult walkResult = funcOp.walk([&](scf::ParallelOp scfParOp) {
1740 if (!scfParOp.getResults().empty()) {
1742 "Reduce operations in scf.parallel is not supported yet");
1743 return WalkResult::interrupt();
1746 if (failed(partialEval(rewriter, scfParOp)))
1747 return WalkResult::interrupt();
1749 return WalkResult::advance();
1752 return walkResult.wasInterrupted() ? failure() : success();
1759 scf::ParallelOp scfParOp)
const {
1760 assert(scfParOp.getLoopSteps() &&
"Parallel loop must have steps");
1761 auto *body = scfParOp.getBody();
1762 auto parOpIVs = scfParOp.getInductionVars();
1763 auto steps = scfParOp.getStep();
1764 auto lowerBounds = scfParOp.getLowerBound();
1765 auto upperBounds = scfParOp.getUpperBound();
1766 rewriter.setInsertionPointAfter(scfParOp);
1767 scf::ParallelOp newParOp = scfParOp.cloneWithoutRegions();
1768 auto loc = newParOp.getLoc();
1769 rewriter.insert(newParOp);
1770 OpBuilder insideBuilder(newParOp);
1771 Block *currBlock =
nullptr;
1772 auto ®ion = newParOp.getRegion();
1773 IRMapping operandMap;
1776 SmallVector<int64_t> lbVals, ubVals, stepVals;
1777 for (
auto lb : lowerBounds) {
1778 auto lbOp = lb.getDefiningOp<arith::ConstantIndexOp>();
1780 "Lower bound must be a statically computable constant index");
1781 lbVals.push_back(lbOp.value());
1783 for (
auto ub : upperBounds) {
1784 auto ubOp = ub.getDefiningOp<arith::ConstantIndexOp>();
1786 "Upper bound must be a statically computable constant index");
1787 ubVals.push_back(ubOp.value());
1789 for (
auto step : steps) {
1790 auto stepOp = step.getDefiningOp<arith::ConstantIndexOp>();
1791 assert(stepOp &&
"Step must be a statically computable constant index");
1792 stepVals.push_back(stepOp.value());
1796 SmallVector<int64_t> indices = lbVals;
1800 currBlock = ®ion.emplaceBlock();
1801 insideBuilder.setInsertionPointToEnd(currBlock);
1804 for (
unsigned i = 0; i < indices.size(); ++i) {
1806 insideBuilder.create<arith::ConstantIndexOp>(loc, indices[i]);
1807 operandMap.map(parOpIVs[i], ivConstant);
1810 for (
auto it = body->begin(); it != std::prev(body->end()); ++it)
1811 insideBuilder.clone(*it, operandMap);
1815 for (
int dim = indices.size() - 1; dim >= 0; --dim) {
1816 indices[dim] += stepVals[dim];
1817 if (indices[dim] < ubVals[dim])
1819 indices[dim] = lbVals[dim];
1828 rewriter.replaceOp(scfParOp, newParOp);
1839 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1843 PatternRewriter &rewriter)
const override {
1844 auto *entryBlock = &funcOp.getBlocks().front();
1845 rewriter.setInsertionPointToStart(
1847 auto topLevelSeqOp = rewriter.create<calyx::SeqOp>(funcOp.getLoc());
1848 DenseSet<Block *> path;
1849 return buildCFGControl(path, rewriter, topLevelSeqOp.getBodyBlock(),
1850 nullptr, entryBlock);
1857 const DenseSet<Block *> &path,
1858 mlir::Block *parentCtrlBlock,
1859 mlir::Block *block)
const {
1860 auto compBlockScheduleables =
1861 getState<ComponentLoweringState>().getBlockScheduleables(block);
1862 auto loc = block->front().getLoc();
1864 if (compBlockScheduleables.size() > 1 &&
1865 !isa<scf::ParallelOp>(block->getParentOp())) {
1866 auto seqOp = rewriter.create<calyx::SeqOp>(loc);
1867 parentCtrlBlock = seqOp.getBodyBlock();
1870 for (
auto &group : compBlockScheduleables) {
1871 rewriter.setInsertionPointToEnd(parentCtrlBlock);
1872 if (
auto groupPtr = std::get_if<calyx::GroupOp>(&group); groupPtr) {
1873 rewriter.create<calyx::EnableOp>(groupPtr->getLoc(),
1874 groupPtr->getSymName());
1875 }
else if (
auto whileSchedPtr = std::get_if<WhileScheduleable>(&group);
1877 auto &whileOp = whileSchedPtr->whileOp;
1879 auto whileCtrlOp = buildWhileCtrlOp(
1881 getState<ComponentLoweringState>().getWhileLoopInitGroups(whileOp),
1883 rewriter.setInsertionPointToEnd(whileCtrlOp.getBodyBlock());
1885 rewriter.create<calyx::SeqOp>(whileOp.getOperation()->getLoc());
1886 auto *whileBodyOpBlock = whileBodyOp.getBodyBlock();
1890 if (LogicalResult result =
1891 buildCFGControl(path, rewriter, whileBodyOpBlock, block,
1892 whileOp.getBodyBlock());
1897 rewriter.setInsertionPointToEnd(whileBodyOpBlock);
1898 calyx::GroupOp whileLatchGroup =
1899 getState<ComponentLoweringState>().getWhileLoopLatchGroup(whileOp);
1900 rewriter.create<calyx::EnableOp>(whileLatchGroup.getLoc(),
1901 whileLatchGroup.getName());
1902 }
else if (
auto *parSchedPtr = std::get_if<ParScheduleable>(&group)) {
1903 auto parOp = parSchedPtr->parOp;
1904 auto calyxParOp = rewriter.create<calyx::ParOp>(parOp.getLoc());
1905 for (
auto &innerBlock : parOp.getRegion().getBlocks()) {
1906 rewriter.setInsertionPointToEnd(calyxParOp.getBodyBlock());
1907 auto seqOp = rewriter.create<calyx::SeqOp>(parOp.getLoc());
1908 rewriter.setInsertionPointToEnd(seqOp.getBodyBlock());
1909 if (LogicalResult res = scheduleBasicBlock(
1910 rewriter, path, seqOp.getBodyBlock(), &innerBlock);
1914 }
else if (
auto *forSchedPtr = std::get_if<ForScheduleable>(&group);
1916 auto forOp = forSchedPtr->forOp;
1918 auto forCtrlOp = buildForCtrlOp(
1920 getState<ComponentLoweringState>().getForLoopInitGroups(forOp),
1921 forSchedPtr->bound, rewriter);
1922 rewriter.setInsertionPointToEnd(forCtrlOp.getBodyBlock());
1924 rewriter.create<calyx::SeqOp>(forOp.getOperation()->getLoc());
1925 auto *forBodyOpBlock = forBodyOp.getBodyBlock();
1928 if (LogicalResult res = buildCFGControl(path, rewriter, forBodyOpBlock,
1929 block, forOp.getBodyBlock());
1934 rewriter.setInsertionPointToEnd(forBodyOpBlock);
1935 calyx::GroupOp forLatchGroup =
1936 getState<ComponentLoweringState>().getForLoopLatchGroup(forOp);
1937 rewriter.create<calyx::EnableOp>(forLatchGroup.getLoc(),
1938 forLatchGroup.getName());
1939 }
else if (
auto *ifSchedPtr = std::get_if<IfScheduleable>(&group);
1941 auto ifOp = ifSchedPtr->ifOp;
1943 Location loc = ifOp->getLoc();
1945 auto cond = ifOp.getCondition();
1946 auto condGroup = getState<ComponentLoweringState>()
1947 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
1952 bool initElse = !ifOp.getElseRegion().empty();
1953 auto ifCtrlOp = rewriter.create<calyx::IfOp>(
1954 loc, cond, symbolAttr, initElse);
1956 rewriter.setInsertionPointToEnd(ifCtrlOp.getBodyBlock());
1959 rewriter.create<calyx::SeqOp>(ifOp.getThenRegion().getLoc());
1960 auto *thenSeqOpBlock = thenSeqOp.getBodyBlock();
1962 auto *thenBlock = &ifOp.getThenRegion().front();
1963 LogicalResult res = buildCFGControl(path, rewriter, thenSeqOpBlock,
1968 rewriter.setInsertionPointToEnd(thenSeqOpBlock);
1969 calyx::GroupOp thenGroup =
1970 getState<ComponentLoweringState>().getThenGroup(ifOp);
1971 rewriter.create<calyx::EnableOp>(thenGroup.getLoc(),
1972 thenGroup.getName());
1974 if (!ifOp.getElseRegion().empty()) {
1975 rewriter.setInsertionPointToEnd(ifCtrlOp.getElseBody());
1978 rewriter.create<calyx::SeqOp>(ifOp.getElseRegion().getLoc());
1979 auto *elseSeqOpBlock = elseSeqOp.getBodyBlock();
1981 auto *elseBlock = &ifOp.getElseRegion().front();
1982 res = buildCFGControl(path, rewriter, elseSeqOpBlock,
1987 rewriter.setInsertionPointToEnd(elseSeqOpBlock);
1988 calyx::GroupOp elseGroup =
1989 getState<ComponentLoweringState>().getElseGroup(ifOp);
1990 rewriter.create<calyx::EnableOp>(elseGroup.getLoc(),
1991 elseGroup.getName());
1993 }
else if (
auto *callSchedPtr = std::get_if<CallScheduleable>(&group)) {
1994 auto instanceOp = callSchedPtr->instanceOp;
1995 OpBuilder::InsertionGuard g(rewriter);
1996 auto callBody = rewriter.create<calyx::SeqOp>(instanceOp.getLoc());
1997 rewriter.setInsertionPointToStart(callBody.getBodyBlock());
1998 std::string initGroupName =
"init_" + instanceOp.getSymName().str();
1999 rewriter.create<calyx::EnableOp>(instanceOp.getLoc(), initGroupName);
2001 auto callee = callSchedPtr->callOp.getCallee();
2002 auto *calleeOp = SymbolTable::lookupNearestSymbolFrom(
2003 callSchedPtr->callOp.getOperation()->getParentOp(),
2005 FuncOp calleeFunc = dyn_cast_or_null<FuncOp>(calleeOp);
2007 auto instanceOpComp =
2008 llvm::cast<calyx::ComponentOp>(instanceOp.getReferencedComponent());
2009 auto *instanceOpLoweringState =
2012 SmallVector<Value, 4> instancePorts;
2013 SmallVector<Value, 4> inputPorts;
2014 SmallVector<Attribute, 4> refCells;
2015 for (
auto operandEnum : enumerate(callSchedPtr->callOp.getOperands())) {
2016 auto operand = operandEnum.value();
2017 auto index = operandEnum.index();
2018 if (!isa<MemRefType>(operand.getType())) {
2019 inputPorts.push_back(operand);
2023 auto memOpName = getState<ComponentLoweringState>()
2024 .getMemoryInterface(operand)
2026 auto memOpNameAttr =
2028 Value argI = calleeFunc.getArgument(index);
2029 if (isa<MemRefType>(argI.getType())) {
2030 NamedAttrList namedAttrList;
2031 namedAttrList.append(
2032 rewriter.getStringAttr(
2033 instanceOpLoweringState->getMemoryInterface(argI)
2040 llvm::copy(instanceOp.getResults().take_front(inputPorts.size()),
2041 std::back_inserter(instancePorts));
2043 ArrayAttr refCellsAttr =
2046 rewriter.create<calyx::InvokeOp>(
2047 instanceOp.getLoc(), instanceOp.getSymName(), instancePorts,
2048 inputPorts, refCellsAttr,
ArrayAttr::get(rewriter.getContext(), {}),
2051 llvm_unreachable(
"Unknown scheduleable");
2062 const DenseSet<Block *> &path, Location loc,
2063 Block *from, Block *to,
2064 Block *parentCtrlBlock)
const {
2067 rewriter.setInsertionPointToEnd(parentCtrlBlock);
2068 auto preSeqOp = rewriter.create<calyx::SeqOp>(loc);
2069 rewriter.setInsertionPointToEnd(preSeqOp.getBodyBlock());
2071 getState<ComponentLoweringState>().getBlockArgGroups(from, to))
2072 rewriter.create<calyx::EnableOp>(barg.getLoc(), barg.getSymName());
2074 return buildCFGControl(path, rewriter, parentCtrlBlock, from, to);
2078 PatternRewriter &rewriter,
2079 mlir::Block *parentCtrlBlock,
2080 mlir::Block *preBlock,
2081 mlir::Block *block)
const {
2082 if (path.count(block) != 0)
2083 return preBlock->getTerminator()->emitError()
2084 <<
"CFG backedge detected. Loops must be raised to 'scf.while' or "
2085 "'scf.for' operations.";
2087 rewriter.setInsertionPointToEnd(parentCtrlBlock);
2088 LogicalResult bbSchedResult =
2089 scheduleBasicBlock(rewriter, path, parentCtrlBlock, block);
2090 if (bbSchedResult.failed())
2091 return bbSchedResult;
2094 auto successors = block->getSuccessors();
2095 auto nSuccessors = successors.size();
2096 if (nSuccessors > 0) {
2097 auto brOp = dyn_cast<BranchOpInterface>(block->getTerminator());
2099 if (nSuccessors > 1) {
2103 assert(nSuccessors == 2 &&
2104 "only conditional branches supported for now...");
2106 auto cond = brOp->getOperand(0);
2107 auto condGroup = getState<ComponentLoweringState>()
2108 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
2112 auto ifOp = rewriter.create<calyx::IfOp>(
2113 brOp->getLoc(), cond, symbolAttr,
true);
2114 rewriter.setInsertionPointToStart(ifOp.getThenBody());
2115 auto thenSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
2116 rewriter.setInsertionPointToStart(ifOp.getElseBody());
2117 auto elseSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
2119 bool trueBrSchedSuccess =
2120 schedulePath(rewriter, path, brOp.getLoc(), block, successors[0],
2121 thenSeqOp.getBodyBlock())
2123 bool falseBrSchedSuccess =
true;
2124 if (trueBrSchedSuccess) {
2125 falseBrSchedSuccess =
2126 schedulePath(rewriter, path, brOp.getLoc(), block, successors[1],
2127 elseSeqOp.getBodyBlock())
2131 return success(trueBrSchedSuccess && falseBrSchedSuccess);
2134 return schedulePath(rewriter, path, brOp.getLoc(), block,
2135 successors.front(), parentCtrlBlock);
2145 const SmallVector<calyx::GroupOp> &initGroups)
const {
2146 PatternRewriter::InsertionGuard g(rewriter);
2147 auto parOp = rewriter.create<calyx::ParOp>(loc);
2148 rewriter.setInsertionPointToStart(parOp.getBodyBlock());
2149 for (calyx::GroupOp group : initGroups)
2150 rewriter.create<calyx::EnableOp>(group.getLoc(), group.getName());
2154 SmallVector<calyx::GroupOp> initGroups,
2155 PatternRewriter &rewriter)
const {
2156 Location loc = whileOp.
getLoc();
2159 insertParInitGroups(rewriter, loc, initGroups);
2163 auto condGroup = getState<ComponentLoweringState>()
2164 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
2167 return rewriter.create<calyx::WhileOp>(loc, cond, symbolAttr);
2171 SmallVector<calyx::GroupOp>
const &initGroups,
2173 PatternRewriter &rewriter)
const {
2174 Location loc = forOp.
getLoc();
2177 insertParInitGroups(rewriter, loc, initGroups);
2180 return rewriter.create<calyx::RepeatOp>(loc, bound);
2187 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2190 PatternRewriter &)
const override {
2191 funcOp.walk([&](scf::IfOp op) {
2192 for (
auto res : getState<ComponentLoweringState>().getResultRegs(op))
2193 op.getOperation()->getResults()[res.first].replaceAllUsesWith(
2194 res.second.getOut());
2197 funcOp.walk([&](scf::WhileOp op) {
2206 getState<ComponentLoweringState>().getWhileLoopIterRegs(whileOp))
2207 whileOp.
getOperation()->getResults()[res.first].replaceAllUsesWith(
2208 res.second.getOut());
2211 funcOp.walk([&](memref::LoadOp loadOp) {
2217 loadOp.getResult().replaceAllUsesWith(
2218 getState<ComponentLoweringState>()
2219 .getMemoryInterface(loadOp.getMemref())
2230 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2233 PatternRewriter &rewriter)
const override {
2234 rewriter.eraseOp(funcOp);
2240 PatternRewriter &rewriter)
const override {
2254 class SCFToCalyxPass :
public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
2258 void runOnOperation()
override;
2260 LogicalResult setTopLevelFunction(mlir::ModuleOp moduleOp,
2261 std::string &topLevelFunction) {
2262 if (!topLevelFunctionOpt.empty()) {
2263 if (SymbolTable::lookupSymbolIn(moduleOp, topLevelFunctionOpt) ==
2265 moduleOp.emitError() <<
"Top level function '" << topLevelFunctionOpt
2266 <<
"' not found in module.";
2269 topLevelFunction = topLevelFunctionOpt;
2273 auto funcOps = moduleOp.getOps<FuncOp>();
2274 if (std::distance(funcOps.begin(), funcOps.end()) == 1)
2275 topLevelFunction = (*funcOps.begin()).getSymName().str();
2277 moduleOp.emitError()
2278 <<
"Module contains multiple functions, but no top level "
2279 "function was set. Please see --top-level-function";
2284 return createOptNewTopLevelFn(moduleOp, topLevelFunction);
2287 struct LoweringPattern {
2288 enum class Strategy { Once, Greedy };
2297 LogicalResult labelEntryPoint(StringRef topLevelFunction) {
2301 using OpRewritePattern::OpRewritePattern;
2302 LogicalResult matchAndRewrite(mlir::ModuleOp,
2303 PatternRewriter &)
const override {
2308 ConversionTarget target(getContext());
2309 target.addLegalDialect<calyx::CalyxDialect>();
2310 target.addLegalDialect<scf::SCFDialect>();
2311 target.addIllegalDialect<hw::HWDialect>();
2312 target.addIllegalDialect<comb::CombDialect>();
2315 target.addIllegalDialect<FuncDialect>();
2316 target.addIllegalDialect<ArithDialect>();
2317 target.addLegalOp<AddIOp, SelectOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp,
2318 ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
2319 CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
2320 RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
2321 ExtSIOp, CallOp, AddFOp, MulFOp, CmpFOp>();
2323 RewritePatternSet legalizePatterns(&getContext());
2324 legalizePatterns.add<DummyPattern>(&getContext());
2325 DenseSet<Operation *> legalizedOps;
2326 if (applyPartialConversion(getOperation(), target,
2327 std::move(legalizePatterns))
2338 template <
typename TPattern,
typename... PatternArgs>
2339 void addOncePattern(SmallVectorImpl<LoweringPattern> &
patterns,
2340 PatternArgs &&...args) {
2341 RewritePatternSet ps(&getContext());
2344 LoweringPattern{std::move(ps), LoweringPattern::Strategy::Once});
2347 template <
typename TPattern,
typename... PatternArgs>
2348 void addGreedyPattern(SmallVectorImpl<LoweringPattern> &
patterns,
2349 PatternArgs &&...args) {
2350 RewritePatternSet ps(&getContext());
2351 ps.add<TPattern>(&getContext(), args...);
2353 LoweringPattern{std::move(ps), LoweringPattern::Strategy::Greedy});
2356 LogicalResult runPartialPattern(RewritePatternSet &
pattern,
bool runOnce) {
2358 "Should only apply 1 partial lowering pattern at once");
2364 GreedyRewriteConfig config;
2365 config.enableRegionSimplification =
2366 mlir::GreedySimplifyRegionLevel::Disabled;
2368 config.maxIterations = 1;
2373 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(
pattern),
2383 FuncOp createNewTopLevelFn(ModuleOp moduleOp, std::string &baseName) {
2384 std::string newName =
"main";
2386 if (
auto *existingMainOp = SymbolTable::lookupSymbolIn(moduleOp, newName)) {
2387 auto existingMainFunc = dyn_cast<FuncOp>(existingMainOp);
2388 if (existingMainFunc ==
nullptr) {
2389 moduleOp.emitError() <<
"Symbol 'main' exists but is not a function";
2392 unsigned counter = 0;
2393 std::string newOldName = baseName;
2394 while (SymbolTable::lookupSymbolIn(moduleOp, newOldName))
2395 newOldName = llvm::join_items(
"_", baseName, std::to_string(++counter));
2396 existingMainFunc.setName(newOldName);
2397 if (baseName ==
"main")
2398 baseName = newOldName;
2402 OpBuilder builder(moduleOp.getContext());
2403 builder.setInsertionPointToStart(moduleOp.getBody());
2405 FunctionType funcType = builder.getFunctionType({}, {});
2408 builder.create<FuncOp>(moduleOp.getLoc(), newName, funcType))
2418 void insertCallFromNewTopLevel(OpBuilder &builder, FuncOp caller,
2420 if (caller.getBody().empty()) {
2421 caller.addEntryBlock();
2424 Block *callerEntryBlock = &caller.getBody().front();
2425 builder.setInsertionPointToStart(callerEntryBlock);
2429 SmallVector<Type, 4> nonMemRefCalleeArgTypes;
2430 for (
auto arg : callee.getArguments()) {
2431 if (!isa<MemRefType>(arg.getType())) {
2432 nonMemRefCalleeArgTypes.push_back(arg.getType());
2436 for (Type type : nonMemRefCalleeArgTypes) {
2437 callerEntryBlock->addArgument(type, caller.getLoc());
2440 FunctionType callerFnType = caller.getFunctionType();
2441 SmallVector<Type, 4> updatedCallerArgTypes(
2442 caller.getFunctionType().getInputs());
2443 updatedCallerArgTypes.append(nonMemRefCalleeArgTypes.begin(),
2444 nonMemRefCalleeArgTypes.end());
2446 callerFnType.getResults()));
2448 Block *calleeFnBody = &callee.getBody().front();
2449 unsigned originalCalleeArgNum = callee.getArguments().size();
2451 SmallVector<Value, 4> extraMemRefArgs;
2452 SmallVector<Type, 4> extraMemRefArgTypes;
2453 SmallVector<Value, 4> extraMemRefOperands;
2454 SmallVector<Operation *, 4> opsToModify;
2455 for (
auto &op : callee.getBody().getOps()) {
2456 if (isa<memref::AllocaOp, memref::AllocOp, memref::GetGlobalOp>(op))
2457 opsToModify.push_back(&op);
2462 builder.setInsertionPointToEnd(callerEntryBlock);
2463 for (
auto *op : opsToModify) {
2466 TypeSwitch<Operation *>(op)
2467 .Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
2468 newOpRes = builder.create<memref::AllocaOp>(callee.getLoc(),
2469 allocaOp.getType());
2471 .Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
2472 newOpRes = builder.create<memref::AllocOp>(callee.getLoc(),
2475 .Case<memref::GetGlobalOp>([&](memref::GetGlobalOp getGlobalOp) {
2476 newOpRes = builder.create<memref::GetGlobalOp>(
2477 caller.getLoc(), getGlobalOp.getType(), getGlobalOp.getName());
2479 .Default([&](Operation *defaultOp) {
2480 llvm::report_fatal_error(
"Unsupported operation in TypeSwitch");
2482 extraMemRefOperands.push_back(newOpRes);
2484 calleeFnBody->addArgument(newOpRes.getType(), callee.getLoc());
2485 BlockArgument newBodyArg = calleeFnBody->getArguments().back();
2486 op->getResult(0).replaceAllUsesWith(newBodyArg);
2488 extraMemRefArgs.push_back(newBodyArg);
2489 extraMemRefArgTypes.push_back(newBodyArg.getType());
2492 SmallVector<Type, 4> updatedCalleeArgTypes(
2493 callee.getFunctionType().getInputs());
2494 updatedCalleeArgTypes.append(extraMemRefArgTypes.begin(),
2495 extraMemRefArgTypes.end());
2497 callee.getFunctionType().getResults()));
2499 unsigned otherArgsCount = 0;
2500 SmallVector<Value, 4> calleeArgFnOperands;
2501 builder.setInsertionPointToStart(callerEntryBlock);
2502 for (
auto arg : callee.getArguments().take_front(originalCalleeArgNum)) {
2503 if (isa<MemRefType>(arg.getType())) {
2504 auto memrefType = cast<MemRefType>(arg.getType());
2506 builder.create<memref::AllocOp>(callee.getLoc(), memrefType);
2507 calleeArgFnOperands.push_back(allocOp);
2509 auto callerArg = callerEntryBlock->getArgument(otherArgsCount++);
2510 calleeArgFnOperands.push_back(callerArg);
2514 SmallVector<Value, 4> fnOperands;
2515 fnOperands.append(calleeArgFnOperands.begin(), calleeArgFnOperands.end());
2516 fnOperands.append(extraMemRefOperands.begin(), extraMemRefOperands.end());
2519 auto resultTypes = callee.getResultTypes();
2521 builder.setInsertionPointToEnd(callerEntryBlock);
2522 builder.create<CallOp>(caller.getLoc(), calleeName, resultTypes,
2529 LogicalResult createOptNewTopLevelFn(ModuleOp moduleOp,
2530 std::string &topLevelFunction) {
2531 auto hasMemrefArguments = [](FuncOp func) {
2533 func.getArguments().begin(), func.getArguments().end(),
2534 [](BlockArgument arg) { return isa<MemRefType>(arg.getType()); });
2540 auto funcOps = moduleOp.getOps<FuncOp>();
2541 bool hasMemrefArgsInTopLevel =
2542 std::any_of(funcOps.begin(), funcOps.end(), [&](
auto funcOp) {
2543 return funcOp.getName() == topLevelFunction &&
2544 hasMemrefArguments(funcOp);
2547 if (hasMemrefArgsInTopLevel) {
2548 auto newTopLevelFunc = createNewTopLevelFn(moduleOp, topLevelFunction);
2549 if (!newTopLevelFunc)
2552 OpBuilder builder(moduleOp.getContext());
2553 Operation *oldTopLevelFuncOp =
2554 SymbolTable::lookupSymbolIn(moduleOp, topLevelFunction);
2555 if (
auto oldTopLevelFunc = dyn_cast<FuncOp>(oldTopLevelFuncOp))
2556 insertCallFromNewTopLevel(builder, newTopLevelFunc, oldTopLevelFunc);
2558 moduleOp.emitOpError(
"Original top-level function not found!");
2561 topLevelFunction =
"main";
2568 void SCFToCalyxPass::runOnOperation() {
2573 std::string topLevelFunction;
2574 if (failed(setTopLevelFunction(getOperation(), topLevelFunction))) {
2575 signalPassFailure();
2580 if (failed(labelEntryPoint(topLevelFunction))) {
2581 signalPassFailure();
2584 loweringState = std::make_shared<calyx::CalyxLoweringState>(getOperation(),
2595 DenseMap<FuncOp, calyx::ComponentOp> funcMap;
2596 SmallVector<LoweringPattern, 8> loweringPatterns;
2600 addOncePattern<FuncOpConversion>(loweringPatterns, patternState, funcMap,
2604 addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);
2606 addOncePattern<BuildParGroups>(loweringPatterns, patternState, funcMap,
2610 addOncePattern<calyx::ConvertIndexTypes>(loweringPatterns, patternState,
2614 addOncePattern<calyx::BuildBasicBlockRegs>(loweringPatterns, patternState,
2617 addOncePattern<calyx::BuildCallInstance>(loweringPatterns, patternState,
2621 addOncePattern<calyx::BuildReturnRegs>(loweringPatterns, patternState,
2627 addOncePattern<BuildWhileGroups>(loweringPatterns, patternState, funcMap,
2633 addOncePattern<BuildForGroups>(loweringPatterns, patternState, funcMap,
2636 addOncePattern<BuildIfGroups>(loweringPatterns, patternState, funcMap,
2646 addOncePattern<BuildOpGroups>(loweringPatterns, patternState, funcMap,
2652 addOncePattern<BuildControl>(loweringPatterns, patternState, funcMap,
2657 addOncePattern<calyx::InlineCombGroups>(loweringPatterns, patternState,
2662 addOncePattern<LateSSAReplacement>(loweringPatterns, patternState, funcMap,
2668 addGreedyPattern<calyx::EliminateUnusedCombGroups>(loweringPatterns);
2672 addOncePattern<calyx::RewriteMemoryAccesses>(loweringPatterns, patternState,
2677 addOncePattern<CleanupFuncOps>(loweringPatterns, patternState, funcMap,
2681 for (
auto &pat : loweringPatterns) {
2684 pat.strategy == LoweringPattern::Strategy::Once);
2687 signalPassFailure();
2694 RewritePatternSet cleanupPatterns(&getContext());
2695 cleanupPatterns.add<calyx::MultipleGroupDonePattern,
2696 calyx::NonTerminatingGroupDonePattern>(&getContext());
2697 if (failed(applyPatternsAndFoldGreedily(getOperation(),
2698 std::move(cleanupPatterns)))) {
2699 signalPassFailure();
2703 if (ciderSourceLocationMetadata) {
2706 SmallVector<Attribute, 16> sourceLocations;
2707 getOperation()->walk([&](calyx::ComponentOp component) {
2711 MLIRContext *context = getOperation()->getContext();
2712 getOperation()->setAttr(
"calyx.metadata",
2723 return std::make_unique<SCFToCalyxPass>();
assert(baseType &&"element must be base type")
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
std::shared_ptr< calyx::CalyxLoweringState > loweringState
LogicalResult partialPatternRes
void setFuncOpResultMapping(const DenseMap< unsigned, unsigned > &mapping)
Assign a mapping between the source funcOp result indices and the corresponding output port indices o...
std::string getUniqueName(StringRef prefix)
Returns a unique name within compOp with the provided prefix.
void registerMemoryInterface(Value memref, const calyx::MemoryInterface &memoryInterface)
Registers a memory interface as being associated with a memory identified by 'memref'.
calyx::ComponentOp getComponentOp()
Returns the calyx::ComponentOp associated with this lowering state.
FuncOpPartialLoweringPatterns are patterns which intend to match on FuncOps and then perform their ow...
Location getLoc() override
Holds common utilities used for scheduling when lowering to Calyx.
Location getLoc() override
Builds a control schedule by traversing the CFG of the function and associating this with the previou...
calyx::RepeatOp buildForCtrlOp(ScfForOp forOp, SmallVector< calyx::GroupOp > const &initGroups, uint64_t bound, PatternRewriter &rewriter) const
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult schedulePath(PatternRewriter &rewriter, const DenseSet< Block * > &path, Location loc, Block *from, Block *to, Block *parentCtrlBlock) const
Schedules a block by inserting a branch argument assignment block (if any) before recursing into the ...
calyx::WhileOp buildWhileCtrlOp(ScfWhileOp whileOp, SmallVector< calyx::GroupOp > initGroups, PatternRewriter &rewriter) const
LogicalResult scheduleBasicBlock(PatternRewriter &rewriter, const DenseSet< Block * > &path, mlir::Block *parentCtrlBlock, mlir::Block *block) const
Sequentially schedules the groups that registered themselves with 'block'.
LogicalResult buildCFGControl(DenseSet< Block * > path, PatternRewriter &rewriter, mlir::Block *parentCtrlBlock, mlir::Block *preBlock, mlir::Block *block) const
void insertParInitGroups(PatternRewriter &rewriter, Location loc, const SmallVector< calyx::GroupOp > &initGroups) const
In BuildForGroups, a register is created for the iteration argument of the for op.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Iterate through the operations of a source function and instantiate components or primitives based on...
TGroupOp createGroupForOp(PatternRewriter &rewriter, Operation *op) const
Creates a group named by the basic block which the input op resides in.
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op) const
buildLibraryOp which provides in- and output types based on the operands and results of the op argume...
calyx::RegisterOp createSignalRegister(PatternRewriter &rewriter, Value signal, bool invert, StringRef nameSuffix, calyx::CompareFOpIEEE754 calyxCmpFOp, calyx::GroupOp group) const
void assignAddressPorts(PatternRewriter &rewriter, Location loc, calyx::GroupInterface group, calyx::MemoryInterface memoryInterface, Operation::operand_range addressValues) const
Creates assignments within the provided group to the address ports of the memoryOp based on the provi...
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op, TypeRange srcTypes, TypeRange dstTypes) const
buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the source operation TSrcOp.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult buildLibraryBinaryPipeOp(PatternRewriter &rewriter, TSrcOp op, TOpType opPipe, Value out) const
buildLibraryBinaryPipeOp will build a TCalyxLibBinaryPipeOp, to deal with MulIOp, DivUIOp and RemUIOp...
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partialEval(PatternRewriter &rewriter, scf::ParallelOp scfParOp) const
In BuildWhileGroups, a register is created for each iteration argumenet of the while op.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Erases FuncOp operations.
LogicalResult matchAndRewrite(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Handles the current state of lowering of a Calyx component.
ComponentLoweringState(calyx::ComponentOp component)
void setForLoopInitGroups(ScfForOp op, SmallVector< calyx::GroupOp > groups)
calyx::GroupOp buildForLoopIterArgAssignments(OpBuilder &builder, ScfForOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
void setForLoopLatchGroup(ScfForOp op, calyx::GroupOp group)
SmallVector< calyx::GroupOp > getForLoopInitGroups(ScfForOp op)
void addForLoopIterReg(ScfForOp op, calyx::RegisterOp reg, unsigned idx)
calyx::GroupOp getForLoopLatchGroup(ScfForOp op)
calyx::RegisterOp getForLoopIterReg(ScfForOp op, unsigned idx)
const DenseMap< unsigned, calyx::RegisterOp > & getForLoopIterRegs(ScfForOp op)
const DenseMap< unsigned, calyx::RegisterOp > & getResultRegs(scf::IfOp op)
DenseMap< Operation *, calyx::GroupOp > elseGroup
DenseMap< Operation *, calyx::GroupOp > thenGroup
void setElseGroup(scf::IfOp op, calyx::GroupOp group)
void setResultRegs(scf::IfOp op, calyx::RegisterOp reg, unsigned idx)
void setThenGroup(scf::IfOp op, calyx::GroupOp group)
DenseMap< Operation *, DenseMap< unsigned, calyx::RegisterOp > > resultRegs
calyx::RegisterOp getResultRegs(scf::IfOp op, unsigned idx)
calyx::GroupOp getThenGroup(scf::IfOp op)
calyx::GroupOp getElseGroup(scf::IfOp op)
Inlines Calyx ExecuteRegionOp operations within their parent blocks.
LogicalResult matchAndRewrite(scf::ExecuteRegionOp execOp, PatternRewriter &rewriter) const override
LateSSAReplacement contains various functions for replacing SSA values that were not replaced during ...
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &) const override
std::optional< int64_t > getBound() override
Block * getBodyBlock() override
Block::BlockArgListType getBodyArgs() override
ScfWhileOp(scf::WhileOp op)
Block::BlockArgListType getBodyArgs() override
Block * getConditionBlock() override
std::optional< int64_t > getBound() override
Block * getBodyBlock() override
Value getConditionValue() override
calyx::GroupOp buildWhileLoopIterArgAssignments(OpBuilder &builder, ScfWhileOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
void setWhileLoopInitGroups(ScfWhileOp op, SmallVector< calyx::GroupOp > groups)
SmallVector< calyx::GroupOp > getWhileLoopInitGroups(ScfWhileOp op)
void addWhileLoopIterReg(ScfWhileOp op, calyx::RegisterOp reg, unsigned idx)
const DenseMap< unsigned, calyx::RegisterOp > & getWhileLoopIterRegs(ScfWhileOp op)
void setWhileLoopLatchGroup(ScfWhileOp op, calyx::GroupOp group)
calyx::GroupOp getWhileLoopLatchGroup(ScfWhileOp op)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
void addMandatoryComponentPorts(PatternRewriter &rewriter, SmallVectorImpl< calyx::PortInfo > &ports)
void buildAssignmentsForRegisterWrite(OpBuilder &builder, calyx::GroupOp groupOp, calyx::ComponentOp componentOp, calyx::RegisterOp ®, Value inputValue)
Creates register assignment operations within the provided groupOp.
DenseMap< const mlir::RewritePattern *, SmallPtrSet< Operation *, 16 > > PatternApplicationState
Extra state that is passed to all PartialLoweringPatterns so they can record when they have run on an...
PredicateInfo getPredicateInfo(mlir::arith::CmpFPredicate pred)
Type convIndexType(OpBuilder &builder, Type type)
LogicalResult applyModuleOpConversion(mlir::ModuleOp, StringRef topLevelFunction)
Helper to update the top-level ModuleOp to set the entrypoing function.
WalkResult getCiderSourceLocationMetadata(calyx::ComponentOp component, SmallVectorImpl< Attribute > &sourceLocations)
bool matchConstantOp(Operation *op, APInt &value)
unsigned handleZeroWidth(int64_t dim)
hw::ConstantOp createConstant(Location loc, OpBuilder &builder, ComponentOp component, size_t width, size_t value)
A helper function to create constants in the HW dialect.
calyx::RegisterOp createRegister(Location loc, OpBuilder &builder, ComponentOp component, size_t width, Twine prefix)
Creates a RegisterOp, with input and output port bit widths defined by width.
bool noStoresToMemory(Value memoryReference)
bool singleLoadFromMemory(Value memoryReference)
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
static constexpr std::string_view sPortNameAttr
static LogicalResult buildAllocOp(ComponentLoweringState &componentState, PatternRewriter &rewriter, TAllocOp allocOp)
std::variant< calyx::GroupOp, WhileScheduleable, ForScheduleable, IfScheduleable, CallScheduleable, ParScheduleable > Scheduleable
A variant of types representing scheduleable operations.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< OperationPass< ModuleOp > > createSCFToCalyxPass()
Create an SCF to Calyx conversion pass.
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
This holds information about the port for either a Component or Cell.
Predicate information for the floating point comparisons.
calyx::InstanceOp instanceOp
Instance for invoking.
ScfForOp forOp
For operation to schedule.
Creates a new Calyx component for each FuncOp in the program.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
scf::ParallelOp parOp
Parallel operation to schedule.
ScfWhileOp whileOp
While operation to schedule.