diff --git a/.gitignore b/.gitignore index ea8c4bf..1895703 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +/config.yaml \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 26026e4..196b14d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,10 @@ async-trait = "0.1" [dev-dependencies] tempfile = "3" + +[profile.release] +opt-level = 3 +lto = true +codegen-units = 1 +panic = "abort" +strip = true diff --git a/src/core/forwarder.rs b/src/core/forwarder.rs new file mode 100644 index 0000000..30a9697 --- /dev/null +++ b/src/core/forwarder.rs @@ -0,0 +1,65 @@ +use crate::core::upstream::AsyncStream; +use std::io; + +// Actual implementation below +// Spliceable trait and its implementations are removed as AsyncStream handles readiness internally. + +async fn splice_loop(read: &AsyncStream, write: &AsyncStream) -> io::Result { + let mut pipe = [0i32; 2]; + if unsafe { libc::pipe2(pipe.as_mut_ptr(), libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 { + return Err(io::Error::last_os_error()); + } + let (pipe_rd, pipe_wr) = (pipe[0], pipe[1]); + + struct PipeGuard(i32, i32); + impl Drop for PipeGuard { + fn drop(&mut self) { + unsafe { + libc::close(self.0); + libc::close(self.1); + } + } + } + let _guard = PipeGuard(pipe_rd, pipe_wr); + + let mut total_bytes = 0; + + loop { + // src -> pipe + // splice_read handles readiness internally with AsyncFd + let len = match read.splice_read(pipe_wr, 65536).await { + Ok(0) => return Ok(total_bytes), // EOF + Ok(n) => n, + Err(e) => return Err(e), + }; + + // pipe -> dst + let mut written = 0; + while written < len { + let to_write = len - written; + let n = write.splice_write(pipe_rd, to_write).await?; + if n == 0 { + return Err(io::Error::new( + io::ErrorKind::WriteZero, + "Zero write in splice logic", + )); + } + written += n; + total_bytes += n as u64; + } + } +} + +pub async fn zero_copy_bidirectional( + inbound: AsyncStream, + outbound: AsyncStream, +) -> io::Result<()> { + // We own the streams now, so we can split references to them for the join. + let (c2s, s2c) = tokio::join!( + splice_loop(&inbound, &outbound), + splice_loop(&outbound, &inbound) + ); + c2s?; + s2c?; + Ok(()) +} diff --git a/src/core/mod.rs b/src/core/mod.rs index 74f47ad..cfeb2ca 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1 +1,3 @@ +pub mod forwarder; pub mod server; +pub mod upstream; diff --git a/src/core/server.rs b/src/core/server.rs index 0aedb34..693e329 100644 --- a/src/core/server.rs +++ b/src/core/server.rs @@ -1,10 +1,13 @@ -use crate::config::{BindType, Config}; +use crate::config::{BindType, Config, ServiceConfig}; +use crate::core::forwarder; +use crate::core::upstream::UpstreamStream; use crate::db::clickhouse::ClickHouseLogger; -use crate::db::AuditLogger; +use crate::protocol; use std::sync::Arc; +use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; use tokio::signal; -use tracing::{error, info}; +use tracing::{error, info, instrument}; pub async fn run(config: Config) -> anyhow::Result<()> { let db = Arc::new(ClickHouseLogger::new(&config.database)); @@ -13,15 +16,19 @@ pub async fn run(config: Config) -> anyhow::Result<()> { for service in config.services { let db = db.clone(); - for bind in service.binds { - let service_name = service.name.clone(); + for bind in &service.binds { + let service_config = service.clone(); // Clone for the task let bind_addr = bind.addr.clone(); + let proxy_protocol = bind.proxy_protocol.is_some(); let bind_type = bind.bind_type; - // TODO: Handle UDP and Unix if bind_type == BindType::Tcp { - let db = db.clone(); - join_set.spawn(start_tcp_service(service_name, bind_addr, db)); + join_set.spawn(start_tcp_service( + service_config, + bind_addr, + proxy_protocol, + db.clone(), + )); } else { info!("Skipping non-TCP bind for now: {:?}", bind_type); } @@ -39,14 +46,18 @@ pub async fn run(config: Config) -> anyhow::Result<()> { } } - // Abort all tasks join_set.shutdown().await; Ok(()) } -async fn start_tcp_service(name: String, addr: String, _db: Arc) { - info!("Service {} listening on TCP {}", name, addr); +async fn start_tcp_service( + service: ServiceConfig, + addr: String, + proxy_protocol: bool, + _db: Arc, +) { + info!("Service {} listening on TCP {}", service.name, addr); let listener = match TcpListener::bind(&addr).await { Ok(l) => l, Err(e) => { @@ -57,10 +68,16 @@ async fn start_tcp_service(name: String, addr: String, _db: Arc { + Ok((mut inbound, client_addr)) => { info!("New connection from {}", client_addr); - // Spawn handler - // tokio::spawn(handle_connection(_socket, ...)); + let service = service.clone(); + // let db = _db.clone(); + + tokio::spawn(async move { + if let Err(e) = handle_connection(inbound, service, proxy_protocol).await { + error!("Connection error: {}", e); + } + }); } Err(e) => { error!("Accept error: {}", e); @@ -68,3 +85,36 @@ async fn start_tcp_service(name: String, addr: String, _db: Arc std::io::Result<()> { + // 1. Read Proxy Protocol (if configured) + let mut buffer = if proxy_protocol { + let (_proxy_info, buffer) = protocol::read_proxy_header(&mut inbound).await?; + buffer + } else { + bytes::BytesMut::new() + }; + + // 2. Connect Upstream + let mut upstream = UpstreamStream::connect(service.forward_type, &service.forward_addr).await?; + + // 3. Forward Header (TODO: if configured) + + // 4. Write buffered data (peeked bytes) + if !buffer.is_empty() { + upstream.write_all_buf(&mut buffer).await?; + } + + // 5. Zero-copy forwarding + let inbound_async = crate::core::upstream::AsyncStream::from_tokio_tcp(inbound)?; + let upstream_async = upstream.into_async_stream()?; + + forwarder::zero_copy_bidirectional(inbound_async, upstream_async).await?; + + Ok(()) +} diff --git a/src/core/upstream.rs b/src/core/upstream.rs new file mode 100644 index 0000000..4e32865 --- /dev/null +++ b/src/core/upstream.rs @@ -0,0 +1,182 @@ +use crate::config::ForwardType; +use std::io; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::{TcpStream, UnixStream}; + +#[derive(Debug)] +pub enum UpstreamStream { + Tcp(TcpStream), + Unix(UnixStream), +} + +impl UpstreamStream { + pub async fn connect(fw_type: ForwardType, addr: &str) -> io::Result { + match fw_type { + ForwardType::Tcp => { + let stream = TcpStream::connect(addr).await?; + stream.set_nodelay(true)?; + Ok(UpstreamStream::Tcp(stream)) + } + ForwardType::Unix => { + let stream = UnixStream::connect(addr).await?; + Ok(UpstreamStream::Unix(stream)) + } + ForwardType::Udp => Err(io::Error::new( + io::ErrorKind::Unsupported, + "UDP forwarding not yet implemented in stream context", + )), + } + } +} + +impl AsRawFd for UpstreamStream { + fn as_raw_fd(&self) -> RawFd { + match self { + UpstreamStream::Tcp(s) => s.as_raw_fd(), + UpstreamStream::Unix(s) => s.as_raw_fd(), + } + } +} + +impl AsyncRead for UpstreamStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + UpstreamStream::Tcp(s) => Pin::new(s).poll_read(cx, buf), + UpstreamStream::Unix(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for UpstreamStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + UpstreamStream::Tcp(s) => Pin::new(s).poll_write(cx, buf), + UpstreamStream::Unix(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + UpstreamStream::Tcp(s) => Pin::new(s).poll_flush(cx), + UpstreamStream::Unix(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + UpstreamStream::Tcp(s) => Pin::new(s).poll_shutdown(cx), + UpstreamStream::Unix(s) => Pin::new(s).poll_shutdown(cx), + } + } +} + +impl UpstreamStream { + pub fn into_async_stream(self) -> io::Result { + match self { + UpstreamStream::Tcp(s) => { + let std = s.into_std()?; + std.set_nonblocking(true)?; + Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?)) + } + UpstreamStream::Unix(s) => { + let std = s.into_std()?; + std.set_nonblocking(true)?; + Ok(AsyncStream::Unix(tokio::io::unix::AsyncFd::new(std)?)) + } + } + } +} + +pub enum AsyncStream { + Tcp(tokio::io::unix::AsyncFd), + Unix(tokio::io::unix::AsyncFd), +} + +impl AsyncStream { + pub fn from_tokio_tcp(stream: tokio::net::TcpStream) -> io::Result { + let std = stream.into_std()?; + std.set_nonblocking(true)?; + Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?)) + } + + pub async fn splice_read(&self, pipe_out: RawFd, len: usize) -> io::Result { + match self { + AsyncStream::Tcp(fd) => perform_splice_read(fd, pipe_out, len).await, + AsyncStream::Unix(fd) => perform_splice_read(fd, pipe_out, len).await, + } + } + + pub async fn splice_write(&self, pipe_in: RawFd, len: usize) -> io::Result { + match self { + AsyncStream::Tcp(fd) => perform_splice_write(fd, pipe_in, len).await, + AsyncStream::Unix(fd) => perform_splice_write(fd, pipe_in, len).await, + } + } +} + +async fn perform_splice_read( + fd: &tokio::io::unix::AsyncFd, + pipe_out: RawFd, + len: usize, +) -> io::Result { + loop { + let mut guard = fd.readable().await?; + match guard.try_io(|inner| unsafe { + let res = libc::splice( + inner.as_raw_fd(), + std::ptr::null_mut(), + pipe_out, + std::ptr::null_mut(), + len, + libc::SPLICE_F_MOVE | libc::SPLICE_F_NONBLOCK, + ); + if res >= 0 { + Ok(res as usize) + } else { + Err(io::Error::last_os_error()) + } + }) { + Ok(res) => return res, + Err(_would_block) => continue, // try_io clears readiness + } + } +} + +async fn perform_splice_write( + fd: &tokio::io::unix::AsyncFd, + pipe_in: RawFd, + len: usize, +) -> io::Result { + loop { + let mut guard = fd.writable().await?; + match guard.try_io(|inner| unsafe { + let res = libc::splice( + pipe_in, + std::ptr::null_mut(), + inner.as_raw_fd(), + std::ptr::null_mut(), + len, + libc::SPLICE_F_MOVE | libc::SPLICE_F_NONBLOCK, + ); + if res >= 0 { + Ok(res as usize) + } else { + Err(io::Error::last_os_error()) + } + }) { + Ok(res) => return res, + Err(_would_block) => continue, + } + } +}