14 #include "../PassDetail.h"
20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21 #include "mlir/Conversion/LLVMCommon/Pattern.h"
22 #include "mlir/Dialect/Arith/IR/Arith.h"
23 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SCF/IR/SCF.h"
27 #include "mlir/IR/AsmState.h"
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/Support/LogicalResult.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/ADT/TypeSwitch.h"
42 class ComponentLoweringStateInterface;
43 namespace scftocalyx {
52 : calyx::WhileOpInterface<scf::WhileOp>(op) {}
55 return getOperation().getAfterArguments();
58 Block *
getBodyBlock()
override {
return &getOperation().getAfter().front(); }
61 return &getOperation().getBefore().front();
65 return getOperation().getConditionOp().getOperand(0);
68 std::optional<int64_t>
getBound()
override {
return std::nullopt; }
73 explicit ScfForOp(scf::ForOp op) : calyx::RepeatOpInterface<scf::ForOp>(op) {}
76 return getOperation().getRegion().getArguments();
80 return &getOperation().getRegion().getBlocks().front();
84 return constantTripCount(getOperation().getLowerBound(),
85 getOperation().getUpperBound(),
86 getOperation().getStep());
121 return getLoopInitGroups(std::move(op));
125 Twine uniqueSuffix, MutableArrayRef<OpOperand> ops) {
126 return buildLoopIterArgAssignments(
builder, std::move(op), componentOp,
130 return addLoopIterReg(std::move(op),
reg, idx);
132 const DenseMap<unsigned, calyx::RegisterOp> &
134 return getLoopIterRegs(std::move(op));
137 return setLoopLatchGroup(std::move(op), group);
140 return getLoopLatchGroup(std::move(op));
143 SmallVector<calyx::GroupOp> groups) {
144 return setLoopInitGroups(std::move(op), std::move(groups));
152 return getLoopInitGroups(std::move(op));
156 Twine uniqueSuffix, MutableArrayRef<OpOperand> ops) {
157 return buildLoopIterArgAssignments(
builder, std::move(op), componentOp,
161 return addLoopIterReg(std::move(op),
reg, idx);
164 return getLoopIterRegs(std::move(op));
167 return getLoopIterReg(std::move(op), idx);
170 return setLoopLatchGroup(std::move(op), group);
173 return getLoopLatchGroup(std::move(op));
176 return setLoopInitGroups(std::move(op), std::move(groups));
189 : calyx::ComponentLoweringStateInterface(component) {}
199 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
203 PatternRewriter &rewriter)
const override {
206 bool opBuiltSuccessfully =
true;
207 funcOp.walk([&](Operation *_op) {
208 opBuiltSuccessfully &=
209 TypeSwitch<mlir::Operation *, bool>(_op)
210 .template Case<arith::ConstantOp, ReturnOp, BranchOpInterface,
212 scf::YieldOp, scf::WhileOp, scf::ForOp,
214 memref::AllocOp, memref::AllocaOp, memref::LoadOp,
217 AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
218 AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
219 MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
220 SelectOp, IndexCastOp, CallOp>(
221 [&](
auto op) {
return buildOp(rewriter, op).succeeded(); })
222 .
template Case<FuncOp, scf::ConditionOp>([&](
auto) {
226 .Default([&](
auto op) {
227 op->emitError() <<
"Unhandled operation during BuildOpGroups()";
231 return opBuiltSuccessfully ? WalkResult::advance()
232 : WalkResult::interrupt();
235 return success(opBuiltSuccessfully);
240 LogicalResult buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp)
const;
241 LogicalResult buildOp(PatternRewriter &rewriter,
242 BranchOpInterface brOp)
const;
243 LogicalResult buildOp(PatternRewriter &rewriter,
244 arith::ConstantOp constOp)
const;
245 LogicalResult buildOp(PatternRewriter &rewriter, SelectOp op)
const;
246 LogicalResult buildOp(PatternRewriter &rewriter, AddIOp op)
const;
247 LogicalResult buildOp(PatternRewriter &rewriter, SubIOp op)
const;
248 LogicalResult buildOp(PatternRewriter &rewriter, MulIOp op)
const;
249 LogicalResult buildOp(PatternRewriter &rewriter, DivUIOp op)
const;
250 LogicalResult buildOp(PatternRewriter &rewriter, DivSIOp op)
const;
251 LogicalResult buildOp(PatternRewriter &rewriter, RemUIOp op)
const;
252 LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op)
const;
253 LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op)
const;
254 LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op)
const;
255 LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op)
const;
256 LogicalResult buildOp(PatternRewriter &rewriter, AndIOp op)
const;
257 LogicalResult buildOp(PatternRewriter &rewriter, OrIOp op)
const;
258 LogicalResult buildOp(PatternRewriter &rewriter, XOrIOp op)
const;
259 LogicalResult buildOp(PatternRewriter &rewriter, CmpIOp op)
const;
260 LogicalResult buildOp(PatternRewriter &rewriter, TruncIOp op)
const;
261 LogicalResult buildOp(PatternRewriter &rewriter, ExtUIOp op)
const;
262 LogicalResult buildOp(PatternRewriter &rewriter, ExtSIOp op)
const;
263 LogicalResult buildOp(PatternRewriter &rewriter, ReturnOp op)
const;
264 LogicalResult buildOp(PatternRewriter &rewriter, IndexCastOp op)
const;
265 LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocOp op)
const;
266 LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocaOp op)
const;
267 LogicalResult buildOp(PatternRewriter &rewriter, memref::LoadOp op)
const;
268 LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op)
const;
269 LogicalResult buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp)
const;
270 LogicalResult buildOp(PatternRewriter &rewriter, scf::ForOp forOp)
const;
271 LogicalResult buildOp(PatternRewriter &rewriter, CallOp callOp)
const;
275 template <
typename TGroupOp,
typename TCalyxLibOp,
typename TSrcOp>
277 TypeRange srcTypes, TypeRange dstTypes)
const {
278 SmallVector<Type> types;
279 llvm::append_range(types, srcTypes);
280 llvm::append_range(types, dstTypes);
283 getState<ComponentLoweringState>().getNewLibraryOpInstance<TCalyxLibOp>(
284 rewriter, op.getLoc(), types);
286 auto directions = calyxOp.portDirections();
287 SmallVector<Value, 4> opInputPorts;
288 SmallVector<Value, 4> opOutputPorts;
289 for (
auto dir : enumerate(directions)) {
291 opInputPorts.push_back(calyxOp.getResult(dir.index()));
293 opOutputPorts.push_back(calyxOp.getResult(dir.index()));
296 opInputPorts.size() == op->getNumOperands() &&
297 opOutputPorts.size() == op->getNumResults() &&
298 "Expected an equal number of in/out ports in the Calyx library op with "
299 "respect to the number of operands/results of the source operation.");
302 auto group = createGroupForOp<TGroupOp>(rewriter, op);
303 rewriter.setInsertionPointToEnd(group.getBodyBlock());
304 for (
auto dstOp : enumerate(opInputPorts))
305 rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
306 op->getOperand(dstOp.index()));
309 for (
auto res : enumerate(opOutputPorts)) {
310 getState<ComponentLoweringState>().registerEvaluatingGroup(res.value(),
312 op->getResult(res.index()).replaceAllUsesWith(res.value());
319 template <
typename TGroupOp,
typename TCalyxLibOp,
typename TSrcOp>
321 return buildLibraryOp<TGroupOp, TCalyxLibOp, TSrcOp>(
322 rewriter, op, op.getOperandTypes(), op->getResultTypes());
326 template <
typename TGroupOp>
328 Block *block = op->getBlock();
329 auto groupName = getState<ComponentLoweringState>().getUniqueName(
330 loweringState().blockName(block));
331 return calyx::createGroup<TGroupOp>(
332 rewriter, getState<ComponentLoweringState>().getComponentOp(),
333 op->getLoc(), groupName);
338 template <
typename TOpType,
typename TSrcOp>
340 TOpType opPipe, Value out)
const {
341 StringRef opName = TSrcOp::getOperationName().split(
".").second;
342 Location loc = op.getLoc();
343 Type
width = op.getResult().getType();
345 op.getResult().replaceAllUsesWith(out);
347 op.getLoc(), rewriter, getComponent(),
width.getIntOrFloatBitWidth(),
348 getState<ComponentLoweringState>().getUniqueName(opName));
350 auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
351 OpBuilder
builder(group->getRegion(0));
352 getState<ComponentLoweringState>().addBlockScheduleable(op->getBlock(),
355 rewriter.setInsertionPointToEnd(group.getBodyBlock());
356 rewriter.create<calyx::AssignOp>(loc, opPipe.getLeft(), op.getLhs());
357 rewriter.create<calyx::AssignOp>(loc, opPipe.getRight(), op.getRhs());
359 rewriter.create<calyx::AssignOp>(loc,
reg.getIn(), out);
361 rewriter.create<calyx::AssignOp>(loc,
reg.getWriteEn(), opPipe.getDone());
366 rewriter.create<calyx::AssignOp>(
367 loc, opPipe.getGo(), c1,
370 rewriter.create<calyx::GroupDoneOp>(loc,
reg.getDone());
373 getState<ComponentLoweringState>().registerEvaluatingGroup(out, group);
374 getState<ComponentLoweringState>().registerEvaluatingGroup(opPipe.getLeft(),
376 getState<ComponentLoweringState>().registerEvaluatingGroup(
377 opPipe.getRight(), group);
385 calyx::GroupInterface group,
387 Operation::operand_range addressValues)
const {
388 IRRewriter::InsertionGuard guard(rewriter);
389 rewriter.setInsertionPointToEnd(group.getBody());
390 auto addrPorts = memoryInterface.
addrPorts();
391 if (addressValues.empty()) {
393 addrPorts.size() == 1 &&
394 "We expected a 1 dimensional memory of size 1 because there were no "
395 "address assignment values");
397 rewriter.create<calyx::AssignOp>(
401 assert(addrPorts.size() == addressValues.size() &&
402 "Mismatch between number of address ports of the provided memory "
403 "and address assignment values");
404 for (
auto address : enumerate(addressValues))
405 rewriter.create<calyx::AssignOp>(loc, addrPorts[address.index()],
411 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
412 memref::LoadOp loadOp)
const {
413 Value memref = loadOp.getMemref();
414 auto memoryInterface =
415 getState<ComponentLoweringState>().getMemoryInterface(memref);
416 auto group = createGroupForOp<calyx::GroupOp>(rewriter, loadOp);
417 assignAddressPorts(rewriter, loadOp.getLoc(), group, memoryInterface,
418 loadOp.getIndices());
420 rewriter.setInsertionPointToEnd(group.getBodyBlock());
426 if (memoryInterface.readEnOpt().has_value()) {
429 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), memoryInterface.readEn(),
431 regWriteEn = memoryInterface.readDone();
438 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(),
439 memoryInterface.readDone());
449 res = loadOp.getResult();
462 loadOp.getLoc(), rewriter, getComponent(),
463 loadOp.getMemRefType().getElementTypeBitWidth(),
464 getState<ComponentLoweringState>().getUniqueName(
"load"));
465 rewriter.setInsertionPointToEnd(group.getBodyBlock());
466 rewriter.create<calyx::AssignOp>(loadOp.getLoc(),
reg.getIn(),
467 memoryInterface.readData());
468 rewriter.create<calyx::AssignOp>(loadOp.getLoc(),
reg.getWriteEn(),
470 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(),
reg.getDone());
471 loadOp.getResult().replaceAllUsesWith(
reg.getOut());
475 getState<ComponentLoweringState>().registerEvaluatingGroup(res, group);
476 getState<ComponentLoweringState>().addBlockScheduleable(loadOp->getBlock(),
481 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
482 memref::StoreOp storeOp)
const {
483 auto memoryInterface = getState<ComponentLoweringState>().getMemoryInterface(
484 storeOp.getMemref());
485 auto group = createGroupForOp<calyx::GroupOp>(rewriter, storeOp);
489 getState<ComponentLoweringState>().addBlockScheduleable(storeOp->getBlock(),
491 assignAddressPorts(rewriter, storeOp.getLoc(), group, memoryInterface,
492 storeOp.getIndices());
493 rewriter.setInsertionPointToEnd(group.getBodyBlock());
494 rewriter.create<calyx::AssignOp>(
495 storeOp.getLoc(), memoryInterface.writeData(), storeOp.getValueToStore());
496 rewriter.create<calyx::AssignOp>(
497 storeOp.getLoc(), memoryInterface.writeEn(),
498 createConstant(storeOp.getLoc(), rewriter, getComponent(), 1, 1));
499 rewriter.create<calyx::GroupDoneOp>(storeOp.getLoc(),
500 memoryInterface.writeDone());
505 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
507 Location loc = mul.getLoc();
508 Type
width = mul.getResult().getType(), one = rewriter.getI1Type();
510 getState<ComponentLoweringState>()
511 .getNewLibraryOpInstance<calyx::MultPipeLibOp>(
513 return buildLibraryBinaryPipeOp<calyx::MultPipeLibOp>(
514 rewriter, mul, mulPipe,
518 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
520 Location loc = div.getLoc();
521 Type
width = div.getResult().getType(), one = rewriter.getI1Type();
523 getState<ComponentLoweringState>()
524 .getNewLibraryOpInstance<calyx::DivUPipeLibOp>(
526 return buildLibraryBinaryPipeOp<calyx::DivUPipeLibOp>(
527 rewriter, div, divPipe,
531 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
533 Location loc = div.getLoc();
534 Type
width = div.getResult().getType(), one = rewriter.getI1Type();
536 getState<ComponentLoweringState>()
537 .getNewLibraryOpInstance<calyx::DivSPipeLibOp>(
539 return buildLibraryBinaryPipeOp<calyx::DivSPipeLibOp>(
540 rewriter, div, divPipe,
544 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
546 Location loc = rem.getLoc();
547 Type
width = rem.getResult().getType(), one = rewriter.getI1Type();
549 getState<ComponentLoweringState>()
550 .getNewLibraryOpInstance<calyx::RemUPipeLibOp>(
552 return buildLibraryBinaryPipeOp<calyx::RemUPipeLibOp>(
553 rewriter, rem, remPipe,
557 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
559 Location loc = rem.getLoc();
560 Type
width = rem.getResult().getType(), one = rewriter.getI1Type();
562 getState<ComponentLoweringState>()
563 .getNewLibraryOpInstance<calyx::RemSPipeLibOp>(
565 return buildLibraryBinaryPipeOp<calyx::RemSPipeLibOp>(
566 rewriter, rem, remPipe,
570 template <
typename TAllocOp>
572 PatternRewriter &rewriter, TAllocOp allocOp) {
573 rewriter.setInsertionPointToStart(
575 MemRefType memtype = allocOp.getType();
576 SmallVector<int64_t> addrSizes;
577 SmallVector<int64_t> sizes;
578 for (int64_t dim : memtype.getShape()) {
579 sizes.push_back(dim);
584 if (sizes.empty() && addrSizes.empty()) {
586 addrSizes.push_back(1);
588 auto memoryOp = rewriter.create<calyx::SeqMemoryOp>(
590 memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
593 memoryOp->setAttr(
"external",
600 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
601 memref::AllocOp allocOp)
const {
602 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
605 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
606 memref::AllocaOp allocOp)
const {
607 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
610 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
611 scf::YieldOp yieldOp)
const {
612 if (yieldOp.getOperands().empty()) {
614 auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
615 assert(forOp &&
"Empty yieldOps should only be located within ForOps");
620 getState<ComponentLoweringState>().getForLoopIterReg(forOpInterface, 0);
622 Type regWidth = inductionReg.getOut().getType();
624 SmallVector<Type> types(3, regWidth);
625 auto addOp = getState<ComponentLoweringState>()
626 .getNewLibraryOpInstance<calyx::AddLibOp>(
627 rewriter, forOp.getLoc(), types);
629 auto directions = addOp.portDirections();
631 SmallVector<Value, 2> opInputPorts;
633 for (
auto dir : enumerate(directions)) {
634 switch (dir.value()) {
636 opInputPorts.push_back(addOp.getResult(dir.index()));
640 opOutputPort = addOp.getResult(dir.index());
647 calyx::ComponentOp componentOp =
648 getState<ComponentLoweringState>().getComponentOp();
649 SmallVector<StringRef, 4> groupIdentifier = {
650 "incr", getState<ComponentLoweringState>().getUniqueName(forOp),
652 auto groupOp = calyx::createGroup<calyx::GroupOp>(
653 rewriter, componentOp, forOp.getLoc(),
654 llvm::join(groupIdentifier,
"_"));
655 rewriter.setInsertionPointToEnd(groupOp.getBodyBlock());
658 Value leftOp = opInputPorts.front();
659 rewriter.create<calyx::AssignOp>(forOp.getLoc(), leftOp,
660 inductionReg.getOut());
662 Value rightOp = opInputPorts.back();
663 rewriter.create<calyx::AssignOp>(
664 forOp.getLoc(), rightOp,
666 regWidth.getIntOrFloatBitWidth(),
667 forOp.getConstantStep().value().getSExtValue()));
670 inductionReg, opOutputPort);
672 getState<ComponentLoweringState>().setForLoopLatchGroup(forOpInterface,
674 getState<ComponentLoweringState>().registerEvaluatingGroup(opOutputPort,
679 if (dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
680 return yieldOp.getOperation()->emitError()
681 <<
"Currently do not support non-empty yield operations inside for "
682 "loops. Run --scf-for-to-while before running --scf-to-calyx.";
685 auto whileOp = dyn_cast<scf::WhileOp>(yieldOp->getParentOp());
687 return yieldOp.getOperation()->emitError()
688 <<
"Currently only support yield operations inside for and while "
694 getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
695 rewriter, whileOpInterface,
696 getState<ComponentLoweringState>().getComponentOp(),
697 getState<ComponentLoweringState>().getUniqueName(whileOp) +
"_latch",
698 yieldOp->getOpOperands());
699 getState<ComponentLoweringState>().setWhileLoopLatchGroup(whileOpInterface,
704 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
705 BranchOpInterface brOp)
const {
710 Block *srcBlock = brOp->getBlock();
711 for (
auto succBlock : enumerate(brOp->getSuccessors())) {
712 auto succOperands = brOp.getSuccessorOperands(succBlock.index());
713 if (succOperands.empty())
716 std::string groupName = loweringState().blockName(srcBlock) +
"_to_" +
717 loweringState().blockName(succBlock.value());
718 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter, getComponent(),
719 brOp.getLoc(), groupName);
721 auto dstBlockArgRegs =
722 getState<ComponentLoweringState>().getBlockArgRegs(succBlock.value());
724 for (
auto arg : enumerate(succOperands.getForwardedOperands())) {
725 auto reg = dstBlockArgRegs[arg.index()];
728 getState<ComponentLoweringState>().getComponentOp(),
reg,
733 getState<ComponentLoweringState>().addBlockArgGroup(
734 srcBlock, succBlock.value(), groupOp);
741 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
742 ReturnOp retOp)
const {
743 if (retOp.getNumOperands() == 0)
746 std::string groupName =
747 getState<ComponentLoweringState>().getUniqueName(
"ret_assign");
748 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter, getComponent(),
749 retOp.getLoc(), groupName);
750 for (
auto op : enumerate(retOp.getOperands())) {
751 auto reg = getState<ComponentLoweringState>().getReturnReg(op.index());
753 rewriter, groupOp, getState<ComponentLoweringState>().getComponentOp(),
757 getState<ComponentLoweringState>().addBlockScheduleable(retOp->getBlock(),
762 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
763 arith::ConstantOp constOp)
const {
768 hwConstOp->moveAfter(getComponent().getBodyBlock(),
769 getComponent().getBodyBlock()->begin());
773 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
775 return buildLibraryOp<calyx::CombGroupOp, calyx::AddLibOp>(rewriter, op);
777 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
779 return buildLibraryOp<calyx::CombGroupOp, calyx::SubLibOp>(rewriter, op);
781 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
783 return buildLibraryOp<calyx::CombGroupOp, calyx::RshLibOp>(rewriter, op);
785 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
787 return buildLibraryOp<calyx::CombGroupOp, calyx::SrshLibOp>(rewriter, op);
789 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
791 return buildLibraryOp<calyx::CombGroupOp, calyx::LshLibOp>(rewriter, op);
793 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
795 return buildLibraryOp<calyx::CombGroupOp, calyx::AndLibOp>(rewriter, op);
797 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
799 return buildLibraryOp<calyx::CombGroupOp, calyx::OrLibOp>(rewriter, op);
801 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
803 return buildLibraryOp<calyx::CombGroupOp, calyx::XorLibOp>(rewriter, op);
805 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
807 return buildLibraryOp<calyx::CombGroupOp, calyx::MuxLibOp>(rewriter, op);
810 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
812 switch (op.getPredicate()) {
813 case CmpIPredicate::eq:
814 return buildLibraryOp<calyx::CombGroupOp, calyx::EqLibOp>(rewriter, op);
815 case CmpIPredicate::ne:
816 return buildLibraryOp<calyx::CombGroupOp, calyx::NeqLibOp>(rewriter, op);
817 case CmpIPredicate::uge:
818 return buildLibraryOp<calyx::CombGroupOp, calyx::GeLibOp>(rewriter, op);
819 case CmpIPredicate::ult:
820 return buildLibraryOp<calyx::CombGroupOp, calyx::LtLibOp>(rewriter, op);
821 case CmpIPredicate::ugt:
822 return buildLibraryOp<calyx::CombGroupOp, calyx::GtLibOp>(rewriter, op);
823 case CmpIPredicate::ule:
824 return buildLibraryOp<calyx::CombGroupOp, calyx::LeLibOp>(rewriter, op);
825 case CmpIPredicate::sge:
826 return buildLibraryOp<calyx::CombGroupOp, calyx::SgeLibOp>(rewriter, op);
827 case CmpIPredicate::slt:
828 return buildLibraryOp<calyx::CombGroupOp, calyx::SltLibOp>(rewriter, op);
829 case CmpIPredicate::sgt:
830 return buildLibraryOp<calyx::CombGroupOp, calyx::SgtLibOp>(rewriter, op);
831 case CmpIPredicate::sle:
832 return buildLibraryOp<calyx::CombGroupOp, calyx::SleLibOp>(rewriter, op);
834 llvm_unreachable(
"unsupported comparison predicate");
836 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
838 return buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
839 rewriter, op, {op.getOperand().getType()}, {op.getType()});
841 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
843 return buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
844 rewriter, op, {op.getOperand().getType()}, {op.getType()});
847 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
849 return buildLibraryOp<calyx::CombGroupOp, calyx::ExtSILibOp>(
850 rewriter, op, {op.getOperand().getType()}, {op.getType()});
853 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
854 IndexCastOp op)
const {
857 unsigned targetBits = targetType.getIntOrFloatBitWidth();
858 unsigned sourceBits = sourceType.getIntOrFloatBitWidth();
859 LogicalResult res = success();
861 if (targetBits == sourceBits) {
864 op.getResult().replaceAllUsesWith(op.getOperand());
867 if (sourceBits > targetBits)
868 res = buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
869 rewriter, op, {sourceType}, {targetType});
871 res = buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
872 rewriter, op, {sourceType}, {targetType});
874 rewriter.eraseOp(op);
878 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
879 scf::WhileOp whileOp)
const {
883 getState<ComponentLoweringState>().addBlockScheduleable(
888 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
889 scf::ForOp forOp)
const {
895 std::optional<uint64_t> bound = scfForOp.
getBound();
896 if (!bound.has_value()) {
898 <<
"Loop bound not statically known. Should "
899 "transform into while loop using `--scf-for-to-while` before "
900 "running --lower-scf-to-calyx.";
902 getState<ComponentLoweringState>().addBlockScheduleable(
910 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
911 CallOp callOp)
const {
913 calyx::InstanceOp instanceOp =
914 getState<ComponentLoweringState>().getInstance(instanceName);
915 SmallVector<Value, 4> outputPorts;
916 auto portInfos = instanceOp.getReferencedComponent().getPortInfo();
917 for (
auto [idx, portInfo] : enumerate(portInfos)) {
919 outputPorts.push_back(instanceOp.getResult(idx));
923 for (
auto [idx, result] : llvm::enumerate(callOp.getResults()))
924 rewriter.replaceAllUsesWith(result, outputPorts[idx]);
928 getState<ComponentLoweringState>().addBlockScheduleable(
942 using OpRewritePattern::OpRewritePattern;
945 PatternRewriter &rewriter)
const override {
947 TypeRange yieldTypes = execOp.getResultTypes();
951 rewriter.setInsertionPointAfter(execOp);
952 auto *sinkBlock = rewriter.splitBlock(
954 execOp.getOperation()->getIterator()->getNextNode()->getIterator());
955 sinkBlock->addArguments(
957 SmallVector<Location, 4>(yieldTypes.size(), rewriter.getUnknownLoc()));
958 for (
auto res : enumerate(execOp.getResults()))
959 res.value().replaceAllUsesWith(sinkBlock->getArgument(res.index()));
963 make_early_inc_range(execOp.getRegion().getOps<scf::YieldOp>())) {
964 rewriter.setInsertionPointAfter(yieldOp);
965 rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, sinkBlock,
966 yieldOp.getOperands());
970 auto *preBlock = execOp->getBlock();
971 auto *execOpEntryBlock = &execOp.getRegion().front();
972 auto *postBlock = execOp->getBlock()->splitBlock(execOp);
973 rewriter.inlineRegionBefore(execOp.getRegion(), postBlock);
974 rewriter.mergeBlocks(postBlock, preBlock);
975 rewriter.eraseOp(execOp);
978 rewriter.mergeBlocks(execOpEntryBlock, preBlock);
986 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
990 PatternRewriter &rewriter)
const override {
993 DenseMap<Value, unsigned> funcOpArgRewrites;
997 DenseMap<unsigned, unsigned> funcOpResultMapping;
1005 DenseMap<Value, std::pair<unsigned, unsigned>> extMemoryCompPortIndices;
1009 SmallVector<calyx::PortInfo> inPorts, outPorts;
1010 FunctionType funcType = funcOp.getFunctionType();
1011 unsigned extMemCounter = 0;
1012 for (
auto arg : enumerate(funcOp.getArguments())) {
1013 if (arg.value().getType().isa<MemRefType>()) {
1016 "ext_mem" + std::to_string(extMemoryCompPortIndices.size());
1017 extMemoryCompPortIndices[arg.value()] = {inPorts.size(),
1020 extMemCounter++, inPorts, outPorts);
1024 if (
auto portNameAttr = funcOp.getArgAttrOfType<StringAttr>(
1026 inName = portNameAttr.str();
1028 inName =
"in" + std::to_string(arg.index());
1029 funcOpArgRewrites[arg.value()] = inPorts.size();
1031 rewriter.getStringAttr(inName),
1037 for (
auto res : enumerate(funcType.getResults())) {
1038 std::string resName;
1039 if (
auto portNameAttr = funcOp.getResultAttrOfType<StringAttr>(
1041 resName = portNameAttr.str();
1043 resName =
"out" + std::to_string(res.index());
1044 funcOpResultMapping[res.index()] = outPorts.size();
1046 rewriter.getStringAttr(resName),
1053 auto ports = inPorts;
1054 llvm::append_range(ports, outPorts);
1058 auto compOp = rewriter.create<calyx::ComponentOp>(
1059 funcOp.getLoc(), rewriter.getStringAttr(funcOp.getSymName()), ports);
1061 std::string funcName =
"func_" + funcOp.getSymName().str();
1062 rewriter.updateRootInPlace(funcOp, [&]() { funcOp.setSymName(funcName); });
1065 compOp->setAttr(
"toplevel", rewriter.getUnitAttr());
1068 functionMapping[funcOp] = compOp;
1073 for (
auto &mapping : funcOpArgRewrites)
1074 mapping.getFirst().replaceAllUsesWith(
1075 compOp.getArgument(mapping.getSecond()));
1078 for (
auto extMemPortIndices : extMemoryCompPortIndices) {
1082 unsigned inPortsIt = extMemPortIndices.getSecond().first;
1083 unsigned outPortsIt = extMemPortIndices.getSecond().second +
1084 compOp.getInputPortInfo().size();
1085 extMemPorts.
readData = compOp.getArgument(inPortsIt++);
1086 extMemPorts.
writeDone = compOp.getArgument(inPortsIt);
1087 extMemPorts.
writeData = compOp.getArgument(outPortsIt++);
1088 unsigned nAddresses = extMemPortIndices.getFirst()
1093 for (
unsigned j = 0; j < nAddresses; ++j)
1094 extMemPorts.
addrPorts.push_back(compOp.getArgument(outPortsIt++));
1095 extMemPorts.
writeEn = compOp.getArgument(outPortsIt);
1099 compState->registerMemoryInterface(extMemPortIndices.getFirst(),
1112 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1116 PatternRewriter &rewriter)
const override {
1117 LogicalResult res = success();
1118 funcOp.walk([&](Operation *op) {
1120 if (!isa<scf::WhileOp>(op))
1121 return WalkResult::advance();
1123 auto scfWhileOp = cast<scf::WhileOp>(op);
1126 getState<ComponentLoweringState>().setUniqueName(whileOp.
getOperation(),
1136 enumerate(scfWhileOp.getBefore().front().getArguments())) {
1137 auto condOp = scfWhileOp.getConditionOp().getArgs()[barg.index()];
1138 if (barg.value() != condOp) {
1140 << loweringState().irName(barg.value())
1141 <<
" != " << loweringState().irName(condOp)
1142 <<
"do-while loops not supported; expected iter-args to "
1143 "remain untransformed in the 'before' region of the "
1145 return WalkResult::interrupt();
1154 for (
auto arg : enumerate(whileOp.
getBodyArgs())) {
1155 std::string name = getState<ComponentLoweringState>()
1158 "_arg" + std::to_string(arg.index());
1161 arg.value().getType().getIntOrFloatBitWidth(), name);
1162 getState<ComponentLoweringState>().addWhileLoopIterReg(whileOp,
reg,
1164 arg.value().replaceAllUsesWith(
reg.getOut());
1168 ->getArgument(arg.index())
1169 .replaceAllUsesWith(
reg.getOut());
1173 SmallVector<calyx::GroupOp> initGroups;
1174 auto numOperands = whileOp.
getOperation()->getNumOperands();
1175 for (
size_t i = 0; i < numOperands; ++i) {
1177 getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
1179 getState<ComponentLoweringState>().getComponentOp(),
1180 getState<ComponentLoweringState>().getUniqueName(
1182 "_init_" + std::to_string(i),
1184 initGroups.push_back(initGroupOp);
1187 getState<ComponentLoweringState>().setWhileLoopInitGroups(whileOp,
1190 return WalkResult::advance();
1200 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1204 PatternRewriter &rewriter)
const override {
1205 LogicalResult res = success();
1206 funcOp.walk([&](Operation *op) {
1208 if (!isa<scf::ForOp>(op))
1209 return WalkResult::advance();
1211 auto scfForOp = cast<scf::ForOp>(op);
1214 getState<ComponentLoweringState>().setUniqueName(forOp.
getOperation(),
1219 auto inductionVar = forOp.
getOperation().getInductionVar();
1220 SmallVector<std::string, 3> inductionVarIdentifiers = {
1221 getState<ComponentLoweringState>()
1224 "induction",
"var"};
1225 std::string name = llvm::join(inductionVarIdentifiers,
"_");
1228 inductionVar.getType().getIntOrFloatBitWidth(), name);
1229 getState<ComponentLoweringState>().addForLoopIterReg(forOp,
reg, 0);
1230 inductionVar.replaceAllUsesWith(
reg.getOut());
1233 calyx::ComponentOp componentOp =
1234 getState<ComponentLoweringState>().getComponentOp();
1235 SmallVector<calyx::GroupOp> initGroups;
1236 SmallVector<std::string, 4> groupIdentifiers = {
1238 getState<ComponentLoweringState>()
1241 "induction",
"var"};
1242 std::string groupName = llvm::join(groupIdentifiers,
"_");
1243 auto groupOp = calyx::createGroup<calyx::GroupOp>(
1244 rewriter, componentOp, forOp.
getLoc(), groupName);
1247 initGroups.push_back(groupOp);
1248 getState<ComponentLoweringState>().setForLoopInitGroups(forOp,
1251 return WalkResult::advance();
1263 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1267 PatternRewriter &rewriter)
const override {
1268 auto *entryBlock = &funcOp.getBlocks().front();
1269 rewriter.setInsertionPointToStart(
1270 getComponent().getControlOp().getBodyBlock());
1271 auto topLevelSeqOp = rewriter.create<calyx::SeqOp>(funcOp.getLoc());
1272 DenseSet<Block *> path;
1273 return buildCFGControl(path, rewriter, topLevelSeqOp.getBodyBlock(),
1274 nullptr, entryBlock);
1281 const DenseSet<Block *> &path,
1282 mlir::Block *parentCtrlBlock,
1283 mlir::Block *block)
const {
1284 auto compBlockScheduleables =
1285 getState<ComponentLoweringState>().getBlockScheduleables(block);
1286 auto loc = block->front().getLoc();
1288 if (compBlockScheduleables.size() > 1) {
1289 auto seqOp = rewriter.create<calyx::SeqOp>(loc);
1290 parentCtrlBlock = seqOp.getBodyBlock();
1293 for (
auto &group : compBlockScheduleables) {
1294 rewriter.setInsertionPointToEnd(parentCtrlBlock);
1295 if (
auto groupPtr = std::get_if<calyx::GroupOp>(&group); groupPtr) {
1296 rewriter.create<calyx::EnableOp>(groupPtr->getLoc(),
1297 groupPtr->getSymName());
1298 }
else if (
auto whileSchedPtr = std::get_if<WhileScheduleable>(&group);
1300 auto &whileOp = whileSchedPtr->whileOp;
1302 auto whileCtrlOp = buildWhileCtrlOp(
1304 getState<ComponentLoweringState>().getWhileLoopInitGroups(whileOp),
1306 rewriter.setInsertionPointToEnd(whileCtrlOp.getBodyBlock());
1308 rewriter.create<calyx::SeqOp>(whileOp.getOperation()->getLoc());
1309 auto *whileBodyOpBlock = whileBodyOp.getBodyBlock();
1313 LogicalResult res = buildCFGControl(path, rewriter, whileBodyOpBlock,
1314 block, whileOp.getBodyBlock());
1317 rewriter.setInsertionPointToEnd(whileBodyOpBlock);
1318 calyx::GroupOp whileLatchGroup =
1319 getState<ComponentLoweringState>().getWhileLoopLatchGroup(whileOp);
1320 rewriter.create<calyx::EnableOp>(whileLatchGroup.getLoc(),
1321 whileLatchGroup.getName());
1325 }
else if (
auto *forSchedPtr = std::get_if<ForScheduleable>(&group);
1327 auto forOp = forSchedPtr->forOp;
1329 auto forCtrlOp = buildForCtrlOp(
1331 getState<ComponentLoweringState>().getForLoopInitGroups(forOp),
1332 forSchedPtr->bound, rewriter);
1333 rewriter.setInsertionPointToEnd(forCtrlOp.getBodyBlock());
1335 rewriter.create<calyx::SeqOp>(forOp.getOperation()->getLoc());
1336 auto *forBodyOpBlock = forBodyOp.getBodyBlock();
1339 LogicalResult res = buildCFGControl(path, rewriter, forBodyOpBlock,
1340 block, forOp.getBodyBlock());
1343 rewriter.setInsertionPointToEnd(forBodyOpBlock);
1344 calyx::GroupOp forLatchGroup =
1345 getState<ComponentLoweringState>().getForLoopLatchGroup(forOp);
1346 rewriter.create<calyx::EnableOp>(forLatchGroup.getLoc(),
1347 forLatchGroup.getName());
1350 }
else if (
auto *callSchedPtr = std::get_if<CallScheduleable>(&group)) {
1351 auto instanceOp = callSchedPtr->instanceOp;
1352 OpBuilder::InsertionGuard g(rewriter);
1353 auto callBody = rewriter.create<calyx::SeqOp>(instanceOp.getLoc());
1354 rewriter.setInsertionPointToStart(callBody.getBodyBlock());
1355 std::string initGroupName =
"init_" + instanceOp.getSymName().str();
1356 rewriter.create<calyx::EnableOp>(instanceOp.getLoc(), initGroupName);
1357 SmallVector<Value, 4> instancePorts;
1358 auto inputPorts = callSchedPtr->callOp.getOperands();
1359 llvm::copy(instanceOp.getResults().take_front(inputPorts.size()),
1360 std::back_inserter(instancePorts));
1361 rewriter.create<calyx::InvokeOp>(
1362 instanceOp.getLoc(), instanceOp.getSymName(), instancePorts,
1366 llvm_unreachable(
"Unknown scheduleable");
1377 const DenseSet<Block *> &path, Location loc,
1378 Block *from, Block *to,
1379 Block *parentCtrlBlock)
const {
1382 rewriter.setInsertionPointToEnd(parentCtrlBlock);
1383 auto preSeqOp = rewriter.create<calyx::SeqOp>(loc);
1384 rewriter.setInsertionPointToEnd(preSeqOp.getBodyBlock());
1386 getState<ComponentLoweringState>().getBlockArgGroups(from, to))
1387 rewriter.create<calyx::EnableOp>(barg.getLoc(), barg.getSymName());
1389 return buildCFGControl(path, rewriter, parentCtrlBlock, from, to);
1393 PatternRewriter &rewriter,
1394 mlir::Block *parentCtrlBlock,
1395 mlir::Block *preBlock,
1396 mlir::Block *block)
const {
1397 if (path.count(block) != 0)
1398 return preBlock->getTerminator()->emitError()
1399 <<
"CFG backedge detected. Loops must be raised to 'scf.while' or "
1400 "'scf.for' operations.";
1402 rewriter.setInsertionPointToEnd(parentCtrlBlock);
1403 LogicalResult bbSchedResult =
1404 scheduleBasicBlock(rewriter, path, parentCtrlBlock, block);
1405 if (bbSchedResult.failed())
1406 return bbSchedResult;
1409 auto successors = block->getSuccessors();
1410 auto nSuccessors = successors.size();
1411 if (nSuccessors > 0) {
1412 auto brOp = dyn_cast<BranchOpInterface>(block->getTerminator());
1414 if (nSuccessors > 1) {
1418 assert(nSuccessors == 2 &&
1419 "only conditional branches supported for now...");
1421 auto cond = brOp->getOperand(0);
1422 auto condGroup = getState<ComponentLoweringState>()
1423 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
1427 auto ifOp = rewriter.create<calyx::IfOp>(
1428 brOp->getLoc(), cond, symbolAttr,
true);
1429 rewriter.setInsertionPointToStart(ifOp.getThenBody());
1430 auto thenSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
1431 rewriter.setInsertionPointToStart(ifOp.getElseBody());
1432 auto elseSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
1434 bool trueBrSchedSuccess =
1435 schedulePath(rewriter, path, brOp.getLoc(), block, successors[0],
1436 thenSeqOp.getBodyBlock())
1438 bool falseBrSchedSuccess =
true;
1439 if (trueBrSchedSuccess) {
1440 falseBrSchedSuccess =
1441 schedulePath(rewriter, path, brOp.getLoc(), block, successors[1],
1442 elseSeqOp.getBodyBlock())
1446 return success(trueBrSchedSuccess && falseBrSchedSuccess);
1449 return schedulePath(rewriter, path, brOp.getLoc(), block,
1450 successors.front(), parentCtrlBlock);
1460 const SmallVector<calyx::GroupOp> &initGroups)
const {
1461 PatternRewriter::InsertionGuard g(rewriter);
1462 auto parOp = rewriter.create<calyx::ParOp>(loc);
1463 rewriter.setInsertionPointToStart(parOp.getBodyBlock());
1464 for (calyx::GroupOp group : initGroups)
1465 rewriter.create<calyx::EnableOp>(group.getLoc(), group.getName());
1469 SmallVector<calyx::GroupOp> initGroups,
1470 PatternRewriter &rewriter)
const {
1471 Location loc = whileOp.
getLoc();
1474 insertParInitGroups(rewriter, loc, initGroups);
1478 auto condGroup = getState<ComponentLoweringState>()
1479 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
1482 return rewriter.create<calyx::WhileOp>(loc, cond, symbolAttr);
1486 SmallVector<calyx::GroupOp>
const &initGroups,
1488 PatternRewriter &rewriter)
const {
1489 Location loc = forOp.
getLoc();
1492 insertParInitGroups(rewriter, loc, initGroups);
1495 return rewriter.create<calyx::RepeatOp>(loc, bound);
1502 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1505 PatternRewriter &)
const override {
1506 funcOp.walk([&](scf::WhileOp op) {
1515 getState<ComponentLoweringState>().getWhileLoopIterRegs(whileOp))
1516 whileOp.
getOperation()->getResults()[res.first].replaceAllUsesWith(
1517 res.second.getOut());
1520 funcOp.walk([&](memref::LoadOp loadOp) {
1526 loadOp.getResult().replaceAllUsesWith(
1527 getState<ComponentLoweringState>()
1528 .getMemoryInterface(loadOp.getMemref())
1539 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1542 PatternRewriter &rewriter)
const override {
1543 rewriter.eraseOp(funcOp);
1549 PatternRewriter &rewriter)
const override {
1561 void runOnOperation()
override;
1564 std::string &topLevelFunction) {
1565 if (!topLevelFunctionOpt.empty()) {
1566 if (SymbolTable::lookupSymbolIn(moduleOp, topLevelFunctionOpt) ==
1568 moduleOp.emitError() <<
"Top level function '" << topLevelFunctionOpt
1569 <<
"' not found in module.";
1572 topLevelFunction = topLevelFunctionOpt;
1576 auto funcOps = moduleOp.getOps<FuncOp>();
1577 if (std::distance(funcOps.begin(), funcOps.end()) == 1)
1578 topLevelFunction = (*funcOps.begin()).getSymName().str();
1580 moduleOp.emitError()
1581 <<
"Module contains multiple functions, but no top level "
1582 "function was set. Please see --top-level-function";
1603 using OpRewritePattern::OpRewritePattern;
1604 LogicalResult matchAndRewrite(mlir::ModuleOp,
1605 PatternRewriter &)
const override {
1610 ConversionTarget target(getContext());
1611 target.addLegalDialect<calyx::CalyxDialect>();
1612 target.addLegalDialect<scf::SCFDialect>();
1613 target.addIllegalDialect<hw::HWDialect>();
1614 target.addIllegalDialect<comb::CombDialect>();
1617 target.addIllegalDialect<FuncDialect>();
1618 target.addIllegalDialect<ArithDialect>();
1619 target.addLegalOp<AddIOp, SelectOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp,
1620 ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
1621 CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
1622 RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
1625 RewritePatternSet legalizePatterns(&getContext());
1626 legalizePatterns.add<DummyPattern>(&getContext());
1627 DenseSet<Operation *> legalizedOps;
1628 if (applyPartialConversion(getOperation(), target,
1629 std::move(legalizePatterns))
1640 template <
typename TPattern,
typename... PatternArgs>
1642 PatternArgs &&...args) {
1643 RewritePatternSet ps(&getContext());
1644 ps.add<TPattern>(&getContext(), partialPatternRes, args...);
1649 template <
typename TPattern,
typename... PatternArgs>
1651 PatternArgs &&...args) {
1652 RewritePatternSet ps(&getContext());
1653 ps.add<TPattern>(&getContext(), args...);
1659 assert(pattern.getNativePatterns().size() == 1 &&
1660 "Should only apply 1 partial lowering pattern at once");
1666 GreedyRewriteConfig config;
1667 config.enableRegionSimplification =
false;
1669 config.maxIterations = 1;
1674 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(pattern),
1676 return partialPatternRes;
1681 std::shared_ptr<calyx::CalyxLoweringState> loweringState =
nullptr;
1684 void SCFToCalyxPass::runOnOperation() {
1686 loweringState.reset();
1687 partialPatternRes = LogicalResult::failure();
1689 std::string topLevelFunction;
1690 if (failed(setTopLevelFunction(getOperation(), topLevelFunction))) {
1691 signalPassFailure();
1696 if (failed(labelEntryPoint(topLevelFunction))) {
1697 signalPassFailure();
1700 loweringState = std::make_shared<calyx::CalyxLoweringState>(getOperation(),
1711 DenseMap<FuncOp, calyx::ComponentOp> funcMap;
1712 SmallVector<LoweringPattern, 8> loweringPatterns;
1716 addOncePattern<FuncOpConversion>(loweringPatterns, patternState, funcMap,
1720 addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);
1723 addOncePattern<calyx::ConvertIndexTypes>(loweringPatterns, patternState,
1724 funcMap, *loweringState);
1727 addOncePattern<calyx::BuildBasicBlockRegs>(loweringPatterns, patternState,
1728 funcMap, *loweringState);
1730 addOncePattern<calyx::BuildCallInstance>(loweringPatterns, patternState,
1731 funcMap, *loweringState);
1734 addOncePattern<calyx::BuildReturnRegs>(loweringPatterns, patternState,
1735 funcMap, *loweringState);
1740 addOncePattern<BuildWhileGroups>(loweringPatterns, patternState, funcMap,
1746 addOncePattern<BuildForGroups>(loweringPatterns, patternState, funcMap,
1756 addOncePattern<BuildOpGroups>(loweringPatterns, patternState, funcMap,
1762 addOncePattern<BuildControl>(loweringPatterns, patternState, funcMap,
1767 addOncePattern<calyx::InlineCombGroups>(loweringPatterns, patternState,
1772 addOncePattern<LateSSAReplacement>(loweringPatterns, patternState, funcMap,
1778 addGreedyPattern<calyx::EliminateUnusedCombGroups>(loweringPatterns);
1782 addOncePattern<calyx::RewriteMemoryAccesses>(loweringPatterns, patternState,
1787 addOncePattern<CleanupFuncOps>(loweringPatterns, patternState, funcMap,
1791 for (
auto &pat : loweringPatterns) {
1792 LogicalResult partialPatternRes = runPartialPattern(
1794 pat.strategy == LoweringPattern::Strategy::Once);
1795 if (succeeded(partialPatternRes))
1797 signalPassFailure();
1804 RewritePatternSet cleanupPatterns(&getContext());
1807 if (failed(applyPatternsAndFoldGreedily(getOperation(),
1808 std::move(cleanupPatterns)))) {
1809 signalPassFailure();
1813 if (ciderSourceLocationMetadata) {
1816 SmallVector<Attribute, 16> sourceLocations;
1817 getOperation()->walk([&](calyx::ComponentOp component) {
1821 MLIRContext *context = getOperation()->getContext();
1822 getOperation()->setAttr(
"calyx.metadata",
1834 return std::make_unique<scftocalyx::SCFToCalyxPass>();
assert(baseType &&"element must be base type")
void setFuncOpResultMapping(const DenseMap< unsigned, unsigned > &mapping)
Assign a mapping between the source funcOp result indices and the corresponding output port indices o...
std::string getUniqueName(StringRef prefix)
Returns a unique name within compOp with the provided prefix.
void registerMemoryInterface(Value memref, const calyx::MemoryInterface &memoryInterface)
Registers a memory interface as being associated with a memory identified by 'memref'.
calyx::ComponentOp getComponentOp()
Returns the calyx::ComponentOp associated with this lowering state.
FuncOpPartialLoweringPatterns are patterns which intend to match on FuncOps and then perform their ow...
Location getLoc() override
Holds common utilities used for scheduling when lowering to Calyx.
Location getLoc() override
Builds a control schedule by traversing the CFG of the function and associating this with the previou...
calyx::RepeatOp buildForCtrlOp(ScfForOp forOp, SmallVector< calyx::GroupOp > const &initGroups, uint64_t bound, PatternRewriter &rewriter) const
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult schedulePath(PatternRewriter &rewriter, const DenseSet< Block * > &path, Location loc, Block *from, Block *to, Block *parentCtrlBlock) const
Schedules a block by inserting a branch argument assignment block (if any) before recursing into the ...
calyx::WhileOp buildWhileCtrlOp(ScfWhileOp whileOp, SmallVector< calyx::GroupOp > initGroups, PatternRewriter &rewriter) const
LogicalResult scheduleBasicBlock(PatternRewriter &rewriter, const DenseSet< Block * > &path, mlir::Block *parentCtrlBlock, mlir::Block *block) const
Sequentially schedules the groups that registered themselves with 'block'.
LogicalResult buildCFGControl(DenseSet< Block * > path, PatternRewriter &rewriter, mlir::Block *parentCtrlBlock, mlir::Block *preBlock, mlir::Block *block) const
void insertParInitGroups(PatternRewriter &rewriter, Location loc, const SmallVector< calyx::GroupOp > &initGroups) const
In BuildForGroups, a register is created for the iteration argument of the for op.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Iterate through the operations of a source function and instantiate components or primitives based on...
TGroupOp createGroupForOp(PatternRewriter &rewriter, Operation *op) const
Creates a group named by the basic block which the input op resides in.
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op) const
buildLibraryOp which provides in- and output types based on the operands and results of the op argume...
void assignAddressPorts(PatternRewriter &rewriter, Location loc, calyx::GroupInterface group, calyx::MemoryInterface memoryInterface, Operation::operand_range addressValues) const
Creates assignments within the provided group to the address ports of the memoryOp based on the provi...
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op, TypeRange srcTypes, TypeRange dstTypes) const
buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the source operation TSrcOp.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult buildLibraryBinaryPipeOp(PatternRewriter &rewriter, TSrcOp op, TOpType opPipe, Value out) const
buildLibraryBinaryPipeOp will build a TCalyxLibBinaryPipeOp, to deal with MulIOp, DivUIOp and RemUIOp...
In BuildWhileGroups, a register is created for each iteration argumenet of the while op.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Erases FuncOp operations.
LogicalResult matchAndRewrite(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Handles the current state of lowering of a Calyx component.
ComponentLoweringState(calyx::ComponentOp component)
void setForLoopInitGroups(ScfForOp op, SmallVector< calyx::GroupOp > groups)
calyx::GroupOp buildForLoopIterArgAssignments(OpBuilder &builder, ScfForOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
void setForLoopLatchGroup(ScfForOp op, calyx::GroupOp group)
SmallVector< calyx::GroupOp > getForLoopInitGroups(ScfForOp op)
void addForLoopIterReg(ScfForOp op, calyx::RegisterOp reg, unsigned idx)
calyx::GroupOp getForLoopLatchGroup(ScfForOp op)
calyx::RegisterOp getForLoopIterReg(ScfForOp op, unsigned idx)
const DenseMap< unsigned, calyx::RegisterOp > & getForLoopIterRegs(ScfForOp op)
Inlines Calyx ExecuteRegionOp operations within their parent blocks.
LogicalResult matchAndRewrite(scf::ExecuteRegionOp execOp, PatternRewriter &rewriter) const override
LateSSAReplacement contains various functions for replacing SSA values that were not replaced during ...
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &) const override
LogicalResult labelEntryPoint(StringRef topLevelFunction)
Labels the entry point of a Calyx program.
LogicalResult runPartialPattern(RewritePatternSet &pattern, bool runOnce)
LogicalResult setTopLevelFunction(mlir::ModuleOp moduleOp, std::string &topLevelFunction)
void addGreedyPattern(SmallVectorImpl< LoweringPattern > &patterns, PatternArgs &&...args)
LogicalResult partialPatternRes
void addOncePattern(SmallVectorImpl< LoweringPattern > &patterns, PatternArgs &&...args)
'Once' patterns are expected to take an additional LogicalResult& argument, to forward their result s...
std::optional< int64_t > getBound() override
Block * getBodyBlock() override
Block::BlockArgListType getBodyArgs() override
ScfWhileOp(scf::WhileOp op)
Block::BlockArgListType getBodyArgs() override
Block * getConditionBlock() override
std::optional< int64_t > getBound() override
Block * getBodyBlock() override
Value getConditionValue() override
calyx::GroupOp buildWhileLoopIterArgAssignments(OpBuilder &builder, ScfWhileOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
void setWhileLoopInitGroups(ScfWhileOp op, SmallVector< calyx::GroupOp > groups)
SmallVector< calyx::GroupOp > getWhileLoopInitGroups(ScfWhileOp op)
void addWhileLoopIterReg(ScfWhileOp op, calyx::RegisterOp reg, unsigned idx)
const DenseMap< unsigned, calyx::RegisterOp > & getWhileLoopIterRegs(ScfWhileOp op)
void setWhileLoopLatchGroup(ScfWhileOp op, calyx::GroupOp group)
calyx::GroupOp getWhileLoopLatchGroup(ScfWhileOp op)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
void addMandatoryComponentPorts(PatternRewriter &rewriter, SmallVectorImpl< calyx::PortInfo > &ports)
void appendPortsForExternalMemref(PatternRewriter &rewriter, StringRef memName, Value memref, unsigned memoryID, SmallVectorImpl< calyx::PortInfo > &inPorts, SmallVectorImpl< calyx::PortInfo > &outPorts)
void buildAssignmentsForRegisterWrite(OpBuilder &builder, calyx::GroupOp groupOp, calyx::ComponentOp componentOp, calyx::RegisterOp ®, Value inputValue)
Creates register assignment operations within the provided groupOp.
DenseMap< const mlir::RewritePattern *, SmallPtrSet< Operation *, 16 > > PatternApplicationState
Extra state that is passed to all PartialLoweringPatterns so they can record when they have run on an...
Type convIndexType(OpBuilder &builder, Type type)
LogicalResult applyModuleOpConversion(mlir::ModuleOp, StringRef topLevelFunction)
Helper to update the top-level ModuleOp to set the entrypoing function.
WalkResult getCiderSourceLocationMetadata(calyx::ComponentOp component, SmallVectorImpl< Attribute > &sourceLocations)
bool matchConstantOp(Operation *op, APInt &value)
unsigned handleZeroWidth(int64_t dim)
hw::ConstantOp createConstant(Location loc, OpBuilder &builder, ComponentOp component, size_t width, size_t value)
A helper function to create constants in the HW dialect.
calyx::RegisterOp createRegister(Location loc, OpBuilder &builder, ComponentOp component, size_t width, Twine prefix)
Creates a RegisterOp, with input and output port bit widths defined by width.
bool noStoresToMemory(Value memoryReference)
bool singleLoadFromMemory(Value memoryReference)
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
static constexpr std::string_view sPortNameAttr
std::variant< calyx::GroupOp, WhileScheduleable, ForScheduleable, CallScheduleable > Scheduleable
A variant of types representing scheduleable operations.
static LogicalResult buildAllocOp(ComponentLoweringState &componentState, PatternRewriter &rewriter, TAllocOp allocOp)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::unique_ptr< OperationPass< ModuleOp > > createSCFToCalyxPass()
Create an SCF to Calyx conversion pass.
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
std::optional< Value > writeEn
std::optional< Value > writeData
std::optional< Value > readData
std::optional< Value > writeDone
SmallVector< Value > addrPorts
When building groups which contain accesses to multiple sequential components, a group_done op is cre...
GroupDoneOp's are terminator operations and should therefore be the last operator in a group.
This holds information about the port for either a Component or Cell.
calyx::InstanceOp instanceOp
Instance for invoking.
ScfForOp forOp
For operation to schedule.
Creates a new Calyx component for each FuncOp in the program.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
RewritePatternSet pattern
ScfWhileOp whileOp
While operation to schedule.