Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
LowerComb.cpp
Go to the documentation of this file.
1//===- LowerComb.cpp - Lower some ops in comb -------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Transforms/DialectConversion.h"
13#include "llvm/ADT/TypeSwitch.h"
14
15using namespace circt;
16using namespace circt::comb;
17
18namespace circt {
19namespace comb {
20#define GEN_PASS_DEF_LOWERCOMB
21#include "circt/Dialect/Comb/Passes.h.inc"
22} // namespace comb
23} // namespace circt
24
25namespace {
26/// Lower truth tables to mux trees.
27struct TruthTableToMuxTree : public OpConversionPattern<TruthTableOp> {
28 using OpConversionPattern::OpConversionPattern;
29
30private:
31 /// Get a mux tree for `inputs` corresponding to the given truth table. Do
32 /// this recursively by dividing the table in half for each input.
33 // NOLINTNEXTLINE(misc-no-recursion)
34 Value getMux(Location loc, OpBuilder &b, Value t, Value f,
35 ArrayRef<bool> table, Operation::operand_range inputs) const {
36 assert(table.size() == (1ull << inputs.size()));
37 if (table.size() == 1)
38 return table.front() ? t : f;
39
40 size_t half = table.size() / 2;
41 Value if1 =
42 getMux(loc, b, t, f, table.drop_front(half), inputs.drop_front());
43 Value if0 =
44 getMux(loc, b, t, f, table.drop_back(half), inputs.drop_front());
45 return MuxOp::create(b, loc, inputs.front(), if1, if0, false);
46 }
47
48public:
49 LogicalResult matchAndRewrite(TruthTableOp op, OpAdaptor adaptor,
50 ConversionPatternRewriter &b) const override {
51 Location loc = op.getLoc();
52 SmallVector<bool> table(
53 llvm::map_range(op.getLookupTableAttr().getAsValueRange<IntegerAttr>(),
54 [](const APInt &a) { return !a.isZero(); }));
55 Value t =
56 hw::ConstantOp::create(b, loc, b.getIntegerAttr(b.getI1Type(), 1));
57 Value f =
58 hw::ConstantOp::create(b, loc, b.getIntegerAttr(b.getI1Type(), 0));
59
60 Value tree = getMux(loc, b, t, f, table, op.getInputs());
61 b.modifyOpInPlace(tree.getDefiningOp(), [&]() {
62 tree.getDefiningOp()->setDialectAttrs(op->getDialectAttrs());
63 });
64 b.replaceOp(op, tree);
65 return success();
66 }
67};
68} // namespace
69
70namespace {
71class LowerCombPass : public impl::LowerCombBase<LowerCombPass> {
72public:
73 using LowerCombBase::LowerCombBase;
74
75 void runOnOperation() override;
76};
77} // namespace
78
79void LowerCombPass::runOnOperation() {
80 auto module = getOperation();
81
82 ConversionTarget target(getContext());
83 RewritePatternSet patterns(&getContext());
84 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
85 target.addIllegalOp<TruthTableOp>();
86
87 patterns.add<TruthTableToMuxTree>(patterns.getContext());
88
89 if (failed(applyPartialConversion(module, target, std::move(patterns))))
90 return signalPassFailure();
91}
assert(baseType &&"element must be base type")
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1