CIRCT 20.0.0git
Loading...
Searching...
No Matches
HandshakeToDC.cpp
Go to the documentation of this file.
1//===- HandshakeToDC.cpp - Translate Handshake into DC --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7//===----------------------------------------------------------------------===//
8//
9// This is the main Handshake to DC Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
24#include "mlir/Dialect/Arith/IR/Arith.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Pass/PassManager.h"
27#include "mlir/Transforms/DialectConversion.h"
28#include "llvm/Support/MathExtras.h"
29#include <optional>
30
31namespace circt {
32#define GEN_PASS_DEF_HANDSHAKETODC
33#include "circt/Conversion/Passes.h.inc"
34} // namespace circt
35
36using namespace mlir;
37using namespace circt;
38using namespace handshake;
39using namespace dc;
40using namespace hw;
41using namespace handshaketodc;
42
43namespace {
44
45struct DCTuple {
46 DCTuple() = default;
47 DCTuple(Value token, Value data) : token(token), data(data) {}
48 DCTuple(dc::UnpackOp unpack)
49 : token(unpack.getToken()), data(unpack.getOutput()) {}
50 Value token;
51 Value data;
52};
53
54// Unpack a !dc.value<...> into a DCTuple.
55static DCTuple unpack(OpBuilder &b, Value v) {
56 if (isa<dc::ValueType>(v.getType()))
57 return DCTuple(b.create<dc::UnpackOp>(v.getLoc(), v));
58 assert(isa<dc::TokenType>(v.getType()) && "Expected a dc::TokenType");
59 return DCTuple(v, {});
60}
61
62static Value pack(OpBuilder &b, Value token, Value data = {}) {
63 if (!data)
64 return token;
65 return b.create<dc::PackOp>(token.getLoc(), token, data);
66}
67
68// NOLINTNEXTLINE(misc-no-recursion)
69static StructType tupleToStruct(TupleType tuple) {
70 auto *ctx = tuple.getContext();
71 mlir::SmallVector<hw::StructType::FieldInfo, 8> hwfields;
72 for (auto [i, innerType] : llvm::enumerate(tuple)) {
73 Type convertedInnerType = innerType;
74 if (auto tupleInnerType = dyn_cast<TupleType>(innerType))
75 convertedInnerType = tupleToStruct(tupleInnerType);
76 hwfields.push_back(
77 {StringAttr::get(ctx, "field" + Twine(i)), convertedInnerType});
78 }
79
80 return hw::StructType::get(ctx, hwfields);
81}
82
83class DCTypeConverter : public TypeConverter {
84public:
85 DCTypeConverter() {
86 addConversion([](Type type) -> Type {
87 if (isa<NoneType>(type))
88 return dc::TokenType::get(type.getContext());
89
90 // For pragmatic reasons, we use a struct type to represent tuples in the
91 // DC lowering; upstream MLIR doesn't have builtin type-modifying ops,
92 // so the next best thing is our "local" struct type in CIRCT.
93 if (auto tupleType = dyn_cast<TupleType>(type))
94 return dc::ValueType::get(type.getContext(), tupleToStruct(tupleType));
95 return dc::ValueType::get(type.getContext(), type);
96 });
97 addConversion([](ValueType type) { return type; });
98 addConversion([](TokenType type) { return type; });
99
100 addTargetMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
101 mlir::ValueRange inputs,
102 mlir::Location loc) -> mlir::Value {
103 if (inputs.size() != 1)
104 return Value();
105
106 // Materialize !dc.value<> -> !dc.token
107 if (isa<dc::TokenType>(resultType) &&
108 isa<dc::ValueType>(inputs.front().getType()))
109 return unpack(builder, inputs.front()).token;
110
111 // Materialize !dc.token -> !dc.value<>
112 auto vt = dyn_cast<dc::ValueType>(resultType);
113 if (vt && !vt.getInnerType())
114 return pack(builder, inputs.front());
115
116 return builder
117 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
118 ->getResult(0);
119 });
120
121 addSourceMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
122 mlir::ValueRange inputs,
123 mlir::Location loc) -> mlir::Value {
124 if (inputs.size() != 1)
125 return Value();
126
127 // Materialize !dc.value<> -> !dc.token
128 if (isa<dc::TokenType>(resultType) &&
129 isa<dc::ValueType>(inputs.front().getType()))
130 return unpack(builder, inputs.front()).token;
131
132 // Materialize !dc.token -> !dc.value<>
133 auto vt = dyn_cast<dc::ValueType>(resultType);
134 if (vt && !vt.getInnerType())
135 return pack(builder, inputs.front());
136
137 return builder
138 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
139 ->getResult(0);
140 });
141 }
142};
143
144template <typename OpTy>
145class DCOpConversionPattern : public OpConversionPattern<OpTy> {
146public:
148 using OpAdaptor = typename OpTy::Adaptor;
149
150 DCOpConversionPattern(MLIRContext *context, TypeConverter &typeConverter,
151 ConvertedOps *convertedOps)
152 : OpConversionPattern<OpTy>(typeConverter, context),
153 convertedOps(convertedOps) {}
154 mutable ConvertedOps *convertedOps;
155};
156
157class CondBranchConversionPattern
158 : public DCOpConversionPattern<handshake::ConditionalBranchOp> {
159public:
160 using DCOpConversionPattern<
161 handshake::ConditionalBranchOp>::DCOpConversionPattern;
162 using OpAdaptor = typename handshake::ConditionalBranchOp::Adaptor;
163
164 LogicalResult
165 matchAndRewrite(handshake::ConditionalBranchOp op, OpAdaptor adaptor,
166 ConversionPatternRewriter &rewriter) const override {
167 auto condition = unpack(rewriter, adaptor.getConditionOperand());
168 auto data = unpack(rewriter, adaptor.getDataOperand());
169
170 // Join the token of the condition and the input.
171 auto join = rewriter.create<dc::JoinOp>(
172 op.getLoc(), ValueRange{condition.token, data.token});
173
174 // Pack that together with the condition data.
175 auto packedCondition = pack(rewriter, join, condition.data);
176
177 // Branch on the input data and the joined control input.
178 auto branch = rewriter.create<dc::BranchOp>(op.getLoc(), packedCondition);
179
180 // Pack the branch output tokens with the input data, and replace the uses.
181 llvm::SmallVector<Value, 4> packed;
182 packed.push_back(pack(rewriter, branch.getTrueToken(), data.data));
183 packed.push_back(pack(rewriter, branch.getFalseToken(), data.data));
184
185 rewriter.replaceOp(op, packed);
186 return success();
187 }
188};
189
190class ForkOpConversionPattern
191 : public DCOpConversionPattern<handshake::ForkOp> {
192public:
193 using DCOpConversionPattern<handshake::ForkOp>::DCOpConversionPattern;
194 using OpAdaptor = typename handshake::ForkOp::Adaptor;
195
196 LogicalResult
197 matchAndRewrite(handshake::ForkOp op, OpAdaptor adaptor,
198 ConversionPatternRewriter &rewriter) const override {
199 auto input = unpack(rewriter, adaptor.getOperand());
200 auto forkOut = rewriter.create<dc::ForkOp>(op.getLoc(), input.token,
201 op.getNumResults());
202
203 // Pack the fork result tokens with the input data, and replace the uses.
204 llvm::SmallVector<Value, 4> packed;
205 for (auto res : forkOut.getResults())
206 packed.push_back(pack(rewriter, res, input.data));
207
208 rewriter.replaceOp(op, packed);
209 return success();
210 }
211};
212
213class JoinOpConversion : public DCOpConversionPattern<handshake::JoinOp> {
214public:
215 using DCOpConversionPattern<handshake::JoinOp>::DCOpConversionPattern;
216 using OpAdaptor = typename handshake::JoinOp::Adaptor;
217
218 LogicalResult
219 matchAndRewrite(handshake::JoinOp op, OpAdaptor adaptor,
220 ConversionPatternRewriter &rewriter) const override {
221 llvm::SmallVector<Value, 4> inputTokens;
222 for (auto input : adaptor.getData())
223 inputTokens.push_back(unpack(rewriter, input).token);
224
225 rewriter.replaceOpWithNewOp<dc::JoinOp>(op, inputTokens);
226 return success();
227 }
228};
229
230class MergeOpConversion : public DCOpConversionPattern<handshake::MergeOp> {
231public:
232 using DCOpConversionPattern<handshake::MergeOp>::DCOpConversionPattern;
233 using OpAdaptor = typename handshake::MergeOp::Adaptor;
234
235 LogicalResult
236 matchAndRewrite(handshake::MergeOp op, OpAdaptor adaptor,
237 ConversionPatternRewriter &rewriter) const override {
238 if (op.getNumOperands() > 2)
239 return rewriter.notifyMatchFailure(op, "only two inputs supported");
240
241 SmallVector<Value, 4> tokens, data;
242
243 for (auto input : adaptor.getDataOperands()) {
244 auto up = unpack(rewriter, input);
245 tokens.push_back(up.token);
246 if (up.data)
247 data.push_back(up.data);
248 }
249
250 // Control side
251 Value selectedIndex = rewriter.create<dc::MergeOp>(op.getLoc(), tokens);
252 auto selectedIndexUnpacked = unpack(rewriter, selectedIndex);
253 Value mergeOutput;
254
255 if (!data.empty()) {
256 // Data-merge; mux the selected input.
257 auto dataMux = rewriter.create<arith::SelectOp>(
258 op.getLoc(), selectedIndexUnpacked.data, data[0], data[1]);
259 convertedOps->insert(dataMux);
260
261 // Pack the data mux with the control token.
262 mergeOutput = pack(rewriter, selectedIndexUnpacked.token, dataMux);
263 } else {
264 // Control-only merge; throw away the index value of the dc.merge
265 // operation and only forward the dc.token.
266 mergeOutput = selectedIndexUnpacked.token;
267 }
268
269 rewriter.replaceOp(op, mergeOutput);
270 return success();
271 }
272};
273
274class PackOpConversion : public DCOpConversionPattern<handshake::PackOp> {
275public:
276 using DCOpConversionPattern<handshake::PackOp>::DCOpConversionPattern;
277 using OpAdaptor = typename handshake::PackOp::Adaptor;
278
279 LogicalResult
280 matchAndRewrite(handshake::PackOp op, OpAdaptor adaptor,
281 ConversionPatternRewriter &rewriter) const override {
282 // Like the join conversion, but also emits a dc.pack_tuple operation to
283 // handle the data side of the operation (since there's no upstream support
284 // for doing so, sigh...)
285 llvm::SmallVector<Value, 4> inputTokens, inputData;
286 for (auto input : adaptor.getOperands()) {
287 DCTuple dct = unpack(rewriter, input);
288 inputTokens.push_back(dct.token);
289 if (dct.data)
290 inputData.push_back(dct.data);
291 }
292
293 auto join = rewriter.create<dc::JoinOp>(op.getLoc(), inputTokens);
294 StructType structType =
295 tupleToStruct(cast<TupleType>(op.getResult().getType()));
296 auto packedData =
297 rewriter.create<hw::StructCreateOp>(op.getLoc(), structType, inputData);
298 convertedOps->insert(packedData);
299 rewriter.replaceOp(op, pack(rewriter, join, packedData));
300 return success();
301 }
302};
303
304class UnpackOpConversion : public DCOpConversionPattern<handshake::UnpackOp> {
305public:
306 using DCOpConversionPattern<handshake::UnpackOp>::DCOpConversionPattern;
307 using OpAdaptor = typename handshake::UnpackOp::Adaptor;
308
309 LogicalResult
310 matchAndRewrite(handshake::UnpackOp op, OpAdaptor adaptor,
311 ConversionPatternRewriter &rewriter) const override {
312 // Unpack the !dc.value<tuple<...>> into the !dc.token and tuple<...>
313 // values.
314 DCTuple unpackedInput = unpack(rewriter, adaptor.getInput());
315 auto unpackedData =
316 rewriter.create<hw::StructExplodeOp>(op.getLoc(), unpackedInput.data);
317 convertedOps->insert(unpackedData);
318 // Re-pack each of the tuple elements with the token.
319 llvm::SmallVector<Value, 4> repackedInputs;
320 for (auto outputData : unpackedData.getResults())
321 repackedInputs.push_back(pack(rewriter, unpackedInput.token, outputData));
322
323 rewriter.replaceOp(op, repackedInputs);
324 return success();
325 }
326};
327
328class ControlMergeOpConversion
329 : public DCOpConversionPattern<handshake::ControlMergeOp> {
330public:
331 using DCOpConversionPattern<handshake::ControlMergeOp>::DCOpConversionPattern;
332
333 using OpAdaptor = typename handshake::ControlMergeOp::Adaptor;
334
335 LogicalResult
336 matchAndRewrite(handshake::ControlMergeOp op, OpAdaptor adaptor,
337 ConversionPatternRewriter &rewriter) const override {
338 if (op.getDataOperands().size() != 2)
339 return op.emitOpError("expected two data operands");
340
341 llvm::SmallVector<Value> tokens, data;
342 for (auto input : adaptor.getDataOperands()) {
343 auto up = unpack(rewriter, input);
344 tokens.push_back(up.token);
345 if (up.data)
346 data.push_back(up.data);
347 }
348
349 bool isIndexType = isa<IndexType>(op.getIndex().getType());
350
351 // control-side
352 Value selectedIndex = rewriter.create<dc::MergeOp>(op.getLoc(), tokens);
353 auto mergeOpUnpacked = unpack(rewriter, selectedIndex);
354 auto selValue = mergeOpUnpacked.data;
355
356 Value dataSide = selectedIndex;
357 if (!data.empty()) {
358 // Data side mux using the selected input.
359 auto dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), selValue,
360 data[0], data[1]);
361 convertedOps->insert(dataMux);
362 // Pack the data mux with the control token.
363 auto packed = pack(rewriter, mergeOpUnpacked.token, dataMux);
364
365 dataSide = packed;
366 }
367
368 // if the original op used `index` as the select operand type, we need to
369 // index-cast the unpacked select operand
370 if (isIndexType) {
371 selValue = rewriter.create<arith::IndexCastOp>(
372 op.getLoc(), rewriter.getIndexType(), selValue);
373 convertedOps->insert(selValue.getDefiningOp());
374 selectedIndex = pack(rewriter, mergeOpUnpacked.token, selValue);
375 } else {
376 // The cmerge had a specific type defined for the index type. dc.merge
377 // provides an i1 operand for the selected index, so we need to cast it.
378 selValue = rewriter.create<arith::ExtUIOp>(
379 op.getLoc(), op.getIndex().getType(), selValue);
380 convertedOps->insert(selValue.getDefiningOp());
381 selectedIndex = pack(rewriter, mergeOpUnpacked.token, selValue);
382 }
383
384 rewriter.replaceOp(op, {dataSide, selectedIndex});
385 return success();
386 }
387};
388
389class SyncOpConversion : public DCOpConversionPattern<handshake::SyncOp> {
390public:
391 using DCOpConversionPattern<handshake::SyncOp>::DCOpConversionPattern;
392 using OpAdaptor = typename handshake::SyncOp::Adaptor;
393
394 LogicalResult
395 matchAndRewrite(handshake::SyncOp op, OpAdaptor adaptor,
396 ConversionPatternRewriter &rewriter) const override {
397 llvm::SmallVector<Value, 4> inputTokens;
398 llvm::SmallVector<Value, 4> inputData;
399 for (auto input : adaptor.getOperands()) {
400 auto unpacked = unpack(rewriter, input);
401 inputTokens.push_back(unpacked.token);
402 inputData.push_back(unpacked.data);
403 }
404
405 auto syncToken = rewriter.create<dc::JoinOp>(op.getLoc(), inputTokens);
406
407 // Wrap all outputs with the synchronization token
408 llvm::SmallVector<Value, 4> wrappedInputs;
409 for (auto inputData : inputData)
410 wrappedInputs.push_back(pack(rewriter, syncToken, inputData));
411
412 rewriter.replaceOp(op, wrappedInputs);
413
414 return success();
415 }
416};
417
418class ConstantOpConversion
419 : public DCOpConversionPattern<handshake::ConstantOp> {
420public:
421 using DCOpConversionPattern<handshake::ConstantOp>::DCOpConversionPattern;
422 using OpAdaptor = typename handshake::ConstantOp::Adaptor;
423
424 LogicalResult
425 matchAndRewrite(handshake::ConstantOp op, OpAdaptor adaptor,
426 ConversionPatternRewriter &rewriter) const override {
427 // Wrap the constant with a token.
428 auto token = rewriter.create<dc::SourceOp>(op.getLoc());
429 auto cst =
430 rewriter.create<arith::ConstantOp>(op.getLoc(), adaptor.getValue());
431 convertedOps->insert(cst);
432 rewriter.replaceOp(op, pack(rewriter, token, cst));
433 return success();
434 }
435};
436
437struct UnitRateConversionPattern : public ConversionPattern {
438public:
439 UnitRateConversionPattern(MLIRContext *context, TypeConverter &converter,
440 ConvertedOps *joinedOps)
441 : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context),
442 joinedOps(joinedOps) {}
443 using ConversionPattern::ConversionPattern;
444
445 // Generic pattern which replaces an operation by one of the same type, but
446 // with the in- and outputs synchronized through join semantics.
447 LogicalResult
448 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
449 ConversionPatternRewriter &rewriter) const override {
450 llvm::SmallVector<Value> inputData;
451
452 Value outToken;
453 if (operands.empty()) {
454 if (!op->hasTrait<OpTrait::ConstantLike>())
455 return op->emitOpError(
456 "no-operand operation which isn't constant-like. Too dangerous "
457 "to assume semantics - won't convert");
458
459 // Constant-like operation; assume the token can be represented as a
460 // constant `dc.source`.
461 outToken = rewriter.create<dc::SourceOp>(op->getLoc());
462 } else {
463 llvm::SmallVector<Value> inputTokens;
464 for (auto input : operands) {
465 auto dct = unpack(rewriter, input);
466 inputData.push_back(dct.data);
467 inputTokens.push_back(dct.token);
468 }
469 // Join the tokens of the inputs.
470 assert(!inputTokens.empty() && "Expected at least one input token");
471 outToken = rewriter.create<dc::JoinOp>(op->getLoc(), inputTokens);
472 }
473
474 // Patchwork to fix bad IR design in Handshake.
475 auto opName = op->getName();
476 if (opName.getStringRef() == "handshake.select") {
477 opName = OperationName("arith.select", getContext());
478 } else if (opName.getStringRef() == "handshake.constant") {
479 opName = OperationName("arith.constant", getContext());
480 }
481
482 // Re-create the operation using the unpacked input data.
483 OperationState state(op->getLoc(), opName, inputData, op->getResultTypes(),
484 op->getAttrs(), op->getSuccessors());
485
486 Operation *newOp = rewriter.create(state);
487 joinedOps->insert(newOp);
488
489 // Pack the result token with the output data, and replace the uses.
490 llvm::SmallVector<Value> results;
491 for (auto result : newOp->getResults())
492 results.push_back(pack(rewriter, outToken, result));
493
494 rewriter.replaceOp(op, results);
495
496 return success();
497 }
498
499 mutable ConvertedOps *joinedOps;
500};
501
502class SinkOpConversionPattern
503 : public DCOpConversionPattern<handshake::SinkOp> {
504public:
505 using DCOpConversionPattern<handshake::SinkOp>::DCOpConversionPattern;
506 using OpAdaptor = typename handshake::SinkOp::Adaptor;
507
508 LogicalResult
509 matchAndRewrite(handshake::SinkOp op, OpAdaptor adaptor,
510 ConversionPatternRewriter &rewriter) const override {
511 auto input = unpack(rewriter, adaptor.getOperand());
512 rewriter.replaceOpWithNewOp<dc::SinkOp>(op, input.token);
513 return success();
514 }
515};
516
517class SourceOpConversionPattern
518 : public DCOpConversionPattern<handshake::SourceOp> {
519public:
520 using DCOpConversionPattern<handshake::SourceOp>::DCOpConversionPattern;
521 using OpAdaptor = typename handshake::SourceOp::Adaptor;
522
523 LogicalResult
524 matchAndRewrite(handshake::SourceOp op, OpAdaptor adaptor,
525 ConversionPatternRewriter &rewriter) const override {
526 rewriter.replaceOpWithNewOp<dc::SourceOp>(op);
527 return success();
528 }
529};
530
531class BufferOpConversion : public DCOpConversionPattern<handshake::BufferOp> {
532public:
533 using DCOpConversionPattern<handshake::BufferOp>::DCOpConversionPattern;
534 using OpAdaptor = typename handshake::BufferOp::Adaptor;
535
536 LogicalResult
537 matchAndRewrite(handshake::BufferOp op, OpAdaptor adaptor,
538 ConversionPatternRewriter &rewriter) const override {
539 rewriter.getI32IntegerAttr(1);
540 rewriter.replaceOpWithNewOp<dc::BufferOp>(
541 op, adaptor.getOperand(), static_cast<size_t>(op.getNumSlots()),
542 op.getInitValuesAttr());
543 return success();
544 }
545};
546
547class ReturnOpConversion : public OpConversionPattern<handshake::ReturnOp> {
548public:
549 using OpConversionPattern<handshake::ReturnOp>::OpConversionPattern;
550 using OpAdaptor = typename handshake::ReturnOp::Adaptor;
551
552 LogicalResult
553 matchAndRewrite(handshake::ReturnOp op, OpAdaptor adaptor,
554 ConversionPatternRewriter &rewriter) const override {
555 // Locate existing output op, Append operands to output op, and move to
556 // the end of the block.
557 auto hwModule = op->getParentOfType<hw::HWModuleOp>();
558 auto outputOp = *hwModule.getBodyBlock()->getOps<hw::OutputOp>().begin();
559 outputOp->setOperands(adaptor.getOperands());
560 outputOp->moveAfter(&hwModule.getBodyBlock()->back());
561 rewriter.eraseOp(op);
562 return success();
563 }
564};
565
566class MuxOpConversionPattern : public DCOpConversionPattern<handshake::MuxOp> {
567public:
568 using DCOpConversionPattern<handshake::MuxOp>::DCOpConversionPattern;
569 using OpAdaptor = typename handshake::MuxOp::Adaptor;
570
571 LogicalResult
572 matchAndRewrite(handshake::MuxOp op, OpAdaptor adaptor,
573 ConversionPatternRewriter &rewriter) const override {
574 auto select = unpack(rewriter, adaptor.getSelectOperand());
575 auto selectData = select.data;
576 auto selectToken = select.token;
577 bool isIndexType = isa<IndexType>(selectData.getType());
578
579 bool withData = !isa<NoneType>(op.getResult().getType());
580
581 llvm::SmallVector<DCTuple> inputs;
582 for (auto input : adaptor.getDataOperands())
583 inputs.push_back(unpack(rewriter, input));
584
585 Value dataMux;
586 Value controlMux = inputs.front().token;
587 // Convert the data-side mux to a sequence of arith.select operations.
588 // The data and control muxes are assumed one-hot and the base-case is set
589 // as the first input.
590 if (withData)
591 dataMux = inputs[0].data;
592
593 llvm::SmallVector<Value> controlMuxInputs = {inputs.front().token};
594 for (auto [i, input] :
595 llvm::enumerate(llvm::make_range(inputs.begin() + 1, inputs.end()))) {
596 if (!withData)
597 continue;
598
599 Value cmpIndex;
600 Value inputData = input.data;
601 Value inputControl = input.token;
602 if (isIndexType) {
603 cmpIndex = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), i);
604 } else {
605 size_t width = cast<IntegerType>(selectData.getType()).getWidth();
606 cmpIndex = rewriter.create<arith::ConstantIntOp>(op.getLoc(), i, width);
607 }
608 auto inputSelected = rewriter.create<arith::CmpIOp>(
609 op.getLoc(), arith::CmpIPredicate::eq, selectData, cmpIndex);
610 dataMux = rewriter.create<arith::SelectOp>(op.getLoc(), inputSelected,
611 inputData, dataMux);
612
613 // Legalize the newly created operations.
614 convertedOps->insert(cmpIndex.getDefiningOp());
615 convertedOps->insert(dataMux.getDefiningOp());
616 convertedOps->insert(inputSelected);
617
618 // And similarly for the control mux, by muxing the input token with a
619 // select value that has it's control from the original select token +
620 // the inputSelected value.
621 auto inputSelectedControl = pack(rewriter, selectToken, inputSelected);
622 controlMux = rewriter.create<dc::SelectOp>(
623 op.getLoc(), inputSelectedControl, inputControl, controlMux);
624 convertedOps->insert(controlMux.getDefiningOp());
625 }
626
627 // finally, pack the control and data side muxes into the output value.
628 rewriter.replaceOp(
629 op, pack(rewriter, controlMux, withData ? dataMux : Value{}));
630 return success();
631 }
632};
633
634static hw::ModulePortInfo getModulePortInfoHS(const TypeConverter &tc,
635 handshake::FuncOp funcOp) {
636 SmallVector<hw::PortInfo> inputs, outputs;
637 auto ft = funcOp.getFunctionType();
638 funcOp.resolveArgAndResNames();
639
640 // Add all inputs of funcOp.
641 for (auto [index, type] : llvm::enumerate(ft.getInputs()))
642 inputs.push_back({{funcOp.getArgName(index), tc.convertType(type),
643 hw::ModulePort::Direction::Input},
644 index,
645 {}});
646
647 // Add all outputs of funcOp.
648 for (auto [index, type] : llvm::enumerate(ft.getResults()))
649 outputs.push_back({{funcOp.getResName(index), tc.convertType(type),
650 hw::ModulePort::Direction::Output},
651 index,
652 {}});
653
654 return hw::ModulePortInfo{inputs, outputs};
655}
656
657class FuncOpConversion : public DCOpConversionPattern<handshake::FuncOp> {
658public:
659 using DCOpConversionPattern<handshake::FuncOp>::DCOpConversionPattern;
660 using OpAdaptor = typename handshake::FuncOp::Adaptor;
661
662 // Replaces a handshake.func with a hw.module, converting the argument and
663 // result types using the provided type converter.
664 // @mortbopet: Not a fan of converting to hw here seeing as we don't
665 // necessarily have hardware semantics here. But, DC doesn't define a
666 // function operation, and there is no "func.graph_func" or any other
667 // generic function operation which is a graph region...
668 LogicalResult
669 matchAndRewrite(handshake::FuncOp op, OpAdaptor adaptor,
670 ConversionPatternRewriter &rewriter) const override {
671 ModulePortInfo ports = getModulePortInfoHS(*getTypeConverter(), op);
672
673 if (op.isExternal()) {
674 auto mod = rewriter.create<hw::HWModuleExternOp>(
675 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
676 convertedOps->insert(mod);
677 } else {
678 auto hwModule = rewriter.create<hw::HWModuleOp>(
679 op.getLoc(), rewriter.getStringAttr(op.getName()), ports);
680
681 auto &region = op->getRegions().front();
682
683 Region &moduleRegion = hwModule->getRegions().front();
684 rewriter.mergeBlocks(&region.getBlocks().front(), hwModule.getBodyBlock(),
685 hwModule.getBodyBlock()->getArguments());
686 TypeConverter::SignatureConversion result(moduleRegion.getNumArguments());
687 (void)getTypeConverter()->convertSignatureArgs(
688 TypeRange(moduleRegion.getArgumentTypes()), result);
689 rewriter.applySignatureConversion(hwModule.getBodyBlock(), result);
690 convertedOps->insert(hwModule);
691 }
692
693 rewriter.eraseOp(op);
694 return success();
695 }
696};
697
698/// Lower the ESIInstanceOp to `hw.instance` with `dc.from_esi` and `dc.to_esi`
699/// to convert the args/results.
700class ESIInstanceConversionPattern
701 : public OpConversionPattern<handshake::ESIInstanceOp> {
702public:
703 ESIInstanceConversionPattern(MLIRContext *context,
704 const HWSymbolCache &symCache)
705 : OpConversionPattern(context), symCache(symCache) {}
706
707 LogicalResult
708 matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
709 ConversionPatternRewriter &rewriter) const override {
710 Location loc = op.getLoc();
711 SmallVector<Value> operands;
712 for (size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
713 i < e; ++i)
714 operands.push_back(
715 rewriter.create<dc::FromESIOp>(loc, adaptor.getOperands()[i]));
716 operands.push_back(adaptor.getClk());
717 operands.push_back(adaptor.getRst());
718 // Locate the lowered module so the instance builder can get all the
719 // metadata.
720 Operation *targetModule = symCache.getDefinition(op.getModuleAttr());
721 // And replace the op with an instance of the target module.
722 auto inst = rewriter.create<hw::InstanceOp>(loc, targetModule,
723 op.getInstNameAttr(), operands);
724 SmallVector<Value> esiResults(
725 llvm::map_range(inst.getResults(), [&](Value v) {
726 return rewriter.create<dc::ToESIOp>(loc, v);
727 }));
728 rewriter.replaceOp(op, esiResults);
729 return success();
730 }
731
732private:
733 const HWSymbolCache &symCache;
734};
735
736/// Add DC clock and reset ports to the module.
737void addClkRst(hw::HWModuleOp mod, StringRef clkName, StringRef rstName) {
738 auto *ctx = mod.getContext();
739
740 size_t numInputs = mod.getNumInputPorts();
741 mod.insertInput(numInputs, clkName, seq::ClockType::get(ctx));
742 mod.setPortAttrs(
743 numInputs,
744 DictionaryAttr::get(ctx, {NamedAttribute(StringAttr::get(ctx, "dc.clock"),
745 UnitAttr::get(ctx))}));
746 mod.insertInput(numInputs + 1, rstName, IntegerType::get(ctx, 1));
747 mod.setPortAttrs(
748 numInputs + 1,
749 DictionaryAttr::get(ctx, {NamedAttribute(StringAttr::get(ctx, "dc.reset"),
750 UnitAttr::get(ctx))}));
751
752 // We must initialize any port attributes that are not set otherwise the
753 // verifier will fail.
754 for (size_t portNum = 0, e = mod.getNumPorts(); portNum < e; ++portNum) {
755 auto attrs = dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(portNum));
756 if (attrs)
757 continue;
758 mod.setPortAttrs(portNum, DictionaryAttr::get(ctx, {}));
759 }
760}
761
762class HandshakeToDCPass
763 : public circt::impl::HandshakeToDCBase<HandshakeToDCPass> {
764public:
765 using Base::Base;
766 void runOnOperation() override {
767 mlir::ModuleOp mod = getOperation();
768 auto patternBuilder = [&](TypeConverter &typeConverter,
769 handshaketodc::ConvertedOps &convertedOps,
770 RewritePatternSet &patterns) {
771 patterns.add<FuncOpConversion>(mod.getContext(), typeConverter,
772 &convertedOps);
773 patterns.add<ReturnOpConversion>(typeConverter, mod.getContext());
774 };
775
776 LogicalResult res =
777 runHandshakeToDC(mod, circt::HandshakeToDCOptions{clkName, rstName},
778 patternBuilder, nullptr);
779 if (failed(res))
780 signalPassFailure();
781 }
782};
783} // namespace
784
786 mlir::Operation *op, circt::HandshakeToDCOptions options,
787 llvm::function_ref<void(TypeConverter &typeConverter,
788 handshaketodc::ConvertedOps &convertedOps,
789 RewritePatternSet &patterns)>
790 patternBuilder,
791 llvm::function_ref<void(mlir::ConversionTarget &)> configureTarget) {
792 // Maintain the set of operations which has been converted either through
793 // unit rate conversion, or as part of other conversions.
794 // Rationale:
795 // This is needed for all of the arith ops that get created as part of the
796 // handshake ops (e.g. arith.select for handshake.mux). There's a bit of a
797 // dilemma here seeing as all operations need to be converted/touched in a
798 // handshake.func - which is done so by UnitRateConversionPattern (when no
799 // other pattern applies). However, we obviously don't want to run said
800 // pattern on these newly created ops since they do not have handshake
801 // semantics.
802 handshaketodc::ConvertedOps convertedOps;
803 mlir::MLIRContext *ctx = op->getContext();
804 ConversionTarget target(*ctx);
805 target.addIllegalDialect<handshake::HandshakeDialect>();
806 target.addLegalDialect<dc::DCDialect>();
807 target.addLegalOp<mlir::ModuleOp, handshake::ESIInstanceOp, hw::HWModuleOp,
808 hw::OutputOp>();
809
810 // And any user-specified target adjustments
811 if (configureTarget)
812 configureTarget(target);
813
814 // The various patterns will insert new operations into the module to
815 // facilitate the conversion - however, these operations must be
816 // distinguishable from already converted operations (which may be of the
817 // same type as the newly inserted operations). To do this, we mark all
818 // operations which have been converted as legal, and all other operations
819 // as illegal.
820 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
821 return convertedOps.contains(op) ||
822 // Allow any ops which weren't in a `handshake.func` to pass through.
823 !convertedOps.contains(op->getParentOfType<hw::HWModuleOp>());
824 });
825
826 DCTypeConverter typeConverter;
827 RewritePatternSet patterns(ctx);
828
829 // Add handshake conversion patterns.
830 // Note: merge/control merge are not supported - these are non-deterministic
831 // operators and we do not care for them.
833 .add<BufferOpConversion, CondBranchConversionPattern,
834 SinkOpConversionPattern, SourceOpConversionPattern,
835 MuxOpConversionPattern, ForkOpConversionPattern, JoinOpConversion,
836 PackOpConversion, UnpackOpConversion, MergeOpConversion,
837 ControlMergeOpConversion, ConstantOpConversion, SyncOpConversion>(
838 ctx, typeConverter, &convertedOps);
839
840 // ALL other single-result operations are converted via the
841 // UnitRateConversionPattern.
842 patterns.add<UnitRateConversionPattern>(ctx, typeConverter, &convertedOps);
843
844 // Build any user-specified patterns
845 patternBuilder(typeConverter, convertedOps, patterns);
846 if (failed(applyPartialConversion(op, target, std::move(patterns))))
847 return failure();
848
849 // Add clock and reset ports to each converted module.
850 for (auto &op : convertedOps)
851 if (auto mod = dyn_cast<hw::HWModuleOp>(op); mod)
852 addClkRst(mod, options.clkName, options.rstName);
853
854 // Run conversions which need see everything.
855 HWSymbolCache symbolCache;
856 symbolCache.addDefinitions(op);
857 symbolCache.freeze();
858 ConversionTarget globalLoweringTarget(*ctx);
859 globalLoweringTarget.addIllegalDialect<handshake::HandshakeDialect>();
860 globalLoweringTarget.addLegalDialect<dc::DCDialect, hw::HWDialect>();
861 RewritePatternSet globalPatterns(ctx);
862 globalPatterns.add<ESIInstanceConversionPattern>(ctx, symbolCache);
863 if (failed(applyPartialConversion(op, globalLoweringTarget,
864 std::move(globalPatterns))))
865 return op->emitOpError() << "error during conversion";
866
867 return success();
868}
assert(baseType &&"element must be base type")
static Type tupleToStruct(TupleType tuple)
Definition DCToHW.cpp:48
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition HWSymCache.h:27
void freeze()
Mark the cache as frozen, which allows it to be shared across threads.
Definition HWSymCache.h:75
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition HWSymCache.h:56
mlir::Type innerType(mlir::Type type)
Definition ESITypes.cpp:227
DenseSet< Operation * > ConvertedOps
LogicalResult runHandshakeToDC(mlir::Operation *op, HandshakeToDCOptions options, llvm::function_ref< void(TypeConverter &typeConverter, ConvertedOps &convertedOps, RewritePatternSet &patterns)> patternBuilder, llvm::function_ref< void(mlir::ConversionTarget &)> configureTarget={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
This holds a decoded list of input/inout and output ports for a module or instance.
Creates a new Calyx component for each FuncOp in the program.