CIRCT 22.0.0git
Loading...
Searching...
No Matches
MooreOps.cpp
Go to the documentation of this file.
1//===- MooreOps.cpp - Implement the Moore operations ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Moore dialect operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Builders.h"
19#include "llvm/ADT/APSInt.h"
20#include "llvm/ADT/SmallString.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include <mlir/Dialect/Func/IR/FuncOps.h>
23
24using namespace circt;
25using namespace circt::moore;
26using namespace mlir;
27
28//===----------------------------------------------------------------------===//
29// SVModuleOp
30//===----------------------------------------------------------------------===//
31
32void SVModuleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
33 llvm::StringRef name, hw::ModuleType type) {
34 state.addAttribute(SymbolTable::getSymbolAttrName(),
35 builder.getStringAttr(name));
36 state.addAttribute(getModuleTypeAttrName(state.name), TypeAttr::get(type));
37 state.addRegion();
38}
39
40void SVModuleOp::print(OpAsmPrinter &p) {
41 p << " ";
42
43 // Print the visibility of the module.
44 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
45 if (auto visibility = (*this)->getAttrOfType<StringAttr>(visibilityAttrName))
46 p << visibility.getValue() << ' ';
47
48 p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
50 getModuleType(), {}, {});
51 p << " ";
52 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false,
53 /*printBlockTerminators=*/true);
54
55 p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(),
56 getAttributeNames());
57}
58
59ParseResult SVModuleOp::parse(OpAsmParser &parser, OperationState &result) {
60 // Parse the visibility attribute.
61 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
62
63 // Parse the module name.
64 StringAttr nameAttr;
65 if (parser.parseSymbolName(nameAttr, getSymNameAttrName(result.name),
66 result.attributes))
67 return failure();
68
69 // Parse the ports.
70 SmallVector<hw::module_like_impl::PortParse> ports;
71 TypeAttr modType;
72 if (failed(
73 hw::module_like_impl::parseModuleSignature(parser, ports, modType)))
74 return failure();
75 result.addAttribute(getModuleTypeAttrName(result.name), modType);
76
77 // Parse the attributes.
78 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
79 return failure();
80
81 // Add the entry block arguments.
82 SmallVector<OpAsmParser::Argument, 4> entryArgs;
83 for (auto &port : ports)
84 if (port.direction != hw::ModulePort::Direction::Output)
85 entryArgs.push_back(port);
86
87 // Parse the optional function body.
88 auto &bodyRegion = *result.addRegion();
89 if (parser.parseRegion(bodyRegion, entryArgs))
90 return failure();
91
92 ensureTerminator(bodyRegion, parser.getBuilder(), result.location);
93 return success();
94}
95
96void SVModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
97 mlir::OpAsmSetValueNameFn setNameFn) {
98 if (&region != &getBodyRegion())
99 return;
100 auto moduleType = getModuleType();
101 for (auto [index, arg] : llvm::enumerate(region.front().getArguments()))
102 setNameFn(arg, moduleType.getInputNameAttr(index));
103}
104
105OutputOp SVModuleOp::getOutputOp() {
106 return cast<OutputOp>(getBody()->getTerminator());
107}
108
109OperandRange SVModuleOp::getOutputs() { return getOutputOp().getOperands(); }
110
111//===----------------------------------------------------------------------===//
112// OutputOp
113//===----------------------------------------------------------------------===//
114
115LogicalResult OutputOp::verify() {
116 auto module = getParentOp();
117
118 // Check that the number of operands matches the number of output ports.
119 auto outputTypes = module.getModuleType().getOutputTypes();
120 if (outputTypes.size() != getNumOperands())
121 return emitOpError("has ")
122 << getNumOperands() << " operands, but enclosing module @"
123 << module.getSymName() << " has " << outputTypes.size()
124 << " outputs";
125
126 // Check that the operand types match the output ports.
127 for (unsigned i = 0, e = outputTypes.size(); i != e; ++i)
128 if (outputTypes[i] != getOperand(i).getType())
129 return emitOpError() << "operand " << i << " (" << getOperand(i).getType()
130 << ") does not match output type (" << outputTypes[i]
131 << ") of module @" << module.getSymName();
132
133 return success();
134}
135
136//===----------------------------------------------------------------------===//
137// InstanceOp
138//===----------------------------------------------------------------------===//
139
140LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
141 // Resolve the target symbol.
142 auto *symbol =
143 symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
144 if (!symbol)
145 return emitOpError("references unknown symbol @") << getModuleName();
146
147 // Check that the symbol is a SVModuleOp.
148 auto module = dyn_cast<SVModuleOp>(symbol);
149 if (!module)
150 return emitOpError("must reference a 'moore.module', but @")
151 << getModuleName() << " is a '" << symbol->getName() << "'";
152
153 // Check that the input ports match.
154 auto moduleType = module.getModuleType();
155 auto inputTypes = moduleType.getInputTypes();
156
157 if (inputTypes.size() != getNumOperands())
158 return emitOpError("has ")
159 << getNumOperands() << " operands, but target module @"
160 << module.getSymName() << " has " << inputTypes.size() << " inputs";
161
162 for (unsigned i = 0, e = inputTypes.size(); i != e; ++i)
163 if (inputTypes[i] != getOperand(i).getType())
164 return emitOpError() << "operand " << i << " (" << getOperand(i).getType()
165 << ") does not match input type (" << inputTypes[i]
166 << ") of module @" << module.getSymName();
167
168 // Check that the output ports match.
169 auto outputTypes = moduleType.getOutputTypes();
170
171 if (outputTypes.size() != getNumResults())
172 return emitOpError("has ")
173 << getNumOperands() << " results, but target module @"
174 << module.getSymName() << " has " << outputTypes.size()
175 << " outputs";
176
177 for (unsigned i = 0, e = outputTypes.size(); i != e; ++i)
178 if (outputTypes[i] != getResult(i).getType())
179 return emitOpError() << "result " << i << " (" << getResult(i).getType()
180 << ") does not match output type (" << outputTypes[i]
181 << ") of module @" << module.getSymName();
182
183 return success();
184}
185
186void InstanceOp::print(OpAsmPrinter &p) {
187 p << " ";
188 p.printAttributeWithoutType(getInstanceNameAttr());
189 p << " ";
190 p.printAttributeWithoutType(getModuleNameAttr());
191 printInputPortList(p, getOperation(), getInputs(), getInputs().getTypes(),
192 getInputNames());
193 p << " -> ";
194 printOutputPortList(p, getOperation(), getOutputs().getTypes(),
195 getOutputNames());
196 p.printOptionalAttrDict(getOperation()->getAttrs(), getAttributeNames());
197}
198
199ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
200 // Parse the instance name.
201 StringAttr instanceName;
202 if (parser.parseAttribute(instanceName, "instanceName", result.attributes))
203 return failure();
204
205 // Parse the module name.
206 FlatSymbolRefAttr moduleName;
207 if (parser.parseAttribute(moduleName, "moduleName", result.attributes))
208 return failure();
209
210 // Parse the input port list.
211 auto loc = parser.getCurrentLocation();
212 SmallVector<OpAsmParser::UnresolvedOperand> inputs;
213 SmallVector<Type> types;
214 ArrayAttr names;
215 if (parseInputPortList(parser, inputs, types, names))
216 return failure();
217 if (parser.resolveOperands(inputs, types, loc, result.operands))
218 return failure();
219 result.addAttribute("inputNames", names);
220
221 // Parse `->`.
222 if (parser.parseArrow())
223 return failure();
224
225 // Parse the output port list.
226 types.clear();
227 if (parseOutputPortList(parser, types, names))
228 return failure();
229 result.addAttribute("outputNames", names);
230 result.addTypes(types);
231
232 // Parse the attributes.
233 if (parser.parseOptionalAttrDict(result.attributes))
234 return failure();
235
236 return success();
237}
238
239void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
240 SmallString<32> name;
241 name += getInstanceName();
242 name += '.';
243 auto baseLen = name.size();
244
245 for (auto [result, portName] :
246 llvm::zip(getOutputs(), getOutputNames().getAsRange<StringAttr>())) {
247 if (!portName || portName.empty())
248 continue;
249 name.resize(baseLen);
250 name += portName.getValue();
251 setNameFn(result, name);
252 }
253}
254
255//===----------------------------------------------------------------------===//
256// VariableOp
257//===----------------------------------------------------------------------===//
258
259void VariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
260 if (getName() && !getName()->empty())
261 setNameFn(getResult(), *getName());
262}
263
264LogicalResult VariableOp::canonicalize(VariableOp op,
265 PatternRewriter &rewriter) {
266 // If the variable is embedded in an SSACFG region, move the initial value
267 // into an assignment immediately after the variable op. This allows the
268 // mem2reg pass which cannot handle variables with initial values.
269 auto initial = op.getInitial();
270 if (initial && mlir::mayHaveSSADominance(*op->getParentRegion())) {
271 rewriter.modifyOpInPlace(op, [&] { op.getInitialMutable().clear(); });
272 rewriter.setInsertionPointAfter(op);
273 BlockingAssignOp::create(rewriter, initial.getLoc(), op, initial);
274 return success();
275 }
276
277 // Check if the variable has one unique continuous assignment to it, all other
278 // uses are reads, and that all uses are in the same block as the variable
279 // itself.
280 auto *block = op->getBlock();
281 ContinuousAssignOp uniqueAssignOp;
282 for (auto *user : op->getUsers()) {
283 // Ensure that all users of the variable are in the same block.
284 if (user->getBlock() != block)
285 return failure();
286
287 // Ensure there is at most one unique continuous assignment to the variable.
288 if (auto assignOp = dyn_cast<ContinuousAssignOp>(user)) {
289 if (uniqueAssignOp)
290 return failure();
291 uniqueAssignOp = assignOp;
292 continue;
293 }
294
295 // Ensure all other users are reads.
296 if (!isa<ReadOp>(user))
297 return failure();
298 }
299 if (!uniqueAssignOp)
300 return failure();
301
302 // If the original variable had a name, create an `AssignedVariableOp` as a
303 // replacement. Otherwise substitute the assigned value directly.
304 Value assignedValue = uniqueAssignOp.getSrc();
305 if (auto name = op.getNameAttr(); name && !name.empty())
306 assignedValue = AssignedVariableOp::create(rewriter, op.getLoc(), name,
307 uniqueAssignOp.getSrc());
308
309 // Remove the assign op and replace all reads with the new assigned var op.
310 rewriter.eraseOp(uniqueAssignOp);
311 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
312 auto readOp = cast<ReadOp>(user);
313 rewriter.replaceOp(readOp, assignedValue);
314 }
315
316 // Remove the original variable.
317 rewriter.eraseOp(op);
318 return success();
319}
320
321SmallVector<MemorySlot> VariableOp::getPromotableSlots() {
322 // We cannot promote variables with an initial value, since that value may not
323 // dominate the location where the default value needs to be constructed.
324 if (mlir::mayBeGraphRegion(*getOperation()->getParentRegion()) ||
325 getInitial())
326 return {};
327
328 // Ensure that `getDefaultValue` can conjure up a default value for the
329 // variable's type.
330 auto nestedType = dyn_cast<PackedType>(getType().getNestedType());
331 if (!nestedType || !nestedType.getBitSize())
332 return {};
333
334 return {MemorySlot{getResult(), getType().getNestedType()}};
335}
336
337Value VariableOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) {
338 auto packedType = dyn_cast<PackedType>(slot.elemType);
339 if (!packedType)
340 return {};
341 auto bitWidth = packedType.getBitSize();
342 if (!bitWidth)
343 return {};
344 auto fvint = packedType.getDomain() == Domain::FourValued
345 ? FVInt::getAllX(*bitWidth)
346 : FVInt::getZero(*bitWidth);
347 Value value = ConstantOp::create(
348 builder, getLoc(),
349 IntType::get(getContext(), *bitWidth, packedType.getDomain()), fvint);
350 if (value.getType() != packedType)
351 SBVToPackedOp::create(builder, getLoc(), packedType, value);
352 return value;
353}
354
355void VariableOp::handleBlockArgument(const MemorySlot &slot,
356 BlockArgument argument,
357 OpBuilder &builder) {}
358
359std::optional<mlir::PromotableAllocationOpInterface>
360VariableOp::handlePromotionComplete(const MemorySlot &slot, Value defaultValue,
361 OpBuilder &builder) {
362 if (defaultValue && defaultValue.use_empty())
363 defaultValue.getDefiningOp()->erase();
364 this->erase();
365 return {};
366}
367
368SmallVector<DestructurableMemorySlot> VariableOp::getDestructurableSlots() {
369 if (isa<SVModuleOp>(getOperation()->getParentOp()))
370 return {};
371 if (getInitial())
372 return {};
373
374 auto refType = getType();
375 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(refType);
376 if (!destructurable)
377 return {};
378
379 auto destructuredType = destructurable.getSubelementIndexMap();
380 if (!destructuredType)
381 return {};
382
383 return {DestructurableMemorySlot{{getResult(), refType}, *destructuredType}};
384}
385
386DenseMap<Attribute, MemorySlot> VariableOp::destructure(
387 const DestructurableMemorySlot &slot,
388 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
389 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
390 assert(slot.ptr == getResult());
391 assert(!getInitial());
392 builder.setInsertionPointAfter(*this);
393
394 auto destructurableType = cast<DestructurableTypeInterface>(getType());
395 DenseMap<Attribute, MemorySlot> slotMap;
396 for (Attribute index : usedIndices) {
397 auto elemType = cast<RefType>(destructurableType.getTypeAtIndex(index));
398 assert(elemType && "used index must exist");
399 StringAttr varName;
400 if (auto name = getName(); name && !name->empty())
401 varName = StringAttr::get(
402 getContext(), (*name) + "." + cast<StringAttr>(index).getValue());
403 auto varOp =
404 VariableOp::create(builder, getLoc(), elemType, varName, Value());
405 newAllocators.push_back(varOp);
406 slotMap.try_emplace<MemorySlot>(index, {varOp.getResult(), elemType});
407 }
408
409 return slotMap;
410}
411
412std::optional<DestructurableAllocationOpInterface>
413VariableOp::handleDestructuringComplete(const DestructurableMemorySlot &slot,
414 OpBuilder &builder) {
415 assert(slot.ptr == getResult());
416 this->erase();
417 return std::nullopt;
418}
419
420//===----------------------------------------------------------------------===//
421// NetOp
422//===----------------------------------------------------------------------===//
423
424void NetOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
425 if (getName() && !getName()->empty())
426 setNameFn(getResult(), *getName());
427}
428
429LogicalResult NetOp::canonicalize(NetOp op, PatternRewriter &rewriter) {
430 bool modified = false;
431
432 // Check if the net has one unique continuous assignment to it, and
433 // additionally if all other users are reads.
434 auto *block = op->getBlock();
435 ContinuousAssignOp uniqueAssignOp;
436 bool allUsesAreReads = true;
437 for (auto *user : op->getUsers()) {
438 // Ensure that all users of the net are in the same block.
439 if (user->getBlock() != block)
440 return failure();
441
442 // Ensure there is at most one unique continuous assignment to the net.
443 if (auto assignOp = dyn_cast<ContinuousAssignOp>(user)) {
444 if (uniqueAssignOp)
445 return failure();
446 uniqueAssignOp = assignOp;
447 continue;
448 }
449
450 // Ensure all other users are reads.
451 if (!isa<ReadOp>(user))
452 allUsesAreReads = false;
453 }
454
455 // If there was one unique assignment, and the `NetOp` does not yet have an
456 // assigned value set, fold the assignment into the net.
457 if (uniqueAssignOp && !op.getAssignment()) {
458 rewriter.modifyOpInPlace(
459 op, [&] { op.getAssignmentMutable().assign(uniqueAssignOp.getSrc()); });
460 rewriter.eraseOp(uniqueAssignOp);
461 modified = true;
462 uniqueAssignOp = {};
463 }
464
465 // If all users of the net op are reads, and any potential unique assignment
466 // has been folded into the net op itself, directly replace the reads with the
467 // net's assigned value.
468 if (!uniqueAssignOp && allUsesAreReads && op.getAssignment()) {
469 // If the original net had a name, create an `AssignedVariableOp` as a
470 // replacement. Otherwise substitute the assigned value directly.
471 auto assignedValue = op.getAssignment();
472 if (auto name = op.getNameAttr(); name && !name.empty())
473 assignedValue = AssignedVariableOp::create(rewriter, op.getLoc(), name,
474 assignedValue);
475
476 // Replace all reads with the new assigned var op and remove the original
477 // net op.
478 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
479 auto readOp = cast<ReadOp>(user);
480 rewriter.replaceOp(readOp, assignedValue);
481 }
482 rewriter.eraseOp(op);
483 modified = true;
484 }
485
486 return success(modified);
487}
488
489//===----------------------------------------------------------------------===//
490// AssignedVariableOp
491//===----------------------------------------------------------------------===//
492
493void AssignedVariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
494 if (getName() && !getName()->empty())
495 setNameFn(getResult(), *getName());
496}
497
498LogicalResult AssignedVariableOp::canonicalize(AssignedVariableOp op,
499 PatternRewriter &rewriter) {
500 // Eliminate chained variables with the same name.
501 // var(name, var(name, x)) -> var(name, x)
502 if (auto otherOp = op.getInput().getDefiningOp<AssignedVariableOp>()) {
503 if (otherOp.getNameAttr() == op.getNameAttr()) {
504 rewriter.replaceOp(op, otherOp);
505 return success();
506 }
507 }
508
509 // Eliminate variables that alias an input port of the same name.
510 if (auto blockArg = dyn_cast<BlockArgument>(op.getInput())) {
511 if (auto moduleOp =
512 dyn_cast<SVModuleOp>(blockArg.getOwner()->getParentOp())) {
513 auto moduleType = moduleOp.getModuleType();
514 auto portName = moduleType.getInputNameAttr(blockArg.getArgNumber());
515 if (portName == op.getNameAttr()) {
516 rewriter.replaceOp(op, blockArg);
517 return success();
518 }
519 }
520 }
521
522 // Eliminate variables that feed an output port of the same name.
523 for (auto &use : op->getUses()) {
524 auto *useOwner = use.getOwner();
525 if (auto outputOp = dyn_cast<OutputOp>(useOwner)) {
526 if (auto moduleOp = dyn_cast<SVModuleOp>(outputOp->getParentOp())) {
527 auto moduleType = moduleOp.getModuleType();
528 auto portName = moduleType.getOutputNameAttr(use.getOperandNumber());
529 if (portName == op.getNameAttr()) {
530 rewriter.replaceOp(op, op.getInput());
531 return success();
532 }
533 } else
534 break;
535 }
536 }
537
538 return failure();
539}
540
541//===----------------------------------------------------------------------===//
542// GlobalVariableOp
543//===----------------------------------------------------------------------===//
544
545LogicalResult GlobalVariableOp::verifyRegions() {
546 if (auto *block = getInitBlock()) {
547 auto &terminator = block->back();
548 if (!isa<YieldOp>(terminator))
549 return emitOpError() << "must have a 'moore.yield' terminator";
550 }
551 return success();
552}
553
554Block *GlobalVariableOp::getInitBlock() {
555 if (getInitRegion().empty())
556 return nullptr;
557 return &getInitRegion().front();
558}
559
560//===----------------------------------------------------------------------===//
561// GetGlobalVariableOp
562//===----------------------------------------------------------------------===//
563
564LogicalResult
565GetGlobalVariableOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
566 // Resolve the target symbol.
567 auto *symbol =
568 symbolTable.lookupNearestSymbolFrom(*this, getGlobalNameAttr());
569 if (!symbol)
570 return emitOpError() << "references unknown symbol " << getGlobalNameAttr();
571
572 // Check that the symbol is a global variable.
573 auto var = dyn_cast<GlobalVariableOp>(symbol);
574 if (!var)
575 return emitOpError() << "must reference a 'moore.global_variable', but "
576 << getGlobalNameAttr() << " is a '"
577 << symbol->getName() << "'";
578
579 // Check that the types match.
580 auto expType = var.getType();
581 auto actType = getType().getNestedType();
582 if (expType != actType)
583 return emitOpError() << "returns a " << actType << " reference, but "
584 << getGlobalNameAttr() << " is of type " << expType;
585
586 return success();
587}
588
589//===----------------------------------------------------------------------===//
590// ConstantOp
591//===----------------------------------------------------------------------===//
592
593void ConstantOp::print(OpAsmPrinter &p) {
594 p << " ";
595 printFVInt(p, getValue());
596 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
597 p << " : ";
598 p.printStrippedAttrOrType(getType());
599}
600
601ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
602 // Parse the constant value.
603 FVInt value;
604 auto valueLoc = parser.getCurrentLocation();
605 if (parseFVInt(parser, value))
606 return failure();
607
608 // Parse any optional attributes and the `:`.
609 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
610 return failure();
611
612 // Parse the result type.
613 IntType type;
614 if (parser.parseCustomTypeWithFallback(type))
615 return failure();
616
617 // Extend or truncate the constant value to match the size of the type.
618 if (type.getWidth() > value.getBitWidth()) {
619 // sext is always safe here, even for unsigned values, because the
620 // parseOptionalInteger method will return something with a zero in the
621 // top bits if it is a positive number.
622 value = value.sext(type.getWidth());
623 } else if (type.getWidth() < value.getBitWidth()) {
624 // The parser can return an unnecessarily wide result with leading
625 // zeros. This isn't a problem, but truncating off bits is bad.
626 unsigned neededBits =
627 value.isNegative() ? value.getSignificantBits() : value.getActiveBits();
628 if (type.getWidth() < neededBits)
629 return parser.emitError(valueLoc)
630 << "value requires " << neededBits
631 << " bits, but result type only has " << type.getWidth();
632 value = value.trunc(type.getWidth());
633 }
634
635 // If the constant contains any X or Z bits, the result type must be
636 // four-valued.
637 if (value.hasUnknown() && type.getDomain() != Domain::FourValued)
638 return parser.emitError(valueLoc)
639 << "value contains X or Z bits, but result type " << type
640 << " only allows two-valued bits";
641
642 // Build the attribute and op.
643 auto attrValue = FVIntegerAttr::get(parser.getContext(), value);
644 result.addAttribute("value", attrValue);
645 result.addTypes(type);
646 return success();
647}
648
649LogicalResult ConstantOp::verify() {
650 auto attrWidth = getValue().getBitWidth();
651 auto typeWidth = getType().getWidth();
652 if (attrWidth != typeWidth)
653 return emitError("attribute width ")
654 << attrWidth << " does not match return type's width " << typeWidth;
655 return success();
656}
657
658void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
659 const FVInt &value) {
660 assert(type.getWidth() == value.getBitWidth() &&
661 "FVInt width must match type width");
662 build(builder, result, type, FVIntegerAttr::get(builder.getContext(), value));
663}
664
665void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
666 const APInt &value) {
667 assert(type.getWidth() == value.getBitWidth() &&
668 "APInt width must match type width");
669 build(builder, result, type, FVInt(value));
670}
671
672/// This builder allows construction of small signed integers like 0, 1, -1
673/// matching a specified MLIR type. This shouldn't be used for general constant
674/// folding because it only works with values that can be expressed in an
675/// `int64_t`.
676void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
677 int64_t value, bool isSigned) {
678 build(builder, result, type,
679 APInt(type.getWidth(), (uint64_t)value, isSigned));
680}
681
682OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
683 assert(adaptor.getOperands().empty() && "constant has no operands");
684 return getValueAttr();
685}
686
687//===----------------------------------------------------------------------===//
688// ConstantTimeOp
689//===----------------------------------------------------------------------===//
690
691OpFoldResult ConstantTimeOp::fold(FoldAdaptor adaptor) {
692 return getValueAttr();
693}
694
695//===----------------------------------------------------------------------===//
696// ConstantRealOp
697//===----------------------------------------------------------------------===//
698
699LogicalResult ConstantRealOp::inferReturnTypes(
700 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
701 DictionaryAttr attrs, mlir::OpaqueProperties properties,
702 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
703 ConstantRealOp::Adaptor adaptor(operands, attrs, properties);
704 results.push_back(RealType::get(
705 context, static_cast<RealWidth>(
706 adaptor.getValueAttr().getType().getIntOrFloatBitWidth())));
707 return success();
708}
709
710//===----------------------------------------------------------------------===//
711// ConcatOp
712//===----------------------------------------------------------------------===//
713
714LogicalResult ConcatOp::inferReturnTypes(
715 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
716 DictionaryAttr attrs, mlir::OpaqueProperties properties,
717 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
718 Domain domain = Domain::TwoValued;
719 unsigned width = 0;
720 for (auto operand : operands) {
721 auto type = cast<IntType>(operand.getType());
722 if (type.getDomain() == Domain::FourValued)
723 domain = Domain::FourValued;
724 width += type.getWidth();
725 }
726 results.push_back(IntType::get(context, width, domain));
727 return success();
728}
729
730//===----------------------------------------------------------------------===//
731// ConcatRefOp
732//===----------------------------------------------------------------------===//
733
734LogicalResult ConcatRefOp::inferReturnTypes(
735 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
736 DictionaryAttr attrs, mlir::OpaqueProperties properties,
737 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
738 Domain domain = Domain::TwoValued;
739 unsigned width = 0;
740 for (Value operand : operands) {
741 UnpackedType nestedType = cast<RefType>(operand.getType()).getNestedType();
742 PackedType packedType = dyn_cast<PackedType>(nestedType);
743
744 if (!packedType) {
745 return failure();
746 }
747
748 if (packedType.getDomain() == Domain::FourValued)
749 domain = Domain::FourValued;
750
751 // getBitSize() for PackedType returns an optional, so we must check it.
752 std::optional<int> bitSize = packedType.getBitSize();
753 if (!bitSize) {
754 return failure();
755 }
756 width += *bitSize;
757 }
758 results.push_back(RefType::get(IntType::get(context, width, domain)));
759 return success();
760}
761
762//===----------------------------------------------------------------------===//
763// ArrayCreateOp
764//===----------------------------------------------------------------------===//
765
766static std::pair<unsigned, UnpackedType> getArrayElements(Type type) {
767 if (auto arrayType = dyn_cast<ArrayType>(type))
768 return {arrayType.getSize(), arrayType.getElementType()};
769 if (auto arrayType = dyn_cast<UnpackedArrayType>(type))
770 return {arrayType.getSize(), arrayType.getElementType()};
771 assert(0 && "expected ArrayType or UnpackedArrayType");
772 return {};
773}
774
775LogicalResult ArrayCreateOp::verify() {
776 auto [size, elementType] = getArrayElements(getType());
777
778 // Check that the number of operands matches the array size.
779 if (getElements().size() != size)
780 return emitOpError() << "has " << getElements().size()
781 << " operands, but result type requires " << size;
782
783 // Check that the operand types match the array element type. We only need to
784 // check one of the operands, since the `SameTypeOperands` trait ensures all
785 // operands have the same type.
786 if (size > 0) {
787 auto value = getElements()[0];
788 if (value.getType() != elementType)
789 return emitOpError() << "operands have type " << value.getType()
790 << ", but array requires " << elementType;
791 }
792 return success();
793}
794
795//===----------------------------------------------------------------------===//
796// StructCreateOp
797//===----------------------------------------------------------------------===//
798
799static std::optional<uint32_t> getStructFieldIndex(Type type, StringAttr name) {
800 if (auto structType = dyn_cast<StructType>(type))
801 return structType.getFieldIndex(name);
802 if (auto structType = dyn_cast<UnpackedStructType>(type))
803 return structType.getFieldIndex(name);
804 assert(0 && "expected StructType or UnpackedStructType");
805 return {};
806}
807
808static ArrayRef<StructLikeMember> getStructMembers(Type type) {
809 if (auto structType = dyn_cast<StructType>(type))
810 return structType.getMembers();
811 if (auto structType = dyn_cast<UnpackedStructType>(type))
812 return structType.getMembers();
813 assert(0 && "expected StructType or UnpackedStructType");
814 return {};
815}
816
817static UnpackedType getStructFieldType(Type type, StringAttr name) {
818 if (auto index = getStructFieldIndex(type, name))
819 return getStructMembers(type)[*index].type;
820 return {};
821}
822
823LogicalResult StructCreateOp::verify() {
824 auto members = getStructMembers(getType());
825
826 // Check that the number of operands matches the number of struct fields.
827 if (getFields().size() != members.size())
828 return emitOpError() << "has " << getFields().size()
829 << " operands, but result type requires "
830 << members.size();
831
832 // Check that the operand types match the struct field types.
833 for (auto [index, pair] : llvm::enumerate(llvm::zip(getFields(), members))) {
834 auto [value, member] = pair;
835 if (value.getType() != member.type)
836 return emitOpError() << "operand #" << index << " has type "
837 << value.getType() << ", but struct field "
838 << member.name << " requires " << member.type;
839 }
840 return success();
841}
842
843OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
844 SmallVector<NamedAttribute> fields;
845 for (auto [member, field] :
846 llvm::zip(getStructMembers(getType()), adaptor.getFields())) {
847 if (!field)
848 return {};
849 fields.push_back(NamedAttribute(member.name, field));
850 }
851 return DictionaryAttr::get(getContext(), fields);
852}
853
854//===----------------------------------------------------------------------===//
855// StructExtractOp
856//===----------------------------------------------------------------------===//
857
858LogicalResult StructExtractOp::verify() {
859 auto type = getStructFieldType(getInput().getType(), getFieldNameAttr());
860 if (!type)
861 return emitOpError() << "extracts field " << getFieldNameAttr()
862 << " which does not exist in " << getInput().getType();
863 if (type != getType())
864 return emitOpError() << "result type " << getType()
865 << " must match struct field type " << type;
866 return success();
867}
868
869OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
870 // Extract on a constant struct input.
871 if (auto fields = dyn_cast_or_null<DictionaryAttr>(adaptor.getInput()))
872 if (auto value = fields.get(getFieldNameAttr()))
873 return value;
874
875 // extract(inject(s, "field", v), "field") -> v
876 if (auto inject = getInput().getDefiningOp<StructInjectOp>()) {
877 if (inject.getFieldNameAttr() == getFieldNameAttr())
878 return inject.getNewValue();
879 return {};
880 }
881
882 // extract(create({"field": v, ...}), "field") -> v
883 if (auto create = getInput().getDefiningOp<StructCreateOp>()) {
884 if (auto index = getStructFieldIndex(create.getType(), getFieldNameAttr()))
885 return create.getFields()[*index];
886 return {};
887 }
888
889 return {};
890}
891
892//===----------------------------------------------------------------------===//
893// StructExtractRefOp
894//===----------------------------------------------------------------------===//
895
896LogicalResult StructExtractRefOp::verify() {
897 auto type = getStructFieldType(
898 cast<RefType>(getInput().getType()).getNestedType(), getFieldNameAttr());
899 if (!type)
900 return emitOpError() << "extracts field " << getFieldNameAttr()
901 << " which does not exist in " << getInput().getType();
902 if (type != getType().getNestedType())
903 return emitOpError() << "result ref of type " << getType().getNestedType()
904 << " must match struct field type " << type;
905 return success();
906}
907
908bool StructExtractRefOp::canRewire(
909 const DestructurableMemorySlot &slot,
910 SmallPtrSetImpl<Attribute> &usedIndices,
911 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
912 const DataLayout &dataLayout) {
913 if (slot.ptr != getInput())
914 return false;
915 auto index = getFieldNameAttr();
916 if (!index || !slot.subelementTypes.contains(index))
917 return false;
918 usedIndices.insert(index);
919 return true;
920}
921
922DeletionKind
923StructExtractRefOp::rewire(const DestructurableMemorySlot &slot,
924 DenseMap<Attribute, MemorySlot> &subslots,
925 OpBuilder &builder, const DataLayout &dataLayout) {
926 auto index = getFieldNameAttr();
927 const MemorySlot &memorySlot = subslots.at(index);
928 replaceAllUsesWith(memorySlot.ptr);
929 getInputMutable().drop();
930 erase();
931 return DeletionKind::Keep;
932}
933
934//===----------------------------------------------------------------------===//
935// StructInjectOp
936//===----------------------------------------------------------------------===//
937
938LogicalResult StructInjectOp::verify() {
939 auto type = getStructFieldType(getInput().getType(), getFieldNameAttr());
940 if (!type)
941 return emitOpError() << "injects field " << getFieldNameAttr()
942 << " which does not exist in " << getInput().getType();
943 if (type != getNewValue().getType())
944 return emitOpError() << "injected value " << getNewValue().getType()
945 << " must match struct field type " << type;
946 return success();
947}
948
949OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
950 auto input = adaptor.getInput();
951 auto newValue = adaptor.getNewValue();
952 if (!input || !newValue)
953 return {};
954 NamedAttrList fields(cast<DictionaryAttr>(input));
955 fields.set(getFieldNameAttr(), newValue);
956 return fields.getDictionary(getContext());
957}
958
959LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
960 PatternRewriter &rewriter) {
961 auto members = getStructMembers(op.getType());
962
963 // Chase a chain of `struct_inject` ops, with an optional final
964 // `struct_create`, and take note of the values assigned to each field.
965 SmallPtrSet<Operation *, 4> injectOps;
966 DenseMap<StringAttr, Value> fieldValues;
967 Value input = op;
968 while (auto injectOp = input.getDefiningOp<StructInjectOp>()) {
969 if (!injectOps.insert(injectOp).second)
970 return failure();
971 fieldValues.insert({injectOp.getFieldNameAttr(), injectOp.getNewValue()});
972 input = injectOp.getInput();
973 }
974 if (auto createOp = input.getDefiningOp<StructCreateOp>())
975 for (auto [value, member] : llvm::zip(createOp.getFields(), members))
976 fieldValues.insert({member.name, value});
977
978 // If the inject chain sets all fields, canonicalize to a `struct_create`.
979 if (fieldValues.size() == members.size()) {
980 SmallVector<Value> values;
981 values.reserve(fieldValues.size());
982 for (auto member : members)
983 values.push_back(fieldValues.lookup(member.name));
984 rewriter.replaceOpWithNewOp<StructCreateOp>(op, op.getType(), values);
985 return success();
986 }
987
988 // If each inject op in the chain assigned to a unique field, there is nothing
989 // to canonicalize.
990 if (injectOps.size() == fieldValues.size())
991 return failure();
992
993 // Otherwise we can eliminate overwrites by creating new injects. The hash map
994 // of field values contains the last assigned value for each field.
995 for (auto member : members)
996 if (auto value = fieldValues.lookup(member.name))
997 input = StructInjectOp::create(rewriter, op.getLoc(), op.getType(), input,
998 member.name, value);
999 rewriter.replaceOp(op, input);
1000 return success();
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// UnionCreateOp
1005//===----------------------------------------------------------------------===//
1006
1007LogicalResult UnionCreateOp::verify() {
1008 /// checks if the types of the input is exactly equal to the union field
1009 /// type
1010 return TypeSwitch<Type, LogicalResult>(getType())
1011 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
1012 auto members = type.getMembers();
1013 auto resultType = getType();
1014 auto fieldName = getFieldName();
1015 for (const auto &member : members)
1016 if (member.name == fieldName && member.type == resultType)
1017 return success();
1018 emitOpError("input type must match the union field type");
1019 return failure();
1020 })
1021 .Default([this](auto &) {
1022 emitOpError("input type must be UnionType or UnpackedUnionType");
1023 return failure();
1024 });
1025}
1026
1027//===----------------------------------------------------------------------===//
1028// UnionExtractOp
1029//===----------------------------------------------------------------------===//
1030
1031LogicalResult UnionExtractOp::verify() {
1032 /// checks if the types of the input is exactly equal to the one of the
1033 /// types of the result union fields
1034 return TypeSwitch<Type, LogicalResult>(getInput().getType())
1035 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
1036 auto members = type.getMembers();
1037 auto fieldName = getFieldName();
1038 auto resultType = getType();
1039 for (const auto &member : members)
1040 if (member.name == fieldName && member.type == resultType)
1041 return success();
1042 emitOpError("result type must match the union field type");
1043 return failure();
1044 })
1045 .Default([this](auto &) {
1046 emitOpError("input type must be UnionType or UnpackedUnionType");
1047 return failure();
1048 });
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// UnionExtractOp
1053//===----------------------------------------------------------------------===//
1054
1055LogicalResult UnionExtractRefOp::verify() {
1056 /// checks if the types of the result is exactly equal to the type of the
1057 /// refe union field
1058 return TypeSwitch<Type, LogicalResult>(getInput().getType().getNestedType())
1059 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
1060 auto members = type.getMembers();
1061 auto fieldName = getFieldName();
1062 auto resultType = getType().getNestedType();
1063 for (const auto &member : members)
1064 if (member.name == fieldName && member.type == resultType)
1065 return success();
1066 emitOpError("result type must match the union field type");
1067 return failure();
1068 })
1069 .Default([this](auto &) {
1070 emitOpError("input type must be UnionType or UnpackedUnionType");
1071 return failure();
1072 });
1073}
1074
1075//===----------------------------------------------------------------------===//
1076// YieldOp
1077//===----------------------------------------------------------------------===//
1078
1079LogicalResult YieldOp::verify() {
1080 Type expType;
1081 auto *parentOp = getOperation()->getParentOp();
1082 if (auto cond = dyn_cast<ConditionalOp>(parentOp)) {
1083 expType = cond.getType();
1084 } else if (auto varOp = dyn_cast<GlobalVariableOp>(parentOp)) {
1085 expType = varOp.getType();
1086 } else {
1087 llvm_unreachable("all in ParentOneOf handled");
1088 }
1089
1090 auto actType = getOperand().getType();
1091 if (expType != actType) {
1092 return emitOpError() << "yields " << actType << ", but parent expects "
1093 << expType;
1094 }
1095 return success();
1096}
1097
1098//===----------------------------------------------------------------------===//
1099// ConversionOp
1100//===----------------------------------------------------------------------===//
1101
1102OpFoldResult ConversionOp::fold(FoldAdaptor adaptor) {
1103 // Fold away no-op casts.
1104 if (getInput().getType() == getResult().getType())
1105 return getInput();
1106
1107 // Convert domains of constant integer inputs.
1108 auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput());
1109 auto fromIntType = dyn_cast<IntType>(getInput().getType());
1110 auto toIntType = dyn_cast<IntType>(getResult().getType());
1111 if (intInput && fromIntType && toIntType &&
1112 fromIntType.getWidth() == toIntType.getWidth()) {
1113 // If we are going *to* a four-valued type, simply pass through the
1114 // constant.
1115 if (toIntType.getDomain() == Domain::FourValued)
1116 return intInput;
1117
1118 // Otherwise map all unknown bits to zero (the default in SystemVerilog) and
1119 // return a new constant.
1120 return FVIntegerAttr::get(getContext(), intInput.getValue().toAPInt(false));
1121 }
1122
1123 return {};
1124}
1125
1126//===----------------------------------------------------------------------===//
1127// LogicToIntOp
1128//===----------------------------------------------------------------------===//
1129
1130OpFoldResult LogicToIntOp::fold(FoldAdaptor adaptor) {
1131 // logic_to_int(int_to_logic(x)) -> x
1132 if (auto reverseOp = getInput().getDefiningOp<IntToLogicOp>())
1133 return reverseOp.getInput();
1134
1135 // Map all unknown bits to zero (the default in SystemVerilog) and return a
1136 // new constant.
1137 if (auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1138 return FVIntegerAttr::get(getContext(), intInput.getValue().toAPInt(false));
1139
1140 return {};
1141}
1142
1143//===----------------------------------------------------------------------===//
1144// IntToLogicOp
1145//===----------------------------------------------------------------------===//
1146
1147OpFoldResult IntToLogicOp::fold(FoldAdaptor adaptor) {
1148 // Cannot fold int_to_logic(logic_to_int(x)) -> x since that would lose
1149 // information.
1150
1151 // Simply pass through constants.
1152 if (auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1153 return intInput;
1154
1155 return {};
1156}
1157
1158//===----------------------------------------------------------------------===//
1159// TimeToLogicOp
1160//===----------------------------------------------------------------------===//
1161
1162OpFoldResult TimeToLogicOp::fold(FoldAdaptor adaptor) {
1163 // time_to_logic(logic_to_time(x)) -> x
1164 if (auto reverseOp = getInput().getDefiningOp<LogicToTimeOp>())
1165 return reverseOp.getInput();
1166
1167 // Convert constants.
1168 if (auto attr = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
1169 return FVIntegerAttr::get(getContext(), attr.getValue());
1170
1171 return {};
1172}
1173
1174//===----------------------------------------------------------------------===//
1175// LogicToTimeOp
1176//===----------------------------------------------------------------------===//
1177
1178OpFoldResult LogicToTimeOp::fold(FoldAdaptor adaptor) {
1179 // logic_to_time(time_to_logic(x)) -> x
1180 if (auto reverseOp = getInput().getDefiningOp<TimeToLogicOp>())
1181 return reverseOp.getInput();
1182
1183 // Convert constants.
1184 if (auto attr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1185 return IntegerAttr::get(getContext(), APSInt(attr.getValue().toAPInt(false),
1186 /*isUnsigned=*/true));
1187
1188 return {};
1189}
1190
1191//===----------------------------------------------------------------------===//
1192// ConvertRealOp
1193//===----------------------------------------------------------------------===//
1194
1195OpFoldResult ConvertRealOp::fold(FoldAdaptor adaptor) {
1196 if (getInput().getType() == getResult().getType())
1197 return getInput();
1198
1199 return {};
1200}
1201
1202//===----------------------------------------------------------------------===//
1203// TruncOp
1204//===----------------------------------------------------------------------===//
1205
1206OpFoldResult TruncOp::fold(FoldAdaptor adaptor) {
1207 // Truncate constants.
1208 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1209 auto width = getType().getWidth();
1210 return FVIntegerAttr::get(getContext(), intAttr.getValue().trunc(width));
1211 }
1212
1213 return {};
1214}
1215
1216//===----------------------------------------------------------------------===//
1217// ZExtOp
1218//===----------------------------------------------------------------------===//
1219
1220OpFoldResult ZExtOp::fold(FoldAdaptor adaptor) {
1221 // Zero-extend constants.
1222 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1223 auto width = getType().getWidth();
1224 return FVIntegerAttr::get(getContext(), intAttr.getValue().zext(width));
1225 }
1226
1227 return {};
1228}
1229
1230//===----------------------------------------------------------------------===//
1231// SExtOp
1232//===----------------------------------------------------------------------===//
1233
1234OpFoldResult SExtOp::fold(FoldAdaptor adaptor) {
1235 // Sign-extend constants.
1236 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1237 auto width = getType().getWidth();
1238 return FVIntegerAttr::get(getContext(), intAttr.getValue().sext(width));
1239 }
1240
1241 return {};
1242}
1243
1244//===----------------------------------------------------------------------===//
1245// BoolCastOp
1246//===----------------------------------------------------------------------===//
1247
1248OpFoldResult BoolCastOp::fold(FoldAdaptor adaptor) {
1249 // Fold away no-op casts.
1250 if (getInput().getType() == getResult().getType())
1251 return getInput();
1252 return {};
1253}
1254
1255//===----------------------------------------------------------------------===//
1256// BlockingAssignOp
1257//===----------------------------------------------------------------------===//
1258
1259bool BlockingAssignOp::loadsFrom(const MemorySlot &slot) { return false; }
1260
1261bool BlockingAssignOp::storesTo(const MemorySlot &slot) {
1262 return getDst() == slot.ptr;
1263}
1264
1265Value BlockingAssignOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1266 Value reachingDef,
1267 const DataLayout &dataLayout) {
1268 return getSrc();
1269}
1270
1271bool BlockingAssignOp::canUsesBeRemoved(
1272 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1273 SmallVectorImpl<OpOperand *> &newBlockingUses,
1274 const DataLayout &dataLayout) {
1275
1276 if (blockingUses.size() != 1)
1277 return false;
1278 Value blockingUse = (*blockingUses.begin())->get();
1279 return blockingUse == slot.ptr && getDst() == slot.ptr &&
1280 getSrc() != slot.ptr && getSrc().getType() == slot.elemType;
1281}
1282
1283DeletionKind BlockingAssignOp::removeBlockingUses(
1284 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1285 OpBuilder &builder, Value reachingDefinition,
1286 const DataLayout &dataLayout) {
1287 return DeletionKind::Delete;
1288}
1289
1290//===----------------------------------------------------------------------===//
1291// ReadOp
1292//===----------------------------------------------------------------------===//
1293
1294bool ReadOp::loadsFrom(const MemorySlot &slot) {
1295 return getInput() == slot.ptr;
1296}
1297
1298bool ReadOp::storesTo(const MemorySlot &slot) { return false; }
1299
1300Value ReadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1301 Value reachingDef, const DataLayout &dataLayout) {
1302 llvm_unreachable("getStored should not be called on ReadOp");
1303}
1304
1305bool ReadOp::canUsesBeRemoved(const MemorySlot &slot,
1306 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1307 SmallVectorImpl<OpOperand *> &newBlockingUses,
1308 const DataLayout &dataLayout) {
1309
1310 if (blockingUses.size() != 1)
1311 return false;
1312 Value blockingUse = (*blockingUses.begin())->get();
1313 return blockingUse == slot.ptr && getOperand() == slot.ptr &&
1314 getResult().getType() == slot.elemType;
1315}
1316
1317DeletionKind
1318ReadOp::removeBlockingUses(const MemorySlot &slot,
1319 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1320 OpBuilder &builder, Value reachingDefinition,
1321 const DataLayout &dataLayout) {
1322 getResult().replaceAllUsesWith(reachingDefinition);
1323 return DeletionKind::Delete;
1324}
1325
1326//===----------------------------------------------------------------------===//
1327// PowSOp
1328//===----------------------------------------------------------------------===//
1329
1330static OpFoldResult powCommonFolding(MLIRContext *ctxt, Attribute lhs,
1331 Attribute rhs) {
1332 auto lhsValue = dyn_cast_or_null<FVIntegerAttr>(lhs);
1333 if (lhsValue && lhsValue.getValue() == 1)
1334 return lhs;
1335
1336 auto rhsValue = dyn_cast_or_null<FVIntegerAttr>(rhs);
1337 if (rhsValue && rhsValue.getValue().isZero())
1338 return FVIntegerAttr::get(ctxt,
1339 FVInt(rhsValue.getValue().getBitWidth(), 1));
1340
1341 return {};
1342}
1343
1344OpFoldResult PowSOp::fold(FoldAdaptor adaptor) {
1345 return powCommonFolding(getContext(), adaptor.getLhs(), adaptor.getRhs());
1346}
1347
1348LogicalResult PowSOp::canonicalize(PowSOp op, PatternRewriter &rewriter) {
1349 Location loc = op.getLoc();
1350 auto intType = cast<IntType>(op.getRhs().getType());
1351 if (auto baseOp = op.getLhs().getDefiningOp<ConstantOp>()) {
1352 if (baseOp.getValue() == 2) {
1353 Value constOne = ConstantOp::create(rewriter, loc, intType, 1);
1354 Value constZero = ConstantOp::create(rewriter, loc, intType, 0);
1355 Value shift = ShlOp::create(rewriter, loc, constOne, op.getRhs());
1356 Value isNegative = SltOp::create(rewriter, loc, op.getRhs(), constZero);
1357 auto condOp = rewriter.replaceOpWithNewOp<ConditionalOp>(
1358 op, op.getLhs().getType(), isNegative);
1359 Block *thenBlock = rewriter.createBlock(&condOp.getTrueRegion());
1360 rewriter.setInsertionPointToStart(thenBlock);
1361 YieldOp::create(rewriter, loc, constZero);
1362 Block *elseBlock = rewriter.createBlock(&condOp.getFalseRegion());
1363 rewriter.setInsertionPointToStart(elseBlock);
1364 YieldOp::create(rewriter, loc, shift);
1365 return success();
1366 }
1367 }
1368
1369 return failure();
1370}
1371
1372//===----------------------------------------------------------------------===//
1373// PowUOp
1374//===----------------------------------------------------------------------===//
1375
1376OpFoldResult PowUOp::fold(FoldAdaptor adaptor) {
1377 return powCommonFolding(getContext(), adaptor.getLhs(), adaptor.getRhs());
1378}
1379
1380LogicalResult PowUOp::canonicalize(PowUOp op, PatternRewriter &rewriter) {
1381 Location loc = op.getLoc();
1382 auto intType = cast<IntType>(op.getRhs().getType());
1383 if (auto baseOp = op.getLhs().getDefiningOp<ConstantOp>()) {
1384 if (baseOp.getValue() == 2) {
1385 Value constOne = ConstantOp::create(rewriter, loc, intType, 1);
1386 rewriter.replaceOpWithNewOp<ShlOp>(op, constOne, op.getRhs());
1387 return success();
1388 }
1389 }
1390
1391 return failure();
1392}
1393
1394//===----------------------------------------------------------------------===//
1395// SubOp
1396//===----------------------------------------------------------------------===//
1397
1398OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1399 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs()))
1400 if (intAttr.getValue().isZero())
1401 return getLhs();
1402
1403 return {};
1404}
1405
1406//===----------------------------------------------------------------------===//
1407// MulOp
1408//===----------------------------------------------------------------------===//
1409
1410OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1411 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1412 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1413 if (lhs && rhs)
1414 return FVIntegerAttr::get(getContext(), lhs.getValue() * rhs.getValue());
1415 return {};
1416}
1417
1418//===----------------------------------------------------------------------===//
1419// DivUOp
1420//===----------------------------------------------------------------------===//
1421
1422OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1423 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1424 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1425 if (lhs && rhs)
1426 return FVIntegerAttr::get(getContext(),
1427 lhs.getValue().udiv(rhs.getValue()));
1428 return {};
1429}
1430
1431//===----------------------------------------------------------------------===//
1432// DivSOp
1433//===----------------------------------------------------------------------===//
1434
1435OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1436 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1437 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1438 if (lhs && rhs)
1439 return FVIntegerAttr::get(getContext(),
1440 lhs.getValue().sdiv(rhs.getValue()));
1441 return {};
1442}
1443
1444//===----------------------------------------------------------------------===//
1445// Classes
1446//===----------------------------------------------------------------------===//
1447
1448LogicalResult ClassDeclOp::verify() {
1449 mlir::Region &body = getBody();
1450 if (body.empty())
1451 return mlir::success();
1452
1453 auto &block = body.front();
1454 for (mlir::Operation &op : block) {
1455
1456 // allow only property and method decls and terminator
1457 if (llvm::isa<circt::moore::ClassPropertyDeclOp,
1458 circt::moore::ClassMethodDeclOp>(&op))
1459 continue;
1460
1461 return emitOpError()
1462 << "body may only contain 'moore.class.propertydecl' operations";
1463 }
1464 return mlir::success();
1465}
1466
1467LogicalResult ClassNewOp::verify() {
1468 // The result is constrained to ClassHandleType in ODS, so this cast should be
1469 // safe.
1470 auto handleTy = cast<ClassHandleType>(getResult().getType());
1471 mlir::SymbolRefAttr classSym = handleTy.getClassSym();
1472 if (!classSym)
1473 return emitOpError("result type is missing a class symbol");
1474
1475 // Resolve the referenced symbol starting from the nearest symbol table.
1476 mlir::Operation *sym =
1477 mlir::SymbolTable::lookupNearestSymbolFrom(getOperation(), classSym);
1478 if (!sym)
1479 return emitOpError("referenced class symbol `")
1480 << classSym << "` was not found";
1481
1482 if (!llvm::isa<ClassDeclOp>(sym))
1483 return emitOpError("symbol `")
1484 << classSym << "` does not name a `moore.class.classdecl`";
1485
1486 return mlir::success();
1487}
1488
1489void ClassNewOp::getEffects(
1490 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1491 &effects) {
1492 // Always allocates heap memory.
1493 effects.emplace_back(MemoryEffects::Allocate::get());
1494}
1495
1496LogicalResult
1497ClassUpcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1498 // 1) Type checks.
1499 auto srcTy = dyn_cast<ClassHandleType>(getOperand().getType());
1500 if (!srcTy)
1501 return emitOpError() << "operand must be !moore.class<...>; got "
1502 << getOperand().getType();
1503
1504 auto dstTy = dyn_cast<ClassHandleType>(getResult().getType());
1505 if (!dstTy)
1506 return emitOpError() << "result must be !moore.class<...>; got "
1507 << getResult().getType();
1508
1509 if (srcTy == dstTy)
1510 return success();
1511
1512 auto *op = getOperation();
1513
1514 auto *srcDeclOp =
1515 symbolTable.lookupNearestSymbolFrom(op, srcTy.getClassSym());
1516 auto *dstDeclOp =
1517 symbolTable.lookupNearestSymbolFrom(op, dstTy.getClassSym());
1518 if (!srcDeclOp || !dstDeclOp)
1519 return emitOpError() << "failed to resolve class symbol(s): src="
1520 << srcTy.getClassSym()
1521 << ", dst=" << dstTy.getClassSym();
1522
1523 auto srcDecl = dyn_cast<ClassDeclOp>(srcDeclOp);
1524 auto dstDecl = dyn_cast<ClassDeclOp>(dstDeclOp);
1525 if (!srcDecl || !dstDecl)
1526 return emitOpError()
1527 << "symbol(s) do not name `moore.class.classdecl` ops: src="
1528 << srcTy.getClassSym() << ", dst=" << dstTy.getClassSym();
1529
1530 auto cur = srcDecl;
1531 while (cur) {
1532 if (cur == dstDecl)
1533 return success(); // legal upcast: dst is src or an ancestor
1534
1535 auto baseSym = cur.getBaseAttr();
1536 if (!baseSym)
1537 break;
1538
1539 auto *baseOp = symbolTable.lookupNearestSymbolFrom(op, baseSym);
1540 cur = llvm::dyn_cast_or_null<ClassDeclOp>(baseOp);
1541 }
1542
1543 return emitOpError() << "cannot upcast from " << srcTy.getClassSym() << " to "
1544 << dstTy.getClassSym()
1545 << " (destination is not a base class)";
1546}
1547
1548LogicalResult
1549ClassPropertyRefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1550 // The operand is constrained to ClassHandleType in ODS; unwrap it.
1551 Type instTy = getInstance().getType();
1552 auto handleTy = dyn_cast<moore::ClassHandleType>(instTy);
1553 if (!handleTy)
1554 return emitOpError() << "instance must be a !moore.class<@C> value, got "
1555 << instTy;
1556
1557 // Extract the referenced class symbol from the handle type.
1558 SymbolRefAttr classSym = handleTy.getClassSym();
1559 if (!classSym)
1560 return emitOpError("instance type is missing a class symbol");
1561
1562 // Resolve the class symbol starting from the nearest symbol table.
1563 Operation *clsSym =
1564 symbolTable.lookupNearestSymbolFrom(getOperation(), classSym);
1565 if (!clsSym)
1566 return emitOpError("referenced class symbol `")
1567 << classSym << "` was not found";
1568 auto classDecl = dyn_cast<ClassDeclOp>(clsSym);
1569 if (!classDecl)
1570 return emitOpError("symbol `")
1571 << classSym << "` does not name a `moore.class.classdecl`";
1572
1573 // Look up the field symbol inside the class declaration's symbol table.
1574 FlatSymbolRefAttr fieldSym = getPropertyAttr();
1575 if (!fieldSym)
1576 return emitOpError("missing field symbol");
1577
1578 Operation *fldSym = symbolTable.lookupSymbolIn(classDecl, fieldSym.getAttr());
1579 if (!fldSym)
1580 return emitOpError("no field `") << fieldSym << "` in class " << classSym;
1581
1582 auto fieldDecl = dyn_cast<ClassPropertyDeclOp>(fldSym);
1583 if (!fieldDecl)
1584 return emitOpError("symbol `")
1585 << fieldSym << "` is not a `moore.class.propertydecl`";
1586
1587 // Result must be !moore.ref<T> where T matches the field's declared type.
1588 auto resRefTy = cast<RefType>(getPropertyRef().getType());
1589 if (!resRefTy)
1590 return emitOpError("result must be a !moore.ref<T>");
1591
1592 Type expectedElemTy = fieldDecl.getPropertyType();
1593 if (resRefTy.getNestedType() != expectedElemTy)
1594 return emitOpError("result element type (")
1595 << resRefTy.getNestedType() << ") does not match field type ("
1596 << expectedElemTy << ")";
1597
1598 return success();
1599}
1600
1601LogicalResult
1602VTableLoadMethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1603 Operation *op = getOperation();
1604
1605 auto object = getObject();
1606 auto implSym = object.getType().getClassSym();
1607
1608 // Check that classdecl of class handle exists
1609 Operation *implOp = symbolTable.lookupNearestSymbolFrom(op, implSym);
1610 if (!implOp)
1611 return emitOpError() << "implementing class " << implSym << " not found";
1612 auto implClass = cast<moore::ClassDeclOp>(implOp);
1613
1614 StringAttr methodName = getMethodSymAttr().getLeafReference();
1615 if (!methodName || methodName.getValue().empty())
1616 return emitOpError() << "empty method name";
1617
1618 moore::ClassDeclOp cursor = implClass;
1619 Operation *methodDeclOp = nullptr;
1620
1621 // Find method in class decl or parents' class decl
1622 while (cursor && !methodDeclOp) {
1623 methodDeclOp = symbolTable.lookupSymbolIn(cursor, methodName);
1624 if (methodDeclOp)
1625 break;
1626 SymbolRefAttr baseSym = cursor.getBaseAttr();
1627 if (!baseSym)
1628 break;
1629 Operation *baseOp = symbolTable.lookupNearestSymbolFrom(op, baseSym);
1630 cursor = baseOp ? cast<moore::ClassDeclOp>(baseOp) : moore::ClassDeclOp();
1631 }
1632
1633 if (!methodDeclOp)
1634 return emitOpError() << "no method `" << methodName << "` found in "
1635 << implClass.getSymName() << " or its bases";
1636
1637 // Make sure method decl is a ClassMethodDeclOp
1638 auto methodDecl = dyn_cast<moore::ClassMethodDeclOp>(methodDeclOp);
1639 if (!methodDecl)
1640 return emitOpError() << "`" << methodName
1641 << "` is not a method declaration";
1642
1643 // Make sure method signature matches
1644 auto resFnTy = cast<FunctionType>(getResult().getType());
1645 auto declFnTy = cast<FunctionType>(methodDecl.getFunctionType());
1646 if (resFnTy != declFnTy)
1647 return emitOpError() << "result type " << resFnTy
1648 << " does not match method erased ABI " << declFnTy;
1649
1650 return success();
1651}
1652
1653LogicalResult VTableOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1654 Operation *self = getOperation();
1655
1656 // sym_name's root must be a ClassDeclOp
1657 SymbolRefAttr name = getSymNameAttr();
1658 if (!name)
1659 return emitOpError("requires 'sym_name' SymbolRefAttr");
1660
1661 // Root symbol must resolve (from the nearest symbol table) to a ClassDeclOp.
1662 Operation *rootDef = symbolTable.lookupNearestSymbolFrom(
1663 self, SymbolRefAttr::get(name.getRootReference()));
1664 if (!rootDef)
1665 return emitOpError() << "cannot resolve root class symbol '"
1666 << name.getRootReference() << "' for sym_name "
1667 << name;
1668
1669 if (!isa<ClassDeclOp>(rootDef))
1670 return emitOpError()
1671 << "root of sym_name must name a 'moore.class.classdecl', got "
1672 << name;
1673
1674 // All good.
1675 return success();
1676}
1677
1678LogicalResult VTableOp::verifyRegions() {
1679 // Ensure only allowed ops appear inside.
1680 for (Operation &op : getBody().front()) {
1681 if (!isa<VTableOp, VTableEntryOp>(op))
1682 return emitOpError(
1683 "body may only contain 'moore.vtable' or 'moore.vtable_entry' ops");
1684 }
1685 return mlir::success();
1686}
1687
1688LogicalResult
1689VTableEntryOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1690 Operation *self = getOperation();
1691
1692 // 'target' must exist and resolve from the top-level symbol table of a func
1693 // op
1694 SymbolRefAttr target = getTargetAttr();
1695 func::FuncOp def =
1696 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(self, target);
1697 if (!def)
1698 return emitOpError()
1699 << "cannot resolve target symbol to a function operation " << target;
1700
1701 // VTableEntries may only exist in VTables.
1702 if (!isa<VTableOp>(self->getParentOp()))
1703 return emitOpError("must be nested directly inside a 'moore.vtable' op");
1704
1705 Operation *currentOp = self;
1706 VTableOp currentVTable;
1707 bool defined = false;
1708
1709 // Walk up the VTable tree and check whether the corresponding classDeclOp
1710 // declares a method with the same implementation. Further checks all the way
1711 // up the tree if another classdeclop overrides the implementation.
1712 // The entry is correct iff the impl matches the most derived classdeclop's
1713 // methoddeclop implementing the virtual method.
1714 while (auto parentOp = dyn_cast<VTableOp>(currentOp->getParentOp())) {
1715 currentOp = parentOp;
1716 currentVTable = cast<VTableOp>(currentOp);
1717
1718 auto classSymName = currentVTable.getSymName();
1719 ClassDeclOp parentClassDecl =
1720 symbolTable.lookupNearestSymbolFrom<ClassDeclOp>(
1721 parentOp, classSymName.getRootReference());
1722 assert(parentClassDecl && "VTableOp must point to a classdeclop");
1723
1724 for (auto method : parentClassDecl.getBody().getOps<ClassMethodDeclOp>()) {
1725 // A virtual interface declaration. Ignore.
1726 if (!method.getImpl())
1727 continue;
1728
1729 // A matching definition.
1730 if (method.getSymName() == getName() && method.getImplAttr() == target)
1731 defined = true;
1732
1733 // All definitions of the same method up the tree must be the same as the
1734 // current definition, there is no shadowing.
1735 // Hence, if we encounter a methoddeclop that has the same name but a
1736 // different implementation that means this vtableentry should point to
1737 // the op's implementation - that's an error.
1738 else if (method.getSymName() == getName() &&
1739 method.getImplAttr() != target && defined)
1740 return emitOpError() << "Target " << target
1741 << " should be overridden by " << classSymName;
1742 }
1743 }
1744 if (!defined)
1745 return emitOpError()
1746 << "Parent class does not point to any implementation!";
1747
1748 return success();
1749}
1750
1751//===----------------------------------------------------------------------===//
1752// TableGen generated logic.
1753//===----------------------------------------------------------------------===//
1754
1755// Provide the autogenerated implementation guts for the Op classes.
1756#define GET_OP_CLASSES
1757#include "circt/Dialect/Moore/Moore.cpp.inc"
1758#include "circt/Dialect/Moore/MooreEnums.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static std::unique_ptr< Context > context
@ Output
Definition HW.h:42
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static OpFoldResult powCommonFolding(MLIRContext *ctxt, Attribute lhs, Attribute rhs)
static ArrayRef< StructLikeMember > getStructMembers(Type type)
Definition MooreOps.cpp:808
static std::optional< uint32_t > getStructFieldIndex(Type type, StringAttr name)
Definition MooreOps.cpp:799
static UnpackedType getStructFieldType(Type type, StringAttr name)
Definition MooreOps.cpp:817
static std::pair< unsigned, UnpackedType > getArrayElements(Type type)
Definition MooreOps.cpp:766
static InstancePath empty
Four-valued arbitrary precision integers.
Definition FVInt.h:37
bool isNegative() const
Determine whether the integer interpreted as a signed number would be negative.
Definition FVInt.h:185
FVInt sext(unsigned bitWidth) const
Sign-extend the integer to a new bit width.
Definition FVInt.h:148
unsigned getSignificantBits() const
Compute the minimum bit width necessary to accurately represent this integer's value and sign.
Definition FVInt.h:102
static FVInt getAllX(unsigned numBits)
Construct an FVInt with all bits set to X.
Definition FVInt.h:75
bool hasUnknown() const
Determine if any bits are X or Z.
Definition FVInt.h:168
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition FVInt.h:92
unsigned getBitWidth() const
Return the number of bits this integer has.
Definition FVInt.h:85
FVInt trunc(unsigned bitWidth) const
Truncate the integer to a smaller bit width.
Definition FVInt.h:132
A packed SystemVerilog type.
Definition MooreTypes.h:153
std::optional< unsigned > getBitSize() const
Get the size of this type in bits.
Domain getDomain() const
Get the value domain of this type.
An unpacked SystemVerilog type.
Definition MooreTypes.h:101
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition HWOps.cpp:529
Domain
The number of values each bit of a type can assume.
Definition MooreTypes.h:49
RealWidth
The type of floating point / real number behind a RealType.
Definition MooreTypes.h:57
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
void printFVInt(AsmPrinter &p, const FVInt &value)
Print a four-valued integer usign an AsmPrinter.
Definition FVInt.cpp:147
ParseResult parseFVInt(AsmParser &p, FVInt &result)
Parse a four-valued integer using an AsmParser.
Definition FVInt.cpp:162
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183