mirror of
https://github.com/awfufu/traudit
synced 2026-03-01 05:29:44 +08:00
perf: migrate from epoll to io_uring
This commit is contained in:
117
Cargo.lock
generated
117
Cargo.lock
generated
@@ -40,6 +40,17 @@ version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||
|
||||
[[package]]
|
||||
name = "auto-const-array"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd73835ad7deb4bd2b389e6f10333b143f025d607c55ca04c66a0bcc6bb2fc6d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.5.0"
|
||||
@@ -58,6 +69,12 @@ version = "1.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.10.0"
|
||||
@@ -445,6 +462,15 @@ dependencies = [
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fxhash"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
@@ -743,6 +769,16 @@ dependencies = [
|
||||
"hashbrown 0.16.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "io-uring"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "595a0399f411a508feb2ec1e970a4a30c249351e30208960d58298de8660b0e5"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.17"
|
||||
@@ -776,7 +812,7 @@ version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"bitflags 2.10.0",
|
||||
"libc",
|
||||
"redox_syscall 0.7.0",
|
||||
]
|
||||
@@ -841,6 +877,27 @@ version = "2.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "0.8.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"wasi",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "1.1.1"
|
||||
@@ -852,6 +909,50 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "monoio"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3bd0f8bcde87b1949f95338b547543fcab187bc7e7a5024247e359a5e828ba6a"
|
||||
dependencies = [
|
||||
"auto-const-array",
|
||||
"bytes",
|
||||
"fxhash",
|
||||
"io-uring",
|
||||
"libc",
|
||||
"memchr",
|
||||
"mio 0.8.11",
|
||||
"monoio-macros",
|
||||
"nix",
|
||||
"pin-project-lite",
|
||||
"socket2 0.5.10",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "monoio-macros"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "176a5f5e69613d9e88337cf2a65e11135332b4efbcc628404a7c555e4452084c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nix"
|
||||
version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"memoffset",
|
||||
"pin-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.50.3"
|
||||
@@ -1074,7 +1175,7 @@ version = "0.5.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"bitflags 2.10.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1083,7 +1184,7 @@ version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49f3fe0889e69e2ae9e41f4d6c4c0181701d00e4697b356fb1f74173a5e0ee27"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"bitflags 2.10.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1132,7 +1233,7 @@ version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"bitflags 2.10.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
@@ -1481,7 +1582,7 @@ checksum = "aa003f0038df784eb8fecbbac13affe3da23b45194bd57dba231c8f48199c526"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64",
|
||||
"bitflags",
|
||||
"bitflags 2.10.0",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"crc",
|
||||
@@ -1523,7 +1624,7 @@ checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64",
|
||||
"bitflags",
|
||||
"bitflags 2.10.0",
|
||||
"byteorder",
|
||||
"crc",
|
||||
"dotenvy",
|
||||
@@ -1722,7 +1823,7 @@ checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"libc",
|
||||
"mio",
|
||||
"mio 1.1.1",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
@@ -1825,7 +1926,9 @@ dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"clickhouse",
|
||||
"io-uring",
|
||||
"libc",
|
||||
"monoio",
|
||||
"serde",
|
||||
"serde_yaml",
|
||||
"socket2 0.5.10",
|
||||
|
||||
@@ -6,6 +6,7 @@ authors = ["awfufu"]
|
||||
description = "A reverse proxy with auditing capabilities."
|
||||
|
||||
[dependencies]
|
||||
monoio = { version = "0.2", features = ["async-cancel", "bytes", "macros", "utils", "splice", "iouring"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio", "tls-rustls", "mysql", "postgres", "sqlite"] }
|
||||
clickhouse = { version = "0.13", features = ["test-util"] }
|
||||
@@ -13,6 +14,7 @@ serde = { version = "1", features = ["derive"] }
|
||||
serde_yaml = "0.9"
|
||||
socket2 = "0.5"
|
||||
libc = "0.2"
|
||||
io-uring = "0.6"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
anyhow = "1.0"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use serde::Deserialize;
|
||||
use std::path::Path;
|
||||
use tokio::fs;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct Config {
|
||||
@@ -62,8 +61,8 @@ pub enum ProxyProtocolVersion {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub async fn load<P: AsRef<Path>>(path: P) -> Result<Self, anyhow::Error> {
|
||||
let content = fs::read_to_string(path).await?;
|
||||
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, anyhow::Error> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let config: Config = serde_yaml::from_str(&content)?;
|
||||
Ok(config)
|
||||
}
|
||||
@@ -99,7 +98,7 @@ services:
|
||||
// Close the file handle so tokio can read it, or just keep it open and read by path?
|
||||
// tempfile deletes on drop. We need to keep `file` alive.
|
||||
|
||||
let config = Config::load(&path).await.expect("Failed to load config");
|
||||
let config = Config::load(&path).expect("Failed to load config");
|
||||
|
||||
assert_eq!(
|
||||
config.database.dsn,
|
||||
|
||||
@@ -1,65 +1,96 @@
|
||||
use crate::core::upstream::AsyncStream;
|
||||
use crate::core::upstream::UpstreamStream;
|
||||
use monoio::fs::File;
|
||||
use monoio::io::splice::{SpliceDestination, SpliceSource};
|
||||
use monoio::io::Splitable;
|
||||
use monoio::net::unix::new_pipe;
|
||||
use monoio::net::TcpStream;
|
||||
use std::io;
|
||||
use std::os::unix::io::AsRawFd;
|
||||
|
||||
// Actual implementation below
|
||||
// Spliceable trait and its implementations are removed as AsyncStream handles readiness internally.
|
||||
const SPLICE_SIZE: u32 = 1024 * 1024; // 1MB
|
||||
|
||||
async fn splice_loop(read: &AsyncStream, write: &AsyncStream) -> io::Result<u64> {
|
||||
let mut pipe = [0i32; 2];
|
||||
if unsafe { libc::pipe2(pipe.as_mut_ptr(), libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
async fn transfer<R, W>(mut read: R, mut write: W) -> io::Result<()>
|
||||
where
|
||||
R: SpliceSource,
|
||||
W: SpliceDestination,
|
||||
{
|
||||
// Double buffering: Create two pipes
|
||||
let (mut p1_r, mut p1_w) = new_pipe()?;
|
||||
let (mut p2_r, mut p2_w) = new_pipe()?;
|
||||
|
||||
// Resize both pipes
|
||||
unsafe {
|
||||
let f1_r: &File = std::mem::transmute(&p1_r);
|
||||
let f1_w: &File = std::mem::transmute(&p1_w);
|
||||
let f2_r: &File = std::mem::transmute(&p2_r);
|
||||
let f2_w: &File = std::mem::transmute(&p2_w);
|
||||
|
||||
libc::fcntl(f1_r.as_raw_fd(), 1031, SPLICE_SIZE as libc::c_int);
|
||||
libc::fcntl(f1_w.as_raw_fd(), 1031, SPLICE_SIZE as libc::c_int);
|
||||
libc::fcntl(f2_r.as_raw_fd(), 1031, SPLICE_SIZE as libc::c_int);
|
||||
libc::fcntl(f2_w.as_raw_fd(), 1031, SPLICE_SIZE as libc::c_int);
|
||||
}
|
||||
let (pipe_rd, pipe_wr) = (pipe[0], pipe[1]);
|
||||
|
||||
struct PipeGuard(i32, i32);
|
||||
impl Drop for PipeGuard {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
libc::close(self.0);
|
||||
libc::close(self.1);
|
||||
}
|
||||
}
|
||||
// Prime the first pipe
|
||||
let n = read.splice_to_pipe(&mut p1_w, SPLICE_SIZE).await?;
|
||||
if n == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let _guard = PipeGuard(pipe_rd, pipe_wr);
|
||||
|
||||
let mut total_bytes = 0;
|
||||
|
||||
loop {
|
||||
// src -> pipe
|
||||
// splice_read handles readiness internally with AsyncFd
|
||||
let len = match read.splice_read(pipe_wr, 65536).await {
|
||||
Ok(0) => return Ok(total_bytes), // EOF
|
||||
Ok(n) => n,
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
// Write from p1_r -> write AND Read from read -> p2_w
|
||||
let (res_w, res_r) = monoio::join!(
|
||||
write.splice_from_pipe(&mut p1_r, SPLICE_SIZE),
|
||||
read.splice_to_pipe(&mut p2_w, SPLICE_SIZE)
|
||||
);
|
||||
|
||||
// pipe -> dst
|
||||
let mut written = 0;
|
||||
while written < len {
|
||||
let to_write = len - written;
|
||||
let n = write.splice_write(pipe_rd, to_write).await?;
|
||||
if n == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"Zero write in splice logic",
|
||||
));
|
||||
}
|
||||
written += n;
|
||||
total_bytes += n as u64;
|
||||
let _w = res_w?;
|
||||
let r = res_r?;
|
||||
|
||||
if r == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Swap pipes so p2 becomes the source for next write, and p1 becomes available for read
|
||||
std::mem::swap(&mut p1_r, &mut p2_r);
|
||||
std::mem::swap(&mut p1_w, &mut p2_w);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_busy_poll(fd: std::os::unix::io::RawFd, us: libc::c_int) {
|
||||
unsafe {
|
||||
let val = us;
|
||||
libc::setsockopt(
|
||||
fd,
|
||||
libc::SOL_SOCKET,
|
||||
50, // SO_BUSY_POLL is 50 on Linux
|
||||
&val as *const _ as *const libc::c_void,
|
||||
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn zero_copy_bidirectional(
|
||||
inbound: AsyncStream,
|
||||
outbound: AsyncStream,
|
||||
inbound: TcpStream,
|
||||
outbound: UpstreamStream,
|
||||
) -> io::Result<()> {
|
||||
// We own the streams now, so we can split references to them for the join.
|
||||
let (c2s, s2c) = tokio::join!(
|
||||
splice_loop(&inbound, &outbound),
|
||||
splice_loop(&outbound, &inbound)
|
||||
);
|
||||
c2s?;
|
||||
s2c?;
|
||||
set_busy_poll(inbound.as_raw_fd(), 50);
|
||||
|
||||
let (in_r, in_w) = inbound.into_split();
|
||||
match outbound {
|
||||
UpstreamStream::Tcp(s) => {
|
||||
set_busy_poll(s.as_raw_fd(), 50);
|
||||
let (out_r, out_w) = s.into_split();
|
||||
let (r1, r2) = monoio::join!(transfer(in_r, out_w), transfer(out_r, in_w));
|
||||
r1?;
|
||||
r2?;
|
||||
}
|
||||
UpstreamStream::Unix(s) => {
|
||||
let (out_r, out_w) = s.into_split();
|
||||
let (r1, r2) = monoio::join!(transfer(in_r, out_w), transfer(out_r, in_w));
|
||||
r1?;
|
||||
r2?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -3,16 +3,14 @@ use crate::core::forwarder;
|
||||
use crate::core::upstream::UpstreamStream;
|
||||
use crate::db::clickhouse::ClickHouseLogger;
|
||||
use crate::protocol;
|
||||
use monoio::net::TcpListener;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
use tracing::{error, info};
|
||||
|
||||
pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
let db = Arc::new(ClickHouseLogger::new(&config.database));
|
||||
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for service in config.services {
|
||||
let db = db.clone();
|
||||
@@ -23,19 +21,22 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
let bind_type = bind.bind_type;
|
||||
|
||||
if bind_type == BindType::Tcp {
|
||||
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
|
||||
error!("[{}] failed to bind {}: {}", service_config.name, bind_addr, e);
|
||||
let listener = TcpListener::bind(&bind_addr).map_err(|e| {
|
||||
error!(
|
||||
"[{}] failed to bind {}: {}",
|
||||
service_config.name, bind_addr, e
|
||||
);
|
||||
e
|
||||
})?;
|
||||
|
||||
info!("[{}] listening on tcp {}", service_config.name, bind_addr);
|
||||
|
||||
join_set.spawn(start_tcp_service(
|
||||
handles.push(monoio::spawn(start_tcp_service(
|
||||
service_config,
|
||||
listener,
|
||||
proxy_protocol,
|
||||
db.clone(),
|
||||
));
|
||||
)));
|
||||
} else {
|
||||
info!("skipping non-tcp bind for now: {:?}", bind_type);
|
||||
}
|
||||
@@ -53,17 +54,14 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
match signal::ctrl_c().await {
|
||||
Ok(()) => {
|
||||
info!("shutdown signal received.");
|
||||
}
|
||||
Err(err) => {
|
||||
error!("unable to listen for shutdown signal: {}", err);
|
||||
}
|
||||
// Monoio doesn't have a signal::ctrl_c helper built-in effectively like tokio's
|
||||
// But we can just wait on the handles or use a simple waiting mechanism.
|
||||
// For now, we await the services. Since they loop forever, this runs forever.
|
||||
// TODO: Implement signal handling for graceful shutdown if needed.
|
||||
for h in handles {
|
||||
let _ = h.await;
|
||||
}
|
||||
|
||||
join_set.shutdown().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -76,20 +74,16 @@ async fn start_tcp_service(
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((inbound, _client_addr)) => {
|
||||
// log moved to handle_connection for consistent real ip logging
|
||||
let service = service.clone();
|
||||
// let db = _db.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
monoio::spawn(async move {
|
||||
if let Err(e) = handle_connection(inbound, service, proxy_protocol).await {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
|
||||
// normal disconnects, debug log only
|
||||
tracing::debug!("connection closed: {}", e);
|
||||
}
|
||||
_ => {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
|
||||
tracing::debug!("connection closed: {}", e);
|
||||
}
|
||||
_ => {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -102,12 +96,12 @@ async fn start_tcp_service(
|
||||
}
|
||||
|
||||
async fn handle_connection(
|
||||
mut inbound: tokio::net::TcpStream,
|
||||
mut inbound: monoio::net::TcpStream,
|
||||
service: ServiceConfig,
|
||||
proxy_protocol: bool,
|
||||
) -> std::io::Result<()> {
|
||||
// read proxy protocol (if configured)
|
||||
let (_client_addr, mut buffer) = if proxy_protocol {
|
||||
let (_client_addr, buffer) = if proxy_protocol {
|
||||
let (proxy_info, buffer) = protocol::read_proxy_header(&mut inbound).await?;
|
||||
if let Some(info) = proxy_info {
|
||||
let physical = inbound.peer_addr()?;
|
||||
@@ -127,18 +121,18 @@ async fn handle_connection(
|
||||
// connect upstream
|
||||
let mut upstream = UpstreamStream::connect(service.forward_type, &service.forward_addr).await?;
|
||||
|
||||
// forward header (TODO: if configured)
|
||||
|
||||
// write buffered data (peeked bytes)
|
||||
if !buffer.is_empty() {
|
||||
upstream.write_all_buf(&mut buffer).await?;
|
||||
// UpstreamStream needs to support writing bytes directly
|
||||
// logic needs update in upstream.rs
|
||||
upstream.write_all(buffer).await.0?;
|
||||
}
|
||||
|
||||
// zero-copy forwarding
|
||||
let inbound_async = crate::core::upstream::AsyncStream::from_tokio_tcp(inbound)?;
|
||||
let upstream_async = upstream.into_async_stream()?;
|
||||
// Monoio's TcpStream is already compatible with splice if we access the fd,
|
||||
// but better to pass the stream itself to forwarder.
|
||||
|
||||
forwarder::zero_copy_bidirectional(inbound_async, upstream_async).await?;
|
||||
forwarder::zero_copy_bidirectional(inbound, upstream).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
use crate::config::ForwardType;
|
||||
use monoio::io::AsyncWriteRentExt;
|
||||
use monoio::net::{TcpStream, UnixStream};
|
||||
use std::io;
|
||||
use std::os::unix::io::{AsRawFd, RawFd};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::{TcpStream, UnixStream};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum UpstreamStream {
|
||||
@@ -30,153 +27,11 @@ impl UpstreamStream {
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRawFd for UpstreamStream {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
pub async fn write_all<T: monoio::buf::IoBuf>(&mut self, buf: T) -> (io::Result<usize>, T) {
|
||||
match self {
|
||||
UpstreamStream::Tcp(s) => s.as_raw_fd(),
|
||||
UpstreamStream::Unix(s) => s.as_raw_fd(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for UpstreamStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
UpstreamStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
|
||||
UpstreamStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for UpstreamStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
match self.get_mut() {
|
||||
UpstreamStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
|
||||
UpstreamStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
match self.get_mut() {
|
||||
UpstreamStream::Tcp(s) => Pin::new(s).poll_flush(cx),
|
||||
UpstreamStream::Unix(s) => Pin::new(s).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
match self.get_mut() {
|
||||
UpstreamStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
|
||||
UpstreamStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UpstreamStream {
|
||||
pub fn into_async_stream(self) -> io::Result<AsyncStream> {
|
||||
match self {
|
||||
UpstreamStream::Tcp(s) => {
|
||||
let std = s.into_std()?;
|
||||
std.set_nonblocking(true)?;
|
||||
Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?))
|
||||
}
|
||||
UpstreamStream::Unix(s) => {
|
||||
let std = s.into_std()?;
|
||||
std.set_nonblocking(true)?;
|
||||
Ok(AsyncStream::Unix(tokio::io::unix::AsyncFd::new(std)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum AsyncStream {
|
||||
Tcp(tokio::io::unix::AsyncFd<std::net::TcpStream>),
|
||||
Unix(tokio::io::unix::AsyncFd<std::os::unix::net::UnixStream>),
|
||||
}
|
||||
|
||||
impl AsyncStream {
|
||||
pub fn from_tokio_tcp(stream: tokio::net::TcpStream) -> io::Result<Self> {
|
||||
let std = stream.into_std()?;
|
||||
std.set_nonblocking(true)?;
|
||||
Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?))
|
||||
}
|
||||
|
||||
pub async fn splice_read(&self, pipe_out: RawFd, len: usize) -> io::Result<usize> {
|
||||
match self {
|
||||
AsyncStream::Tcp(fd) => perform_splice_read(fd, pipe_out, len).await,
|
||||
AsyncStream::Unix(fd) => perform_splice_read(fd, pipe_out, len).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn splice_write(&self, pipe_in: RawFd, len: usize) -> io::Result<usize> {
|
||||
match self {
|
||||
AsyncStream::Tcp(fd) => perform_splice_write(fd, pipe_in, len).await,
|
||||
AsyncStream::Unix(fd) => perform_splice_write(fd, pipe_in, len).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_splice_read<T: AsRawFd>(
|
||||
fd: &tokio::io::unix::AsyncFd<T>,
|
||||
pipe_out: RawFd,
|
||||
len: usize,
|
||||
) -> io::Result<usize> {
|
||||
loop {
|
||||
let mut guard = fd.readable().await?;
|
||||
match guard.try_io(|inner| unsafe {
|
||||
let res = libc::splice(
|
||||
inner.as_raw_fd(),
|
||||
std::ptr::null_mut(),
|
||||
pipe_out,
|
||||
std::ptr::null_mut(),
|
||||
len,
|
||||
libc::SPLICE_F_MOVE | libc::SPLICE_F_NONBLOCK,
|
||||
);
|
||||
if res >= 0 {
|
||||
Ok(res as usize)
|
||||
} else {
|
||||
Err(io::Error::last_os_error())
|
||||
}
|
||||
}) {
|
||||
Ok(res) => return res,
|
||||
Err(_would_block) => continue, // try_io clears readiness
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_splice_write<T: AsRawFd>(
|
||||
fd: &tokio::io::unix::AsyncFd<T>,
|
||||
pipe_in: RawFd,
|
||||
len: usize,
|
||||
) -> io::Result<usize> {
|
||||
loop {
|
||||
let mut guard = fd.writable().await?;
|
||||
match guard.try_io(|inner| unsafe {
|
||||
let res = libc::splice(
|
||||
pipe_in,
|
||||
std::ptr::null_mut(),
|
||||
inner.as_raw_fd(),
|
||||
std::ptr::null_mut(),
|
||||
len,
|
||||
libc::SPLICE_F_MOVE | libc::SPLICE_F_NONBLOCK,
|
||||
);
|
||||
if res >= 0 {
|
||||
Ok(res as usize)
|
||||
} else {
|
||||
Err(io::Error::last_os_error())
|
||||
}
|
||||
}) {
|
||||
Ok(res) => return res,
|
||||
Err(_would_block) => continue,
|
||||
UpstreamStream::Tcp(s) => s.write_all(buf).await,
|
||||
UpstreamStream::Unix(s) => s.write_all(buf).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
27
src/main.rs
27
src/main.rs
@@ -20,8 +20,29 @@ fn print_help() {
|
||||
println!(" -h, --help print this help message");
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
use monoio::time::TimeDriver;
|
||||
use monoio::{IoUringDriver, RuntimeBuilder};
|
||||
|
||||
fn main() {
|
||||
let mut uring_builder = io_uring::IoUring::builder();
|
||||
uring_builder.setup_sqpoll(2000); // 2000ms idle timeout
|
||||
// Optimizations for single-threaded runtime (requires Linux 6.0+)
|
||||
|
||||
let mut rt = RuntimeBuilder::<TimeDriver<IoUringDriver>>::new()
|
||||
.with_entries(32768)
|
||||
.uring_builder(uring_builder)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
rt.block_on(async {
|
||||
if let Err(e) = run().await {
|
||||
error!("application error: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn run() -> anyhow::Result<()> {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
let mut config_path = None;
|
||||
@@ -70,7 +91,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
info!("loading config from {}", config_path.display());
|
||||
|
||||
let config = Config::load(&config_path).await.map_err(|e| {
|
||||
let config = Config::load(&config_path).map_err(|e| {
|
||||
error!("failed to load config: {}", e);
|
||||
e
|
||||
})?;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use bytes::BytesMut;
|
||||
use monoio::io::AsyncReadRent;
|
||||
use std::io;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyInfo {
|
||||
@@ -19,7 +19,7 @@ pub enum Version {
|
||||
const V1_PREFIX: &[u8] = b"PROXY ";
|
||||
const V2_PREFIX: &[u8] = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"; // 12 bytes
|
||||
|
||||
pub async fn read_proxy_header<T: AsyncRead + Unpin>(
|
||||
pub async fn read_proxy_header<T: AsyncReadRent>(
|
||||
stream: &mut T,
|
||||
) -> io::Result<(Option<ProxyInfo>, BytesMut)> {
|
||||
let mut buf = BytesMut::with_capacity(512);
|
||||
@@ -27,7 +27,9 @@ pub async fn read_proxy_header<T: AsyncRead + Unpin>(
|
||||
// Read enough to distinguish version
|
||||
|
||||
// Initial read
|
||||
let n = stream.read_buf(&mut buf).await?;
|
||||
let (res, b) = stream.read(buf).await;
|
||||
buf = b;
|
||||
let n = res?;
|
||||
if n == 0 {
|
||||
return Ok((None, buf));
|
||||
}
|
||||
@@ -48,7 +50,7 @@ pub async fn read_proxy_header<T: AsyncRead + Unpin>(
|
||||
Ok((None, buf))
|
||||
}
|
||||
|
||||
async fn parse_v1<T: AsyncRead + Unpin>(
|
||||
async fn parse_v1<T: AsyncReadRent>(
|
||||
stream: &mut T,
|
||||
mut buf: BytesMut,
|
||||
) -> io::Result<(Option<ProxyInfo>, BytesMut)> {
|
||||
@@ -88,7 +90,9 @@ async fn parse_v1<T: AsyncRead + Unpin>(
|
||||
}
|
||||
|
||||
// Read more
|
||||
let n = stream.read_buf(&mut buf).await?;
|
||||
let (res, b) = stream.read(buf).await;
|
||||
buf = b;
|
||||
let n = res?;
|
||||
if n == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
@@ -104,7 +108,7 @@ async fn parse_v1<T: AsyncRead + Unpin>(
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_v2<T: AsyncRead + Unpin>(
|
||||
async fn parse_v2<T: AsyncReadRent>(
|
||||
stream: &mut T,
|
||||
mut buf: BytesMut,
|
||||
) -> io::Result<(Option<ProxyInfo>, BytesMut)> {
|
||||
@@ -114,7 +118,9 @@ async fn parse_v2<T: AsyncRead + Unpin>(
|
||||
// 15th-16th: len (u16 big endian)
|
||||
|
||||
while buf.len() < 16 {
|
||||
let n = stream.read_buf(&mut buf).await?;
|
||||
let (res, b) = stream.read(buf).await;
|
||||
buf = b;
|
||||
let n = res?;
|
||||
if n == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
@@ -138,7 +144,9 @@ async fn parse_v2<T: AsyncRead + Unpin>(
|
||||
|
||||
// Read payload
|
||||
while buf.len() < 16 + len {
|
||||
let n = stream.read_buf(&mut buf).await?;
|
||||
let (res, b) = stream.read(buf).await;
|
||||
buf = b;
|
||||
let n = res?;
|
||||
if n == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
|
||||
Reference in New Issue
Block a user