diff --git a/src/config/mod.rs b/src/config/mod.rs index f96aa12..bddb08a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -40,6 +40,8 @@ pub struct ServiceConfig { pub binds: Vec, #[serde(rename = "forward_to")] pub forward_to: String, + #[serde(rename = "upstream_proxy")] + pub upstream_proxy: Option, } #[derive(Debug, Deserialize, Clone)] @@ -223,6 +225,17 @@ impl Config { } } } + + if let Some(upstream_proxy) = &service.upstream_proxy { + match upstream_proxy.as_str() { + "v1" | "v2" => {}, + other => anyhow::bail!( + "Service '{}' has invalid 'upstream_proxy' value '{}'. Allowed values are 'v1' or 'v2'.", + service.name, + other + ), + } + } } Ok(()) } @@ -265,6 +278,7 @@ services: assert_eq!(config.services[0].binds[0].addr, "0.0.0.0:22222"); assert_eq!(config.services[0].binds[0].proxy, Some("v2".to_string())); assert_eq!(config.services[0].forward_to, "127.0.0.1:22"); + assert_eq!(config.services[0].upstream_proxy, None); } #[test] diff --git a/src/core/pingora_proxy.rs b/src/core/pingora_proxy.rs index 44cdb3c..dab9e6e 100644 --- a/src/core/pingora_proxy.rs +++ b/src/core/pingora_proxy.rs @@ -182,7 +182,7 @@ impl ProxyHttp for TrauditProxy { for (i, e) in extras.iter().enumerate() { extra_str.push_str(&format!("({})", e)); if i < extras.len() - 1 { - extra_str.push_str(" "); + extra_str.push(' '); } } diff --git a/src/core/server/handler.rs b/src/core/server/handler.rs index 19aac2e..c458e05 100644 --- a/src/core/server/handler.rs +++ b/src/core/server/handler.rs @@ -27,14 +27,27 @@ pub async fn handle_connection( // Extract resolved IP from digest (injected by listener) let digest = stream.get_socket_digest(); - let (final_ip, final_port) = if let Some(d) = digest { - if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.peer_addr() { + let (final_ip, final_port, local_addr_opt) = if let Some(d) = &digest { + let peer = if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.peer_addr() { (addr.ip(), addr.port()) } else { (std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0) - } + }; + + let local = if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.local_addr() + { + Some(SocketAddr::new(addr.ip(), addr.port())) + } else { + None + }; + + (peer.0, peer.1, local) } else { - (std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0) + ( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), + 0, + None, + ) }; // Unwrap stream if Plain to attempt zero-copy, otherwise use generic stream @@ -79,12 +92,11 @@ pub async fn handle_connection( extras.push("(untrusted)".to_string()); } - let helper_str; let version_str = match info.version { protocol::Version::V1 => "proxy.v1", protocol::Version::V2 => "proxy.v2", }; - helper_str = format!("{}: {}", version_str, info.source); + let helper_str = format!("{}: {}", version_str, info.source); extras.push(format!("({})", helper_str)); } else if is_untrusted && real_ip_config.is_some() { } @@ -116,6 +128,24 @@ pub async fn handle_connection( // 3. Connect Upstream let mut upstream = UpstreamStream::connect(&service.forward_to).await?; + // [NEW] Send Proxy Protocol Header if configured + if let Some(upstream_ver) = &service.upstream_proxy { + // Resolve addresses + let src_addr = SocketAddr::new(final_ip, final_port); + // Use extracted local_addr or fallback to physical_addr (server socket) + let dst_addr = local_addr_opt.unwrap_or(physical_addr); + + let version = match upstream_ver.as_str() { + "v1" => protocol::Version::V1, + _ => protocol::Version::V2, + }; + + if let Err(e) = protocol::write_proxy_header(&mut upstream, version, src_addr, dst_addr).await { + error!("Failed to write proxy header to upstream: {}", e); + return Err(e); + } + } + // 4. Write buffered data if !read_buffer.is_empty() { upstream.write_all_buf(&mut read_buffer).await?; diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 71facfe..af92fa8 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -9,6 +9,9 @@ pub struct ProxyInfo { pub source: SocketAddr, } +mod writer; +pub use writer::*; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Version { V1, diff --git a/src/protocol/writer.rs b/src/protocol/writer.rs new file mode 100644 index 0000000..d09fb33 --- /dev/null +++ b/src/protocol/writer.rs @@ -0,0 +1,145 @@ +use crate::protocol::{Version, V2_PREFIX}; +use std::io; +use std::net::SocketAddr; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +pub async fn write_proxy_header( + stream: &mut T, + version: Version, + src_addr: SocketAddr, + dst_addr: SocketAddr, +) -> io::Result<()> { + match version { + Version::V1 => write_v1(stream, src_addr, dst_addr).await, + Version::V2 => write_v2(stream, src_addr, dst_addr).await, + } +} + +async fn write_v1( + stream: &mut T, + src: SocketAddr, + dst: SocketAddr, +) -> io::Result<()> { + // Format: PROXY TCP4/TCP6 src_ip dst_ip src_port dst_port\r\n + let proto = match src { + SocketAddr::V4(_) => "TCP4", + SocketAddr::V6(_) => "TCP6", + }; + + let header = format!( + "PROXY {} {} {} {} {}\r\n", + proto, + src.ip(), + dst.ip(), + src.port(), + dst.port() + ); + + stream.write_all(header.as_bytes()).await +} + +async fn write_v2( + stream: &mut T, + src: SocketAddr, + dst: SocketAddr, +) -> io::Result<()> { + // Signature + stream.write_all(V2_PREFIX).await?; + + // Version + Command (Ver=2, Cmd=1 Proxy) -> 0x21 + stream.write_u8(0x21).await?; + + // Family + Protocol + let (fam, len) = match (src, dst) { + (SocketAddr::V4(_), SocketAddr::V4(_)) => { + // AF_INET=1, STREAM=1 -> 0x11 + // Length: 4+4+2+2 = 12 + (0x11, 12u16) + } + (SocketAddr::V6(_), SocketAddr::V6(_)) => { + // AF_INET6=2, STREAM=1 -> 0x21 + // Length: 16+16+2+2 = 36 + (0x21, 36u16) + } + _ => { + // Mismatched families? Should not happen in normal flows if we stick to one protocol. + // But if it happens, we might just send UNSPEC or fail. + // Let's send UNSPEC (AF_UNSPEC=0, UNSPEC=0) -> 0x00 and len 0 + stream.write_u8(0x00).await?; + stream.write_u16(0).await?; + return Ok(()); + } + }; + + stream.write_u8(fam).await?; + stream.write_u16(len).await?; + + match (src, dst) { + (SocketAddr::V4(s), SocketAddr::V4(d)) => { + stream.write_all(&s.ip().octets()).await?; + stream.write_all(&d.ip().octets()).await?; + stream.write_u16(s.port()).await?; + stream.write_u16(d.port()).await?; + } + (SocketAddr::V6(s), SocketAddr::V6(d)) => { + stream.write_all(&s.ip().octets()).await?; + stream.write_all(&d.ip().octets()).await?; + stream.write_u16(s.port()).await?; + stream.write_u16(d.port()).await?; + } + _ => {} + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + use std::net::{IpAddr, Ipv4Addr}; + + #[tokio::test] + async fn test_write_v1() { + let mut buf = Vec::new(); + let mut cursor = Cursor::new(&mut buf); + let src = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 1000); + let dst = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2)), 2000); + + write_proxy_header(&mut cursor, Version::V1, src, dst) + .await + .unwrap(); + + let output = String::from_utf8(buf).unwrap(); + assert_eq!(output, "PROXY TCP4 1.1.1.1 2.2.2.2 1000 2000\r\n"); + } + + #[tokio::test] + async fn test_write_v2_ipv4() { + let mut buf = Vec::new(); + let mut cursor = Cursor::new(&mut buf); + let src = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 1000); + let dst = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2)), 2000); + + write_proxy_header(&mut cursor, Version::V2, src, dst) + .await + .unwrap(); + + // Check signature + assert!(buf.starts_with(V2_PREFIX)); + // Check Version/Command (0x21) + assert_eq!(buf[12], 0x21); + // Check Fam/Proto (0x11 for IPv4 TCP) + assert_eq!(buf[13], 0x11); + // Length (12 bytes) + assert_eq!(buf[14], 0); + assert_eq!(buf[15], 12); + // Payload (src ip, dst ip, src port, dst port) + // 1.1.1.1 = 01 01 01 01 + // 2.2.2.2 = 02 02 02 02 + // 1000 = 03 E8 + // 2000 = 07 D0 + let payload = &buf[16..]; + assert_eq!(payload, &[1, 1, 1, 1, 2, 2, 2, 2, 0x03, 0xE8, 0x07, 0xD0]); + } +}