Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions crates/sshx-core/proto/sshx.proto
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ message ClientUpdate {
uint32 closed_shell = 4; // Acknowledge that a shell was closed.
fixed64 pong = 14; // Response for latency measurement.
string error = 15;
FileListResponse file_list = 16;
FileChunkResponse file_chunk = 17;
FileDeletedResponse file_deleted = 18;
FileRenamedResponse file_renamed = 19;
}
}

Expand All @@ -85,6 +89,11 @@ message ServerUpdate {
TerminalSize resize = 5; // Resize a terminal window.
fixed64 ping = 14; // Request a pong, with the timestamp.
string error = 15;
FileListRequest list_files = 16;
FileDownloadRequest download_file = 17;
FileUploadRequest upload_file = 18;
FileDeleteRequest delete_file = 19;
FileRenameRequest rename_file = 20;
}
}

Expand Down Expand Up @@ -118,3 +127,63 @@ message SerializedShell {
uint32 winsize_rows = 8;
uint32 winsize_cols = 9;
}

message FileEntry {
string name = 1;
bool is_dir = 2;
uint64 size = 3;
int64 modified = 4;
}

message FileListRequest {
uint32 id = 1;
string path = 2;
}

message FileDownloadRequest {
uint32 id = 1;
string path = 2;
}

message FileUploadRequest {
uint32 id = 1;
string path = 2;
bytes data = 3;
bool done = 4;
}

message FileDeleteRequest {
uint32 id = 1;
string path = 2;
}

message FileRenameRequest {
uint32 id = 1;
string path = 2;
string new_name = 3;
}

message FileListResponse {
uint32 id = 1;
repeated FileEntry entries = 2;
string path = 3;
}

message FileChunkResponse {
uint32 id = 1;
bytes data = 2;
bool done = 3;
uint64 size = 4;
}

message FileDeletedResponse {
uint32 id = 1;
string error = 2;
string path = 3;
}

message FileRenamedResponse {
uint32 id = 1;
string new_path = 2;
string error = 3;
}
27 changes: 27 additions & 0 deletions crates/sshx-server/src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use tonic::{Request, Response, Status, Streaming};
use tracing::{error, info, warn};

use crate::session::{Metadata, Session};
use crate::web::protocol::WsFileEntry;
use crate::ServerState;

/// Interval for synchronizing sequence numbers with the client.
Expand Down Expand Up @@ -218,6 +219,32 @@ async fn handle_update(tx: &ServerTx, session: &Session, update: ClientUpdate) -
// TODO: Propagate these errors to listeners on the web interface?
error!(?err, "error received from client");
}
Some(ClientMessage::FileList(resp)) => {
let entries: Vec<WsFileEntry> = resp.entries.into_iter().map(|e| WsFileEntry {
name: e.name,
is_dir: e.is_dir,
size: e.size,
modified: e.modified,
}).collect();
session.send_file_list(resp.path, entries);
}
Some(ClientMessage::FileChunk(resp)) => {
session.forward_chunk(resp.id, resp.data, resp.done);
}
Some(ClientMessage::FileDeleted(resp)) => {
if resp.error.is_empty() {
session.send_file_changed(resp.path, "deleted".into());
} else {
session.send_file_error(resp.path, resp.error);
}
}
Some(ClientMessage::FileRenamed(resp)) => {
if resp.error.is_empty() {
session.send_file_changed(resp.new_path, "renamed".into());
} else {
session.send_file_error(resp.new_path, resp.error);
}
}
None => (), // Heartbeat message, ignored.
}
true
Expand Down
38 changes: 37 additions & 1 deletion crates/sshx-server/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::Arc;
use anyhow::{bail, Context, Result};
use bytes::Bytes;
use parking_lot::{Mutex, RwLock, RwLockWriteGuard};
use tokio::sync::mpsc;
use sshx_core::{
proto::{server_update::ServerMessage, SequenceNumbers},
IdCounter, Sid, Uid,
Expand All @@ -18,7 +19,7 @@ use tokio_stream::Stream;
use tracing::{debug, warn};

use crate::utils::Shutdown;
use crate::web::protocol::{WsServer, WsUser, WsWinsize};
use crate::web::protocol::{WsFileEntry, WsServer, WsUser, WsWinsize};

mod snapshot;

Expand Down Expand Up @@ -77,6 +78,9 @@ pub struct Session {

/// Set when this session has been closed and removed.
shutdown: Shutdown,

/// Pending download channels keyed by file operation ID.
pending_downloads: Mutex<HashMap<u32, mpsc::Sender<Bytes>>>,
}

/// Internal state for each shell.
Expand Down Expand Up @@ -118,6 +122,7 @@ impl Session {
update_rx,
sync_notify: Notify::new(),
shutdown: Shutdown::new(),
pending_downloads: Mutex::new(HashMap::new()),
}
}

Expand Down Expand Up @@ -373,6 +378,37 @@ impl Session {
self.broadcast.send(WsServer::ShellLatency(latency)).ok();
}

/// Send a file listing to all WebSocket clients.
pub fn send_file_list(&self, path: String, entries: Vec<WsFileEntry>) {
self.broadcast.send(WsServer::FileList(path, entries)).ok();
}

/// Send a file change notification to all WebSocket clients.
pub fn send_file_changed(&self, path: String, kind: String) {
self.broadcast.send(WsServer::FileChanged(path, kind)).ok();
}

/// Send a file operation error to all WebSocket clients.
pub fn send_file_error(&self, path: String, error: String) {
self.broadcast.send(WsServer::FileError(path, error)).ok();
}

/// Register a channel to receive download chunks for a file operation.
pub fn register_download(&self, id: u32, tx: mpsc::Sender<Bytes>) {
self.pending_downloads.lock().insert(id, tx);
}

/// Forward a file chunk to the registered download channel.
pub fn forward_chunk(&self, id: u32, data: Bytes, done: bool) {
let mut pending = self.pending_downloads.lock();
if let Some(tx) = pending.get(&id) {
let removed = tx.try_send(data).is_err() || done;
if removed {
pending.remove(&id);
}
}
}

/// Register a backend client heartbeat, refreshing the timestamp.
pub fn access(&self) {
*self.last_accessed.lock() = Instant::now();
Expand Down
10 changes: 8 additions & 2 deletions crates/sshx-server/src/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

use std::sync::Arc;

use axum::routing::{any, get_service};
use axum::routing::{any, get, get_service};
use axum::Router;
use tower_http::services::{ServeDir, ServeFile};

use crate::ServerState;

pub mod files;
pub mod protocol;
mod socket;

Expand All @@ -30,5 +31,10 @@ pub fn app() -> Router<Arc<ServerState>> {

/// Routes for the backend web API server.
fn backend() -> Router<Arc<ServerState>> {
Router::new().route("/s/{name}", any(socket::get_session_ws))
Router::new()
.route("/s/{name}", any(socket::get_session_ws))
.route(
"/s/{name}/files",
get(files::download_file).post(files::upload_file),
)
}
140 changes: 140 additions & 0 deletions crates/sshx-server/src/web/files.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
//! HTTP file transfer endpoints for downloading and uploading files
//! through a session's CLI connection.

use std::sync::Arc;

use axum::body::Body;
use axum::extract::{Path, Query, State};
use axum::http::{header, HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use base64::prelude::{Engine as _, BASE64_STANDARD};
use bytes::Bytes;
use serde::Deserialize;
use subtle::ConstantTimeEq;

use sshx_core::proto::server_update::ServerMessage;
use sshx_core::proto::{FileDownloadRequest, FileUploadRequest};

use crate::ServerState;

/// Query parameters for file transfer endpoints.
#[derive(Deserialize)]
pub struct FileParams {
path: String,
}

/// Validates the session authentication token from the `X-SSHX-Key` header.
async fn validate_auth(
state: &ServerState,
name: &str,
headers: &HeaderMap,
) -> Result<Arc<crate::session::Session>, StatusCode> {
let session = state.lookup(name).ok_or(StatusCode::NOT_FOUND)?;

let key = headers
.get("X-SSHX-Key")
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;

let key_bytes = BASE64_STANDARD
.decode(key)
.map_err(|_| StatusCode::UNAUTHORIZED)?;

if !bool::from(session.metadata().encrypted_zeros.ct_eq(&key_bytes)) {
return Err(StatusCode::UNAUTHORIZED);
}

Ok(session)
}

/// HTTP GET handler that streams a file from the CLI via gRPC.
pub async fn download_file(
State(state): State<Arc<ServerState>>,
Path(name): Path<String>,
Query(params): Query<FileParams>,
headers: HeaderMap,
) -> Response {
let session = match validate_auth(&state, &name, &headers).await {
Ok(s) => s,
Err(code) => return (code, "Unauthorized").into_response(),
};

if params.path.is_empty() {
return (StatusCode::BAD_REQUEST, "path is required").into_response();
}

let file_id = session.counter().next_sid().0;
let filename = std::path::Path::new(&params.path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("download");

let (chunk_tx, chunk_rx) = tokio::sync::mpsc::channel::<Bytes>(32);
session.register_download(file_id, chunk_tx);
let update_tx = session.update_tx().clone();
let path = params.path.clone();
tokio::spawn(async move {
let msg = ServerMessage::DownloadFile(FileDownloadRequest { id: file_id, path });
let _ = update_tx.send(msg).await;
});

let stream = async_stream::stream! {
let mut rx = chunk_rx;
while let Some(chunk) = rx.recv().await {
if chunk.is_empty() {
break;
}
yield Ok::<Bytes, std::convert::Infallible>(chunk);
}
};

Response::builder()
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(
header::CONTENT_DISPOSITION,
format!("attachment; filename=\"{}\"", filename),
)
.body(Body::from_stream(stream))
.unwrap()
}

/// HTTP POST handler that uploads a file to the CLI via gRPC.
pub async fn upload_file(
State(state): State<Arc<ServerState>>,
Path(name): Path<String>,
Query(params): Query<FileParams>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let session = match validate_auth(&state, &name, &headers).await {
Ok(s) => s,
Err(code) => return (code, "Unauthorized").into_response(),
};

if params.path.is_empty() {
return (StatusCode::BAD_REQUEST, "path is required").into_response();
}

let file_id = session.counter().next_sid().0;
let update_tx = session.update_tx().clone();
const CHUNK_SIZE: usize = 1 << 16; // 64 KiB
let path = params.path.clone();

tokio::spawn(async move {
let total = body.len();
for (i, chunk) in body.chunks(CHUNK_SIZE).enumerate() {
let is_last = i * CHUNK_SIZE + chunk.len() >= total;
let msg = ServerMessage::UploadFile(FileUploadRequest {
id: file_id,
path: path.clone(),
data: chunk.to_vec().into(),
done: is_last,
});
if update_tx.send(msg).await.is_err() {
break;
}
}
});

StatusCode::OK.into_response()
}
Loading
Loading