CIRCT 20.0.0git
Loading...
Searching...
No Matches
SCFToCalyx.cpp
Go to the documentation of this file.
1//===- SCFToCalyx.cpp - SCF to Calyx pass entry point -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main SCF to Calyx conversion pass implementation.
10//
11//===----------------------------------------------------------------------===//
12
19#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
20#include "mlir/Conversion/LLVMCommon/Pattern.h"
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/MemRef/IR/MemRef.h"
25#include "mlir/Dialect/SCF/IR/SCF.h"
26#include "mlir/IR/AsmState.h"
27#include "mlir/IR/Matchers.h"
28#include "mlir/Pass/Pass.h"
29#include "mlir/Support/LogicalResult.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/LogicalResult.h"
33#include "llvm/Support/raw_os_ostream.h"
34#include "llvm/Support/raw_ostream.h"
35#include <algorithm>
36#include <filesystem>
37#include <fstream>
38
39#include <locale>
40#include <numeric>
41#include <variant>
42
43namespace circt {
44#define GEN_PASS_DEF_SCFTOCALYX
45#include "circt/Conversion/Passes.h.inc"
46} // namespace circt
47
48using namespace llvm;
49using namespace mlir;
50using namespace mlir::arith;
51using namespace mlir::cf;
52using namespace mlir::func;
53namespace circt {
54class ComponentLoweringStateInterface;
55namespace scftocalyx {
56
57//===----------------------------------------------------------------------===//
58// Utility types
59//===----------------------------------------------------------------------===//
60
61class ScfWhileOp : public calyx::WhileOpInterface<scf::WhileOp> {
62public:
63 explicit ScfWhileOp(scf::WhileOp op)
64 : calyx::WhileOpInterface<scf::WhileOp>(op) {}
65
66 Block::BlockArgListType getBodyArgs() override {
67 return getOperation().getAfterArguments();
68 }
69
70 Block *getBodyBlock() override { return &getOperation().getAfter().front(); }
71
72 Block *getConditionBlock() override {
73 return &getOperation().getBefore().front();
74 }
75
76 Value getConditionValue() override {
77 return getOperation().getConditionOp().getOperand(0);
78 }
79
80 std::optional<int64_t> getBound() override { return std::nullopt; }
81};
82
83class ScfForOp : public calyx::RepeatOpInterface<scf::ForOp> {
84public:
85 explicit ScfForOp(scf::ForOp op) : calyx::RepeatOpInterface<scf::ForOp>(op) {}
86
87 Block::BlockArgListType getBodyArgs() override {
88 return getOperation().getRegion().getArguments();
89 }
90
91 Block *getBodyBlock() override {
92 return &getOperation().getRegion().getBlocks().front();
93 }
94
95 std::optional<int64_t> getBound() override {
96 return constantTripCount(getOperation().getLowerBound(),
97 getOperation().getUpperBound(),
98 getOperation().getStep());
99 }
100};
101
102//===----------------------------------------------------------------------===//
103// Lowering state classes
104//===----------------------------------------------------------------------===//
105
107 scf::IfOp ifOp;
108};
109
111 /// While operation to schedule.
113};
114
116 /// For operation to schedule.
118 /// Bound
119 uint64_t bound;
120};
121
123 /// Instance for invoking.
124 calyx::InstanceOp instanceOp;
125 // CallOp for getting the arguments.
126 func::CallOp callOp;
127};
128
130 /// Parallel operation to schedule.
131 scf::ParallelOp parOp;
132};
133
134/// A variant of types representing scheduleable operations.
136 std::variant<calyx::GroupOp, WhileScheduleable, ForScheduleable,
138
140public:
141 void setThenGroup(scf::IfOp op, calyx::GroupOp group) {
142 Operation *operation = op.getOperation();
143 assert(thenGroup.count(operation) == 0 &&
144 "A then group was already set for this scf::IfOp!\n");
145 thenGroup[operation] = group;
146 }
147
148 calyx::GroupOp getThenGroup(scf::IfOp op) {
149 auto it = thenGroup.find(op.getOperation());
150 assert(it != thenGroup.end() &&
151 "No then group was set for this scf::IfOp!\n");
152 return it->second;
153 }
154
155 void setElseGroup(scf::IfOp op, calyx::GroupOp group) {
156 Operation *operation = op.getOperation();
157 assert(elseGroup.count(operation) == 0 &&
158 "An else group was already set for this scf::IfOp!\n");
159 elseGroup[operation] = group;
160 }
161
162 calyx::GroupOp getElseGroup(scf::IfOp op) {
163 auto it = elseGroup.find(op.getOperation());
164 assert(it != elseGroup.end() &&
165 "No else group was set for this scf::IfOp!\n");
166 return it->second;
167 }
168
169 void setResultRegs(scf::IfOp op, calyx::RegisterOp reg, unsigned idx) {
170 assert(resultRegs[op.getOperation()].count(idx) == 0 &&
171 "A register was already registered for the given yield result.\n");
172 assert(idx < op->getNumOperands());
173 resultRegs[op.getOperation()][idx] = reg;
174 }
175
176 const DenseMap<unsigned, calyx::RegisterOp> &getResultRegs(scf::IfOp op) {
177 return resultRegs[op.getOperation()];
178 }
179
180 calyx::RegisterOp getResultRegs(scf::IfOp op, unsigned idx) {
181 auto regs = getResultRegs(op);
182 auto it = regs.find(idx);
183 assert(it != regs.end() && "resultReg not found");
184 return it->second;
185 }
186
187private:
188 DenseMap<Operation *, calyx::GroupOp> thenGroup;
189 DenseMap<Operation *, calyx::GroupOp> elseGroup;
190 DenseMap<Operation *, DenseMap<unsigned, calyx::RegisterOp>> resultRegs;
191};
192
195public:
196 SmallVector<calyx::GroupOp> getWhileLoopInitGroups(ScfWhileOp op) {
197 return getLoopInitGroups(std::move(op));
198 }
200 OpBuilder &builder, ScfWhileOp op, calyx::ComponentOp componentOp,
201 Twine uniqueSuffix, MutableArrayRef<OpOperand> ops) {
202 return buildLoopIterArgAssignments(builder, std::move(op), componentOp,
203 uniqueSuffix, ops);
204 }
205 void addWhileLoopIterReg(ScfWhileOp op, calyx::RegisterOp reg, unsigned idx) {
206 return addLoopIterReg(std::move(op), reg, idx);
207 }
208 const DenseMap<unsigned, calyx::RegisterOp> &
210 return getLoopIterRegs(std::move(op));
211 }
212 void setWhileLoopLatchGroup(ScfWhileOp op, calyx::GroupOp group) {
213 return setLoopLatchGroup(std::move(op), group);
214 }
216 return getLoopLatchGroup(std::move(op));
217 }
219 SmallVector<calyx::GroupOp> groups) {
220 return setLoopInitGroups(std::move(op), std::move(groups));
221 }
222};
223
226public:
227 SmallVector<calyx::GroupOp> getForLoopInitGroups(ScfForOp op) {
228 return getLoopInitGroups(std::move(op));
229 }
231 OpBuilder &builder, ScfForOp op, calyx::ComponentOp componentOp,
232 Twine uniqueSuffix, MutableArrayRef<OpOperand> ops) {
233 return buildLoopIterArgAssignments(builder, std::move(op), componentOp,
234 uniqueSuffix, ops);
235 }
236 void addForLoopIterReg(ScfForOp op, calyx::RegisterOp reg, unsigned idx) {
237 return addLoopIterReg(std::move(op), reg, idx);
238 }
239 const DenseMap<unsigned, calyx::RegisterOp> &getForLoopIterRegs(ScfForOp op) {
240 return getLoopIterRegs(std::move(op));
241 }
242 calyx::RegisterOp getForLoopIterReg(ScfForOp op, unsigned idx) {
243 return getLoopIterReg(std::move(op), idx);
244 }
245 void setForLoopLatchGroup(ScfForOp op, calyx::GroupOp group) {
246 return setLoopLatchGroup(std::move(op), group);
247 }
248 calyx::GroupOp getForLoopLatchGroup(ScfForOp op) {
249 return getLoopLatchGroup(std::move(op));
250 }
251 void setForLoopInitGroups(ScfForOp op, SmallVector<calyx::GroupOp> groups) {
252 return setLoopInitGroups(std::move(op), std::move(groups));
253 }
254};
255
256/// Handles the current state of lowering of a Calyx component. It is mainly
257/// used as a key/value store for recording information during partial lowering,
258/// which is required at later lowering passes.
268
269//===----------------------------------------------------------------------===//
270// Conversion patterns
271//===----------------------------------------------------------------------===//
272
273/// Iterate through the operations of a source function and instantiate
274/// components or primitives based on the type of the operations.
276public:
277 BuildOpGroups(MLIRContext *context, LogicalResult &resRef,
279 DenseMap<mlir::func::FuncOp, calyx::ComponentOp> &map,
281 mlir::Pass::Option<std::string> &writeJsonOpt)
282 : FuncOpPartialLoweringPattern(context, resRef, patternState, map, state),
283 writeJson(writeJsonOpt) {}
284 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
285
286 LogicalResult
288 PatternRewriter &rewriter) const override {
289 /// We walk the operations of the funcOp to ensure that all def's have
290 /// been visited before their uses.
291 bool opBuiltSuccessfully = true;
292 funcOp.walk([&](Operation *_op) {
293 opBuiltSuccessfully &=
294 TypeSwitch<mlir::Operation *, bool>(_op)
295 .template Case<arith::ConstantOp, ReturnOp, BranchOpInterface,
296 /// SCF
297 scf::YieldOp, scf::WhileOp, scf::ForOp, scf::IfOp,
298 scf::ParallelOp, scf::ReduceOp,
299 scf::ExecuteRegionOp,
300 /// memref
301 memref::AllocOp, memref::AllocaOp, memref::LoadOp,
302 memref::StoreOp, memref::GetGlobalOp,
303 /// standard arithmetic
304 AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
305 AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
306 MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
307 /// floating point
308 AddFOp, MulFOp, CmpFOp,
309 /// others
310 SelectOp, IndexCastOp, CallOp>(
311 [&](auto op) { return buildOp(rewriter, op).succeeded(); })
312 .template Case<FuncOp, scf::ConditionOp>([&](auto) {
313 /// Skip: these special cases will be handled separately.
314 return true;
315 })
316 .Default([&](auto op) {
317 op->emitError() << "Unhandled operation during BuildOpGroups()";
318 return false;
319 });
320
321 return opBuiltSuccessfully ? WalkResult::advance()
322 : WalkResult::interrupt();
323 });
324
325 if (!writeJson.empty()) {
326 if (auto fileLoc = dyn_cast<mlir::FileLineColLoc>(funcOp->getLoc())) {
327 std::string filename = fileLoc.getFilename().str();
328 std::filesystem::path path(filename);
329 std::string jsonFileName = writeJson.getValue() + ".json";
330 auto outFileName = path.parent_path().append(jsonFileName);
331 std::ofstream outFile(outFileName);
332
333 if (!outFile.is_open()) {
334 llvm::errs() << "Unable to open file: " << outFileName.string()
335 << " for writing\n";
336 return failure();
337 }
338 llvm::raw_os_ostream llvmOut(outFile);
339 llvm::json::OStream jsonOS(llvmOut, /*IndentSize=*/2);
340 jsonOS.value(getState<ComponentLoweringState>().getExtMemData());
341 jsonOS.flush();
342 outFile.close();
343 }
344 }
345
346 return success(opBuiltSuccessfully);
347 }
348
349private:
350 mlir::Pass::Option<std::string> &writeJson;
351 /// Op builder specializations.
352 LogicalResult buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp) const;
353 LogicalResult buildOp(PatternRewriter &rewriter,
354 BranchOpInterface brOp) const;
355 LogicalResult buildOp(PatternRewriter &rewriter,
356 arith::ConstantOp constOp) const;
357 LogicalResult buildOp(PatternRewriter &rewriter, SelectOp op) const;
358 LogicalResult buildOp(PatternRewriter &rewriter, AddIOp op) const;
359 LogicalResult buildOp(PatternRewriter &rewriter, SubIOp op) const;
360 LogicalResult buildOp(PatternRewriter &rewriter, MulIOp op) const;
361 LogicalResult buildOp(PatternRewriter &rewriter, DivUIOp op) const;
362 LogicalResult buildOp(PatternRewriter &rewriter, DivSIOp op) const;
363 LogicalResult buildOp(PatternRewriter &rewriter, RemUIOp op) const;
364 LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op) const;
365 LogicalResult buildOp(PatternRewriter &rewriter, AddFOp op) const;
366 LogicalResult buildOp(PatternRewriter &rewriter, MulFOp op) const;
367 LogicalResult buildOp(PatternRewriter &rewriter, CmpFOp op) const;
368 LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op) const;
369 LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op) const;
370 LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op) const;
371 LogicalResult buildOp(PatternRewriter &rewriter, AndIOp op) const;
372 LogicalResult buildOp(PatternRewriter &rewriter, OrIOp op) const;
373 LogicalResult buildOp(PatternRewriter &rewriter, XOrIOp op) const;
374 LogicalResult buildOp(PatternRewriter &rewriter, CmpIOp op) const;
375 LogicalResult buildOp(PatternRewriter &rewriter, TruncIOp op) const;
376 LogicalResult buildOp(PatternRewriter &rewriter, ExtUIOp op) const;
377 LogicalResult buildOp(PatternRewriter &rewriter, ExtSIOp op) const;
378 LogicalResult buildOp(PatternRewriter &rewriter, ReturnOp op) const;
379 LogicalResult buildOp(PatternRewriter &rewriter, IndexCastOp op) const;
380 LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocOp op) const;
381 LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocaOp op) const;
382 LogicalResult buildOp(PatternRewriter &rewriter,
383 memref::GetGlobalOp op) const;
384 LogicalResult buildOp(PatternRewriter &rewriter, memref::LoadOp op) const;
385 LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op) const;
386 LogicalResult buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp) const;
387 LogicalResult buildOp(PatternRewriter &rewriter, scf::ForOp forOp) const;
388 LogicalResult buildOp(PatternRewriter &rewriter, scf::IfOp ifOp) const;
389 LogicalResult buildOp(PatternRewriter &rewriter,
390 scf::ReduceOp reduceOp) const;
391 LogicalResult buildOp(PatternRewriter &rewriter,
392 scf::ParallelOp parallelOp) const;
393 LogicalResult buildOp(PatternRewriter &rewriter,
394 scf::ExecuteRegionOp executeRegionOp) const;
395 LogicalResult buildOp(PatternRewriter &rewriter, CallOp callOp) const;
396
397 /// buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the
398 /// source operation TSrcOp.
399 template <typename TGroupOp, typename TCalyxLibOp, typename TSrcOp>
400 LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op,
401 TypeRange srcTypes, TypeRange dstTypes) const {
402 SmallVector<Type> types;
403 for (Type srcType : srcTypes)
404 types.push_back(calyx::toBitVector(srcType));
405 for (Type dstType : dstTypes)
406 types.push_back(calyx::toBitVector(dstType));
407
408 auto calyxOp =
409 getState<ComponentLoweringState>().getNewLibraryOpInstance<TCalyxLibOp>(
410 rewriter, op.getLoc(), types);
411
412 auto directions = calyxOp.portDirections();
413 SmallVector<Value, 4> opInputPorts;
414 SmallVector<Value, 4> opOutputPorts;
415 for (auto dir : enumerate(directions)) {
416 if (dir.value() == calyx::Direction::Input)
417 opInputPorts.push_back(calyxOp.getResult(dir.index()));
418 else
419 opOutputPorts.push_back(calyxOp.getResult(dir.index()));
420 }
421 assert(
422 opInputPorts.size() == op->getNumOperands() &&
423 opOutputPorts.size() == op->getNumResults() &&
424 "Expected an equal number of in/out ports in the Calyx library op with "
425 "respect to the number of operands/results of the source operation.");
426
427 /// Create assignments to the inputs of the library op.
428 auto group = createGroupForOp<TGroupOp>(rewriter, op);
429 rewriter.setInsertionPointToEnd(group.getBodyBlock());
430 for (auto dstOp : enumerate(opInputPorts))
431 rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
432 op->getOperand(dstOp.index()));
433
434 /// Replace the result values of the source operator with the new operator.
435 for (auto res : enumerate(opOutputPorts)) {
436 getState<ComponentLoweringState>().registerEvaluatingGroup(res.value(),
437 group);
438 op->getResult(res.index()).replaceAllUsesWith(res.value());
439 }
440 return success();
441 }
442
443 /// buildLibraryOp which provides in- and output types based on the operands
444 /// and results of the op argument.
445 template <typename TGroupOp, typename TCalyxLibOp, typename TSrcOp>
446 LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op) const {
447 return buildLibraryOp<TGroupOp, TCalyxLibOp, TSrcOp>(
448 rewriter, op, op.getOperandTypes(), op->getResultTypes());
449 }
450
451 /// Creates a group named by the basic block which the input op resides in.
452 template <typename TGroupOp>
453 TGroupOp createGroupForOp(PatternRewriter &rewriter, Operation *op) const {
454 Block *block = op->getBlock();
455 auto groupName = getState<ComponentLoweringState>().getUniqueName(
456 loweringState().blockName(block));
457 return calyx::createGroup<TGroupOp>(
458 rewriter, getState<ComponentLoweringState>().getComponentOp(),
459 op->getLoc(), groupName);
460 }
461
462 /// buildLibraryBinaryPipeOp will build a TCalyxLibBinaryPipeOp, to
463 /// deal with MulIOp, DivUIOp and RemUIOp.
464 template <typename TOpType, typename TSrcOp>
465 LogicalResult buildLibraryBinaryPipeOp(PatternRewriter &rewriter, TSrcOp op,
466 TOpType opPipe, Value out) const {
467 StringRef opName = TSrcOp::getOperationName().split(".").second;
468 Location loc = op.getLoc();
469 Type width = op.getResult().getType();
470 auto reg = createRegister(
471 op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(),
472 getState<ComponentLoweringState>().getUniqueName(opName));
473 // Operation pipelines are not combinational, so a GroupOp is required.
474 auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
475 OpBuilder builder(group->getRegion(0));
476 getState<ComponentLoweringState>().addBlockScheduleable(op->getBlock(),
477 group);
478
479 rewriter.setInsertionPointToEnd(group.getBodyBlock());
480 rewriter.create<calyx::AssignOp>(loc, opPipe.getLeft(), op.getLhs());
481 rewriter.create<calyx::AssignOp>(loc, opPipe.getRight(), op.getRhs());
482 // Write the output to this register.
483 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), out);
484 // The write enable port is high when the pipeline is done.
485 rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), opPipe.getDone());
486 // Set pipelineOp to high as long as its done signal is not high.
487 // This prevents the pipelineOP from executing for the cycle that we write
488 // to register. To get !(pipelineOp.done) we do 1 xor pipelineOp.done
489 hw::ConstantOp c1 = createConstant(loc, rewriter, getComponent(), 1, 1);
490 rewriter.create<calyx::AssignOp>(
491 loc, opPipe.getGo(), c1,
492 comb::createOrFoldNot(group.getLoc(), opPipe.getDone(), builder));
493 // The group is done when the register write is complete.
494 rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());
495
496 // Pass the result from the source operation to register holding the resullt
497 // from the Calyx primitive.
498 op.getResult().replaceAllUsesWith(reg.getOut());
499
500 if (isa<calyx::AddFOpIEEE754>(opPipe)) {
501 auto opFOp = cast<calyx::AddFOpIEEE754>(opPipe);
502 hw::ConstantOp subOp;
503 if (isa<arith::AddFOp>(op)) {
504 subOp = createConstant(loc, rewriter, getComponent(), /*width=*/1,
505 /*subtract=*/0);
506 } else {
507 subOp = createConstant(loc, rewriter, getComponent(), /*width=*/1,
508 /*subtract=*/1);
509 }
510 rewriter.create<calyx::AssignOp>(loc, opFOp.getSubOp(), subOp);
511 }
512
513 // Register the values for the pipeline.
514 getState<ComponentLoweringState>().registerEvaluatingGroup(out, group);
515 getState<ComponentLoweringState>().registerEvaluatingGroup(opPipe.getLeft(),
516 group);
517 getState<ComponentLoweringState>().registerEvaluatingGroup(
518 opPipe.getRight(), group);
519
520 return success();
521 }
522
523 /// Creates assignments within the provided group to the address ports of the
524 /// memoryOp based on the provided addressValues.
525 void assignAddressPorts(PatternRewriter &rewriter, Location loc,
526 calyx::GroupInterface group,
527 calyx::MemoryInterface memoryInterface,
528 Operation::operand_range addressValues) const {
529 IRRewriter::InsertionGuard guard(rewriter);
530 rewriter.setInsertionPointToEnd(group.getBody());
531 auto addrPorts = memoryInterface.addrPorts();
532 if (addressValues.empty()) {
533 assert(
534 addrPorts.size() == 1 &&
535 "We expected a 1 dimensional memory of size 1 because there were no "
536 "address assignment values");
537 // Assign to address 1'd0 in memory.
538 rewriter.create<calyx::AssignOp>(
539 loc, addrPorts[0],
540 createConstant(loc, rewriter, getComponent(), 1, 0));
541 } else {
542 assert(addrPorts.size() == addressValues.size() &&
543 "Mismatch between number of address ports of the provided memory "
544 "and address assignment values");
545 for (auto address : enumerate(addressValues))
546 rewriter.create<calyx::AssignOp>(loc, addrPorts[address.index()],
547 address.value());
548 }
549 }
550
551 calyx::RegisterOp createSignalRegister(PatternRewriter &rewriter,
552 Value signal, bool invert,
553 StringRef nameSuffix,
554 calyx::CompareFOpIEEE754 calyxCmpFOp,
555 calyx::GroupOp group) const {
556 Location loc = calyxCmpFOp.getLoc();
557 IntegerType one = rewriter.getI1Type();
558 auto component = getComponent();
559 OpBuilder builder(group->getRegion(0));
560 auto reg = createRegister(
561 loc, rewriter, component, 1,
562 getState<ComponentLoweringState>().getUniqueName(nameSuffix));
563 rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(),
564 calyxCmpFOp.getDone());
565 if (invert) {
566 auto notLibOp = getState<ComponentLoweringState>()
567 .getNewLibraryOpInstance<calyx::NotLibOp>(
568 rewriter, loc, {one, one});
569 rewriter.create<calyx::AssignOp>(loc, notLibOp.getIn(), signal);
570 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), notLibOp.getOut());
571 getState<ComponentLoweringState>().registerEvaluatingGroup(
572 notLibOp.getOut(), group);
573 } else
574 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), signal);
575 return reg;
576 };
577};
578
579LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
580 memref::LoadOp loadOp) const {
581 Value memref = loadOp.getMemref();
582 auto memoryInterface =
583 getState<ComponentLoweringState>().getMemoryInterface(memref);
584 auto group = createGroupForOp<calyx::GroupOp>(rewriter, loadOp);
585 assignAddressPorts(rewriter, loadOp.getLoc(), group, memoryInterface,
586 loadOp.getIndices());
587
588 rewriter.setInsertionPointToEnd(group.getBodyBlock());
589
590 bool needReg = true;
591 Value res;
592 Value regWriteEn =
593 createConstant(loadOp.getLoc(), rewriter, getComponent(), 1, 1);
594 if (memoryInterface.readEnOpt().has_value()) {
595 auto oneI1 =
596 calyx::createConstant(loadOp.getLoc(), rewriter, getComponent(), 1, 1);
597 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), memoryInterface.readEn(),
598 oneI1);
599 regWriteEn = memoryInterface.done();
600 if (calyx::noStoresToMemory(memref) &&
602 // Single load from memory; we do not need to write the output to a
603 // register. The readData value will be held until readEn is asserted
604 // again
605 needReg = false;
606 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(),
607 memoryInterface.done());
608 // We refrain from replacing the loadOp result with
609 // memoryInterface.readData, since multiple loadOp's need to be converted
610 // to a single memory's ReadData. If this replacement is done now, we lose
611 // the link between which SSA memref::LoadOp values map to which groups
612 // for loading a value from the Calyx memory. At this point of lowering,
613 // we keep the memref::LoadOp SSA value, and do value replacement _after_
614 // control has been generated (see LateSSAReplacement). This is *vital*
615 // for things such as calyx::InlineCombGroups to be able to properly track
616 // which memory assignment groups belong to which accesses.
617 res = loadOp.getResult();
618 }
619 } else if (memoryInterface.contentEnOpt().has_value()) {
620 auto oneI1 =
621 calyx::createConstant(loadOp.getLoc(), rewriter, getComponent(), 1, 1);
622 auto zeroI1 =
623 calyx::createConstant(loadOp.getLoc(), rewriter, getComponent(), 1, 0);
624 rewriter.create<calyx::AssignOp>(loadOp.getLoc(),
625 memoryInterface.contentEn(), oneI1);
626 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), memoryInterface.writeEn(),
627 zeroI1);
628 regWriteEn = memoryInterface.done();
629 if (calyx::noStoresToMemory(memref) &&
631 // Single load from memory; we do not need to write the output to a
632 // register. The readData value will be held until contentEn is asserted
633 // again
634 needReg = false;
635 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(),
636 memoryInterface.done());
637 // We refrain from replacing the loadOp result with
638 // memoryInterface.readData, since multiple loadOp's need to be converted
639 // to a single memory's ReadData. If this replacement is done now, we lose
640 // the link between which SSA memref::LoadOp values map to which groups
641 // for loading a value from the Calyx memory. At this point of lowering,
642 // we keep the memref::LoadOp SSA value, and do value replacement _after_
643 // control has been generated (see LateSSAReplacement). This is *vital*
644 // for things such as calyx::InlineCombGroups to be able to properly track
645 // which memory assignment groups belong to which accesses.
646 res = loadOp.getResult();
647 }
648 }
649
650 if (needReg) {
651 // Multiple loads from the same memory; In this case, we _may_ have a
652 // structural hazard in the design we generate. To get around this, we
653 // conservatively place a register in front of each load operation, and
654 // replace all uses of the loaded value with the register output. Reading
655 // for sequential memories will cause a read to take at least 2 cycles,
656 // but it will usually be better because combinational reads on memories
657 // can significantly decrease the maximum achievable frequency.
658 auto reg = createRegister(
659 loadOp.getLoc(), rewriter, getComponent(),
660 loadOp.getMemRefType().getElementTypeBitWidth(),
661 getState<ComponentLoweringState>().getUniqueName("load"));
662 rewriter.setInsertionPointToEnd(group.getBodyBlock());
663 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), reg.getIn(),
664 memoryInterface.readData());
665 rewriter.create<calyx::AssignOp>(loadOp.getLoc(), reg.getWriteEn(),
666 regWriteEn);
667 rewriter.create<calyx::GroupDoneOp>(loadOp.getLoc(), reg.getDone());
668 loadOp.getResult().replaceAllUsesWith(reg.getOut());
669 res = reg.getOut();
670 }
671
672 getState<ComponentLoweringState>().registerEvaluatingGroup(res, group);
673 getState<ComponentLoweringState>().addBlockScheduleable(loadOp->getBlock(),
674 group);
675 return success();
676}
677
678LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
679 memref::StoreOp storeOp) const {
680 auto memoryInterface = getState<ComponentLoweringState>().getMemoryInterface(
681 storeOp.getMemref());
682 auto group = createGroupForOp<calyx::GroupOp>(rewriter, storeOp);
683
684 // This is a sequential group, so register it as being scheduleable for the
685 // block.
686 getState<ComponentLoweringState>().addBlockScheduleable(storeOp->getBlock(),
687 group);
688 assignAddressPorts(rewriter, storeOp.getLoc(), group, memoryInterface,
689 storeOp.getIndices());
690 rewriter.setInsertionPointToEnd(group.getBodyBlock());
691 rewriter.create<calyx::AssignOp>(
692 storeOp.getLoc(), memoryInterface.writeData(), storeOp.getValueToStore());
693 rewriter.create<calyx::AssignOp>(
694 storeOp.getLoc(), memoryInterface.writeEn(),
695 createConstant(storeOp.getLoc(), rewriter, getComponent(), 1, 1));
696 if (memoryInterface.contentEnOpt().has_value()) {
697 // If memory has content enable, it must be asserted when writing
698 rewriter.create<calyx::AssignOp>(
699 storeOp.getLoc(), memoryInterface.contentEn(),
700 createConstant(storeOp.getLoc(), rewriter, getComponent(), 1, 1));
701 }
702 rewriter.create<calyx::GroupDoneOp>(storeOp.getLoc(), memoryInterface.done());
703
704 return success();
705}
706
707LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
708 MulIOp mul) const {
709 Location loc = mul.getLoc();
710 Type width = mul.getResult().getType(), one = rewriter.getI1Type();
711 auto mulPipe =
712 getState<ComponentLoweringState>()
713 .getNewLibraryOpInstance<calyx::MultPipeLibOp>(
714 rewriter, loc, {one, one, one, width, width, width, one});
715 return buildLibraryBinaryPipeOp<calyx::MultPipeLibOp>(
716 rewriter, mul, mulPipe,
717 /*out=*/mulPipe.getOut());
718}
719
720LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
721 DivUIOp div) const {
722 Location loc = div.getLoc();
723 Type width = div.getResult().getType(), one = rewriter.getI1Type();
724 auto divPipe =
725 getState<ComponentLoweringState>()
726 .getNewLibraryOpInstance<calyx::DivUPipeLibOp>(
727 rewriter, loc, {one, one, one, width, width, width, one});
728 return buildLibraryBinaryPipeOp<calyx::DivUPipeLibOp>(
729 rewriter, div, divPipe,
730 /*out=*/divPipe.getOut());
731}
732
733LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
734 DivSIOp div) const {
735 Location loc = div.getLoc();
736 Type width = div.getResult().getType(), one = rewriter.getI1Type();
737 auto divPipe =
738 getState<ComponentLoweringState>()
739 .getNewLibraryOpInstance<calyx::DivSPipeLibOp>(
740 rewriter, loc, {one, one, one, width, width, width, one});
741 return buildLibraryBinaryPipeOp<calyx::DivSPipeLibOp>(
742 rewriter, div, divPipe,
743 /*out=*/divPipe.getOut());
744}
745
746LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
747 RemUIOp rem) const {
748 Location loc = rem.getLoc();
749 Type width = rem.getResult().getType(), one = rewriter.getI1Type();
750 auto remPipe =
751 getState<ComponentLoweringState>()
752 .getNewLibraryOpInstance<calyx::RemUPipeLibOp>(
753 rewriter, loc, {one, one, one, width, width, width, one});
754 return buildLibraryBinaryPipeOp<calyx::RemUPipeLibOp>(
755 rewriter, rem, remPipe,
756 /*out=*/remPipe.getOut());
757}
758
759LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
760 RemSIOp rem) const {
761 Location loc = rem.getLoc();
762 Type width = rem.getResult().getType(), one = rewriter.getI1Type();
763 auto remPipe =
764 getState<ComponentLoweringState>()
765 .getNewLibraryOpInstance<calyx::RemSPipeLibOp>(
766 rewriter, loc, {one, one, one, width, width, width, one});
767 return buildLibraryBinaryPipeOp<calyx::RemSPipeLibOp>(
768 rewriter, rem, remPipe,
769 /*out=*/remPipe.getOut());
770}
771
772LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
773 AddFOp addf) const {
774 Location loc = addf.getLoc();
775 IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
776 five = rewriter.getIntegerType(5),
777 width = rewriter.getIntegerType(
778 addf.getType().getIntOrFloatBitWidth());
779 auto addFOp =
780 getState<ComponentLoweringState>()
781 .getNewLibraryOpInstance<calyx::AddFOpIEEE754>(
782 rewriter, loc,
783 {one, one, one, one, one, width, width, three, width, five, one});
784 return buildLibraryBinaryPipeOp<calyx::AddFOpIEEE754>(rewriter, addf, addFOp,
785 addFOp.getOut());
786}
787
788LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
789 MulFOp mulf) const {
790 Location loc = mulf.getLoc();
791 IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
792 five = rewriter.getIntegerType(5),
793 width = rewriter.getIntegerType(
794 mulf.getType().getIntOrFloatBitWidth());
795 auto mulFOp =
796 getState<ComponentLoweringState>()
797 .getNewLibraryOpInstance<calyx::MulFOpIEEE754>(
798 rewriter, loc,
799 {one, one, one, one, width, width, three, width, five, one});
800 return buildLibraryBinaryPipeOp<calyx::MulFOpIEEE754>(rewriter, mulf, mulFOp,
801 mulFOp.getOut());
802}
803
804LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
805 CmpFOp cmpf) const {
806 Location loc = cmpf.getLoc();
807 IntegerType one = rewriter.getI1Type(), five = rewriter.getIntegerType(5),
808 width = rewriter.getIntegerType(
809 cmpf.getLhs().getType().getIntOrFloatBitWidth());
810 auto calyxCmpFOp = getState<ComponentLoweringState>()
811 .getNewLibraryOpInstance<calyx::CompareFOpIEEE754>(
812 rewriter, loc,
813 {one, one, one, width, width, one, one, one, one,
814 one, five, one});
815 hw::ConstantOp c0 = createConstant(loc, rewriter, getComponent(), 1, 0);
816 hw::ConstantOp c1 = createConstant(loc, rewriter, getComponent(), 1, 1);
817 rewriter.setInsertionPointToStart(getComponent().getBodyBlock());
818
820 using CombLogic = PredicateInfo::CombLogic;
821 using Port = PredicateInfo::InputPorts::Port;
822 PredicateInfo info = calyx::getPredicateInfo(cmpf.getPredicate());
823 if (info.logic == CombLogic::None) {
824 if (cmpf.getPredicate() == CmpFPredicate::AlwaysTrue) {
825 cmpf.getResult().replaceAllUsesWith(c1);
826 return success();
827 }
828
829 if (cmpf.getPredicate() == CmpFPredicate::AlwaysFalse) {
830 cmpf.getResult().replaceAllUsesWith(c0);
831 return success();
832 }
833 }
834
835 // General case
836 StringRef opName = cmpf.getOperationName().split(".").second;
837 auto reg =
838 createRegister(loc, rewriter, getComponent(), 1,
839 getState<ComponentLoweringState>().getUniqueName(opName));
840
841 // Operation pipelines are not combinational, so a GroupOp is required.
842 auto group = createGroupForOp<calyx::GroupOp>(rewriter, cmpf);
843 OpBuilder builder(group->getRegion(0));
844 getState<ComponentLoweringState>().addBlockScheduleable(cmpf->getBlock(),
845 group);
846
847 rewriter.setInsertionPointToEnd(group.getBodyBlock());
848 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getLeft(), cmpf.getLhs());
849 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getRight(), cmpf.getRhs());
850
851 bool signalingFlag = false;
852 switch (cmpf.getPredicate()) {
853 case CmpFPredicate::UGT:
854 case CmpFPredicate::UGE:
855 case CmpFPredicate::ULT:
856 case CmpFPredicate::ULE:
857 case CmpFPredicate::OGT:
858 case CmpFPredicate::OGE:
859 case CmpFPredicate::OLT:
860 case CmpFPredicate::OLE:
861 signalingFlag = true;
862 break;
863 case CmpFPredicate::UEQ:
864 case CmpFPredicate::UNE:
865 case CmpFPredicate::OEQ:
866 case CmpFPredicate::ONE:
867 case CmpFPredicate::UNO:
868 case CmpFPredicate::ORD:
869 case CmpFPredicate::AlwaysTrue:
870 case CmpFPredicate::AlwaysFalse:
871 signalingFlag = false;
872 break;
873 }
874
875 // The IEEE Standard mandates that equality comparisons ordinarily are quiet,
876 // while inequality comparisons ordinarily are signaling.
877 rewriter.create<calyx::AssignOp>(loc, calyxCmpFOp.getSignaling(),
878 signalingFlag ? c1 : c0);
879
880 // Prepare signals and create registers
881 SmallVector<calyx::RegisterOp> inputRegs;
882 for (const auto &input : info.inputPorts) {
883 Value signal;
884 switch (input.port) {
885 case Port::Eq: {
886 signal = calyxCmpFOp.getEq();
887 break;
888 }
889 case Port::Gt: {
890 signal = calyxCmpFOp.getGt();
891 break;
892 }
893 case Port::Lt: {
894 signal = calyxCmpFOp.getLt();
895 break;
896 }
897 case Port::Unordered: {
898 signal = calyxCmpFOp.getUnordered();
899 break;
900 }
901 }
902 std::string nameSuffix =
903 (input.port == PredicateInfo::InputPorts::Port::Unordered)
904 ? "unordered_port"
905 : "compare_port";
906 auto signalReg = createSignalRegister(rewriter, signal, input.invert,
907 nameSuffix, calyxCmpFOp, group);
908 inputRegs.push_back(signalReg);
909 }
910
911 // Create the output logical operation
912 Value outputValue, doneValue;
913 switch (info.logic) {
914 case CombLogic::None: {
915 // it's guaranteed to be either ORD or UNO
916 outputValue = inputRegs[0].getOut();
917 doneValue = inputRegs[0].getDone();
918 break;
919 }
920 case CombLogic::And: {
921 auto outputLibOp = getState<ComponentLoweringState>()
922 .getNewLibraryOpInstance<calyx::AndLibOp>(
923 rewriter, loc, {one, one, one});
924 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
925 inputRegs[0].getOut());
926 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
927 inputRegs[1].getOut());
928
929 outputValue = outputLibOp.getOut();
930 break;
931 }
932 case CombLogic::Or: {
933 auto outputLibOp = getState<ComponentLoweringState>()
934 .getNewLibraryOpInstance<calyx::OrLibOp>(
935 rewriter, loc, {one, one, one});
936 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getLeft(),
937 inputRegs[0].getOut());
938 rewriter.create<calyx::AssignOp>(loc, outputLibOp.getRight(),
939 inputRegs[1].getOut());
940
941 outputValue = outputLibOp.getOut();
942 break;
943 }
944 }
945
946 if (info.logic != CombLogic::None) {
947 auto doneLibOp = getState<ComponentLoweringState>()
948 .getNewLibraryOpInstance<calyx::AndLibOp>(
949 rewriter, loc, {one, one, one});
950 rewriter.create<calyx::AssignOp>(loc, doneLibOp.getLeft(),
951 inputRegs[0].getDone());
952 rewriter.create<calyx::AssignOp>(loc, doneLibOp.getRight(),
953 inputRegs[1].getDone());
954 doneValue = doneLibOp.getOut();
955 }
956
957 // Write to the output register
958 rewriter.create<calyx::AssignOp>(loc, reg.getIn(), outputValue);
959 rewriter.create<calyx::AssignOp>(loc, reg.getWriteEn(), doneValue);
960
961 // Set the go and done signal
962 rewriter.create<calyx::AssignOp>(
963 loc, calyxCmpFOp.getGo(), c1,
964 comb::createOrFoldNot(loc, calyxCmpFOp.getDone(), builder));
965 rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());
966
967 cmpf.getResult().replaceAllUsesWith(reg.getOut());
968
969 // Register evaluating groups
970 getState<ComponentLoweringState>().registerEvaluatingGroup(outputValue,
971 group);
972 getState<ComponentLoweringState>().registerEvaluatingGroup(doneValue, group);
973 getState<ComponentLoweringState>().registerEvaluatingGroup(
974 calyxCmpFOp.getLeft(), group);
975 getState<ComponentLoweringState>().registerEvaluatingGroup(
976 calyxCmpFOp.getRight(), group);
977
978 return success();
979}
980
981template <typename TAllocOp>
982static LogicalResult buildAllocOp(ComponentLoweringState &componentState,
983 PatternRewriter &rewriter, TAllocOp allocOp) {
984 rewriter.setInsertionPointToStart(
985 componentState.getComponentOp().getBodyBlock());
986 MemRefType memtype = allocOp.getType();
987 SmallVector<int64_t> addrSizes;
988 SmallVector<int64_t> sizes;
989 for (int64_t dim : memtype.getShape()) {
990 sizes.push_back(dim);
991 addrSizes.push_back(calyx::handleZeroWidth(dim));
992 }
993 // If memref has no size (e.g., memref<i32>) create a 1 dimensional memory of
994 // size 1.
995 if (sizes.empty() && addrSizes.empty()) {
996 sizes.push_back(1);
997 addrSizes.push_back(1);
998 }
999 auto memoryOp = rewriter.create<calyx::SeqMemoryOp>(
1000 allocOp.getLoc(), componentState.getUniqueName("mem"),
1001 memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
1002
1003 // Externalize memories conditionally (only in the top-level component because
1004 // Calyx compiler requires it as a well-formness check).
1005 memoryOp->setAttr("external",
1006 IntegerAttr::get(rewriter.getI1Type(), llvm::APInt(1, 1)));
1007 componentState.registerMemoryInterface(allocOp.getResult(),
1008 calyx::MemoryInterface(memoryOp));
1009
1010 unsigned elmTyBitWidth = memtype.getElementTypeBitWidth();
1011 assert(elmTyBitWidth <= 64 && "element bitwidth should not exceed 64");
1012 bool isFloat = !memtype.getElementType().isInteger();
1013
1014 auto shape = allocOp.getType().getShape();
1015 int totalSize =
1016 std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
1017 // The `totalSize <= 1` check is a hack to:
1018 // https://github.com/llvm/circt/pull/2661, where a multi-dimensional memory
1019 // whose size in some dimension equals 1, e.g. memref<1x1x1x1xi32>, will be
1020 // collapsed to `memref<1xi32>` with `totalSize == 1`. While the above case is
1021 // a trivial fix, Calyx expects 1-dimensional memories in general:
1022 // https://github.com/calyxir/calyx/issues/907
1023 if (!(shape.size() <= 1 || totalSize <= 1)) {
1024 allocOp.emitError("input memory dimension must be empty or one.");
1025 return failure();
1026 }
1027
1028 std::vector<uint64_t> flattenedVals(totalSize, 0);
1029 if (isa<memref::GetGlobalOp>(allocOp)) {
1030 auto getGlobalOp = cast<memref::GetGlobalOp>(allocOp);
1031 auto *symbolTableOp =
1032 getGlobalOp->template getParentWithTrait<mlir::OpTrait::SymbolTable>();
1033 auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
1034 SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
1035 // Flatten the values in the attribute
1036 auto cstAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(
1037 globalOp.getConstantInitValue());
1038 int sizeCount = 0;
1039 for (auto attr : cstAttr.template getValues<Attribute>()) {
1040 assert((isa<mlir::FloatAttr, mlir::IntegerAttr>(attr)) &&
1041 "memory attributes must be float or int");
1042 if (auto fltAttr = dyn_cast<mlir::FloatAttr>(attr)) {
1043 flattenedVals[sizeCount++] =
1044 bit_cast<uint64_t>(fltAttr.getValueAsDouble());
1045 } else {
1046 auto intAttr = dyn_cast<mlir::IntegerAttr>(attr);
1047 APInt value = intAttr.getValue();
1048 flattenedVals[sizeCount++] = *value.getRawData();
1049 }
1050 }
1051
1052 rewriter.eraseOp(globalOp);
1053 }
1054
1055 llvm::json::Array result;
1056 result.reserve(std::max(static_cast<int>(shape.size()), 1));
1057
1058 Type elemType = memtype.getElementType();
1059 bool isSigned =
1060 !elemType.isSignlessInteger() && !elemType.isUnsignedInteger();
1061 for (uint64_t bitValue : flattenedVals) {
1062 llvm::json::Value value = 0;
1063 if (isFloat) {
1064 // We cast to `double` and let downstream calyx to deal with the actual
1065 // value's precision handling.
1066 value = bit_cast<double>(bitValue);
1067 } else {
1068 APInt apInt(/*numBits=*/elmTyBitWidth, bitValue, isSigned,
1069 /*implicitTrunc=*/true);
1070 // The conditional ternary operation will cause the `value` to interpret
1071 // the underlying data as unsigned regardless `isSigned` or not.
1072 if (isSigned)
1073 value = static_cast<int64_t>(apInt.getSExtValue());
1074 else
1075 value = apInt.getZExtValue();
1076 }
1077 result.push_back(std::move(value));
1078 }
1079
1080 componentState.setDataField(memoryOp.getName(), result);
1081 std::string numType =
1082 memtype.getElementType().isInteger() ? "bitnum" : "ieee754_float";
1083 componentState.setFormat(memoryOp.getName(), numType, isSigned,
1084 elmTyBitWidth);
1085
1086 return success();
1087}
1088
1089LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1090 memref::AllocOp allocOp) const {
1091 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
1092}
1093
1094LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1095 memref::AllocaOp allocOp) const {
1096 return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
1097}
1098
1099LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1100 memref::GetGlobalOp getGlobalOp) const {
1101 return buildAllocOp(getState<ComponentLoweringState>(), rewriter,
1102 getGlobalOp);
1103}
1104
1105LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1106 scf::YieldOp yieldOp) const {
1107 if (yieldOp.getOperands().empty()) {
1108 if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
1109 ScfForOp forOpInterface(forOp);
1110
1111 // Get the ForLoop's Induction Register.
1112 auto inductionReg = getState<ComponentLoweringState>().getForLoopIterReg(
1113 forOpInterface, 0);
1114
1115 Type regWidth = inductionReg.getOut().getType();
1116 // Adder should have same width as the inductionReg.
1117 SmallVector<Type> types(3, regWidth);
1118 auto addOp = getState<ComponentLoweringState>()
1119 .getNewLibraryOpInstance<calyx::AddLibOp>(
1120 rewriter, forOp.getLoc(), types);
1121
1122 auto directions = addOp.portDirections();
1123 // For an add operation, we expect two input ports and one output port.
1124 SmallVector<Value, 2> opInputPorts;
1125 Value opOutputPort;
1126 for (auto dir : enumerate(directions)) {
1127 switch (dir.value()) {
1129 opInputPorts.push_back(addOp.getResult(dir.index()));
1130 break;
1131 }
1133 opOutputPort = addOp.getResult(dir.index());
1134 break;
1135 }
1136 }
1137 }
1138
1139 // "Latch Group" increments inductionReg by forLoop's step value.
1140 calyx::ComponentOp componentOp =
1141 getState<ComponentLoweringState>().getComponentOp();
1142 SmallVector<StringRef, 4> groupIdentifier = {
1143 "incr", getState<ComponentLoweringState>().getUniqueName(forOp),
1144 "induction", "var"};
1145 auto groupOp = calyx::createGroup<calyx::GroupOp>(
1146 rewriter, componentOp, forOp.getLoc(),
1147 llvm::join(groupIdentifier, "_"));
1148 rewriter.setInsertionPointToEnd(groupOp.getBodyBlock());
1149
1150 // Assign inductionReg.out to the left port of the adder.
1151 Value leftOp = opInputPorts.front();
1152 rewriter.create<calyx::AssignOp>(forOp.getLoc(), leftOp,
1153 inductionReg.getOut());
1154 // Assign forOp.getConstantStep to the right port of the adder.
1155 Value rightOp = opInputPorts.back();
1156 rewriter.create<calyx::AssignOp>(
1157 forOp.getLoc(), rightOp,
1158 createConstant(forOp->getLoc(), rewriter, componentOp,
1159 regWidth.getIntOrFloatBitWidth(),
1160 forOp.getConstantStep().value().getSExtValue()));
1161 // Assign adder's output port to inductionReg.
1162 buildAssignmentsForRegisterWrite(rewriter, groupOp, componentOp,
1163 inductionReg, opOutputPort);
1164 // Set group as For Loop's "latch" group.
1165 getState<ComponentLoweringState>().setForLoopLatchGroup(forOpInterface,
1166 groupOp);
1167 getState<ComponentLoweringState>().registerEvaluatingGroup(opOutputPort,
1168 groupOp);
1169 return success();
1170 }
1171 if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp()))
1172 // Empty yield inside ifOp, essentially a no-op.
1173 return success();
1174 return yieldOp.getOperation()->emitError()
1175 << "Unsupported empty yieldOp outside ForOp or IfOp.";
1176 }
1177 // If yieldOp for a for loop is not empty, then we do not transform for loop.
1178 if (dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
1179 return yieldOp.getOperation()->emitError()
1180 << "Currently do not support non-empty yield operations inside for "
1181 "loops. Run --scf-for-to-while before running --scf-to-calyx.";
1182 }
1183
1184 if (auto whileOp = dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1185 ScfWhileOp whileOpInterface(whileOp);
1186
1187 auto assignGroup =
1188 getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
1189 rewriter, whileOpInterface,
1190 getState<ComponentLoweringState>().getComponentOp(),
1191 getState<ComponentLoweringState>().getUniqueName(whileOp) +
1192 "_latch",
1193 yieldOp->getOpOperands());
1194 getState<ComponentLoweringState>().setWhileLoopLatchGroup(whileOpInterface,
1195 assignGroup);
1196 return success();
1197 }
1198
1199 if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
1200 auto resultRegs = getState<ComponentLoweringState>().getResultRegs(ifOp);
1201
1202 if (yieldOp->getParentRegion() == &ifOp.getThenRegion()) {
1203 auto thenGroup = getState<ComponentLoweringState>().getThenGroup(ifOp);
1204 for (auto op : enumerate(yieldOp.getOperands())) {
1205 auto resultReg =
1206 getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
1207 buildAssignmentsForRegisterWrite(
1208 rewriter, thenGroup,
1209 getState<ComponentLoweringState>().getComponentOp(), resultReg,
1210 op.value());
1211 getState<ComponentLoweringState>().registerEvaluatingGroup(
1212 ifOp.getResult(op.index()), thenGroup);
1213 }
1214 }
1215
1216 if (!ifOp.getElseRegion().empty() &&
1217 (yieldOp->getParentRegion() == &ifOp.getElseRegion())) {
1218 auto elseGroup = getState<ComponentLoweringState>().getElseGroup(ifOp);
1219 for (auto op : enumerate(yieldOp.getOperands())) {
1220 auto resultReg =
1221 getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
1222 buildAssignmentsForRegisterWrite(
1223 rewriter, elseGroup,
1224 getState<ComponentLoweringState>().getComponentOp(), resultReg,
1225 op.value());
1226 getState<ComponentLoweringState>().registerEvaluatingGroup(
1227 ifOp.getResult(op.index()), elseGroup);
1228 }
1229 }
1230 }
1231 return success();
1232}
1233
1234LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1235 BranchOpInterface brOp) const {
1236 /// Branch argument passing group creation
1237 /// Branch operands are passed through registers. In BuildBasicBlockRegs we
1238 /// created registers for all branch arguments of each block. We now
1239 /// create groups for assigning values to these registers.
1240 Block *srcBlock = brOp->getBlock();
1241 for (auto succBlock : enumerate(brOp->getSuccessors())) {
1242 auto succOperands = brOp.getSuccessorOperands(succBlock.index());
1243 if (succOperands.empty())
1244 continue;
1245 // Create operand passing group
1246 std::string groupName = loweringState().blockName(srcBlock) + "_to_" +
1247 loweringState().blockName(succBlock.value());
1248 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter, getComponent(),
1249 brOp.getLoc(), groupName);
1250 // Fetch block argument registers associated with the basic block
1251 auto dstBlockArgRegs =
1252 getState<ComponentLoweringState>().getBlockArgRegs(succBlock.value());
1253 // Create register assignment for each block argument
1254 for (auto arg : enumerate(succOperands.getForwardedOperands())) {
1255 auto reg = dstBlockArgRegs[arg.index()];
1257 rewriter, groupOp,
1258 getState<ComponentLoweringState>().getComponentOp(), reg,
1259 arg.value());
1260 }
1261 /// Register the group as a block argument group, to be executed
1262 /// when entering the successor block from this block (srcBlock).
1263 getState<ComponentLoweringState>().addBlockArgGroup(
1264 srcBlock, succBlock.value(), groupOp);
1265 }
1266 return success();
1267}
1268
1269/// For each return statement, we create a new group for assigning to the
1270/// previously created return value registers.
1271LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1272 ReturnOp retOp) const {
1273 if (retOp.getNumOperands() == 0)
1274 return success();
1275
1276 std::string groupName =
1277 getState<ComponentLoweringState>().getUniqueName("ret_assign");
1278 auto groupOp = calyx::createGroup<calyx::GroupOp>(rewriter, getComponent(),
1279 retOp.getLoc(), groupName);
1280 for (auto op : enumerate(retOp.getOperands())) {
1281 auto reg = getState<ComponentLoweringState>().getReturnReg(op.index());
1283 rewriter, groupOp, getState<ComponentLoweringState>().getComponentOp(),
1284 reg, op.value());
1285 }
1286 /// Schedule group for execution for when executing the return op block.
1287 getState<ComponentLoweringState>().addBlockScheduleable(retOp->getBlock(),
1288 groupOp);
1289 return success();
1290}
1291
1292LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1293 arith::ConstantOp constOp) const {
1294 if (isa<IntegerType>(constOp.getType())) {
1295 /// Move constant operations to the compOp body as hw::ConstantOp's.
1296 APInt value;
1297 calyx::matchConstantOp(constOp, value);
1298 auto hwConstOp =
1299 rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp, value);
1300 hwConstOp->moveAfter(getComponent().getBodyBlock(),
1301 getComponent().getBodyBlock()->begin());
1302 } else {
1303 std::string name = getState<ComponentLoweringState>().getUniqueName("cst");
1304 auto floatAttr = cast<FloatAttr>(constOp.getValueAttr());
1305 auto intType =
1306 rewriter.getIntegerType(floatAttr.getType().getIntOrFloatBitWidth());
1307 auto calyxConstOp = rewriter.create<calyx::ConstantOp>(
1308 constOp.getLoc(), name, floatAttr, intType);
1309 calyxConstOp->moveAfter(getComponent().getBodyBlock(),
1310 getComponent().getBodyBlock()->begin());
1311 rewriter.replaceAllUsesWith(constOp, calyxConstOp.getOut());
1312 }
1313
1314 return success();
1315}
1316
1317LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1318 AddIOp op) const {
1319 return buildLibraryOp<calyx::CombGroupOp, calyx::AddLibOp>(rewriter, op);
1320}
1321LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1322 SubIOp op) const {
1323 return buildLibraryOp<calyx::CombGroupOp, calyx::SubLibOp>(rewriter, op);
1324}
1325LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1326 ShRUIOp op) const {
1327 return buildLibraryOp<calyx::CombGroupOp, calyx::RshLibOp>(rewriter, op);
1328}
1329LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1330 ShRSIOp op) const {
1331 return buildLibraryOp<calyx::CombGroupOp, calyx::SrshLibOp>(rewriter, op);
1332}
1333LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1334 ShLIOp op) const {
1335 return buildLibraryOp<calyx::CombGroupOp, calyx::LshLibOp>(rewriter, op);
1336}
1337LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1338 AndIOp op) const {
1339 return buildLibraryOp<calyx::CombGroupOp, calyx::AndLibOp>(rewriter, op);
1340}
1341LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1342 OrIOp op) const {
1343 return buildLibraryOp<calyx::CombGroupOp, calyx::OrLibOp>(rewriter, op);
1344}
1345LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1346 XOrIOp op) const {
1347 return buildLibraryOp<calyx::CombGroupOp, calyx::XorLibOp>(rewriter, op);
1348}
1349LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1350 SelectOp op) const {
1351 return buildLibraryOp<calyx::CombGroupOp, calyx::MuxLibOp>(rewriter, op);
1352}
1353
1354LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1355 CmpIOp op) const {
1356 switch (op.getPredicate()) {
1357 case CmpIPredicate::eq:
1358 return buildLibraryOp<calyx::CombGroupOp, calyx::EqLibOp>(rewriter, op);
1359 case CmpIPredicate::ne:
1360 return buildLibraryOp<calyx::CombGroupOp, calyx::NeqLibOp>(rewriter, op);
1361 case CmpIPredicate::uge:
1362 return buildLibraryOp<calyx::CombGroupOp, calyx::GeLibOp>(rewriter, op);
1363 case CmpIPredicate::ult:
1364 return buildLibraryOp<calyx::CombGroupOp, calyx::LtLibOp>(rewriter, op);
1365 case CmpIPredicate::ugt:
1366 return buildLibraryOp<calyx::CombGroupOp, calyx::GtLibOp>(rewriter, op);
1367 case CmpIPredicate::ule:
1368 return buildLibraryOp<calyx::CombGroupOp, calyx::LeLibOp>(rewriter, op);
1369 case CmpIPredicate::sge:
1370 return buildLibraryOp<calyx::CombGroupOp, calyx::SgeLibOp>(rewriter, op);
1371 case CmpIPredicate::slt:
1372 return buildLibraryOp<calyx::CombGroupOp, calyx::SltLibOp>(rewriter, op);
1373 case CmpIPredicate::sgt:
1374 return buildLibraryOp<calyx::CombGroupOp, calyx::SgtLibOp>(rewriter, op);
1375 case CmpIPredicate::sle:
1376 return buildLibraryOp<calyx::CombGroupOp, calyx::SleLibOp>(rewriter, op);
1377 }
1378 llvm_unreachable("unsupported comparison predicate");
1379}
1380LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1381 TruncIOp op) const {
1382 return buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
1383 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1384}
1385LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1386 ExtUIOp op) const {
1387 return buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
1388 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1389}
1390
1391LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1392 ExtSIOp op) const {
1393 return buildLibraryOp<calyx::CombGroupOp, calyx::ExtSILibOp>(
1394 rewriter, op, {op.getOperand().getType()}, {op.getType()});
1395}
1396
1397LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1398 IndexCastOp op) const {
1399 Type sourceType = calyx::normalizeType(rewriter, op.getOperand().getType());
1400 Type targetType = calyx::normalizeType(rewriter, op.getResult().getType());
1401 unsigned targetBits = targetType.getIntOrFloatBitWidth();
1402 unsigned sourceBits = sourceType.getIntOrFloatBitWidth();
1403 LogicalResult res = success();
1404
1405 if (targetBits == sourceBits) {
1406 /// Drop the index cast and replace uses of the target value with the source
1407 /// value.
1408 op.getResult().replaceAllUsesWith(op.getOperand());
1409 } else {
1410 /// pad/slice the source operand.
1411 if (sourceBits > targetBits)
1412 res = buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
1413 rewriter, op, {sourceType}, {targetType});
1414 else
1415 res = buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
1416 rewriter, op, {sourceType}, {targetType});
1417 }
1418 rewriter.eraseOp(op);
1419 return res;
1420}
1421
1422LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1423 scf::WhileOp whileOp) const {
1424 // Only need to add the whileOp to the BlockSchedulables scheduler interface.
1425 // Everything else was handled in the `BuildWhileGroups` pattern.
1426 ScfWhileOp scfWhileOp(whileOp);
1427 getState<ComponentLoweringState>().addBlockScheduleable(
1428 whileOp.getOperation()->getBlock(), WhileScheduleable{scfWhileOp});
1429 return success();
1430}
1431
1432LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1433 scf::ForOp forOp) const {
1434 // Only need to add the forOp to the BlockSchedulables scheduler interface.
1435 // Everything else was handled in the `BuildForGroups` pattern.
1436 ScfForOp scfForOp(forOp);
1437 // If we cannot compute the trip count of the for loop, then we should
1438 // emit an error saying to use --scf-for-to-while
1439 std::optional<uint64_t> bound = scfForOp.getBound();
1440 if (!bound.has_value()) {
1441 return scfForOp.getOperation()->emitError()
1442 << "Loop bound not statically known. Should "
1443 "transform into while loop using `--scf-for-to-while` before "
1444 "running --lower-scf-to-calyx.";
1445 }
1446 getState<ComponentLoweringState>().addBlockScheduleable(
1447 forOp.getOperation()->getBlock(), ForScheduleable{
1448 scfForOp,
1449 bound.value(),
1450 });
1451 return success();
1452}
1453
1454LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1455 scf::IfOp ifOp) const {
1456 getState<ComponentLoweringState>().addBlockScheduleable(
1457 ifOp.getOperation()->getBlock(), IfScheduleable{ifOp});
1458 return success();
1459}
1460
1461LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1462 scf::ReduceOp reduceOp) const {
1463 // we don't handle reduce operation and simply return success for now since
1464 // BuildParGroups would have already emitted an error and exited early
1465 // if a reduce operation was encountered.
1466 return success();
1467}
1468
1469LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1470 scf::ParallelOp parOp) const {
1471 getState<ComponentLoweringState>().addBlockScheduleable(
1472 parOp.getOperation()->getBlock(), ParScheduleable{parOp});
1473 return success();
1474}
1475
1476LogicalResult
1477BuildOpGroups::buildOp(PatternRewriter &rewriter,
1478 scf::ExecuteRegionOp executeRegionOp) const {
1479 // Simply return success because the only remaining `scf.execute_region` op
1480 // are generated by the `BuildParGroups` pass - the rest of them are inlined
1481 // by the `InlineExecuteRegionOpPattern`.
1482 return success();
1483}
1484
1485LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1486 CallOp callOp) const {
1487 std::string instanceName = calyx::getInstanceName(callOp);
1488 calyx::InstanceOp instanceOp =
1489 getState<ComponentLoweringState>().getInstance(instanceName);
1490 SmallVector<Value, 4> outputPorts;
1491 auto portInfos = instanceOp.getReferencedComponent().getPortInfo();
1492 for (auto [idx, portInfo] : enumerate(portInfos)) {
1493 if (portInfo.direction == calyx::Direction::Output)
1494 outputPorts.push_back(instanceOp.getResult(idx));
1495 }
1496
1497 // Replacing a CallOp results in the out port of the instance.
1498 for (auto [idx, result] : llvm::enumerate(callOp.getResults()))
1499 rewriter.replaceAllUsesWith(result, outputPorts[idx]);
1500
1501 // CallScheduleanle requires an instance, while CallOp can be used to get the
1502 // input ports.
1503 getState<ComponentLoweringState>().addBlockScheduleable(
1504 callOp.getOperation()->getBlock(), CallScheduleable{instanceOp, callOp});
1505 return success();
1506}
1507
1508/// Inlines Calyx ExecuteRegionOp operations within their parent blocks.
1509/// An execution region op (ERO) is inlined by:
1510/// i : add a sink basic block for all yield operations inside the
1511/// ERO to jump to
1512/// ii : Rewrite scf.yield calls inside the ERO to branch to the sink block
1513/// iii: inline the ERO region
1514/// TODO(#1850) evaluate the usefulness of this lowering pattern.
1516 : public OpRewritePattern<scf::ExecuteRegionOp> {
1517 using OpRewritePattern::OpRewritePattern;
1518
1519 LogicalResult matchAndRewrite(scf::ExecuteRegionOp execOp,
1520 PatternRewriter &rewriter) const override {
1521 /// Determine type of "yield" operations inside the ERO.
1522 TypeRange yieldTypes = execOp.getResultTypes();
1523
1524 /// Create sink basic block and rewrite uses of yield results to sink block
1525 /// arguments.
1526 rewriter.setInsertionPointAfter(execOp);
1527 auto *sinkBlock = rewriter.splitBlock(
1528 execOp->getBlock(),
1529 execOp.getOperation()->getIterator()->getNextNode()->getIterator());
1530 sinkBlock->addArguments(
1531 yieldTypes,
1532 SmallVector<Location, 4>(yieldTypes.size(), rewriter.getUnknownLoc()));
1533 for (auto res : enumerate(execOp.getResults()))
1534 res.value().replaceAllUsesWith(sinkBlock->getArgument(res.index()));
1535
1536 /// Rewrite yield calls as branches.
1537 for (auto yieldOp :
1538 make_early_inc_range(execOp.getRegion().getOps<scf::YieldOp>())) {
1539 rewriter.setInsertionPointAfter(yieldOp);
1540 rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, sinkBlock,
1541 yieldOp.getOperands());
1542 }
1543
1544 /// Inline the regionOp.
1545 auto *preBlock = execOp->getBlock();
1546 auto *execOpEntryBlock = &execOp.getRegion().front();
1547 auto *postBlock = execOp->getBlock()->splitBlock(execOp);
1548 rewriter.inlineRegionBefore(execOp.getRegion(), postBlock);
1549 rewriter.mergeBlocks(postBlock, preBlock);
1550 rewriter.eraseOp(execOp);
1551
1552 /// Finally, erase the unused entry block of the execOp region.
1553 rewriter.mergeBlocks(execOpEntryBlock, preBlock);
1554
1555 return success();
1556 }
1557};
1558
1559/// Creates a new Calyx component for each FuncOp in the program.
1561 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1562
1563 LogicalResult
1565 PatternRewriter &rewriter) const override {
1566 /// Maintain a mapping between funcOp input arguments and the port index
1567 /// which the argument will eventually map to.
1568 DenseMap<Value, unsigned> funcOpArgRewrites;
1569
1570 /// Maintain a mapping between funcOp output indexes and the component
1571 /// output port index which the return value will eventually map to.
1572 DenseMap<unsigned, unsigned> funcOpResultMapping;
1573
1574 /// Maintain a mapping between an external memory argument (identified by a
1575 /// memref) and eventual component input- and output port indices that will
1576 /// map to the memory ports. The pair denotes the start index of the memory
1577 /// ports in the in- and output ports of the component. Ports are expected
1578 /// to be ordered in the same manner as they are added by
1579 /// calyx::appendPortsForExternalMemref.
1580 DenseMap<Value, std::pair<unsigned, unsigned>> extMemoryCompPortIndices;
1581
1582 /// Create I/O ports. Maintain separate in/out port vectors to determine
1583 /// which port index each function argument will eventually map to.
1584 SmallVector<calyx::PortInfo> inPorts, outPorts;
1585 FunctionType funcType = funcOp.getFunctionType();
1586 for (auto arg : enumerate(funcOp.getArguments())) {
1587 if (!isa<MemRefType>(arg.value().getType())) {
1588 /// Single-port arguments
1589 std::string inName;
1590 if (auto portNameAttr = funcOp.getArgAttrOfType<StringAttr>(
1591 arg.index(), scfToCalyx::sPortNameAttr))
1592 inName = portNameAttr.str();
1593 else
1594 inName = "in" + std::to_string(arg.index());
1595 funcOpArgRewrites[arg.value()] = inPorts.size();
1596 inPorts.push_back(calyx::PortInfo{
1597 rewriter.getStringAttr(inName),
1598 calyx::normalizeType(rewriter, arg.value().getType()),
1600 DictionaryAttr::get(rewriter.getContext(), {})});
1601 }
1602 }
1603 for (auto res : enumerate(funcType.getResults())) {
1604 std::string resName;
1605 if (auto portNameAttr = funcOp.getResultAttrOfType<StringAttr>(
1606 res.index(), scfToCalyx::sPortNameAttr))
1607 resName = portNameAttr.str();
1608 else
1609 resName = "out" + std::to_string(res.index());
1610 funcOpResultMapping[res.index()] = outPorts.size();
1611
1612 outPorts.push_back(calyx::PortInfo{
1613 rewriter.getStringAttr(resName),
1614 calyx::normalizeType(rewriter, res.value()), calyx::Direction::Output,
1615 DictionaryAttr::get(rewriter.getContext(), {})});
1616 }
1617
1618 /// We've now recorded all necessary indices. Merge in- and output ports
1619 /// and add the required mandatory component ports.
1620 auto ports = inPorts;
1621 llvm::append_range(ports, outPorts);
1622 calyx::addMandatoryComponentPorts(rewriter, ports);
1623
1624 /// Create a calyx::ComponentOp corresponding to the to-be-lowered function.
1625 auto compOp = rewriter.create<calyx::ComponentOp>(
1626 funcOp.getLoc(), rewriter.getStringAttr(funcOp.getSymName()), ports);
1627
1628 std::string funcName = "func_" + funcOp.getSymName().str();
1629 rewriter.modifyOpInPlace(funcOp, [&]() { funcOp.setSymName(funcName); });
1630
1631 /// Mark this component as the toplevel if it's the top-level function of
1632 /// the module.
1633 if (compOp.getName() == loweringState().getTopLevelFunction())
1634 compOp->setAttr("toplevel", rewriter.getUnitAttr());
1635
1636 /// Store the function-to-component mapping.
1637 functionMapping[funcOp] = compOp;
1638 auto *compState = loweringState().getState<ComponentLoweringState>(compOp);
1639 compState->setFuncOpResultMapping(funcOpResultMapping);
1640
1641 unsigned extMemCounter = 0;
1642 for (auto arg : enumerate(funcOp.getArguments())) {
1643 if (isa<MemRefType>(arg.value().getType())) {
1644 std::string memName =
1645 llvm::join_items("_", "arg_mem", std::to_string(extMemCounter++));
1646
1647 rewriter.setInsertionPointToStart(compOp.getBodyBlock());
1648 MemRefType memtype = cast<MemRefType>(arg.value().getType());
1649 SmallVector<int64_t> addrSizes;
1650 SmallVector<int64_t> sizes;
1651 for (int64_t dim : memtype.getShape()) {
1652 sizes.push_back(dim);
1653 addrSizes.push_back(calyx::handleZeroWidth(dim));
1654 }
1655 if (sizes.empty() && addrSizes.empty()) {
1656 sizes.push_back(1);
1657 addrSizes.push_back(1);
1658 }
1659 auto memOp = rewriter.create<calyx::SeqMemoryOp>(
1660 funcOp.getLoc(), memName,
1661 memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
1662 // we don't set the memory to "external", which implies it's a reference
1663
1664 compState->registerMemoryInterface(arg.value(),
1665 calyx::MemoryInterface(memOp));
1666 }
1667 }
1668
1669 /// Rewrite funcOp SSA argument values to the CompOp arguments.
1670 for (auto &mapping : funcOpArgRewrites)
1671 mapping.getFirst().replaceAllUsesWith(
1672 compOp.getArgument(mapping.getSecond()));
1673
1674 return success();
1675 }
1676};
1677
1678/// In BuildWhileGroups, a register is created for each iteration argumenet of
1679/// the while op. These registers are then written to on the while op
1680/// terminating yield operation alongside before executing the whileOp in the
1681/// schedule, to set the initial values of the argument registers.
1683 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1684
1685 LogicalResult
1687 PatternRewriter &rewriter) const override {
1688 LogicalResult res = success();
1689 funcOp.walk([&](Operation *op) {
1690 // Only work on ops that support the ScfWhileOp.
1691 if (!isa<scf::WhileOp>(op))
1692 return WalkResult::advance();
1693
1694 auto scfWhileOp = cast<scf::WhileOp>(op);
1695 ScfWhileOp whileOp(scfWhileOp);
1696
1697 getState<ComponentLoweringState>().setUniqueName(whileOp.getOperation(),
1698 "while");
1699
1700 /// Check for do-while loops.
1701 /// TODO(mortbopet) can we support these? for now, do not support loops
1702 /// where iterargs are changed in the 'before' region. scf.WhileOp also
1703 /// has support for different types of iter_args and return args which we
1704 /// also do not support; iter_args and while return values are placed in
1705 /// the same registers.
1706 for (auto barg :
1707 enumerate(scfWhileOp.getBefore().front().getArguments())) {
1708 auto condOp = scfWhileOp.getConditionOp().getArgs()[barg.index()];
1709 if (barg.value() != condOp) {
1710 res = whileOp.getOperation()->emitError()
1711 << loweringState().irName(barg.value())
1712 << " != " << loweringState().irName(condOp)
1713 << "do-while loops not supported; expected iter-args to "
1714 "remain untransformed in the 'before' region of the "
1715 "scf.while op.";
1716 return WalkResult::interrupt();
1717 }
1718 }
1719
1720 /// Create iteration argument registers.
1721 /// The iteration argument registers will be referenced:
1722 /// - In the "before" part of the while loop, calculating the conditional,
1723 /// - In the "after" part of the while loop,
1724 /// - Outside the while loop, rewriting the while loop return values.
1725 for (auto arg : enumerate(whileOp.getBodyArgs())) {
1726 std::string name = getState<ComponentLoweringState>()
1727 .getUniqueName(whileOp.getOperation())
1728 .str() +
1729 "_arg" + std::to_string(arg.index());
1730 auto reg =
1731 createRegister(arg.value().getLoc(), rewriter, getComponent(),
1732 arg.value().getType().getIntOrFloatBitWidth(), name);
1733 getState<ComponentLoweringState>().addWhileLoopIterReg(whileOp, reg,
1734 arg.index());
1735 arg.value().replaceAllUsesWith(reg.getOut());
1736
1737 /// Also replace uses in the "before" region of the while loop
1738 whileOp.getConditionBlock()
1739 ->getArgument(arg.index())
1740 .replaceAllUsesWith(reg.getOut());
1741 }
1742
1743 /// Create iter args initial value assignment group(s), one per register.
1744 SmallVector<calyx::GroupOp> initGroups;
1745 auto numOperands = whileOp.getOperation()->getNumOperands();
1746 for (size_t i = 0; i < numOperands; ++i) {
1747 auto initGroupOp =
1748 getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
1749 rewriter, whileOp,
1750 getState<ComponentLoweringState>().getComponentOp(),
1751 getState<ComponentLoweringState>().getUniqueName(
1752 whileOp.getOperation()) +
1753 "_init_" + std::to_string(i),
1754 whileOp.getOperation()->getOpOperand(i));
1755 initGroups.push_back(initGroupOp);
1756 }
1757
1758 getState<ComponentLoweringState>().setWhileLoopInitGroups(whileOp,
1759 initGroups);
1760
1761 return WalkResult::advance();
1762 });
1763 return res;
1764 }
1765};
1766
1767/// In BuildForGroups, a register is created for the iteration argument of
1768/// the for op. This register is then initialized to the lowerBound of the for
1769/// loop in a group that executes the for loop.
1771 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1772
1773 LogicalResult
1775 PatternRewriter &rewriter) const override {
1776 LogicalResult res = success();
1777 funcOp.walk([&](Operation *op) {
1778 // Only work on ops that support the ScfForOp.
1779 if (!isa<scf::ForOp>(op))
1780 return WalkResult::advance();
1781
1782 auto scfForOp = cast<scf::ForOp>(op);
1783 ScfForOp forOp(scfForOp);
1784
1785 getState<ComponentLoweringState>().setUniqueName(forOp.getOperation(),
1786 "for");
1787
1788 // Create a register for the InductionVar, and set that Register as the
1789 // only IterReg for the For Loop
1790 auto inductionVar = forOp.getOperation().getInductionVar();
1791 SmallVector<std::string, 3> inductionVarIdentifiers = {
1792 getState<ComponentLoweringState>()
1793 .getUniqueName(forOp.getOperation())
1794 .str(),
1795 "induction", "var"};
1796 std::string name = llvm::join(inductionVarIdentifiers, "_");
1797 auto reg =
1798 createRegister(inductionVar.getLoc(), rewriter, getComponent(),
1799 inductionVar.getType().getIntOrFloatBitWidth(), name);
1800 getState<ComponentLoweringState>().addForLoopIterReg(forOp, reg, 0);
1801 inductionVar.replaceAllUsesWith(reg.getOut());
1802
1803 // Create InitGroup that sets the InductionVar to LowerBound
1804 calyx::ComponentOp componentOp =
1805 getState<ComponentLoweringState>().getComponentOp();
1806 SmallVector<calyx::GroupOp> initGroups;
1807 SmallVector<std::string, 4> groupIdentifiers = {
1808 "init",
1809 getState<ComponentLoweringState>()
1810 .getUniqueName(forOp.getOperation())
1811 .str(),
1812 "induction", "var"};
1813 std::string groupName = llvm::join(groupIdentifiers, "_");
1814 auto groupOp = calyx::createGroup<calyx::GroupOp>(
1815 rewriter, componentOp, forOp.getLoc(), groupName);
1816 buildAssignmentsForRegisterWrite(rewriter, groupOp, componentOp, reg,
1817 forOp.getOperation().getLowerBound());
1818 initGroups.push_back(groupOp);
1819 getState<ComponentLoweringState>().setForLoopInitGroups(forOp,
1820 initGroups);
1821
1822 return WalkResult::advance();
1823 });
1824 return res;
1825 }
1826};
1827
1829 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1830
1831 LogicalResult
1833 PatternRewriter &rewriter) const override {
1834 LogicalResult res = success();
1835 funcOp.walk([&](Operation *op) {
1836 if (!isa<scf::IfOp>(op))
1837 return WalkResult::advance();
1838
1839 auto scfIfOp = cast<scf::IfOp>(op);
1840
1841 // There is no need to build `thenGroup` and `elseGroup` if `scfIfOp`
1842 // doesn't yield any result since these groups are created for managing
1843 // the result values.
1844 if (scfIfOp.getResults().empty())
1845 return WalkResult::advance();
1846
1847 calyx::ComponentOp componentOp =
1848 getState<ComponentLoweringState>().getComponentOp();
1849
1850 std::string thenGroupName =
1851 getState<ComponentLoweringState>().getUniqueName("then_br");
1852 auto thenGroupOp = calyx::createGroup<calyx::GroupOp>(
1853 rewriter, componentOp, scfIfOp.getLoc(), thenGroupName);
1854 getState<ComponentLoweringState>().setThenGroup(scfIfOp, thenGroupOp);
1855
1856 if (!scfIfOp.getElseRegion().empty()) {
1857 std::string elseGroupName =
1858 getState<ComponentLoweringState>().getUniqueName("else_br");
1859 auto elseGroupOp = calyx::createGroup<calyx::GroupOp>(
1860 rewriter, componentOp, scfIfOp.getLoc(), elseGroupName);
1861 getState<ComponentLoweringState>().setElseGroup(scfIfOp, elseGroupOp);
1862 }
1863
1864 for (auto ifOpRes : scfIfOp.getResults()) {
1865 auto reg = createRegister(
1866 scfIfOp.getLoc(), rewriter, getComponent(),
1867 ifOpRes.getType().getIntOrFloatBitWidth(),
1868 getState<ComponentLoweringState>().getUniqueName("if_res"));
1869 getState<ComponentLoweringState>().setResultRegs(
1870 scfIfOp, reg, ifOpRes.getResultNumber());
1871 }
1872
1873 return WalkResult::advance();
1874 });
1875 return res;
1876 }
1877};
1878
1880 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1881
1882 LogicalResult
1884 PatternRewriter &rewriter) const override {
1885 WalkResult walkResult = funcOp.walk([&](scf::ParallelOp scfParOp) {
1886 if (!scfParOp.getResults().empty()) {
1887 scfParOp.emitError(
1888 "Reduce operations in scf.parallel is not supported yet");
1889 return WalkResult::interrupt();
1890 }
1891
1892 if (failed(partialEval(rewriter, scfParOp)))
1893 return WalkResult::interrupt();
1894
1895 return WalkResult::advance();
1896 });
1897
1898 return walkResult.wasInterrupted() ? failure() : success();
1899 }
1900
1901private:
1902 // Partially evaluate/pre-compute all blocks being executed in parallel by
1903 // statically generate loop indices combinations
1904 LogicalResult partialEval(PatternRewriter &rewriter,
1905 scf::ParallelOp scfParOp) const {
1906 assert(scfParOp.getLoopSteps() && "Parallel loop must have steps");
1907 auto *body = scfParOp.getBody();
1908 auto parOpIVs = scfParOp.getInductionVars();
1909 auto steps = scfParOp.getStep();
1910 auto lowerBounds = scfParOp.getLowerBound();
1911 auto upperBounds = scfParOp.getUpperBound();
1912 rewriter.setInsertionPointAfter(scfParOp);
1913 scf::ParallelOp newParOp = scfParOp.cloneWithoutRegions();
1914 auto loc = newParOp.getLoc();
1915 rewriter.insert(newParOp);
1916 OpBuilder insideBuilder(newParOp);
1917 auto &region = newParOp.getRegion();
1918 auto *newParBodyBlock = &region.emplaceBlock();
1919
1920 // extract lower bounds, upper bounds, and steps as integer index values
1921 SmallVector<int64_t> lbVals, ubVals, stepVals;
1922 for (auto lb : lowerBounds) {
1923 auto lbOp = lb.getDefiningOp<arith::ConstantIndexOp>();
1924 assert(lbOp &&
1925 "Lower bound must be a statically computable constant index");
1926 lbVals.push_back(lbOp.value());
1927 }
1928 for (auto ub : upperBounds) {
1929 auto ubOp = ub.getDefiningOp<arith::ConstantIndexOp>();
1930 assert(ubOp &&
1931 "Upper bound must be a statically computable constant index");
1932 ubVals.push_back(ubOp.value());
1933 }
1934 for (auto step : steps) {
1935 auto stepOp = step.getDefiningOp<arith::ConstantIndexOp>();
1936 assert(stepOp && "Step must be a statically computable constant index");
1937 stepVals.push_back(stepOp.value());
1938 }
1939
1940 // Initialize indices with lower bounds
1941 SmallVector<int64_t> indices = lbVals;
1942
1943 while (true) {
1944 insideBuilder.setInsertionPointToEnd(newParBodyBlock);
1945 // Create an `scf.execute_region` to wrap each unrolled block since
1946 // `scf.parallel` requires only one block in the body region.
1947 auto execRegionOp =
1948 insideBuilder.create<scf::ExecuteRegionOp>(loc, TypeRange{});
1949 auto &execRegion = execRegionOp.getRegion();
1950 Block *execBlock = &execRegion.emplaceBlock();
1951 OpBuilder regionBuilder(execRegionOp);
1952 // Each iteration starts with a fresh mapping, so each new block’s
1953 // argument of a region-based operation (such as `scf.for`) get re-mapped
1954 // independently.
1955 IRMapping operandMap;
1956
1957 regionBuilder.setInsertionPointToEnd(execBlock);
1958 // Map induction variables to constant indices
1959 for (unsigned i = 0; i < indices.size(); ++i) {
1960 Value ivConstant =
1961 regionBuilder.create<arith::ConstantIndexOp>(loc, indices[i]);
1962 operandMap.map(parOpIVs[i], ivConstant);
1963 }
1964
1965 for (auto it = body->begin(); it != std::prev(body->end()); ++it)
1966 regionBuilder.clone(*it, operandMap);
1967
1968 // A terminator should always be inserted in `scf.execute_region`'s block.
1969 regionBuilder.create<scf::ReduceOp>(loc);
1970 // Increment indices using `step`
1971 bool done = false;
1972 for (int dim = indices.size() - 1; dim >= 0; --dim) {
1973 indices[dim] += stepVals[dim];
1974 if (indices[dim] < ubVals[dim])
1975 break;
1976 indices[dim] = lbVals[dim];
1977 if (dim == 0)
1978 // All combinations have been generated
1979 done = true;
1980 }
1981 if (done)
1982 break;
1983 }
1984
1985 rewriter.setInsertionPointToEnd(newParOp.getBody());
1986 rewriter.create<scf::ReduceOp>(newParOp.getLoc());
1987
1988 rewriter.replaceOp(scfParOp, newParOp);
1989
1990 auto containsIfOp = [](scf::ParallelOp parOp) -> bool {
1991 bool hasIfOp = false;
1992 parOp.walk([&](scf::IfOp ifOp) {
1993 hasIfOp = true;
1994 return WalkResult::interrupt();
1995 });
1996 return hasIfOp;
1997 };
1998 if (containsIfOp(newParOp)) {
1999 auto *context = newParOp.getContext();
2000 RewritePatternSet patterns(newParOp.getContext());
2001 scf::IfOp::getCanonicalizationPatterns(patterns, context);
2002 if (failed(
2003 applyPatternsGreedily(newParOp->getParentOfType<func::FuncOp>(),
2004 std::move(patterns)))) {
2005 return failure();
2006 }
2007 }
2008
2009 return success();
2010 }
2011};
2012
2013/// Builds a control schedule by traversing the CFG of the function and
2014/// associating this with the previously created groups.
2015/// For simplicity, the generated control flow is expanded for all possible
2016/// paths in the input DAG. This elaborated control flow is later reduced in
2017/// the runControlFlowSimplification passes.
2019 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2020
2021 LogicalResult
2023 PatternRewriter &rewriter) const override {
2024 auto *entryBlock = &funcOp.getBlocks().front();
2025 rewriter.setInsertionPointToStart(
2026 getComponent().getControlOp().getBodyBlock());
2027 auto topLevelSeqOp = rewriter.create<calyx::SeqOp>(funcOp.getLoc());
2028 DenseSet<Block *> path;
2029 return buildCFGControl(path, rewriter, topLevelSeqOp.getBodyBlock(),
2030 nullptr, entryBlock);
2031 }
2032
2033private:
2034 /// Sequentially schedules the groups that registered themselves with
2035 /// 'block'.
2036 LogicalResult scheduleBasicBlock(PatternRewriter &rewriter,
2037 const DenseSet<Block *> &path,
2038 mlir::Block *parentCtrlBlock,
2039 mlir::Block *block) const {
2040 auto compBlockScheduleables =
2041 getState<ComponentLoweringState>().getBlockScheduleables(block);
2042 auto loc = block->front().getLoc();
2043
2044 if (compBlockScheduleables.size() > 1 &&
2045 !isa<scf::ParallelOp>(block->getParentOp())) {
2046 auto seqOp = rewriter.create<calyx::SeqOp>(loc);
2047 parentCtrlBlock = seqOp.getBodyBlock();
2048 }
2049
2050 for (auto &group : compBlockScheduleables) {
2051 rewriter.setInsertionPointToEnd(parentCtrlBlock);
2052 if (auto groupPtr = std::get_if<calyx::GroupOp>(&group); groupPtr) {
2053 rewriter.create<calyx::EnableOp>(groupPtr->getLoc(),
2054 groupPtr->getSymName());
2055 } else if (auto whileSchedPtr = std::get_if<WhileScheduleable>(&group);
2056 whileSchedPtr) {
2057 auto &whileOp = whileSchedPtr->whileOp;
2058
2059 auto whileCtrlOp = buildWhileCtrlOp(
2060 whileOp,
2061 getState<ComponentLoweringState>().getWhileLoopInitGroups(whileOp),
2062 rewriter);
2063 rewriter.setInsertionPointToEnd(whileCtrlOp.getBodyBlock());
2064 auto whileBodyOp =
2065 rewriter.create<calyx::SeqOp>(whileOp.getOperation()->getLoc());
2066 auto *whileBodyOpBlock = whileBodyOp.getBodyBlock();
2067
2068 /// Only schedule the 'after' block. The 'before' block is
2069 /// implicitly scheduled when evaluating the while condition.
2070 if (LogicalResult result =
2071 buildCFGControl(path, rewriter, whileBodyOpBlock, block,
2072 whileOp.getBodyBlock());
2073 result.failed())
2074 return result;
2075
2076 // Insert loop-latch at the end of the while group
2077 rewriter.setInsertionPointToEnd(whileBodyOpBlock);
2078 calyx::GroupOp whileLatchGroup =
2079 getState<ComponentLoweringState>().getWhileLoopLatchGroup(whileOp);
2080 rewriter.create<calyx::EnableOp>(whileLatchGroup.getLoc(),
2081 whileLatchGroup.getName());
2082 } else if (auto *parSchedPtr = std::get_if<ParScheduleable>(&group)) {
2083 auto parOp = parSchedPtr->parOp;
2084 auto calyxParOp = rewriter.create<calyx::ParOp>(parOp.getLoc());
2085
2086 WalkResult walkResult =
2087 parOp.walk([&](scf::ExecuteRegionOp execRegion) {
2088 rewriter.setInsertionPointToEnd(calyxParOp.getBodyBlock());
2089 auto seqOp = rewriter.create<calyx::SeqOp>(execRegion.getLoc());
2090 rewriter.setInsertionPointToEnd(seqOp.getBodyBlock());
2091
2092 for (auto &execBlock : execRegion.getRegion().getBlocks()) {
2093 if (LogicalResult res = scheduleBasicBlock(
2094 rewriter, path, seqOp.getBodyBlock(), &execBlock);
2095 res.failed()) {
2096 return WalkResult::interrupt();
2097 }
2098 }
2099 return WalkResult::advance();
2100 });
2101
2102 if (walkResult.wasInterrupted())
2103 return failure();
2104 } else if (auto *forSchedPtr = std::get_if<ForScheduleable>(&group);
2105 forSchedPtr) {
2106 auto forOp = forSchedPtr->forOp;
2107
2108 auto forCtrlOp = buildForCtrlOp(
2109 forOp,
2110 getState<ComponentLoweringState>().getForLoopInitGroups(forOp),
2111 forSchedPtr->bound, rewriter);
2112 rewriter.setInsertionPointToEnd(forCtrlOp.getBodyBlock());
2113 auto forBodyOp =
2114 rewriter.create<calyx::SeqOp>(forOp.getOperation()->getLoc());
2115 auto *forBodyOpBlock = forBodyOp.getBodyBlock();
2116
2117 // Schedule the body of the for loop.
2118 if (LogicalResult res = buildCFGControl(path, rewriter, forBodyOpBlock,
2119 block, forOp.getBodyBlock());
2120 res.failed())
2121 return res;
2122
2123 // Insert loop-latch at the end of the while group.
2124 rewriter.setInsertionPointToEnd(forBodyOpBlock);
2125 calyx::GroupOp forLatchGroup =
2126 getState<ComponentLoweringState>().getForLoopLatchGroup(forOp);
2127 rewriter.create<calyx::EnableOp>(forLatchGroup.getLoc(),
2128 forLatchGroup.getName());
2129 } else if (auto *ifSchedPtr = std::get_if<IfScheduleable>(&group);
2130 ifSchedPtr) {
2131 auto ifOp = ifSchedPtr->ifOp;
2132
2133 Location loc = ifOp->getLoc();
2134
2135 auto cond = ifOp.getCondition();
2136 auto condGroup = getState<ComponentLoweringState>()
2137 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
2138
2139 auto symbolAttr = FlatSymbolRefAttr::get(
2140 StringAttr::get(getContext(), condGroup.getSymName()));
2141
2142 bool initElse = !ifOp.getElseRegion().empty();
2143 auto ifCtrlOp = rewriter.create<calyx::IfOp>(
2144 loc, cond, symbolAttr, /*initializeElseBody=*/initElse);
2145
2146 rewriter.setInsertionPointToEnd(ifCtrlOp.getBodyBlock());
2147
2148 auto thenSeqOp =
2149 rewriter.create<calyx::SeqOp>(ifOp.getThenRegion().getLoc());
2150 auto *thenSeqOpBlock = thenSeqOp.getBodyBlock();
2151
2152 auto *thenBlock = &ifOp.getThenRegion().front();
2153 LogicalResult res = buildCFGControl(path, rewriter, thenSeqOpBlock,
2154 /*preBlock=*/block, thenBlock);
2155 if (res.failed())
2156 return res;
2157
2158 // `thenGroup`s won't be created in the first place if there's no
2159 // yielded results for this `ifOp`.
2160 if (!ifOp.getResults().empty()) {
2161 rewriter.setInsertionPointToEnd(thenSeqOpBlock);
2162 calyx::GroupOp thenGroup =
2163 getState<ComponentLoweringState>().getThenGroup(ifOp);
2164 rewriter.create<calyx::EnableOp>(thenGroup.getLoc(),
2165 thenGroup.getName());
2166 }
2167
2168 if (!ifOp.getElseRegion().empty()) {
2169 rewriter.setInsertionPointToEnd(ifCtrlOp.getElseBody());
2170
2171 auto elseSeqOp =
2172 rewriter.create<calyx::SeqOp>(ifOp.getElseRegion().getLoc());
2173 auto *elseSeqOpBlock = elseSeqOp.getBodyBlock();
2174
2175 auto *elseBlock = &ifOp.getElseRegion().front();
2176 res = buildCFGControl(path, rewriter, elseSeqOpBlock,
2177 /*preBlock=*/block, elseBlock);
2178 if (res.failed())
2179 return res;
2180
2181 if (!ifOp.getResults().empty()) {
2182 rewriter.setInsertionPointToEnd(elseSeqOpBlock);
2183 calyx::GroupOp elseGroup =
2184 getState<ComponentLoweringState>().getElseGroup(ifOp);
2185 rewriter.create<calyx::EnableOp>(elseGroup.getLoc(),
2186 elseGroup.getName());
2187 }
2188 }
2189 } else if (auto *callSchedPtr = std::get_if<CallScheduleable>(&group)) {
2190 auto instanceOp = callSchedPtr->instanceOp;
2191 OpBuilder::InsertionGuard g(rewriter);
2192 auto callBody = rewriter.create<calyx::SeqOp>(instanceOp.getLoc());
2193 rewriter.setInsertionPointToStart(callBody.getBodyBlock());
2194
2195 auto callee = callSchedPtr->callOp.getCallee();
2196 auto *calleeOp = SymbolTable::lookupNearestSymbolFrom(
2197 callSchedPtr->callOp.getOperation()->getParentOp(),
2198 StringAttr::get(rewriter.getContext(), "func_" + callee.str()));
2199 FuncOp calleeFunc = dyn_cast_or_null<FuncOp>(calleeOp);
2200
2201 auto instanceOpComp =
2202 llvm::cast<calyx::ComponentOp>(instanceOp.getReferencedComponent());
2203 auto *instanceOpLoweringState =
2204 loweringState().getState(instanceOpComp);
2205
2206 SmallVector<Value, 4> instancePorts;
2207 SmallVector<Value, 4> inputPorts;
2208 SmallVector<Attribute, 4> refCells;
2209 for (auto operandEnum : enumerate(callSchedPtr->callOp.getOperands())) {
2210 auto operand = operandEnum.value();
2211 auto index = operandEnum.index();
2212 if (!isa<MemRefType>(operand.getType())) {
2213 inputPorts.push_back(operand);
2214 continue;
2215 }
2216
2217 auto memOpName = getState<ComponentLoweringState>()
2218 .getMemoryInterface(operand)
2219 .memName();
2220 auto memOpNameAttr =
2221 SymbolRefAttr::get(rewriter.getContext(), memOpName);
2222 Value argI = calleeFunc.getArgument(index);
2223 if (isa<MemRefType>(argI.getType())) {
2224 NamedAttrList namedAttrList;
2225 namedAttrList.append(
2226 rewriter.getStringAttr(
2227 instanceOpLoweringState->getMemoryInterface(argI)
2228 .memName()),
2229 memOpNameAttr);
2230 refCells.push_back(
2231 DictionaryAttr::get(rewriter.getContext(), namedAttrList));
2232 }
2233 }
2234 llvm::copy(instanceOp.getResults().take_front(inputPorts.size()),
2235 std::back_inserter(instancePorts));
2236
2237 ArrayAttr refCellsAttr =
2238 ArrayAttr::get(rewriter.getContext(), refCells);
2239
2240 rewriter.create<calyx::InvokeOp>(
2241 instanceOp.getLoc(), instanceOp.getSymName(), instancePorts,
2242 inputPorts, refCellsAttr, ArrayAttr::get(rewriter.getContext(), {}),
2243 ArrayAttr::get(rewriter.getContext(), {}));
2244 } else
2245 llvm_unreachable("Unknown scheduleable");
2246 }
2247 return success();
2248 }
2249
2250 /// Schedules a block by inserting a branch argument assignment block (if any)
2251 /// before recursing into the scheduling of the block innards.
2252 /// Blocks 'from' and 'to' refer to blocks in the source program.
2253 /// parentCtrlBlock refers to the control block wherein control operations are
2254 /// to be inserted.
2255 LogicalResult schedulePath(PatternRewriter &rewriter,
2256 const DenseSet<Block *> &path, Location loc,
2257 Block *from, Block *to,
2258 Block *parentCtrlBlock) const {
2259 /// Schedule any registered block arguments to be executed before the body
2260 /// of the branch.
2261 rewriter.setInsertionPointToEnd(parentCtrlBlock);
2262 auto preSeqOp = rewriter.create<calyx::SeqOp>(loc);
2263 rewriter.setInsertionPointToEnd(preSeqOp.getBodyBlock());
2264 for (auto barg :
2265 getState<ComponentLoweringState>().getBlockArgGroups(from, to))
2266 rewriter.create<calyx::EnableOp>(barg.getLoc(), barg.getSymName());
2267
2268 return buildCFGControl(path, rewriter, parentCtrlBlock, from, to);
2269 }
2270
2271 LogicalResult buildCFGControl(DenseSet<Block *> path,
2272 PatternRewriter &rewriter,
2273 mlir::Block *parentCtrlBlock,
2274 mlir::Block *preBlock,
2275 mlir::Block *block) const {
2276 if (path.count(block) != 0)
2277 return preBlock->getTerminator()->emitError()
2278 << "CFG backedge detected. Loops must be raised to 'scf.while' or "
2279 "'scf.for' operations.";
2280
2281 rewriter.setInsertionPointToEnd(parentCtrlBlock);
2282 LogicalResult bbSchedResult =
2283 scheduleBasicBlock(rewriter, path, parentCtrlBlock, block);
2284 if (bbSchedResult.failed())
2285 return bbSchedResult;
2286
2287 path.insert(block);
2288 auto successors = block->getSuccessors();
2289 auto nSuccessors = successors.size();
2290 if (nSuccessors > 0) {
2291 auto brOp = dyn_cast<BranchOpInterface>(block->getTerminator());
2292 assert(brOp);
2293 if (nSuccessors > 1) {
2294 /// TODO(mortbopet): we could choose to support ie. std.switch, but it
2295 /// would probably be easier to just require it to be lowered
2296 /// beforehand.
2297 assert(nSuccessors == 2 &&
2298 "only conditional branches supported for now...");
2299 /// Wrap each branch inside an if/else.
2300 auto cond = brOp->getOperand(0);
2301 auto condGroup = getState<ComponentLoweringState>()
2302 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
2303 auto symbolAttr = FlatSymbolRefAttr::get(
2304 StringAttr::get(getContext(), condGroup.getSymName()));
2305
2306 auto ifOp = rewriter.create<calyx::IfOp>(
2307 brOp->getLoc(), cond, symbolAttr, /*initializeElseBody=*/true);
2308 rewriter.setInsertionPointToStart(ifOp.getThenBody());
2309 auto thenSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
2310 rewriter.setInsertionPointToStart(ifOp.getElseBody());
2311 auto elseSeqOp = rewriter.create<calyx::SeqOp>(brOp.getLoc());
2312
2313 bool trueBrSchedSuccess =
2314 schedulePath(rewriter, path, brOp.getLoc(), block, successors[0],
2315 thenSeqOp.getBodyBlock())
2316 .succeeded();
2317 bool falseBrSchedSuccess = true;
2318 if (trueBrSchedSuccess) {
2319 falseBrSchedSuccess =
2320 schedulePath(rewriter, path, brOp.getLoc(), block, successors[1],
2321 elseSeqOp.getBodyBlock())
2322 .succeeded();
2323 }
2324
2325 return success(trueBrSchedSuccess && falseBrSchedSuccess);
2326 } else {
2327 /// Schedule sequentially within the current parent control block.
2328 return schedulePath(rewriter, path, brOp.getLoc(), block,
2329 successors.front(), parentCtrlBlock);
2330 }
2331 }
2332 return success();
2333 }
2334
2335 // Insert a Par of initGroups at Location loc. Used as helper for
2336 // `buildWhileCtrlOp` and `buildForCtrlOp`.
2337 void
2338 insertParInitGroups(PatternRewriter &rewriter, Location loc,
2339 const SmallVector<calyx::GroupOp> &initGroups) const {
2340 PatternRewriter::InsertionGuard g(rewriter);
2341 auto parOp = rewriter.create<calyx::ParOp>(loc);
2342 rewriter.setInsertionPointToStart(parOp.getBodyBlock());
2343 for (calyx::GroupOp group : initGroups)
2344 rewriter.create<calyx::EnableOp>(group.getLoc(), group.getName());
2345 }
2346
2347 calyx::WhileOp buildWhileCtrlOp(ScfWhileOp whileOp,
2348 SmallVector<calyx::GroupOp> initGroups,
2349 PatternRewriter &rewriter) const {
2350 Location loc = whileOp.getLoc();
2351 /// Insert while iter arg initialization group(s). Emit a
2352 /// parallel group to assign one or more registers all at once.
2353 insertParInitGroups(rewriter, loc, initGroups);
2354
2355 /// Insert the while op itself.
2356 auto cond = whileOp.getConditionValue();
2357 auto condGroup = getState<ComponentLoweringState>()
2358 .getEvaluatingGroup<calyx::CombGroupOp>(cond);
2359 auto symbolAttr = FlatSymbolRefAttr::get(
2360 StringAttr::get(getContext(), condGroup.getSymName()));
2361 return rewriter.create<calyx::WhileOp>(loc, cond, symbolAttr);
2362 }
2363
2364 calyx::RepeatOp buildForCtrlOp(ScfForOp forOp,
2365 SmallVector<calyx::GroupOp> const &initGroups,
2366 uint64_t bound,
2367 PatternRewriter &rewriter) const {
2368 Location loc = forOp.getLoc();
2369 // Insert for iter arg initialization group(s). Emit a
2370 // parallel group to assign one or more registers all at once.
2371 insertParInitGroups(rewriter, loc, initGroups);
2372
2373 // Insert the repeatOp that corresponds to the For loop.
2374 return rewriter.create<calyx::RepeatOp>(loc, bound);
2375 }
2376};
2377
2378/// LateSSAReplacement contains various functions for replacing SSA values that
2379/// were not replaced during op construction.
2381 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2382
2383 LogicalResult partiallyLowerFuncToComp(FuncOp funcOp,
2384 PatternRewriter &) const override {
2385 funcOp.walk([&](scf::IfOp op) {
2386 for (auto res : getState<ComponentLoweringState>().getResultRegs(op))
2387 op.getOperation()->getResults()[res.first].replaceAllUsesWith(
2388 res.second.getOut());
2389 });
2390
2391 funcOp.walk([&](scf::WhileOp op) {
2392 /// The yielded values returned from the while op will be present in the
2393 /// iterargs registers post execution of the loop.
2394 /// This is done now, as opposed to during BuildWhileGroups since if the
2395 /// results of the whileOp were replaced before
2396 /// BuildOpGroups/BuildControl, the whileOp would get dead-code
2397 /// eliminated.
2398 ScfWhileOp whileOp(op);
2399 for (auto res :
2400 getState<ComponentLoweringState>().getWhileLoopIterRegs(whileOp))
2401 whileOp.getOperation()->getResults()[res.first].replaceAllUsesWith(
2402 res.second.getOut());
2403 });
2404
2405 funcOp.walk([&](memref::LoadOp loadOp) {
2406 if (calyx::singleLoadFromMemory(loadOp)) {
2407 /// In buildOpGroups we did not replace loadOp's results, to ensure a
2408 /// link between evaluating groups (which fix the input addresses of a
2409 /// memory op) and a readData result. Now, we may replace these SSA
2410 /// values with their memoryOp readData output.
2411 loadOp.getResult().replaceAllUsesWith(
2412 getState<ComponentLoweringState>()
2413 .getMemoryInterface(loadOp.getMemref())
2414 .readData());
2415 }
2416 });
2417
2418 return success();
2419 }
2420};
2421
2422/// Erases FuncOp operations.
2424 using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
2425
2426 LogicalResult matchAndRewrite(FuncOp funcOp,
2427 PatternRewriter &rewriter) const override {
2428 rewriter.eraseOp(funcOp);
2429 return success();
2430 }
2431
2432 LogicalResult
2434 PatternRewriter &rewriter) const override {
2435 return success();
2436 }
2437};
2438
2439} // namespace scftocalyx
2440
2441namespace {
2442
2443using namespace circt::scftocalyx;
2444
2445//===----------------------------------------------------------------------===//
2446// Pass driver
2447//===----------------------------------------------------------------------===//
2448class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
2449public:
2450 SCFToCalyxPass()
2451 : SCFToCalyxBase<SCFToCalyxPass>(), partialPatternRes(success()) {}
2452 void runOnOperation() override;
2453
2454 LogicalResult setTopLevelFunction(mlir::ModuleOp moduleOp,
2455 std::string &topLevelFunction) {
2456 if (!topLevelFunctionOpt.empty()) {
2457 if (SymbolTable::lookupSymbolIn(moduleOp, topLevelFunctionOpt) ==
2458 nullptr) {
2459 moduleOp.emitError() << "Top level function '" << topLevelFunctionOpt
2460 << "' not found in module.";
2461 return failure();
2462 }
2463 topLevelFunction = topLevelFunctionOpt;
2464 } else {
2465 /// No top level function set; infer top level if the module only contains
2466 /// a single function, else, throw error.
2467 auto funcOps = moduleOp.getOps<FuncOp>();
2468 if (std::distance(funcOps.begin(), funcOps.end()) == 1)
2469 topLevelFunction = (*funcOps.begin()).getSymName().str();
2470 else {
2471 moduleOp.emitError()
2472 << "Module contains multiple functions, but no top level "
2473 "function was set. Please see --top-level-function";
2474 return failure();
2475 }
2476 }
2477
2478 return createOptNewTopLevelFn(moduleOp, topLevelFunction);
2479 }
2480
2481 struct LoweringPattern {
2482 enum class Strategy { Once, Greedy };
2483 RewritePatternSet pattern;
2484 Strategy strategy;
2485 };
2486
2487 //// Labels the entry point of a Calyx program.
2488 /// Furthermore, this function performs validation on the input function,
2489 /// to ensure that we've implemented the capabilities necessary to convert
2490 /// it.
2491 LogicalResult labelEntryPoint(StringRef topLevelFunction) {
2492 // Program legalization - the partial conversion driver will not run
2493 // unless some pattern is provided - provide a dummy pattern.
2494 struct DummyPattern : public OpRewritePattern<mlir::ModuleOp> {
2495 using OpRewritePattern::OpRewritePattern;
2496 LogicalResult matchAndRewrite(mlir::ModuleOp,
2497 PatternRewriter &) const override {
2498 return failure();
2499 }
2500 };
2501
2502 ConversionTarget target(getContext());
2503 target.addLegalDialect<calyx::CalyxDialect>();
2504 target.addLegalDialect<scf::SCFDialect>();
2505 target.addIllegalDialect<hw::HWDialect>();
2506 target.addIllegalDialect<comb::CombDialect>();
2507
2508 // Only accept std operations which we've added lowerings for
2509 target.addIllegalDialect<FuncDialect>();
2510 target.addIllegalDialect<ArithDialect>();
2511 target.addLegalOp<AddIOp, SelectOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp,
2512 ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
2513 CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
2514 RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
2515 ExtSIOp, CallOp, AddFOp, MulFOp, CmpFOp>();
2516
2517 RewritePatternSet legalizePatterns(&getContext());
2518 legalizePatterns.add<DummyPattern>(&getContext());
2519 DenseSet<Operation *> legalizedOps;
2520 if (applyPartialConversion(getOperation(), target,
2521 std::move(legalizePatterns))
2522 .failed())
2523 return failure();
2524
2525 // Program conversion
2526 return calyx::applyModuleOpConversion(getOperation(), topLevelFunction);
2527 }
2528
2529 /// 'Once' patterns are expected to take an additional LogicalResult&
2530 /// argument, to forward their result state (greedyPatternRewriteDriver
2531 /// results are skipped for Once patterns).
2532 template <typename TPattern, typename... PatternArgs>
2533 void addOncePattern(SmallVectorImpl<LoweringPattern> &patterns,
2534 PatternArgs &&...args) {
2535 RewritePatternSet ps(&getContext());
2536 ps.add<TPattern>(&getContext(), partialPatternRes, args...);
2537 patterns.push_back(
2538 LoweringPattern{std::move(ps), LoweringPattern::Strategy::Once});
2539 }
2540
2541 template <typename TPattern, typename... PatternArgs>
2542 void addGreedyPattern(SmallVectorImpl<LoweringPattern> &patterns,
2543 PatternArgs &&...args) {
2544 RewritePatternSet ps(&getContext());
2545 ps.add<TPattern>(&getContext(), args...);
2546 patterns.push_back(
2547 LoweringPattern{std::move(ps), LoweringPattern::Strategy::Greedy});
2548 }
2549
2550 LogicalResult runPartialPattern(RewritePatternSet &pattern, bool runOnce) {
2551 assert(pattern.getNativePatterns().size() == 1 &&
2552 "Should only apply 1 partial lowering pattern at once");
2553
2554 // During component creation, the function body is inlined into the
2555 // component body for further processing. However, proper control flow
2556 // will only be established later in the conversion process, so ensure
2557 // that rewriter optimizations (especially DCE) are disabled.
2558 GreedyRewriteConfig config;
2559 config.enableRegionSimplification =
2560 mlir::GreedySimplifyRegionLevel::Disabled;
2561 if (runOnce)
2562 config.maxIterations = 1;
2563
2564 /// Can't return applyPatternsGreedily. Root isn't
2565 /// necessarily erased so it will always return failed(). Instead,
2566 /// forward the 'succeeded' value from PartialLoweringPatternBase.
2567 (void)applyPatternsGreedily(getOperation(), std::move(pattern), config);
2568 return partialPatternRes;
2569 }
2570
2571private:
2572 LogicalResult partialPatternRes;
2573 std::shared_ptr<calyx::CalyxLoweringState> loweringState = nullptr;
2574
2575 /// Creates a new new top-level function based on `baseName`.
2576 FuncOp createNewTopLevelFn(ModuleOp moduleOp, std::string &baseName) {
2577 std::string newName = "main";
2578
2579 if (auto *existingMainOp = SymbolTable::lookupSymbolIn(moduleOp, newName)) {
2580 auto existingMainFunc = dyn_cast<FuncOp>(existingMainOp);
2581 if (existingMainFunc == nullptr) {
2582 moduleOp.emitError() << "Symbol 'main' exists but is not a function";
2583 return nullptr;
2584 }
2585 unsigned counter = 0;
2586 std::string newOldName = baseName;
2587 while (SymbolTable::lookupSymbolIn(moduleOp, newOldName))
2588 newOldName = llvm::join_items("_", baseName, std::to_string(++counter));
2589 existingMainFunc.setName(newOldName);
2590 if (baseName == "main")
2591 baseName = newOldName;
2592 }
2593
2594 // Create the new "main" function
2595 OpBuilder builder(moduleOp.getContext());
2596 builder.setInsertionPointToStart(moduleOp.getBody());
2597
2598 FunctionType funcType = builder.getFunctionType({}, {});
2599
2600 if (auto newFunc =
2601 builder.create<FuncOp>(moduleOp.getLoc(), newName, funcType))
2602 return newFunc;
2603
2604 return nullptr;
2605 }
2606
2607 /// Insert a call from the newly created top-level function/`caller` to the
2608 /// old top-level function/`callee`; and create `memref.alloc`s inside the new
2609 /// top-level function for arguments with `memref` types and for the
2610 /// `memref.alloc`s inside `callee`.
2611 void insertCallFromNewTopLevel(OpBuilder &builder, FuncOp caller,
2612 FuncOp callee) {
2613 if (caller.getBody().empty()) {
2614 caller.addEntryBlock();
2615 }
2616
2617 Block *callerEntryBlock = &caller.getBody().front();
2618 builder.setInsertionPointToStart(callerEntryBlock);
2619
2620 // For those non-memref arguments passing to the original top-level
2621 // function, we need to copy them to the new top-level function.
2622 SmallVector<Type, 4> nonMemRefCalleeArgTypes;
2623 for (auto arg : callee.getArguments()) {
2624 if (!isa<MemRefType>(arg.getType())) {
2625 nonMemRefCalleeArgTypes.push_back(arg.getType());
2626 }
2627 }
2628
2629 for (Type type : nonMemRefCalleeArgTypes) {
2630 callerEntryBlock->addArgument(type, caller.getLoc());
2631 }
2632
2633 FunctionType callerFnType = caller.getFunctionType();
2634 SmallVector<Type, 4> updatedCallerArgTypes(
2635 caller.getFunctionType().getInputs());
2636 updatedCallerArgTypes.append(nonMemRefCalleeArgTypes.begin(),
2637 nonMemRefCalleeArgTypes.end());
2638 caller.setType(FunctionType::get(caller.getContext(), updatedCallerArgTypes,
2639 callerFnType.getResults()));
2640
2641 Block *calleeFnBody = &callee.getBody().front();
2642 unsigned originalCalleeArgNum = callee.getArguments().size();
2643
2644 SmallVector<Value, 4> extraMemRefArgs;
2645 SmallVector<Type, 4> extraMemRefArgTypes;
2646 SmallVector<Value, 4> extraMemRefOperands;
2647 SmallVector<Operation *, 4> opsToModify;
2648 for (auto &op : callee.getBody().getOps()) {
2649 if (isa<memref::AllocaOp, memref::AllocOp, memref::GetGlobalOp>(op))
2650 opsToModify.push_back(&op);
2651 }
2652
2653 // Replace `alloc`/`getGlobal` in the original top-level with new
2654 // corresponding operations in the new top-level.
2655 builder.setInsertionPointToEnd(callerEntryBlock);
2656 for (auto *op : opsToModify) {
2657 // TODO (https://github.com/llvm/circt/issues/7764)
2658 Value newOpRes;
2659 TypeSwitch<Operation *>(op)
2660 .Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
2661 newOpRes = builder.create<memref::AllocaOp>(callee.getLoc(),
2662 allocaOp.getType());
2663 })
2664 .Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
2665 newOpRes = builder.create<memref::AllocOp>(callee.getLoc(),
2666 allocOp.getType());
2667 })
2668 .Case<memref::GetGlobalOp>([&](memref::GetGlobalOp getGlobalOp) {
2669 newOpRes = builder.create<memref::GetGlobalOp>(
2670 caller.getLoc(), getGlobalOp.getType(), getGlobalOp.getName());
2671 })
2672 .Default([&](Operation *defaultOp) {
2673 llvm::report_fatal_error("Unsupported operation in TypeSwitch");
2674 });
2675 extraMemRefOperands.push_back(newOpRes);
2676
2677 calleeFnBody->addArgument(newOpRes.getType(), callee.getLoc());
2678 BlockArgument newBodyArg = calleeFnBody->getArguments().back();
2679 op->getResult(0).replaceAllUsesWith(newBodyArg);
2680 op->erase();
2681 extraMemRefArgs.push_back(newBodyArg);
2682 extraMemRefArgTypes.push_back(newBodyArg.getType());
2683 }
2684
2685 SmallVector<Type, 4> updatedCalleeArgTypes(
2686 callee.getFunctionType().getInputs());
2687 updatedCalleeArgTypes.append(extraMemRefArgTypes.begin(),
2688 extraMemRefArgTypes.end());
2689 callee.setType(FunctionType::get(callee.getContext(), updatedCalleeArgTypes,
2690 callee.getFunctionType().getResults()));
2691
2692 unsigned otherArgsCount = 0;
2693 SmallVector<Value, 4> calleeArgFnOperands;
2694 builder.setInsertionPointToStart(callerEntryBlock);
2695 for (auto arg : callee.getArguments().take_front(originalCalleeArgNum)) {
2696 if (isa<MemRefType>(arg.getType())) {
2697 auto memrefType = cast<MemRefType>(arg.getType());
2698 auto allocOp =
2699 builder.create<memref::AllocOp>(callee.getLoc(), memrefType);
2700 calleeArgFnOperands.push_back(allocOp);
2701 } else {
2702 auto callerArg = callerEntryBlock->getArgument(otherArgsCount++);
2703 calleeArgFnOperands.push_back(callerArg);
2704 }
2705 }
2706
2707 SmallVector<Value, 4> fnOperands;
2708 fnOperands.append(calleeArgFnOperands.begin(), calleeArgFnOperands.end());
2709 fnOperands.append(extraMemRefOperands.begin(), extraMemRefOperands.end());
2710 auto calleeName =
2711 SymbolRefAttr::get(builder.getContext(), callee.getSymName());
2712 auto resultTypes = callee.getResultTypes();
2713
2714 builder.setInsertionPointToEnd(callerEntryBlock);
2715 builder.create<CallOp>(caller.getLoc(), calleeName, resultTypes,
2716 fnOperands);
2717 builder.create<ReturnOp>(caller.getLoc());
2718 }
2719
2720 /// Conditionally creates an optional new top-level function; and inserts a
2721 /// call from the new top-level function to the old top-level function if we
2722 /// did create one
2723 LogicalResult createOptNewTopLevelFn(ModuleOp moduleOp,
2724 std::string &topLevelFunction) {
2725 auto hasMemrefArguments = [](FuncOp func) {
2726 return std::any_of(
2727 func.getArguments().begin(), func.getArguments().end(),
2728 [](BlockArgument arg) { return isa<MemRefType>(arg.getType()); });
2729 };
2730
2731 /// We only create a new top-level function and call the original top-level
2732 /// function from the new one if the original top-level has `memref` in its
2733 /// argument
2734 auto funcOps = moduleOp.getOps<FuncOp>();
2735 bool hasMemrefArgsInTopLevel =
2736 std::any_of(funcOps.begin(), funcOps.end(), [&](auto funcOp) {
2737 return funcOp.getName() == topLevelFunction &&
2738 hasMemrefArguments(funcOp);
2739 });
2740
2741 if (hasMemrefArgsInTopLevel) {
2742 auto newTopLevelFunc = createNewTopLevelFn(moduleOp, topLevelFunction);
2743 if (!newTopLevelFunc)
2744 return failure();
2745
2746 OpBuilder builder(moduleOp.getContext());
2747 Operation *oldTopLevelFuncOp =
2748 SymbolTable::lookupSymbolIn(moduleOp, topLevelFunction);
2749 if (auto oldTopLevelFunc = dyn_cast<FuncOp>(oldTopLevelFuncOp))
2750 insertCallFromNewTopLevel(builder, newTopLevelFunc, oldTopLevelFunc);
2751 else {
2752 moduleOp.emitOpError("Original top-level function not found!");
2753 return failure();
2754 }
2755 topLevelFunction = "main";
2756 }
2757
2758 return success();
2759 }
2760};
2761
2762void SCFToCalyxPass::runOnOperation() {
2763 // Clear internal state. See https://github.com/llvm/circt/issues/3235
2764 loweringState.reset();
2765 partialPatternRes = LogicalResult::failure();
2766
2767 std::string topLevelFunction;
2768 if (failed(setTopLevelFunction(getOperation(), topLevelFunction))) {
2769 signalPassFailure();
2770 return;
2771 }
2772
2773 /// Start conversion
2774 if (failed(labelEntryPoint(topLevelFunction))) {
2775 signalPassFailure();
2776 return;
2777 }
2778 loweringState = std::make_shared<calyx::CalyxLoweringState>(getOperation(),
2779 topLevelFunction);
2780
2781 /// --------------------------------------------------------------------------
2782 /// If you are a developer, it may be helpful to add a
2783 /// 'getOperation()->dump()' call after the execution of each stage to
2784 /// view the transformations that's going on.
2785 /// --------------------------------------------------------------------------
2786
2787 /// A mapping is maintained between a function operation and its corresponding
2788 /// Calyx component.
2789 DenseMap<FuncOp, calyx::ComponentOp> funcMap;
2790 SmallVector<LoweringPattern, 8> loweringPatterns;
2791 calyx::PatternApplicationState patternState;
2792
2793 /// Creates a new Calyx component for each FuncOp in the inpurt module.
2794 addOncePattern<FuncOpConversion>(loweringPatterns, patternState, funcMap,
2795 *loweringState);
2796
2797 /// This pass inlines scf.ExecuteRegionOp's by adding control-flow.
2798 addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);
2799
2800 /// Partial evaluate the scf.ParallelOp and apply the scf.IfOp
2801 /// canonicalization optionally.
2802 addOncePattern<BuildParGroups>(loweringPatterns, patternState, funcMap,
2803 *loweringState);
2804
2805 /// This pattern converts all index typed values to an i32 integer.
2806 addOncePattern<calyx::ConvertIndexTypes>(loweringPatterns, patternState,
2807 funcMap, *loweringState);
2808
2809 /// This pattern creates registers for all basic-block arguments.
2810 addOncePattern<calyx::BuildBasicBlockRegs>(loweringPatterns, patternState,
2811 funcMap, *loweringState);
2812
2813 addOncePattern<calyx::BuildCallInstance>(loweringPatterns, patternState,
2814 funcMap, *loweringState);
2815
2816 /// This pattern creates registers for the function return values.
2817 addOncePattern<calyx::BuildReturnRegs>(loweringPatterns, patternState,
2818 funcMap, *loweringState);
2819
2820 /// This pattern creates registers for iteration arguments of scf.while
2821 /// operations. Additionally, creates a group for assigning the initial
2822 /// value of the iteration argument registers.
2823 addOncePattern<BuildWhileGroups>(loweringPatterns, patternState, funcMap,
2824 *loweringState);
2825
2826 /// This pattern creates registers for iteration arguments of scf.for
2827 /// operations. Additionally, creates a group for assigning the initial
2828 /// value of the iteration argument registers.
2829 addOncePattern<BuildForGroups>(loweringPatterns, patternState, funcMap,
2830 *loweringState);
2831
2832 addOncePattern<BuildIfGroups>(loweringPatterns, patternState, funcMap,
2833 *loweringState);
2834
2835 /// This pattern converts operations within basic blocks to Calyx library
2836 /// operators. Combinational operations are assigned inside a
2837 /// calyx::CombGroupOp, and sequential inside calyx::GroupOps.
2838 /// Sequential groups are registered with the Block* of which the operation
2839 /// originated from. This is used during control schedule generation. By
2840 /// having a distinct group for each operation, groups are analogous to SSA
2841 /// values in the source program.
2842 addOncePattern<BuildOpGroups>(loweringPatterns, patternState, funcMap,
2843 *loweringState, writeJsonOpt);
2844
2845 /// This pattern traverses the CFG of the program and generates a control
2846 /// schedule based on the calyx::GroupOp's which were registered for each
2847 /// basic block in the source function.
2848 addOncePattern<BuildControl>(loweringPatterns, patternState, funcMap,
2849 *loweringState);
2850
2851 /// This pass recursively inlines use-def chains of combinational logic (from
2852 /// non-stateful groups) into groups referenced in the control schedule.
2853 addOncePattern<calyx::InlineCombGroups>(loweringPatterns, patternState,
2854 *loweringState);
2855
2856 /// This pattern performs various SSA replacements that must be done
2857 /// after control generation.
2858 addOncePattern<LateSSAReplacement>(loweringPatterns, patternState, funcMap,
2859 *loweringState);
2860
2861 /// Eliminate any unused combinational groups. This is done before
2862 /// calyx::RewriteMemoryAccesses to avoid inferring slice components for
2863 /// groups that will be removed.
2864 addGreedyPattern<calyx::EliminateUnusedCombGroups>(loweringPatterns);
2865
2866 /// This pattern rewrites accesses to memories which are too wide due to
2867 /// index types being converted to a fixed-width integer type.
2868 addOncePattern<calyx::RewriteMemoryAccesses>(loweringPatterns, patternState,
2869 *loweringState);
2870
2871 /// This pattern removes the source FuncOp which has now been converted into
2872 /// a Calyx component.
2873 addOncePattern<CleanupFuncOps>(loweringPatterns, patternState, funcMap,
2874 *loweringState);
2875
2876 /// Sequentially apply each lowering pattern.
2877 for (auto &pat : loweringPatterns) {
2878 LogicalResult partialPatternRes = runPartialPattern(
2879 pat.pattern,
2880 /*runOnce=*/pat.strategy == LoweringPattern::Strategy::Once);
2881 if (succeeded(partialPatternRes))
2882 continue;
2883 signalPassFailure();
2884 return;
2885 }
2886
2887 //===--------------------------------------------------------------------===//
2888 // Cleanup patterns
2889 //===--------------------------------------------------------------------===//
2890 RewritePatternSet cleanupPatterns(&getContext());
2891 cleanupPatterns.add<calyx::MultipleGroupDonePattern,
2893 if (failed(
2894 applyPatternsGreedily(getOperation(), std::move(cleanupPatterns)))) {
2895 signalPassFailure();
2896 return;
2897 }
2898
2899 if (ciderSourceLocationMetadata) {
2900 // Debugging information for the Cider debugger.
2901 // Reference: https://docs.calyxir.org/debug/cider.html
2902 SmallVector<Attribute, 16> sourceLocations;
2903 getOperation()->walk([&](calyx::ComponentOp component) {
2904 return getCiderSourceLocationMetadata(component, sourceLocations);
2905 });
2906
2907 MLIRContext *context = getOperation()->getContext();
2908 getOperation()->setAttr("calyx.metadata",
2909 ArrayAttr::get(context, sourceLocations));
2910 }
2911}
2912} // namespace
2913
2914//===----------------------------------------------------------------------===//
2915// Pass initialization
2916//===----------------------------------------------------------------------===//
2917
2918std::unique_ptr<OperationPass<ModuleOp>> createSCFToCalyxPass() {
2919 return std::make_unique<SCFToCalyxPass>();
2920}
2921
2922} // namespace circt
assert(baseType &&"element must be base type")
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
Strategy strategy
std::shared_ptr< calyx::CalyxLoweringState > loweringState
LogicalResult partialPatternRes
An interface for conversion passes that lower Calyx programs.
std::string irName(ValueOrBlock &v)
Returns a meaningful name for a value within the program scope.
std::string blockName(Block *b)
Returns a meaningful name for a block within the program scope (removes the ^ prefix from block names...
StringRef getTopLevelFunction() const
Returns the name of the top-level function in the source program.
T * getState(calyx::ComponentOp op)
Returns the component lowering state associated with op.
void setFuncOpResultMapping(const DenseMap< unsigned, unsigned > &mapping)
Assign a mapping between the source funcOp result indices and the corresponding output port indices o...
std::string getUniqueName(StringRef prefix)
Returns a unique name within compOp with the provided prefix.
calyx::ComponentOp component
The component which this lowering state is associated to.
void registerMemoryInterface(Value memref, const calyx::MemoryInterface &memoryInterface)
Registers a memory interface as being associated with a memory identified by 'memref'.
calyx::ComponentOp getComponentOp()
Returns the calyx::ComponentOp associated with this lowering state.
void setDataField(StringRef name, llvm::json::Array data)
ComponentLoweringStateInterface(calyx::ComponentOp component)
void setFormat(StringRef name, std::string numType, bool isSigned, unsigned width)
FuncOpPartialLoweringPatterns are patterns which intend to match on FuncOps and then perform their ow...
calyx::ComponentOp getComponent() const
Returns the component operation associated with the currently executing partial lowering.
DenseMap< mlir::func::FuncOp, calyx::ComponentOp > & functionMapping
CalyxLoweringState & loweringState() const
Return the calyx lowering state for this pattern.
FuncOpPartialLoweringPattern(MLIRContext *context, LogicalResult &resRef, PatternApplicationState &patternState, DenseMap< mlir::func::FuncOp, calyx::ComponentOp > &map, calyx::CalyxLoweringState &state)
calyx::GroupOp getLoopLatchGroup(ScfWhileOp op)
Retrieve the loop latch group registered for op.
void setLoopLatchGroup(ScfWhileOp op, calyx::GroupOp group)
Registers grp to be the loop latch group of op.
calyx::RegisterOp getLoopIterReg(ScfForOp op, unsigned idx)
Return a mapping of block argument indices to block argument.
void addLoopIterReg(ScfWhileOp op, calyx::RegisterOp reg, unsigned idx)
Register reg as being the idx'th iter_args register for 'op'.
void setLoopInitGroups(ScfWhileOp op, SmallVector< calyx::GroupOp > groups)
Registers groups to be the loop init groups of op.
SmallVector< calyx::GroupOp > getLoopInitGroups(ScfWhileOp op)
Retrieve the loop init groups registered for op.
calyx::GroupOp buildLoopIterArgAssignments(OpBuilder &builder, ScfWhileOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
Creates a new group that assigns the 'ops' values to the iter arg registers of the loop operation.
const DenseMap< unsigned, calyx::RegisterOp > & getLoopIterRegs(ScfWhileOp op)
Return a mapping of block argument indices to block argument.
Holds common utilities used for scheduling when lowering to Calyx.
Builds a control schedule by traversing the CFG of the function and associating this with the previou...
calyx::RepeatOp buildForCtrlOp(ScfForOp forOp, SmallVector< calyx::GroupOp > const &initGroups, uint64_t bound, PatternRewriter &rewriter) const
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult schedulePath(PatternRewriter &rewriter, const DenseSet< Block * > &path, Location loc, Block *from, Block *to, Block *parentCtrlBlock) const
Schedules a block by inserting a branch argument assignment block (if any) before recursing into the ...
calyx::WhileOp buildWhileCtrlOp(ScfWhileOp whileOp, SmallVector< calyx::GroupOp > initGroups, PatternRewriter &rewriter) const
LogicalResult scheduleBasicBlock(PatternRewriter &rewriter, const DenseSet< Block * > &path, mlir::Block *parentCtrlBlock, mlir::Block *block) const
Sequentially schedules the groups that registered themselves with 'block'.
LogicalResult buildCFGControl(DenseSet< Block * > path, PatternRewriter &rewriter, mlir::Block *parentCtrlBlock, mlir::Block *preBlock, mlir::Block *block) const
void insertParInitGroups(PatternRewriter &rewriter, Location loc, const SmallVector< calyx::GroupOp > &initGroups) const
In BuildForGroups, a register is created for the iteration argument of the for op.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Iterate through the operations of a source function and instantiate components or primitives based on...
BuildOpGroups(MLIRContext *context, LogicalResult &resRef, calyx::PatternApplicationState &patternState, DenseMap< mlir::func::FuncOp, calyx::ComponentOp > &map, calyx::CalyxLoweringState &state, mlir::Pass::Option< std::string > &writeJsonOpt)
TGroupOp createGroupForOp(PatternRewriter &rewriter, Operation *op) const
Creates a group named by the basic block which the input op resides in.
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op) const
buildLibraryOp which provides in- and output types based on the operands and results of the op argume...
LogicalResult buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp) const
Op builder specializations.
calyx::RegisterOp createSignalRegister(PatternRewriter &rewriter, Value signal, bool invert, StringRef nameSuffix, calyx::CompareFOpIEEE754 calyxCmpFOp, calyx::GroupOp group) const
void assignAddressPorts(PatternRewriter &rewriter, Location loc, calyx::GroupInterface group, calyx::MemoryInterface memoryInterface, Operation::operand_range addressValues) const
Creates assignments within the provided group to the address ports of the memoryOp based on the provi...
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op, TypeRange srcTypes, TypeRange dstTypes) const
buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the source operation TSrcOp.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult buildLibraryBinaryPipeOp(PatternRewriter &rewriter, TSrcOp op, TOpType opPipe, Value out) const
buildLibraryBinaryPipeOp will build a TCalyxLibBinaryPipeOp, to deal with MulIOp, DivUIOp and RemUIOp...
mlir::Pass::Option< std::string > & writeJson
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partialEval(PatternRewriter &rewriter, scf::ParallelOp scfParOp) const
In BuildWhileGroups, a register is created for each iteration argumenet of the while op.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Erases FuncOp operations.
LogicalResult matchAndRewrite(FuncOp funcOp, PatternRewriter &rewriter) const override
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
Handles the current state of lowering of a Calyx component.
ComponentLoweringState(calyx::ComponentOp component)
void setForLoopInitGroups(ScfForOp op, SmallVector< calyx::GroupOp > groups)
calyx::GroupOp buildForLoopIterArgAssignments(OpBuilder &builder, ScfForOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
void setForLoopLatchGroup(ScfForOp op, calyx::GroupOp group)
SmallVector< calyx::GroupOp > getForLoopInitGroups(ScfForOp op)
void addForLoopIterReg(ScfForOp op, calyx::RegisterOp reg, unsigned idx)
calyx::GroupOp getForLoopLatchGroup(ScfForOp op)
calyx::RegisterOp getForLoopIterReg(ScfForOp op, unsigned idx)
const DenseMap< unsigned, calyx::RegisterOp > & getForLoopIterRegs(ScfForOp op)
DenseMap< Operation *, calyx::GroupOp > elseGroup
DenseMap< Operation *, calyx::GroupOp > thenGroup
const DenseMap< unsigned, calyx::RegisterOp > & getResultRegs(scf::IfOp op)
void setElseGroup(scf::IfOp op, calyx::GroupOp group)
void setResultRegs(scf::IfOp op, calyx::RegisterOp reg, unsigned idx)
void setThenGroup(scf::IfOp op, calyx::GroupOp group)
DenseMap< Operation *, DenseMap< unsigned, calyx::RegisterOp > > resultRegs
calyx::RegisterOp getResultRegs(scf::IfOp op, unsigned idx)
calyx::GroupOp getThenGroup(scf::IfOp op)
calyx::GroupOp getElseGroup(scf::IfOp op)
Inlines Calyx ExecuteRegionOp operations within their parent blocks.
LogicalResult matchAndRewrite(scf::ExecuteRegionOp execOp, PatternRewriter &rewriter) const override
LateSSAReplacement contains various functions for replacing SSA values that were not replaced during ...
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &) const override
std::optional< int64_t > getBound() override
Block::BlockArgListType getBodyArgs() override
Block * getBodyBlock() override
Block * getBodyBlock() override
Block::BlockArgListType getBodyArgs() override
Value getConditionValue() override
std::optional< int64_t > getBound() override
Block * getConditionBlock() override
calyx::GroupOp buildWhileLoopIterArgAssignments(OpBuilder &builder, ScfWhileOp op, calyx::ComponentOp componentOp, Twine uniqueSuffix, MutableArrayRef< OpOperand > ops)
void setWhileLoopInitGroups(ScfWhileOp op, SmallVector< calyx::GroupOp > groups)
SmallVector< calyx::GroupOp > getWhileLoopInitGroups(ScfWhileOp op)
void addWhileLoopIterReg(ScfWhileOp op, calyx::RegisterOp reg, unsigned idx)
void setWhileLoopLatchGroup(ScfWhileOp op, calyx::GroupOp group)
const DenseMap< unsigned, calyx::RegisterOp > & getWhileLoopIterRegs(ScfWhileOp op)
calyx::GroupOp getWhileLoopLatchGroup(ScfWhileOp op)
void addMandatoryComponentPorts(PatternRewriter &rewriter, SmallVectorImpl< calyx::PortInfo > &ports)
void buildAssignmentsForRegisterWrite(OpBuilder &builder, calyx::GroupOp groupOp, calyx::ComponentOp componentOp, calyx::RegisterOp &reg, Value inputValue)
Creates register assignment operations within the provided groupOp.
DenseMap< const mlir::RewritePattern *, SmallPtrSet< Operation *, 16 > > PatternApplicationState
Extra state that is passed to all PartialLoweringPatterns so they can record when they have run on an...
PredicateInfo getPredicateInfo(mlir::arith::CmpFPredicate pred)
Type normalizeType(OpBuilder &builder, Type type)
LogicalResult applyModuleOpConversion(mlir::ModuleOp, StringRef topLevelFunction)
Helper to update the top-level ModuleOp to set the entrypoing function.
WalkResult getCiderSourceLocationMetadata(calyx::ComponentOp component, SmallVectorImpl< Attribute > &sourceLocations)
bool matchConstantOp(Operation *op, APInt &value)
unsigned handleZeroWidth(int64_t dim)
hw::ConstantOp createConstant(Location loc, OpBuilder &builder, ComponentOp component, size_t width, size_t value)
A helper function to create constants in the HW dialect.
bool noStoresToMemory(Value memoryReference)
bool singleLoadFromMemory(Value memoryReference)
Type toBitVector(T type)
Performs a bit cast from a non-signless integer type value, such as a floating point value,...
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:48
static constexpr std::string_view sPortNameAttr
Definition SCFToCalyx.h:29
static LogicalResult buildAllocOp(ComponentLoweringState &componentState, PatternRewriter &rewriter, TAllocOp allocOp)
std::variant< calyx::GroupOp, WhileScheduleable, ForScheduleable, IfScheduleable, CallScheduleable, ParScheduleable > Scheduleable
A variant of types representing scheduleable operations.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< OperationPass< ModuleOp > > createSCFToCalyxPass()
Create an SCF to Calyx conversion pass.
When building groups which contain accesses to multiple sequential components, a group_done op is cre...
GroupDoneOp's are terminator operations and should therefore be the last operator in a group.
This holds information about the port for either a Component or Cell.
Definition CalyxOps.h:89
Predicate information for the floating point comparisons.
calyx::InstanceOp instanceOp
Instance for invoking.
ScfForOp forOp
For operation to schedule.
Creates a new Calyx component for each FuncOp in the program.
LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &rewriter) const override
scf::ParallelOp parOp
Parallel operation to schedule.
ScfWhileOp whileOp
While operation to schedule.