mirror of
https://github.com/awfufu/traudit
synced 2026-03-01 05:29:44 +08:00
feat: implement upstream proxy protocol support (v1/v2)
This commit is contained in:
@@ -40,6 +40,8 @@ pub struct ServiceConfig {
|
||||
pub binds: Vec<BindEntry>,
|
||||
#[serde(rename = "forward_to")]
|
||||
pub forward_to: String,
|
||||
#[serde(rename = "upstream_proxy")]
|
||||
pub upstream_proxy: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
@@ -223,6 +225,17 @@ impl Config {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(upstream_proxy) = &service.upstream_proxy {
|
||||
match upstream_proxy.as_str() {
|
||||
"v1" | "v2" => {},
|
||||
other => anyhow::bail!(
|
||||
"Service '{}' has invalid 'upstream_proxy' value '{}'. Allowed values are 'v1' or 'v2'.",
|
||||
service.name,
|
||||
other
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -265,6 +278,7 @@ services:
|
||||
assert_eq!(config.services[0].binds[0].addr, "0.0.0.0:22222");
|
||||
assert_eq!(config.services[0].binds[0].proxy, Some("v2".to_string()));
|
||||
assert_eq!(config.services[0].forward_to, "127.0.0.1:22");
|
||||
assert_eq!(config.services[0].upstream_proxy, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -182,7 +182,7 @@ impl ProxyHttp for TrauditProxy {
|
||||
for (i, e) in extras.iter().enumerate() {
|
||||
extra_str.push_str(&format!("({})", e));
|
||||
if i < extras.len() - 1 {
|
||||
extra_str.push_str(" ");
|
||||
extra_str.push(' ');
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,14 +27,27 @@ pub async fn handle_connection(
|
||||
|
||||
// Extract resolved IP from digest (injected by listener)
|
||||
let digest = stream.get_socket_digest();
|
||||
let (final_ip, final_port) = if let Some(d) = digest {
|
||||
if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.peer_addr() {
|
||||
let (final_ip, final_port, local_addr_opt) = if let Some(d) = &digest {
|
||||
let peer = if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.peer_addr() {
|
||||
(addr.ip(), addr.port())
|
||||
} else {
|
||||
(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0)
|
||||
}
|
||||
};
|
||||
|
||||
let local = if let Some(pingora::protocols::l4::socket::SocketAddr::Inet(addr)) = d.local_addr()
|
||||
{
|
||||
Some(SocketAddr::new(addr.ip(), addr.port()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(peer.0, peer.1, local)
|
||||
} else {
|
||||
(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 0)
|
||||
(
|
||||
std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)),
|
||||
0,
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
// Unwrap stream if Plain to attempt zero-copy, otherwise use generic stream
|
||||
@@ -79,12 +92,11 @@ pub async fn handle_connection(
|
||||
extras.push("(untrusted)".to_string());
|
||||
}
|
||||
|
||||
let helper_str;
|
||||
let version_str = match info.version {
|
||||
protocol::Version::V1 => "proxy.v1",
|
||||
protocol::Version::V2 => "proxy.v2",
|
||||
};
|
||||
helper_str = format!("{}: {}", version_str, info.source);
|
||||
let helper_str = format!("{}: {}", version_str, info.source);
|
||||
extras.push(format!("({})", helper_str));
|
||||
} else if is_untrusted && real_ip_config.is_some() {
|
||||
}
|
||||
@@ -116,6 +128,24 @@ pub async fn handle_connection(
|
||||
// 3. Connect Upstream
|
||||
let mut upstream = UpstreamStream::connect(&service.forward_to).await?;
|
||||
|
||||
// [NEW] Send Proxy Protocol Header if configured
|
||||
if let Some(upstream_ver) = &service.upstream_proxy {
|
||||
// Resolve addresses
|
||||
let src_addr = SocketAddr::new(final_ip, final_port);
|
||||
// Use extracted local_addr or fallback to physical_addr (server socket)
|
||||
let dst_addr = local_addr_opt.unwrap_or(physical_addr);
|
||||
|
||||
let version = match upstream_ver.as_str() {
|
||||
"v1" => protocol::Version::V1,
|
||||
_ => protocol::Version::V2,
|
||||
};
|
||||
|
||||
if let Err(e) = protocol::write_proxy_header(&mut upstream, version, src_addr, dst_addr).await {
|
||||
error!("Failed to write proxy header to upstream: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Write buffered data
|
||||
if !read_buffer.is_empty() {
|
||||
upstream.write_all_buf(&mut read_buffer).await?;
|
||||
|
||||
@@ -9,6 +9,9 @@ pub struct ProxyInfo {
|
||||
pub source: SocketAddr,
|
||||
}
|
||||
|
||||
mod writer;
|
||||
pub use writer::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Version {
|
||||
V1,
|
||||
|
||||
145
src/protocol/writer.rs
Normal file
145
src/protocol/writer.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
use crate::protocol::{Version, V2_PREFIX};
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::io::{AsyncWrite, AsyncWriteExt};
|
||||
|
||||
pub async fn write_proxy_header<T: AsyncWrite + Unpin>(
|
||||
stream: &mut T,
|
||||
version: Version,
|
||||
src_addr: SocketAddr,
|
||||
dst_addr: SocketAddr,
|
||||
) -> io::Result<()> {
|
||||
match version {
|
||||
Version::V1 => write_v1(stream, src_addr, dst_addr).await,
|
||||
Version::V2 => write_v2(stream, src_addr, dst_addr).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_v1<T: AsyncWrite + Unpin>(
|
||||
stream: &mut T,
|
||||
src: SocketAddr,
|
||||
dst: SocketAddr,
|
||||
) -> io::Result<()> {
|
||||
// Format: PROXY TCP4/TCP6 src_ip dst_ip src_port dst_port\r\n
|
||||
let proto = match src {
|
||||
SocketAddr::V4(_) => "TCP4",
|
||||
SocketAddr::V6(_) => "TCP6",
|
||||
};
|
||||
|
||||
let header = format!(
|
||||
"PROXY {} {} {} {} {}\r\n",
|
||||
proto,
|
||||
src.ip(),
|
||||
dst.ip(),
|
||||
src.port(),
|
||||
dst.port()
|
||||
);
|
||||
|
||||
stream.write_all(header.as_bytes()).await
|
||||
}
|
||||
|
||||
async fn write_v2<T: AsyncWrite + Unpin>(
|
||||
stream: &mut T,
|
||||
src: SocketAddr,
|
||||
dst: SocketAddr,
|
||||
) -> io::Result<()> {
|
||||
// Signature
|
||||
stream.write_all(V2_PREFIX).await?;
|
||||
|
||||
// Version + Command (Ver=2, Cmd=1 Proxy) -> 0x21
|
||||
stream.write_u8(0x21).await?;
|
||||
|
||||
// Family + Protocol
|
||||
let (fam, len) = match (src, dst) {
|
||||
(SocketAddr::V4(_), SocketAddr::V4(_)) => {
|
||||
// AF_INET=1, STREAM=1 -> 0x11
|
||||
// Length: 4+4+2+2 = 12
|
||||
(0x11, 12u16)
|
||||
}
|
||||
(SocketAddr::V6(_), SocketAddr::V6(_)) => {
|
||||
// AF_INET6=2, STREAM=1 -> 0x21
|
||||
// Length: 16+16+2+2 = 36
|
||||
(0x21, 36u16)
|
||||
}
|
||||
_ => {
|
||||
// Mismatched families? Should not happen in normal flows if we stick to one protocol.
|
||||
// But if it happens, we might just send UNSPEC or fail.
|
||||
// Let's send UNSPEC (AF_UNSPEC=0, UNSPEC=0) -> 0x00 and len 0
|
||||
stream.write_u8(0x00).await?;
|
||||
stream.write_u16(0).await?;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
stream.write_u8(fam).await?;
|
||||
stream.write_u16(len).await?;
|
||||
|
||||
match (src, dst) {
|
||||
(SocketAddr::V4(s), SocketAddr::V4(d)) => {
|
||||
stream.write_all(&s.ip().octets()).await?;
|
||||
stream.write_all(&d.ip().octets()).await?;
|
||||
stream.write_u16(s.port()).await?;
|
||||
stream.write_u16(d.port()).await?;
|
||||
}
|
||||
(SocketAddr::V6(s), SocketAddr::V6(d)) => {
|
||||
stream.write_all(&s.ip().octets()).await?;
|
||||
stream.write_all(&d.ip().octets()).await?;
|
||||
stream.write_u16(s.port()).await?;
|
||||
stream.write_u16(d.port()).await?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Cursor;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_v1() {
|
||||
let mut buf = Vec::new();
|
||||
let mut cursor = Cursor::new(&mut buf);
|
||||
let src = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 1000);
|
||||
let dst = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2)), 2000);
|
||||
|
||||
write_proxy_header(&mut cursor, Version::V1, src, dst)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let output = String::from_utf8(buf).unwrap();
|
||||
assert_eq!(output, "PROXY TCP4 1.1.1.1 2.2.2.2 1000 2000\r\n");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_v2_ipv4() {
|
||||
let mut buf = Vec::new();
|
||||
let mut cursor = Cursor::new(&mut buf);
|
||||
let src = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 1000);
|
||||
let dst = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2)), 2000);
|
||||
|
||||
write_proxy_header(&mut cursor, Version::V2, src, dst)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Check signature
|
||||
assert!(buf.starts_with(V2_PREFIX));
|
||||
// Check Version/Command (0x21)
|
||||
assert_eq!(buf[12], 0x21);
|
||||
// Check Fam/Proto (0x11 for IPv4 TCP)
|
||||
assert_eq!(buf[13], 0x11);
|
||||
// Length (12 bytes)
|
||||
assert_eq!(buf[14], 0);
|
||||
assert_eq!(buf[15], 12);
|
||||
// Payload (src ip, dst ip, src port, dst port)
|
||||
// 1.1.1.1 = 01 01 01 01
|
||||
// 2.2.2.2 = 02 02 02 02
|
||||
// 1000 = 03 E8
|
||||
// 2000 = 07 D0
|
||||
let payload = &buf[16..];
|
||||
assert_eq!(payload, &[1, 1, 1, 1, 2, 2, 2, 2, 0x03, 0xE8, 0x07, 0xD0]);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user