CIRCT 22.0.0git
Loading...
Searching...
No Matches
MooreOps.cpp
Go to the documentation of this file.
1//===- MooreOps.cpp - Implement the Moore operations ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Moore dialect operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Builders.h"
19#include "llvm/ADT/APSInt.h"
20#include "llvm/ADT/SmallString.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23using namespace circt;
24using namespace circt::moore;
25using namespace mlir;
26
27//===----------------------------------------------------------------------===//
28// SVModuleOp
29//===----------------------------------------------------------------------===//
30
31void SVModuleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
32 llvm::StringRef name, hw::ModuleType type) {
33 state.addAttribute(SymbolTable::getSymbolAttrName(),
34 builder.getStringAttr(name));
35 state.addAttribute(getModuleTypeAttrName(state.name), TypeAttr::get(type));
36 state.addRegion();
37}
38
39void SVModuleOp::print(OpAsmPrinter &p) {
40 p << " ";
41
42 // Print the visibility of the module.
43 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
44 if (auto visibility = (*this)->getAttrOfType<StringAttr>(visibilityAttrName))
45 p << visibility.getValue() << ' ';
46
47 p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
49 getModuleType(), {}, {});
50 p << " ";
51 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false,
52 /*printBlockTerminators=*/true);
53
54 p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(),
55 getAttributeNames());
56}
57
58ParseResult SVModuleOp::parse(OpAsmParser &parser, OperationState &result) {
59 // Parse the visibility attribute.
60 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
61
62 // Parse the module name.
63 StringAttr nameAttr;
64 if (parser.parseSymbolName(nameAttr, getSymNameAttrName(result.name),
65 result.attributes))
66 return failure();
67
68 // Parse the ports.
69 SmallVector<hw::module_like_impl::PortParse> ports;
70 TypeAttr modType;
71 if (failed(
72 hw::module_like_impl::parseModuleSignature(parser, ports, modType)))
73 return failure();
74 result.addAttribute(getModuleTypeAttrName(result.name), modType);
75
76 // Parse the attributes.
77 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
78 return failure();
79
80 // Add the entry block arguments.
81 SmallVector<OpAsmParser::Argument, 4> entryArgs;
82 for (auto &port : ports)
83 if (port.direction != hw::ModulePort::Direction::Output)
84 entryArgs.push_back(port);
85
86 // Parse the optional function body.
87 auto &bodyRegion = *result.addRegion();
88 if (parser.parseRegion(bodyRegion, entryArgs))
89 return failure();
90
91 ensureTerminator(bodyRegion, parser.getBuilder(), result.location);
92 return success();
93}
94
95void SVModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
96 mlir::OpAsmSetValueNameFn setNameFn) {
97 if (&region != &getBodyRegion())
98 return;
99 auto moduleType = getModuleType();
100 for (auto [index, arg] : llvm::enumerate(region.front().getArguments()))
101 setNameFn(arg, moduleType.getInputNameAttr(index));
102}
103
104OutputOp SVModuleOp::getOutputOp() {
105 return cast<OutputOp>(getBody()->getTerminator());
106}
107
108OperandRange SVModuleOp::getOutputs() { return getOutputOp().getOperands(); }
109
110//===----------------------------------------------------------------------===//
111// OutputOp
112//===----------------------------------------------------------------------===//
113
114LogicalResult OutputOp::verify() {
115 auto module = getParentOp();
116
117 // Check that the number of operands matches the number of output ports.
118 auto outputTypes = module.getModuleType().getOutputTypes();
119 if (outputTypes.size() != getNumOperands())
120 return emitOpError("has ")
121 << getNumOperands() << " operands, but enclosing module @"
122 << module.getSymName() << " has " << outputTypes.size()
123 << " outputs";
124
125 // Check that the operand types match the output ports.
126 for (unsigned i = 0, e = outputTypes.size(); i != e; ++i)
127 if (outputTypes[i] != getOperand(i).getType())
128 return emitOpError() << "operand " << i << " (" << getOperand(i).getType()
129 << ") does not match output type (" << outputTypes[i]
130 << ") of module @" << module.getSymName();
131
132 return success();
133}
134
135//===----------------------------------------------------------------------===//
136// InstanceOp
137//===----------------------------------------------------------------------===//
138
139LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
140 // Resolve the target symbol.
141 auto *symbol =
142 symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
143 if (!symbol)
144 return emitOpError("references unknown symbol @") << getModuleName();
145
146 // Check that the symbol is a SVModuleOp.
147 auto module = dyn_cast<SVModuleOp>(symbol);
148 if (!module)
149 return emitOpError("must reference a 'moore.module', but @")
150 << getModuleName() << " is a '" << symbol->getName() << "'";
151
152 // Check that the input ports match.
153 auto moduleType = module.getModuleType();
154 auto inputTypes = moduleType.getInputTypes();
155
156 if (inputTypes.size() != getNumOperands())
157 return emitOpError("has ")
158 << getNumOperands() << " operands, but target module @"
159 << module.getSymName() << " has " << inputTypes.size() << " inputs";
160
161 for (unsigned i = 0, e = inputTypes.size(); i != e; ++i)
162 if (inputTypes[i] != getOperand(i).getType())
163 return emitOpError() << "operand " << i << " (" << getOperand(i).getType()
164 << ") does not match input type (" << inputTypes[i]
165 << ") of module @" << module.getSymName();
166
167 // Check that the output ports match.
168 auto outputTypes = moduleType.getOutputTypes();
169
170 if (outputTypes.size() != getNumResults())
171 return emitOpError("has ")
172 << getNumOperands() << " results, but target module @"
173 << module.getSymName() << " has " << outputTypes.size()
174 << " outputs";
175
176 for (unsigned i = 0, e = outputTypes.size(); i != e; ++i)
177 if (outputTypes[i] != getResult(i).getType())
178 return emitOpError() << "result " << i << " (" << getResult(i).getType()
179 << ") does not match output type (" << outputTypes[i]
180 << ") of module @" << module.getSymName();
181
182 return success();
183}
184
185void InstanceOp::print(OpAsmPrinter &p) {
186 p << " ";
187 p.printAttributeWithoutType(getInstanceNameAttr());
188 p << " ";
189 p.printAttributeWithoutType(getModuleNameAttr());
190 printInputPortList(p, getOperation(), getInputs(), getInputs().getTypes(),
191 getInputNames());
192 p << " -> ";
193 printOutputPortList(p, getOperation(), getOutputs().getTypes(),
194 getOutputNames());
195 p.printOptionalAttrDict(getOperation()->getAttrs(), getAttributeNames());
196}
197
198ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
199 // Parse the instance name.
200 StringAttr instanceName;
201 if (parser.parseAttribute(instanceName, "instanceName", result.attributes))
202 return failure();
203
204 // Parse the module name.
205 FlatSymbolRefAttr moduleName;
206 if (parser.parseAttribute(moduleName, "moduleName", result.attributes))
207 return failure();
208
209 // Parse the input port list.
210 auto loc = parser.getCurrentLocation();
211 SmallVector<OpAsmParser::UnresolvedOperand> inputs;
212 SmallVector<Type> types;
213 ArrayAttr names;
214 if (parseInputPortList(parser, inputs, types, names))
215 return failure();
216 if (parser.resolveOperands(inputs, types, loc, result.operands))
217 return failure();
218 result.addAttribute("inputNames", names);
219
220 // Parse `->`.
221 if (parser.parseArrow())
222 return failure();
223
224 // Parse the output port list.
225 types.clear();
226 if (parseOutputPortList(parser, types, names))
227 return failure();
228 result.addAttribute("outputNames", names);
229 result.addTypes(types);
230
231 // Parse the attributes.
232 if (parser.parseOptionalAttrDict(result.attributes))
233 return failure();
234
235 return success();
236}
237
238void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
239 SmallString<32> name;
240 name += getInstanceName();
241 name += '.';
242 auto baseLen = name.size();
243
244 for (auto [result, portName] :
245 llvm::zip(getOutputs(), getOutputNames().getAsRange<StringAttr>())) {
246 if (!portName || portName.empty())
247 continue;
248 name.resize(baseLen);
249 name += portName.getValue();
250 setNameFn(result, name);
251 }
252}
253
254//===----------------------------------------------------------------------===//
255// VariableOp
256//===----------------------------------------------------------------------===//
257
258void VariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
259 if (getName() && !getName()->empty())
260 setNameFn(getResult(), *getName());
261}
262
263LogicalResult VariableOp::canonicalize(VariableOp op,
264 PatternRewriter &rewriter) {
265 // If the variable is embedded in an SSACFG region, move the initial value
266 // into an assignment immediately after the variable op. This allows the
267 // mem2reg pass which cannot handle variables with initial values.
268 auto initial = op.getInitial();
269 if (initial && mlir::mayHaveSSADominance(*op->getParentRegion())) {
270 rewriter.modifyOpInPlace(op, [&] { op.getInitialMutable().clear(); });
271 rewriter.setInsertionPointAfter(op);
272 BlockingAssignOp::create(rewriter, initial.getLoc(), op, initial);
273 return success();
274 }
275
276 // Check if the variable has one unique continuous assignment to it, all other
277 // uses are reads, and that all uses are in the same block as the variable
278 // itself.
279 auto *block = op->getBlock();
280 ContinuousAssignOp uniqueAssignOp;
281 for (auto *user : op->getUsers()) {
282 // Ensure that all users of the variable are in the same block.
283 if (user->getBlock() != block)
284 return failure();
285
286 // Ensure there is at most one unique continuous assignment to the variable.
287 if (auto assignOp = dyn_cast<ContinuousAssignOp>(user)) {
288 if (uniqueAssignOp)
289 return failure();
290 uniqueAssignOp = assignOp;
291 continue;
292 }
293
294 // Ensure all other users are reads.
295 if (!isa<ReadOp>(user))
296 return failure();
297 }
298 if (!uniqueAssignOp)
299 return failure();
300
301 // If the original variable had a name, create an `AssignedVariableOp` as a
302 // replacement. Otherwise substitute the assigned value directly.
303 Value assignedValue = uniqueAssignOp.getSrc();
304 if (auto name = op.getNameAttr(); name && !name.empty())
305 assignedValue = AssignedVariableOp::create(rewriter, op.getLoc(), name,
306 uniqueAssignOp.getSrc());
307
308 // Remove the assign op and replace all reads with the new assigned var op.
309 rewriter.eraseOp(uniqueAssignOp);
310 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
311 auto readOp = cast<ReadOp>(user);
312 rewriter.replaceOp(readOp, assignedValue);
313 }
314
315 // Remove the original variable.
316 rewriter.eraseOp(op);
317 return success();
318}
319
320SmallVector<MemorySlot> VariableOp::getPromotableSlots() {
321 // We cannot promote variables with an initial value, since that value may not
322 // dominate the location where the default value needs to be constructed.
323 if (mlir::mayBeGraphRegion(*getOperation()->getParentRegion()) ||
324 getInitial())
325 return {};
326
327 // Ensure that `getDefaultValue` can conjure up a default value for the
328 // variable's type.
329 if (!isa<PackedType>(getType().getNestedType()))
330 return {};
331
332 return {MemorySlot{getResult(), getType().getNestedType()}};
333}
334
335Value VariableOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) {
336 auto packedType = dyn_cast<PackedType>(slot.elemType);
337 if (!packedType)
338 return {};
339 auto bitWidth = packedType.getBitSize();
340 if (!bitWidth)
341 return {};
342 auto fvint = packedType.getDomain() == Domain::FourValued
343 ? FVInt::getAllX(*bitWidth)
344 : FVInt::getZero(*bitWidth);
345 Value value = ConstantOp::create(
346 builder, getLoc(),
347 IntType::get(getContext(), *bitWidth, packedType.getDomain()), fvint);
348 if (value.getType() != packedType)
349 ConversionOp::create(builder, getLoc(), packedType, value);
350 return value;
351}
352
353void VariableOp::handleBlockArgument(const MemorySlot &slot,
354 BlockArgument argument,
355 OpBuilder &builder) {}
356
357std::optional<mlir::PromotableAllocationOpInterface>
358VariableOp::handlePromotionComplete(const MemorySlot &slot, Value defaultValue,
359 OpBuilder &builder) {
360 if (defaultValue && defaultValue.use_empty())
361 defaultValue.getDefiningOp()->erase();
362 this->erase();
363 return {};
364}
365
366SmallVector<DestructurableMemorySlot> VariableOp::getDestructurableSlots() {
367 if (isa<SVModuleOp>(getOperation()->getParentOp()))
368 return {};
369 if (getInitial())
370 return {};
371
372 auto refType = getType();
373 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(refType);
374 if (!destructurable)
375 return {};
376
377 auto destructuredType = destructurable.getSubelementIndexMap();
378 if (!destructuredType)
379 return {};
380
381 return {DestructurableMemorySlot{{getResult(), refType}, *destructuredType}};
382}
383
384DenseMap<Attribute, MemorySlot> VariableOp::destructure(
385 const DestructurableMemorySlot &slot,
386 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
387 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
388 assert(slot.ptr == getResult());
389 assert(!getInitial());
390 builder.setInsertionPointAfter(*this);
391
392 auto destructurableType = cast<DestructurableTypeInterface>(getType());
393 DenseMap<Attribute, MemorySlot> slotMap;
394 for (Attribute index : usedIndices) {
395 auto elemType = cast<RefType>(destructurableType.getTypeAtIndex(index));
396 assert(elemType && "used index must exist");
397 StringAttr varName;
398 if (auto name = getName(); name && !name->empty())
399 varName = StringAttr::get(
400 getContext(), (*name) + "." + cast<StringAttr>(index).getValue());
401 auto varOp =
402 VariableOp::create(builder, getLoc(), elemType, varName, Value());
403 newAllocators.push_back(varOp);
404 slotMap.try_emplace<MemorySlot>(index, {varOp.getResult(), elemType});
405 }
406
407 return slotMap;
408}
409
410std::optional<DestructurableAllocationOpInterface>
411VariableOp::handleDestructuringComplete(const DestructurableMemorySlot &slot,
412 OpBuilder &builder) {
413 assert(slot.ptr == getResult());
414 this->erase();
415 return std::nullopt;
416}
417
418//===----------------------------------------------------------------------===//
419// NetOp
420//===----------------------------------------------------------------------===//
421
422void NetOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
423 if (getName() && !getName()->empty())
424 setNameFn(getResult(), *getName());
425}
426
427LogicalResult NetOp::canonicalize(NetOp op, PatternRewriter &rewriter) {
428 bool modified = false;
429
430 // Check if the net has one unique continuous assignment to it, and
431 // additionally if all other users are reads.
432 auto *block = op->getBlock();
433 ContinuousAssignOp uniqueAssignOp;
434 bool allUsesAreReads = true;
435 for (auto *user : op->getUsers()) {
436 // Ensure that all users of the net are in the same block.
437 if (user->getBlock() != block)
438 return failure();
439
440 // Ensure there is at most one unique continuous assignment to the net.
441 if (auto assignOp = dyn_cast<ContinuousAssignOp>(user)) {
442 if (uniqueAssignOp)
443 return failure();
444 uniqueAssignOp = assignOp;
445 continue;
446 }
447
448 // Ensure all other users are reads.
449 if (!isa<ReadOp>(user))
450 allUsesAreReads = false;
451 }
452
453 // If there was one unique assignment, and the `NetOp` does not yet have an
454 // assigned value set, fold the assignment into the net.
455 if (uniqueAssignOp && !op.getAssignment()) {
456 rewriter.modifyOpInPlace(
457 op, [&] { op.getAssignmentMutable().assign(uniqueAssignOp.getSrc()); });
458 rewriter.eraseOp(uniqueAssignOp);
459 modified = true;
460 uniqueAssignOp = {};
461 }
462
463 // If all users of the net op are reads, and any potential unique assignment
464 // has been folded into the net op itself, directly replace the reads with the
465 // net's assigned value.
466 if (!uniqueAssignOp && allUsesAreReads && op.getAssignment()) {
467 // If the original net had a name, create an `AssignedVariableOp` as a
468 // replacement. Otherwise substitute the assigned value directly.
469 auto assignedValue = op.getAssignment();
470 if (auto name = op.getNameAttr(); name && !name.empty())
471 assignedValue = AssignedVariableOp::create(rewriter, op.getLoc(), name,
472 assignedValue);
473
474 // Replace all reads with the new assigned var op and remove the original
475 // net op.
476 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
477 auto readOp = cast<ReadOp>(user);
478 rewriter.replaceOp(readOp, assignedValue);
479 }
480 rewriter.eraseOp(op);
481 modified = true;
482 }
483
484 return success(modified);
485}
486
487//===----------------------------------------------------------------------===//
488// AssignedVariableOp
489//===----------------------------------------------------------------------===//
490
491void AssignedVariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
492 if (getName() && !getName()->empty())
493 setNameFn(getResult(), *getName());
494}
495
496LogicalResult AssignedVariableOp::canonicalize(AssignedVariableOp op,
497 PatternRewriter &rewriter) {
498 // Eliminate chained variables with the same name.
499 // var(name, var(name, x)) -> var(name, x)
500 if (auto otherOp = op.getInput().getDefiningOp<AssignedVariableOp>()) {
501 if (otherOp.getNameAttr() == op.getNameAttr()) {
502 rewriter.replaceOp(op, otherOp);
503 return success();
504 }
505 }
506
507 // Eliminate variables that alias an input port of the same name.
508 if (auto blockArg = dyn_cast<BlockArgument>(op.getInput())) {
509 if (auto moduleOp =
510 dyn_cast<SVModuleOp>(blockArg.getOwner()->getParentOp())) {
511 auto moduleType = moduleOp.getModuleType();
512 auto portName = moduleType.getInputNameAttr(blockArg.getArgNumber());
513 if (portName == op.getNameAttr()) {
514 rewriter.replaceOp(op, blockArg);
515 return success();
516 }
517 }
518 }
519
520 // Eliminate variables that feed an output port of the same name.
521 for (auto &use : op->getUses()) {
522 auto *useOwner = use.getOwner();
523 if (auto outputOp = dyn_cast<OutputOp>(useOwner)) {
524 if (auto moduleOp = dyn_cast<SVModuleOp>(outputOp->getParentOp())) {
525 auto moduleType = moduleOp.getModuleType();
526 auto portName = moduleType.getOutputNameAttr(use.getOperandNumber());
527 if (portName == op.getNameAttr()) {
528 rewriter.replaceOp(op, op.getInput());
529 return success();
530 }
531 } else
532 break;
533 }
534 }
535
536 return failure();
537}
538
539//===----------------------------------------------------------------------===//
540// ConstantOp
541//===----------------------------------------------------------------------===//
542
543void ConstantOp::print(OpAsmPrinter &p) {
544 p << " ";
545 printFVInt(p, getValue());
546 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
547 p << " : ";
548 p.printStrippedAttrOrType(getType());
549}
550
551ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
552 // Parse the constant value.
553 FVInt value;
554 auto valueLoc = parser.getCurrentLocation();
555 if (parseFVInt(parser, value))
556 return failure();
557
558 // Parse any optional attributes and the `:`.
559 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
560 return failure();
561
562 // Parse the result type.
563 IntType type;
564 if (parser.parseCustomTypeWithFallback(type))
565 return failure();
566
567 // Extend or truncate the constant value to match the size of the type.
568 if (type.getWidth() > value.getBitWidth()) {
569 // sext is always safe here, even for unsigned values, because the
570 // parseOptionalInteger method will return something with a zero in the
571 // top bits if it is a positive number.
572 value = value.sext(type.getWidth());
573 } else if (type.getWidth() < value.getBitWidth()) {
574 // The parser can return an unnecessarily wide result with leading
575 // zeros. This isn't a problem, but truncating off bits is bad.
576 unsigned neededBits =
577 value.isNegative() ? value.getSignificantBits() : value.getActiveBits();
578 if (type.getWidth() < neededBits)
579 return parser.emitError(valueLoc)
580 << "value requires " << neededBits
581 << " bits, but result type only has " << type.getWidth();
582 value = value.trunc(type.getWidth());
583 }
584
585 // If the constant contains any X or Z bits, the result type must be
586 // four-valued.
587 if (value.hasUnknown() && type.getDomain() != Domain::FourValued)
588 return parser.emitError(valueLoc)
589 << "value contains X or Z bits, but result type " << type
590 << " only allows two-valued bits";
591
592 // Build the attribute and op.
593 auto attrValue = FVIntegerAttr::get(parser.getContext(), value);
594 result.addAttribute("value", attrValue);
595 result.addTypes(type);
596 return success();
597}
598
599LogicalResult ConstantOp::verify() {
600 auto attrWidth = getValue().getBitWidth();
601 auto typeWidth = getType().getWidth();
602 if (attrWidth != typeWidth)
603 return emitError("attribute width ")
604 << attrWidth << " does not match return type's width " << typeWidth;
605 return success();
606}
607
608void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
609 const FVInt &value) {
610 assert(type.getWidth() == value.getBitWidth() &&
611 "FVInt width must match type width");
612 build(builder, result, type, FVIntegerAttr::get(builder.getContext(), value));
613}
614
615void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
616 const APInt &value) {
617 assert(type.getWidth() == value.getBitWidth() &&
618 "APInt width must match type width");
619 build(builder, result, type, FVInt(value));
620}
621
622/// This builder allows construction of small signed integers like 0, 1, -1
623/// matching a specified MLIR type. This shouldn't be used for general constant
624/// folding because it only works with values that can be expressed in an
625/// `int64_t`.
626void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
627 int64_t value, bool isSigned) {
628 build(builder, result, type,
629 APInt(type.getWidth(), (uint64_t)value, isSigned));
630}
631
632OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
633 assert(adaptor.getOperands().empty() && "constant has no operands");
634 return getValueAttr();
635}
636
637//===----------------------------------------------------------------------===//
638// ConstantTimeOp
639//===----------------------------------------------------------------------===//
640
641OpFoldResult ConstantTimeOp::fold(FoldAdaptor adaptor) {
642 return getValueAttr();
643}
644
645//===----------------------------------------------------------------------===//
646// ConcatOp
647//===----------------------------------------------------------------------===//
648
649LogicalResult ConcatOp::inferReturnTypes(
650 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
651 DictionaryAttr attrs, mlir::OpaqueProperties properties,
652 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
653 Domain domain = Domain::TwoValued;
654 unsigned width = 0;
655 for (auto operand : operands) {
656 auto type = cast<IntType>(operand.getType());
657 if (type.getDomain() == Domain::FourValued)
658 domain = Domain::FourValued;
659 width += type.getWidth();
660 }
661 results.push_back(IntType::get(context, width, domain));
662 return success();
663}
664
665//===----------------------------------------------------------------------===//
666// ConcatRefOp
667//===----------------------------------------------------------------------===//
668
669LogicalResult ConcatRefOp::inferReturnTypes(
670 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
671 DictionaryAttr attrs, mlir::OpaqueProperties properties,
672 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
673 Domain domain = Domain::TwoValued;
674 unsigned width = 0;
675 for (Value operand : operands) {
676 UnpackedType nestedType = cast<RefType>(operand.getType()).getNestedType();
677 PackedType packedType = dyn_cast<PackedType>(nestedType);
678
679 if (!packedType) {
680 return failure();
681 }
682
683 if (packedType.getDomain() == Domain::FourValued)
684 domain = Domain::FourValued;
685
686 // getBitSize() for PackedType returns an optional, so we must check it.
687 std::optional<int> bitSize = packedType.getBitSize();
688 if (!bitSize) {
689 return failure();
690 }
691 width += *bitSize;
692 }
693 results.push_back(RefType::get(IntType::get(context, width, domain)));
694 return success();
695}
696
697//===----------------------------------------------------------------------===//
698// ArrayCreateOp
699//===----------------------------------------------------------------------===//
700
701static std::pair<unsigned, UnpackedType> getArrayElements(Type type) {
702 if (auto arrayType = dyn_cast<ArrayType>(type))
703 return {arrayType.getSize(), arrayType.getElementType()};
704 if (auto arrayType = dyn_cast<UnpackedArrayType>(type))
705 return {arrayType.getSize(), arrayType.getElementType()};
706 assert(0 && "expected ArrayType or UnpackedArrayType");
707 return {};
708}
709
710LogicalResult ArrayCreateOp::verify() {
711 auto [size, elementType] = getArrayElements(getType());
712
713 // Check that the number of operands matches the array size.
714 if (getElements().size() != size)
715 return emitOpError() << "has " << getElements().size()
716 << " operands, but result type requires " << size;
717
718 // Check that the operand types match the array element type. We only need to
719 // check one of the operands, since the `SameTypeOperands` trait ensures all
720 // operands have the same type.
721 if (size > 0) {
722 auto value = getElements()[0];
723 if (value.getType() != elementType)
724 return emitOpError() << "operands have type " << value.getType()
725 << ", but array requires " << elementType;
726 }
727 return success();
728}
729
730//===----------------------------------------------------------------------===//
731// StructCreateOp
732//===----------------------------------------------------------------------===//
733
734static std::optional<uint32_t> getStructFieldIndex(Type type, StringAttr name) {
735 if (auto structType = dyn_cast<StructType>(type))
736 return structType.getFieldIndex(name);
737 if (auto structType = dyn_cast<UnpackedStructType>(type))
738 return structType.getFieldIndex(name);
739 assert(0 && "expected StructType or UnpackedStructType");
740 return {};
741}
742
743static ArrayRef<StructLikeMember> getStructMembers(Type type) {
744 if (auto structType = dyn_cast<StructType>(type))
745 return structType.getMembers();
746 if (auto structType = dyn_cast<UnpackedStructType>(type))
747 return structType.getMembers();
748 assert(0 && "expected StructType or UnpackedStructType");
749 return {};
750}
751
752static UnpackedType getStructFieldType(Type type, StringAttr name) {
753 if (auto index = getStructFieldIndex(type, name))
754 return getStructMembers(type)[*index].type;
755 return {};
756}
757
758LogicalResult StructCreateOp::verify() {
759 auto members = getStructMembers(getType());
760
761 // Check that the number of operands matches the number of struct fields.
762 if (getFields().size() != members.size())
763 return emitOpError() << "has " << getFields().size()
764 << " operands, but result type requires "
765 << members.size();
766
767 // Check that the operand types match the struct field types.
768 for (auto [index, pair] : llvm::enumerate(llvm::zip(getFields(), members))) {
769 auto [value, member] = pair;
770 if (value.getType() != member.type)
771 return emitOpError() << "operand #" << index << " has type "
772 << value.getType() << ", but struct field "
773 << member.name << " requires " << member.type;
774 }
775 return success();
776}
777
778OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
779 SmallVector<NamedAttribute> fields;
780 for (auto [member, field] :
781 llvm::zip(getStructMembers(getType()), adaptor.getFields())) {
782 if (!field)
783 return {};
784 fields.push_back(NamedAttribute(member.name, field));
785 }
786 return DictionaryAttr::get(getContext(), fields);
787}
788
789//===----------------------------------------------------------------------===//
790// StructExtractOp
791//===----------------------------------------------------------------------===//
792
793LogicalResult StructExtractOp::verify() {
794 auto type = getStructFieldType(getInput().getType(), getFieldNameAttr());
795 if (!type)
796 return emitOpError() << "extracts field " << getFieldNameAttr()
797 << " which does not exist in " << getInput().getType();
798 if (type != getType())
799 return emitOpError() << "result type " << getType()
800 << " must match struct field type " << type;
801 return success();
802}
803
804OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
805 // Extract on a constant struct input.
806 if (auto fields = dyn_cast_or_null<DictionaryAttr>(adaptor.getInput()))
807 if (auto value = fields.get(getFieldNameAttr()))
808 return value;
809
810 // extract(inject(s, "field", v), "field") -> v
811 if (auto inject = getInput().getDefiningOp<StructInjectOp>()) {
812 if (inject.getFieldNameAttr() == getFieldNameAttr())
813 return inject.getNewValue();
814 return {};
815 }
816
817 // extract(create({"field": v, ...}), "field") -> v
818 if (auto create = getInput().getDefiningOp<StructCreateOp>()) {
819 if (auto index = getStructFieldIndex(create.getType(), getFieldNameAttr()))
820 return create.getFields()[*index];
821 return {};
822 }
823
824 return {};
825}
826
827//===----------------------------------------------------------------------===//
828// StructExtractRefOp
829//===----------------------------------------------------------------------===//
830
831LogicalResult StructExtractRefOp::verify() {
832 auto type = getStructFieldType(
833 cast<RefType>(getInput().getType()).getNestedType(), getFieldNameAttr());
834 if (!type)
835 return emitOpError() << "extracts field " << getFieldNameAttr()
836 << " which does not exist in " << getInput().getType();
837 if (type != getType().getNestedType())
838 return emitOpError() << "result ref of type " << getType().getNestedType()
839 << " must match struct field type " << type;
840 return success();
841}
842
843bool StructExtractRefOp::canRewire(
844 const DestructurableMemorySlot &slot,
845 SmallPtrSetImpl<Attribute> &usedIndices,
846 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
847 const DataLayout &dataLayout) {
848 if (slot.ptr != getInput())
849 return false;
850 auto index = getFieldNameAttr();
851 if (!index || !slot.subelementTypes.contains(index))
852 return false;
853 usedIndices.insert(index);
854 return true;
855}
856
857DeletionKind
858StructExtractRefOp::rewire(const DestructurableMemorySlot &slot,
859 DenseMap<Attribute, MemorySlot> &subslots,
860 OpBuilder &builder, const DataLayout &dataLayout) {
861 auto index = getFieldNameAttr();
862 const MemorySlot &memorySlot = subslots.at(index);
863 replaceAllUsesWith(memorySlot.ptr);
864 getInputMutable().drop();
865 erase();
866 return DeletionKind::Keep;
867}
868
869//===----------------------------------------------------------------------===//
870// StructInjectOp
871//===----------------------------------------------------------------------===//
872
873LogicalResult StructInjectOp::verify() {
874 auto type = getStructFieldType(getInput().getType(), getFieldNameAttr());
875 if (!type)
876 return emitOpError() << "injects field " << getFieldNameAttr()
877 << " which does not exist in " << getInput().getType();
878 if (type != getNewValue().getType())
879 return emitOpError() << "injected value " << getNewValue().getType()
880 << " must match struct field type " << type;
881 return success();
882}
883
884OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
885 auto input = adaptor.getInput();
886 auto newValue = adaptor.getNewValue();
887 if (!input || !newValue)
888 return {};
889 NamedAttrList fields(cast<DictionaryAttr>(input));
890 fields.set(getFieldNameAttr(), newValue);
891 return fields.getDictionary(getContext());
892}
893
894LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
895 PatternRewriter &rewriter) {
896 auto members = getStructMembers(op.getType());
897
898 // Chase a chain of `struct_inject` ops, with an optional final
899 // `struct_create`, and take note of the values assigned to each field.
900 SmallPtrSet<Operation *, 4> injectOps;
901 DenseMap<StringAttr, Value> fieldValues;
902 Value input = op;
903 while (auto injectOp = input.getDefiningOp<StructInjectOp>()) {
904 if (!injectOps.insert(injectOp).second)
905 return failure();
906 fieldValues.insert({injectOp.getFieldNameAttr(), injectOp.getNewValue()});
907 input = injectOp.getInput();
908 }
909 if (auto createOp = input.getDefiningOp<StructCreateOp>())
910 for (auto [value, member] : llvm::zip(createOp.getFields(), members))
911 fieldValues.insert({member.name, value});
912
913 // If the inject chain sets all fields, canonicalize to a `struct_create`.
914 if (fieldValues.size() == members.size()) {
915 SmallVector<Value> values;
916 values.reserve(fieldValues.size());
917 for (auto member : members)
918 values.push_back(fieldValues.lookup(member.name));
919 rewriter.replaceOpWithNewOp<StructCreateOp>(op, op.getType(), values);
920 return success();
921 }
922
923 // If each inject op in the chain assigned to a unique field, there is nothing
924 // to canonicalize.
925 if (injectOps.size() == fieldValues.size())
926 return failure();
927
928 // Otherwise we can eliminate overwrites by creating new injects. The hash map
929 // of field values contains the last assigned value for each field.
930 for (auto member : members)
931 if (auto value = fieldValues.lookup(member.name))
932 input = StructInjectOp::create(rewriter, op.getLoc(), op.getType(), input,
933 member.name, value);
934 rewriter.replaceOp(op, input);
935 return success();
936}
937
938//===----------------------------------------------------------------------===//
939// UnionCreateOp
940//===----------------------------------------------------------------------===//
941
942LogicalResult UnionCreateOp::verify() {
943 /// checks if the types of the input is exactly equal to the union field
944 /// type
945 return TypeSwitch<Type, LogicalResult>(getType())
946 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
947 auto members = type.getMembers();
948 auto resultType = getType();
949 auto fieldName = getFieldName();
950 for (const auto &member : members)
951 if (member.name == fieldName && member.type == resultType)
952 return success();
953 emitOpError("input type must match the union field type");
954 return failure();
955 })
956 .Default([this](auto &) {
957 emitOpError("input type must be UnionType or UnpackedUnionType");
958 return failure();
959 });
960}
961
962//===----------------------------------------------------------------------===//
963// UnionExtractOp
964//===----------------------------------------------------------------------===//
965
966LogicalResult UnionExtractOp::verify() {
967 /// checks if the types of the input is exactly equal to the one of the
968 /// types of the result union fields
969 return TypeSwitch<Type, LogicalResult>(getInput().getType())
970 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
971 auto members = type.getMembers();
972 auto fieldName = getFieldName();
973 auto resultType = getType();
974 for (const auto &member : members)
975 if (member.name == fieldName && member.type == resultType)
976 return success();
977 emitOpError("result type must match the union field type");
978 return failure();
979 })
980 .Default([this](auto &) {
981 emitOpError("input type must be UnionType or UnpackedUnionType");
982 return failure();
983 });
984}
985
986//===----------------------------------------------------------------------===//
987// UnionExtractOp
988//===----------------------------------------------------------------------===//
989
990LogicalResult UnionExtractRefOp::verify() {
991 /// checks if the types of the result is exactly equal to the type of the
992 /// refe union field
993 return TypeSwitch<Type, LogicalResult>(getInput().getType().getNestedType())
994 .Case<UnionType, UnpackedUnionType>([this](auto &type) {
995 auto members = type.getMembers();
996 auto fieldName = getFieldName();
997 auto resultType = getType().getNestedType();
998 for (const auto &member : members)
999 if (member.name == fieldName && member.type == resultType)
1000 return success();
1001 emitOpError("result type must match the union field type");
1002 return failure();
1003 })
1004 .Default([this](auto &) {
1005 emitOpError("input type must be UnionType or UnpackedUnionType");
1006 return failure();
1007 });
1008}
1009
1010//===----------------------------------------------------------------------===//
1011// YieldOp
1012//===----------------------------------------------------------------------===//
1013
1014LogicalResult YieldOp::verify() {
1015 // Check that YieldOp's parent operation is ConditionalOp.
1016 auto cond = dyn_cast<ConditionalOp>(*(*this).getParentOp());
1017 if (!cond) {
1018 emitOpError("must have a conditional parent");
1019 return failure();
1020 }
1021
1022 // Check that the operand matches the parent operation's result.
1023 auto condType = cond.getType();
1024 auto yieldType = getOperand().getType();
1025 if (condType != yieldType) {
1026 emitOpError("yield type must match conditional. Expected ")
1027 << condType << ", but got " << yieldType << ".";
1028 return failure();
1029 }
1030
1031 return success();
1032}
1033
1034//===----------------------------------------------------------------------===//
1035// ConversionOp
1036//===----------------------------------------------------------------------===//
1037
1038OpFoldResult ConversionOp::fold(FoldAdaptor adaptor) {
1039 // Fold away no-op casts.
1040 if (getInput().getType() == getResult().getType())
1041 return getInput();
1042
1043 // Convert domains of constant integer inputs.
1044 auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput());
1045 auto fromIntType = dyn_cast<IntType>(getInput().getType());
1046 auto toIntType = dyn_cast<IntType>(getResult().getType());
1047 if (intInput && fromIntType && toIntType &&
1048 fromIntType.getWidth() == toIntType.getWidth()) {
1049 // If we are going *to* a four-valued type, simply pass through the
1050 // constant.
1051 if (toIntType.getDomain() == Domain::FourValued)
1052 return intInput;
1053
1054 // Otherwise map all unknown bits to zero (the default in SystemVerilog) and
1055 // return a new constant.
1056 return FVIntegerAttr::get(getContext(), intInput.getValue().toAPInt(false));
1057 }
1058
1059 return {};
1060}
1061
1062//===----------------------------------------------------------------------===//
1063// LogicToIntOp
1064//===----------------------------------------------------------------------===//
1065
1066OpFoldResult LogicToIntOp::fold(FoldAdaptor adaptor) {
1067 // logic_to_int(int_to_logic(x)) -> x
1068 if (auto reverseOp = getInput().getDefiningOp<IntToLogicOp>())
1069 return reverseOp.getInput();
1070
1071 // Map all unknown bits to zero (the default in SystemVerilog) and return a
1072 // new constant.
1073 if (auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1074 return FVIntegerAttr::get(getContext(), intInput.getValue().toAPInt(false));
1075
1076 return {};
1077}
1078
1079//===----------------------------------------------------------------------===//
1080// IntToLogicOp
1081//===----------------------------------------------------------------------===//
1082
1083OpFoldResult IntToLogicOp::fold(FoldAdaptor adaptor) {
1084 // Cannot fold int_to_logic(logic_to_int(x)) -> x since that would lose
1085 // information.
1086
1087 // Simply pass through constants.
1088 if (auto intInput = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1089 return intInput;
1090
1091 return {};
1092}
1093
1094//===----------------------------------------------------------------------===//
1095// TimeToLogicOp
1096//===----------------------------------------------------------------------===//
1097
1098OpFoldResult TimeToLogicOp::fold(FoldAdaptor adaptor) {
1099 // time_to_logic(logic_to_time(x)) -> x
1100 if (auto reverseOp = getInput().getDefiningOp<LogicToTimeOp>())
1101 return reverseOp.getInput();
1102
1103 // Convert constants.
1104 if (auto attr = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
1105 return FVIntegerAttr::get(getContext(), attr.getValue());
1106
1107 return {};
1108}
1109
1110//===----------------------------------------------------------------------===//
1111// LogicToTimeOp
1112//===----------------------------------------------------------------------===//
1113
1114OpFoldResult LogicToTimeOp::fold(FoldAdaptor adaptor) {
1115 // logic_to_time(time_to_logic(x)) -> x
1116 if (auto reverseOp = getInput().getDefiningOp<TimeToLogicOp>())
1117 return reverseOp.getInput();
1118
1119 // Convert constants.
1120 if (auto attr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput()))
1121 return IntegerAttr::get(getContext(), APSInt(attr.getValue().toAPInt(false),
1122 /*isUnsigned=*/true));
1123
1124 return {};
1125}
1126
1127//===----------------------------------------------------------------------===//
1128// TruncOp
1129//===----------------------------------------------------------------------===//
1130
1131OpFoldResult TruncOp::fold(FoldAdaptor adaptor) {
1132 // Truncate constants.
1133 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1134 auto width = getType().getWidth();
1135 return FVIntegerAttr::get(getContext(), intAttr.getValue().trunc(width));
1136 }
1137
1138 return {};
1139}
1140
1141//===----------------------------------------------------------------------===//
1142// ZExtOp
1143//===----------------------------------------------------------------------===//
1144
1145OpFoldResult ZExtOp::fold(FoldAdaptor adaptor) {
1146 // Zero-extend constants.
1147 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1148 auto width = getType().getWidth();
1149 return FVIntegerAttr::get(getContext(), intAttr.getValue().zext(width));
1150 }
1151
1152 return {};
1153}
1154
1155//===----------------------------------------------------------------------===//
1156// SExtOp
1157//===----------------------------------------------------------------------===//
1158
1159OpFoldResult SExtOp::fold(FoldAdaptor adaptor) {
1160 // Sign-extend constants.
1161 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getInput())) {
1162 auto width = getType().getWidth();
1163 return FVIntegerAttr::get(getContext(), intAttr.getValue().sext(width));
1164 }
1165
1166 return {};
1167}
1168
1169//===----------------------------------------------------------------------===//
1170// BoolCastOp
1171//===----------------------------------------------------------------------===//
1172
1173OpFoldResult BoolCastOp::fold(FoldAdaptor adaptor) {
1174 // Fold away no-op casts.
1175 if (getInput().getType() == getResult().getType())
1176 return getInput();
1177 return {};
1178}
1179
1180//===----------------------------------------------------------------------===//
1181// BlockingAssignOp
1182//===----------------------------------------------------------------------===//
1183
1184bool BlockingAssignOp::loadsFrom(const MemorySlot &slot) { return false; }
1185
1186bool BlockingAssignOp::storesTo(const MemorySlot &slot) {
1187 return getDst() == slot.ptr;
1188}
1189
1190Value BlockingAssignOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1191 Value reachingDef,
1192 const DataLayout &dataLayout) {
1193 return getSrc();
1194}
1195
1196bool BlockingAssignOp::canUsesBeRemoved(
1197 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1198 SmallVectorImpl<OpOperand *> &newBlockingUses,
1199 const DataLayout &dataLayout) {
1200
1201 if (blockingUses.size() != 1)
1202 return false;
1203 Value blockingUse = (*blockingUses.begin())->get();
1204 return blockingUse == slot.ptr && getDst() == slot.ptr &&
1205 getSrc() != slot.ptr && getSrc().getType() == slot.elemType;
1206}
1207
1208DeletionKind BlockingAssignOp::removeBlockingUses(
1209 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1210 OpBuilder &builder, Value reachingDefinition,
1211 const DataLayout &dataLayout) {
1212 return DeletionKind::Delete;
1213}
1214
1215//===----------------------------------------------------------------------===//
1216// ReadOp
1217//===----------------------------------------------------------------------===//
1218
1219bool ReadOp::loadsFrom(const MemorySlot &slot) {
1220 return getInput() == slot.ptr;
1221}
1222
1223bool ReadOp::storesTo(const MemorySlot &slot) { return false; }
1224
1225Value ReadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1226 Value reachingDef, const DataLayout &dataLayout) {
1227 llvm_unreachable("getStored should not be called on ReadOp");
1228}
1229
1230bool ReadOp::canUsesBeRemoved(const MemorySlot &slot,
1231 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1232 SmallVectorImpl<OpOperand *> &newBlockingUses,
1233 const DataLayout &dataLayout) {
1234
1235 if (blockingUses.size() != 1)
1236 return false;
1237 Value blockingUse = (*blockingUses.begin())->get();
1238 return blockingUse == slot.ptr && getOperand() == slot.ptr &&
1239 getResult().getType() == slot.elemType;
1240}
1241
1242DeletionKind
1243ReadOp::removeBlockingUses(const MemorySlot &slot,
1244 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1245 OpBuilder &builder, Value reachingDefinition,
1246 const DataLayout &dataLayout) {
1247 getResult().replaceAllUsesWith(reachingDefinition);
1248 return DeletionKind::Delete;
1249}
1250
1251//===----------------------------------------------------------------------===//
1252// PowSOp
1253//===----------------------------------------------------------------------===//
1254
1255static OpFoldResult powCommonFolding(MLIRContext *ctxt, Attribute lhs,
1256 Attribute rhs) {
1257 auto lhsValue = dyn_cast_or_null<FVIntegerAttr>(lhs);
1258 if (lhsValue && lhsValue.getValue() == 1)
1259 return lhs;
1260
1261 auto rhsValue = dyn_cast_or_null<FVIntegerAttr>(rhs);
1262 if (rhsValue && rhsValue.getValue().isZero())
1263 return FVIntegerAttr::get(ctxt,
1264 FVInt(rhsValue.getValue().getBitWidth(), 1));
1265
1266 return {};
1267}
1268
1269OpFoldResult PowSOp::fold(FoldAdaptor adaptor) {
1270 return powCommonFolding(getContext(), adaptor.getLhs(), adaptor.getRhs());
1271}
1272
1273LogicalResult PowSOp::canonicalize(PowSOp op, PatternRewriter &rewriter) {
1274 Location loc = op.getLoc();
1275 auto intType = cast<IntType>(op.getRhs().getType());
1276 if (auto baseOp = op.getLhs().getDefiningOp<ConstantOp>()) {
1277 if (baseOp.getValue() == 2) {
1278 Value constOne = ConstantOp::create(rewriter, loc, intType, 1);
1279 Value constZero = ConstantOp::create(rewriter, loc, intType, 0);
1280 Value shift = ShlOp::create(rewriter, loc, constOne, op.getRhs());
1281 Value isNegative = SltOp::create(rewriter, loc, op.getRhs(), constZero);
1282 auto condOp = rewriter.replaceOpWithNewOp<ConditionalOp>(
1283 op, op.getLhs().getType(), isNegative);
1284 Block *thenBlock = rewriter.createBlock(&condOp.getTrueRegion());
1285 rewriter.setInsertionPointToStart(thenBlock);
1286 YieldOp::create(rewriter, loc, constZero);
1287 Block *elseBlock = rewriter.createBlock(&condOp.getFalseRegion());
1288 rewriter.setInsertionPointToStart(elseBlock);
1289 YieldOp::create(rewriter, loc, shift);
1290 return success();
1291 }
1292 }
1293
1294 return failure();
1295}
1296
1297//===----------------------------------------------------------------------===//
1298// PowUOp
1299//===----------------------------------------------------------------------===//
1300
1301OpFoldResult PowUOp::fold(FoldAdaptor adaptor) {
1302 return powCommonFolding(getContext(), adaptor.getLhs(), adaptor.getRhs());
1303}
1304
1305LogicalResult PowUOp::canonicalize(PowUOp op, PatternRewriter &rewriter) {
1306 Location loc = op.getLoc();
1307 auto intType = cast<IntType>(op.getRhs().getType());
1308 if (auto baseOp = op.getLhs().getDefiningOp<ConstantOp>()) {
1309 if (baseOp.getValue() == 2) {
1310 Value constOne = ConstantOp::create(rewriter, loc, intType, 1);
1311 rewriter.replaceOpWithNewOp<ShlOp>(op, constOne, op.getRhs());
1312 return success();
1313 }
1314 }
1315
1316 return failure();
1317}
1318
1319//===----------------------------------------------------------------------===//
1320// SubOp
1321//===----------------------------------------------------------------------===//
1322
1323OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1324 if (auto intAttr = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs()))
1325 if (intAttr.getValue().isZero())
1326 return getLhs();
1327
1328 return {};
1329}
1330
1331//===----------------------------------------------------------------------===//
1332// MulOp
1333//===----------------------------------------------------------------------===//
1334
1335OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1336 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1337 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1338 if (lhs && rhs)
1339 return FVIntegerAttr::get(getContext(), lhs.getValue() * rhs.getValue());
1340 return {};
1341}
1342
1343//===----------------------------------------------------------------------===//
1344// DivUOp
1345//===----------------------------------------------------------------------===//
1346
1347OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1348 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1349 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1350 if (lhs && rhs)
1351 return FVIntegerAttr::get(getContext(),
1352 lhs.getValue().udiv(rhs.getValue()));
1353 return {};
1354}
1355
1356//===----------------------------------------------------------------------===//
1357// DivSOp
1358//===----------------------------------------------------------------------===//
1359
1360OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1361 auto lhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getLhs());
1362 auto rhs = dyn_cast_or_null<FVIntegerAttr>(adaptor.getRhs());
1363 if (lhs && rhs)
1364 return FVIntegerAttr::get(getContext(),
1365 lhs.getValue().sdiv(rhs.getValue()));
1366 return {};
1367}
1368
1369//===----------------------------------------------------------------------===//
1370// TableGen generated logic.
1371//===----------------------------------------------------------------------===//
1372
1373// Provide the autogenerated implementation guts for the Op classes.
1374#define GET_OP_CLASSES
1375#include "circt/Dialect/Moore/Moore.cpp.inc"
1376#include "circt/Dialect/Moore/MooreEnums.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
@ Output
Definition HW.h:35
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static OpFoldResult powCommonFolding(MLIRContext *ctxt, Attribute lhs, Attribute rhs)
static ArrayRef< StructLikeMember > getStructMembers(Type type)
Definition MooreOps.cpp:743
static std::optional< uint32_t > getStructFieldIndex(Type type, StringAttr name)
Definition MooreOps.cpp:734
static UnpackedType getStructFieldType(Type type, StringAttr name)
Definition MooreOps.cpp:752
static std::pair< unsigned, UnpackedType > getArrayElements(Type type)
Definition MooreOps.cpp:701
static InstancePath empty
Four-valued arbitrary precision integers.
Definition FVInt.h:37
bool isNegative() const
Determine whether the integer interpreted as a signed number would be negative.
Definition FVInt.h:185
FVInt sext(unsigned bitWidth) const
Sign-extend the integer to a new bit width.
Definition FVInt.h:148
unsigned getSignificantBits() const
Compute the minimum bit width necessary to accurately represent this integer's value and sign.
Definition FVInt.h:102
static FVInt getAllX(unsigned numBits)
Construct an FVInt with all bits set to X.
Definition FVInt.h:75
bool hasUnknown() const
Determine if any bits are X or Z.
Definition FVInt.h:168
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition FVInt.h:92
unsigned getBitWidth() const
Return the number of bits this integer has.
Definition FVInt.h:85
FVInt trunc(unsigned bitWidth) const
Truncate the integer to a smaller bit width.
Definition FVInt.h:132
A packed SystemVerilog type.
Definition MooreTypes.h:141
std::optional< unsigned > getBitSize() const
Get the size of this type in bits.
Domain getDomain() const
Get the value domain of this type.
An unpacked SystemVerilog type.
Definition MooreTypes.h:90
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition HWOps.cpp:529
Domain
The number of values each bit of a type can assume.
Definition MooreTypes.h:48
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
void printFVInt(AsmPrinter &p, const FVInt &value)
Print a four-valued integer usign an AsmPrinter.
Definition FVInt.cpp:147
ParseResult parseFVInt(AsmParser &p, FVInt &result)
Parse a four-valued integer using an AsmParser.
Definition FVInt.cpp:162
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183