3 Commits

9 changed files with 510 additions and 35 deletions

2
Cargo.lock generated
View File

@@ -4059,7 +4059,7 @@ dependencies = [
[[package]]
name = "traudit"
version = "0.0.8"
version = "0.0.9"
dependencies = [
"anyhow",
"async-trait",

View File

@@ -1,6 +1,6 @@
[package]
name = "traudit"
version = "0.0.8"
version = "0.0.9"
edition = "2021"
authors = ["awfufu"]
description = "A reverse proxy that streams audit records directly to databases."

View File

@@ -1,4 +1,5 @@
use serde::{Deserialize, Deserializer};
use std::collections::HashSet;
use std::net::IpAddr;
use std::path::Path;
use tokio::fs;
@@ -39,11 +40,73 @@ pub struct ServiceConfig {
pub service_type: String,
pub binds: Vec<BindEntry>,
#[serde(rename = "forward_to")]
pub forward_to: String,
pub forward_to: Option<String>,
#[serde(rename = "upstream_proxy")]
pub upstream_proxy: Option<String>,
}
#[derive(Debug, Clone)]
pub struct RedirectHttpsConfig {
pub enabled: bool,
pub code: u16,
pub port: u16,
}
#[derive(Debug, Deserialize)]
struct RedirectHttpsConfigObject {
enabled: bool,
#[serde(default = "default_redirect_code")]
code: u16,
#[serde(default = "default_redirect_port")]
port: u16,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum RedirectHttpsConfigRaw {
Bool(bool),
Object(RedirectHttpsConfigObject),
}
impl Default for RedirectHttpsConfig {
fn default() -> Self {
Self {
enabled: false,
code: default_redirect_code(),
port: default_redirect_port(),
}
}
}
impl<'de> Deserialize<'de> for RedirectHttpsConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let raw = RedirectHttpsConfigRaw::deserialize(deserializer)?;
Ok(match raw {
RedirectHttpsConfigRaw::Bool(enabled) => Self {
enabled,
code: default_redirect_code(),
port: default_redirect_port(),
},
RedirectHttpsConfigRaw::Object(obj) => Self {
enabled: obj.enabled,
code: obj.code,
port: obj.port,
},
})
}
}
fn default_redirect_code() -> u16 {
308
}
fn default_redirect_port() -> u16 {
443
}
#[derive(Debug, Deserialize, Clone)]
pub struct RealIpConfig {
#[serde(default, rename = "from")]
@@ -103,6 +166,8 @@ pub struct BindEntry {
pub tls: Option<TlsConfig>,
#[serde(default)]
pub add_xff_header: bool,
#[serde(default)]
pub redirect_https: RedirectHttpsConfig,
pub real_ip: Option<RealIpConfig>,
}
@@ -190,8 +255,57 @@ impl Config {
}
pub fn validate(&self) -> anyhow::Result<()> {
let mut seen_service_names = HashSet::new();
for service in &self.services {
if !seen_service_names.insert(service.name.as_str()) {
anyhow::bail!(
"duplicate service name '{}' found. Service names must be unique.",
service.name
);
}
let needs_backend = match service.service_type.as_str() {
"tcp" => true,
"http" => service.binds.iter().any(|b| !b.redirect_https.enabled),
_ => true,
};
if needs_backend && service.forward_to.is_none() {
anyhow::bail!(
"Service '{}' requires 'forward_to'. For type '{}' this is required unless all HTTP binds are redirect-only.",
service.name,
service.service_type
);
}
for bind in &service.binds {
if bind.redirect_https.enabled {
if service.service_type != "http" {
anyhow::bail!(
"Service '{}' bind '{}' enables 'redirect_https', but this is only valid for type 'http'.",
service.name,
bind.addr
);
}
if bind.tls.is_some() {
anyhow::bail!(
"Service '{}' bind '{}' enables 'redirect_https' and 'tls' together. Redirect-to-HTTPS must be configured on non-TLS HTTP binds.",
service.name,
bind.addr
);
}
if !(300..=399).contains(&bind.redirect_https.code) {
anyhow::bail!(
"Service '{}' bind '{}' has invalid 'redirect_https.code' {}. Expected 3xx status code.",
service.name,
bind.addr,
bind.redirect_https.code
);
}
}
if let Some(real_ip) = &bind.real_ip {
// Rule 1: TCP services cannot use XFF as they don't parse HTTP headers
if service.service_type == "tcp" && real_ip.source == RealIpSource::Xff {
@@ -227,6 +341,13 @@ impl Config {
}
if let Some(upstream_proxy) = &service.upstream_proxy {
if service.forward_to.is_none() {
anyhow::bail!(
"Service '{}' sets 'upstream_proxy' but has no 'forward_to'.",
service.name
);
}
match upstream_proxy.as_str() {
"v1" | "v2" => {},
other => anyhow::bail!(
@@ -277,10 +398,175 @@ services:
assert_eq!(config.services[0].name, "ssh-prod");
assert_eq!(config.services[0].binds[0].addr, "0.0.0.0:22222");
assert_eq!(config.services[0].binds[0].proxy, Some("v2".to_string()));
assert_eq!(config.services[0].forward_to, "127.0.0.1:22");
assert_eq!(config.services[0].forward_to, Some("127.0.0.1:22".to_string()));
assert_eq!(config.services[0].upstream_proxy, None);
}
#[test]
fn test_redirect_https_bool_and_object() {
#[derive(Deserialize)]
struct TestBind {
#[serde(default)]
redirect_https: RedirectHttpsConfig,
}
let yaml_bool = "redirect_https: true";
let bind_bool: TestBind = serde_yaml::from_str(yaml_bool).unwrap();
assert!(bind_bool.redirect_https.enabled);
assert_eq!(bind_bool.redirect_https.code, 308);
assert_eq!(bind_bool.redirect_https.port, 443);
let yaml_obj = "redirect_https:\n enabled: true\n code: 301\n port: 8443\n";
let bind_obj: TestBind = serde_yaml::from_str(yaml_obj).unwrap();
assert!(bind_obj.redirect_https.enabled);
assert_eq!(bind_obj.redirect_https.code, 301);
assert_eq!(bind_obj.redirect_https.port, 8443);
}
#[tokio::test]
async fn test_http_redirect_only_can_omit_forward_to() {
let config_str = r#"
database:
type: clickhouse
dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db"
services:
- name: "redirect-only"
type: "http"
binds:
- addr: "0.0.0.0:80"
redirect_https: true
"#;
let mut file = tempfile::NamedTempFile::new().unwrap();
write!(file, "{}", config_str).unwrap();
let path = file.path().to_path_buf();
let config = Config::load(&path).await.expect("Failed to load config");
assert_eq!(config.services[0].forward_to, None);
}
#[tokio::test]
async fn test_http_non_redirect_bind_requires_forward_to() {
let config_str = r#"
database:
type: clickhouse
dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db"
services:
- name: "http-no-backend"
type: "http"
binds:
- addr: "0.0.0.0:8080"
"#;
let mut file = tempfile::NamedTempFile::new().unwrap();
write!(file, "{}", config_str).unwrap();
let path = file.path().to_path_buf();
let err = Config::load(&path).await.unwrap_err();
assert!(err.to_string().contains("requires 'forward_to'"));
}
#[tokio::test]
async fn test_redirect_https_rejects_tls_same_bind() {
let config_str = r#"
database:
type: clickhouse
dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db"
services:
- name: "bad-redirect-tls"
type: "http"
binds:
- addr: "0.0.0.0:443"
tls:
cert: "/tmp/cert.pem"
key: "/tmp/key.pem"
redirect_https: true
forward_to: "127.0.0.1:8080"
"#;
let mut file = tempfile::NamedTempFile::new().unwrap();
write!(file, "{}", config_str).unwrap();
let path = file.path().to_path_buf();
let err = Config::load(&path).await.unwrap_err();
assert!(err.to_string().contains("'redirect_https' and 'tls' together"));
}
#[tokio::test]
async fn test_redirect_https_requires_http_service() {
let config_str = r#"
database:
type: clickhouse
dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db"
services:
- name: "bad-redirect-tcp"
type: "tcp"
binds:
- addr: "0.0.0.0:2222"
redirect_https: true
forward_to: "127.0.0.1:22"
"#;
let mut file = tempfile::NamedTempFile::new().unwrap();
write!(file, "{}", config_str).unwrap();
let path = file.path().to_path_buf();
let err = Config::load(&path).await.unwrap_err();
assert!(err.to_string().contains("only valid for type 'http'"));
}
#[tokio::test]
async fn test_redirect_https_code_must_be_3xx() {
let config_str = r#"
database:
type: clickhouse
dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db"
services:
- name: "bad-redirect-code"
type: "http"
binds:
- addr: "0.0.0.0:80"
redirect_https:
enabled: true
code: 200
"#;
let mut file = tempfile::NamedTempFile::new().unwrap();
write!(file, "{}", config_str).unwrap();
let path = file.path().to_path_buf();
let err = Config::load(&path).await.unwrap_err();
assert!(err.to_string().contains("Expected 3xx status code"));
}
#[tokio::test]
async fn test_duplicate_service_names_not_allowed() {
let config_str = r#"
database:
type: clickhouse
dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db"
services:
- name: "dup"
type: "tcp"
forward_to: "127.0.0.1:22"
binds:
- addr: "0.0.0.0:2201"
- name: "dup"
type: "tcp"
forward_to: "127.0.0.1:22"
binds:
- addr: "0.0.0.0:2202"
"#;
let mut file = tempfile::NamedTempFile::new().unwrap();
write!(file, "{}", config_str).unwrap();
let path = file.path().to_path_buf();
let err = Config::load(&path).await.unwrap_err();
assert!(err.to_string().contains("duplicate service name"));
}
#[test]
fn test_mode_deserialization() {
#[derive(Deserialize)]

View File

@@ -1,7 +1,8 @@
use crate::config::{RealIpSource, ServiceConfig};
use crate::config::{RealIpSource, RedirectHttpsConfig, ServiceConfig};
use crate::db::clickhouse::{ClickHouseLogger, HttpLog, HttpMethod};
use async_trait::async_trait;
use pingora::prelude::*;
use pingora::http::ResponseHeader;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Instant;
@@ -12,6 +13,7 @@ pub struct TrauditProxy {
pub listen_addr: String,
pub real_ip: Option<crate::config::RealIpConfig>,
pub add_xff_header: bool,
pub redirect_https: RedirectHttpsConfig,
}
pub struct HttpContext {
@@ -237,6 +239,20 @@ impl ProxyHttp for TrauditProxy {
.unwrap_or("")
.to_string();
if self.redirect_https.enabled {
let location = build_https_redirect_location(
session.req_header(),
self.redirect_https.port,
)
.ok_or_else(|| Error::explain(InternalError, "failed to build https redirect location"))?;
let mut header = ResponseHeader::build(self.redirect_https.code, Some(0))?;
header.insert_header("Location", &location)?;
session.set_keepalive(None);
session.write_response_header(Box::new(header), true).await?;
return Ok(true);
}
Ok(false) // false to continue processing
}
@@ -245,7 +261,12 @@ impl ProxyHttp for TrauditProxy {
_session: &mut Session,
_ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
let addr = &self.service_config.forward_to;
let addr = self.service_config.forward_to.as_deref().ok_or_else(|| {
Error::explain(
InternalError,
format!("service '{}' missing forward_to", self.service_config.name),
)
})?;
let peer = Box::new(HttpPeer::new(addr, false, "".to_string()));
Ok(peer)
}
@@ -307,3 +328,85 @@ impl ProxyHttp for TrauditProxy {
});
}
}
fn build_https_redirect_location(req: &pingora::http::RequestHeader, target_port: u16) -> Option<String> {
let host_raw = req
.uri
.host()
.map(ToString::to_string)
.or_else(|| {
req
.headers
.get("host")
.and_then(|v| v.to_str().ok())
.map(ToString::to_string)
})?;
let authority = host_raw
.parse::<http::uri::Authority>()
.ok()
.map(|a| a.host().to_string())
.unwrap_or_else(|| host_raw.clone());
let needs_brackets = authority.contains(':') && !authority.starts_with('[');
let host = if needs_brackets {
format!("[{}]", authority)
} else {
authority
};
let host_port = if target_port == 443 {
host
} else {
format!("{}:{}", host, target_port)
};
let path_q = req
.uri
.path_and_query()
.map(|v| v.as_str())
.unwrap_or("/");
Some(format!("https://{}{}", host_port, path_q))
}
#[cfg(test)]
mod tests {
use super::build_https_redirect_location;
use pingora::http::RequestHeader;
#[test]
fn test_redirect_location_from_host_header_default_port() {
let mut req = RequestHeader::build("GET", b"/a/b?x=1", None).unwrap();
req.insert_header("Host", "example.com").unwrap();
let location = build_https_redirect_location(&req, 443).unwrap();
assert_eq!(location, "https://example.com/a/b?x=1");
}
#[test]
fn test_redirect_location_overrides_host_port() {
let mut req = RequestHeader::build("GET", b"/", None).unwrap();
req.insert_header("Host", "example.com:8080").unwrap();
let location = build_https_redirect_location(&req, 8443).unwrap();
assert_eq!(location, "https://example.com:8443/");
}
#[test]
fn test_redirect_location_ipv6_host() {
let mut req = RequestHeader::build("GET", b"/hello", None).unwrap();
req.insert_header("Host", "[2001:db8::1]:8080").unwrap();
let location = build_https_redirect_location(&req, 443).unwrap();
assert_eq!(location, "https://[2001:db8::1]/hello");
}
#[test]
fn test_redirect_location_missing_host() {
let req = RequestHeader::build("GET", b"/", None).unwrap();
let location = build_https_redirect_location(&req, 443);
assert!(location.is_none());
}
}

View File

@@ -126,7 +126,13 @@ pub async fn handle_connection(
};
// 3. Connect Upstream
let mut upstream = UpstreamStream::connect(&service.forward_to).await?;
let forward_to = service.forward_to.as_deref().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("service '{}' missing forward_to", service.name),
)
})?;
let mut upstream = UpstreamStream::connect(forward_to).await?;
// [NEW] Send Proxy Protocol Header if configured
if let Some(upstream_ver) = &service.upstream_proxy {

View File

@@ -11,6 +11,34 @@ use tokio::net::{TcpListener, UnixListener, UnixStream};
use tokio_openssl::SslStream;
use tracing::{error, info, warn};
fn normalize_ipv4_mapped_addr(addr: std::net::SocketAddr) -> std::net::SocketAddr {
match addr {
std::net::SocketAddr::V6(v6) => {
if let Some(v4) = v6.ip().to_ipv4_mapped() {
std::net::SocketAddr::new(std::net::IpAddr::V4(v4), v6.port())
} else {
std::net::SocketAddr::V6(v6)
}
}
other => other,
}
}
fn parse_tcp_bind_target(addr_str: &str) -> anyhow::Result<(std::net::SocketAddr, Option<bool>)> {
let (normalized_addr, force_v6_only) = if let Some(port) = addr_str.strip_prefix(":::") {
(format!("[::]:{}", port), Some(true))
} else if let Some(port) = addr_str.strip_prefix(":") {
(format!("[::]:{}", port), Some(false))
} else if let Some(port) = addr_str.strip_prefix("*:") {
(format!("[::]:{}", port), Some(false))
} else {
(addr_str.to_string(), None)
};
let addr: std::net::SocketAddr = normalized_addr.parse()?;
Ok((addr, force_v6_only))
}
pub enum UnifiedListener {
Tcp(TcpListener),
Unix(UnixListener, PathBuf), // PathBuf for cleanup on Drop
@@ -30,7 +58,7 @@ impl UnifiedListener {
match self {
UnifiedListener::Tcp(l) => {
let (stream, addr) = l.accept().await?;
Ok((InboundStream::Tcp(stream), addr))
Ok((InboundStream::Tcp(stream), normalize_ipv4_mapped_addr(addr)))
}
UnifiedListener::Unix(l, _) => {
let (stream, _addr) = l.accept().await?;
@@ -136,20 +164,11 @@ pub async fn bind_listener(
} else {
// TCP with SO_REUSEPORT
use nix::sys::socket::{setsockopt, sockopt};
use std::net::SocketAddr;
// AsRawFd removed
let normalized_addr = if addr_str.starts_with(":::") {
format!("[::]:{}", &addr_str[3..])
} else {
addr_str.to_string()
};
let addr: SocketAddr = normalized_addr
.parse()
.map_err(|e: std::net::AddrParseError| {
let (addr, force_v6_only) = parse_tcp_bind_target(addr_str).map_err(|e| {
error!("[{}] invalid address {}: {}", service_name, addr_str, e);
anyhow::anyhow!(e)
e
})?;
let domain = if addr.is_ipv4() {
@@ -173,6 +192,18 @@ pub async fn bind_listener(
}
}
if addr.is_ipv6() {
if let Some(v6_only) = force_v6_only {
socket.set_only_v6(v6_only).map_err(|e| {
error!(
"[{}] failed to configure IPV6_V6ONLY={} for {}: {}",
service_name, v6_only, addr_str, e
);
e
})?;
}
}
socket.set_nonblocking(true)?;
// Convert std::net::SocketAddr to socket2::SockAddr
@@ -234,6 +265,44 @@ pub async fn bind_listener(
Ok(listener)
}
#[cfg(test)]
mod tests {
use super::{normalize_ipv4_mapped_addr, parse_tcp_bind_target};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
#[test]
fn test_parse_tcp_bind_target_rules() {
let (a, v6_only) = parse_tcp_bind_target("0.0.0.0:80").unwrap();
assert_eq!(a, "0.0.0.0:80".parse::<SocketAddr>().unwrap());
assert_eq!(v6_only, None);
let (a, v6_only) = parse_tcp_bind_target(":::80").unwrap();
assert_eq!(a, "[::]:80".parse::<SocketAddr>().unwrap());
assert_eq!(v6_only, Some(true));
let (a, v6_only) = parse_tcp_bind_target(":80").unwrap();
assert_eq!(a, "[::]:80".parse::<SocketAddr>().unwrap());
assert_eq!(v6_only, Some(false));
let (a, v6_only) = parse_tcp_bind_target("*:80").unwrap();
assert_eq!(a, "[::]:80".parse::<SocketAddr>().unwrap());
assert_eq!(v6_only, Some(false));
}
#[test]
fn test_normalize_ipv4_mapped_addr() {
let mapped = SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xFFFF, 0xC000, 0x0280)),
8080,
);
let normalized = normalize_ipv4_mapped_addr(mapped);
assert_eq!(normalized, SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 0, 2, 128)), 8080));
let normal_v6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080);
assert_eq!(normalize_ipv4_mapped_addr(normal_v6), normal_v6);
}
}
pub async fn serve_listener_loop<F, Fut>(
listener: UnifiedListener,
service: ServiceConfig,

View File

@@ -47,6 +47,7 @@ pub async fn run(
let proxy_proto_config = bind.proxy.clone();
let mode = bind.mode;
let real_ip_config = bind.real_ip.clone();
let redirect_https = bind.redirect_https.clone();
let is_tcp_service = service.service_type == "tcp";
@@ -146,14 +147,16 @@ pub async fn run(
if is_tcp_service {
// --- TCP Handler (with startup check) ---
if let Err(e) = UpstreamStream::connect(&service_config.forward_to).await {
if let Some(forward_to) = service_config.forward_to.as_deref() {
if let Err(e) = UpstreamStream::connect(forward_to).await {
tracing::warn!(
"[{}] -> '{}': startup check failed: {}",
service_config.name,
service_config.forward_to,
forward_to,
e
);
}
}
let db = db.clone();
let _proxy_cfg = proxy_proto_config.clone();
@@ -198,6 +201,7 @@ pub async fn run(
listen_addr: bind_addr.clone(),
real_ip: real_ip_config.clone(),
add_xff_header: bind.add_xff_header,
redirect_https,
};
let mut service_obj = http_proxy_service(&conf, inner_proxy);
let app = unsafe {

View File

@@ -18,7 +18,8 @@ use tokio::sync::OnceCell;
use tokio::task::JoinHandle;
use traudit::config::{
BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, ServiceConfig, TlsConfig,
BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, RedirectHttpsConfig,
ServiceConfig, TlsConfig,
};
static INIT: Once = Once::new();
@@ -354,7 +355,7 @@ pub async fn prepare_env(
services: vec![ServiceConfig {
name: "test-svc".to_string(),
service_type: service_type.to_string(),
forward_to: upstream_addr.to_string(),
forward_to: Some(upstream_addr.to_string()),
upstream_proxy: None,
binds: vec![BindEntry {
addr: bind_addr.clone(),
@@ -363,6 +364,7 @@ pub async fn prepare_env(
tls: tls_config,
real_ip,
add_xff_header: add_xff,
redirect_https: RedirectHttpsConfig::default(),
}],
}],
};

View File

@@ -3,7 +3,8 @@ use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use traudit::config::{
BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, ServiceConfig,
BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, RedirectHttpsConfig,
ServiceConfig,
};
mod common;
@@ -221,8 +222,9 @@ async fn prepare_chain_env() -> ChainTestResources {
tls: None,
add_xff_header: false,
real_ip: None,
redirect_https: RedirectHttpsConfig::default(),
}],
forward_to: addr2.clone(),
forward_to: Some(addr2.clone()),
upstream_proxy: Some("v1".to_string()),
},
// E2: (Proxy V1 In, Upstream Proxy V2)
@@ -236,8 +238,9 @@ async fn prepare_chain_env() -> ChainTestResources {
tls: None,
add_xff_header: false,
real_ip: real_ip_pp.clone(),
redirect_https: RedirectHttpsConfig::default(),
}],
forward_to: addr3.clone(),
forward_to: Some(addr3.clone()),
upstream_proxy: Some("v2".to_string()),
},
// E3: (Proxy V2 In, Upstream Proxy V1)
@@ -251,8 +254,9 @@ async fn prepare_chain_env() -> ChainTestResources {
tls: None,
add_xff_header: false,
real_ip: real_ip_pp.clone(),
redirect_https: RedirectHttpsConfig::default(),
}],
forward_to: addr4.clone(),
forward_to: Some(addr4.clone()),
upstream_proxy: Some("v1".to_string()),
},
// E4: (Proxy V1 In, No Upstream Proxy -> Mock Server)
@@ -266,8 +270,9 @@ async fn prepare_chain_env() -> ChainTestResources {
tls: None,
add_xff_header: false,
real_ip: real_ip_pp.clone(),
redirect_https: RedirectHttpsConfig::default(),
}],
forward_to: e4_upstream_addr.to_string(),
forward_to: Some(e4_upstream_addr.to_string()),
upstream_proxy: None,
},
];