10 #include "../PassDetail.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "convert-to-arcs"
20 using namespace circt;
23 using llvm::MapVector;
26 return op->hasTrait<OpTrait::ConstantLike>() ||
28 seq::ClockGateOp>(op) ||
29 op->getNumResults() > 1;
38 LogicalResult run(ModuleOp module);
40 LogicalResult analyzeFanIn();
48 SmallVector<Operation *> arcBreakers;
52 SmallVector<Operation *> postOrder;
56 MapVector<Operation *, APInt> faninMasks;
59 MapVector<APInt, DenseSet<Operation *>> faninMaskGroups;
62 SmallVector<mlir::CallOpInterface> arcUses;
70 LogicalResult Converter::run(ModuleOp module) {
71 for (
auto &op : module.getOps())
73 op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
74 globalNamespace.newName(sym.getValue());
75 for (
auto module : module.getOps<
HWModuleOp>())
76 if (failed(runOnModule(module)))
81 LogicalResult Converter::runOnModule(
HWModuleOp module) {
84 arcBreakerIndices.clear();
85 for (Operation &op : *module.getBodyBlock()) {
86 if (op.getNumRegions() > 0)
87 return op.emitOpError(
"has regions; not supported by ConvertToArcs");
90 arcBreakerIndices[&op] = arcBreakers.size();
91 arcBreakers.push_back(&op);
94 if (module.getBodyBlock()->without_terminator().empty() &&
95 isa<hw::OutputOp>(module.getBodyBlock()->getTerminator()))
97 LLVM_DEBUG(llvm::dbgs() <<
"Analyzing " << module.getModuleNameAttr() <<
" ("
98 << arcBreakers.size() <<
" breakers)\n");
103 if (failed(analyzeFanIn()))
109 if (failed(absorbRegs(module)))
114 LogicalResult Converter::analyzeFanIn() {
115 SmallVector<std::tuple<Operation *, unsigned>> worklist;
119 for (
auto *op : arcBreakers) {
120 unsigned index = arcBreakerIndices.lookup(op);
121 auto mask = APInt::getOneBitSet(arcBreakers.size(), index);
122 faninMasks[op] =
mask;
123 worklist.push_back({op, 0});
127 DenseSet<Operation *> seen;
128 DenseSet<Operation *> finished;
130 while (!worklist.empty()) {
131 auto &[op, operandIdx] = worklist.back();
132 if (operandIdx == op->getNumOperands()) {
134 postOrder.push_back(op);
140 auto operand = op->getOperand(operandIdx++);
141 auto *definingOp = operand.getDefiningOp();
143 finished.contains(definingOp))
145 if (!seen.insert(definingOp).second) {
146 definingOp->emitError(
"combinational loop detected");
149 worklist.push_back({definingOp, 0});
151 LLVM_DEBUG(llvm::dbgs() <<
"- Sorted " << postOrder.size() <<
" ops\n");
157 for (
auto *op : llvm::reverse(postOrder)) {
158 auto mask = APInt::getZero(arcBreakers.size());
159 for (
auto *user : op->getUsers()) {
160 auto it = faninMasks.find(user);
161 if (it != faninMasks.end())
165 auto duplicateOp = faninMasks.insert({op,
mask});
167 assert(duplicateOp.second &&
"duplicate op in order");
171 faninMaskGroups.clear();
172 for (
auto [op, mask] : faninMasks)
174 faninMaskGroups[
mask].insert(op);
175 LLVM_DEBUG(llvm::dbgs() <<
"- Found " << faninMaskGroups.size()
176 <<
" fanin mask groups\n");
181 void Converter::extractArcs(
HWModuleOp module) {
182 DenseMap<Value, Value> valueMapping;
183 SmallVector<Value>
inputs;
185 SmallVector<Type> inputTypes;
186 SmallVector<Type> outputTypes;
187 SmallVector<std::pair<OpOperand *, unsigned>> externalUses;
190 for (
auto &group : faninMaskGroups) {
191 auto &opSet = group.second;
194 auto block = std::make_unique<Block>();
195 builder.setInsertionPointToStart(block.get());
196 valueMapping.clear();
201 externalUses.clear();
203 Operation *lastOp =
nullptr;
205 for (
auto *op : postOrder) {
206 if (!opSet.contains(op))
211 for (
auto &operand : op->getOpOperands()) {
212 if (opSet.contains(operand.get().getDefiningOp()))
214 auto &mapped = valueMapping[operand.get()];
216 mapped = block->addArgument(operand.get().getType(),
217 operand.get().getLoc());
218 inputs.push_back(operand.get());
219 inputTypes.push_back(mapped.getType());
223 for (
auto result : op->getResults()) {
224 bool anyExternal =
false;
225 for (
auto &use : result.getUses()) {
226 if (!opSet.contains(use.getOwner())) {
228 externalUses.push_back({&use,
outputs.size()});
233 outputTypes.push_back(result.getType());
241 builder.setInsertionPoint(module);
242 auto defOp =
builder.create<DefineOp>(
245 globalNamespace.newName(module.getModuleName() +
"_arc")),
246 builder.getFunctionType(inputTypes, outputTypes));
247 defOp.getBody().push_back(block.release());
251 builder.setInsertionPoint(module.getBodyBlock()->getTerminator());
252 auto arcOp =
builder.create<CallOp>(lastOp->getLoc(), defOp,
inputs);
253 arcUses.push_back(arcOp);
254 for (
auto [use, resultIdx] : externalUses)
255 use->set(arcOp.getResult(resultIdx));
259 LogicalResult Converter::absorbRegs(
HWModuleOp module) {
263 unsigned numTrivialRegs = 0;
264 for (
auto callOp : arcUses) {
265 auto stateOp = dyn_cast<StateOp>(callOp.getOperation());
266 Value clock = stateOp ? stateOp.getClock() : Value{};
268 SmallVector<seq::CompRegOp> absorbedRegs;
269 SmallVector<Attribute> absorbedNames(callOp->getNumResults(), {});
270 if (
auto names = callOp->getAttrOfType<ArrayAttr>(
"names"))
271 absorbedNames.assign(names.getValue().begin(), names.getValue().end());
276 bool isTrivial =
true;
277 for (
auto result : callOp->getResults()) {
278 if (!result.hasOneUse()) {
282 auto regOp = dyn_cast<seq::CompRegOp>(result.use_begin()->getOwner());
283 if (!regOp || regOp.getInput() != result ||
284 (clock && clock != regOp.getClk())) {
289 clock = regOp.getClk();
290 reset = regOp.getReset();
294 Value resetValue = regOp.getResetValue();
295 Operation *op = resetValue.getDefiningOp();
297 return regOp->emitOpError(
298 "is reset by an input; not supported by ConvertToArcs");
299 if (
auto constant = dyn_cast<hw::ConstantOp>(op)) {
300 if (constant.getValue() != 0)
301 return regOp->emitOpError(
"is reset to a constant non-zero value; "
302 "not supported by ConvertToArcs");
304 return regOp->emitOpError(
"is reset to a value that is not clearly "
305 "constant; not supported by ConvertToArcs");
309 absorbedRegs.push_back(regOp);
313 absorbedNames[result.getResultNumber()] = regOp.getNameAttr();
318 arcUses[outIdx++] = callOp;
327 auto arc = dyn_cast<StateOp>(callOp.getOperation());
329 arc.getClockMutable().assign(clock);
330 arc.setLatency(arc.getLatency() + 1);
332 mlir::IRRewriter rewriter(module->getContext());
333 rewriter.setInsertionPoint(callOp);
334 arc = rewriter.replaceOpWithNewOp<StateOp>(
335 callOp.getOperation(),
336 callOp.getCallableForCallee().get<SymbolRefAttr>(),
337 callOp->getResultTypes(), clock, Value{}, 1, callOp.getArgOperands());
342 return arc.emitError(
343 "StateOp tried to infer reset from CompReg, but already "
345 arc.getResetMutable().assign(reset);
347 if (tapRegisters && llvm::any_of(absorbedNames, [](
auto name) {
348 return !name.template cast<StringAttr>().getValue().empty();
350 arc->setAttr(
"names",
ArrayAttr::get(module.getContext(), absorbedNames));
351 for (
auto [arcResult,
reg] : llvm::zip(arc.getResults(), absorbedRegs)) {
352 auto it = arcBreakerIndices.find(
reg);
353 arcBreakers[it->second] = {};
354 arcBreakerIndices.erase(it);
355 reg.replaceAllUsesWith(arcResult);
359 if (numTrivialRegs > 0)
360 LLVM_DEBUG(llvm::dbgs() <<
"- Trivially converted " << numTrivialRegs
361 <<
" regs to arcs\n");
362 arcUses.truncate(outIdx);
367 MapVector<std::tuple<Value, Value, Operation *>, SmallVector<seq::CompRegOp>>
369 for (
auto *op : arcBreakers)
370 if (
auto regOp = dyn_cast_or_null<seq::CompRegOp>(op)) {
371 regsByInput[{regOp.getClk(), regOp.getReset(),
372 regOp.getInput().getDefiningOp()}]
376 unsigned numMappedRegs = 0;
377 for (
auto [clockAndResetAndOp, regOps] : regsByInput) {
378 numMappedRegs += regOps.size();
380 auto block = std::make_unique<Block>();
381 builder.setInsertionPointToStart(block.get());
383 SmallVector<Value>
inputs;
385 SmallVector<Attribute> names;
386 SmallVector<Type> types;
388 SmallVector<unsigned> regToOutputMapping;
389 for (
auto regOp : regOps) {
390 auto it = mapping.find(regOp.getInput());
391 if (it == mapping.end()) {
392 it = mapping.insert({regOp.getInput(),
inputs.size()}).first;
393 inputs.push_back(regOp.getInput());
394 types.push_back(regOp.getType());
395 outputs.push_back(block->addArgument(regOp.getType(), regOp.getLoc()));
396 names.push_back(regOp->getAttrOfType<StringAttr>(
"name"));
398 regToOutputMapping.push_back(it->second);
401 auto loc = regOps.back().getLoc();
404 builder.setInsertionPoint(module);
407 builder.getStringAttr(globalNamespace.newName(
408 module.getModuleName() +
"_arc")),
409 builder.getFunctionType(types, types));
410 defOp.getBody().push_back(block.release());
412 builder.setInsertionPoint(module.getBodyBlock()->getTerminator());
414 builder.create<StateOp>(loc, defOp, std::get<0>(clockAndResetAndOp),
416 auto reset = std::get<1>(clockAndResetAndOp);
418 arcOp.getResetMutable().assign(reset);
419 if (tapRegisters && llvm::any_of(names, [](
auto name) {
420 return !name.template cast<StringAttr>().getValue().empty();
422 arcOp->setAttr(
"names",
builder.getArrayAttr(names));
423 for (
auto [
reg, resultIdx] : llvm::zip(regOps, regToOutputMapping)) {
424 reg.replaceAllUsesWith(arcOp.getResult(resultIdx));
429 if (numMappedRegs > 0)
430 LLVM_DEBUG(llvm::dbgs() <<
"- Mapped " << numMappedRegs <<
" regs to "
431 << regsByInput.size() <<
" shuffling arcs\n");
441 #define GEN_PASS_DEF_CONVERTTOARCS
442 #include "circt/Conversion/Passes.h.inc"
446 struct ConvertToArcsPass :
public impl::ConvertToArcsBase<ConvertToArcsPass> {
447 using ConvertToArcsBase::ConvertToArcsBase;
449 void runOnOperation()
override {
451 converter.tapRegisters = tapRegisters;
452 if (failed(converter.run(getOperation())))
458 std::unique_ptr<OperationPass<ModuleOp>>
460 return std::make_unique<ConvertToArcsPass>(options);
assert(baseType &&"element must be base type")
static bool isArcBreakingOp(Operation *op)
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
A namespace that is used to store existing names and generate new names in some scope within the IR.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< OperationPass< ModuleOp > > createConvertToArcsPass(const ConvertToArcsOptions &options={})
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)