CIRCT  20.0.0git
HandshakeToDC.cpp
Go to the documentation of this file.
1 //===- HandshakeToDC.cpp - Translate Handshake into DC --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 
7 //===----------------------------------------------------------------------===//
8 //
9 // This is the main Handshake to DC Conversion Pass Implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "circt/Dialect/DC/DCOps.h"
18 #include "circt/Dialect/HW/HWOps.h"
24 #include "mlir/Dialect/Arith/IR/Arith.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Pass/PassManager.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "llvm/Support/MathExtras.h"
29 #include <optional>
30 
31 namespace circt {
32 #define GEN_PASS_DEF_HANDSHAKETODC
33 #include "circt/Conversion/Passes.h.inc"
34 } // namespace circt
35 
36 using namespace mlir;
37 using namespace circt;
38 using namespace handshake;
39 using namespace dc;
40 using namespace hw;
41 using namespace handshaketodc;
42 
43 namespace {
44 
45 struct DCTuple {
46  DCTuple() = default;
47  DCTuple(Value token, Value data) : token(token), data(data) {}
48  DCTuple(dc::UnpackOp unpack)
49  : token(unpack.getToken()), data(unpack.getOutput()) {}
50  Value token;
51  Value data;
52 };
53 
54 // Unpack a !dc.value<...> into a DCTuple.
55 static DCTuple unpack(OpBuilder &b, Value v) {
56  if (isa<dc::ValueType>(v.getType()))
57  return DCTuple(b.create<dc::UnpackOp>(v.getLoc(), v));
58  assert(isa<dc::TokenType>(v.getType()) && "Expected a dc::TokenType");
59  return DCTuple(v, {});
60 }
61 
62 static Value pack(OpBuilder &b, Value token, Value data = {}) {
63  if (!data)
64  return token;
65  return b.create<dc::PackOp>(token.getLoc(), token, data);
66 }
67 
68 // NOLINTNEXTLINE(misc-no-recursion)
69 static StructType tupleToStruct(TupleType tuple) {
70  auto *ctx = tuple.getContext();
71  mlir::SmallVector<hw::StructType::FieldInfo, 8> hwfields;
72  for (auto [i, innerType] : llvm::enumerate(tuple)) {
73  Type convertedInnerType = innerType;
74  if (auto tupleInnerType = dyn_cast<TupleType>(innerType))
75  convertedInnerType = tupleToStruct(tupleInnerType);
76  hwfields.push_back(
77  {StringAttr::get(ctx, "field" + Twine(i)), convertedInnerType});
78  }
79 
80  return hw::StructType::get(ctx, hwfields);
81 }
82 
83 class DCTypeConverter : public TypeConverter {
84 public:
85  DCTypeConverter() {
86  addConversion([](Type type) -> Type {
87  if (isa<NoneType>(type))
88  return dc::TokenType::get(type.getContext());
89 
90  // For pragmatic reasons, we use a struct type to represent tuples in the
91  // DC lowering; upstream MLIR doesn't have builtin type-modifying ops,
92  // so the next best thing is our "local" struct type in CIRCT.
93  if (auto tupleType = dyn_cast<TupleType>(type))
94  return dc::ValueType::get(type.getContext(), tupleToStruct(tupleType));
95  return dc::ValueType::get(type.getContext(), type);
96  });
97  addConversion([](ValueType type) { return type; });
98  addConversion([](TokenType type) { return type; });
99 
100  addTargetMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
101  mlir::ValueRange inputs,
102  mlir::Location loc) -> mlir::Value {
103  if (inputs.size() != 1)
104  return Value();
105 
106  // Materialize !dc.value<> -> !dc.token
107  if (isa<dc::TokenType>(resultType) &&
108  isa<dc::ValueType>(inputs.front().getType()))
109  return unpack(builder, inputs.front()).token;
110 
111  // Materialize !dc.token -> !dc.value<>
112  auto vt = dyn_cast<dc::ValueType>(resultType);
113  if (vt && !vt.getInnerType())
114  return pack(builder, inputs.front());
115 
116  return builder
117  .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
118  ->getResult(0);
119  });
120 
121  addSourceMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
122  mlir::ValueRange inputs,
123  mlir::Location loc) -> mlir::Value {
124  if (inputs.size() != 1)
125  return Value();
126 
127  // Materialize !dc.value<> -> !dc.token
128  if (isa<dc::TokenType>(resultType) &&
129  isa<dc::ValueType>(inputs.front().getType()))
130  return unpack(builder, inputs.front()).token;
131 
132  // Materialize !dc.token -> !dc.value<>
133  auto vt = dyn_cast<dc::ValueType>(resultType);
134  if (vt && !vt.getInnerType())
135  return pack(builder, inputs.front());
136 
137  return builder
138  .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
139  ->getResult(0);
140  });
141  }
142 };
143 
144 template <typename OpTy>
145 class DCOpConversionPattern : public OpConversionPattern<OpTy> {
146 public:
148  using OpAdaptor = typename OpTy::Adaptor;
149 
150  DCOpConversionPattern(MLIRContext *context, TypeConverter &typeConverter,
151  ConvertedOps *convertedOps)
152  : OpConversionPattern<OpTy>(typeConverter, context),
153  convertedOps(convertedOps) {}
154  mutable ConvertedOps *convertedOps;
155 };
156 
157 class CondBranchConversionPattern
158  : public DCOpConversionPattern<handshake::ConditionalBranchOp> {
159 public:
160  using DCOpConversionPattern<
161  handshake::ConditionalBranchOp>::DCOpConversionPattern;
162  using OpAdaptor = typename handshake::ConditionalBranchOp::Adaptor;
163 
164  LogicalResult
165  matchAndRewrite(handshake::ConditionalBranchOp op, OpAdaptor adaptor,
166  ConversionPatternRewriter &rewriter) const override {
167  auto condition = unpack(rewriter, adaptor.getConditionOperand());
168  auto data = unpack(rewriter, adaptor.getDataOperand());
169 
170  // Join the token of the condition and the input.
171  auto join = rewriter.create<dc::JoinOp>(
172  op.getLoc(), ValueRange{condition.token, data.token});
173 
174  // Pack that together with the condition data.
175  auto packedCondition = pack(rewriter, join, condition.data);
176 
177  // Branch on the input data and the joined control input.
178  auto branch = rewriter.create<dc::BranchOp>(op.getLoc(), packedCondition);
179 
180  // Pack the branch output tokens with the input data, and replace the uses.
181  llvm::SmallVector<Value, 4> packed;
182  packed.push_back(pack(rewriter, branch.getTrueToken(), data.data));
183  packed.push_back(pack(rewriter, branch.getFalseToken(), data.data));
184 
185  rewriter.replaceOp(op, packed);
186  return success();
187  }
188 };
189 
190 class ForkOpConversionPattern
191  : public DCOpConversionPattern<handshake::ForkOp> {
192 public:
193  using DCOpConversionPattern<handshake::ForkOp>::DCOpConversionPattern;
194  using OpAdaptor = typename handshake::ForkOp::Adaptor;
195 
196  LogicalResult
197  matchAndRewrite(handshake::ForkOp op, OpAdaptor adaptor,
198  ConversionPatternRewriter &rewriter) const override {
199  auto input = unpack(rewriter, adaptor.getOperand());
200  auto forkOut = rewriter.create<dc::ForkOp>(op.getLoc(), input.token,
201  op.getNumResults());
202 
203  // Pack the fork result tokens with the input data, and replace the uses.
204  llvm::SmallVector<Value, 4> packed;
205  for (auto res : forkOut.getResults())
206  packed.push_back(pack(rewriter, res, input.data));
207 
208  rewriter.replaceOp(op, packed);
209  return success();
210  }
211 };
212 
213 class JoinOpConversion : public DCOpConversionPattern<handshake::JoinOp> {
214 public:
215  using DCOpConversionPattern<handshake::JoinOp>::DCOpConversionPattern;
216  using OpAdaptor = typename handshake::JoinOp::Adaptor;
217 
218  LogicalResult
219  matchAndRewrite(handshake::JoinOp op, OpAdaptor adaptor,
220  ConversionPatternRewriter &rewriter) const override {
221  llvm::SmallVector<Value, 4> inputTokens;
222  for (auto input : adaptor.getData())
223  inputTokens.push_back(unpack(rewriter, input).token);
224 
225  rewriter.replaceOpWithNewOp<dc::JoinOp>(op, inputTokens);
226  return success();
227  }
228 };
229 
230 class MergeOpConversion : public DCOpConversionPattern<handshake::MergeOp> {
231 public:
232  using DCOpConversionPattern<handshake::MergeOp>::DCOpConversionPattern;
233  using OpAdaptor = typename handshake::MergeOp::Adaptor;
234 
235  LogicalResult
236  matchAndRewrite(handshake::MergeOp op, OpAdaptor adaptor,
237  ConversionPatternRewriter &rewriter) const override {
238  if (op.getNumOperands() > 2)
239  return rewriter.notifyMatchFailure(op, "only two inputs supported");
240 
241  SmallVector<Value, 4> tokens, data;
242 
243  for (auto input : adaptor.getDataOperands()) {
244  auto up = unpack(rewriter, input);
245  tokens.push_back(up.token);
246  if (up.data)
247  data.push_back(up.data);
248  }
249 
250  // Control side
251  Value selectedIndex = rewriter.create<dc::MergeOp>(op.getLoc(), tokens);
252  auto selectedIndexUnpacked = unpack(rewriter, selectedIndex);
253  Value mergeOutput;
254 
255  if (!data.empty()) {
256  // Data-merge; mux the selected input.
257  auto dataMux = rewriter.create<arith::SelectOp>(
258  op.getLoc(), selectedIndexUnpacked.data, data[0], data[1]);
259  convertedOps->insert(dataMux);
260 
261  // Pack the data mux with the control token.
262  mergeOutput = pack(rewriter, selectedIndexUnpacked.token, dataMux);
263  } else {
264  // Control-only merge; throw away the index value of the dc.merge
265  // operation and only forward the dc.token.
266  mergeOutput = selectedIndexUnpacked.token;
267  }
268 
269  rewriter.replaceOp(op, mergeOutput);
270  return success();
271  }
272 };
273 
274 class PackOpConversion : public DCOpConversionPattern<handshake::PackOp> {
275 public:
276  using DCOpConversionPattern<handshake::PackOp>::DCOpConversionPattern;
277  using OpAdaptor = typename handshake::PackOp::Adaptor;
278 
279  LogicalResult
280  matchAndRewrite(handshake::PackOp op, OpAdaptor adaptor,
281  ConversionPatternRewriter &rewriter) const override {
282  // Like the join conversion, but also emits a dc.pack_tuple operation to
283  // handle the data side of the operation (since there's no upstream support
284  // for doing so, sigh...)
285  llvm::SmallVector<Value, 4> inputTokens, inputData;
286  for (auto input : adaptor.getOperands()) {
287  DCTuple dct = unpack(rewriter, input);
288  inputTokens.push_back(dct.token);
289  if (dct.data)
290  inputData.push_back(dct.data);
291  }
292 
293  auto join = rewriter.create<dc::JoinOp>(op.getLoc(), inputTokens);
294  StructType structType =
295  tupleToStruct(cast<TupleType>(op.getResult().getType()));
296  auto packedData =
297  rewriter.create<hw::StructCreateOp>(op.getLoc(), structType, inputData);
298  convertedOps->insert(packedData);
299  rewriter.replaceOp(op, pack(rewriter, join, packedData));
300  return success();
301  }
302 };
303 
304 class UnpackOpConversion : public DCOpConversionPattern<handshake::UnpackOp> {
305 public:
306  using DCOpConversionPattern<handshake::UnpackOp>::DCOpConversionPattern;
307  using OpAdaptor = typename handshake::UnpackOp::Adaptor;
308 
309  LogicalResult
310  matchAndRewrite(handshake::UnpackOp op, OpAdaptor adaptor,
311  ConversionPatternRewriter &rewriter) const override {
312  // Unpack the !dc.value<tuple<...>> into the !dc.token and tuple<...>
313  // values.
314  DCTuple unpackedInput = unpack(rewriter, adaptor.getInput());
315  auto unpackedData =
316  rewriter.create<hw::StructExplodeOp>(op.getLoc(), unpackedInput.data);
317  convertedOps->insert(unpackedData);
318  // Re-pack each of the tuple elements with the token.
319  llvm::SmallVector<Value, 4> repackedInputs;
320  for (auto outputData : unpackedData.getResults())
321  repackedInputs.push_back(pack(rewriter, unpackedInput.token, outputData));
322 
323  rewriter.replaceOp(op, repackedInputs);
324  return success();
325  }
326 };
327 
328 class ControlMergeOpConversion
329  : public DCOpConversionPattern<handshake::ControlMergeOp> {
330 public:
331  using DCOpConversionPattern<handshake::ControlMergeOp>::DCOpConversionPattern;
332 
333  using OpAdaptor = typename handshake::ControlMergeOp::Adaptor;
334 
335  LogicalResult
336  matchAndRewrite(handshake::ControlMergeOp op, OpAdaptor adaptor,
337  ConversionPatternRewriter &rewriter) const override {
338  if (op.getDataOperands().size() != 2)
339  return op.emitOpError("expected two data operands");
340 
341  llvm::SmallVector<Value> tokens, data;
342  for (auto input : adaptor.getDataOperands()) {
343  auto up = unpack(rewriter, input);
344  tokens.push_back(up.token);
345  if (up.data)
346  data.push_back(up.data);
347  }
348 
349  bool isIndexType = isa<IndexType>(op.getIndex().getType());
350 
351  // control-side
352  Value selectedIndex = rewriter.create<dc::MergeOp>(op.getLoc(), tokens);
353  auto mergeOpUnpacked = unpack(rewriter, selectedIndex);
354  auto selValue = mergeOpUnpacked.data;
355 
356  Value dataSide = selectedIndex;
357  if (!data.empty()) {
358  // Data side mux using the selected input.
359  auto dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), selValue,
360  data[0], data[1]);
361  convertedOps->insert(dataMux);
362  // Pack the data mux with the control token.
363  auto packed = pack(rewriter, mergeOpUnpacked.token, dataMux);
364 
365  dataSide = packed;
366  }
367 
368  // if the original op used `index` as the select operand type, we need to
369  // index-cast the unpacked select operand
370  if (isIndexType) {
371  selValue = rewriter.create<arith::IndexCastOp>(
372  op.getLoc(), rewriter.getIndexType(), selValue);
373  convertedOps->insert(selValue.getDefiningOp());
374  selectedIndex = pack(rewriter, mergeOpUnpacked.token, selValue);
375  } else {
376  // The cmerge had a specific type defined for the index type. dc.merge
377  // provides an i1 operand for the selected index, so we need to cast it.
378  selValue = rewriter.create<arith::ExtUIOp>(
379  op.getLoc(), op.getIndex().getType(), selValue);
380  convertedOps->insert(selValue.getDefiningOp());
381  selectedIndex = pack(rewriter, mergeOpUnpacked.token, selValue);
382  }
383 
384  rewriter.replaceOp(op, {dataSide, selectedIndex});
385  return success();
386  }
387 };
388 
389 class SyncOpConversion : public DCOpConversionPattern<handshake::SyncOp> {
390 public:
391  using DCOpConversionPattern<handshake::SyncOp>::DCOpConversionPattern;
392  using OpAdaptor = typename handshake::SyncOp::Adaptor;
393 
394  LogicalResult
395  matchAndRewrite(handshake::SyncOp op, OpAdaptor adaptor,
396  ConversionPatternRewriter &rewriter) const override {
397  llvm::SmallVector<Value, 4> inputTokens;
398  llvm::SmallVector<Value, 4> inputData;
399  for (auto input : adaptor.getOperands()) {
400  auto unpacked = unpack(rewriter, input);
401  inputTokens.push_back(unpacked.token);
402  inputData.push_back(unpacked.data);
403  }
404 
405  auto syncToken = rewriter.create<dc::JoinOp>(op.getLoc(), inputTokens);
406 
407  // Wrap all outputs with the synchronization token
408  llvm::SmallVector<Value, 4> wrappedInputs;
409  for (auto inputData : inputData)
410  wrappedInputs.push_back(pack(rewriter, syncToken, inputData));
411 
412  rewriter.replaceOp(op, wrappedInputs);
413 
414  return success();
415  }
416 };
417 
418 class ConstantOpConversion
419  : public DCOpConversionPattern<handshake::ConstantOp> {
420 public:
421  using DCOpConversionPattern<handshake::ConstantOp>::DCOpConversionPattern;
422  using OpAdaptor = typename handshake::ConstantOp::Adaptor;
423 
424  LogicalResult
425  matchAndRewrite(handshake::ConstantOp op, OpAdaptor adaptor,
426  ConversionPatternRewriter &rewriter) const override {
427  // Wrap the constant with a token.
428  auto token = rewriter.create<dc::SourceOp>(op.getLoc());
429  auto cst =
430  rewriter.create<arith::ConstantOp>(op.getLoc(), adaptor.getValue());
431  convertedOps->insert(cst);
432  rewriter.replaceOp(op, pack(rewriter, token, cst));
433  return success();
434  }
435 };
436 
437 struct UnitRateConversionPattern : public ConversionPattern {
438 public:
439  UnitRateConversionPattern(MLIRContext *context, TypeConverter &converter,
440  ConvertedOps *joinedOps)
441  : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context),
442  joinedOps(joinedOps) {}
443  using ConversionPattern::ConversionPattern;
444 
445  // Generic pattern which replaces an operation by one of the same type, but
446  // with the in- and outputs synchronized through join semantics.
447  LogicalResult
448  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
449  ConversionPatternRewriter &rewriter) const override {
450  llvm::SmallVector<Value> inputData;
451 
452  Value outToken;
453  if (operands.empty()) {
454  if (!op->hasTrait<OpTrait::ConstantLike>())
455  return op->emitOpError(
456  "no-operand operation which isn't constant-like. Too dangerous "
457  "to assume semantics - won't convert");
458 
459  // Constant-like operation; assume the token can be represented as a
460  // constant `dc.source`.
461  outToken = rewriter.create<dc::SourceOp>(op->getLoc());
462  } else {
463  llvm::SmallVector<Value> inputTokens;
464  for (auto input : operands) {
465  auto dct = unpack(rewriter, input);
466  inputData.push_back(dct.data);
467  inputTokens.push_back(dct.token);
468  }
469  // Join the tokens of the inputs.
470  assert(!inputTokens.empty() && "Expected at least one input token");
471  outToken = rewriter.create<dc::JoinOp>(op->getLoc(), inputTokens);
472  }
473 
474  // Patchwork to fix bad IR design in Handshake.
475  auto opName = op->getName();
476  if (opName.getStringRef() == "handshake.select") {
477  opName = OperationName("arith.select", getContext());
478  } else if (opName.getStringRef() == "handshake.constant") {
479  opName = OperationName("arith.constant", getContext());
480  }
481 
482  // Re-create the operation using the unpacked input data.
483  OperationState state(op->getLoc(), opName, inputData, op->getResultTypes(),
484  op->getAttrs(), op->getSuccessors());
485 
486  Operation *newOp = rewriter.create(state);
487  joinedOps->insert(newOp);
488 
489  // Pack the result token with the output data, and replace the uses.
490  llvm::SmallVector<Value> results;
491  for (auto result : newOp->getResults())
492  results.push_back(pack(rewriter, outToken, result));
493 
494  rewriter.replaceOp(op, results);
495 
496  return success();
497  }
498 
499  mutable ConvertedOps *joinedOps;
500 };
501 
502 class SinkOpConversionPattern
503  : public DCOpConversionPattern<handshake::SinkOp> {
504 public:
505  using DCOpConversionPattern<handshake::SinkOp>::DCOpConversionPattern;
506  using OpAdaptor = typename handshake::SinkOp::Adaptor;
507 
508  LogicalResult
509  matchAndRewrite(handshake::SinkOp op, OpAdaptor adaptor,
510  ConversionPatternRewriter &rewriter) const override {
511  auto input = unpack(rewriter, adaptor.getOperand());
512  rewriter.replaceOpWithNewOp<dc::SinkOp>(op, input.token);
513  return success();
514  }
515 };
516 
517 class SourceOpConversionPattern
518  : public DCOpConversionPattern<handshake::SourceOp> {
519 public:
520  using DCOpConversionPattern<handshake::SourceOp>::DCOpConversionPattern;
521  using OpAdaptor = typename handshake::SourceOp::Adaptor;
522 
523  LogicalResult
524  matchAndRewrite(handshake::SourceOp op, OpAdaptor adaptor,
525  ConversionPatternRewriter &rewriter) const override {
526  rewriter.replaceOpWithNewOp<dc::SourceOp>(op);
527  return success();
528  }
529 };
530 
531 class BufferOpConversion : public DCOpConversionPattern<handshake::BufferOp> {
532 public:
533  using DCOpConversionPattern<handshake::BufferOp>::DCOpConversionPattern;
534  using OpAdaptor = typename handshake::BufferOp::Adaptor;
535 
536  LogicalResult
537  matchAndRewrite(handshake::BufferOp op, OpAdaptor adaptor,
538  ConversionPatternRewriter &rewriter) const override {
539  rewriter.getI32IntegerAttr(1);
540  rewriter.replaceOpWithNewOp<dc::BufferOp>(
541  op, adaptor.getOperand(), static_cast<size_t>(op.getNumSlots()),
542  op.getInitValuesAttr());
543  return success();
544  }
545 };
546 
547 class ReturnOpConversion : public OpConversionPattern<handshake::ReturnOp> {
548 public:
550  using OpAdaptor = typename handshake::ReturnOp::Adaptor;
551 
552  LogicalResult
553  matchAndRewrite(handshake::ReturnOp op, OpAdaptor adaptor,
554  ConversionPatternRewriter &rewriter) const override {
555  // Locate existing output op, Append operands to output op, and move to
556  // the end of the block.
557  auto hwModule = op->getParentOfType<hw::HWModuleOp>();
558  auto outputOp = *hwModule.getBodyBlock()->getOps<hw::OutputOp>().begin();
559  outputOp->setOperands(adaptor.getOperands());
560  outputOp->moveAfter(&hwModule.getBodyBlock()->back());
561  rewriter.eraseOp(op);
562  return success();
563  }
564 };
565 
566 class MuxOpConversionPattern : public DCOpConversionPattern<handshake::MuxOp> {
567 public:
568  using DCOpConversionPattern<handshake::MuxOp>::DCOpConversionPattern;
569  using OpAdaptor = typename handshake::MuxOp::Adaptor;
570 
571  LogicalResult
572  matchAndRewrite(handshake::MuxOp op, OpAdaptor adaptor,
573  ConversionPatternRewriter &rewriter) const override {
574  auto select = unpack(rewriter, adaptor.getSelectOperand());
575  auto selectData = select.data;
576  auto selectToken = select.token;
577  bool isIndexType = isa<IndexType>(selectData.getType());
578 
579  bool withData = !isa<NoneType>(op.getResult().getType());
580 
581  llvm::SmallVector<DCTuple> inputs;
582  for (auto input : adaptor.getDataOperands())
583  inputs.push_back(unpack(rewriter, input));
584 
585  Value dataMux;
586  Value controlMux = inputs.front().token;
587  // Convert the data-side mux to a sequence of arith.select operations.
588  // The data and control muxes are assumed one-hot and the base-case is set
589  // as the first input.
590  if (withData)
591  dataMux = inputs[0].data;
592 
593  llvm::SmallVector<Value> controlMuxInputs = {inputs.front().token};
594  for (auto [i, input] :
595  llvm::enumerate(llvm::make_range(inputs.begin() + 1, inputs.end()))) {
596  if (!withData)
597  continue;
598 
599  Value cmpIndex;
600  Value inputData = input.data;
601  Value inputControl = input.token;
602  if (isIndexType) {
603  cmpIndex = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), i);
604  } else {
605  size_t width = cast<IntegerType>(selectData.getType()).getWidth();
606  cmpIndex = rewriter.create<arith::ConstantIntOp>(op.getLoc(), i, width);
607  }
608  auto inputSelected = rewriter.create<arith::CmpIOp>(
609  op.getLoc(), arith::CmpIPredicate::eq, selectData, cmpIndex);
610  dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), inputSelected,
611  inputData, dataMux);
612 
613  // Legalize the newly created operations.
614  convertedOps->insert(cmpIndex.getDefiningOp());
615  convertedOps->insert(dataMux.getDefiningOp());
616  convertedOps->insert(inputSelected);
617 
618  // And similarly for the control mux, by muxing the input token with a
619  // select value that has it's control from the original select token +
620  // the inputSelected value.
621  auto inputSelectedControl = pack(rewriter, selectToken, inputSelected);
622  controlMux = rewriter.create<dc::SelectOp>(
623  op.getLoc(), inputSelectedControl, inputControl, controlMux);
624  convertedOps->insert(controlMux.getDefiningOp());
625  }
626 
627  // finally, pack the control and data side muxes into the output value.
628  rewriter.replaceOp(
629  op, pack(rewriter, controlMux, withData ? dataMux : Value{}));
630  return success();
631  }
632 };
633 
634 static hw::ModulePortInfo getModulePortInfoHS(const TypeConverter &tc,
635  handshake::FuncOp funcOp) {
636  SmallVector<hw::PortInfo> inputs, outputs;
637  auto ft = funcOp.getFunctionType();
638  funcOp.resolveArgAndResNames();
639 
640  // Add all inputs of funcOp.
641  for (auto [index, type] : llvm::enumerate(ft.getInputs()))
642  inputs.push_back({{funcOp.getArgName(index), tc.convertType(type),
644  index,
645  {}});
646 
647  // Add all outputs of funcOp.
648  for (auto [index, type] : llvm::enumerate(ft.getResults()))
649  outputs.push_back({{funcOp.getResName(index), tc.convertType(type),
651  index,
652  {}});
653 
654  return hw::ModulePortInfo{inputs, outputs};
655 }
656 
657 class FuncOpConversion : public DCOpConversionPattern<handshake::FuncOp> {
658 public:
659  using DCOpConversionPattern<handshake::FuncOp>::DCOpConversionPattern;
660  using OpAdaptor = typename handshake::FuncOp::Adaptor;
661 
662  // Replaces a handshake.func with a hw.module, converting the argument and
663  // result types using the provided type converter.
664  // @mortbopet: Not a fan of converting to hw here seeing as we don't
665  // necessarily have hardware semantics here. But, DC doesn't define a
666  // function operation, and there is no "func.graph_func" or any other
667  // generic function operation which is a graph region...
668  LogicalResult
669  matchAndRewrite(handshake::FuncOp op, OpAdaptor adaptor,
670  ConversionPatternRewriter &rewriter) const override {
671  ModulePortInfo ports = getModulePortInfoHS(*getTypeConverter(), op);
672 
673  if (op.isExternal()) {
674  auto mod = rewriter.create<hw::HWModuleExternOp>(
675  op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
676  convertedOps->insert(mod);
677  } else {
678  auto hwModule = rewriter.create<hw::HWModuleOp>(
679  op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
680 
681  auto &region = op->getRegions().front();
682 
683  Region &moduleRegion = hwModule->getRegions().front();
684  rewriter.mergeBlocks(&region.getBlocks().front(), hwModule.getBodyBlock(),
685  hwModule.getBodyBlock()->getArguments());
686  TypeConverter::SignatureConversion result(moduleRegion.getNumArguments());
687  (void)getTypeConverter()->convertSignatureArgs(
688  TypeRange(moduleRegion.getArgumentTypes()), result);
689  rewriter.applySignatureConversion(hwModule.getBodyBlock(), result);
690  convertedOps->insert(hwModule);
691  }
692 
693  rewriter.eraseOp(op);
694  return success();
695  }
696 };
697 
698 /// Lower the ESIInstanceOp to `hw.instance` with `dc.from_esi` and `dc.to_esi`
699 /// to convert the args/results.
700 class ESIInstanceConversionPattern
701  : public OpConversionPattern<handshake::ESIInstanceOp> {
702 public:
703  ESIInstanceConversionPattern(MLIRContext *context,
704  const HWSymbolCache &symCache)
705  : OpConversionPattern(context), symCache(symCache) {}
706 
707  LogicalResult
708  matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
709  ConversionPatternRewriter &rewriter) const override {
710  Location loc = op.getLoc();
711  SmallVector<Value> operands;
712  for (size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
713  i < e; ++i)
714  operands.push_back(
715  rewriter.create<dc::FromESIOp>(loc, adaptor.getOperands()[i]));
716  operands.push_back(adaptor.getClk());
717  operands.push_back(adaptor.getRst());
718  // Locate the lowered module so the instance builder can get all the
719  // metadata.
720  Operation *targetModule = symCache.getDefinition(op.getModuleAttr());
721  // And replace the op with an instance of the target module.
722  auto inst = rewriter.create<hw::InstanceOp>(loc, targetModule,
723  op.getInstNameAttr(), operands);
724  SmallVector<Value> esiResults(
725  llvm::map_range(inst.getResults(), [&](Value v) {
726  return rewriter.create<dc::ToESIOp>(loc, v);
727  }));
728  rewriter.replaceOp(op, esiResults);
729  return success();
730  }
731 
732 private:
733  const HWSymbolCache &symCache;
734 };
735 
736 /// Add DC clock and reset ports to the module.
737 void addClkRst(hw::HWModuleOp mod, StringRef clkName, StringRef rstName) {
738  auto *ctx = mod.getContext();
739 
740  size_t numInputs = mod.getNumInputPorts();
741  mod.insertInput(numInputs, clkName, seq::ClockType::get(ctx));
742  mod.setPortAttrs(
743  numInputs,
744  DictionaryAttr::get(ctx, {NamedAttribute(StringAttr::get(ctx, "dc.clock"),
745  UnitAttr::get(ctx))}));
746  mod.insertInput(numInputs + 1, rstName, IntegerType::get(ctx, 1));
747  mod.setPortAttrs(
748  numInputs + 1,
749  DictionaryAttr::get(ctx, {NamedAttribute(StringAttr::get(ctx, "dc.reset"),
750  UnitAttr::get(ctx))}));
751 
752  // We must initialize any port attributes that are not set otherwise the
753  // verifier will fail.
754  for (size_t portNum = 0, e = mod.getNumPorts(); portNum < e; ++portNum) {
755  auto attrs = dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(portNum));
756  if (attrs)
757  continue;
758  mod.setPortAttrs(portNum, DictionaryAttr::get(ctx, {}));
759  }
760 }
761 
762 class HandshakeToDCPass
763  : public circt::impl::HandshakeToDCBase<HandshakeToDCPass> {
764 public:
765  using Base::Base;
766  void runOnOperation() override {
767  mlir::ModuleOp mod = getOperation();
768  auto patternBuilder = [&](TypeConverter &typeConverter,
769  handshaketodc::ConvertedOps &convertedOps,
770  RewritePatternSet &patterns) {
771  patterns.add<FuncOpConversion>(mod.getContext(), typeConverter,
772  &convertedOps);
773  patterns.add<ReturnOpConversion>(typeConverter, mod.getContext());
774  };
775 
776  LogicalResult res =
777  runHandshakeToDC(mod, circt::HandshakeToDCOptions{clkName, rstName},
778  patternBuilder, nullptr);
779  if (failed(res))
780  signalPassFailure();
781  }
782 };
783 } // namespace
784 
786  mlir::Operation *op, circt::HandshakeToDCOptions options,
787  llvm::function_ref<void(TypeConverter &typeConverter,
788  handshaketodc::ConvertedOps &convertedOps,
789  RewritePatternSet &patterns)>
790  patternBuilder,
791  llvm::function_ref<void(mlir::ConversionTarget &)> configureTarget) {
792  // Maintain the set of operations which has been converted either through
793  // unit rate conversion, or as part of other conversions.
794  // Rationale:
795  // This is needed for all of the arith ops that get created as part of the
796  // handshake ops (e.g. arith.select for handshake.mux). There's a bit of a
797  // dilemma here seeing as all operations need to be converted/touched in a
798  // handshake.func - which is done so by UnitRateConversionPattern (when no
799  // other pattern applies). However, we obviously don't want to run said
800  // pattern on these newly created ops since they do not have handshake
801  // semantics.
802  handshaketodc::ConvertedOps convertedOps;
803  mlir::MLIRContext *ctx = op->getContext();
804  ConversionTarget target(*ctx);
805  target.addIllegalDialect<handshake::HandshakeDialect>();
806  target.addLegalDialect<dc::DCDialect>();
807  target.addLegalOp<mlir::ModuleOp, handshake::ESIInstanceOp, hw::HWModuleOp,
808  hw::OutputOp>();
809 
810  // And any user-specified target adjustments
811  if (configureTarget)
812  configureTarget(target);
813 
814  // The various patterns will insert new operations into the module to
815  // facilitate the conversion - however, these operations must be
816  // distinguishable from already converted operations (which may be of the
817  // same type as the newly inserted operations). To do this, we mark all
818  // operations which have been converted as legal, and all other operations
819  // as illegal.
820  target.markUnknownOpDynamicallyLegal([&](Operation *op) {
821  return convertedOps.contains(op) ||
822  // Allow any ops which weren't in a `handshake.func` to pass through.
823  !convertedOps.contains(op->getParentOfType<hw::HWModuleOp>());
824  });
825 
826  DCTypeConverter typeConverter;
827  RewritePatternSet patterns(ctx);
828 
829  // Add handshake conversion patterns.
830  // Note: merge/control merge are not supported - these are non-deterministic
831  // operators and we do not care for them.
832  patterns
833  .add<BufferOpConversion, CondBranchConversionPattern,
834  SinkOpConversionPattern, SourceOpConversionPattern,
835  MuxOpConversionPattern, ForkOpConversionPattern, JoinOpConversion,
836  PackOpConversion, UnpackOpConversion, MergeOpConversion,
837  ControlMergeOpConversion, ConstantOpConversion, SyncOpConversion>(
838  ctx, typeConverter, &convertedOps);
839 
840  // ALL other single-result operations are converted via the
841  // UnitRateConversionPattern.
842  patterns.add<UnitRateConversionPattern>(ctx, typeConverter, &convertedOps);
843 
844  // Build any user-specified patterns
845  patternBuilder(typeConverter, convertedOps, patterns);
846  if (failed(applyPartialConversion(op, target, std::move(patterns))))
847  return failure();
848 
849  // Add clock and reset ports to each converted module.
850  for (auto &op : convertedOps)
851  if (auto mod = dyn_cast<hw::HWModuleOp>(op); mod)
852  addClkRst(mod, options.clkName, options.rstName);
853 
854  // Run conversions which need see everything.
855  HWSymbolCache symbolCache;
856  symbolCache.addDefinitions(op);
857  symbolCache.freeze();
858  ConversionTarget globalLoweringTarget(*ctx);
859  globalLoweringTarget.addIllegalDialect<handshake::HandshakeDialect>();
860  globalLoweringTarget.addLegalDialect<dc::DCDialect, hw::HWDialect>();
861  RewritePatternSet globalPatterns(ctx);
862  globalPatterns.add<ESIInstanceConversionPattern>(ctx, symbolCache);
863  if (failed(applyPartialConversion(op, globalLoweringTarget,
864  std::move(globalPatterns))))
865  return op->emitOpError() << "error during conversion";
866 
867  return success();
868 }
assert(baseType &&"element must be base type")
static Type tupleToStruct(TupleType tuple)
Definition: DCToHW.cpp:48
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition: HWSymCache.h:27
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition: HWSymCache.h:56
void freeze()
Mark the cache as frozen, which allows it to be shared across threads.
Definition: HWSymCache.h:75
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
DenseSet< Operation * > ConvertedOps
Definition: HandshakeToDC.h:31
LogicalResult runHandshakeToDC(mlir::Operation *op, HandshakeToDCOptions options, llvm::function_ref< void(TypeConverter &typeConverter, ConvertedOps &convertedOps, RewritePatternSet &patterns)> patternBuilder, llvm::function_ref< void(mlir::ConversionTarget &)> configureTarget={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: hw.py:1
This holds a decoded list of input/inout and output ports for a module or instance.
Creates a new Calyx component for each FuncOp in the program.