From 935fa823f27dbc843597998f227ae44daa6c54b6 Mon Sep 17 00:00:00 2001 From: mofeng-git Date: Mon, 18 May 2026 22:43:28 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=A2=9E=E5=8A=A0=20?= =?UTF-8?q?Windows=20=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 49 +- agents.md | 14 + libs/hwcodec/build.rs | 59 +- libs/hwcodec/cpp/common/util.cpp | 15 +- libs/hwcodec/cpp/ffmpeg_capture.cpp | 879 ++++ libs/hwcodec/cpp/ffmpeg_capture_ffi.h | 64 + libs/hwcodec/src/capture.rs | 297 ++ libs/hwcodec/src/ffmpeg_ram/encode.rs | 8 +- libs/hwcodec/src/lib.rs | 2 + res/vcpkg/libyuv/build.rs | 2 + src/atx/controller.rs | 224 +- src/atx/disabled_key.rs | 34 + src/atx/disabled_led.rs | 34 + src/atx/executor.rs | 592 +-- src/atx/gpio_linux.rs | 106 + src/atx/hidraw_linux.rs | 190 + src/atx/led.rs | 33 +- src/atx/mod.rs | 65 +- src/atx/serial_relay.rs | 141 + src/atx/traits.rs | 51 + src/atx/types.rs | 53 +- src/atx/wol.rs | 74 +- src/audio/capture.rs | 339 +- src/audio/capture_linux.rs | 334 ++ src/audio/capture_windows.rs | 516 +++ src/audio/controller.rs | 382 +- src/audio/device.rs | 206 +- src/audio/device_linux.rs | 201 + src/audio/device_windows.rs | 232 ++ src/audio/mod.rs | 10 +- src/audio/recovery.rs | 320 ++ src/audio/types.rs | 85 + src/config/mod.rs | 8 +- src/config/persistence.rs | 5 - src/config/schema.rs | 827 ---- src/config/schema/atx.rs | 28 + src/config/schema/common.rs | 64 + src/config/schema/hid.rs | 309 ++ src/config/schema/mod.rs | 44 + src/config/schema/stream.rs | 149 + src/config/schema/web.rs | 129 + src/config/store.rs | 35 +- src/diagnostics/linux.rs | 280 ++ src/diagnostics/mod.rs | 47 + src/diagnostics/windows.rs | 249 ++ src/extensions/manager.rs | 138 +- src/extensions/mod.rs | 7 +- src/extensions/software.rs | 15 + src/extensions/software_linux.rs | 19 + src/extensions/software_windows.rs | 47 + src/extensions/types.rs | 13 +- src/hid/ch9329.rs | 283 +- src/hid/ch9329_proto.rs | 225 ++ src/hid/factory.rs | 80 + src/hid/mod.rs | 162 +- src/hid/otg.rs | 208 +- src/hid/otg_device.rs | 50 + src/lib.rs | 7 + src/main.rs | 136 +- src/otg/mod.rs | 23 + src/otg/self_check.rs | 740 ++++ src/platform/capabilities.rs | 89 + src/platform/defaults.rs | 62 + src/platform/linux.rs | 23 + src/platform/mod.rs | 10 + src/{video => platform}/usb_reset.rs | 31 - src/platform/windows.rs | 33 + src/redfish/auth.rs | 6 +- src/redfish/routes/account.rs | 22 +- src/redfish/routes/chassis.rs | 6 +- src/redfish/routes/event.rs | 3 +- src/redfish/routes/managers.rs | 7 +- src/redfish/routes/mod.rs | 5 +- src/redfish/routes/session.rs | 7 +- src/redfish/routes/systems.rs | 11 +- src/redfish/routes/virtual_media.rs | 11 +- src/redfish/schema.rs | 5 +- src/rtsp/bitstream.rs | 4 +- src/rtsp/codec.rs | 2 +- src/rtsp/sdp.rs | 11 +- src/rtsp/streaming.rs | 4 +- src/rustdesk/connection.rs | 22 +- src/rustdesk/frame_adapters.rs | 2 +- src/rustdesk/rendezvous.rs | 1 + src/state.rs | 49 +- src/stream/mjpeg.rs | 4 +- src/stream_encoder.rs | 2 +- src/utils/host.rs | 14 +- src/utils/mod.rs | 6 + src/utils/net_disabled.rs | 14 + src/utils/serial.rs | 12 + .../{v4l2r_capture.rs => capture/linux.rs} | 14 +- src/video/capture/mod.rs | 12 + src/video/capture/runtime.rs | 70 + .../{capture_status.rs => capture/status.rs} | 2 +- src/video/capture/windows.rs | 181 + src/video/capture_limits.rs | 30 - src/video/{ => codec}/convert.rs | 0 src/video/{encoder => codec}/h264.rs | 0 src/video/codec/h264_bitstream.rs | 299 ++ src/video/{encoder => codec}/h265.rs | 0 src/video/{encoder => codec}/jpeg.rs | 2 +- src/video/{decoder => codec}/mjpeg_rkmpp.rs | 2 +- src/video/{decoder => codec}/mjpeg_turbo.rs | 0 src/video/{encoder => codec}/mod.rs | 50 +- src/video/{encoder => codec}/registry.rs | 0 src/video/{encoder => codec}/self_check.rs | 0 src/video/{encoder => codec}/traits.rs | 0 .../codec.rs => codec/video_codec.rs} | 2 +- src/video/{encoder => codec}/vp8.rs | 0 src/video/{encoder => codec}/vp9.rs | 0 src/video/codec_constraints.rs | 4 +- src/video/decoder/mod.rs | 7 - src/video/{csi_bridge.rs => device/bridge.rs} | 2 +- src/video/device/disabled_bridge.rs | 80 + src/video/{device.rs => device/linux.rs} | 4 +- src/video/device/mod.rs | 35 + src/video/device/windows.rs | 359 ++ src/video/format.rs | 3 + src/video/mod.rs | 83 +- .../encoder_state.rs | 18 +- src/video/pipeline/mod.rs | 9 + .../shared.rs} | 164 +- src/video/signal.rs | 47 + src/video/stream_manager.rs | 26 +- src/video/streamer.rs | 61 +- src/video/types.rs | 12 +- src/web/handlers/account.rs | 151 + src/web/handlers/atx_api.rs | 197 + src/web/handlers/audio_api.rs | 83 + src/web/handlers/auth.rs | 107 + src/web/handlers/config/apply.rs | 21 +- src/web/handlers/config/atx.rs | 27 + src/web/handlers/config/mod.rs | 2 + src/web/handlers/config/rustdesk.rs | 5 +- src/web/handlers/config/types.rs | 21 +- src/web/handlers/devices.rs | 17 +- src/web/handlers/extensions.rs | 112 +- src/web/handlers/hid_api.rs | 53 + src/web/handlers/inventory.rs | 182 + src/web/handlers/mod.rs | 3542 +---------------- src/web/handlers/msd_api.rs | 405 ++ src/web/handlers/setup.rs | 261 ++ src/web/handlers/stream.rs | 626 +++ src/web/handlers/system.rs | 113 + src/web/handlers/terminal.rs | 28 +- src/web/handlers/update_api.rs | 31 + src/web/handlers/webrtc.rs | 194 + src/web/routes.rs | 84 +- src/webrtc/rtp.rs | 244 +- src/webrtc/universal_session.rs | 76 +- src/webrtc/video_track.rs | 50 +- src/webrtc/webrtc_streamer.rs | 57 +- web/src/api/index.ts | 38 +- web/src/components/ActionBar.vue | 8 +- web/src/components/VideoConfigPopover.vue | 3 +- web/src/i18n/en-US.ts | 6 + web/src/i18n/zh-CN.ts | 6 + web/src/lib/video-device-label.ts | 51 + web/src/stores/system.ts | 15 +- web/src/views/ConsoleView.vue | 11 +- web/src/views/SettingsView.vue | 141 +- web/src/views/SetupView.vue | 38 +- 163 files changed, 11419 insertions(+), 7581 deletions(-) create mode 100644 agents.md create mode 100644 libs/hwcodec/cpp/ffmpeg_capture.cpp create mode 100644 libs/hwcodec/cpp/ffmpeg_capture_ffi.h create mode 100644 libs/hwcodec/src/capture.rs create mode 100644 src/atx/disabled_key.rs create mode 100644 src/atx/disabled_led.rs create mode 100644 src/atx/gpio_linux.rs create mode 100644 src/atx/hidraw_linux.rs create mode 100644 src/atx/serial_relay.rs create mode 100644 src/atx/traits.rs create mode 100644 src/audio/capture_linux.rs create mode 100644 src/audio/capture_windows.rs create mode 100644 src/audio/device_linux.rs create mode 100644 src/audio/device_windows.rs create mode 100644 src/audio/recovery.rs create mode 100644 src/audio/types.rs delete mode 100644 src/config/persistence.rs delete mode 100644 src/config/schema.rs create mode 100644 src/config/schema/atx.rs create mode 100644 src/config/schema/common.rs create mode 100644 src/config/schema/hid.rs create mode 100644 src/config/schema/mod.rs create mode 100644 src/config/schema/stream.rs create mode 100644 src/config/schema/web.rs create mode 100644 src/diagnostics/linux.rs create mode 100644 src/diagnostics/mod.rs create mode 100644 src/diagnostics/windows.rs create mode 100644 src/extensions/software.rs create mode 100644 src/extensions/software_linux.rs create mode 100644 src/extensions/software_windows.rs create mode 100644 src/hid/ch9329_proto.rs create mode 100644 src/hid/factory.rs create mode 100644 src/hid/otg_device.rs create mode 100644 src/otg/self_check.rs create mode 100644 src/platform/capabilities.rs create mode 100644 src/platform/defaults.rs create mode 100644 src/platform/linux.rs create mode 100644 src/platform/mod.rs rename src/{video => platform}/usb_reset.rs (75%) create mode 100644 src/platform/windows.rs create mode 100644 src/utils/net_disabled.rs create mode 100644 src/utils/serial.rs rename src/video/{v4l2r_capture.rs => capture/linux.rs} (98%) create mode 100644 src/video/capture/mod.rs create mode 100644 src/video/capture/runtime.rs rename src/video/{capture_status.rs => capture/status.rs} (98%) create mode 100644 src/video/capture/windows.rs delete mode 100644 src/video/capture_limits.rs rename src/video/{ => codec}/convert.rs (100%) rename src/video/{encoder => codec}/h264.rs (100%) create mode 100644 src/video/codec/h264_bitstream.rs rename src/video/{encoder => codec}/h265.rs (100%) rename src/video/{encoder => codec}/jpeg.rs (99%) rename src/video/{decoder => codec}/mjpeg_rkmpp.rs (98%) rename src/video/{decoder => codec}/mjpeg_turbo.rs (100%) rename src/video/{encoder => codec}/mod.rs (71%) rename src/video/{encoder => codec}/registry.rs (100%) rename src/video/{encoder => codec}/self_check.rs (100%) rename src/video/{encoder => codec}/traits.rs (100%) rename src/video/{encoder/codec.rs => codec/video_codec.rs} (99%) rename src/video/{encoder => codec}/vp8.rs (100%) rename src/video/{encoder => codec}/vp9.rs (100%) delete mode 100644 src/video/decoder/mod.rs rename src/video/{csi_bridge.rs => device/bridge.rs} (99%) create mode 100644 src/video/device/disabled_bridge.rs rename src/video/{device.rs => device/linux.rs} (99%) create mode 100644 src/video/device/mod.rs create mode 100644 src/video/device/windows.rs rename src/video/{shared_video_pipeline => pipeline}/encoder_state.rs (97%) create mode 100644 src/video/pipeline/mod.rs rename src/video/{shared_video_pipeline.rs => pipeline/shared.rs} (93%) create mode 100644 src/video/signal.rs create mode 100644 src/web/handlers/account.rs create mode 100644 src/web/handlers/atx_api.rs create mode 100644 src/web/handlers/audio_api.rs create mode 100644 src/web/handlers/auth.rs create mode 100644 src/web/handlers/hid_api.rs create mode 100644 src/web/handlers/inventory.rs create mode 100644 src/web/handlers/msd_api.rs create mode 100644 src/web/handlers/setup.rs create mode 100644 src/web/handlers/stream.rs create mode 100644 src/web/handlers/system.rs create mode 100644 src/web/handlers/update_api.rs create mode 100644 src/web/handlers/webrtc.rs create mode 100644 web/src/lib/video-device-label.ts diff --git a/Cargo.toml b/Cargo.toml index ddb2921f..62b4f763 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,12 +63,6 @@ clap = { version = "4", features = ["derive"] } # Time (cookie max_age + RFC3339 timestamps) time = { version = "0.3", features = ["serde", "formatting", "parsing"] } -# Video capture (V4L2) -v4l2r = "0.0.7" - -# JPEG encoding (libjpeg-turbo, SIMD accelerated) -turbojpeg = "1.3" - # Bytes handling bytes = "1" bytemuck = { version = "1.24", features = ["derive"] } @@ -93,11 +87,6 @@ rtp = "0.14" rtsp-types = "0.1" sdp-types = "0.1" -# Audio (ALSA capture + Opus encoding) -# Note: audiopus links to libopus.so (unavoidable for audio support) -alsa = "0.11" -audiopus = "0.2" - # HID (serial port for CH9329) serialport = "4" async-trait = "0.1" @@ -106,22 +95,42 @@ libc = "0.2" # Ventoy bootable image support ventoy-img = { path = "libs/ventoy-img-rs" } -# ATX (GPIO control) -gpio-cdev = "0.6" - -# H264 hardware/software encoding (hwcodec from rustdesk) -hwcodec = { path = "libs/hwcodec" } - # RustDesk protocol support protobuf = { version = "3.7", features = ["with-bytes"] } sodiumoxide = "0.2" sha2 = "0.10" -# High-performance pixel format conversion (libyuv) -libyuv = { path = "res/vcpkg/libyuv" } - # TypeScript type generation typeshare = "1.0" +[target.'cfg(any(unix, windows))'.dependencies] +# Video encoding/decoding (FFmpeg/libjpeg-turbo/libyuv; available on Windows and Linux) +hwcodec = { path = "libs/hwcodec" } +libyuv = { path = "res/vcpkg/libyuv" } +turbojpeg = "1.3" +# Note: audiopus links to libopus.so (unavoidable for audio support) +audiopus = "0.2" + +[target.'cfg(unix)'.dependencies] +# Video capture (V4L2) +v4l2r = "0.0.7" + +# Audio (ALSA capture) +alsa = "0.11" + +# ATX (GPIO control) +gpio-cdev = "0.6" + +[target.'cfg(windows)'.dependencies] +cpal = { version = "0.17", default-features = false } +windows-sys = { version = "0.61", features = [ + "Win32_Foundation", + "Win32_NetworkManagement_IpHelper", + "Win32_NetworkManagement_Ndis", + "Win32_Networking_WinSock", + "Win32_System_SystemInformation", + "Win32_System_Threading", +] } + [dev-dependencies] tempfile = "3" diff --git a/agents.md b/agents.md new file mode 100644 index 00000000..5e73defe --- /dev/null +++ b/agents.md @@ -0,0 +1,14 @@ +# Agents Notes + +## Windows MSVC Build + +Run from the repository root in PowerShell: + +```powershell +$env:VCPKG_ROOT='C:\Users\mofen\code\vcpkg' +$env:TURBOJPEG_SOURCE='explicit' +$env:TURBOJPEG_LIB_DIR='C:\Users\mofen\code\vcpkg\installed\x64-windows-static\lib' +$env:TURBOJPEG_INCLUDE_DIR='C:\Users\mofen\code\vcpkg\installed\x64-windows-static\include' + +cargo build --target x86_64-pc-windows-msvc +``` diff --git a/libs/hwcodec/build.rs b/libs/hwcodec/build.rs index ba799ba8..a8d9b292 100644 --- a/libs/hwcodec/build.rs +++ b/libs/hwcodec/build.rs @@ -34,7 +34,9 @@ fn build_common(builder: &mut Build) { // system #[cfg(windows)] { - ["d3d11", "dxgi"].map(|lib| println!("cargo:rustc-link-lib={}", lib)); + for lib in ["d3d11", "dxgi"] { + println!("cargo:rustc-link-lib={}", lib); + } } builder.include(&common_dir); @@ -99,6 +101,7 @@ mod ffmpeg { link_os(); build_ffmpeg_ram(builder); build_ffmpeg_hw(builder); + build_ffmpeg_capture(builder); } /// Link system FFmpeg using pkg-config or custom path @@ -282,15 +285,24 @@ mod ffmpeg { ) ); { - // Only need avcodec and avutil for encoding + // avdevice/avformat are needed by the Windows DirectShow capture bridge. let mut static_libs = vec!["avcodec", "avutil"]; if target_os == "windows" { - static_libs.push("libmfx"); + static_libs.extend([ + "avformat", + "avdevice", + "avfilter", + "swresample", + "swscale", + "vpx", + "libx264", + "x265-static", + "libmfx", + ]); + } + for lib in static_libs { + println!("cargo:rustc-link-lib=static={}", lib); } - static_libs - .iter() - .map(|lib| println!("cargo:rustc-link-lib=static={}", lib)) - .count(); } let include = path.join("include"); @@ -304,7 +316,10 @@ mod ffmpeg { let target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap(); let dyn_libs: Vec<&str> = if target_os == "windows" { - ["User32", "bcrypt", "ole32", "advapi32"].to_vec() + [ + "User32", "bcrypt", "ole32", "advapi32", "mfuuid", "strmiids", + ] + .to_vec() } else if target_os == "linux" { // Base libraries for all Linux platforms let mut v = vec!["drm", "stdc++"]; @@ -375,6 +390,34 @@ mod ffmpeg { } } + fn build_ffmpeg_capture(builder: &mut Build) { + let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap_or_default(); + if target_os != "windows" { + return; + } + + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let capture_header = manifest_dir + .join("cpp") + .join("ffmpeg_capture_ffi.h") + .to_string_lossy() + .to_string(); + bindgen::builder() + .header(capture_header) + .rustified_enum("*") + .generate() + .unwrap() + .write_to_file( + Path::new(&env::var_os("OUT_DIR").unwrap()).join("ffmpeg_capture_ffi.rs"), + ) + .unwrap(); + + builder.file(manifest_dir.join("cpp").join("ffmpeg_capture.cpp")); + println!("cargo:rustc-link-lib=strmiids"); + println!("cargo:rustc-link-lib=oleaut32"); + println!("cargo:rustc-link-lib=quartz"); + } + fn build_ffmpeg_hw(builder: &mut Build) { let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); let ffmpeg_hw_dir = manifest_dir.join("cpp").join("ffmpeg_hw"); diff --git a/libs/hwcodec/cpp/common/util.cpp b/libs/hwcodec/cpp/common/util.cpp index 9661bc2c..8f5e19a6 100644 --- a/libs/hwcodec/cpp/common/util.cpp +++ b/libs/hwcodec/cpp/common/util.cpp @@ -1,4 +1,5 @@ extern "C" { +#include #include } @@ -99,13 +100,12 @@ void set_av_codec_ctx(AVCodecContext *c, const std::string &name, int kbs, c->color_primaries = AVCOL_PRI_SMPTE170M; c->color_trc = AVCOL_TRC_SMPTE170M; - // Profile selection: use BASELINE for software H264 (faster, simpler) - if (is_software_h264(name)) { - c->profile = FF_PROFILE_H264_BASELINE; // Simpler profile for real-time - } else if (name.find("h264") != std::string::npos) { - c->profile = FF_PROFILE_H264_HIGH; + // WebRTC SDP advertises constrained baseline. Keep hardware and software + // encoders on the same browser-friendly H264 profile. + if (name.find("h264") != std::string::npos) { + c->profile = AV_PROFILE_H264_CONSTRAINED_BASELINE; } else if (name.find("hevc") != std::string::npos) { - c->profile = FF_PROFILE_HEVC_MAIN; + c->profile = AV_PROFILE_HEVC_MAIN; } } @@ -120,8 +120,7 @@ bool set_lantency_free(void *priv_data, const std::string &name) { } if (name.find("amf") != std::string::npos) { if ((ret = av_opt_set(priv_data, "query_timeout", "1000", 0)) < 0) { - LOG_ERROR(std::string("amf set_lantency_free failed, ret = ") + av_err2str(ret)); - return false; + LOG_WARN(std::string("amf query_timeout option is unavailable, ret = ") + av_err2str(ret)); } } if (name.find("qsv") != std::string::npos) { diff --git a/libs/hwcodec/cpp/ffmpeg_capture.cpp b/libs/hwcodec/cpp/ffmpeg_capture.cpp new file mode 100644 index 00000000..ab0adc2c --- /dev/null +++ b/libs/hwcodec/cpp/ffmpeg_capture.cpp @@ -0,0 +1,879 @@ +#define NOMINMAX +#include "ffmpeg_capture_ffi.h" + +#include +#include +#include +extern "C" { +#include +#include +#include +#include +#include +#include +} +#include +#include +#include +#include +#include +#include + + +#pragma comment(lib, "strmiids") + +thread_local std::string g_last_error; + +struct HwcodecDshowCaptureContext { + AVFormatContext* format_ctx = nullptr; + int stream_index = -1; + int width = 0; + int height = 0; + int pixel_format = HWCODEC_CAPTURE_FMT_UNKNOWN; + int stride = 0; + int timeout_ms = 2000; + std::atomic deadline_ms{0}; + std::atomic timed_out{0}; + uint64_t sequence = 0; +}; + +namespace { +struct DshowCapabilityEntry { + std::string format; + int width = 0; + int height = 0; + std::vector fps; +}; + +const char* requested_pixel_format_name(int requested_format); + +void set_last_error(const std::string& message) { + g_last_error = message; +} + +std::string ffmpeg_error(int errnum) { + char buffer[AV_ERROR_MAX_STRING_SIZE] = {0}; + av_make_error_string(buffer, sizeof(buffer), errnum); + return std::string(buffer); +} + +long long now_ms() { + return static_cast(GetTickCount64()); +} + +std::string wide_to_utf8(const wchar_t* value) { + if (!value) { + return std::string(); + } + int size = WideCharToMultiByte(CP_UTF8, 0, value, -1, nullptr, 0, nullptr, nullptr); + if (size <= 1) { + return std::string(); + } + std::string result(static_cast(size - 1), '\0'); + WideCharToMultiByte( + CP_UTF8, + 0, + value, + -1, + result.empty() ? nullptr : &result[0], + size, + nullptr, + nullptr); + return result; +} + +void add_fps_candidate(std::vector* fps, LONGLONG interval_100ns) { + if (!fps || interval_100ns <= 0) { + return; + } + + double fps_value = 10000000.0 / static_cast(interval_100ns); + int rounded = static_cast(fps_value + 0.5); + if (rounded <= 0) { + return; + } + if (std::find(fps->begin(), fps->end(), rounded) == fps->end()) { + fps->push_back(rounded); + } +} + +void normalize_fps(std::vector* fps) { + if (!fps) { + return; + } + std::sort(fps->begin(), fps->end(), std::greater()); + fps->erase(std::unique(fps->begin(), fps->end()), fps->end()); +} + +const char* media_subtype_to_format(const GUID& subtype) { + if (subtype == MEDIASUBTYPE_MJPG) { + return "MJPEG"; + } + if (subtype == MEDIASUBTYPE_YUY2) { + return "YUYV"; + } + if (subtype == MEDIASUBTYPE_UYVY) { + return "UYVY"; + } + if (subtype == MEDIASUBTYPE_YVYU) { + return "YVYU"; + } + if (subtype == MEDIASUBTYPE_NV12) { + return "NV12"; + } + if (subtype == MEDIASUBTYPE_RGB24) { + return "RGB24"; + } + if (subtype == MEDIASUBTYPE_RGB32) { + return "BGR24"; + } + if (subtype == MEDIASUBTYPE_IYUV) { + return "YUV420"; + } + if (subtype == MEDIASUBTYPE_YV12) { + return "YVU420"; + } + return nullptr; +} + +void free_media_type(AM_MEDIA_TYPE* media_type) { + if (!media_type) { + return; + } + if (media_type->cbFormat != 0) { + CoTaskMemFree(media_type->pbFormat); + media_type->cbFormat = 0; + media_type->pbFormat = nullptr; + } + if (media_type->pUnk != nullptr) { + media_type->pUnk->Release(); + media_type->pUnk = nullptr; + } + CoTaskMemFree(media_type); +} + +bool fill_capability_entry( + const AM_MEDIA_TYPE* media_type, + const VIDEO_STREAM_CONFIG_CAPS* caps, + DshowCapabilityEntry* out_entry) { + if (!media_type || !out_entry) { + return false; + } + + const char* format = media_subtype_to_format(media_type->subtype); + if (!format) { + return false; + } + + LONG width = 0; + LONG height = 0; + REFERENCE_TIME avg_time_per_frame = 0; + + if (media_type->formattype == FORMAT_VideoInfo && media_type->pbFormat && + media_type->cbFormat >= sizeof(VIDEOINFOHEADER)) { + const auto* info = reinterpret_cast(media_type->pbFormat); + width = info->bmiHeader.biWidth; + height = std::abs(info->bmiHeader.biHeight); + avg_time_per_frame = info->AvgTimePerFrame; + } else if (media_type->formattype == FORMAT_VideoInfo2 && media_type->pbFormat && + media_type->cbFormat >= sizeof(VIDEOINFOHEADER2)) { + const auto* info = reinterpret_cast(media_type->pbFormat); + width = info->bmiHeader.biWidth; + height = std::abs(info->bmiHeader.biHeight); + avg_time_per_frame = info->AvgTimePerFrame; + } + + if ((width <= 0 || height <= 0) && caps) { + width = std::max(caps->InputSize.cx, caps->MinOutputSize.cx); + height = std::max(caps->InputSize.cy, caps->MinOutputSize.cy); + if (width <= 0 || height <= 0) { + width = caps->MaxOutputSize.cx; + height = caps->MaxOutputSize.cy; + } + } + + if (width <= 0 || height <= 0) { + return false; + } + + out_entry->format = format; + out_entry->width = static_cast(width); + out_entry->height = static_cast(height); + out_entry->fps.clear(); + + add_fps_candidate(&out_entry->fps, avg_time_per_frame); + if (caps) { + add_fps_candidate(&out_entry->fps, caps->MinFrameInterval); + add_fps_candidate(&out_entry->fps, caps->MaxFrameInterval); + } + normalize_fps(&out_entry->fps); + return true; +} + +void append_stream_capabilities(IAMStreamConfig* stream_config, std::vector* entries) { + if (!stream_config || !entries) { + return; + } + + int cap_count = 0; + int cap_size = 0; + HRESULT hr = stream_config->GetNumberOfCapabilities(&cap_count, &cap_size); + if (FAILED(hr) || cap_count <= 0 || cap_size < static_cast(sizeof(VIDEO_STREAM_CONFIG_CAPS))) { + return; + } + + std::vector caps_buffer(static_cast(cap_size)); + for (int index = 0; index < cap_count; ++index) { + AM_MEDIA_TYPE* media_type = nullptr; + hr = stream_config->GetStreamCaps(index, &media_type, caps_buffer.data()); + if (FAILED(hr) || !media_type) { + continue; + } + + DshowCapabilityEntry entry; + const auto* caps = reinterpret_cast(caps_buffer.data()); + if (fill_capability_entry(media_type, caps, &entry)) { + entries->push_back(std::move(entry)); + } + + free_media_type(media_type); + } +} + +bool find_device_filter(const std::string& device_name, IBaseFilter** out_filter) { + if (!out_filter) { + return false; + } + *out_filter = nullptr; + + ICreateDevEnum* dev_enum = nullptr; + IEnumMoniker* enum_moniker = nullptr; + HRESULT hr = CoCreateInstance( + CLSID_SystemDeviceEnum, + nullptr, + CLSCTX_INPROC_SERVER, + IID_ICreateDevEnum, + reinterpret_cast(&dev_enum)); + if (FAILED(hr) || !dev_enum) { + return false; + } + + hr = dev_enum->CreateClassEnumerator(CLSID_VideoInputDeviceCategory, &enum_moniker, 0); + dev_enum->Release(); + if (hr != S_OK || !enum_moniker) { + return false; + } + + bool found = false; + IMoniker* moniker = nullptr; + ULONG fetched = 0; + while (!found && enum_moniker->Next(1, &moniker, &fetched) == S_OK) { + IPropertyBag* bag = nullptr; + hr = moniker->BindToStorage(nullptr, nullptr, IID_IPropertyBag, reinterpret_cast(&bag)); + if (SUCCEEDED(hr) && bag) { + VARIANT name; + VariantInit(&name); + if (SUCCEEDED(bag->Read(L"FriendlyName", &name, nullptr)) && name.vt == VT_BSTR) { + auto utf8_name = wide_to_utf8(name.bstrVal); + if (utf8_name == device_name) { + hr = moniker->BindToObject(nullptr, nullptr, IID_IBaseFilter, reinterpret_cast(out_filter)); + found = SUCCEEDED(hr) && *out_filter != nullptr; + } + } + VariantClear(&name); + bag->Release(); + } + moniker->Release(); + } + enum_moniker->Release(); + return found; +} + +std::string build_capabilities_payload(const std::vector& entries) { + std::string payload; + for (size_t i = 0; i < entries.size(); ++i) { + const auto& entry = entries[i]; + payload += entry.format; + payload.push_back('|'); + payload += std::to_string(entry.width); + payload.push_back('|'); + payload += std::to_string(entry.height); + payload.push_back('|'); + for (size_t fps_index = 0; fps_index < entry.fps.size(); ++fps_index) { + payload += std::to_string(entry.fps[fps_index]); + if (fps_index + 1 < entry.fps.size()) { + payload.push_back(','); + } + } + if (i + 1 < entries.size()) { + payload.push_back('\n'); + } + } + return payload; +} + +char* copy_payload(const std::string& payload) { + char* out = reinterpret_cast(std::malloc(payload.size() + 1)); + if (!out) { + set_last_error("Failed to allocate capture payload buffer"); + return nullptr; + } + std::memcpy(out, payload.c_str(), payload.size() + 1); + return out; +} + +int open_dshow_input_with_options( + AVFormatContext** format_ctx, + const AVInputFormat* input, + const std::string& device_name, + int width, + int height, + int fps, + int requested_format, + bool use_video_size, + bool use_framerate, + bool use_pixel_format, + std::string* attempt_desc) { + if (!format_ctx || !input) { + return AVERROR(EINVAL); + } + + AVDictionary* options = nullptr; + std::vector parts; + + if (use_video_size && width > 0 && height > 0) { + std::string video_size = std::to_string(width) + "x" + std::to_string(height); + av_dict_set(&options, "video_size", video_size.c_str(), 0); + parts.push_back("video_size=" + video_size); + } + if (use_framerate && fps > 0) { + std::string framerate = std::to_string(fps); + av_dict_set(&options, "framerate", framerate.c_str(), 0); + parts.push_back("framerate=" + framerate); + } + + av_dict_set(&options, "rtbufsize", "64M", 0); + parts.push_back("rtbufsize=64M"); + + const char* pixel_format_name = requested_pixel_format_name(requested_format); + if (use_pixel_format && pixel_format_name) { + av_dict_set(&options, "pixel_format", pixel_format_name, 0); + parts.push_back(std::string("pixel_format=") + pixel_format_name); + } + + if (attempt_desc) { + *attempt_desc = parts.empty() ? "default options" : "options{"; + if (!parts.empty()) { + for (size_t i = 0; i < parts.size(); ++i) { + if (i > 0) { + attempt_desc->append(", "); + } + attempt_desc->append(parts[i]); + } + attempt_desc->append("}"); + } + } + + std::string input_name = "video=" + device_name; + int ret = avformat_open_input(format_ctx, input_name.c_str(), input, &options); + av_dict_free(&options); + return ret; +} + +class ScopedComInit { + public: + ScopedComInit() { + HRESULT hr = CoInitializeEx(nullptr, COINIT_MULTITHREADED); + initialized_ = hr == S_OK || hr == S_FALSE; + } + + ~ScopedComInit() { + if (initialized_) { + CoUninitialize(); + } + } + + private: + bool initialized_ = false; +}; + +int capture_stride(int pixel_format, int width) { + switch (pixel_format) { + case HWCODEC_CAPTURE_FMT_YUYV: + case HWCODEC_CAPTURE_FMT_YVYU: + case HWCODEC_CAPTURE_FMT_UYVY: + return width * 2; + case HWCODEC_CAPTURE_FMT_RGB24: + case HWCODEC_CAPTURE_FMT_BGR24: + return width * 3; + case HWCODEC_CAPTURE_FMT_NV24: + return width * 2; + case HWCODEC_CAPTURE_FMT_NV12: + case HWCODEC_CAPTURE_FMT_NV21: + case HWCODEC_CAPTURE_FMT_NV16: + case HWCODEC_CAPTURE_FMT_YUV420: + case HWCODEC_CAPTURE_FMT_YVU420: + case HWCODEC_CAPTURE_FMT_GREY: + case HWCODEC_CAPTURE_FMT_MJPEG: + case HWCODEC_CAPTURE_FMT_JPEG: + default: + return width; + } +} + +int map_raw_pixfmt(int format) { + switch (format) { + case AV_PIX_FMT_YUYV422: + return HWCODEC_CAPTURE_FMT_YUYV; + case AV_PIX_FMT_UYVY422: + return HWCODEC_CAPTURE_FMT_UYVY; +#ifdef AV_PIX_FMT_YVYU422 + case AV_PIX_FMT_YVYU422: + return HWCODEC_CAPTURE_FMT_YVYU; +#endif + case AV_PIX_FMT_NV12: + return HWCODEC_CAPTURE_FMT_NV12; + case AV_PIX_FMT_NV21: + return HWCODEC_CAPTURE_FMT_NV21; +#ifdef AV_PIX_FMT_NV16 + case AV_PIX_FMT_NV16: + return HWCODEC_CAPTURE_FMT_NV16; +#endif +#ifdef AV_PIX_FMT_NV24 + case AV_PIX_FMT_NV24: + return HWCODEC_CAPTURE_FMT_NV24; +#endif + case AV_PIX_FMT_YUV420P: + return HWCODEC_CAPTURE_FMT_YUV420; +#ifdef AV_PIX_FMT_YVU420P + case AV_PIX_FMT_YVU420P: + return HWCODEC_CAPTURE_FMT_YVU420; +#endif + case AV_PIX_FMT_RGB24: + return HWCODEC_CAPTURE_FMT_RGB24; + case AV_PIX_FMT_BGR24: + return HWCODEC_CAPTURE_FMT_BGR24; + case AV_PIX_FMT_GRAY8: + return HWCODEC_CAPTURE_FMT_GREY; + default: + return HWCODEC_CAPTURE_FMT_UNKNOWN; + } +} + +int map_codec_to_capture_format(const AVCodecParameters* codecpar) { + if (!codecpar) { + return HWCODEC_CAPTURE_FMT_UNKNOWN; + } + + switch (codecpar->codec_id) { + case AV_CODEC_ID_MJPEG: + return HWCODEC_CAPTURE_FMT_MJPEG; + case AV_CODEC_ID_JPEG2000: + return HWCODEC_CAPTURE_FMT_JPEG; + case AV_CODEC_ID_RAWVIDEO: + return map_raw_pixfmt(codecpar->format); + default: + return HWCODEC_CAPTURE_FMT_UNKNOWN; + } +} + +int interrupt_callback(void* opaque) { + auto* ctx = reinterpret_cast(opaque); + if (!ctx) { + return 0; + } + auto deadline = ctx->deadline_ms.load(); + if (deadline <= 0) { + return 0; + } + if (now_ms() > deadline) { + ctx->timed_out.store(1); + return 1; + } + return 0; +} + +const char* requested_pixel_format_name(int requested_format) { + switch (requested_format) { + case HWCODEC_CAPTURE_FMT_YUYV: + return "yuyv422"; + case HWCODEC_CAPTURE_FMT_UYVY: + return "uyvy422"; + case HWCODEC_CAPTURE_FMT_NV12: + return "nv12"; + case HWCODEC_CAPTURE_FMT_NV21: + return "nv21"; + case HWCODEC_CAPTURE_FMT_RGB24: + return "rgb24"; + case HWCODEC_CAPTURE_FMT_BGR24: + return "bgr24"; + case HWCODEC_CAPTURE_FMT_GREY: + return "gray"; + default: + return nullptr; + } +} +} // namespace + +extern "C" const char* hwcodec_capture_last_error(void) { + return g_last_error.c_str(); +} + +extern "C" char* hwcodec_dshow_list_video_devices(void) { + ScopedComInit com; + + ICreateDevEnum* dev_enum = nullptr; + IEnumMoniker* enum_moniker = nullptr; + HRESULT hr = CoCreateInstance( + CLSID_SystemDeviceEnum, + nullptr, + CLSCTX_INPROC_SERVER, + IID_ICreateDevEnum, + reinterpret_cast(&dev_enum)); + if (FAILED(hr)) { + set_last_error("Failed to create DirectShow device enumerator"); + return nullptr; + } + + hr = dev_enum->CreateClassEnumerator(CLSID_VideoInputDeviceCategory, &enum_moniker, 0); + dev_enum->Release(); + if (hr != S_OK || !enum_moniker) { + char* out = reinterpret_cast(std::malloc(1)); + if (out) { + out[0] = '\0'; + } + return out; + } + + std::vector devices; + IMoniker* moniker = nullptr; + ULONG fetched = 0; + while (enum_moniker->Next(1, &moniker, &fetched) == S_OK) { + IPropertyBag* bag = nullptr; + hr = moniker->BindToStorage(nullptr, nullptr, IID_IPropertyBag, reinterpret_cast(&bag)); + if (SUCCEEDED(hr) && bag) { + VARIANT name; + VariantInit(&name); + if (SUCCEEDED(bag->Read(L"FriendlyName", &name, nullptr)) && name.vt == VT_BSTR) { + auto utf8_name = wide_to_utf8(name.bstrVal); + if (!utf8_name.empty()) { + devices.push_back(utf8_name); + } + } + VariantClear(&name); + bag->Release(); + } + moniker->Release(); + } + enum_moniker->Release(); + + std::string payload; + for (size_t i = 0; i < devices.size(); ++i) { + payload += devices[i]; + if (i + 1 < devices.size()) { + payload.push_back('\n'); + } + } + + return copy_payload(payload); +} + +extern "C" char* hwcodec_dshow_list_device_capabilities(const char* device_name) { + if (!device_name || device_name[0] == '\0') { + set_last_error("DirectShow device name is empty"); + return nullptr; + } + + ScopedComInit com; + IBaseFilter* filter = nullptr; + if (!find_device_filter(device_name, &filter) || !filter) { + set_last_error("Failed to find DirectShow device filter"); + return nullptr; + } + + std::vector entries; + IEnumPins* enum_pins = nullptr; + HRESULT hr = filter->EnumPins(&enum_pins); + if (SUCCEEDED(hr) && enum_pins) { + IPin* pin = nullptr; + ULONG fetched = 0; + while (enum_pins->Next(1, &pin, &fetched) == S_OK) { + PIN_DIRECTION direction = PINDIR_INPUT; + if (SUCCEEDED(pin->QueryDirection(&direction)) && direction == PINDIR_OUTPUT) { + IAMStreamConfig* stream_config = nullptr; + if (SUCCEEDED(pin->QueryInterface(IID_IAMStreamConfig, reinterpret_cast(&stream_config))) && + stream_config) { + append_stream_capabilities(stream_config, &entries); + stream_config->Release(); + } + } + pin->Release(); + } + enum_pins->Release(); + } + filter->Release(); + + std::sort(entries.begin(), entries.end(), [](const DshowCapabilityEntry& left, const DshowCapabilityEntry& right) { + if (left.format != right.format) { + return left.format < right.format; + } + if (left.width != right.width) { + return left.width < right.width; + } + if (left.height != right.height) { + return left.height < right.height; + } + return left.fps > right.fps; + }); + entries.erase( + std::unique(entries.begin(), entries.end(), [](const DshowCapabilityEntry& left, const DshowCapabilityEntry& right) { + return left.format == right.format && left.width == right.width && left.height == right.height && left.fps == right.fps; + }), + entries.end()); + + return copy_payload(build_capabilities_payload(entries)); +} + +extern "C" void hwcodec_capture_string_free(char* ptr) { + if (ptr) { + std::free(ptr); + } +} + +extern "C" HwcodecDshowCaptureContext* hwcodec_dshow_capture_open( + const char* device_name, + int width, + int height, + int fps, + int requested_format, + int timeout_ms) { + if (!device_name || device_name[0] == '\0') { + set_last_error("Device name is empty"); + return nullptr; + } + + avdevice_register_all(); + + const AVInputFormat* input = av_find_input_format("dshow"); + if (!input) { + set_last_error("FFmpeg dshow input format is unavailable"); + return nullptr; + } + + auto* ctx = new HwcodecDshowCaptureContext(); + ctx->timeout_ms = timeout_ms > 0 ? timeout_ms : 2000; + ctx->format_ctx = avformat_alloc_context(); + if (!ctx->format_ctx) { + delete ctx; + set_last_error("Failed to allocate FFmpeg format context"); + return nullptr; + } + ctx->format_ctx->interrupt_callback.callback = interrupt_callback; + ctx->format_ctx->interrupt_callback.opaque = ctx; + + std::string open_attempt; + int ret = open_dshow_input_with_options( + &ctx->format_ctx, + input, + device_name, + width, + height, + fps, + requested_format, + true, + true, + true, + &open_attempt); + + if (ret < 0) { + avformat_free_context(ctx->format_ctx); + ctx->format_ctx = avformat_alloc_context(); + if (!ctx->format_ctx) { + delete ctx; + set_last_error("Failed to allocate FFmpeg format context for fallback open"); + return nullptr; + } + ctx->format_ctx->interrupt_callback.callback = interrupt_callback; + ctx->format_ctx->interrupt_callback.opaque = ctx; + + std::string fallback_attempt; + ret = open_dshow_input_with_options( + &ctx->format_ctx, + input, + device_name, + width, + height, + fps, + requested_format, + true, + false, + true, + &fallback_attempt); + if (ret >= 0) { + open_attempt = fallback_attempt; + } + } + + if (ret < 0) { + avformat_free_context(ctx->format_ctx); + ctx->format_ctx = avformat_alloc_context(); + if (!ctx->format_ctx) { + delete ctx; + set_last_error("Failed to allocate FFmpeg format context for final fallback open"); + return nullptr; + } + ctx->format_ctx->interrupt_callback.callback = interrupt_callback; + ctx->format_ctx->interrupt_callback.opaque = ctx; + + std::string fallback_attempt; + ret = open_dshow_input_with_options( + &ctx->format_ctx, + input, + device_name, + width, + height, + fps, + requested_format, + false, + false, + false, + &fallback_attempt); + if (ret >= 0) { + open_attempt = fallback_attempt; + } + } + + if (ret < 0) { + set_last_error("Failed to open dshow input (" + open_attempt + "): " + ffmpeg_error(ret)); + avformat_free_context(ctx->format_ctx); + delete ctx; + return nullptr; + } + + ret = avformat_find_stream_info(ctx->format_ctx, nullptr); + if (ret < 0) { + set_last_error("Failed to read stream info: " + ffmpeg_error(ret)); + avformat_close_input(&ctx->format_ctx); + delete ctx; + return nullptr; + } + + for (unsigned int i = 0; i < ctx->format_ctx->nb_streams; ++i) { + AVStream* stream = ctx->format_ctx->streams[i]; + if (stream && stream->codecpar && stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) { + ctx->stream_index = static_cast(i); + ctx->width = stream->codecpar->width > 0 ? stream->codecpar->width : width; + ctx->height = stream->codecpar->height > 0 ? stream->codecpar->height : height; + ctx->pixel_format = map_codec_to_capture_format(stream->codecpar); + ctx->stride = capture_stride(ctx->pixel_format, ctx->width); + break; + } + } + + if (ctx->stream_index < 0) { + set_last_error("No video stream found on DirectShow device"); + avformat_close_input(&ctx->format_ctx); + delete ctx; + return nullptr; + } + + if (ctx->pixel_format == HWCODEC_CAPTURE_FMT_UNKNOWN) { + set_last_error("DirectShow stream format is unsupported in current Windows backend"); + avformat_close_input(&ctx->format_ctx); + delete ctx; + return nullptr; + } + + return ctx; +} + +extern "C" int hwcodec_dshow_capture_info( + HwcodecDshowCaptureContext* ctx, + HwcodecCaptureStreamInfo* out_info) { + if (!ctx || !out_info) { + set_last_error("Invalid capture context"); + return -1; + } + + out_info->width = ctx->width; + out_info->height = ctx->height; + out_info->pixel_format = ctx->pixel_format; + out_info->stride = ctx->stride; + return 0; +} + +extern "C" int hwcodec_dshow_capture_read( + HwcodecDshowCaptureContext* ctx, + uint8_t** out_data, + int* out_len, + uint64_t* out_sequence) { + if (!ctx || !out_data || !out_len || !out_sequence) { + set_last_error("Invalid capture read arguments"); + return -1; + } + + *out_data = nullptr; + *out_len = 0; + *out_sequence = 0; + + AVPacket packet; + av_init_packet(&packet); + packet.data = nullptr; + packet.size = 0; + + while (true) { + ctx->timed_out.store(0); + ctx->deadline_ms.store(now_ms() + ctx->timeout_ms); + int ret = av_read_frame(ctx->format_ctx, &packet); + ctx->deadline_ms.store(0); + + if (ret < 0) { + if (ctx->timed_out.load() != 0) { + set_last_error("Timed out waiting for frame"); + return -110; + } + set_last_error("Failed to read frame: " + ffmpeg_error(ret)); + return ret; + } + + if (packet.stream_index != ctx->stream_index) { + av_packet_unref(&packet); + continue; + } + + if (packet.size <= 0 || !packet.data) { + av_packet_unref(&packet); + continue; + } + + auto* buffer = reinterpret_cast(std::malloc(static_cast(packet.size))); + if (!buffer) { + av_packet_unref(&packet); + set_last_error("Failed to allocate packet buffer"); + return -12; + } + + std::memcpy(buffer, packet.data, static_cast(packet.size)); + *out_data = buffer; + *out_len = packet.size; + *out_sequence = ctx->sequence++; + av_packet_unref(&packet); + return 0; + } +} + +extern "C" void hwcodec_dshow_capture_packet_free(uint8_t* data) { + if (data) { + std::free(data); + } +} + +extern "C" void hwcodec_dshow_capture_close(HwcodecDshowCaptureContext* ctx) { + if (!ctx) { + return; + } + if (ctx->format_ctx) { + avformat_close_input(&ctx->format_ctx); + } + delete ctx; +} diff --git a/libs/hwcodec/cpp/ffmpeg_capture_ffi.h b/libs/hwcodec/cpp/ffmpeg_capture_ffi.h new file mode 100644 index 00000000..9d180885 --- /dev/null +++ b/libs/hwcodec/cpp/ffmpeg_capture_ffi.h @@ -0,0 +1,64 @@ +#ifndef HWCODEC_FFMPEG_CAPTURE_FFI_H +#define HWCODEC_FFMPEG_CAPTURE_FFI_H + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct HwcodecDshowCaptureContext HwcodecDshowCaptureContext; + +enum HwcodecCapturePixelFormat { + HWCODEC_CAPTURE_FMT_UNKNOWN = 0, + HWCODEC_CAPTURE_FMT_MJPEG = 1, + HWCODEC_CAPTURE_FMT_JPEG = 2, + HWCODEC_CAPTURE_FMT_YUYV = 3, + HWCODEC_CAPTURE_FMT_YVYU = 4, + HWCODEC_CAPTURE_FMT_UYVY = 5, + HWCODEC_CAPTURE_FMT_NV12 = 6, + HWCODEC_CAPTURE_FMT_NV21 = 7, + HWCODEC_CAPTURE_FMT_NV16 = 8, + HWCODEC_CAPTURE_FMT_NV24 = 9, + HWCODEC_CAPTURE_FMT_YUV420 = 10, + HWCODEC_CAPTURE_FMT_YVU420 = 11, + HWCODEC_CAPTURE_FMT_RGB24 = 12, + HWCODEC_CAPTURE_FMT_BGR24 = 13, + HWCODEC_CAPTURE_FMT_GREY = 14, +}; + +typedef struct HwcodecCaptureStreamInfo { + int width; + int height; + int pixel_format; + int stride; +} HwcodecCaptureStreamInfo; + +const char* hwcodec_capture_last_error(void); +char* hwcodec_dshow_list_video_devices(void); +char* hwcodec_dshow_list_device_capabilities(const char* device_name); +void hwcodec_capture_string_free(char* ptr); + +HwcodecDshowCaptureContext* hwcodec_dshow_capture_open( + const char* device_name, + int width, + int height, + int fps, + int requested_format, + int timeout_ms); +int hwcodec_dshow_capture_info( + HwcodecDshowCaptureContext* ctx, + HwcodecCaptureStreamInfo* out_info); +int hwcodec_dshow_capture_read( + HwcodecDshowCaptureContext* ctx, + uint8_t** out_data, + int* out_len, + uint64_t* out_sequence); +void hwcodec_dshow_capture_packet_free(uint8_t* data); +void hwcodec_dshow_capture_close(HwcodecDshowCaptureContext* ctx); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/libs/hwcodec/src/capture.rs b/libs/hwcodec/src/capture.rs new file mode 100644 index 00000000..216fcaca --- /dev/null +++ b/libs/hwcodec/src/capture.rs @@ -0,0 +1,297 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +use std::ffi::{CStr, CString}; +use std::os::raw::c_int; + +include!(concat!(env!("OUT_DIR"), "/ffmpeg_capture_ffi.rs")); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CapturePixelFormat { + Unknown, + Mjpeg, + Jpeg, + Yuyv, + Yvyu, + Uyvy, + Nv12, + Nv21, + Nv16, + Nv24, + Yuv420, + Yvu420, + Rgb24, + Bgr24, + Grey, +} + +impl CapturePixelFormat { + pub fn to_ffi(self) -> c_int { + match self { + Self::Unknown => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UNKNOWN as c_int, + Self::Mjpeg => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_MJPEG as c_int, + Self::Jpeg => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_JPEG as c_int, + Self::Yuyv => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUYV as c_int, + Self::Yvyu => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVYU as c_int, + Self::Uyvy => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UYVY as c_int, + Self::Nv12 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV12 as c_int, + Self::Nv21 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV21 as c_int, + Self::Nv16 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV16 as c_int, + Self::Nv24 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV24 as c_int, + Self::Yuv420 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUV420 as c_int, + Self::Yvu420 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVU420 as c_int, + Self::Rgb24 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_RGB24 as c_int, + Self::Bgr24 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_BGR24 as c_int, + Self::Grey => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_GREY as c_int, + } + } + + pub fn from_ffi(value: c_int) -> Self { + match value { + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_MJPEG as c_int => Self::Mjpeg, + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_JPEG as c_int => Self::Jpeg, + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUYV as c_int => Self::Yuyv, + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVYU as c_int => Self::Yvyu, + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UYVY as c_int => Self::Uyvy, + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV12 as c_int => Self::Nv12, + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV21 as c_int => Self::Nv21, + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV16 as c_int => Self::Nv16, + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV24 as c_int => Self::Nv24, + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUV420 as c_int => { + Self::Yuv420 + } + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVU420 as c_int => { + Self::Yvu420 + } + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_RGB24 as c_int => Self::Rgb24, + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_BGR24 as c_int => Self::Bgr24, + x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_GREY as c_int => Self::Grey, + _ => Self::Unknown, + } + } + + pub fn from_name(name: &str) -> Option { + match name.trim().to_ascii_uppercase().as_str() { + "MJPEG" | "MJPG" => Some(Self::Mjpeg), + "JPEG" => Some(Self::Jpeg), + "YUYV" => Some(Self::Yuyv), + "YVYU" => Some(Self::Yvyu), + "UYVY" => Some(Self::Uyvy), + "NV12" => Some(Self::Nv12), + "NV21" => Some(Self::Nv21), + "NV16" => Some(Self::Nv16), + "NV24" => Some(Self::Nv24), + "YUV420" | "I420" | "IYUV" => Some(Self::Yuv420), + "YVU420" | "YV12" => Some(Self::Yvu420), + "RGB24" => Some(Self::Rgb24), + "BGR24" => Some(Self::Bgr24), + "GREY" | "GRAY" | "Y800" => Some(Self::Grey), + _ => None, + } + } +} + +#[derive(Debug, Clone)] +pub struct DshowCapability { + pub format: CapturePixelFormat, + pub width: u32, + pub height: u32, + pub fps: Vec, +} + +#[derive(Debug, Clone, Copy)] +pub struct CaptureStreamInfo { + pub width: i32, + pub height: i32, + pub pixel_format: CapturePixelFormat, + pub stride: i32, +} + +#[derive(Debug)] +pub struct CaptureError { + pub code: i32, + pub message: String, +} + +impl std::fmt::Display for CaptureError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for CaptureError {} + +fn last_error_message() -> String { + unsafe { + let ptr = hwcodec_capture_last_error(); + if ptr.is_null() { + return String::new(); + } + CStr::from_ptr(ptr).to_string_lossy().to_string() + } +} + +pub fn list_dshow_video_devices() -> Result, CaptureError> { + unsafe { + let ptr = hwcodec_dshow_list_video_devices(); + if ptr.is_null() { + return Err(CaptureError { + code: -1, + message: last_error_message(), + }); + } + let payload = CStr::from_ptr(ptr).to_string_lossy().to_string(); + hwcodec_capture_string_free(ptr as *mut _); + Ok(payload + .lines() + .map(str::trim) + .filter(|line| !line.is_empty()) + .map(ToOwned::to_owned) + .collect()) + } +} + +pub fn list_dshow_device_capabilities(device_name: &str) -> Result, CaptureError> { + let device_name = CString::new(device_name).map_err(|_| CaptureError { + code: -1, + message: "device name contains NUL byte".to_string(), + })?; + + unsafe { + let ptr = hwcodec_dshow_list_device_capabilities(device_name.as_ptr()); + if ptr.is_null() { + return Err(CaptureError { + code: -1, + message: last_error_message(), + }); + } + + let payload = CStr::from_ptr(ptr).to_string_lossy().to_string(); + hwcodec_capture_string_free(ptr as *mut _); + + let capabilities = payload + .lines() + .filter_map(parse_dshow_capability_line) + .collect(); + Ok(capabilities) + } +} + +fn parse_dshow_capability_line(line: &str) -> Option { + let mut parts = line.split('|'); + let format = CapturePixelFormat::from_name(parts.next()?.trim())?; + let width = parts.next()?.trim().parse::().ok()?; + let height = parts.next()?.trim().parse::().ok()?; + let fps = parts + .next() + .unwrap_or_default() + .split(',') + .filter_map(|value| value.trim().parse::().ok()) + .filter(|value| *value > 0) + .collect::>(); + + Some(DshowCapability { + format, + width, + height, + fps, + }) +} + +pub struct DshowCapture { + ctx: *mut HwcodecDshowCaptureContext, +} + +unsafe impl Send for DshowCapture {} + +impl DshowCapture { + pub fn open( + device_name: &str, + width: i32, + height: i32, + fps: i32, + requested_format: CapturePixelFormat, + timeout_ms: i32, + ) -> Result { + let device_name = CString::new(device_name).map_err(|_| CaptureError { + code: -1, + message: "device name contains NUL byte".to_string(), + })?; + unsafe { + let ctx = hwcodec_dshow_capture_open( + device_name.as_ptr(), + width, + height, + fps, + requested_format.to_ffi(), + timeout_ms, + ); + if ctx.is_null() { + return Err(CaptureError { + code: -1, + message: last_error_message(), + }); + } + Ok(Self { ctx }) + } + } + + pub fn info(&self) -> Result { + unsafe { + let mut info = HwcodecCaptureStreamInfo { + width: 0, + height: 0, + pixel_format: HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UNKNOWN as c_int, + stride: 0, + }; + let ret = hwcodec_dshow_capture_info(self.ctx, &mut info); + if ret != 0 { + return Err(CaptureError { + code: ret, + message: last_error_message(), + }); + } + Ok(CaptureStreamInfo { + width: info.width, + height: info.height, + pixel_format: CapturePixelFormat::from_ffi(info.pixel_format), + stride: info.stride, + }) + } + } + + pub fn read_packet(&mut self) -> Result<(Vec, u64), CaptureError> { + unsafe { + let mut data = std::ptr::null_mut(); + let mut len = 0; + let mut sequence = 0u64; + let ret = hwcodec_dshow_capture_read(self.ctx, &mut data, &mut len, &mut sequence); + if ret != 0 { + return Err(CaptureError { + code: ret, + message: last_error_message(), + }); + } + if data.is_null() || len <= 0 { + return Err(CaptureError { + code: -1, + message: "empty packet returned by capture backend".to_string(), + }); + } + let slice = std::slice::from_raw_parts(data, len as usize); + let vec = slice.to_vec(); + hwcodec_dshow_capture_packet_free(data); + Ok((vec, sequence)) + } + } +} + +impl Drop for DshowCapture { + fn drop(&mut self) { + unsafe { + hwcodec_dshow_capture_close(self.ctx); + } + self.ctx = std::ptr::null_mut(); + } +} diff --git a/libs/hwcodec/src/ffmpeg_ram/encode.rs b/libs/hwcodec/src/ffmpeg_ram/encode.rs index 49140b78..64d3d279 100644 --- a/libs/hwcodec/src/ffmpeg_ram/encode.rs +++ b/libs/hwcodec/src/ffmpeg_ram/encode.rs @@ -257,7 +257,13 @@ struct ProbePolicy { impl ProbePolicy { fn for_codec(codec_name: &str) -> Self { - if codec_name.contains("v4l2m2m") { + if codec_name.contains("amf") { + Self { + max_attempts: 5, + request_keyframe: true, + accept_any_output: true, + } + } else if codec_name.contains("v4l2m2m") { Self { max_attempts: 5, request_keyframe: true, diff --git a/libs/hwcodec/src/lib.rs b/libs/hwcodec/src/lib.rs index 9645c1f1..9a57fa75 100644 --- a/libs/hwcodec/src/lib.rs +++ b/libs/hwcodec/src/lib.rs @@ -1,3 +1,5 @@ +#[cfg(windows)] +pub mod capture; pub mod common; pub mod ffmpeg; #[cfg(any(target_arch = "aarch64", target_arch = "arm", feature = "rkmpp"))] diff --git a/res/vcpkg/libyuv/build.rs b/res/vcpkg/libyuv/build.rs index 9cb312df..a9d5fb31 100644 --- a/res/vcpkg/libyuv/build.rs +++ b/res/vcpkg/libyuv/build.rs @@ -154,11 +154,13 @@ fn link_vcpkg(mut path: PathBuf) -> bool { if use_static && static_lib.exists() { // Static linking (for deb packaging) println!("cargo:rustc-link-lib=static=yuv"); + #[cfg(target_os = "linux")] println!("cargo:rustc-link-lib=stdc++"); println!("cargo:info=Using libyuv from vcpkg (static linking)"); } else { // Dynamic linking (default for development) println!("cargo:rustc-link-lib=yuv"); + #[cfg(target_os = "linux")] println!("cargo:rustc-link-lib=stdc++"); println!("cargo:info=Using libyuv from vcpkg (dynamic linking)"); } diff --git a/src/atx/controller.rs b/src/atx/controller.rs index c4b022f4..74fb31c7 100644 --- a/src/atx/controller.rs +++ b/src/atx/controller.rs @@ -11,20 +11,14 @@ use super::led::LedSensor; use super::types::{AtxAction, AtxKeyConfig, AtxLedConfig, AtxState, PowerStatus}; use crate::error::{AppError, Result}; -/// ATX power control configuration #[derive(Debug, Clone, Default)] pub struct AtxControllerConfig { - /// Whether ATX is enabled pub enabled: bool, - /// Power button configuration (used for both short and long press) pub power: AtxKeyConfig, - /// Reset button configuration pub reset: AtxKeyConfig, - /// LED sensing configuration pub led: AtxLedConfig, } -/// Internal state holding all ATX components /// Grouped together to reduce lock acquisitions struct AtxInner { config: AtxControllerConfig, @@ -33,12 +27,9 @@ struct AtxInner { led_sensor: Option, } -/// ATX Controller -/// /// Manages ATX power control through independent executors for each action. /// Supports hot-reload of configuration. pub struct AtxController { - /// Single lock for all internal state to reduce lock contention inner: RwLock, } @@ -53,6 +44,24 @@ impl AtxController { && power.baud_rate == reset.baud_rate } + async fn init_key_executor( + warn_label: &str, + info_label: &str, + config: AtxKeyConfig, + mut executor: AtxKeyExecutor, + ) -> Option { + if let Err(e) = executor.init().await { + warn!("Failed to initialize {} executor: {}", warn_label, e); + return None; + } + + info!( + "{} executor initialized: {:?} on {} pin {}", + info_label, config.driver, config.device, config.pin + ); + Some(executor) + } + async fn init_components(inner: &mut AtxInner) { if Self::should_share_serial_device(&inner.config.power, &inner.config.reset) { match AtxKeyExecutor::open_shared_serial( @@ -60,36 +69,28 @@ impl AtxController { inner.config.power.baud_rate, ) { Ok(shared_serial) => { - let mut power_executor = AtxKeyExecutor::new_with_shared_serial( - inner.config.power.clone(), - shared_serial.clone(), - ); - if let Err(e) = power_executor.init().await { - warn!("Failed to initialize power executor: {}", e); - } else { - info!( - "Power executor initialized: {:?} on {} pin {}", - inner.config.power.driver, - inner.config.power.device, - inner.config.power.pin + for (slot, warn_label, info_label, config, serial) in [ + ( + &mut inner.power_executor, + "power", + "Power", + inner.config.power.clone(), + shared_serial.clone(), + ), + ( + &mut inner.reset_executor, + "reset", + "Reset", + inner.config.reset.clone(), + shared_serial, + ), + ] { + let executor = AtxKeyExecutor::new_with_shared_serial( + config.clone(), + serial, ); - inner.power_executor = Some(power_executor); - } - - let mut reset_executor = AtxKeyExecutor::new_with_shared_serial( - inner.config.reset.clone(), - shared_serial, - ); - if let Err(e) = reset_executor.init().await { - warn!("Failed to initialize reset executor: {}", e); - } else { - info!( - "Reset executor initialized: {:?} on {} pin {}", - inner.config.reset.driver, - inner.config.reset.device, - inner.config.reset.pin - ); - inner.reset_executor = Some(reset_executor); + *slot = Self::init_key_executor(warn_label, info_label, config, executor) + .await; } } Err(e) => { @@ -100,40 +101,18 @@ impl AtxController { } } } else { - // Initialize power executor - if inner.config.power.is_configured() { - let mut executor = AtxKeyExecutor::new(inner.config.power.clone()); - if let Err(e) = executor.init().await { - warn!("Failed to initialize power executor: {}", e); - } else { - info!( - "Power executor initialized: {:?} on {} pin {}", - inner.config.power.driver, - inner.config.power.device, - inner.config.power.pin - ); - inner.power_executor = Some(executor); - } - } - - // Initialize reset executor - if inner.config.reset.is_configured() { - let mut executor = AtxKeyExecutor::new(inner.config.reset.clone()); - if let Err(e) = executor.init().await { - warn!("Failed to initialize reset executor: {}", e); - } else { - info!( - "Reset executor initialized: {:?} on {} pin {}", - inner.config.reset.driver, - inner.config.reset.device, - inner.config.reset.pin - ); - inner.reset_executor = Some(executor); + for (slot, warn_label, info_label, config) in [ + (&mut inner.power_executor, "power", "Power", inner.config.power.clone()), + (&mut inner.reset_executor, "reset", "Reset", inner.config.reset.clone()), + ] { + if config.is_configured() { + let executor = AtxKeyExecutor::new(config.clone()); + *slot = Self::init_key_executor(warn_label, info_label, config, executor) + .await; } } } - // Initialize LED sensor if inner.config.led.is_configured() { let mut sensor = LedSensor::new(inner.config.led.clone()); if let Err(e) = sensor.init().await { @@ -149,19 +128,17 @@ impl AtxController { } async fn shutdown_components(inner: &mut AtxInner) { - if let Some(executor) = inner.power_executor.as_mut() { - if let Err(e) = executor.shutdown().await { - warn!("Failed to shutdown power executor: {}", e); + for (slot, label) in [ + (&mut inner.power_executor, "power"), + (&mut inner.reset_executor, "reset"), + ] { + if let Some(executor) = slot.as_mut() { + if let Err(e) = executor.shutdown().await { + warn!("Failed to shutdown {} executor: {}", label, e); + } } + *slot = None; } - inner.power_executor = None; - - if let Some(executor) = inner.reset_executor.as_mut() { - if let Err(e) = executor.shutdown().await { - warn!("Failed to shutdown reset executor: {}", e); - } - } - inner.reset_executor = None; if let Some(sensor) = inner.led_sensor.as_mut() { if let Err(e) = sensor.shutdown().await { @@ -171,7 +148,20 @@ impl AtxController { inner.led_sensor = None; } - /// Create a new ATX controller with the specified configuration + async fn read_power_status(sensor: Option<&LedSensor>) -> PowerStatus { + let Some(sensor) = sensor else { + return PowerStatus::Unknown; + }; + + match sensor.read().await { + Ok(status) => status, + Err(e) => { + debug!("Failed to read ATX LED sensor: {}", e); + PowerStatus::Unknown + } + } + } + pub fn new(config: AtxControllerConfig) -> Self { Self { inner: RwLock::new(AtxInner { @@ -183,12 +173,10 @@ impl AtxController { } } - /// Create a disabled ATX controller pub fn disabled() -> Self { Self::new(AtxControllerConfig::default()) } - /// Initialize the ATX controller and its executors pub async fn init(&self) -> Result<()> { let mut inner = self.inner.write().await; @@ -204,7 +192,6 @@ impl AtxController { Ok(()) } - /// Reload ATX controller configuration pub async fn reload(&self, config: AtxControllerConfig) -> Result<()> { let mut inner = self.inner.write().await; @@ -225,7 +212,6 @@ impl AtxController { Ok(()) } - /// Shutdown ATX controller and release all resources pub async fn shutdown(&self) -> Result<()> { let mut inner = self.inner.write().await; Self::shutdown_components(&mut inner).await; @@ -233,86 +219,48 @@ impl AtxController { Ok(()) } - /// Trigger a power action (short/long/reset) pub async fn trigger_power_action(&self, action: AtxAction) -> Result<()> { let inner = self.inner.read().await; - match action { - AtxAction::Short | AtxAction::Long => { - if let Some(executor) = &inner.power_executor { - let duration = match action { - AtxAction::Short => timing::SHORT_PRESS, - AtxAction::Long => timing::LONG_PRESS, - _ => unreachable!(), - }; - executor.pulse(duration).await?; - } else { - return Err(AppError::Config( - "Power button not configured for ATX controller".to_string(), - )); - } - } - AtxAction::Reset => { - if let Some(executor) = &inner.reset_executor { - executor.pulse(timing::RESET_PRESS).await?; - } else { - return Err(AppError::Config( - "Reset button not configured for ATX controller".to_string(), - )); - } - } - } + let (executor, duration) = match action { + AtxAction::Short => (inner.power_executor.as_ref(), timing::SHORT_PRESS), + AtxAction::Long => (inner.power_executor.as_ref(), timing::LONG_PRESS), + AtxAction::Reset => (inner.reset_executor.as_ref(), timing::RESET_PRESS), + }; + let Some(executor) = executor else { + return Err(AppError::Config(match action { + AtxAction::Reset => "Reset button not configured for ATX controller", + _ => "Power button not configured for ATX controller", + } + .to_string())); + }; + + executor.pulse(duration).await?; Ok(()) } - /// Trigger a short power button press pub async fn power_short(&self) -> Result<()> { self.trigger_power_action(AtxAction::Short).await } - /// Trigger a long power button press pub async fn power_long(&self) -> Result<()> { self.trigger_power_action(AtxAction::Long).await } - /// Trigger a reset button press pub async fn reset(&self) -> Result<()> { self.trigger_power_action(AtxAction::Reset).await } - /// Get the current power status using the LED sensor (if configured) pub async fn power_status(&self) -> PowerStatus { let inner = self.inner.read().await; - - if let Some(sensor) = &inner.led_sensor { - match sensor.read().await { - Ok(status) => status, - Err(e) => { - debug!("Failed to read ATX LED sensor: {}", e); - PowerStatus::Unknown - } - } - } else { - PowerStatus::Unknown - } + Self::read_power_status(inner.led_sensor.as_ref()).await } - /// Get a snapshot of the ATX state for API responses pub async fn state(&self) -> AtxState { let inner = self.inner.read().await; - let power_status = if let Some(sensor) = &inner.led_sensor { - match sensor.read().await { - Ok(status) => status, - Err(e) => { - debug!("Failed to read ATX LED sensor: {}", e); - PowerStatus::Unknown - } - } - } else { - PowerStatus::Unknown - }; + let power_status = Self::read_power_status(inner.led_sensor.as_ref()).await; AtxState { available: inner.config.enabled, diff --git a/src/atx/disabled_key.rs b/src/atx/disabled_key.rs new file mode 100644 index 00000000..ba1d6acd --- /dev/null +++ b/src/atx/disabled_key.rs @@ -0,0 +1,34 @@ +use async_trait::async_trait; +use std::time::Duration; + +use super::traits::AtxKeyBackend; +use crate::error::{AppError, Result}; + +pub struct DisabledAtxKeyBackend { + reason: &'static str, +} + +impl DisabledAtxKeyBackend { + pub fn new(reason: &'static str) -> Self { + Self { reason } + } +} + +#[async_trait] +impl AtxKeyBackend for DisabledAtxKeyBackend { + async fn init(&mut self) -> Result<()> { + Err(AppError::Internal(self.reason.to_string())) + } + + async fn pulse(&self, _duration: Duration) -> Result<()> { + Err(AppError::Internal(self.reason.to_string())) + } + + async fn shutdown(&mut self) -> Result<()> { + Ok(()) + } + + fn is_initialized(&self) -> bool { + false + } +} diff --git a/src/atx/disabled_led.rs b/src/atx/disabled_led.rs new file mode 100644 index 00000000..52ef9034 --- /dev/null +++ b/src/atx/disabled_led.rs @@ -0,0 +1,34 @@ +#![allow(dead_code)] + +use super::types::{AtxLedConfig, PowerStatus}; +use crate::error::Result; + +pub struct LedSensor { + config: AtxLedConfig, +} + +impl LedSensor { + pub fn new(config: AtxLedConfig) -> Self { + Self { config } + } + + pub fn is_configured(&self) -> bool { + self.config.is_configured() + } + + pub fn is_initialized(&self) -> bool { + false + } + + pub async fn init(&mut self) -> Result<()> { + Ok(()) + } + + pub async fn read(&self) -> Result { + Ok(PowerStatus::Unknown) + } + + pub async fn shutdown(&mut self) -> Result<()> { + Ok(()) + } +} diff --git a/src/atx/executor.rs b/src/atx/executor.rs index 5cd276d7..788dca6d 100644 --- a/src/atx/executor.rs +++ b/src/atx/executor.rs @@ -1,497 +1,150 @@ -//! ATX Key Executor -//! -//! Lightweight executor for a single ATX key operation. -//! Each executor handles one button (power or reset) with its own hardware binding. +//! ATX key executor backend selector. -use gpio_cdev::{Chip, LineHandle, LineRequestFlags}; -use serialport::SerialPort; -use std::fs::{File, OpenOptions}; -use std::io::Write; -use std::os::fd::AsRawFd; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex}; use std::time::Duration; -use tokio::time::sleep; -use tracing::{debug, info}; +use tracing::debug; -use super::types::{ActiveLevel, AtxDriverType, AtxKeyConfig}; +use super::serial_relay::SerialRelayBackend; +use super::traits::{AtxKeyBackend, AtxKeyBackendContext, SharedSerialHandle}; +use super::types::{AtxDriverType, AtxKeyConfig}; use crate::error::{AppError, Result}; -pub type SharedSerialHandle = Arc>>; - -const USB_RELAY_MAX_CHANNEL: u8 = 8; -const USB_RELAY_REPORT_LEN: usize = 9; -const HIDIOCSFEATURE_9: libc::c_ulong = 0xC009_4806; // _IOC(_IOC_READ|_IOC_WRITE, 'H', 0x06, 9) - -/// Timing constants for ATX operations pub mod timing { use std::time::Duration; - /// Short press duration (power on/graceful shutdown) pub const SHORT_PRESS: Duration = Duration::from_millis(500); - - /// Long press duration (force power off) pub const LONG_PRESS: Duration = Duration::from_millis(5000); - - /// Reset press duration pub const RESET_PRESS: Duration = Duration::from_millis(500); } -/// Executor for a single ATX key operation -/// -/// Each executor manages one hardware button (power or reset). -/// It handles both GPIO and USB relay backends. pub struct AtxKeyExecutor { config: AtxKeyConfig, - gpio_handle: Mutex>, - /// Cached USB relay file handle to avoid repeated open/close syscalls - usb_relay_handle: Mutex>, - /// Cached Serial port handle (can be shared across power/reset executors) - serial_handle: Mutex>, - initialized: AtomicBool, + backend: Option>, } impl AtxKeyExecutor { - /// Create a new executor with the given configuration pub fn new(config: AtxKeyConfig) -> Self { - Self { - config, - gpio_handle: Mutex::new(None), - usb_relay_handle: Mutex::new(None), - serial_handle: Mutex::new(None), - initialized: AtomicBool::new(false), - } + Self::with_context(config, AtxKeyBackendContext::Standalone) } - /// Create a new executor with a pre-opened shared serial handle. pub fn new_with_shared_serial(config: AtxKeyConfig, serial_handle: SharedSerialHandle) -> Self { - Self { - config, - gpio_handle: Mutex::new(None), - usb_relay_handle: Mutex::new(None), - serial_handle: Mutex::new(Some(serial_handle)), - initialized: AtomicBool::new(false), - } + Self::with_context(config, AtxKeyBackendContext::SharedSerial(serial_handle)) } - /// Open a serial relay device and wrap it for shared use. pub fn open_shared_serial(device: &str, baud_rate: u32) -> Result { - let port = serialport::new(device, baud_rate) - .timeout(Duration::from_millis(100)) - .open() - .map_err(|e| AppError::Internal(format!("Serial port open failed: {}", e)))?; - Ok(Arc::new(Mutex::new(port))) + SerialRelayBackend::open_shared_serial(device, baud_rate) + } + + fn with_context(config: AtxKeyConfig, context: AtxKeyBackendContext) -> Self { + let backend = build_backend(&config, context); + Self { config, backend } } - /// Check if this executor is configured pub fn is_configured(&self) -> bool { self.config.is_configured() } - /// Check if this executor is initialized - pub fn is_initialized(&self) -> bool { - self.initialized.load(Ordering::Relaxed) - } - - /// Initialize the executor pub async fn init(&mut self) -> Result<()> { if !self.config.is_configured() { debug!("ATX key executor not configured, skipping init"); return Ok(()); } - self.validate_runtime_config()?; - - match self.config.driver { - AtxDriverType::Gpio => self.init_gpio().await?, - AtxDriverType::UsbRelay => self.init_usb_relay().await?, - AtxDriverType::Serial => self.init_serial().await?, - AtxDriverType::None => {} - } - - self.initialized.store(true, Ordering::Relaxed); - Ok(()) - } - - fn validate_runtime_config(&self) -> Result<()> { - match self.config.driver { - AtxDriverType::Serial => { - if self.config.pin == 0 { - return Err(AppError::Config( - "Serial ATX channel must be 1-based (>= 1)".to_string(), - )); - } - if self.config.pin > u8::MAX as u32 { - return Err(AppError::Config(format!( - "Serial ATX channel must be <= {}", - u8::MAX - ))); - } - if self.config.baud_rate == 0 { - return Err(AppError::Config( - "Serial ATX baud_rate must be greater than 0".to_string(), - )); - } - } - AtxDriverType::UsbRelay => { - if self.config.pin == 0 { - return Err(AppError::Config( - "USB relay channel must be 1-based (>= 1)".to_string(), - )); - } - if self.config.pin > u8::MAX as u32 { - return Err(AppError::Config(format!( - "USB relay channel must be <= {}", - u8::MAX - ))); - } - if self.config.pin > USB_RELAY_MAX_CHANNEL as u32 { - return Err(AppError::Config(format!( - "USB HID relay channel must be <= {}", - USB_RELAY_MAX_CHANNEL - ))); - } - } - AtxDriverType::Gpio | AtxDriverType::None => {} - } - Ok(()) - } - - /// Initialize GPIO backend - async fn init_gpio(&mut self) -> Result<()> { - info!( - "Initializing GPIO ATX executor on {} pin {}", - self.config.device, self.config.pin - ); - - let mut chip = Chip::new(&self.config.device) - .map_err(|e| AppError::Internal(format!("GPIO chip open failed: {}", e)))?; - - let line = chip.get_line(self.config.pin).map_err(|e| { - AppError::Internal(format!("GPIO line {} failed: {}", self.config.pin, e)) + let backend = self.backend.as_mut().ok_or_else(|| { + AppError::Internal(format!( + "ATX backend {:?} is unsupported on this platform", + self.config.driver + )) })?; - // Initial value depends on active level (start in inactive state) - let initial_value = match self.config.active_level { - ActiveLevel::High => 0, // Inactive = low - ActiveLevel::Low => 1, // Inactive = high - }; - - let handle = line - .request(LineRequestFlags::OUTPUT, initial_value, "one-kvm-atx") - .map_err(|e| AppError::Internal(format!("GPIO request failed: {}", e)))?; - - *self.gpio_handle.lock().unwrap() = Some(handle); - debug!("GPIO pin {} configured successfully", self.config.pin); - Ok(()) + backend.init().await } - /// Initialize USB relay backend - async fn init_usb_relay(&self) -> Result<()> { - info!( - "Initializing USB relay ATX executor on {} channel {}", - self.config.device, self.config.pin - ); - - // Open and cache the device handle - let device = OpenOptions::new() - .read(true) - .write(true) - .open(&self.config.device) - .map_err(|e| AppError::Internal(format!("USB relay device open failed: {}", e)))?; - - *self.usb_relay_handle.lock().unwrap() = Some(device); - - // Ensure relay is off initially - self.send_usb_relay_command(false)?; - - debug!( - "USB relay channel {} configured successfully", - self.config.pin - ); - Ok(()) - } - - /// Initialize Serial relay backend - async fn init_serial(&self) -> Result<()> { - info!( - "Initializing Serial relay ATX executor on {} channel {}", - self.config.device, self.config.pin - ); - - let existing_handle = self.serial_handle.lock().unwrap().as_ref().cloned(); - if existing_handle.is_none() { - let shared = Self::open_shared_serial(&self.config.device, self.config.baud_rate)?; - *self.serial_handle.lock().unwrap() = Some(shared); - } - - // Ensure relay is off initially - self.send_serial_relay_command(false)?; - - debug!( - "Serial relay channel {} configured successfully", - self.config.pin - ); - Ok(()) - } - - /// Pulse the button for the specified duration pub async fn pulse(&self, duration: Duration) -> Result<()> { if !self.is_configured() { return Err(AppError::Internal("ATX key not configured".to_string())); } - if !self.is_initialized() { + let backend = self.backend.as_ref().ok_or_else(|| { + AppError::Internal(format!( + "ATX backend {:?} is unsupported on this platform", + self.config.driver + )) + })?; + + if !backend.is_initialized() { return Err(AppError::Internal("ATX key not initialized".to_string())); } - match self.config.driver { - AtxDriverType::Gpio => self.pulse_gpio(duration).await, - AtxDriverType::UsbRelay => self.pulse_usb_relay(duration).await, - AtxDriverType::Serial => self.pulse_serial(duration).await, - AtxDriverType::None => Ok(()), - } + backend.pulse(duration).await } - /// Pulse GPIO pin - async fn pulse_gpio(&self, duration: Duration) -> Result<()> { - let (active, inactive) = match self.config.active_level { - ActiveLevel::High => (1u8, 0u8), - ActiveLevel::Low => (0u8, 1u8), - }; - - // Set to active state - { - let guard = self.gpio_handle.lock().unwrap(); - let handle = guard - .as_ref() - .ok_or_else(|| AppError::Internal("GPIO not initialized".to_string()))?; - handle - .set_value(active) - .map_err(|e| AppError::Internal(format!("GPIO set failed: {}", e)))?; - } - - // Wait for duration (no lock held) - sleep(duration).await; - - // Set to inactive state - { - let guard = self.gpio_handle.lock().unwrap(); - if let Some(handle) = guard.as_ref() { - handle.set_value(inactive).ok(); - } - } - - Ok(()) - } - - /// Pulse USB relay - async fn pulse_usb_relay(&self, duration: Duration) -> Result<()> { - // Turn relay on - self.send_usb_relay_command(true)?; - - // Wait for duration - sleep(duration).await; - - // Turn relay off - self.send_usb_relay_command(false)?; - - Ok(()) - } - - /// Send USB relay command using cached handle - fn send_usb_relay_command(&self, on: bool) -> Result<()> { - let channel = u8::try_from(self.config.pin).map_err(|_| { - AppError::Config(format!( - "USB relay channel {} exceeds max {}", - self.config.pin, - u8::MAX - )) - })?; - if channel == 0 { - return Err(AppError::Config( - "USB relay channel must be 1-based (>= 1)".to_string(), - )); - } - if channel > USB_RELAY_MAX_CHANNEL { - return Err(AppError::Config(format!( - "USB HID relay channel must be <= {}", - USB_RELAY_MAX_CHANNEL - ))); - } - - let cmd = Self::build_usb_relay_command(channel, on); - - let mut guard = self.usb_relay_handle.lock().unwrap(); - let device = guard - .as_mut() - .ok_or_else(|| AppError::Internal("USB relay not initialized".to_string()))?; - - if let Err(feature_err) = Self::send_usb_relay_feature_report(device, &cmd) { - debug!( - "USB relay feature report failed ({}), falling back to hidraw write", - feature_err - ); - device.write_all(&cmd).map_err(|write_err| { - AppError::Internal(format!( - "USB relay feature report failed: {}; raw write failed: {}", - feature_err, write_err - )) - })?; - device - .flush() - .map_err(|e| AppError::Internal(format!("USB relay flush failed: {}", e)))?; - } - - Ok(()) - } - - fn build_usb_relay_command(channel: u8, on: bool) -> [u8; USB_RELAY_REPORT_LEN] { - let mut cmd = [0x00; USB_RELAY_REPORT_LEN]; - cmd[1] = if on { 0xFF } else { 0xFD }; - cmd[2] = channel; - cmd - } - - fn send_usb_relay_feature_report( - device: &File, - report: &[u8; USB_RELAY_REPORT_LEN], - ) -> std::io::Result<()> { - // Linux hidraw feature reports include the report ID as the first byte. - let rc = unsafe { libc::ioctl(device.as_raw_fd(), HIDIOCSFEATURE_9, report.as_ptr()) }; - if rc < 0 { - Err(std::io::Error::last_os_error()) - } else { - Ok(()) - } - } - - /// Pulse Serial relay - async fn pulse_serial(&self, duration: Duration) -> Result<()> { - info!( - "Pulse serial relay on {} pin {}", - self.config.device, self.config.pin - ); - // Turn relay on - self.send_serial_relay_command(true)?; - - // Wait for duration - sleep(duration).await; - - // Turn relay off - self.send_serial_relay_command(false)?; - - Ok(()) - } - - /// Send Serial relay command using cached handle - fn send_serial_relay_command(&self, on: bool) -> Result<()> { - let channel = u8::try_from(self.config.pin).map_err(|_| { - AppError::Config(format!( - "Serial relay channel {} exceeds max {}", - self.config.pin, - u8::MAX - )) - })?; - if channel == 0 { - return Err(AppError::Config( - "Serial relay channel must be 1-based (>= 1)".to_string(), - )); - } - - // LCUS-Type Protocol - // Frame: [StopByte(A0), Channel, State, Checksum] - // Checksum = A0 + channel + state - let state = if on { 1 } else { 0 }; - let checksum = 0xA0u8.wrapping_add(channel).wrapping_add(state); - - // Example for Channel 1: - // ON: A0 01 01 A2 - // OFF: A0 01 00 A1 - let cmd = [0xA0, channel, state, checksum]; - - let serial_handle = self - .serial_handle - .lock() - .unwrap() - .as_ref() - .cloned() - .ok_or_else(|| AppError::Internal("Serial relay not initialized".to_string()))?; - let mut port = serial_handle.lock().unwrap(); - - port.write_all(&cmd) - .map_err(|e| AppError::Internal(format!("Serial relay write failed: {}", e)))?; - port.flush() - .map_err(|e| AppError::Internal(format!("Serial relay flush failed: {}", e)))?; - - Ok(()) - } - - /// Shutdown the executor pub async fn shutdown(&mut self) -> Result<()> { - if !self.is_initialized() { - return Ok(()); + if let Some(backend) = self.backend.as_mut() { + backend.shutdown().await?; } - - match self.config.driver { - AtxDriverType::Gpio => { - // Release GPIO handle - *self.gpio_handle.lock().unwrap() = None; - } - AtxDriverType::UsbRelay => { - // Ensure relay is off before closing handle - let _ = self.send_usb_relay_command(false); - // Release USB relay handle - *self.usb_relay_handle.lock().unwrap() = None; - } - AtxDriverType::Serial => { - // Ensure relay is off before closing handle - let _ = self.send_serial_relay_command(false); - // Release Serial relay handle - *self.serial_handle.lock().unwrap() = None; - } - AtxDriverType::None => {} - } - - self.initialized.store(false, Ordering::Relaxed); - debug!("ATX key executor shutdown complete"); Ok(()) } } -impl Drop for AtxKeyExecutor { - fn drop(&mut self) { - // Ensure GPIO lines are released - *self.gpio_handle.lock().unwrap() = None; - - // Ensure USB relay is off and handle released - if self.config.driver == AtxDriverType::UsbRelay && self.is_initialized() { - let _ = self.send_usb_relay_command(false); - } - *self.usb_relay_handle.lock().unwrap() = None; - - // Ensure Serial relay is off and handle released - if self.config.driver == AtxDriverType::Serial && self.is_initialized() { - let _ = self.send_serial_relay_command(false); - } - *self.serial_handle.lock().unwrap() = None; +fn build_backend( + config: &AtxKeyConfig, + context: AtxKeyBackendContext, +) -> Option> { + match config.driver { + AtxDriverType::Serial => Some(match context { + AtxKeyBackendContext::Standalone => Box::new(SerialRelayBackend::new(config.clone())), + AtxKeyBackendContext::SharedSerial(handle) => Box::new( + SerialRelayBackend::new_with_shared_serial(config.clone(), handle), + ), + }), + AtxDriverType::Gpio => build_gpio_backend(config), + AtxDriverType::UsbRelay => build_hidraw_backend(config), + AtxDriverType::None => None, } } +#[cfg(unix)] +fn build_gpio_backend(config: &AtxKeyConfig) -> Option> { + Some(Box::new(super::gpio_linux::GpioLinuxBackend::new( + config.clone(), + ))) +} + +#[cfg(not(unix))] +fn build_gpio_backend(_config: &AtxKeyConfig) -> Option> { + Some(Box::new(super::disabled_key::DisabledAtxKeyBackend::new( + "GPIO ATX backend is only available on Linux", + ))) +} + +#[cfg(unix)] +fn build_hidraw_backend(config: &AtxKeyConfig) -> Option> { + Some(Box::new(super::hidraw_linux::HidrawLinuxRelayBackend::new( + config.clone(), + ))) +} + +#[cfg(not(unix))] +fn build_hidraw_backend(_config: &AtxKeyConfig) -> Option> { + Some(Box::new(super::disabled_key::DisabledAtxKeyBackend::new( + "USB hidraw relay backend is only available on Linux", + ))) +} + #[cfg(test)] mod tests { use super::*; + use crate::atx::ActiveLevel; #[test] - fn test_executor_creation() { + fn executor_creation() { let config = AtxKeyConfig::default(); let executor = AtxKeyExecutor::new(config); assert!(!executor.is_configured()); - assert!(!executor.is_initialized()); } #[test] - fn test_executor_with_gpio_config() { + fn executor_with_gpio_config() { let config = AtxKeyConfig { driver: AtxDriverType::Gpio, device: "/dev/gpiochip0".to_string(), @@ -501,16 +154,15 @@ mod tests { }; let executor = AtxKeyExecutor::new(config); assert!(executor.is_configured()); - assert!(!executor.is_initialized()); } #[test] - fn test_executor_with_usb_relay_config() { + fn executor_with_usb_relay_config() { let config = AtxKeyConfig { driver: AtxDriverType::UsbRelay, device: "/dev/hidraw0".to_string(), pin: 1, - active_level: ActiveLevel::High, // Ignored for USB relay + active_level: ActiveLevel::High, baud_rate: 9600, }; let executor = AtxKeyExecutor::new(config); @@ -518,12 +170,12 @@ mod tests { } #[test] - fn test_executor_with_serial_config() { + fn executor_with_serial_config() { let config = AtxKeyConfig { driver: AtxDriverType::Serial, device: "/dev/ttyUSB0".to_string(), pin: 1, - active_level: ActiveLevel::High, // Ignored + active_level: ActiveLevel::High, baud_rate: 9600, }; let executor = AtxKeyExecutor::new(config); @@ -531,91 +183,9 @@ mod tests { } #[test] - fn test_timing_constants() { + fn timing_constants() { assert_eq!(timing::SHORT_PRESS.as_millis(), 500); assert_eq!(timing::LONG_PRESS.as_millis(), 5000); assert_eq!(timing::RESET_PRESS.as_millis(), 500); } - - #[test] - fn test_usb_relay_command_format() { - assert_eq!( - AtxKeyExecutor::build_usb_relay_command(1, true), - [0x00, 0xFF, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] - ); - assert_eq!( - AtxKeyExecutor::build_usb_relay_command(1, false), - [0x00, 0xFD, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] - ); - } - - #[tokio::test] - async fn test_executor_init_rejects_serial_channel_zero() { - let config = AtxKeyConfig { - driver: AtxDriverType::Serial, - device: "/dev/ttyUSB0".to_string(), - pin: 0, - active_level: ActiveLevel::High, - baud_rate: 9600, - }; - let mut executor = AtxKeyExecutor::new(config); - let err = executor.init().await.unwrap_err(); - assert!(matches!(err, AppError::Config(_))); - } - - #[tokio::test] - async fn test_executor_init_rejects_usb_relay_channel_zero() { - let config = AtxKeyConfig { - driver: AtxDriverType::UsbRelay, - device: "/dev/hidraw0".to_string(), - pin: 0, - active_level: ActiveLevel::High, - baud_rate: 9600, - }; - let mut executor = AtxKeyExecutor::new(config); - let err = executor.init().await.unwrap_err(); - assert!(matches!(err, AppError::Config(_))); - } - - #[tokio::test] - async fn test_executor_init_rejects_usb_relay_channel_overflow() { - let config = AtxKeyConfig { - driver: AtxDriverType::UsbRelay, - device: "/dev/hidraw0".to_string(), - pin: USB_RELAY_MAX_CHANNEL as u32 + 1, - active_level: ActiveLevel::High, - baud_rate: 9600, - }; - let mut executor = AtxKeyExecutor::new(config); - let err = executor.init().await.unwrap_err(); - assert!(matches!(err, AppError::Config(_))); - } - - #[tokio::test] - async fn test_executor_init_rejects_serial_channel_overflow() { - let config = AtxKeyConfig { - driver: AtxDriverType::Serial, - device: "/dev/ttyUSB0".to_string(), - pin: 256, - active_level: ActiveLevel::High, - baud_rate: 9600, - }; - let mut executor = AtxKeyExecutor::new(config); - let err = executor.init().await.unwrap_err(); - assert!(matches!(err, AppError::Config(_))); - } - - #[tokio::test] - async fn test_executor_init_rejects_zero_serial_baud_rate() { - let config = AtxKeyConfig { - driver: AtxDriverType::Serial, - device: "/dev/ttyUSB0".to_string(), - pin: 1, - active_level: ActiveLevel::High, - baud_rate: 0, - }; - let mut executor = AtxKeyExecutor::new(config); - let err = executor.init().await.unwrap_err(); - assert!(matches!(err, AppError::Config(_))); - } } diff --git a/src/atx/gpio_linux.rs b/src/atx/gpio_linux.rs new file mode 100644 index 00000000..869ab778 --- /dev/null +++ b/src/atx/gpio_linux.rs @@ -0,0 +1,106 @@ +use async_trait::async_trait; +use gpio_cdev::{Chip, LineHandle, LineRequestFlags}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Mutex; +use std::time::Duration; +use tokio::time::sleep; +use tracing::{debug, info}; + +use super::traits::AtxKeyBackend; +use super::types::{ActiveLevel, AtxKeyConfig}; +use crate::error::{AppError, Result}; + +pub struct GpioLinuxBackend { + config: AtxKeyConfig, + handle: Mutex>, + initialized: AtomicBool, +} + +impl GpioLinuxBackend { + pub fn new(config: AtxKeyConfig) -> Self { + Self { + config, + handle: Mutex::new(None), + initialized: AtomicBool::new(false), + } + } +} + +#[async_trait] +impl AtxKeyBackend for GpioLinuxBackend { + async fn init(&mut self) -> Result<()> { + info!( + "Initializing GPIO ATX backend on {} pin {}", + self.config.device, self.config.pin + ); + + let mut chip = Chip::new(&self.config.device) + .map_err(|e| AppError::Internal(format!("GPIO chip open failed: {}", e)))?; + + let line = chip.get_line(self.config.pin).map_err(|e| { + AppError::Internal(format!("GPIO line {} failed: {}", self.config.pin, e)) + })?; + + let initial_value = match self.config.active_level { + ActiveLevel::High => 0, + ActiveLevel::Low => 1, + }; + + let handle = line + .request(LineRequestFlags::OUTPUT, initial_value, "one-kvm-atx") + .map_err(|e| AppError::Internal(format!("GPIO request failed: {}", e)))?; + + *self.handle.lock().unwrap() = Some(handle); + self.initialized.store(true, Ordering::Relaxed); + debug!("GPIO pin {} configured successfully", self.config.pin); + Ok(()) + } + + async fn pulse(&self, duration: Duration) -> Result<()> { + if !self.is_initialized() { + return Err(AppError::Internal("GPIO not initialized".to_string())); + } + + let (active, inactive) = match self.config.active_level { + ActiveLevel::High => (1u8, 0u8), + ActiveLevel::Low => (0u8, 1u8), + }; + + { + let guard = self.handle.lock().unwrap(); + let handle = guard + .as_ref() + .ok_or_else(|| AppError::Internal("GPIO not initialized".to_string()))?; + handle + .set_value(active) + .map_err(|e| AppError::Internal(format!("GPIO set failed: {}", e)))?; + } + + sleep(duration).await; + + { + let guard = self.handle.lock().unwrap(); + if let Some(handle) = guard.as_ref() { + handle.set_value(inactive).ok(); + } + } + + Ok(()) + } + + async fn shutdown(&mut self) -> Result<()> { + *self.handle.lock().unwrap() = None; + self.initialized.store(false, Ordering::Relaxed); + Ok(()) + } + + fn is_initialized(&self) -> bool { + self.initialized.load(Ordering::Relaxed) + } +} + +impl Drop for GpioLinuxBackend { + fn drop(&mut self) { + *self.handle.lock().unwrap() = None; + } +} diff --git a/src/atx/hidraw_linux.rs b/src/atx/hidraw_linux.rs new file mode 100644 index 00000000..829a93dd --- /dev/null +++ b/src/atx/hidraw_linux.rs @@ -0,0 +1,190 @@ +use async_trait::async_trait; +use std::fs::{File, OpenOptions}; +use std::io::Write; +use std::os::fd::AsRawFd; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Mutex; +use std::time::Duration; +use tokio::time::sleep; +use tracing::{debug, info}; + +use super::traits::AtxKeyBackend; +use super::types::AtxKeyConfig; +use crate::error::{AppError, Result}; + +const USB_RELAY_MAX_CHANNEL: u8 = 8; +const USB_RELAY_REPORT_LEN: usize = 9; +const HIDIOCSFEATURE_9: libc::c_ulong = 0xC009_4806; + +pub struct HidrawLinuxRelayBackend { + config: AtxKeyConfig, + handle: Mutex>, + initialized: AtomicBool, +} + +impl HidrawLinuxRelayBackend { + pub fn new(config: AtxKeyConfig) -> Self { + Self { + config, + handle: Mutex::new(None), + initialized: AtomicBool::new(false), + } + } + + fn validate_config(&self) -> Result<()> { + if self.config.pin == 0 { + return Err(AppError::Config( + "USB relay channel must be 1-based (>= 1)".to_string(), + )); + } + if self.config.pin > USB_RELAY_MAX_CHANNEL as u32 { + return Err(AppError::Config(format!( + "USB HID relay channel must be <= {}", + USB_RELAY_MAX_CHANNEL + ))); + } + Ok(()) + } + + fn send_command(&self, on: bool) -> Result<()> { + let channel = u8::try_from(self.config.pin).map_err(|_| { + AppError::Config(format!( + "USB relay channel {} exceeds max {}", + self.config.pin, + u8::MAX + )) + })?; + if channel == 0 { + return Err(AppError::Config( + "USB relay channel must be 1-based (>= 1)".to_string(), + )); + } + if channel > USB_RELAY_MAX_CHANNEL { + return Err(AppError::Config(format!( + "USB HID relay channel must be <= {}", + USB_RELAY_MAX_CHANNEL + ))); + } + + let cmd = Self::build_command(channel, on); + let mut guard = self.handle.lock().unwrap(); + let device = guard + .as_mut() + .ok_or_else(|| AppError::Internal("USB relay not initialized".to_string()))?; + + if let Err(feature_err) = Self::send_feature_report(device, &cmd) { + debug!( + "USB relay feature report failed ({}), falling back to hidraw write", + feature_err + ); + device.write_all(&cmd).map_err(|write_err| { + AppError::Internal(format!( + "USB relay feature report failed: {}; raw write failed: {}", + feature_err, write_err + )) + })?; + device + .flush() + .map_err(|e| AppError::Internal(format!("USB relay flush failed: {}", e)))?; + } + + Ok(()) + } + + pub fn build_command(channel: u8, on: bool) -> [u8; USB_RELAY_REPORT_LEN] { + let mut cmd = [0x00; USB_RELAY_REPORT_LEN]; + cmd[1] = if on { 0xFF } else { 0xFD }; + cmd[2] = channel; + cmd + } + + fn send_feature_report( + device: &File, + report: &[u8; USB_RELAY_REPORT_LEN], + ) -> std::io::Result<()> { + let rc = unsafe { libc::ioctl(device.as_raw_fd(), HIDIOCSFEATURE_9, report.as_ptr()) }; + if rc < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } + } +} + +#[async_trait] +impl AtxKeyBackend for HidrawLinuxRelayBackend { + async fn init(&mut self) -> Result<()> { + self.validate_config()?; + + info!( + "Initializing USB relay ATX backend on {} channel {}", + self.config.device, self.config.pin + ); + + let device = OpenOptions::new() + .read(true) + .write(true) + .open(&self.config.device) + .map_err(|e| AppError::Internal(format!("USB relay device open failed: {}", e)))?; + + *self.handle.lock().unwrap() = Some(device); + self.send_command(false)?; + self.initialized.store(true, Ordering::Relaxed); + + debug!( + "USB relay channel {} configured successfully", + self.config.pin + ); + Ok(()) + } + + async fn pulse(&self, duration: Duration) -> Result<()> { + if !self.is_initialized() { + return Err(AppError::Internal("USB relay not initialized".to_string())); + } + + self.send_command(true)?; + sleep(duration).await; + self.send_command(false)?; + Ok(()) + } + + async fn shutdown(&mut self) -> Result<()> { + if self.is_initialized() { + let _ = self.send_command(false); + } + *self.handle.lock().unwrap() = None; + self.initialized.store(false, Ordering::Relaxed); + Ok(()) + } + + fn is_initialized(&self) -> bool { + self.initialized.load(Ordering::Relaxed) + } +} + +impl Drop for HidrawLinuxRelayBackend { + fn drop(&mut self) { + if self.is_initialized() { + let _ = self.send_command(false); + } + *self.handle.lock().unwrap() = None; + } +} + +#[cfg(test)] +mod tests { + use super::HidrawLinuxRelayBackend; + + #[test] + fn usb_relay_command_format() { + assert_eq!( + HidrawLinuxRelayBackend::build_command(1, true), + [0x00, 0xFF, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] + ); + assert_eq!( + HidrawLinuxRelayBackend::build_command(1, false), + [0x00, 0xFD, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] + ); + } +} diff --git a/src/atx/led.rs b/src/atx/led.rs index 6f970479..0b4db12e 100644 --- a/src/atx/led.rs +++ b/src/atx/led.rs @@ -10,9 +10,6 @@ use tracing::{debug, info}; use super::types::{AtxLedConfig, PowerStatus}; use crate::error::{AppError, Result}; -/// LED sensor for reading power status -/// -/// Uses GPIO to read the power LED state and determine if the system is on or off. pub struct LedSensor { config: AtxLedConfig, handle: Mutex>, @@ -20,7 +17,6 @@ pub struct LedSensor { } impl LedSensor { - /// Create a new LED sensor with the given configuration pub fn new(config: AtxLedConfig) -> Self { Self { config, @@ -29,17 +25,6 @@ impl LedSensor { } } - /// Check if the sensor is configured - pub fn is_configured(&self) -> bool { - self.config.is_configured() - } - - /// Check if the sensor is initialized - pub fn is_initialized(&self) -> bool { - self.initialized.load(Ordering::Relaxed) - } - - /// Initialize the LED sensor pub async fn init(&mut self) -> Result<()> { if !self.config.is_configured() { debug!("LED sensor not configured, skipping init"); @@ -72,9 +57,8 @@ impl LedSensor { Ok(()) } - /// Read the current power status pub async fn read(&self) -> Result { - if !self.is_configured() || !self.is_initialized() { + if !self.config.is_configured() || !self.initialized.load(Ordering::Relaxed) { return Ok(PowerStatus::Unknown); } @@ -85,11 +69,10 @@ impl LedSensor { .get_value() .map_err(|e| AppError::Internal(format!("LED read failed: {}", e)))?; - // Apply inversion if configured let is_on = if self.config.inverted { - value == 0 // Active low: 0 means on + value == 0 } else { - value == 1 // Active high: 1 means on + value == 1 }; Ok(if is_on { @@ -102,7 +85,6 @@ impl LedSensor { } } - /// Shutdown the LED sensor pub async fn shutdown(&mut self) -> Result<()> { *self.handle.lock().unwrap() = None; self.initialized.store(false, Ordering::Relaxed); @@ -125,8 +107,8 @@ mod tests { fn test_led_sensor_creation() { let config = AtxLedConfig::default(); let sensor = LedSensor::new(config); - assert!(!sensor.is_configured()); - assert!(!sensor.is_initialized()); + assert!(!sensor.config.is_configured()); + assert!(!sensor.initialized.load(Ordering::Relaxed)); } #[test] @@ -138,8 +120,8 @@ mod tests { inverted: false, }; let sensor = LedSensor::new(config); - assert!(sensor.is_configured()); - assert!(!sensor.is_initialized()); + assert!(sensor.config.is_configured()); + assert!(!sensor.initialized.load(Ordering::Relaxed)); } #[test] @@ -151,7 +133,6 @@ mod tests { inverted: true, }; let sensor = LedSensor::new(config); - assert!(sensor.is_configured()); assert!(sensor.config.inverted); } } diff --git a/src/atx/mod.rs b/src/atx/mod.rs index 95083259..4834aad3 100644 --- a/src/atx/mod.rs +++ b/src/atx/mod.rs @@ -2,53 +2,22 @@ //! //! Provides ATX power management functionality for IP-KVM. //! Supports flexible hardware binding with independent configuration for each action. -//! -//! # Features -//! -//! - Power button control (short press for on/graceful shutdown, long press for force off) -//! - Reset button control -//! - Power status monitoring via LED sensing (GPIO only) -//! - Independent hardware binding for each action (GPIO or USB relay) -//! - Hot-reload configuration support -//! -//! # Hardware Support -//! -//! - **GPIO**: Uses Linux GPIO character device (/dev/gpiochipX) for direct hardware control -//! - **USB Relay**: Uses HID USB relay modules for isolated switching -//! - **Serial Relay**: Uses LCUS-style serial relay modules -//! -//! # Example -//! -//! ```ignore -//! use one_kvm::atx::{AtxController, AtxControllerConfig, AtxKeyConfig, AtxDriverType, ActiveLevel}; -//! -//! let config = AtxControllerConfig { -//! enabled: true, -//! power: AtxKeyConfig { -//! driver: AtxDriverType::Gpio, -//! device: "/dev/gpiochip0".to_string(), -//! pin: 5, -//! active_level: ActiveLevel::High, -//! baud_rate: 9600, -//! }, -//! reset: AtxKeyConfig { -//! driver: AtxDriverType::UsbRelay, -//! device: "/dev/hidraw0".to_string(), -//! pin: 0, -//! active_level: ActiveLevel::High, -//! baud_rate: 9600, -//! }, -//! led: Default::default(), -//! }; -//! -//! let controller = AtxController::new(config); -//! controller.init().await?; -//! controller.power_short().await?; // Turn on or graceful shutdown -//! ``` mod controller; +#[cfg(not(unix))] +mod disabled_key; mod executor; +#[cfg(unix)] +mod gpio_linux; +#[cfg(unix)] +mod hidraw_linux; +#[cfg(unix)] mod led; +#[cfg(not(unix))] +#[path = "disabled_led.rs"] +mod led; +mod serial_relay; +mod traits; mod types; mod wol; @@ -58,8 +27,9 @@ pub use types::{ ActiveLevel, AtxAction, AtxDevices, AtxDriverType, AtxKeyConfig, AtxLedConfig, AtxPowerRequest, AtxState, PowerStatus, }; -pub use wol::send_wol; +pub use wol::{list_wol_history, record_wol_history, send_wol}; +#[cfg(any(unix, test))] fn hidraw_uevent_is_usb_relay(uevent: &str) -> bool { let upper = uevent.to_ascii_uppercase(); upper.contains("000016C0:000005DF") @@ -69,6 +39,7 @@ fn hidraw_uevent_is_usb_relay(uevent: &str) -> bool { || upper.contains("USB RELAY") } +#[cfg(unix)] fn is_usb_relay_hidraw(name: &str) -> bool { let uevent_path = format!("/sys/class/hidraw/{}/device/uevent", name); std::fs::read_to_string(uevent_path) @@ -82,7 +53,9 @@ fn is_usb_relay_hidraw(name: &str) -> bool { pub fn discover_devices() -> AtxDevices { let mut devices = AtxDevices::default(); - // Single pass through /dev directory + devices.serial_ports = crate::utils::list_serial_ports(); + + #[cfg(unix)] if let Ok(entries) = std::fs::read_dir("/dev") { for entry in entries.flatten() { let name = entry.file_name(); @@ -100,6 +73,7 @@ pub fn discover_devices() -> AtxDevices { devices.gpio_chips.sort(); devices.usb_relays.sort(); devices.serial_ports.sort(); + devices.serial_ports.dedup(); devices } @@ -129,7 +103,6 @@ mod tests { #[test] fn test_module_exports() { - // Verify all public exports are accessible let _: AtxDriverType = AtxDriverType::None; let _: ActiveLevel = ActiveLevel::High; let _: AtxKeyConfig = AtxKeyConfig::default(); diff --git a/src/atx/serial_relay.rs b/src/atx/serial_relay.rs new file mode 100644 index 00000000..bdc23f7b --- /dev/null +++ b/src/atx/serial_relay.rs @@ -0,0 +1,141 @@ +use async_trait::async_trait; +use std::io::Write; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use tokio::time::sleep; +use tracing::{debug, info}; + +use super::traits::{validate_serial_config, AtxKeyBackend, SharedSerialHandle}; +use super::types::AtxKeyConfig; +use crate::error::{AppError, Result}; + +pub struct SerialRelayBackend { + config: AtxKeyConfig, + serial_handle: Mutex>, + initialized: AtomicBool, +} + +impl SerialRelayBackend { + pub fn new(config: AtxKeyConfig) -> Self { + Self { + config, + serial_handle: Mutex::new(None), + initialized: AtomicBool::new(false), + } + } + + pub fn new_with_shared_serial(config: AtxKeyConfig, serial_handle: SharedSerialHandle) -> Self { + Self { + config, + serial_handle: Mutex::new(Some(serial_handle)), + initialized: AtomicBool::new(false), + } + } + + pub fn open_shared_serial(device: &str, baud_rate: u32) -> Result { + let port = serialport::new(device, baud_rate) + .timeout(Duration::from_millis(100)) + .open() + .map_err(|e| AppError::Internal(format!("Serial port open failed: {}", e)))?; + Ok(Arc::new(Mutex::new(port))) + } + + fn send_command(&self, on: bool) -> Result<()> { + let channel = u8::try_from(self.config.pin).map_err(|_| { + AppError::Config(format!( + "Serial relay channel {} exceeds max {}", + self.config.pin, + u8::MAX + )) + })?; + + let state = if on { 1 } else { 0 }; + let checksum = 0xA0u8.wrapping_add(channel).wrapping_add(state); + let cmd = [0xA0, channel, state, checksum]; + + let serial_handle = self + .serial_handle + .lock() + .unwrap() + .as_ref() + .cloned() + .ok_or_else(|| AppError::Internal("Serial relay not initialized".to_string()))?; + let mut port = serial_handle.lock().unwrap(); + + port.write_all(&cmd) + .map_err(|e| AppError::Internal(format!("Serial relay write failed: {}", e)))?; + port.flush() + .map_err(|e| AppError::Internal(format!("Serial relay flush failed: {}", e)))?; + + Ok(()) + } +} + +#[async_trait] +impl AtxKeyBackend for SerialRelayBackend { + async fn init(&mut self) -> Result<()> { + validate_serial_config(&self.config)?; + + info!( + "Initializing Serial relay ATX backend on {} channel {}", + self.config.device, self.config.pin + ); + + let existing_handle = self.serial_handle.lock().unwrap().as_ref().cloned(); + if existing_handle.is_none() { + let shared = Self::open_shared_serial(&self.config.device, self.config.baud_rate)?; + *self.serial_handle.lock().unwrap() = Some(shared); + } + + self.send_command(false)?; + self.initialized.store(true, Ordering::Relaxed); + + debug!( + "Serial relay channel {} configured successfully", + self.config.pin + ); + Ok(()) + } + + async fn pulse(&self, duration: Duration) -> Result<()> { + if !self.is_initialized() { + return Err(AppError::Internal( + "Serial relay not initialized".to_string(), + )); + } + + info!( + "Pulse serial relay on {} pin {}", + self.config.device, self.config.pin + ); + self.send_command(true)?; + sleep(duration).await; + self.send_command(false)?; + Ok(()) + } + + async fn shutdown(&mut self) -> Result<()> { + if !self.is_initialized() { + return Ok(()); + } + + let _ = self.send_command(false); + *self.serial_handle.lock().unwrap() = None; + self.initialized.store(false, Ordering::Relaxed); + Ok(()) + } + + fn is_initialized(&self) -> bool { + self.initialized.load(Ordering::Relaxed) + } +} + +impl Drop for SerialRelayBackend { + fn drop(&mut self) { + if self.is_initialized() { + let _ = self.send_command(false); + } + *self.serial_handle.lock().unwrap() = None; + } +} diff --git a/src/atx/traits.rs b/src/atx/traits.rs new file mode 100644 index 00000000..d422b2a9 --- /dev/null +++ b/src/atx/traits.rs @@ -0,0 +1,51 @@ +use async_trait::async_trait; +use serialport::SerialPort; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use super::types::AtxKeyConfig; +use crate::error::Result; + +pub type SharedSerialHandle = Arc>>; + +#[async_trait] +pub trait AtxKeyBackend: Send + Sync { + async fn init(&mut self) -> Result<()>; + + async fn pulse(&self, duration: Duration) -> Result<()>; + + async fn shutdown(&mut self) -> Result<()>; + + fn is_initialized(&self) -> bool; +} + +#[derive(Debug, Clone)] +pub enum AtxKeyBackendContext { + Standalone, + SharedSerial(SharedSerialHandle), +} + +pub fn validate_serial_config(config: &AtxKeyConfig) -> Result<()> { + if config.device.trim().is_empty() { + return Err(crate::error::AppError::Config( + "Serial ATX device cannot be empty".to_string(), + )); + } + if config.pin == 0 { + return Err(crate::error::AppError::Config( + "Serial ATX channel must be 1-based (>= 1)".to_string(), + )); + } + if config.pin > u8::MAX as u32 { + return Err(crate::error::AppError::Config(format!( + "Serial ATX channel must be <= {}", + u8::MAX + ))); + } + if config.baud_rate == 0 { + return Err(crate::error::AppError::Config( + "Serial ATX baud_rate must be greater than 0".to_string(), + )); + } + Ok(()) +} diff --git a/src/atx/types.rs b/src/atx/types.rs index fdf38d58..ba353062 100644 --- a/src/atx/types.rs +++ b/src/atx/types.rs @@ -6,67 +6,43 @@ use serde::{Deserialize, Serialize}; use typeshare::typeshare; -/// Power status #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all = "lowercase")] pub enum PowerStatus { - /// Power is on On, - /// Power is off Off, - /// Power status unknown (no LED connected) #[default] Unknown, } -/// Driver type for ATX key operations #[typeshare] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all = "lowercase")] pub enum AtxDriverType { - /// GPIO control via Linux character device Gpio, - /// USB HID relay module UsbRelay, - /// Serial/COM port relay (taobao LCUS type) Serial, - /// Disabled / Not configured #[default] None, } -/// Active level for GPIO pins #[typeshare] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all = "lowercase")] pub enum ActiveLevel { - /// Active high (default for most cases) #[default] High, - /// Active low (inverted) Low, } -/// Configuration for a single ATX key (power or reset) -/// This is the "four-tuple" configuration: (driver, device, pin/channel, level) #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(default)] pub struct AtxKeyConfig { - /// Driver type (GPIO or USB Relay) pub driver: AtxDriverType, - /// Device path: - /// - For GPIO: /dev/gpiochipX - /// - For USB Relay: /dev/hidrawX pub device: String, - /// Pin or channel number: - /// - For GPIO: GPIO pin number - /// - For USB Relay: relay channel (1-based) - /// - For Serial Relay (LCUS): relay channel (1-based) pub pin: u32, - /// Active level (only applicable to GPIO, ignored for USB Relay) pub active_level: ActiveLevel, - /// Baud rate for serial relay (start with 9600) pub baud_rate: u32, } @@ -83,77 +59,54 @@ impl Default for AtxKeyConfig { } impl AtxKeyConfig { - /// Check if this key is configured pub fn is_configured(&self) -> bool { self.driver != AtxDriverType::None && !self.device.is_empty() } } -/// LED sensing configuration (optional) #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] #[serde(default)] pub struct AtxLedConfig { - /// Whether LED sensing is enabled pub enabled: bool, - /// GPIO chip for LED sensing pub gpio_chip: String, - /// GPIO pin for LED input pub gpio_pin: u32, - /// Whether LED is active low (inverted logic) pub inverted: bool, } impl AtxLedConfig { - /// Check if LED sensing is configured pub fn is_configured(&self) -> bool { self.enabled && !self.gpio_chip.is_empty() } } -/// ATX state information #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct AtxState { - /// Whether ATX feature is available/enabled pub available: bool, - /// Whether power button is configured pub power_configured: bool, - /// Whether reset button is configured pub reset_configured: bool, - /// Current power status pub power_status: PowerStatus, - /// Whether power LED sensing is supported pub led_supported: bool, } -/// ATX power action request #[derive(Debug, Clone, Deserialize)] pub struct AtxPowerRequest { - /// Action to perform: "short", "long", "reset" pub action: AtxAction, } -/// ATX power action #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum AtxAction { - /// Short press power button (turn on or graceful shutdown) Short, - /// Long press power button (force power off) Long, - /// Press reset button Reset, } -/// Available ATX devices for discovery #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AtxDevices { - /// Available GPIO chips (/dev/gpiochip*) pub gpio_chips: Vec, - /// Available USB HID relay devices (/dev/hidraw*) pub usb_relays: Vec, - /// Available Serial ports (/dev/ttyUSB*) pub serial_ports: Vec, } @@ -201,13 +154,13 @@ mod tests { assert!(!config.is_configured()); config.driver = AtxDriverType::Gpio; - assert!(!config.is_configured()); // device still empty + assert!(!config.is_configured()); config.device = "/dev/gpiochip0".to_string(); assert!(config.is_configured()); config.driver = AtxDriverType::None; - assert!(!config.is_configured()); // driver is None + assert!(!config.is_configured()); } #[test] @@ -224,7 +177,7 @@ mod tests { assert!(!config.is_configured()); config.enabled = true; - assert!(!config.is_configured()); // gpio_chip still empty + assert!(!config.is_configured()); config.gpio_chip = "/dev/gpiochip0".to_string(); assert!(config.is_configured()); diff --git a/src/atx/wol.rs b/src/atx/wol.rs index 9da93cd7..ab66bd2e 100644 --- a/src/atx/wol.rs +++ b/src/atx/wol.rs @@ -3,18 +3,14 @@ //! Sends magic packets to wake up remote machines. use std::net::{SocketAddr, UdpSocket}; -use tracing::{debug, info}; +use tracing::info; use crate::error::{AppError, Result}; -/// WOL magic packet structure: -/// - 6 bytes of 0xFF -/// - 16 repetitions of the target MAC address (6 bytes each) -/// Total: 6 + 16 * 6 = 102 bytes +const WOL_HISTORY_MAX_ENTRIES: i64 = 50; + const MAGIC_PACKET_SIZE: usize = 102; -/// Parse MAC address string into bytes -/// Supports formats: "AA:BB:CC:DD:EE:FF" or "AA-BB-CC-DD-EE-FF" fn parse_mac_address(mac: &str) -> Result<[u8; 6]> { let mac = mac.trim().to_uppercase(); let parts: Vec<&str> = if mac.contains(':') { @@ -44,16 +40,13 @@ fn parse_mac_address(mac: &str) -> Result<[u8; 6]> { Ok(bytes) } -/// Build WOL magic packet fn build_magic_packet(mac: &[u8; 6]) -> [u8; MAGIC_PACKET_SIZE] { let mut packet = [0u8; MAGIC_PACKET_SIZE]; - // First 6 bytes are 0xFF for byte in packet.iter_mut().take(6) { *byte = 0xFF; } - // Next 96 bytes are 16 repetitions of the MAC address for i in 0..16 { let offset = 6 + i * 6; packet[offset..offset + 6].copy_from_slice(mac); @@ -73,16 +66,13 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> { info!("Sending WOL packet to {} via {:?}", mac_address, interface); - // Create UDP socket let socket = UdpSocket::bind("0.0.0.0:0") .map_err(|e| AppError::Internal(format!("Failed to create UDP socket: {}", e)))?; - // Enable broadcast socket .set_broadcast(true) .map_err(|e| AppError::Internal(format!("Failed to enable broadcast: {}", e)))?; - // Bind to specific interface if specified #[cfg(target_os = "linux")] if let Some(iface) = interface { if !iface.is_empty() { @@ -90,8 +80,7 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> { let fd = socket.as_raw_fd(); let iface_bytes = iface.as_bytes(); - // SO_BINDTODEVICE requires interface name as null-terminated string - let mut iface_buf = [0u8; 16]; // IFNAMSIZ is typically 16 + let mut iface_buf = [0u8; 16]; let len = iface_bytes.len().min(15); iface_buf[..len].copy_from_slice(&iface_bytes[..len]); @@ -112,18 +101,16 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> { iface, err ))); } - debug!("Bound to interface: {}", iface); + tracing::debug!("Bound to interface: {}", iface); } } - // Send to broadcast address on port 9 (discard protocol, commonly used for WOL) let broadcast_addr: SocketAddr = "255.255.255.255:9".parse().unwrap(); socket .send_to(&packet, broadcast_addr) .map_err(|e| AppError::Internal(format!("Failed to send WOL packet: {}", e)))?; - // Also try sending to port 7 (echo protocol, alternative WOL port) let broadcast_addr_7: SocketAddr = "255.255.255.255:7".parse().unwrap(); let _ = socket.send_to(&packet, broadcast_addr_7); @@ -131,6 +118,55 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> { Ok(()) } +pub async fn record_wol_history(pool: &sqlx::Pool, mac_address: &str) -> Result<()> { + sqlx::query( + r#" + INSERT INTO wol_history (mac_address, updated_at) + VALUES (?1, CAST(strftime('%s', 'now') AS INTEGER)) + ON CONFLICT(mac_address) DO UPDATE SET + updated_at = excluded.updated_at + "#, + ) + .bind(mac_address) + .execute(pool) + .await?; + + sqlx::query( + r#" + DELETE FROM wol_history + WHERE mac_address NOT IN ( + SELECT mac_address FROM wol_history + ORDER BY updated_at DESC + LIMIT ?1 + ) + "#, + ) + .bind(WOL_HISTORY_MAX_ENTRIES) + .execute(pool) + .await?; + + Ok(()) +} + +pub async fn list_wol_history( + pool: &sqlx::Pool, + limit: usize, +) -> Result> { + let rows = sqlx::query_as( + r#" + SELECT mac_address, updated_at + FROM wol_history + ORDER BY updated_at DESC + LIMIT ?1 + "#, + ) + .bind(limit as i64) + .fetch_all(pool) + .await?; + + Ok(rows) +} + #[cfg(test)] mod tests { use super::*; @@ -159,12 +195,10 @@ mod tests { let mac = [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]; let packet = build_magic_packet(&mac); - // Check header (6 bytes of 0xFF) for byte in packet.iter().take(6) { assert_eq!(*byte, 0xFF); } - // Check MAC repetitions for i in 0..16 { let offset = 6 + i * 6; assert_eq!(&packet[offset..offset + 6], &mac); diff --git a/src/audio/capture.rs b/src/audio/capture.rs index 14a3510b..8565763d 100644 --- a/src/audio/capture.rs +++ b/src/audio/capture.rs @@ -1,334 +1,9 @@ -use alsa::pcm::{Access, Format, Frames, HwParams, State, IO}; -use alsa::{Direction, ValueOr, PCM}; -use bytes::Bytes; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::Arc; -use std::time::Instant; -use tokio::sync::{broadcast, watch, Mutex}; -use tracing::{debug, info}; +#[cfg(unix)] +#[path = "capture_linux.rs"] +mod imp; -use super::device::AudioDeviceInfo; -use crate::error::{AppError, Result}; -use crate::utils::LogThrottler; -use crate::{error_throttled, warn_throttled}; +#[cfg(windows)] +#[path = "capture_windows.rs"] +mod imp; -#[derive(Debug, Clone)] -pub struct AudioConfig { - pub device_name: String, - pub sample_rate: u32, - pub channels: u32, - pub frame_size: u32, - pub buffer_frames: u32, - pub period_frames: u32, -} - -impl Default for AudioConfig { - fn default() -> Self { - Self { - device_name: String::new(), - sample_rate: 48000, - channels: 2, - frame_size: 960, - buffer_frames: 4096, - period_frames: 960, - } - } -} - -impl AudioConfig { - pub fn for_device(device: &AudioDeviceInfo) -> Self { - Self { - device_name: device.name.clone(), - ..Default::default() - } - } - - pub fn bytes_per_sample(&self) -> u32 { - 2 * self.channels - } - - pub fn bytes_per_frame(&self) -> usize { - (self.frame_size * self.bytes_per_sample()) as usize - } -} - -#[derive(Debug, Clone)] -pub struct AudioFrame { - pub data: Bytes, - pub sample_rate: u32, - pub channels: u32, - pub samples: u32, - pub sequence: u64, - pub timestamp: Instant, -} - -impl AudioFrame { - pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self { - let bps = 2 * channels; - Self { - samples: data.len() as u32 / bps, - data, - sample_rate, - channels, - sequence, - timestamp: Instant::now(), - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CaptureState { - Stopped, - Running, - Error, -} - -pub struct AudioCapturer { - config: AudioConfig, - state: Arc>, - state_rx: watch::Receiver, - frame_tx: broadcast::Sender, - stop_flag: Arc, - sequence: Arc, - capture_handle: Mutex>>, - log_throttler: LogThrottler, -} - -impl AudioCapturer { - pub fn new(config: AudioConfig) -> Self { - let (state_tx, state_rx) = watch::channel(CaptureState::Stopped); - let (frame_tx, _) = broadcast::channel(16); - - Self { - config, - state: Arc::new(state_tx), - state_rx, - frame_tx, - stop_flag: Arc::new(AtomicBool::new(false)), - sequence: Arc::new(AtomicU64::new(0)), - capture_handle: Mutex::new(None), - log_throttler: LogThrottler::with_secs(5), - } - } - - pub fn state(&self) -> CaptureState { - *self.state_rx.borrow() - } - - pub fn state_watch(&self) -> watch::Receiver { - self.state_rx.clone() - } - - pub fn subscribe(&self) -> broadcast::Receiver { - self.frame_tx.subscribe() - } - - pub async fn start(&self) -> Result<()> { - if self.state() == CaptureState::Running { - return Ok(()); - } - - debug!( - "Starting audio capture on {} at {}Hz {}ch", - self.config.device_name, self.config.sample_rate, self.config.channels - ); - - self.stop_flag.store(false, Ordering::SeqCst); - - let config = self.config.clone(); - let state = self.state.clone(); - let frame_tx = self.frame_tx.clone(); - let stop_flag = self.stop_flag.clone(); - let sequence = self.sequence.clone(); - let log_throttler = self.log_throttler.clone(); - - let handle = tokio::task::spawn_blocking(move || { - let result = run_capture( - &config, - &state, - &frame_tx, - &stop_flag, - &sequence, - &log_throttler, - ); - - if let Err(e) = result { - error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e); - let _ = state.send(CaptureState::Error); - } else { - let _ = state.send(CaptureState::Stopped); - } - }); - - *self.capture_handle.lock().await = Some(handle); - Ok(()) - } - - pub async fn stop(&self) -> Result<()> { - info!("Stopping audio capture"); - self.stop_flag.store(true, Ordering::SeqCst); - - if let Some(handle) = self.capture_handle.lock().await.take() { - let _ = handle.await; - } - - let _ = self.state.send(CaptureState::Stopped); - Ok(()) - } - - pub fn is_running(&self) -> bool { - self.state() == CaptureState::Running - } -} - -fn run_capture( - config: &AudioConfig, - state: &watch::Sender, - frame_tx: &broadcast::Sender, - stop_flag: &AtomicBool, - sequence: &AtomicU64, - log_throttler: &LogThrottler, -) -> Result<()> { - let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| { - AppError::AudioError(format!( - "Failed to open audio device {}: {}", - config.device_name, e - )) - })?; - - { - let hwp = HwParams::any(&pcm) - .map_err(|e| AppError::AudioError(format!("Failed to get HwParams: {}", e)))?; - - hwp.set_channels(config.channels) - .map_err(|e| AppError::AudioError(format!("Failed to set channels: {}", e)))?; - - hwp.set_rate(config.sample_rate, ValueOr::Nearest) - .map_err(|e| AppError::AudioError(format!("Failed to set sample rate: {}", e)))?; - - hwp.set_format(Format::s16()) - .map_err(|e| AppError::AudioError(format!("Failed to set format: {}", e)))?; - - hwp.set_access(Access::RWInterleaved) - .map_err(|e| AppError::AudioError(format!("Failed to set access: {}", e)))?; - - hwp.set_buffer_size_near(config.buffer_frames as Frames) - .map_err(|e| AppError::AudioError(format!("Failed to set buffer size: {}", e)))?; - - hwp.set_period_size_near(config.period_frames as Frames, ValueOr::Nearest) - .map_err(|e| AppError::AudioError(format!("Failed to set period size: {}", e)))?; - - pcm.hw_params(&hwp) - .map_err(|e| AppError::AudioError(format!("Failed to apply hw params: {}", e)))?; - } - - let hw_now = pcm.hw_params_current().map_err(|e| { - AppError::AudioError(format!("Failed to read hw_params after apply: {}", e)) - })?; - let actual_rate = hw_now - .get_rate() - .map_err(|e| AppError::AudioError(format!("Failed to read sample rate: {}", e)))?; - let actual_ch = hw_now - .get_channels() - .map_err(|e| AppError::AudioError(format!("Failed to read channels: {}", e)))?; - if actual_rate != 48_000 { - return Err(AppError::AudioError(format!( - "Audio capture requires 48000 Hz; device is {} Hz", - actual_rate - ))); - } - if actual_ch != 2 { - return Err(AppError::AudioError(format!( - "Audio capture requires 2 channels (stereo); device has {}", - actual_ch - ))); - } - debug!("Audio capture: 48000 Hz, 2 ch"); - - pcm.prepare() - .map_err(|e| AppError::AudioError(format!("Failed to prepare PCM: {}", e)))?; - - let _ = state.send(CaptureState::Running); - - let period_frames = pcm - .hw_params_current() - .ok() - .and_then(|h| h.get_period_size().ok()) - .map(|f| f as usize) - .unwrap_or(1024) - .max(256); - let buf_frames = period_frames.saturating_mul(4).max(2048); - let bytes_per_frame = (config.channels as usize) * 2; - let mut buffer = vec![0u8; buf_frames * bytes_per_frame]; - - while !stop_flag.load(Ordering::Relaxed) { - match pcm.state() { - State::XRun => { - warn_throttled!(log_throttler, "xrun", "Audio buffer overrun, recovering"); - let _ = pcm.prepare(); - continue; - } - State::Suspended => { - warn_throttled!( - log_throttler, - "suspended", - "Audio device suspended, recovering" - ); - let _ = pcm.resume(); - continue; - } - _ => {} - } - - // io_bytes: USB capture often lacks mmap (io_checked requires it). - let io: IO = pcm.io_bytes(); - - match io.readi(&mut buffer) { - Ok(frames_read) => { - if frames_read == 0 { - continue; - } - - let byte_count = frames_read * config.channels as usize * 2; - - let seq = sequence.fetch_add(1, Ordering::Relaxed); - let frame = AudioFrame::new_interleaved( - Bytes::copy_from_slice(&buffer[..byte_count]), - config.channels, - 48_000, - seq, - ); - - if frame_tx.receiver_count() > 0 { - if let Err(e) = frame_tx.send(frame) { - debug!("No audio receivers: {}", e); - } - } - } - Err(e) => { - let desc = e.to_string(); - if is_device_lost_error(&desc) { - return Err(AppError::AudioError(format!( - "Audio device lost while reading {}: {}", - config.device_name, e - ))); - } else if desc.contains("EPIPE") || desc.contains("Broken pipe") { - warn_throttled!(log_throttler, "buffer_overrun", "Audio buffer overrun"); - let _ = pcm.prepare(); - } else { - error_throttled!(log_throttler, "read_error", "Audio read error: {}", e); - } - } - } - } - - info!("Audio capture stopped"); - Ok(()) -} - -fn is_device_lost_error(desc: &str) -> bool { - desc.contains("No such device") - || desc.contains("ENODEV") - || desc.contains("ENXIO") - || desc.contains("ESHUTDOWN") -} +pub use imp::*; diff --git a/src/audio/capture_linux.rs b/src/audio/capture_linux.rs new file mode 100644 index 00000000..14a3510b --- /dev/null +++ b/src/audio/capture_linux.rs @@ -0,0 +1,334 @@ +use alsa::pcm::{Access, Format, Frames, HwParams, State, IO}; +use alsa::{Direction, ValueOr, PCM}; +use bytes::Bytes; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::Instant; +use tokio::sync::{broadcast, watch, Mutex}; +use tracing::{debug, info}; + +use super::device::AudioDeviceInfo; +use crate::error::{AppError, Result}; +use crate::utils::LogThrottler; +use crate::{error_throttled, warn_throttled}; + +#[derive(Debug, Clone)] +pub struct AudioConfig { + pub device_name: String, + pub sample_rate: u32, + pub channels: u32, + pub frame_size: u32, + pub buffer_frames: u32, + pub period_frames: u32, +} + +impl Default for AudioConfig { + fn default() -> Self { + Self { + device_name: String::new(), + sample_rate: 48000, + channels: 2, + frame_size: 960, + buffer_frames: 4096, + period_frames: 960, + } + } +} + +impl AudioConfig { + pub fn for_device(device: &AudioDeviceInfo) -> Self { + Self { + device_name: device.name.clone(), + ..Default::default() + } + } + + pub fn bytes_per_sample(&self) -> u32 { + 2 * self.channels + } + + pub fn bytes_per_frame(&self) -> usize { + (self.frame_size * self.bytes_per_sample()) as usize + } +} + +#[derive(Debug, Clone)] +pub struct AudioFrame { + pub data: Bytes, + pub sample_rate: u32, + pub channels: u32, + pub samples: u32, + pub sequence: u64, + pub timestamp: Instant, +} + +impl AudioFrame { + pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self { + let bps = 2 * channels; + Self { + samples: data.len() as u32 / bps, + data, + sample_rate, + channels, + sequence, + timestamp: Instant::now(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CaptureState { + Stopped, + Running, + Error, +} + +pub struct AudioCapturer { + config: AudioConfig, + state: Arc>, + state_rx: watch::Receiver, + frame_tx: broadcast::Sender, + stop_flag: Arc, + sequence: Arc, + capture_handle: Mutex>>, + log_throttler: LogThrottler, +} + +impl AudioCapturer { + pub fn new(config: AudioConfig) -> Self { + let (state_tx, state_rx) = watch::channel(CaptureState::Stopped); + let (frame_tx, _) = broadcast::channel(16); + + Self { + config, + state: Arc::new(state_tx), + state_rx, + frame_tx, + stop_flag: Arc::new(AtomicBool::new(false)), + sequence: Arc::new(AtomicU64::new(0)), + capture_handle: Mutex::new(None), + log_throttler: LogThrottler::with_secs(5), + } + } + + pub fn state(&self) -> CaptureState { + *self.state_rx.borrow() + } + + pub fn state_watch(&self) -> watch::Receiver { + self.state_rx.clone() + } + + pub fn subscribe(&self) -> broadcast::Receiver { + self.frame_tx.subscribe() + } + + pub async fn start(&self) -> Result<()> { + if self.state() == CaptureState::Running { + return Ok(()); + } + + debug!( + "Starting audio capture on {} at {}Hz {}ch", + self.config.device_name, self.config.sample_rate, self.config.channels + ); + + self.stop_flag.store(false, Ordering::SeqCst); + + let config = self.config.clone(); + let state = self.state.clone(); + let frame_tx = self.frame_tx.clone(); + let stop_flag = self.stop_flag.clone(); + let sequence = self.sequence.clone(); + let log_throttler = self.log_throttler.clone(); + + let handle = tokio::task::spawn_blocking(move || { + let result = run_capture( + &config, + &state, + &frame_tx, + &stop_flag, + &sequence, + &log_throttler, + ); + + if let Err(e) = result { + error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e); + let _ = state.send(CaptureState::Error); + } else { + let _ = state.send(CaptureState::Stopped); + } + }); + + *self.capture_handle.lock().await = Some(handle); + Ok(()) + } + + pub async fn stop(&self) -> Result<()> { + info!("Stopping audio capture"); + self.stop_flag.store(true, Ordering::SeqCst); + + if let Some(handle) = self.capture_handle.lock().await.take() { + let _ = handle.await; + } + + let _ = self.state.send(CaptureState::Stopped); + Ok(()) + } + + pub fn is_running(&self) -> bool { + self.state() == CaptureState::Running + } +} + +fn run_capture( + config: &AudioConfig, + state: &watch::Sender, + frame_tx: &broadcast::Sender, + stop_flag: &AtomicBool, + sequence: &AtomicU64, + log_throttler: &LogThrottler, +) -> Result<()> { + let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| { + AppError::AudioError(format!( + "Failed to open audio device {}: {}", + config.device_name, e + )) + })?; + + { + let hwp = HwParams::any(&pcm) + .map_err(|e| AppError::AudioError(format!("Failed to get HwParams: {}", e)))?; + + hwp.set_channels(config.channels) + .map_err(|e| AppError::AudioError(format!("Failed to set channels: {}", e)))?; + + hwp.set_rate(config.sample_rate, ValueOr::Nearest) + .map_err(|e| AppError::AudioError(format!("Failed to set sample rate: {}", e)))?; + + hwp.set_format(Format::s16()) + .map_err(|e| AppError::AudioError(format!("Failed to set format: {}", e)))?; + + hwp.set_access(Access::RWInterleaved) + .map_err(|e| AppError::AudioError(format!("Failed to set access: {}", e)))?; + + hwp.set_buffer_size_near(config.buffer_frames as Frames) + .map_err(|e| AppError::AudioError(format!("Failed to set buffer size: {}", e)))?; + + hwp.set_period_size_near(config.period_frames as Frames, ValueOr::Nearest) + .map_err(|e| AppError::AudioError(format!("Failed to set period size: {}", e)))?; + + pcm.hw_params(&hwp) + .map_err(|e| AppError::AudioError(format!("Failed to apply hw params: {}", e)))?; + } + + let hw_now = pcm.hw_params_current().map_err(|e| { + AppError::AudioError(format!("Failed to read hw_params after apply: {}", e)) + })?; + let actual_rate = hw_now + .get_rate() + .map_err(|e| AppError::AudioError(format!("Failed to read sample rate: {}", e)))?; + let actual_ch = hw_now + .get_channels() + .map_err(|e| AppError::AudioError(format!("Failed to read channels: {}", e)))?; + if actual_rate != 48_000 { + return Err(AppError::AudioError(format!( + "Audio capture requires 48000 Hz; device is {} Hz", + actual_rate + ))); + } + if actual_ch != 2 { + return Err(AppError::AudioError(format!( + "Audio capture requires 2 channels (stereo); device has {}", + actual_ch + ))); + } + debug!("Audio capture: 48000 Hz, 2 ch"); + + pcm.prepare() + .map_err(|e| AppError::AudioError(format!("Failed to prepare PCM: {}", e)))?; + + let _ = state.send(CaptureState::Running); + + let period_frames = pcm + .hw_params_current() + .ok() + .and_then(|h| h.get_period_size().ok()) + .map(|f| f as usize) + .unwrap_or(1024) + .max(256); + let buf_frames = period_frames.saturating_mul(4).max(2048); + let bytes_per_frame = (config.channels as usize) * 2; + let mut buffer = vec![0u8; buf_frames * bytes_per_frame]; + + while !stop_flag.load(Ordering::Relaxed) { + match pcm.state() { + State::XRun => { + warn_throttled!(log_throttler, "xrun", "Audio buffer overrun, recovering"); + let _ = pcm.prepare(); + continue; + } + State::Suspended => { + warn_throttled!( + log_throttler, + "suspended", + "Audio device suspended, recovering" + ); + let _ = pcm.resume(); + continue; + } + _ => {} + } + + // io_bytes: USB capture often lacks mmap (io_checked requires it). + let io: IO = pcm.io_bytes(); + + match io.readi(&mut buffer) { + Ok(frames_read) => { + if frames_read == 0 { + continue; + } + + let byte_count = frames_read * config.channels as usize * 2; + + let seq = sequence.fetch_add(1, Ordering::Relaxed); + let frame = AudioFrame::new_interleaved( + Bytes::copy_from_slice(&buffer[..byte_count]), + config.channels, + 48_000, + seq, + ); + + if frame_tx.receiver_count() > 0 { + if let Err(e) = frame_tx.send(frame) { + debug!("No audio receivers: {}", e); + } + } + } + Err(e) => { + let desc = e.to_string(); + if is_device_lost_error(&desc) { + return Err(AppError::AudioError(format!( + "Audio device lost while reading {}: {}", + config.device_name, e + ))); + } else if desc.contains("EPIPE") || desc.contains("Broken pipe") { + warn_throttled!(log_throttler, "buffer_overrun", "Audio buffer overrun"); + let _ = pcm.prepare(); + } else { + error_throttled!(log_throttler, "read_error", "Audio read error: {}", e); + } + } + } + } + + info!("Audio capture stopped"); + Ok(()) +} + +fn is_device_lost_error(desc: &str) -> bool { + desc.contains("No such device") + || desc.contains("ENODEV") + || desc.contains("ENXIO") + || desc.contains("ESHUTDOWN") +} diff --git a/src/audio/capture_windows.rs b/src/audio/capture_windows.rs new file mode 100644 index 00000000..36429dc3 --- /dev/null +++ b/src/audio/capture_windows.rs @@ -0,0 +1,516 @@ +use bytes::Bytes; +use cpal::traits::{DeviceTrait, StreamTrait}; +use cpal::{BufferSize, SampleFormat, StreamConfig}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::mpsc; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::{broadcast, watch, Mutex}; +use tracing::{debug, info}; + +use crate::audio::device::{find_wasapi_device, AudioDeviceInfo}; +use crate::error::{AppError, Result}; +use crate::error_throttled; +use crate::utils::LogThrottler; + +#[derive(Debug, Clone)] +pub struct AudioConfig { + pub device_name: String, + pub sample_rate: u32, + pub channels: u32, + pub frame_size: u32, + pub buffer_frames: u32, + pub period_frames: u32, +} + +impl Default for AudioConfig { + fn default() -> Self { + Self { + device_name: String::new(), + sample_rate: 48000, + channels: 2, + frame_size: 960, + buffer_frames: 4096, + period_frames: 960, + } + } +} + +impl AudioConfig { + pub fn for_device(device: &AudioDeviceInfo) -> Self { + Self { + device_name: device.name.clone(), + ..Default::default() + } + } + + pub fn bytes_per_sample(&self) -> u32 { + 2 * self.channels + } + + pub fn bytes_per_frame(&self) -> usize { + (self.frame_size * self.bytes_per_sample()) as usize + } +} + +#[derive(Debug, Clone)] +pub struct AudioFrame { + pub data: Bytes, + pub sample_rate: u32, + pub channels: u32, + pub samples: u32, + pub sequence: u64, + pub timestamp: Instant, +} + +impl AudioFrame { + pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self { + let bps = 2 * channels; + Self { + samples: data.len() as u32 / bps, + data, + sample_rate, + channels, + sequence, + timestamp: Instant::now(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CaptureState { + Stopped, + Running, + Error, +} + +pub struct AudioCapturer { + config: AudioConfig, + state: Arc>, + state_rx: watch::Receiver, + frame_tx: broadcast::Sender, + stop_flag: Arc, + sequence: Arc, + capture_handle: Mutex>>, + log_throttler: LogThrottler, +} + +impl AudioCapturer { + pub fn new(config: AudioConfig) -> Self { + let (state_tx, state_rx) = watch::channel(CaptureState::Stopped); + let (frame_tx, _) = broadcast::channel(16); + + Self { + config, + state: Arc::new(state_tx), + state_rx, + frame_tx, + stop_flag: Arc::new(AtomicBool::new(false)), + sequence: Arc::new(AtomicU64::new(0)), + capture_handle: Mutex::new(None), + log_throttler: LogThrottler::with_secs(5), + } + } + + pub fn state(&self) -> CaptureState { + *self.state_rx.borrow() + } + + pub fn state_watch(&self) -> watch::Receiver { + self.state_rx.clone() + } + + pub fn subscribe(&self) -> broadcast::Receiver { + self.frame_tx.subscribe() + } + + pub async fn start(&self) -> Result<()> { + if self.state() == CaptureState::Running { + return Ok(()); + } + + debug!( + "Starting WASAPI audio capture on {} at {}Hz {}ch", + self.config.device_name, self.config.sample_rate, self.config.channels + ); + + self.stop_flag.store(false, Ordering::SeqCst); + + let config = self.config.clone(); + let state = self.state.clone(); + let frame_tx = self.frame_tx.clone(); + let stop_flag = self.stop_flag.clone(); + let sequence = self.sequence.clone(); + let log_throttler = self.log_throttler.clone(); + + let handle = tokio::task::spawn_blocking(move || { + let result = run_capture( + &config, + &state, + &frame_tx, + &stop_flag, + &sequence, + &log_throttler, + ); + + if let Err(e) = result { + error_throttled!( + log_throttler, + "capture_error", + "WASAPI audio capture error: {}", + e + ); + let _ = state.send(CaptureState::Error); + } else { + let _ = state.send(CaptureState::Stopped); + } + }); + + *self.capture_handle.lock().await = Some(handle); + Ok(()) + } + + pub async fn stop(&self) -> Result<()> { + info!("Stopping WASAPI audio capture"); + self.stop_flag.store(true, Ordering::SeqCst); + + if let Some(handle) = self.capture_handle.lock().await.take() { + let _ = handle.await; + } + + let _ = self.state.send(CaptureState::Stopped); + Ok(()) + } + + pub fn is_running(&self) -> bool { + self.state() == CaptureState::Running + } +} + +fn run_capture( + config: &AudioConfig, + state: &watch::Sender, + frame_tx: &broadcast::Sender, + stop_flag: &AtomicBool, + sequence: &AtomicU64, + log_throttler: &LogThrottler, +) -> Result<()> { + let device = find_wasapi_device(&config.device_name)?; + let device_label = device_label(&device); + + let supported = select_input_config(&device, config)?; + let sample_format = supported.sample_format(); + let input_channels = supported.channels() as u32; + let input_rate = supported.sample_rate(); + let stream_config = StreamConfig { + channels: supported.channels(), + sample_rate: supported.sample_rate(), + buffer_size: BufferSize::Fixed(config.period_frames.max(128)), + }; + + debug!( + "WASAPI capture selected: {} @ {}Hz {}ch {:?}", + device_label, input_rate, input_channels, sample_format + ); + + let (tx, rx) = mpsc::sync_channel::>(8); + let (err_tx, err_rx) = mpsc::sync_channel::(1); + let callback_stop = Arc::new(AtomicBool::new(false)); + + let stream = match sample_format { + SampleFormat::F32 => build_stream::( + &device, + &stream_config, + input_channels, + input_rate, + tx.clone(), + err_tx.clone(), + callback_stop.clone(), + ), + SampleFormat::I16 => build_stream::( + &device, + &stream_config, + input_channels, + input_rate, + tx.clone(), + err_tx.clone(), + callback_stop.clone(), + ), + SampleFormat::U16 => build_stream::( + &device, + &stream_config, + input_channels, + input_rate, + tx.clone(), + err_tx.clone(), + callback_stop.clone(), + ), + other => { + return Err(AppError::AudioError(format!( + "Unsupported WASAPI sample format: {:?}", + other + ))); + } + }?; + + stream + .play() + .map_err(|e| AppError::AudioError(format!("Failed to start WASAPI stream: {}", e)))?; + + let _ = state.send(CaptureState::Running); + + while !stop_flag.load(Ordering::Relaxed) { + if let Ok(err) = err_rx.try_recv() { + return Err(AppError::AudioError(format!( + "WASAPI stream error for {}: {}", + device_label, err + ))); + } + + match rx.recv_timeout(Duration::from_millis(100)) { + Ok(samples) => { + if samples.is_empty() { + continue; + } + let seq = sequence.fetch_add(1, Ordering::Relaxed); + let frame = AudioFrame::new_interleaved( + Bytes::copy_from_slice(bytemuck::cast_slice(&samples)), + 2, + 48_000, + seq, + ); + if frame_tx.receiver_count() > 0 { + if let Err(e) = frame_tx.send(frame) { + debug!("No audio receivers: {}", e); + } + } + } + Err(mpsc::RecvTimeoutError::Timeout) => {} + Err(mpsc::RecvTimeoutError::Disconnected) => { + return Err(AppError::AudioError(format!( + "WASAPI capture callback stopped for {}", + device_label + ))); + } + } + } + + callback_stop.store(true, Ordering::SeqCst); + drop(stream); + + info!("WASAPI audio capture stopped"); + let _ = log_throttler; + Ok(()) +} + +fn select_input_config( + device: &cpal::Device, + config: &AudioConfig, +) -> Result { + let requested_rate = config.sample_rate; + let mut fallback = None; + + let configs = device.supported_input_configs().map_err(|e| { + AppError::AudioError(format!("Failed to query WASAPI input configs: {}", e)) + })?; + + for range in configs { + let sample_format = range.sample_format(); + if !matches!( + sample_format, + SampleFormat::F32 | SampleFormat::I16 | SampleFormat::U16 + ) { + continue; + } + + if fallback + .as_ref() + .is_none_or(|best: &cpal::SupportedStreamConfigRange| { + range.cmp_default_heuristics(best).is_gt() + }) + { + fallback = Some(range); + } + + if range.channels() >= 2 + && range.min_sample_rate() <= requested_rate + && requested_rate <= range.max_sample_rate() + { + return Ok(range.with_sample_rate(requested_rate)); + } + } + + if let Some(range) = fallback { + let rate = if range.min_sample_rate() <= requested_rate + && requested_rate <= range.max_sample_rate() + { + requested_rate + } else { + range.with_max_sample_rate().sample_rate() + }; + return Ok(range.with_sample_rate(rate)); + } + + device.default_input_config().map_err(|e| { + AppError::AudioError(format!( + "No supported WASAPI input format found, and default config failed: {}", + e + )) + }) +} + +fn build_stream( + device: &cpal::Device, + config: &StreamConfig, + input_channels: u32, + input_rate: u32, + tx: mpsc::SyncSender>, + err_tx: mpsc::SyncSender, + stop_flag: Arc, +) -> Result +where + T: cpal::SizedSample + SampleToI16, +{ + let mut converter = PcmConverter::new(input_channels, input_rate, 2, 48_000); + let data_tx = tx.clone(); + let stream = device + .build_input_stream( + config, + move |data: &[T], _| { + if stop_flag.load(Ordering::Relaxed) { + return; + } + let pcm = converter.convert(data); + if !pcm.is_empty() { + let _ = data_tx.try_send(pcm); + } + }, + move |err| { + let _ = err_tx.try_send(err.to_string()); + }, + Some(Duration::from_secs(2)), + ) + .map_err(|e| AppError::AudioError(format!("Failed to build WASAPI input stream: {}", e)))?; + Ok(stream) +} + +trait SampleToI16: Copy + Send + 'static { + fn to_i16_sample(self) -> i16; +} + +impl SampleToI16 for i16 { + fn to_i16_sample(self) -> i16 { + self + } +} + +impl SampleToI16 for u16 { + fn to_i16_sample(self) -> i16 { + (self as i32 - 32768).clamp(i16::MIN as i32, i16::MAX as i32) as i16 + } +} + +impl SampleToI16 for f32 { + fn to_i16_sample(self) -> i16 { + (self.clamp(-1.0, 1.0) * i16::MAX as f32).round() as i16 + } +} + +struct PcmConverter { + input_channels: usize, + input_rate: u32, + output_channels: usize, + output_rate: u32, + input_position: u64, + next_output_position: u64, +} + +impl PcmConverter { + fn new(input_channels: u32, input_rate: u32, output_channels: u32, output_rate: u32) -> Self { + Self { + input_channels: input_channels.max(1) as usize, + input_rate: input_rate.max(1), + output_channels: output_channels.max(1) as usize, + output_rate: output_rate.max(1), + input_position: 0, + next_output_position: 0, + } + } + + fn convert(&mut self, input: &[T]) -> Vec { + let frames = input.len() / self.input_channels; + if frames == 0 { + return Vec::new(); + } + + if self.input_rate == self.output_rate { + self.input_position = self.input_position.saturating_add(frames as u64); + return self.convert_channels(input, frames); + } + + let start = self.input_position; + let end = start.saturating_add(frames as u64); + let mut out = Vec::with_capacity( + ((frames as u64 * self.output_rate as u64 / self.input_rate as u64 + 2) as usize) + * self.output_channels, + ); + + while self.source_position_for_output(self.next_output_position) < end { + let src = self.source_position_for_output(self.next_output_position); + if src >= start { + let local = (src - start) as usize; + self.push_frame(input, local.min(frames - 1), &mut out); + } + self.next_output_position = self.next_output_position.saturating_add(1); + } + + self.input_position = end; + out + } + + fn source_position_for_output(&self, output_position: u64) -> u64 { + output_position.saturating_mul(self.input_rate as u64) / self.output_rate as u64 + } + + fn convert_channels(&self, input: &[T], frames: usize) -> Vec { + let mut out = Vec::with_capacity(frames * self.output_channels); + for frame in 0..frames { + self.push_frame(input, frame, &mut out); + } + out + } + + fn push_frame(&self, input: &[T], frame: usize, out: &mut Vec) { + let base = frame * self.input_channels; + let left = input + .get(base) + .copied() + .map(SampleToI16::to_i16_sample) + .unwrap_or(0); + let right = if self.input_channels > 1 { + input + .get(base + 1) + .copied() + .map(SampleToI16::to_i16_sample) + .unwrap_or(left) + } else { + left + }; + + out.push(left); + if self.output_channels > 1 { + out.push(right); + } + } +} + +fn device_label(device: &cpal::Device) -> String { + device + .description() + .map(|desc| desc.to_string()) + .or_else(|_| { + #[allow(deprecated)] + device.name() + }) + .unwrap_or_else(|_| "Unknown WASAPI capture device".to_string()) +} diff --git a/src/audio/controller.rs b/src/audio/controller.rs index 5a8680e9..84742fa2 100644 --- a/src/audio/controller.rs +++ b/src/audio/controller.rs @@ -1,107 +1,21 @@ //! Device selection, quality presets, streaming. -use serde::{Deserialize, Serialize}; -use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::time::Duration; use tokio::sync::RwLock; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; use super::capture::AudioConfig; -use super::device::{ - enumerate_audio_devices, enumerate_audio_devices_with_current, find_best_audio_device, - AudioDeviceInfo, -}; -use super::encoder::{OpusConfig, OpusFrame}; +use super::device::{enumerate_audio_devices_with_current, find_best_audio_device, AudioDeviceInfo}; +use super::encoder::OpusFrame; use super::monitor::AudioHealthMonitor; -use super::streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig}; +use super::streamer::{AudioStreamer, AudioStreamerConfig}; +use super::recovery; +use super::types::{AudioControllerConfig, AudioQuality, AudioStatus}; use crate::error::{AppError, Result}; -use crate::events::{EventBus, StreamDeviceLostKind, SystemEvent}; +use crate::events::EventBus; -const AUDIO_RECOVERY_RETRY_DELAY: Duration = Duration::from_secs(1); - -type AudioRecoveredCallback = Arc; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] -#[serde(rename_all = "lowercase")] -pub enum AudioQuality { - Voice, - #[default] - Balanced, - High, -} - -impl AudioQuality { - pub fn bitrate(&self) -> u32 { - match self { - AudioQuality::Voice => 32000, - AudioQuality::Balanced => 64000, - AudioQuality::High => 128000, - } - } - - pub fn to_opus_config(&self) -> OpusConfig { - match self { - AudioQuality::Voice => OpusConfig::voice(), - AudioQuality::Balanced => OpusConfig::default(), - AudioQuality::High => OpusConfig::music(), - } - } -} - -impl FromStr for AudioQuality { - type Err = AppError; - - fn from_str(s: &str) -> std::result::Result { - match s.trim().to_lowercase().as_str() { - "voice" => Ok(Self::Voice), - "balanced" => Ok(Self::Balanced), - "high" => Ok(Self::High), - _ => Err(AppError::BadRequest(format!( - "invalid audio quality {:?} (expected voice, balanced, or high)", - s.trim() - ))), - } - } -} - -impl std::fmt::Display for AudioQuality { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - AudioQuality::Voice => write!(f, "voice"), - AudioQuality::Balanced => write!(f, "balanced"), - AudioQuality::High => write!(f, "high"), - } - } -} - -#[derive(Debug, Clone)] -pub struct AudioControllerConfig { - pub enabled: bool, - pub device: String, - pub quality: AudioQuality, -} - -impl Default for AudioControllerConfig { - fn default() -> Self { - Self { - enabled: false, - device: String::new(), - quality: AudioQuality::Balanced, - } - } -} - -#[derive(Debug, Clone, Serialize)] -pub struct AudioStatus { - pub enabled: bool, - pub streaming: bool, - pub device: Option, - pub quality: AudioQuality, - pub subscriber_count: usize, - pub error: Option, -} +pub(super) type AudioRecoveredCallback = Arc; pub struct AudioController { config: Arc>, @@ -135,274 +49,12 @@ impl AudioController { } async fn mark_device_info_dirty(&self) { - if let Some(ref bus) = *self.event_bus.read().await { + if let Some(bus) = self.event_bus.read().await.as_ref() { bus.mark_device_info_dirty(); } } - - async fn publish_state( - event_bus: &Arc>>>, - state: &str, - device: Option, - reason: Option<&str>, - next_retry_ms: Option, - ) { - if let Some(ref bus) = *event_bus.read().await { - bus.publish(SystemEvent::StreamStateChanged { - state: state.to_string(), - device, - reason: reason.map(str::to_string), - next_retry_ms, - }); - bus.mark_device_info_dirty(); - } - } - - async fn publish_device_lost( - event_bus: &Arc>>>, - device: &str, - reason: &str, - ) { - if let Some(ref bus) = *event_bus.read().await { - bus.publish(SystemEvent::StreamDeviceLost { - kind: StreamDeviceLostKind::Audio, - device: device.to_string(), - reason: reason.to_string(), - }); - } - } - - async fn publish_reconnecting( - event_bus: &Arc>>>, - device: &str, - attempt: u32, - ) { - if let Some(ref bus) = *event_bus.read().await { - bus.publish(SystemEvent::StreamReconnecting { - device: device.to_string(), - attempt, - }); - } - } - - async fn publish_recovered(event_bus: &Arc>>>, device: &str) { - if let Some(ref bus) = *event_bus.read().await { - bus.publish(SystemEvent::StreamRecovered { - device: device.to_string(), - }); - } - } - - fn select_recovery_device( - devices: &[AudioDeviceInfo], - preferred: &str, - ) -> Option { - if !preferred.trim().is_empty() { - if let Some(device) = devices.iter().find(|d| d.name == preferred) { - return Some(device.clone()); - } - } - - devices - .iter() - .find(|d| d.is_hdmi && d.sample_rates.contains(&48_000) && d.channels.contains(&2)) - .or_else(|| { - devices - .iter() - .find(|d| d.sample_rates.contains(&48_000) && d.channels.contains(&2)) - }) - .or_else(|| devices.first()) - .cloned() - } - - fn spawn_stream_monitor_from_parts( - config: Arc>, - streamer_slot: Arc>>>, - event_bus: Arc>>>, - monitor: Arc, - recovery_in_progress: Arc, - recovered_callback: Arc>>, - streamer: Arc, - device: String, - ) { - let mut state_rx = streamer.state_watch(); - - tokio::spawn(async move { - loop { - if state_rx.changed().await.is_err() { - return; - } - - if *state_rx.borrow() != AudioStreamState::Error { - continue; - } - - { - let current = streamer_slot.read().await; - if !current - .as_ref() - .is_some_and(|current| Arc::ptr_eq(current, &streamer)) - { - return; - } - } - - let reason = format!("Audio device lost: {}", device); - monitor.report_error(&reason, "device_lost").await; - Self::spawn_recovery_task_from_parts( - config, - streamer_slot, - event_bus, - monitor, - recovery_in_progress, - recovered_callback, - device, - reason, - ); - return; - } - }); - } - - fn spawn_recovery_task_from_parts( - config: Arc>, - streamer_slot: Arc>>>, - event_bus: Arc>>>, - monitor: Arc, - recovery_in_progress: Arc, - recovered_callback: Arc>>, - lost_device: String, - reason: String, - ) { - if recovery_in_progress.swap(true, Ordering::SeqCst) { - debug!("Audio recovery already in progress"); - return; - } - - tokio::spawn(async move { - warn!("Audio recovery started for {}: {}", lost_device, reason); - Self::publish_device_lost(&event_bus, &lost_device, &reason).await; - Self::publish_state( - &event_bus, - "device_lost", - Some(lost_device.clone()), - Some("audio_device_lost"), - Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64), - ) - .await; - - let mut attempt = 0u32; - - loop { - if !recovery_in_progress.load(Ordering::SeqCst) { - debug!("Audio recovery canceled"); - return; - } - - if streamer_slot - .read() - .await - .as_ref() - .is_some_and(|s| s.is_running()) - { - recovery_in_progress.store(false, Ordering::SeqCst); - return; - } - - let cfg = config.read().await.clone(); - if !cfg.enabled { - recovery_in_progress.store(false, Ordering::SeqCst); - return; - } - - attempt = attempt.saturating_add(1); - Self::publish_reconnecting(&event_bus, &lost_device, attempt).await; - Self::publish_state( - &event_bus, - "device_lost", - Some(lost_device.clone()), - Some("audio_reconnecting"), - Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64), - ) - .await; - - tokio::time::sleep(AUDIO_RECOVERY_RETRY_DELAY).await; - - let devices = match enumerate_audio_devices() { - Ok(devices) => devices, - Err(e) => { - debug!( - "Audio recovery enumerate failed (attempt {}): {}", - attempt, e - ); - continue; - } - }; - - let Some(device) = Self::select_recovery_device(&devices, &cfg.device) else { - debug!("No audio devices found during recovery attempt {}", attempt); - continue; - }; - - let streamer_config = AudioStreamerConfig { - capture: AudioConfig { - device_name: device.name.clone(), - ..Default::default() - }, - opus: cfg.quality.to_opus_config(), - }; - let new_streamer = Arc::new(AudioStreamer::with_config(streamer_config)); - - match new_streamer.start().await { - Ok(()) => { - { - let mut cfg = config.write().await; - cfg.device = device.name.clone(); - } - *streamer_slot.write().await = Some(new_streamer.clone()); - monitor.report_recovered().await; - Self::publish_recovered(&event_bus, &device.name).await; - if let Some(callback) = recovered_callback.read().await.clone() { - callback(); - } - Self::publish_state( - &event_bus, - "streaming", - Some(device.name.clone()), - None, - None, - ) - .await; - recovery_in_progress.store(false, Ordering::SeqCst); - info!( - "Audio device recovered with {} after {} attempts", - device.name, attempt - ); - Self::spawn_stream_monitor_from_parts( - config, - streamer_slot, - event_bus, - monitor, - recovery_in_progress, - recovered_callback, - new_streamer, - device.name, - ); - return; - } - Err(e) => { - debug!( - "Audio recovery start failed with {} (attempt {}): {}", - device.name, attempt, e - ); - } - } - } - }); - } - fn spawn_recovery_task(&self, lost_device: String, reason: String) { - Self::spawn_recovery_task_from_parts( + recovery::spawn_recovery_task( self.config.clone(), self.streamer.clone(), self.event_bus.clone(), @@ -415,7 +67,7 @@ impl AudioController { } fn spawn_stream_monitor(&self, streamer: Arc, device: String) { - Self::spawn_stream_monitor_from_parts( + recovery::spawn_stream_monitor( self.config.clone(), self.streamer.clone(), self.event_bus.clone(), @@ -477,7 +129,7 @@ impl AudioController { config.quality = quality; } - if let Some(ref streamer) = *self.streamer.read().await { + if let Some(streamer) = self.streamer.read().await.as_ref() { streamer.set_bitrate(quality.bitrate()).await?; } @@ -578,11 +230,11 @@ impl AudioController { } pub async fn is_streaming(&self) -> bool { - if let Some(ref streamer) = *self.streamer.read().await { - streamer.is_running() - } else { - false - } + self.streamer + .read() + .await + .as_ref() + .is_some_and(|streamer| streamer.is_running()) } pub async fn status(&self) -> AudioStatus { diff --git a/src/audio/device.rs b/src/audio/device.rs index 66df38cc..e8752ac2 100644 --- a/src/audio/device.rs +++ b/src/audio/device.rs @@ -1,201 +1,9 @@ -use alsa::pcm::HwParams; -use alsa::{Direction, PCM}; -use serde::Serialize; -use tracing::{debug, info, warn}; +#[cfg(unix)] +#[path = "device_linux.rs"] +mod imp; -use crate::error::{AppError, Result}; +#[cfg(windows)] +#[path = "device_windows.rs"] +mod imp; -#[derive(Debug, Clone, Serialize)] -pub struct AudioDeviceInfo { - pub name: String, - pub description: String, - pub card_index: i32, - pub device_index: i32, - pub sample_rates: Vec, - pub channels: Vec, - pub is_capture: bool, - pub is_hdmi: bool, - pub usb_bus: Option, -} - -fn get_usb_bus_info(card_index: i32) -> Option { - if card_index < 0 { - return None; - } - - let device_path = format!("/sys/class/sound/card{}/device", card_index); - let link_target = std::fs::read_link(&device_path).ok()?; - let link_str = link_target.to_string_lossy(); - - for component in link_str.split('/') { - if component.contains('-') && !component.contains(':') { - if component - .chars() - .next() - .map(|c| c.is_ascii_digit()) - .unwrap_or(false) - { - return Some(component.to_string()); - } - } - } - - None -} - -pub fn enumerate_audio_devices() -> Result> { - enumerate_audio_devices_with_current(None) -} - -pub fn enumerate_audio_devices_with_current( - current_device: Option<&str>, -) -> Result> { - let mut devices = Vec::new(); - - let cards = alsa::card::Iter::new(); - - for card_result in cards { - let card = match card_result { - Ok(c) => c, - Err(e) => { - debug!("Error iterating card: {}", e); - continue; - } - }; - - let card_index = card.get_index(); - let card_name = card.get_name().unwrap_or_else(|_| "Unknown".to_string()); - let card_longname = card.get_longname().unwrap_or_else(|_| card_name.clone()); - - debug!("Found audio card {}: {}", card_index, card_longname); - - let long_lower = card_longname.to_lowercase(); - let is_hdmi = long_lower.contains("hdmi") - || long_lower.contains("capture") - || long_lower.contains("usb"); - - let usb_bus = get_usb_bus_info(card_index); - - for device_index in 0..8 { - let device_name = format!("hw:{},{}", card_index, device_index); - let is_current_device = current_device == Some(device_name.as_str()); - - let mut push_info = - |sample_rates: Vec, channels: Vec, description: String| { - devices.push(AudioDeviceInfo { - name: device_name.clone(), - description, - card_index, - device_index, - sample_rates, - channels, - is_capture: true, - is_hdmi, - usb_bus: usb_bus.clone(), - }); - }; - - match PCM::new(&device_name, Direction::Capture, false) { - Ok(pcm) => { - let (sample_rates, channels) = query_device_caps(&pcm); - - if !sample_rates.is_empty() && !channels.is_empty() { - push_info( - sample_rates, - channels, - format!("{} - Device {}", card_longname, device_index), - ); - } - } - Err(_) => { - if is_current_device { - debug!( - "Device {} is busy (in use by us), adding with default caps", - device_name - ); - push_info( - vec![44100, 48000], - vec![2], - format!("{} - Device {} (in use)", card_longname, device_index), - ); - } - } - } - } - } - - info!("Found {} audio capture devices", devices.len()); - Ok(devices) -} - -fn query_device_caps(pcm: &PCM) -> (Vec, Vec) { - let hwp = match HwParams::any(pcm) { - Ok(h) => h, - Err(_) => return (vec![], vec![]), - }; - - let common_rates = [8000, 16000, 22050, 44100, 48000, 96000]; - let mut supported_rates = Vec::new(); - - for rate in &common_rates { - if hwp.test_rate(*rate).is_ok() { - supported_rates.push(*rate); - } - } - - let mut supported_channels = Vec::new(); - for ch in 1..=8 { - if hwp.test_channels(ch).is_ok() { - supported_channels.push(ch); - } - } - - (supported_rates, supported_channels) -} - -pub fn find_best_audio_device() -> Result { - let devices = enumerate_audio_devices()?; - - if devices.is_empty() { - return Err(AppError::AudioError( - "No audio capture devices found".to_string(), - )); - } - - let mut first_48k_stereo: Option<&AudioDeviceInfo> = None; - for device in &devices { - if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) { - continue; - } - if device.is_hdmi { - info!("Selected HDMI audio device: {}", device.description); - return Ok(device.clone()); - } - if first_48k_stereo.is_none() { - first_48k_stereo = Some(device); - } - } - if let Some(device) = first_48k_stereo { - info!("Selected audio device: {}", device.description); - return Ok(device.clone()); - } - - let device = devices.into_iter().next().unwrap(); - warn!( - "Using fallback audio device: {} (may not support optimal settings)", - device.description - ); - Ok(device) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_enumerate_devices() { - let result = enumerate_audio_devices(); - println!("Audio devices: {:?}", result); - assert!(result.is_ok()); - } -} +pub use imp::*; diff --git a/src/audio/device_linux.rs b/src/audio/device_linux.rs new file mode 100644 index 00000000..66df38cc --- /dev/null +++ b/src/audio/device_linux.rs @@ -0,0 +1,201 @@ +use alsa::pcm::HwParams; +use alsa::{Direction, PCM}; +use serde::Serialize; +use tracing::{debug, info, warn}; + +use crate::error::{AppError, Result}; + +#[derive(Debug, Clone, Serialize)] +pub struct AudioDeviceInfo { + pub name: String, + pub description: String, + pub card_index: i32, + pub device_index: i32, + pub sample_rates: Vec, + pub channels: Vec, + pub is_capture: bool, + pub is_hdmi: bool, + pub usb_bus: Option, +} + +fn get_usb_bus_info(card_index: i32) -> Option { + if card_index < 0 { + return None; + } + + let device_path = format!("/sys/class/sound/card{}/device", card_index); + let link_target = std::fs::read_link(&device_path).ok()?; + let link_str = link_target.to_string_lossy(); + + for component in link_str.split('/') { + if component.contains('-') && !component.contains(':') { + if component + .chars() + .next() + .map(|c| c.is_ascii_digit()) + .unwrap_or(false) + { + return Some(component.to_string()); + } + } + } + + None +} + +pub fn enumerate_audio_devices() -> Result> { + enumerate_audio_devices_with_current(None) +} + +pub fn enumerate_audio_devices_with_current( + current_device: Option<&str>, +) -> Result> { + let mut devices = Vec::new(); + + let cards = alsa::card::Iter::new(); + + for card_result in cards { + let card = match card_result { + Ok(c) => c, + Err(e) => { + debug!("Error iterating card: {}", e); + continue; + } + }; + + let card_index = card.get_index(); + let card_name = card.get_name().unwrap_or_else(|_| "Unknown".to_string()); + let card_longname = card.get_longname().unwrap_or_else(|_| card_name.clone()); + + debug!("Found audio card {}: {}", card_index, card_longname); + + let long_lower = card_longname.to_lowercase(); + let is_hdmi = long_lower.contains("hdmi") + || long_lower.contains("capture") + || long_lower.contains("usb"); + + let usb_bus = get_usb_bus_info(card_index); + + for device_index in 0..8 { + let device_name = format!("hw:{},{}", card_index, device_index); + let is_current_device = current_device == Some(device_name.as_str()); + + let mut push_info = + |sample_rates: Vec, channels: Vec, description: String| { + devices.push(AudioDeviceInfo { + name: device_name.clone(), + description, + card_index, + device_index, + sample_rates, + channels, + is_capture: true, + is_hdmi, + usb_bus: usb_bus.clone(), + }); + }; + + match PCM::new(&device_name, Direction::Capture, false) { + Ok(pcm) => { + let (sample_rates, channels) = query_device_caps(&pcm); + + if !sample_rates.is_empty() && !channels.is_empty() { + push_info( + sample_rates, + channels, + format!("{} - Device {}", card_longname, device_index), + ); + } + } + Err(_) => { + if is_current_device { + debug!( + "Device {} is busy (in use by us), adding with default caps", + device_name + ); + push_info( + vec![44100, 48000], + vec![2], + format!("{} - Device {} (in use)", card_longname, device_index), + ); + } + } + } + } + } + + info!("Found {} audio capture devices", devices.len()); + Ok(devices) +} + +fn query_device_caps(pcm: &PCM) -> (Vec, Vec) { + let hwp = match HwParams::any(pcm) { + Ok(h) => h, + Err(_) => return (vec![], vec![]), + }; + + let common_rates = [8000, 16000, 22050, 44100, 48000, 96000]; + let mut supported_rates = Vec::new(); + + for rate in &common_rates { + if hwp.test_rate(*rate).is_ok() { + supported_rates.push(*rate); + } + } + + let mut supported_channels = Vec::new(); + for ch in 1..=8 { + if hwp.test_channels(ch).is_ok() { + supported_channels.push(ch); + } + } + + (supported_rates, supported_channels) +} + +pub fn find_best_audio_device() -> Result { + let devices = enumerate_audio_devices()?; + + if devices.is_empty() { + return Err(AppError::AudioError( + "No audio capture devices found".to_string(), + )); + } + + let mut first_48k_stereo: Option<&AudioDeviceInfo> = None; + for device in &devices { + if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) { + continue; + } + if device.is_hdmi { + info!("Selected HDMI audio device: {}", device.description); + return Ok(device.clone()); + } + if first_48k_stereo.is_none() { + first_48k_stereo = Some(device); + } + } + if let Some(device) = first_48k_stereo { + info!("Selected audio device: {}", device.description); + return Ok(device.clone()); + } + + let device = devices.into_iter().next().unwrap(); + warn!( + "Using fallback audio device: {} (may not support optimal settings)", + device.description + ); + Ok(device) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_enumerate_devices() { + let result = enumerate_audio_devices(); + println!("Audio devices: {:?}", result); + assert!(result.is_ok()); + } +} diff --git a/src/audio/device_windows.rs b/src/audio/device_windows.rs new file mode 100644 index 00000000..8840aa5b --- /dev/null +++ b/src/audio/device_windows.rs @@ -0,0 +1,232 @@ +use cpal::traits::{DeviceTrait, HostTrait}; +use cpal::DeviceId; +use serde::Serialize; +use std::str::FromStr; +use tracing::{debug, info, warn}; + +use crate::error::{AppError, Result}; + +#[derive(Debug, Clone, Serialize)] +pub struct AudioDeviceInfo { + pub name: String, + pub description: String, + pub card_index: i32, + pub device_index: i32, + pub sample_rates: Vec, + pub channels: Vec, + pub is_capture: bool, + pub is_hdmi: bool, + pub usb_bus: Option, +} + +pub fn enumerate_audio_devices() -> Result> { + enumerate_audio_devices_with_current(None) +} + +pub fn enumerate_audio_devices_with_current( + current_device: Option<&str>, +) -> Result> { + let host = cpal::default_host(); + let devices = host + .input_devices() + .map_err(|e| AppError::AudioError(format!("Failed to enumerate WASAPI devices: {}", e)))?; + + let mut result = Vec::new(); + + for (index, device) in devices.enumerate() { + let labels = device_labels(&device); + let id = device + .id() + .map(|id| id.to_string()) + .unwrap_or_else(|_| format!("wasapi-index:{}", index)); + + let (sample_rates, channels) = query_device_caps(&device); + if sample_rates.is_empty() || channels.is_empty() { + debug!( + "Skipping WASAPI endpoint without usable input caps: {}", + labels.search_text + ); + continue; + } + + let is_current = + current_device == Some(id.as_str()) || current_device == Some(labels.display.as_str()); + let description = if is_current { + format!("{} (in use)", labels.display) + } else { + labels.display.clone() + }; + + let lower = labels.search_text.to_lowercase(); + let is_hdmi = lower.contains("hdmi") + || lower.contains("capture") + || lower.contains("usb") + || lower.contains("digital"); + + result.push(AudioDeviceInfo { + name: id, + description, + card_index: index as i32, + device_index: 0, + sample_rates, + channels, + is_capture: true, + is_hdmi, + usb_bus: None, + }); + } + + info!("Found {} WASAPI audio capture devices", result.len()); + Ok(result) +} + +fn query_device_caps(device: &cpal::Device) -> (Vec, Vec) { + let mut sample_rates = Vec::new(); + let mut channels = Vec::new(); + + if let Ok(configs) = device.supported_input_configs() { + for cfg in configs { + for rate in [8000, 16000, 22050, 44100, 48000, 96000] { + if cfg.min_sample_rate() <= rate + && rate <= cfg.max_sample_rate() + && !sample_rates.contains(&rate) + { + sample_rates.push(rate); + } + } + + let ch = cfg.channels() as u32; + if !channels.contains(&ch) { + channels.push(ch); + } + } + } + + if (sample_rates.is_empty() || channels.is_empty()) && device.default_input_config().is_ok() { + if let Ok(default_cfg) = device.default_input_config() { + if !sample_rates.contains(&default_cfg.sample_rate()) { + sample_rates.push(default_cfg.sample_rate()); + } + let ch = default_cfg.channels() as u32; + if !channels.contains(&ch) { + channels.push(ch); + } + } + } + + sample_rates.sort_unstable(); + channels.sort_unstable(); + (sample_rates, channels) +} + +struct DeviceLabels { + display: String, + search_text: String, +} + +fn device_labels(device: &cpal::Device) -> DeviceLabels { + match device.description() { + Ok(desc) => { + let formatted = desc.to_string(); + let display = desc + .extended() + .first() + .cloned() + .unwrap_or_else(|| formatted.clone()); + let mut parts = vec![formatted, desc.name().to_string(), display.clone()]; + parts.extend(desc.extended().iter().cloned()); + + DeviceLabels { + display, + search_text: parts.join(" "), + } + } + Err(_) => { + #[allow(deprecated)] + let display = device + .name() + .unwrap_or_else(|_| "Unknown WASAPI capture device".to_string()); + DeviceLabels { + display: display.clone(), + search_text: display, + } + } + } +} + +pub(crate) fn find_wasapi_device(requested_device: &str) -> Result { + let host = cpal::default_host(); + let trimmed = requested_device.trim(); + + if trimmed.is_empty() + || trimmed.eq_ignore_ascii_case("auto") + || trimmed.eq_ignore_ascii_case("default") + { + return host.default_input_device().ok_or_else(|| { + AppError::AudioError("No default WASAPI input device found".to_string()) + }); + } + + if let Ok(id) = DeviceId::from_str(trimmed) { + if let Some(device) = host.device_by_id(&id) { + return Ok(device); + } + } + + let needle = trimmed.to_lowercase(); + let devices = host + .input_devices() + .map_err(|e| AppError::AudioError(format!("Failed to enumerate WASAPI devices: {}", e)))?; + + for device in devices { + let id_match = device + .id() + .map(|id| id.to_string() == trimmed) + .unwrap_or(false); + let labels = device_labels(&device); + if id_match || labels.search_text.to_lowercase().contains(&needle) { + return Ok(device); + } + } + + Err(AppError::AudioError(format!( + "WASAPI audio device not found: {}", + requested_device + ))) +} + +pub fn find_best_audio_device() -> Result { + let devices = enumerate_audio_devices()?; + + if devices.is_empty() { + return Err(AppError::AudioError( + "No WASAPI audio capture devices found".to_string(), + )); + } + + let mut first_48k_stereo: Option<&AudioDeviceInfo> = None; + for device in &devices { + if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) { + continue; + } + if device.is_hdmi { + info!("Selected WASAPI capture device: {}", device.description); + return Ok(device.clone()); + } + if first_48k_stereo.is_none() { + first_48k_stereo = Some(device); + } + } + + if let Some(device) = first_48k_stereo { + info!("Selected WASAPI capture device: {}", device.description); + return Ok(device.clone()); + } + + let device = devices.into_iter().next().unwrap(); + warn!( + "Using fallback WASAPI audio device: {} (will resample if needed)", + device.description + ); + Ok(device) +} diff --git a/src/audio/mod.rs b/src/audio/mod.rs index 0ba10841..b6973d0e 100644 --- a/src/audio/mod.rs +++ b/src/audio/mod.rs @@ -1,15 +1,21 @@ -//! ALSA capture, Opus encode, device enumeration, streaming, controller, health monitor. +//! Platform audio capture, Opus encode, device enumeration, streaming, controller, health monitor. +#[cfg(any(unix, windows))] pub mod capture; pub mod controller; +#[cfg(any(unix, windows))] pub mod device; +#[cfg(any(unix, windows))] pub mod encoder; pub mod monitor; +pub mod recovery; pub mod streamer; +pub mod types; pub use capture::{AudioCapturer, AudioConfig, AudioFrame}; -pub use controller::{AudioController, AudioControllerConfig, AudioQuality, AudioStatus}; +pub use controller::AudioController; pub use device::{enumerate_audio_devices, enumerate_audio_devices_with_current, AudioDeviceInfo}; pub use encoder::{OpusConfig, OpusEncoder, OpusFrame}; pub use monitor::{AudioHealthMonitor, AudioHealthStatus}; pub use streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig}; +pub use types::{AudioControllerConfig, AudioQuality, AudioStatus}; diff --git a/src/audio/recovery.rs b/src/audio/recovery.rs new file mode 100644 index 00000000..d22f90fa --- /dev/null +++ b/src/audio/recovery.rs @@ -0,0 +1,320 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{debug, info, warn}; + +use super::capture::AudioConfig; +use super::device::{enumerate_audio_devices, AudioDeviceInfo}; +use super::monitor::AudioHealthMonitor; +use super::streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig}; +use super::types::AudioControllerConfig; +use super::controller::AudioRecoveredCallback; +use crate::events::{EventBus, StreamDeviceLostKind, SystemEvent}; + +const AUDIO_RECOVERY_RETRY_DELAY: std::time::Duration = std::time::Duration::from_secs(1); + +pub(super) fn select_recovery_device( + devices: &[AudioDeviceInfo], + preferred: &str, +) -> Option { + if let Some(device) = devices + .iter() + .find(|d| !preferred.trim().is_empty() && d.name == preferred) + { + return Some(device.clone()); + } + + devices + .iter() + .find(|d| d.is_hdmi && d.sample_rates.contains(&48_000) && d.channels.contains(&2)) + .or_else(|| { + devices + .iter() + .find(|d| d.sample_rates.contains(&48_000) && d.channels.contains(&2)) + }) + .or_else(|| devices.first()) + .cloned() +} + +async fn publish_state( + event_bus: &Arc>>>, + state: &str, + device: Option, + reason: Option<&str>, + next_retry_ms: Option, +) { + if let Some(bus) = event_bus.read().await.as_ref() { + bus.publish(SystemEvent::StreamStateChanged { + state: state.to_string(), + device, + reason: reason.map(str::to_string), + next_retry_ms, + }); + bus.mark_device_info_dirty(); + } +} + +async fn publish_device_lost( + event_bus: &Arc>>>, + device: &str, + reason: &str, +) { + if let Some(bus) = event_bus.read().await.as_ref() { + bus.publish(SystemEvent::StreamDeviceLost { + kind: StreamDeviceLostKind::Audio, + device: device.to_string(), + reason: reason.to_string(), + }); + } +} + +async fn publish_reconnecting( + event_bus: &Arc>>>, + device: &str, + attempt: u32, +) { + if let Some(bus) = event_bus.read().await.as_ref() { + bus.publish(SystemEvent::StreamReconnecting { + device: device.to_string(), + attempt, + }); + } +} + +async fn publish_recovered(event_bus: &Arc>>>, device: &str) { + if let Some(bus) = event_bus.read().await.as_ref() { + bus.publish(SystemEvent::StreamRecovered { + device: device.to_string(), + }); + } +} + +fn spawn_stream_monitor_from_parts( + config: Arc>, + streamer_slot: Arc>>>, + event_bus: Arc>>>, + monitor: Arc, + recovery_in_progress: Arc, + recovered_callback: Arc>>, + streamer: Arc, + device: String, +) { + let mut state_rx = streamer.state_watch(); + + tokio::spawn(async move { + loop { + if state_rx.changed().await.is_err() { + return; + } + + if *state_rx.borrow() != AudioStreamState::Error { + continue; + } + + { + let current = streamer_slot.read().await; + if !current + .as_ref() + .is_some_and(|current| Arc::ptr_eq(current, &streamer)) + { + return; + } + } + + let reason = format!("Audio device lost: {}", device); + monitor.report_error(&reason, "device_lost").await; + spawn_recovery_task_from_parts( + config, + streamer_slot, + event_bus, + monitor, + recovery_in_progress, + recovered_callback, + device, + reason, + ); + return; + } + }); +} + +fn spawn_recovery_task_from_parts( + config: Arc>, + streamer_slot: Arc>>>, + event_bus: Arc>>>, + monitor: Arc, + recovery_in_progress: Arc, + recovered_callback: Arc>>, + lost_device: String, + reason: String, +) { + if recovery_in_progress.swap(true, Ordering::SeqCst) { + debug!("Audio recovery already in progress"); + return; + } + + tokio::spawn(async move { + warn!("Audio recovery started for {}: {}", lost_device, reason); + publish_device_lost(&event_bus, &lost_device, &reason).await; + publish_state( + &event_bus, + "device_lost", + Some(lost_device.clone()), + Some("audio_device_lost"), + Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64), + ) + .await; + + let mut attempt = 0u32; + + loop { + if !recovery_in_progress.load(Ordering::SeqCst) { + debug!("Audio recovery canceled"); + return; + } + + if streamer_slot + .read() + .await + .as_ref() + .is_some_and(|s| s.is_running()) + { + recovery_in_progress.store(false, Ordering::SeqCst); + return; + } + + let cfg: AudioControllerConfig = config.read().await.clone(); + if !cfg.enabled { + recovery_in_progress.store(false, Ordering::SeqCst); + return; + } + + attempt = attempt.saturating_add(1); + publish_reconnecting(&event_bus, &lost_device, attempt).await; + publish_state( + &event_bus, + "device_lost", + Some(lost_device.clone()), + Some("audio_reconnecting"), + Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64), + ) + .await; + + tokio::time::sleep(AUDIO_RECOVERY_RETRY_DELAY).await; + + let devices = match enumerate_audio_devices() { + Ok(devices) => devices, + Err(e) => { + debug!( + "Audio recovery enumerate failed (attempt {}): {}", + attempt, e + ); + continue; + } + }; + + let Some(device) = select_recovery_device(&devices, &cfg.device) else { + debug!("No audio devices found during recovery attempt {}", attempt); + continue; + }; + + let streamer_config = AudioStreamerConfig { + capture: AudioConfig { + device_name: device.name.clone(), + ..Default::default() + }, + opus: cfg.quality.to_opus_config(), + }; + let new_streamer = Arc::new(AudioStreamer::with_config(streamer_config)); + + match new_streamer.start().await { + Ok(()) => { + { + let mut cfg = config.write().await; + cfg.device = device.name.clone(); + } + *streamer_slot.write().await = Some(new_streamer.clone()); + monitor.report_recovered().await; + publish_recovered(&event_bus, &device.name).await; + if let Some(callback) = recovered_callback.read().await.clone() { + callback(); + } + publish_state( + &event_bus, + "streaming", + Some(device.name.clone()), + None, + None, + ) + .await; + recovery_in_progress.store(false, Ordering::SeqCst); + info!( + "Audio device recovered with {} after {} attempts", + device.name, attempt + ); + spawn_stream_monitor_from_parts( + config, + streamer_slot, + event_bus, + monitor, + recovery_in_progress, + recovered_callback, + new_streamer, + device.name, + ); + return; + } + Err(e) => { + debug!( + "Audio recovery start failed with {} (attempt {}): {}", + device.name, attempt, e + ); + } + } + } + }); +} + +pub(super) fn spawn_stream_monitor( + config: Arc>, + streamer_slot: Arc>>>, + event_bus: Arc>>>, + monitor: Arc, + recovery_in_progress: Arc, + recovered_callback: Arc>>, + streamer: Arc, + device: String, +) { + spawn_stream_monitor_from_parts( + config, + streamer_slot, + event_bus, + monitor, + recovery_in_progress, + recovered_callback, + streamer, + device, + ); +} + +pub(super) fn spawn_recovery_task( + config: Arc>, + streamer_slot: Arc>>>, + event_bus: Arc>>>, + monitor: Arc, + recovery_in_progress: Arc, + recovered_callback: Arc>>, + lost_device: String, + reason: String, +) { + spawn_recovery_task_from_parts( + config, + streamer_slot, + event_bus, + monitor, + recovery_in_progress, + recovered_callback, + lost_device, + reason, + ); +} diff --git a/src/audio/types.rs b/src/audio/types.rs new file mode 100644 index 00000000..3232ef62 --- /dev/null +++ b/src/audio/types.rs @@ -0,0 +1,85 @@ +use serde::{Deserialize, Serialize}; +use std::str::FromStr; + +use super::encoder::OpusConfig; +use crate::error::AppError; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum AudioQuality { + Voice, + #[default] + Balanced, + High, +} + +impl AudioQuality { + pub fn bitrate(&self) -> u32 { + match self { + AudioQuality::Voice => 32000, + AudioQuality::Balanced => 64000, + AudioQuality::High => 128000, + } + } + + pub fn to_opus_config(&self) -> OpusConfig { + match self { + AudioQuality::Voice => OpusConfig::voice(), + AudioQuality::Balanced => OpusConfig::default(), + AudioQuality::High => OpusConfig::music(), + } + } +} + +impl FromStr for AudioQuality { + type Err = AppError; + + fn from_str(s: &str) -> std::result::Result { + match s.trim().to_lowercase().as_str() { + "voice" => Ok(Self::Voice), + "balanced" => Ok(Self::Balanced), + "high" => Ok(Self::High), + _ => Err(AppError::BadRequest(format!( + "invalid audio quality {:?} (expected voice, balanced, or high)", + s.trim() + ))), + } + } +} + +impl std::fmt::Display for AudioQuality { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AudioQuality::Voice => write!(f, "voice"), + AudioQuality::Balanced => write!(f, "balanced"), + AudioQuality::High => write!(f, "high"), + } + } +} + +#[derive(Debug, Clone)] +pub struct AudioControllerConfig { + pub enabled: bool, + pub device: String, + pub quality: AudioQuality, +} + +impl Default for AudioControllerConfig { + fn default() -> Self { + Self { + enabled: false, + device: String::new(), + quality: AudioQuality::Balanced, + } + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct AudioStatus { + pub enabled: bool, + pub streaming: bool, + pub device: Option, + pub quality: AudioQuality, + pub subscriber_count: usize, + pub error: Option, +} diff --git a/src/config/mod.rs b/src/config/mod.rs index f6e65c6e..57afda5a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,7 +1,11 @@ -mod persistence; mod schema; mod store; -pub use persistence::ConfigChange; +/// Configuration change event +#[derive(Debug, Clone)] +pub struct ConfigChange { + pub key: String, +} + pub use schema::*; pub use store::ConfigStore; diff --git a/src/config/persistence.rs b/src/config/persistence.rs deleted file mode 100644 index 105f8a76..00000000 --- a/src/config/persistence.rs +++ /dev/null @@ -1,5 +0,0 @@ -/// Configuration change event -#[derive(Debug, Clone)] -pub struct ConfigChange { - pub key: String, -} diff --git a/src/config/schema.rs b/src/config/schema.rs deleted file mode 100644 index 6dadbeb1..00000000 --- a/src/config/schema.rs +++ /dev/null @@ -1,827 +0,0 @@ -use serde::{Deserialize, Serialize}; -use typeshare::typeshare; - -// Re-export domain config types that are embedded in AppConfig. -// These are simple data types defined in their respective modules; -// keeping the re-export here is acceptable since they flow inward. -pub use crate::extensions::ExtensionsConfig; -pub use crate::rustdesk::config::RustDeskConfig; - -/// Bitrate preset for video encoding -/// -/// Simplifies bitrate configuration by providing three intuitive presets -/// plus a custom option for advanced users. -#[typeshare] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(tag = "type", content = "value")] -#[derive(Default)] -pub enum BitratePreset { - /// Speed priority: 1 Mbps, lowest latency, smaller GOP - Speed, - /// Balanced: 4 Mbps, good quality/latency tradeoff - #[default] - Balanced, - /// Quality priority: 8 Mbps, best visual quality - Quality, - /// Custom bitrate in kbps (for advanced users) - Custom(u32), -} - -impl BitratePreset { - /// Get bitrate value in kbps - pub fn bitrate_kbps(&self) -> u32 { - match self { - Self::Speed => 1000, - Self::Balanced => 4000, - Self::Quality => 8000, - Self::Custom(kbps) => *kbps, - } - } - - /// Get recommended GOP size based on preset - pub fn gop_size(&self, fps: u32) -> u32 { - match self { - Self::Speed => (fps / 2).max(15), - Self::Balanced => fps, - Self::Quality => fps * 2, - Self::Custom(_) => fps, - } - } - - /// Get quality preset name for encoder configuration - pub fn quality_level(&self) -> &'static str { - match self { - Self::Speed => "low", - Self::Balanced => "medium", - Self::Quality => "high", - Self::Custom(_) => "medium", - } - } - - /// Create from kbps value, mapping to nearest preset or Custom - pub fn from_kbps(kbps: u32) -> Self { - match kbps { - 0..=1500 => Self::Speed, - 1501..=6000 => Self::Balanced, - 6001..=10000 => Self::Quality, - _ => Self::Custom(kbps), - } - } -} - -impl std::fmt::Display for BitratePreset { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Speed => write!(f, "Speed (1 Mbps)"), - Self::Balanced => write!(f, "Balanced (4 Mbps)"), - Self::Quality => write!(f, "Quality (8 Mbps)"), - Self::Custom(kbps) => write!(f, "Custom ({} kbps)", kbps), - } - } -} - -/// Main application configuration -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(default)] -#[derive(Default)] -pub struct AppConfig { - /// Whether initial setup has been completed - pub initialized: bool, - /// Authentication settings - pub auth: AuthConfig, - /// Video capture settings - pub video: VideoConfig, - /// HID (keyboard/mouse) settings - pub hid: HidConfig, - /// Mass Storage Device settings - pub msd: MsdConfig, - /// ATX power control settings - pub atx: AtxConfig, - /// Audio settings - pub audio: AudioConfig, - /// Streaming settings - pub stream: StreamConfig, - /// Web server settings - pub web: WebConfig, - /// Extensions settings (ttyd, gostc, easytier) - pub extensions: ExtensionsConfig, - /// RustDesk remote access settings - pub rustdesk: RustDeskConfig, - /// RTSP streaming settings - pub rtsp: RtspConfig, - /// Redfish API settings - pub redfish: RedfishConfig, -} - -/// Authentication configuration -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(default)] -pub struct AuthConfig { - /// Session timeout in seconds - pub session_timeout_secs: u32, - /// Allow multiple concurrent web sessions (single-user mode) - pub single_user_allow_multiple_sessions: bool, - /// Enable 2FA - pub totp_enabled: bool, - /// TOTP secret (encrypted) - pub totp_secret: Option, -} - -impl Default for AuthConfig { - fn default() -> Self { - Self { - session_timeout_secs: 3600 * 24, // 24 hours - single_user_allow_multiple_sessions: false, - totp_enabled: false, - totp_secret: None, - } - } -} - -/// Video capture configuration -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(default)] -pub struct VideoConfig { - /// Video device path (e.g., /dev/video0) - pub device: Option, - /// Video pixel format (e.g., "MJPEG", "YUYV", "NV12") - pub format: Option, - /// Resolution width - pub width: u32, - /// Resolution height - pub height: u32, - /// Frame rate - pub fps: u32, - /// JPEG quality (1-100) - pub quality: u32, -} - -impl Default for VideoConfig { - fn default() -> Self { - Self { - device: None, - format: None, // Auto-detect or use MJPEG as default - width: 1920, - height: 1080, - fps: 30, - quality: 80, - } - } -} - -/// HID backend type -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(rename_all = "lowercase")] -#[derive(Default)] -pub enum HidBackend { - /// USB OTG HID gadget - Otg, - /// CH9329 serial HID controller - Ch9329, - /// Disabled - #[default] - None, -} - -/// OTG USB device descriptor configuration -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct OtgDescriptorConfig { - /// USB Vendor ID (e.g., 0x1d6b) - pub vendor_id: u16, - /// USB Product ID (e.g., 0x0104) - pub product_id: u16, - /// Manufacturer string - pub manufacturer: String, - /// Product string - pub product: String, - /// Serial number (optional, auto-generated if not set) - pub serial_number: Option, -} - -impl Default for OtgDescriptorConfig { - fn default() -> Self { - Self { - vendor_id: 0x1d6b, // Linux Foundation - product_id: 0x0104, // Multifunction Composite Gadget - manufacturer: "One-KVM".to_string(), - product: "One-KVM USB Device".to_string(), - serial_number: None, - } - } -} - -/// OTG HID function profile -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(rename_all = "snake_case")] -#[derive(Default)] -pub enum OtgHidProfile { - /// Full HID device set (keyboard + relative mouse + absolute mouse + consumer control) - #[default] - #[serde(alias = "full_no_msd")] - Full, - /// Full HID device set without consumer control - #[serde(alias = "full_no_consumer_no_msd")] - FullNoConsumer, - /// Legacy profile: only keyboard - LegacyKeyboard, - /// Legacy profile: only relative mouse - LegacyMouseRelative, - /// Custom function selection - Custom, -} - -/// OTG endpoint budget policy. -#[typeshare] -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -#[derive(Default)] -pub enum OtgEndpointBudget { - /// Derive a safe default from the selected UDC. - #[default] - Auto, - /// Limit OTG gadget functions to 5 endpoints. - Five, - /// Limit OTG gadget functions to 6 endpoints. - Six, - /// Do not impose a software endpoint budget. - Unlimited, -} - -impl OtgEndpointBudget { - /// Resolve endpoint limit assuming a known budget variant (not Auto). - pub fn endpoint_limit_raw(&self) -> Option { - match self { - Self::Five => Some(5), - Self::Six => Some(6), - Self::Unlimited => None, - Self::Auto => None, // resolved via `HidConfig::resolved_otg_endpoint_limit` - } - } -} - -/// OTG HID function selection (used when profile is Custom) -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(default)] -pub struct OtgHidFunctions { - pub keyboard: bool, - pub mouse_relative: bool, - pub mouse_absolute: bool, - pub consumer: bool, -} - -impl OtgHidFunctions { - pub fn full() -> Self { - Self { - keyboard: true, - mouse_relative: true, - mouse_absolute: true, - consumer: true, - } - } - - pub fn full_no_consumer() -> Self { - Self { - keyboard: true, - mouse_relative: true, - mouse_absolute: true, - consumer: false, - } - } - - pub fn legacy_keyboard() -> Self { - Self { - keyboard: true, - mouse_relative: false, - mouse_absolute: false, - consumer: false, - } - } - - pub fn legacy_mouse_relative() -> Self { - Self { - keyboard: false, - mouse_relative: true, - mouse_absolute: false, - consumer: false, - } - } - - pub fn is_empty(&self) -> bool { - !self.keyboard && !self.mouse_relative && !self.mouse_absolute && !self.consumer - } - - pub fn endpoint_cost(&self, keyboard_leds: bool) -> u8 { - let mut endpoints = 0; - if self.keyboard { - endpoints += 1; - if keyboard_leds { - endpoints += 1; - } - } - if self.mouse_relative { - endpoints += 1; - } - if self.mouse_absolute { - endpoints += 1; - } - if self.consumer { - endpoints += 1; - } - endpoints - } -} - -impl Default for OtgHidFunctions { - fn default() -> Self { - Self::full() - } -} - -impl OtgHidProfile { - pub fn from_legacy_str(value: &str) -> Option { - match value { - "full" | "full_no_msd" => Some(Self::Full), - "full_no_consumer" | "full_no_consumer_no_msd" => Some(Self::FullNoConsumer), - "legacy_keyboard" => Some(Self::LegacyKeyboard), - "legacy_mouse_relative" => Some(Self::LegacyMouseRelative), - "custom" => Some(Self::Custom), - _ => None, - } - } - - pub fn resolve_functions(&self, custom: &OtgHidFunctions) -> OtgHidFunctions { - match self { - Self::Full => OtgHidFunctions::full(), - Self::FullNoConsumer => OtgHidFunctions::full_no_consumer(), - Self::LegacyKeyboard => OtgHidFunctions::legacy_keyboard(), - Self::LegacyMouseRelative => OtgHidFunctions::legacy_mouse_relative(), - Self::Custom => custom.clone(), - } - } -} - -/// HID configuration -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(default)] -pub struct HidConfig { - /// HID backend type - pub backend: HidBackend, - /// OTG UDC (USB Device Controller) name - pub otg_udc: Option, - /// OTG USB device descriptor configuration - #[serde(default)] - pub otg_descriptor: OtgDescriptorConfig, - /// OTG HID function profile - #[serde(default)] - pub otg_profile: OtgHidProfile, - /// OTG endpoint budget policy - #[serde(default)] - pub otg_endpoint_budget: OtgEndpointBudget, - /// OTG HID function selection (used when profile is Custom) - #[serde(default)] - pub otg_functions: OtgHidFunctions, - /// Enable keyboard LED/status feedback for OTG keyboard - #[serde(default)] - pub otg_keyboard_leds: bool, - /// CH9329 serial port - pub ch9329_port: String, - /// CH9329 baud rate - pub ch9329_baudrate: u32, - /// Mouse mode: absolute or relative - pub mouse_absolute: bool, -} - -impl Default for HidConfig { - fn default() -> Self { - Self { - backend: HidBackend::None, - otg_udc: None, - otg_descriptor: OtgDescriptorConfig::default(), - otg_profile: OtgHidProfile::default(), - otg_endpoint_budget: OtgEndpointBudget::default(), - otg_functions: OtgHidFunctions::default(), - otg_keyboard_leds: false, - ch9329_port: "/dev/ttyUSB0".to_string(), - ch9329_baudrate: 9600, - mouse_absolute: true, - } - } -} - -impl HidConfig { - /// Resolve effective OTG HID functions from profile + custom selection. - /// Pure logic, no external dependency. - pub fn effective_otg_functions(&self) -> OtgHidFunctions { - self.otg_profile.resolve_functions(&self.otg_functions) - } - - /// Whether keyboard LED feedback is effectively enabled. - pub fn effective_otg_keyboard_leds(&self) -> bool { - self.otg_keyboard_leds && self.effective_otg_functions().keyboard - } - - /// Effective HID functions after applying all constraints. - pub fn constrained_otg_functions(&self) -> OtgHidFunctions { - self.effective_otg_functions() - } - - /// Calculate required endpoint count for the current function selection. - pub fn effective_otg_required_endpoints(&self, msd_enabled: bool) -> u8 { - let functions = self.effective_otg_functions(); - let mut endpoints = functions.endpoint_cost(self.effective_otg_keyboard_leds()); - if msd_enabled { - endpoints += 2; - } - endpoints - } - - /// Validate endpoint budget for the current OTG configuration (UDC-aware when budget is Auto). - pub fn validate_otg_endpoint_budget(&self, msd_enabled: bool) -> crate::error::Result<()> { - if self.backend != HidBackend::Otg { - return Ok(()); - } - - let functions = self.effective_otg_functions(); - if functions.is_empty() { - return Err(crate::error::AppError::BadRequest( - "OTG HID functions cannot be empty".to_string(), - )); - } - - let resolved_limit = self.resolved_otg_endpoint_limit(); - let required = self.effective_otg_required_endpoints(msd_enabled); - if let Some(limit) = resolved_limit { - if required > limit { - return Err(crate::error::AppError::BadRequest(format!( - "OTG selection requires {} endpoints, but the configured limit is {}", - required, limit - ))); - } - } - - Ok(()) - } - - /// Effective OTG UDC name (for change detection / service). - #[inline] - pub fn resolved_otg_udc(&self) -> Option { - if self.backend != HidBackend::Otg { - return None; - } - self.otg_udc - .as_ref() - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .or_else(|| crate::otg::OtgGadgetManager::find_udc()) - } - - /// Resolved endpoint limit used for OTG gadget allocator / validation. - #[inline] - pub fn resolved_otg_endpoint_limit(&self) -> Option { - if self.backend != HidBackend::Otg { - return None; - } - match self.otg_endpoint_budget { - OtgEndpointBudget::Five => Some(5), - OtgEndpointBudget::Six => Some(6), - OtgEndpointBudget::Unlimited => None, - OtgEndpointBudget::Auto => { - let udc = self.resolved_otg_udc().unwrap_or_default(); - if crate::otg::configfs::is_low_endpoint_udc(&udc) { - Some(5) - } else { - Some(6) - } - } - } - } -} - -/// MSD configuration -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(default)] -pub struct MsdConfig { - /// Enable MSD functionality - pub enabled: bool, - /// MSD base directory (absolute path) - pub msd_dir: String, -} - -impl Default for MsdConfig { - fn default() -> Self { - Self { - enabled: true, - msd_dir: String::new(), - } - } -} - -impl MsdConfig { - pub fn msd_dir_path(&self) -> std::path::PathBuf { - std::path::PathBuf::from(&self.msd_dir) - } - - pub fn images_dir(&self) -> std::path::PathBuf { - self.msd_dir_path().join("images") - } - - pub fn ventoy_dir(&self) -> std::path::PathBuf { - self.msd_dir_path().join("ventoy") - } - - pub fn drive_path(&self) -> std::path::PathBuf { - self.ventoy_dir().join("ventoy.img") - } -} - -// Re-export ATX types from atx module for configuration -pub use crate::atx::{ActiveLevel, AtxDriverType, AtxKeyConfig, AtxLedConfig}; - -/// ATX power control configuration -/// -/// Each ATX action (power, reset) can be independently configured with its own -/// hardware binding using the four-tuple: (driver, device, pin, active_level). -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(default)] -#[derive(Default)] -pub struct AtxConfig { - /// Enable ATX functionality - pub enabled: bool, - /// Power button configuration (used for both short and long press) - pub power: AtxKeyConfig, - /// Reset button configuration - pub reset: AtxKeyConfig, - /// LED sensing configuration (optional) - pub led: AtxLedConfig, - /// Network interface for WOL packets (empty = auto) - pub wol_interface: String, -} - -impl AtxConfig { - /// Convert to AtxControllerConfig for the controller - pub fn to_controller_config(&self) -> crate::atx::AtxControllerConfig { - crate::atx::AtxControllerConfig { - enabled: self.enabled, - power: self.power.clone(), - reset: self.reset.clone(), - led: self.led.clone(), - } - } -} - -/// Audio configuration -/// -/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo). -/// These are optimal for Opus encoding and match WebRTC requirements. -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(default)] -pub struct AudioConfig { - /// Enable audio capture - pub enabled: bool, - /// ALSA device name - pub device: String, - /// Audio quality preset: "voice", "balanced", "high" - pub quality: String, -} - -impl Default for AudioConfig { - fn default() -> Self { - Self { - enabled: false, - device: String::new(), - quality: "balanced".to_string(), - } - } -} - -/// Stream mode -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(rename_all = "lowercase")] -#[derive(Default)] -pub enum StreamMode { - /// WebRTC with H264/H265 - WebRTC, - /// MJPEG over HTTP - #[default] - Mjpeg, -} - -/// RTSP output codec -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(rename_all = "lowercase")] -#[derive(Default)] -pub enum RtspCodec { - #[default] - H264, - H265, -} - -/// RTSP configuration -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(default)] -pub struct RtspConfig { - /// Enable RTSP output - pub enabled: bool, - /// Bind IP address - pub bind: String, - /// RTSP TCP listen port - pub port: u16, - /// Stream path (without leading slash) - pub path: String, - /// Allow only one client connection at a time - pub allow_one_client: bool, - /// Output codec (H264/H265) - pub codec: RtspCodec, - /// Optional username for authentication - pub username: Option, - /// Optional password for authentication - #[typeshare(skip)] - pub password: Option, -} - -impl Default for RtspConfig { - fn default() -> Self { - Self { - enabled: false, - bind: "0.0.0.0".to_string(), - port: 8554, - path: "live".to_string(), - allow_one_client: true, - codec: RtspCodec::H264, - username: None, - password: None, - } - } -} - -/// Encoder type -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(rename_all = "lowercase")] -#[derive(Default)] -pub enum EncoderType { - /// Auto-detect best encoder - #[default] - Auto, - /// Software encoder (libx264) - Software, - /// VAAPI hardware encoder - Vaapi, - /// NVIDIA NVENC hardware encoder - Nvenc, - /// Intel Quick Sync hardware encoder - Qsv, - /// AMD AMF hardware encoder - Amf, - /// Rockchip MPP hardware encoder - Rkmpp, - /// V4L2 M2M hardware encoder - V4l2m2m, -} - -impl EncoderType { - /// Get display name for UI - pub fn display_name(&self) -> &'static str { - match self { - EncoderType::Auto => "Auto (Recommended)", - EncoderType::Software => "Software (CPU)", - EncoderType::Vaapi => "VAAPI", - EncoderType::Nvenc => "NVIDIA NVENC", - EncoderType::Qsv => "Intel Quick Sync", - EncoderType::Amf => "AMD AMF", - EncoderType::Rkmpp => "Rockchip MPP", - EncoderType::V4l2m2m => "V4L2 M2M", - } - } -} - -/// Streaming configuration -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(default)] -pub struct StreamConfig { - /// Stream mode - pub mode: StreamMode, - /// Encoder type for H264/H265 - pub encoder: EncoderType, - /// Bitrate preset (Speed/Balanced/Quality) - pub bitrate_preset: BitratePreset, - /// Custom STUN server (e.g., "stun:stun.l.google.com:19302") - /// If empty, uses public ICE servers from secrets.toml - pub stun_server: Option, - /// Custom TURN server (e.g., "turn:turn.example.com:3478") - /// If empty, uses public ICE servers from secrets.toml - pub turn_server: Option, - /// TURN username - pub turn_username: Option, - /// TURN password (stored encrypted in DB, not exposed via API) - pub turn_password: Option, - /// Auto-pause when no clients connected - #[typeshare(skip)] - pub auto_pause_enabled: bool, - /// Auto-pause delay (seconds) - #[typeshare(skip)] - pub auto_pause_delay_secs: u64, - /// Client timeout for cleanup (seconds) - #[typeshare(skip)] - pub client_timeout_secs: u64, -} - -impl Default for StreamConfig { - fn default() -> Self { - Self { - mode: StreamMode::Mjpeg, - encoder: EncoderType::Auto, - bitrate_preset: BitratePreset::Balanced, - // Empty means use public ICE servers (like RustDesk) - stun_server: None, - turn_server: None, - turn_username: None, - turn_password: None, - auto_pause_enabled: false, - auto_pause_delay_secs: 10, - client_timeout_secs: 30, - } - } -} - -impl StreamConfig { - /// Whether built-in / public ICE is used (no custom STUN or TURN URL configured). - pub fn is_using_public_ice_servers(&self) -> bool { - let no_custom_stun = self - .stun_server - .as_ref() - .map_or(true, |s| s.trim().is_empty()); - let no_custom_turn = self - .turn_server - .as_ref() - .map_or(true, |s| s.trim().is_empty()); - no_custom_stun && no_custom_turn - } -} - -/// Web server configuration persisted in the database (includes on-disk TLS paths). -/// -/// The HTTP API for `/api/config/web` uses `WebConfigResponse` instead: no path fields, includes `has_custom_cert`. -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(default)] -pub struct WebConfig { - /// HTTP port - pub http_port: u16, - /// HTTPS port - pub https_port: u16, - /// Bind addresses (preferred) - pub bind_addresses: Vec, - /// Bind address (legacy) - pub bind_address: String, - /// Enable HTTPS - pub https_enabled: bool, - /// Custom SSL certificate path - pub ssl_cert_path: Option, - /// Custom SSL key path - pub ssl_key_path: Option, -} - -impl Default for WebConfig { - fn default() -> Self { - Self { - http_port: 8080, - https_port: 8443, - bind_addresses: Vec::new(), - bind_address: "0.0.0.0".to_string(), - https_enabled: false, - ssl_cert_path: None, - ssl_key_path: None, - } - } -} - -/// Redfish API configuration -#[typeshare] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(default)] -pub struct RedfishConfig { - /// Enable Redfish API endpoint - pub enabled: bool, -} - -impl Default for RedfishConfig { - fn default() -> Self { - Self { enabled: false } - } -} diff --git a/src/config/schema/atx.rs b/src/config/schema/atx.rs new file mode 100644 index 00000000..d8913393 --- /dev/null +++ b/src/config/schema/atx.rs @@ -0,0 +1,28 @@ +use serde::{Deserialize, Serialize}; +use typeshare::typeshare; + +pub use crate::atx::{ActiveLevel, AtxDriverType, AtxKeyConfig, AtxLedConfig}; + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +#[derive(Default)] +pub struct AtxConfig { + pub enabled: bool, + pub power: AtxKeyConfig, + pub reset: AtxKeyConfig, + pub led: AtxLedConfig, + pub wol_interface: String, +} + +impl AtxConfig { + pub fn to_controller_config(&self) -> crate::atx::AtxControllerConfig { + crate::atx::AtxControllerConfig { + enabled: self.enabled, + power: self.power.clone(), + reset: self.reset.clone(), + led: self.led.clone(), + } + } +} + diff --git a/src/config/schema/common.rs b/src/config/schema/common.rs new file mode 100644 index 00000000..2ce3cc2a --- /dev/null +++ b/src/config/schema/common.rs @@ -0,0 +1,64 @@ +use serde::{Deserialize, Serialize}; +use typeshare::typeshare; + +#[typeshare] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", content = "value")] +#[derive(Default)] +pub enum BitratePreset { + Speed, + #[default] + Balanced, + Quality, + Custom(u32), +} + +impl BitratePreset { + pub fn bitrate_kbps(&self) -> u32 { + match self { + Self::Speed => 1000, + Self::Balanced => 4000, + Self::Quality => 8000, + Self::Custom(kbps) => *kbps, + } + } + + pub fn gop_size(&self, fps: u32) -> u32 { + match self { + Self::Speed => (fps / 2).max(15), + Self::Balanced => fps, + Self::Quality => fps * 2, + Self::Custom(_) => fps, + } + } + + pub fn quality_level(&self) -> &'static str { + match self { + Self::Speed => "low", + Self::Balanced => "medium", + Self::Quality => "high", + Self::Custom(_) => "medium", + } + } + + pub fn from_kbps(kbps: u32) -> Self { + match kbps { + 0..=1500 => Self::Speed, + 1501..=6000 => Self::Balanced, + 6001..=10000 => Self::Quality, + _ => Self::Custom(kbps), + } + } +} + +impl std::fmt::Display for BitratePreset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Speed => write!(f, "Speed (1 Mbps)"), + Self::Balanced => write!(f, "Balanced (4 Mbps)"), + Self::Quality => write!(f, "Quality (8 Mbps)"), + Self::Custom(kbps) => write!(f, "Custom ({} kbps)", kbps), + } + } +} + diff --git a/src/config/schema/hid.rs b/src/config/schema/hid.rs new file mode 100644 index 00000000..7dae19e6 --- /dev/null +++ b/src/config/schema/hid.rs @@ -0,0 +1,309 @@ +use serde::{Deserialize, Serialize}; +use typeshare::typeshare; + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +#[derive(Default)] +pub enum HidBackend { + Otg, + Ch9329, + #[default] + None, +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct OtgDescriptorConfig { + pub vendor_id: u16, + pub product_id: u16, + pub manufacturer: String, + pub product: String, + pub serial_number: Option, +} + +impl Default for OtgDescriptorConfig { + fn default() -> Self { + Self { + vendor_id: 0x1d6b, + product_id: 0x0104, + manufacturer: "One-KVM".to_string(), + product: "One-KVM USB Device".to_string(), + serial_number: None, + } + } +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +#[derive(Default)] +pub enum OtgHidProfile { + #[default] + #[serde(alias = "full_no_msd")] + Full, + #[serde(alias = "full_no_consumer_no_msd")] + FullNoConsumer, + LegacyKeyboard, + LegacyMouseRelative, + Custom, +} + +#[typeshare] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +#[derive(Default)] +pub enum OtgEndpointBudget { + #[default] + Auto, + Five, + Six, + Unlimited, +} + +impl OtgEndpointBudget { + pub fn endpoint_limit_raw(&self) -> Option { + match self { + Self::Five => Some(5), + Self::Six => Some(6), + Self::Unlimited => None, + Self::Auto => None, + } + } +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(default)] +pub struct OtgHidFunctions { + pub keyboard: bool, + pub mouse_relative: bool, + pub mouse_absolute: bool, + pub consumer: bool, +} + +impl OtgHidFunctions { + pub fn full() -> Self { + Self { + keyboard: true, + mouse_relative: true, + mouse_absolute: true, + consumer: true, + } + } + + pub fn full_no_consumer() -> Self { + Self { + keyboard: true, + mouse_relative: true, + mouse_absolute: true, + consumer: false, + } + } + + pub fn legacy_keyboard() -> Self { + Self { + keyboard: true, + mouse_relative: false, + mouse_absolute: false, + consumer: false, + } + } + + pub fn legacy_mouse_relative() -> Self { + Self { + keyboard: false, + mouse_relative: true, + mouse_absolute: false, + consumer: false, + } + } + + pub fn is_empty(&self) -> bool { + !self.keyboard && !self.mouse_relative && !self.mouse_absolute && !self.consumer + } + + pub fn endpoint_cost(&self, keyboard_leds: bool) -> u8 { + let mut endpoints = 0; + if self.keyboard { + endpoints += 1; + if keyboard_leds { + endpoints += 1; + } + } + if self.mouse_relative { + endpoints += 1; + } + if self.mouse_absolute { + endpoints += 1; + } + if self.consumer { + endpoints += 1; + } + endpoints + } +} + +impl Default for OtgHidFunctions { + fn default() -> Self { + Self::full() + } +} + +impl OtgHidProfile { + pub fn from_legacy_str(value: &str) -> Option { + match value { + "full" | "full_no_msd" => Some(Self::Full), + "full_no_consumer" | "full_no_consumer_no_msd" => Some(Self::FullNoConsumer), + "legacy_keyboard" => Some(Self::LegacyKeyboard), + "legacy_mouse_relative" => Some(Self::LegacyMouseRelative), + "custom" => Some(Self::Custom), + _ => None, + } + } + + pub fn resolve_functions(&self, custom: &OtgHidFunctions) -> OtgHidFunctions { + match self { + Self::Full => OtgHidFunctions::full(), + Self::FullNoConsumer => OtgHidFunctions::full_no_consumer(), + Self::LegacyKeyboard => OtgHidFunctions::legacy_keyboard(), + Self::LegacyMouseRelative => OtgHidFunctions::legacy_mouse_relative(), + Self::Custom => custom.clone(), + } + } +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(default)] +pub struct HidConfig { + pub backend: HidBackend, + pub otg_udc: Option, + #[serde(default)] + pub otg_descriptor: OtgDescriptorConfig, + #[serde(default)] + pub otg_profile: OtgHidProfile, + #[serde(default)] + pub otg_endpoint_budget: OtgEndpointBudget, + #[serde(default)] + pub otg_functions: OtgHidFunctions, + #[serde(default)] + pub otg_keyboard_leds: bool, + pub ch9329_port: String, + pub ch9329_baudrate: u32, + pub mouse_absolute: bool, +} + +impl Default for HidConfig { + fn default() -> Self { + Self { + backend: HidBackend::None, + otg_udc: None, + otg_descriptor: OtgDescriptorConfig::default(), + otg_profile: OtgHidProfile::default(), + otg_endpoint_budget: OtgEndpointBudget::default(), + otg_functions: OtgHidFunctions::default(), + otg_keyboard_leds: false, + ch9329_port: "/dev/ttyUSB0".to_string(), + ch9329_baudrate: 9600, + mouse_absolute: true, + } + } +} + +impl HidConfig { + pub fn effective_otg_functions(&self) -> OtgHidFunctions { + self.otg_profile.resolve_functions(&self.otg_functions) + } + + pub fn effective_otg_keyboard_leds(&self) -> bool { + self.otg_keyboard_leds && self.effective_otg_functions().keyboard + } + + pub fn constrained_otg_functions(&self) -> OtgHidFunctions { + self.effective_otg_functions() + } + + pub fn effective_otg_required_endpoints(&self, msd_enabled: bool) -> u8 { + let functions = self.effective_otg_functions(); + let mut endpoints = functions.endpoint_cost(self.effective_otg_keyboard_leds()); + if msd_enabled { + endpoints += 2; + } + endpoints + } + + pub fn validate_otg_endpoint_budget(&self, msd_enabled: bool) -> crate::error::Result<()> { + if self.backend != HidBackend::Otg { + return Ok(()); + } + + let functions = self.effective_otg_functions(); + if functions.is_empty() { + return Err(crate::error::AppError::BadRequest( + "OTG HID functions cannot be empty".to_string(), + )); + } + + let resolved_limit = self.resolved_otg_endpoint_limit(); + let required = self.effective_otg_required_endpoints(msd_enabled); + if let Some(limit) = resolved_limit { + if required > limit { + return Err(crate::error::AppError::BadRequest(format!( + "OTG selection requires {} endpoints, but the configured limit is {}", + required, limit + ))); + } + } + + Ok(()) + } + + #[inline] + pub fn resolved_otg_udc(&self) -> Option { + if self.backend != HidBackend::Otg { + return None; + } + self.otg_udc + .as_ref() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .or_else(|| { + #[cfg(unix)] + { + crate::otg::OtgGadgetManager::find_udc() + } + #[cfg(not(unix))] + { + None + } + }) + } + + #[inline] + pub fn resolved_otg_endpoint_limit(&self) -> Option { + if self.backend != HidBackend::Otg { + return None; + } + match self.otg_endpoint_budget { + OtgEndpointBudget::Five => Some(5), + OtgEndpointBudget::Six => Some(6), + OtgEndpointBudget::Unlimited => None, + OtgEndpointBudget::Auto => { + #[cfg(unix)] + let udc = self.resolved_otg_udc().unwrap_or_default(); + #[cfg(unix)] + if crate::otg::configfs::is_low_endpoint_udc(&udc) { + Some(5) + } else { + Some(6) + } + #[cfg(not(unix))] + { + Some(6) + } + } + } + } +} + diff --git a/src/config/schema/mod.rs b/src/config/schema/mod.rs new file mode 100644 index 00000000..52b013b6 --- /dev/null +++ b/src/config/schema/mod.rs @@ -0,0 +1,44 @@ +use serde::{Deserialize, Serialize}; +use typeshare::typeshare; + +pub use crate::extensions::ExtensionsConfig; +pub use crate::rustdesk::config::RustDeskConfig; + +mod atx; +mod common; +mod hid; +mod stream; +mod web; + +pub use atx::*; +pub use common::*; +pub use hid::*; +pub use stream::*; +pub use web::*; + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +#[derive(Default)] +pub struct AppConfig { + pub initialized: bool, + pub auth: AuthConfig, + pub video: VideoConfig, + pub hid: HidConfig, + pub msd: MsdConfig, + pub atx: AtxConfig, + pub audio: AudioConfig, + pub stream: StreamConfig, + pub web: WebConfig, + pub extensions: ExtensionsConfig, + pub rustdesk: RustDeskConfig, + pub rtsp: RtspConfig, + pub redfish: RedfishConfig, +} + +impl AppConfig { + pub fn apply_platform_defaults(&mut self) { + crate::platform::defaults::apply(self); + } +} + diff --git a/src/config/schema/stream.rs b/src/config/schema/stream.rs new file mode 100644 index 00000000..aefec097 --- /dev/null +++ b/src/config/schema/stream.rs @@ -0,0 +1,149 @@ +use serde::{Deserialize, Serialize}; +use typeshare::typeshare; + +use super::BitratePreset; + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +#[derive(Default)] +pub enum StreamMode { + WebRTC, + #[default] + Mjpeg, +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +#[derive(Default)] +pub enum RtspCodec { + #[default] + H264, + H265, +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct RtspConfig { + pub enabled: bool, + pub bind: String, + pub port: u16, + pub path: String, + pub allow_one_client: bool, + pub codec: RtspCodec, + pub username: Option, + #[typeshare(skip)] + pub password: Option, +} + +impl Default for RtspConfig { + fn default() -> Self { + Self { + enabled: false, + bind: "0.0.0.0".to_string(), + port: 8554, + path: "live".to_string(), + allow_one_client: true, + codec: RtspCodec::H264, + username: None, + password: None, + } + } +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +#[derive(Default)] +pub enum EncoderType { + #[default] + Auto, + Software, + Vaapi, + Nvenc, + Qsv, + Amf, + Rkmpp, + V4l2m2m, +} + +impl EncoderType { + pub fn display_name(&self) -> &'static str { + match self { + EncoderType::Auto => "Auto (Recommended)", + EncoderType::Software => "Software (CPU)", + EncoderType::Vaapi => "VAAPI", + EncoderType::Nvenc => "NVIDIA NVENC", + EncoderType::Qsv => "Intel Quick Sync", + EncoderType::Amf => "AMD AMF", + EncoderType::Rkmpp => "Rockchip MPP", + EncoderType::V4l2m2m => "V4L2 M2M", + } + } +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct StreamConfig { + pub mode: StreamMode, + pub encoder: EncoderType, + pub bitrate_preset: BitratePreset, + pub stun_server: Option, + pub turn_server: Option, + pub turn_username: Option, + pub turn_password: Option, + #[typeshare(skip)] + pub auto_pause_enabled: bool, + #[typeshare(skip)] + pub auto_pause_delay_secs: u64, + #[typeshare(skip)] + pub client_timeout_secs: u64, +} + +impl Default for StreamConfig { + fn default() -> Self { + Self { + mode: StreamMode::Mjpeg, + encoder: EncoderType::Auto, + bitrate_preset: BitratePreset::Balanced, + stun_server: None, + turn_server: None, + turn_username: None, + turn_password: None, + auto_pause_enabled: false, + auto_pause_delay_secs: 10, + client_timeout_secs: 30, + } + } +} + +impl StreamConfig { + pub fn is_using_public_ice_servers(&self) -> bool { + let no_custom_stun = self + .stun_server + .as_ref() + .map_or(true, |s| s.trim().is_empty()); + let no_custom_turn = self + .turn_server + .as_ref() + .map_or(true, |s| s.trim().is_empty()); + no_custom_stun && no_custom_turn + } +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct RedfishConfig { + pub enabled: bool, +} + +impl Default for RedfishConfig { + fn default() -> Self { + Self { enabled: false } + } +} + diff --git a/src/config/schema/web.rs b/src/config/schema/web.rs new file mode 100644 index 00000000..d835ccf1 --- /dev/null +++ b/src/config/schema/web.rs @@ -0,0 +1,129 @@ +use serde::{Deserialize, Serialize}; +use typeshare::typeshare; + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct AuthConfig { + pub session_timeout_secs: u32, + pub single_user_allow_multiple_sessions: bool, + pub totp_enabled: bool, + pub totp_secret: Option, +} + +impl Default for AuthConfig { + fn default() -> Self { + Self { + session_timeout_secs: 3600 * 24, + single_user_allow_multiple_sessions: false, + totp_enabled: false, + totp_secret: None, + } + } +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(default)] +pub struct VideoConfig { + pub device: Option, + pub format: Option, + pub width: u32, + pub height: u32, + pub fps: u32, + pub quality: u32, +} + +impl Default for VideoConfig { + fn default() -> Self { + Self { + device: None, + format: None, + width: 1920, + height: 1080, + fps: 30, + quality: 80, + } + } +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct MsdConfig { + pub enabled: bool, + pub msd_dir: String, +} + +impl Default for MsdConfig { + fn default() -> Self { + Self { + enabled: true, + msd_dir: String::new(), + } + } +} + +impl MsdConfig { + pub fn msd_dir_path(&self) -> std::path::PathBuf { + std::path::PathBuf::from(&self.msd_dir) + } + + pub fn images_dir(&self) -> std::path::PathBuf { + self.msd_dir_path().join("images") + } + + pub fn ventoy_dir(&self) -> std::path::PathBuf { + self.msd_dir_path().join("ventoy") + } + + pub fn drive_path(&self) -> std::path::PathBuf { + self.ventoy_dir().join("ventoy.img") + } +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct AudioConfig { + pub enabled: bool, + pub device: String, + pub quality: String, +} + +impl Default for AudioConfig { + fn default() -> Self { + Self { + enabled: false, + device: String::new(), + quality: "balanced".to_string(), + } + } +} + +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct WebConfig { + pub http_port: u16, + pub https_port: u16, + pub bind_addresses: Vec, + pub bind_address: String, + pub https_enabled: bool, + pub ssl_cert_path: Option, + pub ssl_key_path: Option, +} + +impl Default for WebConfig { + fn default() -> Self { + Self { + http_port: 8080, + https_port: 8443, + bind_addresses: Vec::new(), + bind_address: "0.0.0.0".to_string(), + https_enabled: false, + ssl_cert_path: None, + ssl_key_path: None, + } + } +} diff --git a/src/config/store.rs b/src/config/store.rs index 00edfca3..1ec7e73b 100644 --- a/src/config/store.rs +++ b/src/config/store.rs @@ -4,29 +4,20 @@ use std::sync::Arc; use tokio::sync::broadcast; use tokio::sync::Mutex; -use super::persistence::ConfigChange; use super::AppConfig; +use super::ConfigChange; use crate::error::{AppError, Result}; -/// Configuration store backed by SQLite -/// -/// Uses `ArcSwap` for lock-free reads, providing high performance -/// for frequent configuration access in hot paths. #[derive(Clone)] pub struct ConfigStore { pool: Pool, - /// Lock-free cache using ArcSwap for zero-cost reads cache: Arc>, change_tx: broadcast::Sender, - /// Serializes `set` / `update` so concurrent PATCH handlers cannot clobber each other write_lock: Arc>, } impl ConfigStore { - /// Create a new configuration store pub fn new(pool: Pool) -> Result { - // Load or create default config synchronously wrapper - // (actual DB load is async, handled in init()) Ok(Self { pool, cache: Arc::new(ArcSwap::from_pointee(AppConfig::default())), @@ -35,14 +26,12 @@ impl ConfigStore { }) } - /// Load configuration from database (call after new()) pub async fn load(&self) -> Result<()> { let config = Self::load_config(&self.pool).await?; self.cache.store(Arc::new(config)); Ok(()) } - /// Load configuration from database async fn load_config(pool: &Pool) -> Result { let row: Option<(String,)> = sqlx::query_as("SELECT value FROM config WHERE key = 'app_config'") @@ -54,7 +43,6 @@ impl ConfigStore { serde_json::from_str(&json).map_err(|e| AppError::Config(e.to_string())) } None => { - // Create default config let config = AppConfig::default(); Self::save_config_to_db(pool, &config).await?; Ok(config) @@ -62,7 +50,6 @@ impl ConfigStore { } } - /// Save configuration to database async fn save_config_to_db(pool: &Pool, config: &AppConfig) -> Result<()> { let json = serde_json::to_string(config)?; @@ -80,21 +67,15 @@ impl ConfigStore { Ok(()) } - /// Get current configuration (lock-free, zero-copy) - /// - /// Returns an `Arc` for efficient sharing without cloning. - /// This is a lock-free operation with minimal overhead. pub fn get(&self) -> Arc { self.cache.load_full() } - /// Set entire configuration pub async fn set(&self, config: AppConfig) -> Result<()> { let _guard = self.write_lock.lock().await; Self::save_config_to_db(&self.pool, &config).await?; self.cache.store(Arc::new(config)); - // Notify subscribers let _ = self.change_tx.send(ConfigChange { key: "app_config".to_string(), }); @@ -102,27 +83,19 @@ impl ConfigStore { Ok(()) } - /// Update configuration with a closure - /// - /// Uses read-modify-write under a mutex so concurrent `update` / `set` calls are serialized - /// and merged correctly (each closure sees the latest stored config). pub async fn update(&self, f: F) -> Result<()> where F: FnOnce(&mut AppConfig), { let _guard = self.write_lock.lock().await; - // Load current config, clone it for modification let current = self.cache.load(); let mut config = (**current).clone(); f(&mut config); - // Persist to database first Self::save_config_to_db(&self.pool, &config).await?; - // Then update cache atomically self.cache.store(Arc::new(config)); - // Notify subscribers let _ = self.change_tx.send(ConfigChange { key: "app_config".to_string(), }); @@ -130,12 +103,10 @@ impl ConfigStore { Ok(()) } - /// Subscribe to configuration changes pub fn subscribe(&self) -> broadcast::Receiver { self.change_tx.subscribe() } - /// Check if system is initialized (lock-free) pub fn is_initialized(&self) -> bool { self.cache.load().initialized } @@ -158,11 +129,9 @@ mod tests { let store = ConfigStore::new(db.clone_pool()).unwrap(); store.load().await.unwrap(); - // Check default config (now lock-free, no await needed) let config = store.get(); assert!(!config.initialized); - // Update config store .update(|c| { c.initialized = true; @@ -171,12 +140,10 @@ mod tests { .await .unwrap(); - // Verify update let config = store.get(); assert!(config.initialized); assert_eq!(config.web.http_port, 9000); - // Create new store instance and verify persistence let store2 = ConfigStore::new(db.clone_pool()).unwrap(); store2.load().await.unwrap(); let config = store2.get(); diff --git a/src/diagnostics/linux.rs b/src/diagnostics/linux.rs new file mode 100644 index 00000000..29403ae1 --- /dev/null +++ b/src/diagnostics/linux.rs @@ -0,0 +1,280 @@ +use super::{DeviceInfo, DiskSpaceInfo, NetworkAddress}; +use crate::error::{AppError, Result}; +use crate::utils::hostname_uname; + +pub fn get_disk_space(path: &std::path::Path) -> Result { + let stat = nix::sys::statvfs::statvfs(path) + .map_err(|e| AppError::Internal(format!("Failed to get disk space: {}", e)))?; + + let block_size = stat.block_size() as u64; + let total = stat.blocks() as u64 * block_size; + let available = stat.blocks_available() as u64 * block_size; + let used = total - available; + + Ok(DiskSpaceInfo { + total, + available, + used, + }) +} + +pub fn get_device_info() -> DeviceInfo { + let mem_info = get_meminfo(); + + DeviceInfo { + hostname: hostname_uname(), + cpu_model: get_cpu_model(), + cpu_usage: get_cpu_usage(), + memory_total: mem_info.total, + memory_used: mem_info.total.saturating_sub(mem_info.available), + network_addresses: get_network_addresses(), + serial_ports: crate::utils::list_serial_ports(), + } +} + +fn get_cpu_model() -> String { + let cpuinfo = std::fs::read_to_string("/proc/cpuinfo").ok(); + + if let Some(model) = parse_cpu_model_from_cpuinfo_content(cpuinfo.as_deref()) { + return model; + } + + if let Some(model) = read_device_tree_model() { + return model; + } + + if let Some(content) = cpuinfo.as_deref() { + let cores = content + .lines() + .filter(|line| line.starts_with("processor")) + .count(); + if cores > 0 { + return format!("{} {}C", std::env::consts::ARCH, cores); + } + } + + std::env::consts::ARCH.to_string() +} + +fn parse_cpu_model_from_cpuinfo_content(content: Option<&str>) -> Option { + let content = content?; + + content + .lines() + .find(|line| line.starts_with("model name") || line.starts_with("Model")) + .and_then(|line| line.split(':').nth(1)) + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) +} + +fn read_device_tree_model() -> Option { + std::fs::read("/proc/device-tree/model") + .ok() + .and_then(|bytes| parse_device_tree_model_bytes(bytes.as_slice())) +} + +fn parse_device_tree_model_bytes(bytes: &[u8]) -> Option { + let model = String::from_utf8_lossy(bytes) + .trim_matches(|c: char| c == '\0' || c.is_whitespace()) + .to_string(); + + if model.is_empty() { + None + } else { + Some(model) + } +} + +static CPU_PREV_STATS: std::sync::OnceLock> = + std::sync::OnceLock::new(); + +fn get_cpu_usage() -> f32 { + let content = match std::fs::read_to_string("/proc/stat") { + Ok(c) => c, + Err(_) => return 0.0, + }; + + let cpu_line = match content.lines().next() { + Some(line) if line.starts_with("cpu ") => line, + _ => return 0.0, + }; + + let parts: Vec = cpu_line + .split_whitespace() + .skip(1) + .take(8) + .filter_map(|s| s.parse().ok()) + .collect(); + + if parts.len() < 4 { + return 0.0; + } + + let idle = parts[3] + parts.get(4).unwrap_or(&0); + let total: u64 = parts.iter().sum(); + + let prev_mutex = CPU_PREV_STATS.get_or_init(|| std::sync::Mutex::new((0, 0))); + let mut prev = prev_mutex.lock().unwrap(); + let (prev_idle, prev_total) = *prev; + + let idle_delta = idle.saturating_sub(prev_idle); + let total_delta = total.saturating_sub(prev_total); + *prev = (idle, total); + + if total_delta == 0 { + return 0.0; + } + + let usage = 100.0 * (1.0 - (idle_delta as f64 / total_delta as f64)); + usage as f32 +} + +struct MemInfo { + total: u64, + available: u64, +} + +fn get_meminfo() -> MemInfo { + let content = match std::fs::read_to_string("/proc/meminfo") { + Ok(c) => c, + Err(_) => { + return MemInfo { + total: 0, + available: 0, + } + } + }; + + let mut total = 0u64; + let mut available = 0u64; + + for line in content.lines() { + if line.starts_with("MemTotal:") { + if let Some(kb) = line + .split_whitespace() + .nth(1) + .and_then(|v| v.parse::().ok()) + { + total = kb * 1024; + } + } else if line.starts_with("MemAvailable:") { + if let Some(kb) = line + .split_whitespace() + .nth(1) + .and_then(|v| v.parse::().ok()) + { + available = kb * 1024; + } + } + + if total > 0 && available > 0 { + break; + } + } + + MemInfo { total, available } +} + +fn get_network_addresses() -> Vec { + let all_addrs = match nix::ifaddrs::getifaddrs() { + Ok(addrs) => addrs, + Err(_) => return Vec::new(), + }; + + let mut up_ifaces = std::collections::HashSet::new(); + let net_dir = match std::fs::read_dir("/sys/class/net") { + Ok(dir) => dir, + Err(_) => return Vec::new(), + }; + + for entry in net_dir.flatten() { + let iface_name = match entry.file_name().into_string() { + Ok(name) => name, + Err(_) => continue, + }; + + if iface_name == "lo" { + continue; + } + + let operstate_path = entry.path().join("operstate"); + let is_up = std::fs::read_to_string(&operstate_path) + .map(|s| s.trim() == "up") + .unwrap_or(false); + + if is_up { + up_ifaces.insert(iface_name); + } + } + + let mut addresses = Vec::new(); + let mut seen = std::collections::HashSet::new(); + for ifaddr in all_addrs { + let iface_name = &ifaddr.interface_name; + if iface_name == "lo" || !up_ifaces.contains(iface_name) { + continue; + } + + if let Some(addr) = ifaddr.address { + if let Some(sockaddr_in) = addr.as_sockaddr_in() { + let ip = sockaddr_in.ip(); + if ip.is_loopback() { + continue; + } + let ip_str = ip.to_string(); + if seen.insert((iface_name.clone(), ip_str.clone())) { + addresses.push(NetworkAddress { + interface: iface_name.clone(), + ip: ip_str, + }); + } + } else if let Some(sockaddr_in6) = addr.as_sockaddr_in6() { + let ip = sockaddr_in6.ip(); + if ip.is_loopback() || ip.is_unspecified() || ip.is_unicast_link_local() { + continue; + } + let ip_str = ip.to_string(); + if seen.insert((iface_name.clone(), ip_str.clone())) { + addresses.push(NetworkAddress { + interface: iface_name.clone(), + ip: ip_str, + }); + } + } + } + } + + addresses +} + +#[cfg(test)] +mod tests { + use super::{parse_cpu_model_from_cpuinfo_content, parse_device_tree_model_bytes}; + + #[test] + fn parse_cpu_model_from_model_name_field() { + let input = "processor\t: 0\nmodel name\t: Intel(R) Xeon(R)\n"; + assert_eq!( + parse_cpu_model_from_cpuinfo_content(input), + Some("Intel(R) Xeon(R)".to_string()) + ); + } + + #[test] + fn parse_cpu_model_from_model_field() { + let input = "processor\t: 0\nModel\t\t: Raspberry Pi 4 Model B Rev 1.4\n"; + assert_eq!( + parse_cpu_model_from_cpuinfo_content(input), + Some("Raspberry Pi 4 Model B Rev 1.4".to_string()) + ); + } + + #[test] + fn parse_device_tree_model_trimmed() { + let input = b"Onething OEC Box\0\n"; + assert_eq!( + parse_device_tree_model_bytes(input), + Some("Onething OEC Box".to_string()) + ); + } +} diff --git a/src/diagnostics/mod.rs b/src/diagnostics/mod.rs new file mode 100644 index 00000000..4bca8def --- /dev/null +++ b/src/diagnostics/mod.rs @@ -0,0 +1,47 @@ +//! Host diagnostics used by the web status API. + +use serde::Serialize; + +use crate::error::Result; + +#[derive(Debug, Clone, Serialize)] +pub struct DeviceInfo { + pub hostname: String, + pub cpu_model: String, + pub cpu_usage: f32, + pub memory_total: u64, + pub memory_used: u64, + pub network_addresses: Vec, + pub serial_ports: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub struct NetworkAddress { + pub interface: String, + pub ip: String, +} + +#[derive(Debug, Clone, Serialize)] +pub struct DiskSpaceInfo { + pub total: u64, + pub available: u64, + pub used: u64, +} + +#[cfg(unix)] +mod linux; +#[cfg(windows)] +mod windows; + +#[cfg(unix)] +use linux as platform; +#[cfg(windows)] +use windows as platform; + +pub fn get_disk_space(path: &std::path::Path) -> Result { + platform::get_disk_space(path) +} + +pub fn get_device_info() -> DeviceInfo { + platform::get_device_info() +} diff --git a/src/diagnostics/windows.rs b/src/diagnostics/windows.rs new file mode 100644 index 00000000..ce8685ef --- /dev/null +++ b/src/diagnostics/windows.rs @@ -0,0 +1,249 @@ +use super::{DeviceInfo, DiskSpaceInfo, NetworkAddress}; +use crate::error::{AppError, Result}; +use crate::utils::hostname_uname; +use std::ffi::CStr; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::{Mutex, OnceLock}; +use windows_sys::Win32::Foundation::{ERROR_BUFFER_OVERFLOW, ERROR_SUCCESS, FILETIME}; +use windows_sys::Win32::NetworkManagement::IpHelper::{ + GetAdaptersAddresses, GAA_FLAG_SKIP_ANYCAST, GAA_FLAG_SKIP_DNS_SERVER, GAA_FLAG_SKIP_MULTICAST, + IP_ADAPTER_ADDRESSES_LH, +}; +use windows_sys::Win32::NetworkManagement::Ndis::IfOperStatusUp; +use windows_sys::Win32::Networking::WinSock::{ + AF_INET, AF_INET6, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6, +}; +use windows_sys::Win32::System::SystemInformation::{ + GetNativeSystemInfo, GlobalMemoryStatusEx, MEMORYSTATUSEX, PROCESSOR_ARCHITECTURE_AMD64, + PROCESSOR_ARCHITECTURE_ARM64, PROCESSOR_ARCHITECTURE_INTEL, SYSTEM_INFO, +}; +use windows_sys::Win32::System::Threading::GetSystemTimes; + +pub fn get_disk_space(_path: &std::path::Path) -> Result { + Err(AppError::Internal( + "Disk space reporting is unavailable on Windows".to_string(), + )) +} + +pub fn get_device_info() -> DeviceInfo { + let (memory_total, memory_used) = get_memory_usage(); + + DeviceInfo { + hostname: hostname_uname(), + cpu_model: get_cpu_model(), + cpu_usage: get_cpu_usage(), + memory_total, + memory_used, + network_addresses: get_network_addresses(), + serial_ports: crate::utils::list_serial_ports(), + } +} + +fn get_cpu_model() -> String { + std::env::var("PROCESSOR_IDENTIFIER") + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(get_cpu_arch_label) +} + +fn get_cpu_arch_label() -> String { + let mut info = std::mem::MaybeUninit::::zeroed(); + unsafe { + GetNativeSystemInfo(info.as_mut_ptr()); + let info = info.assume_init(); + match info.Anonymous.Anonymous.wProcessorArchitecture { + PROCESSOR_ARCHITECTURE_AMD64 => "x86_64".to_string(), + PROCESSOR_ARCHITECTURE_ARM64 => "aarch64".to_string(), + PROCESSOR_ARCHITECTURE_INTEL => "x86".to_string(), + _ => std::env::consts::ARCH.to_string(), + } + } +} + +fn get_memory_usage() -> (u64, u64) { + let mut status = MEMORYSTATUSEX { + dwLength: std::mem::size_of::() as u32, + ..unsafe { std::mem::zeroed() } + }; + + let ok = unsafe { GlobalMemoryStatusEx(&mut status) }; + if ok == 0 { + return (0, 0); + } + + ( + status.ullTotalPhys, + status.ullTotalPhys.saturating_sub(status.ullAvailPhys), + ) +} + +fn get_cpu_usage() -> f32 { + static LAST_SAMPLE: OnceLock>> = OnceLock::new(); + + let Some(current) = read_cpu_times() else { + return 0.0; + }; + let sample = LAST_SAMPLE.get_or_init(|| Mutex::new(None)); + let Ok(mut last) = sample.lock() else { + return 0.0; + }; + + let (previous, current) = if let Some(previous) = last.replace(current) { + (previous, current) + } else { + drop(last); + std::thread::sleep(std::time::Duration::from_millis(100)); + let Some(next) = read_cpu_times() else { + return 0.0; + }; + if let Ok(mut last) = sample.lock() { + *last = Some(next); + } + (current, next) + }; + + let idle = current.idle.saturating_sub(previous.idle); + let kernel = current.kernel.saturating_sub(previous.kernel); + let user = current.user.saturating_sub(previous.user); + let total = kernel.saturating_add(user); + + if total == 0 { + return 0.0; + } + + ((total.saturating_sub(idle)) as f64 * 100.0 / total as f64).clamp(0.0, 100.0) as f32 +} + +#[derive(Clone, Copy)] +struct CpuTimes { + idle: u64, + kernel: u64, + user: u64, +} + +fn read_cpu_times() -> Option { + let mut idle = FILETIME { + dwLowDateTime: 0, + dwHighDateTime: 0, + }; + let mut kernel = idle; + let mut user = idle; + + let ok = unsafe { GetSystemTimes(&mut idle, &mut kernel, &mut user) }; + if ok == 0 { + return None; + } + + Some(CpuTimes { + idle: filetime_to_u64(idle), + kernel: filetime_to_u64(kernel), + user: filetime_to_u64(user), + }) +} + +fn filetime_to_u64(time: FILETIME) -> u64 { + ((time.dwHighDateTime as u64) << 32) | time.dwLowDateTime as u64 +} + +fn get_network_addresses() -> Vec { + let mut buffer_len = 15_000u32; + let flags = GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER; + + for _ in 0..2 { + let mut buffer = vec![0u8; buffer_len as usize]; + let ret = unsafe { + GetAdaptersAddresses( + 0, + flags, + std::ptr::null_mut(), + buffer.as_mut_ptr() as *mut IP_ADAPTER_ADDRESSES_LH, + &mut buffer_len, + ) + }; + + if ret == ERROR_BUFFER_OVERFLOW { + continue; + } + if ret != ERROR_SUCCESS { + return Vec::new(); + } + + let mut addresses = Vec::new(); + let mut adapter = buffer.as_ptr() as *const IP_ADAPTER_ADDRESSES_LH; + while !adapter.is_null() { + let adapter_ref = unsafe { &*adapter }; + if adapter_ref.OperStatus != IfOperStatusUp { + adapter = adapter_ref.Next; + continue; + } + + let interface = adapter_name(adapter_ref); + let mut unicast = adapter_ref.FirstUnicastAddress; + + while !unicast.is_null() { + let unicast_ref = unsafe { &*unicast }; + if let Some(ip) = sockaddr_to_ip(unicast_ref.Address.lpSockaddr) { + addresses.push(NetworkAddress { + interface: interface.clone(), + ip, + }); + } + unicast = unicast_ref.Next; + } + + adapter = adapter_ref.Next; + } + + addresses.sort_by(|a, b| a.interface.cmp(&b.interface).then(a.ip.cmp(&b.ip))); + addresses.dedup_by(|a, b| a.interface == b.interface && a.ip == b.ip); + return addresses; + } + + Vec::new() +} + +fn adapter_name(adapter: &IP_ADAPTER_ADDRESSES_LH) -> String { + unsafe { + if !adapter.FriendlyName.is_null() { + let mut len = 0usize; + while *adapter.FriendlyName.add(len) != 0 { + len += 1; + } + let name = + String::from_utf16_lossy(std::slice::from_raw_parts(adapter.FriendlyName, len)); + if !name.trim().is_empty() { + return name; + } + } + + if !adapter.AdapterName.is_null() { + return CStr::from_ptr(adapter.AdapterName.cast()) + .to_string_lossy() + .into_owned(); + } + } + + "unknown".to_string() +} + +fn sockaddr_to_ip(sockaddr: *const SOCKADDR) -> Option { + if sockaddr.is_null() { + return None; + } + + let family = unsafe { (*sockaddr).sa_family }; + match family { + AF_INET => { + let addr = unsafe { *(sockaddr as *const SOCKADDR_IN) }; + let bytes = unsafe { addr.sin_addr.S_un.S_addr.to_ne_bytes() }; + Some(Ipv4Addr::from(bytes).to_string()) + } + AF_INET6 => { + let addr = unsafe { *(sockaddr as *const SOCKADDR_IN6) }; + let bytes = unsafe { addr.sin6_addr.u.Byte }; + Some(Ipv6Addr::from(bytes).to_string()) + } + _ => None, + } +} diff --git a/src/extensions/manager.rs b/src/extensions/manager.rs index 7cb9d9a2..a1474e4d 100644 --- a/src/extensions/manager.rs +++ b/src/extensions/manager.rs @@ -1,5 +1,4 @@ use std::collections::{HashMap, VecDeque}; -use std::path::Path; use std::process::Stdio; use std::sync::Arc; @@ -13,8 +12,16 @@ use crate::events::EventBus; const LOG_BUFFER_SIZE: usize = 200; const LOG_BATCH_SIZE: usize = 16; +#[cfg(unix)] pub const TTYD_SOCKET_PATH: &str = "/var/run/one-kvm/ttyd.sock"; +#[cfg(windows)] +pub const TTYD_TCP_ADDR: &str = "127.0.0.1:7681"; +#[cfg(windows)] +const TTYD_TCP_HOST: &str = "127.0.0.1"; +#[cfg(windows)] +const TTYD_TCP_PORT: &str = "7681"; + struct ExtensionProcess { child: Child, logs: Arc>>, @@ -36,7 +43,7 @@ impl ExtensionManager { pub fn new() -> Self { let availability = ExtensionId::all() .iter() - .map(|id| (*id, Path::new(id.binary_path()).exists())) + .map(|id| (*id, id.binary_path().exists())) .collect(); Self { @@ -64,6 +71,20 @@ impl ExtensionManager { *self.availability.get(&id).unwrap_or(&false) } + fn is_enabled_for_config(id: ExtensionId, config: &ExtensionsConfig) -> bool { + match id { + ExtensionId::Ttyd => config.ttyd.enabled, + ExtensionId::Gostc => { + config.gostc.enabled + && !config.gostc.key.is_empty() + && !config.gostc.addr.trim().is_empty() + } + ExtensionId::Easytier => { + config.easytier.enabled && !config.easytier.network_name.is_empty() + } + } + } + pub async fn status(&self, id: ExtensionId) -> ExtensionStatus { if !self.check_available(id) { return ExtensionStatus::Unavailable; @@ -105,7 +126,11 @@ impl ExtensionManager { pub async fn start(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result<(), String> { if !self.check_available(id) { - return Err(format!("{} not found at {}", id, id.binary_path())); + return Err(format!( + "{} not found at {}", + id, + id.binary_path().display() + )); } self.stop(id).await.ok(); @@ -115,7 +140,7 @@ impl ExtensionManager { tracing::info!( "Starting extension {}: {} {}", id, - id.binary_path(), + id.binary_path().display(), Self::redact_args_for_log(&args).join(" ") ); @@ -232,15 +257,7 @@ impl ExtensionManager { ExtensionId::Ttyd => { let c = &config.ttyd; - Self::prepare_ttyd_socket().await?; - - let mut args = vec![ - "-i".to_string(), - TTYD_SOCKET_PATH.to_string(), - "-b".to_string(), - "/api/terminal".to_string(), - "-W".to_string(), - ]; + let mut args = Self::build_ttyd_listen_args().await?; args.push(c.shell.clone()); Ok(args) @@ -302,6 +319,43 @@ impl ExtensionManager { } } + #[cfg(unix)] + async fn build_ttyd_listen_args() -> Result, String> { + Self::prepare_ttyd_socket().await?; + + Ok(vec![ + "-i".to_string(), + TTYD_SOCKET_PATH.to_string(), + "-b".to_string(), + "/api/terminal".to_string(), + "-W".to_string(), + ]) + } + + #[cfg(windows)] + async fn build_ttyd_listen_args() -> Result, String> { + let cwd = std::env::var("USERPROFILE") + .ok() + .filter(|path| !path.trim().is_empty()) + .unwrap_or_else(|| { + std::env::current_dir() + .map(|path| path.to_string_lossy().to_string()) + .unwrap_or_else(|_| ".".to_string()) + }); + + Ok(vec![ + "-i".to_string(), + TTYD_TCP_HOST.to_string(), + "-p".to_string(), + TTYD_TCP_PORT.to_string(), + "-b".to_string(), + "/api/terminal".to_string(), + "-w".to_string(), + cwd, + "-W".to_string(), + ]) + } + fn redact_args_for_log(args: &[String]) -> Vec { let mut redacted = Vec::with_capacity(args.len()); let mut redact_next = false; @@ -330,8 +384,9 @@ impl ExtensionManager { redacted } + #[cfg(unix)] async fn prepare_ttyd_socket() -> Result<(), String> { - let socket_path = Path::new(TTYD_SOCKET_PATH); + let socket_path = std::path::Path::new(TTYD_SOCKET_PATH); if let Some(socket_dir) = socket_path.parent() { if !socket_dir.exists() { @@ -357,18 +412,7 @@ impl ExtensionManager { let checks: Vec<_> = ExtensionId::all() .iter() .filter_map(|id| { - let should_run = match id { - ExtensionId::Ttyd => config.ttyd.enabled, - ExtensionId::Gostc => { - config.gostc.enabled - && !config.gostc.key.is_empty() - && !config.gostc.addr.trim().is_empty() - } - ExtensionId::Easytier => { - config.easytier.enabled && !config.easytier.network_name.is_empty() - } - }; - if should_run && self.check_available(*id) { + if Self::is_enabled_for_config(*id, config) && self.check_available(*id) { Some(*id) } else { None @@ -404,41 +448,15 @@ impl ExtensionManager { } pub async fn start_enabled(&self, config: &ExtensionsConfig) { - use futures::Future; - use std::pin::Pin; - - let mut start_futures: Vec + Send + '_>>> = Vec::new(); - - if config.ttyd.enabled && self.check_available(ExtensionId::Ttyd) { - start_futures.push(Box::pin(async { - if let Err(e) = self.start(ExtensionId::Ttyd, config).await { - tracing::error!("Failed to start ttyd: {}", e); + let start_futures: Vec<_> = ExtensionId::all() + .iter() + .filter(|id| Self::is_enabled_for_config(**id, config) && self.check_available(**id)) + .map(|id| async move { + if let Err(e) = self.start(*id, config).await { + tracing::error!("Failed to start {}: {}", id, e); } - })); - } - - if config.gostc.enabled - && !config.gostc.key.is_empty() - && !config.gostc.addr.trim().is_empty() - && self.check_available(ExtensionId::Gostc) - { - start_futures.push(Box::pin(async { - if let Err(e) = self.start(ExtensionId::Gostc, config).await { - tracing::error!("Failed to start gostc: {}", e); - } - })); - } - - if config.easytier.enabled - && !config.easytier.network_name.is_empty() - && self.check_available(ExtensionId::Easytier) - { - start_futures.push(Box::pin(async { - if let Err(e) = self.start(ExtensionId::Easytier, config).await { - tracing::error!("Failed to start easytier: {}", e); - } - })); - } + }) + .collect(); futures::future::join_all(start_futures).await; } diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 0d242ff7..7f599dab 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,5 +1,10 @@ mod manager; +mod software; mod types; -pub use manager::{ExtensionManager, TTYD_SOCKET_PATH}; +pub use manager::ExtensionManager; +#[cfg(unix)] +pub use manager::TTYD_SOCKET_PATH; +#[cfg(windows)] +pub use manager::TTYD_TCP_ADDR; pub use types::*; diff --git a/src/extensions/software.rs b/src/extensions/software.rs new file mode 100644 index 00000000..cf17eb91 --- /dev/null +++ b/src/extensions/software.rs @@ -0,0 +1,15 @@ +use std::path::PathBuf; + +use super::ExtensionId; + +#[cfg_attr(windows, path = "software_windows.rs")] +#[cfg_attr(not(windows), path = "software_linux.rs")] +mod platform; + +pub fn binary_path(id: ExtensionId) -> PathBuf { + platform::binary_path(id) +} + +pub fn default_ttyd_shell() -> &'static str { + platform::default_ttyd_shell() +} diff --git a/src/extensions/software_linux.rs b/src/extensions/software_linux.rs new file mode 100644 index 00000000..a3757c79 --- /dev/null +++ b/src/extensions/software_linux.rs @@ -0,0 +1,19 @@ +use std::path::PathBuf; + +use super::ExtensionId; + +pub fn default_binary_path(id: ExtensionId) -> &'static str { + match id { + ExtensionId::Ttyd => "/usr/bin/ttyd", + ExtensionId::Gostc => "/usr/bin/gostc", + ExtensionId::Easytier => "/usr/bin/easytier-core", + } +} + +pub fn binary_path(id: ExtensionId) -> PathBuf { + PathBuf::from(default_binary_path(id)) +} + +pub fn default_ttyd_shell() -> &'static str { + "/bin/bash" +} diff --git a/src/extensions/software_windows.rs b/src/extensions/software_windows.rs new file mode 100644 index 00000000..74f67a90 --- /dev/null +++ b/src/extensions/software_windows.rs @@ -0,0 +1,47 @@ +use std::path::PathBuf; + +use super::ExtensionId; + +pub fn default_binary_path(id: ExtensionId) -> &'static str { + match id { + ExtensionId::Ttyd => "ttyd.win32.exe", + ExtensionId::Gostc => "gostc.exe", + ExtensionId::Easytier => "easytier-core.exe", + } +} + +pub fn binary_path(id: ExtensionId) -> PathBuf { + if id == ExtensionId::Ttyd { + if let Some(path) = env_path("ONE_KVM_TTYD_PATH") { + return path; + } + } + + find_in_app_dir(default_binary_path(id)) + .unwrap_or_else(|| PathBuf::from(default_binary_path(id))) +} + +pub fn default_ttyd_shell() -> &'static str { + "cmd" +} + +fn env_path(name: &str) -> Option { + std::env::var(name) + .ok() + .map(|path| path.trim().to_string()) + .filter(|path| !path.is_empty()) + .map(PathBuf::from) +} + +fn find_in_app_dir(binary_name: &str) -> Option { + if let Ok(exe_path) = std::env::current_exe() { + if let Some(exe_dir) = exe_path.parent() { + let bundled = exe_dir.join(binary_name); + if bundled.exists() { + return Some(bundled); + } + } + } + + None +} diff --git a/src/extensions/types.rs b/src/extensions/types.rs index ff6d2c99..527e63ae 100644 --- a/src/extensions/types.rs +++ b/src/extensions/types.rs @@ -1,6 +1,8 @@ use serde::{Deserialize, Serialize}; use typeshare::typeshare; +use super::software; + #[typeshare] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] @@ -11,12 +13,8 @@ pub enum ExtensionId { } impl ExtensionId { - pub fn binary_path(&self) -> &'static str { - match self { - Self::Ttyd => "/usr/bin/ttyd", - Self::Gostc => "/usr/bin/gostc", - Self::Easytier => "/usr/bin/easytier-core", - } + pub fn binary_path(&self) -> std::path::PathBuf { + software::binary_path(*self) } pub fn all() -> &'static [ExtensionId] { @@ -54,7 +52,6 @@ pub enum ExtensionStatus { Unavailable, Stopped, Running { pid: u32 }, - Failed { error: String }, } impl ExtensionStatus { @@ -75,7 +72,7 @@ impl Default for TtydConfig { fn default() -> Self { Self { enabled: false, - shell: "/bin/bash".to_string(), + shell: software::default_ttyd_shell().to_string(), } } } diff --git a/src/hid/ch9329.rs b/src/hid/ch9329.rs index 4120e901..7641a2f0 100644 --- a/src/hid/ch9329.rs +++ b/src/hid/ch9329.rs @@ -10,29 +10,24 @@ use async_trait::async_trait; use parking_lot::{Mutex, RwLock}; -use serde::{Deserialize, Serialize}; use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU8, Ordering}; use std::sync::{mpsc, Arc}; use std::thread; use std::time::{Duration, Instant}; use tokio::sync::watch; -use tracing::{info, trace, warn}; +use tracing::{info, trace}; use super::backend::{HidBackend, HidBackendRuntimeSnapshot}; +use super::ch9329_proto::{ + build_packet, cmd, expected_response_cmd, try_extract_response, ChipInfo, LedStatus, Response, + DEFAULT_ADDR, DEFAULT_BAUD_RATE, MAX_PACKET_SIZE, +}; use super::types::{KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType}; use crate::error::{AppError, Result}; use crate::events::LedState; -const PACKET_HEADER: [u8; 2] = [0x57, 0xAB]; - -const DEFAULT_ADDR: u8 = 0x00; - -pub const DEFAULT_BAUD_RATE: u32 = 9600; - const RESPONSE_TIMEOUT_MS: u64 = 500; -const MAX_DATA_LEN: usize = 64; - const CH9329_MOUSE_RESOLUTION: u32 = 4096; const PROBE_INTERVAL_MS: u64 = 100; @@ -41,173 +36,6 @@ const RECONNECT_DELAY_MS: u64 = 2000; const INIT_WAIT_MS: u64 = 3000; -pub mod cmd { - pub const GET_INFO: u8 = 0x01; - pub const SEND_KB_GENERAL_DATA: u8 = 0x02; - pub const SEND_KB_MEDIA_DATA: u8 = 0x03; - pub const SEND_MS_ABS_DATA: u8 = 0x04; - pub const SEND_MS_REL_DATA: u8 = 0x05; - pub const SEND_MY_HID_DATA: u8 = 0x06; - pub const SET_DEFAULT_CFG: u8 = 0x0C; - pub const RESET: u8 = 0x0F; -} - -const RESPONSE_SUCCESS_MASK: u8 = 0x80; -const RESPONSE_ERROR_MASK: u8 = 0xC0; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(u8)] -pub enum Ch9329Error { - Success = 0x00, - Timeout = 0xE1, - InvalidHeader = 0xE2, - InvalidCommand = 0xE3, - ChecksumError = 0xE4, - ParameterError = 0xE5, - OperationFailed = 0xE6, -} - -impl From for Ch9329Error { - fn from(code: u8) -> Self { - match code { - 0x00 => Ch9329Error::Success, - 0xE1 => Ch9329Error::Timeout, - 0xE2 => Ch9329Error::InvalidHeader, - 0xE3 => Ch9329Error::InvalidCommand, - 0xE4 => Ch9329Error::ChecksumError, - 0xE5 => Ch9329Error::ParameterError, - 0xE6 => Ch9329Error::OperationFailed, - _ => Ch9329Error::OperationFailed, - } - } -} - -impl std::fmt::Display for Ch9329Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Ch9329Error::Success => write!(f, "Success"), - Ch9329Error::Timeout => write!(f, "Serial receive timeout"), - Ch9329Error::InvalidHeader => write!(f, "Invalid packet header"), - Ch9329Error::InvalidCommand => write!(f, "Invalid command code"), - Ch9329Error::ChecksumError => write!(f, "Checksum mismatch"), - Ch9329Error::ParameterError => write!(f, "Parameter error"), - Ch9329Error::OperationFailed => write!(f, "Operation failed"), - } - } -} - -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct ChipInfo { - pub version: String, - pub version_raw: u8, - pub usb_connected: bool, - pub num_lock: bool, - pub caps_lock: bool, - pub scroll_lock: bool, -} - -impl ChipInfo { - pub fn from_response(data: &[u8]) -> Option { - if data.len() < 8 { - return None; - } - - let version_raw = data[0]; - let version = format!("V{}.{}", version_raw >> 4, version_raw & 0x0F); - let usb_connected = data[1] == 0x01; - let led_status = data[2]; - - Some(Self { - version, - version_raw, - usb_connected, - num_lock: (led_status & 0x01) != 0, - caps_lock: (led_status & 0x02) != 0, - scroll_lock: (led_status & 0x04) != 0, - }) - } -} - -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] -pub struct LedStatus { - pub num_lock: bool, - pub caps_lock: bool, - pub scroll_lock: bool, -} - -impl From for LedStatus { - fn from(byte: u8) -> Self { - Self { - num_lock: (byte & 0x01) != 0, - caps_lock: (byte & 0x02) != 0, - scroll_lock: (byte & 0x04) != 0, - } - } -} - -#[derive(Debug)] -pub struct Response { - pub address: u8, - pub cmd: u8, - pub data: Vec, - pub is_error: bool, - pub error_code: Option, -} - -impl Response { - pub fn parse(bytes: &[u8]) -> Option { - if bytes.len() < 6 { - return None; - } - - if bytes[0] != PACKET_HEADER[0] || bytes[1] != PACKET_HEADER[1] { - return None; - } - - let address = bytes[2]; - let cmd = bytes[3]; - let len = bytes[4] as usize; - - if bytes.len() < 5 + len + 1 { - return None; - } - - let expected_checksum = bytes[5 + len]; - let calculated_checksum = bytes[..5 + len] - .iter() - .fold(0u8, |acc, &x| acc.wrapping_add(x)); - - if expected_checksum != calculated_checksum { - warn!( - "CH9329 checksum mismatch: expected {:02X}, got {:02X}", - expected_checksum, calculated_checksum - ); - return None; - } - - let data = bytes[5..5 + len].to_vec(); - let is_error = (cmd & RESPONSE_ERROR_MASK) == RESPONSE_ERROR_MASK; - let error_code = if is_error && !data.is_empty() { - Some(Ch9329Error::from(data[0])) - } else { - None - }; - - Some(Self { - address, - cmd, - data, - is_error, - error_code, - }) - } - - pub fn is_success(&self) -> bool { - !self.is_error && (self.data.is_empty() || self.data[0] == Ch9329Error::Success as u8) - } -} - -const MAX_PACKET_SIZE: usize = 70; struct Ch9329RuntimeState { initialized: AtomicBool, @@ -331,6 +159,13 @@ impl Ch9329Backend { } pub fn check_port_exists(&self) -> bool { + #[cfg(windows)] + { + return crate::utils::list_serial_ports() + .iter() + .any(|port| port.eq_ignore_ascii_case(&self.port_path)); + } + #[cfg(not(windows))] std::path::Path::new(&self.port_path).exists() } @@ -358,39 +193,8 @@ impl Ch9329Backend { } #[inline] - fn calculate_checksum(data: &[u8]) -> u8 { - data.iter().fold(0u8, |acc, &x| acc.wrapping_add(x)) - } - - #[inline] - fn build_packet_buf(address: u8, cmd: u8, data: &[u8]) -> ([u8; MAX_PACKET_SIZE], usize) { - debug_assert!( - data.len() <= MAX_DATA_LEN, - "Data too long for CH9329 packet" - ); - - let len = data.len() as u8; - let packet_len = 6 + data.len(); - let mut packet = [0u8; MAX_PACKET_SIZE]; - - packet[0] = PACKET_HEADER[0]; - packet[1] = PACKET_HEADER[1]; - packet[2] = address; - packet[3] = cmd; - packet[4] = len; - packet[5..5 + data.len()].copy_from_slice(data); - let checksum = Self::calculate_checksum(&packet[..5 + data.len()]); - packet[5 + data.len()] = checksum; - - (packet, packet_len) - } - - fn build_packet(address: u8, cmd: u8, data: &[u8]) -> Vec { - let (buf, len) = Self::build_packet_buf(address, cmd, data); - buf[..len].to_vec() - } - fn open_port(port_path: &str, baud_rate: u32) -> Result> { + #[cfg(not(windows))] if !std::path::Path::new(port_path).exists() { return Err(Self::backend_error( format!("Serial port {} not found", port_path), @@ -410,46 +214,13 @@ impl Ch9329Backend { cmd: u8, data: &[u8], ) -> Result<()> { - let packet = Self::build_packet(address, cmd, data); + let packet = build_packet(address, cmd, data); port.write_all(&packet).map_err(|e| { Self::backend_error(format!("Failed to write to CH9329: {}", e), "write_failed") })?; Ok(()) } - fn try_extract_response(buffer: &[u8]) -> Option<(Response, usize)> { - let mut offset = 0; - while offset + 6 <= buffer.len() { - if buffer[offset] != PACKET_HEADER[0] || buffer[offset + 1] != PACKET_HEADER[1] { - offset += 1; - continue; - } - - let len = buffer[offset + 4] as usize; - let frame_len = 6 + len; - if offset + frame_len > buffer.len() { - return None; - } - - let frame = &buffer[offset..offset + frame_len]; - if let Some(response) = Response::parse(frame) { - return Some((response, offset + frame_len)); - } - - offset += 1; - } - - None - } - - fn expected_response_cmd(cmd: u8, is_error: bool) -> u8 { - cmd | if is_error { - RESPONSE_ERROR_MASK - } else { - RESPONSE_SUCCESS_MASK - } - } - fn xfer_packet( port: &mut dyn serialport::SerialPort, address: u8, @@ -460,8 +231,8 @@ impl Ch9329Backend { let mut pending = Vec::with_capacity(128); let deadline = Instant::now() + Duration::from_millis(RESPONSE_TIMEOUT_MS); - let expected_ok = Self::expected_response_cmd(cmd, false); - let expected_err = Self::expected_response_cmd(cmd, true); + let expected_ok = expected_response_cmd(cmd, false); + let expected_err = expected_response_cmd(cmd, true); loop { let mut chunk = [0u8; 128]; @@ -469,7 +240,7 @@ impl Ch9329Backend { Ok(n) if n > 0 => { pending.extend_from_slice(&chunk[..n]); - while let Some((response, consumed)) = Self::try_extract_response(&pending) { + while let Some((response, consumed)) = try_extract_response(&pending) { pending.drain(..consumed); if response.cmd == expected_ok || response.cmd == expected_err { return Ok(response); @@ -1023,7 +794,14 @@ impl HidBackend for Ch9329Backend { let mut online = initialized && self.runtime.online.load(Ordering::Relaxed); let mut error = self.runtime.last_error.read().clone(); - if initialized && !self.check_port_exists() { + #[cfg(windows)] + let port_still_present = crate::utils::list_serial_ports() + .iter() + .any(|port| port.eq_ignore_ascii_case(&self.port_path)); + #[cfg(not(windows))] + let port_still_present = self.check_port_exists(); + + if initialized && !port_still_present { online = false; error = Some(( format!("Serial port {} not found", self.port_path), @@ -1066,14 +844,15 @@ impl HidBackend for Ch9329Backend { #[cfg(test)] mod tests { use super::*; + use super::ch9329_proto::{build_packet, calculate_checksum}; #[test] fn test_packet_building() { - let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::GET_INFO, &[]); + let packet = build_packet(DEFAULT_ADDR, cmd::GET_INFO, &[]); assert_eq!(packet, vec![0x57, 0xAB, 0x00, 0x01, 0x00, 0x03]); let data = [0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; // 'A' key - let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_KB_GENERAL_DATA, &data); + let packet = build_packet(DEFAULT_ADDR, cmd::SEND_KB_GENERAL_DATA, &data); assert_eq!(packet[0], 0x57); // Header assert_eq!(packet[1], 0xAB); // Header @@ -1090,7 +869,7 @@ mod tests { #[test] fn test_relative_mouse_packet() { let data = [0x01, 0x00, 50u8, 0x00, 0x00]; - let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_MS_REL_DATA, &data); + let packet = build_packet(DEFAULT_ADDR, cmd::SEND_MS_REL_DATA, &data); assert_eq!(packet[0], 0x57); assert_eq!(packet[1], 0xAB); @@ -1105,13 +884,13 @@ mod tests { #[test] fn test_checksum_calculation() { let packet = [0x57u8, 0xAB, 0x00, 0x01, 0x00]; - let checksum = Ch9329Backend::calculate_checksum(&packet); + let checksum = calculate_checksum(&packet); assert_eq!(checksum, 0x03); let packet = [ 0x57u8, 0xAB, 0x00, 0x02, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, ]; - let checksum = Ch9329Backend::calculate_checksum(&packet); + let checksum = calculate_checksum(&packet); assert_eq!(checksum, 0x10); } diff --git a/src/hid/ch9329_proto.rs b/src/hid/ch9329_proto.rs new file mode 100644 index 00000000..1b0b27d8 --- /dev/null +++ b/src/hid/ch9329_proto.rs @@ -0,0 +1,225 @@ +//! Shared CH9329 protocol types and packet helpers. + +use serde::{Deserialize, Serialize}; + +const PACKET_HEADER: [u8; 2] = [0x57, 0xAB]; +pub const RESPONSE_SUCCESS_MASK: u8 = 0x80; +pub const RESPONSE_ERROR_MASK: u8 = 0xC0; + +pub const DEFAULT_ADDR: u8 = 0x00; +pub const DEFAULT_BAUD_RATE: u32 = 9600; +pub const MAX_DATA_LEN: usize = 64; +pub const MAX_PACKET_SIZE: usize = 70; + +pub mod cmd { + pub const GET_INFO: u8 = 0x01; + pub const SEND_KB_GENERAL_DATA: u8 = 0x02; + pub const SEND_KB_MEDIA_DATA: u8 = 0x03; + pub const SEND_MS_ABS_DATA: u8 = 0x04; + pub const SEND_MS_REL_DATA: u8 = 0x05; + pub const RESET: u8 = 0x0F; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum Ch9329Error { + Success = 0x00, + Timeout = 0xE1, + InvalidHeader = 0xE2, + InvalidCommand = 0xE3, + ChecksumError = 0xE4, + ParameterError = 0xE5, + OperationFailed = 0xE6, +} + +impl From for Ch9329Error { + fn from(code: u8) -> Self { + match code { + 0x00 => Ch9329Error::Success, + 0xE1 => Ch9329Error::Timeout, + 0xE2 => Ch9329Error::InvalidHeader, + 0xE3 => Ch9329Error::InvalidCommand, + 0xE4 => Ch9329Error::ChecksumError, + 0xE5 => Ch9329Error::ParameterError, + 0xE6 => Ch9329Error::OperationFailed, + _ => Ch9329Error::OperationFailed, + } + } +} + +impl std::fmt::Display for Ch9329Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Ch9329Error::Success => write!(f, "Success"), + Ch9329Error::Timeout => write!(f, "Serial receive timeout"), + Ch9329Error::InvalidHeader => write!(f, "Invalid packet header"), + Ch9329Error::InvalidCommand => write!(f, "Invalid command code"), + Ch9329Error::ChecksumError => write!(f, "Checksum mismatch"), + Ch9329Error::ParameterError => write!(f, "Parameter error"), + Ch9329Error::OperationFailed => write!(f, "Operation failed"), + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct ChipInfo { + pub version: String, + pub version_raw: u8, + pub usb_connected: bool, + pub num_lock: bool, + pub caps_lock: bool, + pub scroll_lock: bool, +} + +impl ChipInfo { + pub fn from_response(data: &[u8]) -> Option { + if data.len() < 8 { + return None; + } + + let version_raw = data[0]; + let version = format!("V{}.{}", version_raw >> 4, version_raw & 0x0F); + let usb_connected = data[1] == 0x01; + let led_status = data[2]; + + Some(Self { + version, + version_raw, + usb_connected, + num_lock: (led_status & 0x01) != 0, + caps_lock: (led_status & 0x02) != 0, + scroll_lock: (led_status & 0x04) != 0, + }) + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct LedStatus { + pub num_lock: bool, + pub caps_lock: bool, + pub scroll_lock: bool, +} + +impl From for LedStatus { + fn from(byte: u8) -> Self { + Self { + num_lock: (byte & 0x01) != 0, + caps_lock: (byte & 0x02) != 0, + scroll_lock: (byte & 0x04) != 0, + } + } +} + +#[derive(Debug)] +pub struct Response { + pub cmd: u8, + pub data: Vec, + pub is_error: bool, + pub error_code: Option, +} + +impl Response { + pub fn parse(bytes: &[u8]) -> Option { + if bytes.len() < 6 || bytes[0] != PACKET_HEADER[0] || bytes[1] != PACKET_HEADER[1] { + return None; + } + + let cmd = bytes[3]; + let len = bytes[4] as usize; + if bytes.len() < 5 + len + 1 { + return None; + } + + let expected_checksum = bytes[5 + len]; + let calculated_checksum = bytes[..5 + len] + .iter() + .fold(0u8, |acc, &x| acc.wrapping_add(x)); + if expected_checksum != calculated_checksum { + tracing::warn!( + "CH9329 checksum mismatch: expected {:02X}, got {:02X}", + expected_checksum, + calculated_checksum + ); + return None; + } + + let data = bytes[5..5 + len].to_vec(); + let is_error = (cmd & RESPONSE_ERROR_MASK) == RESPONSE_ERROR_MASK; + let error_code = if is_error && !data.is_empty() { + Some(Ch9329Error::from(data[0])) + } else { + None + }; + + Some(Self { + cmd, + data, + is_error, + error_code, + }) + } +} + +#[inline] +pub fn calculate_checksum(data: &[u8]) -> u8 { + data.iter().fold(0u8, |acc, &x| acc.wrapping_add(x)) +} + +#[inline] +pub fn build_packet_buf(address: u8, cmd: u8, data: &[u8]) -> ([u8; MAX_PACKET_SIZE], usize) { + debug_assert!(data.len() <= MAX_DATA_LEN, "Data too long for CH9329 packet"); + + let len = data.len() as u8; + let packet_len = 6 + data.len(); + let mut packet = [0u8; MAX_PACKET_SIZE]; + + packet[0] = PACKET_HEADER[0]; + packet[1] = PACKET_HEADER[1]; + packet[2] = address; + packet[3] = cmd; + packet[4] = len; + packet[5..5 + data.len()].copy_from_slice(data); + packet[5 + data.len()] = calculate_checksum(&packet[..5 + data.len()]); + + (packet, packet_len) +} + +#[inline] +pub fn build_packet(address: u8, cmd: u8, data: &[u8]) -> Vec { + let (buf, len) = build_packet_buf(address, cmd, data); + buf[..len].to_vec() +} + +#[inline] +pub fn expected_response_cmd(cmd: u8, is_error: bool) -> u8 { + cmd | if is_error { + RESPONSE_ERROR_MASK + } else { + RESPONSE_SUCCESS_MASK + } +} + +pub fn try_extract_response(buffer: &[u8]) -> Option<(Response, usize)> { + let mut offset = 0; + while offset + 6 <= buffer.len() { + if buffer[offset] != PACKET_HEADER[0] || buffer[offset + 1] != PACKET_HEADER[1] { + offset += 1; + continue; + } + + let len = buffer[offset + 4] as usize; + let frame_len = 6 + len; + if offset + frame_len > buffer.len() { + return None; + } + + let frame = &buffer[offset..offset + frame_len]; + if let Some(response) = Response::parse(frame) { + return Some((response, offset + frame_len)); + } + + offset += 1; + } + + None +} diff --git a/src/hid/factory.rs b/src/hid/factory.rs new file mode 100644 index 00000000..4812d8dc --- /dev/null +++ b/src/hid/factory.rs @@ -0,0 +1,80 @@ +use std::sync::Arc; + +use tracing::{info, warn}; + +use super::{ch9329, HidBackend, HidBackendType}; +use crate::error::{AppError, Result}; +#[cfg(unix)] +use crate::otg::OtgService; + +pub struct HidBackendFactory { + #[cfg(unix)] + otg_service: Option>, +} + +impl HidBackendFactory { + #[cfg(unix)] + pub fn new(otg_service: Option>) -> Self { + Self { otg_service } + } + + #[cfg(not(unix))] + pub fn new() -> Self { + Self {} + } + + pub async fn create_initialized( + &self, + backend_type: &HidBackendType, + ) -> Result>> { + let backend = match self.create(backend_type).await? { + Some(backend) => backend, + None => return Ok(None), + }; + + backend.init().await?; + Ok(Some(backend)) + } + + async fn create(&self, backend_type: &HidBackendType) -> Result>> { + match backend_type { + HidBackendType::Otg => self.create_otg_backend().await.map(Some), + HidBackendType::Ch9329 { port, baud_rate } => { + info!( + "Initializing CH9329 HID backend on {} @ {} baud", + port, baud_rate + ); + Ok(Some(Arc::new(ch9329::Ch9329Backend::with_baud_rate( + port, *baud_rate, + )?))) + } + HidBackendType::None => { + warn!("HID backend disabled"); + Ok(None) + } + } + } + + #[cfg(unix)] + async fn create_otg_backend(&self) -> Result> { + let otg_service = self + .otg_service + .as_ref() + .ok_or_else(|| AppError::Config("OTG backend not available".to_string()))?; + + let handles = otg_service + .hid_device_paths() + .await + .ok_or_else(|| AppError::Config("OTG HID paths are not available".to_string()))?; + + info!("Creating OTG HID backend from device paths"); + Ok(Arc::new(super::otg::OtgBackend::from_handles(handles)?)) + } + + #[cfg(not(unix))] + async fn create_otg_backend(&self) -> Result> { + Err(AppError::Config( + "OTG HID is only available on Linux".to_string(), + )) + } +} diff --git a/src/hid/mod.rs b/src/hid/mod.rs index 2bd572ca..8e30ff30 100644 --- a/src/hid/mod.rs +++ b/src/hid/mod.rs @@ -1,11 +1,16 @@ //! HID path: browser (WebSocket or WebRTC DataChannel) → queue → OTG gadget or CH9329. pub mod backend; +mod ch9329_proto; pub mod ch9329; pub mod consumer; pub mod datachannel; +mod factory; pub mod keyboard; +#[cfg(unix)] pub mod otg; +#[cfg(unix)] +mod otg_device; pub mod types; pub mod websocket; @@ -95,7 +100,9 @@ use tracing::{info, warn}; use crate::error::{AppError, Result}; use crate::events::EventBus; +#[cfg(unix)] use crate::otg::OtgService; +use factory::HidBackendFactory; use tokio::sync::mpsc; use tokio::sync::Mutex; use tokio::task::JoinHandle; @@ -112,7 +119,7 @@ enum QueuedHidEvent { } pub struct HidController { - otg_service: Option>, + backend_factory: HidBackendFactory, backend: Arc>>>, backend_type: Arc>, events: Arc>>>, @@ -127,11 +134,33 @@ pub struct HidController { } impl HidController { + #[cfg(unix)] pub fn new(backend_type: HidBackendType, otg_service: Option>) -> Self { let (hid_tx, hid_rx) = mpsc::channel(HID_EVENT_QUEUE_CAPACITY); Self { - otg_service, backend: Arc::new(RwLock::new(None)), + backend_factory: HidBackendFactory::new(otg_service), + backend_type: Arc::new(RwLock::new(backend_type.clone())), + events: Arc::new(tokio::sync::RwLock::new(None)), + runtime_state: Arc::new(RwLock::new(HidRuntimeState::from_backend_type( + &backend_type, + ))), + hid_tx, + hid_rx: Mutex::new(Some(hid_rx)), + pending_move: Arc::new(parking_lot::Mutex::new(None)), + pending_move_flag: Arc::new(AtomicBool::new(false)), + hid_worker: Mutex::new(None), + runtime_worker: Mutex::new(None), + backend_available: Arc::new(AtomicBool::new(false)), + } + } + + #[cfg(not(unix))] + pub fn new(backend_type: HidBackendType) -> Self { + let (hid_tx, hid_rx) = mpsc::channel(HID_EVENT_QUEUE_CAPACITY); + Self { + backend: Arc::new(RwLock::new(None)), + backend_factory: HidBackendFactory::new(), backend_type: Arc::new(RwLock::new(backend_type.clone())), events: Arc::new(tokio::sync::RwLock::new(None)), runtime_state: Arc::new(RwLock::new(HidRuntimeState::from_backend_type( @@ -153,51 +182,22 @@ impl HidController { pub async fn init(&self) -> Result<()> { let backend_type = self.backend_type.read().await.clone(); - let backend: Arc = match backend_type { - HidBackendType::Otg => { - let otg_service = self - .otg_service - .as_ref() - .ok_or_else(|| AppError::Internal("OtgService not available".into()))?; - - let handles = otg_service.hid_device_paths().await.ok_or_else(|| { - AppError::Config("OTG HID paths are not available".to_string()) - })?; - - info!("Creating OTG HID backend from device paths"); - Arc::new(otg::OtgBackend::from_handles(handles)?) - } - HidBackendType::Ch9329 { - ref port, - baud_rate, - } => { - info!( - "Initializing CH9329 HID backend on {} @ {} baud", - port, baud_rate - ); - Arc::new(ch9329::Ch9329Backend::with_baud_rate(port, baud_rate)?) - } - HidBackendType::None => { - warn!("HID backend disabled"); - return Ok(()); - } - }; - - if let Err(e) = backend.init().await { - self.backend_available.store(false, Ordering::Release); - let error_state = { - let backend_type = self.backend_type.read().await.clone(); + let backend = match self.backend_factory.create_initialized(&backend_type).await { + Ok(Some(backend)) => backend, + Ok(None) => return Ok(()), + Err(error) => { + self.backend_available.store(false, Ordering::Release); let current = self.runtime_state.read().await.clone(); - HidRuntimeState::with_error( + let error_state = HidRuntimeState::with_error( &backend_type, ¤t, - format!("Failed to initialize HID backend: {}", e), + format!("Failed to initialize HID backend: {}", error), "init_failed", - ) - }; - self.apply_runtime_state(error_state).await; - return Err(e); - } + ); + self.apply_runtime_state(error_state).await; + return Err(error); + } + }; *self.backend.write().await = Some(backend); self.sync_runtime_state_from_backend().await; @@ -298,73 +298,15 @@ impl HidController { } } - let new_backend: Option> = match new_backend_type { - HidBackendType::Otg => { - info!("Initializing OTG HID backend"); - - let otg_service = match self.otg_service.as_ref() { - Some(svc) => svc, - None => { - warn!("OTG backend requires OtgService, but it's not available"); - return Err(AppError::Config( - "OTG backend not available (OtgService missing)".to_string(), - )); - } - }; - - match otg_service.hid_device_paths().await { - Some(handles) => match otg::OtgBackend::from_handles(handles) { - Ok(backend) => { - let backend = Arc::new(backend); - match backend.init().await { - Ok(_) => { - info!("OTG backend initialized successfully"); - Some(backend) - } - Err(e) => { - warn!("Failed to initialize OTG backend: {}", e); - None - } - } - } - Err(e) => { - warn!("Failed to create OTG backend: {}", e); - None - } - }, - None => { - warn!("OTG HID paths are not available"); - None - } - } - } - HidBackendType::Ch9329 { - ref port, - baud_rate, - } => { - info!( - "Initializing CH9329 HID backend on {} @ {} baud", - port, baud_rate - ); - match ch9329::Ch9329Backend::with_baud_rate(port, baud_rate) { - Ok(b) => { - let backend = Arc::new(b); - match backend.init().await { - Ok(_) => Some(backend), - Err(e) => { - warn!("Failed to initialize CH9329 backend: {}", e); - None - } - } - } - Err(e) => { - warn!("Failed to create CH9329 backend: {}", e); - None - } - } - } - HidBackendType::None => { - warn!("HID backend disabled"); + let new_backend = match self + .backend_factory + .create_initialized(&new_backend_type) + .await + { + Ok(backend) => backend, + Err(error) if matches!(&new_backend_type, HidBackendType::None) => return Err(error), + Err(error) => { + warn!("Failed to initialize HID backend: {}", error); None } }; diff --git a/src/hid/otg.rs b/src/hid/otg.rs index 5acd91e3..f2eb3b81 100644 --- a/src/hid/otg.rs +++ b/src/hid/otg.rs @@ -4,12 +4,10 @@ //! Polled timed writes (JetKVM-style). Treat `ESHUTDOWN` (108) by closing handles and reopening; keep fd on `EAGAIN` (11). Host/gadget teardown during MSD resembles PiKVM. use async_trait::async_trait; -use nix::poll::{poll, PollFd, PollFlags, PollTimeout}; use parking_lot::Mutex; use std::fs::{self, File, OpenOptions}; use std::io::{Read, Write}; use std::os::unix::fs::OpenOptionsExt; -use std::os::unix::io::AsFd; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use std::sync::Arc; @@ -19,6 +17,7 @@ use tokio::sync::watch; use tracing::{debug, info, trace, warn}; use super::backend::{HidBackend, HidBackendRuntimeSnapshot}; +use super::otg_device::OtgDeviceIo; use super::types::{ ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType, }; @@ -87,7 +86,6 @@ pub struct OtgBackend { last_error: parking_lot::RwLock>, last_error_log: parking_lot::Mutex, error_count: AtomicU8, - eagain_count: AtomicU8, runtime_notify_tx: watch::Sender<()>, runtime_worker_stop: Arc, runtime_worker: Mutex>>, @@ -119,7 +117,6 @@ impl OtgBackend { last_error: parking_lot::RwLock::new(None), last_error_log: parking_lot::Mutex::new(std::time::Instant::now()), error_count: AtomicU8::new(0), - eagain_count: AtomicU8::new(0), runtime_notify_tx, runtime_worker_stop: Arc::new(AtomicBool::new(false)), runtime_worker: Mutex::new(None), @@ -179,34 +176,11 @@ impl OtgBackend { fn reset_error_count(&self) { self.error_count.store(0, Ordering::Relaxed); - self.eagain_count.store(0, Ordering::Relaxed); } /// Poll-based write with `HID_WRITE_TIMEOUT_MS`; timeout → drop (JetKVM-style). fn write_with_timeout(&self, file: &mut File, data: &[u8]) -> std::io::Result { - let mut pollfd = [PollFd::new(file.as_fd(), PollFlags::POLLOUT)]; - - match poll(&mut pollfd, PollTimeout::from(HID_WRITE_TIMEOUT_MS as u16)) { - Ok(1) => { - if let Some(revents) = pollfd[0].revents() { - if revents.contains(PollFlags::POLLERR) || revents.contains(PollFlags::POLLHUP) - { - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "Device error or hangup", - )); - } - } - file.write_all(data)?; - Ok(true) - } - Ok(0) => { - trace!("HID write timeout, dropping data"); - Ok(false) - } - Ok(_) => Ok(false), - Err(e) => Err(std::io::Error::other(e)), - } + OtgDeviceIo::write_with_timeout(file, data, HID_WRITE_TIMEOUT_MS) } pub fn set_udc_name(&self, udc: &str) { @@ -357,6 +331,32 @@ impl OtgBackend { } } + fn handle_write_error( + &self, + dev: &mut Option, + err: std::io::Error, + operation: &str, + device_label: &str, + ) -> Result<()> { + match err.raw_os_error() { + Some(108) => { + debug!("{} ESHUTDOWN, closing for recovery", device_label); + *dev = None; + self.record_error(format!("{}: {}", operation, err), "eshutdown"); + Err(Self::io_error_to_hid_error(err, operation)) + } + Some(11) => { + trace!("{} EAGAIN after poll, dropping", device_label); + Ok(()) + } + _ => { + warn!("{} write error: {}", device_label, err); + self.record_error(format!("{}: {}", operation, err), Self::io_error_code(&err)); + Err(Self::io_error_to_hid_error(err, operation)) + } + } + } + pub fn check_devices_exist(&self) -> bool { self.keyboard_path.as_ref().is_none_or(|p| p.exists()) && self.mouse_rel_path.as_ref().is_none_or(|p| p.exists()) @@ -405,41 +405,12 @@ impl OtgBackend { self.log_throttled_error("HID keyboard write timeout, dropped"); Ok(()) } - Err(e) => { - let error_code = e.raw_os_error(); - - match error_code { - Some(108) => { - self.eagain_count.store(0, Ordering::Relaxed); - debug!("Keyboard ESHUTDOWN, closing for recovery"); - *dev = None; - self.record_error( - format!("Failed to write keyboard report: {}", e), - "eshutdown", - ); - Err(Self::io_error_to_hid_error( - e, - "Failed to write keyboard report", - )) - } - Some(11) => { - trace!("Keyboard EAGAIN after poll, dropping"); - Ok(()) - } - _ => { - self.eagain_count.store(0, Ordering::Relaxed); - warn!("Keyboard write error: {}", e); - self.record_error( - format!("Failed to write keyboard report: {}", e), - Self::io_error_code(&e), - ); - Err(Self::io_error_to_hid_error( - e, - "Failed to write keyboard report", - )) - } - } - } + Err(e) => self.handle_write_error( + &mut dev, + e, + "Failed to write keyboard report", + "Keyboard", + ), } } else { Err(AppError::HidError { @@ -468,38 +439,12 @@ impl OtgBackend { Ok(()) } Ok(false) => Ok(()), - Err(e) => { - let error_code = e.raw_os_error(); - - match error_code { - Some(108) => { - self.eagain_count.store(0, Ordering::Relaxed); - debug!("Relative mouse ESHUTDOWN, closing for recovery"); - *dev = None; - self.record_error( - format!("Failed to write mouse report: {}", e), - "eshutdown", - ); - Err(Self::io_error_to_hid_error( - e, - "Failed to write mouse report", - )) - } - Some(11) => Ok(()), - _ => { - self.eagain_count.store(0, Ordering::Relaxed); - warn!("Relative mouse write error: {}", e); - self.record_error( - format!("Failed to write mouse report: {}", e), - Self::io_error_code(&e), - ); - Err(Self::io_error_to_hid_error( - e, - "Failed to write mouse report", - )) - } - } - } + Err(e) => self.handle_write_error( + &mut dev, + e, + "Failed to write mouse report", + "Relative mouse", + ), } } else { Err(AppError::HidError { @@ -534,38 +479,12 @@ impl OtgBackend { Ok(()) } Ok(false) => Ok(()), - Err(e) => { - let error_code = e.raw_os_error(); - - match error_code { - Some(108) => { - self.eagain_count.store(0, Ordering::Relaxed); - debug!("Absolute mouse ESHUTDOWN, closing for recovery"); - *dev = None; - self.record_error( - format!("Failed to write mouse report: {}", e), - "eshutdown", - ); - Err(Self::io_error_to_hid_error( - e, - "Failed to write mouse report", - )) - } - Some(11) => Ok(()), - _ => { - self.eagain_count.store(0, Ordering::Relaxed); - warn!("Absolute mouse write error: {}", e); - self.record_error( - format!("Failed to write mouse report: {}", e), - Self::io_error_code(&e), - ); - Err(Self::io_error_to_hid_error( - e, - "Failed to write mouse report", - )) - } - } - } + Err(e) => self.handle_write_error( + &mut dev, + e, + "Failed to write mouse report", + "Absolute mouse", + ), } } else { Err(AppError::HidError { @@ -597,35 +516,12 @@ impl OtgBackend { Ok(()) } Ok(false) => Ok(()), - Err(e) => { - let error_code = e.raw_os_error(); - match error_code { - Some(108) => { - debug!("Consumer control ESHUTDOWN, closing for recovery"); - *dev = None; - self.record_error( - format!("Failed to write consumer report: {}", e), - "eshutdown", - ); - Err(Self::io_error_to_hid_error( - e, - "Failed to write consumer report", - )) - } - Some(11) => Ok(()), - _ => { - warn!("Consumer control write error: {}", e); - self.record_error( - format!("Failed to write consumer report: {}", e), - Self::io_error_code(&e), - ); - Err(Self::io_error_to_hid_error( - e, - "Failed to write consumer report", - )) - } - } - } + Err(e) => self.handle_write_error( + &mut dev, + e, + "Failed to write consumer report", + "Consumer control", + ), } } else { Err(AppError::HidError { diff --git a/src/hid/otg_device.rs b/src/hid/otg_device.rs new file mode 100644 index 00000000..d9307f42 --- /dev/null +++ b/src/hid/otg_device.rs @@ -0,0 +1,50 @@ +#[cfg(unix)] +use std::fs::{File, OpenOptions}; +#[cfg(unix)] +use std::io::{Read, Write}; +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; +#[cfg(unix)] +use std::os::unix::io::AsFd; +#[cfg(unix)] +use std::path::PathBuf; + +#[cfg(unix)] +use nix::poll::{poll, PollFd, PollFlags, PollTimeout}; +#[cfg(unix)] +use tracing::trace; + +#[cfg(unix)] +pub struct OtgDeviceIo; + +#[cfg(unix)] +impl OtgDeviceIo { + pub fn write_with_timeout( + file: &mut File, + data: &[u8], + timeout_ms: i32, + ) -> std::io::Result { + let mut pollfd = [PollFd::new(file.as_fd(), PollFlags::POLLOUT)]; + match poll(&mut pollfd, PollTimeout::from(timeout_ms as u16)) { + Ok(1) => { + if let Some(revents) = pollfd[0].revents() { + if revents.contains(PollFlags::POLLERR) || revents.contains(PollFlags::POLLHUP) + { + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "Device error or hangup", + )); + } + } + file.write_all(data)?; + Ok(true) + } + Ok(0) => { + trace!("HID write timeout, dropping data"); + Ok(false) + } + Ok(_) => Ok(false), + Err(e) => Err(std::io::Error::other(e)), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 4ded12c1..460771ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,23 @@ //! Core library for One-KVM (IP‑KVM: capture, HID, OTG, streaming, Web UI glue). +#[cfg(not(any(unix, windows)))] +compile_error!("One-KVM supports Linux and Windows targets only."); + pub mod atx; pub mod audio; pub mod auth; pub mod config; pub mod db; +pub mod diagnostics; pub mod error; pub mod events; pub mod extensions; pub mod hid; +#[cfg(unix)] pub mod msd; +#[cfg(unix)] pub mod otg; +pub mod platform; pub mod redfish; pub mod rtsp; pub mod rustdesk; diff --git a/src/main.rs b/src/main.rs index ac0f74dc..9488eb86 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,8 +20,11 @@ use one_kvm::db::DatabasePool; use one_kvm::events::EventBus; use one_kvm::extensions::ExtensionManager; use one_kvm::hid::{HidBackendType, HidController}; +#[cfg(unix)] use one_kvm::msd::MsdController; +#[cfg(unix)] use one_kvm::otg::OtgService; +use one_kvm::platform::PlatformCapabilities; use one_kvm::rtsp::RtspService; use one_kvm::rustdesk::RustDeskService; use one_kvm::state::AppState; @@ -78,7 +81,7 @@ struct CliArgs { #[arg(long, value_name = "FILE", requires = "ssl_cert")] ssl_key: Option, - /// Data directory path (default: /etc/one-kvm) + /// Data directory path (default: /etc/one-kvm, or the executable directory on Windows) #[arg(short = 'd', long, value_name = "DIR")] data_dir: Option, @@ -119,6 +122,12 @@ async fn main() -> anyhow::Result<()> { .expect("Failed to install rustls crypto provider"); tracing::info!("Starting One-KVM v{}", env!("CARGO_PKG_VERSION")); + let platform = PlatformCapabilities::current(); + tracing::info!( + "Platform mode: {:?} ({})", + platform.mode, + platform.mode_label + ); let data_dir = args.data_dir.clone().unwrap_or_else(get_data_dir); tracing::info!("Data directory: {}", data_dir.display()); @@ -128,37 +137,7 @@ async fn main() -> anyhow::Result<()> { return Ok(()); } - tokio::fs::create_dir_all(&data_dir).await?; - - let db = open_database_pool(&data_dir).await?; - let config_store = ConfigStore::new(db.clone_pool())?; - config_store.load().await?; - let mut config = (*config_store.get()).clone(); - - let mut msd_dir_updated = false; - if config.msd.msd_dir.trim().is_empty() { - let msd_dir = data_dir.join("msd"); - config.msd.msd_dir = msd_dir.to_string_lossy().to_string(); - msd_dir_updated = true; - } else if !PathBuf::from(&config.msd.msd_dir).is_absolute() { - let msd_dir = data_dir.join(&config.msd.msd_dir); - tracing::warn!( - "MSD directory is relative, rebasing to {}", - msd_dir.display() - ); - config.msd.msd_dir = msd_dir.to_string_lossy().to_string(); - msd_dir_updated = true; - } - if msd_dir_updated { - config_store.set(config.clone()).await?; - } - let msd_dir = PathBuf::from(&config.msd.msd_dir); - if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("images")).await { - tracing::warn!("Failed to create MSD images directory: {}", e); - } - if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("ventoy")).await { - tracing::warn!("Failed to create MSD ventoy directory: {}", e); - } + let (db, config_store, mut config) = load_runtime_config(&data_dir).await?; if let Some(addr) = args.address { config.web.bind_address = addr.clone(); @@ -311,9 +290,12 @@ async fn main() -> anyhow::Result<()> { }; tracing::info!("WebRTC streamer created"); + #[cfg(unix)] let otg_service = Arc::new(OtgService::new()); + #[cfg(unix)] tracing::info!("OTG Service created"); + #[cfg(unix)] if let Err(e) = otg_service.apply_config(&config.hid, &config.msd).await { tracing::warn!("Failed to apply OTG config: {}", e); } @@ -326,12 +308,16 @@ async fn main() -> anyhow::Result<()> { }, config::HidBackend::None => HidBackendType::None, }; + #[cfg(unix)] let hid = Arc::new(HidController::new(hid_backend, Some(otg_service.clone()))); + #[cfg(not(unix))] + let hid = Arc::new(HidController::new(hid_backend)); hid.set_event_bus(events.clone()).await; if let Err(e) = hid.init().await { tracing::warn!("Failed to initialize HID backend: {}", e); } + #[cfg(unix)] let msd = if config.msd.enabled { let ventoy_resource_dir = data_dir.join("ventoy"); if ventoy_resource_dir.exists() { @@ -544,10 +530,12 @@ async fn main() -> anyhow::Result<()> { config_store.clone(), session_store, user_store, + #[cfg(unix)] otg_service, stream_manager, webrtc_streamer.clone(), hid, + #[cfg(unix)] msd, atx, audio, @@ -703,12 +691,12 @@ fn init_logging(level: LogLevel, verbose_count: u8) { }; let filter = match effective_level { - LogLevel::Error => "one_kvm=error,tower_http=error", - LogLevel::Warn => "one_kvm=warn,tower_http=warn", - LogLevel::Info => "one_kvm=info,tower_http=info", - LogLevel::Verbose => "one_kvm=debug,tower_http=info", - LogLevel::Debug => "one_kvm=debug,tower_http=debug", - LogLevel::Trace => "one_kvm=trace,tower_http=debug", + LogLevel::Error => "one_kvm=error,tower_http=error,webrtc_sctp=warn", + LogLevel::Warn => "one_kvm=warn,tower_http=warn,webrtc_sctp=warn", + LogLevel::Info => "one_kvm=info,tower_http=info,webrtc_sctp=warn", + LogLevel::Verbose => "one_kvm=debug,tower_http=info,webrtc_sctp=warn", + LogLevel::Debug => "one_kvm=debug,tower_http=debug,webrtc_sctp=warn", + LogLevel::Trace => "one_kvm=trace,tower_http=debug,webrtc_sctp=warn", }; let env_filter = @@ -728,6 +716,19 @@ fn get_data_dir() -> PathBuf { return PathBuf::from(path); } + #[cfg(windows)] + { + if let Ok(exe_path) = std::env::current_exe() { + if let Some(exe_dir) = exe_path.parent() { + return exe_dir.join("one-kvm"); + } + } + return std::env::current_dir() + .map(|dir| dir.join("one-kvm")) + .unwrap_or_else(|_| PathBuf::from("one-kvm")); + } + + #[cfg(not(windows))] PathBuf::from("/etc/one-kvm") } @@ -771,6 +772,64 @@ async fn run_cli_command(command: CliCommand, data_dir: PathBuf) -> anyhow::Resu } } +async fn load_runtime_config( + data_dir: &Path, +) -> anyhow::Result<(DatabasePool, ConfigStore, AppConfig)> { + tokio::fs::create_dir_all(data_dir).await?; + + let db = open_database_pool(data_dir).await?; + let config_store = ConfigStore::new(db.clone_pool())?; + config_store.load().await?; + let mut config = (*config_store.get()).clone(); + config.apply_platform_defaults(); + + prepare_linux_runtime_dirs(data_dir, &config_store, &mut config).await?; + + Ok((db, config_store, config)) +} + +#[cfg(unix)] +async fn prepare_linux_runtime_dirs( + data_dir: &Path, + config_store: &ConfigStore, + config: &mut AppConfig, +) -> anyhow::Result<()> { + let mut msd_dir_updated = false; + if config.msd.msd_dir.trim().is_empty() { + let msd_dir = data_dir.join("msd"); + config.msd.msd_dir = msd_dir.to_string_lossy().to_string(); + msd_dir_updated = true; + } else if !PathBuf::from(&config.msd.msd_dir).is_absolute() { + let msd_dir = data_dir.join(&config.msd.msd_dir); + tracing::warn!( + "MSD directory is relative, rebasing to {}", + msd_dir.display() + ); + config.msd.msd_dir = msd_dir.to_string_lossy().to_string(); + msd_dir_updated = true; + } + if msd_dir_updated { + config_store.set(config.clone()).await?; + } + let msd_dir = PathBuf::from(&config.msd.msd_dir); + if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("images")).await { + tracing::warn!("Failed to create MSD images directory: {}", e); + } + if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("ventoy")).await { + tracing::warn!("Failed to create MSD ventoy directory: {}", e); + } + Ok(()) +} + +#[cfg(not(unix))] +async fn prepare_linux_runtime_dirs( + _data_dir: &Path, + _config_store: &ConfigStore, + _config: &mut AppConfig, +) -> anyhow::Result<()> { + Ok(()) +} + async fn run_user_action( action: UserAction, users: &UserStore, @@ -1048,6 +1107,7 @@ async fn cleanup(state: &Arc) { tracing::warn!("Failed to shutdown HID: {}", e); } + #[cfg(unix)] if let Some(msd) = state.msd.write().await.as_mut() { if let Err(e) = msd.shutdown().await { tracing::warn!("Failed to shutdown MSD: {}", e); diff --git a/src/otg/mod.rs b/src/otg/mod.rs index 9d6924e1..4d0deed8 100644 --- a/src/otg/mod.rs +++ b/src/otg/mod.rs @@ -1,14 +1,37 @@ //! USB OTG composite gadget (HID + MSD). +#[cfg(unix)] pub mod configfs; pub mod endpoint; +#[cfg(unix)] pub mod function; +#[cfg(unix)] pub mod hid; +#[cfg(unix)] pub mod manager; +#[cfg(unix)] pub mod msd; pub mod report_desc; +pub mod self_check; +#[cfg(unix)] pub mod service; +#[cfg(unix)] pub use manager::{wait_for_hid_devices, OtgGadgetManager}; +#[cfg(unix)] pub use msd::{MsdFunction, MsdLunConfig}; +#[cfg(unix)] pub use service::{HidDevicePaths, OtgService}; + +/// List USB Device Controller names exposed by sysfs. +pub fn list_udc_devices() -> Vec { + let mut devices: Vec = std::fs::read_dir("/sys/class/udc") + .ok() + .into_iter() + .flat_map(|entries| entries.filter_map(|entry| entry.ok())) + .filter_map(|entry| entry.file_name().to_str().map(str::to_owned)) + .collect(); + + devices.sort(); + devices +} diff --git a/src/otg/self_check.rs b/src/otg/self_check.rs new file mode 100644 index 00000000..a2180196 --- /dev/null +++ b/src/otg/self_check.rs @@ -0,0 +1,740 @@ +use serde::Serialize; + +use crate::utils::{list_dir_names, read_trimmed}; + +#[derive(Serialize, Clone, Copy, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum OtgSelfCheckLevel { + Info, + Warn, + Error, +} + +#[derive(Serialize)] +pub struct OtgSelfCheckItem { + pub id: &'static str, + pub ok: bool, + pub level: OtgSelfCheckLevel, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub hint: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, +} + +#[derive(Serialize)] +pub struct OtgSelfCheckResponse { + pub overall_ok: bool, + pub error_count: usize, + pub warning_count: usize, + pub hid_backend: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub selected_udc: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub bound_udc: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub udc_state: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub udc_speed: Option, + pub available_udcs: Vec, + pub other_gadgets: Vec, + pub checks: Vec, +} + +fn push_otg_check( + checks: &mut Vec, + id: &'static str, + ok: bool, + level: OtgSelfCheckLevel, + message: impl Into, + hint: Option>, + path: Option>, +) { + checks.push(OtgSelfCheckItem { + id, + ok, + level, + message: message.into(), + hint: hint.map(|v| v.into()), + path: path.map(|v| v.into()), + }); +} + +fn proc_modules_has(module_name: &str) -> bool { + std::fs::read_to_string("/proc/modules") + .ok() + .map(|content| { + content + .lines() + .filter_map(|line| line.split_whitespace().next()) + .any(|name| name == module_name) + }) + .unwrap_or(false) +} + +fn modules_metadata_has(module_name: &str) -> bool { + let kernel_release = match read_trimmed(std::path::Path::new("/proc/sys/kernel/osrelease")) { + Some(value) if !value.is_empty() => value, + _ => return false, + }; + + let module_dir = std::path::Path::new("/lib/modules").join(kernel_release); + let candidates = ["modules.builtin", "modules.builtin.modinfo", "modules.dep"]; + + candidates.iter().any(|filename| { + let path = module_dir.join(filename); + std::fs::read_to_string(path) + .ok() + .map(|content| { + let module_token = format!("/{module_name}.ko"); + content.lines().any(|line| { + line.contains(&module_token) + || line.contains(module_name) + || line.contains(&module_name.replace('_', "-")) + }) + }) + .unwrap_or(false) + }) +} + +fn kernel_config_option_enabled(option_name: &str) -> bool { + let kernel_release = match read_trimmed(std::path::Path::new("/proc/sys/kernel/osrelease")) { + Some(value) if !value.is_empty() => value, + _ => return false, + }; + + let config_paths = [ + std::path::PathBuf::from(format!("/boot/config-{kernel_release}")), + std::path::PathBuf::from("/boot/config"), + std::path::PathBuf::from(format!("/lib/modules/{kernel_release}/build/.config")), + ]; + + config_paths.iter().any(|path| { + std::fs::read_to_string(path) + .ok() + .map(|content| { + let enabled_y = format!("{option_name}=y"); + let enabled_m = format!("{option_name}=m"); + content + .lines() + .any(|line| line == enabled_y || line == enabled_m) + }) + .unwrap_or(false) + }) +} + +fn detect_libcomposite_available(gadget_root: &std::path::Path) -> bool { + let sys_module = std::path::Path::new("/sys/module/libcomposite").exists(); + if sys_module { + return true; + } + + if proc_modules_has("libcomposite") { + return true; + } + + if modules_metadata_has("libcomposite") { + return true; + } + + if kernel_config_option_enabled("CONFIG_USB_LIBCOMPOSITE") + || kernel_config_option_enabled("CONFIG_USB_CONFIGFS") + { + return true; + } + + // Fallback: if usb_gadget path exists, libcomposite may be built-in and already active. + gadget_root.exists() +} + +/// OTG self-check status for troubleshooting USB gadget issues +pub fn run(config: &crate::config::AppConfig) -> OtgSelfCheckResponse { + let hid_backend_is_otg = matches!(config.hid.backend, crate::config::HidBackend::Otg); + let mut checks = Vec::new(); + + let build_response = |checks: Vec, + selected_udc: Option, + bound_udc: Option, + udc_state: Option, + udc_speed: Option, + available_udcs: Vec, + other_gadgets: Vec| { + let error_count = checks + .iter() + .filter(|item| item.level == OtgSelfCheckLevel::Error) + .count(); + let warning_count = checks + .iter() + .filter(|item| item.level == OtgSelfCheckLevel::Warn) + .count(); + + OtgSelfCheckResponse { + overall_ok: error_count == 0, + error_count, + warning_count, + hid_backend: format!("{:?}", config.hid.backend).to_lowercase(), + selected_udc, + bound_udc, + udc_state, + udc_speed, + available_udcs, + other_gadgets, + checks, + } + }; + + let udc_root = std::path::Path::new("/sys/class/udc"); + let available_udcs = list_dir_names(udc_root); + let selected_udc = config + .hid + .otg_udc + .clone() + .filter(|udc| !udc.trim().is_empty()) + .or_else(|| available_udcs.first().cloned()); + let mut udc_stage_ok = true; + if !udc_root.exists() { + udc_stage_ok = false; + push_otg_check( + &mut checks, + "udc_dir_exists", + false, + OtgSelfCheckLevel::Error, + "Check /sys/class/udc existence", + Some("Ensure UDC/OTG kernel drivers are enabled"), + Some("/sys/class/udc"), + ); + } else if available_udcs.is_empty() { + udc_stage_ok = false; + push_otg_check( + &mut checks, + "udc_has_entries", + false, + OtgSelfCheckLevel::Error, + "Check available UDC entries", + Some("Ensure OTG controller is enabled in device tree"), + Some("/sys/class/udc"), + ); + } else { + push_otg_check( + &mut checks, + "udc_has_entries", + true, + OtgSelfCheckLevel::Info, + "Check available UDC entries", + None::, + Some("/sys/class/udc"), + ); + } + + let mut configured_udc_ok = true; + if let Some(config_udc) = config + .hid + .otg_udc + .clone() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + { + if available_udcs.iter().any(|item| item == &config_udc) { + push_otg_check( + &mut checks, + "configured_udc_valid", + true, + OtgSelfCheckLevel::Info, + "Check configured UDC validity", + None::, + Some("/sys/class/udc"), + ); + } else { + configured_udc_ok = false; + push_otg_check( + &mut checks, + "configured_udc_valid", + false, + OtgSelfCheckLevel::Error, + "Check configured UDC validity", + Some("Please reselect UDC in HID OTG settings"), + Some("/sys/class/udc"), + ); + } + } else { + push_otg_check( + &mut checks, + "configured_udc_valid", + !available_udcs.is_empty(), + if available_udcs.is_empty() { + OtgSelfCheckLevel::Warn + } else { + OtgSelfCheckLevel::Info + }, + "Check configured UDC validity", + Some( + "You can set hid_otg_udc in settings to avoid ambiguity in multi-controller setups", + ), + Some("/sys/class/udc"), + ); + } + + if !udc_stage_ok || !configured_udc_ok { + return build_response( + checks, + selected_udc, + None, + None, + None, + available_udcs, + vec![], + ); + } + + let gadget_root = std::path::Path::new("/sys/kernel/config/usb_gadget"); + let configfs_mounted = std::fs::read_to_string("/proc/mounts") + .ok() + .map(|mounts| { + mounts.lines().any(|line| { + let mut parts = line.split_whitespace(); + let _src = parts.next(); + let mount_point = parts.next(); + let fs_type = parts.next(); + mount_point == Some("/sys/kernel/config") && fs_type == Some("configfs") + }) + }) + .unwrap_or(false); + + let mut gadget_config_ok = true; + + if configfs_mounted { + push_otg_check( + &mut checks, + "configfs_mounted", + true, + OtgSelfCheckLevel::Info, + "Check configfs mount status", + None::, + Some("/sys/kernel/config"), + ); + } else { + gadget_config_ok = false; + push_otg_check( + &mut checks, + "configfs_mounted", + false, + OtgSelfCheckLevel::Error, + "Check configfs mount status", + Some("Try: mount -t configfs none /sys/kernel/config"), + Some("/sys/kernel/config"), + ); + } + + if gadget_root.exists() { + push_otg_check( + &mut checks, + "usb_gadget_dir_exists", + true, + OtgSelfCheckLevel::Info, + "Check /sys/kernel/config/usb_gadget access", + None::, + Some("/sys/kernel/config/usb_gadget"), + ); + } else { + gadget_config_ok = false; + push_otg_check( + &mut checks, + "usb_gadget_dir_exists", + false, + OtgSelfCheckLevel::Error, + "Check /sys/kernel/config/usb_gadget access", + Some("Ensure configfs and USB gadget support are enabled"), + Some("/sys/kernel/config/usb_gadget"), + ); + } + + let libcomposite_available = detect_libcomposite_available(gadget_root); + if libcomposite_available { + push_otg_check( + &mut checks, + "libcomposite_loaded", + true, + OtgSelfCheckLevel::Info, + "Check libcomposite module status", + None::, + Some("/sys/module/libcomposite"), + ); + } else { + gadget_config_ok = false; + push_otg_check( + &mut checks, + "libcomposite_loaded", + false, + OtgSelfCheckLevel::Error, + "Check libcomposite module status", + Some("Try: modprobe libcomposite"), + Some("/sys/module/libcomposite"), + ); + } + + if !gadget_config_ok { + return build_response( + checks, + selected_udc, + None, + None, + None, + available_udcs, + vec![], + ); + } + + let gadget_names = list_dir_names(gadget_root); + let one_kvm_path = gadget_root.join("one-kvm"); + let one_kvm_exists = one_kvm_path.exists(); + if one_kvm_exists { + push_otg_check( + &mut checks, + "one_kvm_gadget_exists", + true, + OtgSelfCheckLevel::Info, + "Check one-kvm gadget presence", + None::, + Some(one_kvm_path.display().to_string()), + ); + } else { + push_otg_check( + &mut checks, + "one_kvm_gadget_exists", + false, + if hid_backend_is_otg { + OtgSelfCheckLevel::Error + } else { + OtgSelfCheckLevel::Warn + }, + "Check one-kvm gadget presence", + Some("Enable OTG HID or MSD to let one-kvm gadget be created automatically"), + Some(one_kvm_path.display().to_string()), + ); + } + + let other_gadgets = gadget_names + .iter() + .filter(|name| name.as_str() != "one-kvm") + .cloned() + .collect::>(); + if other_gadgets.is_empty() { + push_otg_check( + &mut checks, + "other_gadgets", + true, + OtgSelfCheckLevel::Info, + "Check for other gadget services", + None::, + Some("/sys/kernel/config/usb_gadget"), + ); + } else { + push_otg_check( + &mut checks, + "other_gadgets", + false, + OtgSelfCheckLevel::Warn, + "Check for other gadget services", + Some("Potential UDC contention with one-kvm; check other OTG services"), + Some("/sys/kernel/config/usb_gadget"), + ); + } + + let mut bound_udc = None; + + if one_kvm_exists { + let one_kvm_udc_path = one_kvm_path.join("UDC"); + let current_udc = read_trimmed(&one_kvm_udc_path).unwrap_or_default(); + if current_udc.is_empty() { + push_otg_check( + &mut checks, + "one_kvm_bound_udc", + false, + OtgSelfCheckLevel::Warn, + "Check one-kvm UDC binding", + Some("Ensure HID/MSD is enabled and initialized successfully"), + Some(one_kvm_udc_path.display().to_string()), + ); + } else { + push_otg_check( + &mut checks, + "one_kvm_bound_udc", + true, + OtgSelfCheckLevel::Info, + "Check one-kvm UDC binding", + None::, + Some(one_kvm_udc_path.display().to_string()), + ); + bound_udc = Some(current_udc); + } + + let functions_path = one_kvm_path.join("functions"); + let function_names = list_dir_names(&functions_path) + .into_iter() + .filter(|name| name.contains(".usb")) + .collect::>(); + let hid_functions = function_names + .iter() + .filter(|name| name.starts_with("hid.usb")) + .cloned() + .collect::>(); + if hid_functions.is_empty() { + push_otg_check( + &mut checks, + "hid_functions_present", + false, + if hid_backend_is_otg { + OtgSelfCheckLevel::Error + } else { + OtgSelfCheckLevel::Warn + }, + "Check HID function creation", + Some("Check OTG HID config and enable at least one HID function"), + Some(functions_path.display().to_string()), + ); + } else { + push_otg_check( + &mut checks, + "hid_functions_present", + true, + OtgSelfCheckLevel::Info, + "Check HID function creation", + None::, + Some(functions_path.display().to_string()), + ); + } + + let config_path = one_kvm_path.join("configs/c.1"); + if !config_path.exists() { + push_otg_check( + &mut checks, + "config_c1_exists", + false, + OtgSelfCheckLevel::Error, + "Check configs/c.1 structure", + Some("Gadget structure is incomplete; try restarting One-KVM"), + Some(config_path.display().to_string()), + ); + } else { + push_otg_check( + &mut checks, + "config_c1_exists", + true, + OtgSelfCheckLevel::Info, + "Check configs/c.1 structure", + None::, + Some(config_path.display().to_string()), + ); + + let linked_functions = list_dir_names(&config_path) + .into_iter() + .filter(|name| name.contains(".usb")) + .collect::>(); + let missing_links = function_names + .iter() + .filter(|func| !linked_functions.iter().any(|link| link == *func)) + .cloned() + .collect::>(); + + if missing_links.is_empty() { + push_otg_check( + &mut checks, + "function_links_ok", + true, + OtgSelfCheckLevel::Info, + "Check function links in configs/c.1", + None::, + Some(config_path.display().to_string()), + ); + } else { + push_otg_check( + &mut checks, + "function_links_ok", + false, + OtgSelfCheckLevel::Warn, + "Check function links in configs/c.1", + Some("Reinitialize OTG (toggle HID backend once or restart service)"), + Some(config_path.display().to_string()), + ); + } + } + + let missing_hid_devices = hid_functions + .iter() + .filter_map(|name| { + let index = name.strip_prefix("hid.usb")?.parse::().ok()?; + let dev_path = std::path::PathBuf::from(format!("/dev/hidg{}", index)); + if dev_path.exists() { + None + } else { + Some(dev_path.display().to_string()) + } + }) + .collect::>(); + + if !hid_functions.is_empty() { + if missing_hid_devices.is_empty() { + push_otg_check( + &mut checks, + "hid_device_nodes", + true, + OtgSelfCheckLevel::Info, + "Check /dev/hidg* device nodes", + None::, + Some("/dev/hidg*"), + ); + } else { + push_otg_check( + &mut checks, + "hid_device_nodes", + false, + OtgSelfCheckLevel::Warn, + "Check /dev/hidg* device nodes", + Some("Ensure gadget is bound and check kernel logs"), + Some("/dev/hidg*"), + ); + } + } + } + + if !other_gadgets.is_empty() { + let check_udc = bound_udc.clone().or_else(|| selected_udc.clone()); + if let Some(target_udc) = check_udc { + let conflicting_gadgets = other_gadgets + .iter() + .filter_map(|name| { + let udc_file = gadget_root.join(name).join("UDC"); + let udc = read_trimmed(&udc_file)?; + if udc == target_udc { + Some(name.clone()) + } else { + None + } + }) + .collect::>(); + + if conflicting_gadgets.is_empty() { + push_otg_check( + &mut checks, + "udc_conflict", + true, + OtgSelfCheckLevel::Info, + "Check UDC binding conflicts", + None::, + Some("/sys/kernel/config/usb_gadget/*/UDC"), + ); + } else { + push_otg_check( + &mut checks, + "udc_conflict", + false, + OtgSelfCheckLevel::Error, + "Check UDC binding conflicts", + Some("Stop other OTG services or switch one-kvm to an idle UDC"), + Some("/sys/kernel/config/usb_gadget/*/UDC"), + ); + } + } + } + + let active_udc = bound_udc.clone().or_else(|| selected_udc.clone()); + let mut udc_state = None; + let mut udc_speed = None; + + if let Some(udc) = active_udc.clone() { + let state_path = udc_root.join(&udc).join("state"); + match read_trimmed(&state_path) { + Some(state_name) if state_name.eq_ignore_ascii_case("configured") => { + udc_state = Some(state_name.clone()); + push_otg_check( + &mut checks, + "udc_state", + true, + OtgSelfCheckLevel::Info, + "Check UDC connection state", + None::, + Some(state_path.display().to_string()), + ); + } + Some(state_name) => { + udc_state = Some(state_name.clone()); + push_otg_check( + &mut checks, + "udc_state", + false, + OtgSelfCheckLevel::Warn, + "Check UDC connection state", + Some("Ensure target host is connected and has recognized the USB device"), + Some(state_path.display().to_string()), + ); + } + None => { + push_otg_check( + &mut checks, + "udc_state", + false, + OtgSelfCheckLevel::Warn, + "Check UDC connection state", + Some("Ensure UDC name is valid and check kernel permissions"), + Some(state_path.display().to_string()), + ); + } + } + + let speed_path = udc_root.join(&udc).join("current_speed"); + if let Some(speed) = read_trimmed(&speed_path) { + udc_speed = Some(speed.clone()); + let is_unknown = speed.eq_ignore_ascii_case("unknown"); + push_otg_check( + &mut checks, + "udc_speed", + !is_unknown, + if is_unknown { + OtgSelfCheckLevel::Warn + } else { + OtgSelfCheckLevel::Info + }, + "Check UDC current link speed", + if is_unknown { + Some("Device may not be fully enumerated; try reconnecting USB".to_string()) + } else { + None + }, + Some(speed_path.display().to_string()), + ); + } + } else { + push_otg_check( + &mut checks, + "udc_state", + false, + OtgSelfCheckLevel::Warn, + "Check UDC connection state", + Some("Ensure UDC is available and one-kvm gadget is bound first"), + Some("/sys/class/udc"), + ); + } + + let error_count = checks + .iter() + .filter(|item| item.level == OtgSelfCheckLevel::Error) + .count(); + let warning_count = checks + .iter() + .filter(|item| item.level == OtgSelfCheckLevel::Warn) + .count(); + + OtgSelfCheckResponse { + overall_ok: error_count == 0, + error_count, + warning_count, + hid_backend: format!("{:?}", config.hid.backend).to_lowercase(), + selected_udc, + bound_udc, + udc_state, + udc_speed, + available_udcs, + other_gadgets, + checks, + } +} diff --git a/src/platform/capabilities.rs b/src/platform/capabilities.rs new file mode 100644 index 00000000..494f77bb --- /dev/null +++ b/src/platform/capabilities.rs @@ -0,0 +1,89 @@ +//! Runtime platform mode and feature capability reporting. + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PlatformMode { + Linux, + Windows, +} + +impl PlatformMode { + pub const fn current() -> Self { + if cfg!(windows) { + Self::Windows + } else { + Self::Linux + } + } + + pub const fn label(self) -> &'static str { + match self { + Self::Linux => "Linux", + Self::Windows => "Windows", + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FeatureCapability { + pub available: bool, + pub backends: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub selected_backend: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, +} + +impl FeatureCapability { + pub fn available(backends: impl IntoIterator>) -> Self { + let backends = backends.into_iter().map(Into::into).collect(); + Self { + available: true, + backends, + selected_backend: None, + reason: None, + } + } + + pub fn unsupported(reason: impl Into) -> Self { + Self { + available: false, + backends: Vec::new(), + selected_backend: None, + reason: Some(reason.into()), + } + } + + pub fn with_selected_backend(mut self, backend: Option) -> Self { + self.selected_backend = backend; + self + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlatformCapabilities { + pub mode: PlatformMode, + pub mode_label: &'static str, + pub video_capture: FeatureCapability, + pub encoder: FeatureCapability, + pub hid: FeatureCapability, + pub atx: FeatureCapability, + pub msd: FeatureCapability, + pub otg: FeatureCapability, + pub audio: FeatureCapability, + pub rustdesk: FeatureCapability, + pub diagnostics: FeatureCapability, + pub extensions: FeatureCapability, + pub service_installation: FeatureCapability, +} + +impl PlatformCapabilities { + pub fn current() -> Self { + match PlatformMode::current() { + PlatformMode::Linux => crate::platform::linux::capabilities(), + PlatformMode::Windows => crate::platform::windows::capabilities(), + } + } +} diff --git a/src/platform/defaults.rs b/src/platform/defaults.rs new file mode 100644 index 00000000..90b77230 --- /dev/null +++ b/src/platform/defaults.rs @@ -0,0 +1,62 @@ +use crate::config::{AppConfig, AtxDriverType, HidBackend}; + +pub fn apply(config: &mut AppConfig) { + if cfg!(windows) { + apply_windows(config); + } +} + +fn apply_windows(config: &mut AppConfig) { + config.msd.enabled = false; + config.hid.otg_udc = None; + if config.hid.backend == HidBackend::Otg { + config.hid.backend = HidBackend::None; + } + if config.hid.ch9329_port == "/dev/ttyUSB0" { + config.hid.ch9329_port = "COM3".to_string(); + } + if !config.initialized { + config.audio.enabled = false; + config.audio.device.clear(); + } + + if matches!( + config.atx.power.driver, + AtxDriverType::Gpio | AtxDriverType::UsbRelay + ) { + config.atx.power.driver = AtxDriverType::None; + } + if matches!( + config.atx.reset.driver, + AtxDriverType::Gpio | AtxDriverType::UsbRelay + ) { + config.atx.reset.driver = AtxDriverType::None; + } + if !config.initialized + && config.atx.power.driver == AtxDriverType::None + && config.atx.power.device.is_empty() + { + config.atx.power.driver = AtxDriverType::Serial; + config.atx.power.device = "COM4".to_string(); + config.atx.power.pin = 1; + config.atx.power.baud_rate = 9600; + } + if !config.initialized + && config.atx.reset.driver == AtxDriverType::None + && config.atx.reset.device.is_empty() + { + config.atx.reset.driver = AtxDriverType::Serial; + config.atx.reset.device = "COM4".to_string(); + config.atx.reset.pin = 2; + config.atx.reset.baud_rate = 9600; + } + + config + .video + .device + .get_or_insert_with(|| "auto".to_string()); + config + .video + .format + .get_or_insert_with(|| "MJPEG".to_string()); +} diff --git a/src/platform/linux.rs b/src/platform/linux.rs new file mode 100644 index 00000000..0035e249 --- /dev/null +++ b/src/platform/linux.rs @@ -0,0 +1,23 @@ +//! Linux platform capabilities. + +use super::{FeatureCapability, PlatformCapabilities, PlatformMode}; + +pub fn capabilities() -> PlatformCapabilities { + PlatformCapabilities { + mode: PlatformMode::Linux, + mode_label: PlatformMode::Linux.label(), + video_capture: FeatureCapability::available(["v4l2"]), + encoder: FeatureCapability::available([ + "software", "vaapi", "nvenc", "qsv", "amf", "rkmpp", "v4l2m2m", + ]), + hid: FeatureCapability::available(["otg", "ch9329", "none"]), + atx: FeatureCapability::available(["gpio", "usb_relay", "serial", "wol", "none"]), + msd: FeatureCapability::available(["configfs"]), + otg: FeatureCapability::available(["configfs"]), + audio: FeatureCapability::available(["alsa"]), + rustdesk: FeatureCapability::available(["builtin"]), + diagnostics: FeatureCapability::available(["linux"]), + extensions: FeatureCapability::available(["linux"]), + service_installation: FeatureCapability::available(["systemd"]), + } +} diff --git a/src/platform/mod.rs b/src/platform/mod.rs new file mode 100644 index 00000000..06c876d7 --- /dev/null +++ b/src/platform/mod.rs @@ -0,0 +1,10 @@ +//! Platform selection and capability reporting. + +pub mod capabilities; +pub mod defaults; +pub mod linux; +#[cfg(unix)] +pub mod usb_reset; +pub mod windows; + +pub use capabilities::{FeatureCapability, PlatformCapabilities, PlatformMode}; diff --git a/src/video/usb_reset.rs b/src/platform/usb_reset.rs similarity index 75% rename from src/video/usb_reset.rs rename to src/platform/usb_reset.rs index d2dc9443..826d4745 100644 --- a/src/video/usb_reset.rs +++ b/src/platform/usb_reset.rs @@ -1,14 +1,9 @@ //! USB device enumeration and reset via sysfs `authorized`. -//! -//! Provides APIs for the settings page to list and reset USB devices. -//! Requires write access to `/sys/bus/usb/devices/.../authorized` (typically root). use std::io; use std::path::{Path, PathBuf}; use std::time::{Duration, Instant}; -/// Walk up from a V4L sysfs `device` link until we find a USB device node -/// (`busnum` + `devnum` present). fn usb_device_dir_for_v4l_sysfs(device_link: &Path) -> io::Result { let mut p = device_link.canonicalize()?; loop { @@ -28,57 +23,38 @@ fn usb_device_dir_for_v4l_sysfs(device_link: &Path) -> io::Result { } } -// --------------------------------------------------------------------------- -// USB device enumeration & reset-by-bus/dev (for the settings API) -// --------------------------------------------------------------------------- - use serde::Serialize; -/// Information about a single USB device, read from `/sys/bus/usb/devices/`. #[derive(Debug, Serialize)] pub struct UsbDeviceInfo { - /// USB bus number (`busnum` sysfs attribute). pub bus_num: u32, - /// USB device number on the bus (`devnum` sysfs attribute). pub dev_num: u32, - /// Vendor ID hex string, e.g. `"1d6b"`. pub id_vendor: String, - /// Product ID hex string, e.g. `"0002"`. pub id_product: String, - /// Product name from sysfs `product`. #[serde(skip_serializing_if = "Option::is_none")] pub product: Option, - /// Manufacturer name from sysfs `manufacturer`. #[serde(skip_serializing_if = "Option::is_none")] pub manufacturer: Option, - /// Speed in Mbps from sysfs `speed`, e.g. `"480"`. #[serde(skip_serializing_if = "Option::is_none")] pub speed: Option, - /// `true` if authorized=1, `false` if authorized=0, `None` if no file. #[serde(skip_serializing_if = "Option::is_none")] pub authorized: Option, - /// Kernel driver bound to this device (from driver symlink). #[serde(skip_serializing_if = "Option::is_none")] pub driver: Option, - /// Associated `/dev/videoN` node, if this USB device has a V4L2 child. #[serde(skip_serializing_if = "Option::is_none")] pub video_device: Option, } -/// Read a sysfs string attribute, trimming trailing newline. fn read_sysfs_str(dir: &Path, attr: &str) -> Option { std::fs::read_to_string(dir.join(attr)) .ok() .map(|s| s.trim_end().to_string()) } -/// Read a sysfs u32 attribute. fn read_sysfs_u32(dir: &Path, attr: &str) -> Option { read_sysfs_str(dir, attr).and_then(|s| s.parse().ok()) } -/// Build a map from USB sysfs dir → video device name by scanning -/// `/sys/class/video4linux/`. fn build_usb_to_video_map() -> std::collections::HashMap { let mut map = std::collections::HashMap::new(); let v4l_class = Path::new("/sys/class/video4linux"); @@ -92,7 +68,6 @@ fn build_usb_to_video_map() -> std::collections::HashMap { Some(s) if s.starts_with("video") => s, _ => continue, }; - // Resolve the device symlink and walk up to find the USB parent let device_link = v4l_class.join(name_str).join("device"); if let Ok(usb_dir) = usb_device_dir_for_v4l_sysfs(&device_link) { if let Some(key) = usb_dir.file_name().and_then(|k| k.to_str()) { @@ -103,7 +78,6 @@ fn build_usb_to_video_map() -> std::collections::HashMap { map } -/// List all USB devices visible in `/sys/bus/usb/devices/`. pub fn list_usb_devices() -> Vec { let usb_bus = Path::new("/sys/bus/usb/devices"); let entries = match std::fs::read_dir(usb_bus) { @@ -117,7 +91,6 @@ pub fn list_usb_devices() -> Vec { .flatten() .filter_map(|entry| { let dir = entry.path(); - // Only consider entries that have busnum + devnum (actual devices, not interfaces) let bus_num = read_sysfs_u32(&dir, "busnum")?; let dev_num = read_sysfs_u32(&dir, "devnum")?; @@ -158,13 +131,10 @@ pub fn list_usb_devices() -> Vec { }) .collect(); - // Sort by bus, then device number for stable ordering. devices.sort_by(|a, b| (a.bus_num, a.dev_num).cmp(&(b.bus_num, b.dev_num))); devices } -/// Reset a USB device identified by bus/dev numbers via the `authorized` sysfs -/// attribute. After re-authorizing, waits for the device to reappear. pub fn reset_usb_device(bus_num: u32, dev_num: u32) -> io::Result<()> { let usb_bus = Path::new("/sys/bus/usb/devices"); let entries = std::fs::read_dir(usb_bus)?; @@ -187,7 +157,6 @@ pub fn reset_usb_device(bus_num: u32, dev_num: u32) -> io::Result<()> { std::thread::sleep(Duration::from_millis(300)); std::fs::write(&authorized, b"1")?; - // Wait for device to reappear let wait_until = Instant::now() + Duration::from_secs(2); while !dir.join("busnum").exists() { if Instant::now() >= wait_until { diff --git a/src/platform/windows.rs b/src/platform/windows.rs new file mode 100644 index 00000000..25e91939 --- /dev/null +++ b/src/platform/windows.rs @@ -0,0 +1,33 @@ +//! Windows platform capabilities. + +use super::{FeatureCapability, PlatformCapabilities, PlatformMode}; + +pub fn capabilities() -> PlatformCapabilities { + let linux_only = "unsupported on Windows"; + PlatformCapabilities { + mode: PlatformMode::Windows, + mode_label: PlatformMode::Windows.label(), + video_capture: FeatureCapability::available(["directshow_uvc", "mjpeg"]) + .with_selected_backend(Some("directshow_uvc".to_string())), + encoder: FeatureCapability::available([ + "ffmpeg_h264", + "ffmpeg_h265", + "ffmpeg_vp8", + "ffmpeg_vp9", + "software", + "mjpeg", + ]), + hid: FeatureCapability::available(["ch9329", "none"]) + .with_selected_backend(Some("ch9329".to_string())), + atx: FeatureCapability::available(["serial", "wol", "none"]), + msd: FeatureCapability::unsupported(linux_only), + otg: FeatureCapability::unsupported(linux_only), + audio: FeatureCapability::available(["wasapi", "opus"]) + .with_selected_backend(Some("wasapi".to_string())), + rustdesk: FeatureCapability::available(["builtin", "tcp_direct", "relay"]) + .with_selected_backend(Some("builtin".to_string())), + diagnostics: FeatureCapability::available(["windows"]), + extensions: FeatureCapability::available(["windows_safe"]), + service_installation: FeatureCapability::available(["windows_service"]), + } +} diff --git a/src/redfish/auth.rs b/src/redfish/auth.rs index b1f61b35..cbac4de8 100644 --- a/src/redfish/auth.rs +++ b/src/redfish/auth.rs @@ -62,10 +62,8 @@ pub async fn redfish_auth_middleware( } fn is_redfish_public_endpoint(path: &str, method: &Method) -> bool { - matches!( - path, - "/" | "/v1" | "/v1/" | "/v1/odata" - ) || path.starts_with("/v1/$metadata") + matches!(path, "/" | "/v1" | "/v1/" | "/v1/odata") + || path.starts_with("/v1/$metadata") || (path == "/v1/SessionService/Sessions" && *method == Method::POST) } diff --git a/src/redfish/routes/account.rs b/src/redfish/routes/account.rs index c49253db..45fbcb05 100644 --- a/src/redfish/routes/account.rs +++ b/src/redfish/routes/account.rs @@ -8,29 +8,20 @@ use axum::{ use std::sync::Arc; -use super::{empty_collection, resource_not_found}; use super::super::schema::*; +use super::{empty_collection, resource_not_found}; use crate::state::AppState; pub(crate) fn router(state: Arc) -> Router> { Router::new() .route("/v1/AccountService", get(account_service)) - .route( - "/v1/AccountService/Accounts", - get(account_list), - ) + .route("/v1/AccountService/Accounts", get(account_list)) .route( "/v1/AccountService/Accounts/{account_id}", get(account_detail), ) - .route( - "/v1/AccountService/Roles", - get(roles_stub), - ) - .route( - "/v1/AccountService/Roles/{role_id}", - get(role_detail_stub), - ) + .route("/v1/AccountService/Roles", get(roles_stub)) + .route("/v1/AccountService/Roles/{role_id}", get(role_detail_stub)) .with_state(state) } @@ -77,7 +68,10 @@ async fn account_list(State(state): State>) -> Response { "/redfish/v1/$metadata#ManagerAccountCollection.ManagerAccountCollection", "Accounts Collection", "Collection of Accounts", - vec![odata_ref(&format!("/redfish/v1/AccountService/Accounts/{}", user.id))], + vec![odata_ref(&format!( + "/redfish/v1/AccountService/Accounts/{}", + user.id + ))], )) .into_response() } diff --git a/src/redfish/routes/chassis.rs b/src/redfish/routes/chassis.rs index 6e02d582..5a2ae4d1 100644 --- a/src/redfish/routes/chassis.rs +++ b/src/redfish/routes/chassis.rs @@ -7,8 +7,8 @@ use axum::{ use std::sync::Arc; -use super::{empty_collection, get_power_state, validate_id, RESOURCE_ID}; use super::super::schema::*; +use super::{empty_collection, get_power_state, validate_id, RESOURCE_ID}; use crate::state::AppState; pub(crate) fn router(state: Arc) -> Router> { @@ -64,9 +64,7 @@ async fn chassis_detail( .into_response() } -async fn chassis_power( - Path(chassis_id): Path, -) -> Response { +async fn chassis_power(Path(chassis_id): Path) -> Response { if let Some(resp) = validate_id(&chassis_id) { return resp; } diff --git a/src/redfish/routes/event.rs b/src/redfish/routes/event.rs index 1d1d64ca..f4f4ae68 100644 --- a/src/redfish/routes/event.rs +++ b/src/redfish/routes/event.rs @@ -48,8 +48,7 @@ async fn event_service() -> Json { server_sent_event_uri: Some("/redfish/v1/EventService/SSE".to_string()), actions: EventServiceActions { submit_test_event: ActionTarget { - target: "/redfish/v1/EventService/Actions/EventService.SubmitTestEvent" - .to_string(), + target: "/redfish/v1/EventService/Actions/EventService.SubmitTestEvent".to_string(), }, }, }) diff --git a/src/redfish/routes/managers.rs b/src/redfish/routes/managers.rs index a6d0df72..5197bc61 100644 --- a/src/redfish/routes/managers.rs +++ b/src/redfish/routes/managers.rs @@ -7,8 +7,8 @@ use axum::{ use std::sync::Arc; -use super::{empty_collection, validate_id, RESOURCE_ID}; use super::super::schema::*; +use super::{empty_collection, validate_id, RESOURCE_ID}; use crate::state::AppState; pub(crate) fn router(state: Arc) -> Router> { @@ -86,7 +86,10 @@ async fn manager_detail( manager_for_servers: vec![odata_ref(&format!("/redfish/v1/Systems/{}", RESOURCE_ID))], manager_for_chassis: vec![odata_ref(&format!("/redfish/v1/Chassis/{}", RESOURCE_ID))], }, - network_protocol: odata_ref(&format!("/redfish/v1/Managers/{}/NetworkProtocol", manager_id)), + network_protocol: odata_ref(&format!( + "/redfish/v1/Managers/{}/NetworkProtocol", + manager_id + )), }) .into_response() } diff --git a/src/redfish/routes/mod.rs b/src/redfish/routes/mod.rs index a41b8385..f50a73c0 100644 --- a/src/redfish/routes/mod.rs +++ b/src/redfish/routes/mod.rs @@ -4,6 +4,7 @@ mod event; mod managers; mod session; mod systems; +#[cfg(unix)] mod virtual_media; use axum::{ @@ -191,7 +192,6 @@ pub fn create_redfish_router(state: Arc) -> Router { .merge(systems::router(state.clone())) .merge(chassis::router(state.clone())) .merge(managers::router(state.clone())) - .merge(virtual_media::router(state.clone())) .merge(session::router(state.clone())) .merge(account::router(state.clone())) .merge(event::router(state.clone())) @@ -200,6 +200,9 @@ pub fn create_redfish_router(state: Arc) -> Router { redfish_auth_middleware, )); + #[cfg(unix)] + let redfish_routes = redfish_routes.merge(virtual_media::router(state.clone())); + Router::new() .route("/redfish", get(service_root_redirect)) .nest("/redfish/", redfish_routes) diff --git a/src/redfish/routes/session.rs b/src/redfish/routes/session.rs index c9b67dd0..fc18d68c 100644 --- a/src/redfish/routes/session.rs +++ b/src/redfish/routes/session.rs @@ -9,8 +9,8 @@ use tracing::info; use std::sync::Arc; -use super::empty_collection; use super::super::schema::*; +use super::empty_collection; use crate::state::AppState; pub(crate) fn router(state: Arc) -> Router> { @@ -56,7 +56,10 @@ async fn session_list(State(state): State>) -> Response { let mut members = Vec::new(); for id in &session_ids { if state.sessions.get(id).await.ok().flatten().is_some() { - members.push(odata_ref(&format!("/redfish/v1/SessionService/Sessions/{}", id))); + members.push(odata_ref(&format!( + "/redfish/v1/SessionService/Sessions/{}", + id + ))); } } diff --git a/src/redfish/routes/systems.rs b/src/redfish/routes/systems.rs index 8dfb3dd5..7d1b6859 100644 --- a/src/redfish/routes/systems.rs +++ b/src/redfish/routes/systems.rs @@ -8,8 +8,8 @@ use axum::{ use std::sync::Arc; use tracing::info; -use super::{get_power_state, validate_id, service_unavailable, empty_collection, RESOURCE_ID}; use super::super::schema::*; +use super::{empty_collection, get_power_state, service_unavailable, validate_id, RESOURCE_ID}; use crate::state::AppState; pub(crate) fn router(state: Arc) -> Router> { @@ -208,13 +208,14 @@ async fn system_reset( } } -async fn system_set_default_boot_order( - Path(system_id): Path, -) -> Response { +async fn system_set_default_boot_order(Path(system_id): Path) -> Response { if let Some(resp) = validate_id(&system_id) { return resp; } - info!("Redfish: SetDefaultBootOrder for system {} (accepted, no-op)", system_id); + info!( + "Redfish: SetDefaultBootOrder for system {} (accepted, no-op)", + system_id + ); StatusCode::NO_CONTENT.into_response() } diff --git a/src/redfish/routes/virtual_media.rs b/src/redfish/routes/virtual_media.rs index 8be015c5..ebde46ad 100644 --- a/src/redfish/routes/virtual_media.rs +++ b/src/redfish/routes/virtual_media.rs @@ -9,8 +9,8 @@ use tracing::{info, warn}; use std::sync::Arc; -use super::{empty_collection, validate_id, service_unavailable, resource_not_found, RESOURCE_ID}; use super::super::schema::*; +use super::{empty_collection, resource_not_found, service_unavailable, validate_id, RESOURCE_ID}; use crate::state::AppState; pub(crate) fn router(state: Arc) -> Router> { @@ -95,7 +95,10 @@ async fn virtual_media_detail( Json(VirtualMedia { odata_type: "#VirtualMedia.v1_6_2.VirtualMedia".to_string(), - odata_id: format!("/redfish/v1/Managers/{}/VirtualMedia/{}", manager_id, media_id), + odata_id: format!( + "/redfish/v1/Managers/{}/VirtualMedia/{}", + manager_id, media_id + ), odata_context: "/redfish/v1/$metadata#VirtualMedia.VirtualMedia".to_string(), id: media_id.clone(), name: "Virtual Media 1".to_string(), @@ -153,7 +156,9 @@ async fn virtual_media_insert( if msd.state().await.connected { return ( StatusCode::CONFLICT, - Json(RedfishError::general_error("Virtual media already inserted")), + Json(RedfishError::general_error( + "Virtual media already inserted", + )), ) .into_response(); } diff --git a/src/redfish/schema.rs b/src/redfish/schema.rs index 2ddf6a5b..7663a4f1 100644 --- a/src/redfish/schema.rs +++ b/src/redfish/schema.rs @@ -527,7 +527,10 @@ pub struct RedfishError { pub struct RedfishErrorBody { pub code: String, pub message: String, - #[serde(rename = "@Message.ExtendedInfo", skip_serializing_if = "Vec::is_empty")] + #[serde( + rename = "@Message.ExtendedInfo", + skip_serializing_if = "Vec::is_empty" + )] pub extended_info: Vec, } diff --git a/src/rtsp/bitstream.rs b/src/rtsp/bitstream.rs index 0da15e9b..9466ec75 100644 --- a/src/rtsp/bitstream.rs +++ b/src/rtsp/bitstream.rs @@ -1,7 +1,7 @@ use bytes::Bytes; -use crate::video::encoder::registry::VideoEncoderType; -use crate::video::shared_video_pipeline::EncodedVideoFrame; +use crate::video::codec::registry::VideoEncoderType; +use crate::video::pipeline::EncodedVideoFrame; use super::state::ParameterSets; diff --git a/src/rtsp/codec.rs b/src/rtsp/codec.rs index 31c00818..516f9b9e 100644 --- a/src/rtsp/codec.rs +++ b/src/rtsp/codec.rs @@ -1,5 +1,5 @@ use crate::config::RtspCodec; -use crate::video::encoder::VideoCodecType; +use crate::video::codec::VideoCodecType; pub(crate) fn rtsp_codec_to_video(codec: RtspCodec) -> VideoCodecType { match codec { diff --git a/src/rtsp/sdp.rs b/src/rtsp/sdp.rs index 61acf57e..48f3b45b 100644 --- a/src/rtsp/sdp.rs +++ b/src/rtsp/sdp.rs @@ -2,8 +2,10 @@ use base64::Engine; use sdp_types as sdp; use crate::config::RtspConfig; -use crate::video::encoder::VideoCodecType; -use crate::webrtc::rtp::parse_profile_level_id_from_sps; +use crate::video::codec::h264_bitstream::{ + parse_profile_level_id_from_sps, FALLBACK_WEBRTC_PROFILE_LEVEL_ID, +}; +use crate::video::codec::VideoCodecType; use super::state::ParameterSets; @@ -15,7 +17,10 @@ pub(crate) fn build_h264_fmtp(payload_type: u8, params: &ParameterSets) -> Strin attrs.push(format!("profile-level-id={}", profile_level_id)); } } else { - attrs.push("profile-level-id=42e01f".to_string()); + attrs.push(format!( + "profile-level-id={}", + FALLBACK_WEBRTC_PROFILE_LEVEL_ID + )); } if let (Some(sps), Some(pps)) = (params.h264_sps.as_ref(), params.h264_pps.as_ref()) { diff --git a/src/rtsp/streaming.rs b/src/rtsp/streaming.rs index 31f7aca7..c1b71b73 100644 --- a/src/rtsp/streaming.rs +++ b/src/rtsp/streaming.rs @@ -11,8 +11,8 @@ use webrtc::util::{Marshal, MarshalSize}; use crate::config::RtspCodec; use crate::error::{AppError, Result}; -use crate::video::encoder::registry::VideoEncoderType; -use crate::video::shared_video_pipeline::EncodedVideoFrame; +use crate::video::codec::registry::VideoEncoderType; +use crate::video::pipeline::EncodedVideoFrame; use crate::video::VideoStreamManager; use crate::webrtc::h265_payloader::H265Payloader; diff --git a/src/rustdesk/connection.rs b/src/rustdesk/connection.rs index 2f6c4294..3895fccb 100644 --- a/src/rustdesk/connection.rs +++ b/src/rustdesk/connection.rs @@ -16,11 +16,11 @@ use tracing::{debug, error, info, warn}; use crate::audio::AudioController; use crate::hid::{CanonicalKey, HidController, KeyEventType, KeyboardEvent, KeyboardModifiers}; use crate::utils::hostname_from_etc; +use crate::video::codec::registry::{EncoderRegistry, VideoEncoderType}; +use crate::video::codec::BitratePreset; use crate::video::codec_constraints::{ encoder_codec_to_id, encoder_codec_to_video_codec, video_codec_to_encoder_codec, }; -use crate::video::encoder::registry::{EncoderRegistry, VideoEncoderType}; -use crate::video::encoder::BitratePreset; use crate::video::stream_manager::VideoStreamManager; use super::bytes_codec::{read_frame, write_frame, write_frame_buffered}; @@ -637,22 +637,22 @@ impl Connection { // Check availability in priority order // H264 is preferred because it has the best hardware encoder support (RKMPP, VAAPI, etc.) // and most RustDesk clients support H264 hardware decoding - if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H264) + if constraints.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::H264) && registry.is_codec_available(VideoEncoderType::H264) { return VideoEncoderType::H264; } - if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H265) + if constraints.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::H265) && registry.is_codec_available(VideoEncoderType::H265) { return VideoEncoderType::H265; } - if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP8) + if constraints.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::VP8) && registry.is_codec_available(VideoEncoderType::VP8) { return VideoEncoderType::VP8; } - if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP9) + if constraints.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::VP9) && registry.is_codec_available(VideoEncoderType::VP9) { return VideoEncoderType::VP9; @@ -1106,16 +1106,16 @@ impl Connection { // Check which encoders are available (include software fallback) let h264_available = constraints - .is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H264) + .is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::H264) && registry.is_codec_available(VideoEncoderType::H264); let h265_available = constraints - .is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H265) + .is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::H265) && registry.is_codec_available(VideoEncoderType::H265); let vp8_available = constraints - .is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP8) + .is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::VP8) && registry.is_codec_available(VideoEncoderType::VP8); let vp9_available = constraints - .is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP9) + .is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::VP9) && registry.is_codec_available(VideoEncoderType::VP9); info!( @@ -1574,7 +1574,7 @@ async fn run_video_streaming( shutdown_tx: broadcast::Sender<()>, negotiated_codec: VideoEncoderType, ) -> anyhow::Result<()> { - use crate::video::encoder::VideoCodecType; + use crate::video::codec::VideoCodecType; // Convert VideoEncoderType to VideoCodecType for the pipeline let webrtc_codec = match negotiated_codec { diff --git a/src/rustdesk/frame_adapters.rs b/src/rustdesk/frame_adapters.rs index a8794cdf..93dc6f67 100644 --- a/src/rustdesk/frame_adapters.rs +++ b/src/rustdesk/frame_adapters.rs @@ -91,7 +91,7 @@ impl VideoFrameAdapter { return data; } - let (sps, pps) = crate::webrtc::rtp::extract_sps_pps(&data); + let (sps, pps) = crate::video::codec::h264_bitstream::extract_sps_pps(&data); let mut has_sps = false; let mut has_pps = false; diff --git a/src/rustdesk/rendezvous.rs b/src/rustdesk/rendezvous.rs index 0e62390e..36066c6b 100644 --- a/src/rustdesk/rendezvous.rs +++ b/src/rustdesk/rendezvous.rs @@ -696,6 +696,7 @@ fn try_into_v4(addr: SocketAddr) -> SocketAddr { addr } +#[cfg(target_os = "linux")] fn is_virtual_interface(name: &str) -> bool { name.starts_with("docker") || name.starts_with("br-") diff --git a/src/state.rs b/src/state.rs index f0935639..0237673d 100644 --- a/src/state.rs +++ b/src/state.rs @@ -12,7 +12,9 @@ use crate::events::{ }; use crate::extensions::{ExtensionId, ExtensionManager}; use crate::hid::HidController; +#[cfg(unix)] use crate::msd::MsdController; +#[cfg(unix)] use crate::otg::OtgService; use crate::rtsp::RtspService; use crate::rustdesk::RustDeskService; @@ -51,10 +53,12 @@ pub struct AppState { pub config: ConfigStore, pub sessions: SessionStore, pub users: UserStore, + #[cfg(unix)] pub otg_service: Arc, pub stream_manager: Arc, pub webrtc: Arc, pub hid: Arc, + #[cfg(unix)] pub msd: Arc>>, pub atx: Arc>>, pub audio: Arc, @@ -77,11 +81,11 @@ impl AppState { config: ConfigStore, sessions: SessionStore, users: UserStore, - otg_service: Arc, + #[cfg(unix)] otg_service: Arc, stream_manager: Arc, webrtc: Arc, hid: Arc, - msd: Option, + #[cfg(unix)] msd: Option, atx: Option, audio: Arc, rustdesk: Option>, @@ -99,10 +103,12 @@ impl AppState { config, sessions, users, + #[cfg(unix)] otg_service, stream_manager, webrtc, hid, + #[cfg(unix)] msd: Arc::new(RwLock::new(msd)), atx: Arc::new(RwLock::new(atx)), audio, @@ -202,23 +208,30 @@ impl AppState { } async fn collect_msd_info(&self) -> Option { - let msd_guard = self.msd.read().await; - let msd = msd_guard.as_ref()?; + #[cfg(not(unix))] + { + None + } + #[cfg(unix)] + { + let msd_guard = self.msd.read().await; + let msd = msd_guard.as_ref()?; - let state = msd.state().await; - let error = msd.monitor().error_message().await; - Some(MsdDeviceInfo { - available: state.available, - mode: match state.mode { - crate::msd::MsdMode::None => "none", - crate::msd::MsdMode::Image => "image", - crate::msd::MsdMode::Drive => "drive", - } - .to_string(), - connected: state.connected, - image_id: state.current_image.map(|img| img.id), - error, - }) + let state = msd.state().await; + let error = msd.monitor().error_message().await; + Some(MsdDeviceInfo { + available: state.available, + mode: match state.mode { + crate::msd::MsdMode::None => "none", + crate::msd::MsdMode::Image => "image", + crate::msd::MsdMode::Drive => "drive", + } + .to_string(), + connected: state.connected, + image_id: state.current_image.map(|img| img.id), + error, + }) + } } async fn collect_atx_info(&self) -> Option { diff --git a/src/stream/mjpeg.rs b/src/stream/mjpeg.rs index 2a721ea6..d0a633f3 100644 --- a/src/stream/mjpeg.rs +++ b/src/stream/mjpeg.rs @@ -11,8 +11,8 @@ use tracing::{debug, info, warn}; /// Generation token paired with `client_id` so [`unregister_client`] ignores stale drops. pub type ClientGeneration = u64; -use crate::video::encoder::traits::{Encoder, EncoderConfig}; -use crate::video::encoder::JpegEncoder; +use crate::video::codec::traits::{Encoder, EncoderConfig}; +use crate::video::codec::JpegEncoder; use crate::video::format::PixelFormat; use crate::video::VideoFrame; diff --git a/src/stream_encoder.rs b/src/stream_encoder.rs index 49ba45d1..ce0360f5 100644 --- a/src/stream_encoder.rs +++ b/src/stream_encoder.rs @@ -1,7 +1,7 @@ //! `EncoderType` → `EncoderBackend` (breaks config ↔ video import cycles). use crate::config::EncoderType; -use crate::video::encoder::EncoderBackend; +use crate::video::codec::EncoderBackend; /// `None` means “auto” in WebRTC / pipeline (same as `EncoderType::Auto`). pub fn encoder_type_to_backend(encoder: EncoderType) -> Option { diff --git a/src/utils/host.rs b/src/utils/host.rs index d1db2786..7173399c 100644 --- a/src/utils/host.rs +++ b/src/utils/host.rs @@ -9,7 +9,15 @@ pub fn hostname_from_etc() -> String { /// Current kernel hostname (`gethostname`). Used for live device info in the UI. pub fn hostname_uname() -> String { - nix::unistd::gethostname() - .map(|s| s.to_string_lossy().into_owned()) - .unwrap_or_else(|_| "unknown".to_string()) + #[cfg(unix)] + { + nix::unistd::gethostname() + .map(|s| s.to_string_lossy().into_owned()) + .unwrap_or_else(|_| "unknown".to_string()) + } + + #[cfg(not(unix))] + { + std::env::var("COMPUTERNAME").unwrap_or_else(|_| "unknown".to_string()) + } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 9f382e54..f38bd456 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -2,10 +2,16 @@ pub mod fs; pub mod host; +#[cfg(unix)] pub mod net; +#[cfg(not(unix))] +#[path = "net_disabled.rs"] +pub mod net; +pub mod serial; pub mod throttle; pub use fs::{list_dir_names, read_trimmed}; pub use host::{hostname_from_etc, hostname_uname}; pub use net::{bind_tcp_listener, bind_udp_socket}; +pub use serial::list_serial_ports; pub use throttle::LogThrottler; diff --git a/src/utils/net_disabled.rs b/src/utils/net_disabled.rs new file mode 100644 index 00000000..1abf5c57 --- /dev/null +++ b/src/utils/net_disabled.rs @@ -0,0 +1,14 @@ +use std::io; +use std::net::{SocketAddr, TcpListener, UdpSocket}; + +pub fn bind_tcp_listener(addr: SocketAddr) -> io::Result { + let listener = TcpListener::bind(addr)?; + listener.set_nonblocking(true)?; + Ok(listener) +} + +pub fn bind_udp_socket(addr: SocketAddr) -> io::Result { + let socket = UdpSocket::bind(addr)?; + socket.set_nonblocking(true)?; + Ok(socket) +} diff --git a/src/utils/serial.rs b/src/utils/serial.rs new file mode 100644 index 00000000..3a971abe --- /dev/null +++ b/src/utils/serial.rs @@ -0,0 +1,12 @@ +//! Cross-platform serial port discovery helpers. + +/// Return serial port names that users can put directly into the config. +pub fn list_serial_ports() -> Vec { + let mut ports: Vec = serialport::available_ports() + .map(|ports| ports.into_iter().map(|port| port.port_name).collect()) + .unwrap_or_default(); + + ports.sort(); + ports.dedup(); + ports +} diff --git a/src/video/v4l2r_capture.rs b/src/video/capture/linux.rs similarity index 98% rename from src/video/v4l2r_capture.rs rename to src/video/capture/linux.rs index 26630494..e6a67212 100644 --- a/src/video/v4l2r_capture.rs +++ b/src/video/capture/linux.rs @@ -22,9 +22,9 @@ use v4l2r::nix::errno::Errno; use v4l2r::{Format as V4l2rFormat, PixelFormat as V4l2rPixelFormat, QueueType}; use crate::error::{AppError, Result}; -use crate::video::csi_bridge::{self, CsiBridgeKind, ProbeResult}; +use crate::video::device::bridge::{self as csi_bridge, CsiBridgeKind, ProbeResult}; use crate::video::format::{PixelFormat, Resolution}; -use crate::video::SignalStatus; +use crate::video::signal::SignalStatus; /// `io::Error` payload when the driver posts `V4L2_EVENT_SOURCE_CHANGE`. pub const SOURCE_CHANGED_MARKER: &str = "v4l2_source_changed"; @@ -60,7 +60,7 @@ impl BridgeContext { } /// V4L2 capture stream backed by v4l2r ioctl. -pub struct V4l2rCaptureStream { +pub struct CaptureStream { fd: File, queue: QueueType, resolution: Resolution, @@ -72,7 +72,7 @@ pub struct V4l2rCaptureStream { bridge_kind: Option, } -impl V4l2rCaptureStream { +impl CaptureStream { /// UVC: uses `resolution`. CSI bridges: DV-probe first; may return `CaptureNoSignal`. pub fn open( device_path: impl AsRef, @@ -538,7 +538,7 @@ impl V4l2rCaptureStream { } } -impl Drop for V4l2rCaptureStream { +impl Drop for CaptureStream { fn drop(&mut self) { // Release ordering matters on rkcif: a subsequent open()/S_FMT from a // freshly-constructed stream returns EBUSY if the previous capture has @@ -571,9 +571,9 @@ impl Drop for V4l2rCaptureStream { } /// Driver-name check for CSI/HDMI bridge devices (rk_hdmirx, rkcif, tc358743, -/// …) that expose DV timings. Kept in sync with `video::is_csi_hdmi_bridge` +/// …) that expose DV timings. Kept in sync with `video::device::is_csi_hdmi_bridge` /// but queries the raw V4L2 driver string so we don't need a full -/// `VideoDeviceInfo` at `V4l2rCaptureStream::open` time. +/// `VideoDeviceInfo` at `CaptureStream::open` time. fn is_csi_bridge_driver(driver: &str) -> bool { let d = driver.to_ascii_lowercase(); d == "rk_hdmirx" || d == "rkcif" || d == "tc358743" || d.starts_with("rkcif") diff --git a/src/video/capture/mod.rs b/src/video/capture/mod.rs new file mode 100644 index 00000000..3eb930cd --- /dev/null +++ b/src/video/capture/mod.rs @@ -0,0 +1,12 @@ +//! Video capture implementations and capture-state helpers. + +pub(crate) mod runtime; +pub(crate) mod status; + +#[cfg(unix)] +mod linux; +#[cfg(windows)] +#[path = "windows.rs"] +mod linux; + +pub use linux::*; diff --git a/src/video/capture/runtime.rs b/src/video/capture/runtime.rs new file mode 100644 index 00000000..a09fcdde --- /dev/null +++ b/src/video/capture/runtime.rs @@ -0,0 +1,70 @@ +use std::path::Path; +use std::time::Duration; + +use crate::error::AppError; +use crate::video::capture::status::signal_status_from_capture_kind; +use crate::video::format::{PixelFormat, Resolution}; +use crate::video::signal::SignalStatus; + +use super::{BridgeContext, CaptureStream}; + +pub enum CaptureOpenResult { + Opened(CaptureStream), + NoSignal(SignalStatus), + DeviceLost(String), + Fatal, +} + +pub fn open_capture_stream( + device_path: &Path, + resolution: Resolution, + format: PixelFormat, + fps: u32, + buffer_count: u32, + timeout: Duration, + bridge_ctx: BridgeContext, +) -> Result { + CaptureStream::open_with_bridge( + device_path, + resolution, + format, + fps, + buffer_count.max(1), + timeout, + bridge_ctx, + ) +} + +pub fn open_capture_stream_for_retry( + device_path: &Path, + resolution: Resolution, + format: PixelFormat, + fps: u32, + buffer_count: u32, + timeout: Duration, + bridge_ctx: BridgeContext, + is_device_lost_message: impl FnOnce(&str) -> bool, +) -> CaptureOpenResult { + match open_capture_stream( + device_path, + resolution, + format, + fps, + buffer_count, + timeout, + bridge_ctx, + ) { + Ok(stream) => CaptureOpenResult::Opened(stream), + Err(AppError::CaptureNoSignal { kind }) => { + CaptureOpenResult::NoSignal(signal_status_from_capture_kind(&kind)) + } + Err(error) => { + let reason = error.to_string(); + if is_device_lost_message(&reason) { + CaptureOpenResult::DeviceLost(reason) + } else { + CaptureOpenResult::Fatal + } + } + } +} diff --git a/src/video/capture_status.rs b/src/video/capture/status.rs similarity index 98% rename from src/video/capture_status.rs rename to src/video/capture/status.rs index 565510cc..c6b91939 100644 --- a/src/video/capture_status.rs +++ b/src/video/capture/status.rs @@ -2,7 +2,7 @@ use std::io; -use crate::video::SignalStatus; +use crate::video::signal::SignalStatus; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CaptureIoErrorKind { diff --git a/src/video/capture/windows.rs b/src/video/capture/windows.rs new file mode 100644 index 00000000..e3d2cbf5 --- /dev/null +++ b/src/video/capture/windows.rs @@ -0,0 +1,181 @@ +use std::io; +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use crate::error::{AppError, Result}; +use crate::video::device::bridge::{CsiBridgeKind, ProbeResult}; +use crate::video::device::{directshow_display_name_from_path, normalize_windows_device_path}; +use crate::video::format::{PixelFormat, Resolution}; + +pub const SOURCE_CHANGED_MARKER: &str = "dshow_source_changed"; + +pub fn is_source_changed_error(err: &io::Error) -> bool { + err.get_ref() + .map(|inner| inner.to_string() == SOURCE_CHANGED_MARKER) + .unwrap_or(false) +} + +#[derive(Debug, Clone, Copy)] +pub struct CaptureMeta { + pub bytes_used: usize, + pub sequence: u64, +} + +#[derive(Debug, Clone, Default)] +pub struct BridgeContext { + pub subdev_path: Option, + pub kind: Option, +} + +impl BridgeContext { + pub fn from_parts(subdev_path: Option, kind: Option) -> Self { + Self { subdev_path, kind } + } + + pub fn has_subdev(&self) -> bool { + false + } +} + +pub struct CaptureStream { + capture: hwcodec::capture::DshowCapture, + resolution: Resolution, + format: PixelFormat, + stride: u32, +} + +unsafe impl Send for CaptureStream {} + +impl CaptureStream { + pub fn open( + device_path: impl AsRef, + resolution: Resolution, + format: PixelFormat, + fps: u32, + buffer_count: u32, + timeout: Duration, + ) -> Result { + let _ = buffer_count; + let path = normalize_windows_device_path(device_path); + let display_name = directshow_display_name_from_path(&path).ok_or_else(|| { + AppError::VideoError(format!( + "Unsupported DirectShow device path: {}", + path.display() + )) + })?; + let capture = hwcodec::capture::DshowCapture::open( + &display_name, + resolution.width as i32, + resolution.height as i32, + fps as i32, + map_pixel_format(format), + timeout.as_millis().clamp(1, i32::MAX as u128) as i32, + ) + .map_err(|e| AppError::VideoError(format!("Failed to open DirectShow capture: {}", e)))?; + let info = capture.info().map_err(|e| { + AppError::VideoError(format!("Failed to query DirectShow capture: {}", e)) + })?; + let actual_format = map_capture_format(info.pixel_format)?; + let actual_resolution = + Resolution::new(info.width.max(1) as u32, info.height.max(1) as u32); + + Ok(Self { + capture, + resolution: actual_resolution, + format: actual_format, + stride: info.stride.max(0) as u32, + }) + } + + pub fn open_with_bridge( + device_path: impl AsRef, + resolution: Resolution, + format: PixelFormat, + fps: u32, + buffer_count: u32, + timeout: Duration, + bridge: BridgeContext, + ) -> Result { + let _ = bridge; + Self::open(device_path, resolution, format, fps, buffer_count, timeout) + } + + pub fn resolution(&self) -> Resolution { + self.resolution + } + + pub fn format(&self) -> PixelFormat { + self.format + } + + pub fn stride(&self) -> u32 { + self.stride + } + + pub fn next_into(&mut self, dst: &mut Vec) -> io::Result { + match self.capture.read_packet() { + Ok((packet, sequence)) => { + dst.clear(); + dst.extend_from_slice(&packet); + Ok(CaptureMeta { + bytes_used: packet.len(), + sequence, + }) + } + Err(err) => { + let kind = if err.code == -110 { + io::ErrorKind::TimedOut + } else { + io::ErrorKind::Other + }; + Err(io::Error::new(kind, err.message)) + } + } + } + + pub fn probe_bridge_signal_with_timeout(&self, _limit: Duration) -> Option { + None + } +} + +fn map_pixel_format(format: PixelFormat) -> hwcodec::capture::CapturePixelFormat { + match format { + PixelFormat::Mjpeg => hwcodec::capture::CapturePixelFormat::Mjpeg, + PixelFormat::Jpeg => hwcodec::capture::CapturePixelFormat::Jpeg, + PixelFormat::Yuyv => hwcodec::capture::CapturePixelFormat::Yuyv, + PixelFormat::Yvyu => hwcodec::capture::CapturePixelFormat::Yvyu, + PixelFormat::Uyvy => hwcodec::capture::CapturePixelFormat::Uyvy, + PixelFormat::Nv12 => hwcodec::capture::CapturePixelFormat::Nv12, + PixelFormat::Nv21 => hwcodec::capture::CapturePixelFormat::Nv21, + PixelFormat::Nv16 => hwcodec::capture::CapturePixelFormat::Nv16, + PixelFormat::Nv24 => hwcodec::capture::CapturePixelFormat::Nv24, + PixelFormat::Yuv420 => hwcodec::capture::CapturePixelFormat::Yuv420, + PixelFormat::Yvu420 => hwcodec::capture::CapturePixelFormat::Yvu420, + PixelFormat::Rgb24 => hwcodec::capture::CapturePixelFormat::Rgb24, + PixelFormat::Bgr24 => hwcodec::capture::CapturePixelFormat::Bgr24, + PixelFormat::Grey => hwcodec::capture::CapturePixelFormat::Grey, + PixelFormat::Rgb565 => hwcodec::capture::CapturePixelFormat::Unknown, + } +} + +fn map_capture_format(format: hwcodec::capture::CapturePixelFormat) -> Result { + match format { + hwcodec::capture::CapturePixelFormat::Mjpeg => Ok(PixelFormat::Mjpeg), + hwcodec::capture::CapturePixelFormat::Jpeg => Ok(PixelFormat::Jpeg), + hwcodec::capture::CapturePixelFormat::Yuyv => Ok(PixelFormat::Yuyv), + hwcodec::capture::CapturePixelFormat::Yvyu => Ok(PixelFormat::Yvyu), + hwcodec::capture::CapturePixelFormat::Uyvy => Ok(PixelFormat::Uyvy), + hwcodec::capture::CapturePixelFormat::Nv12 => Ok(PixelFormat::Nv12), + hwcodec::capture::CapturePixelFormat::Nv21 => Ok(PixelFormat::Nv21), + hwcodec::capture::CapturePixelFormat::Nv16 => Ok(PixelFormat::Nv16), + hwcodec::capture::CapturePixelFormat::Nv24 => Ok(PixelFormat::Nv24), + hwcodec::capture::CapturePixelFormat::Yuv420 => Ok(PixelFormat::Yuv420), + hwcodec::capture::CapturePixelFormat::Yvu420 => Ok(PixelFormat::Yvu420), + hwcodec::capture::CapturePixelFormat::Rgb24 => Ok(PixelFormat::Rgb24), + hwcodec::capture::CapturePixelFormat::Bgr24 => Ok(PixelFormat::Bgr24), + hwcodec::capture::CapturePixelFormat::Grey => Ok(PixelFormat::Grey), + hwcodec::capture::CapturePixelFormat::Unknown => Err(AppError::ServiceUnavailable( + "DirectShow returned an unsupported pixel format".to_string(), + )), + } +} diff --git a/src/video/capture_limits.rs b/src/video/capture_limits.rs deleted file mode 100644 index 9cecc31e..00000000 --- a/src/video/capture_limits.rs +++ /dev/null @@ -1,30 +0,0 @@ -//! Shared tuning for V4L2 MJPEG capture paths (`Streamer` + `SharedVideoPipeline`). - -/// Frames smaller than this are treated as incomplete / noise. -pub(crate) const MIN_CAPTURE_FRAME_SIZE: usize = 128; - -/// After startup, validate JPEG header every N frames to limit CPU use. -pub(crate) const JPEG_VALIDATE_INTERVAL: u64 = 30; - -/// Validate every MJPEG frame for the first N frames (UVC warm-up / bad headers). -pub(crate) const STARTUP_JPEG_VALIDATE_FRAMES: u64 = 3; - -#[inline] -pub(crate) fn should_validate_jpeg_frame(validate_counter: u64) -> bool { - validate_counter <= STARTUP_JPEG_VALIDATE_FRAMES - || validate_counter.is_multiple_of(JPEG_VALIDATE_INTERVAL) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn jpeg_validation_policy_startup_then_interval() { - assert!(should_validate_jpeg_frame(1)); - assert!(should_validate_jpeg_frame(2)); - assert!(should_validate_jpeg_frame(3)); - assert!(!should_validate_jpeg_frame(4)); - assert!(should_validate_jpeg_frame(30)); - } -} diff --git a/src/video/convert.rs b/src/video/codec/convert.rs similarity index 100% rename from src/video/convert.rs rename to src/video/codec/convert.rs diff --git a/src/video/encoder/h264.rs b/src/video/codec/h264.rs similarity index 100% rename from src/video/encoder/h264.rs rename to src/video/codec/h264.rs diff --git a/src/video/codec/h264_bitstream.rs b/src/video/codec/h264_bitstream.rs new file mode 100644 index 00000000..eac14041 --- /dev/null +++ b/src/video/codec/h264_bitstream.rs @@ -0,0 +1,299 @@ +//! H.264 Annex-B/AVCC bitstream helpers shared by WebRTC, RTSP and RustDesk. + +pub const FALLBACK_WEBRTC_PROFILE_LEVEL_ID: &str = "42e01f"; + +pub fn webrtc_fmtp_line(profile_level_id: &str) -> String { + format!( + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id={}", + profile_level_id + ) +} + +pub fn fallback_webrtc_fmtp_line() -> String { + webrtc_fmtp_line(FALLBACK_WEBRTC_PROFILE_LEVEL_ID) +} + +pub fn strip_aud_nal_units(data: &[u8]) -> Vec { + let mut result = Vec::with_capacity(data.len()); + let mut i = 0; + + while i < data.len() { + let (start_code_pos, start_code_len) = if i + 4 <= data.len() + && data[i] == 0 + && data[i + 1] == 0 + && data[i + 2] == 0 + && data[i + 3] == 1 + { + (i, 4) + } else if i + 3 <= data.len() && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 { + (i, 3) + } else { + i += 1; + continue; + }; + + let nal_start = start_code_pos + start_code_len; + if nal_start >= data.len() { + break; + } + + let nal_type = data[nal_start] & 0x1F; + + let mut nal_end = data.len(); + let mut j = nal_start + 1; + while j + 3 <= data.len() { + if (data[j] == 0 && data[j + 1] == 0 && data[j + 2] == 1) + || (j + 4 <= data.len() + && data[j] == 0 + && data[j + 1] == 0 + && data[j + 2] == 0 + && data[j + 3] == 1) + { + nal_end = j; + break; + } + j += 1; + } + + if nal_type != 9 && nal_type != 12 { + result.extend_from_slice(&data[start_code_pos..nal_end]); + } + + i = nal_end; + } + + if result.is_empty() && !data.is_empty() { + return data.to_vec(); + } + + result +} + +pub fn extract_sps_pps(data: &[u8]) -> (Option>, Option>) { + let mut sps: Option> = None; + let mut pps: Option> = None; + let mut i = 0; + + while i < data.len() { + let start_code_len = if i + 4 <= data.len() + && data[i] == 0 + && data[i + 1] == 0 + && data[i + 2] == 0 + && data[i + 3] == 1 + { + 4 + } else if i + 3 <= data.len() && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 { + 3 + } else { + i += 1; + continue; + }; + + let nal_start = i + start_code_len; + if nal_start >= data.len() { + break; + } + + let nal_type = data[nal_start] & 0x1F; + + let mut nal_end = data.len(); + let mut j = nal_start + 1; + while j + 3 <= data.len() { + if (data[j] == 0 && data[j + 1] == 0 && data[j + 2] == 1) + || (j + 4 <= data.len() + && data[j] == 0 + && data[j + 1] == 0 + && data[j + 2] == 0 + && data[j + 3] == 1) + { + nal_end = j; + break; + } + j += 1; + } + + match nal_type { + 7 => { + sps = Some(data[nal_start..nal_end].to_vec()); + } + 8 => { + pps = Some(data[nal_start..nal_end].to_vec()); + } + _ => {} + } + + i = nal_end; + } + + (sps, pps) +} + +pub fn has_sps_pps(data: &[u8]) -> bool { + let mut has_sps = false; + let mut has_pps = false; + let mut i = 0; + + while i < data.len() { + let start_code_len = if i + 4 <= data.len() + && data[i] == 0 + && data[i + 1] == 0 + && data[i + 2] == 0 + && data[i + 3] == 1 + { + 4 + } else if i + 3 <= data.len() && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 { + 3 + } else { + i += 1; + continue; + }; + + let nal_start = i + start_code_len; + if nal_start >= data.len() { + break; + } + + let nal_type = data[nal_start] & 0x1F; + + match nal_type { + 7 => has_sps = true, + 8 => has_pps = true, + _ => {} + } + + if has_sps && has_pps { + return true; + } + + i = nal_start + 1; + } + + has_sps && has_pps +} + +pub fn is_keyframe(data: &[u8]) -> bool { + let mut i = 0; + while i < data.len() { + if i + 3 < data.len() && data[i] == 0 && data[i + 1] == 0 { + let nal_start = if data[i + 2] == 1 { + i + 3 + } else if i + 4 < data.len() && data[i + 2] == 0 && data[i + 3] == 1 { + i + 4 + } else { + i += 1; + continue; + }; + + if nal_start < data.len() { + let nal_type = data[nal_start] & 0x1F; + if nal_type == 5 { + return true; + } + } + i = nal_start; + } else { + i += 1; + } + } + false +} + +/// `profile-level-id` hex for SDP (`42001f` etc.); expects SPS NAL without start code. +pub fn parse_profile_level_id_from_sps(sps: &[u8]) -> Option { + if sps.len() < 4 { + return None; + } + + let profile_idc = sps[1]; + let constraint_set_flags = sps[2]; + let level_idc = sps[3]; + + Some(format!( + "{:02x}{:02x}{:02x}", + profile_idc, constraint_set_flags, level_idc + )) +} + +pub fn extract_profile_level_id(data: &[u8]) -> Option { + let (sps, _) = extract_sps_pps(data); + sps.and_then(|sps_data| parse_profile_level_id_from_sps(&sps_data)) +} + +pub fn is_annex_b(data: &[u8]) -> bool { + data.starts_with(&[0, 0, 1]) || data.starts_with(&[0, 0, 0, 1]) +} + +pub fn avcc_to_annex_b(data: &[u8]) -> Option> { + let mut pos = 0; + let mut output = Vec::with_capacity(data.len() + 16); + let mut nalu_count = 0usize; + + while pos + 4 <= data.len() { + let nalu_len = + u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize; + pos += 4; + if nalu_len == 0 || pos + nalu_len > data.len() { + return None; + } + + let nal_type = data[pos] & 0x1F; + if nal_type != 9 && nal_type != 12 { + output.extend_from_slice(&[0, 0, 0, 1]); + output.extend_from_slice(&data[pos..pos + nalu_len]); + } + nalu_count += 1; + pos += nalu_len; + } + + if pos == data.len() && nalu_count > 0 && !output.is_empty() { + Some(output) + } else { + None + } +} + +pub fn normalize_for_webrtc(data: &[u8]) -> Vec { + if is_annex_b(data) { + return strip_aud_nal_units(data); + } + + if let Some(annex_b) = avcc_to_annex_b(data) { + return strip_aud_nal_units(&annex_b); + } + + data.to_vec() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn detects_h264_keyframes() { + let idr_frame = vec![0x00, 0x00, 0x00, 0x01, 0x65]; + assert!(is_keyframe(&idr_frame)); + + let idr_frame_3 = vec![0x00, 0x00, 0x01, 0x65]; + assert!(is_keyframe(&idr_frame_3)); + + let p_frame = vec![0x00, 0x00, 0x00, 0x01, 0x41]; + assert!(!is_keyframe(&p_frame)); + + let sps = vec![0x00, 0x00, 0x00, 0x01, 0x67]; + assert!(!is_keyframe(&sps)); + + let multi_nal = vec![ + 0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x01, 0x68, 0xce, + 0x38, 0x80, 0x00, 0x00, 0x00, 0x01, 0x65, 0x88, 0x84, + ]; + assert!(is_keyframe(&multi_nal)); + } + + #[test] + fn parses_profile_level_id_from_sps() { + assert_eq!( + parse_profile_level_id_from_sps(&[0x67, 0x42, 0x40, 0x2a]), + Some("42402a".to_string()) + ); + } +} diff --git a/src/video/encoder/h265.rs b/src/video/codec/h265.rs similarity index 100% rename from src/video/encoder/h265.rs rename to src/video/codec/h265.rs diff --git a/src/video/encoder/jpeg.rs b/src/video/codec/jpeg.rs similarity index 99% rename from src/video/encoder/jpeg.rs rename to src/video/codec/jpeg.rs index f7a448ef..3c2fe79a 100644 --- a/src/video/encoder/jpeg.rs +++ b/src/video/codec/jpeg.rs @@ -347,7 +347,7 @@ impl JpegEncoder { } } -impl crate::video::encoder::traits::Encoder for JpegEncoder { +impl crate::video::codec::traits::Encoder for JpegEncoder { fn name(&self) -> &str { "JPEG (libyuv+turbojpeg)" } diff --git a/src/video/decoder/mjpeg_rkmpp.rs b/src/video/codec/mjpeg_rkmpp.rs similarity index 98% rename from src/video/decoder/mjpeg_rkmpp.rs rename to src/video/codec/mjpeg_rkmpp.rs index 686a722c..c933a500 100644 --- a/src/video/decoder/mjpeg_rkmpp.rs +++ b/src/video/codec/mjpeg_rkmpp.rs @@ -5,7 +5,7 @@ use hwcodec::ffmpeg_ram::decode::{DecodeContext, Decoder}; use tracing::{info, warn}; use crate::error::{AppError, Result}; -use crate::video::convert::Nv12Converter; +use crate::video::codec::convert::Nv12Converter; use crate::video::format::Resolution; pub struct MjpegRkmppDecoder { diff --git a/src/video/decoder/mjpeg_turbo.rs b/src/video/codec/mjpeg_turbo.rs similarity index 100% rename from src/video/decoder/mjpeg_turbo.rs rename to src/video/codec/mjpeg_turbo.rs diff --git a/src/video/encoder/mod.rs b/src/video/codec/mod.rs similarity index 71% rename from src/video/encoder/mod.rs rename to src/video/codec/mod.rs index e33ef620..27584170 100644 --- a/src/video/encoder/mod.rs +++ b/src/video/codec/mod.rs @@ -1,57 +1,45 @@ -//! Video encoder implementations -//! -//! This module provides video encoding capabilities including: -//! - JPEG encoding for raw frames (YUYV, NV12, etc.) -//! - H264 encoding (hardware + software) -//! - H265 encoding (hardware + software) -//! - VP8 encoding (hardware + software) -//! - VP9 encoding (hardware + software) -//! - WebRTC video codec abstraction -//! - Encoder registry for automatic detection +//! Video codec, conversion, encoding, and decoding implementations. use hwcodec::common::DataFormat; use hwcodec::ffmpeg_ram::CodecInfo; -pub mod codec; +pub mod convert; + pub mod h264; +pub mod h264_bitstream; pub mod h265; pub mod jpeg; pub mod registry; pub mod self_check; pub mod traits; +pub mod video_codec; pub mod vp8; pub mod vp9; -// Core traits and types -pub use traits::{ - BitratePreset, EncodedFormat, EncodedFrame, Encoder, EncoderConfig, EncoderFactory, -}; +pub mod mjpeg_turbo; -// WebRTC codec abstraction -pub use codec::{CodecFrame, VideoCodec, VideoCodecConfig, VideoCodecFactory, VideoCodecType}; +#[cfg(any(target_arch = "aarch64", target_arch = "arm"))] +pub mod mjpeg_rkmpp; -// Encoder registry +pub use convert::{PixelConverter, Yuv420pBuffer}; +pub use h264::{H264Config, H264Encoder, H264EncoderType, H264InputFormat}; +pub use h265::{H265Config, H265Encoder, H265EncoderType, H265InputFormat}; +pub use jpeg::JpegEncoder; +pub use mjpeg_turbo::MjpegTurboDecoder; pub use registry::{AvailableEncoder, EncoderBackend, EncoderRegistry, VideoEncoderType}; pub use self_check::{ build_hardware_self_check_runtime_error, run_hardware_self_check, VideoEncoderSelfCheckCell, VideoEncoderSelfCheckCodec, VideoEncoderSelfCheckResponse, VideoEncoderSelfCheckRow, }; - -// H264 encoder -pub use h264::{H264Config, H264Encoder, H264EncoderType, H264InputFormat}; - -// H265 encoder -pub use h265::{H265Config, H265Encoder, H265EncoderType, H265InputFormat}; - -// VP8 encoder +pub use traits::{ + BitratePreset, EncodedFormat, EncodedFrame, Encoder, EncoderConfig, EncoderFactory, +}; +pub use video_codec::{ + CodecFrame, VideoCodec, VideoCodecConfig, VideoCodecFactory, VideoCodecType, +}; pub use vp8::{VP8Config, VP8Encoder, VP8EncoderType, VP8InputFormat}; - -// VP9 encoder pub use vp9::{VP9Config, VP9Encoder, VP9EncoderType, VP9InputFormat}; -// JPEG encoder -pub use jpeg::JpegEncoder; - pub(crate) fn select_codec_for_format( encoders: &[CodecInfo], format: DataFormat, diff --git a/src/video/encoder/registry.rs b/src/video/codec/registry.rs similarity index 100% rename from src/video/encoder/registry.rs rename to src/video/codec/registry.rs diff --git a/src/video/encoder/self_check.rs b/src/video/codec/self_check.rs similarity index 100% rename from src/video/encoder/self_check.rs rename to src/video/codec/self_check.rs diff --git a/src/video/encoder/traits.rs b/src/video/codec/traits.rs similarity index 100% rename from src/video/encoder/traits.rs rename to src/video/codec/traits.rs diff --git a/src/video/encoder/codec.rs b/src/video/codec/video_codec.rs similarity index 99% rename from src/video/encoder/codec.rs rename to src/video/codec/video_codec.rs index 1b646dae..e2aa604b 100644 --- a/src/video/encoder/codec.rs +++ b/src/video/codec/video_codec.rs @@ -258,7 +258,7 @@ pub trait VideoCodec: Send { /// Get SDP fmtp parameters (codec-specific) /// - /// For H264: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f" + /// For H264: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=" /// For VP8/VP9: None or empty fn sdp_fmtp(&self) -> Option; diff --git a/src/video/encoder/vp8.rs b/src/video/codec/vp8.rs similarity index 100% rename from src/video/encoder/vp8.rs rename to src/video/codec/vp8.rs diff --git a/src/video/encoder/vp9.rs b/src/video/codec/vp9.rs similarity index 100% rename from src/video/encoder/vp9.rs rename to src/video/codec/vp9.rs diff --git a/src/video/codec_constraints.rs b/src/video/codec_constraints.rs index 41d2cd14..bde68f4a 100644 --- a/src/video/codec_constraints.rs +++ b/src/video/codec_constraints.rs @@ -1,7 +1,7 @@ use crate::config::{AppConfig, RtspCodec, StreamMode}; use crate::error::Result; -use crate::video::encoder::registry::VideoEncoderType; -use crate::video::encoder::VideoCodecType; +use crate::video::codec::registry::VideoEncoderType; +use crate::video::codec::VideoCodecType; use crate::video::VideoStreamManager; use std::sync::Arc; diff --git a/src/video/decoder/mod.rs b/src/video/decoder/mod.rs deleted file mode 100644 index 32b47d91..00000000 --- a/src/video/decoder/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -//! Video decoder implementations -//! -//! This module provides video decoding capabilities. - -pub mod mjpeg_turbo; - -pub use mjpeg_turbo::MjpegTurboDecoder; diff --git a/src/video/csi_bridge.rs b/src/video/device/bridge.rs similarity index 99% rename from src/video/csi_bridge.rs rename to src/video/device/bridge.rs index a7f83c6b..174688c8 100644 --- a/src/video/csi_bridge.rs +++ b/src/video/device/bridge.rs @@ -17,7 +17,7 @@ use v4l2r::bindings::{ use v4l2r::ioctl::{self, Event as V4l2Event, EventType, QueryDvTimingsError, SubscribeEventFlags}; use v4l2r::nix::errno::Errno; -use crate::video::SignalStatus; +use crate::video::signal::SignalStatus; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CsiBridgeKind { diff --git a/src/video/device/disabled_bridge.rs b/src/video/device/disabled_bridge.rs new file mode 100644 index 00000000..6da93963 --- /dev/null +++ b/src/video/device/disabled_bridge.rs @@ -0,0 +1,80 @@ +use std::fs::File; +use std::io; +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use crate::video::signal::SignalStatus; + +pub const RK628_SUBDEV_PROBE_TIMEOUT: Duration = Duration::from_millis(3000); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CsiBridgeKind { + Rk628, + RkHdmirx, + Tc358743, + Unknown, +} + +#[derive(Debug, Clone)] +pub enum ProbeResult { + Locked(DvTimingsMode), + NoCable, + NoSync, + OutOfRange, + NoSignal, +} + +impl ProbeResult { + pub fn as_status(&self) -> Option { + match self { + ProbeResult::Locked(_) => None, + ProbeResult::NoCable => Some(SignalStatus::NoCable), + ProbeResult::NoSync => Some(SignalStatus::NoSync), + ProbeResult::OutOfRange => Some(SignalStatus::OutOfRange), + ProbeResult::NoSignal => Some(SignalStatus::NoSignal), + } + } + + pub fn is_locked(&self) -> bool { + matches!(self, ProbeResult::Locked(_)) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DvTimingsMode { + pub width: u32, + pub height: u32, + pub pixelclock: u64, + pub fps: Option, + pub raw: (), +} + +pub fn discover_subdev_for_video(_video_path: &Path) -> Option<(PathBuf, CsiBridgeKind)> { + None +} + +pub fn open_subdev(path: &Path) -> io::Result { + File::open(path) +} + +pub fn probe_signal(_subdev_fd: &File, _kind: CsiBridgeKind) -> ProbeResult { + ProbeResult::NoSignal +} + +pub fn probe_signal_thread_timeout( + _subdev_fd: &File, + _kind: CsiBridgeKind, + _timeout: Duration, +) -> Option { + Some(ProbeResult::NoSignal) +} + +pub fn apply_dv_timings(_subdev_fd: &File, _timings: ()) {} + +pub fn subscribe_source_change(_subdev_fd: &File) -> io::Result<()> { + Ok(()) +} + +pub fn wait_source_change(_subdev_fd: &File, _timeout: Duration) -> io::Result { + Ok(false) +} diff --git a/src/video/device.rs b/src/video/device/linux.rs similarity index 99% rename from src/video/device.rs rename to src/video/device/linux.rs index 30e2a796..4ccd6ef8 100644 --- a/src/video/device.rs +++ b/src/video/device/linux.rs @@ -16,10 +16,10 @@ use v4l2r::ioctl::{ use v4l2r::nix::errno::Errno; use v4l2r::{Format as V4l2rFormat, QueueType}; -use super::csi_bridge; -use super::format::{PixelFormat, Resolution}; +use super::bridge as csi_bridge; use super::{is_rk_hdmirx_driver, is_rkcif_driver}; use crate::error::{AppError, Result}; +use crate::video::format::{PixelFormat, Resolution}; /// Per-node probe limit; rkcif/RK628 ioctl chains can exceed 1s under contention. const DEVICE_PROBE_TIMEOUT_MS: u64 = 10_000; diff --git a/src/video/device/mod.rs b/src/video/device/mod.rs new file mode 100644 index 00000000..d959d48e --- /dev/null +++ b/src/video/device/mod.rs @@ -0,0 +1,35 @@ +//! Video device discovery, capability probing, and platform adapters. + +#[cfg(unix)] +mod linux; +#[cfg(windows)] +mod windows; + +#[cfg(unix)] +pub use linux::*; +#[cfg(windows)] +pub use windows::*; + +#[cfg(unix)] +pub mod bridge; +#[cfg(windows)] +#[path = "disabled_bridge.rs"] +pub mod bridge; + +pub(crate) fn is_rk_hdmirx_driver(driver: &str, card: &str) -> bool { + driver.eq_ignore_ascii_case("rk_hdmirx") || card.eq_ignore_ascii_case("rk_hdmirx") +} + +pub(crate) fn is_rk_hdmirx_device(device: &VideoDeviceInfo) -> bool { + is_rk_hdmirx_driver(&device.driver, &device.card) +} + +pub(crate) fn is_rkcif_driver(driver: &str) -> bool { + driver.eq_ignore_ascii_case("rkcif") +} + +/// Unified check for CSI/HDMI bridge devices (rk_hdmirx, rkcif, etc.) +/// that require special enumeration and format-selection logic. +pub(crate) fn is_csi_hdmi_bridge(device: &VideoDeviceInfo) -> bool { + is_rk_hdmirx_device(device) || is_rkcif_driver(&device.driver) +} diff --git a/src/video/device/windows.rs b/src/video/device/windows.rs new file mode 100644 index 00000000..60d67210 --- /dev/null +++ b/src/video/device/windows.rs @@ -0,0 +1,359 @@ +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; + +use crate::error::{AppError, Result}; +use crate::video::format::{PixelFormat, Resolution}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VideoDeviceInfo { + pub path: PathBuf, + pub name: String, + pub driver: String, + pub bus_info: String, + pub card: String, + pub formats: Vec, + pub capabilities: DeviceCapabilities, + pub is_capture_card: bool, + pub priority: u32, + pub has_signal: bool, + pub subdev_path: Option, + pub bridge_kind: Option, +} + +#[derive(Debug, Clone)] +pub struct VideoDeviceRecoveryHint { + pub path: PathBuf, + pub name: String, + pub driver: String, + pub bus_info: String, + pub card: String, + pub is_capture_card: bool, +} + +impl From<&VideoDeviceInfo> for VideoDeviceRecoveryHint { + fn from(device: &VideoDeviceInfo) -> Self { + Self { + path: device.path.clone(), + name: device.name.clone(), + driver: device.driver.clone(), + bus_info: device.bus_info.clone(), + card: device.card.clone(), + is_capture_card: device.is_capture_card, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FormatInfo { + pub format: PixelFormat, + pub resolutions: Vec, + pub description: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResolutionInfo { + pub width: u32, + pub height: u32, + pub fps: Vec, +} + +impl ResolutionInfo { + pub fn new(width: u32, height: u32, fps: Vec) -> Self { + Self { width, height, fps } + } + + pub fn resolution(&self) -> Resolution { + Resolution::new(self.width, self.height) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct DeviceCapabilities { + pub video_capture: bool, + pub video_capture_mplane: bool, + pub video_output: bool, + pub streaming: bool, + pub read_write: bool, +} + +pub struct VideoDevice { + pub path: PathBuf, +} + +pub(crate) const DIRECTSHOW_DEVICE_PREFIX: &str = "dshow:"; + +impl VideoDevice { + pub fn open(path: impl AsRef) -> Result { + let path = normalize_windows_device_path(path.as_ref()); + if enumerate_devices()? + .iter() + .any(|device| device.path == path) + { + Ok(Self { path }) + } else { + Err(AppError::VideoError(format!( + "Windows video device not found: {}", + path.display() + ))) + } + } + + pub fn open_readonly(path: impl AsRef) -> Result { + Self::open(path) + } + + pub fn info(&self) -> Result { + enumerate_devices()? + .into_iter() + .find(|device| device.path == self.path) + .ok_or_else(|| { + AppError::VideoError(format!( + "Windows video device not found: {}", + self.path.display() + )) + }) + } +} + +pub(crate) fn normalize_windows_device_path(path: impl AsRef) -> PathBuf { + let raw = path.as_ref().to_string_lossy(); + if raw.eq_ignore_ascii_case("auto") { + return find_best_device() + .map(|device| device.path) + .unwrap_or_else(|_| PathBuf::from(raw.as_ref())); + } + PathBuf::from(raw.as_ref()) +} + +pub(crate) fn directshow_display_name_from_path(path: impl AsRef) -> Option { + path.as_ref() + .to_string_lossy() + .strip_prefix(DIRECTSHOW_DEVICE_PREFIX) + .map(str::to_string) +} + +pub fn enumerate_devices() -> Result> { + let names = hwcodec::capture::list_dshow_video_devices().map_err(|e| { + AppError::VideoError(format!("Failed to enumerate DirectShow devices: {}", e)) + })?; + + let mut devices = names + .into_iter() + .enumerate() + .map(|(index, name)| directshow_device_from_name(index, name)) + .collect::>(); + + devices.sort_by(|a, b| { + b.priority + .cmp(&a.priority) + .then_with(|| a.name.cmp(&b.name)) + }); + Ok(devices) +} + +pub fn find_best_device() -> Result { + enumerate_devices()?.into_iter().next().ok_or_else(|| { + AppError::VideoError("No DirectShow video capture devices found".to_string()) + }) +} + +pub fn parse_bridge_kind(value: Option<&str>) -> Option { + value.and_then(|_| None) +} + +pub fn select_recovery_device( + devices: &[VideoDeviceInfo], + hint: &VideoDeviceRecoveryHint, +) -> Option { + devices + .iter() + .find(|device| device.path == hint.path || device.bus_info == hint.bus_info) + .cloned() +} + +fn directshow_device_from_name(index: usize, name: String) -> VideoDeviceInfo { + let name = if name.trim().is_empty() { + format!("Windows Capture Device {}", index + 1) + } else { + name + }; + let path = PathBuf::from(format!("{}{}", DIRECTSHOW_DEVICE_PREFIX, name)); + let formats = enumerate_directshow_formats(&name); + let priority = score_capture_device(&name, &path.to_string_lossy(), &formats); + + VideoDeviceInfo { + path, + name: name.clone(), + driver: "directshow".to_string(), + bus_info: name.clone(), + card: name, + formats, + capabilities: DeviceCapabilities { + video_capture: true, + video_capture_mplane: false, + video_output: false, + streaming: true, + read_write: false, + }, + is_capture_card: true, + priority, + has_signal: true, + subdev_path: None, + bridge_kind: None, + } +} + +fn enumerate_directshow_formats(name: &str) -> Vec { + let Ok(capabilities) = hwcodec::capture::list_dshow_device_capabilities(name) else { + return fallback_windows_formats(); + }; + + let mut formats: Vec = Vec::new(); + for capability in capabilities { + let Some(format) = map_capture_format(capability.format) else { + continue; + }; + if capability.width == 0 || capability.height == 0 { + continue; + } + + if let Some(existing) = formats.iter_mut().find(|info| info.format == format) { + merge_resolution( + &mut existing.resolutions, + capability.width, + capability.height, + &capability.fps, + ); + continue; + } + + let mut resolutions = Vec::new(); + merge_resolution( + &mut resolutions, + capability.width, + capability.height, + &capability.fps, + ); + formats.push(FormatInfo { + format, + resolutions, + description: format_description(format).to_string(), + }); + } + + for info in &mut formats { + info.resolutions.sort_by(|left, right| { + b_pixels(right) + .cmp(&b_pixels(left)) + .then_with(|| right.width.cmp(&left.width)) + .then_with(|| right.height.cmp(&left.height)) + }); + } + formats.sort_by(|a, b| { + b.format + .priority() + .cmp(&a.format.priority()) + .then_with(|| a.description.cmp(&b.description)) + }); + + if formats.is_empty() { + fallback_windows_formats() + } else { + formats + } +} + +fn merge_resolution(resolutions: &mut Vec, width: u32, height: u32, fps: &[u32]) { + if let Some(existing) = resolutions + .iter_mut() + .find(|resolution| resolution.width == width && resolution.height == height) + { + existing.fps.extend(fps.iter().map(|value| *value as f64)); + normalize_fps_list(&mut existing.fps); + return; + } + + let mut fps_values = fps.iter().map(|value| *value as f64).collect::>(); + normalize_fps_list(&mut fps_values); + resolutions.push(ResolutionInfo::new(width, height, fps_values)); +} + +fn b_pixels(resolution: &ResolutionInfo) -> u32 { + resolution.width.saturating_mul(resolution.height) +} + +fn map_capture_format(format: hwcodec::capture::CapturePixelFormat) -> Option { + match format { + hwcodec::capture::CapturePixelFormat::Mjpeg => Some(PixelFormat::Mjpeg), + hwcodec::capture::CapturePixelFormat::Jpeg => Some(PixelFormat::Jpeg), + hwcodec::capture::CapturePixelFormat::Yuyv => Some(PixelFormat::Yuyv), + hwcodec::capture::CapturePixelFormat::Yvyu => Some(PixelFormat::Yvyu), + hwcodec::capture::CapturePixelFormat::Uyvy => Some(PixelFormat::Uyvy), + hwcodec::capture::CapturePixelFormat::Nv12 => Some(PixelFormat::Nv12), + hwcodec::capture::CapturePixelFormat::Nv21 => Some(PixelFormat::Nv21), + hwcodec::capture::CapturePixelFormat::Nv16 => Some(PixelFormat::Nv16), + hwcodec::capture::CapturePixelFormat::Nv24 => Some(PixelFormat::Nv24), + hwcodec::capture::CapturePixelFormat::Yuv420 => Some(PixelFormat::Yuv420), + hwcodec::capture::CapturePixelFormat::Yvu420 => Some(PixelFormat::Yvu420), + hwcodec::capture::CapturePixelFormat::Rgb24 => Some(PixelFormat::Rgb24), + hwcodec::capture::CapturePixelFormat::Bgr24 => Some(PixelFormat::Bgr24), + hwcodec::capture::CapturePixelFormat::Grey => Some(PixelFormat::Grey), + hwcodec::capture::CapturePixelFormat::Unknown => None, + } +} + +fn normalize_fps_list(fps_list: &mut Vec) { + fps_list.retain(|fps| fps.is_finite() && *fps > 0.0); + for fps in fps_list.iter_mut() { + *fps = (*fps * 100.0).round() / 100.0; + } + fps_list.sort_by(|a, b| b.total_cmp(a)); + fps_list.dedup_by(|a, b| (*a - *b).abs() < 0.01); +} + +fn format_description(format: PixelFormat) -> &'static str { + match format { + PixelFormat::Mjpeg => "MJPEG", + PixelFormat::Jpeg => "JPEG", + PixelFormat::Yuyv => "YUYV 4:2:2", + PixelFormat::Yvyu => "YVYU 4:2:2", + PixelFormat::Uyvy => "UYVY 4:2:2", + PixelFormat::Nv12 => "NV12", + PixelFormat::Nv21 => "NV21", + PixelFormat::Nv16 => "NV16", + PixelFormat::Nv24 => "NV24", + PixelFormat::Yuv420 => "YUV420", + PixelFormat::Yvu420 => "YVU420", + PixelFormat::Rgb565 => "RGB565", + PixelFormat::Rgb24 => "RGB24", + PixelFormat::Bgr24 => "BGR24", + PixelFormat::Grey => "GREY", + } +} + +fn score_capture_device(name: &str, device_id: &str, formats: &[FormatInfo]) -> u32 { + let haystack = format!("{} {}", name, device_id).to_ascii_lowercase(); + let mut score = 50; + + if formats + .iter() + .any(|format| format.format == PixelFormat::Mjpeg) + { + score += 25; + } + for keyword in ["capture", "hdmi", "uvc", "video", "usb"] { + if haystack.contains(keyword) { + score += 10; + } + } + + score +} + +fn fallback_windows_formats() -> Vec { + vec![FormatInfo { + format: PixelFormat::Mjpeg, + resolutions: Vec::new(), + description: "DirectShow auto-detected stream format".to_string(), + }] +} diff --git a/src/video/format.rs b/src/video/format.rs index ae55b27e..b35142d7 100644 --- a/src/video/format.rs +++ b/src/video/format.rs @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize}; use std::fmt; +#[cfg(unix)] use v4l2r::PixelFormat as V4l2rPixelFormat; /// Supported pixel formats @@ -85,11 +86,13 @@ impl PixelFormat { } /// Convert to v4l2r PixelFormat + #[cfg(unix)] pub fn to_v4l2r(&self) -> V4l2rPixelFormat { V4l2rPixelFormat::from(&self.to_fourcc()) } /// Convert from v4l2r PixelFormat + #[cfg(unix)] pub fn from_v4l2r(format: V4l2rPixelFormat) -> Option { let repr: [u8; 4] = format.into(); Self::from_fourcc(repr) diff --git a/src/video/mod.rs b/src/video/mod.rs index fa64949c..d3fff9d9 100644 --- a/src/video/mod.rs +++ b/src/video/mod.rs @@ -2,81 +2,30 @@ //! //! This module provides V4L2 video capture, encoding, and streaming functionality. -pub(crate) mod capture_limits; -pub(crate) mod capture_status; +pub mod capture; +pub mod codec; pub mod codec_constraints; -pub mod convert; -pub mod csi_bridge; -pub mod decoder; pub mod device; -pub mod encoder; pub mod format; pub mod frame; -pub mod shared_video_pipeline; +pub mod pipeline; +pub mod signal; pub mod stream_manager; pub mod streamer; pub mod traits; pub mod types; -pub mod usb_reset; -pub mod v4l2r_capture; -pub use convert::{PixelConverter, Yuv420pBuffer}; +pub use codec::{H264Encoder, H264EncoderType, JpegEncoder, PixelConverter, Yuv420pBuffer}; pub use device::{VideoDevice, VideoDeviceInfo}; -pub use encoder::{H264Encoder, H264EncoderType, JpegEncoder}; pub use format::PixelFormat; pub use frame::VideoFrame; -pub use shared_video_pipeline::{ +pub use pipeline::{ EncodedVideoFrame, SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats, }; +pub use signal::SignalStatus; pub use stream_manager::VideoStreamManager; pub use streamer::{Streamer, StreamerState}; -/// Fine-grained signal status reported by CSI/HDMI bridge devices. -/// -/// Only `rk_hdmirx` / `rkcif` / tc358743-class bridges can distinguish these -/// via `VIDIOC_QUERY_DV_TIMINGS` errno; USB UVC devices always report `Ok` -/// until they fail with a generic timeout. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SignalStatus { - /// HDMI cable physically disconnected (`ENOLINK`). - NoCable, - /// TMDS signal present but timings cannot be locked (`ENOLCK`). - NoSync, - /// Timings outside of hardware capability (`ERANGE`). - OutOfRange, - /// Generic "no usable source" (fallback for EINVAL / EIO / unknown errnos). - NoSignal, - /// UVC/USB isochronous protocol error (common kernel: status -71 / userspace EPROTO). - UvcUsbError, - /// UVC capture stalled (repeated DQBUF timeouts; often cable, hub, or controller load). - UvcCaptureStall, -} - -impl SignalStatus { - pub fn as_str(self) -> &'static str { - match self { - SignalStatus::NoCable => "no_cable", - SignalStatus::NoSync => "no_sync", - SignalStatus::OutOfRange => "out_of_range", - SignalStatus::NoSignal => "no_signal", - SignalStatus::UvcUsbError => "uvc_usb_error", - SignalStatus::UvcCaptureStall => "uvc_capture_stall", - } - } - - pub fn from_str(s: &str) -> Option { - Some(match s { - "no_cable" => SignalStatus::NoCable, - "no_sync" => SignalStatus::NoSync, - "out_of_range" => SignalStatus::OutOfRange, - "no_signal" => SignalStatus::NoSignal, - "uvc_usb_error" => SignalStatus::UvcUsbError, - "uvc_capture_stall" => SignalStatus::UvcCaptureStall, - _ => return None, - }) - } -} - impl From for streamer::StreamerState { fn from(value: SignalStatus) -> Self { match value { @@ -89,21 +38,3 @@ impl From for streamer::StreamerState { } } } - -pub(crate) fn is_rk_hdmirx_driver(driver: &str, card: &str) -> bool { - driver.eq_ignore_ascii_case("rk_hdmirx") || card.eq_ignore_ascii_case("rk_hdmirx") -} - -pub(crate) fn is_rk_hdmirx_device(device: &device::VideoDeviceInfo) -> bool { - is_rk_hdmirx_driver(&device.driver, &device.card) -} - -pub(crate) fn is_rkcif_driver(driver: &str) -> bool { - driver.eq_ignore_ascii_case("rkcif") -} - -/// Unified check for CSI/HDMI bridge devices (rk_hdmirx, rkcif, etc.) -/// that require special enumeration and format-selection logic. -pub(crate) fn is_csi_hdmi_bridge(device: &device::VideoDeviceInfo) -> bool { - is_rk_hdmirx_device(device) || is_rkcif_driver(&device.driver) -} diff --git a/src/video/shared_video_pipeline/encoder_state.rs b/src/video/pipeline/encoder_state.rs similarity index 97% rename from src/video/shared_video_pipeline/encoder_state.rs rename to src/video/pipeline/encoder_state.rs index b00f03a1..612df68b 100644 --- a/src/video/shared_video_pipeline/encoder_state.rs +++ b/src/video/pipeline/encoder_state.rs @@ -1,12 +1,12 @@ use crate::error::{AppError, Result}; -use crate::video::convert::{Nv12Converter, PixelConverter}; -use crate::video::decoder::MjpegTurboDecoder; -use crate::video::encoder::h264::{H264Config, H264Encoder, H264InputFormat}; -use crate::video::encoder::h265::{H265Config, H265Encoder, H265InputFormat}; -use crate::video::encoder::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType}; -use crate::video::encoder::traits::EncoderConfig; -use crate::video::encoder::vp8::{VP8Config, VP8Encoder}; -use crate::video::encoder::vp9::{VP9Config, VP9Encoder}; +use crate::video::codec::convert::{Nv12Converter, PixelConverter}; +use crate::video::codec::h264::{H264Config, H264Encoder, H264InputFormat}; +use crate::video::codec::h265::{H265Config, H265Encoder, H265InputFormat}; +use crate::video::codec::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType}; +use crate::video::codec::traits::EncoderConfig; +use crate::video::codec::vp8::{VP8Config, VP8Encoder}; +use crate::video::codec::vp9::{VP9Config, VP9Encoder}; +use crate::video::codec::MjpegTurboDecoder; use crate::video::format::{PixelFormat, Resolution}; #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] use hwcodec::ffmpeg_hw::{ @@ -14,7 +14,7 @@ use hwcodec::ffmpeg_hw::{ }; use tracing::info; -use super::SharedVideoPipelineConfig; +use super::shared::SharedVideoPipelineConfig; pub(super) struct EncoderThreadState { pub(super) encoder: Option>, diff --git a/src/video/pipeline/mod.rs b/src/video/pipeline/mod.rs new file mode 100644 index 00000000..9a43bbd4 --- /dev/null +++ b/src/video/pipeline/mod.rs @@ -0,0 +1,9 @@ +//! Video processing pipelines. + +mod encoder_state; +mod shared; + +pub use shared::{ + EncodedVideoFrame, PipelineStateNotification, SharedVideoPipeline, SharedVideoPipelineConfig, + SharedVideoPipelineStats, +}; diff --git a/src/video/shared_video_pipeline.rs b/src/video/pipeline/shared.rs similarity index 93% rename from src/video/shared_video_pipeline.rs rename to src/video/pipeline/shared.rs index 8d9402a3..42c96870 100644 --- a/src/video/shared_video_pipeline.rs +++ b/src/video/pipeline/shared.rs @@ -16,8 +16,6 @@ //! Session1 Session2 Session3 ... //! ``` -mod encoder_state; - use bytes::Bytes; use parking_lot::Mutex as ParkingMutex; use parking_lot::RwLock as ParkingRwLock; @@ -28,7 +26,7 @@ use std::time::{Duration, Instant}; use tokio::sync::{mpsc, watch, Mutex, RwLock}; use tracing::{debug, error, info, trace, warn}; -use self::encoder_state::{build_encoder_state, EncoderThreadState}; +use super::encoder_state::{build_encoder_state, EncoderThreadState}; /// Grace period before auto-stopping pipeline when no subscribers (in seconds) const AUTO_STOP_GRACE_PERIOD_SECS: u64 = 3; @@ -41,20 +39,27 @@ const NOSIGNAL_POLL_MAX: Duration = Duration::from_secs(20); /// Throttle repeated encoding errors to avoid log flooding const ENCODE_ERROR_THROTTLE_SECS: u64 = 5; +static PROCESS_START: std::sync::OnceLock = std::sync::OnceLock::new(); + use crate::error::{AppError, Result}; use crate::utils::LogThrottler; -use crate::video::capture_limits::{should_validate_jpeg_frame, MIN_CAPTURE_FRAME_SIZE}; -use crate::video::capture_status::{ +use crate::video::capture::runtime::{ + open_capture_stream, open_capture_stream_for_retry, CaptureOpenResult, +}; +use crate::video::capture::status::{ capture_error_log_key, classify_capture_io_error, is_device_lost_message, signal_status_from_capture_kind, CaptureIoErrorKind, }; -use crate::video::csi_bridge::{self, ProbeResult}; +use crate::video::capture::{is_source_changed_error, BridgeContext, CaptureStream}; +use crate::video::codec::h264_bitstream; +use crate::video::codec::registry::{EncoderBackend, VideoEncoderType}; +use crate::video::device::bridge::{self as csi_bridge, ProbeResult}; use crate::video::device::parse_bridge_kind; -use crate::video::encoder::registry::{EncoderBackend, VideoEncoderType}; use crate::video::format::{PixelFormat, Resolution}; use crate::video::frame::{FrameBuffer, FrameBufferPool, VideoFrame}; -use crate::video::v4l2r_capture::{is_source_changed_error, BridgeContext, V4l2rCaptureStream}; -use crate::video::SignalStatus; +use crate::video::signal::SignalStatus; + +const MIN_CAPTURE_FRAME_SIZE: usize = 128; #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] use hwcodec::ffmpeg_hw::last_error_message as ffmpeg_hw_last_error; @@ -122,7 +127,7 @@ pub struct SharedVideoPipelineConfig { /// Output codec type pub output_codec: VideoEncoderType, /// Bitrate preset (replaces raw bitrate_kbps) - pub bitrate_preset: crate::video::encoder::BitratePreset, + pub bitrate_preset: crate::video::codec::BitratePreset, /// Target FPS pub fps: u32, /// Encoder backend (None = auto select best available) @@ -135,7 +140,7 @@ impl Default for SharedVideoPipelineConfig { resolution: Resolution::HD720, input_format: PixelFormat::Yuyv, output_codec: VideoEncoderType::H264, - bitrate_preset: crate::video::encoder::BitratePreset::Balanced, + bitrate_preset: crate::video::codec::BitratePreset::Balanced, fps: 30, encoder_backend: None, } @@ -154,7 +159,7 @@ impl SharedVideoPipelineConfig { } /// Create H264 config with bitrate preset - pub fn h264(resolution: Resolution, preset: crate::video::encoder::BitratePreset) -> Self { + pub fn h264(resolution: Resolution, preset: crate::video::codec::BitratePreset) -> Self { Self { resolution, output_codec: VideoEncoderType::H264, @@ -164,7 +169,7 @@ impl SharedVideoPipelineConfig { } /// Create H265 config with bitrate preset - pub fn h265(resolution: Resolution, preset: crate::video::encoder::BitratePreset) -> Self { + pub fn h265(resolution: Resolution, preset: crate::video::codec::BitratePreset) -> Self { Self { resolution, output_codec: VideoEncoderType::H265, @@ -174,7 +179,7 @@ impl SharedVideoPipelineConfig { } /// Create VP8 config with bitrate preset - pub fn vp8(resolution: Resolution, preset: crate::video::encoder::BitratePreset) -> Self { + pub fn vp8(resolution: Resolution, preset: crate::video::codec::BitratePreset) -> Self { Self { resolution, output_codec: VideoEncoderType::VP8, @@ -184,7 +189,7 @@ impl SharedVideoPipelineConfig { } /// Create VP9 config with bitrate preset - pub fn vp9(resolution: Resolution, preset: crate::video::encoder::BitratePreset) -> Self { + pub fn vp9(resolution: Resolution, preset: crate::video::codec::BitratePreset) -> Self { Self { resolution, output_codec: VideoEncoderType::VP9, @@ -195,7 +200,7 @@ impl SharedVideoPipelineConfig { /// Create config with legacy bitrate_kbps (for compatibility during migration) pub fn with_bitrate_kbps(mut self, bitrate_kbps: u32) -> Self { - self.bitrate_preset = crate::video::encoder::BitratePreset::from_kbps(bitrate_kbps); + self.bitrate_preset = crate::video::codec::BitratePreset::from_kbps(bitrate_kbps); self } } @@ -261,6 +266,8 @@ pub struct SharedVideoPipeline { stats: Mutex, running: watch::Sender, running_rx: watch::Receiver, + h264_profile_level_id: watch::Sender>, + h264_profile_level_id_rx: watch::Receiver>, cmd_tx: ParkingRwLock>>, /// Fast running flag for blocking capture loop running_flag: AtomicBool, @@ -268,9 +275,9 @@ pub struct SharedVideoPipeline { sequence: AtomicU64, /// Atomic flag for keyframe request (avoids lock contention) keyframe_requested: AtomicBool, - /// Pipeline start time for PTS calculation (epoch millis, 0 = not set) - /// Uses AtomicI64 instead of Mutex for lock-free access - pipeline_start_time_ms: AtomicI64, + /// Pipeline start time for monotonic PTS calculation (microseconds from process start). + /// Uses AtomicI64 instead of Mutex for lock-free access. + pipeline_start_time_us: AtomicI64, pending_sync_geometry: ParkingMutex>, device_lost_reason: ParkingMutex>, state_notifier: ParkingRwLock>>, @@ -365,6 +372,7 @@ impl SharedVideoPipeline { ); let (running_tx, running_rx) = watch::channel(false); + let (h264_profile_tx, h264_profile_rx) = watch::channel(None); let pipeline = Arc::new(Self { config: RwLock::new(config), @@ -372,11 +380,13 @@ impl SharedVideoPipeline { stats: Mutex::new(SharedVideoPipelineStats::default()), running: running_tx, running_rx, + h264_profile_level_id: h264_profile_tx, + h264_profile_level_id_rx: h264_profile_rx, cmd_tx: ParkingRwLock::new(None), running_flag: AtomicBool::new(false), sequence: AtomicU64::new(0), keyframe_requested: AtomicBool::new(false), - pipeline_start_time_ms: AtomicI64::new(0), + pipeline_start_time_us: AtomicI64::new(0), pending_sync_geometry: ParkingMutex::new(None), device_lost_reason: ParkingMutex::new(None), state_notifier: ParkingRwLock::new(None), @@ -518,6 +528,20 @@ impl SharedVideoPipeline { self.running_rx.clone() } + pub fn h264_profile_level_id_watch(&self) -> watch::Receiver> { + self.h264_profile_level_id_rx.clone() + } + + fn update_h264_profile_level_id(&self, data: &[u8]) { + let Some(profile_level_id) = h264_bitstream::extract_profile_level_id(data) else { + return; + }; + if self.h264_profile_level_id.borrow().as_deref() == Some(profile_level_id.as_str()) { + return; + } + let _ = self.h264_profile_level_id.send(Some(profile_level_id)); + } + async fn broadcast_encoded(&self, frame: Arc) { let subscribers = { let guard = self.subscribers.read(); @@ -568,7 +592,7 @@ impl SharedVideoPipeline { subdev_path.clone(), parse_bridge_kind(bridge_kind.as_deref()), ); - let preopened: Option = match V4l2rCaptureStream::open_with_bridge( + let preopened: Option = match open_capture_stream( &device_path, config.resolution, config.input_format, @@ -712,7 +736,7 @@ impl SharedVideoPipeline { let bridge_ctx = BridgeContext::from_parts(subdev_path, parse_bridge_kind(bridge_kind.as_deref())); std::thread::spawn(move || { - let mut stream: Option = None; + let mut stream: Option = None; let mut initial_geometry: Option<(Resolution, PixelFormat)> = None; let mut resolution = config.resolution; let mut pixel_format = config.input_format; @@ -727,7 +751,7 @@ impl SharedVideoPipeline { stream = Some(s); } None => { - match V4l2rCaptureStream::open_with_bridge( + match open_capture_stream( &device_path, config.resolution, config.input_format, @@ -786,24 +810,13 @@ impl SharedVideoPipeline { } } - /// Helper: try to (re)open the capture stream. Returns: - /// * `Ok(Some(stream))` — opened successfully - /// * `Ok(None)` — CaptureNoSignal, keep retrying later - /// * `Err(())` — fatal (stop pipeline) - enum OpenResult { - Opened(V4l2rCaptureStream), - NoSignal(SignalStatus), - DeviceLost(String), - Fatal, - } - fn open_or_retry( device_path: &std::path::Path, config: &SharedVideoPipelineConfig, buffer_count: u32, bridge_ctx: BridgeContext, - ) -> OpenResult { - match V4l2rCaptureStream::open_with_bridge( + ) -> CaptureOpenResult { + match open_capture_stream_for_retry( device_path, config.resolution, config.input_format, @@ -811,28 +824,27 @@ impl SharedVideoPipeline { buffer_count.max(1), Duration::from_secs(2), bridge_ctx, + is_device_lost_message, ) { - Ok(s) => OpenResult::Opened(s), - Err(AppError::CaptureNoSignal { kind }) => { - debug!("Capture soft-restart: still no signal ({})", kind); - OpenResult::NoSignal(signal_status_from_capture_kind(&kind)) + CaptureOpenResult::NoSignal(status) => { + debug!("Capture soft-restart: still no signal ({:?})", status); + CaptureOpenResult::NoSignal(status) } - Err(e) => { - let reason = e.to_string(); - if is_device_lost_message(&reason) { - error!("Capture device lost during soft-restart: {}", e); - return OpenResult::DeviceLost(reason); - } - error!("Capture soft-restart failed: {}", e); - OpenResult::Fatal + CaptureOpenResult::DeviceLost(reason) => { + error!("Capture device lost during soft-restart: {}", reason); + CaptureOpenResult::DeviceLost(reason) } + CaptureOpenResult::Fatal => { + error!("Capture soft-restart failed"); + CaptureOpenResult::Fatal + } + opened => opened, } } let mut no_subscribers_since: Option = None; let grace_period = Duration::from_secs(AUTO_STOP_GRACE_PERIOD_SECS); let mut sequence: u64 = 0; - let mut validate_counter: u64 = 0; let mut consecutive_timeouts: u32 = 0; let capture_error_throttler = LogThrottler::with_secs(5); let mut suppressed_capture_errors: HashMap = HashMap::new(); @@ -869,7 +881,7 @@ impl SharedVideoPipeline { if stream.is_none() { match open_or_retry(&device_path, &config, buffer_count, bridge_ctx.clone()) { - OpenResult::Opened(new_stream) => { + CaptureOpenResult::Opened(new_stream) => { let new_res = new_stream.resolution(); let new_fmt = new_stream.format(); let new_stride = new_stream.stride(); @@ -945,7 +957,7 @@ impl SharedVideoPipeline { resolution.width, resolution.height, pixel_format, stride ); } - OpenResult::NoSignal(status) => { + CaptureOpenResult::NoSignal(status) => { consecutive_timeouts = consecutive_timeouts.saturating_add(1); if consecutive_timeouts >= CAPTURE_TIMEOUT_STOP_THRESHOLD { warn!( @@ -966,14 +978,14 @@ impl SharedVideoPipeline { std::thread::sleep(Duration::from_millis(wait_ms)); continue; } - OpenResult::DeviceLost(reason) => { + CaptureOpenResult::DeviceLost(reason) => { pipeline.mark_device_lost(reason); let _ = pipeline.running.send(false); pipeline.running_flag.store(false, Ordering::Release); let _ = frame_seq_tx.send(sequence.wrapping_add(1)); break; } - OpenResult::Fatal => { + CaptureOpenResult::Fatal => { let _ = pipeline.running.send(false); pipeline.running_flag.store(false, Ordering::Release); let _ = frame_seq_tx.send(sequence.wrapping_add(1)); @@ -1194,14 +1206,6 @@ impl SharedVideoPipeline { continue; } - validate_counter = validate_counter.wrapping_add(1); - if pixel_format.is_compressed() - && should_validate_jpeg_frame(validate_counter) - && !VideoFrame::is_valid_jpeg_bytes(&owned[..frame_size]) - { - continue; - } - owned.truncate(frame_size); // Notify streaming only after frame validation passes — // stale/warm-up frames from V4L2 kernel queues can cause @@ -1246,27 +1250,22 @@ impl SharedVideoPipeline { let input_format = state.input_format; let raw_frame = frame.data(); - // Calculate PTS from real capture timestamp (lock-free using AtomicI64) - // This ensures smooth playback even when capture timing varies - let frame_ts_ms = frame.capture_ts.elapsed().as_millis() as i64; - // Convert Instant to a comparable value (negate elapsed to get "time since epoch") - let current_ts_ms = -(frame_ts_ms); - - // Try to set start time if not yet set (first frame wins) - let start_ts = self.pipeline_start_time_ms.load(Ordering::Acquire); - let pts_ms = if start_ts == 0 { - // First frame - try to set the start time - // Use compare_exchange to ensure only one thread sets it - let _ = self.pipeline_start_time_ms.compare_exchange( + let process_start = PROCESS_START.get_or_init(Instant::now); + let current_ts_us = process_start.elapsed().as_micros() as i64; + let start_ts_us = self.pipeline_start_time_us.load(Ordering::Acquire); + let pts_ms = if start_ts_us == 0 { + let start_ts_us = match self.pipeline_start_time_us.compare_exchange( 0, - current_ts_ms, + current_ts_us, Ordering::AcqRel, Ordering::Acquire, - ); - 0 // First frame has PTS 0 + ) { + Ok(_) => current_ts_us, + Err(existing) => existing, + }; + current_ts_us.saturating_sub(start_ts_us) / 1000 } else { - // Subsequent frames: PTS = current - start - current_ts_ms - start_ts + current_ts_us.saturating_sub(start_ts_us) / 1000 }; #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] @@ -1370,6 +1369,9 @@ impl SharedVideoPipeline { let encoded = frames.into_iter().next().unwrap(); let is_keyframe = encoded.key == 1; let sequence = self.sequence.fetch_add(1, Ordering::Relaxed) + 1; + if codec == VideoEncoderType::H264 { + self.update_h264_profile_level_id(&encoded.data); + } // Debug log for H265 encoded frame if codec == VideoEncoderType::H265 && (is_keyframe || frame_count % 30 == 1) { @@ -1464,7 +1466,7 @@ impl SharedVideoPipeline { /// Set bitrate using preset pub async fn set_bitrate_preset( &self, - preset: crate::video::encoder::BitratePreset, + preset: crate::video::codec::BitratePreset, ) -> Result<()> { let bitrate_kbps = preset.bitrate_kbps(); let gop = { @@ -1478,7 +1480,7 @@ impl SharedVideoPipeline { /// Set bitrate using raw kbps value (converts to appropriate preset) pub async fn set_bitrate(&self, bitrate_kbps: u32) -> Result<()> { - let preset = crate::video::encoder::BitratePreset::from_kbps(bitrate_kbps); + let preset = crate::video::codec::BitratePreset::from_kbps(bitrate_kbps); self.set_bitrate_preset(preset).await } @@ -1549,7 +1551,7 @@ fn parse_h265_nal_types(data: &[u8]) -> Vec<(u8, usize)> { #[cfg(test)] mod tests { use super::*; - use crate::video::encoder::BitratePreset; + use crate::video::codec::BitratePreset; #[test] fn test_pipeline_config() { diff --git a/src/video/signal.rs b/src/video/signal.rs new file mode 100644 index 00000000..9c18614c --- /dev/null +++ b/src/video/signal.rs @@ -0,0 +1,47 @@ +//! Video signal status classification. + +/// Fine-grained signal status reported by CSI/HDMI bridge devices. +/// +/// Only `rk_hdmirx` / `rkcif` / tc358743-class bridges can distinguish these +/// via `VIDIOC_QUERY_DV_TIMINGS` errno; USB UVC devices always report `Ok` +/// until they fail with a generic timeout. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SignalStatus { + /// HDMI cable physically disconnected (`ENOLINK`). + NoCable, + /// TMDS signal present but timings cannot be locked (`ENOLCK`). + NoSync, + /// Timings outside of hardware capability (`ERANGE`). + OutOfRange, + /// Generic "no usable source" (fallback for EINVAL / EIO / unknown errnos). + NoSignal, + /// UVC/USB isochronous protocol error (common kernel: status -71 / userspace EPROTO). + UvcUsbError, + /// UVC capture stalled (repeated DQBUF timeouts; often cable, hub, or controller load). + UvcCaptureStall, +} + +impl SignalStatus { + pub fn as_str(self) -> &'static str { + match self { + SignalStatus::NoCable => "no_cable", + SignalStatus::NoSync => "no_sync", + SignalStatus::OutOfRange => "out_of_range", + SignalStatus::NoSignal => "no_signal", + SignalStatus::UvcUsbError => "uvc_usb_error", + SignalStatus::UvcCaptureStall => "uvc_capture_stall", + } + } + + pub fn from_str(s: &str) -> Option { + Some(match s { + "no_cable" => SignalStatus::NoCable, + "no_sync" => SignalStatus::NoSync, + "out_of_range" => SignalStatus::OutOfRange, + "no_signal" => SignalStatus::NoSignal, + "uvc_usb_error" => SignalStatus::UvcUsbError, + "uvc_capture_stall" => SignalStatus::UvcCaptureStall, + _ => return None, + }) + } +} diff --git a/src/video/stream_manager.rs b/src/video/stream_manager.rs index cddad1a4..3ee22698 100644 --- a/src/video/stream_manager.rs +++ b/src/video/stream_manager.rs @@ -37,8 +37,8 @@ use crate::events::{EventBus, SystemEvent, VideoDeviceInfo}; use crate::hid::HidController; use crate::stream::MjpegStreamHandler; use crate::video::codec_constraints::StreamCodecConstraints; +use crate::video::device::is_csi_hdmi_bridge; use crate::video::format::{PixelFormat, Resolution}; -use crate::video::is_csi_hdmi_bridge; use crate::video::streamer::{Streamer, StreamerState, StreamerStats}; use crate::video::traits::VideoOutput; @@ -762,9 +762,7 @@ impl VideoStreamManager { pub async fn subscribe_encoded_frames( &self, ) -> Option< - tokio::sync::mpsc::Receiver< - std::sync::Arc, - >, + tokio::sync::mpsc::Receiver>, > { // 1. Ensure video capture is initialized (for config discovery) if self.streamer.state().await == StreamerState::Uninitialized { @@ -803,12 +801,12 @@ impl VideoStreamManager { /// Get the current video encoding configuration from the shared pipeline pub async fn get_encoding_config( &self, - ) -> Option { + ) -> Option { self.webrtc_streamer.get_pipeline_config().await } /// Get current video codec type - pub async fn current_video_codec(&self) -> crate::video::encoder::VideoCodecType { + pub async fn current_video_codec(&self) -> crate::video::codec::VideoCodecType { self.webrtc_streamer.current_video_codec().await } @@ -823,7 +821,7 @@ impl VideoStreamManager { /// before subscribing to encoded frames. pub async fn set_video_codec( &self, - codec: crate::video::encoder::VideoCodecType, + codec: crate::video::codec::VideoCodecType, ) -> crate::error::Result<()> { self.webrtc_streamer.set_video_codec(codec).await } @@ -834,7 +832,7 @@ impl VideoStreamManager { /// based on client preferences. pub async fn set_bitrate_preset( &self, - preset: crate::video::encoder::BitratePreset, + preset: crate::video::codec::BitratePreset, ) -> crate::error::Result<()> { self.webrtc_streamer.set_bitrate_preset(preset).await } @@ -908,19 +906,19 @@ impl VideoStreamManager { } /// Convert VideoCodecType to lowercase string for frontend -fn codec_to_string(codec: crate::video::encoder::VideoCodecType) -> String { +fn codec_to_string(codec: crate::video::codec::VideoCodecType) -> String { match codec { - crate::video::encoder::VideoCodecType::H264 => "h264".to_string(), - crate::video::encoder::VideoCodecType::H265 => "h265".to_string(), - crate::video::encoder::VideoCodecType::VP8 => "vp8".to_string(), - crate::video::encoder::VideoCodecType::VP9 => "vp9".to_string(), + crate::video::codec::VideoCodecType::H264 => "h264".to_string(), + crate::video::codec::VideoCodecType::H265 => "h265".to_string(), + crate::video::codec::VideoCodecType::VP8 => "vp8".to_string(), + crate::video::codec::VideoCodecType::VP9 => "vp9".to_string(), } } #[cfg(test)] mod tests { use super::*; - use crate::video::encoder::VideoCodecType; + use crate::video::codec::VideoCodecType; #[test] fn test_codec_to_string() { diff --git a/src/video/streamer.rs b/src/video/streamer.rs index 364069d5..57176182 100644 --- a/src/video/streamer.rs +++ b/src/video/streamer.rs @@ -11,24 +11,25 @@ use std::time::Duration; use tokio::sync::RwLock; use tracing::{debug, error, info, trace, warn}; -use super::csi_bridge; use super::device::{ - enumerate_devices, find_best_device, parse_bridge_kind, select_recovery_device, VideoDevice, - VideoDeviceInfo, VideoDeviceRecoveryHint, + bridge as csi_bridge, enumerate_devices, find_best_device, is_csi_hdmi_bridge, + parse_bridge_kind, select_recovery_device, VideoDevice, VideoDeviceInfo, + VideoDeviceRecoveryHint, }; use super::format::{PixelFormat, Resolution}; use super::frame::{FrameBuffer, FrameBufferPool, VideoFrame}; -use super::is_csi_hdmi_bridge; use crate::error::{AppError, Result}; use crate::events::{EventBus, StreamDeviceLostKind, SystemEvent}; use crate::stream::MjpegStreamHandler; use crate::utils::LogThrottler; -use crate::video::capture_limits::{should_validate_jpeg_frame, MIN_CAPTURE_FRAME_SIZE}; -use crate::video::capture_status::{ +use crate::video::capture::runtime::open_capture_stream; +use crate::video::capture::status::{ capture_error_log_key, classify_capture_io_error, signal_status_from_capture_kind, CaptureIoErrorKind, }; -use crate::video::v4l2r_capture::{is_source_changed_error, BridgeContext, V4l2rCaptureStream}; +use crate::video::capture::{is_source_changed_error, BridgeContext, CaptureStream}; + +const MIN_CAPTURE_FRAME_SIZE: usize = 128; /// Streamer configuration #[derive(Debug, Clone)] @@ -358,10 +359,17 @@ impl Streamer { self.publish_event(self.current_state_event().await).await; let devices = enumerate_devices()?; - let device = devices - .into_iter() - .find(|d| d.path.to_string_lossy() == device_path) - .ok_or_else(|| AppError::VideoError("Video device not found".to_string()))?; + let device = if device_path.eq_ignore_ascii_case("auto") { + devices + .into_iter() + .next() + .ok_or_else(|| AppError::VideoError("No video devices found".to_string()))? + } else { + devices + .into_iter() + .find(|d| d.path.to_string_lossy() == device_path) + .ok_or_else(|| AppError::VideoError("Video device not found".to_string()))? + }; let (format, resolution) = self.resolve_capture_config(&device, format, resolution)?; @@ -853,12 +861,12 @@ impl Streamer { // On RK628 this prevents a kernel null-pointer deref. if let Some(subdev_path) = bridge_ctx.subdev_path.as_ref() { match probe_subdev_signal(subdev_path, bridge_ctx.kind) { - Some(crate::video::SignalStatus::NoCable) - | Some(crate::video::SignalStatus::NoSync) - | Some(crate::video::SignalStatus::NoSignal) - | Some(crate::video::SignalStatus::OutOfRange) => { + Some(crate::video::signal::SignalStatus::NoCable) + | Some(crate::video::signal::SignalStatus::NoSync) + | Some(crate::video::signal::SignalStatus::NoSignal) + | Some(crate::video::signal::SignalStatus::OutOfRange) => { let status = probe_subdev_signal(subdev_path, bridge_ctx.kind) - .unwrap_or(crate::video::SignalStatus::NoSignal); + .unwrap_or(crate::video::signal::SignalStatus::NoSignal); let wait_secs = backoff_secs(no_signal_restart_count); debug!( "Pre-STREAMON gate: subdev {:?} reports {:?} — \ @@ -884,7 +892,7 @@ impl Streamer { } // ── Open the capture stream ───────────────────────────────────────── - let mut stream_opt: Option = None; + let mut stream_opt: Option = None; let mut last_error: Option = None; for attempt in 0..MAX_RETRIES { @@ -893,7 +901,7 @@ impl Streamer { return; } - match V4l2rCaptureStream::open_with_bridge( + match open_capture_stream( &device_path, config.resolution, config.format, @@ -985,7 +993,6 @@ impl Streamer { let buffer_pool = Arc::new(FrameBufferPool::new(BUFFER_COUNT.max(4) as usize)); let mut signal_present = true; - let mut validate_counter: u64 = 0; let mut idle_since: Option = None; let mut fps_frame_count: u64 = 0; @@ -1091,7 +1098,7 @@ impl Streamer { break 'capture; } CaptureIoErrorKind::TransientSignal { status } => { - if status == Some(crate::video::SignalStatus::UvcUsbError) { + if status == Some(crate::video::signal::SignalStatus::UvcUsbError) { warn!( "Capture transient error (EPROTO/-71, often UVC USB): {}", e @@ -1145,14 +1152,6 @@ impl Streamer { continue 'capture; } - validate_counter = validate_counter.wrapping_add(1); - if pixel_format.is_compressed() - && should_validate_jpeg_frame(validate_counter) - && !VideoFrame::is_valid_jpeg_bytes(&owned[..frame_size]) - { - continue 'capture; - } - owned.truncate(frame_size); let frame = VideoFrame::from_pooled( Arc::new(FrameBuffer::new(owned, Some(buffer_pool.clone()))), @@ -1275,7 +1274,7 @@ impl Streamer { // Reset no_signal_since so the back-off timer is fresh for the new session. // no_signal_since will be re-set if the new session immediately times out. - // Continue 'session → re-open V4l2rCaptureStream with updated config. + // Continue 'session → re-open CaptureStream with updated config. } // 'session self.direct_active.store(false, Ordering::SeqCst); @@ -1580,7 +1579,7 @@ pub struct StreamerStats { fn probe_subdev_signal( subdev_path: &std::path::Path, kind: Option, -) -> Option { +) -> Option { let fd = match csi_bridge::open_subdev(subdev_path) { Ok(f) => f, Err(e) => { @@ -1588,7 +1587,7 @@ fn probe_subdev_signal( "probe_subdev_signal: failed to open {:?}: {}", subdev_path, e ); - return Some(crate::video::SignalStatus::NoSignal); + return Some(crate::video::signal::SignalStatus::NoSignal); } }; let kind = kind.unwrap_or(csi_bridge::CsiBridgeKind::Unknown); diff --git a/src/video/types.rs b/src/video/types.rs index c6333b2d..56ad9f9f 100644 --- a/src/video/types.rs +++ b/src/video/types.rs @@ -9,14 +9,14 @@ pub use super::format::{PixelFormat, Resolution}; // From video::frame pub use super::frame::VideoFrame; -// From video::encoder (codec-level types) -pub use super::encoder::{BitratePreset, VideoCodecType}; +// From video::codec (codec-level types) +pub use super::codec::{BitratePreset, VideoCodecType}; -// From video::encoder::registry -pub use super::encoder::registry::{EncoderBackend, VideoEncoderType}; +// From video::codec::registry +pub use super::codec::registry::{EncoderBackend, VideoEncoderType}; -// From video::shared_video_pipeline -pub use super::shared_video_pipeline::{ +// From video::pipeline +pub use super::pipeline::{ EncodedVideoFrame, PipelineStateNotification, SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats, }; diff --git a/src/web/handlers/account.rs b/src/web/handlers/account.rs new file mode 100644 index 00000000..5ac6c668 --- /dev/null +++ b/src/web/handlers/account.rs @@ -0,0 +1,151 @@ +use super::*; + +/// Change password request +#[derive(Deserialize)] +pub struct ChangePasswordRequest { + pub current_password: String, + pub new_password: String, +} + +/// Change current user's password +pub async fn change_password( + State(state): State>, + axum::Extension(session): axum::Extension, + Json(req): Json, +) -> Result> { + let current_user = state + .users + .single_user() + .await? + .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; + + if current_user.id != session.user_id { + return Err(AppError::AuthError("Invalid session".to_string())); + } + + if req.new_password.len() < 4 { + return Err(AppError::BadRequest( + "Password must be at least 4 characters".to_string(), + )); + } + + let verified = state + .users + .verify(¤t_user.username, &req.current_password) + .await?; + if verified.is_none() { + return Err(AppError::AuthError( + "Current password is incorrect".to_string(), + )); + } + + state + .users + .update_password(&session.user_id, &req.new_password) + .await?; + info!("Password changed for user ID: {}", session.user_id); + + Ok(Json(LoginResponse { + success: true, + message: Some("Password changed successfully".to_string()), + })) +} + +/// Change username request +#[derive(Deserialize)] +pub struct ChangeUsernameRequest { + pub username: String, + pub current_password: String, +} + +/// Change current user's username +pub async fn change_username( + State(state): State>, + axum::Extension(session): axum::Extension, + Json(req): Json, +) -> Result> { + let current_user = state + .users + .single_user() + .await? + .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; + + if current_user.id != session.user_id { + return Err(AppError::AuthError("Invalid session".to_string())); + } + + if req.username.len() < 2 { + return Err(AppError::BadRequest( + "Username must be at least 2 characters".to_string(), + )); + } + + let verified = state + .users + .verify(¤t_user.username, &req.current_password) + .await?; + if verified.is_none() { + return Err(AppError::AuthError( + "Current password is incorrect".to_string(), + )); + } + + if current_user.username != req.username { + state + .users + .update_username(&session.user_id, &req.username) + .await?; + } + info!("Username changed for user ID: {}", session.user_id); + + Ok(Json(LoginResponse { + success: true, + message: Some("Username changed successfully".to_string()), + })) +} + +/// Restart the application +pub async fn system_restart(State(state): State>) -> Json { + info!("System restart requested via API"); + + // Send shutdown signal + let _ = state.shutdown_tx.send(()); + + // Spawn restart task in background + tokio::spawn(async { + // Wait for resources to be released (OTG, video, etc.) + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + + // Get current executable and args + let exe = match std::env::current_exe() { + Ok(e) => e, + Err(e) => { + tracing::error!("Failed to get current exe: {}", e); + std::process::exit(1); + } + }; + let args: Vec = std::env::args().skip(1).collect(); + + info!("Restarting: {:?} {:?}", exe, args); + + // Use exec to replace current process (Unix) + #[cfg(unix)] + { + use std::os::unix::process::CommandExt; + let err = std::process::Command::new(&exe).args(&args).exec(); + tracing::error!("Failed to restart: {}", err); + std::process::exit(1); + } + + #[cfg(not(unix))] + { + let _ = std::process::Command::new(&exe).args(&args).spawn(); + std::process::exit(0); + } + }); + + Json(LoginResponse { + success: true, + message: Some("Restarting...".to_string()), + }) +} diff --git a/src/web/handlers/atx_api.rs b/src/web/handlers/atx_api.rs new file mode 100644 index 00000000..5dc6869f --- /dev/null +++ b/src/web/handlers/atx_api.rs @@ -0,0 +1,197 @@ +use super::*; + +use crate::atx::{AtxState, PowerStatus}; + +const WOL_HISTORY_DEFAULT_LIMIT: usize = 5; +const WOL_HISTORY_MAX_LIMIT: usize = 50; + +/// ATX state response +#[derive(Serialize)] +pub struct AtxStateResponse { + pub available: bool, + pub backend: String, + pub initialized: bool, + pub power_status: String, + pub led_supported: bool, +} + +impl From for AtxStateResponse { + fn from(state: AtxState) -> Self { + Self { + available: state.available, + backend: if state.power_configured || state.reset_configured { + format!( + "power: {}, reset: {}", + if state.power_configured { "yes" } else { "no" }, + if state.reset_configured { "yes" } else { "no" } + ) + } else { + "none".to_string() + }, + initialized: state.power_configured || state.reset_configured, + power_status: match state.power_status { + PowerStatus::On => "on".to_string(), + PowerStatus::Off => "off".to_string(), + PowerStatus::Unknown => "unknown".to_string(), + }, + led_supported: state.led_supported, + } + } +} + +/// Get ATX status +pub async fn atx_status(State(state): State>) -> Result> { + let atx_guard = state.atx.read().await; + + match atx_guard.as_ref() { + Some(atx) => { + let atx_state = atx.state().await; + Ok(Json(AtxStateResponse::from(atx_state))) + } + None => Ok(Json(AtxStateResponse { + available: false, + backend: "none".to_string(), + initialized: false, + power_status: "unknown".to_string(), + led_supported: false, + })), + } +} + +/// ATX power control request +#[derive(Deserialize)] +pub struct AtxPowerControlRequest { + pub action: String, // "short", "long", "reset" +} + +/// Control ATX power +pub async fn atx_power( + State(state): State>, + Json(req): Json, +) -> Result> { + let atx_guard = state.atx.read().await; + let atx = atx_guard + .as_ref() + .ok_or_else(|| AppError::Internal("ATX controller not initialized".to_string()))?; + + match req.action.as_str() { + "short" => { + atx.power_short().await?; + Ok(Json(LoginResponse { + success: true, + message: Some("Power short press executed".to_string()), + })) + } + "long" => { + atx.power_long().await?; + Ok(Json(LoginResponse { + success: true, + message: Some("Power long press (force off) executed".to_string()), + })) + } + "reset" => { + atx.reset().await?; + Ok(Json(LoginResponse { + success: true, + message: Some("Reset button pressed".to_string()), + })) + } + _ => Err(AppError::BadRequest(format!( + "Unknown ATX action: {}. Valid actions: short, long, reset", + req.action + ))), + } +} + +/// WOL request body +#[derive(Debug, Deserialize)] +pub struct WolRequest { + /// Target MAC address (e.g., "AA:BB:CC:DD:EE:FF" or "AA-BB-CC-DD-EE-FF") + pub mac_address: String, +} + +#[derive(Debug, Deserialize, Default)] +pub struct WolHistoryQuery { + /// Maximum history entries to return + pub limit: Option, +} + +#[derive(Debug, Serialize)] +pub struct WolHistoryEntry { + pub mac_address: String, + pub updated_at: i64, +} + +#[derive(Debug, Serialize)] +pub struct WolHistoryResponse { + pub history: Vec, +} + +fn normalize_wol_mac_address(mac_address: &str) -> String { + let normalized = mac_address.trim().to_uppercase().replace('-', ":"); + + if normalized.len() == 12 && normalized.chars().all(|c| c.is_ascii_hexdigit()) { + let mut mac_with_separator = String::with_capacity(17); + for (index, chunk) in normalized.as_bytes().chunks(2).enumerate() { + if index > 0 { + mac_with_separator.push(':'); + } + mac_with_separator.push(chunk[0] as char); + mac_with_separator.push(chunk[1] as char); + } + mac_with_separator + } else { + normalized + } +} + +/// Send Wake-on-LAN magic packet +pub async fn atx_wol( + State(state): State>, + Json(req): Json, +) -> Result> { + let mac_address = normalize_wol_mac_address(&req.mac_address); + + // Get WOL interface from config + let config = state.config.get(); + let interface = if config.atx.wol_interface.is_empty() { + None + } else { + Some(config.atx.wol_interface.as_str()) + }; + + // Send WOL packet + crate::atx::send_wol(&mac_address, interface)?; + + if let Err(error) = crate::atx::record_wol_history(state.db.pool(), &mac_address).await { + warn!("Failed to persist WOL history: {}", error); + } + + Ok(Json(LoginResponse { + success: true, + message: Some(format!("WOL packet sent to {}", mac_address)), + })) +} + +/// Get WOL history +pub async fn atx_wol_history( + State(state): State>, + Query(query): Query, +) -> Result> { + let limit = query + .limit + .unwrap_or(WOL_HISTORY_DEFAULT_LIMIT) + .clamp(1, WOL_HISTORY_MAX_LIMIT); + + let rows = crate::atx::list_wol_history(state.db.pool(), limit).await?; + + let history = rows + .into_iter() + .map(|(mac_address, updated_at)| WolHistoryEntry { + mac_address, + updated_at, + }) + .collect(); + + Ok(Json(WolHistoryResponse { history })) +} diff --git a/src/web/handlers/audio_api.rs b/src/web/handlers/audio_api.rs new file mode 100644 index 00000000..3d560e4d --- /dev/null +++ b/src/web/handlers/audio_api.rs @@ -0,0 +1,83 @@ +use super::*; + +use crate::audio::{AudioQuality, AudioStatus}; + +/// Audio status response (re-exports AudioStatus from audio module) +pub type AudioStatusResponse = AudioStatus; + +/// Get audio status +pub async fn audio_status(State(state): State>) -> Json { + Json(state.audio.status().await) +} + +/// Start audio streaming +pub async fn start_audio_streaming( + State(state): State>, +) -> Result> { + state.audio.start_streaming().await?; + + // Reconnect audio sources for existing WebRTC sessions + // This ensures sessions created before audio was enabled will receive audio + state.stream_manager.reconnect_webrtc_audio_sources().await; + + Ok(Json(LoginResponse { + success: true, + message: Some("Audio streaming started".to_string()), + })) +} + +/// Stop audio streaming +pub async fn stop_audio_streaming( + State(state): State>, +) -> Result> { + state.audio.stop_streaming().await?; + Ok(Json(LoginResponse { + success: true, + message: Some("Audio streaming stopped".to_string()), + })) +} + +/// Set audio quality request +#[derive(Deserialize)] +pub struct SetAudioQualityRequest { + pub quality: String, +} + +/// Set audio quality +pub async fn set_audio_quality( + State(state): State>, + Json(req): Json, +) -> Result> { + let quality = req.quality.parse::()?; + state.audio.set_quality(quality).await?; + Ok(Json(LoginResponse { + success: true, + message: Some(format!("Audio quality set to {}", quality)), + })) +} + +/// Select audio device request +#[derive(Deserialize)] +pub struct SelectAudioDeviceRequest { + pub device: String, +} + +/// Select audio device +pub async fn select_audio_device( + State(state): State>, + Json(req): Json, +) -> Result> { + state.audio.select_device(&req.device).await?; + Ok(Json(LoginResponse { + success: true, + message: Some(format!("Audio device selected: {}", req.device)), + })) +} + +/// List audio devices +pub async fn list_audio_devices( + State(state): State>, +) -> Result>> { + let devices = state.audio.list_devices().await?; + Ok(Json(devices)) +} diff --git a/src/web/handlers/auth.rs b/src/web/handlers/auth.rs new file mode 100644 index 00000000..babf37a9 --- /dev/null +++ b/src/web/handlers/auth.rs @@ -0,0 +1,107 @@ +use super::*; + +#[derive(Deserialize)] +pub struct LoginRequest { + pub username: String, + pub password: String, +} + +#[derive(Serialize)] +pub struct LoginResponse { + pub success: bool, + pub message: Option, +} + +pub async fn login( + State(state): State>, + cookies: CookieJar, + Json(req): Json, +) -> Result<(CookieJar, Json)> { + let config = state.config.get(); + + // Check if system is initialized + if !config.initialized { + return Err(AppError::BadRequest("System not initialized".to_string())); + } + + // Verify user credentials + let user = state + .users + .verify(&req.username, &req.password) + .await? + .ok_or_else(|| AppError::AuthError("Invalid username or password".to_string()))?; + + if !config.auth.single_user_allow_multiple_sessions { + // Kick existing sessions before creating a new one. + let revoked_ids = state.sessions.list_ids().await?; + state.sessions.delete_all().await?; + state.remember_revoked_sessions(revoked_ids).await; + } + + // Create session + let session = state.sessions.create(&user.id).await?; + + // Set session cookie + let cookie = Cookie::build((SESSION_COOKIE, session.id)) + .path("/") + .http_only(true) + .same_site(SameSite::Lax) + .max_age(time::Duration::seconds( + config.auth.session_timeout_secs as i64, + )) + .build(); + + Ok(( + cookies.add(cookie), + Json(LoginResponse { + success: true, + message: None, + }), + )) +} + +pub async fn logout( + State(state): State>, + cookies: CookieJar, +) -> Result<(CookieJar, Json)> { + // Get session ID from cookie + if let Some(cookie) = cookies.get(SESSION_COOKIE) { + state.sessions.delete(cookie.value()).await?; + } + + // Remove cookie + let cookie = Cookie::build((SESSION_COOKIE, "")) + .path("/") + .max_age(time::Duration::ZERO) + .build(); + + Ok(( + cookies.remove(cookie), + Json(LoginResponse { + success: true, + message: Some("Logged out".to_string()), + }), + )) +} + +#[derive(Serialize)] +pub struct AuthCheckResponse { + pub authenticated: bool, + pub user: Option, +} + +pub async fn auth_check( + State(state): State>, + axum::Extension(session): axum::Extension, +) -> Json { + // Get user info from user_id + let username = match state.users.single_user().await { + Ok(Some(user)) if user.id == session.user_id => Some(user.username), + _ => None, + }; + + Json(AuthCheckResponse { + authenticated: true, + user: username, + }) +} diff --git a/src/web/handlers/config/apply.rs b/src/web/handlers/config/apply.rs index 25f7c114..d8cbd42a 100644 --- a/src/web/handlers/config/apply.rs +++ b/src/web/handlers/config/apply.rs @@ -39,12 +39,20 @@ fn hid_backend_type(config: &HidConfig) -> crate::hid::HidBackendType { } async fn reconcile_otg_from_store(state: &Arc) -> Result<()> { - let config = state.config.get(); - state - .otg_service - .apply_config(&config.hid, &config.msd) - .await - .map_err(|e| AppError::Config(format!("OTG reconcile failed: {}", e))) + #[cfg(not(unix))] + { + let _ = state; + Ok(()) + } + #[cfg(unix)] + { + let config = state.config.get(); + state + .otg_service + .apply_config(&config.hid, &config.msd) + .await + .map_err(|e| AppError::Config(format!("OTG reconcile failed: {}", e))) + } } pub async fn apply_video_config( @@ -207,6 +215,7 @@ pub async fn apply_hid_config( Ok(()) } +#[cfg(unix)] pub async fn apply_msd_config( state: &Arc, old_config: &MsdConfig, diff --git a/src/web/handlers/config/atx.rs b/src/web/handlers/config/atx.rs index 1e60270f..3621ba83 100644 --- a/src/web/handlers/config/atx.rs +++ b/src/web/handlers/config/atx.rs @@ -25,6 +25,7 @@ pub async fn update_atx_config( let _apply_guard = try_apply_lock(&state.config_apply_locks.atx, "atx")?; let mut merged_atx_config = old_atx_config.clone(); req.apply_to(&mut merged_atx_config); + validate_windows_atx_backends(&merged_atx_config)?; validate_serial_device_conflict(&merged_atx_config, ¤t_config.hid)?; state @@ -41,6 +42,23 @@ pub async fn update_atx_config( Ok(Json(new_atx_config)) } +fn validate_windows_atx_backends(atx: &AtxConfig) -> Result<()> { + if !cfg!(windows) { + return Ok(()); + } + + for (name, key) in [("power", &atx.power), ("reset", &atx.reset)] { + if !matches!(key.driver, AtxDriverType::Serial | AtxDriverType::None) { + return Err(AppError::BadRequest(format!( + "Windows ATX {} only supports serial relay or none", + name + ))); + } + } + + Ok(()) +} + fn validate_serial_device_conflict(atx: &AtxConfig, hid: &HidConfig) -> Result<()> { if hid.backend != HidBackend::Ch9329 { return Ok(()); @@ -91,4 +109,13 @@ mod tests { assert!(validate_serial_device_conflict(&atx, &hid).is_ok()); } + + #[test] + fn test_validate_windows_atx_backends_allows_serial() { + let mut atx = AtxConfig::default(); + atx.power.driver = AtxDriverType::Serial; + atx.reset.driver = AtxDriverType::None; + + assert!(validate_windows_atx_backends(&atx).is_ok()); + } } diff --git a/src/web/handlers/config/mod.rs b/src/web/handlers/config/mod.rs index a02446b3..8886be64 100644 --- a/src/web/handlers/config/mod.rs +++ b/src/web/handlers/config/mod.rs @@ -5,6 +5,7 @@ mod atx; mod audio; mod auth; mod hid; +#[cfg(unix)] mod msd; mod redfish; mod rtsp; @@ -17,6 +18,7 @@ pub use atx::{get_atx_config, update_atx_config}; pub use audio::{get_audio_config, update_audio_config}; pub use auth::{get_auth_config, update_auth_config}; pub use hid::{get_hid_config, update_hid_config}; +#[cfg(unix)] pub use msd::{get_msd_config, update_msd_config}; pub use redfish::{get_redfish_config, update_redfish_config}; pub use rtsp::{get_rtsp_config, get_rtsp_status, update_rtsp_config}; diff --git a/src/web/handlers/config/rustdesk.rs b/src/web/handlers/config/rustdesk.rs index a901b47e..3546c30b 100644 --- a/src/web/handlers/config/rustdesk.rs +++ b/src/web/handlers/config/rustdesk.rs @@ -77,11 +77,14 @@ pub async fn update_rustdesk_config( let _apply_guard = try_apply_lock(&state.config_apply_locks.rustdesk, "rustdesk")?; let old_config = state.config.get().rustdesk.clone(); + let mut merged_config = old_config.clone(); + req.apply_to(&mut merged_config); + req.validate_merged(&merged_config)?; state .config .update(|config| { - req.apply_to(&mut config.rustdesk); + config.rustdesk = merged_config.clone(); }) .await?; diff --git a/src/web/handlers/config/types.rs b/src/web/handlers/config/types.rs index cbddd46b..f9c5935a 100644 --- a/src/web/handlers/config/types.rs +++ b/src/web/handlers/config/types.rs @@ -4,6 +4,7 @@ use crate::rtsp::RtspServiceStatus; use crate::rustdesk::config::RustDeskConfig; use base64::{engine::general_purpose::STANDARD, Engine as _}; use serde::{Deserialize, Serialize}; +#[cfg(unix)] use std::path::Path; use typeshare::typeshare; @@ -358,12 +359,14 @@ impl HidConfigUpdate { } #[typeshare] +#[cfg(unix)] #[derive(Debug, Deserialize)] pub struct MsdConfigUpdate { pub enabled: Option, pub msd_dir: Option, } +#[cfg(unix)] impl MsdConfigUpdate { pub fn validate(&self) -> crate::error::Result<()> { if let Some(ref dir) = self.msd_dir { @@ -472,7 +475,8 @@ impl AtxConfigUpdate { fn validate_key_config(key: &AtxKeyConfigUpdate, name: &str) -> crate::error::Result<()> { if let Some(ref device) = key.device { - if !device.is_empty() && !std::path::Path::new(device).exists() { + if !device.trim().is_empty() && !cfg!(windows) && !std::path::Path::new(device).exists() + { return Err(AppError::BadRequest(format!( "{} device '{}' does not exist", name, device @@ -542,6 +546,12 @@ impl AtxConfigUpdate { ) -> crate::error::Result<()> { match key.driver { crate::atx::AtxDriverType::Serial => { + if key.device.trim().is_empty() { + return Err(AppError::BadRequest(format!( + "{} serial device cannot be empty", + name + ))); + } if key.pin == 0 { return Err(AppError::BadRequest(format!( "{} serial channel must be 1-based (>= 1)", @@ -739,6 +749,15 @@ impl RustDeskConfigUpdate { Ok(()) } + pub fn validate_merged(&self, config: &RustDeskConfig) -> crate::error::Result<()> { + if config.enabled && config.rendezvous_server.trim().is_empty() { + return Err(AppError::BadRequest( + "RustDesk ID server is required".into(), + )); + } + Ok(()) + } + pub fn apply_to(&self, config: &mut RustDeskConfig) { if let Some(enabled) = self.enabled { config.enabled = enabled; diff --git a/src/web/handlers/devices.rs b/src/web/handlers/devices.rs index da2d3cb6..742506e6 100644 --- a/src/web/handlers/devices.rs +++ b/src/web/handlers/devices.rs @@ -1,29 +1,38 @@ use axum::Json; +#[cfg(unix)] use serde::Deserialize; use crate::atx::{discover_devices, AtxDevices}; +#[cfg(unix)] use crate::error::{AppError, Result}; -use crate::video::usb_reset; +#[cfg(unix)] +use crate::platform::usb_reset; pub async fn list_atx_devices() -> Json { Json(discover_devices()) } +#[cfg(unix)] pub async fn list_usb_devices() -> Json> { Json(usb_reset::list_usb_devices()) } +#[cfg(unix)] #[derive(Deserialize)] pub struct UsbResetRequest { pub bus_num: u32, pub dev_num: u32, } +#[cfg(unix)] pub async fn reset_usb_device(Json(req): Json) -> Result> { usb_reset::reset_usb_device(req.bus_num, req.dev_num).map_err(|e| { - AppError::VideoError(format!( - "USB reset failed for device {}-{}: {}", - req.bus_num, req.dev_num, e + AppError::Io(std::io::Error::new( + e.kind(), + format!( + "USB reset failed for device {}-{}: {}", + req.bus_num, req.dev_num, e + ), )) })?; Ok(Json(serde_json::json!({ diff --git a/src/web/handlers/extensions.rs b/src/web/handlers/extensions.rs index d6cd4b0e..2e7ee8c9 100644 --- a/src/web/handlers/extensions.rs +++ b/src/web/handlers/extensions.rs @@ -13,6 +13,27 @@ use crate::extensions::{ }; use crate::state::AppState; +fn validate_gostc_enabled(config: &GostcConfig) -> Result<()> { + if config.addr.trim().is_empty() { + return Err(AppError::BadRequest( + "GOSTC server address is required".into(), + )); + } + if config.key.is_empty() { + return Err(AppError::BadRequest("GOSTC client key is required".into())); + } + Ok(()) +} + +fn validate_easytier_enabled(config: &EasytierConfig) -> Result<()> { + if config.network_name.trim().is_empty() { + return Err(AppError::BadRequest( + "EasyTier network name is required".into(), + )); + } + Ok(()) +} + pub async fn list_extensions(State(state): State>) -> Json { let config = state.config.get(); let mgr = &state.extensions; @@ -179,40 +200,40 @@ pub async fn update_gostc_config( State(state): State>, Json(req): Json, ) -> Result> { - let was_enabled = state.config.get().extensions.gostc.enabled; + let current_config = state.config.get(); + let was_enabled = current_config.extensions.gostc.enabled; + let mut next_gostc = current_config.extensions.gostc.clone(); + + if let Some(enabled) = req.enabled { + next_gostc.enabled = enabled; + } + if let Some(ref addr) = req.addr { + next_gostc.addr = addr.clone(); + } + if let Some(ref key) = req.key { + next_gostc.key = key.clone(); + } + if let Some(tls) = req.tls { + next_gostc.tls = tls; + } + + if next_gostc.enabled { + validate_gostc_enabled(&next_gostc)?; + } state .config .update(|config| { - let gostc = &mut config.extensions.gostc; - if let Some(enabled) = req.enabled { - gostc.enabled = enabled; - } - if let Some(ref addr) = req.addr { - gostc.addr = addr.clone(); - } - if let Some(ref key) = req.key { - gostc.key = key.clone(); - } - if let Some(tls) = req.tls { - gostc.tls = tls; - } + config.extensions.gostc = next_gostc.clone(); }) .await?; let new_config = state.config.get(); let is_enabled = new_config.extensions.gostc.enabled; - let has_key = !new_config.extensions.gostc.key.is_empty(); - let has_addr = !new_config.extensions.gostc.addr.trim().is_empty(); if was_enabled && !is_enabled { state.extensions.stop(ExtensionId::Gostc).await.ok(); - } else if !was_enabled - && is_enabled - && has_key - && has_addr - && state.extensions.check_available(ExtensionId::Gostc) - { + } else if !was_enabled && is_enabled && state.extensions.check_available(ExtensionId::Gostc) { state .extensions .start(ExtensionId::Gostc, &new_config.extensions) @@ -227,40 +248,43 @@ pub async fn update_easytier_config( State(state): State>, Json(req): Json, ) -> Result> { - let was_enabled = state.config.get().extensions.easytier.enabled; + let current_config = state.config.get(); + let was_enabled = current_config.extensions.easytier.enabled; + let mut next_easytier = current_config.extensions.easytier.clone(); + + if let Some(enabled) = req.enabled { + next_easytier.enabled = enabled; + } + if let Some(ref name) = req.network_name { + next_easytier.network_name = name.clone(); + } + if let Some(ref secret) = req.network_secret { + next_easytier.network_secret = secret.clone(); + } + if let Some(ref peers) = req.peer_urls { + next_easytier.peer_urls = peers.clone(); + } + if req.virtual_ip.is_some() { + next_easytier.virtual_ip = req.virtual_ip.clone(); + } + + if next_easytier.enabled { + validate_easytier_enabled(&next_easytier)?; + } state .config .update(|config| { - let et = &mut config.extensions.easytier; - if let Some(enabled) = req.enabled { - et.enabled = enabled; - } - if let Some(ref name) = req.network_name { - et.network_name = name.clone(); - } - if let Some(ref secret) = req.network_secret { - et.network_secret = secret.clone(); - } - if let Some(ref peers) = req.peer_urls { - et.peer_urls = peers.clone(); - } - if req.virtual_ip.is_some() { - et.virtual_ip = req.virtual_ip.clone(); - } + config.extensions.easytier = next_easytier.clone(); }) .await?; let new_config = state.config.get(); let is_enabled = new_config.extensions.easytier.enabled; - let has_name = !new_config.extensions.easytier.network_name.is_empty(); if was_enabled && !is_enabled { state.extensions.stop(ExtensionId::Easytier).await.ok(); - } else if !was_enabled - && is_enabled - && has_name - && state.extensions.check_available(ExtensionId::Easytier) + } else if !was_enabled && is_enabled && state.extensions.check_available(ExtensionId::Easytier) { state .extensions diff --git a/src/web/handlers/hid_api.rs b/src/web/handlers/hid_api.rs new file mode 100644 index 00000000..00597b2d --- /dev/null +++ b/src/web/handlers/hid_api.rs @@ -0,0 +1,53 @@ +use super::*; + +#[derive(Serialize)] +pub struct HidStatus { + pub available: bool, + pub backend: String, + pub initialized: bool, + pub online: bool, + pub supports_absolute_mouse: bool, + pub keyboard_leds_enabled: bool, + pub led_state: crate::hid::LedState, + pub screen_resolution: Option<(u32, u32)>, + pub device: Option, + pub error: Option, + pub error_code: Option, +} + +/// OTG self-check status for troubleshooting USB gadget issues +#[cfg(unix)] +pub async fn hid_otg_self_check( + State(state): State>, +) -> Json { + let config = state.config.get(); + Json(crate::otg::self_check::run(config.as_ref())) +} + +/// Get HID status +pub async fn hid_status(State(state): State>) -> Json { + let hid = state.hid.snapshot().await; + Json(HidStatus { + available: hid.available, + backend: hid.backend, + initialized: hid.initialized, + online: hid.online, + supports_absolute_mouse: hid.supports_absolute_mouse, + keyboard_leds_enabled: hid.keyboard_leds_enabled, + led_state: hid.led_state, + screen_resolution: hid.screen_resolution, + device: hid.device, + error: hid.error, + error_code: hid.error_code, + }) +} + +/// Reset HID state +pub async fn hid_reset(State(state): State>) -> Result> { + state.hid.reset().await?; + + Ok(Json(LoginResponse { + success: true, + message: Some("HID state reset".to_string()), + })) +} diff --git a/src/web/handlers/inventory.rs b/src/web/handlers/inventory.rs new file mode 100644 index 00000000..5908d1ed --- /dev/null +++ b/src/web/handlers/inventory.rs @@ -0,0 +1,182 @@ +use super::*; + +#[derive(Serialize)] +pub struct DeviceList { + pub video: Vec, + pub serial: Vec, + pub audio: Vec, + pub udc: Vec, + pub extensions: ExtensionsAvailability, +} + +#[derive(Serialize)] +pub struct ExtensionsAvailability { + pub ttyd_available: bool, + pub rustdesk_available: bool, +} + +#[derive(Serialize)] +pub struct VideoDevice { + pub path: String, + pub name: String, + pub driver: String, + pub formats: Vec, + pub usb_bus: Option, + pub has_signal: bool, +} + +#[derive(Serialize)] +pub struct VideoFormat { + pub format: String, + pub description: String, + pub resolutions: Vec, +} + +#[derive(Serialize)] +pub struct VideoResolution { + pub width: u32, + pub height: u32, + pub fps: Vec, +} + +#[derive(Serialize)] +pub struct SerialDevice { + pub path: String, + pub name: String, +} + +#[derive(Serialize)] +pub struct AudioDevice { + pub name: String, + pub description: String, + pub is_hdmi: bool, + pub usb_bus: Option, +} + +#[derive(Serialize)] +pub struct UdcDevice { + pub name: String, +} + +/// Extract USB bus port from V4L2 bus_info string +/// Examples: +/// - "usb-0000:00:14.0-1" -> Some("1") +/// - "usb-xhci-hcd.0-1.2" -> Some("1.2") +/// - "usb-0000:00:14.0-1.3.2" -> Some("1.3.2") +/// - "platform:..." -> None +fn extract_usb_bus_from_bus_info(bus_info: &str) -> Option { + if !bus_info.starts_with("usb-") { + return None; + } + // Find the last '-' which separates the USB port + // e.g., "usb-0000:00:14.0-1" -> "1" + // e.g., "usb-xhci-hcd.0-1.2" -> "1.2" + let parts: Vec<&str> = bus_info.rsplitn(2, '-').collect(); + if parts.len() == 2 { + let port = parts[0]; + // Verify it looks like a USB port (starts with digit) + if port + .chars() + .next() + .map(|c| c.is_ascii_digit()) + .unwrap_or(false) + { + return Some(port.to_string()); + } + } + None +} + +pub async fn list_devices(State(state): State>) -> Json { + let platform = PlatformCapabilities::current(); + + // Detect video devices + let video_devices = match state.stream_manager.list_devices().await { + Ok(devices) => devices + .into_iter() + .map(|d| { + // Extract USB bus from bus_info (e.g., "usb-0000:00:14.0-1" -> "1") + // or "usb-xhci-hcd.0-1.2" -> "1.2" + let usb_bus = extract_usb_bus_from_bus_info(&d.bus_info); + VideoDevice { + path: d.path.to_string_lossy().to_string(), + name: d.name, + driver: d.driver, + formats: d + .formats + .iter() + .map(|f| VideoFormat { + format: format!("{}", f.format), + description: f.description.clone(), + resolutions: f + .resolutions + .iter() + .map(|r| VideoResolution { + width: r.width, + height: r.height, + fps: r.fps.clone(), + }) + .collect(), + }) + .collect(), + usb_bus, + has_signal: d.has_signal, + } + }) + .collect(), + Err(e) => { + warn!(error = %e, "Video device enumeration failed; returning empty video list for /api/devices"); + vec![] + } + }; + + let serial_devices = list_serial_ports() + .into_iter() + .map(|path| SerialDevice { + name: std::path::Path::new(&path) + .file_name() + .and_then(|name| name.to_str()) + .unwrap_or(&path) + .to_string(), + path, + }) + .collect(); + + #[cfg(unix)] + let udc_devices = crate::otg::list_udc_devices() + .into_iter() + .map(|name| UdcDevice { name }) + .collect(); + #[cfg(not(unix))] + let udc_devices = Vec::new(); + + // Detect audio devices + let audio_devices = match state.audio.list_devices().await { + Ok(devices) => devices + .into_iter() + .map(|d| AudioDevice { + name: d.name, + description: d.description, + is_hdmi: d.is_hdmi, + usb_bus: d.usb_bus, + }) + .collect(), + Err(_) => vec![], + }; + + // Check extension availability + let ttyd_available = state + .extensions + .check_available(crate::extensions::ExtensionId::Ttyd); + + Json(DeviceList { + video: video_devices, + serial: serial_devices, + audio: audio_devices, + udc: udc_devices, + extensions: ExtensionsAvailability { + ttyd_available, + rustdesk_available: platform.rustdesk.available, + }, + }) +} diff --git a/src/web/handlers/mod.rs b/src/web/handlers/mod.rs index 41621b8c..81eac2e0 100644 --- a/src/web/handlers/mod.rs +++ b/src/web/handlers/mod.rs @@ -3,7 +3,38 @@ pub mod devices; pub mod extensions; pub mod terminal; -use axum::{extract::State, Json}; +mod account; +mod atx_api; +mod audio_api; +mod auth; +mod hid_api; +mod inventory; +#[cfg(unix)] +mod msd_api; +mod setup; +mod stream; +mod system; +mod update_api; +mod webrtc; + +pub use account::*; +pub use atx_api::*; +pub use audio_api::*; +pub use auth::*; +pub use hid_api::*; +pub use inventory::*; +#[cfg(unix)] +pub use msd_api::*; +pub use setup::*; +pub use stream::*; +pub use system::*; +pub use update_api::*; +pub use webrtc::*; + +use axum::{ + extract::{Query, State}, + Json, +}; use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite}; use serde::{Deserialize, Serialize}; use std::sync::Arc; @@ -12,3513 +43,14 @@ use tracing::{info, warn}; use self::config::apply::ConfigApplyOptions; use crate::auth::{Session, SESSION_COOKIE}; use crate::config::StreamMode; +use crate::diagnostics::{get_device_info, get_disk_space, DeviceInfo, DiskSpaceInfo}; use crate::error::{AppError, Result}; +use crate::platform::PlatformCapabilities; use crate::state::AppState; use crate::update::{UpdateChannel, UpdateOverviewResponse, UpdateStatusResponse, UpgradeRequest}; -use crate::utils::{hostname_uname, list_dir_names, read_trimmed}; -use crate::video::codec_constraints::codec_to_id; -use crate::video::encoder::{ +use crate::utils::list_serial_ports; +use crate::video::codec::{ build_hardware_self_check_runtime_error, run_hardware_self_check, BitratePreset, VideoEncoderSelfCheckResponse, }; - -/// Health check response -#[derive(Serialize)] -pub struct HealthResponse { - pub status: &'static str, - pub version: &'static str, -} - -pub async fn health_check() -> Json { - Json(HealthResponse { - status: "ok", - version: env!("CARGO_PKG_VERSION"), - }) -} - -/// System info response -#[derive(Serialize)] -pub struct SystemInfo { - pub version: &'static str, - pub build_date: &'static str, - pub initialized: bool, - pub capabilities: Capabilities, - #[serde(skip_serializing_if = "Option::is_none")] - pub disk_space: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub device_info: Option, -} - -/// Device information (hostname, CPU, memory, network) -#[derive(Serialize)] -pub struct DeviceInfo { - pub hostname: String, - pub cpu_model: String, - pub cpu_usage: f32, - pub memory_total: u64, - pub memory_used: u64, - pub network_addresses: Vec, -} - -/// Network interface address -#[derive(Serialize)] -pub struct NetworkAddress { - pub interface: String, - pub ip: String, -} - -/// Disk space information -#[derive(Serialize)] -pub struct DiskSpaceInfo { - pub total: u64, - pub available: u64, - pub used: u64, -} - -#[derive(Serialize)] -pub struct Capabilities { - pub video: CapabilityInfo, - pub hid: CapabilityInfo, - pub msd: CapabilityInfo, - pub atx: CapabilityInfo, - pub audio: CapabilityInfo, -} - -#[derive(Serialize)] -pub struct CapabilityInfo { - pub available: bool, - pub backend: Option, -} - -pub async fn system_info(State(state): State>) -> Json { - let config = state.config.get(); - - // Get disk space information for MSD base directory - let disk_space = { - let msd_dir = config.msd.msd_dir_path(); - if msd_dir.as_os_str().is_empty() { - None - } else { - get_disk_space(&msd_dir).ok() - } - }; - - // Get device information (hostname, CPU, memory, network) - let device_info = Some(get_device_info()); - - Json(SystemInfo { - version: env!("CARGO_PKG_VERSION"), - build_date: env!("BUILD_DATE"), - initialized: config.initialized, - capabilities: Capabilities { - video: CapabilityInfo { - available: config.video.device.is_some(), - backend: config.video.device.clone(), - }, - hid: CapabilityInfo { - available: config.hid.backend != crate::config::HidBackend::None, - backend: Some(format!("{:?}", config.hid.backend)), - }, - msd: CapabilityInfo { - available: config.msd.enabled, - backend: None, - }, - atx: CapabilityInfo { - available: config.atx.enabled, - backend: if config.atx.enabled { - Some(format!( - "power: {:?}, reset: {:?}", - config.atx.power.driver, config.atx.reset.driver - )) - } else { - None - }, - }, - audio: CapabilityInfo { - available: config.audio.enabled, - backend: Some(config.audio.device.clone()), - }, - }, - disk_space, - device_info, - }) -} - -/// Get disk space information for a given path -fn get_disk_space(path: &std::path::Path) -> Result { - let stat = nix::sys::statvfs::statvfs(path) - .map_err(|e| AppError::Internal(format!("Failed to get disk space: {}", e)))?; - - let block_size = stat.block_size() as u64; - let total = stat.blocks() as u64 * block_size; - let available = stat.blocks_available() as u64 * block_size; - let used = total - available; - - Ok(DiskSpaceInfo { - total, - available, - used, - }) -} - -/// Get device information (hostname, CPU, memory, network) -fn get_device_info() -> DeviceInfo { - // Get memory info in a single read - let mem_info = get_meminfo(); - - DeviceInfo { - hostname: hostname_uname(), - cpu_model: get_cpu_model(), - cpu_usage: get_cpu_usage(), - memory_total: mem_info.total, - memory_used: mem_info.total.saturating_sub(mem_info.available), - network_addresses: get_network_addresses(), - } -} - -/// Get CPU model name from /proc/cpuinfo, fallback to device-tree model -fn get_cpu_model() -> String { - let cpuinfo = std::fs::read_to_string("/proc/cpuinfo").ok(); - - if let Some(model) = cpuinfo - .as_deref() - .and_then(parse_cpu_model_from_cpuinfo_content) - { - return model; - } - - if let Some(model) = read_device_tree_model() { - return model; - } - - if let Some(content) = cpuinfo.as_deref() { - let cores = content - .lines() - .filter(|line| line.starts_with("processor")) - .count(); - if cores > 0 { - return format!("{} {}C", std::env::consts::ARCH, cores); - } - } - - std::env::consts::ARCH.to_string() -} - -fn parse_cpu_model_from_cpuinfo_content(content: &str) -> Option { - content - .lines() - .find(|line| line.starts_with("model name") || line.starts_with("Model")) - .and_then(|line| line.split(':').nth(1)) - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) -} - -fn read_device_tree_model() -> Option { - std::fs::read("/proc/device-tree/model") - .ok() - .and_then(|bytes| parse_device_tree_model_bytes(&bytes)) -} - -fn parse_device_tree_model_bytes(bytes: &[u8]) -> Option { - let model = String::from_utf8_lossy(bytes) - .trim_matches(|c: char| c == '\0' || c.is_whitespace()) - .to_string(); - - if model.is_empty() { - None - } else { - Some(model) - } -} - -/// CPU usage state for calculating usage between samples -static CPU_PREV_STATS: std::sync::OnceLock> = - std::sync::OnceLock::new(); - -/// Get CPU usage percentage (0.0 - 100.0) -fn get_cpu_usage() -> f32 { - // Parse /proc/stat for CPU times - let content = match std::fs::read_to_string("/proc/stat") { - Ok(c) => c, - Err(_) => return 0.0, - }; - - let cpu_line = match content.lines().next() { - Some(line) if line.starts_with("cpu ") => line, - _ => return 0.0, - }; - - // Parse CPU times: user, nice, system, idle, iowait, irq, softirq, steal - let parts: Vec = cpu_line - .split_whitespace() - .skip(1) // skip "cpu" - .take(8) - .filter_map(|s| s.parse().ok()) - .collect(); - - if parts.len() < 4 { - return 0.0; - } - - let idle = parts[3] + parts.get(4).unwrap_or(&0); // idle + iowait - let total: u64 = parts.iter().sum(); - - // Get or initialize previous stats - let prev_mutex = CPU_PREV_STATS.get_or_init(|| std::sync::Mutex::new((0, 0))); - let mut prev = prev_mutex.lock().unwrap(); - let (prev_idle, prev_total) = *prev; - - // Calculate delta - let idle_delta = idle.saturating_sub(prev_idle); - let total_delta = total.saturating_sub(prev_total); - - // Update previous stats - *prev = (idle, total); - - if total_delta == 0 { - return 0.0; - } - - let usage = 100.0 * (1.0 - (idle_delta as f64 / total_delta as f64)); - usage as f32 -} - -/// Memory info parsed from /proc/meminfo -struct MemInfo { - total: u64, - available: u64, -} - -/// Parse memory info from /proc/meminfo in a single read -fn get_meminfo() -> MemInfo { - let content = match std::fs::read_to_string("/proc/meminfo") { - Ok(c) => c, - Err(_) => { - return MemInfo { - total: 0, - available: 0, - } - } - }; - - let mut total = 0u64; - let mut available = 0u64; - - for line in content.lines() { - if line.starts_with("MemTotal:") { - if let Some(kb) = line - .split_whitespace() - .nth(1) - .and_then(|v| v.parse::().ok()) - { - total = kb * 1024; - } - } else if line.starts_with("MemAvailable:") { - if let Some(kb) = line - .split_whitespace() - .nth(1) - .and_then(|v| v.parse::().ok()) - { - available = kb * 1024; - } - } - // Early exit if both values found - if total > 0 && available > 0 { - break; - } - } - - MemInfo { total, available } -} - -/// Get network addresses for all non-loopback interfaces -fn get_network_addresses() -> Vec { - // Get all interface addresses in a single system call - let all_addrs = match nix::ifaddrs::getifaddrs() { - Ok(addrs) => addrs, - Err(_) => return Vec::new(), - }; - - // Check which interfaces are up - let mut up_ifaces = std::collections::HashSet::new(); - let net_dir = match std::fs::read_dir("/sys/class/net") { - Ok(dir) => dir, - Err(_) => return Vec::new(), - }; - - for entry in net_dir.flatten() { - let iface_name = match entry.file_name().into_string() { - Ok(name) => name, - Err(_) => continue, - }; - - // Skip loopback - if iface_name == "lo" { - continue; - } - - // Check if interface is up by reading operstate - let operstate_path = entry.path().join("operstate"); - let is_up = std::fs::read_to_string(&operstate_path) - .map(|s| s.trim() == "up") - .unwrap_or(false); - - if !is_up { - continue; - } - - up_ifaces.insert(iface_name); - } - - let mut addresses = Vec::new(); - let mut seen = std::collections::HashSet::new(); - for ifaddr in all_addrs { - let iface_name = &ifaddr.interface_name; - if iface_name == "lo" || !up_ifaces.contains(iface_name) { - continue; - } - - if let Some(addr) = ifaddr.address { - if let Some(sockaddr_in) = addr.as_sockaddr_in() { - let ip = sockaddr_in.ip(); - if ip.is_loopback() { - continue; - } - let ip_str = ip.to_string(); - if seen.insert((iface_name.clone(), ip_str.clone())) { - addresses.push(NetworkAddress { - interface: iface_name.clone(), - ip: ip_str, - }); - } - } else if let Some(sockaddr_in6) = addr.as_sockaddr_in6() { - let ip = sockaddr_in6.ip(); - if ip.is_loopback() || ip.is_unspecified() || ip.is_unicast_link_local() { - continue; - } - let ip_str = ip.to_string(); - if seen.insert((iface_name.clone(), ip_str.clone())) { - addresses.push(NetworkAddress { - interface: iface_name.clone(), - ip: ip_str, - }); - } - } - } - } - - addresses -} - -#[cfg(test)] -mod tests { - use super::{parse_cpu_model_from_cpuinfo_content, parse_device_tree_model_bytes}; - - #[test] - fn parse_cpu_model_from_model_name_field() { - let input = "processor\t: 0\nmodel name\t: Intel(R) Xeon(R)\n"; - assert_eq!( - parse_cpu_model_from_cpuinfo_content(input), - Some("Intel(R) Xeon(R)".to_string()) - ); - } - - #[test] - fn parse_cpu_model_from_model_field() { - let input = "processor\t: 0\nModel\t\t: Raspberry Pi 4 Model B Rev 1.4\n"; - assert_eq!( - parse_cpu_model_from_cpuinfo_content(input), - Some("Raspberry Pi 4 Model B Rev 1.4".to_string()) - ); - } - - #[test] - fn parse_device_tree_model_trimmed() { - let input = b"Onething OEC Box\0\n"; - assert_eq!( - parse_device_tree_model_bytes(input), - Some("Onething OEC Box".to_string()) - ); - } -} - -#[derive(Deserialize)] -pub struct LoginRequest { - pub username: String, - pub password: String, -} - -#[derive(Serialize)] -pub struct LoginResponse { - pub success: bool, - pub message: Option, -} - -pub async fn login( - State(state): State>, - cookies: CookieJar, - Json(req): Json, -) -> Result<(CookieJar, Json)> { - let config = state.config.get(); - - // Check if system is initialized - if !config.initialized { - return Err(AppError::BadRequest("System not initialized".to_string())); - } - - // Verify user credentials - let user = state - .users - .verify(&req.username, &req.password) - .await? - .ok_or_else(|| AppError::AuthError("Invalid username or password".to_string()))?; - - if !config.auth.single_user_allow_multiple_sessions { - // Kick existing sessions before creating a new one. - let revoked_ids = state.sessions.list_ids().await?; - state.sessions.delete_all().await?; - state.remember_revoked_sessions(revoked_ids).await; - } - - // Create session - let session = state.sessions.create(&user.id).await?; - - // Set session cookie - let cookie = Cookie::build((SESSION_COOKIE, session.id)) - .path("/") - .http_only(true) - .same_site(SameSite::Lax) - .max_age(time::Duration::seconds( - config.auth.session_timeout_secs as i64, - )) - .build(); - - Ok(( - cookies.add(cookie), - Json(LoginResponse { - success: true, - message: None, - }), - )) -} - -pub async fn logout( - State(state): State>, - cookies: CookieJar, -) -> Result<(CookieJar, Json)> { - // Get session ID from cookie - if let Some(cookie) = cookies.get(SESSION_COOKIE) { - state.sessions.delete(cookie.value()).await?; - } - - // Remove cookie - let cookie = Cookie::build((SESSION_COOKIE, "")) - .path("/") - .max_age(time::Duration::ZERO) - .build(); - - Ok(( - cookies.remove(cookie), - Json(LoginResponse { - success: true, - message: Some("Logged out".to_string()), - }), - )) -} - -#[derive(Serialize)] -pub struct AuthCheckResponse { - pub authenticated: bool, - pub user: Option, -} - -pub async fn auth_check( - State(state): State>, - axum::Extension(session): axum::Extension, -) -> Json { - // Get user info from user_id - let username = match state.users.single_user().await { - Ok(Some(user)) if user.id == session.user_id => Some(user.username), - _ => None, - }; - - Json(AuthCheckResponse { - authenticated: true, - user: username, - }) -} - -#[derive(Serialize)] -pub struct SetupStatus { - pub initialized: bool, - pub needs_setup: bool, -} - -pub async fn setup_status(State(state): State>) -> Json { - let initialized = state.config.is_initialized(); - Json(SetupStatus { - initialized, - needs_setup: !initialized, - }) -} - -#[derive(Deserialize)] -pub struct SetupRequest { - // Account settings - pub username: String, - pub password: String, - // Video settings - pub video_device: Option, - pub video_format: Option, - pub video_width: Option, - pub video_height: Option, - pub video_fps: Option, - // Audio settings - pub audio_device: Option, - // HID settings - pub hid_backend: Option, - pub hid_ch9329_port: Option, - pub hid_ch9329_baudrate: Option, - pub hid_otg_udc: Option, - pub hid_otg_profile: Option, - pub hid_otg_endpoint_budget: Option, - pub hid_otg_keyboard_leds: Option, - pub msd_enabled: Option, - // Extension settings - pub ttyd_enabled: Option, - pub rustdesk_enabled: Option, -} - -pub async fn setup_init( - State(state): State>, - Json(req): Json, -) -> Result> { - // Check if already initialized - if state.config.is_initialized() { - return Err(AppError::BadRequest("Already initialized".to_string())); - } - - // Validate username - if req.username.len() < 2 { - return Err(AppError::BadRequest( - "Username must be at least 2 characters".to_string(), - )); - } - - // Validate password - if req.password.len() < 4 { - return Err(AppError::BadRequest( - "Password must be at least 4 characters".to_string(), - )); - } - - // Create single system user - state - .users - .create_first_user(&req.username, &req.password) - .await?; - - // Update config - state - .config - .update(|config| { - config.initialized = true; - - // Video settings - if let Some(device) = req.video_device.clone() { - config.video.device = Some(device); - } - if let Some(format) = req.video_format.clone() { - config.video.format = Some(format); - } - if let Some(width) = req.video_width { - config.video.width = width; - } - if let Some(height) = req.video_height { - config.video.height = height; - } - if let Some(fps) = req.video_fps { - config.video.fps = fps; - } - - // Audio settings - if let Some(device) = req.audio_device.clone() { - config.audio.device = device; - config.audio.enabled = true; - } - - // HID settings - if let Some(backend) = req.hid_backend.clone() { - config.hid.backend = match backend.as_str() { - "otg" => crate::config::HidBackend::Otg, - "ch9329" => crate::config::HidBackend::Ch9329, - _ => crate::config::HidBackend::None, - }; - } - if let Some(port) = req.hid_ch9329_port.clone() { - config.hid.ch9329_port = port; - } - if let Some(baudrate) = req.hid_ch9329_baudrate { - config.hid.ch9329_baudrate = baudrate; - } - if let Some(udc) = req.hid_otg_udc.clone() { - config.hid.otg_udc = Some(udc); - } - if let Some(profile) = req.hid_otg_profile.clone() { - if let Some(parsed) = crate::config::OtgHidProfile::from_legacy_str(&profile) { - config.hid.otg_profile = parsed; - } - } - if let Some(budget) = req.hid_otg_endpoint_budget { - config.hid.otg_endpoint_budget = budget; - } - if let Some(enabled) = req.hid_otg_keyboard_leds { - config.hid.otg_keyboard_leds = enabled; - } - if let Some(enabled) = req.msd_enabled { - config.msd.enabled = enabled; - } - - // Extension settings - if let Some(enabled) = req.ttyd_enabled { - config.extensions.ttyd.enabled = enabled; - } - if let Some(enabled) = req.rustdesk_enabled { - config.rustdesk.enabled = enabled; - } - }) - .await?; - - // Get updated config for HID reload - let new_config = state.config.get(); - - if let Err(e) = state - .otg_service - .apply_config(&new_config.hid, &new_config.msd) - .await - { - tracing::warn!("Failed to apply OTG config during setup: {}", e); - } - - tracing::info!( - "Extension config after save: ttyd.enabled={}, rustdesk.enabled={}", - new_config.extensions.ttyd.enabled, - new_config.rustdesk.enabled - ); - - // Initialize HID backend with new config - let new_hid_backend = match new_config.hid.backend { - crate::config::HidBackend::Otg => crate::hid::HidBackendType::Otg, - crate::config::HidBackend::Ch9329 => crate::hid::HidBackendType::Ch9329 { - port: new_config.hid.ch9329_port.clone(), - baud_rate: new_config.hid.ch9329_baudrate, - }, - crate::config::HidBackend::None => crate::hid::HidBackendType::None, - }; - - // Reload HID backend - if let Err(e) = state.hid.reload(new_hid_backend).await { - tracing::warn!("Failed to initialize HID backend during setup: {}", e); - // Don't fail setup, just warn - } else { - tracing::info!("HID backend initialized: {:?}", new_config.hid.backend); - } - - // Start extensions if enabled - if new_config.extensions.ttyd.enabled { - if let Err(e) = state - .extensions - .start(crate::extensions::ExtensionId::Ttyd, &new_config.extensions) - .await - { - tracing::warn!("Failed to start ttyd during setup: {}", e); - } else { - tracing::info!("ttyd started during setup"); - } - } - - // Start RustDesk if enabled - if new_config.rustdesk.enabled { - let empty_config = crate::rustdesk::config::RustDeskConfig::default(); - if let Err(e) = config::apply::apply_rustdesk_config( - &state, - &empty_config, - &new_config.rustdesk, - ConfigApplyOptions::default(), - ) - .await - { - tracing::warn!("Failed to start RustDesk during setup: {}", e); - } else { - tracing::info!("RustDesk started during setup"); - } - } - - // Start RTSP if enabled - if new_config.rtsp.enabled { - let empty_config = crate::config::RtspConfig::default(); - if let Err(e) = config::apply::apply_rtsp_config( - &state, - &empty_config, - &new_config.rtsp, - ConfigApplyOptions::default(), - ) - .await - { - tracing::warn!("Failed to start RTSP during setup: {}", e); - } else { - tracing::info!("RTSP started during setup"); - } - } - - // Start audio streaming if audio device was selected during setup - if new_config.audio.enabled { - let audio_config = crate::audio::AudioControllerConfig { - enabled: true, - device: new_config.audio.device.clone(), - quality: new_config - .audio - .quality - .parse::()?, - }; - if let Err(e) = state.audio.update_config(audio_config).await { - tracing::warn!("Failed to start audio during setup: {}", e); - } else { - tracing::info!( - "Audio started during setup: device={}", - new_config.audio.device - ); - } - // Also enable WebRTC audio - if let Err(e) = state.stream_manager.set_webrtc_audio_enabled(true).await { - tracing::warn!("Failed to enable WebRTC audio during setup: {}", e); - } - } - - tracing::info!("System initialized successfully"); - - Ok(Json(LoginResponse { - success: true, - message: Some("Setup completed".to_string()), - })) -} - -#[derive(Serialize)] -pub struct DeviceList { - pub video: Vec, - pub serial: Vec, - pub audio: Vec, - pub udc: Vec, - pub extensions: ExtensionsAvailability, -} - -#[derive(Serialize)] -pub struct ExtensionsAvailability { - pub ttyd_available: bool, - pub rustdesk_available: bool, -} - -#[derive(Serialize)] -pub struct VideoDevice { - pub path: String, - pub name: String, - pub driver: String, - pub formats: Vec, - pub usb_bus: Option, - pub has_signal: bool, -} - -#[derive(Serialize)] -pub struct VideoFormat { - pub format: String, - pub description: String, - pub resolutions: Vec, -} - -#[derive(Serialize)] -pub struct VideoResolution { - pub width: u32, - pub height: u32, - pub fps: Vec, -} - -#[derive(Serialize)] -pub struct SerialDevice { - pub path: String, - pub name: String, -} - -#[derive(Serialize)] -pub struct AudioDevice { - pub name: String, - pub description: String, - pub is_hdmi: bool, - pub usb_bus: Option, -} - -#[derive(Serialize)] -pub struct UdcDevice { - pub name: String, -} - -/// Extract USB bus port from V4L2 bus_info string -/// Examples: -/// - "usb-0000:00:14.0-1" -> Some("1") -/// - "usb-xhci-hcd.0-1.2" -> Some("1.2") -/// - "usb-0000:00:14.0-1.3.2" -> Some("1.3.2") -/// - "platform:..." -> None -fn extract_usb_bus_from_bus_info(bus_info: &str) -> Option { - if !bus_info.starts_with("usb-") { - return None; - } - // Find the last '-' which separates the USB port - // e.g., "usb-0000:00:14.0-1" -> "1" - // e.g., "usb-xhci-hcd.0-1.2" -> "1.2" - let parts: Vec<&str> = bus_info.rsplitn(2, '-').collect(); - if parts.len() == 2 { - let port = parts[0]; - // Verify it looks like a USB port (starts with digit) - if port - .chars() - .next() - .map(|c| c.is_ascii_digit()) - .unwrap_or(false) - { - return Some(port.to_string()); - } - } - None -} - -pub async fn list_devices(State(state): State>) -> Json { - // Detect video devices - let video_devices = match state.stream_manager.list_devices().await { - Ok(devices) => devices - .into_iter() - .map(|d| { - // Extract USB bus from bus_info (e.g., "usb-0000:00:14.0-1" -> "1") - // or "usb-xhci-hcd.0-1.2" -> "1.2" - let usb_bus = extract_usb_bus_from_bus_info(&d.bus_info); - VideoDevice { - path: d.path.to_string_lossy().to_string(), - name: d.name, - driver: d.driver, - formats: d - .formats - .iter() - .map(|f| VideoFormat { - format: format!("{}", f.format), - description: f.description.clone(), - resolutions: f - .resolutions - .iter() - .map(|r| VideoResolution { - width: r.width, - height: r.height, - fps: r.fps.clone(), - }) - .collect(), - }) - .collect(), - usb_bus, - has_signal: d.has_signal, - } - }) - .collect(), - Err(e) => { - warn!(error = %e, "Video device enumeration failed; returning empty video list for /api/devices"); - vec![] - } - }; - - // Detect serial devices (common USB/ACM ports) - single directory read - let serial_prefixes = ["ttyUSB", "ttyACM", "ttyS"]; - let mut serial_devices = Vec::new(); - if let Ok(entries) = std::fs::read_dir("/dev") { - for entry in entries.flatten() { - let file_name = entry.file_name(); - let name = match file_name.to_str() { - Some(n) => n, - None => continue, - }; - // Check if matches any prefix - if serial_prefixes - .iter() - .any(|prefix| name.starts_with(prefix)) - { - let path = entry.path(); - if let Some(p) = path.to_str() { - serial_devices.push(SerialDevice { - path: p.to_string(), - name: name.to_string(), - }); - } - } - } - } - serial_devices.sort_by(|a, b| a.path.cmp(&b.path)); - - // Detect UDC (USB Device Controller) devices - let mut udc_devices = Vec::new(); - if let Ok(entries) = std::fs::read_dir("/sys/class/udc") { - for entry in entries.flatten() { - if let Some(name) = entry.file_name().to_str() { - udc_devices.push(UdcDevice { - name: name.to_string(), - }); - } - } - } - udc_devices.sort_by(|a, b| a.name.cmp(&b.name)); - - // Detect audio devices - let audio_devices = match state.audio.list_devices().await { - Ok(devices) => devices - .into_iter() - .map(|d| AudioDevice { - name: d.name, - description: d.description, - is_hdmi: d.is_hdmi, - usb_bus: d.usb_bus, - }) - .collect(), - Err(_) => vec![], - }; - - // Check extension availability - let ttyd_available = state - .extensions - .check_available(crate::extensions::ExtensionId::Ttyd); - - Json(DeviceList { - video: video_devices, - serial: serial_devices, - audio: audio_devices, - udc: udc_devices, - extensions: ExtensionsAvailability { - ttyd_available, - rustdesk_available: true, // RustDesk is built-in - }, - }) -} - -use crate::video::streamer::StreamerStats; -use axum::{ - body::Body, - http::{header, StatusCode}, - response::{IntoResponse, Response}, -}; - -/// Get stream state -pub async fn stream_state(State(state): State>) -> Json { - Json(state.stream_manager.stats().await) -} - -/// Start streaming -pub async fn stream_start(State(state): State>) -> Result> { - state.stream_manager.start().await?; - Ok(Json(LoginResponse { - success: true, - message: Some("Streaming started".to_string()), - })) -} - -/// Stop streaming -pub async fn stream_stop(State(state): State>) -> Result> { - state.stream_manager.stop().await?; - Ok(Json(LoginResponse { - success: true, - message: Some("Streaming stopped".to_string()), - })) -} - -/// Stream mode request -#[derive(Deserialize)] -pub struct SetStreamModeRequest { - /// Target mode: "mjpeg" or "webrtc" - pub mode: String, -} - -/// Stream mode response -#[derive(Serialize)] -pub struct StreamModeResponse { - pub success: bool, - pub mode: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub transition_id: Option, - pub switching: bool, - pub message: Option, -} - -/// Get current stream mode -pub async fn stream_mode_get(State(state): State>) -> Json { - let mode = state.stream_manager.current_mode().await; - let mode_str = match mode { - StreamMode::Mjpeg => "mjpeg".to_string(), - StreamMode::WebRTC => { - use crate::video::encoder::VideoCodecType; - let codec = state.stream_manager.current_video_codec().await; - match codec { - VideoCodecType::H264 => "h264".to_string(), - VideoCodecType::H265 => "h265".to_string(), - VideoCodecType::VP8 => "vp8".to_string(), - VideoCodecType::VP9 => "vp9".to_string(), - } - } - }; - Json(StreamModeResponse { - success: true, - mode: mode_str, - transition_id: state.stream_manager.current_transition_id().await, - switching: state.stream_manager.is_switching(), - message: None, - }) -} - -/// Set stream mode (switch between MJPEG and WebRTC) -pub async fn stream_mode_set( - State(state): State>, - Json(req): Json, -) -> Result> { - use crate::video::encoder::VideoCodecType; - - let constraints = state.stream_manager.codec_constraints().await; - - let mode_lower = req.mode.to_lowercase(); - let (new_mode, video_codec) = match mode_lower.as_str() { - "mjpeg" => (StreamMode::Mjpeg, None), - "webrtc" | "h264" => (StreamMode::WebRTC, Some(VideoCodecType::H264)), - "h265" => (StreamMode::WebRTC, Some(VideoCodecType::H265)), - "vp8" => (StreamMode::WebRTC, Some(VideoCodecType::VP8)), - "vp9" => (StreamMode::WebRTC, Some(VideoCodecType::VP9)), - _ => { - return Err(AppError::BadRequest(format!( - "Invalid mode '{}'. Valid modes: mjpeg, h264, h265, vp8, vp9", - req.mode - ))); - } - }; - - if new_mode == StreamMode::Mjpeg && !constraints.is_mjpeg_allowed() { - return Err(AppError::BadRequest(format!( - "Codec 'mjpeg' is not allowed: {}", - constraints.reason - ))); - } - - if let Some(codec) = video_codec { - if !constraints.is_webrtc_codec_allowed(codec) { - return Err(AppError::BadRequest(format!( - "Codec '{}' is not allowed: {}", - codec_to_id(codec), - constraints.reason - ))); - } - } - - let requested_mode_str = match (&new_mode, &video_codec) { - (StreamMode::Mjpeg, _) => "mjpeg", - (StreamMode::WebRTC, Some(VideoCodecType::H264)) => "h264", - (StreamMode::WebRTC, Some(VideoCodecType::H265)) => "h265", - (StreamMode::WebRTC, Some(VideoCodecType::VP8)) => "vp8", - (StreamMode::WebRTC, Some(VideoCodecType::VP9)) => "vp9", - (StreamMode::WebRTC, None) => "webrtc", - }; - - // Detect codec-only switch: already in WebRTC mode, just changing codec. - // switch_mode_transaction treats this as "no switch needed" since StreamMode - // is still WebRTC, so we handle codec change + event emission here. - let current_mode = state.stream_manager.current_mode().await; - let prev_codec = state.stream_manager.current_video_codec().await; - - let codec_changed = video_codec.is_some_and(|c| c != prev_codec); - let is_codec_only_switch = - current_mode == StreamMode::WebRTC && new_mode == StreamMode::WebRTC && codec_changed; - - if let Some(codec) = video_codec { - info!("Setting WebRTC video codec to {:?}", codec); - if let Err(e) = state.stream_manager.set_video_codec(codec).await { - warn!("Failed to set video codec: {}", e); - } - } - - // For codec-only switch, emit events directly instead of going through - // switch_mode_transaction (which short-circuits when mode is unchanged). - if is_codec_only_switch { - let transition_id = uuid::Uuid::new_v4().to_string(); - - state - .stream_manager - .notify_codec_switch(&transition_id, requested_mode_str, &codec_to_id(prev_codec)) - .await; - - return Ok(Json(StreamModeResponse { - success: true, - mode: requested_mode_str.to_string(), - transition_id: Some(transition_id), - switching: false, - message: Some(format!("Codec switched to {}", requested_mode_str)), - })); - } - - let tx = state - .stream_manager - .switch_mode_transaction(new_mode.clone()) - .await?; - - let active_mode_str = match state.stream_manager.current_mode().await { - StreamMode::Mjpeg => "mjpeg".to_string(), - StreamMode::WebRTC => { - let codec = state.stream_manager.current_video_codec().await; - match codec { - VideoCodecType::H264 => "h264".to_string(), - VideoCodecType::H265 => "h265".to_string(), - VideoCodecType::VP8 => "vp8".to_string(), - VideoCodecType::VP9 => "vp9".to_string(), - } - } - }; - - let no_switch_needed = !tx.accepted && !tx.switching && tx.transition_id.is_none(); - Ok(Json(StreamModeResponse { - success: tx.accepted || no_switch_needed, - mode: if tx.accepted { - requested_mode_str.to_string() - } else { - active_mode_str - }, - transition_id: tx.transition_id, - switching: tx.switching, - message: Some(if tx.accepted { - format!("Switching to {} mode", requested_mode_str) - } else if tx.switching { - "Mode switch already in progress".to_string() - } else { - "No switch needed".to_string() - }), - })) -} - -/// Available video codec info -#[derive(Serialize)] -pub struct VideoCodecInfo { - /// Codec identifier (mjpeg, h264, h265, vp8, vp9) - pub id: String, - /// Display name - pub name: String, - /// Protocol (http or webrtc) - pub protocol: String, - /// Whether hardware accelerated - pub hardware: bool, - /// Encoder backend name (e.g., "vaapi", "nvenc", "software") - pub backend: Option, - /// Whether this codec is available - pub available: bool, -} - -/// Encoder backend info -#[derive(Serialize)] -pub struct EncoderBackendInfo { - /// Backend identifier (vaapi, nvenc, qsv, amf, rkmpp, v4l2m2m, software) - pub id: String, - /// Display name - pub name: String, - /// Whether this is a hardware backend - pub is_hardware: bool, - /// Supported video formats (h264, h265, vp8, vp9) - pub supported_formats: Vec, -} - -/// Available codecs response -#[derive(Serialize)] -pub struct AvailableCodecsResponse { - pub success: bool, - /// Available encoder backends - pub backends: Vec, - /// Available codecs (for backward compatibility) - pub codecs: Vec, -} - -/// Stream constraints response -#[derive(Serialize)] -pub struct StreamConstraintsResponse { - pub success: bool, - pub allowed_codecs: Vec, - pub locked_codec: Option, - pub disallow_mjpeg: bool, - pub sources: ConstraintSources, - pub reason: String, - pub current_mode: String, -} - -#[derive(Serialize)] -pub struct ConstraintSources { - pub rustdesk: bool, - pub rtsp: bool, -} - -/// Get stream codec constraints derived from enabled services. -pub async fn stream_constraints_get( - State(state): State>, -) -> Json { - use crate::video::encoder::VideoCodecType; - - let constraints = state.stream_manager.codec_constraints().await; - let current_mode = state.stream_manager.current_mode().await; - let current_mode = match current_mode { - StreamMode::Mjpeg => "mjpeg".to_string(), - StreamMode::WebRTC => { - let codec = state.stream_manager.current_video_codec().await; - match codec { - VideoCodecType::H264 => "h264".to_string(), - VideoCodecType::H265 => "h265".to_string(), - VideoCodecType::VP8 => "vp8".to_string(), - VideoCodecType::VP9 => "vp9".to_string(), - } - } - }; - - Json(StreamConstraintsResponse { - success: true, - allowed_codecs: constraints - .allowed_codecs_for_api() - .into_iter() - .map(str::to_string) - .collect(), - locked_codec: constraints - .locked_codec - .map(codec_to_id) - .map(str::to_string), - disallow_mjpeg: !constraints.allow_mjpeg, - sources: ConstraintSources { - rustdesk: constraints.rustdesk_enabled, - rtsp: constraints.rtsp_enabled, - }, - reason: constraints.reason, - current_mode, - }) -} - -/// Set bitrate request -#[derive(Deserialize)] -pub struct SetBitrateRequest { - pub bitrate_preset: BitratePreset, -} - -/// Set stream bitrate (real-time adjustment) -pub async fn stream_set_bitrate( - State(state): State>, - Json(req): Json, -) -> Result> { - // Update config - state - .config - .update(|config| { - config.stream.bitrate_preset = req.bitrate_preset; - }) - .await?; - - // Apply to WebRTC streamer (real-time adjustment) - if let Err(e) = state - .stream_manager - .set_bitrate_preset(req.bitrate_preset) - .await - { - warn!("Failed to set bitrate dynamically: {}", e); - // Don't fail the request - config is saved, will apply on next connection - } else { - info!("Bitrate updated to {}", req.bitrate_preset); - } - - Ok(Json(LoginResponse { - success: true, - message: Some(format!("Bitrate set to {}", req.bitrate_preset)), - })) -} - -/// Get available video codecs -pub async fn stream_codecs_list() -> Json { - use crate::video::encoder::registry::{EncoderRegistry, VideoEncoderType}; - - let registry = EncoderRegistry::global(); - - // Build backends list - let mut backends = Vec::new(); - for backend in registry.available_backends() { - let formats = registry.formats_for_backend(backend); - let format_ids: Vec = formats - .iter() - .map(|f| match f { - VideoEncoderType::H264 => "h264", - VideoEncoderType::H265 => "h265", - VideoEncoderType::VP8 => "vp8", - VideoEncoderType::VP9 => "vp9", - }) - .map(String::from) - .collect(); - - backends.push(EncoderBackendInfo { - id: format!("{:?}", backend).to_lowercase(), - name: backend.display_name().to_string(), - is_hardware: backend.is_hardware(), - supported_formats: format_ids, - }); - } - - // Build codecs list (for backward compatibility) - let mut codecs = Vec::new(); - - // MJPEG is always available (HTTP streaming) - codecs.push(VideoCodecInfo { - id: "mjpeg".to_string(), - name: "MJPEG / HTTP".to_string(), - protocol: "http".to_string(), - hardware: false, - backend: Some("software".to_string()), - available: true, - }); - - // Check H264 availability (supports software fallback) - let h264_encoder = registry.best_available_encoder(VideoEncoderType::H264); - codecs.push(VideoCodecInfo { - id: "h264".to_string(), - name: "H.264 / WebRTC".to_string(), - protocol: "webrtc".to_string(), - hardware: h264_encoder.map(|e| e.is_hardware).unwrap_or(false), - backend: h264_encoder.map(|e| e.backend.to_string()), - available: h264_encoder.is_some(), - }); - - // Check H265 availability (now supports software too) - let h265_encoder = registry.best_available_encoder(VideoEncoderType::H265); - codecs.push(VideoCodecInfo { - id: "h265".to_string(), - name: "H.265 / WebRTC".to_string(), - protocol: "webrtc".to_string(), - hardware: h265_encoder.map(|e| e.is_hardware).unwrap_or(false), - backend: h265_encoder.map(|e| e.backend.to_string()), - available: h265_encoder.is_some(), - }); - - // Check VP8 availability (now supports software too) - let vp8_encoder = registry.best_available_encoder(VideoEncoderType::VP8); - codecs.push(VideoCodecInfo { - id: "vp8".to_string(), - name: "VP8 / WebRTC".to_string(), - protocol: "webrtc".to_string(), - hardware: vp8_encoder.map(|e| e.is_hardware).unwrap_or(false), - backend: vp8_encoder.map(|e| e.backend.to_string()), - available: vp8_encoder.is_some(), - }); - - // Check VP9 availability (now supports software too) - let vp9_encoder = registry.best_available_encoder(VideoEncoderType::VP9); - codecs.push(VideoCodecInfo { - id: "vp9".to_string(), - name: "VP9 / WebRTC".to_string(), - protocol: "webrtc".to_string(), - hardware: vp9_encoder.map(|e| e.is_hardware).unwrap_or(false), - backend: vp9_encoder.map(|e| e.backend.to_string()), - available: vp9_encoder.is_some(), - }); - - Json(AvailableCodecsResponse { - success: true, - backends, - codecs, - }) -} - -/// Run hardware encoder smoke tests across common resolutions/codecs. -pub async fn video_encoder_self_check() -> Json { - let response = tokio::task::spawn_blocking(run_hardware_self_check) - .await - .unwrap_or_else(|_| build_hardware_self_check_runtime_error()); - - Json(response) -} - -/// Query parameters for MJPEG stream -#[derive(Deserialize, Default)] -pub struct MjpegStreamQuery { - /// Optional client ID (if not provided, a random UUID will be generated) - pub client_id: Option, -} - -/// MJPEG stream endpoint -pub async fn mjpeg_stream( - State(state): State>, - Query(query): Query, -) -> impl IntoResponse { - // Check if MJPEG mode is active - if !state.stream_manager.is_mjpeg_enabled().await { - return axum::response::Response::builder() - .status(axum::http::StatusCode::SERVICE_UNAVAILABLE) - .header("Content-Type", "application/json") - .body(axum::body::Body::from( - r#"{"error":"MJPEG mode not active. Current mode is WebRTC."}"#, - )) - .unwrap(); - } - - // Check if config is being changed - reject new connections during config change - if state.stream_manager.is_config_changing() { - return axum::response::Response::builder() - .status(axum::http::StatusCode::SERVICE_UNAVAILABLE) - .header("Content-Type", "application/json") - .body(axum::body::Body::from( - r#"{"error":"Video configuration is being changed. Please retry shortly."}"#, - )) - .unwrap(); - } - - // Ensure stream is started (but not during config change) - if !state.stream_manager.is_streaming().await && !state.stream_manager.is_config_changing() { - if let Err(e) = state.stream_manager.start().await { - tracing::error!("Failed to auto-start stream: {}", e); - } - } - - let handler = state.stream_manager.mjpeg_handler(); - - // Use provided client ID or generate a new one - let client_id = query - .client_id - .filter(|id| !id.is_empty() && id.len() <= 64) // Validate: non-empty, max 64 chars - .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); - - // Create RAII guard - this will automatically register and unregister the client - let guard = Arc::new(crate::stream::mjpeg::ClientGuard::new( - client_id.clone(), - handler.clone(), - )); - - let (tx, mut rx) = tokio::sync::mpsc::channel::(1); - - let guard_clone = guard.clone(); - let handler_clone = handler.clone(); - tokio::spawn(async move { - let _guard = guard_clone; // Keep guard alive - let mut notify_rx = handler_clone.subscribe(); - let mut last_seq = 0u64; - let mut timeout_count = 0; - - // Send initial frame if available - if let Some(frame) = handler_clone.current_frame() { - if frame.is_valid_jpeg() { - let data = create_mjpeg_part(frame.data()); - // send() blocks until receiver is ready (backpressure) - if tx.send(data).await.is_ok() { - // FPS recording moved to async_stream after yield - last_seq = frame.sequence; - } else { - return; // Receiver closed - } - } - } - - loop { - // Check if stream went offline (e.g., during config change) - if !handler_clone.is_online() { - break; - } - - // Wait for new frame notification with timeout - let result = - tokio::time::timeout(std::time::Duration::from_secs(5), notify_rx.recv()).await; - - match result { - Ok(Ok(())) => { - // Check online status after receiving notification - // set_offline() sends a notification, so we need to check here - if !handler_clone.is_online() { - break; - } - timeout_count = 0; - if let Some(frame) = handler_clone.current_frame() { - // Use != instead of > to handle sequence reset when capturer restarts - // (e.g., after video config change, new capturer starts from seq=0) - if frame.sequence != last_seq && frame.is_valid_jpeg() { - let data = create_mjpeg_part(frame.data()); - if tx.send(data).await.is_ok() { - last_seq = frame.sequence; - } else { - break; - } - } - } - } - Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => { - break; - } - Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => { - // Receiver was too slow - skip missed frames and jump to latest - if !handler_clone.is_online() { - break; - } - timeout_count = 0; - - if let Some(frame) = handler_clone.current_frame() { - if frame.is_valid_jpeg() { - // Send current frame immediately and reset sequence tracking - if tx.send(create_mjpeg_part(frame.data())).await.is_ok() { - last_seq = frame.sequence; - } else { - break; - } - } - } - } - Err(_) => { - // Timeout - check if still online - timeout_count += 1; - if timeout_count > 6 || !handler_clone.is_online() { - break; - } - // Send last frame again to keep connection alive - let Some(frame) = handler_clone.current_frame() else { - continue; - }; - - if frame.is_valid_jpeg() - && tx.send(create_mjpeg_part(frame.data())).await.is_err() - { - break; - } - } - } - } - - }); - - // Create stream that receives from channel and forwards to the HTTP - // body. Record FPS *before* yield so the final frame of a session - // still gets counted (after-yield code in async_stream! only runs - // when the consumer polls again, which never happens for the last - // frame of a closing connection). - let handler_for_stream = handler.clone(); - let guard_for_stream = guard.clone(); - let body_stream = async_stream::stream! { - while let Some(data) = rx.recv().await { - handler_for_stream.record_frame_sent(guard_for_stream.id()); - yield Ok::(data); - } - }; - - Response::builder() - .status(StatusCode::OK) - .header( - header::CONTENT_TYPE, - "multipart/x-mixed-replace; boundary=frame", - ) - .header(header::CACHE_CONTROL, "no-cache, no-store, must-revalidate") - .header(header::PRAGMA, "no-cache") - .header(header::EXPIRES, "0") - .header(header::CONNECTION, "keep-alive") - .body(Body::from_stream(body_stream)) - .unwrap() -} - -/// Single JPEG snapshot -pub async fn snapshot(State(state): State>) -> impl IntoResponse { - let handler = state.stream_manager.mjpeg_handler(); - - match handler.current_frame() { - Some(frame) if frame.is_valid_jpeg() => Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "image/jpeg") - .header(header::CACHE_CONTROL, "no-cache") - .body(Body::from(frame.data_bytes())) - .unwrap(), - _ => Response::builder() - .status(StatusCode::SERVICE_UNAVAILABLE) - .body(Body::from("No frame available")) - .unwrap(), - } -} - -/// Create MJPEG multipart frame bytes -fn create_mjpeg_part(jpeg_data: &[u8]) -> bytes::Bytes { - use bytes::{BufMut, BytesMut}; - - let mut buf = BytesMut::with_capacity(128 + jpeg_data.len()); - - // Write boundary and headers - buf.put_slice(b"--frame\r\n"); - buf.put_slice(b"Content-Type: image/jpeg\r\n"); - buf.put_slice(format!("Content-Length: {}\r\n", jpeg_data.len()).as_bytes()); - buf.put_slice(b"\r\n"); - - // Write JPEG data - buf.put_slice(jpeg_data); - buf.put_slice(b"\r\n"); - - buf.freeze() -} - -use crate::webrtc::signaling::{AnswerResponse, IceCandidateRequest, OfferRequest}; - -/// Create WebRTC session -#[derive(Serialize)] -pub struct CreateSessionResponse { - pub session_id: String, -} - -pub async fn webrtc_create_session( - State(state): State>, -) -> Result> { - // Check if WebRTC mode is active - if !state.stream_manager.is_webrtc_enabled().await { - return Err(AppError::ServiceUnavailable( - "WebRTC mode not active. Current mode is MJPEG.".to_string(), - )); - } - - let session_id = state.webrtc.create_session().await?; - Ok(Json(CreateSessionResponse { session_id })) -} - -/// Handle WebRTC offer -pub async fn webrtc_offer( - State(state): State>, - Json(req): Json, -) -> Result> { - // Check if WebRTC mode is active - if !state.stream_manager.is_webrtc_enabled().await { - return Err(AppError::ServiceUnavailable( - "WebRTC mode not active. Current mode is MJPEG.".to_string(), - )); - } - - // Backward compatibility: `client_id` is treated as an existing session_id hint. - // New clients should not pass it; each offer creates a fresh session. - let webrtc = &state.webrtc; - let session_id = if let Some(client_id) = &req.client_id { - // Reuse only when it matches an active session ID. - if webrtc.get_session(client_id).await.is_some() { - client_id.clone() - } else { - webrtc.create_session().await? - } - } else { - webrtc.create_session().await? - }; - - // Handle offer - let offer = crate::webrtc::SdpOffer::new(req.sdp); - let answer = webrtc.handle_offer(&session_id, offer).await?; - - Ok(Json(AnswerResponse::new( - answer.sdp, - session_id, - answer.ice_candidates.unwrap_or_default(), - ))) -} - -/// Add ICE candidate -pub async fn webrtc_ice_candidate( - State(state): State>, - Json(req): Json, -) -> Result> { - state - .webrtc - .add_ice_candidate(&req.session_id, req.candidate) - .await?; - - Ok(Json(LoginResponse { - success: true, - message: None, - })) -} - -/// Get WebRTC session info -#[derive(Serialize)] -pub struct WebRtcSessionInfo { - pub session_id: String, - pub state: String, -} - -#[derive(Serialize)] -pub struct WebRtcStatus { - pub session_count: usize, - pub sessions: Vec, -} - -pub async fn webrtc_status(State(state): State>) -> Json { - let sessions = state.webrtc.list_sessions().await; - Json(WebRtcStatus { - session_count: sessions.len(), - sessions: sessions - .into_iter() - .map(|s| WebRtcSessionInfo { - session_id: s.session_id, - state: s.state, - }) - .collect(), - }) -} - -/// Close WebRTC session -#[derive(Deserialize)] -pub struct CloseSessionRequest { - pub session_id: String, -} - -pub async fn webrtc_close_session( - State(state): State>, - Json(req): Json, -) -> Result> { - state.webrtc.close_session(&req.session_id).await?; - - Ok(Json(LoginResponse { - success: true, - message: Some("Session closed".to_string()), - })) -} - -/// ICE servers configuration for WebRTC -#[derive(Serialize)] -pub struct IceServersResponse { - pub ice_servers: Vec, - pub mdns_mode: String, -} - -#[derive(Serialize)] -pub struct IceServerInfo { - pub urls: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub username: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub credential: Option, -} - -/// Get ICE servers configuration for client-side WebRTC -/// Returns user-configured servers, or Google STUN as fallback if none configured -pub async fn webrtc_ice_servers(State(state): State>) -> Json { - use crate::webrtc::config::public_ice; - use crate::webrtc::mdns::{mdns_mode, mdns_mode_label}; - - let config = state.config.get(); - let mut ice_servers = Vec::new(); - - // Check if user has configured custom ICE servers - let has_custom_stun = config - .stream - .stun_server - .as_ref() - .map(|s| !s.is_empty()) - .unwrap_or(false); - let has_custom_turn = config - .stream - .turn_server - .as_ref() - .map(|s| !s.is_empty()) - .unwrap_or(false); - - if has_custom_stun || has_custom_turn { - // Use user-configured ICE servers - if let Some(ref stun) = config.stream.stun_server { - if !stun.is_empty() { - ice_servers.push(IceServerInfo { - urls: vec![stun.clone()], - username: None, - credential: None, - }); - } - } - - if let Some(ref turn) = config.stream.turn_server { - if !turn.is_empty() { - let username = config.stream.turn_username.clone(); - let credential = config.stream.turn_password.clone(); - if username.is_some() && credential.is_some() { - ice_servers.push(IceServerInfo { - urls: vec![turn.clone()], - username, - credential, - }); - } - } - } - } else { - // No custom servers — baked-in public STUN - ice_servers.push(IceServerInfo { - urls: vec![public_ice::stun_server().to_string()], - username: None, - credential: None, - }); - // Note: TURN servers are not provided - users must configure their own - } - - let mdns_mode = mdns_mode(); - let mdns_mode = mdns_mode_label(mdns_mode).to_string(); - - Json(IceServersResponse { - ice_servers, - mdns_mode, - }) -} - -/// HID status response -#[derive(Serialize)] -pub struct HidStatus { - pub available: bool, - pub backend: String, - pub initialized: bool, - pub online: bool, - pub supports_absolute_mouse: bool, - pub keyboard_leds_enabled: bool, - pub led_state: crate::hid::LedState, - pub screen_resolution: Option<(u32, u32)>, - pub device: Option, - pub error: Option, - pub error_code: Option, -} - -#[derive(Serialize, Clone, Copy, PartialEq, Eq)] -#[serde(rename_all = "lowercase")] -pub enum OtgSelfCheckLevel { - Info, - Warn, - Error, -} - -#[derive(Serialize)] -pub struct OtgSelfCheckItem { - pub id: &'static str, - pub ok: bool, - pub level: OtgSelfCheckLevel, - pub message: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub hint: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub path: Option, -} - -#[derive(Serialize)] -pub struct OtgSelfCheckResponse { - pub overall_ok: bool, - pub error_count: usize, - pub warning_count: usize, - pub hid_backend: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub selected_udc: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub bound_udc: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub udc_state: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub udc_speed: Option, - pub available_udcs: Vec, - pub other_gadgets: Vec, - pub checks: Vec, -} - -fn push_otg_check( - checks: &mut Vec, - id: &'static str, - ok: bool, - level: OtgSelfCheckLevel, - message: impl Into, - hint: Option>, - path: Option>, -) { - checks.push(OtgSelfCheckItem { - id, - ok, - level, - message: message.into(), - hint: hint.map(|v| v.into()), - path: path.map(|v| v.into()), - }); -} - -fn proc_modules_has(module_name: &str) -> bool { - std::fs::read_to_string("/proc/modules") - .ok() - .map(|content| { - content - .lines() - .filter_map(|line| line.split_whitespace().next()) - .any(|name| name == module_name) - }) - .unwrap_or(false) -} - -fn modules_metadata_has(module_name: &str) -> bool { - let kernel_release = match read_trimmed(std::path::Path::new("/proc/sys/kernel/osrelease")) { - Some(value) if !value.is_empty() => value, - _ => return false, - }; - - let module_dir = std::path::Path::new("/lib/modules").join(kernel_release); - let candidates = ["modules.builtin", "modules.builtin.modinfo", "modules.dep"]; - - candidates.iter().any(|filename| { - let path = module_dir.join(filename); - std::fs::read_to_string(path) - .ok() - .map(|content| { - let module_token = format!("/{module_name}.ko"); - content.lines().any(|line| { - line.contains(&module_token) - || line.contains(module_name) - || line.contains(&module_name.replace('_', "-")) - }) - }) - .unwrap_or(false) - }) -} - -fn kernel_config_option_enabled(option_name: &str) -> bool { - let kernel_release = match read_trimmed(std::path::Path::new("/proc/sys/kernel/osrelease")) { - Some(value) if !value.is_empty() => value, - _ => return false, - }; - - let config_paths = [ - std::path::PathBuf::from(format!("/boot/config-{kernel_release}")), - std::path::PathBuf::from("/boot/config"), - std::path::PathBuf::from(format!("/lib/modules/{kernel_release}/build/.config")), - ]; - - config_paths.iter().any(|path| { - std::fs::read_to_string(path) - .ok() - .map(|content| { - let enabled_y = format!("{option_name}=y"); - let enabled_m = format!("{option_name}=m"); - content - .lines() - .any(|line| line == enabled_y || line == enabled_m) - }) - .unwrap_or(false) - }) -} - -fn detect_libcomposite_available(gadget_root: &std::path::Path) -> bool { - let sys_module = std::path::Path::new("/sys/module/libcomposite").exists(); - if sys_module { - return true; - } - - if proc_modules_has("libcomposite") { - return true; - } - - if modules_metadata_has("libcomposite") { - return true; - } - - if kernel_config_option_enabled("CONFIG_USB_LIBCOMPOSITE") - || kernel_config_option_enabled("CONFIG_USB_CONFIGFS") - { - return true; - } - - // Fallback: if usb_gadget path exists, libcomposite may be built-in and already active. - gadget_root.exists() -} - -/// OTG self-check status for troubleshooting USB gadget issues -pub async fn hid_otg_self_check(State(state): State>) -> Json { - let config = state.config.get(); - let hid_backend_is_otg = matches!(config.hid.backend, crate::config::HidBackend::Otg); - let mut checks = Vec::new(); - - let build_response = |checks: Vec, - selected_udc: Option, - bound_udc: Option, - udc_state: Option, - udc_speed: Option, - available_udcs: Vec, - other_gadgets: Vec| { - let error_count = checks - .iter() - .filter(|item| item.level == OtgSelfCheckLevel::Error) - .count(); - let warning_count = checks - .iter() - .filter(|item| item.level == OtgSelfCheckLevel::Warn) - .count(); - - Json(OtgSelfCheckResponse { - overall_ok: error_count == 0, - error_count, - warning_count, - hid_backend: format!("{:?}", config.hid.backend).to_lowercase(), - selected_udc, - bound_udc, - udc_state, - udc_speed, - available_udcs, - other_gadgets, - checks, - }) - }; - - let udc_root = std::path::Path::new("/sys/class/udc"); - let available_udcs = list_dir_names(udc_root); - let selected_udc = config - .hid - .otg_udc - .clone() - .filter(|udc| !udc.trim().is_empty()) - .or_else(|| available_udcs.first().cloned()); - let mut udc_stage_ok = true; - if !udc_root.exists() { - udc_stage_ok = false; - push_otg_check( - &mut checks, - "udc_dir_exists", - false, - OtgSelfCheckLevel::Error, - "Check /sys/class/udc existence", - Some("Ensure UDC/OTG kernel drivers are enabled"), - Some("/sys/class/udc"), - ); - } else if available_udcs.is_empty() { - udc_stage_ok = false; - push_otg_check( - &mut checks, - "udc_has_entries", - false, - OtgSelfCheckLevel::Error, - "Check available UDC entries", - Some("Ensure OTG controller is enabled in device tree"), - Some("/sys/class/udc"), - ); - } else { - push_otg_check( - &mut checks, - "udc_has_entries", - true, - OtgSelfCheckLevel::Info, - "Check available UDC entries", - None::, - Some("/sys/class/udc"), - ); - } - - let mut configured_udc_ok = true; - if let Some(config_udc) = config - .hid - .otg_udc - .clone() - .map(|value| value.trim().to_string()) - .filter(|value| !value.is_empty()) - { - if available_udcs.iter().any(|item| item == &config_udc) { - push_otg_check( - &mut checks, - "configured_udc_valid", - true, - OtgSelfCheckLevel::Info, - "Check configured UDC validity", - None::, - Some("/sys/class/udc"), - ); - } else { - configured_udc_ok = false; - push_otg_check( - &mut checks, - "configured_udc_valid", - false, - OtgSelfCheckLevel::Error, - "Check configured UDC validity", - Some("Please reselect UDC in HID OTG settings"), - Some("/sys/class/udc"), - ); - } - } else { - push_otg_check( - &mut checks, - "configured_udc_valid", - !available_udcs.is_empty(), - if available_udcs.is_empty() { - OtgSelfCheckLevel::Warn - } else { - OtgSelfCheckLevel::Info - }, - "Check configured UDC validity", - Some( - "You can set hid_otg_udc in settings to avoid ambiguity in multi-controller setups", - ), - Some("/sys/class/udc"), - ); - } - - if !udc_stage_ok || !configured_udc_ok { - return build_response( - checks, - selected_udc, - None, - None, - None, - available_udcs, - vec![], - ); - } - - let gadget_root = std::path::Path::new("/sys/kernel/config/usb_gadget"); - let configfs_mounted = std::fs::read_to_string("/proc/mounts") - .ok() - .map(|mounts| { - mounts.lines().any(|line| { - let mut parts = line.split_whitespace(); - let _src = parts.next(); - let mount_point = parts.next(); - let fs_type = parts.next(); - mount_point == Some("/sys/kernel/config") && fs_type == Some("configfs") - }) - }) - .unwrap_or(false); - - let mut gadget_config_ok = true; - - if configfs_mounted { - push_otg_check( - &mut checks, - "configfs_mounted", - true, - OtgSelfCheckLevel::Info, - "Check configfs mount status", - None::, - Some("/sys/kernel/config"), - ); - } else { - gadget_config_ok = false; - push_otg_check( - &mut checks, - "configfs_mounted", - false, - OtgSelfCheckLevel::Error, - "Check configfs mount status", - Some("Try: mount -t configfs none /sys/kernel/config"), - Some("/sys/kernel/config"), - ); - } - - if gadget_root.exists() { - push_otg_check( - &mut checks, - "usb_gadget_dir_exists", - true, - OtgSelfCheckLevel::Info, - "Check /sys/kernel/config/usb_gadget access", - None::, - Some("/sys/kernel/config/usb_gadget"), - ); - } else { - gadget_config_ok = false; - push_otg_check( - &mut checks, - "usb_gadget_dir_exists", - false, - OtgSelfCheckLevel::Error, - "Check /sys/kernel/config/usb_gadget access", - Some("Ensure configfs and USB gadget support are enabled"), - Some("/sys/kernel/config/usb_gadget"), - ); - } - - let libcomposite_available = detect_libcomposite_available(gadget_root); - if libcomposite_available { - push_otg_check( - &mut checks, - "libcomposite_loaded", - true, - OtgSelfCheckLevel::Info, - "Check libcomposite module status", - None::, - Some("/sys/module/libcomposite"), - ); - } else { - gadget_config_ok = false; - push_otg_check( - &mut checks, - "libcomposite_loaded", - false, - OtgSelfCheckLevel::Error, - "Check libcomposite module status", - Some("Try: modprobe libcomposite"), - Some("/sys/module/libcomposite"), - ); - } - - if !gadget_config_ok { - return build_response( - checks, - selected_udc, - None, - None, - None, - available_udcs, - vec![], - ); - } - - let gadget_names = list_dir_names(gadget_root); - let one_kvm_path = gadget_root.join("one-kvm"); - let one_kvm_exists = one_kvm_path.exists(); - if one_kvm_exists { - push_otg_check( - &mut checks, - "one_kvm_gadget_exists", - true, - OtgSelfCheckLevel::Info, - "Check one-kvm gadget presence", - None::, - Some(one_kvm_path.display().to_string()), - ); - } else { - push_otg_check( - &mut checks, - "one_kvm_gadget_exists", - false, - if hid_backend_is_otg { - OtgSelfCheckLevel::Error - } else { - OtgSelfCheckLevel::Warn - }, - "Check one-kvm gadget presence", - Some("Enable OTG HID or MSD to let one-kvm gadget be created automatically"), - Some(one_kvm_path.display().to_string()), - ); - } - - let other_gadgets = gadget_names - .iter() - .filter(|name| name.as_str() != "one-kvm") - .cloned() - .collect::>(); - if other_gadgets.is_empty() { - push_otg_check( - &mut checks, - "other_gadgets", - true, - OtgSelfCheckLevel::Info, - "Check for other gadget services", - None::, - Some("/sys/kernel/config/usb_gadget"), - ); - } else { - push_otg_check( - &mut checks, - "other_gadgets", - false, - OtgSelfCheckLevel::Warn, - "Check for other gadget services", - Some("Potential UDC contention with one-kvm; check other OTG services"), - Some("/sys/kernel/config/usb_gadget"), - ); - } - - let mut bound_udc = None; - - if one_kvm_exists { - let one_kvm_udc_path = one_kvm_path.join("UDC"); - let current_udc = read_trimmed(&one_kvm_udc_path).unwrap_or_default(); - if current_udc.is_empty() { - push_otg_check( - &mut checks, - "one_kvm_bound_udc", - false, - OtgSelfCheckLevel::Warn, - "Check one-kvm UDC binding", - Some("Ensure HID/MSD is enabled and initialized successfully"), - Some(one_kvm_udc_path.display().to_string()), - ); - } else { - push_otg_check( - &mut checks, - "one_kvm_bound_udc", - true, - OtgSelfCheckLevel::Info, - "Check one-kvm UDC binding", - None::, - Some(one_kvm_udc_path.display().to_string()), - ); - bound_udc = Some(current_udc); - } - - let functions_path = one_kvm_path.join("functions"); - let function_names = list_dir_names(&functions_path) - .into_iter() - .filter(|name| name.contains(".usb")) - .collect::>(); - let hid_functions = function_names - .iter() - .filter(|name| name.starts_with("hid.usb")) - .cloned() - .collect::>(); - if hid_functions.is_empty() { - push_otg_check( - &mut checks, - "hid_functions_present", - false, - if hid_backend_is_otg { - OtgSelfCheckLevel::Error - } else { - OtgSelfCheckLevel::Warn - }, - "Check HID function creation", - Some("Check OTG HID config and enable at least one HID function"), - Some(functions_path.display().to_string()), - ); - } else { - push_otg_check( - &mut checks, - "hid_functions_present", - true, - OtgSelfCheckLevel::Info, - "Check HID function creation", - None::, - Some(functions_path.display().to_string()), - ); - } - - let config_path = one_kvm_path.join("configs/c.1"); - if !config_path.exists() { - push_otg_check( - &mut checks, - "config_c1_exists", - false, - OtgSelfCheckLevel::Error, - "Check configs/c.1 structure", - Some("Gadget structure is incomplete; try restarting One-KVM"), - Some(config_path.display().to_string()), - ); - } else { - push_otg_check( - &mut checks, - "config_c1_exists", - true, - OtgSelfCheckLevel::Info, - "Check configs/c.1 structure", - None::, - Some(config_path.display().to_string()), - ); - - let linked_functions = list_dir_names(&config_path) - .into_iter() - .filter(|name| name.contains(".usb")) - .collect::>(); - let missing_links = function_names - .iter() - .filter(|func| !linked_functions.iter().any(|link| link == *func)) - .cloned() - .collect::>(); - - if missing_links.is_empty() { - push_otg_check( - &mut checks, - "function_links_ok", - true, - OtgSelfCheckLevel::Info, - "Check function links in configs/c.1", - None::, - Some(config_path.display().to_string()), - ); - } else { - push_otg_check( - &mut checks, - "function_links_ok", - false, - OtgSelfCheckLevel::Warn, - "Check function links in configs/c.1", - Some("Reinitialize OTG (toggle HID backend once or restart service)"), - Some(config_path.display().to_string()), - ); - } - } - - let missing_hid_devices = hid_functions - .iter() - .filter_map(|name| { - let index = name.strip_prefix("hid.usb")?.parse::().ok()?; - let dev_path = std::path::PathBuf::from(format!("/dev/hidg{}", index)); - if dev_path.exists() { - None - } else { - Some(dev_path.display().to_string()) - } - }) - .collect::>(); - - if !hid_functions.is_empty() { - if missing_hid_devices.is_empty() { - push_otg_check( - &mut checks, - "hid_device_nodes", - true, - OtgSelfCheckLevel::Info, - "Check /dev/hidg* device nodes", - None::, - Some("/dev/hidg*"), - ); - } else { - push_otg_check( - &mut checks, - "hid_device_nodes", - false, - OtgSelfCheckLevel::Warn, - "Check /dev/hidg* device nodes", - Some("Ensure gadget is bound and check kernel logs"), - Some("/dev/hidg*"), - ); - } - } - } - - if !other_gadgets.is_empty() { - let check_udc = bound_udc.clone().or_else(|| selected_udc.clone()); - if let Some(target_udc) = check_udc { - let conflicting_gadgets = other_gadgets - .iter() - .filter_map(|name| { - let udc_file = gadget_root.join(name).join("UDC"); - let udc = read_trimmed(&udc_file)?; - if udc == target_udc { - Some(name.clone()) - } else { - None - } - }) - .collect::>(); - - if conflicting_gadgets.is_empty() { - push_otg_check( - &mut checks, - "udc_conflict", - true, - OtgSelfCheckLevel::Info, - "Check UDC binding conflicts", - None::, - Some("/sys/kernel/config/usb_gadget/*/UDC"), - ); - } else { - push_otg_check( - &mut checks, - "udc_conflict", - false, - OtgSelfCheckLevel::Error, - "Check UDC binding conflicts", - Some("Stop other OTG services or switch one-kvm to an idle UDC"), - Some("/sys/kernel/config/usb_gadget/*/UDC"), - ); - } - } - } - - let active_udc = bound_udc.clone().or_else(|| selected_udc.clone()); - let mut udc_state = None; - let mut udc_speed = None; - - if let Some(udc) = active_udc.clone() { - let state_path = udc_root.join(&udc).join("state"); - match read_trimmed(&state_path) { - Some(state_name) if state_name.eq_ignore_ascii_case("configured") => { - udc_state = Some(state_name.clone()); - push_otg_check( - &mut checks, - "udc_state", - true, - OtgSelfCheckLevel::Info, - "Check UDC connection state", - None::, - Some(state_path.display().to_string()), - ); - } - Some(state_name) => { - udc_state = Some(state_name.clone()); - push_otg_check( - &mut checks, - "udc_state", - false, - OtgSelfCheckLevel::Warn, - "Check UDC connection state", - Some("Ensure target host is connected and has recognized the USB device"), - Some(state_path.display().to_string()), - ); - } - None => { - push_otg_check( - &mut checks, - "udc_state", - false, - OtgSelfCheckLevel::Warn, - "Check UDC connection state", - Some("Ensure UDC name is valid and check kernel permissions"), - Some(state_path.display().to_string()), - ); - } - } - - let speed_path = udc_root.join(&udc).join("current_speed"); - if let Some(speed) = read_trimmed(&speed_path) { - udc_speed = Some(speed.clone()); - let is_unknown = speed.eq_ignore_ascii_case("unknown"); - push_otg_check( - &mut checks, - "udc_speed", - !is_unknown, - if is_unknown { - OtgSelfCheckLevel::Warn - } else { - OtgSelfCheckLevel::Info - }, - "Check UDC current link speed", - if is_unknown { - Some("Device may not be fully enumerated; try reconnecting USB".to_string()) - } else { - None - }, - Some(speed_path.display().to_string()), - ); - } - } else { - push_otg_check( - &mut checks, - "udc_state", - false, - OtgSelfCheckLevel::Warn, - "Check UDC connection state", - Some("Ensure UDC is available and one-kvm gadget is bound first"), - Some("/sys/class/udc"), - ); - } - - let error_count = checks - .iter() - .filter(|item| item.level == OtgSelfCheckLevel::Error) - .count(); - let warning_count = checks - .iter() - .filter(|item| item.level == OtgSelfCheckLevel::Warn) - .count(); - - Json(OtgSelfCheckResponse { - overall_ok: error_count == 0, - error_count, - warning_count, - hid_backend: format!("{:?}", config.hid.backend).to_lowercase(), - selected_udc, - bound_udc, - udc_state, - udc_speed, - available_udcs, - other_gadgets, - checks, - }) -} - -/// Get HID status -pub async fn hid_status(State(state): State>) -> Json { - let hid = state.hid.snapshot().await; - Json(HidStatus { - available: hid.available, - backend: hid.backend, - initialized: hid.initialized, - online: hid.online, - supports_absolute_mouse: hid.supports_absolute_mouse, - keyboard_leds_enabled: hid.keyboard_leds_enabled, - led_state: hid.led_state, - screen_resolution: hid.screen_resolution, - device: hid.device, - error: hid.error, - error_code: hid.error_code, - }) -} - -/// Reset HID state -pub async fn hid_reset(State(state): State>) -> Result> { - state.hid.reset().await?; - - Ok(Json(LoginResponse { - success: true, - message: Some("HID state reset".to_string()), - })) -} - -use crate::msd::{ - DownloadProgress, DriveFile, DriveInfo, DriveInitRequest, ImageDownloadRequest, ImageInfo, - ImageManager, MsdConnectRequest, MsdMode, MsdState, VentoyDrive, -}; -use axum::extract::{Multipart, Path as AxumPath, Query}; -use std::collections::HashMap; - -/// MSD status response -#[derive(Serialize)] -pub struct MsdStatus { - pub available: bool, - pub state: MsdState, -} - -/// Get MSD status -pub async fn msd_status(State(state): State>) -> Result> { - let msd_guard = state.msd.read().await; - match msd_guard.as_ref() { - Some(controller) => { - let msd_state = controller.state().await; - Ok(Json(MsdStatus { - available: true, - state: msd_state, - })) - } - None => Ok(Json(MsdStatus { - available: false, - state: MsdState::default(), - })), - } -} - -/// List all available images -pub async fn msd_images_list(State(state): State>) -> Result>> { - let config = state.config.get(); - let images_path = config.msd.images_dir(); - let manager = ImageManager::new(images_path); - - let images = manager.list()?; - Ok(Json(images)) -} - -/// Upload new image (streaming - memory efficient for large files) -pub async fn msd_image_upload( - State(state): State>, - mut multipart: Multipart, -) -> Result> { - let config = state.config.get(); - let images_path = config.msd.images_dir(); - let manager = ImageManager::new(images_path); - - while let Some(field) = multipart - .next_field() - .await - .map_err(|e| AppError::Internal(format!("Multipart error: {}", e)))? - { - let name = field.name().unwrap_or("file").to_string(); - if name == "file" { - let filename = field - .file_name() - .ok_or_else(|| AppError::BadRequest("Missing filename".to_string()))? - .to_string(); - - // Use streaming upload - chunks are written directly to disk - // This avoids loading the entire file into memory - let image = manager - .create_from_multipart_field(&filename, field) - .await?; - return Ok(Json(image)); - } - } - - Err(AppError::BadRequest("No file provided".to_string())) -} - -/// Get image by ID -pub async fn msd_image_get( - State(state): State>, - AxumPath(id): AxumPath, -) -> Result> { - let config = state.config.get(); - let images_path = config.msd.images_dir(); - let manager = ImageManager::new(images_path); - - let image = manager.get(&id)?; - Ok(Json(image)) -} - -/// Delete image by ID -pub async fn msd_image_delete( - State(state): State>, - AxumPath(id): AxumPath, -) -> Result> { - let config = state.config.get(); - let images_path = config.msd.images_dir(); - let manager = ImageManager::new(images_path); - - manager.delete(&id)?; - Ok(Json(LoginResponse { - success: true, - message: Some("Image deleted".to_string()), - })) -} - -/// Download image from URL -pub async fn msd_image_download( - State(state): State>, - Json(req): Json, -) -> Result> { - let msd_guard = state.msd.read().await; - let controller = msd_guard - .as_ref() - .ok_or_else(|| AppError::Internal("MSD not initialized".to_string()))?; - - let progress = controller.download_image(req.url, req.filename).await?; - - Ok(Json(progress)) -} - -/// Cancel image download -#[derive(serde::Deserialize)] -pub struct CancelDownloadRequest { - pub download_id: String, -} - -pub async fn msd_image_download_cancel( - State(state): State>, - Json(req): Json, -) -> Result> { - let msd_guard = state.msd.read().await; - let controller = msd_guard - .as_ref() - .ok_or_else(|| AppError::Internal("MSD not initialized".to_string()))?; - - controller.cancel_download(&req.download_id).await?; - - Ok(Json(LoginResponse { - success: true, - message: Some("Download cancelled".to_string()), - })) -} - -/// Connect MSD (image or drive) -pub async fn msd_connect( - State(state): State>, - Json(req): Json, -) -> Result> { - let config = state.config.get(); - let mut msd_guard = state.msd.write().await; - let controller = msd_guard - .as_mut() - .ok_or_else(|| AppError::Internal("MSD not initialized".to_string()))?; - - match req.mode { - MsdMode::Image => { - let image_id = req.image_id.ok_or_else(|| { - AppError::BadRequest("image_id required for image mode".to_string()) - })?; - - // Get image info from ImageManager - let images_path = config.msd.images_dir(); - let manager = ImageManager::new(images_path); - let image = manager.get(&image_id)?; - - // Get mount options from request (defaults: cdrom=false, read_only=false) - let cdrom = req.cdrom.unwrap_or(false); - let read_only = req.read_only.unwrap_or(false); - - controller.connect_image(&image, cdrom, read_only).await?; - } - MsdMode::Drive => { - controller.connect_drive().await?; - } - MsdMode::None => { - return Err(AppError::BadRequest("Invalid mode: none".to_string())); - } - } - - Ok(Json(LoginResponse { - success: true, - message: Some("MSD connected".to_string()), - })) -} - -/// Disconnect MSD -pub async fn msd_disconnect(State(state): State>) -> Result> { - let mut msd_guard = state.msd.write().await; - let controller = msd_guard - .as_mut() - .ok_or_else(|| AppError::Internal("MSD not initialized".to_string()))?; - - controller.disconnect().await?; - - Ok(Json(LoginResponse { - success: true, - message: Some("MSD disconnected".to_string()), - })) -} - -/// Get drive info -pub async fn msd_drive_info(State(state): State>) -> Result> { - let config = state.config.get(); - let drive_path = config.msd.drive_path(); - let drive = VentoyDrive::new(drive_path); - - if !drive.exists() { - return Err(AppError::NotFound("Drive not initialized".to_string())); - } - - let info = drive.info().await?; - Ok(Json(info)) -} - -/// Initialize Ventoy drive -pub async fn msd_drive_init( - State(state): State>, - Json(req): Json, -) -> Result> { - let config = state.config.get(); - let drive_path = config.msd.drive_path(); - let drive = VentoyDrive::new(drive_path); - - let info = drive.init(req.size_mb).await?; - Ok(Json(info)) -} - -/// Delete virtual drive -pub async fn msd_drive_delete(State(state): State>) -> Result> { - let config = state.config.get(); - - // Check if drive is currently connected - let msd_guard = state.msd.write().await; - if let Some(controller) = msd_guard.as_ref() { - let msd_state = controller.state().await; - if msd_state.connected && msd_state.mode == crate::msd::types::MsdMode::Drive { - return Err(AppError::BadRequest( - "Cannot delete drive while connected. Disconnect first.".to_string(), - )); - } - } - drop(msd_guard); - - // Delete the drive file - let drive_path = config.msd.drive_path(); - if drive_path.exists() { - std::fs::remove_file(&drive_path) - .map_err(|e| AppError::Internal(format!("Failed to delete drive file: {}", e)))?; - } - - Ok(Json(LoginResponse { - success: true, - message: Some("Virtual drive deleted".to_string()), - })) -} - -/// List drive files -pub async fn msd_drive_files( - State(state): State>, - Query(params): Query>, -) -> Result>> { - let config = state.config.get(); - let drive_path = config.msd.drive_path(); - let drive = VentoyDrive::new(drive_path); - - let dir_path = params.get("path").map(|s| s.as_str()).unwrap_or("/"); - let files = drive.list_files(dir_path).await?; - Ok(Json(files)) -} - -/// Upload file to drive (streaming - memory efficient for large files) -pub async fn msd_drive_upload( - State(state): State>, - Query(params): Query>, - mut multipart: Multipart, -) -> Result> { - let config = state.config.get(); - let drive_path = config.msd.drive_path(); - let drive = VentoyDrive::new(drive_path); - - let target_dir = params.get("path").map(|s| s.as_str()).unwrap_or("/"); - - while let Some(field) = multipart - .next_field() - .await - .map_err(|e| AppError::Internal(format!("Multipart error: {}", e)))? - { - let name = field.name().unwrap_or("file").to_string(); - if name == "file" { - let filename = field - .file_name() - .ok_or_else(|| AppError::BadRequest("Missing filename".to_string()))? - .to_string(); - - let file_path = if target_dir == "/" { - format!("/{}", filename) - } else { - format!("{}/{}", target_dir.trim_end_matches('/'), filename) - }; - - // Use streaming upload - chunks are written directly to disk - // This avoids loading the entire file into memory - drive - .write_file_from_multipart_field(&file_path, field) - .await?; - - return Ok(Json(LoginResponse { - success: true, - message: Some(format!("File uploaded: {}", file_path)), - })); - } - } - - Err(AppError::BadRequest("No file provided".to_string())) -} - -/// Download file from drive (streaming for large files) -pub async fn msd_drive_download( - State(state): State>, - AxumPath(file_path): AxumPath, -) -> Result { - let config = state.config.get(); - let drive_path = config.msd.drive_path(); - let drive = VentoyDrive::new(drive_path); - - // Get file stream (returns file size and channel receiver) - let (file_size, mut rx) = drive.read_file_stream(&file_path).await?; - - // Extract filename for Content-Disposition - let filename = file_path.split('/').next_back().unwrap_or("download"); - - // Create a stream from the channel receiver - let body_stream = async_stream::stream! { - while let Some(chunk) = rx.recv().await { - yield chunk; - } - }; - - Ok(Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "application/octet-stream") - .header(header::CONTENT_LENGTH, file_size) - .header( - header::CONTENT_DISPOSITION, - format!("attachment; filename=\"{}\"", filename), - ) - .body(Body::from_stream(body_stream)) - .unwrap()) -} - -/// Delete file from drive -pub async fn msd_drive_file_delete( - State(state): State>, - AxumPath(file_path): AxumPath, -) -> Result> { - let config = state.config.get(); - let drive_path = config.msd.drive_path(); - let drive = VentoyDrive::new(drive_path); - - drive.delete(&file_path).await?; - - Ok(Json(LoginResponse { - success: true, - message: Some(format!("Deleted: {}", file_path)), - })) -} - -/// Create directory in drive -pub async fn msd_drive_mkdir( - State(state): State>, - AxumPath(dir_path): AxumPath, -) -> Result> { - let config = state.config.get(); - let drive_path = config.msd.drive_path(); - let drive = VentoyDrive::new(drive_path); - - drive.mkdir(&dir_path).await?; - - Ok(Json(LoginResponse { - success: true, - message: Some(format!("Directory created: {}", dir_path)), - })) -} - -use crate::atx::{AtxState, PowerStatus}; - -const WOL_HISTORY_MAX_ENTRIES: i64 = 50; -const WOL_HISTORY_DEFAULT_LIMIT: usize = 5; -const WOL_HISTORY_MAX_LIMIT: usize = 50; - -/// ATX state response -#[derive(Serialize)] -pub struct AtxStateResponse { - pub available: bool, - pub backend: String, - pub initialized: bool, - pub power_status: String, - pub led_supported: bool, -} - -impl From for AtxStateResponse { - fn from(state: AtxState) -> Self { - Self { - available: state.available, - backend: if state.power_configured || state.reset_configured { - format!( - "power: {}, reset: {}", - if state.power_configured { "yes" } else { "no" }, - if state.reset_configured { "yes" } else { "no" } - ) - } else { - "none".to_string() - }, - initialized: state.power_configured || state.reset_configured, - power_status: match state.power_status { - PowerStatus::On => "on".to_string(), - PowerStatus::Off => "off".to_string(), - PowerStatus::Unknown => "unknown".to_string(), - }, - led_supported: state.led_supported, - } - } -} - -/// Get ATX status -pub async fn atx_status(State(state): State>) -> Result> { - let atx_guard = state.atx.read().await; - - match atx_guard.as_ref() { - Some(atx) => { - let atx_state = atx.state().await; - Ok(Json(AtxStateResponse::from(atx_state))) - } - None => Ok(Json(AtxStateResponse { - available: false, - backend: "none".to_string(), - initialized: false, - power_status: "unknown".to_string(), - led_supported: false, - })), - } -} - -/// ATX power control request -#[derive(Deserialize)] -pub struct AtxPowerControlRequest { - pub action: String, // "short", "long", "reset" -} - -/// Control ATX power -pub async fn atx_power( - State(state): State>, - Json(req): Json, -) -> Result> { - let atx_guard = state.atx.read().await; - let atx = atx_guard - .as_ref() - .ok_or_else(|| AppError::Internal("ATX controller not initialized".to_string()))?; - - match req.action.as_str() { - "short" => { - atx.power_short().await?; - Ok(Json(LoginResponse { - success: true, - message: Some("Power short press executed".to_string()), - })) - } - "long" => { - atx.power_long().await?; - Ok(Json(LoginResponse { - success: true, - message: Some("Power long press (force off) executed".to_string()), - })) - } - "reset" => { - atx.reset().await?; - Ok(Json(LoginResponse { - success: true, - message: Some("Reset button pressed".to_string()), - })) - } - _ => Err(AppError::BadRequest(format!( - "Unknown ATX action: {}. Valid actions: short, long, reset", - req.action - ))), - } -} - -/// WOL request body -#[derive(Debug, Deserialize)] -pub struct WolRequest { - /// Target MAC address (e.g., "AA:BB:CC:DD:EE:FF" or "AA-BB-CC-DD-EE-FF") - pub mac_address: String, -} - -#[derive(Debug, Deserialize, Default)] -pub struct WolHistoryQuery { - /// Maximum history entries to return - pub limit: Option, -} - -#[derive(Debug, Serialize)] -pub struct WolHistoryEntry { - pub mac_address: String, - pub updated_at: i64, -} - -#[derive(Debug, Serialize)] -pub struct WolHistoryResponse { - pub history: Vec, -} - -fn normalize_wol_mac_address(mac_address: &str) -> String { - let normalized = mac_address.trim().to_uppercase().replace('-', ":"); - - if normalized.len() == 12 && normalized.chars().all(|c| c.is_ascii_hexdigit()) { - let mut mac_with_separator = String::with_capacity(17); - for (index, chunk) in normalized.as_bytes().chunks(2).enumerate() { - if index > 0 { - mac_with_separator.push(':'); - } - mac_with_separator.push(chunk[0] as char); - mac_with_separator.push(chunk[1] as char); - } - mac_with_separator - } else { - normalized - } -} - -async fn record_wol_history(state: &Arc, mac_address: &str) -> Result<()> { - sqlx::query( - r#" - INSERT INTO wol_history (mac_address, updated_at) - VALUES (?1, CAST(strftime('%s', 'now') AS INTEGER)) - ON CONFLICT(mac_address) DO UPDATE SET - updated_at = excluded.updated_at - "#, - ) - .bind(mac_address) - .execute(state.db.pool()) - .await?; - - sqlx::query( - r#" - DELETE FROM wol_history - WHERE mac_address NOT IN ( - SELECT mac_address FROM wol_history - ORDER BY updated_at DESC - LIMIT ?1 - ) - "#, - ) - .bind(WOL_HISTORY_MAX_ENTRIES) - .execute(state.db.pool()) - .await?; - - Ok(()) -} - -/// Send Wake-on-LAN magic packet -pub async fn atx_wol( - State(state): State>, - Json(req): Json, -) -> Result> { - let mac_address = normalize_wol_mac_address(&req.mac_address); - - // Get WOL interface from config - let config = state.config.get(); - let interface = if config.atx.wol_interface.is_empty() { - None - } else { - Some(config.atx.wol_interface.as_str()) - }; - - // Send WOL packet - crate::atx::send_wol(&mac_address, interface)?; - - if let Err(error) = record_wol_history(&state, &mac_address).await { - warn!("Failed to persist WOL history: {}", error); - } - - Ok(Json(LoginResponse { - success: true, - message: Some(format!("WOL packet sent to {}", mac_address)), - })) -} - -/// Get WOL history -pub async fn atx_wol_history( - State(state): State>, - Query(query): Query, -) -> Result> { - let limit = query - .limit - .unwrap_or(WOL_HISTORY_DEFAULT_LIMIT) - .clamp(1, WOL_HISTORY_MAX_LIMIT); - - let rows: Vec<(String, i64)> = sqlx::query_as( - r#" - SELECT mac_address, updated_at - FROM wol_history - ORDER BY updated_at DESC - LIMIT ?1 - "#, - ) - .bind(limit as i64) - .fetch_all(state.db.pool()) - .await?; - - let history = rows - .into_iter() - .map(|(mac_address, updated_at)| WolHistoryEntry { - mac_address, - updated_at, - }) - .collect(); - - Ok(Json(WolHistoryResponse { history })) -} - -use crate::audio::{AudioQuality, AudioStatus}; - -/// Audio status response (re-exports AudioStatus from audio module) -pub type AudioStatusResponse = AudioStatus; - -/// Get audio status -pub async fn audio_status(State(state): State>) -> Json { - Json(state.audio.status().await) -} - -/// Start audio streaming -pub async fn start_audio_streaming( - State(state): State>, -) -> Result> { - state.audio.start_streaming().await?; - - // Reconnect audio sources for existing WebRTC sessions - // This ensures sessions created before audio was enabled will receive audio - state.stream_manager.reconnect_webrtc_audio_sources().await; - - Ok(Json(LoginResponse { - success: true, - message: Some("Audio streaming started".to_string()), - })) -} - -/// Stop audio streaming -pub async fn stop_audio_streaming( - State(state): State>, -) -> Result> { - state.audio.stop_streaming().await?; - Ok(Json(LoginResponse { - success: true, - message: Some("Audio streaming stopped".to_string()), - })) -} - -/// Set audio quality request -#[derive(Deserialize)] -pub struct SetAudioQualityRequest { - pub quality: String, -} - -/// Set audio quality -pub async fn set_audio_quality( - State(state): State>, - Json(req): Json, -) -> Result> { - let quality = req.quality.parse::()?; - state.audio.set_quality(quality).await?; - Ok(Json(LoginResponse { - success: true, - message: Some(format!("Audio quality set to {}", quality)), - })) -} - -/// Select audio device request -#[derive(Deserialize)] -pub struct SelectAudioDeviceRequest { - pub device: String, -} - -/// Select audio device -pub async fn select_audio_device( - State(state): State>, - Json(req): Json, -) -> Result> { - state.audio.select_device(&req.device).await?; - Ok(Json(LoginResponse { - success: true, - message: Some(format!("Audio device selected: {}", req.device)), - })) -} - -/// List audio devices -pub async fn list_audio_devices( - State(state): State>, -) -> Result>> { - let devices = state.audio.list_devices().await?; - Ok(Json(devices)) -} - -/// Change password request -#[derive(Deserialize)] -pub struct ChangePasswordRequest { - pub current_password: String, - pub new_password: String, -} - -/// Change current user's password -pub async fn change_password( - State(state): State>, - axum::Extension(session): axum::Extension, - Json(req): Json, -) -> Result> { - let current_user = state - .users - .single_user() - .await? - .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; - - if current_user.id != session.user_id { - return Err(AppError::AuthError("Invalid session".to_string())); - } - - if req.new_password.len() < 4 { - return Err(AppError::BadRequest( - "Password must be at least 4 characters".to_string(), - )); - } - - let verified = state - .users - .verify(¤t_user.username, &req.current_password) - .await?; - if verified.is_none() { - return Err(AppError::AuthError( - "Current password is incorrect".to_string(), - )); - } - - state - .users - .update_password(&session.user_id, &req.new_password) - .await?; - info!("Password changed for user ID: {}", session.user_id); - - Ok(Json(LoginResponse { - success: true, - message: Some("Password changed successfully".to_string()), - })) -} - -/// Change username request -#[derive(Deserialize)] -pub struct ChangeUsernameRequest { - pub username: String, - pub current_password: String, -} - -/// Change current user's username -pub async fn change_username( - State(state): State>, - axum::Extension(session): axum::Extension, - Json(req): Json, -) -> Result> { - let current_user = state - .users - .single_user() - .await? - .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; - - if current_user.id != session.user_id { - return Err(AppError::AuthError("Invalid session".to_string())); - } - - if req.username.len() < 2 { - return Err(AppError::BadRequest( - "Username must be at least 2 characters".to_string(), - )); - } - - let verified = state - .users - .verify(¤t_user.username, &req.current_password) - .await?; - if verified.is_none() { - return Err(AppError::AuthError( - "Current password is incorrect".to_string(), - )); - } - - if current_user.username != req.username { - state - .users - .update_username(&session.user_id, &req.username) - .await?; - } - info!("Username changed for user ID: {}", session.user_id); - - Ok(Json(LoginResponse { - success: true, - message: Some("Username changed successfully".to_string()), - })) -} - -/// Restart the application -pub async fn system_restart(State(state): State>) -> Json { - info!("System restart requested via API"); - - // Send shutdown signal - let _ = state.shutdown_tx.send(()); - - // Spawn restart task in background - tokio::spawn(async { - // Wait for resources to be released (OTG, video, etc.) - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - - // Get current executable and args - let exe = match std::env::current_exe() { - Ok(e) => e, - Err(e) => { - tracing::error!("Failed to get current exe: {}", e); - std::process::exit(1); - } - }; - let args: Vec = std::env::args().skip(1).collect(); - - info!("Restarting: {:?} {:?}", exe, args); - - // Use exec to replace current process (Unix) - #[cfg(unix)] - { - use std::os::unix::process::CommandExt; - let err = std::process::Command::new(&exe).args(&args).exec(); - tracing::error!("Failed to restart: {}", err); - std::process::exit(1); - } - - #[cfg(not(unix))] - { - let _ = std::process::Command::new(&exe).args(&args).spawn(); - std::process::exit(0); - } - }); - - Json(LoginResponse { - success: true, - message: Some("Restarting...".to_string()), - }) -} - -#[derive(Deserialize)] -pub struct UpdateOverviewQuery { - pub channel: Option, -} - -pub async fn update_overview( - State(state): State>, - axum::extract::Query(query): axum::extract::Query, -) -> Result> { - let channel = query.channel.unwrap_or(UpdateChannel::Stable); - let response = state.update.overview(channel).await?; - Ok(Json(response)) -} - -pub async fn update_upgrade( - State(state): State>, - Json(req): Json, -) -> Result> { - state.update.start_upgrade(req, state.shutdown_tx.clone())?; - - Ok(Json(LoginResponse { - success: true, - message: Some("Upgrade started".to_string()), - })) -} - -pub async fn update_status(State(state): State>) -> Json { - Json(state.update.status().await) -} +use crate::video::codec_constraints::codec_to_id; diff --git a/src/web/handlers/msd_api.rs b/src/web/handlers/msd_api.rs new file mode 100644 index 00000000..464a6940 --- /dev/null +++ b/src/web/handlers/msd_api.rs @@ -0,0 +1,405 @@ +use super::*; + +use crate::msd::{ + DownloadProgress, DriveFile, DriveInfo, DriveInitRequest, ImageDownloadRequest, ImageInfo, + ImageManager, MsdConnectRequest, MsdMode, MsdState, VentoyDrive, +}; +#[cfg(unix)] +use axum::extract::{Multipart, Path as AxumPath}; +#[cfg(unix)] +use std::collections::HashMap; + +/// MSD status response +#[cfg(unix)] +#[derive(Serialize)] +pub struct MsdStatus { + pub available: bool, + pub state: MsdState, +} + +/// Get MSD status +#[cfg(unix)] +pub async fn msd_status(State(state): State>) -> Result> { + let msd_guard = state.msd.read().await; + match msd_guard.as_ref() { + Some(controller) => { + let msd_state = controller.state().await; + Ok(Json(MsdStatus { + available: true, + state: msd_state, + })) + } + None => Ok(Json(MsdStatus { + available: false, + state: MsdState::default(), + })), + } +} + +/// List all available images +#[cfg(unix)] +pub async fn msd_images_list(State(state): State>) -> Result>> { + let config = state.config.get(); + let images_path = config.msd.images_dir(); + let manager = ImageManager::new(images_path); + + let images = manager.list()?; + Ok(Json(images)) +} + +/// Upload new image (streaming - memory efficient for large files) +#[cfg(unix)] +pub async fn msd_image_upload( + State(state): State>, + mut multipart: Multipart, +) -> Result> { + let config = state.config.get(); + let images_path = config.msd.images_dir(); + let manager = ImageManager::new(images_path); + + while let Some(field) = multipart + .next_field() + .await + .map_err(|e| AppError::Internal(format!("Multipart error: {}", e)))? + { + let name = field.name().unwrap_or("file").to_string(); + if name == "file" { + let filename = field + .file_name() + .ok_or_else(|| AppError::BadRequest("Missing filename".to_string()))? + .to_string(); + + // Use streaming upload - chunks are written directly to disk + // This avoids loading the entire file into memory + let image = manager + .create_from_multipart_field(&filename, field) + .await?; + return Ok(Json(image)); + } + } + + Err(AppError::BadRequest("No file provided".to_string())) +} + +/// Get image by ID +#[cfg(unix)] +pub async fn msd_image_get( + State(state): State>, + AxumPath(id): AxumPath, +) -> Result> { + let config = state.config.get(); + let images_path = config.msd.images_dir(); + let manager = ImageManager::new(images_path); + + let image = manager.get(&id)?; + Ok(Json(image)) +} + +/// Delete image by ID +#[cfg(unix)] +pub async fn msd_image_delete( + State(state): State>, + AxumPath(id): AxumPath, +) -> Result> { + let config = state.config.get(); + let images_path = config.msd.images_dir(); + let manager = ImageManager::new(images_path); + + manager.delete(&id)?; + Ok(Json(LoginResponse { + success: true, + message: Some("Image deleted".to_string()), + })) +} + +/// Download image from URL +#[cfg(unix)] +pub async fn msd_image_download( + State(state): State>, + Json(req): Json, +) -> Result> { + let msd_guard = state.msd.read().await; + let controller = msd_guard + .as_ref() + .ok_or_else(|| AppError::Internal("MSD not initialized".to_string()))?; + + let progress = controller.download_image(req.url, req.filename).await?; + + Ok(Json(progress)) +} + +/// Cancel image download +#[cfg(unix)] +#[derive(serde::Deserialize)] +pub struct CancelDownloadRequest { + pub download_id: String, +} + +#[cfg(unix)] +pub async fn msd_image_download_cancel( + State(state): State>, + Json(req): Json, +) -> Result> { + let msd_guard = state.msd.read().await; + let controller = msd_guard + .as_ref() + .ok_or_else(|| AppError::Internal("MSD not initialized".to_string()))?; + + controller.cancel_download(&req.download_id).await?; + + Ok(Json(LoginResponse { + success: true, + message: Some("Download cancelled".to_string()), + })) +} + +/// Connect MSD (image or drive) +#[cfg(unix)] +pub async fn msd_connect( + State(state): State>, + Json(req): Json, +) -> Result> { + let config = state.config.get(); + let mut msd_guard = state.msd.write().await; + let controller = msd_guard + .as_mut() + .ok_or_else(|| AppError::Internal("MSD not initialized".to_string()))?; + + match req.mode { + MsdMode::Image => { + let image_id = req.image_id.ok_or_else(|| { + AppError::BadRequest("image_id required for image mode".to_string()) + })?; + + // Get image info from ImageManager + let images_path = config.msd.images_dir(); + let manager = ImageManager::new(images_path); + let image = manager.get(&image_id)?; + + // Get mount options from request (defaults: cdrom=false, read_only=false) + let cdrom = req.cdrom.unwrap_or(false); + let read_only = req.read_only.unwrap_or(false); + + controller.connect_image(&image, cdrom, read_only).await?; + } + MsdMode::Drive => { + controller.connect_drive().await?; + } + MsdMode::None => { + return Err(AppError::BadRequest("Invalid mode: none".to_string())); + } + } + + Ok(Json(LoginResponse { + success: true, + message: Some("MSD connected".to_string()), + })) +} + +/// Disconnect MSD +#[cfg(unix)] +pub async fn msd_disconnect(State(state): State>) -> Result> { + let mut msd_guard = state.msd.write().await; + let controller = msd_guard + .as_mut() + .ok_or_else(|| AppError::Internal("MSD not initialized".to_string()))?; + + controller.disconnect().await?; + + Ok(Json(LoginResponse { + success: true, + message: Some("MSD disconnected".to_string()), + })) +} + +/// Get drive info +#[cfg(unix)] +pub async fn msd_drive_info(State(state): State>) -> Result> { + let config = state.config.get(); + let drive_path = config.msd.drive_path(); + let drive = VentoyDrive::new(drive_path); + + if !drive.exists() { + return Err(AppError::NotFound("Drive not initialized".to_string())); + } + + let info = drive.info().await?; + Ok(Json(info)) +} + +/// Initialize Ventoy drive +#[cfg(unix)] +pub async fn msd_drive_init( + State(state): State>, + Json(req): Json, +) -> Result> { + let config = state.config.get(); + let drive_path = config.msd.drive_path(); + let drive = VentoyDrive::new(drive_path); + + let info = drive.init(req.size_mb).await?; + Ok(Json(info)) +} + +/// Delete virtual drive +#[cfg(unix)] +pub async fn msd_drive_delete(State(state): State>) -> Result> { + let config = state.config.get(); + + // Check if drive is currently connected + let msd_guard = state.msd.write().await; + if let Some(controller) = msd_guard.as_ref() { + let msd_state = controller.state().await; + if msd_state.connected && msd_state.mode == crate::msd::types::MsdMode::Drive { + return Err(AppError::BadRequest( + "Cannot delete drive while connected. Disconnect first.".to_string(), + )); + } + } + drop(msd_guard); + + // Delete the drive file + let drive_path = config.msd.drive_path(); + if drive_path.exists() { + std::fs::remove_file(&drive_path) + .map_err(|e| AppError::Internal(format!("Failed to delete drive file: {}", e)))?; + } + + Ok(Json(LoginResponse { + success: true, + message: Some("Virtual drive deleted".to_string()), + })) +} + +/// List drive files +#[cfg(unix)] +pub async fn msd_drive_files( + State(state): State>, + Query(params): Query>, +) -> Result>> { + let config = state.config.get(); + let drive_path = config.msd.drive_path(); + let drive = VentoyDrive::new(drive_path); + + let dir_path = params.get("path").map(|s| s.as_str()).unwrap_or("/"); + let files = drive.list_files(dir_path).await?; + Ok(Json(files)) +} + +/// Upload file to drive (streaming - memory efficient for large files) +#[cfg(unix)] +pub async fn msd_drive_upload( + State(state): State>, + Query(params): Query>, + mut multipart: Multipart, +) -> Result> { + let config = state.config.get(); + let drive_path = config.msd.drive_path(); + let drive = VentoyDrive::new(drive_path); + + let target_dir = params.get("path").map(|s| s.as_str()).unwrap_or("/"); + + while let Some(field) = multipart + .next_field() + .await + .map_err(|e| AppError::Internal(format!("Multipart error: {}", e)))? + { + let name = field.name().unwrap_or("file").to_string(); + if name == "file" { + let filename = field + .file_name() + .ok_or_else(|| AppError::BadRequest("Missing filename".to_string()))? + .to_string(); + + let file_path = if target_dir == "/" { + format!("/{}", filename) + } else { + format!("{}/{}", target_dir.trim_end_matches('/'), filename) + }; + + // Use streaming upload - chunks are written directly to disk + // This avoids loading the entire file into memory + drive + .write_file_from_multipart_field(&file_path, field) + .await?; + + return Ok(Json(LoginResponse { + success: true, + message: Some(format!("File uploaded: {}", file_path)), + })); + } + } + + Err(AppError::BadRequest("No file provided".to_string())) +} + +/// Download file from drive (streaming for large files) +#[cfg(unix)] +pub async fn msd_drive_download( + State(state): State>, + AxumPath(file_path): AxumPath, +) -> Result { + let config = state.config.get(); + let drive_path = config.msd.drive_path(); + let drive = VentoyDrive::new(drive_path); + + // Get file stream (returns file size and channel receiver) + let (file_size, mut rx) = drive.read_file_stream(&file_path).await?; + + // Extract filename for Content-Disposition + let filename = file_path.split('/').next_back().unwrap_or("download"); + + // Create a stream from the channel receiver + let body_stream = async_stream::stream! { + while let Some(chunk) = rx.recv().await { + yield chunk; + } + }; + + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header(header::CONTENT_LENGTH, file_size) + .header( + header::CONTENT_DISPOSITION, + format!("attachment; filename=\"{}\"", filename), + ) + .body(Body::from_stream(body_stream)) + .unwrap()) +} + +/// Delete file from drive +#[cfg(unix)] +pub async fn msd_drive_file_delete( + State(state): State>, + AxumPath(file_path): AxumPath, +) -> Result> { + let config = state.config.get(); + let drive_path = config.msd.drive_path(); + let drive = VentoyDrive::new(drive_path); + + drive.delete(&file_path).await?; + + Ok(Json(LoginResponse { + success: true, + message: Some(format!("Deleted: {}", file_path)), + })) +} + +/// Create directory in drive +#[cfg(unix)] +pub async fn msd_drive_mkdir( + State(state): State>, + AxumPath(dir_path): AxumPath, +) -> Result> { + let config = state.config.get(); + let drive_path = config.msd.drive_path(); + let drive = VentoyDrive::new(drive_path); + + drive.mkdir(&dir_path).await?; + + Ok(Json(LoginResponse { + success: true, + message: Some(format!("Directory created: {}", dir_path)), + })) +} diff --git a/src/web/handlers/setup.rs b/src/web/handlers/setup.rs new file mode 100644 index 00000000..f2f95526 --- /dev/null +++ b/src/web/handlers/setup.rs @@ -0,0 +1,261 @@ +use super::*; + +#[derive(Serialize)] +pub struct SetupStatus { + pub initialized: bool, + pub needs_setup: bool, + pub platform: PlatformCapabilities, +} + +pub async fn setup_status(State(state): State>) -> Json { + let initialized = state.config.is_initialized(); + Json(SetupStatus { + initialized, + needs_setup: !initialized, + platform: PlatformCapabilities::current(), + }) +} + +#[derive(Deserialize)] +pub struct SetupRequest { + // Account settings + pub username: String, + pub password: String, + // Video settings + pub video_device: Option, + pub video_format: Option, + pub video_width: Option, + pub video_height: Option, + pub video_fps: Option, + // Audio settings + pub audio_device: Option, + // HID settings + pub hid_backend: Option, + pub hid_ch9329_port: Option, + pub hid_ch9329_baudrate: Option, + pub hid_otg_udc: Option, + pub hid_otg_profile: Option, + pub hid_otg_endpoint_budget: Option, + pub hid_otg_keyboard_leds: Option, + pub msd_enabled: Option, + // Extension settings + pub ttyd_enabled: Option, + pub rustdesk_enabled: Option, +} + +pub async fn setup_init( + State(state): State>, + Json(req): Json, +) -> Result> { + // Check if already initialized + if state.config.is_initialized() { + return Err(AppError::BadRequest("Already initialized".to_string())); + } + + // Validate username + if req.username.len() < 2 { + return Err(AppError::BadRequest( + "Username must be at least 2 characters".to_string(), + )); + } + + // Validate password + if req.password.len() < 4 { + return Err(AppError::BadRequest( + "Password must be at least 4 characters".to_string(), + )); + } + + // Create single system user + state + .users + .create_first_user(&req.username, &req.password) + .await?; + + // Update config + state + .config + .update(|config| { + config.initialized = true; + + // Video settings + if let Some(device) = req.video_device.clone() { + config.video.device = Some(device); + } + if let Some(format) = req.video_format.clone() { + config.video.format = Some(format); + } + if let Some(width) = req.video_width { + config.video.width = width; + } + if let Some(height) = req.video_height { + config.video.height = height; + } + if let Some(fps) = req.video_fps { + config.video.fps = fps; + } + + // Audio settings + if let Some(device) = req.audio_device.clone() { + config.audio.device = device; + config.audio.enabled = true; + } + + // HID settings + if let Some(backend) = req.hid_backend.clone() { + config.hid.backend = match backend.as_str() { + "otg" => crate::config::HidBackend::Otg, + "ch9329" => crate::config::HidBackend::Ch9329, + _ => crate::config::HidBackend::None, + }; + } + if let Some(port) = req.hid_ch9329_port.clone() { + config.hid.ch9329_port = port; + } + if let Some(baudrate) = req.hid_ch9329_baudrate { + config.hid.ch9329_baudrate = baudrate; + } + if let Some(udc) = req.hid_otg_udc.clone() { + config.hid.otg_udc = Some(udc); + } + if let Some(profile) = req.hid_otg_profile.clone() { + if let Some(parsed) = crate::config::OtgHidProfile::from_legacy_str(&profile) { + config.hid.otg_profile = parsed; + } + } + if let Some(budget) = req.hid_otg_endpoint_budget { + config.hid.otg_endpoint_budget = budget; + } + if let Some(enabled) = req.hid_otg_keyboard_leds { + config.hid.otg_keyboard_leds = enabled; + } + if let Some(enabled) = req.msd_enabled { + config.msd.enabled = enabled; + } + + // Extension settings + if let Some(enabled) = req.ttyd_enabled { + config.extensions.ttyd.enabled = enabled; + } + if let Some(enabled) = req.rustdesk_enabled { + config.rustdesk.enabled = enabled; + } + }) + .await?; + + // Get updated config for HID reload + let new_config = state.config.get(); + + #[cfg(unix)] + { + if let Err(e) = state + .otg_service + .apply_config(&new_config.hid, &new_config.msd) + .await + { + tracing::warn!("Failed to apply OTG config during setup: {}", e); + } + } + + tracing::info!( + "Extension config after save: ttyd.enabled={}, rustdesk.enabled={}", + new_config.extensions.ttyd.enabled, + new_config.rustdesk.enabled + ); + + // Initialize HID backend with new config + let new_hid_backend = match new_config.hid.backend { + crate::config::HidBackend::Otg => crate::hid::HidBackendType::Otg, + crate::config::HidBackend::Ch9329 => crate::hid::HidBackendType::Ch9329 { + port: new_config.hid.ch9329_port.clone(), + baud_rate: new_config.hid.ch9329_baudrate, + }, + crate::config::HidBackend::None => crate::hid::HidBackendType::None, + }; + + // Reload HID backend + if let Err(e) = state.hid.reload(new_hid_backend).await { + tracing::warn!("Failed to initialize HID backend during setup: {}", e); + // Don't fail setup, just warn + } else { + tracing::info!("HID backend initialized: {:?}", new_config.hid.backend); + } + + // Start extensions if enabled + if new_config.extensions.ttyd.enabled { + if let Err(e) = state + .extensions + .start(crate::extensions::ExtensionId::Ttyd, &new_config.extensions) + .await + { + tracing::warn!("Failed to start ttyd during setup: {}", e); + } else { + tracing::info!("ttyd started during setup"); + } + } + + // Start RustDesk if enabled + if new_config.rustdesk.enabled { + let empty_config = crate::rustdesk::config::RustDeskConfig::default(); + if let Err(e) = config::apply::apply_rustdesk_config( + &state, + &empty_config, + &new_config.rustdesk, + ConfigApplyOptions::default(), + ) + .await + { + tracing::warn!("Failed to start RustDesk during setup: {}", e); + } else { + tracing::info!("RustDesk started during setup"); + } + } + + // Start RTSP if enabled + if new_config.rtsp.enabled { + let empty_config = crate::config::RtspConfig::default(); + if let Err(e) = config::apply::apply_rtsp_config( + &state, + &empty_config, + &new_config.rtsp, + ConfigApplyOptions::default(), + ) + .await + { + tracing::warn!("Failed to start RTSP during setup: {}", e); + } else { + tracing::info!("RTSP started during setup"); + } + } + + // Start audio streaming if audio device was selected during setup + if new_config.audio.enabled { + let audio_config = crate::audio::AudioControllerConfig { + enabled: true, + device: new_config.audio.device.clone(), + quality: new_config + .audio + .quality + .parse::()?, + }; + if let Err(e) = state.audio.update_config(audio_config).await { + tracing::warn!("Failed to start audio during setup: {}", e); + } else { + tracing::info!( + "Audio started during setup: device={}", + new_config.audio.device + ); + } + // Also enable WebRTC audio + if let Err(e) = state.stream_manager.set_webrtc_audio_enabled(true).await { + tracing::warn!("Failed to enable WebRTC audio during setup: {}", e); + } + } + + tracing::info!("System initialized successfully"); + + Ok(Json(LoginResponse { + success: true, + message: Some("Setup completed".to_string()), + })) +} diff --git a/src/web/handlers/stream.rs b/src/web/handlers/stream.rs new file mode 100644 index 00000000..f7495058 --- /dev/null +++ b/src/web/handlers/stream.rs @@ -0,0 +1,626 @@ +use super::*; + +use crate::video::streamer::StreamerStats; +use axum::{ + body::Body, + http::{header, StatusCode}, + response::{IntoResponse, Response}, +}; + +fn stream_mode_label(mode: StreamMode, codec: crate::video::codec::VideoCodecType) -> &'static str { + match mode { + StreamMode::Mjpeg => "mjpeg", + StreamMode::WebRTC => codec_to_id(codec), + } +} + +/// Get stream state +pub async fn stream_state(State(state): State>) -> Json { + Json(state.stream_manager.stats().await) +} + +/// Start streaming +pub async fn stream_start(State(state): State>) -> Result> { + state.stream_manager.start().await?; + Ok(Json(LoginResponse { + success: true, + message: Some("Streaming started".to_string()), + })) +} + +/// Stop streaming +pub async fn stream_stop(State(state): State>) -> Result> { + state.stream_manager.stop().await?; + Ok(Json(LoginResponse { + success: true, + message: Some("Streaming stopped".to_string()), + })) +} + +/// Stream mode request +#[derive(Deserialize)] +pub struct SetStreamModeRequest { + /// Target mode: "mjpeg" or "webrtc" + pub mode: String, +} + +/// Stream mode response +#[derive(Serialize)] +pub struct StreamModeResponse { + pub success: bool, + pub mode: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub transition_id: Option, + pub switching: bool, + pub message: Option, +} + +/// Get current stream mode +pub async fn stream_mode_get(State(state): State>) -> Json { + let mode = state.stream_manager.current_mode().await; + let codec = state.stream_manager.current_video_codec().await; + let mode_str = stream_mode_label(mode, codec).to_string(); + + Json(StreamModeResponse { + success: true, + mode: mode_str, + transition_id: state.stream_manager.current_transition_id().await, + switching: state.stream_manager.is_switching(), + message: None, + }) +} + +/// Set stream mode (switch between MJPEG and WebRTC) +pub async fn stream_mode_set( + State(state): State>, + Json(req): Json, +) -> Result> { + use crate::video::codec::VideoCodecType; + + let constraints = state.stream_manager.codec_constraints().await; + + let mode_lower = req.mode.to_lowercase(); + let (new_mode, video_codec) = match mode_lower.as_str() { + "mjpeg" => (StreamMode::Mjpeg, None), + "webrtc" | "h264" => (StreamMode::WebRTC, Some(VideoCodecType::H264)), + "h265" => (StreamMode::WebRTC, Some(VideoCodecType::H265)), + "vp8" => (StreamMode::WebRTC, Some(VideoCodecType::VP8)), + "vp9" => (StreamMode::WebRTC, Some(VideoCodecType::VP9)), + _ => { + return Err(AppError::BadRequest(format!( + "Invalid mode '{}'. Valid modes: mjpeg, h264, h265, vp8, vp9", + req.mode + ))); + } + }; + + if new_mode == StreamMode::Mjpeg && !constraints.is_mjpeg_allowed() { + return Err(AppError::BadRequest(format!( + "Codec 'mjpeg' is not allowed: {}", + constraints.reason + ))); + } + + if let Some(codec) = video_codec { + if !constraints.is_webrtc_codec_allowed(codec) { + return Err(AppError::BadRequest(format!( + "Codec '{}' is not allowed: {}", + codec_to_id(codec), + constraints.reason + ))); + } + } + + let requested_mode_str = match (&new_mode, &video_codec) { + (StreamMode::Mjpeg, _) => "mjpeg", + (StreamMode::WebRTC, Some(VideoCodecType::H264)) => "h264", + (StreamMode::WebRTC, Some(VideoCodecType::H265)) => "h265", + (StreamMode::WebRTC, Some(VideoCodecType::VP8)) => "vp8", + (StreamMode::WebRTC, Some(VideoCodecType::VP9)) => "vp9", + (StreamMode::WebRTC, None) => "webrtc", + }; + + // Detect codec-only switch: already in WebRTC mode, just changing codec. + // switch_mode_transaction treats this as "no switch needed" since StreamMode + // is still WebRTC, so we handle codec change + event emission here. + let current_mode = state.stream_manager.current_mode().await; + let prev_codec = state.stream_manager.current_video_codec().await; + + let codec_changed = video_codec.is_some_and(|c| c != prev_codec); + let is_codec_only_switch = + current_mode == StreamMode::WebRTC && new_mode == StreamMode::WebRTC && codec_changed; + + if let Some(codec) = video_codec { + info!("Setting WebRTC video codec to {:?}", codec); + if let Err(e) = state.stream_manager.set_video_codec(codec).await { + warn!("Failed to set video codec: {}", e); + } + } + + // For codec-only switch, emit events directly instead of going through + // switch_mode_transaction (which short-circuits when mode is unchanged). + if is_codec_only_switch { + let transition_id = uuid::Uuid::new_v4().to_string(); + + state + .stream_manager + .notify_codec_switch(&transition_id, requested_mode_str, &codec_to_id(prev_codec)) + .await; + + return Ok(Json(StreamModeResponse { + success: true, + mode: requested_mode_str.to_string(), + transition_id: Some(transition_id), + switching: false, + message: Some(format!("Codec switched to {}", requested_mode_str)), + })); + } + + let tx = state + .stream_manager + .switch_mode_transaction(new_mode.clone()) + .await?; + + let active_mode = state.stream_manager.current_mode().await; + let active_codec = state.stream_manager.current_video_codec().await; + let active_mode_str = stream_mode_label(active_mode, active_codec).to_string(); + + let no_switch_needed = !tx.accepted && !tx.switching && tx.transition_id.is_none(); + Ok(Json(StreamModeResponse { + success: tx.accepted || no_switch_needed, + mode: if tx.accepted { + requested_mode_str.to_string() + } else { + active_mode_str + }, + transition_id: tx.transition_id, + switching: tx.switching, + message: Some(if tx.accepted { + format!("Switching to {} mode", requested_mode_str) + } else if tx.switching { + "Mode switch already in progress".to_string() + } else { + "No switch needed".to_string() + }), + })) +} + +/// Available video codec info +#[derive(Serialize)] +pub struct VideoCodecInfo { + /// Codec identifier (mjpeg, h264, h265, vp8, vp9) + pub id: String, + /// Display name + pub name: String, + /// Protocol (http or webrtc) + pub protocol: String, + /// Whether hardware accelerated + pub hardware: bool, + /// Encoder backend name (e.g., "vaapi", "nvenc", "software") + pub backend: Option, + /// Whether this codec is available + pub available: bool, +} + +/// Encoder backend info +#[derive(Serialize)] +pub struct EncoderBackendInfo { + /// Backend identifier (vaapi, nvenc, qsv, amf, rkmpp, v4l2m2m, software) + pub id: String, + /// Display name + pub name: String, + /// Whether this is a hardware backend + pub is_hardware: bool, + /// Supported video formats (h264, h265, vp8, vp9) + pub supported_formats: Vec, +} + +/// Available codecs response +#[derive(Serialize)] +pub struct AvailableCodecsResponse { + pub success: bool, + /// Available encoder backends + pub backends: Vec, + /// Available codecs (for backward compatibility) + pub codecs: Vec, +} + +/// Stream constraints response +#[derive(Serialize)] +pub struct StreamConstraintsResponse { + pub success: bool, + pub allowed_codecs: Vec, + pub locked_codec: Option, + pub disallow_mjpeg: bool, + pub sources: ConstraintSources, + pub reason: String, + pub current_mode: String, +} + +#[derive(Serialize)] +pub struct ConstraintSources { + pub rustdesk: bool, + pub rtsp: bool, +} + +/// Get stream codec constraints derived from enabled services. +pub async fn stream_constraints_get( + State(state): State>, +) -> Json { + let constraints = state.stream_manager.codec_constraints().await; + let current_mode = state.stream_manager.current_mode().await; + let current_codec = state.stream_manager.current_video_codec().await; + let current_mode = stream_mode_label(current_mode, current_codec).to_string(); + + Json(StreamConstraintsResponse { + success: true, + allowed_codecs: constraints + .allowed_codecs_for_api() + .into_iter() + .map(str::to_string) + .collect(), + locked_codec: constraints + .locked_codec + .map(codec_to_id) + .map(str::to_string), + disallow_mjpeg: !constraints.allow_mjpeg, + sources: ConstraintSources { + rustdesk: constraints.rustdesk_enabled, + rtsp: constraints.rtsp_enabled, + }, + reason: constraints.reason, + current_mode, + }) +} + +/// Set bitrate request +#[derive(Deserialize)] +pub struct SetBitrateRequest { + pub bitrate_preset: BitratePreset, +} + +/// Set stream bitrate (real-time adjustment) +pub async fn stream_set_bitrate( + State(state): State>, + Json(req): Json, +) -> Result> { + // Update config + state + .config + .update(|config| { + config.stream.bitrate_preset = req.bitrate_preset; + }) + .await?; + + // Apply to WebRTC streamer (real-time adjustment) + if let Err(e) = state + .stream_manager + .set_bitrate_preset(req.bitrate_preset) + .await + { + warn!("Failed to set bitrate dynamically: {}", e); + // Don't fail the request - config is saved, will apply on next connection + } else { + info!("Bitrate updated to {}", req.bitrate_preset); + } + + Ok(Json(LoginResponse { + success: true, + message: Some(format!("Bitrate set to {}", req.bitrate_preset)), + })) +} + +/// Get available video codecs +pub async fn stream_codecs_list() -> Json { + use crate::video::codec::registry::{EncoderRegistry, VideoEncoderType}; + + let registry = EncoderRegistry::global(); + + // Build backends list + let mut backends = Vec::new(); + for backend in registry.available_backends() { + let formats = registry.formats_for_backend(backend); + let format_ids: Vec = formats + .iter() + .copied() + .map(crate::video::codec_constraints::encoder_codec_to_id) + .map(String::from) + .collect(); + + backends.push(EncoderBackendInfo { + id: format!("{:?}", backend).to_lowercase(), + name: backend.display_name().to_string(), + is_hardware: backend.is_hardware(), + supported_formats: format_ids, + }); + } + + // Build codecs list (for backward compatibility) + let mut codecs = Vec::new(); + + // MJPEG is always available (HTTP streaming) + codecs.push(VideoCodecInfo { + id: "mjpeg".to_string(), + name: "MJPEG / HTTP".to_string(), + protocol: "http".to_string(), + hardware: false, + backend: Some("software".to_string()), + available: true, + }); + + // Check H264 availability (supports software fallback) + let h264_encoder = registry.best_available_encoder(VideoEncoderType::H264); + codecs.push(VideoCodecInfo { + id: "h264".to_string(), + name: "H.264 / WebRTC".to_string(), + protocol: "webrtc".to_string(), + hardware: h264_encoder.map(|e| e.is_hardware).unwrap_or(false), + backend: h264_encoder.map(|e| e.backend.to_string()), + available: h264_encoder.is_some(), + }); + + // Check H265 availability (now supports software too) + let h265_encoder = registry.best_available_encoder(VideoEncoderType::H265); + codecs.push(VideoCodecInfo { + id: "h265".to_string(), + name: "H.265 / WebRTC".to_string(), + protocol: "webrtc".to_string(), + hardware: h265_encoder.map(|e| e.is_hardware).unwrap_or(false), + backend: h265_encoder.map(|e| e.backend.to_string()), + available: h265_encoder.is_some(), + }); + + // Check VP8 availability (now supports software too) + let vp8_encoder = registry.best_available_encoder(VideoEncoderType::VP8); + codecs.push(VideoCodecInfo { + id: "vp8".to_string(), + name: "VP8 / WebRTC".to_string(), + protocol: "webrtc".to_string(), + hardware: vp8_encoder.map(|e| e.is_hardware).unwrap_or(false), + backend: vp8_encoder.map(|e| e.backend.to_string()), + available: vp8_encoder.is_some(), + }); + + // Check VP9 availability (now supports software too) + let vp9_encoder = registry.best_available_encoder(VideoEncoderType::VP9); + codecs.push(VideoCodecInfo { + id: "vp9".to_string(), + name: "VP9 / WebRTC".to_string(), + protocol: "webrtc".to_string(), + hardware: vp9_encoder.map(|e| e.is_hardware).unwrap_or(false), + backend: vp9_encoder.map(|e| e.backend.to_string()), + available: vp9_encoder.is_some(), + }); + + Json(AvailableCodecsResponse { + success: true, + backends, + codecs, + }) +} + +/// Run hardware encoder smoke tests across common resolutions/codecs. +pub async fn video_encoder_self_check() -> Json { + let response = tokio::task::spawn_blocking(run_hardware_self_check) + .await + .unwrap_or_else(|_| build_hardware_self_check_runtime_error()); + + Json(response) +} + +/// Query parameters for MJPEG stream +#[derive(Deserialize, Default)] +pub struct MjpegStreamQuery { + /// Optional client ID (if not provided, a random UUID will be generated) + pub client_id: Option, +} + +/// MJPEG stream endpoint +pub async fn mjpeg_stream( + State(state): State>, + Query(query): Query, +) -> impl IntoResponse { + // Check if MJPEG mode is active + if !state.stream_manager.is_mjpeg_enabled().await { + return axum::response::Response::builder() + .status(axum::http::StatusCode::SERVICE_UNAVAILABLE) + .header("Content-Type", "application/json") + .body(axum::body::Body::from( + r#"{"error":"MJPEG mode not active. Current mode is WebRTC."}"#, + )) + .unwrap(); + } + + // Check if config is being changed - reject new connections during config change + if state.stream_manager.is_config_changing() { + return axum::response::Response::builder() + .status(axum::http::StatusCode::SERVICE_UNAVAILABLE) + .header("Content-Type", "application/json") + .body(axum::body::Body::from( + r#"{"error":"Video configuration is being changed. Please retry shortly."}"#, + )) + .unwrap(); + } + + // Ensure stream is started (but not during config change) + if !state.stream_manager.is_streaming().await && !state.stream_manager.is_config_changing() { + if let Err(e) = state.stream_manager.start().await { + tracing::error!("Failed to auto-start stream: {}", e); + } + } + + let handler = state.stream_manager.mjpeg_handler(); + + // Use provided client ID or generate a new one + let client_id = query + .client_id + .filter(|id| !id.is_empty() && id.len() <= 64) // Validate: non-empty, max 64 chars + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + + // Create RAII guard - this will automatically register and unregister the client + let guard = Arc::new(crate::stream::mjpeg::ClientGuard::new( + client_id.clone(), + handler.clone(), + )); + + let (tx, mut rx) = tokio::sync::mpsc::channel::(1); + + let guard_clone = guard.clone(); + let handler_clone = handler.clone(); + tokio::spawn(async move { + let _guard = guard_clone; // Keep guard alive + let mut notify_rx = handler_clone.subscribe(); + let mut last_seq = 0u64; + let mut timeout_count = 0; + + // Send initial frame if available + if let Some(frame) = handler_clone.current_frame() { + if frame.is_valid_jpeg() { + let data = create_mjpeg_part(frame.data()); + // send() blocks until receiver is ready (backpressure) + if tx.send(data).await.is_ok() { + // FPS recording moved to async_stream after yield + last_seq = frame.sequence; + } else { + return; // Receiver closed + } + } + } + + loop { + // Check if stream went offline (e.g., during config change) + if !handler_clone.is_online() { + break; + } + + // Wait for new frame notification with timeout + let result = + tokio::time::timeout(std::time::Duration::from_secs(5), notify_rx.recv()).await; + + match result { + Ok(Ok(())) => { + // Check online status after receiving notification + // set_offline() sends a notification, so we need to check here + if !handler_clone.is_online() { + break; + } + timeout_count = 0; + if let Some(frame) = handler_clone.current_frame() { + // Use != instead of > to handle sequence reset when capturer restarts + // (e.g., after video config change, new capturer starts from seq=0) + if frame.sequence != last_seq && frame.is_valid_jpeg() { + let data = create_mjpeg_part(frame.data()); + if tx.send(data).await.is_ok() { + last_seq = frame.sequence; + } else { + break; + } + } + } + } + Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => { + break; + } + Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => { + // Receiver was too slow - skip missed frames and jump to latest + if !handler_clone.is_online() { + break; + } + timeout_count = 0; + + if let Some(frame) = handler_clone.current_frame() { + if frame.is_valid_jpeg() { + // Send current frame immediately and reset sequence tracking + if tx.send(create_mjpeg_part(frame.data())).await.is_ok() { + last_seq = frame.sequence; + } else { + break; + } + } + } + } + Err(_) => { + // Timeout - check if still online + timeout_count += 1; + if timeout_count > 6 || !handler_clone.is_online() { + break; + } + // Send last frame again to keep connection alive + let Some(frame) = handler_clone.current_frame() else { + continue; + }; + + if frame.is_valid_jpeg() + && tx.send(create_mjpeg_part(frame.data())).await.is_err() + { + break; + } + } + } + } + }); + + // Create stream that receives from channel and forwards to the HTTP + // body. Record FPS *before* yield so the final frame of a session + // still gets counted (after-yield code in async_stream! only runs + // when the consumer polls again, which never happens for the last + // frame of a closing connection). + let handler_for_stream = handler.clone(); + let guard_for_stream = guard.clone(); + let body_stream = async_stream::stream! { + while let Some(data) = rx.recv().await { + handler_for_stream.record_frame_sent(guard_for_stream.id()); + yield Ok::(data); + } + }; + + Response::builder() + .status(StatusCode::OK) + .header( + header::CONTENT_TYPE, + "multipart/x-mixed-replace; boundary=frame", + ) + .header(header::CACHE_CONTROL, "no-cache, no-store, must-revalidate") + .header(header::PRAGMA, "no-cache") + .header(header::EXPIRES, "0") + .header(header::CONNECTION, "keep-alive") + .body(Body::from_stream(body_stream)) + .unwrap() +} + +/// Single JPEG snapshot +pub async fn snapshot(State(state): State>) -> impl IntoResponse { + let handler = state.stream_manager.mjpeg_handler(); + + match handler.current_frame() { + Some(frame) if frame.is_valid_jpeg() => Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "image/jpeg") + .header(header::CACHE_CONTROL, "no-cache") + .body(Body::from(frame.data_bytes())) + .unwrap(), + _ => Response::builder() + .status(StatusCode::SERVICE_UNAVAILABLE) + .body(Body::from("No frame available")) + .unwrap(), + } +} + +/// Create MJPEG multipart frame bytes +fn create_mjpeg_part(jpeg_data: &[u8]) -> bytes::Bytes { + use bytes::{BufMut, BytesMut}; + + let mut buf = BytesMut::with_capacity(128 + jpeg_data.len()); + + // Write boundary and headers + buf.put_slice(b"--frame\r\n"); + buf.put_slice(b"Content-Type: image/jpeg\r\n"); + buf.put_slice(format!("Content-Length: {}\r\n", jpeg_data.len()).as_bytes()); + buf.put_slice(b"\r\n"); + + // Write JPEG data + buf.put_slice(jpeg_data); + buf.put_slice(b"\r\n"); + + buf.freeze() +} diff --git a/src/web/handlers/system.rs b/src/web/handlers/system.rs new file mode 100644 index 00000000..77a18505 --- /dev/null +++ b/src/web/handlers/system.rs @@ -0,0 +1,113 @@ +use super::*; + +/// Health check response +#[derive(Serialize)] +pub struct HealthResponse { + pub status: &'static str, + pub version: &'static str, +} + +pub async fn health_check() -> Json { + Json(HealthResponse { + status: "ok", + version: env!("CARGO_PKG_VERSION"), + }) +} + +/// System info response +#[derive(Serialize)] +pub struct SystemInfo { + pub version: &'static str, + pub build_date: &'static str, + pub initialized: bool, + pub platform: PlatformCapabilities, + pub capabilities: Capabilities, + #[serde(skip_serializing_if = "Option::is_none")] + pub disk_space: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub device_info: Option, +} + +#[derive(Serialize)] +pub struct Capabilities { + pub video: CapabilityInfo, + pub hid: CapabilityInfo, + pub msd: CapabilityInfo, + pub atx: CapabilityInfo, + pub audio: CapabilityInfo, + pub rustdesk: CapabilityInfo, +} + +#[derive(Serialize)] +pub struct CapabilityInfo { + pub available: bool, + pub backend: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, +} + +pub async fn system_info(State(state): State>) -> Json { + let config = state.config.get(); + let platform = PlatformCapabilities::current(); + + // Get disk space information for MSD base directory + let disk_space = { + let msd_dir = config.msd.msd_dir_path(); + if msd_dir.as_os_str().is_empty() { + None + } else { + get_disk_space(&msd_dir).ok() + } + }; + + // Get device information (hostname, CPU, memory, network) + let device_info = Some(get_device_info()); + + Json(SystemInfo { + version: env!("CARGO_PKG_VERSION"), + build_date: env!("BUILD_DATE"), + initialized: config.initialized, + platform: platform.clone(), + capabilities: Capabilities { + video: CapabilityInfo { + available: config.video.device.is_some(), + backend: config.video.device.clone(), + reason: None, + }, + hid: CapabilityInfo { + available: config.hid.backend != crate::config::HidBackend::None, + backend: Some(format!("{:?}", config.hid.backend)), + reason: None, + }, + msd: CapabilityInfo { + available: config.msd.enabled && platform.msd.available, + backend: None, + reason: platform.msd.reason.clone(), + }, + atx: CapabilityInfo { + available: config.atx.enabled, + backend: if config.atx.enabled { + Some(format!( + "power: {:?}, reset: {:?}", + config.atx.power.driver, config.atx.reset.driver + )) + } else { + None + }, + reason: None, + }, + audio: CapabilityInfo { + available: config.audio.enabled && platform.audio.available, + backend: Some(config.audio.device.clone()), + reason: platform.audio.reason.clone(), + }, + rustdesk: CapabilityInfo { + available: config.rustdesk.enabled && platform.rustdesk.available, + backend: platform.rustdesk.selected_backend.clone(), + reason: platform.rustdesk.reason.clone(), + }, + }, + disk_space, + device_info, + }) +} diff --git a/src/web/handlers/terminal.rs b/src/web/handlers/terminal.rs index 37185286..39868539 100644 --- a/src/web/handlers/terminal.rs +++ b/src/web/handlers/terminal.rs @@ -10,13 +10,19 @@ use axum::{ use futures::{SinkExt, StreamExt}; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; +#[cfg(windows)] +use tokio::net::TcpStream; +#[cfg(unix)] use tokio::net::UnixStream; use tokio_tungstenite::tungstenite::{ client::IntoClientRequest, http::HeaderValue, Message as TungsteniteMessage, }; use crate::error::AppError; +#[cfg(unix)] use crate::extensions::TTYD_SOCKET_PATH; +#[cfg(windows)] +use crate::extensions::TTYD_TCP_ADDR; use crate::state::AppState; pub async fn terminal_ws( @@ -35,10 +41,10 @@ pub async fn terminal_ws( } async fn handle_terminal_websocket(client_ws: WebSocket, query_string: String) { - let unix_stream = match UnixStream::connect(TTYD_SOCKET_PATH).await { + let ttyd_stream = match connect_ttyd().await { Ok(s) => s, Err(e) => { - tracing::error!("Failed to connect to ttyd socket: {}", e); + tracing::error!("Failed to connect to ttyd: {}", e); return; } }; @@ -56,7 +62,7 @@ async fn handle_terminal_websocket(client_ws: WebSocket, query_string: String) { .headers_mut() .insert("Sec-WebSocket-Protocol", HeaderValue::from_static("tty")); - let ws_stream = match tokio_tungstenite::client_async(request, unix_stream).await { + let ws_stream = match tokio_tungstenite::client_async(request, ttyd_stream).await { Ok((ws, _)) => ws, Err(e) => { tracing::error!("Failed to establish WebSocket with ttyd: {}", e); @@ -121,7 +127,7 @@ pub async fn terminal_proxy( ) -> Result { let path_str = path.map(|p| p.0).unwrap_or_default(); - let mut unix_stream = UnixStream::connect(TTYD_SOCKET_PATH) + let mut ttyd_stream = connect_ttyd() .await .map_err(|e| AppError::ServiceUnavailable(format!("ttyd not running: {}", e)))?; @@ -155,13 +161,13 @@ pub async fn terminal_proxy( method, uri_path, headers_str ); - unix_stream + ttyd_stream .write_all(http_request.as_bytes()) .await .map_err(|e| AppError::Internal(format!("Failed to send request: {}", e)))?; let mut response_buf = Vec::new(); - unix_stream + ttyd_stream .read_to_end(&mut response_buf) .await .map_err(|e| AppError::Internal(format!("Failed to read response: {}", e)))?; @@ -211,6 +217,16 @@ pub async fn terminal_proxy( .map_err(|e| AppError::Internal(format!("Failed to build response: {}", e))) } +#[cfg(unix)] +async fn connect_ttyd() -> std::io::Result { + UnixStream::connect(TTYD_SOCKET_PATH).await +} + +#[cfg(windows)] +async fn connect_ttyd() -> std::io::Result { + TcpStream::connect(TTYD_TCP_ADDR).await +} + pub async fn terminal_index( State(state): State>, req: Request, diff --git a/src/web/handlers/update_api.rs b/src/web/handlers/update_api.rs new file mode 100644 index 00000000..3750f973 --- /dev/null +++ b/src/web/handlers/update_api.rs @@ -0,0 +1,31 @@ +use super::*; + +#[derive(Deserialize)] +pub struct UpdateOverviewQuery { + pub channel: Option, +} + +pub async fn update_overview( + State(state): State>, + axum::extract::Query(query): axum::extract::Query, +) -> Result> { + let channel = query.channel.unwrap_or(UpdateChannel::Stable); + let response = state.update.overview(channel).await?; + Ok(Json(response)) +} + +pub async fn update_upgrade( + State(state): State>, + Json(req): Json, +) -> Result> { + state.update.start_upgrade(req, state.shutdown_tx.clone())?; + + Ok(Json(LoginResponse { + success: true, + message: Some("Upgrade started".to_string()), + })) +} + +pub async fn update_status(State(state): State>) -> Json { + Json(state.update.status().await) +} diff --git a/src/web/handlers/webrtc.rs b/src/web/handlers/webrtc.rs new file mode 100644 index 00000000..216eb27f --- /dev/null +++ b/src/web/handlers/webrtc.rs @@ -0,0 +1,194 @@ +use super::*; + +use crate::webrtc::signaling::{AnswerResponse, IceCandidateRequest, OfferRequest}; + +/// Create WebRTC session +#[derive(Serialize)] +pub struct CreateSessionResponse { + pub session_id: String, +} + +pub async fn webrtc_create_session( + State(state): State>, +) -> Result> { + // Check if WebRTC mode is active + if !state.stream_manager.is_webrtc_enabled().await { + return Err(AppError::ServiceUnavailable( + "WebRTC mode not active. Current mode is MJPEG.".to_string(), + )); + } + + let session_id = state.webrtc.create_session().await?; + Ok(Json(CreateSessionResponse { session_id })) +} + +/// Handle WebRTC offer +pub async fn webrtc_offer( + State(state): State>, + Json(req): Json, +) -> Result> { + // Check if WebRTC mode is active + if !state.stream_manager.is_webrtc_enabled().await { + return Err(AppError::ServiceUnavailable( + "WebRTC mode not active. Current mode is MJPEG.".to_string(), + )); + } + + // Backward compatibility: `client_id` is treated as an existing session_id hint. + // New clients should not pass it; each offer creates a fresh session. + let webrtc = &state.webrtc; + let session_id = if let Some(client_id) = &req.client_id { + // Reuse only when it matches an active session ID. + if webrtc.get_session(client_id).await.is_some() { + client_id.clone() + } else { + webrtc.create_session().await? + } + } else { + webrtc.create_session().await? + }; + + // Handle offer + let offer = crate::webrtc::SdpOffer::new(req.sdp); + let answer = webrtc.handle_offer(&session_id, offer).await?; + + Ok(Json(AnswerResponse::new( + answer.sdp, + session_id, + answer.ice_candidates.unwrap_or_default(), + ))) +} + +/// Add ICE candidate +pub async fn webrtc_ice_candidate( + State(state): State>, + Json(req): Json, +) -> Result> { + state + .webrtc + .add_ice_candidate(&req.session_id, req.candidate) + .await?; + + Ok(Json(LoginResponse { + success: true, + message: None, + })) +} + +/// Get WebRTC session info +#[derive(Serialize)] +pub struct WebRtcSessionInfo { + pub session_id: String, + pub state: String, +} + +#[derive(Serialize)] +pub struct WebRtcStatus { + pub session_count: usize, + pub sessions: Vec, +} + +pub async fn webrtc_status(State(state): State>) -> Json { + let sessions = state.webrtc.list_sessions().await; + Json(WebRtcStatus { + session_count: sessions.len(), + sessions: sessions + .into_iter() + .map(|s| WebRtcSessionInfo { + session_id: s.session_id, + state: s.state, + }) + .collect(), + }) +} + +/// Close WebRTC session +#[derive(Deserialize)] +pub struct CloseSessionRequest { + pub session_id: String, +} + +pub async fn webrtc_close_session( + State(state): State>, + Json(req): Json, +) -> Result> { + state.webrtc.close_session(&req.session_id).await?; + + Ok(Json(LoginResponse { + success: true, + message: Some("Session closed".to_string()), + })) +} + +/// ICE servers configuration for WebRTC +#[derive(Serialize)] +pub struct IceServersResponse { + pub ice_servers: Vec, + pub mdns_mode: String, +} + +#[derive(Serialize)] +pub struct IceServerInfo { + pub urls: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub username: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub credential: Option, +} + +fn non_empty_config_value(value: &Option) -> Option<&str> { + value.as_deref().filter(|value| !value.is_empty()) +} + +/// Get ICE servers configuration for client-side WebRTC +/// Returns user-configured servers, or Google STUN as fallback if none configured +pub async fn webrtc_ice_servers(State(state): State>) -> Json { + use crate::webrtc::config::public_ice; + use crate::webrtc::mdns::{mdns_mode, mdns_mode_label}; + + let config = state.config.get(); + let mut ice_servers = Vec::new(); + + // Check if user has configured custom ICE servers + let stun_server = non_empty_config_value(&config.stream.stun_server); + let turn_server = non_empty_config_value(&config.stream.turn_server); + + if stun_server.is_some() || turn_server.is_some() { + // Use user-configured ICE servers + if let Some(stun) = stun_server { + ice_servers.push(IceServerInfo { + urls: vec![stun.to_string()], + username: None, + credential: None, + }); + } + + if let Some(turn) = turn_server { + let username = config.stream.turn_username.clone(); + let credential = config.stream.turn_password.clone(); + if username.is_some() && credential.is_some() { + ice_servers.push(IceServerInfo { + urls: vec![turn.to_string()], + username, + credential, + }); + } + } + } else { + // No custom servers — baked-in public STUN + ice_servers.push(IceServerInfo { + urls: vec![public_ice::stun_server().to_string()], + username: None, + credential: None, + }); + // Note: TURN servers are not provided - users must configure their own + } + + let mdns_mode = mdns_mode(); + let mdns_mode = mdns_mode_label(mdns_mode).to_string(); + + Json(IceServersResponse { + ice_servers, + mdns_mode, + }) +} diff --git a/src/web/routes.rs b/src/web/routes.rs index 79171823..80aef025 100644 --- a/src/web/routes.rs +++ b/src/web/routes.rs @@ -1,7 +1,8 @@ +#[cfg(unix)] +use axum::{extract::DefaultBodyLimit, routing::delete}; use axum::{ - extract::DefaultBodyLimit, middleware, - routing::{any, delete, get, patch, post}, + routing::{any, get, patch, post}, Router, }; use std::sync::Arc; @@ -72,7 +73,6 @@ pub fn create_router(state: Arc) -> Router { .route("/webrtc/close", post(handlers::webrtc_close_session)) // HID endpoints .route("/hid/status", get(handlers::hid_status)) - .route("/hid/otg/self-check", get(handlers::hid_otg_self_check)) .route("/hid/reset", post(handlers::hid_reset)) // WebSocket HID endpoint (for MJPEG mode) .route("/ws/hid", any(ws_hid_handler)) @@ -99,8 +99,6 @@ pub fn create_router(state: Arc) -> Router { ) .route("/config/hid", get(handlers::config::get_hid_config)) .route("/config/hid", patch(handlers::config::update_hid_config)) - .route("/config/msd", get(handlers::config::get_msd_config)) - .route("/config/msd", patch(handlers::config::update_msd_config)) .route("/config/atx", get(handlers::config::get_atx_config)) .route("/config/atx", patch(handlers::config::update_atx_config)) .route("/config/audio", get(handlers::config::get_audio_config)) @@ -148,38 +146,15 @@ pub fn create_router(state: Arc) -> Router { .route("/config/auth", patch(handlers::config::update_auth_config)) // Redfish configuration .route("/config/redfish", get(handlers::config::get_redfish_config)) - .route("/config/redfish", patch(handlers::config::update_redfish_config)) + .route( + "/config/redfish", + patch(handlers::config::update_redfish_config), + ) // System control .route("/system/restart", post(handlers::system_restart)) .route("/update/overview", get(handlers::update_overview)) .route("/update/upgrade", post(handlers::update_upgrade)) .route("/update/status", get(handlers::update_status)) - // MSD (Mass Storage Device) endpoints - .route("/msd/status", get(handlers::msd_status)) - .route("/msd/images", get(handlers::msd_images_list)) - .route("/msd/images/download", post(handlers::msd_image_download)) - .route( - "/msd/images/download/cancel", - post(handlers::msd_image_download_cancel), - ) - .route("/msd/images/{id}", get(handlers::msd_image_get)) - .route("/msd/images/{id}", delete(handlers::msd_image_delete)) - .route("/msd/connect", post(handlers::msd_connect)) - .route("/msd/disconnect", post(handlers::msd_disconnect)) - // MSD Virtual Drive endpoints - .route("/msd/drive", get(handlers::msd_drive_info)) - .route("/msd/drive", delete(handlers::msd_drive_delete)) - .route("/msd/drive/init", post(handlers::msd_drive_init)) - .route("/msd/drive/files", get(handlers::msd_drive_files)) - .route( - "/msd/drive/files/{*path}", - get(handlers::msd_drive_download), - ) - .route( - "/msd/drive/files/{*path}", - delete(handlers::msd_drive_file_delete), - ) - .route("/msd/drive/mkdir/{*path}", post(handlers::msd_drive_mkdir)) // ATX (Power Control) endpoints .route("/atx/status", get(handlers::atx_status)) .route("/atx/power", post(handlers::atx_power)) @@ -187,11 +162,6 @@ pub fn create_router(state: Arc) -> Router { .route("/atx/wol/history", get(handlers::atx_wol_history)) // Device discovery endpoints .route("/devices/atx", get(handlers::devices::list_atx_devices)) - .route("/devices/usb", get(handlers::devices::list_usb_devices)) - .route( - "/devices/usb/reset", - post(handlers::devices::reset_usb_device), - ) // Extension management endpoints .route("/extensions", get(handlers::extensions::list_extensions)) .route("/extensions/{id}", get(handlers::extensions::get_extension)) @@ -225,6 +195,43 @@ pub fn create_router(state: Arc) -> Router { .route("/terminal/ws", get(handlers::terminal::terminal_ws)) .route("/terminal/{*path}", get(handlers::terminal::terminal_proxy)); + #[cfg(unix)] + let user_routes = { + user_routes + .route("/hid/otg/self-check", get(handlers::hid_otg_self_check)) + .route("/config/msd", get(handlers::config::get_msd_config)) + .route("/config/msd", patch(handlers::config::update_msd_config)) + .route("/msd/status", get(handlers::msd_status)) + .route("/msd/images", get(handlers::msd_images_list)) + .route("/msd/images/download", post(handlers::msd_image_download)) + .route( + "/msd/images/download/cancel", + post(handlers::msd_image_download_cancel), + ) + .route("/msd/images/{id}", get(handlers::msd_image_get)) + .route("/msd/images/{id}", delete(handlers::msd_image_delete)) + .route("/msd/connect", post(handlers::msd_connect)) + .route("/msd/disconnect", post(handlers::msd_disconnect)) + .route("/msd/drive", get(handlers::msd_drive_info)) + .route("/msd/drive", delete(handlers::msd_drive_delete)) + .route("/msd/drive/init", post(handlers::msd_drive_init)) + .route("/msd/drive/files", get(handlers::msd_drive_files)) + .route( + "/msd/drive/files/{*path}", + get(handlers::msd_drive_download), + ) + .route( + "/msd/drive/files/{*path}", + delete(handlers::msd_drive_file_delete), + ) + .route("/msd/drive/mkdir/{*path}", post(handlers::msd_drive_mkdir)) + .route("/devices/usb", get(handlers::devices::list_usb_devices)) + .route( + "/devices/usb/reset", + post(handlers::devices::reset_usb_device), + ) + }; + // Protected routes (all authenticated users) let protected_routes = user_routes; @@ -237,10 +244,13 @@ pub fn create_router(state: Arc) -> Router { // Large file upload routes (MSD images and drive files) // Use streaming upload to support files larger than available RAM // Disable body limit for streaming uploads - files are written directly to disk + #[cfg(unix)] let upload_routes = Router::new() .route("/msd/images", post(handlers::msd_image_upload)) .route("/msd/drive/files", post(handlers::msd_drive_upload)) .layer(DefaultBodyLimit::disable()); + #[cfg(not(unix))] + let upload_routes = Router::new(); // Combine API routes let api_routes = Router::new() diff --git a/src/webrtc/rtp.rs b/src/webrtc/rtp.rs index 1e847873..7e39dacc 100644 --- a/src/webrtc/rtp.rs +++ b/src/webrtc/rtp.rs @@ -1,4 +1,4 @@ -//! Opus outbound track plus H.264 Annex B helpers (SPS/PPS, keyframe scan). Video RTP lives in [`crate::webrtc::video_track`]. +//! Opus outbound track. Video RTP lives in [`crate::webrtc::video_track`]. use bytes::Bytes; use std::sync::Arc; @@ -65,245 +65,3 @@ impl OpusAudioTrack { }) } } - -/// Strips AUD (9) and filler (12) NALs; some WebRTC stacks dislike AUD. -pub fn strip_aud_nal_units(data: &[u8]) -> Vec { - let mut result = Vec::with_capacity(data.len()); - let mut i = 0; - - while i < data.len() { - let (start_code_pos, start_code_len) = if i + 4 <= data.len() - && data[i] == 0 - && data[i + 1] == 0 - && data[i + 2] == 0 - && data[i + 3] == 1 - { - (i, 4) - } else if i + 3 <= data.len() && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 { - (i, 3) - } else { - i += 1; - continue; - }; - - let nal_start = start_code_pos + start_code_len; - if nal_start >= data.len() { - break; - } - - let nal_type = data[nal_start] & 0x1F; - - let mut nal_end = data.len(); - let mut j = nal_start + 1; - while j + 3 <= data.len() { - if (data[j] == 0 && data[j + 1] == 0 && data[j + 2] == 1) - || (j + 4 <= data.len() - && data[j] == 0 - && data[j + 1] == 0 - && data[j + 2] == 0 - && data[j + 3] == 1) - { - nal_end = j; - break; - } - j += 1; - } - - if nal_type != 9 && nal_type != 12 { - result.extend_from_slice(&data[start_code_pos..nal_end]); - } - - i = nal_end; - } - - if result.is_empty() && !data.is_empty() { - return data.to_vec(); - } - - result -} - -pub fn extract_sps_pps(data: &[u8]) -> (Option>, Option>) { - let mut sps: Option> = None; - let mut pps: Option> = None; - let mut i = 0; - - while i < data.len() { - let start_code_len = if i + 4 <= data.len() - && data[i] == 0 - && data[i + 1] == 0 - && data[i + 2] == 0 - && data[i + 3] == 1 - { - 4 - } else if i + 3 <= data.len() && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 { - 3 - } else { - i += 1; - continue; - }; - - let nal_start = i + start_code_len; - if nal_start >= data.len() { - break; - } - - let nal_type = data[nal_start] & 0x1F; - - let mut nal_end = data.len(); - let mut j = nal_start + 1; - while j + 3 <= data.len() { - if (data[j] == 0 && data[j + 1] == 0 && data[j + 2] == 1) - || (j + 4 <= data.len() - && data[j] == 0 - && data[j + 1] == 0 - && data[j + 2] == 0 - && data[j + 3] == 1) - { - nal_end = j; - break; - } - j += 1; - } - - match nal_type { - 7 => { - sps = Some(data[nal_start..nal_end].to_vec()); - } - 8 => { - pps = Some(data[nal_start..nal_end].to_vec()); - } - _ => {} - } - - i = nal_end; - } - - (sps, pps) -} - -pub fn has_sps_pps(data: &[u8]) -> bool { - let mut has_sps = false; - let mut has_pps = false; - let mut i = 0; - - while i < data.len() { - let start_code_len = if i + 4 <= data.len() - && data[i] == 0 - && data[i + 1] == 0 - && data[i + 2] == 0 - && data[i + 3] == 1 - { - 4 - } else if i + 3 <= data.len() && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 { - 3 - } else { - i += 1; - continue; - }; - - let nal_start = i + start_code_len; - if nal_start >= data.len() { - break; - } - - let nal_type = data[nal_start] & 0x1F; - - match nal_type { - 7 => has_sps = true, - 8 => has_pps = true, - _ => {} - } - - if has_sps && has_pps { - return true; - } - - i = nal_start + 1; - } - - has_sps && has_pps -} - -pub fn is_h264_keyframe(data: &[u8]) -> bool { - let mut i = 0; - while i < data.len() { - if i + 3 < data.len() && data[i] == 0 && data[i + 1] == 0 { - let nal_start = if data[i + 2] == 1 { - i + 3 - } else if i + 4 < data.len() && data[i + 2] == 0 && data[i + 3] == 1 { - i + 4 - } else { - i += 1; - continue; - }; - - if nal_start < data.len() { - let nal_type = data[nal_start] & 0x1F; - if nal_type == 5 { - return true; - } - } - i = nal_start; - } else { - i += 1; - } - } - false -} - -/// `profile-level-id` hex for SDP (`42001f` etc.); expects SPS NAL RBSP without start code. -pub fn parse_profile_level_id_from_sps(sps: &[u8]) -> Option { - if sps.len() < 4 { - return None; - } - - let profile_idc = sps[1]; - let constraint_set_flags = sps[2]; - let level_idc = sps[3]; - - Some(format!( - "{:02x}{:02x}{:02x}", - profile_idc, constraint_set_flags, level_idc - )) -} - -pub fn extract_profile_level_id(data: &[u8]) -> Option { - let (sps, _) = extract_sps_pps(data); - sps.and_then(|sps_data| parse_profile_level_id_from_sps(&sps_data)) -} - -pub mod profiles { - pub const CONSTRAINED_BASELINE_31: &str = "42e01f"; - pub const BASELINE_31: &str = "42001f"; - pub const MAIN_31: &str = "4d001f"; - pub const HIGH_31: &str = "64001f"; - pub const HIGH_40: &str = "640028"; - pub const HIGH_51: &str = "640033"; -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_is_h264_keyframe() { - let idr_frame = vec![0x00, 0x00, 0x00, 0x01, 0x65]; - assert!(is_h264_keyframe(&idr_frame)); - - let idr_frame_3 = vec![0x00, 0x00, 0x01, 0x65]; - assert!(is_h264_keyframe(&idr_frame_3)); - - let p_frame = vec![0x00, 0x00, 0x00, 0x01, 0x41]; - assert!(!is_h264_keyframe(&p_frame)); - - let sps = vec![0x00, 0x00, 0x00, 0x01, 0x67]; - assert!(!is_h264_keyframe(&sps)); - - let multi_nal = vec![ - 0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x01, 0x68, 0xce, - 0x38, 0x80, 0x00, 0x00, 0x00, 0x01, 0x65, 0x88, 0x84, - ]; - assert!(is_h264_keyframe(&multi_nal)); - } -} diff --git a/src/webrtc/universal_session.rs b/src/webrtc/universal_session.rs index 39965a82..1bd9f1b5 100644 --- a/src/webrtc/universal_session.rs +++ b/src/webrtc/universal_session.rs @@ -1,5 +1,6 @@ //! One browser session: negotiated [`RTCPeerConnection`], outbound video/audio, HID DataChannel. +use std::net::IpAddr; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::{watch, Mutex, RwLock}; @@ -33,6 +34,7 @@ use crate::audio::OpusFrame; use crate::error::{AppError, Result}; use crate::hid::datachannel::{parse_hid_message, HidChannelEvent}; use crate::hid::HidController; +use crate::video::codec::h264_bitstream; use crate::video::types::{ BitratePreset, EncodedVideoFrame, PixelFormat, Resolution, VideoEncoderType, }; @@ -40,49 +42,16 @@ use std::sync::atomic::AtomicBool; const MIME_TYPE_H265: &str = "video/H265"; -fn h264_contains_parameter_sets(data: &[u8]) -> bool { - let mut i = 0usize; - while i + 4 <= data.len() { - let sc_len = if i + 4 <= data.len() - && data[i] == 0 - && data[i + 1] == 0 - && data[i + 2] == 0 - && data[i + 3] == 1 - { - 4 - } else if i + 3 <= data.len() && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 { - 3 - } else { - i += 1; - continue; - }; - - let nal_start = i + sc_len; - if nal_start < data.len() { - let nal_type = data[nal_start] & 0x1F; - if nal_type == 7 || nal_type == 8 { - return true; - } +fn is_allowed_ice_ip(ip: IpAddr) -> bool { + match ip { + IpAddr::V4(ip) => !ip.is_link_local(), + IpAddr::V6(ip) => { + !(ip.is_loopback() + || ip.is_unspecified() + || ip.is_unique_local() + || ip.is_unicast_link_local()) } - i = nal_start.saturating_add(1); } - - let mut pos = 0usize; - while pos + 4 <= data.len() { - let nalu_len = - u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize; - pos += 4; - if nalu_len == 0 || pos + nalu_len > data.len() { - break; - } - let nal_type = data[pos] & 0x1F; - if nal_type == 7 || nal_type == 8 { - return true; - } - pos += nalu_len; - } - - false } #[derive(Debug, Clone)] @@ -94,6 +63,7 @@ pub struct UniversalSessionConfig { pub bitrate_preset: BitratePreset, pub fps: u32, pub audio_enabled: bool, + pub video_sdp_fmtp_line: Option, } impl Default for UniversalSessionConfig { @@ -106,6 +76,7 @@ impl Default for UniversalSessionConfig { bitrate_preset: BitratePreset::Balanced, fps: 30, audio_enabled: false, + video_sdp_fmtp_line: None, } } } @@ -167,6 +138,7 @@ impl UniversalSession { resolution: config.resolution, bitrate_kbps: config.bitrate_preset.bitrate_kbps(), fps: config.fps, + sdp_fmtp_line: config.video_sdp_fmtp_line.clone(), }; let video_track = Arc::new(UniversalVideoTrack::new(track_config)); @@ -257,6 +229,7 @@ impl UniversalSession { .map_err(|e| AppError::VideoError(format!("Failed to register interceptors: {}", e)))?; let mut setting_engine = SettingEngine::default(); + setting_engine.set_ip_filter(Box::new(is_allowed_ice_ip)); let mode = mdns_mode(); setting_engine.set_ice_multicast_dns_mode(mode); if mode == MulticastDnsMode::QueryAndGather { @@ -629,7 +602,7 @@ impl UniversalSession { // before IDR. Keep this frame so browser can decode the next IDR. let forward_h264_parameter_frame = waiting_for_keyframe && expected_codec == VideoEncoderType::H264 - && h264_contains_parameter_sets(encoded_frame.data.as_ref()); + && h264_bitstream::has_sps_pps(encoded_frame.data.as_ref()); let now = Instant::now(); if now.duration_since(last_keyframe_request) @@ -930,4 +903,23 @@ mod tests { VideoCodec::VP9 ); } + + #[test] + fn test_ice_ip_filter_excludes_link_local_ipv4() { + assert!(!is_allowed_ice_ip("169.254.44.156".parse().unwrap())); + assert!(!is_allowed_ice_ip("169.254.228.140".parse().unwrap())); + assert!(is_allowed_ice_ip("192.168.10.9".parse().unwrap())); + assert!(is_allowed_ice_ip("10.0.0.5".parse().unwrap())); + } + + #[test] + fn test_ice_ip_filter_excludes_local_ipv6() { + assert!(!is_allowed_ice_ip("::1".parse().unwrap())); + assert!(!is_allowed_ice_ip("::".parse().unwrap())); + assert!(!is_allowed_ice_ip( + "fe80::3fb1:b28d:a3b0:d160".parse().unwrap() + )); + assert!(!is_allowed_ice_ip("fc00::1".parse().unwrap())); + assert!(is_allowed_ice_ip("2001:4860:4860::8888".parse().unwrap())); + } } diff --git a/src/webrtc/video_track.rs b/src/webrtc/video_track.rs index cab70a54..ceacf7fb 100644 --- a/src/webrtc/video_track.rs +++ b/src/webrtc/video_track.rs @@ -15,6 +15,7 @@ use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; use super::h265_payloader::H265Payloader; use crate::error::{AppError, Result}; +use crate::video::codec::h264_bitstream; use crate::video::types::Resolution; const RTP_MTU: usize = 1200; @@ -52,9 +53,7 @@ impl VideoCodec { pub fn sdp_fmtp(&self) -> String { match self { - VideoCodec::H264 => { - "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f".to_string() - } + VideoCodec::H264 => h264_bitstream::fallback_webrtc_fmtp_line(), VideoCodec::H265 => "level-id=180;profile-id=1;tier-flag=0;tx-mode=SRST".to_string(), VideoCodec::VP8 => String::new(), VideoCodec::VP9 => "profile-id=0".to_string(), @@ -85,6 +84,7 @@ pub struct UniversalVideoTrackConfig { pub resolution: Resolution, pub bitrate_kbps: u32, pub fps: u32, + pub sdp_fmtp_line: Option, } impl Default for UniversalVideoTrackConfig { @@ -96,6 +96,7 @@ impl Default for UniversalVideoTrackConfig { resolution: Resolution::HD720, bitrate_kbps: 8000, fps: 30, + sdp_fmtp_line: None, } } } @@ -154,11 +155,18 @@ struct H265RtpState { timestamp_increment: u32, } +#[derive(Default)] +struct H264TrackState { + sps: Option>, + pps: Option>, +} + pub struct UniversalVideoTrack { track: TrackType, codec: VideoCodec, config: UniversalVideoTrackConfig, h265_state: Option>, + h264_state: Mutex, } impl UniversalVideoTrack { @@ -167,7 +175,10 @@ impl UniversalVideoTrack { mime_type: config.codec.mime_type().to_string(), clock_rate: config.codec.clock_rate(), channels: 0, - sdp_fmtp_line: config.codec.sdp_fmtp(), + sdp_fmtp_line: config + .sdp_fmtp_line + .clone() + .unwrap_or_else(|| config.codec.sdp_fmtp()), rtcp_feedback: vec![], }; @@ -201,6 +212,7 @@ impl UniversalVideoTrack { codec: config.codec, config, h265_state, + h264_state: Mutex::new(H264TrackState::default()), } } @@ -239,6 +251,36 @@ impl UniversalVideoTrack { /// One Annex-B AU per sample so the stack can STAP/FU internally. async fn write_h264_frame(&self, data: Bytes, _is_keyframe: bool) -> Result<()> { + let normalized = h264_bitstream::normalize_for_webrtc(data.as_ref()); + let mut data = Bytes::from(normalized); + + let idr = h264_bitstream::is_keyframe(data.as_ref()); + let has_parameter_sets = h264_bitstream::has_sps_pps(data.as_ref()); + + { + let mut state = self.h264_state.lock().await; + let (sps, pps) = h264_bitstream::extract_sps_pps(data.as_ref()); + if let Some(sps) = sps { + state.sps = Some(sps); + } + if let Some(pps) = pps { + state.pps = Some(pps); + } + + if idr && !has_parameter_sets { + if let (Some(sps), Some(pps)) = (&state.sps, &state.pps) { + let mut with_parameter_sets = + Vec::with_capacity(data.len() + sps.len() + pps.len() + 8); + with_parameter_sets.extend_from_slice(&[0, 0, 0, 1]); + with_parameter_sets.extend_from_slice(sps); + with_parameter_sets.extend_from_slice(&[0, 0, 0, 1]); + with_parameter_sets.extend_from_slice(pps); + with_parameter_sets.extend_from_slice(data.as_ref()); + data = Bytes::from(with_parameter_sets); + } + } + } + let frame_duration = Duration::from_micros(1_000_000 / self.config.fps.max(1) as u64); let sample = Sample { data, diff --git a/src/webrtc/webrtc_streamer.rs b/src/webrtc/webrtc_streamer.rs index b1a51832..a0799b21 100644 --- a/src/webrtc/webrtc_streamer.rs +++ b/src/webrtc/webrtc_streamer.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, RwLock as StdRwLock}; +use std::time::Duration; use tokio::sync::RwLock; use tracing::{debug, info, trace, warn}; @@ -11,6 +12,7 @@ use crate::audio::{AudioController, OpusFrame}; use crate::error::{AppError, Result}; use crate::events::{EventBus, StreamDeviceLostKind, SystemEvent}; use crate::hid::HidController; +use crate::video::codec::h264_bitstream; use crate::video::device::{ enumerate_devices, select_recovery_device, VideoDevice, VideoDeviceRecoveryHint, }; @@ -24,6 +26,8 @@ use super::config::{TurnServer, WebRtcConfig}; use super::signaling::{ConnectionState, IceCandidate, SdpAnswer, SdpOffer}; use super::universal_session::{UniversalSession, UniversalSessionConfig}; +const H264_PROFILE_DETECT_TIMEOUT: Duration = Duration::from_millis(500); + #[derive(Debug, Clone)] pub struct WebRtcStreamerConfig { pub webrtc: WebRtcConfig, @@ -170,7 +174,7 @@ impl WebRtcStreamer { /// Get list of supported video codecs pub fn supported_video_codecs(&self) -> Vec { - use crate::video::encoder::registry::EncoderRegistry; + use crate::video::codec::registry::EncoderRegistry; let registry = EncoderRegistry::global(); VideoEncoderType::ordered() @@ -583,6 +587,47 @@ impl WebRtcStreamer { self.ensure_video_pipeline().await } + async fn negotiate_video_fmtp( + &self, + pipeline: &SharedVideoPipeline, + codec: VideoEncoderType, + ) -> Option { + match codec { + VideoEncoderType::H264 => { + let profile_level_id = Self::wait_for_h264_profile_level_id(pipeline) + .await + .unwrap_or_else(|| { + h264_bitstream::FALLBACK_WEBRTC_PROFILE_LEVEL_ID.to_string() + }); + Some(h264_bitstream::webrtc_fmtp_line(&profile_level_id)) + } + _ => None, + } + } + + async fn wait_for_h264_profile_level_id(pipeline: &SharedVideoPipeline) -> Option { + let mut rx = pipeline.h264_profile_level_id_watch(); + if let Some(profile_level_id) = rx.borrow().clone() { + return Some(profile_level_id); + } + + let wait = async { + loop { + if rx.changed().await.is_err() { + return None; + } + if let Some(profile_level_id) = rx.borrow().clone() { + return Some(profile_level_id); + } + } + }; + + tokio::time::timeout(H264_PROFILE_DETECT_TIMEOUT, wait) + .await + .ok() + .flatten() + } + /// Get the current pipeline configuration (if pipeline is running) pub async fn get_pipeline_config(&self) -> Option { if let Some(ref pipeline) = *self.video_pipeline.read().await { @@ -855,7 +900,7 @@ impl WebRtcStreamer { Some(backend) => backend.is_hardware(), None => { // Auto mode: check if hardware encoder is available for current codec - use crate::video::encoder::registry::{EncoderRegistry, VideoEncoderType}; + use crate::video::codec::registry::{EncoderRegistry, VideoEncoderType}; let codec_type = match *self.video_codec.read().await { VideoCodecType::H264 => VideoEncoderType::H264, VideoCodecType::H265 => VideoEncoderType::H265, @@ -960,8 +1005,16 @@ impl WebRtcStreamer { bitrate_preset: config.bitrate_preset, fps: config.fps, audio_enabled: *self.audio_enabled.read().await, + video_sdp_fmtp_line: None, }; drop(config); + let video_sdp_fmtp_line = self + .negotiate_video_fmtp(&pipeline, session_config.codec) + .await; + let session_config = UniversalSessionConfig { + video_sdp_fmtp_line, + ..session_config + }; // Create universal session let event_bus = self.events.read().await.clone(); diff --git a/web/src/api/index.ts b/web/src/api/index.ts index 7da22c19..c2e5ff0f 100644 --- a/web/src/api/index.ts +++ b/web/src/api/index.ts @@ -46,6 +46,30 @@ export interface DeviceInfo { memory_total: number memory_used: number network_addresses: NetworkAddress[] + serial_ports: string[] +} + +export interface FeatureCapability { + available: boolean + backends: string[] + selected_backend?: string + reason?: string +} + +export interface PlatformCapabilities { + mode: 'linux' | 'windows' + mode_label: string + video_capture: FeatureCapability + encoder: FeatureCapability + hid: FeatureCapability + atx: FeatureCapability + msd: FeatureCapability + otg: FeatureCapability + audio: FeatureCapability + rustdesk: FeatureCapability + diagnostics: FeatureCapability + extensions: FeatureCapability + service_installation: FeatureCapability } export const systemApi = { @@ -54,12 +78,14 @@ export const systemApi = { version: string build_date: string initialized: boolean + platform: PlatformCapabilities capabilities: { - video: { available: boolean; backend?: string } - hid: { available: boolean; backend?: string } - msd: { available: boolean } - atx: { available: boolean; backend?: string } - audio: { available: boolean; backend?: string } + video: { available: boolean; backend?: string; reason?: string } + hid: { available: boolean; backend?: string; reason?: string } + msd: { available: boolean; backend?: string; reason?: string } + atx: { available: boolean; backend?: string; reason?: string } + audio: { available: boolean; backend?: string; reason?: string } + rustdesk: { available: boolean; backend?: string; reason?: string } } disk_space?: { total: number @@ -72,7 +98,7 @@ export const systemApi = { health: () => request<{ status: string; version: string }>('/health'), setupStatus: () => - request<{ initialized: boolean; needs_setup: boolean }>('/setup'), + request<{ initialized: boolean; needs_setup: boolean; platform: PlatformCapabilities }>('/setup'), setup: (data: { username: string diff --git a/web/src/components/ActionBar.vue b/web/src/components/ActionBar.vue index 82d449fd..f1626115 100644 --- a/web/src/components/ActionBar.vue +++ b/web/src/components/ActionBar.vue @@ -63,6 +63,7 @@ const props = defineProps<{ mouseMode?: 'absolute' | 'relative' videoMode?: VideoMode ttydRunning?: boolean + showTerminal?: boolean }>() const emit = defineEmits<{ @@ -194,6 +195,7 @@ const RIGHT_FIXED_PX = 120 const collapsibleItems = computed(() => { const items = ITEM_SPECS.slice(3).filter(item => { if (item.id === 'msd' && !showMsd.value) return false + if (item.id === 'extension' && props.showTerminal === false) return false return true }) return items @@ -339,7 +341,7 @@ const hasOverflow = computed(() => { -
+