diff --git a/.github/workflows/sonarcloud.yaml b/.github/workflows/sonarcloud.yaml index 121533ea..83457faf 100644 --- a/.github/workflows/sonarcloud.yaml +++ b/.github/workflows/sonarcloud.yaml @@ -88,7 +88,7 @@ jobs: overwrite: true - name: SonarQube Scan - uses: SonarSource/sonarqube-scan-action@v5 + uses: SonarSource/sonarqube-scan-action@v8 if: ${{ !cancelled() }} env: SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} diff --git a/docs/explanation/nuclearnet.md b/docs/explanation/nuclearnet.md index 178f8b4a..7978ca36 100644 --- a/docs/explanation/nuclearnet.md +++ b/docs/explanation/nuclearnet.md @@ -1,23 +1,19 @@ -# NUClearNet: Peer-to-Peer Networking +# NUClearNet: peer-to-peer networking NUClearNet is NUClear's built-in networking layer — a decentralized, peer-to-peer messaging system that lets NUClear nodes communicate transparently across a network. It's designed for robotics and distributed systems where nodes need to discover each other automatically and exchange typed messages with minimal configuration. -## Architecture & Design +## Architecture and design ```mermaid graph TD A[Node A] <-->|UDP| B[Node B] A <-->|UDP| C[Node C] - A <-->|UDP| D[Node D] B <-->|UDP| C - B <-->|UDP| D - C <-->|UDP| D style A fill:#4a9eff,color:#fff style B fill:#4a9eff,color:#fff style C fill:#4a9eff,color:#fff - style D fill:#4a9eff,color:#fff ``` Key design principles: @@ -25,14 +21,16 @@ Key design principles: - **Decentralized mesh** — no central server or message broker. Every node is equal. - **Autonomous discovery** — nodes find each other via periodic announcements, no manual configuration of peer addresses. -- **UDP-only** — both discovery and data transfer use UDP (no TCP). +- **UDP-only** — both discovery and data transfer use UDP (User Datagram Protocol), not TCP. This keeps the implementation simple and avoids head-of-line blocking. -- **Two socket types** — each node has an *announce socket* (for discovery) and a *data socket* (for payload transfer). +- **Two socket types** — each node has an *announce socket* (for receiving discovery messages) and a *data socket* (for sending announces and transferring data). +- **Subscription-based routing** — nodes advertise which message types they are interested in, + and senders only transmit messages to peers that have subscribed to that type. The announce socket listens on a shared multicast/broadcast address that all nodes agree on. -The data socket uses an ephemeral port unique to each node — peers learn each other's data address through announce packets. +The data socket uses an ephemeral port unique to each node — peers learn each other's data address from the UDP source address of announce packets. -## Component Layers +## Modular component architecture ```mermaid graph TB @@ -50,61 +48,110 @@ graph TB NC[NetworkController] end - subgraph "NUClearNetwork Engine" - NN[NUClearNetwork class] - end - - subgraph "Operating System" - AS[Announce Socket - UDP multicast/broadcast] - DS[Data Socket - UDP unicast] + subgraph "NUClearNet Engine" + NN[NUClearNet] + DISC[Discovery] + FRAG[Fragmentation] + REL[Reliability] + ROUTE[Routing] + DEDUP[PacketDeduplicator] + RTT[RTTEstimator] + + subgraph "Operating System" + AS[Announce Socket - UDP multicast/broadcast] + DS[Data Socket - UDP ephemeral port] + end end DSL --> NW --> NC EMIT --> NE --> NC NC <--> NN + NN --> DISC + NN --> FRAG + NN --> REL + NN --> ROUTE + NN --> DEDUP + REL --> RTT NN <--> AS NN <--> DS ``` -## Peer Discovery +The NUClearNet engine is decomposed into focused modules: + +| Module | Responsibility | +| -------------------- | ---------------------------------------------------------------------------------------------------- | +| `Discovery` | Peer lifecycle — announce, join/leave detection, peer timeout | +| `Fragmentation` | Splitting large messages into MTU-sized (Maximum Transmission Unit) fragments, reassembly on receive | +| `Reliability` | ACK (Acknowledgment) tracking, retransmission scheduling | +| `Routing` | Subscription-based message filtering per peer | +| `PacketDeduplicator` | Sliding-window duplicate detection per peer | +| `RTTEstimator` | Per-peer RTT (Round-Trip Time) estimation for retransmission timing | -Every node periodically broadcasts an `AnnouncePacket` on the announce address. +## Peer discovery + +Every node periodically sends an `AnnouncePacket` on the announce address. This is how nodes find each other. -### Discovery Sequence +### Discovery sequence ```mermaid sequenceDiagram participant A as Node A (existing) - participant Net as Network (multicast) participant B as Node B (joining) Note over A: Running, announcing every ~interval - A->>Net: AnnouncePacket (name="A") + A-->>B: AnnouncePacket (name="A") [multicast] + Note over B: Not yet joined — doesn't receive + + Note over B: Starts up, joins multicast group + B-->>A: AnnouncePacket (name="B") [multicast] - Note over B: Starts up, begins announcing - B->>Net: AnnouncePacket (name="B") + Note over A: New peer heard on announce channel! + A->>A: Add B (announce_heard=true, handshake=IDLE) + A-->>B: AnnouncePacket (forced re-announce) [multicast] + A->>B: ConnectPacket (SYN) [data port] - Note over A: New peer heard! - A->>A: Add B to peer list - A->>B: Immediate announce back (unicast) - A->>A: Fire NetworkJoin event + Note over B: Heard A on announce channel + B->>B: Add A (announce_heard=true, handshake=IDLE) + B-->>A: AnnouncePacket (forced re-announce) [multicast] - Note over B: Hears A's announce - B->>B: Add A to peer list - B->>B: Fire NetworkJoin event + Note over B: Received SYN on data port from A + B->>B: handshake → SYN_RECEIVED + B->>A: ConnectPacket (SYN+ACK) [data port] + + Note over A: Received SYN+ACK on data port + A->>A: handshake → CONFIRMED + A->>A: announce_heard ✓ + CONFIRMED → Fire NetworkJoin + A->>B: ConnectPacket (ACK) [data port] + + Note over B: Received ACK on data port + B->>B: handshake → CONFIRMED + B->>B: announce_heard ✓ + CONFIRMED → Fire NetworkJoin loop Ongoing - A->>Net: AnnouncePacket every ~500ms - B->>Net: AnnouncePacket every ~500ms + A-->>B: AnnouncePacket every ~500ms [multicast] + B-->>A: AnnouncePacket every ~500ms [multicast] end - Note over A: No packet from B for 2+ seconds - A->>A: Remove B from peer list - A->>A: Fire NetworkLeave event + alt Graceful shutdown + B-->>A: LeavePacket [multicast] + Note over A: Immediate removal, Fire NetworkLeave + else Timeout (no packets for 2s) + Note over A: No packet from B for 2 seconds + A->>A: Remove B, Fire NetworkLeave + end ``` -### Announce Address Options +Dashed lines (`-->>`) represent packets sent to the multicast/broadcast group (announce channel). +Solid lines (`->>`) represent packets sent directly to a peer's data port (unicast). + +When a node hears an announce from an unknown peer, +it immediately re-announces to the multicast group (so the new peer can hear it) +and sends a CONNECT(SYN) to the peer's data port to begin the data handshake. +The connection is only considered "up" once both the announce path and data path are confirmed +(see [Connection establishment](#connection-establishment) below). + +### Announce address options The announce address can be: @@ -113,21 +160,270 @@ The announce address can be: - **Broadcast** (e.g., `255.255.255.255`) — works on simple LANs without multicast support. - **Unicast** — for point-to-point setups or testing. -### Peer Timeout +### NAT-friendly port learning + +NAT (Network Address Translation) devices translate source ports and addresses between networks. +Announces are sent from the *data socket* (ephemeral port), not the announce socket. +This means the receiver learns the sender's data port directly from the UDP source address — no explicit port field is needed in the announce packet. +This design also works naturally with NAT devices that translate source ports. + +### Peer timeout + +Each peer's `last_seen` timestamp is refreshed every time any packet is received from them. +If no packet is received within the configured timeout (default 2 seconds), the peer is considered gone — it's removed from the peer list and a `NetworkLeave` event fires. + +### Connection establishment + +After discovering a peer via announce packets, +both sides must satisfy two independent conditions before the connection is considered "up": + +1. **Announce path confirmed** (`announce_heard`) — the peer's announce was received on the multicast/broadcast channel. + This proves that their data port can reach our announce address. +1. **Data handshake confirmed** (`handshake == CONFIRMED`) — a 3-way handshake over the data ports proves bidirectional data connectivity. + +Both conditions are required because NAT devices may remap ephemeral ports, +making the data-to-data paths unreliable, +while the announce path (multicast group membership) confirms that broadcast-targeted messages will arrive. + +There are four communication paths between two nodes: + +| Path | Meaning | How confirmed | +| --------- | ---------------------------------------- | ------------------------------- | +| b_d → a_a | B's data port → A's announce (multicast) | A receives B's AnnouncePacket | +| a_d → b_a | A's data port → B's announce (multicast) | B receives A's AnnouncePacket | +| a_d → b_d | A's data port → B's data port (unicast) | B receives ConnectPacket from A | +| b_d → a_d | B's data port → A's data port (unicast) | A receives ConnectPacket from B | + +The packet type encodes which path was used: + +- **ANNOUNCE** packets are always sent to the multicast group — receiving one proves the announce path. +- **CONNECT** packets are always sent to a peer's data port — receiving one proves the data path. + +This means no socket tracking is needed to determine which path a packet arrived on; +the packet type itself is the proof. + +#### Two-flag connection model + +```mermaid +stateDiagram-v2 + state "Connection Status" as conn { + state "announce_heard = false" as af + state "announce_heard = true" as at + + state "Handshake State Machine" as hs { + [*] --> IDLE + IDLE --> SYN_SENT: mark_syn_sent() + IDLE --> SYN_RECEIVED: receive SYN + SYN_SENT --> SYN_RECEIVED: receive SYN + SYN_SENT --> CONFIRMED: receive SYN+ACK + SYN_RECEIVED --> CONFIRMED: receive ACK or SYN+ACK + } + } + + note right of conn + Connected = announce_heard AND handshake == CONFIRMED + Either flag can be satisfied first. + end note +``` + +The two flags are independent — they can be satisfied in any order: + +- **Normal flow**: Announce heard first (from periodic announce), then data handshake completes. +- **Late announce**: Data handshake completes first (CONNECT received before announce), + then the announce arrives and triggers the join event. + +#### Handshake sequence (normal flow) + +```mermaid +sequenceDiagram + participant A as Node A + participant B as Node B + + B-->>A: AnnouncePacket (name="B") [multicast] + Note over A: announce_heard = true ✓ + + Note over A: Forces re-announce + sends SYN + A-->>B: AnnouncePacket (re-announce) [multicast] + A->>B: ConnectPacket (SYN) [data port] + Note over A: handshake = SYN_SENT + + Note over B: announce_heard = true ✓ + + Note over B: Received SYN on data port + Note over B: Knows: a_d→b_d ✓ + B->>A: ConnectPacket (SYN+ACK) [data port] + Note over B: handshake = SYN_RECEIVED -Each peer's `last_update` timestamp is refreshed every time a packet arrives from them. -If no packet is received for approximately 2 seconds (configurable), the peer is considered gone — it's removed from the peer list and a `NetworkLeave` event fires. + Note over A: Received SYN+ACK on data port + Note over A: Knows: b_d→a_d ✓, a_d→b_d ✓ (B responded) + Note over A: handshake = CONFIRMED + Note over A: announce_heard ✓ + CONFIRMED → CONNECTED ✓ + A->>B: ConnectPacket (ACK) [data port] -## Wire Protocol + Note over B: Received ACK on data port + Note over B: Knows: b_d→a_d ✓ (A confirmed receipt) + Note over B: handshake = CONFIRMED + Note over B: announce_heard ✓ + CONFIRMED → CONNECTED ✓ -All NUClearNet packets share a common header format. + Note over A,B: Both sides connected — data transfer begins +``` + +The knowledge progression: + +1. **A receives B's announce on multicast** — A learns that `b_d→a_a` works (announce path). +1. **A sends SYN to B's data port** — when B receives this, B learns that `a_d→b_d` works. +1. **B sends SYN+ACK to A's data port** — when A receives this, A learns that `b_d→a_d` works. + A also infers `a_d→b_d` works (because B's response proves the SYN arrived). +1. **A sends ACK to B's data port** — when B receives this, B learns that `b_d→a_d` works + (because A's ACK proves the SYN+ACK arrived). + +After step 4, both sides have confirmed all four communication paths. +Only then does the `NetworkJoin` event fire and data packets begin flowing. + +#### Late announce (data handshake completes first) + +If a CONNECT packet arrives from a peer before their announce has been heard +(for example, when multicast delivery is slower than unicast), +the data handshake proceeds normally but the connection is not declared "up" until the announce arrives: + +```mermaid +sequenceDiagram + participant A as Node A + participant B as Node B + + Note over B: B heard A's announce, sends SYN + B->>A: ConnectPacket (SYN) [data port] + Note over A: Creates peer entry (announce_heard=false) + Note over A: handshake: IDLE → SYN_RECEIVED + A->>B: ConnectPacket (SYN+ACK) [data port] + + B->>A: ConnectPacket (ACK) [data port] + Note over A: handshake = CONFIRMED + Note over A: But announce_heard=false → NOT connected yet + + Note over A: Later, B's announce arrives + B-->>A: AnnouncePacket from B [multicast] + Note over A: announce_heard = true ✓ + Note over A: announce_heard ✓ + CONFIRMED → CONNECTED ✓ + Note over A: Fire NetworkJoin event +``` + +#### Simultaneous open + +If both nodes hear each other's announces at nearly the same time, +both will send SYN simultaneously. +The state machine handles this gracefully: + +```mermaid +sequenceDiagram + participant A as Node A + participant B as Node B + + A->>B: ConnectPacket (SYN) [data port] + B->>A: ConnectPacket (SYN) [data port] + Note over A,B: Both in SYN_SENT, receive SYN → SYN_RECEIVED + + A->>B: ConnectPacket (SYN+ACK) [data port] + B->>A: ConnectPacket (SYN+ACK) [data port] + Note over A: SYN_RECEIVED → receives SYN+ACK → CONFIRMED + Note over B: SYN_RECEIVED → receives SYN+ACK → CONFIRMED + + A->>B: ConnectPacket (ACK) [data port] + B->>A: ConnectPacket (ACK) [data port] + Note over A,B: Both CONFIRMED (duplicate ACKs are harmless) +``` + +Duplicate or out-of-order handshake packets do not cause state regressions — +once a peer reaches CONFIRMED, it stays there. + +#### Handshake resilience + +UDP packets can be dropped at any point in the handshake. +Rather than adding a separate retransmission timer, +the handshake piggybacks on the periodic announce cycle (~500ms): + +Each time an announce is received from a peer whose handshake is incomplete, +the appropriate CONNECT packet is retransmitted: + +| Current state | Retransmit | Purpose | +| ------------- | ---------- | ----------------------------------------------- | +| IDLE | SYN | Initial SYN was never sent or was dropped | +| SYN_SENT | SYN | Our SYN was dropped, retry | +| SYN_RECEIVED | SYN+ACK | Our SYN+ACK was dropped, retry | +| CONFIRMED | ACK | Our ACK was dropped, help peer finish handshake | + +This provides automatic recovery for every drop scenario: + +```mermaid +sequenceDiagram + participant A as Node A + participant B as Node B + + A->>B: ConnectPacket (SYN) [data port] + Note over A: handshake = SYN_SENT + Note over B: SYN dropped ✗ + + B-->>A: AnnouncePacket [multicast, periodic] + Note over A: Peer not connected, retransmit + A->>B: ConnectPacket (SYN) [data port, retransmit] + + Note over B: Received SYN this time + B->>A: ConnectPacket (SYN+ACK) [data port] + Note over A: handshake = CONFIRMED ✓ + A->>B: ConnectPacket (ACK) [data port] + Note over B: handshake = CONFIRMED ✓ +``` + +Since both peers announce periodically, +a dropped packet is retried within at most one announce interval. +If the data path is permanently broken in one direction, +the handshake will never complete — which is correct, +since bidirectional data connectivity is required for message exchange. -### Packet Header +#### Connect packet + +```mermaid +block-beta + columns 6 + hdr["Header (5B)"]:3 flags["flags (1B)"]:3 + + style hdr fill:#ff6b6b,color:#fff + style flags fill:#ff922b,color:#fff +``` + +- **flags** — bit 0: SYN (initiating connection), bit 1: ACK (acknowledging receipt) + +| Flags | Value | Meaning | +| --------- | ----- | ---------------------------- | +| SYN | 0x01 | Initiating a new connection | +| ACK | 0x02 | Acknowledging a received SYN | +| SYN + ACK | 0x03 | Responding to a received SYN | + +CONNECT packets are always sent to a peer's data port (never to the multicast group). +Receiving a CONNECT packet proves that the sender's data port can reach your data port. + +#### Data gating + +While the connection is incomplete (either flag unsatisfied), +data packets from the peer are dropped. +This prevents processing messages from a peer whose connectivity has not been fully verified. +Once both `announce_heard` and `handshake == CONFIRMED` are satisfied, +normal data transfer begins immediately. + +### Graceful departure + +When a node shuts down cleanly, it sends a `LeavePacket` so peers can remove it immediately without waiting for the timeout. + +## Wire protocol + +All NUClearNet packets share a common 5-byte header. + +### Packet header ```mermaid block-beta columns 8 - h1["0xE2"]:1 h2["0x98"]:1 h3["0xA2"]:1 ver["Version 0x02"]:1 type["Type"]:1 payload["Payload..."]:3 + h1["0xE2"]:1 h2["0x98"]:1 h3["0xA2"]:1 ver["Version 0x03"]:1 type["Type"]:1 payload["Payload..."]:3 style h1 fill:#ff6b6b,color:#fff style h2 fill:#ff6b6b,color:#fff @@ -137,94 +433,152 @@ block-beta style payload fill:#96ceb4,color:#fff ``` -- **Bytes 0-2**: `0xE2 0x98 0xA2` — the ☢ (radioactive) symbol in UTF-8. +- **Bytes 0–2**: `0xE2 0x98 0xA2` — the ☢ (radioactive) symbol in UTF-8. Acts as a magic number to identify NUClear packets. -- **Byte 3**: Version — currently `0x02` +- **Byte 3**: Protocol version — `0x03` for the current implementation - **Byte 4**: Packet type -### Packet Types +A received packet is only accepted if the magic bytes, version, and type field all pass validation. + +### Packet types + +| Type | Value | Purpose | +| -------- | ----- | ----------------------------------------- | +| ANNOUNCE | 1 | Periodic discovery broadcast | +| LEAVE | 2 | Graceful departure notification | +| DATA | 3 | Data payload (original or retransmission) | +| ACK | 4 | Acknowledgment of received fragments | +| CONNECT | 5 | Connection handshake (SYN/ACK flags) | + +### Announce packet + +```mermaid +block-beta + columns 10 + hdr["Header (5B)"]:2 nlen["name_length (2B)"]:2 name["name (variable)"]:2 nsub["num_subs (2B)"]:2 subs["sub hashes (N×8B)"]:2 + + style hdr fill:#ff6b6b,color:#fff + style nlen fill:#ffd93d,color:#333 + style name fill:#6bcb77,color:#fff + style nsub fill:#4d96ff,color:#fff + style subs fill:#9775fa,color:#fff +``` + +- **name_length** — length of the node name string +- **name** — the node's name (UTF-8, not null-terminated) +- **num_subscriptions** — how many type hashes follow (0 = interested in all messages) +- **subscription hashes** — `uint64_t` type hashes this node wants to receive -| Type | Value | Purpose | -| ------------------- | ----- | -------------------------------------- | -| ANNOUNCE | 1 | Periodic discovery broadcast | -| LEAVE | 2 | Graceful departure notification | -| DATA | 3 | Normal data payload | -| DATA_RETRANSMISSION | 4 | Retransmitted data fragment | -| ACK | 5 | Acknowledgment of received fragments | -| NACK | 6 | Request for specific missing fragments | +No port field is included — the receiver learns the sender's data port from the UDP source address. -### DataPacket Structure +### Data packet ```mermaid block-beta columns 12 - hdr["Header (5B)"]:2 pid["packet_id (2B)"]:2 pno["packet_no (2B)"]:2 pcnt["packet_count (2B)"]:2 rel["reliable (1B)"]:1 hash["type_hash (8B)"]:1 data["payload..."]:2 + hdr["Header (5B)"]:2 pid["packet_id (2B)"]:2 pno["packet_no (2B)"]:2 pcnt["packet_count (2B)"]:2 flags["flags (1B)"]:1 hash["type_hash (8B)"]:1 data["payload..."]:2 style hdr fill:#ff6b6b,color:#fff style pid fill:#ffd93d,color:#333 style pno fill:#6bcb77,color:#fff style pcnt fill:#4d96ff,color:#fff - style rel fill:#ff922b,color:#fff + style flags fill:#ff922b,color:#fff style hash fill:#9775fa,color:#fff style data fill:#96ceb4,color:#fff ``` -- **packet_id** — a semi-unique identifier for this message (groups fragments together) +- **packet_id** — a semi-unique identifier for this message group (wraps at 65535) - **packet_no** — which fragment this is (0-indexed) - **packet_count** — total number of fragments in this message -- **reliable** — whether this packet requires acknowledgment +- **flags** — bit 0: reliable delivery requested - **hash** — 64-bit type hash identifying what kind of data this is -- **data** — the serialized payload bytes +- **payload** — the serialized payload bytes for this fragment + +### ACK packet + +```mermaid +block-beta + columns 8 + hdr["Header (5B)"]:2 pid["packet_id (2B)"]:1 pcnt["packet_count (2B)"]:1 bits["bitset (⌈count/8⌉ bytes)"]:4 + + style hdr fill:#ff6b6b,color:#fff + style pid fill:#ffd93d,color:#333 + style pcnt fill:#4d96ff,color:#fff + style bits fill:#6bcb77,color:#fff +``` -## Fragmentation & Reassembly +- **packet_id** — which packet group this ACK refers to +- **packet_count** — total fragments in the group (for validation) +- **bitset** — one bit per fragment (LSB first). + Bit set means the corresponding fragment has been received. -UDP datagrams have a practical size limit (the network MTU). -Large messages must be split across multiple packets. +## Fragmentation and reassembly -### MTU Calculation +Because NUClearNet uses UDP, each packet must fit within a single network datagram. +UDP datagrams have a practical size limit — the network's MTU. +Messages larger than this limit are automatically split into fragments, +each sent as a separate UDP datagram. +The receiver reassembles the fragments back into the complete message. + +### MTU calculation ``` -fragment_size = network_mtu - IP_header(40) - UDP_header(8) - DataPacket_header +fragment_size = network_mtu - IP_header(20/40) - UDP_header(8) - DataPacket_header(20) ``` -With a typical 1500-byte MTU, this gives roughly **1441 bytes per fragment** (accounting for the DataPacket fields). +With a typical 1500-byte Ethernet MTU this gives approximately **1452 bytes per fragment** for IPv4. -### Sending Large Messages +### Sending large messages ```mermaid flowchart LR MSG["Message (5000 bytes)"] --> SPLIT[Split into fragments] - SPLIT --> F1["Fragment 0 (1441B)"] - SPLIT --> F2["Fragment 1 (1441B)"] - SPLIT --> F3["Fragment 2 (1441B)"] - SPLIT --> F4["Fragment 3 (677B)"] + SPLIT --> F1["Fragment 0 (1452B)"] + SPLIT --> F2["Fragment 1 (1452B)"] + SPLIT --> F3["Fragment 2 (1452B)"] + SPLIT --> F4["Fragment 3 (644B)"] F1 --> UDP1[UDP Datagram] F2 --> UDP2[UDP Datagram] F3 --> UDP3[UDP Datagram] F4 --> UDP4[UDP Datagram] ``` -### Reassembly on the Receiver +### Reassembly on the receiver The receiver collects fragments keyed by `(source_address, packet_id)`. -Once all `packet_count` fragments arrive, the original message is reassembled and delivered. +Once all `packet_count` fragments have arrived, the original message is reassembled and delivered. + +**Assembly timeout:** +If an incomplete message hasn't received new fragments within the peer timeout (default 2 seconds), it's discarded. +This matches the peer liveness timeout — if no fragments have arrived in this period, +either the peer is dead (and will be removed) or the sender has moved on (unreliable message). +For reliable messages, the sender's retransmissions will keep refreshing the assembly's timestamp, +so the assembly will not expire while the sender is still alive and retransmitting. -- **Stale assemblies**: If an incomplete message hasn't received new fragments in `10 × RTT` (round-trip time to that peer), it's discarded. - This prevents memory leaks from lost unreliable packets. +**Maximum assembly size:** +A configurable limit (default 64 MB) prevents memory exhaustion from maliciously large messages. +If a message's total size would exceed this limit, the assembly is rejected. -## Reliable Delivery +## Reliable delivery -By default, NUClearNet is **unreliable** — packets are fire-and-forget, just like raw UDP. -But when you need guaranteed delivery, the reliable mode adds ACK-based retransmission. +Fragmentation solves the message *size* problem, but UDP itself provides no delivery guarantees. +Packets can be lost, reordered, or duplicated by the network. +For many robotics use cases (sensor streams, video frames), this is fine — a missing update is quickly superseded by the next one. +But some messages *must* arrive: configuration commands, state transitions, calibration data. -### Unreliable (Default) +NUClearNet's reliable delivery mode adds ACK-based retransmission on top of the same fragmented UDP transport. +When you send a message reliably, the system tracks which fragments have been acknowledged by the receiver and retransmits any that go missing. + +### Unreliable (default) - Send and forget - No ACKs, no retransmission - Fastest possible — zero overhead - Fine for high-frequency data where missing one update doesn't matter (sensor streams, video frames) -### Reliable Mode +### Reliable mode + +When reliable delivery is requested, the sender and receiver engage in a conversation to ensure all fragments arrive: ```mermaid sequenceDiagram @@ -236,41 +590,135 @@ sequenceDiagram S->>R: DataPacket (id=42, no=2, reliable=true) Note over R: Received 0 and 2, missing 1 - R->>S: ACK (id=42, no=2, bitset=[1,0,1]) + R->>S: ACK (id=42, bitset=[1,0,1]) Note over S: Sees fragment 1 not ACKed - Note over S: Wait RTT timeout... - S->>R: DataRetransmission (id=42, no=1) + Note over S: Wait RTO timeout... + S->>R: DataPacket (id=42, no=1, retransmit) - R->>S: ACK (id=42, no=1, bitset=[1,1,1]) + R->>S: ACK (id=42, bitset=[1,1,1]) Note over R: All fragments received, deliver message ``` Key mechanisms: -- **ACK per fragment** — when the receiver gets a fragment, it responds with an ACK that includes a bitset of *all* received fragments for that packet_id. - This gives the sender full visibility. -- **RTT-based retransmission** — the sender waits one estimated RTT before retransmitting un-ACKed fragments. - Retransmitting too early wastes bandwidth; too late adds latency. -- **Adaptive RTT estimation** — each peer's round-trip time is tracked using a Kalman filter. - This adapts to changing network conditions smoothly. -- **NACK support** — the receiver can proactively request specific missing fragments via NACK packets. -- **Duplicate detection** — a circular buffer of recent `packet_id` values prevents processing the same message twice. +- **Bitset ACK** — when the receiver gets a fragment, + it responds with an ACK containing a bitset of *all* received fragments for that packet group. + This gives the sender full visibility into what's been received. +- **RTO-based retransmission** — the sender waits one RTO (Retransmission Timeout) before retransmitting un-ACKed fragments. + The RTO is calculated per peer based on measured round-trip times (see below). +- **No retransmission limit** — reliable packets are retransmitted indefinitely until either all fragments are ACKed, + or the peer is removed (due to timeout or graceful leave). + This guarantees delivery as long as the connection remains alive. -### RTT Estimation (Kalman Filter) +### How long to wait before retransmitting (RTT estimation) -Rather than using a simple moving average, NUClearNet uses a single-state Kalman filter per peer: +The key challenge in retransmission is choosing *when* to retransmit. +Too soon, and you waste bandwidth resending packets that were simply delayed. +Too late, and the receiver is left waiting for data that was lost. + +The answer depends on how long packets actually take to travel between two specific peers — the round-trip time. +NUClearNet measures this per peer by timing how long it takes between sending a fragment and receiving its ACK. +This measurement is then smoothed using the Jacobson/Karels algorithm (the same approach TCP uses, defined in RFC 6298): ``` -K = (P + Q) / (P + Q + R) // Kalman gain -P = R * (P + Q) / (R + P + Q) // Update variance -X = X + (measurement - X) * K // Update estimate +RTTVAR = (1 - β) × RTTVAR + β × |SRTT - sample| +SRTT = (1 - α) × SRTT + α × sample +RTO = SRTT + 4 × RTTVAR ``` -Where `Q` is process noise (how much RTT might change), `R` is measurement noise (how noisy individual measurements are), and `X` is the current RTT estimate. -This gives smooth, responsive RTT tracking. +Where: + +- `α = 0.125` — smoothing factor for RTT (standard TCP value) +- `β = 0.25` — smoothing factor for RTT variation (standard TCP value) +- `SRTT` — smoothed RTT estimate +- `RTTVAR` — RTT variation (jitter) +- `RTO` — retransmission timeout (clamped between 100 ms and 60 s) -## Type Routing +The RTO is the actual wait time before retransmitting. +It's set slightly above the smoothed RTT (plus a jitter margin) so that under normal conditions, +the ACK arrives just before the timeout fires. +If the network gets congested and round-trip times increase, the RTO automatically grows to compensate. +If the network recovers, the RTO shrinks back down. + +This means retransmission timing is always appropriate for the current link conditions between each specific pair of peers, +rather than relying on a fixed timeout that would be too aggressive for slow links or too conservative for fast ones. + +## Subscription-based routing + +With fragmentation, reliability, and deduplication handling the *transport* of messages, +the final piece is deciding *which peers* should receive each message. +Rather than broadcasting everything to all peers (which wastes bandwidth), +nodes advertise which message types they want to receive via subscription hashes in their announce packets. +This allows senders to skip transmitting messages to peers that aren't interested in them. + +```mermaid +sequenceDiagram + participant A as Node A + participant B as Node B (subscribed to SensorData) + participant C as Node C (subscribed to Commands) + + Note over A: Sending SensorData (hash=0x1234) + A->>B: DataPacket (hash=0x1234) + Note over A: Node C not subscribed, skip + + Note over A: Sending Command (hash=0x5678) + A->>C: DataPacket (hash=0x5678) + Note over A: Node B not subscribed, skip +``` + +**Default behavior:** +If a peer advertises an empty subscription set (no hashes), it receives *all* messages. +This ensures backward compatibility and supports "gateway" nodes that need to see everything. + +When a local `on>` reaction is registered, +the `NetworkController` adds the corresponding type hash to this node's subscription list and re-announces with the updated subscriptions. + +### Broadcast delivery via multicast + +When a message is sent without a specific target (broadcast to all peers) and does not require reliable delivery, +it is sent once to the multicast/broadcast group rather than unicast to each peer individually. +This is significantly more efficient when there are many peers: + +```mermaid +sequenceDiagram + participant A as Node A + participant B as Node B + participant C as Node C + + Note over A: Unreliable broadcast (empty target) + A-->>B: DataPacket (hash=0x1234) [multicast] + A-->>C: DataPacket (hash=0x1234) [multicast] + Note over A: Single send — network delivers to all + + Note over B: Subscribed → accept + Note over C: Not subscribed → discard +``` + +Each receiver checks its local subscription list and discards messages it is not interested in. +This filtering happens before fragmentation reassembly to avoid wasted work. + +Reliable sends and targeted sends (to a specific named peer) are always unicast, +because ACK tracking and retransmission require per-peer communication. + +## Packet deduplication + +Reliable delivery creates a new problem: duplicate packets. +When a sender retransmits a fragment because the ACK was lost (not the fragment itself), +the receiver may process the same fragment twice. +Similarly, network anomalies can cause any UDP packet to arrive more than once. + +To handle this, each peer has an associated `PacketDeduplicator` — a sliding-window bitset that tracks the last 256 packet IDs seen from that peer. +When a data packet arrives: + +1. If the packet ID falls within the window and is already marked as seen, it's dropped as a duplicate. +1. If the packet ID is newer than the window, the window slides forward and the packet is accepted. +1. If the packet ID is older than the entire window (more than 256 behind), it's dropped. + +This handles scenarios like retransmissions arriving after the original was already processed, +or network loops causing packets to appear multiple times. + +## Type routing Messages are identified by a **type hash** rather than string names or channel IDs. @@ -328,10 +776,11 @@ on>().then([](const SensorData& data) { When you use `Network`: -1. At bind time, the reaction's type hash is registered with the `NetworkController` -1. The `NetworkController` maps `hash → reaction` in its internal multimap +1. At bind time, the reaction's type hash is registered with the `NetworkController`. +1. The `NetworkController` adds the hash to its subscription set and re-announces. +1. The hash is mapped to the reaction in an internal multimap. 1. When a packet arrives with that hash, the `NetworkController`: - - Stores the raw bytes in ThreadStore + - Stores the raw bytes in `ThreadStore` - Calls `get_task()` on the matched reactions - The `Network` word's `get()` deserializes the bytes into a `T` @@ -343,12 +792,14 @@ emit(std::make_unique(reading), "target_name", true) This triggers: -1. `emit::Network` serializes the data and computes the type hash -1. A `NetworkEmit` message is emitted locally -1. `NetworkController` catches it and calls `NUClearNetwork::send(hash, payload, target, reliable)` -1. The network engine fragments and transmits the packet +1. `emit::Network` serializes the data and computes the type hash. +1. A `NetworkEmit` message is emitted locally. +1. `NetworkController` catches it and calls `NUClearNet::send(hash, payload, target, reliable)`. +1. The network engine checks which peers subscribe to the hash via the `Routing` module. +1. For each eligible peer, the message is fragmented and transmitted. +1. If reliable, the `Reliability` module tracks the packet group for ACK/retransmission. -### Peer Lifecycle Events +### Peer lifecycle events ```cpp on>().then([](const NetworkJoin& event) { @@ -360,9 +811,9 @@ on>().then([](const NetworkLeave& event) { }); ``` -These are emitted by the `NetworkController` when its join/leave callbacks fire from the network engine. +These are emitted by the `NetworkController` when its join/leave callbacks fire from the `Discovery` module. -## Data Transmission Flow +## Data transmission flow ```mermaid sequenceDiagram @@ -370,19 +821,17 @@ sequenceDiagram participant Net as UDP Network participant Receiver as Receiver Node - Note over Sender: Serialise data + compute type hash + Note over Sender: Serialize data + compute type hash + Sender->>Sender: Check routing (peer subscriptions) Sender->>Sender: Fragment if larger than MTU Sender->>Net: UDP DataPacket(s) Net->>Receiver: UDP DataPacket(s) + Receiver->>Receiver: Deduplicate (sliding window) Receiver->>Receiver: Reassemble fragments Note over Receiver: Look up reactions by type hash - Note over Receiver: Deserialise → callback executes + Note over Receiver: Deserialize → callback executes ``` -In more detail, the sender side proceeds as: `emit(data)` → serialise → compute type hash → emit a local `NetworkEmit` message → `NetworkController` calls `NUClearNetwork::send()` → fragment and transmit via UDP. - -On the receiver side: `NUClearNetwork` reassembles fragments → calls `packet_callback` on `NetworkController` → looks up reactions by hash → creates tasks → `Network::get()` deserialises → callback runs. - ## Configuration The network is configured by emitting a `NetworkConfiguration` message: @@ -395,5 +844,25 @@ emit(std::make_unique( )); ``` +### Configuration fields + +| Field | Type | Default | Description | +| ------------------ | ---------- | ------------------- | ----------------------------------------------- | +| `name` | `string` | — | Unique name for this node on the network | +| `announce_address` | `string` | `"239.226.152.162"` | Address for node discovery announcements | +| `announce_port` | `uint16_t` | `7447` | Port for announce messages | +| `bind_address` | `string` | `""` (all) | Local interface to bind to | +| `mtu` | `uint16_t` | `1500` | Maximum transmission unit (fragments if larger) | + When a new configuration is received, the `NetworkController` tears down existing sockets and reinitializes with the new settings. The node name becomes the identifier that other peers see in `NetworkJoin` events. + +### Internal engine parameters + +The `NUClearNet` engine supports additional parameters beyond what's exposed through `NetworkConfiguration`: + +| Parameter | Default | Description | +| -------------- | --------- | -------------------------------------------------- | +| `peer_timeout` | 2 seconds | How long without a packet before a peer is removed | + +| `max_assembly_size` | 64 MB | Maximum reassembled message size (prevents memory bombs) | diff --git a/docs/how-to/networking.md b/docs/how-to/networking.md index 98ca2cd0..30af976d 100644 --- a/docs/how-to/networking.md +++ b/docs/how-to/networking.md @@ -48,17 +48,17 @@ public: }; ``` -### NetworkConfiguration Fields +### NetworkConfiguration fields -| Field | Type | Default | Description | -| ------------------ | ---------- | ---------- | ----------------------------------------------- | -| `name` | `string` | — | Unique name for this node on the network | -| `announce_address` | `string` | — | Address for node discovery announcements | -| `announce_port` | `uint16_t` | — | Port for announce messages | -| `bind_address` | `string` | `""` (all) | Local interface to bind to | -| `mtu` | `uint16_t` | `1500` | Maximum transmission unit (fragments if larger) | +| Field | Type | Default | Description | +| ------------------ | ---------- | ------------------- | ----------------------------------------------- | +| `name` | `string` | — | Unique name for this node on the network | +| `announce_address` | `string` | `"239.226.152.162"` | Address for node discovery announcements | +| `announce_port` | `uint16_t` | `7447` | Port for announce messages | +| `bind_address` | `string` | `""` (all) | Local interface to bind to | +| `mtu` | `uint16_t` | `1500` | Maximum transmission unit (fragments if larger) | -### Network Modes +### Network modes NUClearNet supports several discovery modes depending on the `announce_address` you configure: @@ -226,12 +226,12 @@ public: }; ``` -## Reliable vs Unreliable Delivery +## Reliable vs unreliable delivery -| Mode | Behavior | Use when | -| ---------- | ---------------------------------------------------- | -------------------------------- | -| Unreliable | Fire-and-forget. No retransmission. Lowest latency. | Streaming data, periodic updates | -| Reliable | Retransmits until acknowledged. Delivery guaranteed. | Commands, configuration, events | +| Mode | Behavior | Use when | +| ---------- | --------------------------------------------------------------------------------- | -------------------------------- | +| Unreliable | Fire-and-forget. No retransmission. Lowest latency. | Streaming data, periodic updates | +| Reliable | Retransmits until acknowledged (ACK bitset). Uses Jacobson/Karels RTT estimation. | Commands, configuration, events | Pass `true` as the reliability argument to `emit`: @@ -243,7 +243,7 @@ emit(std::make_unique(cmd)); emit(std::make_unique(cmd), true); ``` -## Serialization Requirements +## Serialization requirements Types sent over the network must be serializable. NUClear handles this automatically for **trivially copyable** types (POD structs with no pointers or dynamic memory). @@ -251,3 +251,12 @@ NUClear handles this automatically for **trivially copyable** types (POD structs For complex types, specialize `NUClear::util::serialise::Serialise` to provide custom `serialise()`, `deserialise()`, and `hash()` methods. Type safety across nodes is ensured by hash matching — if a type's hash doesn't match between sender and receiver, the message is silently discarded. + +## Subscription-based routing + +NUClearNet automatically advertises which message types your node is interested in. +When you register an `on>` reaction, the type hash is added to your node's subscription set and announced to peers. +Peers will only send messages to your node if you are subscribed to that message type. + +If a node has no subscriptions (no `on>` reactions), it receives all messages by default. +This is useful for debugging or gateway nodes that need to observe all traffic. diff --git a/docs/reference/dsl/network.md b/docs/reference/dsl/network.md index 7a1f93cf..be05190f 100644 --- a/docs/reference/dsl/network.md +++ b/docs/reference/dsl/network.md @@ -41,6 +41,8 @@ sequenceDiagram ``` **Bind phase:** Emits a `NetworkListen` message with the type hash of `T` to register interest with the `NetworkController`. +The hash is also added to this node's subscription set, which is advertised to peers via announce packets. +Peers use this subscription information to avoid sending messages that no local reaction is listening for. **Get phase:** Deserializes the message from `ThreadStore` data populated by `NetworkController`, using `Serialise::deserialise()`. @@ -90,6 +92,8 @@ on>().then([](const NetworkSource& src, const SensorReadi - Only reacts to messages received over the network, never to local emits. - The type hash is computed from the type name string — renaming a type breaks compatibility with peers using the old name. - Multiple nodes can listen for the same type simultaneously. +- Registering a `Network` reaction causes this node to advertise the type hash as a subscription, + enabling subscription-based routing so peers only send relevant messages. ## See Also diff --git a/docs/reference/emit/network.md b/docs/reference/emit/network.md index c4ff879d..74b8d0e7 100644 --- a/docs/reference/emit/network.md +++ b/docs/reference/emit/network.md @@ -82,9 +82,12 @@ public: - Requires `NetworkConfiguration` to be emitted for the network to be active. - The type must be serializable: either trivially copyable, or provide a `util::serialise::Serialise` specialization. - Type routing uses a hash — the same type must be defined on both peers. -- If `reliable` is true, delivery is guaranteed (TCP-like semantics). - If false, packets may be lost (UDP-like). +- If `reliable` is true, the message uses ACK-based retransmission with Jacobson/Karels RTO estimation. + Retransmissions continue indefinitely until the peer acknowledges or disconnects. + If false, packets are fire-and-forget (UDP-like). - If the target peer is not connected, the message is silently dropped even with `reliable = true`. +- Messages are only sent to peers that have subscribed to the type hash (subscription-based routing). + Peers with no subscriptions receive all messages by default. ## See Also diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7ed0fb8a..07bab21f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -28,7 +28,7 @@ configure_file(nuclear.in ${PROJECT_BINARY_DIR}/nuclear) # Build the library find_package(Threads REQUIRED) -file(GLOB_RECURSE src "*.c" "*.cpp" "*.hpp" "*.ipp") +file(GLOB_RECURSE src CONFIGURE_DEPENDS "*.c" "*.cpp" "*.hpp" "*.ipp") add_library(nuclear STATIC ${src}) add_library(NUClear::nuclear ALIAS nuclear) diff --git a/src/extension/NetworkController.cpp b/src/extension/NetworkController.cpp index 56158ee5..07f9a859 100644 --- a/src/extension/NetworkController.cpp +++ b/src/extension/NetworkController.cpp @@ -27,6 +27,8 @@ #include #include #include +#include +#include #include #include @@ -37,6 +39,8 @@ #include "../dsl/word/emit/Network.hpp" #include "../message/NetworkConfiguration.hpp" #include "../message/NetworkEvent.hpp" +#include "../nuclearnet/Discovery.hpp" +#include "../nuclearnet/NUClearNet.hpp" #include "../util/get_hostname.hpp" namespace NUClear { @@ -52,12 +56,13 @@ namespace extension { : Reactor(std::move(environment)) { // Set our function callback - network.set_packet_callback([this](const network::NUClearNetwork::NetworkTarget& remote, - const uint64_t& hash, - const bool& reliable, - std::vector&& payload) { + net.set_packet_callback([this](const network::NUClearNet::sock_t& source, + const std::string& peer_name, + uint64_t hash, + bool reliable, + std::vector&& payload) { // Construct our NetworkSource information - const dsl::word::NetworkSource src{remote.name, remote.target, reliable}; + const dsl::word::NetworkSource src{peer_name, source, reliable}; // Move the payload in as we are stealing it const std::vector p(std::move(payload)); @@ -85,23 +90,23 @@ namespace extension { }); // Set our join callback - network.set_join_callback([this](const network::NUClearNetwork::NetworkTarget& remote) { + net.set_join_callback([this](const network::PeerInfo& peer) { auto l = std::make_unique(); - l->name = remote.name; - l->address = remote.target; + l->name = peer.name; + l->address = peer.address; emit(l); }); // Set our leave callback - network.set_leave_callback([this](const network::NUClearNetwork::NetworkTarget& remote) { + net.set_leave_callback([this](const network::PeerInfo& peer) { auto l = std::make_unique(); - l->name = remote.name; - l->address = remote.target; + l->name = peer.name; + l->address = peer.address; emit(l); }); // Set our event timer callback - network.set_next_event_callback([this](std::chrono::steady_clock::time_point t) { + net.set_event_callback([this](std::chrono::steady_clock::time_point t) { const std::chrono::steady_clock::duration emit_offset = t - std::chrono::steady_clock::now(); emit(std::make_unique(), std::chrono::duration_cast(emit_offset)); @@ -114,6 +119,9 @@ namespace extension { // Insert our new reaction reactions.insert(std::make_pair(l.hash, l.reaction)); + + // Add subscription so peers know to send us this type + net.add_subscription(l.hash); }); // Stop listening for a network type @@ -128,13 +136,20 @@ namespace extension { if (it != reactions.end()) { reactions.erase(it); } + + // Rebuild subscriptions from remaining reactions + std::set subs; + for (const auto& r : reactions) { + subs.insert(r.first); + } + net.set_subscriptions(subs); }); - on>().then("Network Emit", [this](const NetworkEmit& emit) { - network.send(emit.hash, emit.payload, emit.target, emit.reliable); + on>().then("Network Emit", [this](const NetworkEmit& e) { + net.send(e.hash, e.payload.data(), e.payload.size(), e.target, e.reliable); }); - on().then("Shutdown Network", [this] { network.shutdown(); }); + on().then("Shutdown Network", [this] { net.shutdown(); }); // Configure the NUClearNetwork options on>().then([this](const NetworkConfiguration& config) { @@ -151,17 +166,32 @@ namespace extension { listen_handles.clear(); } - // Name becomes hostname by default if not set - const std::string name = config.name.empty() ? util::get_hostname() : config.name; + // Build configuration + network::NetworkConfig net_config; + net_config.name = config.name.empty() ? util::get_hostname() : config.name; + net_config.announce_address = config.announce_address; + net_config.announce_port = config.announce_port; + net_config.bind_address = config.bind_address; + net_config.mtu = config.mtu; + + // Collect current subscriptions + { + const std::lock_guard lock(reaction_mutex); + std::set subs; + for (const auto& r : reactions) { + subs.insert(r.first); + } + net.set_subscriptions(subs); + } // Reset our network using this configuration - network.reset(name, config.announce_address, config.announce_port, config.bind_address, config.mtu); + net.reset(net_config); // Execution handle - process_handle = on>().then("Network processing", [this] { network.process(); }); + process_handle = on>().then("Network processing", [this] { net.process(); }); - for (auto& fd : network.listen_fds()) { - listen_handles.push_back(on(fd, IO::READ).then("Packet", [this] { network.process(); })); + for (auto& fd : net.listen_fds()) { + listen_handles.push_back(on(fd, IO::READ).then("Packet", [this] { net.process(); })); } }); } diff --git a/src/extension/NetworkController.hpp b/src/extension/NetworkController.hpp index 1452f4b6..94a6e9e6 100644 --- a/src/extension/NetworkController.hpp +++ b/src/extension/NetworkController.hpp @@ -30,8 +30,8 @@ #include "../PowerPlant.hpp" #include "../Reactor.hpp" #include "../message/NetworkConfiguration.hpp" +#include "../nuclearnet/NUClearNet.hpp" #include "../util/get_hostname.hpp" -#include "network/NUClearNetwork.hpp" namespace NUClear { namespace extension { @@ -42,8 +42,8 @@ namespace extension { explicit NetworkController(std::unique_ptr environment); private: - /// Our NUClearNetwork object that handles the networking - network::NUClearNetwork network; + /// Our NUClearNet object that handles the networking + network::NUClearNet net; /// The reaction that handles timed events from the network ReactionHandle process_handle; diff --git a/src/extension/network/NUClearNetwork.cpp b/src/extension/network/NUClearNetwork.cpp deleted file mode 100644 index 11839da9..00000000 --- a/src/extension/network/NUClearNetwork.cpp +++ /dev/null @@ -1,1135 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2017 NUClear Contributors - * - * This file is part of the NUClear codebase. - * See https://github.com/Fastcode/NUClear for further info. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated - * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to - * permit persons to whom the Software is furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the - * Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE - * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR - * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -#include "NUClearNetwork.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../util/network/if_number_from_address.hpp" -#include "../../util/network/resolve.hpp" -#include "../../util/network/sock_t.hpp" -#include "../../util/platform.hpp" -#include "wire_protocol.hpp" - -namespace NUClear { -namespace extension { - namespace network { - - namespace { // Anonymous namespace for internal linkage - - /** - * Read a single packet from the given udp file descriptor. - * - * @param fd The file descriptor to read from - * - * @return The data and who it was sent from - */ - std::pair> read_socket(fd_t fd) { - - // Allocate a vector that can hold a datagram - std::vector payload(1500); - iovec iov{}; - iov.iov_base = reinterpret_cast(payload.data()); - iov.iov_len = static_cast(payload.size()); - - // Who we are receiving from - util::network::sock_t from{}; - - // Setup our message header to receive - msghdr mh{}; - mh.msg_name = &from.sock; - mh.msg_namelen = sizeof(from); - mh.msg_iov = &iov; - mh.msg_iovlen = 1; - - // Now read the data for real - const ssize_t received = recvmsg(fd, &mh, 0); - payload.resize(received); - - return {from, std::move(payload)}; - } - - } // namespace - - NUClearNetwork::PacketQueue::PacketTarget::PacketTarget(std::weak_ptr target, - std::vector acked) - : target(std::move(target)), acked(std::move(acked)), last_send(std::chrono::steady_clock::now()) {} - - NUClearNetwork::PacketQueue::PacketQueue() = default; - - NUClearNetwork::~NUClearNetwork() { - shutdown(); - } - - void NUClearNetwork::set_packet_callback( - std::function&&)> f) { - packet_callback = std::move(f); - } - - - void NUClearNetwork::set_join_callback(std::function f) { - join_callback = std::move(f); - } - - - void NUClearNetwork::set_leave_callback(std::function f) { - leave_callback = std::move(f); - } - - void NUClearNetwork::set_next_event_callback(std::function f) { - next_event_callback = std::move(f); - } - - std::array NUClearNetwork::udp_key(const sock_t& address) { - - // Get our keys for our maps, it will be the ip and then port - std::array key = {0}; - - switch (address.sock.sa_family) { - case AF_INET: - // The first chars are 0 (ipv6) and after that is our address and then port - std::memcpy(&key[6], &address.ipv4.sin_addr, sizeof(address.ipv4.sin_addr)); - key[8] = address.ipv4.sin_port; - break; - - case AF_INET6: - // IPv6 address then port - std::memcpy(key.data(), &address.ipv6.sin6_addr, sizeof(address.ipv6.sin6_addr)); - key[8] = address.ipv6.sin6_port; - break; - - default: throw std::invalid_argument("Unknown address family"); - } - - return key; - } - - - void NUClearNetwork::remove_target(const std::shared_ptr& target) { - - // Erase udp - auto key = udp_key(target->target); - if (udp_target.find(key) != udp_target.end()) { - udp_target.erase(udp_target.find(key)); - } - - // Erase name - auto range = name_target.equal_range(target->name); - for (auto it = range.first; it != range.second; ++it) { - if (it->second == target) { - name_target.erase(it); - break; - } - } - - // Erase target - auto t = std::find(targets.begin(), targets.end(), target); - if (t != targets.end()) { - targets.erase(t); - } - } - - - void NUClearNetwork::open_data(const sock_t& bind_address) { - - // Create the "join any" address for this address family - sock_t address = bind_address; - - // Set port to 0 to get an ephemeral port - if (address.sock.sa_family == AF_INET) { - address.ipv4.sin_port = 0; - } - // IPv6 - else if (address.sock.sa_family == AF_INET6) { - address.ipv6.sin6_port = 0; - } - - // Open a socket with the same family as our announce target - data_fd = ::socket(address.sock.sa_family, SOCK_DGRAM, IPPROTO_UDP); - if (data_fd < 0) { - throw std::system_error(network_errno, std::system_category(), "Unable to open the UDP socket"); - } - - // Set broadcast so we can send to broadcast addresses if needed - int yes = 1; - if (::setsockopt(data_fd, SOL_SOCKET, SO_BROADCAST, reinterpret_cast(&yes), sizeof(yes)) < 0) { - throw std::system_error(network_errno, std::system_category(), "Unable to set broadcast on the socket"); - } - - // Bind to the address, and if we fail throw an error - if (::bind(data_fd, &address.sock, address.size()) != 0) { - throw std::system_error(network_errno, - std::system_category(), - "Unable to bind the UDP socket to the port"); - } - } - - - void NUClearNetwork::open_announce(const sock_t& announce_target, const sock_t& bind_address) { - - // Work out what type of announce we are doing as it will influence how we make the socket - const bool multicast = - (announce_target.sock.sa_family == AF_INET - && (ntohl(announce_target.ipv4.sin_addr.s_addr) & 0xF0000000U) == 0xE0000000) - || (announce_target.sock.sa_family == AF_INET6 && announce_target.ipv6.sin6_addr.s6_addr[0] == 0xFF); - - // Make our socket - announce_fd = ::socket(bind_address.sock.sa_family, SOCK_DGRAM, IPPROTO_UDP); - if (announce_fd < 0) { - throw std::system_error(network_errno, std::system_category(), "Unable to open the UDP socket"); - } - - // Set that we reuse the address so more than one application can bind (this applies for unicast as well) - int yes = 1; - if (::setsockopt(announce_fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), sizeof(yes)) < 0) { - throw std::system_error(network_errno, std::system_category(), "Unable to reuse address on the socket"); - } - -// If SO_REUSEPORT is available set it too -#ifdef SO_REUSEPORT - if (::setsockopt(announce_fd, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), sizeof(yes)) < 0) { - throw std::system_error(network_errno, std::system_category(), "Unable to reuse port on the socket"); - } -#endif - - // We enable SO_BROADCAST since sometimes we need to send broadcast packets - if (::setsockopt(announce_fd, SOL_SOCKET, SO_BROADCAST, reinterpret_cast(&yes), sizeof(yes)) < 0) { - throw std::system_error(network_errno, std::system_category(), "Unable to set broadcast on the socket"); - } - - // Bind to the address - if (::bind(announce_fd, &bind_address.sock, bind_address.size()) != 0) { - throw std::system_error(network_errno, std::system_category(), "Unable to bind the UDP socket"); - } - - // If we have a multicast address, then we need to join the multicast groups - if (multicast) { - - // Our multicast join request will depend on protocol version - if (announce_target.sock.sa_family == AF_INET) { - - // Set the multicast address we are listening on and bind address - ip_mreq mreq{}; - mreq.imr_multiaddr = announce_target.ipv4.sin_addr; - mreq.imr_interface = bind_address.ipv4.sin_addr; - - // Join our multicast group - if (::setsockopt(announce_fd, - IPPROTO_IP, - IP_ADD_MEMBERSHIP, - reinterpret_cast(&mreq), - sizeof(ip_mreq)) - < 0) { - throw std::system_error(network_errno, - std::system_category(), - "There was an error while attempting to join the multicast group"); - } - - // Set our transmission interface for the multicast socket - if (::setsockopt(announce_fd, - IPPROTO_IP, - IP_MULTICAST_IF, - reinterpret_cast(&bind_address.ipv4.sin_addr), - sizeof(bind_address.ipv4.sin_addr)) - < 0) { - throw std::system_error(network_errno, - std::system_category(), - "Unable to use the requested interface for multicast"); - } - } - else if (announce_target.sock.sa_family == AF_INET6) { - - // Set the multicast address we are listening on - ipv6_mreq mreq{}; - mreq.ipv6mr_multiaddr = announce_target.ipv6.sin6_addr; - mreq.ipv6mr_interface = util::network::if_number_from_address(bind_address.ipv6); - - // Join our multicast group - if (::setsockopt(announce_fd, - IPPROTO_IPV6, - IPV6_JOIN_GROUP, - reinterpret_cast(&mreq), - sizeof(ipv6_mreq)) - < 0) { - throw std::system_error(network_errno, - std::system_category(), - "There was an error while attempting to join the multicast group"); - } - - // Set our transmission interface for the multicast socket - if (::setsockopt(announce_fd, - IPPROTO_IPV6, - IPV6_MULTICAST_IF, - reinterpret_cast(&mreq.ipv6mr_interface), - sizeof(mreq.ipv6mr_interface)) - < 0) { - throw std::system_error(network_errno, - std::system_category(), - "Unable to use the requested interface for multicast"); - } - } - } - } - - void NUClearNetwork::shutdown() { - - // If we have an fd, send a shutdown message - if (data_fd > 0) { - // Make a leave packet from our announce packet - LeavePacket packet; - - auto announce_targets = name_target.equal_range(""); - for (auto it = announce_targets.first; it != announce_targets.second; ++it) { - - // Send the packet - ::sendto(data_fd, - reinterpret_cast(&packet), - sizeof(packet), - 0, - &it->second->target.sock, - it->second->target.size()); - } - } - - // Close our existing FDs if they exist - if (data_fd > 0) { - close(data_fd); - data_fd = INVALID_SOCKET; - } - if (announce_fd > 0) { - close(announce_fd); - announce_fd = INVALID_SOCKET; - } - } - - void NUClearNetwork::reset(const std::string& name, - const std::string& address, - in_port_t port, - const std::string& bind_address, - uint16_t network_mtu) { - - // Close our existing FDs if they exist - shutdown(); - - // Lock all mutexes - std::lock(target_mutex, send_queue_mutex); - const std::lock_guard target_lock(target_mutex, std::adopt_lock); - const std::lock_guard send_lock(send_queue_mutex, std::adopt_lock); - - // Clear all our data structures - send_queue.clear(); - name_target.clear(); - targets.clear(); - udp_target.clear(); - - // Resolve the announce address and port into a sockaddr - const util::network::sock_t announce_target = util::network::resolve(address, port); - - // If we have a bind address, resolve it otherwise use the announce address family with any address - sock_t bind_target{}; - if (bind_address.empty()) { - bind_target = announce_target; - if (announce_target.sock.sa_family == AF_INET) { - bind_target.ipv4.sin_addr.s_addr = htonl(INADDR_ANY); - } - else if (announce_target.sock.sa_family == AF_INET6) { - bind_target.ipv6.sin6_addr = IN6ADDR_ANY_INIT; - } - else { - throw std::invalid_argument("Unknown address family"); - } - } - else { - bind_target = util::network::resolve(bind_address, port); - // If the family doesn't match, throw an error - if (bind_target.sock.sa_family != announce_target.sock.sa_family) { - throw std::invalid_argument("Bind address family does not match announce address family"); - } - } - - // Add the target for our multicast packets - auto all_target = std::make_shared("", announce_target); - targets.push_front(all_target); - name_target.insert(std::make_pair("", all_target)); - udp_target.insert(std::make_pair(udp_key(announce_target), all_target)); - - // Work out our MTU for udp packets - packet_data_mtu = network_mtu; // Start with the total mtu - packet_data_mtu -= sizeof(DataPacket) - 1; // Now remove data packet header size - // IPv6 headers are always 40 bytes, and IPv4 can be 20-60 but if we assume 40 for all cases it should - // be safe enough - packet_data_mtu -= 40; // Remove size of an IPv4 header or IPv6 header - packet_data_mtu -= 8; // Size of a UDP packet header - - // Build our announce packet - announce_packet.resize(sizeof(AnnouncePacket) + name.size(), 0); - AnnouncePacket& pkt = *reinterpret_cast(announce_packet.data()); - pkt = AnnouncePacket(); - std::memcpy(&pkt.name, name.c_str(), name.size()); - - // Open the data and announce sockets - open_data(bind_target); - open_announce(announce_target, bind_target); - } - - void NUClearNetwork::reset(const std::string& name, - const std::string& address, - in_port_t port, - uint16_t network_mtu) { - reset(name, address, port, "", network_mtu); - } - - void NUClearNetwork::process() { - - // Record the time - auto now = std::chrono::steady_clock::now(); - - // Check if we should announce now - if (now - last_announce > std::chrono::milliseconds(500)) { - last_announce = now; - announce(); - - // Update our event timer - auto next_announce = now + std::chrono::milliseconds(500); - if (next_announce > next_event) { - next_event = next_announce; - - // Let the system know when we need attention again - next_event_callback(next_event); - } - } - - // We need to make this list outside mutex scope in case the callback needs the mutex - std::vector> leavers; - - // Check if any of our existing connections have timed out - /* Mutex Scope */ { - const std::lock_guard lock(target_mutex); - - // Always skip the first element since it's the "all" target - for (auto it = std::next(targets.begin(), 1); it != targets.end();) { - - auto ptr = *it; - ++it; - - if (now - ptr->last_update > std::chrono::seconds(2)) { - - // Remove this, it timed out - leavers.push_back(ptr); - remove_target(ptr); - } - } - } - - // Run the callback for anyone that left - for (const auto& l : leavers) { - leave_callback(*l); - } - - // Check if we have packets to resend and if so resend - if (!send_queue.empty()) { - retransmit(); - } - - // Used for storing how many bytes are available on a socket - unsigned long count = 0; // NOLINT(google-runtime-int) MSVC wants an unsigned long - - // Read packets from the multicast socket while there is data available - ioctl(announce_fd, FIONREAD, &(count = 0)); - while (count > 0) { - auto packet = read_socket(announce_fd); - process_packet(packet.first, std::move(packet.second)); - ioctl(announce_fd, FIONREAD, &(count = 0)); - } - - // Check if we have a packet available on the data socket - ioctl(data_fd, FIONREAD, &(count = 0)); - while (count > 0) { - auto packet = read_socket(data_fd); - process_packet(packet.first, std::move(packet.second)); - ioctl(data_fd, FIONREAD, &(count = 0)); - } - } - - void NUClearNetwork::retransmit() { - - // Locking send_queue_mutex second after target_mutex - std::lock(target_mutex, send_queue_mutex); - const std::lock_guard target_lock(target_mutex, std::adopt_lock); - const std::lock_guard send_lock(send_queue_mutex, std::adopt_lock); - - for (auto qit = send_queue.begin(); qit != send_queue.end();) { - for (auto it = qit->second.targets.begin(); it != qit->second.targets.end();) { - - // Get the pointer to our target - auto ptr = it->target.lock(); - - // If our pointer is valid (they haven't disconnected) - if (ptr) { - - auto now = std::chrono::steady_clock::now(); - auto timeout = it->last_send + ptr->round_trip_time; - - // Check if we should have expected an ack by now for some packets - if (timeout < now) { - - // We last sent now - it->last_send = now; - - // The next time we should check for a timeout - auto next_timeout = now + ptr->round_trip_time; - if (next_timeout < next_event) { - next_event = next_timeout; - next_event_callback(next_event); - } - - // Work out which packets to resend and resend them - for (uint16_t i = 0; i < qit->second.header.packet_count; ++i) { - if ((it->acked[i / 8] & uint8_t(1 << (i % 8))) == 0) { - send_packet(ptr->target, qit->second.header, i, qit->second.payload, true); - } - } - } - - ++it; - } - // Remove them from the list - else { - it = qit->second.targets.erase(it); - } - } - - if (qit->second.targets.empty()) { - qit = send_queue.erase(qit); - } - else { - ++qit; - } - } - } - - void NUClearNetwork::announce() { - - // Get all our targets that are global targets - auto announce_targets = name_target.equal_range(""); - for (auto it = announce_targets.first; it != announce_targets.second; ++it) { - - // Send the packet - if (::sendto(data_fd, - reinterpret_cast(announce_packet.data()), - static_cast(announce_packet.size()), - 0, - &it->second->target.sock, - it->second->target.size()) - < 0) { - throw std::system_error(network_errno, - std::system_category(), - "Network error when sending the announce packet"); - } - } - } - - void NUClearNetwork::process_packet(const sock_t& address, std::vector&& payload) { - - // First validate this is a NUClear network packet we can read (a version 2 NUClear packet) - if (payload.size() >= sizeof(PacketHeader) && payload[0] == 0xE2 && payload[1] == 0x98 && payload[2] == 0xA2 - && payload[3] == 0x02) { - - // This is a real packet! get our header information - const PacketHeader& header = *reinterpret_cast(payload.data()); - - // Get the map key for this device - auto key = udp_key(address); - - // From here on, we are doing things with our target lists that if changed would make us sad - std::shared_ptr remote; - /* Mutex scope */ { - const std::lock_guard lock(target_mutex); - auto r = udp_target.find(key); - remote = r == udp_target.end() ? nullptr : r->second; - } - - switch (header.type) { - - // A packet announcing that a user is on the network - case ANNOUNCE: { - // This is an announce packet! - const AnnouncePacket& announce = *reinterpret_cast(payload.data()); - - // They're new! - if (!remote) { - const std::string name(&announce.name, payload.size() - sizeof(AnnouncePacket)); - - // If they sent us an empty name ignore that's reserved for multicast transmissions - if (!name.empty()) { - // Add them into our list - auto ptr = std::make_shared(name, address); - bool new_connection = false; - /* Mutex scope */ { - const std::lock_guard lock(target_mutex); - - // Double check they are new - if (udp_target.count(key) == 0) { - new_connection = true; - targets.push_back(ptr); - udp_target.insert(std::make_pair(key, ptr)); - name_target.insert(std::make_pair(name, ptr)); - - // Say hi back! - ::sendto(data_fd, - reinterpret_cast(announce_packet.data()), - static_cast(announce_packet.size()), - 0, - &ptr->target.sock, - ptr->target.size()); - } - } - - // Only call the callback if it is new - if (new_connection) { - join_callback(*ptr); - } - } - } - // They're old but at least they're not timing out - else { - remote->last_update = std::chrono::steady_clock::now(); - } - } break; - case LEAVE: { - - // Goodbye! - if (remote) { - bool left = false; - - // Remove from our list - /* Mutex scope */ { - const std::lock_guard lock(target_mutex); - - // Double check they are gone after locking before removal - if (udp_target.count(key) > 0) { - left = true; - remove_target(remote); - } - } - // Call the callback if they really left - if (left) { - leave_callback(*remote); - } - } - - } break; - - // A packet containing data - case DATA_RETRANSMISSION: - case DATA: { - - // It's a data packet - const DataPacket& packet = *reinterpret_cast(payload.data()); - - // If the packet is obviously corrupt, drop it and since we didn't ack it it'll be resent if - // it's important - if (packet.packet_no > packet.packet_count) { - return; - } - - // Check if we know who this is and if we don't know them, ignore - if (remote) { - - // We got a packet from them recently - remote->last_update = std::chrono::steady_clock::now(); - - // Check if this packet is a retransmission of data - if (header.type == DATA_RETRANSMISSION) { - - // See if we recently processed this packet - // NOLINTNEXTLINE(readability-qualified-auto) MSVC disagrees - auto it = std::find(remote->recent_packets.begin(), - remote->recent_packets.end(), - packet.packet_id); - - // We recently processed this packet, this is just a failed ack - // Send the ack again if it was reliable - if (it != remote->recent_packets.end() && packet.reliable) { - - // Allocate room for the whole ack packet - std::vector r(sizeof(ACKPacket) + (packet.packet_count / 8), 0); - ACKPacket& response = *reinterpret_cast(r.data()); - response = ACKPacket(); - response.packet_id = packet.packet_id; - response.packet_no = packet.packet_no; - response.packet_count = packet.packet_count; - - // Set the bits for all packets (we got the whole thing) - for (int i = 0; i < packet.packet_count; ++i) { - (&response.packets)[i / 8] |= uint8_t(1 << (i % 8)); - } - - // Make who we are sending it to into a useable address - const sock_t& to = remote->target; - - // Send the packet - ::sendto(data_fd, - reinterpret_cast(r.data()), - static_cast(r.size()), - 0, - &to.sock, - to.size()); - - // We don't need to process this packet we already did - return; - } - } - - // If this is a solo packet (in a single chunk) - if (packet.packet_count == 1) { - - // Copy our data into a vector - std::vector out(&packet.data, - &packet.data + payload.size() - sizeof(DataPacket) + 1); - - // If this is a reliable packet, send an ack back - if (packet.reliable) { - // This response is easy since there is only one packet - ACKPacket response; - response.packet_id = packet.packet_id; - response.packet_no = packet.packet_no; - response.packet_count = packet.packet_count; - response.packets = 1; - - // Make who we are sending it to into a useable address - const sock_t& to = remote->target; - - ::sendto(data_fd, - reinterpret_cast(&response), - sizeof(response), - 0, - &to.sock, - to.size()); - - // Set this packet to have been recently received - remote->recent_packets[remote->recent_packets_index - .fetch_add(1, std::memory_order_relaxed)] = - packet.packet_id; - } - - packet_callback(*remote, packet.hash, packet.reliable, std::move(out)); - } - else { - const std::lock_guard lock(remote->assemblers_mutex); - - // Grab the payload and put it in our list of assemblers targets - auto& assemblers = remote->assemblers; - - auto& assembler = assemblers[packet.packet_id]; - - // First check that our cache isn't super corrupted by ensuring that our last packet - // in our list isn't after the number of packets we have - if (!assembler.second.empty() - && std::next(assembler.second.end(), -1)->first >= packet.packet_count) { - - // If so, we need to purge our cache and if this was a reliable packet, send a - // NACK back for all the packets we thought we had - // We don't know if we have any packets except the one we just got - if (packet.reliable) { - - // A basic ack has room for 8 packets and we need 1 extra byte for each 8 - // additional packets - std::vector r(sizeof(NACKPacket) + (packet.packet_count / 8), 0); - NACKPacket& response = *reinterpret_cast(r.data()); - response = NACKPacket(); - response.packet_id = packet.packet_id; - response.packet_count = packet.packet_count; - - // Set the bits for the packets we thought we received - for (const auto& p : assembler.second) { - (&response.packets)[p.first / 8] |= uint8_t(1 << (p.first % 8)); - } - - // Ensure the bit for this packet isn't NACKed - (&response.packets)[packet.packet_no / 8] &= - ~uint8_t(1 << (packet.packet_no % 8)); - - // Make who we are sending it to into a useable address - const sock_t& to = remote->target; - - // Send the packet - ::sendto(data_fd, - reinterpret_cast(r.data()), - static_cast(r.size()), - 0, - &to.sock, - to.size()); - } - - // Clear our packets here (the one we just got will be added right after this) - assembler.second.clear(); - } - - // Add our packet to our list of assemblers - assembler.first = std::chrono::steady_clock::now(); - assembler.second[packet.packet_no] = std::move(payload); - - // Create and send our ACK packet if this is a reliable transmission - if (packet.reliable) { - // A basic ack has room for 8 packets and we need 1 extra byte for each 8 - // additional packets - std::vector r(sizeof(ACKPacket) + (packet.packet_count / 8), 0); - ACKPacket& response = *reinterpret_cast(r.data()); - response = ACKPacket(); - response.packet_id = packet.packet_id; - response.packet_no = packet.packet_no; - response.packet_count = packet.packet_count; - - // Set the bits for the packets we have received - for (const auto& p : assembler.second) { - (&response.packets)[p.first / 8] |= uint8_t(1 << (p.first % 8)); - } - - // Make who we are sending it to into a useable address - const sock_t& to = remote->target; - - // Send the packet - ::sendto(data_fd, - reinterpret_cast(r.data()), - static_cast(r.size()), - 0, - &to.sock, - to.size()); - } - - // Check to see if we have enough to assemble the whole thing - if (assembler.second.size() == packet.packet_count) { - - // Work out exactly how much data we will need first so we only need one - // allocation - size_t payload_size = 0; - for (const auto& p : assembler.second) { - payload_size += p.second.size() - sizeof(DataPacket) + 1; - } - - // Read in our data - std::vector out; - out.reserve(payload_size); - for (auto& p : assembler.second) { - const DataPacket& part = *reinterpret_cast(p.second.data()); - out.insert(out.end(), - &part.data, - &part.data + p.second.size() - sizeof(DataPacket) + 1); - } - - // Send our assembled data packet - packet_callback(*remote, packet.hash, packet.reliable, std::move(out)); - - // If the packet was reliable add that it was recently received - if (packet.reliable) { - // Set this packet to have been recently received - remote->recent_packets[remote->recent_packets_index - .fetch_add(1, std::memory_order_relaxed)] = - packet.packet_id; - } - - // We have completed this packet, discard the data - assemblers.erase(assemblers.find(packet.packet_id)); - } - - // Check for and delete any timed out packets - for (auto it = assemblers.begin(); it != assemblers.end();) { - const auto now = std::chrono::steady_clock::now(); - const auto timeout = remote->round_trip_time * 10.0; - const auto& last_chunk_time = it->second.first; - - it = now > last_chunk_time + timeout ? assemblers.erase(it) : std::next(it); - } - } - } - } break; - - // Packet acknowledging the receipt of a packet of data - case ACK: { - - // It's an ack packet - const ACKPacket& packet = *reinterpret_cast(payload.data()); - - // Check if we know who this is and if we don't know them, ignore - if (remote) { - - // We got a packet from them recently - remote->last_update = std::chrono::steady_clock::now(); - - // lock the send queue mutex - const std::lock_guard send_lock(send_queue_mutex); - - // Check for our packet id in the send queue - if (send_queue.count(packet.packet_id) > 0) { - - auto& queue = send_queue[packet.packet_id]; - - // Find this target in the send queue - auto s = std::find_if(queue.targets.begin(), - queue.targets.end(), - [&](const PacketQueue::PacketTarget& target) { - return target.target.lock() == remote; - }); - - // Check for all the ways this ACK could be invalid: - // From an unknown person - if (s != queue.targets.end() - // Wrong packet - && packet.packet_count == queue.header.packet_count - // Truncated packet - && payload.size() == (sizeof(ACKPacket) + (queue.header.packet_count / 8))) { - - // Work out about how long our round trip time is - auto now = std::chrono::steady_clock::now(); - auto round_trip = now - s->last_send; - - // Approximate how long the round trip is to this remote so we can work out how - // long before retransmitting - // We use a baby kalman filter to help smooth out jitter - remote->measure_round_trip(round_trip); - - // Update our acks - bool all_acked = true; - for (unsigned i = 0; i < s->acked.size(); ++i) { - - // Update our bitset - s->acked[i] |= (&packet.packets)[i]; - - // Work out what a "fully acked" packet would look like - const uint8_t expected = i + 1 < s->acked.size() || packet.packet_count % 8 == 0 - ? 0xFF - : 0xFF >> (8 - (packet.packet_count % 8)); - - all_acked = all_acked && ((s->acked[i] & expected) == expected); - } - - // The remote has received this entire packet we can erase our sender - if (all_acked) { - queue.targets.erase(s); - - // If we're all done remove the whole thing - if (queue.targets.empty()) { - send_queue.erase(packet.packet_id); - } - } - } - } - } - } break; - - // Packet requesting a retransmission of some corrupt data - case NACK: { - // It's a nack packet - const NACKPacket& packet = *reinterpret_cast(payload.data()); - - // Check if we know who this is and if we don't know them, ignore - if (remote) { - - // We got a packet from them recently - remote->last_update = std::chrono::steady_clock::now(); - - // Check for our packet id in the send queue - if (send_queue.count(packet.packet_id) > 0) { - - // Find this packet in our sending queue - auto& queue = send_queue[packet.packet_id]; - - // Find this target in the send queue - auto s = std::find_if(queue.targets.begin(), - queue.targets.end(), - [&](const PacketQueue::PacketTarget& target) { - return target.target.lock() == remote; - }); - - // Validate that the nack is relevant and valid - // We know who it is - if (s != queue.targets.end() - // It's not corrupted - && packet.packet_count == queue.header.packet_count - // It's not truncated - && payload.size() == (sizeof(NACKPacket) + (queue.header.packet_count / 8))) { - - // Store the time as we are now sending new packets - s->last_send = std::chrono::steady_clock::now(); - - // The next time we should check for a timeout - auto next_timeout = s->last_send + remote->round_trip_time; - if (next_timeout < next_event) { - next_event = next_timeout; - next_event_callback(next_event); - } - - // Update our acks with the nacked data - for (unsigned i = 0; i < s->acked.size(); ++i) { - - // Update our bitset - s->acked[i] &= ~(&packet.packets)[i]; - } - - // Now we have to retransmit the nacked packets - for (uint16_t i = 0; i < packet.packet_count * 8; ++i) { - - // Check if this packet needs to be sent - const uint8_t bit = 1 << (i % 8); - if (((&packet.packets)[i] & bit) == bit) { - send_packet(remote->target, queue.header, i, queue.payload, true); - } - } - } - } - } - } - } - } - } - - - std::vector NUClearNetwork::listen_fds() { - return std::vector({data_fd, announce_fd}); - } - - void NUClearNetwork::send_packet(const sock_t& target, - NUClear::extension::network::DataPacket header, - uint16_t packet_no, - const std::vector& payload, - const bool& /*reliable*/) { - - // Our packet we are sending - msghdr message{}; - - std::array data{}; - message.msg_iov = data.data(); - message.msg_iovlen = 2; - - // Update our headers packet number and set it in the message - header.packet_no = packet_no; - data[0].iov_base = reinterpret_cast(&header); - data[0].iov_len = sizeof(DataPacket) - 1; - - // Work out what chunk of data we are sending - // const cast is fine as posix guarantees it won't be modified on a sendmsg - const char* start = reinterpret_cast(payload.data()) + (packet_no * packet_data_mtu); - data[1].iov_base = const_cast(start); // NOLINT(cppcoreguidelines-pro-type-const-cast) - data[1].iov_len = packet_no + 1 < header.packet_count ? packet_data_mtu : payload.size() % packet_data_mtu; - - // Set our target and send (once again const cast is fine) - message.msg_name = const_cast(&target.sock); // NOLINT(cppcoreguidelines-pro-type-const-cast) - message.msg_namelen = target.size(); - - // TODO(trent): if reliable, run select first to see if this socket is writeable - // If it is not reliable just don't send the message instead of blocking - sendmsg(data_fd, &message, 0); - } - - - void NUClearNetwork::send(const uint64_t& hash, - const std::vector& payload, - const std::string& target, - bool reliable) { - - // If we are not connected throw an error - if (targets.empty()) { - throw std::runtime_error("Cannot send messages as the network is not connected"); - } - - // The header for our packet - DataPacket header; - - /* Mutex Scope */ { - const std::lock_guard lock(send_queue_mutex); - // For the packet id we ensure that it's not currently used for retransmission - while (send_queue.count(++packet_id_source) > 0) { - } - header.packet_id = packet_id_source; - } - - header.packet_no = 0; - header.packet_count = uint16_t((payload.size() / packet_data_mtu) + 1); - header.reliable = reliable; - header.hash = hash; - - // If this was a reliable packet we need to cache it in case it needs to be resent - if (reliable) { - std::lock(target_mutex, send_queue_mutex); - const std::lock_guard lock_target(target_mutex, std::adopt_lock); - const std::lock_guard lock_send(send_queue_mutex, std::adopt_lock); - - auto& queue = send_queue[header.packet_id]; - - // Store the header, but update it's type to be a retransmission so it can be ignored if - // overtransmitted - queue.header = header; - queue.header.type = DATA_RETRANSMISSION; - // TODO(trent): there might be some better memory management that can happen here - queue.payload = payload; - const std::vector acks((header.packet_count / 8) + 1, 0); - - // Find interested parties or if multicast it's everyone we are connected to - auto range = target.empty() ? std::make_pair(name_target.begin(), name_target.end()) - : name_target.equal_range(target); - for (auto it = range.first; it != range.second; ++it) { - // If this target is an announce target ignore it - if (!it->first.empty()) { - // Add this guy to the queue - queue.targets.emplace_back(it->second, acks); - - // The next time we should check for a timeout - auto next_timeout = std::chrono::steady_clock::now() + it->second->round_trip_time; - if (next_timeout < next_event) { - next_event = next_timeout; - next_event_callback(next_event); - } - } - } - } - - /* Mutex Scope */ { - const std::lock_guard lock(target_mutex); - - // Now send all our packets to our targets - auto send_to = name_target.equal_range(target); - for (uint16_t i = 0; i < header.packet_count; ++i) { - for (auto s = send_to.first; s != send_to.second; ++s) { - send_packet(s->second->target, header, i, payload, reliable); - } - } - } - } - - } // namespace network -} // namespace extension -} // namespace NUClear diff --git a/src/extension/network/NUClearNetwork.hpp b/src/extension/network/NUClearNetwork.hpp deleted file mode 100644 index e2277af2..00000000 --- a/src/extension/network/NUClearNetwork.hpp +++ /dev/null @@ -1,357 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2017 NUClear Contributors - * - * This file is part of the NUClear codebase. - * See https://github.com/Fastcode/NUClear for further info. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated - * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to - * permit persons to whom the Software is furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the - * Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE - * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR - * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -#ifndef NUCLEAR_EXTENSION_NETWORK_NUCLEAR_NETWORK_HPP -#define NUCLEAR_EXTENSION_NETWORK_NUCLEAR_NETWORK_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../util/network/sock_t.hpp" -#include "../../util/platform.hpp" -#include "wire_protocol.hpp" - -namespace NUClear { -namespace extension { - namespace network { - - class NUClearNetwork { - private: - using sock_t = util::network::sock_t; - - public: - struct NetworkTarget { - - NetworkTarget( - std::string name, - const sock_t& target, - const std::chrono::steady_clock::time_point& last_update = std::chrono::steady_clock::now()) - : name(std::move(name)), target(target), last_update(last_update) { - - // Set our recent packets to an invalid value - recent_packets.fill(-1); - } - - /// The name of the remote target - std::string name; - /// The socket address for the remote target - sock_t target{}; - /// When we last received data from the remote target - std::chrono::steady_clock::time_point last_update; - /// A list of the last n packet groups to be received - std::array::max()> recent_packets{}; - /// An index for the recent_packets (circular buffer) - std::atomic recent_packets_index{0}; - /// Mutex to protect the fragmented packet storage - std::mutex assemblers_mutex; - /// Storage for fragmented packets while we build them - std::map>>> - assemblers; - - /// Struct storing the kalman filter for round trip time - struct RoundTripKF { - float process_noise = 1e-6f; - float measurement_noise = 1e-1f; - float variance = 1.0f; - float mean = 1.0f; - }; - /// A little kalman filter for estimating round trip time - RoundTripKF round_trip_kf{}; - - std::chrono::steady_clock::duration round_trip_time{std::chrono::seconds(1)}; - - void measure_round_trip(std::chrono::steady_clock::duration time) { - - // Make our measurement into a float seconds type - const std::chrono::duration m = - std::chrono::duration_cast>(time); - - // Alias variables - const auto& Q = round_trip_kf.process_noise; - const auto& R = round_trip_kf.measurement_noise; - auto& P = round_trip_kf.variance; - auto& X = round_trip_kf.mean; - - // Calculate our kalman gain - const float K = (P + Q) / (P + Q + R); - - // Do filter - P = R * (P + Q) / (R + P + Q); - X = X + (m.count() - X) * K; - - // Put result into our variable - round_trip_time = std::chrono::duration_cast( - std::chrono::duration(X)); - } - }; - - NUClearNetwork() = default; - virtual ~NUClearNetwork(); - NUClearNetwork(const NUClearNetwork& /*other*/) = delete; - NUClearNetwork(NUClearNetwork&& /*other*/) noexcept = delete; - NUClearNetwork& operator=(const NUClearNetwork& /*rhs*/) = delete; - NUClearNetwork& operator=(NUClearNetwork&& /*rhs*/) noexcept = delete; - - /** - * Send data using the NUClear network. - * - * @param hash The identifying hash for the data - * @param data The bytes that are to be sent - * @param target Who we are sending to (blank means everyone) - * @param reliable If the delivery of the data should be ensured - */ - void send(const uint64_t& hash, - const std::vector& payload, - const std::string& target, - bool reliable); - - /** - * Set the callback to use when a data packet is completed. - * - * @param f The callback function - */ - void set_packet_callback( - std::function&&)> f); - - /** - * Set the callback to use when a node joins the network. - * - * @param f The callback function - */ - void set_join_callback(std::function f); - - /** - * Set the callback to use when a node leaves the network. - * - * @param f The callback function - */ - void set_leave_callback(std::function f); - - /** - * Set the callback to use when the system want's to notify when it next needs attention. - * - * @param f The callback function - */ - void set_next_event_callback(std::function f); - - /** - * Leave the NUClear network. - */ - void shutdown(); - - /** - * Reset our network to use the new settings. - * - * Resets the networking system to use the new announce information and name. - * If the network was already joined, it will first leave and then rejoin the new network. - * If the provided address is multicast it will join a multicast network. - * If it is broadcast it will use IPv4 broadcast traffic to announce, unicast addresses will only announce - * to a single target. - * - * @param name The name of this node in the network - * @param address The address to announce on - * @param port The port to use for announcement - * @param bind_address The address to bind to (if unset will bind to all interfaces) - * @param network_mtu The mtu of the network we operate on - */ - void reset(const std::string& name, - const std::string& address, - in_port_t port, - const std::string& bind_address = "", - uint16_t network_mtu = 1500); - void reset(const std::string& name, - const std::string& address, - in_port_t port, - uint16_t network_mtu = 1500); - - /** - * Process waiting data in the UDP sockets and send them to the callback if they are relevant. - */ - void process(); - - /** - * Get the file descriptors that the network listens on. - * - * @return A list of file descriptors that the system listens on - */ - std::vector listen_fds(); - - private: - struct PacketQueue { - - struct PacketTarget { - - /// Constructor a new PacketTarget - PacketTarget(std::weak_ptr target, std::vector acked); - - /// The target we are sending this packet to - std::weak_ptr target; - - /// The bitset of the packets that have been acked - std::vector acked; - - /// When we last sent data to this client - std::chrono::steady_clock::time_point last_send; - }; - - /// Default constructor for the PacketQueue - PacketQueue(); - - /// The remote targets that want this packet - std::list targets; - - /// The header of the packet to send - DataPacket header; - - /// The data to send - std::vector payload; - }; - - /** - * Open our data udp socket. - * - * @param bind_address The address to bind to or any to bind to all interfaces - */ - void open_data(const sock_t& bind_address); - - /** - * Open our announce udp socket. - * - * @param announce_target The target to announce to - * @param bind_address The address to bind to or any to bind to all interfaces - */ - void open_announce(const sock_t& announce_target, const sock_t& bind_address); - - /** - * Processes the given packet and calls the callback if a packet was completed. - * - * @param address Who the packet came from - * @param data The data that was sent in this packet - */ - void process_packet(const sock_t& address, std::vector&& payload); - - /** - * Send an announce packet to our announce address. - */ - void announce(); - - /** - * Retransmit waiting packets that failed to send. - */ - void retransmit(); - - /** - * Send an individual packet to an individual target. - * - * @param target The target to send the packet to - * @param header The header for this packet - * @param packet_no The packet number we are sending - * @param payload The data bytes for the entire packet - * @param reliable If the packet is reliable (don't drop) - */ - void send_packet(const sock_t& target, - DataPacket header, - uint16_t packet_no, - const std::vector& payload, - const bool& reliable); - - /** - * Get the map key for this socket address. - * - * @param address Who the packet came from - * - * @return The map key for this socket - */ - static std::array udp_key(const sock_t& address); - - /** - * Remove a target from our list of targets. - * - * @param t The target to remove - */ - void remove_target(const std::shared_ptr& target); - - /// The file descriptor for the socket we use to send data and receive regular data - fd_t data_fd{INVALID_SOCKET}; - /// The file descriptor for the socket we use to receive announce data - fd_t announce_fd{INVALID_SOCKET}; - - /// The largest packet of data we will transmit, based on our IP version and MTU - uint16_t packet_data_mtu{1000}; - - // Our announce packet - std::vector announce_packet; - - /// An source for packet IDs to make sure they are semi unique - uint16_t packet_id_source{0}; - - /// The callback to execute when a data packet is completed - std::function&&)> - packet_callback; - /// The callback to execute when a node joins the network - std::function join_callback; - /// The callback to execute when a node leaves the network - std::function leave_callback; - /// The callback to execute when a node leaves the network - std::function next_event_callback; - - /// When we are next due to send an announce packet - std::chrono::steady_clock::time_point last_announce{std::chrono::seconds(0)}; - /// When the next timed event is due - std::chrono::steady_clock::time_point next_event{std::chrono::seconds(0)}; - - /// A mutex to guard modifications to the target lists - /// NOTE: mutex lock order must always be this order to avoid deadlocks - std::mutex target_mutex; - /// A mutex to guard modifications to the send queue - std::mutex send_queue_mutex; - - /// A map from packet_id to allow resending reliable data - std::map send_queue; - - /// A list of targets that we are connected to on the network - std::list> targets; - - /// A map of string names to targets with that name - std::multimap, std::less<>> name_target; - - /// A map of ip/port pairs to the network target they belong to - std::map, std::shared_ptr> udp_target; - }; - - } // namespace network -} // namespace extension -} // namespace NUClear - -#endif // NUCLEAR_EXTENSION_NETWORK_NUCLEAR_NETWORK_HPP diff --git a/src/extension/network/wire_protocol.hpp b/src/extension/network/wire_protocol.hpp deleted file mode 100644 index 775f54e3..00000000 --- a/src/extension/network/wire_protocol.hpp +++ /dev/null @@ -1,119 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2017 NUClear Contributors - * - * This file is part of the NUClear codebase. - * See https://github.com/Fastcode/NUClear for further info. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated - * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to - * permit persons to whom the Software is furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the - * Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE - * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR - * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -#ifndef NUCLEAR_EXTENSION_NETWORK_WIRE_PROTOCOL_HPP -#define NUCLEAR_EXTENSION_NETWORK_WIRE_PROTOCOL_HPP - -#include - -// These macros are used to pack the structs so that they are sent over the network in the correct format -#ifdef _MSC_VER - // NOLINTNEXTLINE(cppcoreguidelines-macro-usage) - #define PACK(...) __pragma(pack(push, 1)) __VA_ARGS__ __pragma(pack(pop)) -#else - // NOLINTNEXTLINE(cppcoreguidelines-macro-usage) - #define PACK(...) __VA_ARGS__ __attribute__((__packed__)) -#endif - -namespace NUClear { -namespace extension { - namespace network { - - /** - * A number that is used to represent the type of packet that is being sent/received - */ - enum Type : uint8_t { ANNOUNCE = 1, LEAVE = 2, DATA = 3, DATA_RETRANSMISSION = 4, ACK = 5, NACK = 6 }; - - /** - * The header that is sent with every packet. - */ - PACK(struct PacketHeader { - explicit PacketHeader(const Type& t) : type(t) {} - - /// Radioactive symbol in UTF8 - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - uint8_t header[3] = {0xE2, 0x98, 0xA2}; - /// The NUClear networking version - uint8_t version = 0x02; - /// The type of packet - Type type; - }); - - PACK(struct AnnouncePacket - : PacketHeader { - AnnouncePacket() : PacketHeader(ANNOUNCE) {} - - // A null terminated string name for this node (&name) - char name{0}; - }); - - PACK(struct LeavePacket : PacketHeader{LeavePacket(): PacketHeader(LEAVE){}}); - - PACK(struct DataPacket - : PacketHeader { - DataPacket() : PacketHeader(DATA) {} - - // A semi-unique identifier for this packet group - uint16_t packet_id{0}; - // What packet number this is within the group - uint16_t packet_no{0}; - // How many packets there are in the group - uint16_t packet_count{1}; - // If this packet is reliable and should be acked - bool reliable{false}; - // The 64 bit hash to identify the data type - uint64_t hash{0}; - // The data (access using &data) - char data{0}; - }); - - PACK(struct ACKPacket - : PacketHeader { - ACKPacket() : PacketHeader(ACK) {} - - /// The packet group identifier we are acknowledging - uint16_t packet_id{0}; - /// The index of the packet we are acknowledging - uint16_t packet_no{0}; - /// How many packets there are in the group - uint16_t packet_count{1}; - /// A bitset of which packets we have received (access using &packets) - uint8_t packets{0}; - }); - - PACK(struct NACKPacket - : PacketHeader { - NACKPacket() : PacketHeader(NACK) {} - - /// The packet group identifier we are acknowledging - uint16_t packet_id{0}; - /// How many packets there are in the group - uint16_t packet_count{1}; - /// A bitset of which packets we have received (access using &packets) - uint8_t packets{0}; - }); - - } // namespace network -} // namespace extension -} // namespace NUClear - -#endif // NUCLEAR_EXTENSION_NETWORK_WIRE_PROTOCOL_HPP diff --git a/src/nuclearnet/CMakeLists.txt b/src/nuclearnet/CMakeLists.txt new file mode 100644 index 00000000..35b3aed9 --- /dev/null +++ b/src/nuclearnet/CMakeLists.txt @@ -0,0 +1,64 @@ +#[[ +MIT License + +Copyright (c) 2025 NUClear Contributors + +This file is part of the NUClear codebase. +See https://github.com/Fastcode/NUClear for further info. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +]] + +# This CMakeLists.txt allows building NUClearNet as a standalone library, +# independent of the NUClear reactor framework. +# +# Usage (standalone): +# cmake -S src/nuclearnet -B build_net +# cmake --build build_net +# +# When built as part of the main NUClear project, these sources are included +# automatically via the parent src/CMakeLists.txt glob. + +cmake_minimum_required(VERSION 3.15) +project(NUClearNet VERSION 1.0.0 LANGUAGES CXX) + +find_package(Threads REQUIRED) + +file(GLOB_RECURSE NUCLEARNET_SOURCES CONFIGURE_DEPENDS "*.cpp") +file(GLOB_RECURSE NUCLEARNET_HEADERS CONFIGURE_DEPENDS "*.hpp") + +add_library(nuclearnet STATIC ${NUCLEARNET_SOURCES}) +add_library(NUClear::nuclearnet ALIAS nuclearnet) + +target_include_directories(nuclearnet + PUBLIC + $ + $ +) + +target_link_libraries(nuclearnet PUBLIC Threads::Threads) +target_compile_features(nuclearnet PUBLIC cxx_std_14) +set_target_properties(nuclearnet PROPERTIES POSITION_INDEPENDENT_CODE ON) + +# Platform-specific linking +if(WIN32) + target_link_libraries(nuclearnet PUBLIC ws2_32 mswsock iphlpapi) +endif() + +# Warnings +if(MSVC) + target_compile_options(nuclearnet PRIVATE /W4) +else() + target_compile_options(nuclearnet PRIVATE -Wall -Wextra -pedantic) +endif() diff --git a/src/nuclearnet/Discovery.cpp b/src/nuclearnet/Discovery.cpp new file mode 100644 index 00000000..b5d509a7 --- /dev/null +++ b/src/nuclearnet/Discovery.cpp @@ -0,0 +1,436 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and + * to permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of + * the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO + * THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF + * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#include "Discovery.hpp" + +#include +#include +#include +#include +#include +#include + +#include "wire_protocol.hpp" + + namespace NUClear { + namespace network { + + Discovery::Discovery(std::chrono::steady_clock::duration peer_timeout) : peer_timeout(peer_timeout) {} + + void Discovery::set_join_callback(JoinCallback cb) { + join_callback = std::move(cb); + } + + void Discovery::set_leave_callback(LeaveCallback cb) { + leave_callback = std::move(cb); + } + + void Discovery::set_subscription_change_callback(SubscriptionChangeCallback cb) { + subscription_change_callback = std::move(cb); + } + + std::vector Discovery::build_announce_packet(const std::string& name, + const std::vector& subscriptions) { + // Calculate total size: + // PacketHeader(5) + name_length(2) + name(variable) + num_subscriptions(2) + subscriptions(8 each) + const std::size_t size = sizeof(PacketHeader) + sizeof(uint16_t) + name.size() + sizeof(uint16_t) + + subscriptions.size() * sizeof(uint64_t); + + std::vector packet(size); + uint8_t* ptr = packet.data(); + + // Write header + PacketHeader header(ANNOUNCE); + std::memcpy(ptr, &header, sizeof(PacketHeader)); + ptr += sizeof(PacketHeader); + + // Write name length and name + auto name_len = static_cast(name.size()); + std::memcpy(ptr, &name_len, sizeof(uint16_t)); + ptr += sizeof(uint16_t); + std::memcpy(ptr, name.data(), name.size()); + ptr += name.size(); + + // Write subscription count and hashes + auto num_subs = static_cast(subscriptions.size()); + std::memcpy(ptr, &num_subs, sizeof(uint16_t)); + ptr += sizeof(uint16_t); + for (const auto& hash : subscriptions) { + std::memcpy(ptr, &hash, sizeof(uint64_t)); + ptr += sizeof(uint64_t); + } + + return packet; + } + + std::vector Discovery::build_leave_packet() { + std::vector packet(sizeof(LeavePacket)); + LeavePacket leave; + std::memcpy(packet.data(), &leave, sizeof(LeavePacket)); + return packet; + } + + std::vector Discovery::build_connect_packet(uint8_t flags) { + std::vector packet(sizeof(ConnectPacket)); + ConnectPacket connect; + connect.flags = flags; + std::memcpy(packet.data(), &connect, sizeof(ConnectPacket)); + return packet; + } + + Discovery::AnnounceResult Discovery::process_announce(const sock_t& source, + const uint8_t* data, + std::size_t length, + std::chrono::steady_clock::time_point now) { + AnnounceResult announce_result; + + // Minimum size: header(5) + name_length(2) + num_subscriptions(2) = 9 + if (length < sizeof(PacketHeader) + sizeof(uint16_t) + sizeof(uint16_t)) { + return announce_result; + } + + const uint8_t* ptr = data + sizeof(PacketHeader); + std::size_t remaining = length - sizeof(PacketHeader); + + // Read name length + uint16_t name_len = 0; + std::memcpy(&name_len, ptr, sizeof(uint16_t)); + ptr += sizeof(uint16_t); + remaining -= sizeof(uint16_t); + + // Validate name fits + if (remaining < name_len + sizeof(uint16_t)) { + return announce_result; + } + + // Read name + std::string name(reinterpret_cast(ptr), name_len); + ptr += name_len; + remaining -= name_len; + + // Ignore empty names + if (name.empty()) { + return announce_result; + } + + // Read subscription count + uint16_t num_subs = 0; + std::memcpy(&num_subs, ptr, sizeof(uint16_t)); + ptr += sizeof(uint16_t); + remaining -= sizeof(uint16_t); + + // Validate subscriptions fit + if (remaining < num_subs * sizeof(uint64_t)) { + return announce_result; + } + + // Read subscriptions + std::set subscriptions; + for (uint16_t i = 0; i < num_subs; ++i) { + uint64_t hash = 0; + std::memcpy(&hash, ptr, sizeof(uint64_t)); + ptr += sizeof(uint64_t); + subscriptions.insert(hash); + } + + // Check if this is a new peer or an existing one + bool subs_changed = false; + bool fire_join = false; + PeerInfo join_info; + { + const std::lock_guard lock(peers_mutex); + + auto it = peers.find(source); + if (it == peers.end()) { + // New peer — record with announce_heard = true + announce_result.is_new = true; + PeerInfo info; + info.name = name; + info.address = source; + info.last_seen = now; + info.subscriptions = std::move(subscriptions); + info.announce_heard = true; + info.handshake = HandshakeState::IDLE; + peers.emplace(source, std::move(info)); + } + else { + auto& peer = it->second; + peer.last_seen = now; + + // Mark announce as heard (may trigger connection if data was already confirmed) + if (!peer.announce_heard) { + peer.announce_heard = true; + if (peer.handshake == HandshakeState::CONFIRMED) { + fire_join = true; + join_info = peer; + } + } + + // Update name if it was unknown (peer added via CONNECT before announce) + if (peer.name.empty()) { + peer.name = name; + } + + // Check for subscription changes + if (peer.subscriptions != subscriptions) { + peer.subscriptions = std::move(subscriptions); + subs_changed = true; + } + + // Determine retransmit flags based on handshake state + switch (peer.handshake) { + case HandshakeState::IDLE: + case HandshakeState::SYN_SENT: + announce_result.response_flags = SYN; + break; + case HandshakeState::SYN_RECEIVED: + announce_result.response_flags = SYN | CON_ACK; + break; + case HandshakeState::CONFIRMED: + // Send ACK to help the other side if they're stuck in SYN_RECEIVED + announce_result.response_flags = CON_ACK; + break; + } + } + } + + // Fire callbacks outside the lock + if (fire_join && join_callback) { + join_callback(join_info); + } + if (subs_changed && subscription_change_callback) { + PeerInfo info; + { + const std::lock_guard lock(peers_mutex); + auto it = peers.find(source); + if (it != peers.end()) { + info = it->second; + } + else { + subs_changed = false; + } + } + if (subs_changed) { + subscription_change_callback(info); + } + } + + return announce_result; + } + + void Discovery::process_leave(const sock_t& source) { + PeerInfo removed; + bool was_connected = false; + { + const std::lock_guard lock(peers_mutex); + auto it = peers.find(source); + if (it != peers.end()) { + was_connected = (it->second.announce_heard + && it->second.handshake == HandshakeState::CONFIRMED); + removed = it->second; + peers.erase(it); + } + } + + // Only fire leave callback for peers that completed the handshake + if (was_connected && leave_callback) { + leave_callback(removed); + } + } + + Discovery::ConnectResult Discovery::process_connect(const sock_t& source, + uint8_t flags, + std::chrono::steady_clock::time_point now) { + ConnectResult result; + bool fire_join = false; + PeerInfo info; + + { + const std::lock_guard lock(peers_mutex); + auto it = peers.find(source); + if (it == peers.end()) { + // Unknown peer — add with minimal info (name/subs will come from announce) + PeerInfo new_peer; + new_peer.address = source; + new_peer.last_seen = now; + new_peer.announce_heard = false; + new_peer.handshake = HandshakeState::IDLE; + it = peers.emplace(source, std::move(new_peer)).first; + } + + auto& peer = it->second; + peer.last_seen = now; + + const bool has_syn = (flags & SYN) != 0; + const bool has_ack = (flags & CON_ACK) != 0; + + switch (peer.handshake) { + case HandshakeState::IDLE: + if (has_syn && !has_ack) { + // Received SYN — respond with SYN+ACK + peer.handshake = HandshakeState::SYN_RECEIVED; + result.response_flags = SYN | CON_ACK; + } + else if (has_syn && has_ack) { + // SYN+ACK but we haven't sent SYN — treat as SYN, respond SYN+ACK + peer.handshake = HandshakeState::SYN_RECEIVED; + result.response_flags = SYN | CON_ACK; + } + break; + + case HandshakeState::SYN_SENT: + if (has_syn && has_ack) { + // Received SYN+ACK to our SYN — send ACK, data path confirmed + peer.handshake = HandshakeState::CONFIRMED; + result.response_flags = CON_ACK; + if (peer.announce_heard) { + result.just_connected = true; + fire_join = true; + info = peer; + } + } + else if (has_syn && !has_ack) { + // Simultaneous open: both sent SYN at the same time + // Respond with SYN+ACK + peer.handshake = HandshakeState::SYN_RECEIVED; + result.response_flags = SYN | CON_ACK; + } + break; + + case HandshakeState::SYN_RECEIVED: + if (has_ack && !has_syn) { + // Received ACK to our SYN+ACK — data path confirmed + peer.handshake = HandshakeState::CONFIRMED; + if (peer.announce_heard) { + result.just_connected = true; + fire_join = true; + info = peer; + } + } + else if (has_syn && has_ack) { + // Simultaneous open: both in SYN_RECEIVED, received SYN+ACK + peer.handshake = HandshakeState::CONFIRMED; + result.response_flags = CON_ACK; + if (peer.announce_heard) { + result.just_connected = true; + fire_join = true; + info = peer; + } + } + break; + + case HandshakeState::CONFIRMED: + // Already confirmed — respond to duplicates + if (has_syn && has_ack) { + result.response_flags = CON_ACK; + } + else if (has_syn && !has_ack) { + // Peer might have restarted — respond with SYN+ACK + result.response_flags = SYN | CON_ACK; + } + break; + } + } + + // Fire join callback outside the lock + if (fire_join && join_callback) { + join_callback(info); + } + + return result; + } + + void Discovery::mark_syn_sent(const sock_t& address) { + const std::lock_guard lock(peers_mutex); + auto it = peers.find(address); + if (it != peers.end() && it->second.handshake == HandshakeState::IDLE) { + it->second.handshake = HandshakeState::SYN_SENT; + } + } + + bool Discovery::is_connected(const sock_t& address) const { + const std::lock_guard lock(peers_mutex); + auto it = peers.find(address); + return it != peers.end() && it->second.announce_heard + && it->second.handshake == HandshakeState::CONFIRMED; + } + + void Discovery::touch_peer(const sock_t& source, std::chrono::steady_clock::time_point now) { + const std::lock_guard lock(peers_mutex); + auto it = peers.find(source); + if (it != peers.end()) { + it->second.last_seen = now; + } + } + + std::vector Discovery::check_timeouts(std::chrono::steady_clock::time_point now) { + std::vector removed; + + { + const std::lock_guard lock(peers_mutex); + for (auto it = peers.begin(); it != peers.end();) { + if (now - it->second.last_seen > peer_timeout) { + // Only report leave for peers that were fully connected + if (it->second.announce_heard + && it->second.handshake == HandshakeState::CONFIRMED) { + removed.push_back(it->second); + } + it = peers.erase(it); + } + else { + ++it; + } + } + } + + // Fire leave callbacks outside the lock + if (leave_callback) { + for (const auto& peer : removed) { + leave_callback(peer); + } + } + + return removed; + } + + std::map Discovery::get_peers() const { + const std::lock_guard lock(peers_mutex); + return peers; + } + + bool Discovery::has_peer(const sock_t& address) const { + const std::lock_guard lock(peers_mutex); + return peers.count(address) > 0; + } + + const PeerInfo* Discovery::get_peer(const sock_t& address) const { + const std::lock_guard lock(peers_mutex); + auto it = peers.find(address); + if (it != peers.end()) { + return &it->second; + } + return nullptr; + } + + } // namespace network +} // namespace NUClear diff --git a/src/nuclearnet/Discovery.hpp b/src/nuclearnet/Discovery.hpp new file mode 100644 index 00000000..7bc03a59 --- /dev/null +++ b/src/nuclearnet/Discovery.hpp @@ -0,0 +1,261 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef NUCLEAR_NETWORK_DISCOVERY_HPP +#define NUCLEAR_NETWORK_DISCOVERY_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../util/network/sock_t.hpp" + +namespace NUClear { +namespace network { + + /** + * Data port handshake state for the CONNECT 3-way handshake. + */ + enum class HandshakeState : uint8_t { + IDLE, ///< No handshake attempted yet + SYN_SENT, ///< Sent SYN, waiting for SYN+ACK + SYN_RECEIVED, ///< Received SYN, sent SYN+ACK, waiting for ACK + CONFIRMED, ///< Data port handshake complete + }; + + /** + * Information about a discovered peer on the network. + */ + struct PeerInfo { + /// The peer's announced name + std::string name; + /// The peer's socket address (IP + data port, learned from UDP source) + util::network::sock_t address{}; + /// When we last received any packet from this peer + std::chrono::steady_clock::time_point last_seen; + /// The set of message type hashes this peer has subscribed to (empty = wants all) + std::set subscriptions; + /// Whether we have heard this peer's announce on the announce channel (proves their_d→our_a) + bool announce_heard = false; + /// Data port handshake state (proves their_d↔our_d via CONNECT packets) + HandshakeState handshake = HandshakeState::IDLE; + }; + + /** + * Handles peer discovery via periodic announce packets. + * + * Responsibilities: + * - Sending announce packets (from the data socket for NAT-friendliness) + * - Processing received announce packets to discover peers + * - Tracking peer liveness via last-seen timestamps + * - Removing peers that have timed out + * - Processing LEAVE packets for graceful departure + * - Storing per-peer subscription information from announce packets + */ + class Discovery { + public: + using sock_t = util::network::sock_t; + + /// Callback when a new peer joins + using JoinCallback = std::function; + /// Callback when a peer leaves (timeout or graceful) + using LeaveCallback = std::function; + /// Callback when a peer's subscriptions change + using SubscriptionChangeCallback = std::function; + + /** + * Construct the discovery module. + * + * @param peer_timeout How long without receiving a packet before a peer is considered gone + */ + explicit Discovery(std::chrono::steady_clock::duration peer_timeout = std::chrono::seconds(2)); + + /** + * Set the callback to invoke when a new peer joins. + */ + void set_join_callback(JoinCallback cb); + + /** + * Set the callback to invoke when a peer leaves. + */ + void set_leave_callback(LeaveCallback cb); + + /** + * Set the callback to invoke when a peer's subscriptions change. + */ + void set_subscription_change_callback(SubscriptionChangeCallback cb); + + /** + * Build an announce packet for this node. + * + * @param name This node's name + * @param subscriptions The set of type hashes this node subscribes to (empty = all) + * + * @return The serialized announce packet bytes + */ + static std::vector build_announce_packet(const std::string& name, + const std::vector& subscriptions); + + /** + * Build a leave packet. + * + * @return The serialized leave packet bytes + */ + static std::vector build_leave_packet(); + + /** + * Build a connect packet with the given flags. + * + * @param flags SYN, ACK, or SYN|ACK + * + * @return The serialized connect packet bytes + */ + static std::vector build_connect_packet(uint8_t flags); + + /** + * Process a received announce packet from a peer. + * Returns information about what action to take (send CONNECT, re-announce, etc.) + * + * @param source The UDP source address (IP + port) of the packet + * @param data The raw packet data + * @param length The length of the packet data + * @param now The current time (defaults to steady_clock::now()) + * + * @return Result indicating whether this is a new peer and what CONNECT flags to send + */ + struct AnnounceResult { + bool is_new = false; ///< Whether this was a previously unknown peer + uint8_t response_flags = 0; ///< CONNECT flags to send (0 = don't send) + }; + AnnounceResult process_announce(const sock_t& source, + const uint8_t* data, + std::size_t length, + std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()); + + /** + * Process a received leave packet from a peer. + * + * @param source The UDP source address of the packet + */ + void process_leave(const sock_t& source); + + /** + * Process a received CONNECT packet from a peer. + * Advances the data port handshake state machine. + * + * @param source The UDP source address + * @param flags The connect flags (SYN, ACK, or SYN|ACK) + * @param now The current time + * + * @return The response flags to send back (0 = no response needed), + * and whether the peer just became fully connected + */ + struct ConnectResult { + uint8_t response_flags = 0; ///< Flags for response packet (0 = don't send) + bool just_connected = false; ///< Whether this transition completed the full connection + }; + ConnectResult process_connect(const sock_t& source, + uint8_t flags, + std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()); + + /** + * Advance a peer to SYN_SENT state (called when we send a SYN). + * + * @param address The peer's address + */ + void mark_syn_sent(const sock_t& address); + + /** + * Check if a peer is fully connected (both announce heard AND data handshake confirmed). + * + * @param address The peer's address + * @return true if the peer is fully connected + */ + bool is_connected(const sock_t& address) const; + + /** + * Update the last_seen timestamp for a peer (called on any received packet). + * + * @param source The UDP source address + * @param now The current time (defaults to steady_clock::now()) + */ + void touch_peer(const sock_t& source, + std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()); + + /** + * Check for peers that have timed out and remove them. + * + * @param now The current time (defaults to steady_clock::now()) + * + * @return List of peers that were removed due to timeout + */ + std::vector check_timeouts( + std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()); + + /** + * Get the current list of known peers. + * + * @return Map of socket address to peer info + */ + std::map get_peers() const; + + /** + * Check if a specific peer is known. + * + * @param address The peer's address + * @return true if the peer is known + */ + bool has_peer(const sock_t& address) const; + + /** + * Get a specific peer's info. + * + * @param address The peer's address + * @return Pointer to peer info, or nullptr if not found + */ + const PeerInfo* get_peer(const sock_t& address) const; + + private: + /// How long without hearing from a peer before it's removed + std::chrono::steady_clock::duration peer_timeout; + + /// Mutex protecting the peers map + mutable std::mutex peers_mutex; + + /// Known peers indexed by their socket address + std::map peers; + + /// Callbacks + JoinCallback join_callback; + LeaveCallback leave_callback; + SubscriptionChangeCallback subscription_change_callback; + }; + +} // namespace network +} // namespace NUClear + +#endif // NUCLEAR_NETWORK_DISCOVERY_HPP diff --git a/src/nuclearnet/FileDescriptor.hpp b/src/nuclearnet/FileDescriptor.hpp new file mode 100644 index 00000000..71ab6d98 --- /dev/null +++ b/src/nuclearnet/FileDescriptor.hpp @@ -0,0 +1,104 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and + * to permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of + * the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO + * THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF + * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef NUCLEAR_NETWORK_FILE_DESCRIPTOR_HPP +#define NUCLEAR_NETWORK_FILE_DESCRIPTOR_HPP + +#include "../util/platform.hpp" + + namespace NUClear { + namespace network { + + /** + * RAII wrapper for a file descriptor (socket). + * Automatically closes the file descriptor on destruction. + * Non-copyable, move-only. + */ + class FileDescriptor { + public: + /// Construct with an invalid descriptor + FileDescriptor() = default; + + /// Construct taking ownership of an existing file descriptor + explicit FileDescriptor(fd_t fd) : fd(fd) {} + + /// Destructor closes the descriptor if valid + ~FileDescriptor() { + reset(); + } + + // Non-copyable + FileDescriptor(const FileDescriptor&) = delete; + FileDescriptor& operator=(const FileDescriptor&) = delete; + + // Movable + FileDescriptor(FileDescriptor&& other) noexcept : fd(other.fd) { + other.fd = INVALID_SOCKET; + } + FileDescriptor& operator=(FileDescriptor&& other) noexcept { + if (this != &other) { + reset(); + fd = other.fd; + other.fd = INVALID_SOCKET; + } + return *this; + } + + /// Get the raw file descriptor + fd_t get() const { + return fd; + } + + /// Check if the descriptor is valid + bool valid() const { + return fd != INVALID_SOCKET; + } + + /// Implicit conversion to fd_t for use with system calls + operator fd_t() const { // NOLINT(google-explicit-constructor) + return fd; + } + + /// Release ownership without closing + fd_t release() { + const fd_t old = fd; + fd = INVALID_SOCKET; + return old; + } + + /// Close the current descriptor and take ownership of a new one + void reset(fd_t new_fd = INVALID_SOCKET) { + if (fd != INVALID_SOCKET) { + ::close(fd); + } + fd = new_fd; + } + + private: + fd_t fd{INVALID_SOCKET}; + }; + + } // namespace network +} // namespace NUClear + +#endif // NUCLEAR_NETWORK_FILE_DESCRIPTOR_HPP diff --git a/src/nuclearnet/Fragmentation.cpp b/src/nuclearnet/Fragmentation.cpp new file mode 100644 index 00000000..f9beadc9 --- /dev/null +++ b/src/nuclearnet/Fragmentation.cpp @@ -0,0 +1,150 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "Fragmentation.hpp" + +#include +#include + +namespace NUClear { +namespace network { + + Fragmentation::Fragmentation(uint16_t packet_mtu, + std::size_t max_assembly_size, + std::chrono::steady_clock::duration assembly_timeout) + : packet_mtu(packet_mtu), max_assembly_size(max_assembly_size), assembly_timeout(assembly_timeout) {} + + std::vector Fragmentation::fragment(uint16_t packet_id, + uint64_t hash, + uint8_t flags, + const std::vector& payload) const { + // Calculate how many fragments we need + uint16_t packet_count = + payload.empty() ? 1 : static_cast((payload.size() + packet_mtu - 1) / packet_mtu); + + std::vector fragments; + fragments.reserve(packet_count); + + for (uint16_t i = 0; i < packet_count; ++i) { + Fragment frag; + frag.packet_id = packet_id; + frag.packet_no = i; + frag.packet_count = packet_count; + frag.flags = flags; + frag.hash = hash; + + // Calculate the data slice for this fragment + std::size_t offset = static_cast(i) * packet_mtu; + std::size_t length = std::min(static_cast(packet_mtu), payload.size() - offset); + frag.data.assign(payload.begin() + offset, payload.begin() + offset + length); + + fragments.push_back(std::move(frag)); + } + + return fragments; + } + + bool Fragmentation::submit_fragment(uint64_t source_key, + uint16_t packet_id, + uint16_t packet_no, + uint16_t packet_count, + uint64_t hash, + uint8_t flags, + const uint8_t* data, + std::size_t data_length, + AssembledPacket& out_packet, + std::chrono::steady_clock::time_point now) { + if (packet_count == 0 || packet_no >= packet_count) { + return false; + } + + // Enforce max assembly size check + if (max_assembly_size > 0) { + std::size_t projected_size = static_cast(packet_count) * packet_mtu; + if (projected_size > max_assembly_size) { + return false; + } + } + + AssemblyKey key{source_key, packet_id}; + + const std::lock_guard lock(assembly_mutex); + + auto& assembly = assemblies[key]; + assembly.hash = hash; + assembly.flags = flags; + assembly.packet_count = packet_count; + assembly.last_update = now; + + // Store this fragment + assembly.fragments[packet_no].assign(data, data + data_length); + + // Check if we have all fragments + if (assembly.fragments.size() == packet_count) { + // Assemble the complete payload + out_packet.packet_id = packet_id; + out_packet.hash = hash; + out_packet.flags = flags; + + // Calculate total size + std::size_t total_size = 0; + for (const auto& fragment_entry : assembly.fragments) { + const auto& frag_data = fragment_entry.second; + total_size += frag_data.size(); + } + out_packet.payload.clear(); + out_packet.payload.reserve(total_size); + + // Concatenate in order + for (uint16_t i = 0; i < packet_count; ++i) { + const auto& frag_data = assembly.fragments[i]; + out_packet.payload.insert(out_packet.payload.end(), frag_data.begin(), frag_data.end()); + } + + // Remove the completed assembly + assemblies.erase(key); + + return true; + } + + return false; + } + + std::size_t Fragmentation::cleanup_expired(std::chrono::steady_clock::time_point now) { + std::size_t removed = 0; + + const std::lock_guard lock(assembly_mutex); + for (auto it = assemblies.begin(); it != assemblies.end();) { + if (now - it->second.last_update > assembly_timeout) { + it = assemblies.erase(it); + ++removed; + } + else { + ++it; + } + } + + return removed; + } + +} // namespace network +} // namespace NUClear diff --git a/src/nuclearnet/Fragmentation.hpp b/src/nuclearnet/Fragmentation.hpp new file mode 100644 index 00000000..c7fe658d --- /dev/null +++ b/src/nuclearnet/Fragmentation.hpp @@ -0,0 +1,161 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef NUCLEAR_NETWORK_FRAGMENTATION_HPP +#define NUCLEAR_NETWORK_FRAGMENTATION_HPP + +#include +#include +#include +#include +#include +#include + +namespace NUClear { +namespace network { + + /** + * Handles fragmentation of large messages into MTU-sized packets and reassembly on the receiving side. + * + * Responsibilities: + * - Splitting a payload into fragments that fit within the packet MTU + * - Reassembling received fragments back into complete messages + * - Tracking incomplete assemblies with timeouts + * - Enforcing maximum reassembly size limits to prevent memory bombs + */ + class Fragmentation { + public: + /// Result of a completed reassembly + struct AssembledPacket { + uint16_t packet_id; + uint64_t hash; + uint8_t flags; + std::vector payload; + }; + + /// A single fragment ready to be sent + struct Fragment { + uint16_t packet_id; + uint16_t packet_no; + uint16_t packet_count; + uint8_t flags; + uint64_t hash; + std::vector data; + }; + + /** + * Construct the fragmentation module. + * + * @param packet_mtu Maximum payload bytes per fragment + * @param max_assembly_size Maximum total size of a reassembled message (0 = unlimited) + * @param assembly_timeout How long to keep an incomplete assembly before discarding. + * Should match the peer timeout since if no fragments arrive within this period, + * the peer will be considered dead and cleaned up anyway. + */ + Fragmentation(uint16_t packet_mtu = 1452, + std::size_t max_assembly_size = 64 * 1024 * 1024, // 64 MB default + std::chrono::steady_clock::duration assembly_timeout = std::chrono::seconds(2)); + + /** + * Fragment a message into MTU-sized pieces. + * + * @param packet_id The unique ID for this packet group + * @param hash The message type hash + * @param flags Packet flags (e.g., RELIABLE) + * @param payload The full message payload + * + * @return Vector of fragments ready to be sent + */ + std::vector fragment(uint16_t packet_id, + uint64_t hash, + uint8_t flags, + const std::vector& payload) const; + + /** + * Submit a received fragment for reassembly. + * + * @param source Opaque key identifying the sender (for per-peer assembly tracking) + * @param packet_id The packet group ID + * @param packet_no This fragment's index (0-based) + * @param packet_count Total fragments in the group + * @param hash The message type hash + * @param flags Packet flags + * @param data The fragment payload + * + * @param out_packet Filled with the assembled packet when reassembly completes + * @param now The current time (defaults to steady_clock::now()) + * + * @return true if all fragments are now received and @p out_packet is valid, false otherwise + */ + bool submit_fragment(uint64_t source_key, + uint16_t packet_id, + uint16_t packet_no, + uint16_t packet_count, + uint64_t hash, + uint8_t flags, + const uint8_t* data, + std::size_t data_length, + AssembledPacket& out_packet, + std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()); + + /** + * Clean up assemblies that have timed out. + * + * @param now The current time (defaults to steady_clock::now()) + * + * @return Number of assemblies that were discarded + */ + std::size_t cleanup_expired( + std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()); + + /** + * Get the packet MTU (max payload per fragment). + */ + uint16_t get_packet_mtu() const { + return packet_mtu; + } + + private: + /// Key for an in-progress assembly: (source_key, packet_id) + using AssemblyKey = std::pair; + + /// State for an in-progress reassembly + struct Assembly { + uint64_t hash; + uint8_t flags; + uint16_t packet_count; + std::chrono::steady_clock::time_point last_update; + std::map> fragments; + }; + + uint16_t packet_mtu; + std::size_t max_assembly_size; + std::chrono::steady_clock::duration assembly_timeout; + + mutable std::mutex assembly_mutex; + std::map assemblies; + }; + +} // namespace network +} // namespace NUClear + +#endif // NUCLEAR_NETWORK_FRAGMENTATION_HPP diff --git a/src/nuclearnet/NUClearNet.cpp b/src/nuclearnet/NUClearNet.cpp new file mode 100644 index 00000000..76e53b76 --- /dev/null +++ b/src/nuclearnet/NUClearNet.cpp @@ -0,0 +1,597 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "NUClearNet.hpp" + +#include +#include +#include +#include + +#include "../util/network/resolve.hpp" +#include "wire_protocol.hpp" + +namespace NUClear { +namespace network { + + const std::chrono::milliseconds NUClearNet::ANNOUNCE_INTERVAL(500); + + NUClearNet::NUClearNet() + : discovery(std::make_unique(std::chrono::seconds(2))) + , fragmentation(std::make_unique(1452, 64 * 1024 * 1024, std::chrono::seconds(2))) + , reliability(std::make_unique()) { + + // Wire up discovery callbacks to forward to user callbacks + discovery->set_join_callback([this](const PeerInfo& peer) { + // Update routing with peer's subscriptions + routing.update_peer_subscriptions(peer.address, peer.subscriptions); + if (join_callback) { + join_callback(peer); + } + }); + + discovery->set_leave_callback([this](const PeerInfo& peer) { + routing.remove_peer(peer.address); + reliability->remove_peer(peer.address); + deduplicators.erase(peer.address); + if (leave_callback) { + leave_callback(peer); + } + }); + + discovery->set_subscription_change_callback([this](const PeerInfo& peer) { + routing.update_peer_subscriptions(peer.address, peer.subscriptions); + }); + } + + NUClearNet::~NUClearNet() { + shutdown(); + } + + void NUClearNet::reset(const NetworkConfig& cfg) { + // Shut down existing connections + shutdown(); + + config = cfg; + node_name = cfg.name; + + // Update module configurations + discovery = std::make_unique(cfg.peer_timeout); + fragmentation = std::make_unique( + static_cast(cfg.mtu - sizeof(DataPacket) + 1 - 40 - 8), // MTU - headers + cfg.max_assembly_size, + cfg.peer_timeout); // Assembly timeout matches peer timeout + reliability = std::make_unique(); + + // Re-wire discovery callbacks + discovery->set_join_callback([this](const PeerInfo& peer) { + routing.update_peer_subscriptions(peer.address, peer.subscriptions); + if (join_callback) { + join_callback(peer); + } + }); + discovery->set_leave_callback([this](const PeerInfo& peer) { + routing.remove_peer(peer.address); + reliability->remove_peer(peer.address); + deduplicators.erase(peer.address); + if (leave_callback) { + leave_callback(peer); + } + }); + discovery->set_subscription_change_callback([this](const PeerInfo& peer) { + routing.update_peer_subscriptions(peer.address, peer.subscriptions); + }); + + // Resolve announce target + announce_target = util::network::resolve(cfg.announce_address, cfg.announce_port); + + // Determine bind address + sock_t bind_addr{}; + if (cfg.bind_address.empty()) { + bind_addr = announce_target; + if (announce_target.sock.sa_family == AF_INET) { + bind_addr.ipv4.sin_addr.s_addr = htonl(INADDR_ANY); + } + else if (announce_target.sock.sa_family == AF_INET6) { + bind_addr.ipv6.sin6_addr = IN6ADDR_ANY_INIT; + } + } + else { + bind_addr = util::network::resolve(cfg.bind_address, cfg.announce_port); + } + + // Open data socket (ephemeral port) + { + sock_t data_bind = bind_addr; + if (data_bind.sock.sa_family == AF_INET) { + data_bind.ipv4.sin_port = 0; + } + else { + data_bind.ipv6.sin6_port = 0; + } + + fd_t fd = ::socket(data_bind.sock.sa_family, SOCK_DGRAM, IPPROTO_UDP); + if (fd < 0) { + throw std::system_error(network_errno, std::system_category(), "Failed to create data socket"); + } + + int yes = 1; + ::setsockopt(fd, SOL_SOCKET, SO_BROADCAST, reinterpret_cast(&yes), sizeof(yes)); + + if (::bind(fd, &data_bind.sock, data_bind.size()) != 0) { + ::close(fd); + throw std::system_error(network_errno, std::system_category(), "Failed to bind data socket"); + } + + data_fd.reset(fd); + } + + // Open announce socket (known port) + { + fd_t fd = ::socket(bind_addr.sock.sa_family, SOCK_DGRAM, IPPROTO_UDP); + if (fd < 0) { + throw std::system_error(network_errno, std::system_category(), "Failed to create announce socket"); + } + + int yes = 1; + ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), sizeof(yes)); +#ifdef SO_REUSEPORT + ::setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), sizeof(yes)); +#endif + ::setsockopt(fd, SOL_SOCKET, SO_BROADCAST, reinterpret_cast(&yes), sizeof(yes)); + + if (::bind(fd, &bind_addr.sock, bind_addr.size()) != 0) { + ::close(fd); + throw std::system_error(network_errno, std::system_category(), "Failed to bind announce socket"); + } + + // Join multicast group if applicable + bool multicast = (announce_target.sock.sa_family == AF_INET + && (ntohl(announce_target.ipv4.sin_addr.s_addr) & 0xF0000000U) == 0xE0000000U) + || (announce_target.sock.sa_family == AF_INET6 + && announce_target.ipv6.sin6_addr.s6_addr[0] == 0xFF); + + if (multicast) { + if (announce_target.sock.sa_family == AF_INET) { + ip_mreq mreq{}; + mreq.imr_multiaddr = announce_target.ipv4.sin_addr; + mreq.imr_interface.s_addr = htonl(INADDR_ANY); + if (::setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, reinterpret_cast(&mreq), sizeof(mreq)) < 0) { + ::close(fd); + throw std::system_error(network_errno, std::system_category(), "Failed to join multicast group"); + } + } + else { + ipv6_mreq mreq{}; + mreq.ipv6mr_multiaddr = announce_target.ipv6.sin6_addr; + mreq.ipv6mr_interface = 0; + if (::setsockopt(fd, IPPROTO_IPV6, IPV6_JOIN_GROUP, reinterpret_cast(&mreq), sizeof(mreq)) < 0) { + ::close(fd); + throw std::system_error(network_errno, std::system_category(), "Failed to join IPv6 multicast group"); + } + } + } + + announce_fd.reset(fd); + } + + // Send initial announce + last_announce = std::chrono::steady_clock::time_point{}; + } + + void NUClearNet::shutdown() { + // Send leave packet to announce address if we have a data socket + if (data_fd.valid()) { + auto leave = Discovery::build_leave_packet(); + send_buf(data_fd, announce_target, leave.data(), leave.size()); + } + data_fd.reset(); + announce_fd.reset(); + } + + void NUClearNet::process() { + auto now = std::chrono::steady_clock::now(); + + // Send announce if interval has elapsed + if (now - last_announce >= ANNOUNCE_INTERVAL) { + announce(); + last_announce = now; + } + + // Check for timed-out peers + discovery->check_timeouts(); + + // Check for retransmissions + auto retransmissions = reliability->check_retransmissions(fragmentation->get_packet_mtu()); + for (const auto& req : retransmissions) { + // Build header on stack, scatter-write header + fragment data + DataPacket header{}; + header.packet_id = req.packet_id; + header.packet_no = req.packet_no; + header.packet_count = req.packet_count; + header.flags = req.flags; + header.hash = req.hash; + + struct iovec iov[2]; + iov[0].iov_base = reinterpret_cast(&header); + iov[0].iov_len = sizeof(DataPacket) - 1; + iov[1].iov_base = const_cast(static_cast(req.data.data())); + iov[1].iov_len = req.data.size(); + + send_iov(data_fd, req.target, iov, 2); + } + + // Clean up expired fragment assemblies + fragmentation->cleanup_expired(); + + // Read pending packets from both sockets + if (announce_fd.valid()) { + read_socket(announce_fd); + } + if (data_fd.valid()) { + read_socket(data_fd); + } + + // Schedule next event + if (event_callback) { + event_callback(now + ANNOUNCE_INTERVAL); + } + } + + void NUClearNet::send(uint64_t hash, const uint8_t* payload, std::size_t length, const std::string& target, bool reliable) { + if (!data_fd.valid()) { + return; + } + + // Get a packet ID + uint16_t packet_id = next_packet_id++; + + uint8_t flags = reliable ? RELIABLE : 0; + + // Compute fragment count + uint16_t packet_mtu = fragmentation->get_packet_mtu(); + uint16_t packet_count = length == 0 ? 1 : static_cast((length + packet_mtu - 1) / packet_mtu); + + // Unreliable broadcast: send once to the multicast/broadcast group + // All connected peers receive it on their announce socket — receivers filter by subscription + if (target.empty() && !reliable) { + for (uint16_t i = 0; i < packet_count; ++i) { + DataPacket header{}; + header.packet_id = packet_id; + header.packet_no = i; + header.packet_count = packet_count; + header.flags = flags; + header.hash = hash; + + std::size_t offset = static_cast(i) * packet_mtu; + std::size_t frag_len = std::min(static_cast(packet_mtu), length - offset); + + struct iovec iov[2]; + iov[0].iov_base = reinterpret_cast(&header); + iov[0].iov_len = sizeof(DataPacket) - 1; + iov[1].iov_base = const_cast(static_cast(payload + offset)); + iov[1].iov_len = frag_len; + + send_iov(data_fd, announce_target, iov, 2); + } + return; + } + + // Targeted or reliable sends: unicast to each matching peer + auto peers = discovery->get_peers(); + std::vector targets; + + if (target.empty()) { + // Reliable broadcast: send to all subscribing peers individually (for ACK tracking) + for (const auto& peer : peers) { + const auto& addr = peer.first; + if (routing.should_send(addr, hash)) { + targets.push_back(addr); + } + } + } + else { + // Send to specific named peer(s) + for (const auto& peer : peers) { + const auto& addr = peer.first; + const auto& info = peer.second; + if (info.name == target && routing.should_send(addr, hash)) { + targets.push_back(addr); + } + } + } + + if (targets.empty()) { + return; + } + + // Send each fragment to each target using scatter IO (no data copy) + for (const auto& tgt : targets) { + for (uint16_t i = 0; i < packet_count; ++i) { + DataPacket header{}; + header.packet_id = packet_id; + header.packet_no = i; + header.packet_count = packet_count; + header.flags = flags; + header.hash = hash; + + std::size_t offset = static_cast(i) * packet_mtu; + std::size_t frag_len = std::min(static_cast(packet_mtu), length - offset); + + struct iovec iov[2]; + iov[0].iov_base = reinterpret_cast(&header); + iov[0].iov_len = sizeof(DataPacket) - 1; + iov[1].iov_base = const_cast(static_cast(payload + offset)); + iov[1].iov_len = frag_len; + + send_iov(data_fd, tgt, iov, 2); + } + + // If reliable, track for retransmission (single copy stored internally) + if (reliable) { + reliability->track_packet(tgt, packet_id, packet_count, hash, flags, payload, length); + } + } + } + + void NUClearNet::set_subscriptions(const std::set& subscriptions) { + routing.set_local_subscriptions(subscriptions); + } + + void NUClearNet::add_subscription(uint64_t hash) { + routing.add_local_subscription(hash); + } + + void NUClearNet::set_packet_callback(PacketCallback cb) { + packet_callback = std::move(cb); + } + + void NUClearNet::set_join_callback(JoinCallback cb) { + join_callback = std::move(cb); + } + + void NUClearNet::set_leave_callback(LeaveCallback cb) { + leave_callback = std::move(cb); + } + + void NUClearNet::set_event_callback(EventCallback cb) { + event_callback = std::move(cb); + } + + std::vector NUClearNet::listen_fds() const { + std::vector fds; + if (data_fd.valid()) { + fds.push_back(data_fd.get()); + } + if (announce_fd.valid()) { + fds.push_back(announce_fd.get()); + } + return fds; + } + + void NUClearNet::announce() { + if (!data_fd.valid()) { + return; + } + + auto subs = routing.get_local_subscriptions(); + auto packet = Discovery::build_announce_packet(node_name, subs); + + // Send announce from data socket to announce target (NAT-friendly) + send_buf(data_fd, announce_target, packet.data(), packet.size()); + } + + void NUClearNet::read_socket(fd_t fd) { + // Stack buffer — 65535 is the maximum UDP datagram size + alignas(8) uint8_t buffer[65535]; + sock_t source{}; + + for (;;) { + socklen_t source_len = sizeof(source.storage); + ssize_t received = ::recvfrom(fd, + reinterpret_cast(buffer), + sizeof(buffer), + MSG_DONTWAIT, + &source.sock, + &source_len); + + if (received <= 0) { + break; + } + + process_packet(source, buffer, static_cast(received)); + } + } + + void NUClearNet::process_packet(const sock_t& source, const uint8_t* data, std::size_t length) { + if (!validate_header(data, length)) { + return; + } + + const auto* header = reinterpret_cast(data); + + // Touch the peer to reset timeout + discovery->touch_peer(source); + + switch (header->type) { + case ANNOUNCE: { + // Check if this is a new peer before processing (which adds them) + const bool is_new_peer = !discovery->has_peer(source); + auto announce_result = discovery->process_announce(source, data, length); + + if (announce_result.is_new) { + // Force an immediate announce to the multicast/broadcast group + // so the new peer hears us on the announce channel (confirms our_d→their_a) + announce(); + last_announce = std::chrono::steady_clock::now(); + } + + // Send CONNECT packet if the handshake needs it (initial SYN or retransmit) + if (announce_result.response_flags != 0) { + auto pkt = Discovery::build_connect_packet(announce_result.response_flags); + send_buf(data_fd, source, pkt.data(), pkt.size()); + + // Mark SYN_SENT if we're sending a SYN (only advances from IDLE) + if ((announce_result.response_flags & SYN) != 0) { + discovery->mark_syn_sent(source); + } + } + } break; + + case LEAVE: { + discovery->process_leave(source); + } break; + + case CONNECT: { + if (length < sizeof(ConnectPacket)) { + return; + } + const auto* pkt = reinterpret_cast(data); + auto result = discovery->process_connect(source, pkt->flags); + + // Send response if the state machine requires one + if (result.response_flags != 0) { + auto response = Discovery::build_connect_packet(result.response_flags); + send_buf(data_fd, source, response.data(), response.size()); + } + } break; + + case DATA: { + // Only accept data from connected peers + if (!discovery->is_connected(source)) { + return; + } + + if (length < sizeof(DataPacket)) { + return; + } + + const auto* pkt = reinterpret_cast(data); + + // Drop messages we are not subscribed to (relevant for multicast broadcast data) + if (!routing.is_locally_subscribed(pkt->hash)) { + return; + } + + // Check for duplicates (at the packet group level) + auto& dedup = deduplicators[source]; + + if (dedup.is_duplicate(pkt->packet_id)) { + // Already processed this packet group — send ACK if reliable + if ((pkt->flags & RELIABLE) != 0) { + std::vector all_received(pkt->packet_count, true); + auto ack = Reliability::build_ack_packet(pkt->packet_id, pkt->packet_count, all_received); + send_buf(data_fd, source, ack.data(), ack.size()); + } + return; + } + + // Extract fragment data + const uint8_t* frag_data = data + sizeof(DataPacket) - 1; + std::size_t frag_length = length - (sizeof(DataPacket) - 1); + + // Use a hash of the source address bytes as the source key for fragmentation + // We use the first 8 bytes of the storage as a simple key + uint64_t source_key = 0; + std::memcpy(&source_key, &source.storage, std::min(sizeof(source_key), sizeof(source.storage))); + + // Submit to fragmentation + Fragmentation::AssembledPacket assembled; + bool has_assembled = fragmentation->submit_fragment(source_key, + pkt->packet_id, + pkt->packet_no, + pkt->packet_count, + pkt->hash, + pkt->flags, + frag_data, + frag_length, + assembled); + + // Send ACK for reliable packets + if ((pkt->flags & RELIABLE) != 0) { + // Build a partial ACK (we'd need to track received fragments per assembly) + // For now, send an ACK indicating we have this fragment + std::vector received(pkt->packet_count, false); + received[pkt->packet_no] = true; + auto ack = Reliability::build_ack_packet(pkt->packet_id, pkt->packet_count, received); + send_buf(data_fd, source, ack.data(), ack.size()); + } + + // If we have a complete message, deliver it + if (has_assembled) { + dedup.add_packet(assembled.packet_id); + + if (packet_callback) { + // Look up peer name + std::string peer_name; + const auto* peer = discovery->get_peer(source); + if (peer != nullptr) { + peer_name = peer->name; + } + + bool reliable = (assembled.flags & RELIABLE) != 0; + packet_callback(source, peer_name, assembled.hash, reliable, std::move(assembled.payload)); + } + + // Send full ACK for reliable + if ((assembled.flags & RELIABLE) != 0) { + std::vector all_received(pkt->packet_count, true); + auto ack = Reliability::build_ack_packet(pkt->packet_id, pkt->packet_count, all_received); + send_buf(data_fd, source, ack.data(), ack.size()); + } + } + } break; + + case ACK: { + // Only accept ACKs from connected peers + if (!discovery->is_connected(source)) { + return; + } + + if (length < sizeof(ACKPacket)) { + return; + } + const auto* pkt = reinterpret_cast(data); + const uint8_t* bitset = data + sizeof(ACKPacket) - 1; + std::size_t bitset_size = length - (sizeof(ACKPacket) - 1); + reliability->process_ack(source, pkt->packet_id, pkt->packet_count, bitset, bitset_size); + } break; + + default: break; + } + } + + void NUClearNet::send_iov(fd_t fd, const sock_t& target, const struct iovec* iov, int iovcnt) { + struct msghdr msg{}; + msg.msg_name = const_cast(&target.sock); + msg.msg_namelen = target.size(); + msg.msg_iov = const_cast(iov); + msg.msg_iovlen = static_cast(iovcnt); + msg.msg_control = nullptr; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + ::sendmsg(fd, &msg, 0); + } + +} // namespace network +} // namespace NUClear diff --git a/src/nuclearnet/NUClearNet.hpp b/src/nuclearnet/NUClearNet.hpp new file mode 100644 index 00000000..cb6b9e13 --- /dev/null +++ b/src/nuclearnet/NUClearNet.hpp @@ -0,0 +1,227 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef NUCLEAR_NETWORK_NUCLEARNET_HPP +#define NUCLEAR_NETWORK_NUCLEARNET_HPP + +#include +#include +#include +#include +#include +#include +#include + +#include "../util/network/sock_t.hpp" +#include "Discovery.hpp" +#include "FileDescriptor.hpp" +#include "Fragmentation.hpp" +#include "PacketDeduplicator.hpp" +#include "Reliability.hpp" +#include "Routing.hpp" + +namespace NUClear { +namespace network { + + /** + * Configuration for the NUClearNet networking system. + */ + struct NetworkConfig { + /// This node's name on the network + std::string name; + /// The multicast/broadcast/unicast address to announce on + std::string announce_address = "239.226.152.162"; + /// The port to use for announce discovery + in_port_t announce_port = 7447; + /// Address to bind to (empty = all interfaces) + std::string bind_address; + /// Network MTU (used to calculate fragment size) + uint16_t mtu = 1500; + /// Peer timeout duration + std::chrono::steady_clock::duration peer_timeout = std::chrono::seconds(2); + /// Maximum total assembly size for fragmented messages + std::size_t max_assembly_size = 64 * 1024 * 1024; + }; + + /** + * NUClearNet — standalone UDP networking library for peer-to-peer communication. + * + * Provides: + * - Automatic peer discovery via multicast/broadcast announces + * - NAT-friendly port learning (from UDP source address) + * - Fragmentation and reassembly of large messages + * - Optional reliable delivery with ACK-based retransmission + * - Subscription-based message filtering + * - Per-peer RTT estimation with Jacobson/Karels algorithm + * + * This class can be used independently of the NUClear reactor framework. + * It operates via a poll-based model: call process() to handle network events. + */ + class NUClearNet { + public: + using sock_t = util::network::sock_t; + + /// Callback for received complete messages + using PacketCallback = std::function&& payload)>; + + /// Callback for peer join events + using JoinCallback = std::function; + + /// Callback for peer leave events + using LeaveCallback = std::function; + + /// Callback for when the system needs attention at a specific time + using EventCallback = std::function; + + NUClearNet(); + ~NUClearNet(); + NUClearNet(const NUClearNet&) = delete; + NUClearNet(NUClearNet&&) = delete; + NUClearNet& operator=(const NUClearNet&) = delete; + NUClearNet& operator=(NUClearNet&&) = delete; + + /** + * Reset/configure the network with new settings. + * If already running, shuts down first, then reinitializes. + * + * @param config The network configuration + */ + void reset(const NetworkConfig& config); + + /** + * Shut down the network, sending a leave packet and closing sockets. + */ + void shutdown(); + + /** + * Process pending network events (send announces, check timeouts, read packets). + * Call this periodically or when a file descriptor becomes readable. + */ + void process(); + + /** + * Send a message to the network. + * + * @param hash The message type hash + * @param payload Pointer to the serialized message data + * @param length Length of the payload in bytes + * @param target Target peer name (empty = send to all eligible peers) + * @param reliable Whether to use reliable delivery + */ + void send(uint64_t hash, + const uint8_t* payload, + std::size_t length, + const std::string& target, + bool reliable); + + /** + * Set this node's subscriptions (which message types to receive). + * This information is advertised in announce packets. + * + * @param subscriptions Set of message type hashes to subscribe to (empty = receive all) + */ + void set_subscriptions(const std::set& subscriptions); + + /** + * Add a single subscription. + * + * @param hash The message type hash to subscribe to + */ + void add_subscription(uint64_t hash); + + // Callback setters + void set_packet_callback(PacketCallback cb); + void set_join_callback(JoinCallback cb); + void set_leave_callback(LeaveCallback cb); + void set_event_callback(EventCallback cb); + + /** + * Get the file descriptors that should be monitored for read events. + * When any of these become readable, call process(). + * + * @return Vector of file descriptors to monitor + */ + std::vector listen_fds() const; + + private: + /// Send an announce packet to the announce address + void announce(); + + /// Read and process all pending packets from a socket + void read_socket(fd_t fd); + + /// Process a single received packet + void process_packet(const sock_t& source, const uint8_t* data, std::size_t length); + + /// Send raw bytes to a target using scatter IO (multiple buffers without copying) + void send_iov(fd_t fd, const sock_t& target, const struct iovec* iov, int iovcnt); + + /// Send a single contiguous buffer to a target + void send_buf(fd_t fd, const sock_t& target, const uint8_t* data, std::size_t length) { + struct iovec iov{}; + iov.iov_base = const_cast(static_cast(data)); // NOLINT(cppcoreguidelines-pro-type-const-cast) + iov.iov_len = length; + send_iov(fd, target, &iov, 1); + } + + // Configuration + NetworkConfig config; + std::string node_name; + + // Sockets + FileDescriptor data_fd; ///< Data socket (ephemeral port, sends announces + data) + FileDescriptor announce_fd; ///< Announce socket (known port, receives announces) + + // The announce target address + sock_t announce_target{}; + + // Modules + std::unique_ptr discovery; + std::unique_ptr fragmentation; + std::unique_ptr reliability; + Routing routing; + + // Per-peer deduplication + std::map deduplicators; + + // Packet ID source (monotonically increasing) + uint16_t next_packet_id{0}; + + // Timing + std::chrono::steady_clock::time_point last_announce; + static const std::chrono::milliseconds ANNOUNCE_INTERVAL; + + // Callbacks + PacketCallback packet_callback; + JoinCallback join_callback; + LeaveCallback leave_callback; + EventCallback event_callback; + }; + +} // namespace network +} // namespace NUClear + +#endif // NUCLEAR_NETWORK_NUCLEARNET_HPP diff --git a/src/nuclearnet/PacketDeduplicator.cpp b/src/nuclearnet/PacketDeduplicator.cpp new file mode 100644 index 00000000..812b22b0 --- /dev/null +++ b/src/nuclearnet/PacketDeduplicator.cpp @@ -0,0 +1,97 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "PacketDeduplicator.hpp" + +namespace NUClear { +namespace network { + + bool PacketDeduplicator::is_duplicate(uint16_t packet_id) const { + // If we haven't seen any packets yet, nothing is a duplicate + if (!initialized) { + return false; + } + + // Check how far ahead of newest_seen this packet is + uint16_t forward_distance = packet_id - newest_seen; + + // If it's ahead of newest_seen (newer), it can't be a duplicate + if (forward_distance != 0 && forward_distance < 0x8000U) { + return false; + } + + // It's at or behind newest_seen — check the distance backward + uint16_t distance = newest_seen - packet_id; + + // If distance >= WINDOW_SIZE, it's too old — treat as duplicate (already processed or lost) + if (distance >= WINDOW_SIZE) { + return true; + } + + // Check the bit in our window + return window.test(distance); + } + + void PacketDeduplicator::add_packet(uint16_t packet_id) { + if (!initialized) { + // First packet ever — initialize the window + initialized = true; + newest_seen = packet_id; + window.reset(); + window.set(0); // Mark current packet as seen + return; + } + + // Calculate how far ahead of newest_seen this packet is (signed interpretation of unsigned diff) + uint16_t forward_distance = packet_id - newest_seen; + + // If the high bit is set, it's actually behind us (wrapped subtraction gave a large positive number) + // We use half the uint16_t range as the threshold for "ahead" vs "behind" + if (forward_distance == 0) { + // Same as newest — just make sure the bit is set + window.set(0); + } + else if (forward_distance < 0x8000U) { + // This packet is NEWER than our current newest + // Slide the window forward by forward_distance positions + if (forward_distance >= WINDOW_SIZE) { + // The new packet is so far ahead that the entire window is invalidated + window.reset(); + } + else { + window <<= forward_distance; + } + newest_seen = packet_id; + window.set(0); // Mark the new packet as seen + } + else { + // This packet is OLDER than our current newest (behind us) + uint16_t distance = newest_seen - packet_id; + if (distance < WINDOW_SIZE) { + window.set(distance); // Mark it as seen in the appropriate position + } + // If it's too old (distance >= WINDOW_SIZE), we just ignore it + } + } + +} // namespace network +} // namespace NUClear diff --git a/src/nuclearnet/PacketDeduplicator.hpp b/src/nuclearnet/PacketDeduplicator.hpp new file mode 100644 index 00000000..775c371a --- /dev/null +++ b/src/nuclearnet/PacketDeduplicator.hpp @@ -0,0 +1,75 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef NUCLEAR_NETWORK_PACKET_DEDUPLICATOR_HPP +#define NUCLEAR_NETWORK_PACKET_DEDUPLICATOR_HPP + +#include +#include + +namespace NUClear { +namespace network { + + /** + * Sliding window bitset for packet deduplication. + * + * Maintains a 256-bit window of recently seen packet IDs. + * The window slides forward as newer packets are added. + * Packets older than 256 IDs behind the newest are considered duplicates. + * + * Uses uint16_t packet IDs with unsigned modular arithmetic for wrap-around handling. + */ + class PacketDeduplicator { + public: + /// Window size in bits + static constexpr uint16_t WINDOW_SIZE = 256; + + /** + * Check if a packet ID has been seen recently. + * + * @param packet_id The packet ID to check + * + * @return true if the packet has been seen recently (is a duplicate), false otherwise + */ + bool is_duplicate(uint16_t packet_id) const; + + /** + * Add a packet ID to the window, marking it as seen. + * If the packet_id is newer than the current window, the window slides forward. + * + * @param packet_id The packet ID to add + */ + void add_packet(uint16_t packet_id); + + private: + /// Whether we've seen any packets yet + bool initialized{false}; + /// The newest packet ID we've seen + uint16_t newest_seen{0}; + /// The 256-bit window of seen packets (bit 0 = newest_seen, higher indices = older) + std::bitset window; + }; + +} // namespace network +} // namespace NUClear + +#endif // NUCLEAR_NETWORK_PACKET_DEDUPLICATOR_HPP diff --git a/src/nuclearnet/RTTEstimator.cpp b/src/nuclearnet/RTTEstimator.cpp new file mode 100644 index 00000000..1a92e444 --- /dev/null +++ b/src/nuclearnet/RTTEstimator.cpp @@ -0,0 +1,88 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and + * to permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of + * the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO + * THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF + * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#include "RTTEstimator.hpp" + +#include +#include + + namespace NUClear { + namespace network { + + namespace { + float clampf(const float value, const float min_value, const float max_value) { + return std::max(min_value, std::min(value, max_value)); + } + } // namespace + + RTTEstimator::RTTEstimator(float alpha, + float beta, + float initial_rtt, + float initial_rtt_var, + float min_rto, + float max_rto) + : alpha(alpha) + , beta(beta) + , min_rto(min_rto) + , max_rto(max_rto) + , smoothed_rtt(initial_rtt) + , rtt_var(initial_rtt_var) + , rto(clampf(initial_rtt + 4.0f * initial_rtt_var, min_rto, max_rto)) { + + if (alpha < 0.0f || alpha > 1.0f) { + throw std::invalid_argument("RTTEstimator: alpha must be in [0, 1]"); + } + if (beta < 0.0f || beta > 1.0f) { + throw std::invalid_argument("RTTEstimator: beta must be in [0, 1]"); + } + if (min_rto >= max_rto) { + throw std::invalid_argument("RTTEstimator: min_rto must be less than max_rto"); + } + } + + void RTTEstimator::measure(std::chrono::steady_clock::duration time) { + // Convert measurement to seconds as float + const float sample = std::chrono::duration(time).count(); + + if (!has_measurement) { + // RFC 6298: First measurement initialization + // SRTT = R, RTTVAR = R/2, RTO = SRTT + 4*RTTVAR = 3*R + smoothed_rtt = sample; + rtt_var = sample * 0.5f; + has_measurement = true; + } + else { + // Jacobson/Karels algorithm (RFC 6298) + rtt_var = (1.0f - beta) * rtt_var + beta * std::abs(smoothed_rtt - sample); + smoothed_rtt = (1.0f - alpha) * smoothed_rtt + alpha * sample; + } + + rto = clampf(smoothed_rtt + 4.0f * rtt_var, min_rto, max_rto); + } + + std::chrono::steady_clock::duration RTTEstimator::timeout() const { + return std::chrono::duration_cast(std::chrono::duration(rto)); + } + + } // namespace network +} // namespace NUClear diff --git a/src/nuclearnet/RTTEstimator.hpp b/src/nuclearnet/RTTEstimator.hpp new file mode 100644 index 00000000..8d4945a8 --- /dev/null +++ b/src/nuclearnet/RTTEstimator.hpp @@ -0,0 +1,90 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef NUCLEAR_NETWORK_RTT_ESTIMATOR_HPP +#define NUCLEAR_NETWORK_RTT_ESTIMATOR_HPP + +#include +#include + +namespace NUClear { +namespace network { + + /** + * TCP-style Round Trip Time (RTT) estimation using Jacobson/Karels algorithm (RFC 6298). + * + * Uses Exponentially Weighted Moving Average (EWMA) to smooth RTT measurements and + * calculate a retransmission timeout (RTO) value. + */ + class RTTEstimator { + public: + /** + * Construct a new RTT Estimator. + * + * @param alpha Weight for RTT smoothing (default: 0.125, TCP standard) + * @param beta Weight for RTT variation (default: 0.25, TCP standard) + * @param initial_rtt Initial RTT estimate in seconds (default: 1.0) + * @param initial_rtt_var Initial RTT variation in seconds (default: 0.0) + * @param min_rto Minimum RTO value in seconds (default: 0.1) + * @param max_rto Maximum RTO value in seconds (default: 60.0) + */ + RTTEstimator(float alpha = 0.125f, + float beta = 0.25f, + float initial_rtt = 1.0f, + float initial_rtt_var = 0.0f, + float min_rto = 0.1f, + float max_rto = 60.0f); + + /** + * Update the RTT estimate with a new measurement. + * + * Applies the Jacobson/Karels algorithm: + * RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - sample| + * SRTT = (1 - alpha) * SRTT + alpha * sample + * RTO = SRTT + 4 * RTTVAR + * + * @param time The measured round trip time + */ + void measure(std::chrono::steady_clock::duration time); + + /** + * Get the current retransmission timeout. + * + * @return The RTO as a duration + */ + std::chrono::steady_clock::duration timeout() const; + + private: + float alpha; ///< Weight for RTT smoothing + float beta; ///< Weight for RTT variation + float min_rto; ///< Minimum RTO value in seconds + float max_rto; ///< Maximum RTO value in seconds + float smoothed_rtt; ///< Smoothed RTT estimate in seconds + float rtt_var; ///< RTT variation in seconds + float rto; ///< Current retransmission timeout in seconds + bool has_measurement{false}; ///< Whether we've received at least one measurement + }; + +} // namespace network +} // namespace NUClear + +#endif // NUCLEAR_NETWORK_RTT_ESTIMATOR_HPP diff --git a/src/nuclearnet/Reliability.cpp b/src/nuclearnet/Reliability.cpp new file mode 100644 index 00000000..04a2d304 --- /dev/null +++ b/src/nuclearnet/Reliability.cpp @@ -0,0 +1,207 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without + * limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the + * Software, and to permit persons to whom the Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions + * of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED + * TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF + * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include "Reliability.hpp" + +#include +#include + +#include "wire_protocol.hpp" + + namespace NUClear { + namespace network { + + + + void Reliability::track_packet(const sock_t& target, + uint16_t packet_id, + uint16_t packet_count, + uint64_t hash, + uint8_t flags, + const uint8_t* payload, + std::size_t payload_len, + std::chrono::steady_clock::time_point now) { + TrackedPacket tp; + tp.target = target; + tp.packet_id = packet_id; + tp.packet_count = packet_count; + tp.hash = hash; + tp.flags = flags; + tp.payload.assign(payload, payload + payload_len); + tp.acked.resize(packet_count, false); + tp.last_send = now; + + TrackingKey key{target, packet_id}; + + const std::lock_guard lock(tracking_mutex); + tracked_packets[key] = std::move(tp); + } + + void Reliability::process_ack(const sock_t& source, + uint16_t packet_id, + uint16_t packet_count, + const uint8_t* ack_bitset, + std::size_t bitset_size, + std::chrono::steady_clock::time_point now) { + TrackingKey key{source, packet_id}; + + const std::lock_guard lock(tracking_mutex); + auto it = tracked_packets.find(key); + if (it == tracked_packets.end()) { + return; + } + + auto& tp = it->second; + + // Update RTT estimate based on time since last send + auto rtt = now - tp.last_send; + { + const std::lock_guard rtt_lock(rtt_mutex); + rtt_estimators[source].measure(rtt); + } + + // Validate that the ACK's packet_count matches our tracked packet + if (packet_count != tp.packet_count) { + return; + } + + // Update acked bitset + bool all_acked = true; + for (uint16_t i = 0; i < packet_count && i < tp.acked.size(); ++i) { + std::size_t byte_idx = i / 8; + uint8_t bit_idx = i % 8; + if (byte_idx < bitset_size && (ack_bitset[byte_idx] & (1u << bit_idx)) != 0) { + tp.acked[i] = true; + } + if (!tp.acked[i]) { + all_acked = false; + } + } + + // If all fragments are ACKed, remove from tracking + if (all_acked) { + tracked_packets.erase(it); + } + } + + std::vector Reliability::build_ack_packet(uint16_t packet_id, + uint16_t packet_count, + const std::vector& received) { + // Calculate bitset size: ceil(packet_count / 8) + std::size_t bitset_bytes = (packet_count + 7) / 8; + std::size_t total_size = sizeof(ACKPacket) - 1 + bitset_bytes; // -1 for the placeholder uint8_t + + std::vector packet(total_size, 0); + + // Write header + ACKPacket ack_header; + ack_header.packet_id = packet_id; + ack_header.packet_count = packet_count; + std::memcpy(packet.data(), &ack_header, sizeof(ACKPacket) - 1); + + // Write bitset + uint8_t* bitset = packet.data() + sizeof(ACKPacket) - 1; + for (std::size_t i = 0; i < received.size() && i < packet_count; ++i) { + if (received[i]) { + bitset[i / 8] |= static_cast(1u << (i % 8)); + } + } + + return packet; + } + + std::vector Reliability::check_retransmissions( + uint16_t packet_mtu, + std::chrono::steady_clock::time_point now) { + std::vector retransmissions; + + const std::lock_guard lock(tracking_mutex); + + for (auto& entry : tracked_packets) { + auto& tp = entry.second; + + // Get the timeout for this peer + std::chrono::steady_clock::duration rto; + { + const std::lock_guard rtt_lock(rtt_mutex); + rto = rtt_estimators[tp.target].timeout(); + } + + // Check if it's time to retransmit + if (now - tp.last_send < rto) { + continue; + } + + // Retransmit unacked fragments (continues until peer is removed) + for (uint16_t i = 0; i < tp.packet_count; ++i) { + if (!tp.acked[i]) { + RetransmitRequest req; + req.target = tp.target; + req.packet_id = tp.packet_id; + req.packet_no = i; + req.packet_count = tp.packet_count; + req.flags = tp.flags; + req.hash = tp.hash; + + // Extract the fragment data + std::size_t offset = static_cast(i) * packet_mtu; + std::size_t length = std::min(static_cast(packet_mtu), tp.payload.size() - offset); + req.data.assign(tp.payload.begin() + offset, tp.payload.begin() + offset + length); + + retransmissions.push_back(std::move(req)); + } + } + + tp.last_send = now; + tp.retransmit_count++; + } + + return retransmissions; + } + + void Reliability::remove_peer(const sock_t& target) { + { + const std::lock_guard lock(tracking_mutex); + for (auto it = tracked_packets.begin(); it != tracked_packets.end();) { + if (it->first.first == target) { + it = tracked_packets.erase(it); + } + else { + ++it; + } + } + } + { + const std::lock_guard lock(rtt_mutex); + rtt_estimators.erase(target); + } + } + + RTTEstimator& Reliability::get_rtt(const sock_t& target) { + const std::lock_guard lock(rtt_mutex); + return rtt_estimators[target]; + } + + } // namespace network +} // namespace NUClear diff --git a/src/nuclearnet/Reliability.hpp b/src/nuclearnet/Reliability.hpp new file mode 100644 index 00000000..6e558730 --- /dev/null +++ b/src/nuclearnet/Reliability.hpp @@ -0,0 +1,171 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef NUCLEAR_NETWORK_RELIABILITY_HPP +#define NUCLEAR_NETWORK_RELIABILITY_HPP + +#include +#include +#include +#include +#include +#include + +#include "../util/network/sock_t.hpp" +#include "RTTEstimator.hpp" + +namespace NUClear { +namespace network { + + /** + * Handles reliable delivery via ACK tracking and retransmission. + * + * Responsibilities: + * - Tracking which fragments have been ACKed for each reliable packet group + * - Scheduling retransmissions based on RTT estimates + * - Processing incoming ACK packets + * - Providing per-peer RTT estimation + * - Exponential backoff on repeated failures + */ + class Reliability { + public: + using sock_t = util::network::sock_t; + + /// Information about a fragment that needs retransmitting + struct RetransmitRequest { + sock_t target; + uint16_t packet_id; + uint16_t packet_no; + uint16_t packet_count; + uint8_t flags; + uint64_t hash; + std::vector data; + }; + + Reliability() = default; + + /** + * Register a sent reliable packet group for tracking. + * + * @param target The peer we sent to + * @param packet_id The packet group ID + * @param packet_count Total fragments in the group + * @param hash Message type hash + * @param flags Packet flags + * @param payload Pointer to the full original payload (copied internally for retransmission) + * @param payload_len Length of the payload in bytes + * @param now The current time (defaults to steady_clock::now()) + */ + void track_packet(const sock_t& target, + uint16_t packet_id, + uint16_t packet_count, + uint64_t hash, + uint8_t flags, + const uint8_t* payload, + std::size_t payload_len, + std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()); + + /** + * Process a received ACK packet. + * + * @param source Who sent the ACK + * @param packet_id The packet group being acknowledged + * @param packet_count Total fragments in the group + * @param ack_bitset Bitset of received fragments (1 bit per fragment, LSB first) + * @param bitset_size Size of the ack_bitset in bytes + * @param now The current time (defaults to steady_clock::now()) + */ + void process_ack(const sock_t& source, + uint16_t packet_id, + uint16_t packet_count, + const uint8_t* ack_bitset, + std::size_t bitset_size, + std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()); + + /** + * Build an ACK packet payload (excluding header) for a packet group. + * + * @param packet_id The packet group to acknowledge + * @param packet_count Total fragments + * @param received Bitset of which fragments have been received + * + * @return Serialized ACK packet bytes (complete packet including header) + */ + static std::vector build_ack_packet(uint16_t packet_id, + uint16_t packet_count, + const std::vector& received); + + /** + * Check for packets that need retransmission and return them. + * + * @param packet_mtu The MTU to use for fragmenting retransmissions + * @param now The current time (defaults to steady_clock::now()) + * + * @return List of fragments that need to be retransmitted + */ + std::vector check_retransmissions( + uint16_t packet_mtu, + std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now()); + + /** + * Remove all tracking for a given peer (e.g., on disconnect). + * + * @param target The peer to remove + */ + void remove_peer(const sock_t& target); + + /** + * Get the RTT estimator for a specific peer. + * + * @param target The peer address + * @return Reference to the RTT estimator (creates one if it doesn't exist) + */ + RTTEstimator& get_rtt(const sock_t& target); + + private: + /// State for a tracked reliable packet group + struct TrackedPacket { + sock_t target; + uint16_t packet_id; + uint16_t packet_count; + uint64_t hash; + uint8_t flags; + std::vector payload; + std::vector acked; ///< Which fragments have been ACKed + std::chrono::steady_clock::time_point last_send; + uint16_t retransmit_count{0}; + }; + + /// Key for tracked packets: (target, packet_id) + using TrackingKey = std::pair; + + mutable std::mutex tracking_mutex; + std::map tracked_packets; + + mutable std::mutex rtt_mutex; + std::map rtt_estimators; + }; + +} // namespace network +} // namespace NUClear + +#endif // NUCLEAR_NETWORK_RELIABILITY_HPP diff --git a/src/nuclearnet/Routing.cpp b/src/nuclearnet/Routing.cpp new file mode 100644 index 00000000..0fa26980 --- /dev/null +++ b/src/nuclearnet/Routing.cpp @@ -0,0 +1,90 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "Routing.hpp" + +namespace NUClear { +namespace network { + + void Routing::update_peer_subscriptions(const sock_t& peer, std::set subscriptions) { + const std::lock_guard lock(peer_mutex); + peer_subscriptions[peer] = std::move(subscriptions); + } + + void Routing::remove_peer(const sock_t& peer) { + const std::lock_guard lock(peer_mutex); + peer_subscriptions.erase(peer); + } + + bool Routing::should_send(const sock_t& peer, uint64_t hash) const { + const std::lock_guard lock(peer_mutex); + auto it = peer_subscriptions.find(peer); + if (it == peer_subscriptions.end()) { + // Unknown peer — default to sending + return true; + } + // Empty subscription set means "send everything" + if (it->second.empty()) { + return true; + } + // Check if the hash is in the peer's subscription set + return it->second.count(hash) > 0; + } + + std::vector Routing::get_targets(const std::vector& all_peers, uint64_t hash) const { + std::vector targets; + targets.reserve(all_peers.size()); + for (const auto& peer : all_peers) { + if (should_send(peer, hash)) { + targets.push_back(peer); + } + } + return targets; + } + + void Routing::set_local_subscriptions(std::set subscriptions) { + const std::lock_guard lock(local_mutex); + local_subscriptions = std::move(subscriptions); + } + + std::vector Routing::get_local_subscriptions() const { + const std::lock_guard lock(local_mutex); + return {local_subscriptions.begin(), local_subscriptions.end()}; + } + + void Routing::add_local_subscription(uint64_t hash) { + const std::lock_guard lock(local_mutex); + local_subscriptions.insert(hash); + } + + void Routing::remove_local_subscription(uint64_t hash) { + const std::lock_guard lock(local_mutex); + local_subscriptions.erase(hash); + } + + bool Routing::is_locally_subscribed(uint64_t hash) const { + const std::lock_guard lock(local_mutex); + return local_subscriptions.empty() || local_subscriptions.count(hash) != 0; + } + +} // namespace network +} // namespace NUClear diff --git a/src/nuclearnet/Routing.hpp b/src/nuclearnet/Routing.hpp new file mode 100644 index 00000000..bb26e89b --- /dev/null +++ b/src/nuclearnet/Routing.hpp @@ -0,0 +1,139 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef NUCLEAR_NETWORK_ROUTING_HPP +#define NUCLEAR_NETWORK_ROUTING_HPP + +#include +#include +#include +#include +#include + +#include "../util/network/sock_t.hpp" + +namespace NUClear { +namespace network { + + /** + * Manages subscription-based message routing. + * + * Responsibilities: + * - Tracking per-peer subscriptions (which message types each peer wants) + * - Filtering outgoing messages based on subscriptions + * - Managing this node's own subscription list for announce packets + * + * Default behavior: if a peer has no subscriptions (empty set), send all messages to it. + * This ensures backwards-compatible behavior with peers that don't support filtering. + */ + class Routing { + public: + using sock_t = util::network::sock_t; + + /** + * Update the subscription set for a remote peer. + * + * @param peer The peer's address + * @param subscriptions The set of hashes the peer wants (empty = all) + */ + void update_peer_subscriptions(const sock_t& peer, std::set subscriptions); + + /** + * Remove a peer's subscription information (e.g., on disconnect). + * + * @param peer The peer to remove + */ + void remove_peer(const sock_t& peer); + + /** + * Check whether a message with the given hash should be sent to a specific peer. + * + * @param peer The target peer + * @param hash The message type hash + * + * @return true if the message should be sent (peer subscribes to it, or has no filter) + */ + bool should_send(const sock_t& peer, uint64_t hash) const; + + /** + * Get the list of peers that should receive a message with the given hash. + * + * @param all_peers All known peer addresses + * @param hash The message type hash + * + * @return Subset of peers that should receive this message + */ + std::vector get_targets(const std::vector& all_peers, uint64_t hash) const; + + /** + * Set this node's local subscriptions (what we want to receive). + * This is used when building announce packets. + * + * @param subscriptions The set of hashes this node wants to receive (empty = all) + */ + void set_local_subscriptions(std::set subscriptions); + + /** + * Get this node's local subscriptions as a vector (for building announce packets). + * + * @return The subscription hashes + */ + std::vector get_local_subscriptions() const; + + /** + * Add a single hash to the local subscriptions. + * + * @param hash The message type hash to subscribe to + */ + void add_local_subscription(uint64_t hash); + + /** + * Remove a single hash from the local subscriptions. + * + * @param hash The message type hash to unsubscribe from + */ + void remove_local_subscription(uint64_t hash); + + /** + * Check if we are locally subscribed to a message type hash. + * Returns true if we have no subscriptions (empty = receive all) or if the hash is in our set. + * + * @param hash The message type hash to check + * + * @return true if we should accept this message type + */ + bool is_locally_subscribed(uint64_t hash) const; + + private: + /// Per-peer subscription sets (empty set = send all) + mutable std::mutex peer_mutex; + std::map> peer_subscriptions; + + /// This node's own subscriptions + mutable std::mutex local_mutex; + std::set local_subscriptions; + }; + +} // namespace network +} // namespace NUClear + +#endif // NUCLEAR_NETWORK_ROUTING_HPP diff --git a/src/nuclearnet/wire_protocol.hpp b/src/nuclearnet/wire_protocol.hpp new file mode 100644 index 00000000..17a4d1f9 --- /dev/null +++ b/src/nuclearnet/wire_protocol.hpp @@ -0,0 +1,201 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef NUCLEAR_NETWORK_WIRE_PROTOCOL_HPP +#define NUCLEAR_NETWORK_WIRE_PROTOCOL_HPP + +#include +#include + +// Packing macros for wire format structs +#ifdef _MSC_VER + #define NUCLEAR_NET_PACK(...) __pragma(pack(push, 1)) __VA_ARGS__ __pragma(pack(pop)) +#else + #define NUCLEAR_NET_PACK(...) __VA_ARGS__ __attribute__((__packed__)) +#endif + +namespace NUClear { +namespace network { + + /// Protocol version for NUClearNet v2 + constexpr uint8_t PROTOCOL_VERSION = 0x03; + + /// Packet type identifiers + enum PacketType : uint8_t { + ANNOUNCE = 1, + LEAVE = 2, + DATA = 3, + ACK = 4, + CONNECT = 5, + }; + + /// Data packet flags (bit field) + enum DataFlags : uint8_t { + RELIABLE = 0x01, + }; + + /// Connect packet flags (bit field) + enum ConnectFlags : uint8_t { + SYN = 0x01, + CON_ACK = 0x02, + }; + + /** + * Header present on every NUClearNet packet. + * + * Wire layout (5 bytes): + * [0-2] 0xE2 0x98 0xA2 (UTF-8 radioactive symbol ☢) + * [3] Protocol version (0x03) + * [4] Packet type + */ + NUCLEAR_NET_PACK(struct PacketHeader { + explicit PacketHeader(PacketType t) : type(t) {} + + /// Magic bytes: radioactive symbol in UTF-8 + uint8_t header[3] = {0xE2, 0x98, 0xA2}; + /// Protocol version + uint8_t version = PROTOCOL_VERSION; + /// Packet type + PacketType type; + }); + + /** + * Announce packet — sent periodically for peer discovery. + * + * Wire layout (variable): + * [0-4] PacketHeader (type = ANNOUNCE) + * [5-6] name_length (uint16_t) + * [7..7+name_length-1] name (UTF-8, NOT null-terminated) + * [next 2 bytes] num_subscriptions (uint16_t) + * [next num_subscriptions*8 bytes] subscription hashes (uint64_t each) + * + * No port field — receiver learns the data port from the UDP source address. + * If num_subscriptions == 0, the sender wants ALL data (no filtering). + */ + NUCLEAR_NET_PACK(struct AnnouncePacket : PacketHeader { + AnnouncePacket() : PacketHeader(ANNOUNCE) {} + + /// Length of the name field that follows + uint16_t name_length{0}; + /// Variable-length data follows: name bytes, then subscription count, then subscription hashes + /// Access via pointer arithmetic from &name_length + sizeof(uint16_t) + }); + + /** + * Leave packet — sent on graceful shutdown. + * + * Wire layout (5 bytes): + * [0-4] PacketHeader (type = LEAVE) + */ + NUCLEAR_NET_PACK(struct LeavePacket : PacketHeader { + LeavePacket() : PacketHeader(LEAVE) {} + }); + + /** + * Data packet — carries message payload (possibly one fragment of a larger message). + * + * Wire layout (18+ bytes): + * [0-4] PacketHeader (type = DATA) + * [5-6] packet_id (uint16_t) — unique identifier for this packet group + * [7-8] packet_no (uint16_t) — fragment index within the group (0-based) + * [9-10] packet_count (uint16_t) — total fragments in the group + * [11] flags (uint8_t) — bit 0: reliable + * [12-19] hash (uint64_t) — message type identifier + * [20+] payload data + */ + NUCLEAR_NET_PACK(struct DataPacket : PacketHeader { + DataPacket() : PacketHeader(DATA) {} + + /// Unique identifier for this packet group + uint16_t packet_id{0}; + /// Fragment index within the group (0-based) + uint16_t packet_no{0}; + /// Total number of fragments in the group + uint16_t packet_count{1}; + /// Flags (bit 0 = reliable) + uint8_t flags{0}; + /// 64-bit hash identifying the message type + uint64_t hash{0}; + /// Payload data starts here (access via &data) + char data{0}; + }); + + /** + * ACK packet — acknowledges receipt of data fragments. + * + * Wire layout (10+ bytes): + * [0-4] PacketHeader (type = ACK) + * [5-6] packet_id (uint16_t) + * [7-8] packet_count (uint16_t) + * [9+] bitset of received packets (1 bit per fragment, LSB first) + */ + NUCLEAR_NET_PACK(struct ACKPacket : PacketHeader { + ACKPacket() : PacketHeader(ACK) {} + + /// The packet group we are acknowledging + uint16_t packet_id{0}; + /// Total number of fragments in the group + uint16_t packet_count{0}; + /// Bitset of received fragments (access via &packets) + uint8_t packets{0}; + }); + + /** + * Connect packet — used for the 3-way handshake to confirm bidirectional data port connectivity. + * + * Wire layout (6 bytes): + * [0-4] PacketHeader (type = CONNECT) + * [5] flags (uint8_t) — bit 0: SYN (initiating), bit 1: ACK (acknowledging) + * + * Handshake sequence: + * 1. Initiator sends CONNECT (SYN) to peer's data port + * 2. Peer responds with CONNECT (SYN|ACK) + * 3. Initiator sends CONNECT (ACK) — connection confirmed on both sides + */ + NUCLEAR_NET_PACK(struct ConnectPacket : PacketHeader { + ConnectPacket() : PacketHeader(CONNECT) {} + + /// Connection flags (SYN = initiating, ACK = acknowledging) + uint8_t flags{0}; + }); + + /** + * Validate that a buffer contains a valid NUClearNet packet header. + * + * @param data Pointer to the received data + * @param length Length of the received data in bytes + * + * @return true if the header is valid (correct magic, correct version, valid type) + */ + inline bool validate_header(const void* data, std::size_t length) { + if (length < sizeof(PacketHeader)) { + return false; + } + const auto* header = static_cast(data); + return header->header[0] == 0xE2 && header->header[1] == 0x98 && header->header[2] == 0xA2 + && header->version == PROTOCOL_VERSION && header->type >= ANNOUNCE && header->type <= CONNECT; + } + +} // namespace network +} // namespace NUClear + +#endif // NUCLEAR_NETWORK_WIRE_PROTOCOL_HPP diff --git a/src/util/network/sock_t.hpp b/src/util/network/sock_t.hpp index 4c52b2c0..fa51f3e1 100644 --- a/src/util/network/sock_t.hpp +++ b/src/util/network/sock_t.hpp @@ -25,8 +25,11 @@ #include #include +#include #include +#include #include +#include #include "../platform.hpp" @@ -42,6 +45,52 @@ namespace util { sockaddr_in6 ipv6; }; + /// Equality comparison operator + friend bool operator==(const sock_t& a, const sock_t& b) { + if (a.sock.sa_family != b.sock.sa_family) { + return false; + } + if (a.sock.sa_family == AF_INET) { + return a.ipv4.sin_port == b.ipv4.sin_port + && a.ipv4.sin_addr.s_addr == b.ipv4.sin_addr.s_addr; + } + if (a.sock.sa_family == AF_INET6) { + return a.ipv6.sin6_port == b.ipv6.sin6_port + && std::memcmp(&a.ipv6.sin6_addr, &b.ipv6.sin6_addr, sizeof(in6_addr)) == 0; + } + return false; + } + + /// Inequality comparison operator + friend bool operator!=(const sock_t& a, const sock_t& b) { + return !(a == b); + } + + /// Less-than comparison for use as map key + friend bool operator<(const sock_t& a, const sock_t& b) { + if (a.sock.sa_family != b.sock.sa_family) { + return a.sock.sa_family < b.sock.sa_family; + } + if (a.sock.sa_family == AF_INET) { + return std::forward_as_tuple(ntohl(a.ipv4.sin_addr.s_addr), ntohs(a.ipv4.sin_port)) + < std::forward_as_tuple(ntohl(b.ipv4.sin_addr.s_addr), ntohs(b.ipv4.sin_port)); + } + if (a.sock.sa_family == AF_INET6) { + const int cmp = std::memcmp(&a.ipv6.sin6_addr, &b.ipv6.sin6_addr, sizeof(in6_addr)); + if (cmp != 0) { + return cmp < 0; + } + return ntohs(a.ipv6.sin6_port) < ntohs(b.ipv6.sin6_port); + } + return false; + } + + /// Stream output operator + friend std::ostream& operator<<(std::ostream& os, const sock_t& addr) { + auto addr_pair = addr.address(true); + return os << addr_pair.first << ":" << addr_pair.second; + } + socklen_t size() const { switch (sock.sa_family) { case AF_INET: return sizeof(sockaddr_in); diff --git a/tests/test_util/has_multicast.cpp b/tests/test_util/has_multicast.cpp index 4b5b0fbc..d7a5e3af 100644 --- a/tests/test_util/has_multicast.cpp +++ b/tests/test_util/has_multicast.cpp @@ -23,26 +23,189 @@ #include "has_multicast.hpp" #include +#include #include "util/network/get_interfaces.hpp" #include "util/platform.hpp" namespace test_util { +namespace { + +/** + * Attempt an actual multicast send/receive round-trip. + * Returns true only if the packet is successfully delivered. + * This detects environments (e.g., macOS CI VMs) where interfaces report IFF_MULTICAST + * but the hypervisor doesn't actually deliver multicast packets. + */ +bool test_multicast_roundtrip(int af, const char* group_addr) { + // Create a UDP socket for receiving + NUClear::fd_t recv_fd = ::socket(af, SOCK_DGRAM, 0); + if (recv_fd < 0) { + return false; + } + + // Allow address reuse + int one = 1; + ::setsockopt(recv_fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&one), sizeof(one)); +#ifdef SO_REUSEPORT + ::setsockopt(recv_fd, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&one), sizeof(one)); +#endif + + // Bind to any address on an ephemeral port + uint16_t port = 0; + if (af == AF_INET) { + sockaddr_in bind_addr{}; + bind_addr.sin_family = AF_INET; + bind_addr.sin_addr.s_addr = htonl(INADDR_ANY); + bind_addr.sin_port = 0; + + if (::bind(recv_fd, reinterpret_cast(&bind_addr), sizeof(bind_addr)) < 0) { + ::close(recv_fd); + return false; + } + + // Get the assigned port + socklen_t len = sizeof(bind_addr); + ::getsockname(recv_fd, reinterpret_cast(&bind_addr), &len); + port = ntohs(bind_addr.sin_port); + + // Join the multicast group + struct ip_mreq mreq {}; + ::inet_pton(AF_INET, group_addr, &mreq.imr_multiaddr); + mreq.imr_interface.s_addr = htonl(INADDR_ANY); + if (::setsockopt(recv_fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, reinterpret_cast(&mreq), sizeof(mreq)) + < 0) { + ::close(recv_fd); + return false; + } + } + else { + sockaddr_in6 bind_addr{}; + bind_addr.sin6_family = AF_INET6; + bind_addr.sin6_addr = in6addr_any; + bind_addr.sin6_port = 0; + + if (::bind(recv_fd, reinterpret_cast(&bind_addr), sizeof(bind_addr)) < 0) { + ::close(recv_fd); + return false; + } + + socklen_t len = sizeof(bind_addr); + ::getsockname(recv_fd, reinterpret_cast(&bind_addr), &len); + port = ntohs(bind_addr.sin6_port); + + // Join the multicast group + struct ipv6_mreq mreq {}; + ::inet_pton(AF_INET6, group_addr, &mreq.ipv6mr_multiaddr); + mreq.ipv6mr_interface = 0; + if (::setsockopt(recv_fd, + IPPROTO_IPV6, + IPV6_JOIN_GROUP, + reinterpret_cast(&mreq), + sizeof(mreq)) + < 0) { + ::close(recv_fd); + return false; + } + } + + // Create a send socket + NUClear::fd_t send_fd = ::socket(af, SOCK_DGRAM, 0); + if (send_fd < 0) { + ::close(recv_fd); + return false; + } + + // Set multicast loopback so we receive our own packet + if (af == AF_INET) { + uint8_t loop = 1; + ::setsockopt(send_fd, IPPROTO_IP, IP_MULTICAST_LOOP, reinterpret_cast(&loop), sizeof(loop)); + } + else { + int loop = 1; + ::setsockopt(send_fd, IPPROTO_IPV6, IPV6_MULTICAST_LOOP, reinterpret_cast(&loop), sizeof(loop)); + } + + // Send a test packet to the multicast group + const char test_msg[] = "MCAST_TEST"; + if (af == AF_INET) { + sockaddr_in dest{}; + dest.sin_family = AF_INET; + dest.sin_port = htons(port); + ::inet_pton(AF_INET, group_addr, &dest.sin_addr); + ::sendto(send_fd, + test_msg, + sizeof(test_msg), + 0, + reinterpret_cast(&dest), + sizeof(dest)); + } + else { + sockaddr_in6 dest{}; + dest.sin6_family = AF_INET6; + dest.sin6_port = htons(port); + ::inet_pton(AF_INET6, group_addr, &dest.sin6_addr); + ::sendto(send_fd, + test_msg, + sizeof(test_msg), + 0, + reinterpret_cast(&dest), + sizeof(dest)); + } + + // Wait for the packet with a 200ms timeout using select (portable across all platforms) + fd_set read_fds; + FD_ZERO(&read_fds); // NOLINT(readability-isolate-declaration) + FD_SET(recv_fd, &read_fds); // NOLINT(hicpp-signed-bitwise) + struct timeval tv {}; + tv.tv_sec = 0; + tv.tv_usec = 200000; // 200ms + + int ready = ::select(static_cast(recv_fd) + 1, &read_fds, nullptr, nullptr, &tv); + + bool success = false; + if (ready > 0) { + // Verify the received data matches what we sent to avoid false positives + char buf[64] = {0}; + ssize_t n = ::recvfrom(recv_fd, buf, sizeof(buf), 0, nullptr, nullptr); + success = (n == static_cast(sizeof(test_msg)) && std::memcmp(buf, test_msg, sizeof(test_msg)) == 0); + } + + ::close(send_fd); + ::close(recv_fd); + + return success; +} + +} // namespace + bool has_ipv4_multicast() { - // See if any interface has multicast ipv4 + // First check if any interface reports multicast support auto ifaces = NUClear::util::network::get_interfaces(); - return std::any_of(ifaces.begin(), ifaces.end(), [](const auto& iface) { + bool has_flag = std::any_of(ifaces.begin(), ifaces.end(), [](const auto& iface) { return iface.ip.sock.sa_family == AF_INET && iface.flags.multicast; }); + if (!has_flag) { + return false; + } + + // Then verify multicast actually works with a real round-trip + return test_multicast_roundtrip(AF_INET, "239.255.255.250"); } bool has_ipv6_multicast() { - // See if any interface has multicast ipv6 + // First check if any interface reports multicast support auto ifaces = NUClear::util::network::get_interfaces(); - return std::any_of(ifaces.begin(), ifaces.end(), [](const auto& iface) { + bool has_flag = std::any_of(ifaces.begin(), ifaces.end(), [](const auto& iface) { return iface.ip.sock.sa_family == AF_INET6 && iface.flags.multicast; }); + if (!has_flag) { + return false; + } + + // Then verify multicast actually works with a real round-trip + return test_multicast_roundtrip(AF_INET6, "ff02::1"); } } // namespace test_util diff --git a/tests/tests/nuclearnet/Discovery.cpp b/tests/tests/nuclearnet/Discovery.cpp new file mode 100644 index 00000000..bf7defff --- /dev/null +++ b/tests/tests/nuclearnet/Discovery.cpp @@ -0,0 +1,442 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "nuclearnet/Discovery.hpp" + +#include +#include +#include + +#include "nuclearnet/wire_protocol.hpp" +#include "util/platform.hpp" + +using NUClear::network::HandshakeState; +using NUClear::network::Discovery; +using NUClear::network::PeerInfo; +using NUClear::network::SYN; +using NUClear::network::CON_ACK; +using NUClear::util::network::sock_t; + +namespace { +sock_t make_addr(uint32_t ip, uint16_t port) { + sock_t addr{}; + addr.ipv4.sin_family = AF_INET; + addr.ipv4.sin_port = htons(port); + addr.ipv4.sin_addr.s_addr = htonl(ip); + return addr; +} +} // namespace + +SCENARIO("Discovery build_announce_packet produces valid packet", "[nuclearnet][discovery]") { + auto packet = Discovery::build_announce_packet("test_node", {0x1111, 0x2222}); + + // Should start with the magic bytes (☢ = 0xE2, 0x98, 0xA2) + REQUIRE(packet.size() >= 5); + REQUIRE(packet[0] == 0xE2); + REQUIRE(packet[1] == 0x98); + REQUIRE(packet[2] == 0xA2); +} + +SCENARIO("Discovery build_leave_packet produces valid packet", "[nuclearnet][discovery]") { + auto packet = Discovery::build_leave_packet(); + + // Should have the magic bytes and LEAVE type + REQUIRE(packet.size() >= 5); + REQUIRE(packet[0] == 0xE2); + REQUIRE(packet[1] == 0x98); + REQUIRE(packet[2] == 0xA2); +} + +SCENARIO("Discovery process_announce adds a new peer", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + bool join_called = false; + + disc.set_join_callback([&](const PeerInfo& info) { + join_called = true; + }); + + auto announce = Discovery::build_announce_packet("peer_a", {0x1111}); + disc.process_announce(peer_addr, announce.data(), announce.size()); + + // Join is deferred until handshake completes + REQUIRE_FALSE(join_called); + REQUIRE(disc.has_peer(peer_addr)); + + const auto* peer = disc.get_peer(peer_addr); + REQUIRE(peer != nullptr); + REQUIRE(peer->name == "peer_a"); + REQUIRE(peer->subscriptions.count(0x1111) == 1); + REQUIRE(peer->announce_heard); + REQUIRE(peer->handshake == HandshakeState::IDLE); +} + +SCENARIO("Discovery process_announce updates existing peer subscriptions", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + bool sub_changed = false; + + disc.set_subscription_change_callback([&](const PeerInfo&) { sub_changed = true; }); + + auto announce1 = Discovery::build_announce_packet("peer_a", {0x1111}); + disc.process_announce(peer_addr, announce1.data(), announce1.size()); + + // Send a new announce with different subscriptions + sub_changed = false; + auto announce2 = Discovery::build_announce_packet("peer_a", {0x2222, 0x3333}); + disc.process_announce(peer_addr, announce2.data(), announce2.size()); + + REQUIRE(sub_changed); + const auto* peer = disc.get_peer(peer_addr); + REQUIRE(peer != nullptr); + REQUIRE(peer->subscriptions.count(0x2222) == 1); + REQUIRE(peer->subscriptions.count(0x3333) == 1); + REQUIRE(peer->subscriptions.count(0x1111) == 0); +} + +SCENARIO("Discovery process_leave removes a peer", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + bool leave_called = false; + + disc.set_leave_callback([&](const PeerInfo& info) { + leave_called = true; + }); + + // Add the peer and complete the handshake + auto announce = Discovery::build_announce_packet("peer_a", {}); + disc.process_announce(peer_addr, announce.data(), announce.size()); + disc.mark_syn_sent(peer_addr); + disc.process_connect(peer_addr, SYN | CON_ACK); // SYN+ACK response + REQUIRE(disc.is_connected(peer_addr)); + + // Now process a leave + disc.process_leave(peer_addr); + + REQUIRE(leave_called); + REQUIRE_FALSE(disc.has_peer(peer_addr)); +} + +SCENARIO("Discovery check_timeouts removes stale peers", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::milliseconds(20)); // 20ms timeout for testing + + sock_t peer_addr = make_addr(0x0A000001, 5000); + bool leave_called = false; + + disc.set_leave_callback([&](const PeerInfo&) { leave_called = true; }); + + // Add peer at time T and complete handshake + auto t = std::chrono::steady_clock::now(); + auto announce = Discovery::build_announce_packet("peer_a", {}); + disc.process_announce(peer_addr, announce.data(), announce.size(), t); + disc.mark_syn_sent(peer_addr); + disc.process_connect(peer_addr, SYN | CON_ACK, t); + + // Check at T+10ms (before timeout) — peer should still be there + auto removed = disc.check_timeouts(t + std::chrono::milliseconds(10)); + REQUIRE(removed.empty()); + REQUIRE(disc.has_peer(peer_addr)); + + // Check at T+25ms (after 20ms timeout) — peer should be removed + removed = disc.check_timeouts(t + std::chrono::milliseconds(25)); + REQUIRE(removed.size() == 1); + REQUIRE(removed[0].name == "peer_a"); + REQUIRE(leave_called); + REQUIRE_FALSE(disc.has_peer(peer_addr)); +} + +SCENARIO("Discovery touch_peer resets timeout", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::milliseconds(200)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + + // Add peer at time T and complete handshake + auto t = std::chrono::steady_clock::now(); + auto announce = Discovery::build_announce_packet("peer_a", {}); + disc.process_announce(peer_addr, announce.data(), announce.size(), t); + disc.mark_syn_sent(peer_addr); + disc.process_connect(peer_addr, SYN | CON_ACK, t); + + // Touch at T+120ms (before 200ms timeout expires) + disc.touch_peer(peer_addr, t + std::chrono::milliseconds(120)); + + // Check at T+240ms — 240ms since announce, but only 120ms since touch + // Since timeout is 200ms from last_seen, peer should still be alive + auto removed = disc.check_timeouts(t + std::chrono::milliseconds(240)); + REQUIRE(removed.empty()); + REQUIRE(disc.has_peer(peer_addr)); + + // Check at T+325ms — 205ms since touch, should now be timed out + removed = disc.check_timeouts(t + std::chrono::milliseconds(325)); + REQUIRE(removed.size() == 1); + REQUIRE(removed[0].name == "peer_a"); +} + +SCENARIO("Discovery get_peers returns all known peers", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t addr_a = make_addr(0x0A000001, 5000); + sock_t addr_b = make_addr(0x0A000002, 5000); + + auto announce_a = Discovery::build_announce_packet("node_a", {0x1111}); + auto announce_b = Discovery::build_announce_packet("node_b", {0x2222}); + + disc.process_announce(addr_a, announce_a.data(), announce_a.size()); + disc.process_announce(addr_b, announce_b.data(), announce_b.size()); + + auto peers = disc.get_peers(); + REQUIRE(peers.size() == 2); + REQUIRE(peers.count(addr_a) == 1); + REQUIRE(peers.count(addr_b) == 1); + REQUIRE(peers[addr_a].name == "node_a"); + REQUIRE(peers[addr_b].name == "node_b"); +} + +SCENARIO("Discovery 3-way handshake normal flow", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + bool join_called = false; + std::string joined_name; + + disc.set_join_callback([&](const PeerInfo& info) { + join_called = true; + joined_name = info.name; + }); + + // Peer announces (heard on announce channel — sets announce_heard) + auto announce = Discovery::build_announce_packet("peer_a", {0x1111}); + disc.process_announce(peer_addr, announce.data(), announce.size()); + REQUIRE_FALSE(join_called); + REQUIRE(disc.get_peer(peer_addr)->announce_heard); + REQUIRE(disc.get_peer(peer_addr)->handshake == HandshakeState::IDLE); + + // We send SYN + disc.mark_syn_sent(peer_addr); + REQUIRE(disc.get_peer(peer_addr)->handshake == HandshakeState::SYN_SENT); + + // Peer responds with SYN+ACK + auto result = disc.process_connect(peer_addr, SYN | CON_ACK); + REQUIRE(result.just_connected); + REQUIRE(result.response_flags == CON_ACK); // We should send ACK back + REQUIRE(join_called); + REQUIRE(joined_name == "peer_a"); + REQUIRE(disc.is_connected(peer_addr)); +} + +SCENARIO("Discovery 3-way handshake receiving SYN first", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + bool join_called = false; + + disc.set_join_callback([&](const PeerInfo&) { join_called = true; }); + + // Peer announces and we add them + auto announce = Discovery::build_announce_packet("peer_a", {}); + disc.process_announce(peer_addr, announce.data(), announce.size()); + + // Peer sends SYN to us (they initiated) + auto result = disc.process_connect(peer_addr, SYN); + REQUIRE_FALSE(result.just_connected); + REQUIRE(result.response_flags == (SYN | CON_ACK)); // We respond with SYN+ACK + REQUIRE_FALSE(join_called); + REQUIRE(disc.get_peer(peer_addr)->handshake == HandshakeState::SYN_RECEIVED); + + // Peer sends ACK to complete the handshake + result = disc.process_connect(peer_addr, CON_ACK); + REQUIRE(result.just_connected); + REQUIRE(result.response_flags == 0); // No further response needed + REQUIRE(join_called); + REQUIRE(disc.is_connected(peer_addr)); +} + +SCENARIO("Discovery 3-way handshake simultaneous open", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + bool join_called = false; + + disc.set_join_callback([&](const PeerInfo&) { join_called = true; }); + + // Peer announces + auto announce = Discovery::build_announce_packet("peer_a", {}); + disc.process_announce(peer_addr, announce.data(), announce.size()); + + // We send SYN + disc.mark_syn_sent(peer_addr); + + // But peer also sent SYN at the same time (simultaneous open) + auto result = disc.process_connect(peer_addr, SYN); + REQUIRE_FALSE(result.just_connected); + REQUIRE(result.response_flags == (SYN | CON_ACK)); // Respond with SYN+ACK + REQUIRE(disc.get_peer(peer_addr)->handshake == HandshakeState::SYN_RECEIVED); + + // Peer also sends SYN+ACK (they got our SYN) + result = disc.process_connect(peer_addr, SYN | CON_ACK); + REQUIRE(result.just_connected); + REQUIRE(result.response_flags == CON_ACK); + REQUIRE(join_called); + REQUIRE(disc.is_connected(peer_addr)); +} + +SCENARIO("Discovery build_connect_packet produces valid packet", "[nuclearnet][discovery]") { + auto syn_packet = Discovery::build_connect_packet(SYN); + REQUIRE(syn_packet.size() == 6); + REQUIRE(syn_packet[0] == 0xE2); + REQUIRE(syn_packet[1] == 0x98); + REQUIRE(syn_packet[2] == 0xA2); + REQUIRE(syn_packet[4] == 5); // CONNECT type + REQUIRE(syn_packet[5] == SYN); + + auto synack_packet = Discovery::build_connect_packet(SYN | CON_ACK); + REQUIRE(synack_packet[5] == (SYN | CON_ACK)); +} + +SCENARIO("Discovery process_leave does not fire callback for non-connected peer", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + bool leave_called = false; + + disc.set_leave_callback([&](const PeerInfo&) { leave_called = true; }); + + // Add peer but do NOT complete handshake + auto announce = Discovery::build_announce_packet("peer_a", {}); + disc.process_announce(peer_addr, announce.data(), announce.size()); + + disc.process_leave(peer_addr); + + REQUIRE_FALSE(leave_called); + REQUIRE_FALSE(disc.has_peer(peer_addr)); +} + +SCENARIO("Discovery connection deferred until announce heard", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + bool join_called = false; + + disc.set_join_callback([&](const PeerInfo&) { join_called = true; }); + + // Peer sends SYN before we've heard their announce (they added us via CONNECT) + auto result = disc.process_connect(peer_addr, SYN); + REQUIRE(result.response_flags == (SYN | CON_ACK)); + REQUIRE_FALSE(join_called); + + // Complete the data handshake + result = disc.process_connect(peer_addr, CON_ACK); + // Data path is confirmed but announce not yet heard — NOT connected + REQUIRE_FALSE(result.just_connected); + REQUIRE_FALSE(join_called); + REQUIRE_FALSE(disc.is_connected(peer_addr)); + REQUIRE(disc.get_peer(peer_addr)->handshake == HandshakeState::CONFIRMED); + REQUIRE_FALSE(disc.get_peer(peer_addr)->announce_heard); + + // Now we hear their announce on the announce channel + auto announce = Discovery::build_announce_packet("peer_a", {}); + disc.process_announce(peer_addr, announce.data(), announce.size()); + + // NOW the connection should be up + REQUIRE(join_called); + REQUIRE(disc.is_connected(peer_addr)); +} + +SCENARIO("Discovery retransmits SYN when announce received in SYN_SENT state", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + + // First announce — new peer + auto announce = Discovery::build_announce_packet("peer_a", {}); + auto result = disc.process_announce(peer_addr, announce.data(), announce.size()); + REQUIRE(result.is_new); + + // We send SYN (externally) and mark state + disc.mark_syn_sent(peer_addr); + REQUIRE(disc.get_peer(peer_addr)->handshake == HandshakeState::SYN_SENT); + + // SYN was dropped. Another announce arrives — should indicate SYN retransmit + result = disc.process_announce(peer_addr, announce.data(), announce.size()); + REQUIRE_FALSE(result.is_new); + REQUIRE(result.response_flags == SYN); +} + +SCENARIO("Discovery retransmits SYN+ACK when announce received in SYN_RECEIVED state", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + + // Add peer via announce + auto announce = Discovery::build_announce_packet("peer_a", {}); + disc.process_announce(peer_addr, announce.data(), announce.size()); + + // Peer sends SYN — we go to SYN_RECEIVED + auto connect_result = disc.process_connect(peer_addr, SYN); + REQUIRE(connect_result.response_flags == (SYN | CON_ACK)); + REQUIRE(disc.get_peer(peer_addr)->handshake == HandshakeState::SYN_RECEIVED); + + // Our SYN+ACK was dropped. Another announce arrives — should indicate SYN+ACK retransmit + auto result = disc.process_announce(peer_addr, announce.data(), announce.size()); + REQUIRE_FALSE(result.is_new); + REQUIRE(result.response_flags == (SYN | CON_ACK)); +} + +SCENARIO("Discovery retransmits ACK when announce received in CONFIRMED but peer not connected", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + + // Add peer via announce and complete handshake + auto announce = Discovery::build_announce_packet("peer_a", {}); + disc.process_announce(peer_addr, announce.data(), announce.size()); + disc.mark_syn_sent(peer_addr); + auto connect_result = disc.process_connect(peer_addr, SYN | CON_ACK); + REQUIRE(connect_result.just_connected); + REQUIRE(disc.is_connected(peer_addr)); + + // Peer is fully connected — no retransmit needed + auto result = disc.process_announce(peer_addr, announce.data(), announce.size()); + REQUIRE_FALSE(result.is_new); + REQUIRE(result.response_flags == CON_ACK); +} + +SCENARIO("Discovery no retransmit for IDLE peer (not yet sent SYN)", "[nuclearnet][discovery]") { + Discovery disc(std::chrono::seconds(5)); + + sock_t peer_addr = make_addr(0x0A000001, 5000); + + // First announce — new peer, handshake IDLE + auto announce = Discovery::build_announce_packet("peer_a", {}); + auto result = disc.process_announce(peer_addr, announce.data(), announce.size()); + REQUIRE(result.is_new); + REQUIRE(disc.get_peer(peer_addr)->handshake == HandshakeState::IDLE); + + // Second announce — peer still in IDLE (we haven't sent SYN yet) + // Should indicate SYN needed + result = disc.process_announce(peer_addr, announce.data(), announce.size()); + REQUIRE_FALSE(result.is_new); + REQUIRE(result.response_flags == SYN); +} diff --git a/tests/tests/nuclearnet/Fragmentation.cpp b/tests/tests/nuclearnet/Fragmentation.cpp new file mode 100644 index 00000000..4fa96b4c --- /dev/null +++ b/tests/tests/nuclearnet/Fragmentation.cpp @@ -0,0 +1,237 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "nuclearnet/Fragmentation.hpp" + +#include +#include +#include +#include +#include + +using NUClear::network::Fragmentation; + +SCENARIO("Fragmentation splits large payload into MTU-sized fragments", "[nuclearnet][fragmentation]") { + Fragmentation frag(100); // 100-byte MTU + + // 250 bytes of data → should produce 3 fragments (100 + 100 + 50) + std::vector payload(250); + std::iota(payload.begin(), payload.end(), uint8_t(0)); + + auto fragments = frag.fragment(1, 0xDEADBEEF, 0, payload); + + REQUIRE(fragments.size() == 3); + REQUIRE(fragments[0].packet_no == 0); + REQUIRE(fragments[1].packet_no == 1); + REQUIRE(fragments[2].packet_no == 2); + + REQUIRE(fragments[0].packet_count == 3); + REQUIRE(fragments[1].packet_count == 3); + REQUIRE(fragments[2].packet_count == 3); + + REQUIRE(fragments[0].data.size() == 100); + REQUIRE(fragments[1].data.size() == 100); + REQUIRE(fragments[2].data.size() == 50); + + // Verify data integrity + REQUIRE(fragments[0].data[0] == 0); + REQUIRE(fragments[0].data[99] == 99); + REQUIRE(fragments[1].data[0] == 100); + REQUIRE(fragments[2].data[0] == 200); + REQUIRE(fragments[2].data[49] == 249); +} + +SCENARIO("Fragmentation produces single fragment for small payload", "[nuclearnet][fragmentation]") { + Fragmentation frag(1452); + + std::vector payload(100, 0xAB); + auto fragments = frag.fragment(5, 0x12345678, 0x01, payload); + + REQUIRE(fragments.size() == 1); + REQUIRE(fragments[0].packet_no == 0); + REQUIRE(fragments[0].packet_count == 1); + REQUIRE(fragments[0].data.size() == 100); + REQUIRE(fragments[0].hash == 0x12345678); + REQUIRE(fragments[0].flags == 0x01); +} + +SCENARIO("Fragmentation produces single fragment for empty payload", "[nuclearnet][fragmentation]") { + Fragmentation frag(1452); + + std::vector payload; + auto fragments = frag.fragment(0, 0x11111111, 0, payload); + + REQUIRE(fragments.size() == 1); + REQUIRE(fragments[0].packet_count == 1); + REQUIRE(fragments[0].data.empty()); +} + +SCENARIO("Fragmentation reassembles fragments into original payload", "[nuclearnet][fragmentation]") { + Fragmentation frag(100); + + // Create a payload and fragment it + std::vector payload(250); + std::iota(payload.begin(), payload.end(), uint8_t(0)); + + auto fragments = frag.fragment(1, 0xDEADBEEF, 0, payload); + + // Submit fragments in order + Fragmentation::AssembledPacket result; + bool complete = false; + for (const auto& f : fragments) { + complete = frag.submit_fragment(99, + f.packet_id, + f.packet_no, + f.packet_count, + f.hash, + f.flags, + f.data.data(), + f.data.size(), + result); + } + + REQUIRE(complete); + REQUIRE(result.payload == payload); + REQUIRE(result.hash == 0xDEADBEEF); + REQUIRE(result.packet_id == 1); +} + +SCENARIO("Fragmentation reassembles out-of-order fragments", "[nuclearnet][fragmentation]") { + Fragmentation frag(100); + + std::vector payload(250); + std::iota(payload.begin(), payload.end(), uint8_t(0)); + + auto fragments = frag.fragment(7, 0xCAFEBABE, 0x01, payload); + + // Submit in reverse order + Fragmentation::AssembledPacket result; + REQUIRE_FALSE(frag.submit_fragment(1, + 7, + 2, + 3, + 0xCAFEBABE, + 0x01, + fragments[2].data.data(), + fragments[2].data.size(), + result)); + REQUIRE_FALSE(frag.submit_fragment(1, + 7, + 0, + 3, + 0xCAFEBABE, + 0x01, + fragments[0].data.data(), + fragments[0].data.size(), + result)); + + // Last fragment completes the assembly + bool complete = frag.submit_fragment(1, + 7, + 1, + 3, + 0xCAFEBABE, + 0x01, + fragments[1].data.data(), + fragments[1].data.size(), + result); + + REQUIRE(complete); + REQUIRE(result.payload == payload); +} + +SCENARIO("Fragmentation rejects oversized assemblies", "[nuclearnet][fragmentation]") { + // Allow only 200 bytes total + Fragmentation frag(100, 200); + + // Try to submit a fragment that implies a total size > 200 bytes + // 3 fragments × 100 byte MTU = 300 bytes projected > 200 byte limit + uint8_t data[100] = {}; + Fragmentation::AssembledPacket result; + bool complete = frag.submit_fragment(1, 1, 0, 3, 0x1234, 0, data, 100, result); + REQUIRE_FALSE(complete); +} + +SCENARIO("Fragmentation rejects invalid fragment indices", "[nuclearnet][fragmentation]") { + Fragmentation frag(100); + + uint8_t data[50] = {}; + + // packet_no >= packet_count is invalid + Fragmentation::AssembledPacket result; + bool complete = frag.submit_fragment(1, 1, 5, 3, 0x1234, 0, data, 50, result); + REQUIRE_FALSE(complete); + + // packet_count == 0 is invalid + complete = frag.submit_fragment(1, 1, 0, 0, 0x1234, 0, data, 50, result); + REQUIRE_FALSE(complete); +} + +SCENARIO("Fragmentation cleanup_expired removes stale assemblies", "[nuclearnet][fragmentation]") { + // Use a very short timeout for testing + Fragmentation frag(100, 64 * 1024 * 1024, std::chrono::milliseconds(1)); + + // Submit a partial assembly at time T + auto t = std::chrono::steady_clock::now(); + uint8_t data[50] = {}; + Fragmentation::AssembledPacket result; + frag.submit_fragment(1, 1, 0, 3, 0x1234, 0, data, 50, result, t); + + // Cleanup at T (not expired yet) — nothing removed + std::size_t removed = frag.cleanup_expired(t); + REQUIRE(removed == 0); + + // Cleanup at T+10ms (past 1ms timeout) — should remove it + removed = frag.cleanup_expired(t + std::chrono::milliseconds(10)); + REQUIRE(removed == 1); + + // Second cleanup should find nothing + removed = frag.cleanup_expired(t + std::chrono::milliseconds(20)); + REQUIRE(removed == 0); +} + +SCENARIO("Fragmentation handles multiple independent assemblies", "[nuclearnet][fragmentation]") { + Fragmentation frag(100); + + uint8_t data_a[100]; + uint8_t data_b[100]; + std::fill_n(data_a, 100, 0xAA); + std::fill_n(data_b, 100, 0xBB); + + // Two different sources sending 2-fragment messages + Fragmentation::AssembledPacket result1; + Fragmentation::AssembledPacket result2; + REQUIRE_FALSE(frag.submit_fragment(1, 10, 0, 2, 0x1111, 0, data_a, 100, result1)); + REQUIRE_FALSE(frag.submit_fragment(2, 10, 0, 2, 0x2222, 0, data_b, 100, result2)); + + // Complete source 2's message + REQUIRE(frag.submit_fragment(2, 10, 1, 2, 0x2222, 0, data_b, 100, result2)); + REQUIRE(result2.hash == 0x2222); + REQUIRE(result2.payload.size() == 200); + REQUIRE(result2.payload[0] == 0xBB); + + // Source 1 still incomplete + // Complete it + REQUIRE(frag.submit_fragment(1, 10, 1, 2, 0x1111, 0, data_a, 100, result1)); + REQUIRE(result1.hash == 0x1111); + REQUIRE(result1.payload[0] == 0xAA); +} diff --git a/tests/tests/nuclearnet/Integration.cpp b/tests/tests/nuclearnet/Integration.cpp new file mode 100644 index 00000000..9d040c61 --- /dev/null +++ b/tests/tests/nuclearnet/Integration.cpp @@ -0,0 +1,262 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO + * THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF + * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "nuclearnet/NUClearNet.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "test_util/has_multicast.hpp" + +using namespace std::chrono_literals; + +namespace { + +using NUClear::network::NUClearNet; +using NUClear::network::NetworkConfig; +using NUClear::network::PeerInfo; +using NUClear::util::network::sock_t; + +constexpr uint64_t HASH_A = 0x1111'2222'3333'4444ULL; +constexpr uint64_t HASH_B = 0x5555'6666'7777'8888ULL; + +NetworkConfig make_config(const std::string& name) { + NetworkConfig config; + config.name = name; + config.announce_address = "239.226.152.162"; + config.announce_port = 17747; + config.mtu = 1500; + config.peer_timeout = 1s; + return config; +} + +bool wait_for(const std::function& predicate, + std::chrono::milliseconds timeout, + const std::function& pump) { + auto deadline = std::chrono::steady_clock::now() + timeout; + while (std::chrono::steady_clock::now() < deadline) { + if (predicate()) { + return true; + } + pump(); + std::this_thread::sleep_for(5ms); + } + return predicate(); +} + +struct NetworkPair { + NUClearNet a; + NUClearNet b; + + ~NetworkPair() { + a.shutdown(); + b.shutdown(); + } +}; + +std::vector make_payload(std::size_t size, uint8_t seed) { + std::vector payload(size); + for (std::size_t i = 0; i < payload.size(); ++i) { + payload[i] = static_cast(seed + static_cast(i)); + } + return payload; +} + +} // namespace + +SCENARIO("Two NUClearNet instances discover and exchange messages", "[nuclearnet][integration]") { + if (!test_util::has_ipv4_multicast()) { + SKIP("IPv4 multicast is unavailable on this system"); + } + + NetworkPair net; + + net.a.reset(make_config("alpha")); + net.b.reset(make_config("bravo")); + + // Each peer subscribes to the other's hash so routing has something real to filter on. + net.a.set_subscriptions({HASH_B}); + net.b.set_subscriptions({HASH_A}); + + std::mutex mutex; + std::vector join_events; + std::vector leave_events; + std::vector>> received; + + net.a.set_join_callback([&](const PeerInfo& peer) { + std::lock_guard lock(mutex); + join_events.push_back("a:" + peer.name); + }); + net.b.set_join_callback([&](const PeerInfo& peer) { + std::lock_guard lock(mutex); + join_events.push_back("b:" + peer.name); + }); + + net.a.set_leave_callback([&](const PeerInfo& peer) { + std::lock_guard lock(mutex); + leave_events.push_back("a:" + peer.name); + }); + net.b.set_leave_callback([&](const PeerInfo& peer) { + std::lock_guard lock(mutex); + leave_events.push_back("b:" + peer.name); + }); + + net.a.set_packet_callback([&](const sock_t&, const std::string& peer_name, uint64_t hash, bool reliable, + std::vector&& payload) { + std::lock_guard lock(mutex); + received.emplace_back("a:" + peer_name + ":" + std::to_string(hash) + ":" + (reliable ? "1" : "0"), + std::move(payload)); + }); + net.b.set_packet_callback([&](const sock_t&, const std::string& peer_name, uint64_t hash, bool reliable, + std::vector&& payload) { + std::lock_guard lock(mutex); + received.emplace_back("b:" + peer_name + ":" + std::to_string(hash) + ":" + (reliable ? "1" : "0"), + std::move(payload)); + }); + + REQUIRE(wait_for([&] { + std::lock_guard lock(mutex); + return std::find(join_events.begin(), join_events.end(), "a:bravo") != join_events.end() + && std::find(join_events.begin(), join_events.end(), "b:alpha") != join_events.end(); + }, 5s, [&] { + net.a.process(); + net.b.process(); + })); + + auto payload_a = make_payload(4096, 0x10); + auto payload_b = make_payload(64, 0x80); + + net.a.send(HASH_A, payload_a.data(), payload_a.size(), "", true); + net.b.send(HASH_B, payload_b.data(), payload_b.size(), "", false); + + REQUIRE(wait_for([&] { + std::lock_guard lock(mutex); + return received.size() == 2; + }, 5s, [&] { + net.a.process(); + net.b.process(); + })); + + { + std::lock_guard lock(mutex); + REQUIRE(received.size() == 2); + + const auto expected_from_b = std::string("b:alpha:") + std::to_string(HASH_A) + ":1"; + const auto expected_from_a = std::string("a:bravo:") + std::to_string(HASH_B) + ":0"; + + auto it_a = std::find_if(received.begin(), received.end(), [&expected_from_b](const auto& entry) { + return entry.first == expected_from_b; + }); + auto it_b = std::find_if(received.begin(), received.end(), [&expected_from_a](const auto& entry) { + return entry.first == expected_from_a; + }); + + REQUIRE(it_a != received.end()); + REQUIRE(it_b != received.end()); + REQUIRE(it_a->second == payload_a); + REQUIRE(it_b->second == payload_b); + } + + net.b.shutdown(); + + REQUIRE(wait_for([&] { + std::lock_guard lock(mutex); + return std::find(leave_events.begin(), leave_events.end(), "a:bravo") != leave_events.end(); + }, 5s, [&] { + net.a.process(); + })); +} + +SCENARIO("NUClearNet handles bidirectional reliable traffic", "[nuclearnet][integration]") { + if (!test_util::has_ipv4_multicast()) { + SKIP("IPv4 multicast is unavailable on this system"); + } + + NetworkPair net; + + net.a.reset(make_config("left")); + net.b.reset(make_config("right")); + + net.a.set_subscriptions({HASH_B}); + net.b.set_subscriptions({HASH_A}); + + std::mutex mutex; + std::vector join_events; + std::vector> a_received; + std::vector> b_received; + + net.a.set_join_callback([&](const PeerInfo& peer) { + std::lock_guard lock(mutex); + join_events.push_back("a:" + peer.name); + }); + net.b.set_join_callback([&](const PeerInfo& peer) { + std::lock_guard lock(mutex); + join_events.push_back("b:" + peer.name); + }); + + net.a.set_packet_callback([&](const sock_t&, const std::string&, uint64_t, bool, std::vector&& payload) { + std::lock_guard lock(mutex); + a_received.push_back(std::move(payload)); + }); + net.b.set_packet_callback([&](const sock_t&, const std::string&, uint64_t, bool, std::vector&& payload) { + std::lock_guard lock(mutex); + b_received.push_back(std::move(payload)); + }); + + REQUIRE(wait_for([&] { + std::lock_guard lock(mutex); + return std::find(join_events.begin(), join_events.end(), "a:right") != join_events.end() + && std::find(join_events.begin(), join_events.end(), "b:left") != join_events.end(); + }, 5s, [&] { + net.a.process(); + net.b.process(); + })); + + auto large_payload = make_payload(8192, 0x33); + auto small_payload = make_payload(32, 0x44); + + net.a.send(HASH_A, large_payload.data(), large_payload.size(), "", true); + net.b.send(HASH_B, small_payload.data(), small_payload.size(), "", true); + + REQUIRE(wait_for([&] { + std::lock_guard lock(mutex); + return a_received.size() == 1 && b_received.size() == 1; + }, 5s, [&] { + net.a.process(); + net.b.process(); + })); + + { + std::lock_guard lock(mutex); + REQUIRE(a_received[0] == small_payload); + REQUIRE(b_received[0] == large_payload); + } +} diff --git a/tests/tests/nuclearnet/PacketDeduplicator.cpp b/tests/tests/nuclearnet/PacketDeduplicator.cpp new file mode 100644 index 00000000..4f2a78b1 --- /dev/null +++ b/tests/tests/nuclearnet/PacketDeduplicator.cpp @@ -0,0 +1,120 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "nuclearnet/PacketDeduplicator.hpp" + +#include + +using NUClear::network::PacketDeduplicator; + +SCENARIO("PacketDeduplicator rejects duplicate packet IDs", "[nuclearnet][deduplicator]") { + PacketDeduplicator dedup; + + // First time seeing packet 42 — not a duplicate + REQUIRE_FALSE(dedup.is_duplicate(42)); + + // Add it to the seen set + dedup.add_packet(42); + + // Now it should be detected as a duplicate + REQUIRE(dedup.is_duplicate(42)); +} + +SCENARIO("PacketDeduplicator accepts distinct packet IDs", "[nuclearnet][deduplicator]") { + PacketDeduplicator dedup; + + dedup.add_packet(1); + dedup.add_packet(2); + dedup.add_packet(3); + + REQUIRE_FALSE(dedup.is_duplicate(4)); + REQUIRE_FALSE(dedup.is_duplicate(100)); + REQUIRE_FALSE(dedup.is_duplicate(255)); +} + +SCENARIO("PacketDeduplicator sliding window advances and forgets old IDs", "[nuclearnet][deduplicator]") { + PacketDeduplicator dedup; + + // Add packet 0 + dedup.add_packet(0); + REQUIRE(dedup.is_duplicate(0)); + + // Advance the window far enough that packet 0 falls outside the window (window size = 256) + for (uint16_t i = 1; i <= 256; ++i) { + dedup.add_packet(i); + } + + // Packet 0 should now be outside the window + // Since it's behind the window base, it should be treated as a duplicate (too old) + REQUIRE(dedup.is_duplicate(0)); +} + +SCENARIO("PacketDeduplicator handles sequential packet IDs", "[nuclearnet][deduplicator]") { + PacketDeduplicator dedup; + + // Process packets 0-99 in order + for (uint16_t i = 0; i < 100; ++i) { + REQUIRE_FALSE(dedup.is_duplicate(i)); + dedup.add_packet(i); + } + + // All should now be duplicates + for (uint16_t i = 0; i < 100; ++i) { + REQUIRE(dedup.is_duplicate(i)); + } +} + +SCENARIO("PacketDeduplicator handles uint16_t wraparound", "[nuclearnet][deduplicator]") { + PacketDeduplicator dedup; + + // Start near the max value + uint16_t start = 65500; + for (uint16_t i = 0; i < 100; ++i) { + uint16_t id = static_cast(start + i); // Will wrap around 65535 → 0 + REQUIRE_FALSE(dedup.is_duplicate(id)); + dedup.add_packet(id); + } + + // IDs that wrapped around should be marked as seen + REQUIRE(dedup.is_duplicate(65500)); + REQUIRE(dedup.is_duplicate(static_cast(65535))); + REQUIRE(dedup.is_duplicate(0)); // wrapped + REQUIRE(dedup.is_duplicate(63)); // 65500 + 99 - 65536 = 63 +} + +SCENARIO("PacketDeduplicator handles out-of-order within window", "[nuclearnet][deduplicator]") { + PacketDeduplicator dedup; + + // Add packet 10 first (advances window) + dedup.add_packet(10); + + // Packets 0-9 arrive late but are still within the 256-element window + for (uint16_t i = 0; i < 10; ++i) { + REQUIRE_FALSE(dedup.is_duplicate(i)); + dedup.add_packet(i); + } + + // All should now be duplicates + for (uint16_t i = 0; i <= 10; ++i) { + REQUIRE(dedup.is_duplicate(i)); + } +} diff --git a/tests/tests/nuclearnet/RTTEstimator.cpp b/tests/tests/nuclearnet/RTTEstimator.cpp new file mode 100644 index 00000000..3933b76f --- /dev/null +++ b/tests/tests/nuclearnet/RTTEstimator.cpp @@ -0,0 +1,86 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "nuclearnet/RTTEstimator.hpp" + +#include +#include +#include + +using NUClear::network::RTTEstimator; +using namespace std::chrono_literals; + +SCENARIO("RTTEstimator initial timeout is 1 second", "[nuclearnet][rtt]") { + RTTEstimator rtt; + // RFC 6298: initial RTO should be generous (we use 1s default) + auto timeout = rtt.timeout(); + REQUIRE(timeout >= 900ms); + REQUIRE(timeout <= 1100ms); +} + +SCENARIO("RTTEstimator converges towards measured RTT", "[nuclearnet][rtt]") { + RTTEstimator rtt; + + // Simulate stable 50ms RTT for several measurements + for (int i = 0; i < 20; ++i) { + rtt.measure(50ms); + } + + auto timeout = rtt.timeout(); + // After convergence, timeout should be close to the measured value + // (SRTT + 4*RTTVAR, but RTTVAR converges to near-zero with constant measurements) + // With constant 50ms, SRTT→50ms, RTTVAR→0, so timeout→50ms + min_rttvar + // But due to the algorithm, it should be well below 500ms + REQUIRE(timeout < 500ms); + REQUIRE(timeout >= 50ms); +} + +SCENARIO("RTTEstimator increases timeout with variable measurements", "[nuclearnet][rtt]") { + RTTEstimator rtt; + + // First establish a baseline + for (int i = 0; i < 10; ++i) { + rtt.measure(50ms); + } + auto stable_timeout = rtt.timeout(); + + // Now introduce high variance + rtt.measure(200ms); + rtt.measure(10ms); + rtt.measure(300ms); + + auto variable_timeout = rtt.timeout(); + // The timeout should be larger due to increased variance + REQUIRE(variable_timeout > stable_timeout); +} + +SCENARIO("RTTEstimator first measurement sets initial estimates", "[nuclearnet][rtt]") { + RTTEstimator rtt; + + // First measurement: SRTT = R, RTTVAR = R/2, RTO = SRTT + 4*RTTVAR = R + 2R = 3R + rtt.measure(100ms); + auto timeout = rtt.timeout(); + + // Should be approximately 3 * 100ms = 300ms + REQUIRE(timeout >= 250ms); + REQUIRE(timeout <= 350ms); +} diff --git a/tests/tests/nuclearnet/Reliability.cpp b/tests/tests/nuclearnet/Reliability.cpp new file mode 100644 index 00000000..853026e1 --- /dev/null +++ b/tests/tests/nuclearnet/Reliability.cpp @@ -0,0 +1,171 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "nuclearnet/Reliability.hpp" + +#include +#include +#include +#include +#include + +#include "util/platform.hpp" + +using NUClear::network::Reliability; +using NUClear::util::network::sock_t; + +namespace { +sock_t make_addr(uint32_t ip, uint16_t port) { + sock_t addr{}; + addr.ipv4.sin_family = AF_INET; + addr.ipv4.sin_port = htons(port); + addr.ipv4.sin_addr.s_addr = htonl(ip); + return addr; +} +} // namespace + +SCENARIO("Reliability tracks sent packet and requests retransmission after timeout", "[nuclearnet][reliability]") { + Reliability rel; + + sock_t target = make_addr(0x0A000001, 5000); + std::vector payload(200, 0xAB); + + // Track a 2-fragment packet at time T + auto t = std::chrono::steady_clock::now(); + rel.track_packet(target, 1, 2, 0x1234, 0x01, payload.data(), payload.size(), t); + + // Immediately, no retransmissions (timeout not elapsed) + auto retransmissions = rel.check_retransmissions(100, t); + REQUIRE(retransmissions.empty()); + + // Inject an RTT measurement to reduce the timeout to min_rto (100ms) + rel.get_rtt(target).measure(std::chrono::milliseconds(10)); + + // Check at T+150ms (past min_rto of 100ms) — should retransmit + retransmissions = rel.check_retransmissions(100, t + std::chrono::milliseconds(150)); + REQUIRE(retransmissions.size() == 2); // Both fragments unacked + REQUIRE(retransmissions[0].packet_no == 0); + REQUIRE(retransmissions[1].packet_no == 1); + REQUIRE(retransmissions[0].data.size() == 100); + REQUIRE(retransmissions[1].data.size() == 100); +} + +SCENARIO("Reliability stops retransmitting ACKed fragments", "[nuclearnet][reliability]") { + Reliability rel; + + sock_t target = make_addr(0x0A000001, 5000); + std::vector payload(200, 0xCC); + + auto t = std::chrono::steady_clock::now(); + rel.track_packet(target, 1, 2, 0x1234, 0x01, payload.data(), payload.size(), t); + + // ACK fragment 0 (bitset: bit 0 set) + uint8_t ack_bits = 0x01; // fragment 0 received + rel.process_ack(target, 1, 2, &ack_bits, 1, t + std::chrono::milliseconds(50)); + + // Inject short RTT and check past min_rto + rel.get_rtt(target).measure(std::chrono::milliseconds(5)); + + auto retransmissions = rel.check_retransmissions(100, t + std::chrono::milliseconds(200)); + // Only fragment 1 should be retransmitted (fragment 0 was ACKed) + REQUIRE(retransmissions.size() == 1); + REQUIRE(retransmissions[0].packet_no == 1); +} + +SCENARIO("Reliability removes tracked packet when all fragments ACKed", "[nuclearnet][reliability]") { + Reliability rel; + + sock_t target = make_addr(0x0A000001, 5000); + std::vector payload(100, 0xDD); + + auto t = std::chrono::steady_clock::now(); + rel.track_packet(target, 5, 1, 0x5678, 0x01, payload.data(), payload.size(), t); + + // ACK all fragments + uint8_t ack_bits = 0x01; // fragment 0 received (only 1 fragment total) + rel.process_ack(target, 5, 1, &ack_bits, 1, t + std::chrono::milliseconds(50)); + + // Inject short RTT and check well past min_rto — nothing to retransmit + rel.get_rtt(target).measure(std::chrono::milliseconds(5)); + + auto retransmissions = rel.check_retransmissions(100, t + std::chrono::milliseconds(200)); + REQUIRE(retransmissions.empty()); +} + +SCENARIO("Reliability retransmits indefinitely until peer is removed", "[nuclearnet][reliability]") { + Reliability rel; + + sock_t target = make_addr(0x0A000001, 5000); + std::vector payload(50, 0xEE); + + auto t = std::chrono::steady_clock::now(); + rel.track_packet(target, 1, 1, 0x1234, 0x01, payload.data(), payload.size(), t); + + // Inject short RTT + rel.get_rtt(target).measure(std::chrono::milliseconds(10)); + + // First retransmission (T+150ms, past min_rto of 100ms) + auto r1 = rel.check_retransmissions(100, t + std::chrono::milliseconds(150)); + REQUIRE(r1.size() == 1); + + // Second retransmission (T+300ms, 150ms since last_send was updated to T+150ms) + auto r2 = rel.check_retransmissions(100, t + std::chrono::milliseconds(300)); + REQUIRE(r2.size() == 1); + + // Third retransmission still works — no limit + auto r3 = rel.check_retransmissions(100, t + std::chrono::milliseconds(450)); + REQUIRE(r3.size() == 1); + + // Removing the peer cleans up all tracked packets + rel.remove_peer(target); + auto r4 = rel.check_retransmissions(100, t + std::chrono::milliseconds(600)); + REQUIRE(r4.empty()); +} + +SCENARIO("Reliability build_ack_packet encodes bitset correctly", "[nuclearnet][reliability]") { + std::vector received = {true, false, true, true, false, false, false, true}; // 0b10001101 = 0x8D + auto ack = Reliability::build_ack_packet(42, 8, received); + + // The packet should contain a header + 1 byte of bitset + // Verify the bitset byte + REQUIRE(ack.size() > 0); + // The last byte(s) should be the bitset + uint8_t bitset_byte = ack.back(); + REQUIRE(bitset_byte == 0x8D); +} + +SCENARIO("Reliability remove_peer removes all tracked state", "[nuclearnet][reliability]") { + Reliability rel; + + sock_t target = make_addr(0x0A000001, 5000); + std::vector payload(100, 0xFF); + + auto t = std::chrono::steady_clock::now(); + rel.track_packet(target, 1, 1, 0x1234, 0x01, payload.data(), payload.size(), t); + rel.get_rtt(target).measure(std::chrono::milliseconds(50)); + + rel.remove_peer(target); + + // After removing, no retransmissions should occur even well past RTO + auto retransmissions = rel.check_retransmissions(100, t + std::chrono::milliseconds(500)); + REQUIRE(retransmissions.empty()); +} diff --git a/tests/tests/nuclearnet/Routing.cpp b/tests/tests/nuclearnet/Routing.cpp new file mode 100644 index 00000000..76e6db91 --- /dev/null +++ b/tests/tests/nuclearnet/Routing.cpp @@ -0,0 +1,150 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "nuclearnet/Routing.hpp" + +#include +#include +#include +#include + +#include "util/platform.hpp" + +using NUClear::network::Routing; +using NUClear::util::network::sock_t; + +namespace { +sock_t make_addr(uint32_t ip, uint16_t port) { + sock_t addr{}; + addr.ipv4.sin_family = AF_INET; + addr.ipv4.sin_port = htons(port); + addr.ipv4.sin_addr.s_addr = htonl(ip); + return addr; +} +} // namespace + +SCENARIO("Routing delivers to peers subscribed to the message hash", "[nuclearnet][routing]") { + Routing routing; + + sock_t peer_a = make_addr(0x0A000001, 5000); + sock_t peer_b = make_addr(0x0A000002, 5000); + + routing.update_peer_subscriptions(peer_a, {0x1111, 0x2222}); + routing.update_peer_subscriptions(peer_b, {0x2222, 0x3333}); + + // Hash 0x1111 — only peer_a + REQUIRE(routing.should_send(peer_a, 0x1111)); + REQUIRE_FALSE(routing.should_send(peer_b, 0x1111)); + + // Hash 0x2222 — both peers + REQUIRE(routing.should_send(peer_a, 0x2222)); + REQUIRE(routing.should_send(peer_b, 0x2222)); + + // Hash 0x3333 — only peer_b + REQUIRE_FALSE(routing.should_send(peer_a, 0x3333)); + REQUIRE(routing.should_send(peer_b, 0x3333)); +} + +SCENARIO("Routing delivers all messages when peer has empty subscription set", "[nuclearnet][routing]") { + Routing routing; + + sock_t peer = make_addr(0x0A000001, 5000); + + // Empty subscription set = receive everything + routing.update_peer_subscriptions(peer, {}); + + REQUIRE(routing.should_send(peer, 0x1111)); + REQUIRE(routing.should_send(peer, 0x9999)); + REQUIRE(routing.should_send(peer, 0)); +} + +SCENARIO("Routing allows sending to unknown peers by default", "[nuclearnet][routing]") { + Routing routing; + + sock_t unknown = make_addr(0x0A000099, 5000); + + // Unknown peer — should default to allowing sends + REQUIRE(routing.should_send(unknown, 0x1111)); +} + +SCENARIO("Routing get_targets returns all subscribed peers for a hash", "[nuclearnet][routing]") { + Routing routing; + + sock_t peer_a = make_addr(0x0A000001, 5000); + sock_t peer_b = make_addr(0x0A000002, 5000); + sock_t peer_c = make_addr(0x0A000003, 5000); + + routing.update_peer_subscriptions(peer_a, {0x1111}); + routing.update_peer_subscriptions(peer_b, {0x1111, 0x2222}); + routing.update_peer_subscriptions(peer_c, {0x2222}); + + std::vector all_peers = {peer_a, peer_b, peer_c}; + auto targets = routing.get_targets(all_peers, 0x1111); + REQUIRE(targets.size() == 2); +} + +SCENARIO("Routing local subscriptions are tracked correctly", "[nuclearnet][routing]") { + Routing routing; + + routing.set_local_subscriptions({0x1111, 0x2222, 0x3333}); + auto subs = routing.get_local_subscriptions(); + REQUIRE(subs.size() == 3); + REQUIRE(std::find(subs.begin(), subs.end(), 0x1111) != subs.end()); + REQUIRE(std::find(subs.begin(), subs.end(), 0x2222) != subs.end()); + REQUIRE(std::find(subs.begin(), subs.end(), 0x3333) != subs.end()); + + routing.add_local_subscription(0x4444); + subs = routing.get_local_subscriptions(); + REQUIRE(subs.size() == 4); + REQUIRE(std::find(subs.begin(), subs.end(), 0x4444) != subs.end()); +} + +SCENARIO("Routing removes peer correctly", "[nuclearnet][routing]") { + Routing routing; + + sock_t peer = make_addr(0x0A000001, 5000); + + routing.update_peer_subscriptions(peer, {0x1111}); + REQUIRE(routing.should_send(peer, 0x1111)); + REQUIRE_FALSE(routing.should_send(peer, 0x9999)); // Not subscribed + + routing.remove_peer(peer); + + // After removal, peer is unknown again — defaults to allowing sends + REQUIRE(routing.should_send(peer, 0x1111)); + REQUIRE(routing.should_send(peer, 0x9999)); +} + +SCENARIO("Routing updates subscriptions for existing peer", "[nuclearnet][routing]") { + Routing routing; + + sock_t peer = make_addr(0x0A000001, 5000); + + routing.update_peer_subscriptions(peer, {0x1111}); + REQUIRE(routing.should_send(peer, 0x1111)); + REQUIRE_FALSE(routing.should_send(peer, 0x2222)); + + // Update subscriptions + routing.update_peer_subscriptions(peer, {0x2222}); + REQUIRE_FALSE(routing.should_send(peer, 0x1111)); + REQUIRE(routing.should_send(peer, 0x2222)); +} diff --git a/tests/tests/nuclearnet/wire_protocol.cpp b/tests/tests/nuclearnet/wire_protocol.cpp new file mode 100644 index 00000000..56bdfb3f --- /dev/null +++ b/tests/tests/nuclearnet/wire_protocol.cpp @@ -0,0 +1,110 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "nuclearnet/wire_protocol.hpp" + +#include +#include +#include + +using namespace NUClear::network; + +SCENARIO("Wire protocol structs have correct packed sizes", "[nuclearnet][wire_protocol]") { + // PacketHeader: magic(3) + version(1) + type(1) = 5 + REQUIRE(sizeof(PacketHeader) == 5); + + // DataPacket: header(5) + packet_id(2) + packet_no(2) + packet_count(2) + flags(1) + hash(8) + data(1) = 21 + REQUIRE(sizeof(DataPacket) == 21); + + // ACKPacket: header(5) + packet_id(2) + packet_count(2) + packets(1) = 10 + REQUIRE(sizeof(ACKPacket) == 10); + + // LeavePacket: just the header = 5 + REQUIRE(sizeof(LeavePacket) == 5); +} + +SCENARIO("Wire protocol header is laid out at expected byte offsets", "[nuclearnet][wire_protocol]") { + DataPacket pkt; + pkt.packet_id = 0x0102; + pkt.packet_no = 0x0304; + pkt.packet_count = 0x0506; + pkt.flags = 0x07; + pkt.hash = 0x08090A0B0C0D0E0F; + + const auto* raw = reinterpret_cast(&pkt); + + // Magic bytes at offset 0-2 + REQUIRE(raw[0] == 0xE2); + REQUIRE(raw[1] == 0x98); + REQUIRE(raw[2] == 0xA2); + + // Version at offset 3 + REQUIRE(raw[3] == PROTOCOL_VERSION); + + // Type at offset 4 + REQUIRE(raw[4] == DATA); + + // packet_id at offset 5-6 + uint16_t pid; + std::memcpy(&pid, raw + 5, 2); + REQUIRE(pid == 0x0102); + + // packet_no at offset 7-8 + uint16_t pno; + std::memcpy(&pno, raw + 7, 2); + REQUIRE(pno == 0x0304); + + // packet_count at offset 9-10 + uint16_t pcnt; + std::memcpy(&pcnt, raw + 9, 2); + REQUIRE(pcnt == 0x0506); + + // flags at offset 11 + REQUIRE(raw[11] == 0x07); + + // hash at offset 12-19 + uint64_t h; + std::memcpy(&h, raw + 12, 8); + REQUIRE(h == 0x08090A0B0C0D0E0F); +} + +SCENARIO("validate_header accepts valid packets", "[nuclearnet][wire_protocol]") { + DataPacket pkt; + const auto* raw = reinterpret_cast(&pkt); + + REQUIRE(validate_header(raw, sizeof(DataPacket))); +} + +SCENARIO("validate_header rejects packet too short", "[nuclearnet][wire_protocol]") { + uint8_t data[4] = {0xE2, 0x98, 0xA2, 0x03}; + REQUIRE_FALSE(validate_header(data, 4)); +} + +SCENARIO("validate_header rejects wrong magic bytes", "[nuclearnet][wire_protocol]") { + uint8_t data[5] = {0x00, 0x00, 0x00, 0x03, 0x01}; + REQUIRE_FALSE(validate_header(data, 5)); +} + +SCENARIO("validate_header rejects wrong protocol version", "[nuclearnet][wire_protocol]") { + uint8_t data[5] = {0xE2, 0x98, 0xA2, 0x99, 0x01}; // Version 0x99 + REQUIRE_FALSE(validate_header(data, 5)); +}