feat: support unix socket binds and update db schema

This commit is contained in:
2026-01-16 14:27:28 +08:00
parent 40d3156e88
commit 887c864550
10 changed files with 589 additions and 219 deletions

2
Cargo.lock generated
View File

@@ -1014,7 +1014,7 @@ dependencies = [
[[package]]
name = "traudit"
version = "0.1.0"
version = "0.0.1"
dependencies = [
"anyhow",
"async-trait",

View File

@@ -1,6 +1,6 @@
[package]
name = "traudit"
version = "0.1.0"
version = "0.0.1"
edition = "2021"
authors = ["awfufu"]
description = "A reverse proxy with auditing capabilities."
@@ -29,7 +29,7 @@ tempfile = "3"
[profile.release]
opt-level = 3
lto = true
lto = "fat"
codegen-units = 1
panic = "abort"
strip = true

View File

@@ -24,7 +24,7 @@ See [config_example.yaml](config_example.yaml).
- [x] TCP Proxy & Zero-copy forwarding (`splice`)
- [x] Proxy Protocol V1/V2 parsing
- [ ] UDP Forwarding (Planned)
- [ ] Unix Socket Forwarding (Planned)
- [x] Unix Socket Forwarding
- [x] Database Integration
- [x] ClickHouse Adapter (Native Interface)
- [x] Traffic Accounting (Bytes/Bandwidth)

View File

@@ -24,7 +24,7 @@ traudit 是一个支持 TCP/UDP/Unix Socket 的反向代理程序,专注于连
- [x] TCP 代理与零拷贝转发 (`splice`)
- [x] Proxy Protocol V1/V2 解析
- [ ] UDP 转发 (计划中)
- [ ] Unix Socket 转发 (计划中)
- [x] Unix Socket 转发
- [x] 数据库集成
- [x] ClickHouse 适配器 (原生接口)
- [x] 流量统计 (字节数)

View File

@@ -2,11 +2,9 @@
database:
type: clickhouse
dsn: "http://user:password@ip:port"
dsn: "http://user:password@ip:port/traudit"
batch_size: 50
batch_timeout_secs: 5
tables:
tcp: tcp_log
services:
# Receives traffic from FRP with v2 Proxy Protocol header, audits it,
@@ -24,9 +22,3 @@ services:
forward_to: "127.0.0.1:22"
# - name: "web"
# type: "tcp"
# binds:
# - addr: "0.0.0.0:8080"
# forward_to: "/run/nginx/web.sock"

View File

@@ -1,5 +1,4 @@
use serde::Deserialize;
use std::collections::HashMap;
use serde::{Deserialize, Deserializer};
use std::path::Path;
use tokio::fs;
@@ -15,7 +14,6 @@ pub struct DatabaseConfig {
#[allow(dead_code)]
pub db_type: String,
pub dsn: String,
pub tables: HashMap<String, String>,
#[serde(default = "default_batch_size")]
#[allow(dead_code)]
pub batch_size: usize,
@@ -47,6 +45,41 @@ pub struct BindEntry {
pub addr: String,
#[serde(alias = "proxy_protocol", rename = "proxy")]
pub proxy: Option<String>,
#[serde(default = "default_socket_mode", deserialize_with = "deserialize_mode")]
pub mode: u32,
}
fn default_socket_mode() -> u32 {
0o600
}
fn deserialize_mode<'de, D>(deserializer: D) -> Result<u32, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum ModeValue {
Integer(u32),
String(String),
}
let value = ModeValue::deserialize(deserializer)?;
match value {
ModeValue::Integer(i) => {
// If user provides 666, they likely mean octal 0666.
// But in YAML `mode: 666` is decimal 666.
// The requirement says: "if user wrote integer (e.g. 666), process as octal"
// So we interpret the decimal value as a sequence of octal digits.
// e.g. decimal 666 -> octal 666 (which is decimal 438)
let s = i.to_string();
u32::from_str_radix(&s, 8).map_err(serde::de::Error::custom)
}
ModeValue::String(s) => {
// If string, parse as octal
u32::from_str_radix(&s, 8).map_err(serde::de::Error::custom)
}
}
}
impl Config {
@@ -70,8 +103,6 @@ database:
dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db"
batch_size: 50
batch_timeout_secs: 5
tables:
tcp: tcp_log
services:
- name: "ssh-prod"
@@ -97,4 +128,26 @@ services:
assert_eq!(config.services[0].binds[0].proxy, Some("v2".to_string()));
assert_eq!(config.services[0].forward_to, "127.0.0.1:22");
}
#[test]
fn test_mode_deserialization() {
#[derive(Deserialize)]
struct TestBind {
#[serde(default = "default_socket_mode", deserialize_with = "deserialize_mode")]
mode: u32,
}
let yaml_int = "mode: 666";
let bind_int: TestBind = serde_yaml::from_str(yaml_int).unwrap();
assert_eq!(bind_int.mode, 0o666); // 438 decimal
let yaml_str = "mode: '600'";
let bind_str: TestBind = serde_yaml::from_str(yaml_str).unwrap();
assert_eq!(bind_str.mode, 0o600); // 384 decimal
// Test default
let yaml_empty = "{}";
let bind_empty: TestBind = serde_yaml::from_str(yaml_empty).unwrap();
assert_eq!(bind_empty.mode, 0o600);
}
}

View File

@@ -3,9 +3,12 @@ use crate::core::forwarder;
use crate::core::upstream::UpstreamStream;
use crate::db::clickhouse::ClickHouseLogger;
use crate::protocol;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream};
use tokio::signal;
use tracing::{error, info};
@@ -16,7 +19,6 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
})?;
let db = Arc::new(db_logger);
// init db table
// init db table
if let Err(e) = db.init().await {
let msg = e.to_string();
@@ -29,6 +31,7 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
}
let mut join_set = tokio::task::JoinSet::new();
let mut socket_guards = Vec::new();
for service in config.services {
let db = db.clone();
@@ -44,24 +47,49 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
let bind_addr = bind.addr.clone();
// proxy is now Option<String>
let proxy_proto_config = bind.proxy.clone();
let mode = bind.mode;
// BindType is removed, assume TCP bind for "tcp" service
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
error!(
"[{}] failed to bind {}: {}",
service_config.name, bind_addr, e
if bind_addr.starts_with("unix://") {
let path = bind_addr.trim_start_matches("unix://");
// bind_robust handles cleanup, existing file checks, and permission checks
let (listener, guard) = bind_robust(path, mode, &service_config.name).await?;
// Push guard to keep it alive until shutdown
socket_guards.push(guard);
info!(
"[{}] listening on unix {} (mode {:o})",
service_config.name, path, mode
);
e
})?;
info!("[{}] listening on tcp {}", service_config.name, bind_addr);
join_set.spawn(start_unix_service(
service_config,
listener,
proxy_proto_config,
db.clone(),
bind.addr.clone(),
));
} else {
// BindType is removed, assume TCP bind for "tcp" service
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
error!(
"[{}] failed to bind {}: {}",
service_config.name, bind_addr, e
);
e
})?;
join_set.spawn(start_tcp_service(
service_config,
listener,
proxy_proto_config,
db.clone(),
));
info!("[{}] listening on tcp {}", service_config.name, bind_addr);
join_set.spawn(start_tcp_service(
service_config,
listener,
proxy_proto_config,
db.clone(),
bind.addr.clone(),
));
}
}
}
@@ -87,14 +115,98 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
join_set.shutdown().await;
// socket_guards are dropped here, cleaning up files
Ok(())
}
struct UnixSocketGuard {
path: std::path::PathBuf,
}
impl Drop for UnixSocketGuard {
fn drop(&mut self) {
if let Err(e) = std::fs::remove_file(&self.path) {
// It's possible the file is already gone or we lost permissions, just log debug.
tracing::debug!("failed to remove socket file {:?}: {}", self.path, e);
} else {
tracing::debug!("removed socket file {:?}", self.path);
}
}
}
async fn bind_robust(
path: &str,
mode: u32,
service_name: &str,
) -> anyhow::Result<(UnixListener, UnixSocketGuard)> {
let path_buf = std::path::Path::new(path).to_path_buf();
if path_buf.exists() {
// Check permissions first: if we cannot write to it, we certainly cannot remove it.
// metadata() follows symlinks, symlink_metadata() does not. Unix sockets are regular files-ish.
match std::fs::symlink_metadata(&path_buf) {
Ok(_meta) => {
// We rely on subsequent operations (connect/remove) to fail with PermissionDenied if we lack access.
}
Err(e) => {
if e.kind() == std::io::ErrorKind::PermissionDenied {
anyhow::bail!("Permission denied accessing existing socket: {}", path);
}
}
}
// Try to connect to check if it's active
match UnixStream::connect(&path_buf).await {
Ok(_) => {
// Active!
anyhow::bail!("Address already in use: {}", path);
}
Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
// Stale! Remove it.
info!("[{}] removing stale socket file: {}", service_name, path);
if let Err(rm_err) = std::fs::remove_file(&path_buf) {
anyhow::bail!("failed to remove stale socket {}: {}", path, rm_err);
}
}
Err(e) => {
// Other error (e.g. Permission Denied during connect?), bail
anyhow::bail!("failed to check existing socket {}: {}", path, e);
}
}
}
// Now bind
let listener = UnixListener::bind(&path_buf).map_err(|e| {
error!("[{}] failed to bind {}: {}", service_name, path, e);
e
})?;
// Set permissions
use std::os::unix::fs::PermissionsExt;
if let Ok(metadata) = std::fs::metadata(&path_buf) {
let mut permissions = metadata.permissions();
// Verify if we need to change it
if permissions.mode() & 0o777 != mode & 0o777 {
permissions.set_mode(mode);
if let Err(e) = std::fs::set_permissions(&path_buf, permissions) {
// This is not fatal but worth error log
error!(
"[{}] failed to set permissions on {}: {}",
service_name, path, e
);
}
}
}
Ok((listener, UnixSocketGuard { path: path_buf }))
}
async fn start_tcp_service(
service: ServiceConfig,
listener: TcpListener,
proxy_cfg: Option<String>,
db: Arc<ClickHouseLogger>,
listen_addr: String,
) {
// Startup liveness check
if let Err(e) = UpstreamStream::connect(&service.forward_to).await {
@@ -123,12 +235,14 @@ async fn start_tcp_service(
let service = service.clone();
let db = db.clone();
let proxy_cfg = proxy_cfg.clone();
let listen_addr = listen_addr.clone();
tokio::spawn(async move {
let svc_name = service.name.clone();
let svc_target = service.forward_to.clone();
let inbound = InboundStream::Tcp(inbound);
if let Err(e) = handle_connection(inbound, service, proxy_cfg, db).await {
if let Err(e) = handle_connection(inbound, service, proxy_cfg, db, listen_addr).await {
match e.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
// normal disconnects, debug log only
@@ -154,19 +268,97 @@ async fn start_tcp_service(
}
}
async fn start_unix_service(
service: ServiceConfig,
listener: UnixListener,
proxy_cfg: Option<String>,
db: Arc<ClickHouseLogger>,
listen_addr: String,
) {
// Startup liveness check (same as TCP)
if let Err(e) = UpstreamStream::connect(&service.forward_to).await {
match e.kind() {
std::io::ErrorKind::ConnectionRefused => {
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
}
std::io::ErrorKind::NotFound => {
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
}
_ => {
tracing::warn!(
"[{}] -> '{}': startup check failed: {}",
service.name,
service.forward_to,
e
);
}
}
}
loop {
match listener.accept().await {
Ok((inbound, _addr)) => {
let service = service.clone();
let db = db.clone();
let proxy_cfg = proxy_cfg.clone();
let listen_addr = listen_addr.clone();
tokio::spawn(async move {
let svc_name = service.name.clone();
let svc_target = service.forward_to.clone();
let inbound = InboundStream::Unix(inbound);
if let Err(e) = handle_connection(inbound, service, proxy_cfg, db, listen_addr).await {
match e.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
tracing::debug!("connection closed: {}", e);
}
std::io::ErrorKind::ConnectionRefused => {
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
}
std::io::ErrorKind::NotFound => {
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
}
_ => {
error!("connection error: {}", e);
}
}
}
});
}
Err(e) => {
error!("accept error: {}", e);
}
}
}
}
async fn handle_connection(
mut inbound: tokio::net::TcpStream,
mut inbound: InboundStream,
service: ServiceConfig,
proxy_cfg: Option<String>,
db: Arc<ClickHouseLogger>,
listen_addr: String,
) -> std::io::Result<u64> {
let conn_ts = time::OffsetDateTime::now_utc();
let start_instant = std::time::Instant::now();
// Default metadata
let mut final_ip = inbound.peer_addr()?.ip();
let mut final_port = inbound.peer_addr()?.port();
// We use this flag to help decide addr_family logic later, or infer from inbound type
let is_unix = matches!(inbound, InboundStream::Unix(_));
let (mut final_ip, mut final_port) = match &inbound {
InboundStream::Tcp(s) => {
let addr = s.peer_addr()?;
(addr.ip(), addr.port())
}
InboundStream::Unix(_) => (
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
0,
),
};
let mut proto_enum = crate::db::clickhouse::ProxyProto::None;
let mut skip_log = false;
let result = async {
// read proxy protocol (if configured)
@@ -174,15 +366,23 @@ async fn handle_connection(
if proxy_cfg.is_some() {
// If configured, we attempt to read.
// Strict V2/V1 check can be implemented if needed, but here we just use the parser.
match protocol::read_proxy_header(&mut inbound).await {
Ok((proxy_info, buf)) => {
buffer = buf;
if let Some(info) = proxy_info {
let physical = inbound.peer_addr()?;
info!("[{}] <- {} ({})", service.name, info.source, physical);
let physical = inbound.peer_addr_string()?;
// INFO [ssh] unix://./test.sock <- 192.168.1.1:12345 (unix_socket)
// Or INFO [ssh] 0.0.0.0:2222 <- 1.2.3.4:5678 (1.2.3.4:5678)
info!(
"[{}] {} <- {} ({})",
service.name, listen_addr, info.source, physical
);
final_ip = info.source.ip();
final_port = info.source.port();
// Note: If we get proxy info, it's effectively "proxied TCP" usually.
// So we rely on the IP address family of final_ip later.
proto_enum = match info.version {
protocol::Version::V1 => crate::db::clickhouse::ProxyProto::V1,
protocol::Version::V2 => crate::db::clickhouse::ProxyProto::V2,
@@ -202,17 +402,29 @@ async fn handle_connection(
}
} else {
// Strict enforcement: if configured with proxy_protocol, MUST have a header
let physical = inbound.peer_addr()?;
let physical = inbound.peer_addr_string()?;
let msg = format!("strict proxy protocol violation from {}", physical);
error!("[{}] {}", service.name, msg);
skip_log = true;
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, msg));
}
}
Err(e) => return Err(e),
Err(e) => {
skip_log = true;
return Err(e);
}
}
} else {
let addr = inbound.peer_addr()?;
info!("[{}] <- {}", service.name, addr);
let addr = if matches!(inbound, InboundStream::Unix(_)) {
// If Unix socket without proxy, display 127.0.0.1:0 as per logic or ...
// User requested: unix://... <- 127.0.0.1:port
// But inbound.peer_addr_string() for unix is "unix_socket"
// And we set final_ip to 127.0.0.1, final_port to 0
format!("{}:{}", final_ip, final_port)
} else {
inbound.peer_addr_string()?
};
info!("[{}] {} <- {}", service.name, listen_addr, addr);
}
// connect upstream
@@ -224,7 +436,10 @@ async fn handle_connection(
}
// zero-copy forwarding
let inbound_async = crate::core::upstream::AsyncStream::from_tokio_tcp(inbound)?;
let inbound_async = match inbound {
InboundStream::Tcp(s) => crate::core::upstream::AsyncStream::from_tokio_tcp(s)?,
InboundStream::Unix(s) => crate::core::upstream::AsyncStream::from_tokio_unix(s)?,
};
let upstream_async = upstream.into_async_stream()?;
let (spliced_bytes, splice_res) =
@@ -240,7 +455,8 @@ async fn handle_connection(
}
}
} else {
info!("[{}] connection closed cleanly", service.name);
// Clean close logging removed as per request
// info!("[{}] connection closed cleanly", service.name);
}
// Total bytes = initial peeked/buffered payload + filtered bytes
@@ -256,21 +472,97 @@ async fn handle_connection(
let bytes_transferred = result.as_ref().unwrap_or(&0).clone();
// Finalize AddrFamily based on final_ip
// But if it was originally Unix AND no proxy info changed the IP (so it's still 127.0.0.1?)
// Wait, if Unix without proxy, final_ip IS 127.0.0.1.
// We want AddrFamily::Unix (1) for proper unix socket.
// If Unix WITH proxy, final_ip is Real IP -> AddrFamily::Ipv4/6.
let mut addr_family = match final_ip {
std::net::IpAddr::V4(_) => crate::db::clickhouse::AddrFamily::Ipv4,
std::net::IpAddr::V6(_) => crate::db::clickhouse::AddrFamily::Ipv6,
};
if is_unix && proto_enum == crate::db::clickhouse::ProxyProto::None {
// Unix socket, direct connection (or no proxy header received)
addr_family = crate::db::clickhouse::AddrFamily::Unix;
// Store 0 (::)
final_ip = std::net::IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED);
final_port = 0;
}
let log_entry = crate::db::clickhouse::TcpLog {
service: service.name.clone(),
conn_ts,
duration,
port: final_port,
duration: duration as u32,
addr_family,
ip: final_ip,
port: final_port,
proxy_proto: proto_enum,
bytes: bytes_transferred,
};
tokio::spawn(async move {
if let Err(e) = db.insert_log(log_entry).await {
error!("failed to insert tcp log: {}", e);
}
});
if !skip_log {
tokio::spawn(async move {
if let Err(e) = db.insert_log(log_entry).await {
error!("failed to insert tcp log: {}", e);
}
});
}
result
}
enum InboundStream {
Tcp(TcpStream),
Unix(UnixStream),
}
impl InboundStream {
fn peer_addr_string(&self) -> std::io::Result<String> {
match self {
InboundStream::Tcp(s) => Ok(s.peer_addr()?.to_string()),
InboundStream::Unix(_) => Ok("unix_socket".to_string()),
}
}
}
impl AsyncRead for InboundStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
InboundStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
InboundStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for InboundStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
InboundStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
InboundStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
InboundStream::Tcp(s) => Pin::new(s).poll_flush(cx),
InboundStream::Unix(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
InboundStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
InboundStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
}
}
}

View File

@@ -102,6 +102,12 @@ impl AsyncStream {
Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?))
}
pub fn from_tokio_unix(stream: tokio::net::UnixStream) -> io::Result<Self> {
let std = stream.into_std()?;
std.set_nonblocking(true)?;
Ok(AsyncStream::Unix(tokio::io::unix::AsyncFd::new(std)?))
}
pub async fn splice_read(&self, pipe_out: RawFd, len: usize) -> io::Result<usize> {
match self {
AsyncStream::Tcp(fd) => perform_splice_read(fd, pipe_out, len).await,

View File

@@ -3,9 +3,9 @@ use clickhouse::{Client, Row};
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::net::{IpAddr, Ipv6Addr};
use tracing::info;
use tracing::{error, info};
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr)]
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr, PartialEq)]
#[repr(u8)]
pub enum ProxyProto {
None = 0,
@@ -13,49 +13,62 @@ pub enum ProxyProto {
V2 = 2,
}
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr)]
#[repr(u8)]
pub enum AddrFamily {
Unix = 1,
Ipv4 = 2,
Ipv6 = 10,
}
#[derive(Debug, Clone)]
pub struct TcpLog {
pub service: String,
pub conn_ts: time::OffsetDateTime,
pub duration: u32,
pub port: u16,
pub addr_family: AddrFamily,
pub ip: IpAddr,
pub port: u16,
pub proxy_proto: ProxyProto,
pub bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, Row)]
struct TcpLogV4 {
struct TcpLogNew {
pub service: String,
#[serde(with = "clickhouse::serde::time::datetime")]
#[serde(with = "clickhouse::serde::time::datetime64::millis")]
pub conn_ts: time::OffsetDateTime,
pub duration: u32,
pub port: u16,
pub ip: u32,
pub proxy_proto: ProxyProto,
pub bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, Row)]
struct TcpLogV6 {
pub service: String,
#[serde(with = "clickhouse::serde::time::datetime")]
pub conn_ts: time::OffsetDateTime,
pub duration: u32,
pub port: u16,
pub addr_family: AddrFamily,
pub ip: Ipv6Addr,
pub port: u16,
pub proxy_proto: ProxyProto,
pub bytes: u64,
}
pub struct ClickHouseLogger {
client: Client,
table_base: String,
db_name: String,
}
impl ClickHouseLogger {
pub fn new(config: &DatabaseConfig) -> anyhow::Result<Self> {
let url = url::Url::parse(&config.dsn).map_err(|e| anyhow::anyhow!("invalid dsn: {}", e))?;
let mut url =
url::Url::parse(&config.dsn).map_err(|e| anyhow::anyhow!("invalid dsn: {}", e))?;
let mut db_name = "default".to_string();
// specific handling for extracting database from path
if let Some(path_segments) = url.path_segments().map(|c| c.collect::<Vec<_>>()) {
if let Some(db) = path_segments.first() {
if !db.is_empty() {
db_name = db.to_string();
}
}
}
// Clear path from URL so client doesn't append it to requests
url.set_path("");
let mut client = Client::default().with_url(url.as_str());
if let (Some(u), Some(p)) = (Some(url.username()), url.password()) {
@@ -66,165 +79,172 @@ impl ClickHouseLogger {
client = client.with_user(url.username());
}
if let Some(path) = url.path_segments().map(|c| c.collect::<Vec<_>>()) {
if let Some(db) = path.first() {
if !db.is_empty() {
client = client.with_database(*db);
}
}
if !db_name.is_empty() && db_name != "default" {
client = client.with_database(&db_name);
}
// Config table name, default to "tcp_log" if missing
// We expect config.tables to contain "tcp" -> "tablename"
let table_base = config
.tables
.get("tcp")
.cloned()
.unwrap_or_else(|| "tcp_log".to_string());
Ok(Self { client, table_base })
Ok(Self { client, db_name })
}
pub async fn init(&self) -> anyhow::Result<()> {
let table_v4 = format!("{}_v4", self.table_base);
let table_v6 = format!("{}_v6", self.table_base);
let view_name = &self.table_base;
let sql_v4 = format!(
r#"
CREATE TABLE IF NOT EXISTS {} (
service LowCardinality(String),
conn_ts DateTime('UTC'),
duration UInt32,
port UInt16,
ip IPv4,
proxy_proto Enum8('None' = 0, 'V1' = 1, 'V2' = 2),
bytes UInt64
) ENGINE = MergeTree()
ORDER BY (service, conn_ts);
"#,
table_v4
);
let sql_v6 = format!(
r#"
CREATE TABLE IF NOT EXISTS {} (
service LowCardinality(String),
conn_ts DateTime('UTC'),
duration UInt32,
port UInt16,
ip IPv6,
proxy_proto Enum8('None' = 0, 'V1' = 1, 'V2' = 2),
bytes UInt64
) ENGINE = MergeTree()
ORDER BY (service, conn_ts);
"#,
table_v6
);
let sql_view = format!(
r#"
CREATE VIEW IF NOT EXISTS {} AS
SELECT
service, conn_ts, duration, port,
IPv4NumToString(ip) AS ip_str,
proxy_proto,
formatReadableSize(bytes) AS traffic
FROM {}
UNION ALL
SELECT
service, conn_ts, duration, port,
IPv6NumToString(ip) AS ip_str,
proxy_proto,
formatReadableSize(bytes) AS traffic
FROM {};
"#,
view_name, table_v4, table_v6
);
self
.client
.query(&sql_v4)
// Ensure database exists. Use 'default' database context to execute CREATE DATABASE.
let sys_client = self.client.clone().with_database("default");
sys_client
.query(&format!("CREATE DATABASE IF NOT EXISTS {}", self.db_name))
.execute()
.await
.map_err(|e| anyhow::anyhow!("failed to create v4 table: {}", e))?;
.map_err(|e| anyhow::anyhow!("failed to create database: {}", e))?;
self
.client
.query(&sql_v6)
.execute()
.await
.map_err(|e| anyhow::anyhow!("failed to create v6 table: {}", e))?;
// Schema Check / Migration
for (table, is_v6) in [(&table_v4, false), (&table_v6, true)] {
let ip_type = if is_v6 { "IPv6" } else { "IPv4" };
let columns = [
("service", "LowCardinality(String)"),
("conn_ts", "DateTime('UTC')"),
("duration", "UInt32"),
("port", "UInt16"),
("ip", ip_type),
("proxy_proto", "Enum8('None' = 0, 'V1' = 1, 'V2' = 2)"),
("bytes", "UInt64"),
];
for (name, type_def) in columns {
self
.client
.query(&format!(
"ALTER TABLE {} ADD COLUMN IF NOT EXISTS {} {}",
table, name, type_def
))
.execute()
.await
.map_err(|e| anyhow::anyhow!("failed to add column {} to {}: {}", name, table, e))?;
}
}
self
.client
.query(&sql_view)
.execute()
.await
.map_err(|e| anyhow::anyhow!("failed to create view: {}", e))?;
// Check migrations
self.check_migrations().await?;
info!("connected to database");
Ok(())
}
pub async fn insert_log(&self, log: TcpLog) -> anyhow::Result<()> {
match log.ip {
IpAddr::V4(ip) => {
let row = TcpLogV4 {
service: log.service,
conn_ts: log.conn_ts,
duration: log.duration,
port: log.port,
ip: u32::from(ip),
proxy_proto: log.proxy_proto,
bytes: log.bytes,
};
let table = format!("{}_v4", self.table_base);
let mut insert = self.client.insert(&table)?;
insert.write(&row).await?;
insert.end().await?;
}
IpAddr::V6(ip) => {
let row = TcpLogV6 {
service: log.service,
conn_ts: log.conn_ts,
duration: log.duration,
port: log.port,
ip,
proxy_proto: log.proxy_proto,
bytes: log.bytes,
};
let table = format!("{}_v6", self.table_base);
let mut insert = self.client.insert(&table)?;
insert.write(&row).await?;
insert.end().await?;
async fn check_migrations(&self) -> anyhow::Result<()> {
// Create migrations table
self
.client
.query(
"
CREATE TABLE IF NOT EXISTS db_migrations (
version String,
success UInt8,
apply_ts DateTime64 DEFAULT now()
) ENGINE = ReplacingMergeTree(apply_ts)
ORDER BY version
",
)
.execute()
.await
.map_err(|e| anyhow::anyhow!("failed to create migrations table: {}", e))?;
// Get current DB version
#[derive(Row, Deserialize)]
struct MigrationRow {
version: String,
success: u8,
}
let last_migration = self
.client
.query("SELECT version, success FROM db_migrations ORDER BY apply_ts DESC LIMIT 1")
.fetch_optional::<MigrationRow>()
.await
.map_err(|e| anyhow::anyhow!("failed to fetch last migration: {}", e))?;
let (current_db_version, success) = last_migration
.map(|r| (r.version, r.success == 1))
.unwrap_or_else(|| ("v0.0.0".to_string(), true));
if current_db_version == crate::VERSION && success {
return Ok(());
}
if !success {
error!(
"previous migration to {} failed. retrying...",
current_db_version
);
} else {
info!(
"migrating database from {} to {}",
current_db_version,
crate::VERSION
);
}
self.run_migrations(&current_db_version, success).await?;
Ok(())
}
async fn run_migrations(&self, from_version: &str, last_success: bool) -> anyhow::Result<()> {
if from_version < "v0.0.1" || (from_version == "v0.0.1" && !last_success) {
info!("applying migration v0.0.1...");
if let Err(e) = self.apply_v0_0_1().await {
error!("migration v0.0.1 failed: {}", e);
// Record failure
let _ = self
.client
.query("INSERT INTO db_migrations (version, success) VALUES (?, 0)")
.bind(crate::VERSION)
.execute()
.await;
return Err(e);
}
// Record success
self
.client
.query("INSERT INTO db_migrations (version, success) VALUES (?, 1)")
.bind(crate::VERSION)
.execute()
.await
.map_err(|e| anyhow::anyhow!("failed to record migration success: {}", e))?;
info!("migration v0.0.1 applied successfully");
}
Ok(())
}
async fn apply_v0_0_1(&self) -> anyhow::Result<()> {
// 1. Create table (tcp_log)
let sql_create = r#"
CREATE TABLE IF NOT EXISTS tcp_log (
service LowCardinality(String),
conn_ts DateTime64(3),
duration UInt32,
addr_family Enum8('unix'=1, 'ipv4'=2, 'ipv6'=10),
ip IPv6,
port UInt16,
proxy_proto Enum8('None' = 0, 'V1' = 1, 'V2' = 2),
bytes UInt64
) ENGINE = MergeTree()
ORDER BY (service, conn_ts);
"#;
self.client.query(sql_create).execute().await?;
// 2. Create View
let sql_view_refined = r#"
CREATE VIEW IF NOT EXISTS tcp_log_view AS
SELECT
service, conn_ts, duration, addr_family,
multiIf(
addr_family = 1, 'unix socket',
addr_family = 2, IPv4NumToString(toIPv4(ip)),
IPv6NumToString(ip)
) as ip_str,
port,
proxy_proto,
formatReadableSize(bytes) AS traffic
FROM tcp_log
"#;
self.client.query(sql_view_refined).execute().await?;
Ok(())
}
pub async fn insert_log(&self, log: TcpLog) -> anyhow::Result<()> {
let ipv6 = match log.ip {
IpAddr::V4(ip) => ip.to_ipv6_mapped(),
IpAddr::V6(ip) => ip,
};
let row = TcpLogNew {
service: log.service,
conn_ts: log.conn_ts,
duration: log.duration,
addr_family: log.addr_family,
ip: ipv6,
port: log.port,
proxy_proto: log.proxy_proto,
bytes: log.bytes,
};
let mut insert = self.client.insert("tcp_log")?;
insert.write(&row).await?;
insert.end().await?;
Ok(())
}
}

View File

@@ -9,6 +9,8 @@ use std::env;
use std::path::Path;
use tracing::{error, info};
pub const VERSION: &str = concat!("v", env!("CARGO_PKG_VERSION"));
fn print_help() {
println!("traudit - a reverse proxy with auditing capabilities");
println!();
@@ -18,6 +20,7 @@ fn print_help() {
println!("options:");
println!(" -f <config_file> path to the yaml configuration file");
println!(" -t, --test test configuration and exit");
println!(" -v, --version print version");
println!(" -h, --help print this help message");
println!();
println!("project: https://github.com/awfufu/traudit");
@@ -56,6 +59,10 @@ async fn main() -> anyhow::Result<()> {
print_help();
return Ok(());
}
"-v" | "--version" => {
println!("{}", VERSION);
return Ok(());
}
_ => {
bail!("unknown argument: {}\n\nuse -h for help", args[i]);
}