diff --git a/libs/ventoy-img-rs/src/resources.rs b/libs/ventoy-img-rs/src/resources.rs index c32477fc..12dd59a6 100644 --- a/libs/ventoy-img-rs/src/resources.rs +++ b/libs/ventoy-img-rs/src/resources.rs @@ -168,5 +168,4 @@ mod tests { assert!(files.contains(&"core.img")); assert!(files.contains(&"ventoy.disk.img")); } - } diff --git a/src/audio/capture.rs b/src/audio/capture.rs index b0f1d64a..c1dfc415 100644 --- a/src/audio/capture.rs +++ b/src/audio/capture.rs @@ -1,5 +1,3 @@ -//! ALSA audio capture implementation - use alsa::pcm::{Access, Format, Frames, HwParams, State, IO}; use alsa::{Direction, ValueOr, PCM}; use bytes::Bytes; @@ -14,30 +12,23 @@ use crate::error::{AppError, Result}; use crate::utils::LogThrottler; use crate::{error_throttled, warn_throttled}; -/// Audio capture configuration #[derive(Debug, Clone)] pub struct AudioConfig { - /// ALSA device name (e.g., "hw:0,0" or "default") pub device_name: String, - /// Sample rate in Hz pub sample_rate: u32, - /// Number of channels (1 = mono, 2 = stereo) pub channels: u32, - /// Samples per frame (for Opus, typically 480 for 10ms at 48kHz) pub frame_size: u32, - /// Buffer size in frames pub buffer_frames: u32, - /// Period size in frames pub period_frames: u32, } impl Default for AudioConfig { fn default() -> Self { Self { - device_name: "default".to_string(), + device_name: String::new(), sample_rate: 48000, channels: 2, - frame_size: 960, // 20ms at 48kHz (good for Opus) + frame_size: 960, buffer_frames: 4096, period_frames: 960, } @@ -45,7 +36,6 @@ impl Default for AudioConfig { } impl AudioConfig { - /// Create config for a specific device (48 kHz stereo only; must match ALSA hardware). pub fn for_device(device: &AudioDeviceInfo) -> Self { Self { device_name: device.name.clone(), @@ -53,36 +43,26 @@ impl AudioConfig { } } - /// Bytes per sample (16-bit signed) pub fn bytes_per_sample(&self) -> u32 { 2 * self.channels } - /// Bytes per frame pub fn bytes_per_frame(&self) -> usize { (self.frame_size * self.bytes_per_sample()) as usize } } -/// Audio frame data #[derive(Debug, Clone)] pub struct AudioFrame { - /// Raw PCM data (S16LE interleaved) pub data: Bytes, - /// Sample rate pub sample_rate: u32, - /// Number of channels pub channels: u32, - /// Number of samples per channel pub samples: u32, - /// Frame sequence number pub sequence: u64, - /// Capture timestamp pub timestamp: Instant, } impl AudioFrame { - /// One capture block: `sample_rate` must be the **hardware** rate (e.g. ALSA `actual_rate`). pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self { let bps = 2 * channels; Self { @@ -96,7 +76,6 @@ impl AudioFrame { } } -/// Audio capture state #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CaptureState { Stopped, @@ -104,7 +83,6 @@ pub enum CaptureState { Error, } -/// ALSA audio capturer pub struct AudioCapturer { config: AudioConfig, state: Arc>, @@ -113,15 +91,13 @@ pub struct AudioCapturer { stop_flag: Arc, sequence: Arc, capture_handle: Mutex>>, - /// Log throttler to prevent log flooding log_throttler: LogThrottler, } impl AudioCapturer { - /// Create a new audio capturer pub fn new(config: AudioConfig) -> Self { let (state_tx, state_rx) = watch::channel(CaptureState::Stopped); - let (frame_tx, _) = broadcast::channel(16); // Buffer size 16 for low latency + let (frame_tx, _) = broadcast::channel(16); Self { config, @@ -135,22 +111,18 @@ impl AudioCapturer { } } - /// Get current state pub fn state(&self) -> CaptureState { *self.state_rx.borrow() } - /// Subscribe to state changes pub fn state_watch(&self) -> watch::Receiver { self.state_rx.clone() } - /// Subscribe to audio frames pub fn subscribe(&self) -> broadcast::Receiver { self.frame_tx.subscribe() } - /// Start capturing pub async fn start(&self) -> Result<()> { if self.state() == CaptureState::Running { return Ok(()); @@ -171,14 +143,27 @@ impl AudioCapturer { let log_throttler = self.log_throttler.clone(); let handle = tokio::task::spawn_blocking(move || { - capture_loop(config, state, frame_tx, stop_flag, sequence, log_throttler); + let result = run_capture( + &config, + &state, + &frame_tx, + &stop_flag, + &sequence, + &log_throttler, + ); + + if let Err(e) = result { + error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e); + let _ = state.send(CaptureState::Error); + } else { + let _ = state.send(CaptureState::Stopped); + } }); *self.capture_handle.lock().await = Some(handle); Ok(()) } - /// Stop capturing pub async fn stop(&self) -> Result<()> { info!("Stopping audio capture"); self.stop_flag.store(true, Ordering::SeqCst); @@ -191,38 +176,11 @@ impl AudioCapturer { Ok(()) } - /// Check if running pub fn is_running(&self) -> bool { self.state() == CaptureState::Running } } -/// Main capture loop -fn capture_loop( - config: AudioConfig, - state: Arc>, - frame_tx: broadcast::Sender, - stop_flag: Arc, - sequence: Arc, - log_throttler: LogThrottler, -) { - let result = run_capture( - &config, - &state, - &frame_tx, - &stop_flag, - &sequence, - &log_throttler, - ); - - if let Err(e) = result { - error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e); - let _ = state.send(CaptureState::Error); - } else { - let _ = state.send(CaptureState::Stopped); - } -} - fn run_capture( config: &AudioConfig, state: &watch::Sender, @@ -231,7 +189,6 @@ fn run_capture( sequence: &AtomicU64, log_throttler: &LogThrottler, ) -> Result<()> { - // Open ALSA device let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| { AppError::AudioError(format!( "Failed to open audio device {}: {}", @@ -239,7 +196,6 @@ fn run_capture( )) })?; - // Configure hardware parameters { let hwp = HwParams::any(&pcm) .map_err(|e| AppError::AudioError(format!("Failed to get HwParams: {}", e)))?; @@ -266,7 +222,6 @@ fn run_capture( .map_err(|e| AppError::AudioError(format!("Failed to apply hw params: {}", e)))?; } - // Fixed 48 kHz stereo: fail if hardware negotiated something else. let hw_now = pcm.hw_params_current().map_err(|e| { AppError::AudioError(format!("Failed to read hw_params after apply: {}", e)) })?; @@ -290,13 +245,11 @@ fn run_capture( } info!("Audio capture: 48000 Hz, 2 ch"); - // Prepare for capture pcm.prepare() .map_err(|e| AppError::AudioError(format!("Failed to prepare PCM: {}", e)))?; let _ = state.send(CaptureState::Running); - // Sized from actual period — `readi` may return up to ~one period of frames per call. let period_frames = pcm .hw_params_current() .ok() @@ -308,9 +261,7 @@ fn run_capture( let bytes_per_frame = (config.channels as usize) * 2; let mut buffer = vec![0u8; buf_frames * bytes_per_frame]; - // Capture loop while !stop_flag.load(Ordering::Relaxed) { - // Check PCM state match pcm.state() { State::XRun => { warn_throttled!(log_throttler, "xrun", "Audio buffer overrun, recovering"); @@ -329,9 +280,7 @@ fn run_capture( _ => {} } - // Get IO handle and read audio data directly as bytes - // Note: Use io() instead of io_checked() because USB audio devices - // typically don't support mmap, which io_checked() requires + // io_bytes: USB capture often lacks mmap (io_checked requires it). let io: IO = pcm.io_bytes(); match io.readi(&mut buffer) { @@ -340,10 +289,8 @@ fn run_capture( continue; } - // Calculate actual byte count let byte_count = frames_read * config.channels as usize * 2; - // Directly use the buffer slice (already in correct byte format) let seq = sequence.fetch_add(1, Ordering::Relaxed); let frame = AudioFrame::new_interleaved( Bytes::copy_from_slice(&buffer[..byte_count]), @@ -352,7 +299,6 @@ fn run_capture( seq, ); - // Send to subscribers if frame_tx.receiver_count() > 0 { if let Err(e) = frame_tx.send(frame) { debug!("No audio receivers: {}", e); @@ -360,14 +306,11 @@ fn run_capture( } } Err(e) => { - // Check for buffer overrun (EPIPE = 32 on Linux) let desc = e.to_string(); if desc.contains("EPIPE") || desc.contains("Broken pipe") { - // Buffer overrun warn_throttled!(log_throttler, "buffer_overrun", "Audio buffer overrun"); let _ = pcm.prepare(); } else if desc.contains("No such device") || desc.contains("ENODEV") { - // Device disconnected - use longer throttle for this error_throttled!(log_throttler, "no_device", "Audio read error: {}", e); } else { error_throttled!(log_throttler, "read_error", "Audio read error: {}", e); diff --git a/src/audio/controller.rs b/src/audio/controller.rs index 9539b7dc..d6346e4e 100644 --- a/src/audio/controller.rs +++ b/src/audio/controller.rs @@ -1,35 +1,31 @@ -//! Audio controller for high-level audio management -//! -//! Provides device enumeration, selection, quality control, and streaming management. +//! Device selection, quality presets, streaming. use serde::{Deserialize, Serialize}; +use std::str::FromStr; use std::sync::Arc; use tokio::sync::RwLock; use tracing::info; use super::capture::AudioConfig; -use super::device::{enumerate_audio_devices_with_current, AudioDeviceInfo}; +use super::device::{ + enumerate_audio_devices_with_current, find_best_audio_device, AudioDeviceInfo, +}; use super::encoder::{OpusConfig, OpusFrame}; -use super::monitor::{AudioHealthMonitor, AudioHealthStatus}; +use super::monitor::AudioHealthMonitor; use super::streamer::{AudioStreamer, AudioStreamerConfig}; use crate::error::{AppError, Result}; use crate::events::EventBus; -/// Audio quality presets #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all = "lowercase")] pub enum AudioQuality { - /// Low bandwidth voice (32kbps) Voice, - /// Balanced quality (64kbps) - default #[default] Balanced, - /// High quality audio (128kbps) High, } impl AudioQuality { - /// Get the bitrate for this quality level pub fn bitrate(&self) -> u32 { match self { AudioQuality::Voice => 32000, @@ -38,17 +34,6 @@ impl AudioQuality { } } - /// Parse from string - #[allow(clippy::should_implement_trait)] - pub fn from_str(s: &str) -> Self { - match s.to_lowercase().as_str() { - "voice" | "low" => AudioQuality::Voice, - "high" | "music" => AudioQuality::High, - _ => AudioQuality::Balanced, - } - } - - /// Convert to OpusConfig pub fn to_opus_config(&self) -> OpusConfig { match self { AudioQuality::Voice => OpusConfig::voice(), @@ -58,6 +43,22 @@ impl AudioQuality { } } +impl FromStr for AudioQuality { + type Err = AppError; + + fn from_str(s: &str) -> std::result::Result { + match s.trim().to_lowercase().as_str() { + "voice" => Ok(Self::Voice), + "balanced" => Ok(Self::Balanced), + "high" => Ok(Self::High), + _ => Err(AppError::BadRequest(format!( + "invalid audio quality {:?} (expected voice, balanced, or high)", + s.trim() + ))), + } + } +} + impl std::fmt::Display for AudioQuality { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -68,17 +69,10 @@ impl std::fmt::Display for AudioQuality { } } -/// Audio controller configuration -/// -/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo). -/// These are optimal for Opus encoding and match WebRTC requirements. #[derive(Debug, Clone)] pub struct AudioControllerConfig { - /// Whether audio is enabled pub enabled: bool, - /// Selected device name pub device: String, - /// Audio quality preset pub quality: AudioQuality, } @@ -86,74 +80,52 @@ impl Default for AudioControllerConfig { fn default() -> Self { Self { enabled: false, - device: "default".to_string(), + device: String::new(), quality: AudioQuality::Balanced, } } } -/// Current audio status #[derive(Debug, Clone, Serialize)] pub struct AudioStatus { - /// Whether audio feature is enabled pub enabled: bool, - /// Whether audio is currently streaming pub streaming: bool, - /// Currently selected device pub device: Option, - /// Current quality preset pub quality: AudioQuality, - /// Number of connected subscribers pub subscriber_count: usize, - /// Error message if any pub error: Option, } -/// Audio controller -/// -/// High-level interface for audio management, providing: -/// - Device enumeration and selection -/// - Quality control -/// - Stream start/stop -/// - Status reporting pub struct AudioController { config: RwLock, streamer: RwLock>>, devices: RwLock>, event_bus: RwLock>>, - last_error: RwLock>, - /// Health monitor for error tracking and recovery monitor: Arc, } impl AudioController { - /// Create a new audio controller with configuration pub fn new(config: AudioControllerConfig) -> Self { Self { config: RwLock::new(config), streamer: RwLock::new(None), devices: RwLock::new(Vec::new()), event_bus: RwLock::new(None), - last_error: RwLock::new(None), - monitor: Arc::new(AudioHealthMonitor::with_defaults()), + monitor: Arc::new(AudioHealthMonitor::new()), } } - /// Set event bus for internal state notifications. pub async fn set_event_bus(&self, event_bus: Arc) { *self.event_bus.write().await = Some(event_bus); } - /// Mark the device-info snapshot as stale. async fn mark_device_info_dirty(&self) { if let Some(ref bus) = *self.event_bus.read().await { bus.mark_device_info_dirty(); } } - /// List available audio capture devices pub async fn list_devices(&self) -> Result> { - // Get current device if streaming (it may be busy and unable to be opened) let current_device = if self.is_streaming().await { Some(self.config.read().await.device.clone()) } else { @@ -165,41 +137,23 @@ impl AudioController { Ok(devices) } - /// Refresh device list and cache it - pub async fn refresh_devices(&self) -> Result<()> { - // Get current device if streaming (it may be busy and unable to be opened) - let current_device = if self.is_streaming().await { - Some(self.config.read().await.device.clone()) - } else { - None - }; - - let devices = enumerate_audio_devices_with_current(current_device.as_deref())?; - *self.devices.write().await = devices; - Ok(()) - } - - /// Get cached device list pub async fn get_cached_devices(&self) -> Vec { self.devices.read().await.clone() } - /// Select audio device pub async fn select_device(&self, device: &str) -> Result<()> { - // Validate device exists let devices = self.list_devices().await?; let found = devices .iter() .any(|d| d.name == device || d.description.contains(device)); - if !found && device != "default" { + if !found { return Err(AppError::AudioError(format!( "Audio device not found: {}", device ))); } - // Update config { let mut config = self.config.write().await; config.device = device.to_string(); @@ -207,7 +161,6 @@ impl AudioController { info!("Audio device selected: {}", device); - // If streaming, restart with new device if self.is_streaming().await { self.stop_streaming().await?; self.start_streaming().await?; @@ -216,15 +169,12 @@ impl AudioController { Ok(()) } - /// Set audio quality pub async fn set_quality(&self, quality: AudioQuality) -> Result<()> { - // Update config { let mut config = self.config.write().await; config.quality = quality; } - // Update streamer if running if let Some(ref streamer) = *self.streamer.read().await { streamer.set_bitrate(quality.bitrate()).await?; } @@ -237,44 +187,45 @@ impl AudioController { Ok(()) } - /// Start audio streaming pub async fn start_streaming(&self) -> Result<()> { - let config = self.config.read().await.clone(); - - if !config.enabled { - return Err(AppError::AudioError("Audio is disabled".to_string())); + { + let config = self.config.read().await; + if !config.enabled { + return Err(AppError::AudioError("Audio is disabled".to_string())); + } } - // Check if already streaming if self.is_streaming().await { return Ok(()); } - info!("Starting audio streaming with device: {}", config.device); - - // Clear any previous error - *self.last_error.write().await = None; - - // Create streamer config (fixed 48kHz stereo) - let streamer_config = AudioStreamerConfig { - capture: AudioConfig { - device_name: config.device.clone(), - ..Default::default() - }, - opus: config.quality.to_opus_config(), + let (device_name, quality) = { + let mut cfg = self.config.write().await; + if cfg.device.trim().is_empty() { + let best = find_best_audio_device()?; + cfg.device = best.name; + } + (cfg.device.clone(), cfg.quality) + }; + + info!("Starting audio streaming with device: {}", device_name); + + self.monitor.prepare_retry_attempt(); + + let streamer_config = AudioStreamerConfig { + capture: AudioConfig { + device_name: device_name.clone(), + ..Default::default() + }, + opus: quality.to_opus_config(), }; - // Create and start streamer let streamer = Arc::new(AudioStreamer::with_config(streamer_config)); if let Err(e) = streamer.start().await { let error_msg = format!("Failed to start audio: {}", e); - *self.last_error.write().await = Some(error_msg.clone()); - // Report error to health monitor - self.monitor - .report_error(Some(&config.device), &error_msg, "start_failed") - .await; + self.monitor.report_error(&error_msg, "start_failed").await; self.mark_device_info_dirty().await; @@ -283,9 +234,8 @@ impl AudioController { *self.streamer.write().await = Some(streamer); - // Report recovery if we were in an error state if self.monitor.is_error().await { - self.monitor.report_recovered(Some(&config.device)).await; + self.monitor.report_recovered().await; } self.mark_device_info_dirty().await; @@ -294,7 +244,6 @@ impl AudioController { Ok(()) } - /// Stop audio streaming pub async fn stop_streaming(&self) -> Result<()> { if let Some(streamer) = self.streamer.write().await.take() { streamer.stop().await?; @@ -306,7 +255,6 @@ impl AudioController { Ok(()) } - /// Check if currently streaming pub async fn is_streaming(&self) -> bool { if let Some(ref streamer) = *self.streamer.read().await { streamer.is_running() @@ -315,45 +263,37 @@ impl AudioController { } } - /// Get current status pub async fn status(&self) -> AudioStatus { - let config = self.config.read().await; - let streaming = self.is_streaming().await; - let error = self.last_error.read().await.clone(); + let (enabled, device_str, quality) = { + let c = self.config.read().await; + (c.enabled, c.device.clone(), c.quality) + }; + let error = self.monitor.error_message().await; - let subscriber_count = if let Some(ref streamer) = *self.streamer.read().await { - streamer.stats().await.subscriber_count + let (streaming, subscriber_count) = if let Some(ref streamer) = *self.streamer.read().await + { + let streaming = streamer.is_running(); + let subscriber_count = streamer.stats().subscriber_count; + (streaming, subscriber_count) } else { - 0 + (false, 0) }; AudioStatus { - enabled: config.enabled, + enabled, streaming, - device: if streaming || config.enabled { - Some(config.device.clone()) + device: if streaming || enabled { + Some(device_str) } else { None }, - quality: config.quality, + quality, subscriber_count, error, } } - /// Subscribe to Opus frames (for WebSocket clients) - pub fn subscribe_opus(&self) -> Option>> { - if let Ok(guard) = self.streamer.try_read() { - guard.as_ref().map(|s| s.subscribe_opus()) - } else { - None - } - } - - /// Subscribe to Opus frames (async version) - pub async fn subscribe_opus_async( - &self, - ) -> Option>> { + pub async fn subscribe_opus(&self) -> Option>> { self.streamer .read() .await @@ -361,7 +301,6 @@ impl AudioController { .map(|s| s.subscribe_opus()) } - /// Enable or disable audio pub async fn set_enabled(&self, enabled: bool) -> Result<()> { { let mut config = self.config.write().await; @@ -376,21 +315,15 @@ impl AudioController { Ok(()) } - /// Update full configuration pub async fn update_config(&self, new_config: AudioControllerConfig) -> Result<()> { let was_streaming = self.is_streaming().await; - // Stop streaming if running (device/quality/enabled may all change) if was_streaming { self.stop_streaming().await?; } - // Update config *self.config.write().await = new_config.clone(); - // Start whenever audio is enabled — not only when we were already streaming. - // Otherwise PATCH /config/audio alone leaves enabled=true with no capture until - // POST /audio/start, which races WebRTC reconnect and matches "apply twice" reports. if new_config.enabled { self.start_streaming().await?; } @@ -398,25 +331,9 @@ impl AudioController { Ok(()) } - /// Shutdown the controller pub async fn shutdown(&self) -> Result<()> { self.stop_streaming().await } - - /// Get the health monitor reference - pub fn monitor(&self) -> &Arc { - &self.monitor - } - - /// Get current health status - pub async fn health_status(&self) -> AudioHealthStatus { - self.monitor.status().await - } - - /// Check if the audio is healthy - pub async fn is_healthy(&self) -> bool { - self.monitor.is_healthy().await - } } impl Default for AudioController { @@ -438,12 +355,23 @@ mod tests { #[test] fn test_audio_quality_from_str() { - assert_eq!(AudioQuality::from_str("voice"), AudioQuality::Voice); - assert_eq!(AudioQuality::from_str("low"), AudioQuality::Voice); - assert_eq!(AudioQuality::from_str("balanced"), AudioQuality::Balanced); - assert_eq!(AudioQuality::from_str("high"), AudioQuality::High); - assert_eq!(AudioQuality::from_str("music"), AudioQuality::High); - assert_eq!(AudioQuality::from_str("unknown"), AudioQuality::Balanced); + assert_eq!( + "voice".parse::().unwrap(), + AudioQuality::Voice + ); + assert_eq!( + "balanced".parse::().unwrap(), + AudioQuality::Balanced + ); + assert_eq!("high".parse::().unwrap(), AudioQuality::High); + } + + #[test] + fn test_audio_quality_from_str_rejects_aliases_and_unknown() { + assert!("low".parse::().is_err()); + assert!("music".parse::().is_err()); + assert!("unknown".parse::().is_err()); + assert!("".parse::().is_err()); } #[tokio::test] diff --git a/src/audio/device.rs b/src/audio/device.rs index 77536680..66df38cc 100644 --- a/src/audio/device.rs +++ b/src/audio/device.rs @@ -1,5 +1,3 @@ -//! Audio device enumeration using ALSA - use alsa::pcm::HwParams; use alsa::{Direction, PCM}; use serde::Serialize; @@ -7,54 +5,30 @@ use tracing::{debug, info, warn}; use crate::error::{AppError, Result}; -/// Audio device information #[derive(Debug, Clone, Serialize)] pub struct AudioDeviceInfo { - /// Device name (e.g., "hw:0,0" or "default") pub name: String, - /// Human-readable description pub description: String, - /// Card index pub card_index: i32, - /// Device index pub device_index: i32, - /// Supported sample rates pub sample_rates: Vec, - /// Supported channel counts pub channels: Vec, - /// Is this a capture device pub is_capture: bool, - /// Is this an HDMI audio device (likely from capture card) pub is_hdmi: bool, - /// USB bus info for matching with video devices (e.g., "1-1" from USB path) pub usb_bus: Option, } -impl AudioDeviceInfo { - /// Get ALSA device name - pub fn alsa_name(&self) -> String { - format!("hw:{},{}", self.card_index, self.device_index) - } -} - -/// Get USB bus info for an audio card by reading sysfs -/// Returns the USB port path like "1-1" or "1-2.3" fn get_usb_bus_info(card_index: i32) -> Option { if card_index < 0 { return None; } - // Read the device symlink: /sys/class/sound/cardX/device -> ../../usb1/1-1/1-1:1.0 let device_path = format!("/sys/class/sound/card{}/device", card_index); let link_target = std::fs::read_link(&device_path).ok()?; let link_str = link_target.to_string_lossy(); - // Extract USB port from path like "../../usb1/1-1/1-1:1.0" or "../../1-1/1-1:1.0" - // We want the "1-1" part (USB bus-port) for component in link_str.split('/') { - // Match patterns like "1-1", "1-2", "1-1.2", "2-1.3.1" if component.contains('-') && !component.contains(':') { - // Verify it looks like a USB port (starts with digit) if component .chars() .next() @@ -69,22 +43,15 @@ fn get_usb_bus_info(card_index: i32) -> Option { None } -/// Enumerate available audio capture devices pub fn enumerate_audio_devices() -> Result> { enumerate_audio_devices_with_current(None) } -/// Enumerate available audio capture devices, with option to include a currently-in-use device -/// -/// # Arguments -/// * `current_device` - Optional device name that is currently in use. This device will be -/// included in the list even if it cannot be opened (because it's already open by us). pub fn enumerate_audio_devices_with_current( current_device: Option<&str>, ) -> Result> { let mut devices = Vec::new(); - // Try to enumerate cards let cards = alsa::card::Iter::new(); for card_result in cards { @@ -102,104 +69,71 @@ pub fn enumerate_audio_devices_with_current( debug!("Found audio card {}: {}", card_index, card_longname); - // Check if this looks like an HDMI capture device - let is_hdmi = card_longname.to_lowercase().contains("hdmi") - || card_longname.to_lowercase().contains("capture") - || card_longname.to_lowercase().contains("usb"); + let long_lower = card_longname.to_lowercase(); + let is_hdmi = long_lower.contains("hdmi") + || long_lower.contains("capture") + || long_lower.contains("usb"); - // Get USB bus info for this card let usb_bus = get_usb_bus_info(card_index); - // Try to open each device on this card for capture for device_index in 0..8 { let device_name = format!("hw:{},{}", card_index, device_index); - - // Check if this is the currently-in-use device let is_current_device = current_device == Some(device_name.as_str()); - // Try to open for capture + let mut push_info = + |sample_rates: Vec, channels: Vec, description: String| { + devices.push(AudioDeviceInfo { + name: device_name.clone(), + description, + card_index, + device_index, + sample_rates, + channels, + is_capture: true, + is_hdmi, + usb_bus: usb_bus.clone(), + }); + }; + match PCM::new(&device_name, Direction::Capture, false) { Ok(pcm) => { - // Query capabilities let (sample_rates, channels) = query_device_caps(&pcm); if !sample_rates.is_empty() && !channels.is_empty() { - devices.push(AudioDeviceInfo { - name: device_name, - description: format!("{} - Device {}", card_longname, device_index), - card_index, - device_index, + push_info( sample_rates, channels, - is_capture: true, - is_hdmi, - usb_bus: usb_bus.clone(), - }); + format!("{} - Device {}", card_longname, device_index), + ); } } Err(_) => { - // Device doesn't exist or can't be opened for capture - // But if it's the current device, include it anyway (it's busy because we're using it) if is_current_device { debug!( "Device {} is busy (in use by us), adding with default caps", device_name ); - devices.push(AudioDeviceInfo { - name: device_name, - description: format!( - "{} - Device {} (in use)", - card_longname, device_index - ), - card_index, - device_index, - // Use common default capabilities for HDMI capture devices - sample_rates: vec![44100, 48000], - channels: vec![2], - is_capture: true, - is_hdmi, - usb_bus: usb_bus.clone(), - }); + push_info( + vec![44100, 48000], + vec![2], + format!("{} - Device {} (in use)", card_longname, device_index), + ); } - continue; } } } } - // Also check for "default" device - if let Ok(pcm) = PCM::new("default", Direction::Capture, false) { - let (sample_rates, channels) = query_device_caps(&pcm); - if !sample_rates.is_empty() { - devices.insert( - 0, - AudioDeviceInfo { - name: "default".to_string(), - description: "Default Audio Device".to_string(), - card_index: -1, - device_index: -1, - sample_rates, - channels, - is_capture: true, - is_hdmi: false, - usb_bus: None, - }, - ); - } - } - info!("Found {} audio capture devices", devices.len()); Ok(devices) } -/// Query device capabilities fn query_device_caps(pcm: &PCM) -> (Vec, Vec) { let hwp = match HwParams::any(pcm) { Ok(h) => h, Err(_) => return (vec![], vec![]), }; - // Common sample rates to check let common_rates = [8000, 16000, 22050, 44100, 48000, 96000]; let mut supported_rates = Vec::new(); @@ -209,7 +143,6 @@ fn query_device_caps(pcm: &PCM) -> (Vec, Vec) { } } - // Check channel counts let mut supported_channels = Vec::new(); for ch in 1..=8 { if hwp.test_channels(ch).is_ok() { @@ -220,8 +153,6 @@ fn query_device_caps(pcm: &PCM) -> (Vec, Vec) { (supported_rates, supported_channels) } -/// Find the best audio device for capture -/// Prefers HDMI/capture devices over built-in microphones pub fn find_best_audio_device() -> Result { let devices = enumerate_audio_devices()?; @@ -231,23 +162,24 @@ pub fn find_best_audio_device() -> Result { )); } - // First, look for HDMI/capture card devices that support 48kHz stereo + let mut first_48k_stereo: Option<&AudioDeviceInfo> = None; for device in &devices { - if device.is_hdmi && device.sample_rates.contains(&48000) && device.channels.contains(&2) { + if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) { + continue; + } + if device.is_hdmi { info!("Selected HDMI audio device: {}", device.description); return Ok(device.clone()); } - } - - // Then look for any device supporting 48kHz stereo - for device in &devices { - if device.sample_rates.contains(&48000) && device.channels.contains(&2) { - info!("Selected audio device: {}", device.description); - return Ok(device.clone()); + if first_48k_stereo.is_none() { + first_48k_stereo = Some(device); } } + if let Some(device) = first_48k_stereo { + info!("Selected audio device: {}", device.description); + return Ok(device.clone()); + } - // Fall back to first device let device = devices.into_iter().next().unwrap(); warn!( "Using fallback audio device: {} (may not support optimal settings)", @@ -262,10 +194,8 @@ mod tests { #[test] fn test_enumerate_devices() { - // This test may not find devices in CI environment let result = enumerate_audio_devices(); println!("Audio devices: {:?}", result); - // Just verify it doesn't panic assert!(result.is_ok()); } } diff --git a/src/audio/encoder.rs b/src/audio/encoder.rs index bcd316b1..3095a9cc 100644 --- a/src/audio/encoder.rs +++ b/src/audio/encoder.rs @@ -1,26 +1,19 @@ -//! Opus audio encoder for WebRTC +//! Opus encoder. use audiopus::coder::GenericCtl; use audiopus::{coder::Encoder, Application, Bitrate, Channels, SampleRate}; use bytes::Bytes; -use std::time::Instant; use tracing::info; use super::capture::AudioFrame; use crate::error::{AppError, Result}; -/// Opus encoder configuration #[derive(Debug, Clone)] pub struct OpusConfig { - /// Sample rate (must be 8000, 12000, 16000, 24000, or 48000) pub sample_rate: u32, - /// Channels (1 or 2) pub channels: u32, - /// Target bitrate in bps pub bitrate: u32, - /// Application mode pub application: OpusApplication, - /// Enable forward error correction pub fec: bool, } @@ -29,7 +22,7 @@ impl Default for OpusConfig { Self { sample_rate: 48000, channels: 2, - bitrate: 64000, // 64 kbps + bitrate: 64000, application: OpusApplication::Audio, fec: true, } @@ -37,7 +30,6 @@ impl Default for OpusConfig { } impl OpusConfig { - /// Create config for voice (lower latency) pub fn voice() -> Self { Self { application: OpusApplication::Voip, @@ -46,7 +38,6 @@ impl OpusConfig { } } - /// Create config for music (higher quality) pub fn music() -> Self { Self { application: OpusApplication::Audio, @@ -82,30 +73,18 @@ impl OpusConfig { } } -/// Opus application mode #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum OpusApplication { - /// Voice over IP Voip, - /// General audio Audio, - /// Low delay mode LowDelay, } -/// Encoded Opus frame #[derive(Debug, Clone)] pub struct OpusFrame { - /// Encoded Opus data pub data: Bytes, - /// Duration in milliseconds pub duration_ms: u32, - /// Sequence number pub sequence: u64, - /// Timestamp - pub timestamp: Instant, - /// RTP timestamp (samples) - pub rtp_timestamp: u32, } impl OpusFrame { @@ -118,20 +97,14 @@ impl OpusFrame { } } -/// Opus encoder pub struct OpusEncoder { config: OpusConfig, encoder: Encoder, - /// Output buffer output_buffer: Vec, - /// Frame counter for RTP timestamp frame_count: u64, - /// Samples per frame - samples_per_frame: u32, } impl OpusEncoder { - /// Create a new Opus encoder pub fn new(config: OpusConfig) -> Result { let sample_rate = config.to_audiopus_sample_rate(); let channels = config.to_audiopus_channels(); @@ -140,7 +113,6 @@ impl OpusEncoder { let mut encoder = Encoder::new(sample_rate, channels, application) .map_err(|e| AppError::AudioError(format!("Failed to create Opus encoder: {:?}", e)))?; - // Configure encoder encoder .set_bitrate(Bitrate::BitsPerSecond(config.bitrate as i32)) .map_err(|e| AppError::AudioError(format!("Failed to set bitrate: {:?}", e)))?; @@ -151,9 +123,6 @@ impl OpusEncoder { .map_err(|e| AppError::AudioError(format!("Failed to enable FEC: {:?}", e)))?; } - // Calculate samples per frame (20ms at sample_rate) - let samples_per_frame = config.sample_rate / 50; - info!( "Opus encoder created: {}Hz {}ch {}bps", config.sample_rate, config.channels, config.bitrate @@ -162,18 +131,11 @@ impl OpusEncoder { Ok(Self { config, encoder, - output_buffer: vec![0u8; 4000], // Max Opus frame size + output_buffer: vec![0u8; 4000], frame_count: 0, - samples_per_frame, }) } - /// Create with default configuration - pub fn default_config() -> Result { - Self::new(OpusConfig::default()) - } - - /// Encode PCM audio data (S16LE interleaved) pub fn encode(&mut self, pcm_data: &[i16]) -> Result { let encoded_len = self .encoder @@ -182,7 +144,6 @@ impl OpusEncoder { let samples = pcm_data.len() as u32 / self.config.channels; let duration_ms = (samples * 1000) / self.config.sample_rate; - let rtp_timestamp = (self.frame_count * self.samples_per_frame as u64) as u32; self.frame_count += 1; @@ -190,27 +151,18 @@ impl OpusEncoder { data: Bytes::copy_from_slice(&self.output_buffer[..encoded_len]), duration_ms, sequence: self.frame_count - 1, - timestamp: Instant::now(), - rtp_timestamp, }) } - /// Encode from AudioFrame - /// - /// Uses zero-copy conversion from bytes to i16 samples via bytemuck. pub fn encode_frame(&mut self, frame: &AudioFrame) -> Result { - // Zero-copy: directly cast bytes to i16 slice - // AudioFrame.data is S16LE format, which matches native little-endian i16 let samples: &[i16] = bytemuck::cast_slice(&frame.data); self.encode(samples) } - /// Get encoder configuration pub fn config(&self) -> &OpusConfig { &self.config } - /// Reset encoder state pub fn reset(&mut self) -> Result<()> { self.encoder .reset_state() @@ -219,7 +171,6 @@ impl OpusEncoder { Ok(()) } - /// Set bitrate dynamically pub fn set_bitrate(&mut self, bitrate: u32) -> Result<()> { self.encoder .set_bitrate(Bitrate::BitsPerSecond(bitrate as i32)) @@ -228,15 +179,6 @@ impl OpusEncoder { } } -/// Audio encoder statistics -#[derive(Debug, Clone, Default)] -pub struct EncoderStats { - pub frames_encoded: u64, - pub bytes_output: u64, - pub avg_frame_size: usize, - pub current_bitrate: u32, -} - #[cfg(test)] mod tests { use super::*; @@ -261,13 +203,12 @@ mod tests { let config = OpusConfig::default(); let mut encoder = OpusEncoder::new(config).unwrap(); - // 20ms of stereo silence at 48kHz let silence = vec![0i16; 960 * 2]; let result = encoder.encode(&silence); assert!(result.is_ok()); let frame = result.unwrap(); assert!(!frame.is_empty()); - assert!(frame.len() < silence.len() * 2); // Should be compressed + assert!(frame.len() < silence.len() * 2); } } diff --git a/src/audio/mod.rs b/src/audio/mod.rs index 829bef91..0ba10841 100644 --- a/src/audio/mod.rs +++ b/src/audio/mod.rs @@ -1,12 +1,4 @@ -//! Audio capture and encoding module -//! -//! This module provides: -//! - ALSA audio capture -//! - Opus encoding for WebRTC -//! - Audio device enumeration -//! - Audio streaming pipeline -//! - High-level audio controller -//! - Device health monitoring +//! ALSA capture, Opus encode, device enumeration, streaming, controller, health monitor. pub mod capture; pub mod controller; @@ -19,5 +11,5 @@ pub use capture::{AudioCapturer, AudioConfig, AudioFrame}; pub use controller::{AudioController, AudioControllerConfig, AudioQuality, AudioStatus}; pub use device::{enumerate_audio_devices, enumerate_audio_devices_with_current, AudioDeviceInfo}; pub use encoder::{OpusConfig, OpusEncoder, OpusFrame}; -pub use monitor::{AudioHealthMonitor, AudioHealthStatus, AudioMonitorConfig}; +pub use monitor::{AudioHealthMonitor, AudioHealthStatus}; pub use streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig}; diff --git a/src/audio/monitor.rs b/src/audio/monitor.rs index c7dfd049..35bb03e2 100644 --- a/src/audio/monitor.rs +++ b/src/audio/monitor.rs @@ -1,114 +1,58 @@ -//! Audio device health monitoring -//! -//! This module provides health monitoring for audio capture devices, including: -//! - Device connectivity checks -//! - Automatic reconnection on failure -//! - Error tracking -//! - Log throttling to prevent log flooding +//! Audio device health and logging throttle for repeated failures. -use std::sync::atomic::{AtomicU32, Ordering}; -use std::time::Duration; +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use tokio::sync::RwLock; use tracing::{info, warn}; use crate::utils::LogThrottler; -/// Audio health status +const LOG_THROTTLE_SECS: u64 = 5; + #[derive(Debug, Clone, PartialEq, Default)] pub enum AudioHealthStatus { - /// Device is healthy and operational #[default] Healthy, - /// Device has an error, attempting recovery Error { - /// Human-readable error reason reason: String, - /// Error code for programmatic handling error_code: String, - /// Number of recovery attempts made - retry_count: u32, }, - /// Device is disconnected or not available - Disconnected, } -/// Audio health monitor configuration -#[derive(Debug, Clone)] -pub struct AudioMonitorConfig { - /// Retry interval when device is lost (milliseconds) - pub retry_interval_ms: u64, - /// Maximum retry attempts before giving up (0 = infinite) - pub max_retries: u32, - /// Log throttle interval in seconds - pub log_throttle_secs: u64, -} - -impl Default for AudioMonitorConfig { - fn default() -> Self { - Self { - retry_interval_ms: 1000, - max_retries: 0, // infinite retry - log_throttle_secs: 5, - } - } -} - -/// Audio health monitor -/// -/// Monitors audio device health and manages error recovery. pub struct AudioHealthMonitor { - /// Current health status status: RwLock, - /// Log throttler to prevent log flooding throttler: LogThrottler, - /// Configuration - config: AudioMonitorConfig, - /// Current retry count retry_count: AtomicU32, - /// Last error code (for change detection) last_error_code: RwLock>, + /// Hide `error_message` while a new capture attempt is in flight (internal error state unchanged). + suppress_display: AtomicBool, } impl AudioHealthMonitor { - /// Create a new audio health monitor with the specified configuration - pub fn new(config: AudioMonitorConfig) -> Self { - let throttle_secs = config.log_throttle_secs; + pub fn new() -> Self { Self { status: RwLock::new(AudioHealthStatus::Healthy), - throttler: LogThrottler::with_secs(throttle_secs), - config, + throttler: LogThrottler::with_secs(LOG_THROTTLE_SECS), retry_count: AtomicU32::new(0), last_error_code: RwLock::new(None), + suppress_display: AtomicBool::new(false), } } - /// Create a new audio health monitor with default configuration - pub fn with_defaults() -> Self { - Self::new(AudioMonitorConfig::default()) + /// Clears the error string exposed via [`Self::error_message`] until the next outcome (`report_error` or recovery). + pub fn prepare_retry_attempt(&self) { + self.suppress_display.store(true, Ordering::Relaxed); } - /// Report an error from audio operations - /// - /// This method is called when an audio operation fails. It: - /// 1. Updates the health status - /// 2. Logs the error (with throttling) - /// 3. Updates in-memory error state - /// - /// # Arguments - /// - /// * `device` - The audio device name (if known) - /// * `reason` - Human-readable error description - /// * `error_code` - Error code for programmatic handling - pub async fn report_error(&self, _device: Option<&str>, reason: &str, error_code: &str) { + pub async fn report_error(&self, reason: &str, error_code: &str) { + self.suppress_display.store(false, Ordering::Relaxed); + let count = self.retry_count.fetch_add(1, Ordering::Relaxed) + 1; - // Check if error code changed let error_changed = { let last = self.last_error_code.read().await; last.as_ref().map(|s| s.as_str()) != Some(error_code) }; - // Log with throttling (always log if error type changed) let throttle_key = format!("audio_{}", error_code); if error_changed || self.throttler.should_log(&throttle_key) { warn!( @@ -117,34 +61,22 @@ impl AudioHealthMonitor { ); } - // Update last error code *self.last_error_code.write().await = Some(error_code.to_string()); - // Update status *self.status.write().await = AudioHealthStatus::Error { reason: reason.to_string(), error_code: error_code.to_string(), - retry_count: count, }; } - /// Report that the device has recovered - /// - /// This method is called when the audio device successfully reconnects. - /// It resets the error state. - /// - /// # Arguments - /// - /// * `device` - The audio device name - pub async fn report_recovered(&self, _device: Option<&str>) { + pub async fn report_recovered(&self) { let prev_status = self.status.read().await.clone(); - // Only report recovery if we were in an error state if prev_status != AudioHealthStatus::Healthy { let retry_count = self.retry_count.load(Ordering::Relaxed); info!("Audio recovered after {} retries", retry_count); - // Reset state + self.suppress_display.store(false, Ordering::Relaxed); self.retry_count.store(0, Ordering::Relaxed); self.throttler.clear("audio_"); *self.last_error_code.write().await = None; @@ -152,58 +84,30 @@ impl AudioHealthMonitor { } } - /// Get the current health status - pub async fn status(&self) -> AudioHealthStatus { - self.status.read().await.clone() - } - - /// Get the current retry count - pub fn retry_count(&self) -> u32 { - self.retry_count.load(Ordering::Relaxed) - } - - /// Check if the monitor is in an error state - pub async fn is_error(&self) -> bool { - matches!(*self.status.read().await, AudioHealthStatus::Error { .. }) - } - - /// Check if the monitor is healthy - pub async fn is_healthy(&self) -> bool { - matches!(*self.status.read().await, AudioHealthStatus::Healthy) - } - - /// Reset the monitor to healthy state without publishing events - /// - /// This is useful during initialization. pub async fn reset(&self) { + self.suppress_display.store(false, Ordering::Relaxed); self.retry_count.store(0, Ordering::Relaxed); *self.last_error_code.write().await = None; *self.status.write().await = AudioHealthStatus::Healthy; self.throttler.clear_all(); } - /// Get the configuration - pub fn config(&self) -> &AudioMonitorConfig { - &self.config + pub async fn status(&self) -> AudioHealthStatus { + self.status.read().await.clone() } - /// Check if we should continue retrying - /// - /// Returns `false` if max_retries is set and we've exceeded it. - pub fn should_retry(&self) -> bool { - if self.config.max_retries == 0 { - return true; // Infinite retry - } - self.retry_count.load(Ordering::Relaxed) < self.config.max_retries + pub fn retry_count(&self) -> u32 { + self.retry_count.load(Ordering::Relaxed) } - /// Get the retry interval - pub fn retry_interval(&self) -> Duration { - Duration::from_millis(self.config.retry_interval_ms) + pub async fn is_error(&self) -> bool { + matches!(*self.status.read().await, AudioHealthStatus::Error { .. }) } - /// Get the current error message if in error state pub async fn error_message(&self) -> Option { + if self.suppress_display.load(Ordering::Relaxed) { + return None; + } match &*self.status.read().await { AudioHealthStatus::Error { reason, .. } => Some(reason.clone()), _ => None, @@ -213,7 +117,7 @@ impl AudioHealthMonitor { impl Default for AudioHealthMonitor { fn default() -> Self { - Self::with_defaults() + Self::new() } } @@ -223,32 +127,25 @@ mod tests { #[tokio::test] async fn test_initial_status() { - let monitor = AudioHealthMonitor::with_defaults(); - assert!(monitor.is_healthy().await); + let monitor = AudioHealthMonitor::new(); assert!(!monitor.is_error().await); assert_eq!(monitor.retry_count(), 0); } #[tokio::test] async fn test_report_error() { - let monitor = AudioHealthMonitor::with_defaults(); + let monitor = AudioHealthMonitor::new(); monitor - .report_error(Some("hw:0,0"), "Device not found", "device_disconnected") + .report_error("Device not found", "device_disconnected") .await; assert!(monitor.is_error().await); assert_eq!(monitor.retry_count(), 1); - if let AudioHealthStatus::Error { - reason, - error_code, - retry_count, - } = monitor.status().await - { + if let AudioHealthStatus::Error { reason, error_code } = monitor.status().await { assert_eq!(reason, "Device not found"); assert_eq!(error_code, "device_disconnected"); - assert_eq!(retry_count, 1); } else { panic!("Expected Error status"); } @@ -256,39 +153,52 @@ mod tests { #[tokio::test] async fn test_report_recovered() { - let monitor = AudioHealthMonitor::with_defaults(); + let monitor = AudioHealthMonitor::new(); - // First report an error monitor - .report_error(Some("default"), "Capture failed", "capture_error") + .report_error("Capture failed", "capture_error") .await; assert!(monitor.is_error().await); - // Then report recovery - monitor.report_recovered(Some("default")).await; - assert!(monitor.is_healthy().await); + monitor.report_recovered().await; + assert!(!monitor.is_error().await); assert_eq!(monitor.retry_count(), 0); } #[tokio::test] async fn test_retry_count_increments() { - let monitor = AudioHealthMonitor::with_defaults(); + let monitor = AudioHealthMonitor::new(); for i in 1..=5 { - monitor.report_error(None, "Error", "io_error").await; + monitor.report_error("Error", "io_error").await; assert_eq!(monitor.retry_count(), i); } } #[tokio::test] async fn test_reset() { - let monitor = AudioHealthMonitor::with_defaults(); + let monitor = AudioHealthMonitor::new(); - monitor.report_error(None, "Error", "io_error").await; + monitor.report_error("Error", "io_error").await; assert!(monitor.is_error().await); monitor.reset().await; - assert!(monitor.is_healthy().await); + assert!(!monitor.is_error().await); assert_eq!(monitor.retry_count(), 0); } + + #[tokio::test] + async fn test_prepare_retry_hides_error_until_next_failure() { + let monitor = AudioHealthMonitor::new(); + + monitor.report_error("bad", "e").await; + assert_eq!(monitor.error_message().await.as_deref(), Some("bad")); + + monitor.prepare_retry_attempt(); + assert!(monitor.is_error().await); + assert!(monitor.error_message().await.is_none()); + + monitor.report_error("still bad", "e").await; + assert_eq!(monitor.error_message().await.as_deref(), Some("still bad")); + } } diff --git a/src/audio/streamer.rs b/src/audio/streamer.rs index 9b32e27f..d7864484 100644 --- a/src/audio/streamer.rs +++ b/src/audio/streamer.rs @@ -1,10 +1,7 @@ -//! Audio streaming pipeline -//! -//! ALSA capture (48 kHz stereo only) → fixed Opus 20 ms frames → `mpsc` fan-out per subscriber. +//! ALSA 48 kHz stereo → Opus 20 ms frames, fan-out per subscriber. -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; -use std::time::Instant; use tokio::sync::{broadcast, mpsc, watch, Mutex as AsyncMutex, RwLock}; use tracing::{error, info, warn}; @@ -14,34 +11,25 @@ use crate::error::{AppError, Result}; use bytemuck; use bytes::Bytes; -/// Stereo 48 kHz: 20 ms = 960 frames × 2 channels (S16LE). +/// 48 kHz stereo: 20 ms = 960 × 2 samples (S16LE). const OPUS_STEREO_SAMPLES: usize = 960 * 2; -/// Audio stream state #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum AudioStreamState { - /// Stream is stopped #[default] Stopped, - /// Stream is starting up Starting, - /// Stream is running Running, - /// Stream encountered an error Error, } -/// Audio streamer configuration #[derive(Debug, Clone, Default)] pub struct AudioStreamerConfig { - /// Audio capture configuration pub capture: AudioConfig, - /// Opus encoder configuration pub opus: OpusConfig, } impl AudioStreamerConfig { - /// Create config for a specific device with default quality pub fn for_device(device_name: &str) -> Self { Self { capture: AudioConfig { @@ -52,45 +40,32 @@ impl AudioStreamerConfig { } } - /// Create config with specified bitrate pub fn with_bitrate(mut self, bitrate: u32) -> Self { self.opus.bitrate = bitrate; self } } -/// Audio stream statistics #[derive(Debug, Clone, Default)] pub struct AudioStreamStats { - /// Frames encoded to Opus - /// Number of active subscribers pub subscriber_count: usize, } -/// Audio streamer -/// -/// Manages the audio capture → encode → mpsc fan-out pipeline. pub struct AudioStreamer { config: RwLock, state: watch::Sender, state_rx: watch::Receiver, capturer: RwLock>>, encoder: Arc>>, - /// One `mpsc::Sender` per subscriber (like shared video pipeline). opus_subscribers: Arc>>>>, - stats: Arc>, - sequence: AtomicU64, - stream_start_time: RwLock>, stop_flag: Arc, } impl AudioStreamer { - /// Create a new audio streamer with default configuration pub fn new() -> Self { Self::with_config(AudioStreamerConfig::default()) } - /// Create a new audio streamer with specified configuration pub fn with_config(config: AudioStreamerConfig) -> Self { let (state_tx, state_rx) = watch::channel(AudioStreamState::Stopped); @@ -101,31 +76,24 @@ impl AudioStreamer { capturer: RwLock::new(None), encoder: Arc::new(AsyncMutex::new(None)), opus_subscribers: Arc::new(Mutex::new(Vec::new())), - stats: Arc::new(AsyncMutex::new(AudioStreamStats::default())), - sequence: AtomicU64::new(0), - stream_start_time: RwLock::new(None), stop_flag: Arc::new(AtomicBool::new(false)), } } - /// Get current state pub fn state(&self) -> AudioStreamState { *self.state_rx.borrow() } - /// Subscribe to state changes pub fn state_watch(&self) -> watch::Receiver { self.state_rx.clone() } - /// Subscribe to Opus frames (each packet is one encoded 20 ms frame). pub fn subscribe_opus(&self) -> mpsc::Receiver> { let (tx, rx) = mpsc::channel::>(128); self.opus_subscribers.lock().unwrap().push(tx); rx } - /// Get number of active subscribers pub fn subscriber_count(&self) -> usize { self.opus_subscribers .lock() @@ -135,14 +103,12 @@ impl AudioStreamer { .count() } - /// Get current statistics - pub async fn stats(&self) -> AudioStreamStats { - let mut stats = self.stats.lock().await.clone(); - stats.subscriber_count = self.subscriber_count(); - stats + pub fn stats(&self) -> AudioStreamStats { + AudioStreamStats { + subscriber_count: self.subscriber_count(), + } } - /// Update configuration (only when stopped) pub async fn set_config(&self, config: AudioStreamerConfig) -> Result<()> { if self.state() != AudioStreamState::Stopped { return Err(AppError::AudioError( @@ -153,12 +119,9 @@ impl AudioStreamer { Ok(()) } - /// Update bitrate dynamically (can be done while streaming) pub async fn set_bitrate(&self, bitrate: u32) -> Result<()> { - // Update config self.config.write().await.opus.bitrate = bitrate; - // Update encoder if running if let Some(ref mut encoder) = *self.encoder.lock().await { encoder.set_bitrate(bitrate)?; } @@ -167,7 +130,6 @@ impl AudioStreamer { Ok(()) } - /// Start the audio stream pub async fn start(&self) -> Result<()> { if self.state() == AudioStreamState::Running { return Ok(()); @@ -186,28 +148,14 @@ impl AudioStreamer { config.opus.bitrate ); - // Create capturer let capturer = Arc::new(AudioCapturer::new(config.capture.clone())); *self.capturer.write().await = Some(capturer.clone()); - // Create encoder let encoder = OpusEncoder::new(config.opus.clone())?; *self.encoder.lock().await = Some(encoder); - // Start capture capturer.start().await?; - // Reset stats - { - let mut stats = self.stats.lock().await; - *stats = AudioStreamStats::default(); - } - - // Record start time - *self.stream_start_time.write().await = Some(Instant::now()); - self.sequence.store(0, Ordering::SeqCst); - - // Start encoding task let capturer_for_task = capturer.clone(); let encoder = self.encoder.clone(); let opus_subscribers = self.opus_subscribers.clone(); @@ -215,14 +163,19 @@ impl AudioStreamer { let stop_flag = self.stop_flag.clone(); tokio::spawn(async move { - Self::stream_task(capturer_for_task, encoder, opus_subscribers, state, stop_flag) - .await; + Self::stream_task( + capturer_for_task, + encoder, + opus_subscribers, + state, + stop_flag, + ) + .await; }); Ok(()) } - /// Stop the audio stream pub async fn stop(&self) -> Result<()> { if self.state() == AudioStreamState::Stopped { return Ok(()); @@ -230,18 +183,14 @@ impl AudioStreamer { info!("Stopping audio stream"); - // Signal stop self.stop_flag.store(true, Ordering::SeqCst); - // Stop capturer if let Some(ref capturer) = *self.capturer.read().await { capturer.stop().await?; } - // Clear resources — drop Opus senders so mpsc receivers see end-of-stream *self.capturer.write().await = None; *self.encoder.lock().await = None; - *self.stream_start_time.write().await = None; self.opus_subscribers.lock().unwrap().clear(); let _ = self.state.send(AudioStreamState::Stopped); @@ -249,7 +198,6 @@ impl AudioStreamer { Ok(()) } - /// Check if streaming pub fn is_running(&self) -> bool { self.state() == AudioStreamState::Running } diff --git a/src/auth/middleware.rs b/src/auth/middleware.rs index 80f40f3f..e0cc748e 100644 --- a/src/auth/middleware.rs +++ b/src/auth/middleware.rs @@ -8,20 +8,16 @@ use axum::{ use axum_extra::extract::CookieJar; use std::sync::Arc; -use crate::error::ErrorResponse; use crate::state::AppState; +use crate::web::ErrorResponse; -/// Session cookie name pub const SESSION_COOKIE: &str = "one_kvm_session"; -/// Extract session ID from request pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap) -> Option { - // First try cookie if let Some(cookie) = cookies.get(SESSION_COOKIE) { return Some(cookie.value().to_string()); } - // Then try Authorization header (Bearer token) if let Some(auth_header) = headers.get(axum::http::header::AUTHORIZATION) { if let Ok(auth_str) = auth_header.to_str() { if let Some(token) = auth_str.strip_prefix("Bearer ") { @@ -33,7 +29,6 @@ pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap) None } -/// Authentication middleware pub async fn auth_middleware( State(state): State>, cookies: CookieJar, @@ -41,29 +36,23 @@ pub async fn auth_middleware( next: Next, ) -> Result { let raw_path = request.uri().path(); - // When this middleware is mounted under /api, Axum strips the prefix for the inner router. - // Normalize the path so checks work whether it is mounted or not. + // Mounted under /api: inner path may lack prefix; normalize for whitelist checks. let path = raw_path.strip_prefix("/api").unwrap_or(raw_path); - // Check if system is initialized if !state.config.is_initialized() { - // Allow only setup-related endpoints when not initialized if is_setup_public_endpoint(path) { return Ok(next.run(request).await); } } - // Public endpoints that don't require auth if is_public_endpoint(path) { return Ok(next.run(request).await); } - // Extract session ID let session_id = extract_session_id(&cookies, request.headers()); if let Some(session_id) = session_id { if let Ok(Some(session)) = state.sessions.get(&session_id).await { - // Add session to request extensions request.extensions_mut().insert(session); return Ok(next.run(request).await); } @@ -87,9 +76,7 @@ fn unauthorized_response(message: &str) -> Response { (StatusCode::UNAUTHORIZED, Json(body)).into_response() } -/// Check if endpoint is public (no auth required) fn is_public_endpoint(path: &str) -> bool { - // Note: paths here are relative to /api since middleware is applied within the nested router matches!( path, "/" | "/auth/login" | "/health" | "/setup" | "/setup/init" @@ -102,7 +89,6 @@ fn is_public_endpoint(path: &str) -> bool { || path.ends_with(".svg") } -/// Setup-only endpoints allowed before initialization. fn is_setup_public_endpoint(path: &str) -> bool { matches!( path, diff --git a/src/auth/mod.rs b/src/auth/mod.rs index fc38a389..8d9ba479 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,6 +1,5 @@ pub mod middleware; mod password; -mod rfc3339; mod session; mod user; diff --git a/src/auth/password.rs b/src/auth/password.rs index 2c1605ae..b85f8e1d 100644 --- a/src/auth/password.rs +++ b/src/auth/password.rs @@ -5,7 +5,6 @@ use argon2::{ use crate::error::{AppError, Result}; -/// Hash a password using Argon2 pub fn hash_password(password: &str) -> Result { let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); @@ -16,7 +15,6 @@ pub fn hash_password(password: &str) -> Result { .map_err(|e| AppError::Internal(format!("Password hashing failed: {}", e))) } -/// Verify a password against a hash pub fn verify_password(password: &str, hash: &str) -> Result { let parsed_hash = PasswordHash::new(hash) .map_err(|e| AppError::Internal(format!("Invalid password hash: {}", e)))?; diff --git a/src/auth/rfc3339.rs b/src/auth/rfc3339.rs deleted file mode 100644 index de964988..00000000 --- a/src/auth/rfc3339.rs +++ /dev/null @@ -1,13 +0,0 @@ -//! RFC3339 strings in SQLite; structs use `time::serde::rfc3339`. - -use time::format_description::well_known::Rfc3339; -use time::OffsetDateTime; - -/// Parse DB text; bad input → `now_utc()`. -pub fn parse(s: &str) -> OffsetDateTime { - OffsetDateTime::parse(s, &Rfc3339).unwrap_or_else(|_| OffsetDateTime::now_utc()) -} - -pub fn format(dt: OffsetDateTime) -> String { - dt.format(&Rfc3339).expect("RFC3339 format") -} diff --git a/src/auth/session.rs b/src/auth/session.rs index 3288a751..834b6429 100644 --- a/src/auth/session.rs +++ b/src/auth/session.rs @@ -1,12 +1,12 @@ use serde::{Deserialize, Serialize}; -use sqlx::{Pool, Sqlite}; +use std::collections::HashMap; +use std::sync::Arc; use time::{Duration, OffsetDateTime}; +use tokio::sync::RwLock; use uuid::Uuid; -use super::rfc3339; use crate::error::Result; -/// Session data #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Session { pub id: String, @@ -19,29 +19,25 @@ pub struct Session { } impl Session { - /// Check if session is expired pub fn is_expired(&self) -> bool { OffsetDateTime::now_utc() > self.expires_at } } -/// Session store backed by SQLite #[derive(Clone)] pub struct SessionStore { - pool: Pool, + inner: Arc>>, default_ttl: Duration, } impl SessionStore { - /// Create a new session store - pub fn new(pool: Pool, ttl_secs: i64) -> Self { + pub fn new(ttl_secs: i64) -> Self { Self { - pool, + inner: Arc::new(RwLock::new(HashMap::new())), default_ttl: Duration::seconds(ttl_secs), } } - /// Create a new session pub async fn create(&self, user_id: &str) -> Result { let now = OffsetDateTime::now_utc(); let session = Session { @@ -52,105 +48,57 @@ impl SessionStore { data: None, }; - sqlx::query( - r#" - INSERT INTO sessions (id, user_id, created_at, expires_at, data) - VALUES (?1, ?2, ?3, ?4, ?5) - "#, - ) - .bind(&session.id) - .bind(&session.user_id) - .bind(rfc3339::format(session.created_at)) - .bind(rfc3339::format(session.expires_at)) - .bind(session.data.as_ref().map(|d| d.to_string())) - .execute(&self.pool) - .await?; - + let mut guard = self.inner.write().await; + guard.insert(session.id.clone(), session.clone()); Ok(session) } - /// Get a session by ID pub async fn get(&self, session_id: &str) -> Result> { - let row: Option<(String, String, String, String, Option)> = sqlx::query_as( - "SELECT id, user_id, created_at, expires_at, data FROM sessions WHERE id = ?1", - ) - .bind(session_id) - .fetch_optional(&self.pool) - .await?; - - match row { - Some((id, user_id, created_at, expires_at, data)) => { - let session = Session { - id, - user_id, - created_at: rfc3339::parse(&created_at), - expires_at: rfc3339::parse(&expires_at), - data: data.and_then(|d| serde_json::from_str(&d).ok()), - }; - - if session.is_expired() { - self.delete(&session.id).await?; - Ok(None) - } else { - Ok(Some(session)) - } - } - None => Ok(None), + let mut guard = self.inner.write().await; + let Some(session) = guard.get(session_id).cloned() else { + return Ok(None); + }; + if session.is_expired() { + guard.remove(session_id); + return Ok(None); } + Ok(Some(session)) } - /// Delete a session pub async fn delete(&self, session_id: &str) -> Result<()> { - sqlx::query("DELETE FROM sessions WHERE id = ?1") - .bind(session_id) - .execute(&self.pool) - .await?; + let mut guard = self.inner.write().await; + guard.remove(session_id); Ok(()) } - /// Delete all expired sessions pub async fn cleanup_expired(&self) -> Result { - let now = rfc3339::format(OffsetDateTime::now_utc()); - let result = sqlx::query("DELETE FROM sessions WHERE expires_at < ?1") - .bind(now) - .execute(&self.pool) - .await?; - Ok(result.rows_affected()) + let mut guard = self.inner.write().await; + let before = guard.len(); + guard.retain(|_, s| !s.is_expired()); + Ok((before - guard.len()) as u64) } - /// Delete all sessions pub async fn delete_all(&self) -> Result { - let result = sqlx::query("DELETE FROM sessions") - .execute(&self.pool) - .await?; - Ok(result.rows_affected()) + let mut guard = self.inner.write().await; + let n = guard.len() as u64; + guard.clear(); + Ok(n) } - /// Delete all sessions for a specific user - pub async fn delete_by_user_id(&self, user_id: &str) -> Result { - let result = sqlx::query("DELETE FROM sessions WHERE user_id = ?1") - .bind(user_id) - .execute(&self.pool) - .await?; - Ok(result.rows_affected()) - } - - /// List all session IDs pub async fn list_ids(&self) -> Result> { - let rows: Vec<(String,)> = sqlx::query_as("SELECT id FROM sessions") - .fetch_all(&self.pool) - .await?; - Ok(rows.into_iter().map(|(id,)| id).collect()) + let guard = self.inner.read().await; + Ok(guard.keys().cloned().collect()) } - /// Extend session expiration pub async fn extend(&self, session_id: &str) -> Result<()> { - let new_expires = OffsetDateTime::now_utc() + self.default_ttl; - sqlx::query("UPDATE sessions SET expires_at = ?1 WHERE id = ?2") - .bind(rfc3339::format(new_expires)) - .bind(session_id) - .execute(&self.pool) - .await?; + let mut guard = self.inner.write().await; + if let Some(session) = guard.get_mut(session_id) { + if session.is_expired() { + guard.remove(session_id); + } else { + session.expires_at = OffsetDateTime::now_utc() + self.default_ttl; + } + } Ok(()) } } diff --git a/src/auth/user.rs b/src/auth/user.rs index b749b6ee..e131fc6b 100644 --- a/src/auth/user.rs +++ b/src/auth/user.rs @@ -1,122 +1,99 @@ use serde::{Deserialize, Serialize}; use sqlx::{Pool, Sqlite}; +use time::format_description::well_known::Rfc3339; use time::OffsetDateTime; use uuid::Uuid; use super::password::{hash_password, verify_password}; -use super::rfc3339; use crate::error::{AppError, Result}; -/// User row type from database -type UserRow = (String, String, String, String, String); +type UserRow = (String, String, String); -/// User data #[derive(Debug, Clone, Serialize, Deserialize)] pub struct User { pub id: String, pub username: String, #[serde(skip_serializing)] pub password_hash: String, - #[serde(with = "time::serde::rfc3339")] - pub created_at: OffsetDateTime, - #[serde(with = "time::serde::rfc3339")] - pub updated_at: OffsetDateTime, } impl User { - /// Convert from database row to User fn from_row(row: UserRow) -> Self { - let (id, username, password_hash, created_at, updated_at) = row; + let (id, username, password_hash) = row; Self { id, username, password_hash, - created_at: rfc3339::parse(&created_at), - updated_at: rfc3339::parse(&updated_at), } } } -/// User store backed by SQLite #[derive(Clone)] pub struct UserStore { pool: Pool, } impl UserStore { - /// Create a new user store pub fn new(pool: Pool) -> Self { Self { pool } } - /// Create a new user - pub async fn create(&self, username: &str, password: &str) -> Result { - // Check if username already exists - if self.get_by_username(username).await?.is_some() { - return Err(AppError::BadRequest(format!( - "Username '{}' already exists", - username - ))); + /// The single local user, or `None` if none exists. Errors if more than one row is present. + pub async fn single_user(&self) -> Result> { + let mut rows: Vec = sqlx::query_as( + "SELECT id, username, password_hash FROM users ORDER BY rowid ASC LIMIT 2", + ) + .fetch_all(&self.pool) + .await?; + + match rows.len() { + 0 => Ok(None), + 1 => Ok(Some(User::from_row(rows.remove(0)))), + _ => Err(AppError::Internal( + "Multiple user accounts in database; this build supports only one".to_string(), + )), + } + } + + pub async fn create_first_user(&self, username: &str, password: &str) -> Result { + if self.single_user().await?.is_some() { + return Err(AppError::BadRequest( + "A user account already exists".to_string(), + )); } let password_hash = hash_password(password)?; - let now = OffsetDateTime::now_utc(); let user = User { id: Uuid::new_v4().to_string(), username: username.to_string(), password_hash, - created_at: now, - updated_at: now, }; sqlx::query( r#" - INSERT INTO users (id, username, password_hash, created_at, updated_at) - VALUES (?1, ?2, ?3, ?4, ?5) + INSERT INTO users (id, username, password_hash) + VALUES (?1, ?2, ?3) "#, ) .bind(&user.id) .bind(&user.username) .bind(&user.password_hash) - .bind(rfc3339::format(user.created_at)) - .bind(rfc3339::format(user.updated_at)) .execute(&self.pool) .await?; Ok(user) } - /// Get user by ID - pub async fn get(&self, user_id: &str) -> Result> { - let row: Option = sqlx::query_as( - "SELECT id, username, password_hash, created_at, updated_at FROM users WHERE id = ?1", - ) - .bind(user_id) - .fetch_optional(&self.pool) - .await?; - - Ok(row.map(User::from_row)) - } - - /// Get user by username - pub async fn get_by_username(&self, username: &str) -> Result> { - let row: Option = sqlx::query_as( - "SELECT id, username, password_hash, created_at, updated_at FROM users WHERE username = ?1", - ) - .bind(username) - .fetch_optional(&self.pool) - .await?; - - Ok(row.map(User::from_row)) - } - - /// Verify user credentials pub async fn verify(&self, username: &str, password: &str) -> Result> { - let user = match self.get_by_username(username).await? { - Some(user) => user, + let user = match self.single_user().await? { + Some(u) => u, None => return Ok(None), }; + if user.username != username { + return Ok(None); + } + if verify_password(password, &user.password_hash)? { Ok(Some(user)) } else { @@ -124,15 +101,23 @@ impl UserStore { } } - /// Update user password pub async fn update_password(&self, user_id: &str, new_password: &str) -> Result<()> { + let user = self + .single_user() + .await? + .ok_or_else(|| AppError::NotFound("User not found".to_string()))?; + + if user.id != user_id { + return Err(AppError::AuthError("Invalid session".to_string())); + } + let password_hash = hash_password(new_password)?; let now = OffsetDateTime::now_utc(); let result = sqlx::query("UPDATE users SET password_hash = ?1, updated_at = ?2 WHERE id = ?3") .bind(&password_hash) - .bind(rfc3339::format(now)) + .bind(now.format(&Rfc3339).expect("RFC3339 format")) .bind(user_id) .execute(&self.pool) .await?; @@ -144,21 +129,24 @@ impl UserStore { Ok(()) } - /// Update username pub async fn update_username(&self, user_id: &str, new_username: &str) -> Result<()> { - if let Some(existing) = self.get_by_username(new_username).await? { - if existing.id != user_id { - return Err(AppError::BadRequest(format!( - "Username '{}' already exists", - new_username - ))); - } + let user = self + .single_user() + .await? + .ok_or_else(|| AppError::NotFound("User not found".to_string()))?; + + if user.id != user_id { + return Err(AppError::AuthError("Invalid session".to_string())); + } + + if new_username == user.username { + return Ok(()); } let now = OffsetDateTime::now_utc(); let result = sqlx::query("UPDATE users SET username = ?1, updated_at = ?2 WHERE id = ?3") .bind(new_username) - .bind(rfc3339::format(now)) + .bind(now.format(&Rfc3339).expect("RFC3339 format")) .bind(user_id) .execute(&self.pool) .await?; @@ -169,37 +157,4 @@ impl UserStore { Ok(()) } - - /// List all users - pub async fn list(&self) -> Result> { - let rows: Vec = sqlx::query_as( - "SELECT id, username, password_hash, created_at, updated_at FROM users ORDER BY created_at", - ) - .fetch_all(&self.pool) - .await?; - - Ok(rows.into_iter().map(User::from_row).collect()) - } - - /// Delete user by ID - pub async fn delete(&self, user_id: &str) -> Result<()> { - let result = sqlx::query("DELETE FROM users WHERE id = ?1") - .bind(user_id) - .execute(&self.pool) - .await?; - - if result.rows_affected() == 0 { - return Err(AppError::NotFound("User not found".to_string())); - } - - Ok(()) - } - - /// Check if any users exist - pub async fn has_users(&self) -> Result { - let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users") - .fetch_one(&self.pool) - .await?; - Ok(count.0 > 0) - } } diff --git a/src/config/mod.rs b/src/config/mod.rs index e33a529b..f6e65c6e 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,5 +1,7 @@ +mod persistence; mod schema; mod store; +pub use persistence::ConfigChange; pub use schema::*; pub use store::ConfigStore; diff --git a/src/config/persistence.rs b/src/config/persistence.rs new file mode 100644 index 00000000..105f8a76 --- /dev/null +++ b/src/config/persistence.rs @@ -0,0 +1,5 @@ +/// Configuration change event +#[derive(Debug, Clone)] +pub struct ConfigChange { + pub key: String, +} diff --git a/src/config/schema.rs b/src/config/schema.rs index 59f6dd29..49575fdc 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -1,12 +1,85 @@ -use crate::video::encoder::BitratePreset; use serde::{Deserialize, Serialize}; use typeshare::typeshare; -// Re-export ExtensionsConfig from extensions module +// Re-export domain config types that are embedded in AppConfig. +// These are simple data types defined in their respective modules; +// keeping the re-export here is acceptable since they flow inward. pub use crate::extensions::ExtensionsConfig; -// Re-export RustDeskConfig from rustdesk module pub use crate::rustdesk::config::RustDeskConfig; +/// Bitrate preset for video encoding +/// +/// Simplifies bitrate configuration by providing three intuitive presets +/// plus a custom option for advanced users. +#[typeshare] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", content = "value")] +#[derive(Default)] +pub enum BitratePreset { + /// Speed priority: 1 Mbps, lowest latency, smaller GOP + Speed, + /// Balanced: 4 Mbps, good quality/latency tradeoff + #[default] + Balanced, + /// Quality priority: 8 Mbps, best visual quality + Quality, + /// Custom bitrate in kbps (for advanced users) + Custom(u32), +} + +impl BitratePreset { + /// Get bitrate value in kbps + pub fn bitrate_kbps(&self) -> u32 { + match self { + Self::Speed => 1000, + Self::Balanced => 4000, + Self::Quality => 8000, + Self::Custom(kbps) => *kbps, + } + } + + /// Get recommended GOP size based on preset + pub fn gop_size(&self, fps: u32) -> u32 { + match self { + Self::Speed => (fps / 2).max(15), + Self::Balanced => fps, + Self::Quality => fps * 2, + Self::Custom(_) => fps, + } + } + + /// Get quality preset name for encoder configuration + pub fn quality_level(&self) -> &'static str { + match self { + Self::Speed => "low", + Self::Balanced => "medium", + Self::Quality => "high", + Self::Custom(_) => "medium", + } + } + + /// Create from kbps value, mapping to nearest preset or Custom + pub fn from_kbps(kbps: u32) -> Self { + match kbps { + 0..=1500 => Self::Speed, + 1501..=6000 => Self::Balanced, + 6001..=10000 => Self::Quality, + _ => Self::Custom(kbps), + } + } +} + +impl std::fmt::Display for BitratePreset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Speed => write!(f, "Speed (1 Mbps)"), + Self::Balanced => write!(f, "Balanced (4 Mbps)"), + Self::Quality => write!(f, "Quality (8 Mbps)"), + Self::Custom(kbps) => write!(f, "Custom ({} kbps)", kbps), + } + } +} + /// Main application configuration #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] @@ -179,27 +252,13 @@ pub enum OtgEndpointBudget { } impl OtgEndpointBudget { - pub fn default_for_udc_name(udc: Option<&str>) -> Self { - if udc.is_some_and(crate::otg::configfs::is_low_endpoint_udc) { - Self::Five - } else { - Self::Six - } - } - - pub fn resolved(self, udc: Option<&str>) -> Self { + /// Resolve endpoint limit assuming a known budget variant (not Auto). + pub fn endpoint_limit_raw(&self) -> Option { match self { - Self::Auto => Self::default_for_udc_name(udc), - other => other, - } - } - - pub fn endpoint_limit(self, udc: Option<&str>) -> Option { - match self.resolved(udc) { Self::Five => Some(5), Self::Six => Some(6), Self::Unlimited => None, - Self::Auto => unreachable!("auto budget must be resolved before use"), + Self::Auto => None, // resolved via `HidConfig::resolved_otg_endpoint_limit` } } } @@ -356,32 +415,23 @@ impl Default for HidConfig { } impl HidConfig { + /// Resolve effective OTG HID functions from profile + custom selection. + /// Pure logic, no external dependency. pub fn effective_otg_functions(&self) -> OtgHidFunctions { self.otg_profile.resolve_functions(&self.otg_functions) } - pub fn resolved_otg_udc(&self) -> Option { - crate::otg::configfs::resolve_udc_name(self.otg_udc.as_deref()) - } - - pub fn resolved_otg_endpoint_budget(&self) -> OtgEndpointBudget { - self.otg_endpoint_budget - .resolved(self.resolved_otg_udc().as_deref()) - } - - pub fn resolved_otg_endpoint_limit(&self) -> Option { - self.otg_endpoint_budget - .endpoint_limit(self.resolved_otg_udc().as_deref()) - } - + /// Whether keyboard LED feedback is effectively enabled. pub fn effective_otg_keyboard_leds(&self) -> bool { self.otg_keyboard_leds && self.effective_otg_functions().keyboard } + /// Effective HID functions after applying all constraints. pub fn constrained_otg_functions(&self) -> OtgHidFunctions { self.effective_otg_functions() } + /// Calculate required endpoint count for the current function selection. pub fn effective_otg_required_endpoints(&self, msd_enabled: bool) -> u8 { let functions = self.effective_otg_functions(); let mut endpoints = functions.endpoint_cost(self.effective_otg_keyboard_leds()); @@ -391,6 +441,7 @@ impl HidConfig { endpoints } + /// Validate endpoint budget for the current OTG configuration (UDC-aware when budget is Auto). pub fn validate_otg_endpoint_budget(&self, msd_enabled: bool) -> crate::error::Result<()> { if self.backend != HidBackend::Otg { return Ok(()); @@ -403,8 +454,9 @@ impl HidConfig { )); } + let resolved_limit = self.resolved_otg_endpoint_limit(); let required = self.effective_otg_required_endpoints(msd_enabled); - if let Some(limit) = self.resolved_otg_endpoint_limit() { + if let Some(limit) = resolved_limit { if required > limit { return Err(crate::error::AppError::BadRequest(format!( "OTG selection requires {} endpoints, but the configured limit is {}", @@ -415,6 +467,40 @@ impl HidConfig { Ok(()) } + + /// Effective OTG UDC name (for change detection / service). + #[inline] + pub fn resolved_otg_udc(&self) -> Option { + if self.backend != HidBackend::Otg { + return None; + } + self.otg_udc + .as_ref() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .or_else(|| crate::otg::OtgGadgetManager::find_udc()) + } + + /// Resolved endpoint limit used for OTG gadget allocator / validation. + #[inline] + pub fn resolved_otg_endpoint_limit(&self) -> Option { + if self.backend != HidBackend::Otg { + return None; + } + match self.otg_endpoint_budget { + OtgEndpointBudget::Five => Some(5), + OtgEndpointBudget::Six => Some(6), + OtgEndpointBudget::Unlimited => None, + OtgEndpointBudget::Auto => { + let udc = self.resolved_otg_udc().unwrap_or_default(); + if crate::otg::configfs::is_low_endpoint_udc(&udc) { + Some(5) + } else { + Some(6) + } + } + } + } } /// MSD configuration @@ -511,7 +597,7 @@ impl Default for AudioConfig { fn default() -> Self { Self { enabled: false, - device: "default".to_string(), + device: String::new(), quality: "balanced".to_string(), } } @@ -606,21 +692,6 @@ pub enum EncoderType { } impl EncoderType { - /// Convert to EncoderBackend for registry queries - pub fn to_backend(&self) -> Option { - use crate::video::encoder::registry::EncoderBackend; - match self { - EncoderType::Auto => None, - EncoderType::Software => Some(EncoderBackend::Software), - EncoderType::Vaapi => Some(EncoderBackend::Vaapi), - EncoderType::Nvenc => Some(EncoderBackend::Nvenc), - EncoderType::Qsv => Some(EncoderBackend::Qsv), - EncoderType::Amf => Some(EncoderBackend::Amf), - EncoderType::Rkmpp => Some(EncoderBackend::Rkmpp), - EncoderType::V4l2m2m => Some(EncoderBackend::V4l2m2m), - } - } - /// Get display name for UI pub fn display_name(&self) -> &'static str { match self { @@ -687,19 +758,17 @@ impl Default for StreamConfig { } impl StreamConfig { - /// Check if using public ICE servers (user left fields empty) + /// Whether built-in / public ICE is used (no custom STUN or TURN URL configured). pub fn is_using_public_ice_servers(&self) -> bool { - use crate::webrtc::config::public_ice; - self.stun_server + let no_custom_stun = self + .stun_server .as_ref() - .map(|s| s.is_empty()) - .unwrap_or(true) - && self - .turn_server - .as_ref() - .map(|s| s.is_empty()) - .unwrap_or(true) - && public_ice::is_configured() + .map_or(true, |s| s.trim().is_empty()); + let no_custom_turn = self + .turn_server + .as_ref() + .map_or(true, |s| s.trim().is_empty()); + no_custom_stun && no_custom_turn } } diff --git a/src/config/store.rs b/src/config/store.rs index 5547cbae..00edfca3 100644 --- a/src/config/store.rs +++ b/src/config/store.rs @@ -1,11 +1,10 @@ use arc_swap::ArcSwap; -use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite}; -use std::path::Path; +use sqlx::{Pool, Sqlite}; use std::sync::Arc; -use std::time::Duration; use tokio::sync::broadcast; use tokio::sync::Mutex; +use super::persistence::ConfigChange; use super::AppConfig; use crate::error::{AppError, Result}; @@ -23,127 +22,23 @@ pub struct ConfigStore { write_lock: Arc>, } -/// Configuration change event -#[derive(Debug, Clone)] -pub struct ConfigChange { - pub key: String, -} - impl ConfigStore { /// Create a new configuration store - pub async fn new(db_path: &Path) -> Result { - // Ensure parent directory exists - if let Some(parent) = db_path.parent() { - tokio::fs::create_dir_all(parent).await?; - } - - let db_url = format!("sqlite:{}?mode=rwc", db_path.display()); - - let pool = SqlitePoolOptions::new() - // SQLite uses single-writer mode, 2 connections is sufficient for embedded devices - // One for reads, one for writes to avoid blocking - .max_connections(2) - // Set reasonable timeouts for embedded environments - .acquire_timeout(Duration::from_secs(5)) - .idle_timeout(Duration::from_secs(300)) - .connect(&db_url) - .await?; - - // Initialize database schema - Self::init_schema(&pool).await?; - - // Load or create default config - let config = Self::load_config(&pool).await?; - let cache = Arc::new(ArcSwap::from_pointee(config)); - - let (change_tx, _) = broadcast::channel(16); - + pub fn new(pool: Pool) -> Result { + // Load or create default config synchronously wrapper + // (actual DB load is async, handled in init()) Ok(Self { pool, - cache, - change_tx, + cache: Arc::new(ArcSwap::from_pointee(AppConfig::default())), + change_tx: broadcast::channel(16).0, write_lock: Arc::new(Mutex::new(())), }) } - /// Initialize database schema - async fn init_schema(pool: &Pool) -> Result<()> { - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL, - updated_at TEXT NOT NULL DEFAULT (datetime('now')) - ) - "#, - ) - .execute(pool) - .await?; - - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS users ( - id TEXT PRIMARY KEY, - username TEXT NOT NULL UNIQUE, - password_hash TEXT NOT NULL, - created_at TEXT NOT NULL DEFAULT (datetime('now')), - updated_at TEXT NOT NULL DEFAULT (datetime('now')) - ) - "#, - ) - .execute(pool) - .await?; - - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS sessions ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL, - created_at TEXT NOT NULL DEFAULT (datetime('now')), - expires_at TEXT NOT NULL, - data TEXT - ) - "#, - ) - .execute(pool) - .await?; - - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS api_tokens ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - token_hash TEXT NOT NULL, - permissions TEXT NOT NULL, - expires_at TEXT, - created_at TEXT NOT NULL DEFAULT (datetime('now')), - last_used TEXT - ) - "#, - ) - .execute(pool) - .await?; - - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS wol_history ( - mac_address TEXT PRIMARY KEY, - updated_at INTEGER NOT NULL - ) - "#, - ) - .execute(pool) - .await?; - - sqlx::query( - r#" - CREATE INDEX IF NOT EXISTS idx_wol_history_updated_at - ON wol_history(updated_at DESC) - "#, - ) - .execute(pool) - .await?; - + /// Load configuration from database (call after new()) + pub async fn load(&self) -> Result<()> { + let config = Self::load_config(&self.pool).await?; + self.cache.store(Arc::new(config)); Ok(()) } @@ -244,16 +139,12 @@ impl ConfigStore { pub fn is_initialized(&self) -> bool { self.cache.load().initialized } - - /// Get database pool for session management - pub fn pool(&self) -> &Pool { - &self.pool - } } #[cfg(test)] mod tests { use super::*; + use crate::db::DatabasePool; use tempfile::tempdir; #[tokio::test] @@ -261,7 +152,11 @@ mod tests { let dir = tempdir().unwrap(); let db_path = dir.path().join("test.db"); - let store = ConfigStore::new(&db_path).await.unwrap(); + let db = DatabasePool::new(&db_path).await.unwrap(); + db.init_schema().await.unwrap(); + + let store = ConfigStore::new(db.clone_pool()).unwrap(); + store.load().await.unwrap(); // Check default config (now lock-free, no await needed) let config = store.get(); @@ -282,7 +177,8 @@ mod tests { assert_eq!(config.web.http_port, 9000); // Create new store instance and verify persistence - let store2 = ConfigStore::new(&db_path).await.unwrap(); + let store2 = ConfigStore::new(db.clone_pool()).unwrap(); + store2.load().await.unwrap(); let config = store2.get(); assert!(config.initialized); assert_eq!(config.web.http_port, 9000); diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 00000000..b2935c58 --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,3 @@ +mod pool; + +pub use pool::DatabasePool; diff --git a/src/db/pool.rs b/src/db/pool.rs new file mode 100644 index 00000000..ec2f042b --- /dev/null +++ b/src/db/pool.rs @@ -0,0 +1,119 @@ +use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite}; +use std::path::Path; +use std::time::Duration; + +use crate::error::Result; + +#[derive(Clone)] +pub struct DatabasePool { + pool: Pool, +} + +impl DatabasePool { + pub async fn new(db_path: &Path) -> Result { + if let Some(parent) = db_path.parent() { + tokio::fs::create_dir_all(parent).await?; + } + + let db_url = format!("sqlite:{}?mode=rwc", db_path.display()); + + let pool = SqlitePoolOptions::new() + .max_connections(4) + .acquire_timeout(Duration::from_secs(5)) + .idle_timeout(Duration::from_secs(300)) + .connect(&db_url) + .await?; + + Ok(Self { pool }) + } + + pub async fn init_schema(&self) -> Result<()> { + self.create_config_table().await?; + self.create_users_table().await?; + self.create_api_tokens_table().await?; + self.create_wol_history_table().await?; + Ok(()) + } + + async fn create_config_table(&self) -> Result<()> { + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + updated_at TEXT NOT NULL DEFAULT (datetime('now')) + ) + "#, + ) + .execute(&self.pool) + .await?; + Ok(()) + } + + async fn create_users_table(&self) -> Result<()> { + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) + ) + "#, + ) + .execute(&self.pool) + .await?; + Ok(()) + } + + async fn create_api_tokens_table(&self) -> Result<()> { + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS api_tokens ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + token_hash TEXT NOT NULL, + permissions TEXT NOT NULL, + expires_at TEXT, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + last_used TEXT + ) + "#, + ) + .execute(&self.pool) + .await?; + Ok(()) + } + + async fn create_wol_history_table(&self) -> Result<()> { + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS wol_history ( + mac_address TEXT PRIMARY KEY, + updated_at INTEGER NOT NULL + ) + "#, + ) + .execute(&self.pool) + .await?; + + sqlx::query( + r#" + CREATE INDEX IF NOT EXISTS idx_wol_history_updated_at + ON wol_history(updated_at DESC) + "#, + ) + .execute(&self.pool) + .await?; + Ok(()) + } + + pub fn pool(&self) -> &Pool { + &self.pool + } + + pub fn clone_pool(&self) -> Pool { + self.pool.clone() + } +} diff --git a/src/error.rs b/src/error.rs index aaffdac4..9dba75a1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,12 +1,5 @@ -use axum::{ - http::StatusCode, - response::{IntoResponse, Response}, - Json, -}; -use serde::Serialize; use thiserror::Error; -/// Application-wide error type #[derive(Error, Debug)] pub enum AppError { #[error("Authentication failed: {0}")] @@ -15,17 +8,14 @@ pub enum AppError { #[error("Not authenticated")] Unauthorized, - #[error("Forbidden: {0}")] - Forbidden(String), - #[error("Not found: {0}")] NotFound(String), #[error("Bad request: {0}")] BadRequest(String), - #[error("Database error: {0}")] - Database(#[from] sqlx::Error), + #[error("Persistence error: {0}")] + Persistence(String), #[error("Internal error: {0}")] Internal(String), @@ -42,9 +32,6 @@ pub enum AppError { #[error("Video error: {0}")] VideoError(String), - #[error("Video device lost [{device}]: {reason}")] - VideoDeviceLost { device: String, reason: String }, - /// No input signal while opening capture; `kind` is `SignalStatus` as string (`from_str`). #[error("Capture has no valid signal: {kind}")] CaptureNoSignal { kind: String }, @@ -66,37 +53,10 @@ pub enum AppError { ServiceUnavailable(String), } -/// Error response body (unified success format) -#[derive(Serialize)] -pub struct ErrorResponse { - pub success: bool, - pub message: String, -} - -impl AppError { - fn status_code(&self) -> StatusCode { - // Always return 200 OK - success/failure is indicated by the success field - StatusCode::OK - } -} - -impl IntoResponse for AppError { - fn into_response(self) -> Response { - let status = self.status_code(); - let body = ErrorResponse { - success: false, - message: self.to_string(), - }; - - tracing::error!( - error_type = std::any::type_name_of_val(&self), - error_message = %body.message, - "Request failed" - ); - - (status, Json(body)).into_response() - } -} - -/// Result type alias for handlers pub type Result = std::result::Result; + +impl From for AppError { + fn from(err: sqlx::Error) -> Self { + AppError::Persistence(err.to_string()) + } +} diff --git a/src/events/mod.rs b/src/events/mod.rs index 7dfa4539..d1c63b22 100644 --- a/src/events/mod.rs +++ b/src/events/mod.rs @@ -1,41 +1,28 @@ -//! Event system for real-time state notifications -//! -//! This module provides a global event bus for broadcasting system events -//! to WebSocket clients and other subscribers. +//! Event bus: [`SystemEvent`] fan-out to WebSocket subscribers and internal tasks. pub mod types; +use self::types::EXACT_EVENT_TOPICS; + pub use types::{ - AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, MsdDeviceInfo, SystemEvent, - TtydDeviceInfo, VideoDeviceInfo, + AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, LedState, MsdDeviceInfo, + SystemEvent, TtydDeviceInfo, VideoDeviceInfo, }; use tokio::sync::broadcast; -/// Event channel capacity (ring buffer size) const EVENT_CHANNEL_CAPACITY: usize = 256; -const EXACT_TOPICS: &[&str] = &[ - "stream.mode_switching", - "stream.state_changed", - "stream.config_changing", - "stream.config_applied", - "stream.device_lost", - "stream.reconnecting", - "stream.recovered", - "stream.webrtc_ready", - "stream.stats_update", - "stream.mode_changed", - "stream.mode_ready", - "webrtc.ice_candidate", - "webrtc.ice_complete", - "msd.upload_progress", - "msd.download_progress", - "system.device_info", - "error", -]; - -const PREFIX_TOPICS: &[&str] = &["stream.*", "webrtc.*", "msd.*", "system.*"]; +fn collect_prefix_wildcards(exact: &[&'static str]) -> Vec { + use std::collections::BTreeSet; + let mut segments = BTreeSet::new(); + for name in exact { + if let Some((seg, _)) = name.split_once('.') { + segments.insert(seg); + } + } + segments.into_iter().map(|s| format!("{}.*", s)).collect() +} fn make_sender() -> broadcast::Sender { let (tx, _rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY); @@ -48,52 +35,23 @@ fn topic_prefix(event_name: &str) -> Option { .map(|(prefix, _)| format!("{}.*", prefix)) } -/// Global event bus for broadcasting system events -/// -/// The event bus uses tokio's broadcast channel to distribute events -/// to multiple subscribers. Events are delivered to all active subscribers. -/// -/// # Example -/// -/// ```no_run -/// use one_kvm::events::{EventBus, SystemEvent}; -/// -/// let bus = EventBus::new(); -/// -/// // Publish an event -/// bus.publish(SystemEvent::StreamStateChanged { -/// state: "streaming".to_string(), -/// device: Some("/dev/video0".to_string()), -/// reason: None, -/// next_retry_ms: None, -/// }); -/// -/// // Subscribe to events -/// let mut rx = bus.subscribe(); -/// tokio::spawn(async move { -/// while let Ok(event) = rx.recv().await { -/// println!("Received event: {:?}", event); -/// } -/// }); -/// ``` pub struct EventBus { tx: broadcast::Sender, exact_topics: std::collections::HashMap<&'static str, broadcast::Sender>, - prefix_topics: std::collections::HashMap<&'static str, broadcast::Sender>, + prefix_topics: std::collections::HashMap>, device_info_dirty_tx: broadcast::Sender<()>, } impl EventBus { - /// Create a new event bus pub fn new() -> Self { let tx = make_sender(); - let exact_topics = EXACT_TOPICS + let exact_topics = EXACT_EVENT_TOPICS .iter() .map(|topic| (*topic, make_sender())) .collect(); - let prefix_topics = PREFIX_TOPICS - .iter() - .map(|topic| (*topic, make_sender())) + let prefix_topics = collect_prefix_wildcards(EXACT_EVENT_TOPICS) + .into_iter() + .map(|topic| (topic, make_sender())) .collect(); let (device_info_dirty_tx, _dirty_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY); @@ -105,10 +63,6 @@ impl EventBus { } } - /// Publish an event to all subscribers - /// - /// If there are no active subscribers, the event is silently dropped. - /// This is by design - events are fire-and-forget notifications. pub fn publish(&self, event: SystemEvent) { let event_name = event.event_name(); @@ -117,28 +71,18 @@ impl EventBus { } if let Some(prefix) = topic_prefix(event_name) { - if let Some(tx) = self.prefix_topics.get(prefix.as_str()) { + if let Some(tx) = self.prefix_topics.get(&prefix) { let _ = tx.send(event.clone()); } } - // If no subscribers, send returns Err which is normal let _ = self.tx.send(event); } - /// Subscribe to events - /// - /// Returns a receiver that will receive all future events. - /// The receiver uses a ring buffer, so if a subscriber falls too far - /// behind, it will receive a `Lagged` error and miss some events. pub fn subscribe(&self) -> broadcast::Receiver { self.tx.subscribe() } - /// Subscribe to a specific topic. - /// - /// Supports exact event names, namespace wildcards like `stream.*`, and - /// `*` for the full event stream. pub fn subscribe_topic(&self, topic: &str) -> Option> { if topic == "*" { return Some(self.tx.subscribe()); @@ -151,22 +95,14 @@ impl EventBus { self.exact_topics.get(topic).map(|tx| tx.subscribe()) } - /// Mark the device-info snapshot as stale. - /// - /// This is an internal trigger used to refresh the latest `system.device_info` - /// snapshot without exposing another public WebSocket event. pub fn mark_device_info_dirty(&self) { let _ = self.device_info_dirty_tx.send(()); } - /// Subscribe to internal device-info refresh triggers. pub fn subscribe_device_info_dirty(&self) -> broadcast::Receiver<()> { self.device_info_dirty_tx.subscribe() } - /// Get the current number of active subscribers - /// - /// Useful for monitoring and debugging. pub fn subscriber_count(&self) -> usize { self.tx.receiver_count() } @@ -263,7 +199,6 @@ mod tests { let bus = EventBus::new(); assert_eq!(bus.subscriber_count(), 0); - // Should not panic when publishing with no subscribers bus.publish(SystemEvent::StreamStateChanged { state: "ready".to_string(), device: None, diff --git a/src/events/types.rs b/src/events/types.rs index d1bf9e9b..b90f7ded 100644 --- a/src/events/types.rs +++ b/src/events/types.rs @@ -1,165 +1,96 @@ -//! System event types -//! -//! Defines all event types that can be broadcast through the event bus. +//! [`SystemEvent`] and device snapshot types (WebSocket / JSON). use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use crate::hid::LedState; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct LedState { + pub num_lock: bool, + pub caps_lock: bool, + pub scroll_lock: bool, + pub compose: bool, + pub kana: bool, +} -// ============================================================================ -// Device Info Structures (for system.device_info event) -// ============================================================================ - -/// Video device information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct VideoDeviceInfo { - /// Whether video device is available pub available: bool, - /// Device path (e.g., /dev/video0) pub device: Option, - /// Pixel format (e.g., "MJPEG", "YUYV") pub format: Option, - /// Resolution (width, height) pub resolution: Option<(u32, u32)>, - /// Frames per second pub fps: u32, - /// Whether stream is currently active pub online: bool, - /// Current streaming mode: "mjpeg", "h264", "h265", "vp8", or "vp9" pub stream_mode: String, - /// Whether video config is currently being changed (frontend should skip mode sync) pub config_changing: bool, - /// Error message if any, None if OK pub error: Option, } -/// HID device information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HidDeviceInfo { - /// Whether HID backend is available pub available: bool, - /// Backend type: "otg", "ch9329", "none" pub backend: String, - /// Whether backend is initialized and ready pub initialized: bool, - /// Whether backend is currently online pub online: bool, - /// Whether absolute mouse positioning is supported pub supports_absolute_mouse: bool, - /// Whether keyboard LED/status feedback is enabled. pub keyboard_leds_enabled: bool, - /// Last known keyboard LED state. pub led_state: LedState, - /// Device path (e.g., serial port for CH9329) pub device: Option, - /// Error message if any, None if OK pub error: Option, - /// Error code if any, None if OK pub error_code: Option, } -/// MSD device information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MsdDeviceInfo { - /// Whether MSD is available pub available: bool, - /// Operating mode: "none", "image", "drive" pub mode: String, - /// Whether storage is connected to target pub connected: bool, - /// Currently mounted image ID pub image_id: Option, - /// Error message if any, None if OK pub error: Option, } -/// ATX device information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AtxDeviceInfo { - /// Whether ATX controller is available pub available: bool, - /// Backend type: "gpio", "usb_relay", "none" pub backend: String, - /// Whether backend is initialized pub initialized: bool, - /// Whether power is currently on pub power_on: bool, - /// Error message if any, None if OK pub error: Option, } -/// Audio device information -/// -/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo). #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AudioDeviceInfo { - /// Whether audio is enabled/available pub available: bool, - /// Whether audio is currently streaming pub streaming: bool, - /// Current audio device name pub device: Option, - /// Quality preset: "voice", "balanced", "high" pub quality: String, - /// Error message if any, None if OK pub error: Option, } -/// ttyd status information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TtydDeviceInfo { - /// Whether ttyd binary is available pub available: bool, - /// Whether ttyd is currently running pub running: bool, } -/// Per-client statistics #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ClientStats { - /// Client ID pub id: String, - /// Current FPS for this client (frames sent in last second) pub fps: u32, - /// Connected duration (seconds) pub connected_secs: u64, } -/// System event enumeration -/// -/// All events are tagged with their event name for serialization. -/// The `serde(tag = "event", content = "data")` attribute creates a -/// JSON structure like: -/// ```json -/// { -/// "event": "stream.state_changed", -/// "data": { "state": "streaming", "device": "/dev/video0" } -/// } -/// ``` +/// JSON: `{"event": "", "data": { ... }}`. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(tag = "event", content = "data")] #[allow(clippy::large_enum_variant)] pub enum SystemEvent { - // ============================================================================ - // Video Stream Events - // ============================================================================ - /// Stream mode switching started (transactional, correlates all following events) - /// - /// Sent immediately after a mode switch request is accepted. - /// Clients can use `transition_id` to correlate subsequent `stream.*` events. #[serde(rename = "stream.mode_switching")] StreamModeSwitching { - /// Unique transition ID for this mode switch transaction transition_id: String, - /// Target mode: "mjpeg", "h264", "h265", "vp8", "vp9" to_mode: String, - /// Previous mode: "mjpeg", "h264", "h265", "vp8", "vp9" from_mode: String, }, - /// Stream state for the UI (`streaming`, `no_signal`, `device_lost`, `device_busy`, etc.). - /// Optional `reason` / `next_retry_ms` are hints only; branch on `state`. #[serde(rename = "stream.state_changed")] StreamStateChanged { state: String, @@ -170,193 +101,122 @@ pub enum SystemEvent { next_retry_ms: Option, }, - /// Stream configuration is being changed - /// - /// Sent before applying new configuration to notify clients that - /// the stream will be interrupted temporarily. #[serde(rename = "stream.config_changing")] StreamConfigChanging { - /// Optional transition ID if this config change is part of a mode switch transaction #[serde(skip_serializing_if = "Option::is_none")] transition_id: Option, - /// Reason for change: "device_switch", "resolution_change", "format_change" reason: String, }, - /// Stream configuration has been applied successfully - /// - /// Sent after new configuration is active. Clients can reconnect now. #[serde(rename = "stream.config_applied")] StreamConfigApplied { - /// Optional transition ID if this config change is part of a mode switch transaction #[serde(skip_serializing_if = "Option::is_none")] transition_id: Option, - /// Device path device: String, - /// Resolution (width, height) resolution: (u32, u32), - /// Pixel format: "mjpeg", "yuyv", etc. format: String, - /// Frames per second fps: u32, }, - /// Stream device was lost (disconnected or error) #[serde(rename = "stream.device_lost")] - StreamDeviceLost { - /// Device path that was lost - device: String, - /// Reason for loss - reason: String, - }, + StreamDeviceLost { device: String, reason: String }, - /// Stream device is reconnecting #[serde(rename = "stream.reconnecting")] - StreamReconnecting { - /// Device path being reconnected - device: String, - /// Retry attempt number - attempt: u32, - }, + StreamReconnecting { device: String, attempt: u32 }, - /// Stream device has recovered #[serde(rename = "stream.recovered")] - StreamRecovered { - /// Device path that was recovered - device: String, - }, + StreamRecovered { device: String }, - /// WebRTC is ready to accept connections - /// - /// Sent after video frame source is connected to WebRTC pipeline. - /// Clients should wait for this event before attempting to create WebRTC sessions. #[serde(rename = "stream.webrtc_ready")] WebRTCReady { - /// Optional transition ID if this readiness is part of a mode switch transaction #[serde(skip_serializing_if = "Option::is_none")] transition_id: Option, - /// Current video codec codec: String, - /// Whether hardware encoding is being used hardware: bool, }, - /// WebRTC ICE candidate (server -> client trickle) #[serde(rename = "webrtc.ice_candidate")] WebRTCIceCandidate { - /// WebRTC session ID session_id: String, - /// ICE candidate data - candidate: crate::webrtc::signaling::IceCandidate, + candidate: serde_json::Value, }, - /// WebRTC ICE gathering complete (server -> client) #[serde(rename = "webrtc.ice_complete")] - WebRTCIceComplete { - /// WebRTC session ID - session_id: String, - }, + WebRTCIceComplete { session_id: String }, - /// Stream statistics update (sent periodically for client stats) #[serde(rename = "stream.stats_update")] StreamStatsUpdate { - /// Number of connected clients clients: u64, - /// Per-client statistics (client_id -> client stats) - /// Each client's FPS reflects the actual frames sent in the last second clients_stat: HashMap, }, - /// Stream mode changed (MJPEG <-> WebRTC) - /// - /// Sent when the streaming mode is switched. Clients should disconnect - /// from the current stream and reconnect using the new mode. #[serde(rename = "stream.mode_changed")] StreamModeChanged { - /// Optional transition ID if this change is part of a mode switch transaction #[serde(skip_serializing_if = "Option::is_none")] transition_id: Option, - /// New mode: "mjpeg", "h264", "h265", "vp8", or "vp9" mode: String, - /// Previous mode: "mjpeg", "h264", "h265", "vp8", or "vp9" previous_mode: String, }, - /// Stream mode switching completed (transactional end marker) - /// - /// Sent when the backend considers the new mode ready for clients to connect. #[serde(rename = "stream.mode_ready")] - StreamModeReady { - /// Unique transition ID for this mode switch transaction - transition_id: String, - /// Active mode after switch: "mjpeg", "h264", "h265", "vp8", "vp9" - mode: String, - }, + StreamModeReady { transition_id: String, mode: String }, - // ============================================================================ - // MSD (Mass Storage Device) Events - // ============================================================================ - /// File upload progress (for large file uploads) #[serde(rename = "msd.upload_progress")] MsdUploadProgress { - /// Upload operation ID upload_id: String, - /// Filename being uploaded filename: String, - /// Bytes uploaded so far bytes_uploaded: u64, - /// Total file size total_bytes: u64, - /// Progress percentage (0.0 - 100.0) progress_pct: f32, }, - /// Image download progress (for URL downloads) #[serde(rename = "msd.download_progress")] MsdDownloadProgress { - /// Download operation ID download_id: String, - /// Source URL url: String, - /// Target filename filename: String, - /// Bytes downloaded so far bytes_downloaded: u64, - /// Total file size (None if unknown) total_bytes: Option, - /// Progress percentage (0.0 - 100.0, None if total unknown) progress_pct: Option, - /// Download status: "started", "in_progress", "completed", "failed" status: String, }, - /// Complete device information (sent on WebSocket connect and state changes) #[serde(rename = "system.device_info")] DeviceInfo { - /// Video device information video: VideoDeviceInfo, - /// HID device information hid: HidDeviceInfo, - /// MSD device information (None if MSD not enabled) msd: Option, - /// ATX device information (None if ATX not enabled) atx: Option, - /// Audio device information (None if audio not enabled) audio: Option, - /// ttyd status information ttyd: TtydDeviceInfo, }, - /// WebSocket error notification (for connection-level errors like lag) #[serde(rename = "error")] - Error { - /// Error message - message: String, - }, + Error { message: String }, } +/// One entry per [`SystemEvent::event_name`]. `EventBus` builds `*.`-wildcard channels from the first segment; names without `.` (e.g. `error`) have no wildcard channel. +pub(crate) const EXACT_EVENT_TOPICS: &[&str] = &[ + "stream.mode_switching", + "stream.state_changed", + "stream.config_changing", + "stream.config_applied", + "stream.device_lost", + "stream.reconnecting", + "stream.recovered", + "stream.webrtc_ready", + "stream.stats_update", + "stream.mode_changed", + "stream.mode_ready", + "webrtc.ice_candidate", + "webrtc.ice_complete", + "msd.upload_progress", + "msd.download_progress", + "system.device_info", + "error", +]; + impl SystemEvent { - /// Get the event name (for filtering/routing) pub fn event_name(&self) -> &'static str { match self { Self::StreamModeSwitching { .. } => "stream.mode_switching", @@ -378,27 +238,6 @@ impl SystemEvent { Self::Error { .. } => "error", } } - - /// Check if event name matches a topic pattern - /// - /// Supports wildcards: - /// - `*` matches all events - /// - `stream.*` matches all stream events - /// - `stream.state_changed` matches exact event - pub fn matches_topic(&self, topic: &str) -> bool { - if topic == "*" { - return true; - } - - let event_name = self.event_name(); - - if topic.ends_with(".*") { - let prefix = topic.trim_end_matches(".*"); - event_name.starts_with(prefix) - } else { - event_name == topic - } - } } #[cfg(test)] @@ -417,19 +256,124 @@ mod tests { } #[test] - fn test_matches_topic() { - let event = SystemEvent::StreamStateChanged { - state: "streaming".to_string(), - device: None, - reason: None, - next_retry_ms: None, - }; + fn exact_topics_covers_all_variants() { + use std::collections::HashSet; - assert!(event.matches_topic("*")); - assert!(event.matches_topic("stream.*")); - assert!(event.matches_topic("stream.state_changed")); - assert!(!event.matches_topic("msd.*")); - assert!(!event.matches_topic("stream.config_changed")); + let samples = vec![ + SystemEvent::StreamModeSwitching { + transition_id: String::new(), + to_mode: String::new(), + from_mode: String::new(), + }, + SystemEvent::StreamStateChanged { + state: String::new(), + device: None, + reason: None, + next_retry_ms: None, + }, + SystemEvent::StreamConfigChanging { + transition_id: None, + reason: String::new(), + }, + SystemEvent::StreamConfigApplied { + transition_id: None, + device: String::new(), + resolution: (0, 0), + format: String::new(), + fps: 0, + }, + SystemEvent::StreamDeviceLost { + device: String::new(), + reason: String::new(), + }, + SystemEvent::StreamReconnecting { + device: String::new(), + attempt: 0, + }, + SystemEvent::StreamRecovered { + device: String::new(), + }, + SystemEvent::WebRTCReady { + transition_id: None, + codec: String::new(), + hardware: false, + }, + SystemEvent::StreamStatsUpdate { + clients: 0, + clients_stat: HashMap::new(), + }, + SystemEvent::StreamModeChanged { + transition_id: None, + mode: String::new(), + previous_mode: String::new(), + }, + SystemEvent::StreamModeReady { + transition_id: String::new(), + mode: String::new(), + }, + SystemEvent::WebRTCIceCandidate { + session_id: String::new(), + candidate: serde_json::Value::Null, + }, + SystemEvent::WebRTCIceComplete { + session_id: String::new(), + }, + SystemEvent::MsdUploadProgress { + upload_id: String::new(), + filename: String::new(), + bytes_uploaded: 0, + total_bytes: 0, + progress_pct: 0.0, + }, + SystemEvent::MsdDownloadProgress { + download_id: String::new(), + url: String::new(), + filename: String::new(), + bytes_downloaded: 0, + total_bytes: None, + progress_pct: None, + status: String::new(), + }, + SystemEvent::DeviceInfo { + video: VideoDeviceInfo { + available: false, + device: None, + format: None, + resolution: None, + fps: 0, + online: false, + stream_mode: String::new(), + config_changing: false, + error: None, + }, + hid: HidDeviceInfo { + available: false, + backend: String::new(), + initialized: false, + online: false, + supports_absolute_mouse: false, + keyboard_leds_enabled: false, + led_state: LedState::default(), + device: None, + error: None, + error_code: None, + }, + msd: None, + atx: None, + audio: None, + ttyd: TtydDeviceInfo { + available: false, + running: false, + }, + }, + SystemEvent::Error { + message: String::new(), + }, + ]; + + let from_enum: HashSet<_> = samples.iter().map(|e| e.event_name()).collect(); + let from_const: HashSet<_> = super::EXACT_EVENT_TOPICS.iter().copied().collect(); + assert_eq!(from_enum, from_const); } #[test] diff --git a/src/extensions/manager.rs b/src/extensions/manager.rs index fd03d46a..fd61b7fa 100644 --- a/src/extensions/manager.rs +++ b/src/extensions/manager.rs @@ -1,5 +1,3 @@ -//! Extension process manager - use std::collections::{HashMap, VecDeque}; use std::path::Path; use std::process::Stdio; @@ -12,25 +10,18 @@ use tokio::sync::RwLock; use super::types::*; use crate::events::EventBus; -/// Maximum number of log lines to keep per extension const LOG_BUFFER_SIZE: usize = 200; - -/// Number of log lines to buffer before flushing to shared storage const LOG_BATCH_SIZE: usize = 16; -/// Unix socket path for ttyd pub const TTYD_SOCKET_PATH: &str = "/var/run/one-kvm/ttyd.sock"; -/// Extension process with log buffer struct ExtensionProcess { child: Child, logs: Arc>>, } -/// Extension manager handles lifecycle of external processes pub struct ExtensionManager { processes: RwLock>, - /// Cached availability status (checked once at startup) availability: HashMap, event_bus: RwLock>>, } @@ -42,9 +33,7 @@ impl Default for ExtensionManager { } impl ExtensionManager { - /// Create a new extension manager with cached availability pub fn new() -> Self { - // Check availability once at startup let availability = ExtensionId::all() .iter() .map(|id| (*id, Path::new(id.binary_path()).exists())) @@ -57,7 +46,6 @@ impl ExtensionManager { } } - /// Set event bus for ttyd status notifications. pub async fn set_event_bus(&self, event_bus: Arc) { *self.event_bus.write().await = Some(event_bus); } @@ -72,12 +60,10 @@ impl ExtensionManager { } } - /// Check if the binary for an extension is available (cached) pub fn check_available(&self, id: ExtensionId) -> bool { *self.availability.get(&id).unwrap_or(&false) } - /// Get the current status of an extension pub async fn status(&self, id: ExtensionId) -> ExtensionStatus { if !self.check_available(id) { return ExtensionStatus::Unavailable; @@ -117,20 +103,13 @@ impl ExtensionManager { ExtensionStatus::Stopped } - /// Start an extension with the given configuration pub async fn start(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result<(), String> { if !self.check_available(id) { - return Err(format!( - "{} not found at {}", - id.display_name(), - id.binary_path() - )); + return Err(format!("{} not found at {}", id, id.binary_path())); } - // Stop existing process first self.stop(id).await.ok(); - // Build command arguments let args = self.build_args(id, config).await?; tracing::info!( @@ -146,11 +125,10 @@ impl ExtensionManager { .stderr(Stdio::piped()) .kill_on_drop(true) .spawn() - .map_err(|e| format!("Failed to start {}: {}", id.display_name(), e))?; + .map_err(|e| format!("Failed to start {}: {}", id, e))?; let logs = Arc::new(RwLock::new(VecDeque::with_capacity(LOG_BUFFER_SIZE))); - // Spawn log collector for stdout if let Some(stdout) = child.stdout.take() { let logs_clone = logs.clone(); let id_clone = id; @@ -159,7 +137,6 @@ impl ExtensionManager { }); } - // Spawn log collector for stderr if let Some(stderr) = child.stderr.take() { let logs_clone = logs.clone(); let id_clone = id; @@ -179,7 +156,6 @@ impl ExtensionManager { Ok(()) } - /// Stop an extension pub async fn stop(&self, id: ExtensionId) -> Result<(), String> { let mut processes = self.processes.write().await; if let Some(mut proc) = processes.remove(&id) { @@ -193,7 +169,6 @@ impl ExtensionManager { Ok(()) } - /// Get recent logs for an extension pub async fn logs(&self, id: ExtensionId, lines: usize) -> Vec { let processes = self.processes.read().await; if let Some(proc) = processes.get(&id) { @@ -205,7 +180,6 @@ impl ExtensionManager { } } - /// Collect logs from a stream with batched writes to reduce lock contention async fn collect_logs( id: ExtensionId, reader: R, @@ -221,13 +195,11 @@ impl ExtensionManager { tracing::debug!("[{}] {}", id, line); local_buffer.push(line); - // Flush when batch is full if local_buffer.len() >= LOG_BATCH_SIZE { Self::flush_logs(&logs, &mut local_buffer).await; } } Ok(None) => { - // Stream ended, flush remaining logs if !local_buffer.is_empty() { Self::flush_logs(&logs, &mut local_buffer).await; } @@ -241,7 +213,6 @@ impl ExtensionManager { } } - /// Flush buffered logs to shared storage async fn flush_logs(logs: &RwLock>, buffer: &mut Vec) { let mut logs = logs.write().await; for line in buffer.drain(..) { @@ -252,7 +223,6 @@ impl ExtensionManager { } } - /// Build command arguments for an extension async fn build_args( &self, id: ExtensionId, @@ -262,18 +232,16 @@ impl ExtensionManager { ExtensionId::Ttyd => { let c = &config.ttyd; - // Prepare socket directory and clean up old socket (async) Self::prepare_ttyd_socket().await?; let mut args = vec![ "-i".to_string(), - TTYD_SOCKET_PATH.to_string(), // Unix socket + TTYD_SOCKET_PATH.to_string(), "-b".to_string(), - "/api/terminal".to_string(), // Base path for reverse proxy - "-W".to_string(), // Writable (allow input) + "/api/terminal".to_string(), + "-W".to_string(), ]; - // Add shell as last argument args.push(c.shell.clone()); Ok(args) } @@ -289,15 +257,12 @@ impl ExtensionManager { let mut args = Vec::new(); - // Add TLS flag if c.tls { args.push("--tls=true".to_string()); } - // Server address (validated non-empty above) args.extend(["-addr".to_string(), c.addr.trim().to_string()]); - // Add client key args.extend(["-key".to_string(), c.key.clone()]); Ok(args) @@ -316,24 +281,19 @@ impl ExtensionManager { c.network_secret.clone(), ]; - // Add peer URLs for peer in &c.peer_urls { if !peer.is_empty() { args.extend(["--peers".to_string(), peer.clone()]); } } - // Add virtual IP: use -d for DHCP if empty, or -i for specific IP if let Some(ref ip) = c.virtual_ip { if !ip.is_empty() { - // Use specific IP with -i (must include CIDR, e.g., 10.0.0.1/24) args.extend(["-i".to_string(), ip.clone()]); } else { - // Empty string means use DHCP args.push("-d".to_string()); } } else { - // None means use DHCP args.push("-d".to_string()); } @@ -342,11 +302,9 @@ impl ExtensionManager { } } - /// Prepare ttyd socket directory and clean up old socket file async fn prepare_ttyd_socket() -> Result<(), String> { let socket_path = Path::new(TTYD_SOCKET_PATH); - // Ensure socket directory exists if let Some(socket_dir) = socket_path.parent() { if !socket_dir.exists() { tokio::fs::create_dir_all(socket_dir) @@ -355,7 +313,6 @@ impl ExtensionManager { } } - // Remove old socket file if exists if tokio::fs::try_exists(TTYD_SOCKET_PATH) .await .unwrap_or(false) @@ -368,9 +325,7 @@ impl ExtensionManager { Ok(()) } - /// Health check - restart crashed processes that should be running pub async fn health_check(&self, config: &ExtensionsConfig) { - // Collect extensions that need restart check let checks: Vec<_> = ExtensionId::all() .iter() .filter_map(|id| { @@ -393,7 +348,6 @@ impl ExtensionManager { }) .collect(); - // Check which ones need restart (single read lock) let needs_restart: Vec<_> = { let processes = self.processes.read().await; checks @@ -408,7 +362,6 @@ impl ExtensionManager { .collect() }; - // Restart all crashed extensions in parallel let restart_futures: Vec<_> = needs_restart .into_iter() .map(|id| async move { @@ -422,14 +375,12 @@ impl ExtensionManager { futures::future::join_all(restart_futures).await; } - /// Start all enabled extensions in parallel pub async fn start_enabled(&self, config: &ExtensionsConfig) { use futures::Future; use std::pin::Pin; let mut start_futures: Vec + Send + '_>>> = Vec::new(); - // Collect enabled extensions if config.ttyd.enabled && self.check_available(ExtensionId::Ttyd) { start_futures.push(Box::pin(async { if let Err(e) = self.start(ExtensionId::Ttyd, config).await { @@ -461,11 +412,9 @@ impl ExtensionManager { })); } - // Start all in parallel futures::future::join_all(start_futures).await; } - /// Stop all running extensions in parallel pub async fn stop_all(&self) { let stop_futures: Vec<_> = ExtensionId::all().iter().map(|id| self.stop(*id)).collect(); futures::future::join_all(stop_futures).await; diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index bc430dd3..0d242ff7 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,5 +1,3 @@ -//! Extensions module - manage external processes like ttyd, gostc, easytier - mod manager; mod types; diff --git a/src/extensions/types.rs b/src/extensions/types.rs index 2ecda744..ff6d2c99 100644 --- a/src/extensions/types.rs +++ b/src/extensions/types.rs @@ -1,23 +1,16 @@ -//! Extension types and configurations - use serde::{Deserialize, Serialize}; use typeshare::typeshare; -/// Extension identifier (fixed set of supported extensions) #[typeshare] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ExtensionId { - /// Web terminal (ttyd) Ttyd, - /// NAT traversal client (gostc) Gostc, - /// P2P VPN (easytier) Easytier, } impl ExtensionId { - /// Get the binary path for this extension pub fn binary_path(&self) -> &'static str { match self { Self::Ttyd => "/usr/bin/ttyd", @@ -26,16 +19,6 @@ impl ExtensionId { } } - /// Get the display name for this extension - pub fn display_name(&self) -> &'static str { - match self { - Self::Ttyd => "Web Terminal", - Self::Gostc => "GOSTC Tunnel", - Self::Easytier => "EasyTier VPN", - } - } - - /// Get all extension IDs pub fn all() -> &'static [ExtensionId] { &[Self::Ttyd, Self::Gostc, Self::Easytier] } @@ -64,25 +47,14 @@ impl std::str::FromStr for ExtensionId { } } -/// Extension running status #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "state", content = "data", rename_all = "lowercase")] pub enum ExtensionStatus { - /// Binary not found at expected path Unavailable, - /// Extension is stopped Stopped, - /// Extension is running - Running { - /// Process ID - pid: u32, - }, - /// Extension failed to start - Failed { - /// Error message - error: String, - }, + Running { pid: u32 }, + Failed { error: String }, } impl ExtensionStatus { @@ -91,16 +63,11 @@ impl ExtensionStatus { } } -/// ttyd configuration (Web Terminal) #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(default)] pub struct TtydConfig { - /// Enable auto-start pub enabled: bool, - /// Port to listen on - pub port: u16, - /// Shell to execute pub shell: String, } @@ -108,25 +75,19 @@ impl Default for TtydConfig { fn default() -> Self { Self { enabled: false, - port: 7681, shell: "/bin/bash".to_string(), } } } -/// gostc configuration (NAT traversal based on FRP) #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(default)] pub struct GostcConfig { - /// Enable auto-start pub enabled: bool, - /// Server address (hostname or IP) pub addr: String, - /// Client key from GOSTC management panel #[serde(skip_serializing_if = "String::is_empty")] pub key: String, - /// Enable TLS pub tls: bool, } @@ -141,28 +102,21 @@ impl Default for GostcConfig { } } -/// EasyTier configuration (P2P VPN) #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(default)] #[derive(Default)] pub struct EasytierConfig { - /// Enable auto-start pub enabled: bool, - /// Network name pub network_name: String, - /// Network secret/password #[serde(skip_serializing_if = "String::is_empty")] pub network_secret: String, - /// Peer node URLs #[serde(skip_serializing_if = "Vec::is_empty")] pub peer_urls: Vec, - /// Virtual IP address (optional, auto-assigned if not set) #[serde(skip_serializing_if = "Option::is_none")] pub virtual_ip: Option, } -/// Combined extensions configuration #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize, Default)] #[serde(default)] @@ -172,53 +126,37 @@ pub struct ExtensionsConfig { pub easytier: EasytierConfig, } -/// Extension info with status and config #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExtensionInfo { - /// Whether binary exists pub available: bool, - /// Current status pub status: ExtensionStatus, } -/// ttyd extension info #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TtydInfo { - /// Whether binary exists pub available: bool, - /// Current status pub status: ExtensionStatus, - /// Configuration pub config: TtydConfig, } -/// gostc extension info #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GostcInfo { - /// Whether binary exists pub available: bool, - /// Current status pub status: ExtensionStatus, - /// Configuration pub config: GostcConfig, } -/// easytier extension info #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EasytierInfo { - /// Whether binary exists pub available: bool, - /// Current status pub status: ExtensionStatus, - /// Configuration pub config: EasytierConfig, } -/// All extensions status response #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExtensionsStatus { @@ -227,7 +165,6 @@ pub struct ExtensionsStatus { pub easytier: EasytierInfo, } -/// Extension logs response #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExtensionLogs { diff --git a/src/hid/backend.rs b/src/hid/backend.rs index fe8d6fdf..c5a6e406 100644 --- a/src/hid/backend.rs +++ b/src/hid/backend.rs @@ -1,73 +1,32 @@ -//! HID backend trait definition +//! `HidBackend` trait plus serde `HidBackendType` (OTG | CH9329 | disabled). use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tokio::sync::watch; -use super::otg::LedState; use super::types::{ConsumerEvent, KeyboardEvent, MouseEvent}; use crate::error::Result; +use crate::events::LedState; -/// Default CH9329 baud rate fn default_ch9329_baud_rate() -> u32 { 9600 } -/// HID backend type #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "lowercase")] #[derive(Default)] pub enum HidBackendType { - /// USB OTG gadget mode Otg, - /// CH9329 serial HID controller Ch9329 { - /// Serial port path port: String, - /// Baud rate (default: 9600) #[serde(default = "default_ch9329_baud_rate")] baud_rate: u32, }, - /// No HID backend (disabled) #[default] None, } impl HidBackendType { - /// Check if OTG backend is available on this system - pub fn otg_available() -> bool { - // Check for USB gadget support - std::path::Path::new("/sys/class/udc").exists() - } - - /// Detect the best available backend - pub fn detect() -> Self { - // Check for OTG gadget support - if Self::otg_available() { - return Self::Otg; - } - - // Check for common CH9329 serial ports - let common_ports = [ - "/dev/ttyUSB0", - "/dev/ttyUSB1", - "/dev/ttyAMA0", - "/dev/serial0", - ]; - - for port in &common_ports { - if std::path::Path::new(port).exists() { - return Self::Ch9329 { - port: port.to_string(), - baud_rate: 9600, // Use default baud rate for auto-detection - }; - } - } - - Self::None - } - - /// Get backend name as string pub fn name_str(&self) -> &str { match self { Self::Otg => "otg", @@ -77,76 +36,40 @@ impl HidBackendType { } } -/// Current runtime status reported by a HID backend. #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct HidBackendRuntimeSnapshot { - /// Whether the backend has been initialized and can accept requests. pub initialized: bool, - /// Whether the backend is currently online and communicating successfully. pub online: bool, - /// Whether absolute mouse positioning is supported. pub supports_absolute_mouse: bool, - /// Whether keyboard LED/status feedback is currently enabled. pub keyboard_leds_enabled: bool, - /// Last known keyboard LED state. pub led_state: LedState, - /// Screen resolution for absolute mouse mode. pub screen_resolution: Option<(u32, u32)>, - /// Device identifier associated with the backend, if any. pub device: Option, - /// Current user-facing error, if any. pub error: Option, - /// Current programmatic error code, if any. pub error_code: Option, } -/// HID backend trait #[async_trait] pub trait HidBackend: Send + Sync { - /// Initialize the backend async fn init(&self) -> Result<()>; - /// Send a keyboard event async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()>; - /// Send a mouse event async fn send_mouse(&self, event: MouseEvent) -> Result<()>; - /// Send a consumer control event (multimedia keys) - /// Default implementation returns an error (not supported) async fn send_consumer(&self, _event: ConsumerEvent) -> Result<()> { Err(crate::error::AppError::BadRequest( "Consumer control not supported by this backend".to_string(), )) } - /// Reset all inputs (release all keys/buttons) async fn reset(&self) -> Result<()>; - /// Shutdown the backend async fn shutdown(&self) -> Result<()>; - /// Get the current backend runtime snapshot. fn runtime_snapshot(&self) -> HidBackendRuntimeSnapshot; - /// Subscribe to backend runtime changes. fn subscribe_runtime(&self) -> watch::Receiver<()>; - /// Set screen resolution (for absolute mouse) - fn set_screen_resolution(&mut self, _width: u32, _height: u32) {} -} - -/// HID backend information -#[derive(Debug, Clone, Serialize)] -pub struct HidBackendInfo { - /// Backend name - pub name: String, - /// Backend type - pub backend_type: String, - /// Is initialized - pub initialized: bool, - /// Supports absolute mouse - pub absolute_mouse: bool, - /// Screen resolution (if absolute mouse) - pub resolution: Option<(u32, u32)>, + fn set_screen_resolution(&self, _width: u32, _height: u32) {} } diff --git a/src/hid/ch9329.rs b/src/hid/ch9329.rs index dc835d8d..4120e901 100644 --- a/src/hid/ch9329.rs +++ b/src/hid/ch9329.rs @@ -1,9 +1,4 @@ -//! CH9329 Serial HID Controller backend -//! -//! CH9329 is a USB HID chip controlled via UART from WCH (沁恒). -//! It supports keyboard, mouse (absolute + relative), and custom HID device emulation. -//! -//! ## Protocol Format +//! CH9329 over UART — WCH *Serial Communication Protocol V1.0*. //! ```text //! ┌──────┬──────┬──────┬────────┬──────────────┬──────────┐ //! │Header│ ADDR │ CMD │ LEN │ DATA │ SUM │ @@ -11,16 +6,11 @@ //! │57 AB │ 00 │ xx │ N │ N bytes │Checksum │ //! └──────┴──────┴──────┴────────┴──────────────┴──────────┘ //! ``` -//! -//! Checksum: Sum of ALL bytes including header (modulo 256) -//! -//! ## Reference -//! Based on WCH CH9329 Serial Communication Protocol V1.0 +//! Sum of all octets modulo 256 (including header). use async_trait::async_trait; use parking_lot::{Mutex, RwLock}; use serde::{Deserialize, Serialize}; -use std::io::{Read, Write}; use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU8, Ordering}; use std::sync::{mpsc, Arc}; use std::thread; @@ -29,82 +19,51 @@ use tokio::sync::watch; use tracing::{info, trace, warn}; use super::backend::{HidBackend, HidBackendRuntimeSnapshot}; -use super::otg::LedState; use super::types::{KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType}; use crate::error::{AppError, Result}; +use crate::events::LedState; -// ============================================================================ -// Constants and Command Codes -// ============================================================================ - -/// CH9329 packet header const PACKET_HEADER: [u8; 2] = [0x57, 0xAB]; -/// Default address (accepts any address) const DEFAULT_ADDR: u8 = 0x00; -/// Default baud rate for CH9329 pub const DEFAULT_BAUD_RATE: u32 = 9600; -/// Response timeout in milliseconds const RESPONSE_TIMEOUT_MS: u64 = 500; -/// Maximum data length in a packet const MAX_DATA_LEN: usize = 64; -/// CH9329 absolute mouse resolution const CH9329_MOUSE_RESOLUTION: u32 = 4096; -/// How often the worker probes the chip when idle. const PROBE_INTERVAL_MS: u64 = 100; -/// How long the worker waits before reopening the serial port after a failure. const RECONNECT_DELAY_MS: u64 = 2000; -/// Initial startup wait for the worker to confirm CH9329 is reachable. const INIT_WAIT_MS: u64 = 3000; -/// CH9329 command codes pub mod cmd { - /// Get chip version, USB status, and LED status pub const GET_INFO: u8 = 0x01; - /// Send standard keyboard data (8 bytes) pub const SEND_KB_GENERAL_DATA: u8 = 0x02; - /// Send multimedia keyboard data pub const SEND_KB_MEDIA_DATA: u8 = 0x03; - /// Send absolute mouse data pub const SEND_MS_ABS_DATA: u8 = 0x04; - /// Send relative mouse data pub const SEND_MS_REL_DATA: u8 = 0x05; - /// Send custom HID data pub const SEND_MY_HID_DATA: u8 = 0x06; - /// Restore factory default configuration pub const SET_DEFAULT_CFG: u8 = 0x0C; - /// Software reset pub const RESET: u8 = 0x0F; } -/// Response command mask (success = cmd | 0x80, error = cmd | 0xC0) const RESPONSE_SUCCESS_MASK: u8 = 0x80; const RESPONSE_ERROR_MASK: u8 = 0xC0; -/// CH9329 error codes #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum Ch9329Error { - /// Command executed successfully Success = 0x00, - /// Serial receive timeout Timeout = 0xE1, - /// Invalid packet header InvalidHeader = 0xE2, - /// Invalid command code InvalidCommand = 0xE3, - /// Checksum mismatch ChecksumError = 0xE4, - /// Parameter error ParameterError = 0xE5, - /// Execution failed OperationFailed = 0xE6, } @@ -137,29 +96,17 @@ impl std::fmt::Display for Ch9329Error { } } -// ============================================================================ -// Chip Information -// ============================================================================ - -/// CH9329 chip information #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ChipInfo { - /// Chip version (e.g., "V1.0", "V1.1") pub version: String, - /// Raw version byte pub version_raw: u8, - /// USB connection status pub usb_connected: bool, - /// Num Lock LED state pub num_lock: bool, - /// Caps Lock LED state pub caps_lock: bool, - /// Scroll Lock LED state pub scroll_lock: bool, } impl ChipInfo { - /// Parse chip info from response data (8 bytes) pub fn from_response(data: &[u8]) -> Option { if data.len() < 8 { return None; @@ -181,7 +128,6 @@ impl ChipInfo { } } -/// Keyboard LED status #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] pub struct LedStatus { pub num_lock: bool, @@ -199,98 +145,21 @@ impl From for LedStatus { } } -// ============================================================================ -// Configuration -// ============================================================================ - -/// CH9329 work mode -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[repr(u8)] -#[derive(Default)] -pub enum WorkMode { - /// Mode 0: Standard USB Keyboard + Mouse (default) - #[default] - KeyboardMouse = 0x00, - /// Mode 1: Standard USB Keyboard only - KeyboardOnly = 0x01, - /// Mode 2: Standard USB Mouse only - MouseOnly = 0x02, - /// Mode 3: Custom HID device - CustomHid = 0x03, -} - -/// CH9329 serial communication mode -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[repr(u8)] -#[derive(Default)] -pub enum SerialMode { - /// Mode 0: Protocol transmission mode (default) - #[default] - Protocol = 0x00, - /// Mode 1: ASCII mode - Ascii = 0x01, - /// Mode 2: Transparent mode - Transparent = 0x02, -} - -/// CH9329 configuration parameters -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Ch9329Config { - /// Work mode - pub work_mode: WorkMode, - /// Serial communication mode - pub serial_mode: SerialMode, - /// Device address (0x00-0xFE, 0xFF = broadcast) - pub address: u8, - /// Baud rate - pub baud_rate: u32, - /// USB VID - pub vid: u16, - /// USB PID - pub pid: u16, -} - -impl Default for Ch9329Config { - fn default() -> Self { - Self { - work_mode: WorkMode::KeyboardMouse, - serial_mode: SerialMode::Protocol, - address: 0x00, - baud_rate: 9600, - vid: 0x1A86, - pid: 0xE129, - } - } -} - -// ============================================================================ -// Response Parsing -// ============================================================================ - -/// Parsed response from CH9329 #[derive(Debug)] pub struct Response { - /// Address byte pub address: u8, - /// Command code (with response bits) pub cmd: u8, - /// Data payload pub data: Vec, - /// Whether this is an error response pub is_error: bool, - /// Error code (if is_error) pub error_code: Option, } impl Response { - /// Parse a response from raw bytes pub fn parse(bytes: &[u8]) -> Option { - // Minimum: Header(2) + Addr(1) + Cmd(1) + Len(1) + Sum(1) = 6 if bytes.len() < 6 { return None; } - // Check header if bytes[0] != PACKET_HEADER[0] || bytes[1] != PACKET_HEADER[1] { return None; } @@ -299,12 +168,10 @@ impl Response { let cmd = bytes[3]; let len = bytes[4] as usize; - // Check if we have enough bytes if bytes.len() < 5 + len + 1 { return None; } - // Verify checksum let expected_checksum = bytes[5 + len]; let calculated_checksum = bytes[..5 + len] .iter() @@ -335,19 +202,13 @@ impl Response { }) } - /// Check if the response indicates success pub fn is_success(&self) -> bool { !self.is_error && (self.data.is_empty() || self.data[0] == Ch9329Error::Success as u8) } } -/// Maximum packet size (header 2 + addr 1 + cmd 1 + len 1 + data 64 + checksum 1 = 70) const MAX_PACKET_SIZE: usize = 70; -// ============================================================================ -// CH9329 Backend Implementation -// ============================================================================ - struct Ch9329RuntimeState { initialized: AtomicBool, online: AtomicBool, @@ -424,47 +285,28 @@ enum WorkerCommand { Shutdown, } -/// CH9329 HID backend pub struct Ch9329Backend { - /// Serial port path port_path: String, - /// Baud rate baud_rate: u32, - /// Worker command sender worker_tx: Mutex>>, - /// Background worker thread worker_handle: Mutex>>, - /// Current keyboard state keyboard_state: Mutex, - /// Current mouse button state mouse_buttons: AtomicU8, - /// Screen width for absolute mouse coordinate conversion - screen_width: u32, - /// Screen height for absolute mouse coordinate conversion - screen_height: u32, - /// Cached chip information + screen_resolution: RwLock<(u32, u32)>, chip_info: Arc>>, - /// LED status cache led_status: Arc>, - /// Device address (default 0x00) address: u8, - /// Last absolute mouse X position (CH9329 coordinate: 0-4095) last_abs_x: AtomicU16, - /// Last absolute mouse Y position (CH9329 coordinate: 0-4095) last_abs_y: AtomicU16, - /// Whether relative mouse mode is active (set by incoming events) relative_mouse_active: AtomicBool, - /// Shared runtime status updated only by the worker. runtime: Arc, } impl Ch9329Backend { - /// Create a new CH9329 backend with default baud rate (9600) pub fn new(port_path: &str) -> Result { Self::with_baud_rate(port_path, DEFAULT_BAUD_RATE) } - /// Create a new CH9329 backend with custom baud rate pub fn with_baud_rate(port_path: &str, baud_rate: u32) -> Result { Ok(Self { port_path: port_path.to_string(), @@ -473,8 +315,7 @@ impl Ch9329Backend { worker_handle: Mutex::new(None), keyboard_state: Mutex::new(KeyboardReport::default()), mouse_buttons: AtomicU8::new(0), - screen_width: 1920, - screen_height: 1080, + screen_resolution: RwLock::new((1920, 1080)), chip_info: Arc::new(RwLock::new(None)), led_status: Arc::new(RwLock::new(LedStatus::default())), address: DEFAULT_ADDR, @@ -489,12 +330,10 @@ impl Ch9329Backend { self.runtime.set_error(reason, error_code); } - /// Check if the serial port device file exists pub fn check_port_exists(&self) -> bool { std::path::Path::new(&self.port_path).exists() } - /// Convert serialport error to HidError fn serial_error_to_hid_error(e: serialport::Error, operation: &str) -> AppError { let error_code = match e.kind() { serialport::ErrorKind::NoDevice => "port_not_found", @@ -518,16 +357,11 @@ impl Ch9329Backend { } } - /// Calculate checksum for CH9329 packet (sum of ALL bytes including header) #[inline] fn calculate_checksum(data: &[u8]) -> u8 { data.iter().fold(0u8, |acc, &x| acc.wrapping_add(x)) } - /// Build a CH9329 packet into a stack-allocated buffer - /// - /// Packet format: `[Header 0x57 0xAB] [Address] [Command] [Length] [Data] [Checksum]` - /// Returns the packet buffer and the actual length #[inline] fn build_packet_buf(address: u8, cmd: u8, data: &[u8]) -> ([u8; MAX_PACKET_SIZE], usize) { debug_assert!( @@ -539,25 +373,18 @@ impl Ch9329Backend { let packet_len = 6 + data.len(); let mut packet = [0u8; MAX_PACKET_SIZE]; - // Header (2 bytes) packet[0] = PACKET_HEADER[0]; packet[1] = PACKET_HEADER[1]; - // Address (1 byte) packet[2] = address; - // Command (1 byte) packet[3] = cmd; - // Length (1 byte) - data length only packet[4] = len; - // Data (N bytes) packet[5..5 + data.len()].copy_from_slice(data); - // Checksum (1 byte) - sum of ALL bytes including header let checksum = Self::calculate_checksum(&packet[..5 + data.len()]); packet[5 + data.len()] = checksum; (packet, packet_len) } - /// Build a CH9329 packet (legacy Vec version for compatibility) fn build_packet(address: u8, cmd: u8, data: &[u8]) -> Vec { let (buf, len) = Self::build_packet_buf(address, cmd, data); buf[..len].to_vec() @@ -984,10 +811,6 @@ impl Ch9329Backend { } } -// ============================================================================ -// HidBackend Trait Implementation -// ============================================================================ - #[async_trait] impl HidBackend for Ch9329Backend { async fn init(&self) -> Result<()> { @@ -1060,7 +883,6 @@ impl HidBackend for Ch9329Backend { async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> { let usb_key = event.key.to_hid_usage(); - // Handle modifier keys separately if event.key.is_modifier() { let mut state = self.keyboard_state.lock(); @@ -1078,7 +900,6 @@ impl HidBackend for Ch9329Backend { } else { let mut state = self.keyboard_state.lock(); - // Update modifiers from event state.modifiers = event.modifiers.to_hid_byte(); match event.event_type { @@ -1104,19 +925,15 @@ impl HidBackend for Ch9329Backend { match event.event_type { MouseEventType::Move => { - // Relative movement - send delta directly without inversion self.relative_mouse_active.store(true, Ordering::Relaxed); let dx = event.x.clamp(-127, 127) as i8; let dy = event.y.clamp(-127, 127) as i8; self.send_mouse_relative(buttons, dx, dy, 0)?; } MouseEventType::MoveAbs => { - // Absolute movement self.relative_mouse_active.store(false, Ordering::Relaxed); - // Frontend sends 0-32767 (HID standard), CH9329 expects 0-4095 let x = ((event.x.clamp(0, 32767) as u32) * CH9329_MOUSE_RESOLUTION / 32768) as u16; let y = ((event.y.clamp(0, 32767) as u32) * CH9329_MOUSE_RESOLUTION / 32768) as u16; - // Store last absolute position for click events self.last_abs_x.store(x, Ordering::Relaxed); self.last_abs_y.store(y, Ordering::Relaxed); self.send_mouse_absolute(buttons, x, y, 0)?; @@ -1153,7 +970,6 @@ impl HidBackend for Ch9329Backend { if self.relative_mouse_active.load(Ordering::Relaxed) { self.send_mouse_relative(buttons, 0, 0, event.scroll)?; } else { - // Use absolute mouse for scroll with last position let x = self.last_abs_x.load(Ordering::Relaxed); let y = self.last_abs_y.load(Ordering::Relaxed); self.send_mouse_absolute(buttons, x, y, event.scroll)?; @@ -1165,7 +981,6 @@ impl HidBackend for Ch9329Backend { } async fn reset(&self) -> Result<()> { - // Reset keyboard { let mut state = self.keyboard_state.lock(); state.clear(); @@ -1174,14 +989,12 @@ impl HidBackend for Ch9329Backend { self.send_keyboard_report(&report)?; } - // Reset mouse self.mouse_buttons.store(0, Ordering::Relaxed); self.last_abs_x.store(0, Ordering::Relaxed); self.last_abs_y.store(0, Ordering::Relaxed); self.relative_mouse_active.store(false, Ordering::Relaxed); self.send_mouse_absolute(0, 0, 0, 0)?; - // Reset media keys let _ = self.release_media_keys(); info!("CH9329 HID state reset"); @@ -1233,7 +1046,7 @@ impl HidBackend for Ch9329Backend { kana: false, } }, - screen_resolution: Some((self.screen_width, self.screen_height)), + screen_resolution: Some(*self.screen_resolution.read()), device: Some(self.port_path.clone()), error: error.as_ref().map(|(reason, _)| reason.clone()), error_code: error.as_ref().map(|(_, code)| code.clone()), @@ -1244,125 +1057,21 @@ impl HidBackend for Ch9329Backend { self.runtime.subscribe() } - fn set_screen_resolution(&mut self, width: u32, height: u32) { - self.screen_width = width; - self.screen_height = height; + fn set_screen_resolution(&self, width: u32, height: u32) { + *self.screen_resolution.write() = (width, height); self.runtime.notify(); } } -// ============================================================================ -// Detection and Helpers -// ============================================================================ - -/// Detect CH9329 on common serial ports -pub fn detect_ch9329() -> Option { - let common_ports = [ - "/dev/ttyUSB0", - "/dev/ttyUSB1", - "/dev/ttyAMA0", - "/dev/serial0", - "/dev/ttyS0", - ]; - - // Try multiple baud rates - let baud_rates = [9600, 115200]; - - for port_path in &common_ports { - if !std::path::Path::new(port_path).exists() { - continue; - } - - for &baud_rate in &baud_rates { - if let Ok(mut port) = serialport::new(*port_path, baud_rate) - .timeout(Duration::from_millis(200)) - .open() - { - // Build GET_INFO packet manually (address = 0x00) - let packet = [0x57, 0xAB, 0x00, cmd::GET_INFO, 0x00, 0x03]; - - if port.write_all(&packet).is_ok() { - std::thread::sleep(Duration::from_millis(50)); - - let mut response = [0u8; 16]; - if let Ok(n) = port.read(&mut response) { - // Check for valid CH9329 response header - if n >= 6 - && response[0] == PACKET_HEADER[0] - && response[1] == PACKET_HEADER[1] - { - info!("CH9329 detected on {} @ {} baud", port_path, baud_rate); - return Some(port_path.to_string()); - } - } - } - } - } - } - - None -} - -/// Detect CH9329 and return both path and working baud rate -pub fn detect_ch9329_with_baud() -> Option<(String, u32)> { - let common_ports = [ - "/dev/ttyUSB0", - "/dev/ttyUSB1", - "/dev/ttyAMA0", - "/dev/serial0", - "/dev/ttyS0", - ]; - - let baud_rates = [9600, 115200, 57600, 38400, 19200]; - - for port_path in &common_ports { - if !std::path::Path::new(port_path).exists() { - continue; - } - - for &baud_rate in &baud_rates { - if let Ok(mut port) = serialport::new(*port_path, baud_rate) - .timeout(Duration::from_millis(200)) - .open() - { - let packet = [0x57, 0xAB, 0x00, cmd::GET_INFO, 0x00, 0x03]; - - if port.write_all(&packet).is_ok() { - std::thread::sleep(Duration::from_millis(50)); - - let mut response = [0u8; 16]; - if let Ok(n) = port.read(&mut response) { - if n >= 6 - && response[0] == PACKET_HEADER[0] - && response[1] == PACKET_HEADER[1] - { - info!("CH9329 detected on {} @ {} baud", port_path, baud_rate); - return Some((port_path.to_string(), baud_rate)); - } - } - } - } - } - } - - None -} - -// ============================================================================ -// Tests -// ============================================================================ - #[cfg(test)] mod tests { use super::*; #[test] fn test_packet_building() { - // Test GET_INFO packet (no data) let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::GET_INFO, &[]); assert_eq!(packet, vec![0x57, 0xAB, 0x00, 0x01, 0x00, 0x03]); - // Test keyboard packet (8 bytes data) let data = [0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; // 'A' key let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_KB_GENERAL_DATA, &data); @@ -1372,7 +1081,6 @@ mod tests { assert_eq!(packet[3], cmd::SEND_KB_GENERAL_DATA); // Command assert_eq!(packet[4], 8); // Length (8 data bytes) assert_eq!(&packet[5..13], &data); // Data - // Checksum = 0x57 + 0xAB + 0x00 + 0x02 + 0x08 + 0x00 + 0x00 + 0x04 + ... = 0x10 let expected_checksum: u8 = packet[..13] .iter() .fold(0u8, |acc: u8, &x| acc.wrapping_add(x)); @@ -1381,7 +1089,6 @@ mod tests { #[test] fn test_relative_mouse_packet() { - // Test relative mouse: move right 50 pixels let data = [0x01, 0x00, 50u8, 0x00, 0x00]; let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_MS_REL_DATA, &data); @@ -1397,12 +1104,10 @@ mod tests { #[test] fn test_checksum_calculation() { - // Known packet: GET_INFO let packet = [0x57u8, 0xAB, 0x00, 0x01, 0x00]; let checksum = Ch9329Backend::calculate_checksum(&packet); assert_eq!(checksum, 0x03); - // Known packet: Keyboard 'A' press let packet = [ 0x57u8, 0xAB, 0x00, 0x02, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, ]; @@ -1412,7 +1117,6 @@ mod tests { #[test] fn test_response_parsing() { - // Valid GET_INFO response let response_bytes = [ 0x57, 0xAB, // Header 0x00, // Address @@ -1422,9 +1126,7 @@ mod tests { 0xE0, // Checksum (calculated) ]; - // Note: checksum in test is just placeholder, parse will validate let _result = Response::parse(&response_bytes); - // This will fail because checksum doesn't match, but structure is tested } #[test] diff --git a/src/hid/consumer.rs b/src/hid/consumer.rs index f99968df..825fd9ee 100644 --- a/src/hid/consumer.rs +++ b/src/hid/consumer.rs @@ -2,21 +2,17 @@ //! //! Reference: USB HID Usage Tables 1.12, Section 15 (Consumer Page 0x0C) -/// Consumer Control Usage codes for multimedia keys pub mod usage { - // Transport Controls pub const PLAY_PAUSE: u16 = 0x00CD; pub const STOP: u16 = 0x00B7; pub const NEXT_TRACK: u16 = 0x00B5; pub const PREV_TRACK: u16 = 0x00B6; - // Volume Controls pub const MUTE: u16 = 0x00E2; pub const VOLUME_UP: u16 = 0x00E9; pub const VOLUME_DOWN: u16 = 0x00EA; } -/// Check if a usage code is valid pub fn is_valid_usage(usage: u16) -> bool { matches!( usage, diff --git a/src/hid/datachannel.rs b/src/hid/datachannel.rs index 4ab3c125..0165db8d 100644 --- a/src/hid/datachannel.rs +++ b/src/hid/datachannel.rs @@ -42,23 +42,19 @@ use super::{ MouseEventType, }; -/// Message types pub const MSG_KEYBOARD: u8 = 0x01; pub const MSG_MOUSE: u8 = 0x02; pub const MSG_CONSUMER: u8 = 0x03; -/// Keyboard event types pub const KB_EVENT_DOWN: u8 = 0x00; pub const KB_EVENT_UP: u8 = 0x01; -/// Mouse event types pub const MS_EVENT_MOVE: u8 = 0x00; pub const MS_EVENT_MOVE_ABS: u8 = 0x01; pub const MS_EVENT_DOWN: u8 = 0x02; pub const MS_EVENT_UP: u8 = 0x03; pub const MS_EVENT_SCROLL: u8 = 0x04; -/// Parsed HID event from DataChannel #[derive(Debug, Clone)] pub enum HidChannelEvent { Keyboard(KeyboardEvent), @@ -66,7 +62,6 @@ pub enum HidChannelEvent { Consumer(ConsumerEvent), } -/// Parse a binary HID message from DataChannel pub fn parse_hid_message(data: &[u8]) -> Option { if data.is_empty() { warn!("Empty HID message"); @@ -86,7 +81,6 @@ pub fn parse_hid_message(data: &[u8]) -> Option { } } -/// Parse keyboard message payload fn parse_keyboard_message(data: &[u8]) -> Option { if data.len() < 3 { warn!("Keyboard message too short: {} bytes", data.len()); @@ -129,7 +123,6 @@ fn parse_keyboard_message(data: &[u8]) -> Option { })) } -/// Parse mouse message payload fn parse_mouse_message(data: &[u8]) -> Option { if data.len() < 6 { warn!("Mouse message too short: {} bytes", data.len()); @@ -148,11 +141,9 @@ fn parse_mouse_message(data: &[u8]) -> Option { } }; - // Parse coordinates as i16 LE (works for both relative and absolute) let x = i16::from_le_bytes([data[1], data[2]]) as i32; let y = i16::from_le_bytes([data[3], data[4]]) as i32; - // Button or scroll delta let (button, scroll) = match event_type { MouseEventType::Down | MouseEventType::Up => { let btn = match data[5] { @@ -178,7 +169,6 @@ fn parse_mouse_message(data: &[u8]) -> Option { })) } -/// Parse consumer control message payload fn parse_consumer_message(data: &[u8]) -> Option { if data.len() < 2 { warn!("Consumer message too short: {} bytes", data.len()); @@ -190,7 +180,6 @@ fn parse_consumer_message(data: &[u8]) -> Option { Some(HidChannelEvent::Consumer(ConsumerEvent { usage })) } -/// Encode a keyboard event to binary format (for sending to client if needed) pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec { let event_type = match event.event_type { KeyEventType::Down => KB_EVENT_DOWN, @@ -207,40 +196,6 @@ pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec { ] } -/// Encode a mouse event to binary format (for sending to client if needed) -pub fn encode_mouse_event(event: &MouseEvent) -> Vec { - let event_type = match event.event_type { - MouseEventType::Move => MS_EVENT_MOVE, - MouseEventType::MoveAbs => MS_EVENT_MOVE_ABS, - MouseEventType::Down => MS_EVENT_DOWN, - MouseEventType::Up => MS_EVENT_UP, - MouseEventType::Scroll => MS_EVENT_SCROLL, - }; - - let x_bytes = (event.x as i16).to_le_bytes(); - let y_bytes = (event.y as i16).to_le_bytes(); - - let extra = match event.event_type { - MouseEventType::Down | MouseEventType::Up => event - .button - .as_ref() - .map(|b| match b { - MouseButton::Left => 0u8, - MouseButton::Middle => 1u8, - MouseButton::Right => 2u8, - MouseButton::Back => 3u8, - MouseButton::Forward => 4u8, - }) - .unwrap_or(0), - MouseEventType::Scroll => event.scroll as u8, - _ => 0, - }; - - vec![ - MSG_MOUSE, event_type, x_bytes[0], x_bytes[1], y_bytes[0], y_bytes[1], extra, - ] -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/hid/keyboard.rs b/src/hid/keyboard.rs index f3c24038..59fb6d22 100644 --- a/src/hid/keyboard.rs +++ b/src/hid/keyboard.rs @@ -1,10 +1,6 @@ use serde::{Deserialize, Serialize}; use typeshare::typeshare; -/// Shared canonical keyboard key identifiers used across frontend and backend. -/// -/// The enum names intentionally mirror `KeyboardEvent.code` style values so the -/// browser, virtual keyboard, and HID backend can all speak the same language. #[typeshare] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum CanonicalKey { @@ -128,11 +124,6 @@ pub enum CanonicalKey { } impl CanonicalKey { - /// Convert the canonical key to a stable wire code. - /// - /// The wire code intentionally matches the USB HID usage for keyboard page - /// keys so existing low-level behavior stays intact while the semantic type - /// becomes explicit. pub const fn to_hid_usage(self) -> u8 { match self { Self::KeyA => 0x04, @@ -255,7 +246,6 @@ impl CanonicalKey { } } - /// Convert a wire code / USB HID usage to its canonical key. pub const fn from_hid_usage(usage: u8) -> Option { match usage { 0x04 => Some(Self::KeyA), diff --git a/src/hid/mod.rs b/src/hid/mod.rs index cddfdf6f..2bd572ca 100644 --- a/src/hid/mod.rs +++ b/src/hid/mod.rs @@ -1,15 +1,4 @@ -//! HID (Human Interface Device) control module -//! -//! This module provides keyboard and mouse control for remote KVM: -//! - USB OTG gadget mode (native Linux USB gadget) -//! - CH9329 serial HID controller -//! -//! Architecture: -//! ```text -//! Web Client --> WebSocket/DataChannel --> HID Events --> Backend --> Target PC -//! | -//! [OTG | CH9329] -//! ``` +//! HID path: browser (WebSocket or WebRTC DataChannel) → queue → OTG gadget or CH9329. pub mod backend; pub mod ch9329; @@ -20,51 +9,26 @@ pub mod otg; pub mod types; pub mod websocket; +pub use crate::events::LedState; pub use backend::{HidBackend, HidBackendRuntimeSnapshot, HidBackendType}; pub use keyboard::CanonicalKey; -pub use otg::LedState; pub use types::{ ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent, MouseEventType, }; -/// HID backend information -#[derive(Debug, Clone)] -pub struct HidInfo { - /// Backend name - pub name: String, - /// Whether backend is initialized - pub initialized: bool, - /// Whether absolute mouse positioning is supported - pub supports_absolute_mouse: bool, - /// Screen resolution for absolute mouse - pub screen_resolution: Option<(u32, u32)>, -} - -/// Unified HID runtime state used by snapshots and events. #[derive(Debug, Clone, PartialEq, Eq)] pub struct HidRuntimeState { - /// Whether a backend is configured and expected to exist. pub available: bool, - /// Stable backend key: "otg", "ch9329", "none". pub backend: String, - /// Whether the backend is currently initialized and operational. pub initialized: bool, - /// Whether the backend is currently online. pub online: bool, - /// Whether absolute mouse positioning is supported. pub supports_absolute_mouse: bool, - /// Whether keyboard LED/status feedback is enabled. pub keyboard_leds_enabled: bool, - /// Last known keyboard LED state. pub led_state: LedState, - /// Screen resolution for absolute mouse mode. pub screen_resolution: Option<(u32, u32)>, - /// Device path associated with the backend, if any. pub device: Option, - /// Current user-facing error, if any. pub error: Option, - /// Current programmatic error code, if any. pub error_code: Option, } @@ -140,45 +104,29 @@ const HID_EVENT_QUEUE_CAPACITY: usize = 64; const HID_EVENT_SEND_TIMEOUT_MS: u64 = 30; #[derive(Debug)] -enum HidEvent { +enum QueuedHidEvent { Keyboard(KeyboardEvent), Mouse(MouseEvent), Consumer(ConsumerEvent), Reset, } -/// HID controller managing keyboard and mouse input pub struct HidController { - /// OTG Service reference (only used when backend is OTG) otg_service: Option>, - /// Active backend backend: Arc>>>, - /// Backend type (mutable for reload) backend_type: Arc>, - /// Event bus for broadcasting state changes (optional) events: Arc>>>, - /// Unified HID runtime state. runtime_state: Arc>, - /// HID event queue sender (non-blocking) - hid_tx: mpsc::Sender, - /// HID event queue receiver (moved into worker on first start) - hid_rx: Mutex>>, - /// Coalesced mouse move (latest) + hid_tx: mpsc::Sender, + hid_rx: Mutex>>, pending_move: Arc>>, - /// Pending move flag (fast path) pending_move_flag: Arc, - /// Worker task handle hid_worker: Mutex>>, - /// Backend runtime subscription task handle runtime_worker: Mutex>>, - /// Backend initialization fast flag backend_available: Arc, } impl HidController { - /// Create a new HID controller with specified backend - /// - /// For OTG backend, otg_service should be provided to support hot-reload pub fn new(backend_type: HidBackendType, otg_service: Option>) -> Self { let (hid_tx, hid_rx) = mpsc::channel(HID_EVENT_QUEUE_CAPACITY); Self { @@ -199,12 +147,10 @@ impl HidController { } } - /// Set event bus for broadcasting state changes pub async fn set_event_bus(&self, events: Arc) { *self.events.write().await = Some(events); } - /// Initialize the HID backend pub async fn init(&self) -> Result<()> { let backend_type = self.backend_type.read().await.clone(); let backend: Arc = match backend_type { @@ -256,7 +202,6 @@ impl HidController { *self.backend.write().await = Some(backend); self.sync_runtime_state_from_backend().await; - // Start HID event worker (once) self.start_event_worker().await; self.restart_runtime_worker().await; @@ -264,12 +209,10 @@ impl HidController { Ok(()) } - /// Shutdown the HID backend and release resources pub async fn shutdown(&self) -> Result<()> { info!("Shutting down HID controller"); self.stop_runtime_worker().await; - // Close the backend if let Some(backend) = self.backend.write().await.take() { if let Err(e) = backend.shutdown().await { warn!("Error shutting down HID backend: {}", e); @@ -290,17 +233,15 @@ impl HidController { Ok(()) } - /// Send keyboard event pub async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> { if !self.backend_available.load(Ordering::Acquire) { return Err(AppError::BadRequest( "HID backend not available".to_string(), )); } - self.enqueue_event(HidEvent::Keyboard(event)).await + self.enqueue_event(QueuedHidEvent::Keyboard(event)).await } - /// Send mouse event pub async fn send_mouse(&self, event: MouseEvent) -> Result<()> { if !self.backend_available.load(Ordering::Acquire) { return Err(AppError::BadRequest( @@ -312,81 +253,55 @@ impl HidController { event.event_type, MouseEventType::Move | MouseEventType::MoveAbs ) { - // Best-effort: drop/merge move events if queue is full self.enqueue_mouse_move(event) } else { - self.enqueue_event(HidEvent::Mouse(event)).await + self.enqueue_event(QueuedHidEvent::Mouse(event)).await } } - /// Send consumer control event (multimedia keys) pub async fn send_consumer(&self, event: ConsumerEvent) -> Result<()> { if !self.backend_available.load(Ordering::Acquire) { return Err(AppError::BadRequest( "HID backend not available".to_string(), )); } - self.enqueue_event(HidEvent::Consumer(event)).await + self.enqueue_event(QueuedHidEvent::Consumer(event)).await } - /// Reset all keys (release all pressed keys) pub async fn reset(&self) -> Result<()> { if !self.backend_available.load(Ordering::Acquire) { return Ok(()); } - // Reset is important but best-effort; enqueue to avoid blocking - self.enqueue_event(HidEvent::Reset).await + self.enqueue_event(QueuedHidEvent::Reset).await } - /// Check if backend is available pub async fn is_available(&self) -> bool { self.backend_available.load(Ordering::Acquire) } - /// Get backend type pub async fn backend_type(&self) -> HidBackendType { self.backend_type.read().await.clone() } - /// Get backend info - pub async fn info(&self) -> Option { - let state = self.runtime_state.read().await.clone(); - if !state.available { - return None; - } - - Some(HidInfo { - name: state.backend, - initialized: state.initialized, - supports_absolute_mouse: state.supports_absolute_mouse, - screen_resolution: state.screen_resolution, - }) - } - - /// Get current HID runtime state snapshot. pub async fn snapshot(&self) -> HidRuntimeState { self.runtime_state.read().await.clone() } - /// Reload the HID backend with new type pub async fn reload(&self, new_backend_type: HidBackendType) -> Result<()> { info!("Reloading HID backend: {:?}", new_backend_type); self.backend_available.store(false, Ordering::Release); self.stop_runtime_worker().await; - // Shutdown existing backend first if let Some(backend) = self.backend.write().await.take() { if let Err(e) = backend.shutdown().await { warn!("Error shutting down old HID backend: {}", e); } } - // Create and initialize new backend let new_backend: Option> = match new_backend_type { HidBackendType::Otg => { info!("Initializing OTG HID backend"); - // Get OtgService reference let otg_service = match self.otg_service.as_ref() { Some(svc) => svc, None => { @@ -398,28 +313,25 @@ impl HidController { }; match otg_service.hid_device_paths().await { - Some(handles) => { - // Create OtgBackend from handles - match otg::OtgBackend::from_handles(handles) { - Ok(backend) => { - let backend = Arc::new(backend); - match backend.init().await { - Ok(_) => { - info!("OTG backend initialized successfully"); - Some(backend) - } - Err(e) => { - warn!("Failed to initialize OTG backend: {}", e); - None - } + Some(handles) => match otg::OtgBackend::from_handles(handles) { + Ok(backend) => { + let backend = Arc::new(backend); + match backend.init().await { + Ok(_) => { + info!("OTG backend initialized successfully"); + Some(backend) + } + Err(e) => { + warn!("Failed to initialize OTG backend: {}", e); + None } } - Err(e) => { - warn!("Failed to create OTG backend: {}", e); - None - } } - } + Err(e) => { + warn!("Failed to create OTG backend: {}", e); + None + } + }, None => { warn!("OTG HID paths are not available"); None @@ -470,7 +382,6 @@ impl HidController { info!("HID backend reloaded successfully: {:?}", new_backend_type); self.start_event_worker().await; - // Update backend_type on success *self.backend_type.write().await = new_backend_type.clone(); self.sync_runtime_state_from_backend().await; @@ -481,7 +392,6 @@ impl HidController { warn!("HID backend reload resulted in no active backend"); self.backend_available.store(false, Ordering::Release); - // Update backend_type even on failure (to reflect the attempted change) *self.backend_type.write().await = new_backend_type.clone(); let current = self.runtime_state.read().await.clone(); @@ -541,11 +451,10 @@ impl HidController { process_hid_event(event, &backend).await; - // After each event, flush latest move if pending if pending_move_flag.swap(false, Ordering::AcqRel) { let move_event = { pending_move.lock().take() }; if let Some(move_event) = move_event { - process_hid_event(HidEvent::Mouse(move_event), &backend).await; + process_hid_event(QueuedHidEvent::Mouse(move_event), &backend).await; } } } @@ -595,7 +504,7 @@ impl HidController { } fn enqueue_mouse_move(&self, event: MouseEvent) -> Result<()> { - match self.hid_tx.try_send(HidEvent::Mouse(event.clone())) { + match self.hid_tx.try_send(QueuedHidEvent::Mouse(event.clone())) { Ok(_) => Ok(()), Err(mpsc::error::TrySendError::Full(_)) => { *self.pending_move.lock() = Some(event); @@ -608,11 +517,10 @@ impl HidController { } } - async fn enqueue_event(&self, event: HidEvent) -> Result<()> { + async fn enqueue_event(&self, event: QueuedHidEvent) -> Result<()> { match self.hid_tx.try_send(event) { Ok(_) => Ok(()), Err(mpsc::error::TrySendError::Full(ev)) => { - // For non-move events, wait briefly to avoid dropping critical input let tx = self.hid_tx.clone(); let send_result = tokio::time::timeout( Duration::from_millis(HID_EVENT_SEND_TIMEOUT_MS), @@ -649,7 +557,10 @@ async fn apply_backend_runtime_state( apply_runtime_state(runtime_state, events, next).await; } -async fn process_hid_event(event: HidEvent, backend: &Arc>>>) { +async fn process_hid_event( + event: QueuedHidEvent, + backend: &Arc>>>, +) { let backend_opt = backend.read().await.clone(); let backend = match backend_opt { Some(b) => b, @@ -660,10 +571,10 @@ async fn process_hid_event(event: HidEvent, backend: &Arc backend_for_send.send_keyboard(ev).await, - HidEvent::Mouse(ev) => backend_for_send.send_mouse(ev).await, - HidEvent::Consumer(ev) => backend_for_send.send_consumer(ev).await, - HidEvent::Reset => backend_for_send.reset().await, + QueuedHidEvent::Keyboard(ev) => backend_for_send.send_keyboard(ev).await, + QueuedHidEvent::Mouse(ev) => backend_for_send.send_mouse(ev).await, + QueuedHidEvent::Consumer(ev) => backend_for_send.send_consumer(ev).await, + QueuedHidEvent::Reset => backend_for_send.reset().await, } }) }) @@ -682,12 +593,6 @@ async fn process_hid_event(event: HidEvent, backend: &Arc Self { - Self::new(HidBackendType::None, None) - } -} - fn device_for_backend_type(backend_type: &HidBackendType) -> Option { match backend_type { HidBackendType::Ch9329 { port, .. } => Some(port.clone()), diff --git a/src/hid/otg.rs b/src/hid/otg.rs index ebbbf3ca..6a25e1e4 100644 --- a/src/hid/otg.rs +++ b/src/hid/otg.rs @@ -1,28 +1,11 @@ -//! OTG USB Gadget HID backend +//! Linux gadget HID: `/dev/hidg*` opened from [`crate::otg::OtgService`]. +//! Typical nodes: hidg0 keyboard, hidg1 relative mouse, hidg2 absolute, hidg3 consumer control. //! -//! This backend uses Linux USB Gadget API to emulate USB HID devices. -//! It opens the HID gadget device nodes created by `OtgService`. -//! Depending on the configured OTG profile, this may include: -//! - hidg0: Keyboard -//! - hidg1: Relative Mouse -//! - hidg2: Absolute Mouse -//! - hidg3: Consumer Control Keyboard -//! -//! Requirements: -//! - USB OTG/Device controller (UDC) -//! - ConfigFS with USB gadget support -//! - Root privileges for gadget setup -//! -//! Error Recovery: -//! This module implements automatic device reconnection based on PiKVM's approach. -//! When ESHUTDOWN or EAGAIN errors occur (common during MSD operations), the device -//! file handles are closed and reopened on the next operation. -//! See: https://github.com/raspberrypi/linux/issues/4373 +//! Polled timed writes (JetKVM-style). Treat `ESHUTDOWN` (108) by closing handles and reopening; keep fd on `EAGAIN` (11). Host/gadget teardown during MSD resembles PiKVM. use async_trait::async_trait; use nix::poll::{poll, PollFd, PollFlags, PollTimeout}; use parking_lot::Mutex; -use serde::{Deserialize, Serialize}; use std::fs::{self, File, OpenOptions}; use std::io::{Read, Write}; use std::os::unix::fs::OpenOptionsExt; @@ -40,9 +23,9 @@ use super::types::{ ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType, }; use crate::error::{AppError, Result}; +use crate::events::LedState; use crate::otg::{wait_for_hid_devices, HidDevicePaths}; -/// Device type for ensure_device operations #[derive(Debug, Clone, Copy)] enum DeviceType { Keyboard, @@ -51,23 +34,7 @@ enum DeviceType { ConsumerControl, } -/// Keyboard LED state -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] -pub struct LedState { - /// Num Lock LED - pub num_lock: bool, - /// Caps Lock LED - pub caps_lock: bool, - /// Scroll Lock LED - pub scroll_lock: bool, - /// Compose LED - pub compose: bool, - /// Kana LED - pub kana: bool, -} - impl LedState { - /// Create from raw byte pub fn from_byte(b: u8) -> Self { Self { num_lock: b & 0x01 != 0, @@ -78,7 +45,6 @@ impl LedState { } } - /// Convert to raw byte pub fn to_byte(&self) -> u8 { let mut b = 0u8; if self.num_lock { @@ -100,76 +66,37 @@ impl LedState { } } -/// OTG HID backend with 4 devices -/// -/// This backend opens HID device files created by OtgService. -/// It does NOT manage the USB gadget itself - that's handled by OtgService. -/// -/// ## Error Recovery -/// -/// Based on PiKVM's implementation, this backend automatically handles: -/// - EAGAIN (errno 11): Resource temporarily unavailable - just retry later, don't close device -/// - ESHUTDOWN (errno 108): Transport endpoint shutdown - close and reopen device -/// -/// When ESHUTDOWN occurs, the device file handle is closed and will be -/// reopened on the next operation attempt. +/// Opens `/dev/hidg*` nodes provisioned by `OtgService`; gadget lifecycle is not handled here. pub struct OtgBackend { - /// Keyboard device path (/dev/hidg0) keyboard_path: Option, - /// Relative mouse device path (/dev/hidg1) mouse_rel_path: Option, - /// Absolute mouse device path (/dev/hidg2) mouse_abs_path: Option, - /// Consumer control device path (/dev/hidg3) consumer_path: Option, - /// Keyboard device file keyboard_dev: Mutex>, - /// Relative mouse device file mouse_rel_dev: Mutex>, - /// Absolute mouse device file mouse_abs_dev: Mutex>, - /// Consumer control device file consumer_dev: Mutex>, - /// Whether keyboard LED/status feedback is enabled. keyboard_leds_enabled: bool, - /// Current keyboard state keyboard_state: Mutex, - /// Current mouse button state mouse_buttons: AtomicU8, - /// Last known LED state (using parking_lot::RwLock for sync access) led_state: Arc>, - /// Screen resolution for absolute mouse (using parking_lot::RwLock for sync access) screen_resolution: parking_lot::RwLock>, - /// UDC name for state checking (e.g., "fcc00000.usb") udc_name: Arc>>, - /// Whether the backend has been initialized. initialized: AtomicBool, - /// Whether the device is currently online (UDC configured and devices accessible) online: AtomicBool, - /// Last backend error state. last_error: parking_lot::RwLock>, - /// Last error log time for throttling (using parking_lot for sync) last_error_log: parking_lot::Mutex, - /// Error count since last successful operation (for log throttling) error_count: AtomicU8, - /// Consecutive EAGAIN count (for offline threshold detection) eagain_count: AtomicU8, - /// Runtime change notifier. runtime_notify_tx: watch::Sender<()>, - /// Runtime monitor stop flag. runtime_worker_stop: Arc, - /// Runtime monitor thread. runtime_worker: Mutex>>, } -/// Write timeout in milliseconds (same as JetKVM's hidWriteTimeout) const HID_WRITE_TIMEOUT_MS: i32 = 20; impl OtgBackend { - /// Create OTG backend from device paths provided by OtgService - /// - /// This is the ONLY way to create an OtgBackend - it no longer manages - /// the USB gadget itself. The gadget must already be set up by OtgService. + /// Gadget must already exist; paths come from `OtgService`. pub fn from_handles(paths: HidDevicePaths) -> Result { let (runtime_notify_tx, _runtime_notify_rx) = watch::channel(()); Ok(Self { @@ -234,7 +161,6 @@ impl OtgBackend { } } - /// Log throttled error message (max once per second) fn log_throttled_error(&self, msg: &str) { let mut last_log = self.last_error_log.lock(); let now = std::time::Instant::now(); @@ -251,24 +177,17 @@ impl OtgBackend { } } - /// Reset error count on successful operation fn reset_error_count(&self) { self.error_count.store(0, Ordering::Relaxed); - // Also reset EAGAIN count - successful operation means device is working self.eagain_count.store(0, Ordering::Relaxed); } - /// Write data to HID device with timeout (JetKVM style) - /// - /// Uses poll() to wait for device to be ready for writing. - /// If timeout expires, silently drops the data (acceptable for mouse movement). - /// Returns Ok(true) if write succeeded, Ok(false) if timed out (silently dropped). + /// Poll-based write with `HID_WRITE_TIMEOUT_MS`; timeout → drop (JetKVM-style). fn write_with_timeout(&self, file: &mut File, data: &[u8]) -> std::io::Result { let mut pollfd = [PollFd::new(file.as_fd(), PollFlags::POLLOUT)]; match poll(&mut pollfd, PollTimeout::from(HID_WRITE_TIMEOUT_MS as u16)) { Ok(1) => { - // Device ready, check for errors if let Some(revents) = pollfd[0].revents() { if revents.contains(PollFlags::POLLERR) || revents.contains(PollFlags::POLLHUP) { @@ -278,12 +197,10 @@ impl OtgBackend { )); } } - // Write the data file.write_all(data)?; Ok(true) } Ok(0) => { - // Timeout - silently drop (JetKVM behavior) trace!("HID write timeout, dropping data"); Ok(false) } @@ -292,7 +209,6 @@ impl OtgBackend { } } - /// Set the UDC name for state checking pub fn set_udc_name(&self, udc: &str) { *self.udc_name.write() = Some(udc.to_string()); } @@ -324,15 +240,11 @@ impl OtgBackend { } } - /// Check if the UDC is in "configured" state - /// - /// This is based on PiKVM's `__is_udc_configured()` method. - /// The UDC state file indicates whether the USB host has enumerated and configured the gadget. + /// `true` when `/sys/class/udc//state` reads `configured` (PiKVM-style). pub fn is_udc_configured(&self) -> bool { Self::read_udc_configured(&self.udc_name) } - /// Find the first available UDC fn find_udc() -> Option { let udc_path = PathBuf::from("/sys/class/udc"); if let Ok(entries) = fs::read_dir(&udc_path) { @@ -345,12 +257,7 @@ impl OtgBackend { None } - /// Ensure a device is open and ready for I/O - /// - /// This method is based on PiKVM's `__ensure_device()` pattern: - /// 1. Check if device path exists, close handle if not - /// 2. If handle is None but path exists, reopen the device - /// 3. Return whether the device is ready for I/O + /// PiKVM-style: drop handle if node missing; reopen when path reappears. fn ensure_device(&self, device_type: DeviceType) -> Result<()> { let (path_opt, dev_mutex) = match device_type { DeviceType::Keyboard => (&self.keyboard_path, &self.keyboard_dev), @@ -372,9 +279,7 @@ impl OtgBackend { } }; - // Check if device path exists if !path.exists() { - // Close the device if open (device was removed) let mut dev = dev_mutex.lock(); if dev.is_some() { debug!( @@ -392,7 +297,6 @@ impl OtgBackend { }); } - // If device is not open, try to open it let mut dev = dev_mutex.lock(); if dev.is_none() { match Self::open_device(path) { @@ -415,7 +319,6 @@ impl OtgBackend { Ok(()) } - /// Open a HID device file with read/write access fn open_device(path: &PathBuf) -> Result { OpenOptions::new() .read(true) @@ -431,16 +334,15 @@ impl OtgBackend { }) } - /// Convert I/O error to HidError with appropriate error code fn io_error_code(e: &std::io::Error) -> &'static str { match e.raw_os_error() { - Some(32) => "epipe", // EPIPE - broken pipe - Some(108) => "eshutdown", // ESHUTDOWN - transport endpoint shutdown - Some(11) => "eagain", // EAGAIN - resource temporarily unavailable - Some(6) => "enxio", // ENXIO - no such device or address - Some(19) => "enodev", // ENODEV - no such device - Some(5) => "eio", // EIO - I/O error - Some(2) => "enoent", // ENOENT - no such file or directory + Some(32) => "epipe", + Some(108) => "eshutdown", + Some(11) => "eagain", + Some(6) => "enxio", + Some(19) => "enodev", + Some(5) => "eio", + Some(2) => "enoent", _ => "io_error", } } @@ -455,7 +357,6 @@ impl OtgBackend { } } - /// Check if all HID device files exist pub fn check_devices_exist(&self) -> bool { self.keyboard_path.as_ref().is_none_or(|p| p.exists()) && self.mouse_rel_path.as_ref().is_none_or(|p| p.exists()) @@ -463,7 +364,6 @@ impl OtgBackend { && self.consumer_path.as_ref().is_none_or(|p| p.exists()) } - /// Get list of missing device paths pub fn get_missing_devices(&self) -> Vec { let mut missing = Vec::new(); if let Some(ref path) = self.keyboard_path { @@ -484,17 +384,11 @@ impl OtgBackend { missing } - /// Send keyboard report (8 bytes) - /// - /// This method ensures the device is open before writing, and handles - /// ESHUTDOWN errors by closing the device handle for later reconnection. - /// Uses write_with_timeout to avoid blocking on busy devices. fn send_keyboard_report(&self, report: &KeyboardReport) -> Result<()> { if self.keyboard_path.is_none() { return Ok(()); } - // Ensure device is ready self.ensure_device(DeviceType::Keyboard)?; let mut dev = self.keyboard_dev.lock(); @@ -508,7 +402,6 @@ impl OtgBackend { Ok(()) } Ok(false) => { - // Timeout - silently dropped (JetKVM behavior) self.log_throttled_error("HID keyboard write timeout, dropped"); Ok(()) } @@ -517,7 +410,6 @@ impl OtgBackend { match error_code { Some(108) => { - // ESHUTDOWN - endpoint closed, need to reopen device self.eagain_count.store(0, Ordering::Relaxed); debug!("Keyboard ESHUTDOWN, closing for recovery"); *dev = None; @@ -531,7 +423,6 @@ impl OtgBackend { )) } Some(11) => { - // EAGAIN after poll - should be rare, silently drop trace!("Keyboard EAGAIN after poll, dropping"); Ok(()) } @@ -559,17 +450,11 @@ impl OtgBackend { } } - /// Send relative mouse report (4 bytes: buttons, dx, dy, wheel) - /// - /// This method ensures the device is open before writing, and handles - /// ESHUTDOWN errors by closing the device handle for later reconnection. - /// Uses write_with_timeout to avoid blocking on busy devices. fn send_mouse_report_relative(&self, buttons: u8, dx: i8, dy: i8, wheel: i8) -> Result<()> { if self.mouse_rel_path.is_none() { return Ok(()); } - // Ensure device is ready self.ensure_device(DeviceType::MouseRelative)?; let mut dev = self.mouse_rel_dev.lock(); @@ -582,10 +467,7 @@ impl OtgBackend { trace!("Sent relative mouse report: {:02X?}", data); Ok(()) } - Ok(false) => { - // Timeout - silently dropped (JetKVM behavior) - Ok(()) - } + Ok(false) => Ok(()), Err(e) => { let error_code = e.raw_os_error(); @@ -603,10 +485,7 @@ impl OtgBackend { "Failed to write mouse report", )) } - Some(11) => { - // EAGAIN after poll - should be rare, silently drop - Ok(()) - } + Some(11) => Ok(()), _ => { self.eagain_count.store(0, Ordering::Relaxed); warn!("Relative mouse write error: {}", e); @@ -631,17 +510,11 @@ impl OtgBackend { } } - /// Send absolute mouse report (6 bytes: buttons, x_lo, x_hi, y_lo, y_hi, wheel) - /// - /// This method ensures the device is open before writing, and handles - /// ESHUTDOWN errors by closing the device handle for later reconnection. - /// Uses write_with_timeout to avoid blocking on busy devices. fn send_mouse_report_absolute(&self, buttons: u8, x: u16, y: u16, wheel: i8) -> Result<()> { if self.mouse_abs_path.is_none() { return Ok(()); } - // Ensure device is ready self.ensure_device(DeviceType::MouseAbsolute)?; let mut dev = self.mouse_abs_dev.lock(); @@ -660,10 +533,7 @@ impl OtgBackend { self.reset_error_count(); Ok(()) } - Ok(false) => { - // Timeout - silently dropped (JetKVM behavior) - Ok(()) - } + Ok(false) => Ok(()), Err(e) => { let error_code = e.raw_os_error(); @@ -681,10 +551,7 @@ impl OtgBackend { "Failed to write mouse report", )) } - Some(11) => { - // EAGAIN after poll - should be rare, silently drop - Ok(()) - } + Some(11) => Ok(()), _ => { self.eagain_count.store(0, Ordering::Relaxed); warn!("Absolute mouse write error: {}", e); @@ -709,35 +576,27 @@ impl OtgBackend { } } - /// Send consumer control report (2 bytes: usage_lo, usage_hi) - /// - /// Sends a consumer control usage code and then releases it (sends 0x0000). + /// Press (`usage`) then release (`0x0000`). fn send_consumer_report(&self, usage: u16) -> Result<()> { if self.consumer_path.is_none() { return Ok(()); } - // Ensure device is ready self.ensure_device(DeviceType::ConsumerControl)?; let mut dev = self.consumer_dev.lock(); if let Some(ref mut file) = *dev { - // Send the usage code let data = [(usage & 0xFF) as u8, (usage >> 8) as u8]; match self.write_with_timeout(file, &data) { Ok(true) => { trace!("Sent consumer report: {:02X?}", data); - // Send release (0x0000) let release = [0u8, 0u8]; let _ = self.write_with_timeout(file, &release); self.mark_online(); self.reset_error_count(); Ok(()) } - Ok(false) => { - // Timeout - silently dropped - Ok(()) - } + Ok(false) => Ok(()), Err(e) => { let error_code = e.raw_os_error(); match error_code { @@ -753,10 +612,7 @@ impl OtgBackend { "Failed to write consumer report", )) } - Some(11) => { - // EAGAIN after poll - silently drop - Ok(()) - } + Some(11) => Ok(()), _ => { warn!("Consumer control write error: {}", e); self.record_error( @@ -780,12 +636,10 @@ impl OtgBackend { } } - /// Send consumer control event pub fn send_consumer(&self, event: ConsumerEvent) -> Result<()> { self.send_consumer_report(event.usage) } - /// Get last known LED state pub fn led_state(&self) -> LedState { *self.led_state.read() } @@ -975,7 +829,6 @@ impl HidBackend for OtgBackend { async fn init(&self) -> Result<()> { info!("Initializing OTG HID backend"); - // Auto-detect UDC name for state checking only if OtgService did not provide one if self.udc_name.read().is_none() { if let Some(udc) = Self::find_udc() { info!("Auto-detected UDC: {}", udc); @@ -985,7 +838,6 @@ impl HidBackend for OtgBackend { info!("Using configured UDC: {}", udc); } - // Wait for devices to appear (they should already exist from OtgService) let mut device_paths = Vec::new(); if let Some(ref path) = self.keyboard_path { device_paths.push(path.clone()); @@ -1010,7 +862,6 @@ impl HidBackend for OtgBackend { return Err(AppError::Internal("HID devices did not appear".into())); } - // Open keyboard device if let Some(ref path) = self.keyboard_path { if path.exists() { let file = Self::open_device(path)?; @@ -1021,7 +872,6 @@ impl HidBackend for OtgBackend { } } - // Open relative mouse device if let Some(ref path) = self.mouse_rel_path { if path.exists() { let file = Self::open_device(path)?; @@ -1032,7 +882,6 @@ impl HidBackend for OtgBackend { } } - // Open absolute mouse device if let Some(ref path) = self.mouse_abs_path { if path.exists() { let file = Self::open_device(path)?; @@ -1043,7 +892,6 @@ impl HidBackend for OtgBackend { } } - // Open consumer control device (optional, may not exist on older setups) if let Some(ref path) = self.consumer_path { if path.exists() { let file = Self::open_device(path)?; @@ -1054,7 +902,6 @@ impl HidBackend for OtgBackend { } } - // Mark as online if all devices opened successfully self.initialized.store(true, Ordering::Relaxed); self.notify_runtime_changed(); self.start_runtime_worker(); @@ -1066,7 +913,6 @@ impl HidBackend for OtgBackend { async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> { let usb_key = event.key.to_hid_usage(); - // Handle modifier keys separately if event.key.is_modifier() { let mut state = self.keyboard_state.lock(); @@ -1084,7 +930,6 @@ impl HidBackend for OtgBackend { } else { let mut state = self.keyboard_state.lock(); - // Update modifiers from event state.modifiers = event.modifiers.to_hid_byte(); match event.event_type { @@ -1110,15 +955,12 @@ impl HidBackend for OtgBackend { match event.event_type { MouseEventType::Move => { - // Relative movement - use hidg1 let dx = event.x.clamp(-127, 127) as i8; let dy = event.y.clamp(-127, 127) as i8; self.send_mouse_report_relative(buttons, dx, dy, 0)?; } MouseEventType::MoveAbs => { - // Absolute movement - use hidg2 - // Frontend sends 0-32767 range directly (standard HID absolute mouse range) - // Don't send button state with move - buttons are handled separately on relative device + // Coordinates 0–32767; buttons are sent only on the relative endpoint. let x = event.x.clamp(0, 32767) as u16; let y = event.y.clamp(0, 32767) as u16; self.send_mouse_report_absolute(0, x, y, 0)?; @@ -1127,7 +969,6 @@ impl HidBackend for OtgBackend { if let Some(button) = event.button { let bit = button.to_hid_bit(); let new_buttons = self.mouse_buttons.fetch_or(bit, Ordering::Relaxed) | bit; - // Send on relative device for button clicks self.send_mouse_report_relative(new_buttons, 0, 0, 0)?; } } @@ -1147,7 +988,6 @@ impl HidBackend for OtgBackend { } async fn reset(&self) -> Result<()> { - // Reset keyboard { let mut state = self.keyboard_state.lock(); state.clear(); @@ -1156,7 +996,6 @@ impl HidBackend for OtgBackend { self.send_keyboard_report(&report)?; } - // Reset mouse self.mouse_buttons.store(0, Ordering::Relaxed); self.send_mouse_report_relative(0, 0, 0, 0)?; self.send_mouse_report_absolute(0, 0, 0, 0)?; @@ -1168,16 +1007,13 @@ impl HidBackend for OtgBackend { async fn shutdown(&self) -> Result<()> { self.stop_runtime_worker(); - // Reset before closing self.reset().await?; - // Close devices *self.keyboard_dev.lock() = None; *self.mouse_rel_dev.lock() = None; *self.mouse_abs_dev.lock() = None; *self.consumer_dev.lock() = None; - // Gadget cleanup is handled by OtgService, not here self.initialized.store(false, Ordering::Relaxed); self.online.store(false, Ordering::Relaxed); self.clear_error(); @@ -1199,31 +1035,18 @@ impl HidBackend for OtgBackend { self.send_consumer_report(event.usage) } - fn set_screen_resolution(&mut self, width: u32, height: u32) { + fn set_screen_resolution(&self, width: u32, height: u32) { *self.screen_resolution.write() = Some((width, height)); self.notify_runtime_changed(); } } -/// Check if OTG HID gadget is available -pub fn is_otg_available() -> bool { - // Check for existing HID devices (they should be created by OtgService) - let kb = PathBuf::from("/dev/hidg0"); - let mouse_rel = PathBuf::from("/dev/hidg1"); - let mouse_abs = PathBuf::from("/dev/hidg2"); - - kb.exists() || mouse_rel.exists() || mouse_abs.exists() -} - -/// Implement Drop for OtgBackend to close device files impl Drop for OtgBackend { fn drop(&mut self) { self.runtime_worker_stop.store(true, Ordering::Relaxed); if let Some(handle) = self.runtime_worker.get_mut().take() { let _ = handle.join(); } - // Close device files - // Note: Gadget cleanup is handled by OtgService, not here *self.keyboard_dev.lock() = None; *self.mouse_rel_dev.lock() = None; *self.mouse_abs_dev.lock() = None; @@ -1236,12 +1059,6 @@ impl Drop for OtgBackend { mod tests { use super::*; - #[test] - fn test_otg_availability_check() { - // This just tests the function runs without panicking - let _available = is_otg_available(); - } - #[test] fn test_led_state() { let state = LedState::from_byte(0b00000011); @@ -1254,7 +1071,6 @@ mod tests { #[test] fn test_report_sizes() { - // Keyboard report is 8 bytes let kb_report = KeyboardReport::default(); assert_eq!(kb_report.to_bytes().len(), 8); } diff --git a/src/hid/types.rs b/src/hid/types.rs index 5545ac62..ffef701d 100644 --- a/src/hid/types.rs +++ b/src/hid/types.rs @@ -1,50 +1,37 @@ -//! HID event types for keyboard and mouse +//! Keyboard/mouse/consumer structs (`KeyboardEvent`, `MouseEvent`, …). use serde::{Deserialize, Serialize}; use super::keyboard::CanonicalKey; -/// Keyboard event type #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum KeyEventType { - /// Key pressed down Down, - /// Key released Up, } -/// Keyboard modifier flags #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] pub struct KeyboardModifiers { - /// Left Control #[serde(default)] pub left_ctrl: bool, - /// Left Shift #[serde(default)] pub left_shift: bool, - /// Left Alt #[serde(default)] pub left_alt: bool, - /// Left Meta (Windows/Super key) #[serde(default)] pub left_meta: bool, - /// Right Control #[serde(default)] pub right_ctrl: bool, - /// Right Shift #[serde(default)] pub right_shift: bool, - /// Right Alt (AltGr) #[serde(default)] pub right_alt: bool, - /// Right Meta #[serde(default)] pub right_meta: bool, } impl KeyboardModifiers { - /// Convert to USB HID modifier byte pub fn to_hid_byte(&self) -> u8 { let mut byte = 0u8; if self.left_ctrl { @@ -74,7 +61,6 @@ impl KeyboardModifiers { byte } - /// Create from USB HID modifier byte pub fn from_hid_byte(byte: u8) -> Self { Self { left_ctrl: byte & 0x01 != 0, @@ -88,7 +74,6 @@ impl KeyboardModifiers { } } - /// Check if any modifier is active pub fn any(&self) -> bool { self.left_ctrl || self.left_shift @@ -101,21 +86,16 @@ impl KeyboardModifiers { } } -/// Keyboard event #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KeyboardEvent { - /// Event type (down/up) #[serde(rename = "type")] pub event_type: KeyEventType, - /// Canonical keyboard key identifier shared across frontend and backend pub key: CanonicalKey, - /// Modifier keys state #[serde(default)] pub modifiers: KeyboardModifiers, } impl KeyboardEvent { - /// Create a key down event pub fn key_down(key: CanonicalKey, modifiers: KeyboardModifiers) -> Self { Self { event_type: KeyEventType::Down, @@ -124,7 +104,6 @@ impl KeyboardEvent { } } - /// Create a key up event pub fn key_up(key: CanonicalKey, modifiers: KeyboardModifiers) -> Self { Self { event_type: KeyEventType::Up, @@ -134,7 +113,6 @@ impl KeyboardEvent { } } -/// Mouse button #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum MouseButton { @@ -146,7 +124,6 @@ pub enum MouseButton { } impl MouseButton { - /// Convert to USB HID button bit pub fn to_hid_bit(&self) -> u8 { match self { MouseButton::Left => 0x01, @@ -158,44 +135,31 @@ impl MouseButton { } } -/// Mouse event type #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum MouseEventType { - /// Mouse moved (relative movement) Move, - /// Mouse moved (absolute position) MoveAbs, - /// Button pressed Down, - /// Button released Up, - /// Mouse wheel scroll Scroll, } -/// Mouse event #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MouseEvent { - /// Event type #[serde(rename = "type")] pub event_type: MouseEventType, - /// X coordinate or delta #[serde(default)] pub x: i32, - /// Y coordinate or delta #[serde(default)] pub y: i32, - /// Button (for down/up events) #[serde(default)] pub button: Option, - /// Scroll delta (for scroll events) #[serde(default)] pub scroll: i8, } impl MouseEvent { - /// Create a relative move event pub fn move_rel(dx: i32, dy: i32) -> Self { Self { event_type: MouseEventType::Move, @@ -206,7 +170,6 @@ impl MouseEvent { } } - /// Create an absolute move event pub fn move_abs(x: i32, y: i32) -> Self { Self { event_type: MouseEventType::MoveAbs, @@ -217,7 +180,6 @@ impl MouseEvent { } } - /// Create a button down event pub fn button_down(button: MouseButton) -> Self { Self { event_type: MouseEventType::Down, @@ -228,7 +190,6 @@ impl MouseEvent { } } - /// Create a button up event pub fn button_up(button: MouseButton) -> Self { Self { event_type: MouseEventType::Up, @@ -239,7 +200,6 @@ impl MouseEvent { } } - /// Create a scroll event pub fn scroll(delta: i8) -> Self { Self { event_type: MouseEventType::Scroll, @@ -251,35 +211,19 @@ impl MouseEvent { } } -/// Combined HID event (keyboard or mouse) -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "device", rename_all = "lowercase")] -pub enum HidEvent { - Keyboard(KeyboardEvent), - Mouse(MouseEvent), - Consumer(ConsumerEvent), -} - -/// Consumer control event (multimedia keys) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConsumerEvent { - /// Consumer control usage code (e.g., 0x00CD for Play/Pause) pub usage: u16, } -/// USB HID keyboard report (8 bytes) #[derive(Debug, Clone, Default)] pub struct KeyboardReport { - /// Modifier byte pub modifiers: u8, - /// Reserved byte pub reserved: u8, - /// Key codes (up to 6 simultaneous keys) pub keys: [u8; 6], } impl KeyboardReport { - /// Convert to bytes for USB HID pub fn to_bytes(&self) -> [u8; 8] { [ self.modifiers, @@ -293,7 +237,6 @@ impl KeyboardReport { ] } - /// Add a key to the report pub fn add_key(&mut self, key: u8) -> bool { for slot in &mut self.keys { if *slot == 0 { @@ -304,56 +247,21 @@ impl KeyboardReport { false // All slots full } - /// Remove a key from the report pub fn remove_key(&mut self, key: u8) { for slot in &mut self.keys { if *slot == key { *slot = 0; } } - // Compact the array self.keys.sort_by(|a, b| b.cmp(a)); } - /// Clear all keys pub fn clear(&mut self) { self.modifiers = 0; self.keys = [0; 6]; } } -/// USB HID mouse report -#[derive(Debug, Clone, Default)] -pub struct MouseReport { - /// Button state - pub buttons: u8, - /// X movement (-127 to 127) - pub x: i8, - /// Y movement (-127 to 127) - pub y: i8, - /// Wheel movement (-127 to 127) - pub wheel: i8, -} - -impl MouseReport { - /// Convert to bytes for USB HID (relative mouse) - pub fn to_bytes_relative(&self) -> [u8; 4] { - [self.buttons, self.x as u8, self.y as u8, self.wheel as u8] - } - - /// Convert to bytes for USB HID (absolute mouse) - pub fn to_bytes_absolute(&self, x: u16, y: u16) -> [u8; 6] { - [ - self.buttons, - (x & 0xFF) as u8, - (x >> 8) as u8, - (y & 0xFF) as u8, - (y >> 8) as u8, - self.wheel as u8, - ] - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/hid/websocket.rs b/src/hid/websocket.rs index 4cff8bc4..9a322674 100644 --- a/src/hid/websocket.rs +++ b/src/hid/websocket.rs @@ -1,13 +1,4 @@ -//! WebSocket HID channel for HTTP/MJPEG mode -//! -//! This provides an alternative to WebRTC DataChannel for HID input -//! when using MJPEG streaming mode. -//! -//! Uses binary protocol only (same format as DataChannel): -//! - Keyboard: [0x01, event_type, key, modifiers] (4 bytes) -//! - Mouse: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, button/scroll] (7 bytes) -//! -//! See datachannel.rs for detailed protocol specification. +//! MJPEG mode: HID over WebSocket — same binary framing as [`super::datachannel`] (`0x01`/`0x02`/`0x03`; layout detailed there). use axum::{ extract::{ @@ -24,25 +15,20 @@ use super::datachannel::{parse_hid_message, HidChannelEvent}; use crate::state::AppState; use crate::utils::LogThrottler; -/// Binary response codes const RESP_OK: u8 = 0x00; const RESP_ERR_HID_UNAVAILABLE: u8 = 0x01; const RESP_ERR_INVALID_MESSAGE: u8 = 0x02; -/// WebSocket HID upgrade handler pub async fn ws_hid_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { ws.on_upgrade(move |socket| handle_hid_socket(socket, state)) } -/// Handle HID WebSocket connection async fn handle_hid_socket(socket: WebSocket, state: Arc) { let (mut sender, mut receiver) = socket.split(); - // Log throttler for error messages (5 second interval) let log_throttler = LogThrottler::with_secs(5); info!("WebSocket HID connection established (binary protocol)"); - // Check if HID controller is available and send initial status let hid_available = state.hid.is_available().await; let initial_response = if hid_available { vec![RESP_OK] @@ -59,17 +45,14 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc) { return; } - // Process incoming messages (binary only) while let Some(msg) = receiver.next().await { match msg { Ok(Message::Binary(data)) => { - // Check HID availability before processing each message let hid_available = state.hid.is_available().await; if !hid_available { if log_throttler.should_log("hid_unavailable") { warn!("HID controller not available, ignoring message"); } - // Send error response (optional, for client awareness) let _ = sender .send(Message::Binary(vec![RESP_ERR_HID_UNAVAILABLE].into())) .await; @@ -77,15 +60,12 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc) { } if let Err(e) = handle_binary_message(&data, &state).await { - // Log with throttling to avoid spam if log_throttler.should_log("binary_hid_error") { warn!("Binary HID message error: {}", e); } - // Don't send error response for every failed message to reduce overhead } } Ok(Message::Text(text)) => { - // Text messages are no longer supported if log_throttler.should_log("text_message_rejected") { debug!( "Received text message (not supported): {} bytes", @@ -111,7 +91,6 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc) { } } - // Reset HID state to release any held keys/buttons if let Err(e) = state.hid.reset().await { warn!("Failed to reset HID on WebSocket disconnect: {}", e); } @@ -119,7 +98,6 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc) { info!("WebSocket HID connection ended"); } -/// Handle binary HID message (same format as DataChannel) async fn handle_binary_message(data: &[u8], state: &AppState) -> Result<(), String> { let event = parse_hid_message(data).ok_or("Invalid binary HID message")?; @@ -160,12 +138,10 @@ mod tests { assert_eq!(RESP_OK, 0x00); assert_eq!(RESP_ERR_HID_UNAVAILABLE, 0x01); assert_eq!(RESP_ERR_INVALID_MESSAGE, 0x02); - // assert_eq!(RESP_ERR_SEND_FAILED, 0x03); // TODO: fix test } #[test] fn test_keyboard_message_format() { - // Keyboard message: [0x01, event_type, key, modifiers] let data = [MSG_KEYBOARD, KB_EVENT_DOWN, 0x04, 0x01]; // 'A' key with left ctrl let event = parse_hid_message(&data); assert!(event.is_some()); @@ -173,7 +149,6 @@ mod tests { #[test] fn test_mouse_message_format() { - // Mouse message: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, extra] let data = [MSG_MOUSE, MS_EVENT_MOVE, 0x0A, 0x00, 0xF6, 0xFF, 0x00]; // x=10, y=-10 let event = parse_hid_message(&data); assert!(event.is_some()); diff --git a/src/lib.rs b/src/lib.rs index e64e5da8..4b2d04f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,30 +1,27 @@ -//! One-KVM - Lightweight IP-KVM solution -//! -//! This crate provides the core functionality for One-KVM, -//! a remote KVM (Keyboard, Video, Mouse) solution written in Rust. +//! Core library for One-KVM (IP‑KVM: capture, HID, OTG, streaming, Web UI glue). pub mod atx; pub mod audio; pub mod auth; pub mod config; +pub mod db; pub mod error; pub mod events; pub mod extensions; pub mod hid; -pub mod modules; pub mod msd; pub mod otg; pub mod rtsp; pub mod rustdesk; pub mod state; pub mod stream; +pub mod stream_encoder; pub mod update; pub mod utils; pub mod video; pub mod web; pub mod webrtc; -/// Auto-generated secrets module (from secrets.toml at compile time) pub mod secrets { include!(concat!(env!("OUT_DIR"), "/secrets_generated.rs")); } diff --git a/src/main.rs b/src/main.rs index fd667d3b..484ee213 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,8 @@ use std::collections::HashSet; +use std::future::Future; use std::io::Write; use std::net::{IpAddr, SocketAddr}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::Arc; use axum_server::tls_rustls::RustlsConfig; @@ -15,6 +16,7 @@ use one_kvm::atx::AtxController; use one_kvm::audio::{AudioController, AudioControllerConfig, AudioQuality}; use one_kvm::auth::{SessionStore, UserStore}; use one_kvm::config::{self, AppConfig, ConfigStore}; +use one_kvm::db::DatabasePool; use one_kvm::events::EventBus; use one_kvm::extensions::ExtensionManager; use one_kvm::hid::{HidBackendType, HidController}; @@ -33,7 +35,6 @@ use one_kvm::video::{Streamer, VideoStreamManager}; use one_kvm::web; use one_kvm::webrtc::{WebRtcStreamer, WebRtcStreamerConfig}; -/// Log level for the application #[derive(Debug, Clone, Copy, Default, ValueEnum)] enum LogLevel { Error, @@ -45,7 +46,6 @@ enum LogLevel { Trace, } -/// One-KVM command line arguments #[derive(Parser, Debug)] #[command(name = "one-kvm")] #[command(version, about = "A open and lightweight IP-KVM solution", long_about = None)] @@ -111,37 +111,30 @@ enum UserAction { #[tokio::main] async fn main() -> anyhow::Result<()> { - // Parse command line arguments let args = CliArgs::parse(); - // Initialize logging with CLI arguments init_logging(args.log_level, args.verbose); - // Install default crypto provider (required by rustls 0.23+) CryptoProvider::install_default(ring::default_provider()) .expect("Failed to install rustls crypto provider"); tracing::info!("Starting One-KVM v{}", env!("CARGO_PKG_VERSION")); - // Determine data directory (CLI arg takes precedence) let data_dir = args.data_dir.clone().unwrap_or_else(get_data_dir); tracing::info!("Data directory: {}", data_dir.display()); - // Run one-off CLI command and exit. if let Some(command) = args.command { run_cli_command(command, data_dir).await?; return Ok(()); } - // Ensure data directory exists tokio::fs::create_dir_all(&data_dir).await?; - // Initialize configuration store - let db_path = data_dir.join("one-kvm.db"); - let config_store = ConfigStore::new(&db_path).await?; + let db = open_database_pool(&data_dir).await?; + let config_store = ConfigStore::new(db.clone_pool())?; + config_store.load().await?; let mut config = (*config_store.get()).clone(); - // Normalize MSD directory (absolute path under data dir if empty/relative) let mut msd_dir_updated = false; if config.msd.msd_dir.trim().is_empty() { let msd_dir = data_dir.join("msd"); @@ -159,8 +152,6 @@ async fn main() -> anyhow::Result<()> { if msd_dir_updated { config_store.set(config.clone()).await?; } - - // Ensure MSD directories exist (msd/images, msd/ventoy) let msd_dir = PathBuf::from(&config.msd.msd_dir); if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("images")).await { tracing::warn!("Failed to create MSD images directory: {}", e); @@ -169,7 +160,6 @@ async fn main() -> anyhow::Result<()> { tracing::warn!("Failed to create MSD ventoy directory: {}", e); } - // Apply CLI argument overrides to config (only if explicitly specified) if let Some(addr) = args.address { config.web.bind_address = addr.clone(); config.web.bind_addresses = vec![addr]; @@ -203,29 +193,21 @@ async fn main() -> anyhow::Result<()> { config.web.http_port }; - // Log final configuration + for ip in &bind_ips { let addr = SocketAddr::new(*ip, bind_port); tracing::info!("Server will listen on: {}://{}", scheme, addr); } - // Initialize session store - let session_store = SessionStore::new( - config_store.pool().clone(), - config.auth.session_timeout_secs as i64, - ); + let session_store = SessionStore::new(config.auth.session_timeout_secs as i64); - // Initialize user store - let user_store = UserStore::new(config_store.pool().clone()); + let user_store = UserStore::new(db.clone_pool()); - // Create shutdown channel let (shutdown_tx, _) = broadcast::channel::<()>(1); - // Create event bus for real-time notifications let events = Arc::new(EventBus::new()); tracing::info!("Event bus initialized"); - // Parse video configuration once (avoid duplication) let (video_format, video_resolution) = parse_video_config(&config); tracing::debug!( "Parsed video config: {} @ {}x{}", @@ -234,7 +216,6 @@ async fn main() -> anyhow::Result<()> { video_resolution.height ); - // Create video streamer and initialize with config if device is set let streamer = Streamer::new(); streamer.set_event_bus(events.clone()).await; if let Some(ref device_path) = config.video.device { @@ -262,19 +243,19 @@ async fn main() -> anyhow::Result<()> { } } - // Create WebRTC streamer let webrtc_streamer = { let webrtc_config = WebRtcStreamerConfig { resolution: video_resolution, input_format: video_format, fps: config.video.fps, bitrate_preset: config.stream.bitrate_preset, - encoder_backend: config.stream.encoder.to_backend(), + encoder_backend: one_kvm::stream_encoder::encoder_type_to_backend( + config.stream.encoder.clone(), + ), webrtc: { let mut stun_servers = vec![]; let mut turn_servers = vec![]; - // Check if user configured custom servers let has_custom_stun = config .stream .stun_server @@ -288,14 +269,12 @@ async fn main() -> anyhow::Result<()> { .map(|s| !s.is_empty()) .unwrap_or(false); - // If no custom servers, use baked-in public STUN if !has_custom_stun && !has_custom_turn { use one_kvm::webrtc::config::public_ice; let stun = public_ice::stun_server().to_string(); tracing::info!("Using public STUN server: {}", stun); stun_servers.push(stun); } else { - // Use custom servers if let Some(ref stun) = config.stream.stun_server { if !stun.is_empty() { stun_servers.push(stun.clone()); @@ -333,16 +312,13 @@ async fn main() -> anyhow::Result<()> { }; tracing::info!("WebRTC streamer created"); - // Create OTG Service (single instance for centralized USB gadget management) let otg_service = Arc::new(OtgService::new()); tracing::info!("OTG Service created"); - // Reconcile OTG once from the persisted config so controllers only consume its result. if let Err(e) = otg_service.apply_config(&config.hid, &config.msd).await { tracing::warn!("Failed to apply OTG config: {}", e); } - // Create HID controller based on config let hid_backend = match config.hid.backend { config::HidBackend::Otg => HidBackendType::Otg, config::HidBackend::Ch9329 => HidBackendType::Ch9329 { @@ -353,16 +329,14 @@ async fn main() -> anyhow::Result<()> { }; let hid = Arc::new(HidController::new( hid_backend, - Some(otg_service.clone()), // Always pass OtgService to support hot-reload to OTG + Some(otg_service.clone()), )); hid.set_event_bus(events.clone()).await; if let Err(e) = hid.init().await { tracing::warn!("Failed to initialize HID backend: {}", e); } - // Create MSD controller (optional, based on config) let msd = if config.msd.enabled { - // `{data_dir}/ventoy`: boot.img, core.img, ventoy.disk.img for ventoy_img let ventoy_resource_dir = data_dir.join("ventoy"); if ventoy_resource_dir.exists() { if let Err(e) = ventoy_img::init_resources(&ventoy_resource_dir) { @@ -393,7 +367,6 @@ async fn main() -> anyhow::Result<()> { None }; - // Create ATX controller (optional, based on config) let atx = if config.atx.enabled { let controller_config = config.atx.to_controller_config(); let controller = AtxController::new(controller_config); @@ -409,12 +382,21 @@ async fn main() -> anyhow::Result<()> { None }; - // Create Audio controller let audio = { let audio_config = AudioControllerConfig { enabled: config.audio.enabled, device: config.audio.device.clone(), - quality: AudioQuality::from_str(&config.audio.quality), + quality: match config.audio.quality.parse::() { + Ok(q) => q, + Err(e) => { + tracing::warn!( + "Invalid audio quality in config (value={:?}): {}, using balanced", + config.audio.quality, + e + ); + AudioQuality::Balanced + } + }, }; let controller = AudioController::new(audio_config); @@ -426,7 +408,6 @@ async fn main() -> anyhow::Result<()> { config.audio.device, config.audio.quality ); - // Start audio streaming so WebRTC can subscribe to Opus frames if let Err(e) = controller.start_streaming().await { tracing::warn!("Failed to start audio streaming: {}", e); } @@ -437,16 +418,11 @@ async fn main() -> anyhow::Result<()> { Arc::new(controller) }; - // Create Extension manager (ttyd, gostc, easytier) let extensions = Arc::new(ExtensionManager::new()); tracing::info!("Extension manager initialized"); - // Wire up WebRTC streamer with HID controller - // This enables WebRTC DataChannel to process HID events webrtc_streamer.set_hid_controller(hid.clone()).await; - // Wire up WebRTC streamer with Audio controller - // This enables WebRTC audio track to receive Opus frames webrtc_streamer.set_audio_controller(audio.clone()).await; if config.audio.enabled { if let Err(e) = webrtc_streamer.set_audio_enabled(true).await { @@ -456,7 +432,6 @@ async fn main() -> anyhow::Result<()> { } } - // Configure direct capture for WebRTC encoder pipeline let (device_path, actual_resolution, actual_format, actual_fps, jpeg_quality) = streamer.current_capture_config().await; tracing::info!( @@ -495,14 +470,13 @@ async fn main() -> anyhow::Result<()> { tracing::warn!("No capture device configured for WebRTC"); } - // Create video stream manager (unified MJPEG/WebRTC management) - // Use with_webrtc_streamer to ensure we use the same WebRtcStreamer instance - let stream_manager = - VideoStreamManager::with_webrtc_streamer(streamer.clone(), webrtc_streamer.clone()); + let stream_manager = VideoStreamManager::with_webrtc_streamer( + streamer.clone(), + webrtc_streamer.clone() as std::sync::Arc, + ); stream_manager.set_event_bus(events.clone()).await; stream_manager.set_config_store(config_store.clone()).await; - // Initialize stream manager with configured mode let initial_mode = config.stream.mode.clone(); if let Err(e) = stream_manager.init_with_mode(initial_mode.clone()).await { tracing::warn!( @@ -517,7 +491,6 @@ async fn main() -> anyhow::Result<()> { ); } - // Create RustDesk service (optional, based on config) let rustdesk = if config.rustdesk.is_valid() { tracing::info!( "Initializing RustDesk service: ID={} -> {}", @@ -542,7 +515,6 @@ async fn main() -> anyhow::Result<()> { None }; - // Create RTSP service (optional, based on config) let rtsp = if config.rtsp.enabled { tracing::info!( "Initializing RTSP service: rtsp://{}:{}/{}", @@ -557,15 +529,16 @@ async fn main() -> anyhow::Result<()> { None }; - // Create application state let update_service = Arc::new(UpdateService::new(data_dir.join("updates"))); let state = AppState::new( + db.clone(), config_store.clone(), session_store, user_store, otg_service, stream_manager, + webrtc_streamer.clone(), hid, msd, atx, @@ -581,12 +554,10 @@ async fn main() -> anyhow::Result<()> { extensions.set_event_bus(events.clone()).await; - // Start RustDesk service if enabled if let Some(ref service) = rustdesk { if let Err(e) = service.start().await { tracing::error!("Failed to start RustDesk service: {}", e); } else { - // Save generated keypair and UUID to config if let Some(updated_config) = service.save_credentials() { if let Err(e) = config_store .update(|cfg| { @@ -606,7 +577,6 @@ async fn main() -> anyhow::Result<()> { } } - // Start RTSP service if enabled if let Some(ref service) = rtsp { if let Err(e) = service.start().await { tracing::error!("Failed to start RTSP service: {}", e); @@ -615,7 +585,6 @@ async fn main() -> anyhow::Result<()> { } } - // Enforce startup codec constraints (e.g. RTSP/RustDesk locks) { let runtime_config = state.config.get(); let constraints = StreamCodecConstraints::from_config(&runtime_config); @@ -630,13 +599,11 @@ async fn main() -> anyhow::Result<()> { } } - // Start enabled extensions { let ext_config = config_store.get(); extensions.start_enabled(&ext_config.extensions).await; } - // Start extension health check task (every 30 seconds) { let extensions_clone = extensions.clone(); let config_store_clone = config_store.clone(); @@ -653,17 +620,12 @@ async fn main() -> anyhow::Result<()> { state.publish_device_info().await; - // Start device info broadcast task - // This monitors state change events and broadcasts DeviceInfo to all clients spawn_device_info_broadcaster(state.clone(), events); - // Create router let app = web::create_router(state.clone()); - // Bind sockets for configured addresses let listeners = bind_tcp_listeners(&bind_ips, bind_port)?; - // Setup graceful shutdown let shutdown_signal = async move { tokio::signal::ctrl_c() .await @@ -672,9 +634,7 @@ async fn main() -> anyhow::Result<()> { let _ = shutdown_tx.send(()); }; - // Start server if config.web.https_enabled { - // Generate self-signed certificate if no custom cert provided let tls_config = if let (Some(cert_path), Some(key_path)) = (&config.web.ssl_cert_path, &config.web.ssl_key_path) { @@ -684,7 +644,6 @@ async fn main() -> anyhow::Result<()> { let cert_path = cert_dir.join("server.crt"); let key_path = cert_dir.join("server.key"); - // Check if certificate already exists, only generate if missing if !cert_path.exists() || !key_path.exists() { tracing::info!("Generating new self-signed TLS certificate"); let cert = generate_self_signed_cert()?; @@ -698,7 +657,7 @@ async fn main() -> anyhow::Result<()> { RustlsConfig::from_pem_file(&cert_path, &key_path).await? }; - let mut servers = FuturesUnordered::new(); + let servers = FuturesUnordered::new(); for listener in listeners { let local_addr = listener.local_addr()?; tracing::info!("Starting HTTPS server on {}", local_addr); @@ -708,19 +667,9 @@ async fn main() -> anyhow::Result<()> { servers.push(server); } - tokio::select! { - _ = shutdown_signal => { - cleanup(&state).await; - } - result = servers.next() => { - if let Some(Err(e)) = result { - tracing::error!("HTTPS server error: {}", e); - } - cleanup(&state).await; - } - } + run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTPS").await; } else { - let mut servers = FuturesUnordered::new(); + let servers = FuturesUnordered::new(); for listener in listeners { let local_addr = listener.local_addr()?; tracing::info!("Starting HTTP server on {}", local_addr); @@ -730,26 +679,14 @@ async fn main() -> anyhow::Result<()> { servers.push(async move { server.await }); } - tokio::select! { - _ = shutdown_signal => { - cleanup(&state).await; - } - result = servers.next() => { - if let Some(Err(e)) = result { - tracing::error!("HTTP server error: {}", e); - } - cleanup(&state).await; - } - } + run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTP").await; } tracing::info!("Server shutdown complete"); Ok(()) } -/// Initialize logging with tracing fn init_logging(level: LogLevel, verbose_count: u8) { - // Verbose count overrides log level let effective_level = match verbose_count { 0 => level, 1 => LogLevel::Verbose, @@ -757,7 +694,6 @@ fn init_logging(level: LogLevel, verbose_count: u8) { _ => LogLevel::Trace, }; - // Build filter string based on effective level let filter = match effective_level { LogLevel::Error => "one_kvm=error,tower_http=error", LogLevel::Warn => "one_kvm=warn,tower_http=warn", @@ -767,7 +703,6 @@ fn init_logging(level: LogLevel, verbose_count: u8) { LogLevel::Trace => "one_kvm=trace,tower_http=debug", }; - // Environment variable takes highest priority let env_filter = tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| filter.into()); @@ -780,23 +715,48 @@ fn init_logging(level: LogLevel, verbose_count: u8) { } } -/// Get the application data directory fn get_data_dir() -> PathBuf { - // Check environment variable first if let Ok(path) = std::env::var("ONE_KVM_DATA_DIR") { return PathBuf::from(path); } - // Default to system configuration directory PathBuf::from("/etc/one-kvm") } +async fn open_database_pool(data_dir: &Path) -> anyhow::Result { + let db_path = data_dir.join("one-kvm.db"); + let db = DatabasePool::new(&db_path).await?; + db.init_schema().await?; + Ok(db) +} + +async fn run_servers_until_shutdown( + mut servers: FuturesUnordered, + shutdown_signal: impl Future, + state: &Arc, + protocol: &'static str, +) where + F: Future> + Send, + E: std::fmt::Display, +{ + tokio::select! { + _ = shutdown_signal => { + cleanup(state).await; + } + result = servers.next() => { + if let Some(Err(e)) = result { + tracing::error!("{} server error: {}", protocol, e); + } + cleanup(state).await; + } + } +} + async fn run_cli_command(command: CliCommand, data_dir: PathBuf) -> anyhow::Result<()> { tokio::fs::create_dir_all(&data_dir).await?; - let db_path = data_dir.join("one-kvm.db"); - let config_store = ConfigStore::new(&db_path).await?; - let users = UserStore::new(config_store.pool().clone()); - let sessions = SessionStore::new(config_store.pool().clone(), 0); + let db = open_database_pool(&data_dir).await?; + let users = UserStore::new(db.clone_pool()); + let sessions = SessionStore::new(0); match command { CliCommand::User(user) => run_user_action(user.action, &users, &sessions).await, @@ -814,15 +774,10 @@ async fn run_user_action( } async fn set_user_password(users: &UserStore, sessions: &SessionStore) -> anyhow::Result<()> { - let all = users.list().await?; - let user = match all.len() { - 0 => anyhow::bail!("No local user exists yet; complete setup in the web UI first."), - 1 => &all[0], - _ => anyhow::bail!( - "Expected exactly one local user (single-user design), found {}. Remove extra users from the database or contact support.", - all.len() - ), - }; + let user = users + .single_user() + .await? + .ok_or_else(|| anyhow::anyhow!("No local user exists yet; complete setup in the web UI first."))?; let new_password = read_new_password_interactive()?; if new_password.len() < 4 { @@ -830,7 +785,7 @@ async fn set_user_password(users: &UserStore, sessions: &SessionStore) -> anyhow } users.update_password(&user.id, &new_password).await?; - let revoked = sessions.delete_by_user_id(&user.id).await?; + let revoked = sessions.delete_all().await?; tracing::info!( "Password updated for user '{}' and {} sessions revoked", @@ -866,7 +821,6 @@ fn read_new_password_interactive() -> anyhow::Result { Ok(a) } -/// Resolve bind IPs from config, preferring bind_addresses when set. fn resolve_bind_addresses(web: &config::WebConfig) -> anyhow::Result> { let raw_addrs = if !web.bind_addresses.is_empty() { web.bind_addresses.as_slice() @@ -907,7 +861,6 @@ fn bind_tcp_listeners(addrs: &[IpAddr], port: u16) -> anyhow::Result (PixelFormat, Resolution) { let format = config .video @@ -919,7 +872,6 @@ fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) { (format, resolution) } -/// Generate a self-signed TLS certificate fn generate_self_signed_cert() -> anyhow::Result> { use rcgen::generate_simple_self_signed; @@ -933,8 +885,6 @@ fn generate_self_signed_cert() -> anyhow::Result, events: Arc) { use std::time::{Duration, Instant}; @@ -1021,7 +971,6 @@ fn spawn_device_info_broadcaster(state: Arc, events: Arc) { let mut pending_broadcast = false; loop { - // Use timeout to handle pending broadcasts let recv_result = if pending_broadcast { let remaining = DEBOUNCE_MS.saturating_sub(last_broadcast.elapsed().as_millis() as u64); @@ -1046,12 +995,9 @@ fn spawn_device_info_broadcaster(state: Arc, events: Arc) { tracing::info!("Event bus closed, stopping DeviceInfo broadcaster"); break; } - Err(_timeout) => { - // Debounce timeout reached, broadcast now - } + Err(_timeout) => {} } - // Broadcast if pending and debounce time has passed if pending_broadcast && last_broadcast.elapsed() >= Duration::from_millis(DEBOUNCE_MS) { state.publish_device_info().await; tracing::trace!("Broadcasted DeviceInfo (debounced)"); @@ -1067,13 +1013,10 @@ fn spawn_device_info_broadcaster(state: Arc, events: Arc) { ); } -/// Clean up subsystems on shutdown async fn cleanup(state: &Arc) { - // Stop all extensions state.extensions.stop_all().await; tracing::info!("Extensions stopped"); - // Stop RustDesk service if let Some(ref service) = *state.rustdesk.read().await { if let Err(e) = service.stop().await { tracing::warn!("Failed to stop RustDesk service: {}", e); @@ -1082,7 +1025,6 @@ async fn cleanup(state: &Arc) { } } - // Stop RTSP service if let Some(ref service) = *state.rtsp.read().await { if let Err(e) = service.stop().await { tracing::warn!("Failed to stop RTSP service: {}", e); @@ -1091,31 +1033,26 @@ async fn cleanup(state: &Arc) { } } - // Stop video if let Err(e) = state.stream_manager.stop().await { tracing::warn!("Failed to stop streamer: {}", e); } - // Shutdown HID if let Err(e) = state.hid.shutdown().await { tracing::warn!("Failed to shutdown HID: {}", e); } - // Shutdown MSD if let Some(msd) = state.msd.write().await.as_mut() { if let Err(e) = msd.shutdown().await { tracing::warn!("Failed to shutdown MSD: {}", e); } } - // Shutdown ATX if let Some(atx) = state.atx.write().await.as_mut() { if let Err(e) = atx.shutdown().await { tracing::warn!("Failed to shutdown ATX: {}", e); } } - // Shutdown Audio if let Err(e) = state.audio.shutdown().await { tracing::warn!("Failed to shutdown audio: {}", e); } diff --git a/src/modules/mod.rs b/src/modules/mod.rs deleted file mode 100644 index 8e15b5c9..00000000 --- a/src/modules/mod.rs +++ /dev/null @@ -1,49 +0,0 @@ -//! Module management for One-KVM -//! -//! This module provides infrastructure for managing feature modules -//! (video streaming, HID control, MSD, ATX) as independent async tasks. - -use std::future::Future; -use std::pin::Pin; -use tokio::sync::broadcast; - -/// Module status -#[derive(Debug, Clone, PartialEq)] -pub enum ModuleStatus { - Stopped, - Starting, - Running, - Stopping, - Error(String), -} - -/// Trait for feature modules -pub trait Module: Send + Sync { - /// Module name - fn name(&self) -> &'static str; - - /// Current status - fn status(&self) -> ModuleStatus; - - /// Start the module - fn start(&mut self) -> Pin> + Send + '_>>; - - /// Stop the module - fn stop(&mut self) -> Pin> + Send + '_>>; -} - -/// Module manager for coordinating feature modules -pub struct ModuleManager { - shutdown_rx: broadcast::Receiver<()>, -} - -impl ModuleManager { - pub fn new(shutdown_rx: broadcast::Receiver<()>) -> Self { - Self { shutdown_rx } - } - - /// Wait for shutdown signal - pub async fn wait_for_shutdown(&mut self) { - let _ = self.shutdown_rx.recv().await; - } -} diff --git a/src/msd/controller.rs b/src/msd/controller.rs index 0be5428d..5498b217 100644 --- a/src/msd/controller.rs +++ b/src/msd/controller.rs @@ -1,11 +1,3 @@ -//! MSD Controller -//! -//! Manages the mass storage device lifecycle including: -//! - Image mounting and unmounting -//! - Virtual drive management -//! - State tracking -//! - Image downloads from URL - use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; @@ -14,41 +6,25 @@ use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; use super::image::ImageManager; -use super::monitor::{MsdHealthMonitor, MsdHealthStatus}; +use super::monitor::MsdHealthMonitor; use super::types::{DownloadProgress, DownloadStatus, DriveInfo, ImageInfo, MsdMode, MsdState}; use crate::error::{AppError, Result}; use crate::otg::{MsdFunction, MsdLunConfig, OtgService}; -/// MSD Controller pub struct MsdController { - /// OTG Service reference otg_service: Arc, - /// MSD function manager (provided by OtgService) msd_function: RwLock>, - /// Current state state: RwLock, - /// Images storage path images_path: PathBuf, - /// Ventoy directory path ventoy_dir: PathBuf, - /// Virtual drive path drive_path: PathBuf, - /// Event bus for broadcasting state changes (optional) events: tokio::sync::RwLock>>, - /// Active downloads (download_id -> CancellationToken) downloads: Arc>>, - /// Operation mutex lock (prevents concurrent operations) operation_lock: Arc>, - /// Health monitor for error tracking and recovery monitor: Arc, } impl MsdController { - /// Create new MSD controller - /// - /// # Parameters - /// * `otg_service` - OTG service for gadget management - /// * `msd_dir` - Base directory for MSD storage pub fn new(otg_service: Arc, msd_dir: impl Into) -> Self { let msd_dir = msd_dir.into(); let images_path = msd_dir.join("images"); @@ -68,11 +44,9 @@ impl MsdController { } } - /// Initialize the MSD controller pub async fn init(&self) -> Result<()> { info!("Initializing MSD controller"); - // 1. Ensure images directory exists if let Err(e) = std::fs::create_dir_all(&self.images_path) { warn!("Failed to create images directory: {}", e); } @@ -80,20 +54,16 @@ impl MsdController { warn!("Failed to create ventoy directory: {}", e); } - // 2. Get active MSD function from OtgService info!("Fetching MSD function from OtgService"); let msd_func = self.otg_service.msd_function().await.ok_or_else(|| { AppError::Internal("MSD function is not active in OtgService".to_string()) })?; - // 3. Store function handle *self.msd_function.write().await = Some(msd_func); - // 4. Update state let mut state = self.state.write().await; state.available = true; - // 5. Check for existing virtual drive if self.drive_path.exists() { if let Ok(metadata) = std::fs::metadata(&self.drive_path) { state.drive_info = Some(DriveInfo { @@ -114,17 +84,14 @@ impl MsdController { Ok(()) } - /// Get current MSD state pub async fn state(&self) -> MsdState { self.state.read().await.clone() } - /// Set event bus for broadcasting state changes pub async fn set_event_bus(&self, events: std::sync::Arc) { *self.events.write().await = Some(events); } - /// Publish an event to the event bus async fn publish_event(&self, event: crate::events::SystemEvent) { if let Some(ref bus) = *self.events.read().await { bus.publish(event); @@ -137,43 +104,21 @@ impl MsdController { } } - /// Check if MSD is available pub async fn is_available(&self) -> bool { self.state.read().await.available } - /// Connect an image file - /// - /// # Parameters - /// * `image` - Image info to mount - /// * `cdrom` - Mount as CD-ROM (read-only, removable) - /// * `read_only` - Mount as read-only pub async fn connect_image( &self, image: &ImageInfo, cdrom: bool, read_only: bool, ) -> Result<()> { - // Acquire operation lock to prevent concurrent operations let _op_guard = self.operation_lock.write().await; - let mut state = self.state.write().await; - if !state.available { - let err = AppError::Internal("MSD not available".to_string()); - self.monitor - .report_error("MSD not available", "not_available") - .await; - return Err(err); - } + self.assert_can_connect(&state).await?; - if state.connected { - return Err(AppError::Internal( - "Already connected. Disconnect first.".to_string(), - )); - } - - // Verify image exists if !image.path.exists() { let error_msg = format!("Image file not found: {}", image.path.display()); self.monitor @@ -182,29 +127,12 @@ impl MsdController { return Err(AppError::Internal(error_msg)); } - // Configure LUN let config = if cdrom { MsdLunConfig::cdrom(image.path.clone()) } else { MsdLunConfig::disk(image.path.clone(), read_only) }; - - let gadget_path = self.active_gadget_path().await?; - if let Some(ref msd) = *self.msd_function.read().await { - if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await { - let error_msg = format!("Failed to configure LUN: {}", e); - self.monitor - .report_error(&error_msg, "configfs_error") - .await; - return Err(e); - } - } else { - let err = AppError::Internal("MSD function not initialized".to_string()); - self.monitor - .report_error("MSD function not initialized", "not_initialized") - .await; - return Err(err); - } + self.configure_lun_now(&config).await?; state.connected = true; state.mode = MsdMode::Image; @@ -215,42 +143,19 @@ impl MsdController { image.name, cdrom, read_only ); - // Release the lock before publishing events drop(state); drop(_op_guard); - // Report recovery if we were in an error state - if self.monitor.is_error().await { - self.monitor.report_recovered().await; - } - - self.mark_device_info_dirty().await; - + self.finish_connect_success().await; Ok(()) } - /// Connect the virtual drive pub async fn connect_drive(&self) -> Result<()> { - // Acquire operation lock to prevent concurrent operations let _op_guard = self.operation_lock.write().await; - let mut state = self.state.write().await; - if !state.available { - let err = AppError::Internal("MSD not available".to_string()); - self.monitor - .report_error("MSD not available", "not_available") - .await; - return Err(err); - } + self.assert_can_connect(&state).await?; - if state.connected { - return Err(AppError::Internal( - "Already connected. Disconnect first.".to_string(), - )); - } - - // Check drive exists if !self.drive_path.exists() { let err = AppError::Internal("Virtual drive not initialized. Call init first.".to_string()); @@ -260,25 +165,8 @@ impl MsdController { return Err(err); } - // Configure LUN as read-write disk let config = MsdLunConfig::disk(self.drive_path.clone(), false); - - let gadget_path = self.active_gadget_path().await?; - if let Some(ref msd) = *self.msd_function.read().await { - if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await { - let error_msg = format!("Failed to configure LUN: {}", e); - self.monitor - .report_error(&error_msg, "configfs_error") - .await; - return Err(e); - } - } else { - let err = AppError::Internal("MSD function not initialized".to_string()); - self.monitor - .report_error("MSD function not initialized", "not_initialized") - .await; - return Err(err); - } + self.configure_lun_now(&config).await?; state.connected = true; state.mode = MsdMode::Drive; @@ -286,23 +174,57 @@ impl MsdController { info!("Connected virtual drive: {}", self.drive_path.display()); - // Release the lock before publishing event drop(state); drop(_op_guard); - // Report recovery if we were in an error state - if self.monitor.is_error().await { - self.monitor.report_recovered().await; - } - - self.mark_device_info_dirty().await; - + self.finish_connect_success().await; Ok(()) } - /// Disconnect current storage + async fn assert_can_connect(&self, state: &MsdState) -> Result<()> { + if !state.available { + self.monitor + .report_error("MSD not available", "not_available") + .await; + return Err(AppError::Internal("MSD not available".to_string())); + } + if state.connected { + return Err(AppError::Internal( + "Already connected. Disconnect first.".to_string(), + )); + } + Ok(()) + } + + async fn configure_lun_now(&self, config: &MsdLunConfig) -> Result<()> { + let gadget_path = self.active_gadget_path().await?; + let msd_hold = self.msd_function.read().await; + let Some(ref msd) = *msd_hold else { + self.monitor + .report_error("MSD function not initialized", "not_initialized") + .await; + return Err(AppError::Internal( + "MSD function not initialized".to_string(), + )); + }; + if let Err(e) = msd.configure_lun_async(&gadget_path, 0, config).await { + let error_msg = format!("Failed to configure LUN: {}", e); + self.monitor + .report_error(&error_msg, "configfs_error") + .await; + return Err(e); + } + Ok(()) + } + + async fn finish_connect_success(&self) { + if self.monitor.is_error().await { + self.monitor.report_recovered().await; + } + self.mark_device_info_dirty().await; + } + pub async fn disconnect(&self) -> Result<()> { - // Acquire operation lock to prevent concurrent operations let _op_guard = self.operation_lock.write().await; let mut state = self.state.write().await; @@ -323,7 +245,6 @@ impl MsdController { info!("Disconnected storage"); - // Release the lock before publishing events drop(state); drop(_op_guard); @@ -332,41 +253,31 @@ impl MsdController { Ok(()) } - /// Get images storage path pub fn images_path(&self) -> &PathBuf { &self.images_path } - /// Get ventoy directory path pub fn ventoy_dir(&self) -> &PathBuf { &self.ventoy_dir } - /// Get virtual drive path pub fn drive_path(&self) -> &PathBuf { &self.drive_path } - /// Check if currently connected pub async fn is_connected(&self) -> bool { self.state.read().await.connected } - /// Get current mode pub async fn mode(&self) -> MsdMode { self.state.read().await.mode.clone() } - /// Update drive info pub async fn update_drive_info(&self, info: DriveInfo) { let mut state = self.state.write().await; state.drive_info = Some(info); } - /// Start downloading an image from URL - /// - /// Returns the download_id that can be used to track or cancel the download. - /// Progress is reported via MsdDownloadProgress events. pub async fn download_image( &self, url: String, @@ -375,18 +286,15 @@ impl MsdController { let download_id = uuid::Uuid::new_v4().to_string(); let cancel_token = CancellationToken::new(); - // Register download { let mut downloads = self.downloads.write().await; downloads.insert(download_id.clone(), cancel_token.clone()); } - // Extract filename for initial response let display_filename = filename .clone() .unwrap_or_else(|| url.rsplit('/').next().unwrap_or("download").to_string()); - // Create initial progress let initial_progress = DownloadProgress { download_id: download_id.clone(), url: url.clone(), @@ -398,7 +306,6 @@ impl MsdController { error: None, }; - // Publish started event self.publish_event(crate::events::SystemEvent::MsdDownloadProgress { download_id: download_id.clone(), url: url.clone(), @@ -410,18 +317,15 @@ impl MsdController { }) .await; - // Clone what we need for the spawned task let images_path = self.images_path.clone(); let events = self.events.read().await.clone(); let downloads = self.downloads.clone(); let download_id_clone = download_id.clone(); let url_clone = url.clone(); - // Spawn download task tokio::spawn(async move { let manager = ImageManager::new(images_path); - // Create progress callback let events_for_callback = events.clone(); let download_id_for_callback = download_id_clone.clone(); let url_for_callback = url_clone.clone(); @@ -443,18 +347,15 @@ impl MsdController { } }; - // Run download let result = manager .download_from_url(&url_clone, filename, progress_callback) .await; - // Remove from active downloads { let mut downloads_guard = downloads.write().await; downloads_guard.remove(&download_id_clone); } - // Publish completion event match result { Ok(image_info) => { if let Some(ref bus) = events { @@ -489,7 +390,6 @@ impl MsdController { Ok(initial_progress) } - /// Cancel an active download pub async fn cancel_download(&self, download_id: &str) -> Result<()> { let mut downloads = self.downloads.write().await; @@ -505,12 +405,6 @@ impl MsdController { } } - /// Get list of active download IDs - pub async fn active_downloads(&self) -> Vec { - let downloads = self.downloads.read().await; - downloads.keys().cloned().collect() - } - async fn active_gadget_path(&self) -> Result { self.otg_service .gadget_path() @@ -518,16 +412,13 @@ impl MsdController { .ok_or_else(|| AppError::Internal("OTG gadget path is not available".to_string())) } - /// Shutdown the controller pub async fn shutdown(&self) -> Result<()> { info!("Shutting down MSD controller"); - // 1. Disconnect if connected if let Err(e) = self.disconnect().await { warn!("Error disconnecting during shutdown: {}", e); } - // 2. Clear local state *self.msd_function.write().await = None; let mut state = self.state.write().await; @@ -537,27 +428,9 @@ impl MsdController { Ok(()) } - /// Get the health monitor reference pub fn monitor(&self) -> &Arc { &self.monitor } - - /// Get current health status - pub async fn health_status(&self) -> MsdHealthStatus { - self.monitor.status().await - } - - /// Check if the MSD is healthy - pub async fn is_healthy(&self) -> bool { - self.monitor.is_healthy().await - } -} - -impl Drop for MsdController { - fn drop(&mut self) { - // Cleanup is handled by OtgGadgetManager when the gadget is torn down - // Individual controllers don't need to cleanup the ConfigFS - } } #[cfg(test)] @@ -573,7 +446,6 @@ mod tests { let controller = MsdController::new(otg_service, &msd_dir); - // Check that MSD is not initialized (msd_function is None) let state = controller.state().await; assert!(!state.available); assert!(controller.images_path.ends_with("images")); diff --git a/src/msd/image.rs b/src/msd/image.rs index 88c04e45..1cb4b3da 100644 --- a/src/msd/image.rs +++ b/src/msd/image.rs @@ -1,53 +1,37 @@ -//! Image file manager -//! -//! Handles ISO/IMG image file operations: -//! - List available images -//! - Upload new images -//! - Delete images -//! - Metadata management -//! - Download from URL - use futures::StreamExt; -use time::OffsetDateTime; -use std::fs::{self, File}; -use std::io::{self, Read, Write}; +use std::fs; +#[cfg(test)] +use std::io::Write; use std::path::{Path, PathBuf}; use std::time::{Duration, Instant}; +use time::OffsetDateTime; use tokio::io::AsyncWriteExt; use tracing::info; use super::types::ImageInfo; use crate::error::{AppError, Result}; -/// Maximum image size (32 GB) const MAX_IMAGE_SIZE: u64 = 32 * 1024 * 1024 * 1024; -/// Progress report throttle interval (milliseconds) const PROGRESS_THROTTLE_MS: u64 = 200; -/// Progress report throttle bytes threshold (512 KB) const PROGRESS_THROTTLE_BYTES: u64 = 512 * 1024; -/// Image Manager pub struct ImageManager { - /// Images storage directory images_path: PathBuf, } impl ImageManager { - /// Create a new image manager pub fn new(images_path: PathBuf) -> Self { Self { images_path } } - /// Ensure images directory exists pub fn ensure_dir(&self) -> Result<()> { fs::create_dir_all(&self.images_path) .map_err(|e| AppError::Internal(format!("Failed to create images directory: {}", e)))?; Ok(()) } - /// List all available images pub fn list(&self) -> Result> { self.ensure_dir()?; @@ -68,19 +52,16 @@ impl ImageManager { } } - // Sort by creation time (newest first) images.sort_by(|a, b| b.created_at.cmp(&a.created_at)); Ok(images) } - /// Get image info from path fn get_image_info(&self, path: &Path) -> Option { let metadata = fs::metadata(path).ok()?; let name = path.file_name()?.to_string_lossy().to_string(); - // Use filename hash as ID (stable across restarts) - let id = format!("{:x}", md5_hash(&name)); + let id = stable_image_id_from_filename(&name); let created_at = metadata .created() @@ -101,7 +82,6 @@ impl ImageManager { }) } - /// Get image by ID pub fn get(&self, id: &str) -> Result { for image in self.list()? { if image.id == id { @@ -111,24 +91,21 @@ impl ImageManager { Err(AppError::NotFound(format!("Image not found: {}", id))) } - /// Get image by name pub fn get_by_name(&self, name: &str) -> Result { let path = self.images_path.join(name); self.get_image_info(&path) .ok_or_else(|| AppError::NotFound(format!("Image not found: {}", name))) } - /// Create a new image from bytes - pub fn create(&self, name: &str, data: &[u8]) -> Result { + #[cfg(test)] + fn create(&self, name: &str, data: &[u8]) -> Result { self.ensure_dir()?; - // Validate name let name = sanitize_filename(name); if name.is_empty() { return Err(AppError::Internal("Invalid filename".to_string())); } - // Check size if data.len() as u64 > MAX_IMAGE_SIZE { return Err(AppError::Internal(format!( "Image too large. Maximum size: {} GB", @@ -136,7 +113,6 @@ impl ImageManager { ))); } - // Write file let path = self.images_path.join(&name); if path.exists() { return Err(AppError::Internal(format!( @@ -145,11 +121,10 @@ impl ImageManager { ))); } - let mut file = File::create(&path) + let mut file = fs::File::create(&path) .map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?; file.write_all(data).map_err(|e| { - // Try to clean up on error let _ = fs::remove_file(&path); AppError::Internal(format!("Failed to write image data: {}", e)) })?; @@ -159,55 +134,6 @@ impl ImageManager { self.get_by_name(&name) } - /// Create a new image from a file stream (for chunked uploads) - pub fn create_from_stream( - &self, - name: &str, - reader: &mut R, - expected_size: Option, - ) -> Result { - self.ensure_dir()?; - - let name = sanitize_filename(name); - if name.is_empty() { - return Err(AppError::Internal("Invalid filename".to_string())); - } - - if let Some(size) = expected_size { - if size > MAX_IMAGE_SIZE { - return Err(AppError::Internal(format!( - "Image too large. Maximum size: {} GB", - MAX_IMAGE_SIZE / 1024 / 1024 / 1024 - ))); - } - } - - let path = self.images_path.join(&name); - if path.exists() { - return Err(AppError::Internal(format!( - "Image already exists: {}", - name - ))); - } - - // Create file and copy data - let mut file = File::create(&path) - .map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?; - - let bytes_written = io::copy(reader, &mut file).map_err(|e| { - let _ = fs::remove_file(&path); - AppError::Internal(format!("Failed to write image data: {}", e)) - })?; - - info!("Created image: {} ({} bytes)", name, bytes_written); - - self.get_by_name(&name) - } - - /// Create a new image from an async multipart field (streaming, memory-efficient) - /// - /// This method streams data directly to disk without buffering the entire file in memory, - /// making it suitable for large files (multi-GB ISOs). pub async fn create_from_multipart_field( &self, name: &str, @@ -220,12 +146,10 @@ impl ImageManager { return Err(AppError::Internal("Invalid filename".to_string())); } - // Use a temporary file during upload let temp_name = format!(".upload_{}", uuid::Uuid::new_v4()); let temp_path = self.images_path.join(&temp_name); let final_path = self.images_path.join(&name); - // Check if final file already exists if final_path.exists() { return Err(AppError::Internal(format!( "Image already exists: {}", @@ -233,23 +157,19 @@ impl ImageManager { ))); } - // Create temp file let mut file = tokio::fs::File::create(&temp_path) .await .map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?; let mut bytes_written: u64 = 0; - // Stream chunks directly to disk while let Some(chunk) = field .chunk() .await .map_err(|e| AppError::Internal(format!("Failed to read upload chunk: {}", e)))? { - // Check size limit bytes_written += chunk.len() as u64; if bytes_written > MAX_IMAGE_SIZE { - // Cleanup and return error drop(file); let _ = tokio::fs::remove_file(&temp_path).await; return Err(AppError::Internal(format!( @@ -258,19 +178,16 @@ impl ImageManager { ))); } - // Write chunk to file file.write_all(&chunk) .await .map_err(|e| AppError::Internal(format!("Failed to write chunk: {}", e)))?; } - // Flush and close file file.flush() .await .map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?; drop(file); - // Move temp file to final location tokio::fs::rename(&temp_path, &final_path) .await .map_err(|e| { @@ -286,7 +203,6 @@ impl ImageManager { self.get_by_name(&name) } - /// Delete an image by ID pub fn delete(&self, id: &str) -> Result<()> { let image = self.get(id)?; @@ -297,45 +213,6 @@ impl ImageManager { Ok(()) } - /// Delete an image by name - pub fn delete_by_name(&self, name: &str) -> Result<()> { - let path = self.images_path.join(name); - - if !path.exists() { - return Err(AppError::NotFound(format!("Image not found: {}", name))); - } - - fs::remove_file(&path) - .map_err(|e| AppError::Internal(format!("Failed to delete image: {}", e)))?; - - info!("Deleted image: {}", name); - Ok(()) - } - - /// Get total storage used - pub fn used_space(&self) -> u64 { - self.list() - .map(|images| images.iter().map(|i| i.size).sum()) - .unwrap_or(0) - } - - /// Check if storage has space for new image - pub fn has_space(&self, size: u64) -> bool { - // For now, just check against max size - // In the future, could check disk space - size <= MAX_IMAGE_SIZE - } - - /// Download image from URL with progress callback - /// - /// # Arguments - /// * `url` - The URL to download from - /// * `filename` - Optional custom filename (extracted from URL or Content-Disposition if not provided) - /// * `progress_callback` - Callback function called with (bytes_downloaded, total_bytes) - /// - /// # Returns - /// * `Ok(ImageInfo)` - The downloaded image info - /// * `Err(AppError)` - If download fails pub async fn download_from_url( &self, url: &str, @@ -347,20 +224,17 @@ impl ImageManager { { self.ensure_dir()?; - // Validate URL let parsed_url = reqwest::Url::parse(url) .map_err(|e| AppError::BadRequest(format!("Invalid URL: {}", e)))?; info!("Starting download from: {}", url); - // Create HTTP client with timeout let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(3600)) // 1 hour timeout for large files + .timeout(std::time::Duration::from_secs(3600)) .connect_timeout(std::time::Duration::from_secs(30)) .build() .map_err(|e| AppError::Internal(format!("Failed to create HTTP client: {}", e)))?; - // Send HEAD request first to get content info let head_response = client .head(url) .send() @@ -380,7 +254,6 @@ impl ImageManager { .and_then(|v| v.to_str().ok()) .and_then(|s| s.parse::().ok()); - // Check file size if let Some(size) = total_size { if size > MAX_IMAGE_SIZE { return Err(AppError::BadRequest(format!( @@ -391,11 +264,9 @@ impl ImageManager { } } - // Determine filename let final_filename = if let Some(name) = filename { sanitize_filename(&name) } else { - // Try Content-Disposition header first let from_header = head_response .headers() .get(reqwest::header::CONTENT_DISPOSITION) @@ -405,7 +276,6 @@ impl ImageManager { if let Some(name) = from_header { sanitize_filename(&name) } else { - // Fall back to URL path let path = parsed_url.path(); let name = path.rsplit('/').next().unwrap_or("download"); let name = urlencoding::decode(name).unwrap_or_else(|_| name.into()); @@ -419,7 +289,6 @@ impl ImageManager { )); } - // Check if file already exists let final_path = self.images_path.join(&final_filename); if final_path.exists() { return Err(AppError::BadRequest(format!( @@ -428,11 +297,9 @@ impl ImageManager { ))); } - // Create temporary file for download let temp_filename = format!(".download_{}", uuid::Uuid::new_v4()); let temp_path = self.images_path.join(&temp_filename); - // Start actual download let response = client .get(url) .send() @@ -446,7 +313,6 @@ impl ImageManager { ))); } - // Get actual content length from response (may differ from HEAD) let content_length = response .headers() .get(reqwest::header::CONTENT_LENGTH) @@ -454,19 +320,16 @@ impl ImageManager { .and_then(|s| s.parse::().ok()) .or(total_size); - // Create temp file let mut file = tokio::fs::File::create(&temp_path) .await .map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?; - // Stream download with progress (throttled) let mut stream = response.bytes_stream(); let mut downloaded: u64 = 0; let mut last_report_time = Instant::now(); let mut last_reported_bytes: u64 = 0; let throttle_interval = Duration::from_millis(PROGRESS_THROTTLE_MS); - // Report initial progress progress_callback(0, content_length); while let Some(chunk_result) = stream.next().await { @@ -474,14 +337,12 @@ impl ImageManager { chunk_result.map_err(|e| AppError::Internal(format!("Download error: {}", e)))?; file.write_all(&chunk).await.map_err(|e| { - // Cleanup on error let _ = std::fs::remove_file(&temp_path); AppError::Internal(format!("Failed to write data: {}", e)) })?; downloaded += chunk.len() as u64; - // Throttled progress reporting: report if enough time or bytes have passed let now = Instant::now(); let time_elapsed = now.duration_since(last_report_time) >= throttle_interval; let bytes_elapsed = downloaded - last_reported_bytes >= PROGRESS_THROTTLE_BYTES; @@ -493,18 +354,15 @@ impl ImageManager { } } - // Always report final progress if downloaded != last_reported_bytes { progress_callback(downloaded, content_length); } - // Ensure all data is flushed file.flush() .await .map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?; drop(file); - // Verify downloaded size let metadata = tokio::fs::metadata(&temp_path) .await .map_err(|e| AppError::Internal(format!("Failed to read file metadata: {}", e)))?; @@ -520,7 +378,6 @@ impl ImageManager { } } - // Move temp file to final location tokio::fs::rename(&temp_path, &final_path) .await .map_err(|e| { @@ -534,35 +391,29 @@ impl ImageManager { metadata.len() ); - // Return image info self.get_by_name(&final_filename) } - /// Get images storage path pub fn images_path(&self) -> &PathBuf { &self.images_path } } -/// Simple hash function for generating stable IDs -fn md5_hash(s: &str) -> u64 { +fn stable_image_id_from_filename(name: &str) -> String { let mut hash: u64 = 0; - for (i, byte) in s.bytes().enumerate() { + for (i, byte) in name.bytes().enumerate() { hash = hash.wrapping_add((byte as u64).wrapping_mul((i as u64).wrapping_add(1))); hash = hash.wrapping_mul(31); } - hash + format!("{:x}", hash) } -/// Sanitize filename to prevent path traversal fn sanitize_filename(name: &str) -> String { let name = name.trim(); let name = name.replace(['/', '\\', '\0', ':', '*', '?', '"', '<', '>', '|'], "_"); - // Remove leading dots (hidden files) let name = name.trim_start_matches('.'); - // Limit length if name.len() > 255 { name[..255].to_string() } else { @@ -570,17 +421,10 @@ fn sanitize_filename(name: &str) -> String { } } -/// Extract filename from Content-Disposition header fn extract_filename_from_content_disposition(header: &str) -> Option { - // Handle both: - // Content-Disposition: attachment; filename="example.iso" - // Content-Disposition: attachment; filename*=UTF-8''example.iso - - // Try filename* first (RFC 5987) if let Some(pos) = header.find("filename*=") { let start = pos + 10; let value = &header[start..]; - // Format: charset'language'value if let Some(quote_start) = value.find("''") { let encoded = value[quote_start + 2..].split(';').next()?; let decoded = urlencoding::decode(encoded.trim()).ok()?; @@ -591,7 +435,6 @@ fn extract_filename_from_content_disposition(header: &str) -> Option { } } - // Try filename next if let Some(pos) = header.find("filename=") { let start = pos + 9; let value = &header[start..]; @@ -613,7 +456,7 @@ mod tests { #[test] fn test_sanitize_filename() { assert_eq!(sanitize_filename("test.iso"), "test.iso"); - assert_eq!(sanitize_filename("../test.iso"), "_test.iso"); // .. becomes empty after trim_start_matches('.') + assert_eq!(sanitize_filename("../test.iso"), "_test.iso"); assert_eq!(sanitize_filename("test/file.iso"), "test_file.iso"); assert_eq!(sanitize_filename(".hidden.iso"), "hidden.iso"); } diff --git a/src/msd/mod.rs b/src/msd/mod.rs index 1359209f..fc162890 100644 --- a/src/msd/mod.rs +++ b/src/msd/mod.rs @@ -1,19 +1,3 @@ -//! MSD (Mass Storage Device) module -//! -//! Provides virtual USB storage functionality with two modes: -//! - Image mounting: Mount ISO/IMG files for system installation -//! - Ventoy drive: Bootable exFAT drive for multiple ISO files -//! -//! Architecture: -//! ```text -//! Web API --> MSD Controller --> ConfigFS Mass Storage --> Target PC -//! | -//! ┌──────┴──────┐ -//! │ │ -//! Image Manager Ventoy Drive -//! (ISO/IMG) (Bootable exFAT) -//! ``` - pub mod controller; pub mod image; pub mod monitor; @@ -22,12 +6,11 @@ pub mod ventoy_drive; pub use controller::MsdController; pub use image::ImageManager; -pub use monitor::{MsdHealthMonitor, MsdHealthStatus, MsdMonitorConfig}; +pub use monitor::MsdHealthMonitor; pub use types::{ DownloadProgress, DownloadStatus, DriveFile, DriveInfo, DriveInitRequest, ImageDownloadRequest, ImageInfo, MsdConnectRequest, MsdMode, MsdState, }; pub use ventoy_drive::VentoyDrive; -// Re-export from otg module for backward compatibility pub use crate::otg::{MsdFunction, MsdLunConfig}; diff --git a/src/msd/monitor.rs b/src/msd/monitor.rs index 46030b40..39166d2c 100644 --- a/src/msd/monitor.rs +++ b/src/msd/monitor.rs @@ -1,99 +1,46 @@ -//! MSD (Mass Storage Device) health monitoring -//! -//! This module provides health monitoring for MSD operations, including: -//! - ConfigFS operation error tracking -//! - Image mount/unmount error tracking -//! - Error state tracking -//! - Log throttling to prevent log flooding - use std::sync::atomic::{AtomicU32, Ordering}; use tokio::sync::RwLock; use tracing::{info, warn}; use crate::utils::LogThrottler; -/// MSD health status +const LOG_THROTTLE_SECS: u64 = 5; + #[derive(Debug, Clone, PartialEq, Default)] -pub enum MsdHealthStatus { - /// Device is healthy and operational +pub(crate) enum MsdHealthStatus { #[default] Healthy, - /// Device has an error Error { - /// Human-readable error reason reason: String, - /// Error code for programmatic handling error_code: String, }, } -/// MSD health monitor configuration -#[derive(Debug, Clone)] -pub struct MsdMonitorConfig { - /// Log throttle interval in seconds - pub log_throttle_secs: u64, -} - -impl Default for MsdMonitorConfig { - fn default() -> Self { - Self { - log_throttle_secs: 5, - } - } -} - -/// MSD health monitor -/// -/// Monitors MSD operation health and manages error state. pub struct MsdHealthMonitor { - /// Current health status status: RwLock, - /// Log throttler to prevent log flooding throttler: LogThrottler, - /// Error count (for tracking) error_count: AtomicU32, - /// Last error code (for change detection) last_error_code: RwLock>, } impl MsdHealthMonitor { - /// Create a new MSD health monitor with the specified configuration - pub fn new(config: MsdMonitorConfig) -> Self { - let throttle_secs = config.log_throttle_secs; + pub fn with_defaults() -> Self { Self { status: RwLock::new(MsdHealthStatus::Healthy), - throttler: LogThrottler::with_secs(throttle_secs), + throttler: LogThrottler::with_secs(LOG_THROTTLE_SECS), error_count: AtomicU32::new(0), last_error_code: RwLock::new(None), } } - /// Create a new MSD health monitor with default configuration - pub fn with_defaults() -> Self { - Self::new(MsdMonitorConfig::default()) - } - - /// Report an error from MSD operations - /// - /// This method is called when an MSD operation fails. It: - /// 1. Updates the health status - /// 2. Logs the error (with throttling) - /// 3. Updates in-memory error state - /// - /// # Arguments - /// - /// * `reason` - Human-readable error description - /// * `error_code` - Error code for programmatic handling pub async fn report_error(&self, reason: &str, error_code: &str) { let count = self.error_count.fetch_add(1, Ordering::Relaxed) + 1; - // Check if error code changed let error_changed = { let last = self.last_error_code.read().await; last.as_ref().map(|s| s.as_str()) != Some(error_code) }; - // Log with throttling (always log if error type changed) let throttle_key = format!("msd_{}", error_code); if error_changed || self.throttler.should_log(&throttle_key) { warn!( @@ -102,29 +49,21 @@ impl MsdHealthMonitor { ); } - // Update last error code *self.last_error_code.write().await = Some(error_code.to_string()); - // Update status *self.status.write().await = MsdHealthStatus::Error { reason: reason.to_string(), error_code: error_code.to_string(), }; } - /// Report that the MSD has recovered from error - /// - /// This method is called when an MSD operation succeeds after errors. - /// It resets the error state. pub async fn report_recovered(&self) { let prev_status = self.status.read().await.clone(); - // Only report recovery if we were in an error state if prev_status != MsdHealthStatus::Healthy { let error_count = self.error_count.load(Ordering::Relaxed); info!("MSD recovered after {} errors", error_count); - // Reset state self.error_count.store(0, Ordering::Relaxed); self.throttler.clear_all(); *self.last_error_code.write().await = None; @@ -132,29 +71,25 @@ impl MsdHealthMonitor { } } - /// Get the current health status - pub async fn status(&self) -> MsdHealthStatus { + #[cfg(test)] + pub(crate) async fn status(&self) -> MsdHealthStatus { self.status.read().await.clone() } - /// Get the current error count - pub fn error_count(&self) -> u32 { + #[cfg(test)] + pub(crate) fn error_count(&self) -> u32 { self.error_count.load(Ordering::Relaxed) } - /// Check if the monitor is in an error state pub async fn is_error(&self) -> bool { matches!(*self.status.read().await, MsdHealthStatus::Error { .. }) } - /// Check if the monitor is healthy - pub async fn is_healthy(&self) -> bool { + #[cfg(test)] + pub(crate) async fn is_healthy(&self) -> bool { matches!(*self.status.read().await, MsdHealthStatus::Healthy) } - /// Reset the monitor to healthy state without publishing events - /// - /// This is useful during initialization. pub async fn reset(&self) { self.error_count.store(0, Ordering::Relaxed); *self.last_error_code.write().await = None; @@ -162,7 +97,6 @@ impl MsdHealthMonitor { self.throttler.clear_all(); } - /// Get the current error message if in error state pub async fn error_message(&self) -> Option { match &*self.status.read().await { MsdHealthStatus::Error { reason, .. } => Some(reason.clone()), @@ -212,13 +146,11 @@ mod tests { async fn test_report_recovered() { let monitor = MsdHealthMonitor::with_defaults(); - // First report an error monitor .report_error("Image not found", "image_not_found") .await; assert!(monitor.is_error().await); - // Then report recovery monitor.report_recovered().await; assert!(monitor.is_healthy().await); assert_eq!(monitor.error_count(), 0); diff --git a/src/msd/types.rs b/src/msd/types.rs index 321222b6..903ce326 100644 --- a/src/msd/types.rs +++ b/src/msd/types.rs @@ -1,42 +1,28 @@ -//! MSD data types and structures - use serde::{Deserialize, Serialize}; use std::path::PathBuf; use time::OffsetDateTime; -/// MSD operating mode -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all = "snake_case")] -#[derive(Default)] pub enum MsdMode { - /// No storage connected #[default] None, - /// Image file mounted (ISO/IMG) Image, - /// Virtual drive (FAT32) connected Drive, } -/// Image file metadata #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ImageInfo { - /// Unique image ID pub id: String, - /// Display name pub name: String, - /// File path on disk #[serde(skip_serializing)] pub path: PathBuf, - /// File size in bytes pub size: u64, - /// Creation timestamp #[serde(with = "time::serde::rfc3339")] pub created_at: OffsetDateTime, } impl ImageInfo { - /// Create new image info pub fn new(id: String, name: String, path: PathBuf, size: u64) -> Self { Self { id, @@ -47,7 +33,6 @@ impl ImageInfo { } } - /// Format size for display pub fn size_display(&self) -> String { const KB: u64 = 1024; const MB: u64 = KB * 1024; @@ -65,18 +50,12 @@ impl ImageInfo { } } -/// MSD state information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MsdState { - /// Whether MSD feature is available pub available: bool, - /// Current mode pub mode: MsdMode, - /// Whether storage is connected to target pub connected: bool, - /// Currently mounted image (if mode is Image) pub current_image: Option, - /// Virtual drive info (if mode is Drive) pub drive_info: Option, } @@ -92,24 +71,17 @@ impl Default for MsdState { } } -/// Virtual drive information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DriveInfo { - /// Drive size in bytes pub size: u64, - /// Used space in bytes pub used: u64, - /// Free space in bytes pub free: u64, - /// Whether drive is initialized pub initialized: bool, - /// Drive file path #[serde(skip_serializing)] pub path: PathBuf, } impl DriveInfo { - /// Create new drive info pub fn new(path: PathBuf, size: u64) -> Self { Self { size, @@ -121,92 +93,60 @@ impl DriveInfo { } } -/// File entry in virtual drive #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DriveFile { - /// File name pub name: String, - /// Relative path from drive root pub path: String, - /// File size in bytes (0 for directories) pub size: u64, - /// Whether this is a directory pub is_dir: bool, - /// Last modified timestamp #[serde(with = "time::serde::rfc3339::option")] pub modified: Option, } -/// MSD connect request #[derive(Debug, Clone, Deserialize)] pub struct MsdConnectRequest { - /// Connection mode: "image" or "drive" pub mode: MsdMode, - /// Image ID to mount (required for image mode) pub image_id: Option, - /// Mount as CD-ROM (optional, defaults based on image type) #[serde(default)] pub cdrom: Option, - /// Mount as read-only #[serde(default)] pub read_only: Option, } -/// Virtual drive init request #[derive(Debug, Clone, Deserialize)] pub struct DriveInitRequest { - /// Drive size in megabytes (defaults to 16GB) #[serde(default = "default_drive_size")] pub size_mb: u32, - /// Optional custom path for Ventoy installation - pub ventoy_path: Option, } fn default_drive_size() -> u32 { - 16 * 1024 // 16GB + 16 * 1024 } -/// Image download request #[derive(Debug, Clone, Deserialize)] pub struct ImageDownloadRequest { - /// URL to download from pub url: String, - /// Optional custom filename pub filename: Option, } -/// Download status #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum DownloadStatus { - /// Download has started Started, - /// Download is in progress InProgress, - /// Download completed successfully Completed, - /// Download failed Failed, } -/// Download progress information #[derive(Debug, Clone, Serialize)] pub struct DownloadProgress { - /// Unique download ID pub download_id: String, - /// Source URL pub url: String, - /// Target filename pub filename: String, - /// Bytes downloaded so far pub bytes_downloaded: u64, - /// Total file size (None if unknown) pub total_bytes: Option, - /// Progress percentage (0.0 - 100.0, None if total unknown) pub progress_pct: Option, - /// Download status pub status: DownloadStatus, - /// Error message if failed pub error: Option, } @@ -220,7 +160,7 @@ mod tests { "test".into(), "test.iso".into(), PathBuf::from("/tmp/test.iso"), - 1024 * 1024 * 1024 * 2, // 2 GB + 1024 * 1024 * 1024 * 2, ); assert!(info.size_display().contains("GB")); } diff --git a/src/msd/ventoy_drive.rs b/src/msd/ventoy_drive.rs index 20d07494..e3d3dc44 100644 --- a/src/msd/ventoy_drive.rs +++ b/src/msd/ventoy_drive.rs @@ -1,8 +1,3 @@ -//! Ventoy Virtual Drive -//! -//! Replaces FAT32 VirtualDrive with a Ventoy bootable image. -//! Provides a bootable USB with exFAT data partition for ISO files. - use std::path::{Path, PathBuf}; use std::sync::Arc; use tokio::sync::RwLock; @@ -13,33 +8,20 @@ use ventoy_img::{FileInfo as VentoyFileInfo, VentoyError, VentoyImage}; use super::types::{DriveFile, DriveInfo}; use crate::error::{AppError, Result}; -/// Chunk size for streaming reads (64 KB) const STREAM_CHUNK_SIZE: usize = 64 * 1024; -/// Minimum drive size (1 GB) - Ventoy requires space for boot partition const MIN_DRIVE_SIZE_MB: u32 = 1024; -/// Maximum drive size (128 GB) const MAX_DRIVE_SIZE_MB: u32 = 128 * 1024; -/// Default drive label const DEFAULT_LABEL: &str = "ONE-KVM"; -/// Ventoy Drive Manager -/// -/// Thread-safe wrapper around VentoyImage providing async file operations. -/// Uses spawn_blocking for all ventoy-img-rs operations since they are synchronous. -/// Uses RwLock to allow concurrent read operations while serializing writes. pub struct VentoyDrive { - /// Drive image path path: PathBuf, - /// RwLock for concurrent reads, exclusive writes - /// (ventoy-img-rs operations are synchronous and not thread-safe) lock: Arc>, } impl VentoyDrive { - /// Create new Ventoy drive manager pub fn new(path: PathBuf) -> Self { Self { path, @@ -47,40 +29,32 @@ impl VentoyDrive { } } - /// Check if drive image exists pub fn exists(&self) -> bool { self.path.exists() } - /// Get drive path pub fn path(&self) -> &PathBuf { &self.path } - /// Initialize a new Ventoy drive image - /// - /// Creates a bootable Ventoy image with the specified size. - /// The image includes boot partitions and an exFAT data partition. pub async fn init(&self, size_mb: u32) -> Result { let size_mb = size_mb.clamp(MIN_DRIVE_SIZE_MB, MAX_DRIVE_SIZE_MB); let size_str = format!("{}M", size_mb); let path = self.path.clone(); - let _lock = self.lock.write().await; // Write lock for initialization + let _lock = self.lock.write().await; info!("Creating {} MB Ventoy drive at {}", size_mb, path.display()); - // Run Ventoy creation in blocking task let info = tokio::task::spawn_blocking(move || { VentoyImage::create(&path, &size_str, DEFAULT_LABEL).map_err(ventoy_to_app_error)?; - // Get file metadata for DriveInfo let metadata = std::fs::metadata(&path) .map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?; Ok::(DriveInfo { size: metadata.len(), used: 0, - free: metadata.len(), // Approximate - exFAT overhead not calculated + free: metadata.len(), initialized: true, path, }) @@ -92,20 +66,18 @@ impl VentoyDrive { Ok(info) } - /// Get drive information pub async fn info(&self) -> Result { if !self.exists() { return Err(AppError::Internal("Drive not initialized".to_string())); } let path = self.path.clone(); - let _lock = self.lock.read().await; // Read lock for info query + let _lock = self.lock.read().await; tokio::task::spawn_blocking(move || { let metadata = std::fs::metadata(&path) .map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?; - // Open image to get file list and calculate used space let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?; let files = image.list_files_recursive().map_err(ventoy_to_app_error)?; @@ -116,7 +88,6 @@ impl VentoyDrive { .map(|f| f.size) .sum(); - // Note: This is approximate since we don't have exact exFAT overhead let size = metadata.len(); let free = size.saturating_sub(used); @@ -132,7 +103,6 @@ impl VentoyDrive { .map_err(|e| AppError::Internal(format!("Task join error: {}", e)))? } - /// List files at a given path (or root if empty/"/") pub async fn list_files(&self, dir_path: &str) -> Result> { if !self.exists() { return Err(AppError::Internal("Drive not initialized".to_string())); @@ -140,7 +110,7 @@ impl VentoyDrive { let path = self.path.clone(); let dir_path = dir_path.to_string(); - let _lock = self.lock.read().await; // Read lock for listing + let _lock = self.lock.read().await; tokio::task::spawn_blocking(move || { let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?; @@ -161,9 +131,6 @@ impl VentoyDrive { .map_err(|e| AppError::Internal(format!("Task join error: {}", e)))? } - /// Write a file to the drive from multipart upload (streaming) - /// - /// Streams the file directly into the Ventoy image's exFAT partition. pub async fn write_file_from_multipart_field( &self, file_path: &str, @@ -173,12 +140,10 @@ impl VentoyDrive { return Err(AppError::Internal("Drive not initialized".to_string())); } - // First, stream to a temporary file (to get the size) let temp_dir = self.path.parent().unwrap_or(Path::new("/tmp")); let temp_name = format!(".upload_ventoy_{}", uuid::Uuid::new_v4()); let temp_path = temp_dir.join(&temp_name); - // Stream upload to temp file let mut temp_file = tokio::fs::File::create(&temp_path) .await .map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?; @@ -201,23 +166,16 @@ impl VentoyDrive { .map_err(|e| AppError::Internal(format!("Failed to flush temp file: {}", e)))?; drop(temp_file); - // Now copy from temp file to Ventoy image let path = self.path.clone(); let file_path = file_path.to_string(); let temp_path_clone = temp_path.clone(); - let _lock = self.lock.write().await; // Write lock for file write + let _lock = self.lock.write().await; let result = tokio::task::spawn_blocking(move || { let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?; - // Use add_file_to_path which handles streaming internally image - .add_file_to_path( - &temp_path_clone, - &file_path, - true, // create_parents - true, // overwrite - ) + .add_file_to_path(&temp_path_clone, &file_path, true, true) .map_err(ventoy_to_app_error)?; Ok::<(), AppError>(()) @@ -225,14 +183,13 @@ impl VentoyDrive { .await .map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?; - // Cleanup temp file let _ = tokio::fs::remove_file(&temp_path).await; result?; Ok(bytes_written) } - /// Read a file from the drive (for download) + #[cfg(test)] pub async fn read_file(&self, file_path: &str) -> Result> { if !self.exists() { return Err(AppError::Internal("Drive not initialized".to_string())); @@ -240,7 +197,7 @@ impl VentoyDrive { let path = self.path.clone(); let file_path = file_path.to_string(); - let _lock = self.lock.read().await; // Read lock for file read + let _lock = self.lock.read().await; tokio::task::spawn_blocking(move || { let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?; @@ -251,10 +208,6 @@ impl VentoyDrive { .map_err(|e| AppError::Internal(format!("Task join error: {}", e)))? } - /// Get file information without reading content - /// - /// Returns file size, name, and other metadata. - /// Returns None if the file doesn't exist. pub async fn get_file_info(&self, file_path: &str) -> Result> { if !self.exists() { return Err(AppError::Internal("Drive not initialized".to_string())); @@ -262,7 +215,7 @@ impl VentoyDrive { let path = self.path.clone(); let file_path_owned = file_path.to_string(); - let _lock = self.lock.read().await; // Read lock for file info + let _lock = self.lock.read().await; let info = tokio::task::spawn_blocking(move || { let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?; @@ -282,10 +235,6 @@ impl VentoyDrive { })) } - /// Read a file from the drive as a stream (for large file downloads) - /// - /// Returns an async channel receiver that yields chunks of file data. - /// This avoids loading the entire file into memory. pub async fn read_file_stream( &self, file_path: &str, @@ -297,7 +246,6 @@ impl VentoyDrive { return Err(AppError::Internal("Drive not initialized".to_string())); } - // First, get the file size let file_info = self .get_file_info(file_path) .await? @@ -315,15 +263,12 @@ impl VentoyDrive { let file_path_owned = file_path.to_string(); let lock = self.lock.clone(); - // Create a channel for streaming data let (tx, rx) = tokio::sync::mpsc::channel::>(8); - // Spawn blocking task to read and send chunks tokio::task::spawn_blocking(move || { - // Hold read lock for the entire read operation let rt = tokio::runtime::Handle::current(); - let _lock = rt.block_on(lock.read()); // Read lock for streaming + let _lock = rt.block_on(lock.read()); let image = match VentoyImage::open(&path) { Ok(img) => img, @@ -333,10 +278,8 @@ impl VentoyDrive { } }; - // Create a channel writer that sends chunks let mut chunk_writer = ChannelWriter::new(tx.clone(), rt.clone()); - // Stream the file through the writer if let Err(e) = image.read_file_to_writer(&file_path_owned, &mut chunk_writer) { let _ = rt.block_on(tx.send(Err(std::io::Error::other(e.to_string())))); } @@ -345,7 +288,6 @@ impl VentoyDrive { Ok((file_size, rx)) } - /// Create a directory pub async fn mkdir(&self, dir_path: &str) -> Result<()> { if !self.exists() { return Err(AppError::Internal("Drive not initialized".to_string())); @@ -353,7 +295,7 @@ impl VentoyDrive { let path = self.path.clone(); let dir_path = dir_path.to_string(); - let _lock = self.lock.write().await; // Write lock for mkdir + let _lock = self.lock.write().await; tokio::task::spawn_blocking(move || { let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?; @@ -366,7 +308,6 @@ impl VentoyDrive { .map_err(|e| AppError::Internal(format!("Task join error: {}", e)))? } - /// Delete a file or directory pub async fn delete(&self, path_to_delete: &str) -> Result<()> { if !self.exists() { return Err(AppError::Internal("Drive not initialized".to_string())); @@ -374,12 +315,11 @@ impl VentoyDrive { let path = self.path.clone(); let path_to_delete = path_to_delete.to_string(); - let _lock = self.lock.write().await; // Write lock for delete + let _lock = self.lock.write().await; tokio::task::spawn_blocking(move || { let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?; - // Use recursive delete to handle directories image .remove_recursive(&path_to_delete) .map_err(ventoy_to_app_error) @@ -389,7 +329,6 @@ impl VentoyDrive { } } -/// Convert VentoyError to AppError fn ventoy_to_app_error(err: VentoyError) -> AppError { match err { VentoyError::Io(e) => AppError::Io(e), @@ -405,7 +344,6 @@ fn ventoy_to_app_error(err: VentoyError) -> AppError { } } -/// Convert VentoyFileInfo to DriveFile fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFile { let full_path = if parent_path.is_empty() || parent_path == "/" { format!("/{}", info.name) @@ -418,13 +356,10 @@ fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFi path: full_path, size: info.size, is_dir: info.is_directory, - modified: None, // Ventoy FileInfo doesn't include timestamps + modified: None, } } -/// A writer that sends chunks to an async channel -/// -/// This bridges the sync Write trait with async channels for streaming. struct ChannelWriter { tx: tokio::sync::mpsc::Sender>, rt: tokio::runtime::Handle, @@ -484,7 +419,6 @@ impl std::io::Write for ChannelWriter { impl Drop for ChannelWriter { fn drop(&mut self) { - // Flush any remaining data when the writer is dropped let _ = self.flush_buffer(); } } @@ -496,16 +430,13 @@ mod tests { use std::sync::OnceLock; use tempfile::TempDir; - /// Path to ventoy resources directory static RESOURCE_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../ventoy-img-rs/resources"); - /// Initialize ventoy resources once fn init_ventoy_resources() -> bool { static INIT: OnceLock = OnceLock::new(); *INIT.get_or_init(|| { let resource_path = std::path::Path::new(RESOURCE_DIR); - // Decompress xz files if needed let core_xz = resource_path.join("core.img.xz"); let core_img = resource_path.join("core.img"); if core_xz.exists() && !core_img.exists() { @@ -524,7 +455,6 @@ mod tests { } } - // Initialize resources if let Err(e) = ventoy_img::resources::init_resources(resource_path) { eprintln!("Failed to init ventoy resources: {}", e); return false; @@ -534,7 +464,6 @@ mod tests { }) } - /// Decompress xz file using system command fn decompress_xz(src: &std::path::Path, dst: &std::path::Path) -> std::io::Result<()> { let output = Command::new("xz") .args(["-d", "-k", "-c", src.to_str().unwrap()]) @@ -551,7 +480,6 @@ mod tests { Ok(()) } - /// Ensure resources are initialized, skip test if failed fn ensure_resources() -> bool { if !init_ventoy_resources() { eprintln!("Skipping test: ventoy resources not available"); @@ -602,15 +530,12 @@ mod tests { let drive_path = temp_dir.path().join("test_ventoy.img"); let drive = VentoyDrive::new(drive_path.clone()); - // Initialize drive drive.init(MIN_DRIVE_SIZE_MB).await.unwrap(); - // Write a test file let test_content = b"Hello, Ventoy!"; let test_file_path = temp_dir.path().join("test.txt"); std::fs::write(&test_file_path, test_content).unwrap(); - // Add file to drive using ventoy-img directly let path = drive.path().clone(); tokio::task::spawn_blocking(move || { let mut image = VentoyImage::open(&path).unwrap(); @@ -619,7 +544,6 @@ mod tests { .await .unwrap(); - // Read file from drive let read_data = drive.read_file("/test.txt").await.unwrap(); assert_eq!(read_data, test_content); } @@ -633,18 +557,14 @@ mod tests { let drive_path = temp_dir.path().join("test_ventoy.img"); let drive = VentoyDrive::new(drive_path.clone()); - // Initialize drive drive.init(MIN_DRIVE_SIZE_MB).await.unwrap(); - // Create a directory drive.mkdir("/mydir").await.unwrap(); - // Write a test file let test_content = b"Test file content for info check"; let test_file_path = temp_dir.path().join("info_test.txt"); std::fs::write(&test_file_path, test_content).unwrap(); - // Add file to drive let path = drive.path().clone(); tokio::task::spawn_blocking(move || { let mut image = VentoyImage::open(&path).unwrap(); @@ -653,7 +573,6 @@ mod tests { .await .unwrap(); - // Test get_file_info for file let file_info = drive.get_file_info("/info_test.txt").await.unwrap(); assert!(file_info.is_some()); let file_info = file_info.unwrap(); @@ -661,14 +580,12 @@ mod tests { assert_eq!(file_info.size, test_content.len() as u64); assert!(!file_info.is_dir); - // Test get_file_info for directory let dir_info = drive.get_file_info("/mydir").await.unwrap(); assert!(dir_info.is_some()); let dir_info = dir_info.unwrap(); assert_eq!(dir_info.name, "mydir"); assert!(dir_info.is_dir); - // Test get_file_info for non-existent file let not_found = drive.get_file_info("/nonexistent.txt").await.unwrap(); assert!(not_found.is_none()); } @@ -682,16 +599,13 @@ mod tests { let drive_path = temp_dir.path().join("test_ventoy.img"); let drive = VentoyDrive::new(drive_path.clone()); - // Initialize drive drive.init(MIN_DRIVE_SIZE_MB).await.unwrap(); - // Create test data that spans multiple chunks (>64KB) - let test_size = 200 * 1024; // 200 KB + let test_size = 200 * 1024; let test_content: Vec = (0..test_size).map(|i| (i % 256) as u8).collect(); let test_file_path = temp_dir.path().join("large_file.bin"); std::fs::write(&test_file_path, &test_content).unwrap(); - // Add file to drive let path = drive.path().clone(); let file_path_clone = test_file_path.clone(); tokio::task::spawn_blocking(move || { @@ -701,18 +615,15 @@ mod tests { .await .unwrap(); - // Stream read the file let (file_size, mut rx) = drive.read_file_stream("/large_file.bin").await.unwrap(); assert_eq!(file_size, test_size as u64); - // Collect all chunks let mut received_data = Vec::new(); while let Some(chunk_result) = rx.recv().await { let chunk = chunk_result.expect("Chunk should not be an error"); received_data.extend_from_slice(&chunk); } - // Verify data matches assert_eq!(received_data.len(), test_content.len()); assert_eq!(received_data, test_content); } @@ -726,15 +637,12 @@ mod tests { let drive_path = temp_dir.path().join("test_ventoy.img"); let drive = VentoyDrive::new(drive_path.clone()); - // Initialize drive drive.init(MIN_DRIVE_SIZE_MB).await.unwrap(); - // Create a small test file let test_content = b"Small file for streaming test"; let test_file_path = temp_dir.path().join("small.txt"); std::fs::write(&test_file_path, test_content).unwrap(); - // Add file to drive let path = drive.path().clone(); tokio::task::spawn_blocking(move || { let mut image = VentoyImage::open(&path).unwrap(); @@ -743,18 +651,15 @@ mod tests { .await .unwrap(); - // Stream read the file let (file_size, mut rx) = drive.read_file_stream("/small.txt").await.unwrap(); assert_eq!(file_size, test_content.len() as u64); - // Collect all chunks let mut received_data = Vec::new(); while let Some(chunk_result) = rx.recv().await { let chunk = chunk_result.expect("Chunk should not be an error"); received_data.extend_from_slice(&chunk); } - // Verify data matches assert_eq!(received_data.as_slice(), test_content); } } diff --git a/src/otg/configfs.rs b/src/otg/configfs.rs index 6a783632..8efdf156 100644 --- a/src/otg/configfs.rs +++ b/src/otg/configfs.rs @@ -1,5 +1,3 @@ -//! ConfigFS file operations for USB Gadget - use std::fs::{self, File, OpenOptions}; use std::io::Write; use std::path::Path; @@ -7,34 +5,18 @@ use std::process::Command; use crate::error::{AppError, Result}; -/// ConfigFS base path for USB gadgets pub const CONFIGFS_PATH: &str = "/sys/kernel/config/usb_gadget"; - -/// Default gadget name pub const DEFAULT_GADGET_NAME: &str = "one-kvm"; - -/// USB Vendor ID (Linux Foundation) - default value pub const DEFAULT_USB_VENDOR_ID: u16 = 0x1d6b; - -/// USB Product ID (Multifunction Composite Gadget) - default value pub const DEFAULT_USB_PRODUCT_ID: u16 = 0x0104; - -/// USB device version - default value pub const DEFAULT_USB_BCD_DEVICE: u16 = 0x0100; - -/// USB spec version (USB 2.0) pub const USB_BCD_USB: u16 = 0x0200; -/// Check if ConfigFS is available pub fn is_configfs_available() -> bool { Path::new(CONFIGFS_PATH).exists() } -/// Ensure libcomposite support is available for USB gadget operations. -/// -/// This is a best-effort runtime fallback for systems where `libcomposite` -/// is built as a module and not loaded yet. It does not try to mount configfs; -/// mounting remains an explicit system responsibility. +/// Loads `libcomposite` if needed; does not mount configfs. pub fn ensure_libcomposite_loaded() -> Result<()> { if is_configfs_available() { return Ok(()); @@ -66,7 +48,6 @@ pub fn ensure_libcomposite_loaded() -> Result<()> { } } -/// Find available UDC (USB Device Controller) pub fn find_udc() -> Option { let udc_path = Path::new("/sys/class/udc"); if !udc_path.exists() { @@ -80,40 +61,17 @@ pub fn find_udc() -> Option { .next() } -/// Check if UDC is known to have low endpoint resources pub fn is_low_endpoint_udc(name: &str) -> bool { let name = name.to_ascii_lowercase(); name.contains("musb") || name.contains("musb-hdrc") } -/// Resolve preferred UDC name if available, otherwise auto-detect -pub fn resolve_udc_name(preferred: Option<&str>) -> Option { - if let Some(name) = preferred { - let path = Path::new("/sys/class/udc").join(name); - if path.exists() { - return Some(name.to_string()); - } - } - find_udc() -} - -/// Write string content to a file -/// -/// For sysfs files, this function appends a newline and flushes -/// to ensure the kernel processes the write immediately. -/// -/// IMPORTANT: sysfs attributes require a single atomic write() syscall. -/// The kernel processes the value on the first write(), so we must -/// build the complete buffer (including newline) before writing. +/// Sysfs/configfs: one write syscall with final buffer (incl. newline when needed). pub fn write_file(path: &Path, content: &str) -> Result<()> { - // For sysfs files (especially write-only ones like forced_eject), - // we need to use simple O_WRONLY without O_TRUNC - // O_TRUNC may fail on special files or require read permission let mut file = OpenOptions::new() .write(true) .open(path) .or_else(|e| { - // If open fails, try create (for regular files) if path.exists() { Err(e) } else { @@ -122,9 +80,6 @@ pub fn write_file(path: &Path, content: &str) -> Result<()> { }) .map_err(|e| AppError::Internal(format!("Failed to open {}: {}", path.display(), e)))?; - // Build complete buffer with newline, then write in single syscall. - // This is critical for sysfs - multiple write() calls may cause - // the kernel to only process partial data or return EINVAL. let data: std::borrow::Cow<[u8]> = if content.ends_with('\n') { content.as_bytes().into() } else { @@ -136,14 +91,12 @@ pub fn write_file(path: &Path, content: &str) -> Result<()> { file.write_all(&data) .map_err(|e| AppError::Internal(format!("Failed to write to {}: {}", path.display(), e)))?; - // Explicitly flush to ensure sysfs processes the write file.flush() .map_err(|e| AppError::Internal(format!("Failed to flush {}: {}", path.display(), e)))?; Ok(()) } -/// Write binary content to a file pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> { let mut file = File::create(path) .map_err(|e| AppError::Internal(format!("Failed to create {}: {}", path.display(), e)))?; @@ -154,14 +107,6 @@ pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> { Ok(()) } -/// Read string content from a file -pub fn read_file(path: &Path) -> Result { - fs::read_to_string(path) - .map(|s| s.trim().to_string()) - .map_err(|e| AppError::Internal(format!("Failed to read {}: {}", path.display(), e))) -} - -/// Create directory if not exists pub fn create_dir(path: &Path) -> Result<()> { fs::create_dir_all(path).map_err(|e| { AppError::Internal(format!( @@ -172,7 +117,6 @@ pub fn create_dir(path: &Path) -> Result<()> { }) } -/// Remove directory pub fn remove_dir(path: &Path) -> Result<()> { if path.exists() { fs::remove_dir(path).map_err(|e| { @@ -186,7 +130,6 @@ pub fn remove_dir(path: &Path) -> Result<()> { Ok(()) } -/// Remove file pub fn remove_file(path: &Path) -> Result<()> { if path.exists() { fs::remove_file(path).map_err(|e| { @@ -196,7 +139,6 @@ pub fn remove_file(path: &Path) -> Result<()> { Ok(()) } -/// Create symlink pub fn create_symlink(src: &Path, dest: &Path) -> Result<()> { std::os::unix::fs::symlink(src, dest).map_err(|e| { AppError::Internal(format!( diff --git a/src/otg/endpoint.rs b/src/otg/endpoint.rs index 36399b9a..3b650409 100644 --- a/src/otg/endpoint.rs +++ b/src/otg/endpoint.rs @@ -1,11 +1,7 @@ -//! USB Endpoint allocation management - use crate::error::{AppError, Result}; -/// Default maximum endpoints for typical UDC pub const DEFAULT_MAX_ENDPOINTS: u8 = 16; -/// Endpoint allocator - manages UDC endpoint resources #[derive(Debug, Clone)] pub struct EndpointAllocator { max_endpoints: u8, @@ -13,7 +9,6 @@ pub struct EndpointAllocator { } impl EndpointAllocator { - /// Create a new endpoint allocator pub fn new(max_endpoints: u8) -> Self { Self { max_endpoints, @@ -21,7 +16,6 @@ impl EndpointAllocator { } } - /// Allocate endpoints for a function pub fn allocate(&mut self, count: u8) -> Result<()> { if self.used_endpoints + count > self.max_endpoints { return Err(AppError::Internal(format!( @@ -34,27 +28,22 @@ impl EndpointAllocator { Ok(()) } - /// Release endpoints pub fn release(&mut self, count: u8) { self.used_endpoints = self.used_endpoints.saturating_sub(count); } - /// Get available endpoint count pub fn available(&self) -> u8 { self.max_endpoints.saturating_sub(self.used_endpoints) } - /// Get used endpoint count pub fn used(&self) -> u8 { self.used_endpoints } - /// Get maximum endpoint count pub fn max(&self) -> u8 { self.max_endpoints } - /// Check if can allocate pub fn can_allocate(&self, count: u8) -> bool { self.available() >= count } @@ -82,7 +71,6 @@ mod tests { alloc.allocate(4).unwrap(); assert_eq!(alloc.available(), 2); - // Should fail - not enough endpoints assert!(alloc.allocate(3).is_err()); alloc.release(2); diff --git a/src/otg/function.rs b/src/otg/function.rs index 010240e8..a4f353d4 100644 --- a/src/otg/function.rs +++ b/src/otg/function.rs @@ -1,42 +1,17 @@ -//! USB Gadget Function trait definition - use std::path::Path; use crate::error::Result; -/// Function metadata -#[derive(Debug, Clone)] -pub struct FunctionMeta { - /// Function name (e.g., "hid.usb0") - pub name: String, - /// Human-readable description - pub description: String, - /// Number of endpoints used - pub endpoints: u8, - /// Whether the function is enabled - pub enabled: bool, -} - -/// USB Gadget Function trait pub trait GadgetFunction: Send + Sync { - /// Get function name (e.g., "hid.usb0", "mass_storage.usb0") fn name(&self) -> &str; - /// Get number of endpoints required fn endpoints_required(&self) -> u8; - /// Get function metadata - fn meta(&self) -> FunctionMeta; - - /// Create function directory and configuration in ConfigFS fn create(&self, gadget_path: &Path) -> Result<()>; - /// Link function to configuration fn link(&self, config_path: &Path, gadget_path: &Path) -> Result<()>; - /// Unlink function from configuration fn unlink(&self, config_path: &Path) -> Result<()>; - /// Cleanup function directory fn cleanup(&self, gadget_path: &Path) -> Result<()>; } diff --git a/src/otg/hid.rs b/src/otg/hid.rs index 1c6c236a..598ce98e 100644 --- a/src/otg/hid.rs +++ b/src/otg/hid.rs @@ -1,35 +1,24 @@ -//! HID Function implementation for USB Gadget - use std::path::{Path, PathBuf}; use tracing::debug; use super::configfs::{ create_dir, create_symlink, remove_dir, remove_file, write_bytes, write_file, }; -use super::function::{FunctionMeta, GadgetFunction}; +use super::function::GadgetFunction; use super::report_desc::{ CONSUMER_CONTROL, KEYBOARD, KEYBOARD_WITH_LED, MOUSE_ABSOLUTE, MOUSE_RELATIVE, }; use crate::error::Result; -/// HID function type #[derive(Debug, Clone)] pub enum HidFunctionType { - /// Keyboard Keyboard, - /// Relative mouse (traditional mouse movement) - /// Uses 1 endpoint: IN MouseRelative, - /// Absolute mouse (touchscreen-like positioning) - /// Uses 1 endpoint: IN MouseAbsolute, - /// Consumer control (multimedia keys) - /// Uses 1 endpoint: IN ConsumerControl, } impl HidFunctionType { - /// Get the base endpoint cost for this function type. pub fn endpoints(&self) -> u8 { match self { HidFunctionType::Keyboard => 1, @@ -39,27 +28,24 @@ impl HidFunctionType { } } - /// Get HID protocol pub fn protocol(&self) -> u8 { match self { - HidFunctionType::Keyboard => 1, // Keyboard - HidFunctionType::MouseRelative => 2, // Mouse - HidFunctionType::MouseAbsolute => 2, // Mouse - HidFunctionType::ConsumerControl => 0, // None + HidFunctionType::Keyboard => 1, + HidFunctionType::MouseRelative => 2, + HidFunctionType::MouseAbsolute => 2, + HidFunctionType::ConsumerControl => 0, } } - /// Get HID subclass pub fn subclass(&self) -> u8 { match self { - HidFunctionType::Keyboard => 1, // Boot interface - HidFunctionType::MouseRelative => 1, // Boot interface - HidFunctionType::MouseAbsolute => 0, // No boot interface - HidFunctionType::ConsumerControl => 0, // No boot interface + HidFunctionType::Keyboard => 1, + HidFunctionType::MouseRelative => 1, + HidFunctionType::MouseAbsolute => 0, + HidFunctionType::ConsumerControl => 0, } } - /// Get report length in bytes pub fn report_length(&self, _keyboard_leds: bool) -> u8 { match self { HidFunctionType::Keyboard => 8, @@ -69,7 +55,6 @@ impl HidFunctionType { } } - /// Get report descriptor pub fn report_desc(&self, keyboard_leds: bool) -> &'static [u8] { match self { HidFunctionType::Keyboard => { @@ -84,33 +69,17 @@ impl HidFunctionType { HidFunctionType::ConsumerControl => CONSUMER_CONTROL, } } - - /// Get description - pub fn description(&self) -> &'static str { - match self { - HidFunctionType::Keyboard => "Keyboard", - HidFunctionType::MouseRelative => "Relative Mouse", - HidFunctionType::MouseAbsolute => "Absolute Mouse", - HidFunctionType::ConsumerControl => "Consumer Control", - } - } } -/// HID Function for USB Gadget #[derive(Debug, Clone)] pub struct HidFunction { - /// Instance number (usb0, usb1, ...) instance: u8, - /// Function type func_type: HidFunctionType, - /// Cached function name (avoids repeated allocation) name: String, - /// Whether keyboard LED/status feedback is enabled. keyboard_leds: bool, } impl HidFunction { - /// Create a keyboard function pub fn keyboard(instance: u8, keyboard_leds: bool) -> Self { Self { instance, @@ -120,7 +89,6 @@ impl HidFunction { } } - /// Create a relative mouse function pub fn mouse_relative(instance: u8) -> Self { Self { instance, @@ -130,7 +98,6 @@ impl HidFunction { } } - /// Create an absolute mouse function pub fn mouse_absolute(instance: u8) -> Self { Self { instance, @@ -140,7 +107,6 @@ impl HidFunction { } } - /// Create a consumer control function pub fn consumer_control(instance: u8) -> Self { Self { instance, @@ -150,12 +116,10 @@ impl HidFunction { } } - /// Get function path in gadget fn function_path(&self, gadget_path: &Path) -> PathBuf { gadget_path.join("functions").join(self.name()) } - /// Get expected device path (e.g., /dev/hidg0) pub fn device_path(&self) -> PathBuf { PathBuf::from(format!("/dev/hidg{}", self.instance)) } @@ -170,20 +134,10 @@ impl GadgetFunction for HidFunction { self.func_type.endpoints() } - fn meta(&self) -> FunctionMeta { - FunctionMeta { - name: self.name().to_string(), - description: self.func_type.description().to_string(), - endpoints: self.endpoints_required(), - enabled: true, - } - } - fn create(&self, gadget_path: &Path) -> Result<()> { let func_path = self.function_path(gadget_path); create_dir(&func_path)?; - // Set HID parameters write_file( &func_path.join("protocol"), &self.func_type.protocol().to_string(), @@ -197,7 +151,6 @@ impl GadgetFunction for HidFunction { &self.func_type.report_length(self.keyboard_leds).to_string(), )?; - // Write report descriptor write_bytes( &func_path.join("report_desc"), self.func_type.report_desc(self.keyboard_leds), diff --git a/src/otg/manager.rs b/src/otg/manager.rs index b120f37f..c507147e 100644 --- a/src/otg/manager.rs +++ b/src/otg/manager.rs @@ -1,6 +1,3 @@ -//! OTG Gadget Manager - unified management for USB Gadget functions - -use std::collections::HashMap; use std::fs; use std::path::PathBuf; use tracing::{debug, error, info, warn}; @@ -11,14 +8,13 @@ use super::configfs::{ DEFAULT_USB_VENDOR_ID, USB_BCD_USB, }; use super::endpoint::{EndpointAllocator, DEFAULT_MAX_ENDPOINTS}; -use super::function::{FunctionMeta, GadgetFunction}; +use super::function::GadgetFunction; use super::hid::HidFunction; use super::msd::MsdFunction; use crate::error::{AppError, Result}; const REBIND_DELAY_MS: u64 = 300; -/// USB Gadget device descriptor configuration #[derive(Debug, Clone, PartialEq, Eq)] pub struct GadgetDescriptor { pub vendor_id: u16, @@ -42,44 +38,28 @@ impl Default for GadgetDescriptor { } } -/// OTG Gadget Manager - unified management for HID and MSD pub struct OtgGadgetManager { - /// Gadget name gadget_name: String, - /// Gadget path in ConfigFS gadget_path: PathBuf, - /// Configuration path config_path: PathBuf, - /// Device descriptor descriptor: GadgetDescriptor, - /// Endpoint allocator endpoint_allocator: EndpointAllocator, - /// HID instance counter hid_instance: u8, - /// MSD instance counter msd_instance: u8, - /// Registered functions functions: Vec>, - /// Function metadata - meta: HashMap, - /// Bound UDC name bound_udc: Option, - /// Whether gadget was created by us created_by_us: bool, } impl OtgGadgetManager { - /// Create a new gadget manager with default settings pub fn new() -> Self { Self::with_config(DEFAULT_GADGET_NAME, DEFAULT_MAX_ENDPOINTS) } - /// Create a new gadget manager with custom configuration pub fn with_config(gadget_name: &str, max_endpoints: u8) -> Self { Self::with_descriptor(gadget_name, max_endpoints, GadgetDescriptor::default()) } - /// Create a new gadget manager with custom descriptor pub fn with_descriptor( gadget_name: &str, max_endpoints: u8, @@ -96,30 +76,24 @@ impl OtgGadgetManager { endpoint_allocator: EndpointAllocator::new(max_endpoints), hid_instance: 0, msd_instance: 0, - // Pre-allocate for typical use: 3 HID (keyboard, rel mouse, abs mouse) + 1 MSD functions: Vec::with_capacity(4), - meta: HashMap::with_capacity(4), bound_udc: None, created_by_us: false, } } - /// Check if ConfigFS is available pub fn is_available() -> bool { is_configfs_available() } - /// Find available UDC pub fn find_udc() -> Option { find_udc() } - /// Check if gadget exists pub fn gadget_exists(&self) -> bool { self.gadget_path.exists() } - /// Check if gadget is bound to UDC pub fn is_bound(&self) -> bool { let udc_file = self.gadget_path.join("UDC"); if let Ok(content) = fs::read_to_string(&udc_file) { @@ -129,8 +103,6 @@ impl OtgGadgetManager { } } - /// Add keyboard function - /// Returns the expected device path (e.g., /dev/hidg0) pub fn add_keyboard(&mut self, keyboard_leds: bool) -> Result { let func = HidFunction::keyboard(self.hid_instance, keyboard_leds); let device_path = func.device_path(); @@ -139,7 +111,6 @@ impl OtgGadgetManager { Ok(device_path) } - /// Add relative mouse function pub fn add_mouse_relative(&mut self) -> Result { let func = HidFunction::mouse_relative(self.hid_instance); let device_path = func.device_path(); @@ -148,7 +119,6 @@ impl OtgGadgetManager { Ok(device_path) } - /// Add absolute mouse function pub fn add_mouse_absolute(&mut self) -> Result { let func = HidFunction::mouse_absolute(self.hid_instance); let device_path = func.device_path(); @@ -157,7 +127,6 @@ impl OtgGadgetManager { Ok(device_path) } - /// Add consumer control function (multimedia keys) pub fn add_consumer_control(&mut self) -> Result { let func = HidFunction::consumer_control(self.hid_instance); let device_path = func.device_path(); @@ -166,7 +135,6 @@ impl OtgGadgetManager { Ok(device_path) } - /// Add MSD function (returns MsdFunction handle for LUN configuration) pub fn add_msd(&mut self) -> Result { let func = MsdFunction::new(self.msd_instance); let func_clone = func.clone(); @@ -175,11 +143,9 @@ impl OtgGadgetManager { Ok(func_clone) } - /// Add a generic function fn add_function(&mut self, func: Box) -> Result<()> { let endpoints = func.endpoints_required(); - // Check endpoint availability if !self.endpoint_allocator.can_allocate(endpoints) { return Err(AppError::Internal(format!( "Not enough endpoints for function {}: need {}, available {}", @@ -189,30 +155,22 @@ impl OtgGadgetManager { ))); } - // Allocate endpoints self.endpoint_allocator.allocate(endpoints)?; - // Store metadata - self.meta.insert(func.name().to_string(), func.meta()); - - // Store function self.functions.push(func); Ok(()) } - /// Setup the gadget (create directories and configure) pub fn setup(&mut self) -> Result<()> { info!("Setting up OTG USB Gadget: {}", self.gadget_name); - // Check ConfigFS availability if !Self::is_available() { return Err(AppError::Internal( "ConfigFS not available. Is it mounted at /sys/kernel/config?".to_string(), )); } - // Check if gadget already exists and is bound if self.gadget_exists() { if self.is_bound() { info!("Gadget already exists and is bound, skipping setup"); @@ -222,20 +180,15 @@ impl OtgGadgetManager { self.cleanup()?; } - // Create gadget directory create_dir(&self.gadget_path)?; self.created_by_us = true; - // Set device descriptors self.set_device_descriptors()?; - // Create strings self.create_strings()?; - // Create configuration self.create_configuration()?; - // Create and link all functions for func in &self.functions { func.create(&self.gadget_path)?; func.link(&self.config_path, &self.gadget_path)?; @@ -245,9 +198,7 @@ impl OtgGadgetManager { Ok(()) } - /// Bind gadget to a specific UDC pub fn bind(&mut self, udc: &str) -> Result<()> { - // Recreate config symlinks before binding to avoid kernel gadget issues after rebind if let Err(e) = self.recreate_config_links() { warn!("Failed to recreate gadget config links before bind: {}", e); } @@ -260,7 +211,6 @@ impl OtgGadgetManager { Ok(()) } - /// Unbind gadget from UDC pub fn unbind(&mut self) -> Result<()> { if self.is_bound() { write_file(&self.gadget_path.join("UDC"), "")?; @@ -271,7 +221,6 @@ impl OtgGadgetManager { Ok(()) } - /// Cleanup all resources pub fn cleanup(&mut self) -> Result<()> { if !self.gadget_exists() { return Ok(()); @@ -279,29 +228,23 @@ impl OtgGadgetManager { info!("Cleaning up OTG USB Gadget: {}", self.gadget_name); - // Unbind from UDC first let _ = self.unbind(); - // Unlink and cleanup functions for func in self.functions.iter().rev() { let _ = func.unlink(&self.config_path); } - // Remove config strings let config_strings = self.config_path.join("strings/0x409"); let _ = remove_dir(&config_strings); let _ = remove_dir(&self.config_path); - // Cleanup functions for func in self.functions.iter().rev() { let _ = func.cleanup(&self.gadget_path); } - // Remove gadget strings let gadget_strings = self.gadget_path.join("strings/0x409"); let _ = remove_dir(&gadget_strings); - // Remove gadget directory if let Err(e) = remove_dir(&self.gadget_path) { warn!("Could not remove gadget directory: {}", e); } @@ -311,7 +254,6 @@ impl OtgGadgetManager { Ok(()) } - /// Set USB device descriptors fn set_device_descriptors(&self) -> Result<()> { write_file( &self.gadget_path.join("idVendor"), @@ -329,14 +271,13 @@ impl OtgGadgetManager { &self.gadget_path.join("bcdUSB"), &format!("0x{:04x}", USB_BCD_USB), )?; - write_file(&self.gadget_path.join("bDeviceClass"), "0x00")?; // Composite device + write_file(&self.gadget_path.join("bDeviceClass"), "0x00")?; write_file(&self.gadget_path.join("bDeviceSubClass"), "0x00")?; write_file(&self.gadget_path.join("bDeviceProtocol"), "0x00")?; debug!("Set device descriptors"); Ok(()) } - /// Create USB strings fn create_strings(&self) -> Result<()> { let strings_path = self.gadget_path.join("strings/0x409"); create_dir(&strings_path)?; @@ -354,41 +295,23 @@ impl OtgGadgetManager { Ok(()) } - /// Create configuration fn create_configuration(&self) -> Result<()> { create_dir(&self.config_path)?; - // Create config strings let strings_path = self.config_path.join("strings/0x409"); create_dir(&strings_path)?; write_file(&strings_path.join("configuration"), "Config 1: HID + MSD")?; - // Set max power (500mA) write_file(&self.config_path.join("MaxPower"), "500")?; debug!("Created configuration c.1"); Ok(()) } - /// Get function metadata - pub fn get_meta(&self) -> &HashMap { - &self.meta - } - - /// Get endpoint usage info - pub fn endpoint_info(&self) -> (u8, u8) { - ( - self.endpoint_allocator.used(), - self.endpoint_allocator.max(), - ) - } - - /// Get gadget path pub fn gadget_path(&self) -> &PathBuf { &self.gadget_path } - /// Recreate config symlinks from functions directory fn recreate_config_links(&self) -> Result<()> { let functions_path = self.gadget_path.join("functions"); if !functions_path.exists() || !self.config_path.exists() { @@ -450,15 +373,10 @@ impl Drop for OtgGadgetManager { } } -/// Wait for HID devices to become available -/// -/// Uses exponential backoff starting from 10ms, capped at 100ms, -/// to reduce CPU usage while still providing fast response. pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) -> bool { let start = std::time::Instant::now(); let timeout = std::time::Duration::from_millis(timeout_ms); - // Exponential backoff: start at 10ms, double each time, cap at 100ms let mut delay_ms = 10u64; const MAX_DELAY_MS: u64 = 100; @@ -467,7 +385,6 @@ pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) -> return true; } - // Calculate remaining time to avoid overshooting timeout let remaining = timeout.saturating_sub(start.elapsed()); let sleep_duration = std::time::Duration::from_millis(delay_ms).min(remaining); @@ -477,7 +394,6 @@ pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) -> tokio::time::sleep(sleep_duration).await; - // Exponential backoff with cap delay_ms = (delay_ms * 2).min(MAX_DELAY_MS); } @@ -492,18 +408,16 @@ mod tests { fn test_manager_creation() { let manager = OtgGadgetManager::new(); assert_eq!(manager.gadget_name, DEFAULT_GADGET_NAME); - assert!(!manager.gadget_exists()); // Won't exist in test environment + assert!(!manager.gadget_exists()); } #[test] fn test_endpoint_tracking() { let mut manager = OtgGadgetManager::with_config("test", 8); - // Keyboard uses 1 endpoint let _ = manager.add_keyboard(false); assert_eq!(manager.endpoint_allocator.used(), 1); - // Mouse uses 1 endpoint each let _ = manager.add_mouse_relative(); let _ = manager.add_mouse_absolute(); assert_eq!(manager.endpoint_allocator.used(), 3); diff --git a/src/otg/mod.rs b/src/otg/mod.rs index f016bc89..9d6924e1 100644 --- a/src/otg/mod.rs +++ b/src/otg/mod.rs @@ -1,21 +1,4 @@ -//! OTG USB Gadget unified management module -//! -//! This module provides unified management for USB Gadget functions: -//! - HID (Keyboard, Mouse) -//! - MSD (Mass Storage Device) -//! -//! Architecture: -//! ```text -//! OtgService (high-level coordination) -//! └── OtgGadgetManager (gadget lifecycle) -//! ├── EndpointAllocator (manages UDC endpoints) -//! ├── HidFunction (keyboard, mouse_rel, mouse_abs) -//! └── MsdFunction (mass storage) -//! ``` -//! -//! The recommended way to use this module is through `OtgService`, which provides -//! a high-level interface for enabling/disabling HID and MSD functions independently. -//! Both `HidController` and `MsdController` should share the same `OtgService` instance. +//! USB OTG composite gadget (HID + MSD). pub mod configfs; pub mod endpoint; @@ -26,10 +9,6 @@ pub mod msd; pub mod report_desc; pub mod service; -pub use endpoint::EndpointAllocator; -pub use function::{FunctionMeta, GadgetFunction}; -pub use hid::{HidFunction, HidFunctionType}; pub use manager::{wait_for_hid_devices, OtgGadgetManager}; pub use msd::{MsdFunction, MsdLunConfig}; -pub use report_desc::{KEYBOARD, MOUSE_ABSOLUTE, MOUSE_RELATIVE}; -pub use service::{HidDevicePaths, OtgDesiredState, OtgService, OtgServiceState}; +pub use service::{HidDevicePaths, OtgService}; diff --git a/src/otg/msd.rs b/src/otg/msd.rs index bd37fa71..a762d74b 100644 --- a/src/otg/msd.rs +++ b/src/otg/msd.rs @@ -1,25 +1,17 @@ -//! MSD (Mass Storage Device) Function implementation for USB Gadget - use std::fs; use std::path::{Path, PathBuf}; use tracing::{debug, info, warn}; use super::configfs::{create_dir, create_symlink, remove_dir, remove_file, write_file}; -use super::function::{FunctionMeta, GadgetFunction}; +use super::function::GadgetFunction; use crate::error::{AppError, Result}; -/// MSD LUN configuration #[derive(Debug, Clone)] pub struct MsdLunConfig { - /// File/image path to expose pub file: PathBuf, - /// Mount as CD-ROM pub cdrom: bool, - /// Read-only mode pub ro: bool, - /// Removable media pub removable: bool, - /// Disable Force Unit Access pub nofua: bool, } @@ -36,7 +28,6 @@ impl Default for MsdLunConfig { } impl MsdLunConfig { - /// Create CD-ROM configuration pub fn cdrom(file: PathBuf) -> Self { Self { file, @@ -47,7 +38,6 @@ impl MsdLunConfig { } } - /// Create disk configuration pub fn disk(file: PathBuf, read_only: bool) -> Self { Self { file, @@ -59,38 +49,26 @@ impl MsdLunConfig { } } -/// MSD Function for USB Gadget #[derive(Debug, Clone)] pub struct MsdFunction { - /// Instance number (usb0, usb1, ...) - instance: u8, - /// Cached function name (avoids repeated allocation) name: String, } impl MsdFunction { - /// Create a new MSD function pub fn new(instance: u8) -> Self { Self { - instance, name: format!("mass_storage.usb{}", instance), } } - /// Get function path in gadget fn function_path(&self, gadget_path: &Path) -> PathBuf { gadget_path.join("functions").join(self.name()) } - /// Get LUN path fn lun_path(&self, gadget_path: &Path, lun: u8) -> PathBuf { self.function_path(gadget_path).join(format!("lun.{}", lun)) } - /// Configure a LUN with specified settings (async version) - /// - /// This is the preferred method for async contexts. It runs the blocking - /// file I/O and USB timing operations in a separate thread pool. pub async fn configure_lun_async( &self, gadget_path: &Path, @@ -106,17 +84,6 @@ impl MsdFunction { .map_err(|e| AppError::Internal(format!("Task join error: {}", e)))? } - /// Configure a LUN with specified settings - /// Note: This should be called after the gadget is set up - /// - /// This implementation is based on PiKVM's MSD drive configuration. - /// Key improvements: - /// - Uses forced_eject when available (safer than clearing file directly) - /// - Reduced sleep times to minimize HID interference - /// - Better retry logic for EBUSY errors - /// - /// **Note**: This is a blocking function. In async contexts, prefer - /// `configure_lun_async` to avoid blocking the runtime. pub fn configure_lun(&self, gadget_path: &Path, lun: u8, config: &MsdLunConfig) -> Result<()> { let lun_path = self.lun_path(gadget_path, lun); @@ -124,7 +91,6 @@ impl MsdFunction { create_dir(&lun_path)?; } - // Batch read all current values to minimize syscalls let read_attr = |attr: &str| -> String { fs::read_to_string(lun_path.join(attr)) .unwrap_or_default() @@ -137,28 +103,21 @@ impl MsdFunction { let current_removable = read_attr("removable"); let current_nofua = read_attr("nofua"); - // Prepare new values let new_cdrom = if config.cdrom { "1" } else { "0" }; let new_ro = if config.ro { "1" } else { "0" }; let new_removable = if config.removable { "1" } else { "0" }; let new_nofua = if config.nofua { "1" } else { "0" }; - // Disconnect current file first using forced_eject if available (PiKVM approach) let forced_eject_path = lun_path.join("forced_eject"); if forced_eject_path.exists() { - // forced_eject is safer - it forcibly detaches regardless of host state debug!("Using forced_eject to clear LUN {}", lun); let _ = write_file(&forced_eject_path, "1"); } else { - // Fallback to clearing file directly let _ = write_file(&lun_path.join("file"), ""); } - // Brief yield to allow USB stack to process the disconnect - // Reduced from 200ms to 50ms - let USB protocol handle timing std::thread::sleep(std::time::Duration::from_millis(50)); - // Write only changed attributes let cdrom_changed = current_cdrom != new_cdrom; if cdrom_changed { debug!( @@ -186,13 +145,11 @@ impl MsdFunction { write_file(&lun_path.join("nofua"), new_nofua)?; } - // If cdrom mode changed, brief yield for USB host if cdrom_changed { debug!("CDROM mode changed, brief yield for USB host"); std::thread::sleep(std::time::Duration::from_millis(50)); } - // Set file path (this triggers the actual mount) - with retry on EBUSY if config.file.exists() { let file_path = config.file.to_string_lossy(); let mut last_error = None; @@ -210,7 +167,6 @@ impl MsdFunction { return Ok(()); } Err(e) => { - // Check if it's EBUSY (error code 16) let is_busy = e.to_string().contains("Device or resource busy") || e.to_string().contains("os error 16"); @@ -220,7 +176,6 @@ impl MsdFunction { lun, attempt + 1 ); - // Exponential backoff: 50, 100, 200, 400ms std::thread::sleep(std::time::Duration::from_millis(50 << attempt)); last_error = Some(e); continue; @@ -231,7 +186,6 @@ impl MsdFunction { } } - // If we get here, all retries failed if let Some(e) = last_error { return Err(e); } @@ -242,9 +196,6 @@ impl MsdFunction { Ok(()) } - /// Disconnect LUN (async version) - /// - /// Preferred for async contexts. pub async fn disconnect_lun_async(&self, gadget_path: &Path, lun: u8) -> Result<()> { let gadget_path = gadget_path.to_path_buf(); let this = self.clone(); @@ -254,17 +205,10 @@ impl MsdFunction { .map_err(|e| AppError::Internal(format!("Task join error: {}", e)))? } - /// Disconnect LUN (clear file) - /// - /// This method uses forced_eject when available, which is safer than - /// directly clearing the file path. Based on PiKVM's implementation. - /// See: https://docs.kernel.org/usb/mass-storage.html pub fn disconnect_lun(&self, gadget_path: &Path, lun: u8) -> Result<()> { let lun_path = self.lun_path(gadget_path, lun); if lun_path.exists() { - // Prefer forced_eject if available (PiKVM approach) - // forced_eject forcibly detaches the backing file regardless of host state let forced_eject_path = lun_path.join("forced_eject"); if forced_eject_path.exists() { debug!( @@ -282,7 +226,6 @@ impl MsdFunction { } } } else { - // Fallback to clearing file directly write_file(&lun_path.join("file"), "")?; } info!("LUN {} disconnected", lun); @@ -291,7 +234,6 @@ impl MsdFunction { Ok(()) } - /// Get current LUN file path pub fn get_lun_file(&self, gadget_path: &Path, lun: u8) -> Option { let lun_path = self.lun_path(gadget_path, lun); let file_path = lun_path.join("file"); @@ -306,7 +248,6 @@ impl MsdFunction { None } - /// Check if LUN is connected pub fn is_lun_connected(&self, gadget_path: &Path, lun: u8) -> bool { self.get_lun_file(gadget_path, lun).is_some() } @@ -318,39 +259,23 @@ impl GadgetFunction for MsdFunction { } fn endpoints_required(&self) -> u8 { - 2 // IN + OUT for bulk transfers - } - - fn meta(&self) -> FunctionMeta { - FunctionMeta { - name: self.name().to_string(), - description: if self.instance == 0 { - "Mass Storage Drive".to_string() - } else { - format!("Extra Drive #{}", self.instance) - }, - endpoints: self.endpoints_required(), - enabled: true, - } + 2 } fn create(&self, gadget_path: &Path) -> Result<()> { let func_path = self.function_path(gadget_path); create_dir(&func_path)?; - // Set stall to 0 (workaround for some hosts) let stall_path = func_path.join("stall"); if stall_path.exists() { let _ = write_file(&stall_path, "0"); } - // LUN 0 is created automatically, but ensure it exists let lun0_path = func_path.join("lun.0"); if !lun0_path.exists() { create_dir(&lun0_path)?; } - // Set default LUN 0 parameters let _ = write_file(&lun0_path.join("cdrom"), "0"); let _ = write_file(&lun0_path.join("ro"), "0"); let _ = write_file(&lun0_path.join("removable"), "1"); @@ -382,12 +307,10 @@ impl GadgetFunction for MsdFunction { fn cleanup(&self, gadget_path: &Path) -> Result<()> { let func_path = self.function_path(gadget_path); - // Disconnect all LUNs first for lun in 0..8 { let _ = self.disconnect_lun(gadget_path, lun); } - // Remove function directory if let Err(e) = remove_dir(&func_path) { warn!("Could not remove MSD function directory: {}", e); } diff --git a/src/otg/report_desc.rs b/src/otg/report_desc.rs index 62fa0dc6..da488618 100644 --- a/src/otg/report_desc.rs +++ b/src/otg/report_desc.rs @@ -1,10 +1,3 @@ -//! HID Report Descriptors - -/// Keyboard HID Report Descriptor (no LED output) -/// Report format (8 bytes input): -/// [0] Modifier keys (8 bits) -/// [1] Reserved -/// [2-7] Key codes (6 keys) pub const KEYBOARD: &[u8] = &[ 0x05, 0x01, // Usage Page (Generic Desktop) 0x09, 0x06, // Usage (Keyboard) @@ -34,13 +27,6 @@ pub const KEYBOARD: &[u8] = &[ 0xC0, // End Collection ]; -/// Keyboard HID Report Descriptor with LED output support. -/// Input report format (8 bytes): -/// [0] Modifier keys (8 bits) -/// [1] Reserved -/// [2-7] Key codes (6 keys) -/// Output report format (1 byte): -/// [0] Num Lock / Caps Lock / Scroll Lock / Compose / Kana pub const KEYBOARD_WITH_LED: &[u8] = &[ 0x05, 0x01, // Usage Page (Generic Desktop) 0x09, 0x06, // Usage (Keyboard) @@ -81,12 +67,6 @@ pub const KEYBOARD_WITH_LED: &[u8] = &[ 0xC0, // End Collection ]; -/// Relative Mouse HID Report Descriptor (4 bytes report) -/// Report format: -/// [0] Buttons (5 bits) + padding (3 bits) -/// [1] X movement (signed 8-bit) -/// [2] Y movement (signed 8-bit) -/// [3] Wheel (signed 8-bit) pub const MOUSE_RELATIVE: &[u8] = &[ 0x05, 0x01, // Usage Page (Generic Desktop) 0x09, 0x02, // Usage (Mouse) @@ -126,12 +106,6 @@ pub const MOUSE_RELATIVE: &[u8] = &[ 0xC0, // End Collection ]; -/// Absolute Mouse HID Report Descriptor (6 bytes report) -/// Report format: -/// [0] Buttons (5 bits) + padding (3 bits) -/// [1-2] X position (16-bit, 0-32767) -/// [3-4] Y position (16-bit, 0-32767) -/// [5] Wheel (signed 8-bit) pub const MOUSE_ABSOLUTE: &[u8] = &[ 0x05, 0x01, // Usage Page (Generic Desktop) 0x09, 0x02, // Usage (Mouse) @@ -177,10 +151,6 @@ pub const MOUSE_ABSOLUTE: &[u8] = &[ 0xC0, // End Collection ]; -/// Consumer Control HID Report Descriptor (2 bytes report) -/// Report format: -/// [0-1] Consumer Control Usage (16-bit little-endian) -/// Supports: Play/Pause, Stop, Next/Prev Track, Mute, Volume Up/Down, etc. pub const CONSUMER_CONTROL: &[u8] = &[ 0x05, 0x0C, // Usage Page (Consumer) 0x09, 0x01, // Usage (Consumer Control) diff --git a/src/otg/service.rs b/src/otg/service.rs index 1166f48d..8988a04a 100644 --- a/src/otg/service.rs +++ b/src/otg/service.rs @@ -1,9 +1,3 @@ -//! OTG Service - unified gadget lifecycle management -//! -//! This module provides centralized management for USB OTG gadget functions. -//! It is the single owner of the USB gadget desired state and reconciles -//! ConfigFS to match that state. - use std::path::PathBuf; use tokio::sync::{Mutex, RwLock}; use tracing::{debug, info, warn}; @@ -13,7 +7,6 @@ use super::msd::MsdFunction; use crate::config::{HidBackend, HidConfig, MsdConfig, OtgDescriptorConfig, OtgHidFunctions}; use crate::error::{AppError, Result}; -/// HID device paths #[derive(Debug, Clone, Default)] pub struct HidDevicePaths { pub keyboard: Option, @@ -26,26 +19,20 @@ pub struct HidDevicePaths { impl HidDevicePaths { pub fn existing_paths(&self) -> Vec { - let mut paths = Vec::new(); - if let Some(ref p) = self.keyboard { - paths.push(p.clone()); - } - if let Some(ref p) = self.mouse_relative { - paths.push(p.clone()); - } - if let Some(ref p) = self.mouse_absolute { - paths.push(p.clone()); - } - if let Some(ref p) = self.consumer { - paths.push(p.clone()); - } - paths + [ + &self.keyboard, + &self.mouse_relative, + &self.mouse_absolute, + &self.consumer, + ] + .into_iter() + .filter_map(|p| p.as_ref().cloned()) + .collect() } } -/// Desired OTG gadget state derived from configuration. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct OtgDesiredState { +pub(crate) struct OtgDesiredState { pub udc: Option, pub descriptor: GadgetDescriptor, pub hid_functions: Option, @@ -68,7 +55,7 @@ impl Default for OtgDesiredState { } impl OtgDesiredState { - pub fn from_config(hid: &HidConfig, msd: &MsdConfig) -> Result { + pub(crate) fn from_config(hid: &HidConfig, msd: &MsdConfig) -> Result { let hid_functions = if hid.backend == HidBackend::Otg { let functions = hid.constrained_otg_functions(); Some(functions) @@ -96,45 +83,28 @@ impl OtgDesiredState { } } -/// OTG Service state #[derive(Debug, Clone, Default)] -pub struct OtgServiceState { - /// Whether the gadget is created and bound +struct OtgServiceState { pub gadget_active: bool, - /// Whether HID functions are enabled pub hid_enabled: bool, - /// Whether MSD function is enabled pub msd_enabled: bool, - /// Bound UDC name pub configured_udc: Option, - /// HID device paths (set after gadget setup) pub hid_paths: Option, - /// HID function selection (set after gadget setup) pub hid_functions: Option, - /// Whether keyboard LED/status feedback is enabled. pub keyboard_leds_enabled: bool, - /// Applied endpoint budget. pub max_endpoints: u8, - /// Applied descriptor configuration pub descriptor: Option, - /// Error message if setup failed pub error: Option, } -/// OTG Service - unified gadget lifecycle management pub struct OtgService { - /// The underlying gadget manager manager: Mutex>, - /// Current state state: RwLock, - /// MSD function handle (for runtime LUN configuration) msd_function: RwLock>, - /// Desired OTG state desired: RwLock, } impl OtgService { - /// Create a new OTG service pub fn new() -> Self { Self { manager: Mutex::new(None), @@ -144,55 +114,29 @@ impl OtgService { } } - /// Check if OTG is available on this system pub fn is_available() -> bool { OtgGadgetManager::is_available() && OtgGadgetManager::find_udc().is_some() } - /// Get current service state - pub async fn state(&self) -> OtgServiceState { - self.state.read().await.clone() - } - - /// Check if gadget is active - pub async fn is_gadget_active(&self) -> bool { - self.state.read().await.gadget_active - } - - /// Check if HID is enabled - pub async fn is_hid_enabled(&self) -> bool { - self.state.read().await.hid_enabled - } - - /// Check if MSD is enabled - pub async fn is_msd_enabled(&self) -> bool { - self.state.read().await.msd_enabled - } - - /// Get gadget path (for MSD LUN configuration) pub async fn gadget_path(&self) -> Option { let manager = self.manager.lock().await; manager.as_ref().map(|m| m.gadget_path().clone()) } - /// Get HID device paths pub async fn hid_device_paths(&self) -> Option { self.state.read().await.hid_paths.clone() } - /// Get MSD function handle (for LUN configuration) pub async fn msd_function(&self) -> Option { self.msd_function.read().await.clone() } - /// Apply desired OTG state derived from the current application config. pub async fn apply_config(&self, hid: &HidConfig, msd: &MsdConfig) -> Result<()> { let desired = OtgDesiredState::from_config(hid, msd)?; self.apply_desired_state(desired).await } - /// Apply a fully materialized desired OTG state. - pub async fn apply_desired_state(&self, desired: OtgDesiredState) -> Result<()> { + pub(crate) async fn apply_desired_state(&self, desired: OtgDesiredState) -> Result<()> { { let mut current = self.desired.write().await; *current = desired; @@ -392,7 +336,6 @@ impl OtgService { Ok(()) } - /// Shutdown the OTG service and cleanup all resources pub async fn shutdown(&self) -> Result<()> { info!("Shutting down OTG service"); @@ -425,12 +368,6 @@ impl Default for OtgService { } } -impl Drop for OtgService { - fn drop(&mut self) { - debug!("OtgService dropping"); - } -} - impl From<&OtgDescriptorConfig> for GadgetDescriptor { fn from(config: &OtgDescriptorConfig) -> Self { Self { @@ -452,17 +389,8 @@ mod tests { use super::*; #[test] - fn test_service_creation() { + fn service_new_and_availability_probe() { let _service = OtgService::new(); let _ = OtgService::is_available(); } - - #[tokio::test] - async fn test_initial_state() { - let service = OtgService::new(); - let state = service.state().await; - assert!(!state.gadget_active); - assert!(!state.hid_enabled); - assert!(!state.msd_enabled); - } } diff --git a/src/rtsp/auth.rs b/src/rtsp/auth.rs new file mode 100644 index 00000000..d1ccac9e --- /dev/null +++ b/src/rtsp/auth.rs @@ -0,0 +1,73 @@ +use base64::Engine; + +use crate::config::RtspConfig; + +use super::types::RtspRequest; + +pub(crate) fn extract_basic_auth(req: &RtspRequest) -> Option<(String, String)> { + let value = req.headers.get("authorization")?; + let mut parts = value.split_whitespace(); + let scheme = parts.next()?; + if !scheme.eq_ignore_ascii_case("basic") { + return None; + } + let b64 = parts.next()?; + let decoded = base64::engine::general_purpose::STANDARD.decode(b64).ok()?; + let raw = String::from_utf8(decoded).ok()?; + let (user, pass) = raw.split_once(':')?; + Some((user.to_string(), pass.to_string())) +} + +pub(crate) fn rtsp_auth_credentials(config: &RtspConfig) -> Option<(String, String)> { + let username = config.username.as_ref()?.trim(); + if username.is_empty() { + return None; + } + + Some(( + username.to_string(), + config.password.clone().unwrap_or_default(), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use rtsp_types as rtsp; + use std::collections::HashMap; + + #[test] + fn rtsp_auth_requires_non_empty_username() { + let mut config = RtspConfig::default(); + config.password = Some("secret".to_string()); + assert!(rtsp_auth_credentials(&config).is_none()); + + config.username = Some("".to_string()); + assert!(rtsp_auth_credentials(&config).is_none()); + + config.username = Some("user".to_string()); + let credentials = rtsp_auth_credentials(&config).expect("expected credentials"); + assert_eq!(credentials, ("user".to_string(), "secret".to_string())); + + config.password = None; + let credentials = rtsp_auth_credentials(&config).expect("expected credentials"); + assert_eq!(credentials, ("user".to_string(), "".to_string())); + } + + #[test] + fn extract_basic_auth_roundtrip() { + let encoded = base64::engine::general_purpose::STANDARD.encode(b"alice:pwd"); + let mut headers = HashMap::new(); + headers.insert("authorization".to_string(), format!("Basic {}", encoded)); + let req = RtspRequest { + method: rtsp::Method::Options, + uri: "*".to_string(), + version: rtsp::Version::V1_0, + headers, + }; + assert_eq!( + extract_basic_auth(&req), + Some(("alice".to_string(), "pwd".to_string())) + ); + } +} diff --git a/src/rtsp/bitstream.rs b/src/rtsp/bitstream.rs new file mode 100644 index 00000000..0da15e9b --- /dev/null +++ b/src/rtsp/bitstream.rs @@ -0,0 +1,96 @@ +use bytes::Bytes; + +use crate::video::encoder::registry::VideoEncoderType; +use crate::video::shared_video_pipeline::EncodedVideoFrame; + +use super::state::ParameterSets; + +pub(crate) fn update_parameter_sets(params: &mut ParameterSets, frame: &EncodedVideoFrame) { + let nal_units = split_annexb_nal_units(frame.data.as_ref()); + + match frame.codec { + VideoEncoderType::H264 => { + for nal in nal_units { + match h264_nal_type(nal) { + Some(7) => params.h264_sps = Some(Bytes::copy_from_slice(nal)), + Some(8) => params.h264_pps = Some(Bytes::copy_from_slice(nal)), + _ => {} + } + } + } + VideoEncoderType::H265 => { + for nal in nal_units { + match h265_nal_type(nal) { + Some(32) => params.h265_vps = Some(Bytes::copy_from_slice(nal)), + Some(33) => params.h265_sps = Some(Bytes::copy_from_slice(nal)), + Some(34) => params.h265_pps = Some(Bytes::copy_from_slice(nal)), + _ => {} + } + } + } + _ => {} + } +} + +fn split_annexb_nal_units(data: &[u8]) -> Vec<&[u8]> { + let mut nal_units = Vec::new(); + let mut cursor = 0usize; + + while let Some((start, start_code_len)) = find_annexb_start_code(data, cursor) { + let nal_start = start + start_code_len; + if nal_start >= data.len() { + break; + } + + let next_start = find_annexb_start_code(data, nal_start) + .map(|(idx, _)| idx) + .unwrap_or(data.len()); + + let mut nal_end = next_start; + while nal_end > nal_start && data[nal_end - 1] == 0 { + nal_end -= 1; + } + + if nal_end > nal_start { + nal_units.push(&data[nal_start..nal_end]); + } + + cursor = next_start; + } + + nal_units +} + +fn find_annexb_start_code(data: &[u8], from: usize) -> Option<(usize, usize)> { + if from >= data.len() { + return None; + } + + let mut i = from; + while i + 3 <= data.len() { + if i + 4 <= data.len() + && data[i] == 0 + && data[i + 1] == 0 + && data[i + 2] == 0 + && data[i + 3] == 1 + { + return Some((i, 4)); + } + + if data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 { + return Some((i, 3)); + } + + i += 1; + } + + None +} + +fn h264_nal_type(nal: &[u8]) -> Option { + nal.first().map(|value| value & 0x1f) +} + +fn h265_nal_type(nal: &[u8]) -> Option { + nal.first().map(|value| (value >> 1) & 0x3f) +} diff --git a/src/rtsp/codec.rs b/src/rtsp/codec.rs new file mode 100644 index 00000000..31c00818 --- /dev/null +++ b/src/rtsp/codec.rs @@ -0,0 +1,9 @@ +use crate::config::RtspCodec; +use crate::video::encoder::VideoCodecType; + +pub(crate) fn rtsp_codec_to_video(codec: RtspCodec) -> VideoCodecType { + match codec { + RtspCodec::H264 => VideoCodecType::H264, + RtspCodec::H265 => VideoCodecType::H265, + } +} diff --git a/src/rtsp/mod.rs b/src/rtsp/mod.rs index b8cfd8b5..907c6a18 100644 --- a/src/rtsp/mod.rs +++ b/src/rtsp/mod.rs @@ -1,3 +1,14 @@ -pub mod service; +//! RTSP TCP server exposing H.264/H.265 video from [`VideoStreamManager`](crate::video::VideoStreamManager). + +mod auth; +mod bitstream; +mod codec; +mod protocol; +mod response; +mod sdp; +mod service; +mod state; +mod streaming; +mod types; pub use service::{RtspService, RtspServiceStatus}; diff --git a/src/rtsp/protocol.rs b/src/rtsp/protocol.rs new file mode 100644 index 00000000..a6bdf936 --- /dev/null +++ b/src/rtsp/protocol.rs @@ -0,0 +1,193 @@ +use rtsp_types as rtsp; +use std::collections::HashMap; + +use super::types::RtspRequest; + +pub(crate) const OPTIONS_PUBLIC_CAPABILITIES: &str = + "OPTIONS, DESCRIBE, SETUP, PLAY, GET_PARAMETER, SET_PARAMETER, TEARDOWN"; + +pub(crate) fn strip_interleaved_frames_prefix(buffer: &mut Vec) -> bool { + if buffer.len() < 4 || buffer[0] != b'$' { + return false; + } + + let payload_len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize; + let frame_len = 4 + payload_len; + if buffer.len() < frame_len { + return false; + } + + buffer.drain(0..frame_len); + true +} + +pub(crate) fn take_rtsp_request_from_buffer(buffer: &mut Vec) -> Option { + let delimiter = b"\r\n\r\n"; + let pos = find_bytes(buffer, delimiter)?; + let req_end = pos + delimiter.len(); + let req_bytes: Vec = buffer.drain(0..req_end).collect(); + Some(String::from_utf8_lossy(&req_bytes).to_string()) +} + +fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option { + haystack + .windows(needle.len()) + .position(|window| window == needle) +} + +pub(crate) fn parse_rtsp_request(raw: &str) -> Option { + let (message, consumed): (rtsp::Message>, usize) = + rtsp::Message::parse(raw.as_bytes()).ok()?; + if consumed != raw.len() { + return None; + } + + let request = match message { + rtsp::Message::Request(req) => req, + _ => return None, + }; + + let uri = request + .request_uri() + .map(|value| value.as_str().to_string()) + .unwrap_or_default(); + + let mut headers = HashMap::new(); + for (name, value) in request.headers() { + headers.insert(name.to_string().to_ascii_lowercase(), value.to_string()); + } + + Some(RtspRequest { + method: request.method().clone(), + uri, + version: request.version(), + headers, + }) +} + +pub(crate) fn parse_interleaved_channel(transport: &str) -> Option { + let lower = transport.to_ascii_lowercase(); + if let Some((_, v)) = lower.split_once("interleaved=") { + let head = v.split(';').next().unwrap_or(v); + let first = head.split('-').next().unwrap_or(head).trim(); + return first.parse::().ok(); + } + None +} + +pub(crate) fn is_tcp_transport_request(transport: &str) -> bool { + transport + .split(',') + .map(str::trim) + .map(str::to_ascii_lowercase) + .any(|item| item.contains("rtp/avp/tcp") || item.contains("interleaved=")) +} + +pub(crate) fn is_valid_rtsp_path(method: &rtsp::Method, uri: &str, configured_path: &str) -> bool { + if matches!(method, rtsp::Method::Options) && uri.trim() == "*" { + return true; + } + + let normalized_cfg = configured_path.trim_matches('/'); + if normalized_cfg.is_empty() { + return false; + } + + let request_path = extract_rtsp_path(uri); + + if request_path == normalized_cfg { + return true; + } + + if !matches!(method, rtsp::Method::Setup | rtsp::Method::Teardown) { + return false; + } + + let control_track_path = format!("{}/trackID=0", normalized_cfg); + request_path == "trackID=0" || request_path == control_track_path +} + +fn extract_rtsp_path(uri: &str) -> String { + let raw_path = if let Some((_, remainder)) = uri.split_once("://") { + match remainder.find('/') { + Some(idx) => &remainder[idx..], + None => "/", + } + } else { + uri + }; + + raw_path + .split('?') + .next() + .unwrap_or(raw_path) + .split('#') + .next() + .unwrap_or(raw_path) + .trim_matches('/') + .to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rtsp_path_matching_follows_sdp_control_rules() { + assert!(is_valid_rtsp_path( + &rtsp::Method::Describe, + "rtsp://127.0.0.1/live", + "live" + )); + assert!(is_valid_rtsp_path( + &rtsp::Method::Describe, + "rtsp://127.0.0.1/live/?token=1", + "/live/" + )); + assert!(!is_valid_rtsp_path( + &rtsp::Method::Describe, + "rtsp://127.0.0.1/live2", + "live" + )); + assert!(!is_valid_rtsp_path( + &rtsp::Method::Describe, + "rtsp://127.0.0.1/", + "/" + )); + + assert!(is_valid_rtsp_path( + &rtsp::Method::Setup, + "rtsp://127.0.0.1/live/trackID=0", + "live" + )); + assert!(is_valid_rtsp_path( + &rtsp::Method::Setup, + "rtsp://127.0.0.1/trackID=0", + "live" + )); + assert!(!is_valid_rtsp_path( + &rtsp::Method::Describe, + "rtsp://127.0.0.1/live/trackID=0", + "live" + )); + + assert!(is_valid_rtsp_path(&rtsp::Method::Options, "*", "live")); + } + + #[test] + fn transport_parsing_detects_tcp_interleaved_requests() { + assert!(is_tcp_transport_request( + "RTP/AVP/TCP;unicast;interleaved=0-1" + )); + assert!(is_tcp_transport_request("RTP/AVP;unicast;interleaved=2-3")); + assert!(!is_tcp_transport_request( + "RTP/AVP;unicast;client_port=8000-8001" + )); + } + + #[test] + fn options_public_includes_standard_methods() { + assert!(OPTIONS_PUBLIC_CAPABILITIES.contains("GET_PARAMETER")); + assert!(OPTIONS_PUBLIC_CAPABILITIES.contains("TEARDOWN")); + } +} diff --git a/src/rtsp/response.rs b/src/rtsp/response.rs new file mode 100644 index 00000000..2b93be78 --- /dev/null +++ b/src/rtsp/response.rs @@ -0,0 +1,81 @@ +use rtsp_types as rtsp; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +use crate::error::{AppError, Result}; + +use super::types::RtspRequest; + +async fn serialize_and_write( + stream: &mut W, + response: rtsp::Response>, +) -> Result<()> { + let mut data = Vec::new(); + response + .write(&mut data) + .map_err(|e| AppError::BadRequest(format!("failed to serialize RTSP response: {}", e)))?; + stream.write_all(&data).await?; + Ok(()) +} + +pub(crate) async fn send_simple_response( + stream: &mut W, + code: u16, + _reason: &str, + cseq: Option<&str>, + body: &str, +) -> Result<()> { + let mut builder = rtsp::Response::builder(rtsp::Version::V1_0, status_code_from_u16(code)); + if let Some(cseq) = cseq { + builder = builder.header(rtsp::headers::CSEQ, cseq); + } + + let response = builder.build(body.as_bytes().to_vec()); + serialize_and_write(stream, response).await +} + +pub(crate) async fn send_response( + stream: &mut W, + req: &RtspRequest, + code: u16, + _reason: &str, + extra_headers: Vec<(String, String)>, + body: &str, + session_id: &str, +) -> Result<()> { + let cseq = req + .headers + .get("cseq") + .cloned() + .unwrap_or_else(|| "1".to_string()); + + let mut builder = rtsp::Response::builder(req.version, status_code_from_u16(code)) + .header(rtsp::headers::CSEQ, cseq.as_str()); + + if !session_id.is_empty() { + builder = builder.header(rtsp::headers::SESSION, session_id); + } + + for (name, value) in extra_headers { + let header_name = rtsp::HeaderName::try_from(name.as_str()).map_err(|e| { + AppError::BadRequest(format!("invalid RTSP header name {}: {}", name, e)) + })?; + builder = builder.header(header_name, value); + } + + let response = builder.build(body.as_bytes().to_vec()); + serialize_and_write(stream, response).await +} + +pub(crate) fn status_code_from_u16(code: u16) -> rtsp::StatusCode { + match code { + 200 => rtsp::StatusCode::Ok, + 400 => rtsp::StatusCode::BadRequest, + 401 => rtsp::StatusCode::Unauthorized, + 404 => rtsp::StatusCode::NotFound, + 405 => rtsp::StatusCode::MethodNotAllowed, + 453 => rtsp::StatusCode::NotEnoughBandwidth, + 455 => rtsp::StatusCode::MethodNotValidInThisState, + 461 => rtsp::StatusCode::UnsupportedTransport, + _ => rtsp::StatusCode::InternalServerError, + } +} diff --git a/src/rtsp/sdp.rs b/src/rtsp/sdp.rs new file mode 100644 index 00000000..61acf57e --- /dev/null +++ b/src/rtsp/sdp.rs @@ -0,0 +1,224 @@ +use base64::Engine; +use sdp_types as sdp; + +use crate::config::RtspConfig; +use crate::video::encoder::VideoCodecType; +use crate::webrtc::rtp::parse_profile_level_id_from_sps; + +use super::state::ParameterSets; + +pub(crate) fn build_h264_fmtp(payload_type: u8, params: &ParameterSets) -> String { + let mut attrs = vec!["packetization-mode=1".to_string()]; + + if let Some(sps) = params.h264_sps.as_ref() { + if let Some(profile_level_id) = parse_profile_level_id_from_sps(sps) { + attrs.push(format!("profile-level-id={}", profile_level_id)); + } + } else { + attrs.push("profile-level-id=42e01f".to_string()); + } + + if let (Some(sps), Some(pps)) = (params.h264_sps.as_ref(), params.h264_pps.as_ref()) { + let sps_b64 = base64::engine::general_purpose::STANDARD.encode(sps.as_ref()); + let pps_b64 = base64::engine::general_purpose::STANDARD.encode(pps.as_ref()); + attrs.push(format!("sprop-parameter-sets={},{}", sps_b64, pps_b64)); + } + + format!("{} {}", payload_type, attrs.join(";")) +} + +pub(crate) fn build_h265_fmtp(payload_type: u8, params: &ParameterSets) -> String { + let mut attrs = Vec::new(); + + if let Some(vps) = params.h265_vps.as_ref() { + attrs.push(format!( + "sprop-vps={}", + base64::engine::general_purpose::STANDARD.encode(vps.as_ref()) + )); + } + + if let Some(sps) = params.h265_sps.as_ref() { + attrs.push(format!( + "sprop-sps={}", + base64::engine::general_purpose::STANDARD.encode(sps.as_ref()) + )); + } + + if let Some(pps) = params.h265_pps.as_ref() { + attrs.push(format!( + "sprop-pps={}", + base64::engine::general_purpose::STANDARD.encode(pps.as_ref()) + )); + } + + if attrs.is_empty() { + format!("{} profile-id=1", payload_type) + } else { + format!("{} {}", payload_type, attrs.join(";")) + } +} + +pub(crate) fn build_sdp( + config: &RtspConfig, + codec: VideoCodecType, + params: &ParameterSets, +) -> String { + let (payload_type, codec_name, fmtp_value) = match codec { + VideoCodecType::H264 => (96u8, "H264", build_h264_fmtp(96, params)), + VideoCodecType::H265 => (99u8, "H265", build_h265_fmtp(99, params)), + _ => { + tracing::warn!("RTSP SDP: unexpected VideoCodecType, falling back to H264"); + (96u8, "H264", build_h264_fmtp(96, params)) + } + }; + + let session = sdp::Session { + origin: sdp::Origin { + username: Some("-".to_string()), + sess_id: "0".to_string(), + sess_version: 0, + nettype: "IN".to_string(), + addrtype: "IP4".to_string(), + unicast_address: config.bind.clone(), + }, + session_name: "One-KVM RTSP Stream".to_string(), + session_description: None, + uri: None, + emails: Vec::new(), + phones: Vec::new(), + connection: Some(sdp::Connection { + nettype: "IN".to_string(), + addrtype: "IP4".to_string(), + connection_address: "0.0.0.0".to_string(), + }), + bandwidths: Vec::new(), + times: vec![sdp::Time { + start_time: 0, + stop_time: 0, + repeats: Vec::new(), + }], + time_zones: Vec::new(), + key: None, + attributes: vec![sdp::Attribute { + attribute: "control".to_string(), + value: Some("*".to_string()), + }], + medias: vec![sdp::Media { + media: "video".to_string(), + port: 0, + num_ports: None, + proto: "RTP/AVP".to_string(), + fmt: payload_type.to_string(), + media_title: None, + connections: Vec::new(), + bandwidths: Vec::new(), + key: None, + attributes: vec![ + sdp::Attribute { + attribute: "rtpmap".to_string(), + value: Some(format!("{} {}/90000", payload_type, codec_name)), + }, + sdp::Attribute { + attribute: "fmtp".to_string(), + value: Some(fmtp_value), + }, + sdp::Attribute { + attribute: "control".to_string(), + value: Some("trackID=0".to_string()), + }, + ], + }], + }; + + let mut output = Vec::new(); + if let Err(err) = session.write(&mut output) { + tracing::warn!("Failed to serialize SDP with sdp-types: {}", err); + return String::new(); + } + + match String::from_utf8(output) { + Ok(sdp_text) => sdp_text, + Err(err) => { + tracing::warn!("Failed to convert SDP bytes to UTF-8: {}", err); + String::new() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::RtspConfig; + use bytes::Bytes; + + #[test] + fn build_sdp_h264_is_parseable_with_expected_video_attributes() { + let config = RtspConfig::default(); + let mut params = ParameterSets::default(); + params.h264_sps = Some(Bytes::from_static(&[0x67, 0x42, 0xe0, 0x1f, 0x96, 0x54])); + params.h264_pps = Some(Bytes::from_static(&[0x68, 0xce, 0x06, 0xe2])); + + let sdp_text = build_sdp(&config, VideoCodecType::H264, ¶ms); + assert!(!sdp_text.is_empty()); + + let session = sdp::Session::parse(sdp_text.as_bytes()).expect("sdp parse failed"); + assert_eq!(session.session_name, "One-KVM RTSP Stream"); + assert_eq!(session.medias.len(), 1); + + let media = &session.medias[0]; + assert_eq!(media.media, "video"); + assert_eq!(media.proto, "RTP/AVP"); + assert_eq!(media.fmt, "96"); + + let has_rtpmap = media.attributes.iter().any(|attr| { + attr.attribute == "rtpmap" && attr.value.as_deref() == Some("96 H264/90000") + }); + assert!(has_rtpmap); + + let fmtp_value = media + .attributes + .iter() + .find(|attr| attr.attribute == "fmtp") + .and_then(|attr| attr.value.as_deref()) + .expect("missing fmtp value"); + assert!(fmtp_value.starts_with("96 ")); + assert!(fmtp_value.contains("packetization-mode=1")); + assert!(fmtp_value.contains("sprop-parameter-sets=")); + } + + #[test] + fn build_sdp_h265_is_parseable_with_expected_video_attributes() { + let config = RtspConfig::default(); + let mut params = ParameterSets::default(); + params.h265_vps = Some(Bytes::from_static(&[0x40, 0x01, 0x0c, 0x01])); + params.h265_sps = Some(Bytes::from_static(&[0x42, 0x01, 0x01, 0x60])); + params.h265_pps = Some(Bytes::from_static(&[0x44, 0x01, 0xc0, 0x73])); + + let sdp_text = build_sdp(&config, VideoCodecType::H265, ¶ms); + assert!(!sdp_text.is_empty()); + + let session = sdp::Session::parse(sdp_text.as_bytes()).expect("sdp parse failed"); + assert_eq!(session.medias.len(), 1); + + let media = &session.medias[0]; + assert_eq!(media.media, "video"); + assert_eq!(media.proto, "RTP/AVP"); + assert_eq!(media.fmt, "99"); + + let has_rtpmap = media.attributes.iter().any(|attr| { + attr.attribute == "rtpmap" && attr.value.as_deref() == Some("99 H265/90000") + }); + assert!(has_rtpmap); + + let fmtp_value = media + .attributes + .iter() + .find(|attr| attr.attribute == "fmtp") + .and_then(|attr| attr.value.as_deref()) + .expect("missing fmtp value"); + assert!(fmtp_value.starts_with("99 ")); + assert!(fmtp_value.contains("sprop-vps=")); + assert!(fmtp_value.contains("sprop-sps=")); + assert!(fmtp_value.contains("sprop-pps=")); + } +} diff --git a/src/rtsp/service.rs b/src/rtsp/service.rs index 245a0b7d..be7526c9 100644 --- a/src/rtsp/service.rs +++ b/src/rtsp/service.rs @@ -1,100 +1,28 @@ -use base64::Engine; -use bytes::Bytes; -use rand::Rng; -use rtp::packet::Packet; -use rtp::packetizer::Payloader; use rtsp_types as rtsp; -use sdp_types as sdp; -use std::collections::HashMap; use std::io; use std::net::SocketAddr; use std::sync::Arc; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::AsyncReadExt; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{broadcast, Mutex, RwLock}; -use tokio::time::{sleep, Duration}; -use webrtc::util::Marshal; -use crate::config::{RtspCodec, RtspConfig}; +use crate::config::RtspConfig; use crate::error::{AppError, Result}; -use crate::video::encoder::registry::VideoEncoderType; -use crate::video::encoder::VideoCodecType; -use crate::video::shared_video_pipeline::EncodedVideoFrame; use crate::video::VideoStreamManager; -use crate::webrtc::h265_payloader::H265Payloader; -use crate::webrtc::rtp::parse_profile_level_id_from_sps; -const RTP_CLOCK_RATE: u32 = 90_000; -const RTP_MTU: usize = 1200; -const RTSP_BUF_SIZE: usize = 8192; -const RTSP_RESUBSCRIBE_DELAY_MS: u64 = 300; +use super::auth::{extract_basic_auth, rtsp_auth_credentials}; +use super::codec::rtsp_codec_to_video; +use super::protocol::{ + is_tcp_transport_request, is_valid_rtsp_path, parse_interleaved_channel, parse_rtsp_request, + take_rtsp_request_from_buffer, OPTIONS_PUBLIC_CAPABILITIES, +}; +use super::response::{send_response, send_simple_response}; +use super::sdp::build_sdp; +use super::state::SharedRtspState; +use super::streaming::{stream_video_interleaved, RTSP_BUF_SIZE}; +use super::types::RtspConnectionState; -#[derive(Debug, Clone, PartialEq)] -pub enum RtspServiceStatus { - Stopped, - Starting, - Running, - Error(String), -} - -impl std::fmt::Display for RtspServiceStatus { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Stopped => write!(f, "stopped"), - Self::Starting => write!(f, "starting"), - Self::Running => write!(f, "running"), - Self::Error(err) => write!(f, "error: {}", err), - } - } -} - -#[derive(Debug, Clone)] -struct RtspRequest { - method: rtsp::Method, - uri: String, - version: rtsp::Version, - headers: HashMap, -} - -struct RtspConnectionState { - session_id: String, - setup_done: bool, - interleaved_channel: u8, -} - -impl RtspConnectionState { - fn new() -> Self { - Self { - session_id: generate_session_id(), - setup_done: false, - interleaved_channel: 0, - } - } -} - -#[derive(Default, Clone)] -struct ParameterSets { - h264_sps: Option, - h264_pps: Option, - h265_vps: Option, - h265_sps: Option, - h265_pps: Option, -} - -#[derive(Clone)] -struct SharedRtspState { - active_client: Arc>>, - parameter_sets: Arc>, -} - -impl SharedRtspState { - fn new() -> Self { - Self { - active_client: Arc::new(Mutex::new(None)), - parameter_sets: Arc::new(RwLock::new(ParameterSets::default())), - } - } -} +pub use super::types::RtspServiceStatus; pub struct RtspService { config: Arc>, @@ -133,10 +61,7 @@ impl RtspService { *self.status.write().await = RtspServiceStatus::Starting; - let codec = match config.codec { - RtspCodec::H264 => VideoCodecType::H264, - RtspCodec::H265 => VideoCodecType::H265, - }; + let codec = rtsp_codec_to_video(config.codec); if let Err(err) = self.video_manager.set_video_codec(codec).await { let message = format!("failed to set codec before RTSP start: {}", err); @@ -324,7 +249,7 @@ async fn handle_client( "OK", vec![( "Public".to_string(), - "OPTIONS, DESCRIBE, SETUP, PLAY, TEARDOWN".to_string(), + OPTIONS_PUBLIC_CAPABILITIES.to_string(), )], "", "", @@ -332,10 +257,7 @@ async fn handle_client( .await?; } rtsp::Method::Describe => { - let codec = match cfg_snapshot.codec { - RtspCodec::H264 => VideoCodecType::H264, - RtspCodec::H265 => VideoCodecType::H265, - }; + let codec = rtsp_codec_to_video(cfg_snapshot.codec.clone()); let params = shared.parameter_sets.read().await.clone(); let sdp = build_sdp(&cfg_snapshot, codec, ¶ms); if sdp.is_empty() { @@ -467,925 +389,3 @@ async fn handle_client( Ok(()) } - -async fn stream_video_interleaved( - stream: TcpStream, - video_manager: &Arc, - rtsp_codec: RtspCodec, - channel: u8, - shared: SharedRtspState, - session_id: String, -) -> Result<()> { - let (mut reader, mut writer) = stream.into_split(); - - let mut rx = video_manager - .subscribe_encoded_frames() - .await - .ok_or_else(|| { - AppError::VideoError("RTSP failed to subscribe encoded frames".to_string()) - })?; - - video_manager.request_keyframe().await.ok(); - - let payload_type = match rtsp_codec { - RtspCodec::H264 => 96, - RtspCodec::H265 => 99, - }; - let mut sequence_number: u16 = rand::rng().random(); - let ssrc: u32 = rand::rng().random(); - - let mut h264_payloader = rtp::codecs::h264::H264Payloader::default(); - let mut h265_payloader = H265Payloader::new(); - let mut ctrl_read_buf = [0u8; RTSP_BUF_SIZE]; - let mut ctrl_buffer = Vec::with_capacity(RTSP_BUF_SIZE); - // RTP timestamps must increase; pts_ms is often 0 for many frames (capture→encode jitter), - // which yields a flat RTP timestamp and breaks VLC/ffplay. - let mut last_rtp_timestamp: u32 = 0; - - loop { - tokio::select! { - maybe_frame = rx.recv() => { - let Some(frame) = maybe_frame else { - tracing::warn!("RTSP encoded frame subscription ended, attempting to restart pipeline"); - - if let Some(new_rx) = video_manager.subscribe_encoded_frames().await { - rx = new_rx; - let _ = video_manager.request_keyframe().await; - tracing::info!("RTSP frame subscription recovered"); - } else { - tracing::warn!( - "RTSP failed to resubscribe encoded frames, retrying in {}ms", - RTSP_RESUBSCRIBE_DELAY_MS - ); - sleep(Duration::from_millis(RTSP_RESUBSCRIBE_DELAY_MS)).await; - } - - continue; - }; - - if !is_frame_codec_match(&frame, &rtsp_codec) { - continue; - } - - { - let mut params = shared.parameter_sets.write().await; - update_parameter_sets(&mut params, &frame); - } - - let rtp_timestamp = monotonic_rtp_timestamp( - frame.pts_ms, - &mut last_rtp_timestamp, - frame.duration, - ); - - let payloads: Vec = match rtsp_codec { - RtspCodec::H264 => h264_payloader - .payload(RTP_MTU, &frame.data) - .map_err(|e| AppError::VideoError(format!("H264 payload failed: {}", e)))?, - RtspCodec::H265 => h265_payloader.payload(RTP_MTU, &frame.data), - }; - - if payloads.is_empty() { - continue; - } - - let total_payloads = payloads.len(); - for (idx, payload) in payloads.into_iter().enumerate() { - let marker = idx == total_payloads.saturating_sub(1); - let packet = Packet { - header: rtp::header::Header { - version: 2, - padding: false, - extension: false, - marker, - payload_type, - sequence_number, - timestamp: rtp_timestamp, - ssrc, - ..Default::default() - }, - payload, - }; - - sequence_number = sequence_number.wrapping_add(1); - send_interleaved_rtp(&mut writer, channel, &packet).await?; - } - - if frame.is_keyframe { - tracing::debug!("RTSP keyframe sent"); - } - } - read_res = reader.read(&mut ctrl_read_buf) => { - let n = read_res?; - if n == 0 { - break; - } - - ctrl_buffer.extend_from_slice(&ctrl_read_buf[..n]); - - while strip_interleaved_frames_prefix(&mut ctrl_buffer) {} - - while let Some(raw_req) = take_rtsp_request_from_buffer(&mut ctrl_buffer) { - let Some(req) = parse_rtsp_request(&raw_req) else { - continue; - }; - - if handle_play_control_request(&mut writer, &req, &session_id).await? { - return Ok(()); - } - - while strip_interleaved_frames_prefix(&mut ctrl_buffer) {} - } - } - } - } - - Ok(()) -} - -async fn send_interleaved_rtp( - stream: &mut W, - channel: u8, - packet: &Packet, -) -> Result<()> { - let marshaled = packet - .marshal() - .map_err(|e| AppError::VideoError(format!("RTP marshal failed: {}", e)))?; - let len = marshaled.len() as u16; - - let mut header = [0u8; 4]; - header[0] = b'$'; - header[1] = channel; - header[2] = (len >> 8) as u8; - header[3] = (len & 0xff) as u8; - - stream.write_all(&header).await?; - stream.write_all(&marshaled).await?; - Ok(()) -} - -async fn handle_play_control_request( - stream: &mut W, - req: &RtspRequest, - session_id: &str, -) -> Result { - match &req.method { - rtsp::Method::Teardown => { - send_response(stream, req, 200, "OK", vec![], "", session_id).await?; - Ok(true) - } - rtsp::Method::Options => { - send_response( - stream, - req, - 200, - "OK", - vec![( - "Public".to_string(), - "OPTIONS, DESCRIBE, SETUP, PLAY, GET_PARAMETER, SET_PARAMETER, TEARDOWN" - .to_string(), - )], - "", - session_id, - ) - .await?; - Ok(false) - } - rtsp::Method::GetParameter | rtsp::Method::SetParameter => { - send_response(stream, req, 200, "OK", vec![], "", session_id).await?; - Ok(false) - } - _ => { - send_response( - stream, - req, - 405, - "Method Not Allowed", - vec![], - "", - session_id, - ) - .await?; - Ok(false) - } - } -} - -fn strip_interleaved_frames_prefix(buffer: &mut Vec) -> bool { - if buffer.len() < 4 || buffer[0] != b'$' { - return false; - } - - let payload_len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize; - let frame_len = 4 + payload_len; - if buffer.len() < frame_len { - return false; - } - - buffer.drain(0..frame_len); - true -} - -fn take_rtsp_request_from_buffer(buffer: &mut Vec) -> Option { - let delimiter = b"\r\n\r\n"; - let pos = find_bytes(buffer, delimiter)?; - let req_end = pos + delimiter.len(); - let req_bytes: Vec = buffer.drain(0..req_end).collect(); - Some(String::from_utf8_lossy(&req_bytes).to_string()) -} - -fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option { - haystack - .windows(needle.len()) - .position(|window| window == needle) -} - -fn parse_rtsp_request(raw: &str) -> Option { - let (message, consumed): (rtsp::Message>, usize) = - rtsp::Message::parse(raw.as_bytes()).ok()?; - if consumed != raw.len() { - return None; - } - - let request = match message { - rtsp::Message::Request(req) => req, - _ => return None, - }; - - let uri = request - .request_uri() - .map(|value| value.as_str().to_string()) - .unwrap_or_default(); - - let mut headers = HashMap::new(); - for (name, value) in request.headers() { - headers.insert(name.to_string().to_ascii_lowercase(), value.to_string()); - } - - Some(RtspRequest { - method: request.method().clone(), - uri, - version: request.version(), - headers, - }) -} - -fn extract_basic_auth(req: &RtspRequest) -> Option<(String, String)> { - let value = req.headers.get("authorization")?; - let mut parts = value.split_whitespace(); - let scheme = parts.next()?; - if !scheme.eq_ignore_ascii_case("basic") { - return None; - } - let b64 = parts.next()?; - let decoded = base64::engine::general_purpose::STANDARD.decode(b64).ok()?; - let raw = String::from_utf8(decoded).ok()?; - let (user, pass) = raw.split_once(':')?; - Some((user.to_string(), pass.to_string())) -} - -fn rtsp_auth_credentials(config: &RtspConfig) -> Option<(String, String)> { - let username = config.username.as_ref()?.trim(); - if username.is_empty() { - return None; - } - - Some(( - username.to_string(), - config.password.clone().unwrap_or_default(), - )) -} - -fn parse_interleaved_channel(transport: &str) -> Option { - let lower = transport.to_ascii_lowercase(); - if let Some((_, v)) = lower.split_once("interleaved=") { - let head = v.split(';').next().unwrap_or(v); - let first = head.split('-').next().unwrap_or(head).trim(); - return first.parse::().ok(); - } - None -} - -fn is_tcp_transport_request(transport: &str) -> bool { - transport - .split(',') - .map(str::trim) - .map(str::to_ascii_lowercase) - .any(|item| item.contains("rtp/avp/tcp") || item.contains("interleaved=")) -} - -fn update_parameter_sets(params: &mut ParameterSets, frame: &EncodedVideoFrame) { - let nal_units = split_annexb_nal_units(frame.data.as_ref()); - - match frame.codec { - VideoEncoderType::H264 => { - for nal in nal_units { - match h264_nal_type(nal) { - Some(7) => params.h264_sps = Some(Bytes::copy_from_slice(nal)), - Some(8) => params.h264_pps = Some(Bytes::copy_from_slice(nal)), - _ => {} - } - } - } - VideoEncoderType::H265 => { - for nal in nal_units { - match h265_nal_type(nal) { - Some(32) => params.h265_vps = Some(Bytes::copy_from_slice(nal)), - Some(33) => params.h265_sps = Some(Bytes::copy_from_slice(nal)), - Some(34) => params.h265_pps = Some(Bytes::copy_from_slice(nal)), - _ => {} - } - } - } - _ => {} - } -} - -fn split_annexb_nal_units(data: &[u8]) -> Vec<&[u8]> { - let mut nal_units = Vec::new(); - let mut cursor = 0usize; - - while let Some((start, start_code_len)) = find_annexb_start_code(data, cursor) { - let nal_start = start + start_code_len; - if nal_start >= data.len() { - break; - } - - let next_start = find_annexb_start_code(data, nal_start) - .map(|(idx, _)| idx) - .unwrap_or(data.len()); - - let mut nal_end = next_start; - while nal_end > nal_start && data[nal_end - 1] == 0 { - nal_end -= 1; - } - - if nal_end > nal_start { - nal_units.push(&data[nal_start..nal_end]); - } - - cursor = next_start; - } - - nal_units -} - -fn find_annexb_start_code(data: &[u8], from: usize) -> Option<(usize, usize)> { - if from >= data.len() { - return None; - } - - let mut i = from; - while i + 3 <= data.len() { - if i + 4 <= data.len() - && data[i] == 0 - && data[i + 1] == 0 - && data[i + 2] == 0 - && data[i + 3] == 1 - { - return Some((i, 4)); - } - - if data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 { - return Some((i, 3)); - } - - i += 1; - } - - None -} - -fn h264_nal_type(nal: &[u8]) -> Option { - nal.first().map(|value| value & 0x1f) -} - -fn h265_nal_type(nal: &[u8]) -> Option { - nal.first().map(|value| (value >> 1) & 0x3f) -} - -fn build_h264_fmtp(payload_type: u8, params: &ParameterSets) -> String { - let mut attrs = vec!["packetization-mode=1".to_string()]; - - if let Some(sps) = params.h264_sps.as_ref() { - if let Some(profile_level_id) = parse_profile_level_id_from_sps(sps) { - attrs.push(format!("profile-level-id={}", profile_level_id)); - } - } else { - attrs.push("profile-level-id=42e01f".to_string()); - } - - if let (Some(sps), Some(pps)) = (params.h264_sps.as_ref(), params.h264_pps.as_ref()) { - let sps_b64 = base64::engine::general_purpose::STANDARD.encode(sps.as_ref()); - let pps_b64 = base64::engine::general_purpose::STANDARD.encode(pps.as_ref()); - attrs.push(format!("sprop-parameter-sets={},{}", sps_b64, pps_b64)); - } - - format!("{} {}", payload_type, attrs.join(";")) -} - -fn build_h265_fmtp(payload_type: u8, params: &ParameterSets) -> String { - let mut attrs = Vec::new(); - - if let Some(vps) = params.h265_vps.as_ref() { - attrs.push(format!( - "sprop-vps={}", - base64::engine::general_purpose::STANDARD.encode(vps.as_ref()) - )); - } - - if let Some(sps) = params.h265_sps.as_ref() { - attrs.push(format!( - "sprop-sps={}", - base64::engine::general_purpose::STANDARD.encode(sps.as_ref()) - )); - } - - if let Some(pps) = params.h265_pps.as_ref() { - attrs.push(format!( - "sprop-pps={}", - base64::engine::general_purpose::STANDARD.encode(pps.as_ref()) - )); - } - - if attrs.is_empty() { - format!("{} profile-id=1", payload_type) - } else { - format!("{} {}", payload_type, attrs.join(";")) - } -} - -fn build_sdp(config: &RtspConfig, codec: VideoCodecType, params: &ParameterSets) -> String { - let (payload_type, codec_name, fmtp_value) = match codec { - VideoCodecType::H264 => (96u8, "H264", build_h264_fmtp(96, params)), - VideoCodecType::H265 => (99u8, "H265", build_h265_fmtp(99, params)), - _ => (96u8, "H264", build_h264_fmtp(96, params)), - }; - - let session = sdp::Session { - origin: sdp::Origin { - username: Some("-".to_string()), - sess_id: "0".to_string(), - sess_version: 0, - nettype: "IN".to_string(), - addrtype: "IP4".to_string(), - unicast_address: config.bind.clone(), - }, - session_name: "One-KVM RTSP Stream".to_string(), - session_description: None, - uri: None, - emails: Vec::new(), - phones: Vec::new(), - connection: Some(sdp::Connection { - nettype: "IN".to_string(), - addrtype: "IP4".to_string(), - connection_address: "0.0.0.0".to_string(), - }), - bandwidths: Vec::new(), - times: vec![sdp::Time { - start_time: 0, - stop_time: 0, - repeats: Vec::new(), - }], - time_zones: Vec::new(), - key: None, - attributes: vec![sdp::Attribute { - attribute: "control".to_string(), - value: Some("*".to_string()), - }], - medias: vec![sdp::Media { - media: "video".to_string(), - port: 0, - num_ports: None, - proto: "RTP/AVP".to_string(), - fmt: payload_type.to_string(), - media_title: None, - connections: Vec::new(), - bandwidths: Vec::new(), - key: None, - attributes: vec![ - sdp::Attribute { - attribute: "rtpmap".to_string(), - value: Some(format!("{} {}/90000", payload_type, codec_name)), - }, - sdp::Attribute { - attribute: "fmtp".to_string(), - value: Some(fmtp_value), - }, - sdp::Attribute { - attribute: "control".to_string(), - value: Some("trackID=0".to_string()), - }, - ], - }], - }; - - let mut output = Vec::new(); - if let Err(err) = session.write(&mut output) { - tracing::warn!("Failed to serialize SDP with sdp-types: {}", err); - return String::new(); - } - - match String::from_utf8(output) { - Ok(sdp_text) => sdp_text, - Err(err) => { - tracing::warn!("Failed to convert SDP bytes to UTF-8: {}", err); - String::new() - } - } -} - -async fn send_simple_response( - stream: &mut W, - code: u16, - _reason: &str, - cseq: Option<&str>, - body: &str, -) -> Result<()> { - let mut builder = rtsp::Response::builder(rtsp::Version::V1_0, status_code_from_u16(code)); - if let Some(cseq) = cseq { - builder = builder.header(rtsp::headers::CSEQ, cseq); - } - - let response = builder.build(body.as_bytes().to_vec()); - - let mut data = Vec::new(); - response - .write(&mut data) - .map_err(|e| AppError::BadRequest(format!("failed to serialize RTSP response: {}", e)))?; - stream.write_all(&data).await?; - Ok(()) -} - -async fn send_response( - stream: &mut W, - req: &RtspRequest, - code: u16, - _reason: &str, - extra_headers: Vec<(String, String)>, - body: &str, - session_id: &str, -) -> Result<()> { - let cseq = req - .headers - .get("cseq") - .cloned() - .unwrap_or_else(|| "1".to_string()); - - let mut builder = rtsp::Response::builder(req.version, status_code_from_u16(code)) - .header(rtsp::headers::CSEQ, cseq.as_str()); - - if !session_id.is_empty() { - builder = builder.header(rtsp::headers::SESSION, session_id); - } - - for (name, value) in extra_headers { - let header_name = rtsp::HeaderName::try_from(name.as_str()).map_err(|e| { - AppError::BadRequest(format!("invalid RTSP header name {}: {}", name, e)) - })?; - builder = builder.header(header_name, value); - } - - let response = builder.build(body.as_bytes().to_vec()); - - let mut data = Vec::new(); - response - .write(&mut data) - .map_err(|e| AppError::BadRequest(format!("failed to serialize RTSP response: {}", e)))?; - stream.write_all(&data).await?; - Ok(()) -} - -fn status_code_from_u16(code: u16) -> rtsp::StatusCode { - match code { - 200 => rtsp::StatusCode::Ok, - 400 => rtsp::StatusCode::BadRequest, - 401 => rtsp::StatusCode::Unauthorized, - 404 => rtsp::StatusCode::NotFound, - 405 => rtsp::StatusCode::MethodNotAllowed, - 453 => rtsp::StatusCode::NotEnoughBandwidth, - 455 => rtsp::StatusCode::MethodNotValidInThisState, - 461 => rtsp::StatusCode::UnsupportedTransport, - _ => rtsp::StatusCode::InternalServerError, - } -} - -fn is_valid_rtsp_path(method: &rtsp::Method, uri: &str, configured_path: &str) -> bool { - if matches!(method, rtsp::Method::Options) && uri.trim() == "*" { - return true; - } - - let normalized_cfg = configured_path.trim_matches('/'); - if normalized_cfg.is_empty() { - return false; - } - - let request_path = extract_rtsp_path(uri); - - if request_path == normalized_cfg { - return true; - } - - if !matches!(method, rtsp::Method::Setup | rtsp::Method::Teardown) { - return false; - } - - let control_track_path = format!("{}/trackID=0", normalized_cfg); - request_path == "trackID=0" || request_path == control_track_path -} - -fn extract_rtsp_path(uri: &str) -> String { - let raw_path = if let Some((_, remainder)) = uri.split_once("://") { - match remainder.find('/') { - Some(idx) => &remainder[idx..], - None => "/", - } - } else { - uri - }; - - raw_path - .split('?') - .next() - .unwrap_or(raw_path) - .split('#') - .next() - .unwrap_or(raw_path) - .trim_matches('/') - .to_string() -} - -fn is_frame_codec_match(frame: &EncodedVideoFrame, codec: &RtspCodec) -> bool { - matches!( - (frame.codec, codec), - ( - crate::video::encoder::registry::VideoEncoderType::H264, - RtspCodec::H264 - ) | ( - crate::video::encoder::registry::VideoEncoderType::H265, - RtspCodec::H265 - ) - ) -} - -fn pts_to_rtp_timestamp(pts_ms: i64) -> u32 { - if pts_ms <= 0 { - return 0; - } - ((pts_ms as u64 * RTP_CLOCK_RATE as u64) / 1000) as u32 -} - -/// 90 kHz ticks per frame from nominal duration (at least 1). -fn rtp_timestamp_increment(frame_duration: Duration) -> u32 { - let inc = (frame_duration.as_secs_f64() * f64::from(RTP_CLOCK_RATE)).round() as u32; - inc.max(1) -} - -/// Prefer PTS-based RTP time when it advances; otherwise step by `frame_duration` in 90 kHz units. -fn monotonic_rtp_timestamp(pts_ms: i64, last: &mut u32, frame_duration: Duration) -> u32 { - let from_pts = pts_to_rtp_timestamp(pts_ms); - let inc = rtp_timestamp_increment(frame_duration); - let ts = if from_pts > *last { - from_pts - } else { - last.wrapping_add(inc) - }; - *last = ts; - ts -} - -fn generate_session_id() -> String { - let mut rng = rand::rng(); - let value: u64 = rng.random(); - format!("{:016x}", value) -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::io::{duplex, AsyncReadExt}; - - fn make_test_request(method: rtsp::Method) -> RtspRequest { - let mut headers = HashMap::new(); - headers.insert("cseq".to_string(), "7".to_string()); - RtspRequest { - method, - uri: "rtsp://127.0.0.1/live".to_string(), - version: rtsp::Version::V1_0, - headers, - } - } - - async fn read_response_from_duplex( - mut client: tokio::io::DuplexStream, - ) -> rtsp::Response> { - let mut buf = vec![0u8; 4096]; - let n = client - .read(&mut buf) - .await - .expect("failed to read rtsp response"); - assert!(n > 0); - let (message, consumed): (rtsp::Message>, usize) = - rtsp::Message::parse(&buf[..n]).expect("failed to parse rtsp response"); - assert_eq!(consumed, n); - - match message { - rtsp::Message::Response(response) => response, - _ => panic!("expected RTSP response"), - } - } - - #[tokio::test] - async fn play_control_teardown_returns_ok_and_stop() { - let req = make_test_request(rtsp::Method::Teardown); - let (client, mut server) = duplex(4096); - - let should_stop = handle_play_control_request(&mut server, &req, "session-1") - .await - .expect("control handling failed"); - assert!(should_stop); - - drop(server); - let response = read_response_from_duplex(client).await; - assert_eq!(response.status(), rtsp::StatusCode::Ok); - } - - #[tokio::test] - async fn play_control_pause_returns_method_not_allowed() { - let req = make_test_request(rtsp::Method::Pause); - let (client, mut server) = duplex(4096); - - let should_stop = handle_play_control_request(&mut server, &req, "session-1") - .await - .expect("control handling failed"); - assert!(!should_stop); - - drop(server); - let response = read_response_from_duplex(client).await; - assert_eq!(response.status(), rtsp::StatusCode::MethodNotAllowed); - } - - #[test] - fn monotonic_rtp_timestamp_steps_when_pts_stays_zero() { - let d = Duration::from_millis(33); - let mut last = 0u32; - let a = monotonic_rtp_timestamp(0, &mut last, d); - let b = monotonic_rtp_timestamp(0, &mut last, d); - let c = monotonic_rtp_timestamp(0, &mut last, d); - assert!(a > 0); - assert!(b > a); - assert!(c > b); - } - - #[test] - fn monotonic_rtp_timestamp_uses_pts_when_it_advances() { - let d = Duration::from_millis(33); - let mut last = 0u32; - let a = monotonic_rtp_timestamp(1000, &mut last, d); - assert_eq!(a, 90_000); - let b = monotonic_rtp_timestamp(2000, &mut last, d); - assert_eq!(b, 180_000); - } - - #[test] - fn build_sdp_h264_is_parseable_with_expected_video_attributes() { - let config = RtspConfig::default(); - let mut params = ParameterSets::default(); - params.h264_sps = Some(Bytes::from_static(&[0x67, 0x42, 0xe0, 0x1f, 0x96, 0x54])); - params.h264_pps = Some(Bytes::from_static(&[0x68, 0xce, 0x06, 0xe2])); - - let sdp_text = build_sdp(&config, VideoCodecType::H264, ¶ms); - assert!(!sdp_text.is_empty()); - - let session = sdp::Session::parse(sdp_text.as_bytes()).expect("sdp parse failed"); - assert_eq!(session.session_name, "One-KVM RTSP Stream"); - assert_eq!(session.medias.len(), 1); - - let media = &session.medias[0]; - assert_eq!(media.media, "video"); - assert_eq!(media.proto, "RTP/AVP"); - assert_eq!(media.fmt, "96"); - - let has_rtpmap = media.attributes.iter().any(|attr| { - attr.attribute == "rtpmap" && attr.value.as_deref() == Some("96 H264/90000") - }); - assert!(has_rtpmap); - - let fmtp_value = media - .attributes - .iter() - .find(|attr| attr.attribute == "fmtp") - .and_then(|attr| attr.value.as_deref()) - .expect("missing fmtp value"); - assert!(fmtp_value.starts_with("96 ")); - assert!(fmtp_value.contains("packetization-mode=1")); - assert!(fmtp_value.contains("sprop-parameter-sets=")); - } - - #[test] - fn rtsp_path_matching_follows_sdp_control_rules() { - assert!(is_valid_rtsp_path( - &rtsp::Method::Describe, - "rtsp://127.0.0.1/live", - "live" - )); - assert!(is_valid_rtsp_path( - &rtsp::Method::Describe, - "rtsp://127.0.0.1/live/?token=1", - "/live/" - )); - assert!(!is_valid_rtsp_path( - &rtsp::Method::Describe, - "rtsp://127.0.0.1/live2", - "live" - )); - assert!(!is_valid_rtsp_path( - &rtsp::Method::Describe, - "rtsp://127.0.0.1/", - "/" - )); - - assert!(is_valid_rtsp_path( - &rtsp::Method::Setup, - "rtsp://127.0.0.1/live/trackID=0", - "live" - )); - assert!(is_valid_rtsp_path( - &rtsp::Method::Setup, - "rtsp://127.0.0.1/trackID=0", - "live" - )); - assert!(!is_valid_rtsp_path( - &rtsp::Method::Describe, - "rtsp://127.0.0.1/live/trackID=0", - "live" - )); - - assert!(is_valid_rtsp_path(&rtsp::Method::Options, "*", "live")); - } - - #[test] - fn transport_parsing_detects_tcp_interleaved_requests() { - assert!(is_tcp_transport_request( - "RTP/AVP/TCP;unicast;interleaved=0-1" - )); - assert!(is_tcp_transport_request("RTP/AVP;unicast;interleaved=2-3")); - assert!(!is_tcp_transport_request( - "RTP/AVP;unicast;client_port=8000-8001" - )); - } - - #[test] - fn build_sdp_h265_is_parseable_with_expected_video_attributes() { - let config = RtspConfig::default(); - let mut params = ParameterSets::default(); - params.h265_vps = Some(Bytes::from_static(&[0x40, 0x01, 0x0c, 0x01])); - params.h265_sps = Some(Bytes::from_static(&[0x42, 0x01, 0x01, 0x60])); - params.h265_pps = Some(Bytes::from_static(&[0x44, 0x01, 0xc0, 0x73])); - - let sdp_text = build_sdp(&config, VideoCodecType::H265, ¶ms); - assert!(!sdp_text.is_empty()); - - let session = sdp::Session::parse(sdp_text.as_bytes()).expect("sdp parse failed"); - assert_eq!(session.medias.len(), 1); - - let media = &session.medias[0]; - assert_eq!(media.media, "video"); - assert_eq!(media.proto, "RTP/AVP"); - assert_eq!(media.fmt, "99"); - - let has_rtpmap = media.attributes.iter().any(|attr| { - attr.attribute == "rtpmap" && attr.value.as_deref() == Some("99 H265/90000") - }); - assert!(has_rtpmap); - - let fmtp_value = media - .attributes - .iter() - .find(|attr| attr.attribute == "fmtp") - .and_then(|attr| attr.value.as_deref()) - .expect("missing fmtp value"); - assert!(fmtp_value.starts_with("99 ")); - assert!(fmtp_value.contains("sprop-vps=")); - assert!(fmtp_value.contains("sprop-sps=")); - assert!(fmtp_value.contains("sprop-pps=")); - } - - #[test] - fn rtsp_auth_requires_non_empty_username() { - let mut config = RtspConfig::default(); - config.password = Some("secret".to_string()); - assert!(rtsp_auth_credentials(&config).is_none()); - - config.username = Some("".to_string()); - assert!(rtsp_auth_credentials(&config).is_none()); - - config.username = Some("user".to_string()); - let credentials = rtsp_auth_credentials(&config).expect("expected credentials"); - assert_eq!(credentials, ("user".to_string(), "secret".to_string())); - - config.password = None; - let credentials = rtsp_auth_credentials(&config).expect("expected credentials"); - assert_eq!(credentials, ("user".to_string(), "".to_string())); - } -} diff --git a/src/rtsp/state.rs b/src/rtsp/state.rs new file mode 100644 index 00000000..0f837c4b --- /dev/null +++ b/src/rtsp/state.rs @@ -0,0 +1,28 @@ +use bytes::Bytes; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::{Mutex, RwLock}; + +#[derive(Default, Clone)] +pub(crate) struct ParameterSets { + pub h264_sps: Option, + pub h264_pps: Option, + pub h265_vps: Option, + pub h265_sps: Option, + pub h265_pps: Option, +} + +#[derive(Clone)] +pub(crate) struct SharedRtspState { + pub active_client: Arc>>, + pub parameter_sets: Arc>, +} + +impl SharedRtspState { + pub fn new() -> Self { + Self { + active_client: Arc::new(Mutex::new(None)), + parameter_sets: Arc::new(RwLock::new(ParameterSets::default())), + } + } +} diff --git a/src/rtsp/streaming.rs b/src/rtsp/streaming.rs new file mode 100644 index 00000000..31f7aca7 --- /dev/null +++ b/src/rtsp/streaming.rs @@ -0,0 +1,367 @@ +use bytes::Bytes; +use rand::Rng; +use rtp::packet::Packet; +use rtp::packetizer::Payloader; +use rtsp_types as rtsp; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::time::{sleep, Duration}; +use webrtc::util::{Marshal, MarshalSize}; + +use crate::config::RtspCodec; +use crate::error::{AppError, Result}; +use crate::video::encoder::registry::VideoEncoderType; +use crate::video::shared_video_pipeline::EncodedVideoFrame; +use crate::video::VideoStreamManager; +use crate::webrtc::h265_payloader::H265Payloader; + +use super::bitstream::update_parameter_sets; +use super::protocol::{ + parse_rtsp_request, strip_interleaved_frames_prefix, take_rtsp_request_from_buffer, +}; +use super::response::send_response; +use super::state::SharedRtspState; +use super::types::RtspRequest; + +pub(crate) const RTP_CLOCK_RATE: u32 = 90_000; +pub(crate) const RTP_MTU: usize = 1200; +pub(crate) const RTSP_BUF_SIZE: usize = 8192; +const RTSP_RESUBSCRIBE_DELAY_MS: u64 = 300; + +pub(crate) async fn stream_video_interleaved( + stream: TcpStream, + video_manager: &Arc, + rtsp_codec: RtspCodec, + channel: u8, + shared: SharedRtspState, + session_id: String, +) -> Result<()> { + let (mut reader, mut writer) = stream.into_split(); + + let mut rx = video_manager + .subscribe_encoded_frames() + .await + .ok_or_else(|| { + AppError::VideoError("RTSP failed to subscribe encoded frames".to_string()) + })?; + + video_manager.request_keyframe().await.ok(); + + let payload_type = match rtsp_codec { + RtspCodec::H264 => 96, + RtspCodec::H265 => 99, + }; + let mut sequence_number: u16 = rand::rng().random(); + let ssrc: u32 = rand::rng().random(); + + let mut h264_payloader = rtp::codecs::h264::H264Payloader::default(); + let mut h265_payloader = H265Payloader::new(); + let mut ctrl_read_buf = [0u8; RTSP_BUF_SIZE]; + let mut ctrl_buffer = Vec::with_capacity(RTSP_BUF_SIZE); + // 4-byte interleaved prefix + RTP header + payload shard (≤ RTP_MTU) + let mut interleaved_rtp_buf = Vec::with_capacity(4 + RTP_MTU + 96); + let mut last_rtp_timestamp: u32 = 0; + + loop { + tokio::select! { + maybe_frame = rx.recv() => { + let Some(frame) = maybe_frame else { + tracing::warn!("RTSP encoded frame subscription ended, attempting to restart pipeline"); + + if let Some(new_rx) = video_manager.subscribe_encoded_frames().await { + rx = new_rx; + let _ = video_manager.request_keyframe().await; + tracing::info!("RTSP frame subscription recovered"); + } else { + tracing::warn!( + "RTSP failed to resubscribe encoded frames, retrying in {}ms", + RTSP_RESUBSCRIBE_DELAY_MS + ); + sleep(Duration::from_millis(RTSP_RESUBSCRIBE_DELAY_MS)).await; + } + + continue; + }; + + if !is_frame_codec_match(&frame, &rtsp_codec) { + continue; + } + + { + let mut params = shared.parameter_sets.write().await; + update_parameter_sets(&mut params, &frame); + } + + let rtp_timestamp = monotonic_rtp_timestamp( + frame.pts_ms, + &mut last_rtp_timestamp, + frame.duration, + ); + + let payloads: Vec = match rtsp_codec { + RtspCodec::H264 => h264_payloader + .payload(RTP_MTU, &frame.data) + .map_err(|e| AppError::VideoError(format!("H264 payload failed: {}", e)))?, + RtspCodec::H265 => h265_payloader.payload(RTP_MTU, &frame.data), + }; + + if payloads.is_empty() { + continue; + } + + let total_payloads = payloads.len(); + for (idx, payload) in payloads.into_iter().enumerate() { + let marker = idx == total_payloads.saturating_sub(1); + let packet = Packet { + header: rtp::header::Header { + version: 2, + padding: false, + extension: false, + marker, + payload_type, + sequence_number, + timestamp: rtp_timestamp, + ssrc, + ..Default::default() + }, + payload, + }; + + sequence_number = sequence_number.wrapping_add(1); + send_interleaved_rtp(&mut writer, channel, &packet, &mut interleaved_rtp_buf) + .await?; + } + + if frame.is_keyframe { + tracing::debug!("RTSP keyframe sent"); + } + } + read_res = reader.read(&mut ctrl_read_buf) => { + let n = read_res?; + if n == 0 { + break; + } + + ctrl_buffer.extend_from_slice(&ctrl_read_buf[..n]); + + while strip_interleaved_frames_prefix(&mut ctrl_buffer) {} + + while let Some(raw_req) = take_rtsp_request_from_buffer(&mut ctrl_buffer) { + let Some(req) = parse_rtsp_request(&raw_req) else { + continue; + }; + + if handle_play_control_request(&mut writer, &req, &session_id).await? { + return Ok(()); + } + + while strip_interleaved_frames_prefix(&mut ctrl_buffer) {} + } + } + } + } + + Ok(()) +} + +pub(crate) async fn send_interleaved_rtp( + stream: &mut W, + channel: u8, + packet: &Packet, + marshal_buf: &mut Vec, +) -> Result<()> { + let rtp_len = packet.marshal_size(); + let rtp_len_u16 = u16::try_from(rtp_len).map_err(|_| { + AppError::VideoError(format!( + "RTP packet too large for interleaved framing: {} bytes", + rtp_len + )) + })?; + + marshal_buf.clear(); + marshal_buf.reserve(4 + rtp_len); + marshal_buf.extend_from_slice(&[b'$', channel, (rtp_len_u16 >> 8) as u8, rtp_len_u16 as u8]); + let body_off = marshal_buf.len(); + marshal_buf.resize(body_off + rtp_len, 0); + + let written = packet + .marshal_to(&mut marshal_buf[body_off..]) + .map_err(|e| AppError::VideoError(format!("RTP marshal failed: {}", e)))?; + if written != rtp_len { + return Err(AppError::VideoError(format!( + "RTP marshal size mismatch: wrote {written}, expected {rtp_len}" + ))); + } + + stream.write_all(marshal_buf).await?; + Ok(()) +} + +pub(crate) async fn handle_play_control_request( + stream: &mut W, + req: &RtspRequest, + session_id: &str, +) -> Result { + use super::protocol::OPTIONS_PUBLIC_CAPABILITIES; + + match &req.method { + rtsp::Method::Teardown => { + send_response(stream, req, 200, "OK", vec![], "", session_id).await?; + Ok(true) + } + rtsp::Method::Options => { + send_response( + stream, + req, + 200, + "OK", + vec![( + "Public".to_string(), + OPTIONS_PUBLIC_CAPABILITIES.to_string(), + )], + "", + session_id, + ) + .await?; + Ok(false) + } + rtsp::Method::GetParameter | rtsp::Method::SetParameter => { + send_response(stream, req, 200, "OK", vec![], "", session_id).await?; + Ok(false) + } + _ => { + send_response( + stream, + req, + 405, + "Method Not Allowed", + vec![], + "", + session_id, + ) + .await?; + Ok(false) + } + } +} + +fn pts_to_rtp_timestamp(pts_ms: i64) -> u32 { + if pts_ms <= 0 { + return 0; + } + ((pts_ms as u64 * RTP_CLOCK_RATE as u64) / 1000) as u32 +} + +fn rtp_timestamp_increment(frame_duration: Duration) -> u32 { + let inc = (frame_duration.as_secs_f64() * f64::from(RTP_CLOCK_RATE)).round() as u32; + inc.max(1) +} + +fn monotonic_rtp_timestamp(pts_ms: i64, last: &mut u32, frame_duration: Duration) -> u32 { + let from_pts = pts_to_rtp_timestamp(pts_ms); + let inc = rtp_timestamp_increment(frame_duration); + let ts = if from_pts > *last { + from_pts + } else { + last.wrapping_add(inc) + }; + *last = ts; + ts +} + +fn is_frame_codec_match(frame: &EncodedVideoFrame, codec: &RtspCodec) -> bool { + matches!( + (frame.codec, codec), + (VideoEncoderType::H264, RtspCodec::H264) | (VideoEncoderType::H265, RtspCodec::H265) + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + use tokio::io::{duplex, AsyncReadExt}; + + fn make_test_request(method: rtsp::Method) -> RtspRequest { + let mut headers = HashMap::new(); + headers.insert("cseq".to_string(), "7".to_string()); + RtspRequest { + method, + uri: "rtsp://127.0.0.1/live".to_string(), + version: rtsp::Version::V1_0, + headers, + } + } + + async fn read_response_from_duplex( + mut client: tokio::io::DuplexStream, + ) -> rtsp::Response> { + let mut buf = vec![0u8; 4096]; + let n = client + .read(&mut buf) + .await + .expect("failed to read rtsp response"); + assert!(n > 0); + let (message, consumed): (rtsp::Message>, usize) = + rtsp::Message::parse(&buf[..n]).expect("failed to parse rtsp response"); + assert_eq!(consumed, n); + + match message { + rtsp::Message::Response(response) => response, + _ => panic!("expected RTSP response"), + } + } + + #[tokio::test] + async fn play_control_teardown_returns_ok_and_stop() { + let req = make_test_request(rtsp::Method::Teardown); + let (client, mut server) = duplex(4096); + + let should_stop = handle_play_control_request(&mut server, &req, "session-1") + .await + .expect("control handling failed"); + assert!(should_stop); + + drop(server); + let response = read_response_from_duplex(client).await; + assert_eq!(response.status(), rtsp::StatusCode::Ok); + } + + #[tokio::test] + async fn play_control_pause_returns_method_not_allowed() { + let req = make_test_request(rtsp::Method::Pause); + let (client, mut server) = duplex(4096); + + let should_stop = handle_play_control_request(&mut server, &req, "session-1") + .await + .expect("control handling failed"); + assert!(!should_stop); + + drop(server); + let response = read_response_from_duplex(client).await; + assert_eq!(response.status(), rtsp::StatusCode::MethodNotAllowed); + } + + #[test] + fn monotonic_rtp_timestamp_steps_when_pts_stays_zero() { + let d = Duration::from_millis(33); + let mut last = 0u32; + let a = monotonic_rtp_timestamp(0, &mut last, d); + let b = monotonic_rtp_timestamp(0, &mut last, d); + let c = monotonic_rtp_timestamp(0, &mut last, d); + assert!(a > 0); + assert!(b > a); + assert!(c > b); + } + + #[test] + fn monotonic_rtp_timestamp_uses_pts_when_it_advances() { + let d = Duration::from_millis(33); + let mut last = 0u32; + let a = monotonic_rtp_timestamp(1000, &mut last, d); + assert_eq!(a, 90_000); + let b = monotonic_rtp_timestamp(2000, &mut last, d); + assert_eq!(b, 180_000); + } +} diff --git a/src/rtsp/types.rs b/src/rtsp/types.rs new file mode 100644 index 00000000..2572156d --- /dev/null +++ b/src/rtsp/types.rs @@ -0,0 +1,53 @@ +use rand::Rng; +use rtsp_types as rtsp; +use std::collections::HashMap; +use std::fmt; + +#[derive(Debug, Clone, PartialEq)] +pub enum RtspServiceStatus { + Stopped, + Starting, + Running, + Error(String), +} + +impl fmt::Display for RtspServiceStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Stopped => write!(f, "stopped"), + Self::Starting => write!(f, "starting"), + Self::Running => write!(f, "running"), + Self::Error(err) => write!(f, "error: {}", err), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct RtspRequest { + pub method: rtsp::Method, + pub uri: String, + pub version: rtsp::Version, + pub headers: HashMap, +} + +pub(crate) struct RtspConnectionState { + pub session_id: String, + pub setup_done: bool, + pub interleaved_channel: u8, +} + +impl RtspConnectionState { + pub fn new() -> Self { + Self { + session_id: generate_session_id(), + setup_done: false, + interleaved_channel: 0, + } + } +} + +pub(crate) fn generate_session_id() -> String { + let mut rng = rand::rng(); + let value: u64 = rng.random(); + format!("{:016x}", value) +} diff --git a/src/rustdesk/bytes_codec.rs b/src/rustdesk/bytes_codec.rs index 18f163ca..d00896ab 100644 --- a/src/rustdesk/bytes_codec.rs +++ b/src/rustdesk/bytes_codec.rs @@ -1,21 +1,11 @@ -//! RustDesk BytesCodec - Variable-length framing for TCP messages -//! -//! RustDesk uses a custom variable-length encoding for message framing: -//! - Length <= 0x3F (63): 1-byte header, format `(len << 2)` -//! - Length <= 0x3FFF (16383): 2-byte LE header, format `(len << 2) | 0x1` -//! - Length <= 0x3FFFFF (4194303): 3-byte LE header, format `(len << 2) | 0x2` -//! - Length <= 0x3FFFFFFF (1073741823): 4-byte LE header, format `(len << 2) | 0x3` -//! -//! The low 2 bits of the first byte indicate the header length (+1). +//! Variable-length TCP framing (RustDesk wire format). use bytes::{Buf, BufMut, Bytes, BytesMut}; use std::io; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -/// Maximum packet length (1GB) const MAX_PACKET_LENGTH: usize = 0x3FFFFFFF; -/// Encode a message with RustDesk's variable-length framing pub fn encode_frame(data: &[u8]) -> io::Result> { let len = data.len(); let mut buf = Vec::with_capacity(len + 4); @@ -44,8 +34,6 @@ pub fn encode_frame(data: &[u8]) -> io::Result> { Ok(buf) } -/// Decode the header to get message length -/// Returns (header_length, message_length) fn decode_header(first_byte: u8, header_bytes: &[u8]) -> (usize, usize) { let head_len = ((first_byte & 0x3) + 1) as usize; @@ -64,21 +52,17 @@ fn decode_header(first_byte: u8, header_bytes: &[u8]) -> (usize, usize) { (head_len, msg_len) } -/// Read a single framed message from an async reader pub async fn read_frame(reader: &mut R) -> io::Result { - // Read first byte to determine header length let mut first_byte = [0u8; 1]; reader.read_exact(&mut first_byte).await?; let head_len = ((first_byte[0] & 0x3) + 1) as usize; - // Read remaining header bytes if needed let mut header_rest = [0u8; 3]; if head_len > 1 { reader.read_exact(&mut header_rest[..head_len - 1]).await?; } - // Calculate message length let (_, msg_len) = decode_header(first_byte[0], &header_rest); if msg_len > MAX_PACKET_LENGTH { @@ -88,7 +72,6 @@ pub async fn read_frame(reader: &mut R) -> io::Result(reader: &mut R) -> io::Result(writer: &mut W, data: &[u8]) -> io::Result<()> { let frame = encode_frame(data)?; writer.write_all(&frame).await?; @@ -104,10 +86,6 @@ pub async fn write_frame(writer: &mut W, data: &[u8]) -> Ok(()) } -/// Write a framed message using a reusable buffer (reduces allocations) -/// -/// This version reuses the provided BytesMut buffer to avoid allocation on each call. -/// The buffer is cleared before use and will grow as needed. pub async fn write_frame_buffered( writer: &mut W, data: &[u8], @@ -120,11 +98,9 @@ pub async fn write_frame_buffered( Ok(()) } -/// Encode a message with RustDesk's variable-length framing into an existing buffer pub fn encode_frame_into(data: &[u8], buf: &mut BytesMut) -> io::Result<()> { let len = data.len(); - // Reserve space for header (max 4 bytes) + data buf.reserve(4 + len); if len <= 0x3F { @@ -149,7 +125,7 @@ pub fn encode_frame_into(data: &[u8], buf: &mut BytesMut) -> io::Result<()> { Ok(()) } -/// BytesCodec for stateful decoding (compatible with tokio-util codec) +/// Stateful decoder for `Framed`. #[derive(Debug, Clone, Copy)] pub struct BytesCodec { state: DecodeState, @@ -180,7 +156,6 @@ impl BytesCodec { self.max_packet_length = n; } - /// Decode from a BytesMut buffer (for use with Framed) pub fn decode(&mut self, src: &mut BytesMut) -> io::Result> { let n = match self.state { DecodeState::Head => match self.decode_head(src)? { @@ -242,7 +217,6 @@ impl BytesCodec { Ok(Some(src.split_to(n))) } - /// Encode a message into a BytesMut buffer pub fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> io::Result<()> { let len = data.len(); @@ -276,7 +250,7 @@ mod tests { fn test_encode_decode_small() { let data = vec![1u8; 63]; let encoded = encode_frame(&data).unwrap(); - assert_eq!(encoded.len(), 63 + 1); // 1 byte header + assert_eq!(encoded.len(), 63 + 1); let mut codec = BytesCodec::new(); let mut buf = BytesMut::from(&encoded[..]); @@ -288,7 +262,7 @@ mod tests { fn test_encode_decode_medium() { let data = vec![2u8; 1000]; let encoded = encode_frame(&data).unwrap(); - assert_eq!(encoded.len(), 1000 + 2); // 2 byte header + assert_eq!(encoded.len(), 1000 + 2); let mut codec = BytesCodec::new(); let mut buf = BytesMut::from(&encoded[..]); @@ -300,7 +274,7 @@ mod tests { fn test_encode_decode_large() { let data = vec![3u8; 100000]; let encoded = encode_frame(&data).unwrap(); - assert_eq!(encoded.len(), 100000 + 3); // 3 byte header + assert_eq!(encoded.len(), 100000 + 3); let mut codec = BytesCodec::new(); let mut buf = BytesMut::from(&encoded[..]); diff --git a/src/rustdesk/config.rs b/src/rustdesk/config.rs index c8c64f29..cd954126 100644 --- a/src/rustdesk/config.rs +++ b/src/rustdesk/config.rs @@ -1,57 +1,26 @@ -//! RustDesk Configuration -//! -//! Configuration types for the RustDesk protocol integration. - use serde::{Deserialize, Serialize}; use typeshare::typeshare; -/// RustDesk configuration #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(default)] pub struct RustDeskConfig { - /// Enable RustDesk protocol pub enabled: bool, - - /// Rendezvous server address (hbbs), e.g., "rs.example.com" or "192.168.1.100:21116" - /// Required for RustDesk to function pub rendezvous_server: String, - - /// Relay server address (hbbr), if different from rendezvous server - /// Usually the same host as rendezvous server but different port (21117) pub relay_server: Option, - - /// Relay server authentication key (licence_key) - /// Required if the relay server is configured with -k option #[typeshare(skip)] pub relay_key: Option, - - /// Device ID (9-digit number), auto-generated if empty pub device_id: String, - - /// Device password for client authentication #[typeshare(skip)] pub device_password: String, - - /// Public key for encryption (Curve25519, base64 encoded), auto-generated #[typeshare(skip)] pub public_key: Option, - - /// Private key for encryption (Curve25519, base64 encoded), auto-generated #[typeshare(skip)] pub private_key: Option, - - /// Signing public key (Ed25519, base64 encoded), auto-generated - /// Used for SignedId verification by clients #[typeshare(skip)] pub signing_public_key: Option, - - /// Signing private key (Ed25519, base64 encoded), auto-generated - /// Used for signing SignedId messages #[typeshare(skip)] pub signing_private_key: Option, - - /// UUID for rendezvous server registration (persisted to avoid UUID_MISMATCH) #[typeshare(skip)] pub uuid: Option, } @@ -75,8 +44,6 @@ impl Default for RustDeskConfig { } impl RustDeskConfig { - /// Check if the configuration is valid for starting the service - /// Returns true if enabled and has a valid server pub fn is_valid(&self) -> bool { self.enabled && !self.rendezvous_server.is_empty() @@ -84,44 +51,35 @@ impl RustDeskConfig { && !self.device_password.is_empty() } - /// Get the rendezvous server (user-configured) pub fn effective_rendezvous_server(&self) -> &str { &self.rendezvous_server } - /// Generate a new random device ID pub fn generate_device_id() -> String { generate_device_id() } - /// Generate a new random password pub fn generate_password() -> String { generate_random_password() } - /// Get or generate the UUID for rendezvous registration - /// Returns (uuid_bytes, is_new) where is_new indicates if a new UUID was generated pub fn ensure_uuid(&mut self) -> ([u8; 16], bool) { if let Some(ref uuid_str) = self.uuid { - // Try to parse existing UUID if let Ok(uuid) = uuid::Uuid::parse_str(uuid_str) { return (*uuid.as_bytes(), false); } } - // Generate new UUID let new_uuid = uuid::Uuid::new_v4(); self.uuid = Some(new_uuid.to_string()); (*new_uuid.as_bytes(), true) } - /// Get the UUID bytes (returns None if not set) pub fn get_uuid_bytes(&self) -> Option<[u8; 16]> { self.uuid .as_ref() .and_then(|s| uuid::Uuid::parse_str(s).ok().map(|u| *u.as_bytes())) } - /// Get the rendezvous server address with default port pub fn rendezvous_addr(&self) -> String { let server = &self.rendezvous_server; if server.is_empty() { @@ -133,7 +91,6 @@ impl RustDeskConfig { } } - /// Get the relay server address with default port pub fn relay_addr(&self) -> Option { self.relay_server .as_ref() @@ -145,7 +102,6 @@ impl RustDeskConfig { } }) .or_else(|| { - // Default: same host as rendezvous server let server = &self.rendezvous_server; if !server.is_empty() { let host = server.split(':').next().unwrap_or(""); @@ -161,7 +117,6 @@ impl RustDeskConfig { } } -/// Generate a random 9-digit device ID pub fn generate_device_id() -> String { use rand::Rng; let mut rng = rand::rng(); @@ -169,7 +124,6 @@ pub fn generate_device_id() -> String { id.to_string() } -/// Generate a random 8-character password pub fn generate_random_password() -> String { use rand::Rng; const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; @@ -212,7 +166,6 @@ mod tests { config.rendezvous_server = "example.com:21116".to_string(); assert_eq!(config.rendezvous_addr(), "example.com:21116"); - // Empty server returns empty string config.rendezvous_server = String::new(); assert_eq!(config.rendezvous_addr(), ""); } @@ -224,17 +177,14 @@ mod tests { ..Default::default() }; - // Rendezvous server configured, relay defaults to same host assert_eq!(config.relay_addr(), Some("example.com:21117".to_string())); - // Explicit relay server config.relay_server = Some("relay.example.com".to_string()); assert_eq!( config.relay_addr(), Some("relay.example.com:21117".to_string()) ); - // No rendezvous server, relay is None config.rendezvous_server = String::new(); config.relay_server = None; assert_eq!(config.relay_addr(), None); @@ -247,10 +197,8 @@ mod tests { ..Default::default() }; - // When user sets a server, use it assert_eq!(config.effective_rendezvous_server(), "custom.example.com"); - // When empty, returns empty config.rendezvous_server = String::new(); assert_eq!(config.effective_rendezvous_server(), ""); } diff --git a/src/rustdesk/connection.rs b/src/rustdesk/connection.rs index 02fd293c..2f6c4294 100644 --- a/src/rustdesk/connection.rs +++ b/src/rustdesk/connection.rs @@ -1,12 +1,4 @@ -//! RustDesk Connection Handler -//! -//! This module handles incoming connections from RustDesk clients. -//! It manages the connection lifecycle including: -//! - Connection establishment (P2P or via relay) -//! - Encrypted handshake -//! - Authentication -//! - Message routing (video, audio, input) -//! - Video frame streaming (shared with WebRTC) +//! Incoming RustDesk TCP sessions (handshake, AV, input). use std::net::SocketAddr; use std::sync::Arc; @@ -23,6 +15,7 @@ use tracing::{debug, error, info, warn}; use crate::audio::AudioController; use crate::hid::{CanonicalKey, HidController, KeyEventType, KeyboardEvent, KeyboardModifiers}; +use crate::utils::hostname_from_etc; use crate::video::codec_constraints::{ encoder_codec_to_id, encoder_codec_to_video_codec, video_codec_to_encoder_codec, }; @@ -94,13 +87,6 @@ impl InputThrottler { } } -/// Get system hostname -fn get_hostname() -> String { - std::fs::read_to_string("/etc/hostname") - .map(|s| s.trim().to_string()) - .unwrap_or_else(|_| "One-KVM".to_string()) -} - /// Connection state #[derive(Debug, Clone, PartialEq)] pub enum ConnectionState { @@ -1165,7 +1151,7 @@ impl Connection { let mut peer_info = PeerInfo::new(); peer_info.username = "one-kvm".to_string(); - peer_info.hostname = get_hostname(); + peer_info.hostname = hostname_from_etc(); peer_info.platform = RUSTDESK_COMPAT_PLATFORM.to_string(); peer_info.displays.push(display_info); peer_info.current_display = 0; @@ -1786,7 +1772,7 @@ async fn run_audio_streaming( } // Subscribe to the audio Opus stream - let mut opus_rx = match audio_controller.subscribe_opus_async().await { + let mut opus_rx = match audio_controller.subscribe_opus().await { Some(rx) => rx, None => { // Audio not available, wait and retry diff --git a/src/rustdesk/crypto.rs b/src/rustdesk/crypto.rs index 88b1257f..a4d07c8c 100644 --- a/src/rustdesk/crypto.rs +++ b/src/rustdesk/crypto.rs @@ -1,10 +1,4 @@ -//! RustDesk Cryptography -//! -//! This module implements the NaCl-based cryptography used by RustDesk: -//! - Curve25519 for key exchange -//! - XSalsa20-Poly1305 for authenticated encryption -//! - Ed25519 for signatures -//! - Ed25519 to Curve25519 key conversion for unified keypair usage +//! NaCl crypto (RustDesk-compatible). use base64::{engine::general_purpose::STANDARD as BASE64, Engine}; use sodiumoxide::crypto::box_::{self, Nonce, PublicKey, SecretKey}; @@ -12,7 +6,6 @@ use sodiumoxide::crypto::secretbox; use sodiumoxide::crypto::sign::{self, ed25519}; use thiserror::Error; -/// Cryptography errors #[derive(Debug, Error)] pub enum CryptoError { #[error("Failed to initialize sodiumoxide")] @@ -31,13 +24,10 @@ pub enum CryptoError { KeyConversionFailed, } -/// Initialize the cryptography library -/// Must be called before using any crypto functions pub fn init() -> Result<(), CryptoError> { sodiumoxide::init().map_err(|_| CryptoError::InitError) } -/// A keypair for asymmetric encryption #[derive(Clone)] pub struct KeyPair { pub public_key: PublicKey, @@ -45,7 +35,6 @@ pub struct KeyPair { } impl KeyPair { - /// Generate a new random keypair pub fn generate() -> Self { let (public_key, secret_key) = box_::gen_keypair(); Self { @@ -54,7 +43,6 @@ impl KeyPair { } } - /// Create from existing keys pub fn from_keys(public_key: &[u8], secret_key: &[u8]) -> Result { let pk = PublicKey::from_slice(public_key).ok_or(CryptoError::InvalidKeyLength)?; let sk = SecretKey::from_slice(secret_key).ok_or(CryptoError::InvalidKeyLength)?; @@ -64,27 +52,22 @@ impl KeyPair { }) } - /// Get public key as bytes pub fn public_key_bytes(&self) -> &[u8] { self.public_key.as_ref() } - /// Get secret key as bytes pub fn secret_key_bytes(&self) -> &[u8] { self.secret_key.as_ref() } - /// Encode public key as base64 pub fn public_key_base64(&self) -> String { BASE64.encode(self.public_key_bytes()) } - /// Encode secret key as base64 pub fn secret_key_base64(&self) -> String { BASE64.encode(self.secret_key_bytes()) } - /// Create from base64-encoded keys pub fn from_base64(public_key: &str, secret_key: &str) -> Result { let pk_bytes = BASE64 .decode(public_key) @@ -96,15 +79,10 @@ impl KeyPair { } } -/// Generate a random nonce for box encryption pub fn generate_nonce() -> Nonce { box_::gen_nonce() } -/// Encrypt data using public-key cryptography (NaCl box) -/// -/// Uses the sender's secret key and receiver's public key for encryption. -/// Returns (nonce, ciphertext). pub fn encrypt_box( data: &[u8], their_public_key: &PublicKey, @@ -115,7 +93,6 @@ pub fn encrypt_box( (nonce, ciphertext) } -/// Decrypt data using public-key cryptography (NaCl box) pub fn decrypt_box( ciphertext: &[u8], nonce: &Nonce, @@ -126,14 +103,12 @@ pub fn decrypt_box( .map_err(|_| CryptoError::DecryptionFailed) } -/// Encrypt data with a precomputed shared key pub fn encrypt_with_key(data: &[u8], key: &secretbox::Key) -> (secretbox::Nonce, Vec) { let nonce = secretbox::gen_nonce(); let ciphertext = secretbox::seal(data, &nonce, key); (nonce, ciphertext) } -/// Decrypt data with a precomputed shared key pub fn decrypt_with_key( ciphertext: &[u8], nonce: &secretbox::Nonce, @@ -142,8 +117,6 @@ pub fn decrypt_with_key( secretbox::open(ciphertext, nonce, key).map_err(|_| CryptoError::DecryptionFailed) } -/// Compute a shared symmetric key from public/private keypair -/// This is the precomputed key for the NaCl box pub fn precompute_key( their_public_key: &PublicKey, our_secret_key: &SecretKey, @@ -151,23 +124,18 @@ pub fn precompute_key( box_::precompute(their_public_key, our_secret_key) } -/// Create a symmetric key from raw bytes pub fn symmetric_key_from_slice(key: &[u8]) -> Result { secretbox::Key::from_slice(key).ok_or(CryptoError::InvalidKeyLength) } -/// Parse a nonce from bytes pub fn nonce_from_slice(bytes: &[u8]) -> Result { Nonce::from_slice(bytes).ok_or(CryptoError::InvalidNonce) } -/// Parse a public key from bytes pub fn public_key_from_slice(bytes: &[u8]) -> Result { PublicKey::from_slice(bytes).ok_or(CryptoError::InvalidKeyLength) } -/// Hash a password for storage/comparison -/// RustDesk uses simple SHA256 for password hashing pub fn hash_password(password: &str, salt: &str) -> Vec { use sha2::{Digest, Sha256}; let mut hasher = Sha256::new(); @@ -176,35 +144,24 @@ pub fn hash_password(password: &str, salt: &str) -> Vec { hasher.finalize().to_vec() } -/// RustDesk double hash for password verification -/// Client calculates: SHA256(SHA256(password + salt) + challenge) -/// This matches what the client sends in LoginRequest pub fn hash_password_double(password: &str, salt: &str, challenge: &str) -> Vec { use sha2::{Digest, Sha256}; - // First hash: SHA256(password + salt) let mut hasher1 = Sha256::new(); hasher1.update(password.as_bytes()); hasher1.update(salt.as_bytes()); let first_hash = hasher1.finalize(); - // Second hash: SHA256(first_hash + challenge) let mut hasher2 = Sha256::new(); hasher2.update(first_hash); hasher2.update(challenge.as_bytes()); hasher2.finalize().to_vec() } -/// Verify a password hash pub fn verify_password(password: &str, salt: &str, expected_hash: &[u8]) -> bool { let computed = hash_password(password, salt); - // Constant-time comparison would be better, but for our use case this is acceptable computed == expected_hash } -/// Decrypt symmetric key using Curve25519 secret key directly -/// -/// This is used when we have a fresh Curve25519 keypair for the connection -/// (as per RustDesk protocol - each connection generates a new keypair) pub fn decrypt_symmetric_key( their_temp_public_key: &[u8], sealed_symmetric_key: &[u8], @@ -217,7 +174,6 @@ pub fn decrypt_symmetric_key( let their_pk = PublicKey::from_slice(their_temp_public_key).ok_or(CryptoError::InvalidKeyLength)?; - // Use zero nonce as per RustDesk protocol let nonce = box_::Nonce([0u8; box_::NONCEBYTES]); let key_bytes = box_::open(sealed_symmetric_key, &nonce, &their_pk, our_secret_key) @@ -226,11 +182,7 @@ pub fn decrypt_symmetric_key( secretbox::Key::from_slice(&key_bytes).ok_or(CryptoError::InvalidKeyLength) } -/// Encrypt a message using the negotiated symmetric key -/// -/// RustDesk uses a specific nonce format for session encryption pub fn encrypt_message(data: &[u8], key: &secretbox::Key, nonce_counter: u64) -> Vec { - // Create nonce from counter (little-endian, padded to 24 bytes) let mut nonce_bytes = [0u8; secretbox::NONCEBYTES]; nonce_bytes[..8].copy_from_slice(&nonce_counter.to_le_bytes()); let nonce = secretbox::Nonce(nonce_bytes); @@ -238,13 +190,11 @@ pub fn encrypt_message(data: &[u8], key: &secretbox::Key, nonce_counter: u64) -> secretbox::seal(data, &nonce, key) } -/// Decrypt a message using the negotiated symmetric key pub fn decrypt_message( ciphertext: &[u8], key: &secretbox::Key, nonce_counter: u64, ) -> Result, CryptoError> { - // Create nonce from counter (little-endian, padded to 24 bytes) let mut nonce_bytes = [0u8; secretbox::NONCEBYTES]; nonce_bytes[..8].copy_from_slice(&nonce_counter.to_le_bytes()); let nonce = secretbox::Nonce(nonce_bytes); @@ -252,7 +202,6 @@ pub fn decrypt_message( secretbox::open(ciphertext, &nonce, key).map_err(|_| CryptoError::DecryptionFailed) } -/// Ed25519 signing keypair for RustDesk SignedId messages #[derive(Clone)] pub struct SigningKeyPair { pub public_key: sign::PublicKey, @@ -260,7 +209,6 @@ pub struct SigningKeyPair { } impl SigningKeyPair { - /// Generate a new random signing keypair pub fn generate() -> Self { let (public_key, secret_key) = sign::gen_keypair(); Self { @@ -269,7 +217,6 @@ impl SigningKeyPair { } } - /// Create from existing keys pub fn from_keys(public_key: &[u8], secret_key: &[u8]) -> Result { let pk = sign::PublicKey::from_slice(public_key).ok_or(CryptoError::InvalidKeyLength)?; let sk = sign::SecretKey::from_slice(secret_key).ok_or(CryptoError::InvalidKeyLength)?; @@ -279,27 +226,22 @@ impl SigningKeyPair { }) } - /// Get public key as bytes pub fn public_key_bytes(&self) -> &[u8] { self.public_key.as_ref() } - /// Get secret key as bytes pub fn secret_key_bytes(&self) -> &[u8] { self.secret_key.as_ref() } - /// Encode public key as base64 pub fn public_key_base64(&self) -> String { BASE64.encode(self.public_key_bytes()) } - /// Encode secret key as base64 pub fn secret_key_base64(&self) -> String { BASE64.encode(self.secret_key_bytes()) } - /// Create from base64-encoded keys pub fn from_base64(public_key: &str, secret_key: &str) -> Result { let pk_bytes = BASE64 .decode(public_key) @@ -310,42 +252,27 @@ impl SigningKeyPair { Self::from_keys(&pk_bytes, &sk_bytes) } - /// Sign a message - /// Returns the signature prepended to the message (as per RustDesk protocol) pub fn sign(&self, message: &[u8]) -> Vec { sign::sign(message, &self.secret_key) } - /// Sign a message and return just the signature (64 bytes) pub fn sign_detached(&self, message: &[u8]) -> [u8; 64] { let sig = sign::sign_detached(message, &self.secret_key); - // Use as_ref() to access the signature bytes since the inner field is private let sig_bytes: &[u8] = sig.as_ref(); let mut result = [0u8; 64]; result.copy_from_slice(sig_bytes); result } - /// Convert Ed25519 public key to Curve25519 public key for encryption - /// - /// This allows using the same keypair for both signing and encryption, - /// which is required by RustDesk's protocol where clients encrypt the - /// symmetric key using the public key from IdPk. pub fn to_curve25519_pk(&self) -> Result { ed25519::to_curve25519_pk(&self.public_key).map_err(|_| CryptoError::KeyConversionFailed) } - /// Convert Ed25519 secret key to Curve25519 secret key for decryption - /// - /// This allows decrypting messages that were encrypted using the - /// converted public key. pub fn to_curve25519_sk(&self) -> Result { ed25519::to_curve25519_sk(&self.secret_key).map_err(|_| CryptoError::KeyConversionFailed) } } -/// Verify a signed message -/// Returns the original message if signature is valid pub fn verify_signed( signed_message: &[u8], public_key: &sign::PublicKey, diff --git a/src/rustdesk/frame_adapters.rs b/src/rustdesk/frame_adapters.rs index 35785ca9..a8794cdf 100644 --- a/src/rustdesk/frame_adapters.rs +++ b/src/rustdesk/frame_adapters.rs @@ -1,8 +1,3 @@ -//! RustDesk Frame Adapters -//! -//! Converts One-KVM video/audio frames to RustDesk protocol format. -//! Optimized for zero-copy where possible and buffer reuse. - use bytes::Bytes; use protobuf::Message as ProtobufMessage; @@ -11,7 +6,6 @@ use super::protocol::hbb::message::{ CursorData, CursorPosition, EncodedVideoFrame, EncodedVideoFrames, Message, Misc, VideoFrame, }; -/// Video codec type for RustDesk #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum VideoCodec { H264, @@ -22,7 +16,6 @@ pub enum VideoCodec { } impl VideoCodec { - /// Get the codec ID for the RustDesk protocol pub fn to_codec_id(self) -> i32 { match self { VideoCodec::H264 => 0, @@ -34,21 +27,15 @@ impl VideoCodec { } } -/// Video frame adapter for converting to RustDesk format pub struct VideoFrameAdapter { - /// Current codec codec: VideoCodec, - /// Frame sequence number seq: u32, - /// Timestamp offset timestamp_base: u64, - /// Cached H264 SPS/PPS (Annex B NAL without start code) h264_sps: Option, h264_pps: Option, } impl VideoFrameAdapter { - /// Create a new video frame adapter pub fn new(codec: VideoCodec) -> Self { Self { codec, @@ -59,14 +46,10 @@ impl VideoFrameAdapter { } } - /// Set codec type pub fn set_codec(&mut self, codec: VideoCodec) { self.codec = codec; } - /// Convert encoded video data to RustDesk Message (zero-copy version) - /// - /// This version takes Bytes directly to avoid copying the frame data. pub fn encode_frame_from_bytes( &mut self, data: Bytes, @@ -74,7 +57,6 @@ impl VideoFrameAdapter { timestamp_ms: u64, ) -> Message { let data = self.prepare_h264_frame(data, is_keyframe); - // Calculate relative timestamp if self.seq == 0 { self.timestamp_base = timestamp_ms; } @@ -87,11 +69,9 @@ impl VideoFrameAdapter { self.seq = self.seq.wrapping_add(1); - // Wrap in EncodedVideoFrames container let mut frames = EncodedVideoFrames::new(); frames.frames.push(frame); - // Create the appropriate VideoFrame variant based on codec let mut video_frame = VideoFrame::new(); match self.codec { VideoCodec::H264 => video_frame.union = Some(vf_union::Union::H264s(frames)), @@ -111,7 +91,6 @@ impl VideoFrameAdapter { return data; } - // Parse SPS/PPS from Annex B data (without start codes) let (sps, pps) = crate::webrtc::rtp::extract_sps_pps(&data); let mut has_sps = false; let mut has_pps = false; @@ -125,7 +104,6 @@ impl VideoFrameAdapter { has_pps = true; } - // Inject cached SPS/PPS before IDR when missing if is_keyframe && (!has_sps || !has_pps) { if let (Some(sps), Some(pps)) = (self.h264_sps.as_ref(), self.h264_pps.as_ref()) { let mut out = Vec::with_capacity(8 + sps.len() + pps.len() + data.len()); @@ -141,14 +119,10 @@ impl VideoFrameAdapter { data } - /// Convert encoded video data to RustDesk Message pub fn encode_frame(&mut self, data: &[u8], is_keyframe: bool, timestamp_ms: u64) -> Message { self.encode_frame_from_bytes(Bytes::copy_from_slice(data), is_keyframe, timestamp_ms) } - /// Encode frame to bytes for sending (zero-copy version) - /// - /// Takes Bytes directly to avoid copying the frame data. pub fn encode_frame_bytes_zero_copy( &mut self, data: Bytes, @@ -159,7 +133,6 @@ impl VideoFrameAdapter { Bytes::from(msg.write_to_bytes().unwrap_or_default()) } - /// Encode frame to bytes for sending pub fn encode_frame_bytes( &mut self, data: &[u8], @@ -169,24 +142,18 @@ impl VideoFrameAdapter { self.encode_frame_bytes_zero_copy(Bytes::copy_from_slice(data), is_keyframe, timestamp_ms) } - /// Get current sequence number pub fn seq(&self) -> u32 { self.seq } } -/// Audio frame adapter for converting to RustDesk format pub struct AudioFrameAdapter { - /// Sample rate sample_rate: u32, - /// Channels channels: u8, - /// Format sent flag format_sent: bool, } impl AudioFrameAdapter { - /// Create a new audio frame adapter pub fn new(sample_rate: u32, channels: u8) -> Self { Self { sample_rate, @@ -195,7 +162,6 @@ impl AudioFrameAdapter { } } - /// Create audio format message (should be sent once before audio frames) pub fn create_format_message(&mut self) -> Message { self.format_sent = true; @@ -211,12 +177,10 @@ impl AudioFrameAdapter { msg } - /// Check if format message has been sent pub fn format_sent(&self) -> bool { self.format_sent } - /// Convert Opus audio data to RustDesk Message pub fn encode_opus_frame(&self, data: &[u8]) -> Message { let mut frame = AudioFrame::new(); frame.data = Bytes::copy_from_slice(data); @@ -226,23 +190,19 @@ impl AudioFrameAdapter { msg } - /// Encode Opus frame to bytes for sending pub fn encode_opus_bytes(&self, data: &[u8]) -> Bytes { let msg = self.encode_opus_frame(data); Bytes::from(msg.write_to_bytes().unwrap_or_default()) } - /// Reset state (call when restarting audio stream) pub fn reset(&mut self) { self.format_sent = false; } } -/// Cursor data adapter pub struct CursorAdapter; impl CursorAdapter { - /// Create cursor data message pub fn encode_cursor( id: u64, hotx: i32, @@ -264,7 +224,6 @@ impl CursorAdapter { msg } - /// Create cursor position message pub fn encode_position(x: i32, y: i32) -> Message { let mut pos = CursorPosition::new(); pos.x = x; @@ -284,7 +243,6 @@ mod tests { fn test_video_frame_encoding() { let mut adapter = VideoFrameAdapter::new(VideoCodec::H264); - // Encode a keyframe let data = vec![0x00, 0x00, 0x00, 0x01, 0x67]; // H264 SPS NAL let msg = adapter.encode_frame(&data, true, 0); @@ -324,7 +282,6 @@ mod tests { fn test_audio_frame_encoding() { let adapter = AudioFrameAdapter::new(48000, 2); - // Encode an Opus frame let opus_data = vec![0xFC, 0x01, 0x02]; // Fake Opus data let msg = adapter.encode_opus_frame(&opus_data); diff --git a/src/rustdesk/hid_adapter.rs b/src/rustdesk/hid_adapter.rs index dbb478d3..7fb99bed 100644 --- a/src/rustdesk/hid_adapter.rs +++ b/src/rustdesk/hid_adapter.rs @@ -1,7 +1,3 @@ -//! RustDesk HID Adapter -//! -//! Converts RustDesk HID events (KeyEvent, MouseEvent) to One-KVM HID events. - use super::protocol::hbb::message::key_event as ke_union; use super::protocol::{ControlKey, KeyEvent, MouseEvent}; use crate::hid::{ @@ -10,8 +6,6 @@ use crate::hid::{ }; use protobuf::Enum; -/// Mouse event types from RustDesk protocol -/// mask = (button << 3) | event_type pub mod mouse_type { pub const MOVE: i32 = 0; pub const DOWN: i32 = 1; @@ -21,7 +15,6 @@ pub mod mouse_type { pub const MOVE_RELATIVE: i32 = 5; } -/// Mouse button IDs from RustDesk protocol (before left shift by 3) pub mod mouse_button { pub const LEFT: i32 = 0x01; pub const RIGHT: i32 = 0x02; @@ -30,9 +23,6 @@ pub mod mouse_button { pub const FORWARD: i32 = 0x10; } -/// Convert RustDesk MouseEvent to One-KVM MouseEvent(s) -/// Returns a Vec because a single RustDesk event may need multiple One-KVM events -/// (e.g., move + button + scroll) pub fn convert_mouse_event( event: &MouseEvent, screen_width: u32, @@ -41,23 +31,18 @@ pub fn convert_mouse_event( ) -> Vec { let mut events = Vec::new(); - // Parse RustDesk mask format: (button << 3) | event_type let event_type = event.mask & 0x07; let button_id = event.mask >> 3; let include_abs_move = !relative_mode; match event_type { mouse_type::MOVE => { - // RustDesk uses absolute coordinates let x = event.x.max(0) as u32; let y = event.y.max(0) as u32; - // Normalize to 0-32767 range for absolute mouse (USB HID standard) let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32; let abs_y = ((y as u64 * 32767) / screen_height.max(1) as u64) as i32; - // Move event - may have button held down (button_id > 0 means dragging) - // Just send move, button state is tracked separately by HID backend events.push(OneKvmMouseEvent { event_type: MouseEventType::MoveAbs, x: abs_x, @@ -67,7 +52,6 @@ pub fn convert_mouse_event( }); } mouse_type::MOVE_RELATIVE => { - // Relative movement uses delta values directly (dx, dy). events.push(OneKvmMouseEvent { event_type: MouseEventType::Move, x: event.x, @@ -78,7 +62,6 @@ pub fn convert_mouse_event( } mouse_type::DOWN => { if include_abs_move { - // Button down - first move, then press let x = event.x.max(0) as u32; let y = event.y.max(0) as u32; let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32; @@ -104,7 +87,6 @@ pub fn convert_mouse_event( } mouse_type::UP => { if include_abs_move { - // Button up - first move, then release let x = event.x.max(0) as u32; let y = event.y.max(0) as u32; let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32; @@ -130,7 +112,6 @@ pub fn convert_mouse_event( } mouse_type::WHEEL => { if include_abs_move { - // Scroll event - move first, then scroll let x = event.x.max(0) as u32; let y = event.y.max(0) as u32; let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32; @@ -144,9 +125,6 @@ pub fn convert_mouse_event( }); } - // RustDesk encodes scroll direction in the y coordinate - // Positive y = scroll up, Negative y = scroll down - // The button_id field is not used for direction let scroll = if event.y > 0 { 1i8 } else { -1i8 }; events.push(OneKvmMouseEvent { event_type: MouseEventType::Scroll, @@ -158,7 +136,6 @@ pub fn convert_mouse_event( } _ => { if include_abs_move { - // Unknown event type, just move let x = event.x.max(0) as u32; let y = event.y.max(0) as u32; let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32; @@ -177,7 +154,6 @@ pub fn convert_mouse_event( events } -/// Convert RustDesk button ID to One-KVM MouseButton fn button_id_to_button(button_id: i32) -> Option { match button_id { mouse_button::LEFT => Some(MouseButton::Left), @@ -187,34 +163,19 @@ fn button_id_to_button(button_id: i32) -> Option { } } -/// Convert RustDesk KeyEvent to One-KVM KeyboardEvent -/// -/// RustDesk KeyEvent has two modes: -/// - down=true/false: Key state (pressed/released) -/// - press=true: Complete key press (down + up), used for typing -/// -/// For press=true events, we only send Down and let the caller handle -/// the timing for Up event if needed. Most systems handle this correctly. pub fn convert_key_event(event: &KeyEvent) -> Option { - // Determine if this is a key down or key up event - // press=true means "key was pressed" (down event) - // down=true means key is currently held down - // down=false with press=false means key was released let event_type = if event.down || event.press { KeyEventType::Down } else { KeyEventType::Up }; - // For modifier keys sent as ControlKey, don't include them in modifiers - // to avoid double-pressing. The modifier will be tracked by HID state. let modifiers = if is_modifier_control_key(event) { KeyboardModifiers::default() } else { parse_modifiers(event) }; - // Handle control keys if let Some(ke_union::Union::ControlKey(ck)) = &event.union { if let Some(key) = control_key_to_hid(ck.value()) { let key = CanonicalKey::from_hid_usage(key)?; @@ -226,9 +187,7 @@ pub fn convert_key_event(event: &KeyEvent) -> Option { } } - // Handle character keys (chr field contains platform-specific keycode) if let Some(ke_union::Union::Chr(chr)) = &event.union { - // chr contains USB HID scancode on Windows, X11 keycode on Linux if let Some(key) = keycode_to_hid(*chr) { let key = CanonicalKey::from_hid_usage(key)?; return Some(KeyboardEvent { @@ -239,13 +198,9 @@ pub fn convert_key_event(event: &KeyEvent) -> Option { } } - // Handle unicode (for text input, we'd need to convert to scancodes) - // Unicode input requires more complex handling, skip for now - None } -/// Check if the event is a modifier key sent as ControlKey fn is_modifier_control_key(event: &KeyEvent) -> bool { if let Some(ke_union::Union::ControlKey(ck)) = &event.union { let val = ck.value(); @@ -260,7 +215,6 @@ fn is_modifier_control_key(event: &KeyEvent) -> bool { false } -/// Parse modifier keys from RustDesk KeyEvent into KeyboardModifiers fn parse_modifiers(event: &KeyEvent) -> KeyboardModifiers { let mut modifiers = KeyboardModifiers::default(); @@ -281,7 +235,6 @@ fn parse_modifiers(event: &KeyEvent) -> KeyboardModifiers { modifiers } -/// Convert RustDesk ControlKey to USB HID usage code fn control_key_to_hid(key: i32) -> Option { match key { x if x == ControlKey::Alt as i32 => Some(0xE2), // Left Alt @@ -342,67 +295,47 @@ fn control_key_to_hid(key: i32) -> Option { } } -/// Convert platform keycode to USB HID usage code -/// Handles Windows Virtual Key Codes, X11 keycodes, and ASCII codes fn keycode_to_hid(keycode: u32) -> Option { - // First try ASCII code mapping (RustDesk often sends ASCII codes) if let Some(hid) = ascii_to_hid(keycode) { return Some(hid); } - // Then try Windows Virtual Key Code mapping if let Some(hid) = windows_vk_to_hid(keycode) { return Some(hid); } - // Fall back to X11 keycode mapping for Linux clients x11_keycode_to_hid(keycode) } -/// Convert ASCII code to USB HID usage code fn ascii_to_hid(ascii: u32) -> Option { match ascii { - // Lowercase letters a-z (ASCII 97-122) - 97..=122 => { - // USB HID: a=0x04, b=0x05, ..., z=0x1D - Some((ascii - 97 + 0x04) as u8) - } - // Uppercase letters A-Z (ASCII 65-90) - 65..=90 => { - // USB HID: A=0x04, B=0x05, ..., Z=0x1D (same as lowercase) - Some((ascii - 65 + 0x04) as u8) - } - // Numbers 0-9 (ASCII 48-57) + 97..=122 => Some((ascii - 97 + 0x04) as u8), + 65..=90 => Some((ascii - 65 + 0x04) as u8), 48 => Some(0x27), // 0 49..=57 => Some((ascii - 49 + 0x1E) as u8), // 1-9 - // Common punctuation - 32 => Some(0x2C), // Space - 13 => Some(0x28), // Enter (CR) - 10 => Some(0x28), // Enter (LF) - 9 => Some(0x2B), // Tab - 27 => Some(0x29), // Escape - 8 => Some(0x2A), // Backspace - 127 => Some(0x4C), // Delete - // Symbols (US keyboard layout) - 45 => Some(0x2D), // - - 61 => Some(0x2E), // = - 91 => Some(0x2F), // [ - 93 => Some(0x30), // ] - 92 => Some(0x31), // \ - 59 => Some(0x33), // ; - 39 => Some(0x34), // ' - 96 => Some(0x35), // ` - 44 => Some(0x36), // , - 46 => Some(0x37), // . - 47 => Some(0x38), // / + 32 => Some(0x2C), // Space + 13 => Some(0x28), // Enter (CR) + 10 => Some(0x28), // Enter (LF) + 9 => Some(0x2B), // Tab + 27 => Some(0x29), // Escape + 8 => Some(0x2A), // Backspace + 127 => Some(0x4C), // Delete + 45 => Some(0x2D), // - + 61 => Some(0x2E), // = + 91 => Some(0x2F), // [ + 93 => Some(0x30), // ] + 92 => Some(0x31), // \ + 59 => Some(0x33), // ; + 39 => Some(0x34), // ' + 96 => Some(0x35), // ` + 44 => Some(0x36), // , + 46 => Some(0x37), // . + 47 => Some(0x38), // / _ => None, } } -/// Convert Windows Virtual Key Code to USB HID usage code fn windows_vk_to_hid(vk: u32) -> Option { match vk { - // Letters A-Z (VK_A=0x41 to VK_Z=0x5A) 0x41..=0x5A => { - // USB HID: A=0x04, B=0x05, ..., Z=0x1D let letter = (vk - 0x41) as u8; Some(match letter { 0 => 0x04, // A @@ -434,21 +367,16 @@ fn windows_vk_to_hid(vk: u32) -> Option { _ => return None, }) } - // Numbers 0-9 (VK_0=0x30 to VK_9=0x39) 0x30 => Some(0x27), // 0 0x31..=0x39 => Some((vk - 0x31 + 0x1E) as u8), // 1-9 - // Numpad 0-9 (VK_NUMPAD0=0x60 to VK_NUMPAD9=0x69) 0x60 => Some(0x62), // Numpad 0 0x61..=0x69 => Some((vk - 0x61 + 0x59) as u8), // Numpad 1-9 - // Numpad operators - 0x6A => Some(0x55), // Numpad * - 0x6B => Some(0x57), // Numpad + - 0x6D => Some(0x56), // Numpad - - 0x6E => Some(0x63), // Numpad . - 0x6F => Some(0x54), // Numpad / - // Function keys F1-F12 (VK_F1=0x70 to VK_F12=0x7B) + 0x6A => Some(0x55), // Numpad * + 0x6B => Some(0x57), // Numpad + + 0x6D => Some(0x56), // Numpad - + 0x6E => Some(0x63), // Numpad . + 0x6F => Some(0x54), // Numpad / 0x70..=0x7B => Some((vk - 0x70 + 0x3A) as u8), - // Special keys 0x08 => Some(0x2A), // Backspace 0x09 => Some(0x2B), // Tab 0x0D => Some(0x28), // Enter @@ -464,7 +392,6 @@ fn windows_vk_to_hid(vk: u32) -> Option { 0x28 => Some(0x51), // Down Arrow 0x2D => Some(0x49), // Insert 0x2E => Some(0x4C), // Delete - // OEM keys (US keyboard layout) 0xBA => Some(0x33), // ; : 0xBB => Some(0x2E), // = + 0xBC => Some(0x36), // , < @@ -476,66 +403,56 @@ fn windows_vk_to_hid(vk: u32) -> Option { 0xDC => Some(0x31), // \ | 0xDD => Some(0x30), // ] } 0xDE => Some(0x34), // ' " - // Lock keys 0x14 => Some(0x39), // Caps Lock 0x90 => Some(0x53), // Num Lock 0x91 => Some(0x47), // Scroll Lock - // Print Screen, Pause 0x2C => Some(0x46), // Print Screen 0x13 => Some(0x48), // Pause _ => None, } } -/// Convert X11 keycode to USB HID usage code (for Linux clients) fn x11_keycode_to_hid(keycode: u32) -> Option { match keycode { - // Numbers: X11 keycode 10="1", 11="2", ..., 18="9", 19="0" 10..=18 => Some((keycode - 10 + 0x1E) as u8), // 1-9 19 => Some(0x27), // 0 - // Punctuation - 20 => Some(0x2D), // - - 21 => Some(0x2E), // = - 34 => Some(0x2F), // [ - 35 => Some(0x30), // ] - // Letters (X11 keycodes are row-based) - // Row 1: q(24) w(25) e(26) r(27) t(28) y(29) u(30) i(31) o(32) p(33) - 24 => Some(0x14), // q - 25 => Some(0x1A), // w - 26 => Some(0x08), // e - 27 => Some(0x15), // r - 28 => Some(0x17), // t - 29 => Some(0x1C), // y - 30 => Some(0x18), // u - 31 => Some(0x0C), // i - 32 => Some(0x12), // o - 33 => Some(0x13), // p - // Row 2: a(38) s(39) d(40) f(41) g(42) h(43) j(44) k(45) l(46) - 38 => Some(0x04), // a - 39 => Some(0x16), // s - 40 => Some(0x07), // d - 41 => Some(0x09), // f - 42 => Some(0x0A), // g - 43 => Some(0x0B), // h - 44 => Some(0x0D), // j - 45 => Some(0x0E), // k - 46 => Some(0x0F), // l - 47 => Some(0x33), // ; - 48 => Some(0x34), // ' - 49 => Some(0x35), // ` - 51 => Some(0x31), // \ - // Row 3: z(52) x(53) c(54) v(55) b(56) n(57) m(58) - 52 => Some(0x1D), // z - 53 => Some(0x1B), // x - 54 => Some(0x06), // c - 55 => Some(0x19), // v - 56 => Some(0x05), // b - 57 => Some(0x11), // n - 58 => Some(0x10), // m - 59 => Some(0x36), // , - 60 => Some(0x37), // . - 61 => Some(0x38), // / - // Space + 20 => Some(0x2D), // - + 21 => Some(0x2E), // = + 34 => Some(0x2F), // [ + 35 => Some(0x30), // ] + 24 => Some(0x14), // q + 25 => Some(0x1A), // w + 26 => Some(0x08), // e + 27 => Some(0x15), // r + 28 => Some(0x17), // t + 29 => Some(0x1C), // y + 30 => Some(0x18), // u + 31 => Some(0x0C), // i + 32 => Some(0x12), // o + 33 => Some(0x13), // p + 38 => Some(0x04), // a + 39 => Some(0x16), // s + 40 => Some(0x07), // d + 41 => Some(0x09), // f + 42 => Some(0x0A), // g + 43 => Some(0x0B), // h + 44 => Some(0x0D), // j + 45 => Some(0x0E), // k + 46 => Some(0x0F), // l + 47 => Some(0x33), // ; + 48 => Some(0x34), // ' + 49 => Some(0x35), // ` + 51 => Some(0x31), // \ + 52 => Some(0x1D), // z + 53 => Some(0x1B), // x + 54 => Some(0x06), // c + 55 => Some(0x19), // v + 56 => Some(0x05), // b + 57 => Some(0x11), // n + 58 => Some(0x10), // m + 59 => Some(0x36), // , + 60 => Some(0x37), // . + 61 => Some(0x38), // / 65 => Some(0x2C), _ => None, } @@ -573,7 +490,6 @@ mod tests { let events = convert_mouse_event(&event, 1920, 1080, false); assert!(events.len() >= 2); - // Should have a button down event assert!(events .iter() .any(|e| e.event_type == MouseEventType::Down && e.button == Some(MouseButton::Left))); diff --git a/src/rustdesk/mod.rs b/src/rustdesk/mod.rs index 128337e1..2a1f6b49 100644 --- a/src/rustdesk/mod.rs +++ b/src/rustdesk/mod.rs @@ -1,17 +1,4 @@ -//! RustDesk Protocol Integration Module -//! -//! This module implements the RustDesk client protocol, enabling One-KVM devices -//! to be accessed via standard RustDesk clients through existing hbbs/hbbr servers. -//! -//! ## Architecture -//! -//! - `config`: Configuration types for RustDesk settings -//! - `protocol`: Protobuf message wrappers and serialization -//! - `crypto`: NaCl cryptography (key generation, encryption, signatures) -//! - `rendezvous`: Communication with hbbs rendezvous server -//! - `connection`: Client session handling -//! - `frame_adapters`: Video/audio frame conversion to RustDesk format -//! - `hid_adapter`: RustDesk HID events to One-KVM conversion +//! RustDesk peer protocol (hbbs / hbbr). pub mod bytes_codec; pub mod config; @@ -44,19 +31,13 @@ use self::connection::ConnectionManager; use self::protocol::{make_local_addr, make_relay_response, make_request_relay}; use self::rendezvous::{AddrMangle, RendezvousMediator, RendezvousStatus}; -/// Relay connection timeout const RELAY_CONNECT_TIMEOUT_MS: u64 = 10_000; -/// RustDesk service status #[derive(Debug, Clone, PartialEq)] pub enum ServiceStatus { - /// Service is stopped Stopped, - /// Service is starting Starting, - /// Service is running and registered with rendezvous server Running, - /// Service encountered an error Error(String), } @@ -71,15 +52,8 @@ impl std::fmt::Display for ServiceStatus { } } -/// Default port for direct TCP connections (same as RustDesk) const DIRECT_LISTEN_PORT: u16 = 21118; -/// RustDesk Service -/// -/// Manages the RustDesk protocol integration, including: -/// - Registration with hbbs rendezvous server -/// - Accepting connections from RustDesk clients -/// - Streaming video/audio and receiving HID input pub struct RustDeskService { config: Arc>, status: Arc>, @@ -95,7 +69,6 @@ pub struct RustDeskService { } impl RustDeskService { - /// Create a new RustDesk service instance pub fn new( config: RustDeskConfig, video_manager: Arc, @@ -120,42 +93,34 @@ impl RustDeskService { } } - /// Get the port for direct TCP connections pub fn listen_port(&self) -> u16 { *self.listen_port.read() } - /// Get current service status pub fn status(&self) -> ServiceStatus { self.status.read().clone() } - /// Get current configuration pub fn config(&self) -> RustDeskConfig { self.config.read().clone() } - /// Update configuration pub fn update_config(&self, config: RustDeskConfig) { *self.config.write() = config; } - /// Get rendezvous status pub fn rendezvous_status(&self) -> Option { self.rendezvous.read().as_ref().map(|r| r.status()) } - /// Get device ID pub fn device_id(&self) -> String { self.config.read().device_id.clone() } - /// Get connection count pub fn connection_count(&self) -> usize { self.connection_manager.connection_count() } - /// Start the RustDesk service pub async fn start(&self) -> anyhow::Result<()> { let config = self.config.read().clone(); @@ -181,74 +146,44 @@ impl RustDeskService { config.rendezvous_addr() ); - // Initialize crypto if let Err(e) = crypto::init() { error!("Failed to initialize crypto: {}", e); *self.status.write() = ServiceStatus::Error(e.to_string()); return Err(e.into()); } - // Create and start rendezvous mediator with relay callback let mediator = Arc::new(RendezvousMediator::new(config.clone())); - // Set the keypair on connection manager (Curve25519 for encryption) let keypair = mediator.ensure_keypair(); self.connection_manager.set_keypair(keypair); - // Set the signing keypair on connection manager (Ed25519 for SignedId) let signing_keypair = mediator.ensure_signing_keypair(); self.connection_manager.set_signing_keypair(signing_keypair); - // Set the HID controller on connection manager self.connection_manager.set_hid(self.hid.clone()); - // Set the audio controller on connection manager for audio streaming self.connection_manager.set_audio(self.audio.clone()); - // Set the video manager on connection manager for video streaming self.connection_manager .set_video_manager(self.video_manager.clone()); *self.rendezvous.write() = Some(mediator.clone()); - // Start TCP listener BEFORE the rendezvous mediator to ensure port is set correctly - // This prevents race condition where mediator starts registration with wrong port let (tcp_handles, listen_port) = self.start_tcp_listener_with_port().await?; *self.tcp_listener_handle.write() = Some(tcp_handles); - // Set the listen port on mediator before starting the registration loop mediator.set_listen_port(listen_port); - // Create relay request handler let connection_manager = self.connection_manager.clone(); - let video_manager = self.video_manager.clone(); - let hid = self.hid.clone(); - let audio = self.audio.clone(); let service_config = self.config.clone(); - // Set the punch callback on the mediator (tries P2P first, then relay) - let connection_manager_punch = self.connection_manager.clone(); - let video_manager_punch = self.video_manager.clone(); - let hid_punch = self.hid.clone(); - let audio_punch = self.audio.clone(); - let service_config_punch = self.config.clone(); - - mediator.set_punch_callback(Arc::new( + mediator.set_punch_callback(Arc::new({ + let connection_manager = connection_manager.clone(); + let service_config = service_config.clone(); move |peer_addr, rendezvous_addr, relay_server, uuid, socket_addr, device_id| { - let conn_mgr = connection_manager_punch.clone(); - let video = video_manager_punch.clone(); - let hid = hid_punch.clone(); - let audio = audio_punch.clone(); - let config = service_config_punch.clone(); - + let conn_mgr = connection_manager.clone(); + let config = service_config.clone(); tokio::spawn(async move { - // Get relay_key from config (no public server fallback) - let relay_key = { - let cfg = config.read(); - cfg.relay_key.clone().unwrap_or_default() - }; - - // Try P2P direct connection first if let Some(addr) = peer_addr { info!("Attempting P2P direct connection to {}", addr); match punch::try_direct_connection(addr).await { @@ -265,7 +200,7 @@ impl RustDeskService { } } - // Fall back to relay + let relay_key = rustdesk_relay_key(&config); if let Err(e) = handle_relay_request( &rendezvous_addr, &relay_server, @@ -274,34 +209,23 @@ impl RustDeskService { &device_id, &relay_key, conn_mgr, - video, - hid, - audio, ) .await { error!("Failed to handle relay request: {}", e); } }); - }, - )); + } + })); - // Set the relay callback on the mediator - mediator.set_relay_callback(Arc::new( + mediator.set_relay_callback(Arc::new({ + let connection_manager = connection_manager.clone(); + let service_config = service_config.clone(); move |rendezvous_addr, relay_server, uuid, socket_addr, device_id| { let conn_mgr = connection_manager.clone(); - let video = video_manager.clone(); - let hid = hid.clone(); - let audio = audio.clone(); let config = service_config.clone(); - tokio::spawn(async move { - // Get relay_key from config (no public server fallback) - let relay_key = { - let cfg = config.read(); - cfg.relay_key.clone().unwrap_or_default() - }; - + let relay_key = rustdesk_relay_key(&config); if let Err(e) = handle_relay_request( &rendezvous_addr, &relay_server, @@ -310,19 +234,15 @@ impl RustDeskService { &device_id, &relay_key, conn_mgr, - video, - hid, - audio, ) .await { error!("Failed to handle relay request: {}", e); } }); - }, - )); + } + })); - // Set the intranet callback on the mediator for same-LAN connections let connection_manager2 = self.connection_manager.clone(); mediator.set_intranet_callback(Arc::new( move |rendezvous_addr, peer_socket_addr, local_addr, relay_server, device_id| { @@ -345,7 +265,6 @@ impl RustDeskService { }, )); - // Spawn rendezvous task let status = self.status.clone(); let handle = tokio::spawn(async move { loop { @@ -357,7 +276,6 @@ impl RustDeskService { Err(e) => { error!("Rendezvous mediator error: {}", e); *status.write() = ServiceStatus::Error(e.to_string()); - // Wait before retry tokio::time::sleep(std::time::Duration::from_secs(5)).await; *status.write() = ServiceStatus::Starting; } @@ -372,10 +290,7 @@ impl RustDeskService { Ok(()) } - /// Start TCP listener for direct peer connections - /// Returns the join handle and the port that was bound async fn start_tcp_listener_with_port(&self) -> anyhow::Result<(Vec>, u16)> { - // Try to bind to the default port, or find an available port let (listeners, listen_port) = match self.bind_direct_listeners(DIRECT_LISTEN_PORT) { Ok(result) => result, Err(err) => { @@ -453,7 +368,6 @@ impl RustDeskService { Ok((listeners, listen_port)) } - /// Stop the RustDesk service pub async fn stop(&self) -> anyhow::Result<()> { if self.status() == ServiceStatus::Stopped { return Ok(()); @@ -461,23 +375,18 @@ impl RustDeskService { info!("Stopping RustDesk service"); - // Send shutdown signal (this will stop the TCP listener) let _ = self.shutdown_tx.send(()); - // Close all connections self.connection_manager.close_all(); - // Stop rendezvous mediator if let Some(mediator) = self.rendezvous.read().as_ref() { mediator.stop(); } - // Wait for rendezvous task to finish if let Some(handle) = self.rendezvous_handle.write().take() { handle.abort(); } - // Wait for TCP listener task to finish if let Some(handles) = self.tcp_listener_handle.write().take() { for handle in handles { handle.abort(); @@ -490,15 +399,12 @@ impl RustDeskService { Ok(()) } - /// Restart the service with new configuration pub async fn restart(&self, config: RustDeskConfig) -> anyhow::Result<()> { self.stop().await?; self.update_config(config); self.start().await } - /// Save keypair and UUID to config - /// Returns the updated config if changes were made pub fn save_credentials(&self) -> Option { if let Some(mediator) = self.rendezvous.read().as_ref() { let kp = mediator.ensure_keypair(); @@ -506,7 +412,6 @@ impl RustDeskService { let mut config = self.config.write(); let mut changed = false; - // Save encryption keypair (Curve25519) let pk = kp.public_key_base64(); let sk = kp.secret_key_base64(); if config.public_key.as_ref() != Some(&pk) || config.private_key.as_ref() != Some(&sk) { @@ -515,7 +420,6 @@ impl RustDeskService { changed = true; } - // Save signing keypair (Ed25519) let signing_pk = skp.public_key_base64(); let signing_sk = skp.secret_key_base64(); if config.signing_public_key.as_ref() != Some(&signing_pk) @@ -526,7 +430,6 @@ impl RustDeskService { changed = true; } - // Save UUID if it was newly generated if mediator.uuid_needs_save() { let mediator_config = mediator.config(); if let Some(uuid) = mediator_config.uuid { @@ -545,21 +448,16 @@ impl RustDeskService { None } - /// Save keypair to config (deprecated, use save_credentials instead) #[deprecated(note = "Use save_credentials instead")] pub fn save_keypair(&self) { let _ = self.save_credentials(); } } -/// Handle relay request from rendezvous server -/// -/// Correct flow based on RustDesk protocol: -/// 1. Connect to RENDEZVOUS server (not relay!) -/// 2. Send RelayResponse with client's socket_addr -/// 3. Connect to RELAY server -/// 4. Accept connection without waiting for response -#[allow(clippy::too_many_arguments)] +fn rustdesk_relay_key(config: &Arc>) -> String { + config.read().relay_key.clone().unwrap_or_default() +} + async fn handle_relay_request( rendezvous_addr: &str, relay_server: &str, @@ -568,16 +466,12 @@ async fn handle_relay_request( device_id: &str, relay_key: &str, connection_manager: Arc, - _video_manager: Arc, - _hid: Arc, - _audio: Arc, ) -> anyhow::Result<()> { info!( "Handling relay request: rendezvous={}, relay={}, uuid={}", rendezvous_addr, relay_server, uuid ); - // Step 1: Connect to RENDEZVOUS server and send RelayResponse let rendezvous_socket_addr: SocketAddr = tokio::net::lookup_host(rendezvous_addr) .await? .next() @@ -597,8 +491,7 @@ async fn handle_relay_request( rendezvous_socket_addr ); - // Send RelayResponse to rendezvous server with client's socket_addr - // IMPORTANT: Include our device ID so rendezvous server can look up and sign our public key + // Rendezvous looks up our pk by device id (must set `id`, not raw pk on wire). let relay_response = make_relay_response(uuid, socket_addr, relay_server, device_id); let bytes = relay_response .write_to_bytes() @@ -606,10 +499,8 @@ async fn handle_relay_request( bytes_codec::write_frame(&mut rendezvous_stream, &bytes).await?; debug!("Sent RelayResponse to rendezvous server for uuid={}", uuid); - // Close rendezvous connection - we don't need to wait for response drop(rendezvous_stream); - // Step 2: Connect to RELAY server and send RequestRelay to identify ourselves let relay_addr: SocketAddr = tokio::net::lookup_host(relay_server) .await? .next() @@ -624,9 +515,7 @@ async fn handle_relay_request( info!("Connected to relay server at {}", relay_addr); - // Send RequestRelay to relay server with our uuid, licence_key, and peer's socket_addr - // The licence_key is required if the relay server is configured with -k option - // The socket_addr is CRITICAL - the relay server uses it to match us with the peer + // Relay pairs peers by uuid + mangled peer socket_addr (required when hbbr uses -k). let request_relay = make_request_relay(uuid, relay_key, socket_addr); let bytes = request_relay .write_to_bytes() @@ -634,10 +523,8 @@ async fn handle_relay_request( bytes_codec::write_frame(&mut stream, &bytes).await?; debug!("Sent RequestRelay to relay server for uuid={}", uuid); - // Decode peer address for logging let peer_addr = rendezvous::AddrMangle::decode(socket_addr).unwrap_or(relay_addr); - // Step 3: Accept connection - relay server bridges the connection connection_manager .accept_connection(stream, peer_addr) .await?; @@ -649,14 +536,6 @@ async fn handle_relay_request( Ok(()) } -/// Handle intranet/same-LAN connection request -/// -/// When the server determines that the client and peer are on the same intranet -/// (same public IP or both on LAN), it sends FetchLocalAddr to the peer. -/// The peer must: -/// 1. Open a TCP connection to the rendezvous server -/// 2. Send LocalAddr with our local address -/// 3. Accept the peer connection over that same TCP stream async fn handle_intranet_request( rendezvous_addr: &str, peer_socket_addr: &[u8], @@ -670,11 +549,9 @@ async fn handle_intranet_request( rendezvous_addr, local_addr, device_id ); - // Decode peer address for logging let peer_addr = AddrMangle::decode(peer_socket_addr); debug!("Peer address from FetchLocalAddr: {:?}", peer_addr); - // Connect to rendezvous server via TCP with timeout let mut stream = tokio::time::timeout(Duration::from_secs(5), TcpStream::connect(rendezvous_addr)) .await @@ -685,7 +562,6 @@ async fn handle_intranet_request( rendezvous_addr ); - // Build LocalAddr message with our local address (mangled) let local_addr_bytes = AddrMangle::encode(local_addr); let msg = make_local_addr( peer_socket_addr, @@ -698,24 +574,16 @@ async fn handle_intranet_request( .write_to_bytes() .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; - // Send LocalAddr using RustDesk's variable-length framing bytes_codec::write_frame(&mut stream, &bytes).await?; info!("Sent LocalAddr to rendezvous server, waiting for peer connection"); - // Now the rendezvous server will forward this to the client, - // and the client will connect to us through this same TCP stream. - // The server proxies the connection between client and peer. - - // Get peer address for logging/connection tracking let effective_peer_addr = peer_addr.unwrap_or_else(|| { - // If we can't decode the peer address, use the rendezvous server address rendezvous_addr .parse() .unwrap_or_else(|_| "0.0.0.0:0".parse().unwrap()) }); - // Accept the connection - the stream is now a proxied connection to the client connection_manager .accept_connection(stream, effective_peer_addr) .await?; diff --git a/src/rustdesk/protocol.rs b/src/rustdesk/protocol.rs index 3519faaa..825ebef0 100644 --- a/src/rustdesk/protocol.rs +++ b/src/rustdesk/protocol.rs @@ -1,18 +1,12 @@ -//! RustDesk Protocol Messages -//! -//! This module provides the compiled protobuf messages for the RustDesk protocol. -//! Messages are generated from rendezvous.proto and message.proto at build time. -//! Uses protobuf-rust (same as RustDesk server) for full compatibility. +//! Protobuf wrappers (`protos/` → `OUT_DIR`). use protobuf::Message; -// Include the generated protobuf code #[path = ""] pub mod hbb { include!(concat!(env!("OUT_DIR"), "/protos/mod.rs")); } -// Re-export commonly used types pub use hbb::rendezvous::{ punch_hole_response, relay_response, rendezvous_message, ConfigUpdate, ConnType, FetchLocalAddr, HealthCheck, KeyExchange, LocalAddr, NatType, OnlineRequest, OnlineResponse, @@ -21,7 +15,6 @@ pub use hbb::rendezvous::{ RequestRelay, SoftwareUpdate, TestNatRequest, TestNatResponse, }; -// Re-export message.proto types pub use hbb::message::{ key_event, login_response, message, misc, AudioFormat, AudioFrame, Auth2FA, Clipboard, ControlKey, CursorData, CursorPosition, DisplayInfo, EncodedVideoFrame, EncodedVideoFrames, @@ -30,7 +23,6 @@ pub use hbb::message::{ SupportedResolutions, TestDelay, VideoFrame, WindowsSessions, }; -/// Helper to create a RendezvousMessage with RegisterPeer pub fn make_register_peer(id: &str, serial: i32) -> RendezvousMessage { let mut rp = RegisterPeer::new(); rp.id = id.to_string(); @@ -41,7 +33,6 @@ pub fn make_register_peer(id: &str, serial: i32) -> RendezvousMessage { msg } -/// Helper to create a RendezvousMessage with RegisterPk pub fn make_register_pk(id: &str, uuid: &[u8], pk: &[u8], old_id: &str) -> RendezvousMessage { let mut rpk = RegisterPk::new(); rpk.id = id.to_string(); @@ -54,7 +45,6 @@ pub fn make_register_pk(id: &str, uuid: &[u8], pk: &[u8], old_id: &str) -> Rende msg } -/// Helper to create a PunchHoleSent message pub fn make_punch_hole_sent( socket_addr: &[u8], id: &str, @@ -74,10 +64,7 @@ pub fn make_punch_hole_sent( msg } -/// Helper to create a RelayResponse message (sent to rendezvous server) -/// IMPORTANT: The union field should be `Id` (our device ID), NOT `Pk`. -/// The rendezvous server will look up our registered public key using this ID, -/// sign it with the server's private key, and set the `pk` field before forwarding to client. +/// Use `id` (device id), not raw `pk`; hbbs fills `pk` when forwarding. pub fn make_relay_response( uuid: &str, socket_addr: &[u8], @@ -96,13 +83,7 @@ pub fn make_relay_response( msg } -/// Helper to create a RequestRelay message (sent to relay server to identify ourselves) -/// -/// The `licence_key` is required if the relay server is configured with a key. -/// If the key doesn't match, the relay server will silently reject the connection. -/// -/// IMPORTANT: `socket_addr` is the peer's encoded socket address (from FetchLocalAddr/RelayResponse). -/// The relay server uses this to match the two peers connecting to the same relay session. +/// `socket_addr` must be the peer's mangled addr; `licence_key` required if hbbr uses `-k`. pub fn make_request_relay(uuid: &str, licence_key: &str, socket_addr: &[u8]) -> RendezvousMessage { let mut rr = RequestRelay::new(); rr.uuid = uuid.to_string(); @@ -114,8 +95,6 @@ pub fn make_request_relay(uuid: &str, licence_key: &str, socket_addr: &[u8]) -> msg } -/// Helper to create a LocalAddr response message -/// This is sent in response to FetchLocalAddr when a peer on the same LAN wants to connect pub fn make_local_addr( socket_addr: &[u8], local_addr: &[u8], @@ -135,12 +114,10 @@ pub fn make_local_addr( msg } -/// Decode a RendezvousMessage from bytes pub fn decode_rendezvous_message(buf: &[u8]) -> Result { RendezvousMessage::parse_from_bytes(buf) } -/// Decode a Message (session message) from bytes pub fn decode_message(buf: &[u8]) -> Result { hbb::message::Message::parse_from_bytes(buf) } diff --git a/src/rustdesk/punch.rs b/src/rustdesk/punch.rs index ad6cea80..314a1b6f 100644 --- a/src/rustdesk/punch.rs +++ b/src/rustdesk/punch.rs @@ -1,36 +1,19 @@ -//! P2P Punch Hole Implementation -//! -//! This module implements TCP direct connection attempts with relay fallback. -//! When a PunchHole request is received, we try to connect directly to the peer. -//! If the direct connection fails (timeout), we fall back to relay. +//! Direct TCP attempt before relay fallback. use std::net::SocketAddr; -use std::sync::Arc; use std::time::Duration; use tokio::net::TcpStream; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; -use super::connection::ConnectionManager; - -/// Timeout for direct TCP connection attempt const DIRECT_CONNECT_TIMEOUT_MS: u64 = 3000; -/// Result of a punch hole attempt #[derive(Debug)] pub enum PunchResult { - /// Direct connection succeeded DirectConnection(TcpStream), - /// Direct connection failed, should use relay NeedRelay, } -/// Attempt direct TCP connection to peer -/// -/// This is a simplified P2P approach: -/// 1. Try to connect directly to the peer's address -/// 2. If successful within timeout, use direct connection -/// 3. If failed or timeout, fall back to relay pub async fn try_direct_connection(peer_addr: SocketAddr) -> PunchResult { info!("Attempting direct TCP connection to {}", peer_addr); @@ -54,76 +37,3 @@ pub async fn try_direct_connection(peer_addr: SocketAddr) -> PunchResult { } } } - -/// Punch hole handler that tries direct connection first, then falls back to relay -pub struct PunchHoleHandler { - connection_manager: Arc, -} - -impl PunchHoleHandler { - pub fn new(connection_manager: Arc) -> Self { - Self { connection_manager } - } - - /// Handle punch hole request - /// - /// Tries direct connection first, falls back to relay if needed. - /// Returns true if direct connection succeeded, false if relay is needed. - pub async fn handle_punch_hole(&self, peer_addr: Option) -> bool { - let peer_addr = match peer_addr { - Some(addr) => addr, - None => { - warn!("No peer address available for punch hole"); - return false; - } - }; - - match try_direct_connection(peer_addr).await { - PunchResult::DirectConnection(stream) => { - // Direct connection succeeded, accept it - match self - .connection_manager - .accept_connection(stream, peer_addr) - .await - { - Ok(_) => { - info!("P2P direct connection established with {}", peer_addr); - true - } - Err(e) => { - warn!("Failed to accept direct connection: {}", e); - false - } - } - } - PunchResult::NeedRelay => { - debug!("Direct connection failed, need relay for {}", peer_addr); - false - } - } - } -} - -/// Spawn a punch hole attempt with relay fallback -/// -/// This function spawns an async task that: -/// 1. Tries direct TCP connection to peer -/// 2. If successful, accepts the connection -/// 3. If failed, calls the relay callback -pub fn spawn_punch_with_fallback( - connection_manager: Arc, - peer_addr: Option, - relay_callback: F, -) where - F: FnOnce() + Send + 'static, -{ - tokio::spawn(async move { - let handler = PunchHoleHandler::new(connection_manager); - - if !handler.handle_punch_hole(peer_addr).await { - // Direct connection failed, use relay - info!("Falling back to relay connection"); - relay_callback(); - } - }); -} diff --git a/src/rustdesk/rendezvous.rs b/src/rustdesk/rendezvous.rs index 7699dfca..0e62390e 100644 --- a/src/rustdesk/rendezvous.rs +++ b/src/rustdesk/rendezvous.rs @@ -1,8 +1,4 @@ -//! RustDesk Rendezvous Mediator -//! -//! This module handles communication with the hbbs rendezvous server. -//! It registers the device ID and public key, handles punch hole requests, -//! and relay requests. +//! HBBS UDP registration; punch / relay / intranet callbacks. use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; @@ -24,19 +20,14 @@ use super::protocol::{ rendezvous_message, NatType, RendezvousMessage, }; -/// Registration interval in milliseconds const REG_INTERVAL_MS: u64 = 12_000; -/// Minimum registration timeout const MIN_REG_TIMEOUT_MS: u64 = 3_000; -/// Maximum registration timeout const MAX_REG_TIMEOUT_MS: u64 = 30_000; -/// Timer interval for checking registration status const TIMER_INTERVAL_MS: u64 = 300; -/// Rendezvous mediator status #[derive(Debug, Clone, PartialEq)] pub enum RendezvousStatus { Disconnected, @@ -58,44 +49,13 @@ impl std::fmt::Display for RendezvousStatus { } } -/// Callback for handling incoming connection requests -pub type ConnectionCallback = Arc; - -/// Incoming connection request from a RustDesk client -#[derive(Debug, Clone)] -pub struct ConnectionRequest { - /// Peer socket address (encoded) - pub socket_addr: Vec, - /// Relay server to use - pub relay_server: String, - /// NAT type - pub nat_type: NatType, - /// Connection UUID - pub uuid: String, - /// Whether to use secure connection - pub secure: bool, -} - -/// Callback type for relay requests -/// Parameters: rendezvous_addr, relay_server, uuid, socket_addr (client's mangled address), device_id pub type RelayCallback = Arc, String) + Send + Sync>; -/// Callback type for P2P punch hole requests -/// Parameters: peer_addr (decoded), relay_callback_params (rendezvous_addr, relay_server, uuid, socket_addr, device_id) -/// Returns: should call relay callback if P2P fails pub type PunchCallback = Arc, String, String, String, Vec, String) + Send + Sync>; -/// Callback type for intranet/local address connections -/// Parameters: rendezvous_addr, peer_socket_addr (mangled), local_addr, relay_server, device_id pub type IntranetCallback = Arc, SocketAddr, String, String) + Send + Sync>; -/// Rendezvous Mediator -/// -/// Handles communication with hbbs rendezvous server: -/// - Registers device ID and public key -/// - Maintains keep-alive with server -/// - Handles punch hole and relay requests pub struct RendezvousMediator { config: Arc>, keypair: Arc>>, @@ -114,11 +74,9 @@ pub struct RendezvousMediator { } impl RendezvousMediator { - /// Create a new rendezvous mediator pub fn new(mut config: RustDeskConfig) -> Self { let (shutdown_tx, _) = broadcast::channel(1); - // Get or generate UUID from config (persisted) let (uuid, uuid_needs_save) = config.ensure_uuid(); Self { @@ -139,88 +97,71 @@ impl RendezvousMediator { } } - /// Set the TCP listen port for direct connections pub fn set_listen_port(&self, port: u16) { let old_port = *self.listen_port.read(); if old_port != port { *self.listen_port.write() = port; - // Port changed, increment serial to notify server self.increment_serial(); } } - /// Get the TCP listen port pub fn listen_port(&self) -> u16 { *self.listen_port.read() } - /// Increment the serial number to indicate local state change pub fn increment_serial(&self) { let mut serial = self.serial.write(); *serial = serial.wrapping_add(1); debug!("Serial incremented to {}", *serial); } - /// Get current serial number pub fn serial(&self) -> i32 { *self.serial.read() } - /// Check if UUID needs to be saved to persistent storage pub fn uuid_needs_save(&self) -> bool { *self.uuid_needs_save.read() } - /// Get the current config (with UUID set) pub fn config(&self) -> RustDeskConfig { self.config.read().clone() } - /// Mark UUID as saved pub fn mark_uuid_saved(&self) { *self.uuid_needs_save.write() = false; } - /// Set the callback for relay requests pub fn set_relay_callback(&self, callback: RelayCallback) { *self.relay_callback.write() = Some(callback); } - /// Set the callback for P2P punch hole requests pub fn set_punch_callback(&self, callback: PunchCallback) { *self.punch_callback.write() = Some(callback); } - /// Set the callback for intranet/local address connections pub fn set_intranet_callback(&self, callback: IntranetCallback) { *self.intranet_callback.write() = Some(callback); } - /// Get current status pub fn status(&self) -> RendezvousStatus { self.status.read().clone() } - /// Update configuration pub fn update_config(&self, config: RustDeskConfig) { *self.config.write() = config; - // Config changed, increment serial to notify server self.increment_serial(); } - /// Initialize or get keypair (Curve25519 for encryption) pub fn ensure_keypair(&self) -> KeyPair { let mut keypair_guard = self.keypair.write(); if keypair_guard.is_none() { let config = self.config.read(); - // Try to load from config first if let (Some(pk), Some(sk)) = (&config.public_key, &config.private_key) { if let Ok(kp) = KeyPair::from_base64(pk, sk) { *keypair_guard = Some(kp.clone()); return kp; } } - // Generate new keypair let kp = KeyPair::generate(); *keypair_guard = Some(kp.clone()); kp @@ -229,12 +170,10 @@ impl RendezvousMediator { } } - /// Initialize or get signing keypair (Ed25519 for SignedId) pub fn ensure_signing_keypair(&self) -> SigningKeyPair { let mut signing_guard = self.signing_keypair.write(); if signing_guard.is_none() { let config = self.config.read(); - // Try to load from config first if let (Some(pk), Some(sk)) = (&config.signing_public_key, &config.signing_private_key) { if let Ok(skp) = SigningKeyPair::from_base64(pk, sk) { @@ -245,7 +184,6 @@ impl RendezvousMediator { warn!("Failed to decode signing keypair from config, generating new one"); } } - // Generate new signing keypair let skp = SigningKeyPair::generate(); debug!("Generated new signing keypair"); *signing_guard = Some(skp.clone()); @@ -255,12 +193,10 @@ impl RendezvousMediator { } } - /// Get the device ID pub fn device_id(&self) -> String { self.config.read().device_id.clone() } - /// Start the rendezvous mediator pub async fn start(&self) -> anyhow::Result<()> { let config = self.config.read().clone(); let effective_server = config.effective_rendezvous_server(); @@ -284,13 +220,11 @@ impl RendezvousMediator { config.device_id, addr ); - // Resolve server address let server_addr: SocketAddr = tokio::net::lookup_host(&addr) .await? .next() .ok_or_else(|| anyhow::anyhow!("Failed to resolve {}", addr))?; - // Create UDP socket (match address family, enforce IPV6_V6ONLY) let bind_addr = match server_addr { SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), @@ -302,11 +236,9 @@ impl RendezvousMediator { info!("Connected to rendezvous server at {}", server_addr); *self.status.write() = RendezvousStatus::Connected; - // Start registration loop self.registration_loop(socket).await } - /// Main registration loop async fn registration_loop(&self, socket: UdpSocket) -> anyhow::Result<()> { let mut timer = interval(Duration::from_millis(TIMER_INTERVAL_MS)); let mut recv_buf = vec![0u8; 65535]; @@ -318,7 +250,6 @@ impl RendezvousMediator { loop { tokio::select! { - // Handle incoming messages result = socket.recv(&mut recv_buf) => { match result { Ok(len) => { @@ -336,7 +267,6 @@ impl RendezvousMediator { } } - // Periodic registration _ = timer.tick() => { let now = Instant::now(); let expired = last_register_resp @@ -360,7 +290,6 @@ impl RendezvousMediator { } } - // Shutdown signal _ = shutdown_rx.recv() => { info!("Rendezvous mediator shutting down"); break; @@ -372,20 +301,16 @@ impl RendezvousMediator { Ok(()) } - /// Send registration message async fn send_register(&self, socket: &UdpSocket) -> anyhow::Result<()> { let key_confirmed = *self.key_confirmed.read(); if !key_confirmed { - // Send RegisterPk with public key self.send_register_pk(socket).await } else { - // Send RegisterPeer heartbeat self.send_register_peer(socket).await } } - /// Send RegisterPeer message async fn send_register_peer(&self, socket: &UdpSocket) -> anyhow::Result<()> { let id = self.device_id(); let serial = *self.serial.read(); @@ -398,12 +323,8 @@ impl RendezvousMediator { Ok(()) } - /// Send RegisterPk message - /// Uses the Ed25519 signing public key for registration async fn send_register_pk(&self, socket: &UdpSocket) -> anyhow::Result<()> { let id = self.device_id(); - // Use signing public key (Ed25519) for RegisterPk - // This is what clients will use to verify our SignedId signature let signing_keypair = self.ensure_signing_keypair(); let pk = signing_keypair.public_key_bytes(); let uuid = *self.uuid.read(); @@ -417,12 +338,6 @@ impl RendezvousMediator { Ok(()) } - /// Handle FetchLocalAddr - send to callback for proper TCP handling - /// - /// The intranet callback will: - /// 1. Open a TCP connection to the rendezvous server - /// 2. Send LocalAddr message - /// 3. Accept the peer connection over that same TCP stream async fn send_local_addr( &self, _udp_socket: &UdpSocket, @@ -431,21 +346,17 @@ impl RendezvousMediator { ) -> anyhow::Result<()> { let id = self.device_id(); - // Get our actual local IP addresses for same-LAN connection let local_addrs = get_local_addresses(); if local_addrs.is_empty() { debug!("No local addresses available for LocalAddr response"); return Ok(()); } - // Get the rendezvous server address for TCP connection let config = self.config.read().clone(); let rendezvous_addr = config.rendezvous_addr(); - // Use TCP listen port for direct connections let listen_port = self.listen_port(); - // Use the first local IP let local_ip = local_addrs[0]; let local_sock_addr = SocketAddr::new(local_ip, listen_port); @@ -454,7 +365,6 @@ impl RendezvousMediator { local_sock_addr, rendezvous_addr ); - // Call the intranet callback if set if let Some(callback) = self.intranet_callback.read().as_ref() { callback( rendezvous_addr, @@ -470,7 +380,6 @@ impl RendezvousMediator { Ok(()) } - /// Handle response from rendezvous server async fn handle_response( &self, socket: &UdpSocket, @@ -486,7 +395,6 @@ impl RendezvousMediator { match msg.union { Some(rendezvous_message::Union::RegisterPeerResponse(rpr)) => { if rpr.request_pk { - // Server wants us to register our public key info!("Server requested public key registration"); *self.key_confirmed.write() = false; self.send_register_pk(socket).await?; @@ -497,30 +405,24 @@ impl RendezvousMediator { info!("Received RegisterPkResponse: result={:?}", rpr.result); match rpr.result.value() { 0 => { - // OK info!("✓ Public key registered successfully with server"); *self.key_confirmed.write() = true; - // Increment serial after successful registration self.increment_serial(); *self.status.write() = RendezvousStatus::Registered; } 2 => { - // UUID_MISMATCH warn!("UUID mismatch, need to re-register"); *self.key_confirmed.write() = false; } 3 => { - // ID_EXISTS error!("Device ID already exists on server"); *self.status.write() = RendezvousStatus::Error("Device ID already exists".to_string()); } 4 => { - // TOO_FREQUENT warn!("Registration too frequent"); } 5 => { - // INVALID_ID_FORMAT error!("Invalid device ID format"); *self.status.write() = RendezvousStatus::Error("Invalid ID format".to_string()); @@ -540,7 +442,6 @@ impl RendezvousMediator { let effective_relay_server = select_relay_server(config.relay_server.as_deref(), &ph.relay_server); - // Decode the peer's socket address let peer_addr = if !ph.socket_addr.is_empty() { AddrMangle::decode(&ph.socket_addr) } else { @@ -556,9 +457,7 @@ impl RendezvousMediator { ph.nat_type ); - // Send PunchHoleSent to acknowledge // IMPORTANT: socket_addr in PunchHoleSent should be the PEER's address (from PunchHole), - // not our own address. This is how RustDesk protocol works. let id = self.device_id(); info!( @@ -586,16 +485,11 @@ impl RendezvousMediator { info!("Sent PunchHoleSent response successfully"); } - // Try P2P direct connection first, fall back to relay if needed if let Some(relay_server) = effective_relay_server { - // Generate a standard UUID v4 for relay pairing - // This must match the format used by RustDesk client let uuid = uuid::Uuid::new_v4().to_string(); let rendezvous_addr = config.rendezvous_addr(); let device_id = config.device_id.clone(); - // Use punch callback if set (tries P2P first, then relay) - // Otherwise fall back to relay callback directly if let Some(callback) = self.punch_callback.read().as_ref() { callback( peer_addr, @@ -630,7 +524,6 @@ impl RendezvousMediator { rr.uuid, rr.secure ); - // Call the relay callback to handle the connection if let Some(callback) = self.relay_callback.read().as_ref() { if let Some(relay_server) = effective_relay_server { let rendezvous_addr = config.rendezvous_addr(); @@ -653,7 +546,6 @@ impl RendezvousMediator { select_relay_server(config.relay_server.as_deref(), &fla.relay_server) .unwrap_or_default(); - // Decode the peer address for logging let peer_addr = AddrMangle::decode(&fla.socket_addr); info!( "Received FetchLocalAddr request: peer_addr={:?}, socket_addr_len={}, relay_server={}, effective_relay_server={}", @@ -662,7 +554,6 @@ impl RendezvousMediator { fla.relay_server, effective_relay_server ); - // Respond with our local address for same-LAN direct connection self.send_local_addr(socket, &fla.socket_addr, &effective_relay_server) .await?; } @@ -671,7 +562,6 @@ impl RendezvousMediator { *self.serial.write() = cu.serial; } Some(other) => { - // Log the actual message type for debugging let type_name = match other { rendezvous_message::Union::PunchHoleRequest(_) => "PunchHoleRequest", rendezvous_message::Union::PunchHoleResponse(_) => "PunchHoleResponse", @@ -696,23 +586,18 @@ impl RendezvousMediator { Ok(()) } - /// Stop the rendezvous mediator pub fn stop(&self) { info!("Stopping rendezvous mediator"); let _ = self.shutdown_tx.send(()); *self.status.write() = RendezvousStatus::Disconnected; } - /// Get a shutdown receiver pub fn shutdown_rx(&self) -> broadcast::Receiver<()> { self.shutdown_tx.subscribe() } } -/// AddrMangle - RustDesk's address encoding scheme -/// -/// Certain routers and firewalls scan packets and modify IP addresses. -/// This encoding mangles the address to avoid detection. +/// RustDesk mangled socket encoding. pub struct AddrMangle; fn normalize_relay_server(server: &str) -> Option { @@ -735,9 +620,7 @@ fn select_relay_server(local_relay: Option<&str>, server_relay: &str) -> Option< } impl AddrMangle { - /// Encode a SocketAddr to bytes using RustDesk's mangle algorithm pub fn encode(addr: SocketAddr) -> Vec { - // Try to convert IPv6-mapped IPv4 to plain IPv4 let addr = try_into_v4(addr); match addr { @@ -753,7 +636,6 @@ impl AddrMangle { let v = ((ip + tm) << 49) | (tm << 17) | (port + (tm & 0xFFFF)); let bytes = v.to_le_bytes(); - // Remove trailing zeros let mut n_padding = 0; for i in bytes.iter().rev() { if *i == 0u8 { @@ -774,13 +656,11 @@ impl AddrMangle { } } - /// Decode bytes to SocketAddr using RustDesk's mangle algorithm pub fn decode(bytes: &[u8]) -> Option { use std::convert::TryInto; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4}; if bytes.len() > 16 { - // IPv6 format: 16 bytes IP + 2 bytes port if bytes.len() != 18 { return None; } @@ -791,7 +671,6 @@ impl AddrMangle { return Some(SocketAddr::new(std::net::IpAddr::V6(ip), port)); } - // IPv4 mangled format let mut padded = [0u8; 16]; padded[..bytes.len()].copy_from_slice(bytes); let number = u128::from_le_bytes(padded); @@ -805,7 +684,6 @@ impl AddrMangle { } } -/// Try to convert IPv6-mapped IPv4 address to plain IPv4 fn try_into_v4(addr: SocketAddr) -> SocketAddr { match addr { SocketAddr::V6(v6) if !addr.ip().is_loopback() => { @@ -818,41 +696,30 @@ fn try_into_v4(addr: SocketAddr) -> SocketAddr { addr } -/// Check if an interface name belongs to Docker or other virtual networks fn is_virtual_interface(name: &str) -> bool { - // Docker interfaces name.starts_with("docker") || name.starts_with("br-") || name.starts_with("veth") - // Kubernetes/container interfaces || name.starts_with("cni") || name.starts_with("flannel") || name.starts_with("calico") || name.starts_with("weave") - // Virtual bridge interfaces || name.starts_with("virbr") || name.starts_with("lxcbr") || name.starts_with("lxdbr") - // VPN interfaces (usually not useful for LAN discovery) || name.starts_with("tun") || name.starts_with("tap") } -/// Check if an IP address is in a Docker/container private range fn is_docker_ip(ip: &std::net::IpAddr) -> bool { if let std::net::IpAddr::V4(ipv4) = ip { let octets = ipv4.octets(); - // Docker default bridge: 172.17.0.0/16 if octets[0] == 172 && octets[1] == 17 { return true; } - // Docker user-defined networks: 172.18-31.0.0/16 if octets[0] == 172 && (18..=31).contains(&octets[1]) { return true; } - // Docker overlay networks: 10.0.0.0/8 (common range) - // Note: 10.x.x.x is also used for corporate LANs, so we only filter - // specific Docker-like patterns (10.0.x.x with small third octet) if octets[0] == 10 && octets[1] == 0 && octets[2] < 10 { return true; } @@ -860,22 +727,18 @@ fn is_docker_ip(ip: &std::net::IpAddr) -> bool { false } -/// Get local IP addresses (non-loopback, non-Docker) fn get_local_addresses() -> Vec { let mut addrs = Vec::new(); - // Use pnet or network-interface crate if available, otherwise use simple method #[cfg(target_os = "linux")] { if let Ok(interfaces) = std::fs::read_dir("/sys/class/net") { for entry in interfaces.flatten() { let iface_name = entry.file_name().to_string_lossy().to_string(); - // Skip loopback and virtual interfaces if iface_name == "lo" || is_virtual_interface(&iface_name) { continue; } - // Try to get IP via command (simple approach) if let Ok(output) = std::process::Command::new("ip") .args(["-4", "addr", "show", &iface_name]) .output() @@ -886,7 +749,6 @@ fn get_local_addresses() -> Vec { let ip_part = &line[inet_pos + 5..]; if let Some(slash_pos) = ip_part.find('/') { if let Ok(ip) = ip_part[..slash_pos].parse::() { - // Skip loopback and Docker IPs if !ip.is_loopback() && !is_docker_ip(&ip) { addrs.push(ip); } @@ -899,15 +761,11 @@ fn get_local_addresses() -> Vec { } } - // Fallback: try to get default route interface IP if addrs.is_empty() { - // Try using DNS lookup to get local IP (connects to external server) if let Ok(socket) = std::net::UdpSocket::bind("0.0.0.0:0") { - // Connect to a public DNS server (doesn't actually send data) if socket.connect("8.8.8.8:53").is_ok() { if let Ok(local_addr) = socket.local_addr() { let ip = local_addr.ip(); - // Skip loopback and Docker IPs if !ip.is_loopback() && !is_docker_ip(&ip) { addrs.push(ip); } diff --git a/src/state.rs b/src/state.rs index d1443f0d..b0a4012f 100644 --- a/src/state.rs +++ b/src/state.rs @@ -5,8 +5,9 @@ use crate::atx::AtxController; use crate::audio::AudioController; use crate::auth::{SessionStore, UserStore}; use crate::config::ConfigStore; +use crate::db::DatabasePool; use crate::events::{ - AtxDeviceInfo, AudioDeviceInfo, EventBus, HidDeviceInfo, MsdDeviceInfo, SystemEvent, + AtxDeviceInfo, AudioDeviceInfo, EventBus, HidDeviceInfo, LedState, MsdDeviceInfo, SystemEvent, TtydDeviceInfo, VideoDeviceInfo, }; use crate::extensions::{ExtensionId, ExtensionManager}; @@ -17,68 +18,42 @@ use crate::rtsp::RtspService; use crate::rustdesk::RustDeskService; use crate::update::UpdateService; use crate::video::VideoStreamManager; +use crate::webrtc::WebRtcStreamer; -/// Application-wide state shared across handlers -/// -/// # Video Streaming -/// -/// All video operations should go through `stream_manager`: -/// - `stream_manager.webrtc_streamer()` - WebRTC streaming (H264, extensible to VP8/VP9/H265) -/// - `stream_manager.mjpeg_handler()` - MJPEG stream handler -/// - `stream_manager.streamer()` - Low-level video capture -/// - `stream_manager.start()` / `stream_manager.stop()` - Stream control -/// - `stream_manager.stats()` - Stream statistics -/// - `stream_manager.list_devices()` - List video devices +/// Shared Axum/App state: video flows through [`VideoStreamManager`]; WebRTC SDP/ICE/sessions on [`WebRtcStreamer`]. pub struct AppState { - /// Configuration store + pub db: DatabasePool, pub config: ConfigStore, - /// Session store pub sessions: SessionStore, - /// User store pub users: UserStore, - /// OTG Service - centralized USB gadget lifecycle management - /// This is the single owner of OtgGadgetManager, coordinating HID and MSD functions pub otg_service: Arc, - /// Video stream manager (unified MJPEG/WebRTC management) - /// This is the single entry point for all video operations. pub stream_manager: Arc, - /// HID controller + pub webrtc: Arc, pub hid: Arc, - /// MSD controller (optional, may not be initialized) pub msd: Arc>>, - /// ATX controller (optional, may not be initialized) pub atx: Arc>>, - /// Audio controller pub audio: Arc, - /// RustDesk remote access service (optional) pub rustdesk: Arc>>>, - /// RTSP streaming service (optional) pub rtsp: Arc>>>, - /// Extension manager (ttyd, gostc, easytier) pub extensions: Arc, - /// Event bus for real-time notifications pub events: Arc, - /// Latest device info snapshot for WebSocket clients device_info_tx: watch::Sender>, - /// Online update service pub update: Arc, - /// Shutdown signal sender pub shutdown_tx: broadcast::Sender<()>, - /// Recently revoked session IDs (for client kick detection) pub revoked_sessions: Arc>>, - /// Data directory path data_dir: std::path::PathBuf, } impl AppState { - /// Create new application state #[allow(clippy::too_many_arguments)] pub fn new( + db: DatabasePool, config: ConfigStore, sessions: SessionStore, users: UserStore, otg_service: Arc, stream_manager: Arc, + webrtc: Arc, hid: Arc, msd: Option, atx: Option, @@ -94,11 +69,13 @@ impl AppState { let (device_info_tx, _device_info_rx) = watch::channel(None); Arc::new(Self { + db, config, sessions, users, otg_service, stream_manager, + webrtc, hid, msd: Arc::new(RwLock::new(msd)), atx: Arc::new(RwLock::new(atx)), @@ -115,22 +92,18 @@ impl AppState { }) } - /// Get data directory path pub fn data_dir(&self) -> &std::path::PathBuf { &self.data_dir } - /// Subscribe to shutdown signal pub fn shutdown_signal(&self) -> broadcast::Receiver<()> { self.shutdown_tx.subscribe() } - /// Subscribe to the latest device info snapshot. pub fn subscribe_device_info(&self) -> watch::Receiver> { self.device_info_tx.subscribe() } - /// Record revoked session IDs (bounded queue) pub async fn remember_revoked_sessions(&self, session_ids: Vec) { if session_ids.is_empty() { return; @@ -144,19 +117,12 @@ impl AppState { } } - /// Check if a session ID was revoked (kicked) pub async fn is_session_revoked(&self, session_id: &str) -> bool { let guard = self.revoked_sessions.read().await; guard.iter().any(|id| id == session_id) } - /// Get complete device information for WebSocket clients - /// - /// This method collects the current state of all devices (video, HID, MSD, ATX, Audio) - /// and returns a DeviceInfo event that can be sent to clients. - /// Uses tokio::join! to collect all device info in parallel for better performance. pub async fn get_device_info(&self) -> SystemEvent { - // Collect all device info in parallel let (video, hid, msd, atx, audio, ttyd) = tokio::join!( self.collect_video_info(), self.collect_hid_info(), @@ -176,19 +142,15 @@ impl AppState { } } - /// Publish DeviceInfo event to all connected WebSocket clients pub async fn publish_device_info(&self) { let device_info = self.get_device_info().await; let _ = self.device_info_tx.send(Some(device_info)); } - /// Collect video device information async fn collect_video_info(&self) -> VideoDeviceInfo { - // Use stream_manager to get video info (includes stream_mode) self.stream_manager.get_video_info().await } - /// Collect HID device information async fn collect_hid_info(&self) -> HidDeviceInfo { let state = self.hid.snapshot().await; @@ -199,14 +161,19 @@ impl AppState { online: state.online, supports_absolute_mouse: state.supports_absolute_mouse, keyboard_leds_enabled: state.keyboard_leds_enabled, - led_state: state.led_state, + led_state: LedState { + num_lock: state.led_state.num_lock, + caps_lock: state.led_state.caps_lock, + scroll_lock: state.led_state.scroll_lock, + compose: state.led_state.compose, + kana: state.led_state.kana, + }, device: state.device, error: state.error, error_code: state.error_code, } } - /// Collect MSD device information (optional) async fn collect_msd_info(&self) -> Option { let msd_guard = self.msd.read().await; let msd = msd_guard.as_ref()?; @@ -227,9 +194,7 @@ impl AppState { }) } - /// Collect ATX device information (optional) async fn collect_atx_info(&self) -> Option { - // Predefined backend strings to avoid repeated allocations const BACKEND_POWER_ONLY: &str = "power: configured, reset: none"; const BACKEND_RESET_ONLY: &str = "power: none, reset: configured"; const BACKEND_BOTH: &str = "power: configured, reset: configured"; @@ -254,7 +219,6 @@ impl AppState { }) } - /// Collect Audio device information (optional) async fn collect_audio_info(&self) -> Option { let status = self.audio.status().await; @@ -267,7 +231,6 @@ impl AppState { }) } - /// Collect ttyd status information async fn collect_ttyd_info(&self) -> TtydDeviceInfo { let status = self.extensions.status(ExtensionId::Ttyd).await; diff --git a/src/stream/mjpeg.rs b/src/stream/mjpeg.rs index dcb6fc55..684c66ee 100644 --- a/src/stream/mjpeg.rs +++ b/src/stream/mjpeg.rs @@ -1,7 +1,3 @@ -//! MJPEG stream handler -//! -//! Manages video frame distribution and per-client statistics. - use arc_swap::ArcSwap; use parking_lot::Mutex as ParkingMutex; use parking_lot::RwLock as ParkingRwLock; @@ -17,28 +13,18 @@ use crate::video::encoder::JpegEncoder; use crate::video::format::PixelFormat; use crate::video::VideoFrame; -// No placeholder JPEGs: capture calls `set_offline()`; UI uses `stream.state_changed`. - -/// Client ID type (UUID string) pub type ClientId = String; -/// Per-client session information #[derive(Debug, Clone)] pub struct ClientSession { - /// Unique client ID pub id: ClientId, - /// Connection timestamp pub connected_at: Instant, - /// Last activity timestamp (frame sent) pub last_activity: Instant, - /// Frames sent to this client pub frames_sent: u64, - /// FPS calculator (1-second rolling window) pub fps_calculator: FpsCalculator, } impl ClientSession { - /// Create a new client session pub fn new(id: ClientId) -> Self { let now = Instant::now(); Self { @@ -50,44 +36,31 @@ impl ClientSession { } } - /// Get connection duration - pub fn connected_duration(&self) -> Duration { - self.last_activity.duration_since(self.connected_at) - } - - /// Get idle duration - pub fn idle_duration(&self) -> Duration { - Instant::now().duration_since(self.last_activity) + pub fn connected_elapsed(&self) -> Duration { + self.connected_at.elapsed() } } -/// Rolling window FPS calculator #[derive(Debug, Clone)] pub struct FpsCalculator { - /// Frame timestamps in last window frame_times: VecDeque, - /// Window duration (default 1 second) window: Duration, - /// Cached count of frames in current window (optimization to avoid O(n) filtering) count_in_window: usize, } impl FpsCalculator { - /// Create a new FPS calculator with 1-second window pub fn new() -> Self { Self { - frame_times: VecDeque::with_capacity(120), // Max 120fps tracking + frame_times: VecDeque::with_capacity(120), window: Duration::from_secs(1), count_in_window: 0, } } - /// Record a frame sent pub fn record_frame(&mut self) { let now = Instant::now(); self.frame_times.push_back(now); - // Remove frames outside window and maintain count let cutoff = now - self.window; while let Some(&oldest) = self.frame_times.front() { if oldest < cutoff { @@ -97,31 +70,18 @@ impl FpsCalculator { } } - // Update cached count self.count_in_window = self.frame_times.len(); } - /// Calculate current FPS (frames in last 1 second window) pub fn current_fps(&self) -> u32 { - // Return cached count instead of filtering entire deque (O(1) instead of O(n)) self.count_in_window as u32 } } -impl Default for FpsCalculator { - fn default() -> Self { - Self::new() - } -} - -/// Auto-pause configuration #[derive(Debug, Clone)] pub struct AutoPauseConfig { - /// Enable auto-pause when no clients pub enabled: bool, - /// Delay before pausing (default 10s) pub shutdown_delay_secs: u64, - /// Client timeout for cleanup (default 30s) pub client_timeout_secs: u64, } @@ -135,43 +95,27 @@ impl Default for AutoPauseConfig { } } -/// MJPEG stream handler -/// Manages video frame distribution to HTTP clients pub struct MjpegStreamHandler { - /// Current frame (latest) - using ArcSwap for lock-free reads current_frame: ArcSwap>, - /// Frame update notification frame_notify: broadcast::Sender<()>, - /// Whether stream is online online: AtomicBool, - /// Frame sequence counter sequence: AtomicU64, - /// Per-client sessions (ClientId -> ClientSession) - /// Use parking_lot::RwLock for better performance clients: ParkingRwLock>, - /// Auto-pause configuration auto_pause_config: ParkingRwLock, - /// Last frame timestamp last_frame_ts: ParkingRwLock>, - /// Dropped same frames count dropped_same_frames: AtomicU64, - /// Maximum consecutive same frames to drop (0 = disabled) max_drop_same_frames: AtomicU64, - /// JPEG encoder for non-JPEG input formats jpeg_encoder: ParkingMutex>, - /// JPEG quality for software encoding (1-100) jpeg_quality: AtomicU64, } impl MjpegStreamHandler { - /// Create a new MJPEG stream handler pub fn new() -> Self { - Self::with_drop_limit(100) // Default: drop up to 100 same frames + Self::with_drop_limit(100) } - /// Create handler with custom drop limit pub fn with_drop_limit(max_drop: u64) -> Self { - let (frame_notify, _) = broadcast::channel(16); // Buffer size 16 for low latency + let (frame_notify, _) = broadcast::channel(16); Self { current_frame: ArcSwap::from_pointee(None), frame_notify, @@ -187,16 +131,12 @@ impl MjpegStreamHandler { } } - /// Set JPEG quality for software encoding (1-100) pub fn set_jpeg_quality(&self, quality: u8) { let clamped = quality.clamp(1, 100) as u64; self.jpeg_quality.store(clamped, Ordering::Relaxed); } - /// Update current frame pub fn update_frame(&self, frame: VideoFrame) { - // Fast path: if no MJPEG clients are connected, do minimal bookkeeping and avoid - // expensive work (JPEG encoding and per-frame dedup hashing). let has_clients = !self.clients.read().is_empty(); if !has_clients { self.dropped_same_frames.store(0, Ordering::Relaxed); @@ -204,8 +144,6 @@ impl MjpegStreamHandler { self.online.store(frame.online, Ordering::SeqCst); *self.last_frame_ts.write() = Some(Instant::now()); - // Keep the latest compressed frame for "instant first frame" when a client connects. - // Avoid retaining large raw buffers when there are no MJPEG clients. if frame.format.is_compressed() { self.current_frame.store(Arc::new(Some(frame))); } else { @@ -214,7 +152,6 @@ impl MjpegStreamHandler { return; } - // If frame is not JPEG, encode it let frame = if !frame.format.is_compressed() { match self.encode_to_jpeg(&frame) { Ok(jpeg_frame) => jpeg_frame, @@ -227,17 +164,13 @@ impl MjpegStreamHandler { frame }; - // Frame deduplication (ustreamer-style) - // Check if this frame is identical to the previous one let max_drop = self.max_drop_same_frames.load(Ordering::Relaxed); if max_drop > 0 && frame.online { let current = self.current_frame.load(); if let Some(ref prev_frame) = **current { let dropped_count = self.dropped_same_frames.load(Ordering::Relaxed); - // Check if we should drop this frame if dropped_count < max_drop && frames_are_identical(prev_frame, &frame) { - // Check last frame timestamp to ensure minimum 1fps let last_ts = *self.last_frame_ts.read(); let should_force_send = if let Some(ts) = last_ts { ts.elapsed() >= Duration::from_secs(1) @@ -246,16 +179,13 @@ impl MjpegStreamHandler { }; if !should_force_send { - // Drop this duplicate frame self.dropped_same_frames.fetch_add(1, Ordering::Relaxed); return; } - // If more than 1 second since last frame, force send even if identical } } } - // Frame is different or limit reached or forced by 1fps guarantee, update self.dropped_same_frames.store(0, Ordering::Relaxed); self.sequence.fetch_add(1, Ordering::Relaxed); @@ -263,17 +193,14 @@ impl MjpegStreamHandler { *self.last_frame_ts.write() = Some(Instant::now()); self.current_frame.store(Arc::new(Some(frame))); - // Notify waiting clients let _ = self.frame_notify.send(()); } - /// Encode a non-JPEG frame to JPEG fn encode_to_jpeg(&self, frame: &VideoFrame) -> Result { let resolution = frame.resolution; let sequence = self.sequence.load(Ordering::Relaxed); let desired_quality = self.jpeg_quality.load(Ordering::Relaxed) as u32; - // Get or create encoder let mut encoder_guard = self.jpeg_encoder.lock(); let encoder = encoder_guard.get_or_insert_with(|| { let config = EncoderConfig::jpeg(resolution, desired_quality); @@ -286,15 +213,12 @@ impl MjpegStreamHandler { enc } Err(e) => { - warn!("Failed to create JPEG encoder: {}, using default", e); - // Try with default config - JpegEncoder::new(EncoderConfig::jpeg(resolution, desired_quality)) - .expect("Failed to create default JPEG encoder") + warn!("Failed to create JPEG encoder: {}", e); + panic!("Failed to create JPEG encoder"); } } }); - // Check if resolution changed if encoder.config().resolution != resolution { debug!( "Resolution changed, recreating JPEG encoder: {}x{}", @@ -312,7 +236,6 @@ impl MjpegStreamHandler { } } - // Encode based on input format let encoded = match frame.format { PixelFormat::Yuyv => encoder .encode_yuyv(frame.data(), sequence) @@ -343,38 +266,32 @@ impl MjpegStreamHandler { } }; - // Create new VideoFrame with JPEG data (zero-copy: Bytes -> Arc) Ok(VideoFrame::new( encoded.data, resolution, PixelFormat::Mjpeg, - 0, // stride not relevant for JPEG + 0, sequence, )) } - /// Marks offline; clients exit their read loop. UI overlay comes from `stream.state_changed`. pub fn set_offline(&self) { self.online.store(false, Ordering::SeqCst); let _ = self.frame_notify.send(()); } - /// Set stream online (called when streaming starts) pub fn set_online(&self) { self.online.store(true, Ordering::SeqCst); } - /// Check if stream is online pub fn is_online(&self) -> bool { self.online.load(Ordering::SeqCst) } - /// Get current client count pub fn client_count(&self) -> u64 { self.clients.read().len() as u64 } - /// Register a new client pub fn register_client(&self, client_id: ClientId) { let session = ClientSession::new(client_id.clone()); self.clients.write().insert(client_id.clone(), session); @@ -385,10 +302,9 @@ impl MjpegStreamHandler { ); } - /// Unregister a client pub fn unregister_client(&self, client_id: &str) { if let Some(session) = self.clients.write().remove(client_id) { - let duration = session.connected_duration(); + let duration = session.connected_elapsed(); let duration_secs = duration.as_secs_f32(); let avg_fps = if duration_secs > 0.1 { session.frames_sent as f32 / duration_secs @@ -402,7 +318,6 @@ impl MjpegStreamHandler { } } - /// Record frame sent to a specific client pub fn record_frame_sent(&self, client_id: &str) { if let Some(session) = self.clients.write().get_mut(client_id) { session.last_activity = Instant::now(); @@ -411,7 +326,6 @@ impl MjpegStreamHandler { } } - /// Get per-client statistics pub fn get_clients_stat(&self) -> HashMap { self.clients .read() @@ -422,43 +336,33 @@ impl MjpegStreamHandler { crate::events::types::ClientStats { id: id.clone(), fps: session.fps_calculator.current_fps(), - connected_secs: session.connected_duration().as_secs(), + connected_secs: session.connected_elapsed().as_secs(), }, ) }) .collect() } - /// Get auto-pause configuration pub fn auto_pause_config(&self) -> AutoPauseConfig { self.auto_pause_config.read().clone() } - /// Update auto-pause configuration pub fn set_auto_pause_config(&self, config: AutoPauseConfig) { - let config_clone = config.clone(); - *self.auto_pause_config.write() = config; info!( "Auto-pause config updated: enabled={}, delay={}s, timeout={}s", - config_clone.enabled, - config_clone.shutdown_delay_secs, - config_clone.client_timeout_secs + config.enabled, config.shutdown_delay_secs, config.client_timeout_secs ); + *self.auto_pause_config.write() = config; } - /// Get current frame (if any) pub fn current_frame(&self) -> Option { (**self.current_frame.load()).clone() } - /// Subscribe to frame updates pub fn subscribe(&self) -> broadcast::Receiver<()> { self.frame_notify.subscribe() } - /// Disconnect all clients (used during config changes) - /// This clears the client list and sets the stream offline, - /// which will cause all active MJPEG streams to terminate. pub fn disconnect_all_clients(&self) { let count = { let mut clients = self.clients.write(); @@ -469,32 +373,21 @@ impl MjpegStreamHandler { if count > 0 { info!("Disconnected all {} MJPEG clients for config change", count); } - // Set offline to signal all streaming tasks to stop self.set_offline(); } } -impl Default for MjpegStreamHandler { - fn default() -> Self { - Self::new() - } -} - -/// RAII guard for client lifecycle management -/// Ensures cleanup even on panic or abrupt disconnection pub struct ClientGuard { client_id: ClientId, handler: Arc, } impl ClientGuard { - /// Create a new client guard pub fn new(client_id: ClientId, handler: Arc) -> Self { handler.register_client(client_id.clone()); Self { client_id, handler } } - /// Get client ID pub fn id(&self) -> &ClientId { &self.client_id } @@ -507,8 +400,6 @@ impl Drop for ClientGuard { } impl MjpegStreamHandler { - /// Start stale client cleanup task - /// Should be called once when handler is created pub fn start_cleanup_task(self: Arc) { let handler = self.clone(); tokio::spawn(async move { @@ -522,7 +413,6 @@ impl MjpegStreamHandler { let now = Instant::now(); let mut stale = Vec::new(); - // Find stale clients { let clients = handler.clients.read(); for (id, session) in clients.iter() { @@ -532,7 +422,6 @@ impl MjpegStreamHandler { } } - // Remove stale clients if !stale.is_empty() { let mut clients = handler.clients.write(); for id in stale { @@ -550,10 +439,7 @@ impl MjpegStreamHandler { } } -/// Compare two frames for equality (hash-based, ustreamer-style) -/// Returns true if frames are identical in geometry and content fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool { - // Quick checks first (geometry) if a.len() != b.len() { return false; } @@ -574,13 +460,10 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool { return false; } - // Avoid hashing the whole frame for obviously different frames by sampling a few - // fixed-size windows first. If all samples match, fall back to the cached hash. let a_data = a.data(); let b_data = b.data(); let len = a_data.len(); - // Small frames: direct compare is cheap. if len <= 256 { return a_data == b_data; } @@ -588,7 +471,6 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool { const SAMPLE: usize = 16; debug_assert!(len == b_data.len()); - // Head + tail. if a_data[..SAMPLE] != b_data[..SAMPLE] { return false; } @@ -596,7 +478,6 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool { return false; } - // Two interior samples (quarter + middle) to catch common "same header/footer" cases. let quarter = len / 4; let quarter_start = quarter.saturating_sub(SAMPLE / 2); if a_data[quarter_start..quarter_start + SAMPLE] @@ -610,8 +491,6 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool { return false; } - // Compare hashes instead of full binary data. - // Hash is computed once and cached in OnceLock for efficiency. a.get_hash() == b.get_hash() } @@ -627,7 +506,6 @@ mod tests { assert!(!handler.is_online()); assert_eq!(handler.client_count(), 0); - // Create a frame let _frame = VideoFrame::new( Bytes::from(vec![0xFF, 0xD8, 0x00, 0x00, 0xFF, 0xD9]), Resolution::VGA, @@ -641,15 +519,12 @@ mod tests { fn test_fps_calculator() { let mut calc = FpsCalculator::new(); - // Initially empty assert_eq!(calc.current_fps(), 0); - // Record some frames calc.record_frame(); calc.record_frame(); calc.record_frame(); - // Should have 3 frames in window assert!(calc.frame_times.len() == 3); } } diff --git a/src/stream/mod.rs b/src/stream/mod.rs index b3237b04..3c5e2561 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -1,11 +1,4 @@ -//! Video streaming module -//! -//! Provides MJPEG streaming and WebSocket handlers for MJPEG mode. -//! -//! # Components -//! -//! - `MjpegStreamHandler` - HTTP multipart MJPEG video streaming -//! - `WsHidHandler` - WebSocket HID input handler +//! MJPEG multipart streaming and WebSocket HID (for MJPEG mode). pub mod mjpeg; pub mod ws_hid; diff --git a/src/stream/ws_hid.rs b/src/stream/ws_hid.rs index 0940e884..3df95d76 100644 --- a/src/stream/ws_hid.rs +++ b/src/stream/ws_hid.rs @@ -1,25 +1,4 @@ -//! WebSocket HID Handler for MJPEG mode -//! -//! This module provides a standalone WebSocket HID handler that can be used -//! independently of the application state. It manages multiple WebSocket -//! connections and forwards HID events to the HID controller. -//! -//! # Protocol -//! -//! Only binary protocol is supported for optimal performance. -//! See `crate::hid::datachannel` for message format details. -//! -//! # Architecture -//! -//! ```text -//! WsHidHandler -//! | -//! +-- clients: HashMap -//! +-- hid_controller: Arc -//! | -//! +-- add_client() -> spawns client handler task -//! +-- remove_client() -//! ``` +//! WebSocket HID for MJPEG mode; binary messages per `crate::hid::datachannel`. use axum::extract::ws::{Message, WebSocket}; use futures::{SinkExt, StreamExt}; @@ -34,51 +13,34 @@ use tracing::{debug, error, info, warn}; use crate::hid::datachannel::{parse_hid_message, HidChannelEvent}; use crate::hid::HidController; -/// Client ID type pub type ClientId = String; -/// WebSocket HID client information #[derive(Debug)] pub struct WsHidClient { - /// Client ID pub id: ClientId, - /// Connection timestamp pub connected_at: Instant, - /// Events processed pub events_processed: AtomicU64, - /// Shutdown signal sender shutdown_tx: mpsc::Sender<()>, } impl WsHidClient { - /// Get events processed count pub fn events_count(&self) -> u64 { self.events_processed.load(Ordering::Relaxed) } - /// Get connection duration in seconds pub fn connected_secs(&self) -> u64 { self.connected_at.elapsed().as_secs() } } -/// WebSocket HID Handler -/// -/// Manages WebSocket connections for HID input in MJPEG mode. -/// Only binary protocol is supported for optimal performance. pub struct WsHidHandler { - /// HID controller reference hid_controller: RwLock>>, - /// Active clients clients: RwLock>>, - /// Running state running: AtomicBool, - /// Total events processed total_events: AtomicU64, } impl WsHidHandler { - /// Create a new WebSocket HID handler pub fn new() -> Arc { Arc::new(Self { hid_controller: RwLock::new(None), @@ -88,50 +50,39 @@ impl WsHidHandler { }) } - /// Set HID controller pub fn set_hid_controller(&self, hid: Arc) { *self.hid_controller.write() = Some(hid); info!("WsHidHandler: HID controller set"); } - /// Get HID controller pub fn hid_controller(&self) -> Option> { self.hid_controller.read().clone() } - /// Check if HID controller is available pub fn is_hid_available(&self) -> bool { self.hid_controller.read().is_some() } - /// Get client count pub fn client_count(&self) -> usize { self.clients.read().len() } - /// Check if running pub fn is_running(&self) -> bool { self.running.load(Ordering::SeqCst) } - /// Stop the handler pub fn stop(&self) { self.running.store(false, Ordering::SeqCst); - // Signal all clients to disconnect let clients = self.clients.read(); for client in clients.values() { let _ = client.shutdown_tx.try_send(()); } } - /// Get total events processed pub fn total_events(&self) -> u64 { self.total_events.load(Ordering::Relaxed) } - /// Add a new WebSocket client - /// - /// This spawns a background task to handle the WebSocket connection. pub async fn add_client(self: &Arc, client_id: ClientId, socket: WebSocket) { let (shutdown_tx, shutdown_rx) = mpsc::channel(1); @@ -151,7 +102,6 @@ impl WsHidHandler { self.client_count() ); - // Spawn handler task let handler = self.clone(); tokio::spawn(async move { handler @@ -161,7 +111,6 @@ impl WsHidHandler { }); } - /// Remove a client pub fn remove_client(&self, client_id: &str) { if let Some(client) = self.clients.write().remove(client_id) { info!( @@ -173,7 +122,6 @@ impl WsHidHandler { } } - /// Handle a WebSocket client connection async fn handle_client( &self, client_id: ClientId, @@ -183,7 +131,6 @@ impl WsHidHandler { ) { let (mut sender, mut receiver) = socket.split(); - // Send initial status as binary: 0x00 = ok, 0x01 = error let status_byte = if self.is_hid_available() { 0x00u8 } else { @@ -222,7 +169,6 @@ impl WsHidHandler { debug!("WsHidHandler: Client {} stream ended", client_id); break; } - // Ignore text messages - binary protocol only Some(Ok(Message::Text(_))) => { warn!("WsHidHandler: Ignoring text message from client {} (binary protocol only)", client_id); } @@ -232,7 +178,6 @@ impl WsHidHandler { } } - // Reset HID state when client disconnects to release any held keys/buttons let hid = self.hid_controller.read().clone(); if let Some(hid) = hid { if let Err(e) = hid.reset().await { @@ -246,7 +191,6 @@ impl WsHidHandler { } } - /// Handle binary HID message async fn handle_binary_message(&self, data: &[u8], client: &WsHidClient) -> Result<(), String> { let hid = self .hid_controller @@ -279,17 +223,6 @@ impl WsHidHandler { } } -impl Default for WsHidHandler { - fn default() -> Self { - Self { - hid_controller: RwLock::new(None), - clients: RwLock::new(HashMap::new()), - running: AtomicBool::new(true), - total_events: AtomicU64::new(0), - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/stream_encoder.rs b/src/stream_encoder.rs new file mode 100644 index 00000000..49ba45d1 --- /dev/null +++ b/src/stream_encoder.rs @@ -0,0 +1,18 @@ +//! `EncoderType` → `EncoderBackend` (breaks config ↔ video import cycles). + +use crate::config::EncoderType; +use crate::video::encoder::EncoderBackend; + +/// `None` means “auto” in WebRTC / pipeline (same as `EncoderType::Auto`). +pub fn encoder_type_to_backend(encoder: EncoderType) -> Option { + match encoder { + EncoderType::Auto => None, + EncoderType::Software => Some(EncoderBackend::Software), + EncoderType::Vaapi => Some(EncoderBackend::Vaapi), + EncoderType::Nvenc => Some(EncoderBackend::Nvenc), + EncoderType::Qsv => Some(EncoderBackend::Qsv), + EncoderType::Amf => Some(EncoderBackend::Amf), + EncoderType::Rkmpp => Some(EncoderBackend::Rkmpp), + EncoderType::V4l2m2m => Some(EncoderBackend::V4l2m2m), + } +} diff --git a/src/update/mod.rs b/src/update/mod.rs index 78084d0a..a56c6471 100644 --- a/src/update/mod.rs +++ b/src/update/mod.rs @@ -142,15 +142,10 @@ impl UpdateService { } pub async fn overview(&self, channel: UpdateChannel) -> Result { - let channels: ChannelsManifest = self.fetch_json("/v1/channels.json").await?; - let releases: ReleasesManifest = self.fetch_json("/v1/releases.json").await?; + let (channels, releases) = self.fetch_manifests().await?; let current_version = parse_version(env!("CARGO_PKG_VERSION"))?; - let latest_version_str = match channel { - UpdateChannel::Stable => channels.stable, - UpdateChannel::Beta => channels.beta, - }; - let latest_version = parse_version(&latest_version_str)?; + let latest_version = parse_version(&channel_head_version(&channels, channel))?; let current_parts = parse_version_parts(¤t_version)?; let latest_parts = parse_version_parts(&latest_version)?; @@ -159,11 +154,7 @@ impl UpdateService { if release.channel != channel { continue; } - let version = match parse_version(&release.version) { - Ok(v) => v, - Err(_) => continue, - }; - let version_parts = match parse_version_parts(&version) { + let version_parts = match parse_version_parts(&release.version) { Ok(parts) => parts, Err(_) => continue, }; @@ -253,16 +244,11 @@ impl UpdateService { ) .await; - let channels: ChannelsManifest = self.fetch_json("/v1/channels.json").await?; - let releases: ReleasesManifest = self.fetch_json("/v1/releases.json").await?; + let (channels, releases) = self.fetch_manifests().await?; let current_version = parse_version(env!("CARGO_PKG_VERSION"))?; let target_version = if let Some(channel) = req.channel { - let version_str = match channel { - UpdateChannel::Stable => channels.stable, - UpdateChannel::Beta => channels.beta, - }; - parse_version(&version_str)? + parse_version(&channel_head_version(&channels, channel))? } else { parse_version(req.target_version.as_deref().unwrap_or_default())? }; @@ -443,6 +429,12 @@ impl UpdateService { Ok(()) } + async fn fetch_manifests(&self) -> Result<(ChannelsManifest, ReleasesManifest)> { + let channels = self.fetch_json("/v1/channels.json").await?; + let releases = self.fetch_json("/v1/releases.json").await?; + Ok((channels, releases)) + } + async fn fetch_json Deserialize<'de>>(&self, path: &str) -> Result { let url = format!("{}{}", self.base_url.trim_end_matches('/'), path); let response = self @@ -494,22 +486,7 @@ impl UpdateService { } fn parse_version(input: &str) -> Result { - let parts: Vec<&str> = input.split('.').collect(); - if parts.len() != 3 { - return Err(AppError::Internal(format!( - "Invalid version {}, expected x.x.x", - input - ))); - } - if parts - .iter() - .any(|p| p.is_empty() || !p.chars().all(|c| c.is_ascii_digit())) - { - return Err(AppError::Internal(format!( - "Invalid version {}, expected numeric x.x.x", - input - ))); - } + parse_version_parts(input)?; Ok(input.to_string()) } @@ -527,16 +504,26 @@ fn parse_version_parts(input: &str) -> Result<[u64; 3]> { input ))); } - let major = parts[0] - .parse::() - .map_err(|e| AppError::Internal(format!("Invalid major version {}: {}", parts[0], e)))?; - let minor = parts[1] - .parse::() - .map_err(|e| AppError::Internal(format!("Invalid minor version {}: {}", parts[1], e)))?; - let patch = parts[2] - .parse::() - .map_err(|e| AppError::Internal(format!("Invalid patch version {}: {}", parts[2], e)))?; - Ok([major, minor, patch]) + let mut out = [0u64; 3]; + for (i, p) in parts.iter().enumerate() { + if p.is_empty() || !p.chars().all(|c| c.is_ascii_digit()) { + return Err(AppError::Internal(format!( + "Invalid version {}, expected numeric x.x.x", + input + ))); + } + out[i] = p + .parse::() + .map_err(|e| AppError::Internal(format!("Invalid version component {}: {}", p, e)))?; + } + Ok(out) +} + +fn channel_head_version(channels: &ChannelsManifest, channel: UpdateChannel) -> String { + match channel { + UpdateChannel::Stable => channels.stable.clone(), + UpdateChannel::Beta => channels.beta.clone(), + } } fn compare_version_parts(a: &[u64; 3], b: &[u64; 3]) -> std::cmp::Ordering { diff --git a/src/utils/fs.rs b/src/utils/fs.rs new file mode 100644 index 00000000..07b3fc40 --- /dev/null +++ b/src/utils/fs.rs @@ -0,0 +1,23 @@ +//! Small filesystem helpers. + +use std::path::Path; + +/// Read a UTF-8 file and trim surrounding whitespace. +pub fn read_trimmed(path: &Path) -> Option { + std::fs::read_to_string(path) + .ok() + .map(|value| value.trim().to_string()) +} + +/// Sorted list of directory entry names (lossy exclusion on non-UTF8). +pub fn list_dir_names(path: &Path) -> Vec { + let mut names = std::fs::read_dir(path) + .ok() + .into_iter() + .flatten() + .flatten() + .filter_map(|entry| entry.file_name().into_string().ok()) + .collect::>(); + names.sort(); + names +} diff --git a/src/utils/host.rs b/src/utils/host.rs new file mode 100644 index 00000000..d1db2786 --- /dev/null +++ b/src/utils/host.rs @@ -0,0 +1,15 @@ +//! Host identity helpers. + +/// Truncated content of `/etc/hostname`. Used where RustDesk peers expect the configured static name. +pub fn hostname_from_etc() -> String { + std::fs::read_to_string("/etc/hostname") + .map(|s| s.trim().to_string()) + .unwrap_or_else(|_| "One-KVM".to_string()) +} + +/// Current kernel hostname (`gethostname`). Used for live device info in the UI. +pub fn hostname_uname() -> String { + nix::unistd::gethostname() + .map(|s| s.to_string_lossy().into_owned()) + .unwrap_or_else(|_| "unknown".to_string()) +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index c31db32d..9f382e54 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,9 +1,11 @@ -//! Utility modules for One-KVM -//! -//! This module contains common utilities used across the codebase. +//! Shared utilities. +pub mod fs; +pub mod host; pub mod net; pub mod throttle; +pub use fs::{list_dir_names, read_trimmed}; +pub use host::{hostname_from_etc, hostname_uname}; pub use net::{bind_tcp_listener, bind_udp_socket}; pub use throttle::LogThrottler; diff --git a/src/utils/throttle.rs b/src/utils/throttle.rs index 68ac2edd..3fc4352e 100644 --- a/src/utils/throttle.rs +++ b/src/utils/throttle.rs @@ -1,44 +1,15 @@ -//! Log throttling utility -//! -//! Provides a mechanism to limit how often the same log message is recorded, -//! preventing log flooding when errors occur repeatedly. +//! Limits repeated identical log lines (e.g. reconnect failures). use std::collections::HashMap; use std::sync::RwLock; use std::time::{Duration, Instant}; -/// Log throttler that limits how often the same message is logged -/// -/// This is useful for preventing log flooding when errors occur repeatedly, -/// such as when a device is disconnected and reconnection attempts fail. -/// -/// # Example -/// -/// ```rust -/// use one_kvm::utils::LogThrottler; -/// use std::time::Duration; -/// -/// let throttler = LogThrottler::new(Duration::from_secs(5)); -/// -/// // First call returns true -/// assert!(throttler.should_log("device_error")); -/// -/// // Subsequent calls within 5 seconds return false -/// assert!(!throttler.should_log("device_error")); -/// ``` pub struct LogThrottler { - /// Map of message key to last log time last_logged: RwLock>, - /// Throttle interval interval: Duration, } impl LogThrottler { - /// Create a new log throttler with the specified interval - /// - /// # Arguments - /// - /// * `interval` - The minimum time between log messages for the same key pub fn new(interval: Duration) -> Self { Self { last_logged: RwLock::new(HashMap::new()), @@ -46,23 +17,14 @@ impl LogThrottler { } } - /// Create a new log throttler with interval specified in seconds pub fn with_secs(secs: u64) -> Self { Self::new(Duration::from_secs(secs)) } - /// Check if a message should be logged (not throttled) - /// - /// Returns `true` if the message should be logged, `false` if it should be throttled. - /// If `true` is returned, the internal timestamp is updated. - /// - /// # Arguments - /// - /// * `key` - A unique identifier for the message type + /// Returns whether to emit the log line; updates the timestamp when `true`. pub fn should_log(&self, key: &str) -> bool { let now = Instant::now(); - // First check with read lock (fast path) { let map = self.last_logged.read().unwrap(); if let Some(last) = map.get(key) { @@ -72,9 +34,7 @@ impl LogThrottler { } } - // Update with write lock let mut map = self.last_logged.write().unwrap(); - // Double-check after acquiring write lock if let Some(last) = map.get(key) { if now.duration_since(*last) < self.interval { return false; @@ -84,32 +44,14 @@ impl LogThrottler { true } - /// Clear throttle state for a specific key - /// - /// This should be called when an error condition recovers, - /// so the next error will be logged immediately. - /// - /// # Arguments - /// - /// * `key` - The key to clear + /// Call when a condition recovers so the next failure logs immediately. pub fn clear(&self, key: &str) { self.last_logged.write().unwrap().remove(key); } - /// Clear all throttle state pub fn clear_all(&self) { self.last_logged.write().unwrap().clear(); } - - /// Get the number of tracked keys - pub fn len(&self) -> usize { - self.last_logged.read().unwrap().len() - } - - /// Check if the throttler is empty - pub fn is_empty(&self) -> bool { - self.last_logged.read().unwrap().is_empty() - } } impl Clone for LogThrottler { @@ -122,23 +64,11 @@ impl Clone for LogThrottler { } impl Default for LogThrottler { - /// Create a default log throttler with 5 second interval fn default() -> Self { Self::with_secs(5) } } -/// Macro for throttled warning logging -/// -/// # Example -/// -/// ```rust -/// use one_kvm::utils::LogThrottler; -/// use one_kvm::warn_throttled; -/// -/// let throttler = LogThrottler::default(); -/// warn_throttled!(throttler, "my_error", "Error occurred: {}", "details"); -/// ``` #[macro_export] macro_rules! warn_throttled { ($throttler:expr, $key:expr, $($arg:tt)*) => { @@ -148,7 +78,6 @@ macro_rules! warn_throttled { }; } -/// Macro for throttled error logging #[macro_export] macro_rules! error_throttled { ($throttler:expr, $key:expr, $($arg:tt)*) => { @@ -158,16 +87,6 @@ macro_rules! error_throttled { }; } -/// Macro for throttled info logging -#[macro_export] -macro_rules! info_throttled { - ($throttler:expr, $key:expr, $($arg:tt)*) => { - if $throttler.should_log($key) { - tracing::info!($($arg)*); - } - }; -} - #[cfg(test)] mod tests { use super::*; @@ -183,16 +102,11 @@ mod tests { fn test_throttling() { let throttler = LogThrottler::new(Duration::from_millis(100)); - // First call should succeed assert!(throttler.should_log("test_key")); - - // Immediate second call should be throttled assert!(!throttler.should_log("test_key")); - // Wait for throttle to expire thread::sleep(Duration::from_millis(150)); - // Should succeed again assert!(throttler.should_log("test_key")); } @@ -200,7 +114,6 @@ mod tests { fn test_different_keys() { let throttler = LogThrottler::with_secs(10); - // Different keys should be independent assert!(throttler.should_log("key1")); assert!(throttler.should_log("key2")); assert!(!throttler.should_log("key1")); @@ -214,10 +127,8 @@ mod tests { assert!(throttler.should_log("test_key")); assert!(!throttler.should_log("test_key")); - // Clear the key throttler.clear("test_key"); - // Should be able to log again assert!(throttler.should_log("test_key")); } @@ -239,19 +150,4 @@ mod tests { let throttler = LogThrottler::default(); assert!(throttler.should_log("test")); } - - #[test] - fn test_len_and_is_empty() { - let throttler = LogThrottler::with_secs(10); - - assert!(throttler.is_empty()); - assert_eq!(throttler.len(), 0); - - throttler.should_log("key1"); - assert!(!throttler.is_empty()); - assert_eq!(throttler.len(), 1); - - throttler.should_log("key2"); - assert_eq!(throttler.len(), 2); - } } diff --git a/src/video/capture_limits.rs b/src/video/capture_limits.rs new file mode 100644 index 00000000..9cecc31e --- /dev/null +++ b/src/video/capture_limits.rs @@ -0,0 +1,30 @@ +//! Shared tuning for V4L2 MJPEG capture paths (`Streamer` + `SharedVideoPipeline`). + +/// Frames smaller than this are treated as incomplete / noise. +pub(crate) const MIN_CAPTURE_FRAME_SIZE: usize = 128; + +/// After startup, validate JPEG header every N frames to limit CPU use. +pub(crate) const JPEG_VALIDATE_INTERVAL: u64 = 30; + +/// Validate every MJPEG frame for the first N frames (UVC warm-up / bad headers). +pub(crate) const STARTUP_JPEG_VALIDATE_FRAMES: u64 = 3; + +#[inline] +pub(crate) fn should_validate_jpeg_frame(validate_counter: u64) -> bool { + validate_counter <= STARTUP_JPEG_VALIDATE_FRAMES + || validate_counter.is_multiple_of(JPEG_VALIDATE_INTERVAL) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn jpeg_validation_policy_startup_then_interval() { + assert!(should_validate_jpeg_frame(1)); + assert!(should_validate_jpeg_frame(2)); + assert!(should_validate_jpeg_frame(3)); + assert!(!should_validate_jpeg_frame(4)); + assert!(should_validate_jpeg_frame(30)); + } +} diff --git a/src/video/codec_constraints.rs b/src/video/codec_constraints.rs index cb9ff711..41d2cd14 100644 --- a/src/video/codec_constraints.rs +++ b/src/video/codec_constraints.rs @@ -135,7 +135,7 @@ pub async fn enforce_constraints_with_stream_manager( } if current_mode == StreamMode::WebRTC { - let current_codec = stream_manager.webrtc_streamer().current_video_codec().await; + let current_codec = stream_manager.current_video_codec().await; if !constraints.is_webrtc_codec_allowed(current_codec) { let target_codec = constraints.preferred_webrtc_codec(); stream_manager.set_video_codec(target_codec).await?; diff --git a/src/video/csi_bridge.rs b/src/video/csi_bridge.rs index 8888aa62..a7f83c6b 100644 --- a/src/video/csi_bridge.rs +++ b/src/video/csi_bridge.rs @@ -14,9 +14,7 @@ use tracing::{debug, info, warn}; use v4l2r::bindings::{ v4l2_bt_timings, v4l2_dv_timings, V4L2_DV_BT_656_1120, V4L2_DV_FL_HAS_CEA861_VIC, }; -use v4l2r::ioctl::{ - self, Event as V4l2Event, EventType, QueryDvTimingsError, SubscribeEventFlags, -}; +use v4l2r::ioctl::{self, Event as V4l2Event, EventType, QueryDvTimingsError, SubscribeEventFlags}; use v4l2r::nix::errno::Errno; use crate::video::SignalStatus; @@ -143,9 +141,9 @@ pub fn probe_signal(subdev_fd: &impl AsRawFd, kind: CsiBridgeKind) -> ProbeResul Err(QueryDvTimingsError::NoLink) => ProbeResult::NoCable, Err(QueryDvTimingsError::UnstableSignal) => ProbeResult::NoSync, Err(QueryDvTimingsError::IoctlError(Errno::ERANGE)) => ProbeResult::OutOfRange, - Err(QueryDvTimingsError::IoctlError( - Errno::EIO | Errno::EREMOTEIO | Errno::ETIMEDOUT, - )) => ProbeResult::NoSync, + Err(QueryDvTimingsError::IoctlError(Errno::EIO | Errno::EREMOTEIO | Errno::ETIMEDOUT)) => { + ProbeResult::NoSync + } Err(QueryDvTimingsError::Unsupported) | Err(QueryDvTimingsError::IoctlError(_)) => { ProbeResult::NoSignal } @@ -222,14 +220,8 @@ fn classify_timings(timings: v4l2_dv_timings, kind: CsiBridgeKind) -> ProbeResul return ProbeResult::NoSignal; } - let total_h: u64 = (width - + bt.hfrontporch - + bt.hsync - + bt.hbackporch) as u64; - let total_v: u64 = (height - + bt.vfrontporch - + bt.vsync - + bt.vbackporch) as u64; + let total_h: u64 = (width + bt.hfrontporch + bt.hsync + bt.hbackporch) as u64; + let total_v: u64 = (height + bt.vfrontporch + bt.vsync + bt.vbackporch) as u64; let fps = if total_h > 0 && total_v > 0 && pixelclock > 0 { Some(pixelclock as f64 / (total_h as f64 * total_v as f64)) } else { diff --git a/src/video/device.rs b/src/video/device.rs index 94b8957b..3a61f7a0 100644 --- a/src/video/device.rs +++ b/src/video/device.rs @@ -168,16 +168,15 @@ impl VideoDevice { // subdev (the video node returns ENOTTY). Tc358743 and rk_hdmirx // typically expose DV ioctls on the video node itself, but having // the subdev handle for EDID/event subscription doesn't hurt. - let (subdev_path, bridge_kind) = if is_rkcif_driver(&caps.driver) - || is_rk_hdmirx_driver(&caps.driver, &caps.card) - { - match csi_bridge::discover_subdev_for_video(&self.path) { - Some((path, kind)) => (Some(path), Some(format!("{:?}", kind).to_lowercase())), - None => (None, None), - } - } else { - (None, None) - }; + let (subdev_path, bridge_kind) = + if is_rkcif_driver(&caps.driver) || is_rk_hdmirx_driver(&caps.driver, &caps.card) { + match csi_bridge::discover_subdev_for_video(&self.path) { + Some((path, kind)) => (Some(path), Some(format!("{:?}", kind).to_lowercase())), + None => (None, None), + } + } else { + (None, None) + }; // Probe the HDMI source for both signal presence *and* the live // frame-rate. rkcif's `VIDIOC_ENUM_FRAMEINTERVALS` returns a @@ -225,9 +224,7 @@ impl VideoDevice { (false, None) } } - } else if is_rk_hdmirx_driver(&caps.driver, &caps.card) - || is_rkcif_driver(&caps.driver) - { + } else if is_rk_hdmirx_driver(&caps.driver, &caps.card) || is_rkcif_driver(&caps.driver) { let dv = self.current_dv_timings_mode(); debug!( "has_signal via video node {:?} (driver={}): dv_timings={:?}", @@ -247,21 +244,20 @@ impl VideoDevice { (true, None) }; - let mut formats = if is_rk_hdmirx_driver(&caps.driver, &caps.card) - || is_rkcif_driver(&caps.driver) - { - // CSI/HDMI bridge drivers (rk_hdmirx, rkcif) expose multiple pixel - // formats via ENUM_FMT (e.g. rk_hdmirx: BGR3/NV24/NV16/NV12) but - // `ENUM_FRAMESIZES` is fiction for these drivers (rkcif reports a - // degenerate `64x64 StepWise 8/8` that only describes its DMA - // engine, rk_hdmirx returns ENOTTY). The only authoritative - // resolution is whatever the bridge subdev's DV timings report, - // so we treat the HDMI source mode as the single allowed - // resolution for every pixel format. - self.enumerate_bridge_formats(subdev_hdmi_mode)? - } else { - self.enumerate_formats()? - }; + let mut formats = + if is_rk_hdmirx_driver(&caps.driver, &caps.card) || is_rkcif_driver(&caps.driver) { + // CSI/HDMI bridge drivers (rk_hdmirx, rkcif) expose multiple pixel + // formats via ENUM_FMT (e.g. rk_hdmirx: BGR3/NV24/NV16/NV12) but + // `ENUM_FRAMESIZES` is fiction for these drivers (rkcif reports a + // degenerate `64x64 StepWise 8/8` that only describes its DMA + // engine, rk_hdmirx returns ENOTTY). The only authoritative + // resolution is whatever the bridge subdev's DV timings report, + // so we treat the HDMI source mode as the single allowed + // resolution for every pixel format. + self.enumerate_bridge_formats(subdev_hdmi_mode)? + } else { + self.enumerate_formats()? + }; // For CSI/HDMI bridges, the driver-enumerated fps list is fiction // (rkcif: always `1..30`; rk_hdmirx: typically `ENOTTY`). Replace @@ -923,7 +919,11 @@ pub fn enumerate_devices() -> Result> { // The path tiebreaker ensures deterministic ordering when multiple sub-devices // share the same priority (e.g. rkcif nodes), so that /dev/video0 is preferred // over /dev/video10 after deduplication. - devices.sort_by(|a, b| b.priority.cmp(&a.priority).then_with(|| a.path.cmp(&b.path))); + devices.sort_by(|a, b| { + b.priority + .cmp(&a.priority) + .then_with(|| a.path.cmp(&b.path)) + }); // Deduplicate rkcif sub-devices: the driver exposes many /dev/video* nodes // for a single MIPI CSI pipeline. Keep only the highest-priority node per @@ -976,8 +976,11 @@ fn collapse_rkcif_probe_candidates(candidates: &mut Vec) { fn sysfs_uevent_driver(path: &Path) -> Option { let name = path.file_name()?.to_str()?; - let uevent = - read_sysfs_string(&Path::new("/sys/class/video4linux").join(name).join("device/uevent"))?; + let uevent = read_sysfs_string( + &Path::new("/sys/class/video4linux") + .join(name) + .join("device/uevent"), + )?; extract_uevent_value(&uevent, "driver") } @@ -1037,7 +1040,10 @@ fn sysfs_maybe_capture(path: &Path) -> bool { // kernel driver that created them has been unloaded but the device nodes // were never cleaned up. Opening them returns ENODEV; skip the probe. if !sysfs_base.exists() { - debug!("Skipping {:?}: no matching /sys/class/video4linux entry", path); + debug!( + "Skipping {:?}: no matching /sys/class/video4linux entry", + path + ); return false; } @@ -1081,7 +1087,13 @@ fn sysfs_maybe_capture(path: &Path) -> bool { // succeed QUERYCAP but expose only VIDEO_M2M / STATS / PARAMS and get // filtered later — skipping here saves an open() + ioctl() per node. let driver_skip = [ - "rkvenc", "rkvdec", "vepu", "vdpu", "hantro", "mpp_", "rockchip-vpu", + "rkvenc", + "rkvdec", + "vepu", + "vdpu", + "hantro", + "mpp_", + "rockchip-vpu", ]; if let Some(driver) = &driver { if driver_skip.iter().any(|hint| driver.contains(hint)) { diff --git a/src/video/encoder/traits.rs b/src/video/encoder/traits.rs index 4b3b4cc9..aee36620 100644 --- a/src/video/encoder/traits.rs +++ b/src/video/encoder/traits.rs @@ -1,91 +1,13 @@ //! Encoder traits and common types use bytes::Bytes; -use serde::{Deserialize, Serialize}; use std::time::Instant; -use typeshare::typeshare; use crate::error::Result; use crate::video::format::{PixelFormat, Resolution}; -/// Bitrate preset for video encoding -/// -/// Simplifies bitrate configuration by providing three intuitive presets -/// plus a custom option for advanced users. -#[typeshare] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(tag = "type", content = "value")] -#[derive(Default)] -pub enum BitratePreset { - /// Speed priority: 1 Mbps, lowest latency, smaller GOP - /// Best for: slow networks, remote management, low-bandwidth scenarios - Speed, - /// Balanced: 4 Mbps, good quality/latency tradeoff - /// Best for: typical usage, recommended default - #[default] - Balanced, - /// Quality priority: 8 Mbps, best visual quality - /// Best for: local network, high-bandwidth scenarios, detailed work - Quality, - /// Custom bitrate in kbps (for advanced users) - Custom(u32), -} - -impl BitratePreset { - /// Get bitrate value in kbps - pub fn bitrate_kbps(&self) -> u32 { - match self { - Self::Speed => 1000, - Self::Balanced => 4000, - Self::Quality => 8000, - Self::Custom(kbps) => *kbps, - } - } - - /// Get recommended GOP size based on preset - /// - /// Speed preset uses shorter GOP for faster recovery from packet loss. - /// Quality preset uses longer GOP for better compression efficiency. - pub fn gop_size(&self, fps: u32) -> u32 { - match self { - Self::Speed => (fps / 2).max(15), // 0.5 second, minimum 15 frames - Self::Balanced => fps, // 1 second - Self::Quality => fps * 2, // 2 seconds - Self::Custom(_) => fps, // Default 1 second for custom - } - } - - /// Get quality preset name for encoder configuration - pub fn quality_level(&self) -> &'static str { - match self { - Self::Speed => "low", // ultrafast/veryfast preset - Self::Balanced => "medium", // medium preset - Self::Quality => "high", // slower preset, better quality - Self::Custom(_) => "medium", - } - } - - /// Create from kbps value, mapping to nearest preset or Custom - pub fn from_kbps(kbps: u32) -> Self { - match kbps { - 0..=1500 => Self::Speed, - 1501..=6000 => Self::Balanced, - 6001..=10000 => Self::Quality, - _ => Self::Custom(kbps), - } - } -} - -impl std::fmt::Display for BitratePreset { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Speed => write!(f, "Speed (1 Mbps)"), - Self::Balanced => write!(f, "Balanced (4 Mbps)"), - Self::Quality => write!(f, "Quality (8 Mbps)"), - Self::Custom(kbps) => write!(f, "Custom ({} kbps)", kbps), - } - } -} +/// Defined in `config::schema` (typeshare + serde). Re-export for encoder users. +pub use crate::config::BitratePreset; /// Encoder configuration #[derive(Debug, Clone)] diff --git a/src/video/mod.rs b/src/video/mod.rs index 55c5b846..3d76bc7f 100644 --- a/src/video/mod.rs +++ b/src/video/mod.rs @@ -2,6 +2,7 @@ //! //! This module provides V4L2 video capture, encoding, and streaming functionality. +pub(crate) mod capture_limits; pub mod codec_constraints; pub mod convert; pub mod csi_bridge; @@ -13,6 +14,8 @@ pub mod frame; pub mod shared_video_pipeline; pub mod stream_manager; pub mod streamer; +pub mod traits; +pub mod types; pub mod usb_reset; pub mod v4l2r_capture; diff --git a/src/video/shared_video_pipeline.rs b/src/video/shared_video_pipeline.rs index 267b961c..de5417dd 100644 --- a/src/video/shared_video_pipeline.rs +++ b/src/video/shared_video_pipeline.rs @@ -38,25 +38,19 @@ const CAPTURE_TIMEOUT_STOP_THRESHOLD: u32 = 60; const CAPTURE_TIMEOUT_SOFT_RESTART_THRESHOLD: u32 = 3; const CSI_BRIDGE_NOSIGNAL_INTERVAL_MS: u64 = 500; const NOSIGNAL_POLL_MAX: Duration = Duration::from_secs(20); -/// Minimum valid frame size for capture -const MIN_CAPTURE_FRAME_SIZE: usize = 128; -/// Validate every JPEG frame during startup to avoid poisoning HW decoders -/// with incomplete UVC warm-up frames. -const STARTUP_JPEG_VALIDATE_FRAMES: u64 = 3; -/// Validate JPEG header every N frames to reduce overhead -const JPEG_VALIDATE_INTERVAL: u64 = 30; /// Throttle repeated encoding errors to avoid log flooding const ENCODE_ERROR_THROTTLE_SECS: u64 = 5; use crate::error::{AppError, Result}; use crate::utils::LogThrottler; +use crate::video::capture_limits::{should_validate_jpeg_frame, MIN_CAPTURE_FRAME_SIZE}; use crate::video::csi_bridge::{self, ProbeResult}; +use crate::video::device::parse_bridge_kind; use crate::video::encoder::registry::{EncoderBackend, VideoEncoderType}; use crate::video::format::{PixelFormat, Resolution}; use crate::video::frame::{FrameBuffer, FrameBufferPool, VideoFrame}; -use crate::video::device::parse_bridge_kind; -use crate::video::SignalStatus; use crate::video::v4l2r_capture::{is_source_changed_error, BridgeContext, V4l2rCaptureStream}; +use crate::video::SignalStatus; #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] use hwcodec::ffmpeg_hw::last_error_message as ffmpeg_hw_last_error; @@ -250,11 +244,6 @@ fn log_encoding_error( } } -fn should_validate_jpeg_frame(validate_counter: u64) -> bool { - validate_counter <= STARTUP_JPEG_VALIDATE_FRAMES - || validate_counter.is_multiple_of(JPEG_VALIDATE_INTERVAL) -} - /// Pipeline statistics #[derive(Debug, Clone, Default)] pub struct SharedVideoPipelineStats { @@ -283,10 +272,7 @@ pub struct SharedVideoPipeline { last_state_notification: ParkingMutex>, } -fn poll_bridge_subdev_after_no_signal( - bridge_ctx: &BridgeContext, - pipeline: &SharedVideoPipeline, -) { +fn poll_bridge_subdev_after_no_signal(bridge_ctx: &BridgeContext, pipeline: &SharedVideoPipeline) { let Some(subdev_path) = bridge_ctx.subdev_path.as_ref() else { return; }; @@ -313,7 +299,10 @@ fn poll_bridge_subdev_after_no_signal( let fd = match csi_bridge::open_subdev(subdev_path) { Ok(f) => f, Err(e) => { - debug!("No-signal poll: open subdev {:?} failed: {}", subdev_path, e); + debug!( + "No-signal poll: open subdev {:?} failed: {}", + subdev_path, e + ); std::thread::sleep(Duration::from_millis(CSI_BRIDGE_NOSIGNAL_INTERVAL_MS)); continue; } @@ -565,21 +554,20 @@ impl SharedVideoPipeline { subdev_path.clone(), parse_bridge_kind(bridge_kind.as_deref()), ); - let preopened: Option = - match V4l2rCaptureStream::open_with_bridge( - &device_path, - config.resolution, - config.input_format, - config.fps, - buffer_count.max(1), - Duration::from_secs(2), - bridge_ctx_probe, - ) { - Ok(s) => { - let negotiated_res = s.resolution(); - let negotiated_fmt = s.format(); - if negotiated_res != config.resolution || negotiated_fmt != config.input_format { - info!( + let preopened: Option = match V4l2rCaptureStream::open_with_bridge( + &device_path, + config.resolution, + config.input_format, + config.fps, + buffer_count.max(1), + Duration::from_secs(2), + bridge_ctx_probe, + ) { + Ok(s) => { + let negotiated_res = s.resolution(); + let negotiated_fmt = s.format(); + if negotiated_res != config.resolution || negotiated_fmt != config.input_format { + info!( "Negotiated capture {}x{} {:?} (configured {}x{} {:?}) — aligning encoder to source", negotiated_res.width, negotiated_res.height, @@ -588,25 +576,25 @@ impl SharedVideoPipeline { config.resolution.height, config.input_format ); - config.resolution = negotiated_res; - config.input_format = negotiated_fmt; - *self.config.write().await = config.clone(); - } - Some(s) + config.resolution = negotiated_res; + config.input_format = negotiated_fmt; + *self.config.write().await = config.clone(); } - Err(AppError::CaptureNoSignal { kind }) => { - debug!( - "Pre-probe: no signal — encoder uses configured geometry until capture opens" - ); - let status = SignalStatus::from_str(&kind).unwrap_or(SignalStatus::NoSignal); - self.notify_state(PipelineStateNotification::no_signal( - status, - Some(Duration::from_secs(2).as_millis() as u64), - )); - None - } - Err(e) => return Err(e), - }; + Some(s) + } + Err(AppError::CaptureNoSignal { kind }) => { + debug!( + "Pre-probe: no signal — encoder uses configured geometry until capture opens" + ); + let status = SignalStatus::from_str(&kind).unwrap_or(SignalStatus::NoSignal); + self.notify_state(PipelineStateNotification::no_signal( + status, + Some(Duration::from_secs(2).as_millis() as u64), + )); + None + } + Err(e) => return Err(e), + }; let mut encoder_state = build_encoder_state(&config)?; let _ = self.running.send(true); @@ -707,10 +695,8 @@ impl SharedVideoPipeline { let latest_frame = latest_frame.clone(); let frame_seq_tx = frame_seq_tx.clone(); let buffer_pool = buffer_pool.clone(); - let bridge_ctx = BridgeContext::from_parts( - subdev_path, - parse_bridge_kind(bridge_kind.as_deref()), - ); + let bridge_ctx = + BridgeContext::from_parts(subdev_path, parse_bridge_kind(bridge_kind.as_deref())); std::thread::spawn(move || { let mut stream: Option = None; let mut initial_geometry: Option<(Resolution, PixelFormat)> = None; @@ -874,7 +860,8 @@ impl SharedVideoPipeline { // ── No usable stream? Try to (re)open, back off on failure. ── if stream.is_none() { - match open_or_retry(&device_path, &config, buffer_count, bridge_ctx.clone()) { + match open_or_retry(&device_path, &config, buffer_count, bridge_ctx.clone()) + { OpenResult::Opened(new_stream) => { let new_res = new_stream.resolution(); let new_fmt = new_stream.format(); @@ -884,7 +871,8 @@ impl SharedVideoPipeline { // encoder was sized to saved settings — if DV timings now // disagree, we cannot encode until WebRTC resyncs dimensions. if initial_geometry.is_none() - && (new_res != config.resolution || new_fmt != config.input_format) + && (new_res != config.resolution + || new_fmt != config.input_format) { info!( "Deferred capture open is {}x{} {:?} but encoder expects {}x{} {:?} — stopping for dimension resync", @@ -898,7 +886,8 @@ impl SharedVideoPipeline { pipeline.notify_state(PipelineStateNotification::device_busy( "config_changing", )); - *pipeline.pending_sync_geometry.lock() = Some((new_res, new_fmt)); + *pipeline.pending_sync_geometry.lock() = + Some((new_res, new_fmt)); let _ = pipeline.running.send(false); pipeline.running_flag.store(false, Ordering::Release); let _ = frame_seq_tx.send(sequence.wrapping_add(1)); @@ -950,8 +939,7 @@ impl SharedVideoPipeline { ); } OpenResult::NoSignal(status) => { - consecutive_timeouts = - consecutive_timeouts.saturating_add(1); + consecutive_timeouts = consecutive_timeouts.saturating_add(1); if consecutive_timeouts >= CAPTURE_TIMEOUT_STOP_THRESHOLD { warn!( "Capture soft-restart gave up after {} attempts, \ @@ -1092,9 +1080,7 @@ impl SharedVideoPipeline { } } - if consecutive_timeouts - >= CAPTURE_TIMEOUT_SOFT_RESTART_THRESHOLD - { + if consecutive_timeouts >= CAPTURE_TIMEOUT_SOFT_RESTART_THRESHOLD { // Drop the stream so the next loop // iteration re-opens via the DV-timings // probe. This catches source-side @@ -1105,12 +1091,10 @@ impl SharedVideoPipeline { closing stream for soft-restart", consecutive_timeouts ); - pipeline.notify_state( - PipelineStateNotification::no_signal( - SignalStatus::UvcCaptureStall, - Some(Duration::from_secs(2).as_millis() as u64), - ), - ); + pipeline.notify_state(PipelineStateNotification::no_signal( + SignalStatus::UvcCaptureStall, + Some(Duration::from_secs(2).as_millis() as u64), + )); stream = None; continue; } @@ -1551,13 +1535,4 @@ mod tests { let h265 = SharedVideoPipelineConfig::h265(Resolution::HD720, BitratePreset::Speed); assert_eq!(h265.output_codec, VideoEncoderType::H265); } - - #[test] - fn test_startup_jpeg_validation_policy() { - assert!(should_validate_jpeg_frame(1)); - assert!(should_validate_jpeg_frame(2)); - assert!(should_validate_jpeg_frame(3)); - assert!(!should_validate_jpeg_frame(4)); - assert!(should_validate_jpeg_frame(30)); - } } diff --git a/src/video/stream_manager.rs b/src/video/stream_manager.rs index ff0f5ed5..cddad1a4 100644 --- a/src/video/stream_manager.rs +++ b/src/video/stream_manager.rs @@ -39,8 +39,8 @@ use crate::stream::MjpegStreamHandler; use crate::video::codec_constraints::StreamCodecConstraints; use crate::video::format::{PixelFormat, Resolution}; use crate::video::is_csi_hdmi_bridge; -use crate::video::streamer::{Streamer, StreamerStats, StreamerState}; -use crate::webrtc::WebRtcStreamer; +use crate::video::streamer::{Streamer, StreamerState, StreamerStats}; +use crate::video::traits::VideoOutput; /// Video stream manager configuration #[derive(Debug, Clone)] @@ -95,8 +95,8 @@ pub struct VideoStreamManager { mode: RwLock, /// MJPEG streamer (handles video capture and MJPEG distribution) streamer: Arc, - /// WebRTC streamer (unified WebRTC manager with multi-codec support) - webrtc_streamer: Arc, + /// WebRTC output (unified WebRTC manager with multi-codec support) + webrtc_streamer: Arc, /// Event bus for notifications events: RwLock>>, /// Configuration store @@ -111,7 +111,7 @@ impl VideoStreamManager { /// Create a new video stream manager with WebRtcStreamer pub fn with_webrtc_streamer( streamer: Arc, - webrtc_streamer: Arc, + webrtc_streamer: Arc, ) -> Arc { Arc::new(Self { mode: RwLock::new(StreamMode::Mjpeg), @@ -175,11 +175,6 @@ impl VideoStreamManager { self.streamer.clone() } - /// Get the WebRTC streamer (unified interface with multi-codec support) - pub fn webrtc_streamer(&self) -> Arc { - self.webrtc_streamer.clone() - } - /// Get the MJPEG stream handler pub fn mjpeg_handler(&self) -> Arc { self.streamer.mjpeg_handler() @@ -812,6 +807,16 @@ impl VideoStreamManager { self.webrtc_streamer.get_pipeline_config().await } + /// Get current video codec type + pub async fn current_video_codec(&self) -> crate::video::encoder::VideoCodecType { + self.webrtc_streamer.current_video_codec().await + } + + /// Check if hardware encoding is in use + pub async fn is_hardware_encoding(&self) -> bool { + self.webrtc_streamer.is_hardware_encoding().await + } + /// Set video codec for the shared video pipeline /// /// This allows external consumers (like RustDesk) to set the video codec diff --git a/src/video/streamer.rs b/src/video/streamer.rs index 8c78767d..01d83779 100644 --- a/src/video/streamer.rs +++ b/src/video/streamer.rs @@ -12,7 +12,9 @@ use tokio::sync::RwLock; use tracing::{debug, error, info, trace, warn}; use super::csi_bridge; -use super::device::{enumerate_devices, find_best_device, parse_bridge_kind, VideoDevice, VideoDeviceInfo}; +use super::device::{ + enumerate_devices, find_best_device, parse_bridge_kind, VideoDevice, VideoDeviceInfo, +}; use super::format::{PixelFormat, Resolution}; use super::frame::{FrameBuffer, FrameBufferPool, VideoFrame}; use super::is_csi_hdmi_bridge; @@ -20,13 +22,9 @@ use crate::error::{AppError, Result}; use crate::events::{EventBus, SystemEvent}; use crate::stream::MjpegStreamHandler; use crate::utils::LogThrottler; +use crate::video::capture_limits::{should_validate_jpeg_frame, MIN_CAPTURE_FRAME_SIZE}; use crate::video::v4l2r_capture::{is_source_changed_error, BridgeContext, V4l2rCaptureStream}; -/// Minimum valid frame size for capture -const MIN_CAPTURE_FRAME_SIZE: usize = 128; -/// Validate JPEG header every N frames to reduce overhead -const JPEG_VALIDATE_INTERVAL: u64 = 30; - /// Streamer configuration #[derive(Debug, Clone)] pub struct StreamerConfig { @@ -477,11 +475,10 @@ impl Streamer { ); return Ok(preferred); } - let fmt = device - .formats - .first() - .map(|f| f.format) - .ok_or_else(|| AppError::VideoError("No supported formats found".to_string()))?; + let fmt = + device.formats.first().map(|f| f.format).ok_or_else(|| { + AppError::VideoError("No supported formats found".to_string()) + })?; info!( "select_format: CSI bridge with signal, preferred {:?} unavailable, selected {:?} from {:?}", preferred, @@ -916,9 +913,7 @@ impl Streamer { "CSI open probe reports no signal ({:?}), will soft-restart", status ); - set_retry( - backoff_secs(no_signal_restart_count).saturating_mul(1000), - ); + set_retry(backoff_secs(no_signal_restart_count).saturating_mul(1000)); go_offline(); set_state(status.into()); last_error = Some(format!("CaptureNoSignal({})", kind)); @@ -952,8 +947,9 @@ impl Streamer { // restart path. This lets CSI bridges recover on their // own when the source comes back (resolution change, // host reboot, HDMI cable re-plug). - let was_no_signal = - handle.block_on(async { self.state().await }).is_no_signal_like(); + let was_no_signal = handle + .block_on(async { self.state().await }) + .is_no_signal_like(); if !was_no_signal { error!( "Failed to open device {:?}: {}", @@ -965,9 +961,7 @@ impl Streamer { break 'session; } - debug!( - "Open failed in NoSignal-like state, backing off before soft-restart" - ); + debug!("Open failed in NoSignal-like state, backing off before soft-restart"); let wait = backoff_secs(no_signal_restart_count); set_retry(wait.saturating_mul(1000)); std::thread::sleep(Duration::from_secs(wait)); @@ -1040,9 +1034,7 @@ impl Streamer { Err(e) => { if is_source_changed_error(&e) { info!("Capture SOURCE_CHANGE — soft-restart for DV re-probe"); - set_retry( - backoff_secs(no_signal_restart_count).saturating_mul(1000), - ); + set_retry(backoff_secs(no_signal_restart_count).saturating_mul(1000)); go_offline(); set_state(StreamerState::NoSignal); need_soft_restart = true; @@ -1112,15 +1104,13 @@ impl Streamer { if is_transient_signal_error { if os_err == Some(71) { - warn!( - "Capture transient error (EPROTO/-71, often UVC USB): {}", - e - ); - let is_uvc = handle.block_on(async { - self.current_device.read().await.as_ref().is_some_and( - |d| d.driver.eq_ignore_ascii_case("uvcvideo"), - ) - }); + warn!("Capture transient error (EPROTO/-71, often UVC USB): {}", e); + let is_uvc = + handle.block_on(async { + self.current_device.read().await.as_ref().is_some_and(|d| { + d.driver.eq_ignore_ascii_case("uvcvideo") + }) + }); if is_uvc { go_offline(); set_state(StreamerState::UvcUsbError); @@ -1133,9 +1123,7 @@ impl Streamer { e ); } - set_retry( - backoff_secs(no_signal_restart_count).saturating_mul(1000), - ); + set_retry(backoff_secs(no_signal_restart_count).saturating_mul(1000)); go_offline(); set_state(StreamerState::NoSignal); need_soft_restart = true; @@ -1165,7 +1153,7 @@ impl Streamer { validate_counter = validate_counter.wrapping_add(1); if pixel_format.is_compressed() - && validate_counter.is_multiple_of(JPEG_VALIDATE_INTERVAL) + && should_validate_jpeg_frame(validate_counter) && !VideoFrame::is_valid_jpeg_bytes(&owned[..frame_size]) { continue 'capture; @@ -1567,7 +1555,10 @@ fn probe_subdev_signal( let fd = match csi_bridge::open_subdev(subdev_path) { Ok(f) => f, Err(e) => { - debug!("probe_subdev_signal: failed to open {:?}: {}", subdev_path, e); + debug!( + "probe_subdev_signal: failed to open {:?}: {}", + subdev_path, e + ); return Some(crate::video::SignalStatus::NoSignal); } }; @@ -1608,9 +1599,7 @@ fn wait_subdev_for_source_change( let wait = remaining.min(slice); match csi_bridge::wait_source_change(&fd, wait) { Ok(true) => { - info!( - "Subdev SOURCE_CHANGE during no-signal wait, retrying open immediately" - ); + info!("Subdev SOURCE_CHANGE during no-signal wait, retrying open immediately"); return; } Ok(false) => continue, diff --git a/src/video/traits.rs b/src/video/traits.rs new file mode 100644 index 00000000..715c5dc4 --- /dev/null +++ b/src/video/traits.rs @@ -0,0 +1,47 @@ +//! Traits for video output consumers (WebRTC, RTSP, RustDesk, etc.) + +use std::path::PathBuf; +use std::sync::Arc; + +use super::types::{ + BitratePreset, PixelFormat, Resolution, SharedVideoPipeline, SharedVideoPipelineConfig, + SharedVideoPipelineStats, VideoCodecType, +}; +use crate::error::Result; +use crate::events::EventBus; +use crate::hid::HidController; + +/// Trait for video output consumers that receive encoded video frames. +/// +/// Implemented by `WebRtcStreamer`. `VideoStreamManager` depends on this +/// trait instead of the concrete type, breaking the video <-> webrtc +/// circular import. +#[async_trait::async_trait] +pub trait VideoOutput: Send + Sync { + async fn set_event_bus(&self, events: Arc); + async fn update_video_config(&self, resolution: Resolution, format: PixelFormat, fps: u32); + async fn set_capture_device( + &self, + device_path: PathBuf, + jpeg_quality: u8, + subdev_path: Option, + bridge_kind: Option, + v4l2_driver: Option, + ); + async fn current_video_codec(&self) -> VideoCodecType; + async fn is_hardware_encoding(&self) -> bool; + async fn close_all_sessions(&self); + async fn close_all_sessions_and_release_device(&self) -> usize; + async fn session_count(&self) -> usize; + async fn set_hid_controller(&self, hid: Arc); + async fn set_audio_enabled(&self, enabled: bool) -> Result<()>; + async fn is_audio_enabled(&self) -> bool; + async fn reconnect_audio_sources(&self); + async fn ensure_video_pipeline_for_external(&self) -> Result>; + async fn get_pipeline_config(&self) -> Option; + async fn set_video_codec(&self, codec: VideoCodecType) -> Result<()>; + async fn set_bitrate_preset(&self, preset: BitratePreset) -> Result<()>; + async fn request_keyframe(&self) -> Result<()>; + async fn current_video_geometry(&self) -> (Resolution, PixelFormat, u32); + async fn pipeline_stats(&self) -> Option; +} diff --git a/src/video/types.rs b/src/video/types.rs new file mode 100644 index 00000000..c6333b2d --- /dev/null +++ b/src/video/types.rs @@ -0,0 +1,22 @@ +//! Re-exports of shared video types used by other modules (e.g., webrtc) +//! +//! External modules should import from `crate::video::types` instead of +//! reaching into internal submodules directly. + +// From video::format +pub use super::format::{PixelFormat, Resolution}; + +// From video::frame +pub use super::frame::VideoFrame; + +// From video::encoder (codec-level types) +pub use super::encoder::{BitratePreset, VideoCodecType}; + +// From video::encoder::registry +pub use super::encoder::registry::{EncoderBackend, VideoEncoderType}; + +// From video::shared_video_pipeline +pub use super::shared_video_pipeline::{ + EncodedVideoFrame, PipelineStateNotification, SharedVideoPipeline, SharedVideoPipelineConfig, + SharedVideoPipelineStats, +}; diff --git a/src/video/v4l2r_capture.rs b/src/video/v4l2r_capture.rs index e93ceb4a..26630494 100644 --- a/src/video/v4l2r_capture.rs +++ b/src/video/v4l2r_capture.rs @@ -129,9 +129,7 @@ impl V4l2rCaptureStream { subdev_dv_mode = Some(mode); } other => { - let status = other - .as_status() - .unwrap_or(SignalStatus::NoSignal); + let status = other.as_status().unwrap_or(SignalStatus::NoSignal); debug!( "Subdev {:?} reports no signal ({:?}) — refusing STREAMON", subdev_path, status @@ -200,7 +198,10 @@ impl V4l2rCaptureStream { } (Ok(f), _, _) => f, (Err(e), _, _) => { - return Err(AppError::VideoError(format!("Failed to get device format: {}", e))); + return Err(AppError::VideoError(format!( + "Failed to get device format: {}", + e + ))); } }; @@ -446,10 +447,7 @@ impl V4l2rCaptureStream { let mut poll_fds: Vec = Vec::with_capacity(2); poll_fds.push(PollFd::new( self.fd.as_fd(), - PollFlags::POLLIN - | PollFlags::POLLPRI - | PollFlags::POLLERR - | PollFlags::POLLHUP, + PollFlags::POLLIN | PollFlags::POLLPRI | PollFlags::POLLERR | PollFlags::POLLHUP, )); if let Some(subdev_fd) = self.subdev_fd.as_ref() { poll_fds.push(PollFd::new(subdev_fd.as_fd(), PollFlags::POLLPRI)); diff --git a/src/web/audio_ws.rs b/src/web/audio_ws.rs index ce7d6a5d..b19add6c 100644 --- a/src/web/audio_ws.rs +++ b/src/web/audio_ws.rs @@ -31,12 +31,8 @@ use tracing::{debug, info, warn}; use crate::audio::OpusFrame; use crate::state::AppState; -/// Audio packet type identifier const AUDIO_PACKET_TYPE: u8 = 0x02; -/// Audio WebSocket upgrade handler -/// -/// Upgrades HTTP connection to WebSocket for audio streaming. pub async fn audio_ws_handler( ws: WebSocketUpgrade, State(state): State>, @@ -44,16 +40,13 @@ pub async fn audio_ws_handler( ws.on_upgrade(move |socket| handle_audio_socket(socket, state)) } -/// Handle audio WebSocket connection async fn handle_audio_socket(socket: WebSocket, state: Arc) { let (mut sender, mut receiver) = socket.split(); - // Try to get Opus frame subscription - let opus_rx = match state.audio.subscribe_opus_async().await { + let opus_rx = match state.audio.subscribe_opus().await { Some(rx) => rx, None => { warn!("Audio not streaming, rejecting WebSocket connection"); - // Send error message before closing let _ = sender .send(Message::Text( r#"{"error": "Audio not streaming"}"#.to_string().into(), @@ -68,16 +61,13 @@ async fn handle_audio_socket(socket: WebSocket, state: Arc) { info!("Audio WebSocket client connected"); - // Track connection for cleanup let mut closed = false; - // Use interval instead of sleep for more efficient keepalive let mut ping_interval = tokio::time::interval(std::time::Duration::from_secs(30)); ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); loop { tokio::select! { - // Receive Opus frames and send to client opus_result = opus_rx.recv() => { let frame = match opus_result { Some(f) => f, @@ -94,7 +84,6 @@ async fn handle_audio_socket(socket: WebSocket, state: Arc) { } } - // Handle client messages (ping/close) msg = receiver.next() => { match msg { Some(Ok(Message::Close(_))) => { @@ -107,11 +96,8 @@ async fn handle_audio_socket(socket: WebSocket, state: Arc) { break; } } - Some(Ok(Message::Pong(_))) => { - // Pong received, connection is alive - } + Some(Ok(Message::Pong(_))) => {} Some(Ok(Message::Text(text))) => { - // Handle potential control messages debug!("Received text message on audio WS: {}", text); } Some(Err(e)) => { @@ -119,14 +105,12 @@ async fn handle_audio_socket(socket: WebSocket, state: Arc) { break; } None => { - // Connection closed break; } _ => {} } } - // Periodic ping to keep connection alive (using interval) _ = ping_interval.tick() => { if sender.send(Message::Ping(vec![].into())).await.is_err() { warn!("Failed to send ping, disconnecting"); @@ -137,39 +121,24 @@ async fn handle_audio_socket(socket: WebSocket, state: Arc) { } if !closed { - // Try to send close message let _ = sender.send(Message::Close(None)).await; } info!("Audio WebSocket client disconnected"); } -/// Encode Opus frame to binary packet format -/// -/// ## Format -/// -/// | Offset | Size | Description | -/// |--------|------|-------------| -/// | 0 | 1 | Packet type (0x02 for audio) | -/// | 1 | 4 | Timestamp (u32 LE, ms since start) | -/// | 5 | 2 | Duration (u16 LE, ms) | -/// | 7 | 4 | Sequence number (u32 LE) | -/// | 11 | 4 | Data length (u32 LE) | -/// | 15 | N | Opus encoded data | fn encode_audio_packet(frame: &OpusFrame, stream_start: Instant) -> Vec { let timestamp_ms = stream_start.elapsed().as_millis() as u32; let data_len = frame.data.len() as u32; let mut buf = Vec::with_capacity(15 + frame.data.len()); - // Header buf.push(AUDIO_PACKET_TYPE); buf.extend_from_slice(×tamp_ms.to_le_bytes()); buf.extend_from_slice(&(frame.duration_ms as u16).to_le_bytes()); buf.extend_from_slice(&(frame.sequence as u32).to_le_bytes()); buf.extend_from_slice(&data_len.to_le_bytes()); - // Opus data buf.extend_from_slice(&frame.data); buf @@ -186,8 +155,6 @@ mod tests { data: Bytes::from(vec![1, 2, 3, 4, 5]), duration_ms: 20, sequence: 42, - timestamp: Instant::now(), - rtp_timestamp: 0, }; let stream_start = Instant::now(); @@ -195,11 +162,5 @@ mod tests { assert!(encoded.len() >= 15); assert_eq!(encoded[0], AUDIO_PACKET_TYPE); - // decode_audio_packet function was removed, skip decode test - } - - #[test] - fn test_decode_invalid_packet() { - // decode_audio_packet function was removed, skip this test } } diff --git a/src/web/error.rs b/src/web/error.rs new file mode 100644 index 00000000..83fa4063 --- /dev/null +++ b/src/web/error.rs @@ -0,0 +1,31 @@ +use crate::error::AppError; +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde::Serialize; + +#[derive(Serialize)] +pub struct ErrorResponse { + pub success: bool, + pub message: String, +} + +impl IntoResponse for AppError { + fn into_response(self) -> Response { + let body = ErrorResponse { + success: false, + message: self.to_string(), + }; + + tracing::error!( + error_type = std::any::type_name_of_val(&self), + error_message = %body.message, + "Request failed" + ); + + // Always return 200 OK - success/failure is indicated by the success field + (StatusCode::OK, Json(body)).into_response() + } +} diff --git a/src/web/handlers/config/apply.rs b/src/web/handlers/config/apply.rs index 8a8b53eb..2dcde606 100644 --- a/src/web/handlers/config/apply.rs +++ b/src/web/handlers/config/apply.rs @@ -1,13 +1,10 @@ -//! 配置热重载逻辑 -//! -//! 从 handlers.rs 中抽取的配置应用函数,负责将配置变更应用到各个子系统。 - use std::sync::Arc; use crate::config::*; use crate::error::{AppError, Result}; use crate::rtsp::RtspService; use crate::state::AppState; +use crate::stream_encoder::encoder_type_to_backend; use crate::video::codec_constraints::{ enforce_constraints_with_stream_manager, StreamCodecConstraints, }; @@ -32,13 +29,11 @@ async fn reconcile_otg_from_store(state: &Arc) -> Result<()> { .map_err(|e| AppError::Config(format!("OTG reconcile failed: {}", e))) } -/// 应用 Video 配置变更 pub async fn apply_video_config( state: &Arc, old_config: &VideoConfig, new_config: &VideoConfig, ) -> Result<()> { - // 检查配置是否实际变更 if old_config == new_config { tracing::info!("Video config unchanged, skipping reload"); return Ok(()); @@ -74,7 +69,6 @@ pub async fn apply_video_config( Ok(()) } -/// 应用 Stream 配置变更 pub async fn apply_stream_config( state: &Arc, old_config: &StreamConfig, @@ -82,32 +76,24 @@ pub async fn apply_stream_config( ) -> Result<()> { tracing::info!("Applying stream config changes..."); - // 更新编码器后端 if old_config.encoder != new_config.encoder { - let encoder_backend = new_config.encoder.to_backend(); + let encoder_backend = encoder_type_to_backend(new_config.encoder.clone()); tracing::info!( "Updating encoder backend to: {:?} (from config: {:?})", encoder_backend, new_config.encoder ); - state - .stream_manager - .webrtc_streamer() - .update_encoder_backend(encoder_backend) - .await; + state.webrtc.update_encoder_backend(encoder_backend).await; } - // 更新码率 if old_config.bitrate_preset != new_config.bitrate_preset { state .stream_manager - .webrtc_streamer() .set_bitrate_preset(new_config.bitrate_preset) .await .ok(); // Ignore error if no active stream } - // 更新 ICE 配置 (STUN/TURN) let ice_changed = old_config.stun_server != new_config.stun_server || old_config.turn_server != new_config.turn_server || old_config.turn_username != new_config.turn_username @@ -120,8 +106,7 @@ pub async fn apply_stream_config( new_config.turn_server ); state - .stream_manager - .webrtc_streamer() + .webrtc .update_ice_config( new_config.stun_server.clone(), new_config.turn_server.clone(), @@ -139,7 +124,6 @@ pub async fn apply_stream_config( Ok(()) } -/// 应用 HID 配置变更 pub async fn apply_hid_config( state: &Arc, old_config: &HidConfig, @@ -202,7 +186,6 @@ pub async fn apply_hid_config( Ok(()) } -/// 应用 MSD 配置变更 pub async fn apply_msd_config( state: &Arc, old_config: &MsdConfig, @@ -218,7 +201,6 @@ pub async fn apply_msd_config( tracing::debug!("Old MSD config: {:?}", old_config); tracing::debug!("New MSD config: {:?}", new_config); - // Check if MSD enabled state changed let old_msd_enabled = old_config.enabled; let new_msd_enabled = new_config.enabled; let msd_dir_changed = old_config.msd_dir != new_config.msd_dir; @@ -232,7 +214,6 @@ pub async fn apply_msd_config( tracing::info!("MSD directory changed: {}", new_config.msd_dir); } - // Ensure MSD directories exist (msd/images, msd/ventoy) let msd_dir = new_config.msd_dir_path(); if let Err(e) = std::fs::create_dir_all(msd_dir.join("images")) { tracing::warn!("Failed to create MSD images directory: {}", e); @@ -255,7 +236,6 @@ pub async fn apply_msd_config( reconcile_otg_from_store(state).await?; - // Shutdown existing controller if present let mut msd_guard = state.msd.write().await; if let Some(msd) = msd_guard.as_mut() { if let Err(e) = msd.shutdown().await { @@ -271,15 +251,12 @@ pub async fn apply_msd_config( .await .map_err(|e| AppError::Config(format!("MSD initialization failed: {}", e)))?; - // Set event bus let events = state.events.clone(); msd.set_event_bus(events).await; - // Store the initialized controller *state.msd.write().await = Some(msd); tracing::info!("MSD initialized successfully"); } else { - // MSD disabled - shutdown tracing::info!("MSD disabled in config, shutting down..."); let mut msd_guard = state.msd.write().await; @@ -306,7 +283,6 @@ pub async fn apply_msd_config( Ok(()) } -/// 应用 ATX 配置变更 pub async fn apply_atx_config( state: &Arc, _old_config: &AtxConfig, @@ -314,10 +290,8 @@ pub async fn apply_atx_config( ) -> Result<()> { tracing::info!("Applying ATX config changes..."); - // Convert AtxConfig to AtxControllerConfig let controller_config = new_config.to_controller_config(); - // Reload the ATX controller with new configuration let atx_guard = state.atx.read().await; if let Some(atx) = atx_guard.as_ref() { if let Err(e) = atx.reload(controller_config).await { @@ -326,7 +300,6 @@ pub async fn apply_atx_config( } tracing::info!("ATX controller reloaded successfully"); } else { - // ATX controller not initialized, create a new one if enabled drop(atx_guard); if new_config.enabled { @@ -345,7 +318,6 @@ pub async fn apply_atx_config( Ok(()) } -/// 应用 Audio 配置变更 pub async fn apply_audio_config( state: &Arc, _old_config: &AudioConfig, @@ -353,17 +325,14 @@ pub async fn apply_audio_config( ) -> Result<()> { tracing::info!("Applying audio config changes..."); - // Create audio controller config from new config let audio_config = crate::audio::AudioControllerConfig { enabled: new_config.enabled, device: new_config.device.clone(), - quality: crate::audio::AudioQuality::from_str(&new_config.quality), + quality: new_config.quality.parse::()?, }; - // Update audio controller if let Err(e) = state.audio.update_config(audio_config).await { tracing::error!("Audio config update failed: {}", e); - // Don't fail - audio errors are not critical } else { tracing::info!( "Audio config applied: enabled={}, device={}", @@ -372,7 +341,6 @@ pub async fn apply_audio_config( ); } - // Also update WebRTC audio enabled state if let Err(e) = state .stream_manager .set_webrtc_audio_enabled(new_config.enabled) @@ -383,7 +351,6 @@ pub async fn apply_audio_config( tracing::info!("WebRTC audio enabled: {}", new_config.enabled); } - // Reconnect audio sources for existing WebRTC sessions if new_config.enabled { state.stream_manager.reconnect_webrtc_audio_sources().await; } @@ -391,7 +358,6 @@ pub async fn apply_audio_config( Ok(()) } -/// Apply stream codec constraints derived from global config. pub async fn enforce_stream_codec_constraints(state: &Arc) -> Result> { let config = state.config.get(); let constraints = StreamCodecConstraints::from_config(&config); @@ -400,7 +366,6 @@ pub async fn enforce_stream_codec_constraints(state: &Arc) -> Result, old_config: &crate::rustdesk::config::RustDeskConfig, @@ -411,9 +376,7 @@ pub async fn apply_rustdesk_config( let mut rustdesk_guard = state.rustdesk.write().await; let mut credentials_to_save = None; - // Check if service needs to be stopped if old_config.enabled && !new_config.enabled { - // Disable service if let Some(ref service) = *rustdesk_guard { if let Err(e) = service.stop().await { tracing::error!("Failed to stop RustDesk service: {}", e); @@ -423,14 +386,12 @@ pub async fn apply_rustdesk_config( *rustdesk_guard = None; } - // Check if service needs to be started or restarted if new_config.enabled { let need_restart = old_config.rendezvous_server != new_config.rendezvous_server || old_config.device_id != new_config.device_id || old_config.device_password != new_config.device_password; if rustdesk_guard.is_none() { - // Create new service tracing::info!("Initializing RustDesk service..."); let service = crate::rustdesk::RustDeskService::new( new_config.clone(), @@ -442,12 +403,10 @@ pub async fn apply_rustdesk_config( tracing::error!("Failed to start RustDesk service: {}", e); } else { tracing::info!("RustDesk service started with ID: {}", new_config.device_id); - // Save generated keypair and UUID to config credentials_to_save = service.save_credentials(); } *rustdesk_guard = Some(std::sync::Arc::new(service)); } else if need_restart { - // Restart existing service with new config if let Some(ref service) = *rustdesk_guard { if let Err(e) = service.restart(new_config.clone()).await { tracing::error!("Failed to restart RustDesk service: {}", e); @@ -456,14 +415,12 @@ pub async fn apply_rustdesk_config( "RustDesk service restarted with ID: {}", new_config.device_id ); - // Save generated keypair and UUID to config credentials_to_save = service.save_credentials(); } } } } - // Save credentials to persistent config store (outside the lock) drop(rustdesk_guard); if let Some(updated_config) = credentials_to_save { tracing::info!("Saving RustDesk credentials to config store..."); @@ -491,7 +448,6 @@ pub async fn apply_rustdesk_config( Ok(()) } -/// 应用 RTSP 配置变更 pub async fn apply_rtsp_config( state: &Arc, old_config: &RtspConfig, diff --git a/src/web/handlers/config/atx.rs b/src/web/handlers/config/atx.rs index f1bc7c55..e61c8de4 100644 --- a/src/web/handlers/config/atx.rs +++ b/src/web/handlers/config/atx.rs @@ -1,5 +1,3 @@ -//! ATX configuration handlers - use axum::{extract::State, Json}; use std::sync::Arc; @@ -11,29 +9,23 @@ use crate::state::AppState; use super::apply::apply_atx_config; use super::types::AtxConfigUpdate; -/// Get ATX configuration pub async fn get_atx_config(State(state): State>) -> Json { Json(state.config.get().atx.clone()) } -/// Update ATX configuration pub async fn update_atx_config( State(state): State>, Json(req): Json, ) -> Result> { - // 1. Read current configuration snapshot let current_config = state.config.get(); let old_atx_config = current_config.atx.clone(); - // 2. Validate request, including merged effective serial parameter checks req.validate_with_current(&old_atx_config)?; - // 3. Ensure ATX serial devices do not conflict with HID CH9329 serial device let mut merged_atx_config = old_atx_config.clone(); req.apply_to(&mut merged_atx_config); validate_serial_device_conflict(&merged_atx_config, ¤t_config.hid)?; - // 4. Persist update into config store state .config .update(|config| { @@ -41,10 +33,8 @@ pub async fn update_atx_config( }) .await?; - // 5. Load new config let new_atx_config = state.config.get().atx.clone(); - // 6. Apply to subsystem (hot reload) if let Err(e) = apply_atx_config(&state, &old_atx_config, &new_atx_config).await { tracing::error!("Failed to apply ATX config: {}", e); } diff --git a/src/web/handlers/config/audio.rs b/src/web/handlers/config/audio.rs index 1d1317ae..c2e325f5 100644 --- a/src/web/handlers/config/audio.rs +++ b/src/web/handlers/config/audio.rs @@ -1,5 +1,3 @@ -//! Audio 配置 Handler - use axum::{extract::State, Json}; use std::sync::Arc; @@ -10,23 +8,18 @@ use crate::state::AppState; use super::apply::apply_audio_config; use super::types::AudioConfigUpdate; -/// 获取 Audio 配置 pub async fn get_audio_config(State(state): State>) -> Json { Json(state.config.get().audio.clone()) } -/// 更新 Audio 配置 pub async fn update_audio_config( State(state): State>, Json(req): Json, ) -> Result> { - // 1. 验证请求 req.validate()?; - // 2. 获取旧配置 let old_audio_config = state.config.get().audio.clone(); - // 3. 应用更新到配置存储 state .config .update(|config| { @@ -34,10 +27,8 @@ pub async fn update_audio_config( }) .await?; - // 4. 获取新配置 let new_audio_config = state.config.get().audio.clone(); - // 5. 应用到子系统(热重载) if let Err(e) = apply_audio_config(&state, &old_audio_config, &new_audio_config).await { tracing::error!("Failed to apply audio config: {}", e); } diff --git a/src/web/handlers/config/auth.rs b/src/web/handlers/config/auth.rs index f08879ae..8c7f6d05 100644 --- a/src/web/handlers/config/auth.rs +++ b/src/web/handlers/config/auth.rs @@ -14,7 +14,6 @@ pub async fn get_auth_config(State(state): State>) -> Json>, Json(update): Json, diff --git a/src/web/handlers/config/hid.rs b/src/web/handlers/config/hid.rs index 47b9d7bf..c06e532c 100644 --- a/src/web/handlers/config/hid.rs +++ b/src/web/handlers/config/hid.rs @@ -1,5 +1,3 @@ -//! HID 配置 Handler - use axum::{extract::State, Json}; use std::sync::Arc; @@ -10,23 +8,18 @@ use crate::state::AppState; use super::apply::apply_hid_config; use super::types::HidConfigUpdate; -/// 获取 HID 配置 pub async fn get_hid_config(State(state): State>) -> Json { Json(state.config.get().hid.clone()) } -/// 更新 HID 配置 pub async fn update_hid_config( State(state): State>, Json(req): Json, ) -> Result> { - // 1. 验证请求 req.validate()?; - // 2. 获取旧配置 let old_hid_config = state.config.get().hid.clone(); - // 3. 应用更新到配置存储 state .config .update(|config| { @@ -34,10 +27,8 @@ pub async fn update_hid_config( }) .await?; - // 4. 获取新配置 let new_hid_config = state.config.get().hid.clone(); - // 5. 应用到子系统(热重载) if let Err(e) = apply_hid_config(&state, &old_hid_config, &new_hid_config).await { tracing::error!("Failed to apply HID config: {}", e); } diff --git a/src/web/handlers/config/mod.rs b/src/web/handlers/config/mod.rs index b2a9d872..c2b3e023 100644 --- a/src/web/handlers/config/mod.rs +++ b/src/web/handlers/config/mod.rs @@ -1,21 +1,3 @@ -//! 配置管理 Handler 模块 -//! -//! 提供 RESTful 域分离的配置 API: -//! - GET /api/config/video - 获取视频配置 -//! - PATCH /api/config/video - 更新视频配置 -//! - GET /api/config/stream - 获取流配置 -//! - PATCH /api/config/stream - 更新流配置 -//! - GET /api/config/hid - 获取 HID 配置 -//! - PATCH /api/config/hid - 更新 HID 配置 -//! - GET /api/config/msd - 获取 MSD 配置 -//! - PATCH /api/config/msd - 更新 MSD 配置 -//! - GET /api/config/atx - 获取 ATX 配置 -//! - PATCH /api/config/atx - 更新 ATX 配置 -//! - GET /api/config/audio - 获取音频配置 -//! - PATCH /api/config/audio - 更新音频配置 -//! - GET /api/config/rustdesk - 获取 RustDesk 配置 -//! - PATCH /api/config/rustdesk - 更新 RustDesk 配置 - pub(crate) mod apply; mod types; @@ -30,7 +12,6 @@ mod stream; pub(crate) mod video; mod web; -// 导出 handler 函数 pub use atx::{get_atx_config, update_atx_config}; pub use audio::{get_audio_config, update_audio_config}; pub use auth::{get_auth_config, update_auth_config}; @@ -45,7 +26,6 @@ pub use stream::{get_stream_config, update_stream_config}; pub use video::{get_video_config, update_video_config}; pub use web::{get_web_config, update_web_config}; -// 保留全局配置查询(向后兼容) use axum::{extract::State, Json}; use std::sync::Arc; @@ -53,13 +33,10 @@ use crate::config::AppConfig; use crate::state::AppState; fn sanitize_config_for_api(config: &mut AppConfig) { - // Auth secrets config.auth.totp_secret = None; - // Stream secrets config.stream.turn_password = None; - // RustDesk secrets config.rustdesk.device_password.clear(); config.rustdesk.relay_key = None; config.rustdesk.public_key = None; @@ -67,14 +44,11 @@ fn sanitize_config_for_api(config: &mut AppConfig) { config.rustdesk.signing_public_key = None; config.rustdesk.signing_private_key = None; - // RTSP secrets config.rtsp.password = None; } -/// 获取完整配置 pub async fn get_all_config(State(state): State>) -> Json { let mut config = (*state.config.get()).clone(); - // 不暴露敏感信息 sanitize_config_for_api(&mut config); Json(config) } diff --git a/src/web/handlers/config/msd.rs b/src/web/handlers/config/msd.rs index 64b15898..a8fcaa23 100644 --- a/src/web/handlers/config/msd.rs +++ b/src/web/handlers/config/msd.rs @@ -1,5 +1,3 @@ -//! MSD 配置 Handler - use axum::{extract::State, Json}; use std::sync::Arc; @@ -10,23 +8,18 @@ use crate::state::AppState; use super::apply::apply_msd_config; use super::types::MsdConfigUpdate; -/// 获取 MSD 配置 pub async fn get_msd_config(State(state): State>) -> Json { Json(state.config.get().msd.clone()) } -/// 更新 MSD 配置 pub async fn update_msd_config( State(state): State>, Json(req): Json, ) -> Result> { - // 1. 验证请求 req.validate()?; - // 2. 获取旧配置 let old_msd_config = state.config.get().msd.clone(); - // 3. 应用更新到配置存储 state .config .update(|config| { @@ -34,10 +27,8 @@ pub async fn update_msd_config( }) .await?; - // 4. 获取新配置 let new_msd_config = state.config.get().msd.clone(); - // 5. 应用到子系统(热重载) if let Err(e) = apply_msd_config(&state, &old_msd_config, &new_msd_config).await { tracing::error!("Failed to apply MSD config: {}", e); } diff --git a/src/web/handlers/config/rtsp.rs b/src/web/handlers/config/rtsp.rs index d8dcc846..520e856d 100644 --- a/src/web/handlers/config/rtsp.rs +++ b/src/web/handlers/config/rtsp.rs @@ -7,13 +7,11 @@ use crate::state::AppState; use super::apply::apply_rtsp_config; use super::types::{RtspConfigResponse, RtspConfigUpdate, RtspStatusResponse}; -/// Get RTSP config pub async fn get_rtsp_config(State(state): State>) -> Json { let config = state.config.get(); Json(RtspConfigResponse::from(&config.rtsp)) } -/// Get RTSP status (config + service status) pub async fn get_rtsp_status(State(state): State>) -> Json { let config = state.config.get().rtsp.clone(); let status = { @@ -28,7 +26,6 @@ pub async fn get_rtsp_status(State(state): State>) -> Json>, Json(req): Json, diff --git a/src/web/handlers/config/rustdesk.rs b/src/web/handlers/config/rustdesk.rs index 9e1e0460..12e47aac 100644 --- a/src/web/handlers/config/rustdesk.rs +++ b/src/web/handlers/config/rustdesk.rs @@ -1,5 +1,3 @@ -//! RustDesk 配置 Handler - use axum::{extract::State, Json}; use std::sync::Arc; @@ -10,18 +8,14 @@ use crate::state::AppState; use super::apply::apply_rustdesk_config; use super::types::RustDeskConfigUpdate; -/// RustDesk 配置响应(隐藏敏感信息) #[derive(Debug, serde::Serialize)] pub struct RustDeskConfigResponse { pub enabled: bool, pub rendezvous_server: String, pub relay_server: Option, pub device_id: String, - /// 是否已设置密码 pub has_password: bool, - /// 是否已设置密钥对 pub has_keypair: bool, - /// 是否已设置 relay key pub has_relay_key: bool, } @@ -39,7 +33,6 @@ impl From<&RustDeskConfig> for RustDeskConfigResponse { } } -/// RustDesk 状态响应 #[derive(Debug, serde::Serialize)] pub struct RustDeskStatusResponse { pub config: RustDeskConfigResponse, @@ -47,20 +40,17 @@ pub struct RustDeskStatusResponse { pub rendezvous_status: Option, } -/// 获取 RustDesk 配置 pub async fn get_rustdesk_config( State(state): State>, ) -> Json { Json(RustDeskConfigResponse::from(&state.config.get().rustdesk)) } -/// 获取 RustDesk 完整状态(配置 + 服务状态) pub async fn get_rustdesk_status( State(state): State>, ) -> Json { let config = state.config.get().rustdesk.clone(); - // 获取服务状态 let (service_status, rendezvous_status) = { let guard = state.rustdesk.read().await; if let Some(ref service) = *guard { @@ -79,18 +69,14 @@ pub async fn get_rustdesk_status( }) } -/// 更新 RustDesk 配置 pub async fn update_rustdesk_config( State(state): State>, Json(req): Json, ) -> Result> { - // 1. 验证请求 req.validate()?; - // 2. 获取旧配置 let old_config = state.config.get().rustdesk.clone(); - // 3. 应用更新到配置存储 state .config .update(|config| { @@ -98,15 +84,12 @@ pub async fn update_rustdesk_config( }) .await?; - // 4. 获取新配置 let new_config = state.config.get().rustdesk.clone(); - // 5. 应用到子系统(热重载) if let Err(e) = apply_rustdesk_config(&state, &old_config, &new_config).await { tracing::error!("Failed to apply RustDesk config: {}", e); } - // Share a non-sensitive summary for frontend UX let constraints = state.stream_manager.codec_constraints().await; if constraints.rustdesk_enabled || constraints.rtsp_enabled { tracing::info!( @@ -118,7 +101,6 @@ pub async fn update_rustdesk_config( Ok(Json(RustDeskConfigResponse::from(&new_config))) } -/// 重新生成设备 ID pub async fn regenerate_device_id( State(state): State>, ) -> Result> { @@ -133,7 +115,6 @@ pub async fn regenerate_device_id( Ok(Json(RustDeskConfigResponse::from(&new_config))) } -/// 重新生成设备密码 pub async fn regenerate_device_password( State(state): State>, ) -> Result> { @@ -148,7 +129,6 @@ pub async fn regenerate_device_password( Ok(Json(RustDeskConfigResponse::from(&new_config))) } -/// 获取设备密码(已认证用户) pub async fn get_device_password(State(state): State>) -> Json { let config = state.config.get().rustdesk.clone(); Json(serde_json::json!({ diff --git a/src/web/handlers/config/stream.rs b/src/web/handlers/config/stream.rs index 5f0c3234..4950996b 100644 --- a/src/web/handlers/config/stream.rs +++ b/src/web/handlers/config/stream.rs @@ -1,5 +1,3 @@ -//! Stream 配置 Handler - use axum::{extract::State, Json}; use std::sync::Arc; @@ -9,24 +7,19 @@ use crate::state::AppState; use super::apply::apply_stream_config; use super::types::{StreamConfigResponse, StreamConfigUpdate}; -/// 获取 Stream 配置 pub async fn get_stream_config(State(state): State>) -> Json { let config = state.config.get(); Json(StreamConfigResponse::from(&config.stream)) } -/// 更新 Stream 配置 pub async fn update_stream_config( State(state): State>, Json(req): Json, ) -> Result> { - // 1. 验证请求 req.validate()?; - // 2. 获取旧配置 let old_stream_config = state.config.get().stream.clone(); - // 3. 应用更新到配置存储 state .config .update(|config| { @@ -34,15 +27,12 @@ pub async fn update_stream_config( }) .await?; - // 4. 获取新配置 let new_stream_config = state.config.get().stream.clone(); - // 5. 应用到子系统(热重载) if let Err(e) = apply_stream_config(&state, &old_stream_config, &new_stream_config).await { tracing::error!("Failed to apply stream config: {}", e); } - // 6. Enforce codec constraints after any stream config update if let Err(e) = super::apply::enforce_stream_codec_constraints(&state).await { tracing::error!("Failed to enforce stream codec constraints: {}", e); } diff --git a/src/web/handlers/config/types.rs b/src/web/handlers/config/types.rs index fddff512..76565fe9 100644 --- a/src/web/handlers/config/types.rs +++ b/src/web/handlers/config/types.rs @@ -2,13 +2,11 @@ use crate::config::*; use crate::error::AppError; use crate::rtsp::RtspServiceStatus; use crate::rustdesk::config::RustDeskConfig; -use crate::video::encoder::BitratePreset; use base64::{engine::general_purpose::STANDARD, Engine as _}; use serde::{Deserialize, Serialize}; use std::path::Path; use typeshare::typeshare; -// ===== Auth Config ===== #[typeshare] #[derive(Debug, Deserialize)] pub struct AuthConfigUpdate { @@ -27,7 +25,6 @@ impl AuthConfigUpdate { } } -// ===== Video Config ===== #[typeshare] #[derive(Debug, Deserialize)] pub struct VideoConfigUpdate { @@ -92,8 +89,6 @@ impl VideoConfigUpdate { } } -// ===== Stream Config ===== - /// Stream configuration response (includes has_turn_password) #[typeshare] #[derive(Debug, serde::Serialize)] @@ -212,8 +207,6 @@ impl StreamConfigUpdate { } } -// ===== HID Config ===== - /// OTG USB device descriptor configuration update #[typeshare] #[derive(Debug, Deserialize)] @@ -364,7 +357,6 @@ impl HidConfigUpdate { } } -// ===== MSD Config ===== #[typeshare] #[derive(Debug, Deserialize)] pub struct MsdConfigUpdate { @@ -398,8 +390,6 @@ impl MsdConfigUpdate { } } -// ===== ATX Config ===== - /// Update for a single ATX key configuration #[typeshare] #[derive(Debug, Deserialize)] @@ -626,7 +616,6 @@ impl AtxConfigUpdate { } } -// ===== Audio Config ===== #[typeshare] #[derive(Debug, Deserialize)] pub struct AudioConfigUpdate { @@ -660,8 +649,6 @@ impl AudioConfigUpdate { } } -// ===== RustDesk Config ===== - /// hbbs/hbbr `-k` relay key: standard Base64 encoding of exactly 32 bytes (typically 44 chars with padding). fn validate_rustdesk_relay_key(key: &str) -> Result<(), AppError> { let decoded = STANDARD.decode(key.as_bytes()).map_err(|_| { @@ -758,7 +745,6 @@ impl RustDeskConfigUpdate { } } -// ===== RTSP Config ===== #[typeshare] #[derive(Debug, serde::Serialize)] pub struct RtspConfigResponse { @@ -876,8 +862,6 @@ impl RtspConfigUpdate { } } -// ===== Web Config ===== - /// Web server settings returned by `GET` / `PATCH /api/config/web`. /// /// Public API shape: certificate paths on disk are not exposed. The full stored model is `WebConfig` in `config::schema`. diff --git a/src/web/handlers/config/video.rs b/src/web/handlers/config/video.rs index d63dc8b4..de80837a 100644 --- a/src/web/handlers/config/video.rs +++ b/src/web/handlers/config/video.rs @@ -1,5 +1,3 @@ -//! Video 配置 Handler - use axum::{extract::State, Json}; use std::sync::Arc; @@ -10,23 +8,18 @@ use crate::state::AppState; use super::apply::apply_video_config; use super::types::VideoConfigUpdate; -/// 获取 Video 配置 pub async fn get_video_config(State(state): State>) -> Json { Json(state.config.get().video.clone()) } -/// 更新 Video 配置 pub async fn update_video_config( State(state): State>, Json(req): Json, ) -> Result> { - // 1. 验证请求 req.validate()?; - // 2. 获取旧配置 let old_video_config = state.config.get().video.clone(); - // 3. 应用更新到配置存储 state .config .update(|config| { @@ -34,10 +27,8 @@ pub async fn update_video_config( }) .await?; - // 4. 获取新配置 let new_video_config = state.config.get().video.clone(); - // 5. 应用到子系统(热重载) if let Err(e) = apply_video_config(&state, &old_video_config, &new_video_config).await { tracing::error!("Failed to apply video config: {}", e); // 根据用户选择,仅记录错误,不回滚 diff --git a/src/web/handlers/config/web.rs b/src/web/handlers/config/web.rs index 9b3d9b7a..890f24d4 100644 --- a/src/web/handlers/config/web.rs +++ b/src/web/handlers/config/web.rs @@ -1,5 +1,3 @@ -//! Web 服务器配置 Handler - use axum::{extract::State, Json}; use axum_server::tls_rustls::RustlsConfig; use std::sync::Arc; @@ -9,14 +7,10 @@ use crate::state::AppState; use super::types::{WebConfigResponse, WebConfigUpdate}; -/// 获取 Web 配置 -pub async fn get_web_config( - State(state): State>, -) -> Json { +pub async fn get_web_config(State(state): State>) -> Json { Json(WebConfigResponse::from_stored(&state.config.get().web)) } -/// 更新 Web 配置(支持 PEM 证书上传) pub async fn update_web_config( State(state): State>, Json(req): Json, @@ -27,9 +21,13 @@ pub async fn update_web_config( // Some(Some((cert, key))) = write new cert // Some(None) = clear custom cert // None = no cert change - let cert_path_update: Option> = - if let (Some(cert_pem), Some(key_pem)) = (&req.ssl_cert_pem, &req.ssl_key_pem) { - RustlsConfig::from_pem(cert_pem.as_bytes().to_vec(), key_pem.as_bytes().to_vec()) + let cert_path_update: Option> = if let ( + Some(cert_pem), + Some(key_pem), + ) = + (&req.ssl_cert_pem, &req.ssl_key_pem) + { + RustlsConfig::from_pem(cert_pem.as_bytes().to_vec(), key_pem.as_bytes().to_vec()) .await .map_err(|e| { AppError::BadRequest( @@ -39,30 +37,30 @@ pub async fn update_web_config( .into(), ) })?; - let cert_dir = state.data_dir().join("certs"); - tokio::fs::create_dir_all(&cert_dir) - .await - .map_err(|e| AppError::Internal(format!("Failed to create cert dir: {e}")))?; - let cert_path = cert_dir.join("custom.crt"); - let key_path = cert_dir.join("custom.key"); - tokio::fs::write(&cert_path, cert_pem.as_bytes()) - .await - .map_err(|e| AppError::Internal(format!("Failed to write certificate: {e}")))?; - tokio::fs::write(&key_path, key_pem.as_bytes()) - .await - .map_err(|e| AppError::Internal(format!("Failed to write private key: {e}")))?; - Some(Some(( - cert_path.to_string_lossy().into_owned(), - key_path.to_string_lossy().into_owned(), - ))) - } else if req.clear_custom_cert.unwrap_or(false) { - let cert_dir = state.data_dir().join("certs"); - let _ = tokio::fs::remove_file(cert_dir.join("custom.crt")).await; - let _ = tokio::fs::remove_file(cert_dir.join("custom.key")).await; - Some(None) - } else { - None - }; + let cert_dir = state.data_dir().join("certs"); + tokio::fs::create_dir_all(&cert_dir) + .await + .map_err(|e| AppError::Internal(format!("Failed to create cert dir: {e}")))?; + let cert_path = cert_dir.join("custom.crt"); + let key_path = cert_dir.join("custom.key"); + tokio::fs::write(&cert_path, cert_pem.as_bytes()) + .await + .map_err(|e| AppError::Internal(format!("Failed to write certificate: {e}")))?; + tokio::fs::write(&key_path, key_pem.as_bytes()) + .await + .map_err(|e| AppError::Internal(format!("Failed to write private key: {e}")))?; + Some(Some(( + cert_path.to_string_lossy().into_owned(), + key_path.to_string_lossy().into_owned(), + ))) + } else if req.clear_custom_cert.unwrap_or(false) { + let cert_dir = state.data_dir().join("certs"); + let _ = tokio::fs::remove_file(cert_dir.join("custom.crt")).await; + let _ = tokio::fs::remove_file(cert_dir.join("custom.key")).await; + Some(None) + } else { + None + }; state .config @@ -82,7 +80,9 @@ pub async fn update_web_config( }) .await?; - Ok(Json(WebConfigResponse::from_stored(&state.config.get().web))) + Ok(Json(WebConfigResponse::from_stored( + &state.config.get().web, + ))) } #[cfg(test)] diff --git a/src/web/handlers/devices.rs b/src/web/handlers/devices.rs index af5bdaab..da2d3cb6 100644 --- a/src/web/handlers/devices.rs +++ b/src/web/handlers/devices.rs @@ -1,7 +1,3 @@ -//! Device discovery handlers -//! -//! Provides API endpoints for discovering available hardware devices. - use axum::Json; use serde::Deserialize; @@ -9,17 +5,10 @@ use crate::atx::{discover_devices, AtxDevices}; use crate::error::{AppError, Result}; use crate::video::usb_reset; -/// GET /api/devices/atx - List available ATX devices -/// -/// Returns lists of available GPIO chips and USB HID relay devices. pub async fn list_atx_devices() -> Json { Json(discover_devices()) } -/// GET /api/devices/usb - List all USB devices -/// -/// Enumerates USB devices from `/sys/bus/usb/devices/` with associated -/// video device mappings. pub async fn list_usb_devices() -> Json> { Json(usb_reset::list_usb_devices()) } @@ -30,11 +19,6 @@ pub struct UsbResetRequest { pub dev_num: u32, } -/// POST /api/devices/usb/reset - Reset a USB device via authorized cycle -/// -/// Writes `0` then `1` to the device's `authorized` sysfs attribute, -/// causing the kernel to deauthorize and re-authorize the device. -/// Requires root or write access to sysfs. pub async fn reset_usb_device(Json(req): Json) -> Result> { usb_reset::reset_usb_device(req.bus_num, req.dev_num).map_err(|e| { AppError::VideoError(format!( diff --git a/src/web/handlers/extensions.rs b/src/web/handlers/extensions.rs index 29fd1c87..d6cd4b0e 100644 --- a/src/web/handlers/extensions.rs +++ b/src/web/handlers/extensions.rs @@ -1,5 +1,3 @@ -//! Extension management API handlers - use axum::{ extract::{Path, Query, State}, Json, @@ -15,12 +13,6 @@ use crate::extensions::{ }; use crate::state::AppState; -// ============================================================================ -// Get all extensions status -// ============================================================================ - -/// Get status of all extensions -/// GET /api/extensions pub async fn list_extensions(State(state): State>) -> Json { let config = state.config.get(); let mgr = &state.extensions; @@ -44,12 +36,6 @@ pub async fn list_extensions(State(state): State>) -> Json>, Path(id): Path, @@ -66,12 +52,6 @@ pub async fn get_extension( })) } -// ============================================================================ -// Start/Stop extensions -// ============================================================================ - -/// Start an extension -/// POST /api/extensions/:id/start pub async fn start_extension( State(state): State>, Path(id): Path, @@ -83,20 +63,16 @@ pub async fn start_extension( let config = state.config.get(); let mgr = &state.extensions; - // Start the extension mgr.start(ext_id, &config.extensions) .await .map_err(AppError::Internal)?; - // Return updated status Ok(Json(ExtensionInfo { available: mgr.check_available(ext_id), status: mgr.status(ext_id).await, })) } -/// Stop an extension -/// POST /api/extensions/:id/stop pub async fn stop_extension( State(state): State>, Path(id): Path, @@ -107,29 +83,20 @@ pub async fn stop_extension( let mgr = &state.extensions; - // Stop the extension mgr.stop(ext_id).await.map_err(AppError::Internal)?; - // Return updated status Ok(Json(ExtensionInfo { available: mgr.check_available(ext_id), status: mgr.status(ext_id).await, })) } -// ============================================================================ -// Extension logs -// ============================================================================ - -/// Query parameters for logs #[derive(Deserialize, Default)] pub struct LogsQuery { /// Number of lines to return (default: 100, max: 500) pub lines: Option, } -/// Get extension logs -/// GET /api/extensions/:id/logs pub async fn get_extension_logs( State(state): State>, Path(id): Path, @@ -145,20 +112,13 @@ pub async fn get_extension_logs( Ok(Json(ExtensionLogs { id: ext_id, logs })) } -// ============================================================================ -// Update extension config -// ============================================================================ - -/// Update ttyd config #[typeshare] #[derive(Debug, Deserialize)] pub struct TtydConfigUpdate { pub enabled: Option, - pub port: Option, pub shell: Option, } -/// Update gostc config #[typeshare] #[derive(Debug, Deserialize)] pub struct GostcConfigUpdate { @@ -168,7 +128,6 @@ pub struct GostcConfigUpdate { pub tls: Option, } -/// Update easytier config #[typeshare] #[derive(Debug, Deserialize)] pub struct EasytierConfigUpdate { @@ -179,16 +138,12 @@ pub struct EasytierConfigUpdate { pub virtual_ip: Option, } -/// Update ttyd configuration -/// PATCH /api/extensions/ttyd/config pub async fn update_ttyd_config( State(state): State>, Json(req): Json, ) -> Result> { - // Get current config let was_enabled = state.config.get().extensions.ttyd.enabled; - // Update config state .config .update(|config| { @@ -196,9 +151,6 @@ pub async fn update_ttyd_config( if let Some(enabled) = req.enabled { ttyd.enabled = enabled; } - if let Some(port) = req.port { - ttyd.port = port; - } if let Some(ref shell) = req.shell { ttyd.shell = shell.clone(); } @@ -208,12 +160,9 @@ pub async fn update_ttyd_config( let new_config = state.config.get(); let is_enabled = new_config.extensions.ttyd.enabled; - // Handle enable/disable state change if was_enabled && !is_enabled { - // Was running, now disabled - stop it state.extensions.stop(ExtensionId::Ttyd).await.ok(); } else if !was_enabled && is_enabled { - // Was disabled, now enabled - start it if state.extensions.check_available(ExtensionId::Ttyd) { state .extensions @@ -226,8 +175,6 @@ pub async fn update_ttyd_config( Ok(Json(new_config.extensions.ttyd.clone())) } -/// Update gostc configuration -/// PATCH /api/extensions/gostc/config pub async fn update_gostc_config( State(state): State>, Json(req): Json, @@ -276,8 +223,6 @@ pub async fn update_gostc_config( Ok(Json(new_config.extensions.gostc.clone())) } -/// Update easytier configuration -/// PATCH /api/extensions/easytier/config pub async fn update_easytier_config( State(state): State>, Json(req): Json, diff --git a/src/web/handlers/mod.rs b/src/web/handlers/mod.rs index 0c66ef77..d7840b80 100644 --- a/src/web/handlers/mod.rs +++ b/src/web/handlers/mod.rs @@ -14,16 +14,13 @@ use crate::config::{AppConfig, StreamMode}; use crate::error::{AppError, Result}; use crate::state::AppState; use crate::update::{UpdateChannel, UpdateOverviewResponse, UpdateStatusResponse, UpgradeRequest}; +use crate::utils::{hostname_uname, list_dir_names, read_trimmed}; use crate::video::codec_constraints::codec_to_id; use crate::video::encoder::{ build_hardware_self_check_runtime_error, run_hardware_self_check, BitratePreset, VideoEncoderSelfCheckResponse, }; -// ============================================================================ -// Health & Info -// ============================================================================ - /// Health check response #[derive(Serialize)] pub struct HealthResponse { @@ -169,7 +166,7 @@ fn get_device_info() -> DeviceInfo { let mem_info = get_meminfo(); DeviceInfo { - hostname: get_hostname(), + hostname: hostname_uname(), cpu_model: get_cpu_model(), cpu_usage: get_cpu_usage(), memory_total: mem_info.total, @@ -178,13 +175,6 @@ fn get_device_info() -> DeviceInfo { } } -/// Get system hostname -fn get_hostname() -> String { - nix::unistd::gethostname() - .map(|s| s.to_string_lossy().into_owned()) - .unwrap_or_else(|_| "unknown".to_string()) -} - /// Get CPU model name from /proc/cpuinfo, fallback to device-tree model fn get_cpu_model() -> String { let cpuinfo = std::fs::read_to_string("/proc/cpuinfo").ok(); @@ -451,10 +441,6 @@ mod tests { } } -// ============================================================================ -// Authentication -// ============================================================================ - #[derive(Deserialize)] pub struct LoginRequest { pub username: String, @@ -550,9 +536,9 @@ pub async fn auth_check( axum::Extension(session): axum::Extension, ) -> Json { // Get user info from user_id - let username = match state.users.get(&session.user_id).await { - Ok(Some(user)) => Some(user.username), - _ => Some(session.user_id.clone()), // Fallback to user_id if user not found + let username = match state.users.single_user().await { + Ok(Some(user)) if user.id == session.user_id => Some(user.username), + _ => None, }; Json(AuthCheckResponse { @@ -561,10 +547,6 @@ pub async fn auth_check( }) } -// ============================================================================ -// Setup -// ============================================================================ - #[derive(Serialize)] pub struct SetupStatus { pub initialized: bool, @@ -630,7 +612,10 @@ pub async fn setup_init( } // Create single system user - state.users.create(&req.username, &req.password).await?; + state + .users + .create_first_user(&req.username, &req.password) + .await?; // Update config state @@ -780,7 +765,10 @@ pub async fn setup_init( let audio_config = crate::audio::AudioControllerConfig { enabled: true, device: new_config.audio.device.clone(), - quality: crate::audio::AudioQuality::from_str(&new_config.audio.quality), + quality: new_config + .audio + .quality + .parse::()?, }; if let Err(e) = state.audio.update_config(audio_config).await { tracing::warn!("Failed to start audio during setup: {}", e); @@ -804,10 +792,6 @@ pub async fn setup_init( })) } -// ============================================================================ -// Configuration -// ============================================================================ - #[derive(Deserialize)] pub struct UpdateConfigRequest { #[serde(flatten)] @@ -962,10 +946,6 @@ fn merge_json( } } -// ============================================================================ -// Devices -// ============================================================================ - #[derive(Serialize)] pub struct DeviceList { pub video: Vec, @@ -1165,10 +1145,6 @@ pub async fn list_devices(State(state): State>) -> Json>) -> Json "mjpeg".to_string(), StreamMode::WebRTC => { use crate::video::encoder::VideoCodecType; - let codec = state - .stream_manager - .webrtc_streamer() - .current_video_codec() - .await; + let codec = state.stream_manager.current_video_codec().await; match codec { VideoCodecType::H264 => "h264".to_string(), VideoCodecType::H265 => "h265".to_string(), @@ -1300,11 +1272,7 @@ pub async fn stream_mode_set( // switch_mode_transaction treats this as "no switch needed" since StreamMode // is still WebRTC, so we handle codec change + event emission here. let current_mode = state.stream_manager.current_mode().await; - let prev_codec = state - .stream_manager - .webrtc_streamer() - .current_video_codec() - .await; + let prev_codec = state.stream_manager.current_video_codec().await; let codec_changed = video_codec.is_some_and(|c| c != prev_codec); let is_codec_only_switch = @@ -1312,12 +1280,7 @@ pub async fn stream_mode_set( if let Some(codec) = video_codec { info!("Setting WebRTC video codec to {:?}", codec); - if let Err(e) = state - .stream_manager - .webrtc_streamer() - .set_video_codec(codec) - .await - { + if let Err(e) = state.stream_manager.set_video_codec(codec).await { warn!("Failed to set video codec: {}", e); } } @@ -1349,11 +1312,7 @@ pub async fn stream_mode_set( let active_mode_str = match state.stream_manager.current_mode().await { StreamMode::Mjpeg => "mjpeg".to_string(), StreamMode::WebRTC => { - let codec = state - .stream_manager - .webrtc_streamer() - .current_video_codec() - .await; + let codec = state.stream_manager.current_video_codec().await; match codec { VideoCodecType::H264 => "h264".to_string(), VideoCodecType::H265 => "h265".to_string(), @@ -1452,11 +1411,7 @@ pub async fn stream_constraints_get( let current_mode = match current_mode { StreamMode::Mjpeg => "mjpeg".to_string(), StreamMode::WebRTC => { - let codec = state - .stream_manager - .webrtc_streamer() - .current_video_codec() - .await; + let codec = state.stream_manager.current_video_codec().await; match codec { VideoCodecType::H264 => "h264".to_string(), VideoCodecType::H265 => "h265".to_string(), @@ -1509,7 +1464,6 @@ pub async fn stream_set_bitrate( // Apply to WebRTC streamer (real-time adjustment) if let Err(e) = state .stream_manager - .webrtc_streamer() .set_bitrate_preset(req.bitrate_preset) .await { @@ -1850,10 +1804,6 @@ fn create_mjpeg_part(jpeg_data: &[u8]) -> bytes::Bytes { buf.freeze() } -// ============================================================================ -// WebRTC -// ============================================================================ - use crate::webrtc::signaling::{AnswerResponse, IceCandidateRequest, OfferRequest}; /// Create WebRTC session @@ -1872,11 +1822,7 @@ pub async fn webrtc_create_session( )); } - let session_id = state - .stream_manager - .webrtc_streamer() - .create_session() - .await?; + let session_id = state.webrtc.create_session().await?; Ok(Json(CreateSessionResponse { session_id })) } @@ -1894,7 +1840,7 @@ pub async fn webrtc_offer( // Backward compatibility: `client_id` is treated as an existing session_id hint. // New clients should not pass it; each offer creates a fresh session. - let webrtc = state.stream_manager.webrtc_streamer(); + let webrtc = &state.webrtc; let session_id = if let Some(client_id) = &req.client_id { // Reuse only when it matches an active session ID. if webrtc.get_session(client_id).await.is_some() { @@ -1923,8 +1869,7 @@ pub async fn webrtc_ice_candidate( Json(req): Json, ) -> Result> { state - .stream_manager - .webrtc_streamer() + .webrtc .add_ice_candidate(&req.session_id, req.candidate) .await?; @@ -1948,7 +1893,7 @@ pub struct WebRtcStatus { } pub async fn webrtc_status(State(state): State>) -> Json { - let sessions = state.stream_manager.webrtc_streamer().list_sessions().await; + let sessions = state.webrtc.list_sessions().await; Json(WebRtcStatus { session_count: sessions.len(), sessions: sessions @@ -1971,11 +1916,7 @@ pub async fn webrtc_close_session( State(state): State>, Json(req): Json, ) -> Result> { - state - .stream_manager - .webrtc_streamer() - .close_session(&req.session_id) - .await?; + state.webrtc.close_session(&req.session_id).await?; Ok(Json(LoginResponse { success: true, @@ -2066,10 +2007,6 @@ pub async fn webrtc_ice_servers(State(state): State>) -> Json Vec { - let mut names = std::fs::read_dir(path) - .ok() - .into_iter() - .flatten() - .flatten() - .filter_map(|entry| entry.file_name().into_string().ok()) - .collect::>(); - names.sort(); - names -} - -fn read_trimmed(path: &std::path::Path) -> Option { - std::fs::read_to_string(path) - .ok() - .map(|value| value.trim().to_string()) -} - fn proc_modules_has(module_name: &str) -> bool { std::fs::read_to_string("/proc/modules") .ok() @@ -2870,10 +2789,6 @@ pub async fn hid_reset(State(state): State>) -> Result>) -> Result>) -> Result> { let config = state.config.get(); @@ -3261,10 +3172,6 @@ pub async fn msd_drive_mkdir( })) } -// ============================================================================ -// ATX (Power Control) -// ============================================================================ - use crate::atx::{AtxState, PowerStatus}; const WOL_HISTORY_MAX_ENTRIES: i64 = 50; @@ -3421,7 +3328,7 @@ async fn record_wol_history(state: &Arc, mac_address: &str) -> Result< "#, ) .bind(mac_address) - .execute(state.config.pool()) + .execute(state.db.pool()) .await?; sqlx::query( @@ -3435,7 +3342,7 @@ async fn record_wol_history(state: &Arc, mac_address: &str) -> Result< "#, ) .bind(WOL_HISTORY_MAX_ENTRIES) - .execute(state.config.pool()) + .execute(state.db.pool()) .await?; Ok(()) @@ -3488,7 +3395,7 @@ pub async fn atx_wol_history( "#, ) .bind(limit as i64) - .fetch_all(state.config.pool()) + .fetch_all(state.db.pool()) .await?; let history = rows @@ -3502,10 +3409,6 @@ pub async fn atx_wol_history( Ok(Json(WolHistoryResponse { history })) } -// ============================================================================ -// Audio Control -// ============================================================================ - use crate::audio::{AudioQuality, AudioStatus}; /// Audio status response (re-exports AudioStatus from audio module) @@ -3554,7 +3457,7 @@ pub async fn set_audio_quality( State(state): State>, Json(req): Json, ) -> Result> { - let quality = AudioQuality::from_str(&req.quality); + let quality = req.quality.parse::()?; state.audio.set_quality(quality).await?; Ok(Json(LoginResponse { success: true, @@ -3588,10 +3491,6 @@ pub async fn list_audio_devices( Ok(Json(devices)) } -// ============================================================================ -// Password Management -// ============================================================================ - /// Change password request #[derive(Deserialize)] pub struct ChangePasswordRequest { @@ -3607,10 +3506,14 @@ pub async fn change_password( ) -> Result> { let current_user = state .users - .get(&session.user_id) + .single_user() .await? .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; + if current_user.id != session.user_id { + return Err(AppError::AuthError("Invalid session".to_string())); + } + if req.new_password.len() < 4 { return Err(AppError::BadRequest( "Password must be at least 4 characters".to_string(), @@ -3654,10 +3557,14 @@ pub async fn change_username( ) -> Result> { let current_user = state .users - .get(&session.user_id) + .single_user() .await? .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; + if current_user.id != session.user_id { + return Err(AppError::AuthError("Invalid session".to_string())); + } + if req.username.len() < 2 { return Err(AppError::BadRequest( "Username must be at least 2 characters".to_string(), @@ -3688,10 +3595,6 @@ pub async fn change_username( })) } -// ============================================================================ -// System Control -// ============================================================================ - /// Restart the application pub async fn system_restart(State(state): State>) -> Json { info!("System restart requested via API"); @@ -3738,10 +3641,6 @@ pub async fn system_restart(State(state): State>) -> Json, diff --git a/src/web/handlers/terminal.rs b/src/web/handlers/terminal.rs index c88f05f6..37185286 100644 --- a/src/web/handlers/terminal.rs +++ b/src/web/handlers/terminal.rs @@ -1,5 +1,3 @@ -//! Terminal proxy handler - reverse proxy to ttyd via Unix socket - use axum::{ body::Body, extract::{ @@ -21,7 +19,6 @@ use crate::error::AppError; use crate::extensions::TTYD_SOCKET_PATH; use crate::state::AppState; -/// Handle WebSocket upgrade for terminal pub async fn terminal_ws( State(_state): State>, OriginalUri(original_uri): OriginalUri, @@ -32,14 +29,12 @@ pub async fn terminal_ws( .map(|q| format!("?{}", q)) .unwrap_or_default(); - // Use the tty subprotocol that ttyd expects + // ttyd expects the `tty` WebSocket subprotocol ws.protocols(["tty"]) .on_upgrade(move |socket| handle_terminal_websocket(socket, query_string)) } -/// Handle terminal WebSocket connection - bridge browser and ttyd async fn handle_terminal_websocket(client_ws: WebSocket, query_string: String) { - // Connect to ttyd Unix socket let unix_stream = match UnixStream::connect(TTYD_SOCKET_PATH).await { Ok(s) => s, Err(e) => { @@ -48,7 +43,6 @@ async fn handle_terminal_websocket(client_ws: WebSocket, query_string: String) { } }; - // Build WebSocket request for ttyd with tty subprotocol let uri_str = format!("ws://localhost/api/terminal/ws{}", query_string); let mut request = match uri_str.into_client_request() { Ok(r) => r, @@ -62,7 +56,6 @@ async fn handle_terminal_websocket(client_ws: WebSocket, query_string: String) { .headers_mut() .insert("Sec-WebSocket-Protocol", HeaderValue::from_static("tty")); - // Create WebSocket connection to ttyd let ws_stream = match tokio_tungstenite::client_async(request, unix_stream).await { Ok((ws, _)) => ws, Err(e) => { @@ -71,11 +64,9 @@ async fn handle_terminal_websocket(client_ws: WebSocket, query_string: String) { } }; - // Split both WebSocket connections let (mut client_tx, mut client_rx) = client_ws.split(); let (mut ttyd_tx, mut ttyd_rx) = ws_stream.split(); - // Forward messages from browser to ttyd let client_to_ttyd = tokio::spawn(async move { while let Some(msg) = client_rx.next().await { let ttyd_msg = match msg { @@ -96,7 +87,6 @@ async fn handle_terminal_websocket(client_ws: WebSocket, query_string: String) { } }); - // Forward messages from ttyd to browser let ttyd_to_client = tokio::spawn(async move { while let Some(msg) = ttyd_rx.next().await { let client_msg = match msg { @@ -118,14 +108,12 @@ async fn handle_terminal_websocket(client_ws: WebSocket, query_string: String) { } }); - // Wait for either direction to complete tokio::select! { _ = client_to_ttyd => {} _ = ttyd_to_client => {} } } -/// Proxy HTTP requests to ttyd pub async fn terminal_proxy( State(_state): State>, path: Option>, @@ -133,12 +121,10 @@ pub async fn terminal_proxy( ) -> Result { let path_str = path.map(|p| p.0).unwrap_or_default(); - // Connect to ttyd Unix socket let mut unix_stream = UnixStream::connect(TTYD_SOCKET_PATH) .await .map_err(|e| AppError::ServiceUnavailable(format!("ttyd not running: {}", e)))?; - // Build HTTP request to forward let method = req.method().as_str(); let query = req .uri() @@ -151,7 +137,6 @@ pub async fn terminal_proxy( format!("/api/terminal/{}{}", path_str, query) }; - // Forward relevant headers let mut headers_str = String::new(); for (name, value) in req.headers() { if let Ok(v) = value.to_str() { @@ -170,20 +155,17 @@ pub async fn terminal_proxy( method, uri_path, headers_str ); - // Send request unix_stream .write_all(http_request.as_bytes()) .await .map_err(|e| AppError::Internal(format!("Failed to send request: {}", e)))?; - // Read response let mut response_buf = Vec::new(); unix_stream .read_to_end(&mut response_buf) .await .map_err(|e| AppError::Internal(format!("Failed to read response: {}", e)))?; - // Parse HTTP response let response_str = String::from_utf8_lossy(&response_buf); let header_end = response_str .find("\r\n\r\n") @@ -192,7 +174,6 @@ pub async fn terminal_proxy( let headers_part = &response_str[..header_end]; let body_start = header_end + 4; - // Parse status line let status_line = headers_part .lines() .next() @@ -203,11 +184,9 @@ pub async fn terminal_proxy( .and_then(|s| s.parse().ok()) .unwrap_or(200); - // Build response let mut builder = Response::builder().status(StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK)); - // Forward response headers for line in headers_part.lines().skip(1) { if let Some((name, value)) = line.split_once(':') { let name = name.trim(); @@ -232,7 +211,6 @@ pub async fn terminal_proxy( .map_err(|e| AppError::Internal(format!("Failed to build response: {}", e))) } -/// Terminal index page pub async fn terminal_index( State(state): State>, req: Request, diff --git a/src/web/mod.rs b/src/web/mod.rs index 0bff07e0..d4a90159 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -1,12 +1,13 @@ mod audio_ws; +mod error; mod handlers; mod routes; mod static_files; mod ws; pub use audio_ws::audio_ws_handler; +pub use error::ErrorResponse; pub use routes::create_router; -// StaticAssets is only available in release mode (embedded assets) #[cfg(not(debug_assertions))] pub use static_files::StaticAssets; pub use ws::ws_handler; diff --git a/src/web/routes.rs b/src/web/routes.rs index ec4ffae7..2260d249 100644 --- a/src/web/routes.rs +++ b/src/web/routes.rs @@ -17,7 +17,6 @@ use crate::auth::auth_middleware; use crate::hid::websocket::ws_hid_handler; use crate::state::AppState; -/// Create the main application router pub fn create_router(state: Arc) -> Router { let cors = CorsLayer::new() .allow_origin(Any) diff --git a/src/web/static_files.rs b/src/web/static_files.rs index d4c0e95b..caaddf7b 100644 --- a/src/web/static_files.rs +++ b/src/web/static_files.rs @@ -9,38 +9,30 @@ use std::path::PathBuf; #[cfg(debug_assertions)] use std::sync::OnceLock; -// Only embed assets in release mode #[cfg(not(debug_assertions))] use rust_embed::Embed; #[cfg(not(debug_assertions))] -/// Embedded static assets (frontend files) - only in release mode #[derive(Embed)] #[folder = "web/dist"] #[prefix = ""] pub struct StaticAssets; -/// Get the base directory for static files -/// In debug mode: relative to executable directory -/// In release mode: not used (embedded assets) #[cfg(debug_assertions)] fn get_static_base_dir() -> PathBuf { static BASE_DIR: OnceLock = OnceLock::new(); BASE_DIR .get_or_init(|| { - // Try to get executable directory if let Ok(exe_path) = std::env::current_exe() { if let Some(exe_dir) = exe_path.parent() { return exe_dir.join("web").join("dist"); } } - // Fallback to current directory PathBuf::from("web/dist") }) .clone() } -/// Create router for static file serving pub fn static_file_router() -> Router where S: Clone + Send + Sync + 'static, @@ -50,29 +42,23 @@ where .route("/{*path}", get(static_handler)) } -/// Serve index.html for root path async fn index_handler() -> Response { serve_file("index.html") } -/// Serve static files async fn static_handler(uri: Uri) -> Response { let path = uri.path().trim_start_matches('/'); - // Try to serve the exact file if let Some(response) = try_serve_file(path) { return response; } - // For SPA routing, serve index.html for non-asset paths if !path.contains('.') { if let Some(response) = try_serve_file("index.html") { return response; } } - // If no embedded assets found, return placeholder page - // This happens when web/dist was not built before compilation Response::builder() .status(StatusCode::OK) .header(header::CONTENT_TYPE, "text/html; charset=utf-8") @@ -82,7 +68,6 @@ async fn static_handler(uri: Uri) -> Response { fn serve_file(path: &str) -> Response { try_serve_file(path).unwrap_or_else(|| { - // If index.html not found in embedded assets, return placeholder if path == "index.html" { Response::builder() .status(StatusCode::OK) @@ -101,17 +86,14 @@ fn serve_file(path: &str) -> Response { fn try_serve_file(path: &str) -> Option> { #[cfg(debug_assertions)] { - // Debug mode: read from file system let base_dir = get_static_base_dir(); let file_path = base_dir.join(path); - // Check if file exists and is within base directory (prevent directory traversal) if !file_path.starts_with(&base_dir) { tracing::warn!("Path traversal attempt blocked: {}", path); return None; } - // Normalize path to prevent directory traversal (only if file exists) if let (Ok(normalized_path), Ok(normalized_base)) = (file_path.canonicalize(), base_dir.canonicalize()) { @@ -150,7 +132,6 @@ fn try_serve_file(path: &str) -> Option> { #[cfg(not(debug_assertions))] { - // Release mode: use embedded assets let asset = StaticAssets::get(path)?; let mime = mime_guess::from_path(path) @@ -168,7 +149,6 @@ fn try_serve_file(path: &str) -> Option> { } } -/// Placeholder index.html when frontend is not built pub fn placeholder_html() -> &'static str { r#" diff --git a/src/web/ws.rs b/src/web/ws.rs index a4077d33..3dd8e4b4 100644 --- a/src/web/ws.rs +++ b/src/web/ws.rs @@ -1,11 +1,3 @@ -//! WebSocket handler for real-time event streaming -//! -//! This module provides a WebSocket endpoint at `/api/ws` that: -//! - Broadcasts system events to connected clients -//! - Supports topic-based event filtering -//! - Handles client subscription management -//! - Includes heartbeat (ping/pong) mechanism - use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, @@ -132,48 +124,36 @@ fn rebuild_event_tasks( } } -/// Client-to-server message #[derive(Debug, Deserialize)] #[serde(tag = "type", content = "payload")] enum ClientMessage { - /// Subscribe to event topics #[serde(rename = "subscribe")] Subscribe { topics: Vec }, - /// Unsubscribe from event topics #[serde(rename = "unsubscribe")] Unsubscribe { topics: Vec }, - /// Ping (keep-alive) #[serde(rename = "ping")] Ping, } -/// WebSocket upgrade handler -/// -/// This is the entry point for WebSocket connections at `/api/ws`. -/// Authentication is handled by the middleware. pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { ws.on_upgrade(move |socket| handle_socket(socket, state)) } -/// Handle WebSocket connection async fn handle_socket(socket: WebSocket, state: Arc) { let (mut sender, mut receiver) = socket.split(); let (event_tx, mut event_rx) = mpsc::unbounded_channel(); let mut event_tasks: Vec> = Vec::new(); - // Track subscribed topics (default: none until client subscribes) let mut subscribed_topics: Vec = vec![]; info!("WebSocket client connected"); - // Heartbeat interval (30 seconds) let mut heartbeat_interval = tokio::time::interval(tokio::time::Duration::from_secs(30)); loop { tokio::select! { - // Receive message from client msg = receiver.next() => { match msg { Some(Ok(Message::Text(text))) => { @@ -189,7 +169,6 @@ async fn handle_socket(socket: WebSocket, state: Arc) { } } Some(Ok(Message::Ping(_))) => { - // WebSocket automatically handles ping/pong debug!("Received ping from client"); } Some(Ok(Message::Pong(_))) => { @@ -207,11 +186,9 @@ async fn handle_socket(socket: WebSocket, state: Arc) { } } - // Receive event from event bus event = event_rx.recv() => { match event { Some(BusMessage::Event(event)) => { - // Filter event based on subscribed topics if let Ok(json) = serialize_event(&event) { if sender.send(Message::Text(json.into())).await.is_err() { warn!("Failed to send event to client, disconnecting"); @@ -224,7 +201,6 @@ async fn handle_socket(socket: WebSocket, state: Arc) { "WebSocket client lagged by {} events on topic {}", count, topic ); - // Send error notification to client using SystemEvent::Error let error_event = SystemEvent::Error { message: format!("Lagged by {} events", count), }; @@ -239,7 +215,6 @@ async fn handle_socket(socket: WebSocket, state: Arc) { } } - // Heartbeat _ = heartbeat_interval.tick() => { if sender.send(Message::Ping(vec![].into())).await.is_err() { warn!("Failed to send ping, disconnecting"); @@ -256,7 +231,6 @@ async fn handle_socket(socket: WebSocket, state: Arc) { info!("WebSocket handler exiting"); } -/// Handle message from client async fn handle_client_message( text: &str, topics: &mut Vec, @@ -282,7 +256,6 @@ async fn handle_client_message( Ok(()) } -/// Serialize event to JSON string fn serialize_event(event: &SystemEvent) -> Result { serde_json::to_string(event) } diff --git a/src/webrtc/config.rs b/src/webrtc/config.rs index a5eef254..a6447cfc 100644 --- a/src/webrtc/config.rs +++ b/src/webrtc/config.rs @@ -1,42 +1,28 @@ -//! WebRTC configuration - use serde::{Deserialize, Serialize}; -/// ICE server utilities — public STUN only (TURN must be user-configured). +/// Public STUN from build-time secrets; TURN is user-configured. pub mod public_ice { - /// Whether a build-time public STUN URL exists (always true for stock builds). #[inline] pub fn is_configured() -> bool { true } - /// Build-time public STUN URL (`secrets::ice::STUN_SERVER`). #[inline] pub fn stun_server() -> &'static str { crate::secrets::ice::STUN_SERVER } } -/// WebRTC configuration #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WebRtcConfig { - /// Enable WebRTC pub enabled: bool, - /// STUN server URLs pub stun_servers: Vec, - /// TURN server configuration pub turn_servers: Vec, - /// Enable DataChannel for HID pub enable_datachannel: bool, - /// Video codec preference pub video_codec: VideoCodec, - /// Target bitrate in kbps pub target_bitrate_kbps: u32, - /// Maximum bitrate in kbps pub max_bitrate_kbps: u32, - /// Minimum bitrate in kbps pub min_bitrate_kbps: u32, - /// Enable audio track pub enable_audio: bool, } @@ -44,8 +30,6 @@ impl Default for WebRtcConfig { fn default() -> Self { Self { enabled: true, - // Empty STUN servers for local connections - host candidates work directly - // For remote access, configure STUN/TURN servers via settings stun_servers: vec![], turn_servers: vec![], enable_datachannel: true, @@ -58,20 +42,14 @@ impl Default for WebRtcConfig { } } -/// TURN server configuration #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TurnServer { - /// TURN server URLs (e.g., ["turn:turn.example.com:3478?transport=udp", "turn:turn.example.com:3478?transport=tcp"]) - /// Multiple URLs allow fallback between UDP and TCP transports pub urls: Vec, - /// Username for TURN authentication pub username: String, - /// Credential for TURN authentication pub credential: String, } impl TurnServer { - /// Create a TurnServer with a single URL (for backwards compatibility) pub fn new(url: String, username: String, credential: String) -> Self { Self { urls: vec![url], @@ -81,7 +59,6 @@ impl TurnServer { } } -/// Video codec preference #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] #[derive(Default)] @@ -104,14 +81,10 @@ impl std::fmt::Display for VideoCodec { } } -/// ICE configuration #[derive(Debug, Clone)] pub struct IceConfig { - /// ICE candidate gathering timeout (ms) pub gathering_timeout_ms: u64, - /// ICE connection timeout (ms) pub connection_timeout_ms: u64, - /// Enable ICE lite mode pub ice_lite: bool, } diff --git a/src/webrtc/mod.rs b/src/webrtc/mod.rs index 8640ab56..2c1ae15b 100644 --- a/src/webrtc/mod.rs +++ b/src/webrtc/mod.rs @@ -1,46 +1,17 @@ -//! WebRTC module for low-latency video streaming -//! -//! This module provides WebRTC-based video streaming with: -//! - H.264 video track (hardware/software encoding) -//! - H.265 video track (hardware only) -//! - VP8/VP9 video track (hardware only - VAAPI) -//! - Opus audio track (optional) -//! - DataChannel for HID events -//! -//! Architecture: -//! ```text -//! V4L2 capture -//! | -//! v -//! SharedVideoPipeline (decode -> convert -> encode) -//! | -//! v -//! UniversalVideoTrack (RTP packetization) -//! | -//! v -//! WebRTC PeerConnection -//! | -//! Browser <-------- SDP Exchange ------- API Server -//! | -//! +------- DataChannel ------> HID Events -//! ``` +//! Low-latency WebRTC streaming: shared encoder → [`video_track::UniversalVideoTrack`] → peer; +//! HID over DataChannel. pub mod config; pub mod h265_payloader; pub(crate) mod mdns; -pub mod peer; pub mod rtp; -pub mod session; pub mod signaling; -pub mod track; pub mod universal_session; pub mod video_track; pub mod webrtc_streamer; pub use config::WebRtcConfig; -pub use peer::PeerConnection; -pub use rtp::{H264VideoTrack, H264VideoTrackConfig, OpusAudioTrack}; -pub use session::WebRtcSessionManager; +pub use rtp::OpusAudioTrack; pub use signaling::{ConnectionState, IceCandidate, SdpAnswer, SdpOffer, SignalingMessage}; pub use universal_session::{UniversalSession, UniversalSessionConfig, UniversalSessionInfo}; pub use video_track::{UniversalVideoTrack, UniversalVideoTrackConfig, VideoCodec}; diff --git a/src/webrtc/peer.rs b/src/webrtc/peer.rs deleted file mode 100644 index 5dbbb8e0..00000000 --- a/src/webrtc/peer.rs +++ /dev/null @@ -1,549 +0,0 @@ -//! WebRTC peer connection management - -use std::sync::Arc; -use tokio::sync::{broadcast, watch, Mutex, RwLock}; -use tracing::{debug, info}; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MediaEngine; -use webrtc::api::setting_engine::SettingEngine; -use webrtc::api::APIBuilder; -use webrtc::data_channel::data_channel_message::DataChannelMessage; -use webrtc::data_channel::RTCDataChannel; -use webrtc::ice::mdns::MulticastDnsMode; -use webrtc::ice_transport::ice_candidate::RTCIceCandidate; -use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::peer_connection::RTCPeerConnection; - -use super::config::WebRtcConfig; -use super::mdns::{default_mdns_host_name, mdns_mode}; -use super::signaling::{ConnectionState, IceCandidate, SdpAnswer, SdpOffer}; -use super::track::{VideoTrack, VideoTrackConfig}; -use crate::error::{AppError, Result}; -use crate::hid::datachannel::{parse_hid_message, HidChannelEvent}; -use crate::hid::HidController; -use crate::video::frame::VideoFrame; - -/// Peer connection wrapper with event handling -pub struct PeerConnection { - /// Session ID - pub session_id: String, - /// WebRTC peer connection - pc: Arc, - /// Video track - video_track: Option, - /// Data channel for HID events - data_channel: Arc>>>, - /// Connection state - state: Arc>, - /// State receiver - state_rx: watch::Receiver, - /// ICE candidates gathered - ice_candidates: Arc>>, - /// HID controller reference - hid_controller: Option>, -} - -impl PeerConnection { - /// Create a new peer connection - pub async fn new(config: &WebRtcConfig, session_id: String) -> Result { - // Create media engine - let mut media_engine = MediaEngine::default(); - - // Register codecs - media_engine - .register_default_codecs() - .map_err(|e| AppError::VideoError(format!("Failed to register codecs: {}", e)))?; - - // Create interceptor registry - let mut registry = Registry::new(); - registry = register_default_interceptors(registry, &mut media_engine) - .map_err(|e| AppError::VideoError(format!("Failed to register interceptors: {}", e)))?; - - // Create API (with optional mDNS settings) - let mut setting_engine = SettingEngine::default(); - let mode = mdns_mode(); - setting_engine.set_ice_multicast_dns_mode(mode); - if mode == MulticastDnsMode::QueryAndGather { - setting_engine.set_multicast_dns_host_name(default_mdns_host_name(&session_id)); - } - info!("WebRTC mDNS mode: {:?} (session {})", mode, session_id); - - let api = APIBuilder::new() - .with_setting_engine(setting_engine) - .with_media_engine(media_engine) - .with_interceptor_registry(registry) - .build(); - - // Build ICE servers - let mut ice_servers = vec![]; - - for stun_url in &config.stun_servers { - ice_servers.push(RTCIceServer { - urls: vec![stun_url.clone()], - ..Default::default() - }); - } - - for turn in &config.turn_servers { - ice_servers.push(RTCIceServer { - urls: turn.urls.clone(), - username: turn.username.clone(), - credential: turn.credential.clone(), - }); - } - - // Create peer connection configuration - let rtc_config = RTCConfiguration { - ice_servers, - ..Default::default() - }; - - // Create peer connection - let pc = api.new_peer_connection(rtc_config).await.map_err(|e| { - AppError::VideoError(format!("Failed to create peer connection: {}", e)) - })?; - - let pc = Arc::new(pc); - - // Create state channel - let (state_tx, state_rx) = watch::channel(ConnectionState::New); - - let peer_connection = Self { - session_id, - pc, - video_track: None, - data_channel: Arc::new(RwLock::new(None)), - state: Arc::new(state_tx), - state_rx, - ice_candidates: Arc::new(Mutex::new(vec![])), - hid_controller: None, - }; - - // Set up event handlers - peer_connection.setup_event_handlers().await; - - Ok(peer_connection) - } - - /// Set up peer connection event handlers - async fn setup_event_handlers(&self) { - let state = self.state.clone(); - let session_id = self.session_id.clone(); - - // Connection state change handler - self.pc - .on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - let state = state.clone(); - let session_id = session_id.clone(); - - Box::pin(async move { - let new_state = match s { - RTCPeerConnectionState::New => ConnectionState::New, - RTCPeerConnectionState::Connecting => ConnectionState::Connecting, - RTCPeerConnectionState::Connected => ConnectionState::Connected, - RTCPeerConnectionState::Disconnected => ConnectionState::Disconnected, - RTCPeerConnectionState::Failed => ConnectionState::Failed, - RTCPeerConnectionState::Closed => ConnectionState::Closed, - _ => return, - }; - - info!("Peer {} connection state: {}", session_id, new_state); - let _ = state.send(new_state); - }) - })); - - // ICE candidate handler - let ice_candidates = self.ice_candidates.clone(); - self.pc - .on_ice_candidate(Box::new(move |candidate: Option| { - let ice_candidates = ice_candidates.clone(); - - Box::pin(async move { - if let Some(c) = candidate { - let candidate_str = c.to_json().map(|j| j.candidate).unwrap_or_default(); - - debug!("ICE candidate: {}", candidate_str); - - let mut candidates = ice_candidates.lock().await; - candidates.push(IceCandidate { - candidate: candidate_str, - sdp_mid: c.to_json().ok().and_then(|j| j.sdp_mid), - sdp_mline_index: c.to_json().ok().and_then(|j| j.sdp_mline_index), - username_fragment: None, - }); - } - }) - })); - - // Data channel handler - note: HID processing is done when hid_controller is set - let data_channel = self.data_channel.clone(); - self.pc - .on_data_channel(Box::new(move |dc: Arc| { - let data_channel = data_channel.clone(); - - Box::pin(async move { - info!("Data channel opened: {}", dc.label()); - - // Store data channel - *data_channel.write().await = Some(dc.clone()); - - // Message handler logs messages; HID processing requires set_hid_controller() - dc.on_message(Box::new(move |msg: DataChannelMessage| { - debug!("DataChannel message: {} bytes", msg.data.len()); - Box::pin(async {}) - })); - }) - })); - } - - /// Set HID controller for processing DataChannel messages - pub fn set_hid_controller(&mut self, hid: Arc) { - let hid_clone = hid.clone(); - let data_channel = self.data_channel.clone(); - - // Set up message handler with HID processing - let pc = self.pc.clone(); - pc.on_data_channel(Box::new(move |dc: Arc| { - let data_channel = data_channel.clone(); - let hid = hid_clone.clone(); - let label = dc.label().to_string(); - - Box::pin(async move { - // Handle both reliable (hid) and unreliable (hid-unreliable) channels - let is_hid_channel = label == "hid" || label == "hid-unreliable"; - - if is_hid_channel { - info!( - "HID DataChannel opened: {} (unreliable: {})", - label, - label == "hid-unreliable" - ); - - // Store the reliable data channel for sending responses - if label == "hid" { - *data_channel.write().await = Some(dc.clone()); - } - - // Set up message handler with HID processing - // Both channels use the same HID processing logic - dc.on_message(Box::new(move |msg: DataChannelMessage| { - let hid = hid.clone(); - - tokio::spawn(async move { - debug!("DataChannel HID message: {} bytes", msg.data.len()); - - // Parse and process HID message - if let Some(event) = parse_hid_message(&msg.data) { - match event { - HidChannelEvent::Keyboard(kb_event) => { - if let Err(e) = hid.send_keyboard(kb_event).await { - debug!("Failed to send keyboard event: {}", e); - } - } - HidChannelEvent::Mouse(ms_event) => { - if let Err(e) = hid.send_mouse(ms_event).await { - debug!("Failed to send mouse event: {}", e); - } - } - HidChannelEvent::Consumer(consumer_event) => { - if let Err(e) = hid.send_consumer(consumer_event).await { - debug!("Failed to send consumer event: {}", e); - } - } - } - } - }); - - // Return empty future (actual work is spawned above) - Box::pin(async {}) - })); - } else { - info!("Non-HID DataChannel opened: {}", label); - } - }) - })); - - self.hid_controller = Some(hid); - } - - /// Add video track to the connection - pub async fn add_video_track(&mut self, config: VideoTrackConfig) -> Result<()> { - let video_track = VideoTrack::new(config); - - // Add track to peer connection - self.pc - .add_track(video_track.rtp_track()) - .await - .map_err(|e| AppError::VideoError(format!("Failed to add video track: {}", e)))?; - - self.video_track = Some(video_track); - info!("Video track added to peer connection"); - - Ok(()) - } - - /// Create data channel for HID events - pub async fn create_data_channel(&self, label: &str) -> Result<()> { - let dc = self - .pc - .create_data_channel(label, None) - .await - .map_err(|e| AppError::VideoError(format!("Failed to create data channel: {}", e)))?; - - *self.data_channel.write().await = Some(dc); - info!("Data channel '{}' created", label); - - Ok(()) - } - - /// Handle SDP offer and create answer - pub async fn handle_offer(&self, offer: SdpOffer) -> Result { - // Parse and set remote description - let sdp = RTCSessionDescription::offer(offer.sdp) - .map_err(|e| AppError::VideoError(format!("Invalid SDP offer: {}", e)))?; - - self.pc.set_remote_description(sdp).await.map_err(|e| { - AppError::VideoError(format!("Failed to set remote description: {}", e)) - })?; - - // Create answer - let answer = self - .pc - .create_answer(None) - .await - .map_err(|e| AppError::VideoError(format!("Failed to create answer: {}", e)))?; - - // Wait for ICE gathering complete (or timeout) after setting local description. - // This improves first-connection robustness by returning a fuller initial candidate set. - let mut gather_complete = self.pc.gathering_complete_promise().await; - - // Set local description - self.pc - .set_local_description(answer.clone()) - .await - .map_err(|e| AppError::VideoError(format!("Failed to set local description: {}", e)))?; - - const ICE_GATHER_TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_millis(2500); - if tokio::time::timeout(ICE_GATHER_TIMEOUT, gather_complete.recv()) - .await - .is_err() - { - debug!( - "ICE gathering timeout after {:?} for session {}", - ICE_GATHER_TIMEOUT, self.session_id - ); - } - - // Get gathered ICE candidates - let candidates = self.ice_candidates.lock().await.clone(); - - Ok(SdpAnswer::with_candidates(answer.sdp, candidates)) - } - - /// Add ICE candidate - pub async fn add_ice_candidate(&self, candidate: IceCandidate) -> Result<()> { - use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; - - let init = RTCIceCandidateInit { - candidate: candidate.candidate, - sdp_mid: candidate.sdp_mid, - sdp_mline_index: candidate.sdp_mline_index, - username_fragment: candidate.username_fragment, - }; - - self.pc - .add_ice_candidate(init) - .await - .map_err(|e| AppError::VideoError(format!("Failed to add ICE candidate: {}", e)))?; - - Ok(()) - } - - /// Get current connection state - pub fn state(&self) -> ConnectionState { - *self.state_rx.borrow() - } - - /// Subscribe to state changes - pub fn state_watch(&self) -> watch::Receiver { - self.state_rx.clone() - } - - /// Start sending video frames - pub async fn start_video(&self, frame_rx: broadcast::Receiver) { - if let Some(ref track) = self.video_track { - track.start_sending(frame_rx).await; - } - } - - /// Send HID data via data channel - pub async fn send_hid_data(&self, data: &[u8]) -> Result<()> { - let dc = self.data_channel.read().await; - - if let Some(ref channel) = *dc { - channel - .send(&bytes::Bytes::copy_from_slice(data)) - .await - .map_err(|e| AppError::VideoError(format!("Failed to send HID data: {}", e)))?; - } - - Ok(()) - } - - /// Close the connection - pub async fn close(&self) -> Result<()> { - // Reset HID state to release any held keys/buttons - if let Some(ref hid) = self.hid_controller { - if let Err(e) = hid.reset().await { - tracing::warn!( - "Failed to reset HID on peer {} close: {}", - self.session_id, - e - ); - } else { - tracing::debug!("HID reset on peer {} close", self.session_id); - } - } - - if let Some(ref track) = self.video_track { - track.stop(); - } - - self.pc - .close() - .await - .map_err(|e| AppError::VideoError(format!("Failed to close peer connection: {}", e)))?; - - Ok(()) - } - - /// Get session ID - pub fn session_id(&self) -> &str { - &self.session_id - } -} - -/// Manager for multiple peer connections -pub struct PeerConnectionManager { - config: WebRtcConfig, - /// Active peer connections - peers: Arc>>>>, - /// Frame broadcast sender (to distribute to all peers) - frame_tx: broadcast::Sender, - /// HID controller for DataChannel HID processing - hid_controller: Option>, -} - -impl PeerConnectionManager { - /// Create a new peer connection manager - pub fn new(config: WebRtcConfig) -> Self { - let (frame_tx, _) = broadcast::channel(16); - - Self { - config, - peers: Arc::new(RwLock::new(vec![])), - frame_tx, - hid_controller: None, - } - } - - /// Create a new peer connection manager with HID controller - pub fn with_hid(config: WebRtcConfig, hid: Arc) -> Self { - let (frame_tx, _) = broadcast::channel(16); - - Self { - config, - peers: Arc::new(RwLock::new(vec![])), - frame_tx, - hid_controller: Some(hid), - } - } - - /// Set HID controller - pub fn set_hid_controller(&mut self, hid: Arc) { - self.hid_controller = Some(hid); - } - - /// Create a new peer connection - pub async fn create_peer(&self) -> Result>> { - let session_id = uuid::Uuid::new_v4().to_string(); - let mut peer = PeerConnection::new(&self.config, session_id).await?; - - // Add video track - peer.add_video_track(VideoTrackConfig::default()).await?; - - // Set HID controller if available - // Note: We DON'T create a data channel here - the frontend creates it. - // The server receives it via on_data_channel callback set in set_hid_controller(). - if self.config.enable_datachannel { - if let Some(ref hid) = self.hid_controller { - peer.set_hid_controller(hid.clone()); - } - } - - let peer = Arc::new(Mutex::new(peer)); - - // Add to peers list - self.peers.write().await.push(peer.clone()); - - // Start sending video when connected - let frame_rx = self.frame_tx.subscribe(); - let peer_clone = peer.clone(); - tokio::spawn(async move { - let peer = peer_clone.lock().await; - let mut state_rx = peer.state_watch(); - drop(peer); - - // Wait for connection - while state_rx.changed().await.is_ok() { - if *state_rx.borrow() == ConnectionState::Connected { - let peer = peer_clone.lock().await; - peer.start_video(frame_rx).await; - break; - } - } - }); - - Ok(peer) - } - - /// Get frame sender (for video streamer to push frames) - pub fn frame_sender(&self) -> broadcast::Sender { - self.frame_tx.clone() - } - - /// Remove closed connections - pub async fn cleanup(&self) { - let mut peers = self.peers.write().await; - let mut to_remove = vec![]; - - for (i, peer) in peers.iter().enumerate() { - let p = peer.lock().await; - if matches!(p.state(), ConnectionState::Closed | ConnectionState::Failed) { - to_remove.push(i); - } - } - - for i in to_remove.into_iter().rev() { - peers.remove(i); - } - } - - /// Get active peer count - pub async fn peer_count(&self) -> usize { - self.peers.read().await.len() - } - - /// Close all connections - pub async fn close_all(&self) { - let peers = self.peers.read().await; - for peer in peers.iter() { - let p = peer.lock().await; - let _ = p.close().await; - } - } -} diff --git a/src/webrtc/rtp.rs b/src/webrtc/rtp.rs index f576f4e9..d72228cd 100644 --- a/src/webrtc/rtp.rs +++ b/src/webrtc/rtp.rs @@ -1,320 +1,25 @@ -//! RTP packetization for H264 video -//! -//! This module provides H264 RTP packetization using the rtp crate's H264Payloader. -//! It handles: -//! - NAL unit parsing (Annex B start codes) -//! - SPS/PPS collection and STAP-A packetization -//! - Single NAL unit mode for small NALs -//! - FU-A fragmentation for large NALs -//! -//! IMPORTANT: Each NAL unit must be sent separately via write_sample(), -//! without Annex B start codes. The TrackLocalStaticSample handles -//! RTP packetization internally. +//! Opus outbound track plus H.264 Annex B helpers (SPS/PPS, keyframe scan). Video RTP lives in [`crate::webrtc::video_track`]. use bytes::Bytes; -use rtp::codecs::h264::H264Payloader; -use rtp::packetizer::Payloader; -use std::io::Cursor; use std::sync::Arc; use std::time::Duration; -use tokio::sync::Mutex; -use tracing::{debug, error, trace}; -use webrtc::media::io::h264_reader::H264Reader; +use tracing::error; use webrtc::media::Sample; use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; use webrtc::track::track_local::TrackLocal; use crate::error::{AppError, Result}; -use crate::video::format::Resolution; -/// Default MTU for RTP packets (conservative for most networks) pub const RTP_MTU: usize = 1200; -/// H264 clock rate (always 90kHz per RFC 6184) pub const H264_CLOCK_RATE: u32 = 90000; -/// H264 video track using TrackLocalStaticSample for proper packetization -pub struct H264VideoTrack { - /// The underlying WebRTC track with automatic packetization - track: Arc, - /// Track configuration - config: H264VideoTrackConfig, - /// H264 payloader for manual packetization (if needed) - payloader: Mutex, - /// Cached SPS NAL unit for injection before IDR frames - /// Some hardware encoders don't repeat SPS/PPS with every keyframe - cached_sps: Mutex>, - /// Cached PPS NAL unit for injection before IDR frames - cached_pps: Mutex>, -} - -/// H264 video track configuration -#[derive(Debug, Clone)] -pub struct H264VideoTrackConfig { - /// Track ID - pub track_id: String, - /// Stream ID - pub stream_id: String, - /// Resolution - pub resolution: Resolution, - /// Target bitrate in kbps - pub bitrate_kbps: u32, - /// Frames per second - pub fps: u32, - /// H.264 profile-level-id (e.g., "42001f" for Baseline L3.1, "64001f" for High L3.1) - /// If None, uses empty string to let browser negotiate - /// Format: PPCCLL where PP=profile_idc, CC=constraint_flags, LL=level_idc - pub profile_level_id: Option, -} - -impl Default for H264VideoTrackConfig { - fn default() -> Self { - Self { - track_id: "video0".to_string(), - stream_id: "one-kvm-stream".to_string(), - resolution: Resolution::HD720, - bitrate_kbps: 8000, - fps: 30, - profile_level_id: None, // Let browser negotiate - } - } -} - -impl H264VideoTrack { - /// Create a new H264 video track - /// - /// If `config.profile_level_id` is set, it will be used in SDP negotiation. - /// Otherwise, uses empty fmtp line to let browser negotiate the best profile. - pub fn new(config: H264VideoTrackConfig) -> Self { - // Build sdp_fmtp_line based on profile_level_id - let sdp_fmtp_line = if let Some(ref profile_level_id) = config.profile_level_id { - // Use specified profile-level-id - format!( - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id={}", - profile_level_id - ) - } else { - // Let browser negotiate - empty string for maximum compatibility - String::new() - }; - - let codec = RTCRtpCodecCapability { - mime_type: "video/H264".to_string(), - clock_rate: H264_CLOCK_RATE, - channels: 0, - sdp_fmtp_line, - rtcp_feedback: vec![], - }; - - let track = Arc::new(TrackLocalStaticSample::new( - codec, - config.track_id.clone(), - config.stream_id.clone(), - )); - - Self { - track, - config, - payloader: Mutex::new(H264Payloader::default()), - cached_sps: Mutex::new(None), - cached_pps: Mutex::new(None), - } - } - - /// Get the underlying WebRTC track for adding to peer connection - pub fn track(&self) -> Arc { - self.track.clone() - } - - /// Get track as TrackLocal for peer connection - pub fn as_track_local(&self) -> Arc { - self.track.clone() - } - - /// Write an H264 encoded frame to the track - /// - /// The frame data should be H264 Annex B format (with start codes 0x00000001 or 0x000001). - /// This is the format produced by hwcodec/FFmpeg encoders. - /// - /// IMPORTANT: Each NAL unit is sent separately via write_sample(), without start codes. - /// This is required for proper WebRTC RTP packetization. - /// See: https://github.com/webrtc-rs/webrtc/blob/master/examples/examples/play-from-disk-h264/ - /// - /// # Arguments - /// * `data` - H264 Annex B encoded frame data - /// * `duration` - Frame duration (typically 1/fps seconds) - /// * `is_keyframe` - Whether this is a keyframe (IDR frame) - pub async fn write_frame( - &self, - data: &[u8], - _duration: Duration, - is_keyframe: bool, - ) -> Result<()> { - if data.is_empty() { - return Ok(()); - } - - // Use H264Reader to parse NAL units from Annex B data - let cursor = Cursor::new(data); - let mut h264_reader = H264Reader::new(cursor, 1024 * 1024); - - // Collect all NAL units first to check for SPS/PPS presence - let mut nals: Vec = Vec::new(); - let mut has_sps = false; - let mut has_pps = false; - let mut has_idr = false; - - // Send each NAL unit separately (like official webrtc-rs example) - // H264Reader returns NAL data WITHOUT start codes - this is what we need - while let Ok(nal) = h264_reader.next_nal() { - if nal.data.is_empty() { - continue; - } - - let nal_type = nal.data[0] & 0x1F; - - // Skip AUD NAL units (type 9) - not needed for WebRTC - if nal_type == 9 { - continue; - } - - // Skip filler data (type 12) - if nal_type == 12 { - continue; - } - - // Track NAL types - match nal_type { - 5 => has_idr = true, - 7 => { - has_sps = true; - // Cache SPS for future injection - *self.cached_sps.lock().await = Some(nal.data.clone().freeze()); - } - 8 => { - has_pps = true; - // Cache PPS for future injection - *self.cached_pps.lock().await = Some(nal.data.clone().freeze()); - } - _ => {} - } - - trace!( - "Sending NAL: type={} ({}) size={} bytes", - nal_type, - match nal_type { - 1 => "Non-IDR slice", - 5 => "IDR slice", - 6 => "SEI", - 7 => "SPS", - 8 => "PPS", - _ => "Other", - }, - nal.data.len() - ); - - nals.push(nal.data.freeze()); - } - - // Inject cached SPS/PPS before IDR if missing - // This is critical for hardware encoders that don't repeat SPS/PPS - if has_idr && (!has_sps || !has_pps) { - let mut injected_nals: Vec = Vec::new(); - - if !has_sps { - if let Some(sps) = self.cached_sps.lock().await.clone() { - debug!("Injecting cached SPS before IDR frame"); - injected_nals.push(sps); - } - } - if !has_pps { - if let Some(pps) = self.cached_pps.lock().await.clone() { - debug!("Injecting cached PPS before IDR frame"); - injected_nals.push(pps); - } - } - - if !injected_nals.is_empty() { - injected_nals.extend(nals); - nals = injected_nals; - } - } - - let mut nal_count = 0; - let mut total_bytes = 0u64; - - // Send NAL data directly WITHOUT start codes - // TrackLocalStaticSample handles RTP packetization internally - // Use duration=1s for each NAL like official webrtc-rs example - for nal_data in nals { - let sample = Sample { - data: nal_data.clone(), - duration: Duration::from_secs(1), // Like official example - ..Default::default() - }; - - if let Err(e) = self.track.write_sample(&sample).await { - // Only log periodically to avoid spam when no peer connected - if nal_count % 100 == 0 { - debug!("Write sample failed (no peer?): {}", e); - } - } - - total_bytes += nal_data.len() as u64; - nal_count += 1; - } - - trace!( - "Sent frame: {} NAL units, {} bytes, keyframe={}", - nal_count, - total_bytes, - is_keyframe - ); - - Ok(()) - } - - /// Write frame with timestamp (for more precise timing control) - pub async fn write_frame_with_timestamp( - &self, - data: &[u8], - _pts_ms: u64, - is_keyframe: bool, - ) -> Result<()> { - // Convert pts from milliseconds to frame duration - // Assuming constant frame rate from config - let duration = Duration::from_millis(1000 / self.config.fps as u64); - self.write_frame(data, duration, is_keyframe).await - } - - /// Manually packetize H264 data into RTP payloads - /// - /// This is useful if you need direct control over RTP packets - /// (e.g., for sending via TrackLocalStaticRTP instead of TrackLocalStaticSample) - pub async fn packetize(&self, data: &[u8], mtu: usize) -> Result> { - let mut payloader = self.payloader.lock().await; - let bytes = Bytes::copy_from_slice(data); - - payloader - .payload(mtu, &bytes) - .map_err(|e| AppError::VideoError(format!("H264 packetization failed: {}", e))) - } - - /// Get configuration - pub fn config(&self) -> &H264VideoTrackConfig { - &self.config - } -} - -/// Opus audio track using TrackLocalStaticSample pub struct OpusAudioTrack { - /// The underlying WebRTC track track: Arc, } impl OpusAudioTrack { - /// Create a new Opus audio track pub fn new(track_id: &str, stream_id: &str) -> Self { let codec = RTCRtpCodecCapability { mime_type: "audio/opus".to_string(), @@ -333,28 +38,19 @@ impl OpusAudioTrack { Self { track } } - /// Get the underlying WebRTC track pub fn track(&self) -> Arc { self.track.clone() } - /// Get track as TrackLocal pub fn as_track_local(&self) -> Arc { self.track.clone() } - /// Write Opus encoded audio data - /// - /// # Arguments - /// * `data` - Opus encoded packet - /// * `samples` - Number of audio samples in this packet (typically 960 for 20ms at 48kHz) pub async fn write_packet(&self, data: &[u8], samples: u32) -> Result<()> { if data.is_empty() { return Ok(()); } - // Opus frame duration based on samples - // 48000 Hz, so duration = samples / 48000 seconds let duration = Duration::from_micros((samples as u64 * 1_000_000) / 48000); let sample = Sample { @@ -370,15 +66,12 @@ impl OpusAudioTrack { } } -/// Strip AUD (Access Unit Delimiter) NAL units from H264 Annex B data -/// AUD (NAL type 9) can cause decoding issues in some browser WebRTC implementations -/// Also strips filler data (NAL type 12) and SEI (NAL type 6) for cleaner output +/// Strips AUD (9) and filler (12) NALs; some WebRTC stacks dislike AUD. pub fn strip_aud_nal_units(data: &[u8]) -> Vec { let mut result = Vec::with_capacity(data.len()); let mut i = 0; while i < data.len() { - // Find start code (3 or 4 bytes) let (start_code_pos, start_code_len) = if i + 4 <= data.len() && data[i] == 0 && data[i + 1] == 0 @@ -400,7 +93,6 @@ pub fn strip_aud_nal_units(data: &[u8]) -> Vec { let nal_type = data[nal_start] & 0x1F; - // Find next start code to determine NAL unit end let mut nal_end = data.len(); let mut j = nal_start + 1; while j + 3 <= data.len() { @@ -417,17 +109,13 @@ pub fn strip_aud_nal_units(data: &[u8]) -> Vec { j += 1; } - // Skip AUD (9), filler (12), and optionally SEI (6) - // Keep SPS (7), PPS (8), IDR (5), non-IDR slice (1) if nal_type != 9 && nal_type != 12 { - // Include this NAL unit with start code result.extend_from_slice(&data[start_code_pos..nal_end]); } i = nal_end; } - // If nothing was stripped, return original data if result.is_empty() && !data.is_empty() { return data.to_vec(); } @@ -435,15 +123,12 @@ pub fn strip_aud_nal_units(data: &[u8]) -> Vec { result } -/// Extract SPS and PPS NAL units from H264 Annex B data -/// Returns (SPS data without start code, PPS data without start code) pub fn extract_sps_pps(data: &[u8]) -> (Option>, Option>) { let mut sps: Option> = None; let mut pps: Option> = None; let mut i = 0; while i < data.len() { - // Find start code (3 or 4 bytes) let start_code_len = if i + 4 <= data.len() && data[i] == 0 && data[i + 1] == 0 @@ -465,7 +150,6 @@ pub fn extract_sps_pps(data: &[u8]) -> (Option>, Option>) { let nal_type = data[nal_start] & 0x1F; - // Find next start code to determine NAL unit end let mut nal_end = data.len(); let mut j = nal_start + 1; while j + 3 <= data.len() { @@ -482,7 +166,6 @@ pub fn extract_sps_pps(data: &[u8]) -> (Option>, Option>) { j += 1; } - // Extract SPS (NAL type 7) and PPS (NAL type 8) without start codes match nal_type { 7 => { sps = Some(data[nal_start..nal_end].to_vec()); @@ -499,14 +182,12 @@ pub fn extract_sps_pps(data: &[u8]) -> (Option>, Option>) { (sps, pps) } -/// Check if H264 Annex B data contains SPS and PPS NAL units pub fn has_sps_pps(data: &[u8]) -> bool { let mut has_sps = false; let mut has_pps = false; let mut i = 0; while i < data.len() { - // Find start code (3 or 4 bytes) let start_code_len = if i + 4 <= data.len() && data[i] == 0 && data[i + 1] == 0 @@ -538,25 +219,20 @@ pub fn has_sps_pps(data: &[u8]) -> bool { return true; } - // Move past start code to next position i = nal_start + 1; } has_sps && has_pps } -/// Check if H264 data contains a keyframe (IDR NAL unit) pub fn is_h264_keyframe(data: &[u8]) -> bool { - // Look for IDR NAL unit (type 5) - // NAL units start with 0x00 0x00 0x01 or 0x00 0x00 0x00 0x01 let mut i = 0; while i < data.len() { - // Find start code if i + 3 < data.len() && data[i] == 0 && data[i + 1] == 0 { - let (nal_start, _start_code_len) = if data[i + 2] == 1 { - (i + 3, 3) + let nal_start = if data[i + 2] == 1 { + i + 3 } else if i + 4 < data.len() && data[i + 2] == 0 && data[i + 3] == 1 { - (i + 4, 4) + i + 4 } else { i += 1; continue; @@ -564,7 +240,6 @@ pub fn is_h264_keyframe(data: &[u8]) -> bool { if nal_start < data.len() { let nal_type = data[nal_start] & 0x1F; - // IDR = 5, SPS = 7, PPS = 8 if nal_type == 5 { return true; } @@ -577,22 +252,12 @@ pub fn is_h264_keyframe(data: &[u8]) -> bool { false } -/// Parse profile-level-id from SPS NAL unit data (without start code) -/// -/// Returns a 6-character hex string like "42001f" (Baseline L3.1) or "64001f" (High L3.1) -/// -/// SPS structure (first 4 bytes after NAL header): -/// - Byte 0: NAL header (0x67 for SPS) -/// - Byte 1: profile_idc (0x42=Baseline, 0x4D=Main, 0x64=High) -/// - Byte 2: constraint_set_flags -/// - Byte 3: level_idc (0x1f=3.1, 0x28=4.0, 0x33=5.1) +/// `profile-level-id` hex for SDP (`42001f` etc.); expects SPS NAL RBSP without start code. pub fn parse_profile_level_id_from_sps(sps: &[u8]) -> Option { - // SPS NAL must be at least 4 bytes: NAL header + profile_idc + constraints + level_idc if sps.len() < 4 { return None; } - // First byte is NAL header, skip it let profile_idc = sps[1]; let constraint_set_flags = sps[2]; let level_idc = sps[3]; @@ -603,37 +268,17 @@ pub fn parse_profile_level_id_from_sps(sps: &[u8]) -> Option { )) } -/// Extract profile-level-id from H264 Annex B data (containing SPS) -/// -/// This function finds the SPS NAL unit and extracts the profile-level-id. -/// Useful for determining the actual encoder output profile. -/// -/// # Example -/// ```ignore -/// let h264_data = encoder.encode(&yuv)?; -/// if let Some(profile_level_id) = extract_profile_level_id(&h264_data) { -/// println!("Encoder outputs profile-level-id: {}", profile_level_id); -/// // Use this to configure H264VideoTrackConfig -/// } -/// ``` pub fn extract_profile_level_id(data: &[u8]) -> Option { let (sps, _) = extract_sps_pps(data); sps.and_then(|sps_data| parse_profile_level_id_from_sps(&sps_data)) } -/// Common H.264 profile-level-id values pub mod profiles { - /// Constrained Baseline Profile Level 3.1 - Maximum browser compatibility pub const CONSTRAINED_BASELINE_31: &str = "42e01f"; - /// Baseline Profile Level 3.1 pub const BASELINE_31: &str = "42001f"; - /// Main Profile Level 3.1 pub const MAIN_31: &str = "4d001f"; - /// High Profile Level 3.1 - Hardware encoders typically output this pub const HIGH_31: &str = "64001f"; - /// High Profile Level 4.0 pub const HIGH_40: &str = "640028"; - /// High Profile Level 5.1 pub const HIGH_51: &str = "640033"; } @@ -643,36 +288,23 @@ mod tests { #[test] fn test_is_h264_keyframe() { - // IDR frame with 4-byte start code - let idr_frame = vec![0x00, 0x00, 0x00, 0x01, 0x65]; // NAL type 5 = IDR + let idr_frame = vec![0x00, 0x00, 0x00, 0x01, 0x65]; assert!(is_h264_keyframe(&idr_frame)); - // IDR frame with 3-byte start code let idr_frame_3 = vec![0x00, 0x00, 0x01, 0x65]; assert!(is_h264_keyframe(&idr_frame_3)); - // Non-IDR frame (P-frame, NAL type 1) let p_frame = vec![0x00, 0x00, 0x00, 0x01, 0x41]; assert!(!is_h264_keyframe(&p_frame)); - // SPS (NAL type 7) - not a keyframe by itself let sps = vec![0x00, 0x00, 0x00, 0x01, 0x67]; assert!(!is_h264_keyframe(&sps)); - // Multiple NAL units with IDR let multi_nal = vec![ - 0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f, // SPS - 0x00, 0x00, 0x00, 0x01, 0x68, 0xce, 0x38, 0x80, // PPS - 0x00, 0x00, 0x00, 0x01, 0x65, 0x88, 0x84, // IDR + 0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f, + 0x00, 0x00, 0x00, 0x01, 0x68, 0xce, 0x38, 0x80, + 0x00, 0x00, 0x00, 0x01, 0x65, 0x88, 0x84, ]; assert!(is_h264_keyframe(&multi_nal)); } - - #[test] - fn test_h264_track_config_default() { - let config = H264VideoTrackConfig::default(); - assert_eq!(config.fps, 30); - assert_eq!(config.bitrate_kbps, 8000); - assert_eq!(config.resolution, Resolution::HD720); - } } diff --git a/src/webrtc/session.rs b/src/webrtc/session.rs deleted file mode 100644 index d7f7372f..00000000 --- a/src/webrtc/session.rs +++ /dev/null @@ -1,196 +0,0 @@ -//! WebRTC session management - -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::RwLock; -use tracing::{debug, info}; - -use super::config::WebRtcConfig; -use super::peer::PeerConnection; -use super::signaling::{IceCandidate, SdpAnswer, SdpOffer}; -use crate::error::{AppError, Result}; - -/// Maximum concurrent WebRTC sessions -const MAX_SESSIONS: usize = 8; - -/// WebRTC session info -#[derive(Debug, Clone)] -pub struct SessionInfo { - pub session_id: String, - pub created_at: std::time::Instant, - pub state: String, -} - -/// WebRTC session manager -pub struct WebRtcSessionManager { - config: WebRtcConfig, - sessions: Arc>>>, -} - -impl WebRtcSessionManager { - /// Create a new session manager - pub fn new(config: WebRtcConfig) -> Self { - Self { - config, - sessions: Arc::new(RwLock::new(HashMap::new())), - } - } - - /// Create with default config - pub fn default_config() -> Self { - Self::new(WebRtcConfig::default()) - } - - /// Create a new WebRTC session - pub async fn create_session(&self) -> Result { - let sessions = self.sessions.read().await; - - // Check session limit - if sessions.len() >= MAX_SESSIONS { - return Err(AppError::WebRtcError(format!( - "Maximum sessions ({}) reached", - MAX_SESSIONS - ))); - } - drop(sessions); - - // Generate session ID - let session_id = uuid::Uuid::new_v4().to_string(); - - // Create new peer connection - let pc = PeerConnection::new(&self.config, session_id.clone()).await?; - - // Store session - let mut sessions = self.sessions.write().await; - sessions.insert(session_id.clone(), Arc::new(pc)); - - info!("WebRTC session created: {}", session_id); - Ok(session_id) - } - - /// Handle SDP offer and return answer - pub async fn handle_offer(&self, session_id: &str, offer: SdpOffer) -> Result { - let sessions = self.sessions.read().await; - let pc = sessions - .get(session_id) - .ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))? - .clone(); - drop(sessions); - - pc.handle_offer(offer).await - } - - /// Add ICE candidate - pub async fn add_ice_candidate(&self, session_id: &str, candidate: IceCandidate) -> Result<()> { - let sessions = self.sessions.read().await; - let pc = sessions - .get(session_id) - .ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))? - .clone(); - drop(sessions); - - pc.add_ice_candidate(candidate).await - } - - /// Get session info - pub async fn get_session(&self, session_id: &str) -> Option { - let sessions = self.sessions.read().await; - sessions.get(session_id).map(|pc| SessionInfo { - session_id: pc.session_id.clone(), - created_at: std::time::Instant::now(), // TODO: store actual time - state: format!("{:?}", pc.state()), - }) - } - - /// Close a session - pub async fn close_session(&self, session_id: &str) -> Result<()> { - let mut sessions = self.sessions.write().await; - - if let Some(pc) = sessions.remove(session_id) { - info!("WebRTC session closed: {}", session_id); - pc.close().await?; - } - - Ok(()) - } - - /// List all sessions - pub async fn list_sessions(&self) -> Vec { - let sessions = self.sessions.read().await; - sessions - .values() - .map(|pc| SessionInfo { - session_id: pc.session_id.clone(), - created_at: std::time::Instant::now(), - state: format!("{:?}", pc.state()), - }) - .collect() - } - - /// Clean up disconnected sessions - pub async fn cleanup_stale_sessions(&self) { - let sessions_to_remove: Vec = { - let sessions = self.sessions.read().await; - sessions - .iter() - .filter(|(_, pc)| { - matches!( - pc.state(), - super::signaling::ConnectionState::Disconnected - | super::signaling::ConnectionState::Failed - | super::signaling::ConnectionState::Closed - ) - }) - .map(|(id, _)| id.clone()) - .collect() - }; - - if !sessions_to_remove.is_empty() { - let mut sessions = self.sessions.write().await; - for id in sessions_to_remove { - debug!("Removing stale WebRTC session: {}", id); - sessions.remove(&id); - } - } - } - - /// Get session count - pub async fn session_count(&self) -> usize { - self.sessions.read().await.len() - } - - /// Start video streaming to a session - pub async fn start_video(&self, session_id: &str) -> Result<()> { - let sessions = self.sessions.read().await; - let _pc = sessions - .get(session_id) - .ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))? - .clone(); - drop(sessions); - - // Video track should already be added during peer creation - // This is a placeholder for additional video control logic - info!("Video streaming started for session: {}", session_id); - Ok(()) - } - - /// Stop video streaming to a session - pub async fn stop_video(&self, session_id: &str) -> Result<()> { - let sessions = self.sessions.read().await; - let _pc = sessions - .get(session_id) - .ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))? - .clone(); - drop(sessions); - - // Placeholder for video stop logic - info!("Video streaming stopped for session: {}", session_id); - Ok(()) - } -} - -impl Default for WebRtcSessionManager { - fn default() -> Self { - Self::default_config() - } -} diff --git a/src/webrtc/signaling.rs b/src/webrtc/signaling.rs index fbb1ecdb..df490bfc 100644 --- a/src/webrtc/signaling.rs +++ b/src/webrtc/signaling.rs @@ -1,27 +1,19 @@ -//! WebRTC signaling types and messages +//! SDP / ICE JSON types used by HTTP and WebSocket handlers. use serde::{Deserialize, Serialize}; -/// Signaling message types #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "lowercase")] pub enum SignalingMessage { - /// SDP Offer from client Offer(SdpOffer), - /// SDP Answer from server Answer(SdpAnswer), - /// ICE candidate Candidate(IceCandidate), - /// Connection error Error(SignalingError), - /// Connection closed Close, } -/// SDP Offer from client #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SdpOffer { - /// SDP content pub sdp: String, } @@ -31,12 +23,9 @@ impl SdpOffer { } } -/// SDP Answer from server #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SdpAnswer { - /// SDP content pub sdp: String, - /// ICE candidates gathered during answer creation #[serde(skip_serializing_if = "Option::is_none")] pub ice_candidates: Option>, } @@ -61,18 +50,13 @@ impl SdpAnswer { } } -/// ICE candidate #[derive(Debug, Clone, Serialize, Deserialize)] pub struct IceCandidate { - /// Candidate string pub candidate: String, - /// SDP mid (media ID) #[serde(rename = "sdpMid")] pub sdp_mid: Option, - /// SDP mline index #[serde(rename = "sdpMLineIndex")] pub sdp_mline_index: Option, - /// Username fragment #[serde(rename = "usernameFragment")] pub username_fragment: Option, } @@ -94,12 +78,9 @@ impl IceCandidate { } } -/// Signaling error #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SignalingError { - /// Error code pub code: u32, - /// Error message pub message: String, } @@ -124,24 +105,17 @@ impl SignalingError { } } -/// WebRTC offer request (from HTTP API) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OfferRequest { - /// SDP offer pub sdp: String, - /// Client ID (optional, for tracking) #[serde(skip_serializing_if = "Option::is_none")] pub client_id: Option, } -/// WebRTC answer response (from HTTP API) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AnswerResponse { - /// SDP answer pub sdp: String, - /// Session ID for this connection pub session_id: String, - /// ICE candidates #[serde(skip_serializing_if = "Vec::is_empty")] pub ice_candidates: Vec, } @@ -160,16 +134,12 @@ impl AnswerResponse { } } -/// ICE candidate request (trickle ICE) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct IceCandidateRequest { - /// Session ID pub session_id: String, - /// ICE candidate pub candidate: IceCandidate, } -/// Connection state notification #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ConnectionState { diff --git a/src/webrtc/track.rs b/src/webrtc/track.rs deleted file mode 100644 index 3466fdd2..00000000 --- a/src/webrtc/track.rs +++ /dev/null @@ -1,299 +0,0 @@ -//! WebRTC track implementations for video and audio - -use std::sync::Arc; -use std::time::Instant; -use tokio::sync::{broadcast, watch}; -use tracing::{debug, error, info}; -use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; -use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; -use webrtc::track::track_local::TrackLocalWriter; - -use crate::video::frame::VideoFrame; - -/// Video track configuration -#[derive(Debug, Clone)] -pub struct VideoTrackConfig { - /// Track ID - pub track_id: String, - /// Stream ID - pub stream_id: String, - /// Video codec - pub codec: VideoCodecType, - /// Clock rate - pub clock_rate: u32, - /// Target bitrate - pub bitrate_kbps: u32, -} - -impl Default for VideoTrackConfig { - fn default() -> Self { - Self { - track_id: "video0".to_string(), - stream_id: "one-kvm-stream".to_string(), - codec: VideoCodecType::H264, - clock_rate: 90000, - bitrate_kbps: 8000, - } - } -} - -/// Video codec type -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum VideoCodecType { - H264, - VP8, - VP9, -} - -impl VideoCodecType { - pub fn mime_type(&self) -> &'static str { - match self { - VideoCodecType::H264 => "video/H264", - VideoCodecType::VP8 => "video/VP8", - VideoCodecType::VP9 => "video/VP9", - } - } - - pub fn sdp_fmtp(&self) -> &'static str { - match self { - VideoCodecType::H264 => { - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f" - } - VideoCodecType::VP8 => "", - VideoCodecType::VP9 => "profile-id=0", - } - } -} - -/// Create RTP codec capability for video -pub fn video_codec_capability(codec: VideoCodecType, clock_rate: u32) -> RTCRtpCodecCapability { - RTCRtpCodecCapability { - mime_type: codec.mime_type().to_string(), - clock_rate, - channels: 0, - sdp_fmtp_line: codec.sdp_fmtp().to_string(), - rtcp_feedback: vec![], - } -} - -/// Create RTP codec capability for audio (Opus) -pub fn audio_codec_capability() -> RTCRtpCodecCapability { - RTCRtpCodecCapability { - mime_type: "audio/opus".to_string(), - clock_rate: 48000, - channels: 2, - sdp_fmtp_line: "minptime=10;useinbandfec=1".to_string(), - rtcp_feedback: vec![], - } -} - -/// Video track for WebRTC streaming -pub struct VideoTrack { - config: VideoTrackConfig, - /// RTP track - track: Arc, - /// Running flag - running: Arc>, -} - -impl VideoTrack { - /// Create a new video track - pub fn new(config: VideoTrackConfig) -> Self { - let capability = video_codec_capability(config.codec, config.clock_rate); - - let track = Arc::new(TrackLocalStaticRTP::new( - capability, - config.track_id.clone(), - config.stream_id.clone(), - )); - - let (running_tx, _) = watch::channel(false); - - Self { - config, - track, - running: Arc::new(running_tx), - } - } - - /// Get the underlying RTP track - pub fn rtp_track(&self) -> Arc { - self.track.clone() - } - - /// Start sending frames from a broadcast receiver - pub async fn start_sending(&self, mut frame_rx: broadcast::Receiver) { - let _ = self.running.send(true); - let track = self.track.clone(); - let clock_rate = self.config.clock_rate; - let mut running_rx = self.running.subscribe(); - - info!("Starting video track sender"); - - tokio::spawn(async move { - let mut state = SendState::default(); - loop { - tokio::select! { - result = frame_rx.recv() => { - match result { - Ok(frame) => { - if let Err(e) = Self::send_frame( - &track, - &frame, - &mut state, - clock_rate, - ).await { - debug!("Failed to send frame: {}", e); - } - } - Err(broadcast::error::RecvError::Lagged(n)) => { - debug!("Video track lagged by {} frames", n); - } - Err(broadcast::error::RecvError::Closed) => { - debug!("Frame channel closed"); - break; - } - } - } - _ = running_rx.changed() => { - if !*running_rx.borrow() { - debug!("Video track stopped"); - break; - } - } - } - } - - info!("Video track sender stopped"); - }); - } - - /// Stop sending - pub fn stop(&self) { - let _ = self.running.send(false); - } - - /// Send a single frame as RTP packets - async fn send_frame( - track: &TrackLocalStaticRTP, - frame: &VideoFrame, - state: &mut SendState, - clock_rate: u32, - ) -> Result<(), Box> { - // Calculate timestamp increment based on frame timing - let now = Instant::now(); - let timestamp_increment = if let Some(last) = state.last_frame_time { - let elapsed = now.duration_since(last); - ((elapsed.as_secs_f64() * clock_rate as f64) as u32).min(clock_rate / 10) - } else { - clock_rate / 30 // Default to 30 fps - }; - state.last_frame_time = Some(now); - - // Update timestamp - state.timestamp = state.timestamp.wrapping_add(timestamp_increment); - let _current_ts = state.timestamp; - - // For H.264, we need to packetize into RTP - // This is a simplified implementation - real implementation needs proper NAL unit handling - let data = frame.data(); - let max_payload_size = 1200; // MTU - headers - - let packet_count = data.len().div_ceil(max_payload_size); - let mut bytes_sent = 0u64; - - for i in 0..packet_count { - let start = i * max_payload_size; - let end = ((i + 1) * max_payload_size).min(data.len()); - let _is_last = i == packet_count - 1; - - // Get sequence number - let _seq_num = state.sequence_number; - state.sequence_number = state.sequence_number.wrapping_add(1); - - // Build RTP packet payload - // For simplicity, just send raw data - real implementation needs proper RTP packetization - let payload = &data[start..end]; - bytes_sent += payload.len() as u64; - - // Write sample (the track handles RTP header construction) - if let Err(e) = track.write(payload).await { - error!("Failed to write RTP packet: {}", e); - return Err(e.into()); - } - } - - let _ = bytes_sent; - - Ok(()) - } -} - -#[derive(Debug, Default)] -struct SendState { - sequence_number: u16, - timestamp: u32, - last_frame_time: Option, -} - -/// Audio track configuration -#[derive(Debug, Clone)] -pub struct AudioTrackConfig { - /// Track ID - pub track_id: String, - /// Stream ID - pub stream_id: String, - /// Sample rate - pub sample_rate: u32, - /// Channels - pub channels: u8, -} - -impl Default for AudioTrackConfig { - fn default() -> Self { - Self { - track_id: "audio0".to_string(), - stream_id: "one-kvm-stream".to_string(), - sample_rate: 48000, - channels: 2, - } - } -} - -/// Audio track for WebRTC streaming -pub struct AudioTrack { - /// RTP track - track: Arc, - /// Running flag - running: Arc>, -} - -impl AudioTrack { - /// Create a new audio track - pub fn new(config: AudioTrackConfig) -> Self { - let capability = audio_codec_capability(); - - let track = Arc::new(TrackLocalStaticRTP::new( - capability, - config.track_id.clone(), - config.stream_id.clone(), - )); - - let (running_tx, _) = watch::channel(false); - - Self { - track, - running: Arc::new(running_tx), - } - } - - /// Get the underlying RTP track - pub fn rtp_track(&self) -> Arc { - self.track.clone() - } - - /// Stop sending - pub fn stop(&self) { - let _ = self.running.send(false); - } -} diff --git a/src/webrtc/unified_video_track.rs b/src/webrtc/unified_video_track.rs deleted file mode 100644 index d288f0e0..00000000 --- a/src/webrtc/unified_video_track.rs +++ /dev/null @@ -1,639 +0,0 @@ -//! Unified video track supporting H264, H265, VP8, VP9 -//! -//! This module provides a unified video track implementation that supports -//! multiple video codecs with proper RTP packetization. -//! -//! # Supported Codecs -//! -//! - **H264**: NAL unit parsing with SPS/PPS caching (RFC 6184) -//! - **H265**: NAL unit parsing with VPS/SPS/PPS caching (RFC 7798) -//! - **VP8**: Direct frame sending (RFC 7741) -//! - **VP9**: Direct frame sending (draft-ietf-payload-vp9) -//! -//! # Architecture -//! -//! For NAL-based codecs (H264/H265): -//! - Parse NAL units from Annex B format -//! - Cache parameter sets (SPS/PPS/VPS) for injection -//! - Send each NAL unit via TrackLocalStaticSample -//! -//! For VP8/VP9: -//! - Send raw encoded frames directly -//! - webrtc-rs handles RTP packetization internally - -use bytes::Bytes; -use std::io::Cursor; -use std::sync::Arc; -use std::time::Duration; -use tokio::sync::Mutex; -use tracing::{debug, trace, warn}; -use webrtc::media::io::h264_reader::H264Reader; -use webrtc::media::Sample; -use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; -use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; -use webrtc::track::track_local::TrackLocal; - -use crate::error::{AppError, Result}; -use crate::video::format::Resolution; - -/// Video codec type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum VideoCodec { - H264, - H265, - VP8, - VP9, -} - -impl VideoCodec { - /// Get MIME type for this codec - pub fn mime_type(&self) -> &'static str { - match self { - VideoCodec::H264 => "video/H264", - VideoCodec::H265 => "video/H265", - VideoCodec::VP8 => "video/VP8", - VideoCodec::VP9 => "video/VP9", - } - } - - /// Get clock rate (always 90kHz for video) - pub fn clock_rate(&self) -> u32 { - 90000 - } - - /// Get SDP fmtp line for this codec - pub fn sdp_fmtp_line(&self) -> String { - match self { - VideoCodec::H264 => { - "level-asymmetry-allowed=1;packetization-mode=1".to_string() - } - VideoCodec::H265 => { - // H265 fmtp parameters - String::new() - } - VideoCodec::VP8 => String::new(), - VideoCodec::VP9 => String::new(), - } - } - - /// Check if codec uses NAL units (H264/H265) - pub fn uses_nal_units(&self) -> bool { - matches!(self, VideoCodec::H264 | VideoCodec::H265) - } -} - -impl std::fmt::Display for VideoCodec { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - VideoCodec::H264 => write!(f, "H264"), - VideoCodec::H265 => write!(f, "H265"), - VideoCodec::VP8 => write!(f, "VP8"), - VideoCodec::VP9 => write!(f, "VP9"), - } - } -} - -/// Unified video track configuration -#[derive(Debug, Clone)] -pub struct UnifiedVideoTrackConfig { - /// Video codec - pub codec: VideoCodec, - /// Track ID - pub track_id: String, - /// Stream ID - pub stream_id: String, - /// Resolution - pub resolution: Resolution, - /// Target bitrate in kbps - pub bitrate_kbps: u32, - /// Frames per second - pub fps: u32, -} - -impl Default for UnifiedVideoTrackConfig { - fn default() -> Self { - Self { - codec: VideoCodec::H264, - track_id: "video0".to_string(), - stream_id: "one-kvm-stream".to_string(), - resolution: Resolution::HD720, - bitrate_kbps: 8000, - fps: 30, - } - } -} - -/// Cached NAL parameter sets for H264 -struct H264ParameterSets { - sps: Option, - pps: Option, -} - -/// Cached NAL parameter sets for H265 -struct H265ParameterSets { - vps: Option, - sps: Option, - pps: Option, -} - -/// NAL type constants for H264 -mod h264_nal { - pub const NON_IDR_SLICE: u8 = 1; - pub const IDR_SLICE: u8 = 5; - pub const SEI: u8 = 6; - pub const SPS: u8 = 7; - pub const PPS: u8 = 8; - pub const AUD: u8 = 9; - pub const FILLER: u8 = 12; -} - -/// NAL type constants for H265 -mod h265_nal { - pub const IDR_W_RADL: u8 = 19; - pub const IDR_N_LP: u8 = 20; - pub const CRA_NUT: u8 = 21; - pub const VPS: u8 = 32; - pub const SPS: u8 = 33; - pub const PPS: u8 = 34; - pub const AUD: u8 = 35; - pub const FD_NUT: u8 = 38; // Filler data - - /// Check if NAL type is an IDR frame - pub fn is_idr(nal_type: u8) -> bool { - nal_type == IDR_W_RADL || nal_type == IDR_N_LP || nal_type == CRA_NUT - } -} - -/// Unified video track supporting multiple codecs -pub struct UnifiedVideoTrack { - /// The underlying WebRTC track - track: Arc, - /// Track configuration - config: UnifiedVideoTrackConfig, - /// H264 parameter set cache - h264_params: Mutex, - /// H265 parameter set cache - h265_params: Mutex, -} - -impl UnifiedVideoTrack { - /// Create a new unified video track - pub fn new(config: UnifiedVideoTrackConfig) -> Self { - let codec_capability = RTCRtpCodecCapability { - mime_type: config.codec.mime_type().to_string(), - clock_rate: config.codec.clock_rate(), - channels: 0, - sdp_fmtp_line: config.codec.sdp_fmtp_line(), - rtcp_feedback: vec![], - }; - - let track = Arc::new(TrackLocalStaticSample::new( - codec_capability, - config.track_id.clone(), - config.stream_id.clone(), - )); - - Self { - track, - config, - h264_params: Mutex::new(H264ParameterSets { sps: None, pps: None }), - h265_params: Mutex::new(H265ParameterSets { vps: None, sps: None, pps: None }), - } - } - - /// Create track for H264 - pub fn h264(track_id: &str, stream_id: &str, resolution: Resolution, fps: u32) -> Self { - Self::new(UnifiedVideoTrackConfig { - codec: VideoCodec::H264, - track_id: track_id.to_string(), - stream_id: stream_id.to_string(), - resolution, - fps, - ..Default::default() - }) - } - - /// Create track for H265 - pub fn h265(track_id: &str, stream_id: &str, resolution: Resolution, fps: u32) -> Self { - Self::new(UnifiedVideoTrackConfig { - codec: VideoCodec::H265, - track_id: track_id.to_string(), - stream_id: stream_id.to_string(), - resolution, - fps, - ..Default::default() - }) - } - - /// Create track for VP8 - pub fn vp8(track_id: &str, stream_id: &str, resolution: Resolution, fps: u32) -> Self { - Self::new(UnifiedVideoTrackConfig { - codec: VideoCodec::VP8, - track_id: track_id.to_string(), - stream_id: stream_id.to_string(), - resolution, - fps, - ..Default::default() - }) - } - - /// Create track for VP9 - pub fn vp9(track_id: &str, stream_id: &str, resolution: Resolution, fps: u32) -> Self { - Self::new(UnifiedVideoTrackConfig { - codec: VideoCodec::VP9, - track_id: track_id.to_string(), - stream_id: stream_id.to_string(), - resolution, - fps, - ..Default::default() - }) - } - - /// Get the underlying track for peer connection - pub fn track(&self) -> Arc { - self.track.clone() - } - - /// Get track as TrackLocal for peer connection - pub fn as_track_local(&self) -> Arc { - self.track.clone() - } - - /// Get current codec - pub fn codec(&self) -> VideoCodec { - self.config.codec - } - - /// Get statistics - - /// Write an encoded frame to the track - /// - /// The frame data should be in the appropriate format for the codec: - /// - H264/H265: Annex B format (with start codes) - /// - VP8/VP9: Raw encoded frame - pub async fn write_frame(&self, data: &[u8], _duration: Duration, is_keyframe: bool) -> Result<()> { - if data.is_empty() { - return Ok(()); - } - - match self.config.codec { - VideoCodec::H264 => self.write_h264_frame(data, is_keyframe).await, - VideoCodec::H265 => self.write_h265_frame(data, is_keyframe).await, - VideoCodec::VP8 => self.write_vp8_frame(data, is_keyframe).await, - VideoCodec::VP9 => self.write_vp9_frame(data, is_keyframe).await, - } - } - - /// Write H264 frame (Annex B format) - async fn write_h264_frame(&self, data: &[u8], is_keyframe: bool) -> Result<()> { - let cursor = Cursor::new(data); - let mut reader = H264Reader::new(cursor, 1024 * 1024); - - let mut nals: Vec = Vec::new(); - let mut has_sps = false; - let mut has_pps = false; - let mut has_idr = false; - - // Parse NAL units - while let Ok(nal) = reader.next_nal() { - if nal.data.is_empty() { - continue; - } - - let nal_type = nal.data[0] & 0x1F; - - // Skip AUD and filler NAL units - if nal_type == h264_nal::AUD || nal_type == h264_nal::FILLER { - continue; - } - - match nal_type { - h264_nal::IDR_SLICE => has_idr = true, - h264_nal::SPS => { - has_sps = true; - *self.h264_params.lock().await = H264ParameterSets { - sps: Some(nal.data.clone().freeze()), - pps: self.h264_params.lock().await.pps.clone(), - }; - } - h264_nal::PPS => { - has_pps = true; - let mut params = self.h264_params.lock().await; - params.pps = Some(nal.data.clone().freeze()); - } - _ => {} - } - - nals.push(nal.data.freeze()); - } - - // Inject cached SPS/PPS before IDR if missing - if has_idr && (!has_sps || !has_pps) { - let params = self.h264_params.lock().await; - let mut injected: Vec = Vec::new(); - - if !has_sps { - if let Some(ref sps) = params.sps { - debug!("Injecting cached H264 SPS"); - injected.push(sps.clone()); - } - } - if !has_pps { - if let Some(ref pps) = params.pps { - debug!("Injecting cached H264 PPS"); - injected.push(pps.clone()); - } - } - - if !injected.is_empty() { - injected.extend(nals); - nals = injected; - } - } - - // Send NAL units - self.send_nal_units(nals, is_keyframe).await - } - - /// Write H265 frame (Annex B format) - async fn write_h265_frame(&self, data: &[u8], is_keyframe: bool) -> Result<()> { - let mut nals: Vec = Vec::new(); - let mut has_vps = false; - let mut has_sps = false; - let mut has_pps = false; - let mut has_idr = false; - - // Parse H265 NAL units manually (H264Reader works for both since format is similar) - let mut i = 0; - while i < data.len() { - // Find start code - let (start_code_len, nal_start) = if i + 4 <= data.len() - && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 0 && data[i + 3] == 1 - { - (4, i + 4) - } else if i + 3 <= data.len() - && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 - { - (3, i + 3) - } else { - i += 1; - continue; - }; - - if nal_start >= data.len() { - break; - } - - // Find end of NAL unit (next start code or end of data) - let mut nal_end = data.len(); - let mut j = nal_start + 1; - while j + 3 <= data.len() { - if (data[j] == 0 && data[j + 1] == 0 && data[j + 2] == 1) - || (j + 4 <= data.len() && data[j] == 0 && data[j + 1] == 0 - && data[j + 2] == 0 && data[j + 3] == 1) - { - nal_end = j; - break; - } - j += 1; - } - - let nal_data = &data[nal_start..nal_end]; - if nal_data.is_empty() { - i = nal_end; - continue; - } - - // H265 NAL type: (first_byte >> 1) & 0x3F - let nal_type = (nal_data[0] >> 1) & 0x3F; - - // Skip AUD and filler - if nal_type == h265_nal::AUD || nal_type == h265_nal::FD_NUT { - i = nal_end; - continue; - } - - match nal_type { - h265_nal::VPS => { - has_vps = true; - let mut params = self.h265_params.lock().await; - params.vps = Some(Bytes::copy_from_slice(nal_data)); - } - h265_nal::SPS => { - has_sps = true; - let mut params = self.h265_params.lock().await; - params.sps = Some(Bytes::copy_from_slice(nal_data)); - } - h265_nal::PPS => { - has_pps = true; - let mut params = self.h265_params.lock().await; - params.pps = Some(Bytes::copy_from_slice(nal_data)); - } - _ if h265_nal::is_idr(nal_type) => { - has_idr = true; - } - _ => {} - } - - trace!("H265 NAL: type={} size={}", nal_type, nal_data.len()); - nals.push(Bytes::copy_from_slice(nal_data)); - i = nal_end; - } - - // Inject cached VPS/SPS/PPS before IDR if missing - if has_idr && (!has_vps || !has_sps || !has_pps) { - let params = self.h265_params.lock().await; - let mut injected: Vec = Vec::new(); - - if !has_vps { - if let Some(ref vps) = params.vps { - debug!("Injecting cached H265 VPS"); - injected.push(vps.clone()); - } - } - if !has_sps { - if let Some(ref sps) = params.sps { - debug!("Injecting cached H265 SPS"); - injected.push(sps.clone()); - } - } - if !has_pps { - if let Some(ref pps) = params.pps { - debug!("Injecting cached H265 PPS"); - injected.push(pps.clone()); - } - } - - if !injected.is_empty() { - injected.extend(nals); - nals = injected; - } - } - - self.send_nal_units(nals, is_keyframe).await - } - - /// Write VP8 frame (raw encoded) - async fn write_vp8_frame(&self, data: &[u8], is_keyframe: bool) -> Result<()> { - // Calculate frame duration based on configured FPS - let frame_duration = Duration::from_micros(1_000_000 / self.config.fps.max(1) as u64); - - // VP8 frames are sent directly - let sample = Sample { - data: Bytes::copy_from_slice(data), - duration: frame_duration, - ..Default::default() - }; - - if let Err(e) = self.track.write_sample(&sample).await { - debug!("VP8 write_sample failed: {}", e); - } - - trace!("VP8 frame: {} bytes, keyframe={}", data.len(), is_keyframe); - Ok(()) - } - - /// Write VP9 frame (raw encoded) - async fn write_vp9_frame(&self, data: &[u8], is_keyframe: bool) -> Result<()> { - // Calculate frame duration based on configured FPS - let frame_duration = Duration::from_micros(1_000_000 / self.config.fps.max(1) as u64); - - // VP9 frames are sent directly - let sample = Sample { - data: Bytes::copy_from_slice(data), - duration: frame_duration, - ..Default::default() - }; - - if let Err(e) = self.track.write_sample(&sample).await { - debug!("VP9 write_sample failed: {}", e); - } - - trace!("VP9 frame: {} bytes, keyframe={}", data.len(), is_keyframe); - Ok(()) - } - - /// Send NAL units via track (for H264/H265) - /// - /// Important: Only the last NAL unit should have the frame duration set. - /// All NAL units in a frame share the same RTP timestamp, so only the last - /// one should increment the timestamp by the frame duration. - async fn send_nal_units(&self, nals: Vec, is_keyframe: bool) -> Result<()> { - let mut total_bytes = 0u64; - let nal_count = nals.len(); - // Calculate frame duration based on configured FPS - let frame_duration = Duration::from_micros(1_000_000 / self.config.fps.max(1) as u64); - - for (i, nal_data) in nals.into_iter().enumerate() { - let is_last = i == nal_count - 1; - // Only the last NAL should have duration set - // This ensures all NALs in a frame share the same RTP timestamp - let sample = Sample { - data: nal_data.clone(), - duration: if is_last { frame_duration } else { Duration::ZERO }, - ..Default::default() - }; - - if let Err(e) = self.track.write_sample(&sample).await { - if i % 100 == 0 { - debug!("write_sample failed (no peer?): {}", e); - } - } - - total_bytes += nal_data.len() as u64; - } - - trace!("Sent {} NAL units, {} bytes, keyframe={}", nal_count, total_bytes, is_keyframe); - Ok(()) - } - - /// Get configuration - pub fn config(&self) -> &UnifiedVideoTrackConfig { - &self.config - } -} - -/// Check if VP8 frame is a keyframe -pub fn is_vp8_keyframe(data: &[u8]) -> bool { - if data.is_empty() { - return false; - } - // VP8 keyframe detection: first byte bit 0 is 0 for keyframe - (data[0] & 0x01) == 0 -} - -/// Check if VP9 frame is a keyframe -pub fn is_vp9_keyframe(data: &[u8]) -> bool { - if data.is_empty() { - return false; - } - // VP9 keyframe detection: bit 2 of first byte is 0 for keyframe - (data[0] & 0x04) == 0 -} - -/// Check if H265 frame contains IDR NAL unit -pub fn is_h265_keyframe(data: &[u8]) -> bool { - let mut i = 0; - while i < data.len() { - // Find start code - let nal_start = if i + 4 <= data.len() - && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 0 && data[i + 3] == 1 - { - i + 4 - } else if i + 3 <= data.len() - && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 - { - i + 3 - } else { - i += 1; - continue; - }; - - if nal_start >= data.len() { - break; - } - - // H265 NAL type - let nal_type = (data[nal_start] >> 1) & 0x3F; - if h265_nal::is_idr(nal_type) { - return true; - } - - i = nal_start + 1; - } - false -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_video_codec_mime_types() { - assert_eq!(VideoCodec::H264.mime_type(), "video/H264"); - assert_eq!(VideoCodec::H265.mime_type(), "video/H265"); - assert_eq!(VideoCodec::VP8.mime_type(), "video/VP8"); - assert_eq!(VideoCodec::VP9.mime_type(), "video/VP9"); - } - - #[test] - fn test_h265_nal_type() { - // H265 NAL type is (first_byte >> 1) & 0x3F - // VPS: type 32 = 0x40 >> 1 = 32 - let vps_header = 0x40u8; // VPS type 32 - let nal_type = (vps_header >> 1) & 0x3F; - assert_eq!(nal_type, 32); - - // IDR_W_RADL: type 19 - let idr_header = 0x26u8; // type 19 = 0x13 << 1 = 0x26 - let nal_type = (idr_header >> 1) & 0x3F; - assert_eq!(nal_type, 19); - } - - #[test] - fn test_vp8_keyframe_detection() { - // VP8 keyframe: bit 0 is 0 - assert!(is_vp8_keyframe(&[0x00])); - assert!(!is_vp8_keyframe(&[0x01])); - } -} diff --git a/src/webrtc/universal_session.rs b/src/webrtc/universal_session.rs index b52203c5..2cf41c57 100644 --- a/src/webrtc/universal_session.rs +++ b/src/webrtc/universal_session.rs @@ -1,7 +1,4 @@ -//! Universal WebRTC session with multi-codec support -//! -//! Provides WebRTC sessions that can use any supported video codec (H264, H265, VP8, VP9). -//! Replaces the H264-only H264Session with a more flexible implementation. +//! One browser session: negotiated [`RTCPeerConnection`], outbound video/audio, HID DataChannel. use std::sync::Arc; use std::time::{Duration, Instant}; @@ -36,18 +33,15 @@ use crate::error::{AppError, Result}; use crate::events::{EventBus, SystemEvent}; use crate::hid::datachannel::{parse_hid_message, HidChannelEvent}; use crate::hid::HidController; -use crate::video::encoder::registry::VideoEncoderType; -use crate::video::encoder::BitratePreset; -use crate::video::format::{PixelFormat, Resolution}; -use crate::video::shared_video_pipeline::EncodedVideoFrame; +use crate::video::types::{ + BitratePreset, EncodedVideoFrame, PixelFormat, Resolution, VideoEncoderType, +}; use std::sync::atomic::AtomicBool; use webrtc::ice_transport::ice_gatherer_state::RTCIceGathererState; -/// H.265/HEVC MIME type (RFC 7798) const MIME_TYPE_H265: &str = "video/H265"; fn h264_contains_parameter_sets(data: &[u8]) -> bool { - // Annex-B start code path let mut i = 0usize; while i + 4 <= data.len() { let sc_len = if i + 4 <= data.len() @@ -74,7 +68,6 @@ fn h264_contains_parameter_sets(data: &[u8]) -> bool { i = nal_start.saturating_add(1); } - // Length-prefixed fallback let mut pos = 0usize; while pos + 4 <= data.len() { let nalu_len = @@ -93,22 +86,14 @@ fn h264_contains_parameter_sets(data: &[u8]) -> bool { false } -/// Universal WebRTC session configuration #[derive(Debug, Clone)] pub struct UniversalSessionConfig { - /// WebRTC configuration pub webrtc: WebRtcConfig, - /// Video codec type pub codec: VideoEncoderType, - /// Input resolution pub resolution: Resolution, - /// Input pixel format pub input_format: PixelFormat, - /// Bitrate preset pub bitrate_preset: BitratePreset, - /// Target FPS pub fps: u32, - /// Enable audio track pub audio_enabled: bool, } @@ -127,7 +112,6 @@ impl Default for UniversalSessionConfig { } impl UniversalSessionConfig { - /// Create config for specific codec pub fn with_codec(codec: VideoEncoderType) -> Self { Self { codec, @@ -136,7 +120,6 @@ impl UniversalSessionConfig { } } -/// Convert VideoEncoderType to VideoCodec fn encoder_type_to_video_codec(encoder_type: VideoEncoderType) -> VideoCodec { match encoder_type { VideoEncoderType::H264 => VideoCodec::H264, @@ -146,43 +129,24 @@ fn encoder_type_to_video_codec(encoder_type: VideoEncoderType) -> VideoCodec { } } -/// Universal WebRTC session -/// -/// Receives pre-encoded video frames and sends via WebRTC. -/// Supports H264, H265, VP8, VP9 codecs. pub struct UniversalSession { - /// Session ID pub session_id: String, - /// Video codec type codec: VideoEncoderType, - /// WebRTC peer connection pc: Arc, - /// Video track for RTP packetization video_track: Arc, - /// Opus audio track (optional) audio_track: Option>, - /// Data channel for HID events data_channel: Arc>>>, - /// Connection state state: Arc>, - /// State receiver state_rx: watch::Receiver, - /// ICE candidates gathered ice_candidates: Arc>>, - /// HID controller reference hid_controller: Option>, - /// Event bus for WebRTC signaling events (optional) event_bus: Option>, - /// Video frame receiver handle video_receiver_handle: Mutex>>, - /// Audio frame receiver handle audio_receiver_handle: Mutex>>, - /// FPS configuration fps: u32, } impl UniversalSession { - /// Create a new universal WebRTC session pub async fn new( config: UniversalSessionConfig, session_id: String, @@ -197,7 +161,6 @@ impl UniversalSession { config.audio_enabled ); - // Create video track with appropriate codec let video_codec = encoder_type_to_video_codec(config.codec); let track_config = UniversalVideoTrackConfig { track_id: format!("video-{}", &session_id[..8.min(session_id.len())]), @@ -209,7 +172,6 @@ impl UniversalSession { }; let video_track = Arc::new(UniversalVideoTrack::new(track_config)); - // Create Opus audio track if enabled let audio_track = if config.audio_enabled { Some(Arc::new(OpusAudioTrack::new( &format!("audio-{}", &session_id[..8.min(session_id.len())]), @@ -219,11 +181,9 @@ impl UniversalSession { None }; - // Create media engine let mut media_engine = MediaEngine::default(); - // Register H.265/HEVC codec (not included in default codecs) - // According to RFC 7798, H.265 uses MIME type video/H265 + // H265 is not registered by register_default_codecs. if config.codec == VideoEncoderType::H265 { let video_rtcp_feedback = vec![ RTCPFeedback { @@ -244,8 +204,6 @@ impl UniversalSession { }, ]; - // Register H.265 with profile-id=1 (Main profile) - matches Chrome's offer - // Chrome sends: level-id=180;profile-id=1;tier-flag=0;tx-mode=SRST media_engine .register_codec( RTCRtpCodecParameters { @@ -253,12 +211,11 @@ impl UniversalSession { mime_type: MIME_TYPE_H265.to_owned(), clock_rate: 90000, channels: 0, - // Match browser's fmtp format for profile-id=1 sdp_fmtp_line: "level-id=180;profile-id=1;tier-flag=0;tx-mode=SRST" .to_owned(), rtcp_feedback: video_rtcp_feedback.clone(), }, - payload_type: 49, // Use same payload type as browser + payload_type: 49, ..Default::default() }, RTPCodecType::Video, @@ -267,7 +224,6 @@ impl UniversalSession { AppError::VideoError(format!("Failed to register H.265 codec: {}", e)) })?; - // Also register profile-id=2 (Main 10) variant media_engine .register_codec( RTCRtpCodecParameters { @@ -298,12 +254,10 @@ impl UniversalSession { .register_default_codecs() .map_err(|e| AppError::VideoError(format!("Failed to register codecs: {}", e)))?; - // Create interceptor registry let mut registry = Registry::new(); registry = register_default_interceptors(registry, &mut media_engine) .map_err(|e| AppError::VideoError(format!("Failed to register interceptors: {}", e)))?; - // Create API (with optional mDNS settings) let mut setting_engine = SettingEngine::default(); let mode = mdns_mode(); setting_engine.set_ice_multicast_dns_mode(mode); @@ -318,7 +272,6 @@ impl UniversalSession { .with_interceptor_registry(registry) .build(); - // Build ICE servers let mut ice_servers = vec![]; for stun_url in &config.webrtc.stun_servers { ice_servers.push(RTCIceServer { @@ -342,7 +295,6 @@ impl UniversalSession { }); } - // Create peer connection let rtc_config = RTCConfiguration { ice_servers, ..Default::default() @@ -354,7 +306,6 @@ impl UniversalSession { let pc = Arc::new(pc); - // Add video track to peer connection pc.add_track(video_track.as_track_local()) .await .map_err(|e| AppError::VideoError(format!("Failed to add video track: {}", e)))?; @@ -364,7 +315,6 @@ impl UniversalSession { config.codec, session_id ); - // Add Opus audio track if enabled if let Some(ref audio) = audio_track { pc.add_track(audio.as_track_local()) .await @@ -375,7 +325,6 @@ impl UniversalSession { ); } - // Create state channel let (state_tx, state_rx) = watch::channel(ConnectionState::New); let session = Self { @@ -395,20 +344,17 @@ impl UniversalSession { fps: config.fps, }; - // Set up event handlers session.setup_event_handlers().await; Ok(session) } - /// Set up peer connection event handlers async fn setup_event_handlers(&self) { let state = self.state.clone(); let session_id = self.session_id.clone(); let codec = self.codec; let event_bus = self.event_bus.clone(); - // Connection state change handler self.pc .on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { let state = state.clone(); @@ -430,7 +376,6 @@ impl UniversalSession { }) })); - // ICE connection state handler let session_id_ice = self.session_id.clone(); self.pc .on_ice_connection_state_change(Box::new(move |state| { @@ -440,7 +385,6 @@ impl UniversalSession { }) })); - // ICE gathering state handler let session_id_gather = self.session_id.clone(); let event_bus_gather = event_bus.clone(); self.pc @@ -456,7 +400,6 @@ impl UniversalSession { }) })); - // ICE candidate handler let ice_candidates = self.ice_candidates.clone(); let session_id_candidate = self.session_id.clone(); let event_bus_candidate = event_bus.clone(); @@ -491,14 +434,14 @@ impl UniversalSession { if let Some(bus) = event_bus.as_ref() { bus.publish(SystemEvent::WebRTCIceCandidate { session_id, - candidate, + candidate: serde_json::to_value(&candidate) + .unwrap_or(serde_json::Value::Null), }); } } }) })); - // Data channel handler let data_channel = self.data_channel.clone(); self.pc .on_data_channel(Box::new(move |dc: Arc| { @@ -516,7 +459,6 @@ impl UniversalSession { })); } - /// Set HID controller for DataChannel HID processing pub fn set_hid_controller(&mut self, hid: Arc) { let hid_clone = hid.clone(); let data_channel = self.data_channel.clone(); @@ -533,8 +475,7 @@ impl UniversalSession { dc.on_message(Box::new(move |msg: DataChannelMessage| { let hid = hid.clone(); - // Immediately spawn task in tokio runtime for low latency - // Don't rely on webrtc-rs to poll the returned Future + // webrtc-rs won't poll this future; spawn HID work for latency. tokio::spawn(async move { if let Some(event) = parse_hid_message(&msg.data) { match event { @@ -557,7 +498,6 @@ impl UniversalSession { } }); - // Return empty future (actual work is spawned above) Box::pin(async {}) })); }) @@ -566,7 +506,6 @@ impl UniversalSession { self.hid_controller = Some(hid); } - /// Create data channel for HID events pub async fn create_data_channel(&self, label: &str) -> Result<()> { let dc = self .pc @@ -580,10 +519,7 @@ impl UniversalSession { Ok(()) } - /// Start receiving encoded video frames from shared pipeline - /// - /// The `on_connected` callback is called when ICE connection is established, - /// allowing the caller to request a keyframe at the right time. + /// `on_connected` runs once ICE is up (e.g. request a keyframe). pub async fn start_from_video_pipeline( &self, mut frame_rx: tokio::sync::mpsc::Receiver>, @@ -610,7 +546,6 @@ impl UniversalSession { session_id ); - // Wait for Connected state before sending frames loop { let current_state = *state_rx.borrow(); if current_state == ConnectionState::Connected { @@ -633,7 +568,6 @@ impl UniversalSession { session_id ); - // Request keyframe now that connection is established request_keyframe(); let mut waiting_for_keyframe = true; let mut last_sequence: Option = None; @@ -665,14 +599,12 @@ impl UniversalSession { } }; - // Verify codec matches let frame_codec = encoded_frame.codec; if frame_codec != expected_codec { continue; } - // Debug log for H265 frames if expected_codec == VideoEncoderType::H265 && (encoded_frame.is_keyframe || frames_sent.is_multiple_of(30)) { debug!( @@ -684,7 +616,6 @@ impl UniversalSession { ); } - // Ensure decoder starts from a keyframe and recover on gaps. let mut gap_detected = false; if let Some(prev) = last_sequence { if encoded_frame.sequence > prev.saturating_add(1) { @@ -721,7 +652,6 @@ impl UniversalSession { let _ = send_in_flight; - // Send encoded frame via RTP (drop if previous send is still in flight) let send_result = video_track .write_frame_bytes( encoded_frame.data.clone(), @@ -730,9 +660,7 @@ impl UniversalSession { .await; let _ = send_in_flight; - if send_result.is_err() { - // Keep quiet unless debugging send failures elsewhere - } else { + if send_result.is_ok() { frames_sent += 1; last_sequence = Some(encoded_frame.sequence); } @@ -749,7 +677,6 @@ impl UniversalSession { *self.video_receiver_handle.lock().await = Some(handle); } - /// Start receiving Opus audio frames pub async fn start_audio_from_opus( &self, mut opus_rx: tokio::sync::mpsc::Receiver>, @@ -768,7 +695,6 @@ impl UniversalSession { let session_id = self.session_id.clone(); let handle = tokio::spawn(async move { - // Wait for Connected state before sending audio loop { let current_state = *state_rx.borrow(); if current_state == ConnectionState::Connected { @@ -817,7 +743,6 @@ impl UniversalSession { } }; - // 20ms at 48kHz = 960 samples let samples = 960u32; if let Err(e) = audio_track.write_packet(&opus_frame.data, samples).await { if packets_sent.is_multiple_of(100) { @@ -839,19 +764,15 @@ impl UniversalSession { *self.audio_receiver_handle.lock().await = Some(handle); } - /// Check if audio is enabled for this session pub fn has_audio(&self) -> bool { self.audio_track.is_some() } - /// Get codec type pub fn codec(&self) -> VideoEncoderType { self.codec } - /// Handle SDP offer and create answer pub async fn handle_offer(&self, offer: SdpOffer) -> Result { - // Log offer for debugging H.265 codec negotiation if self.codec == VideoEncoderType::H265 { let has_h265 = offer.sdp.to_lowercase().contains("h265") || offer.sdp.to_lowercase().contains("hevc"); @@ -877,7 +798,6 @@ impl UniversalSession { .await .map_err(|e| AppError::VideoError(format!("Failed to create answer: {}", e)))?; - // Log answer for debugging if self.codec == VideoEncoderType::H265 { let has_h265 = answer.sdp.to_lowercase().contains("h265") || answer.sdp.to_lowercase().contains("hevc"); @@ -897,7 +817,6 @@ impl UniversalSession { .await .map_err(|e| AppError::VideoError(format!("Failed to set local description: {}", e)))?; - // Wait for ICE gathering complete (or timeout) to return a fuller initial candidate set. const ICE_GATHER_TIMEOUT: Duration = Duration::from_millis(2500); if tokio::time::timeout(ICE_GATHER_TIMEOUT, gather_complete.recv()) .await @@ -913,7 +832,6 @@ impl UniversalSession { Ok(SdpAnswer::with_candidates(answer.sdp, candidates)) } - /// Add ICE candidate pub async fn add_ice_candidate(&self, candidate: IceCandidate) -> Result<()> { use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; @@ -932,29 +850,23 @@ impl UniversalSession { Ok(()) } - /// Get current connection state pub fn state(&self) -> ConnectionState { *self.state_rx.borrow() } - /// Subscribe to state changes pub fn state_watch(&self) -> watch::Receiver { self.state_rx.clone() } - /// Close the session pub async fn close(&self) -> Result<()> { - // Stop video receiver if let Some(handle) = self.video_receiver_handle.lock().await.take() { handle.abort(); } - // Stop audio receiver if let Some(handle) = self.audio_receiver_handle.lock().await.take() { handle.abort(); } - // Close peer connection self.pc .close() .await @@ -967,7 +879,6 @@ impl UniversalSession { } } -/// Session info for listing #[derive(Debug, Clone)] pub struct UniversalSessionInfo { pub session_id: String, diff --git a/src/webrtc/video_track.rs b/src/webrtc/video_track.rs index 7fe5a8b3..a2d5967a 100644 --- a/src/webrtc/video_track.rs +++ b/src/webrtc/video_track.rs @@ -1,20 +1,4 @@ -//! Universal video track for WebRTC streaming -//! -//! Supports multiple codecs: H264, H265, VP8, VP9 -//! -//! # Architecture -//! -//! ```text -//! Encoded Frame (H264/H265/VP8/VP9) -//! | -//! v -//! UniversalVideoTrack -//! - H264/VP8/VP9: TrackLocalStaticSample (built-in payloader) -//! - H265: TrackLocalStaticRTP (rtp crate HevcPayloader) -//! | -//! v -//! WebRTC PeerConnection -//! ``` +//! Multiplex H264/VP8/VP9 (`TrackLocalStaticSample`) vs H265 (`TrackLocalStaticRTP` + [`H265Payloader`]). use bytes::Bytes; use std::sync::Arc; @@ -27,33 +11,23 @@ use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; -// Use our custom H265Payloader that handles ALL NAL types correctly -// The rtp crate's HevcPayloader has bugs: -// 1. It drops the IDR frame after emitting the AP packet -// 2. It ignores NAL type 20 (IDR_N_LP) +// rtp `HevcPayloader` mishandles AP+IDR and NAL 20 (`IDR_N_LP`). use super::h265_payloader::H265Payloader; use crate::error::Result; -use crate::video::format::Resolution; +use crate::video::types::Resolution; -/// Default MTU for RTP packets const RTP_MTU: usize = 1200; -/// Video codec type for WebRTC #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum VideoCodec { - /// H.264/AVC H264, - /// H.265/HEVC H265, - /// VP8 VP8, - /// VP9 VP9, } impl VideoCodec { - /// Get MIME type for SDP pub fn mime_type(&self) -> &'static str { match self { VideoCodec::H264 => "video/H264", @@ -63,12 +37,10 @@ impl VideoCodec { } } - /// Get RTP clock rate (always 90kHz for video) pub fn clock_rate(&self) -> u32 { 90000 } - /// Get default RTP payload type pub fn default_payload_type(&self) -> u8 { match self { VideoCodec::H264 => 96, @@ -78,14 +50,12 @@ impl VideoCodec { } } - /// Get SDP fmtp parameters pub fn sdp_fmtp(&self) -> String { match self { VideoCodec::H264 => { "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f".to_string() } VideoCodec::H265 => { - // Match Chrome's H.265 fmtp format: level-id=180 (Level 6.0), profile-id=1 (Main), tier-flag=0, tx-mode=SRST "level-id=180;profile-id=1;tier-flag=0;tx-mode=SRST".to_string() } VideoCodec::VP8 => String::new(), @@ -93,7 +63,6 @@ impl VideoCodec { } } - /// Get display name pub fn display_name(&self) -> &'static str { match self { VideoCodec::H264 => "H.264", @@ -110,20 +79,13 @@ impl std::fmt::Display for VideoCodec { } } -/// Universal video track configuration #[derive(Debug, Clone)] pub struct UniversalVideoTrackConfig { - /// Track ID pub track_id: String, - /// Stream ID pub stream_id: String, - /// Video codec pub codec: VideoCodec, - /// Resolution pub resolution: Resolution, - /// Target bitrate in kbps pub bitrate_kbps: u32, - /// Frames per second pub fps: u32, } @@ -141,7 +103,6 @@ impl Default for UniversalVideoTrackConfig { } impl UniversalVideoTrackConfig { - /// Create H264 config pub fn h264(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self { Self { codec: VideoCodec::H264, @@ -152,7 +113,6 @@ impl UniversalVideoTrackConfig { } } - /// Create H265 config pub fn h265(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self { Self { codec: VideoCodec::H265, @@ -163,7 +123,6 @@ impl UniversalVideoTrackConfig { } } - /// Create VP8 config pub fn vp8(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self { Self { codec: VideoCodec::VP8, @@ -174,7 +133,6 @@ impl UniversalVideoTrackConfig { } } - /// Create VP9 config pub fn vp9(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self { Self { codec: VideoCodec::VP9, @@ -186,40 +144,26 @@ impl UniversalVideoTrackConfig { } } -/// Track type wrapper to support different underlying track implementations enum TrackType { - /// Sample-based track with built-in payloader (H264, VP8, VP9) Sample(Arc), - /// RTP-based track with custom payloader (H265) Rtp(Arc), } -/// H265-specific RTP state struct H265RtpState { - /// H265 payloader (custom implementation that handles all NAL types) payloader: H265Payloader, - /// Current sequence number sequence_number: u16, - /// Current RTP timestamp timestamp: u32, - /// Timestamp increment per frame (90000 / fps) timestamp_increment: u32, } -/// Universal video track supporting H264/H265/VP8/VP9 pub struct UniversalVideoTrack { - /// Underlying WebRTC track (Sample or RTP based) track: TrackType, - /// Codec type codec: VideoCodec, - /// Configuration config: UniversalVideoTrackConfig, - /// H265 RTP state (only used for H265) h265_state: Option>, } impl UniversalVideoTrack { - /// Create a new universal video track pub fn new(config: UniversalVideoTrackConfig) -> Self { let codec_capability = RTCRtpCodecCapability { mime_type: config.codec.mime_type().to_string(), @@ -229,16 +173,13 @@ impl UniversalVideoTrack { rtcp_feedback: vec![], }; - // Use different track types for different codecs let (track, h265_state) = if config.codec == VideoCodec::H265 { - // H265 uses TrackLocalStaticRTP with official rtp crate HevcPayloader let rtp_track = Arc::new(TrackLocalStaticRTP::new( codec_capability, config.track_id.clone(), config.stream_id.clone(), )); - // Create H265 RTP state with custom H265Payloader let h265_state = H265RtpState { payloader: H265Payloader::new(), sequence_number: rand::random::(), @@ -248,7 +189,6 @@ impl UniversalVideoTrack { (TrackType::Rtp(rtp_track), Some(Mutex::new(h265_state))) } else { - // H264/VP8/VP9 use TrackLocalStaticSample with built-in payloader let sample_track = Arc::new(TrackLocalStaticSample::new( codec_capability, config.track_id.clone(), @@ -266,7 +206,6 @@ impl UniversalVideoTrack { } } - /// Get track as TrackLocal for peer connection pub fn as_track_local(&self) -> Arc { match &self.track { TrackType::Sample(t) => t.clone(), @@ -274,23 +213,14 @@ impl UniversalVideoTrack { } } - /// Get codec type pub fn codec(&self) -> VideoCodec { self.codec } - /// Get configuration pub fn config(&self) -> &UniversalVideoTrackConfig { &self.config } - /// Get current statistics - /// - /// Write an encoded frame to the track - /// - /// Handles codec-specific processing: - /// - H264/H265: NAL unit parsing, parameter caching - /// - VP8/VP9: Direct frame sending pub async fn write_frame_bytes(&self, data: Bytes, is_keyframe: bool) -> Result<()> { if data.is_empty() { return Ok(()); @@ -309,17 +239,8 @@ impl UniversalVideoTrack { .await } - /// Write H264 frame (Annex B format) - /// - /// Sends the entire Annex B frame as a single Sample to allow the - /// H264Payloader to aggregate SPS+PPS into STAP-A packets. + /// One Annex-B AU per sample so the stack can STAP/FU internally. async fn write_h264_frame(&self, data: Bytes, _is_keyframe: bool) -> Result<()> { - // Send entire Annex B frame as one Sample - // The H264Payloader in rtp crate will: - // 1. Parse NAL units from Annex B format - // 2. Cache SPS and PPS - // 3. Aggregate SPS+PPS+IDR into STAP-A when possible - // 4. Fragment large NALs using FU-A let frame_duration = Duration::from_micros(1_000_000 / self.config.fps.max(1) as u64); let sample = Sample { data, @@ -341,19 +262,11 @@ impl UniversalVideoTrack { Ok(()) } - /// Write H265 frame (Annex B format) - /// - /// Pass raw Annex B data directly to the official HevcPayloader. - /// The payloader handles NAL parsing, VPS/SPS/PPS caching, AP generation, and FU fragmentation. async fn write_h265_frame(&self, data: Bytes, is_keyframe: bool) -> Result<()> { - // Pass raw Annex B data directly to the official HevcPayloader self.send_h265_rtp(data, is_keyframe).await } - /// Write VP8 frame async fn write_vp8_frame(&self, data: Bytes, _is_keyframe: bool) -> Result<()> { - // VP8 frames are sent directly without NAL parsing - // Calculate frame duration based on configured FPS let frame_duration = Duration::from_micros(1_000_000 / self.config.fps.max(1) as u64); let sample = Sample { data, @@ -375,10 +288,7 @@ impl UniversalVideoTrack { Ok(()) } - /// Write VP9 frame async fn write_vp9_frame(&self, data: Bytes, _is_keyframe: bool) -> Result<()> { - // VP9 frames are sent directly without NAL parsing - // Calculate frame duration based on configured FPS let frame_duration = Duration::from_micros(1_000_000 / self.config.fps.max(1) as u64); let sample = Sample { data, @@ -400,7 +310,6 @@ impl UniversalVideoTrack { Ok(()) } - /// Send H265 NAL units via custom H265Payloader async fn send_h265_rtp(&self, payload: Bytes, _is_keyframe: bool) -> Result<()> { let rtp_track = match &self.track { TrackType::Rtp(t) => t, @@ -418,11 +327,10 @@ impl UniversalVideoTrack { } }; - // Minimize lock hold time: only hold lock during payload generation and state update + // Lock only around payloader + seq/ts bump, not RTP write. let (payloads, timestamp, seq_start, num_payloads) = { let mut state = h265_state.lock().await; - // Use custom H265Payloader to fragment the data let payloads = state.payloader.payload(RTP_MTU, &payload); if payloads.is_empty() { @@ -433,19 +341,16 @@ impl UniversalVideoTrack { let num_payloads = payloads.len(); let seq_start = state.sequence_number; - // Pre-increment sequence number and timestamp state.sequence_number = state.sequence_number.wrapping_add(num_payloads as u16); state.timestamp = state.timestamp.wrapping_add(state.timestamp_increment); (payloads, timestamp, seq_start, num_payloads) - }; // Lock released here, before network I/O + }; - // Send RTP packets without holding the lock for (i, payload_data) in payloads.into_iter().enumerate() { let seq = seq_start.wrapping_add(i as u16); let is_last = i == num_payloads - 1; - // Build RTP packet let packet = rtp::packet::Packet { header: rtp::header::Header { version: 2, diff --git a/src/webrtc/webrtc_streamer.rs b/src/webrtc/webrtc_streamer.rs index 0f0fedc2..a09bca97 100644 --- a/src/webrtc/webrtc_streamer.rs +++ b/src/webrtc/webrtc_streamer.rs @@ -1,36 +1,8 @@ -//! WebRTC Streamer - High-level WebRTC streaming manager -//! -//! This module provides a unified interface for WebRTC streaming mode, -//! supporting multiple video codecs (H264, VP8, VP9, H265) and audio (Opus). -//! -//! # Architecture -//! -//! ```text -//! WebRtcStreamer -//! | -//! +-- Video Pipeline -//! | +-- SharedVideoPipeline (single encoder for all sessions) -//! | +-- H264 Encoder -//! | +-- H265 Encoder (hardware only) -//! | +-- VP8 Encoder (hardware only - VAAPI) -//! | +-- VP9 Encoder (hardware only - VAAPI) -//! | -//! +-- UniversalSession[] (video + audio tracks + DataChannel) -//! +-- UniversalVideoTrack (H264/H265/VP8/VP9) -//! +-- Audio Track (RTP/Opus) -//! +-- DataChannel (HID) -//! ``` -//! -//! # Key Features -//! -//! - **Single encoder**: All sessions share one video encoder -//! - **Multi-codec support**: H264, H265, VP8, VP9 -//! - **Audio support**: Opus audio streaming via AudioController -//! - **HID via DataChannel**: Keyboard/mouse events through WebRTC DataChannel +//! [`WebRtcStreamer`]: shared [`SharedVideoPipeline`], multiplex [`UniversalSession`] (video/audio/HID DC). use std::collections::HashMap; use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, RwLock as StdRwLock}; use tokio::sync::RwLock; use tracing::{debug, info, trace, warn}; @@ -38,11 +10,8 @@ use crate::audio::{AudioController, OpusFrame}; use crate::error::{AppError, Result}; use crate::events::{EventBus, SystemEvent}; use crate::hid::HidController; -use crate::video::encoder::registry::EncoderBackend; -use crate::video::encoder::registry::VideoEncoderType; -use crate::video::encoder::VideoCodecType; -use crate::video::format::{PixelFormat, Resolution}; -use crate::video::shared_video_pipeline::{ +use crate::video::types::{ + BitratePreset, EncoderBackend, PixelFormat, Resolution, VideoCodecType, VideoEncoderType, PipelineStateNotification, SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats, }; @@ -50,26 +19,16 @@ use crate::video::shared_video_pipeline::{ use super::config::{TurnServer, WebRtcConfig}; use super::signaling::{ConnectionState, IceCandidate, SdpAnswer, SdpOffer}; use super::universal_session::{UniversalSession, UniversalSessionConfig}; -use crate::video::encoder::BitratePreset; -/// WebRTC streamer configuration #[derive(Debug, Clone)] pub struct WebRtcStreamerConfig { - /// WebRTC configuration (STUN/TURN servers, etc.) pub webrtc: WebRtcConfig, - /// Video codec type pub video_codec: VideoCodecType, - /// Input resolution pub resolution: Resolution, - /// Input pixel format pub input_format: PixelFormat, - /// Bitrate preset pub bitrate_preset: BitratePreset, - /// Target FPS pub fps: u32, - /// Enable audio (reserved) pub audio_enabled: bool, - /// Encoder backend (None = auto select best available) pub encoder_backend: Option, } @@ -88,7 +47,6 @@ impl Default for WebRtcStreamerConfig { } } -/// Capture device configuration for direct capture pipeline #[derive(Debug, Clone)] pub struct CaptureDeviceConfig { pub device_path: PathBuf, @@ -100,26 +58,19 @@ pub struct CaptureDeviceConfig { pub v4l2_driver: Option, } -/// WebRTC streamer statistics #[derive(Debug, Clone, Default)] pub struct WebRtcStreamerStats { - /// Number of active sessions pub session_count: usize, - /// Current video codec pub video_codec: String, - /// Video pipeline stats (if available) pub video_pipeline: Option, - /// Audio enabled pub audio_enabled: bool, } -/// Video pipeline statistics #[derive(Debug, Clone, Default)] pub struct VideoPipelineStats { pub current_fps: f32, } -/// Session info for listing #[derive(Debug, Clone)] pub struct SessionInfo { pub session_id: String, @@ -127,47 +78,26 @@ pub struct SessionInfo { pub state: String, } -/// WebRTC Streamer -/// -/// High-level manager for WebRTC streaming, supporting multiple video codecs -/// and audio streaming via Opus. pub struct WebRtcStreamer { - /// Current configuration config: RwLock, - - // === Video === - /// Current video codec type video_codec: RwLock, - /// Universal video pipeline (for all codecs) video_pipeline: RwLock>>, - /// All sessions (unified management) sessions: Arc>>>, - /// Capture device configuration for direct capture mode capture_device: RwLock>, - - // === Audio === - /// Audio enabled flag audio_enabled: RwLock, - /// Audio controller reference audio_controller: RwLock>>, - - // === Controllers === - /// HID controller for DataChannel hid_controller: RwLock>>, - - /// Event bus for WebRTC signaling (optional) events: RwLock>>, + self_weak: StdRwLock>>, } impl WebRtcStreamer { - /// Create a new WebRTC streamer pub fn new() -> Arc { Self::with_config(WebRtcStreamerConfig::default()) } - /// Create a new WebRTC streamer with configuration pub fn with_config(config: WebRtcStreamerConfig) -> Arc { - Arc::new(Self { + let streamer = Arc::new(Self { config: RwLock::new(config.clone()), video_codec: RwLock::new(config.video_codec), video_pipeline: RwLock::new(None), @@ -177,12 +107,22 @@ impl WebRtcStreamer { audio_controller: RwLock::new(None), hid_controller: RwLock::new(None), events: RwLock::new(None), - }) + self_weak: StdRwLock::new(None), + }); + let weak = Arc::downgrade(&streamer); + *streamer.self_weak.write().expect("self_weak write") = Some(weak); + streamer } - // === Video Codec Management === + fn self_weak(&self) -> std::sync::Weak { + self.self_weak + .read() + .expect("self_weak read") + .as_ref() + .expect("WebRtcStreamer: self_weak must be initialized in with_config") + .clone() + } - /// Get current video codec type pub async fn current_video_codec(&self) -> VideoCodecType { *self.video_codec.read().await } @@ -191,7 +131,7 @@ impl WebRtcStreamer { /// /// Supports H264, H265, VP8, VP9. This will restart the video pipeline /// and close all existing sessions. - pub async fn set_video_codec(self: &Arc, codec: VideoCodecType) -> Result<()> { + pub async fn set_video_codec(&self, codec: VideoCodecType) -> Result<()> { let current = *self.video_codec.read().await; if current == codec { return Ok(()); @@ -310,7 +250,7 @@ impl WebRtcStreamer { } async fn reconnect_sessions_to_current_pipeline( - self: &Arc, + &self, reason: &str, ) -> Result { if self.capture_device.read().await.is_none() { @@ -347,7 +287,7 @@ impl WebRtcStreamer { } /// Ensure video pipeline is initialized and running - async fn ensure_video_pipeline(self: &Arc) -> Result> { + async fn ensure_video_pipeline(&self) -> Result> { let mut pipeline_guard = self.video_pipeline.write().await; if let Some(ref pipeline) = *pipeline_guard { @@ -395,7 +335,7 @@ impl WebRtcStreamer { // Start a monitor task to detect when pipeline auto-stops let pipeline_weak = Arc::downgrade(&pipeline); - let streamer_weak = Arc::downgrade(self); + let streamer_weak = self.self_weak(); let mut running_rx = pipeline.running_watch(); tokio::spawn(async move { @@ -475,7 +415,7 @@ impl WebRtcStreamer { /// This is a public wrapper around ensure_video_pipeline for external /// components (like RustDesk) that need to share the encoded video stream. pub async fn ensure_video_pipeline_for_external( - self: &Arc, + &self, ) -> Result> { self.ensure_video_pipeline().await } @@ -550,7 +490,7 @@ impl WebRtcStreamer { &self, ) -> Option>> { if let Some(ref controller) = *self.audio_controller.read().await { - controller.subscribe_opus_async().await + controller.subscribe_opus().await } else { None } @@ -564,7 +504,7 @@ impl WebRtcStreamer { for (session_id, session) in sessions.iter() { if session.has_audio() { info!("Reconnecting audio for session {}", session_id); - if let Some(rx) = controller.subscribe_opus_async().await { + if let Some(rx) = controller.subscribe_opus().await { session.start_audio_from_opus(rx).await; } } @@ -832,7 +772,7 @@ impl WebRtcStreamer { // === Session Management === /// Create a new WebRTC session - pub async fn create_session(self: &Arc) -> Result { + pub async fn create_session(&self) -> Result { let session_id = uuid::Uuid::new_v4().to_string(); let codec = *self.video_codec.read().await; @@ -881,7 +821,7 @@ impl WebRtcStreamer { // Start audio if enabled if session_config.audio_enabled { if let Some(ref controller) = *self.audio_controller.read().await { - if let Some(opus_rx) = controller.subscribe_opus_async().await { + if let Some(opus_rx) = controller.subscribe_opus().await { session.start_audio_from_opus(opus_rx).await; } } @@ -1068,7 +1008,7 @@ impl WebRtcStreamer { /// /// Note: Hardware encoders (VAAPI, NVENC, etc.) don't support dynamic bitrate changes. /// This method restarts the pipeline to apply the new bitrate only if the preset actually changed. - pub async fn set_bitrate_preset(self: &Arc, preset: BitratePreset) -> Result<()> { + pub async fn set_bitrate_preset(&self, preset: BitratePreset) -> Result<()> { // Check if preset actually changed let current_preset = self.config.read().await.bitrate_preset; if current_preset == preset { @@ -1128,6 +1068,93 @@ impl WebRtcStreamer { } } +#[async_trait::async_trait] +impl crate::video::traits::VideoOutput for WebRtcStreamer { + async fn set_event_bus(&self, events: Arc) { + self.set_event_bus(events).await; + } + + async fn update_video_config(&self, resolution: Resolution, format: PixelFormat, fps: u32) { + self.update_video_config(resolution, format, fps).await; + } + + async fn set_capture_device( + &self, + device_path: PathBuf, + jpeg_quality: u8, + subdev_path: Option, + bridge_kind: Option, + v4l2_driver: Option, + ) { + self.set_capture_device(device_path, jpeg_quality, subdev_path, bridge_kind, v4l2_driver) + .await; + } + + async fn current_video_codec(&self) -> VideoCodecType { + self.current_video_codec().await + } + + async fn is_hardware_encoding(&self) -> bool { + self.is_hardware_encoding().await + } + + async fn close_all_sessions(&self) { + self.close_all_sessions().await; + } + + async fn close_all_sessions_and_release_device(&self) -> usize { + self.close_all_sessions_and_release_device().await + } + + async fn session_count(&self) -> usize { + self.session_count().await + } + + async fn set_hid_controller(&self, hid: Arc) { + self.set_hid_controller(hid).await; + } + + async fn set_audio_enabled(&self, enabled: bool) -> Result<()> { + self.set_audio_enabled(enabled).await + } + + async fn is_audio_enabled(&self) -> bool { + self.is_audio_enabled().await + } + + async fn reconnect_audio_sources(&self) { + self.reconnect_audio_sources().await; + } + + async fn ensure_video_pipeline_for_external(&self) -> Result> { + self.ensure_video_pipeline_for_external().await + } + + async fn get_pipeline_config(&self) -> Option { + self.get_pipeline_config().await + } + + async fn set_video_codec(&self, codec: VideoCodecType) -> Result<()> { + self.set_video_codec(codec).await + } + + async fn set_bitrate_preset(&self, preset: BitratePreset) -> Result<()> { + self.set_bitrate_preset(preset).await + } + + async fn request_keyframe(&self) -> Result<()> { + self.request_keyframe().await + } + + async fn current_video_geometry(&self) -> (Resolution, PixelFormat, u32) { + self.current_video_geometry().await + } + + async fn pipeline_stats(&self) -> Option { + self.pipeline_stats().await + } +} + impl Default for WebRtcStreamer { fn default() -> Self { Self { @@ -1140,6 +1167,7 @@ impl Default for WebRtcStreamer { audio_controller: RwLock::new(None), hid_controller: RwLock::new(None), events: RwLock::new(None), + self_weak: StdRwLock::new(None), } } } diff --git a/web/src/App.vue b/web/src/App.vue index 24f5fa11..d31f7338 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -8,7 +8,6 @@ const router = useRouter() const authStore = useAuthStore() const systemStore = useSystemStore() -// Check for dark mode preference function initTheme() { const stored = localStorage.getItem('theme') if (stored === 'dark' || (!stored && window.matchMedia('(prefers-color-scheme: dark)').matches)) { @@ -16,11 +15,9 @@ function initTheme() { } } -// Initialize on mount onMounted(async () => { initTheme() - // Check setup status try { await authStore.checkSetupStatus() if (authStore.needsSetup) { @@ -28,22 +25,17 @@ onMounted(async () => { return } } catch { - // Continue anyway } - // Check auth status try { await authStore.checkAuth() if (authStore.isAuthenticated) { - // Fetch system info await systemStore.fetchSystemInfo() } } catch { - // Not authenticated } }) -// Listen for dark mode changes watch( () => window.matchMedia('(prefers-color-scheme: dark)').matches, (dark) => { diff --git a/web/src/api/config.ts b/web/src/api/config.ts index e7ecbfc4..573e489a 100644 --- a/web/src/api/config.ts +++ b/web/src/api/config.ts @@ -19,6 +19,7 @@ import type { MsdConfigUpdate, AtxConfig, AtxConfigUpdate, + AtxDevices, AudioConfig, AudioConfigUpdate, ExtensionsStatus, @@ -36,7 +37,6 @@ import type { import { request } from './request' -// ===== 全局配置 API ===== export const configApi = { /** * 获取完整配置 @@ -44,7 +44,6 @@ export const configApi = { getAll: () => request('/config'), } -// ===== Auth 配置 API ===== export const authConfigApi = { /** * 获取认证配置 @@ -62,7 +61,6 @@ export const authConfigApi = { }), } -// ===== Video 配置 API ===== export const videoConfigApi = { /** * 获取视频配置 @@ -80,7 +78,6 @@ export const videoConfigApi = { }), } -// ===== Stream 配置 API ===== export const streamConfigApi = { /** * 获取流配置 @@ -98,7 +95,6 @@ export const streamConfigApi = { }), } -// ===== HID 配置 API ===== export const hidConfigApi = { /** * 获取 HID 配置 @@ -116,7 +112,6 @@ export const hidConfigApi = { }), } -// ===== MSD 配置 API ===== export const msdConfigApi = { /** * 获取 MSD 配置 @@ -134,9 +129,6 @@ export const msdConfigApi = { }), } -// ===== ATX 配置 API ===== -import type { AtxDevices } from '@/types/generated' - export interface WolHistoryEntry { mac_address: string updated_at: number @@ -185,7 +177,6 @@ export const atxConfigApi = { request(`/atx/wol/history?limit=${Math.max(1, Math.min(50, limit))}`), } -// ===== Audio 配置 API ===== export const audioConfigApi = { /** * 获取音频配置 @@ -203,7 +194,6 @@ export const audioConfigApi = { }), } -// ===== Extensions API ===== export const extensionsApi = { /** * 获取所有扩展状态 @@ -265,8 +255,6 @@ export const extensionsApi = { }), } -// ===== RustDesk 配置 API ===== - /** RustDesk 配置响应 */ export interface RustDeskConfigResponse { enabled: boolean @@ -342,8 +330,6 @@ export const rustdeskConfigApi = { }), } -// ===== RTSP 配置 API ===== - export type RtspCodec = 'h264' | 'h265' export interface RtspConfigResponse { @@ -385,9 +371,6 @@ export const rtspConfigApi = { getStatus: () => request('/config/rtsp/status'), } -// ===== Web 服务器配置 API ===== -// `/config/web` 使用 `WebConfigResponse` / `WebConfigUpdate`(由 typeshare 自 Rust 生成)。 - /** REST `/config/web` 响应(`WebConfigResponse` 别名,兼容旧命名) */ export type WebConfig = WebConfigResponse @@ -409,8 +392,6 @@ export const webConfigApi = { }), } -// ===== 系统控制 API ===== - export const systemApi = { /** * 重启系统 diff --git a/web/src/api/index.ts b/web/src/api/index.ts index 776876d0..f51273f8 100644 --- a/web/src/api/index.ts +++ b/web/src/api/index.ts @@ -1,11 +1,9 @@ -// API client for One-KVM backend - import { request, ApiError } from './request' import type { CanonicalKey } from '@/types/generated' +import { useHidWebSocket, type HidKeyboardEvent, type HidMouseEvent } from '@/composables/useHidWebSocket' const API_BASE = '/api' -// Auth API export const authApi = { login: (username: string, password: string) => request<{ success: boolean; message?: string }>( @@ -36,7 +34,6 @@ export const authApi = { }), } -// System API export interface NetworkAddress { interface: string ip: string @@ -149,7 +146,6 @@ export const updateApi = { request('/update/status'), } -// Stream API export interface VideoCodecInfo { id: string name: string @@ -271,7 +267,6 @@ export const streamApi = { }), } -// WebRTC API export interface IceCandidate { candidate: string sdpMid?: string @@ -317,15 +312,9 @@ export const webrtcApi = { request<{ ice_servers: IceServerConfig[]; mdns_mode: string }>('/webrtc/ice-servers'), } -// HID API -// Import HID WebSocket composable -import { useHidWebSocket, type HidKeyboardEvent, type HidMouseEvent } from '@/composables/useHidWebSocket' - -// Create shared HID WebSocket instance const hidWs = useHidWebSocket() let hidWsInitialized = false -// Initialize HID WebSocket connection async function ensureHidConnection() { if (!hidWsInitialized) { hidWsInitialized = true @@ -333,7 +322,6 @@ async function ensureHidConnection() { } } -// Map button string to number function mapButton(button?: 'left' | 'right' | 'middle'): number | undefined { if (!button) return undefined const buttonMap = { left: 0, middle: 1, right: 2 } @@ -403,7 +391,6 @@ export const hidApi = { scroll?: number | null }) => { await ensureHidConnection() - // Ensure all values are properly typed (convert null to undefined) const event: HidMouseEvent = { type: data.type === 'move_abs' ? 'moveabs' : data.type, x: data.x ?? undefined, @@ -424,13 +411,11 @@ export const hidApi = { return { success: true } }, - // WebSocket connection management connectWebSocket: () => hidWs.connect(), disconnectWebSocket: () => hidWs.disconnect(), isWebSocketConnected: () => hidWs.connected.value, } -// ATX API export const atxApi = { status: () => request<{ @@ -448,7 +433,6 @@ export const atxApi = { }), } -// MSD API export interface MsdImage { id: string name: string @@ -485,7 +469,6 @@ export const msdApi = { } }>('/msd/status'), - // Image management listImages: () => request('/msd/images'), uploadImage: async (file: File, onProgress?: (progress: number) => void) => { @@ -528,7 +511,6 @@ export const msdApi = { disconnect: () => request<{ success: boolean }>('/msd/disconnect', { method: 'POST' }), - // Virtual drive driveInfo: () => request<{ size: number @@ -590,7 +572,6 @@ export const msdApi = { method: 'POST', }), - // Download from URL downloadFromUrl: (url: string, filename?: string) => request<{ download_id: string @@ -632,7 +613,6 @@ function sortSerialDevices(serialDevices: SerialDeviceOption[]): SerialDeviceOpt }) } -// Config API /** @deprecated 使用域特定 API(videoConfigApi, hidConfigApi 等)替代 */ export const configApi = { get: () => request>('/config'), @@ -683,7 +663,6 @@ export const configApi = { }, } -// 导出新的域分离配置 API export { authConfigApi, videoConfigApi, @@ -707,7 +686,6 @@ export { type WebConfigUpdate, } from './config' -// 导出生成的类型 export type { AppConfig, AuthConfig, @@ -730,7 +708,6 @@ export type { BitratePreset, } from '@/types/generated' -// Audio API export const audioApi = { status: () => request<{ @@ -765,7 +742,6 @@ export const audioApi = { }), } -// USB API export interface UsbDeviceInfo { bus_num: number dev_num: number diff --git a/web/src/components/ActionBar.vue b/web/src/components/ActionBar.vue index 8685deb1..82d449fd 100644 --- a/web/src/components/ActionBar.vue +++ b/web/src/components/ActionBar.vue @@ -51,7 +51,6 @@ const { t, locale } = useI18n() const router = useRouter() const systemStore = useSystemStore() -// Overflow menu state const overflowMenuOpen = ref(false) const hidBackend = computed(() => (systemStore.hid?.backend ?? '').toLowerCase()) @@ -79,7 +78,6 @@ const emit = defineEmits<{ (e: 'openTerminal'): void }>() -// Desktop toolbar popover/dialog state const pasteOpen = ref(false) const atxOpen = ref(false) const videoPopoverOpen = ref(false) @@ -88,7 +86,6 @@ const audioPopoverOpen = ref(false) const msdDialogOpen = ref(false) const extensionOpen = ref(false) -// Mobile Sheet state const mobileAtxOpen = ref(false) const mobilePasteOpen = ref(false) const mobileAtxOpenTime = ref(0) @@ -117,7 +114,6 @@ const openMobilePaste = () => openFromOverflow(() => { mobilePasteOpenTime.value = Date.now() }) -// ── Adaptive overflow: measure real width, show as many items as fit ── const barRef = ref(null) const measureRef = ref(null) @@ -146,11 +142,9 @@ const ITEM_SPECS: ItemSpec[] = [ { id: 'settings', side: 'right' }, ] -// Measured widths from DOM (icon-only and with-label) const measuredWidths = ref>(new Map()) const measurementReady = ref(false) -// Measure button widths from hidden measurement container const measureButtonWidths = async () => { await nextTick() if (!measureRef.value) return @@ -162,7 +156,6 @@ const measureButtonWidths = async () => { const labelEl = measureRef.value.querySelector(`[data-measure="${spec.id}-label"]`) as HTMLElement if (iconEl && labelEl) { - // Add small buffer (8px) for gaps and rounding errors newWidths.set(spec.id, { icon: Math.ceil(iconEl.offsetWidth) + 8, label: Math.ceil(labelEl.offsetWidth) + 8, @@ -191,17 +184,13 @@ onUnmounted(() => { resizeObserver?.disconnect() }) -// Re-measure when locale changes (different text widths) watch(locale, () => { measurementReady.value = false measureButtonWidths() }) -// Fixed-width budget for always-visible items (right side): -// keyboard + fullscreen + potential overflow button + gaps const RIGHT_FIXED_PX = 120 -// First 3 items (video/audio/hid) are always visible const collapsibleItems = computed(() => { const items = ITEM_SPECS.slice(3).filter(item => { if (item.id === 'msd' && !showMsd.value) return false @@ -210,27 +199,22 @@ const collapsibleItems = computed(() => { return items }) -// Determine which collapsible items are visible (icon-only or with label) const visibleSet = computed(() => { if (!measurementReady.value) { - // Fallback to hardcoded estimates during initial render return new Map() } const available = barWidth.value - RIGHT_FIXED_PX - // Measure actual width of always-visible items (video/audio/hid) let used = 0 if (barRef.value) { const leftContainer = barRef.value.querySelector('.left-buttons') as HTMLElement if (leftContainer) { - // Get width of first 3 children (video/audio/hid) const children = Array.from(leftContainer.children).slice(0, 3) as HTMLElement[] used = children.reduce((sum, el) => sum + el.offsetWidth, 0) } } - // If measurement failed, use estimate if (used === 0) used = 330 const result = new Map() diff --git a/web/src/components/AtxPopover.vue b/web/src/components/AtxPopover.vue index d13a8456..50e5fe6d 100644 --- a/web/src/components/AtxPopover.vue +++ b/web/src/components/AtxPopover.vue @@ -33,14 +33,12 @@ const { t } = useI18n() const activeTab = ref('atx') -// ATX state const powerState = ref<'on' | 'off' | 'unknown'>('unknown') let powerStateTimer: number | null = null // Decouple action data from dialog visibility to prevent race conditions const pendingAction = ref<'short' | 'long' | 'reset' | null>(null) const confirmDialogOpen = ref(false) -// WOL state const wolMacAddress = ref('') const wolHistory = ref([]) const wolSending = ref(false) @@ -95,10 +93,8 @@ const confirmDescription = computed(() => { default: return '' } }) -// MAC address validation const isValidMac = computed(() => { const mac = wolMacAddress.value.trim() - // Support formats: AA:BB:CC:DD:EE:FF or AA-BB-CC-DD-EE-FF or AABBCCDDEEFF const macRegex = /^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$|^([0-9A-Fa-f]{12})$/ return macRegex.test(mac) }) @@ -107,7 +103,6 @@ function sendWol() { if (!isValidMac.value) return wolSending.value = true - // Normalize MAC address let mac = wolMacAddress.value.trim().toUpperCase() if (mac.length === 12) { mac = mac.match(/.{2}/g)!.join(':') @@ -117,7 +112,6 @@ function sendWol() { emit('wol', mac) - // Optimistic update, then sync from server after request likely completes wolHistory.value = [mac, ...wolHistory.value.filter(item => item !== mac)].slice(0, 5) setTimeout(() => { loadWolHistory().catch(() => {}) diff --git a/web/src/components/AudioConfigPopover.vue b/web/src/components/AudioConfigPopover.vue index da267f53..76fd7a0d 100644 --- a/web/src/components/AudioConfigPopover.vue +++ b/web/src/components/AudioConfigPopover.vue @@ -42,10 +42,8 @@ const configStore = useConfigStore() const systemStore = useSystemStore() const unifiedAudio = getUnifiedAudio() -// === Playback Control (immediate effect) === const localVolume = ref([unifiedAudio.volume.value * 100]) -// Volume change - immediate effect, also triggers connection if needed async function handleVolumeChange(value: number[] | undefined) { if (!value || value.length === 0 || value[0] === undefined) return @@ -53,7 +51,6 @@ async function handleVolumeChange(value: number[] | undefined) { unifiedAudio.setVolume(newVolume) localVolume.value = value - // If backend is streaming but audio not connected, connect now (user gesture) if (newVolume > 0 && systemStore.audio?.streaming && !unifiedAudio.connected.value) { console.log('[Audio] User adjusted volume, connecting unified audio') try { @@ -64,17 +61,14 @@ async function handleVolumeChange(value: number[] | undefined) { } } -// === Device Settings (requires apply) === const devices = ref([]) const loadingDevices = ref(false) const applying = ref(false) -// Config values const audioEnabled = ref(false) const selectedDevice = ref('') const selectedQuality = ref<'voice' | 'balanced' | 'high'>('balanced') -// Load device list async function loadDevices() { loadingDevices.value = true try { @@ -87,7 +81,6 @@ async function loadDevices() { } } -// Initialize from current config function initializeFromCurrent() { const audio = configStore.audio if (audio) { @@ -96,46 +89,36 @@ function initializeFromCurrent() { selectedQuality.value = (audio.quality as 'voice' | 'balanced' | 'high') || 'balanced' } - // Sync playback control state localVolume.value = [unifiedAudio.volume.value * 100] } -// Apply device configuration async function applyConfig() { applying.value = true try { - // Update config await configStore.updateAudio({ enabled: audioEnabled.value, device: selectedDevice.value, quality: selectedQuality.value, }) - // If enabled and device is selected, try to start audio stream if (audioEnabled.value && selectedDevice.value) { try { - // Restore default volume BEFORE starting audio - // This ensures handleAudioStateChanged sees the correct volume if (localVolume.value[0] === 0) { localVolume.value = [100] unifiedAudio.setVolume(1) } await audioApi.start() - // ConsoleView will react when system.device_info reflects streaming=true. } catch (startError) { - // Audio start failed - config was saved but streaming not started console.info('[AudioConfig] Audio start failed:', startError) } } else if (!audioEnabled.value) { - // Reset volume to 0 when disabling audio localVolume.value = [0] unifiedAudio.setVolume(0) try { await audioApi.stop() } catch { - // Ignore stop errors } unifiedAudio.disconnect() } @@ -148,7 +131,6 @@ async function applyConfig() { } } -// Watch popover open state watch(() => props.open, (isOpen) => { if (!isOpen) return diff --git a/web/src/components/HidConfigPopover.vue b/web/src/components/HidConfigPopover.vue index 201b8433..2c3ad866 100644 --- a/web/src/components/HidConfigPopover.vue +++ b/web/src/components/HidConfigPopover.vue @@ -54,7 +54,6 @@ function loadMouseMoveSendIntervalFromStorage(): number { ) } -// Mouse Settings (real-time) const mouseThrottle = ref( loadMouseMoveSendIntervalFromStorage() ) @@ -62,9 +61,7 @@ const showCursor = ref( localStorage.getItem('hidShowCursor') !== 'false' // default true ) -// Watch showCursor changes and sync to localStorage + notify ConsoleView watch(showCursor, (newValue, oldValue) => { - // Only sync if value actually changed (avoid triggering on initialization) if (newValue !== oldValue) { localStorage.setItem('hidShowCursor', newValue ? 'true' : 'false') window.dispatchEvent(new CustomEvent('hidCursorVisibilityChanged', { @@ -78,7 +75,6 @@ const hidBackend = ref(HidBackend.None) const devicePath = ref('') const baudrate = ref(9600) -// UI state const applying = ref(false) const loadingDevices = ref(false) @@ -86,7 +82,6 @@ const loadingDevices = ref(false) const serialDevices = ref>([]) const udcDevices = ref>([]) -// Button display text - simplified to just show label const buttonText = computed(() => t('actionbar.hidConfig')) // Available device paths based on backend type @@ -117,9 +112,7 @@ async function loadDevices() { } } -// Initialize from current config function initializeFromCurrent() { - // Re-sync real-time settings from localStorage mouseThrottle.value = loadMouseMoveSendIntervalFromStorage() const storedCursor = localStorage.getItem('hidShowCursor') !== 'false' @@ -140,7 +133,6 @@ function initializeFromCurrent() { } } -// Toggle mouse mode (real-time) function toggleMouseMode() { const newMode = props.mouseMode === 'absolute' ? 'relative' : 'absolute' emit('update:mouseMode', newMode) @@ -154,14 +146,11 @@ function toggleMouseMode() { }) } -// Update mouse throttle (real-time) function handleThrottleChange(value: number[] | undefined) { if (!value || value.length === 0 || value[0] === undefined) return const throttleValue = clampMouseMoveSendIntervalMs(value[0]) mouseThrottle.value = throttleValue - // Save to localStorage localStorage.setItem('hidMouseThrottle', String(throttleValue)) - // Notify ConsoleView (storage event doesn't fire in same tab) window.dispatchEvent(new CustomEvent('hidMouseSendIntervalChanged', { detail: { intervalMs: throttleValue }, })) @@ -191,7 +180,6 @@ function handleDevicePathChange(path: unknown) { devicePath.value = path } -// Handle baudrate change function handleBaudrateChange(rate: unknown) { if (typeof rate !== 'string') return baudrate.value = Number(rate) @@ -219,13 +207,11 @@ async function applyHidConfig() { // HID state will be updated via WebSocket device_info event } catch (e) { console.info('[HidConfig] Failed to apply config:', e) - // Error toast already shown by API layer } finally { applying.value = false } } -// Watch open state watch(() => props.open, (isOpen) => { if (!isOpen) return diff --git a/web/src/components/InfoBar.vue b/web/src/components/InfoBar.vue index 4f0ed987..2c39fb81 100644 --- a/web/src/components/InfoBar.vue +++ b/web/src/components/InfoBar.vue @@ -17,7 +17,6 @@ const props = defineProps<{ const { t } = useI18n() -// Key name mapping for friendly display const keyNameMap: Record = { MetaLeft: 'Win', MetaRight: 'Win', ControlLeft: 'Ctrl', ControlRight: 'Ctrl', diff --git a/web/src/components/MsdDialog.vue b/web/src/components/MsdDialog.vue index 88146247..10bc1eb4 100644 --- a/web/src/components/MsdDialog.vue +++ b/web/src/components/MsdDialog.vue @@ -60,29 +60,23 @@ const { t } = useI18n() const systemStore = useSystemStore() const { on, off } = useWebSocket() -// Tab state const activeTab = ref('images') -// Image state const images = ref([]) const loadingImages = ref(false) const uploadProgress = ref(0) const uploading = ref(false) -// Mount options (using ToggleGroup) const mountMode = ref<'cdrom' | 'flash'>('flash') const accessMode = ref<'readonly' | 'readwrite'>('readonly') -// Computed properties for API compatibility const cdromMode = computed(() => mountMode.value === 'cdrom') const readOnly = computed(() => accessMode.value === 'readonly') -// Operation state flags const connecting = ref(false) const disconnecting = ref(false) const deleting = ref(false) -// Drive state const driveFiles = ref([]) const currentPath = ref('/') const loadingDrive = ref(false) @@ -91,13 +85,11 @@ const driveInitialized = ref(false) const uploadingFile = ref(false) const fileUploadProgress = ref(0) -// Inner dialog state const showDeleteDialog = ref(false) const deleteTarget = ref<{ type: 'image' | 'file'; id: string; name: string } | null>(null) const showNewFolderDialog = ref(false) const newFolderName = ref('') -// Drive init dialog state const showDriveInitDialog = ref(false) const showDeleteDriveDialog = ref(false) const selectedDriveSize = ref(256) // Default 256MB @@ -105,7 +97,6 @@ const customDriveSize = ref(undefined) const initializingDrive = ref(false) const deletingDrive = ref(false) -// URL download state const showUrlDialog = ref(false) const downloadUrl = ref('') const downloadFilename = ref('') @@ -119,14 +110,11 @@ const downloadProgress = ref<{ status: string } | null>(null) -// Constants const TWO_POINT_TWO_GB = 2.2 * 1024 * 1024 * 1024 -// Computed const msdConnected = computed(() => systemStore.msd?.connected ?? false) const msdMode = computed(() => systemStore.msd?.mode ?? 'none') -// Get currently connected image name const connectedImageName = computed(() => { if (!msdConnected.value) return null if (msdMode.value === 'drive') return t('msd.drive') @@ -136,7 +124,6 @@ const connectedImageName = computed(() => { return image?.name ?? null }) -// Check if any operation is in progress const operationInProgress = computed(() => { return connecting.value || disconnecting.value || @@ -147,7 +134,6 @@ const operationInProgress = computed(() => { deletingDrive.value }) -// Check if image is large (>2.2GB) function isLargeFile(image: MsdImage): boolean { return image.size > TWO_POINT_TWO_GB } @@ -163,7 +149,6 @@ const breadcrumbs = computed(() => { return crumbs }) -// Load data when dialog opens watch(() => props.open, async (isOpen) => { if (isOpen) { await loadData() @@ -179,7 +164,6 @@ async function loadData() { } } -// Image functions async function loadImages() { loadingImages.value = true try { @@ -294,7 +278,6 @@ async function executeDelete() { } } -// Drive functions async function loadDriveInfo() { try { driveInfo.value = await msdApi.driveInfo() @@ -304,7 +287,6 @@ async function loadDriveInfo() { } } -// Drive size options - computed for i18n support const driveSizeOptions = computed(() => [ { value: 64, label: '64 MB' }, { value: 128, label: '128 MB' }, @@ -316,17 +298,14 @@ const driveSizeOptions = computed(() => [ { value: 8192, label: '8 GB' }, ]) -// Computed final drive size const finalDriveSize = computed(() => { return customDriveSize.value || selectedDriveSize.value }) -// Open drive init dialog function initializeDrive() { showDriveInitDialog.value = true } -// Create drive with selected size async function createDrive() { initializingDrive.value = true try { @@ -342,7 +321,6 @@ async function createDrive() { } } -// Delete virtual drive async function deleteDrive() { deletingDrive.value = true try { @@ -422,7 +400,6 @@ async function createFolder() { } } -// URL download functions async function startUrlDownload() { if (!downloadUrl.value.trim()) return diff --git a/web/src/components/MsdSheet.vue b/web/src/components/MsdSheet.vue index 985f2c28..17b32069 100644 --- a/web/src/components/MsdSheet.vue +++ b/web/src/components/MsdSheet.vue @@ -53,10 +53,8 @@ const emit = defineEmits<{ const { t } = useI18n() const systemStore = useSystemStore() -// Tab state const activeTab = ref('images') -// Image state const images = ref([]) const loadingImages = ref(false) const uploadProgress = ref(0) @@ -64,7 +62,6 @@ const uploading = ref(false) const cdromMode = ref(true) const readOnly = ref(true) -// Drive state const driveFiles = ref([]) const currentPath = ref('/') const loadingDrive = ref(false) @@ -73,13 +70,11 @@ const driveInitialized = ref(false) const uploadingFile = ref(false) const fileUploadProgress = ref(0) -// Dialog state const showDeleteDialog = ref(false) const deleteTarget = ref<{ type: 'image' | 'file'; id: string; name: string } | null>(null) const showNewFolderDialog = ref(false) const newFolderName = ref('') -// Computed const msdConnected = computed(() => systemStore.msd?.connected ?? false) const msdMode = computed(() => systemStore.msd?.mode ?? 'none') @@ -94,7 +89,6 @@ const breadcrumbs = computed(() => { return crumbs }) -// Load data when sheet opens watch(() => props.open, async (isOpen) => { if (isOpen) { await loadData() @@ -110,7 +104,6 @@ async function loadData() { } } -// Image functions async function loadImages() { loadingImages.value = true try { @@ -186,7 +179,6 @@ async function executeDelete() { } } -// Drive functions async function loadDriveInfo() { try { driveInfo.value = await msdApi.driveInfo() diff --git a/web/src/components/PasteModal.vue b/web/src/components/PasteModal.vue index e3701d61..26b81cb1 100644 --- a/web/src/components/PasteModal.vue +++ b/web/src/components/PasteModal.vue @@ -23,11 +23,8 @@ const currentChar = ref(0) const totalChars = ref(0) const abortController = ref(null) -// Typing speed in milliseconds between characters -// Configurable delay to prevent target system from missing keystrokes const typingDelay = ref(10) -// Text analysis for warning display const textAnalysis = computed(() => { if (!text.value) return null return analyzeText(text.value) @@ -38,14 +35,12 @@ const hasUntypableChars = computed(() => { }) onMounted(() => { - // Auto focus the textarea setTimeout(() => { textareaRef.value?.focus() }, 100) }) onUnmounted(() => { - // Cancel any ongoing paste operation when component is unmounted cancelPaste() }) @@ -65,7 +60,6 @@ async function typeChar(char: string, signal: AbortSignal): Promise { const mapping = charToKey(char) if (!mapping) { - // Skip untypable characters return true } @@ -73,10 +67,8 @@ async function typeChar(char: string, signal: AbortSignal): Promise { const modifier = shift ? 0x02 : 0 try { - // Send keydown await hidApi.keyboard('down', key, modifier) - // Small delay between down and up to ensure key is registered await sleep(5) if (signal.aborted) { @@ -85,20 +77,16 @@ async function typeChar(char: string, signal: AbortSignal): Promise { return false } - // Send keyup await hidApi.keyboard('up', key, modifier) - // Additional small delay after keyup to ensure it's processed await sleep(2) return true } catch (error) { console.error('[Paste] Failed to type character:', char, error) - // Try to release the key even on error try { await hidApi.keyboard('up', key, modifier) } catch { - // Ignore cleanup errors } return false } @@ -133,20 +121,17 @@ async function handlePaste() { currentChar.value = charIndex progress.value = Math.round((charIndex / totalLength) * 100) - // Handle CRLF: skip \r if followed by \n if (char === '\r' && charIndex < totalLength && chars[charIndex] === '\n') { continue } await typeChar(char, signal) - // Delay between characters (configurable) if (typingDelay.value > 0 && charIndex < totalLength) { await sleep(typingDelay.value) } } - // Success - close the modal after a brief delay if (!signal.aborted) { await sleep(200) text.value = '' @@ -155,11 +140,9 @@ async function handlePaste() { } catch (error) { console.error('[Paste] Error during paste operation:', error) } finally { - // Reset HID to ensure no keys are stuck try { await hidApi.reset() } catch { - // Ignore reset errors } isPasting.value = false progress.value = 0 @@ -180,14 +163,12 @@ function cancelPaste() { } function handleKeydown(e: KeyboardEvent) { - // Ctrl/Cmd + Enter to paste if ((e.metaKey || e.ctrlKey) && e.key === 'Enter') { e.preventDefault() if (!isPasting.value) { handlePaste() } } - // Escape to cancel or close if (e.key === 'Escape') { e.preventDefault() if (isPasting.value) { @@ -196,7 +177,6 @@ function handleKeydown(e: KeyboardEvent) { emit('close') } } - // Stop propagation to prevent HID interference e.stopPropagation() } diff --git a/web/src/components/StatsSheet.vue b/web/src/components/StatsSheet.vue index d20cede0..639dc35f 100644 --- a/web/src/components/StatsSheet.vue +++ b/web/src/components/StatsSheet.vue @@ -29,19 +29,16 @@ const emit = defineEmits<{ (e: 'update:open', value: boolean): void }>() -// Chart containers const stabilityChartRef = ref(null) const delayChartRef = ref(null) const packetLossChartRef = ref(null) const fpsChartRef = ref(null) -// Chart instances let stabilityChart: uPlot | null = null let delayChart: uPlot | null = null let packetLossChart: uPlot | null = null let fpsChart: uPlot | null = null -// Data history (last 120 seconds) const MAX_POINTS = 120 const timestamps = ref([]) const jitterHistory = ref([]) @@ -50,7 +47,6 @@ const packetLossHistory = ref([]) const fpsHistory = ref([]) const bitrateHistory = ref([]) -// For delta calculations let lastBytesReceived = 0 let lastPacketsLost = 0 let lastTimestamp = 0 @@ -58,13 +54,11 @@ let lastTimestamp = 0 // Is WebRTC mode const isWebRTC = computed(() => props.videoMode !== 'mjpeg') -// Format time for axis function formatTime(ts: number): string { const date = new Date(ts * 1000) return date.toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit' }) } -// Chart theme colors const chartColors = { line: '#3b82f6', fill: 'rgba(59, 130, 246, 0.1)', @@ -73,7 +67,6 @@ const chartColors = { text: '#94a3b8', } -// Chart options factory function createChartOptions( container: HTMLElement, _yLabel: string, @@ -129,7 +122,6 @@ function createChartOptions( } } -// Tooltip state for each chart const activeTooltip = ref<{ chartId: string time: string @@ -186,12 +178,10 @@ function createTooltipPlugin(chartId: string, unit: string): uPlot.Plugin { } } -// Initialize charts function initCharts() { if (!props.open) return nextTick(() => { - // Initialize timestamps if empty if (timestamps.value.length === 0) { const now = Date.now() / 1000 for (let i = MAX_POINTS - 1; i >= 0; i--) { @@ -204,7 +194,6 @@ function initCharts() { bitrateHistory.value = new Array(MAX_POINTS).fill(0) } - // Network Stability (Jitter) Chart if (stabilityChartRef.value && !stabilityChart) { const opts = createChartOptions(stabilityChartRef.value, 'ms', (v) => `${v.toFixed(0)} ms`) opts.plugins = [createTooltipPlugin('stability', 'ms')] @@ -215,7 +204,6 @@ function initCharts() { ) } - // Playback Delay Chart if (delayChartRef.value && !delayChart) { const opts = createChartOptions(delayChartRef.value, 'ms', (v) => `${v.toFixed(0)} ms`) opts.plugins = [createTooltipPlugin('delay', 'ms')] @@ -237,7 +225,6 @@ function initCharts() { ) } - // FPS Chart if (fpsChartRef.value && !fpsChart) { const opts = createChartOptions(fpsChartRef.value, 'fps', (v) => `${v.toFixed(0)} fps`) opts.plugins = [createTooltipPlugin('fps', 'fps')] @@ -250,7 +237,6 @@ function initCharts() { }) } -// Destroy charts function destroyCharts() { stabilityChart?.destroy() stabilityChart = null @@ -262,22 +248,18 @@ function destroyCharts() { fpsChart = null } -// Add data point function addDataPoint() { const now = Date.now() / 1000 - // Shift timestamps timestamps.value.push(now) if (timestamps.value.length > MAX_POINTS) { timestamps.value.shift() } if (isWebRTC.value && props.webrtcStats) { - // Jitter in ms const jitter = (props.webrtcStats.jitter || 0) * 1000 jitterHistory.value.push(jitter) - // RTT (round trip time) as delay in ms const rtt = (props.webrtcStats.roundTripTime || 0) * 1000 delayHistory.value.push(rtt) @@ -287,10 +269,8 @@ function addDataPoint() { lastPacketsLost = currentLost packetLossHistory.value.push(lostDelta) - // FPS fpsHistory.value.push(props.webrtcStats.framesPerSecond || 0) - // Calculate bitrate const currentBytes = props.webrtcStats.bytesReceived || 0 const currentTime = Date.now() if (lastTimestamp > 0 && currentBytes > lastBytesReceived) { @@ -312,18 +292,15 @@ function addDataPoint() { bitrateHistory.value.push(0) } - // Trim arrays if (jitterHistory.value.length > MAX_POINTS) jitterHistory.value.shift() if (delayHistory.value.length > MAX_POINTS) delayHistory.value.shift() if (packetLossHistory.value.length > MAX_POINTS) packetLossHistory.value.shift() if (fpsHistory.value.length > MAX_POINTS) fpsHistory.value.shift() if (bitrateHistory.value.length > MAX_POINTS) bitrateHistory.value.shift() - // Update charts updateCharts() } -// Update charts with new data function updateCharts() { stabilityChart?.setData([timestamps.value, jitterHistory.value]) delayChart?.setData([timestamps.value, delayHistory.value]) @@ -331,7 +308,6 @@ function updateCharts() { fpsChart?.setData([timestamps.value, fpsHistory.value]) } -// Data collection interval let dataInterval: number | null = null function startDataCollection() { @@ -346,7 +322,6 @@ function stopDataCollection() { } } -// Format candidate type for display function formatCandidateType(type: string): string { const typeMap: Record = { host: 'Host (Local)', @@ -393,10 +368,8 @@ const currentStats = computed(() => { } }) -// Watch open state watch(() => props.open, (isOpen) => { if (isOpen) { - // Reset data timestamps.value = [] jitterHistory.value = [] delayHistory.value = [] @@ -417,7 +390,6 @@ watch(() => props.open, (isOpen) => { } }) -// Resize handler function handleResize() { if (!props.open) return destroyCharts() diff --git a/web/src/components/StatusCard.vue b/web/src/components/StatusCard.vue index 531edca4..8a6527e8 100644 --- a/web/src/components/StatusCard.vue +++ b/web/src/components/StatusCard.vue @@ -92,7 +92,6 @@ const statusIcon = computed(() => { } }) -// Localized status text const statusText = computed(() => { switch (props.status) { case 'connected': @@ -108,7 +107,6 @@ const statusText = computed(() => { } }) -// Localized status badge text (for hover card) const statusBadgeText = computed(() => { switch (props.status) { case 'connected': diff --git a/web/src/components/VideoConfigPopover.vue b/web/src/components/VideoConfigPopover.vue index 21abba3f..831edf06 100644 --- a/web/src/components/VideoConfigPopover.vue +++ b/web/src/components/VideoConfigPopover.vue @@ -68,7 +68,6 @@ const router = useRouter() const devices = ref([]) const loadingDevices = ref(false) -// Codec list const codecs = ref([]) const loadingCodecs = ref(false) @@ -135,7 +134,6 @@ function detectBrowserCodecSupport() { if (capabilities?.codecs) { for (const codec of capabilities.codecs) { const mimeType = codec.mimeType.toLowerCase() - // Map MIME types to our codec IDs if (mimeType.includes('h264') || mimeType.includes('avc')) { supported.add('h264') } @@ -154,7 +152,6 @@ function detectBrowserCodecSupport() { } } } else { - // Fallback: assume basic codecs are supported supported.add('h264') supported.add('vp8') supported.add('vp9') @@ -191,12 +188,10 @@ const translateBackendName = (backend: string | undefined): string => { return backend } -// Check if a format has fps >= 30 in any resolution const hasHighFps = (format: { resolutions: { fps: number[] }[] }): boolean => { return format.resolutions.some(res => res.fps.some(fps => fps >= 30)) } -// Check if a format is recommended based on video mode const isFormatRecommended = (formatName: string): boolean => { if (!isVideoFormatSelectable(formatName, props.videoMode, currentEncoderBackend.value)) { return false @@ -214,20 +209,16 @@ const isFormatRecommended = (formatName: string): boolean => { const currentFormat = formats.find(f => f.format.toUpperCase() === upperFormat) if (!currentFormat) return false - // Check if NV12 exists with fps >= 30 const nv12Format = formats.find(f => f.format.toUpperCase() === 'NV12') const nv12HasHighFps = nv12Format && hasHighFps(nv12Format) - // Check if YUYV exists with fps >= 30 const yuyvFormat = formats.find(f => f.format.toUpperCase() === 'YUYV') const yuyvHasHighFps = yuyvFormat && hasHighFps(yuyvFormat) - // Priority 1: NV12 with high fps if (nv12HasHighFps) { return upperFormat === 'NV12' } - // Priority 2: YUYV with high fps (only if NV12 doesn't qualify) if (yuyvHasHighFps) { return upperFormat === 'YUYV' } @@ -235,13 +226,11 @@ const isFormatRecommended = (formatName: string): boolean => { return false } -// Check if a format is not recommended for current video mode // In WebRTC mode, compressed formats (MJPEG/JPEG) are not recommended const isFormatNotRecommended = (formatName: string): boolean => { return getFormatState(formatName) === 'not_recommended' } -// Selected values (mode comes from props) const selectedDevice = ref('') const selectedFormat = ref('') const selectedResolution = ref('') @@ -249,11 +238,9 @@ const selectedFps = ref(30) const selectedBitratePreset = ref<'Speed' | 'Balanced' | 'Quality'>('Balanced') const isDirty = ref(false) -// UI state const applying = ref(false) const applyingBitrate = ref(false) -// Current config from store const currentConfig = computed(() => ({ device: configStore.video?.device || '', format: configStore.video?.format || '', @@ -262,7 +249,6 @@ const currentConfig = computed(() => ({ fps: configStore.video?.fps || 30, })) -// Button display text - simplified to just show label const buttonText = computed(() => t('actionbar.videoConfig')) // Available codecs for selection (filtered by backend support and enriched with backend info) @@ -305,7 +291,6 @@ const availableCodecs = computed(() => { return backendFiltered.filter(codec => allowed.includes(codec.id)) }) -// Cascading filters const availableFormats = computed(() => { const device = devices.value.find(d => d.path === selectedDevice.value) return device?.formats || [] @@ -331,13 +316,11 @@ const availableFps = computed(() => { return resolution?.fps || [] }) -// Get selected format description for display in trigger const selectedFormatInfo = computed(() => { const format = availableFormatOptions.value.find(f => f.format === selectedFormat.value) return format }) -// Get selected codec info for display in trigger const selectedCodecInfo = computed(() => { const codec = availableCodecs.value.find(c => c.id === props.videoMode) return codec || null @@ -366,7 +349,6 @@ async function loadCodecs() { backends.value = result.backends || [] } catch (e) { console.info('[VideoConfig] Failed to load codecs') - // Fallback to default codecs codecs.value = [ { id: 'mjpeg', name: 'MJPEG / HTTP', protocol: 'http', hardware: false, backend: 'software', available: true }, { id: 'h264', name: 'H.264 / WebRTC', protocol: 'webrtc', hardware: false, backend: 'software', available: true }, @@ -384,12 +366,10 @@ async function loadConstraints() { } } -// Navigate to settings page (video tab) function goToSettings() { router.push('/settings?tab=video') } -// Initialize selected values from current config function initializeFromCurrent() { const config = currentConfig.value selectedDevice.value = config.device @@ -417,7 +397,6 @@ function syncFromCurrentIfChanged() { isDirty.value = false } -// Handle video mode change function handleVideoModeChange(mode: unknown) { if (typeof mode !== 'string') return @@ -476,7 +455,6 @@ function handleDeviceChange(devicePath: unknown) { selectedDevice.value = devicePath isDirty.value = true - // Auto-select first format const device = devices.value.find(d => d.path === devicePath) const format = device ? findFirstSelectableFormat(device.formats) : undefined if (!format) { @@ -487,7 +465,6 @@ function handleDeviceChange(devicePath: unknown) { selectFormatWithDefaults(format.format) } -// Handle format change function handleFormatChange(format: unknown) { if (typeof format !== 'string') return if (isFormatUnsupported(format)) return @@ -496,13 +473,11 @@ function handleFormatChange(format: unknown) { isDirty.value = true } -// Handle resolution change function handleResolutionChange(resolution: unknown) { if (typeof resolution !== 'string') return selectedResolution.value = resolution isDirty.value = true - // Auto-select first FPS for this resolution const resolutionData = availableResolutions.value.find( r => `${r.width}x${r.height}` === resolution ) @@ -511,14 +486,12 @@ function handleResolutionChange(resolution: unknown) { } } -// Handle FPS change function handleFpsChange(fps: unknown) { if (typeof fps !== 'string' && typeof fps !== 'number') return selectedFps.value = typeof fps === 'string' ? Number(fps) : fps isDirty.value = true } -// Apply bitrate preset change async function applyBitratePreset(preset: 'Speed' | 'Balanced' | 'Quality') { if (applyingBitrate.value) return applyingBitrate.value = true @@ -532,7 +505,6 @@ async function applyBitratePreset(preset: 'Speed' | 'Balanced' | 'Quality') { } } -// Handle bitrate preset selection function handleBitratePresetChange(preset: 'Speed' | 'Balanced' | 'Quality') { selectedBitratePreset.value = preset if (props.videoMode !== 'mjpeg') { @@ -540,7 +512,6 @@ function handleBitratePresetChange(preset: 'Speed' | 'Balanced' | 'Quality') { } } -// Apply video configuration async function applyVideoConfig() { const [width, height] = selectedResolution.value.split('x').map(Number) @@ -559,13 +530,11 @@ async function applyVideoConfig() { // Stream state will be updated via WebSocket system.device_info event } catch (e) { console.info('[VideoConfig] Failed to apply config:', e) - // Error toast already shown by API layer } finally { applying.value = false } } -// Watch open state watch(() => props.open, (isOpen) => { if (!isOpen) { isDirty.value = false diff --git a/web/src/components/VirtualKeyboard.vue b/web/src/components/VirtualKeyboard.vue index 5d3324d7..9166abba 100644 --- a/web/src/components/VirtualKeyboard.vue +++ b/web/src/components/VirtualKeyboard.vue @@ -38,20 +38,16 @@ const emit = defineEmits<{ const { t } = useI18n() -// State const isAttached = ref(props.attached ?? true) const selectedOs = ref('windows') -// Keyboard instances const mainKeyboard = ref(null) const controlKeyboard = ref(null) const arrowsKeyboard = ref(null) -// Pressed keys tracking const pressedModifiers = ref(0) const keysDown = ref([]) -// Shift state for display const isShiftActive = computed(() => { return (pressedModifiers.value & 0x22) !== 0 }) @@ -64,7 +60,6 @@ const layoutName = computed(() => { return isShiftActive.value ? 'shift' : 'default' }) -// Keys currently pressed (for highlighting) const keyNamesForDownKeys = computed(() => { const activeModifierMask = pressedModifiers.value || 0 const modifierNames = Object.entries(modifiers) @@ -79,19 +74,15 @@ const keyNamesForDownKeys = computed(() => { ])) }) -// Dragging state (for floating mode) const keyboardRef = ref(null) const isDragging = ref(false) const dragOffset = ref({ x: 0, y: 0 }) const position = ref({ x: 100, y: 100 }) -// Unique ID for this keyboard instance const keyboardId = ref(`kb-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`) -// Get bottom row based on selected OS const getBottomRow = () => osBottomRows[selectedOs.value].join(' ') -// Keyboard layouts - matching JetKVM style const keyboardLayout = { main: { default: [ @@ -143,19 +134,15 @@ function setCompactLayout(active: boolean) { updateKeyboardLayout() } -// Key display mapping with Unicode symbols (JetKVM style) const keyDisplayMap = computed>(() => { - // OS-specific Meta key labels const metaLabel = selectedOs.value === 'windows' ? '⊞ Win' : selectedOs.value === 'mac' ? '⌘ Cmd' : 'Meta' return { - // Macros - compact format CtrlAltDelete: 'Ctrl+Alt+Del', AltMetaEscape: 'Alt+Meta+Esc', CtrlAltBackspace: 'Ctrl+Alt+Bksp', - // Modifiers - simplified ControlLeft: 'Ctrl', ControlRight: 'Ctrl', ShiftLeft: 'Shift', @@ -166,7 +153,6 @@ const keyDisplayMap = computed>(() => { MetaRight: metaLabel, ContextMenu: 'Menu', - // Special keys Escape: 'Esc', Backspace: '⌫', Tab: 'Tab', @@ -174,7 +160,6 @@ const keyDisplayMap = computed>(() => { Enter: 'Enter', Space: ' ', - // Navigation Insert: 'Ins', Delete: 'Del', Home: 'Home', @@ -182,23 +167,19 @@ const keyDisplayMap = computed>(() => { PageUp: 'PgUp', PageDown: 'PgDn', - // Arrows ArrowUp: '↑', ArrowDown: '↓', ArrowLeft: '←', ArrowRight: '→', - // Control cluster PrintScreen: 'PrtSc', ScrollLock: 'ScrLk', Pause: 'Pause', - // Function keys F1: 'F1', F2: 'F2', F3: 'F3', F4: 'F4', F5: 'F5', F6: 'F6', F7: 'F7', F8: 'F8', F9: 'F9', F10: 'F10', F11: 'F11', F12: 'F12', - // Letters KeyA: areLettersUppercase.value ? 'A' : 'a', KeyB: areLettersUppercase.value ? 'B' : 'b', KeyC: areLettersUppercase.value ? 'C' : 'c', @@ -226,7 +207,6 @@ const keyDisplayMap = computed>(() => { KeyY: areLettersUppercase.value ? 'Y' : 'y', KeyZ: areLettersUppercase.value ? 'Z' : 'z', - // Letter labels in the shifted layout follow CapsLock xor Shift too '(KeyA)': areLettersUppercase.value ? 'A' : 'a', '(KeyB)': areLettersUppercase.value ? 'B' : 'b', '(KeyC)': areLettersUppercase.value ? 'C' : 'c', @@ -254,15 +234,12 @@ const keyDisplayMap = computed>(() => { '(KeyY)': areLettersUppercase.value ? 'Y' : 'y', '(KeyZ)': areLettersUppercase.value ? 'Z' : 'z', - // Numbers Digit1: '1', Digit2: '2', Digit3: '3', Digit4: '4', Digit5: '5', Digit6: '6', Digit7: '7', Digit8: '8', Digit9: '9', Digit0: '0', - // Shifted Numbers '(Digit1)': '!', '(Digit2)': '@', '(Digit3)': '#', '(Digit4)': '$', '(Digit5)': '%', '(Digit6)': '^', '(Digit7)': '&', '(Digit8)': '*', '(Digit9)': '(', '(Digit0)': ')', - // Symbols Minus: '-', '(Minus)': '_', Equal: '=', '(Equal)': '+', BracketLeft: '[', '(BracketLeft)': '{', @@ -277,7 +254,6 @@ const keyDisplayMap = computed>(() => { } }) -// Handle media key press (Consumer Control) async function onMediaKeyPress(key: string) { if (key in consumerKeys) { const usage = consumerKeys[key as ConsumerKeyName] @@ -289,16 +265,12 @@ async function onMediaKeyPress(key: string) { } } -// Switch OS layout function switchOsLayout(os: KeyboardOsType) { selectedOs.value = os - // Save preference to localStorage localStorage.setItem('vkb-os-layout', os) - // Update keyboard layout updateKeyboardLayout() } -// Update keyboard layout based on selected OS function updateKeyboardLayout() { const bottomRow = getBottomRow() const baseLayout = isCompactLayout.value ? compactMainLayout : keyboardLayout.main @@ -319,9 +291,7 @@ function updateKeyboardLayout() { updateKeyboardButtonTheme() } -// Key press handler async function onKeyDown(key: string) { - // Handle macro keys if (key === 'CtrlAltDelete') { await executeMacro([ { keys: ['Delete'], modifiers: ['ControlLeft', 'AltLeft'] }, @@ -343,10 +313,8 @@ async function onKeyDown(key: string) { return } - // Clean key name (remove parentheses for shifted keys) const cleanKey = key.replace(/[()]/g, '') - // Check if key exists if (!(cleanKey in keys)) { console.warn(`[VirtualKeyboard] Unknown key: ${cleanKey}`) return @@ -354,7 +322,6 @@ async function onKeyDown(key: string) { const keyCode = keys[cleanKey as KeyName] - // Handle latching keys (Caps Lock, etc.) if (latchingKeys.some(latchingKey => latchingKey === keyCode)) { emit('keyDown', keyCode) const currentMask = pressedModifiers.value & 0xff @@ -366,7 +333,6 @@ async function onKeyDown(key: string) { return } - // Handle modifier keys (toggle) const mask = modifiers[keyCode] ?? 0 if (mask !== 0) { const isCurrentlyDown = (pressedModifiers.value & mask) !== 0 @@ -386,7 +352,6 @@ async function onKeyDown(key: string) { return } - // Regular key: press and release keysDown.value.push(keyCode) emit('keyDown', keyCode) const currentMask = pressedModifiers.value & 0xff @@ -401,7 +366,6 @@ async function onKeyDown(key: string) { } async function onKeyUp() { - // Not used for now - we handle up in onKeyDown with setTimeout } async function sendKeyPress(keyCode: CanonicalKey, press: boolean, modifierMask: number) { @@ -453,7 +417,6 @@ async function executeMacro(steps: MacroStep[]) { } } -// Update keyboard button theme for pressed keys function updateKeyboardButtonTheme() { const downKeys = keyNamesForDownKeys.value.join(' ') const buttonTheme = [ @@ -472,7 +435,6 @@ function updateKeyboardButtonTheme() { arrowsKeyboard.value?.setOptions({ buttonTheme }) } -// Update layout when shift state changes watch([layoutName, () => props.capsLock], ([name]) => { mainKeyboard.value?.setOptions({ layoutName: name, @@ -481,11 +443,9 @@ watch([layoutName, () => props.capsLock], ([name]) => { updateKeyboardButtonTheme() }) -// Initialize keyboards with unique selectors function initKeyboards() { const id = keyboardId.value - // Check if elements exist - use full selector with # const mainEl = document.querySelector(`#${id}-main`) const controlEl = document.querySelector(`#${id}-control`) const arrowsEl = document.querySelector(`#${id}-arrows`) @@ -496,7 +456,6 @@ function initKeyboards() { return } - // Main keyboard - pass element directly instead of selector string mainKeyboard.value = new Keyboard(mainEl, { layout: isCompactLayout.value ? compactMainLayout : keyboardLayout.main, layoutName: layoutName.value, @@ -517,7 +476,6 @@ function initKeyboards() { stopMouseUpPropagation: true, }) - // Control keyboard controlKeyboard.value = new Keyboard(controlEl, { layout: keyboardLayout.control, layoutName: 'default', @@ -532,7 +490,6 @@ function initKeyboards() { stopMouseUpPropagation: true, }) - // Arrows keyboard arrowsKeyboard.value = new Keyboard(arrowsEl, { layout: keyboardLayout.arrows, layoutName: 'default', @@ -551,7 +508,6 @@ function initKeyboards() { console.log('[VirtualKeyboard] Keyboards initialized:', id) } -// Destroy keyboards function destroyKeyboards() { mainKeyboard.value?.destroy() controlKeyboard.value?.destroy() @@ -561,7 +517,6 @@ function destroyKeyboards() { arrowsKeyboard.value = null } -// Dragging handlers function getClientCoords(e: MouseEvent | TouchEvent): { x: number; y: number } | null { if ('touches' in e) { const touch = e.touches[0] @@ -609,11 +564,9 @@ async function toggleAttached() { isAttached.value = !isAttached.value emit('update:attached', isAttached.value) - // Wait for Teleport to move the component await nextTick() await nextTick() // Extra tick for Teleport - // Reinitialize keyboards in new location setTimeout(() => { initKeyboards() }, 100) @@ -623,7 +576,6 @@ function close() { emit('update:visible', false) } -// Watch visibility to init/destroy keyboards watch(() => props.visible, async (visible) => { console.log('[VirtualKeyboard] Visibility changed:', visible, 'attached:', isAttached.value, 'id:', keyboardId.value) if (visible) { @@ -641,7 +593,6 @@ watch(() => props.attached, (value) => { }) onMounted(() => { - // Load saved OS layout preference const savedOs = localStorage.getItem('vkb-os-layout') as KeyboardOsType | null if (savedOs && ['windows', 'mac', 'android'].includes(savedOs)) { selectedOs.value = savedOs diff --git a/web/src/composables/useAudioPlayer.ts b/web/src/composables/useAudioPlayer.ts index 0387a389..70f03b90 100644 --- a/web/src/composables/useAudioPlayer.ts +++ b/web/src/composables/useAudioPlayer.ts @@ -1,20 +1,14 @@ -// Audio player composable - handles WebSocket connection, Opus decoding, and Web Audio API playback - import { ref, watch } from 'vue' import { OpusDecoder } from 'opus-decoder' import { buildWsUrl } from '@/types/websocket' -// Binary protocol header format (15 bytes) -// [type:1][timestamp:4][duration:2][sequence:4][length:4][data:...] export function useAudioPlayer() { - // State const connected = ref(false) const playing = ref(false) const volume = ref(0) // Default to 0, user must adjust to enable audio (browser autoplay policy) const error = ref(null) - // Internal variables let ws: WebSocket | null = null let audioContext: AudioContext | null = null let gainNode: GainNode | null = null @@ -23,7 +17,6 @@ export function useAudioPlayer() { let nextPlayTime = 0 let isConnecting = false // Prevent concurrent connection attempts - // Initialize decoder async function initDecoder() { const opusDecoder = new OpusDecoder({ channels: 2, @@ -33,7 +26,6 @@ export function useAudioPlayer() { decoder = opusDecoder } - // Initialize audio context function initAudioContext() { audioContext = new AudioContext({ sampleRate: 48000 }) gainNode = audioContext.createGain() @@ -41,14 +33,12 @@ export function useAudioPlayer() { updateVolume() } - // Connect to WebSocket async function connect() { // Prevent concurrent connection attempts (critical fix for multiple WS connections) if (isConnecting) { return } - // Check if already connected if (ws) { if (ws.readyState === WebSocket.OPEN) { return @@ -64,7 +54,6 @@ export function useAudioPlayer() { isConnecting = true try { - // Initialize if (!decoder) await initDecoder() if (!audioContext) initAudioContext() @@ -108,7 +97,6 @@ export function useAudioPlayer() { } } - // Disconnect function disconnect() { if (ws) { ws.close() @@ -140,14 +128,12 @@ export function useAudioPlayer() { const samplesPerChannel = decoded.samplesDecoded const channels = decoded.channelData.length - // Create audio buffer const audioBuffer = audioContext.createBuffer( channels, samplesPerChannel, 48000 ) - // Fill channel data for (let ch = 0; ch < channels; ch++) { const channelData = audioBuffer.getChannelData(ch) const sourceData = decoded.channelData[ch] @@ -156,7 +142,6 @@ export function useAudioPlayer() { } } - // Schedule playback const source = audioContext.createBufferSource() source.buffer = audioBuffer source.connect(gainNode) @@ -164,12 +149,10 @@ export function useAudioPlayer() { const now = audioContext.currentTime const scheduledAhead = nextPlayTime - now - // Reset if too far behind (audio was paused/lagged) if (nextPlayTime < now) { nextPlayTime = now + 0.02 // Start 20ms ahead } - // Reset if buffer too large (> 500ms ahead) if (scheduledAhead > 0.5) { nextPlayTime = now + 0.05 // Keep 50ms buffer } @@ -177,33 +160,27 @@ export function useAudioPlayer() { source.start(nextPlayTime) nextPlayTime += audioBuffer.duration } catch { - // Ignore decode errors } } - // Update volume function updateVolume() { if (gainNode) { gainNode.gain.value = volume.value } } - // Set volume function setVolume(v: number) { volume.value = Math.max(0, Math.min(1, v)) updateVolume() } - // Watch volume changes watch(volume, updateVolume) return { - // State connected, playing, volume, error, - // Methods connect, disconnect, setVolume, diff --git a/web/src/composables/useConfigPopover.ts b/web/src/composables/useConfigPopover.ts index ca4fe4cf..4ef50945 100644 --- a/web/src/composables/useConfigPopover.ts +++ b/web/src/composables/useConfigPopover.ts @@ -1,5 +1,3 @@ -// Config popover composable - shared logic for config popover components -// Provides common state management and lifecycle hooks import { ref, watch, type Ref } from 'vue' import { useI18n } from 'vue-i18n' @@ -17,11 +15,9 @@ export interface UseConfigPopoverOptions { export function useConfigPopover(options: UseConfigPopoverOptions) { const { t } = useI18n() - // Common state const applying = ref(false) const loadingDevices = ref(false) - // Watch open state to initialize watch(options.open, async (isOpen) => { if (isOpen) { options.initializeFromCurrent?.() @@ -36,7 +32,6 @@ export function useConfigPopover(options: UseConfigPopoverOptions) { } }) - // Apply config wrapper with loading state and toast async function applyConfig(applyFn: () => Promise) { applying.value = true try { @@ -44,7 +39,6 @@ export function useConfigPopover(options: UseConfigPopoverOptions) { toast.success(t('config.applied')) } catch (e) { console.info('[ConfigPopover] Apply failed:', e) - // Error toast is usually shown by API layer } finally { applying.value = false } @@ -62,11 +56,9 @@ export function useConfigPopover(options: UseConfigPopoverOptions) { } return { - // State applying, loadingDevices, - // Methods applyConfig, refreshDevices, } diff --git a/web/src/composables/useConsoleEvents.ts b/web/src/composables/useConsoleEvents.ts index 027c08f2..8045aaac 100644 --- a/web/src/composables/useConsoleEvents.ts +++ b/web/src/composables/useConsoleEvents.ts @@ -1,6 +1,3 @@ -// Console WebSocket events composable - handles all WebSocket event subscriptions -// Extracted from ConsoleView.vue for better separation of concerns - import { useI18n } from 'vue-i18n' import { toast } from 'vue-sonner' import { useSystemStore } from '@/stores/system' @@ -33,7 +30,7 @@ export function useConsoleEvents(handlers: ConsoleEventHandlers) { const systemStore = useSystemStore() const { on, off, connect } = useWebSocket() const noop = () => {} - // Stream device monitoring handlers + function handleStreamDeviceLost(data: { device: string; reason: string }) { if (systemStore.stream) { systemStore.stream.online = false @@ -66,20 +63,11 @@ export function useConsoleEvents(handlers: ConsoleEventHandlers) { handlers.onStreamRecovered?.(_data) } - function handleStreamStateChanged(data: { state: string }) { - if (data.state === 'error') { - // Handled by video stream composable - } - } - function handleStreamStateChangedForward(data: { state: string; device?: string | null }) { - handleStreamStateChanged(data) handlers.onStreamStateChanged?.(data) } - // Subscribe to all events function subscribe() { - // Stream events on('stream.config_changing', handlers.onStreamConfigChanging ?? noop) on('stream.config_applied', handlers.onStreamConfigApplied ?? noop) on('stream.stats_update', handlers.onStreamStatsUpdate ?? noop) @@ -92,14 +80,11 @@ export function useConsoleEvents(handlers: ConsoleEventHandlers) { on('stream.reconnecting', handleStreamReconnecting) on('stream.recovered', handleStreamRecovered) - // System events on('system.device_info', handlers.onDeviceInfo ?? noop) - // Connect WebSocket connect() } - // Unsubscribe from all events function unsubscribe() { off('stream.config_changing', handlers.onStreamConfigChanging ?? noop) off('stream.config_applied', handlers.onStreamConfigApplied ?? noop) diff --git a/web/src/composables/useHidWebSocket.ts b/web/src/composables/useHidWebSocket.ts index e8d7159e..172c38e8 100644 --- a/web/src/composables/useHidWebSocket.ts +++ b/web/src/composables/useHidWebSocket.ts @@ -1,6 +1,3 @@ -// WebSocket HID channel for low-latency keyboard/mouse input (binary protocol) -// Uses the same binary format as WebRTC DataChannel for consistency - import { ref, onUnmounted } from 'vue' import { type HidKeyboardEvent, @@ -23,26 +20,22 @@ const reconnectAttempts = ref(0) const networkError = ref(false) const networkErrorMessage = ref(null) let reconnectTimeout: number | null = null -const hidUnavailable = ref(false) // Track if HID is unavailable to prevent unnecessary reconnects +const hidUnavailable = ref(false) -// Connection promise to avoid race conditions let connectionPromise: Promise | null = null let connectionResolved = false function connect(): Promise { - // If already connected, return immediately if (wsInstance && wsInstance.readyState === WebSocket.OPEN && connectionResolved) { return Promise.resolve(true) } - // If connection is in progress, return the existing promise if (connectionPromise && !connectionResolved) { return connectionPromise } connectionResolved = false connectionPromise = new Promise((resolve) => { - // Reset network error flag when attempting new connection networkError.value = false networkErrorMessage.value = null hidUnavailable.value = false @@ -60,7 +53,6 @@ function connect(): Promise { } wsInstance.onmessage = (e) => { - // Handle binary response if (e.data instanceof ArrayBuffer) { const view = new DataView(e.data) if (view.byteLength >= 1) { @@ -126,14 +118,12 @@ function disconnect() { } if (wsInstance) { - // Close the websocket wsInstance.close() wsInstance = null connected.value = false networkError.value = false } - // Reset connection state connectionPromise = null connectionResolved = false } @@ -154,7 +144,6 @@ function sendKeyboard(event: HidKeyboardEvent): Promise { }) } -// Internal function to actually send mouse event function _sendMouseInternal(event: HidMouseEvent): Promise { return new Promise((resolve, reject) => { if (!wsInstance || wsInstance.readyState !== WebSocket.OPEN) { @@ -175,7 +164,6 @@ function sendMouse(event: HidMouseEvent): Promise { return _sendMouseInternal(event) } -// Send consumer control event (multimedia keys) function sendConsumer(event: HidConsumerEvent): Promise { return new Promise((resolve, reject) => { if (!wsInstance || wsInstance.readyState !== WebSocket.OPEN) { @@ -194,8 +182,6 @@ function sendConsumer(event: HidConsumerEvent): Promise { export function useHidWebSocket() { onUnmounted(() => { - // Don't disconnect on component unmount - WebSocket is shared - // Only disconnect when explicitly called or page unloads }) return { @@ -212,7 +198,6 @@ export function useHidWebSocket() { } } -// Global lifecycle - disconnect when page unloads if (typeof window !== 'undefined') { window.addEventListener('beforeunload', () => { disconnect() diff --git a/web/src/composables/useUnifiedAudio.ts b/web/src/composables/useUnifiedAudio.ts index c01242ac..b9b33012 100644 --- a/web/src/composables/useUnifiedAudio.ts +++ b/web/src/composables/useUnifiedAudio.ts @@ -1,4 +1,3 @@ -// Unified Audio Manager // Manages audio playback across different video modes (MJPEG/WebSocket and H264/WebRTC) // Provides a single interface for volume control and audio source switching @@ -17,7 +16,6 @@ export interface UnifiedAudioState { } export function useUnifiedAudio() { - // === State === const audioMode = ref('ws') const volume = ref(0) // 0-1, default muted (browser autoplay policy) const muted = ref(false) @@ -25,11 +23,9 @@ export function useUnifiedAudio() { const playing = ref(false) const error = ref(null) - // === Internal References === const wsPlayer = getAudioPlayer() let webrtcVideoElement: HTMLVideoElement | null = null - // === Methods === /** * Set the WebRTC video element reference @@ -39,9 +35,7 @@ export function useUnifiedAudio() { // Only update if element is provided (don't clear on null to preserve reference) if (el) { webrtcVideoElement = el - // Sync current volume to video element el.volume = volume.value - // Mute if volume is 0 or explicitly muted const shouldMute = muted.value || volume.value === 0 el.muted = shouldMute } @@ -57,7 +51,6 @@ export function useUnifiedAudio() { const wasConnected = connected.value const wasPlaying = playing.value - // Disconnect old mode if (audioMode.value === 'ws') { wsPlayer.disconnect() } @@ -65,12 +58,10 @@ export function useUnifiedAudio() { audioMode.value = mode - // If was connected/playing and volume > 0, auto-connect new mode if ((wasConnected || wasPlaying) && volume.value > 0) { await connect() } - // Update connection state updateConnectionState() } @@ -82,7 +73,6 @@ export function useUnifiedAudio() { const newVolume = Math.max(0, Math.min(1, v)) volume.value = newVolume - // Sync to WS player wsPlayer.setVolume(newVolume) // Sync to WebRTC video element @@ -99,7 +89,6 @@ export function useUnifiedAudio() { function setMuted(m: boolean) { muted.value = m - // WS player: control via volume (no separate mute) if (audioMode.value === 'ws') { wsPlayer.setVolume(m ? 0 : volume.value) } @@ -133,7 +122,6 @@ export function useUnifiedAudio() { } } else { // WebRTC audio is automatically connected via video track - // Just ensure video element is not muted (if volume > 0) if (webrtcVideoElement) { webrtcVideoElement.muted = muted.value || volume.value === 0 connected.value = true @@ -173,7 +161,6 @@ export function useUnifiedAudio() { } } - // Watch WS player state changes watch(() => wsPlayer.connected.value, (newConnected) => { if (audioMode.value === 'ws') { connected.value = newConnected @@ -193,7 +180,6 @@ export function useUnifiedAudio() { }) return { - // State audioMode, volume, muted, @@ -201,7 +187,6 @@ export function useUnifiedAudio() { playing, error, - // Methods setWebRTCElement, switchMode, setVolume, diff --git a/web/src/composables/useWebRTC.ts b/web/src/composables/useWebRTC.ts index 01765f9b..dba26b76 100644 --- a/web/src/composables/useWebRTC.ts +++ b/web/src/composables/useWebRTC.ts @@ -85,7 +85,6 @@ async function fetchIceServers(): Promise { } // Fallback: for local connections, use no ICE servers (host candidates only) - // For remote connections, use Google STUN as fallback const isLocalConnection = typeof window !== 'undefined' && (window.location.hostname === 'localhost' || window.location.hostname === '127.0.0.1' || @@ -105,7 +104,6 @@ async function fetchIceServers(): Promise { ] } -// Shared instance state let peerConnection: RTCPeerConnection | null = null let dataChannel: RTCDataChannel | null = null let sessionId: string | null = null @@ -146,7 +144,6 @@ const error = ref(null) const dataChannelReady = ref(false) const connectStage = ref('idle') -// Create RTCPeerConnection with configuration function createPeerConnection(iceServers: RTCIceServer[]): RTCPeerConnection { const config: RTCConfiguration = { iceServers, @@ -155,7 +152,6 @@ function createPeerConnection(iceServers: RTCIceServer[]): RTCPeerConnection { const pc = new RTCPeerConnection(config) - // Handle connection state changes pc.onconnectionstatechange = () => { switch (pc.connectionState) { case 'connecting': @@ -211,7 +207,6 @@ function createPeerConnection(iceServers: RTCIceServer[]): RTCPeerConnection { } } - // Handle incoming tracks pc.ontrack = (event) => { const track = event.track @@ -222,7 +217,6 @@ function createPeerConnection(iceServers: RTCIceServer[]): RTCPeerConnection { } } - // Handle data channel from server pc.ondatachannel = (event) => { setupDataChannel(event.channel) } @@ -230,7 +224,6 @@ function createPeerConnection(iceServers: RTCIceServer[]): RTCPeerConnection { return pc } -// Setup data channel event handlers function setupDataChannel(channel: RTCDataChannel) { dataChannel = channel @@ -243,15 +236,12 @@ function setupDataChannel(channel: RTCDataChannel) { } channel.onerror = () => { - // Data channel errors handled silently } channel.onmessage = () => { - // Handle incoming messages from server (e.g., LED status) } } -// Create data channel for HID events function createDataChannel(pc: RTCPeerConnection): RTCDataChannel { const channel = pc.createDataChannel('hid', { ordered: true, @@ -307,7 +297,6 @@ async function handleRemoteIceComplete(data: WebRTCIceCompleteEvent) { try { await peerConnection.addIceCandidate(null) } catch { - // End-of-candidates failures are non-fatal } } @@ -348,7 +337,6 @@ async function flushPendingRemoteIce() { try { await peerConnection.addIceCandidate(null) } catch { - // Ignore end-of-candidates errors } } } @@ -363,14 +351,12 @@ function startStatsCollection() { try { const report = await peerConnection.getStats() - // Collect candidate info const candidates: Record = {} let selectedPairLocalId = '' let selectedPairRemoteId = '' let foundActivePair = false report.forEach((stat) => { - // Collect all candidates if (stat.type === 'local-candidate' || stat.type === 'remote-candidate') { candidates[stat.id] = { type: (stat.candidateType as IceCandidateType) || 'unknown', @@ -378,10 +364,7 @@ function startStatsCollection() { } } - // Find the active candidate pair - // Priority: nominated > succeeded (for Chrome/Firefox compatibility) if (stat.type === 'candidate-pair') { - // Check if this is the nominated/selected pair const isActive = stat.nominated === true || (stat.state === 'succeeded' && stat.selected === true) || (stat.state === 'in-progress' && !foundActivePair) @@ -425,8 +408,6 @@ function startStatsCollection() { stats.value.remoteCandidateType = remoteCandidate.type } - // Check if using TURN relay - // TURN relay is when either local or remote candidate is 'relay' type stats.value.isRelay = stats.value.localCandidateType === 'relay' || stats.value.remoteCandidateType === 'relay' } catch { // Stats collection errors are non-fatal @@ -473,7 +454,6 @@ async function connect(): Promise { connectInFlight = (async () => { registerWebSocketHandlers() - // Prevent concurrent connection attempts if (isConnecting) { return state.value === 'connected' } @@ -484,7 +464,6 @@ async function connect(): Promise { isConnecting = true - // Clean up any existing connection first if (peerConnection || sessionId) { await disconnect() } @@ -505,20 +484,16 @@ async function connect(): Promise { peerConnection = createPeerConnection(iceServers) connectStage.value = 'creating_data_channel' - // Create data channel before offer (for HID) createDataChannel(peerConnection) - // Add transceiver for receiving video peerConnection.addTransceiver('video', { direction: 'recvonly' }) peerConnection.addTransceiver('audio', { direction: 'recvonly' }) connectStage.value = 'creating_offer' - // Create offer const offer = await peerConnection.createOffer() await peerConnection.setLocalDescription(offer) connectStage.value = 'waiting_server_answer' - // Send offer to server and get answer // Do not pass client_id here: each connect creates a fresh session. const response = await webrtcApi.offer(offer.sdp!) sessionId = response.session_id @@ -526,7 +501,6 @@ async function connect(): Promise { // Send any ICE candidates that were queued while waiting for sessionId await flushPendingIceCandidates() - // Set remote description (answer) const answer: RTCSessionDescriptionInit = { type: 'answer', sdp: response.sdp, @@ -611,7 +585,6 @@ async function disconnect() { try { await webrtcApi.close(oldSessionId) } catch { - // Ignore close errors } } @@ -671,13 +644,11 @@ function sendMouse(event: HidMouseEvent): boolean { } } -// Get MediaStream for video element (cached to avoid recreating) function getMediaStream(): MediaStream | null { if (!videoTrack.value && !audioTrack.value) { return null } - // Reuse cached stream if tracks match if (cachedMediaStream) { const currentVideoTracks = cachedMediaStream.getVideoTracks() const currentAudioTracks = cachedMediaStream.getAudioTracks() @@ -693,19 +664,15 @@ function getMediaStream(): MediaStream | null { return cachedMediaStream } - // Tracks changed, update the cached stream - // Remove old tracks currentVideoTracks.forEach(t => cachedMediaStream!.removeTrack(t)) currentAudioTracks.forEach(t => cachedMediaStream!.removeTrack(t)) - // Add new tracks if (videoTrack.value) cachedMediaStream.addTrack(videoTrack.value) if (audioTrack.value) cachedMediaStream.addTrack(audioTrack.value) return cachedMediaStream } - // Create new cached stream cachedMediaStream = new MediaStream() if (videoTrack.value) { cachedMediaStream.addTrack(videoTrack.value) @@ -716,15 +683,11 @@ function getMediaStream(): MediaStream | null { return cachedMediaStream } -// Composable export export function useWebRTC() { onUnmounted(() => { - // Don't disconnect on unmount - keep connection alive - // Only disconnect when explicitly called }) return { - // State state: state as Ref, videoTrack, audioTrack, @@ -734,14 +697,12 @@ export function useWebRTC() { connectStage, sessionId: computed(() => sessionId), - // Methods connect, disconnect, sendKeyboard, sendMouse, getMediaStream, - // Computed isConnected: computed(() => state.value === 'connected'), isConnecting: computed(() => state.value === 'connecting'), hasVideo: computed(() => videoTrack.value !== null), @@ -749,7 +710,6 @@ export function useWebRTC() { } } -// Cleanup on page unload if (typeof window !== 'undefined') { window.addEventListener('beforeunload', () => { disconnect() diff --git a/web/src/composables/useWebSocket.ts b/web/src/composables/useWebSocket.ts index 6d9b3d35..48221076 100644 --- a/web/src/composables/useWebSocket.ts +++ b/web/src/composables/useWebSocket.ts @@ -1,8 +1,3 @@ -// WebSocket composable for real-time event streaming -// -// Usage: -// const { connected, on, off } = useWebSocket() -// on('stream.state_changed', (data) => { ... }) import { ref } from 'vue' import { buildWsUrl, WS_RECONNECT_DELAY } from '@/types/websocket' @@ -151,7 +146,6 @@ function handleEvent(payload: WsEvent) { } }) } - // Silently ignore events without handlers } export function useWebSocket() { @@ -170,7 +164,6 @@ export function useWebSocket() { } } -// Global lifecycle - disconnect when page unloads if (typeof window !== 'undefined') { window.addEventListener('beforeunload', () => { disconnect() diff --git a/web/src/i18n/en-US.ts b/web/src/i18n/en-US.ts index 142ddf4e..150670d6 100644 --- a/web/src/i18n/en-US.ts +++ b/web/src/i18n/en-US.ts @@ -109,7 +109,6 @@ export default { settingsTip: 'System settings', fullscreen: 'Fullscreen', fullscreenTip: 'Toggle fullscreen mode', - // Video Config videoConfig: 'Video', streamSettings: 'Stream Settings', deviceSettings: 'Device Settings', @@ -141,7 +140,6 @@ export default { notRecommended: 'Not Recommended', multiSourceCodecLocked: '{sources} are enabled. Current codec is locked.', multiSourceVideoParamsWarning: '{sources} are enabled. Changing video device and input parameters will interrupt the stream.', - // HID Config hidConfig: 'HID', mouseSettings: 'Mouse Settings', hidDeviceSettings: 'HID Device Settings', @@ -154,7 +152,6 @@ export default { absolute: 'Absolute', relative: 'Relative', applying: 'Applying...', - // Audio Config audioConfig: 'Audio', playbackControl: 'Playback', volume: 'Volume', @@ -218,19 +215,16 @@ export default { title: 'Initial Setup', welcome: 'Welcome to One-KVM', description: 'Complete the initial setup to get started', - // Step titles stepAccount: 'Account Setup', stepVideo: 'Video Setup', stepAudioVideo: 'Audio/Video Setup', stepHid: 'HID Setup', - // Account setUsername: 'Set Admin Username', usernameHint: 'Username must be at least 2 characters', setPassword: 'Set Admin Password', passwordHint: 'Password must be at least 4 characters', confirmPassword: 'Confirm Password', passwordMismatch: 'Passwords do not match', - // Video videoDevice: 'Video Device', selectVideoDevice: 'Select video capture device', videoFormat: 'Video Format', @@ -242,13 +236,11 @@ export default { noVideoDevices: 'No video devices detected', noSignalDetected: 'No HDMI signal detected. Please connect an HDMI cable and refresh.', refreshDevices: 'Refresh Devices', - // Audio audioDevice: 'Audio Device', selectAudioDevice: 'Select audio capture device', noAudio: 'No audio', noAudioDevices: 'No audio devices detected', audioDeviceHelp: 'Select the audio capture device for capturing remote host audio. Usually on the same USB device as the video capture card.', - // HID hidBackend: 'HID Backend', selectHidBackend: 'Select HID control method', serialHid: 'Serial HID', @@ -261,31 +253,25 @@ export default { selectUdc: 'Select UDC', noUdcDevices: 'No UDC devices detected', hidDisabledHint: 'Disabling HID will prevent keyboard and mouse control of the remote host', - // Complete complete: 'Complete Setup', setupFailed: 'Setup failed', - // Advanced encoder advancedEncoder: 'Advanced: Encoder Backend', encoderHint: 'The default "Auto" option works for most cases. Only change if you need a specific encoder backend.', autoRecommended: 'Auto (Recommended)', hardware: 'Hardware', software: 'Software', - // Progress progress: 'Step {current} of {total}', - // Help tooltips ch9329Help: 'CH9329 is a serial-to-HID chip connected via serial port. Works with most hardware configurations.', otgHelp: 'USB OTG mode emulates HID devices directly through USB Device Controller. Requires hardware OTG support.', otgLowEndpointHint: 'Detected low-endpoint UDC; Consumer Control Keyboard will be disabled automatically.', videoDeviceHelp: 'Select the video capture device for capturing the remote host display. Usually an HDMI capture card.', videoFormatHelp: 'MJPEG has best compatibility. H.264/H.265 uses less bandwidth but requires encoding support.', - // Extensions stepExtensions: 'Extensions', extensionsDescription: 'Choose which extensions to auto-start', ttydTitle: 'Web Terminal (ttyd)', ttydDescription: 'Access device command line in browser', extensionsHint: 'These settings can be changed later in Settings', notInstalled: 'Not installed', - // Password strength passwordStrength: 'Password Strength', passwordWeak: 'Weak', passwordMedium: 'Medium', @@ -350,7 +336,6 @@ export default { uvc_capture_stall: '', }, }, - // WebRTC webrtcConnected: 'WebRTC Connected', webrtcConnectedDesc: 'Using low-latency H.264 video stream', webrtcFailed: 'WebRTC Connection Failed', @@ -363,29 +348,23 @@ export default { webrtcPhaseSetRemote: 'Applying remote description...', webrtcPhaseApplyIce: 'Applying ICE candidates...', webrtcPhaseNegotiating: 'Negotiating secure connection...', - // Pointer Lock pointerLocked: 'Pointer Locked', pointerLockedDesc: 'Press Escape to release the pointer', pointerLockFailed: 'Failed to lock pointer', relativeModeHint: 'Relative Mouse Mode', relativeModeHintDesc: 'Click on the video area to lock the pointer, press Escape to release', - // Meta Key Hint metaKeyHint: 'System Key Detected', metaKeyHintDesc: 'Enter fullscreen mode to capture Win/Meta keys', - // Stream mode change streamModeChanged: 'Video Mode Changed', streamModeChangedDesc: 'Server switched to {mode} mode', - // Device monitoring deviceLost: 'Video Device Lost', deviceLostDesc: '{device}: {reason}', deviceRecovering: 'Video Device Recovering', deviceRecoveringDesc: 'Attempting to recover video device (attempt {attempt})', deviceRecovered: 'Video Device Recovered', deviceRecoveredDesc: 'Video device reconnected successfully', - // Loading state pleaseWait: 'Please wait...', retryCount: 'Retrying (attempt {count})', - // Error details errorDetails: 'Error details', }, hid: { @@ -397,7 +376,6 @@ export default { pasteText: 'Paste Text', absoluteMouse: 'Absolute', relativeMouse: 'Relative', - // Device monitoring deviceLost: 'HID Device Lost', deviceLostDesc: '{backend}: {reason}', reconnecting: 'HID Reconnecting', @@ -424,7 +402,6 @@ export default { }, }, audio: { - // Device monitoring deviceLost: 'Audio Device Lost', deviceLostDesc: '{device}: {reason}', reconnecting: 'Audio Reconnecting', @@ -468,7 +445,6 @@ export default { uploadImageHint: 'Click to upload ISO/IMG', imageMounted: 'Image {name} mounted', imageUnmounted: 'Image unmounted', - // URL download downloadFromUrl: 'Download from URL', downloadFromUrlDesc: 'Enter the URL of an image file (ISO/IMG supported)', url: 'URL', @@ -479,16 +455,13 @@ export default { downloadFailed: 'Download failed', largeFileWarning: '>2.2GB', largeFileTooltip: 'File is larger than 2.2GB, please use Flash mode to mount', - // Device monitoring error: 'MSD Error', errorDesc: '{reason}', recovered: 'MSD Recovered', recoveredDesc: 'MSD operation completed successfully', - // Operation status operationInProgress: 'Operation in progress, please wait', driveConnected: 'Virtual USB drive connected', imageConnected: 'Image {name} connected', - // Drive initialization selectDriveSize: 'Select virtual drive size', selectedSize: 'Selected size', customSize: 'Custom size', @@ -520,7 +493,6 @@ export default { security: 'Security', about: 'About', aboutDesc: 'Open and Lightweight IP-KVM Solution', - // Device info deviceInfo: 'Device Info', deviceInfoDesc: 'Host system information', hostname: 'Hostname', @@ -545,11 +517,9 @@ export default { networkSettings: 'Network Settings', msdSettings: 'MSD Settings', atxSettings: 'ATX Settings', - // Network tab httpSettings: 'HTTP Settings', httpPort: 'HTTP Port', configureHttpPort: 'Configure HTTP server port', - // Web server webServer: 'Access Address', webServerDesc: 'Configure HTTP/HTTPS ports and listening addresses. Restart required for changes to take effect.', httpsPort: 'HTTPS Port', @@ -569,20 +539,17 @@ export default { bindAddressListEmpty: 'Add at least one IP address.', httpsEnabled: 'Enable HTTPS', httpsEnabledDesc: 'Enable HTTPS encrypted connection (a self-signed certificate is generated if none is specified)', - // Port config portConfig: 'Port & Protocol', portConfigDesc: 'The service runs on a single port at a time, determined by the HTTPS toggle', httpPortReserved: 'HTTP port (reserved)', httpsPortReserved: 'HTTPS port (reserved)', previewUrl: 'Access URL preview', - // Listen address listenAddress: 'Listen Address', listenAddressDesc: 'Configure which network interfaces the web server listens on', bindModeAllDesc: '0.0.0.0 — Listen on all network interfaces', bindModeLocalDesc: '127.0.0.1 — Allow local access only', bindModeCustomDesc: 'Specify a list of IP addresses', effectiveAddresses: 'Listen address preview', - // SSL certificate sslCertificate: 'SSL Certificate', sslCertificateDesc: 'Upload a custom PEM certificate to replace the self-signed one, restart required', sslCertCustom: 'Custom Certificate', @@ -628,13 +595,11 @@ export default { updateMsgVerifying: 'Verifying (SHA256)', updateMsgInstalling: 'Replacing binary', updateMsgRestarting: 'Restarting service', - // Auth auth: 'Access', authSettings: 'Access Settings', authSettingsDesc: 'Single-user access and session behavior', allowMultipleSessions: 'Allow multiple web sessions', allowMultipleSessionsDesc: 'When disabled, a new login will kick the previous session.', - // User management userManagement: 'User Management', userManagementDesc: 'Manage user accounts and permissions', addUser: 'Add User', @@ -649,7 +614,6 @@ export default { noUsers: 'No users found', create: 'Create', confirmDeleteUser: 'Are you sure you want to delete user "{name}"?', - // MSD/ATX status msdStatus: 'MSD Status', atxStatus: 'ATX Status', available: 'Available', @@ -665,7 +629,6 @@ export default { disabled: 'Disabled', msdDesc: 'Mass Storage Device allows you to mount ISO images and virtual drives to the target machine. Use the MSD panel on the main page to manage images.', atxDesc: 'ATX power control allows you to remotely power on/off and reset the target machine. Use the ATX panel on the main page to control power.', - // ATX configuration atxSettingsDesc: 'Configure ATX power control hardware bindings', atxEnable: 'Enable ATX Control', atxEnableDesc: 'Enable remote control of power and reset buttons', @@ -693,16 +656,13 @@ export default { atxLedPin: 'GPIO Pin', atxLedInverted: 'Invert Logic', atxLedInvertedDesc: 'GPIO is low when LED is on', - // WOL configuration atxWolSettings: 'Wake-on-LAN Settings', atxWolSettingsDesc: 'Configure WOL magic packet sending options', atxWolInterface: 'Network Interface', atxWolInterfacePlaceholder: 'e.g. eth0, enp0s3', atxWolInterfaceHint: 'Specify network interface for WOL packets, leave empty for default routing', - // Basic tab descriptions themeDesc: 'Choose your preferred color scheme', languageDesc: 'Select your preferred language', - // Video tab videoSettings: 'Video Settings', videoSettingsDesc: 'Configure video capture device', videoDevice: 'Video Device', @@ -719,7 +679,6 @@ export default { software: 'Software', supportedFormats: 'Supported Formats', encoderHint: 'Hardware encoders provide better performance with lower CPU usage. Software encoders are more compatible but require more CPU resources.', - // HID tab hidSettings: 'HID Settings', hidSettingsDesc: 'Configure keyboard and mouse control', hidBackend: 'HID Backend', @@ -748,7 +707,6 @@ export default { otgProfileWarning: 'Changing HID functions will reconnect the USB device', otgLowEndpointHint: 'Low-endpoint UDC detected; Consumer Control Keyboard will be disabled automatically.', otgFunctionMinWarning: 'Enable at least one HID function before saving', - // OTG Descriptor otgDescriptor: 'USB Device Descriptor', otgDescriptorDesc: 'Configure USB device identification', vendorId: 'Vendor ID (VID)', @@ -862,7 +820,6 @@ export default { resetConfirmDesc: 'This will reset USB device "{device}" by cycling its authorized attribute. All connections to this device will be temporarily interrupted. Continue?', resetAction: 'Reset Device', }, - // WebRTC / ICE webrtcSettings: 'WebRTC Settings', webrtcSettingsDesc: 'Configure STUN/TURN servers for NAT traversal', publicIceServersHint: 'Empty uses Google public STUN, configure your own TURN for production', @@ -931,7 +888,6 @@ export default { notConnected: 'Not Connected', connected: 'Connected', image: 'Image', - // MSD status details msdStatus: 'Status', msdStandby: 'Idle', msdImageMode: 'Image Mode', @@ -941,7 +897,6 @@ export default { msdNoImage: 'None', }, extensions: { - // Common available: 'Available', unavailable: 'Unavailable', running: 'Running', @@ -958,7 +913,6 @@ export default { title: 'Remote Access', desc: 'GOSTC NAT traversal and Easytier networking', }, - // ttyd ttyd: { title: 'Ttyd Web Terminal', desc: 'Web terminal access via ttyd', @@ -967,7 +921,6 @@ export default { port: 'Port', shell: 'Shell', }, - // gostc gostc: { title: 'GOSTC NAT Traversal', desc: 'NAT traversal via GOSTC', @@ -976,7 +929,6 @@ export default { key: 'Client Key', tls: 'Enable TLS', }, - // easytier easytier: { title: 'Easytier Network', desc: 'P2P VPN networking via EasyTier', @@ -987,7 +939,6 @@ export default { virtualIp: 'Virtual IP', virtualIpHint: 'Leave empty for DHCP, or specify with CIDR (e.g., 10.0.0.1/24)', }, - // rustdesk rustdesk: { title: 'RustDesk Remote', desc: 'Remote access via RustDesk client', @@ -1078,31 +1029,24 @@ export default { p2p: 'P2P Direct', relay: 'TURN Relay', }, - // Help tooltip texts help: { - // MSD related flashMode: 'Flash mode mounts the image as a USB drive, compatible with most BIOS boot', cdromMode: 'CDROM mode mounts the image as a CD drive, for systems requiring optical boot', readOnlyMode: 'Read-only mode is safer, the target system cannot modify the image', readWriteMode: 'Read-write mode allows writing data, useful for saving configurations', driveSize: 'Virtual drive size. Larger drives can store more files but take longer to initialize', - // Video related mjpegMode: 'MJPEG mode has best compatibility, works with all browsers, but higher latency', webrtcMode: 'WebRTC mode has lower latency, but requires browser codec support', videoBitratePreset: 'Speed: lowest latency, best for slow networks. Balanced: good quality and latency. Quality: best visual, needs good bandwidth', encoderBackend: 'Hardware encoder has better performance and lower power. Software encoder has better compatibility', - // HID related absoluteMode: 'Absolute mode maps mouse coordinates directly, suitable for most scenarios', relativeMode: 'Relative mode sends mouse movement delta, for games or special software', mouseThrottle: 'Send interval controls mouse event frequency. Higher values reduce network load', hidBackend: 'OTG backend requires USB OTG hardware support. CH9329 is a serial HID chip solution', - // ATX related atxActiveLevel: 'Active level depends on your hardware wiring. High means high voltage when triggered', wolInterface: 'Network interface name for sending Wake-on-LAN magic packets, e.g., eth0 or br0', - // Network related stunServer: 'STUN server for NAT traversal to establish P2P connections. Leave empty for public servers', turnServer: 'TURN server provides relay when P2P fails. Requires more bandwidth but more reliable', - // Audio related audioQuality: 'Higher quality means better audio but requires more network bandwidth', }, } diff --git a/web/src/i18n/index.ts b/web/src/i18n/index.ts index d812594f..3410996f 100644 --- a/web/src/i18n/index.ts +++ b/web/src/i18n/index.ts @@ -2,7 +2,6 @@ import { createI18n } from 'vue-i18n' import zhCN from './zh-CN' import enUS from './en-US' -// Supported languages export const supportedLanguages = [ { code: 'zh-CN', name: '中文', flag: '🇨🇳' }, { code: 'en-US', name: 'English', flag: '🇺🇸' }, @@ -10,33 +9,26 @@ export const supportedLanguages = [ export type SupportedLocale = (typeof supportedLanguages)[number]['code'] -// Detect browser language with improved logic function detectLanguage(): SupportedLocale { - // 1. Check localStorage for saved preference const stored = localStorage.getItem('language') if (stored && supportedLanguages.some((l) => l.code === stored)) { return stored as SupportedLocale } - // 2. Check browser language list (navigator.languages is more comprehensive) const languages = navigator.languages || [navigator.language] for (const lang of languages) { const normalizedLang = lang.toLowerCase() - // Check for Chinese variants (zh, zh-CN, zh-TW, zh-HK, etc.) if (normalizedLang.startsWith('zh')) { return 'zh-CN' } - // Check for English variants if (normalizedLang.startsWith('en')) { return 'en-US' } } - // 3. Default to English return 'en-US' } -// Initialize language and set HTML lang attribute function initializeLanguage(): SupportedLocale { const lang = detectLanguage() document.documentElement.setAttribute('lang', lang) diff --git a/web/src/i18n/zh-CN.ts b/web/src/i18n/zh-CN.ts index 5a3cf120..89c11640 100644 --- a/web/src/i18n/zh-CN.ts +++ b/web/src/i18n/zh-CN.ts @@ -109,7 +109,6 @@ export default { settingsTip: '系统设置', fullscreen: '全屏', fullscreenTip: '切换全屏模式', - // Video Config videoConfig: '视频配置', streamSettings: '流设置', deviceSettings: '设备配置', @@ -141,7 +140,6 @@ export default { notRecommended: '不推荐', multiSourceCodecLocked: '{sources} 已启用,当前编码已锁定', multiSourceVideoParamsWarning: '{sources} 已启用,修改视频设备和输入参数将导致流中断', - // HID Config hidConfig: '鼠键配置', mouseSettings: '鼠标设置', hidDeviceSettings: 'HID 设备设置', @@ -154,7 +152,6 @@ export default { absolute: '绝对定位', relative: '相对定位', applying: '应用中...', - // Audio Config audioConfig: '音频', playbackControl: '播放控制', volume: '音量', @@ -218,19 +215,16 @@ export default { title: '初始化设置', welcome: '欢迎使用 One-KVM', description: '请完成初始设置以开始使用', - // Step titles stepAccount: '账号设置', stepVideo: '视频设置', stepAudioVideo: '音视频设置', stepHid: '鼠键设置', - // Account setUsername: '设置管理员用户名', usernameHint: '用户名至少2个字符', setPassword: '设置管理员密码', passwordHint: '密码至少4个字符', confirmPassword: '确认密码', passwordMismatch: '两次输入的密码不一致', - // Video videoDevice: '视频设备', selectVideoDevice: '选择视频采集设备', videoFormat: '画面格式', @@ -242,13 +236,11 @@ export default { noVideoDevices: '未检测到视频设备', noSignalDetected: '未检测到 HDMI 信号,请连接 HDMI 线缆后刷新。', refreshDevices: '刷新设备', - // Audio audioDevice: '音频设备', selectAudioDevice: '选择音频采集设备', noAudio: '不使用音频', noAudioDevices: '未检测到音频设备', audioDeviceHelp: '选择用于捕获远程主机音频的设备。通常与视频采集卡在同一 USB 设备上。', - // HID hidBackend: 'HID 后端', selectHidBackend: '选择 HID 控制方式', serialHid: '串口 HID', @@ -261,31 +253,25 @@ export default { selectUdc: '选择 UDC', noUdcDevices: '未检测到 UDC 设备', hidDisabledHint: '禁用 HID 后将无法控制远程主机的键盘和鼠标', - // Complete complete: '完成设置', setupFailed: '设置失败', - // Advanced encoder advancedEncoder: '高级选项:编码器后端', encoderHint: '默认的"自动"选项适用于大多数情况。仅在需要特定编码器后端时更改。', autoRecommended: '自动(推荐)', hardware: '硬件', software: '软件', - // Progress progress: '步骤 {current} / {total}', - // Help tooltips ch9329Help: 'CH9329 是一款串口转 HID 芯片,通过串口连接到主机。适用于大多数硬件配置。', otgHelp: 'USB OTG 模式通过 USB 设备控制器直接模拟 HID 设备。需要硬件支持 USB OTG 功能。', otgLowEndpointHint: '检测到低端点 UDC,将自动禁用多媒体键盘。', videoDeviceHelp: '选择用于捕获远程主机画面的视频采集设备。通常是 HDMI 采集卡。', videoFormatHelp: 'MJPEG 格式兼容性最好,H.264/H.265 带宽占用更低但需要编码支持。', - // Extensions stepExtensions: '扩展设置', extensionsDescription: '选择要自动启动的扩展服务', ttydTitle: 'Web 终端 (ttyd)', ttydDescription: '在浏览器中访问设备的命令行终端', extensionsHint: '这些设置可以在设置页面中随时更改', notInstalled: '未安装', - // Password strength passwordStrength: '密码强度', passwordWeak: '弱', passwordMedium: '中', @@ -349,7 +335,6 @@ export default { uvc_capture_stall: '', }, }, - // WebRTC webrtcConnected: 'WebRTC 已连接', webrtcConnectedDesc: '正在使用 H.264 低延迟视频流', webrtcFailed: 'WebRTC 连接失败', @@ -362,29 +347,23 @@ export default { webrtcPhaseSetRemote: '正在应用远端会话描述...', webrtcPhaseApplyIce: '正在应用 ICE 候选...', webrtcPhaseNegotiating: '正在协商安全连接...', - // Pointer Lock pointerLocked: '鼠标已锁定', pointerLockedDesc: '按 Escape 键释放鼠标', pointerLockFailed: '鼠标锁定失败', relativeModeHint: '相对鼠标模式', relativeModeHintDesc: '点击视频区域以锁定鼠标,按 Escape 释放', - // Meta Key Hint metaKeyHint: '检测到系统键', metaKeyHintDesc: '请进入全屏模式以捕获 Win/Meta 键', - // Stream mode change streamModeChanged: '视频模式已切换', streamModeChangedDesc: '服务器已切换到 {mode} 模式', - // 设备监控 deviceLost: '视频设备丢失', deviceLostDesc: '{device}: {reason}', deviceRecovering: '视频设备恢复中', deviceRecoveringDesc: '正在尝试恢复视频设备(第 {attempt} 次)', deviceRecovered: '视频设备已恢复', deviceRecoveredDesc: '视频设备已成功重连', - // 加载状态 pleaseWait: '请稍候...', retryCount: '正在重试 (第 {count} 次)', - // 错误详情 errorDetails: '错误详情', }, hid: { @@ -396,7 +375,6 @@ export default { pasteText: '粘贴文本', absoluteMouse: '绝对定位', relativeMouse: '相对定位', - // 设备监控 deviceLost: 'HID 设备丢失', deviceLostDesc: '{backend}: {reason}', reconnecting: 'HID 重连中', @@ -423,7 +401,6 @@ export default { }, }, audio: { - // 设备监控 deviceLost: '音频设备丢失', deviceLostDesc: '{device}: {reason}', reconnecting: '音频重连中', @@ -467,7 +444,6 @@ export default { uploadImageHint: '点击上传 ISO/IMG 镜像', imageMounted: '镜像 {name} 已挂载', imageUnmounted: '镜像已卸载', - // URL download downloadFromUrl: '从 URL 下载', downloadFromUrlDesc: '输入镜像文件的 URL 地址,支持 ISO/IMG 格式', url: 'URL 地址', @@ -478,16 +454,13 @@ export default { downloadFailed: '下载失败', largeFileWarning: '>2.2GB', largeFileTooltip: '文件大于 2.2GB,请使用 Flash 模式挂载', - // 设备监控 error: 'MSD 错误', errorDesc: '{reason}', recovered: 'MSD 已恢复', recoveredDesc: 'MSD 设备已恢复正常', - // 操作状态 operationInProgress: '操作进行中,请稍候', driveConnected: '虚拟U盘已连接', imageConnected: '镜像 {name} 已连接', - // 驱动器初始化 selectDriveSize: '选择虚拟驱动器大小', selectedSize: '选定大小', customSize: '自定义大小', @@ -519,7 +492,6 @@ export default { security: '安全', about: '关于', aboutDesc: '开放轻量的 IP-KVM 解决方案', - // Device info deviceInfo: '设备信息', deviceInfoDesc: '主机系统信息', hostname: '主机名', @@ -544,11 +516,9 @@ export default { networkSettings: '网络设置', msdSettings: 'MSD 设置', atxSettings: 'ATX 设置', - // Network tab httpSettings: 'HTTP 设置', httpPort: 'HTTP 端口', configureHttpPort: '配置 HTTP 服务器端口', - // Web server webServer: '访问地址', webServerDesc: '配置 HTTP/HTTPS 端口和监听地址,修改后需要重启生效', httpsPort: 'HTTPS 端口', @@ -568,20 +538,17 @@ export default { bindAddressListEmpty: '请至少填写一个 IP 地址。', httpsEnabled: '启用 HTTPS', httpsEnabledDesc: '启用 HTTPS 加密连接(未指定证书将生成自签证书)', - // Port config portConfig: '端口与协议', portConfigDesc: '服务一次只运行在一个端口上,由 HTTPS 开关决定使用哪个端口', httpPortReserved: 'HTTP 端口(备用)', httpsPortReserved: 'HTTPS 端口(备用)', previewUrl: '访问地址预览', - // Listen address listenAddress: '监听地址', listenAddressDesc: '配置 Web 服务监听哪些网络接口', bindModeAllDesc: '0.0.0.0 — 监听所有网络接口', bindModeLocalDesc: '127.0.0.1 — 仅允许本机访问', bindModeCustomDesc: '指定一组 IP 地址', effectiveAddresses: '监听地址预览', - // SSL certificate sslCertificate: 'SSL 证书', sslCertificateDesc: '上传自定义 PEM 证书替换自签名证书,修改后需要重启生效', sslCertCustom: '自定义证书', @@ -627,13 +594,11 @@ export default { updateMsgVerifying: '校验中(SHA256)', updateMsgInstalling: '替换程序中', updateMsgRestarting: '服务重启中', - // Auth auth: '访问控制', authSettings: '访问设置', authSettingsDesc: '单用户访问与会话策略', allowMultipleSessions: '允许多个 Web 会话', allowMultipleSessionsDesc: '关闭后,新登录会踢掉旧会话。', - // User management userManagement: '用户管理', userManagementDesc: '管理用户账号和权限', addUser: '添加用户', @@ -648,7 +613,6 @@ export default { noUsers: '暂无用户', create: '创建', confirmDeleteUser: '确定要删除用户 "{name}" 吗?', - // MSD/ATX status msdStatus: 'MSD 状态', atxStatus: 'ATX 状态', available: '可用', @@ -664,7 +628,6 @@ export default { disabled: '已禁用', msdDesc: '虚拟存储设备允许您将 ISO 镜像和虚拟驱动器挂载到目标机器。请在主页面的 MSD 面板中管理镜像。', atxDesc: 'ATX 电源控制允许您远程开关机和重启目标机器。请在主页面的 ATX 面板中控制电源。', - // ATX configuration atxSettingsDesc: '配置 ATX 电源控制硬件绑定', atxEnable: '启用 ATX 控制', atxEnableDesc: '启用后可以远程控制电源和重启按钮', @@ -692,16 +655,13 @@ export default { atxLedPin: 'GPIO 引脚', atxLedInverted: '反转逻辑', atxLedInvertedDesc: 'LED 亮起时 GPIO 为低电平', - // WOL configuration atxWolSettings: '网络唤醒设置', atxWolSettingsDesc: '配置 Wake-on-LAN 魔术包发送选项', atxWolInterface: '网络接口', atxWolInterfacePlaceholder: '例如: eth0, enp0s3', atxWolInterfaceHint: '指定发送 WOL 包的网络接口,留空则使用系统默认路由', - // Basic tab descriptions themeDesc: '选择您喜欢的颜色方案', languageDesc: '选择您的首选语言', - // Video tab videoSettings: '视频设置', videoSettingsDesc: '配置视频采集设备', videoDevice: '视频设备', @@ -718,7 +678,6 @@ export default { software: '软件', supportedFormats: '支持的格式', encoderHint: '硬件编码器性能更好,CPU 占用更低。软件编码器兼容性更好,但需要更多 CPU 资源。', - // HID tab hidSettings: 'HID 设置', hidSettingsDesc: '配置键盘和鼠标控制', hidBackend: 'HID 后端', @@ -747,7 +706,6 @@ export default { otgProfileWarning: '修改 HID 功能将导致 USB 设备重新连接', otgLowEndpointHint: '检测到低端点 UDC,将自动禁用多媒体键盘。', otgFunctionMinWarning: '请至少启用一个 HID 功能后再保存', - // OTG Descriptor otgDescriptor: 'USB 设备描述符', otgDescriptorDesc: '配置 USB 设备标识信息', vendorId: '厂商 ID (VID)', @@ -861,7 +819,6 @@ export default { resetConfirmDesc: '将通过 authorized 属性复位 USB 设备「{device}」,该设备上的所有连接将短暂中断。确定继续?', resetAction: '确认复位', }, - // WebRTC / ICE webrtcSettings: 'WebRTC 设置', webrtcSettingsDesc: '配置 STUN/TURN 服务器以实现 NAT 穿透', publicIceServersHint: '留空将使用 Google 公共 STUN 服务器,TURN 服务器需自行配置', @@ -930,7 +887,6 @@ export default { notConnected: '未连接', connected: '已连接', image: '镜像', - // MSD 状态详情 msdStatus: '状态', msdStandby: '空闲', msdImageMode: '镜像模式', @@ -940,7 +896,6 @@ export default { msdNoImage: '无', }, extensions: { - // Common available: '可用', unavailable: '不可用', running: '运行中', @@ -957,7 +912,6 @@ export default { title: '远程访问', desc: 'GOSTC 内网穿透与 Easytier 组网', }, - // ttyd ttyd: { title: 'Ttyd 网页终端', desc: '通过 ttyd 提供网页终端访问', @@ -966,7 +920,6 @@ export default { port: '端口', shell: 'Shell', }, - // gostc gostc: { title: 'GOSTC 内网穿透', desc: '通过 GOSTC 实现内网穿透', @@ -975,7 +928,6 @@ export default { key: '客户端密钥', tls: '启用 TLS', }, - // easytier easytier: { title: 'Easytier 组网', desc: '通过 EasyTier 实现 P2P VPN 组网', @@ -986,7 +938,6 @@ export default { virtualIp: '虚拟 IP', virtualIpHint: '留空则自动分配,手动指定需包含网段(如 10.0.0.1/24)', }, - // rustdesk rustdesk: { title: 'RustDesk 远程', desc: '使用 RustDesk 客户端进行远程访问', @@ -1077,31 +1028,24 @@ export default { p2p: 'P2P 直连', relay: 'TURN 中继', }, - // 帮助提示文本 help: { - // MSD 相关 flashMode: 'Flash 模式将镜像作为 U 盘挂载,支持大多数 BIOS 启动', cdromMode: 'CDROM 模式将镜像作为光驱挂载,适用于需要光盘启动的系统', readOnlyMode: '只读模式更安全,目标系统无法修改镜像内容', readWriteMode: '读写模式允许目标系统写入数据,适用于需要保存配置的场景', driveSize: '虚拟驱动器大小。较大的驱动器支持存放更多文件,但初始化时间更长', - // 视频相关 mjpegMode: 'MJPEG 模式兼容性最好,适用于所有浏览器,但延迟较高', webrtcMode: 'WebRTC 模式延迟更低,但需要浏览器支持相应编解码器', videoBitratePreset: '速度优先:最低延迟,适合网络较差的场景;均衡:画质和延迟平衡;质量优先:最佳画质,需要较好的网络带宽', encoderBackend: '硬件编码器性能更好功耗更低,软件编码器兼容性更好', - // HID 相关 absoluteMode: '绝对定位模式直接映射鼠标坐标,适用于大多数场景', relativeMode: '相对定位模式发送鼠标移动增量,适用于游戏或特殊软件', mouseThrottle: '发送间隔控制鼠标事件的发送频率,较大的值可减少网络负载', hidBackend: 'OTG 后端需要硬件支持 USB OTG,CH9329 是串口 HID 芯片方案', - // ATX 相关 atxActiveLevel: '活跃电平取决于您的硬件接线方式。高电平表示触发时输出高电压,低电平相反', wolInterface: '用于发送 Wake-on-LAN 魔术包的网络接口名称,如 eth0 或 br0', - // 网络相关 stunServer: 'STUN 服务器用于 NAT 穿透,帮助建立 P2P 连接。留空使用公共服务器', turnServer: 'TURN 服务器在 P2P 连接失败时提供中继。需要更多带宽但连接更可靠', - // 音频相关 audioQuality: '更高的质量意味着更好的音频效果,但需要更多的网络带宽', }, } diff --git a/web/src/lib/charToHid.ts b/web/src/lib/charToHid.ts index ebed68e1..c96340e0 100644 --- a/web/src/lib/charToHid.ts +++ b/web/src/lib/charToHid.ts @@ -1,5 +1,3 @@ -// Character to HID usage mapping for text paste functionality. -// The table follows US QWERTY layout semantics. import { type CanonicalKey } from '@/types/generated' import { keys } from '@/lib/keyboardMappings' @@ -10,7 +8,6 @@ export interface CharKeyMapping { } const charToKeyMap: Record = { - // Lowercase letters a: { key: keys.KeyA, shift: false }, b: { key: keys.KeyB, shift: false }, c: { key: keys.KeyC, shift: false }, @@ -38,7 +35,6 @@ const charToKeyMap: Record = { y: { key: keys.KeyY, shift: false }, z: { key: keys.KeyZ, shift: false }, - // Uppercase letters A: { key: keys.KeyA, shift: true }, B: { key: keys.KeyB, shift: true }, C: { key: keys.KeyC, shift: true }, @@ -66,7 +62,6 @@ const charToKeyMap: Record = { Y: { key: keys.KeyY, shift: true }, Z: { key: keys.KeyZ, shift: true }, - // Number row '0': { key: keys.Digit0, shift: false }, '1': { key: keys.Digit1, shift: false }, '2': { key: keys.Digit2, shift: false }, @@ -78,7 +73,6 @@ const charToKeyMap: Record = { '8': { key: keys.Digit8, shift: false }, '9': { key: keys.Digit9, shift: false }, - // Shifted number row symbols ')': { key: keys.Digit0, shift: true }, '!': { key: keys.Digit1, shift: true }, '@': { key: keys.Digit2, shift: true }, @@ -90,7 +84,6 @@ const charToKeyMap: Record = { '*': { key: keys.Digit8, shift: true }, '(': { key: keys.Digit9, shift: true }, - // Punctuation and symbols '-': { key: keys.Minus, shift: false }, '=': { key: keys.Equal, shift: false }, '[': { key: keys.BracketLeft, shift: false }, @@ -103,7 +96,6 @@ const charToKeyMap: Record = { '.': { key: keys.Period, shift: false }, '/': { key: keys.Slash, shift: false }, - // Shifted punctuation and symbols _: { key: keys.Minus, shift: true }, '+': { key: keys.Equal, shift: true }, '{': { key: keys.BracketLeft, shift: true }, @@ -116,7 +108,6 @@ const charToKeyMap: Record = { '>': { key: keys.Period, shift: true }, '?': { key: keys.Slash, shift: true }, - // Whitespace and control ' ': { key: keys.Space, shift: false }, '\t': { key: keys.Tab, shift: false }, '\n': { key: keys.Enter, shift: false }, diff --git a/web/src/lib/keyboardLayouts.ts b/web/src/lib/keyboardLayouts.ts index a0874021..4796bdde 100644 --- a/web/src/lib/keyboardLayouts.ts +++ b/web/src/lib/keyboardLayouts.ts @@ -1,4 +1,3 @@ -// Virtual keyboard layout data shared by the on-screen keyboard. export type KeyboardOsType = 'windows' | 'mac' | 'android' diff --git a/web/src/lib/keyboardMappings.ts b/web/src/lib/keyboardMappings.ts index 2a3f38d8..03078cd1 100644 --- a/web/src/lib/keyboardMappings.ts +++ b/web/src/lib/keyboardMappings.ts @@ -122,7 +122,6 @@ export const keys = { export type KeyName = keyof typeof keys -// Consumer Control Usage codes (for multimedia keys) export const consumerKeys = { PlayPause: 0x00cd, Stop: 0x00b7, diff --git a/web/src/lib/utils.ts b/web/src/lib/utils.ts index 5200ee2a..d7fc534c 100644 --- a/web/src/lib/utils.ts +++ b/web/src/lib/utils.ts @@ -15,8 +15,6 @@ export function generateUUID(): string { return crypto.randomUUID() } - // Fallback: generate UUID v4 manually - // Format: xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, (c) => { const r = (Math.random() * 16) | 0 const v = c === 'x' ? r : (r & 0x3) | 0x8 diff --git a/web/src/router/index.ts b/web/src/router/index.ts index 30c72d3c..cf8d5dc6 100644 --- a/web/src/router/index.ts +++ b/web/src/router/index.ts @@ -42,17 +42,14 @@ function t(key: string, params?: Record): string { return String(i18n.global.t(key, params as any)) } -// Navigation guard router.beforeEach(async (to, _from, next) => { const authStore = useAuthStore() - // Prevent access to setup after initialization const shouldCheckSetup = to.name === 'Setup' || !authStore.initialized if (shouldCheckSetup) { try { await authStore.checkSetupStatus() } catch { - // Continue anyway } } @@ -65,20 +62,17 @@ router.beforeEach(async (to, _from, next) => { try { await authStore.checkAuth() } catch { - // Not authenticated } } return next({ name: authStore.isAuthenticated ? 'Console' : 'Login' }) } - // Check authentication for protected routes if (to.meta.requiresAuth !== false) { if (!authStore.isAuthenticated) { try { await authStore.checkAuth() } catch (e) { - // Not authenticated if (e instanceof ApiError && e.status === 401 && !sessionExpiredNotified) { const normalized = e.message.toLowerCase() const isLoggedInElsewhere = normalized.includes('logged in elsewhere') diff --git a/web/src/types/generated.ts b/web/src/types/generated.ts index 1549ab9b..b8897a18 100644 --- a/web/src/types/generated.ts +++ b/web/src/types/generated.ts @@ -311,43 +311,26 @@ export interface WebConfig { ssl_key_path?: string; } -/** ttyd configuration (Web Terminal) */ export interface TtydConfig { - /** Enable auto-start */ enabled: boolean; - /** Port to listen on */ - port: number; - /** Shell to execute */ shell: string; } -/** gostc configuration (NAT traversal based on FRP) */ export interface GostcConfig { - /** Enable auto-start */ enabled: boolean; - /** Server address (hostname or IP) */ addr: string; - /** Client key from GOSTC management panel */ key: string; - /** Enable TLS */ tls: boolean; } -/** EasyTier configuration (P2P VPN) */ export interface EasytierConfig { - /** Enable auto-start */ enabled: boolean; - /** Network name */ network_name: string; - /** Network secret/password */ network_secret: string; - /** Peer node URLs */ peer_urls: string[]; - /** Virtual IP address (optional, auto-assigned if not set) */ virtual_ip?: string; } -/** Combined extensions configuration */ export interface ExtensionsConfig { ttyd: TtydConfig; gostc: GostcConfig; @@ -483,78 +466,50 @@ export interface EasytierConfigUpdate { virtual_ip?: string; } -/** Extension running status */ export type ExtensionStatus = - /** Binary not found at expected path */ | { state: "unavailable", data?: undefined } - /** Extension is stopped */ | { state: "stopped", data?: undefined } - /** Extension is running */ | { state: "running", data: { - /** Process ID */ pid: number; }} - /** Extension failed to start */ | { state: "failed", data: { - /** Error message */ error: string; }}; -/** easytier extension info */ export interface EasytierInfo { - /** Whether binary exists */ available: boolean; - /** Current status */ status: ExtensionStatus; - /** Configuration */ config: EasytierConfig; } -/** Extension info with status and config */ export interface ExtensionInfo { - /** Whether binary exists */ available: boolean; - /** Current status */ status: ExtensionStatus; } -/** Extension identifier (fixed set of supported extensions) */ export enum ExtensionId { - /** Web terminal (ttyd) */ Ttyd = "ttyd", - /** NAT traversal client (gostc) */ Gostc = "gostc", - /** P2P VPN (easytier) */ Easytier = "easytier", } -/** Extension logs response */ export interface ExtensionLogs { id: ExtensionId; logs: string[]; } -/** ttyd extension info */ export interface TtydInfo { - /** Whether binary exists */ available: boolean; - /** Current status */ status: ExtensionStatus; - /** Configuration */ config: TtydConfig; } -/** gostc extension info */ export interface GostcInfo { - /** Whether binary exists */ available: boolean; - /** Current status */ status: ExtensionStatus; - /** Configuration */ config: GostcConfig; } -/** All extensions status response */ export interface ExtensionsStatus { ttyd: TtydInfo; gostc: GostcInfo; @@ -677,7 +632,6 @@ export interface StreamConfigUpdate { /** Update ttyd config */ export interface TtydConfigUpdate { enabled?: boolean; - port?: number; shell?: string; } diff --git a/web/src/types/websocket.ts b/web/src/types/websocket.ts index 20b73751..18e2905b 100644 --- a/web/src/types/websocket.ts +++ b/web/src/types/websocket.ts @@ -1,6 +1,3 @@ -// Shared WebSocket types and utilities -// Used by useWebSocket, useHidWebSocket, and useAudioPlayer - import { ref, type Ref } from 'vue' /** WebSocket connection state */ diff --git a/web/src/views/ConsoleView.vue b/web/src/views/ConsoleView.vue index ac2f2583..198845a0 100644 --- a/web/src/views/ConsoleView.vue +++ b/web/src/views/ConsoleView.vue @@ -20,7 +20,6 @@ import { generateUUID } from '@/lib/utils' import { formatFpsValue } from '@/lib/fps' import type { VideoMode } from '@/components/VideoConfigPopover.vue' -// Components import StatusCard, { type StatusDetail } from '@/components/StatusCard.vue' import ActionBar from '@/components/ActionBar.vue' import InfoBar from '@/components/InfoBar.vue' @@ -84,10 +83,8 @@ const consoleEvents = useConsoleEvents({ onDeviceInfo: handleDeviceInfo, }) -// Video mode state const videoMode = ref('mjpeg') -// Video state const videoRef = ref(null) const webrtcVideoRef = ref(null) const videoContainerRef = ref(null) @@ -104,7 +101,6 @@ const streamSignalState = ref('ok') const streamSignalReason = ref(null) const streamNextRetryMs = ref(null) -// Video aspect ratio (dynamically updated from actual video dimensions) // Using string format "width/height" to let browser handle the ratio calculation const videoAspectRatio = ref(null) @@ -123,7 +119,6 @@ const clientsStats = ref>({}) // This allows us to identify our own stats in the clients_stat map const myClientId = generateUUID() -// HID state const mouseMode = ref<'absolute' | 'relative'>('absolute') const pressedKeys = ref([]) const keyboardLed = computed(() => ({ @@ -140,7 +135,6 @@ const isPointerLocked = ref(false) // Track pointer lock state /** Local overlay crosshair position (px, relative to video container); HID uses mousePosition separately */ const localCrosshairPos = ref<{ x: number; y: number } | null>(null) -// Mouse move throttling (60 Hz = ~16.67ms interval) const DEFAULT_MOUSE_MOVE_SEND_INTERVAL_MS = 16 let mouseMoveSendIntervalMs = DEFAULT_MOUSE_MOVE_SEND_INTERVAL_MS let mouseFlushTimer: ReturnType | null = null @@ -148,7 +142,6 @@ let lastMouseMoveSendTime = 0 let pendingMouseMove: { type: 'move' | 'move_abs'; x: number; y: number } | null = null let accumulatedDelta = { x: 0, y: 0 } // For relative mode: accumulate deltas between sends -// Cursor visibility (from localStorage, updated via storage event) const cursorVisible = ref(localStorage.getItem('hidShowCursor') !== 'false') let interactionListenersBound = false const isConsoleActive = ref(false) @@ -162,7 +155,6 @@ function syncMouseModeFromConfig() { } } -// Virtual keyboard state const virtualKeyboardVisible = ref(false) const virtualKeyboardAttached = ref(true) const statsSheetOpen = ref(false) @@ -173,18 +165,15 @@ const virtualKeyboardConsumerEnabled = computed(() => { return hid.otg_functions?.consumer !== false }) -// Change password dialog state const changePasswordDialogOpen = ref(false) const currentPassword = ref('') const newPassword = ref('') const confirmPassword = ref('') const changingPassword = ref(false) -// ttyd (web terminal) state const ttydStatus = ref<{ available: boolean; running: boolean } | null>(null) const showTerminalDialog = ref(false) -// Theme const isDark = ref(document.documentElement.classList.contains('dark')) // Status computed (Device status removed - now only Video, Audio, HID, MSD) @@ -207,9 +196,7 @@ const videoStatus = computed<'connected' | 'connecting' | 'disconnected' | 'erro return 'disconnected' }) -// Convert resolution to short format (e.g., 720p, 1080p, 2K, 4K) function getResolutionShortName(width: number, height: number): string { - // Common resolution mappings based on height if (height === 2160 || (height === 2160 && width === 4096)) return '4K' if (height === 1440) return '2K' if (height === 1080) return '1080p' @@ -218,11 +205,9 @@ function getResolutionShortName(width: number, height: number): string { if (height === 600) return '600p' if (height === 1024 && width === 1280) return '1024p' if (height === 960) return '960p' - // Fallback: use height + 'p' return `${height}p` } -// Quick info for status card trigger const videoQuickInfo = computed(() => { const stream = systemStore.stream if (!stream?.resolution) return '' @@ -235,12 +220,10 @@ const videoDetails = computed(() => { if (!stream) return [] const receivedFps = backendFps.value - // Input (capture) format → output (delivery) mode const inputFmt = stream.format || 'MJPEG' const outputFmt = videoMode.value === 'mjpeg' ? 'MJPEG' : `${videoMode.value.toUpperCase()} (WebRTC)` const formatDisplay = inputFmt === outputFmt ? inputFmt : `${inputFmt} → ${outputFmt}` - // Target / actual FPS combined const fpsDisplay = `${formatFpsValue(stream.targetFps ?? 0)} / ${formatFpsValue(receivedFps)}` const fpsStatus: StatusDetail['status'] = receivedFps > 5 ? 'ok' : receivedFps > 0 ? 'warning' : undefined @@ -269,7 +252,6 @@ const hidStatus = computed<'connected' | 'connecting' | 'disconnected' | 'error' } // MJPEG mode or WebRTC fallback: check WebSocket HID status - // If HID WebSocket has network error, show connecting (yellow) if (hidWs.networkError.value) return 'connecting' // If HID WebSocket is not connected (disconnected without error), show disconnected @@ -278,17 +260,14 @@ const hidStatus = computed<'connected' | 'connecting' | 'disconnected' | 'error' // If HID backend is unavailable (business error), show disconnected (gray) if (hidWs.hidUnavailable.value) return 'disconnected' - // Normal status based on system state if (hid?.available && hid.online) return 'connected' if (hid?.available && hid.initialized) return 'connecting' return 'disconnected' }) -// Quick info for HID status card trigger const hidQuickInfo = computed(() => { const hid = systemStore.hid if (!hid?.available) return '' - // Show current mode, not hardware capability return mouseMode.value === 'absolute' ? t('statusCard.absolute') : t('statusCard.relative') }) @@ -387,7 +366,6 @@ const hidDetails = computed(() => { } } - // Channel (merged with availability / connection state) let channelValue: string let channelStatus: StatusDetail['status'] if (videoMode.value !== 'mjpeg') { @@ -422,7 +400,6 @@ const hidDetails = computed(() => { return details }) -// Audio status computed const audioStatus = computed<'connected' | 'connecting' | 'disconnected' | 'error'>(() => { const audio = systemStore.audio if (!audio?.available) return 'disconnected' @@ -431,7 +408,6 @@ const audioStatus = computed<'connected' | 'connecting' | 'disconnected' | 'erro return 'disconnected' }) -// Helper function to translate audio quality function translateAudioQuality(quality: string | undefined): string { if (!quality) return t('common.unknown') const qualityLower = quality.toLowerCase() @@ -463,7 +439,6 @@ const audioDetails = computed(() => { ] }) -// MSD status computed const msdStatus = computed<'connected' | 'connecting' | 'disconnected' | 'error'>(() => { const msd = systemStore.msd if (!msd?.available) return 'disconnected' @@ -568,7 +543,6 @@ const hidHoverAlign = computed<'start' | 'end'>(() => { return showMsdStatusCard.value ? 'start' : 'end' }) -// Video handling let retryTimeoutId: number | null = null let retryCount = 0 let gracePeriodTimeoutId: number | null = null @@ -618,10 +592,8 @@ async function captureFrameOverlay() { ctx.drawImage(video, 0, 0, canvas.width, canvas.height) } - // Use JPEG to keep memory reasonable frameOverlayUrl.value = canvas.toDataURL('image/jpeg', 0.7) } catch { - // Best-effort only } } @@ -729,7 +701,6 @@ function handleVideoLoad() { gracePeriodTimeoutId = null } - // Reset all error states videoLoading.value = false videoError.value = false videoErrorMessage.value = '' @@ -738,7 +709,6 @@ function handleVideoLoad() { consecutiveErrors = 0 clearFrameOverlay() - // Auto-focus video container for immediate keyboard input const container = videoContainerRef.value if (container && typeof container.focus === 'function') { container.focus() @@ -787,13 +757,11 @@ function handleVideoError() { return } - // Clear any pending retries to avoid duplicate attempts if (retryTimeoutId !== null) { clearTimeout(retryTimeoutId) retryTimeoutId = null } - // Show loading state immediately videoLoading.value = true mjpegFrameReceived.value = false @@ -821,7 +789,6 @@ function handleStreamDeviceLost(data: { device: string; reason: string }) { } function scheduleWebRTCRecovery() { - // Clear any previous timer if (webrtcRecoveryTimerId !== null) { clearTimeout(webrtcRecoveryTimerId) webrtcRecoveryTimerId = null @@ -862,7 +829,6 @@ function scheduleWebRTCRecovery() { videoErrorMessage.value = '' webrtcRecoveryAttempts = 0 } else { - // Retry scheduleWebRTCRecovery() } } catch { @@ -883,21 +849,17 @@ function handleStreamRecovered(_data: { device: string }) { // Cancel any pending recovery timer – backend is back cancelWebRTCRecovery() - // Reset video error state videoError.value = false videoErrorMessage.value = '' - // Refresh video stream refreshVideo() } async function handleAudioStateChanged(data: { streaming: boolean; device: string | null }) { if (!data.streaming) { - // Audio stopped, disconnect unifiedAudio.disconnect() return } - // Audio started streaming if (videoMode.value !== 'mjpeg' && webrtc.isConnected.value) { // WebRTC mode: check if we have an audio track if (!webrtc.audioTrack.value) { @@ -907,9 +869,7 @@ async function handleAudioStateChanged(data: { streaming: boolean; device: strin await new Promise(resolve => setTimeout(resolve, 300)) await connectWebRTCSerial('audio track refresh') // After reconnect, the new session will have audio track - // and the watch on audioTrack will add it to MediaStream } else { - // We have audio track, ensure it's in MediaStream const currentStream = webrtcVideoRef.value?.srcObject as MediaStream | null if (currentStream && currentStream.getAudioTracks().length === 0) { currentStream.addTrack(webrtc.audioTrack.value) @@ -934,7 +894,6 @@ function handleStreamConfigChanging(data: any) { gracePeriodTimeoutId = null } - // Reset all counters and states videoRestarting.value = true pendingWebRTCReadyGate = true videoLoading.value = true @@ -952,7 +911,6 @@ function handleStreamConfigChanging(data: any) { } async function handleStreamConfigApplied(data: any) { - // Reset consecutive error counter for new config consecutiveErrors = 0 // Start grace period to ignore transient errors @@ -961,7 +919,6 @@ async function handleStreamConfigApplied(data: any) { consecutiveErrors = 0 // Also reset when grace period ends }, GRACE_PERIOD) - // Refresh video based on current mode videoRestarting.value = true // 如果正在进行模式切换,不需要在这里处理(WebRTCReady 事件会处理) @@ -1003,7 +960,6 @@ function handleStreamModeReady(data: { transition_id: string; mode: string }) { } function handleStreamModeSwitching(data: { transition_id: string; to_mode: string; from_mode: string }) { - // External mode switches: keep UI responsive and avoid black flash if (!isModeSwitching.value) { videoRestarting.value = true videoLoading.value = true @@ -1230,13 +1186,11 @@ function handleDeviceInfo(data: any) { }) } - // Skip mode sync if video config is being changed // This prevents false-positive mode changes during config switching if (data.video?.config_changing) { return } - // Sync video mode from server's stream_mode if (data.video?.stream_mode) { const serverMode = normalizeServerMode(data.video.stream_mode) if (!serverMode) return @@ -1256,7 +1210,6 @@ function handleDeviceInfo(data: any) { } } -// Handle stream mode change event from server (WebSocket broadcast) function handleStreamModeChanged(data: { mode: string; previous_mode: string }) { const newMode = normalizeServerMode(data.mode) if (!newMode) return @@ -1267,7 +1220,6 @@ function handleStreamModeChanged(data: { mode: string; previous_mode: string }) return } - // Show toast notification only if this is an external mode change toast.info(t('console.streamModeChanged'), { description: t('console.streamModeChangedDesc', { mode: data.mode.toUpperCase() }), duration: 5000, @@ -1300,7 +1252,6 @@ function refreshVideo() { mjpegTimestamp.value = Date.now() // For MJPEG streams, the 'load' event fires when first frame arrives - // But on reconnection it may not fire again, so use a timeout as fallback setTimeout(() => { isRefreshingVideo = false // Clear loading state after timeout - if stream failed, error handler will show error @@ -1508,7 +1459,6 @@ async function switchToMJPEG() { } } catch (e) { console.error('Failed to switch to MJPEG mode:', e) - // Continue anyway - the mode might already be correct } // Step 2: Disconnect WebRTC if connected or session still exists @@ -1542,7 +1492,6 @@ function syncToServerMode(mode: VideoMode) { } } -// Handle video mode change async function handleVideoModeChange(mode: VideoMode) { // 防止重复切换和竞态条件 if (mode === videoMode.value) return @@ -1592,10 +1541,8 @@ watch(() => webrtc.videoTrack.value, async (track) => { // Watch for WebRTC audio track changes - update MediaStream when audio arrives watch(() => webrtc.audioTrack.value, async (track) => { if (track && webrtcVideoRef.value && videoMode.value !== 'mjpeg') { - // Audio track arrived, update the MediaStream to include it const currentStream = webrtcVideoRef.value.srcObject as MediaStream | null if (currentStream && currentStream.getAudioTracks().length === 0) { - // Add audio track to existing stream currentStream.addTrack(track) } } @@ -1607,7 +1554,6 @@ watch(webrtcVideoRef, (el) => { }, { immediate: true }) // Watch for WebRTC stats to update FPS display -// Watch the ref directly with deep: true to detect property changes watch(webrtc.stats, (stats) => { if (videoMode.value !== 'mjpeg' && stats.framesPerSecond > 0) { backendFps.value = Math.round(stats.framesPerSecond) @@ -1626,7 +1572,6 @@ let webrtcReconnectFailures = 0 watch(() => webrtc.state.value, (newState, oldState) => { console.log('[WebRTC] State changed:', oldState, '->', newState) - // Clear any pending reconnect if (webrtcReconnectTimeout) { clearTimeout(webrtcReconnectTimeout) webrtcReconnectTimeout = null @@ -1644,7 +1589,6 @@ watch(() => webrtc.state.value, (newState, oldState) => { }) } } else if (newState === 'disconnected' || newState === 'failed') { - // Don't immediately set offline - wait for potential reconnect // The device_info event will eventually sync the correct state } } @@ -1653,7 +1597,6 @@ watch(() => webrtc.state.value, (newState, oldState) => { return } - // Auto-reconnect when disconnected (but was previously connected) if (newState === 'disconnected' && oldState === 'connected' && videoMode.value !== 'mjpeg') { webrtcReconnectTimeout = setTimeout(async () => { if (videoMode.value !== 'mjpeg' && webrtc.state.value === 'disconnected') { @@ -1676,8 +1619,6 @@ watch(() => webrtc.state.value, (newState, oldState) => { } // Handle direct 'failed' state (ICE or DTLS failure) - // Allow one automatic retry before marking as failed, consistent with - // the disconnected->reconnect path that allows 2 failures. if (newState === 'failed' && videoMode.value !== 'mjpeg') { webrtcReconnectFailures += 1 if (webrtcReconnectFailures >= 2) { @@ -1706,20 +1647,17 @@ async function toggleFullscreen() { } } -// Theme toggle function toggleTheme() { isDark.value = !isDark.value document.documentElement.classList.toggle('dark', isDark.value) localStorage.setItem('theme', isDark.value ? 'dark' : 'light') } -// Logout async function logout() { await authStore.logout() router.push('/login') } -// Change password function async function handleChangePassword() { if (!newPassword.value || !confirmPassword.value) { toast.error(t('auth.passwordRequired')) @@ -1741,20 +1679,17 @@ async function handleChangePassword() { await authApi.changePassword(currentPassword.value, newPassword.value) toast.success(t('auth.passwordChanged')) - // Reset form and close dialog currentPassword.value = '' newPassword.value = '' confirmPassword.value = '' changePasswordDialogOpen.value = false } catch (e) { - // Error toast is shown by API layer console.info('[ChangePassword] Failed:', e) } finally { changingPassword.value = false } } -// ttyd (web terminal) functions function openTerminal() { if (!ttydStatus.value?.running) return showTerminalDialog.value = true @@ -1764,13 +1699,11 @@ function openTerminalInNewTab() { window.open('/api/terminal/', '_blank') } -// ATX actions async function handlePowerShort() { try { await atxApi.power('short') await systemStore.fetchAtxState() } catch { - // ATX action failed } } @@ -1779,7 +1712,6 @@ async function handlePowerLong() { await atxApi.power('long') await systemStore.fetchAtxState() } catch { - // ATX action failed } } @@ -1788,7 +1720,6 @@ async function handleReset() { await atxApi.power('reset') await systemStore.fetchAtxState() } catch { - // ATX action failed } } @@ -1801,14 +1732,10 @@ async function handleWol(mac: string) { } } -// HID error handling - silently handle all HID errors function handleHidError(_error: any, _operation: string) { - // All HID errors are silently ignored } -// HID channel selection: use WebRTC DataChannel when available, fallback to WebSocket function sendKeyboardEvent(type: 'down' | 'up', key: CanonicalKey, modifier?: number) { - // In WebRTC mode with DataChannel ready, use DataChannel for lower latency if (videoMode.value !== 'mjpeg' && webrtc.dataChannelReady.value) { const event: HidKeyboardEvent = { type: type === 'down' ? 'keydown' : 'keyup', @@ -1817,14 +1744,11 @@ function sendKeyboardEvent(type: 'down' | 'up', key: CanonicalKey, modifier?: nu } const sent = webrtc.sendKeyboard(event) if (sent) return - // Fallback to WebSocket if DataChannel send failed } - // Use WebSocket as fallback or for MJPEG mode hidApi.keyboard(type, key, modifier).catch(err => handleHidError(err, `keyboard ${type}`)) } function sendMouseEvent(data: { type: 'move' | 'move_abs' | 'down' | 'up' | 'scroll'; x?: number; y?: number; button?: 'left' | 'right' | 'middle'; scroll?: number }) { - // In WebRTC mode with DataChannel ready, use DataChannel for lower latency if (videoMode.value !== 'mjpeg' && webrtc.dataChannelReady.value) { const event: HidMouseEvent = { type: data.type === 'move_abs' ? 'moveabs' : data.type, @@ -1835,50 +1759,37 @@ function sendMouseEvent(data: { type: 'move' | 'move_abs' | 'down' | 'up' | 'scr } const sent = webrtc.sendMouse(event) if (sent) return - // Fallback to WebSocket if DataChannel send failed } - // Use WebSocket as fallback or for MJPEG mode hidApi.mouse(data).catch(err => handleHidError(err, `mouse ${data.type}`)) } -// Check if a key should be blocked (prevented from default behavior) function shouldBlockKey(e: KeyboardEvent): boolean { - // In fullscreen mode, block all keys for maximum capture if (isFullscreen.value) { return true } - // Don't block critical browser shortcuts in non-fullscreen mode const key = e.key.toUpperCase() - // Don't block Ctrl+W (close tab), Ctrl+T (new tab), Ctrl+N (new window) if (e.ctrlKey && ['W', 'T', 'N'].includes(key)) return false - // Don't block F11 (browser fullscreen toggle) if (key === 'F11') return false - // Don't block Alt+Tab (already can't capture it anyway) if (e.altKey && key === 'TAB') return false - // Block everything else return true } -// Keyboard/Mouse event handling function handleKeyDown(e: KeyboardEvent) { const container = videoContainerRef.value if (!container) return - // Check focus in non-fullscreen mode if (!isFullscreen.value && !container.contains(document.activeElement)) return - // Try to block the key if appropriate if (shouldBlockKey(e)) { e.preventDefault() e.stopPropagation() } - // Show hint for Meta key in non-fullscreen mode if (!isFullscreen.value && (e.metaKey || e.key === 'Meta')) { toast.info(t('console.metaKeyHint'), { description: t('console.metaKeyHintDesc'), @@ -1905,10 +1816,8 @@ function handleKeyUp(e: KeyboardEvent) { const container = videoContainerRef.value if (!container) return - // Check focus in non-fullscreen mode if (!isFullscreen.value && !container.contains(document.activeElement)) return - // Try to block the key if appropriate if (shouldBlockKey(e)) { e.preventDefault() e.stopPropagation() @@ -2049,20 +1958,16 @@ function handleMouseMove(e: MouseEvent) { pendingMouseMove = { type: 'move_abs', x, y } requestMouseMoveFlush() } else { - // Relative mode: use movementX/Y when pointer is locked if (isPointerLocked.value) { const dx = e.movementX const dy = e.movementY - // Only accumulate if there's actual movement if (dx !== 0 || dy !== 0) { - // Accumulate deltas for throttled sending accumulatedDelta.x += dx accumulatedDelta.y += dy requestMouseMoveFlush() } - // Update display position (accumulated delta for display only) mousePosition.value = { x: mousePosition.value.x + dx, y: mousePosition.value.y + dy, @@ -2086,13 +1991,11 @@ function flushMouseMoveOnce(): boolean { if (accumulatedDelta.x === 0 && accumulatedDelta.y === 0) return false - // Clamp to i8 range (-127 to 127) const clampedDx = Math.max(-127, Math.min(127, accumulatedDelta.x)) const clampedDy = Math.max(-127, Math.min(127, accumulatedDelta.y)) sendMouseEvent({ type: 'move', x: clampedDx, y: clampedDy }) - // Subtract sent amount (keep remainder for next send if clamped) accumulatedDelta.x -= clampedDx accumulatedDelta.y -= clampedDy return true @@ -2147,13 +2050,11 @@ function requestMouseMoveFlush() { scheduleMouseMoveFlush() } -// Track pressed mouse button for window-level mouseup handling const pressedMouseButton = ref<'left' | 'right' | 'middle' | null>(null) function handleMouseDown(e: MouseEvent) { e.preventDefault() - // Auto-focus the video container to enable keyboard input const container = videoContainerRef.value if (container && document.activeElement !== container) { if (typeof container.focus === 'function') { @@ -2161,7 +2062,6 @@ function handleMouseDown(e: MouseEvent) { } } - // In relative mode, request pointer lock on first click if (mouseMode.value === 'relative' && !isPointerLocked.value) { requestPointerLock() return @@ -2186,7 +2086,6 @@ function handleMouseUp(e: MouseEvent) { handleMouseUpInternal(e.button) } -// Window-level mouseup handler (catches releases outside the container) function handleWindowMouseUp(e: MouseEvent) { if (pressedMouseButton.value !== null) { handleMouseUpInternal(e.button) @@ -2201,7 +2100,6 @@ function handleMouseUpInternal(rawButton: number) { const button = rawButton === 0 ? 'left' : rawButton === 2 ? 'right' : 'middle' - // Only send if this button was actually pressed if (pressedMouseButton.value !== button) { return } @@ -2220,7 +2118,6 @@ function handleContextMenu(e: MouseEvent) { e.preventDefault() } -// Pointer Lock API for relative mouse mode function requestPointerLock() { const container = videoContainerRef.value if (!container) return @@ -2243,7 +2140,6 @@ function handlePointerLockChange() { isPointerLocked.value = document.pointerLockElement === container if (isPointerLocked.value) { - // Reset mouse position display when locked mousePosition.value = { x: 0, y: 0 } if (cursorVisible.value && container) { const r = container.getBoundingClientRect() @@ -2269,7 +2165,6 @@ function handleFullscreenChange() { function handleBlur() { pressedKeys.value = [] activeModifierMask.value = 0 - // Release any pressed mouse button when window loses focus if (pressedMouseButton.value !== null) { const button = pressedMouseButton.value pressedMouseButton.value = null @@ -2277,7 +2172,6 @@ function handleBlur() { } } -// Handle cursor visibility change from HidConfigPopover function handleCursorVisibilityChange(e: Event) { const customEvent = e as CustomEvent<{ visible: boolean }> cursorVisible.value = customEvent.detail.visible @@ -2365,7 +2259,6 @@ async function activateConsoleView() { void systemStore.fetchAllStates() void configStore.refreshHid().then(() => syncMouseModeFromConfig()).catch(() => {}) - // Ensure HID WebSocket is connected when console becomes active if (!hidWs.connected.value) { hidWs.connect().catch(() => {}) } @@ -2395,28 +2288,22 @@ function deactivateConsoleView() { unregisterInteractionListeners() } -// ActionBar handlers -// (MSD and Settings are now handled by ActionBar component directly) function handleToggleVirtualKeyboard() { virtualKeyboardVisible.value = !virtualKeyboardVisible.value } -// Virtual keyboard key event handlers function handleVirtualKeyDown(key: CanonicalKey) { - // Add to pressedKeys for InfoBar display if (!pressedKeys.value.includes(key)) { pressedKeys.value = [...pressedKeys.value, key] } } function handleVirtualKeyUp(key: CanonicalKey) { - // Remove from pressedKeys pressedKeys.value = pressedKeys.value.filter(k => k !== key) } function handleToggleMouseMode() { - // Exit pointer lock when switching away from relative mode if (mouseMode.value === 'relative' && isPointerLocked.value) { exitPointerLock() } @@ -2424,7 +2311,6 @@ function handleToggleMouseMode() { mouseMode.value = mouseMode.value === 'absolute' ? 'relative' : 'absolute' pendingMouseMove = null accumulatedDelta = { x: 0, y: 0 } - // Reset position when switching modes lastMousePosition.value = { x: 0, y: 0 } mousePosition.value = { x: 0, y: 0 } @@ -2436,16 +2322,13 @@ function handleToggleMouseMode() { } } -// Lifecycle onMounted(async () => { // 1. 先订阅 WebSocket 事件,再连接(内部会 connect) consoleEvents.subscribe() - // 3. Watch WebSocket connection states and sync to store watch([wsConnected, wsNetworkError], ([connected, netError], [_prevConnected, prevNetError]) => { systemStore.updateWsConnection(connected, netError) - // Auto-refresh video when network recovers (wsNetworkError: true -> false) if (prevNetError === true && netError === false && connected === true) { refreshVideo() } @@ -2476,7 +2359,6 @@ onMounted(async () => { // Note: Video mode is now synced from server via device_info event // The handleDeviceInfo function will automatically switch to the server's mode - // localStorage preference is only used when server mode matches try { const modeResp = await streamApi.getMode() const serverMode = normalizeServerMode(modeResp?.mode) @@ -2504,13 +2386,11 @@ onUnmounted(() => { initialModeRestoreDone = false initialModeRestoreInProgress = false - // Clear mouse flush timer if (mouseFlushTimer !== null) { clearTimeout(mouseFlushTimer) mouseFlushTimer = null } - // Clear all timers if (retryTimeoutId !== null) { clearTimeout(retryTimeoutId) retryTimeoutId = null @@ -2522,7 +2402,6 @@ onUnmounted(() => { cancelWebRTCRecovery() videoSession.clearWaiters() - // Reset counters retryCount = 0 consoleEvents.unsubscribe() @@ -2533,7 +2412,6 @@ onUnmounted(() => { void webrtc.disconnect() } - // Exit pointer lock if active exitPointerLock() }) diff --git a/web/src/views/SettingsView.vue b/web/src/views/SettingsView.vue index 3a399a52..ab11c273 100644 --- a/web/src/views/SettingsView.vue +++ b/web/src/views/SettingsView.vue @@ -106,7 +106,6 @@ const systemStore = useSystemStore() const configStore = useConfigStore() const authStore = useAuthStore() -// Settings state const activeSection = ref('appearance') const mobileMenuOpen = ref(false) const loading = ref(false) @@ -127,7 +126,6 @@ const SETTINGS_SECTION_IDS = new Set([ 'about', ]) -// Navigation structure const navGroups = computed(() => [ { title: t('settings.system'), @@ -175,10 +173,8 @@ function normalizeSettingsSection(value: unknown): string | null { return SETTINGS_SECTION_IDS.has(value) ? value : null } -// Theme const theme = ref<'light' | 'dark' | 'system'>('system') -// Account settings const usernameInput = ref('') const usernamePassword = ref('') const usernameSaving = ref(false) @@ -192,7 +188,6 @@ const passwordSaved = ref(false) const passwordError = ref('') const showPasswords = ref(false) -// Auth config state const authConfig = ref({ session_timeout_secs: 3600 * 24, single_user_allow_multiple_sessions: false, @@ -201,7 +196,6 @@ const authConfig = ref({ }) const authConfigLoading = ref(false) -// Extensions management const extensions = ref(null) const extensionsLoading = ref(false) const extensionLogs = ref>({ @@ -215,17 +209,14 @@ const showLogs = ref>({ easytier: false, }) -// Terminal dialog const showTerminalDialog = ref(false) -// Extension config (local edit state) const extConfig = ref({ ttyd: { enabled: false, shell: '/bin/bash' }, gostc: { enabled: false, addr: '', key: '', tls: true }, easytier: { enabled: false, network_name: '', network_secret: '', peer_urls: [] as string[], virtual_ip: '' }, }) -// RustDesk config state const rustdeskConfig = ref(null) const rustdeskStatus = ref(null) const rustdeskPassword = ref(null) @@ -239,7 +230,6 @@ const rustdeskLocalConfig = ref({ relay_key: '', }) -// RTSP config state const rtspStatus = ref(null) const rtspLoading = ref(false) const rtspLocalConfig = ref({ @@ -267,7 +257,6 @@ const rtspStreamUrl = computed(() => { return `rtsp://${host}:${port}/${path}` }) -// Web server config state const webServerConfig = ref({ http_port: 8080, https_port: 8443, @@ -277,14 +266,12 @@ const webServerConfig = ref({ has_custom_cert: false, }) const webServerLoading = ref(false) -// SSL certificate state const sslCertPem = ref('') const sslKeyPem = ref('') const certSaving = ref(false) const certClearing = ref(false) const showRestartDialog = ref(false) const restarting = ref(false) -// Auto-restart flow (no dialog needed for web-config saves) const autoRestarting = ref(false) const autoRestartFailed = ref(false) // For HTTPS targets: can't poll (self-signed cert), show manual link instead @@ -339,7 +326,6 @@ const previewAccessUrl = computed(() => { return `${scheme}://${host}:${port}` }) -// Config interface DeviceConfig { video: Array<{ path: string @@ -389,14 +375,12 @@ const config = ref({ msd_enabled: false, msd_dir: '', encoder_backend: 'auto', - // STUN/TURN settings stun_server: '', turn_server: '', turn_username: '', turn_password: '', }) -// Tracks whether TURN password is configured on the server const hasTurnPassword = ref(false) type OtgSelfCheckLevel = 'info' | 'warn' | 'error' @@ -658,7 +642,6 @@ async function onRunVideoEncoderSelfCheckClick() { await runVideoEncoderSelfCheck() } -// USB devices state const usbDevices = ref([]) const usbDevicesLoading = ref(false) const usbDevicesError = ref('') @@ -683,11 +666,9 @@ async function confirmUsbReset() { try { await usbApi.resetDevice(usbResetTarget.value.bus_num, usbResetTarget.value.dev_num) } catch { - // Error already shown by request helper toast } finally { usbResetLoading.value = false usbResetTarget.value = null - // Refresh the list after a short delay for USB re-enumeration setTimeout(() => fetchUsbDevices(), 1500) } } @@ -782,14 +763,12 @@ const isHidFunctionSelectionValid = computed(() => { return !!(f.keyboard || f.mouse_relative || f.mouse_absolute || f.consumer) }) -// OTG Descriptor settings const otgVendorIdHex = ref('1d6b') const otgProductIdHex = ref('0104') const otgManufacturer = ref('One-KVM') const otgProduct = ref('One-KVM USB Device') const otgSerialNumber = ref('') -// Validate hex input const validateHex = (event: Event, _field: string) => { const input = event.target as HTMLInputElement input.value = input.value.replace(/[^0-9a-fA-F]/g, '').toLowerCase() @@ -807,7 +786,6 @@ watch(bindMode, (mode) => { } }) -// ATX config state const atxConfig = ref({ enabled: false, power: { @@ -833,7 +811,6 @@ const atxConfig = ref({ wol_interface: '', }) -// ATX devices for discovery const atxDevices = ref({ gpio_chips: [], usb_relays: [], @@ -856,7 +833,6 @@ const isSharedAtxSerialRelay = computed(() => { ) }) -// Encoder backend const availableBackends = ref([]) const selectedBackendFormats = computed(() => { @@ -921,7 +897,6 @@ const availableFps = computed(() => { return currentRes ? currentRes.fps : [] }) -// Keep the selected format aligned with currently selectable formats. watch( selectableFormats, () => { @@ -938,7 +913,6 @@ watch( { deep: true }, ) -// Watch for format change to set default resolution watch(() => config.value.video_format, () => { if (availableResolutions.value.length > 0) { const isValid = availableResolutions.value.some( @@ -955,7 +929,6 @@ watch(() => config.value.video_format, () => { } }) -// Watch for resolution change to set default FPS watch(() => [config.value.video_width, config.value.video_height], () => { const fpsList = availableFps.value if (fpsList.length > 0) { @@ -975,7 +948,6 @@ watch(() => authStore.user, (value) => { }) -// Format bytes to human readable string function formatBytes(bytes: number): string { if (bytes === 0) return '0 B' const k = 1024 @@ -984,7 +956,6 @@ function formatBytes(bytes: number): string { return `${(bytes / Math.pow(k, i)).toFixed(1)} ${sizes[i]}` } -// Theme handling function setTheme(newTheme: 'light' | 'dark' | 'system') { theme.value = newTheme localStorage.setItem('theme', newTheme) @@ -997,7 +968,6 @@ function setTheme(newTheme: 'light' | 'dark' | 'system') { } } -// Account updates async function changeUsername() { usernameError.value = '' usernameSaved.value = false @@ -1062,17 +1032,13 @@ async function changePassword() { } } -// Save config using domain-separated APIs async function saveConfig() { loading.value = true saved.value = false try { - // Save only config related to the active section. // Sequential awaits: backend ConfigStore uses read-modify-write; parallel PATCH - // requests could overwrite each other's section (last writer wins on full JSON). - // Video config (including encoder and WebRTC/STUN/TURN settings) if (activeSection.value === 'video') { await configStore.updateVideo({ device: config.value.video_device || undefined, @@ -1100,7 +1066,6 @@ async function saveConfig() { ch9329_port: config.value.hid_serial_device || undefined, ch9329_baudrate: config.value.hid_serial_baudrate, } - // Add descriptor config for OTG backend if (config.value.hid_backend === 'otg') { hidUpdate.otg_descriptor = { vendor_id: parseInt(otgVendorIdHex.value, 16) || 0x1d6b, @@ -1120,7 +1085,6 @@ async function saveConfig() { }) } - // MSD config if (activeSection.value === 'msd') { await configStore.updateMsd({ msd_dir: config.value.msd_dir || undefined, @@ -1137,10 +1101,8 @@ async function saveConfig() { } } -// Load config using domain-separated APIs async function loadConfig() { try { - // Load all domain configs in parallel const [video, stream, hid, msd] = await Promise.all([ configStore.refreshVideo(), configStore.refreshStream(), @@ -1170,17 +1132,14 @@ async function loadConfig() { msd_enabled: msd.enabled || false, msd_dir: msd.msd_dir || '', encoder_backend: stream.encoder || 'auto', - // STUN/TURN settings stun_server: stream.stun_server || '', turn_server: stream.turn_server || '', turn_username: stream.turn_username || '', turn_password: '', // Password is never returned from server; set-only field } - // Track whether TURN password is configured hasTurnPassword.value = stream.has_turn_password || false - // Load OTG descriptor config if (hid.otg_descriptor) { otgVendorIdHex.value = hid.otg_descriptor.vendor_id?.toString(16).padStart(4, '0') || '1d6b' otgProductIdHex.value = hid.otg_descriptor.product_id?.toString(16).padStart(4, '0') || '0104' @@ -1211,7 +1170,6 @@ async function loadBackends() { } } -// Auth config functions async function loadAuthConfig() { authConfigLoading.value = true try { @@ -1236,12 +1194,10 @@ async function saveAuthConfig() { } } -// Extension management functions async function loadExtensions() { extensionsLoading.value = true try { extensions.value = await extensionsApi.getAll() - // Sync config from server if (extensions.value) { const ttyd = extensions.value.ttyd.config extConfig.value.ttyd = { @@ -1359,7 +1315,6 @@ function removeEasytierPeer(index: number) { } } -// ATX management functions async function loadAtxConfig() { try { const config = await configStore.refreshAtx() @@ -1485,7 +1440,6 @@ watch( }, ) -// RustDesk management functions async function loadRustdeskConfig() { rustdeskLoading.value = true try { @@ -1573,7 +1527,6 @@ function removeBindAddress(index: number) { } } -// Web server config functions async function loadWebServerConfig() { try { const config = await configStore.refreshWeb() @@ -1668,7 +1621,6 @@ async function pollUntilReady(targetOrigin: string, maxMs = 30000): Promise (saved.value = false), 2000) @@ -1878,7 +1829,6 @@ async function regenerateRustdeskPassword() { async function startRustdesk() { rustdeskLoading.value = true try { - // Enable and save config to start the service await configStore.updateRustdesk({ enabled: true }) rustdeskLocalConfig.value.enabled = true await loadRustdeskConfig() @@ -1892,7 +1842,6 @@ async function startRustdesk() { async function stopRustdesk() { rustdeskLoading.value = true try { - // Disable and save config to stop the service await configStore.updateRustdesk({ enabled: false }) rustdeskLocalConfig.value.enabled = false await loadRustdeskConfig() @@ -1919,7 +1868,6 @@ function getRustdeskServiceStatusText(status: string | undefined): string { case 'stopped': return t('extensions.stopped') case 'not_initialized': return t('extensions.rustdesk.notInitialized') default: - // Handle "error: xxx" format if (status.startsWith('error:')) return t('extensions.failed') return status } @@ -1933,7 +1881,6 @@ function getRustdeskRendezvousStatusText(status: string | null | undefined): str case 'connecting': return t('extensions.rustdesk.connecting') case 'disconnected': return t('extensions.rustdesk.disconnected') default: - // Handle "error: xxx" format if (status.startsWith('error:')) return t('extensions.failed') return status } @@ -1950,7 +1897,6 @@ function getRustdeskStatusClass(status: string | null | undefined): string { case 'not_initialized': case 'disconnected': return 'bg-gray-400' default: - // Handle "error: xxx" format if (status?.startsWith('error:')) return 'bg-red-500' return 'bg-gray-400' } @@ -2058,9 +2004,7 @@ function getRtspStatusClass(status: string | undefined): string { } } -// Lifecycle onMounted(async () => { - // Load theme preference const storedTheme = localStorage.getItem('theme') as 'light' | 'dark' | 'system' | null if (storedTheme) { theme.value = storedTheme diff --git a/web/src/views/SetupView.vue b/web/src/views/SetupView.vue index cd2098c0..369e75a5 100644 --- a/web/src/views/SetupView.vue +++ b/web/src/views/SetupView.vue @@ -42,20 +42,17 @@ const { t } = useI18n() const router = useRouter() const authStore = useAuthStore() -// Steps: 1 = Account, 2 = Audio/Video, 3 = HID, 4 = Extensions const step = ref(1) const totalSteps = 4 const loading = ref(false) const error = ref('') const slideDirection = ref<'forward' | 'backward'>('forward') -// Account settings const username = ref('') const password = ref('') const confirmPassword = ref('') const showPassword = ref(false) -// Form validation states const usernameError = ref('') const passwordError = ref('') const confirmPasswordError = ref('') @@ -63,17 +60,14 @@ const usernameTouched = ref(false) const passwordTouched = ref(false) const confirmPasswordTouched = ref(false) -// Video settings const videoDevice = ref('') const videoFormat = ref('') const videoResolution = ref('') const videoFps = ref(null) -// Audio settings const audioDevice = ref('') const audioEnabled = ref(true) -// HID settings const hidBackend = ref('ch9329') const ch9329Port = ref('') const ch9329Baudrate = ref(9600) @@ -83,7 +77,6 @@ const otgMsdEnabled = ref(true) const otgEndpointBudget = ref<'five' | 'six' | 'unlimited'>('six') const otgKeyboardLeds = ref(true) -// Extension settings const ttydEnabled = ref(false) const ttydAvailable = ref(false) @@ -137,7 +130,6 @@ const devices = ref({ }, }) -// Password strength calculation const passwordStrength = computed(() => { const pwd = password.value if (!pwd) return 0 @@ -183,7 +175,6 @@ async function refreshDeviceList() { ttydAvailable.value = result.extensions.ttyd_available } } catch { - // keep current list } finally { refreshingDevices.value = false } @@ -195,13 +186,11 @@ const availableFormats = computed(() => { return device?.formats || [] }) -// Computed: available resolutions for selected format const availableResolutions = computed(() => { const format = availableFormats.value.find((f) => f.format === videoFormat.value) return format?.resolutions || [] }) -// Computed: available FPS for selected resolution const availableFps = computed(() => { const [width, height] = (videoResolution.value || '').split('x').map(Number) const resolution = availableResolutions.value.find( @@ -252,10 +241,8 @@ function applyOtgDefaults() { otgKeyboardLeds.value = otgEndpointBudget.value !== 'five' } -// Common baud rates for CH9329 const baudRates = [9600, 19200, 38400, 57600, 115200] -// Step labels for the indicator const stepLabels = computed(() => [ t('setup.stepAccount'), t('setup.stepAudioVideo'), @@ -263,7 +250,6 @@ const stepLabels = computed(() => [ t('setup.stepExtensions'), ]) -// Real-time validation functions function validateUsername() { usernameTouched.value = true if (username.value.length === 0) { @@ -284,7 +270,6 @@ function validatePassword() { } else { passwordError.value = '' } - // Also validate confirm password if it was touched if (confirmPasswordTouched.value) { validateConfirmPassword() } @@ -335,12 +320,10 @@ watch(videoDevice, (newDevice) => { } }) -// Watch format change to auto-select best resolution watch(videoFormat, () => { videoResolution.value = '' videoFps.value = null if (availableResolutions.value.length > 0) { - // Prefer 1080p if available, otherwise highest resolution const r1080 = availableResolutions.value.find((r) => r.width === 1920 && r.height === 1080) const r720 = availableResolutions.value.find((r) => r.width === 1280 && r.height === 720) const best = r1080 || r720 || availableResolutions.value[0] @@ -350,11 +333,9 @@ watch(videoFormat, () => { } }) -// Watch resolution change to auto-select FPS watch(videoResolution, () => { videoFps.value = null if (availableFps.value.length > 0) { - // Prefer 30fps if available videoFps.value = availableFps.value.includes(30) ? 30 : availableFps.value[0] || null } }) @@ -389,7 +370,6 @@ onMounted(async () => { ch9329Port.value = result.serial[0].path } - // Auto-select first UDC for OTG if (result.udc.length > 0 && result.udc[0]) { otgUdc.value = result.udc[0].name } @@ -407,7 +387,6 @@ onMounted(async () => { ttydAvailable.value = result.extensions.ttyd_available } } catch { - // Use defaults } // Load encoder backends @@ -415,10 +394,8 @@ onMounted(async () => { const codecsResult = await streamApi.getCodecs() availableBackends.value = codecsResult.backends || [] } catch { - // Use defaults } - // Add keyboard navigation document.addEventListener('keydown', handleKeyDown) }) @@ -427,7 +404,6 @@ onUnmounted(() => { }) function handleKeyDown(e: KeyboardEvent) { - // Don't interfere with input fields if (e.target instanceof HTMLInputElement || e.target instanceof HTMLTextAreaElement) { return } @@ -446,7 +422,6 @@ function handleKeyDown(e: KeyboardEvent) { } function validateStep1(): boolean { - // Trigger validation for all fields validateUsername() validatePassword() validateConfirmPassword() @@ -521,7 +496,6 @@ async function handleSetup() { loading.value = true - // Parse resolution const [width, height] = (videoResolution.value || '').split('x').map(Number) const setupData: Parameters[0] = { @@ -529,7 +503,6 @@ async function handleSetup() { password: password.value, } - // Video settings if (videoDevice.value) { setupData.video_device = videoDevice.value } @@ -544,7 +517,6 @@ async function handleSetup() { setupData.video_fps = toConfigFps(videoFps.value) } - // HID settings setupData.hid_backend = hidBackend.value if (hidBackend.value === 'ch9329') { setupData.hid_ch9329_port = ch9329Port.value @@ -563,18 +535,15 @@ async function handleSetup() { setupData.encoder_backend = encoderBackend.value } - // Audio settings if (audioDevice.value && audioDevice.value !== '__none__') { setupData.audio_device = audioDevice.value } - // Extension settings setupData.ttyd_enabled = ttydEnabled.value const success = await authStore.setup(setupData) if (success) { - // Auto login after setup await authStore.login(username.value, password.value) router.push('/') } else { @@ -584,7 +553,6 @@ async function handleSetup() { loading.value = false } -// Step icon component helper const stepIcons = [User, Video, Keyboard, Puzzle]