CIRCT 22.0.0git
Loading...
Searching...
No Matches
SVOps.cpp
Go to the documentation of this file.
1//===- SVOps.cpp - Implement the SV operations ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the SV ops.
10//
11//===----------------------------------------------------------------------===//
12
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/IR/Matchers.h"
27#include "mlir/IR/PatternMatch.h"
28#include "mlir/Interfaces/FunctionImplementation.h"
29#include "llvm/ADT/SmallString.h"
30#include "llvm/ADT/StringExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32
33#include <optional>
34
35using namespace circt;
36using namespace sv;
37using mlir::TypedAttr;
38
39/// Return true if the specified expression is 2-state. This is determined by
40/// looking at the defining op. This can look as far through the dataflow as it
41/// wants, but for now, it is just looking at the single value.
42bool sv::is2StateExpression(Value v) {
43 if (auto *op = v.getDefiningOp()) {
44 if (auto attr = op->getAttrOfType<UnitAttr>("twoState"))
45 return (bool)attr;
46 }
47 // Plain constants are obviously safe
48 return v.getDefiningOp<hw::ConstantOp>();
49}
50
51/// Return true if the specified operation is an expression.
52bool sv::isExpression(Operation *op) {
53 return isa<VerbatimExprOp, VerbatimExprSEOp, GetModportOp,
54 ReadInterfaceSignalOp, ConstantXOp, ConstantZOp, ConstantStrOp,
55 MacroRefExprOp, MacroRefExprSEOp>(op);
56}
57
58LogicalResult sv::verifyInProceduralRegion(Operation *op) {
59 if (op->getParentOp()->hasTrait<sv::ProceduralRegion>())
60 return success();
61 op->emitError() << op->getName() << " should be in a procedural region";
62 return failure();
63}
64
65LogicalResult sv::verifyInNonProceduralRegion(Operation *op) {
66 if (!op->getParentOp()->hasTrait<sv::ProceduralRegion>())
67 return success();
68 op->emitError() << op->getName() << " should be in a non-procedural region";
69 return failure();
70}
71
72/// Returns the operation registered with the given symbol name with the regions
73/// of 'symbolTableOp'. recurse through nested regions which don't contain the
74/// symboltable trait. Returns nullptr if no valid symbol was found.
75static Operation *lookupSymbolInNested(Operation *symbolTableOp,
76 StringRef symbol) {
77 Region &region = symbolTableOp->getRegion(0);
78 if (region.empty())
79 return nullptr;
80
81 // Look for a symbol with the given name.
82 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
83 SymbolTable::getSymbolAttrName());
84 for (Block &block : region)
85 for (Operation &nestedOp : block) {
86 auto nameAttr = nestedOp.getAttrOfType<StringAttr>(symbolNameId);
87 if (nameAttr && nameAttr.getValue() == symbol)
88 return &nestedOp;
89 if (!nestedOp.hasTrait<OpTrait::SymbolTable>() &&
90 nestedOp.getNumRegions()) {
91 if (auto *nop = lookupSymbolInNested(&nestedOp, symbol))
92 return nop;
93 }
94 }
95 return nullptr;
96}
97
98/// Verifies symbols referenced by macro identifiers.
99static LogicalResult
100verifyMacroIdentSymbolUses(Operation *op, FlatSymbolRefAttr attr,
101 SymbolTableCollection &symbolTable) {
102 auto *refOp = symbolTable.lookupNearestSymbolFrom(op, attr);
103 if (!refOp)
104 return op->emitError("references an undefined symbol: ") << attr;
105 if (!isa<MacroDeclOp>(refOp))
106 return op->emitError("must reference a macro declaration");
107 return success();
108}
109
110//===----------------------------------------------------------------------===//
111// VerbatimExprOp
112//===----------------------------------------------------------------------===//
113
114/// Get the asm name for sv.verbatim.expr and sv.verbatim.expr.se.
115static void
117 function_ref<void(Value, StringRef)> setNameFn) {
118 // If the string is macro like, then use a pretty name. We only take the
119 // string up to a weird character (like a paren) and currently ignore
120 // parenthesized expressions.
121 auto isOkCharacter = [](char c) { return llvm::isAlnum(c) || c == '_'; };
122 auto name = op->getAttrOfType<StringAttr>("format_string").getValue();
123 // Ignore a leading ` in macro name.
124 if (name.starts_with("`"))
125 name = name.drop_front();
126 name = name.take_while(isOkCharacter);
127 if (!name.empty())
128 setNameFn(op->getResult(0), name);
129}
130
131void VerbatimExprOp::getAsmResultNames(
132 function_ref<void(Value, StringRef)> setNameFn) {
133 getVerbatimExprAsmResultNames(getOperation(), std::move(setNameFn));
134}
135
136void VerbatimExprSEOp::getAsmResultNames(
137 function_ref<void(Value, StringRef)> setNameFn) {
138 getVerbatimExprAsmResultNames(getOperation(), std::move(setNameFn));
139}
140
141//===----------------------------------------------------------------------===//
142// MacroRefExprOp
143//===----------------------------------------------------------------------===//
144
145void MacroRefExprOp::getAsmResultNames(
146 function_ref<void(Value, StringRef)> setNameFn) {
147 setNameFn(getResult(), getMacroName());
148}
149
150void MacroRefExprSEOp::getAsmResultNames(
151 function_ref<void(Value, StringRef)> setNameFn) {
152 setNameFn(getResult(), getMacroName());
153}
154
155static MacroDeclOp getReferencedMacro(const hw::HWSymbolCache *cache,
156 Operation *op,
157 FlatSymbolRefAttr macroName) {
158 if (cache)
159 if (auto *result = cache->getDefinition(macroName.getAttr()))
160 return cast<MacroDeclOp>(result);
161
162 auto topLevelModuleOp = op->getParentOfType<ModuleOp>();
163 return topLevelModuleOp.lookupSymbol<MacroDeclOp>(macroName.getValue());
164}
165
166/// Lookup the module or extmodule for the symbol. This returns null on
167/// invalid IR.
168MacroDeclOp MacroRefExprOp::getReferencedMacro(const hw::HWSymbolCache *cache) {
169 return ::getReferencedMacro(cache, *this, getMacroNameAttr());
170}
171
172MacroDeclOp
173MacroRefExprSEOp::getReferencedMacro(const hw::HWSymbolCache *cache) {
174 return ::getReferencedMacro(cache, *this, getMacroNameAttr());
175}
176
177//===----------------------------------------------------------------------===//
178// MacroErrorOp
179//===----------------------------------------------------------------------===//
180
181std::string MacroErrorOp::getMacroIdentifier() {
182 const auto *prefix = "_ERROR";
183 auto msg = getMessage();
184 if (!msg || msg->empty())
185 return prefix;
186
187 std::string id(prefix);
188 id.push_back('_');
189 for (auto c : *msg) {
190 if (llvm::isAlnum(c))
191 id.push_back(c);
192 else
193 id.push_back('_');
194 }
195 return id;
196}
197
198//===----------------------------------------------------------------------===//
199// MacroDeclOp
200//===----------------------------------------------------------------------===//
201
202MacroDeclOp MacroDefOp::getReferencedMacro(const hw::HWSymbolCache *cache) {
203 return ::getReferencedMacro(cache, *this, getMacroNameAttr());
204}
205
206MacroDeclOp MacroRefOp::getReferencedMacro(const hw::HWSymbolCache *cache) {
207 return ::getReferencedMacro(cache, *this, getMacroNameAttr());
208}
209
210/// Ensure that the symbol being instantiated exists and is a MacroDefOp.
211LogicalResult
212MacroRefExprOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
213 return verifyMacroIdentSymbolUses(*this, getMacroNameAttr(), symbolTable);
214}
215
216/// Ensure that the symbol being instantiated exists and is a MacroDefOp.
217LogicalResult
218MacroRefExprSEOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
219 return verifyMacroIdentSymbolUses(*this, getMacroNameAttr(), symbolTable);
220}
221
222/// Ensure that the symbol being instantiated exists and is a MacroDefOp.
223LogicalResult MacroDefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
224 return verifyMacroIdentSymbolUses(*this, getMacroNameAttr(), symbolTable);
225}
226
227/// Ensure that the symbol being instantiated exists and is a MacroDefOp.
228LogicalResult MacroRefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
229 return verifyMacroIdentSymbolUses(*this, getMacroNameAttr(), symbolTable);
230}
231
232//===----------------------------------------------------------------------===//
233// MacroDeclOp
234//===----------------------------------------------------------------------===//
235
236StringRef MacroDeclOp::getMacroIdentifier() {
237 return getVerilogName().value_or(getSymName());
238}
239
240//===----------------------------------------------------------------------===//
241// ConstantXOp / ConstantZOp
242//===----------------------------------------------------------------------===//
243
244void ConstantXOp::getAsmResultNames(
245 function_ref<void(Value, StringRef)> setNameFn) {
246 SmallVector<char, 32> specialNameBuffer;
247 llvm::raw_svector_ostream specialName(specialNameBuffer);
248 specialName << "x_i" << getWidth();
249 setNameFn(getResult(), specialName.str());
250}
251
252LogicalResult ConstantXOp::verify() {
253 // We don't allow zero width constant or unknown width.
254 if (getWidth() <= 0)
255 return emitError("unsupported type");
256 return success();
257}
258
259void ConstantZOp::getAsmResultNames(
260 function_ref<void(Value, StringRef)> setNameFn) {
261 SmallVector<char, 32> specialNameBuffer;
262 llvm::raw_svector_ostream specialName(specialNameBuffer);
263 specialName << "z_i" << getWidth();
264 setNameFn(getResult(), specialName.str());
265}
266
267LogicalResult ConstantZOp::verify() {
268 // We don't allow zero width constant or unknown type.
269 if (getWidth() <= 0)
270 return emitError("unsupported type");
271 return success();
272}
273
274//===----------------------------------------------------------------------===//
275// LocalParamOp
276//===----------------------------------------------------------------------===//
277
278void LocalParamOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
279 // If the localparam has an optional 'name' attribute, use it.
280 auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
281 if (!nameAttr.getValue().empty())
282 setNameFn(getResult(), nameAttr.getValue());
283}
284
285LogicalResult LocalParamOp::verify() {
286 // Verify that this is a valid parameter value.
287 return hw::checkParameterInContext(
288 getValue(), (*this)->getParentOfType<hw::HWModuleOp>(), *this);
289}
290
291//===----------------------------------------------------------------------===//
292// RegOp
293//===----------------------------------------------------------------------===//
294
295static ParseResult
296parseImplicitInitType(OpAsmParser &p, mlir::Type regType,
297 std::optional<OpAsmParser::UnresolvedOperand> &initValue,
298 mlir::Type &initType) {
299 if (!initValue.has_value())
300 return success();
301
302 hw::InOutType ioType = dyn_cast<hw::InOutType>(regType);
303 if (!ioType)
304 return p.emitError(p.getCurrentLocation(), "expected inout type for reg");
305
306 initType = ioType.getElementType();
307 return success();
308}
309
310static void printImplicitInitType(OpAsmPrinter &p, Operation *op,
311 mlir::Type regType, mlir::Value initValue,
312 mlir::Type initType) {}
313
314void RegOp::build(OpBuilder &builder, OperationState &odsState,
315 Type elementType, StringAttr name, hw::InnerSymAttr innerSym,
316 mlir::Value initValue) {
317 if (!name)
318 name = builder.getStringAttr("");
319 odsState.addAttribute("name", name);
320 if (innerSym)
321 odsState.addAttribute(hw::InnerSymbolTable::getInnerSymbolAttrName(),
322 innerSym);
323 odsState.addTypes(hw::InOutType::get(elementType));
324 if (initValue)
325 odsState.addOperands(initValue);
326}
327
328/// Suggest a name for each result value based on the saved result names
329/// attribute.
330void RegOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
331 // If the wire has an optional 'name' attribute, use it.
332 auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
333 if (!nameAttr.getValue().empty())
334 setNameFn(getResult(), nameAttr.getValue());
335}
336
337std::optional<size_t> RegOp::getTargetResultIndex() { return 0; }
338
339// If this reg is only written to, delete the reg and all writers.
340LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
341 // Block if op has SV attributes.
342 if (hasSVAttributes(op))
343 return failure();
344
345 // If the reg has a symbol, then we can't delete it.
346 if (op.getInnerSymAttr())
347 return failure();
348 // Check that all operations on the wire are sv.assigns. All other wire
349 // operations will have been handled by other canonicalization.
350 for (auto *user : op.getResult().getUsers())
351 if (!isa<AssignOp>(user))
352 return failure();
353
354 // Remove all uses of the wire.
355 for (auto *user : llvm::make_early_inc_range(op.getResult().getUsers()))
356 rewriter.eraseOp(user);
357
358 // Remove the wire.
359 rewriter.eraseOp(op);
360 return success();
361}
362
363//===----------------------------------------------------------------------===//
364// LogicOp
365//===----------------------------------------------------------------------===//
366
367void LogicOp::build(OpBuilder &builder, OperationState &odsState,
368 Type elementType, StringAttr name,
369 hw::InnerSymAttr innerSym) {
370 if (!name)
371 name = builder.getStringAttr("");
372 odsState.addAttribute("name", name);
373 if (innerSym)
374 odsState.addAttribute(hw::InnerSymbolTable::getInnerSymbolAttrName(),
375 innerSym);
376 odsState.addTypes(hw::InOutType::get(elementType));
377}
378
379/// Suggest a name for each result value based on the saved result names
380/// attribute.
381void LogicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
382 // If the logic has an optional 'name' attribute, use it.
383 auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
384 if (!nameAttr.getValue().empty())
385 setNameFn(getResult(), nameAttr.getValue());
386}
387
388std::optional<size_t> LogicOp::getTargetResultIndex() { return 0; }
389
390//===----------------------------------------------------------------------===//
391// Control flow like-operations
392//===----------------------------------------------------------------------===//
393
394//===----------------------------------------------------------------------===//
395// IfDefOp
396//===----------------------------------------------------------------------===//
397
398void IfDefOp::build(OpBuilder &builder, OperationState &result, StringRef cond,
399 std::function<void()> thenCtor,
400 std::function<void()> elseCtor) {
401 build(builder, result, builder.getStringAttr(cond), std::move(thenCtor),
402 std::move(elseCtor));
403}
404
405void IfDefOp::build(OpBuilder &builder, OperationState &result, StringAttr cond,
406 std::function<void()> thenCtor,
407 std::function<void()> elseCtor) {
408 build(builder, result, FlatSymbolRefAttr::get(builder.getContext(), cond),
409 std::move(thenCtor), std::move(elseCtor));
410}
411
412void IfDefOp::build(OpBuilder &builder, OperationState &result,
413 FlatSymbolRefAttr cond, std::function<void()> thenCtor,
414 std::function<void()> elseCtor) {
415 build(builder, result, MacroIdentAttr::get(builder.getContext(), cond),
416 std::move(thenCtor), std::move(elseCtor));
417}
418
419void IfDefOp::build(OpBuilder &builder, OperationState &result,
420 MacroIdentAttr cond, std::function<void()> thenCtor,
421 std::function<void()> elseCtor) {
422 OpBuilder::InsertionGuard guard(builder);
423
424 result.addAttribute("cond", cond);
425 builder.createBlock(result.addRegion());
426
427 // Fill in the body of the #ifdef.
428 if (thenCtor)
429 thenCtor();
430
431 Region *elseRegion = result.addRegion();
432 if (elseCtor) {
433 builder.createBlock(elseRegion);
434 elseCtor();
435 }
436}
437
438LogicalResult IfDefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
439 return verifyMacroIdentSymbolUses(*this, getCond().getIdent(), symbolTable);
440}
441
442// If both thenRegion and elseRegion are empty, erase op.
443template <class Op>
444static LogicalResult canonicalizeIfDefLike(Op op, PatternRewriter &rewriter) {
445 if (!op.getThenBlock()->empty())
446 return failure();
447
448 if (op.hasElse() && !op.getElseBlock()->empty())
449 return failure();
450
451 rewriter.eraseOp(op);
452 return success();
453}
454
455LogicalResult IfDefOp::canonicalize(IfDefOp op, PatternRewriter &rewriter) {
456 return canonicalizeIfDefLike(op, rewriter);
457}
458
459//===----------------------------------------------------------------------===//
460// IfDefProceduralOp
461//===----------------------------------------------------------------------===//
462
463void IfDefProceduralOp::build(OpBuilder &builder, OperationState &result,
464 StringRef cond, std::function<void()> thenCtor,
465 std::function<void()> elseCtor) {
466 build(builder, result, builder.getStringAttr(cond), std::move(thenCtor),
467 std::move(elseCtor));
468}
469
470void IfDefProceduralOp::build(OpBuilder &builder, OperationState &result,
471 StringAttr cond, std::function<void()> thenCtor,
472 std::function<void()> elseCtor) {
473 build(builder, result, FlatSymbolRefAttr::get(builder.getContext(), cond),
474 std::move(thenCtor), std::move(elseCtor));
475}
476
477void IfDefProceduralOp::build(OpBuilder &builder, OperationState &result,
478 FlatSymbolRefAttr cond,
479 std::function<void()> thenCtor,
480 std::function<void()> elseCtor) {
481 build(builder, result, MacroIdentAttr::get(builder.getContext(), cond),
482 std::move(thenCtor), std::move(elseCtor));
483}
484
485void IfDefProceduralOp::build(OpBuilder &builder, OperationState &result,
486 MacroIdentAttr cond,
487 std::function<void()> thenCtor,
488 std::function<void()> elseCtor) {
489 OpBuilder::InsertionGuard guard(builder);
490
491 result.addAttribute("cond", cond);
492 builder.createBlock(result.addRegion());
493
494 // Fill in the body of the #ifdef.
495 if (thenCtor)
496 thenCtor();
497
498 Region *elseRegion = result.addRegion();
499 if (elseCtor) {
500 builder.createBlock(elseRegion);
501 elseCtor();
502 }
503}
504
505LogicalResult IfDefProceduralOp::canonicalize(IfDefProceduralOp op,
506 PatternRewriter &rewriter) {
507 return canonicalizeIfDefLike(op, rewriter);
508}
509
510LogicalResult
511IfDefProceduralOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
512 return verifyMacroIdentSymbolUses(*this, getCond().getIdent(), symbolTable);
513}
514
515//===----------------------------------------------------------------------===//
516// IfOp
517//===----------------------------------------------------------------------===//
518
519void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
520 std::function<void()> thenCtor,
521 std::function<void()> elseCtor) {
522 OpBuilder::InsertionGuard guard(builder);
523
524 result.addOperands(cond);
525 builder.createBlock(result.addRegion());
526
527 // Fill in the body of the if.
528 if (thenCtor)
529 thenCtor();
530
531 Region *elseRegion = result.addRegion();
532 if (elseCtor) {
533 builder.createBlock(elseRegion);
534 elseCtor();
535 }
536}
537
538/// Replaces the given op with the contents of the given single-block region.
539static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
540 Region &region) {
541 assert(llvm::hasSingleElement(region) && "expected single-region block");
542 Block *fromBlock = &region.front();
543 // Merge it in above the specified operation.
544 op->getBlock()->getOperations().splice(Block::iterator(op),
545 fromBlock->getOperations());
546}
547
548LogicalResult IfOp::canonicalize(IfOp op, PatternRewriter &rewriter) {
549 // Block if op has SV attributes.
550 if (hasSVAttributes(op))
551 return failure();
552
553 if (auto constant = op.getCond().getDefiningOp<hw::ConstantOp>()) {
554
555 if (constant.getValue().isAllOnes())
556 replaceOpWithRegion(rewriter, op, op.getThenRegion());
557 else if (!op.getElseRegion().empty())
558 replaceOpWithRegion(rewriter, op, op.getElseRegion());
559
560 rewriter.eraseOp(op);
561
562 return success();
563 }
564
565 // Erase empty if-else block.
566 if (!op.getThenBlock()->empty() && op.hasElse() &&
567 op.getElseBlock()->empty()) {
568 rewriter.eraseBlock(op.getElseBlock());
569 return success();
570 }
571
572 // Erase empty if's.
573
574 // If there is stuff in the then block, leave this operation alone.
575 if (!op.getThenBlock()->empty())
576 return failure();
577
578 // If not and there is no else, then this operation is just useless.
579 if (!op.hasElse() || op.getElseBlock()->empty()) {
580 rewriter.eraseOp(op);
581 return success();
582 }
583
584 // Otherwise, invert the condition and move the 'else' block to the 'then'
585 // region if the condition is a 2-state operation. This changes x prop
586 // behavior so it needs to be guarded.
587 if (is2StateExpression(op.getCond())) {
588 auto cond = comb::createOrFoldNot(op.getLoc(), op.getCond(), rewriter);
589 op.setOperand(cond);
590
591 auto *thenBlock = op.getThenBlock(), *elseBlock = op.getElseBlock();
592
593 // Move the body of the then block over to the else.
594 thenBlock->getOperations().splice(thenBlock->end(),
595 elseBlock->getOperations());
596 rewriter.eraseBlock(elseBlock);
597 return success();
598 }
599 return failure();
600}
601
602//===----------------------------------------------------------------------===//
603// AlwaysOp
604//===----------------------------------------------------------------------===//
605
606AlwaysOp::Condition AlwaysOp::getCondition(size_t idx) {
607 return Condition{EventControl(cast<IntegerAttr>(getEvents()[idx]).getInt()),
608 getOperand(idx)};
609}
610
611void AlwaysOp::build(OpBuilder &builder, OperationState &result,
612 ArrayRef<sv::EventControl> events, ArrayRef<Value> clocks,
613 std::function<void()> bodyCtor) {
614 assert(events.size() == clocks.size() &&
615 "mismatch between event and clock list");
616 OpBuilder::InsertionGuard guard(builder);
617
618 SmallVector<Attribute> eventAttrs;
619 for (auto event : events)
620 eventAttrs.push_back(
621 builder.getI32IntegerAttr(static_cast<int32_t>(event)));
622 result.addAttribute("events", builder.getArrayAttr(eventAttrs));
623 result.addOperands(clocks);
624
625 // Set up the body. Moves the insert point
626 builder.createBlock(result.addRegion());
627
628 // Fill in the body of the #ifdef.
629 if (bodyCtor)
630 bodyCtor();
631}
632
633/// Ensure that the symbol being instantiated exists and is an InterfaceOp.
634LogicalResult AlwaysOp::verify() {
635 if (getEvents().size() != getNumOperands())
636 return emitError("different number of operands and events");
637 return success();
638}
639
640static ParseResult parseEventList(
641 OpAsmParser &p, Attribute &eventsAttr,
642 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &clocksOperands) {
643
644 // Parse zero or more conditions intoevents and clocksOperands.
645 SmallVector<Attribute> events;
646
647 auto loc = p.getCurrentLocation();
648 StringRef keyword;
649 if (!p.parseOptionalKeyword(&keyword)) {
650 while (1) {
651 auto kind = sv::symbolizeEventControl(keyword);
652 if (!kind.has_value())
653 return p.emitError(loc, "expected 'posedge', 'negedge', or 'edge'");
654 auto eventEnum = static_cast<int32_t>(*kind);
655 events.push_back(p.getBuilder().getI32IntegerAttr(eventEnum));
656
657 clocksOperands.push_back({});
658 if (p.parseOperand(clocksOperands.back()))
659 return failure();
660
661 if (failed(p.parseOptionalComma()))
662 break;
663 if (p.parseKeyword(&keyword))
664 return failure();
665 }
666 }
667 eventsAttr = p.getBuilder().getArrayAttr(events);
668 return success();
669}
670
671static void printEventList(OpAsmPrinter &p, AlwaysOp op, ArrayAttr portsAttr,
672 OperandRange operands) {
673 for (size_t i = 0, e = op.getNumConditions(); i != e; ++i) {
674 if (i != 0)
675 p << ", ";
676 auto cond = op.getCondition(i);
677 p << stringifyEventControl(cond.event);
678 p << ' ';
679 p.printOperand(cond.value);
680 }
681}
682
683//===----------------------------------------------------------------------===//
684// AlwaysFFOp
685//===----------------------------------------------------------------------===//
686
687void AlwaysFFOp::build(OpBuilder &builder, OperationState &result,
688 EventControl clockEdge, Value clock,
689 std::function<void()> bodyCtor) {
690 OpBuilder::InsertionGuard guard(builder);
691
692 result.addAttribute(
693 "clockEdge", builder.getI32IntegerAttr(static_cast<int32_t>(clockEdge)));
694 result.addOperands(clock);
695 result.addAttribute(
696 "resetStyle",
697 builder.getI32IntegerAttr(static_cast<int32_t>(ResetType::NoReset)));
698
699 // Set up the body. Moves Insert Point
700 builder.createBlock(result.addRegion());
701
702 if (bodyCtor)
703 bodyCtor();
704
705 // Set up the reset region.
706 result.addRegion();
707}
708
709void AlwaysFFOp::build(OpBuilder &builder, OperationState &result,
710 EventControl clockEdge, Value clock,
711 ResetType resetStyle, EventControl resetEdge,
712 Value reset, std::function<void()> bodyCtor,
713 std::function<void()> resetCtor) {
714 OpBuilder::InsertionGuard guard(builder);
715
716 result.addAttribute(
717 "clockEdge", builder.getI32IntegerAttr(static_cast<int32_t>(clockEdge)));
718 result.addOperands(clock);
719 result.addAttribute("resetStyle", builder.getI32IntegerAttr(
720 static_cast<int32_t>(resetStyle)));
721 result.addAttribute(
722 "resetEdge", builder.getI32IntegerAttr(static_cast<int32_t>(resetEdge)));
723 result.addOperands(reset);
724
725 // Set up the body. Moves Insert Point.
726 builder.createBlock(result.addRegion());
727
728 if (bodyCtor)
729 bodyCtor();
730
731 // Set up the reset. Moves Insert Point.
732 builder.createBlock(result.addRegion());
733
734 if (resetCtor)
735 resetCtor();
736}
737
738//===----------------------------------------------------------------------===//
739// AlwaysCombOp
740//===----------------------------------------------------------------------===//
741
742void AlwaysCombOp::build(OpBuilder &builder, OperationState &result,
743 std::function<void()> bodyCtor) {
744 OpBuilder::InsertionGuard guard(builder);
745
746 builder.createBlock(result.addRegion());
747
748 if (bodyCtor)
749 bodyCtor();
750}
751
752//===----------------------------------------------------------------------===//
753// InitialOp
754//===----------------------------------------------------------------------===//
755
756void InitialOp::build(OpBuilder &builder, OperationState &result,
757 std::function<void()> bodyCtor) {
758 OpBuilder::InsertionGuard guard(builder);
759
760 builder.createBlock(result.addRegion());
761
762 // Fill in the body of the #ifdef.
763 if (bodyCtor)
764 bodyCtor();
765}
766
767//===----------------------------------------------------------------------===//
768// CaseOp
769//===----------------------------------------------------------------------===//
770
771/// Return the letter for the specified pattern bit, e.g. "0", "1", "x" or "z".
773 switch (bit) {
775 return '0';
777 return '1';
779 return 'x';
781 return 'z';
782 }
783 llvm_unreachable("invalid casez PatternBit");
784}
785
786/// Return the specified bit, bit 0 is the least significant bit.
787auto CaseBitPattern::getBit(size_t bitNumber) const -> CasePatternBit {
788 return CasePatternBit(unsigned(intAttr.getValue()[bitNumber * 2]) +
789 2 * unsigned(intAttr.getValue()[bitNumber * 2 + 1]));
790}
791
793 for (size_t i = 0, e = getWidth(); i != e; ++i)
794 if (getBit(i) == CasePatternBit::AnyX)
795 return true;
796 return false;
797}
798
800 for (size_t i = 0, e = getWidth(); i != e; ++i)
801 if (getBit(i) == CasePatternBit::AnyZ)
802 return true;
803 return false;
804}
805static SmallVector<CasePatternBit> getPatternBitsForValue(const APInt &value) {
806 SmallVector<CasePatternBit> result;
807 result.reserve(value.getBitWidth());
808 for (size_t i = 0, e = value.getBitWidth(); i != e; ++i)
809 result.push_back(CasePatternBit(value[i]));
810
811 return result;
812}
813
814// Get a CaseBitPattern from a specified list of PatternBits. Bits are
815// specified in most least significant order - element zero is the least
816// significant bit.
817CaseBitPattern::CaseBitPattern(const APInt &value, MLIRContext *context)
818 : CaseBitPattern(getPatternBitsForValue(value), context) {}
819
820// Get a CaseBitPattern from a specified list of PatternBits. Bits are
821// specified in most least significant order - element zero is the least
822// significant bit.
823CaseBitPattern::CaseBitPattern(ArrayRef<CasePatternBit> bits,
824 MLIRContext *context)
825 : CasePattern(CPK_bit) {
826 APInt pattern(bits.size() * 2, 0);
827 for (auto elt : llvm::reverse(bits)) {
828 pattern <<= 2;
829 pattern |= unsigned(elt);
830 }
831 auto patternType = IntegerType::get(context, bits.size() * 2);
832 intAttr = IntegerAttr::get(patternType, pattern);
833}
834
835auto CaseOp::getCases() -> SmallVector<CaseInfo, 4> {
836 SmallVector<CaseInfo, 4> result;
837 assert(getCasePatterns().size() == getNumRegions() &&
838 "case pattern / region count mismatch");
839 size_t nextRegion = 0;
840 for (auto elt : getCasePatterns()) {
841 llvm::TypeSwitch<Attribute>(elt)
842 .Case<hw::EnumFieldAttr>([&](auto enumAttr) {
843 result.push_back({std::make_unique<CaseEnumPattern>(enumAttr),
844 &getRegion(nextRegion++).front()});
845 })
846 .Case<CaseExprPatternAttr>([&](auto exprAttr) {
847 result.push_back({std::make_unique<CaseExprPattern>(getContext()),
848 &getRegion(nextRegion++).front()});
849 })
850 .Case<IntegerAttr>([&](auto intAttr) {
851 result.push_back({std::make_unique<CaseBitPattern>(intAttr),
852 &getRegion(nextRegion++).front()});
853 })
854 .Case<CaseDefaultPattern::AttrType>([&](auto) {
855 result.push_back({std::make_unique<CaseDefaultPattern>(getContext()),
856 &getRegion(nextRegion++).front()});
857 })
858 .Default([](auto) {
859 assert(false && "invalid case pattern attribute type");
860 });
861 }
862
863 return result;
864}
865
867 return cast<hw::EnumFieldAttr>(enumAttr).getField();
868}
869
870/// Parse case op.
871/// case op ::= `sv.case` case-style? validation-qualifier? cond `:` type
872/// attr-dict case-pattern^*
873/// case-style ::= `case` | `casex` | `casez`
874/// validation-qualifier (see SV Spec 12.5.3) ::= `unique` | `unique0`
875/// | `priority`
876/// case-pattern ::= `case` bit-pattern `:` region
877ParseResult CaseOp::parse(OpAsmParser &parser, OperationState &result) {
878 auto &builder = parser.getBuilder();
879
880 OpAsmParser::UnresolvedOperand condOperand;
881 Type condType;
882
883 auto loc = parser.getCurrentLocation();
884
885 StringRef keyword;
886 if (!parser.parseOptionalKeyword(&keyword, {"case", "casex", "casez"})) {
887 auto kind = symbolizeCaseStmtType(keyword);
888 auto caseEnum = static_cast<int32_t>(kind.value());
889 result.addAttribute("caseStyle", builder.getI32IntegerAttr(caseEnum));
890 }
891
892 // Parse validation qualifier.
893 if (!parser.parseOptionalKeyword(
894 &keyword, {"plain", "priority", "unique", "unique0"})) {
895 auto kind = symbolizeValidationQualifierTypeEnum(keyword);
896 result.addAttribute("validationQualifier",
897 ValidationQualifierTypeEnumAttr::get(
898 builder.getContext(), kind.value()));
899 }
900
901 if (parser.parseOperand(condOperand) || parser.parseColonType(condType) ||
902 parser.parseOptionalAttrDict(result.attributes) ||
903 parser.resolveOperand(condOperand, condType, result.operands))
904 return failure();
905
906 // Check the integer type.
907 Type canonicalCondType = hw::getCanonicalType(condType);
908 hw::EnumType enumType = dyn_cast<hw::EnumType>(canonicalCondType);
909 unsigned condWidth = 0;
910 if (!enumType) {
911 if (!result.operands[0].getType().isSignlessInteger())
912 return parser.emitError(loc, "condition must have signless integer type");
913 condWidth = condType.getIntOrFloatBitWidth();
914 }
915
916 // Parse all the cases.
917 SmallVector<Attribute> casePatterns;
918 SmallVector<CasePatternBit, 16> caseBits;
919 while (1) {
920 mlir::OptionalParseResult caseValueParseResult;
921 OpAsmParser::UnresolvedOperand caseValueOperand;
922 if (succeeded(parser.parseOptionalKeyword("default"))) {
923 casePatterns.push_back(CaseDefaultPattern(parser.getContext()).attr());
924 } else if (failed(parser.parseOptionalKeyword("case"))) {
925 // Not default or case, must be the end of the cases.
926 break;
927 } else if (enumType) {
928 // Enumerated case; parse the case value.
929 StringRef caseVal;
930
931 if (parser.parseKeyword(&caseVal))
932 return failure();
933
934 if (!enumType.contains(caseVal))
935 return parser.emitError(loc)
936 << "case value '" + caseVal + "' is not a member of enum type "
937 << enumType;
938 casePatterns.push_back(
939 hw::EnumFieldAttr::get(parser.getEncodedSourceLoc(loc),
940 builder.getStringAttr(caseVal), condType));
941 } else if ((caseValueParseResult =
942 parser.parseOptionalOperand(caseValueOperand))
943 .has_value()) {
944 if (failed(caseValueParseResult.value()) ||
945 parser.resolveOperand(caseValueOperand, condType, result.operands))
946 return failure();
947 casePatterns.push_back(CaseExprPattern(parser.getContext()).attr());
948 } else {
949 // Parse the pattern. It always starts with b, so it is an MLIR
950 // keyword.
951 StringRef caseVal;
952 loc = parser.getCurrentLocation();
953 if (parser.parseKeyword(&caseVal))
954 return failure();
955
956 if (caseVal.front() != 'b')
957 return parser.emitError(loc, "expected case value starting with 'b'");
958 caseVal = caseVal.drop_front();
959
960 // Parse and decode each bit, we reverse the list later for MSB->LSB.
961 for (; !caseVal.empty(); caseVal = caseVal.drop_front()) {
962 CasePatternBit bit;
963 switch (caseVal.front()) {
964 case '0':
966 break;
967 case '1':
969 break;
970 case 'x':
972 break;
973 case 'z':
975 break;
976 default:
977 return parser.emitError(loc, "unexpected case bit '")
978 << caseVal.front() << "'";
979 }
980 caseBits.push_back(bit);
981 }
982
983 if (caseVal.size() > condWidth)
984 return parser.emitError(loc, "too many bits specified in pattern");
985 std::reverse(caseBits.begin(), caseBits.end());
986
987 // High zeros may be missing.
988 if (caseBits.size() < condWidth)
989 caseBits.append(condWidth - caseBits.size(), CasePatternBit::Zero);
990
991 auto resultPattern = CaseBitPattern(caseBits, builder.getContext());
992 casePatterns.push_back(resultPattern.attr());
993 caseBits.clear();
994 }
995
996 // Parse the case body.
997 auto caseRegion = std::make_unique<Region>();
998 if (parser.parseColon() || parser.parseRegion(*caseRegion))
999 return failure();
1000 result.addRegion(std::move(caseRegion));
1001 }
1002
1003 result.addAttribute("casePatterns", builder.getArrayAttr(casePatterns));
1004 return success();
1005}
1006
1007void CaseOp::print(OpAsmPrinter &p) {
1008 p << ' ';
1009 if (getCaseStyle() == CaseStmtType::CaseXStmt)
1010 p << "casex ";
1011 else if (getCaseStyle() == CaseStmtType::CaseZStmt)
1012 p << "casez ";
1013
1014 if (getValidationQualifier() !=
1015 ValidationQualifierTypeEnum::ValidationQualifierPlain)
1016 p << stringifyValidationQualifierTypeEnum(getValidationQualifier()) << ' ';
1017
1018 p << getCond() << " : " << getCond().getType();
1019 p.printOptionalAttrDict(
1020 (*this)->getAttrs(),
1021 /*elidedAttrs=*/{"casePatterns", "caseStyle", "validationQualifier"});
1022
1023 size_t caseValueIndex = 0;
1024 for (auto &caseInfo : getCases()) {
1025 p.printNewline();
1026 auto &pattern = caseInfo.pattern;
1027
1028 llvm::TypeSwitch<CasePattern *>(pattern.get())
1029 .Case<CaseBitPattern>([&](auto bitPattern) {
1030 p << "case b";
1031 for (size_t bit = 0, e = bitPattern->getWidth(); bit != e; ++bit)
1032 p << getLetter(bitPattern->getBit(e - bit - 1));
1033 })
1034 .Case<CaseEnumPattern>([&](auto enumPattern) {
1035 p << "case " << enumPattern->getFieldValue();
1036 })
1037 .Case<CaseExprPattern>([&](auto) {
1038 p << "case ";
1039 p.printOperand(getCaseValues()[caseValueIndex++]);
1040 })
1041 .Case<CaseDefaultPattern>([&](auto) { p << "default"; })
1042 .Default([&](auto) { assert(false && "unhandled case pattern"); });
1043
1044 p << ": ";
1045 p.printRegion(*caseInfo.block->getParent(), /*printEntryBlockArgs=*/false,
1046 /*printBlockTerminators=*/true);
1047 }
1048}
1049
1050LogicalResult CaseOp::verify() {
1051 if (!(hw::isHWIntegerType(getCond().getType()) ||
1052 hw::isHWEnumType(getCond().getType())))
1053 return emitError("condition must have either integer or enum type");
1054
1055 // Ensure that the number of regions and number of case values match.
1056 if (getCasePatterns().size() != getNumRegions())
1057 return emitOpError("case pattern / region count mismatch");
1058 return success();
1059}
1060
1061/// This ctor allows you to build a CaseZ with some number of cases, getting
1062/// a callback for each case.
1063void CaseOp::build(
1064 OpBuilder &builder, OperationState &result, CaseStmtType caseStyle,
1065 ValidationQualifierTypeEnum validationQualifier, Value cond,
1066 size_t numCases,
1067 std::function<std::unique_ptr<CasePattern>(size_t)> caseCtor) {
1068 result.addOperands(cond);
1069 result.addAttribute("caseStyle",
1070 CaseStmtTypeAttr::get(builder.getContext(), caseStyle));
1071 result.addAttribute("validationQualifier",
1072 ValidationQualifierTypeEnumAttr::get(
1073 builder.getContext(), validationQualifier));
1074 SmallVector<Attribute> casePatterns;
1075
1076 OpBuilder::InsertionGuard guard(builder);
1077
1078 // Fill in the cases with the callback.
1079 for (size_t i = 0, e = numCases; i != e; ++i) {
1080 builder.createBlock(result.addRegion());
1081 casePatterns.push_back(caseCtor(i)->attr());
1082 }
1083
1084 result.addAttribute("casePatterns", builder.getArrayAttr(casePatterns));
1085}
1086
1087// Strength reduce case styles based on the bit patterns.
1088LogicalResult CaseOp::canonicalize(CaseOp op, PatternRewriter &rewriter) {
1089 if (op.getCaseStyle() == CaseStmtType::CaseStmt)
1090 return failure();
1091 if (isa<hw::EnumType>(op.getCond().getType()))
1092 return failure();
1093
1094 auto caseInfo = op.getCases();
1095 bool noXZ = llvm::all_of(caseInfo, [](const CaseInfo &ci) {
1096 return !ci.pattern.get()->hasX() && !ci.pattern.get()->hasZ();
1097 });
1098 bool noX = llvm::all_of(caseInfo, [](const CaseInfo &ci) {
1099 if (isa<CaseDefaultPattern>(ci.pattern))
1100 return true;
1101 return !ci.pattern.get()->hasX();
1102 });
1103 bool noZ = llvm::all_of(caseInfo, [](const CaseInfo &ci) {
1104 if (isa<CaseDefaultPattern>(ci.pattern))
1105 return true;
1106 return !ci.pattern.get()->hasZ();
1107 });
1108
1109 if (op.getCaseStyle() == CaseStmtType::CaseXStmt) {
1110 if (noXZ) {
1111 rewriter.modifyOpInPlace(op, [&]() {
1112 op.setCaseStyleAttr(
1113 CaseStmtTypeAttr::get(op.getContext(), CaseStmtType::CaseStmt));
1114 });
1115 return success();
1116 }
1117 if (noX) {
1118 rewriter.modifyOpInPlace(op, [&]() {
1119 op.setCaseStyleAttr(
1120 CaseStmtTypeAttr::get(op.getContext(), CaseStmtType::CaseZStmt));
1121 });
1122 return success();
1123 }
1124 }
1125
1126 if (op.getCaseStyle() == CaseStmtType::CaseZStmt && noZ) {
1127 rewriter.modifyOpInPlace(op, [&]() {
1128 op.setCaseStyleAttr(
1129 CaseStmtTypeAttr::get(op.getContext(), CaseStmtType::CaseStmt));
1130 });
1131 return success();
1132 }
1133
1134 return failure();
1135}
1136
1137//===----------------------------------------------------------------------===//
1138// OrderedOutputOp
1139//===----------------------------------------------------------------------===//
1140
1141void OrderedOutputOp::build(OpBuilder &builder, OperationState &result,
1142 std::function<void()> body) {
1143 OpBuilder::InsertionGuard guard(builder);
1144
1145 builder.createBlock(result.addRegion());
1146
1147 // Fill in the body of the ordered block.
1148 if (body)
1149 body();
1150}
1151
1152//===----------------------------------------------------------------------===//
1153// ForOp
1154//===----------------------------------------------------------------------===//
1155
1156void ForOp::build(OpBuilder &builder, OperationState &result,
1157 int64_t lowerBound, int64_t upperBound, int64_t step,
1158 IntegerType type, StringRef name,
1159 llvm::function_ref<void(BlockArgument)> body) {
1160 auto lb = hw::ConstantOp::create(builder, result.location, type, lowerBound);
1161 auto ub = hw::ConstantOp::create(builder, result.location, type, upperBound);
1162 auto st = hw::ConstantOp::create(builder, result.location, type, step);
1163 build(builder, result, lb, ub, st, name, body);
1164}
1165void ForOp::build(OpBuilder &builder, OperationState &result, Value lowerBound,
1166 Value upperBound, Value step, StringRef name,
1167 llvm::function_ref<void(BlockArgument)> body) {
1168 OpBuilder::InsertionGuard guard(builder);
1169 build(builder, result, lowerBound, upperBound, step, name);
1170 auto *region = result.regions.front().get();
1171 builder.createBlock(region);
1172 BlockArgument blockArgument =
1173 region->addArgument(lowerBound.getType(), result.location);
1174
1175 if (body)
1176 body(blockArgument);
1177}
1178
1179void ForOp::getAsmBlockArgumentNames(mlir::Region &region,
1180 mlir::OpAsmSetValueNameFn setNameFn) {
1181 auto *block = &region.front();
1182 setNameFn(block->getArgument(0), getInductionVarNameAttr());
1183}
1184
1185ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
1186 auto &builder = parser.getBuilder();
1187 Type type;
1188
1189 OpAsmParser::Argument inductionVariable;
1190 OpAsmParser::UnresolvedOperand lb, ub, step;
1191 // Parse the optional initial iteration arguments.
1192 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1193
1194 // Parse the induction variable followed by '='.
1195 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
1196 // Parse loop bounds.
1197 parser.parseOperand(lb) || parser.parseKeyword("to") ||
1198 parser.parseOperand(ub) || parser.parseKeyword("step") ||
1199 parser.parseOperand(step) || parser.parseColon() ||
1200 parser.parseType(type))
1201 return failure();
1202
1203 regionArgs.push_back(inductionVariable);
1204
1205 // Resolve input operands.
1206 regionArgs.front().type = type;
1207 if (parser.resolveOperand(lb, type, result.operands) ||
1208 parser.resolveOperand(ub, type, result.operands) ||
1209 parser.resolveOperand(step, type, result.operands))
1210 return failure();
1211
1212 // Parse the body region.
1213 Region *body = result.addRegion();
1214 if (parser.parseRegion(*body, regionArgs))
1215 return failure();
1216
1217 // Parse the optional attribute list.
1218 if (parser.parseOptionalAttrDict(result.attributes))
1219 return failure();
1220
1221 if (!inductionVariable.ssaName.name.empty()) {
1222 if (!isdigit(inductionVariable.ssaName.name[1]))
1223 // Retrive from its SSA name.
1224 result.attributes.append(
1225 {builder.getStringAttr("inductionVarName"),
1226 builder.getStringAttr(inductionVariable.ssaName.name.drop_front())});
1227 }
1228
1229 return success();
1230}
1231
1232void ForOp::print(OpAsmPrinter &p) {
1233 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
1234 << getUpperBound() << " step " << getStep();
1235 p << " : " << getInductionVar().getType() << ' ';
1236 p.printRegion(getRegion(),
1237 /*printEntryBlockArgs=*/false,
1238 /*printBlockTerminators=*/false);
1239 p.printOptionalAttrDict((*this)->getAttrs(), {"inductionVarName"});
1240}
1241
1242LogicalResult ForOp::canonicalize(ForOp op, PatternRewriter &rewriter) {
1243 APInt lb, ub, step;
1244 if (matchPattern(op.getLowerBound(), mlir::m_ConstantInt(&lb)) &&
1245 matchPattern(op.getUpperBound(), mlir::m_ConstantInt(&ub)) &&
1246 matchPattern(op.getStep(), mlir::m_ConstantInt(&step)) &&
1247 lb + step == ub) {
1248 // Unroll the loop if it's executed only once.
1249 rewriter.replaceAllUsesWith(op.getInductionVar(), op.getLowerBound());
1250 replaceOpWithRegion(rewriter, op, op.getBodyRegion());
1251 rewriter.eraseOp(op);
1252 return success();
1253 }
1254 return failure();
1255}
1256
1257//===----------------------------------------------------------------------===//
1258// Assignment statements
1259//===----------------------------------------------------------------------===//
1260
1261LogicalResult BPAssignOp::verify() {
1262 if (isa<sv::WireOp>(getDest().getDefiningOp()))
1263 return emitOpError(
1264 "Verilog disallows procedural assignment to a net type (did you intend "
1265 "to use a variable type, e.g., sv.reg?)");
1266 return success();
1267}
1268
1269LogicalResult PAssignOp::verify() {
1270 if (isa<sv::WireOp>(getDest().getDefiningOp()))
1271 return emitOpError(
1272 "Verilog disallows procedural assignment to a net type (did you intend "
1273 "to use a variable type, e.g., sv.reg?)");
1274 return success();
1275}
1276
1277namespace {
1278// This represents a slice of an array.
1279struct ArraySlice {
1280 Value array;
1281 Value start;
1282 size_t size; // Represent a range array[start, start + size).
1283
1284 // Get a struct from the value. Return std::nullopt if the value doesn't
1285 // represent an array slice.
1286 static std::optional<ArraySlice> getArraySlice(Value v) {
1287 auto *op = v.getDefiningOp();
1288 if (!op)
1289 return std::nullopt;
1290 return TypeSwitch<Operation *, std::optional<ArraySlice>>(op)
1291 .Case<hw::ArrayGetOp, ArrayIndexInOutOp>(
1292 [](auto arrayIndex) -> std::optional<ArraySlice> {
1293 hw::ConstantOp constant =
1294 arrayIndex.getIndex()
1295 .template getDefiningOp<hw::ConstantOp>();
1296 if (!constant)
1297 return std::nullopt;
1298 return ArraySlice{/*array=*/arrayIndex.getInput(),
1299 /*start=*/constant,
1300 /*end=*/1};
1301 })
1302 .Case<hw::ArraySliceOp>([](hw::ArraySliceOp slice)
1303 -> std::optional<ArraySlice> {
1304 auto constant = slice.getLowIndex().getDefiningOp<hw::ConstantOp>();
1305 if (!constant)
1306 return std::nullopt;
1307 return ArraySlice{
1308 /*array=*/slice.getInput(), /*start=*/constant,
1309 /*end=*/
1310 hw::type_cast<hw::ArrayType>(slice.getType()).getNumElements()};
1311 })
1312 .Case<sv::IndexedPartSelectInOutOp>(
1313 [](sv::IndexedPartSelectInOutOp index)
1314 -> std::optional<ArraySlice> {
1315 auto constant = index.getBase().getDefiningOp<hw::ConstantOp>();
1316 if (!constant || index.getDecrement())
1317 return std::nullopt;
1318 return ArraySlice{/*array=*/index.getInput(),
1319 /*start=*/constant,
1320 /*end=*/index.getWidth()};
1321 })
1322 .Default([](auto) { return std::nullopt; });
1323 }
1324
1325 // Create a pair of ArraySlice from source and destination of assignments.
1326 static std::optional<std::pair<ArraySlice, ArraySlice>>
1327 getAssignedRange(Operation *op) {
1328 assert((isa<PAssignOp, BPAssignOp>(op) && "assignments are expected"));
1329 auto srcRange = ArraySlice::getArraySlice(op->getOperand(1));
1330 if (!srcRange)
1331 return std::nullopt;
1332 auto destRange = ArraySlice::getArraySlice(op->getOperand(0));
1333 if (!destRange)
1334 return std::nullopt;
1335
1336 return std::make_pair(*destRange, *srcRange);
1337 }
1338};
1339} // namespace
1340
1341// This canonicalization merges neiboring assignments of array elements into
1342// array slice assignments. e.g.
1343// a[0] <= b[1]
1344// a[1] <= b[2]
1345// ->
1346// a[1:0] <= b[2:1]
1347template <typename AssignTy>
1348static LogicalResult mergeNeiboringAssignments(AssignTy op,
1349 PatternRewriter &rewriter) {
1350 // Get assigned ranges of each assignment.
1351 auto assignedRangeOpt = ArraySlice::getAssignedRange(op);
1352 if (!assignedRangeOpt)
1353 return failure();
1354
1355 auto [dest, src] = *assignedRangeOpt;
1356 AssignTy nextAssign = dyn_cast_or_null<AssignTy>(op->getNextNode());
1357 bool changed = false;
1358 SmallVector<Location> loc{op.getLoc()};
1359 // Check that a next operation is a same kind of the assignment.
1360 while (nextAssign) {
1361 auto nextAssignedRange = ArraySlice::getAssignedRange(nextAssign);
1362 if (!nextAssignedRange)
1363 break;
1364 auto [nextDest, nextSrc] = *nextAssignedRange;
1365 // Check that these assignments are mergaable.
1366 if (dest.array != nextDest.array || src.array != nextSrc.array ||
1367 !hw::isOffset(dest.start, nextDest.start, dest.size) ||
1368 !hw::isOffset(src.start, nextSrc.start, src.size))
1369 break;
1370
1371 dest.size += nextDest.size;
1372 src.size += nextSrc.size;
1373 changed = true;
1374 loc.push_back(nextAssign.getLoc());
1375 rewriter.eraseOp(nextAssign);
1376 nextAssign = dyn_cast_or_null<AssignTy>(op->getNextNode());
1377 }
1378
1379 if (!changed)
1380 return failure();
1381
1382 // From here, construct assignments of array slices.
1383 auto resultType = hw::ArrayType::get(
1384 hw::type_cast<hw::ArrayType>(src.array.getType()).getElementType(),
1385 src.size);
1386 auto newDest = sv::IndexedPartSelectInOutOp::create(
1387 rewriter, op.getLoc(), dest.array, dest.start, dest.size);
1388 auto newSrc = hw::ArraySliceOp::create(rewriter, op.getLoc(), resultType,
1389 src.array, src.start);
1390 auto newLoc = rewriter.getFusedLoc(loc);
1391 auto newOp = rewriter.replaceOpWithNewOp<AssignTy>(op, newDest, newSrc);
1392 newOp->setLoc(newLoc);
1393 return success();
1394}
1395
1396LogicalResult PAssignOp::canonicalize(PAssignOp op, PatternRewriter &rewriter) {
1397 return mergeNeiboringAssignments(op, rewriter);
1398}
1399
1400LogicalResult BPAssignOp::canonicalize(BPAssignOp op,
1401 PatternRewriter &rewriter) {
1402 return mergeNeiboringAssignments(op, rewriter);
1403}
1404
1405//===----------------------------------------------------------------------===//
1406// TypeDecl operations
1407//===----------------------------------------------------------------------===//
1408
1409void InterfaceOp::build(OpBuilder &builder, OperationState &result,
1410 StringRef sym_name, std::function<void()> body) {
1411 OpBuilder::InsertionGuard guard(builder);
1412
1413 result.addAttribute(::SymbolTable::getSymbolAttrName(),
1414 builder.getStringAttr(sym_name));
1415 builder.createBlock(result.addRegion());
1416 if (body)
1417 body();
1418}
1419
1420ModportType InterfaceOp::getModportType(StringRef modportName) {
1421 assert(lookupSymbol<InterfaceModportOp>(modportName) &&
1422 "Modport symbol not found.");
1423 auto *ctxt = getContext();
1424 return ModportType::get(
1425 getContext(),
1426 SymbolRefAttr::get(ctxt, getSymName(),
1427 {SymbolRefAttr::get(ctxt, modportName)}));
1428}
1429
1430Type InterfaceOp::getSignalType(StringRef signalName) {
1431 InterfaceSignalOp signal = lookupSymbol<InterfaceSignalOp>(signalName);
1432 assert(signal && "Interface signal symbol not found.");
1433 return signal.getType();
1434}
1435
1436static ParseResult parseModportStructs(OpAsmParser &parser,
1437 ArrayAttr &portsAttr) {
1438
1439 auto *context = parser.getBuilder().getContext();
1440
1441 SmallVector<Attribute, 8> ports;
1442 auto parseElement = [&]() -> ParseResult {
1443 auto direction = ModportDirectionAttr::parse(parser, {});
1444 if (!direction)
1445 return failure();
1446
1447 FlatSymbolRefAttr signal;
1448 if (parser.parseAttribute(signal))
1449 return failure();
1450
1451 ports.push_back(ModportStructAttr::get(
1452 context, cast<ModportDirectionAttr>(direction), signal));
1453 return success();
1454 };
1455 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
1456 parseElement))
1457 return failure();
1458
1459 portsAttr = ArrayAttr::get(context, ports);
1460 return success();
1461}
1462
1463static void printModportStructs(OpAsmPrinter &p, Operation *,
1464 ArrayAttr portsAttr) {
1465 p << "(";
1466 llvm::interleaveComma(portsAttr, p, [&](Attribute attr) {
1467 auto port = cast<ModportStructAttr>(attr);
1468 p << stringifyEnum(port.getDirection().getValue());
1469 p << ' ';
1470 p.printSymbolName(port.getSignal().getRootReference().getValue());
1471 });
1472 p << ')';
1473}
1474
1475void InterfaceSignalOp::build(mlir::OpBuilder &builder,
1476 ::mlir::OperationState &state, StringRef name,
1477 mlir::Type type) {
1478 build(builder, state, name, mlir::TypeAttr::get(type));
1479}
1480
1481void InterfaceModportOp::build(OpBuilder &builder, OperationState &state,
1482 StringRef name, ArrayRef<StringRef> inputs,
1483 ArrayRef<StringRef> outputs) {
1484 auto *ctxt = builder.getContext();
1485 SmallVector<Attribute, 8> directions;
1486 auto inputDir = ModportDirectionAttr::get(ctxt, ModportDirection::input);
1487 auto outputDir = ModportDirectionAttr::get(ctxt, ModportDirection::output);
1488 for (auto input : inputs)
1489 directions.push_back(ModportStructAttr::get(
1490 ctxt, inputDir, SymbolRefAttr::get(ctxt, input)));
1491 for (auto output : outputs)
1492 directions.push_back(ModportStructAttr::get(
1493 ctxt, outputDir, SymbolRefAttr::get(ctxt, output)));
1494 build(builder, state, name, ArrayAttr::get(ctxt, directions));
1495}
1496
1497std::optional<size_t> InterfaceInstanceOp::getTargetResultIndex() {
1498 // Inner symbols on instance operations target the op not any result.
1499 return std::nullopt;
1500}
1501
1502/// Suggest a name for each result value based on the saved result names
1503/// attribute.
1504void InterfaceInstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
1505 setNameFn(getResult(), getName());
1506}
1507
1508/// Ensure that the symbol being instantiated exists and is an InterfaceOp.
1509LogicalResult InterfaceInstanceOp::verify() {
1510 if (getName().empty())
1511 return emitOpError("requires non-empty name");
1512 return success();
1513}
1514
1515LogicalResult
1516InterfaceInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1517 auto *symtable = SymbolTable::getNearestSymbolTable(*this);
1518 if (!symtable)
1519 return emitError("sv.interface.instance must exist within a region "
1520 "which has a symbol table.");
1521 auto ifaceTy = getType();
1522 auto *referencedOp =
1523 symbolTable.lookupSymbolIn(symtable, ifaceTy.getInterface());
1524 if (!referencedOp)
1525 return emitError("Symbol not found: ") << ifaceTy.getInterface() << ".";
1526 if (!isa<InterfaceOp>(referencedOp))
1527 return emitError("Symbol ")
1528 << ifaceTy.getInterface() << " is not an InterfaceOp.";
1529 return success();
1530}
1531
1532/// Ensure that the symbol being instantiated exists and is an
1533/// InterfaceModportOp.
1534LogicalResult
1535GetModportOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1536 auto *symtable = SymbolTable::getNearestSymbolTable(*this);
1537 if (!symtable)
1538 return emitError("sv.interface.instance must exist within a region "
1539 "which has a symbol table.");
1540
1541 auto ifaceTy = getType();
1542 auto *referencedOp =
1543 symbolTable.lookupSymbolIn(symtable, ifaceTy.getModport());
1544 if (!referencedOp)
1545 return emitError("Symbol not found: ") << ifaceTy.getModport() << ".";
1546 if (!isa<InterfaceModportOp>(referencedOp))
1547 return emitError("Symbol ")
1548 << ifaceTy.getModport() << " is not an InterfaceModportOp.";
1549 return success();
1550}
1551
1552void GetModportOp::build(OpBuilder &builder, OperationState &state, Value value,
1553 StringRef field) {
1554 auto ifaceTy = dyn_cast<InterfaceType>(value.getType());
1555 assert(ifaceTy && "GetModportOp expects an InterfaceType.");
1556 auto fieldAttr = SymbolRefAttr::get(builder.getContext(), field);
1557 auto modportSym =
1558 SymbolRefAttr::get(ifaceTy.getInterface().getRootReference(), fieldAttr);
1559 build(builder, state, ModportType::get(builder.getContext(), modportSym),
1560 value, fieldAttr);
1561}
1562
1563/// Lookup the op for the modport declaration. This returns null on invalid
1564/// IR.
1565InterfaceModportOp
1566GetModportOp::getReferencedDecl(const hw::HWSymbolCache &cache) {
1567 return dyn_cast_or_null<InterfaceModportOp>(
1568 cache.getDefinition(getFieldAttr()));
1569}
1570
1571void ReadInterfaceSignalOp::build(OpBuilder &builder, OperationState &state,
1572 Value iface, StringRef signalName) {
1573 auto ifaceTy = dyn_cast<InterfaceType>(iface.getType());
1574 assert(ifaceTy && "ReadInterfaceSignalOp expects an InterfaceType.");
1575 auto fieldAttr = SymbolRefAttr::get(builder.getContext(), signalName);
1576 InterfaceOp ifaceDefOp = SymbolTable::lookupNearestSymbolFrom<InterfaceOp>(
1577 iface.getDefiningOp(), ifaceTy.getInterface());
1578 assert(ifaceDefOp &&
1579 "ReadInterfaceSignalOp could not resolve an InterfaceOp.");
1580 build(builder, state, ifaceDefOp.getSignalType(signalName), iface, fieldAttr);
1581}
1582
1583/// Lookup the op for the signal declaration. This returns null on invalid
1584/// IR.
1585InterfaceSignalOp
1586ReadInterfaceSignalOp::getReferencedDecl(const hw::HWSymbolCache &cache) {
1587 return dyn_cast_or_null<InterfaceSignalOp>(
1588 cache.getDefinition(getSignalNameAttr()));
1589}
1590
1591ParseResult parseIfaceTypeAndSignal(OpAsmParser &p, Type &ifaceTy,
1592 FlatSymbolRefAttr &signalName) {
1593 SymbolRefAttr fullSym;
1594 if (p.parseAttribute(fullSym) || fullSym.getNestedReferences().size() != 1)
1595 return failure();
1596
1597 auto *ctxt = p.getBuilder().getContext();
1598 ifaceTy = InterfaceType::get(
1599 ctxt, FlatSymbolRefAttr::get(fullSym.getRootReference()));
1600 signalName = FlatSymbolRefAttr::get(fullSym.getLeafReference());
1601 return success();
1602}
1603
1604void printIfaceTypeAndSignal(OpAsmPrinter &p, Operation *op, Type type,
1605 FlatSymbolRefAttr signalName) {
1606 InterfaceType ifaceTy = dyn_cast<InterfaceType>(type);
1607 assert(ifaceTy && "Expected an InterfaceType");
1608 auto sym = SymbolRefAttr::get(ifaceTy.getInterface().getRootReference(),
1609 {signalName});
1610 p << sym;
1611}
1612
1613LogicalResult verifySignalExists(Value ifaceVal, FlatSymbolRefAttr signalName) {
1614 auto ifaceTy = dyn_cast<InterfaceType>(ifaceVal.getType());
1615 if (!ifaceTy)
1616 return failure();
1617 InterfaceOp iface = SymbolTable::lookupNearestSymbolFrom<InterfaceOp>(
1618 ifaceVal.getDefiningOp(), ifaceTy.getInterface());
1619 if (!iface)
1620 return failure();
1621 InterfaceSignalOp signal = iface.lookupSymbol<InterfaceSignalOp>(signalName);
1622 if (!signal)
1623 return failure();
1624 return success();
1625}
1626
1627Operation *
1628InterfaceInstanceOp::getReferencedInterface(const hw::HWSymbolCache *cache) {
1629 FlatSymbolRefAttr interface = getInterfaceType().getInterface();
1630 if (cache)
1631 if (auto *result = cache->getDefinition(interface))
1632 return result;
1633
1634 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1635 if (!topLevelModuleOp)
1636 return nullptr;
1637
1638 return topLevelModuleOp.lookupSymbol(interface);
1639}
1640
1641LogicalResult AssignInterfaceSignalOp::verify() {
1642 return verifySignalExists(getIface(), getSignalNameAttr());
1643}
1644
1645LogicalResult ReadInterfaceSignalOp::verify() {
1646 return verifySignalExists(getIface(), getSignalNameAttr());
1647}
1648
1649//===----------------------------------------------------------------------===//
1650// WireOp
1651//===----------------------------------------------------------------------===//
1652
1653void WireOp::build(OpBuilder &builder, OperationState &odsState,
1654 Type elementType, StringAttr name,
1655 hw::InnerSymAttr innerSym) {
1656 if (!name)
1657 name = builder.getStringAttr("");
1658 if (innerSym)
1659 odsState.addAttribute(hw::InnerSymbolTable::getInnerSymbolAttrName(),
1660 innerSym);
1661
1662 odsState.addAttribute("name", name);
1663 odsState.addTypes(InOutType::get(elementType));
1664}
1665
1666/// Suggest a name for each result value based on the saved result names
1667/// attribute.
1668void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
1669 // If the wire has an optional 'name' attribute, use it.
1670 auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
1671 if (!nameAttr.getValue().empty())
1672 setNameFn(getResult(), nameAttr.getValue());
1673}
1674
1675std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
1676
1677// If this wire is only written to, delete the wire and all writers.
1678LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
1679 // Block if op has SV attributes.
1680 if (hasSVAttributes(wire))
1681 return failure();
1682
1683 // If the wire has a symbol, then we can't delete it.
1684 if (wire.getInnerSymAttr())
1685 return failure();
1686
1687 // Wires have inout type, so they'll have assigns and read_inout operations
1688 // that work on them. If anything unexpected is found then leave it alone.
1689 SmallVector<sv::ReadInOutOp> reads;
1690 sv::AssignOp write;
1691
1692 for (auto *user : wire->getUsers()) {
1693 if (auto read = dyn_cast<sv::ReadInOutOp>(user)) {
1694 reads.push_back(read);
1695 continue;
1696 }
1697
1698 // Otherwise must be an assign, and we must not have seen a write yet.
1699 auto assign = dyn_cast<sv::AssignOp>(user);
1700 // Either the wire has more than one write or another kind of Op (other than
1701 // AssignOp and ReadInOutOp), then can't optimize.
1702 if (!assign || write)
1703 return failure();
1704
1705 // If the assign op has SV attributes, we don't want to delete the
1706 // assignment.
1707 if (hasSVAttributes(assign))
1708 return failure();
1709
1710 write = assign;
1711 }
1712
1713 Value connected;
1714 if (!write) {
1715 // If no write and only reads, then replace with ZOp.
1716 // SV 6.6: "If no driver is connected to a net, its
1717 // value shall be high-impedance (z) unless the net is a trireg"
1718 connected = ConstantZOp::create(
1719 rewriter, wire.getLoc(),
1720 cast<InOutType>(wire.getResult().getType()).getElementType());
1721 } else if (isa<hw::HWModuleOp>(write->getParentOp()))
1722 connected = write.getSrc();
1723 else
1724 // If the write is happening at the module level then we don't have any
1725 // use-before-def checking to do, so we only handle that for now.
1726 return failure();
1727
1728 // If the wire has a name attribute, propagate the name to the expression.
1729 if (auto *connectedOp = connected.getDefiningOp())
1730 if (!wire.getName().empty())
1731 rewriter.modifyOpInPlace(connectedOp, [&] {
1732 connectedOp->setAttr("sv.namehint", wire.getNameAttr());
1733 });
1734
1735 // Ok, we can do this. Replace all the reads with the connected value.
1736 for (auto read : reads)
1737 rewriter.replaceOp(read, connected);
1738
1739 // And remove the write and wire itself.
1740 if (write)
1741 rewriter.eraseOp(write);
1742 rewriter.eraseOp(wire);
1743 return success();
1744}
1745
1746//===----------------------------------------------------------------------===//
1747// IndexedPartSelectInOutOp
1748//===----------------------------------------------------------------------===//
1749
1750// A helper function to infer a return type of IndexedPartSelectInOutOp.
1751static Type getElementTypeOfWidth(Type type, int32_t width) {
1752 auto elemTy = cast<hw::InOutType>(type).getElementType();
1753 if (isa<IntegerType>(elemTy))
1754 return hw::InOutType::get(IntegerType::get(type.getContext(), width));
1755 if (isa<hw::ArrayType>(elemTy))
1756 return hw::InOutType::get(hw::ArrayType::get(
1757 cast<hw::ArrayType>(elemTy).getElementType(), width));
1758 return {};
1759}
1760
1761LogicalResult IndexedPartSelectInOutOp::inferReturnTypes(
1762 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
1763 DictionaryAttr attrs, mlir::OpaqueProperties properties,
1764 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
1765 Adaptor adaptor(operands, attrs, properties, regions);
1766 auto width = adaptor.getWidthAttr();
1767 if (!width)
1768 return failure();
1769
1770 auto typ = getElementTypeOfWidth(operands[0].getType(),
1771 width.getValue().getZExtValue());
1772 if (!typ)
1773 return failure();
1774 results.push_back(typ);
1775 return success();
1776}
1777
1778LogicalResult IndexedPartSelectInOutOp::verify() {
1779 unsigned inputWidth = 0, resultWidth = 0;
1780 auto opWidth = getWidth();
1781 auto inputElemTy = cast<InOutType>(getInput().getType()).getElementType();
1782 auto resultElemTy = cast<InOutType>(getType()).getElementType();
1783 if (auto i = dyn_cast<IntegerType>(inputElemTy))
1784 inputWidth = i.getWidth();
1785 else if (auto i = hw::type_cast<hw::ArrayType>(inputElemTy))
1786 inputWidth = i.getNumElements();
1787 else
1788 return emitError("input element type must be Integer or Array");
1789
1790 if (auto resType = dyn_cast<IntegerType>(resultElemTy))
1791 resultWidth = resType.getWidth();
1792 else if (auto resType = hw::type_cast<hw::ArrayType>(resultElemTy))
1793 resultWidth = resType.getNumElements();
1794 else
1795 return emitError("result element type must be Integer or Array");
1796
1797 if (opWidth > inputWidth)
1798 return emitError("slice width should not be greater than input width");
1799 if (opWidth != resultWidth)
1800 return emitError("result width must be equal to slice width");
1801 return success();
1802}
1803
1804OpFoldResult IndexedPartSelectInOutOp::fold(FoldAdaptor) {
1805 if (getType() == getInput().getType())
1806 return getInput();
1807 return {};
1808}
1809
1810//===----------------------------------------------------------------------===//
1811// IndexedPartSelectOp
1812//===----------------------------------------------------------------------===//
1813
1814LogicalResult IndexedPartSelectOp::inferReturnTypes(
1815 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
1816 DictionaryAttr attrs, mlir::OpaqueProperties properties,
1817 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
1818 Adaptor adaptor(operands, attrs, properties, regions);
1819 auto width = adaptor.getWidthAttr();
1820 if (!width)
1821 return failure();
1822
1823 results.push_back(IntegerType::get(context, width.getInt()));
1824 return success();
1825}
1826
1827LogicalResult IndexedPartSelectOp::verify() {
1828 auto opWidth = getWidth();
1829
1830 unsigned resultWidth = cast<IntegerType>(getType()).getWidth();
1831 unsigned inputWidth = cast<IntegerType>(getInput().getType()).getWidth();
1832
1833 if (opWidth > inputWidth)
1834 return emitError("slice width should not be greater than input width");
1835 if (opWidth != resultWidth)
1836 return emitError("result width must be equal to slice width");
1837 return success();
1838}
1839
1840//===----------------------------------------------------------------------===//
1841// StructFieldInOutOp
1842//===----------------------------------------------------------------------===//
1843
1844LogicalResult StructFieldInOutOp::inferReturnTypes(
1845 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
1846 DictionaryAttr attrs, mlir::OpaqueProperties properties,
1847 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
1848 Adaptor adaptor(operands, attrs, properties, regions);
1849 auto field = adaptor.getFieldAttr();
1850 if (!field)
1851 return failure();
1852 auto structType =
1853 hw::type_cast<hw::StructType>(getInOutElementType(operands[0].getType()));
1854 auto resultType = structType.getFieldType(field);
1855 if (!resultType)
1856 return failure();
1857
1858 results.push_back(hw::InOutType::get(resultType));
1859 return success();
1860}
1861
1862//===----------------------------------------------------------------------===//
1863// Other ops.
1864//===----------------------------------------------------------------------===//
1865
1866LogicalResult AliasOp::verify() {
1867 // Must have at least two operands.
1868 if (getAliases().size() < 2)
1869 return emitOpError("alias must have at least two operands");
1870
1871 return success();
1872}
1873
1874//===----------------------------------------------------------------------===//
1875// BindOp
1876//===----------------------------------------------------------------------===//
1877
1878/// Instances must be at the top level of the hw.module (or within a `ifdef)
1879// and are typically at the end of it, so we scan backwards to find them.
1880template <class Op>
1881static Op findInstanceSymbolInBlock(StringAttr name, Block *body) {
1882 for (auto &op : llvm::reverse(body->getOperations())) {
1883 if (auto instance = dyn_cast<Op>(op)) {
1884 if (auto innerSym = instance.getInnerSym())
1885 if (innerSym->getSymName() == name)
1886 return instance;
1887 }
1888
1889 if (auto ifdef = dyn_cast<IfDefOp>(op)) {
1890 if (auto result =
1891 findInstanceSymbolInBlock<Op>(name, ifdef.getThenBlock()))
1892 return result;
1893 if (ifdef.hasElse())
1894 if (auto result =
1895 findInstanceSymbolInBlock<Op>(name, ifdef.getElseBlock()))
1896 return result;
1897 }
1898 }
1899 return {};
1900}
1901
1902hw::InstanceOp BindOp::getReferencedInstance(const hw::HWSymbolCache *cache) {
1903 // If we have a cache, directly look up the referenced instance.
1904 if (cache) {
1905 auto result = cache->getInnerDefinition(getInstance());
1906 return cast<hw::InstanceOp>(result.getOp());
1907 }
1908
1909 // Otherwise, resolve the instance by looking up the module ...
1910 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1911 if (!topLevelModuleOp)
1912 return {};
1913
1914 auto hwModule = dyn_cast_or_null<hw::HWModuleOp>(
1915 topLevelModuleOp.lookupSymbol(getInstance().getModule()));
1916 if (!hwModule)
1917 return {};
1918
1919 // ... then look up the instance within it.
1920 return findInstanceSymbolInBlock<hw::InstanceOp>(getInstance().getName(),
1921 hwModule.getBodyBlock());
1922}
1923
1924/// Ensure that the symbol being instantiated exists and is an InterfaceOp.
1925LogicalResult BindOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1926 auto module = (*this)->getParentOfType<mlir::ModuleOp>();
1927 auto hwModule = dyn_cast_or_null<hw::HWModuleOp>(
1928 symbolTable.lookupSymbolIn(module, getInstance().getModule()));
1929 if (!hwModule)
1930 return emitError("Referenced module doesn't exist ")
1931 << getInstance().getModule() << "::" << getInstance().getName();
1932
1933 auto inst = findInstanceSymbolInBlock<hw::InstanceOp>(
1934 getInstance().getName(), hwModule.getBodyBlock());
1935 if (!inst)
1936 return emitError("Referenced instance doesn't exist ")
1937 << getInstance().getModule() << "::" << getInstance().getName();
1938 if (!inst.getDoNotPrint())
1939 return emitError("Referenced instance isn't marked as doNotPrint");
1940 return success();
1941}
1942
1943void BindOp::build(OpBuilder &builder, OperationState &odsState, StringAttr mod,
1944 StringAttr name) {
1945 auto ref = hw::InnerRefAttr::get(mod, name);
1946 odsState.addAttribute("instance", ref);
1947}
1948
1949//===----------------------------------------------------------------------===//
1950// SVVerbatimSourceOp
1951//===----------------------------------------------------------------------===//
1952
1953void SVVerbatimSourceOp::print(OpAsmPrinter &p) {
1954 p << ' ';
1955
1956 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1957 if (auto visibility = (*this)->getAttrOfType<StringAttr>(visibilityAttrName))
1958 p << visibility.getValue() << ' ';
1959
1960 p.printSymbolName(getSymName());
1961
1962 // Print parameters
1963 circt::printOptionalParameterList(p, *this, getParameters());
1964
1965 // Print attributes using the helper function
1966 SmallVector<StringRef> omittedAttrs = {SymbolTable::getSymbolAttrName(),
1967 "parameters", visibilityAttrName};
1968
1969 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), omittedAttrs);
1970}
1971
1972ParseResult SVVerbatimSourceOp::parse(OpAsmParser &parser,
1973 OperationState &result) {
1974
1975 // parse optional visibility
1976 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1977 StringRef visibility;
1978 if (succeeded(parser.parseOptionalKeyword(&visibility,
1979 {"public", "private", "nested"}))) {
1980 result.addAttribute(visibilityAttrName,
1981 parser.getBuilder().getStringAttr(visibility));
1982 }
1983
1984 // Parse the symbol name
1985 StringAttr nameAttr;
1986 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1987 result.attributes))
1988 return failure();
1989
1990 // Parse optional parameters
1991 ArrayAttr parameters;
1992 if (circt::parseOptionalParameterList(parser, parameters))
1993 return failure();
1994 result.addAttribute("parameters", parameters);
1995
1996 // Parse attributes using the helper function
1997 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1998 return failure();
1999
2000 return success();
2001}
2002
2003LogicalResult SVVerbatimSourceOp::verify() {
2004 // must have verbatim content
2005 if (getContent().empty())
2006 return emitOpError("missing or empty content attribute");
2007
2008 return success();
2009}
2010
2011LogicalResult
2012SVVerbatimSourceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2013 // Verify that all symbols in additional_files are emit.file operations
2014 if (auto additionalFiles = getAdditionalFiles()) {
2015 for (auto fileRef : *additionalFiles) {
2016 auto flatRef = dyn_cast<FlatSymbolRefAttr>(fileRef);
2017 if (!flatRef)
2018 return emitOpError(
2019 "additional_files must contain flat symbol references");
2020
2021 auto *referencedOp =
2022 symbolTable.lookupNearestSymbolFrom(getOperation(), flatRef);
2023 if (!referencedOp)
2024 return emitOpError("references nonexistent file ")
2025 << flatRef.getValue();
2026
2027 // Check that the referenced operation is an emit.file
2028 if (referencedOp->getName().getStringRef() != "emit.file")
2029 return emitOpError("references ")
2030 << flatRef.getValue() << ", which is not an emit.file";
2031 }
2032 }
2033
2034 return success();
2035}
2036
2037//===----------------------------------------------------------------------===//
2038// SVVerbatimModuleOp
2039//===----------------------------------------------------------------------===//
2040
2041SmallVector<hw::PortInfo> SVVerbatimModuleOp::getPortList() {
2042 SmallVector<hw::PortInfo> ports;
2043 auto moduleType = getModuleType();
2044 auto portLocs = getPortLocs();
2045 auto portAttrs = getPerPortAttrs();
2046
2047 for (size_t i = 0, e = moduleType.getNumPorts(); i < e; ++i) {
2048 auto port = moduleType.getPorts()[i];
2049 LocationAttr loc = portLocs && i < portLocs->size()
2050 ? cast<LocationAttr>((*portLocs)[i])
2051 : UnknownLoc::get(getContext());
2052 DictionaryAttr attrs = portAttrs && i < portAttrs->size()
2053 ? cast<DictionaryAttr>((*portAttrs)[i])
2054 : DictionaryAttr::get(getContext());
2061 ports.push_back({{port.name, port.type, dir}, i, attrs, loc});
2062 }
2063 return ports;
2064}
2065
2066hw::PortInfo SVVerbatimModuleOp::getPort(size_t idx) {
2067 return getPortList()[idx];
2068}
2069
2070size_t SVVerbatimModuleOp::getPortIdForInputId(size_t idx) {
2071 return getModuleType().getPortIdForInputId(idx);
2072}
2073
2074size_t SVVerbatimModuleOp::getPortIdForOutputId(size_t idx) {
2075 return getModuleType().getPortIdForOutputId(idx);
2076}
2077
2078size_t SVVerbatimModuleOp::getNumPorts() {
2079 return getModuleType().getNumPorts();
2080}
2081
2082size_t SVVerbatimModuleOp::getNumInputPorts() {
2083 return getModuleType().getNumInputs();
2084}
2085
2086size_t SVVerbatimModuleOp::getNumOutputPorts() {
2087 return getModuleType().getNumOutputs();
2088}
2089
2090hw::ModuleType SVVerbatimModuleOp::getHWModuleType() { return getModuleType(); }
2091
2092ArrayRef<Attribute> SVVerbatimModuleOp::getAllPortAttrs() {
2093 if (auto attrs = getPerPortAttrs())
2094 return attrs->getValue();
2095 return {};
2096}
2097
2098void SVVerbatimModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
2099 setPerPortAttrsAttr(ArrayAttr::get(getContext(), attrs));
2100}
2101
2102void SVVerbatimModuleOp::removeAllPortAttrs() { removePerPortAttrsAttr(); }
2103
2104SmallVector<Location> SVVerbatimModuleOp::getAllPortLocs() {
2105 if (auto locs = getPortLocs()) {
2106 SmallVector<Location> result;
2107 result.reserve(locs->size());
2108 for (auto loc : *locs)
2109 result.push_back(cast<Location>(loc));
2110 return result;
2111 }
2112 return SmallVector<Location>(getNumPorts(), UnknownLoc::get(getContext()));
2113}
2114
2115void SVVerbatimModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
2116 setPortLocsAttr(ArrayAttr::get(getContext(), locs));
2117}
2118
2119void SVVerbatimModuleOp::setHWModuleType(hw::ModuleType type) {
2120 setModuleTypeAttr(TypeAttr::get(type));
2121}
2122
2123void SVVerbatimModuleOp::setAllPortNames(ArrayRef<Attribute> names) {
2124 // Port names are part of the module type, so we need to reconstruct it
2125 auto currentType = getModuleType();
2126 SmallVector<hw::ModulePort> ports;
2127 for (size_t i = 0, e = currentType.getNumPorts(); i < e; ++i) {
2128 auto port = currentType.getPorts()[i];
2129 if (i < names.size())
2130 port.name = cast<StringAttr>(names[i]);
2131 ports.push_back(port);
2132 }
2133 setHWModuleType(hw::ModuleType::get(getContext(), ports));
2134}
2135
2136void SVVerbatimModuleOp::print(OpAsmPrinter &p) {
2137 p << ' ';
2138
2139 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
2140 if (auto visibility = (*this)->getAttrOfType<StringAttr>(visibilityAttrName))
2141 p << visibility.getValue() << ' ';
2142
2143 p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
2144
2145 printOptionalParameterList(p, *this, getParameters());
2146
2147 Region emptyRegion;
2149 p, emptyRegion, getModuleType(), getAllPortAttrs(), getAllPortLocs());
2150
2151 SmallVector<StringRef> omittedAttrs = {
2152 SymbolTable::getSymbolAttrName(), SymbolTable::getVisibilityAttrName(),
2153 getModuleTypeAttrName().getValue(), getPerPortAttrsAttrName().getValue(),
2154 getPortLocsAttrName().getValue(), getParametersAttrName().getValue()};
2155
2156 mlir::function_interface_impl::printFunctionAttributes(p, *this,
2157 omittedAttrs);
2158}
2159
2160ParseResult SVVerbatimModuleOp::parse(OpAsmParser &parser,
2161 OperationState &result) {
2162 using namespace mlir::function_interface_impl;
2163 auto builder = parser.getBuilder();
2164
2165 // Parse the visibility attribute.
2166 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
2167
2168 // Parse the name as a symbol.
2169 StringAttr nameAttr;
2170 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2171 result.attributes))
2172 return failure();
2173
2174 // Parse the parameters.
2175 ArrayAttr parameters;
2176 if (parseOptionalParameterList(parser, parameters))
2177 return failure();
2178
2179 SmallVector<hw::module_like_impl::PortParse> ports;
2180 TypeAttr modType;
2181 if (failed(
2182 hw::module_like_impl::parseModuleSignature(parser, ports, modType)))
2183 return failure();
2184
2185 result.addAttribute(getModuleTypeAttrName(result.name), modType);
2186 result.addAttribute("parameters", parameters);
2187
2188 // Convert the specified array of dictionary attrs (which may have null
2189 // entries) to an ArrayAttr of dictionaries.
2190 auto unknownLoc = builder.getUnknownLoc();
2191 SmallVector<Attribute> attrs, locs;
2192
2193 for (auto &port : ports) {
2194 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
2195 auto loc = port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc;
2196 locs.push_back(loc);
2197 }
2198
2199 if (!attrs.empty())
2200 result.addAttribute("per_port_attrs", builder.getArrayAttr(attrs));
2201 if (!locs.empty())
2202 result.addAttribute("port_locs", builder.getArrayAttr(locs));
2203
2204 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
2205 return failure();
2206
2207 // Verify required attributes exist
2208 if (!result.attributes.get("source"))
2209 return parser.emitError(parser.getCurrentLocation(),
2210 "sv.verbatim.module requires 'source' attribute");
2211
2212 return success();
2213}
2214
2215LogicalResult SVVerbatimModuleOp::verify() { return success(); }
2216
2217LogicalResult
2218SVVerbatimModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2219 // Verify that the source attribute references an sv.verbatim.source operation
2220 auto sourceOp = dyn_cast_or_null<SVVerbatimSourceOp>(
2221 symbolTable.lookupNearestSymbolFrom(*this, getSourceAttr()));
2222 if (!sourceOp)
2223 return emitError("references ") << getSourceAttr().getAttr().getValue()
2224 << ", which is not an sv.verbatim.source";
2225
2226 return success();
2227}
2228
2229//===----------------------------------------------------------------------===//
2230// BindInterfaceOp
2231//===----------------------------------------------------------------------===//
2232
2233sv::InterfaceInstanceOp
2234BindInterfaceOp::getReferencedInstance(const hw::HWSymbolCache *cache) {
2235 // If we have a cache, directly look up the referenced instance.
2236 if (cache) {
2237 auto result = cache->getInnerDefinition(getInstance());
2238 return cast<sv::InterfaceInstanceOp>(result.getOp());
2239 }
2240
2241 // Otherwise, resolve the instance by looking up the module ...
2242 auto *symbolTable = SymbolTable::getNearestSymbolTable(*this);
2243 if (!symbolTable)
2244 return {};
2245 auto *parentOp =
2246 lookupSymbolInNested(symbolTable, getInstance().getModule().getValue());
2247 if (!parentOp)
2248 return {};
2249
2250 // ... then look up the instance within it.
2251 return findInstanceSymbolInBlock<sv::InterfaceInstanceOp>(
2252 getInstance().getName(), &parentOp->getRegion(0).front());
2253}
2254
2255/// Ensure that the symbol being instantiated exists and is an InterfaceOp.
2256LogicalResult
2257BindInterfaceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2258 auto *parentOp =
2259 symbolTable.lookupNearestSymbolFrom(*this, getInstance().getModule());
2260 if (!parentOp)
2261 return emitError("Referenced module doesn't exist ")
2262 << getInstance().getModule() << "::" << getInstance().getName();
2263
2264 auto inst = findInstanceSymbolInBlock<sv::InterfaceInstanceOp>(
2265 getInstance().getName(), &parentOp->getRegion(0).front());
2266 if (!inst)
2267 return emitError("Referenced interface doesn't exist ")
2268 << getInstance().getModule() << "::" << getInstance().getName();
2269 if (!inst.getDoNotPrint())
2270 return emitError("Referenced interface isn't marked as doNotPrint");
2271 return success();
2272}
2273
2274//===----------------------------------------------------------------------===//
2275// XMROp
2276//===----------------------------------------------------------------------===//
2277
2278ParseResult parseXMRPath(::mlir::OpAsmParser &parser, ArrayAttr &pathAttr,
2279 StringAttr &terminalAttr) {
2280 SmallVector<Attribute> strings;
2281 ParseResult ret = parser.parseCommaSeparatedList([&]() {
2282 StringAttr result;
2283 StringRef keyword;
2284 if (succeeded(parser.parseOptionalKeyword(&keyword))) {
2285 strings.push_back(parser.getBuilder().getStringAttr(keyword));
2286 return success();
2287 }
2288 if (succeeded(parser.parseAttribute(
2289 result, parser.getBuilder().getType<NoneType>()))) {
2290 strings.push_back(result);
2291 return success();
2292 }
2293 return failure();
2294 });
2295 if (succeeded(ret)) {
2296 pathAttr = parser.getBuilder().getArrayAttr(
2297 ArrayRef<Attribute>(strings).drop_back());
2298 terminalAttr = cast<StringAttr>(*strings.rbegin());
2299 }
2300 return ret;
2301}
2302
2303void printXMRPath(OpAsmPrinter &p, XMROp op, ArrayAttr pathAttr,
2304 StringAttr terminalAttr) {
2305 llvm::interleaveComma(pathAttr, p);
2306 p << ", " << terminalAttr;
2307}
2308
2309/// Ensure that the symbol being instantiated exists and is a HierPathOp.
2310LogicalResult XMRRefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2311 auto *table = SymbolTable::getNearestSymbolTable(*this);
2312 auto path = dyn_cast_or_null<hw::HierPathOp>(
2313 symbolTable.lookupSymbolIn(table, getRefAttr()));
2314 if (!path)
2315 return emitError("Referenced path doesn't exist ") << getRefAttr();
2316
2317 return success();
2318}
2319
2320hw::HierPathOp XMRRefOp::getReferencedPath(const hw::HWSymbolCache *cache) {
2321 if (cache)
2322 if (auto *result = cache->getDefinition(getRefAttr().getAttr()))
2323 return cast<hw::HierPathOp>(result);
2324
2325 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
2326 return topLevelModuleOp.lookupSymbol<hw::HierPathOp>(getRefAttr().getValue());
2327}
2328
2329//===----------------------------------------------------------------------===//
2330// Verification Ops.
2331//===----------------------------------------------------------------------===//
2332
2333static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value value,
2334 PatternRewriter &rewriter,
2335 bool eraseIfZero) {
2336 if (auto constant = value.getDefiningOp<hw::ConstantOp>())
2337 if (constant.getValue().isZero() == eraseIfZero) {
2338 rewriter.eraseOp(op);
2339 return success();
2340 }
2341
2342 return failure();
2343}
2344
2345template <class Op, bool EraseIfZero = false>
2346static LogicalResult canonicalizeImmediateVerifOp(Op op,
2347 PatternRewriter &rewriter) {
2348 return eraseIfZeroOrNotZero(op, op.getExpression(), rewriter, EraseIfZero);
2349}
2350
2351void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
2352 MLIRContext *context) {
2353 results.add(canonicalizeImmediateVerifOp<AssertOp>);
2354}
2355
2356void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2357 MLIRContext *context) {
2358 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
2359}
2360
2361void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
2362 MLIRContext *context) {
2363 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
2364}
2365
2366template <class Op, bool EraseIfZero = false>
2367static LogicalResult canonicalizeConcurrentVerifOp(Op op,
2368 PatternRewriter &rewriter) {
2369 return eraseIfZeroOrNotZero(op, op.getProperty(), rewriter, EraseIfZero);
2370}
2371
2372void AssertConcurrentOp::getCanonicalizationPatterns(RewritePatternSet &results,
2373 MLIRContext *context) {
2374 results.add(canonicalizeConcurrentVerifOp<AssertConcurrentOp>);
2375}
2376
2377void AssumeConcurrentOp::getCanonicalizationPatterns(RewritePatternSet &results,
2378 MLIRContext *context) {
2379 results.add(canonicalizeConcurrentVerifOp<AssumeConcurrentOp>);
2380}
2381
2382void CoverConcurrentOp::getCanonicalizationPatterns(RewritePatternSet &results,
2383 MLIRContext *context) {
2384 results.add(
2385 canonicalizeConcurrentVerifOp<CoverConcurrentOp, /* EraseIfZero */ true>);
2386}
2387
2388//===----------------------------------------------------------------------===//
2389// SV generate ops
2390//===----------------------------------------------------------------------===//
2391
2392/// Parse cases formatted like:
2393/// case (pattern, "name") { ... }
2394bool parseCaseRegions(OpAsmParser &p, ArrayAttr &patternsArray,
2395 ArrayAttr &caseNamesArray,
2396 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
2397 SmallVector<Attribute> patterns;
2398 SmallVector<Attribute> names;
2399 while (!p.parseOptionalKeyword("case")) {
2400 Attribute pattern;
2401 StringAttr name;
2402 std::unique_ptr<Region> region = std::make_unique<Region>();
2403 if (p.parseLParen() || p.parseAttribute(pattern) || p.parseComma() ||
2404 p.parseAttribute(name) || p.parseRParen() || p.parseRegion(*region))
2405 return true;
2406 patterns.push_back(pattern);
2407 names.push_back(name);
2408 if (region->empty())
2409 region->push_back(new Block());
2410 caseRegions.push_back(std::move(region));
2411 }
2412 patternsArray = p.getBuilder().getArrayAttr(patterns);
2413 caseNamesArray = p.getBuilder().getArrayAttr(names);
2414 return false;
2415}
2416
2417/// Print cases formatted like:
2418/// case (pattern, "name") { ... }
2419void printCaseRegions(OpAsmPrinter &p, Operation *, ArrayAttr patternsArray,
2420 ArrayAttr namesArray,
2421 MutableArrayRef<Region> caseRegions) {
2422 assert(patternsArray.size() == caseRegions.size());
2423 assert(patternsArray.size() == namesArray.size());
2424 for (size_t i = 0, e = caseRegions.size(); i < e; ++i) {
2425 p.printNewline();
2426 p << "case (" << patternsArray[i] << ", " << namesArray[i] << ") ";
2427 p.printRegion(caseRegions[i]);
2428 }
2429 p.printNewline();
2430}
2431
2432LogicalResult GenerateCaseOp::verify() {
2433 size_t numPatterns = getCasePatterns().size();
2434 if (getCaseRegions().size() != numPatterns ||
2435 getCaseNames().size() != numPatterns)
2436 return emitOpError(
2437 "Size of caseRegions, patterns, and caseNames must match");
2438
2439 StringSet<> usedNames;
2440 for (Attribute name : getCaseNames()) {
2441 StringAttr nameStr = dyn_cast<StringAttr>(name);
2442 if (!nameStr)
2443 return emitOpError("caseNames must all be string attributes");
2444 if (usedNames.contains(nameStr.getValue()))
2445 return emitOpError("caseNames must be unique");
2446 usedNames.insert(nameStr.getValue());
2447 }
2448
2449 // mlir::FailureOr<Type> condType = evaluateParametricType();
2450
2451 return success();
2452}
2453
2454ModportStructAttr ModportStructAttr::get(MLIRContext *context,
2455 ModportDirection direction,
2456 FlatSymbolRefAttr signal) {
2457 return get(context, ModportDirectionAttr::get(context, direction), signal);
2458}
2459
2460//===----------------------------------------------------------------------===//
2461// FuncOp
2462//===----------------------------------------------------------------------===//
2463
2464ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
2465 auto builder = parser.getBuilder();
2466 // Parse visibility.
2467 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
2468
2469 // Parse the name as a symbol.
2470 StringAttr nameAttr;
2471 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2472 result.attributes))
2473 return failure();
2474
2475 SmallVector<hw::module_like_impl::PortParse> ports;
2476 TypeAttr modType;
2477 if (failed(
2478 hw::module_like_impl::parseModuleSignature(parser, ports, modType)))
2479 return failure();
2480
2481 result.addAttribute(FuncOp::getModuleTypeAttrName(result.name), modType);
2482
2483 // Convert the specified array of dictionary attrs (which may have null
2484 // entries) to an ArrayAttr of dictionaries.
2485 auto unknownLoc = builder.getUnknownLoc();
2486 SmallVector<Attribute> attrs, inputLocs, outputLocs;
2487 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
2488 return attr && cast<Location>(attr) != unknownLoc;
2489 };
2490
2491 for (auto &port : ports) {
2492 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
2493 auto loc = port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc;
2494 (port.direction == hw::PortInfo::Direction::Output ? outputLocs : inputLocs)
2495 .push_back(loc);
2496 }
2497
2498 result.addAttribute(FuncOp::getPerArgumentAttrsAttrName(result.name),
2499 builder.getArrayAttr(attrs));
2500
2501 if (llvm::any_of(outputLocs, nonEmptyLocsFn))
2502 result.addAttribute(FuncOp::getResultLocsAttrName(result.name),
2503 builder.getArrayAttr(outputLocs));
2504 // Parse the attribute dict.
2505 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
2506 return failure();
2507
2508 // Add the entry block arguments.
2509 SmallVector<OpAsmParser::Argument, 4> entryArgs;
2510 for (auto &port : ports)
2511 if (port.direction != hw::ModulePort::Direction::Output)
2512 entryArgs.push_back(port);
2513
2514 // Parse the optional function body. The printer will not print the body if
2515 // its empty, so disallow parsing of empty body in the parser.
2516 auto *body = result.addRegion();
2517 llvm::SMLoc loc = parser.getCurrentLocation();
2518
2519 mlir::OptionalParseResult parseResult =
2520 parser.parseOptionalRegion(*body, entryArgs,
2521 /*enableNameShadowing=*/false);
2522 if (parseResult.has_value()) {
2523 if (failed(*parseResult))
2524 return failure();
2525 // Function body was parsed, make sure its not empty.
2526 if (body->empty())
2527 return parser.emitError(loc, "expected non-empty function body");
2528 } else {
2529 if (llvm::any_of(inputLocs, nonEmptyLocsFn))
2530 result.addAttribute(FuncOp::getInputLocsAttrName(result.name),
2531 builder.getArrayAttr(inputLocs));
2532 }
2533
2534 return success();
2535}
2536
2537void FuncOp::getAsmBlockArgumentNames(mlir::Region &region,
2538 mlir::OpAsmSetValueNameFn setNameFn) {
2539 if (region.empty())
2540 return;
2541 // Assign port names to the bbargs.
2542 auto func = cast<FuncOp>(region.getParentOp());
2543
2544 auto *block = &region.front();
2545
2546 auto names = func.getModuleType().getInputNames();
2547 for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
2548 // Let mlir deterministically convert names to valid identifiers
2549 setNameFn(block->getArgument(i), cast<StringAttr>(names[i]));
2550 }
2551}
2552
2553Type FuncOp::getExplicitlyReturnedType() {
2554 if (!getPerArgumentAttrs() || getNumOutputs() == 0)
2555 return {};
2556
2557 // Check if the last port is used as an explicit return.
2558 auto lastArgument = getModuleType().getPorts().back();
2559 auto lastArgumentAttr = dyn_cast<DictionaryAttr>(
2560 getPerArgumentAttrsAttr()[getPerArgumentAttrsAttr().size() - 1]);
2561
2562 if (lastArgument.dir == hw::ModulePort::Output && lastArgumentAttr &&
2563 lastArgumentAttr.getAs<UnitAttr>(getExplicitlyReturnedAttrName()))
2564 return lastArgument.type;
2565 return {};
2566}
2567
2568ArrayRef<Attribute> FuncOp::getAllPortAttrs() {
2569 if (getPerArgumentAttrs())
2570 return getPerArgumentAttrs()->getValue();
2571 return {};
2572}
2573
2574void FuncOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
2575 setPerArgumentAttrsAttr(ArrayAttr::get(getContext(), attrs));
2576}
2577
2578void FuncOp::removeAllPortAttrs() { setPerArgumentAttrsAttr({}); }
2579SmallVector<Location> FuncOp::getAllPortLocs() {
2580 SmallVector<Location> portLocs;
2581 portLocs.reserve(getNumPorts());
2582 auto resultLocs = getResultLocsAttr();
2583 unsigned inputCount = 0;
2584 auto modType = getModuleType();
2585 auto unknownLoc = UnknownLoc::get(getContext());
2586 auto *body = getBodyBlock();
2587 auto inputLocs = getInputLocsAttr();
2588 for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
2589 if (modType.isOutput(i)) {
2590 auto loc = resultLocs
2591 ? cast<Location>(
2592 resultLocs.getValue()[portLocs.size() - inputCount])
2593 : unknownLoc;
2594 portLocs.push_back(loc);
2595 } else {
2596 auto loc = body ? body->getArgument(inputCount).getLoc()
2597 : (inputLocs ? cast<Location>(inputLocs[inputCount])
2598 : unknownLoc);
2599 portLocs.push_back(loc);
2600 ++inputCount;
2601 }
2602 }
2603 return portLocs;
2604}
2605
2606void FuncOp::setAllPortLocsAttrs(llvm::ArrayRef<mlir::Attribute> locs) {
2607 SmallVector<Attribute> resultLocs, inputLocs;
2608 unsigned inputCount = 0;
2609 auto modType = getModuleType();
2610 auto *body = getBodyBlock();
2611 for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
2612 if (modType.isOutput(i))
2613 resultLocs.push_back(locs[i]);
2614 else if (body)
2615 body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
2616 else // Need to store locations in an attribute if declaration.
2617 inputLocs.push_back(locs[i]);
2618 }
2619 setResultLocsAttr(ArrayAttr::get(getContext(), resultLocs));
2620 if (!body)
2621 setInputLocsAttr(ArrayAttr::get(getContext(), inputLocs));
2622}
2623
2624SmallVector<hw::PortInfo> FuncOp::getPortList() { return getPortList(false); }
2625
2626hw::PortInfo FuncOp::getPort(size_t idx) {
2627 auto modTy = getHWModuleType();
2628 auto emptyDict = DictionaryAttr::get(getContext());
2629 LocationAttr loc = getPortLoc(idx);
2630 DictionaryAttr attrs = dyn_cast_or_null<DictionaryAttr>(getPortAttrs(idx));
2631 if (!attrs)
2632 attrs = emptyDict;
2633 return {modTy.getPorts()[idx],
2634 modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
2635 : modTy.getInputIdForPortId(idx),
2636 attrs, loc};
2637}
2638
2639SmallVector<hw::PortInfo> FuncOp::getPortList(bool excludeExplicitReturn) {
2640 auto modTy = getModuleType();
2641 auto emptyDict = DictionaryAttr::get(getContext());
2642 auto skipLastArgument = getExplicitlyReturnedType() && excludeExplicitReturn;
2643 SmallVector<hw::PortInfo> retval;
2644 auto portAttr = getAllPortLocs();
2645 for (unsigned i = 0, e = skipLastArgument ? modTy.getNumPorts() - 1
2646 : modTy.getNumPorts();
2647 i < e; ++i) {
2648 DictionaryAttr attrs = emptyDict;
2649 if (auto perArgumentAttr = getPerArgumentAttrs())
2650 if (auto argumentAttr =
2651 dyn_cast_or_null<DictionaryAttr>((*perArgumentAttr)[i]))
2652 attrs = argumentAttr;
2653
2654 retval.push_back({modTy.getPorts()[i],
2655 modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
2656 : modTy.getInputIdForPortId(i),
2657 attrs, portAttr[i]});
2658 }
2659 return retval;
2660}
2661
2662void FuncOp::print(OpAsmPrinter &p) {
2663 FuncOp op = *this;
2664 // Print the operation and the function name.
2665 auto funcName =
2666 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
2667 .getValue();
2668 p << ' ';
2669
2670 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
2671 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
2672 p << visibility.getValue() << ' ';
2673 p.printSymbolName(funcName);
2675 p, op.getBody(), op.getModuleType(),
2676 op.getPerArgumentAttrsAttr()
2677 ? ArrayRef<Attribute>(op.getPerArgumentAttrsAttr().getValue())
2678 : ArrayRef<Attribute>{},
2679 getAllPortLocs());
2680
2681 mlir::function_interface_impl::printFunctionAttributes(
2682 p, op,
2683 {visibilityAttrName, getModuleTypeAttrName(),
2684 getPerArgumentAttrsAttrName(), getInputLocsAttrName(),
2685 getResultLocsAttrName()});
2686 // Print the body if this is not an external function.
2687 Region &body = op->getRegion(0);
2688 if (!body.empty()) {
2689 p << ' ';
2690 p.printRegion(body, /*printEntryBlockArgs=*/false,
2691 /*printBlockTerminators=*/true);
2692 }
2693}
2694
2695//===----------------------------------------------------------------------===//
2696// ReturnOp
2697//===----------------------------------------------------------------------===//
2698
2699LogicalResult ReturnOp::verify() {
2700 auto func = getParentOp<sv::FuncOp>();
2701 auto funcResults = func.getResultTypes();
2702 auto returnedValues = getOperands();
2703 if (funcResults.size() != returnedValues.size())
2704 return emitOpError("must have same number of operands as region results.");
2705 // Check that the types of our operands and the region's results match.
2706 for (size_t i = 0, e = funcResults.size(); i < e; ++i) {
2707 if (funcResults[i] != returnedValues[i].getType()) {
2708 emitOpError("output types must match function. In "
2709 "operand ")
2710 << i << ", expected " << funcResults[i] << ", but got "
2711 << returnedValues[i].getType() << ".";
2712 return failure();
2713 }
2714 }
2715 return success();
2716}
2717
2718//===----------------------------------------------------------------------===//
2719// Call Ops
2720//===----------------------------------------------------------------------===//
2721
2722static Value
2724 mlir::Operation::result_range results) {
2725 if (!op.getExplicitlyReturnedType())
2726 return {};
2727 return results.back();
2728}
2729
2730Value FuncCallOp::getExplicitlyReturnedValue(sv::FuncOp op) {
2731 return getExplicitlyReturnedValueImpl(op, getResults());
2732}
2733
2734Value FuncCallProceduralOp::getExplicitlyReturnedValue(sv::FuncOp op) {
2735 return getExplicitlyReturnedValueImpl(op, getResults());
2736}
2737
2738LogicalResult
2739FuncCallProceduralOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2740 auto referencedOp = dyn_cast_or_null<sv::FuncOp>(
2741 symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr()));
2742 if (!referencedOp)
2743 return emitError("cannot find function declaration '")
2744 << getCallee() << "'";
2745 return success();
2746}
2747
2748LogicalResult FuncCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2749 auto referencedOp = dyn_cast_or_null<sv::FuncOp>(
2750 symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr()));
2751 if (!referencedOp)
2752 return emitError("cannot find function declaration '")
2753 << getCallee() << "'";
2754
2755 // Non-procedural call cannot have output arguments.
2756 if (referencedOp.getNumOutputs() != 1 ||
2757 !referencedOp.getExplicitlyReturnedType()) {
2758 auto diag = emitError()
2759 << "function called in a non-procedural region must "
2760 "return a single result";
2761 diag.attachNote(referencedOp.getLoc()) << "doesn't satisfy the constraint";
2762 return failure();
2763 }
2764 return success();
2765}
2766
2767//===----------------------------------------------------------------------===//
2768// FuncDPIImportOp
2769//===----------------------------------------------------------------------===//
2770
2771LogicalResult
2772FuncDPIImportOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2773 auto referencedOp = dyn_cast_or_null<sv::FuncOp>(
2774 symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr()));
2775
2776 if (!referencedOp)
2777 return emitError("cannot find function declaration '")
2778 << getCallee() << "'";
2779 if (!referencedOp.isDeclaration())
2780 return emitError("imported function must be a declaration but '")
2781 << getCallee() << "' is defined";
2782 return success();
2783}
2784
2785//===----------------------------------------------------------------------===//
2786// Assert Property Like ops
2787//===----------------------------------------------------------------------===//
2788
2790// Check that a clock is never given without an event
2791// and that an event is never given with a clock.
2792static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc) {
2793 if ((!clock && eventExists) || (clock && !eventExists))
2794 return mlir::emitError(
2795 loc, "Every clock must be associated to an even and vice-versa!");
2796 return success();
2797}
2798} // namespace AssertPropertyLikeOp
2799
2800LogicalResult AssertPropertyOp::verify() {
2801 return AssertPropertyLikeOp::verify(getClock(), getEvent().has_value(),
2802 getLoc());
2803}
2804
2805LogicalResult AssumePropertyOp::verify() {
2806 return AssertPropertyLikeOp::verify(getClock(), getEvent().has_value(),
2807 getLoc());
2808}
2809
2810LogicalResult CoverPropertyOp::verify() {
2811 return AssertPropertyLikeOp::verify(getClock(), getEvent().has_value(),
2812 getLoc());
2813}
2814
2815//===----------------------------------------------------------------------===//
2816// TableGen generated logic.
2817//===----------------------------------------------------------------------===//
2818
2819// Provide the autogenerated implementation guts for the Op classes.
2820#define GET_OP_CLASSES
2821#include "circt/Dialect/SV/SV.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static bool hasSVAttributes(Operation *op)
Definition CombFolds.cpp:67
#define isdigit(x)
Definition FIRLexer.cpp:26
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition HWOps.cpp:1428
static SmallVector< Location > getAllPortLocs(ModTy module)
Definition HWOps.cpp:1206
static void setHWModuleType(ModTy &mod, ModuleType type)
Definition HWOps.cpp:1349
@ Output
Definition HW.h:42
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
bool parseCaseRegions(OpAsmParser &p, ArrayAttr &patternsArray, ArrayAttr &caseNamesArray, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse cases formatted like: case (pattern, "name") { ... }.
Definition SVOps.cpp:2394
ParseResult parseIfaceTypeAndSignal(OpAsmParser &p, Type &ifaceTy, FlatSymbolRefAttr &signalName)
Definition SVOps.cpp:1591
LogicalResult verifySignalExists(Value ifaceVal, FlatSymbolRefAttr signalName)
Definition SVOps.cpp:1613
void printCaseRegions(OpAsmPrinter &p, Operation *, ArrayAttr patternsArray, ArrayAttr namesArray, MutableArrayRef< Region > caseRegions)
Print cases formatted like: case (pattern, "name") { ... }.
Definition SVOps.cpp:2419
static Value getExplicitlyReturnedValueImpl(sv::FuncOp op, mlir::Operation::result_range results)
Definition SVOps.cpp:2723
void printIfaceTypeAndSignal(OpAsmPrinter &p, Operation *op, Type type, FlatSymbolRefAttr signalName)
Definition SVOps.cpp:1604
static void printModportStructs(OpAsmPrinter &p, Operation *, ArrayAttr portsAttr)
Definition SVOps.cpp:1463
static LogicalResult canonicalizeConcurrentVerifOp(Op op, PatternRewriter &rewriter)
Definition SVOps.cpp:2367
static ParseResult parseEventList(OpAsmParser &p, Attribute &eventsAttr, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &clocksOperands)
Definition SVOps.cpp:640
static MacroDeclOp getReferencedMacro(const hw::HWSymbolCache *cache, Operation *op, FlatSymbolRefAttr macroName)
Definition SVOps.cpp:155
static LogicalResult canonicalizeIfDefLike(Op op, PatternRewriter &rewriter)
Definition SVOps.cpp:444
ParseResult parseXMRPath(::mlir::OpAsmParser &parser, ArrayAttr &pathAttr, StringAttr &terminalAttr)
Definition SVOps.cpp:2278
static Type getElementTypeOfWidth(Type type, int32_t width)
Definition SVOps.cpp:1751
static LogicalResult mergeNeiboringAssignments(AssignTy op, PatternRewriter &rewriter)
Definition SVOps.cpp:1348
static Op findInstanceSymbolInBlock(StringAttr name, Block *body)
Instances must be at the top level of the hw.module (or within a `ifdef)
Definition SVOps.cpp:1881
static void printEventList(OpAsmPrinter &p, AlwaysOp op, ArrayAttr portsAttr, OperandRange operands)
Definition SVOps.cpp:671
static SmallVector< CasePatternBit > getPatternBitsForValue(const APInt &value)
Definition SVOps.cpp:805
static ParseResult parseImplicitInitType(OpAsmParser &p, mlir::Type regType, std::optional< OpAsmParser::UnresolvedOperand > &initValue, mlir::Type &initType)
Definition SVOps.cpp:296
static LogicalResult verifyMacroIdentSymbolUses(Operation *op, FlatSymbolRefAttr attr, SymbolTableCollection &symbolTable)
Verifies symbols referenced by macro identifiers.
Definition SVOps.cpp:100
static void getVerbatimExprAsmResultNames(Operation *op, function_ref< void(Value, StringRef)> setNameFn)
Get the asm name for sv.verbatim.expr and sv.verbatim.expr.se.
Definition SVOps.cpp:116
static void printImplicitInitType(OpAsmPrinter &p, Operation *op, mlir::Type regType, mlir::Value initValue, mlir::Type initType)
Definition SVOps.cpp:310
static ParseResult parseModportStructs(OpAsmParser &parser, ArrayAttr &portsAttr)
Definition SVOps.cpp:1436
static Operation * lookupSymbolInNested(Operation *symbolTableOp, StringRef symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Definition SVOps.cpp:75
void printXMRPath(OpAsmPrinter &p, XMROp op, ArrayAttr pathAttr, StringAttr terminalAttr)
Definition SVOps.cpp:2303
static InstancePath empty
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition HWSymCache.h:27
HWSymbolCache::Item getInnerDefinition(mlir::StringAttr modSymbol, mlir::StringAttr name) const
Definition HWSymCache.h:65
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition HWSymCache.h:56
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
IntegerAttr intAttr
Definition SVOps.h:123
CasePatternBit getBit(size_t bitNumber) const
Return the specified bit, bit 0 is the least significant bit.
Definition SVOps.cpp:787
bool hasZ() const override
Return true if this pattern has an Z.
Definition SVOps.cpp:799
CaseBitPattern(ArrayRef< CasePatternBit > bits, MLIRContext *context)
Get a CasePattern from a specified list of CasePatternBit.
Definition SVOps.cpp:823
bool hasX() const override
Return true if this pattern has an X.
Definition SVOps.cpp:792
hw::EnumFieldAttr enumAttr
Definition SVOps.h:140
StringRef getFieldValue() const
Definition SVOps.cpp:866
Signals that an operations regions are procedural.
Definition SVOps.h:176
create(array_value, low_index, ret_type)
Definition hw.py:466
create(data_type, value)
Definition hw.py:433
Definition sv.py:70
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition SVOps.cpp:2792
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:66
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
bool isHWIntegerType(mlir::Type type)
Return true if the specified type is a value HW Integer type.
Definition HWTypes.cpp:60
bool isOffset(Value base, Value index, uint64_t offset)
Definition HWOps.cpp:1842
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition HWOps.cpp:529
bool isHWEnumType(mlir::Type type)
Return true if the specified type is a HW Enum type.
Definition HWTypes.cpp:73
mlir::Type getCanonicalType(mlir::Type type)
Definition HWTypes.cpp:49
CasePatternBit
This describes the bit in a pattern, 0/1/x/z.
Definition SVOps.h:49
char getLetter(CasePatternBit bit)
Return the letter for the specified pattern bit, e.g. "0", "1", "x" or "z".
Definition SVOps.cpp:772
bool hasSVAttributes(mlir::Operation *op)
Helper functions to handle SV attributes.
bool is2StateExpression(Value v)
Returns if the expression is known to be 2-state (binary)
Definition SVOps.cpp:42
mlir::Type getInOutElementType(mlir::Type type)
Return the element type of an InOutType or null if the operand isn't an InOut type.
Definition SVTypes.cpp:42
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr &parameters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
Definition sv.py:1
This holds the name, type, direction of a module's ports.