diff --git a/src/atx/mod.rs b/src/atx/mod.rs index a3c3671c..0a6b1be4 100644 --- a/src/atx/mod.rs +++ b/src/atx/mod.rs @@ -93,11 +93,7 @@ mod tests { #[test] fn test_discover_devices() { - let devices = discover_devices(); - // Just verify the function runs without error - assert!(devices.gpio_chips.len() >= 0); - assert!(devices.usb_relays.len() >= 0); - assert!(devices.serial_ports.len() >= 0); + let _devices = discover_devices(); } #[test] diff --git a/src/audio/controller.rs b/src/audio/controller.rs index 9da09108..5861aec1 100644 --- a/src/audio/controller.rs +++ b/src/audio/controller.rs @@ -13,7 +13,7 @@ use super::encoder::{OpusConfig, OpusFrame}; use super::monitor::{AudioHealthMonitor, AudioHealthStatus}; use super::streamer::{AudioStreamer, AudioStreamerConfig}; use crate::error::{AppError, Result}; -use crate::events::{EventBus, SystemEvent}; +use crate::events::EventBus; /// Audio quality presets #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] @@ -139,15 +139,15 @@ impl AudioController { } } - /// Set event bus for publishing audio events + /// Set event bus for internal state notifications. pub async fn set_event_bus(&self, event_bus: Arc) { *self.event_bus.write().await = Some(event_bus); } - /// Publish an event to the event bus - async fn publish_event(&self, event: SystemEvent) { + /// Mark the device-info snapshot as stale. + async fn mark_device_info_dirty(&self) { if let Some(ref bus) = *self.event_bus.read().await { - bus.publish(event); + bus.mark_device_info_dirty(); } } @@ -276,11 +276,7 @@ impl AudioController { .report_error(Some(&config.device), &error_msg, "start_failed") .await; - self.publish_event(SystemEvent::AudioStateChanged { - streaming: false, - device: None, - }) - .await; + self.mark_device_info_dirty().await; return Err(AppError::AudioError(error_msg)); } @@ -292,12 +288,7 @@ impl AudioController { self.monitor.report_recovered(Some(&config.device)).await; } - // Publish event - self.publish_event(SystemEvent::AudioStateChanged { - streaming: true, - device: Some(config.device), - }) - .await; + self.mark_device_info_dirty().await; info!("Audio streaming started"); Ok(()) @@ -309,12 +300,7 @@ impl AudioController { streamer.stop().await?; } - // Publish event - self.publish_event(SystemEvent::AudioStateChanged { - streaming: false, - device: None, - }) - .await; + self.mark_device_info_dirty().await; info!("Audio streaming stopped"); Ok(()) diff --git a/src/events/mod.rs b/src/events/mod.rs index b3de4e18..f317fd6b 100644 --- a/src/events/mod.rs +++ b/src/events/mod.rs @@ -15,6 +15,39 @@ use tokio::sync::broadcast; /// Event channel capacity (ring buffer size) const EVENT_CHANNEL_CAPACITY: usize = 256; +const EXACT_TOPICS: &[&str] = &[ + "stream.mode_switching", + "stream.state_changed", + "stream.config_changing", + "stream.config_applied", + "stream.device_lost", + "stream.reconnecting", + "stream.recovered", + "stream.webrtc_ready", + "stream.stats_update", + "stream.mode_changed", + "stream.mode_ready", + "webrtc.ice_candidate", + "webrtc.ice_complete", + "msd.upload_progress", + "msd.download_progress", + "system.device_info", + "error", +]; + +const PREFIX_TOPICS: &[&str] = &["stream.*", "webrtc.*", "msd.*", "system.*"]; + +fn make_sender() -> broadcast::Sender { + let (tx, _rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY); + tx +} + +fn topic_prefix(event_name: &str) -> Option { + event_name + .split_once('.') + .map(|(prefix, _)| format!("{}.*", prefix)) +} + /// Global event bus for broadcasting system events /// /// The event bus uses tokio's broadcast channel to distribute events @@ -43,13 +76,31 @@ const EVENT_CHANNEL_CAPACITY: usize = 256; /// ``` pub struct EventBus { tx: broadcast::Sender, + exact_topics: std::collections::HashMap<&'static str, broadcast::Sender>, + prefix_topics: std::collections::HashMap<&'static str, broadcast::Sender>, + device_info_dirty_tx: broadcast::Sender<()>, } impl EventBus { /// Create a new event bus pub fn new() -> Self { - let (tx, _rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY); - Self { tx } + let tx = make_sender(); + let exact_topics = EXACT_TOPICS + .iter() + .map(|topic| (*topic, make_sender())) + .collect(); + let prefix_topics = PREFIX_TOPICS + .iter() + .map(|topic| (*topic, make_sender())) + .collect(); + let (device_info_dirty_tx, _dirty_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY); + + Self { + tx, + exact_topics, + prefix_topics, + device_info_dirty_tx, + } } /// Publish an event to all subscribers @@ -57,6 +108,18 @@ impl EventBus { /// If there are no active subscribers, the event is silently dropped. /// This is by design - events are fire-and-forget notifications. pub fn publish(&self, event: SystemEvent) { + let event_name = event.event_name(); + + if let Some(tx) = self.exact_topics.get(event_name) { + let _ = tx.send(event.clone()); + } + + if let Some(prefix) = topic_prefix(event_name) { + if let Some(tx) = self.prefix_topics.get(prefix.as_str()) { + let _ = tx.send(event.clone()); + } + } + // If no subscribers, send returns Err which is normal let _ = self.tx.send(event); } @@ -70,6 +133,35 @@ impl EventBus { self.tx.subscribe() } + /// Subscribe to a specific topic. + /// + /// Supports exact event names, namespace wildcards like `stream.*`, and + /// `*` for the full event stream. + pub fn subscribe_topic(&self, topic: &str) -> Option> { + if topic == "*" { + return Some(self.tx.subscribe()); + } + + if topic.ends_with(".*") { + return self.prefix_topics.get(topic).map(|tx| tx.subscribe()); + } + + self.exact_topics.get(topic).map(|tx| tx.subscribe()) + } + + /// Mark the device-info snapshot as stale. + /// + /// This is an internal trigger used to refresh the latest `system.device_info` + /// snapshot without exposing another public WebSocket event. + pub fn mark_device_info_dirty(&self) { + let _ = self.device_info_dirty_tx.send(()); + } + + /// Subscribe to internal device-info refresh triggers. + pub fn subscribe_device_info_dirty(&self) -> broadcast::Receiver<()> { + self.device_info_dirty_tx.subscribe() + } + /// Get the current number of active subscribers /// /// Useful for monitoring and debugging. @@ -122,6 +214,40 @@ mod tests { assert!(matches!(event2, SystemEvent::StreamStateChanged { .. })); } + #[tokio::test] + async fn test_subscribe_topic_exact() { + let bus = EventBus::new(); + let mut rx = bus.subscribe_topic("stream.state_changed").unwrap(); + + bus.publish(SystemEvent::StreamStateChanged { + state: "ready".to_string(), + device: None, + }); + + let event = rx.recv().await.unwrap(); + assert!(matches!(event, SystemEvent::StreamStateChanged { .. })); + } + + #[tokio::test] + async fn test_subscribe_topic_prefix() { + let bus = EventBus::new(); + let mut rx = bus.subscribe_topic("stream.*").unwrap(); + + bus.publish(SystemEvent::StreamStateChanged { + state: "ready".to_string(), + device: None, + }); + + let event = rx.recv().await.unwrap(); + assert!(matches!(event, SystemEvent::StreamStateChanged { .. })); + } + + #[test] + fn test_subscribe_topic_unknown() { + let bus = EventBus::new(); + assert!(bus.subscribe_topic("unknown.topic").is_none()); + } + #[test] fn test_no_subscribers() { let bus = EventBus::new(); diff --git a/src/events/types.rs b/src/events/types.rs index 8f7930ef..8a0aa4fa 100644 --- a/src/events/types.rs +++ b/src/events/types.rs @@ -5,9 +5,6 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use crate::atx::PowerStatus; -use crate::msd::MsdMode; - // ============================================================================ // Device Info Structures (for system.device_info event) // ============================================================================ @@ -278,36 +275,9 @@ pub enum SystemEvent { mode: String, }, - // ============================================================================ - // HID Events - // ============================================================================ - /// HID backend state changed - #[serde(rename = "hid.state_changed")] - HidStateChanged { - /// Backend type: "otg", "ch9329", "none" - backend: String, - /// Whether backend is initialized and ready - initialized: bool, - /// Whether backend is currently online - online: bool, - /// Error message if any, None if OK - error: Option, - /// Error code for programmatic handling: "epipe", "eagain", "port_not_found", etc. - error_code: Option, - }, - // ============================================================================ // MSD (Mass Storage Device) Events // ============================================================================ - /// MSD state changed - #[serde(rename = "msd.state_changed")] - MsdStateChanged { - /// Operating mode - mode: MsdMode, - /// Whether storage is connected to target - connected: bool, - }, - /// File upload progress (for large file uploads) #[serde(rename = "msd.upload_progress")] MsdUploadProgress { @@ -342,28 +312,6 @@ pub enum SystemEvent { status: String, }, - // ============================================================================ - // ATX (Power Control) Events - // ============================================================================ - /// ATX power state changed - #[serde(rename = "atx.state_changed")] - AtxStateChanged { - /// Power status - power_status: PowerStatus, - }, - - // ============================================================================ - // Audio Events - // ============================================================================ - /// Audio state changed (streaming started/stopped) - #[serde(rename = "audio.state_changed")] - AudioStateChanged { - /// Whether audio is currently streaming - streaming: bool, - /// Current device (None if stopped) - device: Option, - }, - /// Complete device information (sent on WebSocket connect and state changes) #[serde(rename = "system.device_info")] DeviceInfo { @@ -404,12 +352,8 @@ impl SystemEvent { Self::StreamModeReady { .. } => "stream.mode_ready", Self::WebRTCIceCandidate { .. } => "webrtc.ice_candidate", Self::WebRTCIceComplete { .. } => "webrtc.ice_complete", - Self::HidStateChanged { .. } => "hid.state_changed", - Self::MsdStateChanged { .. } => "msd.state_changed", Self::MsdUploadProgress { .. } => "msd.upload_progress", Self::MsdDownloadProgress { .. } => "msd.download_progress", - Self::AtxStateChanged { .. } => "atx.state_changed", - Self::AudioStateChanged { .. } => "audio.state_changed", Self::DeviceInfo { .. } => "system.device_info", Self::Error { .. } => "error", } @@ -448,12 +392,6 @@ mod tests { device: Some("/dev/video0".to_string()), }; assert_eq!(event.event_name(), "stream.state_changed"); - - let event = SystemEvent::MsdStateChanged { - mode: MsdMode::Image, - connected: true, - }; - assert_eq!(event.event_name(), "msd.state_changed"); } #[test] diff --git a/src/hid/ch9329.rs b/src/hid/ch9329.rs index e7ec3a93..b889d148 100644 --- a/src/hid/ch9329.rs +++ b/src/hid/ch9329.rs @@ -567,8 +567,9 @@ impl Ch9329Backend { data: &[u8], ) -> Result<()> { let packet = Self::build_packet(address, cmd, data); - port.write_all(&packet) - .map_err(|e| Self::backend_error(format!("Failed to write to CH9329: {}", e), "write_failed"))?; + port.write_all(&packet).map_err(|e| { + Self::backend_error(format!("Failed to write to CH9329: {}", e), "write_failed") + })?; trace!("CH9329 TX [cmd=0x{:02X}]: {:02X?}", cmd, packet); Ok(()) } @@ -599,7 +600,11 @@ impl Ch9329Backend { } fn expected_response_cmd(cmd: u8, is_error: bool) -> u8 { - cmd | if is_error { RESPONSE_ERROR_MASK } else { RESPONSE_SUCCESS_MASK } + cmd | if is_error { + RESPONSE_ERROR_MASK + } else { + RESPONSE_SUCCESS_MASK + } } fn xfer_packet( @@ -700,9 +705,9 @@ impl Ch9329Backend { fn enqueue_command(&self, command: WorkerCommand) -> Result<()> { let guard = self.worker_tx.lock(); - let sender = guard.as_ref().ok_or_else(|| { - Self::backend_error("CH9329 worker is not running", "worker_stopped") - })?; + let sender = guard + .as_ref() + .ok_or_else(|| Self::backend_error("CH9329 worker is not running", "worker_stopped"))?; sender .send(command) .map_err(|_| Self::backend_error("CH9329 worker stopped", "worker_stopped")) @@ -765,9 +770,7 @@ impl Ch9329Backend { } Err(err) => { if let AppError::HidError { - reason, - error_code, - .. + reason, error_code, .. } = err { runtime.set_error(reason, error_code); @@ -894,9 +897,7 @@ impl Ch9329Backend { } Err(err) => { if let AppError::HidError { - reason, - error_code, - .. + reason, error_code, .. } = &err { runtime.set_error(reason.clone(), error_code.clone()); @@ -912,9 +913,7 @@ impl Ch9329Backend { Ok(WorkerCommand::Packet { cmd, data }) => { if let Err(err) = Self::xfer_packet(port.as_mut(), address, cmd, &data) { if let AppError::HidError { - reason, - error_code, - .. + reason, error_code, .. } = err { runtime.set_error(reason, error_code); @@ -949,9 +948,7 @@ impl Ch9329Backend { for (cmd, data) in reset_sequence { if let Err(err) = Self::xfer_packet(port.as_mut(), address, cmd, &data) { if let AppError::HidError { - reason, - error_code, - .. + reason, error_code, .. } = err { runtime.set_error(reason, error_code); @@ -988,9 +985,7 @@ impl Ch9329Backend { } Err(err) => { if let AppError::HidError { - reason, - error_code, - .. + reason, error_code, .. } = err { runtime.set_error(reason, error_code); @@ -1050,14 +1045,7 @@ impl HidBackend for Ch9329Backend { .name("ch9329-worker".to_string()) .spawn(move || { Self::worker_loop( - port_path, - baud_rate, - address, - rx, - chip_info, - led_status, - runtime, - init_tx, + port_path, baud_rate, address, rx, chip_info, led_status, runtime, init_tx, ); }) .map_err(|e| AppError::Internal(format!("Failed to spawn CH9329 worker: {}", e)))?; @@ -1084,7 +1072,10 @@ impl HidBackend for Ch9329Backend { Ok(Err(err)) => { let _ = handle.join(); self.record_error( - format!("CH9329 not responding on {} @ {} baud: {}", self.port_path, self.baud_rate, err), + format!( + "CH9329 not responding on {} @ {} baud: {}", + self.port_path, self.baud_rate, err + ), "init_failed", ); Err(AppError::Internal(format!( @@ -1398,15 +1389,14 @@ mod tests { #[test] fn test_packet_building() { - let backend = Ch9329Backend::new("/dev/null").unwrap(); - // Test GET_INFO packet (no data) - let packet = backend.build_packet(cmd::GET_INFO, &[]); + let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::GET_INFO, &[]); assert_eq!(packet, vec![0x57, 0xAB, 0x00, 0x01, 0x00, 0x03]); // Test keyboard packet (8 bytes data) let data = [0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; // 'A' key - let packet = backend.build_packet(cmd::SEND_KB_GENERAL_DATA, &data); + let packet = + Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_KB_GENERAL_DATA, &data); assert_eq!(packet[0], 0x57); // Header assert_eq!(packet[1], 0xAB); // Header @@ -1415,17 +1405,17 @@ mod tests { assert_eq!(packet[4], 8); // Length (8 data bytes) assert_eq!(&packet[5..13], &data); // Data // Checksum = 0x57 + 0xAB + 0x00 + 0x02 + 0x08 + 0x00 + 0x00 + 0x04 + ... = 0x10 - let expected_checksum: u8 = packet[..13].iter().fold(0u8, |acc, &x| acc.wrapping_add(x)); + let expected_checksum: u8 = packet[..13] + .iter() + .fold(0u8, |acc: u8, &x| acc.wrapping_add(x)); assert_eq!(packet[13], expected_checksum); } #[test] fn test_relative_mouse_packet() { - let backend = Ch9329Backend::new("/dev/null").unwrap(); - // Test relative mouse: move right 50 pixels let data = [0x01, 0x00, 50u8, 0x00, 0x00]; - let packet = backend.build_packet(cmd::SEND_MS_REL_DATA, &data); + let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_MS_REL_DATA, &data); assert_eq!(packet[0], 0x57); assert_eq!(packet[1], 0xAB); diff --git a/src/hid/mod.rs b/src/hid/mod.rs index 7b28a882..f6e93449 100644 --- a/src/hid/mod.rs +++ b/src/hid/mod.rs @@ -113,11 +113,11 @@ impl HidRuntimeState { use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use tracing::{info, warn}; use tokio::sync::RwLock; +use tracing::{info, warn}; use crate::error::{AppError, Result}; -use crate::events::{EventBus, SystemEvent}; +use crate::events::EventBus; use crate::otg::OtgService; use std::time::Duration; use tokio::sync::mpsc; @@ -360,18 +360,6 @@ impl HidController { self.runtime_state.read().await.clone() } - /// Get current state as SystemEvent - pub async fn current_state_event(&self) -> crate::events::SystemEvent { - let state = self.snapshot().await; - SystemEvent::HidStateChanged { - backend: state.backend, - initialized: state.initialized, - online: state.online, - error: state.error, - error_code: state.error_code, - } - } - /// Reload the HID backend with new type pub async fn reload(&self, new_backend_type: HidBackendType) -> Result<()> { info!("Reloading HID backend: {:?}", new_backend_type); @@ -707,12 +695,6 @@ async fn apply_runtime_state( } if let Some(events) = events.read().await.as_ref() { - events.publish(SystemEvent::HidStateChanged { - backend: next.backend, - initialized: next.initialized, - online: next.online, - error: next.error, - error_code: next.error_code, - }); + events.mark_device_info_dirty(); } } diff --git a/src/main.rs b/src/main.rs index 3cd9ef0e..2ce2dc01 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ use axum_server::tls_rustls::RustlsConfig; use clap::{Parser, ValueEnum}; use futures::{stream::FuturesUnordered, StreamExt}; use rustls::crypto::{ring, CryptoProvider}; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use one_kvm::atx::AtxController; @@ -646,6 +646,8 @@ async fn main() -> anyhow::Result<()> { tracing::info!("Extension health check task started"); } + state.publish_device_info().await; + // Start device info broadcast task // This monitors state change events and broadcasts DeviceInfo to all clients spawn_device_info_broadcaster(state.clone(), events); @@ -854,12 +856,86 @@ fn generate_self_signed_cert() -> anyhow::Result, events: Arc) { - use one_kvm::events::SystemEvent; use std::time::{Duration, Instant}; - let mut rx = events.subscribe(); + enum DeviceInfoTrigger { + Event, + Lagged { topic: &'static str, count: u64 }, + } + + const DEVICE_INFO_TOPICS: &[&str] = &[ + "stream.state_changed", + "stream.config_applied", + "stream.mode_ready", + ]; const DEBOUNCE_MS: u64 = 100; + let (trigger_tx, mut trigger_rx) = mpsc::unbounded_channel(); + + for topic in DEVICE_INFO_TOPICS { + let Some(mut rx) = events.subscribe_topic(topic) else { + tracing::warn!( + "DeviceInfo broadcaster missing topic subscription: {}", + topic + ); + continue; + }; + + let trigger_tx = trigger_tx.clone(); + let topic_name = *topic; + tokio::spawn(async move { + loop { + match rx.recv().await { + Ok(_) => { + if trigger_tx.send(DeviceInfoTrigger::Event).is_err() { + break; + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => { + if trigger_tx + .send(DeviceInfoTrigger::Lagged { + topic: topic_name, + count, + }) + .is_err() + { + break; + } + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } + } + }); + } + + { + let mut dirty_rx = events.subscribe_device_info_dirty(); + let trigger_tx = trigger_tx.clone(); + tokio::spawn(async move { + loop { + match dirty_rx.recv().await { + Ok(()) => { + if trigger_tx.send(DeviceInfoTrigger::Event).is_err() { + break; + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => { + if trigger_tx + .send(DeviceInfoTrigger::Lagged { + topic: "device_info_dirty", + count, + }) + .is_err() + { + break; + } + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } + } + }); + } + tokio::spawn(async move { let mut last_broadcast = Instant::now() - Duration::from_millis(DEBOUNCE_MS); let mut pending_broadcast = false; @@ -869,32 +945,24 @@ fn spawn_device_info_broadcaster(state: Arc, events: Arc) { let recv_result = if pending_broadcast { let remaining = DEBOUNCE_MS.saturating_sub(last_broadcast.elapsed().as_millis() as u64); - tokio::time::timeout(Duration::from_millis(remaining), rx.recv()).await + tokio::time::timeout(Duration::from_millis(remaining), trigger_rx.recv()).await } else { - Ok(rx.recv().await) + Ok(trigger_rx.recv().await) }; match recv_result { - Ok(Ok(event)) => { - let should_broadcast = matches!( - event, - SystemEvent::StreamStateChanged { .. } - | SystemEvent::StreamConfigApplied { .. } - | SystemEvent::StreamModeReady { .. } - | SystemEvent::HidStateChanged { .. } - | SystemEvent::MsdStateChanged { .. } - | SystemEvent::AtxStateChanged { .. } - | SystemEvent::AudioStateChanged { .. } - ); - if should_broadcast { - pending_broadcast = true; - } - } - Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(n))) => { - tracing::warn!("DeviceInfo broadcaster lagged by {} events", n); + Ok(Some(DeviceInfoTrigger::Event)) => { pending_broadcast = true; } - Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => { + Ok(Some(DeviceInfoTrigger::Lagged { topic, count })) => { + tracing::warn!( + "DeviceInfo broadcaster lagged by {} events on topic {}", + count, + topic + ); + pending_broadcast = true; + } + Ok(None) => { tracing::info!("Event bus closed, stopping DeviceInfo broadcaster"); break; } diff --git a/src/msd/controller.rs b/src/msd/controller.rs index 4830c088..d2e039bd 100644 --- a/src/msd/controller.rs +++ b/src/msd/controller.rs @@ -115,15 +115,6 @@ impl MsdController { Ok(()) } - /// Get current state as SystemEvent - pub async fn current_state_event(&self) -> crate::events::SystemEvent { - let state = self.state.read().await; - crate::events::SystemEvent::MsdStateChanged { - mode: state.mode.clone(), - connected: state.connected, - } - } - /// Get current MSD state pub async fn state(&self) -> MsdState { self.state.read().await.clone() @@ -141,6 +132,12 @@ impl MsdController { } } + async fn mark_device_info_dirty(&self) { + if let Some(ref bus) = *self.events.read().await { + bus.mark_device_info_dirty(); + } + } + /// Check if MSD is available pub async fn is_available(&self) -> bool { self.state.read().await.available @@ -228,11 +225,7 @@ impl MsdController { self.monitor.report_recovered().await; } - self.publish_event(crate::events::SystemEvent::MsdStateChanged { - mode: MsdMode::Image, - connected: true, - }) - .await; + self.mark_device_info_dirty().await; Ok(()) } @@ -303,12 +296,7 @@ impl MsdController { self.monitor.report_recovered().await; } - // Publish event - self.publish_event(crate::events::SystemEvent::MsdStateChanged { - mode: MsdMode::Drive, - connected: true, - }) - .await; + self.mark_device_info_dirty().await; Ok(()) } @@ -340,11 +328,7 @@ impl MsdController { drop(state); drop(_op_guard); - self.publish_event(crate::events::SystemEvent::MsdStateChanged { - mode: MsdMode::None, - connected: false, - }) - .await; + self.mark_device_info_dirty().await; Ok(()) } diff --git a/src/state.rs b/src/state.rs index 814ec9cc..21469f9f 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,5 +1,5 @@ use std::{collections::VecDeque, sync::Arc}; -use tokio::sync::{broadcast, RwLock}; +use tokio::sync::{broadcast, watch, RwLock}; use crate::atx::AtxController; use crate::audio::AudioController; @@ -58,6 +58,8 @@ pub struct AppState { pub extensions: Arc, /// Event bus for real-time notifications pub events: Arc, + /// Latest device info snapshot for WebSocket clients + device_info_tx: watch::Sender>, /// Online update service pub update: Arc, /// Shutdown signal sender @@ -89,6 +91,8 @@ impl AppState { shutdown_tx: broadcast::Sender<()>, data_dir: std::path::PathBuf, ) -> Arc { + let (device_info_tx, _device_info_rx) = watch::channel(None); + Arc::new(Self { config, sessions, @@ -103,6 +107,7 @@ impl AppState { rtsp: Arc::new(RwLock::new(rtsp)), extensions, events, + device_info_tx, update, shutdown_tx, revoked_sessions: Arc::new(RwLock::new(VecDeque::new())), @@ -120,6 +125,11 @@ impl AppState { self.shutdown_tx.subscribe() } + /// Subscribe to the latest device info snapshot. + pub fn subscribe_device_info(&self) -> watch::Receiver> { + self.device_info_tx.subscribe() + } + /// Record revoked session IDs (bounded queue) pub async fn remember_revoked_sessions(&self, session_ids: Vec) { if session_ids.is_empty() { @@ -167,7 +177,7 @@ impl AppState { /// Publish DeviceInfo event to all connected WebSocket clients pub async fn publish_device_info(&self) { let device_info = self.get_device_info().await; - self.events.publish(device_info); + let _ = self.device_info_tx.send(Some(device_info)); } /// Collect video device information diff --git a/src/video/stream_manager.rs b/src/video/stream_manager.rs index b5beb25f..91231074 100644 --- a/src/video/stream_manager.rs +++ b/src/video/stream_manager.rs @@ -532,17 +532,30 @@ impl VideoStreamManager { device_path, format, resolution.width, resolution.height, fps, mode ); + if mode == StreamMode::WebRTC { + // Stop the shared pipeline before replacing the capture source so WebRTC + // sessions do not stay attached to a stale frame source. + self.webrtc_streamer + .update_video_config(resolution, format, fps) + .await; + info!("WebRTC streamer config updated (pipeline stopped, sessions closed)"); + } + // Apply to streamer (handles video capture) self.streamer .apply_video_config(device_path, format, resolution, fps) .await?; + if mode != StreamMode::WebRTC { + if let Err(e) = self.start().await { + error!("Failed to start streamer after config change: {}", e); + } else { + info!("Streamer started after config change"); + } + } + // Update WebRTC config if in WebRTC mode if mode == StreamMode::WebRTC { - self.webrtc_streamer - .update_video_config(resolution, format, fps) - .await; - let (device_path, actual_resolution, actual_format, actual_fps, jpeg_quality) = self.streamer.current_capture_config().await; if actual_format != format || actual_resolution != resolution || actual_fps != fps { diff --git a/src/web/handlers/config/apply.rs b/src/web/handlers/config/apply.rs index e4a16036..c2282b9a 100644 --- a/src/web/handlers/config/apply.rs +++ b/src/web/handlers/config/apply.rs @@ -6,7 +6,6 @@ use std::sync::Arc; use crate::config::*; use crate::error::{AppError, Result}; -use crate::events::SystemEvent; use crate::rtsp::RtspService; use crate::state::AppState; use crate::video::codec_constraints::{ @@ -45,73 +44,11 @@ pub async fn apply_video_config( let resolution = crate::video::format::Resolution::new(new_config.width, new_config.height); - // Step 1: 更新 WebRTC streamer 配置(停止现有 pipeline 和 sessions) state .stream_manager - .webrtc_streamer() - .update_video_config(resolution, format, new_config.fps) - .await; - tracing::info!("WebRTC streamer config updated"); - - // Step 2: 应用视频配置到 streamer(重新创建 capturer) - state - .stream_manager - .streamer() .apply_video_config(&device, format, resolution, new_config.fps) .await .map_err(|e| AppError::VideoError(format!("Failed to apply video config: {}", e)))?; - tracing::info!("Video config applied to streamer"); - - // Step 3: 重启 streamer(仅 MJPEG 模式) - if !state.stream_manager.is_webrtc_enabled().await { - if let Err(e) = state.stream_manager.start().await { - tracing::error!("Failed to start streamer after config change: {}", e); - } else { - tracing::info!("Streamer started after config change"); - } - } - - // 配置 WebRTC direct capture(所有模式统一配置) - let (device_path, _resolution, _format, _fps, jpeg_quality) = state - .stream_manager - .streamer() - .current_capture_config() - .await; - if let Some(device_path) = device_path { - state - .stream_manager - .webrtc_streamer() - .set_capture_device(device_path, jpeg_quality) - .await; - } else { - tracing::warn!("No capture device configured for WebRTC"); - } - - if state.stream_manager.is_webrtc_enabled().await { - use crate::video::encoder::VideoCodecType; - let codec = state - .stream_manager - .webrtc_streamer() - .current_video_codec() - .await; - let codec_str = match codec { - VideoCodecType::H264 => "h264", - VideoCodecType::H265 => "h265", - VideoCodecType::VP8 => "vp8", - VideoCodecType::VP9 => "vp9", - } - .to_string(); - let is_hardware = state - .stream_manager - .webrtc_streamer() - .is_hardware_encoding() - .await; - state.events.publish(SystemEvent::WebRTCReady { - transition_id: None, - codec: codec_str, - hardware: is_hardware, - }); - } tracing::info!("Video config applied successfully"); Ok(()) diff --git a/src/web/handlers/mod.rs b/src/web/handlers/mod.rs index 70229ea9..91448ab9 100644 --- a/src/web/handlers/mod.rs +++ b/src/web/handlers/mod.rs @@ -12,7 +12,6 @@ use tracing::{info, warn}; use crate::auth::{Session, SESSION_COOKIE}; use crate::config::{AppConfig, StreamMode}; use crate::error::{AppError, Result}; -use crate::events::SystemEvent; use crate::state::AppState; use crate::update::{UpdateChannel, UpdateOverviewResponse, UpdateStatusResponse, UpgradeRequest}; use crate::video::codec_constraints::codec_to_id; @@ -936,20 +935,8 @@ pub async fn update_config( let resolution = crate::video::format::Resolution::new(new_config.video.width, new_config.video.height); - // Step 1: Update WebRTC streamer config FIRST - // This stops the shared pipeline and closes existing sessions BEFORE capturer is recreated - // This ensures the pipeline won't be subscribed to a stale frame source - state - .stream_manager - .webrtc_streamer() - .update_video_config(resolution, format, new_config.video.fps) - .await; - tracing::info!("WebRTC streamer config updated (pipeline stopped, sessions closed)"); - - // Step 2: Apply video config to streamer (recreates capturer) if let Err(e) = state .stream_manager - .streamer() .apply_video_config(&device, format, resolution, new_config.video.fps) .await { @@ -962,59 +949,6 @@ pub async fn update_config( })); } tracing::info!("Video config applied successfully"); - - // Step 3: Start the streamer to begin capturing frames (MJPEG mode only) - if !state.stream_manager.is_webrtc_enabled().await { - // This is necessary because apply_video_config only creates the capturer but doesn't start it - if let Err(e) = state.stream_manager.start().await { - tracing::error!("Failed to start streamer after config change: {}", e); - // Don't fail the request - the stream might start later when client connects - } else { - tracing::info!("Streamer started after config change"); - } - } - - // Configure WebRTC direct capture (all modes) - let (device_path, _resolution, _format, _fps, jpeg_quality) = state - .stream_manager - .streamer() - .current_capture_config() - .await; - if let Some(device_path) = device_path { - state - .stream_manager - .webrtc_streamer() - .set_capture_device(device_path, jpeg_quality) - .await; - } else { - tracing::warn!("No capture device configured for WebRTC"); - } - - if state.stream_manager.is_webrtc_enabled().await { - use crate::video::encoder::VideoCodecType; - let codec = state - .stream_manager - .webrtc_streamer() - .current_video_codec() - .await; - let codec_str = match codec { - VideoCodecType::H264 => "h264", - VideoCodecType::H265 => "h265", - VideoCodecType::VP8 => "vp8", - VideoCodecType::VP9 => "vp9", - } - .to_string(); - let is_hardware = state - .stream_manager - .webrtc_streamer() - .is_hardware_encoding() - .await; - state.events.publish(SystemEvent::WebRTCReady { - transition_id: None, - codec: codec_str, - hardware: is_hardware, - }); - } } // Stream config processing (encoder backend, bitrate, etc.) diff --git a/src/web/ws.rs b/src/web/ws.rs index 38ede8c0..6cefcac8 100644 --- a/src/web/ws.rs +++ b/src/web/ws.rs @@ -16,12 +16,122 @@ use axum::{ use futures::{SinkExt, StreamExt}; use serde::Deserialize; use std::sync::Arc; -use tokio::sync::broadcast; +use tokio::{sync::mpsc, task::JoinHandle}; use tracing::{debug, info, warn}; use crate::events::SystemEvent; use crate::state::AppState; +enum BusMessage { + Event(SystemEvent), + Lagged { topic: String, count: u64 }, +} + +fn normalize_topics(topics: &[String]) -> Vec { + let mut normalized = topics.to_vec(); + normalized.sort(); + normalized.dedup(); + + if normalized.iter().any(|topic| topic == "*") { + return vec!["*".to_string()]; + } + + normalized + .into_iter() + .filter(|topic| { + if topic.ends_with(".*") { + return true; + } + + let Some((prefix, _)) = topic.split_once('.') else { + return true; + }; + + let wildcard = format!("{}.*", prefix); + !topics.iter().any(|candidate| candidate == &wildcard) + }) + .collect() +} + +fn is_device_info_topic(topic: &str) -> bool { + matches!(topic, "*" | "system.*" | "system.device_info") +} + +fn rebuild_event_tasks( + state: &Arc, + topics: &[String], + event_tx: &mpsc::UnboundedSender, + event_tasks: &mut Vec>, +) { + for task in event_tasks.drain(..) { + task.abort(); + } + + let topics = normalize_topics(topics); + let mut device_info_task_added = false; + for topic in topics { + if is_device_info_topic(&topic) && !device_info_task_added { + let mut rx = state.subscribe_device_info(); + let event_tx = event_tx.clone(); + event_tasks.push(tokio::spawn(async move { + if let Some(snapshot) = rx.borrow().clone() { + if event_tx.send(BusMessage::Event(snapshot)).is_err() { + return; + } + } + + loop { + if rx.changed().await.is_err() { + break; + } + + if let Some(snapshot) = rx.borrow().clone() { + if event_tx.send(BusMessage::Event(snapshot)).is_err() { + break; + } + } + } + })); + device_info_task_added = true; + } + + if is_device_info_topic(&topic) && topic != "*" { + continue; + } + + let Some(mut rx) = state.events.subscribe_topic(&topic) else { + warn!("Client subscribed to unknown topic: {}", topic); + continue; + }; + + let event_tx = event_tx.clone(); + let topic_name = topic.clone(); + event_tasks.push(tokio::spawn(async move { + loop { + match rx.recv().await { + Ok(event) => { + if event_tx.send(BusMessage::Event(event)).is_err() { + break; + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => { + if event_tx + .send(BusMessage::Lagged { + topic: topic_name.clone(), + count, + }) + .is_err() + { + break; + } + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } + } + })); + } +} + /// Client-to-server message #[derive(Debug, Deserialize)] #[serde(tag = "type", content = "payload")] @@ -50,16 +160,12 @@ pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State> /// Handle WebSocket connection async fn handle_socket(socket: WebSocket, state: Arc) { let (mut sender, mut receiver) = socket.split(); - - // Subscribe to event bus - let mut event_rx = state.events.subscribe(); + let (event_tx, mut event_rx) = mpsc::unbounded_channel(); + let mut event_tasks: Vec> = Vec::new(); // Track subscribed topics (default: none until client subscribes) let mut subscribed_topics: Vec = vec![]; - // Flag to send device info after first subscribe - let mut device_info_sent = false; - info!("WebSocket client connected"); // Heartbeat interval (30 seconds) @@ -73,18 +179,13 @@ async fn handle_socket(socket: WebSocket, state: Arc) { Some(Ok(Message::Text(text))) => { if let Err(e) = handle_client_message(&text, &mut subscribed_topics).await { warn!("Failed to handle client message: {}", e); - } - - // Send device info after first subscribe - if !device_info_sent && !subscribed_topics.is_empty() { - let device_info = state.get_device_info().await; - if let Ok(json) = serialize_event(&device_info) { - if sender.send(Message::Text(json.into())).await.is_err() { - warn!("Failed to send device info to client"); - break; - } - } - device_info_sent = true; + } else { + rebuild_event_tasks( + &state, + &subscribed_topics, + &event_tx, + &mut event_tasks, + ); } } Some(Ok(Message::Ping(_))) => { @@ -109,28 +210,29 @@ async fn handle_socket(socket: WebSocket, state: Arc) { // Receive event from event bus event = event_rx.recv() => { match event { - Ok(event) => { + Some(BusMessage::Event(event)) => { // Filter event based on subscribed topics - if should_send_event(&event, &subscribed_topics) { - if let Ok(json) = serialize_event(&event) { - if sender.send(Message::Text(json.into())).await.is_err() { - warn!("Failed to send event to client, disconnecting"); - break; - } + if let Ok(json) = serialize_event(&event) { + if sender.send(Message::Text(json.into())).await.is_err() { + warn!("Failed to send event to client, disconnecting"); + break; } } } - Err(broadcast::error::RecvError::Lagged(n)) => { - warn!("WebSocket client lagged by {} events", n); + Some(BusMessage::Lagged { topic, count }) => { + warn!( + "WebSocket client lagged by {} events on topic {}", + count, topic + ); // Send error notification to client using SystemEvent::Error let error_event = SystemEvent::Error { - message: format!("Lagged by {} events", n), + message: format!("Lagged by {} events", count), }; if let Ok(json) = serialize_event(&error_event) { let _ = sender.send(Message::Text(json.into())).await; } } - Err(_) => { + None => { warn!("Event bus closed"); break; } @@ -147,6 +249,10 @@ async fn handle_socket(socket: WebSocket, state: Arc) { } } + for task in event_tasks { + task.abort(); + } + info!("WebSocket handler exiting"); } @@ -176,21 +282,6 @@ async fn handle_client_message( Ok(()) } -/// Check if an event should be sent based on subscribed topics -fn should_send_event(event: &SystemEvent, topics: &[String]) -> bool { - if topics.is_empty() { - return false; - } - - // Fast path: check for wildcard subscription (avoid String allocation) - if topics.iter().any(|t| t == "*") { - return true; - } - - // Check if event matches any subscribed topic - topics.iter().any(|topic| event.matches_topic(topic)) -} - /// Serialize event to JSON string fn serialize_event(event: &SystemEvent) -> Result { serde_json::to_string(event) @@ -199,53 +290,49 @@ fn serialize_event(event: &SystemEvent) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::events::SystemEvent; #[test] - fn test_should_send_event_wildcard() { - let event = SystemEvent::StreamStateChanged { - state: "streaming".to_string(), - device: None, - }; + fn test_normalize_topics_dedupes_and_sorts() { + let topics = vec![ + "stream.state_changed".to_string(), + "stream.state_changed".to_string(), + "system.device_info".to_string(), + ]; - assert!(should_send_event(&event, &["*".to_string()])); + assert_eq!( + normalize_topics(&topics), + vec![ + "stream.state_changed".to_string(), + "system.device_info".to_string() + ] + ); } #[test] - fn test_should_send_event_prefix() { - let event = SystemEvent::StreamStateChanged { - state: "streaming".to_string(), - device: None, - }; - - assert!(should_send_event(&event, &["stream.*".to_string()])); - assert!(!should_send_event(&event, &["msd.*".to_string()])); + fn test_normalize_topics_wildcard_wins() { + let topics = vec!["*".to_string(), "stream.state_changed".to_string()]; + assert_eq!(normalize_topics(&topics), vec!["*".to_string()]); } #[test] - fn test_should_send_event_exact() { - let event = SystemEvent::StreamStateChanged { - state: "streaming".to_string(), - device: None, - }; + fn test_normalize_topics_drops_exact_when_prefix_exists() { + let topics = vec![ + "stream.*".to_string(), + "stream.state_changed".to_string(), + "system.device_info".to_string(), + ]; - assert!(should_send_event( - &event, - &["stream.state_changed".to_string()] - )); - assert!(!should_send_event( - &event, - &["stream.config_changed".to_string()] - )); + assert_eq!( + normalize_topics(&topics), + vec!["stream.*".to_string(), "system.device_info".to_string()] + ); } #[test] - fn test_should_send_event_empty_topics() { - let event = SystemEvent::StreamStateChanged { - state: "streaming".to_string(), - device: None, - }; - - assert!(!should_send_event(&event, &[])); + fn test_is_device_info_topic_matches_expected_topics() { + assert!(is_device_info_topic("system.device_info")); + assert!(is_device_info_topic("system.*")); + assert!(is_device_info_topic("*")); + assert!(!is_device_info_topic("stream.*")); } } diff --git a/web/src/composables/useWebSocket.ts b/web/src/composables/useWebSocket.ts index d70e6af0..6d9b3d35 100644 --- a/web/src/composables/useWebSocket.ts +++ b/web/src/composables/useWebSocket.ts @@ -16,13 +16,40 @@ type EventHandler = (data: any) => void let wsInstance: WebSocket | null = null let handlers = new Map() +let subscribedTopics: string[] = [] const connected = ref(false) const reconnectAttempts = ref(0) const networkError = ref(false) const networkErrorMessage = ref(null) +function getSubscribedTopics(): string[] { + return Array.from(handlers.entries()) + .filter(([, eventHandlers]) => eventHandlers.length > 0) + .map(([event]) => event) + .sort() +} + +function arraysEqual(a: string[], b: string[]): boolean { + return a.length === b.length && a.every((value, index) => value === b[index]) +} + +function syncSubscriptions() { + const topics = getSubscribedTopics() + + if (arraysEqual(topics, subscribedTopics)) { + return + } + + subscribedTopics = topics + + if (wsInstance && wsInstance.readyState === WebSocket.OPEN) { + subscribe(topics) + } +} + function connect() { if (wsInstance && wsInstance.readyState === WebSocket.OPEN) { + syncSubscriptions() return } @@ -37,8 +64,7 @@ function connect() { networkErrorMessage.value = null reconnectAttempts.value = 0 - // Subscribe to all events by default - subscribe(['*']) + syncSubscriptions() } wsInstance.onmessage = (e) => { @@ -78,6 +104,7 @@ function disconnect() { wsInstance.close() wsInstance = null } + subscribedTopics = [] } function subscribe(topics: string[]) { @@ -94,6 +121,7 @@ function on(event: string, handler: EventHandler) { handlers.set(event, []) } handlers.get(event)!.push(handler) + syncSubscriptions() } function off(event: string, handler: EventHandler) { @@ -103,7 +131,11 @@ function off(event: string, handler: EventHandler) { if (index > -1) { eventHandlers.splice(index, 1) } + if (eventHandlers.length === 0) { + handlers.delete(event) + } } + syncSubscriptions() } function handleEvent(payload: WsEvent) {