From e633f0d478e6b544d4d419290f80a22190b39e96 Mon Sep 17 00:00:00 2001 From: sasicodes Date: Sat, 6 Jun 2026 19:56:42 +0530 Subject: [PATCH] Add websocket tunnel support --- Cargo.lock | 1 + crates/peek-client/src/lib.rs | 218 ++++++++++++++++++++++++- crates/peek-proto/src/lib.rs | 154 +++++++++++++++++ crates/peek-relay/Cargo.toml | 1 + crates/peek-relay/src/handler.rs | 174 +++++++++++++++++++- crates/peek-relay/src/registry.rs | 2 + crates/peek-relay/tests/integration.rs | 169 ++++++++++++++++++- 7 files changed, 707 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e0d720b..d091694 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1172,6 +1172,7 @@ dependencies = [ "sha2", "subtle", "tokio", + "tokio-tungstenite", "tower-http", "tracing", "tracing-subscriber", diff --git a/crates/peek-client/src/lib.rs b/crates/peek-client/src/lib.rs index 62d0be9..fd6c483 100644 --- a/crates/peek-client/src/lib.rs +++ b/crates/peek-client/src/lib.rs @@ -1,13 +1,13 @@ -use std::time::Duration; +use std::{collections::HashMap, sync::Arc, time::Duration}; mod pages; use futures_util::{SinkExt, StreamExt}; -use tokio::sync::mpsc; -use tokio_tungstenite::tungstenite::Message; +use tokio::sync::{Mutex, mpsc}; +use tokio_tungstenite::tungstenite::{Message, client::IntoClientRequest}; use tracing::{info, warn}; -use peek_proto::{RegistrationRequest, RegistrationResponse}; +use peek_proto::{RegistrationRequest, RegistrationResponse, WsFrame, WsMessageKind}; const HOP_BY_HOP_HEADERS: &[&str] = &[ "connection", @@ -174,6 +174,7 @@ fn spawn_connection_tasks( let write_tx_read = write_tx.clone(); let http_client = http_client.clone(); + let ws_streams = Arc::new(Mutex::new(HashMap::>::new())); tokio::spawn(async move { loop { let msg = match stream.next().await { @@ -189,8 +190,9 @@ fn spawn_connection_tasks( Message::Binary(data) => { let write_tx = write_tx_read.clone(); let http_client = http_client.clone(); + let ws_streams = ws_streams.clone(); tokio::spawn(async move { - handle_request(&data, port, &http_client, &write_tx).await; + handle_tunnel_frame(&data, port, &http_client, &write_tx, ws_streams).await; }); } Message::Ping(data) => { @@ -235,11 +237,12 @@ impl Drop for TunnelHandle { } } -async fn handle_request( +async fn handle_tunnel_frame( data: &[u8], port: u16, http_client: &reqwest::Client, write_tx: &mpsc::Sender, + ws_streams: Arc>>>, ) { let (request_id, payload) = match peek_proto::decode_frame(data) { Ok(r) => r, @@ -249,6 +252,209 @@ async fn handle_request( } }; + if peek_proto::is_ws_frame(payload) { + match peek_proto::deserialize_ws_frame(payload) { + Ok(WsFrame::Open { uri, headers }) => { + open_local_ws(request_id, port, uri, headers, write_tx.clone(), ws_streams).await; + } + Ok(frame) => { + let streams = ws_streams.lock().await; + if let Some(tx) = streams.get(&request_id) { + let _ = tx.send(frame).await; + } + } + Err(e) => { + warn!(error = %e, "failed to decode websocket frame"); + } + } + return; + } + + handle_request(request_id, payload, port, http_client, write_tx).await; +} + +async fn open_local_ws( + request_id: u32, + port: u16, + uri: String, + headers: Vec<(String, String)>, + write_tx: mpsc::Sender, + ws_streams: Arc>>>, +) { + let (incoming_tx, incoming_rx) = mpsc::channel::(256); + ws_streams.lock().await.insert(request_id, incoming_tx); + tokio::spawn(connect_and_bridge_local_ws( + request_id, + port, + uri, + headers, + incoming_rx, + write_tx, + ws_streams, + )); +} + +async fn connect_and_bridge_local_ws( + request_id: u32, + port: u16, + uri: String, + headers: Vec<(String, String)>, + incoming_rx: mpsc::Receiver, + write_tx: mpsc::Sender, + ws_streams: Arc>>>, +) { + let local_url = format!("ws://127.0.0.1:{port}{uri}"); + let request = match local_ws_request(&local_url, &headers) { + Ok(request) => request, + Err(e) => { + warn!(error = %e, url = %local_url, "failed to build local websocket request"); + ws_streams.lock().await.remove(&request_id); + send_ws_close(request_id, &write_tx).await; + return; + } + }; + let (local_ws, _) = match tokio_tungstenite::connect_async(request).await { + Ok(result) => result, + Err(e) => { + warn!(error = %e, url = %local_url, "failed to connect local websocket"); + ws_streams.lock().await.remove(&request_id); + send_ws_close(request_id, &write_tx).await; + return; + } + }; + + bridge_local_ws(request_id, local_ws, incoming_rx, write_tx, ws_streams).await; +} + +fn local_ws_request( + local_url: &str, + headers: &[(String, String)], +) -> Result { + let mut request = local_url + .into_client_request() + .map_err(|error| error.to_string())?; + for (name, value) in headers { + if is_ws_handshake_header(name) { + continue; + } + let Ok(name) = name.parse::() else { + continue; + }; + let Ok(value) = value.parse::() else { + continue; + }; + request.headers_mut().insert(name, value); + } + Ok(request) +} + +fn is_ws_handshake_header(name: &str) -> bool { + is_hop_by_hop_header(name) + || name.eq_ignore_ascii_case("sec-websocket-accept") + || name.eq_ignore_ascii_case("sec-websocket-key") + || name.eq_ignore_ascii_case("sec-websocket-version") +} + +async fn bridge_local_ws( + request_id: u32, + local_ws: tokio_tungstenite::WebSocketStream, + mut incoming_rx: mpsc::Receiver, + write_tx: mpsc::Sender, + ws_streams: Arc>>>, +) where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, +{ + let (mut local_sink, mut local_stream) = local_ws.split(); + loop { + tokio::select! { + msg = local_stream.next() => { + let Some(Ok(msg)) = msg else { + send_ws_close(request_id, &write_tx).await; + break; + }; + if !send_local_ws_message_to_tunnel(request_id, msg, &write_tx).await { + break; + } + } + frame = incoming_rx.recv() => { + let Some(frame) = frame else { + break; + }; + if !send_tunnel_ws_message_to_local(&mut local_sink, frame).await { + break; + } + } + } + } + + ws_streams.lock().await.remove(&request_id); +} + +async fn send_local_ws_message_to_tunnel( + request_id: u32, + msg: Message, + write_tx: &mpsc::Sender, +) -> bool { + let payload = match msg { + Message::Text(text) => { + peek_proto::serialize_ws_message(WsMessageKind::Text, text.as_str().as_bytes()) + } + Message::Binary(data) => peek_proto::serialize_ws_message(WsMessageKind::Binary, &data), + Message::Ping(data) => peek_proto::serialize_ws_message(WsMessageKind::Ping, &data), + Message::Pong(data) => peek_proto::serialize_ws_message(WsMessageKind::Pong, &data), + Message::Close(_) => peek_proto::serialize_ws_close(), + Message::Frame(_) => return true, + }; + write_tx + .send(Message::Binary( + peek_proto::encode_frame(request_id, &payload).into(), + )) + .await + .is_ok() +} + +async fn send_tunnel_ws_message_to_local( + local_sink: &mut futures_util::stream::SplitSink< + tokio_tungstenite::WebSocketStream, + Message, + >, + frame: WsFrame, +) -> bool +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + let msg = match frame { + WsFrame::Message { kind, data } => match kind { + WsMessageKind::Text => match String::from_utf8(data) { + Ok(text) => Message::Text(text.into()), + Err(_) => return false, + }, + WsMessageKind::Binary => Message::Binary(data.into()), + WsMessageKind::Ping => Message::Ping(data.into()), + WsMessageKind::Pong => Message::Pong(data.into()), + }, + WsFrame::Close => Message::Close(None), + WsFrame::Open { .. } => return true, + }; + local_sink.send(msg).await.is_ok() +} + +async fn send_ws_close(request_id: u32, write_tx: &mpsc::Sender) { + let payload = peek_proto::serialize_ws_close(); + let _ = write_tx + .send(Message::Binary( + peek_proto::encode_frame(request_id, &payload).into(), + )) + .await; +} + +async fn handle_request( + request_id: u32, + payload: &[u8], + port: u16, + http_client: &reqwest::Client, + write_tx: &mpsc::Sender, +) { let req = match peek_proto::deserialize_request(payload) { Ok(r) => r, Err(e) => { diff --git a/crates/peek-proto/src/lib.rs b/crates/peek-proto/src/lib.rs index e634600..2c143b1 100644 --- a/crates/peek-proto/src/lib.rs +++ b/crates/peek-proto/src/lib.rs @@ -9,6 +9,8 @@ pub enum ProtoError { InvalidHttp(String), #[error("JSON error: {0}")] Json(#[from] serde_json::Error), + #[error("invalid WebSocket frame: {0}")] + InvalidWebSocket(String), } #[must_use] @@ -108,6 +110,130 @@ pub struct DeserializedResponse { pub body: Vec, } +const WS_MAGIC: &[u8; 3] = b"WS1"; +const WS_OPEN: u8 = 1; +const WS_MESSAGE: u8 = 2; +const WS_CLOSE: u8 = 3; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum WsFrame { + Open { + uri: String, + headers: Vec<(String, String)>, + }, + Message { + kind: WsMessageKind, + data: Vec, + }, + Close, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WsMessageKind { + Text, + Binary, + Ping, + Pong, +} + +#[must_use] +pub fn is_ws_frame(data: &[u8]) -> bool { + data.starts_with(WS_MAGIC) +} + +#[must_use] +pub fn serialize_ws_open(uri: &str, headers: &[(String, String)]) -> Vec { + let mut buf = Vec::new(); + buf.extend_from_slice(WS_MAGIC); + buf.push(WS_OPEN); + buf.extend_from_slice(uri.as_bytes()); + buf.extend_from_slice(CRLF); + for (k, v) in headers { + buf.extend_from_slice(k.as_bytes()); + buf.extend_from_slice(b": "); + buf.extend_from_slice(v.as_bytes()); + buf.extend_from_slice(CRLF); + } + buf.extend_from_slice(CRLF); + buf +} + +#[must_use] +pub fn serialize_ws_message(kind: WsMessageKind, data: &[u8]) -> Vec { + let mut buf = Vec::with_capacity(5 + data.len()); + buf.extend_from_slice(WS_MAGIC); + buf.push(WS_MESSAGE); + buf.push(match kind { + WsMessageKind::Text => 1, + WsMessageKind::Binary => 2, + WsMessageKind::Ping => 3, + WsMessageKind::Pong => 4, + }); + buf.extend_from_slice(data); + buf +} + +#[must_use] +pub fn serialize_ws_close() -> Vec { + let mut buf = Vec::with_capacity(4); + buf.extend_from_slice(WS_MAGIC); + buf.push(WS_CLOSE); + buf +} + +pub fn deserialize_ws_frame(data: &[u8]) -> Result { + if !is_ws_frame(data) || data.len() < 4 { + return Err(ProtoError::InvalidWebSocket("missing prefix".into())); + } + + match data[3] { + WS_OPEN => deserialize_ws_open(&data[4..]), + WS_MESSAGE => deserialize_ws_message(&data[4..]), + WS_CLOSE => Ok(WsFrame::Close), + value => Err(ProtoError::InvalidWebSocket(format!( + "unknown frame type: {value}" + ))), + } +} + +fn deserialize_ws_open(data: &[u8]) -> Result { + let (head, _) = split_head_body(data); + let mut lines = head.split(|&b| b == b'\n'); + let uri = lines + .next() + .map(strip_cr) + .ok_or_else(|| ProtoError::InvalidWebSocket("missing uri".into()))?; + let uri = String::from_utf8_lossy(uri).to_string(); + if uri.is_empty() || !uri.starts_with('/') { + return Err(ProtoError::InvalidWebSocket("invalid uri".into())); + } + Ok(WsFrame::Open { + uri, + headers: parse_headers(lines), + }) +} + +fn deserialize_ws_message(data: &[u8]) -> Result { + if data.is_empty() { + return Err(ProtoError::InvalidWebSocket("missing message kind".into())); + } + let kind = match data[0] { + 1 => WsMessageKind::Text, + 2 => WsMessageKind::Binary, + 3 => WsMessageKind::Ping, + 4 => WsMessageKind::Pong, + value => { + return Err(ProtoError::InvalidWebSocket(format!( + "unknown message kind: {value}" + ))); + } + }; + Ok(WsFrame::Message { + kind, + data: data[1..].to_vec(), + }) +} + pub fn deserialize_response(data: &[u8]) -> Result { let (head, body) = split_head_body(data); let mut lines = head.split(|&b| b == b'\n'); @@ -285,6 +411,34 @@ mod tests { assert!(resp.body.is_empty()); } + #[test] + fn websocket_frames_roundtrip() { + let headers = vec![("x-test".into(), "one".into())]; + let open = serialize_ws_open("/ws?room=1", &headers); + assert!(is_ws_frame(&open)); + assert_eq!( + deserialize_ws_frame(&open).unwrap(), + WsFrame::Open { + uri: "/ws?room=1".into(), + headers + } + ); + + let msg = serialize_ws_message(WsMessageKind::Text, b"hello"); + assert_eq!( + deserialize_ws_frame(&msg).unwrap(), + WsFrame::Message { + kind: WsMessageKind::Text, + data: b"hello".to_vec() + } + ); + + assert_eq!( + deserialize_ws_frame(&serialize_ws_close()).unwrap(), + WsFrame::Close + ); + } + #[test] fn close_codes_permanent() { assert!(crate::close_codes::is_permanent(4000)); diff --git a/crates/peek-relay/Cargo.toml b/crates/peek-relay/Cargo.toml index 222dbdf..c693c1f 100644 --- a/crates/peek-relay/Cargo.toml +++ b/crates/peek-relay/Cargo.toml @@ -24,3 +24,4 @@ subtle = "2" peek-client = { path = "../peek-client" } reqwest = { version = "0.13", default-features = false } http-body-util = "0.1" +tokio-tungstenite = "0.29" diff --git a/crates/peek-relay/src/handler.rs b/crates/peek-relay/src/handler.rs index a0866f6..77f6626 100644 --- a/crates/peek-relay/src/handler.rs +++ b/crates/peek-relay/src/handler.rs @@ -6,7 +6,7 @@ use axum::{ body::Body, extract::{ ConnectInfo, Request, State, - ws::{Message, WebSocket, WebSocketUpgrade}, + ws::{Message, WebSocket, WebSocketUpgrade, rejection::WebSocketUpgradeRejection}, }, http::HeaderMap, response::Response, @@ -17,7 +17,7 @@ use subtle::ConstantTimeEq; use tokio::sync::{mpsc, oneshot}; use tracing::{error, info, warn}; -use peek_proto::{self, RegistrationRequest, RegistrationResponse}; +use peek_proto::{self, RegistrationRequest, RegistrationResponse, WsFrame, WsMessageKind}; use crate::{ pages, @@ -155,9 +155,23 @@ async fn handle_tunnel_client(socket: WebSocket, registry: Arc) { match msg { Message::Binary(data) => match peek_proto::decode_frame(&data) { Ok((request_id, payload)) => { - let mut pending = conn.pending.lock().await; - if let Some(tx) = pending.remove(&request_id) { - let _ = tx.send(payload.to_vec()); + if peek_proto::is_ws_frame(payload) { + match peek_proto::deserialize_ws_frame(payload) { + Ok(frame) => { + let streams = conn.ws_streams.lock().await; + if let Some(tx) = streams.get(&request_id) { + let _ = tx.send(frame).await; + } + } + Err(e) => { + warn!(subdomain = %subdomain, error = %e, "bad websocket frame from client"); + } + } + } else { + let mut pending = conn.pending.lock().await; + if let Some(tx) = pending.remove(&request_id) { + let _ = tx.send(payload.to_vec()); + } } } Err(e) => { @@ -177,6 +191,7 @@ async fn handle_tunnel_client(socket: WebSocket, registry: Arc) { drop(write_tx); let _ = writer_handle.await; conn.pending.lock().await.clear(); + conn.ws_streams.lock().await.clear(); } async fn read_registration( @@ -246,6 +261,7 @@ fn registration_error_json(message: &'static str) -> String { pub async fn public_handler( State(registry): State>, ConnectInfo(peer_addr): ConnectInfo, + ws: Result, mut request: Request, ) -> Response { let ip = extract_client_ip( @@ -277,6 +293,11 @@ pub async fn public_handler( Err(response) => return response, }; } + + if let Ok(ws) = ws { + return handle_public_ws_upgrade(ws, request, conn, registry.max_body_size).await; + } + let max_body_size = registry.max_body_size; let method = request.method().to_string(); let uri = request.uri().to_string(); @@ -354,6 +375,149 @@ pub async fn public_handler( } } +async fn handle_public_ws_upgrade( + ws: WebSocketUpgrade, + request: Request, + conn: Arc, + max_message_size: usize, +) -> Response { + { + let streams = conn.ws_streams.lock().await; + if streams.len() >= MAX_PENDING_PER_TUNNEL { + return service_unavailable_page(); + } + } + + let request_id = conn.next_id(); + let uri = request.uri().to_string(); + let protocols = requested_ws_protocols(request.headers()); + let headers: Vec<(String, String)> = request + .headers() + .iter() + .filter(|(k, _)| !is_hop_by_hop_header(k.as_str())) + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + let open = peek_proto::serialize_ws_open(&uri, &headers); + let frame = peek_proto::encode_frame(request_id, &open); + + let (incoming_tx, incoming_rx) = mpsc::channel::(256); + conn.ws_streams.lock().await.insert(request_id, incoming_tx); + + if conn + .write_tx + .send(Message::Binary(frame.into())) + .await + .is_err() + { + conn.ws_streams.lock().await.remove(&request_id); + return bad_gateway_page(); + } + + ws.protocols(protocols) + .max_frame_size(max_message_size) + .max_message_size(max_message_size) + .on_upgrade(move |socket| bridge_public_ws(socket, conn, request_id, incoming_rx)) +} + +fn requested_ws_protocols(headers: &HeaderMap) -> Vec { + headers + .get("sec-websocket-protocol") + .and_then(|value| value.to_str().ok()) + .map(|value| { + value + .split(',') + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(str::to_string) + .collect() + }) + .unwrap_or_default() +} + +async fn bridge_public_ws( + socket: WebSocket, + conn: Arc, + request_id: u32, + mut incoming_rx: mpsc::Receiver, +) { + let (mut public_sink, mut public_stream) = socket.split(); + loop { + tokio::select! { + msg = public_stream.next() => { + let Some(Ok(msg)) = msg else { + send_ws_close(&conn, request_id).await; + break; + }; + if !send_public_ws_message_to_tunnel(&conn, request_id, msg).await { + break; + } + } + frame = incoming_rx.recv() => { + let Some(frame) = frame else { + break; + }; + if !send_tunnel_ws_message_to_public(&mut public_sink, frame).await { + break; + } + } + } + } + + conn.ws_streams.lock().await.remove(&request_id); +} + +async fn send_public_ws_message_to_tunnel( + conn: &TunnelConnection, + request_id: u32, + msg: Message, +) -> bool { + let payload = match msg { + Message::Text(text) => { + peek_proto::serialize_ws_message(WsMessageKind::Text, text.as_str().as_bytes()) + } + Message::Binary(data) => peek_proto::serialize_ws_message(WsMessageKind::Binary, &data), + Message::Ping(data) => peek_proto::serialize_ws_message(WsMessageKind::Ping, &data), + Message::Pong(data) => peek_proto::serialize_ws_message(WsMessageKind::Pong, &data), + Message::Close(_) => peek_proto::serialize_ws_close(), + }; + conn.write_tx + .send(Message::Binary( + peek_proto::encode_frame(request_id, &payload).into(), + )) + .await + .is_ok() +} + +async fn send_tunnel_ws_message_to_public( + public_sink: &mut futures_util::stream::SplitSink, + frame: WsFrame, +) -> bool { + let msg = match frame { + WsFrame::Message { kind, data } => match kind { + WsMessageKind::Text => match String::from_utf8(data) { + Ok(text) => Message::Text(text.into()), + Err(_) => return false, + }, + WsMessageKind::Binary => Message::Binary(data.into()), + WsMessageKind::Ping => Message::Ping(data.into()), + WsMessageKind::Pong => Message::Pong(data.into()), + }, + WsFrame::Close => Message::Close(None), + WsFrame::Open { .. } => return true, + }; + public_sink.send(msg).await.is_ok() +} + +async fn send_ws_close(conn: &TunnelConnection, request_id: u32) { + let payload = peek_proto::serialize_ws_close(); + let _ = conn + .write_tx + .send(Message::Binary( + peek_proto::encode_frame(request_id, &payload).into(), + )) + .await; +} + async fn password_gate( request: Request, subdomain: &str, diff --git a/crates/peek-relay/src/registry.rs b/crates/peek-relay/src/registry.rs index d5535cf..fdbe60d 100644 --- a/crates/peek-relay/src/registry.rs +++ b/crates/peek-relay/src/registry.rs @@ -24,6 +24,7 @@ pub struct Registry { pub struct TunnelConnection { pub write_tx: mpsc::Sender, pub pending: Mutex>>>, + pub ws_streams: Mutex>>, pub password: Option, next_request_id: AtomicU32, } @@ -110,6 +111,7 @@ impl TunnelConnection { Self { write_tx, pending: Mutex::new(HashMap::new()), + ws_streams: Mutex::new(HashMap::new()), password, next_request_id: AtomicU32::new(1), } diff --git a/crates/peek-relay/tests/integration.rs b/crates/peek-relay/tests/integration.rs index db598d8..e243766 100644 --- a/crates/peek-relay/tests/integration.rs +++ b/crates/peek-relay/tests/integration.rs @@ -2,9 +2,16 @@ use std::fmt::Write as _; use std::net::SocketAddr; use std::time::Duration; -use axum::{Router, routing::get}; +use axum::{ + Router, + extract::ws::{Message, WebSocketUpgrade}, + http::HeaderMap, + routing::get, +}; +use futures_util::{SinkExt, StreamExt}; use peek_relay::{AppConfig, build_app}; use tokio::net::TcpListener; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; async fn start_relay(domain: &str) -> SocketAddr { let app = build_app(AppConfig { @@ -76,6 +83,48 @@ async fn start_local_server() -> (SocketAddr, &'static str) { tokio::time::sleep(Duration::from_secs(2)).await; "slow response" }), + ) + .route( + "/ws", + get(|ws: WebSocketUpgrade| async move { + ws.on_upgrade(|socket| async move { + let (mut sink, mut stream) = socket.split(); + while let Some(Ok(msg)) = stream.next().await { + match msg { + Message::Text(_) | Message::Binary(_) | Message::Ping(_) => { + if sink.send(msg).await.is_err() { + break; + } + } + Message::Close(_) => break, + Message::Pong(_) => {} + } + } + }) + }), + ) + .route( + "/ws-header", + get(|headers: HeaderMap, ws: WebSocketUpgrade| async move { + let header = headers + .get("x-ws-token") + .and_then(|value| value.to_str().ok()) + .unwrap_or("") + .to_string(); + ws.on_upgrade(|socket| async move { + let (mut sink, _) = socket.split(); + let _ = sink.send(Message::Text(header.into())).await; + }) + }), + ) + .route( + "/ws-protocol", + get(|ws: WebSocketUpgrade| async move { + ws.protocols(["chat"]).on_upgrade(|socket| async move { + let (mut sink, _) = socket.split(); + let _ = sink.send(Message::Text("protocol ok".into())).await; + }) + }), ); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -129,6 +178,124 @@ async fn test_tunnel_end_to_end() { handle.close().await; } +#[tokio::test] +async fn test_websocket_tunnel_selects_subprotocol() { + let (local_addr, _) = start_local_server().await; + let relay_addr = start_relay("test-ws-protocol.local").await; + + let client = peek_client::TunnelClient::new(&format!("ws://{relay_addr}/tunnel")).unwrap(); + let handle = client + .connect_with_subdomain(local_addr.port(), Some("testsubdomain".into())) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut request = format!("ws://{relay_addr}/ws-protocol") + .into_client_request() + .unwrap(); + request.headers_mut().insert( + "host", + "testsubdomain.test-ws-protocol.local".parse().unwrap(), + ); + request + .headers_mut() + .insert("sec-websocket-protocol", "chat".parse().unwrap()); + + let (mut socket, response) = tokio_tungstenite::connect_async(request).await.unwrap(); + assert_eq!( + response + .headers() + .get("sec-websocket-protocol") + .and_then(|value| value.to_str().ok()), + Some("chat") + ); + let msg = socket.next().await.unwrap().unwrap(); + assert_eq!(msg.into_text().unwrap(), "protocol ok"); + + handle.close().await; +} + +#[tokio::test] +async fn test_websocket_tunnel_echo() { + let (local_addr, expected_body) = start_local_server().await; + let relay_addr = start_relay("test-ws.local").await; + + let client = peek_client::TunnelClient::new(&format!("ws://{relay_addr}/tunnel")).unwrap(); + let handle = client + .connect_with_subdomain(local_addr.port(), Some("testsubdomain".into())) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + + let http_client = reqwest::Client::new(); + let resp = http_client + .get(format!("http://{relay_addr}/")) + .header("host", "testsubdomain.test-ws.local") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + assert_eq!(resp.text().await.unwrap(), expected_body); + + let mut request = format!("ws://{relay_addr}/ws") + .into_client_request() + .unwrap(); + request + .headers_mut() + .insert("host", "testsubdomain.test-ws.local".parse().unwrap()); + let (mut socket, _) = tokio_tungstenite::connect_async(request).await.unwrap(); + + socket + .send(tokio_tungstenite::tungstenite::Message::Text( + "hello".into(), + )) + .await + .unwrap(); + let msg = socket.next().await.unwrap().unwrap(); + assert_eq!(msg.into_text().unwrap(), "hello"); + + socket + .send(tokio_tungstenite::tungstenite::Message::Binary( + vec![1, 2, 3].into(), + )) + .await + .unwrap(); + let msg = socket.next().await.unwrap().unwrap(); + assert_eq!(msg.into_data(), vec![1, 2, 3]); + + handle.close().await; +} + +#[tokio::test] +async fn test_websocket_tunnel_forwards_headers() { + let (local_addr, _) = start_local_server().await; + let relay_addr = start_relay("test-ws-headers.local").await; + + let client = peek_client::TunnelClient::new(&format!("ws://{relay_addr}/tunnel")).unwrap(); + let handle = client + .connect_with_subdomain(local_addr.port(), Some("testsubdomain".into())) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut request = format!("ws://{relay_addr}/ws-header") + .into_client_request() + .unwrap(); + request.headers_mut().insert( + "host", + "testsubdomain.test-ws-headers.local".parse().unwrap(), + ); + request + .headers_mut() + .insert("x-ws-token", "secret".parse().unwrap()); + + let (mut socket, _) = tokio_tungstenite::connect_async(request).await.unwrap(); + let msg = socket.next().await.unwrap().unwrap(); + assert_eq!(msg.into_text().unwrap(), "secret"); + + handle.close().await; +} + #[tokio::test] async fn test_tunnel_not_found() { let relay_addr = start_relay("test2.local").await;