CIRCT 23.0.0git
Loading...
Searching...
No Matches
ESIRuntimeTypedPortsTest.cpp
Go to the documentation of this file.
1//===- ESIRuntimeTypedPortsTest.cpp - Typed ESI port tests ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "esi/TypedPorts.h"
10#include "gtest/gtest.h"
11
12using namespace esi;
13
14namespace {
15
16//===----------------------------------------------------------------------===//
17// verifyTypeCompatibility tests
18//===----------------------------------------------------------------------===//
19
20TEST(TypedPortsTest, VoidTypeCompatibility) {
21 VoidType voidType("void");
22 EXPECT_NO_THROW(verifyTypeCompatibility<void>(&voidType));
23
24 // Non-void types should fail.
25 UIntType uint1("ui1", 1);
26 EXPECT_THROW(verifyTypeCompatibility<void>(&uint1), AcceleratorMismatchError);
27
28 SIntType sint32("si32", 32);
29 EXPECT_THROW(verifyTypeCompatibility<void>(&sint32),
31}
32
33TEST(TypedPortsTest, BoolTypeCompatibility) {
34 BitsType bits1("i1", 1);
35 EXPECT_NO_THROW(verifyTypeCompatibility<bool>(&bits1));
36
37 // Width > 1 should fail.
38 BitsType bits8("i8", 8);
39 EXPECT_THROW(verifyTypeCompatibility<bool>(&bits8), AcceleratorMismatchError);
40
41 // Wrong type entirely should fail.
42 SIntType sint1("si1", 1);
43 EXPECT_THROW(verifyTypeCompatibility<bool>(&sint1), AcceleratorMismatchError);
44}
45
46TEST(TypedPortsTest, SignedIntTypeCompatibility) {
47 // int32_t can hold si17 (width 17, in range (16,32]).
48 SIntType sint17("si17", 17);
49 EXPECT_NO_THROW(verifyTypeCompatibility<int32_t>(&sint17));
50
51 // si32 has width 32, which fits exactly in int32_t. Should pass.
52 SIntType sint32("si32", 32);
53 EXPECT_NO_THROW(verifyTypeCompatibility<int32_t>(&sint32));
54
55 // si33 has width 33, which exceeds int32_t. Should fail.
56 SIntType sint33("si33", 33);
57 EXPECT_THROW(verifyTypeCompatibility<int32_t>(&sint33),
59
60 // si16 fits in int32_t but a smaller type (int16_t) would suffice. Reject.
61 SIntType sint16("si16", 16);
62 EXPECT_THROW(verifyTypeCompatibility<int32_t>(&sint16),
64
65 // si8 is even smaller — also reject for int32_t.
66 SIntType sint8("si8", 8);
67 EXPECT_THROW(verifyTypeCompatibility<int32_t>(&sint8),
69
70 // But si8 should be fine for int8_t (closest match).
71 EXPECT_NO_THROW(verifyTypeCompatibility<int8_t>(&sint8));
72
73 // UIntType should fail for signed C++ type.
74 UIntType uint31("ui31", 31);
75 EXPECT_THROW(verifyTypeCompatibility<int32_t>(&uint31),
77
78 // int64_t can hold si33 (width 33, in range (32,64]).
79 SIntType sint33b("si33", 33);
80 EXPECT_NO_THROW(verifyTypeCompatibility<int64_t>(&sint33b));
81
82 // si64 fits exactly in int64_t. Should pass.
83 SIntType sint64("si64", 64);
84 EXPECT_NO_THROW(verifyTypeCompatibility<int64_t>(&sint64));
85
86 // si65 exceeds int64_t. Should fail.
87 SIntType sint65("si65", 65);
88 EXPECT_THROW(verifyTypeCompatibility<int64_t>(&sint65),
90
91 // si32 fits in int64_t but int32_t would suffice. Reject.
92 EXPECT_THROW(verifyTypeCompatibility<int64_t>(&sint32),
94}
95
96TEST(TypedPortsTest, UnsignedIntTypeCompatibility) {
97 // uint32_t can hold ui17 (width 17, in range (16,32]).
98 UIntType uint17("ui17", 17);
99 EXPECT_NO_THROW(verifyTypeCompatibility<uint32_t>(&uint17));
100
101 // ui32 has width 32, which fits exactly in uint32_t. Should pass.
102 UIntType uint32_t_("ui32", 32);
103 EXPECT_NO_THROW(verifyTypeCompatibility<uint32_t>(&uint32_t_));
104
105 // ui33 exceeds uint32_t. Should fail.
106 UIntType uint33("ui33", 33);
107 EXPECT_THROW(verifyTypeCompatibility<uint32_t>(&uint33),
109
110 // ui16 fits but uint16_t would suffice. Reject for uint32_t.
111 UIntType uint16_("ui16", 16);
112 EXPECT_THROW(verifyTypeCompatibility<uint32_t>(&uint16_),
114
115 // But ui16 should be fine for uint16_t.
116 EXPECT_NO_THROW(verifyTypeCompatibility<uint16_t>(&uint16_));
117
118 // BitsType (signless iM) should also be accepted for unsigned.
119 BitsType bits17("i17", 17);
120 EXPECT_NO_THROW(verifyTypeCompatibility<uint32_t>(&bits17));
121
122 // BitsType with width 32 fits in uint32_t. Should pass.
123 BitsType bits32("i32", 32);
124 EXPECT_NO_THROW(verifyTypeCompatibility<uint32_t>(&bits32));
125
126 // BitsType with width 33 exceeds uint32_t. Should fail.
127 BitsType bits33("i33", 33);
128 EXPECT_THROW(verifyTypeCompatibility<uint32_t>(&bits33),
130
131 // BitsType width 8 should be rejected for uint32_t (uint8_t suffices).
132 BitsType bits8("i8", 8);
133 EXPECT_THROW(verifyTypeCompatibility<uint32_t>(&bits8),
135 EXPECT_NO_THROW(verifyTypeCompatibility<uint8_t>(&bits8));
136
137 // uint64_t with ui33 (in range (32,64]).
138 UIntType uint33b("ui33", 33);
139 EXPECT_NO_THROW(verifyTypeCompatibility<uint64_t>(&uint33b));
140
141 // uint64_t with ui64 fits exactly. Should pass.
142 UIntType uint64_t_("ui64", 64);
143 EXPECT_NO_THROW(verifyTypeCompatibility<uint64_t>(&uint64_t_));
144
145 // uint64_t with ui65 exceeds. Should fail.
146 UIntType uint65("ui65", 65);
147 EXPECT_THROW(verifyTypeCompatibility<uint64_t>(&uint65),
149
150 // uint64_t with ui32 — uint32_t would suffice. Reject.
151 EXPECT_THROW(verifyTypeCompatibility<uint64_t>(&uint32_t_),
153
154 // SIntType should fail for unsigned C++ type.
155 SIntType sint31("si31", 31);
156 EXPECT_THROW(verifyTypeCompatibility<uint32_t>(&sint31),
158}
159
160// Test struct with _ESI_ID.
161struct TestStruct {
162 static constexpr std::string_view _ESI_ID = "MyModule.TestStruct";
163 uint32_t field1;
164 uint16_t field2;
165};
166
167TEST(TypedPortsTest, ESIIDTypeCompatibility) {
168 // Matching ID should pass.
169 StructType matchType("MyModule.TestStruct", {});
170 EXPECT_NO_THROW(verifyTypeCompatibility<TestStruct>(&matchType));
171
172 // Mismatched ID should fail.
173 StructType mismatchType("OtherModule.OtherStruct", {});
174 EXPECT_THROW(verifyTypeCompatibility<TestStruct>(&mismatchType),
176
177 // Even a non-struct type with matching ID should pass (ID comparison only).
178 UIntType uintWithMatchingID("MyModule.TestStruct", 32);
179 EXPECT_NO_THROW(verifyTypeCompatibility<TestStruct>(&uintWithMatchingID));
180}
181
182TEST(TypedPortsTest, NullPortTypeThrows) {
183 EXPECT_THROW(verifyTypeCompatibility<int32_t>(nullptr),
185 EXPECT_THROW(verifyTypeCompatibility<void>(nullptr),
187}
188
189// A type that is not integral and has no _ESI_ID — should hit fallback.
190struct UnknownCppType {
191 double x;
192};
193
194TEST(TypedPortsTest, FallbackThrows) {
195 UIntType uint32("ui32", 32);
196 EXPECT_THROW(verifyTypeCompatibility<UnknownCppType>(&uint32),
198}
199
200//===----------------------------------------------------------------------===//
201// TypedWritePort round-trip tests (verify MessageData encoding)
202//===----------------------------------------------------------------------===//
203
204// A minimal concrete WriteChannelPort for testing. Captures the last written
205// MessageData instead of sending it anywhere.
206class MockWritePort : public WriteChannelPort {
207public:
208 MockWritePort(const Type *type) : WriteChannelPort(type) {}
209
210 void connect(const ConnectOptions &opts = {}) override {
211 connectImpl(opts);
212 connected = true;
213 }
214 void disconnect() override { connected = false; }
215 bool isConnected() const override { return connected; }
216
217 MessageData lastWritten;
218
219protected:
220 void writeImpl(const MessageData &data) override { lastWritten = data; }
221 bool tryWriteImpl(const MessageData &data) override {
222 lastWritten = data;
223 return true;
224 }
225
226private:
227 bool connected = false;
228};
229
230TEST(TypedPortsTest, TypedWritePortConnectThrowsOnMismatch) {
231 UIntType uint32("ui32", 32);
232 MockWritePort mock(&uint32);
233 TypedWritePort<int32_t> typed(mock); // int32_t expects SIntType
234 EXPECT_THROW(typed.connect(), AcceleratorMismatchError);
235}
236
237TEST(TypedPortsTest, TypedWritePortConnectSucceeds) {
238 SIntType sint31("si31", 31);
239 MockWritePort mock(&sint31);
240 TypedWritePort<int32_t> typed(mock);
241 EXPECT_NO_THROW(typed.connect());
242 EXPECT_TRUE(typed.isConnected());
243}
244
245TEST(TypedPortsTest, TypedWritePortRoundTrip) {
246 SIntType sint15("si15", 15);
247 MockWritePort mock(&sint15);
248 TypedWritePort<int16_t> typed(mock);
249 typed.connect();
250
251 int16_t val = 12345;
252 typed.write(val);
253
254 // Wire size for si15 is 2 bytes ((15+7)/8).
255 ASSERT_EQ(mock.lastWritten.getSize(), 2u);
256}
257
258TEST(TypedPortsTest, SignExtensionNonByteAligned) {
259 // Test fromMessageData sign extension for non-byte-aligned widths.
260 // si4 value -1 is wire 0x0F (4 bits: 1111). Sign bit is bit 3.
261 {
262 SIntType si4("si4", 4);
263 WireInfo wi = getWireInfo(&si4);
264 EXPECT_EQ(wi.bytes, 1u);
265 EXPECT_EQ(wi.bitWidth, 4u);
266 uint8_t wire = 0x0F; // -1 in si4
267 MessageData msg(&wire, 1);
268 int32_t val = fromMessageData<int32_t>(msg, wi);
269 EXPECT_EQ(val, -1);
270 }
271 // si4 value 7 is wire 0x07 (4 bits: 0111). Positive, no sign extension.
272 {
273 SIntType si4("si4", 4);
274 WireInfo wi = getWireInfo(&si4);
275 uint8_t wire = 0x07;
276 MessageData msg(&wire, 1);
277 int32_t val = fromMessageData<int32_t>(msg, wi);
278 EXPECT_EQ(val, 7);
279 }
280 // si22 value -1 is wire {0xFF, 0xFF, 0x3F} (22 bits all 1s).
281 // Sign bit is bit 21 = bit 5 of byte 2, mask 0x20.
282 {
283 SIntType si22("si22", 22);
284 WireInfo wi = getWireInfo(&si22);
285 EXPECT_EQ(wi.bytes, 3u);
286 uint8_t wire[3] = {0xFF, 0xFF, 0x3F};
287 MessageData msg(wire, 3);
288 int32_t val = fromMessageData<int32_t>(msg, wi);
289 EXPECT_EQ(val, -1);
290 }
291 // si22 positive value: 0x1FFFFF (all data bits 1, sign bit 0)
292 {
293 SIntType si22("si22", 22);
294 WireInfo wi = getWireInfo(&si22);
295 uint8_t wire[3] = {0xFF, 0xFF, 0x1F}; // bit 21 = 0
296 MessageData msg(wire, 3);
297 int32_t val = fromMessageData<int32_t>(msg, wi);
298 EXPECT_EQ(val, 0x1FFFFF); // 2097151
299 }
300}
301
302TEST(TypedPortsTest, TypedWritePortVoid) {
303 VoidType voidType("void");
304 MockWritePort mock(&voidType);
305 TypedWritePort<void> typed(mock);
306 EXPECT_NO_THROW(typed.connect());
307
308 typed.write();
309 ASSERT_EQ(mock.lastWritten.getSize(), 1u);
310 EXPECT_EQ(mock.lastWritten.getData()[0], 0);
311}
312
313//===----------------------------------------------------------------------===//
314// MockReadPort for TypedFunction testing
315//===----------------------------------------------------------------------===//
316
317// A minimal concrete ReadChannelPort that returns a preset response.
318class MockReadPort : public ReadChannelPort {
319public:
320 MockReadPort(const Type *type) : ReadChannelPort(type) {}
321
322 void connect(std::function<bool(MessageData)>,
323 const ConnectOptions & = {}) override {
324 mode = Mode::Callback;
325 }
326 void connect(const ConnectOptions & = {}) override { mode = Mode::Polling; }
327
328 void read(MessageData &outData) override { outData = nextResponse; }
329 std::future<MessageData> readAsync() override {
330 std::promise<MessageData> p;
331 p.set_value(nextResponse);
332 return p.get_future();
333 }
334
335 MessageData nextResponse;
336};
337
338//===----------------------------------------------------------------------===//
339// TypedFunction tests
340//===----------------------------------------------------------------------===//
341
342TEST(TypedPortsTest, TypedFunctionNullThrowsAtConnect) {
343 // Null is accepted at construction but throws at connect().
345 EXPECT_THROW(typed.connect(), AcceleratorMismatchError);
346}
347
348TEST(TypedPortsTest, TypedFunctionConnectVerifiesTypes) {
349 // Create channel types matching si24 arg and ui16 result.
350 SIntType argInner("si24", 24);
351 ChannelType argChanType("channel<si24>", &argInner);
352 UIntType resultInner("ui15", 15);
353 ChannelType resultChanType("channel<ui15>", &resultInner);
354
355 BundleType::ChannelVector channels = {
356 {"arg", BundleType::Direction::To, &argChanType},
357 {"result", BundleType::Direction::From, &resultChanType},
358 };
359 BundleType bundleType("func_bundle", channels);
360
361 MockWritePort mockWrite(&argInner);
362 MockReadPort mockRead(&resultInner);
363
364 auto *func = services::FuncService::Function::get(AppID("test"), &bundleType,
365 mockWrite, mockRead);
366
367 // int32_t arg (signed) against si24 — should pass.
368 // uint16_t result (unsigned) against ui15 — should pass.
370 EXPECT_NO_THROW(typed.connect());
371 delete func;
372}
373
374TEST(TypedPortsTest, TypedFunctionConnectRejectsArgMismatch) {
375 UIntType argInner("ui24", 24);
376 ChannelType argChanType("channel<ui24>", &argInner);
377 UIntType resultInner("ui15", 15);
378 ChannelType resultChanType("channel<ui15>", &resultInner);
379
380 BundleType::ChannelVector channels = {
381 {"arg", BundleType::Direction::To, &argChanType},
382 {"result", BundleType::Direction::From, &resultChanType},
383 };
384 BundleType bundleType("func_bundle", channels);
385
386 MockWritePort mockWrite(&argInner);
387 MockReadPort mockRead(&resultInner);
388
389 auto *func = services::FuncService::Function::get(AppID("test"), &bundleType,
390 mockWrite, mockRead);
391
392 // int32_t (signed) against UIntType — should fail at connect.
394 EXPECT_THROW(typed.connect(), AcceleratorMismatchError);
395 delete func;
396}
397
398TEST(TypedPortsTest, TypedFunctionCallRoundTrip) {
399 SIntType argInner("si24", 24);
400 ChannelType argChanType("channel<si24>", &argInner);
401 UIntType resultInner("ui15", 15);
402 ChannelType resultChanType("channel<ui15>", &resultInner);
403
404 BundleType::ChannelVector channels = {
405 {"arg", BundleType::Direction::To, &argChanType},
406 {"result", BundleType::Direction::From, &resultChanType},
407 };
408 BundleType bundleType("func_bundle", channels);
409
410 MockWritePort mockWrite(&argInner);
411 MockReadPort mockRead(&resultInner);
412
413 // Set up mock read to return a known uint16_t value.
414 uint16_t expected = 42;
415 mockRead.nextResponse = MessageData::from(expected);
416
417 auto *func = services::FuncService::Function::get(AppID("test"), &bundleType,
418 mockWrite, mockRead);
419
421 typed.connect();
422
423 int32_t arg = 100;
424 uint16_t result = typed.call(arg).get();
425 EXPECT_EQ(result, 42);
426
427 // Verify the written arg matches — si24 wire size is 3 bytes.
428 ASSERT_EQ(mockWrite.lastWritten.getSize(), 3u);
429 delete func;
430}
431
432} // namespace
Bits are just an array of bits.
Definition Types.h:199
Bundles represent a collection of channels.
Definition Types.h:99
std::vector< std::tuple< std::string, Direction, const Type * > > ChannelVector
Definition Types.h:104
virtual void connectImpl(const ConnectOptions &options)
Called by all connect methods to let backends initiate the underlying connections.
Definition Ports.h:202
Channels are the basic communication primitives.
Definition Types.h:120
A logical chunk of data representing serialized data.
Definition Common.h:113
static MessageData from(T &t)
Cast from a type to its raw bytes.
Definition Common.h:158
A ChannelPort which reads data from the accelerator.
Definition Ports.h:318
virtual std::future< MessageData > readAsync()
Asynchronous read.
Definition Ports.cpp:126
virtual void connect(std::function< bool(MessageData)> callback, const ConnectOptions &options={})
Definition Ports.cpp:69
virtual void read(MessageData &outData)
Specify a buffer to read into.
Definition Ports.h:358
Signed integer.
Definition Types.h:217
Structs are an ordered collection of fields, each with a name and a type.
Definition Types.h:239
Root class of the ESI type system.
Definition Types.h:36
Unsigned integer.
Definition Types.h:228
The "void" type is a special type which can be used to represent no type.
Definition Types.h:136
A ChannelPort which sends data to the accelerator.
Definition Ports.h:206
virtual bool isConnected() const override
Definition Ports.h:218
virtual void disconnect() override
Definition Ports.h:217
virtual bool tryWriteImpl(const MessageData &data)=0
Implementation for tryWrite(). Subclasses must implement this.
volatile bool connected
Definition Ports.h:290
virtual void connect(const ConnectOptions &options={}) override
Set up a connection to the accelerator.
Definition Ports.h:210
virtual void writeImpl(const MessageData &)=0
Implementation for write(). Subclasses must implement this.
static Function * get(AppID id, BundleType *type, WriteChannelPort &arg, ReadChannelPort &result)
Definition Services.cpp:286
Definition esi.py:1
WireInfo getWireInfo(const Type *portType)
Definition TypedPorts.h:69
Compute the wire byte count for a port type.
Definition TypedPorts.h:64
size_t bitWidth
Definition TypedPorts.h:66