CIRCT  18.0.0git
Circuit.cpp
Go to the documentation of this file.
1 //===-- Circuit.cpp - intermediate representation for circuits --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// This file defines an intermediate representation for circuits acting as
10 /// an abstraction for constraints defined over an SMT's solver context.
11 ///
12 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/IR/Builders.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 #define DEBUG_TYPE "lec-circuit"
22 
23 using namespace mlir;
24 using namespace circt;
25 
26 /// Add an input to the circuit; internally a new value gets allocated.
27 void Solver::Circuit::addInput(Value value) {
28  LLVM_DEBUG(lec::dbgs() << name << " addInput\n");
29  lec::Scope indent;
30  z3::expr input = fetchOrAllocateExpr(value);
31  inputs.insert(inputs.end(), input);
32 }
33 
34 /// Add an output to the circuit.
35 void Solver::Circuit::addOutput(Value value) {
36  LLVM_DEBUG(lec::dbgs() << name << " addOutput\n");
37  // Referenced value already assigned, fetching from expression table.
38  z3::expr output = fetchOrAllocateExpr(value);
39  outputs.insert(outputs.end(), output);
40 }
41 
42 /// Recover the inputs.
43 llvm::ArrayRef<z3::expr> Solver::Circuit::getInputs() { return inputs; }
44 
45 /// Recover the outputs.
46 llvm::ArrayRef<z3::expr> Solver::Circuit::getOutputs() { return outputs; }
47 
48 //===----------------------------------------------------------------------===//
49 // `hw` dialect operations
50 //===----------------------------------------------------------------------===//
51 
52 void Solver::Circuit::addConstant(Value opResult, const APInt &opValue) {
53  LLVM_DEBUG(lec::dbgs() << name << " addConstant\n");
54  lec::Scope indent;
55  allocateConstant(opResult, opValue);
56 }
57 
58 void Solver::Circuit::addInstance(llvm::StringRef instanceName,
59  circt::hw::HWModuleOp op,
60  OperandRange arguments, ResultRange results) {
61  LLVM_DEBUG(lec::dbgs() << name << " addInstance\n");
62  lec::Scope indent;
63  LLVM_DEBUG(lec::dbgs() << "instance name: " << instanceName << "\n");
64  LLVM_DEBUG(lec::dbgs() << "module name: " << op->getName() << "\n");
65  // There is no preventing multiple instances holding the same name.
66  // As an hack, a suffix is used to differentiate them.
67  std::string suffix = "_" + std::to_string(assignments);
68  Circuit instance(name + "@" + instanceName + suffix, solver);
69  // Export logic to the instance's circuit by visiting the IR of the
70  // instanced module.
71  auto res = LogicExporter(op.getModuleName(), &instance).run(op);
72  (void)res; // Suppress Warning
73  assert(res.succeeded() && "Instance visit failed");
74 
75  // Constrain the inputs and outputs of the instanced circuit to, respectively,
76  // the arguments and results of the instance operation.
77  {
78  LLVM_DEBUG(lec::dbgs() << "instance inputs:\n");
79  lec::Scope indent;
80  auto *input = instance.inputs.begin();
81  for (Value argument : arguments) {
82  LLVM_DEBUG(lec::dbgs() << "input\n");
83  z3::expr argExpr = fetchOrAllocateExpr(argument);
84  solver.solver.add(argExpr == *input++);
85  }
86  }
87  {
88  LLVM_DEBUG(lec::dbgs() << "instance results:\n");
89  lec::Scope indent;
90  auto *output = instance.outputs.begin();
91  for (circt::OpResult result : results) {
92  z3::expr resultExpr = fetchOrAllocateExpr(result);
93  solver.solver.add(resultExpr == *output++);
94  }
95  }
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // `comb` dialect operations
100 //===----------------------------------------------------------------------===//
101 
102 void Solver::Circuit::performAdd(Value result, OperandRange operands) {
103  LLVM_DEBUG(lec::dbgs() << name << " perform Add\n");
104  lec::Scope indent;
105  variadicOperation(result, operands,
106  [](auto op1, auto op2) { return op1 + op2; });
107 }
108 
109 void Solver::Circuit::performAnd(Value result, OperandRange operands) {
110  LLVM_DEBUG(lec::dbgs() << name << " perform And\n");
111  lec::Scope indent;
112  variadicOperation(result, operands,
113  [](auto op1, auto op2) { return z3::operator&(op1, op2); });
114 }
115 
116 void Solver::Circuit::performConcat(Value result, OperandRange operands) {
117  LLVM_DEBUG(lec::dbgs() << name << " perform Concat\n");
118  lec::Scope indent;
119  variadicOperation(result, operands,
120  [](auto op1, auto op2) { return z3::concat(op1, op2); });
121 }
122 
123 void Solver::Circuit::performDivS(Value result, Value lhs, Value rhs) {
124  LLVM_DEBUG(lec::dbgs() << name << " perform DivS\n");
125  lec::Scope indent;
126  LLVM_DEBUG(lec::dbgs() << "lhs:\n");
127  z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
128  LLVM_DEBUG(lec::dbgs() << "rhs:\n");
129  z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
130  z3::expr op = z3::operator/(lhsExpr, rhsExpr);
131  constrainResult(result, op);
132 }
133 
134 void Solver::Circuit::performDivU(Value result, Value lhs, Value rhs) {
135  LLVM_DEBUG(lec::dbgs() << name << " perform DivU\n");
136  lec::Scope indent;
137  LLVM_DEBUG(lec::dbgs() << "lhs:\n");
138  z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
139  LLVM_DEBUG(lec::dbgs() << "rhs:\n");
140  z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
141  z3::expr op = z3::udiv(lhsExpr, rhsExpr);
142  constrainResult(result, op);
143 }
144 
145 void Solver::Circuit::performExtract(Value result, Value input,
146  uint32_t lowBit) {
147  LLVM_DEBUG(lec::dbgs() << name << " performExtract\n");
148  lec::Scope indent;
149  LLVM_DEBUG(lec::dbgs() << "input:\n");
150  z3::expr inputExpr = fetchOrAllocateExpr(input);
151  unsigned width = result.getType().getIntOrFloatBitWidth();
152  LLVM_DEBUG(lec::dbgs() << "width: " << width << "\n");
153  z3::expr extract = inputExpr.extract(lowBit + width - 1, lowBit);
154  constrainResult(result, extract);
155 }
156 
157 LogicalResult Solver::Circuit::performICmp(Value result,
158  circt::comb::ICmpPredicate predicate,
159  Value lhs, Value rhs) {
160  LLVM_DEBUG(lec::dbgs() << name << " performICmp\n");
161  lec::Scope indent;
162  LLVM_DEBUG(lec::dbgs() << "lhs:\n");
163  z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
164  LLVM_DEBUG(lec::dbgs() << "rhs:\n");
165  z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
166  z3::expr icmp(solver.context);
167 
168  switch (predicate) {
169  case circt::comb::ICmpPredicate::eq:
170  icmp = boolToBv(lhsExpr == rhsExpr);
171  break;
172  case circt::comb::ICmpPredicate::ne:
173  icmp = boolToBv(lhsExpr != rhsExpr);
174  break;
175  case circt::comb::ICmpPredicate::slt:
176  icmp = boolToBv(z3::slt(lhsExpr, rhsExpr));
177  break;
178  case circt::comb::ICmpPredicate::sle:
179  icmp = boolToBv(z3::sle(lhsExpr, rhsExpr));
180  break;
181  case circt::comb::ICmpPredicate::sgt:
182  icmp = boolToBv(z3::sgt(lhsExpr, rhsExpr));
183  break;
184  case circt::comb::ICmpPredicate::sge:
185  icmp = boolToBv(z3::sge(lhsExpr, rhsExpr));
186  break;
187  case circt::comb::ICmpPredicate::ult:
188  icmp = boolToBv(z3::ult(lhsExpr, rhsExpr));
189  break;
190  case circt::comb::ICmpPredicate::ule:
191  icmp = boolToBv(z3::ule(lhsExpr, rhsExpr));
192  break;
193  case circt::comb::ICmpPredicate::ugt:
194  icmp = boolToBv(z3::ugt(lhsExpr, rhsExpr));
195  break;
196  case circt::comb::ICmpPredicate::uge:
197  icmp = boolToBv(z3::uge(lhsExpr, rhsExpr));
198  break;
199  // Multi-valued logic comparisons are not supported.
200  case circt::comb::ICmpPredicate::ceq:
201  case circt::comb::ICmpPredicate::weq:
202  case circt::comb::ICmpPredicate::cne:
203  case circt::comb::ICmpPredicate::wne:
204  result.getDefiningOp()->emitError(
205  "n-state logic predicates are not supported");
206  return failure();
207  };
208 
209  constrainResult(result, icmp);
210  return success();
211 }
212 
213 void Solver::Circuit::performModS(Value result, Value lhs, Value rhs) {
214  LLVM_DEBUG(lec::dbgs() << name << " perform ModS\n");
215  lec::Scope indent;
216  LLVM_DEBUG(lec::dbgs() << "lhs:\n");
217  z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
218  LLVM_DEBUG(lec::dbgs() << "rhs:\n");
219  z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
220  z3::expr op = z3::smod(lhsExpr, rhsExpr);
221  constrainResult(result, op);
222 }
223 
224 void Solver::Circuit::performModU(Value result, Value lhs, Value rhs) {
225  LLVM_DEBUG(lec::dbgs() << name << " perform ModU\n");
226  lec::Scope indent;
227  LLVM_DEBUG(lec::dbgs() << "lhs:\n");
228  z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
229  LLVM_DEBUG(lec::dbgs() << "rhs:\n");
230  z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
231  z3::expr op = z3::urem(lhsExpr, rhsExpr);
232  constrainResult(result, op);
233 }
234 
235 void Solver::Circuit::performMul(Value result, OperandRange operands) {
236  LLVM_DEBUG(lec::dbgs() << name << " perform Mul\n");
237  lec::Scope indent;
238  variadicOperation(result, operands,
239  [](auto op1, auto op2) { return op1 * op2; });
240 }
241 
242 void Solver::Circuit::performMux(Value result, Value cond, Value trueValue,
243  Value falseValue) {
244  LLVM_DEBUG(lec::dbgs() << name << " performMux\n");
245  lec::Scope indent;
246  LLVM_DEBUG(lec::dbgs() << "cond:\n");
247  z3::expr condExpr = fetchOrAllocateExpr(cond);
248  LLVM_DEBUG(lec::dbgs() << "trueValue:\n");
249  z3::expr tvalue = fetchOrAllocateExpr(trueValue);
250  LLVM_DEBUG(lec::dbgs() << "falseValue:\n");
251  z3::expr fvalue = fetchOrAllocateExpr(falseValue);
252  // Conversion due to z3::ite requiring a bool rather than a bitvector.
253  z3::expr mux = z3::ite(bvToBool(condExpr), tvalue, fvalue);
254  constrainResult(result, mux);
255 }
256 
257 void Solver::Circuit::performOr(Value result, OperandRange operands) {
258  LLVM_DEBUG(lec::dbgs() << name << " perform Or\n");
259  lec::Scope indent;
260  variadicOperation(result, operands,
261  [](auto op1, auto op2) { return op1 | op2; });
262 }
263 
264 void Solver::Circuit::performParity(Value result, Value input) {
265  LLVM_DEBUG(lec::dbgs() << name << " performParity\n");
266  lec::Scope indent;
267  LLVM_DEBUG(lec::dbgs() << "input:\n");
268  z3::expr inputExpr = fetchOrAllocateExpr(input);
269 
270  unsigned width = inputExpr.get_sort().bv_size();
271 
272  // input has 1 or more bits
273  z3::expr parity = inputExpr.extract(0, 0);
274  // calculate parity with every other bit
275  for (unsigned int i = 1; i < width; i++) {
276  parity = parity ^ inputExpr.extract(i, i);
277  }
278 
279  constrainResult(result, parity);
280 }
281 
282 void Solver::Circuit::performReplicate(Value result, Value input) {
283  LLVM_DEBUG(lec::dbgs() << name << " performReplicate\n");
284  lec::Scope indent;
285  LLVM_DEBUG(lec::dbgs() << "input:\n");
286  z3::expr inputExpr = fetchOrAllocateExpr(input);
287 
288  unsigned int final = result.getType().getIntOrFloatBitWidth();
289  unsigned int initial = input.getType().getIntOrFloatBitWidth();
290  unsigned int times = final / initial;
291  LLVM_DEBUG(lec::dbgs() << "replies: " << times << "\n");
292 
293  z3::expr replicate = inputExpr;
294  for (unsigned int i = 1; i < times; i++) {
295  replicate = z3::concat(replicate, inputExpr);
296  }
297 
298  constrainResult(result, replicate);
299 }
300 
301 void Solver::Circuit::performShl(Value result, Value lhs, Value rhs) {
302  LLVM_DEBUG(lec::dbgs() << name << " perform Shl\n");
303  lec::Scope indent;
304  LLVM_DEBUG(lec::dbgs() << "lhs:\n");
305  z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
306  LLVM_DEBUG(lec::dbgs() << "rhs:\n");
307  z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
308  z3::expr op = z3::shl(lhsExpr, rhsExpr);
309  constrainResult(result, op);
310 }
311 
312 // Arithmetic shift right.
313 void Solver::Circuit::performShrS(Value result, Value lhs, Value rhs) {
314  LLVM_DEBUG(lec::dbgs() << name << " perform ShrS\n");
315  lec::Scope indent;
316  LLVM_DEBUG(lec::dbgs() << "lhs:\n");
317  z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
318  LLVM_DEBUG(lec::dbgs() << "rhs:\n");
319  z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
320  z3::expr op = z3::ashr(lhsExpr, rhsExpr);
321  constrainResult(result, op);
322 }
323 
324 // Logical shift right.
325 void Solver::Circuit::performShrU(Value result, Value lhs, Value rhs) {
326  LLVM_DEBUG(lec::dbgs() << name << " perform ShrU\n");
327  lec::Scope indent;
328  LLVM_DEBUG(lec::dbgs() << "lhs:\n");
329  z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
330  LLVM_DEBUG(lec::dbgs() << "rhs:\n");
331  z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
332  z3::expr op = z3::lshr(lhsExpr, rhsExpr);
333  constrainResult(result, op);
334 }
335 
336 void Solver::Circuit::performSub(Value result, OperandRange operands) {
337  LLVM_DEBUG(lec::dbgs() << name << " perform Sub\n");
338  lec::Scope indent;
339  variadicOperation(result, operands,
340  [](auto op1, auto op2) { return op1 - op2; });
341 }
342 
343 void Solver::Circuit::performXor(Value result, OperandRange operands) {
344  LLVM_DEBUG(lec::dbgs() << name << " perform Xor\n");
345  lec::Scope indent;
346  variadicOperation(result, operands,
347  [](auto op1, auto op2) { return op1 ^ op2; });
348 }
349 
350 /// Helper function for performing a variadic operation: it executes a lambda
351 /// over a range of operands.
352 void Solver::Circuit::variadicOperation(
353  Value result, OperandRange operands,
354  llvm::function_ref<z3::expr(const z3::expr &, const z3::expr &)>
355  operation) {
356  // Allocate operands if unallocated
357  LLVM_DEBUG(lec::dbgs() << "variadic operation\n");
358  lec::Scope indent;
359  // Vacuous base case.
360  auto it = operands.begin();
361  Value operand = *it;
362  z3::expr varOp = fetchOrAllocateExpr(operand);
363  {
364  LLVM_DEBUG(lec::dbgs() << "first operand:\n");
365  lec::Scope indent;
366  LLVM_DEBUG(lec::printValue(operand));
367  }
368  ++it;
369  // Inductive step.
370  while (it != operands.end()) {
371  operand = *it;
372  varOp = operation(varOp, fetchOrAllocateExpr(operand));
373  {
374  LLVM_DEBUG(lec::dbgs() << "next operand:\n");
375  lec::Scope indent;
376  LLVM_DEBUG(lec::printValue(operand));
377  }
378  ++it;
379  }
380  constrainResult(result, varOp);
381 }
382 
383 /// Allocates an IR value in the logical backend and returns its representing
384 /// expression.
385 z3::expr Solver::Circuit::fetchOrAllocateExpr(Value value) {
386  z3::expr expr(solver.context);
387  auto exprPair = exprTable.find(value);
388  if (exprPair != exprTable.end()) {
389  LLVM_DEBUG(lec::dbgs() << "value already allocated:\n");
390  lec::Scope indent;
391  expr = exprPair->second;
392  LLVM_DEBUG(lec::printExpr(expr));
393  LLVM_DEBUG(lec::printValue(value));
394  } else {
395  std::string valueName = name + "%" + std::to_string(assignments++);
396  LLVM_DEBUG(lec::dbgs() << "allocating value:\n");
397  lec::Scope indent;
398  Type type = value.getType();
399  assert(type.isSignlessInteger() && "Unsupported type");
400  unsigned int width = type.getIntOrFloatBitWidth();
401  // Technically allowed for the `hw` dialect but
402  // disallowed for `comb` operations; should check separately.
403  assert(width > 0 && "0-width integers are not supported"); // NOLINT
404  expr = solver.context.bv_const(valueName.c_str(), width);
405  LLVM_DEBUG(lec::printExpr(expr));
406  LLVM_DEBUG(lec::printValue(value));
407  auto exprInsertion = exprTable.insert(std::pair(value, expr));
408  (void)exprInsertion; // Suppress Warning
409  assert(exprInsertion.second && "Value not inserted in expression table");
410  Builder builder(solver.mlirCtx);
411  StringAttr symbol = builder.getStringAttr(valueName);
412  auto symInsertion = solver.symbolTable.insert(std::pair(symbol, value));
413  (void)symInsertion; // Suppress Warning
414  assert(symInsertion.second && "Value not inserted in symbol table");
415  }
416  return expr;
417 }
418 
419 /// Allocates a constant value in the logical backend and returns its
420 /// representing expression.
421 void Solver::Circuit::allocateConstant(Value result, const APInt &value) {
422  // `The constant operation produces a constant value
423  // of standard integer type without a sign`
424  const z3::expr constant =
425  solver.context.bv_val(value.getZExtValue(), value.getBitWidth());
426  // Check whether the constant has been pre-allocated
427  auto allocatedPair = exprTable.find(result);
428  if (allocatedPair == exprTable.end()) {
429  // If not, then allocate
430  auto insertion = exprTable.insert(std::pair(result, constant));
431  (void)insertion; // suppress warning
432  assert(insertion.second && "Constant not inserted in expression table");
433  LLVM_DEBUG(lec::printExpr(constant));
434  LLVM_DEBUG(lec::printValue(result));
435  } else {
436  // If it has, then we force equivalence to the constant (we cannot just
437  // overwrite in the table as when it was allocated, a constraint was already
438  // formed using the symbolic form).
439  solver.solver.add(allocatedPair->second == constant);
440  LLVM_DEBUG(lec::dbgs() << "constraining symbolic value to constant:\n");
441  lec::Scope indent;
442  LLVM_DEBUG(lec::printExpr(constant));
443  LLVM_DEBUG(lec::printValue(result));
444  }
445 }
446 
447 /// Constrains the result of a MLIR operation to be equal a given logical
448 /// express, simulating an assignment.
449 void Solver::Circuit::constrainResult(Value &result, z3::expr &expr) {
450  LLVM_DEBUG(lec::dbgs() << "constraining result:\n");
451  lec::Scope indent;
452  {
453  LLVM_DEBUG(lec::dbgs() << "result expression:\n");
454  lec::Scope indent;
455  LLVM_DEBUG(lec::printExpr(expr));
456  }
457  z3::expr resExpr = fetchOrAllocateExpr(result);
458  z3::expr constraint = resExpr == expr;
459  {
460  LLVM_DEBUG(lec::dbgs() << "adding constraint:\n");
461  lec::Scope indent;
462  LLVM_DEBUG(lec::dbgs() << constraint.to_string() << "\n");
463  }
464  solver.solver.add(constraint);
465 }
466 
467 /// Convert from bitvector to bool sort.
468 z3::expr Solver::Circuit::bvToBool(const z3::expr &condition) {
469  // bitvector is true if it's different from 0
470  return condition != 0;
471 }
472 
473 /// Convert from a boolean sort to the corresponding 1-width bitvector.
474 z3::expr Solver::Circuit::boolToBv(const z3::expr &condition) {
475  return z3::ite(condition, solver.context.bv_val(1, 1),
476  solver.context.bv_val(0, 1));
477 }
lowerAnnotationsNoRefTypePorts FirtoolPreserveValuesMode value
Definition: Firtool.cpp:95
assert(baseType &&"element must be base type")
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
Definition: CalyxOps.cpp:119
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition: CalyxOps.cpp:538
int32_t width
Definition: FIRRTL.cpp:27
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
A class traversing MLIR IR to extrapolate the logic of a given circuit.
Definition: LogicExporter.h:35
mlir::LogicalResult run(mlir::ModuleOp &module)
Initializes the exporting by visiting the builtin module.
The representation of a circuit within a logical engine.
Definition: Circuit.h:35
llvm::SmallVector< z3::expr > outputs
The list for the circuit's outputs.
Definition: Circuit.h:112
llvm::SmallVector< z3::expr > inputs
The list for the circuit's inputs.
Definition: Circuit.h:110
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
void printExpr(const z3::expr &expr)
Helper function to provide a common debug formatting for z3 expressions.
Definition: Utility.h:51
void printValue(const mlir::Value &value)
Helper function to provide a common debug formatting for MLIR values.
Definition: Utility.h:59
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28
RAII struct to indent the output streams.
Definition: Utility.h:44