mirror of
https://github.com/awfufu/traudit
synced 2026-03-01 05:29:44 +08:00
606 lines
17 KiB
Rust
606 lines
17 KiB
Rust
#![allow(dead_code)]
|
|
use std::io::Write;
|
|
use std::net::SocketAddr;
|
|
use std::os::unix::fs::PermissionsExt;
|
|
use std::path::PathBuf;
|
|
use std::sync::{Mutex, Once};
|
|
use std::time::Duration;
|
|
|
|
use clickhouse::Client;
|
|
use ctor::dtor;
|
|
use rcgen::generate_simple_self_signed;
|
|
use serde::Deserialize;
|
|
use testcontainers::runners::AsyncRunner;
|
|
use testcontainers::{GenericImage, ImageExt};
|
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
use tokio::net::{TcpListener, TcpStream, UnixStream};
|
|
use tokio::sync::OnceCell;
|
|
use tokio::task::JoinHandle;
|
|
|
|
use traudit::config::{
|
|
BindEntry, Config, DatabaseConfig, RealIpConfig, RealIpSource, RedirectHttpsConfig,
|
|
ServiceConfig, TlsConfig,
|
|
};
|
|
|
|
static INIT: Once = Once::new();
|
|
|
|
// Shared Container Singleton
|
|
pub struct SharedDb {
|
|
pub port: u16,
|
|
}
|
|
|
|
static SHARED_DB: OnceCell<SharedDb> = OnceCell::const_new();
|
|
|
|
// Cleanup Info used by dtor
|
|
struct CleanupInfo {
|
|
container_id: Option<String>,
|
|
temp_dir: Option<PathBuf>,
|
|
}
|
|
static CLEANUP_INFO: Mutex<CleanupInfo> = Mutex::new(CleanupInfo {
|
|
container_id: None,
|
|
temp_dir: None,
|
|
});
|
|
|
|
pub async fn get_shared_db_port() -> u16 {
|
|
let db = SHARED_DB
|
|
.get_or_init(|| async {
|
|
init_env();
|
|
|
|
let image = GenericImage::new("clickhouse/clickhouse-server", "latest")
|
|
.with_env_var("CLICKHOUSE_DB", "traudit")
|
|
.with_env_var("CLICKHOUSE_USER", "traudit")
|
|
.with_env_var("CLICKHOUSE_PASSWORD", "traudit")
|
|
.with_env_var("CLICKHOUSE_DEFAULT_ACCESS_MANAGEMENT", "1");
|
|
|
|
let container = image.start().await.expect("Failed to start container");
|
|
let port = container
|
|
.get_host_port_ipv4(8123)
|
|
.await
|
|
.expect("Failed to get port");
|
|
|
|
// Save ID for cleanup
|
|
if let Ok(mut info) = CLEANUP_INFO.lock() {
|
|
info.container_id = Some(container.id().to_string());
|
|
}
|
|
|
|
Box::leak(Box::new(container));
|
|
|
|
// Async wait
|
|
wait_for_clickhouse(port).await;
|
|
|
|
SharedDb { port }
|
|
})
|
|
.await;
|
|
|
|
db.port
|
|
}
|
|
|
|
#[dtor]
|
|
fn cleanup() {
|
|
if let Ok(info) = CLEANUP_INFO.lock() {
|
|
// Cleanup Container
|
|
if let Some(id) = &info.container_id {
|
|
let _ = std::process::Command::new("docker")
|
|
.args(["rm", "-f", id])
|
|
.output();
|
|
let _ = std::process::Command::new("podman")
|
|
.args(["rm", "-f", id])
|
|
.output();
|
|
}
|
|
|
|
// Cleanup Temp Dir (shim)
|
|
if let Some(path) = &info.temp_dir {
|
|
let _ = std::fs::remove_dir_all(path);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn init_env() {
|
|
INIT.call_once(|| {
|
|
// Initialize tracing
|
|
tracing_subscriber::fmt()
|
|
.with_env_filter("info")
|
|
.with_test_writer()
|
|
.try_init()
|
|
.ok();
|
|
// Install Rustls Crypto Provider (Ring)
|
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
|
|
|
// Detect Podman socket if Docker socket is missing or empty
|
|
let docker_host = std::env::var("DOCKER_HOST").unwrap_or_default();
|
|
let docker_sock_exists = std::path::Path::new("/var/run/docker.sock").exists();
|
|
let need_detection = docker_host.trim().is_empty() && !docker_sock_exists;
|
|
eprintln!(
|
|
"DEBUG: DOCKER_HOST='{}', docker_sock_exists={}, need_detection={}",
|
|
docker_host, docker_sock_exists, need_detection
|
|
);
|
|
|
|
if need_detection {
|
|
eprintln!("DEBUG: Attempting detection...");
|
|
let mut found = false;
|
|
if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
|
|
let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock");
|
|
if podman_sock.exists() {
|
|
std::env::set_var("DOCKER_HOST", format!("unix://{}", podman_sock.display()));
|
|
found = true;
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
// Fallback to hardcoded path
|
|
let fallback = std::path::Path::new("/run/user/1000/podman/podman.sock");
|
|
if fallback.exists() {
|
|
eprintln!("DEBUG: Found fallback podman sock: {}", fallback.display());
|
|
std::env::set_var("DOCKER_HOST", format!("unix://{}", fallback.display()));
|
|
} else {
|
|
eprintln!("DEBUG: No podman socket found in fallback locations.");
|
|
}
|
|
}
|
|
} else {
|
|
eprintln!("DEBUG: Skipping detection. Using existing DOCKER_HOST or docker.sock");
|
|
}
|
|
|
|
// Create docker shim for podman if docker is missing
|
|
if std::process::Command::new("docker")
|
|
.arg("-v")
|
|
.output()
|
|
.is_err()
|
|
{
|
|
let temp_dir = tempfile::tempdir().expect("failed to create temp dir");
|
|
let temp_path = temp_dir.path().to_owned();
|
|
let docker_shim = temp_path.join("docker");
|
|
let mut file = std::fs::File::create(&docker_shim).expect("failed to create docker shim");
|
|
file
|
|
.write_all(b"#!/bin/sh\nexec podman \"$@\"")
|
|
.expect("failed to write shim");
|
|
|
|
let mut perms = file.metadata().unwrap().permissions();
|
|
perms.set_mode(0o755);
|
|
file.set_permissions(perms).unwrap();
|
|
|
|
let path = std::env::var("PATH").unwrap_or_default();
|
|
let new_path = format!("{}:{}", temp_path.display(), path);
|
|
std::env::set_var("PATH", new_path);
|
|
|
|
// Persist the temp dir
|
|
let _ = temp_dir.keep();
|
|
|
|
// Save for cleanup
|
|
if let Ok(mut info) = CLEANUP_INFO.lock() {
|
|
info.temp_dir = Some(temp_path);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
// Stream Trait Alias
|
|
pub trait TestStream: AsyncRead + AsyncWrite + Unpin + Send {}
|
|
impl<T: AsyncRead + AsyncWrite + Unpin + Send> TestStream for T {}
|
|
|
|
// Database Helpers
|
|
|
|
pub fn get_db_client(port: u16, db_name: &str) -> Client {
|
|
let url = format!("http://127.0.0.1:{}", port);
|
|
Client::default()
|
|
.with_url(&url)
|
|
.with_user("traudit")
|
|
.with_password("traudit")
|
|
.with_database(db_name)
|
|
}
|
|
|
|
pub async fn wait_for_clickhouse(port: u16) {
|
|
let client = get_db_client(port, "default");
|
|
let mut attempts = 0;
|
|
while attempts < 30 {
|
|
if client.query("SELECT 1").execute().await.is_ok() {
|
|
return;
|
|
}
|
|
tokio::time::sleep(Duration::from_millis(1000)).await;
|
|
attempts += 1;
|
|
}
|
|
panic!("ClickHouse failed to start on port {}", port);
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, clickhouse::Row)]
|
|
pub struct TcpLogCount {
|
|
pub count: u64,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, clickhouse::Row)]
|
|
pub struct HttpLogCount {
|
|
pub count: u64,
|
|
}
|
|
|
|
// Mock Upstream
|
|
|
|
pub async fn spawn_mock_upstream() -> (SocketAddr, JoinHandle<()>) {
|
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
let addr = listener.local_addr().unwrap();
|
|
|
|
let handle = tokio::spawn(async move {
|
|
loop {
|
|
if let Ok((mut socket, _)) = listener.accept().await {
|
|
tokio::spawn(async move {
|
|
let mut buf = [0u8; 4096];
|
|
loop {
|
|
let n = match socket.read(&mut buf).await {
|
|
Ok(0) => return,
|
|
Ok(n) => n,
|
|
Err(_) => return,
|
|
};
|
|
let data = &buf[..n];
|
|
|
|
if data.starts_with(b"GET") || data.starts_with(b"POST") {
|
|
let req = String::from_utf8_lossy(data);
|
|
let body = format!("Headers:\n{}", req);
|
|
let response = format!(
|
|
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
|
|
body.len(),
|
|
body
|
|
);
|
|
let _ = socket.write_all(response.as_bytes()).await;
|
|
return;
|
|
} else {
|
|
let _ = socket.write_all(data).await;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
}
|
|
});
|
|
|
|
(addr, handle)
|
|
}
|
|
|
|
// Protocol Helpers
|
|
|
|
pub fn build_proxy_v1_header() -> Vec<u8> {
|
|
b"PROXY TCP4 1.1.1.1 2.2.2.2 1234 443\r\n".to_vec()
|
|
}
|
|
|
|
pub fn build_proxy_v2_header() -> Vec<u8> {
|
|
let mut header = vec![
|
|
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C,
|
|
];
|
|
header.extend_from_slice(&[1, 1, 1, 1]);
|
|
header.extend_from_slice(&[2, 2, 2, 2]);
|
|
header.extend_from_slice(&[0x04, 0xD2]);
|
|
header.extend_from_slice(&[0x01, 0xBB]);
|
|
header
|
|
}
|
|
|
|
// TLS Helpers
|
|
|
|
pub struct CertBundle {
|
|
pub cert_pem: String,
|
|
pub key_pem: String,
|
|
}
|
|
|
|
pub fn generate_cert() -> CertBundle {
|
|
let subject_alt_names = vec!["localhost".to_string(), "127.0.0.1".to_string()];
|
|
let certified_key = generate_simple_self_signed(subject_alt_names).unwrap();
|
|
CertBundle {
|
|
cert_pem: certified_key.cert.pem(),
|
|
key_pem: certified_key.signing_key.serialize_pem(),
|
|
}
|
|
}
|
|
|
|
// Config Builder
|
|
|
|
pub struct TestResources {
|
|
pub config: Config,
|
|
pub proxy_addr: String,
|
|
#[allow(dead_code)]
|
|
pub upstream_addr: SocketAddr,
|
|
pub _cert_file: Option<tempfile::NamedTempFile>,
|
|
pub _key_file: Option<tempfile::NamedTempFile>,
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub async fn prepare_env(
|
|
service_type: &str,
|
|
proxy_proto: Option<&str>,
|
|
bind_tls: bool,
|
|
real_ip_source: Option<RealIpSource>,
|
|
add_xff: bool,
|
|
is_unix: bool,
|
|
db_port: u16,
|
|
db_name: String,
|
|
) -> TestResources {
|
|
let (upstream_addr, _) = spawn_mock_upstream().await;
|
|
|
|
let (bind_addr, port_guard, socket_path) = if is_unix {
|
|
let path = format!("/tmp/traudit_test_{}.sock", rand::random::<u64>());
|
|
let _ = std::fs::remove_file(&path);
|
|
(format!("unix://{}", path), None, Some(path))
|
|
} else {
|
|
let l = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
let p = l.local_addr().unwrap().port();
|
|
(format!("127.0.0.1:{}", p), Some(l), None)
|
|
};
|
|
drop(port_guard);
|
|
|
|
let (tls_config, cert_f, key_f) = if bind_tls {
|
|
let bundle = generate_cert();
|
|
let mut cf = tempfile::NamedTempFile::new().unwrap();
|
|
cf.write_all(bundle.cert_pem.as_bytes()).unwrap();
|
|
let mut kf = tempfile::NamedTempFile::new().unwrap();
|
|
kf.write_all(bundle.key_pem.as_bytes()).unwrap();
|
|
(
|
|
Some(TlsConfig {
|
|
cert: cf.path().to_str().unwrap().to_string(),
|
|
key: Some(kf.path().to_str().unwrap().to_string()),
|
|
}),
|
|
Some(cf),
|
|
Some(kf),
|
|
)
|
|
} else {
|
|
(None, None, None)
|
|
};
|
|
|
|
let real_ip = real_ip_source.map(|s| RealIpConfig {
|
|
source: s,
|
|
trusted_proxies: vec![],
|
|
trust_private_ranges: true,
|
|
xff_trust_depth: 0,
|
|
});
|
|
|
|
let config = Config {
|
|
database: DatabaseConfig {
|
|
db_type: "clickhouse".to_string(),
|
|
dsn: format!("http://traudit:traudit@127.0.0.1:{}/{}", db_port, db_name),
|
|
batch_size: 1,
|
|
batch_timeout_secs: 1,
|
|
},
|
|
services: vec![ServiceConfig {
|
|
name: "test-svc".to_string(),
|
|
service_type: service_type.to_string(),
|
|
forward_to: Some(upstream_addr.to_string()),
|
|
upstream_proxy: None,
|
|
binds: vec![BindEntry {
|
|
addr: bind_addr.clone(),
|
|
mode: 0o777,
|
|
proxy: proxy_proto.map(|s| s.to_string()),
|
|
tls: tls_config,
|
|
real_ip,
|
|
add_xff_header: add_xff,
|
|
redirect_https: RedirectHttpsConfig::default(),
|
|
}],
|
|
}],
|
|
};
|
|
|
|
let proxy_addr_clean = if is_unix {
|
|
socket_path.unwrap()
|
|
} else {
|
|
bind_addr
|
|
};
|
|
|
|
TestResources {
|
|
config,
|
|
proxy_addr: proxy_addr_clean,
|
|
upstream_addr,
|
|
_cert_file: cert_f,
|
|
_key_file: key_f,
|
|
}
|
|
}
|
|
|
|
pub async fn run_tcp_test(test_name: &str, proxy_proto: Option<&str>, is_unix: bool) {
|
|
// init_env called inside Lazy
|
|
let db_port = get_shared_db_port().await;
|
|
let db_name = test_name.to_string();
|
|
|
|
// Create DB
|
|
let system_client = get_db_client(db_port, "default");
|
|
system_client
|
|
.query(&format!("CREATE DATABASE IF NOT EXISTS {}", db_name))
|
|
.execute()
|
|
.await
|
|
.unwrap();
|
|
|
|
let client = get_db_client(db_port, &db_name);
|
|
|
|
let res = prepare_env(
|
|
"tcp",
|
|
proxy_proto,
|
|
false,
|
|
proxy_proto.map(|_| RealIpSource::ProxyProtocol),
|
|
false,
|
|
is_unix,
|
|
db_port,
|
|
db_name,
|
|
)
|
|
.await;
|
|
|
|
tokio::spawn(async move {
|
|
let (tx, _rx) = tokio::sync::broadcast::channel(1);
|
|
let _ = traudit::core::server::run(res.config, tx).await;
|
|
});
|
|
// Wait for traudit startup and DB connect
|
|
tokio::time::sleep(Duration::from_millis(1000)).await;
|
|
|
|
let mut stream: Box<dyn TestStream> = if is_unix {
|
|
Box::new(
|
|
UnixStream::connect(&res.proxy_addr)
|
|
.await
|
|
.expect("Unix connect failed"),
|
|
)
|
|
} else {
|
|
Box::new(
|
|
TcpStream::connect(&res.proxy_addr)
|
|
.await
|
|
.expect("Tcp connect failed"),
|
|
)
|
|
};
|
|
|
|
if let Some(p) = proxy_proto {
|
|
if p == "v1" {
|
|
stream.write_all(&build_proxy_v1_header()).await.unwrap();
|
|
} else {
|
|
stream.write_all(&build_proxy_v2_header()).await.unwrap();
|
|
}
|
|
}
|
|
|
|
stream.write_all(b"ping").await.unwrap();
|
|
let mut buf = [0u8; 1024];
|
|
let n = stream.read(&mut buf).await.unwrap();
|
|
assert_eq!(&buf[..n], b"ping");
|
|
|
|
drop(stream);
|
|
tokio::time::sleep(Duration::from_millis(2000)).await;
|
|
|
|
let count = client
|
|
.query("SELECT count() as count FROM tcp_log WHERE service = 'test-svc'")
|
|
.fetch_one::<TcpLogCount>()
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(count.count, 1);
|
|
|
|
if is_unix {
|
|
let _ = std::fs::remove_file(&res.proxy_addr);
|
|
}
|
|
}
|
|
|
|
pub async fn run_http_test(
|
|
test_name: &str,
|
|
proxy_proto: Option<&str>,
|
|
use_tls: bool,
|
|
real_ip_source: Option<RealIpSource>,
|
|
add_xff: bool,
|
|
expected_xff_in_upstream: Option<&str>,
|
|
) {
|
|
let db_port = get_shared_db_port().await;
|
|
let db_name = test_name.to_string();
|
|
|
|
let system_client = get_db_client(db_port, "default");
|
|
system_client
|
|
.query(&format!("CREATE DATABASE IF NOT EXISTS {}", db_name))
|
|
.execute()
|
|
.await
|
|
.unwrap();
|
|
|
|
let client = get_db_client(db_port, &db_name);
|
|
|
|
let res = prepare_env(
|
|
"http",
|
|
proxy_proto,
|
|
use_tls,
|
|
real_ip_source.clone(),
|
|
add_xff,
|
|
false,
|
|
db_port,
|
|
db_name,
|
|
)
|
|
.await;
|
|
|
|
tokio::spawn(async move {
|
|
let (tx, _rx) = tokio::sync::broadcast::channel(1);
|
|
let _ = traudit::core::server::run(res.config, tx).await;
|
|
});
|
|
tokio::time::sleep(Duration::from_millis(1000)).await;
|
|
|
|
if proxy_proto.is_some() {
|
|
let mut stream = TcpStream::connect(&res.proxy_addr).await.unwrap();
|
|
if proxy_proto == Some("v1") {
|
|
stream.write_all(&build_proxy_v1_header()).await.unwrap();
|
|
} else {
|
|
stream.write_all(&build_proxy_v2_header()).await.unwrap();
|
|
}
|
|
|
|
if use_tls {
|
|
let mut root_store = rustls::RootCertStore::empty();
|
|
let cert_bytes = tokio::fs::read(res._cert_file.as_ref().unwrap().path())
|
|
.await
|
|
.unwrap();
|
|
let mut pem = std::io::BufReader::new(&cert_bytes[..]);
|
|
let certs = rustls_pemfile::certs(&mut pem)
|
|
.map(|c| c.unwrap())
|
|
.collect::<Vec<_>>();
|
|
root_store.add(certs[0].clone()).unwrap();
|
|
|
|
let config = rustls::ClientConfig::builder()
|
|
.with_root_certificates(root_store)
|
|
.with_no_client_auth();
|
|
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
|
|
let domain = rustls::pki_types::ServerName::try_from("localhost").unwrap();
|
|
let mut tls_stream = connector
|
|
.connect(domain, stream)
|
|
.await
|
|
.expect("TLS handshake failed");
|
|
|
|
let request = if add_xff || real_ip_source == Some(RealIpSource::Xff) {
|
|
if real_ip_source == Some(RealIpSource::Xff) {
|
|
"GET / HTTP/1.1\r\nHost: localhost\r\nX-Forwarded-For: 8.8.8.8\r\nConnection: close\r\n\r\n"
|
|
} else {
|
|
"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
|
|
}
|
|
} else {
|
|
"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
|
|
};
|
|
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
|
|
|
let mut buf = Vec::new();
|
|
tls_stream.read_to_end(&mut buf).await.unwrap();
|
|
let body = String::from_utf8_lossy(&buf).to_string();
|
|
|
|
if let Some(expected) = expected_xff_in_upstream {
|
|
assert!(body.contains(expected), "Body: {}", body);
|
|
}
|
|
assert!(body.contains("200 OK"));
|
|
} else {
|
|
let request = if real_ip_source == Some(RealIpSource::Xff) {
|
|
"GET / HTTP/1.1\r\nHost: localhost\r\nX-Forwarded-For: 8.8.8.8\r\nConnection: close\r\n\r\n"
|
|
} else {
|
|
"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
|
|
};
|
|
stream.write_all(request.as_bytes()).await.unwrap();
|
|
let mut buf = Vec::new();
|
|
stream.read_to_end(&mut buf).await.unwrap();
|
|
let body = String::from_utf8_lossy(&buf).to_string();
|
|
if let Some(expected) = expected_xff_in_upstream {
|
|
assert!(body.contains(expected));
|
|
}
|
|
assert!(body.contains("200 OK"));
|
|
}
|
|
} else {
|
|
let client_builder = reqwest::Client::builder();
|
|
let client_http = if use_tls {
|
|
let cert_bytes = tokio::fs::read(res._cert_file.as_ref().unwrap().path())
|
|
.await
|
|
.unwrap();
|
|
let cert = reqwest::Certificate::from_pem(&cert_bytes).unwrap();
|
|
client_builder
|
|
.add_root_certificate(cert)
|
|
.danger_accept_invalid_certs(true)
|
|
.build()
|
|
.unwrap()
|
|
} else {
|
|
client_builder.build().unwrap()
|
|
};
|
|
let protocol = if use_tls { "https" } else { "http" };
|
|
let url = format!("{}://{}/", protocol, res.proxy_addr);
|
|
let resp = if real_ip_source == Some(RealIpSource::Xff) {
|
|
client_http
|
|
.get(&url)
|
|
.header("X-Forwarded-For", "8.8.8.8")
|
|
.send()
|
|
.await
|
|
.expect("Req failed")
|
|
} else {
|
|
client_http.get(&url).send().await.expect("Req failed")
|
|
};
|
|
assert_eq!(resp.status(), 200);
|
|
let body = resp.text().await.unwrap();
|
|
if let Some(expected) = expected_xff_in_upstream {
|
|
assert!(body.contains(expected));
|
|
}
|
|
}
|
|
|
|
tokio::time::sleep(Duration::from_millis(2000)).await;
|
|
let count = client
|
|
.query("SELECT count() as count FROM http_log WHERE service = 'test-svc'")
|
|
.fetch_one::<HttpLogCount>()
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(count.count, 1);
|
|
}
|