Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

218 changes: 212 additions & 6 deletions crates/peek-client/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::time::Duration;
use std::{collections::HashMap, sync::Arc, time::Duration};

mod pages;

use futures_util::{SinkExt, StreamExt};
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;
use tokio::sync::{Mutex, mpsc};
use tokio_tungstenite::tungstenite::{Message, client::IntoClientRequest};
use tracing::{info, warn};

use peek_proto::{RegistrationRequest, RegistrationResponse};
use peek_proto::{RegistrationRequest, RegistrationResponse, WsFrame, WsMessageKind};

const HOP_BY_HOP_HEADERS: &[&str] = &[
"connection",
Expand Down Expand Up @@ -174,6 +174,7 @@ fn spawn_connection_tasks<S, K>(

let write_tx_read = write_tx.clone();
let http_client = http_client.clone();
let ws_streams = Arc::new(Mutex::new(HashMap::<u32, mpsc::Sender<WsFrame>>::new()));
tokio::spawn(async move {
loop {
let msg = match stream.next().await {
Expand All @@ -189,8 +190,9 @@ fn spawn_connection_tasks<S, K>(
Message::Binary(data) => {
let write_tx = write_tx_read.clone();
let http_client = http_client.clone();
let ws_streams = ws_streams.clone();
tokio::spawn(async move {
handle_request(&data, port, &http_client, &write_tx).await;
handle_tunnel_frame(&data, port, &http_client, &write_tx, ws_streams).await;
});
}
Message::Ping(data) => {
Expand Down Expand Up @@ -235,11 +237,12 @@ impl Drop for TunnelHandle {
}
}

async fn handle_request(
async fn handle_tunnel_frame(
data: &[u8],
port: u16,
http_client: &reqwest::Client,
write_tx: &mpsc::Sender<Message>,
ws_streams: Arc<Mutex<HashMap<u32, mpsc::Sender<WsFrame>>>>,
) {
let (request_id, payload) = match peek_proto::decode_frame(data) {
Ok(r) => r,
Expand All @@ -249,6 +252,209 @@ async fn handle_request(
}
};

if peek_proto::is_ws_frame(payload) {
match peek_proto::deserialize_ws_frame(payload) {
Ok(WsFrame::Open { uri, headers }) => {
open_local_ws(request_id, port, uri, headers, write_tx.clone(), ws_streams).await;
}
Ok(frame) => {
let streams = ws_streams.lock().await;
if let Some(tx) = streams.get(&request_id) {
let _ = tx.send(frame).await;
}
}
Err(e) => {
warn!(error = %e, "failed to decode websocket frame");
}
}
return;
}

handle_request(request_id, payload, port, http_client, write_tx).await;
}

async fn open_local_ws(
request_id: u32,
port: u16,
uri: String,
headers: Vec<(String, String)>,
write_tx: mpsc::Sender<Message>,
ws_streams: Arc<Mutex<HashMap<u32, mpsc::Sender<WsFrame>>>>,
) {
let (incoming_tx, incoming_rx) = mpsc::channel::<WsFrame>(256);
ws_streams.lock().await.insert(request_id, incoming_tx);
tokio::spawn(connect_and_bridge_local_ws(
request_id,
port,
uri,
headers,
incoming_rx,
write_tx,
ws_streams,
));
}

async fn connect_and_bridge_local_ws(
request_id: u32,
port: u16,
uri: String,
headers: Vec<(String, String)>,
incoming_rx: mpsc::Receiver<WsFrame>,
write_tx: mpsc::Sender<Message>,
ws_streams: Arc<Mutex<HashMap<u32, mpsc::Sender<WsFrame>>>>,
) {
let local_url = format!("ws://127.0.0.1:{port}{uri}");
let request = match local_ws_request(&local_url, &headers) {
Ok(request) => request,
Err(e) => {
warn!(error = %e, url = %local_url, "failed to build local websocket request");
ws_streams.lock().await.remove(&request_id);
send_ws_close(request_id, &write_tx).await;
return;
}
};
let (local_ws, _) = match tokio_tungstenite::connect_async(request).await {
Ok(result) => result,
Err(e) => {
warn!(error = %e, url = %local_url, "failed to connect local websocket");
ws_streams.lock().await.remove(&request_id);
send_ws_close(request_id, &write_tx).await;
return;
}
};

bridge_local_ws(request_id, local_ws, incoming_rx, write_tx, ws_streams).await;
}

fn local_ws_request(
local_url: &str,
headers: &[(String, String)],
) -> Result<tokio_tungstenite::tungstenite::handshake::client::Request, String> {
let mut request = local_url
.into_client_request()
.map_err(|error| error.to_string())?;
for (name, value) in headers {
if is_ws_handshake_header(name) {
continue;
}
let Ok(name) = name.parse::<tokio_tungstenite::tungstenite::http::HeaderName>() else {
continue;
};
let Ok(value) = value.parse::<tokio_tungstenite::tungstenite::http::HeaderValue>() else {
continue;
};
request.headers_mut().insert(name, value);
}
Ok(request)
}

fn is_ws_handshake_header(name: &str) -> bool {
is_hop_by_hop_header(name)
|| name.eq_ignore_ascii_case("sec-websocket-accept")
|| name.eq_ignore_ascii_case("sec-websocket-key")
|| name.eq_ignore_ascii_case("sec-websocket-version")
}

async fn bridge_local_ws<S>(
request_id: u32,
local_ws: tokio_tungstenite::WebSocketStream<S>,
mut incoming_rx: mpsc::Receiver<WsFrame>,
write_tx: mpsc::Sender<Message>,
ws_streams: Arc<Mutex<HashMap<u32, mpsc::Sender<WsFrame>>>>,
) where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let (mut local_sink, mut local_stream) = local_ws.split();
loop {
tokio::select! {
msg = local_stream.next() => {
let Some(Ok(msg)) = msg else {
send_ws_close(request_id, &write_tx).await;
break;
};
if !send_local_ws_message_to_tunnel(request_id, msg, &write_tx).await {
break;
}
}
frame = incoming_rx.recv() => {
let Some(frame) = frame else {
break;
};
if !send_tunnel_ws_message_to_local(&mut local_sink, frame).await {
break;
}
}
}
}

ws_streams.lock().await.remove(&request_id);
}

async fn send_local_ws_message_to_tunnel(
request_id: u32,
msg: Message,
write_tx: &mpsc::Sender<Message>,
) -> bool {
let payload = match msg {
Message::Text(text) => {
peek_proto::serialize_ws_message(WsMessageKind::Text, text.as_str().as_bytes())
}
Message::Binary(data) => peek_proto::serialize_ws_message(WsMessageKind::Binary, &data),
Message::Ping(data) => peek_proto::serialize_ws_message(WsMessageKind::Ping, &data),
Message::Pong(data) => peek_proto::serialize_ws_message(WsMessageKind::Pong, &data),
Message::Close(_) => peek_proto::serialize_ws_close(),
Message::Frame(_) => return true,
};
write_tx
.send(Message::Binary(
peek_proto::encode_frame(request_id, &payload).into(),
))
.await
.is_ok()
}

async fn send_tunnel_ws_message_to_local<S>(
local_sink: &mut futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<S>,
Message,
>,
frame: WsFrame,
) -> bool
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
let msg = match frame {
WsFrame::Message { kind, data } => match kind {
WsMessageKind::Text => match String::from_utf8(data) {
Ok(text) => Message::Text(text.into()),
Err(_) => return false,
},
WsMessageKind::Binary => Message::Binary(data.into()),
WsMessageKind::Ping => Message::Ping(data.into()),
WsMessageKind::Pong => Message::Pong(data.into()),
},
WsFrame::Close => Message::Close(None),
WsFrame::Open { .. } => return true,
};
local_sink.send(msg).await.is_ok()
}

async fn send_ws_close(request_id: u32, write_tx: &mpsc::Sender<Message>) {
let payload = peek_proto::serialize_ws_close();
let _ = write_tx
.send(Message::Binary(
peek_proto::encode_frame(request_id, &payload).into(),
))
.await;
}

async fn handle_request(
request_id: u32,
payload: &[u8],
port: u16,
http_client: &reqwest::Client,
write_tx: &mpsc::Sender<Message>,
) {
let req = match peek_proto::deserialize_request(payload) {
Ok(r) => r,
Err(e) => {
Expand Down
Loading