feat: implement upstream proxy protocol support (v1/v2)

This commit is contained in:
2026-01-21 15:25:48 +08:00
parent 003ca203a4
commit 37c948db7b
5 changed files with 199 additions and 7 deletions

View File

@@ -40,6 +40,8 @@ pub struct ServiceConfig {
pub binds: Vec<BindEntry>,
#[serde(rename = "forward_to")]
pub forward_to: String,
#[serde(rename = "upstream_proxy")]
pub upstream_proxy: Option<String>,
}
#[derive(Debug, Deserialize, Clone)]
@@ -223,6 +225,17 @@ impl Config {
}
}
}
if let Some(upstream_proxy) = &service.upstream_proxy {
match upstream_proxy.as_str() {
"v1" | "v2" => {},
other => anyhow::bail!(
"Service '{}' has invalid 'upstream_proxy' value '{}'. Allowed values are 'v1' or 'v2'.",
service.name,
other
),
}
}
}
Ok(())
}
@@ -265,6 +278,7 @@ services:
assert_eq!(config.services[0].binds[0].addr, "0.0.0.0:22222");
assert_eq!(config.services[0].binds[0].proxy, Some("v2".to_string()));
assert_eq!(config.services[0].forward_to, "127.0.0.1:22");
assert_eq!(config.services[0].upstream_proxy, None);
}
#[test]

View File

@@ -182,7 +182,7 @@ impl ProxyHttp for TrauditProxy {
for (i, e) in extras.iter().enumerate() {
extra_str.push_str(&format!("({})", e));
if i < extras.len() - 1 {
extra_str.push_str(" ");
extra_str.push(' ');
}
}

View File

@@ -27,14 +27,27 @@ pub async fn handle_connection(
// Extract resolved IP from digest (injected by listener)
let digest = stream.get_socket_digest();
let (final_ip, final_port) = if let Some(d) = digest {
if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.peer_addr() {
let (final_ip, final_port, local_addr_opt) = if let Some(d) = &digest {
let peer = if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.peer_addr() {
(addr.ip(), addr.port())
} else {
(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0)
}
};
let local = if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.local_addr()
{
Some(SocketAddr::new(addr.ip(), addr.port()))
} else {
None
};
(peer.0, peer.1, local)
} else {
(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0)
(
std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)),
0,
None,
)
};
// Unwrap stream if Plain to attempt zero-copy, otherwise use generic stream
@@ -79,12 +92,11 @@ pub async fn handle_connection(
extras.push("(untrusted)".to_string());
}
let helper_str;
let version_str = match info.version {
protocol::Version::V1 => "proxy.v1",
protocol::Version::V2 => "proxy.v2",
};
helper_str = format!("{}: {}", version_str, info.source);
let helper_str = format!("{}: {}", version_str, info.source);
extras.push(format!("({})", helper_str));
} else if is_untrusted && real_ip_config.is_some() {
}
@@ -116,6 +128,24 @@ pub async fn handle_connection(
// 3. Connect Upstream
let mut upstream = UpstreamStream::connect(&service.forward_to).await?;
// [NEW] Send Proxy Protocol Header if configured
if let Some(upstream_ver) = &service.upstream_proxy {
// Resolve addresses
let src_addr = SocketAddr::new(final_ip, final_port);
// Use extracted local_addr or fallback to physical_addr (server socket)
let dst_addr = local_addr_opt.unwrap_or(physical_addr);
let version = match upstream_ver.as_str() {
"v1" => protocol::Version::V1,
_ => protocol::Version::V2,
};
if let Err(e) = protocol::write_proxy_header(&mut upstream, version, src_addr, dst_addr).await {
error!("Failed to write proxy header to upstream: {}", e);
return Err(e);
}
}
// 4. Write buffered data
if !read_buffer.is_empty() {
upstream.write_all_buf(&mut read_buffer).await?;

View File

@@ -9,6 +9,9 @@ pub struct ProxyInfo {
pub source: SocketAddr,
}
mod writer;
pub use writer::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Version {
V1,

145
src/protocol/writer.rs Normal file
View File

@@ -0,0 +1,145 @@
use crate::protocol::{Version, V2_PREFIX};
use std::io;
use std::net::SocketAddr;
use tokio::io::{AsyncWrite, AsyncWriteExt};
pub async fn write_proxy_header<T: AsyncWrite + Unpin>(
stream: &mut T,
version: Version,
src_addr: SocketAddr,
dst_addr: SocketAddr,
) -> io::Result<()> {
match version {
Version::V1 => write_v1(stream, src_addr, dst_addr).await,
Version::V2 => write_v2(stream, src_addr, dst_addr).await,
}
}
async fn write_v1<T: AsyncWrite + Unpin>(
stream: &mut T,
src: SocketAddr,
dst: SocketAddr,
) -> io::Result<()> {
// Format: PROXY TCP4/TCP6 src_ip dst_ip src_port dst_port\r\n
let proto = match src {
SocketAddr::V4(_) => "TCP4",
SocketAddr::V6(_) => "TCP6",
};
let header = format!(
"PROXY {} {} {} {} {}\r\n",
proto,
src.ip(),
dst.ip(),
src.port(),
dst.port()
);
stream.write_all(header.as_bytes()).await
}
async fn write_v2<T: AsyncWrite + Unpin>(
stream: &mut T,
src: SocketAddr,
dst: SocketAddr,
) -> io::Result<()> {
// Signature
stream.write_all(V2_PREFIX).await?;
// Version + Command (Ver=2, Cmd=1 Proxy) -> 0x21
stream.write_u8(0x21).await?;
// Family + Protocol
let (fam, len) = match (src, dst) {
(SocketAddr::V4(_), SocketAddr::V4(_)) => {
// AF_INET=1, STREAM=1 -> 0x11
// Length: 4+4+2+2 = 12
(0x11, 12u16)
}
(SocketAddr::V6(_), SocketAddr::V6(_)) => {
// AF_INET6=2, STREAM=1 -> 0x21
// Length: 16+16+2+2 = 36
(0x21, 36u16)
}
_ => {
// Mismatched families? Should not happen in normal flows if we stick to one protocol.
// But if it happens, we might just send UNSPEC or fail.
// Let's send UNSPEC (AF_UNSPEC=0, UNSPEC=0) -> 0x00 and len 0
stream.write_u8(0x00).await?;
stream.write_u16(0).await?;
return Ok(());
}
};
stream.write_u8(fam).await?;
stream.write_u16(len).await?;
match (src, dst) {
(SocketAddr::V4(s), SocketAddr::V4(d)) => {
stream.write_all(&s.ip().octets()).await?;
stream.write_all(&d.ip().octets()).await?;
stream.write_u16(s.port()).await?;
stream.write_u16(d.port()).await?;
}
(SocketAddr::V6(s), SocketAddr::V6(d)) => {
stream.write_all(&s.ip().octets()).await?;
stream.write_all(&d.ip().octets()).await?;
stream.write_u16(s.port()).await?;
stream.write_u16(d.port()).await?;
}
_ => {}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use std::net::{IpAddr, Ipv4Addr};
#[tokio::test]
async fn test_write_v1() {
let mut buf = Vec::new();
let mut cursor = Cursor::new(&mut buf);
let src = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 1000);
let dst = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2)), 2000);
write_proxy_header(&mut cursor, Version::V1, src, dst)
.await
.unwrap();
let output = String::from_utf8(buf).unwrap();
assert_eq!(output, "PROXY TCP4 1.1.1.1 2.2.2.2 1000 2000\r\n");
}
#[tokio::test]
async fn test_write_v2_ipv4() {
let mut buf = Vec::new();
let mut cursor = Cursor::new(&mut buf);
let src = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 1000);
let dst = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2)), 2000);
write_proxy_header(&mut cursor, Version::V2, src, dst)
.await
.unwrap();
// Check signature
assert!(buf.starts_with(V2_PREFIX));
// Check Version/Command (0x21)
assert_eq!(buf[12], 0x21);
// Check Fam/Proto (0x11 for IPv4 TCP)
assert_eq!(buf[13], 0x11);
// Length (12 bytes)
assert_eq!(buf[14], 0);
assert_eq!(buf[15], 12);
// Payload (src ip, dst ip, src port, dst port)
// 1.1.1.1 = 01 01 01 01
// 2.2.2.2 = 02 02 02 02
// 1000 = 03 E8
// 2000 = 07 D0
let payload = &buf[16..];
assert_eq!(payload, &[1, 1, 1, 1, 2, 2, 2, 2, 0x03, 0xE8, 0x07, 0xD0]);
}
}