diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 1d4b36d4..a102e888 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -8,6 +8,9 @@ on: branches: - main +env: + CARGO_TERM_COLOR: always + jobs: rustfmt: name: Rust format @@ -33,10 +36,22 @@ jobs: - uses: Swatinem/rust-cache@v2 + - run: cargo build -p sshx-server --release + + - run: cargo build -p sshx --release + - run: cargo test - run: cargo clippy --all-targets -- -D warnings + - uses: actions/upload-artifact@v4 + with: + name: linux-rust-binaries + path: | + target/release/sshx + target/release/sshx-server + if-no-files-found: error + windows_test: name: Client test (Windows) runs-on: windows-latest @@ -50,8 +65,16 @@ jobs: - uses: Swatinem/rust-cache@v2 + - run: cargo build -p sshx --release + - run: cargo test -p sshx + - uses: actions/upload-artifact@v4 + with: + name: windows-sshx + path: target/release/sshx.exe + if-no-files-found: error + web: name: Web lint, check, and build runs-on: ubuntu-latest @@ -62,6 +85,7 @@ jobs: - uses: actions/setup-node@v4 with: node-version: "18" + cache: npm - run: npm ci @@ -71,6 +95,12 @@ jobs: - run: npm run build + - uses: actions/upload-artifact@v4 + with: + name: web-build + path: build + if-no-files-found: error + deploy: name: Deploy runs-on: ubuntu-latest diff --git a/crates/sshx-server/src/session.rs b/crates/sshx-server/src/session.rs index b0d0eefc..b4a830fc 100644 --- a/crates/sshx-server/src/session.rs +++ b/crates/sshx-server/src/session.rs @@ -24,6 +24,7 @@ mod snapshot; /// Store a rolling buffer with at most this quantity of output, per shell. const SHELL_STORED_BYTES: u64 = 1 << 21; // 2 MiB +const SHELL_CHUNK_BROADCAST_CAPACITY: usize = 256; /// Static metadata for this session. #[derive(Debug, Clone)] @@ -80,7 +81,7 @@ pub struct Session { } /// Internal state for each shell. -#[derive(Default, Debug)] +#[derive(Debug)] struct State { /// Sequence number, indicating how many bytes have been received. seqnum: u64, @@ -99,6 +100,29 @@ struct State { /// Updated when any of the above fields change. notify: Arc, + + chunks: broadcast::Sender, +} + +impl State { + fn new() -> Self { + let (chunks, _rx) = broadcast::channel(SHELL_CHUNK_BROADCAST_CAPACITY); + State { + seqnum: 0, + data: Vec::new(), + chunk_offset: 0, + byte_offset: 0, + closed: false, + notify: Arc::new(Notify::new()), + chunks, + } + } +} + +fn encode_wsserver(msg: &WsServer) -> Result { + let mut buf = Vec::new(); + ciborium::ser::into_writer(msg, &mut buf)?; + Ok(Bytes::from(buf)) } impl Session { @@ -113,7 +137,7 @@ impl Session { counter: IdCounter::default(), last_accessed: Mutex::new(now), source: watch::channel(Vec::new()).0, - broadcast: broadcast::channel(64).0, + broadcast: broadcast::channel(1024).0, update_tx, update_rx, sync_notify: Notify::new(), @@ -155,6 +179,39 @@ impl Session { WatchStream::new(self.source.subscribe()) } + #[allow(missing_docs)] + pub fn list_shells(&self) -> Vec<(Sid, WsWinsize)> { + self.source.borrow().clone() + } + + #[allow(missing_docs)] + pub fn init_chunk_subscription( + &self, + id: Sid, + chunknum: u64, + ) -> Option<(broadcast::Receiver, u64, Vec, u64)> { + let shells = self.shells.read(); + let shell = shells.get(&id)?; + if shell.closed { + return None; + } + + let rx = shell.chunks.subscribe(); + let mut seqnum = shell.byte_offset; + let baseline_chunks = shell.chunk_offset + shell.data.len() as u64; + if chunknum < baseline_chunks { + let start = chunknum.saturating_sub(shell.chunk_offset) as usize; + seqnum += shell.data[..start] + .iter() + .map(|x| x.len() as u64) + .sum::(); + let chunks = shell.data[start..].to_vec(); + Some((rx, seqnum, chunks, baseline_chunks)) + } else { + Some((rx, shell.seqnum, Vec::new(), baseline_chunks)) + } + } + /// Subscribe for chunks from a shell, until it is closed. pub fn subscribe_chunks( &self, @@ -201,7 +258,7 @@ impl Session { use std::collections::hash_map::Entry::*; let _guard = match self.shells.write().entry(id) { Occupied(_) => bail!("shell already exists with id={id}"), - Vacant(v) => v.insert(State::default()), + Vacant(v) => v.insert(State::new()), }; self.source.send_modify(|source| { let winsize = WsWinsize { @@ -263,8 +320,9 @@ impl Session { let start = shell.seqnum - seq; let segment = data.slice(start as usize..); debug!(%id, bytes = segment.len(), "adding data to shell"); + let seqnum = shell.seqnum; shell.seqnum += segment.len() as u64; - shell.data.push(segment); + shell.data.push(Bytes::clone(&segment)); // Prune old chunks if we've exceeded the maximum stored bytes. let mut stored_bytes = shell.seqnum - shell.byte_offset; @@ -280,6 +338,11 @@ impl Session { shell.data.drain(..offset); } + let msg = WsServer::Chunks(id, seqnum, vec![segment]); + if let Ok(encoded) = encode_wsserver(&msg) { + shell.chunks.send(encoded).ok(); + } + shell.notify.notify_waiters(); } diff --git a/crates/sshx-server/src/session/snapshot.rs b/crates/sshx-server/src/session/snapshot.rs index a1f2579c..66782728 100644 --- a/crates/sshx-server/src/session/snapshot.rs +++ b/crates/sshx-server/src/session/snapshot.rs @@ -93,15 +93,13 @@ impl Session { cols: shell.winsize_cols.try_into().context("cols overflow")?, }, )); - let shell = State { - seqnum: shell.seqnum, - data: shell.data, - chunk_offset: shell.chunk_offset, - byte_offset: shell.byte_offset, - closed: shell.closed, - notify: Default::default(), - }; - shells.insert(Sid(sid), shell); + let mut state = State::new(); + state.seqnum = shell.seqnum; + state.data = shell.data; + state.chunk_offset = shell.chunk_offset; + state.byte_offset = shell.byte_offset; + state.closed = shell.closed; + shells.insert(Sid(sid), state); } drop(shells); session.source.send_replace(winsizes); diff --git a/crates/sshx-server/src/web/socket.rs b/crates/sshx-server/src/web/socket.rs index 85a43d62..01f6dde4 100644 --- a/crates/sshx-server/src/web/socket.rs +++ b/crates/sshx-server/src/web/socket.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use std::sync::Arc; -use anyhow::{Context, Result}; +use anyhow::Result; use axum::extract::{ ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade}, Path, State, @@ -10,9 +10,10 @@ use axum::response::IntoResponse; use bytes::Bytes; use futures_util::SinkExt; use sshx_core::proto::{server_update::ServerMessage, NewShell, TerminalInput, TerminalSize}; -use sshx_core::Sid; use subtle::ConstantTimeEq; use tokio::sync::mpsc; +use tokio::time::{Duration, Instant}; +use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; use tokio_stream::StreamExt; use tracing::{error, info_span, warn, Instrument}; @@ -79,6 +80,11 @@ async fn handle_socket(socket: &mut WebSocket, session: Arc) -> Result< Ok(()) } + async fn send_bytes(socket: &mut WebSocket, bytes: Bytes) -> Result<()> { + socket.send(Message::Binary(bytes)).await?; + Ok(()) + } + /// Receive a message from the client over WebSocket. async fn recv(socket: &mut WebSocket) -> Result> { Ok(loop { @@ -134,23 +140,29 @@ async fn handle_socket(socket: &mut WebSocket, session: Arc) -> Result< send(socket, WsServer::Users(session.list_users())).await?; let mut subscribed = HashSet::new(); // prevent duplicate subscriptions - let (chunks_tx, mut chunks_rx) = mpsc::channel::<(Sid, u64, Vec)>(1); + let (chunks_tx, mut chunks_rx) = mpsc::channel::(32); + let mut last_cursor_update = Instant::now() - Duration::from_secs(1); let mut shells_stream = session.subscribe_shells(); loop { let msg = tokio::select! { _ = session.terminated() => break, Some(result) = broadcast_stream.next() => { - let msg = result.context("client fell behind on broadcast stream")?; - send(socket, msg).await?; + match result { + Ok(msg) => send(socket, msg).await?, + Err(BroadcastStreamRecvError::Lagged(_)) => { + send(socket, WsServer::Users(session.list_users())).await?; + send(socket, WsServer::Shells(session.list_shells())).await?; + } + } continue; } Some(shells) = shells_stream.next() => { send(socket, WsServer::Shells(shells)).await?; continue; } - Some((id, seqnum, chunks)) = chunks_rx.recv() => { - send(socket, WsServer::Chunks(id, seqnum, chunks)).await?; + Some(bytes) = chunks_rx.recv() => { + send_bytes(socket, bytes).await?; continue; } result = recv(socket) => { @@ -169,7 +181,11 @@ async fn handle_socket(socket: &mut WebSocket, session: Arc) -> Result< } } WsClient::SetCursor(cursor) => { - session.update_user(user_id, |user| user.cursor = cursor)?; + let now = Instant::now(); + if now.duration_since(last_cursor_update) >= Duration::from_millis(33) { + last_cursor_update = now; + session.update_user(user_id, |user| user.cursor = cursor)?; + } } WsClient::SetFocus(id) => { session.update_user(user_id, |user| user.focus = id)?; @@ -231,11 +247,51 @@ async fn handle_socket(socket: &mut WebSocket, session: Arc) -> Result< let session = Arc::clone(&session); let chunks_tx = chunks_tx.clone(); tokio::spawn(async move { - let stream = session.subscribe_chunks(id, chunknum); - tokio::pin!(stream); - while let Some((seqnum, chunks)) = stream.next().await { - if chunks_tx.send((id, seqnum, chunks)).await.is_err() { - break; + let Some((rx, seqnum, chunks, baseline_chunks)) = + session.init_chunk_subscription(id, chunknum) + else { + return; + }; + let mut next_chunknum = baseline_chunks; + if !chunks.is_empty() { + let msg = WsServer::Chunks(id, seqnum, chunks); + let mut buf = Vec::new(); + if ciborium::ser::into_writer(&msg, &mut buf).is_err() { + return; + } + if chunks_tx.send(Bytes::from(buf)).await.is_err() { + return; + } + } + + let mut stream = BroadcastStream::new(rx); + while let Some(item) = stream.next().await { + match item { + Ok(bytes) => { + if chunks_tx.send(bytes).await.is_err() { + return; + } + next_chunknum += 1; + } + Err(BroadcastStreamRecvError::Lagged(_)) => loop { + let Some((_rx, seqnum, chunks, baseline_chunks)) = + session.init_chunk_subscription(id, next_chunknum) + else { + return; + }; + next_chunknum = baseline_chunks; + if chunks.is_empty() { + break; + } + let msg = WsServer::Chunks(id, seqnum, chunks); + let mut buf = Vec::new(); + if ciborium::ser::into_writer(&msg, &mut buf).is_err() { + return; + } + if chunks_tx.send(Bytes::from(buf)).await.is_err() { + return; + } + }, } } }); diff --git a/crates/sshx-server/tests/stress_ws.rs b/crates/sshx-server/tests/stress_ws.rs new file mode 100644 index 00000000..de3e8803 --- /dev/null +++ b/crates/sshx-server/tests/stress_ws.rs @@ -0,0 +1,88 @@ +use anyhow::Result; +use futures_util::{SinkExt, StreamExt}; +use sshx::{controller::Controller, encrypt::Encrypt, runner::Runner}; +use sshx_server::web::protocol::{WsClient, WsServer}; +use tokio::time::{sleep, Duration}; +use tokio_tungstenite::tungstenite::Message; + +use crate::common::*; + +pub mod common; + +#[tokio::test] +async fn test_ws_broadcast_lag_resync() -> Result<()> { + let server = TestServer::new().await; + + let mut controller = Controller::new(&server.endpoint(), "", Runner::Echo, false).await?; + let name = controller.name().to_owned(); + let key = controller.encryption_key().to_owned(); + tokio::spawn(async move { controller.run().await }); + + let endpoint = server.ws_endpoint(&name); + let mut s1 = ClientSocket::connect(&endpoint, &key, None).await?; + let mut s2 = ClientSocket::connect(&endpoint, &key, None).await?; + s1.flush().await; + s2.flush().await; + + for i in 0..5000 { + let focus = if i % 2 == 0 { + Some(sshx_core::Sid(1)) + } else { + None + }; + s2.send(WsClient::SetFocus(focus)).await; + } + + s1.send(WsClient::Create(0, 0)).await; + s1.flush().await; + assert_eq!(s1.shells.len(), 1); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 8)] +#[ignore] +async fn stress_1000_ws_clients() -> Result<()> { + const CLIENTS: usize = 1000; + + let server = TestServer::new().await; + + let mut controller = Controller::new(&server.endpoint(), "", Runner::Echo, false).await?; + let name = controller.name().to_owned(); + let key = controller.encryption_key().to_owned(); + tokio::spawn(async move { controller.run().await }); + + let endpoint = server.ws_endpoint(&name); + let encrypt = Encrypt::new(&key); + let auth = WsClient::Authenticate(encrypt.zeros().into(), None); + let mut auth_buf = Vec::new(); + ciborium::ser::into_writer(&auth, &mut auth_buf)?; + let auth_msg = Message::Binary(auth_buf.into()); + + let mut clients = Vec::with_capacity(CLIENTS); + for _ in 0..CLIENTS { + let (mut ws, _resp) = tokio_tungstenite::connect_async(&endpoint).await?; + loop { + match ws.next().await.transpose()? { + Some(Message::Binary(msg)) => { + let parsed: WsServer = ciborium::de::from_reader(&*msg)?; + if matches!(parsed, WsServer::Hello(_, _)) { + break; + } + } + Some(_) => {} + None => anyhow::bail!("server closed connection during handshake"), + } + } + ws.send(auth_msg.clone()).await?; + clients.push(ws); + } + + sleep(Duration::from_millis(250)).await; + + for mut ws in clients { + ws.close(None).await.ok(); + } + + Ok(()) +}