CIRCT 23.0.0git
Loading...
Searching...
No Matches
MooreOps.cpp
Go to the documentation of this file.
1//===- MooreOps.cpp - Implement the Moore operations ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Moore dialect operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Builders.h"
19#include "llvm/ADT/APSInt.h"
20#include "llvm/ADT/SmallString.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include <mlir/Dialect/Func/IR/FuncOps.h>
23
24using namespace circt;
25using namespace circt::moore;
26using namespace mlir;
27
28//===----------------------------------------------------------------------===//
29// SVModuleOp
30//===----------------------------------------------------------------------===//
31
32void SVModuleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
33 llvm::StringRef name, hw::ModuleType type) {
34 state.addAttribute(SymbolTable::getSymbolAttrName(),
35 builder.getStringAttr(name));
36 state.addAttribute(getModuleTypeAttrName(state.name), TypeAttr::get(type));
37 state.addRegion();
38}
39
40void SVModuleOp::print(OpAsmPrinter &p) {
41 p << " ";
42
43 // Print the visibility of the module.
44 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
45 if (auto visibility = (*this)->getAttrOfType<StringAttr>(visibilityAttrName))
46 p << visibility.getValue() << ' ';
47
48 p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
50 getModuleType(), {}, {});
51 p << " ";
52 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false,
53 /*printBlockTerminators=*/true);
54
55 p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(),
56 getAttributeNames());
57}
58
59ParseResult SVModuleOp::parse(OpAsmParser &parser, OperationState &result) {
60 // Parse the visibility attribute.
61 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
62
63 // Parse the module name.
64 StringAttr nameAttr;
65 if (parser.parseSymbolName(nameAttr, getSymNameAttrName(result.name),
66 result.attributes))
67 return failure();
68
69 // Parse the ports.
70 SmallVector<hw::module_like_impl::PortParse> ports;
71 TypeAttr modType;
72 if (failed(
73 hw::module_like_impl::parseModuleSignature(parser, ports, modType)))
74 return failure();
75 result.addAttribute(getModuleTypeAttrName(result.name), modType);
76
77 // Parse the attributes.
78 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
79 return failure();
80
81 // Add the entry block arguments.
82 SmallVector<OpAsmParser::Argument, 4> entryArgs;
83 for (auto &port : ports)
84 if (port.direction != hw::ModulePort::Direction::Output)
85 entryArgs.push_back(port);
86
87 // Parse the optional function body.
88 auto &bodyRegion = *result.addRegion();
89 if (parser.parseRegion(bodyRegion, entryArgs))
90 return failure();
91
92 ensureTerminator(bodyRegion, parser.getBuilder(), result.location);
93 return success();
94}
95
96void SVModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
97 mlir::OpAsmSetValueNameFn setNameFn) {
98 if (&region != &getBodyRegion())
99 return;
100 auto moduleType = getModuleType();
101 for (auto [index, arg] : llvm::enumerate(region.front().getArguments()))
102 setNameFn(arg, moduleType.getInputNameAttr(index));
103}
104
105OutputOp SVModuleOp::getOutputOp() {
106 return cast<OutputOp>(getBody()->getTerminator());
107}
108
109OperandRange SVModuleOp::getOutputs() { return getOutputOp().getOperands(); }
110
111//===----------------------------------------------------------------------===//
112// OutputOp
113//===----------------------------------------------------------------------===//
114
115LogicalResult OutputOp::verify() {
116 auto module = getParentOp();
117
118 // Check that the number of operands matches the number of output ports.
119 auto outputTypes = module.getModuleType().getOutputTypes();
120 if (outputTypes.size() != getNumOperands())
121 return emitOpError("has ")
122 << getNumOperands() << " operands, but enclosing module @"
123 << module.getSymName() << " has " << outputTypes.size()
124 << " outputs";
125
126 // Check that the operand types match the output ports.
127 for (unsigned i = 0, e = outputTypes.size(); i != e; ++i)
128 if (outputTypes[i] != getOperand(i).getType())
129 return emitOpError() << "operand " << i << " (" << getOperand(i).getType()
130 << ") does not match output type (" << outputTypes[i]
131 << ") of module @" << module.getSymName();
132
133 return success();
134}
135
136//===----------------------------------------------------------------------===//
137// InstanceOp
138//===----------------------------------------------------------------------===//
139
140LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
141 // Resolve the target symbol.
142 auto *symbol =
143 symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
144 if (!symbol)
145 return emitOpError("references unknown symbol @") << getModuleName();
146
147 // Check that the symbol is a SVModuleOp.
148 auto module = dyn_cast<SVModuleOp>(symbol);
149 if (!module)
150 return emitOpError("must reference a 'moore.module', but @")
151 << getModuleName() << " is a '" << symbol->getName() << "'";
152
153 // Check that the input ports match.
154 auto moduleType = module.getModuleType();
155 auto inputTypes = moduleType.getInputTypes();
156
157 if (inputTypes.size() != getNumOperands())
158 return emitOpError("has ")
159 << getNumOperands() << " operands, but target module @"
160 << module.getSymName() << " has " << inputTypes.size() << " inputs";
161
162 for (unsigned i = 0, e = inputTypes.size(); i != e; ++i)
163 if (inputTypes[i] != getOperand(i).getType())
164 return emitOpError() << "operand " << i << " (" << getOperand(i).getType()
165 << ") does not match input type (" << inputTypes[i]
166 << ") of module @" << module.getSymName();
167
168 // Check that the output ports match.
169 auto outputTypes = moduleType.getOutputTypes();
170
171 if (outputTypes.size() != getNumResults())
172 return emitOpError("has ")
173 << getNumOperands() << " results, but target module @"
174 << module.getSymName() << " has " << outputTypes.size()
175 << " outputs";
176
177 for (unsigned i = 0, e = outputTypes.size(); i != e; ++i)
178 if (outputTypes[i] != getResult(i).getType())
179 return emitOpError() << "result " << i << " (" << getResult(i).getType()
180 << ") does not match output type (" << outputTypes[i]
181 << ") of module @" << module.getSymName();
182
183 return success();
184}
185
186void InstanceOp::print(OpAsmPrinter &p) {
187 p << " ";
188 p.printAttributeWithoutType(getInstanceNameAttr());
189 p << " ";
190 p.printAttributeWithoutType(getModuleNameAttr());
191 printInputPortList(p, getOperation(), getInputs(), getInputs().getTypes(),
192 getInputNames());
193 p << " -> ";
194 printOutputPortList(p, getOperation(), getOutputs().getTypes(),
195 getOutputNames());
196 p.printOptionalAttrDict(getOperation()->getAttrs(), getAttributeNames());
197}
198
199ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
200 // Parse the instance name.
201 StringAttr instanceName;
202 if (parser.parseAttribute(instanceName, "instanceName", result.attributes))
203 return failure();
204
205 // Parse the module name.
206 FlatSymbolRefAttr moduleName;
207 if (parser.parseAttribute(moduleName, "moduleName", result.attributes))
208 return failure();
209
210 // Parse the input port list.
211 auto loc = parser.getCurrentLocation();
212 SmallVector<OpAsmParser::UnresolvedOperand> inputs;
213 SmallVector<Type> types;
214 ArrayAttr names;
215 if (parseInputPortList(parser, inputs, types, names))
216 return failure();
217 if (parser.resolveOperands(inputs, types, loc, result.operands))
218 return failure();
219 result.addAttribute("inputNames", names);
220
221 // Parse `->`.
222 if (parser.parseArrow())
223 return failure();
224
225 // Parse the output port list.
226 types.clear();
227 if (parseOutputPortList(parser, types, names))
228 return failure();
229 result.addAttribute("outputNames", names);
230 result.addTypes(types);
231
232 // Parse the attributes.
233 if (parser.parseOptionalAttrDict(result.attributes))
234 return failure();
235
236 return success();
237}
238
239void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
240 SmallString<32> name;
241 name += getInstanceName();
242 name += '.';
243 auto baseLen = name.size();
244
245 for (auto [result, portName] :
246 llvm::zip(getOutputs(), getOutputNames().getAsRange<StringAttr>())) {
247 if (!portName || portName.empty())
248 continue;
249 name.resize(baseLen);
250 name += portName.getValue();
251 setNameFn(result, name);
252 }
253}
254
255//===----------------------------------------------------------------------===//
256// VariableOp
257//===----------------------------------------------------------------------===//
258
259void VariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
260 if (getName() && !getName()->empty())
261 setNameFn(getResult(), *getName());
262}
263
264LogicalResult VariableOp::canonicalize(VariableOp op,
265 PatternRewriter &rewriter) {
266 // If the variable is embedded in an SSACFG region, move the initial value
267 // into an assignment immediately after the variable op. This allows the
268 // mem2reg pass which cannot handle variables with initial values.
269 auto initial = op.getInitial();
270 if (initial && mlir::mayHaveSSADominance(*op->getParentRegion())) {
271 rewriter.modifyOpInPlace(op, [&] { op.getInitialMutable().clear(); });
272 rewriter.setInsertionPointAfter(op);
273 BlockingAssignOp::create(rewriter, initial.getLoc(), op, initial);
274 return success();
275 }
276
277 // Check if the variable has one unique continuous assignment to it, all other
278 // uses are reads, and that all uses are in the same block as the variable
279 // itself.
280 auto *block = op->getBlock();
281 ContinuousAssignOp uniqueAssignOp;
282 for (auto *user : op->getUsers()) {
283 // Ensure that all users of the variable are in the same block.
284 if (user->getBlock() != block)
285 return failure();
286
287 // Ensure there is at most one unique continuous assignment to the variable.
288 if (auto assignOp = dyn_cast<ContinuousAssignOp>(user)) {
289 if (uniqueAssignOp)
290 return failure();
291 uniqueAssignOp = assignOp;
292 continue;
293 }
294
295 // Ensure all other users are reads.
296 if (!isa<ReadOp>(user))
297 return failure();
298 }
299 if (!uniqueAssignOp)
300 return failure();
301
302 // If the original variable had a name, create an `AssignedVariableOp` as a
303 // replacement. Otherwise substitute the assigned value directly.
304 Value assignedValue = uniqueAssignOp.getSrc();
305 if (auto name = op.getNameAttr(); name && !name.empty())
306 assignedValue = AssignedVariableOp::create(rewriter, op.getLoc(), name,
307 uniqueAssignOp.getSrc());
308
309 // Remove the assign op and replace all reads with the new assigned var op.
310 rewriter.eraseOp(uniqueAssignOp);
311 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
312 auto readOp = cast<ReadOp>(user);
313 rewriter.replaceOp(readOp, assignedValue);
314 }
315
316 // Remove the original variable.
317 rewriter.eraseOp(op);
318 return success();
319}
320
321SmallVector<MemorySlot> VariableOp::getPromotableSlots() {
322 // We cannot promote variables with an initial value, since that value may not
323 // dominate the location where the default value needs to be constructed.
324 if (mlir::mayBeGraphRegion(*getOperation()->getParentRegion()) ||
325 getInitial())
326 return {};
327
328 // Ensure that `getDefaultValue` can conjure up a default value for the
329 // variable's type.
330 auto nestedType = dyn_cast<PackedType>(getType().getNestedType());
331 if (!nestedType || !nestedType.getBitSize())
332 return {};
333
334 return {MemorySlot{getResult(), getType().getNestedType()}};
335}
336
337Value VariableOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) {
338 auto packedType = dyn_cast<PackedType>(slot.elemType);
339 if (!packedType)
340 return {};
341 auto bitWidth = packedType.getBitSize();
342 if (!bitWidth)
343 return {};
344 auto fvint = packedType.getDomain() == Domain::FourValued
345 ? FVInt::getAllX(*bitWidth)
346 : FVInt::getZero(*bitWidth);
347 Value value = ConstantOp::create(
348 builder, getLoc(),
349 IntType::get(getContext(), *bitWidth, packedType.getDomain()), fvint);
350 if (value.getType() != packedType)
351 SBVToPackedOp::create(builder, getLoc(), packedType, value);
352 return value;
353}
354
355void VariableOp::handleBlockArgument(const MemorySlot &slot,
356 BlockArgument argument,
357 OpBuilder &builder) {}
358
359std::optional<mlir::PromotableAllocationOpInterface>
360VariableOp::handlePromotionComplete(const MemorySlot &slot, Value defaultValue,
361 OpBuilder &builder) {
362 if (defaultValue && defaultValue.use_empty())
363 defaultValue.getDefiningOp()->erase();
364 this->erase();
365 return {};
366}
367
368SmallVector<DestructurableMemorySlot> VariableOp::getDestructurableSlots() {
369 if (isa<SVModuleOp>(getOperation()->getParentOp()))
370 return {};
371 if (getInitial())
372 return {};
373
374 auto refType = getType();
375 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(refType);
376 if (!destructurable)
377 return {};
378
379 auto destructuredType = destructurable.getSubelementIndexMap();
380 if (!destructuredType)
381 return {};
382
383 return {DestructurableMemorySlot{{getResult(), refType}, *destructuredType}};
384}
385
386DenseMap<Attribute, MemorySlot> VariableOp::destructure(
387 const DestructurableMemorySlot &slot,
388 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
389 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
390 assert(slot.ptr == getResult());
391 assert(!getInitial());
392 builder.setInsertionPointAfter(*this);
393
394 auto destructurableType = cast<DestructurableTypeInterface>(getType());
395 DenseMap<Attribute, MemorySlot> slotMap;
396 for (Attribute index : usedIndices) {
397 auto elemType = cast<RefType>(destructurableType.getTypeAtIndex(index));
398 assert(elemType && "used index must exist");
399 StringAttr varName;
400 if (auto name = getName(); name && !name->empty())
401 varName = StringAttr::get(
402 getContext(), (*name) + "." + cast<StringAttr>(index).getValue());
403 auto varOp =
404 VariableOp::create(builder, getLoc(), elemType, varName, Value());
405 newAllocators.push_back(varOp);
406 slotMap.try_emplace<MemorySlot>(index, {varOp.getResult(), elemType});
407 }
408
409 return slotMap;
410}
411
412std::optional<DestructurableAllocationOpInterface>
413VariableOp::handleDestructuringComplete(const DestructurableMemorySlot &slot,
414 OpBuilder &builder) {
415 assert(slot.ptr == getResult());
416 this->erase();
417 return std::nullopt;
418}
419
420//===----------------------------------------------------------------------===//
421// NetOp
422//===----------------------------------------------------------------------===//
423
424void NetOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
425 if (getName() && !getName()->empty())
426 setNameFn(getResult(), *getName());
427}
428
429LogicalResult NetOp::canonicalize(NetOp op, PatternRewriter &rewriter) {
430 bool modified = false;
431
432 // Check if the net has one unique continuous assignment to it, and
433 // additionally if all other users are reads.
434 auto *block = op->getBlock();
435 ContinuousAssignOp uniqueAssignOp;
436 bool allUsesAreReads = true;
437 for (auto *user : op->getUsers()) {
438 // Ensure that all users of the net are in the same block.
439 if (user->getBlock() != block)
440 return failure();
441
442 // Ensure there is at most one unique continuous assignment to the net.
443 if (auto assignOp = dyn_cast<ContinuousAssignOp>(user)) {
444 if (uniqueAssignOp)
445 return failure();
446 uniqueAssignOp = assignOp;
447 continue;
448 }
449
450 // Ensure all other users are reads.
451 if (!isa<ReadOp>(user))
452 allUsesAreReads = false;
453 }
454
455 // If there was one unique assignment, and the `NetOp` does not yet have an
456 // assigned value set, fold the assignment into the net.
457 if (uniqueAssignOp && !op.getAssignment()) {
458 rewriter.modifyOpInPlace(
459 op, [&] { op.getAssignmentMutable().assign(uniqueAssignOp.getSrc()); });
460 rewriter.eraseOp(uniqueAssignOp);
461 modified = true;
462 uniqueAssignOp = {};
463 }
464
465 // If all users of the net op are reads, and any potential unique assignment
466 // has been folded into the net op itself, directly replace the reads with the
467 // net's assigned value.
468 if (!uniqueAssignOp && allUsesAreReads && op.getAssignment()) {
469 // If the original net had a name, create an `AssignedVariableOp` as a
470 // replacement. Otherwise substitute the assigned value directly.
471 auto assignedValue = op.getAssignment();
472 if (auto name = op.getNameAttr(); name && !name.empty())
473 assignedValue = AssignedVariableOp::create(rewriter, op.getLoc(), name,
474 assignedValue);
475
476 // Replace all reads with the new assigned var op and remove the original
477 // net op.
478 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
479 auto readOp = cast<ReadOp>(user);
480 rewriter.replaceOp(readOp, assignedValue);
481 }
482 rewriter.eraseOp(op);
483 modified = true;
484 }
485
486 return success(modified);
487}
488
489//===----------------------------------------------------------------------===//
490// AssignedVariableOp
491//===----------------------------------------------------------------------===//
492
493void AssignedVariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
494 if (getName() && !getName()->empty())
495 setNameFn(getResult(), *getName());
496}
497
498LogicalResult AssignedVariableOp::canonicalize(AssignedVariableOp op,
499 PatternRewriter &rewriter) {
500 // Eliminate chained variables with the same name.
501 // var(name, var(name, x)) -> var(name, x)
502 if (auto otherOp = op.getInput().getDefiningOp<AssignedVariableOp>()) {
503 if (otherOp.getNameAttr() == op.getNameAttr()) {
504 rewriter.replaceOp(op, otherOp);
505 return success();
506 }
507 }
508
509 // Eliminate variables that alias an input port of the same name.
510 if (auto blockArg = dyn_cast<BlockArgument>(op.getInput())) {
511 if (auto moduleOp =
512 dyn_cast<SVModuleOp>(blockArg.getOwner()->getParentOp())) {
513 auto moduleType = moduleOp.getModuleType();
514 auto portName = moduleType.getInputNameAttr(blockArg.getArgNumber());
515 if (portName == op.getNameAttr()) {
516 rewriter.replaceOp(op, blockArg);
517 return success();
518 }
519 }
520 }
521
522 // Eliminate variables that feed an output port of the same name.
523 for (auto &use : op->getUses()) {
524 auto *useOwner = use.getOwner();
525 if (auto outputOp = dyn_cast<OutputOp>(useOwner)) {
526 if (auto moduleOp = dyn_cast<SVModuleOp>(outputOp->getParentOp())) {
527 auto moduleType = moduleOp.getModuleType();
528 auto portName = moduleType.getOutputNameAttr(use.getOperandNumber());
529 if (portName == op.getNameAttr()) {
530 rewriter.replaceOp(op, op.getInput());
531 return success();
532 }
533 } else
534 break;
535 }
536 }
537
538 return failure();
539}
540
541//===----------------------------------------------------------------------===//
542// GlobalVariableOp
543//===----------------------------------------------------------------------===//
544
545LogicalResult GlobalVariableOp::verifyRegions() {
546 if (auto *block = getInitBlock()) {
547 auto &terminator = block->back();
548 if (!isa<YieldOp>(terminator))
549 return emitOpError() << "must have a 'moore.yield' terminator";
550 }
551 return success();
552}
553
554Block *GlobalVariableOp::getInitBlock() {
555 if (getInitRegion().empty())
556 return nullptr;
557 return &getInitRegion().front();
558}
559
560//===----------------------------------------------------------------------===//
561// GetGlobalVariableOp
562//===----------------------------------------------------------------------===//
563
564LogicalResult
565GetGlobalVariableOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
566 // Resolve the target symbol.
567 auto *symbol =
568 symbolTable.lookupNearestSymbolFrom(*this, getGlobalNameAttr());
569 if (!symbol)
570 return emitOpError() << "references unknown symbol " << getGlobalNameAttr();
571
572 // Check that the symbol is a global variable.
573 auto var = dyn_cast<GlobalVariableOp>(symbol);
574 if (!var)
575 return emitOpError() << "must reference a 'moore.global_variable', but "
576 << getGlobalNameAttr() << " is a '"
577 << symbol->getName() << "'";
578
579 // Check that the types match.
580 auto expType = var.getType();
581 auto actType = getType().getNestedType();
582 if (expType != actType)
583 return emitOpError() << "returns a " << actType << " reference, but "
584 << getGlobalNameAttr() << " is of type " << expType;
585
586 return success();
587}
588
589//===----------------------------------------------------------------------===//
590// ConstantOp
591//===----------------------------------------------------------------------===//
592
593void ConstantOp::print(OpAsmPrinter &p) {
594 p << " ";
595 printFVInt(p, getValue());
596 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
597 p << " : ";
598 p.printStrippedAttrOrType(getType());
599}
600
601ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
602 // Parse the constant value.
603 FVInt value;
604 auto valueLoc = parser.getCurrentLocation();
605 if (parseFVInt(parser, value))
606 return failure();
607
608 // Parse any optional attributes and the `:`.
609 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
610 return failure();
611
612 // Parse the result type.
613 IntType type;
614 if (parser.parseCustomTypeWithFallback(type))
615 return failure();
616
617 // Extend or truncate the constant value to match the size of the type.
618 if (type.getWidth() > value.getBitWidth()) {
619 // sext is always safe here, even for unsigned values, because the
620 // parseOptionalInteger method will return something with a zero in the
621 // top bits if it is a positive number.
622 value = value.sext(type.getWidth());
623 } else if (type.getWidth() < value.getBitWidth()) {
624 // The parser can return an unnecessarily wide result with leading
625 // zeros. This isn't a problem, but truncating off bits is bad.
626 unsigned neededBits =
627 value.isNegative() ? value.getSignificantBits() : value.getActiveBits();
628 if (type.getWidth() < neededBits)
629 return parser.emitError(valueLoc)
630 << "value requires " << neededBits
631 << " bits, but result type only has " << type.getWidth();
632 value = value.trunc(type.getWidth());
633 }
634
635 // If the constant contains any X or Z bits, the result type must be
636 // four-valued.
637 if (value.hasUnknown() && type.getDomain() != Domain::FourValued)
638 return parser.emitError(valueLoc)
639 << "value contains X or Z bits, but result type " << type
640 << " only allows two-valued bits";
641
642 // Build the attribute and op.
643 auto attrValue = FVIntegerAttr::get(parser.getContext(), value);
644 result.addAttribute("value", attrValue);
645 result.addTypes(type);
646 return success();
647}
648
649LogicalResult ConstantOp::verify() {
650 auto attrWidth = getValue().getBitWidth();
651 auto typeWidth = getType().getWidth();
652 if (attrWidth != typeWidth)
653 return emitError("attribute width ")
654 << attrWidth << " does not match return type's width " << typeWidth;
655 return success();
656}
657
658void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
659 const FVInt &value) {
660 assert(type.getWidth() == value.getBitWidth() &&
661 "FVInt width must match type width");
662 build(builder, result, type, FVIntegerAttr::get(builder.getContext(), value));
663}
664
665void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
666 const APInt &value) {
667 assert(type.getWidth() == value.getBitWidth() &&
668 "APInt width must match type width");
669 build(builder, result, type, FVInt(value));
670}
671
672/// This builder allows construction of small signed integers like 0, 1, -1
673/// matching a specified MLIR type. This shouldn't be used for general constant
674/// folding because it only works with values that can be expressed in an
675/// `int64_t`.
676void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
677 int64_t value, bool isSigned) {
678 build(builder, result, type,
679 APInt(type.getWidth(), (uint64_t)value, isSigned));
680}
681
682OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
683 assert(adaptor.getOperands().empty() && "constant has no operands");
684 return getValueAttr();
685}
686
687//===----------------------------------------------------------------------===//
688// ConstantTimeOp
689//===----------------------------------------------------------------------===//
690
691OpFoldResult ConstantTimeOp::fold(FoldAdaptor adaptor) {
692 return getValueAttr();
693}
694
695//===----------------------------------------------------------------------===//
696// ConstantRealOp
697//===----------------------------------------------------------------------===//
698
699LogicalResult ConstantRealOp::inferReturnTypes(
700 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
701 DictionaryAttr attrs, mlir::OpaqueProperties properties,
702 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
703 ConstantRealOp::Adaptor adaptor(operands, attrs, properties);
704 results.push_back(RealType::get(
705 context, static_cast<RealWidth>(
706 adaptor.getValueAttr().getType().getIntOrFloatBitWidth())));
707 return success();
708}
709
710//===----------------------------------------------------------------------===//
711// ConcatOp
712//===----------------------------------------------------------------------===//
713
714LogicalResult ConcatOp::inferReturnTypes(
715 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
716 DictionaryAttr attrs, mlir::OpaqueProperties properties,
717 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
718 Domain domain = Domain::TwoValued;
719 unsigned width = 0;
720 for (auto operand : operands) {
721 auto type = cast<IntType>(operand.getType());
722 if (type.getDomain() == Domain::FourValued)
723 domain = Domain::FourValued;
724 width += type.getWidth();
725 }
726 results.push_back(IntType::get(context, width, domain));
727 return success();
728}
729
730//===----------------------------------------------------------------------===//
731// ConcatRefOp
732//===----------------------------------------------------------------------===//
733
734LogicalResult ConcatRefOp::inferReturnTypes(
735 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
736 DictionaryAttr attrs, mlir::OpaqueProperties properties,
737 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
738 Domain domain = Domain::TwoValued;
739 unsigned width = 0;
740 for (Value operand : operands) {
741 UnpackedType nestedType = cast<RefType>(operand.getType()).getNestedType();
742 PackedType packedType = dyn_cast<PackedType>(nestedType);
743
744 if (!packedType) {
745 return failure();
746 }
747
748 if (packedType.getDomain() == Domain::FourValued)
749 domain = Domain::FourValued;
750
751 // getBitSize() for PackedType returns an optional, so we must check it.
752 std::optional<int> bitSize = packedType.getBitSize();
753 if (!bitSize) {
754 return failure();
755 }
756 width += *bitSize;
757 }
758 results.push_back(RefType::get(IntType::get(context, width, domain)));
759 return success();
760}
761
762//===----------------------------------------------------------------------===//
763// ArrayCreateOp
764//===----------------------------------------------------------------------===//
765
766static std::pair<unsigned, UnpackedType> getArrayElements(Type type) {
767 if (auto arrayType = dyn_cast<ArrayType>(type))
768 return {arrayType.getSize(), arrayType.getElementType()};
769 if (auto arrayType = dyn_cast<UnpackedArrayType>(type))
770 return {arrayType.getSize(), arrayType.getElementType()};
771 assert(0 && "expected ArrayType or UnpackedArrayType");
772 return {};
773}
774
775LogicalResult ArrayCreateOp::verify() {
776 auto [size, elementType] = getArrayElements(getType());
777
778 // Check that the number of operands matches the array size.
779 if (getElements().size() != size)
780 return emitOpError() << "has " << getElements().size()
781 << " operands, but result type requires " << size;
782
783 // Check that the operand types match the array element type. We only need to
784 // check one of the operands, since the `SameTypeOperands` trait ensures all
785 // operands have the same type.
786 if (size > 0) {
787 auto value = getElements()[0];
788 if (value.getType() != elementType)
789 return emitOpError() << "operands have type " << value.getType()
790 << ", but array requires " << elementType;
791 }
792 return success();
793}
794
795//===----------------------------------------------------------------------===//
796// StructCreateOp
797//===----------------------------------------------------------------------===//
798
799static std::optional<uint32_t> getStructFieldIndex(Type type, StringAttr name) {
800 if (auto structType = dyn_cast<StructType>(type))
801 return structType.getFieldIndex(name);
802 if (auto structType = dyn_cast<UnpackedStructType>(type))
803 return structType.getFieldIndex(name);
804 assert(0 && "expected StructType or UnpackedStructType");
805 return {};
806}
807
808static ArrayRef<StructLikeMember> getStructMembers(Type type) {
809 if (auto structType = dyn_cast<StructType>(type))
810 return structType.getMembers();
811 if (auto structType = dyn_cast<UnpackedStructType>(type))
812 return structType.getMembers();
813 assert(0 && "expected StructType or UnpackedStructType");
814 return {};
815}
816
817static UnpackedType getStructFieldType(Type type, StringAttr name) {
818 if (auto index = getStructFieldIndex(type, name))
819 return getStructMembers(type)[*index].type;
820 return {};
821}
822
823LogicalResult StructCreateOp::verify() {
824 auto members = getStructMembers(getType());
825
826 // Check that the number of operands matches the number of struct fields.
827 if (getFields().size() != members.size())
828 return emitOpError() << "has " << getFields().size()
829 << " operands, but result type requires "
830 << members.size();
831
832 // Check that the operand types match the struct field types.
833 for (auto [index, pair] : llvm::enumerate(llvm::zip(getFields(), members))) {
834 auto [value, member] = pair;
835 if (value.getType() != member.type)
836 return emitOpError() << "operand #" << index << " has type "
837 << value.getType() << ", but struct field "
838 << member.name << " requires " << member.type;
839 }
840 return success();
841}
842
843OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
844 SmallVector<NamedAttribute> fields;
845 for (auto [member, field] :
846 llvm::zip(getStructMembers(getType()), adaptor.getFields())) {
847 if (!field)
848 return {};
849 fields.push_back(NamedAttribute(member.name, field));
850 }
851 return DictionaryAttr::get(getContext(), fields);
852}
853
854//===----------------------------------------------------------------------===//
855// StructExtractOp
856//===----------------------------------------------------------------------===//
857
858LogicalResult StructExtractOp::verify() {
859 auto type = getStructFieldType(getInput().getType(), getFieldNameAttr());
860 if (!type)
861 return emitOpError() << "extracts field " << getFieldNameAttr()
862 << " which does not exist in " << getInput().getType();
863 if (type != getType())
864 return emitOpError() << "result type " << getType()
865 << " must match struct field type " << type;
866 return success();
867}
868
869OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
870 // Extract on a constant struct input.
871 if (auto fields = dyn_cast_or_null<DictionaryAttr>(adaptor.getInput()))
872 if (auto value = fields.get(getFieldNameAttr()))
873 return value;
874
875 // extract(inject(s, "field", v), "field") -> v
876 if (auto inject = getInput().getDefiningOp<StructInjectOp>()) {
877 if (inject.getFieldNameAttr() == getFieldNameAttr())
878 return inject.getNewValue();
879 return {};
880 }
881
882 // extract(create({"field": v, ...}), "field") -> v
883 if (auto create = getInput().getDefiningOp<StructCreateOp>()) {
884 if (auto index = getStructFieldIndex(create.getType(), getFieldNameAttr()))
885 return create.getFields()[*index];
886 return {};
887 }
888
889 return {};
890}
891
892//===----------------------------------------------------------------------===//
893// StructExtractRefOp
894//===----------------------------------------------------------------------===//
895
896LogicalResult StructExtractRefOp::verify() {
897 auto type = getStructFieldType(
898 cast<RefType>(getInput().getType()).getNestedType(), getFieldNameAttr());
899 if (!type)
900 return emitOpError() << "extracts field " << getFieldNameAttr()
901 << " which does not exist in " << getInput().getType();
902 if (type != getType().getNestedType())
903 return emitOpError() << "result ref of type " << getType().getNestedType()
904 << " must match struct field type " << type;
905 return success();
906}
907
908bool StructExtractRefOp::canRewire(
909 const DestructurableMemorySlot &slot,
910 SmallPtrSetImpl<Attribute> &usedIndices,
911 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
912 const DataLayout &dataLayout) {
913 if (slot.ptr != getInput())
914 return false;
915 auto index = getFieldNameAttr();
916 if (!index || !slot.subelementTypes.contains(index))
917 return false;
918 usedIndices.insert(index);
919 return true;
920}
921
922DeletionKind
923StructExtractRefOp::rewire(const DestructurableMemorySlot &slot,
924 DenseMap<Attribute, MemorySlot> &subslots,
925 OpBuilder &builder, const DataLayout &dataLayout) {
926 auto index = getFieldNameAttr();
927 const MemorySlot &memorySlot = subslots.at(index);
928 replaceAllUsesWith(memorySlot.ptr);
929 getInputMutable().drop();
930 erase();
931 return DeletionKind::Keep;
932}
933
934//===----------------------------------------------------------------------===//
935// StructInjectOp
936//===----------------------------------------------------------------------===//
937
938LogicalResult StructInjectOp::verify() {
939 auto type = getStructFieldType(getInput().getType(), getFieldNameAttr());
940 if (!type)
941 return emitOpError() << "injects field " << getFieldNameAttr()
942 << " which does not exist in " << getInput().getType();
943 if (type != getNewValue().getType())
944 return emitOpError() << "injected value " << getNewValue().getType()
945 << " must match struct field type " << type;
946 return success();
947}
948
949OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
950 auto input = adaptor.getInput();
951 auto newValue = adaptor.getNewValue();
952 if (!input || !newValue)
953 return {};
954 NamedAttrList fields(cast<DictionaryAttr>(input));
955 fields.set(getFieldNameAttr(), newValue);
956 return fields.getDictionary(getContext());
957}
958
959LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
960 PatternRewriter &rewriter) {
961 auto members = getStructMembers(op.getType());
962
963 // Chase a chain of `struct_inject` ops, with an optional final
964 // `struct_create`, and take note of the values assigned to each field.
965 SmallPtrSet<Operation *, 4> injectOps;
966 DenseMap<StringAttr, Value> fieldValues;
967 Value input = op;
968 while (auto injectOp = input.getDefiningOp<StructInjectOp>()) {
969 if (!injectOps.insert(injectOp).second)
970 return failure();
971 fieldValues.insert({injectOp.getFieldNameAttr(), injectOp.getNewValue()});
972 input = injectOp.getInput();
973 }
974 if (auto createOp = input.getDefiningOp<StructCreateOp>())
975 for (auto [value, member] : llvm::zip(createOp.getFields(), members))
976 fieldValues.insert({member.name, value});
977
978 // If the inject chain sets all fields, canonicalize to a `struct_create`.
979 if (fieldValues.size() == members.size()) {
980 SmallVector<Value> values;
981 values.reserve(fieldValues.size());
982 for (auto member : members)
983 values.push_back(fieldValues.lookup(member.name));
984 rewriter.replaceOpWithNewOp<StructCreateOp>(op, op.getType(), values);
985 return success();
986 }
987
988 // If each inject op in the chain assigned to a unique field, there is nothing
989 // to canonicalize.
990 if (injectOps.size() == fieldValues.size())
991 return failure();
992
993 // Otherwise we can eliminate overwrites by creating new injects. The hash map
994 // of field values contains the last assigned value for each field.
995 for (auto member : members)
996 if (auto value = fieldValues.lookup(member.name))
997 input = StructInjectOp::create(rewriter, op.getLoc(), op.getType(), input,
998 member.name, value);
999 rewriter.replaceOp(op, input);
1000 return success();
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// UnionCreateOp
1005//===----------------------------------------------------------------------===//
1006
1007LogicalResult UnionCreateOp::verify() {
1008 /// checks if the types of the input is exactly equal to the union field
1009 /// type
1010 return TypeSwitch<Type, LogicalResult>(getType())
1011 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
1012 auto members = type.getMembers();
1013 auto inputType = getInput().getType();
1014 auto fieldName = getFieldName();
1015 for (const auto &member : members)
1016 if (member.name == fieldName && member.type == inputType)
1017 return success();
1018 for (const auto &member : members) {
1019 if (member.name == fieldName) {
1020 emitOpError() << "input type " << inputType
1021 << " does not match union field '" << fieldName
1022 << "' type " << member.type;
1023 return failure();
1024 }
1025 }
1026 emitOpError() << "field '" << fieldName << "' not found in union type";
1027 return failure();
1028 })
1029 .Default([this](auto &) {
1030 emitOpError("input type must be UnionType or UnpackedUnionType");
1031 return failure();
1032 });
1033}
1034
1035//===----------------------------------------------------------------------===//
1036// UnionExtractOp
1037//===----------------------------------------------------------------------===//
1038
1039LogicalResult UnionExtractOp::verify() {
1040 /// checks if the types of the input is exactly equal to the one of the
1041 /// types of the result union fields
1042 return TypeSwitch<Type, LogicalResult>(getInput().getType())
1043 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
1044 auto members = type.getMembers();
1045 auto fieldName = getFieldName();
1046 auto resultType = getType();
1047 for (const auto &member : members)
1048 if (member.name == fieldName && member.type == resultType)
1049 return success();
1050 emitOpError("result type must match the union field type");
1051 return failure();
1052 })
1053 .Default([this](auto &) {
1054 emitOpError("input type must be UnionType or UnpackedUnionType");
1055 return failure();
1056 });
1057}
1058
1059//===----------------------------------------------------------------------===//
1060// UnionExtractOp
1061//===----------------------------------------------------------------------===//
1062
1063LogicalResult UnionExtractRefOp::verify() {
1064 /// checks if the types of the result is exactly equal to the type of the
1065 /// refe union field
1066 return TypeSwitch<Type, LogicalResult>(getInput().getType().getNestedType())
1067 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
1068 auto members = type.getMembers();
1069 auto fieldName = getFieldName();
1070 auto resultType = getType().getNestedType();
1071 for (const auto &member : members)
1072 if (member.name == fieldName && member.type == resultType)
1073 return success();
1074 emitOpError("result type must match the union field type");
1075 return failure();
1076 })
1077 .Default([this](auto &) {
1078 emitOpError("input type must be UnionType or UnpackedUnionType");
1079 return failure();
1080 });
1081}
1082
1083//===----------------------------------------------------------------------===//
1084// YieldOp
1085//===----------------------------------------------------------------------===//
1086
1087LogicalResult YieldOp::verify() {
1088 Type expType;
1089 auto *parentOp = getOperation()->getParentOp();
1090 if (auto cond = dyn_cast<ConditionalOp>(parentOp)) {
1091 expType = cond.getType();
1092 } else if (auto varOp = dyn_cast<GlobalVariableOp>(parentOp)) {
1093 expType = varOp.getType();
1094 } else {
1095 llvm_unreachable("all in ParentOneOf handled");
1096 }
1097
1098 auto actType = getOperand().getType();
1099 if (expType != actType) {
1100 return emitOpError() << "yields " << actType << ", but parent expects "
1101 << expType;
1102 }
1103 return success();
1104}
1105
1106//===----------------------------------------------------------------------===//
1107// ConversionOp
1108//===----------------------------------------------------------------------===//
1109
1110OpFoldResult ConversionOp::fold(FoldAdaptor adaptor) {
1111 // Fold away no-op casts.
1112 if (getInput().getType() == getResult().getType())
1113 return getInput();
1114
1115 // Convert domains of constant integer inputs.
1116 auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput());
1117 auto fromIntType = dyn_cast<IntType>(getInput().getType());
1118 auto toIntType = dyn_cast<IntType>(getResult().getType());
1119 if (intInput && fromIntType && toIntType &&
1120 fromIntType.getWidth() == toIntType.getWidth()) {
1121 // If we are going *to* a four-valued type, simply pass through the
1122 // constant.
1123 if (toIntType.getDomain() == Domain::FourValued)
1124 return intInput;
1125
1126 // Otherwise map all unknown bits to zero (the default in SystemVerilog) and
1127 // return a new constant.
1128 return FVIntegerAttr::get(getContext(), intInput.getValue().toAPInt(false));
1129 }
1130
1131 return {};
1132}
1133
1134//===----------------------------------------------------------------------===//
1135// LogicToIntOp
1136//===----------------------------------------------------------------------===//
1137
1138OpFoldResult LogicToIntOp::fold(FoldAdaptor adaptor) {
1139 // logic_to_int(int_to_logic(x)) -> x
1140 if (auto reverseOp = getInput().getDefiningOp<IntToLogicOp>())
1141 return reverseOp.getInput();
1142
1143 // Map all unknown bits to zero (the default in SystemVerilog) and return a
1144 // new constant.
1145 if (auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1146 return FVIntegerAttr::get(getContext(), intInput.getValue().toAPInt(false));
1147
1148 return {};
1149}
1150
1151//===----------------------------------------------------------------------===//
1152// IntToLogicOp
1153//===----------------------------------------------------------------------===//
1154
1155OpFoldResult IntToLogicOp::fold(FoldAdaptor adaptor) {
1156 // Cannot fold int_to_logic(logic_to_int(x)) -> x since that would lose
1157 // information.
1158
1159 // Simply pass through constants.
1160 if (auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1161 return intInput;
1162
1163 return {};
1164}
1165
1166//===----------------------------------------------------------------------===//
1167// TimeToLogicOp
1168//===----------------------------------------------------------------------===//
1169
1170OpFoldResult TimeToLogicOp::fold(FoldAdaptor adaptor) {
1171 // time_to_logic(logic_to_time(x)) -> x
1172 if (auto reverseOp = getInput().getDefiningOp<LogicToTimeOp>())
1173 return reverseOp.getInput();
1174
1175 // Convert constants.
1176 if (auto attr = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
1177 return FVIntegerAttr::get(getContext(), attr.getValue());
1178
1179 return {};
1180}
1181
1182//===----------------------------------------------------------------------===//
1183// LogicToTimeOp
1184//===----------------------------------------------------------------------===//
1185
1186OpFoldResult LogicToTimeOp::fold(FoldAdaptor adaptor) {
1187 // logic_to_time(time_to_logic(x)) -> x
1188 if (auto reverseOp = getInput().getDefiningOp<TimeToLogicOp>())
1189 return reverseOp.getInput();
1190
1191 // Convert constants.
1192 if (auto attr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1193 return IntegerAttr::get(getContext(), APSInt(attr.getValue().toAPInt(false),
1194 /*isUnsigned=*/true));
1195
1196 return {};
1197}
1198
1199//===----------------------------------------------------------------------===//
1200// ConvertRealOp
1201//===----------------------------------------------------------------------===//
1202
1203OpFoldResult ConvertRealOp::fold(FoldAdaptor adaptor) {
1204 if (getInput().getType() == getResult().getType())
1205 return getInput();
1206
1207 return {};
1208}
1209
1210//===----------------------------------------------------------------------===//
1211// TruncOp
1212//===----------------------------------------------------------------------===//
1213
1214OpFoldResult TruncOp::fold(FoldAdaptor adaptor) {
1215 // Truncate constants.
1216 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1217 auto width = getType().getWidth();
1218 return FVIntegerAttr::get(getContext(), intAttr.getValue().trunc(width));
1219 }
1220
1221 return {};
1222}
1223
1224//===----------------------------------------------------------------------===//
1225// ZExtOp
1226//===----------------------------------------------------------------------===//
1227
1228OpFoldResult ZExtOp::fold(FoldAdaptor adaptor) {
1229 // Zero-extend constants.
1230 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1231 auto width = getType().getWidth();
1232 return FVIntegerAttr::get(getContext(), intAttr.getValue().zext(width));
1233 }
1234
1235 return {};
1236}
1237
1238//===----------------------------------------------------------------------===//
1239// SExtOp
1240//===----------------------------------------------------------------------===//
1241
1242OpFoldResult SExtOp::fold(FoldAdaptor adaptor) {
1243 // Sign-extend constants.
1244 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1245 auto width = getType().getWidth();
1246 return FVIntegerAttr::get(getContext(), intAttr.getValue().sext(width));
1247 }
1248
1249 return {};
1250}
1251
1252//===----------------------------------------------------------------------===//
1253// BoolCastOp
1254//===----------------------------------------------------------------------===//
1255
1256OpFoldResult BoolCastOp::fold(FoldAdaptor adaptor) {
1257 // Fold away no-op casts.
1258 if (getInput().getType() == getResult().getType())
1259 return getInput();
1260 return {};
1261}
1262
1263//===----------------------------------------------------------------------===//
1264// BlockingAssignOp
1265//===----------------------------------------------------------------------===//
1266
1267bool BlockingAssignOp::loadsFrom(const MemorySlot &slot) { return false; }
1268
1269bool BlockingAssignOp::storesTo(const MemorySlot &slot) {
1270 return getDst() == slot.ptr;
1271}
1272
1273Value BlockingAssignOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1274 Value reachingDef,
1275 const DataLayout &dataLayout) {
1276 return getSrc();
1277}
1278
1279bool BlockingAssignOp::canUsesBeRemoved(
1280 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1281 SmallVectorImpl<OpOperand *> &newBlockingUses,
1282 const DataLayout &dataLayout) {
1283
1284 if (blockingUses.size() != 1)
1285 return false;
1286 Value blockingUse = (*blockingUses.begin())->get();
1287 return blockingUse == slot.ptr && getDst() == slot.ptr &&
1288 getSrc() != slot.ptr && getSrc().getType() == slot.elemType;
1289}
1290
1291DeletionKind BlockingAssignOp::removeBlockingUses(
1292 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1293 OpBuilder &builder, Value reachingDefinition,
1294 const DataLayout &dataLayout) {
1295 return DeletionKind::Delete;
1296}
1297
1298//===----------------------------------------------------------------------===//
1299// ReadOp
1300//===----------------------------------------------------------------------===//
1301
1302bool ReadOp::loadsFrom(const MemorySlot &slot) {
1303 return getInput() == slot.ptr;
1304}
1305
1306bool ReadOp::storesTo(const MemorySlot &slot) { return false; }
1307
1308Value ReadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1309 Value reachingDef, const DataLayout &dataLayout) {
1310 llvm_unreachable("getStored should not be called on ReadOp");
1311}
1312
1313bool ReadOp::canUsesBeRemoved(const MemorySlot &slot,
1314 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1315 SmallVectorImpl<OpOperand *> &newBlockingUses,
1316 const DataLayout &dataLayout) {
1317
1318 if (blockingUses.size() != 1)
1319 return false;
1320 Value blockingUse = (*blockingUses.begin())->get();
1321 return blockingUse == slot.ptr && getOperand() == slot.ptr &&
1322 getResult().getType() == slot.elemType;
1323}
1324
1325DeletionKind
1326ReadOp::removeBlockingUses(const MemorySlot &slot,
1327 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1328 OpBuilder &builder, Value reachingDefinition,
1329 const DataLayout &dataLayout) {
1330 getResult().replaceAllUsesWith(reachingDefinition);
1331 return DeletionKind::Delete;
1332}
1333
1334//===----------------------------------------------------------------------===//
1335// PowSOp
1336//===----------------------------------------------------------------------===//
1337
1338static OpFoldResult powCommonFolding(MLIRContext *ctxt, Attribute lhs,
1339 Attribute rhs) {
1340 auto lhsValue = dyn_cast_or_null<FVIntegerAttr>(lhs);
1341 if (lhsValue && lhsValue.getValue() == 1)
1342 return lhs;
1343
1344 auto rhsValue = dyn_cast_or_null<FVIntegerAttr>(rhs);
1345 if (rhsValue && rhsValue.getValue().isZero())
1346 return FVIntegerAttr::get(ctxt,
1347 FVInt(rhsValue.getValue().getBitWidth(), 1));
1348
1349 return {};
1350}
1351
1352OpFoldResult PowSOp::fold(FoldAdaptor adaptor) {
1353 return powCommonFolding(getContext(), adaptor.getLhs(), adaptor.getRhs());
1354}
1355
1356LogicalResult PowSOp::canonicalize(PowSOp op, PatternRewriter &rewriter) {
1357 Location loc = op.getLoc();
1358 auto intType = cast<IntType>(op.getRhs().getType());
1359 if (auto baseOp = op.getLhs().getDefiningOp<ConstantOp>()) {
1360 if (baseOp.getValue() == 2) {
1361 Value constOne = ConstantOp::create(rewriter, loc, intType, 1);
1362 Value constZero = ConstantOp::create(rewriter, loc, intType, 0);
1363 Value shift = ShlOp::create(rewriter, loc, constOne, op.getRhs());
1364 Value isNegative = SltOp::create(rewriter, loc, op.getRhs(), constZero);
1365 auto condOp = rewriter.replaceOpWithNewOp<ConditionalOp>(
1366 op, op.getLhs().getType(), isNegative);
1367 Block *thenBlock = rewriter.createBlock(&condOp.getTrueRegion());
1368 rewriter.setInsertionPointToStart(thenBlock);
1369 YieldOp::create(rewriter, loc, constZero);
1370 Block *elseBlock = rewriter.createBlock(&condOp.getFalseRegion());
1371 rewriter.setInsertionPointToStart(elseBlock);
1372 YieldOp::create(rewriter, loc, shift);
1373 return success();
1374 }
1375 }
1376
1377 return failure();
1378}
1379
1380//===----------------------------------------------------------------------===//
1381// PowUOp
1382//===----------------------------------------------------------------------===//
1383
1384OpFoldResult PowUOp::fold(FoldAdaptor adaptor) {
1385 return powCommonFolding(getContext(), adaptor.getLhs(), adaptor.getRhs());
1386}
1387
1388LogicalResult PowUOp::canonicalize(PowUOp op, PatternRewriter &rewriter) {
1389 Location loc = op.getLoc();
1390 auto intType = cast<IntType>(op.getRhs().getType());
1391 if (auto baseOp = op.getLhs().getDefiningOp<ConstantOp>()) {
1392 if (baseOp.getValue() == 2) {
1393 Value constOne = ConstantOp::create(rewriter, loc, intType, 1);
1394 rewriter.replaceOpWithNewOp<ShlOp>(op, constOne, op.getRhs());
1395 return success();
1396 }
1397 }
1398
1399 return failure();
1400}
1401
1402//===----------------------------------------------------------------------===//
1403// SubOp
1404//===----------------------------------------------------------------------===//
1405
1406OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1407 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs()))
1408 if (intAttr.getValue().isZero())
1409 return getLhs();
1410
1411 return {};
1412}
1413
1414//===----------------------------------------------------------------------===//
1415// MulOp
1416//===----------------------------------------------------------------------===//
1417
1418OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1419 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1420 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1421 if (lhs && rhs)
1422 return FVIntegerAttr::get(getContext(), lhs.getValue() * rhs.getValue());
1423 return {};
1424}
1425
1426//===----------------------------------------------------------------------===//
1427// DivUOp
1428//===----------------------------------------------------------------------===//
1429
1430OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1431 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1432 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1433 if (lhs && rhs)
1434 return FVIntegerAttr::get(getContext(),
1435 lhs.getValue().udiv(rhs.getValue()));
1436 return {};
1437}
1438
1439//===----------------------------------------------------------------------===//
1440// DivSOp
1441//===----------------------------------------------------------------------===//
1442
1443OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1444 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1445 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1446 if (lhs && rhs)
1447 return FVIntegerAttr::get(getContext(),
1448 lhs.getValue().sdiv(rhs.getValue()));
1449 return {};
1450}
1451
1452//===----------------------------------------------------------------------===//
1453// Classes
1454//===----------------------------------------------------------------------===//
1455
1456LogicalResult ClassDeclOp::verify() {
1457 mlir::Region &body = getBody();
1458 if (body.empty())
1459 return mlir::success();
1460
1461 auto &block = body.front();
1462 for (mlir::Operation &op : block) {
1463
1464 // allow only property and method decls and terminator
1465 if (llvm::isa<circt::moore::ClassPropertyDeclOp,
1466 circt::moore::ClassMethodDeclOp>(&op))
1467 continue;
1468
1469 return emitOpError()
1470 << "body may only contain 'moore.class.propertydecl' operations";
1471 }
1472 return mlir::success();
1473}
1474
1475LogicalResult ClassNewOp::verify() {
1476 // The result is constrained to ClassHandleType in ODS, so this cast should be
1477 // safe.
1478 auto handleTy = cast<ClassHandleType>(getResult().getType());
1479 mlir::SymbolRefAttr classSym = handleTy.getClassSym();
1480 if (!classSym)
1481 return emitOpError("result type is missing a class symbol");
1482
1483 // Resolve the referenced symbol starting from the nearest symbol table.
1484 mlir::Operation *sym =
1485 mlir::SymbolTable::lookupNearestSymbolFrom(getOperation(), classSym);
1486 if (!sym)
1487 return emitOpError("referenced class symbol `")
1488 << classSym << "` was not found";
1489
1490 if (!llvm::isa<ClassDeclOp>(sym))
1491 return emitOpError("symbol `")
1492 << classSym << "` does not name a `moore.class.classdecl`";
1493
1494 return mlir::success();
1495}
1496
1497void ClassNewOp::getEffects(
1498 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1499 &effects) {
1500 // Always allocates heap memory.
1501 effects.emplace_back(MemoryEffects::Allocate::get());
1502}
1503
1504LogicalResult
1505ClassUpcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1506 // 1) Type checks.
1507 auto srcTy = dyn_cast<ClassHandleType>(getOperand().getType());
1508 if (!srcTy)
1509 return emitOpError() << "operand must be !moore.class<...>; got "
1510 << getOperand().getType();
1511
1512 auto dstTy = dyn_cast<ClassHandleType>(getResult().getType());
1513 if (!dstTy)
1514 return emitOpError() << "result must be !moore.class<...>; got "
1515 << getResult().getType();
1516
1517 if (srcTy == dstTy)
1518 return success();
1519
1520 auto *op = getOperation();
1521
1522 auto *srcDeclOp =
1523 symbolTable.lookupNearestSymbolFrom(op, srcTy.getClassSym());
1524 auto *dstDeclOp =
1525 symbolTable.lookupNearestSymbolFrom(op, dstTy.getClassSym());
1526 if (!srcDeclOp || !dstDeclOp)
1527 return emitOpError() << "failed to resolve class symbol(s): src="
1528 << srcTy.getClassSym()
1529 << ", dst=" << dstTy.getClassSym();
1530
1531 auto srcDecl = dyn_cast<ClassDeclOp>(srcDeclOp);
1532 auto dstDecl = dyn_cast<ClassDeclOp>(dstDeclOp);
1533 if (!srcDecl || !dstDecl)
1534 return emitOpError()
1535 << "symbol(s) do not name `moore.class.classdecl` ops: src="
1536 << srcTy.getClassSym() << ", dst=" << dstTy.getClassSym();
1537
1538 auto cur = srcDecl;
1539 while (cur) {
1540 if (cur == dstDecl)
1541 return success(); // legal upcast: dst is src or an ancestor
1542
1543 auto baseSym = cur.getBaseAttr();
1544 if (!baseSym)
1545 break;
1546
1547 auto *baseOp = symbolTable.lookupNearestSymbolFrom(op, baseSym);
1548 cur = llvm::dyn_cast_or_null<ClassDeclOp>(baseOp);
1549 }
1550
1551 return emitOpError() << "cannot upcast from " << srcTy.getClassSym() << " to "
1552 << dstTy.getClassSym()
1553 << " (destination is not a base class)";
1554}
1555
1556LogicalResult
1557ClassPropertyRefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1558 // The operand is constrained to ClassHandleType in ODS; unwrap it.
1559 Type instTy = getInstance().getType();
1560 auto handleTy = dyn_cast<moore::ClassHandleType>(instTy);
1561 if (!handleTy)
1562 return emitOpError() << "instance must be a !moore.class<@C> value, got "
1563 << instTy;
1564
1565 // Extract the referenced class symbol from the handle type.
1566 SymbolRefAttr classSym = handleTy.getClassSym();
1567 if (!classSym)
1568 return emitOpError("instance type is missing a class symbol");
1569
1570 // Resolve the class symbol starting from the nearest symbol table.
1571 Operation *clsSym =
1572 symbolTable.lookupNearestSymbolFrom(getOperation(), classSym);
1573 if (!clsSym)
1574 return emitOpError("referenced class symbol `")
1575 << classSym << "` was not found";
1576 auto classDecl = dyn_cast<ClassDeclOp>(clsSym);
1577 if (!classDecl)
1578 return emitOpError("symbol `")
1579 << classSym << "` does not name a `moore.class.classdecl`";
1580
1581 // Look up the field symbol inside the class declaration's symbol table.
1582 FlatSymbolRefAttr fieldSym = getPropertyAttr();
1583 if (!fieldSym)
1584 return emitOpError("missing field symbol");
1585
1586 Operation *fldSym = symbolTable.lookupSymbolIn(classDecl, fieldSym.getAttr());
1587 if (!fldSym)
1588 return emitOpError("no field `") << fieldSym << "` in class " << classSym;
1589
1590 auto fieldDecl = dyn_cast<ClassPropertyDeclOp>(fldSym);
1591 if (!fieldDecl)
1592 return emitOpError("symbol `")
1593 << fieldSym << "` is not a `moore.class.propertydecl`";
1594
1595 // Result must be !moore.ref<T> where T matches the field's declared type.
1596 auto resRefTy = cast<RefType>(getPropertyRef().getType());
1597 if (!resRefTy)
1598 return emitOpError("result must be a !moore.ref<T>");
1599
1600 Type expectedElemTy = fieldDecl.getPropertyType();
1601 if (resRefTy.getNestedType() != expectedElemTy)
1602 return emitOpError("result element type (")
1603 << resRefTy.getNestedType() << ") does not match field type ("
1604 << expectedElemTy << ")";
1605
1606 return success();
1607}
1608
1609LogicalResult
1610VTableLoadMethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1611 Operation *op = getOperation();
1612
1613 auto object = getObject();
1614 auto implSym = object.getType().getClassSym();
1615
1616 // Check that classdecl of class handle exists
1617 Operation *implOp = symbolTable.lookupNearestSymbolFrom(op, implSym);
1618 if (!implOp)
1619 return emitOpError() << "implementing class " << implSym << " not found";
1620 auto implClass = cast<moore::ClassDeclOp>(implOp);
1621
1622 StringAttr methodName = getMethodSymAttr().getLeafReference();
1623 if (!methodName || methodName.getValue().empty())
1624 return emitOpError() << "empty method name";
1625
1626 moore::ClassDeclOp cursor = implClass;
1627 Operation *methodDeclOp = nullptr;
1628
1629 // Find method in class decl or parents' class decl
1630 while (cursor && !methodDeclOp) {
1631 methodDeclOp = symbolTable.lookupSymbolIn(cursor, methodName);
1632 if (methodDeclOp)
1633 break;
1634 SymbolRefAttr baseSym = cursor.getBaseAttr();
1635 if (!baseSym)
1636 break;
1637 Operation *baseOp = symbolTable.lookupNearestSymbolFrom(op, baseSym);
1638 cursor = baseOp ? cast<moore::ClassDeclOp>(baseOp) : moore::ClassDeclOp();
1639 }
1640
1641 if (!methodDeclOp)
1642 return emitOpError() << "no method `" << methodName << "` found in "
1643 << implClass.getSymName() << " or its bases";
1644
1645 // Make sure method decl is a ClassMethodDeclOp
1646 auto methodDecl = dyn_cast<moore::ClassMethodDeclOp>(methodDeclOp);
1647 if (!methodDecl)
1648 return emitOpError() << "`" << methodName
1649 << "` is not a method declaration";
1650
1651 // Make sure method signature matches
1652 auto resFnTy = cast<FunctionType>(getResult().getType());
1653 auto declFnTy = cast<FunctionType>(methodDecl.getFunctionType());
1654 if (resFnTy != declFnTy)
1655 return emitOpError() << "result type " << resFnTy
1656 << " does not match method erased ABI " << declFnTy;
1657
1658 return success();
1659}
1660
1661LogicalResult VTableOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1662 Operation *self = getOperation();
1663
1664 // sym_name's root must be a ClassDeclOp
1665 SymbolRefAttr name = getSymNameAttr();
1666 if (!name)
1667 return emitOpError("requires 'sym_name' SymbolRefAttr");
1668
1669 // Root symbol must resolve (from the nearest symbol table) to a ClassDeclOp.
1670 Operation *rootDef = symbolTable.lookupNearestSymbolFrom(
1671 self, SymbolRefAttr::get(name.getRootReference()));
1672 if (!rootDef)
1673 return emitOpError() << "cannot resolve root class symbol '"
1674 << name.getRootReference() << "' for sym_name "
1675 << name;
1676
1677 if (!isa<ClassDeclOp>(rootDef))
1678 return emitOpError()
1679 << "root of sym_name must name a 'moore.class.classdecl', got "
1680 << name;
1681
1682 // All good.
1683 return success();
1684}
1685
1686LogicalResult VTableOp::verifyRegions() {
1687 // Ensure only allowed ops appear inside.
1688 for (Operation &op : getBody().front()) {
1689 if (!isa<VTableOp, VTableEntryOp>(op))
1690 return emitOpError(
1691 "body may only contain 'moore.vtable' or 'moore.vtable_entry' ops");
1692 }
1693 return mlir::success();
1694}
1695
1696LogicalResult
1697VTableEntryOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1698 Operation *self = getOperation();
1699
1700 // 'target' must exist and resolve from the top-level symbol table of a func
1701 // op
1702 SymbolRefAttr target = getTargetAttr();
1703 func::FuncOp def =
1704 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(self, target);
1705 if (!def)
1706 return emitOpError()
1707 << "cannot resolve target symbol to a function operation " << target;
1708
1709 // VTableEntries may only exist in VTables.
1710 if (!isa<VTableOp>(self->getParentOp()))
1711 return emitOpError("must be nested directly inside a 'moore.vtable' op");
1712
1713 Operation *currentOp = self;
1714 VTableOp currentVTable;
1715 bool defined = false;
1716
1717 // Walk up the VTable tree and check whether the corresponding classDeclOp
1718 // declares a method with the same implementation. Further checks all the way
1719 // up the tree if another classdeclop overrides the implementation.
1720 // The entry is correct iff the impl matches the most derived classdeclop's
1721 // methoddeclop implementing the virtual method.
1722 while (auto parentOp = dyn_cast<VTableOp>(currentOp->getParentOp())) {
1723 currentOp = parentOp;
1724 currentVTable = cast<VTableOp>(currentOp);
1725
1726 auto classSymName = currentVTable.getSymName();
1727 ClassDeclOp parentClassDecl =
1728 symbolTable.lookupNearestSymbolFrom<ClassDeclOp>(
1729 parentOp, classSymName.getRootReference());
1730 assert(parentClassDecl && "VTableOp must point to a classdeclop");
1731
1732 for (auto method : parentClassDecl.getBody().getOps<ClassMethodDeclOp>()) {
1733 // A virtual interface declaration. Ignore.
1734 if (!method.getImpl())
1735 continue;
1736
1737 // A matching definition.
1738 if (method.getSymName() == getName() && method.getImplAttr() == target)
1739 defined = true;
1740
1741 // All definitions of the same method up the tree must be the same as the
1742 // current definition, there is no shadowing.
1743 // Hence, if we encounter a methoddeclop that has the same name but a
1744 // different implementation that means this vtableentry should point to
1745 // the op's implementation - that's an error.
1746 else if (method.getSymName() == getName() &&
1747 method.getImplAttr() != target && defined)
1748 return emitOpError() << "Target " << target
1749 << " should be overridden by " << classSymName;
1750 }
1751 }
1752 if (!defined)
1753 return emitOpError()
1754 << "Parent class does not point to any implementation!";
1755
1756 return success();
1757}
1758
1759//===----------------------------------------------------------------------===//
1760// TableGen generated logic.
1761//===----------------------------------------------------------------------===//
1762
1763// Provide the autogenerated implementation guts for the Op classes.
1764#define GET_OP_CLASSES
1765#include "circt/Dialect/Moore/Moore.cpp.inc"
1766#include "circt/Dialect/Moore/MooreEnums.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static std::unique_ptr< Context > context
@ Output
Definition HW.h:42
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static OpFoldResult powCommonFolding(MLIRContext *ctxt, Attribute lhs, Attribute rhs)
static ArrayRef< StructLikeMember > getStructMembers(Type type)
Definition MooreOps.cpp:808
static std::optional< uint32_t > getStructFieldIndex(Type type, StringAttr name)
Definition MooreOps.cpp:799
static UnpackedType getStructFieldType(Type type, StringAttr name)
Definition MooreOps.cpp:817
static std::pair< unsigned, UnpackedType > getArrayElements(Type type)
Definition MooreOps.cpp:766
static InstancePath empty
Four-valued arbitrary precision integers.
Definition FVInt.h:37
bool isNegative() const
Determine whether the integer interpreted as a signed number would be negative.
Definition FVInt.h:185
FVInt sext(unsigned bitWidth) const
Sign-extend the integer to a new bit width.
Definition FVInt.h:148
unsigned getSignificantBits() const
Compute the minimum bit width necessary to accurately represent this integer's value and sign.
Definition FVInt.h:102
static FVInt getAllX(unsigned numBits)
Construct an FVInt with all bits set to X.
Definition FVInt.h:75
bool hasUnknown() const
Determine if any bits are X or Z.
Definition FVInt.h:168
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition FVInt.h:92
unsigned getBitWidth() const
Return the number of bits this integer has.
Definition FVInt.h:85
FVInt trunc(unsigned bitWidth) const
Truncate the integer to a smaller bit width.
Definition FVInt.h:132
A packed SystemVerilog type.
Definition MooreTypes.h:153
std::optional< unsigned > getBitSize() const
Get the size of this type in bits.
Domain getDomain() const
Get the value domain of this type.
An unpacked SystemVerilog type.
Definition MooreTypes.h:101
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition HWOps.cpp:529
Domain
The number of values each bit of a type can assume.
Definition MooreTypes.h:49
RealWidth
The type of floating point / real number behind a RealType.
Definition MooreTypes.h:57
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
void printFVInt(AsmPrinter &p, const FVInt &value)
Print a four-valued integer usign an AsmPrinter.
Definition FVInt.cpp:147
ParseResult parseFVInt(AsmParser &p, FVInt &result)
Parse a four-valued integer using an AsmParser.
Definition FVInt.cpp:162
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183