feat: implement zero-downtime smooth reload and graceful shutdown using SO_REUSEPORT

This commit is contained in:
2026-01-27 13:41:56 +08:00
parent 64980f00c5
commit a71e950734
8 changed files with 294 additions and 122 deletions

27
Cargo.lock generated
View File

@@ -1816,6 +1816,15 @@ dependencies = [
"autocfg",
]
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]]
name = "mime"
version = "0.3.17"
@@ -1910,7 +1919,20 @@ dependencies = [
"bitflags 1.3.2",
"cfg-if",
"libc",
"memoffset",
"memoffset 0.6.5",
]
[[package]]
name = "nix"
version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225e7cfe711e0ba79a68baeddb2982723e4235247aefce1482f2f16c27865b66"
dependencies = [
"bitflags 2.10.0",
"cfg-if",
"cfg_aliases",
"libc",
"memoffset 0.9.1",
]
[[package]]
@@ -2275,7 +2297,7 @@ dependencies = [
"httpdate",
"libc",
"log",
"nix",
"nix 0.24.3",
"once_cell",
"openssl-probe 0.1.6",
"parking_lot",
@@ -3876,6 +3898,7 @@ dependencies = [
"httparse",
"ipnet",
"libc",
"nix 0.31.1",
"once_cell",
"openssl",
"pingora",

View File

@@ -30,6 +30,7 @@ httparse = "1.10.1"
openssl = { version = "0.10" }
serde_ignored = "0.1.14"
tokio-openssl = "0.6"
nix = { version = "0.31.1", features = ["signal", "process", "socket"] }
[features]
default = []

View File

@@ -4,7 +4,7 @@ use crate::core::server::stream::InboundStream;
use bytes::BytesMut;
use openssl::ssl::{Ssl, SslAcceptor};
use pingora::protocols::l4::socket::SocketAddr;
use pingora::server::ShutdownWatch;
// ShutdownWatch removed
use std::os::unix::fs::PermissionsExt;
use std::path::PathBuf;
use std::sync::Arc;
@@ -95,11 +95,61 @@ pub async fn bind_listener(
Ok(UnifiedListener::Unix(listener, path_buf))
} else {
// TCP
let listener = TcpListener::bind(addr_str).await.map_err(|e| {
// TCP with SO_REUSEPORT
use nix::sys::socket::{setsockopt, sockopt};
use std::net::SocketAddr;
// AsRawFd removed
let addr: SocketAddr = addr_str.parse().map_err(|e: std::net::AddrParseError| {
error!("[{}] invalid address {}: {}", service_name, addr_str, e);
anyhow::anyhow!(e)
})?;
let domain = if addr.is_ipv4() {
socket2::Domain::IPV4
} else {
socket2::Domain::IPV6
};
let socket = socket2::Socket::new(domain, socket2::Type::STREAM, None).map_err(|e| {
error!("[{}] failed to create socket: {}", service_name, e);
e
})?;
#[cfg(unix)]
{
if let Err(e) = setsockopt(&socket, sockopt::ReusePort, &true) {
warn!("[{}] failed to set SO_REUSEPORT: {}", service_name, e);
}
if let Err(e) = setsockopt(&socket, sockopt::ReuseAddr, &true) {
warn!("[{}] failed to set SO_REUSEADDR: {}", service_name, e);
}
}
socket.set_nonblocking(true)?;
// Convert std::net::SocketAddr to socket2::SockAddr
let sock_addr = socket2::SockAddr::from(addr);
socket.bind(&sock_addr).map_err(|e| {
error!("[{}] failed to bind {}: {}", service_name, addr_str, e);
e
})?;
socket.listen(1024).map_err(|e| {
error!("[{}] failed to listen {}: {}", service_name, addr_str, e);
e
})?;
let std_listener: std::net::TcpListener = socket.into();
let listener = TcpListener::from_std(std_listener).map_err(|e| {
error!(
"[{}] failed to convert to tokio listener: {}",
service_name, e
);
e
})?;
Ok(UnifiedListener::Tcp(listener))
}
}
@@ -110,7 +160,7 @@ pub async fn serve_listener_loop<F, Fut>(
real_ip_config: Option<crate::config::RealIpConfig>,
proxy_cfg: Option<String>,
tls_acceptor: Option<Arc<SslAcceptor>>,
_shutdown: ShutdownWatch,
mut shutdown_rx: tokio::sync::broadcast::Receiver<()>,
handler: F,
) where
F: Fn(UnifiedPingoraStream, Option<crate::protocol::ProxyInfo>, std::net::SocketAddr) -> Fut
@@ -120,120 +170,169 @@ pub async fn serve_listener_loop<F, Fut>(
+ Clone,
Fut: std::future::Future<Output = ()> + Send,
{
use std::sync::atomic::{AtomicUsize, Ordering};
// Track active connections
let active_connections = Arc::new(AtomicUsize::new(0));
let notify_shutdown = Arc::new(tokio::sync::Notify::new());
loop {
match listener.accept().await {
Ok((mut stream, client_addr)) => {
let proxy_cfg = proxy_cfg.clone();
let service = service.clone();
let real_ip_config = real_ip_config.clone();
let handler = handler.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::select! {
_ = shutdown_rx.recv() => {
info!("[{}] shutdown signal received, stopping acceptance", service.name);
break;
}
accept_res = listener.accept() => {
match accept_res {
Ok((mut stream, client_addr)) => {
let proxy_cfg = proxy_cfg.clone();
let service = service.clone();
let real_ip_config = real_ip_config.clone();
let handler = handler.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let mut buffer = BytesMut::new();
let mut proxy_info = None;
// Increment counter
active_connections.fetch_add(1, Ordering::SeqCst);
let active_connections = active_connections.clone();
let notify_shutdown = notify_shutdown.clone();
// 1. Read PROXY header
if proxy_cfg.is_some() {
match crate::protocol::read_proxy_header(&mut stream).await {
Ok((info, buf)) => {
buffer = buf;
if let Some(info) = info {
// Validate version
let valid = match proxy_cfg.as_deref() {
Some("v1") => info.version == crate::protocol::Version::V1,
Some("v2") => info.version == crate::protocol::Version::V2,
_ => true,
};
if !valid {
warn!("[{}] proxy protocol version mismatch", service.name);
tokio::spawn(async move {
// Ensure we decrement on drop
struct ConnectionGuard {
counter: Arc<AtomicUsize>,
notify: Arc<tokio::sync::Notify>,
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
let prev = self.counter.fetch_sub(1, Ordering::SeqCst);
if prev == 1 {
self.notify.notify_waiters();
}
proxy_info = Some(info);
} else {
let msg = format!("strict proxy protocol violation from {}", client_addr);
error!("[{}] {}", service.name, msg);
return; // Close connection
}
}
Err(e) => {
error!("failed to read proxy header: {}", e);
return;
}
}
}
let _guard = ConnectionGuard {
counter: active_connections,
notify: notify_shutdown,
};
// 2. Resolve Real IP (consumes stream/buffer for XFF peeking if needed).
let mut buffer = BytesMut::new();
let mut proxy_info = None;
let (real_peer_ip, real_peer_port) = match crate::core::server::handler::resolve_real_ip(
&real_ip_config,
client_addr,
&proxy_info,
&mut stream,
&mut buffer,
)
.await
{
Ok((ip, port)) => (ip, port),
Err(e) => {
error!("[{}] real ip resolution failed: {}", service.name, e);
// Fallback or abort?
// Abort is safer if I/O broken.
return;
}
};
let local_addr = match &stream {
InboundStream::Tcp(s) => s.local_addr().ok(),
_ => None,
}
.unwrap_or_else(|| "0.0.0.0:0".parse().unwrap());
// 3. Construct base PingoraStream
let stream = PingoraStream::new(
stream,
buffer,
match SocketAddr::from(std::net::SocketAddr::new(real_peer_ip, real_peer_port)) {
SocketAddr::Inet(addr) => addr,
_ => unreachable!(),
},
match SocketAddr::from(local_addr) {
SocketAddr::Inet(addr) => addr,
_ => unreachable!(),
},
);
// 4. TLS Handshake if configured
let stream: UnifiedPingoraStream = if let Some(acceptor) = tls_acceptor {
match Ssl::new(acceptor.context()) {
Ok(ssl) => match SslStream::new(ssl, stream) {
Ok(mut ssl_stream) => match std::pin::Pin::new(&mut ssl_stream).accept().await {
Ok(_) => UnifiedPingoraStream::Tls(PingoraTlsStream::new(ssl_stream)),
// 1. Read PROXY header
if proxy_cfg.is_some() {
match crate::protocol::read_proxy_header(&mut stream).await {
Ok((info, buf)) => {
buffer = buf;
if let Some(info) = info {
// Validate version
let valid = match proxy_cfg.as_deref() {
Some("v1") => info.version == crate::protocol::Version::V1,
Some("v2") => info.version == crate::protocol::Version::V2,
_ => true,
};
if !valid {
warn!("[{}] proxy protocol version mismatch", service.name);
}
proxy_info = Some(info);
} else {
let msg = format!("strict proxy protocol violation from {}", client_addr);
error!("[{}] {}", service.name, msg);
return; // Close connection
}
}
Err(e) => {
error!("[{}] tls handshake failed: {}", service.name, e);
error!("failed to read proxy header: {}", e);
return;
}
},
}
}
// 2. Resolve Real IP (consumes stream/buffer for XFF peeking if needed).
let (real_peer_ip, real_peer_port) = match crate::core::server::handler::resolve_real_ip(
&real_ip_config,
client_addr,
&proxy_info,
&mut stream,
&mut buffer,
)
.await
{
Ok((ip, port)) => (ip, port),
Err(e) => {
error!("[{}] failed to create ssl stream: {}", service.name, e);
error!("[{}] real ip resolution failed: {}", service.name, e);
// Fallback or abort?
// Abort is safer if I/O broken.
return;
}
},
Err(e) => {
error!("[{}] failed to create ssl object: {}", service.name, e);
return;
}
}
} else {
UnifiedPingoraStream::Plain(stream)
};
};
// 5. Handler
handler(stream, proxy_info, client_addr).await;
});
}
Err(e) => {
error!("accept error: {}", e);
let local_addr = match &stream {
InboundStream::Tcp(s) => s.local_addr().ok(),
_ => None,
}
.unwrap_or_else(|| "0.0.0.0:0".parse().unwrap());
// 3. Construct base PingoraStream
let stream = PingoraStream::new(
stream,
buffer,
match SocketAddr::from(std::net::SocketAddr::new(real_peer_ip, real_peer_port)) {
SocketAddr::Inet(addr) => addr,
_ => unreachable!(),
},
match SocketAddr::from(local_addr) {
SocketAddr::Inet(addr) => addr,
_ => unreachable!(),
},
);
// 4. TLS Handshake if configured
let stream: UnifiedPingoraStream = if let Some(acceptor) = tls_acceptor {
match Ssl::new(acceptor.context()) {
Ok(ssl) => match SslStream::new(ssl, stream) {
Ok(mut ssl_stream) => match std::pin::Pin::new(&mut ssl_stream).accept().await {
Ok(_) => UnifiedPingoraStream::Tls(PingoraTlsStream::new(ssl_stream)),
Err(e) => {
error!("[{}] tls handshake failed: {}", service.name, e);
return;
}
},
Err(e) => {
error!("[{}] failed to create ssl stream: {}", service.name, e);
return;
}
},
Err(e) => {
error!("[{}] failed to create ssl object: {}", service.name, e);
return;
}
}
} else {
UnifiedPingoraStream::Plain(stream)
};
// 5. Handler
handler(stream, proxy_info, client_addr).await;
});
}
Err(e) => {
error!("accept error: {}", e);
}
}
}
}
}
// Graceful shutdown: wait for active connections
drop(listener); // Close socket immediately
if active_connections.load(Ordering::SeqCst) > 0 {
info!(
"[{}] waiting for {} active connections...",
service.name,
active_connections.load(Ordering::SeqCst)
);
notify_shutdown.notified().await;
}
info!("[{}] shutdown complete", service.name);
}

View File

@@ -4,7 +4,6 @@ use crate::db::clickhouse::ClickHouseLogger;
use pingora::apps::ServerApp;
use std::os::unix::fs::PermissionsExt;
use std::sync::{Arc, Barrier};
use tokio::signal;
use tracing::{error, info};
pub mod context;
@@ -17,7 +16,10 @@ use self::handler::handle_connection;
use self::listener::{bind_listener, serve_listener_loop, UnifiedListener};
use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
pub async fn run(config: Config) -> anyhow::Result<()> {
pub async fn run(
config: Config,
shutdown_tx: tokio::sync::broadcast::Sender<()>,
) -> anyhow::Result<()> {
let db_logger = ClickHouseLogger::new(&config.database).map_err(|e| {
error!("database: {}", e);
e
@@ -160,8 +162,7 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
);
}
let shutdown_dummy =
pingora::server::ShutdownWatch::from(tokio::sync::watch::channel(false).1);
let shutdown_rx = shutdown_tx.subscribe();
if is_tcp_service {
// --- TCP Handler (with startup check) ---
@@ -185,7 +186,7 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
real_ip_config.clone(),
proxy_proto_config,
tls_acceptor,
shutdown_dummy,
shutdown_rx,
move |stream, info, client_addr| {
let db = db.clone();
let svc = svc_cfg.clone();
@@ -226,13 +227,17 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
std::mem::forget(service_obj);
let app = Arc::new(app);
// pass the REAL shutdown signal to serve_listener_loop.
let shutdown_dummy =
pingora::server::ShutdownWatch::from(tokio::sync::watch::channel(false).1);
join_set.spawn(serve_listener_loop(
listener,
service_config,
real_ip_config,
proxy_proto_config,
tls_acceptor,
shutdown_dummy.clone(),
shutdown_rx,
move |stream, info, client_addr| {
let app = app.clone();
let shutdown = shutdown_dummy.clone();
@@ -323,15 +328,13 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
}
}
match signal::ctrl_c().await {
Ok(()) => {
info!("shutdown signal received.");
}
Err(err) => {
error!("unable to listen for shutdown signal: {}", err);
// Wait for all server components to finish gracefully
while let Some(res) = join_set.join_next().await {
if let Err(e) = res {
error!("task panicked: {}", e);
}
}
join_set.shutdown().await;
info!("server components finished.");
Ok(())
}

View File

@@ -4,6 +4,7 @@ use traudit::core;
use anyhow::bail;
use std::env;
use std::path::Path;
use tokio::signal;
use tracing::{error, info};
pub const VERSION: &str = concat!("v", env!("CARGO_PKG_VERSION"));
@@ -106,9 +107,49 @@ async fn main() -> anyhow::Result<()> {
return Ok(());
}
if let Err(_e) = core::server::run(config).await {
std::process::exit(1);
// Create a channel to signal shutdown to the server component
let (shutdown_tx, _shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
let shutdown_tx_clone = shutdown_tx.clone();
// Run server in a separate task
let server_handle = tokio::spawn(async move {
if let Err(e) = core::server::run(config, shutdown_tx_clone).await {
error!("server error: {}", e);
std::process::exit(1);
}
});
// Signal handling loop
let mut sighup = signal::unix::signal(signal::unix::SignalKind::hangup())?;
let mut sigint = signal::unix::signal(signal::unix::SignalKind::interrupt())?;
let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())?;
tokio::select! {
_ = sighup.recv() => {
info!("received SIGHUP (reload). spawning new process...");
// Spawn new process
let args: Vec<String> = env::args().collect();
match std::process::Command::new(&args[0])
.args(&args[1..])
.spawn() {
Ok(child) => info!("spawned new process with pid: {}", child.id()),
Err(e) => error!("failed to spawn new process: {}", e),
}
// Initiate graceful shutdown for this process
info!("shutting down old process gracefully...");
}
_ = sigint.recv() => {
info!("received SIGINT, shutdown...");
}
_ = sigterm.recv() => {
info!("received SIGTERM, shutdown...");
}
}
// Send shutdown signal to server components
let _ = shutdown_tx.send(());
// Wait for server to finish (graceful drain)
let _ = server_handle.await;
Ok(())
}

View File

@@ -410,7 +410,8 @@ pub async fn run_tcp_test(test_name: &str, proxy_proto: Option<&str>, is_unix: b
.await;
tokio::spawn(async move {
let _ = traudit::core::server::run(res.config).await;
let (tx, _rx) = tokio::sync::broadcast::channel(1);
let _ = traudit::core::server::run(res.config, tx).await;
});
// Wait for traudit startup and DB connect
tokio::time::sleep(Duration::from_millis(1000)).await;
@@ -490,7 +491,8 @@ pub async fn run_http_test(
.await;
tokio::spawn(async move {
let _ = traudit::core::server::run(res.config).await;
let (tx, _rx) = tokio::sync::broadcast::channel(1);
let _ = traudit::core::server::run(res.config, tx).await;
});
tokio::time::sleep(Duration::from_millis(1000)).await;

View File

@@ -302,7 +302,8 @@ async fn test_proxy_chain() {
let res = prepare_chain_env().await;
tokio::spawn(async move {
let _ = traudit::core::server::run(res.config).await;
let (tx, _rx) = tokio::sync::broadcast::channel(1);
let _ = traudit::core::server::run(res.config, tx).await;
});
tokio::time::sleep(Duration::from_millis(2000)).await;

View File

@@ -7,6 +7,8 @@ Type=notify
RuntimeDirectory=traudit
WorkingDirectory=/run/traudit
ExecStart=/usr/bin/traudit -f /etc/traudit/config.yaml
ExecReload=/bin/kill -HUP $MAINPID
KillSignal=SIGINT
Restart=on-failure
RestartSec=5s