feat: support per-bind real_ip configuration and unify tcp/http listener logic

This commit is contained in:
2026-01-18 22:32:16 +08:00
parent d757a23c7a
commit 4e5fdf3d21
10 changed files with 906 additions and 476 deletions

11
Cargo.lock generated
View File

@@ -1064,6 +1064,15 @@ dependencies = [
"hashbrown 0.16.1",
]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
dependencies = [
"serde",
]
[[package]]
name = "itoa"
version = "1.0.17"
@@ -2476,6 +2485,8 @@ dependencies = [
"async-trait",
"bytes",
"clickhouse",
"httparse",
"ipnet",
"libc",
"pingora",
"serde",

View File

@@ -24,6 +24,8 @@ async-trait = "0.1"
time = { version = "0.3.45", features = ["serde", "macros", "formatting", "parsing"] }
serde_repr = "0.1.20"
pingora = { version = "0.6", features = ["lb", "openssl"] }
ipnet = { version = "2.11.0", features = ["serde"] }
httparse = "1.10.1"
[dev-dependencies]
tempfile = "3"

View File

@@ -2,37 +2,43 @@
database:
type: clickhouse
dsn: "http://user:password@ip:port/traudit"
# dsn: "http://user:password@ip:port/database"
dsn: http://traudit:traudit114514@127.0.0.1:8123/traudit
batch_size: 50
batch_timeout_secs: 5
batch_timeout_secs: 100
services:
# receives traffic from frp with v2 proxy protocol header, audits it,
# strips the header, and forwards pure tcp to local sshd.
# Receives traffic from FRP with v2 Proxy Protocol header, audits it,
# strips the header, and forwards pure TCP to local SSHD.
- name: "ssh"
forward_to: "127.0.0.1:22"
type: "tcp"
real_ip:
from: "proxy_protocol"
trust_private_ranges: true
binds:
# Entry 1: Public traffic from FRP
- addr: "unix://test.sock"
- addr: "0.0.0.0:2223"
proxy: "v2"
# Entry 2: LAN direct traffic (no Proxy Protocol)
- addr: "0.0.0.0:2222"
# real_ip:
# strategy: ["proxy_protocol", "remote_addr"] # default
forward_to: "127.0.0.1:22"
- name: "https-web"
type: "tcp"
- name: "web"
forward_to: "127.0.0.1:8080"
type: "http"
binds:
- addr: "0.0.0.0:443"
- addr: 0.0.0.0:443
tls:
cert: "/etc/ssl/certs/site.pem"
key: "/etc/ssl/private/site.key"
- addr: "0.0.0.0:4433"
tls:
cert: "/etc/ssl/certs/site.pem"
key: "/etc/ssl/private/site.key"
proxy: "v2"
cert: "/path/to/cert.crt"
key: "/path/to/key.key"
proxy: v2
real_ip:
from: "proxy_protocol"
trust_private_ranges: true
trusted_proxies:
- 1.2.3.4
forward_to: "127.0.0.1:8080"

View File

@@ -1,4 +1,5 @@
use serde::{Deserialize, Deserializer};
use std::net::IpAddr;
use std::path::Path;
use tokio::fs;
@@ -41,6 +42,58 @@ pub struct ServiceConfig {
pub forward_to: String,
}
#[derive(Debug, Deserialize, Clone)]
pub struct RealIpConfig {
#[serde(default, rename = "from")]
pub source: RealIpSource,
#[serde(default, deserialize_with = "deserialize_trusted_proxies")]
pub trusted_proxies: Vec<ipnet::IpNet>,
#[serde(default)]
pub trust_private_ranges: bool,
#[serde(default)]
pub xff_trust_depth: usize,
}
impl RealIpConfig {
pub fn is_trusted(&self, ip: IpAddr) -> bool {
// Check explicit trusted proxies
for net in &self.trusted_proxies {
if net.contains(&ip) {
return true;
}
}
if self.trust_private_ranges && is_private(&ip) {
return true;
}
false
}
}
fn is_private(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(addr) => addr.is_loopback() || addr.is_link_local() || addr.is_private(),
IpAddr::V6(addr) => {
addr.is_loopback() ||
// addr.is_unique_local() is unstable, check ranges manually
// fc00::/7
(addr.segments()[0] & 0xfe00) == 0xfc00 ||
// fe80::/10
(addr.segments()[0] & 0xffc0) == 0xfe80
}
}
}
#[derive(Debug, Deserialize, Clone, PartialEq, Default)]
#[serde(rename_all = "snake_case")]
pub enum RealIpSource {
ProxyProtocol,
Xff,
#[default]
RemoteAddr,
}
#[derive(Debug, Deserialize, Clone)]
pub struct BindEntry {
pub addr: String,
@@ -49,6 +102,7 @@ pub struct BindEntry {
#[serde(default = "default_socket_mode", deserialize_with = "deserialize_mode")]
pub mode: u32,
pub tls: Option<TlsConfig>,
pub real_ip: Option<RealIpConfig>,
}
#[derive(Debug, Deserialize, Clone)]
@@ -75,11 +129,7 @@ where
let value = ModeValue::deserialize(deserializer)?;
match value {
ModeValue::Integer(i) => {
// If user provides 666, they likely mean octal 0666.
// But in YAML `mode: 666` is decimal 666.
// The requirement says: "if user wrote integer (e.g. 666), process as octal"
// So we interpret the decimal value as a sequence of octal digits.
// e.g. decimal 666 -> octal 666 (which is decimal 438)
// Interpret decimal integer as octal (e.g., 666 -> 0666) per requirements.
let s = i.to_string();
u32::from_str_radix(&s, 8).map_err(serde::de::Error::custom)
}
@@ -90,6 +140,27 @@ where
}
}
fn deserialize_trusted_proxies<'de, D>(deserializer: D) -> Result<Vec<ipnet::IpNet>, D::Error>
where
D: Deserializer<'de>,
{
let strings: Vec<String> = Vec::deserialize(deserializer)?;
let mut nets = Vec::with_capacity(strings.len());
for s in strings {
if let Ok(net) = s.parse::<ipnet::IpNet>() {
nets.push(net);
} else if let Ok(ip) = s.parse::<std::net::IpAddr>() {
nets.push(ipnet::IpNet::from(ip));
} else {
return Err(serde::de::Error::custom(format!(
"invalid IP address or CIDR: {}",
s
)));
}
}
Ok(nets)
}
impl Config {
pub async fn load<P: AsRef<Path>>(path: P) -> Result<Self, anyhow::Error> {
let content = fs::read_to_string(path).await?;

View File

@@ -1,4 +1,4 @@
use crate::config::ServiceConfig;
use crate::config::{RealIpSource, ServiceConfig};
use crate::db::clickhouse::{ClickHouseLogger, HttpLog, HttpMethod};
use async_trait::async_trait;
use pingora::prelude::*;
@@ -9,6 +9,8 @@ use std::time::Instant;
pub struct TrauditProxy {
pub db: Arc<ClickHouseLogger>,
pub service_config: ServiceConfig,
pub listen_addr: String,
pub real_ip: Option<crate::config::RealIpConfig>,
}
pub struct HttpContext {
@@ -41,28 +43,129 @@ impl ProxyHttp for TrauditProxy {
}
async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
// IP Priority: Proxy Protocol > XFF > X-Real-IP > Peer
let mut client_ip: Option<IpAddr> = session
ctx.start_ts = Some(Instant::now());
// 1. Determine Source IP
let peer_addr = session
.client_addr()
.and_then(|a| a.as_inet())
.map(|a| a.ip());
.map(|a| a.ip())
.unwrap_or(IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)));
// Check headers for overrides
if let Some(xff) = session.req_header().headers.get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
if let Some(first_ip) = xff_str.split(',').next() {
if let Ok(parsed_ip) = first_ip.trim().parse::<IpAddr>() {
client_ip = Some(parsed_ip); // Overwrite
let mut resolved_ip = peer_addr;
if let Some(cfg) = &self.real_ip {
match cfg.source {
RealIpSource::ProxyProtocol => {
// If custom listener was used, peer_addr is already the injected Real IP.
resolved_ip = peer_addr;
}
RealIpSource::RemoteAddr => {
resolved_ip = peer_addr;
}
RealIpSource::Xff => {
// Check trust on current peer/proxy IP
if cfg.is_trusted(peer_addr) {
if let Some(xff) = session.req_header().headers.get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
let ips: Vec<&str> = xff_str.split(',').map(|s| s.trim()).collect();
if !ips.is_empty() {
// Recursive trust (0) vs Fixed Depth
if cfg.xff_trust_depth == 0 {
// Recursive: walk backwards until first untrusted
let mut candidate = None;
for ip_str in ips.iter().rev() {
if let Ok(ip) = ip_str.parse() {
if cfg.is_trusted(ip) {
continue;
} else {
candidate = Some(ip);
break;
}
}
}
// If all trusted, take the first one (leftmost)
if let Some(ip) = candidate {
resolved_ip = ip;
} else if let Some(first_str) = ips.first() {
if let Ok(ip) = first_str.parse() {
resolved_ip = ip;
}
}
} else {
// Fixed depth
let idx = if ips.len() >= cfg.xff_trust_depth {
ips.len() - cfg.xff_trust_depth
} else {
0
};
if let Some(val) = ips.get(idx) {
if let Ok(ip) = val.parse() {
resolved_ip = ip;
}
}
}
}
}
}
}
}
}
}
if let Some(ip) = client_ip {
ctx.src_ip = ip;
ctx.src_ip = resolved_ip;
// Log connection info
let src_fmt = resolved_ip.to_string();
let physical_fmt = peer_addr.to_string();
if src_fmt == physical_fmt {
// If we stuck to physical, check if there was an XFF we ignored
let xff_msg = if let Some(xff) = session.req_header().headers.get("x-forwarded-for") {
if let Ok(v) = xff.to_str() {
// Only show if we actually have RealIpConfig that denied us
if let Some(cfg) = &self.real_ip {
if !cfg.is_trusted(peer_addr) {
format!("(untrusted) xff: {}", v)
} else {
"".to_string()
}
} else {
"".to_string()
}
} else {
"".to_string()
}
} else {
"".to_string()
};
if !xff_msg.is_empty() {
tracing::info!(
"[{}] {} <- {} {}",
self.service_config.name,
self.listen_addr,
src_fmt,
xff_msg
);
} else {
tracing::info!(
"[{}] {} <- {}",
self.service_config.name,
self.listen_addr,
src_fmt
);
}
} else {
// fallback to 0.0.0.0 if entirely missing (unlikely)
ctx.src_ip = IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0));
tracing::info!(
"[{}] {} <- {} ({})",
self.service_config.name,
self.listen_addr,
src_fmt,
physical_fmt
);
}
// 2. Audit Info
@@ -137,8 +240,6 @@ impl ProxyHttp for TrauditProxy {
ctx.status_code = header.status.as_u16();
}
// Bytes (resp_body_size accumulated in filter)
ctx.req_body_size = session.body_bytes_read() as u64;
let addr_family = if ctx.src_ip.is_ipv4() {
@@ -166,7 +267,6 @@ impl ProxyHttp for TrauditProxy {
let db = self.db.clone();
tokio::spawn(async move {
if let Err(e) = db.insert_http_log(log).await {
// log error
tracing::error!("failed to insert http log: {}", e);
}
});

View File

@@ -1,185 +1,232 @@
use super::stream::InboundStream;
use crate::config::ServiceConfig;
use crate::config::{RealIpConfig, RealIpSource, ServiceConfig};
use crate::core::forwarder;
use crate::core::server::pingora_compat::PingoraStream;
use crate::core::upstream::UpstreamStream;
use crate::db::clickhouse::ClickHouseLogger;
use crate::protocol;
use crate::db::clickhouse::{ClickHouseLogger, ProxyProto};
use crate::protocol::{self, ProxyInfo};
use bytes::BytesMut;
use pingora::protocols::GetSocketDigest;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tracing::{error, info};
pub async fn handle_connection(
mut inbound: InboundStream,
stream: PingoraStream,
proxy_info: Option<ProxyInfo>,
service: ServiceConfig,
proxy_cfg: Option<String>,
db: Arc<ClickHouseLogger>,
listen_addr: String,
) -> std::io::Result<u64> {
let conn_ts = time::OffsetDateTime::now_utc();
let start_instant = std::time::Instant::now();
// Use this flag or inbound type to determine if it's a Unix socket
let is_unix = matches!(inbound, InboundStream::Unix(_));
let (mut final_ip, mut final_port) = match &inbound {
InboundStream::Tcp(s) => {
let addr = s.peer_addr()?;
// Extract resolved IP from digest (injected by listener)
let digest = stream.get_socket_digest();
let (final_ip, final_port) = if let Some(d) = digest {
if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.peer_addr() {
(addr.ip(), addr.port())
}
InboundStream::Unix(_) => (
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
0,
),
};
let mut proto_enum = crate::db::clickhouse::ProxyProto::None;
let mut skip_log = false;
let result = async {
// read proxy protocol (if configured)
let mut buffer = bytes::BytesMut::new();
if proxy_cfg.is_some() {
// If configured, we attempt to read.
match protocol::read_proxy_header(&mut inbound).await {
Ok((proxy_info, buf)) => {
buffer = buf;
if let Some(info) = proxy_info {
let physical = inbound.peer_addr_string()?;
// Format: [ssh] unix://test.sock <- RealIP:Port (local) or [ssh] 0.0.0.0:2222 <- RealIP:Port (1.2.3.4:5678)
let physical_fmt = if matches!(inbound, InboundStream::Unix(_)) {
"local".to_string()
} else {
physical
};
info!(
"[{}] {} <- {} ({})",
service.name, listen_addr, info.source, physical_fmt
);
final_ip = info.source.ip();
final_port = info.source.port();
// Proxy info implies "proxied TCP" usually; rely on final_ip family later
proto_enum = match info.version {
protocol::Version::V1 => crate::db::clickhouse::ProxyProto::V1,
protocol::Version::V2 => crate::db::clickhouse::ProxyProto::V2,
};
// Verify version matches config if required
if let Some(ref required_ver) = proxy_cfg {
match required_ver.as_str() {
"v1" if info.version != protocol::Version::V1 => {
// warn mismatch?
}
"v2" if info.version != protocol::Version::V2 => {
// warn mismatch?
}
_ => {}
}
}
} else {
// Strict enforcement: config requires header
let physical = inbound.peer_addr_string()?;
let msg = format!("strict proxy protocol violation from {}", physical);
error!("[{}] {}", service.name, msg);
skip_log = true;
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, msg));
}
}
Err(e) => {
skip_log = true;
return Err(e);
}
}
} else {
let addr = if matches!(inbound, InboundStream::Unix(_)) {
// [ssh] unix://test.sock <- local
"local".to_string()
} else {
inbound.peer_addr_string()?
};
info!("[{}] {} <- {}", service.name, listen_addr, addr);
// Should not match other types if logic is correct
(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0)
}
// connect upstream
let mut upstream = UpstreamStream::connect(&service.forward_to).await?;
// write buffered data (peeked bytes)
if !buffer.is_empty() {
upstream.write_all_buf(&mut buffer).await?;
}
// zero-copy forwarding
let inbound_async = match inbound {
InboundStream::Tcp(s) => crate::core::upstream::AsyncStream::from_tokio_tcp(s)?,
InboundStream::Unix(s) => crate::core::upstream::AsyncStream::from_tokio_unix(s)?,
};
let upstream_async = upstream.into_async_stream()?;
let (spliced_bytes, splice_res) =
forwarder::zero_copy_bidirectional(inbound_async, upstream_async).await;
if let Err(e) = splice_res {
match e.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
tracing::debug!("[{}] connection closed with error: {}", service.name, e);
}
_ => {
error!("[{}] connection error: {}", service.name, e);
}
}
} else {
// Clean close logging removed
}
// Total bytes = initial peeked + filtered
Ok(spliced_bytes + buffer.len() as u64)
}
.await;
let duration = if result.is_ok() {
start_instant.elapsed().as_millis() as u32
} else {
0
(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0)
};
let bytes_transferred = *result.as_ref().unwrap_or(&0);
// Unwrap stream
let (inbound, mut read_buffer) = stream.into_inner();
// Finalize AddrFamily based on final_ip; Unix logic handled below
let mut addr_family = match final_ip {
std::net::IpAddr::V4(_) => crate::db::clickhouse::AddrFamily::Ipv4,
std::net::IpAddr::V6(_) => crate::db::clickhouse::AddrFamily::Ipv6,
let is_unix = matches!(inbound, InboundStream::Unix(_));
let remote_addr = match &inbound {
InboundStream::Tcp(s) => s.peer_addr()?,
InboundStream::Unix(_) => SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), 0),
};
if is_unix && proto_enum == crate::db::clickhouse::ProxyProto::None {
// Unix socket, direct connection (or no proxy header received)
addr_family = crate::db::clickhouse::AddrFamily::Unix;
// Store 0 (::)
final_ip = std::net::IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED);
final_port = 0;
// skip redundant proxy/IP resolution (done by listener); determine ProxyProto for logging
let proto_enum = if let Some(ref info) = proxy_info {
match info.version {
protocol::Version::V1 => ProxyProto::V1,
protocol::Version::V2 => ProxyProto::V2,
}
} else {
ProxyProto::None
};
// Log connection info
let src_fmt = if is_unix && proto_enum == ProxyProto::None {
"local".to_string()
} else {
final_ip.to_string()
};
let physical_fmt = if is_unix {
"local".to_string()
} else {
remote_addr.to_string()
};
if src_fmt == physical_fmt {
info!("[{}] {} <- {}", service.name, listen_addr, src_fmt);
} else {
info!(
"[{}] {} <- {} ({})",
service.name, listen_addr, src_fmt, physical_fmt
);
}
// 3. Connect Upstream
let mut upstream = UpstreamStream::connect(&service.forward_to).await?;
// 4. Write buffered data
if !read_buffer.is_empty() {
upstream.write_all_buf(&mut read_buffer).await?;
}
// 5. Zero-copy forwarding
let inbound_async = match inbound {
InboundStream::Tcp(s) => crate::core::upstream::AsyncStream::from_tokio_tcp(s)?,
InboundStream::Unix(s) => crate::core::upstream::AsyncStream::from_tokio_unix(s)?,
};
let upstream_async = upstream.into_async_stream()?;
let (spliced_bytes, splice_res) =
forwarder::zero_copy_bidirectional(inbound_async, upstream_async).await;
if let Err(e) = splice_res {
match e.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
tracing::debug!("[{}] connection closed: {}", service.name, e);
}
_ => error!("[{}] connection error: {}", service.name, e),
}
}
// Calculate total bytes
let total_bytes = spliced_bytes + read_buffer.len() as u64;
// Logging logic
let duration = start_instant.elapsed().as_millis() as u32;
// Handle Unix socket specifics for logging
let (log_ip, log_port, log_family) = if is_unix && proto_enum == ProxyProto::None {
(
IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
0,
crate::db::clickhouse::AddrFamily::Unix,
)
} else {
let family = match final_ip {
IpAddr::V4(_) => crate::db::clickhouse::AddrFamily::Ipv4,
IpAddr::V6(_) => crate::db::clickhouse::AddrFamily::Ipv6,
};
(final_ip, final_port, family)
};
let log_entry = crate::db::clickhouse::TcpLog {
service: service.name.clone(),
conn_ts,
duration: duration as u32,
addr_family,
ip: final_ip,
port: final_port,
duration,
addr_family: log_family,
ip: log_ip,
port: log_port,
proxy_proto: proto_enum,
bytes: bytes_transferred,
bytes: total_bytes,
};
if !skip_log {
tokio::spawn(async move {
if let Err(e) = db.insert_log(log_entry).await {
error!("failed to insert tcp log: {}", e);
tokio::spawn(async move {
if let Err(e) = db.insert_log(log_entry).await {
error!("failed to insert tcp log: {}", e);
}
});
Ok(total_bytes)
}
pub(crate) async fn resolve_real_ip(
config: &Option<RealIpConfig>,
remote_addr: SocketAddr,
proxy_info: &Option<ProxyInfo>,
inbound: &mut InboundStream,
buffer: &mut BytesMut,
) -> io::Result<(IpAddr, u16)> {
if let Some(cfg) = config {
match cfg.source {
RealIpSource::ProxyProtocol => {
if let Some(info) = proxy_info {
// Trust check: The PHYSICAL connection must be from a trusted source
if cfg.is_trusted(remote_addr.ip()) {
return Ok((info.source.ip(), info.source.port()));
}
}
}
});
RealIpSource::Xff => {
let current_ip = if let Some(info) = proxy_info {
info.source.ip()
} else {
remote_addr.ip()
};
if cfg.is_trusted(current_ip) {
if let Some(ip) = peek_xff_ip(inbound, buffer, cfg.xff_trust_depth).await? {
// XFF doesn't have port, use remote/proxy port
let port = if let Some(info) = proxy_info {
info.source.port()
} else {
remote_addr.port()
};
return Ok((ip, port));
}
}
}
RealIpSource::RemoteAddr => {
return Ok((remote_addr.ip(), remote_addr.port()));
}
}
}
result
// Fallback to Remote Address if no config or strategy failed.
Ok((remote_addr.ip(), remote_addr.port()))
}
pub(crate) async fn peek_xff_ip<T: AsyncRead + Unpin>(
stream: &mut T,
buffer: &mut BytesMut,
_trust_depth: usize,
) -> io::Result<Option<IpAddr>> {
let max_header = 4096;
loop {
if let Some(pos) = buffer.windows(4).position(|w| w == b"\r\n\r\n") {
let header_bytes = &buffer[..pos];
let mut headers = [httparse::Header {
name: "",
value: &[],
}; 32];
let mut req = httparse::Request::new(&mut headers);
if req.parse(header_bytes).is_ok() {
for header in req.headers {
if header.name.eq_ignore_ascii_case("x-forwarded-for") {
if let Ok(val) = std::str::from_utf8(header.value) {
let ips: Vec<&str> = val.split(',').map(|s| s.trim()).collect();
if let Some(ip_str) = ips.last() {
if let Ok(ip) = ip_str.parse() {
return Ok(Some(ip));
}
}
}
}
}
}
return Ok(None);
}
if buffer.len() >= max_header {
return Ok(None);
}
if stream.read_buf(buffer).await? == 0 {
return Ok(None);
}
}
}

View File

@@ -1,83 +1,205 @@
use crate::config::ServiceConfig;
use crate::core::server::pingora_compat::PingoraStream;
use crate::core::server::stream::InboundStream;
use bytes::BytesMut;
use pingora::protocols::l4::socket::SocketAddr;
use pingora::server::ShutdownWatch;
use std::os::unix::fs::PermissionsExt;
use std::path::PathBuf;
use tokio::net::{UnixListener, UnixStream};
use tracing::{error, info};
use tokio::net::{TcpListener, UnixListener, UnixStream};
use tracing::{error, info, warn};
pub struct UnixSocketGuard {
pub path: PathBuf,
pub enum UnifiedListener {
Tcp(TcpListener),
Unix(UnixListener, PathBuf), // PathBuf for cleanup on Drop
}
impl Drop for UnixSocketGuard {
impl Drop for UnifiedListener {
fn drop(&mut self) {
if let Err(_e) = std::fs::remove_file(&self.path) {
// File potentially gone or no permissions, debug log only
} else {
tracing::debug!("removed socket file {:?}", self.path);
if let UnifiedListener::Unix(_, ref path) = self {
let _ = std::fs::remove_file(path);
tracing::debug!("removed socket file {:?}", path);
}
}
}
pub async fn bind_robust(
path: &str,
impl UnifiedListener {
pub async fn accept(&self) -> std::io::Result<(InboundStream, std::net::SocketAddr)> {
match self {
UnifiedListener::Tcp(l) => {
let (stream, addr) = l.accept().await?;
Ok((InboundStream::Tcp(stream), addr))
}
UnifiedListener::Unix(l, _) => {
let (stream, _addr) = l.accept().await?;
// Mock IPv4 loopback for Unix sockets
let addr = std::net::SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
0,
);
Ok((InboundStream::Unix(stream), addr))
}
}
}
}
pub async fn bind_listener(
addr_str: &str,
mode: u32,
service_name: &str,
) -> anyhow::Result<(UnixListener, UnixSocketGuard)> {
let path_buf = std::path::Path::new(path).to_path_buf();
) -> anyhow::Result<UnifiedListener> {
if let Some(path) = addr_str.strip_prefix("unix://") {
// Robust bind logic adapted from previous implementation
let path_buf = std::path::Path::new(path).to_path_buf();
if path_buf.exists() {
// Check permissions; we need write access to remove it
match std::fs::symlink_metadata(&path_buf) {
Ok(_meta) => {
// We rely on subsequent operations (connect/remove) to fail with PermissionDenied if we lack access.
}
Err(e) => {
if path_buf.exists() {
// Check permissions
if let Err(e) = std::fs::symlink_metadata(&path_buf) {
if e.kind() == std::io::ErrorKind::PermissionDenied {
anyhow::bail!("Permission denied accessing existing socket: {}", path);
}
}
// Check if active
match UnixStream::connect(&path_buf).await {
Ok(_) => anyhow::bail!("Address already in use: {}", path),
Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
info!("[{}] removing stale socket file: {}", service_name, path);
std::fs::remove_file(&path_buf)?;
}
Err(e) => anyhow::bail!("failed to check existing socket {}: {}", path, e),
}
}
// Try to connect to check if it's active
match UnixStream::connect(&path_buf).await {
Ok(_) => {
// Active!
anyhow::bail!("Address already in use: {}", path);
}
Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
// Stale socket, remove it
info!("[{}] removing stale socket file: {}", service_name, path);
if let Err(rm_err) = std::fs::remove_file(&path_buf) {
anyhow::bail!("failed to remove stale socket {}: {}", path, rm_err);
let listener = UnixListener::bind(&path_buf).map_err(|e| {
error!("[{}] failed to bind {}: {}", service_name, path, e);
e
})?;
// Permissions
if let Ok(metadata) = std::fs::metadata(&path_buf) {
let mut perms = metadata.permissions();
if perms.mode() & 0o777 != mode & 0o777 {
perms.set_mode(mode);
if let Err(e) = std::fs::set_permissions(&path_buf, perms) {
error!(
"[{}] failed to set permissions on {}: {}",
service_name, path, e
);
}
}
Err(e) => {
// Other error, bail
anyhow::bail!("failed to check existing socket {}: {}", path, e);
}
}
Ok(UnifiedListener::Unix(listener, path_buf))
} else {
// TCP
let listener = TcpListener::bind(addr_str).await.map_err(|e| {
error!("[{}] failed to bind {}: {}", service_name, addr_str, e);
e
})?;
Ok(UnifiedListener::Tcp(listener))
}
}
pub async fn serve_listener_loop<F, Fut>(
listener: UnifiedListener,
service: ServiceConfig,
real_ip_config: Option<crate::config::RealIpConfig>,
proxy_cfg: Option<String>,
_shutdown: ShutdownWatch,
handler: F,
) where
F: Fn(PingoraStream, Option<crate::protocol::ProxyInfo>) -> Fut + Send + Sync + 'static + Clone,
Fut: std::future::Future<Output = ()> + Send,
{
loop {
match listener.accept().await {
Ok((mut stream, client_addr)) => {
let proxy_cfg = proxy_cfg.clone();
let service = service.clone();
let real_ip_config = real_ip_config.clone();
let handler = handler.clone();
tokio::spawn(async move {
let mut buffer = BytesMut::new();
let mut proxy_info = None;
// 1. Read PROXY header
if proxy_cfg.is_some() {
match crate::protocol::read_proxy_header(&mut stream).await {
Ok((info, buf)) => {
buffer = buf;
if let Some(info) = info {
// Validate version
let valid = match proxy_cfg.as_deref() {
Some("v1") => info.version == crate::protocol::Version::V1,
Some("v2") => info.version == crate::protocol::Version::V2,
_ => true,
};
if !valid {
warn!("[{}] proxy protocol version mismatch", service.name);
}
proxy_info = Some(info);
} else {
let msg = format!("strict proxy protocol violation from {}", client_addr);
error!("[{}] {}", service.name, msg);
return; // Close connection
}
}
Err(e) => {
error!("failed to read proxy header: {}", e);
return;
}
}
}
// 2. Resolve Real IP (consumes stream/buffer for XFF peeking if needed).
let (real_peer_ip, real_peer_port) = match crate::core::server::handler::resolve_real_ip(
&real_ip_config,
client_addr,
&proxy_info,
&mut stream,
&mut buffer,
)
.await
{
Ok((ip, port)) => (ip, port),
Err(e) => {
error!("[{}] real ip resolution failed: {}", service.name, e);
// Fallback or abort?
// Abort is safer if I/O broken.
return;
}
};
let local_addr = match &stream {
InboundStream::Tcp(s) => s.local_addr().ok(),
_ => None,
}
.unwrap_or_else(|| "0.0.0.0:0".parse().unwrap());
// 3. Construct PingoraStream
let stream = PingoraStream::new(
stream,
buffer,
match SocketAddr::from(std::net::SocketAddr::new(real_peer_ip, real_peer_port)) {
SocketAddr::Inet(addr) => addr,
_ => unreachable!(),
},
match SocketAddr::from(local_addr) {
SocketAddr::Inet(addr) => addr,
_ => unreachable!(),
},
);
// 4. Handler
handler(stream, proxy_info).await;
});
}
Err(e) => {
error!("accept error: {}", e);
}
}
}
// Now bind
let listener = UnixListener::bind(&path_buf).map_err(|e| {
error!("[{}] failed to bind {}: {}", service_name, path, e);
e
})?;
// Set permissions
if let Ok(metadata) = std::fs::metadata(&path_buf) {
let mut permissions = metadata.permissions();
// Verify if we need to change it
if permissions.mode() & 0o777 != mode & 0o777 {
permissions.set_mode(mode);
if let Err(e) = std::fs::set_permissions(&path_buf, permissions) {
// Non-fatal error, log only
error!(
"[{}] failed to set permissions on {}: {}",
service_name, path, e
);
}
}
}
Ok((listener, UnixSocketGuard { path: path_buf }))
}

View File

@@ -1,18 +1,19 @@
use crate::config::{Config, ServiceConfig};
use crate::config::Config;
use crate::core::upstream::UpstreamStream;
use crate::db::clickhouse::ClickHouseLogger;
use std::sync::Arc;
use tokio::net::{TcpListener, UnixListener};
use pingora::apps::ServerApp;
use std::os::unix::fs::PermissionsExt;
use std::sync::{Arc, Barrier};
use tokio::signal;
use tracing::{error, info};
mod handler;
mod listener;
mod pingora_compat;
mod stream;
use self::handler::handle_connection;
use self::listener::bind_robust;
use self::stream::InboundStream;
use self::listener::{bind_listener, serve_listener_loop, UnifiedListener};
pub async fn run(config: Config) -> anyhow::Result<()> {
let db_logger = ClickHouseLogger::new(&config.database).map_err(|e| {
@@ -33,71 +34,200 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
}
let mut join_set = tokio::task::JoinSet::new();
let mut socket_guards = Vec::new();
// Pingora server initialization (TLS only)
// Pingora server initialization (TLS only or Standard HTTP)
let mut pingora_services = Vec::new();
for service in config.services {
let db = db.clone();
for bind in &service.binds {
let service_config = service.clone();
let bind_addr = bind.addr.clone();
let proxy_proto_config = bind.proxy.clone();
let mode = bind.mode;
let real_ip_config = bind.real_ip.clone();
// Check if this bind is TLS/Pingora managed
if let Some(tls_config) = &bind.tls {
// This is a Pingora service
pingora_services.push((service_config, bind.clone(), tls_config.clone()));
// Use custom loop for TCP services or HTTP services requiring PROXY protocol parsing (not fully supported by pingora standard loop).
let is_tcp_service = service.service_type == "tcp";
let is_http_proxy =
service.service_type == "http" && bind.proxy.is_some() && bind.tls.is_none();
let use_custom_loop = is_tcp_service || is_http_proxy;
if !use_custom_loop {
// Use Standard Pingora Service (For TLS, or Pure HTTP, or Unix HTTP without PROXY)
pingora_services.push((
service_config,
bind.clone(),
bind.tls.clone(),
real_ip_config,
));
continue;
}
// Legacy TCP/Unix Logic
if bind_addr.starts_with("unix://") {
let path = bind_addr.trim_start_matches("unix://");
// --- Custom Loop Logic ---
// Bind robustly
let (listener, guard) = bind_robust(path, mode, &service_config.name).await?;
let listener_res = bind_listener(&bind_addr, mode, &service_config.name).await;
// Push guard to keep it alive
socket_guards.push(guard);
let listener = match listener_res {
Ok(l) => l,
Err(e) => {
return Err(e);
}
};
let listen_type = match &listener {
UnifiedListener::Unix(_, _) => "unix",
UnifiedListener::Tcp(_) => "tcp",
};
if is_http_proxy {
info!(
"[{}] listening on unix {} (mode {:o})",
service_config.name, path, mode
"[{}] listening on http {} {} (PROXY support)",
service_config.name, listen_type, bind_addr
);
} else {
info!(
"[{}] listening on {} {}",
service_config.name, listen_type, bind_addr
);
}
join_set.spawn(start_unix_service(
service_config,
let shutdown_dummy =
pingora::server::ShutdownWatch::from(tokio::sync::watch::channel(false).1);
if is_tcp_service {
// --- TCP Handler (with startup check) ---
if let Err(e) = UpstreamStream::connect(&service_config.forward_to).await {
tracing::warn!(
"[{}] -> '{}': startup check failed: {}",
service_config.name,
service_config.forward_to,
e
);
}
let db = db.clone();
let _proxy_cfg = proxy_proto_config.clone();
let listen_addr_log = bind_addr.clone();
let svc_cfg = service_config.clone();
join_set.spawn(serve_listener_loop(
listener,
service_config,
real_ip_config,
proxy_proto_config,
db.clone(),
bind.addr.clone(),
shutdown_dummy,
move |stream, info| {
let db = db.clone();
let svc = svc_cfg.clone();
let addr = listen_addr_log.clone();
async move {
if let Err(e) = handle_connection(stream, info, svc, db, addr).await {
match e.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
tracing::debug!("connection closed: {}", e);
}
_ => error!("connection error: {}", e),
}
}
}
},
));
} else {
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
error!(
"[{}] failed to bind {}: {}",
service_config.name, bind_addr, e
);
e
})?;
// --- HTTP Proxy Handler ---
use crate::core::pingora_proxy::TrauditProxy;
use pingora::proxy::http_proxy_service;
info!("[{}] listening on tcp {}", service_config.name, bind_addr);
let conf = Arc::new(pingora::server::configuration::ServerConf::default());
let inner_proxy = TrauditProxy {
db: db.clone(),
service_config: service_config.clone(),
listen_addr: bind_addr.clone(),
real_ip: real_ip_config.clone(),
};
let mut service_obj = http_proxy_service(&conf, inner_proxy);
let app = unsafe {
let app_ptr = service_obj.app_logic_mut().expect("app logic missing");
std::ptr::read(app_ptr)
};
std::mem::forget(service_obj);
let app = Arc::new(app);
join_set.spawn(start_tcp_service(
service_config,
join_set.spawn(serve_listener_loop(
listener,
service_config,
real_ip_config,
proxy_proto_config,
db.clone(),
bind.addr.clone(),
shutdown_dummy.clone(),
move |stream, _info| {
let app = app.clone();
let shutdown = shutdown_dummy.clone();
async move {
let stream: pingora::protocols::Stream = Box::new(stream);
app.process_new(stream, &shutdown).await;
}
},
));
}
}
}
// Run Pingora in a separate thread if needed
if !pingora_services.is_empty() {
let barrier = Arc::new(Barrier::new(2));
let barrier_clone = barrier.clone();
std::thread::spawn(move || {
use crate::core::pingora_proxy::TrauditProxy;
use pingora::proxy::http_proxy_service;
use pingora::server::configuration::Opt;
use pingora::server::Server;
if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut server = Server::new(Some(Opt::default())).unwrap();
server.bootstrap();
for (svc_config, bind, tls, real_ip) in pingora_services {
let proxy = TrauditProxy {
db: db.clone(),
service_config: svc_config.clone(),
listen_addr: bind.addr.clone(),
real_ip,
};
let mut service = http_proxy_service(&server.configuration, proxy);
if let Some(tls_config) = tls {
let key_path = tls_config.key.as_deref().unwrap_or(&tls_config.cert);
service
.add_tls(&bind.addr, &tls_config.cert, key_path)
.unwrap();
info!("[{}] listening on https {}", svc_config.name, bind.addr);
} else if bind.addr.starts_with("unix://") {
let path = bind.addr.trim_start_matches("unix://");
service.add_uds(path, Some(std::fs::Permissions::from_mode(bind.mode)));
info!("[{}] listening on http unix {}", svc_config.name, path);
} else {
service.add_tcp(&bind.addr);
info!("[{}] listening on http {}", svc_config.name, bind.addr);
}
server.add_service(service);
}
barrier_clone.wait();
server.run_forever();
})) {
error!("pingora server panicked: {:?}", e);
}
error!("pingora server exited unexpectedly!");
});
barrier.wait();
}
info!("traudit started...");
// notify systemd if configured
@@ -109,52 +239,6 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
}
}
// Run Pingora in a separate thread if needed
if !pingora_services.is_empty() {
info!(
"initializing pingora for {} tls services",
pingora_services.len()
);
// Spawn Pingora
std::thread::spawn(move || {
use crate::core::pingora_proxy::TrauditProxy;
use pingora::proxy::http_proxy_service;
use pingora::server::configuration::Opt;
use pingora::server::Server;
if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut server = Server::new(Some(Opt::default())).unwrap();
server.bootstrap();
for (svc_config, bind, tls) in pingora_services {
let proxy = TrauditProxy {
db: db.clone(),
service_config: svc_config.clone(),
};
let mut service = http_proxy_service(&server.configuration, proxy);
// Key path fallback
let key_path = tls.key.as_deref().unwrap_or(&tls.cert);
service.add_tls(&bind.addr, &tls.cert, key_path).unwrap();
info!("[{}] listening on tcp {}", svc_config.name, bind.addr);
server.add_service(service);
}
info!("starting pingora server run_forever loop");
server.run_forever();
})) {
error!("pingora server panicked: {:?}", e);
}
error!("pingora server exited unexpectedly!");
error!("pingora server exited unexpectedly!");
});
}
// Always wait for signals
match signal::ctrl_c().await {
Ok(()) => {
info!("shutdown signal received.");
@@ -165,139 +249,5 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
}
join_set.shutdown().await;
// socket_guards dropped here, cleaning up files
Ok(())
}
async fn start_tcp_service(
service: ServiceConfig,
listener: TcpListener,
proxy_cfg: Option<String>,
db: Arc<ClickHouseLogger>,
listen_addr: String,
) {
// Startup liveness check
if let Err(e) = UpstreamStream::connect(&service.forward_to).await {
match e.kind() {
std::io::ErrorKind::ConnectionRefused => {
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
}
std::io::ErrorKind::NotFound => {
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
}
_ => {
// Log other startup errors as warnings
tracing::warn!(
"[{}] -> '{}': startup check failed: {}",
service.name,
service.forward_to,
e
);
}
}
}
loop {
match listener.accept().await {
Ok((inbound, _client_addr)) => {
let service = service.clone();
let db = db.clone();
let proxy_cfg = proxy_cfg.clone();
let listen_addr = listen_addr.clone();
tokio::spawn(async move {
let svc_name = service.name.clone();
let svc_target = service.forward_to.clone();
let inbound = InboundStream::Tcp(inbound);
if let Err(e) = handle_connection(inbound, service, proxy_cfg, db, listen_addr).await {
match e.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
// normal disconnects, debug log only
tracing::debug!("connection closed: {}", e);
}
std::io::ErrorKind::ConnectionRefused => {
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
}
std::io::ErrorKind::NotFound => {
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
}
_ => {
error!("connection error: {}", e);
}
}
}
});
}
Err(e) => {
error!("accept error: {}", e);
}
}
}
}
async fn start_unix_service(
service: ServiceConfig,
listener: UnixListener,
proxy_cfg: Option<String>,
db: Arc<ClickHouseLogger>,
listen_addr: String,
) {
// Startup liveness check (same as TCP)
if let Err(e) = UpstreamStream::connect(&service.forward_to).await {
match e.kind() {
std::io::ErrorKind::ConnectionRefused => {
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
}
std::io::ErrorKind::NotFound => {
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
}
_ => {
tracing::warn!(
"[{}] -> '{}': startup check failed: {}",
service.name,
service.forward_to,
e
);
}
}
}
loop {
match listener.accept().await {
Ok((inbound, _addr)) => {
let service = service.clone();
let db = db.clone();
let proxy_cfg = proxy_cfg.clone();
let listen_addr = listen_addr.clone();
tokio::spawn(async move {
let svc_name = service.name.clone();
let svc_target = service.forward_to.clone();
let inbound = InboundStream::Unix(inbound);
if let Err(e) = handle_connection(inbound, service, proxy_cfg, db, listen_addr).await {
match e.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
tracing::debug!("connection closed: {}", e);
}
std::io::ErrorKind::ConnectionRefused => {
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
}
std::io::ErrorKind::NotFound => {
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
}
_ => {
error!("connection error: {}", e);
}
}
}
});
}
Err(e) => {
error!("accept error: {}", e);
}
}
}
}

View File

@@ -0,0 +1,129 @@
use async_trait::async_trait;
use bytes::{Buf, BytesMut};
use pingora::protocols::l4::socket::SocketAddr;
use pingora::protocols::{
GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
TimingDigest, UniqueID, UniqueIDType,
};
use std::fmt::Debug;
use std::io;
use std::net::SocketAddr as InetSocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use crate::core::server::stream::InboundStream;
#[derive(Debug)]
pub struct PingoraStream {
inner: InboundStream,
buffer: BytesMut,
digest: Arc<SocketDigest>,
}
impl PingoraStream {
pub fn new(
inner: InboundStream,
buffer: BytesMut,
peer_addr: InetSocketAddr,
local_addr: InetSocketAddr,
) -> Self {
#[cfg(unix)]
let digest = {
use std::os::fd::AsRawFd;
let fd = match &inner {
InboundStream::Tcp(s) => s.as_raw_fd(),
InboundStream::Unix(s) => s.as_raw_fd(),
};
let digest = SocketDigest::from_raw_fd(fd);
let _ = digest.peer_addr.set(Some(SocketAddr::Inet(peer_addr)));
let _ = digest.local_addr.set(Some(SocketAddr::Inet(local_addr)));
Arc::new(digest)
};
// Windows support or non-unix fallback
#[cfg(not(unix))]
let digest = Arc::new(SocketDigest::default());
Self {
inner,
buffer,
digest,
}
}
pub fn into_inner(self) -> (InboundStream, BytesMut) {
(self.inner, self.buffer)
}
}
impl AsyncRead for PingoraStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.buffer.has_remaining() {
let len = std::cmp::min(self.buffer.len(), buf.remaining());
buf.put_slice(&self.buffer[..len]);
self.buffer.advance(len);
Poll::Ready(Ok(()))
} else {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
}
impl AsyncWrite for PingoraStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
#[async_trait]
impl Shutdown for PingoraStream {
async fn shutdown(&mut self) -> () {
let _ = <InboundStream as AsyncWriteExt>::shutdown(&mut self.inner).await;
}
}
impl UniqueID for PingoraStream {
fn id(&self) -> UniqueIDType {
0
}
}
impl Ssl for PingoraStream {}
impl GetTimingDigest for PingoraStream {
fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
vec![]
}
}
impl GetProxyDigest for PingoraStream {
fn get_proxy_digest(&self) -> Option<Arc<pingora::protocols::raw_connect::ProxyDigest>> {
None
}
}
impl GetSocketDigest for PingoraStream {
fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
Some(self.digest.clone())
}
}
impl Peek for PingoraStream {}

View File

@@ -3,20 +3,12 @@ use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{TcpStream, UnixStream};
#[derive(Debug)]
pub enum InboundStream {
Tcp(TcpStream),
Unix(UnixStream),
}
impl InboundStream {
pub fn peer_addr_string(&self) -> std::io::Result<String> {
match self {
InboundStream::Tcp(s) => Ok(s.peer_addr()?.to_string()),
InboundStream::Unix(_) => Ok("unix_socket".to_string()),
}
}
}
impl AsyncRead for InboundStream {
fn poll_read(
self: Pin<&mut Self>,