CIRCT 22.0.0git
Loading...
Searching...
No Matches
ConvertToArcs.cpp
Go to the documentation of this file.
1//===- ConvertToArcs.cpp --------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "mlir/Transforms/RegionUtils.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "convert-to-arcs"
24
25using namespace circt;
26using namespace arc;
27using namespace hw;
28using llvm::MapVector;
29using llvm::SmallSetVector;
30using mlir::ConversionConfig;
31
32static bool isArcBreakingOp(Operation *op) {
33 if (isa<TapOp>(op))
34 return false;
35 return op->hasTrait<OpTrait::ConstantLike>() ||
36 isa<hw::InstanceOp, seq::CompRegOp, MemoryOp, MemoryReadPortOp,
37 ClockedOpInterface, seq::InitialOp, seq::ClockGateOp,
38 sim::DPICallOp>(op) ||
39 op->getNumResults() > 1 || op->getNumRegions() > 0 ||
40 !mlir::isMemoryEffectFree(op);
41}
42
43static LogicalResult convertInitialValue(seq::CompRegOp reg,
44 SmallVectorImpl<Value> &values) {
45 if (!reg.getInitialValue())
46 return values.push_back({}), success();
47
48 // Use from_immutable cast to convert the seq.immutable type to the reg's
49 // type.
50 OpBuilder builder(reg);
51 auto init = seq::FromImmutableOp::create(builder, reg.getLoc(), reg.getType(),
52 reg.getInitialValue());
53
54 values.push_back(init);
55 return success();
56}
57
58//===----------------------------------------------------------------------===//
59// Conversion
60//===----------------------------------------------------------------------===//
61
62namespace {
63struct Converter {
64 LogicalResult run(ModuleOp module);
65 LogicalResult runOnModule(HWModuleOp module);
66 LogicalResult analyzeFanIn();
67 void extractArcs(HWModuleOp module);
68 LogicalResult absorbRegs(HWModuleOp module);
69
70 /// The global namespace used to create unique definition names.
71 Namespace globalNamespace;
72
73 /// All arc-breaking operations in the current module.
74 SmallVector<Operation *> arcBreakers;
76
77 /// A post-order traversal of the operations in the current module.
78 SmallVector<Operation *> postOrder;
79
80 /// The set of arc-breaking ops an operation in the current module
81 /// contributes to, represented as a bit mask.
82 MapVector<Operation *, APInt> faninMasks;
83
84 /// The sets of operations that contribute to the same arc-breaking ops.
85 MapVector<APInt, DenseSet<Operation *>> faninMaskGroups;
86
87 /// The arc uses generated by `extractArcs`.
88 SmallVector<mlir::CallOpInterface> arcUses;
89
90 /// Whether registers should be made observable by assigning their arcs a
91 /// "name" attribute.
92 bool tapRegisters;
93};
94} // namespace
95
96LogicalResult Converter::run(ModuleOp module) {
97 for (auto &op : module.getOps())
98 if (auto sym =
99 op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
100 globalNamespace.newName(sym.getValue());
101 for (auto module : module.getOps<HWModuleOp>())
102 if (failed(runOnModule(module)))
103 return failure();
104 return success();
105}
106
107LogicalResult Converter::runOnModule(HWModuleOp module) {
108 // Find all arc-breaking operations in this module and assign them an index.
109 arcBreakers.clear();
110 arcBreakerIndices.clear();
111 for (Operation &op : *module.getBodyBlock()) {
112 if (isa<seq::InitialOp>(&op))
113 continue;
114 if (!isArcBreakingOp(&op) && !isa<hw::OutputOp>(&op))
115 continue;
116 arcBreakerIndices[&op] = arcBreakers.size();
117 arcBreakers.push_back(&op);
118 }
119 // Skip modules with only `OutputOp`.
120 if (module.getBodyBlock()->without_terminator().empty() &&
121 isa<hw::OutputOp>(module.getBodyBlock()->getTerminator()))
122 return success();
123 LLVM_DEBUG(llvm::dbgs() << "Analyzing " << module.getModuleNameAttr() << " ("
124 << arcBreakers.size() << " breakers)\n");
125
126 // For each operation, figure out the set of breaker ops it contributes to,
127 // in the form of a bit mask. Then group operations together that contribute
128 // to the same set of breaker ops.
129 if (failed(analyzeFanIn()))
130 return failure();
131
132 // Extract the fanin mask groups into separate combinational arcs and
133 // combine them with the registers in the design.
134 extractArcs(module);
135 if (failed(absorbRegs(module)))
136 return failure();
137
138 return success();
139}
140
141LogicalResult Converter::analyzeFanIn() {
142 SmallVector<std::tuple<Operation *, SmallVector<Value, 2>>> worklist;
143 SetVector<Value> seenOperands;
144 auto addToWorklist = [&](Operation *op) {
145 seenOperands.clear();
146 for (auto operand : op->getOperands())
147 seenOperands.insert(operand);
148 mlir::getUsedValuesDefinedAbove(op->getRegions(), seenOperands);
149 worklist.emplace_back(op, seenOperands.getArrayRef());
150 };
151
152 // Seed the worklist and fanin masks with the arc breaking operations.
153 faninMasks.clear();
154 for (auto *op : arcBreakers) {
155 unsigned index = arcBreakerIndices.lookup(op);
156 auto mask = APInt::getOneBitSet(arcBreakers.size(), index);
157 faninMasks[op] = mask;
158 addToWorklist(op);
159 }
160
161 // Establish a post-order among the operations.
162 DenseSet<Operation *> seen;
163 DenseSet<Operation *> finished;
164 postOrder.clear();
165 while (!worklist.empty()) {
166 auto &[op, operands] = worklist.back();
167 if (operands.empty()) {
168 if (!isArcBreakingOp(op) && !isa<hw::OutputOp>(op))
169 postOrder.push_back(op);
170 finished.insert(op);
171 seen.erase(op);
172 worklist.pop_back();
173 continue;
174 }
175 auto operand = operands.pop_back_val(); // advance to next operand
176 auto *definingOp = operand.getDefiningOp();
177 if (!definingOp || isArcBreakingOp(definingOp) ||
178 finished.contains(definingOp))
179 continue;
180 if (!seen.insert(definingOp).second) {
181 definingOp->emitError("combinational loop detected");
182 return failure();
183 }
184 addToWorklist(definingOp);
185 }
186 LLVM_DEBUG(llvm::dbgs() << "- Sorted " << postOrder.size() << " ops\n");
187
188 // Compute fanin masks in reverse post-order, which will compute the mask
189 // for an operation's uses before it computes it for the operation itself.
190 // This allows us to compute the set of arc breakers an operation
191 // contributes to in one pass.
192 for (auto *op : llvm::reverse(postOrder)) {
193 auto mask = APInt::getZero(arcBreakers.size());
194 for (auto *user : op->getUsers()) {
195 while (user->getParentOp() != op->getParentOp())
196 user = user->getParentOp();
197 auto it = faninMasks.find(user);
198 if (it != faninMasks.end())
199 mask |= it->second;
200 }
201
202 auto duplicateOp = faninMasks.insert({op, mask});
203 (void)duplicateOp;
204 assert(duplicateOp.second && "duplicate op in order");
205 }
206
207 // Group the operations by their fan-in mask.
208 faninMaskGroups.clear();
209 for (auto [op, mask] : faninMasks)
210 if (!isArcBreakingOp(op) && !isa<hw::OutputOp>(op))
211 faninMaskGroups[mask].insert(op);
212 LLVM_DEBUG(llvm::dbgs() << "- Found " << faninMaskGroups.size()
213 << " fanin mask groups\n");
214
215 return success();
216}
217
218void Converter::extractArcs(HWModuleOp module) {
219 DenseMap<Value, Value> valueMapping;
220 SmallVector<Value> inputs;
221 SmallVector<Value> outputs;
222 SmallVector<Type> inputTypes;
223 SmallVector<Type> outputTypes;
224 SmallVector<std::pair<OpOperand *, unsigned>> externalUses;
225
226 arcUses.clear();
227 for (auto &group : faninMaskGroups) {
228 auto &opSet = group.second;
229 OpBuilder builder(module);
230
231 auto block = std::make_unique<Block>();
232 builder.setInsertionPointToStart(block.get());
233 valueMapping.clear();
234 inputs.clear();
235 outputs.clear();
236 inputTypes.clear();
237 outputTypes.clear();
238 externalUses.clear();
239
240 Operation *lastOp = nullptr;
241 // TODO: Remove the elements from the post order as we go.
242 for (auto *op : postOrder) {
243 if (!opSet.contains(op))
244 continue;
245 lastOp = op;
246 op->remove();
247 builder.insert(op);
248 for (auto &operand : op->getOpOperands()) {
249 if (opSet.contains(operand.get().getDefiningOp()))
250 continue;
251 auto &mapped = valueMapping[operand.get()];
252 if (!mapped) {
253 mapped = block->addArgument(operand.get().getType(),
254 operand.get().getLoc());
255 inputs.push_back(operand.get());
256 inputTypes.push_back(mapped.getType());
257 }
258 operand.set(mapped);
259 }
260 for (auto result : op->getResults()) {
261 bool anyExternal = false;
262 for (auto &use : result.getUses()) {
263 if (!opSet.contains(use.getOwner())) {
264 anyExternal = true;
265 externalUses.push_back({&use, outputs.size()});
266 }
267 }
268 if (anyExternal) {
269 outputs.push_back(result);
270 outputTypes.push_back(result.getType());
271 }
272 }
273 }
274 assert(lastOp);
275 arc::OutputOp::create(builder, lastOp->getLoc(), outputs);
276
277 // Create the arc definition.
278 builder.setInsertionPoint(module);
279 auto defOp =
280 DefineOp::create(builder, lastOp->getLoc(),
281 builder.getStringAttr(globalNamespace.newName(
282 module.getModuleName() + "_arc")),
283 builder.getFunctionType(inputTypes, outputTypes));
284 defOp.getBody().push_back(block.release());
285
286 // Create the call to the arc definition to replace the operations that
287 // we have just extracted.
288 builder.setInsertionPoint(module.getBodyBlock()->getTerminator());
289 auto arcOp = CallOp::create(builder, lastOp->getLoc(), defOp, inputs);
290 arcUses.push_back(arcOp);
291 for (auto [use, resultIdx] : externalUses)
292 use->set(arcOp.getResult(resultIdx));
293 }
294}
295
296LogicalResult Converter::absorbRegs(HWModuleOp module) {
297 // Handle the trivial cases where all of an arc's results are used by
298 // exactly one register each.
299 unsigned outIdx = 0;
300 unsigned numTrivialRegs = 0;
301 for (auto callOp : arcUses) {
302 auto stateOp = dyn_cast<StateOp>(callOp.getOperation());
303 Value clock = stateOp ? stateOp.getClock() : Value{};
304 Value reset;
305 SmallVector<Value> initialValues;
306 SmallVector<seq::CompRegOp> absorbedRegs;
307 SmallVector<Attribute> absorbedNames(callOp->getNumResults(), {});
308 if (auto names = callOp->getAttrOfType<ArrayAttr>("names"))
309 absorbedNames.assign(names.getValue().begin(), names.getValue().end());
310
311 // Go through all every arc result and collect the single register that uses
312 // it. If a result has multiple uses or is used by something other than a
313 // register, skip the arc for now and handle it later.
314 bool isTrivial = true;
315 for (auto result : callOp->getResults()) {
316 if (!result.hasOneUse()) {
317 isTrivial = false;
318 break;
319 }
320 auto regOp = dyn_cast<seq::CompRegOp>(result.use_begin()->getOwner());
321 if (!regOp || regOp.getInput() != result ||
322 (clock && clock != regOp.getClk())) {
323 isTrivial = false;
324 break;
325 }
326
327 clock = regOp.getClk();
328 reset = regOp.getReset();
329
330 // Check that if the register has a reset, it is to a constant zero
331 if (reset) {
332 Value resetValue = regOp.getResetValue();
333 Operation *op = resetValue.getDefiningOp();
334 if (!op)
335 return regOp->emitOpError(
336 "is reset by an input; not supported by ConvertToArcs");
337 if (auto constant = dyn_cast<hw::ConstantOp>(op)) {
338 if (constant.getValue() != 0)
339 return regOp->emitOpError("is reset to a constant non-zero value; "
340 "not supported by ConvertToArcs");
341 } else {
342 return regOp->emitOpError("is reset to a value that is not clearly "
343 "constant; not supported by ConvertToArcs");
344 }
345 }
346
347 if (failed(convertInitialValue(regOp, initialValues)))
348 return failure();
349
350 absorbedRegs.push_back(regOp);
351 // If we absorb a register into the arc, the arc effectively produces that
352 // register's value. So if the register had a name, ensure that we assign
353 // that name to the arc's output.
354 absorbedNames[result.getResultNumber()] = regOp.getNameAttr();
355 }
356
357 // If this wasn't a trivial case keep the arc around for a second iteration.
358 if (!isTrivial) {
359 arcUses[outIdx++] = callOp;
360 continue;
361 }
362 ++numTrivialRegs;
363
364 // Set the arc's clock to the clock of the registers we've absorbed, bump
365 // the latency up by one to account for the registers, add the reset if
366 // present and update the output names. Then replace the registers.
367
368 auto arc = dyn_cast<StateOp>(callOp.getOperation());
369 if (arc) {
370 arc.getClockMutable().assign(clock);
371 arc.setLatency(arc.getLatency() + 1);
372 } else {
373 mlir::IRRewriter rewriter(module->getContext());
374 rewriter.setInsertionPoint(callOp);
375 arc = rewriter.replaceOpWithNewOp<StateOp>(
376 callOp.getOperation(),
377 llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee()),
378 callOp->getResultTypes(), clock, Value{}, 1, callOp.getArgOperands());
379 }
380
381 if (reset) {
382 if (arc.getReset())
383 return arc.emitError(
384 "StateOp tried to infer reset from CompReg, but already "
385 "had a reset.");
386 arc.getResetMutable().assign(reset);
387 }
388
389 bool onlyDefaultInitializers =
390 llvm::all_of(initialValues, [](auto val) -> bool { return !val; });
391
392 if (!onlyDefaultInitializers) {
393 if (!arc.getInitials().empty()) {
394 return arc.emitError(
395 "StateOp tried to infer initial values from CompReg, but already "
396 "had an initial value.");
397 }
398 // Create 0 constants for default initialization
399 for (unsigned i = 0; i < initialValues.size(); ++i) {
400 if (!initialValues[i]) {
401 OpBuilder zeroBuilder(arc);
402 initialValues[i] = zeroBuilder.createOrFold<hw::ConstantOp>(
403 arc.getLoc(),
404 zeroBuilder.getIntegerAttr(arc.getResult(i).getType(), 0));
405 }
406 }
407 arc.getInitialsMutable().assign(initialValues);
408 }
409
410 if (tapRegisters && llvm::any_of(absorbedNames, [](auto name) {
411 return !cast<StringAttr>(name).getValue().empty();
412 }))
413 arc->setAttr("names", ArrayAttr::get(module.getContext(), absorbedNames));
414 for (auto [arcResult, reg] : llvm::zip(arc.getResults(), absorbedRegs)) {
415 auto it = arcBreakerIndices.find(reg);
416 arcBreakers[it->second] = {};
417 arcBreakerIndices.erase(it);
418 reg.replaceAllUsesWith(arcResult);
419 reg.erase();
420 }
421 }
422 if (numTrivialRegs > 0)
423 LLVM_DEBUG(llvm::dbgs() << "- Trivially converted " << numTrivialRegs
424 << " regs to arcs\n");
425 arcUses.truncate(outIdx);
426
427 // Group the remaining registers by their clock, their reset and the operation
428 // they use as input. This will allow us to generally collapse registers
429 // derived from the same arc into one shuffling arc.
430 MapVector<std::tuple<Value, Value, Operation *>, SmallVector<seq::CompRegOp>>
431 regsByInput;
432 for (auto *op : arcBreakers)
433 if (auto regOp = dyn_cast_or_null<seq::CompRegOp>(op)) {
434 regsByInput[{regOp.getClk(), regOp.getReset(),
435 regOp.getInput().getDefiningOp()}]
436 .push_back(regOp);
437 }
438
439 unsigned numMappedRegs = 0;
440 for (auto [clockAndResetAndOp, regOps] : regsByInput) {
441 numMappedRegs += regOps.size();
442 OpBuilder builder(module);
443 auto block = std::make_unique<Block>();
444 builder.setInsertionPointToStart(block.get());
445
446 SmallVector<Value> inputs;
447 SmallVector<Value> outputs;
448 SmallVector<Attribute> names;
449 SmallVector<Type> types;
450 SmallVector<Value> initialValues;
452 SmallVector<unsigned> regToOutputMapping;
453 for (auto regOp : regOps) {
454 auto it = mapping.find(regOp.getInput());
455 if (it == mapping.end()) {
456 it = mapping.insert({regOp.getInput(), inputs.size()}).first;
457 inputs.push_back(regOp.getInput());
458 types.push_back(regOp.getType());
459 outputs.push_back(block->addArgument(regOp.getType(), regOp.getLoc()));
460 names.push_back(regOp->getAttrOfType<StringAttr>("name"));
461 if (failed(convertInitialValue(regOp, initialValues)))
462 return failure();
463 }
464 regToOutputMapping.push_back(it->second);
465 }
466
467 auto loc = regOps.back().getLoc();
468 arc::OutputOp::create(builder, loc, outputs);
469
470 builder.setInsertionPoint(module);
471 auto defOp = DefineOp::create(builder, loc,
472 builder.getStringAttr(globalNamespace.newName(
473 module.getModuleName() + "_arc")),
474 builder.getFunctionType(types, types));
475 defOp.getBody().push_back(block.release());
476
477 builder.setInsertionPoint(module.getBodyBlock()->getTerminator());
478
479 bool onlyDefaultInitializers =
480 llvm::all_of(initialValues, [](auto val) -> bool { return !val; });
481
482 if (onlyDefaultInitializers)
483 initialValues.clear();
484 else
485 for (unsigned i = 0; i < initialValues.size(); ++i) {
486 if (!initialValues[i])
487 initialValues[i] = builder.createOrFold<hw::ConstantOp>(
488 loc, builder.getIntegerAttr(types[i], 0));
489 }
490
491 auto arcOp =
492 StateOp::create(builder, loc, defOp, std::get<0>(clockAndResetAndOp),
493 /*enable=*/Value{}, 1, inputs, initialValues);
494 auto reset = std::get<1>(clockAndResetAndOp);
495 if (reset)
496 arcOp.getResetMutable().assign(reset);
497 if (tapRegisters && llvm::any_of(names, [](auto name) {
498 return !cast<StringAttr>(name).getValue().empty();
499 }))
500 arcOp->setAttr("names", builder.getArrayAttr(names));
501 for (auto [reg, resultIdx] : llvm::zip(regOps, regToOutputMapping)) {
502 reg.replaceAllUsesWith(arcOp.getResult(resultIdx));
503 reg.erase();
504 }
505 }
506
507 if (numMappedRegs > 0)
508 LLVM_DEBUG(llvm::dbgs() << "- Mapped " << numMappedRegs << " regs to "
509 << regsByInput.size() << " shuffling arcs\n");
510
511 return success();
512}
513
514//===----------------------------------------------------------------------===//
515// LLHD Conversion
516//===----------------------------------------------------------------------===//
517
518/// `llhd.combinational` -> `arc.execute`
519static LogicalResult convert(llhd::CombinationalOp op,
520 llhd::CombinationalOp::Adaptor adaptor,
521 ConversionPatternRewriter &rewriter,
522 const TypeConverter &converter) {
523 // Convert the result types.
524 SmallVector<Type> resultTypes;
525 if (failed(converter.convertTypes(op.getResultTypes(), resultTypes)))
526 return failure();
527
528 // Collect the SSA values defined outside but used inside the body region.
529 auto cloneIntoBody = [](Operation *op) {
530 return op->hasTrait<OpTrait::ConstantLike>();
531 };
532 auto operands =
533 mlir::makeRegionIsolatedFromAbove(rewriter, op.getBody(), cloneIntoBody);
534
535 // Create a replacement `arc.execute` op.
536 auto executeOp =
537 ExecuteOp::create(rewriter, op.getLoc(), resultTypes, operands);
538 executeOp.getBody().takeBody(op.getBody());
539 rewriter.replaceOp(op, executeOp.getResults());
540 return success();
541}
542
543/// `llhd.yield` -> `arc.output`
544static LogicalResult convert(llhd::YieldOp op, llhd::YieldOp::Adaptor adaptor,
545 ConversionPatternRewriter &rewriter) {
546 rewriter.replaceOpWithNewOp<arc::OutputOp>(op, adaptor.getOperands());
547 return success();
548}
549
550//===----------------------------------------------------------------------===//
551// Pass Infrastructure
552//===----------------------------------------------------------------------===//
553
554namespace circt {
555#define GEN_PASS_DEF_CONVERTTOARCSPASS
556#include "circt/Conversion/Passes.h.inc"
557} // namespace circt
558
559namespace {
560struct ConvertToArcsPass
561 : public circt::impl::ConvertToArcsPassBase<ConvertToArcsPass> {
562 using ConvertToArcsPassBase::ConvertToArcsPassBase;
563 void runOnOperation() override;
564};
565} // namespace
566
567void ConvertToArcsPass::runOnOperation() {
568 // Setup the type conversion.
569 TypeConverter converter;
570
571 // Define legal types.
572 converter.addConversion([](Type type) -> std::optional<Type> {
573 if (isa<llhd::LLHDDialect>(type.getDialect()))
574 return std::nullopt;
575 return type;
576 });
577
578 // Gather the conversion patterns.
579 ConversionPatternSet patterns(&getContext(), converter);
580 patterns.add<llhd::CombinationalOp>(convert);
581 patterns.add<llhd::YieldOp>(convert);
582
583 // Setup the legal ops. (Sort alphabetically.)
584 ConversionTarget target(getContext());
585 target.addIllegalDialect<llhd::LLHDDialect>();
586 target.markUnknownOpDynamicallyLegal(
587 [](Operation *op) { return !isa<llhd::LLHDDialect>(op->getDialect()); });
588
589 // Disable pattern rollback to use the faster one-shot dialect conversion.
590 ConversionConfig config;
591 config.allowPatternRollback = false;
592
593 // Apply the dialect conversion patterns.
594 if (failed(applyPartialConversion(getOperation(), target, std::move(patterns),
595 config))) {
596 emitError(getOperation().getLoc()) << "conversion to arcs failed";
597 return signalPassFailure();
598 }
599
600 // Outline operations into arcs.
601 Converter outliner;
602 outliner.tapRegisters = tapRegisters;
603 if (failed(outliner.run(getOperation())))
604 return signalPassFailure();
605}
assert(baseType &&"element must be base type")
static LogicalResult convertInitialValue(seq::CompRegOp reg, SmallVectorImpl< Value > &values)
static LogicalResult convert(llhd::CombinationalOp op, llhd::CombinationalOp::Adaptor adaptor, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
llhd.combinational -> arc.execute
static bool isArcBreakingOp(Operation *op)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static Block * getBodyBlock(FModuleLike mod)
Extension of RewritePatternSet that allows adding matchAndRewrite functions with op adaptors and Conv...
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:121
Definition hw.py:1
Definition seq.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21