feat: implement zero-copy tcp forwarding using splice

This commit is contained in:
2026-01-15 10:52:02 +08:00
parent e0d57dc975
commit b36ece9d86
6 changed files with 321 additions and 14 deletions

1
.gitignore vendored
View File

@@ -1 +1,2 @@
/target /target
/config.yaml

View File

@@ -23,3 +23,10 @@ async-trait = "0.1"
[dev-dependencies] [dev-dependencies]
tempfile = "3" tempfile = "3"
[profile.release]
opt-level = 3
lto = true
codegen-units = 1
panic = "abort"
strip = true

65
src/core/forwarder.rs Normal file
View File

@@ -0,0 +1,65 @@
use crate::core::upstream::AsyncStream;
use std::io;
// Actual implementation below
// Spliceable trait and its implementations are removed as AsyncStream handles readiness internally.
async fn splice_loop(read: &AsyncStream, write: &AsyncStream) -> io::Result<u64> {
let mut pipe = [0i32; 2];
if unsafe { libc::pipe2(pipe.as_mut_ptr(), libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 {
return Err(io::Error::last_os_error());
}
let (pipe_rd, pipe_wr) = (pipe[0], pipe[1]);
struct PipeGuard(i32, i32);
impl Drop for PipeGuard {
fn drop(&mut self) {
unsafe {
libc::close(self.0);
libc::close(self.1);
}
}
}
let _guard = PipeGuard(pipe_rd, pipe_wr);
let mut total_bytes = 0;
loop {
// src -> pipe
// splice_read handles readiness internally with AsyncFd
let len = match read.splice_read(pipe_wr, 65536).await {
Ok(0) => return Ok(total_bytes), // EOF
Ok(n) => n,
Err(e) => return Err(e),
};
// pipe -> dst
let mut written = 0;
while written < len {
let to_write = len - written;
let n = write.splice_write(pipe_rd, to_write).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Zero write in splice logic",
));
}
written += n;
total_bytes += n as u64;
}
}
}
pub async fn zero_copy_bidirectional(
inbound: AsyncStream,
outbound: AsyncStream,
) -> io::Result<()> {
// We own the streams now, so we can split references to them for the join.
let (c2s, s2c) = tokio::join!(
splice_loop(&inbound, &outbound),
splice_loop(&outbound, &inbound)
);
c2s?;
s2c?;
Ok(())
}

View File

@@ -1 +1,3 @@
pub mod forwarder;
pub mod server; pub mod server;
pub mod upstream;

View File

@@ -1,10 +1,13 @@
use crate::config::{BindType, Config}; use crate::config::{BindType, Config, ServiceConfig};
use crate::core::forwarder;
use crate::core::upstream::UpstreamStream;
use crate::db::clickhouse::ClickHouseLogger; use crate::db::clickhouse::ClickHouseLogger;
use crate::db::AuditLogger; use crate::protocol;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::signal; use tokio::signal;
use tracing::{error, info}; use tracing::{error, info, instrument};
pub async fn run(config: Config) -> anyhow::Result<()> { pub async fn run(config: Config) -> anyhow::Result<()> {
let db = Arc::new(ClickHouseLogger::new(&config.database)); let db = Arc::new(ClickHouseLogger::new(&config.database));
@@ -13,15 +16,19 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
for service in config.services { for service in config.services {
let db = db.clone(); let db = db.clone();
for bind in service.binds { for bind in &service.binds {
let service_name = service.name.clone(); let service_config = service.clone(); // Clone for the task
let bind_addr = bind.addr.clone(); let bind_addr = bind.addr.clone();
let proxy_protocol = bind.proxy_protocol.is_some();
let bind_type = bind.bind_type; let bind_type = bind.bind_type;
// TODO: Handle UDP and Unix
if bind_type == BindType::Tcp { if bind_type == BindType::Tcp {
let db = db.clone(); join_set.spawn(start_tcp_service(
join_set.spawn(start_tcp_service(service_name, bind_addr, db)); service_config,
bind_addr,
proxy_protocol,
db.clone(),
));
} else { } else {
info!("Skipping non-TCP bind for now: {:?}", bind_type); info!("Skipping non-TCP bind for now: {:?}", bind_type);
} }
@@ -39,14 +46,18 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
} }
} }
// Abort all tasks
join_set.shutdown().await; join_set.shutdown().await;
Ok(()) Ok(())
} }
async fn start_tcp_service(name: String, addr: String, _db: Arc<ClickHouseLogger>) { async fn start_tcp_service(
info!("Service {} listening on TCP {}", name, addr); service: ServiceConfig,
addr: String,
proxy_protocol: bool,
_db: Arc<ClickHouseLogger>,
) {
info!("Service {} listening on TCP {}", service.name, addr);
let listener = match TcpListener::bind(&addr).await { let listener = match TcpListener::bind(&addr).await {
Ok(l) => l, Ok(l) => l,
Err(e) => { Err(e) => {
@@ -57,10 +68,16 @@ async fn start_tcp_service(name: String, addr: String, _db: Arc<ClickHouseLogger
loop { loop {
match listener.accept().await { match listener.accept().await {
Ok((_socket, client_addr)) => { Ok((mut inbound, client_addr)) => {
info!("New connection from {}", client_addr); info!("New connection from {}", client_addr);
// Spawn handler let service = service.clone();
// tokio::spawn(handle_connection(_socket, ...)); // let db = _db.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(inbound, service, proxy_protocol).await {
error!("Connection error: {}", e);
}
});
} }
Err(e) => { Err(e) => {
error!("Accept error: {}", e); error!("Accept error: {}", e);
@@ -68,3 +85,36 @@ async fn start_tcp_service(name: String, addr: String, _db: Arc<ClickHouseLogger
} }
} }
} }
#[instrument(skip(inbound, service), fields(service = %service.name))]
async fn handle_connection(
mut inbound: tokio::net::TcpStream,
service: ServiceConfig,
proxy_protocol: bool,
) -> std::io::Result<()> {
// 1. Read Proxy Protocol (if configured)
let mut buffer = if proxy_protocol {
let (_proxy_info, buffer) = protocol::read_proxy_header(&mut inbound).await?;
buffer
} else {
bytes::BytesMut::new()
};
// 2. Connect Upstream
let mut upstream = UpstreamStream::connect(service.forward_type, &service.forward_addr).await?;
// 3. Forward Header (TODO: if configured)
// 4. Write buffered data (peeked bytes)
if !buffer.is_empty() {
upstream.write_all_buf(&mut buffer).await?;
}
// 5. Zero-copy forwarding
let inbound_async = crate::core::upstream::AsyncStream::from_tokio_tcp(inbound)?;
let upstream_async = upstream.into_async_stream()?;
forwarder::zero_copy_bidirectional(inbound_async, upstream_async).await?;
Ok(())
}

182
src/core/upstream.rs Normal file
View File

@@ -0,0 +1,182 @@
use crate::config::ForwardType;
use std::io;
use std::os::unix::io::{AsRawFd, RawFd};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{TcpStream, UnixStream};
#[derive(Debug)]
pub enum UpstreamStream {
Tcp(TcpStream),
Unix(UnixStream),
}
impl UpstreamStream {
pub async fn connect(fw_type: ForwardType, addr: &str) -> io::Result<Self> {
match fw_type {
ForwardType::Tcp => {
let stream = TcpStream::connect(addr).await?;
stream.set_nodelay(true)?;
Ok(UpstreamStream::Tcp(stream))
}
ForwardType::Unix => {
let stream = UnixStream::connect(addr).await?;
Ok(UpstreamStream::Unix(stream))
}
ForwardType::Udp => Err(io::Error::new(
io::ErrorKind::Unsupported,
"UDP forwarding not yet implemented in stream context",
)),
}
}
}
impl AsRawFd for UpstreamStream {
fn as_raw_fd(&self) -> RawFd {
match self {
UpstreamStream::Tcp(s) => s.as_raw_fd(),
UpstreamStream::Unix(s) => s.as_raw_fd(),
}
}
}
impl AsyncRead for UpstreamStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
UpstreamStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
UpstreamStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for UpstreamStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
match self.get_mut() {
UpstreamStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
UpstreamStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.get_mut() {
UpstreamStream::Tcp(s) => Pin::new(s).poll_flush(cx),
UpstreamStream::Unix(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.get_mut() {
UpstreamStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
UpstreamStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
impl UpstreamStream {
pub fn into_async_stream(self) -> io::Result<AsyncStream> {
match self {
UpstreamStream::Tcp(s) => {
let std = s.into_std()?;
std.set_nonblocking(true)?;
Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?))
}
UpstreamStream::Unix(s) => {
let std = s.into_std()?;
std.set_nonblocking(true)?;
Ok(AsyncStream::Unix(tokio::io::unix::AsyncFd::new(std)?))
}
}
}
}
pub enum AsyncStream {
Tcp(tokio::io::unix::AsyncFd<std::net::TcpStream>),
Unix(tokio::io::unix::AsyncFd<std::os::unix::net::UnixStream>),
}
impl AsyncStream {
pub fn from_tokio_tcp(stream: tokio::net::TcpStream) -> io::Result<Self> {
let std = stream.into_std()?;
std.set_nonblocking(true)?;
Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?))
}
pub async fn splice_read(&self, pipe_out: RawFd, len: usize) -> io::Result<usize> {
match self {
AsyncStream::Tcp(fd) => perform_splice_read(fd, pipe_out, len).await,
AsyncStream::Unix(fd) => perform_splice_read(fd, pipe_out, len).await,
}
}
pub async fn splice_write(&self, pipe_in: RawFd, len: usize) -> io::Result<usize> {
match self {
AsyncStream::Tcp(fd) => perform_splice_write(fd, pipe_in, len).await,
AsyncStream::Unix(fd) => perform_splice_write(fd, pipe_in, len).await,
}
}
}
async fn perform_splice_read<T: AsRawFd>(
fd: &tokio::io::unix::AsyncFd<T>,
pipe_out: RawFd,
len: usize,
) -> io::Result<usize> {
loop {
let mut guard = fd.readable().await?;
match guard.try_io(|inner| unsafe {
let res = libc::splice(
inner.as_raw_fd(),
std::ptr::null_mut(),
pipe_out,
std::ptr::null_mut(),
len,
libc::SPLICE_F_MOVE | libc::SPLICE_F_NONBLOCK,
);
if res >= 0 {
Ok(res as usize)
} else {
Err(io::Error::last_os_error())
}
}) {
Ok(res) => return res,
Err(_would_block) => continue, // try_io clears readiness
}
}
}
async fn perform_splice_write<T: AsRawFd>(
fd: &tokio::io::unix::AsyncFd<T>,
pipe_in: RawFd,
len: usize,
) -> io::Result<usize> {
loop {
let mut guard = fd.writable().await?;
match guard.try_io(|inner| unsafe {
let res = libc::splice(
pipe_in,
std::ptr::null_mut(),
inner.as_raw_fd(),
std::ptr::null_mut(),
len,
libc::SPLICE_F_MOVE | libc::SPLICE_F_NONBLOCK,
);
if res >= 0 {
Ok(res as usize)
} else {
Err(io::Error::last_os_error())
}
}) {
Ok(res) => return res,
Err(_would_block) => continue,
}
}
}