From f883753b96de6e7c3f9d785103759d9fc5cd34f6 Mon Sep 17 00:00:00 2001 From: awfufu Date: Fri, 20 Feb 2026 20:50:00 +0800 Subject: [PATCH] feat: add configurable HTTP-to-HTTPS redirects with redirect-only HTTP services --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/config/mod.rs | 253 ++++++++++++++++++++++++++++++++++++- src/core/pingora_proxy.rs | 107 +++++++++++++++- src/core/server/handler.rs | 8 +- src/core/server/mod.rs | 18 ++- tests/common/mod.rs | 6 +- tests/proxy_test.rs | 15 ++- 8 files changed, 390 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 39bc20b..ecbb188 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4059,7 +4059,7 @@ dependencies = [ [[package]] name = "traudit" -version = "0.0.8" +version = "0.0.9" dependencies = [ "anyhow", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index e143481..317ba4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "traudit" -version = "0.0.8" +version = "0.0.9" edition = "2021" authors = ["awfufu"] description = "A reverse proxy that streams audit records directly to databases." diff --git a/src/config/mod.rs b/src/config/mod.rs index bddb08a..fcac540 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -39,11 +39,73 @@ pub struct ServiceConfig { pub service_type: String, pub binds: Vec, #[serde(rename = "forward_to")] - pub forward_to: String, + pub forward_to: Option, #[serde(rename = "upstream_proxy")] pub upstream_proxy: Option, } +#[derive(Debug, Clone)] +pub struct RedirectHttpsConfig { + pub enabled: bool, + pub code: u16, + pub port: u16, +} + +#[derive(Debug, Deserialize)] +struct RedirectHttpsConfigObject { + enabled: bool, + #[serde(default = "default_redirect_code")] + code: u16, + #[serde(default = "default_redirect_port")] + port: u16, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum RedirectHttpsConfigRaw { + Bool(bool), + Object(RedirectHttpsConfigObject), +} + +impl Default for RedirectHttpsConfig { + fn default() -> Self { + Self { + enabled: false, + code: default_redirect_code(), + port: default_redirect_port(), + } + } +} + +impl<'de> Deserialize<'de> for RedirectHttpsConfig { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let raw = RedirectHttpsConfigRaw::deserialize(deserializer)?; + Ok(match raw { + RedirectHttpsConfigRaw::Bool(enabled) => Self { + enabled, + code: default_redirect_code(), + port: default_redirect_port(), + }, + RedirectHttpsConfigRaw::Object(obj) => Self { + enabled: obj.enabled, + code: obj.code, + port: obj.port, + }, + }) + } +} + +fn default_redirect_code() -> u16 { + 308 +} + +fn default_redirect_port() -> u16 { + 443 +} + #[derive(Debug, Deserialize, Clone)] pub struct RealIpConfig { #[serde(default, rename = "from")] @@ -103,6 +165,8 @@ pub struct BindEntry { pub tls: Option, #[serde(default)] pub add_xff_header: bool, + #[serde(default)] + pub redirect_https: RedirectHttpsConfig, pub real_ip: Option, } @@ -191,7 +255,48 @@ impl Config { pub fn validate(&self) -> anyhow::Result<()> { for service in &self.services { + let needs_backend = match service.service_type.as_str() { + "tcp" => true, + "http" => service.binds.iter().any(|b| !b.redirect_https.enabled), + _ => true, + }; + + if needs_backend && service.forward_to.is_none() { + anyhow::bail!( + "Service '{}' requires 'forward_to'. For type '{}' this is required unless all HTTP binds are redirect-only.", + service.name, + service.service_type + ); + } + for bind in &service.binds { + if bind.redirect_https.enabled { + if service.service_type != "http" { + anyhow::bail!( + "Service '{}' bind '{}' enables 'redirect_https', but this is only valid for type 'http'.", + service.name, + bind.addr + ); + } + + if bind.tls.is_some() { + anyhow::bail!( + "Service '{}' bind '{}' enables 'redirect_https' and 'tls' together. Redirect-to-HTTPS must be configured on non-TLS HTTP binds.", + service.name, + bind.addr + ); + } + + if !(300..=399).contains(&bind.redirect_https.code) { + anyhow::bail!( + "Service '{}' bind '{}' has invalid 'redirect_https.code' {}. Expected 3xx status code.", + service.name, + bind.addr, + bind.redirect_https.code + ); + } + } + if let Some(real_ip) = &bind.real_ip { // Rule 1: TCP services cannot use XFF as they don't parse HTTP headers if service.service_type == "tcp" && real_ip.source == RealIpSource::Xff { @@ -227,6 +332,13 @@ impl Config { } if let Some(upstream_proxy) = &service.upstream_proxy { + if service.forward_to.is_none() { + anyhow::bail!( + "Service '{}' sets 'upstream_proxy' but has no 'forward_to'.", + service.name + ); + } + match upstream_proxy.as_str() { "v1" | "v2" => {}, other => anyhow::bail!( @@ -277,10 +389,147 @@ services: assert_eq!(config.services[0].name, "ssh-prod"); assert_eq!(config.services[0].binds[0].addr, "0.0.0.0:22222"); assert_eq!(config.services[0].binds[0].proxy, Some("v2".to_string())); - assert_eq!(config.services[0].forward_to, "127.0.0.1:22"); + assert_eq!(config.services[0].forward_to, Some("127.0.0.1:22".to_string())); assert_eq!(config.services[0].upstream_proxy, None); } + #[test] + fn test_redirect_https_bool_and_object() { + #[derive(Deserialize)] + struct TestBind { + #[serde(default)] + redirect_https: RedirectHttpsConfig, + } + + let yaml_bool = "redirect_https: true"; + let bind_bool: TestBind = serde_yaml::from_str(yaml_bool).unwrap(); + assert!(bind_bool.redirect_https.enabled); + assert_eq!(bind_bool.redirect_https.code, 308); + assert_eq!(bind_bool.redirect_https.port, 443); + + let yaml_obj = "redirect_https:\n enabled: true\n code: 301\n port: 8443\n"; + let bind_obj: TestBind = serde_yaml::from_str(yaml_obj).unwrap(); + assert!(bind_obj.redirect_https.enabled); + assert_eq!(bind_obj.redirect_https.code, 301); + assert_eq!(bind_obj.redirect_https.port, 8443); + } + + #[tokio::test] + async fn test_http_redirect_only_can_omit_forward_to() { + let config_str = r#" +database: + type: clickhouse + dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db" + +services: + - name: "redirect-only" + type: "http" + binds: + - addr: "0.0.0.0:80" + redirect_https: true +"#; + let mut file = tempfile::NamedTempFile::new().unwrap(); + write!(file, "{}", config_str).unwrap(); + let path = file.path().to_path_buf(); + + let config = Config::load(&path).await.expect("Failed to load config"); + assert_eq!(config.services[0].forward_to, None); + } + + #[tokio::test] + async fn test_http_non_redirect_bind_requires_forward_to() { + let config_str = r#" +database: + type: clickhouse + dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db" + +services: + - name: "http-no-backend" + type: "http" + binds: + - addr: "0.0.0.0:8080" +"#; + let mut file = tempfile::NamedTempFile::new().unwrap(); + write!(file, "{}", config_str).unwrap(); + let path = file.path().to_path_buf(); + + let err = Config::load(&path).await.unwrap_err(); + assert!(err.to_string().contains("requires 'forward_to'")); + } + + #[tokio::test] + async fn test_redirect_https_rejects_tls_same_bind() { + let config_str = r#" +database: + type: clickhouse + dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db" + +services: + - name: "bad-redirect-tls" + type: "http" + binds: + - addr: "0.0.0.0:443" + tls: + cert: "/tmp/cert.pem" + key: "/tmp/key.pem" + redirect_https: true + forward_to: "127.0.0.1:8080" +"#; + let mut file = tempfile::NamedTempFile::new().unwrap(); + write!(file, "{}", config_str).unwrap(); + let path = file.path().to_path_buf(); + + let err = Config::load(&path).await.unwrap_err(); + assert!(err.to_string().contains("'redirect_https' and 'tls' together")); + } + + #[tokio::test] + async fn test_redirect_https_requires_http_service() { + let config_str = r#" +database: + type: clickhouse + dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db" + +services: + - name: "bad-redirect-tcp" + type: "tcp" + binds: + - addr: "0.0.0.0:2222" + redirect_https: true + forward_to: "127.0.0.1:22" +"#; + let mut file = tempfile::NamedTempFile::new().unwrap(); + write!(file, "{}", config_str).unwrap(); + let path = file.path().to_path_buf(); + + let err = Config::load(&path).await.unwrap_err(); + assert!(err.to_string().contains("only valid for type 'http'")); + } + + #[tokio::test] + async fn test_redirect_https_code_must_be_3xx() { + let config_str = r#" +database: + type: clickhouse + dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db" + +services: + - name: "bad-redirect-code" + type: "http" + binds: + - addr: "0.0.0.0:80" + redirect_https: + enabled: true + code: 200 +"#; + let mut file = tempfile::NamedTempFile::new().unwrap(); + write!(file, "{}", config_str).unwrap(); + let path = file.path().to_path_buf(); + + let err = Config::load(&path).await.unwrap_err(); + assert!(err.to_string().contains("Expected 3xx status code")); + } + #[test] fn test_mode_deserialization() { #[derive(Deserialize)] diff --git a/src/core/pingora_proxy.rs b/src/core/pingora_proxy.rs index dab9e6e..969ff14 100644 --- a/src/core/pingora_proxy.rs +++ b/src/core/pingora_proxy.rs @@ -1,7 +1,8 @@ -use crate::config::{RealIpSource, ServiceConfig}; +use crate::config::{RealIpSource, RedirectHttpsConfig, ServiceConfig}; use crate::db::clickhouse::{ClickHouseLogger, HttpLog, HttpMethod}; use async_trait::async_trait; use pingora::prelude::*; +use pingora::http::ResponseHeader; use std::net::IpAddr; use std::sync::Arc; use std::time::Instant; @@ -12,6 +13,7 @@ pub struct TrauditProxy { pub listen_addr: String, pub real_ip: Option, pub add_xff_header: bool, + pub redirect_https: RedirectHttpsConfig, } pub struct HttpContext { @@ -237,6 +239,20 @@ impl ProxyHttp for TrauditProxy { .unwrap_or("") .to_string(); + if self.redirect_https.enabled { + let location = build_https_redirect_location( + session.req_header(), + self.redirect_https.port, + ) + .ok_or_else(|| Error::explain(InternalError, "failed to build https redirect location"))?; + + let mut header = ResponseHeader::build(self.redirect_https.code, Some(0))?; + header.insert_header("Location", &location)?; + session.set_keepalive(None); + session.write_response_header(Box::new(header), true).await?; + return Ok(true); + } + Ok(false) // false to continue processing } @@ -245,7 +261,12 @@ impl ProxyHttp for TrauditProxy { _session: &mut Session, _ctx: &mut Self::CTX, ) -> Result> { - let addr = &self.service_config.forward_to; + let addr = self.service_config.forward_to.as_deref().ok_or_else(|| { + Error::explain( + InternalError, + format!("service '{}' missing forward_to", self.service_config.name), + ) + })?; let peer = Box::new(HttpPeer::new(addr, false, "".to_string())); Ok(peer) } @@ -307,3 +328,85 @@ impl ProxyHttp for TrauditProxy { }); } } + +fn build_https_redirect_location(req: &pingora::http::RequestHeader, target_port: u16) -> Option { + let host_raw = req + .uri + .host() + .map(ToString::to_string) + .or_else(|| { + req + .headers + .get("host") + .and_then(|v| v.to_str().ok()) + .map(ToString::to_string) + })?; + + let authority = host_raw + .parse::() + .ok() + .map(|a| a.host().to_string()) + .unwrap_or_else(|| host_raw.clone()); + + let needs_brackets = authority.contains(':') && !authority.starts_with('['); + let host = if needs_brackets { + format!("[{}]", authority) + } else { + authority + }; + + let host_port = if target_port == 443 { + host + } else { + format!("{}:{}", host, target_port) + }; + + let path_q = req + .uri + .path_and_query() + .map(|v| v.as_str()) + .unwrap_or("/"); + + Some(format!("https://{}{}", host_port, path_q)) +} + +#[cfg(test)] +mod tests { + use super::build_https_redirect_location; + use pingora::http::RequestHeader; + + #[test] + fn test_redirect_location_from_host_header_default_port() { + let mut req = RequestHeader::build("GET", b"/a/b?x=1", None).unwrap(); + req.insert_header("Host", "example.com").unwrap(); + + let location = build_https_redirect_location(&req, 443).unwrap(); + assert_eq!(location, "https://example.com/a/b?x=1"); + } + + #[test] + fn test_redirect_location_overrides_host_port() { + let mut req = RequestHeader::build("GET", b"/", None).unwrap(); + req.insert_header("Host", "example.com:8080").unwrap(); + + let location = build_https_redirect_location(&req, 8443).unwrap(); + assert_eq!(location, "https://example.com:8443/"); + } + + #[test] + fn test_redirect_location_ipv6_host() { + let mut req = RequestHeader::build("GET", b"/hello", None).unwrap(); + req.insert_header("Host", "[2001:db8::1]:8080").unwrap(); + + let location = build_https_redirect_location(&req, 443).unwrap(); + assert_eq!(location, "https://[2001:db8::1]/hello"); + } + + #[test] + fn test_redirect_location_missing_host() { + let req = RequestHeader::build("GET", b"/", None).unwrap(); + + let location = build_https_redirect_location(&req, 443); + assert!(location.is_none()); + } +} diff --git a/src/core/server/handler.rs b/src/core/server/handler.rs index 199aa0b..2bbb4cd 100644 --- a/src/core/server/handler.rs +++ b/src/core/server/handler.rs @@ -126,7 +126,13 @@ pub async fn handle_connection( }; // 3. Connect Upstream - let mut upstream = UpstreamStream::connect(&service.forward_to).await?; + let forward_to = service.forward_to.as_deref().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("service '{}' missing forward_to", service.name), + ) + })?; + let mut upstream = UpstreamStream::connect(forward_to).await?; // [NEW] Send Proxy Protocol Header if configured if let Some(upstream_ver) = &service.upstream_proxy { diff --git a/src/core/server/mod.rs b/src/core/server/mod.rs index 1c7500f..1a10178 100644 --- a/src/core/server/mod.rs +++ b/src/core/server/mod.rs @@ -47,6 +47,7 @@ pub async fn run( let proxy_proto_config = bind.proxy.clone(); let mode = bind.mode; let real_ip_config = bind.real_ip.clone(); + let redirect_https = bind.redirect_https.clone(); let is_tcp_service = service.service_type == "tcp"; @@ -146,13 +147,15 @@ pub async fn run( if is_tcp_service { // --- TCP Handler (with startup check) --- - if let Err(e) = UpstreamStream::connect(&service_config.forward_to).await { - tracing::warn!( - "[{}] -> '{}': startup check failed: {}", - service_config.name, - service_config.forward_to, - e - ); + if let Some(forward_to) = service_config.forward_to.as_deref() { + if let Err(e) = UpstreamStream::connect(forward_to).await { + tracing::warn!( + "[{}] -> '{}': startup check failed: {}", + service_config.name, + forward_to, + e + ); + } } let db = db.clone(); @@ -198,6 +201,7 @@ pub async fn run( listen_addr: bind_addr.clone(), real_ip: real_ip_config.clone(), add_xff_header: bind.add_xff_header, + redirect_https, }; let mut service_obj = http_proxy_service(&conf, inner_proxy); let app = unsafe { diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 7a16c9a..674b28e 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -18,7 +18,8 @@ use tokio::sync::OnceCell; use tokio::task::JoinHandle; use traudit::config::{ - BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, ServiceConfig, TlsConfig, + BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, RedirectHttpsConfig, + ServiceConfig, TlsConfig, }; static INIT: Once = Once::new(); @@ -354,7 +355,7 @@ pub async fn prepare_env( services: vec![ServiceConfig { name: "test-svc".to_string(), service_type: service_type.to_string(), - forward_to: upstream_addr.to_string(), + forward_to: Some(upstream_addr.to_string()), upstream_proxy: None, binds: vec![BindEntry { addr: bind_addr.clone(), @@ -363,6 +364,7 @@ pub async fn prepare_env( tls: tls_config, real_ip, add_xff_header: add_xff, + redirect_https: RedirectHttpsConfig::default(), }], }], }; diff --git a/tests/proxy_test.rs b/tests/proxy_test.rs index d165435..2c79669 100644 --- a/tests/proxy_test.rs +++ b/tests/proxy_test.rs @@ -3,7 +3,8 @@ use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use traudit::config::{ - BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, ServiceConfig, + BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, RedirectHttpsConfig, + ServiceConfig, }; mod common; @@ -221,8 +222,9 @@ async fn prepare_chain_env() -> ChainTestResources { tls: None, add_xff_header: false, real_ip: None, + redirect_https: RedirectHttpsConfig::default(), }], - forward_to: addr2.clone(), + forward_to: Some(addr2.clone()), upstream_proxy: Some("v1".to_string()), }, // E2: (Proxy V1 In, Upstream Proxy V2) @@ -236,8 +238,9 @@ async fn prepare_chain_env() -> ChainTestResources { tls: None, add_xff_header: false, real_ip: real_ip_pp.clone(), + redirect_https: RedirectHttpsConfig::default(), }], - forward_to: addr3.clone(), + forward_to: Some(addr3.clone()), upstream_proxy: Some("v2".to_string()), }, // E3: (Proxy V2 In, Upstream Proxy V1) @@ -251,8 +254,9 @@ async fn prepare_chain_env() -> ChainTestResources { tls: None, add_xff_header: false, real_ip: real_ip_pp.clone(), + redirect_https: RedirectHttpsConfig::default(), }], - forward_to: addr4.clone(), + forward_to: Some(addr4.clone()), upstream_proxy: Some("v1".to_string()), }, // E4: (Proxy V1 In, No Upstream Proxy -> Mock Server) @@ -266,8 +270,9 @@ async fn prepare_chain_env() -> ChainTestResources { tls: None, add_xff_header: false, real_ip: real_ip_pp.clone(), + redirect_https: RedirectHttpsConfig::default(), }], - forward_to: e4_upstream_addr.to_string(), + forward_to: Some(e4_upstream_addr.to_string()), upstream_proxy: None, }, ];