use super::stream::InboundStream; use crate::config::{RealIpConfig, RealIpSource, ServiceConfig}; use crate::core::forwarder; use crate::core::server::pingora_compat::UnifiedPingoraStream; use crate::core::upstream::UpstreamStream; use crate::db::clickhouse::{ClickHouseLogger, ProxyProto}; use crate::protocol::{self, ProxyInfo}; use bytes::BytesMut; use pingora::protocols::GetSocketDigest; use std::io; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; use tracing::{error, info}; pub async fn handle_connection( stream: UnifiedPingoraStream, proxy_info: Option, service: ServiceConfig, db: Arc, listen_addr: String, physical_addr: SocketAddr, real_ip_config: Option, ) -> std::io::Result { let conn_ts = time::OffsetDateTime::now_utc(); let start_instant = std::time::Instant::now(); // Extract resolved IP from digest (injected by listener) let digest = stream.get_socket_digest(); let (final_ip, final_port, local_addr_opt) = if let Some(d) = &digest { let peer = if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.peer_addr() { (addr.ip(), addr.port()) } else { (std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0) }; let local = if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.local_addr() { Some(SocketAddr::new(addr.ip(), addr.port())) } else { None }; (peer.0, peer.1, local) } else { ( std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0, None, ) }; // Unwrap stream if Plain to attempt zero-copy, otherwise use generic stream let mut read_buffer = BytesMut::new(); let mut inbound_enum: Option = None; let mut generic_stream: Option = None; if let UnifiedPingoraStream::Plain(s) = stream { let (inbound, buf) = s.into_inner(); read_buffer = buf; inbound_enum = Some(inbound); } else { generic_stream = Some(stream); } let is_unix = if let Some(ref inbound) = inbound_enum { matches!(inbound, InboundStream::Unix(_)) } else { false }; let src_fmt = if is_unix { "local".to_string() } else { physical_addr.to_string() }; let mut extras = Vec::new(); let mut is_untrusted = false; // Determine untrusted status based solely on RealIpConfig trust range if let Some(ref cfg) = real_ip_config { if !cfg.is_trusted(physical_addr.ip()) { is_untrusted = true; } } // If we have proxy info, we should show it. if let Some(ref info) = proxy_info { // Only show (untrusted) if we have proxy info and the source is not trusted if is_untrusted { extras.push("(untrusted)".to_string()); } let version_str = match info.version { protocol::Version::V1 => "proxy.v1", protocol::Version::V2 => "proxy.v2", }; let helper_str = format!("{}: {}", version_str, info.source); extras.push(format!("({})", helper_str)); } else if is_untrusted && real_ip_config.is_some() { } let log_msg = if extras.is_empty() { format!("[{}] {} <- {}", service.name, listen_addr, src_fmt) } else { format!( "[{}] {} <- {} {}", service.name, listen_addr, src_fmt, extras.join(" ") ) }; info!("{}", log_msg); // skip redundant proxy/IP resolution (done by listener); determine ProxyProto for logging let proto_enum = if let Some(ref info) = proxy_info { match info.version { protocol::Version::V1 => ProxyProto::V1, protocol::Version::V2 => ProxyProto::V2, } } else { ProxyProto::None }; // 3. Connect Upstream let forward_to = service.forward_to.as_deref().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, format!("service '{}' missing forward_to", service.name), ) })?; let mut upstream = UpstreamStream::connect(forward_to).await?; // [NEW] Send Proxy Protocol Header if configured if let Some(upstream_ver) = &service.upstream_proxy { // Resolve addresses let src_addr = SocketAddr::new(final_ip, final_port); // Determine destination address, fallback to localhost if unknown. let mut dst_addr = local_addr_opt.unwrap_or_else(|| { if let Ok(addr) = listen_addr.parse::() { addr } else { // Last resort fallback SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), 0) } }); // If destination is unspecified (0.0.0.0), replace with localhost to be valid if dst_addr.ip().is_unspecified() { let new_ip = match dst_addr.ip() { IpAddr::V4(_) => IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), IpAddr::V6(_) => IpAddr::V6(std::net::Ipv6Addr::LOCALHOST), }; dst_addr = SocketAddr::new(new_ip, dst_addr.port()); } let version = match upstream_ver.as_str() { "v1" => protocol::Version::V1, _ => protocol::Version::V2, }; if let Err(e) = protocol::write_proxy_header(&mut upstream, version, src_addr, dst_addr).await { error!("Failed to write proxy header to upstream: {}", e); return Err(e); } } // 4. Write buffered data if !read_buffer.is_empty() { upstream.write_all_buf(&mut read_buffer).await?; } // 5. Forwarding let mut bytes_sent = 0; let mut bytes_recv = 0; if let Some(inbound) = inbound_enum { // Splice Optimization Path let inbound_async = match inbound { InboundStream::Tcp(s) => crate::core::upstream::AsyncStream::from_tokio_tcp(s)?, InboundStream::Unix(s) => crate::core::upstream::AsyncStream::from_tokio_unix(s)?, }; let upstream_async = upstream.into_async_stream()?; let ((s2c, c2s), res) = forwarder::zero_copy_bidirectional(inbound_async, upstream_async).await; bytes_sent = s2c; bytes_recv = c2s; if let Err(e) = res { match e.kind() { std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => { tracing::debug!("[{}] connection closed: {}", service.name, e); } _ => error!("[{}] connection error: {}", service.name, e), } } } else if let Some(mut stream) = generic_stream { // Generic Copy Path (TLS): internal buffer is automatically handled by the stream wrapper. match tokio::io::copy_bidirectional(&mut stream, &mut upstream).await { Ok((s2c, c2s)) => { bytes_sent = s2c; bytes_recv = c2s; } Err(e) => match e.kind() { std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => { tracing::debug!("[{}] connection closed: {}", service.name, e); } _ => error!("[{}] connection error: {}", service.name, e), }, } } // Calculate total bytes // bytes_sent/recv already set let bytes_recv = bytes_recv + read_buffer.len() as u64; let total_bytes = bytes_sent + bytes_recv; // Logging logic let duration = start_instant.elapsed().as_millis() as u32; // Handle Unix socket specifics for logging let (log_ip, log_port, log_family) = if is_unix && proto_enum == ProxyProto::None { ( IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED), 0, crate::db::clickhouse::AddrFamily::Unix, ) } else { let family = match final_ip { IpAddr::V4(_) => crate::db::clickhouse::AddrFamily::Ipv4, IpAddr::V6(_) => crate::db::clickhouse::AddrFamily::Ipv6, }; (final_ip, final_port, family) }; let log_entry = crate::db::clickhouse::TcpLog { service: service.name.clone(), conn_ts, duration, addr_family: log_family, ip: log_ip, port: log_port, proxy_proto: proto_enum, bytes: total_bytes, bytes_sent, bytes_recv, }; if let Err(e) = db.insert_log(log_entry).await { error!("failed to insert tcp log: {}", e); } Ok(total_bytes) } pub async fn resolve_real_ip( config: &Option, remote_addr: SocketAddr, proxy_info: &Option, inbound: &mut InboundStream, buffer: &mut BytesMut, ) -> io::Result<(IpAddr, u16)> { if let Some(cfg) = config { match cfg.source { RealIpSource::ProxyProtocol => { if let Some(info) = proxy_info { // Trust check: The PHYSICAL connection must be from a trusted source if cfg.is_trusted(remote_addr.ip()) { return Ok((info.source.ip(), info.source.port())); } } } RealIpSource::Xff => { let current_ip = if let Some(info) = proxy_info { info.source.ip() } else { remote_addr.ip() }; if cfg.is_trusted(current_ip) { if let Some(ip) = peek_xff_ip(inbound, buffer, cfg.xff_trust_depth).await? { // XFF doesn't have port, use remote/proxy port let port = if let Some(info) = proxy_info { info.source.port() } else { remote_addr.port() }; return Ok((ip, port)); } } } RealIpSource::RemoteAddr => { return Ok((remote_addr.ip(), remote_addr.port())); } } } // Fallback to Remote Address if no config or strategy failed. Ok((remote_addr.ip(), remote_addr.port())) } pub(crate) async fn peek_xff_ip( stream: &mut T, buffer: &mut BytesMut, _trust_depth: usize, ) -> io::Result> { let max_header = 4096; loop { if let Some(pos) = buffer.windows(4).position(|w| w == b"\r\n\r\n") { let header_bytes = &buffer[..pos + 4]; let mut headers = [httparse::Header { name: "", value: &[], }; 32]; let mut req = httparse::Request::new(&mut headers); if req.parse(header_bytes).is_ok() { for header in req.headers { if header.name.eq_ignore_ascii_case("x-forwarded-for") { if let Ok(val) = std::str::from_utf8(header.value) { let ips: Vec<&str> = val.split(',').map(|s| s.trim()).collect(); if let Some(ip_str) = ips.last() { if let Ok(ip) = ip_str.parse() { return Ok(Some(ip)); } } } } } } return Ok(None); } if buffer.len() >= max_header { return Ok(None); } if stream.read_buf(buffer).await? == 0 { return Ok(None); } } }