mirror of
https://github.com/awfufu/traudit
synced 2026-03-01 05:29:44 +08:00
feat: implement zero-downtime hot reload via FD passing and consolidate server loops
This commit is contained in:
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -3887,7 +3887,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "traudit"
|
||||
version = "0.0.7"
|
||||
version = "0.0.8"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
@@ -3909,6 +3909,7 @@ dependencies = [
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"serde_ignored",
|
||||
"serde_json",
|
||||
"serde_repr",
|
||||
"serde_yaml 0.9.34+deprecated",
|
||||
"socket2",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "traudit"
|
||||
version = "0.0.7"
|
||||
version = "0.0.8"
|
||||
edition = "2021"
|
||||
authors = ["awfufu"]
|
||||
description = "A reverse proxy that streams audit records directly to databases."
|
||||
@@ -12,6 +12,7 @@ tokio = { version = "1", features = ["full"] }
|
||||
clickhouse = { version = "0.14", features = ["time"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_yaml = "0.9"
|
||||
serde_json = "1"
|
||||
http = "1"
|
||||
socket2 = "0.6"
|
||||
libc = "0.2"
|
||||
@@ -30,7 +31,7 @@ httparse = "1.10.1"
|
||||
openssl = { version = "0.10" }
|
||||
serde_ignored = "0.1.14"
|
||||
tokio-openssl = "0.6"
|
||||
nix = { version = "0.31.1", features = ["signal", "process", "socket"] }
|
||||
nix = { version = "0.31.1", features = ["signal", "process", "socket", "fs"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
@@ -4,7 +4,6 @@ use crate::core::server::stream::InboundStream;
|
||||
use bytes::BytesMut;
|
||||
use openssl::ssl::{Ssl, SslAcceptor};
|
||||
use pingora::protocols::l4::socket::SocketAddr;
|
||||
// ShutdownWatch removed
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
@@ -46,12 +45,52 @@ impl UnifiedListener {
|
||||
}
|
||||
}
|
||||
|
||||
// Global registry for FDs to be passed during reload
|
||||
pub static FD_REGISTRY: std::sync::OnceLock<
|
||||
std::sync::Mutex<std::collections::HashMap<String, std::os::unix::io::RawFd>>,
|
||||
> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn get_fd_registry(
|
||||
) -> &'static std::sync::Mutex<std::collections::HashMap<String, std::os::unix::io::RawFd>> {
|
||||
FD_REGISTRY.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
|
||||
}
|
||||
|
||||
pub async fn bind_listener(
|
||||
addr_str: &str,
|
||||
mode: u32,
|
||||
service_name: &str,
|
||||
) -> anyhow::Result<UnifiedListener> {
|
||||
if let Some(path) = addr_str.strip_prefix("unix://") {
|
||||
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
|
||||
|
||||
// Check if we inherited an FD for this service
|
||||
let inherited_fds_json = std::env::var("TRAUDIT_INHERITED_FDS").ok();
|
||||
let mut inherited_fd: Option<RawFd> = None;
|
||||
|
||||
if let Some(json) = inherited_fds_json {
|
||||
let map: std::collections::HashMap<String, RawFd> =
|
||||
serde_json::from_str(&json).unwrap_or_default();
|
||||
if let Some(&fd) = map.get(service_name) {
|
||||
info!("[{}] inherited fd: {}", service_name, fd);
|
||||
inherited_fd = Some(fd);
|
||||
}
|
||||
}
|
||||
|
||||
let listener = if let Some(fd) = inherited_fd {
|
||||
// Determine type based on address string prefix
|
||||
if addr_str.starts_with("unix://") {
|
||||
let l = unsafe { std::os::unix::net::UnixListener::from_raw_fd(fd) };
|
||||
// We must set it non-blocking as tokio expects
|
||||
l.set_nonblocking(true)?;
|
||||
let l = UnixListener::from_std(l)?;
|
||||
let path = std::path::PathBuf::from(addr_str.trim_start_matches("unix://"));
|
||||
UnifiedListener::Unix(l, path)
|
||||
} else {
|
||||
let l = unsafe { std::net::TcpListener::from_raw_fd(fd) };
|
||||
l.set_nonblocking(true)?;
|
||||
let l = TcpListener::from_std(l)?;
|
||||
UnifiedListener::Tcp(l)
|
||||
}
|
||||
} else if let Some(path) = addr_str.strip_prefix("unix://") {
|
||||
// Robust bind logic adapted from previous implementation
|
||||
let path_buf = std::path::Path::new(path).to_path_buf();
|
||||
|
||||
@@ -93,7 +132,7 @@ pub async fn bind_listener(
|
||||
}
|
||||
}
|
||||
|
||||
Ok(UnifiedListener::Unix(listener, path_buf))
|
||||
UnifiedListener::Unix(listener, path_buf)
|
||||
} else {
|
||||
// TCP with SO_REUSEPORT
|
||||
use nix::sys::socket::{setsockopt, sockopt};
|
||||
@@ -158,8 +197,41 @@ pub async fn bind_listener(
|
||||
e
|
||||
})?;
|
||||
|
||||
Ok(UnifiedListener::Tcp(listener))
|
||||
UnifiedListener::Tcp(listener)
|
||||
};
|
||||
|
||||
// Register duplicated FD for reload to pass to the next process.
|
||||
let raw_fd = match &listener {
|
||||
UnifiedListener::Tcp(l) => l.as_raw_fd(),
|
||||
UnifiedListener::Unix(l, _) => l.as_raw_fd(),
|
||||
};
|
||||
|
||||
// Use libc for dup to avoid nix version issues
|
||||
let dup_fd = unsafe { libc::dup(raw_fd) };
|
||||
if dup_fd < 0 {
|
||||
let err = std::io::Error::last_os_error();
|
||||
error!("failed to dup fd: {}", err);
|
||||
return Err(anyhow::anyhow!(err));
|
||||
}
|
||||
|
||||
// Set CLOEXEC on the dup_fd
|
||||
let flags = unsafe { libc::fcntl(dup_fd, libc::F_GETFD) };
|
||||
if flags < 0 {
|
||||
let _ = unsafe { libc::close(dup_fd) };
|
||||
return Err(anyhow::anyhow!(std::io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
if unsafe { libc::fcntl(dup_fd, libc::F_SETFD, flags | libc::FD_CLOEXEC) } < 0 {
|
||||
let _ = unsafe { libc::close(dup_fd) };
|
||||
return Err(anyhow::anyhow!(std::io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
get_fd_registry()
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(service_name.to_string(), dup_fd);
|
||||
|
||||
Ok(listener)
|
||||
}
|
||||
|
||||
pub async fn serve_listener_loop<F, Fut>(
|
||||
|
||||
@@ -2,13 +2,12 @@ use crate::config::Config;
|
||||
use crate::core::upstream::UpstreamStream;
|
||||
use crate::db::clickhouse::ClickHouseLogger;
|
||||
use pingora::apps::ServerApp;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::sync::{Arc, Barrier};
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
|
||||
pub mod context;
|
||||
pub mod handler;
|
||||
mod listener;
|
||||
pub mod listener;
|
||||
mod pingora_compat;
|
||||
pub mod stream;
|
||||
|
||||
@@ -37,11 +36,9 @@ pub async fn run(
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// JoinSet to manage all server tasks
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
|
||||
// Pingora server initialization (TLS only or Standard HTTP)
|
||||
let mut pingora_services = Vec::new();
|
||||
|
||||
for service in config.services {
|
||||
let db = db.clone();
|
||||
for bind in &service.binds {
|
||||
@@ -51,26 +48,9 @@ pub async fn run(
|
||||
let mode = bind.mode;
|
||||
let real_ip_config = bind.real_ip.clone();
|
||||
|
||||
// Use custom loop for TCP services or HTTP services requiring PROXY protocol parsing (not fully supported by pingora standard loop).
|
||||
|
||||
let is_tcp_service = service.service_type == "tcp";
|
||||
// Use custom loop if Proxy Protocol is enabled, even if TLS is used
|
||||
let is_http_proxy = service.service_type == "http" && bind.proxy.is_some();
|
||||
|
||||
let use_custom_loop = is_tcp_service || is_http_proxy;
|
||||
|
||||
if !use_custom_loop {
|
||||
// Use Standard Pingora Service (For TLS, or Pure HTTP, or Unix HTTP without PROXY)
|
||||
pingora_services.push((
|
||||
service_config,
|
||||
bind.clone(),
|
||||
bind.tls.clone(),
|
||||
real_ip_config,
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
// --- Custom Loop Logic ---
|
||||
// Custom Loop
|
||||
|
||||
let mut tls_acceptor = None;
|
||||
if let Some(tls_config) = &bind.tls {
|
||||
@@ -91,7 +71,7 @@ pub async fn run(
|
||||
error!("failed to load cert chain {}: {}", tls_config.cert, e);
|
||||
anyhow::anyhow!(e)
|
||||
})?;
|
||||
// ALPN support matching Pingora's defaults?
|
||||
// ALPN support matching Pingora's defaults
|
||||
acceptor.set_alpn_protos(b"\x02h2\x08http/1.1").ok();
|
||||
tls_acceptor = Some(Arc::new(acceptor.build()));
|
||||
}
|
||||
@@ -150,14 +130,14 @@ pub async fn run(
|
||||
format!(" {}", tags.join(" "))
|
||||
};
|
||||
|
||||
if is_http_proxy {
|
||||
if is_tcp_service {
|
||||
info!(
|
||||
"[{}] listening on http {} {}{}",
|
||||
"[{}] listening on {} {}{}",
|
||||
service_config.name, listen_type, bind_addr, tag_str
|
||||
);
|
||||
} else {
|
||||
info!(
|
||||
"[{}] listening on {} {}{}",
|
||||
"[{}] listening on http {} {}{}",
|
||||
service_config.name, listen_type, bind_addr, tag_str
|
||||
);
|
||||
}
|
||||
@@ -262,61 +242,6 @@ pub async fn run(
|
||||
}
|
||||
}
|
||||
|
||||
// Run Pingora in a separate thread if needed
|
||||
if !pingora_services.is_empty() {
|
||||
let barrier = Arc::new(Barrier::new(2));
|
||||
let barrier_clone = barrier.clone();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
use crate::core::pingora_proxy::TrauditProxy;
|
||||
use pingora::proxy::http_proxy_service;
|
||||
use pingora::server::configuration::Opt;
|
||||
use pingora::server::Server;
|
||||
|
||||
if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
let mut server = Server::new(Some(Opt::default())).unwrap();
|
||||
server.bootstrap();
|
||||
|
||||
for (svc_config, bind, tls, real_ip) in pingora_services {
|
||||
let proxy = TrauditProxy {
|
||||
db: db.clone(),
|
||||
service_config: svc_config.clone(),
|
||||
listen_addr: bind.addr.clone(),
|
||||
real_ip,
|
||||
add_xff_header: bind.add_xff_header,
|
||||
};
|
||||
|
||||
let mut service = http_proxy_service(&server.configuration, proxy);
|
||||
|
||||
if let Some(tls_config) = tls {
|
||||
let key_path = tls_config.key.as_deref().unwrap_or(&tls_config.cert);
|
||||
service
|
||||
.add_tls(&bind.addr, &tls_config.cert, key_path)
|
||||
.unwrap();
|
||||
info!("[{}] listening on https {}", svc_config.name, bind.addr);
|
||||
} else if bind.addr.starts_with("unix://") {
|
||||
let path = bind.addr.trim_start_matches("unix://");
|
||||
service.add_uds(path, Some(std::fs::Permissions::from_mode(bind.mode)));
|
||||
info!("[{}] listening on http unix {}", svc_config.name, path);
|
||||
} else {
|
||||
service.add_tcp(&bind.addr);
|
||||
info!("[{}] listening on http {}", svc_config.name, bind.addr);
|
||||
}
|
||||
|
||||
server.add_service(service);
|
||||
}
|
||||
|
||||
barrier_clone.wait();
|
||||
server.run_forever();
|
||||
})) {
|
||||
error!("pingora server panicked: {:?}", e);
|
||||
}
|
||||
error!("pingora server exited unexpectedly!");
|
||||
});
|
||||
|
||||
barrier.wait();
|
||||
}
|
||||
|
||||
info!("traudit started...");
|
||||
|
||||
// notify systemd if configured
|
||||
|
||||
35
src/main.rs
35
src/main.rs
@@ -127,11 +127,40 @@ async fn main() -> anyhow::Result<()> {
|
||||
tokio::select! {
|
||||
_ = sighup.recv() => {
|
||||
info!("received SIGHUP (reload). spawning new process...");
|
||||
|
||||
// Prepare FDs to pass
|
||||
let fd_map = {
|
||||
let registry = traudit::core::server::listener::get_fd_registry().lock().unwrap();
|
||||
registry.clone()
|
||||
};
|
||||
|
||||
let fd_json = serde_json::to_string(&fd_map).unwrap_or_default();
|
||||
info!("passing fds: {}", fd_json);
|
||||
|
||||
// Spawn new process
|
||||
let args: Vec<String> = env::args().collect();
|
||||
match std::process::Command::new(&args[0])
|
||||
.args(&args[1..])
|
||||
.spawn() {
|
||||
let mut cmd = std::process::Command::new(&args[0]);
|
||||
cmd.args(&args[1..]);
|
||||
cmd.env("TRAUDIT_INHERITED_FDS", fd_json);
|
||||
|
||||
unsafe {
|
||||
// Use pre_exec to clear CLOEXEC on the FDs to be inherited.
|
||||
let fd_map_for_closure = fd_map.clone();
|
||||
|
||||
use std::os::unix::process::CommandExt;
|
||||
cmd.pre_exec(move || {
|
||||
for (_, &fd) in &fd_map_for_closure {
|
||||
// Clear FD_CLOEXEC flag
|
||||
let flags = libc::fcntl(fd, libc::F_GETFD);
|
||||
if flags >= 0 {
|
||||
libc::fcntl(fd, libc::F_SETFD, flags & !libc::FD_CLOEXEC);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
match cmd.spawn() {
|
||||
Ok(child) => {
|
||||
let child_pid = child.id();
|
||||
info!("spawned new process with pid: {}", child_pid);
|
||||
|
||||
Reference in New Issue
Block a user