mirror of
https://github.com/awfufu/traudit
synced 2026-03-01 05:29:44 +08:00
feat: implement zero-copy tcp forwarding using splice
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1 +1,2 @@
|
|||||||
/target
|
/target
|
||||||
|
/config.yaml
|
||||||
@@ -23,3 +23,10 @@ async-trait = "0.1"
|
|||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = "3"
|
tempfile = "3"
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
opt-level = 3
|
||||||
|
lto = true
|
||||||
|
codegen-units = 1
|
||||||
|
panic = "abort"
|
||||||
|
strip = true
|
||||||
|
|||||||
65
src/core/forwarder.rs
Normal file
65
src/core/forwarder.rs
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
use crate::core::upstream::AsyncStream;
|
||||||
|
use std::io;
|
||||||
|
|
||||||
|
// Actual implementation below
|
||||||
|
// Spliceable trait and its implementations are removed as AsyncStream handles readiness internally.
|
||||||
|
|
||||||
|
async fn splice_loop(read: &AsyncStream, write: &AsyncStream) -> io::Result<u64> {
|
||||||
|
let mut pipe = [0i32; 2];
|
||||||
|
if unsafe { libc::pipe2(pipe.as_mut_ptr(), libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 {
|
||||||
|
return Err(io::Error::last_os_error());
|
||||||
|
}
|
||||||
|
let (pipe_rd, pipe_wr) = (pipe[0], pipe[1]);
|
||||||
|
|
||||||
|
struct PipeGuard(i32, i32);
|
||||||
|
impl Drop for PipeGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
unsafe {
|
||||||
|
libc::close(self.0);
|
||||||
|
libc::close(self.1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _guard = PipeGuard(pipe_rd, pipe_wr);
|
||||||
|
|
||||||
|
let mut total_bytes = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// src -> pipe
|
||||||
|
// splice_read handles readiness internally with AsyncFd
|
||||||
|
let len = match read.splice_read(pipe_wr, 65536).await {
|
||||||
|
Ok(0) => return Ok(total_bytes), // EOF
|
||||||
|
Ok(n) => n,
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
// pipe -> dst
|
||||||
|
let mut written = 0;
|
||||||
|
while written < len {
|
||||||
|
let to_write = len - written;
|
||||||
|
let n = write.splice_write(pipe_rd, to_write).await?;
|
||||||
|
if n == 0 {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
io::ErrorKind::WriteZero,
|
||||||
|
"Zero write in splice logic",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
written += n;
|
||||||
|
total_bytes += n as u64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn zero_copy_bidirectional(
|
||||||
|
inbound: AsyncStream,
|
||||||
|
outbound: AsyncStream,
|
||||||
|
) -> io::Result<()> {
|
||||||
|
// We own the streams now, so we can split references to them for the join.
|
||||||
|
let (c2s, s2c) = tokio::join!(
|
||||||
|
splice_loop(&inbound, &outbound),
|
||||||
|
splice_loop(&outbound, &inbound)
|
||||||
|
);
|
||||||
|
c2s?;
|
||||||
|
s2c?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1 +1,3 @@
|
|||||||
|
pub mod forwarder;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
|
pub mod upstream;
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
use crate::config::{BindType, Config};
|
use crate::config::{BindType, Config, ServiceConfig};
|
||||||
|
use crate::core::forwarder;
|
||||||
|
use crate::core::upstream::UpstreamStream;
|
||||||
use crate::db::clickhouse::ClickHouseLogger;
|
use crate::db::clickhouse::ClickHouseLogger;
|
||||||
use crate::db::AuditLogger;
|
use crate::protocol;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tracing::{error, info};
|
use tracing::{error, info, instrument};
|
||||||
|
|
||||||
pub async fn run(config: Config) -> anyhow::Result<()> {
|
pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||||
let db = Arc::new(ClickHouseLogger::new(&config.database));
|
let db = Arc::new(ClickHouseLogger::new(&config.database));
|
||||||
@@ -13,15 +16,19 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
|||||||
|
|
||||||
for service in config.services {
|
for service in config.services {
|
||||||
let db = db.clone();
|
let db = db.clone();
|
||||||
for bind in service.binds {
|
for bind in &service.binds {
|
||||||
let service_name = service.name.clone();
|
let service_config = service.clone(); // Clone for the task
|
||||||
let bind_addr = bind.addr.clone();
|
let bind_addr = bind.addr.clone();
|
||||||
|
let proxy_protocol = bind.proxy_protocol.is_some();
|
||||||
let bind_type = bind.bind_type;
|
let bind_type = bind.bind_type;
|
||||||
|
|
||||||
// TODO: Handle UDP and Unix
|
|
||||||
if bind_type == BindType::Tcp {
|
if bind_type == BindType::Tcp {
|
||||||
let db = db.clone();
|
join_set.spawn(start_tcp_service(
|
||||||
join_set.spawn(start_tcp_service(service_name, bind_addr, db));
|
service_config,
|
||||||
|
bind_addr,
|
||||||
|
proxy_protocol,
|
||||||
|
db.clone(),
|
||||||
|
));
|
||||||
} else {
|
} else {
|
||||||
info!("Skipping non-TCP bind for now: {:?}", bind_type);
|
info!("Skipping non-TCP bind for now: {:?}", bind_type);
|
||||||
}
|
}
|
||||||
@@ -39,14 +46,18 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Abort all tasks
|
|
||||||
join_set.shutdown().await;
|
join_set.shutdown().await;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn start_tcp_service(name: String, addr: String, _db: Arc<ClickHouseLogger>) {
|
async fn start_tcp_service(
|
||||||
info!("Service {} listening on TCP {}", name, addr);
|
service: ServiceConfig,
|
||||||
|
addr: String,
|
||||||
|
proxy_protocol: bool,
|
||||||
|
_db: Arc<ClickHouseLogger>,
|
||||||
|
) {
|
||||||
|
info!("Service {} listening on TCP {}", service.name, addr);
|
||||||
let listener = match TcpListener::bind(&addr).await {
|
let listener = match TcpListener::bind(&addr).await {
|
||||||
Ok(l) => l,
|
Ok(l) => l,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -57,10 +68,16 @@ async fn start_tcp_service(name: String, addr: String, _db: Arc<ClickHouseLogger
|
|||||||
|
|
||||||
loop {
|
loop {
|
||||||
match listener.accept().await {
|
match listener.accept().await {
|
||||||
Ok((_socket, client_addr)) => {
|
Ok((mut inbound, client_addr)) => {
|
||||||
info!("New connection from {}", client_addr);
|
info!("New connection from {}", client_addr);
|
||||||
// Spawn handler
|
let service = service.clone();
|
||||||
// tokio::spawn(handle_connection(_socket, ...));
|
// let db = _db.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = handle_connection(inbound, service, proxy_protocol).await {
|
||||||
|
error!("Connection error: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Accept error: {}", e);
|
error!("Accept error: {}", e);
|
||||||
@@ -68,3 +85,36 @@ async fn start_tcp_service(name: String, addr: String, _db: Arc<ClickHouseLogger
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(inbound, service), fields(service = %service.name))]
|
||||||
|
async fn handle_connection(
|
||||||
|
mut inbound: tokio::net::TcpStream,
|
||||||
|
service: ServiceConfig,
|
||||||
|
proxy_protocol: bool,
|
||||||
|
) -> std::io::Result<()> {
|
||||||
|
// 1. Read Proxy Protocol (if configured)
|
||||||
|
let mut buffer = if proxy_protocol {
|
||||||
|
let (_proxy_info, buffer) = protocol::read_proxy_header(&mut inbound).await?;
|
||||||
|
buffer
|
||||||
|
} else {
|
||||||
|
bytes::BytesMut::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
// 2. Connect Upstream
|
||||||
|
let mut upstream = UpstreamStream::connect(service.forward_type, &service.forward_addr).await?;
|
||||||
|
|
||||||
|
// 3. Forward Header (TODO: if configured)
|
||||||
|
|
||||||
|
// 4. Write buffered data (peeked bytes)
|
||||||
|
if !buffer.is_empty() {
|
||||||
|
upstream.write_all_buf(&mut buffer).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Zero-copy forwarding
|
||||||
|
let inbound_async = crate::core::upstream::AsyncStream::from_tokio_tcp(inbound)?;
|
||||||
|
let upstream_async = upstream.into_async_stream()?;
|
||||||
|
|
||||||
|
forwarder::zero_copy_bidirectional(inbound_async, upstream_async).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|||||||
182
src/core/upstream.rs
Normal file
182
src/core/upstream.rs
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
use crate::config::ForwardType;
|
||||||
|
use std::io;
|
||||||
|
use std::os::unix::io::{AsRawFd, RawFd};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
use tokio::net::{TcpStream, UnixStream};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum UpstreamStream {
|
||||||
|
Tcp(TcpStream),
|
||||||
|
Unix(UnixStream),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamStream {
|
||||||
|
pub async fn connect(fw_type: ForwardType, addr: &str) -> io::Result<Self> {
|
||||||
|
match fw_type {
|
||||||
|
ForwardType::Tcp => {
|
||||||
|
let stream = TcpStream::connect(addr).await?;
|
||||||
|
stream.set_nodelay(true)?;
|
||||||
|
Ok(UpstreamStream::Tcp(stream))
|
||||||
|
}
|
||||||
|
ForwardType::Unix => {
|
||||||
|
let stream = UnixStream::connect(addr).await?;
|
||||||
|
Ok(UpstreamStream::Unix(stream))
|
||||||
|
}
|
||||||
|
ForwardType::Udp => Err(io::Error::new(
|
||||||
|
io::ErrorKind::Unsupported,
|
||||||
|
"UDP forwarding not yet implemented in stream context",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsRawFd for UpstreamStream {
|
||||||
|
fn as_raw_fd(&self) -> RawFd {
|
||||||
|
match self {
|
||||||
|
UpstreamStream::Tcp(s) => s.as_raw_fd(),
|
||||||
|
UpstreamStream::Unix(s) => s.as_raw_fd(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for UpstreamStream {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
UpstreamStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
|
||||||
|
UpstreamStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for UpstreamStream {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize, io::Error>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
UpstreamStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
|
||||||
|
UpstreamStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
UpstreamStream::Tcp(s) => Pin::new(s).poll_flush(cx),
|
||||||
|
UpstreamStream::Unix(s) => Pin::new(s).poll_flush(cx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
UpstreamStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
|
||||||
|
UpstreamStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamStream {
|
||||||
|
pub fn into_async_stream(self) -> io::Result<AsyncStream> {
|
||||||
|
match self {
|
||||||
|
UpstreamStream::Tcp(s) => {
|
||||||
|
let std = s.into_std()?;
|
||||||
|
std.set_nonblocking(true)?;
|
||||||
|
Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?))
|
||||||
|
}
|
||||||
|
UpstreamStream::Unix(s) => {
|
||||||
|
let std = s.into_std()?;
|
||||||
|
std.set_nonblocking(true)?;
|
||||||
|
Ok(AsyncStream::Unix(tokio::io::unix::AsyncFd::new(std)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum AsyncStream {
|
||||||
|
Tcp(tokio::io::unix::AsyncFd<std::net::TcpStream>),
|
||||||
|
Unix(tokio::io::unix::AsyncFd<std::os::unix::net::UnixStream>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncStream {
|
||||||
|
pub fn from_tokio_tcp(stream: tokio::net::TcpStream) -> io::Result<Self> {
|
||||||
|
let std = stream.into_std()?;
|
||||||
|
std.set_nonblocking(true)?;
|
||||||
|
Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn splice_read(&self, pipe_out: RawFd, len: usize) -> io::Result<usize> {
|
||||||
|
match self {
|
||||||
|
AsyncStream::Tcp(fd) => perform_splice_read(fd, pipe_out, len).await,
|
||||||
|
AsyncStream::Unix(fd) => perform_splice_read(fd, pipe_out, len).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn splice_write(&self, pipe_in: RawFd, len: usize) -> io::Result<usize> {
|
||||||
|
match self {
|
||||||
|
AsyncStream::Tcp(fd) => perform_splice_write(fd, pipe_in, len).await,
|
||||||
|
AsyncStream::Unix(fd) => perform_splice_write(fd, pipe_in, len).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn perform_splice_read<T: AsRawFd>(
|
||||||
|
fd: &tokio::io::unix::AsyncFd<T>,
|
||||||
|
pipe_out: RawFd,
|
||||||
|
len: usize,
|
||||||
|
) -> io::Result<usize> {
|
||||||
|
loop {
|
||||||
|
let mut guard = fd.readable().await?;
|
||||||
|
match guard.try_io(|inner| unsafe {
|
||||||
|
let res = libc::splice(
|
||||||
|
inner.as_raw_fd(),
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
pipe_out,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
len,
|
||||||
|
libc::SPLICE_F_MOVE | libc::SPLICE_F_NONBLOCK,
|
||||||
|
);
|
||||||
|
if res >= 0 {
|
||||||
|
Ok(res as usize)
|
||||||
|
} else {
|
||||||
|
Err(io::Error::last_os_error())
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
Ok(res) => return res,
|
||||||
|
Err(_would_block) => continue, // try_io clears readiness
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn perform_splice_write<T: AsRawFd>(
|
||||||
|
fd: &tokio::io::unix::AsyncFd<T>,
|
||||||
|
pipe_in: RawFd,
|
||||||
|
len: usize,
|
||||||
|
) -> io::Result<usize> {
|
||||||
|
loop {
|
||||||
|
let mut guard = fd.writable().await?;
|
||||||
|
match guard.try_io(|inner| unsafe {
|
||||||
|
let res = libc::splice(
|
||||||
|
pipe_in,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
inner.as_raw_fd(),
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
len,
|
||||||
|
libc::SPLICE_F_MOVE | libc::SPLICE_F_NONBLOCK,
|
||||||
|
);
|
||||||
|
if res >= 0 {
|
||||||
|
Ok(res as usize)
|
||||||
|
} else {
|
||||||
|
Err(io::Error::last_os_error())
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
Ok(res) => return res,
|
||||||
|
Err(_would_block) => continue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user