Files
traudit/src/core/server/handler.rs

347 lines
10 KiB
Rust

use super::stream::InboundStream;
use crate::config::{RealIpConfig, RealIpSource, ServiceConfig};
use crate::core::forwarder;
use crate::core::server::pingora_compat::UnifiedPingoraStream;
use crate::core::upstream::UpstreamStream;
use crate::db::clickhouse::{ClickHouseLogger, ProxyProto};
use crate::protocol::{self, ProxyInfo};
use bytes::BytesMut;
use pingora::protocols::GetSocketDigest;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tracing::{error, info};
pub async fn handle_connection(
stream: UnifiedPingoraStream,
proxy_info: Option<ProxyInfo>,
service: ServiceConfig,
db: Arc<ClickHouseLogger>,
listen_addr: String,
physical_addr: SocketAddr,
real_ip_config: Option<RealIpConfig>,
) -> std::io::Result<u64> {
let conn_ts = time::OffsetDateTime::now_utc();
let start_instant = std::time::Instant::now();
// Extract resolved IP from digest (injected by listener)
let digest = stream.get_socket_digest();
let (final_ip, final_port, local_addr_opt) = if let Some(d) = &digest {
let peer = if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.peer_addr() {
(addr.ip(), addr.port())
} else {
(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0)
};
let local = if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.local_addr()
{
Some(SocketAddr::new(addr.ip(), addr.port()))
} else {
None
};
(peer.0, peer.1, local)
} else {
(
std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)),
0,
None,
)
};
// Unwrap stream if Plain to attempt zero-copy, otherwise use generic stream
let mut read_buffer = BytesMut::new();
let mut inbound_enum: Option<InboundStream> = None;
let mut generic_stream: Option<UnifiedPingoraStream> = None;
if let UnifiedPingoraStream::Plain(s) = stream {
let (inbound, buf) = s.into_inner();
read_buffer = buf;
inbound_enum = Some(inbound);
} else {
generic_stream = Some(stream);
}
let is_unix = if let Some(ref inbound) = inbound_enum {
matches!(inbound, InboundStream::Unix(_))
} else {
false
};
let src_fmt = if is_unix {
"local".to_string()
} else {
physical_addr.to_string()
};
let mut extras = Vec::new();
let mut is_untrusted = false;
// Determine untrusted status based solely on RealIpConfig trust range
if let Some(ref cfg) = real_ip_config {
if !cfg.is_trusted(physical_addr.ip()) {
is_untrusted = true;
}
}
// If we have proxy info, we should show it.
if let Some(ref info) = proxy_info {
// Only show (untrusted) if we have proxy info and the source is not trusted
if is_untrusted {
extras.push("(untrusted)".to_string());
}
let version_str = match info.version {
protocol::Version::V1 => "proxy.v1",
protocol::Version::V2 => "proxy.v2",
};
let helper_str = format!("{}: {}", version_str, info.source);
extras.push(format!("({})", helper_str));
} else if is_untrusted && real_ip_config.is_some() {
}
let log_msg = if extras.is_empty() {
format!("[{}] {} <- {}", service.name, listen_addr, src_fmt)
} else {
format!(
"[{}] {} <- {} {}",
service.name,
listen_addr,
src_fmt,
extras.join(" ")
)
};
info!("{}", log_msg);
// skip redundant proxy/IP resolution (done by listener); determine ProxyProto for logging
let proto_enum = if let Some(ref info) = proxy_info {
match info.version {
protocol::Version::V1 => ProxyProto::V1,
protocol::Version::V2 => ProxyProto::V2,
}
} else {
ProxyProto::None
};
// 3. Connect Upstream
let forward_to = service.forward_to.as_deref().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("service '{}' missing forward_to", service.name),
)
})?;
let mut upstream = UpstreamStream::connect(forward_to).await?;
// [NEW] Send Proxy Protocol Header if configured
if let Some(upstream_ver) = &service.upstream_proxy {
// Resolve addresses
let src_addr = SocketAddr::new(final_ip, final_port);
// Determine destination address, fallback to localhost if unknown.
let mut dst_addr = local_addr_opt.unwrap_or_else(|| {
if let Ok(addr) = listen_addr.parse::<SocketAddr>() {
addr
} else {
// Last resort fallback
SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), 0)
}
});
// If destination is unspecified (0.0.0.0), replace with localhost to be valid
if dst_addr.ip().is_unspecified() {
let new_ip = match dst_addr.ip() {
IpAddr::V4(_) => IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
IpAddr::V6(_) => IpAddr::V6(std::net::Ipv6Addr::LOCALHOST),
};
dst_addr = SocketAddr::new(new_ip, dst_addr.port());
}
let version = match upstream_ver.as_str() {
"v1" => protocol::Version::V1,
_ => protocol::Version::V2,
};
if let Err(e) = protocol::write_proxy_header(&mut upstream, version, src_addr, dst_addr).await {
error!("Failed to write proxy header to upstream: {}", e);
return Err(e);
}
}
// 4. Write buffered data
if !read_buffer.is_empty() {
upstream.write_all_buf(&mut read_buffer).await?;
}
// 5. Forwarding
let mut bytes_sent = 0;
let mut bytes_recv = 0;
if let Some(inbound) = inbound_enum {
// Splice Optimization Path
let inbound_async = match inbound {
InboundStream::Tcp(s) => crate::core::upstream::AsyncStream::from_tokio_tcp(s)?,
InboundStream::Unix(s) => crate::core::upstream::AsyncStream::from_tokio_unix(s)?,
};
let upstream_async = upstream.into_async_stream()?;
let ((s2c, c2s), res) = forwarder::zero_copy_bidirectional(inbound_async, upstream_async).await;
bytes_sent = s2c;
bytes_recv = c2s;
if let Err(e) = res {
match e.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
tracing::debug!("[{}] connection closed: {}", service.name, e);
}
_ => error!("[{}] connection error: {}", service.name, e),
}
}
} else if let Some(mut stream) = generic_stream {
// Generic Copy Path (TLS): internal buffer is automatically handled by the stream wrapper.
match tokio::io::copy_bidirectional(&mut stream, &mut upstream).await {
Ok((s2c, c2s)) => {
bytes_sent = s2c;
bytes_recv = c2s;
}
Err(e) => match e.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
tracing::debug!("[{}] connection closed: {}", service.name, e);
}
_ => error!("[{}] connection error: {}", service.name, e),
},
}
}
// Calculate total bytes
// bytes_sent/recv already set
let bytes_recv = bytes_recv + read_buffer.len() as u64;
let total_bytes = bytes_sent + bytes_recv;
// Logging logic
let duration = start_instant.elapsed().as_millis() as u32;
// Handle Unix socket specifics for logging
let (log_ip, log_port, log_family) = if is_unix && proto_enum == ProxyProto::None {
(
IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
0,
crate::db::clickhouse::AddrFamily::Unix,
)
} else {
let family = match final_ip {
IpAddr::V4(_) => crate::db::clickhouse::AddrFamily::Ipv4,
IpAddr::V6(_) => crate::db::clickhouse::AddrFamily::Ipv6,
};
(final_ip, final_port, family)
};
let log_entry = crate::db::clickhouse::TcpLog {
service: service.name.clone(),
conn_ts,
duration,
addr_family: log_family,
ip: log_ip,
port: log_port,
proxy_proto: proto_enum,
bytes: total_bytes,
bytes_sent,
bytes_recv,
};
if let Err(e) = db.insert_log(log_entry).await {
error!("failed to insert tcp log: {}", e);
}
Ok(total_bytes)
}
pub async fn resolve_real_ip(
config: &Option<RealIpConfig>,
remote_addr: SocketAddr,
proxy_info: &Option<ProxyInfo>,
inbound: &mut InboundStream,
buffer: &mut BytesMut,
) -> io::Result<(IpAddr, u16)> {
if let Some(cfg) = config {
match cfg.source {
RealIpSource::ProxyProtocol => {
if let Some(info) = proxy_info {
// Trust check: The PHYSICAL connection must be from a trusted source
if cfg.is_trusted(remote_addr.ip()) {
return Ok((info.source.ip(), info.source.port()));
}
}
}
RealIpSource::Xff => {
let current_ip = if let Some(info) = proxy_info {
info.source.ip()
} else {
remote_addr.ip()
};
if cfg.is_trusted(current_ip) {
if let Some(ip) = peek_xff_ip(inbound, buffer, cfg.xff_trust_depth).await? {
// XFF doesn't have port, use remote/proxy port
let port = if let Some(info) = proxy_info {
info.source.port()
} else {
remote_addr.port()
};
return Ok((ip, port));
}
}
}
RealIpSource::RemoteAddr => {
return Ok((remote_addr.ip(), remote_addr.port()));
}
}
}
// Fallback to Remote Address if no config or strategy failed.
Ok((remote_addr.ip(), remote_addr.port()))
}
pub(crate) async fn peek_xff_ip<T: AsyncRead + Unpin>(
stream: &mut T,
buffer: &mut BytesMut,
_trust_depth: usize,
) -> io::Result<Option<IpAddr>> {
let max_header = 4096;
loop {
if let Some(pos) = buffer.windows(4).position(|w| w == b"\r\n\r\n") {
let header_bytes = &buffer[..pos + 4];
let mut headers = [httparse::Header {
name: "",
value: &[],
}; 32];
let mut req = httparse::Request::new(&mut headers);
if req.parse(header_bytes).is_ok() {
for header in req.headers {
if header.name.eq_ignore_ascii_case("x-forwarded-for") {
if let Ok(val) = std::str::from_utf8(header.value) {
let ips: Vec<&str> = val.split(',').map(|s| s.trim()).collect();
if let Some(ip_str) = ips.last() {
if let Ok(ip) = ip_str.parse() {
return Ok(Some(ip));
}
}
}
}
}
}
return Ok(None);
}
if buffer.len() >= max_header {
return Ok(None);
}
if stream.read_buf(buffer).await? == 0 {
return Ok(None);
}
}
}