CIRCT 22.0.0git
Loading...
Searching...
No Matches
MooreOps.cpp
Go to the documentation of this file.
1//===- MooreOps.cpp - Implement the Moore operations ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Moore dialect operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Builders.h"
19#include "llvm/ADT/APSInt.h"
20#include "llvm/ADT/SmallString.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23using namespace circt;
24using namespace circt::moore;
25using namespace mlir;
26
27//===----------------------------------------------------------------------===//
28// SVModuleOp
29//===----------------------------------------------------------------------===//
30
31void SVModuleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
32 llvm::StringRef name, hw::ModuleType type) {
33 state.addAttribute(SymbolTable::getSymbolAttrName(),
34 builder.getStringAttr(name));
35 state.addAttribute(getModuleTypeAttrName(state.name), TypeAttr::get(type));
36 state.addRegion();
37}
38
39void SVModuleOp::print(OpAsmPrinter &p) {
40 p << " ";
41
42 // Print the visibility of the module.
43 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
44 if (auto visibility = (*this)->getAttrOfType<StringAttr>(visibilityAttrName))
45 p << visibility.getValue() << ' ';
46
47 p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
49 getModuleType(), {}, {});
50 p << " ";
51 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false,
52 /*printBlockTerminators=*/true);
53
54 p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(),
55 getAttributeNames());
56}
57
58ParseResult SVModuleOp::parse(OpAsmParser &parser, OperationState &result) {
59 // Parse the visibility attribute.
60 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
61
62 // Parse the module name.
63 StringAttr nameAttr;
64 if (parser.parseSymbolName(nameAttr, getSymNameAttrName(result.name),
65 result.attributes))
66 return failure();
67
68 // Parse the ports.
69 SmallVector<hw::module_like_impl::PortParse> ports;
70 TypeAttr modType;
71 if (failed(
72 hw::module_like_impl::parseModuleSignature(parser, ports, modType)))
73 return failure();
74 result.addAttribute(getModuleTypeAttrName(result.name), modType);
75
76 // Parse the attributes.
77 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
78 return failure();
79
80 // Add the entry block arguments.
81 SmallVector<OpAsmParser::Argument, 4> entryArgs;
82 for (auto &port : ports)
83 if (port.direction != hw::ModulePort::Direction::Output)
84 entryArgs.push_back(port);
85
86 // Parse the optional function body.
87 auto &bodyRegion = *result.addRegion();
88 if (parser.parseRegion(bodyRegion, entryArgs))
89 return failure();
90
91 ensureTerminator(bodyRegion, parser.getBuilder(), result.location);
92 return success();
93}
94
95void SVModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
96 mlir::OpAsmSetValueNameFn setNameFn) {
97 if (&region != &getBodyRegion())
98 return;
99 auto moduleType = getModuleType();
100 for (auto [index, arg] : llvm::enumerate(region.front().getArguments()))
101 setNameFn(arg, moduleType.getInputNameAttr(index));
102}
103
104OutputOp SVModuleOp::getOutputOp() {
105 return cast<OutputOp>(getBody()->getTerminator());
106}
107
108OperandRange SVModuleOp::getOutputs() { return getOutputOp().getOperands(); }
109
110//===----------------------------------------------------------------------===//
111// OutputOp
112//===----------------------------------------------------------------------===//
113
114LogicalResult OutputOp::verify() {
115 auto module = getParentOp();
116
117 // Check that the number of operands matches the number of output ports.
118 auto outputTypes = module.getModuleType().getOutputTypes();
119 if (outputTypes.size() != getNumOperands())
120 return emitOpError("has ")
121 << getNumOperands() << " operands, but enclosing module @"
122 << module.getSymName() << " has " << outputTypes.size()
123 << " outputs";
124
125 // Check that the operand types match the output ports.
126 for (unsigned i = 0, e = outputTypes.size(); i != e; ++i)
127 if (outputTypes[i] != getOperand(i).getType())
128 return emitOpError() << "operand " << i << " (" << getOperand(i).getType()
129 << ") does not match output type (" << outputTypes[i]
130 << ") of module @" << module.getSymName();
131
132 return success();
133}
134
135//===----------------------------------------------------------------------===//
136// InstanceOp
137//===----------------------------------------------------------------------===//
138
139LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
140 // Resolve the target symbol.
141 auto *symbol =
142 symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
143 if (!symbol)
144 return emitOpError("references unknown symbol @") << getModuleName();
145
146 // Check that the symbol is a SVModuleOp.
147 auto module = dyn_cast<SVModuleOp>(symbol);
148 if (!module)
149 return emitOpError("must reference a 'moore.module', but @")
150 << getModuleName() << " is a '" << symbol->getName() << "'";
151
152 // Check that the input ports match.
153 auto moduleType = module.getModuleType();
154 auto inputTypes = moduleType.getInputTypes();
155
156 if (inputTypes.size() != getNumOperands())
157 return emitOpError("has ")
158 << getNumOperands() << " operands, but target module @"
159 << module.getSymName() << " has " << inputTypes.size() << " inputs";
160
161 for (unsigned i = 0, e = inputTypes.size(); i != e; ++i)
162 if (inputTypes[i] != getOperand(i).getType())
163 return emitOpError() << "operand " << i << " (" << getOperand(i).getType()
164 << ") does not match input type (" << inputTypes[i]
165 << ") of module @" << module.getSymName();
166
167 // Check that the output ports match.
168 auto outputTypes = moduleType.getOutputTypes();
169
170 if (outputTypes.size() != getNumResults())
171 return emitOpError("has ")
172 << getNumOperands() << " results, but target module @"
173 << module.getSymName() << " has " << outputTypes.size()
174 << " outputs";
175
176 for (unsigned i = 0, e = outputTypes.size(); i != e; ++i)
177 if (outputTypes[i] != getResult(i).getType())
178 return emitOpError() << "result " << i << " (" << getResult(i).getType()
179 << ") does not match output type (" << outputTypes[i]
180 << ") of module @" << module.getSymName();
181
182 return success();
183}
184
185void InstanceOp::print(OpAsmPrinter &p) {
186 p << " ";
187 p.printAttributeWithoutType(getInstanceNameAttr());
188 p << " ";
189 p.printAttributeWithoutType(getModuleNameAttr());
190 printInputPortList(p, getOperation(), getInputs(), getInputs().getTypes(),
191 getInputNames());
192 p << " -> ";
193 printOutputPortList(p, getOperation(), getOutputs().getTypes(),
194 getOutputNames());
195 p.printOptionalAttrDict(getOperation()->getAttrs(), getAttributeNames());
196}
197
198ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
199 // Parse the instance name.
200 StringAttr instanceName;
201 if (parser.parseAttribute(instanceName, "instanceName", result.attributes))
202 return failure();
203
204 // Parse the module name.
205 FlatSymbolRefAttr moduleName;
206 if (parser.parseAttribute(moduleName, "moduleName", result.attributes))
207 return failure();
208
209 // Parse the input port list.
210 auto loc = parser.getCurrentLocation();
211 SmallVector<OpAsmParser::UnresolvedOperand> inputs;
212 SmallVector<Type> types;
213 ArrayAttr names;
214 if (parseInputPortList(parser, inputs, types, names))
215 return failure();
216 if (parser.resolveOperands(inputs, types, loc, result.operands))
217 return failure();
218 result.addAttribute("inputNames", names);
219
220 // Parse `->`.
221 if (parser.parseArrow())
222 return failure();
223
224 // Parse the output port list.
225 types.clear();
226 if (parseOutputPortList(parser, types, names))
227 return failure();
228 result.addAttribute("outputNames", names);
229 result.addTypes(types);
230
231 // Parse the attributes.
232 if (parser.parseOptionalAttrDict(result.attributes))
233 return failure();
234
235 return success();
236}
237
238void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
239 SmallString<32> name;
240 name += getInstanceName();
241 name += '.';
242 auto baseLen = name.size();
243
244 for (auto [result, portName] :
245 llvm::zip(getOutputs(), getOutputNames().getAsRange<StringAttr>())) {
246 if (!portName || portName.empty())
247 continue;
248 name.resize(baseLen);
249 name += portName.getValue();
250 setNameFn(result, name);
251 }
252}
253
254//===----------------------------------------------------------------------===//
255// VariableOp
256//===----------------------------------------------------------------------===//
257
258void VariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
259 if (getName() && !getName()->empty())
260 setNameFn(getResult(), *getName());
261}
262
263LogicalResult VariableOp::canonicalize(VariableOp op,
264 PatternRewriter &rewriter) {
265 // If the variable is embedded in an SSACFG region, move the initial value
266 // into an assignment immediately after the variable op. This allows the
267 // mem2reg pass which cannot handle variables with initial values.
268 auto initial = op.getInitial();
269 if (initial && mlir::mayHaveSSADominance(*op->getParentRegion())) {
270 rewriter.modifyOpInPlace(op, [&] { op.getInitialMutable().clear(); });
271 rewriter.setInsertionPointAfter(op);
272 BlockingAssignOp::create(rewriter, initial.getLoc(), op, initial);
273 return success();
274 }
275
276 // Check if the variable has one unique continuous assignment to it, all other
277 // uses are reads, and that all uses are in the same block as the variable
278 // itself.
279 auto *block = op->getBlock();
280 ContinuousAssignOp uniqueAssignOp;
281 for (auto *user : op->getUsers()) {
282 // Ensure that all users of the variable are in the same block.
283 if (user->getBlock() != block)
284 return failure();
285
286 // Ensure there is at most one unique continuous assignment to the variable.
287 if (auto assignOp = dyn_cast<ContinuousAssignOp>(user)) {
288 if (uniqueAssignOp)
289 return failure();
290 uniqueAssignOp = assignOp;
291 continue;
292 }
293
294 // Ensure all other users are reads.
295 if (!isa<ReadOp>(user))
296 return failure();
297 }
298 if (!uniqueAssignOp)
299 return failure();
300
301 // If the original variable had a name, create an `AssignedVariableOp` as a
302 // replacement. Otherwise substitute the assigned value directly.
303 Value assignedValue = uniqueAssignOp.getSrc();
304 if (auto name = op.getNameAttr(); name && !name.empty())
305 assignedValue = AssignedVariableOp::create(rewriter, op.getLoc(), name,
306 uniqueAssignOp.getSrc());
307
308 // Remove the assign op and replace all reads with the new assigned var op.
309 rewriter.eraseOp(uniqueAssignOp);
310 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
311 auto readOp = cast<ReadOp>(user);
312 rewriter.replaceOp(readOp, assignedValue);
313 }
314
315 // Remove the original variable.
316 rewriter.eraseOp(op);
317 return success();
318}
319
320SmallVector<MemorySlot> VariableOp::getPromotableSlots() {
321 // We cannot promote variables with an initial value, since that value may not
322 // dominate the location where the default value needs to be constructed.
323 if (mlir::mayBeGraphRegion(*getOperation()->getParentRegion()) ||
324 getInitial())
325 return {};
326
327 // Ensure that `getDefaultValue` can conjure up a default value for the
328 // variable's type.
329 if (!isa<PackedType>(getType().getNestedType()))
330 return {};
331
332 return {MemorySlot{getResult(), getType().getNestedType()}};
333}
334
335Value VariableOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) {
336 auto packedType = dyn_cast<PackedType>(slot.elemType);
337 if (!packedType)
338 return {};
339 auto bitWidth = packedType.getBitSize();
340 if (!bitWidth)
341 return {};
342 auto fvint = packedType.getDomain() == Domain::FourValued
343 ? FVInt::getAllX(*bitWidth)
344 : FVInt::getZero(*bitWidth);
345 Value value = ConstantOp::create(
346 builder, getLoc(),
347 IntType::get(getContext(), *bitWidth, packedType.getDomain()), fvint);
348 if (value.getType() != packedType)
349 ConversionOp::create(builder, getLoc(), packedType, value);
350 return value;
351}
352
353void VariableOp::handleBlockArgument(const MemorySlot &slot,
354 BlockArgument argument,
355 OpBuilder &builder) {}
356
357std::optional<mlir::PromotableAllocationOpInterface>
358VariableOp::handlePromotionComplete(const MemorySlot &slot, Value defaultValue,
359 OpBuilder &builder) {
360 if (defaultValue && defaultValue.use_empty())
361 defaultValue.getDefiningOp()->erase();
362 this->erase();
363 return {};
364}
365
366SmallVector<DestructurableMemorySlot> VariableOp::getDestructurableSlots() {
367 if (isa<SVModuleOp>(getOperation()->getParentOp()))
368 return {};
369 if (getInitial())
370 return {};
371
372 auto refType = getType();
373 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(refType);
374 if (!destructurable)
375 return {};
376
377 auto destructuredType = destructurable.getSubelementIndexMap();
378 if (!destructuredType)
379 return {};
380
381 return {DestructurableMemorySlot{{getResult(), refType}, *destructuredType}};
382}
383
384DenseMap<Attribute, MemorySlot> VariableOp::destructure(
385 const DestructurableMemorySlot &slot,
386 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
387 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
388 assert(slot.ptr == getResult());
389 assert(!getInitial());
390 builder.setInsertionPointAfter(*this);
391
392 auto destructurableType = cast<DestructurableTypeInterface>(getType());
393 DenseMap<Attribute, MemorySlot> slotMap;
394 for (Attribute index : usedIndices) {
395 auto elemType = cast<RefType>(destructurableType.getTypeAtIndex(index));
396 assert(elemType && "used index must exist");
397 StringAttr varName;
398 if (auto name = getName(); name && !name->empty())
399 varName = StringAttr::get(
400 getContext(), (*name) + "." + cast<StringAttr>(index).getValue());
401 auto varOp =
402 VariableOp::create(builder, getLoc(), elemType, varName, Value());
403 newAllocators.push_back(varOp);
404 slotMap.try_emplace<MemorySlot>(index, {varOp.getResult(), elemType});
405 }
406
407 return slotMap;
408}
409
410std::optional<DestructurableAllocationOpInterface>
411VariableOp::handleDestructuringComplete(const DestructurableMemorySlot &slot,
412 OpBuilder &builder) {
413 assert(slot.ptr == getResult());
414 this->erase();
415 return std::nullopt;
416}
417
418//===----------------------------------------------------------------------===//
419// NetOp
420//===----------------------------------------------------------------------===//
421
422void NetOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
423 if (getName() && !getName()->empty())
424 setNameFn(getResult(), *getName());
425}
426
427LogicalResult NetOp::canonicalize(NetOp op, PatternRewriter &rewriter) {
428 bool modified = false;
429
430 // Check if the net has one unique continuous assignment to it, and
431 // additionally if all other users are reads.
432 auto *block = op->getBlock();
433 ContinuousAssignOp uniqueAssignOp;
434 bool allUsesAreReads = true;
435 for (auto *user : op->getUsers()) {
436 // Ensure that all users of the net are in the same block.
437 if (user->getBlock() != block)
438 return failure();
439
440 // Ensure there is at most one unique continuous assignment to the net.
441 if (auto assignOp = dyn_cast<ContinuousAssignOp>(user)) {
442 if (uniqueAssignOp)
443 return failure();
444 uniqueAssignOp = assignOp;
445 continue;
446 }
447
448 // Ensure all other users are reads.
449 if (!isa<ReadOp>(user))
450 allUsesAreReads = false;
451 }
452
453 // If there was one unique assignment, and the `NetOp` does not yet have an
454 // assigned value set, fold the assignment into the net.
455 if (uniqueAssignOp && !op.getAssignment()) {
456 rewriter.modifyOpInPlace(
457 op, [&] { op.getAssignmentMutable().assign(uniqueAssignOp.getSrc()); });
458 rewriter.eraseOp(uniqueAssignOp);
459 modified = true;
460 uniqueAssignOp = {};
461 }
462
463 // If all users of the net op are reads, and any potential unique assignment
464 // has been folded into the net op itself, directly replace the reads with the
465 // net's assigned value.
466 if (!uniqueAssignOp && allUsesAreReads && op.getAssignment()) {
467 // If the original net had a name, create an `AssignedVariableOp` as a
468 // replacement. Otherwise substitute the assigned value directly.
469 auto assignedValue = op.getAssignment();
470 if (auto name = op.getNameAttr(); name && !name.empty())
471 assignedValue = AssignedVariableOp::create(rewriter, op.getLoc(), name,
472 assignedValue);
473
474 // Replace all reads with the new assigned var op and remove the original
475 // net op.
476 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
477 auto readOp = cast<ReadOp>(user);
478 rewriter.replaceOp(readOp, assignedValue);
479 }
480 rewriter.eraseOp(op);
481 modified = true;
482 }
483
484 return success(modified);
485}
486
487//===----------------------------------------------------------------------===//
488// AssignedVariableOp
489//===----------------------------------------------------------------------===//
490
491void AssignedVariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
492 if (getName() && !getName()->empty())
493 setNameFn(getResult(), *getName());
494}
495
496LogicalResult AssignedVariableOp::canonicalize(AssignedVariableOp op,
497 PatternRewriter &rewriter) {
498 // Eliminate chained variables with the same name.
499 // var(name, var(name, x)) -> var(name, x)
500 if (auto otherOp = op.getInput().getDefiningOp<AssignedVariableOp>()) {
501 if (otherOp.getNameAttr() == op.getNameAttr()) {
502 rewriter.replaceOp(op, otherOp);
503 return success();
504 }
505 }
506
507 // Eliminate variables that alias an input port of the same name.
508 if (auto blockArg = dyn_cast<BlockArgument>(op.getInput())) {
509 if (auto moduleOp =
510 dyn_cast<SVModuleOp>(blockArg.getOwner()->getParentOp())) {
511 auto moduleType = moduleOp.getModuleType();
512 auto portName = moduleType.getInputNameAttr(blockArg.getArgNumber());
513 if (portName == op.getNameAttr()) {
514 rewriter.replaceOp(op, blockArg);
515 return success();
516 }
517 }
518 }
519
520 // Eliminate variables that feed an output port of the same name.
521 for (auto &use : op->getUses()) {
522 auto *useOwner = use.getOwner();
523 if (auto outputOp = dyn_cast<OutputOp>(useOwner)) {
524 if (auto moduleOp = dyn_cast<SVModuleOp>(outputOp->getParentOp())) {
525 auto moduleType = moduleOp.getModuleType();
526 auto portName = moduleType.getOutputNameAttr(use.getOperandNumber());
527 if (portName == op.getNameAttr()) {
528 rewriter.replaceOp(op, op.getInput());
529 return success();
530 }
531 } else
532 break;
533 }
534 }
535
536 return failure();
537}
538
539//===----------------------------------------------------------------------===//
540// ConstantOp
541//===----------------------------------------------------------------------===//
542
543void ConstantOp::print(OpAsmPrinter &p) {
544 p << " ";
545 printFVInt(p, getValue());
546 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
547 p << " : ";
548 p.printStrippedAttrOrType(getType());
549}
550
551ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
552 // Parse the constant value.
553 FVInt value;
554 auto valueLoc = parser.getCurrentLocation();
555 if (parseFVInt(parser, value))
556 return failure();
557
558 // Parse any optional attributes and the `:`.
559 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
560 return failure();
561
562 // Parse the result type.
563 IntType type;
564 if (parser.parseCustomTypeWithFallback(type))
565 return failure();
566
567 // Extend or truncate the constant value to match the size of the type.
568 if (type.getWidth() > value.getBitWidth()) {
569 // sext is always safe here, even for unsigned values, because the
570 // parseOptionalInteger method will return something with a zero in the
571 // top bits if it is a positive number.
572 value = value.sext(type.getWidth());
573 } else if (type.getWidth() < value.getBitWidth()) {
574 // The parser can return an unnecessarily wide result with leading
575 // zeros. This isn't a problem, but truncating off bits is bad.
576 unsigned neededBits =
577 value.isNegative() ? value.getSignificantBits() : value.getActiveBits();
578 if (type.getWidth() < neededBits)
579 return parser.emitError(valueLoc)
580 << "value requires " << neededBits
581 << " bits, but result type only has " << type.getWidth();
582 value = value.trunc(type.getWidth());
583 }
584
585 // If the constant contains any X or Z bits, the result type must be
586 // four-valued.
587 if (value.hasUnknown() && type.getDomain() != Domain::FourValued)
588 return parser.emitError(valueLoc)
589 << "value contains X or Z bits, but result type " << type
590 << " only allows two-valued bits";
591
592 // Build the attribute and op.
593 auto attrValue = FVIntegerAttr::get(parser.getContext(), value);
594 result.addAttribute("value", attrValue);
595 result.addTypes(type);
596 return success();
597}
598
599LogicalResult ConstantOp::verify() {
600 auto attrWidth = getValue().getBitWidth();
601 auto typeWidth = getType().getWidth();
602 if (attrWidth != typeWidth)
603 return emitError("attribute width ")
604 << attrWidth << " does not match return type's width " << typeWidth;
605 return success();
606}
607
608void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
609 const FVInt &value) {
610 assert(type.getWidth() == value.getBitWidth() &&
611 "FVInt width must match type width");
612 build(builder, result, type, FVIntegerAttr::get(builder.getContext(), value));
613}
614
615void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
616 const APInt &value) {
617 assert(type.getWidth() == value.getBitWidth() &&
618 "APInt width must match type width");
619 build(builder, result, type, FVInt(value));
620}
621
622/// This builder allows construction of small signed integers like 0, 1, -1
623/// matching a specified MLIR type. This shouldn't be used for general constant
624/// folding because it only works with values that can be expressed in an
625/// `int64_t`.
626void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
627 int64_t value, bool isSigned) {
628 build(builder, result, type,
629 APInt(type.getWidth(), (uint64_t)value, isSigned));
630}
631
632OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
633 assert(adaptor.getOperands().empty() && "constant has no operands");
634 return getValueAttr();
635}
636
637//===----------------------------------------------------------------------===//
638// ConstantTimeOp
639//===----------------------------------------------------------------------===//
640
641OpFoldResult ConstantTimeOp::fold(FoldAdaptor adaptor) {
642 return getValueAttr();
643}
644
645//===----------------------------------------------------------------------===//
646// ConstantRealOp
647//===----------------------------------------------------------------------===//
648
649LogicalResult ConstantRealOp::inferReturnTypes(
650 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
651 DictionaryAttr attrs, mlir::OpaqueProperties properties,
652 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
653 ConstantRealOp::Adaptor adaptor(operands, attrs, properties);
654 results.push_back(RealType::get(
655 context, static_cast<RealWidth>(
656 adaptor.getValueAttr().getType().getIntOrFloatBitWidth())));
657 return success();
658}
659
660//===----------------------------------------------------------------------===//
661// ConcatOp
662//===----------------------------------------------------------------------===//
663
664LogicalResult ConcatOp::inferReturnTypes(
665 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
666 DictionaryAttr attrs, mlir::OpaqueProperties properties,
667 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
668 Domain domain = Domain::TwoValued;
669 unsigned width = 0;
670 for (auto operand : operands) {
671 auto type = cast<IntType>(operand.getType());
672 if (type.getDomain() == Domain::FourValued)
673 domain = Domain::FourValued;
674 width += type.getWidth();
675 }
676 results.push_back(IntType::get(context, width, domain));
677 return success();
678}
679
680//===----------------------------------------------------------------------===//
681// ConcatRefOp
682//===----------------------------------------------------------------------===//
683
684LogicalResult ConcatRefOp::inferReturnTypes(
685 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
686 DictionaryAttr attrs, mlir::OpaqueProperties properties,
687 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
688 Domain domain = Domain::TwoValued;
689 unsigned width = 0;
690 for (Value operand : operands) {
691 UnpackedType nestedType = cast<RefType>(operand.getType()).getNestedType();
692 PackedType packedType = dyn_cast<PackedType>(nestedType);
693
694 if (!packedType) {
695 return failure();
696 }
697
698 if (packedType.getDomain() == Domain::FourValued)
699 domain = Domain::FourValued;
700
701 // getBitSize() for PackedType returns an optional, so we must check it.
702 std::optional<int> bitSize = packedType.getBitSize();
703 if (!bitSize) {
704 return failure();
705 }
706 width += *bitSize;
707 }
708 results.push_back(RefType::get(IntType::get(context, width, domain)));
709 return success();
710}
711
712//===----------------------------------------------------------------------===//
713// ArrayCreateOp
714//===----------------------------------------------------------------------===//
715
716static std::pair<unsigned, UnpackedType> getArrayElements(Type type) {
717 if (auto arrayType = dyn_cast<ArrayType>(type))
718 return {arrayType.getSize(), arrayType.getElementType()};
719 if (auto arrayType = dyn_cast<UnpackedArrayType>(type))
720 return {arrayType.getSize(), arrayType.getElementType()};
721 assert(0 && "expected ArrayType or UnpackedArrayType");
722 return {};
723}
724
725LogicalResult ArrayCreateOp::verify() {
726 auto [size, elementType] = getArrayElements(getType());
727
728 // Check that the number of operands matches the array size.
729 if (getElements().size() != size)
730 return emitOpError() << "has " << getElements().size()
731 << " operands, but result type requires " << size;
732
733 // Check that the operand types match the array element type. We only need to
734 // check one of the operands, since the `SameTypeOperands` trait ensures all
735 // operands have the same type.
736 if (size > 0) {
737 auto value = getElements()[0];
738 if (value.getType() != elementType)
739 return emitOpError() << "operands have type " << value.getType()
740 << ", but array requires " << elementType;
741 }
742 return success();
743}
744
745//===----------------------------------------------------------------------===//
746// StructCreateOp
747//===----------------------------------------------------------------------===//
748
749static std::optional<uint32_t> getStructFieldIndex(Type type, StringAttr name) {
750 if (auto structType = dyn_cast<StructType>(type))
751 return structType.getFieldIndex(name);
752 if (auto structType = dyn_cast<UnpackedStructType>(type))
753 return structType.getFieldIndex(name);
754 assert(0 && "expected StructType or UnpackedStructType");
755 return {};
756}
757
758static ArrayRef<StructLikeMember> getStructMembers(Type type) {
759 if (auto structType = dyn_cast<StructType>(type))
760 return structType.getMembers();
761 if (auto structType = dyn_cast<UnpackedStructType>(type))
762 return structType.getMembers();
763 assert(0 && "expected StructType or UnpackedStructType");
764 return {};
765}
766
767static UnpackedType getStructFieldType(Type type, StringAttr name) {
768 if (auto index = getStructFieldIndex(type, name))
769 return getStructMembers(type)[*index].type;
770 return {};
771}
772
773LogicalResult StructCreateOp::verify() {
774 auto members = getStructMembers(getType());
775
776 // Check that the number of operands matches the number of struct fields.
777 if (getFields().size() != members.size())
778 return emitOpError() << "has " << getFields().size()
779 << " operands, but result type requires "
780 << members.size();
781
782 // Check that the operand types match the struct field types.
783 for (auto [index, pair] : llvm::enumerate(llvm::zip(getFields(), members))) {
784 auto [value, member] = pair;
785 if (value.getType() != member.type)
786 return emitOpError() << "operand #" << index << " has type "
787 << value.getType() << ", but struct field "
788 << member.name << " requires " << member.type;
789 }
790 return success();
791}
792
793OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
794 SmallVector<NamedAttribute> fields;
795 for (auto [member, field] :
796 llvm::zip(getStructMembers(getType()), adaptor.getFields())) {
797 if (!field)
798 return {};
799 fields.push_back(NamedAttribute(member.name, field));
800 }
801 return DictionaryAttr::get(getContext(), fields);
802}
803
804//===----------------------------------------------------------------------===//
805// StructExtractOp
806//===----------------------------------------------------------------------===//
807
808LogicalResult StructExtractOp::verify() {
809 auto type = getStructFieldType(getInput().getType(), getFieldNameAttr());
810 if (!type)
811 return emitOpError() << "extracts field " << getFieldNameAttr()
812 << " which does not exist in " << getInput().getType();
813 if (type != getType())
814 return emitOpError() << "result type " << getType()
815 << " must match struct field type " << type;
816 return success();
817}
818
819OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
820 // Extract on a constant struct input.
821 if (auto fields = dyn_cast_or_null<DictionaryAttr>(adaptor.getInput()))
822 if (auto value = fields.get(getFieldNameAttr()))
823 return value;
824
825 // extract(inject(s, "field", v), "field") -> v
826 if (auto inject = getInput().getDefiningOp<StructInjectOp>()) {
827 if (inject.getFieldNameAttr() == getFieldNameAttr())
828 return inject.getNewValue();
829 return {};
830 }
831
832 // extract(create({"field": v, ...}), "field") -> v
833 if (auto create = getInput().getDefiningOp<StructCreateOp>()) {
834 if (auto index = getStructFieldIndex(create.getType(), getFieldNameAttr()))
835 return create.getFields()[*index];
836 return {};
837 }
838
839 return {};
840}
841
842//===----------------------------------------------------------------------===//
843// StructExtractRefOp
844//===----------------------------------------------------------------------===//
845
846LogicalResult StructExtractRefOp::verify() {
847 auto type = getStructFieldType(
848 cast<RefType>(getInput().getType()).getNestedType(), getFieldNameAttr());
849 if (!type)
850 return emitOpError() << "extracts field " << getFieldNameAttr()
851 << " which does not exist in " << getInput().getType();
852 if (type != getType().getNestedType())
853 return emitOpError() << "result ref of type " << getType().getNestedType()
854 << " must match struct field type " << type;
855 return success();
856}
857
858bool StructExtractRefOp::canRewire(
859 const DestructurableMemorySlot &slot,
860 SmallPtrSetImpl<Attribute> &usedIndices,
861 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
862 const DataLayout &dataLayout) {
863 if (slot.ptr != getInput())
864 return false;
865 auto index = getFieldNameAttr();
866 if (!index || !slot.subelementTypes.contains(index))
867 return false;
868 usedIndices.insert(index);
869 return true;
870}
871
872DeletionKind
873StructExtractRefOp::rewire(const DestructurableMemorySlot &slot,
874 DenseMap<Attribute, MemorySlot> &subslots,
875 OpBuilder &builder, const DataLayout &dataLayout) {
876 auto index = getFieldNameAttr();
877 const MemorySlot &memorySlot = subslots.at(index);
878 replaceAllUsesWith(memorySlot.ptr);
879 getInputMutable().drop();
880 erase();
881 return DeletionKind::Keep;
882}
883
884//===----------------------------------------------------------------------===//
885// StructInjectOp
886//===----------------------------------------------------------------------===//
887
888LogicalResult StructInjectOp::verify() {
889 auto type = getStructFieldType(getInput().getType(), getFieldNameAttr());
890 if (!type)
891 return emitOpError() << "injects field " << getFieldNameAttr()
892 << " which does not exist in " << getInput().getType();
893 if (type != getNewValue().getType())
894 return emitOpError() << "injected value " << getNewValue().getType()
895 << " must match struct field type " << type;
896 return success();
897}
898
899OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
900 auto input = adaptor.getInput();
901 auto newValue = adaptor.getNewValue();
902 if (!input || !newValue)
903 return {};
904 NamedAttrList fields(cast<DictionaryAttr>(input));
905 fields.set(getFieldNameAttr(), newValue);
906 return fields.getDictionary(getContext());
907}
908
909LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
910 PatternRewriter &rewriter) {
911 auto members = getStructMembers(op.getType());
912
913 // Chase a chain of `struct_inject` ops, with an optional final
914 // `struct_create`, and take note of the values assigned to each field.
915 SmallPtrSet<Operation *, 4> injectOps;
916 DenseMap<StringAttr, Value> fieldValues;
917 Value input = op;
918 while (auto injectOp = input.getDefiningOp<StructInjectOp>()) {
919 if (!injectOps.insert(injectOp).second)
920 return failure();
921 fieldValues.insert({injectOp.getFieldNameAttr(), injectOp.getNewValue()});
922 input = injectOp.getInput();
923 }
924 if (auto createOp = input.getDefiningOp<StructCreateOp>())
925 for (auto [value, member] : llvm::zip(createOp.getFields(), members))
926 fieldValues.insert({member.name, value});
927
928 // If the inject chain sets all fields, canonicalize to a `struct_create`.
929 if (fieldValues.size() == members.size()) {
930 SmallVector<Value> values;
931 values.reserve(fieldValues.size());
932 for (auto member : members)
933 values.push_back(fieldValues.lookup(member.name));
934 rewriter.replaceOpWithNewOp<StructCreateOp>(op, op.getType(), values);
935 return success();
936 }
937
938 // If each inject op in the chain assigned to a unique field, there is nothing
939 // to canonicalize.
940 if (injectOps.size() == fieldValues.size())
941 return failure();
942
943 // Otherwise we can eliminate overwrites by creating new injects. The hash map
944 // of field values contains the last assigned value for each field.
945 for (auto member : members)
946 if (auto value = fieldValues.lookup(member.name))
947 input = StructInjectOp::create(rewriter, op.getLoc(), op.getType(), input,
948 member.name, value);
949 rewriter.replaceOp(op, input);
950 return success();
951}
952
953//===----------------------------------------------------------------------===//
954// UnionCreateOp
955//===----------------------------------------------------------------------===//
956
957LogicalResult UnionCreateOp::verify() {
958 /// checks if the types of the input is exactly equal to the union field
959 /// type
960 return TypeSwitch<Type, LogicalResult>(getType())
961 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
962 auto members = type.getMembers();
963 auto resultType = getType();
964 auto fieldName = getFieldName();
965 for (const auto &member : members)
966 if (member.name == fieldName && member.type == resultType)
967 return success();
968 emitOpError("input type must match the union field type");
969 return failure();
970 })
971 .Default([this](auto &) {
972 emitOpError("input type must be UnionType or UnpackedUnionType");
973 return failure();
974 });
975}
976
977//===----------------------------------------------------------------------===//
978// UnionExtractOp
979//===----------------------------------------------------------------------===//
980
981LogicalResult UnionExtractOp::verify() {
982 /// checks if the types of the input is exactly equal to the one of the
983 /// types of the result union fields
984 return TypeSwitch<Type, LogicalResult>(getInput().getType())
985 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
986 auto members = type.getMembers();
987 auto fieldName = getFieldName();
988 auto resultType = getType();
989 for (const auto &member : members)
990 if (member.name == fieldName && member.type == resultType)
991 return success();
992 emitOpError("result type must match the union field type");
993 return failure();
994 })
995 .Default([this](auto &) {
996 emitOpError("input type must be UnionType or UnpackedUnionType");
997 return failure();
998 });
999}
1000
1001//===----------------------------------------------------------------------===//
1002// UnionExtractOp
1003//===----------------------------------------------------------------------===//
1004
1005LogicalResult UnionExtractRefOp::verify() {
1006 /// checks if the types of the result is exactly equal to the type of the
1007 /// refe union field
1008 return TypeSwitch<Type, LogicalResult>(getInput().getType().getNestedType())
1009 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
1010 auto members = type.getMembers();
1011 auto fieldName = getFieldName();
1012 auto resultType = getType().getNestedType();
1013 for (const auto &member : members)
1014 if (member.name == fieldName && member.type == resultType)
1015 return success();
1016 emitOpError("result type must match the union field type");
1017 return failure();
1018 })
1019 .Default([this](auto &) {
1020 emitOpError("input type must be UnionType or UnpackedUnionType");
1021 return failure();
1022 });
1023}
1024
1025//===----------------------------------------------------------------------===//
1026// YieldOp
1027//===----------------------------------------------------------------------===//
1028
1029LogicalResult YieldOp::verify() {
1030 // Check that YieldOp's parent operation is ConditionalOp.
1031 auto cond = dyn_cast<ConditionalOp>(*(*this).getParentOp());
1032 if (!cond) {
1033 emitOpError("must have a conditional parent");
1034 return failure();
1035 }
1036
1037 // Check that the operand matches the parent operation's result.
1038 auto condType = cond.getType();
1039 auto yieldType = getOperand().getType();
1040 if (condType != yieldType) {
1041 emitOpError("yield type must match conditional. Expected ")
1042 << condType << ", but got " << yieldType << ".";
1043 return failure();
1044 }
1045
1046 return success();
1047}
1048
1049//===----------------------------------------------------------------------===//
1050// ConversionOp
1051//===----------------------------------------------------------------------===//
1052
1053OpFoldResult ConversionOp::fold(FoldAdaptor adaptor) {
1054 // Fold away no-op casts.
1055 if (getInput().getType() == getResult().getType())
1056 return getInput();
1057
1058 // Convert domains of constant integer inputs.
1059 auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput());
1060 auto fromIntType = dyn_cast<IntType>(getInput().getType());
1061 auto toIntType = dyn_cast<IntType>(getResult().getType());
1062 if (intInput && fromIntType && toIntType &&
1063 fromIntType.getWidth() == toIntType.getWidth()) {
1064 // If we are going *to* a four-valued type, simply pass through the
1065 // constant.
1066 if (toIntType.getDomain() == Domain::FourValued)
1067 return intInput;
1068
1069 // Otherwise map all unknown bits to zero (the default in SystemVerilog) and
1070 // return a new constant.
1071 return FVIntegerAttr::get(getContext(), intInput.getValue().toAPInt(false));
1072 }
1073
1074 return {};
1075}
1076
1077//===----------------------------------------------------------------------===//
1078// LogicToIntOp
1079//===----------------------------------------------------------------------===//
1080
1081OpFoldResult LogicToIntOp::fold(FoldAdaptor adaptor) {
1082 // logic_to_int(int_to_logic(x)) -> x
1083 if (auto reverseOp = getInput().getDefiningOp<IntToLogicOp>())
1084 return reverseOp.getInput();
1085
1086 // Map all unknown bits to zero (the default in SystemVerilog) and return a
1087 // new constant.
1088 if (auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1089 return FVIntegerAttr::get(getContext(), intInput.getValue().toAPInt(false));
1090
1091 return {};
1092}
1093
1094//===----------------------------------------------------------------------===//
1095// IntToLogicOp
1096//===----------------------------------------------------------------------===//
1097
1098OpFoldResult IntToLogicOp::fold(FoldAdaptor adaptor) {
1099 // Cannot fold int_to_logic(logic_to_int(x)) -> x since that would lose
1100 // information.
1101
1102 // Simply pass through constants.
1103 if (auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1104 return intInput;
1105
1106 return {};
1107}
1108
1109//===----------------------------------------------------------------------===//
1110// TimeToLogicOp
1111//===----------------------------------------------------------------------===//
1112
1113OpFoldResult TimeToLogicOp::fold(FoldAdaptor adaptor) {
1114 // time_to_logic(logic_to_time(x)) -> x
1115 if (auto reverseOp = getInput().getDefiningOp<LogicToTimeOp>())
1116 return reverseOp.getInput();
1117
1118 // Convert constants.
1119 if (auto attr = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
1120 return FVIntegerAttr::get(getContext(), attr.getValue());
1121
1122 return {};
1123}
1124
1125//===----------------------------------------------------------------------===//
1126// LogicToTimeOp
1127//===----------------------------------------------------------------------===//
1128
1129OpFoldResult LogicToTimeOp::fold(FoldAdaptor adaptor) {
1130 // logic_to_time(time_to_logic(x)) -> x
1131 if (auto reverseOp = getInput().getDefiningOp<TimeToLogicOp>())
1132 return reverseOp.getInput();
1133
1134 // Convert constants.
1135 if (auto attr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1136 return IntegerAttr::get(getContext(), APSInt(attr.getValue().toAPInt(false),
1137 /*isUnsigned=*/true));
1138
1139 return {};
1140}
1141
1142//===----------------------------------------------------------------------===//
1143// TruncOp
1144//===----------------------------------------------------------------------===//
1145
1146OpFoldResult TruncOp::fold(FoldAdaptor adaptor) {
1147 // Truncate constants.
1148 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1149 auto width = getType().getWidth();
1150 return FVIntegerAttr::get(getContext(), intAttr.getValue().trunc(width));
1151 }
1152
1153 return {};
1154}
1155
1156//===----------------------------------------------------------------------===//
1157// ZExtOp
1158//===----------------------------------------------------------------------===//
1159
1160OpFoldResult ZExtOp::fold(FoldAdaptor adaptor) {
1161 // Zero-extend constants.
1162 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1163 auto width = getType().getWidth();
1164 return FVIntegerAttr::get(getContext(), intAttr.getValue().zext(width));
1165 }
1166
1167 return {};
1168}
1169
1170//===----------------------------------------------------------------------===//
1171// SExtOp
1172//===----------------------------------------------------------------------===//
1173
1174OpFoldResult SExtOp::fold(FoldAdaptor adaptor) {
1175 // Sign-extend constants.
1176 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1177 auto width = getType().getWidth();
1178 return FVIntegerAttr::get(getContext(), intAttr.getValue().sext(width));
1179 }
1180
1181 return {};
1182}
1183
1184//===----------------------------------------------------------------------===//
1185// BoolCastOp
1186//===----------------------------------------------------------------------===//
1187
1188OpFoldResult BoolCastOp::fold(FoldAdaptor adaptor) {
1189 // Fold away no-op casts.
1190 if (getInput().getType() == getResult().getType())
1191 return getInput();
1192 return {};
1193}
1194
1195//===----------------------------------------------------------------------===//
1196// BlockingAssignOp
1197//===----------------------------------------------------------------------===//
1198
1199bool BlockingAssignOp::loadsFrom(const MemorySlot &slot) { return false; }
1200
1201bool BlockingAssignOp::storesTo(const MemorySlot &slot) {
1202 return getDst() == slot.ptr;
1203}
1204
1205Value BlockingAssignOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1206 Value reachingDef,
1207 const DataLayout &dataLayout) {
1208 return getSrc();
1209}
1210
1211bool BlockingAssignOp::canUsesBeRemoved(
1212 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1213 SmallVectorImpl<OpOperand *> &newBlockingUses,
1214 const DataLayout &dataLayout) {
1215
1216 if (blockingUses.size() != 1)
1217 return false;
1218 Value blockingUse = (*blockingUses.begin())->get();
1219 return blockingUse == slot.ptr && getDst() == slot.ptr &&
1220 getSrc() != slot.ptr && getSrc().getType() == slot.elemType;
1221}
1222
1223DeletionKind BlockingAssignOp::removeBlockingUses(
1224 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1225 OpBuilder &builder, Value reachingDefinition,
1226 const DataLayout &dataLayout) {
1227 return DeletionKind::Delete;
1228}
1229
1230//===----------------------------------------------------------------------===//
1231// ReadOp
1232//===----------------------------------------------------------------------===//
1233
1234bool ReadOp::loadsFrom(const MemorySlot &slot) {
1235 return getInput() == slot.ptr;
1236}
1237
1238bool ReadOp::storesTo(const MemorySlot &slot) { return false; }
1239
1240Value ReadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1241 Value reachingDef, const DataLayout &dataLayout) {
1242 llvm_unreachable("getStored should not be called on ReadOp");
1243}
1244
1245bool ReadOp::canUsesBeRemoved(const MemorySlot &slot,
1246 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1247 SmallVectorImpl<OpOperand *> &newBlockingUses,
1248 const DataLayout &dataLayout) {
1249
1250 if (blockingUses.size() != 1)
1251 return false;
1252 Value blockingUse = (*blockingUses.begin())->get();
1253 return blockingUse == slot.ptr && getOperand() == slot.ptr &&
1254 getResult().getType() == slot.elemType;
1255}
1256
1257DeletionKind
1258ReadOp::removeBlockingUses(const MemorySlot &slot,
1259 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1260 OpBuilder &builder, Value reachingDefinition,
1261 const DataLayout &dataLayout) {
1262 getResult().replaceAllUsesWith(reachingDefinition);
1263 return DeletionKind::Delete;
1264}
1265
1266//===----------------------------------------------------------------------===//
1267// PowSOp
1268//===----------------------------------------------------------------------===//
1269
1270static OpFoldResult powCommonFolding(MLIRContext *ctxt, Attribute lhs,
1271 Attribute rhs) {
1272 auto lhsValue = dyn_cast_or_null<FVIntegerAttr>(lhs);
1273 if (lhsValue && lhsValue.getValue() == 1)
1274 return lhs;
1275
1276 auto rhsValue = dyn_cast_or_null<FVIntegerAttr>(rhs);
1277 if (rhsValue && rhsValue.getValue().isZero())
1278 return FVIntegerAttr::get(ctxt,
1279 FVInt(rhsValue.getValue().getBitWidth(), 1));
1280
1281 return {};
1282}
1283
1284OpFoldResult PowSOp::fold(FoldAdaptor adaptor) {
1285 return powCommonFolding(getContext(), adaptor.getLhs(), adaptor.getRhs());
1286}
1287
1288LogicalResult PowSOp::canonicalize(PowSOp op, PatternRewriter &rewriter) {
1289 Location loc = op.getLoc();
1290 auto intType = cast<IntType>(op.getRhs().getType());
1291 if (auto baseOp = op.getLhs().getDefiningOp<ConstantOp>()) {
1292 if (baseOp.getValue() == 2) {
1293 Value constOne = ConstantOp::create(rewriter, loc, intType, 1);
1294 Value constZero = ConstantOp::create(rewriter, loc, intType, 0);
1295 Value shift = ShlOp::create(rewriter, loc, constOne, op.getRhs());
1296 Value isNegative = SltOp::create(rewriter, loc, op.getRhs(), constZero);
1297 auto condOp = rewriter.replaceOpWithNewOp<ConditionalOp>(
1298 op, op.getLhs().getType(), isNegative);
1299 Block *thenBlock = rewriter.createBlock(&condOp.getTrueRegion());
1300 rewriter.setInsertionPointToStart(thenBlock);
1301 YieldOp::create(rewriter, loc, constZero);
1302 Block *elseBlock = rewriter.createBlock(&condOp.getFalseRegion());
1303 rewriter.setInsertionPointToStart(elseBlock);
1304 YieldOp::create(rewriter, loc, shift);
1305 return success();
1306 }
1307 }
1308
1309 return failure();
1310}
1311
1312//===----------------------------------------------------------------------===//
1313// PowUOp
1314//===----------------------------------------------------------------------===//
1315
1316OpFoldResult PowUOp::fold(FoldAdaptor adaptor) {
1317 return powCommonFolding(getContext(), adaptor.getLhs(), adaptor.getRhs());
1318}
1319
1320LogicalResult PowUOp::canonicalize(PowUOp op, PatternRewriter &rewriter) {
1321 Location loc = op.getLoc();
1322 auto intType = cast<IntType>(op.getRhs().getType());
1323 if (auto baseOp = op.getLhs().getDefiningOp<ConstantOp>()) {
1324 if (baseOp.getValue() == 2) {
1325 Value constOne = ConstantOp::create(rewriter, loc, intType, 1);
1326 rewriter.replaceOpWithNewOp<ShlOp>(op, constOne, op.getRhs());
1327 return success();
1328 }
1329 }
1330
1331 return failure();
1332}
1333
1334//===----------------------------------------------------------------------===//
1335// SubOp
1336//===----------------------------------------------------------------------===//
1337
1338OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1339 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs()))
1340 if (intAttr.getValue().isZero())
1341 return getLhs();
1342
1343 return {};
1344}
1345
1346//===----------------------------------------------------------------------===//
1347// MulOp
1348//===----------------------------------------------------------------------===//
1349
1350OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1351 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1352 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1353 if (lhs && rhs)
1354 return FVIntegerAttr::get(getContext(), lhs.getValue() * rhs.getValue());
1355 return {};
1356}
1357
1358//===----------------------------------------------------------------------===//
1359// DivUOp
1360//===----------------------------------------------------------------------===//
1361
1362OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1363 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1364 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1365 if (lhs && rhs)
1366 return FVIntegerAttr::get(getContext(),
1367 lhs.getValue().udiv(rhs.getValue()));
1368 return {};
1369}
1370
1371//===----------------------------------------------------------------------===//
1372// DivSOp
1373//===----------------------------------------------------------------------===//
1374
1375OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1376 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1377 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1378 if (lhs && rhs)
1379 return FVIntegerAttr::get(getContext(),
1380 lhs.getValue().sdiv(rhs.getValue()));
1381 return {};
1382}
1383
1384//===----------------------------------------------------------------------===//
1385// Classes
1386//===----------------------------------------------------------------------===//
1387
1388LogicalResult ClassDeclOp::verify() {
1389 mlir::Region &body = getBody();
1390 if (body.empty())
1391 return mlir::success();
1392
1393 auto &block = body.front();
1394 for (mlir::Operation &op : block) {
1395
1396 // allow only property and method decls and terminator
1397 if (llvm::isa<circt::moore::ClassPropertyDeclOp,
1398 circt::moore::ClassMethodDeclOp>(&op))
1399 continue;
1400
1401 return emitOpError()
1402 << "body may only contain 'moore.class.propertydecl' operations";
1403 }
1404 return mlir::success();
1405}
1406
1407LogicalResult ClassNewOp::verify() {
1408 // The result is constrained to ClassHandleType in ODS, so this cast should be
1409 // safe.
1410 auto handleTy = cast<ClassHandleType>(getResult().getType());
1411 mlir::SymbolRefAttr classSym = handleTy.getClassSym();
1412 if (!classSym)
1413 return emitOpError("result type is missing a class symbol");
1414
1415 // Resolve the referenced symbol starting from the nearest symbol table.
1416 mlir::Operation *sym =
1417 mlir::SymbolTable::lookupNearestSymbolFrom(getOperation(), classSym);
1418 if (!sym)
1419 return emitOpError("referenced class symbol `")
1420 << classSym << "` was not found";
1421
1422 if (!llvm::isa<ClassDeclOp>(sym))
1423 return emitOpError("symbol `")
1424 << classSym << "` does not name a `moore.class.classdecl`";
1425
1426 return mlir::success();
1427}
1428
1429void ClassNewOp::getEffects(
1430 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1431 &effects) {
1432 // Always allocates heap memory.
1433 effects.emplace_back(MemoryEffects::Allocate::get());
1434}
1435
1436LogicalResult
1437ClassUpcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1438 // 1) Type checks.
1439 auto srcTy = dyn_cast<ClassHandleType>(getOperand().getType());
1440 if (!srcTy)
1441 return emitOpError() << "operand must be !moore.class<...>; got "
1442 << getOperand().getType();
1443
1444 auto dstTy = dyn_cast<ClassHandleType>(getResult().getType());
1445 if (!dstTy)
1446 return emitOpError() << "result must be !moore.class<...>; got "
1447 << getResult().getType();
1448
1449 if (srcTy == dstTy)
1450 return success();
1451
1452 auto *op = getOperation();
1453
1454 auto *srcDeclOp =
1455 symbolTable.lookupNearestSymbolFrom(op, srcTy.getClassSym());
1456 auto *dstDeclOp =
1457 symbolTable.lookupNearestSymbolFrom(op, dstTy.getClassSym());
1458 if (!srcDeclOp || !dstDeclOp)
1459 return emitOpError() << "failed to resolve class symbol(s): src="
1460 << srcTy.getClassSym()
1461 << ", dst=" << dstTy.getClassSym();
1462
1463 auto srcDecl = dyn_cast<ClassDeclOp>(srcDeclOp);
1464 auto dstDecl = dyn_cast<ClassDeclOp>(dstDeclOp);
1465 if (!srcDecl || !dstDecl)
1466 return emitOpError()
1467 << "symbol(s) do not name `moore.class.classdecl` ops: src="
1468 << srcTy.getClassSym() << ", dst=" << dstTy.getClassSym();
1469
1470 auto cur = srcDecl;
1471 while (cur) {
1472 if (cur == dstDecl)
1473 return success(); // legal upcast: dst is src or an ancestor
1474
1475 auto baseSym = cur.getBaseAttr();
1476 if (!baseSym)
1477 break;
1478
1479 auto *baseOp = symbolTable.lookupNearestSymbolFrom(op, baseSym);
1480 cur = llvm::dyn_cast_or_null<ClassDeclOp>(baseOp);
1481 }
1482
1483 return emitOpError() << "cannot upcast from " << srcTy.getClassSym() << " to "
1484 << dstTy.getClassSym()
1485 << " (destination is not a base class)";
1486}
1487LogicalResult
1488ClassPropertyRefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1489 // The operand is constrained to ClassHandleType in ODS; unwrap it.
1490 Type instTy = getInstance().getType();
1491 auto handleTy = dyn_cast<moore::ClassHandleType>(instTy);
1492 if (!handleTy)
1493 return emitOpError() << "instance must be a !moore.class<@C> value, got "
1494 << instTy;
1495
1496 // Extract the referenced class symbol from the handle type.
1497 SymbolRefAttr classSym = handleTy.getClassSym();
1498 if (!classSym)
1499 return emitOpError("instance type is missing a class symbol");
1500
1501 // Resolve the class symbol starting from the nearest symbol table.
1502 Operation *clsSym =
1503 symbolTable.lookupNearestSymbolFrom(getOperation(), classSym);
1504 if (!clsSym)
1505 return emitOpError("referenced class symbol `")
1506 << classSym << "` was not found";
1507 auto classDecl = dyn_cast<ClassDeclOp>(clsSym);
1508 if (!classDecl)
1509 return emitOpError("symbol `")
1510 << classSym << "` does not name a `moore.class.classdecl`";
1511
1512 // Look up the field symbol inside the class declaration's symbol table.
1513 FlatSymbolRefAttr fieldSym = getPropertyAttr();
1514 if (!fieldSym)
1515 return emitOpError("missing field symbol");
1516
1517 Operation *fldSym = symbolTable.lookupSymbolIn(classDecl, fieldSym.getAttr());
1518 if (!fldSym)
1519 return emitOpError("no field `") << fieldSym << "` in class " << classSym;
1520
1521 auto fieldDecl = dyn_cast<ClassPropertyDeclOp>(fldSym);
1522 if (!fieldDecl)
1523 return emitOpError("symbol `")
1524 << fieldSym << "` is not a `moore.class.propertydecl`";
1525
1526 // Result must be !moore.ref<T> where T matches the field's declared type.
1527 auto resRefTy = cast<RefType>(getPropertyRef().getType());
1528 if (!resRefTy)
1529 return emitOpError("result must be a !moore.ref<T>");
1530
1531 Type expectedElemTy = fieldDecl.getPropertyType();
1532 if (resRefTy.getNestedType() != expectedElemTy)
1533 return emitOpError("result element type (")
1534 << resRefTy.getNestedType() << ") does not match field type ("
1535 << expectedElemTy << ")";
1536
1537 return success();
1538}
1539
1540//===----------------------------------------------------------------------===//
1541// TableGen generated logic.
1542//===----------------------------------------------------------------------===//
1543
1544// Provide the autogenerated implementation guts for the Op classes.
1545#define GET_OP_CLASSES
1546#include "circt/Dialect/Moore/Moore.cpp.inc"
1547#include "circt/Dialect/Moore/MooreEnums.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
@ Output
Definition HW.h:35
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static OpFoldResult powCommonFolding(MLIRContext *ctxt, Attribute lhs, Attribute rhs)
static ArrayRef< StructLikeMember > getStructMembers(Type type)
Definition MooreOps.cpp:758
static std::optional< uint32_t > getStructFieldIndex(Type type, StringAttr name)
Definition MooreOps.cpp:749
static UnpackedType getStructFieldType(Type type, StringAttr name)
Definition MooreOps.cpp:767
static std::pair< unsigned, UnpackedType > getArrayElements(Type type)
Definition MooreOps.cpp:716
static InstancePath empty
Four-valued arbitrary precision integers.
Definition FVInt.h:37
bool isNegative() const
Determine whether the integer interpreted as a signed number would be negative.
Definition FVInt.h:185
FVInt sext(unsigned bitWidth) const
Sign-extend the integer to a new bit width.
Definition FVInt.h:148
unsigned getSignificantBits() const
Compute the minimum bit width necessary to accurately represent this integer's value and sign.
Definition FVInt.h:102
static FVInt getAllX(unsigned numBits)
Construct an FVInt with all bits set to X.
Definition FVInt.h:75
bool hasUnknown() const
Determine if any bits are X or Z.
Definition FVInt.h:168
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition FVInt.h:92
unsigned getBitWidth() const
Return the number of bits this integer has.
Definition FVInt.h:85
FVInt trunc(unsigned bitWidth) const
Truncate the integer to a smaller bit width.
Definition FVInt.h:132
A packed SystemVerilog type.
Definition MooreTypes.h:153
std::optional< unsigned > getBitSize() const
Get the size of this type in bits.
Domain getDomain() const
Get the value domain of this type.
An unpacked SystemVerilog type.
Definition MooreTypes.h:101
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition HWOps.cpp:529
Domain
The number of values each bit of a type can assume.
Definition MooreTypes.h:49
RealWidth
The type of floating point / real number behind a RealType.
Definition MooreTypes.h:57
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
void printFVInt(AsmPrinter &p, const FVInt &value)
Print a four-valued integer usign an AsmPrinter.
Definition FVInt.cpp:147
ParseResult parseFVInt(AsmParser &p, FVInt &result)
Parse a four-valued integer using an AsmParser.
Definition FVInt.cpp:162
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183