mirror of
https://github.com/awfufu/traudit
synced 2026-03-01 05:29:44 +08:00
feat: implement zero-copy tcp forwarding using splice
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1 +1,2 @@
|
||||
/target
|
||||
/config.yaml
|
||||
@@ -23,3 +23,10 @@ async-trait = "0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
strip = true
|
||||
|
||||
65
src/core/forwarder.rs
Normal file
65
src/core/forwarder.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use crate::core::upstream::AsyncStream;
|
||||
use std::io;
|
||||
|
||||
// Actual implementation below
|
||||
// Spliceable trait and its implementations are removed as AsyncStream handles readiness internally.
|
||||
|
||||
async fn splice_loop(read: &AsyncStream, write: &AsyncStream) -> io::Result<u64> {
|
||||
let mut pipe = [0i32; 2];
|
||||
if unsafe { libc::pipe2(pipe.as_mut_ptr(), libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
let (pipe_rd, pipe_wr) = (pipe[0], pipe[1]);
|
||||
|
||||
struct PipeGuard(i32, i32);
|
||||
impl Drop for PipeGuard {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
libc::close(self.0);
|
||||
libc::close(self.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
let _guard = PipeGuard(pipe_rd, pipe_wr);
|
||||
|
||||
let mut total_bytes = 0;
|
||||
|
||||
loop {
|
||||
// src -> pipe
|
||||
// splice_read handles readiness internally with AsyncFd
|
||||
let len = match read.splice_read(pipe_wr, 65536).await {
|
||||
Ok(0) => return Ok(total_bytes), // EOF
|
||||
Ok(n) => n,
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
// pipe -> dst
|
||||
let mut written = 0;
|
||||
while written < len {
|
||||
let to_write = len - written;
|
||||
let n = write.splice_write(pipe_rd, to_write).await?;
|
||||
if n == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"Zero write in splice logic",
|
||||
));
|
||||
}
|
||||
written += n;
|
||||
total_bytes += n as u64;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn zero_copy_bidirectional(
|
||||
inbound: AsyncStream,
|
||||
outbound: AsyncStream,
|
||||
) -> io::Result<()> {
|
||||
// We own the streams now, so we can split references to them for the join.
|
||||
let (c2s, s2c) = tokio::join!(
|
||||
splice_loop(&inbound, &outbound),
|
||||
splice_loop(&outbound, &inbound)
|
||||
);
|
||||
c2s?;
|
||||
s2c?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1 +1,3 @@
|
||||
pub mod forwarder;
|
||||
pub mod server;
|
||||
pub mod upstream;
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use crate::config::{BindType, Config};
|
||||
use crate::config::{BindType, Config, ServiceConfig};
|
||||
use crate::core::forwarder;
|
||||
use crate::core::upstream::UpstreamStream;
|
||||
use crate::db::clickhouse::ClickHouseLogger;
|
||||
use crate::db::AuditLogger;
|
||||
use crate::protocol;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
use tracing::{error, info};
|
||||
use tracing::{error, info, instrument};
|
||||
|
||||
pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
let db = Arc::new(ClickHouseLogger::new(&config.database));
|
||||
@@ -13,15 +16,19 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
|
||||
for service in config.services {
|
||||
let db = db.clone();
|
||||
for bind in service.binds {
|
||||
let service_name = service.name.clone();
|
||||
for bind in &service.binds {
|
||||
let service_config = service.clone(); // Clone for the task
|
||||
let bind_addr = bind.addr.clone();
|
||||
let proxy_protocol = bind.proxy_protocol.is_some();
|
||||
let bind_type = bind.bind_type;
|
||||
|
||||
// TODO: Handle UDP and Unix
|
||||
if bind_type == BindType::Tcp {
|
||||
let db = db.clone();
|
||||
join_set.spawn(start_tcp_service(service_name, bind_addr, db));
|
||||
join_set.spawn(start_tcp_service(
|
||||
service_config,
|
||||
bind_addr,
|
||||
proxy_protocol,
|
||||
db.clone(),
|
||||
));
|
||||
} else {
|
||||
info!("Skipping non-TCP bind for now: {:?}", bind_type);
|
||||
}
|
||||
@@ -39,14 +46,18 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Abort all tasks
|
||||
join_set.shutdown().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_tcp_service(name: String, addr: String, _db: Arc<ClickHouseLogger>) {
|
||||
info!("Service {} listening on TCP {}", name, addr);
|
||||
async fn start_tcp_service(
|
||||
service: ServiceConfig,
|
||||
addr: String,
|
||||
proxy_protocol: bool,
|
||||
_db: Arc<ClickHouseLogger>,
|
||||
) {
|
||||
info!("Service {} listening on TCP {}", service.name, addr);
|
||||
let listener = match TcpListener::bind(&addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
@@ -57,10 +68,16 @@ async fn start_tcp_service(name: String, addr: String, _db: Arc<ClickHouseLogger
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((_socket, client_addr)) => {
|
||||
Ok((mut inbound, client_addr)) => {
|
||||
info!("New connection from {}", client_addr);
|
||||
// Spawn handler
|
||||
// tokio::spawn(handle_connection(_socket, ...));
|
||||
let service = service.clone();
|
||||
// let db = _db.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_connection(inbound, service, proxy_protocol).await {
|
||||
error!("Connection error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Accept error: {}", e);
|
||||
@@ -68,3 +85,36 @@ async fn start_tcp_service(name: String, addr: String, _db: Arc<ClickHouseLogger
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(inbound, service), fields(service = %service.name))]
|
||||
async fn handle_connection(
|
||||
mut inbound: tokio::net::TcpStream,
|
||||
service: ServiceConfig,
|
||||
proxy_protocol: bool,
|
||||
) -> std::io::Result<()> {
|
||||
// 1. Read Proxy Protocol (if configured)
|
||||
let mut buffer = if proxy_protocol {
|
||||
let (_proxy_info, buffer) = protocol::read_proxy_header(&mut inbound).await?;
|
||||
buffer
|
||||
} else {
|
||||
bytes::BytesMut::new()
|
||||
};
|
||||
|
||||
// 2. Connect Upstream
|
||||
let mut upstream = UpstreamStream::connect(service.forward_type, &service.forward_addr).await?;
|
||||
|
||||
// 3. Forward Header (TODO: if configured)
|
||||
|
||||
// 4. Write buffered data (peeked bytes)
|
||||
if !buffer.is_empty() {
|
||||
upstream.write_all_buf(&mut buffer).await?;
|
||||
}
|
||||
|
||||
// 5. Zero-copy forwarding
|
||||
let inbound_async = crate::core::upstream::AsyncStream::from_tokio_tcp(inbound)?;
|
||||
let upstream_async = upstream.into_async_stream()?;
|
||||
|
||||
forwarder::zero_copy_bidirectional(inbound_async, upstream_async).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
182
src/core/upstream.rs
Normal file
182
src/core/upstream.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use crate::config::ForwardType;
|
||||
use std::io;
|
||||
use std::os::unix::io::{AsRawFd, RawFd};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::{TcpStream, UnixStream};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum UpstreamStream {
|
||||
Tcp(TcpStream),
|
||||
Unix(UnixStream),
|
||||
}
|
||||
|
||||
impl UpstreamStream {
|
||||
pub async fn connect(fw_type: ForwardType, addr: &str) -> io::Result<Self> {
|
||||
match fw_type {
|
||||
ForwardType::Tcp => {
|
||||
let stream = TcpStream::connect(addr).await?;
|
||||
stream.set_nodelay(true)?;
|
||||
Ok(UpstreamStream::Tcp(stream))
|
||||
}
|
||||
ForwardType::Unix => {
|
||||
let stream = UnixStream::connect(addr).await?;
|
||||
Ok(UpstreamStream::Unix(stream))
|
||||
}
|
||||
ForwardType::Udp => Err(io::Error::new(
|
||||
io::ErrorKind::Unsupported,
|
||||
"UDP forwarding not yet implemented in stream context",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRawFd for UpstreamStream {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
match self {
|
||||
UpstreamStream::Tcp(s) => s.as_raw_fd(),
|
||||
UpstreamStream::Unix(s) => s.as_raw_fd(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for UpstreamStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
UpstreamStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
|
||||
UpstreamStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for UpstreamStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
match self.get_mut() {
|
||||
UpstreamStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
|
||||
UpstreamStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
match self.get_mut() {
|
||||
UpstreamStream::Tcp(s) => Pin::new(s).poll_flush(cx),
|
||||
UpstreamStream::Unix(s) => Pin::new(s).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
match self.get_mut() {
|
||||
UpstreamStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
|
||||
UpstreamStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UpstreamStream {
|
||||
pub fn into_async_stream(self) -> io::Result<AsyncStream> {
|
||||
match self {
|
||||
UpstreamStream::Tcp(s) => {
|
||||
let std = s.into_std()?;
|
||||
std.set_nonblocking(true)?;
|
||||
Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?))
|
||||
}
|
||||
UpstreamStream::Unix(s) => {
|
||||
let std = s.into_std()?;
|
||||
std.set_nonblocking(true)?;
|
||||
Ok(AsyncStream::Unix(tokio::io::unix::AsyncFd::new(std)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum AsyncStream {
|
||||
Tcp(tokio::io::unix::AsyncFd<std::net::TcpStream>),
|
||||
Unix(tokio::io::unix::AsyncFd<std::os::unix::net::UnixStream>),
|
||||
}
|
||||
|
||||
impl AsyncStream {
|
||||
pub fn from_tokio_tcp(stream: tokio::net::TcpStream) -> io::Result<Self> {
|
||||
let std = stream.into_std()?;
|
||||
std.set_nonblocking(true)?;
|
||||
Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?))
|
||||
}
|
||||
|
||||
pub async fn splice_read(&self, pipe_out: RawFd, len: usize) -> io::Result<usize> {
|
||||
match self {
|
||||
AsyncStream::Tcp(fd) => perform_splice_read(fd, pipe_out, len).await,
|
||||
AsyncStream::Unix(fd) => perform_splice_read(fd, pipe_out, len).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn splice_write(&self, pipe_in: RawFd, len: usize) -> io::Result<usize> {
|
||||
match self {
|
||||
AsyncStream::Tcp(fd) => perform_splice_write(fd, pipe_in, len).await,
|
||||
AsyncStream::Unix(fd) => perform_splice_write(fd, pipe_in, len).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_splice_read<T: AsRawFd>(
|
||||
fd: &tokio::io::unix::AsyncFd<T>,
|
||||
pipe_out: RawFd,
|
||||
len: usize,
|
||||
) -> io::Result<usize> {
|
||||
loop {
|
||||
let mut guard = fd.readable().await?;
|
||||
match guard.try_io(|inner| unsafe {
|
||||
let res = libc::splice(
|
||||
inner.as_raw_fd(),
|
||||
std::ptr::null_mut(),
|
||||
pipe_out,
|
||||
std::ptr::null_mut(),
|
||||
len,
|
||||
libc::SPLICE_F_MOVE | libc::SPLICE_F_NONBLOCK,
|
||||
);
|
||||
if res >= 0 {
|
||||
Ok(res as usize)
|
||||
} else {
|
||||
Err(io::Error::last_os_error())
|
||||
}
|
||||
}) {
|
||||
Ok(res) => return res,
|
||||
Err(_would_block) => continue, // try_io clears readiness
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_splice_write<T: AsRawFd>(
|
||||
fd: &tokio::io::unix::AsyncFd<T>,
|
||||
pipe_in: RawFd,
|
||||
len: usize,
|
||||
) -> io::Result<usize> {
|
||||
loop {
|
||||
let mut guard = fd.writable().await?;
|
||||
match guard.try_io(|inner| unsafe {
|
||||
let res = libc::splice(
|
||||
pipe_in,
|
||||
std::ptr::null_mut(),
|
||||
inner.as_raw_fd(),
|
||||
std::ptr::null_mut(),
|
||||
len,
|
||||
libc::SPLICE_F_MOVE | libc::SPLICE_F_NONBLOCK,
|
||||
);
|
||||
if res >= 0 {
|
||||
Ok(res as usize)
|
||||
} else {
|
||||
Err(io::Error::last_os_error())
|
||||
}
|
||||
}) {
|
||||
Ok(res) => return res,
|
||||
Err(_would_block) => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user