CIRCT  18.0.0git
FIRRTLOpInterfaces.cpp
Go to the documentation of this file.
1 //===- FIRRTLOpInterfaces.cpp - Implement the FIRRTL op interfaces --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the FIRRTL operation interfaces.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/BuiltinAttributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "llvm/ADT/SmallBitVector.h"
19 #include "llvm/ADT/StringRef.h"
20 
21 using namespace mlir;
22 using namespace llvm;
23 using namespace circt::firrtl;
24 
25 LogicalResult circt::firrtl::verifyModuleLikeOpInterface(FModuleLike module) {
26  // Verify port types first. This is used as the basis for the number of
27  // ports required everywhere else.
28  auto portTypes = module.getPortTypesAttr();
29  if (!portTypes || llvm::any_of(portTypes.getValue(), [](Attribute attr) {
30  return !isa<TypeAttr>(attr);
31  }))
32  return module.emitOpError("requires valid port types");
33 
34  auto numPorts = portTypes.size();
35 
36  // Verify the port dirctions.
37  auto portDirections = module.getPortDirectionsAttr();
38  if (!portDirections)
39  return module.emitOpError("requires valid port direction");
40  // TODO: bitwidth is 1 when there are no ports, since APInt previously did not
41  // support 0 bit widths.
42  auto bitWidth = portDirections.getValue().getBitWidth();
43  if (bitWidth != numPorts)
44  return module.emitOpError("requires ") << numPorts << " port directions";
45 
46  // Verify the port names.
47  auto portNames = module.getPortNamesAttr();
48  if (!portNames)
49  return module.emitOpError("requires valid port names");
50  if (portNames.size() != numPorts)
51  return module.emitOpError("requires ") << numPorts << " port names";
52  if (llvm::any_of(portNames.getValue(),
53  [](Attribute attr) { return !isa<StringAttr>(attr); }))
54  return module.emitOpError("port names should all be string attributes");
55 
56  // Verify the port annotations.
57  auto portAnnotations = module.getPortAnnotationsAttr();
58  if (!portAnnotations)
59  return module.emitOpError("requires valid port annotations");
60  // TODO: it seems weird to allow empty port annotations.
61  if (!portAnnotations.empty() && portAnnotations.size() != numPorts)
62  return module.emitOpError("requires ") << numPorts << " port annotations";
63  // TODO: Move this into an annotation verifier.
64  for (auto annos : portAnnotations.getValue()) {
65  auto arrayAttr = dyn_cast<ArrayAttr>(annos);
66  if (!arrayAttr)
67  return module.emitOpError(
68  "requires port annotations be array attributes");
69  if (llvm::any_of(arrayAttr.getValue(),
70  [](Attribute attr) { return !isa<DictionaryAttr>(attr); }))
71  return module.emitOpError(
72  "annotations must be dictionaries or subannotations");
73  }
74 
75  // Verify the port symbols.
76  auto portSymbols = module.getPortSymbolsAttr();
77  if (!portSymbols)
78  return module.emitOpError("requires valid port symbols");
79  if (!portSymbols.empty() && portSymbols.size() != numPorts)
80  return module.emitOpError("requires ") << numPorts << " port symbols";
81  if (llvm::any_of(portSymbols.getValue(), [](Attribute attr) {
82  return !attr || !isa<hw::InnerSymAttr>(attr);
83  }))
84  return module.emitOpError("port symbols should all be InnerSym attributes");
85 
86  // Verify the port locations.
87  auto portLocs = module.getPortLocationsAttr();
88  if (!portLocs)
89  return module.emitOpError("requires valid port locations");
90  if (portLocs.size() != numPorts)
91  return module.emitOpError("requires ") << numPorts << " port locations";
92  if (llvm::any_of(portLocs.getValue(), [](Attribute attr) {
93  return !attr || !isa<LocationAttr>(attr);
94  }))
95  return module.emitOpError("port symbols should all be location attributes");
96 
97  // Verify the body.
98  if (module->getNumRegions() != 1)
99  return module.emitOpError("requires one region");
100 
101  return success();
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // Forceable
106 //===----------------------------------------------------------------------===//
107 
109  Type type) {
110  auto base = dyn_cast_or_null<FIRRTLBaseType>(type);
111  // TODO: Find a way to not check same things RefType::get/verify does.
112  if (!forceable || !base || base.containsConst())
113  return {};
114  return circt::firrtl::RefType::get(base.getPassiveType(), forceable);
115 }
116 
117 LogicalResult circt::firrtl::detail::verifyForceableOp(Forceable op) {
118  bool forceable = op.isForceable();
119  auto ref = op.getDataRef();
120  if ((bool)ref != forceable)
121  return op.emitOpError("must have ref result iff marked forceable");
122  if (!forceable)
123  return success();
124  auto data = op.getDataRaw();
125  auto baseType = type_dyn_cast<FIRRTLBaseType>(data.getType());
126  if (!baseType)
127  return op.emitOpError("has data that is not a base type");
128  if (baseType.containsConst())
129  return op.emitOpError("cannot force a declaration of constant type");
130  auto expectedRefType = getForceableResultType(forceable, baseType);
131  if (ref.getType() != expectedRefType)
132  return op.emitOpError("reference result of incorrect type, found ")
133  << ref.getType() << ", expected " << expectedRefType;
134  return success();
135 }
136 
137 namespace {
138 /// Simple wrapper to allow construction from a context for local use.
139 class TrivialPatternRewriter : public PatternRewriter {
140 public:
141  explicit TrivialPatternRewriter(MLIRContext *context)
142  : PatternRewriter(context) {}
143 };
144 } // end namespace
145 
146 Forceable
147 circt::firrtl::detail::replaceWithNewForceability(Forceable op, bool forceable,
148  PatternRewriter *rewriter) {
149  if (forceable == op.isForceable())
150  return op;
151 
152  assert(op->getNumRegions() == 0);
153 
154  // Create copy of this operation with/without the forceable marker + result
155  // type.
156 
157  TrivialPatternRewriter localRewriter(op.getContext());
158  PatternRewriter &rw = rewriter ? *rewriter : localRewriter;
159 
160  // Grab the current operation's results and attributes.
161  SmallVector<Type, 8> resultTypes(op->getResultTypes());
162  SmallVector<NamedAttribute, 16> attributes(op->getAttrs());
163 
164  // Add/remove the optional ref result.
165  auto refType = firrtl::detail::getForceableResultType(true, op.getDataType());
166  if (forceable)
167  resultTypes.push_back(refType);
168  else {
169  assert(resultTypes.back() == refType &&
170  "expected forceable type as last result");
171  resultTypes.pop_back();
172  }
173 
174  // Add/remove the forceable marker.
175  auto forceableMarker =
176  rw.getNamedAttr(op.getForceableAttrName(), rw.getUnitAttr());
177  if (forceable)
178  attributes.push_back(forceableMarker);
179  else {
180  llvm::erase(attributes, forceableMarker);
181  assert(attributes.size() != op->getAttrs().size());
182  }
183 
184  // Create the replacement operation.
185  OperationState state(op.getLoc(), op->getName(), op->getOperands(),
186  resultTypes, attributes, op->getSuccessors());
187  rw.setInsertionPoint(op);
188  auto *replace = rw.create(state);
189 
190  // Dropping forceability (!forceable) -> no uses of forceable ref handle.
191  assert(forceable || op.getDataRef().use_empty());
192 
193  // Replace results.
194  for (auto result : llvm::drop_end(op->getResults(), forceable ? 0 : 1))
195  rw.replaceAllUsesWith(result, replace->getResult(result.getResultNumber()));
196  rw.eraseOp(op);
197  return cast<Forceable>(replace);
198 }
199 
200 #include "circt/Dialect/FIRRTL/FIRRTLOpInterfaces.cpp.inc"
assert(baseType &&"element must be base type")
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
LogicalResult verifyForceableOp(Forceable op)
Verify a Forceable op.
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
RefType getForceableResultType(bool forceable, Type type)
Return null or forceable reference result type.
LogicalResult verifyModuleLikeOpInterface(FModuleLike module)
Verification hook for verifying module like operations.