From 206594e292a1ebca51dfad2c19812fb956a52ebe Mon Sep 17 00:00:00 2001 From: mofeng-git Date: Sun, 11 Jan 2026 10:41:57 +0800 Subject: [PATCH] =?UTF-8?q?feat(video):=20=E4=BA=8B=E5=8A=A1=E5=8C=96?= =?UTF-8?q?=E5=88=87=E6=8D=A2=E4=B8=8E=E5=89=8D=E7=AB=AF=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E7=BC=96=E6=8E=92=EF=BC=8C=E5=A2=9E=E5=BC=BA=E8=A7=86=E9=A2=91?= =?UTF-8?q?=E8=BE=93=E5=85=A5=E6=A0=BC=E5=BC=8F=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 后端:切换事务+transition_id,/stream/mode 返回 switching/transition_id 与实际 codec - 事件:新增 mode_switching/mode_ready,config/webrtc_ready/mode_changed 关联事务 - 编码/格式:扩展 NV21/NV16/NV24/RGB/BGR 输入与转换链路,RKMPP direct input 优化 - 前端:useVideoSession 统一切换,失败回退真实切回 MJPEG,菜单格式同步修复 - 清理:useVideoStream 降级为 MJPEG-only --- libs/hwcodec/build.rs | 41 +- libs/hwcodec/cpp/common/ffmpeg_ffi.h | 6 +- .../cpp/ffmpeg_ram/ffmpeg_ram_encode.cpp | 6 +- libs/hwcodec/src/common.rs | 2 +- libs/hwcodec/src/ffmpeg_ram/encode.rs | 12 +- libs/ventoy-img-rs/src/exfat/format.rs | 33 +- libs/ventoy-img-rs/src/exfat/ops.rs | 260 ++++++--- libs/ventoy-img-rs/src/exfat/unicode.rs | 2 +- libs/ventoy-img-rs/src/image.rs | 12 +- libs/ventoy-img-rs/src/lib.rs | 2 +- libs/ventoy-img-rs/src/main.rs | 63 ++- res/vcpkg/libyuv/build.rs | 32 +- res/vcpkg/libyuv/src/lib.rs | 31 ++ src/atx/led.rs | 5 +- src/atx/mod.rs | 4 +- src/atx/wol.rs | 10 +- src/audio/capture.rs | 66 +-- src/audio/controller.rs | 35 +- src/audio/device.rs | 16 +- src/audio/encoder.rs | 5 +- src/audio/mod.rs | 4 +- src/audio/monitor.rs | 8 +- src/audio/shared_pipeline.rs | 9 +- src/audio/streamer.rs | 7 +- src/auth/middleware.rs | 5 +- src/auth/mod.rs | 4 +- src/auth/user.rs | 17 +- src/config/schema.rs | 17 +- src/config/store.rs | 20 +- src/events/mod.rs | 3 +- src/events/types.rs | 45 +- src/extensions/manager.rs | 26 +- src/hid/ch9329.rs | 91 ++-- src/hid/datachannel.rs | 18 +- src/hid/keymap.rs | 58 +-- src/hid/mod.rs | 106 ++-- src/hid/monitor.rs | 20 +- src/hid/otg.rs | 142 +++-- src/hid/types.rs | 7 +- src/hid/websocket.rs | 21 +- src/main.rs | 139 +++-- src/msd/controller.rs | 58 ++- src/msd/image.rs | 76 +-- src/msd/mod.rs | 4 +- src/msd/monitor.rs | 5 +- src/msd/ventoy_drive.rs | 57 +- src/otg/configfs.rs | 33 +- src/otg/hid.rs | 41 +- src/otg/manager.rs | 44 +- src/otg/msd.rs | 25 +- src/otg/report_desc.rs | 18 +- src/otg/service.rs | 12 +- src/rustdesk/bytes_codec.rs | 25 +- src/rustdesk/config.rs | 48 +- src/rustdesk/connection.rs | 177 +++++-- src/rustdesk/crypto.rs | 39 +- src/rustdesk/frame_adapters.rs | 54 +- src/rustdesk/hid_adapter.rs | 64 ++- src/rustdesk/mod.rs | 247 +++++---- src/rustdesk/protocol.rs | 29 +- src/rustdesk/punch.rs | 11 +- src/rustdesk/rendezvous.rs | 58 ++- src/state.rs | 5 +- src/stream/mjpeg.rs | 71 +-- src/stream/mjpeg_streamer.rs | 37 +- src/stream/mod.rs | 4 +- src/stream/ws_hid.rs | 15 +- src/utils/throttle.rs | 1 + src/video/capture.rs | 55 +- src/video/convert.rs | 179 ++++++- src/video/device.rs | 63 ++- src/video/encoder/h264.rs | 58 ++- src/video/encoder/h265.rs | 57 +- src/video/encoder/jpeg.rs | 17 +- src/video/encoder/mod.rs | 4 +- src/video/encoder/registry.rs | 26 +- src/video/encoder/traits.rs | 10 +- src/video/encoder/vp8.rs | 23 +- src/video/encoder/vp9.rs | 23 +- src/video/format.rs | 21 +- src/video/frame.rs | 6 +- src/video/h264_pipeline.rs | 12 +- src/video/mod.rs | 10 +- src/video/shared_video_pipeline.rs | 433 +++++++++++---- src/video/stream_manager.rs | 245 +++++++-- src/video/streamer.rs | 155 ++++-- src/video/video_session.rs | 8 +- src/web/handlers/config/apply.rs | 42 +- src/web/handlers/config/mod.rs | 20 +- src/web/handlers/config/rustdesk.rs | 8 +- src/web/handlers/config/types.rs | 73 ++- src/web/handlers/extensions.rs | 8 +- src/web/handlers/mod.rs | 338 ++++++++---- src/web/handlers/terminal.rs | 20 +- src/web/mod.rs | 2 +- src/web/routes.rs | 99 +++- src/web/static_files.rs | 38 +- src/webrtc/h265_payloader.rs | 11 +- src/webrtc/mod.rs | 4 +- src/webrtc/peer.rs | 123 ++--- src/webrtc/rtp.rs | 20 +- src/webrtc/track.rs | 11 +- src/webrtc/universal_session.rs | 109 ++-- src/webrtc/webrtc_streamer.rs | 99 ++-- web/src/api/index.ts | 4 +- web/src/components/VideoConfigPopover.vue | 34 ++ web/src/composables/useConsoleEvents.ts | 46 +- web/src/composables/useVideoSession.ts | 185 +++++++ web/src/composables/useVideoStream.ts | 346 +----------- web/src/views/ConsoleView.vue | 493 ++++++++---------- 110 files changed, 3955 insertions(+), 2251 deletions(-) create mode 100644 web/src/composables/useVideoSession.ts diff --git a/libs/hwcodec/build.rs b/libs/hwcodec/build.rs index 7980a7fb..444d6a8c 100644 --- a/libs/hwcodec/build.rs +++ b/libs/hwcodec/build.rs @@ -56,7 +56,10 @@ fn build_common(builder: &mut Build) { // Unsupported platforms if target_os != "windows" && target_os != "linux" { - panic!("Unsupported OS: {}. Only Windows and Linux are supported.", target_os); + panic!( + "Unsupported OS: {}. Only Windows and Linux are supported.", + target_os + ); } // tool @@ -103,7 +106,9 @@ mod ffmpeg { use std::process::Command; // Check if static linking is requested - let use_static = std::env::var("FFMPEG_STATIC").map(|v| v == "1").unwrap_or(false); + let use_static = std::env::var("FFMPEG_STATIC") + .map(|v| v == "1") + .unwrap_or(false); let target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_default(); // Try custom library path first: @@ -142,7 +147,7 @@ mod ffmpeg { // VAAPI for x86_64 println!("cargo:rustc-link-lib=va"); println!("cargo:rustc-link-lib=va-drm"); - println!("cargo:rustc-link-lib=va-x11"); // Required for vaGetDisplay + println!("cargo:rustc-link-lib=va-x11"); // Required for vaGetDisplay println!("cargo:rustc-link-lib=mfx"); } else { // RKMPP for ARM @@ -172,10 +177,7 @@ mod ffmpeg { for lib in &libs { // Get cflags - if let Ok(output) = Command::new("pkg-config") - .args(["--cflags", lib]) - .output() - { + if let Ok(output) = Command::new("pkg-config").args(["--cflags", lib]).output() { if output.status.success() { let cflags = String::from_utf8_lossy(&output.stdout); for flag in cflags.split_whitespace() { @@ -193,10 +195,7 @@ mod ffmpeg { vec!["--libs", lib] }; - if let Ok(output) = Command::new("pkg-config") - .args(&pkg_config_args) - .output() - { + if let Ok(output) = Command::new("pkg-config").args(&pkg_config_args).output() { if output.status.success() { let libs_str = String::from_utf8_lossy(&output.stdout); for flag in libs_str.split_whitespace() { @@ -221,7 +220,9 @@ mod ffmpeg { panic!("pkg-config failed for {}. Install FFmpeg development libraries: sudo apt install libavcodec-dev libavutil-dev", lib); } } else { - panic!("pkg-config not found. Install pkg-config and FFmpeg development libraries."); + panic!( + "pkg-config not found. Install pkg-config and FFmpeg development libraries." + ); } } @@ -301,7 +302,10 @@ mod ffmpeg { // ARM (aarch64, arm): no X11 needed, uses RKMPP/V4L2 v } else { - panic!("Unsupported OS: {}. Only Windows and Linux are supported.", target_os); + panic!( + "Unsupported OS: {}. Only Windows and Linux are supported.", + target_os + ); }; for lib in dyn_libs.iter() { @@ -312,10 +316,9 @@ mod ffmpeg { fn ffmpeg_ffi() { let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); let ffmpeg_ram_dir = manifest_dir.join("cpp").join("common"); - let ffi_header = ffmpeg_ram_dir - .join("ffmpeg_ffi.h") - .to_string_lossy() - .to_string(); + let ffi_header_path = ffmpeg_ram_dir.join("ffmpeg_ffi.h"); + println!("cargo:rerun-if-changed={}", ffi_header_path.display()); + let ffi_header = ffi_header_path.to_string_lossy().to_string(); bindgen::builder() .header(ffi_header) .rustified_enum("*") @@ -340,8 +343,6 @@ mod ffmpeg { .write_to_file(Path::new(&env::var_os("OUT_DIR").unwrap()).join("ffmpeg_ram_ffi.rs")) .unwrap(); - builder.files( - ["ffmpeg_ram_encode.cpp"].map(|f| ffmpeg_ram_dir.join(f)), - ); + builder.files(["ffmpeg_ram_encode.cpp"].map(|f| ffmpeg_ram_dir.join(f))); } } diff --git a/libs/hwcodec/cpp/common/ffmpeg_ffi.h b/libs/hwcodec/cpp/common/ffmpeg_ffi.h index 1845d406..07777772 100644 --- a/libs/hwcodec/cpp/common/ffmpeg_ffi.h +++ b/libs/hwcodec/cpp/common/ffmpeg_ffi.h @@ -14,11 +14,15 @@ enum AVPixelFormat { AV_PIX_FMT_YUV420P = 0, AV_PIX_FMT_YUYV422 = 1, + AV_PIX_FMT_RGB24 = 2, + AV_PIX_FMT_BGR24 = 3, AV_PIX_FMT_YUV422P = 4, // planar YUV 4:2:2 AV_PIX_FMT_YUVJ420P = 12, // JPEG full-range YUV420P (same layout as YUV420P) AV_PIX_FMT_YUVJ422P = 13, // JPEG full-range YUV422P (same layout as YUV422P) AV_PIX_FMT_NV12 = 23, AV_PIX_FMT_NV21 = 24, + AV_PIX_FMT_NV16 = 101, + AV_PIX_FMT_NV24 = 188, }; int av_log_get_level(void); @@ -26,4 +30,4 @@ void av_log_set_level(int level); void hwcodec_set_av_log_callback(); void hwcodec_set_flag_could_not_find_ref_with_poc(); -#endif \ No newline at end of file +#endif diff --git a/libs/hwcodec/cpp/ffmpeg_ram/ffmpeg_ram_encode.cpp b/libs/hwcodec/cpp/ffmpeg_ram/ffmpeg_ram_encode.cpp index f0c36ae3..1fb942aa 100644 --- a/libs/hwcodec/cpp/ffmpeg_ram/ffmpeg_ram_encode.cpp +++ b/libs/hwcodec/cpp/ffmpeg_ram/ffmpeg_ram_encode.cpp @@ -388,7 +388,9 @@ private: } _exit: av_packet_unref(pkt_); - return encoded ? 0 : -1; + // If no packet is produced for this input frame, treat it as EAGAIN. + // This is not a fatal error: encoders may buffer internally (e.g., startup delay). + return encoded ? 0 : AVERROR(EAGAIN); } int fill_frame(AVFrame *frame, uint8_t *data, int data_length, @@ -511,4 +513,4 @@ extern "C" void ffmpeg_ram_request_keyframe(FFmpegRamEncoder *encoder) { } catch (const std::exception &e) { LOG_ERROR(std::string("ffmpeg_ram_request_keyframe failed, ") + std::string(e.what())); } -} \ No newline at end of file +} diff --git a/libs/hwcodec/src/common.rs b/libs/hwcodec/src/common.rs index 3d6d8b4c..b545d9ef 100644 --- a/libs/hwcodec/src/common.rs +++ b/libs/hwcodec/src/common.rs @@ -84,7 +84,7 @@ pub fn setup_parent_death_signal() { pub fn child_exit_when_parent_exit(child_process_id: u32) -> bool { unsafe { extern "C" { - fn add_process_to_new_job(child_process_id: u32) -> i32; + fn add_process_to_new_job(child_process_id: u32) -> i32; } let result = add_process_to_new_job(child_process_id); result == 0 diff --git a/libs/hwcodec/src/ffmpeg_ram/encode.rs b/libs/hwcodec/src/ffmpeg_ram/encode.rs index a1a7e241..dff0a135 100644 --- a/libs/hwcodec/src/ffmpeg_ram/encode.rs +++ b/libs/hwcodec/src/ffmpeg_ram/encode.rs @@ -3,7 +3,8 @@ use crate::{ ffmpeg::{init_av_log, AVPixelFormat}, ffmpeg_ram::{ ffmpeg_linesize_offset_length, ffmpeg_ram_encode, ffmpeg_ram_free_encoder, - ffmpeg_ram_new_encoder, ffmpeg_ram_request_keyframe, ffmpeg_ram_set_bitrate, CodecInfo, AV_NUM_DATA_POINTERS, + ffmpeg_ram_new_encoder, ffmpeg_ram_request_keyframe, ffmpeg_ram_set_bitrate, CodecInfo, + AV_NUM_DATA_POINTERS, }, }; use log::trace; @@ -123,6 +124,12 @@ impl Encoder { self.frames as *const _ as *const c_void, ms, ); + // ffmpeg_ram_encode returns AVERROR(EAGAIN) when the encoder accepts the frame + // but does not output a packet yet (e.g., startup delay / internal buffering). + // Treat this as a successful call with an empty output list. + if result == -11 { + return Ok(&mut *self.frames); + } if result != 0 { return Err(result); } @@ -358,7 +365,8 @@ impl Encoder { if frames[0].key == 1 && elapsed < TEST_TIMEOUT_MS as _ { debug!( "Encoder {} test passed on attempt {}", - codec.name, attempt + 1 + codec.name, + attempt + 1 ); res.push(codec.clone()); passed = true; diff --git a/libs/ventoy-img-rs/src/exfat/format.rs b/libs/ventoy-img-rs/src/exfat/format.rs index 5f10460b..f0717e0d 100644 --- a/libs/ventoy-img-rs/src/exfat/format.rs +++ b/libs/ventoy-img-rs/src/exfat/format.rs @@ -33,13 +33,13 @@ fn get_cluster_size(total_sectors: u64) -> u32 { /// For example: 32KB cluster = 64 sectors = 2^6, so shift = 6 fn sectors_per_cluster_shift(cluster_size: u32) -> u8 { match cluster_size { - 4096 => 3, // 8 sectors (4KB) - 8192 => 4, // 16 sectors (8KB) - 16384 => 5, // 32 sectors (16KB) - 32768 => 6, // 64 sectors (32KB) - 65536 => 7, // 128 sectors (64KB) - 131072 => 8, // 256 sectors (128KB) - 262144 => 9, // 512 sectors (256KB) + 4096 => 3, // 8 sectors (4KB) + 8192 => 4, // 16 sectors (8KB) + 16384 => 5, // 32 sectors (16KB) + 32768 => 6, // 64 sectors (32KB) + 65536 => 7, // 128 sectors (64KB) + 131072 => 8, // 256 sectors (128KB) + 262144 => 9, // 512 sectors (256KB) _ => { // Fallback: calculate dynamically let sectors = cluster_size / 512; @@ -75,11 +75,7 @@ struct ExfatBootSector { } impl ExfatBootSector { - fn new( - volume_length: u64, - cluster_size: u32, - volume_serial: u32, - ) -> Self { + fn new(volume_length: u64, cluster_size: u32, volume_serial: u32) -> Self { let sector_size: u32 = 512; let sectors_per_cluster = cluster_size / sector_size; let spc_shift = sectors_per_cluster_shift(cluster_size); @@ -106,7 +102,8 @@ impl ExfatBootSector { // Cluster 3...: Upcase table (128KB, may span multiple clusters) // Next available: Root directory const UPCASE_TABLE_SIZE: u64 = 128 * 1024; - let upcase_clusters = ((UPCASE_TABLE_SIZE + cluster_size as u64 - 1) / cluster_size as u64) as u32; + let upcase_clusters = + ((UPCASE_TABLE_SIZE + cluster_size as u64 - 1) / cluster_size as u64) as u32; let first_cluster_of_root = 3 + upcase_clusters; Self { @@ -235,7 +232,7 @@ fn create_bitmap_entry(start_cluster: u32, size: u64) -> [u8; 32] { let mut entry = [0u8; 32]; entry[0] = ENTRY_TYPE_BITMAP; entry[1] = 0; // BitmapFlags - // Reserved: bytes 2-19 + // Reserved: bytes 2-19 entry[20..24].copy_from_slice(&start_cluster.to_le_bytes()); entry[24..32].copy_from_slice(&size.to_le_bytes()); entry @@ -306,7 +303,8 @@ pub fn format_exfat( // Calculate how many clusters the upcase table needs (128KB) const UPCASE_TABLE_SIZE: u64 = 128 * 1024; - let upcase_clusters = ((UPCASE_TABLE_SIZE + cluster_size as u64 - 1) / cluster_size as u64) as u32; + let upcase_clusters = + ((UPCASE_TABLE_SIZE + cluster_size as u64 - 1) / cluster_size as u64) as u32; let root_cluster = 3 + upcase_clusters; // Root comes after bitmap and upcase // FAT entries: cluster 0 and 1 are reserved @@ -349,13 +347,14 @@ pub fn format_exfat( // Cluster 2: Allocation Bitmap let bitmap_size = (boot_sector.cluster_count + 7) / 8; - let _bitmap_clusters = ((bitmap_size as u64 + cluster_size as u64 - 1) / cluster_size as u64).max(1); + let _bitmap_clusters = + ((bitmap_size as u64 + cluster_size as u64 - 1) / cluster_size as u64).max(1); let mut bitmap = vec![0u8; cluster_size as usize]; // Mark clusters 2, 3..3+upcase_clusters-1, root_cluster as used // Cluster 2: bitmap bitmap[0] |= 0b00000100; // Bit 2 - // Clusters 3..3+upcase_clusters-1: upcase table + // Clusters 3..3+upcase_clusters-1: upcase table for i in 0..upcase_clusters { let cluster = 3 + i; let byte_idx = (cluster / 8) as usize; diff --git a/libs/ventoy-img-rs/src/exfat/ops.rs b/libs/ventoy-img-rs/src/exfat/ops.rs index 7a25fe16..b05ba956 100644 --- a/libs/ventoy-img-rs/src/exfat/ops.rs +++ b/libs/ventoy-img-rs/src/exfat/ops.rs @@ -53,8 +53,7 @@ impl FatCache { if self.entries.is_empty() { return false; } - cluster >= self.start_cluster - && cluster < self.start_cluster + self.entries.len() as u32 + cluster >= self.start_cluster && cluster < self.start_cluster + self.entries.len() as u32 } /// Get a FAT entry from cache (if present) @@ -243,7 +242,9 @@ impl ExfatFs { /// Get the byte offset of a FAT entry fn fat_entry_offset(&self, cluster: u32) -> u64 { - self.partition_offset + self.fat_offset as u64 * self.bytes_per_sector as u64 + cluster as u64 * 4 + self.partition_offset + + self.fat_offset as u64 * self.bytes_per_sector as u64 + + cluster as u64 * 4 } /// Load a FAT segment into cache starting from the given cluster @@ -287,7 +288,10 @@ impl ExfatFs { // Should be in cache now self.fat_cache.get(cluster).ok_or_else(|| { - VentoyError::FilesystemError(format!("Failed to cache FAT entry for cluster {}", cluster)) + VentoyError::FilesystemError(format!( + "Failed to cache FAT entry for cluster {}", + cluster + )) }) } @@ -490,9 +494,9 @@ impl ExfatFs { fn extend_cluster_chain(&mut self, first_cluster: u32) -> Result { // Find the last cluster in the chain let chain = self.read_cluster_chain(first_cluster)?; - let last_cluster = *chain.last().ok_or_else(|| { - VentoyError::FilesystemError("Empty cluster chain".to_string()) - })?; + let last_cluster = *chain + .last() + .ok_or_else(|| VentoyError::FilesystemError("Empty cluster chain".to_string()))?; // Allocate one new cluster let new_cluster = self.allocate_clusters(1)?; @@ -532,7 +536,12 @@ impl ExfatFs { } /// Create file directory entries for a new file - fn create_file_entries(name: &str, first_cluster: u32, size: u64, is_dir: bool) -> Vec<[u8; 32]> { + fn create_file_entries( + name: &str, + first_cluster: u32, + size: u64, + is_dir: bool, + ) -> Vec<[u8; 32]> { let name_utf16: Vec = name.encode_utf16().collect(); let name_entries_needed = (name_utf16.len() + 14) / 15; // 15 chars per name entry let secondary_count = 1 + name_entries_needed; // Stream + Name entries @@ -552,18 +561,22 @@ impl ExfatFs { .map(|d| d.as_secs() as u32) .unwrap_or(0); // DOS timestamp format (simplified) - let dos_time = ((now / 2) & 0x1F) | (((now / 60) & 0x3F) << 5) | (((now / 3600) & 0x1F) << 11); + let dos_time = + ((now / 2) & 0x1F) | (((now / 60) & 0x3F) << 5) | (((now / 3600) & 0x1F) << 11); let dos_date = 1 | (1 << 5) | ((45) << 9); // Jan 1, 2025 - file_entry[8..12].copy_from_slice(&(dos_date as u32 | ((dos_time as u32) << 16)).to_le_bytes()); - file_entry[12..16].copy_from_slice(&(dos_date as u32 | ((dos_time as u32) << 16)).to_le_bytes()); - file_entry[16..20].copy_from_slice(&(dos_date as u32 | ((dos_time as u32) << 16)).to_le_bytes()); + file_entry[8..12] + .copy_from_slice(&(dos_date as u32 | ((dos_time as u32) << 16)).to_le_bytes()); + file_entry[12..16] + .copy_from_slice(&(dos_date as u32 | ((dos_time as u32) << 16)).to_le_bytes()); + file_entry[16..20] + .copy_from_slice(&(dos_date as u32 | ((dos_time as u32) << 16)).to_le_bytes()); entries.push(file_entry); // 2. Stream Extension Entry (0xC0) let mut stream_entry = [0u8; 32]; stream_entry[0] = ENTRY_TYPE_STREAM; stream_entry[1] = 0x03; // GeneralSecondaryFlags: AllocationPossible | NoFatChain (for contiguous) - // For non-contiguous files, use 0x01 + // For non-contiguous files, use 0x01 if size > 0 { stream_entry[1] = 0x01; // AllocationPossible, use FAT chain } @@ -588,7 +601,8 @@ impl ExfatFs { for i in 0..15 { if char_index < name_utf16.len() { let offset = 2 + i * 2; - name_entry[offset..offset + 2].copy_from_slice(&name_utf16[char_index].to_le_bytes()); + name_entry[offset..offset + 2] + .copy_from_slice(&name_utf16[char_index].to_le_bytes()); char_index += 1; } } @@ -603,7 +617,11 @@ impl ExfatFs { } /// Find a file entry in a specific directory cluster - fn find_entry_in_directory(&mut self, dir_cluster: u32, name: &str) -> Result> { + fn find_entry_in_directory( + &mut self, + dir_cluster: u32, + name: &str, + ) -> Result> { let target_name_lower = name.to_lowercase(); // Read all clusters in the directory chain @@ -747,7 +765,11 @@ impl ExfatFs { /// /// If no free slot is found in existing clusters, this method will /// automatically extend the directory by allocating a new cluster. - fn find_free_slot_in_directory(&mut self, dir_cluster: u32, entries_needed: usize) -> Result<(u32, u32)> { + fn find_free_slot_in_directory( + &mut self, + dir_cluster: u32, + entries_needed: usize, + ) -> Result<(u32, u32)> { let dir_clusters = self.read_cluster_chain(dir_cluster)?; for &cluster in &dir_clusters { @@ -759,10 +781,12 @@ impl ExfatFs { while i < cluster_data.len() { let entry_type = cluster_data[i]; - if entry_type == ENTRY_TYPE_END || entry_type == 0x00 + if entry_type == ENTRY_TYPE_END + || entry_type == 0x00 || entry_type == ENTRY_TYPE_DELETED_FILE || entry_type == ENTRY_TYPE_DELETED_STREAM - || entry_type == ENTRY_TYPE_DELETED_NAME { + || entry_type == ENTRY_TYPE_DELETED_NAME + { if consecutive_free == 0 { slot_start = i; } @@ -795,7 +819,8 @@ impl ExfatFs { // This is critical: when we extend a directory, we need to clear any END markers // that may exist in previous clusters, otherwise list_files will stop prematurely let dir_clusters_before = self.read_cluster_chain(dir_cluster)?; - for &cluster in &dir_clusters_before[..dir_clusters_before.len()-1] { // Exclude the newly added cluster + for &cluster in &dir_clusters_before[..dir_clusters_before.len() - 1] { + // Exclude the newly added cluster let mut cluster_data = self.read_cluster(cluster)?; // Scan for END markers and replace them with 0xFF (invalid entry, will be skipped) @@ -815,14 +840,23 @@ impl ExfatFs { /// Find a free slot in the root directory for new entries (backward compatible) #[allow(dead_code)] fn find_free_directory_slot(&mut self, entries_needed: usize) -> Result { - let (_, offset) = self.find_free_slot_in_directory(self.first_cluster_of_root, entries_needed)?; + let (_, offset) = + self.find_free_slot_in_directory(self.first_cluster_of_root, entries_needed)?; Ok(offset) } /// Create an entry in a specific directory - fn create_entry_in_directory(&mut self, dir_cluster: u32, name: &str, first_cluster: u32, size: u64, is_dir: bool) -> Result<()> { + fn create_entry_in_directory( + &mut self, + dir_cluster: u32, + name: &str, + first_cluster: u32, + size: u64, + is_dir: bool, + ) -> Result<()> { let entries = Self::create_file_entries(name, first_cluster, size, is_dir); - let (slot_cluster, slot_offset) = self.find_free_slot_in_directory(dir_cluster, entries.len())?; + let (slot_cluster, slot_offset) = + self.find_free_slot_in_directory(dir_cluster, entries.len())?; let mut cluster_data = self.read_cluster(slot_cluster)?; @@ -854,7 +888,10 @@ impl ExfatFs { } // Check if already exists - if self.find_entry_in_directory(parent_cluster, name)?.is_some() { + if self + .find_entry_in_directory(parent_cluster, name)? + .is_some() + { return Err(VentoyError::FilesystemError(format!( "Entry '{}' already exists", name @@ -903,7 +940,11 @@ impl ExfatFs { // ==================== Public File Operations ==================== /// List files in a specific directory cluster - fn list_files_in_directory(&mut self, dir_cluster: u32, current_path: &str) -> Result> { + fn list_files_in_directory( + &mut self, + dir_cluster: u32, + current_path: &str, + ) -> Result> { let dir_clusters = self.read_cluster_chain(dir_cluster)?; // Pre-allocate Vec based on estimated entries @@ -1038,7 +1079,12 @@ impl ExfatFs { } /// Write file data to allocated clusters and create directory entry - fn write_file_data_and_entry(&mut self, dir_cluster: u32, name: &str, data: &[u8]) -> Result<()> { + fn write_file_data_and_entry( + &mut self, + dir_cluster: u32, + name: &str, + data: &[u8], + ) -> Result<()> { // Calculate clusters needed let clusters_needed = if data.is_empty() { 0 @@ -1121,7 +1167,13 @@ impl ExfatFs { /// Path can include directories, e.g., "iso/linux/ubuntu.iso" /// If create_parents is true, intermediate directories will be created. /// If overwrite is true, existing files will be replaced. - pub fn write_file_path(&mut self, path: &str, data: &[u8], create_parents: bool, overwrite: bool) -> Result<()> { + pub fn write_file_path( + &mut self, + path: &str, + data: &[u8], + create_parents: bool, + overwrite: bool, + ) -> Result<()> { let resolved = self.resolve_path(path, create_parents)?; // Validate filename @@ -1177,9 +1229,9 @@ impl ExfatFs { /// Read a file from the filesystem (root directory) pub fn read_file(&mut self, name: &str) -> Result> { - let location = self.find_file_entry(name)?.ok_or_else(|| { - VentoyError::FilesystemError(format!("File '{}' not found", name)) - })?; + let location = self + .find_file_entry(name)? + .ok_or_else(|| VentoyError::FilesystemError(format!("File '{}' not found", name)))?; self.read_file_from_location(&location) } @@ -1224,9 +1276,9 @@ impl ExfatFs { /// Delete a file from the filesystem (root directory) pub fn delete_file(&mut self, name: &str) -> Result<()> { - let location = self.find_file_entry(name)?.ok_or_else(|| { - VentoyError::FilesystemError(format!("File '{}' not found", name)) - })?; + let location = self + .find_file_entry(name)? + .ok_or_else(|| VentoyError::FilesystemError(format!("File '{}' not found", name)))?; // Free cluster chain if location.first_cluster >= 2 { @@ -1244,9 +1296,9 @@ impl ExfatFs { pub fn delete_path(&mut self, path: &str) -> Result<()> { let resolved = self.resolve_path(path, false)?; - let location = resolved.location.ok_or_else(|| { - VentoyError::FilesystemError(format!("'{}' not found", path)) - })?; + let location = resolved + .location + .ok_or_else(|| VentoyError::FilesystemError(format!("'{}' not found", path)))?; // If it's a directory, check if it's empty if location.is_directory { @@ -1275,9 +1327,9 @@ impl ExfatFs { pub fn delete_recursive(&mut self, path: &str) -> Result<()> { let resolved = self.resolve_path(path, false)?; - let location = resolved.location.ok_or_else(|| { - VentoyError::FilesystemError(format!("'{}' not found", path)) - })?; + let location = resolved + .location + .ok_or_else(|| VentoyError::FilesystemError(format!("'{}' not found", path)))?; if location.is_directory { // Get all contents and delete them first @@ -1344,7 +1396,12 @@ impl<'a> ExfatFileWriter<'a> { } /// Create a new file writer with overwrite option - pub fn create_overwrite(fs: &'a mut ExfatFs, name: &str, total_size: u64, overwrite: bool) -> Result { + pub fn create_overwrite( + fs: &'a mut ExfatFs, + name: &str, + total_size: u64, + overwrite: bool, + ) -> Result { let root_cluster = fs.first_cluster_of_root; Self::create_in_directory(fs, root_cluster, name, total_size, overwrite) } @@ -1353,7 +1410,13 @@ impl<'a> ExfatFileWriter<'a> { /// /// If create_parents is true, intermediate directories will be created. /// If overwrite is true, existing files will be replaced. - pub fn create_at_path(fs: &'a mut ExfatFs, path: &str, total_size: u64, create_parents: bool, overwrite: bool) -> Result { + pub fn create_at_path( + fs: &'a mut ExfatFs, + path: &str, + total_size: u64, + create_parents: bool, + overwrite: bool, + ) -> Result { let resolved = fs.resolve_path(path, create_parents)?; // Handle existing file @@ -1378,11 +1441,23 @@ impl<'a> ExfatFileWriter<'a> { } } - Self::create_in_directory(fs, resolved.parent_cluster, &resolved.name, total_size, false) + Self::create_in_directory( + fs, + resolved.parent_cluster, + &resolved.name, + total_size, + false, + ) } /// Internal: Create a file writer in a specific directory - fn create_in_directory(fs: &'a mut ExfatFs, dir_cluster: u32, name: &str, total_size: u64, overwrite: bool) -> Result { + fn create_in_directory( + fs: &'a mut ExfatFs, + dir_cluster: u32, + name: &str, + total_size: u64, + overwrite: bool, + ) -> Result { // Validate filename if name.is_empty() || name.len() > 255 { return Err(VentoyError::FilesystemError( @@ -1473,7 +1548,9 @@ impl<'a> ExfatFileWriter<'a> { /// This must be called after all data has been written. pub fn finish(self) -> Result<()> { // Write any remaining data in buffer - if !self.cluster_buffer.is_empty() && self.current_cluster_index < self.allocated_clusters.len() { + if !self.cluster_buffer.is_empty() + && self.current_cluster_index < self.allocated_clusters.len() + { let cluster = self.allocated_clusters[self.current_cluster_index]; self.fs.write_cluster(cluster, &self.cluster_buffer)?; } @@ -1485,7 +1562,13 @@ impl<'a> ExfatFileWriter<'a> { self.allocated_clusters[0] }; - self.fs.create_entry_in_directory(self.dir_cluster, &self.name, first_cluster, self.total_size, false)?; + self.fs.create_entry_in_directory( + self.dir_cluster, + &self.name, + first_cluster, + self.total_size, + false, + )?; self.fs.file.flush()?; Ok(()) @@ -1517,9 +1600,9 @@ pub struct ExfatFileReader<'a> { impl<'a> ExfatFileReader<'a> { /// Open a file for reading from root directory pub fn open(fs: &'a mut ExfatFs, name: &str) -> Result { - let location = fs.find_file_entry(name)?.ok_or_else(|| { - VentoyError::FilesystemError(format!("File '{}' not found", name)) - })?; + let location = fs + .find_file_entry(name)? + .ok_or_else(|| VentoyError::FilesystemError(format!("File '{}' not found", name)))?; if location.is_directory { return Err(VentoyError::FilesystemError(format!( @@ -1635,7 +1718,8 @@ impl<'a> Read for ExfatFileReader<'a> { // Read the cluster and copy data { - let cluster_data = self.read_cluster_cached(cluster_index) + let cluster_data = self + .read_cluster_cached(cluster_index) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; // Copy data to buffer @@ -1676,11 +1760,7 @@ impl ExfatFs { /// Read a file to a writer (streaming) /// /// This is useful for reading large files without loading them into memory. - pub fn read_file_to_writer( - &mut self, - name: &str, - writer: &mut W, - ) -> Result { + pub fn read_file_to_writer(&mut self, name: &str, writer: &mut W) -> Result { let mut reader = ExfatFileReader::open(self, name)?; Self::do_stream_read(&mut reader, writer) } @@ -1701,13 +1781,13 @@ impl ExfatFs { let mut total_bytes = 0u64; loop { - let bytes_read = reader.read(&mut buffer).map_err(|e| { - VentoyError::Io(e) - })?; + let bytes_read = reader.read(&mut buffer).map_err(|e| VentoyError::Io(e))?; if bytes_read == 0 { break; } - writer.write_all(&buffer[..bytes_read]).map_err(VentoyError::Io)?; + writer + .write_all(&buffer[..bytes_read]) + .map_err(VentoyError::Io)?; total_bytes += bytes_read as u64; } @@ -1755,7 +1835,8 @@ impl ExfatFs { create_parents: bool, overwrite: bool, ) -> Result<()> { - let mut writer = ExfatFileWriter::create_at_path(self, path, size, create_parents, overwrite)?; + let mut writer = + ExfatFileWriter::create_at_path(self, path, size, create_parents, overwrite)?; Self::do_stream_write(&mut writer, reader)?; writer.finish() } @@ -1804,8 +1885,13 @@ mod tests { file.set_len(size).unwrap(); // Format data partition (this will use 4KB clusters for 64MB volume) - crate::exfat::format::format_exfat(&mut file, layout.data_offset(), layout.data_size(), "TEST") - .unwrap(); + crate::exfat::format::format_exfat( + &mut file, + layout.data_offset(), + layout.data_size(), + "TEST", + ) + .unwrap(); drop(file); @@ -1826,21 +1912,12 @@ mod tests { let data = format!("content {}", i); let mut cursor = Cursor::new(data.as_bytes()); - fs.write_file_from_reader( - &filename, - &mut cursor, - data.len() as u64, - )?; + fs.write_file_from_reader(&filename, &mut cursor, data.len() as u64)?; } // Verify all files were created let files = fs.list_files().unwrap(); - assert_eq!( - files.len(), - 50, - "Expected 50 files, found {}", - files.len() - ); + assert_eq!(files.len(), 50, "Expected 50 files, found {}", files.len()); // Verify we can read all files back for i in 0..50 { @@ -1882,8 +1959,13 @@ mod tests { file.set_len(size).unwrap(); // Format data partition - crate::exfat::format::format_exfat(&mut file, layout.data_offset(), layout.data_size(), "TEST") - .unwrap(); + crate::exfat::format::format_exfat( + &mut file, + layout.data_offset(), + layout.data_size(), + "TEST", + ) + .unwrap(); drop(file); @@ -1903,7 +1985,8 @@ mod tests { assert_eq!(reader.position(), 0); let mut read_data = Vec::new(); - let bytes_read = reader.read_to_end(&mut read_data) + let bytes_read = reader + .read_to_end(&mut read_data) .map_err(|e| VentoyError::Io(e))?; assert_eq!(bytes_read, test_data.len()); @@ -1932,22 +2015,32 @@ mod tests { let mut reader = ExfatFileReader::open(&mut fs, "large_file.bin")?; // Seek to middle - reader.seek(SeekFrom::Start(10000)).map_err(|e| VentoyError::Io(e))?; + reader + .seek(SeekFrom::Start(10000)) + .map_err(|e| VentoyError::Io(e))?; assert_eq!(reader.position(), 10000); let mut buffer = [0u8; 10]; - reader.read_exact(&mut buffer).map_err(|e| VentoyError::Io(e))?; + reader + .read_exact(&mut buffer) + .map_err(|e| VentoyError::Io(e))?; assert_eq!(&buffer, &test_data[10000..10010]); // Seek from current position - reader.seek(SeekFrom::Current(-5)).map_err(|e| VentoyError::Io(e))?; + reader + .seek(SeekFrom::Current(-5)) + .map_err(|e| VentoyError::Io(e))?; assert_eq!(reader.position(), 10005); // Seek from end - reader.seek(SeekFrom::End(-100)).map_err(|e| VentoyError::Io(e))?; + reader + .seek(SeekFrom::End(-100)) + .map_err(|e| VentoyError::Io(e))?; assert_eq!(reader.position(), test_data.len() as u64 - 100); - reader.read_exact(&mut buffer).map_err(|e| VentoyError::Io(e))?; + reader + .read_exact(&mut buffer) + .map_err(|e| VentoyError::Io(e))?; let expected_start = test_data.len() - 100; assert_eq!(&buffer, &test_data[expected_start..expected_start + 10]); } @@ -1981,8 +2074,13 @@ mod tests { file.set_len(size).unwrap(); // Format data partition - crate::exfat::format::format_exfat(&mut file, layout.data_offset(), layout.data_size(), "TEST") - .unwrap(); + crate::exfat::format::format_exfat( + &mut file, + layout.data_offset(), + layout.data_size(), + "TEST", + ) + .unwrap(); drop(file); diff --git a/libs/ventoy-img-rs/src/exfat/unicode.rs b/libs/ventoy-img-rs/src/exfat/unicode.rs index d29b6d7f..2001eafa 100644 --- a/libs/ventoy-img-rs/src/exfat/unicode.rs +++ b/libs/ventoy-img-rs/src/exfat/unicode.rs @@ -93,7 +93,7 @@ pub fn to_uppercase_simple(ch: u16) -> u16 { // Greek lowercase (α-ω and variants) 0x03B1..=0x03C1 => ch - 32, // α-ρ -> Α-Ρ 0x03C3..=0x03C9 => ch - 32, // σ-ω -> Σ-Ω - 0x03C2 => 0x03A3, // ς (final sigma) -> Σ + 0x03C2 => 0x03A3, // ς (final sigma) -> Σ // Cyrillic lowercase (а-я) 0x0430..=0x044F => ch - 32, // а-я -> А-Я diff --git a/libs/ventoy-img-rs/src/image.rs b/libs/ventoy-img-rs/src/image.rs index 9f157be7..a74ef32e 100644 --- a/libs/ventoy-img-rs/src/image.rs +++ b/libs/ventoy-img-rs/src/image.rs @@ -22,7 +22,11 @@ impl VentoyImage { let size = parse_size(size_str)?; let layout = PartitionLayout::calculate(size)?; - println!("[INFO] Creating {}MB image: {}", size / (1024 * 1024), path.display()); + println!( + "[INFO] Creating {}MB image: {}", + size / (1024 * 1024), + path.display() + ); // Create sparse file let mut file = File::create(path)?; @@ -247,7 +251,11 @@ impl VentoyImage { /// /// This is the preferred method for large files as it doesn't load /// the entire file into memory. - pub fn read_file_to_writer(&self, path: &str, writer: &mut W) -> Result { + pub fn read_file_to_writer( + &self, + path: &str, + writer: &mut W, + ) -> Result { let mut fs = ExfatFs::open(&self.path, &self.layout)?; fs.read_file_path_to_writer(path, writer) } diff --git a/libs/ventoy-img-rs/src/lib.rs b/libs/ventoy-img-rs/src/lib.rs index 5799db47..cc3636d7 100644 --- a/libs/ventoy-img-rs/src/lib.rs +++ b/libs/ventoy-img-rs/src/lib.rs @@ -45,4 +45,4 @@ pub use error::{Result, VentoyError}; pub use exfat::FileInfo; pub use image::VentoyImage; pub use partition::{parse_size, PartitionLayout}; -pub use resources::{init_resources, get_resource_dir, is_initialized, required_files}; +pub use resources::{get_resource_dir, init_resources, is_initialized, required_files}; diff --git a/libs/ventoy-img-rs/src/main.rs b/libs/ventoy-img-rs/src/main.rs index 842912b9..ab290690 100644 --- a/libs/ventoy-img-rs/src/main.rs +++ b/libs/ventoy-img-rs/src/main.rs @@ -4,7 +4,7 @@ use clap::{Parser, Subcommand}; use std::path::PathBuf; use std::process::ExitCode; -use ventoy_img::{VentoyImage, Result, VentoyError}; +use ventoy_img::{Result, VentoyError, VentoyImage}; #[derive(Parser)] #[command(name = "ventoy-img")] @@ -103,11 +103,33 @@ fn main() -> ExitCode { let cli = Cli::parse(); let result = match cli.command { - Commands::Create { size, output, label } => cmd_create(&output, &size, &label), - Commands::Add { image, file, dest, force, parents } => cmd_add(&image, &file, dest.as_deref(), force, parents), - Commands::List { image, path, recursive } => cmd_list(&image, path.as_deref(), recursive), - Commands::Remove { image, path, recursive } => cmd_remove(&image, &path, recursive), - Commands::Mkdir { image, path, parents } => cmd_mkdir(&image, &path, parents), + Commands::Create { + size, + output, + label, + } => cmd_create(&output, &size, &label), + Commands::Add { + image, + file, + dest, + force, + parents, + } => cmd_add(&image, &file, dest.as_deref(), force, parents), + Commands::List { + image, + path, + recursive, + } => cmd_list(&image, path.as_deref(), recursive), + Commands::Remove { + image, + path, + recursive, + } => cmd_remove(&image, &path, recursive), + Commands::Mkdir { + image, + path, + parents, + } => cmd_mkdir(&image, &path, parents), Commands::Info { image } => cmd_info(&image), }; @@ -138,7 +160,13 @@ fn cmd_create(output: &PathBuf, size: &str, label: &str) -> Result<()> { Ok(()) } -fn cmd_add(image: &PathBuf, file: &PathBuf, dest: Option<&str>, force: bool, parents: bool) -> Result<()> { +fn cmd_add( + image: &PathBuf, + file: &PathBuf, + dest: Option<&str>, + force: bool, + parents: bool, +) -> Result<()> { if !file.exists() { return Err(VentoyError::FileNotFound(file.display().to_string())); } @@ -234,18 +262,23 @@ fn cmd_info(image: &PathBuf) -> Result<()> { println!(); println!("Partition Layout:"); println!(" Data partition:"); - println!(" Start: sector {} (offset {})", + println!( + " Start: sector {} (offset {})", layout.data_start_sector, - format_size(layout.data_offset())); - println!(" Size: {} sectors ({})", + format_size(layout.data_offset()) + ); + println!( + " Size: {} sectors ({})", layout.data_size_sectors, - format_size(layout.data_size())); + format_size(layout.data_size()) + ); println!(" EFI partition:"); - println!(" Start: sector {} (offset {})", + println!( + " Start: sector {} (offset {})", layout.efi_start_sector, - format_size(layout.efi_offset())); - println!(" Size: {} sectors (32 MB)", - layout.efi_size_sectors); + format_size(layout.efi_offset()) + ); + println!(" Size: {} sectors (32 MB)", layout.efi_size_sectors); Ok(()) } diff --git a/res/vcpkg/libyuv/build.rs b/res/vcpkg/libyuv/build.rs index d754996a..f03715e5 100644 --- a/res/vcpkg/libyuv/build.rs +++ b/res/vcpkg/libyuv/build.rs @@ -225,15 +225,19 @@ fn link_system() -> bool { } // Then standard paths - lib_paths.extend([ - "/usr/local/lib", // Custom builds - "/usr/local/lib64", - "/usr/lib", - "/usr/lib64", - "/usr/lib/x86_64-linux-gnu", // Debian/Ubuntu x86_64 - "/usr/lib/aarch64-linux-gnu", // Debian/Ubuntu ARM64 - "/usr/lib/arm-linux-gnueabihf", // Debian/Ubuntu ARMv7 - ].iter().map(|s| s.to_string())); + lib_paths.extend( + [ + "/usr/local/lib", // Custom builds + "/usr/local/lib64", + "/usr/lib", + "/usr/lib64", + "/usr/lib/x86_64-linux-gnu", // Debian/Ubuntu x86_64 + "/usr/lib/aarch64-linux-gnu", // Debian/Ubuntu ARM64 + "/usr/lib/arm-linux-gnueabihf", // Debian/Ubuntu ARMv7 + ] + .iter() + .map(|s| s.to_string()), + ); for path in &lib_paths { let lib_path = Path::new(path); @@ -245,7 +249,10 @@ fn link_system() -> bool { println!("cargo:rustc-link-search=native={}", path); println!("cargo:rustc-link-lib=static=yuv"); println!("cargo:rustc-link-lib=stdc++"); - println!("cargo:info=Using system libyuv from {} (static linking)", path); + println!( + "cargo:info=Using system libyuv from {} (static linking)", + path + ); return true; } @@ -257,7 +264,10 @@ fn link_system() -> bool { #[cfg(target_os = "linux")] println!("cargo:rustc-link-lib=stdc++"); - println!("cargo:info=Using system libyuv from {} (dynamic linking)", path); + println!( + "cargo:info=Using system libyuv from {} (dynamic linking)", + path + ); return true; } } diff --git a/res/vcpkg/libyuv/src/lib.rs b/res/vcpkg/libyuv/src/lib.rs index 45e1fc66..c86f8ad7 100644 --- a/res/vcpkg/libyuv/src/lib.rs +++ b/res/vcpkg/libyuv/src/lib.rs @@ -404,6 +404,37 @@ pub fn nv12_to_i420(src: &[u8], dst: &mut [u8], width: i32, height: i32) -> Resu )) } +/// Convert NV21 to I420 (YUV420P) +pub fn nv21_to_i420(src: &[u8], dst: &mut [u8], width: i32, height: i32) -> Result<()> { + if width % 2 != 0 || height % 2 != 0 { + return Err(YuvError::InvalidDimensions); + } + + let w = width as usize; + let h = height as usize; + let y_size = w * h; + let uv_size = (w / 2) * (h / 2); + + if src.len() < nv12_size(w, h) || dst.len() < i420_size(w, h) { + return Err(YuvError::BufferTooSmall); + } + + call_yuv!(NV21ToI420( + src.as_ptr(), + width, + src[y_size..].as_ptr(), + width, + dst.as_mut_ptr(), + width, + dst[y_size..].as_mut_ptr(), + width / 2, + dst[y_size + uv_size..].as_mut_ptr(), + width / 2, + width, + height, + )) +} + // ============================================================================ // ARGB/BGRA conversions (32-bit) // Note: libyuv ARGB = BGRA in memory on little-endian systems diff --git a/src/atx/led.rs b/src/atx/led.rs index 93765e27..6f970479 100644 --- a/src/atx/led.rs +++ b/src/atx/led.rs @@ -55,7 +55,10 @@ impl LedSensor { .map_err(|e| AppError::Internal(format!("LED GPIO chip failed: {}", e)))?; let line = chip.get_line(self.config.gpio_pin).map_err(|e| { - AppError::Internal(format!("LED GPIO line {} failed: {}", self.config.gpio_pin, e)) + AppError::Internal(format!( + "LED GPIO line {} failed: {}", + self.config.gpio_pin, e + )) })?; let handle = line diff --git a/src/atx/mod.rs b/src/atx/mod.rs index 84c35bad..dd7c90b4 100644 --- a/src/atx/mod.rs +++ b/src/atx/mod.rs @@ -52,8 +52,8 @@ mod wol; pub use controller::{AtxController, AtxControllerConfig}; pub use executor::timing; pub use types::{ - ActiveLevel, AtxAction, AtxDevices, AtxDriverType, AtxKeyConfig, AtxLedConfig, - AtxPowerRequest, AtxState, PowerStatus, + ActiveLevel, AtxAction, AtxDevices, AtxDriverType, AtxKeyConfig, AtxLedConfig, AtxPowerRequest, + AtxState, PowerStatus, }; pub use wol::send_wol; diff --git a/src/atx/wol.rs b/src/atx/wol.rs index 412d7226..958e3a97 100644 --- a/src/atx/wol.rs +++ b/src/atx/wol.rs @@ -22,7 +22,10 @@ fn parse_mac_address(mac: &str) -> Result<[u8; 6]> { } else if mac.contains('-') { mac.split('-').collect() } else { - return Err(AppError::Config(format!("Invalid MAC address format: {}", mac))); + return Err(AppError::Config(format!( + "Invalid MAC address format: {}", + mac + ))); }; if parts.len() != 6 { @@ -34,9 +37,8 @@ fn parse_mac_address(mac: &str) -> Result<[u8; 6]> { let mut bytes = [0u8; 6]; for (i, part) in parts.iter().enumerate() { - bytes[i] = u8::from_str_radix(part, 16).map_err(|_| { - AppError::Config(format!("Invalid MAC address byte: {}", part)) - })?; + bytes[i] = u8::from_str_radix(part, 16) + .map_err(|_| AppError::Config(format!("Invalid MAC address byte: {}", part)))?; } Ok(bytes) diff --git a/src/audio/capture.rs b/src/audio/capture.rs index 4252b9cf..229e1650 100644 --- a/src/audio/capture.rs +++ b/src/audio/capture.rs @@ -201,7 +201,15 @@ impl AudioCapturer { let log_throttler = self.log_throttler.clone(); let handle = tokio::task::spawn_blocking(move || { - capture_loop(config, state, stats, frame_tx, stop_flag, sequence, log_throttler); + capture_loop( + config, + state, + stats, + frame_tx, + stop_flag, + sequence, + log_throttler, + ); }); *self.capture_handle.lock().await = Some(handle); @@ -274,40 +282,34 @@ fn run_capture( // Configure hardware parameters { - let hwp = HwParams::any(&pcm).map_err(|e| { - AppError::AudioError(format!("Failed to get HwParams: {}", e)) - })?; + let hwp = HwParams::any(&pcm) + .map_err(|e| AppError::AudioError(format!("Failed to get HwParams: {}", e)))?; - hwp.set_channels(config.channels).map_err(|e| { - AppError::AudioError(format!("Failed to set channels: {}", e)) - })?; + hwp.set_channels(config.channels) + .map_err(|e| AppError::AudioError(format!("Failed to set channels: {}", e)))?; - hwp.set_rate(config.sample_rate, ValueOr::Nearest).map_err(|e| { - AppError::AudioError(format!("Failed to set sample rate: {}", e)) - })?; + hwp.set_rate(config.sample_rate, ValueOr::Nearest) + .map_err(|e| AppError::AudioError(format!("Failed to set sample rate: {}", e)))?; - hwp.set_format(Format::s16()).map_err(|e| { - AppError::AudioError(format!("Failed to set format: {}", e)) - })?; + hwp.set_format(Format::s16()) + .map_err(|e| AppError::AudioError(format!("Failed to set format: {}", e)))?; - hwp.set_access(Access::RWInterleaved).map_err(|e| { - AppError::AudioError(format!("Failed to set access: {}", e)) - })?; + hwp.set_access(Access::RWInterleaved) + .map_err(|e| AppError::AudioError(format!("Failed to set access: {}", e)))?; - hwp.set_buffer_size_near(config.buffer_frames as Frames).map_err(|e| { - AppError::AudioError(format!("Failed to set buffer size: {}", e)) - })?; + hwp.set_buffer_size_near(config.buffer_frames as Frames) + .map_err(|e| AppError::AudioError(format!("Failed to set buffer size: {}", e)))?; hwp.set_period_size_near(config.period_frames as Frames, ValueOr::Nearest) .map_err(|e| AppError::AudioError(format!("Failed to set period size: {}", e)))?; - pcm.hw_params(&hwp).map_err(|e| { - AppError::AudioError(format!("Failed to apply hw params: {}", e)) - })?; + pcm.hw_params(&hwp) + .map_err(|e| AppError::AudioError(format!("Failed to apply hw params: {}", e)))?; } // Get actual configuration - let actual_rate = pcm.hw_params_current() + let actual_rate = pcm + .hw_params_current() .map(|h| h.get_rate().unwrap_or(config.sample_rate)) .unwrap_or(config.sample_rate); @@ -317,9 +319,8 @@ fn run_capture( ); // Prepare for capture - pcm.prepare().map_err(|e| { - AppError::AudioError(format!("Failed to prepare PCM: {}", e)) - })?; + pcm.prepare() + .map_err(|e| AppError::AudioError(format!("Failed to prepare PCM: {}", e)))?; let _ = state.send(CaptureState::Running); @@ -340,7 +341,11 @@ fn run_capture( continue; } State::Suspended => { - warn_throttled!(log_throttler, "suspended", "Audio device suspended, recovering"); + warn_throttled!( + log_throttler, + "suspended", + "Audio device suspended, recovering" + ); let _ = pcm.resume(); continue; } @@ -363,11 +368,8 @@ fn run_capture( // Directly use the buffer slice (already in correct byte format) let seq = sequence.fetch_add(1, Ordering::Relaxed); - let frame = AudioFrame::new( - Bytes::copy_from_slice(&buffer[..byte_count]), - config, - seq, - ); + let frame = + AudioFrame::new(Bytes::copy_from_slice(&buffer[..byte_count]), config, seq); // Send to subscribers if frame_tx.receiver_count() > 0 { diff --git a/src/audio/controller.rs b/src/audio/controller.rs index ed67e72b..9858a764 100644 --- a/src/audio/controller.rs +++ b/src/audio/controller.rs @@ -193,7 +193,9 @@ impl AudioController { pub async fn select_device(&self, device: &str) -> Result<()> { // Validate device exists let devices = self.list_devices().await?; - let found = devices.iter().any(|d| d.name == device || d.description.contains(device)); + let found = devices + .iter() + .any(|d| d.name == device || d.description.contains(device)); if !found && device != "default" { return Err(AppError::AudioError(format!( @@ -244,7 +246,11 @@ impl AudioController { }) .await; - info!("Audio quality set to: {:?} ({}bps)", quality, quality.bitrate()); + info!( + "Audio quality set to: {:?} ({}bps)", + quality, + quality.bitrate() + ); Ok(()) } @@ -346,14 +352,17 @@ impl AudioController { let streaming = self.is_streaming().await; let error = self.last_error.read().await.clone(); - let (subscriber_count, frames_encoded, bytes_output) = if let Some(ref streamer) = - *self.streamer.read().await - { - let stats = streamer.stats().await; - (stats.subscriber_count, stats.frames_encoded, stats.bytes_output) - } else { - (0, 0, 0) - }; + let (subscriber_count, frames_encoded, bytes_output) = + if let Some(ref streamer) = *self.streamer.read().await { + let stats = streamer.stats().await; + ( + stats.subscriber_count, + stats.frames_encoded, + stats.bytes_output, + ) + } else { + (0, 0, 0) + }; AudioStatus { enabled: config.enabled, @@ -383,7 +392,11 @@ impl AudioController { /// Subscribe to Opus frames (async version) pub async fn subscribe_opus_async(&self) -> Option> { - self.streamer.read().await.as_ref().map(|s| s.subscribe_opus()) + self.streamer + .read() + .await + .as_ref() + .map(|s| s.subscribe_opus()) } /// Enable or disable audio diff --git a/src/audio/device.rs b/src/audio/device.rs index 70c8d1a0..ed42726c 100644 --- a/src/audio/device.rs +++ b/src/audio/device.rs @@ -55,7 +55,12 @@ fn get_usb_bus_info(card_index: i32) -> Option { // Match patterns like "1-1", "1-2", "1-1.2", "2-1.3.1" if component.contains('-') && !component.contains(':') { // Verify it looks like a USB port (starts with digit) - if component.chars().next().map(|c| c.is_ascii_digit()).unwrap_or(false) { + if component + .chars() + .next() + .map(|c| c.is_ascii_digit()) + .unwrap_or(false) + { return Some(component.to_string()); } } @@ -223,15 +228,14 @@ pub fn find_best_audio_device() -> Result { let devices = enumerate_audio_devices()?; if devices.is_empty() { - return Err(AppError::AudioError("No audio capture devices found".to_string())); + return Err(AppError::AudioError( + "No audio capture devices found".to_string(), + )); } // First, look for HDMI/capture card devices that support 48kHz stereo for device in &devices { - if device.is_hdmi - && device.sample_rates.contains(&48000) - && device.channels.contains(&2) - { + if device.is_hdmi && device.sample_rates.contains(&48000) && device.channels.contains(&2) { info!("Selected HDMI audio device: {}", device.description); return Ok(device.clone()); } diff --git a/src/audio/encoder.rs b/src/audio/encoder.rs index 671ae967..bcd316b1 100644 --- a/src/audio/encoder.rs +++ b/src/audio/encoder.rs @@ -137,9 +137,8 @@ impl OpusEncoder { let channels = config.to_audiopus_channels(); let application = config.to_audiopus_application(); - let mut encoder = Encoder::new(sample_rate, channels, application).map_err(|e| { - AppError::AudioError(format!("Failed to create Opus encoder: {:?}", e)) - })?; + let mut encoder = Encoder::new(sample_rate, channels, application) + .map_err(|e| AppError::AudioError(format!("Failed to create Opus encoder: {:?}", e)))?; // Configure encoder encoder diff --git a/src/audio/mod.rs b/src/audio/mod.rs index bc70a8ce..b6a3f9c6 100644 --- a/src/audio/mod.rs +++ b/src/audio/mod.rs @@ -22,5 +22,7 @@ pub use controller::{AudioController, AudioControllerConfig, AudioQuality, Audio pub use device::{enumerate_audio_devices, enumerate_audio_devices_with_current, AudioDeviceInfo}; pub use encoder::{OpusConfig, OpusEncoder, OpusFrame}; pub use monitor::{AudioHealthMonitor, AudioHealthStatus, AudioMonitorConfig}; -pub use shared_pipeline::{SharedAudioPipeline, SharedAudioPipelineConfig, SharedAudioPipelineStats}; +pub use shared_pipeline::{ + SharedAudioPipeline, SharedAudioPipelineConfig, SharedAudioPipelineStats, +}; pub use streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig}; diff --git a/src/audio/monitor.rs b/src/audio/monitor.rs index f6933ccd..d29b747a 100644 --- a/src/audio/monitor.rs +++ b/src/audio/monitor.rs @@ -329,9 +329,7 @@ mod tests { let monitor = AudioHealthMonitor::with_defaults(); for i in 1..=5 { - monitor - .report_error(None, "Error", "io_error") - .await; + monitor.report_error(None, "Error", "io_error").await; assert_eq!(monitor.retry_count(), i); } } @@ -340,9 +338,7 @@ mod tests { async fn test_reset() { let monitor = AudioHealthMonitor::with_defaults(); - monitor - .report_error(None, "Error", "io_error") - .await; + monitor.report_error(None, "Error", "io_error").await; assert!(monitor.is_error().await); monitor.reset().await; diff --git a/src/audio/shared_pipeline.rs b/src/audio/shared_pipeline.rs index feda8c82..0e2cab4b 100644 --- a/src/audio/shared_pipeline.rs +++ b/src/audio/shared_pipeline.rs @@ -60,7 +60,7 @@ impl Default for SharedAudioPipelineConfig { bitrate: 64000, application: OpusApplicationMode::Audio, fec: true, - channel_capacity: 16, // Reduced from 64 for lower latency + channel_capacity: 16, // Reduced from 64 for lower latency } } } @@ -320,11 +320,8 @@ impl SharedAudioPipeline { } // Receive audio frame with timeout - let recv_result = tokio::time::timeout( - std::time::Duration::from_secs(2), - audio_rx.recv(), - ) - .await; + let recv_result = + tokio::time::timeout(std::time::Duration::from_secs(2), audio_rx.recv()).await; match recv_result { Ok(Ok(audio_frame)) => { diff --git a/src/audio/streamer.rs b/src/audio/streamer.rs index 3e2cd8e0..3e208b0b 100644 --- a/src/audio/streamer.rs +++ b/src/audio/streamer.rs @@ -297,11 +297,8 @@ impl AudioStreamer { } // Receive PCM frame with timeout - let recv_result = tokio::time::timeout( - std::time::Duration::from_secs(2), - pcm_rx.recv(), - ) - .await; + let recv_result = + tokio::time::timeout(std::time::Duration::from_secs(2), pcm_rx.recv()).await; match recv_result { Ok(Ok(audio_frame)) => { diff --git a/src/auth/middleware.rs b/src/auth/middleware.rs index ed08d3f1..5954b700 100644 --- a/src/auth/middleware.rs +++ b/src/auth/middleware.rs @@ -46,7 +46,10 @@ pub async fn auth_middleware( if !state.config.is_initialized() { // Allow access to setup endpoints when not initialized let path = request.uri().path(); - if path.starts_with("/api/setup") || path == "/api/info" || path.starts_with("/") && !path.starts_with("/api/") { + if path.starts_with("/api/setup") + || path == "/api/info" + || path.starts_with("/") && !path.starts_with("/api/") + { return Ok(next.run(request).await); } } diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 35ef9b23..0e5147be 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,9 +1,9 @@ +pub mod middleware; mod password; mod session; mod user; -pub mod middleware; +pub use middleware::{auth_middleware, require_admin, AuthLayer, SESSION_COOKIE}; pub use password::{hash_password, verify_password}; pub use session::{Session, SessionStore}; pub use user::{User, UserStore}; -pub use middleware::{AuthLayer, SESSION_COOKIE, auth_middleware, require_admin}; diff --git a/src/auth/user.rs b/src/auth/user.rs index dc0061bd..8c68cb79 100644 --- a/src/auth/user.rs +++ b/src/auth/user.rs @@ -3,8 +3,8 @@ use serde::{Deserialize, Serialize}; use sqlx::{Pool, Sqlite}; use uuid::Uuid; -use crate::error::{AppError, Result}; use super::password::{hash_password, verify_password}; +use crate::error::{AppError, Result}; /// User row type from database type UserRow = (String, String, String, i32, String, String); @@ -134,14 +134,13 @@ impl UserStore { let password_hash = hash_password(new_password)?; let now = Utc::now(); - let result = sqlx::query( - "UPDATE users SET password_hash = ?1, updated_at = ?2 WHERE id = ?3", - ) - .bind(&password_hash) - .bind(now.to_rfc3339()) - .bind(user_id) - .execute(&self.pool) - .await?; + let result = + sqlx::query("UPDATE users SET password_hash = ?1, updated_at = ?2 WHERE id = ?3") + .bind(&password_hash) + .bind(now.to_rfc3339()) + .bind(user_id) + .execute(&self.pool) + .await?; if result.rows_affected() == 0 { return Err(AppError::NotFound("User not found".to_string())); diff --git a/src/config/schema.rs b/src/config/schema.rs index 770bc14a..c2b2062b 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -1,6 +1,6 @@ +use crate::video::encoder::BitratePreset; use serde::{Deserialize, Serialize}; use typeshare::typeshare; -use crate::video::encoder::BitratePreset; // Re-export ExtensionsConfig from extensions module pub use crate::extensions::ExtensionsConfig; @@ -147,8 +147,8 @@ pub struct OtgDescriptorConfig { impl Default for OtgDescriptorConfig { fn default() -> Self { Self { - vendor_id: 0x1d6b, // Linux Foundation - product_id: 0x0104, // Multifunction Composite Gadget + vendor_id: 0x1d6b, // Linux Foundation + product_id: 0x0104, // Multifunction Composite Gadget manufacturer: "One-KVM".to_string(), product: "One-KVM USB Device".to_string(), serial_number: None, @@ -425,8 +425,15 @@ impl StreamConfig { /// Check if using public ICE servers (user left fields empty) pub fn is_using_public_ice_servers(&self) -> bool { use crate::webrtc::config::public_ice; - self.stun_server.as_ref().map(|s| s.is_empty()).unwrap_or(true) - && self.turn_server.as_ref().map(|s| s.is_empty()).unwrap_or(true) + self.stun_server + .as_ref() + .map(|s| s.is_empty()) + .unwrap_or(true) + && self + .turn_server + .as_ref() + .map(|s| s.is_empty()) + .unwrap_or(true) && public_ice::is_configured() } } diff --git a/src/config/store.rs b/src/config/store.rs index 827edcc2..0e48be8d 100644 --- a/src/config/store.rs +++ b/src/config/store.rs @@ -126,11 +126,10 @@ impl ConfigStore { /// Load configuration from database async fn load_config(pool: &Pool) -> Result { - let row: Option<(String,)> = sqlx::query_as( - "SELECT value FROM config WHERE key = 'app_config'" - ) - .fetch_optional(pool) - .await?; + let row: Option<(String,)> = + sqlx::query_as("SELECT value FROM config WHERE key = 'app_config'") + .fetch_optional(pool) + .await?; match row { Some((json,)) => { @@ -245,10 +244,13 @@ mod tests { assert!(!config.initialized); // Update config - store.update(|c| { - c.initialized = true; - c.web.http_port = 9000; - }).await.unwrap(); + store + .update(|c| { + c.initialized = true; + c.web.http_port = 9000; + }) + .await + .unwrap(); // Verify update let config = store.get(); diff --git a/src/events/mod.rs b/src/events/mod.rs index c0912302..466b259d 100644 --- a/src/events/mod.rs +++ b/src/events/mod.rs @@ -6,7 +6,8 @@ pub mod types; pub use types::{ - AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, MsdDeviceInfo, SystemEvent, VideoDeviceInfo, + AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, MsdDeviceInfo, SystemEvent, + VideoDeviceInfo, }; use tokio::sync::broadcast; diff --git a/src/events/types.rs b/src/events/types.rs index 98ccef55..38835088 100644 --- a/src/events/types.rs +++ b/src/events/types.rs @@ -128,6 +128,20 @@ pub enum SystemEvent { // ============================================================================ // Video Stream Events // ============================================================================ + /// Stream mode switching started (transactional, correlates all following events) + /// + /// Sent immediately after a mode switch request is accepted. + /// Clients can use `transition_id` to correlate subsequent `stream.*` events. + #[serde(rename = "stream.mode_switching")] + StreamModeSwitching { + /// Unique transition ID for this mode switch transaction + transition_id: String, + /// Target mode: "mjpeg", "h264", "h265", "vp8", "vp9" + to_mode: String, + /// Previous mode: "mjpeg", "h264", "h265", "vp8", "vp9" + from_mode: String, + }, + /// Stream state changed (e.g., started, stopped, error) #[serde(rename = "stream.state_changed")] StreamStateChanged { @@ -143,6 +157,9 @@ pub enum SystemEvent { /// the stream will be interrupted temporarily. #[serde(rename = "stream.config_changing")] StreamConfigChanging { + /// Optional transition ID if this config change is part of a mode switch transaction + #[serde(skip_serializing_if = "Option::is_none")] + transition_id: Option, /// Reason for change: "device_switch", "resolution_change", "format_change" reason: String, }, @@ -152,6 +169,9 @@ pub enum SystemEvent { /// Sent after new configuration is active. Clients can reconnect now. #[serde(rename = "stream.config_applied")] StreamConfigApplied { + /// Optional transition ID if this config change is part of a mode switch transaction + #[serde(skip_serializing_if = "Option::is_none")] + transition_id: Option, /// Device path device: String, /// Resolution (width, height) @@ -193,6 +213,9 @@ pub enum SystemEvent { /// Clients should wait for this event before attempting to create WebRTC sessions. #[serde(rename = "stream.webrtc_ready")] WebRTCReady { + /// Optional transition ID if this readiness is part of a mode switch transaction + #[serde(skip_serializing_if = "Option::is_none")] + transition_id: Option, /// Current video codec codec: String, /// Whether hardware encoding is being used @@ -215,12 +238,26 @@ pub enum SystemEvent { /// from the current stream and reconnect using the new mode. #[serde(rename = "stream.mode_changed")] StreamModeChanged { + /// Optional transition ID if this change is part of a mode switch transaction + #[serde(skip_serializing_if = "Option::is_none")] + transition_id: Option, /// New mode: "mjpeg", "h264", "h265", "vp8", or "vp9" mode: String, /// Previous mode: "mjpeg", "h264", "h265", "vp8", or "vp9" previous_mode: String, }, + /// Stream mode switching completed (transactional end marker) + /// + /// Sent when the backend considers the new mode ready for clients to connect. + #[serde(rename = "stream.mode_ready")] + StreamModeReady { + /// Unique transition ID for this mode switch transaction + transition_id: String, + /// Active mode after switch: "mjpeg", "h264", "h265", "vp8", "vp9" + mode: String, + }, + // ============================================================================ // HID Events // ============================================================================ @@ -491,6 +528,7 @@ impl SystemEvent { /// Get the event name (for filtering/routing) pub fn event_name(&self) -> &'static str { match self { + Self::StreamModeSwitching { .. } => "stream.mode_switching", Self::StreamStateChanged { .. } => "stream.state_changed", Self::StreamConfigChanging { .. } => "stream.config_changing", Self::StreamConfigApplied { .. } => "stream.config_applied", @@ -500,6 +538,7 @@ impl SystemEvent { Self::WebRTCReady { .. } => "stream.webrtc_ready", Self::StreamStatsUpdate { .. } => "stream.stats_update", Self::StreamModeChanged { .. } => "stream.mode_changed", + Self::StreamModeReady { .. } => "stream.mode_ready", Self::HidStateChanged { .. } => "hid.state_changed", Self::HidBackendSwitching { .. } => "hid.backend_switching", Self::HidDeviceLost { .. } => "hid.device_lost", @@ -589,6 +628,7 @@ mod tests { #[test] fn test_serialization() { let event = SystemEvent::StreamConfigApplied { + transition_id: None, device: "/dev/video0".to_string(), resolution: (1920, 1080), format: "mjpeg".to_string(), @@ -600,6 +640,9 @@ mod tests { assert!(json.contains("/dev/video0")); let deserialized: SystemEvent = serde_json::from_str(&json).unwrap(); - assert!(matches!(deserialized, SystemEvent::StreamConfigApplied { .. })); + assert!(matches!( + deserialized, + SystemEvent::StreamConfigApplied { .. } + )); } } diff --git a/src/extensions/manager.rs b/src/extensions/manager.rs index 6624f99b..279d303c 100644 --- a/src/extensions/manager.rs +++ b/src/extensions/manager.rs @@ -210,7 +210,11 @@ impl ExtensionManager { } /// Build command arguments for an extension - async fn build_args(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result, String> { + async fn build_args( + &self, + id: ExtensionId, + config: &ExtensionsConfig, + ) -> Result, String> { match id { ExtensionId::Ttyd => { let c = &config.ttyd; @@ -219,9 +223,11 @@ impl ExtensionManager { Self::prepare_ttyd_socket().await?; let mut args = vec![ - "-i".to_string(), TTYD_SOCKET_PATH.to_string(), // Unix socket - "-b".to_string(), "/api/terminal".to_string(), // Base path for reverse proxy - "-W".to_string(), // Writable (allow input) + "-i".to_string(), + TTYD_SOCKET_PATH.to_string(), // Unix socket + "-b".to_string(), + "/api/terminal".to_string(), // Base path for reverse proxy + "-W".to_string(), // Writable (allow input) ]; // Add credential if set (still useful for additional security layer) @@ -313,7 +319,10 @@ impl ExtensionManager { } // Remove old socket file if exists - if tokio::fs::try_exists(TTYD_SOCKET_PATH).await.unwrap_or(false) { + if tokio::fs::try_exists(TTYD_SOCKET_PATH) + .await + .unwrap_or(false) + { tokio::fs::remove_file(TTYD_SOCKET_PATH) .await .map_err(|e| format!("Failed to remove old socket: {}", e))?; @@ -374,8 +383,8 @@ impl ExtensionManager { /// Start all enabled extensions in parallel pub async fn start_enabled(&self, config: &ExtensionsConfig) { - use std::pin::Pin; use futures::Future; + use std::pin::Pin; let mut start_futures: Vec + Send + '_>>> = Vec::new(); @@ -416,10 +425,7 @@ impl ExtensionManager { /// Stop all running extensions in parallel pub async fn stop_all(&self) { - let stop_futures: Vec<_> = ExtensionId::all() - .iter() - .map(|id| self.stop(*id)) - .collect(); + let stop_futures: Vec<_> = ExtensionId::all().iter().map(|id| self.stop(*id)).collect(); futures::future::join_all(stop_futures).await; } } diff --git a/src/hid/ch9329.rs b/src/hid/ch9329.rs index a6f11397..5423627c 100644 --- a/src/hid/ch9329.rs +++ b/src/hid/ch9329.rs @@ -21,7 +21,7 @@ use async_trait::async_trait; use parking_lot::{Mutex, RwLock}; use serde::{Deserialize, Serialize}; use std::io::{Read, Write}; -use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU8, AtomicU32, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, Ordering}; use std::time::{Duration, Instant}; use tracing::{debug, info, trace, warn}; @@ -358,8 +358,7 @@ impl Response { /// Check if the response indicates success pub fn is_success(&self) -> bool { - !self.is_error - && (self.data.is_empty() || self.data[0] == Ch9329Error::Success as u8) + !self.is_error && (self.data.is_empty() || self.data[0] == Ch9329Error::Success as u8) } } @@ -489,7 +488,10 @@ impl Ch9329Backend { .map_err(|e| Self::serial_error_to_hid_error(e, "Failed to open serial port"))?; *self.port.lock() = Some(port); - info!("CH9329 serial port reopened: {} @ {} baud", self.port_path, self.baud_rate); + info!( + "CH9329 serial port reopened: {} @ {} baud", + self.port_path, self.baud_rate + ); // Verify connection with GET_INFO command self.query_chip_info().map_err(|e| { @@ -518,7 +520,10 @@ impl Ch9329Backend { /// Returns the packet buffer and the actual length #[inline] fn build_packet_buf(&self, cmd: u8, data: &[u8]) -> ([u8; MAX_PACKET_SIZE], usize) { - debug_assert!(data.len() <= MAX_DATA_LEN, "Data too long for CH9329 packet"); + debug_assert!( + data.len() <= MAX_DATA_LEN, + "Data too long for CH9329 packet" + ); let len = data.len() as u8; let packet_len = 6 + data.len(); @@ -554,16 +559,19 @@ impl Ch9329Backend { let mut port_guard = self.port.lock(); if let Some(ref mut port) = *port_guard { - port.write_all(&packet[..packet_len]).map_err(|e| { - AppError::HidError { + port.write_all(&packet[..packet_len]) + .map_err(|e| AppError::HidError { backend: "ch9329".to_string(), reason: format!("Failed to write to CH9329: {}", e), error_code: "write_failed".to_string(), - } - })?; + })?; // Only log mouse button events at debug level to avoid flooding if cmd == cmd::SEND_MS_ABS_DATA && data.len() >= 2 && data[1] != 0 { - debug!("CH9329 TX [cmd=0x{:02X}]: {:02X?}", cmd, &packet[..packet_len]); + debug!( + "CH9329 TX [cmd=0x{:02X}]: {:02X?}", + cmd, + &packet[..packet_len] + ); } Ok(()) } else { @@ -655,7 +663,11 @@ impl Ch9329Backend { info!( "CH9329: Recovery successful, chip version: {}, USB: {}", info.version, - if info.usb_connected { "connected" } else { "disconnected" } + if info.usb_connected { + "connected" + } else { + "disconnected" + } ); // Reset error count on successful recovery self.error_count.store(0, Ordering::Relaxed); @@ -695,9 +707,8 @@ impl Ch9329Backend { let mut port_guard = self.port.lock(); if let Some(ref mut port) = *port_guard { // Send packet - port.write_all(&packet).map_err(|e| { - AppError::Internal(format!("Failed to write to CH9329: {}", e)) - })?; + port.write_all(&packet) + .map_err(|e| AppError::Internal(format!("Failed to write to CH9329: {}", e)))?; trace!("CH9329 TX: {:02X?}", packet); // Wait for response - use shorter delay for faster response @@ -725,7 +736,10 @@ impl Ch9329Backend { debug!("CH9329 response timeout (may be normal)"); Err(AppError::Internal("CH9329 response timeout".to_string())) } - Err(e) => Err(AppError::Internal(format!("Failed to read from CH9329: {}", e))), + Err(e) => Err(AppError::Internal(format!( + "Failed to read from CH9329: {}", + e + ))), } } else { Err(AppError::Internal("CH9329 port not opened".to_string())) @@ -799,7 +813,9 @@ impl Ch9329Backend { if response.is_success() { Ok(()) } else { - Err(AppError::Internal("Failed to restore factory defaults".to_string())) + Err(AppError::Internal( + "Failed to restore factory defaults".to_string(), + )) } } @@ -820,7 +836,9 @@ impl Ch9329Backend { /// For other multimedia keys: data = [0x02, byte2, byte3, byte4] pub fn send_media_key(&self, data: &[u8]) -> Result<()> { if data.len() < 2 || data.len() > 4 { - return Err(AppError::Internal("Invalid media key data length".to_string())); + return Err(AppError::Internal( + "Invalid media key data length".to_string(), + )); } self.send_packet(cmd::SEND_KB_MEDIA_DATA, data) } @@ -871,10 +889,7 @@ impl Ch9329Backend { // Use send_packet which has retry logic built-in self.send_packet(cmd::SEND_MS_ABS_DATA, &data)?; - trace!( - "CH9329 mouse: buttons=0x{:02X} pos=({},{})", - buttons, x, y - ); + trace!("CH9329 mouse: buttons=0x{:02X} pos=({},{})", buttons, x, y); Ok(()) } @@ -930,7 +945,11 @@ impl HidBackend for Ch9329Backend { info!( "CH9329 chip detected: {}, USB: {}, LEDs: NumLock={}, CapsLock={}, ScrollLock={}", info.version, - if info.usb_connected { "connected" } else { "disconnected" }, + if info.usb_connected { + "connected" + } else { + "disconnected" + }, info.num_lock, info.caps_lock, info.scroll_lock @@ -1128,10 +1147,7 @@ pub fn detect_ch9329() -> Option { && response[0] == PACKET_HEADER[0] && response[1] == PACKET_HEADER[1] { - info!( - "CH9329 detected on {} @ {} baud", - port_path, baud_rate - ); + info!("CH9329 detected on {} @ {} baud", port_path, baud_rate); return Some(port_path.to_string()); } } @@ -1176,10 +1192,7 @@ pub fn detect_ch9329_with_baud() -> Option<(String, u32)> { && response[0] == PACKET_HEADER[0] && response[1] == PACKET_HEADER[1] { - info!( - "CH9329 detected on {} @ {} baud", - port_path, baud_rate - ); + info!("CH9329 detected on {} @ {} baud", port_path, baud_rate); return Some((port_path.to_string(), baud_rate)); } } @@ -1217,7 +1230,7 @@ mod tests { assert_eq!(packet[3], cmd::SEND_KB_GENERAL_DATA); // Command assert_eq!(packet[4], 8); // Length (8 data bytes) assert_eq!(&packet[5..13], &data); // Data - // Checksum = 0x57 + 0xAB + 0x00 + 0x02 + 0x08 + 0x00 + 0x00 + 0x04 + ... = 0x10 + // Checksum = 0x57 + 0xAB + 0x00 + 0x02 + 0x08 + 0x00 + 0x00 + 0x04 + ... = 0x10 let expected_checksum: u8 = packet[..13].iter().fold(0u8, |acc, &x| acc.wrapping_add(x)); assert_eq!(packet[13], expected_checksum); } @@ -1234,10 +1247,10 @@ mod tests { assert_eq!(packet[1], 0xAB); assert_eq!(packet[2], 0x00); // Address assert_eq!(packet[3], 0x05); // CMD_SEND_MS_REL_DATA - assert_eq!(packet[4], 5); // Length = 5 + assert_eq!(packet[4], 5); // Length = 5 assert_eq!(packet[5], 0x01); // Mode marker assert_eq!(packet[6], 0x00); // Buttons - assert_eq!(packet[7], 50); // X delta + assert_eq!(packet[7], 50); // X delta } #[test] @@ -1248,7 +1261,9 @@ mod tests { assert_eq!(checksum, 0x03); // Known packet: Keyboard 'A' press - let packet = [0x57u8, 0xAB, 0x00, 0x02, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; + let packet = [ + 0x57u8, 0xAB, 0x00, 0x02, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; let checksum = Ch9329Backend::calculate_checksum(&packet); assert_eq!(checksum, 0x10); } @@ -1258,11 +1273,11 @@ mod tests { // Valid GET_INFO response let response_bytes = [ 0x57, 0xAB, // Header - 0x00, // Address - 0x81, // Command (GET_INFO | 0x80 = success) - 0x08, // Length + 0x00, // Address + 0x81, // Command (GET_INFO | 0x80 = success) + 0x08, // Length 0x31, 0x01, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // Data - 0xE0, // Checksum (calculated) + 0xE0, // Checksum (calculated) ]; // Note: checksum in test is just placeholder, parse will validate diff --git a/src/hid/datachannel.rs b/src/hid/datachannel.rs index 2c898e72..00ecfb12 100644 --- a/src/hid/datachannel.rs +++ b/src/hid/datachannel.rs @@ -210,27 +210,23 @@ pub fn encode_mouse_event(event: &MouseEvent) -> Vec { let y_bytes = (event.y as i16).to_le_bytes(); let extra = match event.event_type { - MouseEventType::Down | MouseEventType::Up => { - event.button.as_ref().map(|b| match b { + MouseEventType::Down | MouseEventType::Up => event + .button + .as_ref() + .map(|b| match b { MouseButton::Left => 0u8, MouseButton::Middle => 1u8, MouseButton::Right => 2u8, MouseButton::Back => 3u8, MouseButton::Forward => 4u8, - }).unwrap_or(0) - } + }) + .unwrap_or(0), MouseEventType::Scroll => event.scroll as u8, _ => 0, }; vec![ - MSG_MOUSE, - event_type, - x_bytes[0], - x_bytes[1], - y_bytes[0], - y_bytes[1], - extra, + MSG_MOUSE, event_type, x_bytes[0], x_bytes[1], y_bytes[0], y_bytes[1], extra, ] } diff --git a/src/hid/keymap.rs b/src/hid/keymap.rs index 8866d1b7..7d570e33 100644 --- a/src/hid/keymap.rs +++ b/src/hid/keymap.rs @@ -278,16 +278,16 @@ static JS_TO_USB_TABLE: [u8; 256] = { } // Numbers 1-9, 0 (JS 49-57, 48 -> USB 0x1E-0x27) - table[49] = usb::KEY_1; // 1 - table[50] = usb::KEY_2; // 2 - table[51] = usb::KEY_3; // 3 - table[52] = usb::KEY_4; // 4 - table[53] = usb::KEY_5; // 5 - table[54] = usb::KEY_6; // 6 - table[55] = usb::KEY_7; // 7 - table[56] = usb::KEY_8; // 8 - table[57] = usb::KEY_9; // 9 - table[48] = usb::KEY_0; // 0 + table[49] = usb::KEY_1; // 1 + table[50] = usb::KEY_2; // 2 + table[51] = usb::KEY_3; // 3 + table[52] = usb::KEY_4; // 4 + table[53] = usb::KEY_5; // 5 + table[54] = usb::KEY_6; // 6 + table[55] = usb::KEY_7; // 7 + table[56] = usb::KEY_8; // 8 + table[57] = usb::KEY_9; // 9 + table[48] = usb::KEY_0; // 0 // Function keys F1-F12 (JS 112-123 -> USB 0x3A-0x45) table[112] = usb::KEY_F1; @@ -304,25 +304,25 @@ static JS_TO_USB_TABLE: [u8; 256] = { table[123] = usb::KEY_F12; // Control keys - table[13] = usb::KEY_ENTER; // Enter - table[27] = usb::KEY_ESCAPE; // Escape - table[8] = usb::KEY_BACKSPACE; // Backspace - table[9] = usb::KEY_TAB; // Tab - table[32] = usb::KEY_SPACE; // Space - table[20] = usb::KEY_CAPS_LOCK; // Caps Lock + table[13] = usb::KEY_ENTER; // Enter + table[27] = usb::KEY_ESCAPE; // Escape + table[8] = usb::KEY_BACKSPACE; // Backspace + table[9] = usb::KEY_TAB; // Tab + table[32] = usb::KEY_SPACE; // Space + table[20] = usb::KEY_CAPS_LOCK; // Caps Lock // Punctuation (JS codes vary by browser/layout) - table[189] = usb::KEY_MINUS; // - - table[187] = usb::KEY_EQUAL; // = - table[219] = usb::KEY_LEFT_BRACKET; // [ + table[189] = usb::KEY_MINUS; // - + table[187] = usb::KEY_EQUAL; // = + table[219] = usb::KEY_LEFT_BRACKET; // [ table[221] = usb::KEY_RIGHT_BRACKET; // ] - table[220] = usb::KEY_BACKSLASH; // \ - table[186] = usb::KEY_SEMICOLON; // ; - table[222] = usb::KEY_APOSTROPHE; // ' - table[192] = usb::KEY_GRAVE; // ` - table[188] = usb::KEY_COMMA; // , - table[190] = usb::KEY_PERIOD; // . - table[191] = usb::KEY_SLASH; // / + table[220] = usb::KEY_BACKSLASH; // \ + table[186] = usb::KEY_SEMICOLON; // ; + table[222] = usb::KEY_APOSTROPHE; // ' + table[192] = usb::KEY_GRAVE; // ` + table[188] = usb::KEY_COMMA; // , + table[190] = usb::KEY_PERIOD; // . + table[191] = usb::KEY_SLASH; // / // Navigation keys table[45] = usb::KEY_INSERT; @@ -359,14 +359,14 @@ static JS_TO_USB_TABLE: [u8; 256] = { // Special keys table[19] = usb::KEY_PAUSE; table[145] = usb::KEY_SCROLL_LOCK; - table[93] = usb::KEY_APPLICATION; // Context menu + table[93] = usb::KEY_APPLICATION; // Context menu // Modifier keys table[17] = usb::KEY_LEFT_CTRL; table[16] = usb::KEY_LEFT_SHIFT; table[18] = usb::KEY_LEFT_ALT; - table[91] = usb::KEY_LEFT_META; // Left Windows/Command - table[92] = usb::KEY_RIGHT_META; // Right Windows/Command + table[91] = usb::KEY_LEFT_META; // Left Windows/Command + table[92] = usb::KEY_RIGHT_META; // Right Windows/Command table }; diff --git a/src/hid/mod.rs b/src/hid/mod.rs index 25177c29..46e4c45b 100644 --- a/src/hid/mod.rs +++ b/src/hid/mod.rs @@ -102,8 +102,14 @@ impl HidController { info!("Creating OTG HID backend from device paths"); Box::new(otg::OtgBackend::from_handles(handles)?) } - HidBackendType::Ch9329 { ref port, baud_rate } => { - info!("Initializing CH9329 HID backend on {} @ {} baud", port, baud_rate); + HidBackendType::Ch9329 { + ref port, + baud_rate, + } => { + info!( + "Initializing CH9329 HID backend on {} @ {} baud", + port, baud_rate + ); Box::new(ch9329::Ch9329Backend::with_baud_rate(port, baud_rate)?) } HidBackendType::None => { @@ -157,16 +163,25 @@ impl HidController { // Report error to monitor, but skip temporary EAGAIN retries // - "eagain_retry": within threshold, just temporary busy // - "eagain": exceeded threshold, report as error - if let AppError::HidError { ref backend, ref reason, ref error_code } = e { + if let AppError::HidError { + ref backend, + ref reason, + ref error_code, + } = e + { if error_code != "eagain_retry" { - self.monitor.report_error(backend, None, reason, error_code).await; + self.monitor + .report_error(backend, None, reason, error_code) + .await; } } Err(e) } } } - None => Err(AppError::BadRequest("HID backend not available".to_string())), + None => Err(AppError::BadRequest( + "HID backend not available".to_string(), + )), } } @@ -188,16 +203,25 @@ impl HidController { // Report error to monitor, but skip temporary EAGAIN retries // - "eagain_retry": within threshold, just temporary busy // - "eagain": exceeded threshold, report as error - if let AppError::HidError { ref backend, ref reason, ref error_code } = e { + if let AppError::HidError { + ref backend, + ref reason, + ref error_code, + } = e + { if error_code != "eagain_retry" { - self.monitor.report_error(backend, None, reason, error_code).await; + self.monitor + .report_error(backend, None, reason, error_code) + .await; } } Err(e) } } } - None => Err(AppError::BadRequest("HID backend not available".to_string())), + None => Err(AppError::BadRequest( + "HID backend not available".to_string(), + )), } } @@ -205,26 +229,33 @@ impl HidController { pub async fn send_consumer(&self, event: ConsumerEvent) -> Result<()> { let backend = self.backend.read().await; match backend.as_ref() { - Some(b) => { - match b.send_consumer(event).await { - Ok(_) => { - if self.monitor.is_error().await { - let backend_type = self.backend_type.read().await; - self.monitor.report_recovered(backend_type.name_str()).await; - } - Ok(()) - } - Err(e) => { - if let AppError::HidError { ref backend, ref reason, ref error_code } = e { - if error_code != "eagain_retry" { - self.monitor.report_error(backend, None, reason, error_code).await; - } - } - Err(e) + Some(b) => match b.send_consumer(event).await { + Ok(_) => { + if self.monitor.is_error().await { + let backend_type = self.backend_type.read().await; + self.monitor.report_recovered(backend_type.name_str()).await; } + Ok(()) } - } - None => Err(AppError::BadRequest("HID backend not available".to_string())), + Err(e) => { + if let AppError::HidError { + ref backend, + ref reason, + ref error_code, + } = e + { + if error_code != "eagain_retry" { + self.monitor + .report_error(backend, None, reason, error_code) + .await; + } + } + Err(e) + } + }, + None => Err(AppError::BadRequest( + "HID backend not available".to_string(), + )), } } @@ -269,9 +300,9 @@ impl HidController { // Include error information from monitor let (error, error_code) = match self.monitor.status().await { - HidHealthStatus::Error { reason, error_code, .. } => { - (Some(reason), Some(error_code)) - } + HidHealthStatus::Error { + reason, error_code, .. + } => (Some(reason), Some(error_code)), _ => (None, None), }; @@ -320,7 +351,7 @@ impl HidController { None => { warn!("OTG backend requires OtgService, but it's not available"); return Err(AppError::Config( - "OTG backend not available (OtgService missing)".to_string() + "OTG backend not available (OtgService missing)".to_string(), )); } }; @@ -341,7 +372,10 @@ impl HidController { warn!("Failed to initialize OTG backend: {}", e); // Cleanup: disable HID in OtgService if let Err(e2) = otg_service.disable_hid().await { - warn!("Failed to cleanup HID after init failure: {}", e2); + warn!( + "Failed to cleanup HID after init failure: {}", + e2 + ); } None } @@ -363,8 +397,14 @@ impl HidController { } } } - HidBackendType::Ch9329 { ref port, baud_rate } => { - info!("Initializing CH9329 HID backend on {} @ {} baud", port, baud_rate); + HidBackendType::Ch9329 { + ref port, + baud_rate, + } => { + info!( + "Initializing CH9329 HID backend on {} @ {} baud", + port, baud_rate + ); match ch9329::Ch9329Backend::with_baud_rate(port, baud_rate) { Ok(b) => { let boxed = Box::new(b); diff --git a/src/hid/monitor.rs b/src/hid/monitor.rs index 1c03979c..0ce84d88 100644 --- a/src/hid/monitor.rs +++ b/src/hid/monitor.rs @@ -144,7 +144,8 @@ impl HidHealthMonitor { // Check if we're in cooldown period after recent recovery let current_ms = self.start_instant.elapsed().as_millis() as u64; let last_recovery = self.last_recovery_ms.load(Ordering::Relaxed); - let in_cooldown = last_recovery > 0 && current_ms < last_recovery + self.config.recovery_cooldown_ms; + let in_cooldown = + last_recovery > 0 && current_ms < last_recovery + self.config.recovery_cooldown_ms; // Check if error code changed let error_changed = { @@ -229,10 +230,7 @@ impl HidHealthMonitor { // Only log and publish events if there were multiple retries // (avoid log spam for transient single-retry recoveries) if retry_count > 1 { - debug!( - "HID {} recovered after {} retries", - backend, retry_count - ); + debug!("HID {} recovered after {} retries", backend, retry_count); // Publish recovery event if let Some(ref events) = *self.events.read().await { @@ -372,9 +370,7 @@ mod tests { let monitor = HidHealthMonitor::with_defaults(); for i in 1..=5 { - monitor - .report_error("otg", None, "Error", "io_error") - .await; + monitor.report_error("otg", None, "Error", "io_error").await; assert_eq!(monitor.retry_count(), i); } } @@ -387,9 +383,7 @@ mod tests { }); for _ in 0..100 { - monitor - .report_error("otg", None, "Error", "io_error") - .await; + monitor.report_error("otg", None, "Error", "io_error").await; assert!(monitor.should_retry()); } } @@ -417,9 +411,7 @@ mod tests { async fn test_reset() { let monitor = HidHealthMonitor::with_defaults(); - monitor - .report_error("otg", None, "Error", "io_error") - .await; + monitor.report_error("otg", None, "Error", "io_error").await; assert!(monitor.is_error().await); monitor.reset().await; diff --git a/src/hid/otg.rs b/src/hid/otg.rs index f8253c61..1c34f6e6 100644 --- a/src/hid/otg.rs +++ b/src/hid/otg.rs @@ -30,9 +30,11 @@ use tracing::{debug, info, trace, warn}; use super::backend::HidBackend; use super::keymap; -use super::types::{ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType}; +use super::types::{ + ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType, +}; use crate::error::{AppError, Result}; -use crate::otg::{HidDevicePaths, wait_for_hid_devices}; +use crate::otg::{wait_for_hid_devices, HidDevicePaths}; /// Device type for ensure_device operations #[derive(Debug, Clone, Copy)] @@ -73,11 +75,21 @@ impl LedState { /// Convert to raw byte pub fn to_byte(&self) -> u8 { let mut b = 0u8; - if self.num_lock { b |= 0x01; } - if self.caps_lock { b |= 0x02; } - if self.scroll_lock { b |= 0x04; } - if self.compose { b |= 0x08; } - if self.kana { b |= 0x10; } + if self.num_lock { + b |= 0x01; + } + if self.caps_lock { + b |= 0x02; + } + if self.scroll_lock { + b |= 0x04; + } + if self.compose { + b |= 0x08; + } + if self.kana { + b |= 0x10; + } b } } @@ -145,7 +157,9 @@ impl OtgBackend { keyboard_path: paths.keyboard, mouse_rel_path: paths.mouse_relative, mouse_abs_path: paths.mouse_absolute, - consumer_path: paths.consumer.unwrap_or_else(|| PathBuf::from("/dev/hidg3")), + consumer_path: paths + .consumer + .unwrap_or_else(|| PathBuf::from("/dev/hidg3")), keyboard_dev: Mutex::new(None), mouse_rel_dev: Mutex::new(None), mouse_abs_dev: Mutex::new(None), @@ -198,7 +212,8 @@ impl OtgBackend { Ok(1) => { // Device ready, check for errors if let Some(revents) = pollfd[0].revents() { - if revents.contains(PollFlags::POLLERR) || revents.contains(PollFlags::POLLHUP) { + if revents.contains(PollFlags::POLLERR) || revents.contains(PollFlags::POLLHUP) + { return Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "Device error or hangup", @@ -297,7 +312,10 @@ impl OtgBackend { // Close the device if open (device was removed) let mut dev = dev_mutex.lock(); if dev.is_some() { - debug!("Device path {} no longer exists, closing handle", path.display()); + debug!( + "Device path {} no longer exists, closing handle", + path.display() + ); *dev = None; } self.online.store(false, Ordering::Relaxed); @@ -335,20 +353,24 @@ impl OtgBackend { .custom_flags(libc::O_NONBLOCK) .open(path) .map_err(|e| { - AppError::Internal(format!("Failed to open HID device {}: {}", path.display(), e)) + AppError::Internal(format!( + "Failed to open HID device {}: {}", + path.display(), + e + )) }) } /// Convert I/O error to HidError with appropriate error code fn io_error_to_hid_error(e: std::io::Error, operation: &str) -> AppError { let error_code = match e.raw_os_error() { - Some(32) => "epipe", // EPIPE - broken pipe - Some(108) => "eshutdown", // ESHUTDOWN - transport endpoint shutdown - Some(11) => "eagain", // EAGAIN - resource temporarily unavailable - Some(6) => "enxio", // ENXIO - no such device or address - Some(19) => "enodev", // ENODEV - no such device - Some(5) => "eio", // EIO - I/O error - Some(2) => "enoent", // ENOENT - no such file or directory + Some(32) => "epipe", // EPIPE - broken pipe + Some(108) => "eshutdown", // ESHUTDOWN - transport endpoint shutdown + Some(11) => "eagain", // EAGAIN - resource temporarily unavailable + Some(6) => "enxio", // ENXIO - no such device or address + Some(19) => "enodev", // ENODEV - no such device + Some(5) => "eio", // EIO - I/O error + Some(2) => "enoent", // ENOENT - no such file or directory _ => "io_error", }; @@ -361,9 +383,7 @@ impl OtgBackend { /// Check if all HID device files exist pub fn check_devices_exist(&self) -> bool { - self.keyboard_path.exists() - && self.mouse_rel_path.exists() - && self.mouse_abs_path.exists() + self.keyboard_path.exists() && self.mouse_rel_path.exists() && self.mouse_abs_path.exists() } /// Get list of missing device paths @@ -415,7 +435,10 @@ impl OtgBackend { self.eagain_count.store(0, Ordering::Relaxed); debug!("Keyboard ESHUTDOWN, closing for recovery"); *dev = None; - Err(Self::io_error_to_hid_error(e, "Failed to write keyboard report")) + Err(Self::io_error_to_hid_error( + e, + "Failed to write keyboard report", + )) } Some(11) => { // EAGAIN after poll - should be rare, silently drop @@ -426,7 +449,10 @@ impl OtgBackend { self.online.store(false, Ordering::Relaxed); self.eagain_count.store(0, Ordering::Relaxed); warn!("Keyboard write error: {}", e); - Err(Self::io_error_to_hid_error(e, "Failed to write keyboard report")) + Err(Self::io_error_to_hid_error( + e, + "Failed to write keyboard report", + )) } } } @@ -472,7 +498,10 @@ impl OtgBackend { self.eagain_count.store(0, Ordering::Relaxed); debug!("Relative mouse ESHUTDOWN, closing for recovery"); *dev = None; - Err(Self::io_error_to_hid_error(e, "Failed to write mouse report")) + Err(Self::io_error_to_hid_error( + e, + "Failed to write mouse report", + )) } Some(11) => { // EAGAIN after poll - should be rare, silently drop @@ -482,7 +511,10 @@ impl OtgBackend { self.online.store(false, Ordering::Relaxed); self.eagain_count.store(0, Ordering::Relaxed); warn!("Relative mouse write error: {}", e); - Err(Self::io_error_to_hid_error(e, "Failed to write mouse report")) + Err(Self::io_error_to_hid_error( + e, + "Failed to write mouse report", + )) } } } @@ -534,7 +566,10 @@ impl OtgBackend { self.eagain_count.store(0, Ordering::Relaxed); debug!("Absolute mouse ESHUTDOWN, closing for recovery"); *dev = None; - Err(Self::io_error_to_hid_error(e, "Failed to write mouse report")) + Err(Self::io_error_to_hid_error( + e, + "Failed to write mouse report", + )) } Some(11) => { // EAGAIN after poll - should be rare, silently drop @@ -544,7 +579,10 @@ impl OtgBackend { self.online.store(false, Ordering::Relaxed); self.eagain_count.store(0, Ordering::Relaxed); warn!("Absolute mouse write error: {}", e); - Err(Self::io_error_to_hid_error(e, "Failed to write mouse report")) + Err(Self::io_error_to_hid_error( + e, + "Failed to write mouse report", + )) } } } @@ -590,7 +628,10 @@ impl OtgBackend { self.online.store(false, Ordering::Relaxed); debug!("Consumer control ESHUTDOWN, closing for recovery"); *dev = None; - Err(Self::io_error_to_hid_error(e, "Failed to write consumer report")) + Err(Self::io_error_to_hid_error( + e, + "Failed to write consumer report", + )) } Some(11) => { // EAGAIN after poll - silently drop @@ -599,7 +640,10 @@ impl OtgBackend { _ => { self.online.store(false, Ordering::Relaxed); warn!("Consumer control write error: {}", e); - Err(Self::io_error_to_hid_error(e, "Failed to write consumer report")) + Err(Self::io_error_to_hid_error( + e, + "Failed to write consumer report", + )) } } } @@ -632,7 +676,10 @@ impl OtgBackend { } Ok(_) => Ok(None), // No data available Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(None), - Err(e) => Err(AppError::Internal(format!("Failed to read LED state: {}", e))), + Err(e) => Err(AppError::Internal(format!( + "Failed to read LED state: {}", + e + ))), } } else { Ok(None) @@ -677,34 +724,55 @@ impl HidBackend for OtgBackend { *self.keyboard_dev.lock() = Some(file); info!("Keyboard device opened: {}", self.keyboard_path.display()); } else { - warn!("Keyboard device not found: {}", self.keyboard_path.display()); + warn!( + "Keyboard device not found: {}", + self.keyboard_path.display() + ); } // Open relative mouse device if self.mouse_rel_path.exists() { let file = Self::open_device(&self.mouse_rel_path)?; *self.mouse_rel_dev.lock() = Some(file); - info!("Relative mouse device opened: {}", self.mouse_rel_path.display()); + info!( + "Relative mouse device opened: {}", + self.mouse_rel_path.display() + ); } else { - warn!("Relative mouse device not found: {}", self.mouse_rel_path.display()); + warn!( + "Relative mouse device not found: {}", + self.mouse_rel_path.display() + ); } // Open absolute mouse device if self.mouse_abs_path.exists() { let file = Self::open_device(&self.mouse_abs_path)?; *self.mouse_abs_dev.lock() = Some(file); - info!("Absolute mouse device opened: {}", self.mouse_abs_path.display()); + info!( + "Absolute mouse device opened: {}", + self.mouse_abs_path.display() + ); } else { - warn!("Absolute mouse device not found: {}", self.mouse_abs_path.display()); + warn!( + "Absolute mouse device not found: {}", + self.mouse_abs_path.display() + ); } // Open consumer control device (optional, may not exist on older setups) if self.consumer_path.exists() { let file = Self::open_device(&self.consumer_path)?; *self.consumer_dev.lock() = Some(file); - info!("Consumer control device opened: {}", self.consumer_path.display()); + info!( + "Consumer control device opened: {}", + self.consumer_path.display() + ); } else { - debug!("Consumer control device not found: {}", self.consumer_path.display()); + debug!( + "Consumer control device not found: {}", + self.consumer_path.display() + ); } // Mark as online if all devices opened successfully diff --git a/src/hid/types.rs b/src/hid/types.rs index bff86b67..c6848d4a 100644 --- a/src/hid/types.rs +++ b/src/hid/types.rs @@ -341,12 +341,7 @@ pub struct MouseReport { impl MouseReport { /// Convert to bytes for USB HID (relative mouse) pub fn to_bytes_relative(&self) -> [u8; 4] { - [ - self.buttons, - self.x as u8, - self.y as u8, - self.wheel as u8, - ] + [self.buttons, self.x as u8, self.y as u8, self.wheel as u8] } /// Convert to bytes for USB HID (absolute mouse) diff --git a/src/hid/websocket.rs b/src/hid/websocket.rs index e427a29f..4cff8bc4 100644 --- a/src/hid/websocket.rs +++ b/src/hid/websocket.rs @@ -50,7 +50,11 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc) { vec![RESP_ERR_HID_UNAVAILABLE] }; - if sender.send(Message::Binary(initial_response.into())).await.is_err() { + if sender + .send(Message::Binary(initial_response.into())) + .await + .is_err() + { error!("Failed to send initial HID status"); return; } @@ -66,7 +70,9 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc) { warn!("HID controller not available, ignoring message"); } // Send error response (optional, for client awareness) - let _ = sender.send(Message::Binary(vec![RESP_ERR_HID_UNAVAILABLE].into())).await; + let _ = sender + .send(Message::Binary(vec![RESP_ERR_HID_UNAVAILABLE].into())) + .await; continue; } @@ -81,9 +87,14 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc) { Ok(Message::Text(text)) => { // Text messages are no longer supported if log_throttler.should_log("text_message_rejected") { - debug!("Received text message (not supported): {} bytes", text.len()); + debug!( + "Received text message (not supported): {} bytes", + text.len() + ); } - let _ = sender.send(Message::Binary(vec![RESP_ERR_INVALID_MESSAGE].into())).await; + let _ = sender + .send(Message::Binary(vec![RESP_ERR_INVALID_MESSAGE].into())) + .await; } Ok(Message::Ping(data)) => { let _ = sender.send(Message::Pong(data)).await; @@ -142,7 +153,7 @@ async fn handle_binary_message(data: &[u8], state: &AppState) -> Result<(), Stri #[cfg(test)] mod tests { use super::*; - use crate::hid::datachannel::{MSG_KEYBOARD, MSG_MOUSE, KB_EVENT_DOWN, MS_EVENT_MOVE}; + use crate::hid::datachannel::{KB_EVENT_DOWN, MSG_KEYBOARD, MSG_MOUSE, MS_EVENT_MOVE}; #[test] fn test_response_codes() { diff --git a/src/main.rs b/src/main.rs index 189ecc43..08291de1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,9 +4,9 @@ use std::sync::Arc; use axum_server::tls_rustls::RustlsConfig; use clap::{Parser, ValueEnum}; +use rustls::crypto::{ring, CryptoProvider}; use tokio::sync::broadcast; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -use rustls::crypto::{ring, CryptoProvider}; use one_kvm::atx::AtxController; use one_kvm::audio::{AudioController, AudioControllerConfig, AudioQuality}; @@ -26,7 +26,15 @@ use one_kvm::webrtc::{WebRtcStreamer, WebRtcStreamerConfig}; /// Log level for the application #[derive(Debug, Clone, Copy, Default, ValueEnum)] -enum LogLevel {Error, Warn, #[default] Info, Verbose, Debug, Trace,} +enum LogLevel { + Error, + Warn, + #[default] + Info, + Verbose, + Debug, + Trace, +} /// One-KVM command line arguments #[derive(Parser, Debug)] @@ -82,10 +90,7 @@ async fn main() -> anyhow::Result<()> { CryptoProvider::install_default(ring::default_provider()) .expect("Failed to install rustls crypto provider"); - tracing::info!( - "Starting One-KVM v{}", - env!("CARGO_PKG_VERSION") - ); + tracing::info!("Starting One-KVM v{}", env!("CARGO_PKG_VERSION")); // Determine data directory (CLI arg takes precedence) let data_dir = args.data_dir.unwrap_or_else(get_data_dir); @@ -153,21 +158,37 @@ async fn main() -> anyhow::Result<()> { // Parse video configuration once (avoid duplication) let (video_format, video_resolution) = parse_video_config(&config); - tracing::debug!("Parsed video config: {} @ {}x{}", video_format, video_resolution.width, video_resolution.height); + tracing::debug!( + "Parsed video config: {} @ {}x{}", + video_format, + video_resolution.width, + video_resolution.height + ); // Create video streamer and initialize with config if device is set let streamer = Streamer::new(); streamer.set_event_bus(events.clone()).await; if let Some(ref device_path) = config.video.device { if let Err(e) = streamer - .apply_video_config(device_path, video_format, video_resolution, config.video.fps) + .apply_video_config( + device_path, + video_format, + video_resolution, + config.video.fps, + ) .await { - tracing::warn!("Failed to initialize video with config: {}, will auto-detect", e); + tracing::warn!( + "Failed to initialize video with config: {}, will auto-detect", + e + ); } else { tracing::info!( "Video configured: {} @ {}x{} {}", - device_path, video_resolution.width, video_resolution.height, video_format + device_path, + video_resolution.width, + video_resolution.height, + video_format ); } } @@ -185,8 +206,18 @@ async fn main() -> anyhow::Result<()> { let mut turn_servers = vec![]; // Check if user configured custom servers - let has_custom_stun = config.stream.stun_server.as_ref().map(|s| !s.is_empty()).unwrap_or(false); - let has_custom_turn = config.stream.turn_server.as_ref().map(|s| !s.is_empty()).unwrap_or(false); + let has_custom_stun = config + .stream + .stun_server + .as_ref() + .map(|s| !s.is_empty()) + .unwrap_or(false); + let has_custom_turn = config + .stream + .turn_server + .as_ref() + .map(|s| !s.is_empty()) + .unwrap_or(false); // If no custom servers, use public ICE servers (like RustDesk) if !has_custom_stun && !has_custom_turn { @@ -201,7 +232,9 @@ async fn main() -> anyhow::Result<()> { turn_servers.push(turn); } } else { - tracing::info!("No public ICE servers configured, using host candidates only"); + tracing::info!( + "No public ICE servers configured, using host candidates only" + ); } } else { // Use custom servers @@ -214,13 +247,18 @@ async fn main() -> anyhow::Result<()> { if let Some(ref turn) = config.stream.turn_server { if !turn.is_empty() { let username = config.stream.turn_username.clone().unwrap_or_default(); - let credential = config.stream.turn_password.clone().unwrap_or_default(); + let credential = + config.stream.turn_password.clone().unwrap_or_default(); turn_servers.push(one_kvm::webrtc::config::TurnServer::new( turn.clone(), username.clone(), credential, )); - tracing::info!("Using custom TURN server: {} (user: {})", turn, username); + tracing::info!( + "Using custom TURN server: {} (user: {})", + turn, + username + ); } } } @@ -237,7 +275,6 @@ async fn main() -> anyhow::Result<()> { }; tracing::info!("WebRTC streamer created (supports H264, extensible to VP8/VP9/H265)"); - // Create OTG Service (single instance for centralized USB gadget management) let otg_service = Arc::new(OtgService::new()); tracing::info!("OTG Service created"); @@ -285,14 +322,26 @@ async fn main() -> anyhow::Result<()> { if ventoy_resource_dir.exists() { if let Err(e) = ventoy_img::init_resources(&ventoy_resource_dir) { tracing::warn!("Failed to initialize Ventoy resources: {}", e); - tracing::info!("Ventoy resource files should be placed in: {}", ventoy_resource_dir.display()); + tracing::info!( + "Ventoy resource files should be placed in: {}", + ventoy_resource_dir.display() + ); tracing::info!("Required files: {:?}", ventoy_img::required_files()); } else { - tracing::info!("Ventoy resources initialized from {}", ventoy_resource_dir.display()); + tracing::info!( + "Ventoy resources initialized from {}", + ventoy_resource_dir.display() + ); } } else { - tracing::warn!("Ventoy resource directory not found: {}", ventoy_resource_dir.display()); - tracing::info!("Create the directory and place the following files: {:?}", ventoy_img::required_files()); + tracing::warn!( + "Ventoy resource directory not found: {}", + ventoy_resource_dir.display() + ); + tracing::info!( + "Create the directory and place the following files: {:?}", + ventoy_img::required_files() + ); } let controller = MsdController::new( @@ -382,27 +431,42 @@ async fn main() -> anyhow::Result<()> { let (actual_format, actual_resolution, actual_fps) = streamer.current_video_config().await; tracing::info!( "Initial video config from capturer: {}x{} {:?} @ {}fps", - actual_resolution.width, actual_resolution.height, actual_format, actual_fps + actual_resolution.width, + actual_resolution.height, + actual_format, + actual_fps ); - webrtc_streamer.update_video_config(actual_resolution, actual_format, actual_fps).await; + webrtc_streamer + .update_video_config(actual_resolution, actual_format, actual_fps) + .await; webrtc_streamer.set_video_source(frame_tx).await; tracing::info!("WebRTC streamer connected to video frame source"); } else { - tracing::warn!("Video capturer not ready, WebRTC will connect to frame source when available"); + tracing::warn!( + "Video capturer not ready, WebRTC will connect to frame source when available" + ); } // Create video stream manager (unified MJPEG/WebRTC management) // Use with_webrtc_streamer to ensure we use the same WebRtcStreamer instance - let stream_manager = VideoStreamManager::with_webrtc_streamer(streamer.clone(), webrtc_streamer.clone()); + let stream_manager = + VideoStreamManager::with_webrtc_streamer(streamer.clone(), webrtc_streamer.clone()); stream_manager.set_event_bus(events.clone()).await; stream_manager.set_config_store(config_store.clone()).await; // Initialize stream manager with configured mode let initial_mode = config.stream.mode.clone(); if let Err(e) = stream_manager.init_with_mode(initial_mode.clone()).await { - tracing::warn!("Failed to initialize stream manager with mode {:?}: {}", initial_mode, e); + tracing::warn!( + "Failed to initialize stream manager with mode {:?}: {}", + initial_mode, + e + ); } else { - tracing::info!("Video stream manager initialized with mode: {:?}", initial_mode); + tracing::info!( + "Video stream manager initialized with mode: {:?}", + initial_mode + ); } // Create RustDesk service (optional, based on config) @@ -421,7 +485,9 @@ async fn main() -> anyhow::Result<()> { Some(Arc::new(service)) } else { if config.rustdesk.enabled { - tracing::warn!("RustDesk enabled but configuration is incomplete (missing server or credentials)"); + tracing::warn!( + "RustDesk enabled but configuration is incomplete (missing server or credentials)" + ); } else { tracing::info!("RustDesk disabled in configuration"); } @@ -458,7 +524,8 @@ async fn main() -> anyhow::Result<()> { cfg.rustdesk.public_key = updated_config.public_key.clone(); cfg.rustdesk.private_key = updated_config.private_key.clone(); cfg.rustdesk.signing_public_key = updated_config.signing_public_key.clone(); - cfg.rustdesk.signing_private_key = updated_config.signing_private_key.clone(); + cfg.rustdesk.signing_private_key = + updated_config.signing_private_key.clone(); cfg.rustdesk.uuid = updated_config.uuid.clone(); }) .await @@ -542,8 +609,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!("Starting HTTPS server on {}", bind_addr); - let server = axum_server::bind_rustls(bind_addr, tls_config) - .serve(app.into_make_service()); + let server = axum_server::bind_rustls(bind_addr, tls_config).serve(app.into_make_service()); tokio::select! { _ = shutdown_signal => { @@ -600,8 +666,8 @@ fn init_logging(level: LogLevel, verbose_count: u8) { }; // Environment variable takes highest priority - let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| filter.into()); + let env_filter = + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| filter.into()); tracing_subscriber::registry() .with(env_filter) @@ -662,7 +728,8 @@ fn spawn_device_info_broadcaster(state: Arc, events: Arc) { loop { // Use timeout to handle pending broadcasts let recv_result = if pending_broadcast { - let remaining = DEBOUNCE_MS.saturating_sub(last_broadcast.elapsed().as_millis() as u64); + let remaining = + DEBOUNCE_MS.saturating_sub(last_broadcast.elapsed().as_millis() as u64); tokio::time::timeout(Duration::from_millis(remaining), rx.recv()).await } else { Ok(rx.recv().await) @@ -674,6 +741,7 @@ fn spawn_device_info_broadcaster(state: Arc, events: Arc) { event, SystemEvent::StreamStateChanged { .. } | SystemEvent::StreamConfigApplied { .. } + | SystemEvent::StreamModeReady { .. } | SystemEvent::HidStateChanged { .. } | SystemEvent::MsdStateChanged { .. } | SystemEvent::AtxStateChanged { .. } @@ -706,7 +774,10 @@ fn spawn_device_info_broadcaster(state: Arc, events: Arc) { } }); - tracing::info!("DeviceInfo broadcaster task started (debounce: {}ms)", DEBOUNCE_MS); + tracing::info!( + "DeviceInfo broadcaster task started (debounce: {}ms)", + DEBOUNCE_MS + ); } /// Clean up subsystems on shutdown diff --git a/src/msd/controller.rs b/src/msd/controller.rs index 5db2296f..0b8dc676 100644 --- a/src/msd/controller.rs +++ b/src/msd/controller.rs @@ -99,7 +99,10 @@ impl MsdController { initialized: true, path: self.drive_path.clone(), }); - debug!("Found existing virtual drive: {}", self.drive_path.display()); + debug!( + "Found existing virtual drive: {}", + self.drive_path.display() + ); } } @@ -146,7 +149,12 @@ impl MsdController { /// * `image` - Image info to mount /// * `cdrom` - Mount as CD-ROM (read-only, removable) /// * `read_only` - Mount as read-only - pub async fn connect_image(&self, image: &ImageInfo, cdrom: bool, read_only: bool) -> Result<()> { + pub async fn connect_image( + &self, + image: &ImageInfo, + cdrom: bool, + read_only: bool, + ) -> Result<()> { // Acquire operation lock to prevent concurrent operations let _op_guard = self.operation_lock.write().await; @@ -154,7 +162,9 @@ impl MsdController { if !state.available { let err = AppError::Internal("MSD not available".to_string()); - self.monitor.report_error("MSD not available", "not_available").await; + self.monitor + .report_error("MSD not available", "not_available") + .await; return Err(err); } @@ -167,7 +177,9 @@ impl MsdController { // Verify image exists if !image.path.exists() { let error_msg = format!("Image file not found: {}", image.path.display()); - self.monitor.report_error(&error_msg, "image_not_found").await; + self.monitor + .report_error(&error_msg, "image_not_found") + .await; return Err(AppError::Internal(error_msg)); } @@ -182,12 +194,16 @@ impl MsdController { if let Some(ref msd) = *self.msd_function.read().await { if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await { let error_msg = format!("Failed to configure LUN: {}", e); - self.monitor.report_error(&error_msg, "configfs_error").await; + self.monitor + .report_error(&error_msg, "configfs_error") + .await; return Err(e); } } else { let err = AppError::Internal("MSD function not initialized".to_string()); - self.monitor.report_error("MSD function not initialized", "not_initialized").await; + self.monitor + .report_error("MSD function not initialized", "not_initialized") + .await; return Err(err); } @@ -236,7 +252,9 @@ impl MsdController { if !state.available { let err = AppError::Internal("MSD not available".to_string()); - self.monitor.report_error("MSD not available", "not_available").await; + self.monitor + .report_error("MSD not available", "not_available") + .await; return Err(err); } @@ -248,10 +266,11 @@ impl MsdController { // Check drive exists if !self.drive_path.exists() { - let err = AppError::Internal( - "Virtual drive not initialized. Call init first.".to_string(), - ); - self.monitor.report_error("Virtual drive not initialized", "drive_not_found").await; + let err = + AppError::Internal("Virtual drive not initialized. Call init first.".to_string()); + self.monitor + .report_error("Virtual drive not initialized", "drive_not_found") + .await; return Err(err); } @@ -262,12 +281,16 @@ impl MsdController { if let Some(ref msd) = *self.msd_function.read().await { if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await { let error_msg = format!("Failed to configure LUN: {}", e); - self.monitor.report_error(&error_msg, "configfs_error").await; + self.monitor + .report_error(&error_msg, "configfs_error") + .await; return Err(e); } } else { let err = AppError::Internal("MSD function not initialized".to_string()); - self.monitor.report_error("MSD function not initialized", "not_initialized").await; + self.monitor + .report_error("MSD function not initialized", "not_initialized") + .await; return Err(err); } @@ -381,12 +404,9 @@ impl MsdController { } // Extract filename for initial response - let display_filename = filename.clone().unwrap_or_else(|| { - url.rsplit('/') - .next() - .unwrap_or("download") - .to_string() - }); + let display_filename = filename + .clone() + .unwrap_or_else(|| url.rsplit('/').next().unwrap_or("download").to_string()); // Create initial progress let initial_progress = DownloadProgress { diff --git a/src/msd/image.rs b/src/msd/image.rs index da524d66..d08b7a18 100644 --- a/src/msd/image.rs +++ b/src/msd/image.rs @@ -42,9 +42,8 @@ impl ImageManager { /// Ensure images directory exists pub fn ensure_dir(&self) -> Result<()> { - fs::create_dir_all(&self.images_path).map_err(|e| { - AppError::Internal(format!("Failed to create images directory: {}", e)) - })?; + fs::create_dir_all(&self.images_path) + .map_err(|e| AppError::Internal(format!("Failed to create images directory: {}", e)))?; Ok(()) } @@ -54,9 +53,9 @@ impl ImageManager { let mut images = Vec::new(); - for entry in fs::read_dir(&self.images_path).map_err(|e| { - AppError::Internal(format!("Failed to read images directory: {}", e)) - })? { + for entry in fs::read_dir(&self.images_path) + .map_err(|e| AppError::Internal(format!("Failed to read images directory: {}", e)))? + { let entry = entry.map_err(|e| { AppError::Internal(format!("Failed to read directory entry: {}", e)) })?; @@ -146,9 +145,8 @@ impl ImageManager { ))); } - let mut file = File::create(&path).map_err(|e| { - AppError::Internal(format!("Failed to create image file: {}", e)) - })?; + let mut file = File::create(&path) + .map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?; file.write_all(data).map_err(|e| { // Try to clean up on error @@ -193,9 +191,8 @@ impl ImageManager { } // Create file and copy data - let mut file = File::create(&path).map_err(|e| { - AppError::Internal(format!("Failed to create image file: {}", e)) - })?; + let mut file = File::create(&path) + .map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?; let bytes_written = io::copy(reader, &mut file).map_err(|e| { let _ = fs::remove_file(&path); @@ -244,9 +241,11 @@ impl ImageManager { let mut bytes_written: u64 = 0; // Stream chunks directly to disk - while let Some(chunk) = field.chunk().await.map_err(|e| { - AppError::Internal(format!("Failed to read upload chunk: {}", e)) - })? { + while let Some(chunk) = field + .chunk() + .await + .map_err(|e| AppError::Internal(format!("Failed to read upload chunk: {}", e)))? + { // Check size limit bytes_written += chunk.len() as u64; if bytes_written > MAX_IMAGE_SIZE { @@ -260,15 +259,15 @@ impl ImageManager { } // Write chunk to file - file.write_all(&chunk).await.map_err(|e| { - AppError::Internal(format!("Failed to write chunk: {}", e)) - })?; + file.write_all(&chunk) + .await + .map_err(|e| AppError::Internal(format!("Failed to write chunk: {}", e)))?; } // Flush and close file - file.flush().await.map_err(|e| { - AppError::Internal(format!("Failed to flush file: {}", e)) - })?; + file.flush() + .await + .map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?; drop(file); // Move temp file to final location @@ -279,7 +278,10 @@ impl ImageManager { AppError::Internal(format!("Failed to rename temp file: {}", e)) })?; - info!("Created image (streaming): {} ({} bytes)", name, bytes_written); + info!( + "Created image (streaming): {} ({} bytes)", + name, bytes_written + ); self.get_by_name(&name) } @@ -288,9 +290,8 @@ impl ImageManager { pub fn delete(&self, id: &str) -> Result<()> { let image = self.get(id)?; - fs::remove_file(&image.path).map_err(|e| { - AppError::Internal(format!("Failed to delete image: {}", e)) - })?; + fs::remove_file(&image.path) + .map_err(|e| AppError::Internal(format!("Failed to delete image: {}", e)))?; info!("Deleted image: {}", image.name); Ok(()) @@ -304,9 +305,8 @@ impl ImageManager { return Err(AppError::NotFound(format!("Image not found: {}", name))); } - fs::remove_file(&path).map_err(|e| { - AppError::Internal(format!("Failed to delete image: {}", e)) - })?; + fs::remove_file(&path) + .map_err(|e| AppError::Internal(format!("Failed to delete image: {}", e)))?; info!("Deleted image: {}", name); Ok(()) @@ -414,7 +414,9 @@ impl ImageManager { }; if final_filename.is_empty() { - return Err(AppError::BadRequest("Could not determine filename".to_string())); + return Err(AppError::BadRequest( + "Could not determine filename".to_string(), + )); } // Check if file already exists @@ -468,16 +470,14 @@ impl ImageManager { progress_callback(0, content_length); while let Some(chunk_result) = stream.next().await { - let chunk = chunk_result - .map_err(|e| AppError::Internal(format!("Download error: {}", e)))?; + let chunk = + chunk_result.map_err(|e| AppError::Internal(format!("Download error: {}", e)))?; - file.write_all(&chunk) - .await - .map_err(|e| { - // Cleanup on error - let _ = std::fs::remove_file(&temp_path); - AppError::Internal(format!("Failed to write data: {}", e)) - })?; + file.write_all(&chunk).await.map_err(|e| { + // Cleanup on error + let _ = std::fs::remove_file(&temp_path); + AppError::Internal(format!("Failed to write data: {}", e)) + })?; downloaded += chunk.len() as u64; diff --git a/src/msd/mod.rs b/src/msd/mod.rs index bfc0ff13..1359209f 100644 --- a/src/msd/mod.rs +++ b/src/msd/mod.rs @@ -15,19 +15,19 @@ //! ``` pub mod controller; -pub mod ventoy_drive; pub mod image; pub mod monitor; pub mod types; +pub mod ventoy_drive; pub use controller::MsdController; -pub use ventoy_drive::VentoyDrive; pub use image::ImageManager; pub use monitor::{MsdHealthMonitor, MsdHealthStatus, MsdMonitorConfig}; pub use types::{ DownloadProgress, DownloadStatus, DriveFile, DriveInfo, DriveInitRequest, ImageDownloadRequest, ImageInfo, MsdConnectRequest, MsdMode, MsdState, }; +pub use ventoy_drive::VentoyDrive; // Re-export from otg module for backward compatibility pub use crate::otg::{MsdFunction, MsdLunConfig}; diff --git a/src/msd/monitor.rs b/src/msd/monitor.rs index aa1e09fc..a9f80109 100644 --- a/src/msd/monitor.rs +++ b/src/msd/monitor.rs @@ -120,7 +120,10 @@ impl MsdHealthMonitor { // Log with throttling (always log if error type changed) let throttle_key = format!("msd_{}", error_code); if error_changed || self.throttler.should_log(&throttle_key) { - warn!("MSD error: {} (code: {}, count: {})", reason, error_code, count); + warn!( + "MSD error: {} (code: {}, count: {})", + reason, error_code, count + ); } // Update last error code diff --git a/src/msd/ventoy_drive.rs b/src/msd/ventoy_drive.rs index 04d03be2..8839a6e1 100644 --- a/src/msd/ventoy_drive.rs +++ b/src/msd/ventoy_drive.rs @@ -71,13 +71,11 @@ impl VentoyDrive { // Run Ventoy creation in blocking task let info = tokio::task::spawn_blocking(move || { - VentoyImage::create(&path, &size_str, DEFAULT_LABEL) - .map_err(ventoy_to_app_error)?; + VentoyImage::create(&path, &size_str, DEFAULT_LABEL).map_err(ventoy_to_app_error)?; // Get file metadata for DriveInfo - let metadata = std::fs::metadata(&path).map_err(|e| { - AppError::Internal(format!("Failed to read drive metadata: {}", e)) - })?; + let metadata = std::fs::metadata(&path) + .map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?; Ok::(DriveInfo { size: metadata.len(), @@ -104,16 +102,13 @@ impl VentoyDrive { let _lock = self.lock.read().await; // Read lock for info query tokio::task::spawn_blocking(move || { - let metadata = std::fs::metadata(&path).map_err(|e| { - AppError::Internal(format!("Failed to read drive metadata: {}", e)) - })?; + let metadata = std::fs::metadata(&path) + .map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?; // Open image to get file list and calculate used space let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?; - let files = image - .list_files_recursive() - .map_err(ventoy_to_app_error)?; + let files = image.list_files_recursive().map_err(ventoy_to_app_error)?; let used: u64 = files .iter() @@ -190,9 +185,11 @@ impl VentoyDrive { let mut bytes_written: u64 = 0; - while let Some(chunk) = field.chunk().await.map_err(|e| { - AppError::Internal(format!("Failed to read upload chunk: {}", e)) - })? { + while let Some(chunk) = field + .chunk() + .await + .map_err(|e| AppError::Internal(format!("Failed to read upload chunk: {}", e)))? + { bytes_written += chunk.len() as u64; tokio::io::AsyncWriteExt::write_all(&mut temp_file, &chunk) .await @@ -248,9 +245,7 @@ impl VentoyDrive { tokio::task::spawn_blocking(move || { let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?; - image - .read_file(&file_path) - .map_err(ventoy_to_app_error) + image.read_file(&file_path).map_err(ventoy_to_app_error) }) .await .map_err(|e| AppError::Internal(format!("Task join error: {}", e)))? @@ -321,7 +316,8 @@ impl VentoyDrive { let lock = self.lock.clone(); // Create a channel for streaming data - let (tx, rx) = tokio::sync::mpsc::channel::>(8); + let (tx, rx) = + tokio::sync::mpsc::channel::>(8); // Spawn blocking task to read and send chunks tokio::task::spawn_blocking(move || { @@ -404,20 +400,14 @@ fn ventoy_to_app_error(err: VentoyError) -> AppError { match err { VentoyError::Io(e) => AppError::Io(e), VentoyError::InvalidSize(s) => AppError::BadRequest(format!("Invalid size: {}", s)), - VentoyError::SizeParseError(s) => { - AppError::BadRequest(format!("Size parse error: {}", s)) - } - VentoyError::FilesystemError(s) => { - AppError::Internal(format!("Filesystem error: {}", s)) - } + VentoyError::SizeParseError(s) => AppError::BadRequest(format!("Size parse error: {}", s)), + VentoyError::FilesystemError(s) => AppError::Internal(format!("Filesystem error: {}", s)), VentoyError::ImageError(s) => AppError::Internal(format!("Image error: {}", s)), VentoyError::FileNotFound(s) => AppError::NotFound(format!("File not found: {}", s)), VentoyError::ResourceNotFound(s) => { AppError::Internal(format!("Resource not found: {}", s)) } - VentoyError::PartitionError(s) => { - AppError::Internal(format!("Partition error: {}", s)) - } + VentoyError::PartitionError(s) => AppError::Internal(format!("Partition error: {}", s)), } } @@ -481,7 +471,8 @@ impl std::io::Write for ChannelWriter { let space = STREAM_CHUNK_SIZE - self.buffer.len(); let to_copy = std::cmp::min(space, buf.len() - written); - self.buffer.extend_from_slice(&buf[written..written + to_copy]); + self.buffer + .extend_from_slice(&buf[written..written + to_copy]); written += to_copy; if self.buffer.len() >= STREAM_CHUNK_SIZE { @@ -512,10 +503,7 @@ mod tests { use tempfile::TempDir; /// Path to ventoy resources directory - static RESOURCE_DIR: &str = concat!( - env!("CARGO_MANIFEST_DIR"), - "/../ventoy-img-rs/resources" - ); + static RESOURCE_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../ventoy-img-rs/resources"); /// Initialize ventoy resources once fn init_ventoy_resources() -> bool { @@ -561,7 +549,10 @@ mod tests { if !output.status.success() { return Err(std::io::Error::new( std::io::ErrorKind::Other, - format!("xz decompress failed: {}", String::from_utf8_lossy(&output.stderr)), + format!( + "xz decompress failed: {}", + String::from_utf8_lossy(&output.stderr) + ), )); } diff --git a/src/otg/configfs.rs b/src/otg/configfs.rs index 11cae1f4..5e7fd493 100644 --- a/src/otg/configfs.rs +++ b/src/otg/configfs.rs @@ -109,15 +109,25 @@ pub fn read_file(path: &Path) -> Result { /// Create directory if not exists pub fn create_dir(path: &Path) -> Result<()> { - fs::create_dir_all(path) - .map_err(|e| AppError::Internal(format!("Failed to create directory {}: {}", path.display(), e))) + fs::create_dir_all(path).map_err(|e| { + AppError::Internal(format!( + "Failed to create directory {}: {}", + path.display(), + e + )) + }) } /// Remove directory pub fn remove_dir(path: &Path) -> Result<()> { if path.exists() { - fs::remove_dir(path) - .map_err(|e| AppError::Internal(format!("Failed to remove directory {}: {}", path.display(), e)))?; + fs::remove_dir(path).map_err(|e| { + AppError::Internal(format!( + "Failed to remove directory {}: {}", + path.display(), + e + )) + })?; } Ok(()) } @@ -125,14 +135,21 @@ pub fn remove_dir(path: &Path) -> Result<()> { /// Remove file pub fn remove_file(path: &Path) -> Result<()> { if path.exists() { - fs::remove_file(path) - .map_err(|e| AppError::Internal(format!("Failed to remove file {}: {}", path.display(), e)))?; + fs::remove_file(path).map_err(|e| { + AppError::Internal(format!("Failed to remove file {}: {}", path.display(), e)) + })?; } Ok(()) } /// Create symlink pub fn create_symlink(src: &Path, dest: &Path) -> Result<()> { - std::os::unix::fs::symlink(src, dest) - .map_err(|e| AppError::Internal(format!("Failed to create symlink {} -> {}: {}", dest.display(), src.display(), e))) + std::os::unix::fs::symlink(src, dest).map_err(|e| { + AppError::Internal(format!( + "Failed to create symlink {} -> {}: {}", + dest.display(), + src.display(), + e + )) + }) } diff --git a/src/otg/hid.rs b/src/otg/hid.rs index f172fac3..38283b79 100644 --- a/src/otg/hid.rs +++ b/src/otg/hid.rs @@ -3,7 +3,9 @@ use std::path::{Path, PathBuf}; use tracing::debug; -use super::configfs::{create_dir, create_symlink, remove_dir, remove_file, write_bytes, write_file}; +use super::configfs::{ + create_dir, create_symlink, remove_dir, remove_file, write_bytes, write_file, +}; use super::function::{FunctionMeta, GadgetFunction}; use super::report_desc::{CONSUMER_CONTROL, KEYBOARD, MOUSE_ABSOLUTE, MOUSE_RELATIVE}; use crate::error::Result; @@ -39,20 +41,20 @@ impl HidFunctionType { /// Get HID protocol pub fn protocol(&self) -> u8 { match self { - HidFunctionType::Keyboard => 1, // Keyboard - HidFunctionType::MouseRelative => 2, // Mouse - HidFunctionType::MouseAbsolute => 2, // Mouse - HidFunctionType::ConsumerControl => 0, // None + HidFunctionType::Keyboard => 1, // Keyboard + HidFunctionType::MouseRelative => 2, // Mouse + HidFunctionType::MouseAbsolute => 2, // Mouse + HidFunctionType::ConsumerControl => 0, // None } } /// Get HID subclass pub fn subclass(&self) -> u8 { match self { - HidFunctionType::Keyboard => 1, // Boot interface - HidFunctionType::MouseRelative => 1, // Boot interface - HidFunctionType::MouseAbsolute => 0, // No boot interface - HidFunctionType::ConsumerControl => 0, // No boot interface + HidFunctionType::Keyboard => 1, // Boot interface + HidFunctionType::MouseRelative => 1, // Boot interface + HidFunctionType::MouseAbsolute => 0, // No boot interface + HidFunctionType::ConsumerControl => 0, // No boot interface } } @@ -169,14 +171,27 @@ impl GadgetFunction for HidFunction { create_dir(&func_path)?; // Set HID parameters - write_file(&func_path.join("protocol"), &self.func_type.protocol().to_string())?; - write_file(&func_path.join("subclass"), &self.func_type.subclass().to_string())?; - write_file(&func_path.join("report_length"), &self.func_type.report_length().to_string())?; + write_file( + &func_path.join("protocol"), + &self.func_type.protocol().to_string(), + )?; + write_file( + &func_path.join("subclass"), + &self.func_type.subclass().to_string(), + )?; + write_file( + &func_path.join("report_length"), + &self.func_type.report_length().to_string(), + )?; // Write report descriptor write_bytes(&func_path.join("report_desc"), self.func_type.report_desc())?; - debug!("Created HID function: {} at {}", self.name(), func_path.display()); + debug!( + "Created HID function: {} at {}", + self.name(), + func_path.display() + ); Ok(()) } diff --git a/src/otg/manager.rs b/src/otg/manager.rs index 6840d639..7e50949d 100644 --- a/src/otg/manager.rs +++ b/src/otg/manager.rs @@ -7,7 +7,8 @@ use tracing::{debug, error, info, warn}; use super::configfs::{ create_dir, find_udc, is_configfs_available, remove_dir, write_file, CONFIGFS_PATH, - DEFAULT_GADGET_NAME, DEFAULT_USB_BCD_DEVICE, USB_BCD_USB, DEFAULT_USB_PRODUCT_ID, DEFAULT_USB_VENDOR_ID, + DEFAULT_GADGET_NAME, DEFAULT_USB_BCD_DEVICE, DEFAULT_USB_PRODUCT_ID, DEFAULT_USB_VENDOR_ID, + USB_BCD_USB, }; use super::endpoint::{EndpointAllocator, DEFAULT_MAX_ENDPOINTS}; use super::function::{FunctionMeta, GadgetFunction}; @@ -77,7 +78,11 @@ impl OtgGadgetManager { } /// Create a new gadget manager with custom descriptor - pub fn with_descriptor(gadget_name: &str, max_endpoints: u8, descriptor: GadgetDescriptor) -> Self { + pub fn with_descriptor( + gadget_name: &str, + max_endpoints: u8, + descriptor: GadgetDescriptor, + ) -> Self { let gadget_path = PathBuf::from(CONFIGFS_PATH).join(gadget_name); let config_path = gadget_path.join("configs/c.1"); @@ -303,10 +308,22 @@ impl OtgGadgetManager { /// Set USB device descriptors fn set_device_descriptors(&self) -> Result<()> { - write_file(&self.gadget_path.join("idVendor"), &format!("0x{:04x}", self.descriptor.vendor_id))?; - write_file(&self.gadget_path.join("idProduct"), &format!("0x{:04x}", self.descriptor.product_id))?; - write_file(&self.gadget_path.join("bcdDevice"), &format!("0x{:04x}", self.descriptor.device_version))?; - write_file(&self.gadget_path.join("bcdUSB"), &format!("0x{:04x}", USB_BCD_USB))?; + write_file( + &self.gadget_path.join("idVendor"), + &format!("0x{:04x}", self.descriptor.vendor_id), + )?; + write_file( + &self.gadget_path.join("idProduct"), + &format!("0x{:04x}", self.descriptor.product_id), + )?; + write_file( + &self.gadget_path.join("bcdDevice"), + &format!("0x{:04x}", self.descriptor.device_version), + )?; + write_file( + &self.gadget_path.join("bcdUSB"), + &format!("0x{:04x}", USB_BCD_USB), + )?; write_file(&self.gadget_path.join("bDeviceClass"), "0x00")?; // Composite device write_file(&self.gadget_path.join("bDeviceSubClass"), "0x00")?; write_file(&self.gadget_path.join("bDeviceProtocol"), "0x00")?; @@ -319,8 +336,14 @@ impl OtgGadgetManager { let strings_path = self.gadget_path.join("strings/0x409"); create_dir(&strings_path)?; - write_file(&strings_path.join("serialnumber"), &self.descriptor.serial_number)?; - write_file(&strings_path.join("manufacturer"), &self.descriptor.manufacturer)?; + write_file( + &strings_path.join("serialnumber"), + &self.descriptor.serial_number, + )?; + write_file( + &strings_path.join("manufacturer"), + &self.descriptor.manufacturer, + )?; write_file(&strings_path.join("product"), &self.descriptor.product)?; debug!("Created USB strings"); Ok(()) @@ -349,7 +372,10 @@ impl OtgGadgetManager { /// Get endpoint usage info pub fn endpoint_info(&self) -> (u8, u8) { - (self.endpoint_allocator.used(), self.endpoint_allocator.max()) + ( + self.endpoint_allocator.used(), + self.endpoint_allocator.max(), + ) } /// Get gadget path diff --git a/src/otg/msd.rs b/src/otg/msd.rs index 8d644f66..bd37fa71 100644 --- a/src/otg/msd.rs +++ b/src/otg/msd.rs @@ -161,7 +161,10 @@ impl MsdFunction { // Write only changed attributes let cdrom_changed = current_cdrom != new_cdrom; if cdrom_changed { - debug!("Updating LUN {} cdrom: {} -> {}", lun, current_cdrom, new_cdrom); + debug!( + "Updating LUN {} cdrom: {} -> {}", + lun, current_cdrom, new_cdrom + ); write_file(&lun_path.join("cdrom"), new_cdrom)?; } if current_ro != new_ro { @@ -169,11 +172,17 @@ impl MsdFunction { write_file(&lun_path.join("ro"), new_ro)?; } if current_removable != new_removable { - debug!("Updating LUN {} removable: {} -> {}", lun, current_removable, new_removable); + debug!( + "Updating LUN {} removable: {} -> {}", + lun, current_removable, new_removable + ); write_file(&lun_path.join("removable"), new_removable)?; } if current_nofua != new_nofua { - debug!("Updating LUN {} nofua: {} -> {}", lun, current_nofua, new_nofua); + debug!( + "Updating LUN {} nofua: {} -> {}", + lun, current_nofua, new_nofua + ); write_file(&lun_path.join("nofua"), new_nofua)?; } @@ -258,11 +267,17 @@ impl MsdFunction { // forced_eject forcibly detaches the backing file regardless of host state let forced_eject_path = lun_path.join("forced_eject"); if forced_eject_path.exists() { - debug!("Using forced_eject to disconnect LUN {} at {:?}", lun, forced_eject_path); + debug!( + "Using forced_eject to disconnect LUN {} at {:?}", + lun, forced_eject_path + ); match write_file(&forced_eject_path, "1") { Ok(_) => debug!("forced_eject write succeeded"), Err(e) => { - warn!("forced_eject write failed: {}, falling back to clearing file", e); + warn!( + "forced_eject write failed: {}, falling back to clearing file", + e + ); write_file(&lun_path.join("file"), "")?; } } diff --git a/src/otg/report_desc.rs b/src/otg/report_desc.rs index 1dac7df0..45871e2b 100644 --- a/src/otg/report_desc.rs +++ b/src/otg/report_desc.rs @@ -135,17 +135,17 @@ pub const MOUSE_ABSOLUTE: &[u8] = &[ /// [0-1] Consumer Control Usage (16-bit little-endian) /// Supports: Play/Pause, Stop, Next/Prev Track, Mute, Volume Up/Down, etc. pub const CONSUMER_CONTROL: &[u8] = &[ - 0x05, 0x0C, // Usage Page (Consumer) - 0x09, 0x01, // Usage (Consumer Control) - 0xA1, 0x01, // Collection (Application) - 0x15, 0x00, // Logical Minimum (0) + 0x05, 0x0C, // Usage Page (Consumer) + 0x09, 0x01, // Usage (Consumer Control) + 0xA1, 0x01, // Collection (Application) + 0x15, 0x00, // Logical Minimum (0) 0x26, 0xFF, 0x03, // Logical Maximum (1023) - 0x19, 0x00, // Usage Minimum (0) + 0x19, 0x00, // Usage Minimum (0) 0x2A, 0xFF, 0x03, // Usage Maximum (1023) - 0x75, 0x10, // Report Size (16) - 0x95, 0x01, // Report Count (1) - 0x81, 0x00, // Input (Data, Array) - 0xC0, // End Collection + 0x75, 0x10, // Report Size (16) + 0x95, 0x01, // Report Count (1) + 0x81, 0x00, // Input (Data, Array) + 0xC0, // End Collection ]; #[cfg(test)] diff --git a/src/otg/service.rs b/src/otg/service.rs index f64948f8..38755bc8 100644 --- a/src/otg/service.rs +++ b/src/otg/service.rs @@ -27,8 +27,8 @@ use tracing::{debug, info, warn}; use super::manager::{wait_for_hid_devices, GadgetDescriptor, OtgGadgetManager}; use super::msd::MsdFunction; -use crate::error::{AppError, Result}; use crate::config::OtgDescriptorConfig; +use crate::error::{AppError, Result}; /// Bitflags for requested functions (lock-free) const FLAG_HID: u8 = 0b01; @@ -254,8 +254,9 @@ impl OtgService { // Get MSD function let msd = self.msd_function.read().await; - msd.clone() - .ok_or_else(|| AppError::Internal("MSD function not set after gadget setup".to_string())) + msd.clone().ok_or_else(|| { + AppError::Internal("MSD function not set after gadget setup".to_string()) + }) } /// Disable MSD function @@ -465,7 +466,10 @@ impl OtgService { device_version: super::configfs::DEFAULT_USB_BCD_DEVICE, manufacturer: config.manufacturer.clone(), product: config.product.clone(), - serial_number: config.serial_number.clone().unwrap_or_else(|| "0123456789".to_string()), + serial_number: config + .serial_number + .clone() + .unwrap_or_else(|| "0123456789".to_string()), }; // Update stored descriptor diff --git a/src/rustdesk/bytes_codec.rs b/src/rustdesk/bytes_codec.rs index ec210bc7..a592f7f5 100644 --- a/src/rustdesk/bytes_codec.rs +++ b/src/rustdesk/bytes_codec.rs @@ -34,7 +34,10 @@ pub fn encode_frame(data: &[u8]) -> io::Result> { let h = ((len << 2) as u32) | 0x3; buf.extend_from_slice(&h.to_le_bytes()); } else { - return Err(io::Error::new(io::ErrorKind::InvalidInput, "Message too large")); + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Message too large", + )); } buf.extend_from_slice(data); @@ -79,7 +82,10 @@ pub async fn read_frame(reader: &mut R) -> io::Result MAX_PACKET_LENGTH { - return Err(io::Error::new(io::ErrorKind::InvalidData, "Message too large")); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Message too large", + )); } // Read message body @@ -133,7 +139,10 @@ pub fn encode_frame_into(data: &[u8], buf: &mut BytesMut) -> io::Result<()> { } else if len <= MAX_PACKET_LENGTH { buf.put_u32_le(((len << 2) as u32) | 0x3); } else { - return Err(io::Error::new(io::ErrorKind::InvalidInput, "Message too large")); + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Message too large", + )); } buf.extend_from_slice(data); @@ -216,7 +225,10 @@ impl BytesCodec { n >>= 2; if n > self.max_packet_length { - return Err(io::Error::new(io::ErrorKind::InvalidData, "Message too large")); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Message too large", + )); } src.advance(head_len); @@ -245,7 +257,10 @@ impl BytesCodec { } else if len <= MAX_PACKET_LENGTH { buf.put_u32_le(((len << 2) as u32) | 0x3); } else { - return Err(io::Error::new(io::ErrorKind::InvalidInput, "Message too large")); + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Message too large", + )); } buf.extend(data); diff --git a/src/rustdesk/config.rs b/src/rustdesk/config.rs index a51649ec..3a72e792 100644 --- a/src/rustdesk/config.rs +++ b/src/rustdesk/config.rs @@ -116,9 +116,9 @@ impl RustDeskConfig { /// Get the UUID bytes (returns None if not set) pub fn get_uuid_bytes(&self) -> Option<[u8; 16]> { - self.uuid.as_ref().and_then(|s| { - uuid::Uuid::parse_str(s).ok().map(|u| *u.as_bytes()) - }) + self.uuid + .as_ref() + .and_then(|s| uuid::Uuid::parse_str(s).ok().map(|u| *u.as_bytes())) } /// Get the rendezvous server address with default port @@ -135,26 +135,29 @@ impl RustDeskConfig { /// Get the relay server address with default port pub fn relay_addr(&self) -> Option { - self.relay_server.as_ref().map(|s| { - if s.contains(':') { - s.clone() - } else { - format!("{}:21117", s) - } - }).or_else(|| { - // Default: same host as rendezvous server - let server = &self.rendezvous_server; - if !server.is_empty() { - let host = server.split(':').next().unwrap_or(""); - if !host.is_empty() { - Some(format!("{}:21117", host)) + self.relay_server + .as_ref() + .map(|s| { + if s.contains(':') { + s.clone() + } else { + format!("{}:21117", s) + } + }) + .or_else(|| { + // Default: same host as rendezvous server + let server = &self.rendezvous_server; + if !server.is_empty() { + let host = server.split(':').next().unwrap_or(""); + if !host.is_empty() { + Some(format!("{}:21117", host)) + } else { + None + } } else { None } - } else { - None - } - }) + }) } } @@ -222,7 +225,10 @@ mod tests { // Explicit relay server config.relay_server = Some("relay.example.com".to_string()); - assert_eq!(config.relay_addr(), Some("relay.example.com:21117".to_string())); + assert_eq!( + config.relay_addr(), + Some("relay.example.com:21117".to_string()) + ); // No rendezvous server, relay is None config.rendezvous_server = String::new(); diff --git a/src/rustdesk/connection.rs b/src/rustdesk/connection.rs index 4c75c55b..5957a528 100644 --- a/src/rustdesk/connection.rs +++ b/src/rustdesk/connection.rs @@ -13,16 +13,16 @@ use std::sync::Arc; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use bytes::{Bytes, BytesMut}; -use sodiumoxide::crypto::box_; use parking_lot::RwLock; use protobuf::Message as ProtobufMessage; -use tokio::net::TcpStream; +use sodiumoxide::crypto::box_; use tokio::net::tcp::OwnedWriteHalf; +use tokio::net::TcpStream; use tokio::sync::{broadcast, mpsc, Mutex}; use tracing::{debug, error, info, warn}; use crate::audio::AudioController; -use crate::hid::{HidController, KeyboardEvent, KeyEventType, KeyboardModifiers}; +use crate::hid::{HidController, KeyEventType, KeyboardEvent, KeyboardModifiers}; use crate::video::encoder::registry::{EncoderRegistry, VideoEncoderType}; use crate::video::encoder::BitratePreset; use crate::video::stream_manager::VideoStreamManager; @@ -33,10 +33,9 @@ use super::crypto::{self, KeyPair, SigningKeyPair}; use super::frame_adapters::{AudioFrameAdapter, VideoCodec, VideoFrameAdapter}; use super::hid_adapter::{convert_key_event, convert_mouse_event, mouse_type}; use super::protocol::{ - message, misc, login_response, - KeyEvent, MouseEvent, Clipboard, Misc, LoginRequest, LoginResponse, PeerInfo, - IdPk, SignedId, Hash, TestDelay, ControlKey, - decode_message, HbbMessage, DisplayInfo, SupportedEncoding, OptionMessage, PublicKey, + decode_message, login_response, message, misc, Clipboard, ControlKey, DisplayInfo, Hash, + HbbMessage, IdPk, KeyEvent, LoginRequest, LoginResponse, Misc, MouseEvent, OptionMessage, + PeerInfo, PublicKey, SignedId, SupportedEncoding, TestDelay, }; use sodiumoxide::crypto::secretbox; @@ -268,7 +267,11 @@ impl Connection { } /// Handle an incoming TCP connection - pub async fn handle_tcp(&mut self, stream: TcpStream, peer_addr: SocketAddr) -> anyhow::Result<()> { + pub async fn handle_tcp( + &mut self, + stream: TcpStream, + peer_addr: SocketAddr, + ) -> anyhow::Result<()> { info!("New connection from {}", peer_addr); *self.state.write() = ConnectionState::Handshaking; @@ -279,7 +282,9 @@ impl Connection { // Send our SignedId first (this is what RustDesk protocol expects) // The SignedId contains our device ID and temporary public key let signed_id_msg = self.create_signed_id_message(&self.device_id.clone()); - let signed_id_bytes = signed_id_msg.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode SignedId: {}", e))?; + let signed_id_bytes = signed_id_msg + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode SignedId: {}", e))?; debug!("Sending SignedId with device_id={}", self.device_id); self.send_framed_arc(&writer, &signed_id_bytes).await?; @@ -402,7 +407,11 @@ impl Connection { } /// Send framed message using Arc> with RustDesk's variable-length encoding - async fn send_framed_arc(&self, writer: &Arc>, data: &[u8]) -> anyhow::Result<()> { + async fn send_framed_arc( + &self, + writer: &Arc>, + data: &[u8], + ) -> anyhow::Result<()> { let mut w = writer.lock().await; write_frame(&mut *w, data).await?; Ok(()) @@ -480,7 +489,9 @@ impl Connection { pk.symmetric_value.len() ); if pk.asymmetric_value.is_empty() && pk.symmetric_value.is_empty() { - warn!("Received EMPTY PublicKey - client may have failed signature verification!"); + warn!( + "Received EMPTY PublicKey - client may have failed signature verification!" + ); } self.handle_peer_public_key(pk, writer).await?; } @@ -535,7 +546,7 @@ impl Connection { info!("Received SignedId from peer, id_len={}", si.id.len()); self.handle_signed_id(si, writer).await?; return Ok(()); - }, + } message::Union::Hash(_) => "Hash", message::Union::VideoFrame(_) => "VideoFrame", message::Union::CursorData(_) => "CursorData", @@ -564,16 +575,26 @@ impl Connection { lr: &LoginRequest, writer: &Arc>, ) -> anyhow::Result { - info!("Login request from {} ({}), password_len={}", lr.my_id, lr.my_name, lr.password.len()); + info!( + "Login request from {} ({}), password_len={}", + lr.my_id, + lr.my_name, + lr.password.len() + ); // Check if our server requires a password if !self.password.is_empty() { // Server requires password if lr.password.is_empty() { // Client sent empty password - tell them to enter password - info!("Empty password from {}, requesting password input", lr.my_id); + info!( + "Empty password from {}, requesting password input", + lr.my_id + ); let error_response = self.create_login_error_response("Empty Password"); - let response_bytes = error_response.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + let response_bytes = error_response + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; self.send_encrypted_arc(writer, &response_bytes).await?; // Don't close connection - wait for retry with password return Ok(false); @@ -583,7 +604,9 @@ impl Connection { if !self.verify_password(&lr.password) { warn!("Wrong password from {}", lr.my_id); let error_response = self.create_login_error_response("Wrong Password"); - let response_bytes = error_response.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + let response_bytes = error_response + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; self.send_encrypted_arc(writer, &response_bytes).await?; // Don't close connection - wait for retry with correct password return Ok(false); @@ -601,7 +624,9 @@ impl Connection { info!("Negotiated video codec: {:?}", negotiated); let response = self.create_login_response(true); - let response_bytes = response.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + let response_bytes = response + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; self.send_encrypted_arc(writer, &response_bytes).await?; Ok(true) } @@ -679,7 +704,10 @@ impl Connection { }; if let Some(preset) = preset { - info!("Client requested quality preset: {:?} (image_quality={})", preset, image_quality); + info!( + "Client requested quality preset: {:?} (image_quality={})", + preset, image_quality + ); if let Some(ref video_manager) = self.video_manager { if let Err(e) = video_manager.set_bitrate_preset(preset).await { warn!("Failed to set bitrate preset: {}", e); @@ -729,7 +757,10 @@ impl Connection { // Log custom_image_quality (accept but don't process) if opt.custom_image_quality > 0 { - debug!("Client sent custom_image_quality: {} (ignored)", opt.custom_image_quality); + debug!( + "Client sent custom_image_quality: {} (ignored)", + opt.custom_image_quality + ); } if opt.custom_fps > 0 { debug!("Client requested FPS: {}", opt.custom_fps); @@ -779,7 +810,10 @@ impl Connection { let negotiated_codec = self.negotiated_codec.unwrap_or(VideoEncoderType::H264); let task = tokio::spawn(async move { - info!("Starting video streaming for connection {} with codec {:?}", conn_id, negotiated_codec); + info!( + "Starting video streaming for connection {} with codec {:?}", + conn_id, negotiated_codec + ); if let Err(e) = run_video_streaming( conn_id, @@ -788,7 +822,9 @@ impl Connection { state, shutdown_tx, negotiated_codec, - ).await { + ) + .await + { error!("Video streaming error for connection {}: {}", conn_id, e); } @@ -815,13 +851,9 @@ impl Connection { let task = tokio::spawn(async move { info!("Starting audio streaming for connection {}", conn_id); - if let Err(e) = run_audio_streaming( - conn_id, - audio_controller, - audio_tx, - state, - shutdown_tx, - ).await { + if let Err(e) = + run_audio_streaming(conn_id, audio_controller, audio_tx, state, shutdown_tx).await + { error!("Audio streaming error for connection {}: {}", conn_id, e); } @@ -894,7 +926,10 @@ impl Connection { self.encryption_enabled = true; } Err(e) => { - warn!("Failed to decrypt session key: {:?}, falling back to unencrypted", e); + warn!( + "Failed to decrypt session key: {:?}, falling back to unencrypted", + e + ); // Continue without encryption - some clients may not support it self.encryption_enabled = false; } @@ -917,8 +952,13 @@ impl Connection { // This tells the client what salt to use for password hashing // Must be encrypted if session key was negotiated let hash_msg = self.create_hash_message(); - let hash_bytes = hash_msg.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; - debug!("Sending Hash message for password authentication (encrypted={})", self.encryption_enabled); + let hash_bytes = hash_msg + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + debug!( + "Sending Hash message for password authentication (encrypted={})", + self.encryption_enabled + ); self.send_encrypted_arc(writer, &hash_bytes).await?; Ok(()) @@ -971,7 +1011,9 @@ impl Connection { // If we haven't sent our SignedId yet, send it now // (This handles the case where client sends SignedId before we do) let signed_id_msg = self.create_signed_id_message(&self.device_id.clone()); - let signed_id_bytes = signed_id_msg.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + let signed_id_bytes = signed_id_msg + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; self.send_framed_arc(writer, &signed_id_bytes).await?; Ok(()) @@ -1073,7 +1115,8 @@ impl Connection { msg } else { let mut login_response = LoginResponse::new(); - login_response.union = Some(login_response::Union::Error("Invalid password".to_string())); + login_response.union = + Some(login_response::Union::Error("Invalid password".to_string())); login_response.enable_trusted_devices = false; let mut msg = HbbMessage::new(); @@ -1133,7 +1176,9 @@ impl Connection { let mut response = HbbMessage::new(); response.union = Some(message::Union::TestDelay(test_delay)); - let data = response.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + let data = response + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; self.send_encrypted_arc(writer, &data).await?; debug!( @@ -1161,10 +1206,7 @@ impl Connection { /// The client will echo this back, allowing us to calculate RTT. /// The measured delay is then included in future TestDelay messages /// for the client to display. - async fn send_test_delay( - &mut self, - writer: &Arc>, - ) -> anyhow::Result<()> { + async fn send_test_delay(&mut self, writer: &Arc>) -> anyhow::Result<()> { // Get current time in milliseconds since epoch let time_ms = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -1180,13 +1222,18 @@ impl Connection { let mut msg = HbbMessage::new(); msg.union = Some(message::Union::TestDelay(test_delay)); - let data = msg.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + let data = msg + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; self.send_encrypted_arc(writer, &data).await?; // Record when we sent this, so we can calculate RTT when client echoes back self.last_test_delay_sent = Some(Instant::now()); - debug!("TestDelay sent: time={}, last_delay={}ms", time_ms, self.last_delay); + debug!( + "TestDelay sent: time={}, last_delay={}ms", + time_ms, self.last_delay + ); Ok(()) } @@ -1208,7 +1255,10 @@ impl Connection { self.last_caps_lock = caps_lock_in_modifiers; // Send CapsLock key press (down + up) to toggle state on target if let Some(ref hid) = self.hid { - debug!("CapsLock state changed to {}, sending CapsLock key", caps_lock_in_modifiers); + debug!( + "CapsLock state changed to {}, sending CapsLock key", + caps_lock_in_modifiers + ); let caps_down = KeyboardEvent { event_type: KeyEventType::Down, key: 0x39, // USB HID CapsLock @@ -1234,7 +1284,9 @@ impl Connection { if let Some(kb_event) = convert_key_event(ke) { debug!( "Converted to HID: key=0x{:02X}, event_type={:?}, modifiers={:02X}", - kb_event.key, kb_event.event_type, kb_event.modifiers.to_hid_byte() + kb_event.key, + kb_event.event_type, + kb_event.modifiers.to_hid_byte() ); // Send to HID controller if available if let Some(ref hid) = self.hid { @@ -1393,7 +1445,11 @@ impl ConnectionManager { } /// Accept a new connection - pub async fn accept_connection(&self, stream: TcpStream, peer_addr: SocketAddr) -> anyhow::Result { + pub async fn accept_connection( + &self, + stream: TcpStream, + peer_addr: SocketAddr, + ) -> anyhow::Result { let id = { let mut next = self.next_id.write(); let id = *next; @@ -1406,14 +1462,14 @@ impl ConnectionManager { let hid = self.hid.read().clone(); let audio = self.audio.read().clone(); let video_manager = self.video_manager.read().clone(); - let (mut conn, _rx) = Connection::new(id, &config, signing_keypair, hid, audio, video_manager); + let (mut conn, _rx) = + Connection::new(id, &config, signing_keypair, hid, audio, video_manager); // Track connection state for external access let state = conn.state.clone(); - self.connections.write().push(Arc::new(RwLock::new(ConnectionInfo { - id, - state, - }))); + self.connections + .write() + .push(Arc::new(RwLock::new(ConnectionInfo { id, state }))); // Spawn connection handler - Connection is moved, not locked tokio::spawn(async move { @@ -1466,7 +1522,10 @@ async fn run_video_streaming( }; // Set the video codec on the shared pipeline before subscribing - info!("Setting video codec to {:?} for connection {}", negotiated_codec, conn_id); + info!( + "Setting video codec to {:?} for connection {}", + negotiated_codec, conn_id + ); if let Err(e) = video_manager.set_video_codec(webrtc_codec).await { error!("Failed to set video codec: {}", e); // Continue anyway, will use whatever codec the pipeline already has @@ -1485,7 +1544,10 @@ async fn run_video_streaming( let mut encoded_count: u64 = 0; let mut last_log_time = Instant::now(); - info!("Started shared video streaming for connection {} (codec: {:?})", conn_id, codec); + info!( + "Started shared video streaming for connection {} (codec: {:?})", + conn_id, codec + ); // Outer loop: handles pipeline restarts by re-subscribing 'subscribe_loop: loop { @@ -1500,7 +1562,10 @@ async fn run_video_streaming( Some(rx) => rx, None => { // Pipeline not ready yet, wait and retry - debug!("No encoded frame source available for connection {}, retrying...", conn_id); + debug!( + "No encoded frame source available for connection {}, retrying...", + conn_id + ); tokio::time::sleep(Duration::from_millis(100)).await; continue 'subscribe_loop; } @@ -1619,13 +1684,19 @@ async fn run_audio_streaming( Some(rx) => rx, None => { // Audio not available, wait and retry - debug!("No audio source available for connection {}, retrying...", conn_id); + debug!( + "No audio source available for connection {}, retrying...", + conn_id + ); tokio::time::sleep(Duration::from_millis(500)).await; continue 'subscribe_loop; } }; - info!("RustDesk connection {} subscribed to audio pipeline", conn_id); + info!( + "RustDesk connection {} subscribed to audio pipeline", + conn_id + ); // Send audio format message once before sending frames if !audio_adapter.format_sent() { diff --git a/src/rustdesk/crypto.rs b/src/rustdesk/crypto.rs index 931470e5..10860402 100644 --- a/src/rustdesk/crypto.rs +++ b/src/rustdesk/crypto.rs @@ -86,8 +86,12 @@ impl KeyPair { /// Create from base64-encoded keys pub fn from_base64(public_key: &str, secret_key: &str) -> Result { - let pk_bytes = BASE64.decode(public_key).map_err(|_| CryptoError::InvalidKeyLength)?; - let sk_bytes = BASE64.decode(secret_key).map_err(|_| CryptoError::InvalidKeyLength)?; + let pk_bytes = BASE64 + .decode(public_key) + .map_err(|_| CryptoError::InvalidKeyLength)?; + let sk_bytes = BASE64 + .decode(secret_key) + .map_err(|_| CryptoError::InvalidKeyLength)?; Self::from_keys(&pk_bytes, &sk_bytes) } } @@ -140,7 +144,10 @@ pub fn decrypt_with_key( /// Compute a shared symmetric key from public/private keypair /// This is the precomputed key for the NaCl box -pub fn precompute_key(their_public_key: &PublicKey, our_secret_key: &SecretKey) -> box_::PrecomputedKey { +pub fn precompute_key( + their_public_key: &PublicKey, + our_secret_key: &SecretKey, +) -> box_::PrecomputedKey { box_::precompute(their_public_key, our_secret_key) } @@ -207,8 +214,8 @@ pub fn decrypt_symmetric_key( return Err(CryptoError::InvalidKeyLength); } - let their_pk = PublicKey::from_slice(their_temp_public_key) - .ok_or(CryptoError::InvalidKeyLength)?; + let their_pk = + PublicKey::from_slice(their_temp_public_key).ok_or(CryptoError::InvalidKeyLength)?; // Use zero nonce as per RustDesk protocol let nonce = box_::Nonce([0u8; box_::NONCEBYTES]); @@ -294,8 +301,12 @@ impl SigningKeyPair { /// Create from base64-encoded keys pub fn from_base64(public_key: &str, secret_key: &str) -> Result { - let pk_bytes = BASE64.decode(public_key).map_err(|_| CryptoError::InvalidKeyLength)?; - let sk_bytes = BASE64.decode(secret_key).map_err(|_| CryptoError::InvalidKeyLength)?; + let pk_bytes = BASE64 + .decode(public_key) + .map_err(|_| CryptoError::InvalidKeyLength)?; + let sk_bytes = BASE64 + .decode(secret_key) + .map_err(|_| CryptoError::InvalidKeyLength)?; Self::from_keys(&pk_bytes, &sk_bytes) } @@ -321,8 +332,7 @@ impl SigningKeyPair { /// which is required by RustDesk's protocol where clients encrypt the /// symmetric key using the public key from IdPk. pub fn to_curve25519_pk(&self) -> Result { - ed25519::to_curve25519_pk(&self.public_key) - .map_err(|_| CryptoError::KeyConversionFailed) + ed25519::to_curve25519_pk(&self.public_key).map_err(|_| CryptoError::KeyConversionFailed) } /// Convert Ed25519 secret key to Curve25519 secret key for decryption @@ -330,14 +340,16 @@ impl SigningKeyPair { /// This allows decrypting messages that were encrypted using the /// converted public key. pub fn to_curve25519_sk(&self) -> Result { - ed25519::to_curve25519_sk(&self.secret_key) - .map_err(|_| CryptoError::KeyConversionFailed) + ed25519::to_curve25519_sk(&self.secret_key).map_err(|_| CryptoError::KeyConversionFailed) } } /// Verify a signed message /// Returns the original message if signature is valid -pub fn verify_signed(signed_message: &[u8], public_key: &sign::PublicKey) -> Result, CryptoError> { +pub fn verify_signed( + signed_message: &[u8], + public_key: &sign::PublicKey, +) -> Result, CryptoError> { sign::verify(signed_message, public_key).map_err(|_| CryptoError::SignatureVerificationFailed) } @@ -374,7 +386,8 @@ mod tests { let message = b"Hello, RustDesk!"; let (nonce, ciphertext) = encrypt_box(message, &bob.public_key, &alice.secret_key); - let plaintext = decrypt_box(&ciphertext, &nonce, &alice.public_key, &bob.secret_key).unwrap(); + let plaintext = + decrypt_box(&ciphertext, &nonce, &alice.public_key, &bob.secret_key).unwrap(); assert_eq!(plaintext, message); } diff --git a/src/rustdesk/frame_adapters.rs b/src/rustdesk/frame_adapters.rs index 3697ab1b..14e4d321 100644 --- a/src/rustdesk/frame_adapters.rs +++ b/src/rustdesk/frame_adapters.rs @@ -7,9 +7,8 @@ use bytes::Bytes; use protobuf::Message as ProtobufMessage; use super::protocol::hbb::message::{ - message as msg_union, misc as misc_union, video_frame as vf_union, - AudioFormat, AudioFrame, CursorData, CursorPosition, - EncodedVideoFrame, EncodedVideoFrames, Message, Misc, VideoFrame, + message as msg_union, misc as misc_union, video_frame as vf_union, AudioFormat, AudioFrame, + CursorData, CursorPosition, EncodedVideoFrame, EncodedVideoFrames, Message, Misc, VideoFrame, }; /// Video codec type for RustDesk @@ -63,7 +62,12 @@ impl VideoFrameAdapter { /// Convert encoded video data to RustDesk Message (zero-copy version) /// /// This version takes Bytes directly to avoid copying the frame data. - pub fn encode_frame_from_bytes(&mut self, data: Bytes, is_keyframe: bool, timestamp_ms: u64) -> Message { + pub fn encode_frame_from_bytes( + &mut self, + data: Bytes, + is_keyframe: bool, + timestamp_ms: u64, + ) -> Message { // Calculate relative timestamp if self.seq == 0 { self.timestamp_base = timestamp_ms; @@ -104,13 +108,23 @@ impl VideoFrameAdapter { /// Encode frame to bytes for sending (zero-copy version) /// /// Takes Bytes directly to avoid copying the frame data. - pub fn encode_frame_bytes_zero_copy(&mut self, data: Bytes, is_keyframe: bool, timestamp_ms: u64) -> Bytes { + pub fn encode_frame_bytes_zero_copy( + &mut self, + data: Bytes, + is_keyframe: bool, + timestamp_ms: u64, + ) -> Bytes { let msg = self.encode_frame_from_bytes(data, is_keyframe, timestamp_ms); Bytes::from(msg.write_to_bytes().unwrap_or_default()) } /// Encode frame to bytes for sending - pub fn encode_frame_bytes(&mut self, data: &[u8], is_keyframe: bool, timestamp_ms: u64) -> Bytes { + pub fn encode_frame_bytes( + &mut self, + data: &[u8], + is_keyframe: bool, + timestamp_ms: u64, + ) -> Bytes { self.encode_frame_bytes_zero_copy(Bytes::copy_from_slice(data), is_keyframe, timestamp_ms) } @@ -234,15 +248,13 @@ mod tests { let msg = adapter.encode_frame(&data, true, 0); match &msg.union { - Some(msg_union::Union::VideoFrame(vf)) => { - match &vf.union { - Some(vf_union::Union::H264s(frames)) => { - assert_eq!(frames.frames.len(), 1); - assert!(frames.frames[0].key); - } - _ => panic!("Expected H264s"), + Some(msg_union::Union::VideoFrame(vf)) => match &vf.union { + Some(vf_union::Union::H264s(frames)) => { + assert_eq!(frames.frames.len(), 1); + assert!(frames.frames[0].key); } - } + _ => panic!("Expected H264s"), + }, _ => panic!("Expected VideoFrame"), } } @@ -256,15 +268,13 @@ mod tests { assert!(adapter.format_sent()); match &msg.union { - Some(msg_union::Union::Misc(misc)) => { - match &misc.union { - Some(misc_union::Union::AudioFormat(fmt)) => { - assert_eq!(fmt.sample_rate, 48000); - assert_eq!(fmt.channels, 2); - } - _ => panic!("Expected AudioFormat"), + Some(msg_union::Union::Misc(misc)) => match &misc.union { + Some(misc_union::Union::AudioFormat(fmt)) => { + assert_eq!(fmt.sample_rate, 48000); + assert_eq!(fmt.channels, 2); } - } + _ => panic!("Expected AudioFormat"), + }, _ => panic!("Expected Misc"), } } diff --git a/src/rustdesk/hid_adapter.rs b/src/rustdesk/hid_adapter.rs index 75edba96..4261f8d9 100644 --- a/src/rustdesk/hid_adapter.rs +++ b/src/rustdesk/hid_adapter.rs @@ -2,13 +2,13 @@ //! //! Converts RustDesk HID events (KeyEvent, MouseEvent) to One-KVM HID events. -use protobuf::Enum; -use crate::hid::{ - KeyboardEvent, KeyboardModifiers, KeyEventType, - MouseButton, MouseEvent as OneKvmMouseEvent, MouseEventType, -}; -use super::protocol::{KeyEvent, MouseEvent, ControlKey}; use super::protocol::hbb::message::key_event as ke_union; +use super::protocol::{ControlKey, KeyEvent, MouseEvent}; +use crate::hid::{ + KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent as OneKvmMouseEvent, + MouseEventType, +}; +use protobuf::Enum; /// Mouse event types from RustDesk protocol /// mask = (button << 3) | event_type @@ -32,7 +32,11 @@ pub mod mouse_button { /// Convert RustDesk MouseEvent to One-KVM MouseEvent(s) /// Returns a Vec because a single RustDesk event may need multiple One-KVM events /// (e.g., move + button + scroll) -pub fn convert_mouse_event(event: &MouseEvent, screen_width: u32, screen_height: u32) -> Vec { +pub fn convert_mouse_event( + event: &MouseEvent, + screen_width: u32, + screen_height: u32, +) -> Vec { let mut events = Vec::new(); // RustDesk uses absolute coordinates @@ -243,10 +247,10 @@ fn parse_modifiers(event: &KeyEvent) -> KeyboardModifiers { /// Convert RustDesk ControlKey to USB HID usage code fn control_key_to_hid(key: i32) -> Option { match key { - x if x == ControlKey::Alt as i32 => Some(0xE2), // Left Alt + x if x == ControlKey::Alt as i32 => Some(0xE2), // Left Alt x if x == ControlKey::Backspace as i32 => Some(0x2A), x if x == ControlKey::CapsLock as i32 => Some(0x39), - x if x == ControlKey::Control as i32 => Some(0xE0), // Left Ctrl + x if x == ControlKey::Control as i32 => Some(0xE0), // Left Ctrl x if x == ControlKey::Delete as i32 => Some(0x4C), x if x == ControlKey::DownArrow as i32 => Some(0x51), x if x == ControlKey::End as i32 => Some(0x4D), @@ -265,12 +269,12 @@ fn control_key_to_hid(key: i32) -> Option { x if x == ControlKey::F12 as i32 => Some(0x45), x if x == ControlKey::Home as i32 => Some(0x4A), x if x == ControlKey::LeftArrow as i32 => Some(0x50), - x if x == ControlKey::Meta as i32 => Some(0xE3), // Left GUI/Windows + x if x == ControlKey::Meta as i32 => Some(0xE3), // Left GUI/Windows x if x == ControlKey::PageDown as i32 => Some(0x4E), x if x == ControlKey::PageUp as i32 => Some(0x4B), x if x == ControlKey::Return as i32 => Some(0x28), x if x == ControlKey::RightArrow as i32 => Some(0x4F), - x if x == ControlKey::Shift as i32 => Some(0xE1), // Left Shift + x if x == ControlKey::Shift as i32 => Some(0xE1), // Left Shift x if x == ControlKey::Space as i32 => Some(0x2C), x if x == ControlKey::Tab as i32 => Some(0x2B), x if x == ControlKey::UpArrow as i32 => Some(0x52), @@ -330,7 +334,7 @@ fn ascii_to_hid(ascii: u32) -> Option { Some((ascii - 65 + 0x04) as u8) } // Numbers 0-9 (ASCII 48-57) - 48 => Some(0x27), // 0 + 48 => Some(0x27), // 0 49..=57 => Some((ascii - 49 + 0x1E) as u8), // 1-9 // Common punctuation 32 => Some(0x2C), // Space @@ -341,17 +345,17 @@ fn ascii_to_hid(ascii: u32) -> Option { 8 => Some(0x2A), // Backspace 127 => Some(0x4C), // Delete // Symbols (US keyboard layout) - 45 => Some(0x2D), // - - 61 => Some(0x2E), // = - 91 => Some(0x2F), // [ - 93 => Some(0x30), // ] - 92 => Some(0x31), // \ - 59 => Some(0x33), // ; - 39 => Some(0x34), // ' - 96 => Some(0x35), // ` - 44 => Some(0x36), // , - 46 => Some(0x37), // . - 47 => Some(0x38), // / + 45 => Some(0x2D), // - + 61 => Some(0x2E), // = + 91 => Some(0x2F), // [ + 93 => Some(0x30), // ] + 92 => Some(0x31), // \ + 59 => Some(0x33), // ; + 39 => Some(0x34), // ' + 96 => Some(0x35), // ` + 44 => Some(0x36), // , + 46 => Some(0x37), // . + 47 => Some(0x38), // / _ => None, } } @@ -394,10 +398,10 @@ fn windows_vk_to_hid(vk: u32) -> Option { }) } // Numbers 0-9 (VK_0=0x30 to VK_9=0x39) - 0x30 => Some(0x27), // 0 + 0x30 => Some(0x27), // 0 0x31..=0x39 => Some((vk - 0x31 + 0x1E) as u8), // 1-9 // Numpad 0-9 (VK_NUMPAD0=0x60 to VK_NUMPAD9=0x69) - 0x60 => Some(0x62), // Numpad 0 + 0x60 => Some(0x62), // Numpad 0 0x61..=0x69 => Some((vk - 0x61 + 0x59) as u8), // Numpad 1-9 // Numpad operators 0x6A => Some(0x55), // Numpad * @@ -451,7 +455,7 @@ fn x11_keycode_to_hid(keycode: u32) -> Option { match keycode { // Numbers: X11 keycode 10="1", 11="2", ..., 18="9", 19="0" 10..=18 => Some((keycode - 10 + 0x1E) as u8), // 1-9 - 19 => Some(0x27), // 0 + 19 => Some(0x27), // 0 // Punctuation 20 => Some(0x2D), // - 21 => Some(0x2E), // = @@ -533,7 +537,9 @@ mod tests { let events = convert_mouse_event(&event, 1920, 1080); assert!(events.len() >= 2); // Should have a button down event - assert!(events.iter().any(|e| e.event_type == MouseEventType::Down && e.button == Some(MouseButton::Left))); + assert!(events + .iter() + .any(|e| e.event_type == MouseEventType::Down && e.button == Some(MouseButton::Left))); } #[test] @@ -542,7 +548,9 @@ mod tests { let mut key_event = KeyEvent::new(); key_event.down = true; key_event.press = false; - key_event.union = Some(ke_union::Union::ControlKey(EnumOrUnknown::new(ControlKey::Return))); + key_event.union = Some(ke_union::Union::ControlKey(EnumOrUnknown::new( + ControlKey::Return, + ))); let result = convert_key_event(&key_event); assert!(result.is_some()); diff --git a/src/rustdesk/mod.rs b/src/rustdesk/mod.rs index b262632f..23178c47 100644 --- a/src/rustdesk/mod.rs +++ b/src/rustdesk/mod.rs @@ -205,7 +205,8 @@ impl RustDeskService { self.connection_manager.set_audio(self.audio.clone()); // Set the video manager on connection manager for video streaming - self.connection_manager.set_video_manager(self.video_manager.clone()); + self.connection_manager + .set_video_manager(self.video_manager.clone()); *self.rendezvous.write() = Some(mediator.clone()); @@ -231,105 +232,117 @@ impl RustDeskService { let audio_punch = self.audio.clone(); let service_config_punch = self.config.clone(); - mediator.set_punch_callback(Arc::new(move |peer_addr, rendezvous_addr, relay_server, uuid, socket_addr, device_id| { - let conn_mgr = connection_manager_punch.clone(); - let video = video_manager_punch.clone(); - let hid = hid_punch.clone(); - let audio = audio_punch.clone(); - let config = service_config_punch.clone(); + mediator.set_punch_callback(Arc::new( + move |peer_addr, rendezvous_addr, relay_server, uuid, socket_addr, device_id| { + let conn_mgr = connection_manager_punch.clone(); + let video = video_manager_punch.clone(); + let hid = hid_punch.clone(); + let audio = audio_punch.clone(); + let config = service_config_punch.clone(); - tokio::spawn(async move { - // Get relay_key from config (no public server fallback) - let relay_key = { - let cfg = config.read(); - cfg.relay_key.clone().unwrap_or_default() - }; + tokio::spawn(async move { + // Get relay_key from config (no public server fallback) + let relay_key = { + let cfg = config.read(); + cfg.relay_key.clone().unwrap_or_default() + }; - // Try P2P direct connection first - if let Some(addr) = peer_addr { - info!("Attempting P2P direct connection to {}", addr); - match punch::try_direct_connection(addr).await { - punch::PunchResult::DirectConnection(stream) => { - info!("P2P direct connection succeeded to {}", addr); - if let Err(e) = conn_mgr.accept_connection(stream, addr).await { - error!("Failed to accept P2P connection: {}", e); + // Try P2P direct connection first + if let Some(addr) = peer_addr { + info!("Attempting P2P direct connection to {}", addr); + match punch::try_direct_connection(addr).await { + punch::PunchResult::DirectConnection(stream) => { + info!("P2P direct connection succeeded to {}", addr); + if let Err(e) = conn_mgr.accept_connection(stream, addr).await { + error!("Failed to accept P2P connection: {}", e); + } + return; + } + punch::PunchResult::NeedRelay => { + info!("P2P direct connection failed, falling back to relay"); } - return; - } - punch::PunchResult::NeedRelay => { - info!("P2P direct connection failed, falling back to relay"); } } - } - // Fall back to relay - if let Err(e) = handle_relay_request( - &rendezvous_addr, - &relay_server, - &uuid, - &socket_addr, - &device_id, - &relay_key, - conn_mgr, - video, - hid, - audio, - ).await { - error!("Failed to handle relay request: {}", e); - } - }); - })); + // Fall back to relay + if let Err(e) = handle_relay_request( + &rendezvous_addr, + &relay_server, + &uuid, + &socket_addr, + &device_id, + &relay_key, + conn_mgr, + video, + hid, + audio, + ) + .await + { + error!("Failed to handle relay request: {}", e); + } + }); + }, + )); // Set the relay callback on the mediator - mediator.set_relay_callback(Arc::new(move |rendezvous_addr, relay_server, uuid, socket_addr, device_id| { - let conn_mgr = connection_manager.clone(); - let video = video_manager.clone(); - let hid = hid.clone(); - let audio = audio.clone(); - let config = service_config.clone(); + mediator.set_relay_callback(Arc::new( + move |rendezvous_addr, relay_server, uuid, socket_addr, device_id| { + let conn_mgr = connection_manager.clone(); + let video = video_manager.clone(); + let hid = hid.clone(); + let audio = audio.clone(); + let config = service_config.clone(); - tokio::spawn(async move { - // Get relay_key from config (no public server fallback) - let relay_key = { - let cfg = config.read(); - cfg.relay_key.clone().unwrap_or_default() - }; + tokio::spawn(async move { + // Get relay_key from config (no public server fallback) + let relay_key = { + let cfg = config.read(); + cfg.relay_key.clone().unwrap_or_default() + }; - if let Err(e) = handle_relay_request( - &rendezvous_addr, - &relay_server, - &uuid, - &socket_addr, - &device_id, - &relay_key, - conn_mgr, - video, - hid, - audio, - ).await { - error!("Failed to handle relay request: {}", e); - } - }); - })); + if let Err(e) = handle_relay_request( + &rendezvous_addr, + &relay_server, + &uuid, + &socket_addr, + &device_id, + &relay_key, + conn_mgr, + video, + hid, + audio, + ) + .await + { + error!("Failed to handle relay request: {}", e); + } + }); + }, + )); // Set the intranet callback on the mediator for same-LAN connections let connection_manager2 = self.connection_manager.clone(); - mediator.set_intranet_callback(Arc::new(move |rendezvous_addr, peer_socket_addr, local_addr, relay_server, device_id| { - let conn_mgr = connection_manager2.clone(); + mediator.set_intranet_callback(Arc::new( + move |rendezvous_addr, peer_socket_addr, local_addr, relay_server, device_id| { + let conn_mgr = connection_manager2.clone(); - tokio::spawn(async move { - if let Err(e) = handle_intranet_request( - &rendezvous_addr, - &peer_socket_addr, - local_addr, - &relay_server, - &device_id, - conn_mgr, - ).await { - error!("Failed to handle intranet request: {}", e); - } - }); - })); + tokio::spawn(async move { + if let Err(e) = handle_intranet_request( + &rendezvous_addr, + &peer_socket_addr, + local_addr, + &relay_server, + &device_id, + conn_mgr, + ) + .await + { + error!("Failed to handle intranet request: {}", e); + } + }); + }, + )); // Spawn rendezvous task let status = self.status.clone(); @@ -471,7 +484,9 @@ impl RustDeskService { // Save signing keypair (Ed25519) let signing_pk = skp.public_key_base64(); let signing_sk = skp.secret_key_base64(); - if config.signing_public_key.as_ref() != Some(&signing_pk) || config.signing_private_key.as_ref() != Some(&signing_sk) { + if config.signing_public_key.as_ref() != Some(&signing_pk) + || config.signing_private_key.as_ref() != Some(&signing_sk) + { config.signing_public_key = Some(signing_pk); config.signing_private_key = Some(signing_sk); changed = true; @@ -522,13 +537,18 @@ async fn handle_relay_request( _hid: Arc, _audio: Arc, ) -> anyhow::Result<()> { - info!("Handling relay request: rendezvous={}, relay={}, uuid={}", rendezvous_addr, relay_server, uuid); + info!( + "Handling relay request: rendezvous={}, relay={}, uuid={}", + rendezvous_addr, relay_server, uuid + ); // Step 1: Connect to RENDEZVOUS server and send RelayResponse let rendezvous_socket_addr: SocketAddr = tokio::net::lookup_host(rendezvous_addr) .await? .next() - .ok_or_else(|| anyhow::anyhow!("Failed to resolve rendezvous server: {}", rendezvous_addr))?; + .ok_or_else(|| { + anyhow::anyhow!("Failed to resolve rendezvous server: {}", rendezvous_addr) + })?; let mut rendezvous_stream = tokio::time::timeout( Duration::from_millis(RELAY_CONNECT_TIMEOUT_MS), @@ -537,12 +557,17 @@ async fn handle_relay_request( .await .map_err(|_| anyhow::anyhow!("Rendezvous connection timeout"))??; - debug!("Connected to rendezvous server at {}", rendezvous_socket_addr); + debug!( + "Connected to rendezvous server at {}", + rendezvous_socket_addr + ); // Send RelayResponse to rendezvous server with client's socket_addr // IMPORTANT: Include our device ID so rendezvous server can look up and sign our public key let relay_response = make_relay_response(uuid, socket_addr, relay_server, device_id); - let bytes = relay_response.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + let bytes = relay_response + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; bytes_codec::write_frame(&mut rendezvous_stream, &bytes).await?; debug!("Sent RelayResponse to rendezvous server for uuid={}", uuid); @@ -568,7 +593,9 @@ async fn handle_relay_request( // The licence_key is required if the relay server is configured with -k option // The socket_addr is CRITICAL - the relay server uses it to match us with the peer let request_relay = make_request_relay(uuid, relay_key, socket_addr); - let bytes = request_relay.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + let bytes = request_relay + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; bytes_codec::write_frame(&mut stream, &bytes).await?; debug!("Sent RequestRelay to relay server for uuid={}", uuid); @@ -576,8 +603,13 @@ async fn handle_relay_request( let peer_addr = rendezvous::AddrMangle::decode(socket_addr).unwrap_or(relay_addr); // Step 3: Accept connection - relay server bridges the connection - connection_manager.accept_connection(stream, peer_addr).await?; - info!("Relay connection established for uuid={}, peer={}", uuid, peer_addr); + connection_manager + .accept_connection(stream, peer_addr) + .await?; + info!( + "Relay connection established for uuid={}, peer={}", + uuid, peer_addr + ); Ok(()) } @@ -608,14 +640,15 @@ async fn handle_intranet_request( debug!("Peer address from FetchLocalAddr: {:?}", peer_addr); // Connect to rendezvous server via TCP with timeout - let mut stream = tokio::time::timeout( - Duration::from_secs(5), - TcpStream::connect(rendezvous_addr), - ) - .await - .map_err(|_| anyhow::anyhow!("Timeout connecting to rendezvous server"))??; + let mut stream = + tokio::time::timeout(Duration::from_secs(5), TcpStream::connect(rendezvous_addr)) + .await + .map_err(|_| anyhow::anyhow!("Timeout connecting to rendezvous server"))??; - info!("Connected to rendezvous server for intranet: {}", rendezvous_addr); + info!( + "Connected to rendezvous server for intranet: {}", + rendezvous_addr + ); // Build LocalAddr message with our local address (mangled) let local_addr_bytes = AddrMangle::encode(local_addr); @@ -626,7 +659,9 @@ async fn handle_intranet_request( device_id, env!("CARGO_PKG_VERSION"), ); - let bytes = msg.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + let bytes = msg + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; // Send LocalAddr using RustDesk's variable-length framing bytes_codec::write_frame(&mut stream, &bytes).await?; @@ -640,11 +675,15 @@ async fn handle_intranet_request( // Get peer address for logging/connection tracking let effective_peer_addr = peer_addr.unwrap_or_else(|| { // If we can't decode the peer address, use the rendezvous server address - rendezvous_addr.parse().unwrap_or_else(|_| "0.0.0.0:0".parse().unwrap()) + rendezvous_addr + .parse() + .unwrap_or_else(|_| "0.0.0.0:0".parse().unwrap()) }); // Accept the connection - the stream is now a proxied connection to the client - connection_manager.accept_connection(stream, effective_peer_addr).await?; + connection_manager + .accept_connection(stream, effective_peer_addr) + .await?; info!("Intranet connection established via rendezvous server proxy"); Ok(()) diff --git a/src/rustdesk/protocol.rs b/src/rustdesk/protocol.rs index c772ec42..3519faaa 100644 --- a/src/rustdesk/protocol.rs +++ b/src/rustdesk/protocol.rs @@ -14,22 +14,20 @@ pub mod hbb { // Re-export commonly used types pub use hbb::rendezvous::{ - rendezvous_message, relay_response, punch_hole_response, - ConnType, ConfigUpdate, FetchLocalAddr, HealthCheck, KeyExchange, LocalAddr, NatType, - OnlineRequest, OnlineResponse, PeerDiscovery, PunchHole, PunchHoleRequest, PunchHoleResponse, - PunchHoleSent, RegisterPeer, RegisterPeerResponse, RegisterPk, RegisterPkResponse, - RelayResponse, RendezvousMessage, RequestRelay, SoftwareUpdate, TestNatRequest, - TestNatResponse, + punch_hole_response, relay_response, rendezvous_message, ConfigUpdate, ConnType, + FetchLocalAddr, HealthCheck, KeyExchange, LocalAddr, NatType, OnlineRequest, OnlineResponse, + PeerDiscovery, PunchHole, PunchHoleRequest, PunchHoleResponse, PunchHoleSent, RegisterPeer, + RegisterPeerResponse, RegisterPk, RegisterPkResponse, RelayResponse, RendezvousMessage, + RequestRelay, SoftwareUpdate, TestNatRequest, TestNatResponse, }; // Re-export message.proto types pub use hbb::message::{ - message, misc, login_response, key_event, - AudioFormat, AudioFrame, Auth2FA, Clipboard, CursorData, CursorPosition, EncodedVideoFrame, - EncodedVideoFrames, Hash, IdPk, KeyEvent, LoginRequest, LoginResponse, MouseEvent, Misc, - OptionMessage, PeerInfo, PublicKey, SignedId, SupportedDecoding, VideoFrame, TestDelay, - Features, SupportedResolutions, WindowsSessions, Message as HbbMessage, ControlKey, - DisplayInfo, SupportedEncoding, + key_event, login_response, message, misc, AudioFormat, AudioFrame, Auth2FA, Clipboard, + ControlKey, CursorData, CursorPosition, DisplayInfo, EncodedVideoFrame, EncodedVideoFrames, + Features, Hash, IdPk, KeyEvent, LoginRequest, LoginResponse, Message as HbbMessage, Misc, + MouseEvent, OptionMessage, PeerInfo, PublicKey, SignedId, SupportedDecoding, SupportedEncoding, + SupportedResolutions, TestDelay, VideoFrame, WindowsSessions, }; /// Helper to create a RendezvousMessage with RegisterPeer @@ -80,7 +78,12 @@ pub fn make_punch_hole_sent( /// IMPORTANT: The union field should be `Id` (our device ID), NOT `Pk`. /// The rendezvous server will look up our registered public key using this ID, /// sign it with the server's private key, and set the `pk` field before forwarding to client. -pub fn make_relay_response(uuid: &str, socket_addr: &[u8], relay_server: &str, device_id: &str) -> RendezvousMessage { +pub fn make_relay_response( + uuid: &str, + socket_addr: &[u8], + relay_server: &str, + device_id: &str, +) -> RendezvousMessage { let mut rr = RelayResponse::new(); rr.socket_addr = socket_addr.to_vec().into(); rr.uuid = uuid.to_string(); diff --git a/src/rustdesk/punch.rs b/src/rustdesk/punch.rs index cc9aa774..ad6cea80 100644 --- a/src/rustdesk/punch.rs +++ b/src/rustdesk/punch.rs @@ -69,10 +69,7 @@ impl PunchHoleHandler { /// /// Tries direct connection first, falls back to relay if needed. /// Returns true if direct connection succeeded, false if relay is needed. - pub async fn handle_punch_hole( - &self, - peer_addr: Option, - ) -> bool { + pub async fn handle_punch_hole(&self, peer_addr: Option) -> bool { let peer_addr = match peer_addr { Some(addr) => addr, None => { @@ -84,7 +81,11 @@ impl PunchHoleHandler { match try_direct_connection(peer_addr).await { PunchResult::DirectConnection(stream) => { // Direct connection succeeded, accept it - match self.connection_manager.accept_connection(stream, peer_addr).await { + match self + .connection_manager + .accept_connection(stream, peer_addr) + .await + { Ok(_) => { info!("P2P direct connection established with {}", peer_addr); true diff --git a/src/rustdesk/rendezvous.rs b/src/rustdesk/rendezvous.rs index a28bdc3b..8b411769 100644 --- a/src/rustdesk/rendezvous.rs +++ b/src/rustdesk/rendezvous.rs @@ -18,8 +18,8 @@ use tracing::{debug, error, info, warn}; use super::config::RustDeskConfig; use super::crypto::{KeyPair, SigningKeyPair}; use super::protocol::{ - rendezvous_message, make_punch_hole_sent, make_register_peer, - make_register_pk, NatType, RendezvousMessage, decode_rendezvous_message, + decode_rendezvous_message, make_punch_hole_sent, make_register_peer, make_register_pk, + rendezvous_message, NatType, RendezvousMessage, }; /// Registration interval in milliseconds @@ -81,7 +81,8 @@ pub type RelayCallback = Arc, String) + S /// Callback type for P2P punch hole requests /// Parameters: peer_addr (decoded), relay_callback_params (rendezvous_addr, relay_server, uuid, socket_addr, device_id) /// Returns: should call relay callback if P2P fails -pub type PunchCallback = Arc, String, String, String, Vec, String) + Send + Sync>; +pub type PunchCallback = + Arc, String, String, String, Vec, String) + Send + Sync>; /// Callback type for intranet/local address connections /// Parameters: rendezvous_addr, peer_socket_addr (mangled), local_addr, relay_server, device_id @@ -232,7 +233,8 @@ impl RendezvousMediator { if signing_guard.is_none() { let config = self.config.read(); // Try to load from config first - if let (Some(pk), Some(sk)) = (&config.signing_public_key, &config.signing_private_key) { + if let (Some(pk), Some(sk)) = (&config.signing_public_key, &config.signing_private_key) + { if let Ok(skp) = SigningKeyPair::from_base64(pk, sk) { debug!("Loaded signing keypair from config"); *signing_guard = Some(skp.clone()); @@ -265,14 +267,20 @@ impl RendezvousMediator { config.enabled, effective_server ); if !config.enabled || effective_server.is_empty() { - info!("Rendezvous mediator not starting: enabled={}, server='{}'", config.enabled, effective_server); + info!( + "Rendezvous mediator not starting: enabled={}, server='{}'", + config.enabled, effective_server + ); return Ok(()); } *self.status.write() = RendezvousStatus::Connecting; let addr = config.rendezvous_addr(); - info!("Starting rendezvous mediator for {} to {}", config.device_id, addr); + info!( + "Starting rendezvous mediator for {} to {}", + config.device_id, addr + ); // Resolve server address let server_addr: SocketAddr = tokio::net::lookup_host(&addr) @@ -376,7 +384,9 @@ impl RendezvousMediator { let serial = *self.serial.read(); let msg = make_register_peer(&id, serial); - let bytes = msg.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + let bytes = msg + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; socket.send(&bytes).await?; Ok(()) } @@ -393,7 +403,9 @@ impl RendezvousMediator { debug!("Sending RegisterPk: id={}", id); let msg = make_register_pk(&id, &uuid, pk, ""); - let bytes = msg.write_to_bytes().map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; + let bytes = msg + .write_to_bytes() + .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; socket.send(&bytes).await?; Ok(()) } @@ -540,7 +552,7 @@ impl RendezvousMediator { ); let msg = make_punch_hole_sent( - &ph.socket_addr.to_vec(), // Use peer's socket_addr, not ours + &ph.socket_addr.to_vec(), // Use peer's socket_addr, not ours &id, &ph.relay_server, ph.nat_type.enum_value().unwrap_or(NatType::UNKNOWN_NAT), @@ -570,9 +582,22 @@ impl RendezvousMediator { // Use punch callback if set (tries P2P first, then relay) // Otherwise fall back to relay callback directly if let Some(callback) = self.punch_callback.read().as_ref() { - callback(peer_addr, rendezvous_addr, relay_server, uuid, ph.socket_addr.to_vec(), device_id); + callback( + peer_addr, + rendezvous_addr, + relay_server, + uuid, + ph.socket_addr.to_vec(), + device_id, + ); } else if let Some(callback) = self.relay_callback.read().as_ref() { - callback(rendezvous_addr, relay_server, uuid, ph.socket_addr.to_vec(), device_id); + callback( + rendezvous_addr, + relay_server, + uuid, + ph.socket_addr.to_vec(), + device_id, + ); } } } @@ -591,7 +616,13 @@ impl RendezvousMediator { let config = self.config.read().clone(); let rendezvous_addr = config.rendezvous_addr(); let device_id = config.device_id.clone(); - callback(rendezvous_addr, relay_server, rr.uuid.clone(), rr.socket_addr.to_vec(), device_id); + callback( + rendezvous_addr, + relay_server, + rr.uuid.clone(), + rr.socket_addr.to_vec(), + device_id, + ); } } Some(rendezvous_message::Union::FetchLocalAddr(fla)) => { @@ -602,7 +633,8 @@ impl RendezvousMediator { peer_addr, fla.socket_addr.len(), fla.relay_server ); // Respond with our local address for same-LAN direct connection - self.send_local_addr(socket, &fla.socket_addr, &fla.relay_server).await?; + self.send_local_addr(socket, &fla.socket_addr, &fla.relay_server) + .await?; } Some(rendezvous_message::Union::ConfigureUpdate(cu)) => { info!("Received ConfigureUpdate, serial={}", cu.serial); diff --git a/src/state.rs b/src/state.rs index d1104aef..5a4f8b65 100644 --- a/src/state.rs +++ b/src/state.rs @@ -5,7 +5,10 @@ use crate::atx::AtxController; use crate::audio::AudioController; use crate::auth::{SessionStore, UserStore}; use crate::config::ConfigStore; -use crate::events::{AtxDeviceInfo, AudioDeviceInfo, EventBus, HidDeviceInfo, MsdDeviceInfo, SystemEvent, VideoDeviceInfo}; +use crate::events::{ + AtxDeviceInfo, AudioDeviceInfo, EventBus, HidDeviceInfo, MsdDeviceInfo, SystemEvent, + VideoDeviceInfo, +}; use crate::extensions::ExtensionManager; use crate::hid::HidController; use crate::msd::MsdController; diff --git a/src/stream/mjpeg.rs b/src/stream/mjpeg.rs index ccf4c352..1b7607b6 100644 --- a/src/stream/mjpeg.rs +++ b/src/stream/mjpeg.rs @@ -12,8 +12,8 @@ use std::time::{Duration, Instant}; use tokio::sync::broadcast; use tracing::{debug, info, warn}; -use crate::video::encoder::JpegEncoder; use crate::video::encoder::traits::{Encoder, EncoderConfig}; +use crate::video::encoder::JpegEncoder; use crate::video::format::PixelFormat; use crate::video::VideoFrame; @@ -256,7 +256,10 @@ impl MjpegStreamHandler { let config = EncoderConfig::jpeg(resolution, 85); match JpegEncoder::new(config) { Ok(enc) => { - debug!("Created JPEG encoder for MJPEG stream: {}x{}", resolution.width, resolution.height); + debug!( + "Created JPEG encoder for MJPEG stream: {}x{}", + resolution.width, resolution.height + ); enc } Err(e) => { @@ -270,37 +273,40 @@ impl MjpegStreamHandler { // Check if resolution changed if encoder.config().resolution != resolution { - debug!("Resolution changed, recreating JPEG encoder: {}x{}", resolution.width, resolution.height); + debug!( + "Resolution changed, recreating JPEG encoder: {}x{}", + resolution.width, resolution.height + ); let config = EncoderConfig::jpeg(resolution, 85); - *encoder = JpegEncoder::new(config).map_err(|e| format!("Failed to create encoder: {}", e))?; + *encoder = + JpegEncoder::new(config).map_err(|e| format!("Failed to create encoder: {}", e))?; } // Encode based on input format let encoded = match frame.format { - PixelFormat::Yuyv => { - encoder.encode_yuyv(frame.data(), sequence) - .map_err(|e| format!("YUYV encode failed: {}", e))? - } - PixelFormat::Nv12 => { - encoder.encode_nv12(frame.data(), sequence) - .map_err(|e| format!("NV12 encode failed: {}", e))? - } - PixelFormat::Rgb24 => { - encoder.encode_rgb(frame.data(), sequence) - .map_err(|e| format!("RGB encode failed: {}", e))? - } - PixelFormat::Bgr24 => { - encoder.encode_bgr(frame.data(), sequence) - .map_err(|e| format!("BGR encode failed: {}", e))? - } + PixelFormat::Yuyv => encoder + .encode_yuyv(frame.data(), sequence) + .map_err(|e| format!("YUYV encode failed: {}", e))?, + PixelFormat::Nv12 => encoder + .encode_nv12(frame.data(), sequence) + .map_err(|e| format!("NV12 encode failed: {}", e))?, + PixelFormat::Rgb24 => encoder + .encode_rgb(frame.data(), sequence) + .map_err(|e| format!("RGB encode failed: {}", e))?, + PixelFormat::Bgr24 => encoder + .encode_bgr(frame.data(), sequence) + .map_err(|e| format!("BGR encode failed: {}", e))?, _ => { - return Err(format!("Unsupported format for JPEG encoding: {}", frame.format)); + return Err(format!( + "Unsupported format for JPEG encoding: {}", + frame.format + )); } }; - // Create new VideoFrame with JPEG data - Ok(VideoFrame::from_vec( - encoded.data.to_vec(), + // Create new VideoFrame with JPEG data (zero-copy: Bytes -> Arc) + Ok(VideoFrame::new( + encoded.data, resolution, PixelFormat::Mjpeg, 0, // stride not relevant for JPEG @@ -333,7 +339,11 @@ impl MjpegStreamHandler { pub fn register_client(&self, client_id: ClientId) { let session = ClientSession::new(client_id.clone()); self.clients.write().insert(client_id.clone(), session); - info!("Client {} connected (total: {})", client_id, self.client_count()); + info!( + "Client {} connected (total: {})", + client_id, + self.client_count() + ); } /// Unregister a client @@ -391,7 +401,9 @@ impl MjpegStreamHandler { *self.auto_pause_config.write() = config; info!( "Auto-pause config updated: enabled={}, delay={}s, timeout={}s", - config_clone.enabled, config_clone.shutdown_delay_secs, config_clone.client_timeout_secs + config_clone.enabled, + config_clone.shutdown_delay_secs, + config_clone.client_timeout_secs ); } @@ -440,10 +452,7 @@ impl ClientGuard { /// Create a new client guard pub fn new(client_id: ClientId, handler: Arc) -> Self { handler.register_client(client_id.clone()); - Self { - client_id, - handler, - } + Self { client_id, handler } } /// Get client ID @@ -535,8 +544,8 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool { #[cfg(test)] mod tests { use super::*; - use bytes::Bytes; use crate::video::{format::Resolution, PixelFormat}; + use bytes::Bytes; #[tokio::test] async fn test_stream_handler() { diff --git a/src/stream/mjpeg_streamer.rs b/src/stream/mjpeg_streamer.rs index b6d10977..ad4fb96c 100644 --- a/src/stream/mjpeg_streamer.rs +++ b/src/stream/mjpeg_streamer.rs @@ -228,16 +228,17 @@ impl MjpegStreamer { let device = self.current_device.read().await; let config = self.config.read().await; - let (resolution, format, frames_captured) = if let Some(ref cap) = *self.capturer.read().await { - let stats = cap.stats().await; - ( - Some((config.resolution.width, config.resolution.height)), - Some(config.format.to_string()), - stats.frames_captured, - ) - } else { - (None, None, 0) - }; + let (resolution, format, frames_captured) = + if let Some(ref cap) = *self.capturer.read().await { + let stats = cap.stats().await; + ( + Some((config.resolution.width, config.resolution.height)), + Some(config.format.to_string()), + stats.frames_captured, + ) + } else { + (None, None, 0) + }; MjpegStreamerStats { state: state.to_string(), @@ -286,7 +287,10 @@ impl MjpegStreamer { /// Initialize with specific device pub async fn init_with_device(self: &Arc, device: VideoDeviceInfo) -> Result<()> { - info!("MjpegStreamer: Initializing with device: {}", device.path.display()); + info!( + "MjpegStreamer: Initializing with device: {}", + device.path.display() + ); let config = self.config.read().await.clone(); @@ -322,7 +326,9 @@ impl MjpegStreamer { let _lock = self.start_lock.lock().await; if self.config_changing.load(Ordering::SeqCst) { - return Err(AppError::VideoError("Config change in progress".to_string())); + return Err(AppError::VideoError( + "Config change in progress".to_string(), + )); } let state = *self.state.read().await; @@ -332,7 +338,8 @@ impl MjpegStreamer { // Get capturer let capturer = self.capturer.read().await.clone(); - let capturer = capturer.ok_or_else(|| AppError::VideoError("Not initialized".to_string()))?; + let capturer = + capturer.ok_or_else(|| AppError::VideoError("Not initialized".to_string()))?; // Start capture capturer.start().await?; @@ -412,7 +419,9 @@ impl MjpegStreamer { let device = devices .into_iter() .find(|d| d.path == *path) - .ok_or_else(|| AppError::VideoError(format!("Device not found: {}", path.display())))?; + .ok_or_else(|| { + AppError::VideoError(format!("Device not found: {}", path.display())) + })?; self.init_with_device(device).await?; } diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 706cb573..110d77b1 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -13,5 +13,7 @@ pub mod mjpeg_streamer; pub mod ws_hid; pub use mjpeg::{ClientGuard, MjpegStreamHandler}; -pub use mjpeg_streamer::{MjpegStreamer, MjpegStreamerConfig, MjpegStreamerState, MjpegStreamerStats}; +pub use mjpeg_streamer::{ + MjpegStreamer, MjpegStreamerConfig, MjpegStreamerState, MjpegStreamerStats, +}; pub use ws_hid::WsHidHandler; diff --git a/src/stream/ws_hid.rs b/src/stream/ws_hid.rs index 929d7278..0940e884 100644 --- a/src/stream/ws_hid.rs +++ b/src/stream/ws_hid.rs @@ -142,7 +142,9 @@ impl WsHidHandler { shutdown_tx, }); - self.clients.write().insert(client_id.clone(), client.clone()); + self.clients + .write() + .insert(client_id.clone(), client.clone()); info!( "WsHidHandler: Client {} connected (total: {})", client_id, @@ -182,7 +184,11 @@ impl WsHidHandler { let (mut sender, mut receiver) = socket.split(); // Send initial status as binary: 0x00 = ok, 0x01 = error - let status_byte = if self.is_hid_available() { 0x00u8 } else { 0x01u8 }; + let status_byte = if self.is_hid_available() { + 0x00u8 + } else { + 0x01u8 + }; let _ = sender.send(Message::Binary(vec![status_byte].into())).await; loop { @@ -230,7 +236,10 @@ impl WsHidHandler { let hid = self.hid_controller.read().clone(); if let Some(hid) = hid { if let Err(e) = hid.reset().await { - warn!("WsHidHandler: Failed to reset HID on client {} disconnect: {}", client_id, e); + warn!( + "WsHidHandler: Failed to reset HID on client {} disconnect: {}", + client_id, e + ); } else { debug!("WsHidHandler: HID reset on client {} disconnect", client_id); } diff --git a/src/utils/throttle.rs b/src/utils/throttle.rs index f983e61d..68ac2edd 100644 --- a/src/utils/throttle.rs +++ b/src/utils/throttle.rs @@ -16,6 +16,7 @@ use std::time::{Duration, Instant}; /// /// ```rust /// use one_kvm::utils::LogThrottler; +/// use std::time::Duration; /// /// let throttler = LogThrottler::new(Duration::from_secs(5)); /// diff --git a/src/video/capture.rs b/src/video/capture.rs index b56bb374..21598d9e 100644 --- a/src/video/capture.rs +++ b/src/video/capture.rs @@ -231,7 +231,9 @@ impl VideoCapturer { let last_error = self.last_error.clone(); let handle = tokio::task::spawn_blocking(move || { - capture_loop(config, state, stats, frame_tx, stop_flag, sequence, last_error); + capture_loop( + config, state, stats, frame_tx, stop_flag, sequence, last_error, + ); }); *self.capture_handle.lock().await = Some(handle); @@ -275,14 +277,7 @@ fn capture_loop( sequence: Arc, error_holder: Arc>>, ) { - let result = run_capture( - &config, - &state, - &stats, - &frame_tx, - &stop_flag, - &sequence, - ); + let result = run_capture(&config, &state, &stats, &frame_tx, &stop_flag, &sequence); match result { Ok(_) => { @@ -503,7 +498,10 @@ fn run_capture_inner( // Validate frame if frame_size < MIN_FRAME_SIZE { - debug!("Dropping small frame: {} bytes (bytesused={})", frame_size, meta.bytesused); + debug!( + "Dropping small frame: {} bytes (bytesused={})", + frame_size, meta.bytesused + ); if let Ok(mut s) = stats.try_lock() { s.frames_dropped += 1; } @@ -606,18 +604,12 @@ impl FrameGrabber { } /// Capture a single frame - pub async fn grab( - &self, - resolution: Resolution, - format: PixelFormat, - ) -> Result { + pub async fn grab(&self, resolution: Resolution, format: PixelFormat) -> Result { let device_path = self.device_path.clone(); - tokio::task::spawn_blocking(move || { - grab_single_frame(&device_path, resolution, format) - }) - .await - .map_err(|e| AppError::VideoError(format!("Grab task failed: {}", e)))? + tokio::task::spawn_blocking(move || grab_single_frame(&device_path, resolution, format)) + .await + .map_err(|e| AppError::VideoError(format!("Grab task failed: {}", e)))? } } @@ -626,14 +618,13 @@ fn grab_single_frame( resolution: Resolution, format: PixelFormat, ) -> Result { - let device = Device::with_path(device_path).map_err(|e| { - AppError::VideoError(format!("Failed to open device: {}", e)) - })?; + let device = Device::with_path(device_path) + .map_err(|e| AppError::VideoError(format!("Failed to open device: {}", e)))?; let fmt = Format::new(resolution.width, resolution.height, format.to_fourcc()); - let actual = device.set_format(&fmt).map_err(|e| { - AppError::VideoError(format!("Failed to set format: {}", e)) - })?; + let actual = device + .set_format(&fmt) + .map_err(|e| AppError::VideoError(format!("Failed to set format: {}", e)))?; let mut stream = MmapStream::with_buffers(&device, BufferType::VideoCapture, 2) .map_err(|e| AppError::VideoError(format!("Failed to create stream: {}", e)))?; @@ -643,8 +634,7 @@ fn grab_single_frame( match stream.next() { Ok((buf, _meta)) => { if buf.len() >= MIN_FRAME_SIZE { - let actual_format = - PixelFormat::from_fourcc(actual.fourcc).unwrap_or(format); + let actual_format = PixelFormat::from_fourcc(actual.fourcc).unwrap_or(format); return Ok(VideoFrame::new( Bytes::copy_from_slice(buf), @@ -657,16 +647,15 @@ fn grab_single_frame( } Err(e) => { if attempt == 4 { - return Err(AppError::VideoError(format!( - "Failed to grab frame: {}", - e - ))); + return Err(AppError::VideoError(format!("Failed to grab frame: {}", e))); } } } } - Err(AppError::VideoError("Failed to capture valid frame".to_string())) + Err(AppError::VideoError( + "Failed to capture valid frame".to_string(), + )) } #[cfg(test)] diff --git a/src/video/convert.rs b/src/video/convert.rs index ca60b80d..b93b63c1 100644 --- a/src/video/convert.rs +++ b/src/video/convert.rs @@ -233,6 +233,16 @@ impl PixelConverter { } } + /// Create a new converter for NV21 → YUV420P + pub fn nv21_to_yuv420p(resolution: Resolution) -> Self { + Self { + src_format: PixelFormat::Nv21, + dst_format: PixelFormat::Yuv420, + resolution, + output_buffer: Yuv420pBuffer::new(resolution), + } + } + /// Create a new converter for YVU420 → YUV420P (swap U and V planes) pub fn yvu420_to_yuv420p(resolution: Resolution) -> Self { Self { @@ -272,23 +282,39 @@ impl PixelConverter { match (self.src_format, self.dst_format) { (PixelFormat::Yuyv, PixelFormat::Yuv420) => { libyuv::yuy2_to_i420(input, self.output_buffer.as_bytes_mut(), width, height) - .map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?; + .map_err(|e| { + AppError::VideoError(format!("libyuv conversion failed: {}", e)) + })?; } (PixelFormat::Uyvy, PixelFormat::Yuv420) => { libyuv::uyvy_to_i420(input, self.output_buffer.as_bytes_mut(), width, height) - .map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?; + .map_err(|e| { + AppError::VideoError(format!("libyuv conversion failed: {}", e)) + })?; } (PixelFormat::Nv12, PixelFormat::Yuv420) => { libyuv::nv12_to_i420(input, self.output_buffer.as_bytes_mut(), width, height) - .map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?; + .map_err(|e| { + AppError::VideoError(format!("libyuv conversion failed: {}", e)) + })?; + } + (PixelFormat::Nv21, PixelFormat::Yuv420) => { + libyuv::nv21_to_i420(input, self.output_buffer.as_bytes_mut(), width, height) + .map_err(|e| { + AppError::VideoError(format!("libyuv conversion failed: {}", e)) + })?; } (PixelFormat::Rgb24, PixelFormat::Yuv420) => { libyuv::rgb24_to_i420(input, self.output_buffer.as_bytes_mut(), width, height) - .map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?; + .map_err(|e| { + AppError::VideoError(format!("libyuv conversion failed: {}", e)) + })?; } (PixelFormat::Bgr24, PixelFormat::Yuv420) => { libyuv::bgr24_to_i420(input, self.output_buffer.as_bytes_mut(), width, height) - .map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?; + .map_err(|e| { + AppError::VideoError(format!("libyuv conversion failed: {}", e)) + })?; } (PixelFormat::Yvyu, PixelFormat::Yuv420) => { // YVYU is not directly supported by libyuv, use software conversion @@ -307,7 +333,9 @@ impl PixelConverter { expected_size ))); } - self.output_buffer.as_bytes_mut().copy_from_slice(&input[..expected_size]); + self.output_buffer + .as_bytes_mut() + .copy_from_slice(&input[..expected_size]); } _ => { return Err(AppError::VideoError(format!( @@ -426,6 +454,8 @@ pub struct Nv12Converter { resolution: Resolution, /// Output buffer (reused across conversions) output_buffer: Nv12Buffer, + /// Optional I420 buffer for intermediate conversions + i420_buffer: Option, } impl Nv12Converter { @@ -435,6 +465,7 @@ impl Nv12Converter { src_format: PixelFormat::Bgr24, resolution, output_buffer: Nv12Buffer::new(resolution), + i420_buffer: None, } } @@ -444,6 +475,7 @@ impl Nv12Converter { src_format: PixelFormat::Rgb24, resolution, output_buffer: Nv12Buffer::new(resolution), + i420_buffer: None, } } @@ -453,6 +485,37 @@ impl Nv12Converter { src_format: PixelFormat::Yuyv, resolution, output_buffer: Nv12Buffer::new(resolution), + i420_buffer: None, + } + } + + /// Create a new converter for YUV420P (I420) → NV12 + pub fn yuv420_to_nv12(resolution: Resolution) -> Self { + Self { + src_format: PixelFormat::Yuv420, + resolution, + output_buffer: Nv12Buffer::new(resolution), + i420_buffer: None, + } + } + + /// Create a new converter for NV21 → NV12 + pub fn nv21_to_nv12(resolution: Resolution) -> Self { + Self { + src_format: PixelFormat::Nv21, + resolution, + output_buffer: Nv12Buffer::new(resolution), + i420_buffer: Some(Yuv420pBuffer::new(resolution)), + } + } + + /// Create a new converter for NV16 → NV12 (downsample chroma vertically) + pub fn nv16_to_nv12(resolution: Resolution) -> Self { + Self { + src_format: PixelFormat::Nv16, + resolution, + output_buffer: Nv12Buffer::new(resolution), + i420_buffer: None, } } @@ -460,12 +523,45 @@ impl Nv12Converter { pub fn convert(&mut self, input: &[u8]) -> Result<&[u8]> { let width = self.resolution.width as i32; let height = self.resolution.height as i32; - let dst = self.output_buffer.as_bytes_mut(); + // Handle formats that need custom conversion without holding dst borrow + match self.src_format { + PixelFormat::Nv21 => { + let mut i420 = self.i420_buffer.take().ok_or_else(|| { + AppError::VideoError("NV21 I420 buffer not initialized".to_string()) + })?; + { + let dst = self.output_buffer.as_bytes_mut(); + Self::convert_nv21_to_nv12_with_dims( + self.resolution.width as usize, + self.resolution.height as usize, + input, + dst, + &mut i420, + )?; + } + self.i420_buffer = Some(i420); + return Ok(self.output_buffer.as_bytes()); + } + PixelFormat::Nv16 => { + let dst = self.output_buffer.as_bytes_mut(); + Self::convert_nv16_to_nv12_with_dims( + self.resolution.width as usize, + self.resolution.height as usize, + input, + dst, + )?; + return Ok(self.output_buffer.as_bytes()); + } + _ => {} + } + + let dst = self.output_buffer.as_bytes_mut(); let result = match self.src_format { PixelFormat::Bgr24 => libyuv::bgr24_to_nv12(input, dst, width, height), PixelFormat::Rgb24 => libyuv::rgb24_to_nv12(input, dst, width, height), PixelFormat::Yuyv => libyuv::yuy2_to_nv12(input, dst, width, height), + PixelFormat::Yuv420 => libyuv::i420_to_nv12(input, dst, width, height), _ => { return Err(AppError::VideoError(format!( "Unsupported conversion to NV12: {}", @@ -474,10 +570,71 @@ impl Nv12Converter { } }; - result.map_err(|e| AppError::VideoError(format!("libyuv NV12 conversion failed: {}", e)))?; + result + .map_err(|e| AppError::VideoError(format!("libyuv NV12 conversion failed: {}", e)))?; Ok(self.output_buffer.as_bytes()) } + fn convert_nv21_to_nv12_with_dims( + width: usize, + height: usize, + input: &[u8], + dst: &mut [u8], + yuv: &mut Yuv420pBuffer, + ) -> Result<()> { + libyuv::nv21_to_i420(input, yuv.as_bytes_mut(), width as i32, height as i32) + .map_err(|e| AppError::VideoError(format!("libyuv NV21->I420 failed: {}", e)))?; + libyuv::i420_to_nv12(yuv.as_bytes(), dst, width as i32, height as i32) + .map_err(|e| AppError::VideoError(format!("libyuv I420->NV12 failed: {}", e)))?; + + Ok(()) + } + + fn convert_nv16_to_nv12_with_dims( + width: usize, + height: usize, + input: &[u8], + dst: &mut [u8], + ) -> Result<()> { + let y_size = width * height; + let uv_size_nv16 = y_size; // NV16 chroma plane is full height + let uv_size_nv12 = y_size / 2; + + if input.len() < y_size + uv_size_nv16 { + return Err(AppError::VideoError(format!( + "NV16 data too small: {} < {}", + input.len(), + y_size + uv_size_nv16 + ))); + } + + // Copy Y plane as-is + dst[..y_size].copy_from_slice(&input[..y_size]); + + // Downsample chroma vertically: average pairs of rows + let src_uv = &input[y_size..y_size + uv_size_nv16]; + let dst_uv = &mut dst[y_size..y_size + uv_size_nv12]; + + let src_row_bytes = width; + let dst_row_bytes = width; + let dst_rows = height / 2; + + for row in 0..dst_rows { + let src_row0 = + &src_uv[row * 2 * src_row_bytes..row * 2 * src_row_bytes + src_row_bytes]; + let src_row1 = &src_uv + [(row * 2 + 1) * src_row_bytes..(row * 2 + 1) * src_row_bytes + src_row_bytes]; + let dst_row = &mut dst_uv[row * dst_row_bytes..row * dst_row_bytes + dst_row_bytes]; + + for i in 0..dst_row_bytes { + let sum = src_row0[i] as u16 + src_row1[i] as u16; + dst_row[i] = (sum / 2) as u8; + } + } + + Ok(()) + } + /// Get output buffer length pub fn output_len(&self) -> usize { self.output_buffer.len() @@ -542,10 +699,8 @@ mod tests { // Create YUYV data (4x4 = 32 bytes) let yuyv = vec![ - 16, 128, 17, 129, 18, 130, 19, 131, - 20, 132, 21, 133, 22, 134, 23, 135, - 24, 136, 25, 137, 26, 138, 27, 139, - 28, 140, 29, 141, 30, 142, 31, 143, + 16, 128, 17, 129, 18, 130, 19, 131, 20, 132, 21, 133, 22, 134, 23, 135, 24, 136, 25, + 137, 26, 138, 27, 139, 28, 140, 29, 141, 30, 142, 31, 143, ]; let result = converter.convert(&yuyv).unwrap(); diff --git a/src/video/device.rs b/src/video/device.rs index 58eb4aed..4a655d51 100644 --- a/src/video/device.rs +++ b/src/video/device.rs @@ -95,9 +95,10 @@ impl VideoDevice { /// Get device capabilities pub fn capabilities(&self) -> Result { - let caps = self.device.query_caps().map_err(|e| { - AppError::VideoError(format!("Failed to query capabilities: {}", e)) - })?; + let caps = self + .device + .query_caps() + .map_err(|e| AppError::VideoError(format!("Failed to query capabilities: {}", e)))?; Ok(DeviceCapabilities { video_capture: caps.capabilities.contains(Flags::VIDEO_CAPTURE), @@ -110,9 +111,10 @@ impl VideoDevice { /// Get detailed device information pub fn info(&self) -> Result { - let caps = self.device.query_caps().map_err(|e| { - AppError::VideoError(format!("Failed to query capabilities: {}", e)) - })?; + let caps = self + .device + .query_caps() + .map_err(|e| AppError::VideoError(format!("Failed to query capabilities: {}", e)))?; let capabilities = DeviceCapabilities { video_capture: caps.capabilities.contains(Flags::VIDEO_CAPTURE), @@ -128,7 +130,8 @@ impl VideoDevice { let is_capture_card = Self::detect_capture_card(&caps.card, &caps.driver, &formats); // Calculate priority score - let priority = Self::calculate_priority(&caps.card, &caps.driver, &formats, is_capture_card); + let priority = + Self::calculate_priority(&caps.card, &caps.driver, &formats, is_capture_card); Ok(VideoDeviceInfo { path: self.path.clone(), @@ -148,9 +151,10 @@ impl VideoDevice { let mut formats = Vec::new(); // Get supported formats - let format_descs = self.device.enum_formats().map_err(|e| { - AppError::VideoError(format!("Failed to enumerate formats: {}", e)) - })?; + let format_descs = self + .device + .enum_formats() + .map_err(|e| AppError::VideoError(format!("Failed to enumerate formats: {}", e)))?; for desc in format_descs { // Try to convert FourCC to our PixelFormat @@ -186,7 +190,9 @@ impl VideoDevice { for size in sizes { match size.size { v4l::framesize::FrameSizeEnum::Discrete(d) => { - let fps = self.enumerate_fps(fourcc, d.width, d.height).unwrap_or_default(); + let fps = self + .enumerate_fps(fourcc, d.width, d.height) + .unwrap_or_default(); resolutions.push(ResolutionInfo::new(d.width, d.height, fps)); } v4l::framesize::FrameSizeEnum::Stepwise(s) => { @@ -202,8 +208,11 @@ impl VideoDevice { && res.height >= s.min_height && res.height <= s.max_height { - let fps = self.enumerate_fps(fourcc, res.width, res.height).unwrap_or_default(); - resolutions.push(ResolutionInfo::new(res.width, res.height, fps)); + let fps = self + .enumerate_fps(fourcc, res.width, res.height) + .unwrap_or_default(); + resolutions + .push(ResolutionInfo::new(res.width, res.height, fps)); } } } @@ -255,7 +264,7 @@ impl VideoDevice { fps_list.push(30); } } - + fps_list.sort_by(|a, b| b.cmp(a)); fps_list.dedup(); Ok(fps_list) @@ -263,9 +272,9 @@ impl VideoDevice { /// Get current format pub fn get_format(&self) -> Result { - self.device.format().map_err(|e| { - AppError::VideoError(format!("Failed to get format: {}", e)) - }) + self.device + .format() + .map_err(|e| AppError::VideoError(format!("Failed to get format: {}", e))) } /// Set capture format @@ -273,9 +282,10 @@ impl VideoDevice { let fmt = Format::new(width, height, format.to_fourcc()); // Request the format - let actual = self.device.set_format(&fmt).map_err(|e| { - AppError::VideoError(format!("Failed to set format: {}", e)) - })?; + let actual = self + .device + .set_format(&fmt) + .map_err(|e| AppError::VideoError(format!("Failed to set format: {}", e)))?; if actual.width != width || actual.height != height { warn!( @@ -374,9 +384,9 @@ pub fn enumerate_devices() -> Result> { let mut devices = Vec::new(); // Scan /dev/video* devices - for entry in std::fs::read_dir("/dev").map_err(|e| { - AppError::VideoError(format!("Failed to read /dev: {}", e)) - })? { + for entry in std::fs::read_dir("/dev") + .map_err(|e| AppError::VideoError(format!("Failed to read /dev: {}", e)))? + { let entry = match entry { Ok(e) => e, Err(_) => continue, @@ -432,9 +442,10 @@ pub fn enumerate_devices() -> Result> { pub fn find_best_device() -> Result { let devices = enumerate_devices()?; - devices.into_iter().next().ok_or_else(|| { - AppError::VideoError("No video capture devices found".to_string()) - }) + devices + .into_iter() + .next() + .ok_or_else(|| AppError::VideoError("No video capture devices found".to_string())) } #[cfg(test)] diff --git a/src/video/encoder/h264.rs b/src/video/encoder/h264.rs index 21886f97..d02a398b 100644 --- a/src/video/encoder/h264.rs +++ b/src/video/encoder/h264.rs @@ -99,8 +99,18 @@ pub enum H264InputFormat { Yuv420p, /// NV12 - Y plane + interleaved UV plane (optimal for VAAPI) Nv12, + /// NV21 - Y plane + interleaved VU plane + Nv21, + /// NV16 - Y plane + interleaved UV plane (4:2:2) + Nv16, + /// NV24 - Y plane + interleaved UV plane (4:4:4) + Nv24, /// YUYV422 - packed YUV 4:2:2 format (optimal for RKMPP direct input) Yuyv422, + /// RGB24 - packed RGB format (RKMPP direct input) + Rgb24, + /// BGR24 - packed BGR format (RKMPP direct input) + Bgr24, } impl Default for H264InputFormat { @@ -202,7 +212,7 @@ pub fn get_available_encoders(width: u32, height: u32) -> Vec { fps: 30, gop: 30, rc: RateControl::RC_CBR, - quality: Quality::Quality_Low, // Use low quality preset for fastest encoding (ultrafast) + quality: Quality::Quality_Low, // Use low quality preset for fastest encoding (ultrafast) kbs: 2000, q: 23, thread_count: 4, @@ -270,9 +280,8 @@ impl H264Encoder { // Detect best encoder let (_encoder_type, codec_name) = detect_best_encoder(width, height); - let codec_name = codec_name.ok_or_else(|| { - AppError::VideoError("No H.264 encoder available".to_string()) - })?; + let codec_name = codec_name + .ok_or_else(|| AppError::VideoError("No H.264 encoder available".to_string()))?; Self::with_codec(config, &codec_name) } @@ -287,8 +296,13 @@ impl H264Encoder { // Select pixel format based on config let pixfmt = match config.input_format { H264InputFormat::Nv12 => AVPixelFormat::AV_PIX_FMT_NV12, + H264InputFormat::Nv21 => AVPixelFormat::AV_PIX_FMT_NV21, + H264InputFormat::Nv16 => AVPixelFormat::AV_PIX_FMT_NV16, + H264InputFormat::Nv24 => AVPixelFormat::AV_PIX_FMT_NV24, H264InputFormat::Yuv420p => AVPixelFormat::AV_PIX_FMT_YUV420P, H264InputFormat::Yuyv422 => AVPixelFormat::AV_PIX_FMT_YUYV422, + H264InputFormat::Rgb24 => AVPixelFormat::AV_PIX_FMT_RGB24, + H264InputFormat::Bgr24 => AVPixelFormat::AV_PIX_FMT_BGR24, }; info!( @@ -306,10 +320,10 @@ impl H264Encoder { fps: config.fps as i32, gop: config.gop_size as i32, rc: RateControl::RC_CBR, - quality: Quality::Quality_Low, // Use low quality preset for fastest encoding (lowest latency) + quality: Quality::Quality_Low, // Use low quality preset for fastest encoding (lowest latency) kbs: config.bitrate_kbps as i32, q: 23, - thread_count: 4, // Use 4 threads for better performance + thread_count: 4, // Use 4 threads for better performance }; let inner = HwEncoder::new(ctx).map_err(|_| { @@ -353,9 +367,9 @@ impl H264Encoder { /// Update bitrate dynamically pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> { - self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| { - AppError::VideoError("Failed to set bitrate".to_string()) - })?; + self.inner + .set_bitrate(bitrate_kbps as i32) + .map_err(|_| AppError::VideoError("Failed to set bitrate".to_string()))?; self.config.bitrate_kbps = bitrate_kbps; debug!("Bitrate updated to {} kbps", bitrate_kbps); Ok(()) @@ -394,16 +408,7 @@ impl H264Encoder { Ok(owned_frames) } Err(e) => { - // For the first ~30 frames, x264 may fail due to initialization - // Log as warning instead of error to avoid alarming users - if self.frame_count <= 30 { - warn!( - "Encode failed during initialization (frame {}): {} - this is normal for x264", - self.frame_count, e - ); - } else { - error!("Encode failed: {}", e); - } + error!("Encode failed: {}", e); Err(AppError::VideoError(format!("Encode failed: {}", e))) } } @@ -458,7 +463,9 @@ impl Encoder for H264Encoder { if frames.is_empty() { // Encoder needs more frames (shouldn't happen with our config) warn!("Encoder returned no frames"); - return Err(AppError::VideoError("Encoder returned no frames".to_string())); + return Err(AppError::VideoError( + "Encoder returned no frames".to_string(), + )); } // Take ownership of the first frame (zero-copy) @@ -493,8 +500,13 @@ impl Encoder for H264Encoder { // Check if the format matches our configured input format match self.config.input_format { H264InputFormat::Nv12 => matches!(format, PixelFormat::Nv12), + H264InputFormat::Nv21 => matches!(format, PixelFormat::Nv21), + H264InputFormat::Nv16 => matches!(format, PixelFormat::Nv16), + H264InputFormat::Nv24 => matches!(format, PixelFormat::Nv24), H264InputFormat::Yuv420p => matches!(format, PixelFormat::Yuv420), H264InputFormat::Yuyv422 => matches!(format, PixelFormat::Yuyv), + H264InputFormat::Rgb24 => matches!(format, PixelFormat::Rgb24), + H264InputFormat::Bgr24 => matches!(format, PixelFormat::Bgr24), } } } @@ -538,7 +550,11 @@ mod tests { let config = H264Config::low_latency(Resolution::HD720, 2000); match H264Encoder::new(config) { Ok(encoder) => { - println!("Created encoder: {} ({})", encoder.codec_name(), encoder.encoder_type()); + println!( + "Created encoder: {} ({})", + encoder.codec_name(), + encoder.encoder_type() + ); } Err(e) => { println!("Failed to create encoder: {}", e); diff --git a/src/video/encoder/h265.rs b/src/video/encoder/h265.rs index 9fe82b39..8a89015d 100644 --- a/src/video/encoder/h265.rs +++ b/src/video/encoder/h265.rs @@ -92,8 +92,18 @@ pub enum H265InputFormat { Yuv420p, /// NV12 - Y plane + interleaved UV plane (optimal for hardware encoders) Nv12, + /// NV21 - Y plane + interleaved VU plane + Nv21, + /// NV16 - Y plane + interleaved UV plane (4:2:2) + Nv16, + /// NV24 - Y plane + interleaved UV plane (4:4:4) + Nv24, /// YUYV422 - packed YUV 4:2:2 format (optimal for RKMPP direct input) Yuyv422, + /// RGB24 - packed RGB format (RKMPP direct input) + Rgb24, + /// BGR24 - packed BGR format (RKMPP direct input) + Bgr24, } impl Default for H265InputFormat { @@ -252,10 +262,7 @@ pub fn detect_best_h265_encoder(width: u32, height: u32) -> (H265EncoderType, Op H265EncoderType::Software // Default to software for unknown }; - info!( - "Selected H.265 encoder: {} ({})", - codec.name, encoder_type - ); + info!("Selected H.265 encoder: {} ({})", codec.name, encoder_type); (encoder_type, Some(codec.name.clone())) } @@ -304,7 +311,8 @@ impl H265Encoder { if encoder_type == H265EncoderType::None { return Err(AppError::VideoError( - "No H.265 encoder available. Please ensure FFmpeg is built with libx265 support.".to_string(), + "No H.265 encoder available. Please ensure FFmpeg is built with libx265 support." + .to_string(), )); } @@ -336,8 +344,17 @@ impl H265Encoder { } else { match config.input_format { H265InputFormat::Nv12 => (AVPixelFormat::AV_PIX_FMT_NV12, H265InputFormat::Nv12), - H265InputFormat::Yuv420p => (AVPixelFormat::AV_PIX_FMT_YUV420P, H265InputFormat::Yuv420p), - H265InputFormat::Yuyv422 => (AVPixelFormat::AV_PIX_FMT_YUYV422, H265InputFormat::Yuyv422), + H265InputFormat::Nv21 => (AVPixelFormat::AV_PIX_FMT_NV21, H265InputFormat::Nv21), + H265InputFormat::Nv16 => (AVPixelFormat::AV_PIX_FMT_NV16, H265InputFormat::Nv16), + H265InputFormat::Nv24 => (AVPixelFormat::AV_PIX_FMT_NV24, H265InputFormat::Nv24), + H265InputFormat::Yuv420p => { + (AVPixelFormat::AV_PIX_FMT_YUV420P, H265InputFormat::Yuv420p) + } + H265InputFormat::Yuyv422 => { + (AVPixelFormat::AV_PIX_FMT_YUYV422, H265InputFormat::Yuyv422) + } + H265InputFormat::Rgb24 => (AVPixelFormat::AV_PIX_FMT_RGB24, H265InputFormat::Rgb24), + H265InputFormat::Bgr24 => (AVPixelFormat::AV_PIX_FMT_BGR24, H265InputFormat::Bgr24), } }; @@ -407,9 +424,9 @@ impl H265Encoder { /// Update bitrate dynamically pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> { - self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| { - AppError::VideoError("Failed to set H.265 bitrate".to_string()) - })?; + self.inner + .set_bitrate(bitrate_kbps as i32) + .map_err(|_| AppError::VideoError("Failed to set H.265 bitrate".to_string()))?; self.config.bitrate_kbps = bitrate_kbps; debug!("H.265 bitrate updated to {} kbps", bitrate_kbps); Ok(()) @@ -464,7 +481,10 @@ impl H265Encoder { if keyframe || self.frame_count % 30 == 1 { debug!( "[H265] Encoded frame #{}: output_size={}, keyframe={}, frame_count={}", - self.frame_count, total_size, keyframe, owned_frames.len() + self.frame_count, + total_size, + keyframe, + owned_frames.len() ); // Log first few bytes of keyframe for debugging @@ -477,7 +497,10 @@ impl H265Encoder { } } } else { - warn!("[H265] Encoder returned empty frame list for frame #{}", self.frame_count); + warn!( + "[H265] Encoder returned empty frame list for frame #{}", + self.frame_count + ); } Ok(owned_frames) @@ -567,8 +590,13 @@ impl Encoder for H265Encoder { fn supports_format(&self, format: PixelFormat) -> bool { match self.config.input_format { H265InputFormat::Nv12 => matches!(format, PixelFormat::Nv12), + H265InputFormat::Nv21 => matches!(format, PixelFormat::Nv21), + H265InputFormat::Nv16 => matches!(format, PixelFormat::Nv16), + H265InputFormat::Nv24 => matches!(format, PixelFormat::Nv24), H265InputFormat::Yuv420p => matches!(format, PixelFormat::Yuv420), H265InputFormat::Yuyv422 => matches!(format, PixelFormat::Yuyv), + H265InputFormat::Rgb24 => matches!(format, PixelFormat::Rgb24), + H265InputFormat::Bgr24 => matches!(format, PixelFormat::Bgr24), } } } @@ -580,7 +608,10 @@ mod tests { #[test] fn test_detect_h265_encoder() { let (encoder_type, codec_name) = detect_best_h265_encoder(1280, 720); - println!("Detected H.265 encoder: {:?} ({:?})", encoder_type, codec_name); + println!( + "Detected H.265 encoder: {:?} ({:?})", + encoder_type, codec_name + ); } #[test] diff --git a/src/video/encoder/jpeg.rs b/src/video/encoder/jpeg.rs index b4a54f58..9b3f4c78 100644 --- a/src/video/encoder/jpeg.rs +++ b/src/video/encoder/jpeg.rs @@ -35,10 +35,12 @@ impl JpegEncoder { // I420: Y = width*height, U = width*height/4, V = width*height/4 let i420_size = width * height * 3 / 2; - let mut compressor = turbojpeg::Compressor::new() - .map_err(|e| AppError::VideoError(format!("Failed to create turbojpeg compressor: {}", e)))?; + let mut compressor = turbojpeg::Compressor::new().map_err(|e| { + AppError::VideoError(format!("Failed to create turbojpeg compressor: {}", e)) + })?; - compressor.set_quality(config.quality.min(100) as i32) + compressor + .set_quality(config.quality.min(100) as i32) .map_err(|e| AppError::VideoError(format!("Failed to set JPEG quality: {}", e)))?; Ok(Self { @@ -56,7 +58,8 @@ impl JpegEncoder { /// Set JPEG quality (1-100) pub fn set_quality(&mut self, quality: u32) -> Result<()> { - self.compressor.set_quality(quality.min(100) as i32) + self.compressor + .set_quality(quality.min(100) as i32) .map_err(|e| AppError::VideoError(format!("Failed to set JPEG quality: {}", e)))?; self.config.quality = quality; Ok(()) @@ -73,12 +76,14 @@ impl JpegEncoder { pixels: self.i420_buffer.as_slice(), width, height, - align: 1, // No padding between rows + align: 1, // No padding between rows subsamp: turbojpeg::Subsamp::Sub2x2, // YUV 4:2:0 }; // Compress YUV directly to JPEG (skips color space conversion!) - let jpeg_data = self.compressor.compress_yuv_to_vec(yuv_image) + let jpeg_data = self + .compressor + .compress_yuv_to_vec(yuv_image) .map_err(|e| AppError::VideoError(format!("JPEG compression failed: {}", e)))?; Ok(EncodedFrame::jpeg( diff --git a/src/video/encoder/mod.rs b/src/video/encoder/mod.rs index ac9a4432..a0e10fb3 100644 --- a/src/video/encoder/mod.rs +++ b/src/video/encoder/mod.rs @@ -19,7 +19,9 @@ pub mod vp8; pub mod vp9; // Core traits and types -pub use traits::{BitratePreset, EncodedFormat, EncodedFrame, Encoder, EncoderConfig, EncoderFactory}; +pub use traits::{ + BitratePreset, EncodedFormat, EncodedFrame, Encoder, EncoderConfig, EncoderFactory, +}; // WebRTC codec abstraction pub use codec::{CodecFrame, VideoCodec, VideoCodecConfig, VideoCodecFactory, VideoCodecType}; diff --git a/src/video/encoder/registry.rs b/src/video/encoder/registry.rs index f76545ed..edcb7780 100644 --- a/src/video/encoder/registry.rs +++ b/src/video/encoder/registry.rs @@ -264,10 +264,7 @@ impl EncoderRegistry { if let Some(encoder) = AvailableEncoder::from_codec_info(codec_info) { debug!( "Detected encoder: {} ({}) - {} priority={}", - encoder.codec_name, - encoder.format, - encoder.backend, - encoder.priority + encoder.codec_name, encoder.format, encoder.backend, encoder.priority ); self.encoders @@ -336,13 +333,15 @@ impl EncoderRegistry { format: VideoEncoderType, hardware_only: bool, ) -> Option<&AvailableEncoder> { - self.encoders.get(&format)?.iter().find(|e| { - if hardware_only { - e.is_hardware - } else { - true - } - }) + self.encoders.get(&format)?.iter().find( + |e| { + if hardware_only { + e.is_hardware + } else { + true + } + }, + ) } /// Get all encoders for a format @@ -523,9 +522,6 @@ mod tests { // Should have detected at least H264 (software fallback available) println!("Available formats: {:?}", registry.available_formats(false)); - println!( - "Selectable formats: {:?}", - registry.selectable_formats() - ); + println!("Selectable formats: {:?}", registry.selectable_formats()); } } diff --git a/src/video/encoder/traits.rs b/src/video/encoder/traits.rs index a9f96688..940ec245 100644 --- a/src/video/encoder/traits.rs +++ b/src/video/encoder/traits.rs @@ -5,8 +5,8 @@ use serde::{Deserialize, Serialize}; use std::time::Instant; use typeshare::typeshare; -use crate::video::format::{PixelFormat, Resolution}; use crate::error::Result; +use crate::video::format::{PixelFormat, Resolution}; /// Bitrate preset for video encoding /// @@ -46,10 +46,10 @@ impl BitratePreset { /// Quality preset uses longer GOP for better compression efficiency. pub fn gop_size(&self, fps: u32) -> u32 { match self { - Self::Speed => (fps / 2).max(15), // 0.5 second, minimum 15 frames - Self::Balanced => fps, // 1 second - Self::Quality => fps * 2, // 2 seconds - Self::Custom(_) => fps, // Default 1 second for custom + Self::Speed => (fps / 2).max(15), // 0.5 second, minimum 15 frames + Self::Balanced => fps, // 1 second + Self::Quality => fps * 2, // 2 seconds + Self::Custom(_) => fps, // Default 1 second for custom } } diff --git a/src/video/encoder/vp8.rs b/src/video/encoder/vp8.rs index a9af912c..868af8ee 100644 --- a/src/video/encoder/vp8.rs +++ b/src/video/encoder/vp8.rs @@ -186,10 +186,7 @@ pub fn detect_best_vp8_encoder(width: u32, height: u32) -> (VP8EncoderType, Opti VP8EncoderType::Software // Default to software for unknown }; - info!( - "Selected VP8 encoder: {} ({})", - codec.name, encoder_type - ); + info!("Selected VP8 encoder: {} ({})", codec.name, encoder_type); (encoder_type, Some(codec.name.clone())) } @@ -238,7 +235,8 @@ impl VP8Encoder { if encoder_type == VP8EncoderType::None { return Err(AppError::VideoError( - "No VP8 encoder available. Please ensure FFmpeg is built with libvpx support.".to_string(), + "No VP8 encoder available. Please ensure FFmpeg is built with libvpx support." + .to_string(), )); } @@ -270,7 +268,9 @@ impl VP8Encoder { } else { match config.input_format { VP8InputFormat::Nv12 => (AVPixelFormat::AV_PIX_FMT_NV12, VP8InputFormat::Nv12), - VP8InputFormat::Yuv420p => (AVPixelFormat::AV_PIX_FMT_YUV420P, VP8InputFormat::Yuv420p), + VP8InputFormat::Yuv420p => { + (AVPixelFormat::AV_PIX_FMT_YUV420P, VP8InputFormat::Yuv420p) + } } }; @@ -340,9 +340,9 @@ impl VP8Encoder { /// Update bitrate dynamically pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> { - self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| { - AppError::VideoError("Failed to set VP8 bitrate".to_string()) - })?; + self.inner + .set_bitrate(bitrate_kbps as i32) + .map_err(|_| AppError::VideoError("Failed to set VP8 bitrate".to_string()))?; self.config.bitrate_kbps = bitrate_kbps; debug!("VP8 bitrate updated to {} kbps", bitrate_kbps); Ok(()) @@ -470,7 +470,10 @@ mod tests { #[test] fn test_detect_vp8_encoder() { let (encoder_type, codec_name) = detect_best_vp8_encoder(1280, 720); - println!("Detected VP8 encoder: {:?} ({:?})", encoder_type, codec_name); + println!( + "Detected VP8 encoder: {:?} ({:?})", + encoder_type, codec_name + ); } #[test] diff --git a/src/video/encoder/vp9.rs b/src/video/encoder/vp9.rs index 3725388b..6995db5d 100644 --- a/src/video/encoder/vp9.rs +++ b/src/video/encoder/vp9.rs @@ -186,10 +186,7 @@ pub fn detect_best_vp9_encoder(width: u32, height: u32) -> (VP9EncoderType, Opti VP9EncoderType::Software // Default to software for unknown }; - info!( - "Selected VP9 encoder: {} ({})", - codec.name, encoder_type - ); + info!("Selected VP9 encoder: {} ({})", codec.name, encoder_type); (encoder_type, Some(codec.name.clone())) } @@ -238,7 +235,8 @@ impl VP9Encoder { if encoder_type == VP9EncoderType::None { return Err(AppError::VideoError( - "No VP9 encoder available. Please ensure FFmpeg is built with libvpx support.".to_string(), + "No VP9 encoder available. Please ensure FFmpeg is built with libvpx support." + .to_string(), )); } @@ -270,7 +268,9 @@ impl VP9Encoder { } else { match config.input_format { VP9InputFormat::Nv12 => (AVPixelFormat::AV_PIX_FMT_NV12, VP9InputFormat::Nv12), - VP9InputFormat::Yuv420p => (AVPixelFormat::AV_PIX_FMT_YUV420P, VP9InputFormat::Yuv420p), + VP9InputFormat::Yuv420p => { + (AVPixelFormat::AV_PIX_FMT_YUV420P, VP9InputFormat::Yuv420p) + } } }; @@ -340,9 +340,9 @@ impl VP9Encoder { /// Update bitrate dynamically pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> { - self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| { - AppError::VideoError("Failed to set VP9 bitrate".to_string()) - })?; + self.inner + .set_bitrate(bitrate_kbps as i32) + .map_err(|_| AppError::VideoError("Failed to set VP9 bitrate".to_string()))?; self.config.bitrate_kbps = bitrate_kbps; debug!("VP9 bitrate updated to {} kbps", bitrate_kbps); Ok(()) @@ -470,7 +470,10 @@ mod tests { #[test] fn test_detect_vp9_encoder() { let (encoder_type, codec_name) = detect_best_vp9_encoder(1280, 720); - println!("Detected VP9 encoder: {:?} ({:?})", encoder_type, codec_name); + println!( + "Detected VP9 encoder: {:?} ({:?})", + encoder_type, codec_name + ); } #[test] diff --git a/src/video/format.rs b/src/video/format.rs index 05d014d6..4097ae6f 100644 --- a/src/video/format.rs +++ b/src/video/format.rs @@ -20,6 +20,8 @@ pub enum PixelFormat { Uyvy, /// NV12 semi-planar format (Y plane + interleaved UV) Nv12, + /// NV21 semi-planar format (Y plane + interleaved VU) + Nv21, /// NV16 semi-planar format Nv16, /// NV24 semi-planar format @@ -48,6 +50,7 @@ impl PixelFormat { PixelFormat::Yvyu => fourcc::FourCC::new(b"YVYU"), PixelFormat::Uyvy => fourcc::FourCC::new(b"UYVY"), PixelFormat::Nv12 => fourcc::FourCC::new(b"NV12"), + PixelFormat::Nv21 => fourcc::FourCC::new(b"NV21"), PixelFormat::Nv16 => fourcc::FourCC::new(b"NV16"), PixelFormat::Nv24 => fourcc::FourCC::new(b"NV24"), PixelFormat::Yuv420 => fourcc::FourCC::new(b"YU12"), @@ -69,6 +72,7 @@ impl PixelFormat { b"YVYU" => Some(PixelFormat::Yvyu), b"UYVY" => Some(PixelFormat::Uyvy), b"NV12" => Some(PixelFormat::Nv12), + b"NV21" => Some(PixelFormat::Nv21), b"NV16" => Some(PixelFormat::Nv16), b"NV24" => Some(PixelFormat::Nv24), b"YU12" | b"I420" => Some(PixelFormat::Yuv420), @@ -92,7 +96,9 @@ impl PixelFormat { match self { PixelFormat::Mjpeg | PixelFormat::Jpeg => None, PixelFormat::Yuyv | PixelFormat::Yvyu | PixelFormat::Uyvy => Some(2), - PixelFormat::Nv12 | PixelFormat::Yuv420 | PixelFormat::Yvu420 => None, // Variable + PixelFormat::Nv12 | PixelFormat::Nv21 | PixelFormat::Yuv420 | PixelFormat::Yvu420 => { + None + } // Variable PixelFormat::Nv16 => None, PixelFormat::Nv24 => None, PixelFormat::Rgb565 => Some(2), @@ -108,7 +114,9 @@ impl PixelFormat { match self { PixelFormat::Mjpeg | PixelFormat::Jpeg => None, PixelFormat::Yuyv | PixelFormat::Yvyu | PixelFormat::Uyvy => Some(pixels * 2), - PixelFormat::Nv12 | PixelFormat::Yuv420 | PixelFormat::Yvu420 => Some(pixels * 3 / 2), + PixelFormat::Nv12 | PixelFormat::Nv21 | PixelFormat::Yuv420 | PixelFormat::Yvu420 => { + Some(pixels * 3 / 2) + } PixelFormat::Nv16 => Some(pixels * 2), PixelFormat::Nv24 => Some(pixels * 3), PixelFormat::Rgb565 => Some(pixels * 2), @@ -125,6 +133,7 @@ impl PixelFormat { PixelFormat::Jpeg => 99, PixelFormat::Yuyv => 80, PixelFormat::Nv12 => 75, + PixelFormat::Nv21 => 74, PixelFormat::Yuv420 => 70, PixelFormat::Uyvy => 65, PixelFormat::Yvyu => 64, @@ -144,7 +153,10 @@ impl PixelFormat { /// Software encoding prefers: YUYV > NV12 /// /// Returns None if no suitable format is available - pub fn recommended_for_encoding(available: &[PixelFormat], is_hardware: bool) -> Option { + pub fn recommended_for_encoding( + available: &[PixelFormat], + is_hardware: bool, + ) -> Option { if is_hardware { // Hardware encoding: NV12 > YUYV if available.contains(&PixelFormat::Nv12) { @@ -175,6 +187,7 @@ impl PixelFormat { PixelFormat::Yvyu, PixelFormat::Uyvy, PixelFormat::Nv12, + PixelFormat::Nv21, PixelFormat::Nv16, PixelFormat::Nv24, PixelFormat::Yuv420, @@ -196,6 +209,7 @@ impl fmt::Display for PixelFormat { PixelFormat::Yvyu => "YVYU", PixelFormat::Uyvy => "UYVY", PixelFormat::Nv12 => "NV12", + PixelFormat::Nv21 => "NV21", PixelFormat::Nv16 => "NV16", PixelFormat::Nv24 => "NV24", PixelFormat::Yuv420 => "YUV420", @@ -220,6 +234,7 @@ impl std::str::FromStr for PixelFormat { "YVYU" => Ok(PixelFormat::Yvyu), "UYVY" => Ok(PixelFormat::Uyvy), "NV12" => Ok(PixelFormat::Nv12), + "NV21" => Ok(PixelFormat::Nv21), "NV16" => Ok(PixelFormat::Nv16), "NV24" => Ok(PixelFormat::Nv24), "YUV420" | "I420" => Ok(PixelFormat::Yuv420), diff --git a/src/video/frame.rs b/src/video/frame.rs index 131e5928..cd66796e 100644 --- a/src/video/frame.rs +++ b/src/video/frame.rs @@ -106,9 +106,9 @@ impl VideoFrame { /// Get hash of frame data (computed once, cached) /// Used for fast frame deduplication comparison pub fn get_hash(&self) -> u64 { - *self.hash.get_or_init(|| { - xxhash_rust::xxh64::xxh64(self.data.as_ref(), 0) - }) + *self + .hash + .get_or_init(|| xxhash_rust::xxh64::xxh64(self.data.as_ref(), 0)) } /// Check if format is JPEG/MJPEG diff --git a/src/video/h264_pipeline.rs b/src/video/h264_pipeline.rs index 6c6c20b5..ddc47d8f 100644 --- a/src/video/h264_pipeline.rs +++ b/src/video/h264_pipeline.rs @@ -93,10 +93,7 @@ impl H264Pipeline { pub fn new(config: H264PipelineConfig) -> Result { info!( "Creating H264 pipeline: {}x{} @ {} kbps, {} fps", - config.resolution.width, - config.resolution.height, - config.bitrate_kbps, - config.fps + config.resolution.width, config.resolution.height, config.bitrate_kbps, config.fps ); // Determine encoder input format based on pipeline input @@ -154,7 +151,7 @@ impl H264Pipeline { // MJPEG/JPEG input - not supported (requires libjpeg for decoding) PixelFormat::Mjpeg | PixelFormat::Jpeg => { return Err(AppError::VideoError( - "MJPEG input format not supported in this build".to_string() + "MJPEG input format not supported in this build".to_string(), )); } @@ -216,7 +213,10 @@ impl H264Pipeline { } let _ = self.running.send(true); - info!("Starting H264 pipeline (input format: {})", self.config.input_format); + info!( + "Starting H264 pipeline (input format: {})", + self.config.input_format + ); let encoder = self.encoder.lock().await.take(); let nv12_converter = self.nv12_converter.lock().await.take(); diff --git a/src/video/mod.rs b/src/video/mod.rs index d024e1ce..b5664f48 100644 --- a/src/video/mod.rs +++ b/src/video/mod.rs @@ -18,11 +18,15 @@ pub mod video_session; pub use capture::VideoCapturer; pub use convert::{PixelConverter, Yuv420pBuffer}; pub use device::{VideoDevice, VideoDeviceInfo}; -pub use encoder::{JpegEncoder, H264Encoder, H264EncoderType}; +pub use encoder::{H264Encoder, H264EncoderType, JpegEncoder}; pub use format::PixelFormat; pub use frame::VideoFrame; pub use h264_pipeline::{H264Pipeline, H264PipelineBuilder, H264PipelineConfig}; -pub use shared_video_pipeline::{EncodedVideoFrame, SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats}; +pub use shared_video_pipeline::{ + EncodedVideoFrame, SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats, +}; pub use stream_manager::VideoStreamManager; pub use streamer::{Streamer, StreamerState}; -pub use video_session::{VideoSessionManager, VideoSessionManagerConfig, VideoSessionInfo, VideoSessionState, CodecInfo}; +pub use video_session::{ + CodecInfo, VideoSessionInfo, VideoSessionManager, VideoSessionManagerConfig, VideoSessionState, +}; diff --git a/src/video/shared_video_pipeline.rs b/src/video/shared_video_pipeline.rs index 0d5b5699..2ed46296 100644 --- a/src/video/shared_video_pipeline.rs +++ b/src/video/shared_video_pipeline.rs @@ -28,8 +28,10 @@ const AUTO_STOP_GRACE_PERIOD_SECS: u64 = 3; use crate::error::{AppError, Result}; use crate::video::convert::{Nv12Converter, PixelConverter}; -use crate::video::encoder::h264::{H264Config, H264Encoder}; -use crate::video::encoder::h265::{H265Config, H265Encoder}; +use crate::video::encoder::h264::{detect_best_encoder, H264Config, H264Encoder, H264InputFormat}; +use crate::video::encoder::h265::{ + detect_best_h265_encoder, H265Config, H265Encoder, H265InputFormat, +}; use crate::video::encoder::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType}; use crate::video::encoder::traits::EncoderConfig; use crate::video::encoder::vp8::{VP8Config, VP8Encoder}; @@ -157,7 +159,6 @@ pub struct SharedVideoPipelineStats { pub subscribers: u64, } - /// Universal video encoder trait object #[allow(dead_code)] trait VideoEncoderTrait: Send { @@ -300,7 +301,7 @@ pub struct SharedVideoPipeline { /// Whether the encoder needs YUV420P (true) or NV12 (false) encoder_needs_yuv420p: AtomicBool, /// Whether YUYV direct input is enabled (RKMPP optimization) - yuyv_direct_input: AtomicBool, + direct_input: AtomicBool, frame_tx: broadcast::Sender, stats: Mutex, running: watch::Sender, @@ -326,7 +327,7 @@ impl SharedVideoPipeline { config.input_format ); - let (frame_tx, _) = broadcast::channel(16); // Reduced from 64 for lower latency + let (frame_tx, _) = broadcast::channel(16); // Reduced from 64 for lower latency let (running_tx, running_rx) = watch::channel(false); let pipeline = Arc::new(Self { @@ -335,7 +336,7 @@ impl SharedVideoPipeline { nv12_converter: Mutex::new(None), yuv420p_converter: Mutex::new(None), encoder_needs_yuv420p: AtomicBool::new(false), - yuyv_direct_input: AtomicBool::new(false), + direct_input: AtomicBool::new(false), frame_tx, stats: Mutex::new(SharedVideoPipelineStats::default()), running: running_tx, @@ -354,29 +355,108 @@ impl SharedVideoPipeline { let registry = EncoderRegistry::global(); // Helper to get codec name for specific backend - let get_codec_name = |format: VideoEncoderType, backend: Option| -> Option { - match backend { - Some(b) => registry.encoder_with_backend(format, b).map(|e| e.codec_name.clone()), - None => registry.best_encoder(format, false).map(|e| e.codec_name.clone()), - } - }; + let get_codec_name = + |format: VideoEncoderType, backend: Option| -> Option { + match backend { + Some(b) => registry + .encoder_with_backend(format, b) + .map(|e| e.codec_name.clone()), + None => registry + .best_encoder(format, false) + .map(|e| e.codec_name.clone()), + } + }; - // Check if RKMPP backend is available for YUYV direct input optimization - let is_rkmpp_available = registry.encoder_with_backend(VideoEncoderType::H264, EncoderBackend::Rkmpp).is_some(); + // Check if RKMPP backend is available for direct input optimization + let is_rkmpp_available = registry + .encoder_with_backend(VideoEncoderType::H264, EncoderBackend::Rkmpp) + .is_some(); let use_yuyv_direct = is_rkmpp_available && config.input_format == PixelFormat::Yuyv; + let use_rkmpp_direct = is_rkmpp_available + && matches!( + config.input_format, + PixelFormat::Yuyv + | PixelFormat::Yuv420 + | PixelFormat::Rgb24 + | PixelFormat::Bgr24 + | PixelFormat::Nv12 + | PixelFormat::Nv16 + | PixelFormat::Nv21 + | PixelFormat::Nv24 + ); if use_yuyv_direct { - info!("RKMPP backend detected with YUYV input, enabling YUYV direct input optimization"); + info!( + "RKMPP backend detected with YUYV input, enabling YUYV direct input optimization" + ); + } else if use_rkmpp_direct { + info!( + "RKMPP backend detected with {} input, enabling direct input optimization", + config.input_format + ); } // Create encoder based on codec type let encoder: Box = match config.output_codec { VideoEncoderType::H264 => { - // Determine H264 input format based on backend and input format - let h264_input_format = if use_yuyv_direct { - crate::video::encoder::h264::H264InputFormat::Yuyv422 + let codec_name = if use_rkmpp_direct { + // Force RKMPP backend for direct input + get_codec_name(VideoEncoderType::H264, Some(EncoderBackend::Rkmpp)).ok_or_else( + || { + AppError::VideoError( + "RKMPP backend not available for H.264".to_string(), + ) + }, + )? + } else if let Some(ref backend) = config.encoder_backend { + // Specific backend requested + get_codec_name(VideoEncoderType::H264, Some(*backend)).ok_or_else(|| { + AppError::VideoError(format!( + "Backend {:?} does not support H.264", + backend + )) + })? } else { - crate::video::encoder::h264::H264InputFormat::Nv12 + // Auto select best available encoder + let (_encoder_type, detected) = + detect_best_encoder(config.resolution.width, config.resolution.height); + detected.ok_or_else(|| { + AppError::VideoError("No H.264 encoder available".to_string()) + })? + }; + + let is_rkmpp = codec_name.contains("rkmpp"); + let direct_input_format = if is_rkmpp { + match config.input_format { + PixelFormat::Yuyv => Some(H264InputFormat::Yuyv422), + PixelFormat::Yuv420 => Some(H264InputFormat::Yuv420p), + PixelFormat::Rgb24 => Some(H264InputFormat::Rgb24), + PixelFormat::Bgr24 => Some(H264InputFormat::Bgr24), + PixelFormat::Nv12 => Some(H264InputFormat::Nv12), + PixelFormat::Nv16 => Some(H264InputFormat::Nv16), + PixelFormat::Nv21 => Some(H264InputFormat::Nv21), + PixelFormat::Nv24 => Some(H264InputFormat::Nv24), + _ => None, + } + } else if codec_name.contains("libx264") { + match config.input_format { + PixelFormat::Nv12 => Some(H264InputFormat::Nv12), + PixelFormat::Nv16 => Some(H264InputFormat::Nv16), + PixelFormat::Nv21 => Some(H264InputFormat::Nv21), + PixelFormat::Yuv420 => Some(H264InputFormat::Yuv420p), + _ => None, + } + } else { + None + }; + + // Choose input format: prefer direct input when supported + let h264_input_format = if let Some(fmt) = direct_input_format { + fmt + } else if codec_name.contains("libx264") { + H264InputFormat::Yuv420p + } else { + H264InputFormat::Nv12 }; let encoder_config = H264Config { @@ -387,69 +467,124 @@ impl SharedVideoPipeline { input_format: h264_input_format, }; - let encoder = if use_yuyv_direct { - // Force RKMPP backend for YUYV direct input - let codec_name = get_codec_name(VideoEncoderType::H264, Some(EncoderBackend::Rkmpp)) - .ok_or_else(|| AppError::VideoError( - "RKMPP backend not available for H.264".to_string() - ))?; - info!("Creating H264 encoder with RKMPP backend for YUYV direct input (codec: {})", codec_name); - H264Encoder::with_codec(encoder_config, &codec_name)? + if use_rkmpp_direct { + info!( + "Creating H264 encoder with RKMPP backend for {} direct input (codec: {})", + config.input_format, codec_name + ); } else if let Some(ref backend) = config.encoder_backend { - // Specific backend requested - let codec_name = get_codec_name(VideoEncoderType::H264, Some(*backend)) - .ok_or_else(|| AppError::VideoError(format!( - "Backend {:?} does not support H.264", backend - )))?; - info!("Creating H264 encoder with backend {:?} (codec: {})", backend, codec_name); - H264Encoder::with_codec(encoder_config, &codec_name)? - } else { - // Auto select - H264Encoder::new(encoder_config)? - }; + info!( + "Creating H264 encoder with backend {:?} (codec: {})", + backend, codec_name + ); + } + + let encoder = H264Encoder::with_codec(encoder_config, &codec_name)?; info!("Created H264 encoder: {}", encoder.codec_name()); Box::new(H264EncoderWrapper(encoder)) } VideoEncoderType::H265 => { - // Determine H265 input format based on backend and input format - let encoder_config = if use_yuyv_direct { - H265Config::low_latency_yuyv422(config.resolution, config.bitrate_kbps()) + let codec_name = if use_rkmpp_direct { + get_codec_name(VideoEncoderType::H265, Some(EncoderBackend::Rkmpp)).ok_or_else( + || { + AppError::VideoError( + "RKMPP backend not available for H.265".to_string(), + ) + }, + )? + } else if let Some(ref backend) = config.encoder_backend { + get_codec_name(VideoEncoderType::H265, Some(*backend)).ok_or_else(|| { + AppError::VideoError(format!( + "Backend {:?} does not support H.265", + backend + )) + })? } else { - H265Config::low_latency(config.resolution, config.bitrate_kbps()) + let (_encoder_type, detected) = + detect_best_h265_encoder(config.resolution.width, config.resolution.height); + detected.ok_or_else(|| { + AppError::VideoError("No H.265 encoder available".to_string()) + })? }; - let encoder = if use_yuyv_direct { - // Force RKMPP backend for YUYV direct input - let codec_name = get_codec_name(VideoEncoderType::H265, Some(EncoderBackend::Rkmpp)) - .ok_or_else(|| AppError::VideoError( - "RKMPP backend not available for H.265".to_string() - ))?; - info!("Creating H265 encoder with RKMPP backend for YUYV direct input (codec: {})", codec_name); - H265Encoder::with_codec(encoder_config, &codec_name)? - } else if let Some(ref backend) = config.encoder_backend { - let codec_name = get_codec_name(VideoEncoderType::H265, Some(*backend)) - .ok_or_else(|| AppError::VideoError(format!( - "Backend {:?} does not support H.265", backend - )))?; - info!("Creating H265 encoder with backend {:?} (codec: {})", backend, codec_name); - H265Encoder::with_codec(encoder_config, &codec_name)? + let is_rkmpp = codec_name.contains("rkmpp"); + let direct_input_format = if is_rkmpp { + match config.input_format { + PixelFormat::Yuyv => Some(H265InputFormat::Yuyv422), + PixelFormat::Yuv420 => Some(H265InputFormat::Yuv420p), + PixelFormat::Rgb24 => Some(H265InputFormat::Rgb24), + PixelFormat::Bgr24 => Some(H265InputFormat::Bgr24), + PixelFormat::Nv12 => Some(H265InputFormat::Nv12), + PixelFormat::Nv16 => Some(H265InputFormat::Nv16), + PixelFormat::Nv21 => Some(H265InputFormat::Nv21), + PixelFormat::Nv24 => Some(H265InputFormat::Nv24), + _ => None, + } + } else if codec_name.contains("libx265") { + match config.input_format { + PixelFormat::Yuv420 => Some(H265InputFormat::Yuv420p), + _ => None, + } } else { - H265Encoder::new(encoder_config)? + None }; + let h265_input_format = if let Some(fmt) = direct_input_format { + fmt + } else if codec_name.contains("libx265") { + H265InputFormat::Yuv420p + } else { + H265InputFormat::Nv12 + }; + + let encoder_config = H265Config { + base: EncoderConfig { + resolution: config.resolution, + input_format: config.input_format, + quality: config.bitrate_kbps(), + fps: config.fps, + gop_size: config.gop_size(), + }, + bitrate_kbps: config.bitrate_kbps(), + gop_size: config.gop_size(), + fps: config.fps, + input_format: h265_input_format, + }; + + if use_rkmpp_direct { + info!( + "Creating H265 encoder with RKMPP backend for {} direct input (codec: {})", + config.input_format, codec_name + ); + } else if let Some(ref backend) = config.encoder_backend { + info!( + "Creating H265 encoder with backend {:?} (codec: {})", + backend, codec_name + ); + } + + let encoder = H265Encoder::with_codec(encoder_config, &codec_name)?; + info!("Created H265 encoder: {}", encoder.codec_name()); Box::new(H265EncoderWrapper(encoder)) } VideoEncoderType::VP8 => { - let encoder_config = VP8Config::low_latency(config.resolution, config.bitrate_kbps()); + let encoder_config = + VP8Config::low_latency(config.resolution, config.bitrate_kbps()); let encoder = if let Some(ref backend) = config.encoder_backend { let codec_name = get_codec_name(VideoEncoderType::VP8, Some(*backend)) - .ok_or_else(|| AppError::VideoError(format!( - "Backend {:?} does not support VP8", backend - )))?; - info!("Creating VP8 encoder with backend {:?} (codec: {})", backend, codec_name); + .ok_or_else(|| { + AppError::VideoError(format!( + "Backend {:?} does not support VP8", + backend + )) + })?; + info!( + "Creating VP8 encoder with backend {:?} (codec: {})", + backend, codec_name + ); VP8Encoder::with_codec(encoder_config, &codec_name)? } else { VP8Encoder::new(encoder_config)? @@ -459,14 +594,21 @@ impl SharedVideoPipeline { Box::new(VP8EncoderWrapper(encoder)) } VideoEncoderType::VP9 => { - let encoder_config = VP9Config::low_latency(config.resolution, config.bitrate_kbps()); + let encoder_config = + VP9Config::low_latency(config.resolution, config.bitrate_kbps()); let encoder = if let Some(ref backend) = config.encoder_backend { let codec_name = get_codec_name(VideoEncoderType::VP9, Some(*backend)) - .ok_or_else(|| AppError::VideoError(format!( - "Backend {:?} does not support VP9", backend - )))?; - info!("Creating VP9 encoder with backend {:?} (codec: {})", backend, codec_name); + .ok_or_else(|| { + AppError::VideoError(format!( + "Backend {:?} does not support VP9", + backend + )) + })?; + info!( + "Creating VP9 encoder with backend {:?} (codec: {})", + backend, codec_name + ); VP9Encoder::with_codec(encoder_config, &codec_name)? } else { VP9Encoder::new(encoder_config)? @@ -477,25 +619,71 @@ impl SharedVideoPipeline { } }; - // Determine if encoder needs YUV420P (software encoders) or NV12 (hardware encoders) + // Determine if encoder can take direct input without conversion let codec_name = encoder.codec_name(); - let needs_yuv420p = codec_name.contains("libvpx") || codec_name.contains("libx265"); + let use_direct_input = if codec_name.contains("rkmpp") { + matches!( + config.input_format, + PixelFormat::Yuyv + | PixelFormat::Yuv420 + | PixelFormat::Rgb24 + | PixelFormat::Bgr24 + | PixelFormat::Nv12 + | PixelFormat::Nv16 + | PixelFormat::Nv21 + | PixelFormat::Nv24 + ) + } else if codec_name.contains("libx264") { + matches!( + config.input_format, + PixelFormat::Nv12 | PixelFormat::Nv16 | PixelFormat::Nv21 | PixelFormat::Yuv420 + ) + } else { + false + }; + + // Determine if encoder needs YUV420P (software encoders) or NV12 (hardware encoders) + let needs_yuv420p = if codec_name.contains("libx264") { + !matches!( + config.input_format, + PixelFormat::Nv12 | PixelFormat::Nv16 | PixelFormat::Nv21 | PixelFormat::Yuv420 + ) + } else { + codec_name.contains("libvpx") || codec_name.contains("libx265") + }; info!( "Encoder {} needs {} format", codec_name, - if use_yuyv_direct { "YUYV422 (direct)" } else if needs_yuv420p { "YUV420P" } else { "NV12" } + if use_direct_input { + "direct" + } else if needs_yuv420p { + "YUV420P" + } else { + "NV12" + } ); // Create converter or decoder based on input format and encoder needs - info!("Initializing input format handler for: {} -> {}", - config.input_format, - if use_yuyv_direct { "YUYV422 (direct)" } else if needs_yuv420p { "YUV420P" } else { "NV12" }); + info!( + "Initializing input format handler for: {} -> {}", + config.input_format, + if use_direct_input { + "direct" + } else if needs_yuv420p { + "YUV420P" + } else { + "NV12" + } + ); let (nv12_converter, yuv420p_converter) = if use_yuyv_direct { // RKMPP with YUYV direct input - skip all conversion info!("YUYV direct input enabled for RKMPP, skipping format conversion"); (None, None) + } else if use_direct_input { + info!("Direct input enabled, skipping format conversion"); + (None, None) } else if needs_yuv420p { // Software encoder needs YUV420P match config.input_format { @@ -505,19 +693,38 @@ impl SharedVideoPipeline { } PixelFormat::Yuyv => { info!("Using YUYV->YUV420P converter"); - (None, Some(PixelConverter::yuyv_to_yuv420p(config.resolution))) + ( + None, + Some(PixelConverter::yuyv_to_yuv420p(config.resolution)), + ) } PixelFormat::Nv12 => { info!("Using NV12->YUV420P converter"); - (None, Some(PixelConverter::nv12_to_yuv420p(config.resolution))) + ( + None, + Some(PixelConverter::nv12_to_yuv420p(config.resolution)), + ) + } + PixelFormat::Nv21 => { + info!("Using NV21->YUV420P converter"); + ( + None, + Some(PixelConverter::nv21_to_yuv420p(config.resolution)), + ) } PixelFormat::Rgb24 => { info!("Using RGB24->YUV420P converter"); - (None, Some(PixelConverter::rgb24_to_yuv420p(config.resolution))) + ( + None, + Some(PixelConverter::rgb24_to_yuv420p(config.resolution)), + ) } PixelFormat::Bgr24 => { info!("Using BGR24->YUV420P converter"); - (None, Some(PixelConverter::bgr24_to_yuv420p(config.resolution))) + ( + None, + Some(PixelConverter::bgr24_to_yuv420p(config.resolution)), + ) } _ => { return Err(AppError::VideoError(format!( @@ -537,6 +744,18 @@ impl SharedVideoPipeline { info!("Using YUYV->NV12 converter"); (Some(Nv12Converter::yuyv_to_nv12(config.resolution)), None) } + PixelFormat::Nv21 => { + info!("Using NV21->NV12 converter"); + (Some(Nv12Converter::nv21_to_nv12(config.resolution)), None) + } + PixelFormat::Nv16 => { + info!("Using NV16->NV12 converter"); + (Some(Nv12Converter::nv16_to_nv12(config.resolution)), None) + } + PixelFormat::Yuv420 => { + info!("Using YUV420P->NV12 converter"); + (Some(Nv12Converter::yuv420_to_nv12(config.resolution)), None) + } PixelFormat::Rgb24 => { info!("Using RGB24->NV12 converter"); (Some(Nv12Converter::rgb24_to_nv12(config.resolution)), None) @@ -557,8 +776,9 @@ impl SharedVideoPipeline { *self.encoder.lock().await = Some(encoder); *self.nv12_converter.lock().await = nv12_converter; *self.yuv420p_converter.lock().await = yuv420p_converter; - self.encoder_needs_yuv420p.store(needs_yuv420p, Ordering::Release); - self.yuyv_direct_input.store(use_yuyv_direct, Ordering::Release); + self.encoder_needs_yuv420p + .store(needs_yuv420p, Ordering::Release); + self.direct_input.store(use_direct_input, Ordering::Release); Ok(()) } @@ -646,7 +866,10 @@ impl SharedVideoPipeline { } /// Start the pipeline - pub async fn start(self: &Arc, mut frame_rx: broadcast::Receiver) -> Result<()> { + pub async fn start( + self: &Arc, + mut frame_rx: broadcast::Receiver, + ) -> Result<()> { if *self.running_rx.borrow() { warn!("Pipeline already running"); return Ok(()); @@ -657,7 +880,10 @@ impl SharedVideoPipeline { let config = self.config.read().await.clone(); let gop_size = config.gop_size(); - info!("Starting {} pipeline (GOP={})", config.output_codec, gop_size); + info!( + "Starting {} pipeline (GOP={})", + config.output_codec, gop_size + ); let pipeline = self.clone(); @@ -674,7 +900,6 @@ impl SharedVideoPipeline { let mut local_errors: u64 = 0; let mut local_dropped: u64 = 0; let mut local_skipped: u64 = 0; - // Track when we last had subscribers for auto-stop feature let mut no_subscribers_since: Option = None; let grace_period = Duration::from_secs(AUTO_STOP_GRACE_PERIOD_SECS); @@ -790,7 +1015,11 @@ impl SharedVideoPipeline { } /// Encode a single frame - async fn encode_frame(&self, frame: &VideoFrame, frame_count: u64) -> Result> { + async fn encode_frame( + &self, + frame: &VideoFrame, + frame_count: u64, + ) -> Result> { let config = self.config.read().await; let raw_frame = frame.data(); let fps = config.fps; @@ -835,9 +1064,9 @@ impl SharedVideoPipeline { let needs_yuv420p = self.encoder_needs_yuv420p.load(Ordering::Acquire); let mut encoder_guard = self.encoder.lock().await; - let encoder = encoder_guard.as_mut().ok_or_else(|| { - AppError::VideoError("Encoder not initialized".to_string()) - })?; + let encoder = encoder_guard + .as_mut() + .ok_or_else(|| AppError::VideoError("Encoder not initialized".to_string()))?; // Check and consume keyframe request (atomic, no lock contention) if self.keyframe_requested.swap(false, Ordering::AcqRel) { @@ -848,13 +1077,15 @@ impl SharedVideoPipeline { let encode_result = if needs_yuv420p && yuv420p_converter.is_some() { // Software encoder with direct input conversion to YUV420P let conv = yuv420p_converter.as_mut().unwrap(); - let yuv420p_data = conv.convert(raw_frame) + let yuv420p_data = conv + .convert(raw_frame) .map_err(|e| AppError::VideoError(format!("YUV420P conversion failed: {}", e)))?; encoder.encode_raw(yuv420p_data, pts_ms) } else if nv12_converter.is_some() { // Hardware encoder with input conversion to NV12 let conv = nv12_converter.as_mut().unwrap(); - let nv12_data = conv.convert(raw_frame) + let nv12_data = conv + .convert(raw_frame) .map_err(|e| AppError::VideoError(format!("NV12 conversion failed: {}", e)))?; encoder.encode_raw(nv12_data, pts_ms) } else { @@ -871,7 +1102,6 @@ impl SharedVideoPipeline { if !frames.is_empty() { let encoded = frames.into_iter().next().unwrap(); let is_keyframe = encoded.key == 1; - let sequence = self.sequence.fetch_add(1, Ordering::Relaxed) + 1; // Debug log for H265 encoded frame @@ -901,17 +1131,23 @@ impl SharedVideoPipeline { })) } else { if codec == VideoEncoderType::H265 { - warn!("[Pipeline-H265] Encoder returned no frames for frame #{}", frame_count); + warn!( + "[Pipeline-H265] Encoder returned no frames for frame #{}", + frame_count + ); } Ok(None) } } Err(e) => { if codec == VideoEncoderType::H265 { - error!("[Pipeline-H265] Encode error at frame #{}: {}", frame_count, e); + error!( + "[Pipeline-H265] Encode error at frame #{}: {}", + frame_count, e + ); } Err(e) - }, + } } } @@ -924,7 +1160,10 @@ impl SharedVideoPipeline { } /// Set bitrate using preset - pub async fn set_bitrate_preset(&self, preset: crate::video::encoder::BitratePreset) -> Result<()> { + pub async fn set_bitrate_preset( + &self, + preset: crate::video::encoder::BitratePreset, + ) -> Result<()> { let bitrate_kbps = preset.bitrate_kbps(); if let Some(ref mut encoder) = *self.encoder.lock().await { encoder.set_bitrate(bitrate_kbps)?; @@ -965,11 +1204,7 @@ fn parse_h265_nal_types(data: &[u8]) -> Vec<(u8, usize)> { && data[i + 3] == 1 { i + 4 - } else if i + 3 <= data.len() - && data[i] == 0 - && data[i + 1] == 0 - && data[i + 2] == 1 - { + } else if i + 3 <= data.len() && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 { i + 3 } else { i += 1; diff --git a/src/video/stream_manager.rs b/src/video/stream_manager.rs index 1fdaa718..3dc56398 100644 --- a/src/video/stream_manager.rs +++ b/src/video/stream_manager.rs @@ -30,6 +30,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::sync::RwLock; use tracing::{debug, error, info, warn}; +use uuid::Uuid; use crate::config::{ConfigStore, StreamMode}; use crate::error::Result; @@ -55,6 +56,17 @@ pub struct StreamManagerConfig { pub fps: u32, } +/// Result of a mode switch request. +#[derive(Debug, Clone)] +pub struct ModeSwitchTransaction { + /// Whether this request started a new switch. + pub accepted: bool, + /// Whether a switch is currently in progress after handling this request. + pub switching: bool, + /// Transition ID if a switch is/was in progress. + pub transition_id: Option, +} + impl Default for StreamManagerConfig { fn default() -> Self { Self { @@ -90,6 +102,8 @@ pub struct VideoStreamManager { config_store: RwLock>, /// Mode switching lock to prevent concurrent switch requests switching: AtomicBool, + /// Current mode switch transaction ID (set while switching=true) + transition_id: RwLock>, } impl VideoStreamManager { @@ -105,6 +119,7 @@ impl VideoStreamManager { events: RwLock::new(None), config_store: RwLock::new(None), switching: AtomicBool::new(false), + transition_id: RwLock::new(None), }) } @@ -113,6 +128,11 @@ impl VideoStreamManager { self.switching.load(Ordering::SeqCst) } + /// Get current mode switch transition ID, if any + pub async fn current_transition_id(&self) -> Option { + self.transition_id.read().await.clone() + } + /// Set event bus for notifications pub async fn set_event_bus(&self, events: Arc) { *self.events.write().await = Some(events); @@ -188,7 +208,9 @@ impl VideoStreamManager { "Reconnecting frame source to WebRTC after init: {}x{} {:?} @ {}fps (receiver_count={})", resolution.width, resolution.height, format, fps, frame_tx.receiver_count() ); - self.webrtc_streamer.update_video_config(resolution, format, fps).await; + self.webrtc_streamer + .update_video_config(resolution, format, fps) + .await; self.webrtc_streamer.set_video_source(frame_tx).await; } @@ -204,6 +226,18 @@ impl VideoStreamManager { /// 4. Start the new mode (ensuring video capture runs for WebRTC) /// 5. Update configuration pub async fn switch_mode(self: &Arc, new_mode: StreamMode) -> Result<()> { + let _ = self.switch_mode_transaction(new_mode).await?; + Ok(()) + } + + /// Switch streaming mode with a transaction ID for correlating events + /// + /// If a switch is already in progress, returns `accepted=false` with the + /// current `transition_id` (if known) and does not start a new switch. + pub async fn switch_mode_transaction( + self: &Arc, + new_mode: StreamMode, + ) -> Result { let current_mode = self.mode.read().await.clone(); if current_mode == new_mode { @@ -212,19 +246,85 @@ impl VideoStreamManager { if new_mode == StreamMode::WebRTC { self.ensure_video_capture_running().await?; } - return Ok(()); + return Ok(ModeSwitchTransaction { + accepted: false, + switching: false, + transition_id: None, + }); } // Acquire switching lock - prevent concurrent switch requests - if self.switching.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err() { + if self + .switching + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_err() + { debug!("Mode switch already in progress, ignoring duplicate request"); - return Ok(()); + return Ok(ModeSwitchTransaction { + accepted: false, + switching: true, + transition_id: self.transition_id.read().await.clone(), + }); } - // Use a helper to ensure we release the lock when done - let result = self.do_switch_mode(current_mode, new_mode.clone()).await; - self.switching.store(false, Ordering::SeqCst); - result + let transition_id = Uuid::new_v4().to_string(); + *self.transition_id.write().await = Some(transition_id.clone()); + + // Publish transaction start event + let from_mode_str = self.mode_to_string(¤t_mode).await; + let to_mode_str = self.mode_to_string(&new_mode).await; + self.publish_event(SystemEvent::StreamModeSwitching { + transition_id: transition_id.clone(), + to_mode: to_mode_str, + from_mode: from_mode_str, + }) + .await; + + // Perform the switch asynchronously so the HTTP handler can return + // immediately and clients can reliably wait for WebSocket events. + let manager = Arc::clone(self); + let transition_id_for_task = transition_id.clone(); + tokio::spawn(async move { + let result = manager + .do_switch_mode(current_mode, new_mode, transition_id_for_task.clone()) + .await; + + if let Err(e) = result { + error!( + "Mode switch transaction {} failed: {}", + transition_id_for_task, e + ); + } + + // Publish transaction end marker with best-effort actual mode + let actual_mode = manager.mode.read().await.clone(); + let actual_mode_str = manager.mode_to_string(&actual_mode).await; + manager + .publish_event(SystemEvent::StreamModeReady { + transition_id: transition_id_for_task.clone(), + mode: actual_mode_str, + }) + .await; + + *manager.transition_id.write().await = None; + manager.switching.store(false, Ordering::SeqCst); + }); + + Ok(ModeSwitchTransaction { + accepted: true, + switching: true, + transition_id: Some(transition_id), + }) + } + + async fn mode_to_string(&self, mode: &StreamMode) -> String { + match mode { + StreamMode::Mjpeg => "mjpeg".to_string(), + StreamMode::WebRTC => { + let codec = self.webrtc_streamer.current_video_codec().await; + codec_to_string(codec) + } + } } /// Ensure video capture is running (for WebRTC mode) @@ -257,7 +357,9 @@ impl VideoStreamManager { "Reconnecting frame source to WebRTC: {}x{} {:?} @ {}fps", resolution.width, resolution.height, format, fps ); - self.webrtc_streamer.update_video_config(resolution, format, fps).await; + self.webrtc_streamer + .update_video_config(resolution, format, fps) + .await; self.webrtc_streamer.set_video_source(frame_tx).await; } @@ -265,7 +367,12 @@ impl VideoStreamManager { } /// Internal implementation of mode switching (called with lock held) - async fn do_switch_mode(self: &Arc, current_mode: StreamMode, new_mode: StreamMode) -> Result<()> { + async fn do_switch_mode( + self: &Arc, + current_mode: StreamMode, + new_mode: StreamMode, + transition_id: String, + ) -> Result<()> { info!("Switching video mode: {:?} -> {:?}", current_mode, new_mode); // Get the actual mode strings (with codec info for WebRTC) @@ -286,6 +393,7 @@ impl VideoStreamManager { // 1. Publish mode change event (clients should prepare to reconnect) self.publish_event(SystemEvent::StreamModeChanged { + transition_id: Some(transition_id.clone()), mode: new_mode_str, previous_mode: previous_mode_str, }) @@ -320,15 +428,26 @@ impl VideoStreamManager { // Auto-switch to MJPEG format if device supports it if let Some(device) = self.streamer.current_device().await { - let (current_format, resolution, fps) = self.streamer.current_video_config().await; - let available_formats: Vec = device.formats.iter().map(|f| f.format).collect(); + let (current_format, resolution, fps) = + self.streamer.current_video_config().await; + let available_formats: Vec = + device.formats.iter().map(|f| f.format).collect(); // If current format is not MJPEG and device supports MJPEG, switch to it - if current_format != PixelFormat::Mjpeg && available_formats.contains(&PixelFormat::Mjpeg) { + if current_format != PixelFormat::Mjpeg + && available_formats.contains(&PixelFormat::Mjpeg) + { info!("Auto-switching to MJPEG format for MJPEG mode"); let device_path = device.path.to_string_lossy().to_string(); - if let Err(e) = self.streamer.apply_video_config(&device_path, PixelFormat::Mjpeg, resolution, fps).await { - warn!("Failed to auto-switch to MJPEG format: {}, keeping current format", e); + if let Err(e) = self + .streamer + .apply_video_config(&device_path, PixelFormat::Mjpeg, resolution, fps) + .await + { + warn!( + "Failed to auto-switch to MJPEG format: {}, keeping current format", + e + ); } } } @@ -353,21 +472,29 @@ impl VideoStreamManager { // Auto-switch to non-compressed format if current format is MJPEG/JPEG if let Some(device) = self.streamer.current_device().await { - let (current_format, resolution, fps) = self.streamer.current_video_config().await; + let (current_format, resolution, fps) = + self.streamer.current_video_config().await; if current_format.is_compressed() { - let available_formats: Vec = device.formats.iter().map(|f| f.format).collect(); + let available_formats: Vec = + device.formats.iter().map(|f| f.format).collect(); // Determine if using hardware encoding let is_hardware = self.webrtc_streamer.is_hardware_encoding().await; - if let Some(recommended) = PixelFormat::recommended_for_encoding(&available_formats, is_hardware) { + if let Some(recommended) = + PixelFormat::recommended_for_encoding(&available_formats, is_hardware) + { info!( "Auto-switching from {:?} to {:?} for WebRTC encoding (hardware={})", current_format, recommended, is_hardware ); let device_path = device.path.to_string_lossy().to_string(); - if let Err(e) = self.streamer.apply_video_config(&device_path, recommended, resolution, fps).await { + if let Err(e) = self + .streamer + .apply_video_config(&device_path, recommended, resolution, fps) + .await + { warn!("Failed to auto-switch format for WebRTC: {}, keeping current format", e); } } @@ -394,33 +521,24 @@ impl VideoStreamManager { "Connecting frame source to WebRTC pipeline: {}x{} {:?} @ {}fps", resolution.width, resolution.height, format, fps ); - self.webrtc_streamer.update_video_config(resolution, format, fps).await; + self.webrtc_streamer + .update_video_config(resolution, format, fps) + .await; self.webrtc_streamer.set_video_source(frame_tx).await; - // Get device path for events - let device_path = self.streamer.current_device().await - .map(|d| d.path.to_string_lossy().to_string()) - .unwrap_or_default(); - - // Publish StreamConfigApplied event - clients can now safely connect - self.publish_event(SystemEvent::StreamConfigApplied { - device: device_path, - resolution: (resolution.width, resolution.height), - format: format!("{:?}", format).to_lowercase(), - fps, - }) - .await; - // Publish WebRTCReady event - frame source is now connected let codec = self.webrtc_streamer.current_video_codec().await; let is_hardware = self.webrtc_streamer.is_hardware_encoding().await; self.publish_event(SystemEvent::WebRTCReady { + transition_id: Some(transition_id.clone()), codec: codec_to_string(codec), hardware: is_hardware, }) .await; } else { - warn!("No frame source available for WebRTC - sessions may fail to receive video"); + warn!( + "No frame source available for WebRTC - sessions may fail to receive video" + ); } info!("WebRTC mode activated (sessions created on-demand)"); @@ -483,13 +601,16 @@ impl VideoStreamManager { if let Some(frame_tx) = self.streamer.frame_sender().await { // Note: update_video_config was already called above with the requested config, // but verify that actual capture matches - let (actual_format, actual_resolution, actual_fps) = self.streamer.current_video_config().await; + let (actual_format, actual_resolution, actual_fps) = + self.streamer.current_video_config().await; if actual_format != format || actual_resolution != resolution || actual_fps != fps { info!( "Actual capture config differs from requested, updating WebRTC: {}x{} {:?} @ {}fps", actual_resolution.width, actual_resolution.height, actual_format, actual_fps ); - self.webrtc_streamer.update_video_config(actual_resolution, actual_format, actual_fps).await; + self.webrtc_streamer + .update_video_config(actual_resolution, actual_format, actual_fps) + .await; } info!("Reconnecting frame source to WebRTC after config change"); self.webrtc_streamer.set_video_source(frame_tx).await; @@ -522,7 +643,9 @@ impl VideoStreamManager { if let Some(frame_tx) = self.streamer.frame_sender().await { // Synchronize WebRTC config with actual capture format let (format, resolution, fps) = self.streamer.current_video_config().await; - self.webrtc_streamer.update_video_config(resolution, format, fps).await; + self.webrtc_streamer + .update_video_config(resolution, format, fps) + .await; self.webrtc_streamer.set_video_source(frame_tx).await; } } @@ -620,7 +743,9 @@ impl VideoStreamManager { // ========================================================================= /// List available video devices - pub async fn list_devices(&self) -> crate::error::Result> { + pub async fn list_devices( + &self, + ) -> crate::error::Result> { self.streamer.list_devices().await } @@ -640,7 +765,9 @@ impl VideoStreamManager { } /// Get frame sender for video frames - pub async fn frame_sender(&self) -> Option> { + pub async fn frame_sender( + &self, + ) -> Option> { self.streamer.frame_sender().await } @@ -654,12 +781,17 @@ impl VideoStreamManager { /// Returns None if video capture cannot be started or pipeline creation fails. pub async fn subscribe_encoded_frames( &self, - ) -> Option> { + ) -> Option< + tokio::sync::broadcast::Receiver, + > { // 1. Ensure video capture is initialized if self.streamer.state().await == StreamerState::Uninitialized { tracing::info!("Initializing video capture for encoded frame subscription"); if let Err(e) = self.streamer.init_auto().await { - tracing::error!("Failed to initialize video capture for encoded frames: {}", e); + tracing::error!( + "Failed to initialize video capture for encoded frames: {}", + e + ); return None; } } @@ -688,13 +820,22 @@ impl VideoStreamManager { let (format, resolution, fps) = self.streamer.current_video_config().await; tracing::info!( "Connecting encoded frame subscription: {}x{} {:?} @ {}fps", - resolution.width, resolution.height, format, fps + resolution.width, + resolution.height, + format, + fps ); - self.webrtc_streamer.update_video_config(resolution, format, fps).await; + self.webrtc_streamer + .update_video_config(resolution, format, fps) + .await; // 5. Use WebRtcStreamer to ensure the shared video pipeline is running // This will create the pipeline if needed - match self.webrtc_streamer.ensure_video_pipeline_for_external(frame_tx).await { + match self + .webrtc_streamer + .ensure_video_pipeline_for_external(frame_tx) + .await + { Ok(pipeline) => Some(pipeline.subscribe()), Err(e) => { tracing::error!("Failed to start shared video pipeline: {}", e); @@ -704,7 +845,9 @@ impl VideoStreamManager { } /// Get the current video encoding configuration from the shared pipeline - pub async fn get_encoding_config(&self) -> Option { + pub async fn get_encoding_config( + &self, + ) -> Option { self.webrtc_streamer.get_pipeline_config().await } @@ -712,7 +855,10 @@ impl VideoStreamManager { /// /// This allows external consumers (like RustDesk) to set the video codec /// before subscribing to encoded frames. - pub async fn set_video_codec(&self, codec: crate::video::encoder::VideoCodecType) -> crate::error::Result<()> { + pub async fn set_video_codec( + &self, + codec: crate::video::encoder::VideoCodecType, + ) -> crate::error::Result<()> { self.webrtc_streamer.set_video_codec(codec).await } @@ -720,7 +866,10 @@ impl VideoStreamManager { /// /// This allows external consumers (like RustDesk) to adjust the video quality /// based on client preferences. - pub async fn set_bitrate_preset(&self, preset: crate::video::encoder::BitratePreset) -> crate::error::Result<()> { + pub async fn set_bitrate_preset( + &self, + preset: crate::video::encoder::BitratePreset, + ) -> crate::error::Result<()> { self.webrtc_streamer.set_bitrate_preset(preset).await } diff --git a/src/video/streamer.rs b/src/video/streamer.rs index 92b60fa9..7e0ab5ae 100644 --- a/src/video/streamer.rs +++ b/src/video/streamer.rs @@ -133,7 +133,12 @@ impl Streamer { /// Get current state as SystemEvent pub async fn current_state_event(&self) -> SystemEvent { let state = *self.state.read().await; - let device = self.current_device.read().await.as_ref().map(|d| d.path.display().to_string()); + let device = self + .current_device + .read() + .await + .as_ref() + .map(|d| d.path.display().to_string()); SystemEvent::StreamStateChanged { state: match state { @@ -162,7 +167,8 @@ impl Streamer { /// Check if config is currently being changed /// When true, auto-start should be blocked to prevent device busy errors pub fn is_config_changing(&self) -> bool { - self.config_changing.load(std::sync::atomic::Ordering::SeqCst) + self.config_changing + .load(std::sync::atomic::Ordering::SeqCst) } /// Get MJPEG handler for stream endpoints @@ -209,13 +215,17 @@ impl Streamer { fps: u32, ) -> Result<()> { // Set config_changing flag to prevent frontend mode sync during config change - self.config_changing.store(true, std::sync::atomic::Ordering::SeqCst); + self.config_changing + .store(true, std::sync::atomic::Ordering::SeqCst); - let result = self.apply_video_config_inner(device_path, format, resolution, fps).await; + let result = self + .apply_video_config_inner(device_path, format, resolution, fps) + .await; // Clear the flag after config change is complete // The stream will be started by MJPEG client connection, not here - self.config_changing.store(false, std::sync::atomic::Ordering::SeqCst); + self.config_changing + .store(false, std::sync::atomic::Ordering::SeqCst); result } @@ -230,6 +240,7 @@ impl Streamer { ) -> Result<()> { // Publish "config changing" event self.publish_event(SystemEvent::StreamConfigChanging { + transition_id: None, reason: "device_switch".to_string(), }) .await; @@ -254,7 +265,9 @@ impl Streamer { .iter() .any(|r| r.width == resolution.width && r.height == resolution.height) { - return Err(AppError::VideoError("Requested resolution not supported".to_string())); + return Err(AppError::VideoError( + "Requested resolution not supported".to_string(), + )); } // IMPORTANT: Disconnect all MJPEG clients FIRST before stopping capture @@ -277,7 +290,6 @@ impl Streamer { // Explicitly drop the capturer to release V4L2 resources drop(capturer); } - } // Update config @@ -305,9 +317,12 @@ impl Streamer { *self.state.write().await = StreamerState::Ready; // Publish "config applied" event - info!("Publishing StreamConfigApplied event: {}x{} {:?} @ {}fps", - resolution.width, resolution.height, format, fps); + info!( + "Publishing StreamConfigApplied event: {}x{} {:?} @ {}fps", + resolution.width, resolution.height, format, fps + ); self.publish_event(SystemEvent::StreamConfigApplied { + transition_id: None, device: device_path.to_string(), resolution: (resolution.width, resolution.height), format: format!("{:?}", format), @@ -381,7 +396,11 @@ impl Streamer { } /// Select best format for device - fn select_format(&self, device: &VideoDeviceInfo, preferred: PixelFormat) -> Result { + fn select_format( + &self, + device: &VideoDeviceInfo, + preferred: PixelFormat, + ) -> Result { // Check if preferred format is available if device.formats.iter().any(|f| f.format == preferred) { return Ok(preferred); @@ -410,9 +429,10 @@ impl Streamer { // Check if preferred resolution is available if format_info.resolutions.is_empty() - || format_info.resolutions.iter().any(|r| { - r.width == preferred.width && r.height == preferred.height - }) + || format_info + .resolutions + .iter() + .any(|r| r.width == preferred.width && r.height == preferred.height) { return Ok(preferred); } @@ -528,7 +548,10 @@ impl Streamer { // Stop the streamer if let Some(streamer) = state_ref.upgrade() { if let Err(e) = streamer.stop().await { - warn!("Failed to stop streamer during idle cleanup: {}", e); + warn!( + "Failed to stop streamer during idle cleanup: {}", + e + ); } } break; @@ -609,8 +632,14 @@ impl Streamer { // Start background tasks only once per Streamer instance // Use compare_exchange to atomically check and set the flag - if self.background_tasks_started - .compare_exchange(false, true, std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst) + if self + .background_tasks_started + .compare_exchange( + false, + true, + std::sync::atomic::Ordering::SeqCst, + std::sync::atomic::Ordering::SeqCst, + ) .is_ok() { info!("Starting background tasks (stats, cleanup, monitor)"); @@ -626,10 +655,12 @@ impl Streamer { let clients_stat = streamer.mjpeg_handler().get_clients_stat(); let clients = clients_stat.len() as u64; - streamer.publish_event(SystemEvent::StreamStatsUpdate { - clients, - clients_stat, - }).await; + streamer + .publish_event(SystemEvent::StreamStatsUpdate { + clients, + clients_stat, + }) + .await; } else { break; } @@ -649,7 +680,9 @@ impl Streamer { loop { interval.tick().await; - let Some(streamer) = monitor_ref.upgrade() else { break; }; + let Some(streamer) = monitor_ref.upgrade() else { + break; + }; // Check auto-pause configuration let config = monitor_handler.auto_pause_config(); @@ -663,10 +696,16 @@ impl Streamer { if count == 0 { if zero_since.is_none() { zero_since = Some(std::time::Instant::now()); - info!("No clients connected, starting shutdown timer ({}s)", config.shutdown_delay_secs); + info!( + "No clients connected, starting shutdown timer ({}s)", + config.shutdown_delay_secs + ); } else if let Some(since) = zero_since { if since.elapsed().as_secs() >= config.shutdown_delay_secs { - info!("Auto-pausing stream (no clients for {}s)", config.shutdown_delay_secs); + info!( + "Auto-pausing stream (no clients for {}s)", + config.shutdown_delay_secs + ); if let Err(e) = streamer.stop().await { error!("Auto-pause failed: {}", e); } @@ -734,8 +773,14 @@ impl Streamer { clients: self.mjpeg_handler.client_count(), target_fps: config.fps, fps: capture_stats.as_ref().map(|s| s.current_fps).unwrap_or(0.0), - frames_captured: capture_stats.as_ref().map(|s| s.frames_captured).unwrap_or(0), - frames_dropped: capture_stats.as_ref().map(|s| s.frames_dropped).unwrap_or(0), + frames_captured: capture_stats + .as_ref() + .map(|s| s.frames_captured) + .unwrap_or(0), + frames_dropped: capture_stats + .as_ref() + .map(|s| s.frames_dropped) + .unwrap_or(0), } } @@ -776,7 +821,10 @@ impl Streamer { /// until the device is recovered. async fn start_device_recovery_internal(self: &Arc) { // Check if recovery is already in progress - if self.recovery_in_progress.swap(true, std::sync::atomic::Ordering::SeqCst) { + if self + .recovery_in_progress + .swap(true, std::sync::atomic::Ordering::SeqCst) + { debug!("Device recovery already in progress, skipping"); return; } @@ -786,7 +834,9 @@ impl Streamer { let capturer = self.capturer.read().await; if let Some(cap) = capturer.as_ref() { cap.last_error().unwrap_or_else(|| { - let device_path = self.current_device.blocking_read() + let device_path = self + .current_device + .blocking_read() .as_ref() .map(|d| d.path.display().to_string()) .unwrap_or_else(|| "unknown".to_string()); @@ -800,13 +850,15 @@ impl Streamer { // Store error info *self.last_lost_device.write().await = Some(device.clone()); *self.last_lost_reason.write().await = Some(reason.clone()); - self.recovery_retry_count.store(0, std::sync::atomic::Ordering::Relaxed); + self.recovery_retry_count + .store(0, std::sync::atomic::Ordering::Relaxed); // Publish device lost event self.publish_event(SystemEvent::StreamDeviceLost { device: device.clone(), reason: reason.clone(), - }).await; + }) + .await; // Start recovery task let streamer = Arc::clone(self); @@ -814,11 +866,16 @@ impl Streamer { let device_path = device.clone(); loop { - let attempt = streamer.recovery_retry_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; + let attempt = streamer + .recovery_retry_count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + + 1; // Check if still in device lost state let current_state = *streamer.state.read().await; - if current_state != StreamerState::DeviceLost && current_state != StreamerState::Recovering { + if current_state != StreamerState::DeviceLost + && current_state != StreamerState::Recovering + { info!("Stream state changed during recovery, stopping recovery task"); break; } @@ -828,11 +885,16 @@ impl Streamer { // Publish reconnecting event (every 5 attempts to avoid spam) if attempt == 1 || attempt % 5 == 0 { - streamer.publish_event(SystemEvent::StreamReconnecting { - device: device_path.clone(), - attempt, - }).await; - info!("Attempting to recover video device {} (attempt {})", device_path, attempt); + streamer + .publish_event(SystemEvent::StreamReconnecting { + device: device_path.clone(), + attempt, + }) + .await; + info!( + "Attempting to recover video device {} (attempt {})", + device_path, attempt + ); } // Wait before retry (1 second) @@ -848,13 +910,20 @@ impl Streamer { // Try to restart capture match streamer.restart_capturer().await { Ok(_) => { - info!("Video device {} recovered after {} attempts", device_path, attempt); - streamer.recovery_in_progress.store(false, std::sync::atomic::Ordering::SeqCst); + info!( + "Video device {} recovered after {} attempts", + device_path, attempt + ); + streamer + .recovery_in_progress + .store(false, std::sync::atomic::Ordering::SeqCst); // Publish recovered event - streamer.publish_event(SystemEvent::StreamRecovered { - device: device_path.clone(), - }).await; + streamer + .publish_event(SystemEvent::StreamRecovered { + device: device_path.clone(), + }) + .await; // Clear error info *streamer.last_lost_device.write().await = None; @@ -867,7 +936,9 @@ impl Streamer { } } - streamer.recovery_in_progress.store(false, std::sync::atomic::Ordering::SeqCst); + streamer + .recovery_in_progress + .store(false, std::sync::atomic::Ordering::SeqCst); }); } } diff --git a/src/video/video_session.rs b/src/video/video_session.rs index 9734f11a..81b5ee41 100644 --- a/src/video/video_session.rs +++ b/src/video/video_session.rs @@ -234,10 +234,7 @@ impl VideoSessionManager { let mut sessions = self.sessions.write().await; sessions.insert(session_id.clone(), session); - info!( - "Video session created: {} (codec: {})", - session_id, codec - ); + info!("Video session created: {} (codec: {})", session_id, codec); Ok(session_id) } @@ -428,8 +425,7 @@ impl VideoSessionManager { sessions .iter() .filter(|(_, s)| { - (s.state == VideoSessionState::Paused - || s.state == VideoSessionState::Created) + (s.state == VideoSessionState::Paused || s.state == VideoSessionState::Created) && now.duration_since(s.last_activity) > timeout }) .map(|(id, _)| id.clone()) diff --git a/src/web/handlers/config/apply.rs b/src/web/handlers/config/apply.rs index 49b98579..6dac7862 100644 --- a/src/web/handlers/config/apply.rs +++ b/src/web/handlers/config/apply.rs @@ -31,15 +31,14 @@ pub async fn apply_video_config( .format .as_ref() .and_then(|f| { - serde_json::from_value::( - serde_json::Value::String(f.clone()), - ) + serde_json::from_value::(serde_json::Value::String( + f.clone(), + )) .ok() }) .unwrap_or(crate::video::format::PixelFormat::Mjpeg); - let resolution = - crate::video::format::Resolution::new(new_config.width, new_config.height); + let resolution = crate::video::format::Resolution::new(new_config.width, new_config.height); // Step 1: 更新 WebRTC streamer 配置(停止现有 pipeline 和 sessions) state @@ -162,9 +161,16 @@ pub async fn apply_hid_config( // 如果描述符变更且当前使用 OTG 后端,需要重建 Gadget if descriptor_changed && new_config.backend == HidBackend::Otg { tracing::info!("OTG descriptor changed, updating gadget..."); - if let Err(e) = state.otg_service.update_descriptor(&new_config.otg_descriptor).await { + if let Err(e) = state + .otg_service + .update_descriptor(&new_config.otg_descriptor) + .await + { tracing::error!("Failed to update OTG descriptor: {}", e); - return Err(AppError::Config(format!("OTG descriptor update failed: {}", e))); + return Err(AppError::Config(format!( + "OTG descriptor update failed: {}", + e + ))); } tracing::info!("OTG descriptor updated successfully"); } @@ -197,7 +203,10 @@ pub async fn apply_hid_config( .await .map_err(|e| AppError::Config(format!("HID reload failed: {}", e)))?; - tracing::info!("HID backend reloaded successfully: {:?}", new_config.backend); + tracing::info!( + "HID backend reloaded successfully: {:?}", + new_config.backend + ); // When switching to OTG backend, automatically enable MSD if not already enabled // OTG HID and MSD share the same USB gadget, so it makes sense to enable both @@ -245,7 +254,11 @@ pub async fn apply_msd_config( let old_msd_enabled = old_config.enabled; let new_msd_enabled = new_config.enabled; - tracing::info!("MSD enabled: old={}, new={}", old_msd_enabled, new_msd_enabled); + tracing::info!( + "MSD enabled: old={}, new={}", + old_msd_enabled, + new_msd_enabled + ); if old_msd_enabled != new_msd_enabled { if new_msd_enabled { @@ -257,9 +270,9 @@ pub async fn apply_msd_config( &new_config.images_path, &new_config.drive_path, ); - msd.init().await.map_err(|e| { - AppError::Config(format!("MSD initialization failed: {}", e)) - })?; + msd.init() + .await + .map_err(|e| AppError::Config(format!("MSD initialization failed: {}", e)))?; // Set event bus let events = state.events.clone(); @@ -429,7 +442,10 @@ pub async fn apply_rustdesk_config( if let Err(e) = service.restart(new_config.clone()).await { tracing::error!("Failed to restart RustDesk service: {}", e); } else { - tracing::info!("RustDesk service restarted with ID: {}", new_config.device_id); + tracing::info!( + "RustDesk service restarted with ID: {}", + new_config.device_id + ); // Save generated keypair and UUID to config credentials_to_save = service.save_credentials(); } diff --git a/src/web/handlers/config/mod.rs b/src/web/handlers/config/mod.rs index b53ac841..5c9ee5e9 100644 --- a/src/web/handlers/config/mod.rs +++ b/src/web/handlers/config/mod.rs @@ -19,26 +19,26 @@ pub(crate) mod apply; mod types; -pub(crate) mod video; -mod stream; -mod hid; -mod msd; mod atx; mod audio; +mod hid; +mod msd; mod rustdesk; +mod stream; +pub(crate) mod video; mod web; // 导出 handler 函数 -pub use video::{get_video_config, update_video_config}; -pub use stream::{get_stream_config, update_stream_config}; -pub use hid::{get_hid_config, update_hid_config}; -pub use msd::{get_msd_config, update_msd_config}; pub use atx::{get_atx_config, update_atx_config}; pub use audio::{get_audio_config, update_audio_config}; +pub use hid::{get_hid_config, update_hid_config}; +pub use msd::{get_msd_config, update_msd_config}; pub use rustdesk::{ - get_rustdesk_config, get_rustdesk_status, update_rustdesk_config, - regenerate_device_id, regenerate_device_password, get_device_password, + get_device_password, get_rustdesk_config, get_rustdesk_status, regenerate_device_id, + regenerate_device_password, update_rustdesk_config, }; +pub use stream::{get_stream_config, update_stream_config}; +pub use video::{get_video_config, update_video_config}; pub use web::{get_web_config, update_web_config}; // 保留全局配置查询(向后兼容) diff --git a/src/web/handlers/config/rustdesk.rs b/src/web/handlers/config/rustdesk.rs index 1b291461..29dbf3c9 100644 --- a/src/web/handlers/config/rustdesk.rs +++ b/src/web/handlers/config/rustdesk.rs @@ -48,12 +48,16 @@ pub struct RustDeskStatusResponse { } /// 获取 RustDesk 配置 -pub async fn get_rustdesk_config(State(state): State>) -> Json { +pub async fn get_rustdesk_config( + State(state): State>, +) -> Json { Json(RustDeskConfigResponse::from(&state.config.get().rustdesk)) } /// 获取 RustDesk 完整状态(配置 + 服务状态) -pub async fn get_rustdesk_status(State(state): State>) -> Json { +pub async fn get_rustdesk_status( + State(state): State>, +) -> Json { let config = state.config.get().rustdesk.clone(); // 获取服务状态 diff --git a/src/web/handlers/config/types.rs b/src/web/handlers/config/types.rs index dfa0da8c..b23bdb49 100644 --- a/src/web/handlers/config/types.rs +++ b/src/web/handlers/config/types.rs @@ -1,9 +1,9 @@ -use serde::Deserialize; -use typeshare::typeshare; use crate::config::*; use crate::error::AppError; use crate::rustdesk::config::RustDeskConfig; use crate::video::encoder::BitratePreset; +use serde::Deserialize; +use typeshare::typeshare; // ===== Video Config ===== #[typeshare] @@ -21,12 +21,16 @@ impl VideoConfigUpdate { pub fn validate(&self) -> crate::error::Result<()> { if let Some(width) = self.width { if !(320..=7680).contains(&width) { - return Err(AppError::BadRequest("Invalid width: must be 320-7680".into())); + return Err(AppError::BadRequest( + "Invalid width: must be 320-7680".into(), + )); } } if let Some(height) = self.height { if !(240..=4320).contains(&height) { - return Err(AppError::BadRequest("Invalid height: must be 240-4320".into())); + return Err(AppError::BadRequest( + "Invalid height: must be 240-4320".into(), + )); } } if let Some(fps) = self.fps { @@ -36,7 +40,9 @@ impl VideoConfigUpdate { } if let Some(quality) = self.quality { if !(1..=100).contains(&quality) { - return Err(AppError::BadRequest("Invalid quality: must be 1-100".into())); + return Err(AppError::BadRequest( + "Invalid quality: must be 1-100".into(), + )); } } Ok(()) @@ -126,7 +132,8 @@ impl StreamConfigUpdate { if let Some(ref stun) = self.stun_server { if !stun.is_empty() && !stun.starts_with("stun:") { return Err(AppError::BadRequest( - "STUN server must start with 'stun:' (e.g., stun:stun.l.google.com:19302)".into(), + "STUN server must start with 'stun:' (e.g., stun:stun.l.google.com:19302)" + .into(), )); } } @@ -153,16 +160,32 @@ impl StreamConfigUpdate { } // STUN/TURN settings - empty string means clear (use public servers), Some("value") means set custom if let Some(ref stun) = self.stun_server { - config.stun_server = if stun.is_empty() { None } else { Some(stun.clone()) }; + config.stun_server = if stun.is_empty() { + None + } else { + Some(stun.clone()) + }; } if let Some(ref turn) = self.turn_server { - config.turn_server = if turn.is_empty() { None } else { Some(turn.clone()) }; + config.turn_server = if turn.is_empty() { + None + } else { + Some(turn.clone()) + }; } if let Some(ref username) = self.turn_username { - config.turn_username = if username.is_empty() { None } else { Some(username.clone()) }; + config.turn_username = if username.is_empty() { + None + } else { + Some(username.clone()) + }; } if let Some(ref password) = self.turn_password { - config.turn_password = if password.is_empty() { None } else { Some(password.clone()) }; + config.turn_password = if password.is_empty() { + None + } else { + Some(password.clone()) + }; } } } @@ -185,19 +208,25 @@ impl OtgDescriptorConfigUpdate { // Validate manufacturer string length if let Some(ref s) = self.manufacturer { if s.len() > 126 { - return Err(AppError::BadRequest("Manufacturer string too long (max 126 chars)".into())); + return Err(AppError::BadRequest( + "Manufacturer string too long (max 126 chars)".into(), + )); } } // Validate product string length if let Some(ref s) = self.product { if s.len() > 126 { - return Err(AppError::BadRequest("Product string too long (max 126 chars)".into())); + return Err(AppError::BadRequest( + "Product string too long (max 126 chars)".into(), + )); } } // Validate serial number string length if let Some(ref s) = self.serial_number { if s.len() > 126 { - return Err(AppError::BadRequest("Serial number string too long (max 126 chars)".into())); + return Err(AppError::BadRequest( + "Serial number string too long (max 126 chars)".into(), + )); } } Ok(()) @@ -469,7 +498,8 @@ impl RustDeskConfigUpdate { if let Some(ref server) = self.rendezvous_server { if !server.is_empty() && !server.contains(':') { return Err(AppError::BadRequest( - "Rendezvous server must be in format 'host:port' (e.g., rs.example.com:21116)".into(), + "Rendezvous server must be in format 'host:port' (e.g., rs.example.com:21116)" + .into(), )); } } @@ -477,7 +507,8 @@ impl RustDeskConfigUpdate { if let Some(ref server) = self.relay_server { if !server.is_empty() && !server.contains(':') { return Err(AppError::BadRequest( - "Relay server must be in format 'host:port' (e.g., rs.example.com:21117)".into(), + "Relay server must be in format 'host:port' (e.g., rs.example.com:21117)" + .into(), )); } } @@ -500,10 +531,18 @@ impl RustDeskConfigUpdate { config.rendezvous_server = server.clone(); } if let Some(ref server) = self.relay_server { - config.relay_server = if server.is_empty() { None } else { Some(server.clone()) }; + config.relay_server = if server.is_empty() { + None + } else { + Some(server.clone()) + }; } if let Some(ref key) = self.relay_key { - config.relay_key = if key.is_empty() { None } else { Some(key.clone()) }; + config.relay_key = if key.is_empty() { + None + } else { + Some(key.clone()) + }; } if let Some(ref password) = self.device_password { if !password.is_empty() { diff --git a/src/web/handlers/extensions.rs b/src/web/handlers/extensions.rs index dabbead6..f91cdb2b 100644 --- a/src/web/handlers/extensions.rs +++ b/src/web/handlers/extensions.rs @@ -10,8 +10,8 @@ use typeshare::typeshare; use crate::error::{AppError, Result}; use crate::extensions::{ - EasytierConfig, EasytierInfo, ExtensionId, ExtensionInfo, ExtensionLogs, - ExtensionsStatus, GostcConfig, GostcInfo, TtydConfig, TtydInfo, + EasytierConfig, EasytierInfo, ExtensionId, ExtensionInfo, ExtensionLogs, ExtensionsStatus, + GostcConfig, GostcInfo, TtydConfig, TtydInfo, }; use crate::state::AppState; @@ -108,9 +108,7 @@ pub async fn stop_extension( let mgr = &state.extensions; // Stop the extension - mgr.stop(ext_id) - .await - .map_err(|e| AppError::Internal(e))?; + mgr.stop(ext_id).await.map_err(|e| AppError::Internal(e))?; // Return updated status Ok(Json(ExtensionInfo { diff --git a/src/web/handlers/mod.rs b/src/web/handlers/mod.rs index bf34ecc5..a8bbf379 100644 --- a/src/web/handlers/mod.rs +++ b/src/web/handlers/mod.rs @@ -124,8 +124,7 @@ pub async fn system_info(State(state): State>) -> Json backend: if config.atx.enabled { Some(format!( "power: {:?}, reset: {:?}", - config.atx.power.driver, - config.atx.reset.driver + config.atx.power.driver, config.atx.reset.driver )) } else { None @@ -208,7 +207,8 @@ fn get_cpu_model() -> String { } /// CPU usage state for calculating usage between samples -static CPU_PREV_STATS: std::sync::OnceLock> = std::sync::OnceLock::new(); +static CPU_PREV_STATS: std::sync::OnceLock> = + std::sync::OnceLock::new(); /// Get CPU usage percentage (0.0 - 100.0) fn get_cpu_usage() -> f32 { @@ -268,7 +268,12 @@ struct MemInfo { fn get_meminfo() -> MemInfo { let content = match std::fs::read_to_string("/proc/meminfo") { Ok(c) => c, - Err(_) => return MemInfo { total: 0, available: 0 }, + Err(_) => { + return MemInfo { + total: 0, + available: 0, + } + } }; let mut total = 0u64; @@ -276,11 +281,19 @@ fn get_meminfo() -> MemInfo { for line in content.lines() { if line.starts_with("MemTotal:") { - if let Some(kb) = line.split_whitespace().nth(1).and_then(|v| v.parse::().ok()) { + if let Some(kb) = line + .split_whitespace() + .nth(1) + .and_then(|v| v.parse::().ok()) + { total = kb * 1024; } } else if line.starts_with("MemAvailable:") { - if let Some(kb) = line.split_whitespace().nth(1).and_then(|v| v.parse::().ok()) { + if let Some(kb) = line + .split_whitespace() + .nth(1) + .and_then(|v| v.parse::().ok()) + { available = kb * 1024; } } @@ -312,10 +325,7 @@ fn get_network_addresses() -> Vec { if !ipv4_map.contains_key(&ifaddr.interface_name) { if let Some(addr) = ifaddr.address { if let Some(sockaddr_in) = addr.as_sockaddr_in() { - ipv4_map.insert( - ifaddr.interface_name.clone(), - sockaddr_in.ip().to_string(), - ); + ipv4_map.insert(ifaddr.interface_name.clone(), sockaddr_in.ip().to_string()); } } } @@ -624,10 +634,7 @@ pub async fn setup_init( if new_config.extensions.ttyd.enabled { if let Err(e) = state .extensions - .start( - crate::extensions::ExtensionId::Ttyd, - &new_config.extensions, - ) + .start(crate::extensions::ExtensionId::Ttyd, &new_config.extensions) .await { tracing::warn!("Failed to start ttyd during setup: {}", e); @@ -658,7 +665,10 @@ pub async fn setup_init( if let Err(e) = state.audio.update_config(audio_config).await { tracing::warn!("Failed to start audio during setup: {}", e); } else { - tracing::info!("Audio started during setup: device={}", new_config.audio.device); + tracing::info!( + "Audio started during setup: device={}", + new_config.audio.device + ); } // Also enable WebRTC audio if let Err(e) = state.stream_manager.set_webrtc_audio_enabled(true).await { @@ -666,7 +676,10 @@ pub async fn setup_init( } } - tracing::info!("System initialized successfully with admin user: {}", req.username); + tracing::info!( + "System initialized successfully with admin user: {}", + req.username + ); Ok(Json(LoginResponse { success: true, @@ -798,10 +811,19 @@ pub async fn update_config( if let Some(frame_tx) = state.stream_manager.frame_sender().await { let receiver_count = frame_tx.receiver_count(); // Use WebRtcStreamer (new unified interface) - state.stream_manager.webrtc_streamer().set_video_source(frame_tx).await; - tracing::info!("WebRTC streamer frame source updated with new capturer (receiver_count={})", receiver_count); + state + .stream_manager + .webrtc_streamer() + .set_video_source(frame_tx) + .await; + tracing::info!( + "WebRTC streamer frame source updated with new capturer (receiver_count={})", + receiver_count + ); } else { - tracing::warn!("No frame source available after config change - streamer may not be running"); + tracing::warn!( + "No frame source available after config change - streamer may not be running" + ); } } @@ -831,8 +853,11 @@ pub async fn update_config( .await .ok(); // Ignore error if no active stream - tracing::info!("Stream config applied: encoder={:?}, bitrate={}", - new_config.stream.encoder, new_config.stream.bitrate_preset); + tracing::info!( + "Stream config applied: encoder={:?}, bitrate={}", + new_config.stream.encoder, + new_config.stream.bitrate_preset + ); } // HID config processing - always reload if section was sent @@ -860,7 +885,10 @@ pub async fn update_config( })); } - tracing::info!("HID backend reloaded successfully: {:?}", new_config.hid.backend); + tracing::info!( + "HID backend reloaded successfully: {:?}", + new_config.hid.backend + ); } // Audio config processing - always reload if section was sent @@ -888,7 +916,11 @@ pub async fn update_config( } // Also update WebRTC audio enabled state - if let Err(e) = state.stream_manager.set_webrtc_audio_enabled(new_config.audio.enabled).await { + if let Err(e) = state + .stream_manager + .set_webrtc_audio_enabled(new_config.audio.enabled) + .await + { tracing::warn!("Failed to update WebRTC audio state: {}", e); } else { tracing::info!("WebRTC audio enabled: {}", new_config.audio.enabled); @@ -911,7 +943,11 @@ pub async fn update_config( let old_msd_enabled = old_config.msd.enabled; let new_msd_enabled = new_config.msd.enabled; - tracing::info!("MSD enabled: old={}, new={}", old_msd_enabled, new_msd_enabled); + tracing::info!( + "MSD enabled: old={}, new={}", + old_msd_enabled, + new_msd_enabled + ); if old_msd_enabled != new_msd_enabled { if new_msd_enabled { @@ -953,7 +989,10 @@ pub async fn update_config( tracing::info!("MSD shutdown complete"); } } else { - tracing::info!("MSD enabled state unchanged ({}), no reload needed", new_msd_enabled); + tracing::info!( + "MSD enabled state unchanged ({}), no reload needed", + new_msd_enabled + ); } } @@ -1060,7 +1099,12 @@ fn extract_usb_bus_from_bus_info(bus_info: &str) -> Option { if parts.len() == 2 { let port = parts[0]; // Verify it looks like a USB port (starts with digit) - if port.chars().next().map(|c| c.is_ascii_digit()).unwrap_or(false) { + if port + .chars() + .next() + .map(|c| c.is_ascii_digit()) + .unwrap_or(false) + { return Some(port.to_string()); } } @@ -1115,7 +1159,10 @@ pub async fn list_devices(State(state): State>) -> Json continue, }; // Check if matches any prefix - if serial_prefixes.iter().any(|prefix| name.starts_with(prefix)) { + if serial_prefixes + .iter() + .any(|prefix| name.starts_with(prefix)) + { let path = entry.path(); if let Some(p) = path.to_str() { serial_devices.push(SerialDevice { @@ -1156,7 +1203,9 @@ pub async fn list_devices(State(state): State>) -> Json>) -> Json>) -> Json { @@ -1216,6 +1265,9 @@ pub struct SetStreamModeRequest { pub struct StreamModeResponse { pub success: bool, pub mode: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub transition_id: Option, + pub switching: bool, pub message: Option, } @@ -1223,12 +1275,27 @@ pub struct StreamModeResponse { pub async fn stream_mode_get(State(state): State>) -> Json { let mode = state.stream_manager.current_mode().await; let mode_str = match mode { - StreamMode::Mjpeg => "mjpeg", - StreamMode::WebRTC => "webrtc", + StreamMode::Mjpeg => "mjpeg".to_string(), + StreamMode::WebRTC => { + use crate::video::encoder::VideoCodecType; + let codec = state + .stream_manager + .webrtc_streamer() + .current_video_codec() + .await; + match codec { + VideoCodecType::H264 => "h264".to_string(), + VideoCodecType::H265 => "h265".to_string(), + VideoCodecType::VP8 => "vp8".to_string(), + VideoCodecType::VP9 => "vp9".to_string(), + } + } }; Json(StreamModeResponse { success: true, - mode: mode_str.to_string(), + mode: mode_str, + transition_id: state.stream_manager.current_transition_id().await, + switching: state.stream_manager.is_switching(), message: None, }) } @@ -1258,15 +1325,24 @@ pub async fn stream_mode_set( // Set video codec if switching to WebRTC mode with specific codec if let Some(codec) = video_codec { info!("Setting WebRTC video codec to {:?}", codec); - if let Err(e) = state.stream_manager.webrtc_streamer().set_video_codec(codec).await { + if let Err(e) = state + .stream_manager + .webrtc_streamer() + .set_video_codec(codec) + .await + { warn!("Failed to set video codec: {}", e); } } - state.stream_manager.switch_mode(new_mode.clone()).await?; + let tx = state + .stream_manager + .switch_mode_transaction(new_mode.clone()) + .await?; - // Return the actual codec being used - let mode_str = match (&new_mode, &video_codec) { + // Return the requested codec identifier (for UI display). The actual active mode + // may differ if the request was rejected due to an in-progress switch. + let requested_mode_str = match (&new_mode, &video_codec) { (StreamMode::Mjpeg, _) => "mjpeg", (StreamMode::WebRTC, Some(VideoCodecType::H264)) => "h264", (StreamMode::WebRTC, Some(VideoCodecType::H265)) => "h265", @@ -1275,10 +1351,39 @@ pub async fn stream_mode_set( (StreamMode::WebRTC, None) => "webrtc", }; + let active_mode_str = match state.stream_manager.current_mode().await { + StreamMode::Mjpeg => "mjpeg".to_string(), + StreamMode::WebRTC => { + let codec = state + .stream_manager + .webrtc_streamer() + .current_video_codec() + .await; + match codec { + VideoCodecType::H264 => "h264".to_string(), + VideoCodecType::H265 => "h265".to_string(), + VideoCodecType::VP8 => "vp8".to_string(), + VideoCodecType::VP9 => "vp9".to_string(), + } + } + }; + Ok(Json(StreamModeResponse { - success: true, - mode: mode_str.to_string(), - message: Some(format!("Switched to {} mode", mode_str)), + success: tx.accepted, + mode: if tx.accepted { + requested_mode_str.to_string() + } else { + active_mode_str + }, + transition_id: tx.transition_id, + switching: tx.switching, + message: Some(if tx.accepted { + format!("Switching to {} mode", requested_mode_str) + } else if tx.switching { + "Mode switch already in progress".to_string() + } else { + "No switch needed".to_string() + }), })) } @@ -1470,7 +1575,9 @@ pub async fn mjpeg_stream( return axum::response::Response::builder() .status(axum::http::StatusCode::SERVICE_UNAVAILABLE) .header("Content-Type", "application/json") - .body(axum::body::Body::from(r#"{"error":"MJPEG mode not active. Current mode is WebRTC."}"#)) + .body(axum::body::Body::from( + r#"{"error":"MJPEG mode not active. Current mode is WebRTC."}"#, + )) .unwrap(); } @@ -1479,7 +1586,9 @@ pub async fn mjpeg_stream( return axum::response::Response::builder() .status(axum::http::StatusCode::SERVICE_UNAVAILABLE) .header("Content-Type", "application/json") - .body(axum::body::Body::from(r#"{"error":"Video configuration is being changed. Please retry shortly."}"#)) + .body(axum::body::Body::from( + r#"{"error":"Video configuration is being changed. Please retry shortly."}"#, + )) .unwrap(); } @@ -1493,8 +1602,9 @@ pub async fn mjpeg_stream( let handler = state.stream_manager.mjpeg_handler(); // Use provided client ID or generate a new one - let client_id = query.client_id - .filter(|id| !id.is_empty() && id.len() <= 64) // Validate: non-empty, max 64 chars + let client_id = query + .client_id + .filter(|id| !id.is_empty() && id.len() <= 64) // Validate: non-empty, max 64 chars .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); // Create RAII guard - this will automatically register and unregister the client @@ -1538,10 +1648,8 @@ pub async fn mjpeg_stream( } // Wait for new frame notification with timeout - let result = tokio::time::timeout( - std::time::Duration::from_secs(5), - notify_rx.recv() - ).await; + let result = + tokio::time::timeout(std::time::Duration::from_secs(5), notify_rx.recv()).await; match result { Ok(Ok(())) => { @@ -1622,7 +1730,10 @@ pub async fn mjpeg_stream( Response::builder() .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "multipart/x-mixed-replace; boundary=frame") + .header( + header::CONTENT_TYPE, + "multipart/x-mixed-replace; boundary=frame", + ) .header(header::CACHE_CONTROL, "no-cache, no-store, must-revalidate") .header(header::PRAGMA, "no-cache") .header(header::EXPIRES, "0") @@ -1636,14 +1747,12 @@ pub async fn snapshot(State(state): State>) -> impl IntoResponse { let handler = state.stream_manager.mjpeg_handler(); match handler.current_frame() { - Some(frame) if frame.is_valid_jpeg() => { - Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "image/jpeg") - .header(header::CACHE_CONTROL, "no-cache") - .body(Body::from(frame.data_bytes())) - .unwrap() - } + Some(frame) if frame.is_valid_jpeg() => Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "image/jpeg") + .header(header::CACHE_CONTROL, "no-cache") + .body(Body::from(frame.data_bytes())) + .unwrap(), _ => Response::builder() .status(StatusCode::SERVICE_UNAVAILABLE) .body(Body::from("No frame available")) @@ -1674,7 +1783,7 @@ fn create_mjpeg_part(jpeg_data: &[u8]) -> bytes::Bytes { // WebRTC // ============================================================================ -use crate::webrtc::signaling::{IceCandidateRequest, OfferRequest, AnswerResponse}; +use crate::webrtc::signaling::{AnswerResponse, IceCandidateRequest, OfferRequest}; /// Create WebRTC session #[derive(Serialize)] @@ -1692,7 +1801,11 @@ pub async fn webrtc_create_session( )); } - let session_id = state.stream_manager.webrtc_streamer().create_session().await?; + let session_id = state + .stream_manager + .webrtc_streamer() + .create_session() + .await?; Ok(Json(CreateSessionResponse { session_id })) } @@ -1986,7 +2099,9 @@ pub async fn msd_image_upload( // Use streaming upload - chunks are written directly to disk // This avoids loading the entire file into memory - let image = manager.create_from_multipart_field(&filename, field).await?; + let image = manager + .create_from_multipart_field(&filename, field) + .await?; return Ok(Json(image)); } } @@ -2033,9 +2148,7 @@ pub async fn msd_image_download( .as_ref() .ok_or_else(|| AppError::Internal("MSD not initialized".to_string()))?; - let progress = controller - .download_image(req.url, req.filename) - .await?; + let progress = controller.download_image(req.url, req.filename).await?; Ok(Json(progress)) } @@ -2076,9 +2189,9 @@ pub async fn msd_connect( match req.mode { MsdMode::Image => { - let image_id = req - .image_id - .ok_or_else(|| AppError::BadRequest("image_id required for image mode".to_string()))?; + let image_id = req.image_id.ok_or_else(|| { + AppError::BadRequest("image_id required for image mode".to_string()) + })?; // Get image info from ImageManager let images_path = std::path::PathBuf::from(&config.msd.images_path); @@ -2170,9 +2283,8 @@ pub async fn msd_drive_delete(State(state): State>) -> Result, ) -> Result>> { // Check if current user is admin - let current_user = state.users.get(&session.user_id).await? + let current_user = state + .users + .get(&session.user_id) + .await? .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; if !current_user.is_admin { @@ -2588,7 +2705,10 @@ pub async fn create_user( Json(req): Json, ) -> Result> { // Check if current user is admin - let current_user = state.users.get(&session.user_id).await? + let current_user = state + .users + .get(&session.user_id) + .await? .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; if !current_user.is_admin { @@ -2597,13 +2717,20 @@ pub async fn create_user( // Validate input if req.username.len() < 2 { - return Err(AppError::BadRequest("Username must be at least 2 characters".to_string())); + return Err(AppError::BadRequest( + "Username must be at least 2 characters".to_string(), + )); } if req.password.len() < 4 { - return Err(AppError::BadRequest("Password must be at least 4 characters".to_string())); + return Err(AppError::BadRequest( + "Password must be at least 4 characters".to_string(), + )); } - let user = state.users.create(&req.username, &req.password, req.is_admin).await?; + let user = state + .users + .create(&req.username, &req.password, req.is_admin) + .await?; info!("User created: {} (admin: {})", user.username, user.is_admin); Ok(Json(UserResponse::from(user))) } @@ -2623,7 +2750,10 @@ pub async fn update_user( Json(req): Json, ) -> Result> { // Check if current user is admin - let current_user = state.users.get(&session.user_id).await? + let current_user = state + .users + .get(&session.user_id) + .await? .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; if !current_user.is_admin { @@ -2631,13 +2761,18 @@ pub async fn update_user( } // Get target user - let mut user = state.users.get(&user_id).await? + let mut user = state + .users + .get(&user_id) + .await? .ok_or_else(|| AppError::NotFound("User not found".to_string()))?; // Update fields if provided if let Some(username) = req.username { if username.len() < 2 { - return Err(AppError::BadRequest("Username must be at least 2 characters".to_string())); + return Err(AppError::BadRequest( + "Username must be at least 2 characters".to_string(), + )); } user.username = username; } @@ -2647,7 +2782,9 @@ pub async fn update_user( // Note: We need to add an update method to UserStore // For now, return error - Err(AppError::Internal("User update not yet implemented".to_string())) + Err(AppError::Internal( + "User update not yet implemented".to_string(), + )) } /// Delete user (admin only) @@ -2657,7 +2794,10 @@ pub async fn delete_user( Path(user_id): Path, ) -> Result> { // Check if current user is admin - let current_user = state.users.get(&session.user_id).await? + let current_user = state + .users + .get(&session.user_id) + .await? .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; if !current_user.is_admin { @@ -2666,17 +2806,24 @@ pub async fn delete_user( // Prevent deleting self if user_id == session.user_id { - return Err(AppError::BadRequest("Cannot delete your own account".to_string())); + return Err(AppError::BadRequest( + "Cannot delete your own account".to_string(), + )); } // Check if this is the last admin let users = state.users.list().await?; let admin_count = users.iter().filter(|u| u.is_admin).count(); - let target_user = state.users.get(&user_id).await? + let target_user = state + .users + .get(&user_id) + .await? .ok_or_else(|| AppError::NotFound("User not found".to_string()))?; if target_user.is_admin && admin_count <= 1 { - return Err(AppError::BadRequest("Cannot delete the last admin user".to_string())); + return Err(AppError::BadRequest( + "Cannot delete the last admin user".to_string(), + )); } state.users.delete(&user_id).await?; @@ -2703,30 +2850,45 @@ pub async fn change_user_password( Json(req): Json, ) -> Result> { // Check if current user is admin or changing own password - let current_user = state.users.get(&session.user_id).await? + let current_user = state + .users + .get(&session.user_id) + .await? .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; let is_self = user_id == session.user_id; let is_admin = current_user.is_admin; if !is_self && !is_admin { - return Err(AppError::Forbidden("Cannot change other user's password".to_string())); + return Err(AppError::Forbidden( + "Cannot change other user's password".to_string(), + )); } // Validate new password if req.new_password.len() < 4 { - return Err(AppError::BadRequest("Password must be at least 4 characters".to_string())); + return Err(AppError::BadRequest( + "Password must be at least 4 characters".to_string(), + )); } // If changing own password, verify current password if is_self { - let verified = state.users.verify(¤t_user.username, &req.current_password).await?; + let verified = state + .users + .verify(¤t_user.username, &req.current_password) + .await?; if verified.is_none() { - return Err(AppError::AuthError("Current password is incorrect".to_string())); + return Err(AppError::AuthError( + "Current password is incorrect".to_string(), + )); } } - state.users.update_password(&user_id, &req.new_password).await?; + state + .users + .update_password(&user_id, &req.new_password) + .await?; info!("Password changed for user ID: {}", user_id); Ok(Json(LoginResponse { diff --git a/src/web/handlers/terminal.rs b/src/web/handlers/terminal.rs index 67feb0b5..c88f05f6 100644 --- a/src/web/handlers/terminal.rs +++ b/src/web/handlers/terminal.rs @@ -14,9 +14,7 @@ use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::UnixStream; use tokio_tungstenite::tungstenite::{ - client::IntoClientRequest, - http::HeaderValue, - Message as TungsteniteMessage, + client::IntoClientRequest, http::HeaderValue, Message as TungsteniteMessage, }; use crate::error::AppError; @@ -60,10 +58,9 @@ async fn handle_terminal_websocket(client_ws: WebSocket, query_string: String) { } }; - request.headers_mut().insert( - "Sec-WebSocket-Protocol", - HeaderValue::from_static("tty"), - ); + request + .headers_mut() + .insert("Sec-WebSocket-Protocol", HeaderValue::from_static("tty")); // Create WebSocket connection to ttyd let ws_stream = match tokio_tungstenite::client_async(request, unix_stream).await { @@ -143,7 +140,11 @@ pub async fn terminal_proxy( // Build HTTP request to forward let method = req.method().as_str(); - let query = req.uri().query().map(|q| format!("?{}", q)).unwrap_or_default(); + let query = req + .uri() + .query() + .map(|q| format!("?{}", q)) + .unwrap_or_default(); let uri_path = if path_str.is_empty() { format!("/api/terminal/{}", query) } else { @@ -203,7 +204,8 @@ pub async fn terminal_proxy( .unwrap_or(200); // Build response - let mut builder = Response::builder().status(StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK)); + let mut builder = + Response::builder().status(StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK)); // Forward response headers for line in headers_part.lines().skip(1) { diff --git a/src/web/mod.rs b/src/web/mod.rs index 18c09693..0bff07e0 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -1,6 +1,6 @@ mod audio_ws; -mod routes; mod handlers; +mod routes; mod static_files; mod ws; diff --git a/src/web/routes.rs b/src/web/routes.rs index 74200187..a913c796 100644 --- a/src/web/routes.rs +++ b/src/web/routes.rs @@ -78,9 +78,15 @@ pub fn create_router(state: Arc) -> Router { .route("/config", get(handlers::config::get_all_config)) .route("/config", post(handlers::update_config)) .route("/config/video", get(handlers::config::get_video_config)) - .route("/config/video", patch(handlers::config::update_video_config)) + .route( + "/config/video", + patch(handlers::config::update_video_config), + ) .route("/config/stream", get(handlers::config::get_stream_config)) - .route("/config/stream", patch(handlers::config::update_stream_config)) + .route( + "/config/stream", + patch(handlers::config::update_stream_config), + ) .route("/config/hid", get(handlers::config::get_hid_config)) .route("/config/hid", patch(handlers::config::update_hid_config)) .route("/config/msd", get(handlers::config::get_msd_config)) @@ -88,14 +94,35 @@ pub fn create_router(state: Arc) -> Router { .route("/config/atx", get(handlers::config::get_atx_config)) .route("/config/atx", patch(handlers::config::update_atx_config)) .route("/config/audio", get(handlers::config::get_audio_config)) - .route("/config/audio", patch(handlers::config::update_audio_config)) + .route( + "/config/audio", + patch(handlers::config::update_audio_config), + ) // RustDesk configuration endpoints - .route("/config/rustdesk", get(handlers::config::get_rustdesk_config)) - .route("/config/rustdesk", patch(handlers::config::update_rustdesk_config)) - .route("/config/rustdesk/status", get(handlers::config::get_rustdesk_status)) - .route("/config/rustdesk/password", get(handlers::config::get_device_password)) - .route("/config/rustdesk/regenerate-id", post(handlers::config::regenerate_device_id)) - .route("/config/rustdesk/regenerate-password", post(handlers::config::regenerate_device_password)) + .route( + "/config/rustdesk", + get(handlers::config::get_rustdesk_config), + ) + .route( + "/config/rustdesk", + patch(handlers::config::update_rustdesk_config), + ) + .route( + "/config/rustdesk/status", + get(handlers::config::get_rustdesk_status), + ) + .route( + "/config/rustdesk/password", + get(handlers::config::get_device_password), + ) + .route( + "/config/rustdesk/regenerate-id", + post(handlers::config::regenerate_device_id), + ) + .route( + "/config/rustdesk/regenerate-password", + post(handlers::config::regenerate_device_password), + ) // Web server configuration .route("/config/web", get(handlers::config::get_web_config)) .route("/config/web", patch(handlers::config::update_web_config)) @@ -105,7 +132,10 @@ pub fn create_router(state: Arc) -> Router { .route("/msd/status", get(handlers::msd_status)) .route("/msd/images", get(handlers::msd_images_list)) .route("/msd/images/download", post(handlers::msd_image_download)) - .route("/msd/images/download/cancel", post(handlers::msd_image_download_cancel)) + .route( + "/msd/images/download/cancel", + post(handlers::msd_image_download_cancel), + ) .route("/msd/images/{id}", get(handlers::msd_image_get)) .route("/msd/images/{id}", delete(handlers::msd_image_delete)) .route("/msd/connect", post(handlers::msd_connect)) @@ -115,8 +145,14 @@ pub fn create_router(state: Arc) -> Router { .route("/msd/drive", delete(handlers::msd_drive_delete)) .route("/msd/drive/init", post(handlers::msd_drive_init)) .route("/msd/drive/files", get(handlers::msd_drive_files)) - .route("/msd/drive/files/{*path}", get(handlers::msd_drive_download)) - .route("/msd/drive/files/{*path}", delete(handlers::msd_drive_file_delete)) + .route( + "/msd/drive/files/{*path}", + get(handlers::msd_drive_download), + ) + .route( + "/msd/drive/files/{*path}", + delete(handlers::msd_drive_file_delete), + ) .route("/msd/drive/mkdir/{*path}", post(handlers::msd_drive_mkdir)) // ATX (Power Control) endpoints .route("/atx/status", get(handlers::atx_status)) @@ -132,13 +168,34 @@ pub fn create_router(state: Arc) -> Router { // Extension management endpoints .route("/extensions", get(handlers::extensions::list_extensions)) .route("/extensions/{id}", get(handlers::extensions::get_extension)) - .route("/extensions/{id}/start", post(handlers::extensions::start_extension)) - .route("/extensions/{id}/stop", post(handlers::extensions::stop_extension)) - .route("/extensions/{id}/logs", get(handlers::extensions::get_extension_logs)) - .route("/extensions/ttyd/config", patch(handlers::extensions::update_ttyd_config)) - .route("/extensions/ttyd/status", get(handlers::extensions::get_ttyd_status)) - .route("/extensions/gostc/config", patch(handlers::extensions::update_gostc_config)) - .route("/extensions/easytier/config", patch(handlers::extensions::update_easytier_config)) + .route( + "/extensions/{id}/start", + post(handlers::extensions::start_extension), + ) + .route( + "/extensions/{id}/stop", + post(handlers::extensions::stop_extension), + ) + .route( + "/extensions/{id}/logs", + get(handlers::extensions::get_extension_logs), + ) + .route( + "/extensions/ttyd/config", + patch(handlers::extensions::update_ttyd_config), + ) + .route( + "/extensions/ttyd/status", + get(handlers::extensions::get_ttyd_status), + ) + .route( + "/extensions/gostc/config", + patch(handlers::extensions::update_gostc_config), + ) + .route( + "/extensions/easytier/config", + patch(handlers::extensions::update_easytier_config), + ) // Terminal (ttyd) reverse proxy - WebSocket and HTTP .route("/terminal", get(handlers::terminal::terminal_index)) .route("/terminal/", get(handlers::terminal::terminal_index)) @@ -148,9 +205,7 @@ pub fn create_router(state: Arc) -> Router { .layer(middleware::from_fn_with_state(state.clone(), require_admin)); // Combine protected routes (user + admin) - let protected_routes = Router::new() - .merge(user_routes) - .merge(admin_routes); + let protected_routes = Router::new().merge(user_routes).merge(admin_routes); // Stream endpoints (accessible with auth, but typically embedded in pages) let stream_routes = Router::new() diff --git a/src/web/static_files.rs b/src/web/static_files.rs index cbe78c06..3fb84bbf 100644 --- a/src/web/static_files.rs +++ b/src/web/static_files.rs @@ -26,16 +26,18 @@ pub struct StaticAssets; #[cfg(debug_assertions)] fn get_static_base_dir() -> PathBuf { static BASE_DIR: OnceLock = OnceLock::new(); - BASE_DIR.get_or_init(|| { - // Try to get executable directory - if let Ok(exe_path) = std::env::current_exe() { - if let Some(exe_dir) = exe_path.parent() { - return exe_dir.join("web").join("dist"); + BASE_DIR + .get_or_init(|| { + // Try to get executable directory + if let Ok(exe_path) = std::env::current_exe() { + if let Some(exe_dir) = exe_path.parent() { + return exe_dir.join("web").join("dist"); + } } - } - // Fallback to current directory - PathBuf::from("web/dist") - }).clone() + // Fallback to current directory + PathBuf::from("web/dist") + }) + .clone() } /// Create router for static file serving @@ -102,29 +104,29 @@ fn try_serve_file(path: &str) -> Option> { // Debug mode: read from file system let base_dir = get_static_base_dir(); let file_path = base_dir.join(path); - + // Check if file exists and is within base directory (prevent directory traversal) if !file_path.starts_with(&base_dir) { tracing::warn!("Path traversal attempt blocked: {}", path); return None; } - + // Normalize path to prevent directory traversal (only if file exists) - if let (Ok(normalized_path), Ok(normalized_base)) = - (file_path.canonicalize(), base_dir.canonicalize()) + if let (Ok(normalized_path), Ok(normalized_base)) = + (file_path.canonicalize(), base_dir.canonicalize()) { if !normalized_path.starts_with(&normalized_base) { tracing::warn!("Path traversal attempt blocked (canonicalized): {}", path); return None; } } - + match std::fs::read(&file_path) { Ok(data) => { let mime = mime_guess::from_path(path) .first_or_octet_stream() .to_string(); - + return Some( Response::builder() .status(StatusCode::OK) @@ -145,16 +147,16 @@ fn try_serve_file(path: &str) -> Option> { } } } - + #[cfg(not(debug_assertions))] { // Release mode: use embedded assets let asset = StaticAssets::get(path)?; - + let mime = mime_guess::from_path(path) .first_or_octet_stream() .to_string(); - + Some( Response::builder() .status(StatusCode::OK) diff --git a/src/webrtc/h265_payloader.rs b/src/webrtc/h265_payloader.rs index 628757de..5263f0eb 100644 --- a/src/webrtc/h265_payloader.rs +++ b/src/webrtc/h265_payloader.rs @@ -38,12 +38,12 @@ const H265_NAL_PPS: u8 = 34; const H265_NAL_AUD: u8 = 35; const H265_NAL_FILLER: u8 = 38; #[allow(dead_code)] -const H265_NAL_SEI_PREFIX: u8 = 39; // PREFIX_SEI_NUT +const H265_NAL_SEI_PREFIX: u8 = 39; // PREFIX_SEI_NUT #[allow(dead_code)] -const H265_NAL_SEI_SUFFIX: u8 = 40; // SUFFIX_SEI_NUT +const H265_NAL_SEI_SUFFIX: u8 = 40; // SUFFIX_SEI_NUT #[allow(dead_code)] -const H265_NAL_AP: u8 = 48; // Aggregation Packet -const H265_NAL_FU: u8 = 49; // Fragmentation Unit +const H265_NAL_AP: u8 = 48; // Aggregation Packet +const H265_NAL_FU: u8 = 49; // Fragmentation Unit /// H.265 NAL header size const H265_NAL_HEADER_SIZE: usize = 2; @@ -228,7 +228,8 @@ impl H265Payloader { let fragment_size = remaining.min(max_fragment_size); // Create FU packet - let mut packet = BytesMut::with_capacity(H265_NAL_HEADER_SIZE + H265_FU_HEADER_SIZE + fragment_size); + let mut packet = + BytesMut::with_capacity(H265_NAL_HEADER_SIZE + H265_FU_HEADER_SIZE + fragment_size); // NAL header for FU (2 bytes) // Preserve F bit (bit 7) and LayerID MSB (bit 0) from original, set Type to 49 diff --git a/src/webrtc/mod.rs b/src/webrtc/mod.rs index 5f6baf39..ac264a7c 100644 --- a/src/webrtc/mod.rs +++ b/src/webrtc/mod.rs @@ -42,5 +42,7 @@ pub use rtp::{H264VideoTrack, H264VideoTrackConfig, OpusAudioTrack}; pub use session::WebRtcSessionManager; pub use signaling::{ConnectionState, IceCandidate, SdpAnswer, SdpOffer, SignalingMessage}; pub use universal_session::{UniversalSession, UniversalSessionConfig, UniversalSessionInfo}; -pub use video_track::{UniversalVideoTrack, UniversalVideoTrackConfig, VideoCodec, VideoTrackStats}; +pub use video_track::{ + UniversalVideoTrack, UniversalVideoTrackConfig, VideoCodec, VideoTrackStats, +}; pub use webrtc_streamer::{SessionInfo, WebRtcStreamer, WebRtcStreamerConfig, WebRtcStreamerStats}; diff --git a/src/webrtc/peer.rs b/src/webrtc/peer.rs index 49636676..6705aae1 100644 --- a/src/webrtc/peer.rs +++ b/src/webrtc/peer.rs @@ -92,10 +92,9 @@ impl PeerConnection { }; // Create peer connection - let pc = api - .new_peer_connection(rtc_config) - .await - .map_err(|e| AppError::VideoError(format!("Failed to create peer connection: {}", e)))?; + let pc = api.new_peer_connection(rtc_config).await.map_err(|e| { + AppError::VideoError(format!("Failed to create peer connection: {}", e)) + })?; let pc = Arc::new(pc); @@ -125,68 +124,69 @@ impl PeerConnection { let session_id = self.session_id.clone(); // Connection state change handler - self.pc.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - let state = state.clone(); - let session_id = session_id.clone(); + self.pc + .on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { + let state = state.clone(); + let session_id = session_id.clone(); - Box::pin(async move { - let new_state = match s { - RTCPeerConnectionState::New => ConnectionState::New, - RTCPeerConnectionState::Connecting => ConnectionState::Connecting, - RTCPeerConnectionState::Connected => ConnectionState::Connected, - RTCPeerConnectionState::Disconnected => ConnectionState::Disconnected, - RTCPeerConnectionState::Failed => ConnectionState::Failed, - RTCPeerConnectionState::Closed => ConnectionState::Closed, - _ => return, - }; + Box::pin(async move { + let new_state = match s { + RTCPeerConnectionState::New => ConnectionState::New, + RTCPeerConnectionState::Connecting => ConnectionState::Connecting, + RTCPeerConnectionState::Connected => ConnectionState::Connected, + RTCPeerConnectionState::Disconnected => ConnectionState::Disconnected, + RTCPeerConnectionState::Failed => ConnectionState::Failed, + RTCPeerConnectionState::Closed => ConnectionState::Closed, + _ => return, + }; - info!("Peer {} connection state: {}", session_id, new_state); - let _ = state.send(new_state); - }) - })); + info!("Peer {} connection state: {}", session_id, new_state); + let _ = state.send(new_state); + }) + })); // ICE candidate handler let ice_candidates = self.ice_candidates.clone(); - self.pc.on_ice_candidate(Box::new(move |candidate: Option| { - let ice_candidates = ice_candidates.clone(); + self.pc + .on_ice_candidate(Box::new(move |candidate: Option| { + let ice_candidates = ice_candidates.clone(); - Box::pin(async move { - if let Some(c) = candidate { - let candidate_str = c.to_json() - .map(|j| j.candidate) - .unwrap_or_default(); + Box::pin(async move { + if let Some(c) = candidate { + let candidate_str = c.to_json().map(|j| j.candidate).unwrap_or_default(); - debug!("ICE candidate: {}", candidate_str); + debug!("ICE candidate: {}", candidate_str); - let mut candidates = ice_candidates.lock().await; - candidates.push(IceCandidate { - candidate: candidate_str, - sdp_mid: c.to_json().ok().and_then(|j| j.sdp_mid), - sdp_mline_index: c.to_json().ok().and_then(|j| j.sdp_mline_index), - username_fragment: None, - }); - } - }) - })); + let mut candidates = ice_candidates.lock().await; + candidates.push(IceCandidate { + candidate: candidate_str, + sdp_mid: c.to_json().ok().and_then(|j| j.sdp_mid), + sdp_mline_index: c.to_json().ok().and_then(|j| j.sdp_mline_index), + username_fragment: None, + }); + } + }) + })); // Data channel handler - note: HID processing is done when hid_controller is set let data_channel = self.data_channel.clone(); - self.pc.on_data_channel(Box::new(move |dc: Arc| { - let data_channel = data_channel.clone(); + self.pc + .on_data_channel(Box::new(move |dc: Arc| { + let data_channel = data_channel.clone(); - Box::pin(async move { - info!("Data channel opened: {}", dc.label()); + Box::pin(async move { + info!("Data channel opened: {}", dc.label()); - // Store data channel - *data_channel.write().await = Some(dc.clone()); + // Store data channel + *data_channel.write().await = Some(dc.clone()); - // Message handler logs messages; HID processing requires set_hid_controller() - dc.on_message(Box::new(move |msg: DataChannelMessage| { - debug!("DataChannel message: {} bytes", msg.data.len()); - Box::pin(async {}) - })); - }) - })); + // Message handler logs messages; HID processing requires set_hid_controller() + dc.on_message(Box::new(move |msg: DataChannelMessage| { + debug!("DataChannel message: {} bytes", msg.data.len()); + Box::pin(async {}) + })); + }) + })); } /// Set HID controller for processing DataChannel messages @@ -206,7 +206,11 @@ impl PeerConnection { let is_hid_channel = label == "hid" || label == "hid-unreliable"; if is_hid_channel { - info!("HID DataChannel opened: {} (unreliable: {})", label, label == "hid-unreliable"); + info!( + "HID DataChannel opened: {} (unreliable: {})", + label, + label == "hid-unreliable" + ); // Store the reliable data channel for sending responses if label == "hid" { @@ -291,10 +295,9 @@ impl PeerConnection { let sdp = RTCSessionDescription::offer(offer.sdp) .map_err(|e| AppError::VideoError(format!("Invalid SDP offer: {}", e)))?; - self.pc - .set_remote_description(sdp) - .await - .map_err(|e| AppError::VideoError(format!("Failed to set remote description: {}", e)))?; + self.pc.set_remote_description(sdp).await.map_err(|e| { + AppError::VideoError(format!("Failed to set remote description: {}", e)) + })?; // Create answer let answer = self @@ -373,7 +376,11 @@ impl PeerConnection { // Reset HID state to release any held keys/buttons if let Some(ref hid) = self.hid_controller { if let Err(e) = hid.reset().await { - tracing::warn!("Failed to reset HID on peer {} close: {}", self.session_id, e); + tracing::warn!( + "Failed to reset HID on peer {} close: {}", + self.session_id, + e + ); } else { tracing::debug!("HID reset on peer {} close", self.session_id); } diff --git a/src/webrtc/rtp.rs b/src/webrtc/rtp.rs index 325feb47..4dbdf26b 100644 --- a/src/webrtc/rtp.rs +++ b/src/webrtc/rtp.rs @@ -21,9 +21,9 @@ use tokio::sync::Mutex; use tracing::{debug, error, trace}; use webrtc::media::io::h264_reader::H264Reader; use webrtc::media::Sample; +use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; use webrtc::track::track_local::TrackLocal; -use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use crate::error::{AppError, Result}; use crate::video::format::Resolution; @@ -168,7 +168,12 @@ impl H264VideoTrack { /// * `data` - H264 Annex B encoded frame data /// * `duration` - Frame duration (typically 1/fps seconds) /// * `is_keyframe` - Whether this is a keyframe (IDR frame) - pub async fn write_frame(&self, data: &[u8], _duration: Duration, is_keyframe: bool) -> Result<()> { + pub async fn write_frame( + &self, + data: &[u8], + _duration: Duration, + is_keyframe: bool, + ) -> Result<()> { if data.is_empty() { return Ok(()); } @@ -324,9 +329,9 @@ impl H264VideoTrack { let mut payloader = self.payloader.lock().await; let bytes = Bytes::copy_from_slice(data); - payloader.payload(mtu, &bytes).map_err(|e| { - AppError::VideoError(format!("H264 packetization failed: {}", e)) - }) + payloader + .payload(mtu, &bytes) + .map_err(|e| AppError::VideoError(format!("H264 packetization failed: {}", e))) } /// Get configuration @@ -423,7 +428,10 @@ impl OpusAudioTrack { let mut stats = self.stats.lock().await; stats.errors += 1; error!("Failed to write Opus sample: {}", e); - Err(AppError::WebRtcError(format!("Failed to write audio sample: {}", e))) + Err(AppError::WebRtcError(format!( + "Failed to write audio sample: {}", + e + ))) } } } diff --git a/src/webrtc/track.rs b/src/webrtc/track.rs index 4f5f6c65..e144a8be 100644 --- a/src/webrtc/track.rs +++ b/src/webrtc/track.rs @@ -4,9 +4,9 @@ use std::sync::Arc; use std::time::Instant; use tokio::sync::{broadcast, watch, Mutex}; use tracing::{debug, error, info}; +use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; use webrtc::track::track_local::TrackLocalWriter; -use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use crate::video::frame::VideoFrame; @@ -56,7 +56,9 @@ impl VideoCodecType { pub fn sdp_fmtp(&self) -> &'static str { match self { - VideoCodecType::H264 => "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + VideoCodecType::H264 => { + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f" + } VideoCodecType::VP8 => "", VideoCodecType::VP9 => "profile-id=0", } @@ -156,10 +158,7 @@ impl VideoTrack { } /// Start sending frames from a broadcast receiver - pub async fn start_sending( - &self, - mut frame_rx: broadcast::Receiver, - ) { + pub async fn start_sending(&self, mut frame_rx: broadcast::Receiver) { let _ = self.running.send(true); let track = self.track.clone(); let stats = self.stats.clone(); diff --git a/src/webrtc/universal_session.rs b/src/webrtc/universal_session.rs index 3fb8e1e2..f091408d 100644 --- a/src/webrtc/universal_session.rs +++ b/src/webrtc/universal_session.rs @@ -18,7 +18,9 @@ use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; use webrtc::peer_connection::RTCPeerConnection; -use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType}; +use webrtc::rtp_transceiver::rtp_codec::{ + RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType, +}; use webrtc::rtp_transceiver::RTCPFeedback; use super::config::WebRtcConfig; @@ -192,7 +194,8 @@ impl UniversalSession { clock_rate: 90000, channels: 0, // Match browser's fmtp format for profile-id=1 - sdp_fmtp_line: "level-id=180;profile-id=1;tier-flag=0;tx-mode=SRST".to_owned(), + sdp_fmtp_line: "level-id=180;profile-id=1;tier-flag=0;tx-mode=SRST" + .to_owned(), rtcp_feedback: video_rtcp_feedback.clone(), }, payload_type: 49, // Use same payload type as browser @@ -200,7 +203,9 @@ impl UniversalSession { }, RTPCodecType::Video, ) - .map_err(|e| AppError::VideoError(format!("Failed to register H.265 codec: {}", e)))?; + .map_err(|e| { + AppError::VideoError(format!("Failed to register H.265 codec: {}", e)) + })?; // Also register profile-id=2 (Main 10) variant media_engine @@ -210,7 +215,8 @@ impl UniversalSession { mime_type: MIME_TYPE_H265.to_owned(), clock_rate: 90000, channels: 0, - sdp_fmtp_line: "level-id=180;profile-id=2;tier-flag=0;tx-mode=SRST".to_owned(), + sdp_fmtp_line: "level-id=180;profile-id=2;tier-flag=0;tx-mode=SRST" + .to_owned(), rtcp_feedback: video_rtcp_feedback, }, payload_type: 51, @@ -218,7 +224,12 @@ impl UniversalSession { }, RTPCodecType::Video, ) - .map_err(|e| AppError::VideoError(format!("Failed to register H.265 codec (profile 2): {}", e)))?; + .map_err(|e| { + AppError::VideoError(format!( + "Failed to register H.265 codec (profile 2): {}", + e + )) + })?; info!("Registered H.265/HEVC codec for session {}", session_id); } @@ -269,10 +280,9 @@ impl UniversalSession { ..Default::default() }; - let pc = api - .new_peer_connection(rtc_config) - .await - .map_err(|e| AppError::VideoError(format!("Failed to create peer connection: {}", e)))?; + let pc = api.new_peer_connection(rtc_config).await.map_err(|e| { + AppError::VideoError(format!("Failed to create peer connection: {}", e)) + })?; let pc = Arc::new(pc); @@ -291,7 +301,10 @@ impl UniversalSession { pc.add_track(audio.as_track_local()) .await .map_err(|e| AppError::AudioError(format!("Failed to add audio track: {}", e)))?; - info!("Opus audio track added to peer connection (session {})", session_id); + info!( + "Opus audio track added to peer connection (session {})", + session_id + ); } // Create state channel @@ -479,11 +492,13 @@ impl UniversalSession { &self, mut frame_rx: broadcast::Receiver, on_connected: F, - ) - where + ) where F: FnOnce() + Send + 'static, { - info!("Starting {} session {} with shared encoder", self.codec, self.session_id); + info!( + "Starting {} session {} with shared encoder", + self.codec, self.session_id + ); let video_track = self.video_track.clone(); let mut state_rx = self.state_rx.clone(); @@ -492,7 +507,10 @@ impl UniversalSession { let expected_codec = self.codec; let handle = tokio::spawn(async move { - info!("Video receiver waiting for connection for session {}", session_id); + info!( + "Video receiver waiting for connection for session {}", + session_id + ); // Wait for Connected state before sending frames loop { @@ -500,7 +518,10 @@ impl UniversalSession { if current_state == ConnectionState::Connected { break; } - if matches!(current_state, ConnectionState::Closed | ConnectionState::Failed) { + if matches!( + current_state, + ConnectionState::Closed | ConnectionState::Failed + ) { info!("Session {} closed before connecting", session_id); return; } @@ -509,7 +530,10 @@ impl UniversalSession { } } - info!("Video receiver started for session {} (ICE connected)", session_id); + info!( + "Video receiver started for session {} (ICE connected)", + session_id + ); // Request keyframe now that connection is established on_connected(); @@ -592,7 +616,10 @@ impl UniversalSession { } } - info!("Video receiver stopped for session {} (sent {} frames)", session_id, frames_sent); + info!( + "Video receiver stopped for session {} (sent {} frames)", + session_id, frames_sent + ); }); *self.video_receiver_handle.lock().await = Some(handle); @@ -620,7 +647,10 @@ impl UniversalSession { if current_state == ConnectionState::Connected { break; } - if matches!(current_state, ConnectionState::Closed | ConnectionState::Failed) { + if matches!( + current_state, + ConnectionState::Closed | ConnectionState::Failed + ) { info!("Session {} closed before audio could start", session_id); return; } @@ -629,7 +659,10 @@ impl UniversalSession { } } - info!("Audio receiver started for session {} (ICE connected)", session_id); + info!( + "Audio receiver started for session {} (ICE connected)", + session_id + ); let mut packets_sent: u64 = 0; @@ -673,7 +706,10 @@ impl UniversalSession { } } - info!("Audio receiver stopped for session {} (sent {} packets)", session_id, packets_sent); + info!( + "Audio receiver stopped for session {} (sent {} packets)", + session_id, packets_sent + ); }); *self.audio_receiver_handle.lock().await = Some(handle); @@ -697,8 +733,7 @@ impl UniversalSession { || offer.sdp.to_lowercase().contains("hevc"); info!( "[SDP] Session {} offer contains H.265: {}", - self.session_id, - has_h265 + self.session_id, has_h265 ); if !has_h265 { warn!("[SDP] Browser offer does not include H.265 codec! Session may fail."); @@ -708,10 +743,9 @@ impl UniversalSession { let sdp = RTCSessionDescription::offer(offer.sdp) .map_err(|e| AppError::VideoError(format!("Invalid SDP offer: {}", e)))?; - self.pc - .set_remote_description(sdp) - .await - .map_err(|e| AppError::VideoError(format!("Failed to set remote description: {}", e)))?; + self.pc.set_remote_description(sdp).await.map_err(|e| { + AppError::VideoError(format!("Failed to set remote description: {}", e)) + })?; let answer = self .pc @@ -725,8 +759,7 @@ impl UniversalSession { || answer.sdp.to_lowercase().contains("hevc"); info!( "[SDP] Session {} answer contains H.265: {}", - self.session_id, - has_h265 + self.session_id, has_h265 ); if !has_h265 { warn!("[SDP] Answer does not include H.265! Codec negotiation may have failed."); @@ -821,9 +854,21 @@ mod tests { #[test] fn test_encoder_type_to_video_codec() { - assert_eq!(encoder_type_to_video_codec(VideoEncoderType::H264), VideoCodec::H264); - assert_eq!(encoder_type_to_video_codec(VideoEncoderType::H265), VideoCodec::H265); - assert_eq!(encoder_type_to_video_codec(VideoEncoderType::VP8), VideoCodec::VP8); - assert_eq!(encoder_type_to_video_codec(VideoEncoderType::VP9), VideoCodec::VP9); + assert_eq!( + encoder_type_to_video_codec(VideoEncoderType::H264), + VideoCodec::H264 + ); + assert_eq!( + encoder_type_to_video_codec(VideoEncoderType::H265), + VideoCodec::H265 + ); + assert_eq!( + encoder_type_to_video_codec(VideoEncoderType::VP8), + VideoCodec::VP8 + ); + assert_eq!( + encoder_type_to_video_codec(VideoEncoderType::VP9), + VideoCodec::VP9 + ); } } diff --git a/src/webrtc/webrtc_streamer.rs b/src/webrtc/webrtc_streamer.rs index 7dda63bf..510540fe 100644 --- a/src/webrtc/webrtc_streamer.rs +++ b/src/webrtc/webrtc_streamer.rs @@ -41,12 +41,14 @@ use crate::audio::shared_pipeline::{SharedAudioPipeline, SharedAudioPipelineConf use crate::audio::{AudioController, OpusFrame}; use crate::error::{AppError, Result}; use crate::hid::HidController; -use crate::video::encoder::registry::VideoEncoderType; use crate::video::encoder::registry::EncoderBackend; +use crate::video::encoder::registry::VideoEncoderType; use crate::video::encoder::VideoCodecType; use crate::video::format::{PixelFormat, Resolution}; use crate::video::frame::VideoFrame; -use crate::video::shared_video_pipeline::{SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats}; +use crate::video::shared_video_pipeline::{ + SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats, +}; use super::config::{TurnServer, WebRtcConfig}; use super::signaling::{ConnectionState, IceCandidate, SdpAnswer, SdpOffer}; @@ -489,7 +491,9 @@ impl WebRtcStreamer { } } } else { - info!("No video pipeline exists yet, frame source will be used when pipeline is created"); + info!( + "No video pipeline exists yet, frame source will be used when pipeline is created" + ); } } @@ -517,24 +521,21 @@ impl WebRtcStreamer { /// Only restarts the encoding pipeline if configuration actually changed. /// This allows multiple consumers (WebRTC, RustDesk) to share the same pipeline /// without interrupting each other when they call this method with the same config. - pub async fn update_video_config( - &self, - resolution: Resolution, - format: PixelFormat, - fps: u32, - ) { + pub async fn update_video_config(&self, resolution: Resolution, format: PixelFormat, fps: u32) { // Check if configuration actually changed let config = self.config.read().await; - let config_changed = config.resolution != resolution - || config.input_format != format - || config.fps != fps; + let config_changed = + config.resolution != resolution || config.input_format != format || config.fps != fps; drop(config); if !config_changed { // Configuration unchanged, no need to restart pipeline trace!( "Video config unchanged: {}x{} {:?} @ {} fps", - resolution.width, resolution.height, format, fps + resolution.width, + resolution.height, + format, + fps ); return; } @@ -554,7 +555,10 @@ impl WebRtcStreamer { // Close all existing sessions - they need to reconnect let session_count = self.close_all_sessions().await; if session_count > 0 { - info!("Closed {} existing sessions due to config change", session_count); + info!( + "Closed {} existing sessions due to config change", + session_count + ); } // Update config (preserve user-configured bitrate) @@ -581,17 +585,17 @@ impl WebRtcStreamer { // Close all existing sessions - they need to reconnect with new encoder let session_count = self.close_all_sessions().await; if session_count > 0 { - info!("Closed {} existing sessions due to encoder backend change", session_count); + info!( + "Closed {} existing sessions due to encoder backend change", + session_count + ); } // Update config let mut config = self.config.write().await; config.encoder_backend = encoder_backend; - info!( - "WebRTC encoder backend updated: {:?}", - encoder_backend - ); + info!("WebRTC encoder backend updated: {:?}", encoder_backend); } /// Check if current encoder configuration uses hardware encoding @@ -694,7 +698,11 @@ impl WebRtcStreamer { let codec = *self.video_codec.read().await; // Ensure video pipeline is running - let frame_tx = self.video_frame_tx.read().await.clone() + let frame_tx = self + .video_frame_tx + .read() + .await + .clone() .ok_or_else(|| AppError::VideoError("No video frame source".to_string()))?; let pipeline = self.ensure_video_pipeline(frame_tx).await?; @@ -729,15 +737,20 @@ impl WebRtcStreamer { // Request keyframe after ICE connection is established (via callback) let pipeline_for_callback = pipeline.clone(); let session_id_for_callback = session_id.clone(); - session.start_from_video_pipeline(pipeline.subscribe(), move || { - // Spawn async task to request keyframe - let pipeline = pipeline_for_callback; - let sid = session_id_for_callback; - tokio::spawn(async move { - info!("Requesting keyframe for session {} after ICE connected", sid); - pipeline.request_keyframe().await; - }); - }).await; + session + .start_from_video_pipeline(pipeline.subscribe(), move || { + // Spawn async task to request keyframe + let pipeline = pipeline_for_callback; + let sid = session_id_for_callback; + tokio::spawn(async move { + info!( + "Requesting keyframe for session {} after ICE connected", + sid + ); + pipeline.request_keyframe().await; + }); + }) + .await; // Start audio if enabled if session_config.audio_enabled { @@ -863,7 +876,9 @@ impl WebRtcStreamer { .filter(|(_, s)| { matches!( s.state(), - ConnectionState::Closed | ConnectionState::Failed | ConnectionState::Disconnected + ConnectionState::Closed + | ConnectionState::Failed + | ConnectionState::Disconnected ) }) .map(|(id, _)| id.clone()) @@ -967,10 +982,7 @@ impl WebRtcStreamer { }; if pipeline_running { - info!( - "Restarting video pipeline to apply new bitrate: {}", - preset - ); + info!("Restarting video pipeline to apply new bitrate: {}", preset); // Save video_frame_tx BEFORE stopping pipeline (monitor task will clear it) let saved_frame_tx = self.video_frame_tx.read().await.clone(); @@ -1005,13 +1017,18 @@ impl WebRtcStreamer { info!("Reconnecting session {} to new pipeline", session_id); let pipeline_for_callback = pipeline.clone(); let sid = session_id.clone(); - session.start_from_video_pipeline(pipeline.subscribe(), move || { - let pipeline = pipeline_for_callback; - tokio::spawn(async move { - info!("Requesting keyframe for session {} after reconnect", sid); - pipeline.request_keyframe().await; - }); - }).await; + session + .start_from_video_pipeline(pipeline.subscribe(), move || { + let pipeline = pipeline_for_callback; + tokio::spawn(async move { + info!( + "Requesting keyframe for session {} after reconnect", + sid + ); + pipeline.request_keyframe().await; + }); + }) + .await; } } diff --git a/web/src/api/index.ts b/web/src/api/index.ts index b2a624af..4f0d7200 100644 --- a/web/src/api/index.ts +++ b/web/src/api/index.ts @@ -227,10 +227,10 @@ export const streamApi = { getSnapshotUrl: () => `${API_BASE}/snapshot`, getMode: () => - request<{ success: boolean; mode: string; message?: string }>('/stream/mode'), + request<{ success: boolean; mode: string; transition_id?: string; switching?: boolean; message?: string }>('/stream/mode'), setMode: (mode: string) => - request<{ success: boolean; mode: string; message?: string }>('/stream/mode', { + request<{ success: boolean; mode: string; transition_id?: string; switching?: boolean; message?: string }>('/stream/mode', { method: 'POST', body: JSON.stringify({ mode }), }), diff --git a/web/src/components/VideoConfigPopover.vue b/web/src/components/VideoConfigPopover.vue index cbf014c6..9d1529bd 100644 --- a/web/src/components/VideoConfigPopover.vue +++ b/web/src/components/VideoConfigPopover.vue @@ -189,6 +189,7 @@ const selectedFormat = ref('') const selectedResolution = ref('') const selectedFps = ref(30) const selectedBitratePreset = ref<'Speed' | 'Balanced' | 'Quality'>('Balanced') +const isDirty = ref(false) // UI state const applying = ref(false) @@ -327,6 +328,25 @@ function initializeFromCurrent() { selectedFormat.value = config.format selectedResolution.value = `${config.width}x${config.height}` selectedFps.value = config.fps + isDirty.value = false +} + +function syncFromCurrentIfChanged() { + const config = currentConfig.value + const nextResolution = `${config.width}x${config.height}` + + if (selectedDevice.value === config.device + && selectedFormat.value === config.format + && selectedResolution.value === nextResolution + && selectedFps.value === config.fps) { + return + } + + selectedDevice.value = config.device + selectedFormat.value = config.format + selectedResolution.value = nextResolution + selectedFps.value = config.fps + isDirty.value = false } // Handle video mode change @@ -339,6 +359,7 @@ function handleVideoModeChange(mode: unknown) { function handleDeviceChange(devicePath: unknown) { if (typeof devicePath !== 'string') return selectedDevice.value = devicePath + isDirty.value = true // Auto-select first format const device = devices.value.find(d => d.path === devicePath) @@ -358,6 +379,7 @@ function handleDeviceChange(devicePath: unknown) { function handleFormatChange(format: unknown) { if (typeof format !== 'string') return selectedFormat.value = format + isDirty.value = true // Auto-select first resolution for this format const formatData = availableFormats.value.find(f => f.format === format) @@ -372,6 +394,7 @@ function handleFormatChange(format: unknown) { function handleResolutionChange(resolution: unknown) { if (typeof resolution !== 'string') return selectedResolution.value = resolution + isDirty.value = true // Auto-select first FPS for this resolution const resolutionData = availableResolutions.value.find( @@ -386,6 +409,7 @@ function handleResolutionChange(resolution: unknown) { function handleFpsChange(fps: unknown) { if (typeof fps !== 'string' && typeof fps !== 'number') return selectedFps.value = typeof fps === 'string' ? Number(fps) : fps + isDirty.value = true } // Apply bitrate preset change @@ -427,6 +451,7 @@ async function applyVideoConfig() { }) toast.success(t('config.applied')) + isDirty.value = false // Stream state will be updated via WebSocket system.device_info event } catch (e) { console.info('[VideoConfig] Failed to apply config:', e) @@ -455,8 +480,17 @@ watch(() => props.open, (isOpen) => { loadEncoderBackend() // Initialize from current config initializeFromCurrent() + } else { + isDirty.value = false } }) + +// Sync selected values when backend config changes (e.g., auto format switch on mode change) +watch(currentConfig, () => { + if (applying.value) return + if (props.open && isDirty.value) return + syncFromCurrentIfChanged() +}, { deep: true })