mirror of
https://github.com/awfufu/traudit
synced 2026-03-01 05:29:44 +08:00
feat: support per-bind real_ip configuration and unify tcp/http listener logic
This commit is contained in:
11
Cargo.lock
generated
11
Cargo.lock
generated
@@ -1064,6 +1064,15 @@ dependencies = [
|
||||
"hashbrown 0.16.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipnet"
|
||||
version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.17"
|
||||
@@ -2476,6 +2485,8 @@ dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"clickhouse",
|
||||
"httparse",
|
||||
"ipnet",
|
||||
"libc",
|
||||
"pingora",
|
||||
"serde",
|
||||
|
||||
@@ -24,6 +24,8 @@ async-trait = "0.1"
|
||||
time = { version = "0.3.45", features = ["serde", "macros", "formatting", "parsing"] }
|
||||
serde_repr = "0.1.20"
|
||||
pingora = { version = "0.6", features = ["lb", "openssl"] }
|
||||
ipnet = { version = "2.11.0", features = ["serde"] }
|
||||
httparse = "1.10.1"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
|
||||
@@ -2,37 +2,43 @@
|
||||
|
||||
database:
|
||||
type: clickhouse
|
||||
dsn: "http://user:password@ip:port/traudit"
|
||||
# dsn: "http://user:password@ip:port/database"
|
||||
dsn: http://traudit:traudit114514@127.0.0.1:8123/traudit
|
||||
batch_size: 50
|
||||
batch_timeout_secs: 5
|
||||
batch_timeout_secs: 100
|
||||
|
||||
services:
|
||||
# receives traffic from frp with v2 proxy protocol header, audits it,
|
||||
# strips the header, and forwards pure tcp to local sshd.
|
||||
# Receives traffic from FRP with v2 Proxy Protocol header, audits it,
|
||||
# strips the header, and forwards pure TCP to local SSHD.
|
||||
- name: "ssh"
|
||||
forward_to: "127.0.0.1:22"
|
||||
type: "tcp"
|
||||
real_ip:
|
||||
from: "proxy_protocol"
|
||||
trust_private_ranges: true
|
||||
binds:
|
||||
# Entry 1: Public traffic from FRP
|
||||
- addr: "unix://test.sock"
|
||||
- addr: "0.0.0.0:2223"
|
||||
proxy: "v2"
|
||||
|
||||
|
||||
# Entry 2: LAN direct traffic (no Proxy Protocol)
|
||||
- addr: "0.0.0.0:2222"
|
||||
# real_ip:
|
||||
# strategy: ["proxy_protocol", "remote_addr"] # default
|
||||
|
||||
forward_to: "127.0.0.1:22"
|
||||
|
||||
- name: "https-web"
|
||||
type: "tcp"
|
||||
- name: "web"
|
||||
forward_to: "127.0.0.1:8080"
|
||||
type: "http"
|
||||
binds:
|
||||
- addr: "0.0.0.0:443"
|
||||
- addr: 0.0.0.0:443
|
||||
tls:
|
||||
cert: "/etc/ssl/certs/site.pem"
|
||||
key: "/etc/ssl/private/site.key"
|
||||
|
||||
- addr: "0.0.0.0:4433"
|
||||
tls:
|
||||
cert: "/etc/ssl/certs/site.pem"
|
||||
key: "/etc/ssl/private/site.key"
|
||||
proxy: "v2"
|
||||
cert: "/path/to/cert.crt"
|
||||
key: "/path/to/key.key"
|
||||
proxy: v2
|
||||
real_ip:
|
||||
from: "proxy_protocol"
|
||||
trust_private_ranges: true
|
||||
trusted_proxies:
|
||||
- 1.2.3.4
|
||||
|
||||
forward_to: "127.0.0.1:8080"
|
||||
@@ -1,4 +1,5 @@
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use std::net::IpAddr;
|
||||
use std::path::Path;
|
||||
use tokio::fs;
|
||||
|
||||
@@ -41,6 +42,58 @@ pub struct ServiceConfig {
|
||||
pub forward_to: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct RealIpConfig {
|
||||
#[serde(default, rename = "from")]
|
||||
pub source: RealIpSource,
|
||||
#[serde(default, deserialize_with = "deserialize_trusted_proxies")]
|
||||
pub trusted_proxies: Vec<ipnet::IpNet>,
|
||||
#[serde(default)]
|
||||
pub trust_private_ranges: bool,
|
||||
#[serde(default)]
|
||||
pub xff_trust_depth: usize,
|
||||
}
|
||||
|
||||
impl RealIpConfig {
|
||||
pub fn is_trusted(&self, ip: IpAddr) -> bool {
|
||||
// Check explicit trusted proxies
|
||||
for net in &self.trusted_proxies {
|
||||
if net.contains(&ip) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if self.trust_private_ranges && is_private(&ip) {
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn is_private(ip: &IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(addr) => addr.is_loopback() || addr.is_link_local() || addr.is_private(),
|
||||
IpAddr::V6(addr) => {
|
||||
addr.is_loopback() ||
|
||||
// addr.is_unique_local() is unstable, check ranges manually
|
||||
// fc00::/7
|
||||
(addr.segments()[0] & 0xfe00) == 0xfc00 ||
|
||||
// fe80::/10
|
||||
(addr.segments()[0] & 0xffc0) == 0xfe80
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone, PartialEq, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RealIpSource {
|
||||
ProxyProtocol,
|
||||
Xff,
|
||||
#[default]
|
||||
RemoteAddr,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct BindEntry {
|
||||
pub addr: String,
|
||||
@@ -49,6 +102,7 @@ pub struct BindEntry {
|
||||
#[serde(default = "default_socket_mode", deserialize_with = "deserialize_mode")]
|
||||
pub mode: u32,
|
||||
pub tls: Option<TlsConfig>,
|
||||
pub real_ip: Option<RealIpConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
@@ -75,11 +129,7 @@ where
|
||||
let value = ModeValue::deserialize(deserializer)?;
|
||||
match value {
|
||||
ModeValue::Integer(i) => {
|
||||
// If user provides 666, they likely mean octal 0666.
|
||||
// But in YAML `mode: 666` is decimal 666.
|
||||
// The requirement says: "if user wrote integer (e.g. 666), process as octal"
|
||||
// So we interpret the decimal value as a sequence of octal digits.
|
||||
// e.g. decimal 666 -> octal 666 (which is decimal 438)
|
||||
// Interpret decimal integer as octal (e.g., 666 -> 0666) per requirements.
|
||||
let s = i.to_string();
|
||||
u32::from_str_radix(&s, 8).map_err(serde::de::Error::custom)
|
||||
}
|
||||
@@ -90,6 +140,27 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_trusted_proxies<'de, D>(deserializer: D) -> Result<Vec<ipnet::IpNet>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let strings: Vec<String> = Vec::deserialize(deserializer)?;
|
||||
let mut nets = Vec::with_capacity(strings.len());
|
||||
for s in strings {
|
||||
if let Ok(net) = s.parse::<ipnet::IpNet>() {
|
||||
nets.push(net);
|
||||
} else if let Ok(ip) = s.parse::<std::net::IpAddr>() {
|
||||
nets.push(ipnet::IpNet::from(ip));
|
||||
} else {
|
||||
return Err(serde::de::Error::custom(format!(
|
||||
"invalid IP address or CIDR: {}",
|
||||
s
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(nets)
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub async fn load<P: AsRef<Path>>(path: P) -> Result<Self, anyhow::Error> {
|
||||
let content = fs::read_to_string(path).await?;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::config::ServiceConfig;
|
||||
use crate::config::{RealIpSource, ServiceConfig};
|
||||
use crate::db::clickhouse::{ClickHouseLogger, HttpLog, HttpMethod};
|
||||
use async_trait::async_trait;
|
||||
use pingora::prelude::*;
|
||||
@@ -9,6 +9,8 @@ use std::time::Instant;
|
||||
pub struct TrauditProxy {
|
||||
pub db: Arc<ClickHouseLogger>,
|
||||
pub service_config: ServiceConfig,
|
||||
pub listen_addr: String,
|
||||
pub real_ip: Option<crate::config::RealIpConfig>,
|
||||
}
|
||||
|
||||
pub struct HttpContext {
|
||||
@@ -41,28 +43,129 @@ impl ProxyHttp for TrauditProxy {
|
||||
}
|
||||
|
||||
async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
|
||||
// IP Priority: Proxy Protocol > XFF > X-Real-IP > Peer
|
||||
let mut client_ip: Option<IpAddr> = session
|
||||
ctx.start_ts = Some(Instant::now());
|
||||
|
||||
// 1. Determine Source IP
|
||||
let peer_addr = session
|
||||
.client_addr()
|
||||
.and_then(|a| a.as_inet())
|
||||
.map(|a| a.ip());
|
||||
.map(|a| a.ip())
|
||||
.unwrap_or(IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)));
|
||||
|
||||
// Check headers for overrides
|
||||
if let Some(xff) = session.req_header().headers.get("x-forwarded-for") {
|
||||
if let Ok(xff_str) = xff.to_str() {
|
||||
if let Some(first_ip) = xff_str.split(',').next() {
|
||||
if let Ok(parsed_ip) = first_ip.trim().parse::<IpAddr>() {
|
||||
client_ip = Some(parsed_ip); // Overwrite
|
||||
let mut resolved_ip = peer_addr;
|
||||
|
||||
if let Some(cfg) = &self.real_ip {
|
||||
match cfg.source {
|
||||
RealIpSource::ProxyProtocol => {
|
||||
// If custom listener was used, peer_addr is already the injected Real IP.
|
||||
resolved_ip = peer_addr;
|
||||
}
|
||||
RealIpSource::RemoteAddr => {
|
||||
resolved_ip = peer_addr;
|
||||
}
|
||||
RealIpSource::Xff => {
|
||||
// Check trust on current peer/proxy IP
|
||||
if cfg.is_trusted(peer_addr) {
|
||||
if let Some(xff) = session.req_header().headers.get("x-forwarded-for") {
|
||||
if let Ok(xff_str) = xff.to_str() {
|
||||
let ips: Vec<&str> = xff_str.split(',').map(|s| s.trim()).collect();
|
||||
|
||||
if !ips.is_empty() {
|
||||
// Recursive trust (0) vs Fixed Depth
|
||||
if cfg.xff_trust_depth == 0 {
|
||||
// Recursive: walk backwards until first untrusted
|
||||
let mut candidate = None;
|
||||
for ip_str in ips.iter().rev() {
|
||||
if let Ok(ip) = ip_str.parse() {
|
||||
if cfg.is_trusted(ip) {
|
||||
continue;
|
||||
} else {
|
||||
candidate = Some(ip);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// If all trusted, take the first one (leftmost)
|
||||
if let Some(ip) = candidate {
|
||||
resolved_ip = ip;
|
||||
} else if let Some(first_str) = ips.first() {
|
||||
if let Ok(ip) = first_str.parse() {
|
||||
resolved_ip = ip;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fixed depth
|
||||
let idx = if ips.len() >= cfg.xff_trust_depth {
|
||||
ips.len() - cfg.xff_trust_depth
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
if let Some(val) = ips.get(idx) {
|
||||
if let Ok(ip) = val.parse() {
|
||||
resolved_ip = ip;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ip) = client_ip {
|
||||
ctx.src_ip = ip;
|
||||
ctx.src_ip = resolved_ip;
|
||||
|
||||
// Log connection info
|
||||
let src_fmt = resolved_ip.to_string();
|
||||
let physical_fmt = peer_addr.to_string();
|
||||
|
||||
if src_fmt == physical_fmt {
|
||||
// If we stuck to physical, check if there was an XFF we ignored
|
||||
let xff_msg = if let Some(xff) = session.req_header().headers.get("x-forwarded-for") {
|
||||
if let Ok(v) = xff.to_str() {
|
||||
// Only show if we actually have RealIpConfig that denied us
|
||||
if let Some(cfg) = &self.real_ip {
|
||||
if !cfg.is_trusted(peer_addr) {
|
||||
format!("(untrusted) xff: {}", v)
|
||||
} else {
|
||||
"".to_string()
|
||||
}
|
||||
} else {
|
||||
"".to_string()
|
||||
}
|
||||
} else {
|
||||
"".to_string()
|
||||
}
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
if !xff_msg.is_empty() {
|
||||
tracing::info!(
|
||||
"[{}] {} <- {} {}",
|
||||
self.service_config.name,
|
||||
self.listen_addr,
|
||||
src_fmt,
|
||||
xff_msg
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
"[{}] {} <- {}",
|
||||
self.service_config.name,
|
||||
self.listen_addr,
|
||||
src_fmt
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// fallback to 0.0.0.0 if entirely missing (unlikely)
|
||||
ctx.src_ip = IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0));
|
||||
tracing::info!(
|
||||
"[{}] {} <- {} ({})",
|
||||
self.service_config.name,
|
||||
self.listen_addr,
|
||||
src_fmt,
|
||||
physical_fmt
|
||||
);
|
||||
}
|
||||
|
||||
// 2. Audit Info
|
||||
@@ -137,8 +240,6 @@ impl ProxyHttp for TrauditProxy {
|
||||
ctx.status_code = header.status.as_u16();
|
||||
}
|
||||
|
||||
// Bytes (resp_body_size accumulated in filter)
|
||||
|
||||
ctx.req_body_size = session.body_bytes_read() as u64;
|
||||
|
||||
let addr_family = if ctx.src_ip.is_ipv4() {
|
||||
@@ -166,7 +267,6 @@ impl ProxyHttp for TrauditProxy {
|
||||
let db = self.db.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = db.insert_http_log(log).await {
|
||||
// log error
|
||||
tracing::error!("failed to insert http log: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,185 +1,232 @@
|
||||
use super::stream::InboundStream;
|
||||
use crate::config::ServiceConfig;
|
||||
use crate::config::{RealIpConfig, RealIpSource, ServiceConfig};
|
||||
use crate::core::forwarder;
|
||||
use crate::core::server::pingora_compat::PingoraStream;
|
||||
use crate::core::upstream::UpstreamStream;
|
||||
use crate::db::clickhouse::ClickHouseLogger;
|
||||
use crate::protocol;
|
||||
use crate::db::clickhouse::{ClickHouseLogger, ProxyProto};
|
||||
use crate::protocol::{self, ProxyInfo};
|
||||
use bytes::BytesMut;
|
||||
use pingora::protocols::GetSocketDigest;
|
||||
use std::io;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{error, info};
|
||||
|
||||
pub async fn handle_connection(
|
||||
mut inbound: InboundStream,
|
||||
stream: PingoraStream,
|
||||
proxy_info: Option<ProxyInfo>,
|
||||
service: ServiceConfig,
|
||||
proxy_cfg: Option<String>,
|
||||
db: Arc<ClickHouseLogger>,
|
||||
listen_addr: String,
|
||||
) -> std::io::Result<u64> {
|
||||
let conn_ts = time::OffsetDateTime::now_utc();
|
||||
let start_instant = std::time::Instant::now();
|
||||
|
||||
// Use this flag or inbound type to determine if it's a Unix socket
|
||||
let is_unix = matches!(inbound, InboundStream::Unix(_));
|
||||
|
||||
let (mut final_ip, mut final_port) = match &inbound {
|
||||
InboundStream::Tcp(s) => {
|
||||
let addr = s.peer_addr()?;
|
||||
// Extract resolved IP from digest (injected by listener)
|
||||
let digest = stream.get_socket_digest();
|
||||
let (final_ip, final_port) = if let Some(d) = digest {
|
||||
if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.peer_addr() {
|
||||
(addr.ip(), addr.port())
|
||||
}
|
||||
InboundStream::Unix(_) => (
|
||||
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
|
||||
0,
|
||||
),
|
||||
};
|
||||
let mut proto_enum = crate::db::clickhouse::ProxyProto::None;
|
||||
let mut skip_log = false;
|
||||
|
||||
let result = async {
|
||||
// read proxy protocol (if configured)
|
||||
let mut buffer = bytes::BytesMut::new();
|
||||
|
||||
if proxy_cfg.is_some() {
|
||||
// If configured, we attempt to read.
|
||||
match protocol::read_proxy_header(&mut inbound).await {
|
||||
Ok((proxy_info, buf)) => {
|
||||
buffer = buf;
|
||||
if let Some(info) = proxy_info {
|
||||
let physical = inbound.peer_addr_string()?;
|
||||
|
||||
// Format: [ssh] unix://test.sock <- RealIP:Port (local) or [ssh] 0.0.0.0:2222 <- RealIP:Port (1.2.3.4:5678)
|
||||
let physical_fmt = if matches!(inbound, InboundStream::Unix(_)) {
|
||||
"local".to_string()
|
||||
} else {
|
||||
physical
|
||||
};
|
||||
|
||||
info!(
|
||||
"[{}] {} <- {} ({})",
|
||||
service.name, listen_addr, info.source, physical_fmt
|
||||
);
|
||||
final_ip = info.source.ip();
|
||||
final_port = info.source.port();
|
||||
|
||||
// Proxy info implies "proxied TCP" usually; rely on final_ip family later
|
||||
|
||||
proto_enum = match info.version {
|
||||
protocol::Version::V1 => crate::db::clickhouse::ProxyProto::V1,
|
||||
protocol::Version::V2 => crate::db::clickhouse::ProxyProto::V2,
|
||||
};
|
||||
|
||||
// Verify version matches config if required
|
||||
if let Some(ref required_ver) = proxy_cfg {
|
||||
match required_ver.as_str() {
|
||||
"v1" if info.version != protocol::Version::V1 => {
|
||||
// warn mismatch?
|
||||
}
|
||||
"v2" if info.version != protocol::Version::V2 => {
|
||||
// warn mismatch?
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Strict enforcement: config requires header
|
||||
let physical = inbound.peer_addr_string()?;
|
||||
let msg = format!("strict proxy protocol violation from {}", physical);
|
||||
error!("[{}] {}", service.name, msg);
|
||||
skip_log = true;
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, msg));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
skip_log = true;
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let addr = if matches!(inbound, InboundStream::Unix(_)) {
|
||||
// [ssh] unix://test.sock <- local
|
||||
"local".to_string()
|
||||
} else {
|
||||
inbound.peer_addr_string()?
|
||||
};
|
||||
info!("[{}] {} <- {}", service.name, listen_addr, addr);
|
||||
// Should not match other types if logic is correct
|
||||
(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0)
|
||||
}
|
||||
|
||||
// connect upstream
|
||||
let mut upstream = UpstreamStream::connect(&service.forward_to).await?;
|
||||
|
||||
// write buffered data (peeked bytes)
|
||||
if !buffer.is_empty() {
|
||||
upstream.write_all_buf(&mut buffer).await?;
|
||||
}
|
||||
|
||||
// zero-copy forwarding
|
||||
let inbound_async = match inbound {
|
||||
InboundStream::Tcp(s) => crate::core::upstream::AsyncStream::from_tokio_tcp(s)?,
|
||||
InboundStream::Unix(s) => crate::core::upstream::AsyncStream::from_tokio_unix(s)?,
|
||||
};
|
||||
let upstream_async = upstream.into_async_stream()?;
|
||||
|
||||
let (spliced_bytes, splice_res) =
|
||||
forwarder::zero_copy_bidirectional(inbound_async, upstream_async).await;
|
||||
|
||||
if let Err(e) = splice_res {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
|
||||
tracing::debug!("[{}] connection closed with error: {}", service.name, e);
|
||||
}
|
||||
_ => {
|
||||
error!("[{}] connection error: {}", service.name, e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Clean close logging removed
|
||||
}
|
||||
|
||||
// Total bytes = initial peeked + filtered
|
||||
Ok(spliced_bytes + buffer.len() as u64)
|
||||
}
|
||||
.await;
|
||||
|
||||
let duration = if result.is_ok() {
|
||||
start_instant.elapsed().as_millis() as u32
|
||||
} else {
|
||||
0
|
||||
(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0)
|
||||
};
|
||||
|
||||
let bytes_transferred = *result.as_ref().unwrap_or(&0);
|
||||
// Unwrap stream
|
||||
let (inbound, mut read_buffer) = stream.into_inner();
|
||||
|
||||
// Finalize AddrFamily based on final_ip; Unix logic handled below
|
||||
|
||||
let mut addr_family = match final_ip {
|
||||
std::net::IpAddr::V4(_) => crate::db::clickhouse::AddrFamily::Ipv4,
|
||||
std::net::IpAddr::V6(_) => crate::db::clickhouse::AddrFamily::Ipv6,
|
||||
let is_unix = matches!(inbound, InboundStream::Unix(_));
|
||||
let remote_addr = match &inbound {
|
||||
InboundStream::Tcp(s) => s.peer_addr()?,
|
||||
InboundStream::Unix(_) => SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), 0),
|
||||
};
|
||||
|
||||
if is_unix && proto_enum == crate::db::clickhouse::ProxyProto::None {
|
||||
// Unix socket, direct connection (or no proxy header received)
|
||||
addr_family = crate::db::clickhouse::AddrFamily::Unix;
|
||||
// Store 0 (::)
|
||||
final_ip = std::net::IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED);
|
||||
final_port = 0;
|
||||
// skip redundant proxy/IP resolution (done by listener); determine ProxyProto for logging
|
||||
let proto_enum = if let Some(ref info) = proxy_info {
|
||||
match info.version {
|
||||
protocol::Version::V1 => ProxyProto::V1,
|
||||
protocol::Version::V2 => ProxyProto::V2,
|
||||
}
|
||||
} else {
|
||||
ProxyProto::None
|
||||
};
|
||||
|
||||
// Log connection info
|
||||
let src_fmt = if is_unix && proto_enum == ProxyProto::None {
|
||||
"local".to_string()
|
||||
} else {
|
||||
final_ip.to_string()
|
||||
};
|
||||
let physical_fmt = if is_unix {
|
||||
"local".to_string()
|
||||
} else {
|
||||
remote_addr.to_string()
|
||||
};
|
||||
|
||||
if src_fmt == physical_fmt {
|
||||
info!("[{}] {} <- {}", service.name, listen_addr, src_fmt);
|
||||
} else {
|
||||
info!(
|
||||
"[{}] {} <- {} ({})",
|
||||
service.name, listen_addr, src_fmt, physical_fmt
|
||||
);
|
||||
}
|
||||
|
||||
// 3. Connect Upstream
|
||||
let mut upstream = UpstreamStream::connect(&service.forward_to).await?;
|
||||
|
||||
// 4. Write buffered data
|
||||
if !read_buffer.is_empty() {
|
||||
upstream.write_all_buf(&mut read_buffer).await?;
|
||||
}
|
||||
|
||||
// 5. Zero-copy forwarding
|
||||
let inbound_async = match inbound {
|
||||
InboundStream::Tcp(s) => crate::core::upstream::AsyncStream::from_tokio_tcp(s)?,
|
||||
InboundStream::Unix(s) => crate::core::upstream::AsyncStream::from_tokio_unix(s)?,
|
||||
};
|
||||
let upstream_async = upstream.into_async_stream()?;
|
||||
|
||||
let (spliced_bytes, splice_res) =
|
||||
forwarder::zero_copy_bidirectional(inbound_async, upstream_async).await;
|
||||
|
||||
if let Err(e) = splice_res {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
|
||||
tracing::debug!("[{}] connection closed: {}", service.name, e);
|
||||
}
|
||||
_ => error!("[{}] connection error: {}", service.name, e),
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate total bytes
|
||||
let total_bytes = spliced_bytes + read_buffer.len() as u64;
|
||||
|
||||
// Logging logic
|
||||
let duration = start_instant.elapsed().as_millis() as u32;
|
||||
|
||||
// Handle Unix socket specifics for logging
|
||||
let (log_ip, log_port, log_family) = if is_unix && proto_enum == ProxyProto::None {
|
||||
(
|
||||
IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
|
||||
0,
|
||||
crate::db::clickhouse::AddrFamily::Unix,
|
||||
)
|
||||
} else {
|
||||
let family = match final_ip {
|
||||
IpAddr::V4(_) => crate::db::clickhouse::AddrFamily::Ipv4,
|
||||
IpAddr::V6(_) => crate::db::clickhouse::AddrFamily::Ipv6,
|
||||
};
|
||||
(final_ip, final_port, family)
|
||||
};
|
||||
|
||||
let log_entry = crate::db::clickhouse::TcpLog {
|
||||
service: service.name.clone(),
|
||||
conn_ts,
|
||||
duration: duration as u32,
|
||||
addr_family,
|
||||
ip: final_ip,
|
||||
port: final_port,
|
||||
duration,
|
||||
addr_family: log_family,
|
||||
ip: log_ip,
|
||||
port: log_port,
|
||||
proxy_proto: proto_enum,
|
||||
bytes: bytes_transferred,
|
||||
bytes: total_bytes,
|
||||
};
|
||||
|
||||
if !skip_log {
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = db.insert_log(log_entry).await {
|
||||
error!("failed to insert tcp log: {}", e);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = db.insert_log(log_entry).await {
|
||||
error!("failed to insert tcp log: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(total_bytes)
|
||||
}
|
||||
|
||||
pub(crate) async fn resolve_real_ip(
|
||||
config: &Option<RealIpConfig>,
|
||||
remote_addr: SocketAddr,
|
||||
proxy_info: &Option<ProxyInfo>,
|
||||
inbound: &mut InboundStream,
|
||||
buffer: &mut BytesMut,
|
||||
) -> io::Result<(IpAddr, u16)> {
|
||||
if let Some(cfg) = config {
|
||||
match cfg.source {
|
||||
RealIpSource::ProxyProtocol => {
|
||||
if let Some(info) = proxy_info {
|
||||
// Trust check: The PHYSICAL connection must be from a trusted source
|
||||
if cfg.is_trusted(remote_addr.ip()) {
|
||||
return Ok((info.source.ip(), info.source.port()));
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
RealIpSource::Xff => {
|
||||
let current_ip = if let Some(info) = proxy_info {
|
||||
info.source.ip()
|
||||
} else {
|
||||
remote_addr.ip()
|
||||
};
|
||||
|
||||
if cfg.is_trusted(current_ip) {
|
||||
if let Some(ip) = peek_xff_ip(inbound, buffer, cfg.xff_trust_depth).await? {
|
||||
// XFF doesn't have port, use remote/proxy port
|
||||
let port = if let Some(info) = proxy_info {
|
||||
info.source.port()
|
||||
} else {
|
||||
remote_addr.port()
|
||||
};
|
||||
return Ok((ip, port));
|
||||
}
|
||||
}
|
||||
}
|
||||
RealIpSource::RemoteAddr => {
|
||||
return Ok((remote_addr.ip(), remote_addr.port()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
// Fallback to Remote Address if no config or strategy failed.
|
||||
Ok((remote_addr.ip(), remote_addr.port()))
|
||||
}
|
||||
|
||||
pub(crate) async fn peek_xff_ip<T: AsyncRead + Unpin>(
|
||||
stream: &mut T,
|
||||
buffer: &mut BytesMut,
|
||||
_trust_depth: usize,
|
||||
) -> io::Result<Option<IpAddr>> {
|
||||
let max_header = 4096;
|
||||
loop {
|
||||
if let Some(pos) = buffer.windows(4).position(|w| w == b"\r\n\r\n") {
|
||||
let header_bytes = &buffer[..pos];
|
||||
let mut headers = [httparse::Header {
|
||||
name: "",
|
||||
value: &[],
|
||||
}; 32];
|
||||
let mut req = httparse::Request::new(&mut headers);
|
||||
if req.parse(header_bytes).is_ok() {
|
||||
for header in req.headers {
|
||||
if header.name.eq_ignore_ascii_case("x-forwarded-for") {
|
||||
if let Ok(val) = std::str::from_utf8(header.value) {
|
||||
let ips: Vec<&str> = val.split(',').map(|s| s.trim()).collect();
|
||||
if let Some(ip_str) = ips.last() {
|
||||
if let Ok(ip) = ip_str.parse() {
|
||||
return Ok(Some(ip));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if buffer.len() >= max_header {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if stream.read_buf(buffer).await? == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,83 +1,205 @@
|
||||
use crate::config::ServiceConfig;
|
||||
use crate::core::server::pingora_compat::PingoraStream;
|
||||
use crate::core::server::stream::InboundStream;
|
||||
use bytes::BytesMut;
|
||||
use pingora::protocols::l4::socket::SocketAddr;
|
||||
use pingora::server::ShutdownWatch;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::PathBuf;
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tracing::{error, info};
|
||||
use tokio::net::{TcpListener, UnixListener, UnixStream};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
pub struct UnixSocketGuard {
|
||||
pub path: PathBuf,
|
||||
pub enum UnifiedListener {
|
||||
Tcp(TcpListener),
|
||||
Unix(UnixListener, PathBuf), // PathBuf for cleanup on Drop
|
||||
}
|
||||
|
||||
impl Drop for UnixSocketGuard {
|
||||
impl Drop for UnifiedListener {
|
||||
fn drop(&mut self) {
|
||||
if let Err(_e) = std::fs::remove_file(&self.path) {
|
||||
// File potentially gone or no permissions, debug log only
|
||||
} else {
|
||||
tracing::debug!("removed socket file {:?}", self.path);
|
||||
if let UnifiedListener::Unix(_, ref path) = self {
|
||||
let _ = std::fs::remove_file(path);
|
||||
tracing::debug!("removed socket file {:?}", path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn bind_robust(
|
||||
path: &str,
|
||||
impl UnifiedListener {
|
||||
pub async fn accept(&self) -> std::io::Result<(InboundStream, std::net::SocketAddr)> {
|
||||
match self {
|
||||
UnifiedListener::Tcp(l) => {
|
||||
let (stream, addr) = l.accept().await?;
|
||||
Ok((InboundStream::Tcp(stream), addr))
|
||||
}
|
||||
UnifiedListener::Unix(l, _) => {
|
||||
let (stream, _addr) = l.accept().await?;
|
||||
// Mock IPv4 loopback for Unix sockets
|
||||
let addr = std::net::SocketAddr::new(
|
||||
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
|
||||
0,
|
||||
);
|
||||
Ok((InboundStream::Unix(stream), addr))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn bind_listener(
|
||||
addr_str: &str,
|
||||
mode: u32,
|
||||
service_name: &str,
|
||||
) -> anyhow::Result<(UnixListener, UnixSocketGuard)> {
|
||||
let path_buf = std::path::Path::new(path).to_path_buf();
|
||||
) -> anyhow::Result<UnifiedListener> {
|
||||
if let Some(path) = addr_str.strip_prefix("unix://") {
|
||||
// Robust bind logic adapted from previous implementation
|
||||
let path_buf = std::path::Path::new(path).to_path_buf();
|
||||
|
||||
if path_buf.exists() {
|
||||
// Check permissions; we need write access to remove it
|
||||
match std::fs::symlink_metadata(&path_buf) {
|
||||
Ok(_meta) => {
|
||||
// We rely on subsequent operations (connect/remove) to fail with PermissionDenied if we lack access.
|
||||
}
|
||||
Err(e) => {
|
||||
if path_buf.exists() {
|
||||
// Check permissions
|
||||
if let Err(e) = std::fs::symlink_metadata(&path_buf) {
|
||||
if e.kind() == std::io::ErrorKind::PermissionDenied {
|
||||
anyhow::bail!("Permission denied accessing existing socket: {}", path);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if active
|
||||
match UnixStream::connect(&path_buf).await {
|
||||
Ok(_) => anyhow::bail!("Address already in use: {}", path),
|
||||
Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
|
||||
info!("[{}] removing stale socket file: {}", service_name, path);
|
||||
std::fs::remove_file(&path_buf)?;
|
||||
}
|
||||
Err(e) => anyhow::bail!("failed to check existing socket {}: {}", path, e),
|
||||
}
|
||||
}
|
||||
|
||||
// Try to connect to check if it's active
|
||||
match UnixStream::connect(&path_buf).await {
|
||||
Ok(_) => {
|
||||
// Active!
|
||||
anyhow::bail!("Address already in use: {}", path);
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
|
||||
// Stale socket, remove it
|
||||
info!("[{}] removing stale socket file: {}", service_name, path);
|
||||
if let Err(rm_err) = std::fs::remove_file(&path_buf) {
|
||||
anyhow::bail!("failed to remove stale socket {}: {}", path, rm_err);
|
||||
let listener = UnixListener::bind(&path_buf).map_err(|e| {
|
||||
error!("[{}] failed to bind {}: {}", service_name, path, e);
|
||||
e
|
||||
})?;
|
||||
|
||||
// Permissions
|
||||
if let Ok(metadata) = std::fs::metadata(&path_buf) {
|
||||
let mut perms = metadata.permissions();
|
||||
if perms.mode() & 0o777 != mode & 0o777 {
|
||||
perms.set_mode(mode);
|
||||
if let Err(e) = std::fs::set_permissions(&path_buf, perms) {
|
||||
error!(
|
||||
"[{}] failed to set permissions on {}: {}",
|
||||
service_name, path, e
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Other error, bail
|
||||
anyhow::bail!("failed to check existing socket {}: {}", path, e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(UnifiedListener::Unix(listener, path_buf))
|
||||
} else {
|
||||
// TCP
|
||||
let listener = TcpListener::bind(addr_str).await.map_err(|e| {
|
||||
error!("[{}] failed to bind {}: {}", service_name, addr_str, e);
|
||||
e
|
||||
})?;
|
||||
Ok(UnifiedListener::Tcp(listener))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_listener_loop<F, Fut>(
|
||||
listener: UnifiedListener,
|
||||
service: ServiceConfig,
|
||||
real_ip_config: Option<crate::config::RealIpConfig>,
|
||||
proxy_cfg: Option<String>,
|
||||
_shutdown: ShutdownWatch,
|
||||
handler: F,
|
||||
) where
|
||||
F: Fn(PingoraStream, Option<crate::protocol::ProxyInfo>) -> Fut + Send + Sync + 'static + Clone,
|
||||
Fut: std::future::Future<Output = ()> + Send,
|
||||
{
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((mut stream, client_addr)) => {
|
||||
let proxy_cfg = proxy_cfg.clone();
|
||||
let service = service.clone();
|
||||
let real_ip_config = real_ip_config.clone();
|
||||
let handler = handler.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut buffer = BytesMut::new();
|
||||
let mut proxy_info = None;
|
||||
|
||||
// 1. Read PROXY header
|
||||
if proxy_cfg.is_some() {
|
||||
match crate::protocol::read_proxy_header(&mut stream).await {
|
||||
Ok((info, buf)) => {
|
||||
buffer = buf;
|
||||
if let Some(info) = info {
|
||||
// Validate version
|
||||
let valid = match proxy_cfg.as_deref() {
|
||||
Some("v1") => info.version == crate::protocol::Version::V1,
|
||||
Some("v2") => info.version == crate::protocol::Version::V2,
|
||||
_ => true,
|
||||
};
|
||||
if !valid {
|
||||
warn!("[{}] proxy protocol version mismatch", service.name);
|
||||
}
|
||||
proxy_info = Some(info);
|
||||
} else {
|
||||
let msg = format!("strict proxy protocol violation from {}", client_addr);
|
||||
error!("[{}] {}", service.name, msg);
|
||||
return; // Close connection
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to read proxy header: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Resolve Real IP (consumes stream/buffer for XFF peeking if needed).
|
||||
|
||||
let (real_peer_ip, real_peer_port) = match crate::core::server::handler::resolve_real_ip(
|
||||
&real_ip_config,
|
||||
client_addr,
|
||||
&proxy_info,
|
||||
&mut stream,
|
||||
&mut buffer,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok((ip, port)) => (ip, port),
|
||||
Err(e) => {
|
||||
error!("[{}] real ip resolution failed: {}", service.name, e);
|
||||
// Fallback or abort?
|
||||
// Abort is safer if I/O broken.
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let local_addr = match &stream {
|
||||
InboundStream::Tcp(s) => s.local_addr().ok(),
|
||||
_ => None,
|
||||
}
|
||||
.unwrap_or_else(|| "0.0.0.0:0".parse().unwrap());
|
||||
|
||||
// 3. Construct PingoraStream
|
||||
let stream = PingoraStream::new(
|
||||
stream,
|
||||
buffer,
|
||||
match SocketAddr::from(std::net::SocketAddr::new(real_peer_ip, real_peer_port)) {
|
||||
SocketAddr::Inet(addr) => addr,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
match SocketAddr::from(local_addr) {
|
||||
SocketAddr::Inet(addr) => addr,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
);
|
||||
|
||||
// 4. Handler
|
||||
handler(stream, proxy_info).await;
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("accept error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now bind
|
||||
let listener = UnixListener::bind(&path_buf).map_err(|e| {
|
||||
error!("[{}] failed to bind {}: {}", service_name, path, e);
|
||||
e
|
||||
})?;
|
||||
|
||||
// Set permissions
|
||||
if let Ok(metadata) = std::fs::metadata(&path_buf) {
|
||||
let mut permissions = metadata.permissions();
|
||||
// Verify if we need to change it
|
||||
if permissions.mode() & 0o777 != mode & 0o777 {
|
||||
permissions.set_mode(mode);
|
||||
if let Err(e) = std::fs::set_permissions(&path_buf, permissions) {
|
||||
// Non-fatal error, log only
|
||||
error!(
|
||||
"[{}] failed to set permissions on {}: {}",
|
||||
service_name, path, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((listener, UnixSocketGuard { path: path_buf }))
|
||||
}
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
use crate::config::{Config, ServiceConfig};
|
||||
use crate::config::Config;
|
||||
use crate::core::upstream::UpstreamStream;
|
||||
use crate::db::clickhouse::ClickHouseLogger;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::{TcpListener, UnixListener};
|
||||
use pingora::apps::ServerApp;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::sync::{Arc, Barrier};
|
||||
use tokio::signal;
|
||||
use tracing::{error, info};
|
||||
|
||||
mod handler;
|
||||
mod listener;
|
||||
mod pingora_compat;
|
||||
mod stream;
|
||||
|
||||
use self::handler::handle_connection;
|
||||
use self::listener::bind_robust;
|
||||
use self::stream::InboundStream;
|
||||
use self::listener::{bind_listener, serve_listener_loop, UnifiedListener};
|
||||
|
||||
pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
let db_logger = ClickHouseLogger::new(&config.database).map_err(|e| {
|
||||
@@ -33,71 +34,200 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
let mut socket_guards = Vec::new();
|
||||
|
||||
// Pingora server initialization (TLS only)
|
||||
// Pingora server initialization (TLS only or Standard HTTP)
|
||||
let mut pingora_services = Vec::new();
|
||||
|
||||
for service in config.services {
|
||||
let db = db.clone();
|
||||
|
||||
for bind in &service.binds {
|
||||
let service_config = service.clone();
|
||||
let bind_addr = bind.addr.clone();
|
||||
let proxy_proto_config = bind.proxy.clone();
|
||||
let mode = bind.mode;
|
||||
let real_ip_config = bind.real_ip.clone();
|
||||
|
||||
// Check if this bind is TLS/Pingora managed
|
||||
if let Some(tls_config) = &bind.tls {
|
||||
// This is a Pingora service
|
||||
pingora_services.push((service_config, bind.clone(), tls_config.clone()));
|
||||
// Use custom loop for TCP services or HTTP services requiring PROXY protocol parsing (not fully supported by pingora standard loop).
|
||||
|
||||
let is_tcp_service = service.service_type == "tcp";
|
||||
let is_http_proxy =
|
||||
service.service_type == "http" && bind.proxy.is_some() && bind.tls.is_none();
|
||||
|
||||
let use_custom_loop = is_tcp_service || is_http_proxy;
|
||||
|
||||
if !use_custom_loop {
|
||||
// Use Standard Pingora Service (For TLS, or Pure HTTP, or Unix HTTP without PROXY)
|
||||
pingora_services.push((
|
||||
service_config,
|
||||
bind.clone(),
|
||||
bind.tls.clone(),
|
||||
real_ip_config,
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Legacy TCP/Unix Logic
|
||||
if bind_addr.starts_with("unix://") {
|
||||
let path = bind_addr.trim_start_matches("unix://");
|
||||
// --- Custom Loop Logic ---
|
||||
|
||||
// Bind robustly
|
||||
let (listener, guard) = bind_robust(path, mode, &service_config.name).await?;
|
||||
let listener_res = bind_listener(&bind_addr, mode, &service_config.name).await;
|
||||
|
||||
// Push guard to keep it alive
|
||||
socket_guards.push(guard);
|
||||
let listener = match listener_res {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let listen_type = match &listener {
|
||||
UnifiedListener::Unix(_, _) => "unix",
|
||||
UnifiedListener::Tcp(_) => "tcp",
|
||||
};
|
||||
|
||||
if is_http_proxy {
|
||||
info!(
|
||||
"[{}] listening on unix {} (mode {:o})",
|
||||
service_config.name, path, mode
|
||||
"[{}] listening on http {} {} (PROXY support)",
|
||||
service_config.name, listen_type, bind_addr
|
||||
);
|
||||
} else {
|
||||
info!(
|
||||
"[{}] listening on {} {}",
|
||||
service_config.name, listen_type, bind_addr
|
||||
);
|
||||
}
|
||||
|
||||
join_set.spawn(start_unix_service(
|
||||
service_config,
|
||||
let shutdown_dummy =
|
||||
pingora::server::ShutdownWatch::from(tokio::sync::watch::channel(false).1);
|
||||
|
||||
if is_tcp_service {
|
||||
// --- TCP Handler (with startup check) ---
|
||||
if let Err(e) = UpstreamStream::connect(&service_config.forward_to).await {
|
||||
tracing::warn!(
|
||||
"[{}] -> '{}': startup check failed: {}",
|
||||
service_config.name,
|
||||
service_config.forward_to,
|
||||
e
|
||||
);
|
||||
}
|
||||
|
||||
let db = db.clone();
|
||||
let _proxy_cfg = proxy_proto_config.clone();
|
||||
let listen_addr_log = bind_addr.clone();
|
||||
let svc_cfg = service_config.clone();
|
||||
|
||||
join_set.spawn(serve_listener_loop(
|
||||
listener,
|
||||
service_config,
|
||||
real_ip_config,
|
||||
proxy_proto_config,
|
||||
db.clone(),
|
||||
bind.addr.clone(),
|
||||
shutdown_dummy,
|
||||
move |stream, info| {
|
||||
let db = db.clone();
|
||||
let svc = svc_cfg.clone();
|
||||
let addr = listen_addr_log.clone();
|
||||
async move {
|
||||
if let Err(e) = handle_connection(stream, info, svc, db, addr).await {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
|
||||
tracing::debug!("connection closed: {}", e);
|
||||
}
|
||||
_ => error!("connection error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
));
|
||||
} else {
|
||||
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
|
||||
error!(
|
||||
"[{}] failed to bind {}: {}",
|
||||
service_config.name, bind_addr, e
|
||||
);
|
||||
e
|
||||
})?;
|
||||
// --- HTTP Proxy Handler ---
|
||||
use crate::core::pingora_proxy::TrauditProxy;
|
||||
use pingora::proxy::http_proxy_service;
|
||||
|
||||
info!("[{}] listening on tcp {}", service_config.name, bind_addr);
|
||||
let conf = Arc::new(pingora::server::configuration::ServerConf::default());
|
||||
let inner_proxy = TrauditProxy {
|
||||
db: db.clone(),
|
||||
service_config: service_config.clone(),
|
||||
listen_addr: bind_addr.clone(),
|
||||
real_ip: real_ip_config.clone(),
|
||||
};
|
||||
let mut service_obj = http_proxy_service(&conf, inner_proxy);
|
||||
let app = unsafe {
|
||||
let app_ptr = service_obj.app_logic_mut().expect("app logic missing");
|
||||
std::ptr::read(app_ptr)
|
||||
};
|
||||
std::mem::forget(service_obj);
|
||||
let app = Arc::new(app);
|
||||
|
||||
join_set.spawn(start_tcp_service(
|
||||
service_config,
|
||||
join_set.spawn(serve_listener_loop(
|
||||
listener,
|
||||
service_config,
|
||||
real_ip_config,
|
||||
proxy_proto_config,
|
||||
db.clone(),
|
||||
bind.addr.clone(),
|
||||
shutdown_dummy.clone(),
|
||||
move |stream, _info| {
|
||||
let app = app.clone();
|
||||
let shutdown = shutdown_dummy.clone();
|
||||
async move {
|
||||
let stream: pingora::protocols::Stream = Box::new(stream);
|
||||
app.process_new(stream, &shutdown).await;
|
||||
}
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run Pingora in a separate thread if needed
|
||||
if !pingora_services.is_empty() {
|
||||
let barrier = Arc::new(Barrier::new(2));
|
||||
let barrier_clone = barrier.clone();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
use crate::core::pingora_proxy::TrauditProxy;
|
||||
use pingora::proxy::http_proxy_service;
|
||||
use pingora::server::configuration::Opt;
|
||||
use pingora::server::Server;
|
||||
|
||||
if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
let mut server = Server::new(Some(Opt::default())).unwrap();
|
||||
server.bootstrap();
|
||||
|
||||
for (svc_config, bind, tls, real_ip) in pingora_services {
|
||||
let proxy = TrauditProxy {
|
||||
db: db.clone(),
|
||||
service_config: svc_config.clone(),
|
||||
listen_addr: bind.addr.clone(),
|
||||
real_ip,
|
||||
};
|
||||
|
||||
let mut service = http_proxy_service(&server.configuration, proxy);
|
||||
|
||||
if let Some(tls_config) = tls {
|
||||
let key_path = tls_config.key.as_deref().unwrap_or(&tls_config.cert);
|
||||
service
|
||||
.add_tls(&bind.addr, &tls_config.cert, key_path)
|
||||
.unwrap();
|
||||
info!("[{}] listening on https {}", svc_config.name, bind.addr);
|
||||
} else if bind.addr.starts_with("unix://") {
|
||||
let path = bind.addr.trim_start_matches("unix://");
|
||||
service.add_uds(path, Some(std::fs::Permissions::from_mode(bind.mode)));
|
||||
info!("[{}] listening on http unix {}", svc_config.name, path);
|
||||
} else {
|
||||
service.add_tcp(&bind.addr);
|
||||
info!("[{}] listening on http {}", svc_config.name, bind.addr);
|
||||
}
|
||||
|
||||
server.add_service(service);
|
||||
}
|
||||
|
||||
barrier_clone.wait();
|
||||
server.run_forever();
|
||||
})) {
|
||||
error!("pingora server panicked: {:?}", e);
|
||||
}
|
||||
error!("pingora server exited unexpectedly!");
|
||||
});
|
||||
|
||||
barrier.wait();
|
||||
}
|
||||
|
||||
info!("traudit started...");
|
||||
|
||||
// notify systemd if configured
|
||||
@@ -109,52 +239,6 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Run Pingora in a separate thread if needed
|
||||
if !pingora_services.is_empty() {
|
||||
info!(
|
||||
"initializing pingora for {} tls services",
|
||||
pingora_services.len()
|
||||
);
|
||||
|
||||
// Spawn Pingora
|
||||
std::thread::spawn(move || {
|
||||
use crate::core::pingora_proxy::TrauditProxy;
|
||||
use pingora::proxy::http_proxy_service;
|
||||
use pingora::server::configuration::Opt;
|
||||
use pingora::server::Server;
|
||||
|
||||
if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
let mut server = Server::new(Some(Opt::default())).unwrap();
|
||||
server.bootstrap();
|
||||
|
||||
for (svc_config, bind, tls) in pingora_services {
|
||||
let proxy = TrauditProxy {
|
||||
db: db.clone(),
|
||||
service_config: svc_config.clone(),
|
||||
};
|
||||
|
||||
let mut service = http_proxy_service(&server.configuration, proxy);
|
||||
|
||||
// Key path fallback
|
||||
let key_path = tls.key.as_deref().unwrap_or(&tls.cert);
|
||||
|
||||
service.add_tls(&bind.addr, &tls.cert, key_path).unwrap();
|
||||
|
||||
info!("[{}] listening on tcp {}", svc_config.name, bind.addr);
|
||||
server.add_service(service);
|
||||
}
|
||||
|
||||
info!("starting pingora server run_forever loop");
|
||||
server.run_forever();
|
||||
})) {
|
||||
error!("pingora server panicked: {:?}", e);
|
||||
}
|
||||
error!("pingora server exited unexpectedly!");
|
||||
error!("pingora server exited unexpectedly!");
|
||||
});
|
||||
}
|
||||
|
||||
// Always wait for signals
|
||||
match signal::ctrl_c().await {
|
||||
Ok(()) => {
|
||||
info!("shutdown signal received.");
|
||||
@@ -165,139 +249,5 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
join_set.shutdown().await;
|
||||
|
||||
// socket_guards dropped here, cleaning up files
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_tcp_service(
|
||||
service: ServiceConfig,
|
||||
listener: TcpListener,
|
||||
proxy_cfg: Option<String>,
|
||||
db: Arc<ClickHouseLogger>,
|
||||
listen_addr: String,
|
||||
) {
|
||||
// Startup liveness check
|
||||
if let Err(e) = UpstreamStream::connect(&service.forward_to).await {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionRefused => {
|
||||
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
|
||||
}
|
||||
std::io::ErrorKind::NotFound => {
|
||||
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
|
||||
}
|
||||
_ => {
|
||||
// Log other startup errors as warnings
|
||||
tracing::warn!(
|
||||
"[{}] -> '{}': startup check failed: {}",
|
||||
service.name,
|
||||
service.forward_to,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((inbound, _client_addr)) => {
|
||||
let service = service.clone();
|
||||
let db = db.clone();
|
||||
let proxy_cfg = proxy_cfg.clone();
|
||||
let listen_addr = listen_addr.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let svc_name = service.name.clone();
|
||||
let svc_target = service.forward_to.clone();
|
||||
let inbound = InboundStream::Tcp(inbound);
|
||||
|
||||
if let Err(e) = handle_connection(inbound, service, proxy_cfg, db, listen_addr).await {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
|
||||
// normal disconnects, debug log only
|
||||
tracing::debug!("connection closed: {}", e);
|
||||
}
|
||||
std::io::ErrorKind::ConnectionRefused => {
|
||||
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
|
||||
}
|
||||
std::io::ErrorKind::NotFound => {
|
||||
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
|
||||
}
|
||||
_ => {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("accept error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_unix_service(
|
||||
service: ServiceConfig,
|
||||
listener: UnixListener,
|
||||
proxy_cfg: Option<String>,
|
||||
db: Arc<ClickHouseLogger>,
|
||||
listen_addr: String,
|
||||
) {
|
||||
// Startup liveness check (same as TCP)
|
||||
if let Err(e) = UpstreamStream::connect(&service.forward_to).await {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionRefused => {
|
||||
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
|
||||
}
|
||||
std::io::ErrorKind::NotFound => {
|
||||
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!(
|
||||
"[{}] -> '{}': startup check failed: {}",
|
||||
service.name,
|
||||
service.forward_to,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((inbound, _addr)) => {
|
||||
let service = service.clone();
|
||||
let db = db.clone();
|
||||
let proxy_cfg = proxy_cfg.clone();
|
||||
let listen_addr = listen_addr.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let svc_name = service.name.clone();
|
||||
let svc_target = service.forward_to.clone();
|
||||
let inbound = InboundStream::Unix(inbound);
|
||||
|
||||
if let Err(e) = handle_connection(inbound, service, proxy_cfg, db, listen_addr).await {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
|
||||
tracing::debug!("connection closed: {}", e);
|
||||
}
|
||||
std::io::ErrorKind::ConnectionRefused => {
|
||||
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
|
||||
}
|
||||
std::io::ErrorKind::NotFound => {
|
||||
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
|
||||
}
|
||||
_ => {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("accept error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
129
src/core/server/pingora_compat.rs
Normal file
129
src/core/server/pingora_compat.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
use async_trait::async_trait;
|
||||
use bytes::{Buf, BytesMut};
|
||||
use pingora::protocols::l4::socket::SocketAddr;
|
||||
use pingora::protocols::{
|
||||
GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
|
||||
TimingDigest, UniqueID, UniqueIDType,
|
||||
};
|
||||
use std::fmt::Debug;
|
||||
use std::io;
|
||||
use std::net::SocketAddr as InetSocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
|
||||
use crate::core::server::stream::InboundStream;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PingoraStream {
|
||||
inner: InboundStream,
|
||||
buffer: BytesMut,
|
||||
digest: Arc<SocketDigest>,
|
||||
}
|
||||
|
||||
impl PingoraStream {
|
||||
pub fn new(
|
||||
inner: InboundStream,
|
||||
buffer: BytesMut,
|
||||
peer_addr: InetSocketAddr,
|
||||
local_addr: InetSocketAddr,
|
||||
) -> Self {
|
||||
#[cfg(unix)]
|
||||
let digest = {
|
||||
use std::os::fd::AsRawFd;
|
||||
let fd = match &inner {
|
||||
InboundStream::Tcp(s) => s.as_raw_fd(),
|
||||
InboundStream::Unix(s) => s.as_raw_fd(),
|
||||
};
|
||||
let digest = SocketDigest::from_raw_fd(fd);
|
||||
let _ = digest.peer_addr.set(Some(SocketAddr::Inet(peer_addr)));
|
||||
let _ = digest.local_addr.set(Some(SocketAddr::Inet(local_addr)));
|
||||
Arc::new(digest)
|
||||
};
|
||||
|
||||
// Windows support or non-unix fallback
|
||||
#[cfg(not(unix))]
|
||||
let digest = Arc::new(SocketDigest::default());
|
||||
|
||||
Self {
|
||||
inner,
|
||||
buffer,
|
||||
digest,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> (InboundStream, BytesMut) {
|
||||
(self.inner, self.buffer)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for PingoraStream {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if self.buffer.has_remaining() {
|
||||
let len = std::cmp::min(self.buffer.len(), buf.remaining());
|
||||
buf.put_slice(&self.buffer[..len]);
|
||||
self.buffer.advance(len);
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for PingoraStream {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Shutdown for PingoraStream {
|
||||
async fn shutdown(&mut self) -> () {
|
||||
let _ = <InboundStream as AsyncWriteExt>::shutdown(&mut self.inner).await;
|
||||
}
|
||||
}
|
||||
|
||||
impl UniqueID for PingoraStream {
|
||||
fn id(&self) -> UniqueIDType {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl Ssl for PingoraStream {}
|
||||
|
||||
impl GetTimingDigest for PingoraStream {
|
||||
fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
impl GetProxyDigest for PingoraStream {
|
||||
fn get_proxy_digest(&self) -> Option<Arc<pingora::protocols::raw_connect::ProxyDigest>> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl GetSocketDigest for PingoraStream {
|
||||
fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
|
||||
Some(self.digest.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Peek for PingoraStream {}
|
||||
@@ -3,20 +3,12 @@ use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::{TcpStream, UnixStream};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum InboundStream {
|
||||
Tcp(TcpStream),
|
||||
Unix(UnixStream),
|
||||
}
|
||||
|
||||
impl InboundStream {
|
||||
pub fn peer_addr_string(&self) -> std::io::Result<String> {
|
||||
match self {
|
||||
InboundStream::Tcp(s) => Ok(s.peer_addr()?.to_string()),
|
||||
InboundStream::Unix(_) => Ok("unix_socket".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for InboundStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
|
||||
Reference in New Issue
Block a user