feat: 支持 ipv4/ipv6 双栈访问

This commit is contained in:
mofeng
2026-01-30 14:47:41 +08:00
parent 58f9020192
commit 6a110258b9
13 changed files with 465 additions and 107 deletions

View File

@@ -1,9 +1,11 @@
use std::net::SocketAddr;
use std::collections::HashSet;
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::sync::Arc;
use axum_server::tls_rustls::RustlsConfig;
use clap::{Parser, ValueEnum};
use futures::{stream::FuturesUnordered, StreamExt};
use rustls::crypto::{ring, CryptoProvider};
use tokio::sync::broadcast;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
@@ -19,6 +21,7 @@ use one_kvm::msd::MsdController;
use one_kvm::otg::{configfs, OtgService};
use one_kvm::rustdesk::RustDeskService;
use one_kvm::state::AppState;
use one_kvm::utils::bind_tcp_listener;
use one_kvm::video::format::{PixelFormat, Resolution};
use one_kvm::video::{Streamer, VideoStreamManager};
use one_kvm::web;
@@ -134,7 +137,8 @@ async fn main() -> anyhow::Result<()> {
// Apply CLI argument overrides to config (only if explicitly specified)
if let Some(addr) = args.address {
config.web.bind_address = addr;
config.web.bind_address = addr.clone();
config.web.bind_addresses = vec![addr];
}
if let Some(port) = args.http_port {
config.web.http_port = port;
@@ -153,19 +157,18 @@ async fn main() -> anyhow::Result<()> {
config.web.ssl_key_path = Some(key_path.to_string_lossy().to_string());
}
// Log final configuration
if config.web.https_enabled {
tracing::info!(
"Server will listen on: https://{}:{}",
config.web.bind_address,
config.web.https_port
);
let bind_ips = resolve_bind_addresses(&config.web)?;
let scheme = if config.web.https_enabled { "https" } else { "http" };
let bind_port = if config.web.https_enabled {
config.web.https_port
} else {
tracing::info!(
"Server will listen on: http://{}:{}",
config.web.bind_address,
config.web.http_port
);
config.web.http_port
};
// Log final configuration
for ip in &bind_ips {
let addr = SocketAddr::new(*ip, bind_port);
tracing::info!("Server will listen on: {}://{}", scheme, addr);
}
// Initialize session store
@@ -598,12 +601,8 @@ async fn main() -> anyhow::Result<()> {
// Create router
let app = web::create_router(state.clone());
// Determine bind address based on HTTPS setting
let bind_addr: SocketAddr = if config.web.https_enabled {
format!("{}:{}", config.web.bind_address, config.web.https_port).parse()?
} else {
format!("{}:{}", config.web.bind_address, config.web.http_port).parse()?
};
// Bind sockets for configured addresses
let listeners = bind_tcp_listeners(&bind_ips, bind_port)?;
// Setup graceful shutdown
let shutdown_signal = async move {
@@ -640,33 +639,44 @@ async fn main() -> anyhow::Result<()> {
RustlsConfig::from_pem_file(&cert_path, &key_path).await?
};
tracing::info!("Starting HTTPS server on {}", bind_addr);
let mut servers = FuturesUnordered::new();
for listener in listeners {
let local_addr = listener.local_addr()?;
tracing::info!("Starting HTTPS server on {}", local_addr);
let server = axum_server::bind_rustls(bind_addr, tls_config).serve(app.into_make_service());
let server = axum_server::from_tcp_rustls(listener, tls_config.clone())?
.serve(app.clone().into_make_service());
servers.push(async move { server.await });
}
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = server => {
if let Err(e) = result {
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("HTTPS server error: {}", e);
}
cleanup(&state).await;
}
}
} else {
tracing::info!("Starting HTTP server on {}", bind_addr);
let mut servers = FuturesUnordered::new();
for listener in listeners {
let local_addr = listener.local_addr()?;
tracing::info!("Starting HTTP server on {}", local_addr);
let listener = tokio::net::TcpListener::bind(bind_addr).await?;
let server = axum::serve(listener, app);
let listener = tokio::net::TcpListener::from_std(listener)?;
let server = axum::serve(listener, app.clone());
servers.push(async move { server.await });
}
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = server => {
if let Err(e) = result {
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("HTTP server error: {}", e);
}
cleanup(&state).await;
@@ -719,6 +729,47 @@ fn get_data_dir() -> PathBuf {
PathBuf::from("/etc/one-kvm")
}
/// Resolve bind IPs from config, preferring bind_addresses when set.
fn resolve_bind_addresses(web: &config::WebConfig) -> anyhow::Result<Vec<IpAddr>> {
let raw_addrs = if !web.bind_addresses.is_empty() {
web.bind_addresses.as_slice()
} else {
std::slice::from_ref(&web.bind_address)
};
let mut seen = HashSet::new();
let mut addrs = Vec::new();
for addr in raw_addrs {
let ip: IpAddr = addr
.parse()
.map_err(|_| anyhow::anyhow!("Invalid bind address: {}", addr))?;
if seen.insert(ip) {
addrs.push(ip);
}
}
Ok(addrs)
}
fn bind_tcp_listeners(addrs: &[IpAddr], port: u16) -> anyhow::Result<Vec<std::net::TcpListener>> {
let mut listeners = Vec::new();
for ip in addrs {
let addr = SocketAddr::new(*ip, port);
match bind_tcp_listener(addr) {
Ok(listener) => listeners.push(listener),
Err(err) => {
tracing::warn!("Failed to bind {}: {}", addr, err);
}
}
}
if listeners.is_empty() {
anyhow::bail!("Failed to bind any addresses on port {}", port);
}
Ok(listeners)
}
/// Parse video format and resolution from config (avoids code duplication)
fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
let format = config