14 #include "../PassDetail.h"
20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21 #include "mlir/Conversion/LLVMCommon/Pattern.h"
22 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
23 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/IR/AsmState.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 #include "llvm/ADT/TypeSwitch.h"
40 class ComponentLoweringStateInterface;
41 namespace scftocalyx {
50 : calyx::WhileOpInterface<scf::
WhileOp>(op) {}
53 return getOperation().getAfterArguments();
56 Block *
getBodyBlock()
override {
return &getOperation().getAfter().front(); }
59 return &getOperation().getBefore().front();
63 return getOperation().getConditionOp().getOperand(0);
66 Optional<uint64_t>
getBound()
override {
return None; }
82 using Scheduleable = std::variant<calyx::GroupOp, WhileScheduleable>;
93 : calyx::ComponentLoweringStateInterface(component) {}
103 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
107 PatternRewriter &rewriter)
const override {
110 bool opBuiltSuccessfully =
true;
111 funcOp.walk([&](Operation *_op) {
112 opBuiltSuccessfully &=
113 TypeSwitch<mlir::Operation *, bool>(_op)
114 .template Case<arith::ConstantOp, ReturnOp, BranchOpInterface,
118 memref::AllocOp, memref::AllocaOp, memref::LoadOp,
121 AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
122 AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
123 MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
125 [&](
auto op) {
return buildOp(rewriter, op).succeeded(); })
126 .
template Case<scf::WhileOp, FuncOp, scf::ConditionOp>([&](
auto) {
130 .Default([&](
auto op) {
131 op->emitError() <<
"Unhandled operation during BuildOpGroups()";
135 return opBuiltSuccessfully ? WalkResult::advance()
136 : WalkResult::interrupt();
139 return success(opBuiltSuccessfully);
144 LogicalResult buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp)
const;
145 LogicalResult buildOp(PatternRewriter &rewriter,
146 BranchOpInterface brOp)
const;
147 LogicalResult buildOp(PatternRewriter &rewriter,
148 arith::ConstantOp constOp)
const;
149 LogicalResult buildOp(PatternRewriter &rewriter, AddIOp op)
const;
150 LogicalResult buildOp(PatternRewriter &rewriter, SubIOp op)
const;
151 LogicalResult buildOp(PatternRewriter &rewriter, MulIOp op)
const;
152 LogicalResult buildOp(PatternRewriter &rewriter, DivUIOp op)
const;
153 LogicalResult buildOp(PatternRewriter &rewriter, DivSIOp op)
const;
154 LogicalResult buildOp(PatternRewriter &rewriter, RemUIOp op)
const;
155 LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op)
const;
156 LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op)
const;
157 LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op)
const;
158 LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op)
const;
159 LogicalResult buildOp(PatternRewriter &rewriter, AndIOp op)
const;
160 LogicalResult buildOp(PatternRewriter &rewriter, OrIOp op)
const;
161 LogicalResult buildOp(PatternRewriter &rewriter, XOrIOp op)
const;
162 LogicalResult buildOp(PatternRewriter &rewriter, CmpIOp op)
const;
163 LogicalResult buildOp(PatternRewriter &rewriter, TruncIOp op)
const;
164 LogicalResult buildOp(PatternRewriter &rewriter, ExtUIOp op)
const;
165 LogicalResult buildOp(PatternRewriter &rewriter, ExtSIOp op)
const;
166 LogicalResult buildOp(PatternRewriter &rewriter, ReturnOp op)
const;
167 LogicalResult buildOp(PatternRewriter &rewriter, IndexCastOp op)
const;
168 LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocOp op)
const;
169 LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocaOp op)
const;
170 LogicalResult buildOp(PatternRewriter &rewriter, memref::LoadOp op)
const;
171 LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op)
const;
175 template <
typename TGroupOp,
typename TCalyxLibOp,
typename TSrcOp>
177 TypeRange srcTypes, TypeRange dstTypes)
const {
178 SmallVector<Type>
types;
179 llvm::append_range(
types, srcTypes);
180 llvm::append_range(
types, dstTypes);
183 getState<ComponentLoweringState>().getNewLibraryOpInstance<TCalyxLibOp>(
184 rewriter, op.getLoc(),
types);
186 auto directions = calyxOp.portDirections();
187 SmallVector<Value, 4> opInputPorts;
188 SmallVector<Value, 4> opOutputPorts;
189 for (
auto dir : enumerate(directions)) {
191 opInputPorts.push_back(calyxOp.getResult(dir.index()));
193 opOutputPorts.push_back(calyxOp.getResult(dir.index()));
196 opInputPorts.size() == op->getNumOperands() &&
197 opOutputPorts.size() == op->getNumResults() &&
198 "Expected an equal number of in/out ports in the Calyx library op with "
199 "respect to the number of operands/results of the source operation.");
202 auto group = createGroupForOp<TGroupOp>(rewriter, op);
203 rewriter.setInsertionPointToEnd(group.getBody());
204 for (
auto dstOp : enumerate(opInputPorts))
205 rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
206 op->getOperand(dstOp.index()));
209 for (
auto res : enumerate(opOutputPorts)) {
210 getState<ComponentLoweringState>().registerEvaluatingGroup(res.value(),
212 op->getResult(res.index()).replaceAllUsesWith(res.value());
219 template <
typename TGroupOp,
typename TCalyxLibOp,
typename TSrcOp>
221 return buildLibraryOp<TGroupOp, TCalyxLibOp, TSrcOp>(
222 rewriter, op, op.getOperandTypes(), op->getResultTypes());
226 template <
typename TGroupOp>
228 Block *block = op->getBlock();
229 auto groupName = getState<ComponentLoweringState>().getUniqueName(
230 programState().blockName(block));
231 return calyx::createGroup<TGroupOp>(
232 rewriter, getState<ComponentLoweringState>().getComponentOp(),
233 op->getLoc(), groupName);
238 template <
typename TOpType,
typename TSrcOp>
240 TOpType opPipe, Value out)
const {
241 StringRef opName = TSrcOp::getOperationName().split(
".").second;
242 Location loc = op.getLoc();
243 Type width = op.getResult().getType();
245 op.getResult().replaceAllUsesWith(out);
247 op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(),
248 getState<ComponentLoweringState>().getUniqueName(opName));
250 auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
251 getState<ComponentLoweringState>().addBlockScheduleable(op->getBlock(),
254 rewriter.setInsertionPointToEnd(group.getBody());
255 rewriter.create<calyx::AssignOp>(loc, opPipe.left(), op.getLhs());
256 rewriter.create<calyx::AssignOp>(loc, opPipe.right(), op.getRhs());
258 rewriter.create<calyx::AssignOp>(loc,
reg.in(), out);
260 rewriter.create<calyx::AssignOp>(loc,
reg.write_en(), opPipe.done());
261 rewriter.create<calyx::AssignOp>(
262 loc, opPipe.go(),
createConstant(loc, rewriter, getComponent(), 1, 1));
264 rewriter.create<calyx::GroupDoneOp>(loc,
reg.done());
267 getState<ComponentLoweringState>().registerEvaluatingGroup(out, group);
268 getState<ComponentLoweringState>().registerEvaluatingGroup(opPipe.left(),
270 getState<ComponentLoweringState>().registerEvaluatingGroup(opPipe.right(),
279 calyx::GroupInterface group,
281 Operation::operand_range addressValues)
const {
282 IRRewriter::InsertionGuard guard(rewriter);
283 rewriter.setInsertionPointToEnd(group.getBody());
284 auto addrPorts = memoryInterface.
addrPorts();
285 assert(addrPorts.size() == addressValues.size() &&
286 "Mismatch between number of address ports of the provided memory "
287 "and address assignment values");
288 for (
auto &address : enumerate(addressValues))
289 rewriter.create<calyx::AssignOp>(loc, addrPorts[address.index()],
294 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
295 memref::LoadOp loadOp)
const {
296 Value memref = loadOp.memref();
297 auto memoryInterface =
298 getState<ComponentLoweringState>().getMemoryInterface(memref);
306 auto combGroup = createGroupForOp<calyx::CombGroupOp>(rewriter, loadOp);
307 assignAddressPorts(rewriter, loadOp.getLoc(), combGroup, memoryInterface,
308 loadOp.getIndices());
319 getState<ComponentLoweringState>().registerEvaluatingGroup(
320 loadOp.getResult(), combGroup);
322 auto group = createGroupForOp<calyx::GroupOp>(rewriter, loadOp);
323 assignAddressPorts(rewriter, loadOp.getLoc(), group, memoryInterface,
324 loadOp.getIndices());
335 loadOp.getLoc(), rewriter, getComponent(),
336 loadOp.getMemRefType().getElementTypeBitWidth(),
337 getState<ComponentLoweringState>().getUniqueName(
"load"));
339 rewriter, group, getState<ComponentLoweringState>().getComponentOp(),
340 reg, memoryInterface.readData());
341 loadOp.getResult().replaceAllUsesWith(
reg.out());
342 getState<ComponentLoweringState>().addBlockScheduleable(loadOp->getBlock(),
348 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
349 memref::StoreOp storeOp)
const {
350 auto memoryInterface =
351 getState<ComponentLoweringState>().getMemoryInterface(storeOp.memref());
352 auto group = createGroupForOp<calyx::GroupOp>(rewriter, storeOp);
356 getState<ComponentLoweringState>().addBlockScheduleable(storeOp->getBlock(),
358 assignAddressPorts(rewriter, storeOp.getLoc(), group, memoryInterface,
359 storeOp.getIndices());
360 rewriter.setInsertionPointToEnd(group.getBody());
361 rewriter.create<calyx::AssignOp>(
362 storeOp.getLoc(), memoryInterface.writeData(), storeOp.getValueToStore());
363 rewriter.create<calyx::AssignOp>(
364 storeOp.getLoc(), memoryInterface.writeEn(),
365 createConstant(storeOp.getLoc(), rewriter, getComponent(), 1, 1));
366 rewriter.create<calyx::GroupDoneOp>(storeOp.getLoc(), memoryInterface.done());
371 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
373 Location loc = mul.getLoc();
374 Type width = mul.getResult().getType(), one = rewriter.getI1Type();
376 getState<ComponentLoweringState>()
377 .getNewLibraryOpInstance<calyx::MultPipeLibOp>(
378 rewriter, loc, {one, one, one, width, width, width, one});
379 return buildLibraryBinaryPipeOp<calyx::MultPipeLibOp>(rewriter, mul, mulPipe,
383 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
385 Location loc = div.getLoc();
386 Type width = div.getResult().getType(), one = rewriter.getI1Type();
388 getState<ComponentLoweringState>()
389 .getNewLibraryOpInstance<calyx::DivUPipeLibOp>(
390 rewriter, loc, {one, one, one, width, width, width, one});
391 return buildLibraryBinaryPipeOp<calyx::DivUPipeLibOp>(rewriter, div, divPipe,
395 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
397 Location loc = div.getLoc();
398 Type width = div.getResult().getType(), one = rewriter.getI1Type();
400 getState<ComponentLoweringState>()
401 .getNewLibraryOpInstance<calyx::DivSPipeLibOp>(
402 rewriter, loc, {one, one, one, width, width, width, one});
403 return buildLibraryBinaryPipeOp<calyx::DivSPipeLibOp>(rewriter, div, divPipe,
407 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
409 Location loc = rem.getLoc();
410 Type width = rem.getResult().getType(), one = rewriter.getI1Type();
412 getState<ComponentLoweringState>()
413 .getNewLibraryOpInstance<calyx::RemUPipeLibOp>(
414 rewriter, loc, {one, one, one, width, width, width, one});
415 return buildLibraryBinaryPipeOp<calyx::RemUPipeLibOp>(rewriter, rem, remPipe,
419 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
421 Location loc = rem.getLoc();
422 Type width = rem.getResult().getType(), one = rewriter.getI1Type();
424 getState<ComponentLoweringState>()
425 .getNewLibraryOpInstance<calyx::RemSPipeLibOp>(
426 rewriter, loc, {one, one, one, width, width, width, one});
427 return buildLibraryBinaryPipeOp<calyx::RemSPipeLibOp>(rewriter, rem, remPipe,
431 template <
typename TAllocOp>
433 PatternRewriter &rewriter, TAllocOp allocOp) {
434 rewriter.setInsertionPointToStart(componentState.
getComponentOp().getBody());
435 MemRefType memtype = allocOp.getType();
436 SmallVector<int64_t> addrSizes;
437 SmallVector<int64_t> sizes;
438 for (int64_t dim : memtype.getShape()) {
439 sizes.push_back(dim);
442 auto memoryOp = rewriter.create<calyx::MemoryOp>(
444 memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
447 memoryOp->setAttr(
"external",
454 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
455 memref::AllocOp allocOp)
const {
456 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
459 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
460 memref::AllocaOp allocOp)
const {
461 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
464 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
465 scf::YieldOp yieldOp)
const {
466 if (yieldOp.getOperands().size() == 0)
468 auto whileOp = dyn_cast<scf::WhileOp>(yieldOp->getParentOp());
473 getState<ComponentLoweringState>().buildLoopIterArgAssignments(
474 rewriter, whileOpInterface,
475 getState<ComponentLoweringState>().getComponentOp(),
476 getState<ComponentLoweringState>().getUniqueName(whileOp) +
"_latch",
477 yieldOp->getOpOperands());
478 getState<ComponentLoweringState>().setLoopLatchGroup(whileOpInterface,
483 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
484 BranchOpInterface brOp)
const {
489 Block *srcBlock = brOp->getBlock();
490 for (
auto succBlock : enumerate(brOp->getSuccessors())) {
491 auto succOperands = brOp.getSuccessorOperands(succBlock.index());
492 if (succOperands.empty())
495 std::string groupName = programState().blockName(srcBlock) +
"_to_" +
496 programState().blockName(succBlock.value());
497 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter, getComponent(),
498 brOp.getLoc(), groupName);
500 auto dstBlockArgRegs =
501 getState<ComponentLoweringState>().getBlockArgRegs(succBlock.value());
503 for (
auto arg : enumerate(succOperands.getForwardedOperands())) {
504 auto reg = dstBlockArgRegs[arg.index()];
507 getState<ComponentLoweringState>().getComponentOp(),
reg,
512 getState<ComponentLoweringState>().addBlockArgGroup(
513 srcBlock, succBlock.value(), groupOp);
520 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
521 ReturnOp retOp)
const {
522 if (retOp.getNumOperands() == 0)
525 std::string groupName =
526 getState<ComponentLoweringState>().getUniqueName(
"ret_assign");
527 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter, getComponent(),
528 retOp.getLoc(), groupName);
529 for (
auto op : enumerate(retOp.getOperands())) {
530 auto reg = getState<ComponentLoweringState>().getReturnReg(op.index());
532 rewriter, groupOp, getState<ComponentLoweringState>().getComponentOp(),
536 getState<ComponentLoweringState>().addBlockScheduleable(retOp->getBlock(),
541 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
542 arith::ConstantOp constOp)
const {
546 auto hwConstOp = rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp, value);
547 hwConstOp->moveAfter(getComponent().getBody(),
548 getComponent().getBody()->begin());
552 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
554 return buildLibraryOp<calyx::CombGroupOp, calyx::AddLibOp>(rewriter, op);
556 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
558 return buildLibraryOp<calyx::CombGroupOp, calyx::SubLibOp>(rewriter, op);
560 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
562 return buildLibraryOp<calyx::CombGroupOp, calyx::RshLibOp>(rewriter, op);
564 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
566 return buildLibraryOp<calyx::CombGroupOp, calyx::SrshLibOp>(rewriter, op);
568 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
570 return buildLibraryOp<calyx::CombGroupOp, calyx::LshLibOp>(rewriter, op);
572 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
574 return buildLibraryOp<calyx::CombGroupOp, calyx::AndLibOp>(rewriter, op);
576 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
578 return buildLibraryOp<calyx::CombGroupOp, calyx::OrLibOp>(rewriter, op);
580 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
582 return buildLibraryOp<calyx::CombGroupOp, calyx::XorLibOp>(rewriter, op);
585 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
587 switch (op.getPredicate()) {
588 case CmpIPredicate::eq:
589 return buildLibraryOp<calyx::CombGroupOp, calyx::EqLibOp>(rewriter, op);
590 case CmpIPredicate::ne:
591 return buildLibraryOp<calyx::CombGroupOp, calyx::NeqLibOp>(rewriter, op);
592 case CmpIPredicate::uge:
593 return buildLibraryOp<calyx::CombGroupOp, calyx::GeLibOp>(rewriter, op);
594 case CmpIPredicate::ult:
595 return buildLibraryOp<calyx::CombGroupOp, calyx::LtLibOp>(rewriter, op);
596 case CmpIPredicate::ugt:
597 return buildLibraryOp<calyx::CombGroupOp, calyx::GtLibOp>(rewriter, op);
598 case CmpIPredicate::ule:
599 return buildLibraryOp<calyx::CombGroupOp, calyx::LeLibOp>(rewriter, op);
600 case CmpIPredicate::sge:
601 return buildLibraryOp<calyx::CombGroupOp, calyx::SgeLibOp>(rewriter, op);
602 case CmpIPredicate::slt:
603 return buildLibraryOp<calyx::CombGroupOp, calyx::SltLibOp>(rewriter, op);
604 case CmpIPredicate::sgt:
605 return buildLibraryOp<calyx::CombGroupOp, calyx::SgtLibOp>(rewriter, op);
606 case CmpIPredicate::sle:
607 return buildLibraryOp<calyx::CombGroupOp, calyx::SleLibOp>(rewriter, op);
609 llvm_unreachable(
"unsupported comparison predicate");
611 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
613 return buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
614 rewriter, op, {op.getOperand().getType()}, {op.getType()});
616 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
618 return buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
619 rewriter, op, {op.getOperand().getType()}, {op.getType()});
622 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
624 return buildLibraryOp<calyx::CombGroupOp, calyx::ExtSILibOp>(
625 rewriter, op, {op.getOperand().getType()}, {op.getType()});
628 LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
629 IndexCastOp op)
const {
632 unsigned targetBits = targetType.getIntOrFloatBitWidth();
633 unsigned sourceBits = sourceType.getIntOrFloatBitWidth();
634 LogicalResult res = success();
636 if (targetBits == sourceBits) {
639 op.getResult().replaceAllUsesWith(op.getOperand());
642 if (sourceBits > targetBits)
643 res = buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
644 rewriter, op, {sourceType}, {targetType});
646 res = buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
647 rewriter, op, {sourceType}, {targetType});
649 rewriter.eraseOp(op);
662 using OpRewritePattern::OpRewritePattern;
665 PatternRewriter &rewriter)
const override {
667 TypeRange yieldTypes = execOp.getResultTypes();
671 rewriter.setInsertionPointAfter(execOp);
672 auto *sinkBlock = rewriter.splitBlock(
674 execOp.getOperation()->getIterator()->getNextNode()->getIterator());
675 sinkBlock->addArguments(
677 SmallVector<Location, 4>(yieldTypes.size(), rewriter.getUnknownLoc()));
678 for (
auto res : enumerate(execOp.getResults()))
679 res.value().replaceAllUsesWith(sinkBlock->getArgument(res.index()));
683 make_early_inc_range(execOp.getRegion().getOps<scf::YieldOp>())) {
684 rewriter.setInsertionPointAfter(yieldOp);
685 rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, sinkBlock,
686 yieldOp.getOperands());
690 auto *preBlock = execOp->getBlock();
691 auto *execOpEntryBlock = &execOp.getRegion().front();
692 auto *postBlock = execOp->getBlock()->splitBlock(execOp);
693 rewriter.inlineRegionBefore(execOp.getRegion(), postBlock);
694 rewriter.mergeBlocks(postBlock, preBlock);
695 rewriter.eraseOp(execOp);
698 rewriter.mergeBlocks(execOpEntryBlock, preBlock);
706 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
710 PatternRewriter &rewriter)
const override {
713 DenseMap<Value, unsigned> funcOpArgRewrites;
717 DenseMap<unsigned, unsigned> funcOpResultMapping;
725 DenseMap<Value, std::pair<unsigned, unsigned>> extMemoryCompPortIndices;
729 SmallVector<calyx::PortInfo> inPorts, outPorts;
730 FunctionType funcType = funcOp.getFunctionType();
731 unsigned extMemCounter = 0;
732 for (
auto &arg : enumerate(funcOp.getArguments())) {
733 if (arg.value().getType().isa<MemRefType>()) {
736 "ext_mem" + std::to_string(extMemoryCompPortIndices.size());
737 extMemoryCompPortIndices[arg.value()] = {inPorts.size(),
740 extMemCounter++, inPorts, outPorts);
743 auto inName =
"in" + std::to_string(arg.index());
744 funcOpArgRewrites[arg.value()] = inPorts.size();
746 rewriter.getStringAttr(inName),
752 for (
auto &res : enumerate(funcType.getResults())) {
753 funcOpResultMapping[res.index()] = outPorts.size();
755 rewriter.getStringAttr(
"out" + std::to_string(res.index())),
762 auto ports = inPorts;
763 llvm::append_range(ports, outPorts);
767 auto compOp = rewriter.create<calyx::ComponentOp>(
768 funcOp.getLoc(), rewriter.getStringAttr(funcOp.getSymName()), ports);
771 compOp->setAttr(
"toplevel", rewriter.getUnitAttr());
774 functionMapping[funcOp] = compOp;
779 for (
auto &mapping : funcOpArgRewrites)
780 mapping.getFirst().replaceAllUsesWith(
781 compOp.getArgument(mapping.getSecond()));
784 for (
auto extMemPortIndices : extMemoryCompPortIndices) {
788 unsigned inPortsIt = extMemPortIndices.getSecond().first;
789 unsigned outPortsIt = extMemPortIndices.getSecond().second +
790 compOp.getInputPortInfo().size();
791 extMemPorts.
readData = compOp.getArgument(inPortsIt++);
792 extMemPorts.
done = compOp.getArgument(inPortsIt);
793 extMemPorts.
writeData = compOp.getArgument(outPortsIt++);
794 unsigned nAddresses = extMemPortIndices.getFirst()
799 for (
unsigned j = 0; j < nAddresses; ++j)
800 extMemPorts.
addrPorts.push_back(compOp.getArgument(outPortsIt++));
801 extMemPorts.
writeEn = compOp.getArgument(outPortsIt);
805 compState->registerMemoryInterface(extMemPortIndices.getFirst(),
818 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
822 PatternRewriter &rewriter)
const override {
823 LogicalResult res = success();
824 funcOp.walk([&](Operation *op) {
826 if (!isa<scf::WhileOp>(op))
827 return WalkResult::advance();
829 auto scfWhileOp = cast<scf::WhileOp>(op);
832 getState<ComponentLoweringState>().setUniqueName(whileOp.
getOperation(),
842 enumerate(scfWhileOp.getBefore().front().getArguments())) {
843 auto condOp = scfWhileOp.getConditionOp().getArgs()[barg.index()];
844 if (barg.value() != condOp) {
846 << programState().irName(barg.value())
847 <<
" != " << programState().irName(condOp)
848 <<
"do-while loops not supported; expected iter-args to "
849 "remain untransformed in the 'before' region of the "
851 return WalkResult::interrupt();
860 for (
auto arg : enumerate(whileOp.
getBodyArgs())) {
861 std::string name = getState<ComponentLoweringState>()
864 "_arg" + std::to_string(arg.index());
867 arg.value().getType().getIntOrFloatBitWidth(), name);
868 getState<ComponentLoweringState>().addLoopIterReg(whileOp,
reg,
870 arg.value().replaceAllUsesWith(
reg.out());
874 ->getArgument(arg.index())
875 .replaceAllUsesWith(
reg.out());
879 SmallVector<calyx::GroupOp> initGroups;
880 auto numOperands = whileOp.
getOperation()->getNumOperands();
881 for (
size_t i = 0; i < numOperands; ++i) {
883 getState<ComponentLoweringState>().buildLoopIterArgAssignments(
885 getState<ComponentLoweringState>().getComponentOp(),
886 getState<ComponentLoweringState>().getUniqueName(
888 "_init_" + std::to_string(i),
890 initGroups.push_back(initGroupOp);
893 getState<ComponentLoweringState>().addBlockScheduleable(
898 return WalkResult::advance();
910 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
914 PatternRewriter &rewriter)
const override {
915 auto *entryBlock = &funcOp.getBlocks().front();
916 rewriter.setInsertionPointToStart(getComponent().getControlOp().getBody());
917 auto topLevelSeqOp = rewriter.create<calyx::SeqOp>(funcOp.getLoc());
918 DenseSet<Block *> path;
919 return buildCFGControl(path, rewriter, topLevelSeqOp.getBody(),
nullptr,
927 const DenseSet<Block *> &path,
928 mlir::Block *parentCtrlBlock,
929 mlir::Block *block)
const {
930 auto compBlockScheduleables =
931 getState<ComponentLoweringState>().getBlockScheduleables(block);
932 auto loc = block->front().getLoc();
934 if (compBlockScheduleables.size() > 1) {
935 auto seqOp = rewriter.create<calyx::SeqOp>(loc);
936 parentCtrlBlock = seqOp.getBody();
939 for (
auto &group : compBlockScheduleables) {
940 rewriter.setInsertionPointToEnd(parentCtrlBlock);
941 if (
auto groupPtr = std::get_if<calyx::GroupOp>(&group); groupPtr) {
942 rewriter.create<calyx::EnableOp>(groupPtr->getLoc(),
943 groupPtr->sym_name());
944 }
else if (
auto whileSchedPtr = std::get_if<WhileScheduleable>(&group);
946 auto &whileOp = whileSchedPtr->whileOp;
949 buildWhileCtrlOp(whileOp, whileSchedPtr->initGroups, rewriter);
950 rewriter.setInsertionPointToEnd(whileCtrlOp.getBody());
952 rewriter.create<calyx::SeqOp>(whileOp.getOperation()->getLoc());
953 auto *whileBodyOpBlock = whileBodyOp.getBody();
957 LogicalResult res = buildCFGControl(path, rewriter, whileBodyOpBlock,
958 block, whileOp.getBodyBlock());
961 rewriter.setInsertionPointToEnd(whileBodyOpBlock);
962 calyx::GroupOp whileLatchGroup =
963 getState<ComponentLoweringState>().getLoopLatchGroup(whileOp);
964 rewriter.create<calyx::EnableOp>(whileLatchGroup.getLoc(),
965 whileLatchGroup.getName());
970 llvm_unreachable(
"Unknown scheduleable");
981 const DenseSet<Block *> &path, Location loc,
982 Block *from, Block *to,
983 Block *parentCtrlBlock)
const {
986 rewriter.setInsertionPointToEnd(parentCtrlBlock);
987 auto preSeqOp = rewriter.create<calyx::SeqOp>(loc);
988 rewriter.setInsertionPointToEnd(preSeqOp.getBody());
990 getState<ComponentLoweringState>().getBlockArgGroups(from, to))
991 rewriter.create<calyx::EnableOp>(barg.getLoc(), barg.sym_name());
993 return buildCFGControl(path, rewriter, parentCtrlBlock, from, to);
997 PatternRewriter &rewriter,
998 mlir::Block *parentCtrlBlock,
999 mlir::Block *preBlock,
1000 mlir::Block *block)
const {
1001 if (path.count(block) != 0)
1002 return preBlock->getTerminator()->emitError()
1003 <<
"CFG backedge detected. Loops must be raised to 'scf.while' or "
1004 "'scf.for' operations.";
1006 rewriter.setInsertionPointToEnd(parentCtrlBlock);
1007 LogicalResult bbSchedResult =
1008 scheduleBasicBlock(rewriter, path, parentCtrlBlock, block);
1009 if (bbSchedResult.failed())
1010 return bbSchedResult;
1013 auto successors = block->getSuccessors();
1014 auto nSuccessors = successors.size();
1015 if (nSuccessors > 0) {
1016 auto brOp = dyn_cast<BranchOpInterface>(block->getTerminator());
1018 if (nSuccessors > 1) {
1022 assert(nSuccessors == 2 &&
1023 "only conditional branches supported for now...");
1025 auto cond = brOp->getOperand(0);
1026 auto condGroup = getState<ComponentLoweringState>()
1027 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
1031 auto ifOp = rewriter.create<calyx::IfOp>(
1032 brOp->getLoc(), cond, symbolAttr,
true);
1033 rewriter.setInsertionPointToStart(ifOp.getThenBody());
1034 auto thenSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
1035 rewriter.setInsertionPointToStart(ifOp.getElseBody());
1036 auto elseSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
1038 bool trueBrSchedSuccess =
1039 schedulePath(rewriter, path, brOp.getLoc(), block, successors[0],
1040 thenSeqOp.getBody())
1042 bool falseBrSchedSuccess =
true;
1043 if (trueBrSchedSuccess) {
1044 falseBrSchedSuccess =
1045 schedulePath(rewriter, path, brOp.getLoc(), block, successors[1],
1046 elseSeqOp.getBody())
1050 return success(trueBrSchedSuccess && falseBrSchedSuccess);
1053 return schedulePath(rewriter, path, brOp.getLoc(), block,
1054 successors.front(), parentCtrlBlock);
1061 SmallVector<calyx::GroupOp> initGroups,
1062 PatternRewriter &rewriter)
const {
1063 Location loc = whileOp.
getLoc();
1067 PatternRewriter::InsertionGuard g(rewriter);
1068 auto parOp = rewriter.create<calyx::ParOp>(loc);
1069 rewriter.setInsertionPointToStart(parOp.getBody());
1070 for (calyx::GroupOp group : initGroups)
1071 rewriter.create<calyx::EnableOp>(group.getLoc(), group.getName());
1076 auto condGroup = getState<ComponentLoweringState>()
1077 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
1080 return rewriter.create<calyx::WhileOp>(loc, cond, symbolAttr);
1087 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1090 PatternRewriter &)
const override {
1091 funcOp.walk([&](scf::WhileOp op) {
1100 getState<ComponentLoweringState>().getLoopIterRegs(whileOp))
1101 whileOp.
getOperation()->getResults()[res.first].replaceAllUsesWith(
1105 funcOp.walk([&](memref::LoadOp loadOp) {
1111 loadOp.getResult().replaceAllUsesWith(
1112 getState<ComponentLoweringState>()
1113 .getMemoryInterface(loadOp.memref())
1124 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1128 PatternRewriter &rewriter)
const override {
1129 rewriter.eraseOp(funcOp);
1141 void runOnOperation()
override;
1144 std::string &topLevelFunction) {
1145 if (!topLevelFunctionOpt.empty()) {
1146 if (SymbolTable::lookupSymbolIn(moduleOp, topLevelFunctionOpt) ==
1148 moduleOp.emitError() <<
"Top level function '" << topLevelFunctionOpt
1149 <<
"' not found in module.";
1152 topLevelFunction = topLevelFunctionOpt;
1156 auto funcOps = moduleOp.getOps<FuncOp>();
1157 if (std::distance(funcOps.begin(), funcOps.end()) == 1)
1158 topLevelFunction = (*funcOps.begin()).getSymName().str();
1160 moduleOp.emitError()
1161 <<
"Module contains multiple functions, but no top level "
1162 "function was set. Please see --top-level-function";
1181 calyx::ProgramOp *programOpOut) {
1185 using OpRewritePattern::OpRewritePattern;
1186 LogicalResult matchAndRewrite(mlir::ModuleOp,
1187 PatternRewriter &)
const override {
1192 ConversionTarget target(getContext());
1193 target.addLegalDialect<calyx::CalyxDialect>();
1194 target.addLegalDialect<scf::SCFDialect>();
1195 target.addIllegalDialect<hw::HWDialect>();
1196 target.addIllegalDialect<comb::CombDialect>();
1199 target.addIllegalOp<scf::ForOp>();
1202 target.addIllegalDialect<FuncDialect>();
1203 target.addIllegalDialect<ArithmeticDialect>();
1204 target.addLegalOp<AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp, AndIOp,
1205 XOrIOp, OrIOp, ExtUIOp, TruncIOp, CondBranchOp, BranchOp,
1206 MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp, ReturnOp,
1207 arith::ConstantOp, IndexCastOp, FuncOp, ExtSIOp>();
1209 RewritePatternSet legalizePatterns(&getContext());
1210 legalizePatterns.add<DummyPattern>(&getContext());
1211 DenseSet<Operation *> legalizedOps;
1212 if (applyPartialConversion(getOperation(), target,
1213 std::move(legalizePatterns))
1218 RewritePatternSet conversionPatterns(&getContext());
1220 &getContext(), topLevelFunction, programOpOut);
1221 return applyOpPatternsAndFold(getOperation(),
1222 std::move(conversionPatterns));
1228 template <
typename TPattern,
typename... PatternArgs>
1230 PatternArgs &&...args) {
1231 RewritePatternSet ps(&getContext());
1232 ps.add<TPattern>(&getContext(), partialPatternRes, args...);
1237 template <
typename TPattern,
typename... PatternArgs>
1239 PatternArgs &&...args) {
1240 RewritePatternSet ps(&getContext());
1241 ps.add<TPattern>(&getContext(), args...);
1247 assert(pattern.getNativePatterns().size() == 1 &&
1248 "Should only apply 1 partial lowering pattern at once");
1254 GreedyRewriteConfig config;
1255 config.enableRegionSimplification =
false;
1257 config.maxIterations = 0;
1262 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(pattern),
1264 return partialPatternRes;
1269 std::shared_ptr<calyx::ProgramLoweringState> loweringState =
nullptr;
1272 void SCFToCalyxPass::runOnOperation() {
1274 loweringState.reset();
1275 partialPatternRes = LogicalResult::failure();
1277 std::string topLevelFunction;
1278 if (failed(setTopLevelFunction(getOperation(), topLevelFunction))) {
1279 signalPassFailure();
1284 calyx::ProgramOp programOp;
1285 if (failed(createProgram(topLevelFunction, &programOp))) {
1286 signalPassFailure();
1289 assert(programOp.getOperation() !=
nullptr &&
1290 "programOp should have been set during module "
1291 "conversion, if module conversion succeeded.");
1292 loweringState = std::make_shared<calyx::ProgramLoweringState>(
1293 programOp, topLevelFunction);
1303 DenseMap<FuncOp, calyx::ComponentOp> funcMap;
1304 SmallVector<LoweringPattern, 8> loweringPatterns;
1307 addOncePattern<FuncOpConversion>(loweringPatterns, funcMap, *loweringState);
1310 addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);
1313 addOncePattern<calyx::ConvertIndexTypes>(loweringPatterns, funcMap,
1317 addOncePattern<calyx::BuildBasicBlockRegs>(loweringPatterns, funcMap,
1321 addOncePattern<calyx::BuildReturnRegs>(loweringPatterns, funcMap,
1327 addOncePattern<BuildWhileGroups>(loweringPatterns, funcMap, *loweringState);
1336 addOncePattern<BuildOpGroups>(loweringPatterns, funcMap, *loweringState);
1341 addOncePattern<BuildControl>(loweringPatterns, funcMap, *loweringState);
1345 addOncePattern<calyx::InlineCombGroups>(loweringPatterns, *loweringState);
1349 addOncePattern<LateSSAReplacement>(loweringPatterns, funcMap, *loweringState);
1354 addGreedyPattern<calyx::EliminateUnusedCombGroups>(loweringPatterns);
1358 addOncePattern<calyx::RewriteMemoryAccesses>(loweringPatterns,
1363 addOncePattern<CleanupFuncOps>(loweringPatterns, funcMap, *loweringState);
1366 for (
auto &pat : loweringPatterns) {
1367 LogicalResult partialPatternRes = runPartialPattern(
1369 pat.strategy == LoweringPattern::Strategy::Once);
1370 if (succeeded(partialPatternRes))
1372 signalPassFailure();
1379 RewritePatternSet cleanupPatterns(&getContext());
1382 if (failed(applyPatternsAndFoldGreedily(getOperation(),
1383 std::move(cleanupPatterns)))) {
1384 signalPassFailure();
1388 if (ciderSourceLocationMetadata) {
1391 SmallVector<Attribute, 16> sourceLocations;
1392 getOperation()->walk([&](calyx::ComponentOp component) {
1396 MLIRContext *context = getOperation()->getContext();
1397 getOperation()->setAttr(
"calyx.metadata",
1409 return std::make_unique<scftocalyx::SCFToCalyxPass>();