test: refactor integration tests and add proxy chain verification

This commit is contained in:
2026-01-21 15:25:58 +08:00
parent 37c948db7b
commit 779e836189
6 changed files with 945 additions and 750 deletions

2
Cargo.lock generated
View File

@@ -3461,7 +3461,7 @@ dependencies = [
[[package]]
name = "traudit"
version = "0.0.4"
version = "0.0.5"
dependencies = [
"anyhow",
"async-trait",

View File

@@ -1,6 +1,6 @@
[package]
name = "traudit"
version = "0.0.4"
version = "0.0.5"
edition = "2021"
authors = ["awfufu"]
description = "A reverse proxy that streams audit records directly to databases."

601
tests/common/mod.rs Normal file
View File

@@ -0,0 +1,601 @@
#![allow(dead_code)]
use std::io::Write;
use std::net::SocketAddr;
use std::os::unix::fs::PermissionsExt;
use std::path::PathBuf;
use std::sync::{Mutex, Once};
use std::time::Duration;
use clickhouse::Client;
use ctor::dtor;
use rcgen::generate_simple_self_signed;
use serde::Deserialize;
use testcontainers::runners::AsyncRunner;
use testcontainers::{GenericImage, ImageExt};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, UnixStream};
use tokio::sync::OnceCell;
use tokio::task::JoinHandle;
use traudit::config::{
BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, ServiceConfig, TlsConfig,
};
static INIT: Once = Once::new();
// Shared Container Singleton
pub struct SharedDb {
pub port: u16,
}
static SHARED_DB: OnceCell<SharedDb> = OnceCell::const_new();
// Cleanup Info used by dtor
struct CleanupInfo {
container_id: Option<String>,
temp_dir: Option<PathBuf>,
}
static CLEANUP_INFO: Mutex<CleanupInfo> = Mutex::new(CleanupInfo {
container_id: None,
temp_dir: None,
});
pub async fn get_shared_db_port() -> u16 {
let db = SHARED_DB
.get_or_init(|| async {
init_env();
let image = GenericImage::new("clickhouse/clickhouse-server", "latest")
.with_env_var("CLICKHOUSE_DB", "traudit")
.with_env_var("CLICKHOUSE_USER", "traudit")
.with_env_var("CLICKHOUSE_PASSWORD", "traudit")
.with_env_var("CLICKHOUSE_DEFAULT_ACCESS_MANAGEMENT", "1");
let container = image.start().await.expect("Failed to start container");
let port = container
.get_host_port_ipv4(8123)
.await
.expect("Failed to get port");
// Save ID for cleanup
if let Ok(mut info) = CLEANUP_INFO.lock() {
info.container_id = Some(container.id().to_string());
}
Box::leak(Box::new(container));
// Async wait
wait_for_clickhouse(port).await;
SharedDb { port }
})
.await;
db.port
}
#[dtor]
fn cleanup() {
if let Ok(info) = CLEANUP_INFO.lock() {
// Cleanup Container
if let Some(id) = &info.container_id {
let _ = std::process::Command::new("docker")
.args(["rm", "-f", id])
.output();
let _ = std::process::Command::new("podman")
.args(["rm", "-f", id])
.output();
}
// Cleanup Temp Dir (shim)
if let Some(path) = &info.temp_dir {
let _ = std::fs::remove_dir_all(path);
}
}
}
pub fn init_env() {
INIT.call_once(|| {
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter("info")
.with_test_writer()
.try_init()
.ok();
// Install Rustls Crypto Provider (Ring)
let _ = rustls::crypto::ring::default_provider().install_default();
// Detect Podman socket if Docker socket is missing or empty
let docker_host = std::env::var("DOCKER_HOST").unwrap_or_default();
let docker_sock_exists = std::path::Path::new("/var/run/docker.sock").exists();
let need_detection = docker_host.trim().is_empty() && !docker_sock_exists;
eprintln!(
"DEBUG: DOCKER_HOST='{}', docker_sock_exists={}, need_detection={}",
docker_host, docker_sock_exists, need_detection
);
if need_detection {
eprintln!("DEBUG: Attempting detection...");
let mut found = false;
if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock");
if podman_sock.exists() {
std::env::set_var("DOCKER_HOST", format!("unix://{}", podman_sock.display()));
found = true;
}
}
if !found {
// Fallback to hardcoded path
let fallback = std::path::Path::new("/run/user/1000/podman/podman.sock");
if fallback.exists() {
eprintln!("DEBUG: Found fallback podman sock: {}", fallback.display());
std::env::set_var("DOCKER_HOST", format!("unix://{}", fallback.display()));
} else {
eprintln!("DEBUG: No podman socket found in fallback locations.");
}
}
} else {
eprintln!("DEBUG: Skipping detection. Using existing DOCKER_HOST or docker.sock");
}
// Create docker shim for podman if docker is missing
if std::process::Command::new("docker")
.arg("-v")
.output()
.is_err()
{
let temp_dir = tempfile::tempdir().expect("failed to create temp dir");
let temp_path = temp_dir.path().to_owned();
let docker_shim = temp_path.join("docker");
let mut file = std::fs::File::create(&docker_shim).expect("failed to create docker shim");
file
.write_all(b"#!/bin/sh\nexec podman \"$@\"")
.expect("failed to write shim");
let mut perms = file.metadata().unwrap().permissions();
perms.set_mode(0o755);
file.set_permissions(perms).unwrap();
let path = std::env::var("PATH").unwrap_or_default();
let new_path = format!("{}:{}", temp_path.display(), path);
std::env::set_var("PATH", new_path);
// Persist the temp dir
let _ = temp_dir.keep();
// Save for cleanup
if let Ok(mut info) = CLEANUP_INFO.lock() {
info.temp_dir = Some(temp_path);
}
}
});
}
// Stream Trait Alias
pub trait TestStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> TestStream for T {}
// Database Helpers
pub fn get_db_client(port: u16, db_name: &str) -> Client {
let url = format!("http://127.0.0.1:{}", port);
Client::default()
.with_url(&url)
.with_user("traudit")
.with_password("traudit")
.with_database(db_name)
}
pub async fn wait_for_clickhouse(port: u16) {
let client = get_db_client(port, "default");
let mut attempts = 0;
while attempts < 30 {
if client.query("SELECT 1").execute().await.is_ok() {
return;
}
tokio::time::sleep(Duration::from_millis(1000)).await;
attempts += 1;
}
panic!("ClickHouse failed to start on port {}", port);
}
#[derive(Debug, Deserialize, clickhouse::Row)]
pub struct TcpLogCount {
pub count: u64,
}
#[derive(Debug, Deserialize, clickhouse::Row)]
pub struct HttpLogCount {
pub count: u64,
}
// Mock Upstream
pub async fn spawn_mock_upstream() -> (SocketAddr, JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
loop {
if let Ok((mut socket, _)) = listener.accept().await {
tokio::spawn(async move {
let mut buf = [0u8; 4096];
loop {
let n = match socket.read(&mut buf).await {
Ok(0) => return,
Ok(n) => n,
Err(_) => return,
};
let data = &buf[..n];
if data.starts_with(b"GET") || data.starts_with(b"POST") {
let req = String::from_utf8_lossy(data);
let body = format!("Headers:\n{}", req);
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
let _ = socket.write_all(response.as_bytes()).await;
return;
} else {
let _ = socket.write_all(data).await;
}
}
});
}
}
});
(addr, handle)
}
// Protocol Helpers
pub fn build_proxy_v1_header() -> Vec<u8> {
b"PROXY TCP4 1.1.1.1 2.2.2.2 1234 443\r\n".to_vec()
}
pub fn build_proxy_v2_header() -> Vec<u8> {
let mut header = vec![
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C,
];
header.extend_from_slice(&[1, 1, 1, 1]);
header.extend_from_slice(&[2, 2, 2, 2]);
header.extend_from_slice(&[0x04, 0xD2]);
header.extend_from_slice(&[0x01, 0xBB]);
header
}
// TLS Helpers
pub struct CertBundle {
pub cert_pem: String,
pub key_pem: String,
}
pub fn generate_cert() -> CertBundle {
let subject_alt_names = vec!["localhost".to_string(), "127.0.0.1".to_string()];
let certified_key = generate_simple_self_signed(subject_alt_names).unwrap();
CertBundle {
cert_pem: certified_key.cert.pem(),
key_pem: certified_key.key_pair.serialize_pem(),
}
}
// Config Builder
pub struct TestResources {
pub config: Config,
pub proxy_addr: String,
#[allow(dead_code)]
pub upstream_addr: SocketAddr,
pub _cert_file: Option<tempfile::NamedTempFile>,
pub _key_file: Option<tempfile::NamedTempFile>,
}
#[allow(clippy::too_many_arguments)]
pub async fn prepare_env(
service_type: &str,
proxy_proto: Option<&str>,
bind_tls: bool,
real_ip_source: Option<RealIpSource>,
add_xff: bool,
is_unix: bool,
db_port: u16,
db_name: String,
) -> TestResources {
let (upstream_addr, _) = spawn_mock_upstream().await;
let (bind_addr, port_guard, socket_path) = if is_unix {
let path = format!("/tmp/traudit_test_{}.sock", rand::random::<u64>());
let _ = std::fs::remove_file(&path);
(format!("unix://{}", path), None, Some(path))
} else {
let l = TcpListener::bind("127.0.0.1:0").await.unwrap();
let p = l.local_addr().unwrap().port();
(format!("127.0.0.1:{}", p), Some(l), None)
};
drop(port_guard);
let (tls_config, cert_f, key_f) = if bind_tls {
let bundle = generate_cert();
let mut cf = tempfile::NamedTempFile::new().unwrap();
cf.write_all(bundle.cert_pem.as_bytes()).unwrap();
let mut kf = tempfile::NamedTempFile::new().unwrap();
kf.write_all(bundle.key_pem.as_bytes()).unwrap();
(
Some(TlsConfig {
cert: cf.path().to_str().unwrap().to_string(),
key: Some(kf.path().to_str().unwrap().to_string()),
}),
Some(cf),
Some(kf),
)
} else {
(None, None, None)
};
let real_ip = real_ip_source.map(|s| RealIpConfig {
source: s,
trusted_proxies: vec![],
trust_private_ranges: true,
xff_trust_depth: 0,
});
let config = Config {
database: DatabaseConfig {
db_type: "clickhouse".to_string(),
dsn: format!("http://traudit:traudit@127.0.0.1:{}/{}", db_port, db_name),
batch_size: 1,
batch_timeout_secs: 1,
},
services: vec![ServiceConfig {
name: "test-svc".to_string(),
service_type: service_type.to_string(),
forward_to: upstream_addr.to_string(),
upstream_proxy: None,
binds: vec![BindEntry {
addr: bind_addr.clone(),
mode: 0o777,
proxy: proxy_proto.map(|s| s.to_string()),
tls: tls_config,
real_ip,
add_xff_header: add_xff,
}],
}],
};
let proxy_addr_clean = if is_unix {
socket_path.unwrap()
} else {
bind_addr
};
TestResources {
config,
proxy_addr: proxy_addr_clean,
upstream_addr,
_cert_file: cert_f,
_key_file: key_f,
}
}
pub async fn run_tcp_test(test_name: &str, proxy_proto: Option<&str>, is_unix: bool) {
// init_env called inside Lazy
let db_port = get_shared_db_port().await;
let db_name = test_name.to_string();
// Create DB
let system_client = get_db_client(db_port, "default");
system_client
.query(&format!("CREATE DATABASE IF NOT EXISTS {}", db_name))
.execute()
.await
.unwrap();
let client = get_db_client(db_port, &db_name);
let res = prepare_env(
"tcp",
proxy_proto,
false,
proxy_proto.map(|_| RealIpSource::ProxyProtocol),
false,
is_unix,
db_port,
db_name,
)
.await;
tokio::spawn(async move {
let _ = traudit::core::server::run(res.config).await;
});
// Wait for traudit startup and DB connect
tokio::time::sleep(Duration::from_millis(1000)).await;
let mut stream: Box<dyn TestStream> = if is_unix {
Box::new(
UnixStream::connect(&res.proxy_addr)
.await
.expect("Unix connect failed"),
)
} else {
Box::new(
TcpStream::connect(&res.proxy_addr)
.await
.expect("Tcp connect failed"),
)
};
if let Some(p) = proxy_proto {
if p == "v1" {
stream.write_all(&build_proxy_v1_header()).await.unwrap();
} else {
stream.write_all(&build_proxy_v2_header()).await.unwrap();
}
}
stream.write_all(b"ping").await.unwrap();
let mut buf = [0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"ping");
drop(stream);
tokio::time::sleep(Duration::from_millis(2000)).await;
let count = client
.query("SELECT count() as count FROM tcp_log WHERE service = 'test-svc'")
.fetch_one::<TcpLogCount>()
.await
.unwrap();
assert_eq!(count.count, 1);
if is_unix {
let _ = std::fs::remove_file(&res.proxy_addr);
}
}
pub async fn run_http_test(
test_name: &str,
proxy_proto: Option<&str>,
use_tls: bool,
real_ip_source: Option<RealIpSource>,
add_xff: bool,
expected_xff_in_upstream: Option<&str>,
) {
let db_port = get_shared_db_port().await;
let db_name = test_name.to_string();
let system_client = get_db_client(db_port, "default");
system_client
.query(&format!("CREATE DATABASE IF NOT EXISTS {}", db_name))
.execute()
.await
.unwrap();
let client = get_db_client(db_port, &db_name);
let res = prepare_env(
"http",
proxy_proto,
use_tls,
real_ip_source.clone(),
add_xff,
false,
db_port,
db_name,
)
.await;
tokio::spawn(async move {
let _ = traudit::core::server::run(res.config).await;
});
tokio::time::sleep(Duration::from_millis(1000)).await;
if proxy_proto.is_some() {
let mut stream = TcpStream::connect(&res.proxy_addr).await.unwrap();
if proxy_proto == Some("v1") {
stream.write_all(&build_proxy_v1_header()).await.unwrap();
} else {
stream.write_all(&build_proxy_v2_header()).await.unwrap();
}
if use_tls {
let mut root_store = rustls::RootCertStore::empty();
let cert_bytes = tokio::fs::read(res._cert_file.as_ref().unwrap().path())
.await
.unwrap();
let mut pem = std::io::BufReader::new(&cert_bytes[..]);
let certs = rustls_pemfile::certs(&mut pem)
.map(|c| c.unwrap())
.collect::<Vec<_>>();
root_store.add(certs[0].clone()).unwrap();
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let domain = rustls::pki_types::ServerName::try_from("localhost").unwrap();
let mut tls_stream = connector
.connect(domain, stream)
.await
.expect("TLS handshake failed");
let request = if add_xff || real_ip_source == Some(RealIpSource::Xff) {
if real_ip_source == Some(RealIpSource::Xff) {
"GET / HTTP/1.1\r\nHost: localhost\r\nX-Forwarded-For: 8.8.8.8\r\nConnection: close\r\n\r\n"
} else {
"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
}
} else {
"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
};
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut buf = Vec::new();
tls_stream.read_to_end(&mut buf).await.unwrap();
let body = String::from_utf8_lossy(&buf).to_string();
if let Some(expected) = expected_xff_in_upstream {
assert!(body.contains(expected), "Body: {}", body);
}
assert!(body.contains("200 OK"));
} else {
let request = if real_ip_source == Some(RealIpSource::Xff) {
"GET / HTTP/1.1\r\nHost: localhost\r\nX-Forwarded-For: 8.8.8.8\r\nConnection: close\r\n\r\n"
} else {
"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
};
stream.write_all(request.as_bytes()).await.unwrap();
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await.unwrap();
let body = String::from_utf8_lossy(&buf).to_string();
if let Some(expected) = expected_xff_in_upstream {
assert!(body.contains(expected));
}
assert!(body.contains("200 OK"));
}
} else {
let client_builder = reqwest::Client::builder();
let client_http = if use_tls {
let cert_bytes = tokio::fs::read(res._cert_file.as_ref().unwrap().path())
.await
.unwrap();
let cert = reqwest::Certificate::from_pem(&cert_bytes).unwrap();
client_builder
.add_root_certificate(cert)
.danger_accept_invalid_certs(true)
.build()
.unwrap()
} else {
client_builder.build().unwrap()
};
let protocol = if use_tls { "https" } else { "http" };
let url = format!("{}://{}/", protocol, res.proxy_addr);
let resp = if real_ip_source == Some(RealIpSource::Xff) {
client_http
.get(&url)
.header("X-Forwarded-For", "8.8.8.8")
.send()
.await
.expect("Req failed")
} else {
client_http.get(&url).send().await.expect("Req failed")
};
assert_eq!(resp.status(), 200);
let body = resp.text().await.unwrap();
if let Some(expected) = expected_xff_in_upstream {
assert!(body.contains(expected));
}
}
tokio::time::sleep(Duration::from_millis(2000)).await;
let count = client
.query("SELECT count() as count FROM http_log WHERE service = 'test-svc'")
.fetch_one::<HttpLogCount>()
.await
.unwrap();
assert_eq!(count.count, 1);
}

View File

@@ -1,762 +1,17 @@
use std::io::Write;
use std::net::SocketAddr;
use std::sync::Once;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, UnixStream};
use tokio::sync::OnceCell;
use tokio::task::JoinHandle;
use clickhouse::Client;
use serde::Deserialize;
use traudit::config::{
BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, ServiceConfig, TlsConfig,
};
// Testcontainers
use ctor::dtor;
use std::sync::Mutex;
use testcontainers::runners::AsyncRunner;
use testcontainers::{GenericImage, ImageExt};
// TLS Dependencies
use rcgen::generate_simple_self_signed;
use std::os::unix::fs::PermissionsExt;
static INIT: Once = Once::new();
// Shared Container Singleton
struct SharedDb {
port: u16,
}
static SHARED_DB: OnceCell<SharedDb> = OnceCell::const_new();
// Cleanup Info used by dtor
struct CleanupInfo {
container_id: Option<String>,
temp_dir: Option<std::path::PathBuf>,
}
static CLEANUP_INFO: Mutex<CleanupInfo> = Mutex::new(CleanupInfo {
container_id: None,
temp_dir: None,
});
async fn get_shared_db_port() -> u16 {
let db = SHARED_DB
.get_or_init(|| async {
init_env();
let image = GenericImage::new("clickhouse/clickhouse-server", "latest")
.with_env_var("CLICKHOUSE_DB", "traudit")
.with_env_var("CLICKHOUSE_USER", "traudit")
.with_env_var("CLICKHOUSE_PASSWORD", "traudit")
.with_env_var("CLICKHOUSE_DEFAULT_ACCESS_MANAGEMENT", "1");
let container = image.start().await.expect("Failed to start container");
let port = container
.get_host_port_ipv4(8123)
.await
.expect("Failed to get port");
// Save ID for cleanup
if let Ok(mut info) = CLEANUP_INFO.lock() {
info.container_id = Some(container.id().to_string());
}
Box::leak(Box::new(container));
// Async wait
wait_for_clickhouse(port).await;
SharedDb { port }
})
.await;
db.port
}
#[dtor]
fn cleanup() {
if let Ok(info) = CLEANUP_INFO.lock() {
// Cleanup Container
if let Some(id) = &info.container_id {
// We use standard process command to clean up attached container
// Try docker first, then podman
let _ = std::process::Command::new("docker")
.args(["rm", "-f", id])
.output();
let _ = std::process::Command::new("podman")
.args(["rm", "-f", id])
.output();
}
// Cleanup Temp Dir (shim)
if let Some(path) = &info.temp_dir {
let _ = std::fs::remove_dir_all(path);
}
}
}
fn init_env() {
INIT.call_once(|| {
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter("info")
.with_test_writer()
.try_init()
.ok();
// Install Rustls Crypto Provider (Ring)
let _ = rustls::crypto::ring::default_provider().install_default();
// Detect Podman socket if Docker socket is missing
if std::env::var("DOCKER_HOST").is_err()
&& !std::path::Path::new("/var/run/docker.sock").exists()
{
if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock");
if podman_sock.exists() {
std::env::set_var("DOCKER_HOST", format!("unix://{}", podman_sock.display()));
}
}
}
// Create docker shim for podman if docker is missing
if std::process::Command::new("docker")
.arg("-v")
.output()
.is_err()
{
let temp_dir = tempfile::tempdir().expect("failed to create temp dir");
let temp_path = temp_dir.path().to_owned();
let docker_shim = temp_path.join("docker");
let mut file = std::fs::File::create(&docker_shim).expect("failed to create docker shim");
file
.write_all(b"#!/bin/sh\nexec podman \"$@\"")
.expect("failed to write shim");
let mut perms = file.metadata().unwrap().permissions();
perms.set_mode(0o755);
file.set_permissions(perms).unwrap();
let path = std::env::var("PATH").unwrap_or_default();
let new_path = format!("{}:{}", temp_path.display(), path);
std::env::set_var("PATH", new_path);
// Persist the temp dir
let _ = temp_dir.keep();
// Save for cleanup
if let Ok(mut info) = CLEANUP_INFO.lock() {
info.temp_dir = Some(temp_path);
}
}
});
}
// Stream Trait Alias
trait TestStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> TestStream for T {}
// Database Helpers
fn get_db_client(port: u16, db_name: &str) -> Client {
let url = format!("http://127.0.0.1:{}", port);
Client::default()
.with_url(&url)
.with_user("traudit")
.with_password("traudit")
.with_database(db_name)
}
async fn wait_for_clickhouse(port: u16) {
let client = get_db_client(port, "default");
let mut attempts = 0;
while attempts < 30 {
if client.query("SELECT 1").execute().await.is_ok() {
return;
}
tokio::time::sleep(Duration::from_millis(1000)).await;
attempts += 1;
}
panic!("ClickHouse failed to start on port {}", port);
}
#[derive(Debug, Deserialize, clickhouse::Row)]
struct TcpLogCount {
count: u64,
}
#[derive(Debug, Deserialize, clickhouse::Row)]
struct HttpLogCount {
count: u64,
}
// Mock Upstream
async fn spawn_mock_upstream() -> (SocketAddr, JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
loop {
if let Ok((mut socket, _)) = listener.accept().await {
tokio::spawn(async move {
let mut buf = [0u8; 4096];
loop {
let n = match socket.read(&mut buf).await {
Ok(n) if n == 0 => return,
Ok(n) => n,
Err(_) => return,
};
let data = &buf[..n];
if data.starts_with(b"GET") || data.starts_with(b"POST") {
let req = String::from_utf8_lossy(data);
let body = format!("Headers:\n{}", req);
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
let _ = socket.write_all(response.as_bytes()).await;
return;
} else {
let _ = socket.write_all(data).await;
}
}
});
}
}
});
(addr, handle)
}
// Protocol Helpers
fn build_proxy_v1_header() -> Vec<u8> {
b"PROXY TCP4 1.1.1.1 2.2.2.2 1234 443\r\n".to_vec()
}
fn build_proxy_v2_header() -> Vec<u8> {
let mut header = vec![
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C,
];
header.extend_from_slice(&[1, 1, 1, 1]);
header.extend_from_slice(&[2, 2, 2, 2]);
header.extend_from_slice(&[0x04, 0xD2]);
header.extend_from_slice(&[0x01, 0xBB]);
header
}
// TLS Helpers
struct CertBundle {
cert_pem: String,
key_pem: String,
}
fn generate_cert() -> CertBundle {
let subject_alt_names = vec!["localhost".to_string(), "127.0.0.1".to_string()];
let certified_key = generate_simple_self_signed(subject_alt_names).unwrap();
CertBundle {
cert_pem: certified_key.cert.pem(),
key_pem: certified_key.key_pair.serialize_pem(),
}
}
// Config Builder
struct TestResources {
config: Config,
proxy_addr: String,
#[allow(dead_code)]
upstream_addr: SocketAddr,
_cert_file: Option<tempfile::NamedTempFile>,
_key_file: Option<tempfile::NamedTempFile>,
}
async fn prepare_env(
service_type: &str,
proxy_proto: Option<&str>,
bind_tls: bool,
real_ip_source: Option<RealIpSource>,
add_xff: bool,
is_unix: bool,
db_port: u16,
db_name: String,
) -> TestResources {
let (upstream_addr, _) = spawn_mock_upstream().await;
let (bind_addr, port_guard, socket_path) = if is_unix {
let path = format!("/tmp/traudit_test_{}.sock", rand::random::<u64>());
let _ = std::fs::remove_file(&path);
(format!("unix://{}", path), None, Some(path))
} else {
let l = TcpListener::bind("127.0.0.1:0").await.unwrap();
let p = l.local_addr().unwrap().port();
(format!("127.0.0.1:{}", p), Some(l), None)
};
drop(port_guard);
let (tls_config, cert_f, key_f) = if bind_tls {
let bundle = generate_cert();
let mut cf = tempfile::NamedTempFile::new().unwrap();
cf.write_all(bundle.cert_pem.as_bytes()).unwrap();
let mut kf = tempfile::NamedTempFile::new().unwrap();
kf.write_all(bundle.key_pem.as_bytes()).unwrap();
(
Some(TlsConfig {
cert: cf.path().to_str().unwrap().to_string(),
key: Some(kf.path().to_str().unwrap().to_string()),
}),
Some(cf),
Some(kf),
)
} else {
(None, None, None)
};
let real_ip = real_ip_source.map(|s| RealIpConfig {
source: s,
trusted_proxies: vec![],
trust_private_ranges: true,
xff_trust_depth: 0,
});
let config = Config {
database: DatabaseConfig {
db_type: "clickhouse".to_string(),
dsn: format!("http://traudit:traudit@127.0.0.1:{}/{}", db_port, db_name),
batch_size: 1,
batch_timeout_secs: 1,
},
services: vec![ServiceConfig {
name: "test-svc".to_string(),
service_type: service_type.to_string(),
forward_to: upstream_addr.to_string(),
binds: vec![BindEntry {
addr: bind_addr.clone(),
mode: 0o777,
proxy: proxy_proto.map(|s| s.to_string()),
tls: tls_config,
real_ip,
add_xff_header: add_xff,
}],
}],
};
let proxy_addr_clean = if is_unix {
socket_path.unwrap()
} else {
bind_addr
};
TestResources {
config,
proxy_addr: proxy_addr_clean,
upstream_addr,
_cert_file: cert_f,
_key_file: key_f,
}
}
// Scenarios
async fn run_tcp_test(test_name: &str, proxy_proto: Option<&str>, is_unix: bool) {
// init_env called inside Lazy
let db_port = get_shared_db_port().await;
let db_name = test_name.to_string();
// Create DB
let system_client = get_db_client(db_port, "default");
system_client
.query(&format!("CREATE DATABASE IF NOT EXISTS {}", db_name))
.execute()
.await
.unwrap();
let client = get_db_client(db_port, &db_name);
let res = prepare_env(
"tcp",
proxy_proto,
false,
proxy_proto.map(|_| RealIpSource::ProxyProtocol),
false,
is_unix,
db_port,
db_name,
)
.await;
tokio::spawn(async move {
let _ = traudit::core::server::run(res.config).await;
});
// Wait for traudit startup and DB connect
tokio::time::sleep(Duration::from_millis(1000)).await;
let mut stream: Box<dyn TestStream> = if is_unix {
Box::new(
UnixStream::connect(&res.proxy_addr)
.await
.expect("Unix connect failed"),
)
} else {
Box::new(
TcpStream::connect(&res.proxy_addr)
.await
.expect("Tcp connect failed"),
)
};
if let Some(p) = proxy_proto {
if p == "v1" {
stream.write_all(&build_proxy_v1_header()).await.unwrap();
} else {
stream.write_all(&build_proxy_v2_header()).await.unwrap();
}
}
stream.write_all(b"ping").await.unwrap();
let mut buf = [0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"ping");
drop(stream);
tokio::time::sleep(Duration::from_millis(2000)).await;
let count = client
.query("SELECT count() as count FROM tcp_log WHERE service = 'test-svc'")
.fetch_one::<TcpLogCount>()
.await
.unwrap();
assert_eq!(count.count, 1);
if is_unix {
let _ = std::fs::remove_file(&res.proxy_addr);
}
}
mod common;
use common::*;
#[tokio::test]
async fn test_tcp_normal() {
run_tcp_test("test_tcp_normal", None, false).await;
}
#[tokio::test]
async fn test_tcp_proxy_v1() {
run_tcp_test("test_tcp_proxy_v1", Some("v1"), false).await;
}
#[tokio::test]
async fn test_tcp_proxy_v2() {
run_tcp_test("test_tcp_proxy_v2", Some("v2"), false).await;
}
#[tokio::test]
async fn test_unix_suite() {
#[cfg(unix)]
{
run_tcp_test("test_unix_normal", None, true).await;
run_tcp_test("test_unix_proxy_v1", Some("v1"), true).await;
run_tcp_test("test_unix_proxy_v2", Some("v2"), true).await;
}
}
async fn run_http_test(
test_name: &str,
proxy_proto: Option<&str>,
use_tls: bool,
real_ip_source: Option<RealIpSource>,
add_xff: bool,
expected_xff_in_upstream: Option<&str>,
) {
let db_port = get_shared_db_port().await;
let db_name = test_name.to_string();
let system_client = get_db_client(db_port, "default");
system_client
.query(&format!("CREATE DATABASE IF NOT EXISTS {}", db_name))
.execute()
.await
.unwrap();
let client = get_db_client(db_port, &db_name);
let res = prepare_env(
"http",
proxy_proto,
use_tls,
real_ip_source.clone(),
add_xff,
false,
db_port,
db_name,
)
.await;
tokio::spawn(async move {
let _ = traudit::core::server::run(res.config).await;
});
tokio::time::sleep(Duration::from_millis(1000)).await;
if proxy_proto.is_some() {
let mut stream = TcpStream::connect(&res.proxy_addr).await.unwrap();
if proxy_proto == Some("v1") {
stream.write_all(&build_proxy_v1_header()).await.unwrap();
} else {
stream.write_all(&build_proxy_v2_header()).await.unwrap();
}
if use_tls {
let mut root_store = rustls::RootCertStore::empty();
let cert_bytes = tokio::fs::read(res._cert_file.as_ref().unwrap().path())
.await
.unwrap();
let mut pem = std::io::BufReader::new(&cert_bytes[..]);
let certs = rustls_pemfile::certs(&mut pem)
.map(|c| c.unwrap())
.map(rustls::pki_types::CertificateDer::from)
.collect::<Vec<_>>();
root_store.add(certs[0].clone()).unwrap();
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let domain = rustls::pki_types::ServerName::try_from("localhost").unwrap();
let mut tls_stream = connector
.connect(domain, stream)
.await
.expect("TLS handshake failed");
let request = if add_xff || real_ip_source == Some(RealIpSource::Xff) {
if real_ip_source == Some(RealIpSource::Xff) {
"GET / HTTP/1.1\r\nHost: localhost\r\nX-Forwarded-For: 8.8.8.8\r\nConnection: close\r\n\r\n"
} else {
"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
}
} else {
"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
};
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut buf = Vec::new();
tls_stream.read_to_end(&mut buf).await.unwrap();
let body = String::from_utf8_lossy(&buf).to_string();
if let Some(expected) = expected_xff_in_upstream {
assert!(body.contains(expected), "Body: {}", body);
}
assert!(body.contains("200 OK"));
} else {
let request = if real_ip_source == Some(RealIpSource::Xff) {
"GET / HTTP/1.1\r\nHost: localhost\r\nX-Forwarded-For: 8.8.8.8\r\nConnection: close\r\n\r\n"
} else {
"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
};
stream.write_all(request.as_bytes()).await.unwrap();
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await.unwrap();
let body = String::from_utf8_lossy(&buf).to_string();
if let Some(expected) = expected_xff_in_upstream {
assert!(body.contains(expected));
}
assert!(body.contains("200 OK"));
}
} else {
let client_builder = reqwest::Client::builder();
let client_http = if use_tls {
let cert_bytes = tokio::fs::read(res._cert_file.as_ref().unwrap().path())
.await
.unwrap();
let cert = reqwest::Certificate::from_pem(&cert_bytes).unwrap();
client_builder
.add_root_certificate(cert)
.danger_accept_invalid_certs(true)
.build()
.unwrap()
} else {
client_builder.build().unwrap()
};
let protocol = if use_tls { "https" } else { "http" };
let url = format!("{}://{}/", protocol, res.proxy_addr);
let resp = if real_ip_source == Some(RealIpSource::Xff) {
client_http
.get(&url)
.header("X-Forwarded-For", "8.8.8.8")
.send()
.await
.expect("Req failed")
} else {
client_http.get(&url).send().await.expect("Req failed")
};
assert_eq!(resp.status(), 200);
let body = resp.text().await.unwrap();
if let Some(expected) = expected_xff_in_upstream {
assert!(body.contains(expected));
}
}
tokio::time::sleep(Duration::from_millis(2000)).await;
let count = client
.query("SELECT count() as count FROM http_log WHERE service = 'test-svc'")
.fetch_one::<HttpLogCount>()
.await
.unwrap();
assert_eq!(count.count, 1);
}
#[tokio::test]
async fn test_http_normal() {
run_http_test("test_http_normal", None, false, None, false, None).await;
}
#[tokio::test]
async fn test_http_proxy_v1() {
run_http_test(
"test_http_proxy_v1",
Some("v1"),
false,
Some(RealIpSource::ProxyProtocol),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_http_proxy_v2() {
run_http_test(
"test_http_proxy_v2",
Some("v2"),
false,
Some(RealIpSource::ProxyProtocol),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_http_xff_source() {
run_http_test(
"test_http_xff_source",
None,
false,
Some(RealIpSource::Xff),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_http_append_xff() {
run_http_test(
"test_http_append_xff",
None,
false,
None,
true,
Some("X-Forwarded-For: 127.0.0.1"),
)
.await;
}
#[tokio::test]
async fn test_http_v1_append_xff() {
run_http_test(
"test_http_v1_append_xff",
Some("v1"),
false,
Some(RealIpSource::ProxyProtocol),
true,
Some("X-Forwarded-For: 1.1.1.1"),
)
.await;
}
#[tokio::test]
async fn test_http_v2_append_xff() {
run_http_test(
"test_http_v2_append_xff",
Some("v2"),
false,
Some(RealIpSource::ProxyProtocol),
true,
Some("X-Forwarded-For: 1.1.1.1"),
)
.await;
}
#[tokio::test]
async fn test_https_normal() {
run_http_test("test_https_normal", None, true, None, false, None).await;
}
#[tokio::test]
async fn test_https_proxy_v1() {
run_http_test(
"test_https_proxy_v1",
Some("v1"),
true,
Some(RealIpSource::ProxyProtocol),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_https_proxy_v2() {
run_http_test(
"test_https_proxy_v2",
Some("v2"),
true,
Some(RealIpSource::ProxyProtocol),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_https_xff_source() {
run_http_test(
"test_https_xff_source",
None,
true,
Some(RealIpSource::Xff),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_https_append_xff() {
run_http_test(
"test_https_append_xff",
None,
true,
None,
true,
Some("X-Forwarded-For: 127.0.0.1"),
)
.await;
}
#[tokio::test]
async fn test_https_v1_append_xff() {
run_http_test(
"test_https_v1_append_xff",
Some("v1"),
true,
Some(RealIpSource::ProxyProtocol),
true,
Some("X-Forwarded-For: 1.1.1.1"),
)
.await;
}
#[tokio::test]
async fn test_https_v2_append_xff() {
run_http_test(
"test_https_v2_append_xff",
Some("v2"),
true,
Some(RealIpSource::ProxyProtocol),
true,
Some("X-Forwarded-For: 1.1.1.1"),
)
.await;
}

327
tests/proxy_test.rs Normal file
View File

@@ -0,0 +1,327 @@
use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use traudit::config::{
BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, ServiceConfig,
};
mod common;
use common::*;
#[tokio::test]
async fn test_tcp_proxy_v1() {
run_tcp_test("test_tcp_proxy_v1", Some("v1"), false).await;
}
#[tokio::test]
async fn test_tcp_proxy_v2() {
run_tcp_test("test_tcp_proxy_v2", Some("v2"), false).await;
}
#[tokio::test]
async fn test_http_proxy_v1() {
run_http_test(
"test_http_proxy_v1",
Some("v1"),
false,
Some(RealIpSource::ProxyProtocol),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_http_proxy_v2() {
run_http_test(
"test_http_proxy_v2",
Some("v2"),
false,
Some(RealIpSource::ProxyProtocol),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_http_xff_source() {
run_http_test(
"test_http_xff_source",
None,
false,
Some(RealIpSource::Xff),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_http_append_xff() {
run_http_test(
"test_http_append_xff",
None,
false,
None,
true,
Some("X-Forwarded-For: 127.0.0.1"),
)
.await;
}
#[tokio::test]
async fn test_http_v1_append_xff() {
run_http_test(
"test_http_v1_append_xff",
Some("v1"),
false,
Some(RealIpSource::ProxyProtocol),
true,
Some("X-Forwarded-For: 1.1.1.1"),
)
.await;
}
#[tokio::test]
async fn test_http_v2_append_xff() {
run_http_test(
"test_http_v2_append_xff",
Some("v2"),
false,
Some(RealIpSource::ProxyProtocol),
true,
Some("X-Forwarded-For: 1.1.1.1"),
)
.await;
}
#[tokio::test]
async fn test_https_proxy_v1() {
run_http_test(
"test_https_proxy_v1",
Some("v1"),
true,
Some(RealIpSource::ProxyProtocol),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_https_proxy_v2() {
run_http_test(
"test_https_proxy_v2",
Some("v2"),
true,
Some(RealIpSource::ProxyProtocol),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_https_xff_source() {
run_http_test(
"test_https_xff_source",
None,
true,
Some(RealIpSource::Xff),
false,
None,
)
.await;
}
#[tokio::test]
async fn test_https_append_xff() {
run_http_test(
"test_https_append_xff",
None,
true,
None,
true,
Some("X-Forwarded-For: 127.0.0.1"),
)
.await;
}
#[tokio::test]
async fn test_https_v1_append_xff() {
run_http_test(
"test_https_v1_append_xff",
Some("v1"),
true,
Some(RealIpSource::ProxyProtocol),
true,
Some("X-Forwarded-For: 1.1.1.1"),
)
.await;
}
#[tokio::test]
async fn test_https_v2_append_xff() {
run_http_test(
"test_https_v2_append_xff",
Some("v2"),
true,
Some(RealIpSource::ProxyProtocol),
true,
Some("X-Forwarded-For: 1.1.1.1"),
)
.await;
}
// Helper for Chain Test
struct ChainTestResources {
config: Config,
e1_addr: String,
#[allow(dead_code)]
e4_upstream_addr: SocketAddr,
}
async fn prepare_chain_env() -> ChainTestResources {
// E4 Upstream (Mock Server)
let (e4_upstream_addr, _) = spawn_mock_upstream().await;
// Assign ports dynamically
let l1 = TcpListener::bind("127.0.0.1:0").await.unwrap();
let p1 = l1.local_addr().unwrap().port();
let addr1 = format!("127.0.0.1:{}", p1);
drop(l1);
let l2 = TcpListener::bind("127.0.0.1:0").await.unwrap();
let p2 = l2.local_addr().unwrap().port();
let addr2 = format!("127.0.0.1:{}", p2);
drop(l2);
let l3 = TcpListener::bind("127.0.0.1:0").await.unwrap();
let p3 = l3.local_addr().unwrap().port();
let addr3 = format!("127.0.0.1:{}", p3);
drop(l3);
let l4 = TcpListener::bind("127.0.0.1:0").await.unwrap();
let p4 = l4.local_addr().unwrap().port();
let addr4 = format!("127.0.0.1:{}", p4);
drop(l4);
// DB Config
let db_port = get_shared_db_port().await;
let db_config = DatabaseConfig {
db_type: "clickhouse".to_string(),
dsn: format!("http://traudit:traudit@127.0.0.1:{}/chain_test", db_port),
batch_size: 1,
batch_timeout_secs: 1,
};
// Create DB
let system_client = get_db_client(db_port, "default");
let _ = system_client
.query("CREATE DATABASE IF NOT EXISTS chain_test")
.execute()
.await;
// REAL IP CONFIG (Trust Proxy Protocol)
let real_ip_pp = Some(RealIpConfig {
source: RealIpSource::ProxyProtocol,
trusted_proxies: vec![],
trust_private_ranges: true, // Trusted because test runs on loopback
xff_trust_depth: 0,
});
// Services
let services = vec![
// E1: Entry (No Proxy In, Upstream Proxy V1)
ServiceConfig {
name: "e1".to_string(),
service_type: "tcp".to_string(),
binds: vec![BindEntry {
addr: addr1.clone(),
mode: 0o777,
proxy: None,
tls: None,
add_xff_header: false,
real_ip: None,
}],
forward_to: addr2.clone(),
upstream_proxy: Some("v1".to_string()),
},
// E2: (Proxy V1 In, Upstream Proxy V2)
ServiceConfig {
name: "e2".to_string(),
service_type: "tcp".to_string(),
binds: vec![BindEntry {
addr: addr2.clone(),
mode: 0o777,
proxy: Some("v1".to_string()),
tls: None,
add_xff_header: false,
real_ip: real_ip_pp.clone(),
}],
forward_to: addr3.clone(),
upstream_proxy: Some("v2".to_string()),
},
// E3: (Proxy V2 In, Upstream Proxy V1)
ServiceConfig {
name: "e3".to_string(),
service_type: "tcp".to_string(),
binds: vec![BindEntry {
addr: addr3.clone(),
mode: 0o777,
proxy: Some("v2".to_string()),
tls: None,
add_xff_header: false,
real_ip: real_ip_pp.clone(),
}],
forward_to: addr4.clone(),
upstream_proxy: Some("v1".to_string()),
},
// E4: (Proxy V1 In, No Upstream Proxy -> Mock Server)
ServiceConfig {
name: "e4".to_string(),
service_type: "tcp".to_string(),
binds: vec![BindEntry {
addr: addr4.clone(),
mode: 0o777,
proxy: Some("v1".to_string()),
tls: None,
add_xff_header: false,
real_ip: real_ip_pp.clone(),
}],
forward_to: e4_upstream_addr.to_string(),
upstream_proxy: None,
},
];
let config = Config {
database: db_config,
services,
};
ChainTestResources {
config,
e1_addr: addr1,
e4_upstream_addr,
}
}
#[tokio::test]
async fn test_proxy_chain() {
let res = prepare_chain_env().await;
tokio::spawn(async move {
let _ = traudit::core::server::run(res.config).await;
});
tokio::time::sleep(Duration::from_millis(2000)).await;
// Connect to E1
let mut stream = TcpStream::connect(&res.e1_addr)
.await
.expect("Failed to connect to E1");
// Send data
stream.write_all(b"chain_test_ping").await.unwrap();
// Read response
let mut buf = [0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
let response = &buf[..n];
// The mock upstream echoes "chain_test_ping" (since it doesn't match GET/POST)
assert_eq!(
response, b"chain_test_ping",
"Chain test failed: response mismatch"
);
}

12
tests/unix_sock_test.rs Normal file
View File

@@ -0,0 +1,12 @@
mod common;
use common::*;
#[tokio::test]
async fn test_unix_suite() {
#[cfg(unix)]
{
run_tcp_test("test_unix_normal", None, true).await;
run_tcp_test("test_unix_proxy_v1", Some("v1"), true).await;
run_tcp_test("test_unix_proxy_v2", Some("v2"), true).await;
}
}