diff --git a/Cargo.lock b/Cargo.lock index fd9da7e..8d7112d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -152,6 +152,7 @@ dependencies = [ "serde", "serde_yaml", "simple_logger", + "socket2 0.5.10", "tokio", "uuid", ] @@ -345,6 +346,16 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "socket2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "socket2" version = "0.6.2" @@ -411,7 +422,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.6.2", "tokio-macros", "windows-sys 0.61.2", ] @@ -516,13 +527,22 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" dependencies = [ - "windows-targets", + "windows-targets 0.53.5", ] [[package]] @@ -534,6 +554,22 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + [[package]] name = "windows-targets" version = "0.53.5" @@ -541,58 +577,106 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ "windows-link", - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + [[package]] name = "windows_aarch64_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + [[package]] name = "windows_aarch64_msvc" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + [[package]] name = "windows_i686_gnu" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + [[package]] name = "windows_i686_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + [[package]] name = "windows_i686_msvc" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + [[package]] name = "windows_x86_64_gnu" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + [[package]] name = "windows_x86_64_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "windows_x86_64_msvc" version = "0.53.1" diff --git a/minecraft/src/serialization.rs b/minecraft/src/serialization.rs index 2a5a0b3..ed10977 100644 --- a/minecraft/src/serialization.rs +++ b/minecraft/src/serialization.rs @@ -9,6 +9,19 @@ use crate::{ const SEGMENT_BITS: u32 = 0x7F; const CONTINUE_BIT: u32 = 0x80; +/// Returns the number of bytes a VarInt value occupies when encoded. +fn varint_size(value: i32) -> usize { + let mut value = value as u32; + let mut size = 0; + loop { + size += 1; + if (value & !SEGMENT_BITS) == 0 { + return size; + } + value >>= 7; + } +} + #[derive(Debug, PartialEq)] pub enum ReadingError { Insufficient, @@ -93,7 +106,7 @@ impl MinecraftStream { } pub fn data_len(&self) -> usize { - self.free - self.position + 1 + self.free - self.position } pub fn take_buffer(&mut self) -> Vec { @@ -160,8 +173,11 @@ impl MinecraftStream { where T: PacketDeserializer, { - if signature.length > self.data_len() { - match &self.fill_buffer_from_source(signature.length).await { + // signature.length includes the packet_id VarInt which was already consumed + // in read_signature, so subtract its encoded size to get the actual data length + let data_needed = signature.length.saturating_sub(varint_size(signature.packet_id)); + if data_needed > self.data_len() { + match &self.fill_buffer_from_source(data_needed).await { Ok(_) => {} Err(_) => return Err(ReadingError::Closed), }; @@ -199,10 +215,6 @@ impl MinecraftStream { T::read(self) } - fn remain_len(&self) -> usize { - self.buffer.len() - self.position - } - fn copy_buffer_to_start(&mut self) { let data_len = self.free - self.position; self.buffer.copy_within(self.position..self.free, 0); @@ -211,7 +223,8 @@ impl MinecraftStream { } fn expand_buffer(&mut self) { - todo!() + let new_len = self.buffer.len() * 2; + self.buffer.resize(new_len, 0); } async fn fill_buffer_from_source(&mut self, required: usize) -> Result<(), ()> { @@ -295,16 +308,15 @@ impl FieldReader for String { fn read( stream: &mut MinecraftStream, ) -> Result { - // todo: there is a bug - read_field changes position of the stream, but below can happen reading error if packet doesn't fully read let length = stream.read_field::()? as usize; - if length > stream.remain_len() { + if length > stream.data_len() { return Err(ReadingError::Insufficient); } let mut vec: Vec = vec![0; length]; vec.copy_from_slice(&stream.buffer[stream.position..stream.position + length]); stream.position += length; - Ok(String::from_utf8(vec).unwrap()) + String::from_utf8(vec).map_err(|_| ReadingError::Invalid) } } diff --git a/mineginx/Cargo.toml b/mineginx/Cargo.toml index 731b302..d941759 100644 --- a/mineginx/Cargo.toml +++ b/mineginx/Cargo.toml @@ -16,3 +16,4 @@ tokio = { version = "1.49.0", features = ["full"] } uuid = { version = "1.20.0", features = ["v4"] } log = { version = "0.4" } simple_logger = { version = "5.1.0" } +socket2 = "0.5" diff --git a/mineginx/src/main.rs b/mineginx/src/main.rs index a56c75e..9aef6d2 100644 --- a/mineginx/src/main.rs +++ b/mineginx/src/main.rs @@ -1,21 +1,24 @@ use std::{ - borrow::BorrowMut, collections::HashMap, env, fs::{self}, io::ErrorKind, path::Path, process::ExitCode, sync::Arc, time::Duration + borrow::BorrowMut, collections::HashMap, env, fs::{self}, io::ErrorKind, net::SocketAddr, path::Path, process::ExitCode, sync::Arc, time::Duration }; use config::{MinecraftServerDescription, MineginxConfig}; use log::{error, info, warn}; use minecraft::{packets::{HandshakeC2SPacket, MinecraftPacket}, serialization::{truncate_to_zero, MinecraftStream}}; use simple_logger::SimpleLogger; -use tokio::{io::AsyncWriteExt, net::{TcpListener, TcpStream}, sync::oneshot, task::JoinHandle, time::timeout}; -use stream::forward_stream; +use socket2::{Socket, Domain, Type}; +use tokio::{io::AsyncWriteExt, net::{TcpListener, TcpStream}, task::JoinHandle, time::timeout}; +use stream::forward_half; mod stream; mod config; -fn find_upstream(domain: &String, config: Arc) -> Option { +fn find_upstream_config<'a>(domain: &str, config: &'a MineginxConfig) -> Option<&'a MinecraftServerDescription> { + let domain = domain.trim_end_matches('.'); for x in &config.servers { for server_name in &x.server_names { - if server_name == domain { - return Some(x.clone()); + let cfg_name = server_name.trim_end_matches('.'); + if cfg_name.eq_ignore_ascii_case(domain) { + return Some(x); } } } @@ -41,9 +44,7 @@ async fn handle_client(mut client: TcpStream, config: Arc) { let handshake_result = timeout(timeout_future, read_handshake_packet(&mut minecraft)).await; let handshake = match handshake_result { Ok(result) => match result { - Ok(handshake) => { - handshake - } + Ok(handshake) => handshake, Err(_) => { error!("handshake failed for someone"); return; @@ -54,26 +55,25 @@ async fn handle_client(mut client: TcpStream, config: Arc) { return; } }; - let domain = truncate_to_zero(&handshake.domain).to_string(); - let upstream_server = match find_upstream(&domain, config.clone()) { - Some(x) => x, + let server_desc = match find_upstream_config(&domain, &config) { + Some(x) => x.clone(), None => { warn!("there is no upstream for domain {:#?}", &domain); return; } }; - info!("new connection (protocol_version: {}, domain: {}, upstream: {})", &handshake.protocol_version, &domain, upstream_server.proxy_pass); + info!("new connection (protocol_version: {}, domain: {}, upstream: {})", &handshake.protocol_version, &domain, &server_desc.proxy_pass); - let mut upstream = match TcpStream::connect(&upstream_server.proxy_pass).await { + let mut upstream_conn = match TcpStream::connect(&server_desc.proxy_pass).await { Ok(x) => x, Err(e) => { - error!("failed to connect upstream: {}, {e}", &upstream_server.proxy_pass); + error!("failed to connect upstream: {}, {e}", &server_desc.proxy_pass); return; } }; - if let Err(e) = upstream.set_nodelay(true) { + if let Err(e) = upstream_conn.set_nodelay(true) { error!("failed to set no_delay for upstream: {}", e); return; } @@ -81,12 +81,11 @@ async fn handle_client(mut client: TcpStream, config: Arc) { Some(v) => v, None => return }; - match upstream.write_all(&packet[0..packet.len()]).await { + match upstream_conn.write_all(&packet[0..packet.len()]).await { Ok(_) => { }, Err(_) => return }; - // flush unread buffer to the upstream - match upstream.write_all(&minecraft.take_buffer()).await { + match upstream_conn.write_all(&minecraft.take_buffer()).await { Ok(_) => {}, Err(_) => { return; @@ -94,26 +93,21 @@ async fn handle_client(mut client: TcpStream, config: Arc) { } let (client_reader, client_writer) = client.into_split(); - let (upstream_reader, upstream_writer) = upstream.into_split(); - let (client_close_sender, client_close_receiver) = oneshot::channel::<()>(); - let (upstream_close_sender, upstream_close_receiver) = oneshot::channel::<()>(); - forward_stream( - client_close_sender, - upstream_close_receiver, - client_reader, - upstream_writer, - if let Some(buffer_size) = upstream_server.buffer_size { buffer_size as usize } else { 2048 }); - forward_stream( - upstream_close_sender, - client_close_receiver, - upstream_reader, - client_writer, - if let Some(buffer_size) = upstream_server.buffer_size { buffer_size as usize } else { 2048 }); + let (upstream_reader, upstream_writer) = upstream_conn.into_split(); + let buf_size = server_desc.buffer_size.map(|b| b as usize).unwrap_or(8192); + + // Each direction is a separate task. When one side hits EOF, + // it shuts down its writer (sends FIN), causing the other to EOF too. + let c2s = forward_half(client_reader, upstream_writer, buf_size); + let s2c = forward_half(upstream_reader, client_writer, buf_size); + + let _ = tokio::join!(c2s, s2c); + info!("connection closed (domain: {}, upstream: {})", &domain, &server_desc.proxy_pass); } async fn handle_address(listener: &TcpListener, config: Arc) { loop { - let (socket, _address) = match listener.accept().await { + let (socket, _addr) = match listener.accept().await { Ok(x) => x, Err(e) => { error!("failed to accept client: {e}"); @@ -191,6 +185,12 @@ const CONFIG_FILE: &str = "./config/mineginx.yaml"; #[tokio::main(flavor = "multi_thread")] async fn main() -> ExitCode { SimpleLogger::new().init().unwrap(); + + std::panic::set_hook(Box::new(|panic_info| { + eprintln!("PANIC: {}", panic_info); + log::error!("panic occurred: {}", panic_info); + })); + let mut args = env::args(); if args.any(|x| &x == "-t") { return match check_config().await { @@ -207,20 +207,57 @@ async fn main() -> ExitCode { return ExitCode::from(2); } }; + let mut listening = HashMap::::new(); for server in &config.servers { if listening.contains_key(&server.listen) { continue; } info!("listening {}", &server.listen); - let listener = TcpListener::bind(&server.listen).await.unwrap(); + let addr: SocketAddr = match server.listen.parse() { + Ok(a) => a, + Err(e) => { + error!("failed to parse listen address {}: {e}", &server.listen); + return ExitCode::from(3); + } + }; + let socket = match Socket::new(Domain::for_address(addr), Type::STREAM, None) { + Ok(s) => s, + Err(e) => { + error!("failed to create socket for {}: {e}", &server.listen); + return ExitCode::from(3); + } + }; + if let Err(e) = socket.set_reuse_address(true) { + error!("failed to set SO_REUSEADDR for {}: {e}", &server.listen); + return ExitCode::from(3); + } + if let Err(e) = socket.bind(&addr.into()) { + error!("failed to bind {}: {e}", &server.listen); + return ExitCode::from(3); + } + + if let Err(e) = socket.listen(1024) { + error!("failed to listen on {}: {e}", &server.listen); + return ExitCode::from(3); + } + socket.set_nonblocking(true).unwrap(); + let listener = match TcpListener::from_std(socket.into()) { + Ok(l) => l, + Err(e) => { + error!("failed to create tokio listener for {}: {e}", &server.listen); + return ExitCode::from(3); + } + }; let conf = config.clone(); let task = tokio::spawn(async move { handle_address(&listener, conf).await; }); listening.insert(server.listen.to_string(), ListeningAddress(task)); } - tokio::signal::ctrl_c().await.unwrap(); + if let Err(e) = tokio::signal::ctrl_c().await { + error!("failed to listen for ctrl_c signal: {e}"); + } info!("shutdown"); ExitCode::from(0) } diff --git a/mineginx/src/stream.rs b/mineginx/src/stream.rs index c3fa842..2b2eb69 100644 --- a/mineginx/src/stream.rs +++ b/mineginx/src/stream.rs @@ -1,62 +1,28 @@ use tokio::{ task::JoinHandle, - sync::oneshot::{ - Sender, Receiver, error::TryRecvError - }, - net::tcp::{ - OwnedReadHalf, OwnedWriteHalf - }, - io::{AsyncReadExt, AsyncWriteExt} + net::tcp::{OwnedReadHalf, OwnedWriteHalf}, + io::{AsyncReadExt, AsyncWriteExt}, }; -pub fn forward_stream( - close: Sender<()>, - close_by_other: Receiver<()>, +/// Forwards data from `reader` to `writer` until EOF or error, +/// then shuts down the writer (sends TCP FIN). +pub fn forward_half( mut reader: OwnedReadHalf, mut writer: OwnedWriteHalf, - buffer_size: usize) -> JoinHandle<()> { + buffer_size: usize, +) -> JoinHandle<()> { tokio::spawn(async move { let mut buf = vec![0; buffer_size]; - let mut close = Some(close); - let mut close_by_other = Some(close_by_other); - let mut closed = false; loop { - if let Some(mut receiver) = close_by_other.take() { - match receiver.try_recv() { - Err(e ) => closed |= e == TryRecvError::Closed, - Ok(_) => closed = true - } - } - if closed { - return; - } - let res = reader.read(&mut buf).await; - match res { - Ok(size) => { - if size == 0 { - if let Some(sender) = close.take() { - closed = true; - _ = sender.send(()); - } - } - let writed = writer.write_all(&buf[..size]).await; - match writed { - Ok(_) => { }, - Err(_) => { - if let Some(sender) = close.take() { - _ = sender.send(()) - } - return; - } - } - }, - Err(_) => { - if let Some(sender) = close.take() { - _ = sender.send(()); + match reader.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => { + if writer.write_all(&buf[..n]).await.is_err() { + break; } - return; } } } + _ = writer.shutdown().await; }) }