Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ on:
branches:
- main

env:
CARGO_TERM_COLOR: always

jobs:
rustfmt:
name: Rust format
Expand All @@ -33,10 +36,22 @@ jobs:

- uses: Swatinem/rust-cache@v2

- run: cargo build -p sshx-server --release

- run: cargo build -p sshx --release

- run: cargo test

- run: cargo clippy --all-targets -- -D warnings

- uses: actions/upload-artifact@v4
with:
name: linux-rust-binaries
path: |
target/release/sshx
target/release/sshx-server
if-no-files-found: error

windows_test:
name: Client test (Windows)
runs-on: windows-latest
Expand All @@ -50,8 +65,16 @@ jobs:

- uses: Swatinem/rust-cache@v2

- run: cargo build -p sshx --release

- run: cargo test -p sshx

- uses: actions/upload-artifact@v4
with:
name: windows-sshx
path: target/release/sshx.exe
if-no-files-found: error

web:
name: Web lint, check, and build
runs-on: ubuntu-latest
Expand All @@ -62,6 +85,7 @@ jobs:
- uses: actions/setup-node@v4
with:
node-version: "18"
cache: npm

- run: npm ci

Expand All @@ -71,6 +95,12 @@ jobs:

- run: npm run build

- uses: actions/upload-artifact@v4
with:
name: web-build
path: build
if-no-files-found: error

deploy:
name: Deploy
runs-on: ubuntu-latest
Expand Down
71 changes: 67 additions & 4 deletions crates/sshx-server/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod snapshot;

/// Store a rolling buffer with at most this quantity of output, per shell.
const SHELL_STORED_BYTES: u64 = 1 << 21; // 2 MiB
const SHELL_CHUNK_BROADCAST_CAPACITY: usize = 256;

/// Static metadata for this session.
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -80,7 +81,7 @@ pub struct Session {
}

/// Internal state for each shell.
#[derive(Default, Debug)]
#[derive(Debug)]
struct State {
/// Sequence number, indicating how many bytes have been received.
seqnum: u64,
Expand All @@ -99,6 +100,29 @@ struct State {

/// Updated when any of the above fields change.
notify: Arc<Notify>,

chunks: broadcast::Sender<Bytes>,
}

impl State {
fn new() -> Self {
let (chunks, _rx) = broadcast::channel(SHELL_CHUNK_BROADCAST_CAPACITY);
State {
seqnum: 0,
data: Vec::new(),
chunk_offset: 0,
byte_offset: 0,
closed: false,
notify: Arc::new(Notify::new()),
chunks,
}
}
}

fn encode_wsserver(msg: &WsServer) -> Result<Bytes> {
let mut buf = Vec::new();
ciborium::ser::into_writer(msg, &mut buf)?;
Ok(Bytes::from(buf))
}

impl Session {
Expand All @@ -113,7 +137,7 @@ impl Session {
counter: IdCounter::default(),
last_accessed: Mutex::new(now),
source: watch::channel(Vec::new()).0,
broadcast: broadcast::channel(64).0,
broadcast: broadcast::channel(1024).0,
update_tx,
update_rx,
sync_notify: Notify::new(),
Expand Down Expand Up @@ -155,6 +179,39 @@ impl Session {
WatchStream::new(self.source.subscribe())
}

#[allow(missing_docs)]
pub fn list_shells(&self) -> Vec<(Sid, WsWinsize)> {
self.source.borrow().clone()
}

#[allow(missing_docs)]
pub fn init_chunk_subscription(
&self,
id: Sid,
chunknum: u64,
) -> Option<(broadcast::Receiver<Bytes>, u64, Vec<Bytes>, u64)> {
let shells = self.shells.read();
let shell = shells.get(&id)?;
if shell.closed {
return None;
}

let rx = shell.chunks.subscribe();
let mut seqnum = shell.byte_offset;
let baseline_chunks = shell.chunk_offset + shell.data.len() as u64;
if chunknum < baseline_chunks {
let start = chunknum.saturating_sub(shell.chunk_offset) as usize;
seqnum += shell.data[..start]
.iter()
.map(|x| x.len() as u64)
.sum::<u64>();
let chunks = shell.data[start..].to_vec();
Some((rx, seqnum, chunks, baseline_chunks))
} else {
Some((rx, shell.seqnum, Vec::new(), baseline_chunks))
}
}

/// Subscribe for chunks from a shell, until it is closed.
pub fn subscribe_chunks(
&self,
Expand Down Expand Up @@ -201,7 +258,7 @@ impl Session {
use std::collections::hash_map::Entry::*;
let _guard = match self.shells.write().entry(id) {
Occupied(_) => bail!("shell already exists with id={id}"),
Vacant(v) => v.insert(State::default()),
Vacant(v) => v.insert(State::new()),
};
self.source.send_modify(|source| {
let winsize = WsWinsize {
Expand Down Expand Up @@ -263,8 +320,9 @@ impl Session {
let start = shell.seqnum - seq;
let segment = data.slice(start as usize..);
debug!(%id, bytes = segment.len(), "adding data to shell");
let seqnum = shell.seqnum;
shell.seqnum += segment.len() as u64;
shell.data.push(segment);
shell.data.push(Bytes::clone(&segment));

// Prune old chunks if we've exceeded the maximum stored bytes.
let mut stored_bytes = shell.seqnum - shell.byte_offset;
Expand All @@ -280,6 +338,11 @@ impl Session {
shell.data.drain(..offset);
}

let msg = WsServer::Chunks(id, seqnum, vec![segment]);
if let Ok(encoded) = encode_wsserver(&msg) {
shell.chunks.send(encoded).ok();
}

shell.notify.notify_waiters();
}

Expand Down
16 changes: 7 additions & 9 deletions crates/sshx-server/src/session/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,13 @@ impl Session {
cols: shell.winsize_cols.try_into().context("cols overflow")?,
},
));
let shell = State {
seqnum: shell.seqnum,
data: shell.data,
chunk_offset: shell.chunk_offset,
byte_offset: shell.byte_offset,
closed: shell.closed,
notify: Default::default(),
};
shells.insert(Sid(sid), shell);
let mut state = State::new();
state.seqnum = shell.seqnum;
state.data = shell.data;
state.chunk_offset = shell.chunk_offset;
state.byte_offset = shell.byte_offset;
state.closed = shell.closed;
shells.insert(Sid(sid), state);
}
drop(shells);
session.source.send_replace(winsizes);
Expand Down
82 changes: 69 additions & 13 deletions crates/sshx-server/src/web/socket.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashSet;
use std::sync::Arc;

use anyhow::{Context, Result};
use anyhow::Result;
use axum::extract::{
ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade},
Path, State,
Expand All @@ -10,9 +10,10 @@ use axum::response::IntoResponse;
use bytes::Bytes;
use futures_util::SinkExt;
use sshx_core::proto::{server_update::ServerMessage, NewShell, TerminalInput, TerminalSize};
use sshx_core::Sid;
use subtle::ConstantTimeEq;
use tokio::sync::mpsc;
use tokio::time::{Duration, Instant};
use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
use tokio_stream::StreamExt;
use tracing::{error, info_span, warn, Instrument};

Expand Down Expand Up @@ -79,6 +80,11 @@ async fn handle_socket(socket: &mut WebSocket, session: Arc<Session>) -> Result<
Ok(())
}

async fn send_bytes(socket: &mut WebSocket, bytes: Bytes) -> Result<()> {
socket.send(Message::Binary(bytes)).await?;
Ok(())
}

/// Receive a message from the client over WebSocket.
async fn recv(socket: &mut WebSocket) -> Result<Option<WsClient>> {
Ok(loop {
Expand Down Expand Up @@ -134,23 +140,29 @@ async fn handle_socket(socket: &mut WebSocket, session: Arc<Session>) -> Result<
send(socket, WsServer::Users(session.list_users())).await?;

let mut subscribed = HashSet::new(); // prevent duplicate subscriptions
let (chunks_tx, mut chunks_rx) = mpsc::channel::<(Sid, u64, Vec<Bytes>)>(1);
let (chunks_tx, mut chunks_rx) = mpsc::channel::<Bytes>(32);
let mut last_cursor_update = Instant::now() - Duration::from_secs(1);

let mut shells_stream = session.subscribe_shells();
loop {
let msg = tokio::select! {
_ = session.terminated() => break,
Some(result) = broadcast_stream.next() => {
let msg = result.context("client fell behind on broadcast stream")?;
send(socket, msg).await?;
match result {
Ok(msg) => send(socket, msg).await?,
Err(BroadcastStreamRecvError::Lagged(_)) => {
send(socket, WsServer::Users(session.list_users())).await?;
send(socket, WsServer::Shells(session.list_shells())).await?;
}
}
continue;
}
Some(shells) = shells_stream.next() => {
send(socket, WsServer::Shells(shells)).await?;
continue;
}
Some((id, seqnum, chunks)) = chunks_rx.recv() => {
send(socket, WsServer::Chunks(id, seqnum, chunks)).await?;
Some(bytes) = chunks_rx.recv() => {
send_bytes(socket, bytes).await?;
continue;
}
result = recv(socket) => {
Expand All @@ -169,7 +181,11 @@ async fn handle_socket(socket: &mut WebSocket, session: Arc<Session>) -> Result<
}
}
WsClient::SetCursor(cursor) => {
session.update_user(user_id, |user| user.cursor = cursor)?;
let now = Instant::now();
if now.duration_since(last_cursor_update) >= Duration::from_millis(33) {
last_cursor_update = now;
session.update_user(user_id, |user| user.cursor = cursor)?;
}
}
WsClient::SetFocus(id) => {
session.update_user(user_id, |user| user.focus = id)?;
Expand Down Expand Up @@ -231,11 +247,51 @@ async fn handle_socket(socket: &mut WebSocket, session: Arc<Session>) -> Result<
let session = Arc::clone(&session);
let chunks_tx = chunks_tx.clone();
tokio::spawn(async move {
let stream = session.subscribe_chunks(id, chunknum);
tokio::pin!(stream);
while let Some((seqnum, chunks)) = stream.next().await {
if chunks_tx.send((id, seqnum, chunks)).await.is_err() {
break;
let Some((rx, seqnum, chunks, baseline_chunks)) =
session.init_chunk_subscription(id, chunknum)
else {
return;
};
let mut next_chunknum = baseline_chunks;
if !chunks.is_empty() {
let msg = WsServer::Chunks(id, seqnum, chunks);
let mut buf = Vec::new();
if ciborium::ser::into_writer(&msg, &mut buf).is_err() {
return;
}
if chunks_tx.send(Bytes::from(buf)).await.is_err() {
return;
}
}

let mut stream = BroadcastStream::new(rx);
while let Some(item) = stream.next().await {
match item {
Ok(bytes) => {
if chunks_tx.send(bytes).await.is_err() {
return;
}
next_chunknum += 1;
}
Err(BroadcastStreamRecvError::Lagged(_)) => loop {
let Some((_rx, seqnum, chunks, baseline_chunks)) =
session.init_chunk_subscription(id, next_chunknum)
else {
return;
};
next_chunknum = baseline_chunks;
if chunks.is_empty() {
break;
}
let msg = WsServer::Chunks(id, seqnum, chunks);
let mut buf = Vec::new();
if ciborium::ser::into_writer(&msg, &mut buf).is_err() {
return;
}
if chunks_tx.send(Bytes::from(buf)).await.is_err() {
return;
}
},
}
}
});
Expand Down
Loading
Loading