CIRCT 22.0.0git
Loading...
Searching...
No Matches
MooreOps.cpp
Go to the documentation of this file.
1//===- MooreOps.cpp - Implement the Moore operations ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Moore dialect operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Builders.h"
19#include "llvm/ADT/APSInt.h"
20#include "llvm/ADT/SmallString.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include <mlir/Dialect/Func/IR/FuncOps.h>
23
24using namespace circt;
25using namespace circt::moore;
26using namespace mlir;
27
28//===----------------------------------------------------------------------===//
29// SVModuleOp
30//===----------------------------------------------------------------------===//
31
32void SVModuleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
33 llvm::StringRef name, hw::ModuleType type) {
34 state.addAttribute(SymbolTable::getSymbolAttrName(),
35 builder.getStringAttr(name));
36 state.addAttribute(getModuleTypeAttrName(state.name), TypeAttr::get(type));
37 state.addRegion();
38}
39
40void SVModuleOp::print(OpAsmPrinter &p) {
41 p << " ";
42
43 // Print the visibility of the module.
44 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
45 if (auto visibility = (*this)->getAttrOfType<StringAttr>(visibilityAttrName))
46 p << visibility.getValue() << ' ';
47
48 p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
50 getModuleType(), {}, {});
51 p << " ";
52 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false,
53 /*printBlockTerminators=*/true);
54
55 p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(),
56 getAttributeNames());
57}
58
59ParseResult SVModuleOp::parse(OpAsmParser &parser, OperationState &result) {
60 // Parse the visibility attribute.
61 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
62
63 // Parse the module name.
64 StringAttr nameAttr;
65 if (parser.parseSymbolName(nameAttr, getSymNameAttrName(result.name),
66 result.attributes))
67 return failure();
68
69 // Parse the ports.
70 SmallVector<hw::module_like_impl::PortParse> ports;
71 TypeAttr modType;
72 if (failed(
73 hw::module_like_impl::parseModuleSignature(parser, ports, modType)))
74 return failure();
75 result.addAttribute(getModuleTypeAttrName(result.name), modType);
76
77 // Parse the attributes.
78 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
79 return failure();
80
81 // Add the entry block arguments.
82 SmallVector<OpAsmParser::Argument, 4> entryArgs;
83 for (auto &port : ports)
84 if (port.direction != hw::ModulePort::Direction::Output)
85 entryArgs.push_back(port);
86
87 // Parse the optional function body.
88 auto &bodyRegion = *result.addRegion();
89 if (parser.parseRegion(bodyRegion, entryArgs))
90 return failure();
91
92 ensureTerminator(bodyRegion, parser.getBuilder(), result.location);
93 return success();
94}
95
96void SVModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
97 mlir::OpAsmSetValueNameFn setNameFn) {
98 if (&region != &getBodyRegion())
99 return;
100 auto moduleType = getModuleType();
101 for (auto [index, arg] : llvm::enumerate(region.front().getArguments()))
102 setNameFn(arg, moduleType.getInputNameAttr(index));
103}
104
105OutputOp SVModuleOp::getOutputOp() {
106 return cast<OutputOp>(getBody()->getTerminator());
107}
108
109OperandRange SVModuleOp::getOutputs() { return getOutputOp().getOperands(); }
110
111//===----------------------------------------------------------------------===//
112// OutputOp
113//===----------------------------------------------------------------------===//
114
115LogicalResult OutputOp::verify() {
116 auto module = getParentOp();
117
118 // Check that the number of operands matches the number of output ports.
119 auto outputTypes = module.getModuleType().getOutputTypes();
120 if (outputTypes.size() != getNumOperands())
121 return emitOpError("has ")
122 << getNumOperands() << " operands, but enclosing module @"
123 << module.getSymName() << " has " << outputTypes.size()
124 << " outputs";
125
126 // Check that the operand types match the output ports.
127 for (unsigned i = 0, e = outputTypes.size(); i != e; ++i)
128 if (outputTypes[i] != getOperand(i).getType())
129 return emitOpError() << "operand " << i << " (" << getOperand(i).getType()
130 << ") does not match output type (" << outputTypes[i]
131 << ") of module @" << module.getSymName();
132
133 return success();
134}
135
136//===----------------------------------------------------------------------===//
137// InstanceOp
138//===----------------------------------------------------------------------===//
139
140LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
141 // Resolve the target symbol.
142 auto *symbol =
143 symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
144 if (!symbol)
145 return emitOpError("references unknown symbol @") << getModuleName();
146
147 // Check that the symbol is a SVModuleOp.
148 auto module = dyn_cast<SVModuleOp>(symbol);
149 if (!module)
150 return emitOpError("must reference a 'moore.module', but @")
151 << getModuleName() << " is a '" << symbol->getName() << "'";
152
153 // Check that the input ports match.
154 auto moduleType = module.getModuleType();
155 auto inputTypes = moduleType.getInputTypes();
156
157 if (inputTypes.size() != getNumOperands())
158 return emitOpError("has ")
159 << getNumOperands() << " operands, but target module @"
160 << module.getSymName() << " has " << inputTypes.size() << " inputs";
161
162 for (unsigned i = 0, e = inputTypes.size(); i != e; ++i)
163 if (inputTypes[i] != getOperand(i).getType())
164 return emitOpError() << "operand " << i << " (" << getOperand(i).getType()
165 << ") does not match input type (" << inputTypes[i]
166 << ") of module @" << module.getSymName();
167
168 // Check that the output ports match.
169 auto outputTypes = moduleType.getOutputTypes();
170
171 if (outputTypes.size() != getNumResults())
172 return emitOpError("has ")
173 << getNumOperands() << " results, but target module @"
174 << module.getSymName() << " has " << outputTypes.size()
175 << " outputs";
176
177 for (unsigned i = 0, e = outputTypes.size(); i != e; ++i)
178 if (outputTypes[i] != getResult(i).getType())
179 return emitOpError() << "result " << i << " (" << getResult(i).getType()
180 << ") does not match output type (" << outputTypes[i]
181 << ") of module @" << module.getSymName();
182
183 return success();
184}
185
186void InstanceOp::print(OpAsmPrinter &p) {
187 p << " ";
188 p.printAttributeWithoutType(getInstanceNameAttr());
189 p << " ";
190 p.printAttributeWithoutType(getModuleNameAttr());
191 printInputPortList(p, getOperation(), getInputs(), getInputs().getTypes(),
192 getInputNames());
193 p << " -> ";
194 printOutputPortList(p, getOperation(), getOutputs().getTypes(),
195 getOutputNames());
196 p.printOptionalAttrDict(getOperation()->getAttrs(), getAttributeNames());
197}
198
199ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
200 // Parse the instance name.
201 StringAttr instanceName;
202 if (parser.parseAttribute(instanceName, "instanceName", result.attributes))
203 return failure();
204
205 // Parse the module name.
206 FlatSymbolRefAttr moduleName;
207 if (parser.parseAttribute(moduleName, "moduleName", result.attributes))
208 return failure();
209
210 // Parse the input port list.
211 auto loc = parser.getCurrentLocation();
212 SmallVector<OpAsmParser::UnresolvedOperand> inputs;
213 SmallVector<Type> types;
214 ArrayAttr names;
215 if (parseInputPortList(parser, inputs, types, names))
216 return failure();
217 if (parser.resolveOperands(inputs, types, loc, result.operands))
218 return failure();
219 result.addAttribute("inputNames", names);
220
221 // Parse `->`.
222 if (parser.parseArrow())
223 return failure();
224
225 // Parse the output port list.
226 types.clear();
227 if (parseOutputPortList(parser, types, names))
228 return failure();
229 result.addAttribute("outputNames", names);
230 result.addTypes(types);
231
232 // Parse the attributes.
233 if (parser.parseOptionalAttrDict(result.attributes))
234 return failure();
235
236 return success();
237}
238
239void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
240 SmallString<32> name;
241 name += getInstanceName();
242 name += '.';
243 auto baseLen = name.size();
244
245 for (auto [result, portName] :
246 llvm::zip(getOutputs(), getOutputNames().getAsRange<StringAttr>())) {
247 if (!portName || portName.empty())
248 continue;
249 name.resize(baseLen);
250 name += portName.getValue();
251 setNameFn(result, name);
252 }
253}
254
255//===----------------------------------------------------------------------===//
256// VariableOp
257//===----------------------------------------------------------------------===//
258
259void VariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
260 if (getName() && !getName()->empty())
261 setNameFn(getResult(), *getName());
262}
263
264LogicalResult VariableOp::canonicalize(VariableOp op,
265 PatternRewriter &rewriter) {
266 // If the variable is embedded in an SSACFG region, move the initial value
267 // into an assignment immediately after the variable op. This allows the
268 // mem2reg pass which cannot handle variables with initial values.
269 auto initial = op.getInitial();
270 if (initial && mlir::mayHaveSSADominance(*op->getParentRegion())) {
271 rewriter.modifyOpInPlace(op, [&] { op.getInitialMutable().clear(); });
272 rewriter.setInsertionPointAfter(op);
273 BlockingAssignOp::create(rewriter, initial.getLoc(), op, initial);
274 return success();
275 }
276
277 // Check if the variable has one unique continuous assignment to it, all other
278 // uses are reads, and that all uses are in the same block as the variable
279 // itself.
280 auto *block = op->getBlock();
281 ContinuousAssignOp uniqueAssignOp;
282 for (auto *user : op->getUsers()) {
283 // Ensure that all users of the variable are in the same block.
284 if (user->getBlock() != block)
285 return failure();
286
287 // Ensure there is at most one unique continuous assignment to the variable.
288 if (auto assignOp = dyn_cast<ContinuousAssignOp>(user)) {
289 if (uniqueAssignOp)
290 return failure();
291 uniqueAssignOp = assignOp;
292 continue;
293 }
294
295 // Ensure all other users are reads.
296 if (!isa<ReadOp>(user))
297 return failure();
298 }
299 if (!uniqueAssignOp)
300 return failure();
301
302 // If the original variable had a name, create an `AssignedVariableOp` as a
303 // replacement. Otherwise substitute the assigned value directly.
304 Value assignedValue = uniqueAssignOp.getSrc();
305 if (auto name = op.getNameAttr(); name && !name.empty())
306 assignedValue = AssignedVariableOp::create(rewriter, op.getLoc(), name,
307 uniqueAssignOp.getSrc());
308
309 // Remove the assign op and replace all reads with the new assigned var op.
310 rewriter.eraseOp(uniqueAssignOp);
311 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
312 auto readOp = cast<ReadOp>(user);
313 rewriter.replaceOp(readOp, assignedValue);
314 }
315
316 // Remove the original variable.
317 rewriter.eraseOp(op);
318 return success();
319}
320
321SmallVector<MemorySlot> VariableOp::getPromotableSlots() {
322 // We cannot promote variables with an initial value, since that value may not
323 // dominate the location where the default value needs to be constructed.
324 if (mlir::mayBeGraphRegion(*getOperation()->getParentRegion()) ||
325 getInitial())
326 return {};
327
328 // Ensure that `getDefaultValue` can conjure up a default value for the
329 // variable's type.
330 if (!isa<PackedType>(getType().getNestedType()))
331 return {};
332
333 return {MemorySlot{getResult(), getType().getNestedType()}};
334}
335
336Value VariableOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) {
337 auto packedType = dyn_cast<PackedType>(slot.elemType);
338 if (!packedType)
339 return {};
340 auto bitWidth = packedType.getBitSize();
341 if (!bitWidth)
342 return {};
343 auto fvint = packedType.getDomain() == Domain::FourValued
344 ? FVInt::getAllX(*bitWidth)
345 : FVInt::getZero(*bitWidth);
346 Value value = ConstantOp::create(
347 builder, getLoc(),
348 IntType::get(getContext(), *bitWidth, packedType.getDomain()), fvint);
349 if (value.getType() != packedType)
350 ConversionOp::create(builder, getLoc(), packedType, value);
351 return value;
352}
353
354void VariableOp::handleBlockArgument(const MemorySlot &slot,
355 BlockArgument argument,
356 OpBuilder &builder) {}
357
358std::optional<mlir::PromotableAllocationOpInterface>
359VariableOp::handlePromotionComplete(const MemorySlot &slot, Value defaultValue,
360 OpBuilder &builder) {
361 if (defaultValue && defaultValue.use_empty())
362 defaultValue.getDefiningOp()->erase();
363 this->erase();
364 return {};
365}
366
367SmallVector<DestructurableMemorySlot> VariableOp::getDestructurableSlots() {
368 if (isa<SVModuleOp>(getOperation()->getParentOp()))
369 return {};
370 if (getInitial())
371 return {};
372
373 auto refType = getType();
374 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(refType);
375 if (!destructurable)
376 return {};
377
378 auto destructuredType = destructurable.getSubelementIndexMap();
379 if (!destructuredType)
380 return {};
381
382 return {DestructurableMemorySlot{{getResult(), refType}, *destructuredType}};
383}
384
385DenseMap<Attribute, MemorySlot> VariableOp::destructure(
386 const DestructurableMemorySlot &slot,
387 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
388 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
389 assert(slot.ptr == getResult());
390 assert(!getInitial());
391 builder.setInsertionPointAfter(*this);
392
393 auto destructurableType = cast<DestructurableTypeInterface>(getType());
394 DenseMap<Attribute, MemorySlot> slotMap;
395 for (Attribute index : usedIndices) {
396 auto elemType = cast<RefType>(destructurableType.getTypeAtIndex(index));
397 assert(elemType && "used index must exist");
398 StringAttr varName;
399 if (auto name = getName(); name && !name->empty())
400 varName = StringAttr::get(
401 getContext(), (*name) + "." + cast<StringAttr>(index).getValue());
402 auto varOp =
403 VariableOp::create(builder, getLoc(), elemType, varName, Value());
404 newAllocators.push_back(varOp);
405 slotMap.try_emplace<MemorySlot>(index, {varOp.getResult(), elemType});
406 }
407
408 return slotMap;
409}
410
411std::optional<DestructurableAllocationOpInterface>
412VariableOp::handleDestructuringComplete(const DestructurableMemorySlot &slot,
413 OpBuilder &builder) {
414 assert(slot.ptr == getResult());
415 this->erase();
416 return std::nullopt;
417}
418
419//===----------------------------------------------------------------------===//
420// NetOp
421//===----------------------------------------------------------------------===//
422
423void NetOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
424 if (getName() && !getName()->empty())
425 setNameFn(getResult(), *getName());
426}
427
428LogicalResult NetOp::canonicalize(NetOp op, PatternRewriter &rewriter) {
429 bool modified = false;
430
431 // Check if the net has one unique continuous assignment to it, and
432 // additionally if all other users are reads.
433 auto *block = op->getBlock();
434 ContinuousAssignOp uniqueAssignOp;
435 bool allUsesAreReads = true;
436 for (auto *user : op->getUsers()) {
437 // Ensure that all users of the net are in the same block.
438 if (user->getBlock() != block)
439 return failure();
440
441 // Ensure there is at most one unique continuous assignment to the net.
442 if (auto assignOp = dyn_cast<ContinuousAssignOp>(user)) {
443 if (uniqueAssignOp)
444 return failure();
445 uniqueAssignOp = assignOp;
446 continue;
447 }
448
449 // Ensure all other users are reads.
450 if (!isa<ReadOp>(user))
451 allUsesAreReads = false;
452 }
453
454 // If there was one unique assignment, and the `NetOp` does not yet have an
455 // assigned value set, fold the assignment into the net.
456 if (uniqueAssignOp && !op.getAssignment()) {
457 rewriter.modifyOpInPlace(
458 op, [&] { op.getAssignmentMutable().assign(uniqueAssignOp.getSrc()); });
459 rewriter.eraseOp(uniqueAssignOp);
460 modified = true;
461 uniqueAssignOp = {};
462 }
463
464 // If all users of the net op are reads, and any potential unique assignment
465 // has been folded into the net op itself, directly replace the reads with the
466 // net's assigned value.
467 if (!uniqueAssignOp && allUsesAreReads && op.getAssignment()) {
468 // If the original net had a name, create an `AssignedVariableOp` as a
469 // replacement. Otherwise substitute the assigned value directly.
470 auto assignedValue = op.getAssignment();
471 if (auto name = op.getNameAttr(); name && !name.empty())
472 assignedValue = AssignedVariableOp::create(rewriter, op.getLoc(), name,
473 assignedValue);
474
475 // Replace all reads with the new assigned var op and remove the original
476 // net op.
477 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
478 auto readOp = cast<ReadOp>(user);
479 rewriter.replaceOp(readOp, assignedValue);
480 }
481 rewriter.eraseOp(op);
482 modified = true;
483 }
484
485 return success(modified);
486}
487
488//===----------------------------------------------------------------------===//
489// AssignedVariableOp
490//===----------------------------------------------------------------------===//
491
492void AssignedVariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
493 if (getName() && !getName()->empty())
494 setNameFn(getResult(), *getName());
495}
496
497LogicalResult AssignedVariableOp::canonicalize(AssignedVariableOp op,
498 PatternRewriter &rewriter) {
499 // Eliminate chained variables with the same name.
500 // var(name, var(name, x)) -> var(name, x)
501 if (auto otherOp = op.getInput().getDefiningOp<AssignedVariableOp>()) {
502 if (otherOp.getNameAttr() == op.getNameAttr()) {
503 rewriter.replaceOp(op, otherOp);
504 return success();
505 }
506 }
507
508 // Eliminate variables that alias an input port of the same name.
509 if (auto blockArg = dyn_cast<BlockArgument>(op.getInput())) {
510 if (auto moduleOp =
511 dyn_cast<SVModuleOp>(blockArg.getOwner()->getParentOp())) {
512 auto moduleType = moduleOp.getModuleType();
513 auto portName = moduleType.getInputNameAttr(blockArg.getArgNumber());
514 if (portName == op.getNameAttr()) {
515 rewriter.replaceOp(op, blockArg);
516 return success();
517 }
518 }
519 }
520
521 // Eliminate variables that feed an output port of the same name.
522 for (auto &use : op->getUses()) {
523 auto *useOwner = use.getOwner();
524 if (auto outputOp = dyn_cast<OutputOp>(useOwner)) {
525 if (auto moduleOp = dyn_cast<SVModuleOp>(outputOp->getParentOp())) {
526 auto moduleType = moduleOp.getModuleType();
527 auto portName = moduleType.getOutputNameAttr(use.getOperandNumber());
528 if (portName == op.getNameAttr()) {
529 rewriter.replaceOp(op, op.getInput());
530 return success();
531 }
532 } else
533 break;
534 }
535 }
536
537 return failure();
538}
539
540//===----------------------------------------------------------------------===//
541// GlobalVariableOp
542//===----------------------------------------------------------------------===//
543
544LogicalResult GlobalVariableOp::verifyRegions() {
545 if (auto *block = getInitBlock()) {
546 auto &terminator = block->back();
547 if (!isa<YieldOp>(terminator))
548 return emitOpError() << "must have a 'moore.yield' terminator";
549 }
550 return success();
551}
552
553Block *GlobalVariableOp::getInitBlock() {
554 if (getInitRegion().empty())
555 return nullptr;
556 return &getInitRegion().front();
557}
558
559//===----------------------------------------------------------------------===//
560// GetGlobalVariableOp
561//===----------------------------------------------------------------------===//
562
563LogicalResult
564GetGlobalVariableOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
565 // Resolve the target symbol.
566 auto *symbol =
567 symbolTable.lookupNearestSymbolFrom(*this, getGlobalNameAttr());
568 if (!symbol)
569 return emitOpError() << "references unknown symbol " << getGlobalNameAttr();
570
571 // Check that the symbol is a global variable.
572 auto var = dyn_cast<GlobalVariableOp>(symbol);
573 if (!var)
574 return emitOpError() << "must reference a 'moore.global_variable', but "
575 << getGlobalNameAttr() << " is a '"
576 << symbol->getName() << "'";
577
578 // Check that the types match.
579 auto expType = var.getType();
580 auto actType = getType().getNestedType();
581 if (expType != actType)
582 return emitOpError() << "returns a " << actType << " reference, but "
583 << getGlobalNameAttr() << " is of type " << expType;
584
585 return success();
586}
587
588//===----------------------------------------------------------------------===//
589// ConstantOp
590//===----------------------------------------------------------------------===//
591
592void ConstantOp::print(OpAsmPrinter &p) {
593 p << " ";
594 printFVInt(p, getValue());
595 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
596 p << " : ";
597 p.printStrippedAttrOrType(getType());
598}
599
600ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
601 // Parse the constant value.
602 FVInt value;
603 auto valueLoc = parser.getCurrentLocation();
604 if (parseFVInt(parser, value))
605 return failure();
606
607 // Parse any optional attributes and the `:`.
608 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
609 return failure();
610
611 // Parse the result type.
612 IntType type;
613 if (parser.parseCustomTypeWithFallback(type))
614 return failure();
615
616 // Extend or truncate the constant value to match the size of the type.
617 if (type.getWidth() > value.getBitWidth()) {
618 // sext is always safe here, even for unsigned values, because the
619 // parseOptionalInteger method will return something with a zero in the
620 // top bits if it is a positive number.
621 value = value.sext(type.getWidth());
622 } else if (type.getWidth() < value.getBitWidth()) {
623 // The parser can return an unnecessarily wide result with leading
624 // zeros. This isn't a problem, but truncating off bits is bad.
625 unsigned neededBits =
626 value.isNegative() ? value.getSignificantBits() : value.getActiveBits();
627 if (type.getWidth() < neededBits)
628 return parser.emitError(valueLoc)
629 << "value requires " << neededBits
630 << " bits, but result type only has " << type.getWidth();
631 value = value.trunc(type.getWidth());
632 }
633
634 // If the constant contains any X or Z bits, the result type must be
635 // four-valued.
636 if (value.hasUnknown() && type.getDomain() != Domain::FourValued)
637 return parser.emitError(valueLoc)
638 << "value contains X or Z bits, but result type " << type
639 << " only allows two-valued bits";
640
641 // Build the attribute and op.
642 auto attrValue = FVIntegerAttr::get(parser.getContext(), value);
643 result.addAttribute("value", attrValue);
644 result.addTypes(type);
645 return success();
646}
647
648LogicalResult ConstantOp::verify() {
649 auto attrWidth = getValue().getBitWidth();
650 auto typeWidth = getType().getWidth();
651 if (attrWidth != typeWidth)
652 return emitError("attribute width ")
653 << attrWidth << " does not match return type's width " << typeWidth;
654 return success();
655}
656
657void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
658 const FVInt &value) {
659 assert(type.getWidth() == value.getBitWidth() &&
660 "FVInt width must match type width");
661 build(builder, result, type, FVIntegerAttr::get(builder.getContext(), value));
662}
663
664void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
665 const APInt &value) {
666 assert(type.getWidth() == value.getBitWidth() &&
667 "APInt width must match type width");
668 build(builder, result, type, FVInt(value));
669}
670
671/// This builder allows construction of small signed integers like 0, 1, -1
672/// matching a specified MLIR type. This shouldn't be used for general constant
673/// folding because it only works with values that can be expressed in an
674/// `int64_t`.
675void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
676 int64_t value, bool isSigned) {
677 build(builder, result, type,
678 APInt(type.getWidth(), (uint64_t)value, isSigned));
679}
680
681OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
682 assert(adaptor.getOperands().empty() && "constant has no operands");
683 return getValueAttr();
684}
685
686//===----------------------------------------------------------------------===//
687// ConstantTimeOp
688//===----------------------------------------------------------------------===//
689
690OpFoldResult ConstantTimeOp::fold(FoldAdaptor adaptor) {
691 return getValueAttr();
692}
693
694//===----------------------------------------------------------------------===//
695// ConstantRealOp
696//===----------------------------------------------------------------------===//
697
698LogicalResult ConstantRealOp::inferReturnTypes(
699 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
700 DictionaryAttr attrs, mlir::OpaqueProperties properties,
701 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
702 ConstantRealOp::Adaptor adaptor(operands, attrs, properties);
703 results.push_back(RealType::get(
704 context, static_cast<RealWidth>(
705 adaptor.getValueAttr().getType().getIntOrFloatBitWidth())));
706 return success();
707}
708
709//===----------------------------------------------------------------------===//
710// ConcatOp
711//===----------------------------------------------------------------------===//
712
713LogicalResult ConcatOp::inferReturnTypes(
714 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
715 DictionaryAttr attrs, mlir::OpaqueProperties properties,
716 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
717 Domain domain = Domain::TwoValued;
718 unsigned width = 0;
719 for (auto operand : operands) {
720 auto type = cast<IntType>(operand.getType());
721 if (type.getDomain() == Domain::FourValued)
722 domain = Domain::FourValued;
723 width += type.getWidth();
724 }
725 results.push_back(IntType::get(context, width, domain));
726 return success();
727}
728
729//===----------------------------------------------------------------------===//
730// ConcatRefOp
731//===----------------------------------------------------------------------===//
732
733LogicalResult ConcatRefOp::inferReturnTypes(
734 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
735 DictionaryAttr attrs, mlir::OpaqueProperties properties,
736 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
737 Domain domain = Domain::TwoValued;
738 unsigned width = 0;
739 for (Value operand : operands) {
740 UnpackedType nestedType = cast<RefType>(operand.getType()).getNestedType();
741 PackedType packedType = dyn_cast<PackedType>(nestedType);
742
743 if (!packedType) {
744 return failure();
745 }
746
747 if (packedType.getDomain() == Domain::FourValued)
748 domain = Domain::FourValued;
749
750 // getBitSize() for PackedType returns an optional, so we must check it.
751 std::optional<int> bitSize = packedType.getBitSize();
752 if (!bitSize) {
753 return failure();
754 }
755 width += *bitSize;
756 }
757 results.push_back(RefType::get(IntType::get(context, width, domain)));
758 return success();
759}
760
761//===----------------------------------------------------------------------===//
762// ArrayCreateOp
763//===----------------------------------------------------------------------===//
764
765static std::pair<unsigned, UnpackedType> getArrayElements(Type type) {
766 if (auto arrayType = dyn_cast<ArrayType>(type))
767 return {arrayType.getSize(), arrayType.getElementType()};
768 if (auto arrayType = dyn_cast<UnpackedArrayType>(type))
769 return {arrayType.getSize(), arrayType.getElementType()};
770 assert(0 && "expected ArrayType or UnpackedArrayType");
771 return {};
772}
773
774LogicalResult ArrayCreateOp::verify() {
775 auto [size, elementType] = getArrayElements(getType());
776
777 // Check that the number of operands matches the array size.
778 if (getElements().size() != size)
779 return emitOpError() << "has " << getElements().size()
780 << " operands, but result type requires " << size;
781
782 // Check that the operand types match the array element type. We only need to
783 // check one of the operands, since the `SameTypeOperands` trait ensures all
784 // operands have the same type.
785 if (size > 0) {
786 auto value = getElements()[0];
787 if (value.getType() != elementType)
788 return emitOpError() << "operands have type " << value.getType()
789 << ", but array requires " << elementType;
790 }
791 return success();
792}
793
794//===----------------------------------------------------------------------===//
795// StructCreateOp
796//===----------------------------------------------------------------------===//
797
798static std::optional<uint32_t> getStructFieldIndex(Type type, StringAttr name) {
799 if (auto structType = dyn_cast<StructType>(type))
800 return structType.getFieldIndex(name);
801 if (auto structType = dyn_cast<UnpackedStructType>(type))
802 return structType.getFieldIndex(name);
803 assert(0 && "expected StructType or UnpackedStructType");
804 return {};
805}
806
807static ArrayRef<StructLikeMember> getStructMembers(Type type) {
808 if (auto structType = dyn_cast<StructType>(type))
809 return structType.getMembers();
810 if (auto structType = dyn_cast<UnpackedStructType>(type))
811 return structType.getMembers();
812 assert(0 && "expected StructType or UnpackedStructType");
813 return {};
814}
815
816static UnpackedType getStructFieldType(Type type, StringAttr name) {
817 if (auto index = getStructFieldIndex(type, name))
818 return getStructMembers(type)[*index].type;
819 return {};
820}
821
822LogicalResult StructCreateOp::verify() {
823 auto members = getStructMembers(getType());
824
825 // Check that the number of operands matches the number of struct fields.
826 if (getFields().size() != members.size())
827 return emitOpError() << "has " << getFields().size()
828 << " operands, but result type requires "
829 << members.size();
830
831 // Check that the operand types match the struct field types.
832 for (auto [index, pair] : llvm::enumerate(llvm::zip(getFields(), members))) {
833 auto [value, member] = pair;
834 if (value.getType() != member.type)
835 return emitOpError() << "operand #" << index << " has type "
836 << value.getType() << ", but struct field "
837 << member.name << " requires " << member.type;
838 }
839 return success();
840}
841
842OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
843 SmallVector<NamedAttribute> fields;
844 for (auto [member, field] :
845 llvm::zip(getStructMembers(getType()), adaptor.getFields())) {
846 if (!field)
847 return {};
848 fields.push_back(NamedAttribute(member.name, field));
849 }
850 return DictionaryAttr::get(getContext(), fields);
851}
852
853//===----------------------------------------------------------------------===//
854// StructExtractOp
855//===----------------------------------------------------------------------===//
856
857LogicalResult StructExtractOp::verify() {
858 auto type = getStructFieldType(getInput().getType(), getFieldNameAttr());
859 if (!type)
860 return emitOpError() << "extracts field " << getFieldNameAttr()
861 << " which does not exist in " << getInput().getType();
862 if (type != getType())
863 return emitOpError() << "result type " << getType()
864 << " must match struct field type " << type;
865 return success();
866}
867
868OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
869 // Extract on a constant struct input.
870 if (auto fields = dyn_cast_or_null<DictionaryAttr>(adaptor.getInput()))
871 if (auto value = fields.get(getFieldNameAttr()))
872 return value;
873
874 // extract(inject(s, "field", v), "field") -> v
875 if (auto inject = getInput().getDefiningOp<StructInjectOp>()) {
876 if (inject.getFieldNameAttr() == getFieldNameAttr())
877 return inject.getNewValue();
878 return {};
879 }
880
881 // extract(create({"field": v, ...}), "field") -> v
882 if (auto create = getInput().getDefiningOp<StructCreateOp>()) {
883 if (auto index = getStructFieldIndex(create.getType(), getFieldNameAttr()))
884 return create.getFields()[*index];
885 return {};
886 }
887
888 return {};
889}
890
891//===----------------------------------------------------------------------===//
892// StructExtractRefOp
893//===----------------------------------------------------------------------===//
894
895LogicalResult StructExtractRefOp::verify() {
896 auto type = getStructFieldType(
897 cast<RefType>(getInput().getType()).getNestedType(), getFieldNameAttr());
898 if (!type)
899 return emitOpError() << "extracts field " << getFieldNameAttr()
900 << " which does not exist in " << getInput().getType();
901 if (type != getType().getNestedType())
902 return emitOpError() << "result ref of type " << getType().getNestedType()
903 << " must match struct field type " << type;
904 return success();
905}
906
907bool StructExtractRefOp::canRewire(
908 const DestructurableMemorySlot &slot,
909 SmallPtrSetImpl<Attribute> &usedIndices,
910 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
911 const DataLayout &dataLayout) {
912 if (slot.ptr != getInput())
913 return false;
914 auto index = getFieldNameAttr();
915 if (!index || !slot.subelementTypes.contains(index))
916 return false;
917 usedIndices.insert(index);
918 return true;
919}
920
921DeletionKind
922StructExtractRefOp::rewire(const DestructurableMemorySlot &slot,
923 DenseMap<Attribute, MemorySlot> &subslots,
924 OpBuilder &builder, const DataLayout &dataLayout) {
925 auto index = getFieldNameAttr();
926 const MemorySlot &memorySlot = subslots.at(index);
927 replaceAllUsesWith(memorySlot.ptr);
928 getInputMutable().drop();
929 erase();
930 return DeletionKind::Keep;
931}
932
933//===----------------------------------------------------------------------===//
934// StructInjectOp
935//===----------------------------------------------------------------------===//
936
937LogicalResult StructInjectOp::verify() {
938 auto type = getStructFieldType(getInput().getType(), getFieldNameAttr());
939 if (!type)
940 return emitOpError() << "injects field " << getFieldNameAttr()
941 << " which does not exist in " << getInput().getType();
942 if (type != getNewValue().getType())
943 return emitOpError() << "injected value " << getNewValue().getType()
944 << " must match struct field type " << type;
945 return success();
946}
947
948OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
949 auto input = adaptor.getInput();
950 auto newValue = adaptor.getNewValue();
951 if (!input || !newValue)
952 return {};
953 NamedAttrList fields(cast<DictionaryAttr>(input));
954 fields.set(getFieldNameAttr(), newValue);
955 return fields.getDictionary(getContext());
956}
957
958LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
959 PatternRewriter &rewriter) {
960 auto members = getStructMembers(op.getType());
961
962 // Chase a chain of `struct_inject` ops, with an optional final
963 // `struct_create`, and take note of the values assigned to each field.
964 SmallPtrSet<Operation *, 4> injectOps;
965 DenseMap<StringAttr, Value> fieldValues;
966 Value input = op;
967 while (auto injectOp = input.getDefiningOp<StructInjectOp>()) {
968 if (!injectOps.insert(injectOp).second)
969 return failure();
970 fieldValues.insert({injectOp.getFieldNameAttr(), injectOp.getNewValue()});
971 input = injectOp.getInput();
972 }
973 if (auto createOp = input.getDefiningOp<StructCreateOp>())
974 for (auto [value, member] : llvm::zip(createOp.getFields(), members))
975 fieldValues.insert({member.name, value});
976
977 // If the inject chain sets all fields, canonicalize to a `struct_create`.
978 if (fieldValues.size() == members.size()) {
979 SmallVector<Value> values;
980 values.reserve(fieldValues.size());
981 for (auto member : members)
982 values.push_back(fieldValues.lookup(member.name));
983 rewriter.replaceOpWithNewOp<StructCreateOp>(op, op.getType(), values);
984 return success();
985 }
986
987 // If each inject op in the chain assigned to a unique field, there is nothing
988 // to canonicalize.
989 if (injectOps.size() == fieldValues.size())
990 return failure();
991
992 // Otherwise we can eliminate overwrites by creating new injects. The hash map
993 // of field values contains the last assigned value for each field.
994 for (auto member : members)
995 if (auto value = fieldValues.lookup(member.name))
996 input = StructInjectOp::create(rewriter, op.getLoc(), op.getType(), input,
997 member.name, value);
998 rewriter.replaceOp(op, input);
999 return success();
1000}
1001
1002//===----------------------------------------------------------------------===//
1003// UnionCreateOp
1004//===----------------------------------------------------------------------===//
1005
1006LogicalResult UnionCreateOp::verify() {
1007 /// checks if the types of the input is exactly equal to the union field
1008 /// type
1009 return TypeSwitch<Type, LogicalResult>(getType())
1010 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
1011 auto members = type.getMembers();
1012 auto resultType = getType();
1013 auto fieldName = getFieldName();
1014 for (const auto &member : members)
1015 if (member.name == fieldName && member.type == resultType)
1016 return success();
1017 emitOpError("input type must match the union field type");
1018 return failure();
1019 })
1020 .Default([this](auto &) {
1021 emitOpError("input type must be UnionType or UnpackedUnionType");
1022 return failure();
1023 });
1024}
1025
1026//===----------------------------------------------------------------------===//
1027// UnionExtractOp
1028//===----------------------------------------------------------------------===//
1029
1030LogicalResult UnionExtractOp::verify() {
1031 /// checks if the types of the input is exactly equal to the one of the
1032 /// types of the result union fields
1033 return TypeSwitch<Type, LogicalResult>(getInput().getType())
1034 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
1035 auto members = type.getMembers();
1036 auto fieldName = getFieldName();
1037 auto resultType = getType();
1038 for (const auto &member : members)
1039 if (member.name == fieldName && member.type == resultType)
1040 return success();
1041 emitOpError("result type must match the union field type");
1042 return failure();
1043 })
1044 .Default([this](auto &) {
1045 emitOpError("input type must be UnionType or UnpackedUnionType");
1046 return failure();
1047 });
1048}
1049
1050//===----------------------------------------------------------------------===//
1051// UnionExtractOp
1052//===----------------------------------------------------------------------===//
1053
1054LogicalResult UnionExtractRefOp::verify() {
1055 /// checks if the types of the result is exactly equal to the type of the
1056 /// refe union field
1057 return TypeSwitch<Type, LogicalResult>(getInput().getType().getNestedType())
1058 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
1059 auto members = type.getMembers();
1060 auto fieldName = getFieldName();
1061 auto resultType = getType().getNestedType();
1062 for (const auto &member : members)
1063 if (member.name == fieldName && member.type == resultType)
1064 return success();
1065 emitOpError("result type must match the union field type");
1066 return failure();
1067 })
1068 .Default([this](auto &) {
1069 emitOpError("input type must be UnionType or UnpackedUnionType");
1070 return failure();
1071 });
1072}
1073
1074//===----------------------------------------------------------------------===//
1075// YieldOp
1076//===----------------------------------------------------------------------===//
1077
1078LogicalResult YieldOp::verify() {
1079 Type expType;
1080 auto *parentOp = getOperation()->getParentOp();
1081 if (auto cond = dyn_cast<ConditionalOp>(parentOp)) {
1082 expType = cond.getType();
1083 } else if (auto varOp = dyn_cast<GlobalVariableOp>(parentOp)) {
1084 expType = varOp.getType();
1085 } else {
1086 llvm_unreachable("all in ParentOneOf handled");
1087 }
1088
1089 auto actType = getOperand().getType();
1090 if (expType != actType) {
1091 return emitOpError() << "yields " << actType << ", but parent expects "
1092 << expType;
1093 }
1094 return success();
1095}
1096
1097//===----------------------------------------------------------------------===//
1098// ConversionOp
1099//===----------------------------------------------------------------------===//
1100
1101OpFoldResult ConversionOp::fold(FoldAdaptor adaptor) {
1102 // Fold away no-op casts.
1103 if (getInput().getType() == getResult().getType())
1104 return getInput();
1105
1106 // Convert domains of constant integer inputs.
1107 auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput());
1108 auto fromIntType = dyn_cast<IntType>(getInput().getType());
1109 auto toIntType = dyn_cast<IntType>(getResult().getType());
1110 if (intInput && fromIntType && toIntType &&
1111 fromIntType.getWidth() == toIntType.getWidth()) {
1112 // If we are going *to* a four-valued type, simply pass through the
1113 // constant.
1114 if (toIntType.getDomain() == Domain::FourValued)
1115 return intInput;
1116
1117 // Otherwise map all unknown bits to zero (the default in SystemVerilog) and
1118 // return a new constant.
1119 return FVIntegerAttr::get(getContext(), intInput.getValue().toAPInt(false));
1120 }
1121
1122 return {};
1123}
1124
1125//===----------------------------------------------------------------------===//
1126// LogicToIntOp
1127//===----------------------------------------------------------------------===//
1128
1129OpFoldResult LogicToIntOp::fold(FoldAdaptor adaptor) {
1130 // logic_to_int(int_to_logic(x)) -> x
1131 if (auto reverseOp = getInput().getDefiningOp<IntToLogicOp>())
1132 return reverseOp.getInput();
1133
1134 // Map all unknown bits to zero (the default in SystemVerilog) and return a
1135 // new constant.
1136 if (auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1137 return FVIntegerAttr::get(getContext(), intInput.getValue().toAPInt(false));
1138
1139 return {};
1140}
1141
1142//===----------------------------------------------------------------------===//
1143// IntToLogicOp
1144//===----------------------------------------------------------------------===//
1145
1146OpFoldResult IntToLogicOp::fold(FoldAdaptor adaptor) {
1147 // Cannot fold int_to_logic(logic_to_int(x)) -> x since that would lose
1148 // information.
1149
1150 // Simply pass through constants.
1151 if (auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1152 return intInput;
1153
1154 return {};
1155}
1156
1157//===----------------------------------------------------------------------===//
1158// TimeToLogicOp
1159//===----------------------------------------------------------------------===//
1160
1161OpFoldResult TimeToLogicOp::fold(FoldAdaptor adaptor) {
1162 // time_to_logic(logic_to_time(x)) -> x
1163 if (auto reverseOp = getInput().getDefiningOp<LogicToTimeOp>())
1164 return reverseOp.getInput();
1165
1166 // Convert constants.
1167 if (auto attr = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
1168 return FVIntegerAttr::get(getContext(), attr.getValue());
1169
1170 return {};
1171}
1172
1173//===----------------------------------------------------------------------===//
1174// LogicToTimeOp
1175//===----------------------------------------------------------------------===//
1176
1177OpFoldResult LogicToTimeOp::fold(FoldAdaptor adaptor) {
1178 // logic_to_time(time_to_logic(x)) -> x
1179 if (auto reverseOp = getInput().getDefiningOp<TimeToLogicOp>())
1180 return reverseOp.getInput();
1181
1182 // Convert constants.
1183 if (auto attr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1184 return IntegerAttr::get(getContext(), APSInt(attr.getValue().toAPInt(false),
1185 /*isUnsigned=*/true));
1186
1187 return {};
1188}
1189
1190//===----------------------------------------------------------------------===//
1191// TruncOp
1192//===----------------------------------------------------------------------===//
1193
1194OpFoldResult TruncOp::fold(FoldAdaptor adaptor) {
1195 // Truncate constants.
1196 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1197 auto width = getType().getWidth();
1198 return FVIntegerAttr::get(getContext(), intAttr.getValue().trunc(width));
1199 }
1200
1201 return {};
1202}
1203
1204//===----------------------------------------------------------------------===//
1205// ZExtOp
1206//===----------------------------------------------------------------------===//
1207
1208OpFoldResult ZExtOp::fold(FoldAdaptor adaptor) {
1209 // Zero-extend constants.
1210 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1211 auto width = getType().getWidth();
1212 return FVIntegerAttr::get(getContext(), intAttr.getValue().zext(width));
1213 }
1214
1215 return {};
1216}
1217
1218//===----------------------------------------------------------------------===//
1219// SExtOp
1220//===----------------------------------------------------------------------===//
1221
1222OpFoldResult SExtOp::fold(FoldAdaptor adaptor) {
1223 // Sign-extend constants.
1224 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1225 auto width = getType().getWidth();
1226 return FVIntegerAttr::get(getContext(), intAttr.getValue().sext(width));
1227 }
1228
1229 return {};
1230}
1231
1232//===----------------------------------------------------------------------===//
1233// BoolCastOp
1234//===----------------------------------------------------------------------===//
1235
1236OpFoldResult BoolCastOp::fold(FoldAdaptor adaptor) {
1237 // Fold away no-op casts.
1238 if (getInput().getType() == getResult().getType())
1239 return getInput();
1240 return {};
1241}
1242
1243//===----------------------------------------------------------------------===//
1244// BlockingAssignOp
1245//===----------------------------------------------------------------------===//
1246
1247bool BlockingAssignOp::loadsFrom(const MemorySlot &slot) { return false; }
1248
1249bool BlockingAssignOp::storesTo(const MemorySlot &slot) {
1250 return getDst() == slot.ptr;
1251}
1252
1253Value BlockingAssignOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1254 Value reachingDef,
1255 const DataLayout &dataLayout) {
1256 return getSrc();
1257}
1258
1259bool BlockingAssignOp::canUsesBeRemoved(
1260 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1261 SmallVectorImpl<OpOperand *> &newBlockingUses,
1262 const DataLayout &dataLayout) {
1263
1264 if (blockingUses.size() != 1)
1265 return false;
1266 Value blockingUse = (*blockingUses.begin())->get();
1267 return blockingUse == slot.ptr && getDst() == slot.ptr &&
1268 getSrc() != slot.ptr && getSrc().getType() == slot.elemType;
1269}
1270
1271DeletionKind BlockingAssignOp::removeBlockingUses(
1272 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1273 OpBuilder &builder, Value reachingDefinition,
1274 const DataLayout &dataLayout) {
1275 return DeletionKind::Delete;
1276}
1277
1278//===----------------------------------------------------------------------===//
1279// ReadOp
1280//===----------------------------------------------------------------------===//
1281
1282bool ReadOp::loadsFrom(const MemorySlot &slot) {
1283 return getInput() == slot.ptr;
1284}
1285
1286bool ReadOp::storesTo(const MemorySlot &slot) { return false; }
1287
1288Value ReadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1289 Value reachingDef, const DataLayout &dataLayout) {
1290 llvm_unreachable("getStored should not be called on ReadOp");
1291}
1292
1293bool ReadOp::canUsesBeRemoved(const MemorySlot &slot,
1294 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1295 SmallVectorImpl<OpOperand *> &newBlockingUses,
1296 const DataLayout &dataLayout) {
1297
1298 if (blockingUses.size() != 1)
1299 return false;
1300 Value blockingUse = (*blockingUses.begin())->get();
1301 return blockingUse == slot.ptr && getOperand() == slot.ptr &&
1302 getResult().getType() == slot.elemType;
1303}
1304
1305DeletionKind
1306ReadOp::removeBlockingUses(const MemorySlot &slot,
1307 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1308 OpBuilder &builder, Value reachingDefinition,
1309 const DataLayout &dataLayout) {
1310 getResult().replaceAllUsesWith(reachingDefinition);
1311 return DeletionKind::Delete;
1312}
1313
1314//===----------------------------------------------------------------------===//
1315// PowSOp
1316//===----------------------------------------------------------------------===//
1317
1318static OpFoldResult powCommonFolding(MLIRContext *ctxt, Attribute lhs,
1319 Attribute rhs) {
1320 auto lhsValue = dyn_cast_or_null<FVIntegerAttr>(lhs);
1321 if (lhsValue && lhsValue.getValue() == 1)
1322 return lhs;
1323
1324 auto rhsValue = dyn_cast_or_null<FVIntegerAttr>(rhs);
1325 if (rhsValue && rhsValue.getValue().isZero())
1326 return FVIntegerAttr::get(ctxt,
1327 FVInt(rhsValue.getValue().getBitWidth(), 1));
1328
1329 return {};
1330}
1331
1332OpFoldResult PowSOp::fold(FoldAdaptor adaptor) {
1333 return powCommonFolding(getContext(), adaptor.getLhs(), adaptor.getRhs());
1334}
1335
1336LogicalResult PowSOp::canonicalize(PowSOp op, PatternRewriter &rewriter) {
1337 Location loc = op.getLoc();
1338 auto intType = cast<IntType>(op.getRhs().getType());
1339 if (auto baseOp = op.getLhs().getDefiningOp<ConstantOp>()) {
1340 if (baseOp.getValue() == 2) {
1341 Value constOne = ConstantOp::create(rewriter, loc, intType, 1);
1342 Value constZero = ConstantOp::create(rewriter, loc, intType, 0);
1343 Value shift = ShlOp::create(rewriter, loc, constOne, op.getRhs());
1344 Value isNegative = SltOp::create(rewriter, loc, op.getRhs(), constZero);
1345 auto condOp = rewriter.replaceOpWithNewOp<ConditionalOp>(
1346 op, op.getLhs().getType(), isNegative);
1347 Block *thenBlock = rewriter.createBlock(&condOp.getTrueRegion());
1348 rewriter.setInsertionPointToStart(thenBlock);
1349 YieldOp::create(rewriter, loc, constZero);
1350 Block *elseBlock = rewriter.createBlock(&condOp.getFalseRegion());
1351 rewriter.setInsertionPointToStart(elseBlock);
1352 YieldOp::create(rewriter, loc, shift);
1353 return success();
1354 }
1355 }
1356
1357 return failure();
1358}
1359
1360//===----------------------------------------------------------------------===//
1361// PowUOp
1362//===----------------------------------------------------------------------===//
1363
1364OpFoldResult PowUOp::fold(FoldAdaptor adaptor) {
1365 return powCommonFolding(getContext(), adaptor.getLhs(), adaptor.getRhs());
1366}
1367
1368LogicalResult PowUOp::canonicalize(PowUOp op, PatternRewriter &rewriter) {
1369 Location loc = op.getLoc();
1370 auto intType = cast<IntType>(op.getRhs().getType());
1371 if (auto baseOp = op.getLhs().getDefiningOp<ConstantOp>()) {
1372 if (baseOp.getValue() == 2) {
1373 Value constOne = ConstantOp::create(rewriter, loc, intType, 1);
1374 rewriter.replaceOpWithNewOp<ShlOp>(op, constOne, op.getRhs());
1375 return success();
1376 }
1377 }
1378
1379 return failure();
1380}
1381
1382//===----------------------------------------------------------------------===//
1383// SubOp
1384//===----------------------------------------------------------------------===//
1385
1386OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1387 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs()))
1388 if (intAttr.getValue().isZero())
1389 return getLhs();
1390
1391 return {};
1392}
1393
1394//===----------------------------------------------------------------------===//
1395// MulOp
1396//===----------------------------------------------------------------------===//
1397
1398OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1399 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1400 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1401 if (lhs && rhs)
1402 return FVIntegerAttr::get(getContext(), lhs.getValue() * rhs.getValue());
1403 return {};
1404}
1405
1406//===----------------------------------------------------------------------===//
1407// DivUOp
1408//===----------------------------------------------------------------------===//
1409
1410OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1411 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1412 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1413 if (lhs && rhs)
1414 return FVIntegerAttr::get(getContext(),
1415 lhs.getValue().udiv(rhs.getValue()));
1416 return {};
1417}
1418
1419//===----------------------------------------------------------------------===//
1420// DivSOp
1421//===----------------------------------------------------------------------===//
1422
1423OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1424 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1425 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1426 if (lhs && rhs)
1427 return FVIntegerAttr::get(getContext(),
1428 lhs.getValue().sdiv(rhs.getValue()));
1429 return {};
1430}
1431
1432//===----------------------------------------------------------------------===//
1433// Classes
1434//===----------------------------------------------------------------------===//
1435
1436LogicalResult ClassDeclOp::verify() {
1437 mlir::Region &body = getBody();
1438 if (body.empty())
1439 return mlir::success();
1440
1441 auto &block = body.front();
1442 for (mlir::Operation &op : block) {
1443
1444 // allow only property and method decls and terminator
1445 if (llvm::isa<circt::moore::ClassPropertyDeclOp,
1446 circt::moore::ClassMethodDeclOp>(&op))
1447 continue;
1448
1449 return emitOpError()
1450 << "body may only contain 'moore.class.propertydecl' operations";
1451 }
1452 return mlir::success();
1453}
1454
1455LogicalResult ClassNewOp::verify() {
1456 // The result is constrained to ClassHandleType in ODS, so this cast should be
1457 // safe.
1458 auto handleTy = cast<ClassHandleType>(getResult().getType());
1459 mlir::SymbolRefAttr classSym = handleTy.getClassSym();
1460 if (!classSym)
1461 return emitOpError("result type is missing a class symbol");
1462
1463 // Resolve the referenced symbol starting from the nearest symbol table.
1464 mlir::Operation *sym =
1465 mlir::SymbolTable::lookupNearestSymbolFrom(getOperation(), classSym);
1466 if (!sym)
1467 return emitOpError("referenced class symbol `")
1468 << classSym << "` was not found";
1469
1470 if (!llvm::isa<ClassDeclOp>(sym))
1471 return emitOpError("symbol `")
1472 << classSym << "` does not name a `moore.class.classdecl`";
1473
1474 return mlir::success();
1475}
1476
1477void ClassNewOp::getEffects(
1478 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1479 &effects) {
1480 // Always allocates heap memory.
1481 effects.emplace_back(MemoryEffects::Allocate::get());
1482}
1483
1484LogicalResult
1485ClassUpcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1486 // 1) Type checks.
1487 auto srcTy = dyn_cast<ClassHandleType>(getOperand().getType());
1488 if (!srcTy)
1489 return emitOpError() << "operand must be !moore.class<...>; got "
1490 << getOperand().getType();
1491
1492 auto dstTy = dyn_cast<ClassHandleType>(getResult().getType());
1493 if (!dstTy)
1494 return emitOpError() << "result must be !moore.class<...>; got "
1495 << getResult().getType();
1496
1497 if (srcTy == dstTy)
1498 return success();
1499
1500 auto *op = getOperation();
1501
1502 auto *srcDeclOp =
1503 symbolTable.lookupNearestSymbolFrom(op, srcTy.getClassSym());
1504 auto *dstDeclOp =
1505 symbolTable.lookupNearestSymbolFrom(op, dstTy.getClassSym());
1506 if (!srcDeclOp || !dstDeclOp)
1507 return emitOpError() << "failed to resolve class symbol(s): src="
1508 << srcTy.getClassSym()
1509 << ", dst=" << dstTy.getClassSym();
1510
1511 auto srcDecl = dyn_cast<ClassDeclOp>(srcDeclOp);
1512 auto dstDecl = dyn_cast<ClassDeclOp>(dstDeclOp);
1513 if (!srcDecl || !dstDecl)
1514 return emitOpError()
1515 << "symbol(s) do not name `moore.class.classdecl` ops: src="
1516 << srcTy.getClassSym() << ", dst=" << dstTy.getClassSym();
1517
1518 auto cur = srcDecl;
1519 while (cur) {
1520 if (cur == dstDecl)
1521 return success(); // legal upcast: dst is src or an ancestor
1522
1523 auto baseSym = cur.getBaseAttr();
1524 if (!baseSym)
1525 break;
1526
1527 auto *baseOp = symbolTable.lookupNearestSymbolFrom(op, baseSym);
1528 cur = llvm::dyn_cast_or_null<ClassDeclOp>(baseOp);
1529 }
1530
1531 return emitOpError() << "cannot upcast from " << srcTy.getClassSym() << " to "
1532 << dstTy.getClassSym()
1533 << " (destination is not a base class)";
1534}
1535
1536LogicalResult
1537ClassPropertyRefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1538 // The operand is constrained to ClassHandleType in ODS; unwrap it.
1539 Type instTy = getInstance().getType();
1540 auto handleTy = dyn_cast<moore::ClassHandleType>(instTy);
1541 if (!handleTy)
1542 return emitOpError() << "instance must be a !moore.class<@C> value, got "
1543 << instTy;
1544
1545 // Extract the referenced class symbol from the handle type.
1546 SymbolRefAttr classSym = handleTy.getClassSym();
1547 if (!classSym)
1548 return emitOpError("instance type is missing a class symbol");
1549
1550 // Resolve the class symbol starting from the nearest symbol table.
1551 Operation *clsSym =
1552 symbolTable.lookupNearestSymbolFrom(getOperation(), classSym);
1553 if (!clsSym)
1554 return emitOpError("referenced class symbol `")
1555 << classSym << "` was not found";
1556 auto classDecl = dyn_cast<ClassDeclOp>(clsSym);
1557 if (!classDecl)
1558 return emitOpError("symbol `")
1559 << classSym << "` does not name a `moore.class.classdecl`";
1560
1561 // Look up the field symbol inside the class declaration's symbol table.
1562 FlatSymbolRefAttr fieldSym = getPropertyAttr();
1563 if (!fieldSym)
1564 return emitOpError("missing field symbol");
1565
1566 Operation *fldSym = symbolTable.lookupSymbolIn(classDecl, fieldSym.getAttr());
1567 if (!fldSym)
1568 return emitOpError("no field `") << fieldSym << "` in class " << classSym;
1569
1570 auto fieldDecl = dyn_cast<ClassPropertyDeclOp>(fldSym);
1571 if (!fieldDecl)
1572 return emitOpError("symbol `")
1573 << fieldSym << "` is not a `moore.class.propertydecl`";
1574
1575 // Result must be !moore.ref<T> where T matches the field's declared type.
1576 auto resRefTy = cast<RefType>(getPropertyRef().getType());
1577 if (!resRefTy)
1578 return emitOpError("result must be a !moore.ref<T>");
1579
1580 Type expectedElemTy = fieldDecl.getPropertyType();
1581 if (resRefTy.getNestedType() != expectedElemTy)
1582 return emitOpError("result element type (")
1583 << resRefTy.getNestedType() << ") does not match field type ("
1584 << expectedElemTy << ")";
1585
1586 return success();
1587}
1588
1589LogicalResult
1590VTableLoadMethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1591 Operation *op = getOperation();
1592
1593 auto object = getObject();
1594 auto implSym = object.getType().getClassSym();
1595
1596 // Check that classdecl of class handle exists
1597 Operation *implOp = symbolTable.lookupNearestSymbolFrom(op, implSym);
1598 if (!implOp)
1599 return emitOpError() << "implementing class " << implSym << " not found";
1600 auto implClass = cast<moore::ClassDeclOp>(implOp);
1601
1602 StringAttr methodName = getMethodSymAttr().getLeafReference();
1603 if (!methodName || methodName.getValue().empty())
1604 return emitOpError() << "empty method name";
1605
1606 moore::ClassDeclOp cursor = implClass;
1607 Operation *methodDeclOp = nullptr;
1608
1609 // Find method in class decl or parents' class decl
1610 while (cursor && !methodDeclOp) {
1611 methodDeclOp = symbolTable.lookupSymbolIn(cursor, methodName);
1612 if (methodDeclOp)
1613 break;
1614 SymbolRefAttr baseSym = cursor.getBaseAttr();
1615 if (!baseSym)
1616 break;
1617 Operation *baseOp = symbolTable.lookupNearestSymbolFrom(op, baseSym);
1618 cursor = baseOp ? cast<moore::ClassDeclOp>(baseOp) : moore::ClassDeclOp();
1619 }
1620
1621 if (!methodDeclOp)
1622 return emitOpError() << "no method `" << methodName << "` found in "
1623 << implClass.getSymName() << " or its bases";
1624
1625 // Make sure method decl is a ClassMethodDeclOp
1626 auto methodDecl = dyn_cast<moore::ClassMethodDeclOp>(methodDeclOp);
1627 if (!methodDecl)
1628 return emitOpError() << "`" << methodName
1629 << "` is not a method declaration";
1630
1631 // Make sure method signature matches
1632 auto resFnTy = cast<FunctionType>(getResult().getType());
1633 auto declFnTy = cast<FunctionType>(methodDecl.getFunctionType());
1634 if (resFnTy != declFnTy)
1635 return emitOpError() << "result type " << resFnTy
1636 << " does not match method erased ABI " << declFnTy;
1637
1638 return success();
1639}
1640
1641LogicalResult VTableOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1642 Operation *self = getOperation();
1643
1644 // sym_name's root must be a ClassDeclOp
1645 SymbolRefAttr name = getSymNameAttr();
1646 if (!name)
1647 return emitOpError("requires 'sym_name' SymbolRefAttr");
1648
1649 // Root symbol must resolve (from the nearest symbol table) to a ClassDeclOp.
1650 Operation *rootDef = symbolTable.lookupNearestSymbolFrom(
1651 self, SymbolRefAttr::get(name.getRootReference()));
1652 if (!rootDef)
1653 return emitOpError() << "cannot resolve root class symbol '"
1654 << name.getRootReference() << "' for sym_name "
1655 << name;
1656
1657 if (!isa<ClassDeclOp>(rootDef))
1658 return emitOpError()
1659 << "root of sym_name must name a 'moore.class.classdecl', got "
1660 << name;
1661
1662 // All good.
1663 return success();
1664}
1665
1666LogicalResult VTableOp::verifyRegions() {
1667 // Ensure only allowed ops appear inside.
1668 for (Operation &op : getBody().front()) {
1669 if (!isa<VTableOp, VTableEntryOp>(op))
1670 return emitOpError(
1671 "body may only contain 'moore.vtable' or 'moore.vtable_entry' ops");
1672 }
1673 return mlir::success();
1674}
1675
1676LogicalResult
1677VTableEntryOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1678 Operation *self = getOperation();
1679
1680 // 'target' must exist and resolve from the top-level symbol table of a func
1681 // op
1682 SymbolRefAttr target = getTargetAttr();
1683 func::FuncOp def =
1684 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(self, target);
1685 if (!def)
1686 return emitOpError()
1687 << "cannot resolve target symbol to a function operation " << target;
1688
1689 // VTableEntries may only exist in VTables.
1690 if (!isa<VTableOp>(self->getParentOp()))
1691 return emitOpError("must be nested directly inside a 'moore.vtable' op");
1692
1693 Operation *currentOp = self;
1694 VTableOp currentVTable;
1695 bool defined = false;
1696
1697 // Walk up the VTable tree and check whether the corresponding classDeclOp
1698 // declares a method with the same implementation. Further checks all the way
1699 // up the tree if another classdeclop overrides the implementation.
1700 // The entry is correct iff the impl matches the most derived classdeclop's
1701 // methoddeclop implementing the virtual method.
1702 while (auto parentOp = dyn_cast<VTableOp>(currentOp->getParentOp())) {
1703 currentOp = parentOp;
1704 currentVTable = cast<VTableOp>(currentOp);
1705
1706 auto classSymName = currentVTable.getSymName();
1707 ClassDeclOp parentClassDecl =
1708 symbolTable.lookupNearestSymbolFrom<ClassDeclOp>(
1709 parentOp, classSymName.getRootReference());
1710 assert(parentClassDecl && "VTableOp must point to a classdeclop");
1711
1712 for (auto method : parentClassDecl.getBody().getOps<ClassMethodDeclOp>()) {
1713 // A virtual interface declaration. Ignore.
1714 if (!method.getImpl())
1715 continue;
1716
1717 // A matching definition.
1718 if (method.getSymName() == getName() && method.getImplAttr() == target)
1719 defined = true;
1720
1721 // All definitions of the same method up the tree must be the same as the
1722 // current definition, there is no shadowing.
1723 // Hence, if we encounter a methoddeclop that has the same name but a
1724 // different implementation that means this vtableentry should point to
1725 // the op's implementation - that's an error.
1726 else if (method.getSymName() == getName() &&
1727 method.getImplAttr() != target && defined)
1728 return emitOpError() << "Target " << target
1729 << " should be overridden by " << classSymName;
1730 }
1731 }
1732 if (!defined)
1733 return emitOpError()
1734 << "Parent class does not point to any implementation!";
1735
1736 return success();
1737}
1738
1739//===----------------------------------------------------------------------===//
1740// TableGen generated logic.
1741//===----------------------------------------------------------------------===//
1742
1743// Provide the autogenerated implementation guts for the Op classes.
1744#define GET_OP_CLASSES
1745#include "circt/Dialect/Moore/Moore.cpp.inc"
1746#include "circt/Dialect/Moore/MooreEnums.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
@ Output
Definition HW.h:42
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static OpFoldResult powCommonFolding(MLIRContext *ctxt, Attribute lhs, Attribute rhs)
static ArrayRef< StructLikeMember > getStructMembers(Type type)
Definition MooreOps.cpp:807
static std::optional< uint32_t > getStructFieldIndex(Type type, StringAttr name)
Definition MooreOps.cpp:798
static UnpackedType getStructFieldType(Type type, StringAttr name)
Definition MooreOps.cpp:816
static std::pair< unsigned, UnpackedType > getArrayElements(Type type)
Definition MooreOps.cpp:765
static InstancePath empty
Four-valued arbitrary precision integers.
Definition FVInt.h:37
bool isNegative() const
Determine whether the integer interpreted as a signed number would be negative.
Definition FVInt.h:185
FVInt sext(unsigned bitWidth) const
Sign-extend the integer to a new bit width.
Definition FVInt.h:148
unsigned getSignificantBits() const
Compute the minimum bit width necessary to accurately represent this integer's value and sign.
Definition FVInt.h:102
static FVInt getAllX(unsigned numBits)
Construct an FVInt with all bits set to X.
Definition FVInt.h:75
bool hasUnknown() const
Determine if any bits are X or Z.
Definition FVInt.h:168
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition FVInt.h:92
unsigned getBitWidth() const
Return the number of bits this integer has.
Definition FVInt.h:85
FVInt trunc(unsigned bitWidth) const
Truncate the integer to a smaller bit width.
Definition FVInt.h:132
A packed SystemVerilog type.
Definition MooreTypes.h:153
std::optional< unsigned > getBitSize() const
Get the size of this type in bits.
Domain getDomain() const
Get the value domain of this type.
An unpacked SystemVerilog type.
Definition MooreTypes.h:101
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition HWOps.cpp:529
Domain
The number of values each bit of a type can assume.
Definition MooreTypes.h:49
RealWidth
The type of floating point / real number behind a RealType.
Definition MooreTypes.h:57
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
void printFVInt(AsmPrinter &p, const FVInt &value)
Print a four-valued integer usign an AsmPrinter.
Definition FVInt.cpp:147
ParseResult parseFVInt(AsmParser &p, FVInt &result)
Parse a four-valued integer using an AsmParser.
Definition FVInt.cpp:162
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183