30#include "mlir/IR/Builders.h"
31#include "mlir/IR/BuiltinOps.h"
32#include "mlir/IR/ImplicitLocOpBuilder.h"
33#include "mlir/IR/Threading.h"
34#include "mlir/Pass/Pass.h"
35#include "llvm/ADT/DenseMap.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/StringSet.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include "llvm/Support/Debug.h"
43#define DEBUG_TYPE "cellift-instrument"
46#define GEN_PASS_DEF_CELLIFTINSTRUMENT
47#include "circt/Transforms/Passes.h.inc"
62using TaintMap = DenseMap<Value, Value>;
63using PendingTaintMap = DenseMap<Value, Backedge>;
65struct ModulePortTaintInfo {
66 SmallVector<unsigned> taintInputSources;
67 SmallVector<unsigned> taintOutputSources;
70struct ModuleInstrumentationInfo {
72 ModulePortTaintInfo portInfo;
75Value getZero(ImplicitLocOpBuilder &b, Type ty);
76Value getAllOnes(ImplicitLocOpBuilder &b, Type ty);
77Value orReduce(ImplicitLocOpBuilder &b, Value v);
78Value broadcast(ImplicitLocOpBuilder &b, Value bit, Type ty);
79Value getTaint(TaintMap &taintOf, PendingTaintMap &pendingTaints,
81void setTaint(TaintMap &taintOf, PendingTaintMap &pendingTaints, Value v,
83Value conservativeTaint(ImplicitLocOpBuilder &b, ValueRange taintInputs,
85Value conservativeTaint(ImplicitLocOpBuilder &b, TaintMap &taintOf,
86 PendingTaintMap &pendingTaints,
90template <
typename ShiftOp>
91using EnableIfCombShiftOp =
92 std::enable_if_t<std::is_same_v<ShiftOp, comb::ShlOp> ||
93 std::is_same_v<ShiftOp, comb::ShrUOp> ||
94 std::is_same_v<ShiftOp, comb::ShrSOp>>;
96Value instrumentOperation(ImplicitLocOpBuilder &b,
hw::ConstantOp op,
98Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::AndOp op,
99 comb::AndOp::Adaptor taintAdaptor);
100Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::OrOp op,
101 comb::OrOp::Adaptor taintAdaptor);
102Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::XorOp op,
103 comb::XorOp::Adaptor taintAdaptor);
104Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::AddOp op,
105 comb::AddOp::Adaptor taintAdaptor);
106Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::SubOp op,
107 comb::SubOp::Adaptor taintAdaptor);
108Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::MulOp op,
109 comb::MulOp::Adaptor taintAdaptor);
110Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::DivUOp op,
111 comb::DivUOp::Adaptor taintAdaptor);
112Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::DivSOp op,
113 comb::DivSOp::Adaptor taintAdaptor);
114Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::ModUOp op,
115 comb::ModUOp::Adaptor taintAdaptor);
116Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::ModSOp op,
117 comb::ModSOp::Adaptor taintAdaptor);
118Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::MuxOp op,
119 comb::MuxOp::Adaptor taintAdaptor);
120Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::ConcatOp op,
121 comb::ConcatOp::Adaptor taintAdaptor);
123 comb::ExtractOp::Adaptor taintAdaptor);
124Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ReplicateOp op,
125 comb::ReplicateOp::Adaptor taintAdaptor);
126Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ICmpOp op,
127 comb::ICmpOp::Adaptor taintAdaptor);
128template <
typename ShiftOp,
typename = EnableIfCombShiftOp<ShiftOp>>
129Value instrumentOperation(ImplicitLocOpBuilder &b, ShiftOp op,
130 typename ShiftOp::Adaptor taintAdaptor);
131Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::ParityOp op,
132 comb::ParityOp::Adaptor taintAdaptor);
133Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ReverseOp op,
134 comb::ReverseOp::Adaptor taintAdaptor);
136template <
typename OpT>
137void instrumentSingleResultOperation(TaintMap &taintOf,
138 PendingTaintMap &pendingTaints,
140 ImplicitLocOpBuilder &b, OpT op) {
141 SmallVector<Value> taintOperands;
142 taintOperands.reserve(op->getNumOperands());
143 for (
auto operand : op->getOperands())
144 taintOperands.push_back(
145 getTaint(taintOf, pendingTaints, backedgeBuilder, operand));
146 typename OpT::Adaptor taintAdaptor(taintOperands, op->getAttrDictionary());
147 setTaint(taintOf, pendingTaints, op.getResult(),
148 instrumentOperation(b, op, taintAdaptor));
151void instrumentInstance(TaintMap &taintOf, PendingTaintMap &pendingTaints,
153 hw::InstanceOp op, ImplicitLocOpBuilder &b);
155class CellIFTInstrumentPass
156 :
public circt::impl::CellIFTInstrumentBase<CellIFTInstrumentPass> {
158 using CellIFTInstrumentBase::CellIFTInstrumentBase;
159 void runOnOperation()
override;
163 LogicalResult rewriteModuleSignature(HWModuleOp mod, StringRef taintSuffix,
164 ModulePortTaintInfo &portInfo);
165 LogicalResult instrumentModuleBody(HWModuleOp mod,
166 const ModulePortTaintInfo &portInfo,
167 TaintMap &taintOf,
bool taintConstants,
168 StringRef taintSuffix);
169 void rewriteOutputOp(HWModuleOp mod,
const ModulePortTaintInfo &portInfo,
181Value getZero(ImplicitLocOpBuilder &b, Type ty) {
185Value getAllOnes(ImplicitLocOpBuilder &b, Type ty) {
187 b, APInt::getAllOnes(cast<IntegerType>(ty).
getWidth()));
190Value orReduce(ImplicitLocOpBuilder &b, Value v) {
191 auto width = cast<IntegerType>(v.getType()).getWidth();
194 return comb::ICmpOp::create(b, comb::ICmpPredicate::ne, v,
195 getZero(b, v.getType()));
198Value broadcast(ImplicitLocOpBuilder &b, Value bit, Type ty) {
199 return b.createOrFold<comb::ReplicateOp>(ty, bit);
202Value getTaint(TaintMap &taintOf, PendingTaintMap &pendingTaints,
204 assert(isa<IntegerType>(v.getType()) &&
"can only taint integer values");
205 auto it = taintOf.find(v);
206 if (it != taintOf.end())
210 pendingTaints.try_emplace(v, backedgeBuilder.
get(v.getType())).first;
211 Value placeholder = backedgeIt->second;
212 taintOf[v] = placeholder;
216void setTaint(TaintMap &taintOf, PendingTaintMap &pendingTaints, Value v,
218 if (
auto it = pendingTaints.find(v); it != pendingTaints.end()) {
219 it->second.setValue(taint);
220 pendingTaints.erase(it);
225Value conservativeTaint(ImplicitLocOpBuilder &b, ValueRange taintInputs,
227 SmallVector<Value> bits;
228 bits.reserve(taintInputs.size());
229 for (
auto taint : taintInputs)
230 bits.push_back(orReduce(
b, taint));
233 return getZero(b, resTy);
235 Value any = bits.size() == 1
237 : comb::OrOp::create(b, bits,
false);
238 return broadcast(b, any, resTy);
241Value conservativeTaint(ImplicitLocOpBuilder &b, TaintMap &taintOf,
242 PendingTaintMap &pendingTaints,
245 SmallVector<Value> taintInputs;
246 taintInputs.reserve(inputs.size());
247 for (
auto input : inputs)
248 if (isa<IntegerType>(input.getType()))
249 taintInputs.push_back(
250 getTaint(taintOf, pendingTaints, backedgeBuilder, input));
251 return conservativeTaint(b, taintInputs, resTy);
254Value instrumentOperation(ImplicitLocOpBuilder &b,
hw::ConstantOp op,
255 bool taintConstants) {
256 return taintConstants ? getAllOnes(b, op.getType())
257 : getZero(
b, op.getType());
261Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::AndOp op,
262 comb::AndOp::Adaptor taintAdaptor) {
263 auto inputs = op.getInputs();
264 auto taintInputs = taintAdaptor.getInputs();
265 Value value = inputs.front();
266 Value taint = taintInputs.front();
267 for (
auto [input, inputTaint] :
268 llvm::zip_equal(inputs.drop_front(), taintInputs.drop_front())) {
269 Value t1 = comb::AndOp::create(b, value, inputTaint);
270 Value t2 = comb::AndOp::create(b, input, taint);
271 Value t3 = comb::AndOp::create(b, taint, inputTaint);
272 taint = comb::OrOp::create(b, ValueRange{t1, t2, t3},
false);
273 value = comb::AndOp::create(b, value, input);
279Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::OrOp op,
280 comb::OrOp::Adaptor taintAdaptor) {
281 auto inputs = op.getInputs();
282 auto taintInputs = taintAdaptor.getInputs();
283 Value value = inputs.front();
284 Value taint = taintInputs.front();
285 auto ones = getAllOnes(b, value.getType());
286 for (
auto [input, inputTaint] :
287 llvm::zip_equal(inputs.drop_front(), taintInputs.drop_front())) {
288 Value notValue = comb::XorOp::create(b, value, ones);
289 Value notInput = comb::XorOp::create(b, input, ones);
290 Value t1 = comb::AndOp::create(b, notValue, inputTaint);
291 Value t2 = comb::AndOp::create(b, notInput, taint);
292 Value t3 = comb::AndOp::create(b, taint, inputTaint);
293 taint = comb::OrOp::create(b, ValueRange{t1, t2, t3},
false);
294 value = comb::OrOp::create(b, ValueRange{value, input},
301Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::XorOp,
302 comb::XorOp::Adaptor taintAdaptor) {
303 auto taintInputs = taintAdaptor.getInputs();
304 return taintInputs.size() == 1
305 ? taintInputs.front()
306 : comb::OrOp::create(b, taintInputs,
false);
310Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::AddOp op,
311 comb::AddOp::Adaptor taintAdaptor) {
312 auto inputs = op.getInputs();
313 auto taintInputs = taintAdaptor.getInputs();
314 Value value = inputs.front();
315 Value taint = taintInputs.front();
316 auto ones = getAllOnes(b, value.getType());
318 for (
auto [input, inputTaint] :
319 llvm::zip_equal(inputs.drop_front(), taintInputs.drop_front())) {
320 Value notTaint = comb::XorOp::create(b, taint, ones);
321 Value notInputTaint = comb::XorOp::create(b, inputTaint, ones);
322 Value valueZero = comb::AndOp::create(b, value, notTaint);
323 Value inputZero = comb::AndOp::create(b, input, notInputTaint);
324 Value valueOne = comb::OrOp::create(b, value, taint);
325 Value inputOne = comb::OrOp::create(b, input, inputTaint);
326 Value sumMin = comb::AddOp::create(b, valueZero, inputZero);
327 Value sumMax = comb::AddOp::create(b, valueOne, inputOne);
328 Value xorResult = comb::XorOp::create(b, sumMin, sumMax);
329 taint = comb::OrOp::create(b, ValueRange{xorResult, taint, inputTaint},
331 value = comb::AddOp::create(b, value, input);
337Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::SubOp op,
338 comb::SubOp::Adaptor taintAdaptor) {
339 Value
a = op.getLhs();
340 Value aT = taintAdaptor.getLhs();
341 Value bv = op.getRhs();
342 Value bT = taintAdaptor.getRhs();
343 auto ones = getAllOnes(b,
a.getType());
345 Value notAT = comb::XorOp::create(b, aT, ones);
346 Value notBT = comb::XorOp::create(b, bT, ones);
347 Value aOne = comb::OrOp::create(b, a, aT);
348 Value bZero = comb::AndOp::create(b, bv, notBT);
349 Value aZero = comb::AndOp::create(b, a, notAT);
350 Value bOne = comb::OrOp::create(b, bv, bT);
352 Value sub1 = comb::SubOp::create(b, aOne, bZero);
353 Value sub2 = comb::SubOp::create(b, aZero, bOne);
354 Value xorResult = comb::XorOp::create(b, sub1, sub2);
355 return comb::OrOp::create(b, ValueRange{xorResult, aT, bT},
359Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::MulOp op,
360 comb::MulOp::Adaptor taintAdaptor) {
361 return conservativeTaint(b, taintAdaptor.getInputs(), op.getType());
365Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::DivUOp op,
366 comb::DivUOp::Adaptor taintAdaptor) {
367 SmallVector<Value> taintInputs{taintAdaptor.getLhs(), taintAdaptor.getRhs()};
368 return conservativeTaint(b, taintInputs, op.getType());
371Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::DivSOp op,
372 comb::DivSOp::Adaptor taintAdaptor) {
373 SmallVector<Value> taintInputs{taintAdaptor.getLhs(), taintAdaptor.getRhs()};
374 return conservativeTaint(b, taintInputs, op.getType());
378Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::ModUOp op,
379 comb::ModUOp::Adaptor taintAdaptor) {
380 SmallVector<Value> taintInputs{taintAdaptor.getLhs(), taintAdaptor.getRhs()};
381 return conservativeTaint(b, taintInputs, op.getType());
384Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::ModSOp op,
385 comb::ModSOp::Adaptor taintAdaptor) {
386 SmallVector<Value> taintInputs{taintAdaptor.getLhs(), taintAdaptor.getRhs()};
387 return conservativeTaint(b, taintInputs, op.getType());
391Value instrumentOperation(ImplicitLocOpBuilder &b,
comb::MuxOp op,
392 comb::MuxOp::Adaptor taintAdaptor) {
394 comb::MuxOp::create(b, op.getCond(), taintAdaptor.getTrueValue(),
395 taintAdaptor.getFalseValue());
396 Value selBroad = broadcast(b, taintAdaptor.getCond(), op.getType());
397 Value diff = comb::XorOp::create(b, op.getTrueValue(), op.getFalseValue());
398 Value inner = comb::OrOp::create(b,
399 ValueRange{diff, taintAdaptor.getTrueValue(),
400 taintAdaptor.getFalseValue()},
402 Value ctrlTaint = comb::AndOp::create(b, selBroad, inner);
403 return comb::OrOp::create(b, dataTaint, ctrlTaint);
408 comb::ConcatOp::Adaptor taintAdaptor) {
409 return comb::ConcatOp::create(b, taintAdaptor.getInputs());
414 comb::ExtractOp::Adaptor taintAdaptor) {
416 taintAdaptor.getInput(), op.getLowBit());
420Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ReplicateOp op,
421 comb::ReplicateOp::Adaptor taintAdaptor) {
422 return comb::ReplicateOp::create(b, op.getResult().getType(),
423 taintAdaptor.getInput());
427Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ICmpOp op,
428 comb::ICmpOp::Adaptor taintAdaptor) {
429 Value
a = op.getLhs();
430 Value aT = taintAdaptor.getLhs();
431 Value bv = op.getRhs();
432 Value bT = taintAdaptor.getRhs();
433 auto pred = op.getPredicate();
434 auto ty =
a.getType();
435 unsigned width = cast<IntegerType>(ty).getWidth();
436 auto ones = getAllOnes(b, ty);
438 if (pred == ICmpPredicate::ceq || pred == ICmpPredicate::cne ||
439 pred == ICmpPredicate::weq || pred == ICmpPredicate::wne) {
440 op.emitWarning() <<
"falling back to conservative taint propagation for "
441 << stringifyICmpPredicate(pred) <<
" predicate";
442 SmallVector<Value, 2> taintInputs{aT, bT};
443 return conservativeTaint(b, taintInputs, op.getType());
446 if (pred == ICmpPredicate::eq || pred == ICmpPredicate::ne) {
447 Value combined = comb::OrOp::create(b, aT, bT);
448 Value hasTaint = orReduce(b, combined);
449 Value
mask = comb::XorOp::create(b, combined, ones);
450 Value maskedA = comb::AndOp::create(b, a, mask);
451 Value maskedB = comb::AndOp::create(b, bv, mask);
453 comb::ICmpOp::create(b, ICmpPredicate::eq, maskedA, maskedB);
454 return comb::AndOp::create(b, hasTaint, eqUntainted);
457 bool isSigned = pred == ICmpPredicate::slt || pred == ICmpPredicate::sle ||
458 pred == ICmpPredicate::sgt || pred == ICmpPredicate::sge;
460 Value notAT = comb::XorOp::create(b, aT, ones);
461 Value notBT = comb::XorOp::create(b, bT, ones);
463 Value minA, maxA, minB, maxB;
464 if (width == 1 || !isSigned) {
465 minA = comb::AndOp::create(b, a, notAT);
466 maxA = comb::OrOp::create(b, a, aT);
467 minB = comb::AndOp::create(b, bv, notBT);
468 maxB = comb::OrOp::create(b, bv, bT);
470 auto lsbTy = IntegerType::get(
b.getContext(), width - 1);
471 auto bitTy = IntegerType::get(
b.getContext(), 1);
483 auto lsbOnes = getAllOnes(b, lsbTy);
484 Value notATLsbs = comb::XorOp::create(b, aTLsbs, lsbOnes);
485 Value notBTLsbs = comb::XorOp::create(b, bTLsbs, lsbOnes);
487 auto bitOnes = getAllOnes(b, bitTy);
488 Value notATMsb = comb::XorOp::create(b, aTMsb, bitOnes);
489 Value notBTMsb = comb::XorOp::create(b, bTMsb, bitOnes);
491 Value minALsbs = comb::AndOp::create(b, aLsbs, notATLsbs);
492 Value maxALsbs = comb::OrOp::create(b, aLsbs, aTLsbs);
493 Value minBLsbs = comb::AndOp::create(b, bLsbs, notBTLsbs);
494 Value maxBLsbs = comb::OrOp::create(b, bLsbs, bTLsbs);
496 Value minAMsb = comb::OrOp::create(b, aMsb, aTMsb);
497 Value maxAMsb = comb::AndOp::create(b, aMsb, notATMsb);
498 Value minBMsb = comb::OrOp::create(b, bMsb, bTMsb);
499 Value maxBMsb = comb::AndOp::create(b, bMsb, notBTMsb);
501 minA = comb::ConcatOp::create(b, minAMsb, minALsbs);
502 maxA = comb::ConcatOp::create(b, maxAMsb, maxALsbs);
503 minB = comb::ConcatOp::create(b, minBMsb, minBLsbs);
504 maxB = comb::ConcatOp::create(b, maxBMsb, maxBLsbs);
507 Value cmp1 = comb::ICmpOp::create(b, pred, minA, maxB);
508 Value cmp2 = comb::ICmpOp::create(b, pred, maxA, minB);
509 return comb::XorOp::create(b, cmp1, cmp2);
512template <
typename ShiftOp>
513Value shiftTaintImprecise(ImplicitLocOpBuilder &b, Value dataT, Value amt,
514 Value amtT, Type resTy) {
515 Value shiftAmtTaint = broadcast(b, orReduce(b, amtT), resTy);
516 Value shiftedDataT = ShiftOp::create(b, dataT, amt);
517 return comb::OrOp::create(b, shiftAmtTaint, shiftedDataT);
520template <
typename ShiftOp,
typename>
521Value instrumentOperation(ImplicitLocOpBuilder &b, ShiftOp op,
522 typename ShiftOp::Adaptor taintAdaptor) {
523 return shiftTaintImprecise<ShiftOp>(b, taintAdaptor.getLhs(), op.getRhs(),
524 taintAdaptor.getRhs(), op.getType());
529 comb::ParityOp::Adaptor taintAdaptor) {
530 return orReduce(b, taintAdaptor.getInput());
534Value instrumentOperation(ImplicitLocOpBuilder &b, comb::ReverseOp op,
535 comb::ReverseOp::Adaptor taintAdaptor) {
536 return comb::ReverseOp::create(b,
b.getLoc(), op.getType(),
537 taintAdaptor.getInput());
544void instrumentInstance(TaintMap &taintOf, PendingTaintMap &pendingTaints,
546 hw::InstanceOp op, ImplicitLocOpBuilder &b) {
547 SmallVector<Value> newInputs;
548 SmallVector<Attribute> newArgNames, newResNames;
549 SmallVector<Type> newResTys;
551 for (
auto [input, nameAttr] :
552 llvm::zip_equal(op.getInputs(), op.getInputNames())) {
553 auto name = cast<StringAttr>(nameAttr);
554 newInputs.push_back(input);
555 newArgNames.push_back(name);
557 if (!isa<IntegerType>(input.getType()))
561 getTaint(taintOf, pendingTaints, backedgeBuilder, input));
562 newArgNames.push_back(
563 b.getStringAttr(name.getValue().str() + taintSuffix.str()));
566 for (
auto [result, nameAttr] :
567 llvm::zip_equal(op.getResults(), op.getOutputNames())) {
568 auto name = cast<StringAttr>(nameAttr);
569 newResTys.push_back(result.getType());
570 newResNames.push_back(name);
572 if (!isa<IntegerType>(result.getType()))
575 newResTys.push_back(result.getType());
576 newResNames.push_back(
577 b.getStringAttr(name.getValue().str() + taintSuffix.str()));
580 auto newInst = hw::InstanceOp::create(
581 b, op.getLoc(), newResTys, op.getInstanceNameAttr(),
582 op.getModuleNameAttr(), newInputs,
b.getArrayAttr(newArgNames),
583 b.getArrayAttr(newResNames), op.getParametersAttr(), op.getInnerSymAttr(),
584 op.getDoNotPrintAttr());
586 unsigned newResIdx = 0;
587 for (
auto result : op.getResults()) {
588 Value newResult = newInst.getResult(newResIdx++);
589 if (isa<IntegerType>(result.getType())) {
590 Value newTaint = newInst.getResult(newResIdx++);
591 setTaint(taintOf, pendingTaints, result, newTaint);
592 taintOf[newResult] = newTaint;
594 result.replaceAllUsesWith(newResult);
606LogicalResult CellIFTInstrumentPass::rewriteModuleSignature(
607 HWModuleOp mod, StringRef taintSuffix, ModulePortTaintInfo &portInfo) {
608 auto *ctx = mod.getContext();
609 auto portList = mod.getPortList();
610 Block *body = mod.getBodyBlock();
612 SmallVector<std::pair<unsigned, PortInfo>> insIns, insOuts;
613 struct BlockArgInsertion {
618 SmallVector<BlockArgInsertion> blockArgInsertions;
619 llvm::StringSet<> portNames;
620 for (
auto port : portList)
621 portNames.insert(port.
getName());
623 portInfo.taintInputSources.clear();
624 portInfo.taintOutputSources.clear();
628 unsigned inIdx = 0, outIdx = 0;
629 unsigned blockArgInsertionOffset = 0;
630 for (
auto port : portList) {
631 if (port.isInput()) {
632 if (isa<IntegerType>(port.type)) {
633 std::string taintName = port.getName().str() + taintSuffix.str();
634 if (!portNames.insert(taintName).second) {
635 mod.emitOpError() <<
"cannot add taint port '" << taintName
636 <<
"' because the name already exists; pick a "
637 "different taint suffix";
642 tp.
name = StringAttr::get(ctx, taintName);
644 tp.
dir = ModulePort::Direction::Input;
645 tp.
loc = port.loc ? port.loc : UnknownLoc::get(ctx);
646 insIns.push_back({inIdx + 1, tp});
647 blockArgInsertions.push_back({inIdx + 1 + blockArgInsertionOffset,
648 tp.
type, cast<Location>(tp.
loc)});
649 portInfo.taintInputSources.push_back(inIdx);
650 blockArgInsertionOffset++;
653 }
else if (port.isOutput()) {
654 if (isa<IntegerType>(port.type)) {
655 std::string taintName = port.getName().str() + taintSuffix.str();
656 if (!portNames.insert(taintName).second) {
657 mod.emitOpError() <<
"cannot add taint port '" << taintName
658 <<
"' because the name already exists; pick a "
659 "different taint suffix";
664 tp.
name = StringAttr::get(ctx, taintName);
666 tp.
dir = ModulePort::Direction::Output;
667 tp.
loc = port.loc ? port.loc : UnknownLoc::get(ctx);
668 insOuts.push_back({outIdx + 1, tp});
669 portInfo.taintOutputSources.push_back(outIdx);
675 for (
auto &insertion : blockArgInsertions)
676 body->insertArgument(insertion.index, insertion.type, insertion.loc);
678 mod.modifyPorts(insIns, insOuts, {}, {});
686LogicalResult CellIFTInstrumentPass::instrumentModuleBody(
687 HWModuleOp mod,
const ModulePortTaintInfo &portInfo, TaintMap &taintOf,
688 bool taintConstants, StringRef taintSuffix) {
690 Block *body = mod.getBodyBlock();
691 PendingTaintMap pendingTaints;
694 for (
auto [taintOrdinal, sourceIdx] :
695 llvm::enumerate(portInfo.taintInputSources)) {
696 unsigned origArgIdx = sourceIdx + taintOrdinal;
697 unsigned taintArgIdx = origArgIdx + 1;
698 setTaint(taintOf, pendingTaints, body->getArgument(origArgIdx),
699 body->getArgument(taintArgIdx));
702 ImplicitLocOpBuilder
b(mod.getLoc(), mod.getContext());
705 for (Operation &op :
llvm::make_early_inc_range(*body)) {
706 if (isa<hw::OutputOp>(op))
709 b.setLoc(op.getLoc());
710 b.setInsertionPointAfter(&op);
712 llvm::TypeSwitch<Operation *>(&op)
715 if (
auto n = compreg.getNameAttr())
716 name =
b.getStringAttr(n.getValue().str() + taintSuffix.str());
718 Value inputT = getTaint(taintOf, pendingTaints, backedgeBuilder,
721 if (compreg.getReset()) {
722 Value rstVal = getZero(b, compreg.getType());
724 compreg.getClk(), compreg.getReset(),
728 compreg.getClk(), name);
730 setTaint(taintOf, pendingTaints, compreg.getData(), tReg);
732 .Case<seq::FirRegOp>([&](
auto firreg) {
734 b.getStringAttr(firreg.getName().str() + taintSuffix.str());
736 Value nextT = getTaint(taintOf, pendingTaints, backedgeBuilder,
739 if (firreg.hasReset()) {
740 Value rstVal = getZero(b, firreg.getType());
741 tReg = seq::FirRegOp::create(
742 b, nextT, firreg.getClk(), name, firreg.getReset(), rstVal,
743 firreg.getInnerSymAttr(), firreg.getIsAsync());
745 tReg = seq::FirRegOp::create(b, nextT, firreg.getClk(), name);
747 setTaint(taintOf, pendingTaints, firreg.getData(), tReg);
749 .Case<hw::ConstantOp>([&](
auto o) {
750 setTaint(taintOf, pendingTaints, o.getResult(),
751 instrumentOperation(b, o, taintConstants));
758 instrumentSingleResultOperation(taintOf, pendingTaints,
759 backedgeBuilder, b, o);
761 .Case<hw::InstanceOp>([&](
auto o) {
762 instrumentInstance(taintOf, pendingTaints, backedgeBuilder,
765 .Default([&](Operation *o) {
767 for (
auto res : o->getResults())
768 if (isa<IntegerType>(res.getType()))
769 setTaint(taintOf, pendingTaints, res,
770 conservativeTaint(
b, taintOf, pendingTaints,
771 backedgeBuilder, o->getOperands(),
779void CellIFTInstrumentPass::rewriteOutputOp(HWModuleOp mod,
780 const ModulePortTaintInfo &portInfo,
782 Block *body = mod.getBodyBlock();
783 auto outputOp = cast<hw::OutputOp>(body->getTerminator());
785 SmallVector<Value> newOuts;
786 ImplicitLocOpBuilder
b(outputOp.getLoc(), outputOp);
788 unsigned nextTaintOutput = 0;
789 for (
unsigned origOutputIdx = 0, e = outputOp.getNumOperands();
790 origOutputIdx < e; ++origOutputIdx) {
791 Value origOutput = outputOp.getOperand(origOutputIdx);
792 newOuts.push_back(origOutput);
794 if (nextTaintOutput < portInfo.taintOutputSources.size() &&
795 portInfo.taintOutputSources[nextTaintOutput] == origOutputIdx) {
796 newOuts.push_back(taintOf.at(origOutput));
801 hw::OutputOp::create(b, outputOp.getLoc(), newOuts);
809void CellIFTInstrumentPass::runOnOperation() {
811 llvm::to_vector(getOperation().getBody()->getOps<HWModuleOp>());
813 std::string taintSuffix = this->taintSuffix;
814 bool taintConstants = this->taintConstants;
816 SmallVector<ModuleInstrumentationInfo> moduleInfos;
817 moduleInfos.reserve(modules.size());
819 for (
auto mod : modules) {
820 ModuleInstrumentationInfo
info{mod, {}};
821 if (failed(rewriteModuleSignature(mod, taintSuffix,
info.portInfo))) {
825 moduleInfos.push_back(std::move(info));
828 if (failed(failableParallelForEach(
829 &getContext(), moduleInfos,
830 [&](ModuleInstrumentationInfo &info) -> LogicalResult {
832 if (failed(instrumentModuleBody(
info.mod,
info.portInfo, taintOf,
833 taintConstants, taintSuffix)))
835 rewriteOutputOp(
info.mod,
info.portInfo, taintOf);
assert(baseType &&"element must be base type")
Instantiate one of these and use it to build typed backedges.
Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc={})
Create a typed backedge.
mlir::LogicalResult clearOrEmitError()
Clear the backedges, erasing any remaining cursor ops.
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
uint64_t getWidth(Type t)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name, type, direction of a module's ports.