feat: 支持 ipv4/ipv6 双栈访问

This commit is contained in:
mofeng
2026-01-30 14:47:41 +08:00
parent 58f9020192
commit 6a110258b9
13 changed files with 465 additions and 107 deletions

View File

@@ -576,7 +576,9 @@ pub struct WebConfig {
pub http_port: u16,
/// HTTPS port
pub https_port: u16,
/// Bind address
/// Bind addresses (preferred)
pub bind_addresses: Vec<String>,
/// Bind address (legacy)
pub bind_address: String,
/// Enable HTTPS
pub https_enabled: bool,
@@ -591,6 +593,7 @@ impl Default for WebConfig {
Self {
http_port: 8080,
https_port: 8443,
bind_addresses: Vec::new(),
bind_address: "0.0.0.0".to_string(),
https_enabled: false,
ssl_cert_path: None,

View File

@@ -1,9 +1,11 @@
use std::net::SocketAddr;
use std::collections::HashSet;
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::sync::Arc;
use axum_server::tls_rustls::RustlsConfig;
use clap::{Parser, ValueEnum};
use futures::{stream::FuturesUnordered, StreamExt};
use rustls::crypto::{ring, CryptoProvider};
use tokio::sync::broadcast;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
@@ -19,6 +21,7 @@ use one_kvm::msd::MsdController;
use one_kvm::otg::{configfs, OtgService};
use one_kvm::rustdesk::RustDeskService;
use one_kvm::state::AppState;
use one_kvm::utils::bind_tcp_listener;
use one_kvm::video::format::{PixelFormat, Resolution};
use one_kvm::video::{Streamer, VideoStreamManager};
use one_kvm::web;
@@ -134,7 +137,8 @@ async fn main() -> anyhow::Result<()> {
// Apply CLI argument overrides to config (only if explicitly specified)
if let Some(addr) = args.address {
config.web.bind_address = addr;
config.web.bind_address = addr.clone();
config.web.bind_addresses = vec![addr];
}
if let Some(port) = args.http_port {
config.web.http_port = port;
@@ -153,19 +157,18 @@ async fn main() -> anyhow::Result<()> {
config.web.ssl_key_path = Some(key_path.to_string_lossy().to_string());
}
// Log final configuration
if config.web.https_enabled {
tracing::info!(
"Server will listen on: https://{}:{}",
config.web.bind_address,
config.web.https_port
);
let bind_ips = resolve_bind_addresses(&config.web)?;
let scheme = if config.web.https_enabled { "https" } else { "http" };
let bind_port = if config.web.https_enabled {
config.web.https_port
} else {
tracing::info!(
"Server will listen on: http://{}:{}",
config.web.bind_address,
config.web.http_port
);
config.web.http_port
};
// Log final configuration
for ip in &bind_ips {
let addr = SocketAddr::new(*ip, bind_port);
tracing::info!("Server will listen on: {}://{}", scheme, addr);
}
// Initialize session store
@@ -598,12 +601,8 @@ async fn main() -> anyhow::Result<()> {
// Create router
let app = web::create_router(state.clone());
// Determine bind address based on HTTPS setting
let bind_addr: SocketAddr = if config.web.https_enabled {
format!("{}:{}", config.web.bind_address, config.web.https_port).parse()?
} else {
format!("{}:{}", config.web.bind_address, config.web.http_port).parse()?
};
// Bind sockets for configured addresses
let listeners = bind_tcp_listeners(&bind_ips, bind_port)?;
// Setup graceful shutdown
let shutdown_signal = async move {
@@ -640,33 +639,44 @@ async fn main() -> anyhow::Result<()> {
RustlsConfig::from_pem_file(&cert_path, &key_path).await?
};
tracing::info!("Starting HTTPS server on {}", bind_addr);
let mut servers = FuturesUnordered::new();
for listener in listeners {
let local_addr = listener.local_addr()?;
tracing::info!("Starting HTTPS server on {}", local_addr);
let server = axum_server::bind_rustls(bind_addr, tls_config).serve(app.into_make_service());
let server = axum_server::from_tcp_rustls(listener, tls_config.clone())?
.serve(app.clone().into_make_service());
servers.push(async move { server.await });
}
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = server => {
if let Err(e) = result {
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("HTTPS server error: {}", e);
}
cleanup(&state).await;
}
}
} else {
tracing::info!("Starting HTTP server on {}", bind_addr);
let mut servers = FuturesUnordered::new();
for listener in listeners {
let local_addr = listener.local_addr()?;
tracing::info!("Starting HTTP server on {}", local_addr);
let listener = tokio::net::TcpListener::bind(bind_addr).await?;
let server = axum::serve(listener, app);
let listener = tokio::net::TcpListener::from_std(listener)?;
let server = axum::serve(listener, app.clone());
servers.push(async move { server.await });
}
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = server => {
if let Err(e) = result {
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("HTTP server error: {}", e);
}
cleanup(&state).await;
@@ -719,6 +729,47 @@ fn get_data_dir() -> PathBuf {
PathBuf::from("/etc/one-kvm")
}
/// Resolve bind IPs from config, preferring bind_addresses when set.
fn resolve_bind_addresses(web: &config::WebConfig) -> anyhow::Result<Vec<IpAddr>> {
let raw_addrs = if !web.bind_addresses.is_empty() {
web.bind_addresses.as_slice()
} else {
std::slice::from_ref(&web.bind_address)
};
let mut seen = HashSet::new();
let mut addrs = Vec::new();
for addr in raw_addrs {
let ip: IpAddr = addr
.parse()
.map_err(|_| anyhow::anyhow!("Invalid bind address: {}", addr))?;
if seen.insert(ip) {
addrs.push(ip);
}
}
Ok(addrs)
}
fn bind_tcp_listeners(addrs: &[IpAddr], port: u16) -> anyhow::Result<Vec<std::net::TcpListener>> {
let mut listeners = Vec::new();
for ip in addrs {
let addr = SocketAddr::new(*ip, port);
match bind_tcp_listener(addr) {
Ok(listener) => listeners.push(listener),
Err(err) => {
tracing::warn!("Failed to bind {}: {}", addr, err);
}
}
}
if listeners.is_empty() {
anyhow::bail!("Failed to bind any addresses on port {}", port);
}
Ok(listeners)
}
/// Parse video format and resolution from config (avoids code duplication)
fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
let format = config

View File

@@ -23,7 +23,7 @@ pub mod protocol;
pub mod punch;
pub mod rendezvous;
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
@@ -37,6 +37,7 @@ use tracing::{debug, error, info, warn};
use crate::audio::AudioController;
use crate::hid::HidController;
use crate::video::stream_manager::VideoStreamManager;
use crate::utils::bind_tcp_listener;
use self::config::RustDeskConfig;
use self::connection::ConnectionManager;
@@ -84,7 +85,7 @@ pub struct RustDeskService {
status: Arc<RwLock<ServiceStatus>>,
rendezvous: Arc<RwLock<Option<Arc<RendezvousMediator>>>>,
rendezvous_handle: Arc<RwLock<Option<JoinHandle<()>>>>,
tcp_listener_handle: Arc<RwLock<Option<JoinHandle<()>>>>,
tcp_listener_handle: Arc<RwLock<Option<Vec<JoinHandle<()>>>>>,
listen_port: Arc<RwLock<u16>>,
connection_manager: Arc<ConnectionManager>,
video_manager: Arc<VideoStreamManager>,
@@ -212,8 +213,8 @@ impl RustDeskService {
// Start TCP listener BEFORE the rendezvous mediator to ensure port is set correctly
// This prevents race condition where mediator starts registration with wrong port
let (tcp_handle, listen_port) = self.start_tcp_listener_with_port().await?;
*self.tcp_listener_handle.write() = Some(tcp_handle);
let (tcp_handles, listen_port) = self.start_tcp_listener_with_port().await?;
*self.tcp_listener_handle.write() = Some(tcp_handles);
// Set the listen port on mediator before starting the registration loop
mediator.set_listen_port(listen_port);
@@ -373,52 +374,83 @@ impl RustDeskService {
/// Start TCP listener for direct peer connections
/// Returns the join handle and the port that was bound
async fn start_tcp_listener_with_port(&self) -> anyhow::Result<(JoinHandle<()>, u16)> {
async fn start_tcp_listener_with_port(&self) -> anyhow::Result<(Vec<JoinHandle<()>>, u16)> {
// Try to bind to the default port, or find an available port
let listener = match TcpListener::bind(format!("0.0.0.0:{}", DIRECT_LISTEN_PORT)).await {
Ok(l) => l,
Err(_) => {
// Try binding to port 0 to get an available port
TcpListener::bind("0.0.0.0:0").await?
let (listeners, listen_port) = match self.bind_direct_listeners(DIRECT_LISTEN_PORT) {
Ok(result) => result,
Err(err) => {
warn!(
"Failed to bind RustDesk TCP on port {}: {}, falling back to random port",
DIRECT_LISTEN_PORT, err
);
self.bind_direct_listeners(0)?
}
};
let local_addr = listener.local_addr()?;
let listen_port = local_addr.port();
*self.listen_port.write() = listen_port;
info!("RustDesk TCP listener started on {}", local_addr);
let connection_manager = self.connection_manager.clone();
let mut shutdown_rx = self.shutdown_tx.subscribe();
let mut handles = Vec::new();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, peer_addr)) => {
info!("Accepted direct connection from {}", peer_addr);
let conn_mgr = connection_manager.clone();
tokio::spawn(async move {
if let Err(e) = conn_mgr.accept_connection(stream, peer_addr).await {
error!("Failed to handle direct connection from {}: {}", peer_addr, e);
}
});
}
Err(e) => {
error!("TCP accept error: {}", e);
for listener in listeners {
let local_addr = listener.local_addr()?;
info!("RustDesk TCP listener started on {}", local_addr);
let conn_mgr = connection_manager.clone();
let mut shutdown_rx = self.shutdown_tx.subscribe();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, peer_addr)) => {
info!("Accepted direct connection from {}", peer_addr);
let conn_mgr = conn_mgr.clone();
tokio::spawn(async move {
if let Err(e) = conn_mgr.accept_connection(stream, peer_addr).await {
error!("Failed to handle direct connection from {}: {}", peer_addr, e);
}
});
}
Err(e) => {
error!("TCP accept error: {}", e);
}
}
}
}
_ = shutdown_rx.recv() => {
info!("TCP listener shutting down");
break;
_ = shutdown_rx.recv() => {
info!("TCP listener shutting down");
break;
}
}
}
}
});
});
handles.push(handle);
}
Ok((handle, listen_port))
Ok((handles, listen_port))
}
fn bind_direct_listeners(&self, port: u16) -> anyhow::Result<(Vec<TcpListener>, u16)> {
let v4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port);
let v4_listener = bind_tcp_listener(v4_addr)?;
let listen_port = v4_listener.local_addr()?.port();
let mut listeners = vec![TcpListener::from_std(v4_listener)?];
let v6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), listen_port);
match bind_tcp_listener(v6_addr) {
Ok(v6_listener) => {
listeners.push(TcpListener::from_std(v6_listener)?);
}
Err(err) => {
warn!(
"IPv6 listener unavailable on port {}: {}, continuing with IPv4 only",
listen_port, err
);
}
}
Ok((listeners, listen_port))
}
/// Stop the RustDesk service
@@ -446,8 +478,10 @@ impl RustDeskService {
}
// Wait for TCP listener task to finish
if let Some(handle) = self.tcp_listener_handle.write().take() {
handle.abort();
if let Some(handles) = self.tcp_listener_handle.write().take() {
for handle in handles {
handle.abort();
}
}
*self.rendezvous.write() = None;

View File

@@ -4,7 +4,7 @@
//! It registers the device ID and public key, handles punch hole requests,
//! and relay requests.
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
@@ -15,6 +15,8 @@ use tokio::sync::broadcast;
use tokio::time::interval;
use tracing::{debug, error, info, warn};
use crate::utils::bind_udp_socket;
use super::config::RustDeskConfig;
use super::crypto::{KeyPair, SigningKeyPair};
use super::protocol::{
@@ -288,8 +290,13 @@ impl RendezvousMediator {
.next()
.ok_or_else(|| anyhow::anyhow!("Failed to resolve {}", addr))?;
// Create UDP socket
let socket = UdpSocket::bind("0.0.0.0:0").await?;
// Create UDP socket (match address family, enforce IPV6_V6ONLY)
let bind_addr = match server_addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
};
let std_socket = bind_udp_socket(bind_addr)?;
let socket = UdpSocket::from_std(std_socket)?;
socket.connect(server_addr).await?;
info!("Connected to rendezvous server at {}", server_addr);

View File

@@ -3,5 +3,7 @@
//! This module contains common utilities used across the codebase.
pub mod throttle;
pub mod net;
pub use throttle::LogThrottler;
pub use net::{bind_tcp_listener, bind_udp_socket};

84
src/utils/net.rs Normal file
View File

@@ -0,0 +1,84 @@
//! Networking helpers for binding sockets with explicit IPv6-only behavior.
use std::io;
use std::net::{SocketAddr, TcpListener, UdpSocket};
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
use nix::sys::socket::{
self, sockopt, AddressFamily, Backlog, SockFlag, SockProtocol, SockType, SockaddrIn,
SockaddrIn6,
};
fn socket_addr_family(addr: &SocketAddr) -> AddressFamily {
match addr {
SocketAddr::V4(_) => AddressFamily::Inet,
SocketAddr::V6(_) => AddressFamily::Inet6,
}
}
/// Bind a TCP listener with IPv6-only set for IPv6 sockets.
pub fn bind_tcp_listener(addr: SocketAddr) -> io::Result<TcpListener> {
let domain = socket_addr_family(&addr);
let fd = socket::socket(
domain,
SockType::Stream,
SockFlag::SOCK_CLOEXEC,
SockProtocol::Tcp,
)
.map_err(io::Error::from)?;
socket::setsockopt(&fd, sockopt::ReuseAddr, &true).map_err(io::Error::from)?;
if matches!(addr, SocketAddr::V6(_)) {
socket::setsockopt(&fd, sockopt::Ipv6V6Only, &true).map_err(io::Error::from)?;
}
match addr {
SocketAddr::V4(v4) => {
let sockaddr = SockaddrIn::from(v4);
socket::bind(fd.as_raw_fd(), &sockaddr).map_err(io::Error::from)?;
}
SocketAddr::V6(v6) => {
let sockaddr = SockaddrIn6::from(v6);
socket::bind(fd.as_raw_fd(), &sockaddr).map_err(io::Error::from)?;
}
}
socket::listen(&fd, Backlog::MAXCONN).map_err(io::Error::from)?;
let listener = unsafe { TcpListener::from_raw_fd(fd.into_raw_fd()) };
listener.set_nonblocking(true)?;
Ok(listener)
}
/// Bind a UDP socket with IPv6-only set for IPv6 sockets.
pub fn bind_udp_socket(addr: SocketAddr) -> io::Result<UdpSocket> {
let domain = socket_addr_family(&addr);
let fd = socket::socket(
domain,
SockType::Datagram,
SockFlag::SOCK_CLOEXEC,
SockProtocol::Udp,
)
.map_err(io::Error::from)?;
socket::setsockopt(&fd, sockopt::ReuseAddr, &true).map_err(io::Error::from)?;
if matches!(addr, SocketAddr::V6(_)) {
socket::setsockopt(&fd, sockopt::Ipv6V6Only, &true).map_err(io::Error::from)?;
}
match addr {
SocketAddr::V4(v4) => {
let sockaddr = SockaddrIn::from(v4);
socket::bind(fd.as_raw_fd(), &sockaddr).map_err(io::Error::from)?;
}
SocketAddr::V6(v6) => {
let sockaddr = SockaddrIn6::from(v6);
socket::bind(fd.as_raw_fd(), &sockaddr).map_err(io::Error::from)?;
}
}
let socket = unsafe { UdpSocket::from_raw_fd(fd.into_raw_fd()) };
socket.set_nonblocking(true)?;
Ok(socket)
}

View File

@@ -610,6 +610,7 @@ impl RustDeskConfigUpdate {
pub struct WebConfigUpdate {
pub http_port: Option<u16>,
pub https_port: Option<u16>,
pub bind_addresses: Option<Vec<String>>,
pub bind_address: Option<String>,
pub https_enabled: Option<bool>,
}
@@ -626,6 +627,13 @@ impl WebConfigUpdate {
return Err(AppError::BadRequest("HTTPS port cannot be 0".into()));
}
}
if let Some(ref addrs) = self.bind_addresses {
for addr in addrs {
if addr.parse::<std::net::IpAddr>().is_err() {
return Err(AppError::BadRequest("Invalid bind address".into()));
}
}
}
if let Some(ref addr) = self.bind_address {
if addr.parse::<std::net::IpAddr>().is_err() {
return Err(AppError::BadRequest("Invalid bind address".into()));
@@ -641,8 +649,16 @@ impl WebConfigUpdate {
if let Some(port) = self.https_port {
config.https_port = port;
}
if let Some(ref addr) = self.bind_address {
if let Some(ref addrs) = self.bind_addresses {
config.bind_addresses = addrs.clone();
if let Some(first) = addrs.first() {
config.bind_address = first.clone();
}
} else if let Some(ref addr) = self.bind_address {
config.bind_address = addr.clone();
if config.bind_addresses.is_empty() {
config.bind_addresses = vec![addr.clone()];
}
}
if let Some(enabled) = self.https_enabled {
config.https_enabled = enabled;

View File

@@ -316,28 +316,11 @@ fn get_network_addresses() -> Vec<NetworkAddress> {
Err(_) => return Vec::new(),
};
// Build a map of interface name -> IPv4 address
let mut ipv4_map: std::collections::HashMap<String, String> = std::collections::HashMap::new();
for ifaddr in all_addrs {
// Skip loopback
if ifaddr.interface_name == "lo" {
continue;
}
// Only collect IPv4 addresses (skip if already have one for this interface)
if !ipv4_map.contains_key(&ifaddr.interface_name) {
if let Some(addr) = ifaddr.address {
if let Some(sockaddr_in) = addr.as_sockaddr_in() {
ipv4_map.insert(ifaddr.interface_name.clone(), sockaddr_in.ip().to_string());
}
}
}
}
// Now check which interfaces are up
let mut addresses = Vec::new();
// Check which interfaces are up
let mut up_ifaces = std::collections::HashSet::new();
let net_dir = match std::fs::read_dir("/sys/class/net") {
Ok(dir) => dir,
Err(_) => return addresses,
Err(_) => return Vec::new(),
};
for entry in net_dir.flatten() {
@@ -361,12 +344,43 @@ fn get_network_addresses() -> Vec<NetworkAddress> {
continue;
}
// Get IP from pre-fetched map
if let Some(ip) = ipv4_map.remove(&iface_name) {
addresses.push(NetworkAddress {
interface: iface_name,
ip,
});
up_ifaces.insert(iface_name);
}
let mut addresses = Vec::new();
let mut seen = std::collections::HashSet::new();
for ifaddr in all_addrs {
let iface_name = &ifaddr.interface_name;
if iface_name == "lo" || !up_ifaces.contains(iface_name) {
continue;
}
if let Some(addr) = ifaddr.address {
if let Some(sockaddr_in) = addr.as_sockaddr_in() {
let ip = sockaddr_in.ip();
if ip.is_loopback() {
continue;
}
let ip_str = ip.to_string();
if seen.insert((iface_name.clone(), ip_str.clone())) {
addresses.push(NetworkAddress {
interface: iface_name.clone(),
ip: ip_str,
});
}
} else if let Some(sockaddr_in6) = addr.as_sockaddr_in6() {
let ip = sockaddr_in6.ip();
if ip.is_loopback() || ip.is_unspecified() || ip.is_unicast_link_local() {
continue;
}
let ip_str = ip.to_string();
if seen.insert((iface_name.clone(), ip_str.clone())) {
addresses.push(NetworkAddress {
interface: iface_name.clone(),
ip: ip_str,
});
}
}
}
}