CIRCT 23.0.0git
Loading...
Searching...
No Matches
loopback_typed.cpp
Go to the documentation of this file.
1#include "loopback/LoopbackIP.h"
2
3#include "esi/Accelerator.h"
4#include "esi/CLI.h"
5#include "esi/Manifest.h"
6#include "esi/Services.h"
7#include "esi/TypedPorts.h"
8
9#include <cstdint>
10#include <iostream>
11#include <random>
12#include <stdexcept>
13#include <vector>
14
15using namespace esi;
16
17static void runLoopbackI8(Accelerator *accel) {
18 AppIDPath lastLookup;
19 BundlePort *inPort = accel->resolvePort(
20 {AppID("loopback_inst", 0), AppID("loopback_tohw")}, lastLookup);
21 if (!inPort)
22 throw std::runtime_error("No loopback_tohw port found");
23 BundlePort *outPort = accel->resolvePort(
24 {AppID("loopback_inst", 0), AppID("loopback_fromhw")}, lastLookup);
25 if (!outPort)
26 throw std::runtime_error("No loopback_fromhw port found");
27
28 // Use TypedWritePort and TypedReadPort instead of raw channels.
29 TypedWritePort<uint8_t> toHw(inPort->getRawWrite("recv"));
30 TypedReadPort<uint8_t> fromHw(outPort->getRawRead("send"));
31 toHw.connect();
32 fromHw.connect();
33
34 uint8_t sendVal = 0x5a;
35 toHw.write(sendVal);
36
37 uint8_t got = fromHw.read();
38 if (got != sendVal)
39 throw std::runtime_error("Loopback byte mismatch");
40
41 std::cout << "loopback i8 ok: 0x" << std::hex << (int)got << std::dec << "\n";
42}
43
44static void runStructFunc(Accelerator *accel) {
45 AppIDPath lastLookup;
46 BundlePort *port = accel->resolvePort({AppID("structFunc")}, lastLookup);
47 if (!port)
48 throw std::runtime_error("No structFunc port found");
49
50 // Use TypedFunction instead of raw FuncService::Function.
53 func.connect();
54
55 esi_system::ArgStruct arg{};
56 arg.a = 0x1234;
57 arg.b = static_cast<int8_t>(-7);
58
59 esi_system::ResultStruct res = func.call(arg).get();
60
61 int8_t expectedX = static_cast<int8_t>(arg.b + 1);
62 if (res.x != expectedX || res.y != arg.b)
63 throw std::runtime_error("Struct func result mismatch");
64
65 std::cout << "struct func ok: b=" << (int)arg.b << " x=" << (int)res.x
66 << " y=" << (int)res.y << "\n";
67}
68
69static void runOddStructFunc(Accelerator *accel) {
70 AppIDPath lastLookup;
71 BundlePort *port = accel->resolvePort({AppID("oddStructFunc")}, lastLookup);
72 if (!port)
73 throw std::runtime_error("No oddStructFunc port found");
74
75 // Use TypedFunction with OddStruct for both arg and result.
78 func.connect();
79
80 esi_system::OddStruct arg{};
81 arg.a = 0xabc;
82 arg.b = static_cast<int8_t>(-17);
83 arg.inner.p = 5;
84 arg.inner.q = static_cast<int8_t>(-7);
85 arg.inner.r[0] = 3;
86 arg.inner.r[1] = 4;
87
88 esi_system::OddStruct res = func.call(arg).get();
89
90 uint16_t expectA = static_cast<uint16_t>(arg.a + 1);
91 int8_t expectB = static_cast<int8_t>(arg.b - 3);
92 uint8_t expectP = static_cast<uint8_t>(arg.inner.p + 5);
93 int8_t expectQ = static_cast<int8_t>(arg.inner.q + 2);
94 uint8_t expectR0 = static_cast<uint8_t>(arg.inner.r[0] + 1);
95 uint8_t expectR1 = static_cast<uint8_t>(arg.inner.r[1] + 2);
96 if (res.a != expectA || res.b != expectB || res.inner.p != expectP ||
97 res.inner.q != expectQ || res.inner.r[0] != expectR0 ||
98 res.inner.r[1] != expectR1)
99 throw std::runtime_error("Odd struct func result mismatch");
100
101 std::cout << "odd struct func ok: a=" << res.a << " b=" << (int)res.b
102 << " p=" << (int)res.inner.p << " q=" << (int)res.inner.q
103 << " r0=" << (int)res.inner.r[0] << " r1=" << (int)res.inner.r[1]
104 << "\n";
105}
106
107static void runArrayFunc(Accelerator *accel) {
108 AppIDPath lastLookup;
109 BundlePort *port = accel->resolvePort({AppID("arrayFunc")}, lastLookup);
110 if (!port)
111 throw std::runtime_error("No arrayFunc port found");
112
113 auto *func = port->getAs<services::FuncService::Function>();
114 if (!func)
115 throw std::runtime_error("arrayFunc not a FuncService::Function");
116 func->connect();
117
118 int8_t argArray[1] = {static_cast<int8_t>(-3)};
119 MessageData resMsg =
120 func->call(MessageData(reinterpret_cast<const uint8_t *>(argArray),
121 sizeof(argArray)))
122 .get();
123
124 const auto *res = resMsg.as<esi_system::ResultArray>();
125 int8_t a = (*res)[0];
126 int8_t b = (*res)[1];
127 int8_t expect0 = argArray[0];
128 int8_t expect1 = static_cast<int8_t>(argArray[0] + 1);
129
130 bool ok = (a == expect0 && b == expect1) || (a == expect1 && b == expect0);
131 if (!ok)
132 throw std::runtime_error("Array func result mismatch");
133
134 int8_t low = a;
135 int8_t high = b;
136 if (low > high) {
137 int8_t tmp = low;
138 low = high;
139 high = tmp;
140 }
141 std::cout << "array func ok: " << (int)low << " " << (int)high << "\n";
142}
143
144static void runSInt4Loopback(Accelerator *accel) {
145 AppIDPath lastLookup;
146 BundlePort *port = accel->resolvePort({AppID("sint4Func")}, lastLookup);
147 if (!port)
148 throw std::runtime_error("No sint4Func port found");
149
150 // Use TypedFunction<int8_t, int8_t> for si4 → si4 loopback.
151 // si4 fits in int8_t (width 4 <= 8). Tests sign extension of small widths.
154 func.connect();
155
156 // Test positive value.
157 int8_t posArg = 5;
158 int8_t posResult = func.call(posArg).get();
159 if (posResult != posArg)
160 throw std::runtime_error("sint4 loopback positive mismatch: got " +
161 std::to_string(posResult));
162
163 // Test negative value (-3, which is 0x0D in si4 wire format).
164 int8_t negArg = -3;
165 int8_t negResult = func.call(negArg).get();
166 if (negResult != negArg)
167 throw std::runtime_error("sint4 loopback negative mismatch: got " +
168 std::to_string(negResult));
169
170 std::cout << "sint4 loopback ok: pos=" << (int)posResult
171 << " neg=" << (int)negResult << "\n";
172}
173
174//
175// SerialCoordTranslator test
176//
177
178using SerialCoordInput = esi_system::serial_coord_args;
179using SerialCoordValue = SerialCoordInput::value_type;
180
181#pragma pack(push, 1)
183 uint8_t _pad[6];
184 uint16_t coordsCount;
185};
187 uint32_t y;
188 uint32_t x;
189};
193};
194#pragma pack(pop)
195static_assert(sizeof(SerialCoordOutputFrame) == 8, "Size mismatch");
196
198 size_t numCoords = 100;
199 uint32_t xTrans = 10, yTrans = 20;
200
201 // Generate random coordinates.
202 std::mt19937 rng(0xDEADBEEF);
203 std::uniform_int_distribution<uint32_t> dist(0, 1000000);
204 std::vector<SerialCoordValue> coords;
205 coords.reserve(numCoords);
206 for (uint32_t i = 0; i < numCoords; ++i)
207 coords.emplace_back(dist(rng), dist(rng));
208
209 auto child = accel->getChildren().find(AppID("coord_translator_serial"));
210 if (child == accel->getChildren().end())
211 throw std::runtime_error("Serial coord translate test: no "
212 "'coord_translator_serial' child found");
213
214 auto &ports = child->second->getPorts();
215 auto portIter = ports.find(AppID("translate_coords_serial"));
216 if (portIter == ports.end())
217 throw std::runtime_error(
218 "Serial coord translate test: no 'translate_coords_serial' port found");
219
220 auto *func = portIter->second.getAs<services::FuncService::Function>();
221 if (!func)
222 throw std::runtime_error(
223 "Serial coord translate test: port is not a FuncService::Function");
224
225 // Keep the raw result channel here: the serial window reply arrives as
226 // multiple frames, while FuncService::Function / TypedFunction only waits
227 // for a single result message.
228 TypedWritePort<SerialCoordInput, /*SkipTypeCheck=*/true> argPort(
229 func->getRawWrite("arg"));
230 ReadChannelPort &resultPort = func->getRawRead("result");
231
232 argPort.connect(ChannelPort::ConnectOptions(std::nullopt, false));
233 resultPort.connect(ChannelPort::ConnectOptions(std::nullopt, false));
234
235 auto batch = std::make_unique<SerialCoordInput>(xTrans, yTrans, coords);
236 argPort.write(batch);
237
238 std::vector<SerialCoordValue> results;
239 while (true) {
240 MessageData msg;
241 resultPort.read(msg);
242 if (msg.getSize() != sizeof(SerialCoordOutputFrame))
243 throw std::runtime_error("Unexpected result message size");
244
245 const auto *frame =
246 reinterpret_cast<const SerialCoordOutputFrame *>(msg.getBytes());
247 uint16_t batchCount = frame->header.coordsCount;
248 if (batchCount == 0)
249 break;
250
251 for (uint16_t i = 0; i < batchCount; ++i) {
252 resultPort.read(msg);
253 if (msg.getSize() != sizeof(SerialCoordOutputFrame))
254 throw std::runtime_error("Unexpected result message size");
255 const auto *dFrame =
256 reinterpret_cast<const SerialCoordOutputFrame *>(msg.getBytes());
257 results.push_back({dFrame->data.y, dFrame->data.x});
258 }
259 }
260
261 if (results.size() != coords.size())
262 throw std::runtime_error("Serial coord translate result size mismatch");
263 for (size_t i = 0; i < coords.size(); ++i) {
264 uint32_t expX = coords[i].x + xTrans;
265 uint32_t expY = coords[i].y + yTrans;
266 if (results[i].x != expX || results[i].y != expY)
267 throw std::runtime_error("Serial coord translate result mismatch");
268 }
269
270 argPort.disconnect();
271 resultPort.disconnect();
272}
273
274int main(int argc, const char *argv[]) {
275 CliParser cli("loopback-typed-cpp");
276 cli.description(
277 "Loopback cosim test using generated ESI headers and typed ports.");
278 if (int rc = cli.esiParse(argc, argv))
279 return rc;
280 if (!cli.get_help_ptr()->empty())
281 return 0;
282
283 Context &ctxt = cli.getContext();
284 AcceleratorConnection *conn = cli.connect();
285 try {
286 const auto &info = *conn->getService<services::SysInfo>();
287 Manifest manifest(ctxt, info.getJsonManifest());
288 Accelerator *accel = manifest.buildAccelerator(*conn);
289 conn->getServiceThread()->addPoll(*accel);
290
291 std::cout << "depth: 0x" << std::hex << esi_system::LoopbackIP::depth
292 << std::dec << "\n";
293
294 runLoopbackI8(accel);
295 runSInt4Loopback(accel);
296 runStructFunc(accel);
297 runOddStructFunc(accel);
298 runArrayFunc(accel);
300
301 conn->disconnect();
302 } catch (std::exception &e) {
303 ctxt.getLogger().error("loopback-typed-cpp", e.what());
304 conn->disconnect();
305 return 1;
306 }
307
308 return 0;
309}
Abstract class representing a connection to an accelerator.
Definition Accelerator.h:89
Top level accelerator class.
Definition Accelerator.h:70
Services provide connections to 'bundles' – collections of named, unidirectional communication channe...
Definition Ports.h:456
T * getAs() const
Cast this Bundle port to a subclass which is actually useful.
Definition Ports.h:484
ReadChannelPort & getRawRead(const std::string &name) const
Definition Ports.cpp:52
WriteChannelPort & getRawWrite(const std::string &name) const
Get access to the raw byte streams of a channel.
Definition Ports.cpp:42
Common options and code for ESI runtime tools.
Definition CLI.h:29
Context & getContext()
Get the context.
Definition CLI.h:69
AcceleratorConnection * connect()
Connect to the accelerator using the specified backend and connection.
Definition CLI.h:66
int esiParse(int argc, const char **argv)
Run the parser.
Definition CLI.h:52
AcceleratorConnections, Accelerators, and Manifests must all share a context.
Definition Context.h:34
Logger & getLogger()
Definition Context.h:69
BundlePort * resolvePort(const AppIDPath &path, AppIDPath &lastLookup) const
Attempt to resolve a path to a port.
Definition Design.cpp:72
const std::map< AppID, Instance * > & getChildren() const
Access the module's children by ID.
Definition Design.h:71
virtual void error(const std::string &subsystem, const std::string &msg, const std::map< std::string, std::any > *details=nullptr)
Report an error.
Definition Logging.h:64
Class to parse a manifest.
Definition Manifest.h:39
A logical chunk of data representing serialized data.
Definition Common.h:113
const uint8_t * getBytes() const
Definition Common.h:124
const T * as() const
Cast to a type.
Definition Common.h:148
size_t getSize() const
Get the size of the data in bytes.
Definition Common.h:138
A ChannelPort which reads data from the accelerator.
Definition Ports.h:341
virtual void connect(std::function< bool(MessageData)> callback, const ConnectOptions &options={})
Definition Ports.cpp:69
virtual void disconnect() override
Definition Ports.h:346
virtual void read(MessageData &outData)
Specify a buffer to read into.
Definition Ports.h:381
std::future< ResultT > call(const ArgT &arg)
Definition TypedPorts.h:444
void connect(const ChannelPort::ConnectOptions &opts={})
Definition TypedPorts.h:320
void connect(const ChannelPort::ConnectOptions &opts={})
Definition TypedPorts.h:233
void write(const T &data)
Definition TypedPorts.h:242
A function call which gets attached to a service port.
Definition Services.h:353
Information about the Accelerator system.
Definition Services.h:113
esi_system::serial_coord_args SerialCoordInput
static void runOddStructFunc(Accelerator *accel)
int main(int argc, const char *argv[])
static void serialCoordTranslateTest(Accelerator *accel)
static void runLoopbackI8(Accelerator *accel)
SerialCoordInput::value_type SerialCoordValue
static void runStructFunc(Accelerator *accel)
static void runArrayFunc(Accelerator *accel)
static void runSInt4Loopback(Accelerator *accel)
Definition esi.py:1
SerialCoordOutputData data
SerialCoordOutputHeader header