mirror of
https://github.com/awfufu/traudit
synced 2026-03-01 05:29:44 +08:00
feat: support unix socket in forward_to and fix log writing
This commit is contained in:
@@ -59,8 +59,16 @@ pub async fn zero_copy_bidirectional(
|
||||
) -> (u64, io::Result<()>) {
|
||||
// We own the streams now, so we can split references to them for the join.
|
||||
let ((c2s_bytes, c2s_res), (s2c_bytes, s2c_res)) = tokio::join!(
|
||||
splice_loop(&inbound, &outbound),
|
||||
splice_loop(&outbound, &inbound)
|
||||
async {
|
||||
let res = splice_loop(&inbound, &outbound).await;
|
||||
let _ = outbound.shutdown_write();
|
||||
res
|
||||
},
|
||||
async {
|
||||
let res = splice_loop(&outbound, &inbound).await;
|
||||
let _ = inbound.shutdown_write();
|
||||
res
|
||||
}
|
||||
);
|
||||
|
||||
let total = c2s_bytes + s2c_bytes;
|
||||
|
||||
@@ -371,11 +371,18 @@ async fn handle_connection(
|
||||
buffer = buf;
|
||||
if let Some(info) = proxy_info {
|
||||
let physical = inbound.peer_addr_string()?;
|
||||
// INFO [ssh] unix://./test.sock <- 192.168.1.1:12345 (unix_socket)
|
||||
// Or INFO [ssh] 0.0.0.0:2222 <- 1.2.3.4:5678 (1.2.3.4:5678)
|
||||
|
||||
// Format: [ssh] unix://test.sock <- RealIP:Port (local)
|
||||
// or [ssh] 0.0.0.0:2222 <- RealIP:Port (1.2.3.4:5678)
|
||||
let physical_fmt = if matches!(inbound, InboundStream::Unix(_)) {
|
||||
"local".to_string()
|
||||
} else {
|
||||
physical
|
||||
};
|
||||
|
||||
info!(
|
||||
"[{}] {} <- {} ({})",
|
||||
service.name, listen_addr, info.source, physical
|
||||
service.name, listen_addr, info.source, physical_fmt
|
||||
);
|
||||
final_ip = info.source.ip();
|
||||
final_port = info.source.port();
|
||||
@@ -416,11 +423,8 @@ async fn handle_connection(
|
||||
}
|
||||
} else {
|
||||
let addr = if matches!(inbound, InboundStream::Unix(_)) {
|
||||
// If Unix socket without proxy, display 127.0.0.1:0 as per logic or ...
|
||||
// User requested: unix://... <- 127.0.0.1:port
|
||||
// But inbound.peer_addr_string() for unix is "unix_socket"
|
||||
// And we set final_ip to 127.0.0.1, final_port to 0
|
||||
format!("{}:{}", final_ip, final_port)
|
||||
// [ssh] unix://test.sock <- local
|
||||
"local".to_string()
|
||||
} else {
|
||||
inbound.peer_addr_string()?
|
||||
};
|
||||
|
||||
@@ -16,6 +16,9 @@ impl UpstreamStream {
|
||||
if addr.starts_with('/') {
|
||||
let stream = UnixStream::connect(addr).await?;
|
||||
Ok(UpstreamStream::Unix(stream))
|
||||
} else if let Some(path) = addr.strip_prefix("unix://") {
|
||||
let stream = UnixStream::connect(path).await?;
|
||||
Ok(UpstreamStream::Unix(stream))
|
||||
} else {
|
||||
let stream = TcpStream::connect(addr).await?;
|
||||
stream.set_nodelay(true)?;
|
||||
@@ -121,6 +124,13 @@ impl AsyncStream {
|
||||
AsyncStream::Unix(fd) => perform_splice_write(fd, pipe_in, len).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shutdown_write(&self) -> io::Result<()> {
|
||||
match self {
|
||||
AsyncStream::Tcp(fd) => fd.get_ref().shutdown(std::net::Shutdown::Write),
|
||||
AsyncStream::Unix(fd) => fd.get_ref().shutdown(std::net::Shutdown::Write),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_splice_read<T: AsRawFd>(
|
||||
@@ -178,3 +188,49 @@ async fn perform_splice_write<T: AsRawFd>(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
use tokio::net::UnixListener;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connect_unix_scheme() {
|
||||
let dir = tempdir().unwrap();
|
||||
let socket_path = dir.path().join("test_scheme.sock");
|
||||
let socket_path_str = socket_path.to_str().unwrap();
|
||||
|
||||
// Start a listener
|
||||
let _listener = UnixListener::bind(&socket_path).unwrap();
|
||||
|
||||
// Test unix:// path
|
||||
let addr = format!("unix://{}", socket_path_str);
|
||||
let stream = UpstreamStream::connect(&addr).await;
|
||||
assert!(
|
||||
stream.is_ok(),
|
||||
"Failed to connect to unix socket with unix:// prefix: {:?}",
|
||||
stream.err()
|
||||
);
|
||||
assert!(matches!(stream.unwrap(), UpstreamStream::Unix(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connect_unix_absolute_path() {
|
||||
let dir = tempdir().unwrap();
|
||||
let socket_path = dir.path().join("test_abs.sock");
|
||||
let socket_path_str = socket_path.to_str().unwrap();
|
||||
|
||||
// Start a listener
|
||||
let _listener = UnixListener::bind(&socket_path).unwrap();
|
||||
|
||||
// Test absolute path (legacy support)
|
||||
let stream = UpstreamStream::connect(socket_path_str).await;
|
||||
assert!(
|
||||
stream.is_ok(),
|
||||
"Failed to connect to unix socket with absolute path: {:?}",
|
||||
stream.err()
|
||||
);
|
||||
assert!(matches!(stream.unwrap(), UpstreamStream::Unix(_)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
[Unit]
|
||||
Description=Traudit Reverse Proxy
|
||||
Description=Traudit Reverse Proxy (https://github.com/awfufu/traudit)
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
Type=notify
|
||||
RuntimeDirectory=traudit
|
||||
WorkingDirectory=/run/traudit
|
||||
ExecStart=/usr/bin/traudit -f /etc/traudit/config.yaml
|
||||
Restart=on-failure
|
||||
RestartSec=5s
|
||||
|
||||
Reference in New Issue
Block a user