mirror of
https://github.com/awfufu/traudit
synced 2026-03-01 05:29:44 +08:00
feat: support unix socket binds and update db schema
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -1014,7 +1014,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "traudit"
|
||||
version = "0.1.0"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "traudit"
|
||||
version = "0.1.0"
|
||||
version = "0.0.1"
|
||||
edition = "2021"
|
||||
authors = ["awfufu"]
|
||||
description = "A reverse proxy with auditing capabilities."
|
||||
@@ -29,7 +29,7 @@ tempfile = "3"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
lto = true
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
strip = true
|
||||
|
||||
@@ -24,7 +24,7 @@ See [config_example.yaml](config_example.yaml).
|
||||
- [x] TCP Proxy & Zero-copy forwarding (`splice`)
|
||||
- [x] Proxy Protocol V1/V2 parsing
|
||||
- [ ] UDP Forwarding (Planned)
|
||||
- [ ] Unix Socket Forwarding (Planned)
|
||||
- [x] Unix Socket Forwarding
|
||||
- [x] Database Integration
|
||||
- [x] ClickHouse Adapter (Native Interface)
|
||||
- [x] Traffic Accounting (Bytes/Bandwidth)
|
||||
|
||||
@@ -24,7 +24,7 @@ traudit 是一个支持 TCP/UDP/Unix Socket 的反向代理程序,专注于连
|
||||
- [x] TCP 代理与零拷贝转发 (`splice`)
|
||||
- [x] Proxy Protocol V1/V2 解析
|
||||
- [ ] UDP 转发 (计划中)
|
||||
- [ ] Unix Socket 转发 (计划中)
|
||||
- [x] Unix Socket 转发
|
||||
- [x] 数据库集成
|
||||
- [x] ClickHouse 适配器 (原生接口)
|
||||
- [x] 流量统计 (字节数)
|
||||
|
||||
@@ -2,11 +2,9 @@
|
||||
|
||||
database:
|
||||
type: clickhouse
|
||||
dsn: "http://user:password@ip:port"
|
||||
dsn: "http://user:password@ip:port/traudit"
|
||||
batch_size: 50
|
||||
batch_timeout_secs: 5
|
||||
tables:
|
||||
tcp: tcp_log
|
||||
|
||||
services:
|
||||
# Receives traffic from FRP with v2 Proxy Protocol header, audits it,
|
||||
@@ -24,9 +22,3 @@ services:
|
||||
|
||||
forward_to: "127.0.0.1:22"
|
||||
|
||||
# - name: "web"
|
||||
# type: "tcp"
|
||||
# binds:
|
||||
# - addr: "0.0.0.0:8080"
|
||||
|
||||
# forward_to: "/run/nginx/web.sock"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use std::path::Path;
|
||||
use tokio::fs;
|
||||
|
||||
@@ -15,7 +14,6 @@ pub struct DatabaseConfig {
|
||||
#[allow(dead_code)]
|
||||
pub db_type: String,
|
||||
pub dsn: String,
|
||||
pub tables: HashMap<String, String>,
|
||||
#[serde(default = "default_batch_size")]
|
||||
#[allow(dead_code)]
|
||||
pub batch_size: usize,
|
||||
@@ -47,6 +45,41 @@ pub struct BindEntry {
|
||||
pub addr: String,
|
||||
#[serde(alias = "proxy_protocol", rename = "proxy")]
|
||||
pub proxy: Option<String>,
|
||||
#[serde(default = "default_socket_mode", deserialize_with = "deserialize_mode")]
|
||||
pub mode: u32,
|
||||
}
|
||||
|
||||
fn default_socket_mode() -> u32 {
|
||||
0o600
|
||||
}
|
||||
|
||||
fn deserialize_mode<'de, D>(deserializer: D) -> Result<u32, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ModeValue {
|
||||
Integer(u32),
|
||||
String(String),
|
||||
}
|
||||
|
||||
let value = ModeValue::deserialize(deserializer)?;
|
||||
match value {
|
||||
ModeValue::Integer(i) => {
|
||||
// If user provides 666, they likely mean octal 0666.
|
||||
// But in YAML `mode: 666` is decimal 666.
|
||||
// The requirement says: "if user wrote integer (e.g. 666), process as octal"
|
||||
// So we interpret the decimal value as a sequence of octal digits.
|
||||
// e.g. decimal 666 -> octal 666 (which is decimal 438)
|
||||
let s = i.to_string();
|
||||
u32::from_str_radix(&s, 8).map_err(serde::de::Error::custom)
|
||||
}
|
||||
ModeValue::String(s) => {
|
||||
// If string, parse as octal
|
||||
u32::from_str_radix(&s, 8).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -70,8 +103,6 @@ database:
|
||||
dsn: "clickhouse://admin:password@127.0.0.1:8123/audit_db"
|
||||
batch_size: 50
|
||||
batch_timeout_secs: 5
|
||||
tables:
|
||||
tcp: tcp_log
|
||||
|
||||
services:
|
||||
- name: "ssh-prod"
|
||||
@@ -97,4 +128,26 @@ services:
|
||||
assert_eq!(config.services[0].binds[0].proxy, Some("v2".to_string()));
|
||||
assert_eq!(config.services[0].forward_to, "127.0.0.1:22");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mode_deserialization() {
|
||||
#[derive(Deserialize)]
|
||||
struct TestBind {
|
||||
#[serde(default = "default_socket_mode", deserialize_with = "deserialize_mode")]
|
||||
mode: u32,
|
||||
}
|
||||
|
||||
let yaml_int = "mode: 666";
|
||||
let bind_int: TestBind = serde_yaml::from_str(yaml_int).unwrap();
|
||||
assert_eq!(bind_int.mode, 0o666); // 438 decimal
|
||||
|
||||
let yaml_str = "mode: '600'";
|
||||
let bind_str: TestBind = serde_yaml::from_str(yaml_str).unwrap();
|
||||
assert_eq!(bind_str.mode, 0o600); // 384 decimal
|
||||
|
||||
// Test default
|
||||
let yaml_empty = "{}";
|
||||
let bind_empty: TestBind = serde_yaml::from_str(yaml_empty).unwrap();
|
||||
assert_eq!(bind_empty.mode, 0o600);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@ use crate::core::forwarder;
|
||||
use crate::core::upstream::UpstreamStream;
|
||||
use crate::db::clickhouse::ClickHouseLogger;
|
||||
use crate::protocol;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream};
|
||||
use tokio::signal;
|
||||
use tracing::{error, info};
|
||||
|
||||
@@ -16,7 +19,6 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
})?;
|
||||
let db = Arc::new(db_logger);
|
||||
|
||||
// init db table
|
||||
// init db table
|
||||
if let Err(e) = db.init().await {
|
||||
let msg = e.to_string();
|
||||
@@ -29,6 +31,7 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
let mut socket_guards = Vec::new();
|
||||
|
||||
for service in config.services {
|
||||
let db = db.clone();
|
||||
@@ -44,24 +47,49 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
let bind_addr = bind.addr.clone();
|
||||
// proxy is now Option<String>
|
||||
let proxy_proto_config = bind.proxy.clone();
|
||||
let mode = bind.mode;
|
||||
|
||||
// BindType is removed, assume TCP bind for "tcp" service
|
||||
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
|
||||
error!(
|
||||
"[{}] failed to bind {}: {}",
|
||||
service_config.name, bind_addr, e
|
||||
if bind_addr.starts_with("unix://") {
|
||||
let path = bind_addr.trim_start_matches("unix://");
|
||||
|
||||
// bind_robust handles cleanup, existing file checks, and permission checks
|
||||
let (listener, guard) = bind_robust(path, mode, &service_config.name).await?;
|
||||
|
||||
// Push guard to keep it alive until shutdown
|
||||
socket_guards.push(guard);
|
||||
|
||||
info!(
|
||||
"[{}] listening on unix {} (mode {:o})",
|
||||
service_config.name, path, mode
|
||||
);
|
||||
e
|
||||
})?;
|
||||
|
||||
info!("[{}] listening on tcp {}", service_config.name, bind_addr);
|
||||
join_set.spawn(start_unix_service(
|
||||
service_config,
|
||||
listener,
|
||||
proxy_proto_config,
|
||||
db.clone(),
|
||||
bind.addr.clone(),
|
||||
));
|
||||
} else {
|
||||
// BindType is removed, assume TCP bind for "tcp" service
|
||||
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
|
||||
error!(
|
||||
"[{}] failed to bind {}: {}",
|
||||
service_config.name, bind_addr, e
|
||||
);
|
||||
e
|
||||
})?;
|
||||
|
||||
join_set.spawn(start_tcp_service(
|
||||
service_config,
|
||||
listener,
|
||||
proxy_proto_config,
|
||||
db.clone(),
|
||||
));
|
||||
info!("[{}] listening on tcp {}", service_config.name, bind_addr);
|
||||
|
||||
join_set.spawn(start_tcp_service(
|
||||
service_config,
|
||||
listener,
|
||||
proxy_proto_config,
|
||||
db.clone(),
|
||||
bind.addr.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,14 +115,98 @@ pub async fn run(config: Config) -> anyhow::Result<()> {
|
||||
|
||||
join_set.shutdown().await;
|
||||
|
||||
// socket_guards are dropped here, cleaning up files
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct UnixSocketGuard {
|
||||
path: std::path::PathBuf,
|
||||
}
|
||||
|
||||
impl Drop for UnixSocketGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Err(e) = std::fs::remove_file(&self.path) {
|
||||
// It's possible the file is already gone or we lost permissions, just log debug.
|
||||
tracing::debug!("failed to remove socket file {:?}: {}", self.path, e);
|
||||
} else {
|
||||
tracing::debug!("removed socket file {:?}", self.path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn bind_robust(
|
||||
path: &str,
|
||||
mode: u32,
|
||||
service_name: &str,
|
||||
) -> anyhow::Result<(UnixListener, UnixSocketGuard)> {
|
||||
let path_buf = std::path::Path::new(path).to_path_buf();
|
||||
|
||||
if path_buf.exists() {
|
||||
// Check permissions first: if we cannot write to it, we certainly cannot remove it.
|
||||
// metadata() follows symlinks, symlink_metadata() does not. Unix sockets are regular files-ish.
|
||||
match std::fs::symlink_metadata(&path_buf) {
|
||||
Ok(_meta) => {
|
||||
// We rely on subsequent operations (connect/remove) to fail with PermissionDenied if we lack access.
|
||||
}
|
||||
Err(e) => {
|
||||
if e.kind() == std::io::ErrorKind::PermissionDenied {
|
||||
anyhow::bail!("Permission denied accessing existing socket: {}", path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try to connect to check if it's active
|
||||
match UnixStream::connect(&path_buf).await {
|
||||
Ok(_) => {
|
||||
// Active!
|
||||
anyhow::bail!("Address already in use: {}", path);
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
|
||||
// Stale! Remove it.
|
||||
info!("[{}] removing stale socket file: {}", service_name, path);
|
||||
if let Err(rm_err) = std::fs::remove_file(&path_buf) {
|
||||
anyhow::bail!("failed to remove stale socket {}: {}", path, rm_err);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Other error (e.g. Permission Denied during connect?), bail
|
||||
anyhow::bail!("failed to check existing socket {}: {}", path, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now bind
|
||||
let listener = UnixListener::bind(&path_buf).map_err(|e| {
|
||||
error!("[{}] failed to bind {}: {}", service_name, path, e);
|
||||
e
|
||||
})?;
|
||||
|
||||
// Set permissions
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
if let Ok(metadata) = std::fs::metadata(&path_buf) {
|
||||
let mut permissions = metadata.permissions();
|
||||
// Verify if we need to change it
|
||||
if permissions.mode() & 0o777 != mode & 0o777 {
|
||||
permissions.set_mode(mode);
|
||||
if let Err(e) = std::fs::set_permissions(&path_buf, permissions) {
|
||||
// This is not fatal but worth error log
|
||||
error!(
|
||||
"[{}] failed to set permissions on {}: {}",
|
||||
service_name, path, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((listener, UnixSocketGuard { path: path_buf }))
|
||||
}
|
||||
|
||||
async fn start_tcp_service(
|
||||
service: ServiceConfig,
|
||||
listener: TcpListener,
|
||||
proxy_cfg: Option<String>,
|
||||
db: Arc<ClickHouseLogger>,
|
||||
listen_addr: String,
|
||||
) {
|
||||
// Startup liveness check
|
||||
if let Err(e) = UpstreamStream::connect(&service.forward_to).await {
|
||||
@@ -123,12 +235,14 @@ async fn start_tcp_service(
|
||||
let service = service.clone();
|
||||
let db = db.clone();
|
||||
let proxy_cfg = proxy_cfg.clone();
|
||||
let listen_addr = listen_addr.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let svc_name = service.name.clone();
|
||||
let svc_target = service.forward_to.clone();
|
||||
let inbound = InboundStream::Tcp(inbound);
|
||||
|
||||
if let Err(e) = handle_connection(inbound, service, proxy_cfg, db).await {
|
||||
if let Err(e) = handle_connection(inbound, service, proxy_cfg, db, listen_addr).await {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
|
||||
// normal disconnects, debug log only
|
||||
@@ -154,19 +268,97 @@ async fn start_tcp_service(
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_unix_service(
|
||||
service: ServiceConfig,
|
||||
listener: UnixListener,
|
||||
proxy_cfg: Option<String>,
|
||||
db: Arc<ClickHouseLogger>,
|
||||
listen_addr: String,
|
||||
) {
|
||||
// Startup liveness check (same as TCP)
|
||||
if let Err(e) = UpstreamStream::connect(&service.forward_to).await {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionRefused => {
|
||||
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
|
||||
}
|
||||
std::io::ErrorKind::NotFound => {
|
||||
tracing::warn!("[{}] -> '{}': {}", service.name, service.forward_to, e);
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!(
|
||||
"[{}] -> '{}': startup check failed: {}",
|
||||
service.name,
|
||||
service.forward_to,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((inbound, _addr)) => {
|
||||
let service = service.clone();
|
||||
let db = db.clone();
|
||||
let proxy_cfg = proxy_cfg.clone();
|
||||
let listen_addr = listen_addr.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let svc_name = service.name.clone();
|
||||
let svc_target = service.forward_to.clone();
|
||||
let inbound = InboundStream::Unix(inbound);
|
||||
|
||||
if let Err(e) = handle_connection(inbound, service, proxy_cfg, db, listen_addr).await {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => {
|
||||
tracing::debug!("connection closed: {}", e);
|
||||
}
|
||||
std::io::ErrorKind::ConnectionRefused => {
|
||||
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
|
||||
}
|
||||
std::io::ErrorKind::NotFound => {
|
||||
tracing::warn!("[{}] -> '{}': {}", svc_name, svc_target, e);
|
||||
}
|
||||
_ => {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("accept error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_connection(
|
||||
mut inbound: tokio::net::TcpStream,
|
||||
mut inbound: InboundStream,
|
||||
service: ServiceConfig,
|
||||
proxy_cfg: Option<String>,
|
||||
db: Arc<ClickHouseLogger>,
|
||||
listen_addr: String,
|
||||
) -> std::io::Result<u64> {
|
||||
let conn_ts = time::OffsetDateTime::now_utc();
|
||||
let start_instant = std::time::Instant::now();
|
||||
|
||||
// Default metadata
|
||||
let mut final_ip = inbound.peer_addr()?.ip();
|
||||
let mut final_port = inbound.peer_addr()?.port();
|
||||
// We use this flag to help decide addr_family logic later, or infer from inbound type
|
||||
let is_unix = matches!(inbound, InboundStream::Unix(_));
|
||||
|
||||
let (mut final_ip, mut final_port) = match &inbound {
|
||||
InboundStream::Tcp(s) => {
|
||||
let addr = s.peer_addr()?;
|
||||
(addr.ip(), addr.port())
|
||||
}
|
||||
InboundStream::Unix(_) => (
|
||||
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
|
||||
0,
|
||||
),
|
||||
};
|
||||
let mut proto_enum = crate::db::clickhouse::ProxyProto::None;
|
||||
let mut skip_log = false;
|
||||
|
||||
let result = async {
|
||||
// read proxy protocol (if configured)
|
||||
@@ -174,15 +366,23 @@ async fn handle_connection(
|
||||
|
||||
if proxy_cfg.is_some() {
|
||||
// If configured, we attempt to read.
|
||||
// Strict V2/V1 check can be implemented if needed, but here we just use the parser.
|
||||
match protocol::read_proxy_header(&mut inbound).await {
|
||||
Ok((proxy_info, buf)) => {
|
||||
buffer = buf;
|
||||
if let Some(info) = proxy_info {
|
||||
let physical = inbound.peer_addr()?;
|
||||
info!("[{}] <- {} ({})", service.name, info.source, physical);
|
||||
let physical = inbound.peer_addr_string()?;
|
||||
// INFO [ssh] unix://./test.sock <- 192.168.1.1:12345 (unix_socket)
|
||||
// Or INFO [ssh] 0.0.0.0:2222 <- 1.2.3.4:5678 (1.2.3.4:5678)
|
||||
info!(
|
||||
"[{}] {} <- {} ({})",
|
||||
service.name, listen_addr, info.source, physical
|
||||
);
|
||||
final_ip = info.source.ip();
|
||||
final_port = info.source.port();
|
||||
|
||||
// Note: If we get proxy info, it's effectively "proxied TCP" usually.
|
||||
// So we rely on the IP address family of final_ip later.
|
||||
|
||||
proto_enum = match info.version {
|
||||
protocol::Version::V1 => crate::db::clickhouse::ProxyProto::V1,
|
||||
protocol::Version::V2 => crate::db::clickhouse::ProxyProto::V2,
|
||||
@@ -202,17 +402,29 @@ async fn handle_connection(
|
||||
}
|
||||
} else {
|
||||
// Strict enforcement: if configured with proxy_protocol, MUST have a header
|
||||
let physical = inbound.peer_addr()?;
|
||||
let physical = inbound.peer_addr_string()?;
|
||||
let msg = format!("strict proxy protocol violation from {}", physical);
|
||||
error!("[{}] {}", service.name, msg);
|
||||
skip_log = true;
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, msg));
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
Err(e) => {
|
||||
skip_log = true;
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let addr = inbound.peer_addr()?;
|
||||
info!("[{}] <- {}", service.name, addr);
|
||||
let addr = if matches!(inbound, InboundStream::Unix(_)) {
|
||||
// If Unix socket without proxy, display 127.0.0.1:0 as per logic or ...
|
||||
// User requested: unix://... <- 127.0.0.1:port
|
||||
// But inbound.peer_addr_string() for unix is "unix_socket"
|
||||
// And we set final_ip to 127.0.0.1, final_port to 0
|
||||
format!("{}:{}", final_ip, final_port)
|
||||
} else {
|
||||
inbound.peer_addr_string()?
|
||||
};
|
||||
info!("[{}] {} <- {}", service.name, listen_addr, addr);
|
||||
}
|
||||
|
||||
// connect upstream
|
||||
@@ -224,7 +436,10 @@ async fn handle_connection(
|
||||
}
|
||||
|
||||
// zero-copy forwarding
|
||||
let inbound_async = crate::core::upstream::AsyncStream::from_tokio_tcp(inbound)?;
|
||||
let inbound_async = match inbound {
|
||||
InboundStream::Tcp(s) => crate::core::upstream::AsyncStream::from_tokio_tcp(s)?,
|
||||
InboundStream::Unix(s) => crate::core::upstream::AsyncStream::from_tokio_unix(s)?,
|
||||
};
|
||||
let upstream_async = upstream.into_async_stream()?;
|
||||
|
||||
let (spliced_bytes, splice_res) =
|
||||
@@ -240,7 +455,8 @@ async fn handle_connection(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("[{}] connection closed cleanly", service.name);
|
||||
// Clean close logging removed as per request
|
||||
// info!("[{}] connection closed cleanly", service.name);
|
||||
}
|
||||
|
||||
// Total bytes = initial peeked/buffered payload + filtered bytes
|
||||
@@ -256,21 +472,97 @@ async fn handle_connection(
|
||||
|
||||
let bytes_transferred = result.as_ref().unwrap_or(&0).clone();
|
||||
|
||||
// Finalize AddrFamily based on final_ip
|
||||
// But if it was originally Unix AND no proxy info changed the IP (so it's still 127.0.0.1?)
|
||||
// Wait, if Unix without proxy, final_ip IS 127.0.0.1.
|
||||
// We want AddrFamily::Unix (1) for proper unix socket.
|
||||
// If Unix WITH proxy, final_ip is Real IP -> AddrFamily::Ipv4/6.
|
||||
|
||||
let mut addr_family = match final_ip {
|
||||
std::net::IpAddr::V4(_) => crate::db::clickhouse::AddrFamily::Ipv4,
|
||||
std::net::IpAddr::V6(_) => crate::db::clickhouse::AddrFamily::Ipv6,
|
||||
};
|
||||
|
||||
if is_unix && proto_enum == crate::db::clickhouse::ProxyProto::None {
|
||||
// Unix socket, direct connection (or no proxy header received)
|
||||
addr_family = crate::db::clickhouse::AddrFamily::Unix;
|
||||
// Store 0 (::)
|
||||
final_ip = std::net::IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED);
|
||||
final_port = 0;
|
||||
}
|
||||
|
||||
let log_entry = crate::db::clickhouse::TcpLog {
|
||||
service: service.name.clone(),
|
||||
conn_ts,
|
||||
duration,
|
||||
port: final_port,
|
||||
duration: duration as u32,
|
||||
addr_family,
|
||||
ip: final_ip,
|
||||
port: final_port,
|
||||
proxy_proto: proto_enum,
|
||||
bytes: bytes_transferred,
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = db.insert_log(log_entry).await {
|
||||
error!("failed to insert tcp log: {}", e);
|
||||
}
|
||||
});
|
||||
if !skip_log {
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = db.insert_log(log_entry).await {
|
||||
error!("failed to insert tcp log: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
enum InboundStream {
|
||||
Tcp(TcpStream),
|
||||
Unix(UnixStream),
|
||||
}
|
||||
|
||||
impl InboundStream {
|
||||
fn peer_addr_string(&self) -> std::io::Result<String> {
|
||||
match self {
|
||||
InboundStream::Tcp(s) => Ok(s.peer_addr()?.to_string()),
|
||||
InboundStream::Unix(_) => Ok("unix_socket".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for InboundStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
InboundStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
|
||||
InboundStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for InboundStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
match self.get_mut() {
|
||||
InboundStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
|
||||
InboundStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.get_mut() {
|
||||
InboundStream::Tcp(s) => Pin::new(s).poll_flush(cx),
|
||||
InboundStream::Unix(s) => Pin::new(s).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.get_mut() {
|
||||
InboundStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
|
||||
InboundStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,6 +102,12 @@ impl AsyncStream {
|
||||
Ok(AsyncStream::Tcp(tokio::io::unix::AsyncFd::new(std)?))
|
||||
}
|
||||
|
||||
pub fn from_tokio_unix(stream: tokio::net::UnixStream) -> io::Result<Self> {
|
||||
let std = stream.into_std()?;
|
||||
std.set_nonblocking(true)?;
|
||||
Ok(AsyncStream::Unix(tokio::io::unix::AsyncFd::new(std)?))
|
||||
}
|
||||
|
||||
pub async fn splice_read(&self, pipe_out: RawFd, len: usize) -> io::Result<usize> {
|
||||
match self {
|
||||
AsyncStream::Tcp(fd) => perform_splice_read(fd, pipe_out, len).await,
|
||||
|
||||
@@ -3,9 +3,9 @@ use clickhouse::{Client, Row};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_repr::{Deserialize_repr, Serialize_repr};
|
||||
use std::net::{IpAddr, Ipv6Addr};
|
||||
use tracing::info;
|
||||
use tracing::{error, info};
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr)]
|
||||
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr, PartialEq)]
|
||||
#[repr(u8)]
|
||||
pub enum ProxyProto {
|
||||
None = 0,
|
||||
@@ -13,49 +13,62 @@ pub enum ProxyProto {
|
||||
V2 = 2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr)]
|
||||
#[repr(u8)]
|
||||
pub enum AddrFamily {
|
||||
Unix = 1,
|
||||
Ipv4 = 2,
|
||||
Ipv6 = 10,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TcpLog {
|
||||
pub service: String,
|
||||
pub conn_ts: time::OffsetDateTime,
|
||||
pub duration: u32,
|
||||
pub port: u16,
|
||||
pub addr_family: AddrFamily,
|
||||
pub ip: IpAddr,
|
||||
pub port: u16,
|
||||
pub proxy_proto: ProxyProto,
|
||||
pub bytes: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Row)]
|
||||
struct TcpLogV4 {
|
||||
struct TcpLogNew {
|
||||
pub service: String,
|
||||
#[serde(with = "clickhouse::serde::time::datetime")]
|
||||
#[serde(with = "clickhouse::serde::time::datetime64::millis")]
|
||||
pub conn_ts: time::OffsetDateTime,
|
||||
pub duration: u32,
|
||||
pub port: u16,
|
||||
pub ip: u32,
|
||||
pub proxy_proto: ProxyProto,
|
||||
pub bytes: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Row)]
|
||||
struct TcpLogV6 {
|
||||
pub service: String,
|
||||
#[serde(with = "clickhouse::serde::time::datetime")]
|
||||
pub conn_ts: time::OffsetDateTime,
|
||||
pub duration: u32,
|
||||
pub port: u16,
|
||||
pub addr_family: AddrFamily,
|
||||
pub ip: Ipv6Addr,
|
||||
pub port: u16,
|
||||
pub proxy_proto: ProxyProto,
|
||||
pub bytes: u64,
|
||||
}
|
||||
|
||||
pub struct ClickHouseLogger {
|
||||
client: Client,
|
||||
table_base: String,
|
||||
db_name: String,
|
||||
}
|
||||
|
||||
impl ClickHouseLogger {
|
||||
pub fn new(config: &DatabaseConfig) -> anyhow::Result<Self> {
|
||||
let url = url::Url::parse(&config.dsn).map_err(|e| anyhow::anyhow!("invalid dsn: {}", e))?;
|
||||
let mut url =
|
||||
url::Url::parse(&config.dsn).map_err(|e| anyhow::anyhow!("invalid dsn: {}", e))?;
|
||||
let mut db_name = "default".to_string();
|
||||
|
||||
// specific handling for extracting database from path
|
||||
if let Some(path_segments) = url.path_segments().map(|c| c.collect::<Vec<_>>()) {
|
||||
if let Some(db) = path_segments.first() {
|
||||
if !db.is_empty() {
|
||||
db_name = db.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear path from URL so client doesn't append it to requests
|
||||
url.set_path("");
|
||||
|
||||
let mut client = Client::default().with_url(url.as_str());
|
||||
|
||||
if let (Some(u), Some(p)) = (Some(url.username()), url.password()) {
|
||||
@@ -66,165 +79,172 @@ impl ClickHouseLogger {
|
||||
client = client.with_user(url.username());
|
||||
}
|
||||
|
||||
if let Some(path) = url.path_segments().map(|c| c.collect::<Vec<_>>()) {
|
||||
if let Some(db) = path.first() {
|
||||
if !db.is_empty() {
|
||||
client = client.with_database(*db);
|
||||
}
|
||||
}
|
||||
if !db_name.is_empty() && db_name != "default" {
|
||||
client = client.with_database(&db_name);
|
||||
}
|
||||
|
||||
// Config table name, default to "tcp_log" if missing
|
||||
// We expect config.tables to contain "tcp" -> "tablename"
|
||||
let table_base = config
|
||||
.tables
|
||||
.get("tcp")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "tcp_log".to_string());
|
||||
|
||||
Ok(Self { client, table_base })
|
||||
Ok(Self { client, db_name })
|
||||
}
|
||||
|
||||
pub async fn init(&self) -> anyhow::Result<()> {
|
||||
let table_v4 = format!("{}_v4", self.table_base);
|
||||
let table_v6 = format!("{}_v6", self.table_base);
|
||||
let view_name = &self.table_base;
|
||||
|
||||
let sql_v4 = format!(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS {} (
|
||||
service LowCardinality(String),
|
||||
conn_ts DateTime('UTC'),
|
||||
duration UInt32,
|
||||
port UInt16,
|
||||
ip IPv4,
|
||||
proxy_proto Enum8('None' = 0, 'V1' = 1, 'V2' = 2),
|
||||
bytes UInt64
|
||||
) ENGINE = MergeTree()
|
||||
ORDER BY (service, conn_ts);
|
||||
"#,
|
||||
table_v4
|
||||
);
|
||||
|
||||
let sql_v6 = format!(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS {} (
|
||||
service LowCardinality(String),
|
||||
conn_ts DateTime('UTC'),
|
||||
duration UInt32,
|
||||
port UInt16,
|
||||
ip IPv6,
|
||||
proxy_proto Enum8('None' = 0, 'V1' = 1, 'V2' = 2),
|
||||
bytes UInt64
|
||||
) ENGINE = MergeTree()
|
||||
ORDER BY (service, conn_ts);
|
||||
"#,
|
||||
table_v6
|
||||
);
|
||||
|
||||
let sql_view = format!(
|
||||
r#"
|
||||
CREATE VIEW IF NOT EXISTS {} AS
|
||||
SELECT
|
||||
service, conn_ts, duration, port,
|
||||
IPv4NumToString(ip) AS ip_str,
|
||||
proxy_proto,
|
||||
formatReadableSize(bytes) AS traffic
|
||||
FROM {}
|
||||
UNION ALL
|
||||
SELECT
|
||||
service, conn_ts, duration, port,
|
||||
IPv6NumToString(ip) AS ip_str,
|
||||
proxy_proto,
|
||||
formatReadableSize(bytes) AS traffic
|
||||
FROM {};
|
||||
"#,
|
||||
view_name, table_v4, table_v6
|
||||
);
|
||||
|
||||
self
|
||||
.client
|
||||
.query(&sql_v4)
|
||||
// Ensure database exists. Use 'default' database context to execute CREATE DATABASE.
|
||||
let sys_client = self.client.clone().with_database("default");
|
||||
sys_client
|
||||
.query(&format!("CREATE DATABASE IF NOT EXISTS {}", self.db_name))
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("failed to create v4 table: {}", e))?;
|
||||
.map_err(|e| anyhow::anyhow!("failed to create database: {}", e))?;
|
||||
|
||||
self
|
||||
.client
|
||||
.query(&sql_v6)
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("failed to create v6 table: {}", e))?;
|
||||
|
||||
// Schema Check / Migration
|
||||
for (table, is_v6) in [(&table_v4, false), (&table_v6, true)] {
|
||||
let ip_type = if is_v6 { "IPv6" } else { "IPv4" };
|
||||
let columns = [
|
||||
("service", "LowCardinality(String)"),
|
||||
("conn_ts", "DateTime('UTC')"),
|
||||
("duration", "UInt32"),
|
||||
("port", "UInt16"),
|
||||
("ip", ip_type),
|
||||
("proxy_proto", "Enum8('None' = 0, 'V1' = 1, 'V2' = 2)"),
|
||||
("bytes", "UInt64"),
|
||||
];
|
||||
for (name, type_def) in columns {
|
||||
self
|
||||
.client
|
||||
.query(&format!(
|
||||
"ALTER TABLE {} ADD COLUMN IF NOT EXISTS {} {}",
|
||||
table, name, type_def
|
||||
))
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("failed to add column {} to {}: {}", name, table, e))?;
|
||||
}
|
||||
}
|
||||
|
||||
self
|
||||
.client
|
||||
.query(&sql_view)
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("failed to create view: {}", e))?;
|
||||
// Check migrations
|
||||
self.check_migrations().await?;
|
||||
|
||||
info!("connected to database");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert_log(&self, log: TcpLog) -> anyhow::Result<()> {
|
||||
match log.ip {
|
||||
IpAddr::V4(ip) => {
|
||||
let row = TcpLogV4 {
|
||||
service: log.service,
|
||||
conn_ts: log.conn_ts,
|
||||
duration: log.duration,
|
||||
port: log.port,
|
||||
ip: u32::from(ip),
|
||||
proxy_proto: log.proxy_proto,
|
||||
bytes: log.bytes,
|
||||
};
|
||||
let table = format!("{}_v4", self.table_base);
|
||||
let mut insert = self.client.insert(&table)?;
|
||||
insert.write(&row).await?;
|
||||
insert.end().await?;
|
||||
}
|
||||
IpAddr::V6(ip) => {
|
||||
let row = TcpLogV6 {
|
||||
service: log.service,
|
||||
conn_ts: log.conn_ts,
|
||||
duration: log.duration,
|
||||
port: log.port,
|
||||
ip,
|
||||
proxy_proto: log.proxy_proto,
|
||||
bytes: log.bytes,
|
||||
};
|
||||
let table = format!("{}_v6", self.table_base);
|
||||
let mut insert = self.client.insert(&table)?;
|
||||
insert.write(&row).await?;
|
||||
insert.end().await?;
|
||||
async fn check_migrations(&self) -> anyhow::Result<()> {
|
||||
// Create migrations table
|
||||
self
|
||||
.client
|
||||
.query(
|
||||
"
|
||||
CREATE TABLE IF NOT EXISTS db_migrations (
|
||||
version String,
|
||||
success UInt8,
|
||||
apply_ts DateTime64 DEFAULT now()
|
||||
) ENGINE = ReplacingMergeTree(apply_ts)
|
||||
ORDER BY version
|
||||
",
|
||||
)
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("failed to create migrations table: {}", e))?;
|
||||
|
||||
// Get current DB version
|
||||
#[derive(Row, Deserialize)]
|
||||
struct MigrationRow {
|
||||
version: String,
|
||||
success: u8,
|
||||
}
|
||||
|
||||
let last_migration = self
|
||||
.client
|
||||
.query("SELECT version, success FROM db_migrations ORDER BY apply_ts DESC LIMIT 1")
|
||||
.fetch_optional::<MigrationRow>()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("failed to fetch last migration: {}", e))?;
|
||||
|
||||
let (current_db_version, success) = last_migration
|
||||
.map(|r| (r.version, r.success == 1))
|
||||
.unwrap_or_else(|| ("v0.0.0".to_string(), true));
|
||||
|
||||
if current_db_version == crate::VERSION && success {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !success {
|
||||
error!(
|
||||
"previous migration to {} failed. retrying...",
|
||||
current_db_version
|
||||
);
|
||||
} else {
|
||||
info!(
|
||||
"migrating database from {} to {}",
|
||||
current_db_version,
|
||||
crate::VERSION
|
||||
);
|
||||
}
|
||||
self.run_migrations(¤t_db_version, success).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_migrations(&self, from_version: &str, last_success: bool) -> anyhow::Result<()> {
|
||||
if from_version < "v0.0.1" || (from_version == "v0.0.1" && !last_success) {
|
||||
info!("applying migration v0.0.1...");
|
||||
if let Err(e) = self.apply_v0_0_1().await {
|
||||
error!("migration v0.0.1 failed: {}", e);
|
||||
// Record failure
|
||||
let _ = self
|
||||
.client
|
||||
.query("INSERT INTO db_migrations (version, success) VALUES (?, 0)")
|
||||
.bind(crate::VERSION)
|
||||
.execute()
|
||||
.await;
|
||||
return Err(e);
|
||||
}
|
||||
// Record success
|
||||
self
|
||||
.client
|
||||
.query("INSERT INTO db_migrations (version, success) VALUES (?, 1)")
|
||||
.bind(crate::VERSION)
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("failed to record migration success: {}", e))?;
|
||||
info!("migration v0.0.1 applied successfully");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn apply_v0_0_1(&self) -> anyhow::Result<()> {
|
||||
// 1. Create table (tcp_log)
|
||||
let sql_create = r#"
|
||||
CREATE TABLE IF NOT EXISTS tcp_log (
|
||||
service LowCardinality(String),
|
||||
conn_ts DateTime64(3),
|
||||
duration UInt32,
|
||||
addr_family Enum8('unix'=1, 'ipv4'=2, 'ipv6'=10),
|
||||
ip IPv6,
|
||||
port UInt16,
|
||||
proxy_proto Enum8('None' = 0, 'V1' = 1, 'V2' = 2),
|
||||
bytes UInt64
|
||||
) ENGINE = MergeTree()
|
||||
ORDER BY (service, conn_ts);
|
||||
"#;
|
||||
self.client.query(sql_create).execute().await?;
|
||||
|
||||
// 2. Create View
|
||||
let sql_view_refined = r#"
|
||||
CREATE VIEW IF NOT EXISTS tcp_log_view AS
|
||||
SELECT
|
||||
service, conn_ts, duration, addr_family,
|
||||
multiIf(
|
||||
addr_family = 1, 'unix socket',
|
||||
addr_family = 2, IPv4NumToString(toIPv4(ip)),
|
||||
IPv6NumToString(ip)
|
||||
) as ip_str,
|
||||
port,
|
||||
proxy_proto,
|
||||
formatReadableSize(bytes) AS traffic
|
||||
FROM tcp_log
|
||||
"#;
|
||||
|
||||
self.client.query(sql_view_refined).execute().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert_log(&self, log: TcpLog) -> anyhow::Result<()> {
|
||||
let ipv6 = match log.ip {
|
||||
IpAddr::V4(ip) => ip.to_ipv6_mapped(),
|
||||
IpAddr::V6(ip) => ip,
|
||||
};
|
||||
|
||||
let row = TcpLogNew {
|
||||
service: log.service,
|
||||
conn_ts: log.conn_ts,
|
||||
duration: log.duration,
|
||||
addr_family: log.addr_family,
|
||||
ip: ipv6,
|
||||
port: log.port,
|
||||
proxy_proto: log.proxy_proto,
|
||||
bytes: log.bytes,
|
||||
};
|
||||
|
||||
let mut insert = self.client.insert("tcp_log")?;
|
||||
insert.write(&row).await?;
|
||||
insert.end().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ use std::env;
|
||||
use std::path::Path;
|
||||
use tracing::{error, info};
|
||||
|
||||
pub const VERSION: &str = concat!("v", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
fn print_help() {
|
||||
println!("traudit - a reverse proxy with auditing capabilities");
|
||||
println!();
|
||||
@@ -18,6 +20,7 @@ fn print_help() {
|
||||
println!("options:");
|
||||
println!(" -f <config_file> path to the yaml configuration file");
|
||||
println!(" -t, --test test configuration and exit");
|
||||
println!(" -v, --version print version");
|
||||
println!(" -h, --help print this help message");
|
||||
println!();
|
||||
println!("project: https://github.com/awfufu/traudit");
|
||||
@@ -56,6 +59,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
print_help();
|
||||
return Ok(());
|
||||
}
|
||||
"-v" | "--version" => {
|
||||
println!("{}", VERSION);
|
||||
return Ok(());
|
||||
}
|
||||
_ => {
|
||||
bail!("unknown argument: {}\n\nuse -h for help", args[i]);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user