diff --git a/src/main.rs b/src/main.rs index 9488eb86..7e37434b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,7 +27,7 @@ use one_kvm::otg::OtgService; use one_kvm::platform::PlatformCapabilities; use one_kvm::rtsp::RtspService; use one_kvm::rustdesk::RustDeskService; -use one_kvm::state::AppState; +use one_kvm::state::{AppState, ShutdownAction}; use one_kvm::update::UpdateService; use one_kvm::utils::bind_tcp_listener; use one_kvm::video::codec_constraints::{ @@ -181,7 +181,7 @@ async fn main() -> anyhow::Result<()> { let user_store = UserStore::new(db.clone_pool()); - let (shutdown_tx, _) = broadcast::channel::<()>(1); + let (shutdown_tx, _) = broadcast::channel::(1); let events = Arc::new(EventBus::new()); tracing::info!("Event bus initialized"); @@ -622,15 +622,32 @@ async fn main() -> anyhow::Result<()> { let listeners = bind_tcp_listeners(&bind_ips, bind_port)?; - let shutdown_signal = async move { - tokio::signal::ctrl_c() - .await - .expect("Failed to install CTRL+C handler"); - tracing::info!("Shutdown signal received"); - let _ = shutdown_tx.send(()); + let shutdown_signal = { + let mut shutdown_rx = shutdown_tx.subscribe(); + async move { + tokio::select! { + result = tokio::signal::ctrl_c() => { + result.expect("Failed to install CTRL+C handler"); + tracing::info!("Shutdown signal received"); + ShutdownAction::Exit + } + request = shutdown_rx.recv() => { + match request { + Ok(action) => { + tracing::info!("Shutdown request received: {:?}", action); + action + } + Err(e) => { + tracing::warn!("Shutdown request channel closed: {}", e); + ShutdownAction::Exit + } + } + } + } + } }; - if config.web.https_enabled { + let shutdown_action = if config.web.https_enabled { let tls_config = if let (Some(cert_path), Some(key_path)) = (&config.web.ssl_cert_path, &config.web.ssl_key_path) { @@ -663,7 +680,7 @@ async fn main() -> anyhow::Result<()> { servers.push(server); } - run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTPS").await; + run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTPS").await } else { let servers = FuturesUnordered::new(); for listener in listeners { @@ -675,10 +692,13 @@ async fn main() -> anyhow::Result<()> { servers.push(async move { server.await }); } - run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTP").await; - } + run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTP").await + }; tracing::info!("Server shutdown complete"); + if let ShutdownAction::Restart { exe_path } = shutdown_action { + restart_current_process(exe_path)?; + } Ok(()) } @@ -741,23 +761,46 @@ async fn open_database_pool(data_dir: &Path) -> anyhow::Result { async fn run_servers_until_shutdown( mut servers: FuturesUnordered, - shutdown_signal: impl Future, + shutdown_signal: impl Future, state: &Arc, protocol: &'static str, -) where +) -> ShutdownAction +where F: Future> + Send, E: std::fmt::Display, { - tokio::select! { - _ = shutdown_signal => { - cleanup(state).await; + let action = tokio::select! { + action = shutdown_signal => { + action } result = servers.next() => { if let Some(Err(e)) = result { tracing::error!("{} server error: {}", protocol, e); } - cleanup(state).await; + ShutdownAction::Exit } + }; + cleanup(state).await; + action +} + +fn restart_current_process(exe_path: Option) -> anyhow::Result<()> { + let exe = exe_path.unwrap_or(std::env::current_exe()?); + let args: Vec = std::env::args().skip(1).collect(); + + tracing::info!("Restarting: {:?} {:?}", exe, args); + + #[cfg(unix)] + { + use std::os::unix::process::CommandExt; + let err = std::process::Command::new(&exe).args(&args).exec(); + Err(anyhow::anyhow!("Failed to restart: {}", err)) + } + + #[cfg(not(unix))] + { + std::process::Command::new(&exe).args(&args).spawn()?; + std::process::exit(0); } } diff --git a/src/runtime/android.rs b/src/runtime/android.rs index d0bc60ba..1f1c9202 100644 --- a/src/runtime/android.rs +++ b/src/runtime/android.rs @@ -27,7 +27,7 @@ use crate::msd::MsdController; use crate::otg::OtgService; use crate::rtsp::RtspService; use crate::rustdesk::RustDeskService; -use crate::state::AppState; +use crate::state::{AppState, ShutdownAction}; use crate::stream_encoder::encoder_type_to_backend; use crate::update::UpdateService; use crate::utils::bind_tcp_listener; @@ -132,7 +132,7 @@ async fn run_async( ) -> Result<(), String> { let (db, config_store, app_config) = load_runtime_config(&PathBuf::from(&config.data_dir), &config).await?; - let (shutdown_tx, _) = broadcast::channel::<()>(1); + let (shutdown_tx, _) = broadcast::channel::(1); let state = build_app_state( PathBuf::from(&config.data_dir), db, @@ -156,10 +156,26 @@ async fn run_async( .map_err(|err| format!("failed to create tokio listener: {err}"))?; let server = axum::serve(listener, app); - let shutdown_signal = async move { - let _ = stop_rx.await; - tracing::info!("Android stop request received"); - let _ = shutdown_tx.send(()); + let shutdown_signal = { + let mut shutdown_rx = shutdown_tx.subscribe(); + async move { + tokio::select! { + _ = stop_rx => { + tracing::info!("Android stop request received"); + let _ = shutdown_tx.send(ShutdownAction::Exit); + } + request = shutdown_rx.recv() => { + match request { + Ok(action) => { + tracing::info!("Android shutdown request received: {:?}", action); + } + Err(err) => { + tracing::warn!("Android shutdown request channel closed: {}", err); + } + } + } + } + } }; tokio::select! { @@ -261,7 +277,7 @@ async fn build_app_state( db: DatabasePool, config_store: ConfigStore, config: AppConfig, - shutdown_tx: broadcast::Sender<()>, + shutdown_tx: broadcast::Sender, ) -> Result, String> { let session_store = SessionStore::new(config.auth.session_timeout_secs as i64); let user_store = UserStore::new(db.clone_pool()); diff --git a/src/state.rs b/src/state.rs index 0237673d..6c841d30 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,4 +1,4 @@ -use std::{collections::VecDeque, sync::Arc}; +use std::{collections::VecDeque, path::PathBuf, sync::Arc}; use tokio::sync::{broadcast, watch, Mutex, RwLock}; use crate::atx::AtxController; @@ -33,6 +33,12 @@ pub struct ConfigApplyLocks { pub rtsp: Arc>, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ShutdownAction { + Exit, + Restart { exe_path: Option }, +} + impl ConfigApplyLocks { fn new() -> Self { Self { @@ -68,7 +74,7 @@ pub struct AppState { pub events: Arc, device_info_tx: watch::Sender>, pub update: Arc, - pub shutdown_tx: broadcast::Sender<()>, + pub shutdown_tx: broadcast::Sender, pub revoked_sessions: Arc>>, pub config_apply_locks: ConfigApplyLocks, data_dir: std::path::PathBuf, @@ -93,7 +99,7 @@ impl AppState { extensions: Arc, events: Arc, update: Arc, - shutdown_tx: broadcast::Sender<()>, + shutdown_tx: broadcast::Sender, data_dir: std::path::PathBuf, ) -> Arc { let (device_info_tx, _device_info_rx) = watch::channel(None); @@ -129,10 +135,6 @@ impl AppState { &self.data_dir } - pub fn shutdown_signal(&self) -> broadcast::Receiver<()> { - self.shutdown_tx.subscribe() - } - pub fn subscribe_device_info(&self) -> watch::Receiver> { self.device_info_tx.subscribe() } diff --git a/src/update/mod.rs b/src/update/mod.rs index a56c6471..cec38e0c 100644 --- a/src/update/mod.rs +++ b/src/update/mod.rs @@ -9,6 +9,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::{broadcast, RwLock, Semaphore}; use crate::error::{AppError, Result}; +use crate::state::ShutdownAction; const DEFAULT_UPDATE_BASE_URL: &str = "https://update.one-kvm.cn"; @@ -128,7 +129,7 @@ impl UpdateService { success: true, phase: UpdatePhase::Idle, progress: 0, - current_version: env!("CARGO_PKG_VERSION").to_string(), + current_version: current_version_for_update(), target_version: None, message: None, last_error: None, @@ -144,7 +145,7 @@ impl UpdateService { pub async fn overview(&self, channel: UpdateChannel) -> Result { let (channels, releases) = self.fetch_manifests().await?; - let current_version = parse_version(env!("CARGO_PKG_VERSION"))?; + let current_version = parse_version(¤t_version_for_update())?; let latest_version = parse_version(&channel_head_version(&channels, channel))?; let current_parts = parse_version_parts(¤t_version)?; let latest_parts = parse_version_parts(&latest_version)?; @@ -197,7 +198,7 @@ impl UpdateService { pub fn start_upgrade( self: &Arc, req: UpgradeRequest, - shutdown_tx: broadcast::Sender<()>, + shutdown_tx: broadcast::Sender, ) -> Result<()> { if req.channel.is_none() == req.target_version.is_none() { return Err(AppError::BadRequest( @@ -233,7 +234,7 @@ impl UpdateService { async fn execute_upgrade( &self, req: UpgradeRequest, - shutdown_tx: broadcast::Sender<()>, + shutdown_tx: broadcast::Sender, ) -> Result<()> { self.set_status( UpdatePhase::Checking, @@ -246,7 +247,7 @@ impl UpdateService { let (channels, releases) = self.fetch_manifests().await?; - let current_version = parse_version(env!("CARGO_PKG_VERSION"))?; + let current_version = parse_version(¤t_version_for_update())?; let target_version = if let Some(channel) = req.channel { parse_version(&channel_head_version(&channels, channel))? } else { @@ -305,7 +306,7 @@ impl UpdateService { ) .await; - self.install_binary(&staging_path).await?; + let restart_exe = self.install_binary(&staging_path).await?; self.set_status( UpdatePhase::Restarting, @@ -316,10 +317,11 @@ impl UpdateService { ) .await; - let _ = shutdown_tx.send(()); - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - - restart_current_process()?; + shutdown_tx + .send(ShutdownAction::Restart { + exe_path: Some(restart_exe), + }) + .map_err(|e| AppError::Internal(format!("Failed to request restart: {}", e)))?; Ok(()) } @@ -399,7 +401,7 @@ impl UpdateService { Ok(()) } - async fn install_binary(&self, staging_path: &Path) -> Result<()> { + async fn install_binary(&self, staging_path: &Path) -> Result { let current_exe = std::env::current_exe() .map_err(|e| AppError::Internal(format!("Failed to get current exe path: {}", e)))?; let exe_dir = current_exe.parent().ok_or_else(|| { @@ -426,7 +428,7 @@ impl UpdateService { .await .map_err(|e| AppError::Internal(format!("Failed to replace executable {}", e)))?; - Ok(()) + Ok(current_exe) } async fn fetch_manifests(&self) -> Result<(ChannelsManifest, ReleasesManifest)> { @@ -481,10 +483,17 @@ impl UpdateService { status.message = message; status.last_error = last_error; status.success = status.phase != UpdatePhase::Failed; - status.current_version = env!("CARGO_PKG_VERSION").to_string(); + status.current_version = current_version_for_update(); } } +fn current_version_for_update() -> String { + std::env::var("ONE_KVM_UPDATE_CURRENT_VERSION") + .ok() + .filter(|value| !value.trim().is_empty()) + .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string()) +} + fn parse_version(input: &str) -> Result { parse_version_parts(input)?; Ok(input.to_string()) @@ -569,25 +578,3 @@ fn current_target_triple() -> Result { }; Ok(triple.to_string()) } - -fn restart_current_process() -> Result<()> { - let exe = std::env::current_exe() - .map_err(|e| AppError::Internal(format!("Failed to get current exe: {}", e)))?; - let args: Vec = std::env::args().skip(1).collect(); - - #[cfg(unix)] - { - use std::os::unix::process::CommandExt; - let err = std::process::Command::new(&exe).args(&args).exec(); - Err(AppError::Internal(format!("Failed to restart: {}", err))) - } - - #[cfg(not(unix))] - { - std::process::Command::new(&exe) - .args(&args) - .spawn() - .map_err(|e| AppError::Internal(format!("Failed to spawn restart process: {}", e)))?; - std::process::exit(0); - } -} diff --git a/src/web/handlers/account.rs b/src/web/handlers/account.rs index 5ac6c668..d5975f48 100644 --- a/src/web/handlers/account.rs +++ b/src/web/handlers/account.rs @@ -1,4 +1,5 @@ use super::*; +use crate::state::ShutdownAction; /// Change password request #[derive(Deserialize)] @@ -108,41 +109,9 @@ pub async fn change_username( pub async fn system_restart(State(state): State>) -> Json { info!("System restart requested via API"); - // Send shutdown signal - let _ = state.shutdown_tx.send(()); - - // Spawn restart task in background - tokio::spawn(async { - // Wait for resources to be released (OTG, video, etc.) - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - - // Get current executable and args - let exe = match std::env::current_exe() { - Ok(e) => e, - Err(e) => { - tracing::error!("Failed to get current exe: {}", e); - std::process::exit(1); - } - }; - let args: Vec = std::env::args().skip(1).collect(); - - info!("Restarting: {:?} {:?}", exe, args); - - // Use exec to replace current process (Unix) - #[cfg(unix)] - { - use std::os::unix::process::CommandExt; - let err = std::process::Command::new(&exe).args(&args).exec(); - tracing::error!("Failed to restart: {}", err); - std::process::exit(1); - } - - #[cfg(not(unix))] - { - let _ = std::process::Command::new(&exe).args(&args).spawn(); - std::process::exit(0); - } - }); + let _ = state + .shutdown_tx + .send(ShutdownAction::Restart { exe_path: None }); Json(LoginResponse { success: true, diff --git a/web/src/views/SettingsView.vue b/web/src/views/SettingsView.vue index 62c36620..2401a623 100644 --- a/web/src/views/SettingsView.vue +++ b/web/src/views/SettingsView.vue @@ -1,8 +1,9 @@