CIRCT 23.0.0git
Loading...
Searching...
No Matches
common.py
Go to the documentation of this file.
1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2# See https://llvm.org/LICENSE.txt for license information.
3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from __future__ import annotations
6from math import ceil
7
8from pycde.common import Clock, Input, InputChannel, Output, OutputChannel, Reset
9from pycde.constructs import (AssignableSignal, ControlReg, Counter, Mux,
10 NamedWire, Wire)
11from pycde import esi
12from pycde.module import Module, generator, modparams
13from pycde.signals import BitsSignal, ChannelSignal, StructSignal
14from pycde.support import clog2
15from pycde.system import System
16from pycde.types import (Array, Bits, Bundle, BundledChannel, Channel,
17 ChannelDirection, StructType, Type, UInt, Window)
18
19from typing import Callable, Dict, List, Tuple
20import typing
21
22MagicNumber = 0x207D98E5_E5100E51 # random + ESI__ESI
23VersionNumber = 0 # Version 0: format subject to change
24
25IndirectionMagicNumber = 0x312bf0cc_E5100E51 # random + ESI__ESI
26IndirectionVersionNumber = 0 # Version 0: format subject to change
27
28# Magic value which, when written by the host to header slot 7, requests a
29# design reset. Keep in sync with 'ResetMagicNumber' in the runtime
30# (cpp/include/esi/Accelerator.h). This magic number guards against "write
31# spraying" which other devices have been know to do on boot.
32ResetMagicNumber = 0x00000E510000B007
33# Number of cycles to wait after a reset is requested before asserting it. This
34# gives in-flight transactions time to drain.
35ResetCycles = 8192
36
37
38class ESI_Manifest_ROM(Module):
39 """Module which will be created later by CIRCT which will contain the
40 compressed manifest."""
41
42 module_name = "__ESI_Manifest_ROM"
43
44 clk = Clock()
45 address = Input(Bits(29))
46 # Data is two cycles delayed after address changes.
47 data = Output(Bits(64))
48
49
51 """Wrap the manifest ROM with ESI bundle."""
52
53 clk = Clock()
54 read = Input(esi.MMIO.read.type)
55
56 @generator
57 def build(self):
58 data, data_valid = Wire(Bits(64)), Wire(Bits(1))
59 data_chan, data_ready = Channel(Bits(64)).wrap(data, data_valid)
60 address_chan = self.read.unpack(data=data_chan)['offset']
61 address, address_valid = address_chan.unwrap(data_ready)
62 address_words = address.as_bits(32)[3:] # Lop off the lower three bits.
63
64 rom = ESI_Manifest_ROM(clk=self.clk, address=address_words)
65 data.assign(rom.data)
66 data_valid.assign(address_valid.reg(self.clk, name="data_valid", cycles=2))
67
68
69@modparams
70def HeaderMMIO(manifest_loc: int) -> Module:
71
72 class HeaderMMIO(Module):
73 """Construct the ESI header MMIO adhering to the MMIO layout specified in
74 the ChannelMMIO service implementation."""
75
76 clk = Clock()
77 rst = Reset()
78 read = Input(esi.MMIO.read_write.type)
79 # Asserted for one cycle when the host writes the reset magic number to
80 # header slot 7. Propagates up to the BSP which performs the actual reset.
81 reset_request = Output(Bits(1))
82
83 @generator
84 def build(ports):
85 clk = ports.clk
86 rst = ports.rst
87 data_chan_wire = Wire(Channel(esi.MMIODataType))
88 input_bundles = ports.read.unpack(data=data_chan_wire)
89 cmd_chan = input_bundles['cmd']
90
91 # Two-stage half-throughput pipeline: stage 1 captures the incoming
92 # command, stage 2 holds the looked-up response. Each stage carries its
93 # own occupancy bit.
94 cmd_ready = Wire(Bits(1))
95 s1_to_s2_xact = Wire(Bits(1))
96 cmd_raw, cmd_valid = cmd_chan.unwrap(cmd_ready)
97
98 # Stage 1: command capture register and occupancy bit.
99 s1_load = cmd_valid & cmd_ready
100 cmd = cmd_raw.reg(clk, rst, ce=s1_load, name="cmd")
101 s1_valid = ControlReg(clk,
102 rst,
103 asserts=[s1_load],
104 resets=[s1_to_s2_xact],
105 name="s1_valid")
106 # Accept a new command when stage 1 is empty.
107 cmd_ready.assign(~s1_valid)
108
109 address_words = cmd.offset.as_bits()[3:] # Lop off the lower three bits.
110 slot = address_words[:3]
111
112 cycles = Counter(64)(clk=ports.clk,
113 rst=ports.rst,
114 clear=Bits(1)(0),
115 increment=Bits(1)(1),
116 instance_name="cycle_counter")
117
118 # Layout the header as an array.
119 core_freq = System.current().core_freq
120 if core_freq is None:
121 core_freq = 0
122 header = Array(Bits(64), 8)([
123 0, # Generally a good idea to not use address 0.
124 MagicNumber, # ESI magic number.
125 VersionNumber, # ESI version number.
126 manifest_loc, # Absolute address of the manifest ROM.
127 0, # Reserved for future use.
128 cycles.out.as_bits(), # Cycle counter.
129 core_freq, # Core frequency, if known.
130 0, # Slot 7: write the reset magic number here to request a reset.
131 ])
132 header.name = "header"
133
134 # Stage 2: registered response value and its occupancy bit.
135 s2_valid = Wire(Bits(1))
136 data_chan_ready = Wire(Bits(1))
137 s2_xact = s2_valid & data_chan_ready
138 # Stage 1 advances into stage 2 only when stage 2 is empty.
139 s1_to_s2_xact.assign(s1_valid & ~s2_valid)
140
141 header_out = header[slot].reg(clk=clk,
142 rst=rst,
143 ce=s1_to_s2_xact,
144 name="header_out")
145 s2_valid.assign(
146 ControlReg(clk,
147 rst,
148 asserts=[s1_to_s2_xact],
149 resets=[s2_xact],
150 name="header_out_valid"))
151 # Wrap the response.
152 data_chan, data_chan_ready_sig = Channel(esi.MMIODataType).wrap(
153 header_out, s2_valid)
154 data_chan_wire.assign(data_chan)
155 data_chan_ready.assign(data_chan_ready_sig)
156
157 # Detect a write of the reset magic number to slot 7. Register the request
158 # so it is a clean one-cycle pulse, asserted as the command advances into
159 # the response stage. 'DesignResetController' latches it, so a single-cycle
160 # pulse is sufficient to trigger the reset.
161 reset_detect = (cmd.write & (slot == Bits(3)(7)) &
162 (cmd.data == Bits(64)(ResetMagicNumber))).as_bits()
163 ports.reset_request = (reset_detect & s1_to_s2_xact).as_bits()
164
165 return HeaderMMIO
166
167
168@modparams
170 data_type: Type, num_outs: int,
171 next_sel_width: int) -> type["ChannelDemuxNImpl"]:
172 """N-way channel demultiplexer for valid/ready signaling. Contains
173 valid/ready registers on the output channels. The selection signal is now
174 embedded in the input channel payload as a struct {sel, data}. Input
175 signals ready when the selected output register is empty."""
176
177 assert num_outs >= 1, "num_outs must be at least 1."
178
179 class ChannelDemuxNImpl(Module):
180 clk = Clock()
181 rst = Reset()
182
183 # Input channel now carries selection along with data.
184 InPayloadType = StructType([
185 ("sel", Bits(clog2(num_outs))),
186 ("next_sel", Bits(next_sel_width)),
187 ("data", data_type),
188 ])
189 inp = Input(Channel(InPayloadType))
190 OutPayloadType = StructType([
191 ("next_sel", Bits(next_sel_width)),
192 ("data", data_type),
193 ])
194 # Outputs are channels of OutPayloadType, which includes both 'next_sel' and 'data' fields.
195 for i in range(num_outs):
196 locals()[f"output_{i}"] = Output(Channel(OutPayloadType))
197
198 @generator
199 def generate(ports) -> None:
200 # Half-stage demux: one register per output channel. Input is ready
201 # when the currently selected output register is empty (not valid).
202 clk = ports.clk
203 rst = ports.rst
204 sel_width = clog2(num_outs)
205
206 # Unwrap input with backpressure from selected output register.
207 input_ready = Wire(Bits(1), name="input_ready")
208 in_payload, in_valid = ports.inp.unwrap(input_ready)
209 in_sel = in_payload.sel
210 in_next_sel = in_payload.next_sel
211 in_data = in_payload.data
212
213 # Track per-output valid regs and build a purely combinational
214 # expression 'selected_valid_expr' = OR_i((sel==i)&valid_i). Avoid
215 # assigning to a Wire multiple times.
216 valid_regs: List[BitsSignal] = []
217 selected_valid_expr = Bits(1)(0)
218
219 for i in range(num_outs):
220 # Write when input transaction targets this output and output not holding data yet.
221 will_write = Wire(Bits(1), name=f"will_write_{i}")
222 write_cond = (in_valid & input_ready & (in_sel == Bits(sel_width)(i)))
223 will_write.assign(write_cond)
224
225 # Data and next_sel registers.
226 out_msg_reg = ChannelDemuxNImpl.OutPayloadType({
227 "next_sel": in_next_sel,
228 "data": in_data
229 }).reg(clk=clk, rst=rst, ce=will_write, name=f"out{i}_msg_reg")
230
231 # Valid register cleared on successful downstream consume.
232 consume = Wire(Bits(1), name=f"consume_{i}")
233 valid_reg = ControlReg(
234 clk=clk,
235 rst=rst,
236 asserts=[will_write],
237 resets=[consume],
238 name=f"out{i}_valid_reg",
239 )
240 valid_regs.append(valid_reg)
241
242 # Channel wrapper.
243 ch_sig, ch_ready = Channel(ChannelDemuxNImpl.OutPayloadType).wrap(
244 out_msg_reg, valid_reg)
245 setattr(ports, f"output_{i}", ch_sig)
246 consume.assign(valid_reg & ch_ready)
247
248 # Accumulate selected_valid expression.
249 selected_valid_expr = (selected_valid_expr | (
250 (in_sel == Bits(sel_width)(i)) & valid_reg)).as_bits()
251
252 # Input ready only when selected output has no valid data latched.
253 input_ready.assign((selected_valid_expr ^ Bits(1)(1)).as_bits())
254
255 def get_out(self, index: int) -> ChannelSignal:
256 return getattr(self, f"output_{index}")
257
258 return ChannelDemuxNImpl
259
260
261@modparams
263 data_type: Type, num_outs: int,
264 branching_factor_log2: int) -> type["ChannelDemuxTree"]:
265 """Pipelined N-way channel demultiplexer for valid/ready signaling. This
266 implementation uses a tree structure of
267 ChannelDemuxN_HalfStage_ReadyBlocking modules to reduce fanout pressure.
268 Supports maximum half-throughput to save complexity and area.
269 """
270
271 root_sel_width = clog2(num_outs)
272 # Simplify algorithm by making sure num_outs is a power of two.
273 num_outs = 2**root_sel_width
274 sel_width = branching_factor_log2
275 fanout = 2**sel_width
276
277 class ChannelDemuxTree(Module):
278 clk = Clock()
279 rst = Reset()
280 # Input now embeds selection bits alongside data.
281 InPayloadType = StructType([
282 ("sel", Bits(clog2(num_outs))),
283 ("data", data_type),
284 ])
285 inp = Input(Channel(InPayloadType))
286
287 # Outputs (data only).
288 for i in range(num_outs):
289 locals()[f"output_{i}"] = Output(Channel(data_type))
290
291 @generator
292 def build(ports) -> None:
293 assert branching_factor_log2 > 0
294 if num_outs == 1:
295 # Strip selection bits and return single channel.
296 setattr(ports, "output_0", ports.inp.transform(lambda p: p.data))
297 return
298
299 def payload_type(sel_width: int, next_sel_width: int) -> Type:
300 return StructType([
301 ("sel", Bits(sel_width)),
302 ("next_sel", Bits(next_sel_width)),
303 ("data", data_type),
304 ])
305
306 def next_sel_width_calc(curr_sel_width) -> int:
307 return max(curr_sel_width - sel_width, 0)
308
309 def payload_next(curr_msg: StructSignal) -> StructSignal:
310 """Given current level payload, produce next level payload by
311 stripping off the top selection bits."""
312
313 next_sel_width = next_sel_width_calc(curr_msg.next_sel.type.width)
314 curr_sel_width = curr_msg.next_sel.type.width
315 new_sel_width = min(curr_sel_width, sel_width)
316 return payload_type(
317 new_sel_width,
318 next_sel_width,
319 )({
320 # Use the MSB bits of next_sel as the next level selection.
321 "sel": (curr_msg.next_sel[next_sel_width:]
322 if curr_sel_width > 0 else Bits(0)(0)),
323 "next_sel": (curr_msg.next_sel[:next_sel_width]
324 if next_sel_width > 0 else Bits(0)(0)),
325 "data": curr_msg.data,
326 })
327
328 current_channels: List[ChannelSignal] = [
329 ports.inp.transform(lambda m: payload_type(0, root_sel_width)({
330 "sel": Bits(0)(0),
331 "next_sel": m.sel,
332 "data": m.data,
333 }))
334 ]
335
336 curr_sel_width = root_sel_width
337 level = 0
338 while len(current_channels) < num_outs:
339 next_level: List[ChannelSignal] = []
340 level_num_outs = min(2**curr_sel_width, fanout)
341 for i, c in enumerate(current_channels):
343 data_type,
344 num_outs=level_num_outs,
345 next_sel_width=next_sel_width_calc(curr_sel_width),
346 )(
347 clk=ports.clk,
348 rst=ports.rst,
349 inp=c.transform(payload_next),
350 instance_name=f"demux_l{level}_i{i}",
351 )
352 for j in range(level_num_outs):
353 next_level.append(dmux.get_out(j))
354 current_channels = next_level
355 curr_sel_width -= sel_width
356 level += 1
357
358 for i in range(num_outs):
359 # Strip off next_sel bits for final output.
360 setattr(
361 ports,
362 f"output_{i}",
363 current_channels[i].transform(lambda p: p.data),
364 )
365
366 def get_out(self, index: int) -> ChannelSignal:
367 return getattr(self, f"output_{index}")
368
369 return ChannelDemuxTree
370
371
372@modparams
373def DesignResetController(
374 delay_cycles: int) -> type["DesignResetControllerImpl"]:
375 """Counts `delay_cycles` clock cycles after a reset request is observed, then
376 asserts `design_reset` for one cycle. This module must be driven by the
377 *external* reset only (not the reset it generates) so that the countdown is
378 not disturbed by the reset it produces.
379
380 `reset_pending` is asserted from the moment a reset is requested until it
381 fires. It is intended to be used to quiesce the design (e.g. stop accepting
382 new transactions) so that nothing is in flight when the reset is asserted."""
383
384 if delay_cycles < 1:
385 raise ValueError("'delay_cycles' must be at least 1.")
386
387 counter_width = max(clog2(delay_cycles), 1)
388
389 class DesignResetControllerImpl(Module):
390 clk = Clock()
391 rst = Reset()
392 reset_request = Input(Bits(1))
393 design_reset = Output(Bits(1))
394 # High from the cycle a reset is requested until it fires. Use this to stop
395 # accepting new work so in-flight transactions can drain before the reset.
396 reset_pending = Output(Bits(1))
397
398 @generator
399 def build(ports):
400 fire = Wire(Bits(1))
401 # Latch that a reset has been requested until we fire the reset.
402 pending = ControlReg(clk=ports.clk,
403 rst=ports.rst,
404 asserts=[ports.reset_request],
405 resets=[fire],
406 name="reset_pending")
407 # Count cycles while a reset is pending.
408 count = Counter(counter_width)(clk=ports.clk,
409 rst=ports.rst,
410 clear=(fire | ~pending).as_bits(),
411 increment=pending,
412 instance_name="reset_delay_counter")
413 fire.assign(
414 (pending &
415 (count.out == UInt(counter_width)(delay_cycles - 1))).as_bits())
416 ports.design_reset = fire
417 ports.reset_pending = pending
418
419 return DesignResetControllerImpl
420
421
422class ChannelMMIO(esi.ServiceImplementation):
423 """MMIO service implementation with MMIO bundle interfaces. Should be
424 relatively easy to adapt to physical interfaces by wrapping the wires to
425 channels then bundles. Allows the implementation to be shared and (hopefully)
426 platform independent.
427
428 Whether or not to support unaligned accesses is up to the clients. The header
429 and manifest do not support unaligned accesses and throw away the lower three
430 bits.
431
432 Only allows for one outstanding request at a time. If a client doesn't return
433 a response, the MMIO service will hang. TODO: add some kind of timeout.
434
435 Implementation-defined MMIO layout:
436 - 0x0: 0 constant
437 - 0x8: Magic number (0x207D98E5_E5100E51)
438 - 0x12: ESI version number (0)
439 - 0x18: Location of the manifest ROM (absolute address)
440
441 - 0x400: Start of MMIO space for requests. Mapping is contained in the
442 manifest so can be dynamically queried.
443
444 - addr(Manifest ROM) + 0: Size of compressed manifest
445 - addr(Manifest ROM) + 8: Start of compressed manifest
446
447 This layout _should_ be pretty standard, but different BSPs may have various
448 different restrictions. Any BSP which uses this service implementation will
449 have this layout, possibly with an offset or address window.
450 """
451
452 clk = Clock()
453 rst = Input(Bits(1))
454
455 cmd = Input(esi.MMIO.read_write.type)
456
457 # Asserted for one cycle when the host requests a design reset via an MMIO
458 # write to the header. Propagates up to the BSP which performs the reset.
459 reset_request = Output(Bits(1))
460
461 # Amount of register space each client gets. This is a GIANT HACK and needs to
462 # be replaced by parameterizable services.
463 # TODO: make the amount of register space each client gets a parameter.
464 # Supporting this will require more address decode logic.
465 #
466 # TODO: only supports one outstanding transaction at a time. This is NOT
467 # enforced or checked! Enforce this.
468
469 RegisterSpace = 0x400
470 RegisterSpaceBits = RegisterSpace.bit_length() - 1
471 AddressMask = 0x3FF
472
473 # Start at this address for assigning MMIO addresses to service requests.
474 initial_offset: int = RegisterSpace
475
476 @generator
477 def generate(ports, bundles: esi._ServiceGeneratorBundles):
478 table, manifest_loc = ChannelMMIO.build_table(bundles)
479 ChannelMMIO.build_read(ports, manifest_loc, table)
480 return True
481
482 @staticmethod
483 def build_table(bundles) -> Tuple[Dict[int, AssignableSignal], int]:
484 """Build a table of read and write addresses to BundleSignals."""
485 offset = ChannelMMIO.initial_offset
486 table: Dict[int, AssignableSignal] = {}
487 for bundle in bundles.to_client_reqs:
488 if bundle.port == 'read':
489 table[offset] = bundle
490 bundle.add_record(details={
491 "offset": offset,
492 "size": ChannelMMIO.RegisterSpace,
493 "type": "ro"
494 })
495 offset += ChannelMMIO.RegisterSpace
496 elif bundle.port == 'read_write':
497 table[offset] = bundle
498 bundle.add_record(details={
499 "offset": offset,
500 "size": ChannelMMIO.RegisterSpace,
501 "type": "rw"
502 })
503 offset += ChannelMMIO.RegisterSpace
504 else:
505 assert False, "Unrecognized port name."
506
507 manifest_loc = offset
508 return table, manifest_loc
509
510 @staticmethod
511 def build_read(ports, manifest_loc: int, table: Dict[int, AssignableSignal]):
512 """Builds the read side of the MMIO service."""
513
514 # Instantiate the header and manifest ROM. Fill in the read_table with
515 # bundle wires to be assigned identically to the other MMIO clients.
516 header_bundle_wire = Wire(esi.MMIO.read_write.type)
517 table[0] = header_bundle_wire
518 header = HeaderMMIO(manifest_loc)(clk=ports.clk,
519 rst=ports.rst,
520 read=header_bundle_wire)
521
522 mani_bundle_wire = Wire(esi.MMIO.read.type)
523 table[manifest_loc] = mani_bundle_wire
524 ESI_Manifest_ROM_Wrapper(clk=ports.clk, read=mani_bundle_wire)
525
526 # Unpack the cmd bundle.
527 data_resp_channel = Wire(Channel(esi.MMIODataType))
528 counted_output = Wire(Channel(esi.MMIODataType))
529 cmd_channel = ports.cmd.unpack(data=counted_output)["cmd"]
530 counted_output.assign(data_resp_channel)
531
532 # Get the selection index and the address to hand off to the clients.
533 sel_bits, client_cmd_chan = ChannelMMIO.build_addr_read(
534 cmd_channel, len(table), manifest_loc)
535
536 # Build the demux/mux and assign the results of each appropriately.
537 read_clients_clog2 = clog2(len(table))
538 # Combine selection bits and command channel payload into a struct channel for the demux tree.
539 TreeInType = StructType([
540 ("sel", Bits(read_clients_clog2)),
541 ("data", client_cmd_chan.type.inner_type),
542 ])
543 sel_bits_truncated = sel_bits.pad_or_truncate(read_clients_clog2)
544 combined_cmd_chan = client_cmd_chan.transform(
545 lambda cmd, _sel=sel_bits_truncated: TreeInType({
546 "sel": _sel,
547 "data": cmd
548 }))
550 client_cmd_chan.type.inner_type, len(table), branching_factor_log2=2)(
551 clk=ports.clk,
552 rst=ports.rst,
553 inp=combined_cmd_chan,
554 instance_name="client_cmd_demux",
555 )
556 client_cmd_channels = [demux_inst.get_out(i) for i in range(len(table))]
557 client_data_channels = []
558 for (idx, offset) in enumerate(sorted(table.keys())):
559 bundle_wire = table[offset]
560 bundle_type = bundle_wire.type
561 if bundle_type == esi.MMIO.read.type:
562 offset = client_cmd_channels[idx].transform(lambda cmd: cmd.offset)
563 bundle, bundle_froms = esi.MMIO.read.type.pack(offset=offset)
564 elif bundle_type == esi.MMIO.read_write.type:
565 bundle, bundle_froms = esi.MMIO.read_write.type.pack(
566 cmd=client_cmd_channels[idx])
567 else:
568 assert False, "Unrecognized bundle type."
569 bundle_wire.assign(bundle)
570 client_data_channels.append(bundle_froms["data"])
571 resp_channel = esi.ChannelMux(client_data_channels)
572 data_resp_channel.assign(resp_channel)
573
574 # The header surfaces a reset request when the host writes the reset magic
575 # number to slot 7. Propagate it up to the caller (the BSP).
576 ports.reset_request = header.reset_request
577
578 @staticmethod
579 def build_addr_read(read_addr_chan: ChannelSignal, num_clients: int,
580 manifest_loc: int) -> Tuple[BitsSignal, ChannelSignal]:
581 """Build a channel for the address read request. Returns the index to select
582 the client and a channel for the masked address to be passed to the
583 clients."""
584
585 # Decoding the selection bits is very simple as of now. This might need to
586 # change to support more flexibility in addressing. Not clear if what we're
587 # doing now it sufficient or not.
588
589 manifest_loc_const = UInt(32)(manifest_loc)
590
591 cmd_ready_wire = Wire(Bits(1))
592 cmd, cmd_valid = read_addr_chan.unwrap(cmd_ready_wire)
593 is_manifest_read = cmd.offset >= manifest_loc_const
594 sel_bits = NamedWire(Bits(32 - ChannelMMIO.RegisterSpaceBits), "sel_bits")
595 # If reading the manifest, override the selection to select the manifest instead.
596 sel_bits.assign(
597 Mux(is_manifest_read,
598 cmd.offset.as_bits()[ChannelMMIO.RegisterSpaceBits:],
599 Bits(32 - ChannelMMIO.RegisterSpaceBits)(num_clients - 1)))
600 regular_client_offset = (cmd.offset.as_bits() &
601 Bits(32)(ChannelMMIO.AddressMask)).as_uint()
602 offset = Mux(is_manifest_read, regular_client_offset,
603 (cmd.offset - manifest_loc_const).as_uint(32))
604 client_cmd = NamedWire(esi.MMIOReadWriteCmdType, "client_cmd")
605 client_cmd.assign(
606 esi.MMIOReadWriteCmdType({
607 "write": cmd.write,
608 "offset": offset,
609 "data": cmd.data
610 }))
611 client_addr_chan, client_addr_ready = Channel(
612 esi.MMIOReadWriteCmdType).wrap(client_cmd, cmd_valid)
613 cmd_ready_wire.assign(client_addr_ready)
614 return sel_bits, client_addr_chan
615
616
617class MMIOIndirection(Module):
618 """Some platforms do not support MMIO space greater than a certain size (e.g.
619 Vitis 2022's limit is 4k). This module implements a level of indirection to
620 provide access to a full 32-bit address space.
621
622 MMIO addresses:
623 - 0x0: 0 constant
624 - 0x8: 64 bit ESI magic number for Indirect MMIO (0x312bf0cc_E5100E51)
625 - 0x10: Version number for Indirect MMIO (0)
626 - 0x18: Location of read/write in the virtual MMIO space.
627 - 0x20: A read from this location will initiate a read in the virtual MMIO
628 space specified by the address stored in 0x18 and return the result.
629 A write to this location will initiate a write into the virtual MMIO
630 space to the virtual address specified in 0x18.
631 """
632 clk = Clock()
633 rst = Reset()
634
635 upstream = Input(esi.MMIO.read_write.type)
636 downstream = Output(esi.MMIO.read_write.type)
637
638 @generator
639 def build(ports):
640 # This implementation assumes there is only one outstanding upstream MMIO
641 # transaction in flight at once. TODO: enforce this or make it more robust.
642
643 reg_bits = 8
644 location_reg = UInt(reg_bits)(0x18)
645 indirect_mmio_reg = UInt(reg_bits)(0x20)
646 virt_address = Wire(UInt(32))
647
648 # Set up the upstream MMIO interface. Capture last upstream command in a
649 # mailbox which never empties to give access to the last command for all
650 # time.
651 upstream_resp_chan_wire = Wire(Channel(esi.MMIODataType))
652 upstream_cmd_chan = ports.upstream.unpack(
653 data=upstream_resp_chan_wire)["cmd"]
654 _, _, upstream_cmd_data = upstream_cmd_chan.snoop()
655
656 # Set up a channel demux to separate the MMIO commands which get processed
657 # locally with ones which should be transformed and fowarded downstream.
658 phys_loc = upstream_cmd_data.offset.as_uint(reg_bits)
659 fwd_upstream = NamedWire(phys_loc == indirect_mmio_reg, "fwd_upstream")
660 local_reg_cmd_chan, downstream_cmd_channel = esi.ChannelDemux(
661 upstream_cmd_chan, fwd_upstream, 2, "upstream_demux")
662
663 # Set up the downstream MMIO interface.
664 downstream_cmd_channel = downstream_cmd_channel.transform(
665 lambda cmd: esi.MMIOReadWriteCmdType({
666 "write": cmd.write,
667 "offset": virt_address,
668 "data": cmd.data
669 }))
670 ports.downstream, froms = esi.MMIO.read_write.type.pack(
671 cmd=downstream_cmd_channel)
672 downstream_data_chan = froms["data"]
673
674 # Process local regs.
675 (local_reg_cmd_valid, local_reg_cmd_ready,
676 local_reg_cmd) = local_reg_cmd_chan.snoop()
677 write_virt_address = (local_reg_cmd_valid & local_reg_cmd_ready &
678 local_reg_cmd.write & (phys_loc == location_reg))
679 virt_address.assign(
680 local_reg_cmd.data.as_uint(32).reg(
681 name="virt_address",
682 clk=ports.clk,
683 ce=write_virt_address,
684 ))
685
686 # Build the pysical MMIO register space.
687 local_reg_resp_array = Array(Bits(64), 4)([
688 0x0, # 0x0
689 IndirectionMagicNumber, # 0x8
690 IndirectionVersionNumber, # 0x10
691 virt_address.as_bits(64), # 0x18
692 ])
693 local_reg_resp_chan = local_reg_cmd_chan.transform(
694 lambda cmd: local_reg_resp_array[cmd.offset.as_uint(2)])
695
696 # Mux together the local register responses and the downstream data to
697 # create the upstream response.
698 upstream_resp = esi.ChannelMux([local_reg_resp_chan, downstream_data_chan])
699 upstream_resp_chan_wire.assign(upstream_resp)
700
701
702@modparams
703def TaggedReadGearbox(input_bitwidth: int,
704 output_bitwidth: int) -> type["TaggedReadGearboxImpl"]:
705 """Build a gearbox to convert the upstream data to the client data
706 type. Assumes a struct {tag, data} and only gearboxes the data. Tag is stored
707 separately and the struct is re-assembled later on."""
708
709 class TaggedReadGearboxImpl(Module):
710 clk = Clock()
711 rst = Reset()
712 in_ = InputChannel(
713 StructType([
714 ("tag", esi.HostMem.TagType),
715 ("data", Bits(input_bitwidth)),
716 ]))
717 out = OutputChannel(
718 StructType([
719 ("tag", esi.HostMem.TagType),
720 ("data", Bits(output_bitwidth)),
721 ]))
722
723 @generator
724 def build(ports):
725 ready_for_upstream = Wire(Bits(1), name="ready_for_upstream")
726 upstream_tag_and_data, upstream_valid = ports.in_.unwrap(
727 ready_for_upstream)
728 upstream_data = upstream_tag_and_data.data
729 upstream_xact = ready_for_upstream & upstream_valid
730
731 # Determine if gearboxing is necessary and whether it needs to be
732 # gearboxed up or just sliced down.
733 if output_bitwidth == input_bitwidth:
734 client_data_bits = upstream_data
735 client_valid = upstream_valid
736 elif output_bitwidth < input_bitwidth:
737 client_data_bits = upstream_data[:output_bitwidth]
738 client_valid = upstream_valid
739 else:
740 # Create registers equal to the number of upstream transactions needed
741 # to fill the client data. Set the output to the concatenation of said
742 # registers.
743 chunks = ceil(output_bitwidth / input_bitwidth)
744 reg_ces = [Wire(Bits(1)) for _ in range(chunks)]
745 regs = [
746 upstream_data.reg(ports.clk,
747 ports.rst,
748 ce=reg_ces[idx],
749 name=f"chunk_reg_{idx}") for idx in range(chunks)
750 ]
751 client_data_bits = BitsSignal.concat(reversed(regs))[:output_bitwidth]
752
753 # Use counter to determine to which register to write and determine if
754 # the registers are all full.
755 clear_counter = Wire(Bits(1))
756 counter_width = clog2(chunks)
757 counter = Counter(counter_width)(clk=ports.clk,
758 rst=ports.rst,
759 clear=clear_counter,
760 increment=upstream_xact)
761 set_client_valid = counter.out == chunks - 1
762 client_xact = Wire(Bits(1))
763 client_valid = ControlReg(ports.clk, ports.rst,
764 [set_client_valid & upstream_xact],
765 [client_xact])
766 client_xact.assign(client_valid & ready_for_upstream)
767 clear_counter.assign(client_xact)
768 for idx, reg_ce in enumerate(reg_ces):
769 reg_ce.assign(upstream_xact &
770 (counter.out == UInt(counter_width)(idx)))
771
772 # Construct the output channel. Shared logic across all three cases.
773 tag_reg = upstream_tag_and_data.tag.reg(ports.clk,
774 ports.rst,
775 ce=upstream_xact,
776 name="tag_reg")
777
778 client_channel, client_ready = TaggedReadGearboxImpl.out.type.wrap(
779 {
780 "tag": tag_reg,
781 "data": client_data_bits,
782 }, client_valid)
783 ready_for_upstream.assign(client_ready)
784 ports.out = client_channel
785
786 return TaggedReadGearboxImpl
787
788
789def HostmemReadProcessor(read_width: int, hostmem_module,
790 reqs: List[esi._OutputBundleSetter]):
791 """Construct a host memory read request module to orchestrate the the read
792 connections. Responsible for both gearboxing the data, multiplexing the
793 requests, reassembling out-of-order responses and routing the responses to the
794 correct clients.
795
796 Generate this module dynamically to allow for multiple read clients of
797 multiple types to be directly accomodated."""
798
799 class HostmemReadProcessorImpl(Module):
800 clk = Clock()
801 rst = Reset()
802
803 # Add an output port for each read client.
804 reqPortMap: Dict[esi._OutputBundleSetter, str] = {}
805 for req in reqs:
806 name = "client_" + req.client_name_str
807 locals()[name] = Output(req.type)
808 reqPortMap[req] = name
809
810 # And then the port which goes to the host.
811 upstream = Output(hostmem_module.read.type)
812
813 @generator
814 def build(ports):
815 """Build the read side of the HostMem service."""
816
817 # If there's no read clients, just return a no-op read bundle.
818 if len(reqs) == 0:
819 upstream_req_channel, _ = Channel(hostmem_module.UpstreamReadReq).wrap(
820 {
821 "tag": 0,
822 "length": 0,
823 "address": 0
824 }, 0)
825 upstream_read_bundle, _ = hostmem_module.read.type.pack(
826 req=upstream_req_channel)
827 ports.upstream = upstream_read_bundle
828 return
829
830 # Since we use the tag to identify the client, we can't have more than 256
831 # read clients. Supporting more than 256 clients would require
832 # tag-rewriting, which we'll probably have to implement at some point.
833 # TODO: Implement tag-rewriting.
834 assert len(reqs) <= 256, "More than 256 read clients not supported."
835
836 # Pack the upstream bundle and leave the request as a wire.
837 upstream_req_channel = Wire(Channel(hostmem_module.UpstreamReadReq))
838 upstream_read_bundle, froms = hostmem_module.read.type.pack(
839 req=upstream_req_channel)
840 ports.upstream = upstream_read_bundle
841 upstream_resp_channel = froms["resp"]
842
843 demux = esi.TaggedDemux(len(reqs), upstream_resp_channel.type)(
844 clk=ports.clk, rst=ports.rst, in_=upstream_resp_channel)
845
846 tagged_client_reqs = []
847 for idx, client in enumerate(reqs):
848 # Find the response channel in the request bundle.
849 resp_type = [
850 c.channel for c in client.type.channels if c.name == 'resp'
851 ][0]
852 demuxed_upstream_channel = demux.get_out(idx)
853
854 # TODO: Should responses come back out-of-order (interleaved tags),
855 # re-order them here so the gearbox doesn't get confused. (Longer term.)
856 # For now, only support one outstanding transaction at a time. This has
857 # the additional benefit of letting the upstream tag be the client
858 # identifier. TODO: Implement the gating logic here.
859
860 # Gearbox the data to the client's data type.
861 client_type = resp_type.inner_type
862 if client_type.data.bitwidth == 0:
863 raise ValueError("Client data type cannot be zero-width. Use a "
864 "single-bit type if no data is needed.")
865
866 gearbox = TaggedReadGearbox(read_width, client_type.data.bitwidth)(
867 clk=ports.clk, rst=ports.rst, in_=demuxed_upstream_channel)
868 client_resp_channel = gearbox.out.transform(lambda m: client_type({
869 "tag": m.tag,
870 "data": m.data.bitcast(client_type.data)
871 }))
872
873 # Assign the client response to the correct port.
874 client_bundle, froms = client.type.pack(resp=client_resp_channel)
875 client_req = froms["req"]
876 tagged_client_req = client_req.transform(
877 lambda r: hostmem_module.UpstreamReadReq({
878 "address": r.address,
879 "length": (client_type.data.bitwidth + 7) // 8,
880 # TODO: Change this once we support tag-rewriting.
881 "tag": idx
882 }))
883 tagged_client_reqs.append(tagged_client_req)
884
885 # Set the port for the client request.
886 setattr(ports, HostmemReadProcessorImpl.reqPortMap[client],
887 client_bundle)
888
889 # Assign the multiplexed read request to the upstream request.
890 # TODO: Don't release a request until the client is ready to accept
891 # the response otherwise the system could deadlock.
892 muxed_client_reqs = esi.ChannelMux(tagged_client_reqs)
893 upstream_req_channel.assign(muxed_client_reqs)
894 HostmemReadProcessorImpl.reqPortMap.clear()
895
896 return HostmemReadProcessorImpl
897
898
899@modparams
900def TaggedWriteGearbox(input_bitwidth: int,
901 output_bitwidth: int) -> type["TaggedWriteGearboxImpl"]:
902 """Build a gearbox to convert the client data to upstream write chunks.
903 Assumes a struct {address, tag, data} and only gearboxes the data. Tag is
904 stored separately and the struct is re-assembled later on."""
905
906 if output_bitwidth % 8 != 0:
907 raise ValueError("Output bitwidth must be a multiple of 8.")
908 input_pad_bits = 0
909 if input_bitwidth % 8 != 0:
910 input_pad_bits = 8 - (input_bitwidth % 8)
911 input_padded_bitwidth = input_bitwidth + input_pad_bits
912
913 class TaggedWriteGearboxImpl(Module):
914 clk = Clock()
915 rst = Reset()
916 in_ = InputChannel(
917 StructType([
918 ("address", UInt(64)),
919 ("tag", esi.HostMem.TagType),
920 ("data", Bits(input_bitwidth)),
921 ]))
922 out = OutputChannel(
923 StructType([
924 ("address", UInt(64)),
925 ("tag", esi.HostMem.TagType),
926 ("data", Bits(output_bitwidth)),
927 ("valid_bytes", Bits(8)),
928 ]))
929
930 num_chunks = ceil(input_padded_bitwidth / output_bitwidth)
931
932 @generator
933 def build(ports):
934 upstream_ready = Wire(Bits(1))
935 ready_for_client = Wire(Bits(1))
936 client_tag_and_data, client_valid = ports.in_.unwrap(ready_for_client)
937 client_data = client_tag_and_data.data
938 if input_pad_bits > 0:
939 client_data = client_data.pad_or_truncate(input_padded_bitwidth)
940 client_xact = ready_for_client & client_valid
941 input_bitwidth_bytes = input_padded_bitwidth // 8
942 output_bitwidth_bytes = output_bitwidth // 8
943
944 # Determine if gearboxing is necessary and whether it needs to be
945 # gearboxed up or just sliced down.
946 if output_bitwidth == input_padded_bitwidth:
947 upstream_data_bits = client_data
948 upstream_valid = client_valid
949 ready_for_client.assign(upstream_ready)
950 tag = client_tag_and_data.tag
951 address = client_tag_and_data.address
952 valid_bytes = Bits(8)(input_bitwidth_bytes)
953 elif output_bitwidth > input_padded_bitwidth:
954 upstream_data_bits = client_data.as_bits(output_bitwidth)
955 upstream_valid = client_valid
956 ready_for_client.assign(upstream_ready)
957 tag = client_tag_and_data.tag
958 address = client_tag_and_data.address
959 valid_bytes = Bits(8)(input_bitwidth_bytes)
960 else:
961 # Create registers equal to the number of upstream transactions needed
962 # to complete the transmission.
963 num_chunks = TaggedWriteGearboxImpl.num_chunks
964 num_chunks_idx_bitwidth = clog2(num_chunks)
965 if input_padded_bitwidth % output_bitwidth == 0:
966 padding_numbits = 0
967 else:
968 padding_numbits = output_bitwidth - (input_padded_bitwidth %
969 output_bitwidth)
970 client_data_padded = BitsSignal.concat(
971 [Bits(padding_numbits)(0), client_data])
972 chunks = [
973 client_data_padded[i * output_bitwidth:(i + 1) * output_bitwidth]
974 for i in range(num_chunks)
975 ]
976 chunk_regs = Array(Bits(output_bitwidth), num_chunks)([
977 c.reg(ports.clk, ce=client_xact, name=f"chunk_{idx}")
978 for idx, c in enumerate(chunks)
979 ])
980 increment = Wire(Bits(1))
981 clear = Wire(Bits(1))
982 counter = Counter(num_chunks_idx_bitwidth)(clk=ports.clk,
983 rst=ports.rst,
984 increment=increment,
985 clear=clear)
986 upstream_data_bits = chunk_regs[counter.out]
987 upstream_valid = ControlReg(ports.clk, ports.rst, [client_xact],
988 [clear])
989 upstream_xact = upstream_valid & upstream_ready
990 clear.assign(upstream_xact & (counter.out == (num_chunks - 1)))
991 increment.assign(upstream_xact)
992 ready_for_client.assign(~upstream_valid)
993 address_padding_bits = clog2(output_bitwidth_bytes)
994 counter_bytes = BitsSignal.concat(
995 [counter.out.as_bits(),
996 Bits(address_padding_bits)(0)]).as_uint()
997
998 # Construct the output channel. Shared logic across all three cases.
999 tag_reg = client_tag_and_data.tag.reg(ports.clk,
1000 ce=client_xact,
1001 name="tag_reg")
1002 addr_reg = client_tag_and_data.address.reg(ports.clk,
1003 ce=client_xact,
1004 name="address_reg")
1005 address = (addr_reg + counter_bytes).as_uint(64)
1006 tag = tag_reg
1007 valid_bytes = Mux(counter.out == (num_chunks - 1),
1008 Bits(8)(output_bitwidth_bytes),
1009 Bits(8)((output_bitwidth - padding_numbits) // 8))
1010
1011 upstream_channel, upstrm_ready_sig = TaggedWriteGearboxImpl.out.type.wrap(
1012 {
1013 "address": address,
1014 "tag": tag,
1015 "data": upstream_data_bits,
1016 "valid_bytes": valid_bytes
1017 }, upstream_valid)
1018 upstream_ready.assign(upstrm_ready_sig)
1019 ports.out = upstream_channel
1020
1021 return TaggedWriteGearboxImpl
1022
1023
1024@modparams
1025def EmitEveryN(message_type: Type, N: int) -> type['EmitEveryNImpl']:
1026 """Emit (forward) one message for every N input messages. The emitted message
1027 is the last one of the N received. N must be >= 1."""
1028
1029 if N < 1:
1030 raise ValueError("N must be >= 1")
1031
1032 class EmitEveryNImpl(Module):
1033 clk = Clock()
1034 rst = Reset()
1035 in_ = InputChannel(message_type)
1036 out = OutputChannel(message_type)
1037
1038 @generator
1039 def build(ports):
1040 ready_for_in = Wire(Bits(1))
1041 in_data, in_valid = ports.in_.unwrap(ready_for_in)
1042 xact = in_valid & ready_for_in
1043
1044 # Fast path: N == 1 -> pass-through.
1045 if N == 1:
1046 out_chan, out_ready = EmitEveryNImpl.out.type.wrap(in_data, in_valid)
1047 ready_for_in.assign(out_ready)
1048 ports.out = out_chan
1049 return
1050
1051 counter_width = clog2(N)
1052 counter_clear = Wire(Bits(1))
1053 counter = Counter(counter_width)(clk=ports.clk,
1054 rst=ports.rst,
1055 increment=xact,
1056 clear=counter_clear)
1057
1058 # Capture last message of the group.
1059 last_msg = in_data.reg(ports.clk, ports.rst, ce=xact, name="last_msg")
1060 # Clear the counter.
1061 hit_last = (counter.out == UInt(counter_width)(N - 1)) & xact
1062 counter_clear.assign(hit_last)
1063
1064 emit_accepted = Wire(Bits(1))
1065 out_valid = ControlReg(ports.clk, ports.rst, [hit_last], [emit_accepted])
1066
1067 out_chan, out_ready = EmitEveryNImpl.out.type.wrap(last_msg, out_valid)
1068 # Stall input while waiting for downstream to accept the aggregated output.
1069 ready_for_in.assign(~(out_valid & ~out_ready))
1070 emit_accepted.assign(out_valid & out_ready) # Output consumed downstream.
1071
1072 ports.out = out_chan
1073
1074 return EmitEveryNImpl
1075
1076
1078 write_width: int, hostmem_module,
1079 reqs: List[esi._OutputBundleSetter]) -> type["HostMemWriteProcessorImpl"]:
1080 """Construct a host memory write request module to orchestrate the the write
1081 connections. Responsible for both gearboxing the data, multiplexing the
1082 requests, reassembling out-of-order responses and routing the responses to the
1083 correct clients.
1084
1085 Generate this module dynamically to allow for multiple write clients of
1086 multiple types to be directly accomodated."""
1087
1088 class HostMemWriteProcessorImpl(Module):
1089
1090 clk = Clock()
1091 rst = Reset()
1092
1093 # Add an output port for each read client.
1094 reqPortMap: Dict[esi._OutputBundleSetter, str] = {}
1095 for req in reqs:
1096 name = "client_" + req.client_name_str
1097 locals()[name] = Output(req.type)
1098 reqPortMap[req] = name
1099
1100 # And then the port which goes to the host.
1101 upstream = Output(hostmem_module.write.type)
1102
1103 @generator
1104 def build(ports):
1105 clk = ports.clk
1106 rst = ports.rst
1107
1108 # If there's no write clients, just create a no-op write bundle
1109 if len(reqs) == 0:
1110 req, _ = Channel(hostmem_module.UpstreamWriteReq).wrap(
1111 {
1112 "address": 0,
1113 "tag": 0,
1114 "data": 0,
1115 "valid_bytes": 0,
1116 }, 0)
1117 write_bundle, _ = hostmem_module.write.type.pack(req=req)
1118 ports.upstream = write_bundle
1119 return
1120
1121 assert len(reqs) <= 256, "More than 256 write clients not supported."
1122
1123 upstream_req_channel = Wire(Channel(hostmem_module.UpstreamWriteReq))
1124 upstream_write_bundle, froms = hostmem_module.write.type.pack(
1125 req=upstream_req_channel)
1126 ports.upstream = upstream_write_bundle
1127 upstream_ack_tag = froms["ackTag"]
1128
1129 demuxed_acks = esi.TaggedDemux(len(reqs), upstream_ack_tag.type)(
1130 clk=ports.clk, rst=ports.rst, in_=upstream_ack_tag)
1131
1132 # TODO: re-write the tags and store the client and client tag.
1133
1134 # Build the write request channels and ack wires.
1135 write_channels: List[ChannelSignal] = []
1136 for idx, req in enumerate(reqs):
1137 # Get the request channel and its data type.
1138 reqch = [c.channel for c in req.type.channels if c.name == 'req'][0]
1139 client_type = reqch.inner_type
1140 if isinstance(client_type.data, Window):
1141 client_type = client_type.lowered_type
1142
1143 # Pack up the bundle and assign the request channel.
1144 write_req_bundle_type = esi.HostMem.write_req_bundle_type(
1145 client_type.data)
1146 input_flit_ack = Wire(upstream_ack_tag.type)
1147 bundle_sig, froms = write_req_bundle_type.pack(ackTag=input_flit_ack)
1148
1149 gearbox_mod = TaggedWriteGearbox(client_type.data.bitwidth, write_width)
1150 gearbox_in_type = gearbox_mod.in_.type.inner_type
1151 tagged_client_req = froms["req"]
1152 bitcast_client_req = tagged_client_req.transform(
1153 lambda m: gearbox_in_type({
1154 "tag": m.tag,
1155 "address": m.address,
1156 "data": m.data.bitcast(gearbox_in_type.data)
1157 }))
1158
1159 # Gearbox the data to the client's data type.
1160 gearbox = gearbox_mod(clk=ports.clk,
1161 rst=ports.rst,
1162 in_=bitcast_client_req)
1163 write_channels.append(
1164 gearbox.out.transform(lambda m: m.type({
1165 "address": m.address,
1166 "tag": idx,
1167 "data": m.data,
1168 "valid_bytes": m.valid_bytes
1169 })))
1170
1171 # Count the number of acks received from hostmem for this client
1172 # and only send one back to the client per input.
1173 ack_every_n = EmitEveryN(upstream_ack_tag.type, gearbox_mod.num_chunks)(
1174 clk=clk, rst=rst, in_=demuxed_acks.get_out(idx))
1175 input_flit_ack.assign(ack_every_n.out)
1176
1177 # Set the port for the client request.
1178 setattr(ports, HostMemWriteProcessorImpl.reqPortMap[req], bundle_sig)
1179
1180 # Build a channel mux for the write requests.
1181 muxed_write_channel = esi.ChannelMux(write_channels)
1182 upstream_req_channel.assign(muxed_write_channel)
1183
1184 return HostMemWriteProcessorImpl
1185
1186
1187@modparams
1188def ChannelHostMem(read_width: int,
1189 write_width: int) -> typing.Type['ChannelHostMemImpl']:
1190
1191 class ChannelHostMemImpl(esi.ServiceImplementation):
1192 """Builds a HostMem service which multiplexes multiple HostMem clients into
1193 two (read and write) bundles of the given data width."""
1194
1195 clk = Clock()
1196 rst = Reset()
1197
1198 UpstreamReadReq = StructType([
1199 ("address", UInt(64)),
1200 ("length", UInt(32)), # In bytes.
1201 ("tag", UInt(8)),
1202 ])
1203 read = Output(
1204 Bundle([
1205 BundledChannel("req", ChannelDirection.TO, UpstreamReadReq),
1206 BundledChannel(
1207 "resp", ChannelDirection.FROM,
1208 StructType([
1209 ("tag", esi.HostMem.TagType),
1210 ("data", Bits(read_width)),
1211 ])),
1212 ]))
1213
1214 if write_width % 8 != 0:
1215 raise ValueError("Write width must be a multiple of 8.")
1216 UpstreamWriteReq = StructType([
1217 ("address", UInt(64)),
1218 ("tag", UInt(8)),
1219 ("data", Bits(write_width)),
1220 ("valid_bytes", Bits(8)),
1221 ])
1222 write = Output(
1223 Bundle([
1224 BundledChannel("req", ChannelDirection.TO, UpstreamWriteReq),
1225 BundledChannel("ackTag", ChannelDirection.FROM, UInt(8)),
1226 ]))
1227
1228 @generator
1229 def generate(ports, bundles: esi._ServiceGeneratorBundles):
1230 # Split the read side out into a separate module. Must assign the output
1231 # ports to the clients since we can't service a request in a different
1232 # module.
1233 read_reqs = [req for req in bundles.to_client_reqs if req.port == 'read']
1234 read_proc_module = HostmemReadProcessor(read_width, ChannelHostMemImpl,
1235 read_reqs)
1236 read_proc = read_proc_module(clk=ports.clk, rst=ports.rst)
1237 ports.read = read_proc.upstream
1238 for req in read_reqs:
1239 req.assign(getattr(read_proc, read_proc_module.reqPortMap[req]))
1240
1241 # The write side.
1242 write_reqs = [
1243 req for req in bundles.to_client_reqs if req.port == 'write'
1244 ]
1245 write_proc_module = HostMemWriteProcessor(write_width, ChannelHostMemImpl,
1246 write_reqs)
1247 write_proc = write_proc_module(clk=ports.clk, rst=ports.rst)
1248 ports.write = write_proc.upstream
1249 for req in write_reqs:
1250 req.assign(getattr(write_proc, write_proc_module.reqPortMap[req]))
1251
1252 return ChannelHostMemImpl
1253
1254
1255@modparams
1256def DummyToHostEngine(client_type: Type) -> type['DummyToHostEngineImpl']:
1257 """Create a fake DMA engine which just throws everything away."""
1258
1259 class DummyToHostEngineImpl(esi.EngineModule):
1260
1261 @property
1262 def TypeName(self):
1263 return "DummyToHostEngine"
1264
1265 clk = Clock()
1266 rst = Reset()
1267 input_channel = InputChannel(client_type)
1268
1269 @generator
1270 def build(ports):
1271 pass
1272
1273 return DummyToHostEngineImpl
1274
1275
1276@modparams
1277def DummyFromHostEngine(client_type: Type) -> type['DummyFromHostEngineImpl']:
1278 """Create a fake DMA engine which just never produces messages."""
1279
1280 class DummyFromHostEngineImpl(esi.EngineModule):
1281
1282 @property
1283 def TypeName(self):
1284 return "DummyFromHostEngine"
1285
1286 clk = Clock()
1287 rst = Reset()
1288 output_channel = OutputChannel(client_type)
1289
1290 @generator
1291 def build(ports):
1292 valid = Bits(1)(0)
1293 data = Bits(client_type.bitwidth)(0).bitcast(client_type)
1294 channel, ready = Channel(client_type).wrap(data, valid)
1295 ports.output_channel = channel
1296
1297 return DummyFromHostEngineImpl
1298
1299
1300def ChannelEngineService(
1301 to_host_engine_gen: Callable,
1302 from_host_engine_gen: Callable) -> type['ChannelEngineService']:
1303 """Returns a channel service implementation which calls
1304 to_host_engine_gen(<client_type>) or from_host_engine_gen(<client_type>) to
1305 generate the to_host and from_host engines for each channel. Does not support
1306 engines which can service multiple clients at once."""
1307
1308 class ChannelEngineService(esi.ServiceImplementation):
1309 """Service implementation which services the clients via a per-channel DMA
1310 engine."""
1311
1312 clk = Clock()
1313 rst = Reset()
1314
1315 @generator
1316 def build(ports, bundles: esi._ServiceGeneratorBundles):
1317 clk = ports.clk
1318 rst = ports.rst
1319
1320 def build_engine_appid(client_appid: List[esi.AppID],
1321 channel_name: str) -> str:
1322 appid_strings = [str(appid) for appid in client_appid]
1323 return f"{'_'.join(appid_strings)}.{channel_name}"
1324
1325 def build_engine(bc: BundledChannel, input_channel=None) -> Type:
1326 idbase = build_engine_appid(bundle.client_name, bc.name)
1327 eng_appid = esi.AppID(idbase)
1328 # DMA engines require at least 1 byte of data; substitute Bits(8)
1329 # for zero-width (void) channel types so the engine never sees a
1330 # zero-length transfer.
1331 engine_client_type = bc.channel.inner_type
1332 is_void = (engine_client_type.bitwidth == 0)
1333 if is_void:
1334 engine_client_type = Bits(8)
1335 if bc.direction == ChannelDirection.FROM:
1336 engine_mod = to_host_engine_gen(engine_client_type)
1337 else:
1338 engine_mod = from_host_engine_gen(engine_client_type)
1339 eng_inputs = {
1340 "clk": ports.clk,
1341 "rst": ports.rst,
1342 }
1343 eng_details: Dict[str, object] = {"engine_inst": eng_appid}
1344 if input_channel is not None:
1345 # For void channels, widen the 0-bit input to the 8-bit
1346 # placeholder the engine expects.
1347 if is_void:
1348 input_channel = input_channel.transform(lambda _: Bits(8)(0))
1349 if (engine_mod.input_channel.type.signaling
1350 != input_channel.type.signaling):
1351 input_channel = input_channel.buffer(
1352 clk,
1353 rst,
1354 stages=1,
1355 output_signaling=engine_mod.input_channel.type.signaling)
1356 eng_inputs["input_channel"] = input_channel
1357 if hasattr(engine_mod, "mmio"):
1358 mmio_appid = esi.AppID(idbase + ".mmio")
1359 eng_inputs["mmio"] = esi.MMIO.read_write(mmio_appid)
1360 eng_details["mmio"] = mmio_appid
1361 if hasattr(engine_mod, "hostmem_write"):
1362 eng_inputs["hostmem_write"] = esi.HostMem.write_from_bundle(
1363 esi.AppID(idbase + ".hostmem_write"),
1364 engine_mod.hostmem_write.type)
1365 if hasattr(engine_mod, "hostmem_read"):
1366 eng_inputs["hostmem_read"] = esi.HostMem.read_from_bundle(
1367 esi.AppID(idbase + ".hostmem_read"), engine_mod.hostmem_read.type)
1368 engine = engine_mod(appid=eng_appid, **eng_inputs)
1369 engine_rec = bundles.emit_engine(engine, details=eng_details)
1370 engine_rec.add_record(bundle, {bc.name: {}})
1371 return engine
1372
1373 for bundle in bundles.to_client_reqs:
1374 bundle_type = bundle.type
1375 to_channels = {}
1376 # Create a DMA engine for each channel headed TO the client (from the host).
1377 for bc in bundle_type.channels:
1378 if bc.direction == ChannelDirection.TO:
1379 engine = build_engine(bc)
1380 out_chan = engine.output_channel
1381 # For void channels, narrow the 8-bit placeholder back to 0-bit.
1382 if bc.channel.inner_type.bitwidth == 0:
1383 out_chan = out_chan.transform(lambda _: Bits(0)(0))
1384 to_channels[bc.name] = out_chan
1385
1386 client_bundle_sig, froms = bundle_type.pack(**to_channels)
1387 bundle.assign(client_bundle_sig)
1388
1389 # Create a DMA engine for each channel headed FROM the client (to the host).
1390 for bc in bundle_type.channels:
1391 if bc.direction == ChannelDirection.FROM:
1392 build_engine(bc, froms[bc.name])
1393
1394 return ChannelEngineService
return wrap(CMemoryType::get(unwrap(ctx), baseType, numElements))
Tuple[BitsSignal, ChannelSignal] build_addr_read(ChannelSignal read_addr_chan, int num_clients, int manifest_loc)
Definition common.py:580
generate(ports, esi._ServiceGeneratorBundles bundles)
Definition common.py:477
Tuple[Dict[int, AssignableSignal], int] build_table(bundles)
Definition common.py:483
build_read(ports, int manifest_loc, Dict[int, AssignableSignal] table)
Definition common.py:511
type["ChannelDemuxNImpl"] ChannelDemuxN_HalfStage_ReadyBlocking(Type data_type, int num_outs, int next_sel_width)
Definition common.py:171
HostmemReadProcessor(int read_width, hostmem_module, List[esi._OutputBundleSetter] reqs)
Definition common.py:790
type["ChannelDemuxTree"] ChannelDemuxTree_HalfStage_ReadyBlocking(Type data_type, int num_outs, int branching_factor_log2)
Definition common.py:264
Module HeaderMMIO(int manifest_loc)
Definition common.py:70
type["TaggedWriteGearboxImpl"] TaggedWriteGearbox(int input_bitwidth, int output_bitwidth)
Definition common.py:901
type[ 'DummyToHostEngineImpl'] DummyToHostEngine(Type client_type)
Definition common.py:1256
type[ 'DummyFromHostEngineImpl'] DummyFromHostEngine(Type client_type)
Definition common.py:1277
type[ 'EmitEveryNImpl'] EmitEveryN(Type message_type, int N)
Definition common.py:1025
type["TaggedReadGearboxImpl"] TaggedReadGearbox(int input_bitwidth, int output_bitwidth)
Definition common.py:704
type["HostMemWriteProcessorImpl"] HostMemWriteProcessor(int write_width, hostmem_module, List[esi._OutputBundleSetter] reqs)
Definition common.py:1079