CIRCT 23.0.0git
Loading...
Searching...
No Matches
CellIFT.cpp
Go to the documentation of this file.
1//===- CellIFT.cpp - Cell-level Information Flow Tracking -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass instruments HW/Comb/Seq IR with parallel taint-tracking logic
10// using the CellIFT methodology (cell-level dynamic information flow tracking).
11//
12// For every data value v : iN, the pass creates a parallel taint value
13// v_t : iN. Module signatures are extended with taint ports, instances are
14// updated accordingly, and combinational/sequential operations are instrumented
15// with taint propagation rules that operate at the macrocell level.
16//
17// References:
18// F. Solt, B. Gras., and K. Razavi., "CellIFT: Leveraging Cells for
19// Scalable and Precise Dynamic Information Flow Tracking in RTL,"
20// USENIX Security 2022.
21//
22//===----------------------------------------------------------------------===//
23
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/BuiltinOps.h"
32#include "mlir/IR/ImplicitLocOpBuilder.h"
33#include "mlir/IR/Threading.h"
34#include "mlir/Pass/Pass.h"
35#include "llvm/ADT/DenseMap.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/StringSet.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include "llvm/Support/Debug.h"
40
41#include <type_traits>
42
43#define DEBUG_TYPE "cellift-instrument"
44
45namespace circt {
46#define GEN_PASS_DEF_CELLIFTINSTRUMENT
47#include "circt/Transforms/Passes.h.inc"
48} // namespace circt
49
50using namespace mlir;
51using namespace circt;
52using namespace circt::hw;
53using namespace circt::comb;
54using namespace circt::seq;
55
56//===----------------------------------------------------------------------===//
57// CellIFT Pass
58//===----------------------------------------------------------------------===//
59
60namespace {
61
62using TaintMap = DenseMap<Value, Value>;
63using PendingTaintMap = DenseMap<Value, Backedge>;
64
65struct ModulePortTaintInfo {
66 SmallVector<unsigned> taintInputSources;
67 SmallVector<unsigned> taintOutputSources;
68};
69
70struct ModuleInstrumentationInfo {
71 HWModuleOp mod;
72 ModulePortTaintInfo portInfo;
73};
74
75Value getZero(ImplicitLocOpBuilder &b, Type ty);
76Value getAllOnes(ImplicitLocOpBuilder &b, Type ty);
77Value orReduce(ImplicitLocOpBuilder &b, Value v);
78Value broadcast(ImplicitLocOpBuilder &b, Value bit, Type ty);
79Value getTaint(TaintMap &taintOf, PendingTaintMap &pendingTaints,
80 BackedgeBuilder &backedgeBuilder, Value v);
81void setTaint(TaintMap &taintOf, PendingTaintMap &pendingTaints, Value v,
82 Value taint);
83Value conservativeTaint(ImplicitLocOpBuilder &b, ValueRange taintInputs,
84 Type resTy);
85Value conservativeTaint(ImplicitLocOpBuilder &b, TaintMap &taintOf,
86 PendingTaintMap &pendingTaints,
87 BackedgeBuilder &backedgeBuilder, ValueRange inputs,
88 Type resTy);
89
90template <typename ShiftOp>
91using EnableIfCombShiftOp =
92 std::enable_if_t<std::is_same_v<ShiftOp, comb::ShlOp> ||
93 std::is_same_v<ShiftOp, comb::ShrUOp> ||
94 std::is_same_v<ShiftOp, comb::ShrSOp>>;
95
96Value instrumentOperation(ImplicitLocOpBuilder &b, hw::ConstantOp op,
97 bool taintConstants);
98Value instrumentOperation(ImplicitLocOpBuilder &b, comb::AndOp op,
99 comb::AndOp::Adaptor taintAdaptor);
100Value instrumentOperation(ImplicitLocOpBuilder &b, comb::OrOp op,
101 comb::OrOp::Adaptor taintAdaptor);
102Value instrumentOperation(ImplicitLocOpBuilder &b, comb::XorOp op,
103 comb::XorOp::Adaptor taintAdaptor);
104Value instrumentOperation(ImplicitLocOpBuilder &b, comb::AddOp op,
105 comb::AddOp::Adaptor taintAdaptor);
106Value instrumentOperation(ImplicitLocOpBuilder &b, comb::SubOp op,
107 comb::SubOp::Adaptor taintAdaptor);
108Value instrumentOperation(ImplicitLocOpBuilder &b, comb::MulOp op,
109 comb::MulOp::Adaptor taintAdaptor);
110Value instrumentOperation(ImplicitLocOpBuilder &b, comb::DivUOp op,
111 comb::DivUOp::Adaptor taintAdaptor);
112Value instrumentOperation(ImplicitLocOpBuilder &b, comb::DivSOp op,
113 comb::DivSOp::Adaptor taintAdaptor);
114Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ModUOp op,
115 comb::ModUOp::Adaptor taintAdaptor);
116Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ModSOp op,
117 comb::ModSOp::Adaptor taintAdaptor);
118Value instrumentOperation(ImplicitLocOpBuilder &b, comb::MuxOp op,
119 comb::MuxOp::Adaptor taintAdaptor);
120Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ConcatOp op,
121 comb::ConcatOp::Adaptor taintAdaptor);
122Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ExtractOp op,
123 comb::ExtractOp::Adaptor taintAdaptor);
124Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ReplicateOp op,
125 comb::ReplicateOp::Adaptor taintAdaptor);
126Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ICmpOp op,
127 comb::ICmpOp::Adaptor taintAdaptor);
128template <typename ShiftOp, typename = EnableIfCombShiftOp<ShiftOp>>
129Value instrumentOperation(ImplicitLocOpBuilder &b, ShiftOp op,
130 typename ShiftOp::Adaptor taintAdaptor);
131Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ParityOp op,
132 comb::ParityOp::Adaptor taintAdaptor);
133Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ReverseOp op,
134 comb::ReverseOp::Adaptor taintAdaptor);
135
136template <typename OpT>
137void instrumentSingleResultOperation(TaintMap &taintOf,
138 PendingTaintMap &pendingTaints,
139 BackedgeBuilder &backedgeBuilder,
140 ImplicitLocOpBuilder &b, OpT op) {
141 SmallVector<Value> taintOperands;
142 taintOperands.reserve(op->getNumOperands());
143 for (auto operand : op->getOperands())
144 taintOperands.push_back(
145 getTaint(taintOf, pendingTaints, backedgeBuilder, operand));
146 typename OpT::Adaptor taintAdaptor(taintOperands, op->getAttrDictionary());
147 setTaint(taintOf, pendingTaints, op.getResult(),
148 instrumentOperation(b, op, taintAdaptor));
149}
150
151void instrumentInstance(TaintMap &taintOf, PendingTaintMap &pendingTaints,
152 BackedgeBuilder &backedgeBuilder, StringRef taintSuffix,
153 hw::InstanceOp op, ImplicitLocOpBuilder &b);
154
155class CellIFTInstrumentPass
156 : public circt::impl::CellIFTInstrumentBase<CellIFTInstrumentPass> {
157public:
158 using CellIFTInstrumentBase::CellIFTInstrumentBase;
159 void runOnOperation() override;
160
161private:
162 // ---- Module-level passes ----------------------------------------------
163 LogicalResult rewriteModuleSignature(HWModuleOp mod, StringRef taintSuffix,
164 ModulePortTaintInfo &portInfo);
165 LogicalResult instrumentModuleBody(HWModuleOp mod,
166 const ModulePortTaintInfo &portInfo,
167 TaintMap &taintOf, bool taintConstants,
168 StringRef taintSuffix);
169 void rewriteOutputOp(HWModuleOp mod, const ModulePortTaintInfo &portInfo,
170 TaintMap &taintOf);
171};
172
173} // namespace
174
175//===----------------------------------------------------------------------===//
176// Taint Rules
177//===----------------------------------------------------------------------===//
178
179namespace {
180
181Value getZero(ImplicitLocOpBuilder &b, Type ty) {
182 return hw::ConstantOp::create(b, APInt(cast<IntegerType>(ty).getWidth(), 0));
183}
184
185Value getAllOnes(ImplicitLocOpBuilder &b, Type ty) {
187 b, APInt::getAllOnes(cast<IntegerType>(ty).getWidth()));
188}
189
190Value orReduce(ImplicitLocOpBuilder &b, Value v) {
191 auto width = cast<IntegerType>(v.getType()).getWidth();
192 if (width == 1)
193 return v;
194 return comb::ICmpOp::create(b, comb::ICmpPredicate::ne, v,
195 getZero(b, v.getType()));
196}
197
198Value broadcast(ImplicitLocOpBuilder &b, Value bit, Type ty) {
199 return b.createOrFold<comb::ReplicateOp>(ty, bit);
200}
201
202Value getTaint(TaintMap &taintOf, PendingTaintMap &pendingTaints,
203 BackedgeBuilder &backedgeBuilder, Value v) {
204 assert(isa<IntegerType>(v.getType()) && "can only taint integer values");
205 auto it = taintOf.find(v);
206 if (it != taintOf.end())
207 return it->second;
208
209 auto backedgeIt =
210 pendingTaints.try_emplace(v, backedgeBuilder.get(v.getType())).first;
211 Value placeholder = backedgeIt->second;
212 taintOf[v] = placeholder;
213 return placeholder;
214}
215
216void setTaint(TaintMap &taintOf, PendingTaintMap &pendingTaints, Value v,
217 Value taint) {
218 if (auto it = pendingTaints.find(v); it != pendingTaints.end()) {
219 it->second.setValue(taint);
220 pendingTaints.erase(it);
221 }
222 taintOf[v] = taint;
223}
224
225Value conservativeTaint(ImplicitLocOpBuilder &b, ValueRange taintInputs,
226 Type resTy) {
227 SmallVector<Value> bits;
228 bits.reserve(taintInputs.size());
229 for (auto taint : taintInputs)
230 bits.push_back(orReduce(b, taint));
231
232 if (bits.empty())
233 return getZero(b, resTy);
234
235 Value any = bits.size() == 1
236 ? bits.front()
237 : comb::OrOp::create(b, bits, /*twoState=*/false);
238 return broadcast(b, any, resTy);
239}
240
241Value conservativeTaint(ImplicitLocOpBuilder &b, TaintMap &taintOf,
242 PendingTaintMap &pendingTaints,
243 BackedgeBuilder &backedgeBuilder, ValueRange inputs,
244 Type resTy) {
245 SmallVector<Value> taintInputs;
246 taintInputs.reserve(inputs.size());
247 for (auto input : inputs)
248 if (isa<IntegerType>(input.getType()))
249 taintInputs.push_back(
250 getTaint(taintOf, pendingTaints, backedgeBuilder, input));
251 return conservativeTaint(b, taintInputs, resTy);
252}
253
254Value instrumentOperation(ImplicitLocOpBuilder &b, hw::ConstantOp op,
255 bool taintConstants) {
256 return taintConstants ? getAllOnes(b, op.getType())
257 : getZero(b, op.getType());
258}
259
260// AND: y_t = (a & b_t) | (b & a_t) | (a_t & b_t)
261Value instrumentOperation(ImplicitLocOpBuilder &b, comb::AndOp op,
262 comb::AndOp::Adaptor taintAdaptor) {
263 auto inputs = op.getInputs();
264 auto taintInputs = taintAdaptor.getInputs();
265 Value value = inputs.front();
266 Value taint = taintInputs.front();
267 for (auto [input, inputTaint] :
268 llvm::zip_equal(inputs.drop_front(), taintInputs.drop_front())) {
269 Value t1 = comb::AndOp::create(b, value, inputTaint);
270 Value t2 = comb::AndOp::create(b, input, taint);
271 Value t3 = comb::AndOp::create(b, taint, inputTaint);
272 taint = comb::OrOp::create(b, ValueRange{t1, t2, t3}, /*twoState=*/false);
273 value = comb::AndOp::create(b, value, input);
274 }
275 return taint;
276}
277
278// OR: y_t = (~a & b_t) | (~b & a_t) | (a_t & b_t)
279Value instrumentOperation(ImplicitLocOpBuilder &b, comb::OrOp op,
280 comb::OrOp::Adaptor taintAdaptor) {
281 auto inputs = op.getInputs();
282 auto taintInputs = taintAdaptor.getInputs();
283 Value value = inputs.front();
284 Value taint = taintInputs.front();
285 auto ones = getAllOnes(b, value.getType());
286 for (auto [input, inputTaint] :
287 llvm::zip_equal(inputs.drop_front(), taintInputs.drop_front())) {
288 Value notValue = comb::XorOp::create(b, value, ones);
289 Value notInput = comb::XorOp::create(b, input, ones);
290 Value t1 = comb::AndOp::create(b, notValue, inputTaint);
291 Value t2 = comb::AndOp::create(b, notInput, taint);
292 Value t3 = comb::AndOp::create(b, taint, inputTaint);
293 taint = comb::OrOp::create(b, ValueRange{t1, t2, t3}, /*twoState=*/false);
294 value = comb::OrOp::create(b, ValueRange{value, input},
295 /*twoState=*/false);
296 }
297 return taint;
298}
299
300// XOR: y_t = OR of all input taints.
301Value instrumentOperation(ImplicitLocOpBuilder &b, comb::XorOp,
302 comb::XorOp::Adaptor taintAdaptor) {
303 auto taintInputs = taintAdaptor.getInputs();
304 return taintInputs.size() == 1
305 ? taintInputs.front()
306 : comb::OrOp::create(b, taintInputs, /*twoState=*/false);
307}
308
309// ADD (precise): y_t = ((a&~a_t)+(b&~b_t)) XOR ((a|a_t)+(b|b_t)) | a_t | b_t
310Value instrumentOperation(ImplicitLocOpBuilder &b, comb::AddOp op,
311 comb::AddOp::Adaptor taintAdaptor) {
312 auto inputs = op.getInputs();
313 auto taintInputs = taintAdaptor.getInputs();
314 Value value = inputs.front();
315 Value taint = taintInputs.front();
316 auto ones = getAllOnes(b, value.getType());
317
318 for (auto [input, inputTaint] :
319 llvm::zip_equal(inputs.drop_front(), taintInputs.drop_front())) {
320 Value notTaint = comb::XorOp::create(b, taint, ones);
321 Value notInputTaint = comb::XorOp::create(b, inputTaint, ones);
322 Value valueZero = comb::AndOp::create(b, value, notTaint);
323 Value inputZero = comb::AndOp::create(b, input, notInputTaint);
324 Value valueOne = comb::OrOp::create(b, value, taint);
325 Value inputOne = comb::OrOp::create(b, input, inputTaint);
326 Value sumMin = comb::AddOp::create(b, valueZero, inputZero);
327 Value sumMax = comb::AddOp::create(b, valueOne, inputOne);
328 Value xorResult = comb::XorOp::create(b, sumMin, sumMax);
329 taint = comb::OrOp::create(b, ValueRange{xorResult, taint, inputTaint},
330 /*twoState=*/false);
331 value = comb::AddOp::create(b, value, input);
332 }
333 return taint;
334}
335
336// SUB (precise): y_t = ((a|a_t)-(b&~b_t)) XOR ((a&~a_t)-(b|b_t)) | a_t | b_t
337Value instrumentOperation(ImplicitLocOpBuilder &b, comb::SubOp op,
338 comb::SubOp::Adaptor taintAdaptor) {
339 Value a = op.getLhs();
340 Value aT = taintAdaptor.getLhs();
341 Value bv = op.getRhs();
342 Value bT = taintAdaptor.getRhs();
343 auto ones = getAllOnes(b, a.getType());
344
345 Value notAT = comb::XorOp::create(b, aT, ones);
346 Value notBT = comb::XorOp::create(b, bT, ones);
347 Value aOne = comb::OrOp::create(b, a, aT);
348 Value bZero = comb::AndOp::create(b, bv, notBT);
349 Value aZero = comb::AndOp::create(b, a, notAT);
350 Value bOne = comb::OrOp::create(b, bv, bT);
351
352 Value sub1 = comb::SubOp::create(b, aOne, bZero);
353 Value sub2 = comb::SubOp::create(b, aZero, bOne);
354 Value xorResult = comb::XorOp::create(b, sub1, sub2);
355 return comb::OrOp::create(b, ValueRange{xorResult, aT, bT},
356 /*twoState=*/false);
357}
358
359Value instrumentOperation(ImplicitLocOpBuilder &b, comb::MulOp op,
360 comb::MulOp::Adaptor taintAdaptor) {
361 return conservativeTaint(b, taintAdaptor.getInputs(), op.getType());
362}
363
364// DIV (conservative): any tainted input taints the full result.
365Value instrumentOperation(ImplicitLocOpBuilder &b, comb::DivUOp op,
366 comb::DivUOp::Adaptor taintAdaptor) {
367 SmallVector<Value> taintInputs{taintAdaptor.getLhs(), taintAdaptor.getRhs()};
368 return conservativeTaint(b, taintInputs, op.getType());
369}
370
371Value instrumentOperation(ImplicitLocOpBuilder &b, comb::DivSOp op,
372 comb::DivSOp::Adaptor taintAdaptor) {
373 SmallVector<Value> taintInputs{taintAdaptor.getLhs(), taintAdaptor.getRhs()};
374 return conservativeTaint(b, taintInputs, op.getType());
375}
376
377// MOD (conservative): any tainted input taints the full result.
378Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ModUOp op,
379 comb::ModUOp::Adaptor taintAdaptor) {
380 SmallVector<Value> taintInputs{taintAdaptor.getLhs(), taintAdaptor.getRhs()};
381 return conservativeTaint(b, taintInputs, op.getType());
382}
383
384Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ModSOp op,
385 comb::ModSOp::Adaptor taintAdaptor) {
386 SmallVector<Value> taintInputs{taintAdaptor.getLhs(), taintAdaptor.getRhs()};
387 return conservativeTaint(b, taintInputs, op.getType());
388}
389
390// MUX: y_t = mux(sel, t_t, f_t) | replicate(sel_t) & (t^f | t_t | f_t)
391Value instrumentOperation(ImplicitLocOpBuilder &b, comb::MuxOp op,
392 comb::MuxOp::Adaptor taintAdaptor) {
393 Value dataTaint =
394 comb::MuxOp::create(b, op.getCond(), taintAdaptor.getTrueValue(),
395 taintAdaptor.getFalseValue());
396 Value selBroad = broadcast(b, taintAdaptor.getCond(), op.getType());
397 Value diff = comb::XorOp::create(b, op.getTrueValue(), op.getFalseValue());
398 Value inner = comb::OrOp::create(b,
399 ValueRange{diff, taintAdaptor.getTrueValue(),
400 taintAdaptor.getFalseValue()},
401 /*twoState=*/false);
402 Value ctrlTaint = comb::AndOp::create(b, selBroad, inner);
403 return comb::OrOp::create(b, dataTaint, ctrlTaint);
404}
405
406// CONCAT: y_t = concat(each input_t)
407Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ConcatOp,
408 comb::ConcatOp::Adaptor taintAdaptor) {
409 return comb::ConcatOp::create(b, taintAdaptor.getInputs());
410}
411
412// EXTRACT: y_t = extract(input_t, lowBit)
413Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ExtractOp op,
414 comb::ExtractOp::Adaptor taintAdaptor) {
415 return comb::ExtractOp::create(b, op.getResult().getType(),
416 taintAdaptor.getInput(), op.getLowBit());
417}
418
419// REPLICATE: y_t = replicate(input_t)
420Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ReplicateOp op,
421 comb::ReplicateOp::Adaptor taintAdaptor) {
422 return comb::ReplicateOp::create(b, op.getResult().getType(),
423 taintAdaptor.getInput());
424}
425
426// ICMP: precise rules per supported predicate with conservative fallback.
427Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ICmpOp op,
428 comb::ICmpOp::Adaptor taintAdaptor) {
429 Value a = op.getLhs();
430 Value aT = taintAdaptor.getLhs();
431 Value bv = op.getRhs();
432 Value bT = taintAdaptor.getRhs();
433 auto pred = op.getPredicate();
434 auto ty = a.getType();
435 unsigned width = cast<IntegerType>(ty).getWidth();
436 auto ones = getAllOnes(b, ty);
437
438 if (pred == ICmpPredicate::ceq || pred == ICmpPredicate::cne ||
439 pred == ICmpPredicate::weq || pred == ICmpPredicate::wne) {
440 op.emitWarning() << "falling back to conservative taint propagation for "
441 << stringifyICmpPredicate(pred) << " predicate";
442 SmallVector<Value, 2> taintInputs{aT, bT};
443 return conservativeTaint(b, taintInputs, op.getType());
444 }
445
446 if (pred == ICmpPredicate::eq || pred == ICmpPredicate::ne) {
447 Value combined = comb::OrOp::create(b, aT, bT);
448 Value hasTaint = orReduce(b, combined);
449 Value mask = comb::XorOp::create(b, combined, ones); // ~(a_t | b_t)
450 Value maskedA = comb::AndOp::create(b, a, mask);
451 Value maskedB = comb::AndOp::create(b, bv, mask);
452 Value eqUntainted =
453 comb::ICmpOp::create(b, ICmpPredicate::eq, maskedA, maskedB);
454 return comb::AndOp::create(b, hasTaint, eqUntainted);
455 }
456
457 bool isSigned = pred == ICmpPredicate::slt || pred == ICmpPredicate::sle ||
458 pred == ICmpPredicate::sgt || pred == ICmpPredicate::sge;
459
460 Value notAT = comb::XorOp::create(b, aT, ones);
461 Value notBT = comb::XorOp::create(b, bT, ones);
462
463 Value minA, maxA, minB, maxB;
464 if (width == 1 || !isSigned) {
465 minA = comb::AndOp::create(b, a, notAT);
466 maxA = comb::OrOp::create(b, a, aT);
467 minB = comb::AndOp::create(b, bv, notBT);
468 maxB = comb::OrOp::create(b, bv, bT);
469 } else {
470 auto lsbTy = IntegerType::get(b.getContext(), width - 1);
471 auto bitTy = IntegerType::get(b.getContext(), 1);
472
473 Value aLsbs = comb::ExtractOp::create(b, lsbTy, a, 0);
474 Value aMsb = comb::ExtractOp::create(b, bitTy, a, width - 1);
475 Value aTLsbs = comb::ExtractOp::create(b, lsbTy, aT, 0);
476 Value aTMsb = comb::ExtractOp::create(b, bitTy, aT, width - 1);
477
478 Value bLsbs = comb::ExtractOp::create(b, lsbTy, bv, 0);
479 Value bMsb = comb::ExtractOp::create(b, bitTy, bv, width - 1);
480 Value bTLsbs = comb::ExtractOp::create(b, lsbTy, bT, 0);
481 Value bTMsb = comb::ExtractOp::create(b, bitTy, bT, width - 1);
482
483 auto lsbOnes = getAllOnes(b, lsbTy);
484 Value notATLsbs = comb::XorOp::create(b, aTLsbs, lsbOnes);
485 Value notBTLsbs = comb::XorOp::create(b, bTLsbs, lsbOnes);
486
487 auto bitOnes = getAllOnes(b, bitTy);
488 Value notATMsb = comb::XorOp::create(b, aTMsb, bitOnes);
489 Value notBTMsb = comb::XorOp::create(b, bTMsb, bitOnes);
490
491 Value minALsbs = comb::AndOp::create(b, aLsbs, notATLsbs);
492 Value maxALsbs = comb::OrOp::create(b, aLsbs, aTLsbs);
493 Value minBLsbs = comb::AndOp::create(b, bLsbs, notBTLsbs);
494 Value maxBLsbs = comb::OrOp::create(b, bLsbs, bTLsbs);
495
496 Value minAMsb = comb::OrOp::create(b, aMsb, aTMsb);
497 Value maxAMsb = comb::AndOp::create(b, aMsb, notATMsb);
498 Value minBMsb = comb::OrOp::create(b, bMsb, bTMsb);
499 Value maxBMsb = comb::AndOp::create(b, bMsb, notBTMsb);
500
501 minA = comb::ConcatOp::create(b, minAMsb, minALsbs);
502 maxA = comb::ConcatOp::create(b, maxAMsb, maxALsbs);
503 minB = comb::ConcatOp::create(b, minBMsb, minBLsbs);
504 maxB = comb::ConcatOp::create(b, maxBMsb, maxBLsbs);
505 }
506
507 Value cmp1 = comb::ICmpOp::create(b, pred, minA, maxB);
508 Value cmp2 = comb::ICmpOp::create(b, pred, maxA, minB);
509 return comb::XorOp::create(b, cmp1, cmp2);
510}
511
512template <typename ShiftOp>
513Value shiftTaintImprecise(ImplicitLocOpBuilder &b, Value dataT, Value amt,
514 Value amtT, Type resTy) {
515 Value shiftAmtTaint = broadcast(b, orReduce(b, amtT), resTy);
516 Value shiftedDataT = ShiftOp::create(b, dataT, amt);
517 return comb::OrOp::create(b, shiftAmtTaint, shiftedDataT);
518}
519
520template <typename ShiftOp, typename>
521Value instrumentOperation(ImplicitLocOpBuilder &b, ShiftOp op,
522 typename ShiftOp::Adaptor taintAdaptor) {
523 return shiftTaintImprecise<ShiftOp>(b, taintAdaptor.getLhs(), op.getRhs(),
524 taintAdaptor.getRhs(), op.getType());
525}
526
527// PARITY: y_t = OR-reduce(input_t)
528Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ParityOp,
529 comb::ParityOp::Adaptor taintAdaptor) {
530 return orReduce(b, taintAdaptor.getInput());
531}
532
533// REVERSE: y_t = reverse(input_t)
534Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ReverseOp op,
535 comb::ReverseOp::Adaptor taintAdaptor) {
536 return comb::ReverseOp::create(b, b.getLoc(), op.getType(),
537 taintAdaptor.getInput());
538}
539
540//===----------------------------------------------------------------------===//
541// Instance Instrumentation
542//===----------------------------------------------------------------------===//
543
544void instrumentInstance(TaintMap &taintOf, PendingTaintMap &pendingTaints,
545 BackedgeBuilder &backedgeBuilder, StringRef taintSuffix,
546 hw::InstanceOp op, ImplicitLocOpBuilder &b) {
547 SmallVector<Value> newInputs;
548 SmallVector<Attribute> newArgNames, newResNames;
549 SmallVector<Type> newResTys;
550
551 for (auto [input, nameAttr] :
552 llvm::zip_equal(op.getInputs(), op.getInputNames())) {
553 auto name = cast<StringAttr>(nameAttr);
554 newInputs.push_back(input);
555 newArgNames.push_back(name);
556
557 if (!isa<IntegerType>(input.getType()))
558 continue;
559
560 newInputs.push_back(
561 getTaint(taintOf, pendingTaints, backedgeBuilder, input));
562 newArgNames.push_back(
563 b.getStringAttr(name.getValue().str() + taintSuffix.str()));
564 }
565
566 for (auto [result, nameAttr] :
567 llvm::zip_equal(op.getResults(), op.getOutputNames())) {
568 auto name = cast<StringAttr>(nameAttr);
569 newResTys.push_back(result.getType());
570 newResNames.push_back(name);
571
572 if (!isa<IntegerType>(result.getType()))
573 continue;
574
575 newResTys.push_back(result.getType());
576 newResNames.push_back(
577 b.getStringAttr(name.getValue().str() + taintSuffix.str()));
578 }
579
580 auto newInst = hw::InstanceOp::create(
581 b, op.getLoc(), newResTys, op.getInstanceNameAttr(),
582 op.getModuleNameAttr(), newInputs, b.getArrayAttr(newArgNames),
583 b.getArrayAttr(newResNames), op.getParametersAttr(), op.getInnerSymAttr(),
584 op.getDoNotPrintAttr());
585
586 unsigned newResIdx = 0;
587 for (auto result : op.getResults()) {
588 Value newResult = newInst.getResult(newResIdx++);
589 if (isa<IntegerType>(result.getType())) {
590 Value newTaint = newInst.getResult(newResIdx++);
591 setTaint(taintOf, pendingTaints, result, newTaint);
592 taintOf[newResult] = newTaint;
593 }
594 result.replaceAllUsesWith(newResult);
595 }
596
597 op.erase();
598}
599
600} // namespace
601
602//===----------------------------------------------------------------------===//
603// Module Signature Rewriting
604//===----------------------------------------------------------------------===//
605
606LogicalResult CellIFTInstrumentPass::rewriteModuleSignature(
607 HWModuleOp mod, StringRef taintSuffix, ModulePortTaintInfo &portInfo) {
608 auto *ctx = mod.getContext();
609 auto portList = mod.getPortList();
610 Block *body = mod.getBodyBlock();
611
612 SmallVector<std::pair<unsigned, PortInfo>> insIns, insOuts;
613 struct BlockArgInsertion {
614 unsigned index;
615 Type type;
616 Location loc;
617 };
618 SmallVector<BlockArgInsertion> blockArgInsertions;
619 llvm::StringSet<> portNames;
620 for (auto port : portList)
621 portNames.insert(port.getName());
622
623 portInfo.taintInputSources.clear();
624 portInfo.taintOutputSources.clear();
625
626 // Insert a taint port after every integer-typed input/output.
627 // Also insert block arguments for the new input ports.
628 unsigned inIdx = 0, outIdx = 0;
629 unsigned blockArgInsertionOffset = 0;
630 for (auto port : portList) {
631 if (port.isInput()) {
632 if (isa<IntegerType>(port.type)) {
633 std::string taintName = port.getName().str() + taintSuffix.str();
634 if (!portNames.insert(taintName).second) {
635 mod.emitOpError() << "cannot add taint port '" << taintName
636 << "' because the name already exists; pick a "
637 "different taint suffix";
638 return failure();
639 }
640
641 PortInfo tp;
642 tp.name = StringAttr::get(ctx, taintName);
643 tp.type = port.type;
644 tp.dir = ModulePort::Direction::Input;
645 tp.loc = port.loc ? port.loc : UnknownLoc::get(ctx);
646 insIns.push_back({inIdx + 1, tp});
647 blockArgInsertions.push_back({inIdx + 1 + blockArgInsertionOffset,
648 tp.type, cast<Location>(tp.loc)});
649 portInfo.taintInputSources.push_back(inIdx);
650 blockArgInsertionOffset++;
651 }
652 inIdx++;
653 } else if (port.isOutput()) {
654 if (isa<IntegerType>(port.type)) {
655 std::string taintName = port.getName().str() + taintSuffix.str();
656 if (!portNames.insert(taintName).second) {
657 mod.emitOpError() << "cannot add taint port '" << taintName
658 << "' because the name already exists; pick a "
659 "different taint suffix";
660 return failure();
661 }
662
663 PortInfo tp;
664 tp.name = StringAttr::get(ctx, taintName);
665 tp.type = port.type;
666 tp.dir = ModulePort::Direction::Output;
667 tp.loc = port.loc ? port.loc : UnknownLoc::get(ctx);
668 insOuts.push_back({outIdx + 1, tp});
669 portInfo.taintOutputSources.push_back(outIdx);
670 }
671 outIdx++;
672 }
673 }
674
675 for (auto &insertion : blockArgInsertions)
676 body->insertArgument(insertion.index, insertion.type, insertion.loc);
677
678 mod.modifyPorts(insIns, insOuts, {}, {});
679 return success();
680}
681
682//===----------------------------------------------------------------------===//
683// Body Instrumentation
684//===----------------------------------------------------------------------===//
685
686LogicalResult CellIFTInstrumentPass::instrumentModuleBody(
687 HWModuleOp mod, const ModulePortTaintInfo &portInfo, TaintMap &taintOf,
688 bool taintConstants, StringRef taintSuffix) {
689 taintOf.clear();
690 Block *body = mod.getBodyBlock();
691 PendingTaintMap pendingTaints;
692
693 // Map input block args -> their taint block args.
694 for (auto [taintOrdinal, sourceIdx] :
695 llvm::enumerate(portInfo.taintInputSources)) {
696 unsigned origArgIdx = sourceIdx + taintOrdinal;
697 unsigned taintArgIdx = origArgIdx + 1;
698 setTaint(taintOf, pendingTaints, body->getArgument(origArgIdx),
699 body->getArgument(taintArgIdx));
700 }
701
702 ImplicitLocOpBuilder b(mod.getLoc(), mod.getContext());
703 BackedgeBuilder backedgeBuilder(b, mod.getLoc());
704
705 for (Operation &op : llvm::make_early_inc_range(*body)) {
706 if (isa<hw::OutputOp>(op))
707 continue;
708
709 b.setLoc(op.getLoc());
710 b.setInsertionPointAfter(&op);
711
712 llvm::TypeSwitch<Operation *>(&op)
713 .Case<seq::CompRegOp>([&](auto compreg) {
714 StringAttr name;
715 if (auto n = compreg.getNameAttr())
716 name = b.getStringAttr(n.getValue().str() + taintSuffix.str());
717
718 Value inputT = getTaint(taintOf, pendingTaints, backedgeBuilder,
719 compreg.getInput());
720 Value tReg;
721 if (compreg.getReset()) {
722 Value rstVal = getZero(b, compreg.getType());
723 tReg = seq::CompRegOp::create(b, compreg.getLoc(), inputT,
724 compreg.getClk(), compreg.getReset(),
725 rstVal, name);
726 } else {
727 tReg = seq::CompRegOp::create(b, compreg.getLoc(), inputT,
728 compreg.getClk(), name);
729 }
730 setTaint(taintOf, pendingTaints, compreg.getData(), tReg);
731 })
732 .Case<seq::FirRegOp>([&](auto firreg) {
733 StringAttr name =
734 b.getStringAttr(firreg.getName().str() + taintSuffix.str());
735
736 Value nextT = getTaint(taintOf, pendingTaints, backedgeBuilder,
737 firreg.getNext());
738 Value tReg;
739 if (firreg.hasReset()) {
740 Value rstVal = getZero(b, firreg.getType());
741 tReg = seq::FirRegOp::create(
742 b, nextT, firreg.getClk(), name, firreg.getReset(), rstVal,
743 firreg.getInnerSymAttr(), firreg.getIsAsync());
744 } else {
745 tReg = seq::FirRegOp::create(b, nextT, firreg.getClk(), name);
746 }
747 setTaint(taintOf, pendingTaints, firreg.getData(), tReg);
748 })
749 .Case<hw::ConstantOp>([&](auto o) {
750 setTaint(taintOf, pendingTaints, o.getResult(),
751 instrumentOperation(b, o, taintConstants));
752 })
756 comb::ReplicateOp, comb::ICmpOp, comb::ShlOp, comb::ShrUOp,
757 comb::ShrSOp, comb::ParityOp, comb::ReverseOp>([&](auto o) {
758 instrumentSingleResultOperation(taintOf, pendingTaints,
759 backedgeBuilder, b, o);
760 })
761 .Case<hw::InstanceOp>([&](auto o) {
762 instrumentInstance(taintOf, pendingTaints, backedgeBuilder,
763 taintSuffix, o, b);
764 })
765 .Default([&](Operation *o) {
766 // Conservative fallback for unknown ops with integer results.
767 for (auto res : o->getResults())
768 if (isa<IntegerType>(res.getType()))
769 setTaint(taintOf, pendingTaints, res,
770 conservativeTaint(b, taintOf, pendingTaints,
771 backedgeBuilder, o->getOperands(),
772 res.getType()));
773 });
774 }
775
776 return backedgeBuilder.clearOrEmitError();
777}
778
779void CellIFTInstrumentPass::rewriteOutputOp(HWModuleOp mod,
780 const ModulePortTaintInfo &portInfo,
781 TaintMap &taintOf) {
782 Block *body = mod.getBodyBlock();
783 auto outputOp = cast<hw::OutputOp>(body->getTerminator());
784
785 SmallVector<Value> newOuts;
786 ImplicitLocOpBuilder b(outputOp.getLoc(), outputOp);
787
788 unsigned nextTaintOutput = 0;
789 for (unsigned origOutputIdx = 0, e = outputOp.getNumOperands();
790 origOutputIdx < e; ++origOutputIdx) {
791 Value origOutput = outputOp.getOperand(origOutputIdx);
792 newOuts.push_back(origOutput);
793
794 if (nextTaintOutput < portInfo.taintOutputSources.size() &&
795 portInfo.taintOutputSources[nextTaintOutput] == origOutputIdx) {
796 newOuts.push_back(taintOf.at(origOutput));
797 ++nextTaintOutput;
798 }
799 }
800
801 hw::OutputOp::create(b, outputOp.getLoc(), newOuts);
802 outputOp.erase();
803}
804
805//===----------------------------------------------------------------------===//
806// Main
807//===----------------------------------------------------------------------===//
808
809void CellIFTInstrumentPass::runOnOperation() {
810 auto modules =
811 llvm::to_vector(getOperation().getBody()->getOps<HWModuleOp>());
812
813 std::string taintSuffix = this->taintSuffix;
814 bool taintConstants = this->taintConstants;
815
816 SmallVector<ModuleInstrumentationInfo> moduleInfos;
817 moduleInfos.reserve(modules.size());
818
819 for (auto mod : modules) {
820 ModuleInstrumentationInfo info{mod, {}};
821 if (failed(rewriteModuleSignature(mod, taintSuffix, info.portInfo))) {
822 signalPassFailure();
823 return;
824 }
825 moduleInfos.push_back(std::move(info));
826 }
827
828 if (failed(failableParallelForEach(
829 &getContext(), moduleInfos,
830 [&](ModuleInstrumentationInfo &info) -> LogicalResult {
831 TaintMap taintOf;
832 if (failed(instrumentModuleBody(info.mod, info.portInfo, taintOf,
833 taintConstants, taintSuffix)))
834 return failure();
835 rewriteOutputOp(info.mod, info.portInfo, taintOf);
836 return success();
837 })))
838 signalPassFailure();
839}
assert(baseType &&"element must be base type")
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
mlir::LogicalResult clearOrEmitError()
Clear the backedges, erasing any remaining cursor ops.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:433
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
Definition seq.py:157
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
mlir::Type type
Definition HWTypes.h:31
mlir::StringAttr name
Definition HWTypes.h:30
This holds the name, type, direction of a module's ports.