refactor: optimize configuration structure and logging, add config test flag

This commit is contained in:
2026-01-15 21:57:13 +08:00
parent ab67b5f45a
commit cccbd80078
6 changed files with 259 additions and 174 deletions

View File

@@ -1,42 +1,32 @@
# Traudit Configuration Example
database:
dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db"
batch:
size: 50
timeout_secs: 5
type: clickhouse
dsn: "http://user:password@ip:port"
batch_size: 50
batch_timeout_secs: 5
tables:
tcp: tcp_log
services:
# Scenario: SSH High-Level Audit Gateway
# Receives traffic from FRP with v2 Proxy Protocol header, audits it,
# strips the header, and forwards pure TCP to local SSHD.
- name: "ssh-prod"
db_table: "ssh_audit_logs"
- name: "ssh"
type: "tcp"
binds:
# Entry 1: Public traffic from FRP
- type: "tcp"
addr: "0.0.0.0:1222"
proxy_protocol: "v2"
# Entry 2: LAN direct traffic (no Proxy Protocol)
- type: "tcp"
addr: "0.0.0.0:1223"
forward_type: "tcp"
forward_addr: "127.0.0.1:22"
# forward_proxy_protocol omitted, sends pure stream to SSHD
- addr: "0.0.0.0:2222"
proxy: "v2"
# Scenario: Protocol Conversion and Local Socket Forwarding
# Receives normal TCP traffic, converts to v1 Proxy Protocol header,
# and forwards to local Unix socket (Nginx).
- name: "web-gateway"
db_table: "http_access_audit"
# Entry 2: LAN direct traffic (no Proxy Protocol)
- addr: "0.0.0.0:2233"
# no proxy
forward_to: "127.0.0.1:22"
- name: "web"
type: "tcp"
binds:
- type: "tcp"
addr: "0.0.0.0:8080"
- addr: "0.0.0.0:8080"
forward_type: "unix"
forward_addr: "/run/nginx/web.sock"
forward_proxy_protocol: "v1"
forward_to: "/run/nginx/web.sock"

View File

@@ -1,4 +1,5 @@
use serde::Deserialize;
use std::collections::HashMap;
use std::path::Path;
use tokio::fs;
@@ -10,61 +11,42 @@ pub struct Config {
#[derive(Debug, Deserialize, Clone)]
pub struct DatabaseConfig {
pub dsn: String,
#[serde(rename = "type")]
#[allow(dead_code)]
pub batch: BatchConfig,
pub db_type: String,
pub dsn: String,
pub tables: HashMap<String, String>,
#[serde(default = "default_batch_size")]
#[allow(dead_code)]
pub batch_size: usize,
#[serde(default = "default_timeout_secs")]
#[allow(dead_code)]
pub batch_timeout_secs: u64,
}
#[derive(Debug, Deserialize, Clone)]
pub struct BatchConfig {
#[allow(dead_code)]
pub size: usize,
#[allow(dead_code)]
pub timeout_secs: u64,
fn default_batch_size() -> usize {
1000
}
fn default_timeout_secs() -> u64 {
5
}
#[derive(Debug, Deserialize, Clone)]
pub struct ServiceConfig {
pub name: String,
#[allow(dead_code)]
pub db_table: String,
pub binds: Vec<BindConfig>,
pub forward_type: ForwardType,
pub forward_addr: String,
#[allow(dead_code)]
pub forward_proxy_protocol: Option<ProxyProtocolVersion>,
#[serde(rename = "type")]
pub service_type: String,
pub binds: Vec<BindEntry>,
#[serde(rename = "forward_to")]
pub forward_to: String,
}
#[derive(Debug, Deserialize, Clone)]
pub struct BindConfig {
#[serde(rename = "type")]
pub bind_type: BindType,
pub struct BindEntry {
pub addr: String,
#[serde(alias = "proxy")]
pub proxy_protocol: Option<ProxyProtocolVersion>,
}
#[derive(Debug, Deserialize, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum BindType {
Tcp,
Udp,
Unix,
}
#[derive(Debug, Deserialize, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ForwardType {
Tcp,
Udp,
Unix,
}
#[derive(Debug, Deserialize, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ProxyProtocolVersion {
V1,
V2,
#[serde(alias = "proxy_protocol", rename = "proxy")]
pub proxy: Option<String>,
}
impl Config {
@@ -84,26 +66,24 @@ mod tests {
async fn test_load_config() {
let config_str = r#"
database:
type: clickhouse
dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db"
batch:
size: 50
timeout_secs: 5
batch_size: 50
batch_timeout_secs: 5
tables:
tcp: tcp_log
services:
- name: "ssh-prod"
db_table: "ssh_audit_logs"
type: "tcp"
binds:
- type: "tcp"
addr: "0.0.0.0:22222"
proxy_protocol: "v2"
forward_type: "tcp"
forward_addr: "127.0.0.1:22"
- addr: "0.0.0.0:22222"
proxy: "v2"
forward_to: "127.0.0.1:22"
"#;
let mut file = tempfile::NamedTempFile::new().unwrap();
write!(file, "{}", config_str).unwrap();
let path = file.path().to_path_buf();
// Close the file handle so tokio can read it, or just keep it open and read by path?
// tempfile deletes on drop. We need to keep `file` alive.
let config = Config::load(&path).await.expect("Failed to load config");
@@ -113,10 +93,8 @@ services:
);
assert_eq!(config.services.len(), 1);
assert_eq!(config.services[0].name, "ssh-prod");
assert_eq!(config.services[0].binds[0].bind_type, BindType::Tcp);
assert_eq!(
config.services[0].binds[0].proxy_protocol,
Some(ProxyProtocolVersion::V2)
);
assert_eq!(config.services[0].binds[0].addr, "0.0.0.0:22222");
assert_eq!(config.services[0].binds[0].proxy, Some("v2".to_string()));
assert_eq!(config.services[0].forward_to, "127.0.0.1:22");
}
}

View File

@@ -1,4 +1,4 @@
use crate::config::{BindType, Config, ServiceConfig};
use crate::config::{Config, ServiceConfig};
use crate::core::forwarder;
use crate::core::upstream::UpstreamStream;
use crate::db::clickhouse::ClickHouseLogger;
@@ -10,11 +10,21 @@ use tokio::signal;
use tracing::{error, info};
pub async fn run(config: Config) -> anyhow::Result<()> {
let db = Arc::new(ClickHouseLogger::new(&config.database));
let db_logger = ClickHouseLogger::new(&config.database).map_err(|e| {
error!("database: {}", e);
e
})?;
let db = Arc::new(db_logger);
// init db table
// init db table
if let Err(e) = db.init().await {
error!("failed to init database: {}", e);
let msg = e.to_string();
if msg.len() > 200 {
error!("failed to init database: {}... (truncated)", &msg[..200]);
} else {
error!("failed to init database: {}", msg);
}
return Err(e);
}
@@ -22,32 +32,36 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
for service in config.services {
let db = db.clone();
// Only support TCP service type for now, as per user instructions implied context
if service.service_type != "tcp" {
info!("skipping non-tcp service: {}", service.name);
continue;
}
for bind in &service.binds {
let service_config = service.clone();
let bind_addr = bind.addr.clone();
let proxy_protocol = bind.proxy_protocol.is_some();
let bind_type = bind.bind_type;
// proxy is now Option<String>
let proxy_proto_config = bind.proxy.clone();
if bind_type == BindType::Tcp {
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
error!(
"[{}] failed to bind {}: {}",
service_config.name, bind_addr, e
);
e
})?;
// BindType is removed, assume TCP bind for "tcp" service
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
error!(
"[{}] failed to bind {}: {}",
service_config.name, bind_addr, e
);
e
})?;
info!("[{}] listening on tcp {}", service_config.name, bind_addr);
info!("[{}] listening on tcp {}", service_config.name, bind_addr);
join_set.spawn(start_tcp_service(
service_config,
listener,
proxy_protocol,
db.clone(),
));
} else {
info!("skipping non-tcp bind for now: {:?}", bind_type);
}
join_set.spawn(start_tcp_service(
service_config,
listener,
proxy_proto_config,
db.clone(),
));
}
}
@@ -79,22 +93,53 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
async fn start_tcp_service(
service: ServiceConfig,
listener: TcpListener,
proxy_protocol: bool,
proxy_cfg: Option<String>,
db: Arc<ClickHouseLogger>,
) {
// Startup liveness check
if let Err(e) = UpstreamStream::connect(&service.forward_to).await {
match e.kind() {
std::io::ErrorKind::ConnectionRefused => {
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
}
std::io::ErrorKind::NotFound => {
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
}
_ => {
// For other startup errors, we might want to warn or just debug, but let's stick to user request for WARNING
tracing::warn!(
"[{}] -> '{}': startup check failed: {}",
service.name,
service.forward_to,
e
);
}
}
}
loop {
match listener.accept().await {
Ok((inbound, _client_addr)) => {
let service = service.clone();
let db = db.clone();
let proxy_cfg = proxy_cfg.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(inbound, service, proxy_protocol, db).await {
let svc_name = service.name.clone();
let svc_target = service.forward_to.clone();
if let Err(e) = handle_connection(inbound, service, proxy_cfg, db).await {
match e.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
// normal disconnects, debug log only
tracing::debug!("connection closed: {}", e);
}
std::io::ErrorKind::ConnectionRefused => {
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
}
std::io::ErrorKind::NotFound => {
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
}
_ => {
error!("connection error: {}", e);
}
@@ -112,7 +157,7 @@ async fn start_tcp_service(
async fn handle_connection(
mut inbound: tokio::net::TcpStream,
service: ServiceConfig,
proxy_protocol: bool,
proxy_cfg: Option<String>,
db: Arc<ClickHouseLogger>,
) -> std::io::Result<u64> {
let conn_ts = time::OffsetDateTime::now_utc();
@@ -127,7 +172,9 @@ async fn handle_connection(
// read proxy protocol (if configured)
let mut buffer = bytes::BytesMut::new();
if proxy_protocol {
if proxy_cfg.is_some() {
// If configured, we attempt to read.
// Strict V2/V1 check can be implemented if needed, but here we just use the parser.
match protocol::read_proxy_header(&mut inbound).await {
Ok((proxy_info, buf)) => {
buffer = buf;
@@ -140,6 +187,19 @@ async fn handle_connection(
protocol::Version::V1 => crate::db::clickhouse::ProxyProto::V1,
protocol::Version::V2 => crate::db::clickhouse::ProxyProto::V2,
};
// Optional: verify version matches config if strictly required
if let Some(ref required_ver) = proxy_cfg {
match required_ver.as_str() {
"v1" if info.version != protocol::Version::V1 => {
// warn mismatch?
}
"v2" if info.version != protocol::Version::V2 => {
// warn mismatch?
}
_ => {}
}
}
} else {
// Strict enforcement: if configured with proxy_protocol, MUST have a header
let physical = inbound.peer_addr()?;
@@ -156,7 +216,7 @@ async fn handle_connection(
}
// connect upstream
let mut upstream = UpstreamStream::connect(service.forward_type, &service.forward_addr).await?;
let mut upstream = UpstreamStream::connect(&service.forward_to).await?;
// write buffered data (peeked bytes)
if !buffer.is_empty() {

View File

@@ -1,4 +1,3 @@
use crate::config::ForwardType;
use std::io;
use std::os::unix::io::{AsRawFd, RawFd};
use std::pin::Pin;
@@ -13,21 +12,14 @@ pub enum UpstreamStream {
}
impl UpstreamStream {
pub async fn connect(fw_type: ForwardType, addr: &str) -> io::Result<Self> {
match fw_type {
ForwardType::Tcp => {
let stream = TcpStream::connect(addr).await?;
stream.set_nodelay(true)?;
Ok(UpstreamStream::Tcp(stream))
}
ForwardType::Unix => {
let stream = UnixStream::connect(addr).await?;
Ok(UpstreamStream::Unix(stream))
}
ForwardType::Udp => Err(io::Error::new(
io::ErrorKind::Unsupported,
"UDP forwarding not yet implemented in stream context",
)),
pub async fn connect(addr: &str) -> io::Result<Self> {
if addr.starts_with('/') {
let stream = UnixStream::connect(addr).await?;
Ok(UpstreamStream::Unix(stream))
} else {
let stream = TcpStream::connect(addr).await?;
stream.set_nodelay(true)?;
Ok(UpstreamStream::Tcp(stream))
}
}
}

View File

@@ -50,11 +50,12 @@ struct TcpLogV6 {
pub struct ClickHouseLogger {
client: Client,
table_base: String,
}
impl ClickHouseLogger {
pub fn new(config: &DatabaseConfig) -> Self {
let url = url::Url::parse(&config.dsn).expect("invalid dsn");
pub fn new(config: &DatabaseConfig) -> anyhow::Result<Self> {
let url = url::Url::parse(&config.dsn).map_err(|e| anyhow::anyhow!("invalid dsn: {}", e))?;
let mut client = Client::default().with_url(url.as_str());
if let (Some(u), Some(p)) = (Some(url.username()), url.password()) {
@@ -73,12 +74,25 @@ impl ClickHouseLogger {
}
}
Self { client }
// Config table name, default to "tcp_log" if missing
// We expect config.tables to contain "tcp" -> "tablename"
let table_base = config
.tables
.get("tcp")
.cloned()
.unwrap_or_else(|| "tcp_log".to_string());
Ok(Self { client, table_base })
}
pub async fn init(&self) -> anyhow::Result<()> {
let sql_v4 = r#"
CREATE TABLE IF NOT EXISTS tcp_log_v4 (
let table_v4 = format!("{}_v4", self.table_base);
let table_v6 = format!("{}_v6", self.table_base);
let view_name = &self.table_base;
let sql_v4 = format!(
r#"
CREATE TABLE IF NOT EXISTS {} (
service LowCardinality(String),
conn_ts DateTime('UTC'),
duration UInt32,
@@ -88,10 +102,13 @@ impl ClickHouseLogger {
bytes UInt64
) ENGINE = MergeTree()
ORDER BY (service, conn_ts);
"#;
"#,
table_v4
);
let sql_v6 = r#"
CREATE TABLE IF NOT EXISTS tcp_log_v6 (
let sql_v6 = format!(
r#"
CREATE TABLE IF NOT EXISTS {} (
service LowCardinality(String),
conn_ts DateTime('UTC'),
duration UInt32,
@@ -101,30 +118,35 @@ impl ClickHouseLogger {
bytes UInt64
) ENGINE = MergeTree()
ORDER BY (service, conn_ts);
"#;
"#,
table_v6
);
let drop_view = "DROP VIEW IF EXISTS tcp_log";
let drop_view = format!("DROP VIEW IF EXISTS {}", view_name);
let sql_view = r#"
CREATE VIEW tcp_log AS
let sql_view = format!(
r#"
CREATE VIEW {} AS
SELECT
service, conn_ts, duration, port,
IPv4NumToString(ip) AS ip_str,
proxy_proto,
formatReadableSize(bytes) AS traffic
FROM tcp_log_v4
FROM {}
UNION ALL
SELECT
service, conn_ts, duration, port,
IPv6NumToString(ip) AS ip_str,
proxy_proto,
formatReadableSize(bytes) AS traffic
FROM tcp_log_v6;
"#;
FROM {};
"#,
view_name, table_v4, table_v6
);
self
.client
.query(sql_v4)
.query(&sql_v4)
.execute()
.await
.map_err(|e| anyhow::anyhow!("failed to create v4 table: {}", e))?;
@@ -132,57 +154,75 @@ impl ClickHouseLogger {
// Migrations
let _ = self
.client
.query("ALTER TABLE tcp_log_v4 RENAME COLUMN IF EXISTS bytes_transferred TO bytes")
.query(&format!(
"ALTER TABLE {} RENAME COLUMN IF EXISTS bytes_transferred TO bytes",
table_v4
))
.execute()
.await;
let _ = self
.client
.query("ALTER TABLE tcp_log_v4 RENAME COLUMN IF EXISTS traffic TO bytes")
.query(&format!(
"ALTER TABLE {} RENAME COLUMN IF EXISTS traffic TO bytes",
table_v4
))
.execute()
.await;
let _ = self
.client
.query("ALTER TABLE tcp_log_v4 ADD COLUMN IF NOT EXISTS bytes UInt64")
.query(&format!(
"ALTER TABLE {} ADD COLUMN IF NOT EXISTS bytes UInt64",
table_v4
))
.execute()
.await;
self
.client
.query(sql_v6)
.query(&sql_v6)
.execute()
.await
.map_err(|e| anyhow::anyhow!("failed to create v6 table: {}", e))?;
let _ = self
.client
.query("ALTER TABLE tcp_log_v6 RENAME COLUMN IF EXISTS bytes_transferred TO bytes")
.query(&format!(
"ALTER TABLE {} RENAME COLUMN IF EXISTS bytes_transferred TO bytes",
table_v6
))
.execute()
.await;
let _ = self
.client
.query("ALTER TABLE tcp_log_v6 RENAME COLUMN IF EXISTS traffic TO bytes")
.query(&format!(
"ALTER TABLE {} RENAME COLUMN IF EXISTS traffic TO bytes",
table_v6
))
.execute()
.await;
let _ = self
.client
.query("ALTER TABLE tcp_log_v6 ADD COLUMN IF NOT EXISTS bytes UInt64")
.query(&format!(
"ALTER TABLE {} ADD COLUMN IF NOT EXISTS bytes UInt64",
table_v6
))
.execute()
.await;
self
.client
.query(drop_view)
.query(&drop_view)
.execute()
.await
.map_err(|e| anyhow::anyhow!("failed to drop view: {}", e))?;
self
.client
.query(sql_view)
.query(&sql_view)
.execute()
.await
.map_err(|e| anyhow::anyhow!("failed to create view: {}", e))?;
info!("ensured tables and view exist");
info!("connected to database");
Ok(())
}
@@ -198,7 +238,8 @@ impl ClickHouseLogger {
proxy_proto: log.proxy_proto,
bytes: log.bytes,
};
let mut insert = self.client.insert("tcp_log_v4")?;
let table = format!("{}_v4", self.table_base);
let mut insert = self.client.insert(&table)?;
insert.write(&row).await?;
insert.end().await?;
}
@@ -212,7 +253,8 @@ impl ClickHouseLogger {
proxy_proto: log.proxy_proto,
bytes: log.bytes,
};
let mut insert = self.client.insert("tcp_log_v6")?;
let table = format!("{}_v6", self.table_base);
let mut insert = self.client.insert(&table)?;
insert.write(&row).await?;
insert.end().await?;
}

View File

@@ -17,14 +17,23 @@ fn print_help() {
println!();
println!("options:");
println!(" -f <config_file> path to the yaml configuration file");
println!(" -t, --test test configuration and exit");
println!(" -h, --help print this help message");
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_target(false)
.with_thread_ids(false)
.with_file(false)
.with_line_number(false)
.init();
let args: Vec<String> = env::args().collect();
let mut config_path = None;
let mut test_config = false;
let mut i = 1;
while i < args.len() {
@@ -37,6 +46,10 @@ async fn main() -> anyhow::Result<()> {
bail!("missing value for -f");
}
}
"-t" | "--test" => {
test_config = true;
i += 1;
}
"-h" | "--help" => {
print_help();
return Ok(());
@@ -51,7 +64,8 @@ async fn main() -> anyhow::Result<()> {
Some(p) => {
let path = Path::new(&p);
if !path.exists() {
bail!("config file '{}' not found", p);
error!("config file '{}' not found", p);
std::process::exit(1);
}
std::fs::canonicalize(path)?
}
@@ -61,21 +75,30 @@ async fn main() -> anyhow::Result<()> {
}
};
tracing_subscriber::fmt()
.with_target(false)
.with_thread_ids(false)
.with_file(false)
.with_line_number(false)
.init();
info!("loading config from {}", config_path.display());
let config = Config::load(&config_path).await.map_err(|e| {
error!("failed to load config: {}", e);
e
})?;
let config = match Config::load(&config_path).await {
Ok(c) => c,
Err(e) => {
error!("failed to load config: {}", e);
std::process::exit(1);
}
};
core::server::run(config).await?;
if test_config {
// Validate database config
if let Err(e) = crate::db::clickhouse::ClickHouseLogger::new(&config.database) {
error!("configuration check failed: {}", e);
std::process::exit(1);
}
info!("configuration ok");
return Ok(());
}
if let Err(_e) = core::server::run(config).await {
std::process::exit(1);
}
Ok(())
}