diff --git a/.gitignore b/.gitignore index 7158bec6..725f1869 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ CLAUDE.md # Secrets (compile-time configuration) secrets.toml +.env diff --git a/Cargo.toml b/Cargo.toml index a95b5b8f..c4f7dc1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "one-kvm" -version = "0.1.1" +version = "0.1.4" edition = "2021" authors = ["SilentWind"] description = "A open and lightweight IP-KVM solution written in Rust" @@ -129,6 +129,7 @@ tempfile = "3" [build-dependencies] protobuf-codegen = "3.7" toml = "0.9" +cc = "1" [profile.release] opt-level = 3 diff --git a/README.md b/README.md index 8ead0e9c..4f1f91e6 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,4 @@
- One-KVM Logo

One-KVM

Rust 编写的开放轻量 IP-KVM 解决方案,实现 BIOS 级远程管理

@@ -19,16 +18,6 @@ --- -## 📋 目录 - -- [项目概述](#项目概述) -- [迁移说明](#迁移说明) -- [功能介绍](#功能介绍) -- [快速开始](#快速开始) -- [贡献与反馈](#贡献与反馈) -- [致谢](#致谢) -- [许可证](#许可证) - ## 📖 项目概述 **One-KVM Rust** 是一个用 Rust 编写的轻量级 IP-KVM 解决方案,可通过网络远程管理服务器和工作站,实现 BIOS 级远程控制。 @@ -66,7 +55,7 @@ - **VAAPI**:Intel/AMD GPU - **RKMPP**:Rockchip SoC -- **V4L2 M2M**:RaspberryPi +- **V4L2 M2M**:通用硬件编码器(尚未实现) - **软件编码**:CPU 编码 ### 扩展能力 @@ -74,85 +63,14 @@ - Web UI 配置,多语言支持(中文/英文) - 内置 Web 终端(ttyd)内网穿透支持(gostc)、P2P 组网支持(EasyTier)、RustDesk 协议集成(用于跨平台远程访问能力扩展) -## ⚡ 快速开始 +## ⚡ 安装使用 -安装方式:Docker / DEB 软件包 / 飞牛 NAS(FPK)。 - -### 方式一:Docker 安装(推荐) - -前提条件: - -- Linux 主机已安装 Docker -- 插好 USB HDMI 采集卡 -- 启用 USB OTG 或插好 CH340+CH9329 HID 线(用于 HID 模拟) - -启动容器: - -```bash -docker run --name one-kvm -itd --privileged=true \ - -v /dev:/dev -v /sys/:/sys \ - --net=host \ - silentwind0/one-kvm -``` - -访问 Web 界面:`http://<设备IP>:8080`(首次访问会引导创建管理员账户)。默认端口:HTTP `8080`;启用 HTTPS 后为 `8443`。 - -#### 常用环境变量(Docker) - -| 变量名 | 默认值 | 说明 | -|------|------|------| -| `ENABLE_HTTPS` | `false` | 是否启用 HTTPS(`true/false`) | -| `HTTP_PORT` | `8080` | HTTP 端口(`ENABLE_HTTPS=false` 时生效) | -| `HTTPS_PORT` | `8443` | HTTPS 端口(`ENABLE_HTTPS=true` 时生效) | -| `BIND_ADDRESS` | - | 监听地址(如 `0.0.0.0`) | -| `VERBOSE` | `0` | 日志详细程度:`1`(-v)、`2`(-vv)、`3`(-vvv) | -| `DATA_DIR` | `/etc/one-kvm` | 数据目录(等价于 `one-kvm -d `,优先级高于 `ONE_KVM_DATA_DIR`) | - -> 说明:`--privileged=true` 和挂载 `/dev`、`/sys` 是硬件访问所需配置,当前版本不可省略。 -> -> 兼容性:同时支持旧变量名 `ONE_KVM_DATA_DIR`。 -> -> HTTPS:未提供证书时会自动生成默认自签名证书。 -> -> Ventoy:若修改 `DATA_DIR`,请确保 Ventoy 资源文件位于 `${DATA_DIR}/ventoy`(`boot.img`、`core.img`、`ventoy.disk.img`)。 - -### 方式二:DEB 软件包安装 - -前提条件: - -- Debian 11+ / Ubuntu 22+ -- 插好 USB HDMI 采集卡、HID 线(OTG 或 CH340+CH9329) - -安装步骤: - -1. 从 GitHub Releases 下载适合架构的 `one-kvm_*.deb`:[Releases](https://github.com/mofeng-git/One-KVM/releases) -2. 安装: - -```bash -sudo apt update -sudo apt install ./one-kvm_*_*.deb -``` - -访问 Web 界面:`http://<设备IP>:8080`。 - -### 方式三:飞牛 NAS(FPK)安装 - -前提条件: - -- 飞牛 NAS 系统(目前仅支持 x86_64 架构) -- 插好 USB HDMI 采集卡、CH340+CH9329 HID 线 - -安装步骤: - -1. 从 GitHub Releases 下载 `*.fpk` 软件包:[Releases](https://github.com/mofeng-git/One-KVM/releases) -2. 在飞牛应用商店选择“手动安装”,导入 `*.fpk` - -访问 Web 界面:`http://<设备IP>:8420`。 +可以访问 [One-KVM Rust 文档站点](https://docs.one-kvm.cn/) 获取详细信息。 ## 报告问题 如果您发现了问题,请: -1. 使用 [GitHub Issues](https://github.com/mofeng-git/One-KVM/issues) 报告 +1. 使用 [GitHub Issues](https://github.com/mofeng-git/One-KVM/issues) 报告,或加入 QQ 群聊反馈。 2. 提供详细的错误信息和复现步骤 3. 包含您的硬件配置和系统信息 @@ -269,6 +187,14 @@ sudo apt install ./one-kvm_*_*.deb - 葱 +- MaxZ + +- 爱发电用户_c5f33 + +- 爱发电用户_09386 + +- 爱发电用户_JT6c + - ...... @@ -277,11 +203,6 @@ sudo apt install ./one-kvm_*_*.deb 本项目得到以下赞助商的支持: -**CDN 加速及安全防护:** -- **[Tencent EdgeOne](https://edgeone.ai/zh?from=github)** - 提供 CDN 加速及安全防护服务 - -![Tencent EdgeOne](https://edgeone.ai/media/34fe3a45-492d-4ea4-ae5d-ea1087ca7b4b.png) - **文件存储服务:** - **[Huang1111公益计划](https://pan.huang1111.cn/s/mxkx3T1)** - 提供免登录下载服务 diff --git a/build/Dockerfile.runtime b/build/Dockerfile.runtime index 44617adf..ccdc1658 100644 --- a/build/Dockerfile.runtime +++ b/build/Dockerfile.runtime @@ -39,7 +39,7 @@ RUN apt-get update && \ COPY --chmod=755 init.sh /init.sh # Copy binaries (these are placed by the build script) -COPY --chmod=755 one-kvm ttyd gostc easytier-core /usr/bin/ +COPY --chmod=755 one-kvm ttyd /usr/bin/ # Copy ventoy resources if they exist COPY ventoy/ /etc/one-kvm/ventoy/ diff --git a/build/Dockerfile.runtime-full b/build/Dockerfile.runtime-full new file mode 100644 index 00000000..32428305 --- /dev/null +++ b/build/Dockerfile.runtime-full @@ -0,0 +1,48 @@ +# One-KVM Runtime Image (full) +# This Dockerfile only packages pre-compiled binaries (no compilation) +# Used after cross-compiling with `cross build` +# Using Debian 11 for maximum compatibility (GLIBC 2.31) + +ARG TARGETPLATFORM=linux/amd64 + +FROM debian:11-slim + +ARG TARGETPLATFORM + +# Install runtime dependencies in a single layer +# All codec libraries (libx264, libx265, libopus) are now statically linked +# Only hardware acceleration drivers and core system libraries remain dynamic +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + # Core runtime (all platforms) - no codec libs needed + ca-certificates \ + libudev1 \ + libasound2 \ + # v4l2 is handled by kernel, minimal userspace needed + libv4l-0 \ + && \ + # Platform-specific hardware acceleration + if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \ + apt-get install -y --no-install-recommends \ + libva2 libva-drm2 libva-x11-2 libx11-6 libxcb1 libxau6 libxdmcp6 libmfx1; \ + elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + apt-get install -y --no-install-recommends \ + libdrm2 libva2; \ + elif [ "$TARGETPLATFORM" = "linux/arm/v7" ]; then \ + apt-get install -y --no-install-recommends \ + libdrm2 libva2; \ + fi && \ + rm -rf /var/lib/apt/lists/* && \ + mkdir -p /etc/one-kvm/ventoy + +# Copy init script +COPY --chmod=755 init.sh /init.sh + +# Copy binaries (these are placed by the build script) +COPY --chmod=755 one-kvm ttyd gostc easytier-core /usr/bin/ + +# Copy ventoy resources if they exist +COPY ventoy/ /etc/one-kvm/ventoy/ + +# Entrypoint +CMD ["/init.sh"] diff --git a/build/package-docker.sh b/build/package-docker.sh index ceb1217d..94c9ba82 100755 --- a/build/package-docker.sh +++ b/build/package-docker.sh @@ -25,11 +25,13 @@ echo_error() { echo -e "${RED}[ERROR]${NC} $1"; } # Configuration REGISTRY="${REGISTRY:-}" # e.g., docker.io/username or ghcr.io/username -IMAGE_NAME="${IMAGE_NAME:-one-kvm}" +IMAGE_NAME="${IMAGE_NAME:-}" TAG="${TAG:-latest}" +VARIANT="${VARIANT:-minimal}" +INCLUDE_THIRD_PARTY=false SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" -STAGING_DIR="$PROJECT_ROOT/build-staging" +BASE_STAGING_DIR="$PROJECT_ROOT/build-staging" # Full image name with registry get_full_image_name() { @@ -77,6 +79,18 @@ while [[ $# -gt 0 ]]; do REGISTRY="$2" shift 2 ;; + --image-name) + IMAGE_NAME="$2" + shift 2 + ;; + --variant) + VARIANT="$2" + shift 2 + ;; + --full) + VARIANT="full" + shift + ;; --build) BUILD_BINARY=true shift @@ -91,9 +105,12 @@ while [[ $# -gt 0 ]]; do echo " Use comma to specify multiple: linux/amd64,linux/arm64" echo " Default: $DEFAULT_PLATFORM" echo " --registry REGISTRY Container registry (e.g., docker.io/user, ghcr.io/user)" + echo " --image-name NAME Override image name (default: one-kvm or one-kvm-full)" echo " --push Push image to registry" echo " --load Load image to local Docker (single platform only)" echo " --tag TAG Image tag (default: latest)" + echo " --variant VARIANT Image variant: minimal or full (default: minimal)" + echo " --full Shortcut for --variant full" echo " --build Also build the binary with cross (optional)" echo " --help Show this help" echo "" @@ -101,6 +118,9 @@ while [[ $# -gt 0 ]]; do echo " # Build for current platform and load locally" echo " $0 --platform linux/arm64 --load" echo "" + echo " # Build full image (includes gostc + easytier)" + echo " $0 --variant full --platform linux/arm64 --load" + echo "" echo " # Build and push single platform" echo " $0 --platform linux/arm64 --registry docker.io/user --push" echo "" @@ -115,6 +135,28 @@ while [[ $# -gt 0 ]]; do esac done +# Normalize variant and image name +case "$VARIANT" in + minimal) + INCLUDE_THIRD_PARTY=false + ;; + full) + INCLUDE_THIRD_PARTY=true + ;; + *) + echo_error "Unknown variant: $VARIANT (expected: minimal or full)" + exit 1 + ;; +esac + +if [ -z "$IMAGE_NAME" ]; then + if [ "$VARIANT" = "full" ]; then + IMAGE_NAME="one-kvm-full" + else + IMAGE_NAME="one-kvm" + fi +fi + # Default platform if [ -z "$PLATFORMS" ]; then PLATFORMS="$DEFAULT_PLATFORM" @@ -176,21 +218,23 @@ download_tools() { chmod +x "$staging/ttyd" fi - # gostc - if [ ! -f "$staging/gostc" ]; then - curl -fsSL "$GOSTC_URL" -o /tmp/gostc.tar.gz - tar -xzf /tmp/gostc.tar.gz -C "$staging" - chmod +x "$staging/gostc" - rm /tmp/gostc.tar.gz - fi + if [ "$INCLUDE_THIRD_PARTY" = true ]; then + # gostc + if [ ! -f "$staging/gostc" ]; then + curl -fsSL "$GOSTC_URL" -o /tmp/gostc.tar.gz + tar -xzf /tmp/gostc.tar.gz -C "$staging" + chmod +x "$staging/gostc" + rm /tmp/gostc.tar.gz + fi - # easytier - if [ ! -f "$staging/easytier-core" ]; then - curl -fsSL "$EASYTIER_URL" -o /tmp/easytier.zip - unzip -o /tmp/easytier.zip -d /tmp/easytier - cp "/tmp/easytier/$EASYTIER_DIR/easytier-core" "$staging/easytier-core" - chmod +x "$staging/easytier-core" - rm -rf /tmp/easytier.zip /tmp/easytier + # easytier + if [ ! -f "$staging/easytier-core" ]; then + curl -fsSL "$EASYTIER_URL" -o /tmp/easytier.zip + unzip -o /tmp/easytier.zip -d /tmp/easytier + cp "/tmp/easytier/$EASYTIER_DIR/easytier-core" "$staging/easytier-core" + chmod +x "$staging/easytier-core" + rm -rf /tmp/easytier.zip /tmp/easytier + fi fi } @@ -198,13 +242,14 @@ download_tools() { build_for_platform() { local platform="$1" local target=$(platform_to_target "$platform") - local staging="$STAGING_DIR/$target" + local staging="$BASE_STAGING_DIR/$VARIANT/$target" echo_info "==========================================" echo_info "Processing: $platform ($target)" echo_info "==========================================" # Create staging directory + rm -rf "$staging" mkdir -p "$staging/ventoy" # Build binary if requested @@ -252,7 +297,11 @@ build_for_platform() { fi # Copy Dockerfile - cp "$PROJECT_ROOT/build/Dockerfile.runtime" "$staging/Dockerfile" + local dockerfile="$PROJECT_ROOT/build/Dockerfile.runtime" + if [ "$INCLUDE_THIRD_PARTY" = true ]; then + dockerfile="$PROJECT_ROOT/build/Dockerfile.runtime-full" + fi + cp "$dockerfile" "$staging/Dockerfile" # Build Docker image echo_info "Building Docker image..." @@ -292,6 +341,7 @@ main() { echo_info "One-KVM Docker Image Builder" echo_info "Image: $full_image:$TAG" + echo_info "Variant: $VARIANT" echo_info "Platforms: $PLATFORMS" if [ -n "$REGISTRY" ]; then echo_info "Registry: $REGISTRY" diff --git a/libs/hwcodec/build.rs b/libs/hwcodec/build.rs index bdae4450..f2c89db2 100644 --- a/libs/hwcodec/build.rs +++ b/libs/hwcodec/build.rs @@ -98,6 +98,7 @@ mod ffmpeg { link_os(); build_ffmpeg_ram(builder); + build_ffmpeg_hw(builder); } /// Link system FFmpeg using pkg-config or custom path @@ -374,4 +375,57 @@ mod ffmpeg { ); } } + + fn build_ffmpeg_hw(builder: &mut Build) { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let ffmpeg_hw_dir = manifest_dir.join("cpp").join("ffmpeg_hw"); + let ffi_header = ffmpeg_hw_dir + .join("ffmpeg_hw_ffi.h") + .to_string_lossy() + .to_string(); + bindgen::builder() + .header(ffi_header) + .rustified_enum("*") + .generate() + .unwrap() + .write_to_file(Path::new(&env::var_os("OUT_DIR").unwrap()).join("ffmpeg_hw_ffi.rs")) + .unwrap(); + + let target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_default(); + let enable_rkmpp = matches!(target_arch.as_str(), "aarch64" | "arm") + || std::env::var_os("CARGO_FEATURE_RKMPP").is_some(); + if enable_rkmpp { + // Include RGA headers for NV16->NV12 conversion (RGA im2d API) + let rga_sys_dirs = [ + Path::new("/usr/aarch64-linux-gnu/include/rga"), + Path::new("/usr/include/rga"), + ]; + let mut added = false; + for dir in rga_sys_dirs.iter() { + if dir.exists() { + builder.include(dir); + added = true; + } + } + if !added { + // Fallback to repo-local rkrga headers if present + let repo_root = manifest_dir + .parent() + .and_then(|p| p.parent()) + .map(|p| p.to_path_buf()) + .unwrap_or_else(|| manifest_dir.clone()); + let rkrga_dir = repo_root.join("ffmpeg").join("rkrga"); + if rkrga_dir.exists() { + builder.include(rkrga_dir.join("include")); + builder.include(rkrga_dir.join("im2d_api")); + } + } + builder.file(ffmpeg_hw_dir.join("ffmpeg_hw_mjpeg_h26x.cpp")); + } else { + println!( + "cargo:info=Skipping ffmpeg_hw_mjpeg_h26x.cpp (RKMPP) for arch {}", + target_arch + ); + } + } } diff --git a/libs/hwcodec/cpp/common/platform/linux/linux.cpp b/libs/hwcodec/cpp/common/platform/linux/linux.cpp index c9036d20..e9f1e6c4 100644 --- a/libs/hwcodec/cpp/common/platform/linux/linux.cpp +++ b/libs/hwcodec/cpp/common/platform/linux/linux.cpp @@ -1,12 +1,16 @@ #include "linux.h" #include "../../log.h" +#include +#include #include #include #include +#include #include #include #include #include +#include // Check for NVIDIA driver support by loading CUDA libraries int linux_support_nv() @@ -106,6 +110,57 @@ int linux_support_rkmpp() { // Check for V4L2 Memory-to-Memory (M2M) codec support // Returns 0 if a M2M capable device is found, -1 otherwise int linux_support_v4l2m2m() { + auto to_lower = [](std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return value; + }; + + auto read_text_file = [](const char *path, std::string *out) -> bool { + std::ifstream file(path); + if (!file.is_open()) { + return false; + } + std::getline(file, *out); + return !out->empty(); + }; + + auto allow_video0_probe = []() -> bool { + const char *env = std::getenv("ONE_KVM_V4L2M2M_ALLOW_VIDEO0"); + if (env == nullptr) { + return false; + } + if (env[0] == '\0') { + return false; + } + return std::strcmp(env, "0") != 0; + }; + + auto is_amlogic_vdec = [&]() -> bool { + std::string name; + std::string modalias; + if (read_text_file("/sys/class/video4linux/video0/name", &name)) { + const std::string lowered = to_lower(name); + if (lowered.find("meson") != std::string::npos || + lowered.find("vdec") != std::string::npos || + lowered.find("decoder") != std::string::npos || + lowered.find("video-decoder") != std::string::npos) { + return true; + } + } + if (read_text_file("/sys/class/video4linux/video0/device/modalias", &modalias)) { + const std::string lowered = to_lower(modalias); + if (lowered.find("amlogic") != std::string::npos || + lowered.find("meson") != std::string::npos || + lowered.find("gxl-vdec") != std::string::npos || + lowered.find("gx-vdec") != std::string::npos) { + return true; + } + } + return false; + }; + // Check common V4L2 M2M device paths used by various ARM SoCs // /dev/video10 - Standard on many SoCs // /dev/video11 - Standard on many SoCs (often decoder) @@ -124,6 +179,13 @@ int linux_support_v4l2m2m() { for (size_t i = 0; i < sizeof(m2m_devices) / sizeof(m2m_devices[0]); i++) { if (access(m2m_devices[i], F_OK) == 0) { + if (std::strcmp(m2m_devices[i], "/dev/video0") == 0) { + if (!allow_video0_probe() && is_amlogic_vdec()) { + LOG_TRACE(std::string("V4L2 M2M: Skipping /dev/video0 (Amlogic vdec)")); + continue; + } + } + // Device exists, check if it's an M2M device by trying to open it int fd = open(m2m_devices[i], O_RDWR | O_NONBLOCK); if (fd >= 0) { diff --git a/libs/hwcodec/cpp/ffmpeg_hw/ffmpeg_hw_ffi.h b/libs/hwcodec/cpp/ffmpeg_hw/ffmpeg_hw_ffi.h new file mode 100644 index 00000000..ac4cba21 --- /dev/null +++ b/libs/hwcodec/cpp/ffmpeg_hw/ffmpeg_hw_ffi.h @@ -0,0 +1,50 @@ +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// MJPEG -> H26x (H.264 / H.265) hardware pipeline +typedef struct FfmpegHwMjpegH26x FfmpegHwMjpegH26x; + +// Create a new MJPEG -> H26x pipeline. +FfmpegHwMjpegH26x* ffmpeg_hw_mjpeg_h26x_new(const char* dec_name, + const char* enc_name, + int width, + int height, + int fps, + int bitrate_kbps, + int gop, + int thread_count); + +// Encode one MJPEG frame. Returns 1 if output produced, 0 if no output, <0 on error. +int ffmpeg_hw_mjpeg_h26x_encode(FfmpegHwMjpegH26x* ctx, + const uint8_t* data, + int len, + int64_t pts_ms, + uint8_t** out_data, + int* out_len, + int* out_keyframe); + +// Reconfigure bitrate/gop (best-effort, may recreate encoder internally). +int ffmpeg_hw_mjpeg_h26x_reconfigure(FfmpegHwMjpegH26x* ctx, + int bitrate_kbps, + int gop); + +// Request next frame to be a keyframe. +int ffmpeg_hw_mjpeg_h26x_request_keyframe(FfmpegHwMjpegH26x* ctx); + +// Free pipeline resources. +void ffmpeg_hw_mjpeg_h26x_free(FfmpegHwMjpegH26x* ctx); + +// Free packet buffer allocated by ffmpeg_hw_mjpeg_h26x_encode. +void ffmpeg_hw_packet_free(uint8_t* data); + +// Get last error message (thread-local). +const char* ffmpeg_hw_last_error(void); + +#ifdef __cplusplus +} +#endif diff --git a/libs/hwcodec/cpp/ffmpeg_hw/ffmpeg_hw_mjpeg_h26x.cpp b/libs/hwcodec/cpp/ffmpeg_hw/ffmpeg_hw_mjpeg_h26x.cpp new file mode 100644 index 00000000..d19aafca --- /dev/null +++ b/libs/hwcodec/cpp/ffmpeg_hw/ffmpeg_hw_mjpeg_h26x.cpp @@ -0,0 +1,468 @@ +extern "C" { +#include +#include +#include +#include +#include +#include +#include +} + +#include +#include +#include +#include + +#define LOG_MODULE "FFMPEG_HW" +#include "../common/log.h" + +#include "ffmpeg_hw_ffi.h" + +namespace { +thread_local std::string g_last_error; + +static void set_last_error(const std::string &msg) { + g_last_error = msg; + LOG_ERROR(msg); +} + +static std::string make_err(const std::string &ctx, int err) { + return ctx + " (ret=" + std::to_string(err) + "): " + av_err2str(err); +} + +static const char* pix_fmt_name(AVPixelFormat fmt) { + const char *name = av_get_pix_fmt_name(fmt); + return name ? name : "unknown"; +} + +struct FfmpegHwMjpegH26xCtx { + AVCodecContext *dec_ctx = nullptr; + AVCodecContext *enc_ctx = nullptr; + AVPacket *dec_pkt = nullptr; + AVFrame *dec_frame = nullptr; + AVPacket *enc_pkt = nullptr; + AVBufferRef *hw_device_ctx = nullptr; + AVBufferRef *hw_frames_ctx = nullptr; + AVPixelFormat hw_pixfmt = AV_PIX_FMT_NONE; + std::string dec_name; + std::string enc_name; + int width = 0; + int height = 0; + int aligned_width = 0; + int aligned_height = 0; + int fps = 30; + int bitrate_kbps = 2000; + int gop = 60; + int thread_count = 1; + bool force_keyframe = false; +}; + +static enum AVPixelFormat get_hw_format(AVCodecContext *ctx, + const enum AVPixelFormat *pix_fmts) { + auto *self = reinterpret_cast(ctx->opaque); + if (self && self->hw_pixfmt != AV_PIX_FMT_NONE) { + const enum AVPixelFormat *p; + for (p = pix_fmts; *p != AV_PIX_FMT_NONE; p++) { + if (*p == self->hw_pixfmt) { + return *p; + } + } + } + return pix_fmts[0]; +} + +static int init_decoder(FfmpegHwMjpegH26xCtx *ctx) { + const AVCodec *dec = avcodec_find_decoder_by_name(ctx->dec_name.c_str()); + if (!dec) { + set_last_error("Decoder not found: " + ctx->dec_name); + return -1; + } + + ctx->dec_ctx = avcodec_alloc_context3(dec); + if (!ctx->dec_ctx) { + set_last_error("Failed to allocate decoder context"); + return -1; + } + + ctx->dec_ctx->width = ctx->width; + ctx->dec_ctx->height = ctx->height; + ctx->dec_ctx->thread_count = ctx->thread_count > 0 ? ctx->thread_count : 1; + ctx->dec_ctx->opaque = ctx; + + // Pick HW pixfmt for RKMPP + const AVCodecHWConfig *cfg = nullptr; + for (int i = 0; (cfg = avcodec_get_hw_config(dec, i)); i++) { + if (cfg->device_type == AV_HWDEVICE_TYPE_RKMPP) { + ctx->hw_pixfmt = cfg->pix_fmt; + break; + } + } + if (ctx->hw_pixfmt == AV_PIX_FMT_NONE) { + set_last_error("No RKMPP hw pixfmt for decoder"); + return -1; + } + + int ret = av_hwdevice_ctx_create(&ctx->hw_device_ctx, + AV_HWDEVICE_TYPE_RKMPP, NULL, NULL, 0); + if (ret < 0) { + set_last_error(make_err("av_hwdevice_ctx_create failed", ret)); + return -1; + } + + ctx->dec_ctx->hw_device_ctx = av_buffer_ref(ctx->hw_device_ctx); + ctx->dec_ctx->get_format = get_hw_format; + + ret = avcodec_open2(ctx->dec_ctx, dec, NULL); + if (ret < 0) { + set_last_error(make_err("avcodec_open2 decoder failed", ret)); + return -1; + } + + ctx->dec_pkt = av_packet_alloc(); + ctx->dec_frame = av_frame_alloc(); + ctx->enc_pkt = av_packet_alloc(); + if (!ctx->dec_pkt || !ctx->dec_frame || !ctx->enc_pkt) { + set_last_error("Failed to allocate packet/frame"); + return -1; + } + + return 0; +} + +static int init_encoder(FfmpegHwMjpegH26xCtx *ctx, AVBufferRef *frames_ctx) { + const AVCodec *enc = avcodec_find_encoder_by_name(ctx->enc_name.c_str()); + if (!enc) { + set_last_error("Encoder not found: " + ctx->enc_name); + return -1; + } + + ctx->enc_ctx = avcodec_alloc_context3(enc); + if (!ctx->enc_ctx) { + set_last_error("Failed to allocate encoder context"); + return -1; + } + + ctx->enc_ctx->width = ctx->width; + ctx->enc_ctx->height = ctx->height; + ctx->enc_ctx->coded_width = ctx->width; + ctx->enc_ctx->coded_height = ctx->height; + ctx->aligned_width = ctx->width; + ctx->aligned_height = ctx->height; + ctx->enc_ctx->time_base = AVRational{1, 1000}; + ctx->enc_ctx->framerate = AVRational{ctx->fps, 1}; + ctx->enc_ctx->bit_rate = (int64_t)ctx->bitrate_kbps * 1000; + ctx->enc_ctx->gop_size = ctx->gop > 0 ? ctx->gop : ctx->fps; + ctx->enc_ctx->max_b_frames = 0; + ctx->enc_ctx->pix_fmt = AV_PIX_FMT_DRM_PRIME; + ctx->enc_ctx->sw_pix_fmt = AV_PIX_FMT_NV12; + + if (frames_ctx) { + AVHWFramesContext *hwfc = reinterpret_cast(frames_ctx->data); + if (hwfc) { + ctx->enc_ctx->pix_fmt = static_cast(hwfc->format); + ctx->enc_ctx->sw_pix_fmt = static_cast(hwfc->sw_format); + if (hwfc->width > 0) { + ctx->aligned_width = hwfc->width; + ctx->enc_ctx->coded_width = hwfc->width; + } + if (hwfc->height > 0) { + ctx->aligned_height = hwfc->height; + ctx->enc_ctx->coded_height = hwfc->height; + } + } + ctx->hw_frames_ctx = av_buffer_ref(frames_ctx); + ctx->enc_ctx->hw_frames_ctx = av_buffer_ref(frames_ctx); + } + if (ctx->hw_device_ctx) { + ctx->enc_ctx->hw_device_ctx = av_buffer_ref(ctx->hw_device_ctx); + } + + AVDictionary *opts = nullptr; + av_dict_set(&opts, "rc_mode", "CBR", 0); + if (enc->id == AV_CODEC_ID_H264) { + av_dict_set(&opts, "profile", "high", 0); + } else if (enc->id == AV_CODEC_ID_HEVC) { + av_dict_set(&opts, "profile", "main", 0); + } + av_dict_set_int(&opts, "qp_init", 23, 0); + av_dict_set_int(&opts, "qp_max", 48, 0); + av_dict_set_int(&opts, "qp_min", 0, 0); + av_dict_set_int(&opts, "qp_max_i", 48, 0); + av_dict_set_int(&opts, "qp_min_i", 0, 0); + int ret = avcodec_open2(ctx->enc_ctx, enc, &opts); + av_dict_free(&opts); + if (ret < 0) { + std::string detail = "avcodec_open2 encoder failed: "; + detail += ctx->enc_name; + detail += " fmt=" + std::string(pix_fmt_name(ctx->enc_ctx->pix_fmt)); + detail += " sw=" + std::string(pix_fmt_name(ctx->enc_ctx->sw_pix_fmt)); + detail += " size=" + std::to_string(ctx->enc_ctx->width) + "x" + std::to_string(ctx->enc_ctx->height); + detail += " fps=" + std::to_string(ctx->fps); + set_last_error(make_err(detail, ret)); + avcodec_free_context(&ctx->enc_ctx); + ctx->enc_ctx = nullptr; + if (ctx->hw_frames_ctx) { + av_buffer_unref(&ctx->hw_frames_ctx); + ctx->hw_frames_ctx = nullptr; + } + return -1; + } + + return 0; +} + +static void free_encoder(FfmpegHwMjpegH26xCtx *ctx) { + if (ctx->enc_ctx) { + avcodec_free_context(&ctx->enc_ctx); + ctx->enc_ctx = nullptr; + } + if (ctx->hw_frames_ctx) { + av_buffer_unref(&ctx->hw_frames_ctx); + ctx->hw_frames_ctx = nullptr; + } +} + +} // namespace + +extern "C" FfmpegHwMjpegH26x* ffmpeg_hw_mjpeg_h26x_new(const char* dec_name, + const char* enc_name, + int width, + int height, + int fps, + int bitrate_kbps, + int gop, + int thread_count) { + if (!dec_name || !enc_name || width <= 0 || height <= 0) { + set_last_error("Invalid parameters for ffmpeg_hw_mjpeg_h26x_new"); + return nullptr; + } + + auto *ctx = new FfmpegHwMjpegH26xCtx(); + ctx->dec_name = dec_name; + ctx->enc_name = enc_name; + ctx->width = width; + ctx->height = height; + ctx->fps = fps > 0 ? fps : 30; + ctx->bitrate_kbps = bitrate_kbps > 0 ? bitrate_kbps : 2000; + ctx->gop = gop > 0 ? gop : ctx->fps; + ctx->thread_count = thread_count > 0 ? thread_count : 1; + + if (init_decoder(ctx) != 0) { + ffmpeg_hw_mjpeg_h26x_free(reinterpret_cast(ctx)); + return nullptr; + } + + return reinterpret_cast(ctx); +} + +extern "C" int ffmpeg_hw_mjpeg_h26x_encode(FfmpegHwMjpegH26x* handle, + const uint8_t* data, + int len, + int64_t pts_ms, + uint8_t** out_data, + int* out_len, + int* out_keyframe) { + if (!handle || !data || len <= 0 || !out_data || !out_len || !out_keyframe) { + set_last_error("Invalid parameters for encode"); + return -1; + } + + auto *ctx = reinterpret_cast(handle); + *out_data = nullptr; + *out_len = 0; + *out_keyframe = 0; + + av_packet_unref(ctx->dec_pkt); + int ret = av_new_packet(ctx->dec_pkt, len); + if (ret < 0) { + set_last_error(make_err("av_new_packet failed", ret)); + return -1; + } + memcpy(ctx->dec_pkt->data, data, len); + ctx->dec_pkt->size = len; + + ret = avcodec_send_packet(ctx->dec_ctx, ctx->dec_pkt); + if (ret < 0) { + set_last_error(make_err("avcodec_send_packet failed", ret)); + return -1; + } + + while (true) { + ret = avcodec_receive_frame(ctx->dec_ctx, ctx->dec_frame); + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { + return 0; + } + if (ret < 0) { + set_last_error(make_err("avcodec_receive_frame failed", ret)); + return -1; + } + + if (ctx->dec_frame->format != AV_PIX_FMT_DRM_PRIME) { + set_last_error("Decoder output is not DRM_PRIME"); + av_frame_unref(ctx->dec_frame); + return -1; + } + + if (!ctx->enc_ctx) { + if (!ctx->dec_frame->hw_frames_ctx) { + set_last_error("Decoder returned frame without hw_frames_ctx"); + av_frame_unref(ctx->dec_frame); + return -1; + } + if (init_encoder(ctx, ctx->dec_frame->hw_frames_ctx) != 0) { + av_frame_unref(ctx->dec_frame); + return -1; + } + } + + AVFrame *send_frame = ctx->dec_frame; + AVFrame *tmp = nullptr; + if (ctx->force_keyframe) { + tmp = av_frame_clone(send_frame); + if (tmp) { + tmp->pict_type = AV_PICTURE_TYPE_I; + send_frame = tmp; + } + ctx->force_keyframe = false; + } + + // Apply visible size crop if aligned buffer is larger than display size + if (ctx->aligned_width > 0 && ctx->width > 0 && ctx->aligned_width > ctx->width) { + send_frame->crop_right = ctx->aligned_width - ctx->width; + } + if (ctx->aligned_height > 0 && ctx->height > 0 && ctx->aligned_height > ctx->height) { + send_frame->crop_bottom = ctx->aligned_height - ctx->height; + } + + send_frame->pts = pts_ms; // time_base is ms + + ret = avcodec_send_frame(ctx->enc_ctx, send_frame); + if (tmp) { + av_frame_free(&tmp); + } + if (ret < 0) { + std::string detail = "avcodec_send_frame failed"; + if (send_frame) { + detail += " frame_fmt="; + detail += pix_fmt_name(static_cast(send_frame->format)); + detail += " w=" + std::to_string(send_frame->width); + detail += " h=" + std::to_string(send_frame->height); + if (send_frame->format == AV_PIX_FMT_DRM_PRIME && send_frame->data[0]) { + const AVDRMFrameDescriptor *drm = + reinterpret_cast(send_frame->data[0]); + if (drm && drm->layers[0].format) { + detail += " drm_fmt=0x"; + char buf[9]; + snprintf(buf, sizeof(buf), "%08x", drm->layers[0].format); + detail += buf; + } + if (drm && drm->objects[0].format_modifier) { + detail += " drm_mod=0x"; + char buf[17]; + snprintf(buf, sizeof(buf), "%016llx", + (unsigned long long)drm->objects[0].format_modifier); + detail += buf; + } + } + } + set_last_error(make_err(detail, ret)); + av_frame_unref(ctx->dec_frame); + return -1; + } + + av_packet_unref(ctx->enc_pkt); + ret = avcodec_receive_packet(ctx->enc_ctx, ctx->enc_pkt); + if (ret == AVERROR(EAGAIN)) { + av_frame_unref(ctx->dec_frame); + return 0; + } + if (ret < 0) { + set_last_error(make_err("avcodec_receive_packet failed", ret)); + av_frame_unref(ctx->dec_frame); + return -1; + } + + if (ctx->enc_pkt->size > 0) { + uint8_t *buf = (uint8_t*)malloc(ctx->enc_pkt->size); + if (!buf) { + set_last_error("malloc for output packet failed"); + av_packet_unref(ctx->enc_pkt); + av_frame_unref(ctx->dec_frame); + return -1; + } + memcpy(buf, ctx->enc_pkt->data, ctx->enc_pkt->size); + *out_data = buf; + *out_len = ctx->enc_pkt->size; + *out_keyframe = (ctx->enc_pkt->flags & AV_PKT_FLAG_KEY) ? 1 : 0; + av_packet_unref(ctx->enc_pkt); + av_frame_unref(ctx->dec_frame); + return 1; + } + + av_frame_unref(ctx->dec_frame); + } +} + +extern "C" int ffmpeg_hw_mjpeg_h26x_reconfigure(FfmpegHwMjpegH26x* handle, + int bitrate_kbps, + int gop) { + if (!handle) { + set_last_error("Invalid handle for reconfigure"); + return -1; + } + auto *ctx = reinterpret_cast(handle); + if (!ctx->enc_ctx || !ctx->hw_frames_ctx) { + set_last_error("Encoder not initialized for reconfigure"); + return -1; + } + + ctx->bitrate_kbps = bitrate_kbps > 0 ? bitrate_kbps : ctx->bitrate_kbps; + ctx->gop = gop > 0 ? gop : ctx->gop; + + AVBufferRef *frames_ref = ctx->hw_frames_ctx ? av_buffer_ref(ctx->hw_frames_ctx) : nullptr; + free_encoder(ctx); + + if (init_encoder(ctx, frames_ref) != 0) { + if (frames_ref) av_buffer_unref(&frames_ref); + return -1; + } + if (frames_ref) av_buffer_unref(&frames_ref); + + return 0; +} + +extern "C" int ffmpeg_hw_mjpeg_h26x_request_keyframe(FfmpegHwMjpegH26x* handle) { + if (!handle) { + set_last_error("Invalid handle for request_keyframe"); + return -1; + } + auto *ctx = reinterpret_cast(handle); + ctx->force_keyframe = true; + return 0; +} + +extern "C" void ffmpeg_hw_mjpeg_h26x_free(FfmpegHwMjpegH26x* handle) { + auto *ctx = reinterpret_cast(handle); + if (!ctx) return; + + if (ctx->dec_pkt) av_packet_free(&ctx->dec_pkt); + if (ctx->dec_frame) av_frame_free(&ctx->dec_frame); + if (ctx->enc_pkt) av_packet_free(&ctx->enc_pkt); + + if (ctx->dec_ctx) avcodec_free_context(&ctx->dec_ctx); + free_encoder(ctx); + + if (ctx->hw_device_ctx) av_buffer_unref(&ctx->hw_device_ctx); + + delete ctx; +} + +extern "C" void ffmpeg_hw_packet_free(uint8_t* data) { + if (data) { + free(data); + } +} + +extern "C" const char* ffmpeg_hw_last_error(void) { + return g_last_error.c_str(); +} diff --git a/libs/hwcodec/src/ffmpeg_hw/mod.rs b/libs/hwcodec/src/ffmpeg_hw/mod.rs new file mode 100644 index 00000000..222c9d14 --- /dev/null +++ b/libs/hwcodec/src/ffmpeg_hw/mod.rs @@ -0,0 +1,118 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +use std::{ + ffi::{CStr, CString}, + os::raw::c_int, +}; + +include!(concat!(env!("OUT_DIR"), "/ffmpeg_hw_ffi.rs")); + +#[derive(Debug, Clone)] +pub struct HwMjpegH26xConfig { + pub decoder: String, + pub encoder: String, + pub width: i32, + pub height: i32, + pub fps: i32, + pub bitrate_kbps: i32, + pub gop: i32, + pub thread_count: i32, +} + +pub struct HwMjpegH26xPipeline { + ctx: *mut FfmpegHwMjpegH26x, + config: HwMjpegH26xConfig, +} + +unsafe impl Send for HwMjpegH26xPipeline {} + +impl HwMjpegH26xPipeline { + pub fn new(config: HwMjpegH26xConfig) -> Result { + unsafe { + let dec = CString::new(config.decoder.as_str()).map_err(|_| "decoder name invalid".to_string())?; + let enc = CString::new(config.encoder.as_str()).map_err(|_| "encoder name invalid".to_string())?; + let ctx = ffmpeg_hw_mjpeg_h26x_new( + dec.as_ptr(), + enc.as_ptr(), + config.width, + config.height, + config.fps, + config.bitrate_kbps, + config.gop, + config.thread_count, + ); + if ctx.is_null() { + return Err(last_error_message()); + } + Ok(Self { ctx, config }) + } + } + + pub fn encode(&mut self, data: &[u8], pts_ms: i64) -> Result, bool)>, String> { + unsafe { + let mut out_data: *mut u8 = std::ptr::null_mut(); + let mut out_len: c_int = 0; + let mut out_key: c_int = 0; + let ret = ffmpeg_hw_mjpeg_h26x_encode( + self.ctx, + data.as_ptr(), + data.len() as c_int, + pts_ms, + &mut out_data, + &mut out_len, + &mut out_key, + ); + if ret < 0 { + return Err(last_error_message()); + } + if out_data.is_null() || out_len == 0 { + return Ok(None); + } + let slice = std::slice::from_raw_parts(out_data, out_len as usize); + let mut vec = Vec::with_capacity(slice.len()); + vec.extend_from_slice(slice); + ffmpeg_hw_packet_free(out_data); + Ok(Some((vec, out_key != 0))) + } + } + + pub fn reconfigure(&mut self, bitrate_kbps: i32, gop: i32) -> Result<(), String> { + unsafe { + let ret = ffmpeg_hw_mjpeg_h26x_reconfigure(self.ctx, bitrate_kbps, gop); + if ret != 0 { + return Err(last_error_message()); + } + self.config.bitrate_kbps = bitrate_kbps; + self.config.gop = gop; + Ok(()) + } + } + + pub fn request_keyframe(&mut self) { + unsafe { + let _ = ffmpeg_hw_mjpeg_h26x_request_keyframe(self.ctx); + } + } +} + +impl Drop for HwMjpegH26xPipeline { + fn drop(&mut self) { + unsafe { + ffmpeg_hw_mjpeg_h26x_free(self.ctx); + } + self.ctx = std::ptr::null_mut(); + } +} + +pub fn last_error_message() -> String { + unsafe { + let ptr = ffmpeg_hw_last_error(); + if ptr.is_null() { + return String::new(); + } + let cstr = CStr::from_ptr(ptr); + cstr.to_string_lossy().to_string() + } +} diff --git a/libs/hwcodec/src/lib.rs b/libs/hwcodec/src/lib.rs index 33054dd8..9645c1f1 100644 --- a/libs/hwcodec/src/lib.rs +++ b/libs/hwcodec/src/lib.rs @@ -1,5 +1,7 @@ pub mod common; pub mod ffmpeg; +#[cfg(any(target_arch = "aarch64", target_arch = "arm", feature = "rkmpp"))] +pub mod ffmpeg_hw; pub mod ffmpeg_ram; #[no_mangle] diff --git a/src/audio/capture.rs b/src/audio/capture.rs index 229e1650..33fcc673 100644 --- a/src/audio/capture.rs +++ b/src/audio/capture.rs @@ -117,21 +117,11 @@ pub enum CaptureState { Error, } -/// Audio capture statistics -#[derive(Debug, Clone, Default)] -pub struct AudioStats { - pub frames_captured: u64, - pub frames_dropped: u64, - pub buffer_overruns: u64, - pub current_latency_ms: f32, -} - /// ALSA audio capturer pub struct AudioCapturer { config: AudioConfig, state: Arc>, state_rx: watch::Receiver, - stats: Arc>, frame_tx: broadcast::Sender, stop_flag: Arc, sequence: Arc, @@ -150,7 +140,6 @@ impl AudioCapturer { config, state: Arc::new(state_tx), state_rx, - stats: Arc::new(Mutex::new(AudioStats::default())), frame_tx, stop_flag: Arc::new(AtomicBool::new(false)), sequence: Arc::new(AtomicU64::new(0)), @@ -174,11 +163,6 @@ impl AudioCapturer { self.frame_tx.subscribe() } - /// Get statistics - pub async fn stats(&self) -> AudioStats { - self.stats.lock().await.clone() - } - /// Start capturing pub async fn start(&self) -> Result<()> { if self.state() == CaptureState::Running { @@ -194,7 +178,6 @@ impl AudioCapturer { let config = self.config.clone(); let state = self.state.clone(); - let stats = self.stats.clone(); let frame_tx = self.frame_tx.clone(); let stop_flag = self.stop_flag.clone(); let sequence = self.sequence.clone(); @@ -204,7 +187,6 @@ impl AudioCapturer { capture_loop( config, state, - stats, frame_tx, stop_flag, sequence, @@ -239,7 +221,6 @@ impl AudioCapturer { fn capture_loop( config: AudioConfig, state: Arc>, - stats: Arc>, frame_tx: broadcast::Sender, stop_flag: Arc, sequence: Arc, @@ -248,7 +229,6 @@ fn capture_loop( let result = run_capture( &config, &state, - &stats, &frame_tx, &stop_flag, &sequence, @@ -266,7 +246,6 @@ fn capture_loop( fn run_capture( config: &AudioConfig, state: &watch::Sender, - stats: &Arc>, frame_tx: &broadcast::Sender, stop_flag: &AtomicBool, sequence: &AtomicU64, @@ -334,9 +313,6 @@ fn run_capture( match pcm.state() { State::XRun => { warn_throttled!(log_throttler, "xrun", "Audio buffer overrun, recovering"); - if let Ok(mut s) = stats.try_lock() { - s.buffer_overruns += 1; - } let _ = pcm.prepare(); continue; } @@ -377,11 +353,6 @@ fn run_capture( debug!("No audio receivers: {}", e); } } - - // Update stats - if let Ok(mut s) = stats.try_lock() { - s.frames_captured += 1; - } } Err(e) => { // Check for buffer overrun (EPIPE = 32 on Linux) @@ -389,21 +360,12 @@ fn run_capture( if desc.contains("EPIPE") || desc.contains("Broken pipe") { // Buffer overrun warn_throttled!(log_throttler, "buffer_overrun", "Audio buffer overrun"); - if let Ok(mut s) = stats.try_lock() { - s.buffer_overruns += 1; - } let _ = pcm.prepare(); } else if desc.contains("No such device") || desc.contains("ENODEV") { // Device disconnected - use longer throttle for this error_throttled!(log_throttler, "no_device", "Audio read error: {}", e); - if let Ok(mut s) = stats.try_lock() { - s.frames_dropped += 1; - } } else { error_throttled!(log_throttler, "read_error", "Audio read error: {}", e); - if let Ok(mut s) = stats.try_lock() { - s.frames_dropped += 1; - } } } } diff --git a/src/audio/controller.rs b/src/audio/controller.rs index 9858a764..ea3621d0 100644 --- a/src/audio/controller.rs +++ b/src/audio/controller.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use std::sync::Arc; -use tokio::sync::{broadcast, RwLock}; +use tokio::sync::RwLock; use tracing::info; use super::capture::AudioConfig; @@ -104,10 +104,6 @@ pub struct AudioStatus { pub quality: AudioQuality, /// Number of connected subscribers pub subscriber_count: usize, - /// Frames encoded - pub frames_encoded: u64, - /// Bytes output - pub bytes_output: u64, /// Error message if any pub error: Option, } @@ -352,17 +348,11 @@ impl AudioController { let streaming = self.is_streaming().await; let error = self.last_error.read().await.clone(); - let (subscriber_count, frames_encoded, bytes_output) = - if let Some(ref streamer) = *self.streamer.read().await { - let stats = streamer.stats().await; - ( - stats.subscriber_count, - stats.frames_encoded, - stats.bytes_output, - ) - } else { - (0, 0, 0) - }; + let subscriber_count = if let Some(ref streamer) = *self.streamer.read().await { + streamer.stats().await.subscriber_count + } else { + 0 + }; AudioStatus { enabled: config.enabled, @@ -374,14 +364,12 @@ impl AudioController { }, quality: config.quality, subscriber_count, - frames_encoded, - bytes_output, error, } } /// Subscribe to Opus frames (for WebSocket clients) - pub fn subscribe_opus(&self) -> Option> { + pub fn subscribe_opus(&self) -> Option>>> { // Use try_read to avoid blocking - this is called from sync context sometimes if let Ok(guard) = self.streamer.try_read() { guard.as_ref().map(|s| s.subscribe_opus()) @@ -391,7 +379,9 @@ impl AudioController { } /// Subscribe to Opus frames (async version) - pub async fn subscribe_opus_async(&self) -> Option> { + pub async fn subscribe_opus_async( + &self, + ) -> Option>>> { self.streamer .read() .await diff --git a/src/audio/mod.rs b/src/audio/mod.rs index b6a3f9c6..829bef91 100644 --- a/src/audio/mod.rs +++ b/src/audio/mod.rs @@ -6,7 +6,6 @@ //! - Audio device enumeration //! - Audio streaming pipeline //! - High-level audio controller -//! - Shared audio pipeline for WebRTC multi-session support //! - Device health monitoring pub mod capture; @@ -14,7 +13,6 @@ pub mod controller; pub mod device; pub mod encoder; pub mod monitor; -pub mod shared_pipeline; pub mod streamer; pub use capture::{AudioCapturer, AudioConfig, AudioFrame}; @@ -22,7 +20,4 @@ pub use controller::{AudioController, AudioControllerConfig, AudioQuality, Audio pub use device::{enumerate_audio_devices, enumerate_audio_devices_with_current, AudioDeviceInfo}; pub use encoder::{OpusConfig, OpusEncoder, OpusFrame}; pub use monitor::{AudioHealthMonitor, AudioHealthStatus, AudioMonitorConfig}; -pub use shared_pipeline::{ - SharedAudioPipeline, SharedAudioPipelineConfig, SharedAudioPipelineStats, -}; pub use streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig}; diff --git a/src/audio/shared_pipeline.rs b/src/audio/shared_pipeline.rs deleted file mode 100644 index 0e2cab4b..00000000 --- a/src/audio/shared_pipeline.rs +++ /dev/null @@ -1,450 +0,0 @@ -//! Shared Audio Pipeline for WebRTC -//! -//! This module provides a shared audio encoding pipeline that can serve -//! multiple WebRTC sessions with a single encoder instance. -//! -//! # Architecture -//! -//! ```text -//! AudioCapturer (ALSA) -//! | -//! v (broadcast::Receiver) -//! SharedAudioPipeline (single Opus encoder) -//! | -//! v (broadcast::Sender) -//! ┌────┴────┬────────┬────────┐ -//! v v v v -//! Session1 Session2 Session3 ... -//! (RTP) (RTP) (RTP) (RTP) -//! ``` -//! -//! # Key Features -//! -//! - **Single encoder**: All sessions share one Opus encoder -//! - **Broadcast distribution**: Encoded frames are broadcast to all subscribers -//! - **Dynamic bitrate**: Bitrate can be changed at runtime -//! - **Statistics**: Tracks encoding performance metrics - -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::Arc; -use std::time::Instant; -use tokio::sync::{broadcast, Mutex, RwLock}; -use tracing::{debug, error, info, trace, warn}; - -use super::capture::AudioFrame; -use super::encoder::{OpusConfig, OpusEncoder, OpusFrame}; -use crate::error::{AppError, Result}; - -/// Shared audio pipeline configuration -#[derive(Debug, Clone)] -pub struct SharedAudioPipelineConfig { - /// Sample rate (must match audio capture) - pub sample_rate: u32, - /// Number of channels (1 or 2) - pub channels: u32, - /// Target bitrate in bps - pub bitrate: u32, - /// Opus application mode - pub application: OpusApplicationMode, - /// Enable forward error correction - pub fec: bool, - /// Broadcast channel capacity - pub channel_capacity: usize, -} - -impl Default for SharedAudioPipelineConfig { - fn default() -> Self { - Self { - sample_rate: 48000, - channels: 2, - bitrate: 64000, - application: OpusApplicationMode::Audio, - fec: true, - channel_capacity: 16, // Reduced from 64 for lower latency - } - } -} - -impl SharedAudioPipelineConfig { - /// Create config optimized for voice - pub fn voice() -> Self { - Self { - bitrate: 32000, - application: OpusApplicationMode::Voip, - ..Default::default() - } - } - - /// Create config optimized for music/high quality - pub fn high_quality() -> Self { - Self { - bitrate: 128000, - application: OpusApplicationMode::Audio, - ..Default::default() - } - } - - /// Convert to OpusConfig - pub fn to_opus_config(&self) -> OpusConfig { - OpusConfig { - sample_rate: self.sample_rate, - channels: self.channels, - bitrate: self.bitrate, - application: match self.application { - OpusApplicationMode::Voip => super::encoder::OpusApplication::Voip, - OpusApplicationMode::Audio => super::encoder::OpusApplication::Audio, - OpusApplicationMode::LowDelay => super::encoder::OpusApplication::LowDelay, - }, - fec: self.fec, - } - } -} - -/// Opus application mode -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum OpusApplicationMode { - /// Voice over IP - optimized for speech - Voip, - /// General audio - balanced quality - Audio, - /// Low delay mode - minimal latency - LowDelay, -} - -/// Shared audio pipeline statistics -#[derive(Debug, Clone, Default)] -pub struct SharedAudioPipelineStats { - /// Frames received from audio capture - pub frames_received: u64, - /// Frames successfully encoded - pub frames_encoded: u64, - /// Frames dropped (encode errors) - pub frames_dropped: u64, - /// Total bytes encoded - pub bytes_encoded: u64, - /// Number of active subscribers - pub subscribers: u64, - /// Average encode time in milliseconds - pub avg_encode_time_ms: f32, - /// Current bitrate in bps - pub current_bitrate: u32, - /// Pipeline running time in seconds - pub running_time_secs: f64, -} - -/// Shared Audio Pipeline -/// -/// Provides a single Opus encoder that serves multiple WebRTC sessions. -/// All sessions receive the same encoded audio stream via broadcast channel. -pub struct SharedAudioPipeline { - /// Configuration - config: RwLock, - /// Opus encoder (protected by mutex for encoding) - encoder: Mutex>, - /// Broadcast sender for encoded Opus frames - opus_tx: broadcast::Sender, - /// Running state - running: AtomicBool, - /// Statistics - stats: Mutex, - /// Start time for running time calculation - start_time: RwLock>, - /// Encode time accumulator for averaging - encode_time_sum_us: AtomicU64, - /// Encode count for averaging - encode_count: AtomicU64, - /// Stop signal (atomic for lock-free checking) - stop_flag: AtomicBool, - /// Encoding task handle - task_handle: Mutex>>, -} - -impl SharedAudioPipeline { - /// Create a new shared audio pipeline - pub fn new(config: SharedAudioPipelineConfig) -> Result> { - let (opus_tx, _) = broadcast::channel(config.channel_capacity); - - Ok(Arc::new(Self { - config: RwLock::new(config), - encoder: Mutex::new(None), - opus_tx, - running: AtomicBool::new(false), - stats: Mutex::new(SharedAudioPipelineStats::default()), - start_time: RwLock::new(None), - encode_time_sum_us: AtomicU64::new(0), - encode_count: AtomicU64::new(0), - stop_flag: AtomicBool::new(false), - task_handle: Mutex::new(None), - })) - } - - /// Create with default configuration - pub fn default_config() -> Result> { - Self::new(SharedAudioPipelineConfig::default()) - } - - /// Start the audio encoding pipeline - /// - /// # Arguments - /// * `audio_rx` - Receiver for raw audio frames from AudioCapturer - pub async fn start(self: &Arc, audio_rx: broadcast::Receiver) -> Result<()> { - if self.running.load(Ordering::SeqCst) { - return Ok(()); - } - - let config = self.config.read().await.clone(); - - info!( - "Starting shared audio pipeline: {}Hz {}ch {}bps", - config.sample_rate, config.channels, config.bitrate - ); - - // Create encoder - let opus_config = config.to_opus_config(); - let encoder = OpusEncoder::new(opus_config)?; - *self.encoder.lock().await = Some(encoder); - - // Reset stats - { - let mut stats = self.stats.lock().await; - *stats = SharedAudioPipelineStats::default(); - stats.current_bitrate = config.bitrate; - } - - // Reset counters - self.encode_time_sum_us.store(0, Ordering::SeqCst); - self.encode_count.store(0, Ordering::SeqCst); - *self.start_time.write().await = Some(Instant::now()); - self.stop_flag.store(false, Ordering::SeqCst); - - self.running.store(true, Ordering::SeqCst); - - // Start encoding task - let pipeline = self.clone(); - let handle = tokio::spawn(async move { - pipeline.encoding_task(audio_rx).await; - }); - - *self.task_handle.lock().await = Some(handle); - - info!("Shared audio pipeline started"); - Ok(()) - } - - /// Stop the audio encoding pipeline - pub fn stop(&self) { - if !self.running.load(Ordering::SeqCst) { - return; - } - - info!("Stopping shared audio pipeline"); - - // Signal stop (atomic, no lock needed) - self.stop_flag.store(true, Ordering::SeqCst); - - self.running.store(false, Ordering::SeqCst); - } - - /// Check if pipeline is running - pub fn is_running(&self) -> bool { - self.running.load(Ordering::SeqCst) - } - - /// Subscribe to encoded Opus frames - pub fn subscribe(&self) -> broadcast::Receiver { - self.opus_tx.subscribe() - } - - /// Get number of active subscribers - pub fn subscriber_count(&self) -> usize { - self.opus_tx.receiver_count() - } - - /// Get current statistics - pub async fn stats(&self) -> SharedAudioPipelineStats { - let mut stats = self.stats.lock().await.clone(); - stats.subscribers = self.subscriber_count() as u64; - - // Calculate average encode time - let count = self.encode_count.load(Ordering::SeqCst); - if count > 0 { - let sum_us = self.encode_time_sum_us.load(Ordering::SeqCst); - stats.avg_encode_time_ms = (sum_us as f64 / count as f64 / 1000.0) as f32; - } - - // Calculate running time - if let Some(start) = *self.start_time.read().await { - stats.running_time_secs = start.elapsed().as_secs_f64(); - } - - stats - } - - /// Set bitrate dynamically - pub async fn set_bitrate(&self, bitrate: u32) -> Result<()> { - // Update config - self.config.write().await.bitrate = bitrate; - - // Update encoder if running - if let Some(ref mut encoder) = *self.encoder.lock().await { - encoder.set_bitrate(bitrate)?; - } - - // Update stats - self.stats.lock().await.current_bitrate = bitrate; - - info!("Shared audio pipeline bitrate changed to {}bps", bitrate); - Ok(()) - } - - /// Update configuration (requires restart) - pub async fn update_config(&self, config: SharedAudioPipelineConfig) -> Result<()> { - if self.is_running() { - return Err(AppError::AudioError( - "Cannot update config while pipeline is running".to_string(), - )); - } - - *self.config.write().await = config; - Ok(()) - } - - /// Internal encoding task - async fn encoding_task(self: Arc, mut audio_rx: broadcast::Receiver) { - info!("Audio encoding task started"); - - loop { - // Check stop flag (atomic, no async lock needed) - if self.stop_flag.load(Ordering::Relaxed) { - break; - } - - // Receive audio frame with timeout - let recv_result = - tokio::time::timeout(std::time::Duration::from_secs(2), audio_rx.recv()).await; - - match recv_result { - Ok(Ok(audio_frame)) => { - // Update received count - { - let mut stats = self.stats.lock().await; - stats.frames_received += 1; - } - - // Encode frame - let encode_start = Instant::now(); - let encode_result = { - let mut encoder_guard = self.encoder.lock().await; - if let Some(ref mut encoder) = *encoder_guard { - Some(encoder.encode_frame(&audio_frame)) - } else { - None - } - }; - let encode_time = encode_start.elapsed(); - - // Update encode time stats - self.encode_time_sum_us - .fetch_add(encode_time.as_micros() as u64, Ordering::SeqCst); - self.encode_count.fetch_add(1, Ordering::SeqCst); - - match encode_result { - Some(Ok(opus_frame)) => { - // Update stats - { - let mut stats = self.stats.lock().await; - stats.frames_encoded += 1; - stats.bytes_encoded += opus_frame.data.len() as u64; - } - - // Broadcast to subscribers - if self.opus_tx.receiver_count() > 0 { - if let Err(e) = self.opus_tx.send(opus_frame) { - trace!("No audio subscribers: {}", e); - } - } - } - Some(Err(e)) => { - error!("Opus encode error: {}", e); - let mut stats = self.stats.lock().await; - stats.frames_dropped += 1; - } - None => { - warn!("Encoder not available"); - break; - } - } - } - Ok(Err(broadcast::error::RecvError::Closed)) => { - info!("Audio source channel closed"); - break; - } - Ok(Err(broadcast::error::RecvError::Lagged(n))) => { - warn!("Audio pipeline lagged by {} frames", n); - let mut stats = self.stats.lock().await; - stats.frames_dropped += n; - } - Err(_) => { - // Timeout - check if still running - if !self.running.load(Ordering::SeqCst) { - break; - } - debug!("Audio receive timeout, continuing..."); - } - } - } - - // Cleanup - self.running.store(false, Ordering::SeqCst); - *self.encoder.lock().await = None; - - let stats = self.stats().await; - info!( - "Audio encoding task ended: {} frames encoded, {} dropped, {:.1}s runtime", - stats.frames_encoded, stats.frames_dropped, stats.running_time_secs - ); - } -} - -impl Drop for SharedAudioPipeline { - fn drop(&mut self) { - self.stop(); - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_config_default() { - let config = SharedAudioPipelineConfig::default(); - assert_eq!(config.sample_rate, 48000); - assert_eq!(config.channels, 2); - assert_eq!(config.bitrate, 64000); - } - - #[test] - fn test_config_voice() { - let config = SharedAudioPipelineConfig::voice(); - assert_eq!(config.bitrate, 32000); - assert_eq!(config.application, OpusApplicationMode::Voip); - } - - #[test] - fn test_config_high_quality() { - let config = SharedAudioPipelineConfig::high_quality(); - assert_eq!(config.bitrate, 128000); - } - - #[tokio::test] - async fn test_pipeline_creation() { - let config = SharedAudioPipelineConfig::default(); - let pipeline = SharedAudioPipeline::new(config); - assert!(pipeline.is_ok()); - - let pipeline = pipeline.unwrap(); - assert!(!pipeline.is_running()); - assert_eq!(pipeline.subscriber_count(), 0); - } -} diff --git a/src/audio/streamer.rs b/src/audio/streamer.rs index 3e208b0b..0d843e05 100644 --- a/src/audio/streamer.rs +++ b/src/audio/streamer.rs @@ -7,7 +7,7 @@ use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; use std::time::Instant; use tokio::sync::{broadcast, watch, Mutex, RwLock}; -use tracing::{error, info, trace, warn}; +use tracing::{error, info, warn}; use super::capture::{AudioCapturer, AudioConfig, CaptureState}; use super::encoder::{OpusConfig, OpusEncoder, OpusFrame}; @@ -72,18 +72,9 @@ impl AudioStreamerConfig { /// Audio stream statistics #[derive(Debug, Clone, Default)] pub struct AudioStreamStats { - /// Frames captured from ALSA - pub frames_captured: u64, /// Frames encoded to Opus - pub frames_encoded: u64, - /// Total bytes output (Opus) - pub bytes_output: u64, - /// Current encoding bitrate - pub current_bitrate: u32, /// Number of active subscribers pub subscriber_count: usize, - /// Buffer overruns - pub buffer_overruns: u64, } /// Audio streamer @@ -95,7 +86,7 @@ pub struct AudioStreamer { state_rx: watch::Receiver, capturer: RwLock>>, encoder: Arc>>, - opus_tx: broadcast::Sender, + opus_tx: watch::Sender>>, stats: Arc>, sequence: AtomicU64, stream_start_time: RwLock>, @@ -111,7 +102,7 @@ impl AudioStreamer { /// Create a new audio streamer with specified configuration pub fn with_config(config: AudioStreamerConfig) -> Self { let (state_tx, state_rx) = watch::channel(AudioStreamState::Stopped); - let (opus_tx, _) = broadcast::channel(16); // Buffer size 16 for low latency + let (opus_tx, _opus_rx) = watch::channel(None); Self { config: RwLock::new(config), @@ -138,7 +129,7 @@ impl AudioStreamer { } /// Subscribe to Opus frames - pub fn subscribe_opus(&self) -> broadcast::Receiver { + pub fn subscribe_opus(&self) -> watch::Receiver>> { self.opus_tx.subscribe() } @@ -175,9 +166,6 @@ impl AudioStreamer { encoder.set_bitrate(bitrate)?; } - // Update stats - self.stats.lock().await.current_bitrate = bitrate; - info!("Audio bitrate changed to {}bps", bitrate); Ok(()) } @@ -216,7 +204,6 @@ impl AudioStreamer { { let mut stats = self.stats.lock().await; *stats = AudioStreamStats::default(); - stats.current_bitrate = config.opus.bitrate; } // Record start time @@ -227,12 +214,11 @@ impl AudioStreamer { let capturer_for_task = capturer.clone(); let encoder = self.encoder.clone(); let opus_tx = self.opus_tx.clone(); - let stats = self.stats.clone(); let state = self.state.clone(); let stop_flag = self.stop_flag.clone(); tokio::spawn(async move { - Self::stream_task(capturer_for_task, encoder, opus_tx, stats, state, stop_flag).await; + Self::stream_task(capturer_for_task, encoder, opus_tx, state, stop_flag).await; }); Ok(()) @@ -273,8 +259,7 @@ impl AudioStreamer { async fn stream_task( capturer: Arc, encoder: Arc>>, - opus_tx: broadcast::Sender, - stats: Arc>, + opus_tx: watch::Sender>>, state: watch::Sender, stop_flag: Arc, ) { @@ -302,12 +287,6 @@ impl AudioStreamer { match recv_result { Ok(Ok(audio_frame)) => { - // Update capture stats - { - let mut s = stats.lock().await; - s.frames_captured += 1; - } - // Encode to Opus let opus_result = { let mut enc_guard = encoder.lock().await; @@ -320,18 +299,9 @@ impl AudioStreamer { match opus_result { Some(Ok(opus_frame)) => { - // Update stats - { - let mut s = stats.lock().await; - s.frames_encoded += 1; - s.bytes_output += opus_frame.data.len() as u64; - } - - // Broadcast to subscribers + // Publish latest frame to subscribers if opus_tx.receiver_count() > 0 { - if let Err(e) = opus_tx.send(opus_frame) { - trace!("No audio subscribers: {}", e); - } + let _ = opus_tx.send(Some(Arc::new(opus_frame))); } } Some(Err(e)) => { @@ -349,8 +319,6 @@ impl AudioStreamer { } Ok(Err(broadcast::error::RecvError::Lagged(n))) => { warn!("Audio receiver lagged by {} frames", n); - let mut s = stats.lock().await; - s.buffer_overruns += n; } Err(_) => { // Timeout - check if still capturing diff --git a/src/auth/middleware.rs b/src/auth/middleware.rs index 5954b700..5bbbd2f0 100644 --- a/src/auth/middleware.rs +++ b/src/auth/middleware.rs @@ -2,20 +2,18 @@ use axum::{ extract::{Request, State}, http::StatusCode, middleware::Next, - response::Response, + response::{IntoResponse, Response}, + Json, }; use axum_extra::extract::CookieJar; use std::sync::Arc; +use crate::error::ErrorResponse; use crate::state::AppState; /// Session cookie name pub const SESSION_COOKIE: &str = "one_kvm_session"; -/// Auth layer for extracting session from request -#[derive(Clone)] -pub struct AuthLayer; - /// Extract session ID from request pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap) -> Option { // First try cookie @@ -42,20 +40,20 @@ pub async fn auth_middleware( mut request: Request, next: Next, ) -> Result { + let raw_path = request.uri().path(); + // When this middleware is mounted under /api, Axum strips the prefix for the inner router. + // Normalize the path so checks work whether it is mounted or not. + let path = raw_path.strip_prefix("/api").unwrap_or(raw_path); + // Check if system is initialized if !state.config.is_initialized() { - // Allow access to setup endpoints when not initialized - let path = request.uri().path(); - if path.starts_with("/api/setup") - || path == "/api/info" - || path.starts_with("/") && !path.starts_with("/api/") - { + // Allow only setup-related endpoints when not initialized + if is_setup_public_endpoint(path) { return Ok(next.run(request).await); } } // Public endpoints that don't require auth - let path = request.uri().path(); if is_public_endpoint(path) { return Ok(next.run(request).await); } @@ -69,28 +67,36 @@ pub async fn auth_middleware( request.extensions_mut().insert(session); return Ok(next.run(request).await); } + + let message = if state.is_session_revoked(&session_id).await { + "Logged in elsewhere" + } else { + "Session expired" + }; + return Ok(unauthorized_response(message)); } - Err(StatusCode::UNAUTHORIZED) + Ok(unauthorized_response("Not authenticated")) +} + +fn unauthorized_response(message: &str) -> Response { + let body = ErrorResponse { + success: false, + message: message.to_string(), + }; + (StatusCode::UNAUTHORIZED, Json(body)).into_response() } /// Check if endpoint is public (no auth required) fn is_public_endpoint(path: &str) -> bool { - // Note: paths here are relative to /api since middleware is applied before nest + // Note: paths here are relative to /api since middleware is applied within the nested router matches!( path, "/" | "/auth/login" - | "/info" | "/health" | "/setup" | "/setup/init" - // Also check with /api prefix for direct access - | "/api/auth/login" - | "/api/info" - | "/api/health" - | "/api/setup" - | "/api/setup/init" ) || path.starts_with("/assets/") || path.starts_with("/static/") || path.ends_with(".js") @@ -100,46 +106,10 @@ fn is_public_endpoint(path: &str) -> bool { || path.ends_with(".svg") } -/// Require authentication - returns 401 if not authenticated -pub async fn require_auth( - State(state): State>, - cookies: CookieJar, - request: Request, - next: Next, -) -> Result { - let session_id = extract_session_id(&cookies, request.headers()); - - if let Some(session_id) = session_id { - if let Ok(Some(_session)) = state.sessions.get(&session_id).await { - return Ok(next.run(request).await); - } - } - - Err(StatusCode::UNAUTHORIZED) -} - -/// Require admin privileges - returns 403 if not admin -pub async fn require_admin( - State(state): State>, - cookies: CookieJar, - request: Request, - next: Next, -) -> Result { - let session_id = extract_session_id(&cookies, request.headers()); - - if let Some(session_id) = session_id { - if let Ok(Some(session)) = state.sessions.get(&session_id).await { - // Get user and check admin status - if let Ok(Some(user)) = state.users.get(&session.user_id).await { - if user.is_admin { - return Ok(next.run(request).await); - } - // User is authenticated but not admin - return Err(StatusCode::FORBIDDEN); - } - } - } - - // Not authenticated at all - Err(StatusCode::UNAUTHORIZED) +/// Setup-only endpoints allowed before initialization. +fn is_setup_public_endpoint(path: &str) -> bool { + matches!( + path, + "/setup" | "/setup/init" | "/devices" | "/stream/codecs" + ) } diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 0e5147be..8d9ba479 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -3,7 +3,7 @@ mod password; mod session; mod user; -pub use middleware::{auth_middleware, require_admin, AuthLayer, SESSION_COOKIE}; +pub use middleware::{auth_middleware, SESSION_COOKIE}; pub use password::{hash_password, verify_password}; pub use session::{Session, SessionStore}; pub use user::{User, UserStore}; diff --git a/src/auth/session.rs b/src/auth/session.rs index 653a1443..9902acc8 100644 --- a/src/auth/session.rs +++ b/src/auth/session.rs @@ -116,6 +116,22 @@ impl SessionStore { Ok(result.rows_affected()) } + /// Delete all sessions + pub async fn delete_all(&self) -> Result { + let result = sqlx::query("DELETE FROM sessions") + .execute(&self.pool) + .await?; + Ok(result.rows_affected()) + } + + /// List all session IDs + pub async fn list_ids(&self) -> Result> { + let rows: Vec<(String,)> = sqlx::query_as("SELECT id FROM sessions") + .fetch_all(&self.pool) + .await?; + Ok(rows.into_iter().map(|(id,)| id).collect()) + } + /// Extend session expiration pub async fn extend(&self, session_id: &str) -> Result<()> { let new_expires = Utc::now() + self.default_ttl; diff --git a/src/auth/user.rs b/src/auth/user.rs index 8c68cb79..f731f52f 100644 --- a/src/auth/user.rs +++ b/src/auth/user.rs @@ -149,6 +149,33 @@ impl UserStore { Ok(()) } + /// Update username + pub async fn update_username(&self, user_id: &str, new_username: &str) -> Result<()> { + if let Some(existing) = self.get_by_username(new_username).await? { + if existing.id != user_id { + return Err(AppError::BadRequest(format!( + "Username '{}' already exists", + new_username + ))); + } + } + + let now = Utc::now(); + let result = + sqlx::query("UPDATE users SET username = ?1, updated_at = ?2 WHERE id = ?3") + .bind(new_username) + .bind(now.to_rfc3339()) + .bind(user_id) + .execute(&self.pool) + .await?; + + if result.rows_affected() == 0 { + return Err(AppError::NotFound("User not found".to_string())); + } + + Ok(()) + } + /// List all users pub async fn list(&self) -> Result> { let rows: Vec = sqlx::query_as( diff --git a/src/config/schema.rs b/src/config/schema.rs index 7fa53497..63abe268 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -61,6 +61,8 @@ impl Default for AppConfig { pub struct AuthConfig { /// Session timeout in seconds pub session_timeout_secs: u32, + /// Allow multiple concurrent web sessions (single-user mode) + pub single_user_allow_multiple_sessions: bool, /// Enable 2FA pub totp_enabled: bool, /// TOTP secret (encrypted) @@ -71,6 +73,7 @@ impl Default for AuthConfig { fn default() -> Self { Self { session_timeout_secs: 3600 * 24, // 24 hours + single_user_allow_multiple_sessions: false, totp_enabled: false, totp_secret: None, } @@ -156,6 +159,106 @@ impl Default for OtgDescriptorConfig { } } +/// OTG HID function profile +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum OtgHidProfile { + /// Full HID device set (keyboard + relative mouse + absolute mouse + consumer control) + Full, + /// Full HID device set without MSD + FullNoMsd, + /// Full HID device set without consumer control + FullNoConsumer, + /// Full HID device set without consumer control and MSD + FullNoConsumerNoMsd, + /// Legacy profile: only keyboard + LegacyKeyboard, + /// Legacy profile: only relative mouse + LegacyMouseRelative, + /// Custom function selection + Custom, +} + +impl Default for OtgHidProfile { + fn default() -> Self { + Self::Full + } +} + +/// OTG HID function selection (used when profile is Custom) +#[typeshare] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(default)] +pub struct OtgHidFunctions { + pub keyboard: bool, + pub mouse_relative: bool, + pub mouse_absolute: bool, + pub consumer: bool, +} + +impl OtgHidFunctions { + pub fn full() -> Self { + Self { + keyboard: true, + mouse_relative: true, + mouse_absolute: true, + consumer: true, + } + } + + pub fn full_no_consumer() -> Self { + Self { + keyboard: true, + mouse_relative: true, + mouse_absolute: true, + consumer: false, + } + } + + pub fn legacy_keyboard() -> Self { + Self { + keyboard: true, + mouse_relative: false, + mouse_absolute: false, + consumer: false, + } + } + + pub fn legacy_mouse_relative() -> Self { + Self { + keyboard: false, + mouse_relative: true, + mouse_absolute: false, + consumer: false, + } + } + + pub fn is_empty(&self) -> bool { + !self.keyboard && !self.mouse_relative && !self.mouse_absolute && !self.consumer + } +} + +impl Default for OtgHidFunctions { + fn default() -> Self { + Self::full() + } +} + +impl OtgHidProfile { + pub fn resolve_functions(&self, custom: &OtgHidFunctions) -> OtgHidFunctions { + match self { + Self::Full => OtgHidFunctions::full(), + Self::FullNoMsd => OtgHidFunctions::full(), + Self::FullNoConsumer => OtgHidFunctions::full_no_consumer(), + Self::FullNoConsumerNoMsd => OtgHidFunctions::full_no_consumer(), + Self::LegacyKeyboard => OtgHidFunctions::legacy_keyboard(), + Self::LegacyMouseRelative => OtgHidFunctions::legacy_mouse_relative(), + Self::Custom => custom.clone(), + } + } +} + /// HID configuration #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -172,6 +275,12 @@ pub struct HidConfig { /// OTG USB device descriptor configuration #[serde(default)] pub otg_descriptor: OtgDescriptorConfig, + /// OTG HID function profile + #[serde(default)] + pub otg_profile: OtgHidProfile, + /// OTG HID function selection (used when profile is Custom) + #[serde(default)] + pub otg_functions: OtgHidFunctions, /// CH9329 serial port pub ch9329_port: String, /// CH9329 baud rate @@ -188,6 +297,8 @@ impl Default for HidConfig { otg_mouse: "/dev/hidg1".to_string(), otg_udc: None, otg_descriptor: OtgDescriptorConfig::default(), + otg_profile: OtgHidProfile::default(), + otg_functions: OtgHidFunctions::default(), ch9329_port: "/dev/ttyUSB0".to_string(), ch9329_baudrate: 9600, mouse_absolute: true, @@ -195,6 +306,12 @@ impl Default for HidConfig { } } +impl HidConfig { + pub fn effective_otg_functions(&self) -> OtgHidFunctions { + self.otg_profile.resolve_functions(&self.otg_functions) + } +} + /// MSD configuration #[typeshare] #[derive(Debug, Clone, Serialize, Deserialize)] @@ -459,7 +576,9 @@ pub struct WebConfig { pub http_port: u16, /// HTTPS port pub https_port: u16, - /// Bind address + /// Bind addresses (preferred) + pub bind_addresses: Vec, + /// Bind address (legacy) pub bind_address: String, /// Enable HTTPS pub https_enabled: bool, @@ -474,6 +593,7 @@ impl Default for WebConfig { Self { http_port: 8080, https_port: 8443, + bind_addresses: Vec::new(), bind_address: "0.0.0.0".to_string(), https_enabled: false, ssl_cert_path: None, diff --git a/src/events/types.rs b/src/events/types.rs index 38835088..44a69b2d 100644 --- a/src/events/types.rs +++ b/src/events/types.rs @@ -222,6 +222,22 @@ pub enum SystemEvent { hardware: bool, }, + /// WebRTC ICE candidate (server -> client trickle) + #[serde(rename = "webrtc.ice_candidate")] + WebRTCIceCandidate { + /// WebRTC session ID + session_id: String, + /// ICE candidate data + candidate: crate::webrtc::signaling::IceCandidate, + }, + + /// WebRTC ICE gathering complete (server -> client) + #[serde(rename = "webrtc.ice_complete")] + WebRTCIceComplete { + /// WebRTC session ID + session_id: String, + }, + /// Stream statistics update (sent periodically for client stats) #[serde(rename = "stream.stats_update")] StreamStatsUpdate { @@ -539,6 +555,8 @@ impl SystemEvent { Self::StreamStatsUpdate { .. } => "stream.stats_update", Self::StreamModeChanged { .. } => "stream.mode_changed", Self::StreamModeReady { .. } => "stream.mode_ready", + Self::WebRTCIceCandidate { .. } => "webrtc.ice_candidate", + Self::WebRTCIceComplete { .. } => "webrtc.ice_complete", Self::HidStateChanged { .. } => "hid.state_changed", Self::HidBackendSwitching { .. } => "hid.backend_switching", Self::HidDeviceLost { .. } => "hid.device_lost", diff --git a/src/hid/ch9329.rs b/src/hid/ch9329.rs index 5423627c..0893a49e 100644 --- a/src/hid/ch9329.rs +++ b/src/hid/ch9329.rs @@ -395,6 +395,8 @@ pub struct Ch9329Backend { last_abs_x: AtomicU16, /// Last absolute mouse Y position (CH9329 coordinate: 0-4095) last_abs_y: AtomicU16, + /// Whether relative mouse mode is active (set by incoming events) + relative_mouse_active: AtomicBool, /// Consecutive error count error_count: AtomicU32, /// Whether a reset is in progress @@ -426,6 +428,7 @@ impl Ch9329Backend { address: DEFAULT_ADDR, last_abs_x: AtomicU16::new(0), last_abs_y: AtomicU16::new(0), + relative_mouse_active: AtomicBool::new(false), error_count: AtomicU32::new(0), reset_in_progress: AtomicBool::new(false), last_success: Mutex::new(None), @@ -1014,12 +1017,14 @@ impl HidBackend for Ch9329Backend { match event.event_type { MouseEventType::Move => { // Relative movement - send delta directly without inversion + self.relative_mouse_active.store(true, Ordering::Relaxed); let dx = event.x.clamp(-127, 127) as i8; let dy = event.y.clamp(-127, 127) as i8; self.send_mouse_relative(buttons, dx, dy, 0)?; } MouseEventType::MoveAbs => { // Absolute movement + self.relative_mouse_active.store(false, Ordering::Relaxed); // Frontend sends 0-32767 (HID standard), CH9329 expects 0-4095 let x = ((event.x.clamp(0, 32767) as u32) * CH9329_MOUSE_RESOLUTION / 32768) as u16; let y = ((event.y.clamp(0, 32767) as u32) * CH9329_MOUSE_RESOLUTION / 32768) as u16; @@ -1031,28 +1036,40 @@ impl HidBackend for Ch9329Backend { MouseEventType::Down => { if let Some(button) = event.button { let bit = button.to_hid_bit(); - let x = self.last_abs_x.load(Ordering::Relaxed); - let y = self.last_abs_y.load(Ordering::Relaxed); let new_buttons = self.mouse_buttons.fetch_or(bit, Ordering::Relaxed) | bit; trace!("Mouse down: {:?} buttons=0x{:02X}", button, new_buttons); - self.send_mouse_absolute(new_buttons, x, y, 0)?; + if self.relative_mouse_active.load(Ordering::Relaxed) { + self.send_mouse_relative(new_buttons, 0, 0, 0)?; + } else { + let x = self.last_abs_x.load(Ordering::Relaxed); + let y = self.last_abs_y.load(Ordering::Relaxed); + self.send_mouse_absolute(new_buttons, x, y, 0)?; + } } } MouseEventType::Up => { if let Some(button) = event.button { let bit = button.to_hid_bit(); - let x = self.last_abs_x.load(Ordering::Relaxed); - let y = self.last_abs_y.load(Ordering::Relaxed); let new_buttons = self.mouse_buttons.fetch_and(!bit, Ordering::Relaxed) & !bit; trace!("Mouse up: {:?} buttons=0x{:02X}", button, new_buttons); - self.send_mouse_absolute(new_buttons, x, y, 0)?; + if self.relative_mouse_active.load(Ordering::Relaxed) { + self.send_mouse_relative(new_buttons, 0, 0, 0)?; + } else { + let x = self.last_abs_x.load(Ordering::Relaxed); + let y = self.last_abs_y.load(Ordering::Relaxed); + self.send_mouse_absolute(new_buttons, x, y, 0)?; + } } } MouseEventType::Scroll => { - // Use absolute mouse for scroll with last position - let x = self.last_abs_x.load(Ordering::Relaxed); - let y = self.last_abs_y.load(Ordering::Relaxed); - self.send_mouse_absolute(buttons, x, y, event.scroll)?; + if self.relative_mouse_active.load(Ordering::Relaxed) { + self.send_mouse_relative(buttons, 0, 0, event.scroll)?; + } else { + // Use absolute mouse for scroll with last position + let x = self.last_abs_x.load(Ordering::Relaxed); + let y = self.last_abs_y.load(Ordering::Relaxed); + self.send_mouse_absolute(buttons, x, y, event.scroll)?; + } } } @@ -1073,6 +1090,7 @@ impl HidBackend for Ch9329Backend { self.mouse_buttons.store(0, Ordering::Relaxed); self.last_abs_x.store(0, Ordering::Relaxed); self.last_abs_y.store(0, Ordering::Relaxed); + self.relative_mouse_active.store(false, Ordering::Relaxed); self.send_mouse_absolute(0, 0, 0, 0)?; // Reset media keys diff --git a/src/hid/mod.rs b/src/hid/mod.rs index 46e4c45b..611bdea8 100644 --- a/src/hid/mod.rs +++ b/src/hid/mod.rs @@ -43,24 +43,52 @@ pub struct HidInfo { } use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use tokio::sync::RwLock; use tracing::{info, warn}; use crate::error::{AppError, Result}; use crate::otg::OtgService; +use tokio::sync::mpsc; +use tokio::sync::Mutex; +use tokio::task::JoinHandle; +use std::time::Duration; + +const HID_EVENT_QUEUE_CAPACITY: usize = 64; +const HID_EVENT_SEND_TIMEOUT_MS: u64 = 30; + +#[derive(Debug)] +enum HidEvent { + Keyboard(KeyboardEvent), + Mouse(MouseEvent), + Consumer(ConsumerEvent), + Reset, +} /// HID controller managing keyboard and mouse input pub struct HidController { /// OTG Service reference (only used when backend is OTG) otg_service: Option>, /// Active backend - backend: Arc>>>, + backend: Arc>>>, /// Backend type (mutable for reload) - backend_type: RwLock, + backend_type: Arc>, /// Event bus for broadcasting state changes (optional) events: tokio::sync::RwLock>>, /// Health monitor for error tracking and recovery monitor: Arc, + /// HID event queue sender (non-blocking) + hid_tx: mpsc::Sender, + /// HID event queue receiver (moved into worker on first start) + hid_rx: Mutex>>, + /// Coalesced mouse move (latest) + pending_move: Arc>>, + /// Pending move flag (fast path) + pending_move_flag: Arc, + /// Worker task handle + hid_worker: Mutex>>, + /// Backend availability fast flag + backend_available: AtomicBool, } impl HidController { @@ -68,12 +96,19 @@ impl HidController { /// /// For OTG backend, otg_service should be provided to support hot-reload pub fn new(backend_type: HidBackendType, otg_service: Option>) -> Self { + let (hid_tx, hid_rx) = mpsc::channel(HID_EVENT_QUEUE_CAPACITY); Self { otg_service, backend: Arc::new(RwLock::new(None)), - backend_type: RwLock::new(backend_type), + backend_type: Arc::new(RwLock::new(backend_type)), events: tokio::sync::RwLock::new(None), monitor: Arc::new(HidHealthMonitor::with_defaults()), + hid_tx, + hid_rx: Mutex::new(Some(hid_rx)), + pending_move: Arc::new(parking_lot::Mutex::new(None)), + pending_move_flag: Arc::new(AtomicBool::new(false)), + hid_worker: Mutex::new(None), + backend_available: AtomicBool::new(false), } } @@ -87,7 +122,7 @@ impl HidController { /// Initialize the HID backend pub async fn init(&self) -> Result<()> { let backend_type = self.backend_type.read().await.clone(); - let backend: Box = match backend_type { + let backend: Arc = match backend_type { HidBackendType::Otg => { // Request HID functions from OtgService let otg_service = self @@ -100,7 +135,7 @@ impl HidController { // Create OtgBackend from handles (no longer manages gadget itself) info!("Creating OTG HID backend from device paths"); - Box::new(otg::OtgBackend::from_handles(handles)?) + Arc::new(otg::OtgBackend::from_handles(handles)?) } HidBackendType::Ch9329 { ref port, @@ -110,7 +145,7 @@ impl HidController { "Initializing CH9329 HID backend on {} @ {} baud", port, baud_rate ); - Box::new(ch9329::Ch9329Backend::with_baud_rate(port, baud_rate)?) + Arc::new(ch9329::Ch9329Backend::with_baud_rate(port, baud_rate)?) } HidBackendType::None => { warn!("HID backend disabled"); @@ -120,6 +155,10 @@ impl HidController { backend.init().await?; *self.backend.write().await = Some(backend); + self.backend_available.store(true, Ordering::Release); + + // Start HID event worker (once) + self.start_event_worker().await; info!("HID backend initialized: {:?}", backend_type); Ok(()) @@ -131,6 +170,7 @@ impl HidController { // Close the backend *self.backend.write().await = None; + self.backend_available.store(false, Ordering::Release); // If OTG backend, notify OtgService to disable HID let backend_type = self.backend_type.read().await.clone(); @@ -147,125 +187,47 @@ impl HidController { /// Send keyboard event pub async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> { - let backend = self.backend.read().await; - match backend.as_ref() { - Some(b) => { - match b.send_keyboard(event).await { - Ok(_) => { - // Check if we were in an error state and now recovered - if self.monitor.is_error().await { - let backend_type = self.backend_type.read().await; - self.monitor.report_recovered(backend_type.name_str()).await; - } - Ok(()) - } - Err(e) => { - // Report error to monitor, but skip temporary EAGAIN retries - // - "eagain_retry": within threshold, just temporary busy - // - "eagain": exceeded threshold, report as error - if let AppError::HidError { - ref backend, - ref reason, - ref error_code, - } = e - { - if error_code != "eagain_retry" { - self.monitor - .report_error(backend, None, reason, error_code) - .await; - } - } - Err(e) - } - } - } - None => Err(AppError::BadRequest( + if !self.backend_available.load(Ordering::Acquire) { + return Err(AppError::BadRequest( "HID backend not available".to_string(), - )), + )); } + self.enqueue_event(HidEvent::Keyboard(event)).await } /// Send mouse event pub async fn send_mouse(&self, event: MouseEvent) -> Result<()> { - let backend = self.backend.read().await; - match backend.as_ref() { - Some(b) => { - match b.send_mouse(event).await { - Ok(_) => { - // Check if we were in an error state and now recovered - if self.monitor.is_error().await { - let backend_type = self.backend_type.read().await; - self.monitor.report_recovered(backend_type.name_str()).await; - } - Ok(()) - } - Err(e) => { - // Report error to monitor, but skip temporary EAGAIN retries - // - "eagain_retry": within threshold, just temporary busy - // - "eagain": exceeded threshold, report as error - if let AppError::HidError { - ref backend, - ref reason, - ref error_code, - } = e - { - if error_code != "eagain_retry" { - self.monitor - .report_error(backend, None, reason, error_code) - .await; - } - } - Err(e) - } - } - } - None => Err(AppError::BadRequest( + if !self.backend_available.load(Ordering::Acquire) { + return Err(AppError::BadRequest( "HID backend not available".to_string(), - )), + )); + } + + if matches!(event.event_type, MouseEventType::Move | MouseEventType::MoveAbs) { + // Best-effort: drop/merge move events if queue is full + self.enqueue_mouse_move(event) + } else { + self.enqueue_event(HidEvent::Mouse(event)).await } } /// Send consumer control event (multimedia keys) pub async fn send_consumer(&self, event: ConsumerEvent) -> Result<()> { - let backend = self.backend.read().await; - match backend.as_ref() { - Some(b) => match b.send_consumer(event).await { - Ok(_) => { - if self.monitor.is_error().await { - let backend_type = self.backend_type.read().await; - self.monitor.report_recovered(backend_type.name_str()).await; - } - Ok(()) - } - Err(e) => { - if let AppError::HidError { - ref backend, - ref reason, - ref error_code, - } = e - { - if error_code != "eagain_retry" { - self.monitor - .report_error(backend, None, reason, error_code) - .await; - } - } - Err(e) - } - }, - None => Err(AppError::BadRequest( + if !self.backend_available.load(Ordering::Acquire) { + return Err(AppError::BadRequest( "HID backend not available".to_string(), - )), + )); } + self.enqueue_event(HidEvent::Consumer(event)).await } /// Reset all keys (release all pressed keys) pub async fn reset(&self) -> Result<()> { - let backend = self.backend.read().await; - match backend.as_ref() { - Some(b) => b.reset().await, - None => Ok(()), + if !self.backend_available.load(Ordering::Acquire) { + return Ok(()); } + // Reset is important but best-effort; enqueue to avoid blocking + self.enqueue_event(HidEvent::Reset).await } /// Check if backend is available @@ -332,6 +294,7 @@ impl HidController { /// Reload the HID backend with new type pub async fn reload(&self, new_backend_type: HidBackendType) -> Result<()> { info!("Reloading HID backend: {:?}", new_backend_type); + self.backend_available.store(false, Ordering::Release); // Shutdown existing backend first if let Some(backend) = self.backend.write().await.take() { @@ -341,7 +304,7 @@ impl HidController { } // Create and initialize new backend - let new_backend: Option> = match new_backend_type { + let new_backend: Option> = match new_backend_type { HidBackendType::Otg => { info!("Initializing OTG HID backend"); @@ -362,11 +325,11 @@ impl HidController { // Create OtgBackend from handles match otg::OtgBackend::from_handles(handles) { Ok(backend) => { - let boxed: Box = Box::new(backend); - match boxed.init().await { + let backend = Arc::new(backend); + match backend.init().await { Ok(_) => { info!("OTG backend initialized successfully"); - Some(boxed) + Some(backend) } Err(e) => { warn!("Failed to initialize OTG backend: {}", e); @@ -407,9 +370,9 @@ impl HidController { ); match ch9329::Ch9329Backend::with_baud_rate(port, baud_rate) { Ok(b) => { - let boxed = Box::new(b); - match boxed.init().await { - Ok(_) => Some(boxed), + let backend = Arc::new(b); + match backend.init().await { + Ok(_) => Some(backend), Err(e) => { warn!("Failed to initialize CH9329 backend: {}", e); None @@ -432,6 +395,8 @@ impl HidController { if self.backend.read().await.is_some() { info!("HID backend reloaded successfully: {:?}", new_backend_type); + self.backend_available.store(true, Ordering::Release); + self.start_event_worker().await; // Update backend_type on success *self.backend_type.write().await = new_backend_type.clone(); @@ -452,6 +417,7 @@ impl HidController { Ok(()) } else { warn!("HID backend reload resulted in no active backend"); + self.backend_available.store(false, Ordering::Release); // Update backend_type even on failure (to reflect the attempted change) *self.backend_type.write().await = new_backend_type.clone(); @@ -477,6 +443,148 @@ impl HidController { events.publish(event); } } + + async fn start_event_worker(&self) { + let mut worker_guard = self.hid_worker.lock().await; + if worker_guard.is_some() { + return; + } + + let mut rx_guard = self.hid_rx.lock().await; + let rx = match rx_guard.take() { + Some(rx) => rx, + None => return, + }; + + let backend = self.backend.clone(); + let monitor = self.monitor.clone(); + let backend_type = self.backend_type.clone(); + let pending_move = self.pending_move.clone(); + let pending_move_flag = self.pending_move_flag.clone(); + + let handle = tokio::spawn(async move { + let mut rx = rx; + loop { + let event = match rx.recv().await { + Some(ev) => ev, + None => break, + }; + + process_hid_event( + event, + &backend, + &monitor, + &backend_type, + ) + .await; + + // After each event, flush latest move if pending + if pending_move_flag.swap(false, Ordering::AcqRel) { + let move_event = { pending_move.lock().take() }; + if let Some(move_event) = move_event { + process_hid_event( + HidEvent::Mouse(move_event), + &backend, + &monitor, + &backend_type, + ) + .await; + } + } + } + }); + + *worker_guard = Some(handle); + } + + fn enqueue_mouse_move(&self, event: MouseEvent) -> Result<()> { + match self.hid_tx.try_send(HidEvent::Mouse(event.clone())) { + Ok(_) => Ok(()), + Err(mpsc::error::TrySendError::Full(_)) => { + *self.pending_move.lock() = Some(event); + self.pending_move_flag.store(true, Ordering::Release); + Ok(()) + } + Err(mpsc::error::TrySendError::Closed(_)) => Err(AppError::BadRequest( + "HID event queue closed".to_string(), + )), + } + } + + async fn enqueue_event(&self, event: HidEvent) -> Result<()> { + match self.hid_tx.try_send(event) { + Ok(_) => Ok(()), + Err(mpsc::error::TrySendError::Full(ev)) => { + // For non-move events, wait briefly to avoid dropping critical input + let tx = self.hid_tx.clone(); + let send_result = + tokio::time::timeout(Duration::from_millis(HID_EVENT_SEND_TIMEOUT_MS), tx.send(ev)) + .await; + if send_result.is_ok() { + Ok(()) + } else { + warn!("HID event queue full, dropping event"); + Ok(()) + } + } + Err(mpsc::error::TrySendError::Closed(_)) => Err(AppError::BadRequest( + "HID event queue closed".to_string(), + )), + } + } +} + +async fn process_hid_event( + event: HidEvent, + backend: &Arc>>>, + monitor: &Arc, + backend_type: &Arc>, +) { + let backend_opt = backend.read().await.clone(); + let backend = match backend_opt { + Some(b) => b, + None => return, + }; + + let result = tokio::task::spawn_blocking(move || { + futures::executor::block_on(async move { + match event { + HidEvent::Keyboard(ev) => backend.send_keyboard(ev).await, + HidEvent::Mouse(ev) => backend.send_mouse(ev).await, + HidEvent::Consumer(ev) => backend.send_consumer(ev).await, + HidEvent::Reset => backend.reset().await, + } + }) + }) + .await; + + let result = match result { + Ok(r) => r, + Err(_) => return, + }; + + match result { + Ok(_) => { + if monitor.is_error().await { + let backend_type = backend_type.read().await; + monitor.report_recovered(backend_type.name_str()).await; + } + } + Err(e) => { + if let AppError::HidError { + ref backend, + ref reason, + ref error_code, + } = e + { + if error_code != "eagain_retry" { + monitor + .report_error(backend, None, reason, error_code) + .await; + } + } + } + } } impl Default for HidController { diff --git a/src/hid/otg.rs b/src/hid/otg.rs index 1c34f6e6..b21917d4 100644 --- a/src/hid/otg.rs +++ b/src/hid/otg.rs @@ -109,13 +109,13 @@ impl LedState { /// reopened on the next operation attempt. pub struct OtgBackend { /// Keyboard device path (/dev/hidg0) - keyboard_path: PathBuf, + keyboard_path: Option, /// Relative mouse device path (/dev/hidg1) - mouse_rel_path: PathBuf, + mouse_rel_path: Option, /// Absolute mouse device path (/dev/hidg2) - mouse_abs_path: PathBuf, + mouse_abs_path: Option, /// Consumer control device path (/dev/hidg3) - consumer_path: PathBuf, + consumer_path: Option, /// Keyboard device file keyboard_dev: Mutex>, /// Relative mouse device file @@ -145,7 +145,7 @@ pub struct OtgBackend { } /// Write timeout in milliseconds (same as JetKVM's hidWriteTimeout) -const HID_WRITE_TIMEOUT_MS: i32 = 500; +const HID_WRITE_TIMEOUT_MS: i32 = 20; impl OtgBackend { /// Create OTG backend from device paths provided by OtgService @@ -157,9 +157,7 @@ impl OtgBackend { keyboard_path: paths.keyboard, mouse_rel_path: paths.mouse_relative, mouse_abs_path: paths.mouse_absolute, - consumer_path: paths - .consumer - .unwrap_or_else(|| PathBuf::from("/dev/hidg3")), + consumer_path: paths.consumer, keyboard_dev: Mutex::new(None), mouse_rel_dev: Mutex::new(None), mouse_abs_dev: Mutex::new(None), @@ -300,13 +298,25 @@ impl OtgBackend { /// 2. If handle is None but path exists, reopen the device /// 3. Return whether the device is ready for I/O fn ensure_device(&self, device_type: DeviceType) -> Result<()> { - let (path, dev_mutex) = match device_type { + let (path_opt, dev_mutex) = match device_type { DeviceType::Keyboard => (&self.keyboard_path, &self.keyboard_dev), DeviceType::MouseRelative => (&self.mouse_rel_path, &self.mouse_rel_dev), DeviceType::MouseAbsolute => (&self.mouse_abs_path, &self.mouse_abs_dev), DeviceType::ConsumerControl => (&self.consumer_path, &self.consumer_dev), }; + let path = match path_opt { + Some(p) => p, + None => { + self.online.store(false, Ordering::Relaxed); + return Err(AppError::HidError { + backend: "otg".to_string(), + reason: "Device disabled".to_string(), + error_code: "disabled".to_string(), + }); + } + }; + // Check if device path exists if !path.exists() { // Close the device if open (device was removed) @@ -383,20 +393,40 @@ impl OtgBackend { /// Check if all HID device files exist pub fn check_devices_exist(&self) -> bool { - self.keyboard_path.exists() && self.mouse_rel_path.exists() && self.mouse_abs_path.exists() + self.keyboard_path + .as_ref() + .map_or(true, |p| p.exists()) + && self + .mouse_rel_path + .as_ref() + .map_or(true, |p| p.exists()) + && self + .mouse_abs_path + .as_ref() + .map_or(true, |p| p.exists()) + && self + .consumer_path + .as_ref() + .map_or(true, |p| p.exists()) } /// Get list of missing device paths pub fn get_missing_devices(&self) -> Vec { let mut missing = Vec::new(); - if !self.keyboard_path.exists() { - missing.push(self.keyboard_path.display().to_string()); + if let Some(ref path) = self.keyboard_path { + if !path.exists() { + missing.push(path.display().to_string()); + } } - if !self.mouse_rel_path.exists() { - missing.push(self.mouse_rel_path.display().to_string()); + if let Some(ref path) = self.mouse_rel_path { + if !path.exists() { + missing.push(path.display().to_string()); + } } - if !self.mouse_abs_path.exists() { - missing.push(self.mouse_abs_path.display().to_string()); + if let Some(ref path) = self.mouse_abs_path { + if !path.exists() { + missing.push(path.display().to_string()); + } } missing } @@ -407,6 +437,10 @@ impl OtgBackend { /// ESHUTDOWN errors by closing the device handle for later reconnection. /// Uses write_with_timeout to avoid blocking on busy devices. fn send_keyboard_report(&self, report: &KeyboardReport) -> Result<()> { + if self.keyboard_path.is_none() { + return Ok(()); + } + // Ensure device is ready self.ensure_device(DeviceType::Keyboard)?; @@ -472,6 +506,10 @@ impl OtgBackend { /// ESHUTDOWN errors by closing the device handle for later reconnection. /// Uses write_with_timeout to avoid blocking on busy devices. fn send_mouse_report_relative(&self, buttons: u8, dx: i8, dy: i8, wheel: i8) -> Result<()> { + if self.mouse_rel_path.is_none() { + return Ok(()); + } + // Ensure device is ready self.ensure_device(DeviceType::MouseRelative)?; @@ -534,6 +572,10 @@ impl OtgBackend { /// ESHUTDOWN errors by closing the device handle for later reconnection. /// Uses write_with_timeout to avoid blocking on busy devices. fn send_mouse_report_absolute(&self, buttons: u8, x: u16, y: u16, wheel: i8) -> Result<()> { + if self.mouse_abs_path.is_none() { + return Ok(()); + } + // Ensure device is ready self.ensure_device(DeviceType::MouseAbsolute)?; @@ -600,6 +642,10 @@ impl OtgBackend { /// /// Sends a consumer control usage code and then releases it (sends 0x0000). fn send_consumer_report(&self, usage: u16) -> Result<()> { + if self.consumer_path.is_none() { + return Ok(()); + } + // Ensure device is ready self.ensure_device(DeviceType::ConsumerControl)?; @@ -708,71 +754,72 @@ impl HidBackend for OtgBackend { } // Wait for devices to appear (they should already exist from OtgService) - let device_paths = vec![ - self.keyboard_path.clone(), - self.mouse_rel_path.clone(), - self.mouse_abs_path.clone(), - ]; + let mut device_paths = Vec::new(); + if let Some(ref path) = self.keyboard_path { + device_paths.push(path.clone()); + } + if let Some(ref path) = self.mouse_rel_path { + device_paths.push(path.clone()); + } + if let Some(ref path) = self.mouse_abs_path { + device_paths.push(path.clone()); + } + if let Some(ref path) = self.consumer_path { + device_paths.push(path.clone()); + } + + if device_paths.is_empty() { + return Err(AppError::Internal( + "No HID devices configured for OTG backend".into(), + )); + } if !wait_for_hid_devices(&device_paths, 2000).await { return Err(AppError::Internal("HID devices did not appear".into())); } // Open keyboard device - if self.keyboard_path.exists() { - let file = Self::open_device(&self.keyboard_path)?; - *self.keyboard_dev.lock() = Some(file); - info!("Keyboard device opened: {}", self.keyboard_path.display()); - } else { - warn!( - "Keyboard device not found: {}", - self.keyboard_path.display() - ); + if let Some(ref path) = self.keyboard_path { + if path.exists() { + let file = Self::open_device(path)?; + *self.keyboard_dev.lock() = Some(file); + info!("Keyboard device opened: {}", path.display()); + } else { + warn!("Keyboard device not found: {}", path.display()); + } } // Open relative mouse device - if self.mouse_rel_path.exists() { - let file = Self::open_device(&self.mouse_rel_path)?; - *self.mouse_rel_dev.lock() = Some(file); - info!( - "Relative mouse device opened: {}", - self.mouse_rel_path.display() - ); - } else { - warn!( - "Relative mouse device not found: {}", - self.mouse_rel_path.display() - ); + if let Some(ref path) = self.mouse_rel_path { + if path.exists() { + let file = Self::open_device(path)?; + *self.mouse_rel_dev.lock() = Some(file); + info!("Relative mouse device opened: {}", path.display()); + } else { + warn!("Relative mouse device not found: {}", path.display()); + } } // Open absolute mouse device - if self.mouse_abs_path.exists() { - let file = Self::open_device(&self.mouse_abs_path)?; - *self.mouse_abs_dev.lock() = Some(file); - info!( - "Absolute mouse device opened: {}", - self.mouse_abs_path.display() - ); - } else { - warn!( - "Absolute mouse device not found: {}", - self.mouse_abs_path.display() - ); + if let Some(ref path) = self.mouse_abs_path { + if path.exists() { + let file = Self::open_device(path)?; + *self.mouse_abs_dev.lock() = Some(file); + info!("Absolute mouse device opened: {}", path.display()); + } else { + warn!("Absolute mouse device not found: {}", path.display()); + } } // Open consumer control device (optional, may not exist on older setups) - if self.consumer_path.exists() { - let file = Self::open_device(&self.consumer_path)?; - *self.consumer_dev.lock() = Some(file); - info!( - "Consumer control device opened: {}", - self.consumer_path.display() - ); - } else { - debug!( - "Consumer control device not found: {}", - self.consumer_path.display() - ); + if let Some(ref path) = self.consumer_path { + if path.exists() { + let file = Self::open_device(path)?; + *self.consumer_dev.lock() = Some(file); + info!("Consumer control device opened: {}", path.display()); + } else { + debug!("Consumer control device not found: {}", path.display()); + } } // Mark as online if all devices opened successfully @@ -905,7 +952,9 @@ impl HidBackend for OtgBackend { } fn supports_absolute_mouse(&self) -> bool { - self.mouse_abs_path.exists() + self.mouse_abs_path + .as_ref() + .map_or(false, |p| p.exists()) } async fn send_consumer(&self, event: ConsumerEvent) -> Result<()> { @@ -928,7 +977,7 @@ pub fn is_otg_available() -> bool { let mouse_rel = PathBuf::from("/dev/hidg1"); let mouse_abs = PathBuf::from("/dev/hidg2"); - kb.exists() && mouse_rel.exists() && mouse_abs.exists() + kb.exists() || mouse_rel.exists() || mouse_abs.exists() } /// Implement Drop for OtgBackend to close device files diff --git a/src/main.rs b/src/main.rs index 9c2d93ea..d360ae80 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,11 @@ -use std::net::SocketAddr; +use std::collections::HashSet; +use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; use std::sync::Arc; use axum_server::tls_rustls::RustlsConfig; use clap::{Parser, ValueEnum}; +use futures::{stream::FuturesUnordered, StreamExt}; use rustls::crypto::{ring, CryptoProvider}; use tokio::sync::broadcast; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -16,9 +18,10 @@ use one_kvm::events::EventBus; use one_kvm::extensions::ExtensionManager; use one_kvm::hid::{HidBackendType, HidController}; use one_kvm::msd::MsdController; -use one_kvm::otg::OtgService; +use one_kvm::otg::{configfs, OtgService}; use one_kvm::rustdesk::RustDeskService; use one_kvm::state::AppState; +use one_kvm::utils::bind_tcp_listener; use one_kvm::video::format::{PixelFormat, Resolution}; use one_kvm::video::{Streamer, VideoStreamManager}; use one_kvm::web; @@ -134,7 +137,8 @@ async fn main() -> anyhow::Result<()> { // Apply CLI argument overrides to config (only if explicitly specified) if let Some(addr) = args.address { - config.web.bind_address = addr; + config.web.bind_address = addr.clone(); + config.web.bind_addresses = vec![addr]; } if let Some(port) = args.http_port { config.web.http_port = port; @@ -153,19 +157,18 @@ async fn main() -> anyhow::Result<()> { config.web.ssl_key_path = Some(key_path.to_string_lossy().to_string()); } - // Log final configuration - if config.web.https_enabled { - tracing::info!( - "Server will listen on: https://{}:{}", - config.web.bind_address, - config.web.https_port - ); + let bind_ips = resolve_bind_addresses(&config.web)?; + let scheme = if config.web.https_enabled { "https" } else { "http" }; + let bind_port = if config.web.https_enabled { + config.web.https_port } else { - tracing::info!( - "Server will listen on: http://{}:{}", - config.web.bind_address, - config.web.http_port - ); + config.web.http_port + }; + + // Log final configuration + for ip in &bind_ips { + let addr = SocketAddr::new(*ip, bind_port); + tracing::info!("Server will listen on: {}://{}", scheme, addr); } // Initialize session store @@ -309,11 +312,21 @@ async fn main() -> anyhow::Result<()> { // Pre-enable OTG functions to avoid gadget recreation (prevents kernel crashes) let will_use_otg_hid = matches!(config.hid.backend, config::HidBackend::Otg); - let will_use_msd = config.msd.enabled || will_use_otg_hid; + let will_use_msd = config.msd.enabled; if will_use_otg_hid { - if !config.msd.enabled { - tracing::info!("OTG HID enabled, automatically enabling MSD functionality"); + let mut hid_functions = config.hid.effective_otg_functions(); + if let Some(udc) = configfs::resolve_udc_name(config.hid.otg_udc.as_deref()) { + if configfs::is_low_endpoint_udc(&udc) && hid_functions.consumer { + tracing::warn!( + "UDC {} has low endpoint resources, disabling consumer control", + udc + ); + hid_functions.consumer = false; + } + } + if let Err(e) = otg_service.update_hid_functions(hid_functions).await { + tracing::warn!("Failed to apply HID functions: {}", e); } if let Err(e) = otg_service.enable_hid().await { tracing::warn!("Failed to pre-enable HID: {}", e); @@ -448,27 +461,26 @@ async fn main() -> anyhow::Result<()> { } } - // Set up frame source from video streamer (if capturer is available) - // The frame source allows WebRTC sessions to receive live video frames - if let Some(frame_tx) = streamer.frame_sender().await { - // Synchronize WebRTC config with actual capture format before connecting - let (actual_format, actual_resolution, actual_fps) = streamer.current_video_config().await; - tracing::info!( - "Initial video config from capturer: {}x{} {:?} @ {}fps", - actual_resolution.width, - actual_resolution.height, - actual_format, - actual_fps - ); + // Configure direct capture for WebRTC encoder pipeline + let (device_path, actual_resolution, actual_format, actual_fps, jpeg_quality) = + streamer.current_capture_config().await; + tracing::info!( + "Initial video config: {}x{} {:?} @ {}fps", + actual_resolution.width, + actual_resolution.height, + actual_format, + actual_fps + ); + webrtc_streamer + .update_video_config(actual_resolution, actual_format, actual_fps) + .await; + if let Some(device_path) = device_path { webrtc_streamer - .update_video_config(actual_resolution, actual_format, actual_fps) + .set_capture_device(device_path, jpeg_quality) .await; - webrtc_streamer.set_video_source(frame_tx).await; - tracing::info!("WebRTC streamer connected to video frame source"); + tracing::info!("WebRTC streamer configured for direct capture"); } else { - tracing::warn!( - "Video capturer not ready, WebRTC will connect to frame source when available" - ); + tracing::warn!("No capture device configured for WebRTC"); } // Create video stream manager (unified MJPEG/WebRTC management) @@ -589,12 +601,8 @@ async fn main() -> anyhow::Result<()> { // Create router let app = web::create_router(state.clone()); - // Determine bind address based on HTTPS setting - let bind_addr: SocketAddr = if config.web.https_enabled { - format!("{}:{}", config.web.bind_address, config.web.https_port).parse()? - } else { - format!("{}:{}", config.web.bind_address, config.web.http_port).parse()? - }; + // Bind sockets for configured addresses + let listeners = bind_tcp_listeners(&bind_ips, bind_port)?; // Setup graceful shutdown let shutdown_signal = async move { @@ -631,33 +639,44 @@ async fn main() -> anyhow::Result<()> { RustlsConfig::from_pem_file(&cert_path, &key_path).await? }; - tracing::info!("Starting HTTPS server on {}", bind_addr); + let mut servers = FuturesUnordered::new(); + for listener in listeners { + let local_addr = listener.local_addr()?; + tracing::info!("Starting HTTPS server on {}", local_addr); - let server = axum_server::bind_rustls(bind_addr, tls_config).serve(app.into_make_service()); + let server = axum_server::from_tcp_rustls(listener, tls_config.clone())? + .serve(app.clone().into_make_service()); + servers.push(async move { server.await }); + } tokio::select! { _ = shutdown_signal => { cleanup(&state).await; } - result = server => { - if let Err(e) = result { + result = servers.next() => { + if let Some(Err(e)) = result { tracing::error!("HTTPS server error: {}", e); } cleanup(&state).await; } } } else { - tracing::info!("Starting HTTP server on {}", bind_addr); + let mut servers = FuturesUnordered::new(); + for listener in listeners { + let local_addr = listener.local_addr()?; + tracing::info!("Starting HTTP server on {}", local_addr); - let listener = tokio::net::TcpListener::bind(bind_addr).await?; - let server = axum::serve(listener, app); + let listener = tokio::net::TcpListener::from_std(listener)?; + let server = axum::serve(listener, app.clone()); + servers.push(async move { server.await }); + } tokio::select! { _ = shutdown_signal => { cleanup(&state).await; } - result = server => { - if let Err(e) = result { + result = servers.next() => { + if let Some(Err(e)) = result { tracing::error!("HTTP server error: {}", e); } cleanup(&state).await; @@ -710,6 +729,47 @@ fn get_data_dir() -> PathBuf { PathBuf::from("/etc/one-kvm") } +/// Resolve bind IPs from config, preferring bind_addresses when set. +fn resolve_bind_addresses(web: &config::WebConfig) -> anyhow::Result> { + let raw_addrs = if !web.bind_addresses.is_empty() { + web.bind_addresses.as_slice() + } else { + std::slice::from_ref(&web.bind_address) + }; + + let mut seen = HashSet::new(); + let mut addrs = Vec::new(); + for addr in raw_addrs { + let ip: IpAddr = addr + .parse() + .map_err(|_| anyhow::anyhow!("Invalid bind address: {}", addr))?; + if seen.insert(ip) { + addrs.push(ip); + } + } + + Ok(addrs) +} + +fn bind_tcp_listeners(addrs: &[IpAddr], port: u16) -> anyhow::Result> { + let mut listeners = Vec::new(); + for ip in addrs { + let addr = SocketAddr::new(*ip, port); + match bind_tcp_listener(addr) { + Ok(listener) => listeners.push(listener), + Err(err) => { + tracing::warn!("Failed to bind {}: {}", addr, err); + } + } + } + + if listeners.is_empty() { + anyhow::bail!("Failed to bind any addresses on port {}", port); + } + + Ok(listeners) +} + /// Parse video format and resolution from config (avoids code duplication) fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) { let format = config diff --git a/src/otg/configfs.rs b/src/otg/configfs.rs index 5e7fd493..bacdd50b 100644 --- a/src/otg/configfs.rs +++ b/src/otg/configfs.rs @@ -43,6 +43,23 @@ pub fn find_udc() -> Option { .next() } +/// Check if UDC is known to have low endpoint resources +pub fn is_low_endpoint_udc(name: &str) -> bool { + let name = name.to_ascii_lowercase(); + name.contains("musb") || name.contains("musb-hdrc") +} + +/// Resolve preferred UDC name if available, otherwise auto-detect +pub fn resolve_udc_name(preferred: Option<&str>) -> Option { + if let Some(name) = preferred { + let path = Path::new("/sys/class/udc").join(name); + if path.exists() { + return Some(name.to_string()); + } + } + find_udc() +} + /// Write string content to a file /// /// For sysfs files, this function appends a newline and flushes diff --git a/src/otg/manager.rs b/src/otg/manager.rs index 7e50949d..64773d2b 100644 --- a/src/otg/manager.rs +++ b/src/otg/manager.rs @@ -6,9 +6,9 @@ use std::path::PathBuf; use tracing::{debug, error, info, warn}; use super::configfs::{ - create_dir, find_udc, is_configfs_available, remove_dir, write_file, CONFIGFS_PATH, - DEFAULT_GADGET_NAME, DEFAULT_USB_BCD_DEVICE, DEFAULT_USB_PRODUCT_ID, DEFAULT_USB_VENDOR_ID, - USB_BCD_USB, + create_dir, create_symlink, find_udc, is_configfs_available, remove_dir, remove_file, + write_file, CONFIGFS_PATH, DEFAULT_GADGET_NAME, DEFAULT_USB_BCD_DEVICE, DEFAULT_USB_PRODUCT_ID, + DEFAULT_USB_VENDOR_ID, USB_BCD_USB, }; use super::endpoint::{EndpointAllocator, DEFAULT_MAX_ENDPOINTS}; use super::function::{FunctionMeta, GadgetFunction}; @@ -16,6 +16,8 @@ use super::hid::HidFunction; use super::msd::MsdFunction; use crate::error::{AppError, Result}; +const REBIND_DELAY_MS: u64 = 300; + /// USB Gadget device descriptor configuration #[derive(Debug, Clone)] pub struct GadgetDescriptor { @@ -249,9 +251,15 @@ impl OtgGadgetManager { AppError::Internal("No USB Device Controller (UDC) found".to_string()) })?; + // Recreate config symlinks before binding to avoid kernel gadget issues after rebind + if let Err(e) = self.recreate_config_links() { + warn!("Failed to recreate gadget config links before bind: {}", e); + } + info!("Binding gadget to UDC: {}", udc); write_file(&self.gadget_path.join("UDC"), &udc)?; self.bound_udc = Some(udc); + std::thread::sleep(std::time::Duration::from_millis(REBIND_DELAY_MS)); Ok(()) } @@ -262,6 +270,7 @@ impl OtgGadgetManager { write_file(&self.gadget_path.join("UDC"), "")?; self.bound_udc = None; info!("Unbound gadget from UDC"); + std::thread::sleep(std::time::Duration::from_millis(REBIND_DELAY_MS)); } Ok(()) } @@ -382,6 +391,47 @@ impl OtgGadgetManager { pub fn gadget_path(&self) -> &PathBuf { &self.gadget_path } + + /// Recreate config symlinks from functions directory + fn recreate_config_links(&self) -> Result<()> { + let functions_path = self.gadget_path.join("functions"); + if !functions_path.exists() || !self.config_path.exists() { + return Ok(()); + } + + let entries = std::fs::read_dir(&functions_path).map_err(|e| { + AppError::Internal(format!( + "Failed to read functions directory {}: {}", + functions_path.display(), + e + )) + })?; + + for entry in entries.flatten() { + let name = entry.file_name(); + let name = match name.to_str() { + Some(n) => n, + None => continue, + }; + if !name.contains(".usb") { + continue; + } + + let src = functions_path.join(name); + let dest = self.config_path.join(name); + + if dest.exists() { + if let Err(e) = remove_file(&dest) { + warn!("Failed to remove existing config link {}: {}", dest.display(), e); + continue; + } + } + + create_symlink(&src, &dest)?; + } + + Ok(()) + } } impl Default for OtgGadgetManager { diff --git a/src/otg/service.rs b/src/otg/service.rs index 38755bc8..2be3c944 100644 --- a/src/otg/service.rs +++ b/src/otg/service.rs @@ -27,7 +27,7 @@ use tracing::{debug, info, warn}; use super::manager::{wait_for_hid_devices, GadgetDescriptor, OtgGadgetManager}; use super::msd::MsdFunction; -use crate::config::OtgDescriptorConfig; +use crate::config::{OtgDescriptorConfig, OtgHidFunctions}; use crate::error::{AppError, Result}; /// Bitflags for requested functions (lock-free) @@ -37,23 +37,42 @@ const FLAG_MSD: u8 = 0b10; /// HID device paths #[derive(Debug, Clone)] pub struct HidDevicePaths { - pub keyboard: PathBuf, - pub mouse_relative: PathBuf, - pub mouse_absolute: PathBuf, + pub keyboard: Option, + pub mouse_relative: Option, + pub mouse_absolute: Option, pub consumer: Option, } impl Default for HidDevicePaths { fn default() -> Self { Self { - keyboard: PathBuf::from("/dev/hidg0"), - mouse_relative: PathBuf::from("/dev/hidg1"), - mouse_absolute: PathBuf::from("/dev/hidg2"), - consumer: Some(PathBuf::from("/dev/hidg3")), + keyboard: None, + mouse_relative: None, + mouse_absolute: None, + consumer: None, } } } +impl HidDevicePaths { + pub fn existing_paths(&self) -> Vec { + let mut paths = Vec::new(); + if let Some(ref p) = self.keyboard { + paths.push(p.clone()); + } + if let Some(ref p) = self.mouse_relative { + paths.push(p.clone()); + } + if let Some(ref p) = self.mouse_absolute { + paths.push(p.clone()); + } + if let Some(ref p) = self.consumer { + paths.push(p.clone()); + } + paths + } +} + /// OTG Service state #[derive(Debug, Clone, Default)] pub struct OtgServiceState { @@ -65,6 +84,8 @@ pub struct OtgServiceState { pub msd_enabled: bool, /// HID device paths (set after gadget setup) pub hid_paths: Option, + /// HID function selection (set after gadget setup) + pub hid_functions: Option, /// Error message if setup failed pub error: Option, } @@ -83,6 +104,8 @@ pub struct OtgService { msd_function: RwLock>, /// Requested functions flags (atomic, lock-free read/write) requested_flags: AtomicU8, + /// Requested HID function set + hid_functions: RwLock, /// Current descriptor configuration current_descriptor: RwLock, } @@ -95,6 +118,7 @@ impl OtgService { state: RwLock::new(OtgServiceState::default()), msd_function: RwLock::new(None), requested_flags: AtomicU8::new(0), + hid_functions: RwLock::new(OtgHidFunctions::default()), current_descriptor: RwLock::new(GadgetDescriptor::default()), } } @@ -167,6 +191,35 @@ impl OtgService { self.state.read().await.hid_paths.clone() } + /// Get current HID function selection + pub async fn hid_functions(&self) -> OtgHidFunctions { + self.hid_functions.read().await.clone() + } + + /// Update HID function selection + pub async fn update_hid_functions(&self, functions: OtgHidFunctions) -> Result<()> { + if functions.is_empty() { + return Err(AppError::BadRequest( + "OTG HID functions cannot be empty".to_string(), + )); + } + + { + let mut current = self.hid_functions.write().await; + if *current == functions { + return Ok(()); + } + *current = functions; + } + + // If HID is active, recreate gadget with new function set + if self.is_hid_requested() { + self.recreate_gadget().await?; + } + + Ok(()) + } + /// Get MSD function handle (for LUN configuration) pub async fn msd_function(&self) -> Option { self.msd_function.read().await.clone() @@ -182,13 +235,16 @@ impl OtgService { // Mark HID as requested (lock-free) self.set_hid_requested(true); - // Check if already enabled + // Check if already enabled and function set unchanged + let requested_functions = self.hid_functions.read().await.clone(); { let state = self.state.read().await; if state.hid_enabled { - if let Some(ref paths) = state.hid_paths { - info!("HID already enabled, returning existing paths"); - return Ok(paths.clone()); + if state.hid_functions.as_ref() == Some(&requested_functions) { + if let Some(ref paths) = state.hid_paths { + info!("HID already enabled, returning existing paths"); + return Ok(paths.clone()); + } } } } @@ -294,6 +350,11 @@ impl OtgService { // Read requested flags atomically (lock-free) let hid_requested = self.is_hid_requested(); let msd_requested = self.is_msd_requested(); + let hid_functions = if hid_requested { + self.hid_functions.read().await.clone() + } else { + OtgHidFunctions::default() + }; info!( "Recreating gadget with: HID={}, MSD={}", @@ -303,9 +364,15 @@ impl OtgService { // Check if gadget already matches requested state { let state = self.state.read().await; + let functions_match = if hid_requested { + state.hid_functions.as_ref() == Some(&hid_functions) + } else { + state.hid_functions.is_none() + }; if state.gadget_active && state.hid_enabled == hid_requested && state.msd_enabled == msd_requested + && functions_match { info!("Gadget already has requested functions, skipping recreate"); return Ok(()); @@ -333,6 +400,7 @@ impl OtgService { state.hid_enabled = false; state.msd_enabled = false; state.hid_paths = None; + state.hid_functions = None; state.error = None; } @@ -361,28 +429,65 @@ impl OtgService { // Add HID functions if requested if hid_requested { - match ( - manager.add_keyboard(), - manager.add_mouse_relative(), - manager.add_mouse_absolute(), - manager.add_consumer_control(), - ) { - (Ok(kb), Ok(rel), Ok(abs), Ok(consumer)) => { - hid_paths = Some(HidDevicePaths { - keyboard: kb, - mouse_relative: rel, - mouse_absolute: abs, - consumer: Some(consumer), - }); - debug!("HID functions added to gadget"); - } - (Err(e), _, _, _) | (_, Err(e), _, _) | (_, _, Err(e), _) | (_, _, _, Err(e)) => { - let error = format!("Failed to add HID functions: {}", e); - let mut state = self.state.write().await; - state.error = Some(error.clone()); - return Err(AppError::Internal(error)); + if hid_functions.is_empty() { + let error = "HID functions set is empty".to_string(); + let mut state = self.state.write().await; + state.error = Some(error.clone()); + return Err(AppError::BadRequest(error)); + } + + let mut paths = HidDevicePaths::default(); + + if hid_functions.keyboard { + match manager.add_keyboard() { + Ok(kb) => paths.keyboard = Some(kb), + Err(e) => { + let error = format!("Failed to add keyboard HID function: {}", e); + let mut state = self.state.write().await; + state.error = Some(error.clone()); + return Err(AppError::Internal(error)); + } } } + + if hid_functions.mouse_relative { + match manager.add_mouse_relative() { + Ok(rel) => paths.mouse_relative = Some(rel), + Err(e) => { + let error = format!("Failed to add relative mouse HID function: {}", e); + let mut state = self.state.write().await; + state.error = Some(error.clone()); + return Err(AppError::Internal(error)); + } + } + } + + if hid_functions.mouse_absolute { + match manager.add_mouse_absolute() { + Ok(abs) => paths.mouse_absolute = Some(abs), + Err(e) => { + let error = format!("Failed to add absolute mouse HID function: {}", e); + let mut state = self.state.write().await; + state.error = Some(error.clone()); + return Err(AppError::Internal(error)); + } + } + } + + if hid_functions.consumer { + match manager.add_consumer_control() { + Ok(consumer) => paths.consumer = Some(consumer), + Err(e) => { + let error = format!("Failed to add consumer HID function: {}", e); + let mut state = self.state.write().await; + state.error = Some(error.clone()); + return Err(AppError::Internal(error)); + } + } + } + + hid_paths = Some(paths); + debug!("HID functions added to gadget"); } // Add MSD function if requested @@ -423,12 +528,8 @@ impl OtgService { // Wait for HID devices to appear if let Some(ref paths) = hid_paths { - let device_paths = vec![ - paths.keyboard.clone(), - paths.mouse_relative.clone(), - paths.mouse_absolute.clone(), - ]; - if !wait_for_hid_devices(&device_paths, 2000).await { + let device_paths = paths.existing_paths(); + if !device_paths.is_empty() && !wait_for_hid_devices(&device_paths, 2000).await { warn!("HID devices did not appear after gadget setup"); } } @@ -448,6 +549,11 @@ impl OtgService { state.hid_enabled = hid_requested; state.msd_enabled = msd_requested; state.hid_paths = hid_paths; + state.hid_functions = if hid_requested { + Some(hid_functions) + } else { + None + }; state.error = None; } @@ -509,6 +615,7 @@ impl OtgService { state.hid_enabled = false; state.msd_enabled = false; state.hid_paths = None; + state.hid_functions = None; state.error = None; } diff --git a/src/rustdesk/connection.rs b/src/rustdesk/connection.rs index 5957a528..b408ed12 100644 --- a/src/rustdesk/connection.rs +++ b/src/rustdesk/connection.rs @@ -47,6 +47,11 @@ const DEFAULT_SCREEN_HEIGHT: u32 = 1080; /// Default mouse event throttle interval (16ms ≈ 60Hz) const DEFAULT_MOUSE_THROTTLE_MS: u64 = 16; +/// Advertised RustDesk version for client compatibility. +const RUSTDESK_COMPAT_VERSION: &str = "1.4.5"; +// Advertised platform for RustDesk clients. This affects which UI options are shown. +const RUSTDESK_COMPAT_PLATFORM: &str = "Windows"; + /// Input event throttler /// /// Limits the rate of input events sent to HID devices to prevent EAGAIN errors. @@ -164,6 +169,8 @@ pub struct Connection { last_test_delay_sent: Option, /// Last known CapsLock state from RustDesk modifiers (for detecting toggle) last_caps_lock: bool, + /// Whether relative mouse mode is currently active for this connection + relative_mouse_active: bool, } /// Messages sent to connection handler @@ -241,6 +248,7 @@ impl Connection { last_delay: 0, last_test_delay_sent: None, last_caps_lock: false, + relative_mouse_active: false, }; (conn, rx) @@ -623,7 +631,7 @@ impl Connection { self.negotiated_codec = Some(negotiated); info!("Negotiated video codec: {:?}", negotiated); - let response = self.create_login_response(true); + let response = self.create_login_response(true).await; let response_bytes = response .write_to_bytes() .map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?; @@ -673,7 +681,11 @@ impl Connection { Some(misc::Union::RefreshVideo(refresh)) => { if *refresh { debug!("Video refresh requested"); - // TODO: Request keyframe from encoder + if let Some(ref video_manager) = self.video_manager { + if let Err(e) = video_manager.request_keyframe().await { + warn!("Failed to request keyframe: {}", e); + } + } } } Some(misc::Union::VideoReceived(received)) => { @@ -1064,7 +1076,7 @@ impl Connection { } /// Create login response with dynamically detected encoder capabilities - fn create_login_response(&self, success: bool) -> HbbMessage { + async fn create_login_response(&self, success: bool) -> HbbMessage { if success { // Dynamically detect available encoders let registry = EncoderRegistry::global(); @@ -1080,11 +1092,21 @@ impl Connection { h264_available, h265_available, vp8_available, vp9_available ); + let mut display_width = self.screen_width; + let mut display_height = self.screen_height; + if let Some(ref video_manager) = self.video_manager { + let video_info = video_manager.get_video_info().await; + if let Some((width, height)) = video_info.resolution { + display_width = width; + display_height = height; + } + } + let mut display_info = DisplayInfo::new(); display_info.x = 0; display_info.y = 0; - display_info.width = 1920; - display_info.height = 1080; + display_info.width = display_width as i32; + display_info.height = display_height as i32; display_info.name = "KVM Display".to_string(); display_info.online = true; display_info.cursor_embedded = false; @@ -1099,11 +1121,11 @@ impl Connection { let mut peer_info = PeerInfo::new(); peer_info.username = "one-kvm".to_string(); peer_info.hostname = get_hostname(); - peer_info.platform = "Linux".to_string(); + peer_info.platform = RUSTDESK_COMPAT_PLATFORM.to_string(); peer_info.displays.push(display_info); peer_info.current_display = 0; peer_info.sas_enabled = false; - peer_info.version = env!("CARGO_PKG_VERSION").to_string(); + peer_info.version = RUSTDESK_COMPAT_VERSION.to_string(); peer_info.encoding = protobuf::MessageField::some(encoding); let mut login_response = LoginResponse::new(); @@ -1310,9 +1332,16 @@ impl Connection { async fn handle_mouse_event(&mut self, me: &MouseEvent) -> anyhow::Result<()> { // Parse RustDesk mask format: (button << 3) | event_type let event_type = me.mask & 0x07; + let is_relative_move = event_type == mouse_type::MOVE_RELATIVE; + + if is_relative_move { + self.relative_mouse_active = true; + } else if event_type == mouse_type::MOVE { + self.relative_mouse_active = false; + } // Check if this is a pure move event (no button/scroll) - let is_pure_move = event_type == mouse_type::MOVE; + let is_pure_move = event_type == mouse_type::MOVE || is_relative_move; // For pure move events, apply throttling if is_pure_move && !self.input_throttler.should_send_mouse_move() { @@ -1323,7 +1352,8 @@ impl Connection { debug!("Mouse event: x={}, y={}, mask={}", me.x, me.y, me.mask); // Convert RustDesk mouse event to One-KVM mouse events - let mouse_events = convert_mouse_event(me, self.screen_width, self.screen_height); + let mouse_events = + convert_mouse_event(me, self.screen_width, self.screen_height, self.relative_mouse_active); // Send to HID controller if available if let Some(ref hid) = self.hid { @@ -1543,6 +1573,9 @@ async fn run_video_streaming( let mut shutdown_rx = shutdown_tx.subscribe(); let mut encoded_count: u64 = 0; let mut last_log_time = Instant::now(); + let mut waiting_for_keyframe = true; + let mut last_sequence: Option = None; + let mut last_keyframe_request = Instant::now() - Duration::from_secs(1); info!( "Started shared video streaming for connection {} (codec: {:?})", @@ -1582,6 +1615,9 @@ async fn run_video_streaming( config.bitrate_preset ); } + if let Err(e) = video_manager.request_keyframe().await { + debug!("Failed to request keyframe for connection {}: {}", conn_id, e); + } // Inner loop: receives frames from current subscription loop { @@ -1600,43 +1636,63 @@ async fn run_video_streaming( } result = encoded_frame_rx.recv() => { - match result { - Ok(frame) => { - // Convert EncodedVideoFrame to RustDesk VideoFrame message - // Use zero-copy version: Bytes.clone() only increments refcount - let msg_bytes = video_adapter.encode_frame_bytes_zero_copy( - frame.data.clone(), - frame.is_keyframe, - frame.pts_ms as u64, - ); - - // Send to connection (blocks if channel is full, providing backpressure) - if video_tx.send(msg_bytes).await.is_err() { - debug!("Video channel closed for connection {}", conn_id); - break 'subscribe_loop; - } - - encoded_count += 1; - - // Log stats periodically - if last_log_time.elapsed().as_secs() >= 10 { - info!( - "Video streaming stats for connection {}: {} frames forwarded", - conn_id, encoded_count - ); - last_log_time = Instant::now(); - } - } - Err(broadcast::error::RecvError::Lagged(n)) => { - debug!("Connection {} lagged {} encoded frames", conn_id, n); - } - Err(broadcast::error::RecvError::Closed) => { - // Pipeline was restarted (e.g., bitrate/codec change) - // Re-subscribe to the new pipeline + let frame = match result { + Some(frame) => frame, + None => { info!("Video pipeline closed for connection {}, re-subscribing...", conn_id); tokio::time::sleep(Duration::from_millis(100)).await; continue 'subscribe_loop; } + }; + + let gap_detected = if let Some(prev) = last_sequence { + frame.sequence > prev.saturating_add(1) + } else { + false + }; + + if waiting_for_keyframe || gap_detected { + if frame.is_keyframe { + waiting_for_keyframe = false; + } else { + if gap_detected { + waiting_for_keyframe = true; + } + let now = Instant::now(); + if now.duration_since(last_keyframe_request) >= Duration::from_millis(200) { + if let Err(e) = video_manager.request_keyframe().await { + debug!("Failed to request keyframe for connection {}: {}", conn_id, e); + } + last_keyframe_request = now; + } + continue; + } + } + + // Convert EncodedVideoFrame to RustDesk VideoFrame message + // Use zero-copy version: Bytes.clone() only increments refcount + let msg_bytes = video_adapter.encode_frame_bytes_zero_copy( + frame.data.clone(), + frame.is_keyframe, + frame.pts_ms as u64, + ); + + // Send to connection (backpressure instead of dropping) + if video_tx.send(msg_bytes).await.is_err() { + debug!("Video channel closed for connection {}", conn_id); + break 'subscribe_loop; + } + + last_sequence = Some(frame.sequence); + encoded_count += 1; + + // Log stats periodically + if last_log_time.elapsed().as_secs() >= 30 { + info!( + "Video streaming stats for connection {}: {} frames forwarded", + conn_id, encoded_count + ); + last_log_time = Instant::now(); } } } @@ -1725,39 +1781,38 @@ async fn run_audio_streaming( break 'subscribe_loop; } - result = opus_rx.recv() => { - match result { - Ok(opus_frame) => { - // Convert OpusFrame to RustDesk AudioFrame message - let msg_bytes = audio_adapter.encode_opus_bytes(&opus_frame.data); + result = opus_rx.changed() => { + if result.is_err() { + // Pipeline was restarted + info!("Audio pipeline closed for connection {}, re-subscribing...", conn_id); + audio_adapter.reset(); + tokio::time::sleep(Duration::from_millis(100)).await; + continue 'subscribe_loop; + } - // Send to connection (blocks if channel is full, providing backpressure) - if audio_tx.send(msg_bytes).await.is_err() { - debug!("Audio channel closed for connection {}", conn_id); - break 'subscribe_loop; - } + let opus_frame = match opus_rx.borrow().clone() { + Some(frame) => frame, + None => continue, + }; - frame_count += 1; + // Convert OpusFrame to RustDesk AudioFrame message + let msg_bytes = audio_adapter.encode_opus_bytes(&opus_frame.data); - // Log stats periodically - if last_log_time.elapsed().as_secs() >= 30 { - info!( - "Audio streaming stats for connection {}: {} frames forwarded", - conn_id, frame_count - ); - last_log_time = Instant::now(); - } - } - Err(broadcast::error::RecvError::Lagged(n)) => { - debug!("Connection {} lagged {} audio frames", conn_id, n); - } - Err(broadcast::error::RecvError::Closed) => { - // Pipeline was restarted - info!("Audio pipeline closed for connection {}, re-subscribing...", conn_id); - audio_adapter.reset(); - tokio::time::sleep(Duration::from_millis(100)).await; - continue 'subscribe_loop; - } + // Send to connection (blocks if channel is full, providing backpressure) + if audio_tx.send(msg_bytes).await.is_err() { + debug!("Audio channel closed for connection {}", conn_id); + break 'subscribe_loop; + } + + frame_count += 1; + + // Log stats periodically + if last_log_time.elapsed().as_secs() >= 30 { + info!( + "Audio streaming stats for connection {}: {} frames forwarded", + conn_id, frame_count + ); + last_log_time = Instant::now(); } } } diff --git a/src/rustdesk/frame_adapters.rs b/src/rustdesk/frame_adapters.rs index 14e4d321..fbee2c1e 100644 --- a/src/rustdesk/frame_adapters.rs +++ b/src/rustdesk/frame_adapters.rs @@ -42,6 +42,9 @@ pub struct VideoFrameAdapter { seq: u32, /// Timestamp offset timestamp_base: u64, + /// Cached H264 SPS/PPS (Annex B NAL without start code) + h264_sps: Option, + h264_pps: Option, } impl VideoFrameAdapter { @@ -51,6 +54,8 @@ impl VideoFrameAdapter { codec, seq: 0, timestamp_base: 0, + h264_sps: None, + h264_pps: None, } } @@ -68,6 +73,7 @@ impl VideoFrameAdapter { is_keyframe: bool, timestamp_ms: u64, ) -> Message { + let data = self.prepare_h264_frame(data, is_keyframe); // Calculate relative timestamp if self.seq == 0 { self.timestamp_base = timestamp_ms; @@ -100,6 +106,41 @@ impl VideoFrameAdapter { msg } + fn prepare_h264_frame(&mut self, data: Bytes, is_keyframe: bool) -> Bytes { + if self.codec != VideoCodec::H264 { + return data; + } + + // Parse SPS/PPS from Annex B data (without start codes) + let (sps, pps) = crate::webrtc::rtp::extract_sps_pps(&data); + let mut has_sps = false; + let mut has_pps = false; + + if let Some(sps) = sps { + self.h264_sps = Some(Bytes::from(sps)); + has_sps = true; + } + if let Some(pps) = pps { + self.h264_pps = Some(Bytes::from(pps)); + has_pps = true; + } + + // Inject cached SPS/PPS before IDR when missing + if is_keyframe && (!has_sps || !has_pps) { + if let (Some(ref sps), Some(ref pps)) = (self.h264_sps.as_ref(), self.h264_pps.as_ref()) { + let mut out = Vec::with_capacity(8 + sps.len() + pps.len() + data.len()); + out.extend_from_slice(&[0, 0, 0, 1]); + out.extend_from_slice(sps); + out.extend_from_slice(&[0, 0, 0, 1]); + out.extend_from_slice(pps); + out.extend_from_slice(&data); + return Bytes::from(out); + } + } + + data + } + /// Convert encoded video data to RustDesk Message pub fn encode_frame(&mut self, data: &[u8], is_keyframe: bool, timestamp_ms: u64) -> Message { self.encode_frame_from_bytes(Bytes::copy_from_slice(data), is_keyframe, timestamp_ms) diff --git a/src/rustdesk/hid_adapter.rs b/src/rustdesk/hid_adapter.rs index 4261f8d9..7a89cdbb 100644 --- a/src/rustdesk/hid_adapter.rs +++ b/src/rustdesk/hid_adapter.rs @@ -18,6 +18,7 @@ pub mod mouse_type { pub const UP: i32 = 2; pub const WHEEL: i32 = 3; pub const TRACKPAD: i32 = 4; + pub const MOVE_RELATIVE: i32 = 5; } /// Mouse button IDs from RustDesk protocol (before left shift by 3) @@ -36,23 +37,25 @@ pub fn convert_mouse_event( event: &MouseEvent, screen_width: u32, screen_height: u32, + relative_mode: bool, ) -> Vec { let mut events = Vec::new(); - // RustDesk uses absolute coordinates - let x = event.x.max(0) as u32; - let y = event.y.max(0) as u32; - - // Normalize to 0-32767 range for absolute mouse (USB HID standard) - let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32; - let abs_y = ((y as u64 * 32767) / screen_height.max(1) as u64) as i32; - // Parse RustDesk mask format: (button << 3) | event_type let event_type = event.mask & 0x07; let button_id = event.mask >> 3; + let include_abs_move = !relative_mode; match event_type { mouse_type::MOVE => { + // RustDesk uses absolute coordinates + let x = event.x.max(0) as u32; + let y = event.y.max(0) as u32; + + // Normalize to 0-32767 range for absolute mouse (USB HID standard) + let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32; + let abs_y = ((y as u64 * 32767) / screen_height.max(1) as u64) as i32; + // Move event - may have button held down (button_id > 0 means dragging) // Just send move, button state is tracked separately by HID backend events.push(OneKvmMouseEvent { @@ -63,55 +66,83 @@ pub fn convert_mouse_event( scroll: 0, }); } - mouse_type::DOWN => { - // Button down - first move, then press + mouse_type::MOVE_RELATIVE => { + // Relative movement uses delta values directly (dx, dy). events.push(OneKvmMouseEvent { - event_type: MouseEventType::MoveAbs, - x: abs_x, - y: abs_y, + event_type: MouseEventType::Move, + x: event.x, + y: event.y, button: None, scroll: 0, }); + } + mouse_type::DOWN => { + if include_abs_move { + // Button down - first move, then press + let x = event.x.max(0) as u32; + let y = event.y.max(0) as u32; + let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32; + let abs_y = ((y as u64 * 32767) / screen_height.max(1) as u64) as i32; + events.push(OneKvmMouseEvent { + event_type: MouseEventType::MoveAbs, + x: abs_x, + y: abs_y, + button: None, + scroll: 0, + }); + } if let Some(button) = button_id_to_button(button_id) { events.push(OneKvmMouseEvent { event_type: MouseEventType::Down, - x: abs_x, - y: abs_y, + x: 0, + y: 0, button: Some(button), scroll: 0, }); } } mouse_type::UP => { - // Button up - first move, then release - events.push(OneKvmMouseEvent { - event_type: MouseEventType::MoveAbs, - x: abs_x, - y: abs_y, - button: None, - scroll: 0, - }); + if include_abs_move { + // Button up - first move, then release + let x = event.x.max(0) as u32; + let y = event.y.max(0) as u32; + let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32; + let abs_y = ((y as u64 * 32767) / screen_height.max(1) as u64) as i32; + events.push(OneKvmMouseEvent { + event_type: MouseEventType::MoveAbs, + x: abs_x, + y: abs_y, + button: None, + scroll: 0, + }); + } if let Some(button) = button_id_to_button(button_id) { events.push(OneKvmMouseEvent { event_type: MouseEventType::Up, - x: abs_x, - y: abs_y, + x: 0, + y: 0, button: Some(button), scroll: 0, }); } } mouse_type::WHEEL => { - // Scroll event - move first, then scroll - events.push(OneKvmMouseEvent { - event_type: MouseEventType::MoveAbs, - x: abs_x, - y: abs_y, - button: None, - scroll: 0, - }); + if include_abs_move { + // Scroll event - move first, then scroll + let x = event.x.max(0) as u32; + let y = event.y.max(0) as u32; + let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32; + let abs_y = ((y as u64 * 32767) / screen_height.max(1) as u64) as i32; + events.push(OneKvmMouseEvent { + event_type: MouseEventType::MoveAbs, + x: abs_x, + y: abs_y, + button: None, + scroll: 0, + }); + } // RustDesk encodes scroll direction in the y coordinate // Positive y = scroll up, Negative y = scroll down @@ -119,21 +150,27 @@ pub fn convert_mouse_event( let scroll = if event.y > 0 { 1i8 } else { -1i8 }; events.push(OneKvmMouseEvent { event_type: MouseEventType::Scroll, - x: abs_x, - y: abs_y, + x: 0, + y: 0, button: None, scroll, }); } _ => { - // Unknown event type, just move - events.push(OneKvmMouseEvent { - event_type: MouseEventType::MoveAbs, - x: abs_x, - y: abs_y, - button: None, - scroll: 0, - }); + if include_abs_move { + // Unknown event type, just move + let x = event.x.max(0) as u32; + let y = event.y.max(0) as u32; + let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32; + let abs_y = ((y as u64 * 32767) / screen_height.max(1) as u64) as i32; + events.push(OneKvmMouseEvent { + event_type: MouseEventType::MoveAbs, + x: abs_x, + y: abs_y, + button: None, + scroll: 0, + }); + } } } @@ -522,7 +559,7 @@ mod tests { event.y = 300; event.mask = mouse_type::MOVE; // Pure move event - let events = convert_mouse_event(&event, 1920, 1080); + let events = convert_mouse_event(&event, 1920, 1080, false); assert!(!events.is_empty()); assert_eq!(events[0].event_type, MouseEventType::MoveAbs); } @@ -534,7 +571,7 @@ mod tests { event.y = 300; event.mask = (mouse_button::LEFT << 3) | mouse_type::DOWN; - let events = convert_mouse_event(&event, 1920, 1080); + let events = convert_mouse_event(&event, 1920, 1080, false); assert!(events.len() >= 2); // Should have a button down event assert!(events @@ -542,6 +579,20 @@ mod tests { .any(|e| e.event_type == MouseEventType::Down && e.button == Some(MouseButton::Left))); } + #[test] + fn test_convert_mouse_move_relative() { + let mut event = MouseEvent::new(); + event.x = -12; + event.y = 8; + event.mask = mouse_type::MOVE_RELATIVE; + + let events = convert_mouse_event(&event, 1920, 1080, true); + assert_eq!(events.len(), 1); + assert_eq!(events[0].event_type, MouseEventType::Move); + assert_eq!(events[0].x, -12); + assert_eq!(events[0].y, 8); + } + #[test] fn test_convert_key_event() { use protobuf::EnumOrUnknown; diff --git a/src/rustdesk/mod.rs b/src/rustdesk/mod.rs index 23178c47..5b636497 100644 --- a/src/rustdesk/mod.rs +++ b/src/rustdesk/mod.rs @@ -23,7 +23,7 @@ pub mod protocol; pub mod punch; pub mod rendezvous; -use std::net::SocketAddr; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; use std::time::Duration; @@ -37,6 +37,7 @@ use tracing::{debug, error, info, warn}; use crate::audio::AudioController; use crate::hid::HidController; use crate::video::stream_manager::VideoStreamManager; +use crate::utils::bind_tcp_listener; use self::config::RustDeskConfig; use self::connection::ConnectionManager; @@ -84,7 +85,7 @@ pub struct RustDeskService { status: Arc>, rendezvous: Arc>>>, rendezvous_handle: Arc>>>, - tcp_listener_handle: Arc>>>, + tcp_listener_handle: Arc>>>>, listen_port: Arc>, connection_manager: Arc, video_manager: Arc, @@ -212,8 +213,8 @@ impl RustDeskService { // Start TCP listener BEFORE the rendezvous mediator to ensure port is set correctly // This prevents race condition where mediator starts registration with wrong port - let (tcp_handle, listen_port) = self.start_tcp_listener_with_port().await?; - *self.tcp_listener_handle.write() = Some(tcp_handle); + let (tcp_handles, listen_port) = self.start_tcp_listener_with_port().await?; + *self.tcp_listener_handle.write() = Some(tcp_handles); // Set the listen port on mediator before starting the registration loop mediator.set_listen_port(listen_port); @@ -373,52 +374,83 @@ impl RustDeskService { /// Start TCP listener for direct peer connections /// Returns the join handle and the port that was bound - async fn start_tcp_listener_with_port(&self) -> anyhow::Result<(JoinHandle<()>, u16)> { + async fn start_tcp_listener_with_port(&self) -> anyhow::Result<(Vec>, u16)> { // Try to bind to the default port, or find an available port - let listener = match TcpListener::bind(format!("0.0.0.0:{}", DIRECT_LISTEN_PORT)).await { - Ok(l) => l, - Err(_) => { - // Try binding to port 0 to get an available port - TcpListener::bind("0.0.0.0:0").await? + let (listeners, listen_port) = match self.bind_direct_listeners(DIRECT_LISTEN_PORT) { + Ok(result) => result, + Err(err) => { + warn!( + "Failed to bind RustDesk TCP on port {}: {}, falling back to random port", + DIRECT_LISTEN_PORT, err + ); + self.bind_direct_listeners(0)? } }; - let local_addr = listener.local_addr()?; - let listen_port = local_addr.port(); *self.listen_port.write() = listen_port; - info!("RustDesk TCP listener started on {}", local_addr); let connection_manager = self.connection_manager.clone(); - let mut shutdown_rx = self.shutdown_tx.subscribe(); + let mut handles = Vec::new(); - let handle = tokio::spawn(async move { - loop { - tokio::select! { - result = listener.accept() => { - match result { - Ok((stream, peer_addr)) => { - info!("Accepted direct connection from {}", peer_addr); - let conn_mgr = connection_manager.clone(); - tokio::spawn(async move { - if let Err(e) = conn_mgr.accept_connection(stream, peer_addr).await { - error!("Failed to handle direct connection from {}: {}", peer_addr, e); - } - }); - } - Err(e) => { - error!("TCP accept error: {}", e); + for listener in listeners { + let local_addr = listener.local_addr()?; + info!("RustDesk TCP listener started on {}", local_addr); + + let conn_mgr = connection_manager.clone(); + let mut shutdown_rx = self.shutdown_tx.subscribe(); + let handle = tokio::spawn(async move { + loop { + tokio::select! { + result = listener.accept() => { + match result { + Ok((stream, peer_addr)) => { + info!("Accepted direct connection from {}", peer_addr); + let conn_mgr = conn_mgr.clone(); + tokio::spawn(async move { + if let Err(e) = conn_mgr.accept_connection(stream, peer_addr).await { + error!("Failed to handle direct connection from {}: {}", peer_addr, e); + } + }); + } + Err(e) => { + error!("TCP accept error: {}", e); + } } } - } - _ = shutdown_rx.recv() => { - info!("TCP listener shutting down"); - break; + _ = shutdown_rx.recv() => { + info!("TCP listener shutting down"); + break; + } } } - } - }); + }); + handles.push(handle); + } - Ok((handle, listen_port)) + Ok((handles, listen_port)) + } + + fn bind_direct_listeners(&self, port: u16) -> anyhow::Result<(Vec, u16)> { + let v4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port); + let v4_listener = bind_tcp_listener(v4_addr)?; + let listen_port = v4_listener.local_addr()?.port(); + + let mut listeners = vec![TcpListener::from_std(v4_listener)?]; + + let v6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), listen_port); + match bind_tcp_listener(v6_addr) { + Ok(v6_listener) => { + listeners.push(TcpListener::from_std(v6_listener)?); + } + Err(err) => { + warn!( + "IPv6 listener unavailable on port {}: {}, continuing with IPv4 only", + listen_port, err + ); + } + } + + Ok((listeners, listen_port)) } /// Stop the RustDesk service @@ -446,8 +478,10 @@ impl RustDeskService { } // Wait for TCP listener task to finish - if let Some(handle) = self.tcp_listener_handle.write().take() { - handle.abort(); + if let Some(handles) = self.tcp_listener_handle.write().take() { + for handle in handles { + handle.abort(); + } } *self.rendezvous.write() = None; diff --git a/src/rustdesk/rendezvous.rs b/src/rustdesk/rendezvous.rs index 8b411769..d347f81f 100644 --- a/src/rustdesk/rendezvous.rs +++ b/src/rustdesk/rendezvous.rs @@ -4,7 +4,7 @@ //! It registers the device ID and public key, handles punch hole requests, //! and relay requests. -use std::net::SocketAddr; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -15,6 +15,8 @@ use tokio::sync::broadcast; use tokio::time::interval; use tracing::{debug, error, info, warn}; +use crate::utils::bind_udp_socket; + use super::config::RustDeskConfig; use super::crypto::{KeyPair, SigningKeyPair}; use super::protocol::{ @@ -288,8 +290,13 @@ impl RendezvousMediator { .next() .ok_or_else(|| anyhow::anyhow!("Failed to resolve {}", addr))?; - // Create UDP socket - let socket = UdpSocket::bind("0.0.0.0:0").await?; + // Create UDP socket (match address family, enforce IPV6_V6ONLY) + let bind_addr = match server_addr { + SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), + }; + let std_socket = bind_udp_socket(bind_addr)?; + let socket = UdpSocket::from_std(std_socket)?; socket.connect(server_addr).await?; info!("Connected to rendezvous server at {}", server_addr); diff --git a/src/state.rs b/src/state.rs index 5a4f8b65..b322f7ed 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{collections::VecDeque, sync::Arc}; use tokio::sync::{broadcast, RwLock}; use crate::atx::AtxController; @@ -56,6 +56,8 @@ pub struct AppState { pub events: Arc, /// Shutdown signal sender pub shutdown_tx: broadcast::Sender<()>, + /// Recently revoked session IDs (for client kick detection) + pub revoked_sessions: Arc>>, /// Data directory path data_dir: std::path::PathBuf, } @@ -92,6 +94,7 @@ impl AppState { extensions, events, shutdown_tx, + revoked_sessions: Arc::new(RwLock::new(VecDeque::new())), data_dir, }) } @@ -106,6 +109,26 @@ impl AppState { self.shutdown_tx.subscribe() } + /// Record revoked session IDs (bounded queue) + pub async fn remember_revoked_sessions(&self, session_ids: Vec) { + if session_ids.is_empty() { + return; + } + let mut guard = self.revoked_sessions.write().await; + for id in session_ids { + guard.push_back(id); + } + while guard.len() > 32 { + guard.pop_front(); + } + } + + /// Check if a session ID was revoked (kicked) + pub async fn is_session_revoked(&self, session_id: &str) -> bool { + let guard = self.revoked_sessions.read().await; + guard.iter().any(|id| id == session_id) + } + /// Get complete device information for WebSocket clients /// /// This method collects the current state of all devices (video, HID, MSD, ATX, Audio) diff --git a/src/stream/mjpeg.rs b/src/stream/mjpeg.rs index 7ed8de50..d98e8c39 100644 --- a/src/stream/mjpeg.rs +++ b/src/stream/mjpeg.rs @@ -157,6 +157,8 @@ pub struct MjpegStreamHandler { max_drop_same_frames: AtomicU64, /// JPEG encoder for non-JPEG input formats jpeg_encoder: ParkingMutex>, + /// JPEG quality for software encoding (1-100) + jpeg_quality: AtomicU64, } impl MjpegStreamHandler { @@ -179,9 +181,16 @@ impl MjpegStreamHandler { last_frame_ts: ParkingRwLock::new(None), dropped_same_frames: AtomicU64::new(0), max_drop_same_frames: AtomicU64::new(max_drop), + jpeg_quality: AtomicU64::new(80), } } + /// Set JPEG quality for software encoding (1-100) + pub fn set_jpeg_quality(&self, quality: u8) { + let clamped = quality.clamp(1, 100) as u64; + self.jpeg_quality.store(clamped, Ordering::Relaxed); + } + /// Update current frame pub fn update_frame(&self, frame: VideoFrame) { // Fast path: if no MJPEG clients are connected, do minimal bookkeeping and avoid @@ -260,23 +269,24 @@ impl MjpegStreamHandler { fn encode_to_jpeg(&self, frame: &VideoFrame) -> Result { let resolution = frame.resolution; let sequence = self.sequence.load(Ordering::Relaxed); + let desired_quality = self.jpeg_quality.load(Ordering::Relaxed) as u32; // Get or create encoder let mut encoder_guard = self.jpeg_encoder.lock(); let encoder = encoder_guard.get_or_insert_with(|| { - let config = EncoderConfig::jpeg(resolution, 85); + let config = EncoderConfig::jpeg(resolution, desired_quality); match JpegEncoder::new(config) { Ok(enc) => { debug!( - "Created JPEG encoder for MJPEG stream: {}x{}", - resolution.width, resolution.height + "Created JPEG encoder for MJPEG stream: {}x{} (q={})", + resolution.width, resolution.height, desired_quality ); enc } Err(e) => { warn!("Failed to create JPEG encoder: {}, using default", e); // Try with default config - JpegEncoder::new(EncoderConfig::jpeg(resolution, 85)) + JpegEncoder::new(EncoderConfig::jpeg(resolution, desired_quality)) .expect("Failed to create default JPEG encoder") } } @@ -288,9 +298,16 @@ impl MjpegStreamHandler { "Resolution changed, recreating JPEG encoder: {}x{}", resolution.width, resolution.height ); - let config = EncoderConfig::jpeg(resolution, 85); + let config = EncoderConfig::jpeg(resolution, desired_quality); *encoder = JpegEncoder::new(config).map_err(|e| format!("Failed to create encoder: {}", e))?; + } else if encoder.config().quality != desired_quality { + if let Err(e) = encoder.set_quality(desired_quality) { + warn!("Failed to set JPEG quality: {}, recreating encoder", e); + let config = EncoderConfig::jpeg(resolution, desired_quality); + *encoder = JpegEncoder::new(config) + .map_err(|e| format!("Failed to create encoder: {}", e))?; + } } // Encode based on input format diff --git a/src/stream/mjpeg_streamer.rs b/src/stream/mjpeg_streamer.rs index ad4fb96c..79e0fb38 100644 --- a/src/stream/mjpeg_streamer.rs +++ b/src/stream/mjpeg_streamer.rs @@ -15,11 +15,18 @@ //! //! Note: Audio WebSocket is handled separately by audio_ws.rs (/api/ws/audio) +use std::io; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use tokio::sync::{broadcast, RwLock}; -use tracing::info; +use tokio::sync::{Mutex, RwLock}; +use tracing::{error, info, warn}; +use v4l::buffer::Type as BufferType; +use v4l::io::traits::CaptureStream; +use v4l::prelude::*; +use v4l::video::Capture; +use v4l::video::capture::Parameters; +use v4l::Format; use crate::audio::AudioController; use crate::error::{AppError, Result}; @@ -28,11 +35,16 @@ use crate::hid::HidController; use crate::video::capture::{CaptureConfig, VideoCapturer}; use crate::video::device::{enumerate_devices, find_best_device, VideoDeviceInfo}; use crate::video::format::{PixelFormat, Resolution}; -use crate::video::frame::VideoFrame; +use crate::video::frame::{FrameBuffer, FrameBufferPool, VideoFrame}; use super::mjpeg::MjpegStreamHandler; use super::ws_hid::WsHidHandler; +/// Minimum valid frame size for capture +const MIN_CAPTURE_FRAME_SIZE: usize = 128; +/// Validate JPEG header every N frames to reduce overhead +const JPEG_VALIDATE_INTERVAL: u64 = 30; + /// MJPEG streamer configuration #[derive(Debug, Clone)] pub struct MjpegStreamerConfig { @@ -104,8 +116,6 @@ pub struct MjpegStreamerStats { pub mjpeg_clients: u64, /// WebSocket HID client count pub ws_hid_clients: usize, - /// Total frames captured - pub frames_captured: u64, } /// MJPEG Streamer @@ -130,6 +140,9 @@ pub struct MjpegStreamer { // === Control === start_lock: tokio::sync::Mutex<()>, + direct_stop: AtomicBool, + direct_active: AtomicBool, + direct_handle: Mutex>>, events: RwLock>>, config_changing: AtomicBool, } @@ -148,6 +161,9 @@ impl MjpegStreamer { ws_hid_handler: WsHidHandler::new(), hid_controller: RwLock::new(None), start_lock: tokio::sync::Mutex::new(()), + direct_stop: AtomicBool::new(false), + direct_active: AtomicBool::new(false), + direct_handle: Mutex::new(None), events: RwLock::new(None), config_changing: AtomicBool::new(false), }) @@ -166,6 +182,9 @@ impl MjpegStreamer { ws_hid_handler: WsHidHandler::new(), hid_controller: RwLock::new(None), start_lock: tokio::sync::Mutex::new(()), + direct_stop: AtomicBool::new(false), + direct_active: AtomicBool::new(false), + direct_handle: Mutex::new(None), events: RwLock::new(None), config_changing: AtomicBool::new(false), }) @@ -228,17 +247,22 @@ impl MjpegStreamer { let device = self.current_device.read().await; let config = self.config.read().await; - let (resolution, format, frames_captured) = - if let Some(ref cap) = *self.capturer.read().await { - let stats = cap.stats().await; + let (resolution, format) = { + if self.direct_active.load(Ordering::Relaxed) { + ( + Some((config.resolution.width, config.resolution.height)), + Some(config.format.to_string()), + ) + } else if let Some(ref cap) = *self.capturer.read().await { + let _ = cap; ( Some((config.resolution.width, config.resolution.height)), Some(config.format.to_string()), - stats.frames_captured, ) } else { - (None, None, 0) - }; + (None, None) + } + }; MjpegStreamerStats { state: state.to_string(), @@ -248,7 +272,6 @@ impl MjpegStreamer { fps: config.fps, mjpeg_clients: self.mjpeg_handler.client_count(), ws_hid_clients: self.ws_hid_handler.client_count(), - frames_captured, } } @@ -266,15 +289,6 @@ impl MjpegStreamer { self.ws_hid_handler.clone() } - /// Get frame sender for WebRTC integration - pub async fn frame_sender(&self) -> Option> { - if let Some(ref cap) = *self.capturer.read().await { - Some(cap.frame_sender()) - } else { - None - } - } - // ======================================================================== // Initialization // ======================================================================== @@ -293,6 +307,7 @@ impl MjpegStreamer { ); let config = self.config.read().await.clone(); + self.mjpeg_handler.set_jpeg_quality(config.jpeg_quality); // Create capture config let capture_config = CaptureConfig { @@ -336,22 +351,23 @@ impl MjpegStreamer { return Ok(()); } - // Get capturer - let capturer = self.capturer.read().await.clone(); - let capturer = - capturer.ok_or_else(|| AppError::VideoError("Not initialized".to_string()))?; + let device = self + .current_device + .read() + .await + .clone() + .ok_or_else(|| AppError::VideoError("Not initialized".to_string()))?; - // Start capture - capturer.start().await?; + let config = self.config.read().await.clone(); - // Start frame forwarding task - let handler = self.mjpeg_handler.clone(); - let mut frame_rx = capturer.frame_sender().subscribe(); - tokio::spawn(async move { - while let Ok(frame) = frame_rx.recv().await { - handler.update_frame(frame); - } + self.direct_stop.store(false, Ordering::SeqCst); + self.direct_active.store(true, Ordering::SeqCst); + + let streamer = self.clone(); + let handle = tokio::task::spawn_blocking(move || { + streamer.run_direct_capture(device.path, config); }); + *self.direct_handle.lock().await = Some(handle); // Note: Audio WebSocket is handled separately by audio_ws.rs (/api/ws/audio) @@ -370,7 +386,14 @@ impl MjpegStreamer { return Ok(()); } - // Stop capturer + self.direct_stop.store(true, Ordering::SeqCst); + + if let Some(handle) = self.direct_handle.lock().await.take() { + let _ = handle.await; + } + self.direct_active.store(false, Ordering::SeqCst); + + // Stop capturer (legacy path) if let Some(ref cap) = *self.capturer.read().await { let _ = cap.stop().await; } @@ -412,6 +435,7 @@ impl MjpegStreamer { // Update config *self.config.write().await = config.clone(); + self.mjpeg_handler.set_jpeg_quality(config.jpeg_quality); // Re-initialize if device path is set if let Some(ref path) = config.device_path { @@ -448,6 +472,202 @@ impl MjpegStreamer { }); } } + + /// Direct capture loop for MJPEG mode (single loop, no broadcast) + fn run_direct_capture(self: Arc, device_path: PathBuf, config: MjpegStreamerConfig) { + const MAX_RETRIES: u32 = 5; + const RETRY_DELAY_MS: u64 = 200; + + let handle = tokio::runtime::Handle::current(); + let mut last_state = MjpegStreamerState::Streaming; + + let mut set_state = |new_state: MjpegStreamerState| { + if new_state != last_state { + handle.block_on(async { + *self.state.write().await = new_state; + self.publish_state_change().await; + }); + last_state = new_state; + } + }; + + let mut device_opt: Option = None; + let mut format_opt: Option = None; + let mut last_error: Option = None; + + for attempt in 0..MAX_RETRIES { + if self.direct_stop.load(Ordering::Relaxed) { + self.direct_active.store(false, Ordering::SeqCst); + return; + } + + let device = match Device::with_path(&device_path) { + Ok(d) => d, + Err(e) => { + let err_str = e.to_string(); + if err_str.contains("busy") || err_str.contains("resource") { + warn!( + "Device busy on attempt {}/{}, retrying in {}ms...", + attempt + 1, + MAX_RETRIES, + RETRY_DELAY_MS + ); + std::thread::sleep(std::time::Duration::from_millis(RETRY_DELAY_MS)); + last_error = Some(err_str); + continue; + } + last_error = Some(err_str); + break; + } + }; + + let requested = Format::new( + config.resolution.width, + config.resolution.height, + config.format.to_fourcc(), + ); + + match device.set_format(&requested) { + Ok(actual) => { + device_opt = Some(device); + format_opt = Some(actual); + break; + } + Err(e) => { + let err_str = e.to_string(); + if err_str.contains("busy") || err_str.contains("resource") { + warn!( + "Device busy on set_format attempt {}/{}, retrying in {}ms...", + attempt + 1, + MAX_RETRIES, + RETRY_DELAY_MS + ); + std::thread::sleep(std::time::Duration::from_millis(RETRY_DELAY_MS)); + last_error = Some(err_str); + continue; + } + last_error = Some(err_str); + break; + } + } + } + + let (device, actual_format) = match (device_opt, format_opt) { + (Some(d), Some(f)) => (d, f), + _ => { + error!( + "Failed to open device {:?}: {}", + device_path, + last_error.unwrap_or_else(|| "unknown error".to_string()) + ); + set_state(MjpegStreamerState::Error); + self.mjpeg_handler.set_offline(); + self.direct_active.store(false, Ordering::SeqCst); + return; + } + }; + + info!( + "Capture format: {}x{} {:?} stride={}", + actual_format.width, actual_format.height, actual_format.fourcc, actual_format.stride + ); + + let resolution = Resolution::new(actual_format.width, actual_format.height); + let pixel_format = + PixelFormat::from_fourcc(actual_format.fourcc).unwrap_or(config.format); + + if config.fps > 0 { + if let Err(e) = device.set_params(&Parameters::with_fps(config.fps)) { + warn!("Failed to set hardware FPS: {}", e); + } + } + + let mut stream = match MmapStream::with_buffers(&device, BufferType::VideoCapture, 4) { + Ok(s) => s, + Err(e) => { + error!("Failed to create capture stream: {}", e); + set_state(MjpegStreamerState::Error); + self.mjpeg_handler.set_offline(); + self.direct_active.store(false, Ordering::SeqCst); + return; + } + }; + + let buffer_pool = Arc::new(FrameBufferPool::new(8)); + let mut signal_present = true; + let mut sequence: u64 = 0; + let mut validate_counter: u64 = 0; + + while !self.direct_stop.load(Ordering::Relaxed) { + let (buf, meta) = match stream.next() { + Ok(frame_data) => frame_data, + Err(e) => { + if e.kind() == io::ErrorKind::TimedOut { + if signal_present { + signal_present = false; + set_state(MjpegStreamerState::NoSignal); + } + std::thread::sleep(std::time::Duration::from_millis(100)); + continue; + } + + let is_device_lost = match e.raw_os_error() { + Some(6) => true, // ENXIO + Some(19) => true, // ENODEV + Some(5) => true, // EIO + Some(32) => true, // EPIPE + Some(108) => true, // ESHUTDOWN + _ => false, + }; + + if is_device_lost { + error!("Video device lost: {} - {}", device_path.display(), e); + set_state(MjpegStreamerState::Error); + self.mjpeg_handler.set_offline(); + self.direct_active.store(false, Ordering::SeqCst); + return; + } + + error!("Capture error: {}", e); + continue; + } + }; + + let frame_size = meta.bytesused as usize; + if frame_size < MIN_CAPTURE_FRAME_SIZE { + continue; + } + + validate_counter = validate_counter.wrapping_add(1); + if pixel_format.is_compressed() + && validate_counter % JPEG_VALIDATE_INTERVAL == 0 + && !VideoFrame::is_valid_jpeg_bytes(&buf[..frame_size]) + { + continue; + } + + let mut owned = buffer_pool.take(frame_size); + owned.resize(frame_size, 0); + owned[..frame_size].copy_from_slice(&buf[..frame_size]); + let frame = VideoFrame::from_pooled( + Arc::new(FrameBuffer::new(owned, Some(buffer_pool.clone()))), + resolution, + pixel_format, + actual_format.stride, + sequence, + ); + sequence = sequence.wrapping_add(1); + + if !signal_present { + signal_present = true; + set_state(MjpegStreamerState::Streaming); + } + + self.mjpeg_handler.update_frame(frame); + } + + self.direct_active.store(false, Ordering::SeqCst); + } } impl Default for MjpegStreamer { @@ -463,6 +683,9 @@ impl Default for MjpegStreamer { ws_hid_handler: WsHidHandler::new(), hid_controller: RwLock::new(None), start_lock: tokio::sync::Mutex::new(()), + direct_stop: AtomicBool::new(false), + direct_active: AtomicBool::new(false), + direct_handle: Mutex::new(None), events: RwLock::new(None), config_changing: AtomicBool::new(false), } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 06002c2f..12bf372a 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -3,5 +3,7 @@ //! This module contains common utilities used across the codebase. pub mod throttle; +pub mod net; pub use throttle::LogThrottler; +pub use net::{bind_tcp_listener, bind_udp_socket}; diff --git a/src/utils/net.rs b/src/utils/net.rs new file mode 100644 index 00000000..2bd38a29 --- /dev/null +++ b/src/utils/net.rs @@ -0,0 +1,84 @@ +//! Networking helpers for binding sockets with explicit IPv6-only behavior. + +use std::io; +use std::net::{SocketAddr, TcpListener, UdpSocket}; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; + +use nix::sys::socket::{ + self, sockopt, AddressFamily, Backlog, SockFlag, SockProtocol, SockType, SockaddrIn, + SockaddrIn6, +}; + +fn socket_addr_family(addr: &SocketAddr) -> AddressFamily { + match addr { + SocketAddr::V4(_) => AddressFamily::Inet, + SocketAddr::V6(_) => AddressFamily::Inet6, + } +} + +/// Bind a TCP listener with IPv6-only set for IPv6 sockets. +pub fn bind_tcp_listener(addr: SocketAddr) -> io::Result { + let domain = socket_addr_family(&addr); + let fd = socket::socket( + domain, + SockType::Stream, + SockFlag::SOCK_CLOEXEC, + SockProtocol::Tcp, + ) + .map_err(io::Error::from)?; + + socket::setsockopt(&fd, sockopt::ReuseAddr, &true).map_err(io::Error::from)?; + + if matches!(addr, SocketAddr::V6(_)) { + socket::setsockopt(&fd, sockopt::Ipv6V6Only, &true).map_err(io::Error::from)?; + } + + match addr { + SocketAddr::V4(v4) => { + let sockaddr = SockaddrIn::from(v4); + socket::bind(fd.as_raw_fd(), &sockaddr).map_err(io::Error::from)?; + } + SocketAddr::V6(v6) => { + let sockaddr = SockaddrIn6::from(v6); + socket::bind(fd.as_raw_fd(), &sockaddr).map_err(io::Error::from)?; + } + } + socket::listen(&fd, Backlog::MAXCONN).map_err(io::Error::from)?; + + let listener = unsafe { TcpListener::from_raw_fd(fd.into_raw_fd()) }; + listener.set_nonblocking(true)?; + Ok(listener) +} + +/// Bind a UDP socket with IPv6-only set for IPv6 sockets. +pub fn bind_udp_socket(addr: SocketAddr) -> io::Result { + let domain = socket_addr_family(&addr); + let fd = socket::socket( + domain, + SockType::Datagram, + SockFlag::SOCK_CLOEXEC, + SockProtocol::Udp, + ) + .map_err(io::Error::from)?; + + socket::setsockopt(&fd, sockopt::ReuseAddr, &true).map_err(io::Error::from)?; + + if matches!(addr, SocketAddr::V6(_)) { + socket::setsockopt(&fd, sockopt::Ipv6V6Only, &true).map_err(io::Error::from)?; + } + + match addr { + SocketAddr::V4(v4) => { + let sockaddr = SockaddrIn::from(v4); + socket::bind(fd.as_raw_fd(), &sockaddr).map_err(io::Error::from)?; + } + SocketAddr::V6(v6) => { + let sockaddr = SockaddrIn6::from(v6); + socket::bind(fd.as_raw_fd(), &sockaddr).map_err(io::Error::from)?; + } + } + + let socket = unsafe { UdpSocket::from_raw_fd(fd.into_raw_fd()) }; + socket.set_nonblocking(true)?; + Ok(socket) +} diff --git a/src/video/capture.rs b/src/video/capture.rs index 21598d9e..8701521f 100644 --- a/src/video/capture.rs +++ b/src/video/capture.rs @@ -2,13 +2,13 @@ //! //! Provides async video capture using memory-mapped buffers. -use bytes::Bytes; use std::io; use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use bytes::Bytes; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; -use tokio::sync::{broadcast, watch, Mutex}; +use tokio::sync::{watch, Mutex}; use tracing::{debug, error, info, warn}; use v4l::buffer::Type as BufferType; use v4l::io::traits::CaptureStream; @@ -92,20 +92,8 @@ impl CaptureConfig { /// Capture statistics #[derive(Debug, Clone, Default)] pub struct CaptureStats { - /// Total frames captured - pub frames_captured: u64, - /// Frames dropped (invalid/too small) - pub frames_dropped: u64, /// Current FPS (calculated) pub current_fps: f32, - /// Average frame size in bytes - pub avg_frame_size: usize, - /// Capture errors - pub errors: u64, - /// Last frame timestamp - pub last_frame_ts: Option, - /// Whether signal is present - pub signal_present: bool, } /// Video capturer state @@ -131,9 +119,7 @@ pub struct VideoCapturer { state: Arc>, state_rx: watch::Receiver, stats: Arc>, - frame_tx: broadcast::Sender, stop_flag: Arc, - sequence: Arc, capture_handle: Mutex>>, /// Last error that occurred (device path, reason) last_error: Arc>>, @@ -143,16 +129,13 @@ impl VideoCapturer { /// Create a new video capturer pub fn new(config: CaptureConfig) -> Self { let (state_tx, state_rx) = watch::channel(CaptureState::Stopped); - let (frame_tx, _) = broadcast::channel(16); // Buffer size 16 for low latency Self { config, state: Arc::new(state_tx), state_rx, stats: Arc::new(Mutex::new(CaptureStats::default())), - frame_tx, stop_flag: Arc::new(AtomicBool::new(false)), - sequence: Arc::new(AtomicU64::new(0)), capture_handle: Mutex::new(None), last_error: Arc::new(parking_lot::RwLock::new(None)), } @@ -178,16 +161,6 @@ impl VideoCapturer { *self.last_error.write() = None; } - /// Subscribe to frames - pub fn subscribe(&self) -> broadcast::Receiver { - self.frame_tx.subscribe() - } - - /// Get frame sender (for sharing with other components like WebRTC) - pub fn frame_sender(&self) -> broadcast::Sender { - self.frame_tx.clone() - } - /// Get capture statistics pub async fn stats(&self) -> CaptureStats { self.stats.lock().await.clone() @@ -225,15 +198,11 @@ impl VideoCapturer { let config = self.config.clone(); let state = self.state.clone(); let stats = self.stats.clone(); - let frame_tx = self.frame_tx.clone(); let stop_flag = self.stop_flag.clone(); - let sequence = self.sequence.clone(); let last_error = self.last_error.clone(); let handle = tokio::task::spawn_blocking(move || { - capture_loop( - config, state, stats, frame_tx, stop_flag, sequence, last_error, - ); + capture_loop(config, state, stats, stop_flag, last_error); }); *self.capture_handle.lock().await = Some(handle); @@ -272,12 +241,10 @@ fn capture_loop( config: CaptureConfig, state: Arc>, stats: Arc>, - frame_tx: broadcast::Sender, stop_flag: Arc, - sequence: Arc, error_holder: Arc>>, ) { - let result = run_capture(&config, &state, &stats, &frame_tx, &stop_flag, &sequence); + let result = run_capture(&config, &state, &stats, &stop_flag); match result { Ok(_) => { @@ -300,9 +267,7 @@ fn run_capture( config: &CaptureConfig, state: &watch::Sender, stats: &Arc>, - frame_tx: &broadcast::Sender, stop_flag: &AtomicBool, - sequence: &AtomicU64, ) -> Result<()> { // Retry logic for device busy errors const MAX_RETRIES: u32 = 5; @@ -368,16 +333,7 @@ fn run_capture( }; // Device opened and format set successfully - proceed with capture - return run_capture_inner( - config, - state, - stats, - frame_tx, - stop_flag, - sequence, - device, - actual_format, - ); + return run_capture_inner(config, state, stats, stop_flag, device, actual_format); } // All retries exhausted @@ -391,9 +347,7 @@ fn run_capture_inner( config: &CaptureConfig, state: &watch::Sender, stats: &Arc>, - frame_tx: &broadcast::Sender, stop_flag: &AtomicBool, - sequence: &AtomicU64, device: Device, actual_format: Format, ) -> Result<()> { @@ -402,8 +356,6 @@ fn run_capture_inner( actual_format.width, actual_format.height, actual_format.fourcc, actual_format.stride ); - let resolution = Resolution::new(actual_format.width, actual_format.height); - let pixel_format = PixelFormat::from_fourcc(actual_format.fourcc).unwrap_or(config.format); // Try to set hardware FPS (V4L2 VIDIOC_S_PARM) if config.fps > 0 { @@ -449,18 +401,13 @@ fn run_capture_inner( // Main capture loop while !stop_flag.load(Ordering::Relaxed) { // Try to capture a frame - let (buf, meta) = match stream.next() { + let (_buf, meta) = match stream.next() { Ok(frame_data) => frame_data, Err(e) => { if e.kind() == io::ErrorKind::TimedOut { warn!("Capture timeout - no signal?"); let _ = state.send(CaptureState::NoSignal); - // Update stats - if let Ok(mut s) = stats.try_lock() { - s.signal_present = false; - } - // Wait a bit before retrying std::thread::sleep(Duration::from_millis(100)); continue; @@ -486,9 +433,6 @@ fn run_capture_inner( } error!("Capture error: {}", e); - if let Ok(mut s) = stats.try_lock() { - s.errors += 1; - } continue; } }; @@ -502,54 +446,16 @@ fn run_capture_inner( "Dropping small frame: {} bytes (bytesused={})", frame_size, meta.bytesused ); - if let Ok(mut s) = stats.try_lock() { - s.frames_dropped += 1; - } continue; } - // For JPEG formats, validate header - if pixel_format.is_compressed() && !is_valid_jpeg(&buf[..frame_size]) { - debug!("Dropping invalid JPEG frame (size={})", frame_size); - if let Ok(mut s) = stats.try_lock() { - s.frames_dropped += 1; - } - continue; - } - - // Create frame with actual data size - let seq = sequence.fetch_add(1, Ordering::Relaxed); - let frame = VideoFrame::new( - Bytes::copy_from_slice(&buf[..frame_size]), - resolution, - pixel_format, - actual_format.stride, - seq, - ); - // Update state if was no signal if *state.borrow() == CaptureState::NoSignal { let _ = state.send(CaptureState::Running); } - // Send frame to subscribers - let receiver_count = frame_tx.receiver_count(); - if receiver_count > 0 { - if let Err(e) = frame_tx.send(frame) { - debug!("No active receivers for frame: {}", e); - } - } else if seq % 60 == 0 { - // Log every 60 frames (about 1 second at 60fps) when no receivers - debug!("No receivers for video frames (receiver_count=0)"); - } - - // Update stats + // Update FPS calculation if let Ok(mut s) = stats.try_lock() { - s.frames_captured += 1; - s.signal_present = true; - s.last_frame_ts = Some(Instant::now()); - - // Update FPS calculation fps_frame_count += 1; let elapsed = fps_window_start.elapsed(); @@ -571,6 +477,7 @@ fn run_capture_inner( } /// Validate JPEG frame data +#[cfg(test)] fn is_valid_jpeg(data: &[u8]) -> bool { if data.len() < 125 { return false; diff --git a/src/video/decoder/mjpeg_rkmpp.rs b/src/video/decoder/mjpeg_rkmpp.rs index c95ada1e..686a722c 100644 --- a/src/video/decoder/mjpeg_rkmpp.rs +++ b/src/video/decoder/mjpeg_rkmpp.rs @@ -2,7 +2,7 @@ use hwcodec::ffmpeg::AVPixelFormat; use hwcodec::ffmpeg_ram::decode::{DecodeContext, Decoder}; -use tracing::warn; +use tracing::{info, warn}; use crate::error::{AppError, Result}; use crate::video::convert::Nv12Converter; @@ -72,6 +72,9 @@ impl MjpegRkmppDecoder { ); } } else { + if frame.pixfmt == AVPixelFormat::AV_PIX_FMT_NV16 { + info!("mjpeg_rkmpp output pixfmt NV16 on first frame; converting to NV12"); + } self.last_pixfmt = Some(frame.pixfmt); } diff --git a/src/video/decoder/mod.rs b/src/video/decoder/mod.rs index 55a1569f..32b47d91 100644 --- a/src/video/decoder/mod.rs +++ b/src/video/decoder/mod.rs @@ -2,10 +2,6 @@ //! //! This module provides video decoding capabilities. -#[cfg(any(target_arch = "aarch64", target_arch = "arm"))] -pub mod mjpeg_rkmpp; pub mod mjpeg_turbo; -#[cfg(any(target_arch = "aarch64", target_arch = "arm"))] -pub use mjpeg_rkmpp::MjpegRkmppDecoder; pub use mjpeg_turbo::MjpegTurboDecoder; diff --git a/src/video/device.rs b/src/video/device.rs index 4a655d51..c99b4786 100644 --- a/src/video/device.rs +++ b/src/video/device.rs @@ -2,6 +2,8 @@ use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; +use std::sync::mpsc; +use std::time::Duration; use tracing::{debug, info, warn}; use v4l::capability::Flags; use v4l::prelude::*; @@ -12,6 +14,8 @@ use v4l::FourCC; use super::format::{PixelFormat, Resolution}; use crate::error::{AppError, Result}; +const DEVICE_PROBE_TIMEOUT_MS: u64 = 400; + /// Information about a video device #[derive(Debug, Clone, Serialize, Deserialize)] pub struct VideoDeviceInfo { @@ -401,32 +405,29 @@ pub fn enumerate_devices() -> Result> { debug!("Found video device: {:?}", path); - // Try to open and query the device - match VideoDevice::open(&path) { - Ok(device) => { - match device.info() { - Ok(info) => { - // Only include devices with video capture capability - if info.capabilities.video_capture || info.capabilities.video_capture_mplane - { - info!( - "Found capture device: {} ({}) - {} formats", - info.name, - info.driver, - info.formats.len() - ); - devices.push(info); - } else { - debug!("Skipping non-capture device: {:?}", path); - } - } - Err(e) => { - debug!("Failed to get info for {:?}: {}", path, e); - } + if !sysfs_maybe_capture(&path) { + debug!("Skipping non-capture candidate (sysfs): {:?}", path); + continue; + } + + // Try to open and query the device (with timeout) + match probe_device_with_timeout(&path, Duration::from_millis(DEVICE_PROBE_TIMEOUT_MS)) { + Some(info) => { + // Only include devices with video capture capability + if info.capabilities.video_capture || info.capabilities.video_capture_mplane { + info!( + "Found capture device: {} ({}) - {} formats", + info.name, + info.driver, + info.formats.len() + ); + devices.push(info); + } else { + debug!("Skipping non-capture device: {:?}", path); } } - Err(e) => { - debug!("Failed to open {:?}: {}", path, e); + None => { + debug!("Failed to probe {:?}", path); } } } @@ -438,6 +439,104 @@ pub fn enumerate_devices() -> Result> { Ok(devices) } +fn probe_device_with_timeout(path: &Path, timeout: Duration) -> Option { + let path = path.to_path_buf(); + let path_for_thread = path.clone(); + let (tx, rx) = mpsc::channel(); + + std::thread::spawn(move || { + let result = (|| -> Result { + let device = VideoDevice::open(&path_for_thread)?; + device.info() + })(); + let _ = tx.send(result); + }); + + match rx.recv_timeout(timeout) { + Ok(Ok(info)) => Some(info), + Ok(Err(e)) => { + debug!("Failed to get info for {:?}: {}", path, e); + None + } + Err(mpsc::RecvTimeoutError::Timeout) => { + warn!("Timed out probing video device: {:?}", path); + None + } + Err(_) => None, + } +} + +fn sysfs_maybe_capture(path: &Path) -> bool { + let name = match path.file_name().and_then(|n| n.to_str()) { + Some(name) => name, + None => return true, + }; + let sysfs_base = Path::new("/sys/class/video4linux").join(name); + + let sysfs_name = read_sysfs_string(&sysfs_base.join("name")) + .unwrap_or_default() + .to_lowercase(); + let uevent = read_sysfs_string(&sysfs_base.join("device/uevent")) + .unwrap_or_default() + .to_lowercase(); + let driver = extract_uevent_value(&uevent, "driver"); + + let mut maybe_capture = false; + let capture_hints = [ + "capture", + "hdmi", + "usb", + "uvc", + "ms2109", + "ms2130", + "macrosilicon", + "tc358743", + "grabber", + ]; + if capture_hints.iter().any(|hint| sysfs_name.contains(hint)) { + maybe_capture = true; + } + if let Some(driver) = driver { + if driver.contains("uvcvideo") || driver.contains("tc358743") { + maybe_capture = true; + } + } + + let skip_hints = [ + "codec", + "decoder", + "encoder", + "isp", + "mem2mem", + "m2m", + "vbi", + "radio", + "metadata", + "output", + ]; + if skip_hints.iter().any(|hint| sysfs_name.contains(hint)) && !maybe_capture { + return false; + } + + true +} + +fn read_sysfs_string(path: &Path) -> Option { + std::fs::read_to_string(path) + .ok() + .map(|value| value.trim().to_string()) +} + +fn extract_uevent_value(content: &str, key: &str) -> Option { + let key_upper = key.to_ascii_uppercase(); + for line in content.lines() { + if let Some(value) = line.strip_prefix(&format!("{}=", key_upper)) { + return Some(value.to_lowercase()); + } + } + None +} + /// Find the best video device for KVM use pub fn find_best_device() -> Result { let devices = enumerate_devices()?; diff --git a/src/video/encoder/h264.rs b/src/video/encoder/h264.rs index d02a398b..65c2512b 100644 --- a/src/video/encoder/h264.rs +++ b/src/video/encoder/h264.rs @@ -511,21 +511,6 @@ impl Encoder for H264Encoder { } } -/// Encoder statistics -#[derive(Debug, Clone, Default)] -pub struct EncoderStats { - /// Total frames encoded - pub frames_encoded: u64, - /// Total bytes output - pub bytes_output: u64, - /// Current encoding FPS - pub fps: f32, - /// Average encoding time per frame (ms) - pub avg_encode_time_ms: f32, - /// Keyframes encoded - pub keyframes: u64, -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/video/encoder/registry.rs b/src/video/encoder/registry.rs index edcb7780..1f9dd1a9 100644 --- a/src/video/encoder/registry.rs +++ b/src/video/encoder/registry.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; use std::sync::OnceLock; -use tracing::{debug, info}; +use tracing::{debug, info, warn}; use hwcodec::common::{DataFormat, Quality, RateControl}; use hwcodec::ffmpeg::AVPixelFormat; @@ -255,8 +255,33 @@ impl EncoderRegistry { thread_count: 1, }; - // Get all available encoders from hwcodec - let all_encoders = HwEncoder::available_encoders(ctx, None); + const DETECT_TIMEOUT_MS: u64 = 5000; + + // Get all available encoders from hwcodec with a hard timeout + let all_encoders = { + use std::sync::mpsc; + use std::time::Duration; + + info!("Encoder detection timeout: {}ms", DETECT_TIMEOUT_MS); + + let (tx, rx) = mpsc::channel(); + let ctx_clone = ctx.clone(); + std::thread::spawn(move || { + let result = HwEncoder::available_encoders(ctx_clone, None); + let _ = tx.send(result); + }); + + match rx.recv_timeout(Duration::from_millis(DETECT_TIMEOUT_MS)) { + Ok(encoders) => encoders, + Err(_) => { + warn!( + "Encoder detection timed out after {}ms, skipping hardware detection", + DETECT_TIMEOUT_MS + ); + Vec::new() + } + } + }; info!("Found {} encoders from hwcodec", all_encoders.len()); diff --git a/src/video/frame.rs b/src/video/frame.rs index cd66796e..dc8f4c92 100644 --- a/src/video/frame.rs +++ b/src/video/frame.rs @@ -1,17 +1,110 @@ //! Video frame data structures use bytes::Bytes; +use parking_lot::Mutex; use std::sync::Arc; use std::sync::OnceLock; use std::time::Instant; use super::format::{PixelFormat, Resolution}; +#[derive(Clone)] +enum FrameData { + Bytes(Bytes), + Pooled(Arc), +} + +impl std::fmt::Debug for FrameData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FrameData::Bytes(bytes) => f + .debug_struct("FrameData::Bytes") + .field("len", &bytes.len()) + .finish(), + FrameData::Pooled(buf) => f + .debug_struct("FrameData::Pooled") + .field("len", &buf.len()) + .finish(), + } + } +} + +#[derive(Debug)] +pub struct FrameBufferPool { + pool: Mutex>>, + max_buffers: usize, +} + +impl FrameBufferPool { + pub fn new(max_buffers: usize) -> Self { + Self { + pool: Mutex::new(Vec::new()), + max_buffers: max_buffers.max(1), + } + } + + pub fn take(&self, min_capacity: usize) -> Vec { + let mut pool = self.pool.lock(); + if let Some(mut buf) = pool.pop() { + if buf.capacity() < min_capacity { + buf.reserve(min_capacity - buf.capacity()); + } + buf + } else { + Vec::with_capacity(min_capacity) + } + } + + pub fn put(&self, mut buf: Vec) { + buf.clear(); + let mut pool = self.pool.lock(); + if pool.len() < self.max_buffers { + pool.push(buf); + } + } +} + +pub struct FrameBuffer { + data: Vec, + pool: Option>, +} + +impl FrameBuffer { + pub fn new(data: Vec, pool: Option>) -> Self { + Self { data, pool } + } + + pub fn as_slice(&self) -> &[u8] { + &self.data + } + + pub fn len(&self) -> usize { + self.data.len() + } +} + +impl std::fmt::Debug for FrameBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FrameBuffer") + .field("len", &self.data.len()) + .finish() + } +} + +impl Drop for FrameBuffer { + fn drop(&mut self) { + if let Some(pool) = self.pool.take() { + let data = std::mem::take(&mut self.data); + pool.put(data); + } + } +} + /// A video frame with metadata #[derive(Debug, Clone)] pub struct VideoFrame { /// Raw frame data - data: Arc, + data: FrameData, /// Cached xxHash64 of frame data (lazy computed for deduplication) hash: Arc>, /// Frame resolution @@ -40,7 +133,7 @@ impl VideoFrame { sequence: u64, ) -> Self { Self { - data: Arc::new(data), + data: FrameData::Bytes(data), hash: Arc::new(OnceLock::new()), resolution, format, @@ -63,24 +156,51 @@ impl VideoFrame { Self::new(Bytes::from(data), resolution, format, stride, sequence) } + /// Create a frame from pooled buffer + pub fn from_pooled( + data: Arc, + resolution: Resolution, + format: PixelFormat, + stride: u32, + sequence: u64, + ) -> Self { + Self { + data: FrameData::Pooled(data), + hash: Arc::new(OnceLock::new()), + resolution, + format, + stride, + key_frame: true, + sequence, + capture_ts: Instant::now(), + online: true, + } + } + /// Get frame data as bytes slice pub fn data(&self) -> &[u8] { - &self.data + match &self.data { + FrameData::Bytes(bytes) => bytes, + FrameData::Pooled(buf) => buf.as_slice(), + } } /// Get frame data as Bytes (cheap clone) pub fn data_bytes(&self) -> Bytes { - (*self.data).clone() + match &self.data { + FrameData::Bytes(bytes) => bytes.clone(), + FrameData::Pooled(buf) => Bytes::copy_from_slice(buf.as_slice()), + } } /// Get data length pub fn len(&self) -> usize { - self.data.len() + self.data().len() } /// Check if frame is empty pub fn is_empty(&self) -> bool { - self.data.is_empty() + self.data().is_empty() } /// Get width @@ -108,7 +228,7 @@ impl VideoFrame { pub fn get_hash(&self) -> u64 { *self .hash - .get_or_init(|| xxhash_rust::xxh64::xxh64(self.data.as_ref(), 0)) + .get_or_init(|| xxhash_rust::xxh64::xxh64(self.data(), 0)) } /// Check if format is JPEG/MJPEG @@ -121,25 +241,27 @@ impl VideoFrame { if !self.is_jpeg() { return false; } - if self.data.len() < 125 { + Self::is_valid_jpeg_bytes(self.data()) + } + + /// Validate JPEG bytes without constructing a frame + pub fn is_valid_jpeg_bytes(data: &[u8]) -> bool { + if data.len() < 125 { return false; } - // Check JPEG header - let start_marker = ((self.data[0] as u16) << 8) | self.data[1] as u16; + let start_marker = ((data[0] as u16) << 8) | data[1] as u16; if start_marker != 0xFFD8 { return false; } - // Check JPEG end marker - let end = self.data.len(); - let end_marker = ((self.data[end - 2] as u16) << 8) | self.data[end - 1] as u16; - // Valid end markers: 0xFFD9, 0xD900, 0x0000 (padded) + let end = data.len(); + let end_marker = ((data[end - 2] as u16) << 8) | data[end - 1] as u16; matches!(end_marker, 0xFFD9 | 0xD900 | 0x0000) } /// Create an offline placeholder frame pub fn offline(resolution: Resolution, format: PixelFormat) -> Self { Self { - data: Arc::new(Bytes::new()), + data: FrameData::Bytes(Bytes::new()), hash: Arc::new(OnceLock::new()), resolution, format, @@ -175,65 +297,3 @@ impl From<&VideoFrame> for FrameMeta { } } } - -/// Ring buffer for storing recent frames -pub struct FrameRing { - frames: Vec>, - capacity: usize, - write_pos: usize, - count: usize, -} - -impl FrameRing { - /// Create a new frame ring with specified capacity - pub fn new(capacity: usize) -> Self { - assert!(capacity > 0, "Ring capacity must be > 0"); - Self { - frames: (0..capacity).map(|_| None).collect(), - capacity, - write_pos: 0, - count: 0, - } - } - - /// Push a frame into the ring - pub fn push(&mut self, frame: VideoFrame) { - self.frames[self.write_pos] = Some(frame); - self.write_pos = (self.write_pos + 1) % self.capacity; - if self.count < self.capacity { - self.count += 1; - } - } - - /// Get the latest frame - pub fn latest(&self) -> Option<&VideoFrame> { - if self.count == 0 { - return None; - } - let pos = if self.write_pos == 0 { - self.capacity - 1 - } else { - self.write_pos - 1 - }; - self.frames[pos].as_ref() - } - - /// Get number of frames in ring - pub fn len(&self) -> usize { - self.count - } - - /// Check if ring is empty - pub fn is_empty(&self) -> bool { - self.count == 0 - } - - /// Clear all frames - pub fn clear(&mut self) { - for frame in &mut self.frames { - *frame = None; - } - self.write_pos = 0; - self.count = 0; - } -} diff --git a/src/video/h264_pipeline.rs b/src/video/h264_pipeline.rs index ddc47d8f..c20273cf 100644 --- a/src/video/h264_pipeline.rs +++ b/src/video/h264_pipeline.rs @@ -53,22 +53,8 @@ impl Default for H264PipelineConfig { /// H264 pipeline statistics #[derive(Debug, Clone, Default)] pub struct H264PipelineStats { - /// Total frames captured - pub frames_captured: u64, - /// Total frames encoded - pub frames_encoded: u64, - /// Frames dropped (encoding too slow) - pub frames_dropped: u64, - /// Total bytes encoded - pub bytes_encoded: u64, - /// Keyframes encoded - pub keyframes_encoded: u64, - /// Average encoding time per frame (ms) - pub avg_encode_time_ms: f32, /// Current encoding FPS pub current_fps: f32, - /// Errors encountered - pub errors: u64, } /// H264 video encoding pipeline @@ -84,8 +70,6 @@ pub struct H264Pipeline { stats: Arc>, /// Running state running: watch::Sender, - /// Encode time accumulator for averaging - encode_times: Arc>>, } impl H264Pipeline { @@ -183,7 +167,6 @@ impl H264Pipeline { video_track, stats: Arc::new(Mutex::new(H264PipelineStats::default())), running: running_tx, - encode_times: Arc::new(Mutex::new(Vec::with_capacity(100))), }) } @@ -222,7 +205,6 @@ impl H264Pipeline { let nv12_converter = self.nv12_converter.lock().await.take(); let video_track = self.video_track.clone(); let stats = self.stats.clone(); - let encode_times = self.encode_times.clone(); let config = self.config.clone(); let mut running_rx = self.running.subscribe(); @@ -275,12 +257,6 @@ impl H264Pipeline { } } - // Update captured count - { - let mut s = stats.lock().await; - s.frames_captured += 1; - } - // Convert to NV12 for VAAPI encoder // BGR24/RGB24/YUYV -> NV12 (via NV12 converter) // NV12 -> pass through @@ -297,8 +273,6 @@ impl H264Pipeline { Ok(nv12_data) => encoder.encode_raw(nv12_data, pts_ms), Err(e) => { error!("NV12 conversion failed: {}", e); - let mut s = stats.lock().await; - s.errors += 1; continue; } } @@ -323,35 +297,13 @@ impl H264Pipeline { .await { error!("Failed to write frame to track: {}", e); - let mut s = stats.lock().await; - s.errors += 1; } else { - // Update stats - let encode_time = start.elapsed().as_secs_f32() * 1000.0; - let mut s = stats.lock().await; - s.frames_encoded += 1; - s.bytes_encoded += frame.data.len() as u64; - if is_keyframe { - s.keyframes_encoded += 1; - } - - // Update encode time average - let mut times = encode_times.lock().await; - times.push(encode_time); - if times.len() > 100 { - times.remove(0); - } - if !times.is_empty() { - s.avg_encode_time_ms = - times.iter().sum::() / times.len() as f32; - } + let _ = start; } } } Err(e) => { error!("Encoding failed: {}", e); - let mut s = stats.lock().await; - s.errors += 1; } } @@ -365,8 +317,7 @@ impl H264Pipeline { } } Err(broadcast::error::RecvError::Lagged(n)) => { - let mut s = stats.lock().await; - s.frames_dropped += n; + let _ = n; } Err(broadcast::error::RecvError::Closed) => { info!("Frame channel closed, stopping H264 pipeline"); diff --git a/src/video/shared_video_pipeline.rs b/src/video/shared_video_pipeline.rs index 928709e4..6ab721ff 100644 --- a/src/video/shared_video_pipeline.rs +++ b/src/video/shared_video_pipeline.rs @@ -17,20 +17,31 @@ //! ``` use bytes::Bytes; +use parking_lot::RwLock as ParkingRwLock; use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; -use tokio::sync::{broadcast, watch, Mutex, RwLock}; +use tokio::sync::{broadcast, mpsc, watch, Mutex, RwLock}; use tracing::{debug, error, info, trace, warn}; /// Grace period before auto-stopping pipeline when no subscribers (in seconds) const AUTO_STOP_GRACE_PERIOD_SECS: u64 = 3; +/// Minimum valid frame size for capture +const MIN_CAPTURE_FRAME_SIZE: usize = 128; +/// Validate JPEG header every N frames to reduce overhead +const JPEG_VALIDATE_INTERVAL: u64 = 30; use crate::error::{AppError, Result}; use crate::video::convert::{Nv12Converter, PixelConverter}; -#[cfg(any(target_arch = "aarch64", target_arch = "arm"))] -use crate::video::decoder::MjpegRkmppDecoder; use crate::video::decoder::MjpegTurboDecoder; +#[cfg(any(target_arch = "aarch64", target_arch = "arm"))] +use hwcodec::ffmpeg_hw::{last_error_message as ffmpeg_hw_last_error, HwMjpegH26xConfig, HwMjpegH26xPipeline}; +use v4l::buffer::Type as BufferType; +use v4l::io::traits::CaptureStream; +use v4l::prelude::*; +use v4l::video::Capture; +use v4l::video::capture::Parameters; +use v4l::Format; use crate::video::encoder::h264::{detect_best_encoder, H264Config, H264Encoder, H264InputFormat}; use crate::video::encoder::h265::{ detect_best_h265_encoder, H265Config, H265Encoder, H265InputFormat, @@ -40,7 +51,7 @@ use crate::video::encoder::traits::EncoderConfig; use crate::video::encoder::vp8::{detect_best_vp8_encoder, VP8Config, VP8Encoder}; use crate::video::encoder::vp9::{detect_best_vp9_encoder, VP9Config, VP9Encoder}; use crate::video::format::{PixelFormat, Resolution}; -use crate::video::frame::VideoFrame; +use crate::video::frame::{FrameBuffer, FrameBufferPool, VideoFrame}; /// Encoded video frame for distribution #[derive(Debug, Clone)] @@ -59,6 +70,10 @@ pub struct EncodedVideoFrame { pub codec: VideoEncoderType, } +enum PipelineCmd { + SetBitrate { bitrate_kbps: u32, gop: u32 }, +} + /// Shared video pipeline configuration #[derive(Debug, Clone)] pub struct SharedVideoPipelineConfig { @@ -150,16 +165,22 @@ impl SharedVideoPipelineConfig { /// Pipeline statistics #[derive(Debug, Clone, Default)] pub struct SharedVideoPipelineStats { - pub frames_captured: u64, - pub frames_encoded: u64, - pub frames_dropped: u64, - pub frames_skipped: u64, - pub bytes_encoded: u64, - pub keyframes_encoded: u64, - pub avg_encode_time_ms: f32, pub current_fps: f32, - pub errors: u64, - pub subscribers: u64, +} + +struct EncoderThreadState { + encoder: Option>, + mjpeg_decoder: Option, + nv12_converter: Option, + yuv420p_converter: Option, + encoder_needs_yuv420p: bool, + #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] + ffmpeg_hw_pipeline: Option, + #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] + ffmpeg_hw_enabled: bool, + fps: u32, + codec: VideoEncoderType, + input_format: PixelFormat, } /// Universal video encoder trait object @@ -296,16 +317,12 @@ impl VideoEncoderTrait for VP9EncoderWrapper { } enum MjpegDecoderKind { - #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] - Rkmpp(MjpegRkmppDecoder), Turbo(MjpegTurboDecoder), } impl MjpegDecoderKind { fn decode(&mut self, data: &[u8]) -> Result> { match self { - #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] - MjpegDecoderKind::Rkmpp(decoder) => decoder.decode_to_nv12(data), MjpegDecoderKind::Turbo(decoder) => decoder.decode_to_rgb(data), } } @@ -314,18 +331,13 @@ impl MjpegDecoderKind { /// Universal shared video pipeline pub struct SharedVideoPipeline { config: RwLock, - encoder: Mutex>>, - mjpeg_decoder: Mutex>, - nv12_converter: Mutex>, - yuv420p_converter: Mutex>, - /// Whether the encoder needs YUV420P (true) or NV12 (false) - encoder_needs_yuv420p: AtomicBool, - /// Whether YUYV direct input is enabled (RKMPP optimization) - direct_input: AtomicBool, - frame_tx: broadcast::Sender, + subscribers: ParkingRwLock>>>, stats: Mutex, running: watch::Sender, running_rx: watch::Receiver, + cmd_tx: ParkingRwLock>>, + /// Fast running flag for blocking capture loop + running_flag: AtomicBool, /// Frame sequence counter (atomic for lock-free access) sequence: AtomicU64, /// Atomic flag for keyframe request (avoids lock contention) @@ -347,21 +359,16 @@ impl SharedVideoPipeline { config.input_format ); - let (frame_tx, _) = broadcast::channel(16); // Reduced from 64 for lower latency let (running_tx, running_rx) = watch::channel(false); let pipeline = Arc::new(Self { config: RwLock::new(config), - encoder: Mutex::new(None), - mjpeg_decoder: Mutex::new(None), - nv12_converter: Mutex::new(None), - yuv420p_converter: Mutex::new(None), - encoder_needs_yuv420p: AtomicBool::new(false), - direct_input: AtomicBool::new(false), - frame_tx, + subscribers: ParkingRwLock::new(Vec::new()), stats: Mutex::new(SharedVideoPipelineStats::default()), running: running_tx, running_rx, + cmd_tx: ParkingRwLock::new(None), + running_flag: AtomicBool::new(false), sequence: AtomicU64::new(0), keyframe_requested: AtomicBool::new(false), pipeline_start_time_ms: AtomicI64::new(0), @@ -370,9 +377,7 @@ impl SharedVideoPipeline { Ok(pipeline) } - /// Initialize encoder based on config - async fn init_encoder(&self) -> Result<()> { - let config = self.config.read().await.clone(); + fn build_encoder_state(config: &SharedVideoPipelineConfig) -> Result { let registry = EncoderRegistry::global(); // Helper to get codec name for specific backend @@ -501,46 +506,62 @@ impl SharedVideoPipeline { } }; + #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] let is_rkmpp_encoder = selected_codec_name.contains("rkmpp"); - let is_software_encoder = selected_codec_name.contains("libx264") - || selected_codec_name.contains("libx265") - || selected_codec_name.contains("libvpx"); + #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] + if needs_mjpeg_decode + && is_rkmpp_encoder + && matches!(config.output_codec, VideoEncoderType::H264 | VideoEncoderType::H265) + { + info!( + "Initializing FFmpeg HW MJPEG->{} pipeline (no fallback)", + config.output_codec + ); + let hw_config = HwMjpegH26xConfig { + decoder: "mjpeg_rkmpp".to_string(), + encoder: selected_codec_name.clone(), + width: config.resolution.width as i32, + height: config.resolution.height as i32, + fps: config.fps as i32, + bitrate_kbps: config.bitrate_kbps() as i32, + gop: config.gop_size() as i32, + thread_count: 1, + }; + let pipeline = HwMjpegH26xPipeline::new(hw_config).map_err(|e| { + let detail = if e.is_empty() { ffmpeg_hw_last_error() } else { e }; + AppError::VideoError(format!( + "FFmpeg HW MJPEG->{} init failed: {}", + config.output_codec, detail + )) + })?; + info!("Using FFmpeg HW MJPEG->{} pipeline", config.output_codec); + return Ok(EncoderThreadState { + encoder: None, + mjpeg_decoder: None, + nv12_converter: None, + yuv420p_converter: None, + encoder_needs_yuv420p: false, + #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] + ffmpeg_hw_pipeline: Some(pipeline), + #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] + ffmpeg_hw_enabled: true, + fps: config.fps, + codec: config.output_codec, + input_format: config.input_format, + }); + } let pipeline_input_format = if needs_mjpeg_decode { - if is_rkmpp_encoder { - info!( - "MJPEG input detected, using RKMPP decoder ({} -> NV12 with NV16 fallback)", - config.input_format - ); - #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] - { - let decoder = MjpegRkmppDecoder::new(config.resolution)?; - *self.mjpeg_decoder.lock().await = Some(MjpegDecoderKind::Rkmpp(decoder)); - PixelFormat::Nv12 - } - #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))] - { - return Err(AppError::VideoError( - "RKMPP MJPEG decode is only supported on ARM builds".to_string(), - )); - } - } else if is_software_encoder { - info!( - "MJPEG input detected, using TurboJPEG decoder ({} -> RGB24)", - config.input_format - ); - let decoder = MjpegTurboDecoder::new(config.resolution)?; - *self.mjpeg_decoder.lock().await = Some(MjpegDecoderKind::Turbo(decoder)); - PixelFormat::Rgb24 - } else { - return Err(AppError::VideoError( - "MJPEG input requires RKMPP or software encoder".to_string(), - )); - } + info!( + "MJPEG input detected, using TurboJPEG decoder ({} -> RGB24)", + config.input_format + ); + let decoder = MjpegTurboDecoder::new(config.resolution)?; + (Some(MjpegDecoderKind::Turbo(decoder)), PixelFormat::Rgb24) } else { - *self.mjpeg_decoder.lock().await = None; - config.input_format + (None, config.input_format) }; + let (mjpeg_decoder, pipeline_input_format) = pipeline_input_format; // Create encoder based on codec type let encoder: Box = match config.output_codec { @@ -856,24 +877,32 @@ impl SharedVideoPipeline { } }; - *self.encoder.lock().await = Some(encoder); - *self.nv12_converter.lock().await = nv12_converter; - *self.yuv420p_converter.lock().await = yuv420p_converter; - self.encoder_needs_yuv420p - .store(needs_yuv420p, Ordering::Release); - self.direct_input.store(use_direct_input, Ordering::Release); - - Ok(()) + Ok(EncoderThreadState { + encoder: Some(encoder), + mjpeg_decoder, + nv12_converter, + yuv420p_converter, + encoder_needs_yuv420p: needs_yuv420p, + #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] + ffmpeg_hw_pipeline: None, + #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] + ffmpeg_hw_enabled: false, + fps: config.fps, + codec: config.output_codec, + input_format: config.input_format, + }) } /// Subscribe to encoded frames - pub fn subscribe(&self) -> broadcast::Receiver { - self.frame_tx.subscribe() + pub fn subscribe(&self) -> mpsc::Receiver> { + let (tx, rx) = mpsc::channel(4); + self.subscribers.write().push(tx); + rx } /// Get subscriber count pub fn subscriber_count(&self) -> usize { - self.frame_tx.receiver_count() + self.subscribers.read().iter().filter(|tx| !tx.is_closed()).count() } /// Report that a receiver has lagged behind @@ -899,11 +928,50 @@ impl SharedVideoPipeline { info!("[Pipeline] Keyframe requested for new client"); } + fn send_cmd(&self, cmd: PipelineCmd) { + let tx = self.cmd_tx.read().clone(); + if let Some(tx) = tx { + let _ = tx.send(cmd); + } + } + + fn clear_cmd_tx(&self) { + let mut guard = self.cmd_tx.write(); + *guard = None; + } + + fn apply_cmd(&self, state: &mut EncoderThreadState, cmd: PipelineCmd) -> Result<()> { + match cmd { + PipelineCmd::SetBitrate { bitrate_kbps, gop } => { + #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))] + let _ = gop; + #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] + if state.ffmpeg_hw_enabled { + if let Some(ref mut pipeline) = state.ffmpeg_hw_pipeline { + pipeline + .reconfigure(bitrate_kbps as i32, gop as i32) + .map_err(|e| { + let detail = if e.is_empty() { ffmpeg_hw_last_error() } else { e }; + AppError::VideoError(format!( + "FFmpeg HW reconfigure failed: {}", + detail + )) + })?; + return Ok(()); + } + } + + if let Some(ref mut encoder) = state.encoder { + encoder.set_bitrate(bitrate_kbps)?; + } + } + } + Ok(()) + } + /// Get current stats pub async fn stats(&self) -> SharedVideoPipelineStats { - let mut stats = self.stats.lock().await.clone(); - stats.subscribers = self.frame_tx.receiver_count() as u64; - stats + self.stats.lock().await.clone() } /// Check if running @@ -919,6 +987,27 @@ impl SharedVideoPipeline { self.running_rx.clone() } + async fn broadcast_encoded(&self, frame: Arc) { + let subscribers = { + let guard = self.subscribers.read(); + if guard.is_empty() { + return; + } + guard.iter().cloned().collect::>() + }; + + for tx in &subscribers { + if tx.send(frame.clone()).await.is_err() { + // Receiver dropped; cleanup happens below. + } + } + + if subscribers.iter().any(|tx| tx.is_closed()) { + let mut guard = self.subscribers.write(); + guard.retain(|tx| !tx.is_closed()); + } + } + /// Get current codec pub async fn current_codec(&self) -> VideoEncoderType { self.config.read().await.output_codec @@ -938,12 +1027,7 @@ impl SharedVideoPipeline { config.output_codec = codec; } - // Clear encoder state - *self.encoder.lock().await = None; - *self.mjpeg_decoder.lock().await = None; - *self.nv12_converter.lock().await = None; - *self.yuv420p_converter.lock().await = None; - self.encoder_needs_yuv420p.store(false, Ordering::Release); + self.clear_cmd_tx(); info!("Switched to {} codec", codec); Ok(()) @@ -959,10 +1043,10 @@ impl SharedVideoPipeline { return Ok(()); } - self.init_encoder().await?; - let _ = self.running.send(true); - let config = self.config.read().await.clone(); + let mut encoder_state = Self::build_encoder_state(&config)?; + let _ = self.running.send(true); + self.running_flag.store(true, Ordering::Release); let gop_size = config.gop_size(); info!( "Starting {} pipeline (GOP={})", @@ -970,6 +1054,11 @@ impl SharedVideoPipeline { ); let pipeline = self.clone(); + let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel(); + { + let mut guard = self.cmd_tx.write(); + *guard = Some(cmd_tx); + } tokio::spawn(async move { let mut frame_count: u64 = 0; @@ -977,13 +1066,6 @@ impl SharedVideoPipeline { let mut fps_frame_count: u64 = 0; let mut running_rx = pipeline.running_rx.clone(); - // Local counters for batch stats update (reduce lock contention) - let mut local_frames_encoded: u64 = 0; - let mut local_bytes_encoded: u64 = 0; - let mut local_keyframes: u64 = 0; - let mut local_errors: u64 = 0; - let mut local_dropped: u64 = 0; - let mut local_skipped: u64 = 0; // Track when we last had subscribers for auto-stop feature let mut no_subscribers_since: Option = None; let grace_period = Duration::from_secs(AUTO_STOP_GRACE_PERIOD_SECS); @@ -1001,7 +1083,12 @@ impl SharedVideoPipeline { result = frame_rx.recv() => { match result { Ok(video_frame) => { - let subscriber_count = pipeline.frame_tx.receiver_count(); + while let Ok(cmd) = cmd_rx.try_recv() { + if let Err(e) = pipeline.apply_cmd(&mut encoder_state, cmd) { + error!("Failed to apply pipeline command: {}", e); + } + } + let subscriber_count = pipeline.subscriber_count(); if subscriber_count == 0 { // Track when we started having no subscribers @@ -1019,6 +1106,9 @@ impl SharedVideoPipeline { ); // Signal stop and break out of loop let _ = pipeline.running.send(false); + pipeline + .running_flag + .store(false, Ordering::Release); break; } } @@ -1033,18 +1123,10 @@ impl SharedVideoPipeline { } } - match pipeline.encode_frame(&video_frame, frame_count).await { + match pipeline.encode_frame_sync(&mut encoder_state, &video_frame, frame_count) { Ok(Some(encoded_frame)) => { - // Send frame to all subscribers - // Note: broadcast::send is non-blocking - let _ = pipeline.frame_tx.send(encoded_frame.clone()); - - // Update local counters (no lock) - local_frames_encoded += 1; - local_bytes_encoded += encoded_frame.data.len() as u64; - if encoded_frame.is_keyframe { - local_keyframes += 1; - } + let encoded_arc = Arc::new(encoded_frame); + pipeline.broadcast_encoded(encoded_arc).await; frame_count += 1; fps_frame_count += 1; @@ -1052,11 +1134,10 @@ impl SharedVideoPipeline { Ok(None) => {} Err(e) => { error!("Encoding failed: {}", e); - local_errors += 1; } } - // Batch update stats every second (reduces lock contention) + // Update FPS every second (reduces lock contention) let fps_elapsed = last_fps_time.elapsed(); if fps_elapsed >= Duration::from_secs(1) { let current_fps = @@ -1064,27 +1145,13 @@ impl SharedVideoPipeline { fps_frame_count = 0; last_fps_time = Instant::now(); - // Single lock acquisition for all stats + // Single lock acquisition for FPS let mut s = pipeline.stats.lock().await; - s.frames_encoded += local_frames_encoded; - s.bytes_encoded += local_bytes_encoded; - s.keyframes_encoded += local_keyframes; - s.errors += local_errors; - s.frames_dropped += local_dropped; - s.frames_skipped += local_skipped; s.current_fps = current_fps; - - // Reset local counters - local_frames_encoded = 0; - local_bytes_encoded = 0; - local_keyframes = 0; - local_errors = 0; - local_dropped = 0; - local_skipped = 0; } } Err(broadcast::error::RecvError::Lagged(n)) => { - local_dropped += n; + let _ = n; } Err(broadcast::error::RecvError::Closed) => { break; @@ -1094,37 +1161,277 @@ impl SharedVideoPipeline { } } + pipeline.clear_cmd_tx(); + pipeline.running_flag.store(false, Ordering::Release); info!("Video pipeline stopped"); }); Ok(()) } - /// Encode a single frame - async fn encode_frame( + /// Start the pipeline by owning capture + encode in a single loop. + /// + /// This avoids the raw-frame broadcast path and keeps capture and encode + /// in the same thread for lower overhead. + pub async fn start_with_device( + self: &Arc, + device_path: std::path::PathBuf, + buffer_count: u32, + _jpeg_quality: u8, + ) -> Result<()> { + if *self.running_rx.borrow() { + warn!("Pipeline already running"); + return Ok(()); + } + + let config = self.config.read().await.clone(); + let mut encoder_state = Self::build_encoder_state(&config)?; + let _ = self.running.send(true); + self.running_flag.store(true, Ordering::Release); + + let pipeline = self.clone(); + let latest_frame: Arc>>> = + Arc::new(ParkingRwLock::new(None)); + let (frame_seq_tx, mut frame_seq_rx) = watch::channel(0u64); + let buffer_pool = Arc::new(FrameBufferPool::new(buffer_count.max(4) as usize)); + let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel(); + { + let mut guard = self.cmd_tx.write(); + *guard = Some(cmd_tx); + } + + // Encoder loop (runs on tokio, consumes latest frame) + { + let pipeline = pipeline.clone(); + let latest_frame = latest_frame.clone(); + tokio::spawn(async move { + let mut frame_count: u64 = 0; + let mut last_fps_time = Instant::now(); + let mut fps_frame_count: u64 = 0; + let mut last_seq = *frame_seq_rx.borrow(); + + while pipeline.running_flag.load(Ordering::Acquire) { + if frame_seq_rx.changed().await.is_err() { + break; + } + if !pipeline.running_flag.load(Ordering::Acquire) { + break; + } + + let seq = *frame_seq_rx.borrow(); + if seq == last_seq { + continue; + } + last_seq = seq; + + if pipeline.subscriber_count() == 0 { + continue; + } + + while let Ok(cmd) = cmd_rx.try_recv() { + if let Err(e) = pipeline.apply_cmd(&mut encoder_state, cmd) { + error!("Failed to apply pipeline command: {}", e); + } + } + + let frame = { + let guard = latest_frame.read(); + guard.clone() + }; + let frame = match frame { + Some(f) => f, + None => continue, + }; + + match pipeline.encode_frame_sync(&mut encoder_state, &frame, frame_count) { + Ok(Some(encoded_frame)) => { + let encoded_arc = Arc::new(encoded_frame); + pipeline.broadcast_encoded(encoded_arc).await; + + frame_count += 1; + fps_frame_count += 1; + } + Ok(None) => {} + Err(e) => { + error!("Encoding failed: {}", e); + } + } + + let fps_elapsed = last_fps_time.elapsed(); + if fps_elapsed >= Duration::from_secs(1) { + let current_fps = fps_frame_count as f32 / fps_elapsed.as_secs_f32(); + fps_frame_count = 0; + last_fps_time = Instant::now(); + + let mut s = pipeline.stats.lock().await; + s.current_fps = current_fps; + } + } + + pipeline.clear_cmd_tx(); + }); + } + + // Capture loop (runs on thread, updates latest frame) + { + let pipeline = pipeline.clone(); + let latest_frame = latest_frame.clone(); + let frame_seq_tx = frame_seq_tx.clone(); + let buffer_pool = buffer_pool.clone(); + std::thread::spawn(move || { + let device = match Device::with_path(&device_path) { + Ok(d) => d, + Err(e) => { + error!("Failed to open device {:?}: {}", device_path, e); + let _ = pipeline.running.send(false); + pipeline.running_flag.store(false, Ordering::Release); + let _ = frame_seq_tx.send(1); + return; + } + }; + + let requested_format = Format::new( + config.resolution.width, + config.resolution.height, + config.input_format.to_fourcc(), + ); + + let actual_format = match device.set_format(&requested_format) { + Ok(f) => f, + Err(e) => { + error!("Failed to set capture format: {}", e); + let _ = pipeline.running.send(false); + pipeline.running_flag.store(false, Ordering::Release); + let _ = frame_seq_tx.send(1); + return; + } + }; + + let resolution = Resolution::new(actual_format.width, actual_format.height); + let pixel_format = + PixelFormat::from_fourcc(actual_format.fourcc).unwrap_or(config.input_format); + let stride = actual_format.stride; + + if config.fps > 0 { + if let Err(e) = device.set_params(&Parameters::with_fps(config.fps)) { + warn!("Failed to set hardware FPS: {}", e); + } + } + + let mut stream = match MmapStream::with_buffers( + &device, + BufferType::VideoCapture, + buffer_count.max(1), + ) { + Ok(s) => s, + Err(e) => { + error!("Failed to create capture stream: {}", e); + let _ = pipeline.running.send(false); + pipeline.running_flag.store(false, Ordering::Release); + let _ = frame_seq_tx.send(1); + return; + } + }; + + let mut no_subscribers_since: Option = None; + let grace_period = Duration::from_secs(AUTO_STOP_GRACE_PERIOD_SECS); + let mut sequence: u64 = 0; + let mut validate_counter: u64 = 0; + + while pipeline.running_flag.load(Ordering::Acquire) { + let subscriber_count = pipeline.subscriber_count(); + if subscriber_count == 0 { + if no_subscribers_since.is_none() { + no_subscribers_since = Some(Instant::now()); + trace!("No subscribers, starting grace period timer"); + } + + if let Some(since) = no_subscribers_since { + if since.elapsed() >= grace_period { + info!( + "No subscribers for {}s, auto-stopping video pipeline", + grace_period.as_secs() + ); + let _ = pipeline.running.send(false); + pipeline.running_flag.store(false, Ordering::Release); + let _ = frame_seq_tx.send(sequence.wrapping_add(1)); + break; + } + } + + std::thread::sleep(Duration::from_millis(5)); + continue; + } else if no_subscribers_since.is_some() { + trace!("Subscriber connected, resetting grace period timer"); + no_subscribers_since = None; + } + + let (buf, meta) = match stream.next() { + Ok(frame_data) => frame_data, + Err(e) => { + if e.kind() == std::io::ErrorKind::TimedOut { + warn!("Capture timeout - no signal?"); + } else { + error!("Capture error: {}", e); + } + continue; + } + }; + + let frame_size = meta.bytesused as usize; + if frame_size < MIN_CAPTURE_FRAME_SIZE { + continue; + } + + validate_counter = validate_counter.wrapping_add(1); + if pixel_format.is_compressed() + && validate_counter % JPEG_VALIDATE_INTERVAL == 0 + && !VideoFrame::is_valid_jpeg_bytes(&buf[..frame_size]) + { + continue; + } + + let mut owned = buffer_pool.take(frame_size); + owned.resize(frame_size, 0); + owned[..frame_size].copy_from_slice(&buf[..frame_size]); + let frame = Arc::new(VideoFrame::from_pooled( + Arc::new(FrameBuffer::new(owned, Some(buffer_pool.clone()))), + resolution, + pixel_format, + stride, + sequence, + )); + sequence = sequence.wrapping_add(1); + + { + let mut guard = latest_frame.write(); + *guard = Some(frame); + } + let _ = frame_seq_tx.send(sequence); + + } + + pipeline.running_flag.store(false, Ordering::Release); + let _ = pipeline.running.send(false); + let _ = frame_seq_tx.send(sequence.wrapping_add(1)); + info!("Video pipeline stopped"); + }); + } + + Ok(()) + } + + /// Encode a single frame (synchronous, no async locks) + fn encode_frame_sync( &self, + state: &mut EncoderThreadState, frame: &VideoFrame, frame_count: u64, ) -> Result> { - let (fps, codec, input_format) = { - let config = self.config.read().await; - (config.fps, config.output_codec, config.input_format) - }; - + let fps = state.fps; + let codec = state.codec; + let input_format = state.input_format; let raw_frame = frame.data(); - let decoded_buf = if input_format.is_compressed() { - let decoded = { - let mut decoder_guard = self.mjpeg_decoder.lock().await; - let decoder = decoder_guard.as_mut().ok_or_else(|| { - AppError::VideoError("MJPEG decoder not initialized".to_string()) - })?; - decoder.decode(raw_frame)? - }; - Some(decoded) - } else { - None - }; - let raw_frame = decoded_buf.as_deref().unwrap_or(raw_frame); // Calculate PTS from real capture timestamp (lock-free using AtomicI64) // This ensures smooth playback even when capture timing varies @@ -1149,6 +1456,53 @@ impl SharedVideoPipeline { current_ts_ms - start_ts }; + #[cfg(any(target_arch = "aarch64", target_arch = "arm"))] + if state.ffmpeg_hw_enabled { + if input_format != PixelFormat::Mjpeg { + return Err(AppError::VideoError( + "FFmpeg HW pipeline requires MJPEG input".to_string(), + )); + } + let pipeline = state.ffmpeg_hw_pipeline.as_mut().ok_or_else(|| { + AppError::VideoError("FFmpeg HW pipeline not initialized".to_string()) + })?; + + if self.keyframe_requested.swap(false, Ordering::AcqRel) { + pipeline.request_keyframe(); + debug!("[Pipeline] FFmpeg HW keyframe requested"); + } + + let packet = pipeline.encode(raw_frame, pts_ms).map_err(|e| { + let detail = if e.is_empty() { ffmpeg_hw_last_error() } else { e }; + AppError::VideoError(format!("FFmpeg HW encode failed: {}", detail)) + })?; + + if let Some((data, is_keyframe)) = packet { + let sequence = self.sequence.fetch_add(1, Ordering::Relaxed) + 1; + return Ok(Some(EncodedVideoFrame { + data: Bytes::from(data), + pts_ms, + is_keyframe, + sequence, + duration: Duration::from_millis(1000 / fps as u64), + codec, + })); + } + + return Ok(None); + } + + let decoded_buf = if input_format.is_compressed() { + let decoder = state.mjpeg_decoder.as_mut().ok_or_else(|| { + AppError::VideoError("MJPEG decoder not initialized".to_string()) + })?; + let decoded = decoder.decode(raw_frame)?; + Some(decoded) + } else { + None + }; + let raw_frame = decoded_buf.as_deref().unwrap_or(raw_frame); + // Debug log for H265 if codec == VideoEncoderType::H265 && frame_count % 30 == 1 { debug!( @@ -1159,12 +1513,9 @@ impl SharedVideoPipeline { ); } - let mut nv12_converter = self.nv12_converter.lock().await; - let mut yuv420p_converter = self.yuv420p_converter.lock().await; - let needs_yuv420p = self.encoder_needs_yuv420p.load(Ordering::Acquire); - let mut encoder_guard = self.encoder.lock().await; - - let encoder = encoder_guard + let needs_yuv420p = state.encoder_needs_yuv420p; + let encoder = state + .encoder .as_mut() .ok_or_else(|| AppError::VideoError("Encoder not initialized".to_string()))?; @@ -1174,16 +1525,16 @@ impl SharedVideoPipeline { debug!("[Pipeline] Keyframe will be generated for this frame"); } - let encode_result = if needs_yuv420p && yuv420p_converter.is_some() { + let encode_result = if needs_yuv420p && state.yuv420p_converter.is_some() { // Software encoder with direct input conversion to YUV420P - let conv = yuv420p_converter.as_mut().unwrap(); + let conv = state.yuv420p_converter.as_mut().unwrap(); let yuv420p_data = conv .convert(raw_frame) .map_err(|e| AppError::VideoError(format!("YUV420P conversion failed: {}", e)))?; encoder.encode_raw(yuv420p_data, pts_ms) - } else if nv12_converter.is_some() { + } else if state.nv12_converter.is_some() { // Hardware encoder with input conversion to NV12 - let conv = nv12_converter.as_mut().unwrap(); + let conv = state.nv12_converter.as_mut().unwrap(); let nv12_data = conv .convert(raw_frame) .map_err(|e| AppError::VideoError(format!("NV12 conversion failed: {}", e)))?; @@ -1193,10 +1544,6 @@ impl SharedVideoPipeline { encoder.encode_raw(raw_frame, pts_ms) }; - drop(encoder_guard); - drop(nv12_converter); - drop(yuv420p_converter); - match encode_result { Ok(frames) => { if !frames.is_empty() { @@ -1255,6 +1602,8 @@ impl SharedVideoPipeline { pub fn stop(&self) { if *self.running_rx.borrow() { let _ = self.running.send(false); + self.running_flag.store(false, Ordering::Release); + self.clear_cmd_tx(); info!("Stopping video pipeline"); } } @@ -1265,10 +1614,12 @@ impl SharedVideoPipeline { preset: crate::video::encoder::BitratePreset, ) -> Result<()> { let bitrate_kbps = preset.bitrate_kbps(); - if let Some(ref mut encoder) = *self.encoder.lock().await { - encoder.set_bitrate(bitrate_kbps)?; - self.config.write().await.bitrate_preset = preset; - } + let gop = { + let mut config = self.config.write().await; + config.bitrate_preset = preset; + config.gop_size() + }; + self.send_cmd(PipelineCmd::SetBitrate { bitrate_kbps, gop }); Ok(()) } diff --git a/src/video/stream_manager.rs b/src/video/stream_manager.rs index 3dc56398..cabe553a 100644 --- a/src/video/stream_manager.rs +++ b/src/video/stream_manager.rs @@ -135,7 +135,8 @@ impl VideoStreamManager { /// Set event bus for notifications pub async fn set_event_bus(&self, events: Arc) { - *self.events.write().await = Some(events); + *self.events.write().await = Some(events.clone()); + self.webrtc_streamer.set_event_bus(events).await; } /// Set configuration store @@ -199,19 +200,20 @@ impl VideoStreamManager { } } - // Always reconnect frame source after initialization - // This ensures WebRTC has the correct frame_tx from the current capturer - if let Some(frame_tx) = self.streamer.frame_sender().await { - // Synchronize WebRTC config with actual capture format - let (format, resolution, fps) = self.streamer.current_video_config().await; - info!( - "Reconnecting frame source to WebRTC after init: {}x{} {:?} @ {}fps (receiver_count={})", - resolution.width, resolution.height, format, fps, frame_tx.receiver_count() - ); + // Configure WebRTC capture source after initialization + let (device_path, resolution, format, fps, jpeg_quality) = + self.streamer.current_capture_config().await; + info!( + "WebRTC capture config after init: {}x{} {:?} @ {}fps", + resolution.width, resolution.height, format, fps + ); + self.webrtc_streamer + .update_video_config(resolution, format, fps) + .await; + if let Some(device_path) = device_path { self.webrtc_streamer - .update_video_config(resolution, format, fps) + .set_capture_device(device_path, jpeg_quality) .await; - self.webrtc_streamer.set_video_source(frame_tx).await; } Ok(()) @@ -329,7 +331,7 @@ impl VideoStreamManager { /// Ensure video capture is running (for WebRTC mode) async fn ensure_video_capture_running(self: &Arc) -> Result<()> { - // Initialize streamer if not already initialized + // Initialize streamer if not already initialized (for config discovery) if self.streamer.state().await == StreamerState::Uninitialized { info!("Initializing video capture for WebRTC (ensure)"); if let Err(e) = self.streamer.init_auto().await { @@ -338,29 +340,19 @@ impl VideoStreamManager { } } - // Start video capture if not streaming - if self.streamer.state().await != StreamerState::Streaming { - info!("Starting video capture for WebRTC (ensure)"); - if let Err(e) = self.streamer.start().await { - error!("Failed to start video capture: {}", e); - return Err(e); - } - - // Wait a bit for capture to stabilize - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - } - - // Reconnect frame source to WebRTC - if let Some(frame_tx) = self.streamer.frame_sender().await { - let (format, resolution, fps) = self.streamer.current_video_config().await; - info!( - "Reconnecting frame source to WebRTC: {}x{} {:?} @ {}fps", - resolution.width, resolution.height, format, fps - ); + let (device_path, resolution, format, fps, jpeg_quality) = + self.streamer.current_capture_config().await; + info!( + "Configuring WebRTC capture: {}x{} {:?} @ {}fps", + resolution.width, resolution.height, format, fps + ); + self.webrtc_streamer + .update_video_config(resolution, format, fps) + .await; + if let Some(device_path) = device_path { self.webrtc_streamer - .update_video_config(resolution, format, fps) + .set_capture_device(device_path, jpeg_quality) .await; - self.webrtc_streamer.set_video_source(frame_tx).await; } Ok(()) @@ -403,7 +395,6 @@ impl VideoStreamManager { match current_mode { StreamMode::Mjpeg => { info!("Stopping MJPEG streaming"); - // Only stop MJPEG distribution, keep video capture running for WebRTC self.streamer.mjpeg_handler().set_offline(); if let Err(e) = self.streamer.stop().await { warn!("Error stopping MJPEG streamer: {}", e); @@ -458,10 +449,9 @@ impl VideoStreamManager { } } StreamMode::WebRTC => { - // WebRTC mode: ensure video capture is running for H264 encoding + // WebRTC mode: configure direct capture for encoder pipeline info!("Activating WebRTC mode"); - // Initialize streamer if not already initialized if self.streamer.state().await == StreamerState::Uninitialized { info!("Initializing video capture for WebRTC"); if let Err(e) = self.streamer.init_auto().await { @@ -470,77 +460,32 @@ impl VideoStreamManager { } } - // Auto-switch to non-compressed format if current format is MJPEG/JPEG - if let Some(device) = self.streamer.current_device().await { - let (current_format, resolution, fps) = - self.streamer.current_video_config().await; - - if current_format.is_compressed() { - let available_formats: Vec = - device.formats.iter().map(|f| f.format).collect(); - - // Determine if using hardware encoding - let is_hardware = self.webrtc_streamer.is_hardware_encoding().await; - - if let Some(recommended) = - PixelFormat::recommended_for_encoding(&available_formats, is_hardware) - { - info!( - "Auto-switching from {:?} to {:?} for WebRTC encoding (hardware={})", - current_format, recommended, is_hardware - ); - let device_path = device.path.to_string_lossy().to_string(); - if let Err(e) = self - .streamer - .apply_video_config(&device_path, recommended, resolution, fps) - .await - { - warn!("Failed to auto-switch format for WebRTC: {}, keeping current format", e); - } - } - } - } - - // Start video capture if not streaming - if self.streamer.state().await != StreamerState::Streaming { - info!("Starting video capture for WebRTC"); - if let Err(e) = self.streamer.start().await { - error!("Failed to start video capture for WebRTC: {}", e); - return Err(e); - } - } - - // Wait a bit for capture to stabilize - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - // Connect frame source to WebRTC with correct format - if let Some(frame_tx) = self.streamer.frame_sender().await { - // Synchronize WebRTC config with actual capture format - let (format, resolution, fps) = self.streamer.current_video_config().await; - info!( - "Connecting frame source to WebRTC pipeline: {}x{} {:?} @ {}fps", - resolution.width, resolution.height, format, fps - ); - self.webrtc_streamer - .update_video_config(resolution, format, fps) - .await; - self.webrtc_streamer.set_video_source(frame_tx).await; - - // Publish WebRTCReady event - frame source is now connected - let codec = self.webrtc_streamer.current_video_codec().await; - let is_hardware = self.webrtc_streamer.is_hardware_encoding().await; - self.publish_event(SystemEvent::WebRTCReady { - transition_id: Some(transition_id.clone()), - codec: codec_to_string(codec), - hardware: is_hardware, - }) + let (device_path, resolution, format, fps, jpeg_quality) = + self.streamer.current_capture_config().await; + info!( + "Configuring WebRTC capture pipeline: {}x{} {:?} @ {}fps", + resolution.width, resolution.height, format, fps + ); + self.webrtc_streamer + .update_video_config(resolution, format, fps) .await; + if let Some(device_path) = device_path { + self.webrtc_streamer + .set_capture_device(device_path, jpeg_quality) + .await; } else { - warn!( - "No frame source available for WebRTC - sessions may fail to receive video" - ); + warn!("No capture device configured for WebRTC"); } + let codec = self.webrtc_streamer.current_video_codec().await; + let is_hardware = self.webrtc_streamer.is_hardware_encoding().await; + self.publish_event(SystemEvent::WebRTCReady { + transition_id: Some(transition_id.clone()), + codec: codec_to_string(codec), + hardware: is_hardware, + }) + .await; + info!("WebRTC mode activated (sessions created on-demand)"); } } @@ -587,36 +532,34 @@ impl VideoStreamManager { .update_video_config(resolution, format, fps) .await; - // Restart video capture for WebRTC (it was stopped during config change) - info!("Restarting video capture for WebRTC after config change"); - if let Err(e) = self.streamer.start().await { - error!("Failed to restart video capture for WebRTC: {}", e); - return Err(e); + let (device_path, actual_resolution, actual_format, actual_fps, jpeg_quality) = + self.streamer.current_capture_config().await; + if actual_format != format || actual_resolution != resolution || actual_fps != fps { + info!( + "Actual capture config differs from requested, updating WebRTC: {}x{} {:?} @ {}fps", + actual_resolution.width, actual_resolution.height, actual_format, actual_fps + ); + self.webrtc_streamer + .update_video_config(actual_resolution, actual_format, actual_fps) + .await; } - - // Wait a bit for capture to stabilize - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - // Reconnect frame source with the new capturer - if let Some(frame_tx) = self.streamer.frame_sender().await { - // Note: update_video_config was already called above with the requested config, - // but verify that actual capture matches - let (actual_format, actual_resolution, actual_fps) = - self.streamer.current_video_config().await; - if actual_format != format || actual_resolution != resolution || actual_fps != fps { - info!( - "Actual capture config differs from requested, updating WebRTC: {}x{} {:?} @ {}fps", - actual_resolution.width, actual_resolution.height, actual_format, actual_fps - ); - self.webrtc_streamer - .update_video_config(actual_resolution, actual_format, actual_fps) - .await; - } - info!("Reconnecting frame source to WebRTC after config change"); - self.webrtc_streamer.set_video_source(frame_tx).await; + if let Some(device_path) = device_path { + info!("Configuring direct capture for WebRTC after config change"); + self.webrtc_streamer + .set_capture_device(device_path, jpeg_quality) + .await; } else { - warn!("No frame source available after config change"); + warn!("No capture device configured for WebRTC after config change"); } + + let codec = self.webrtc_streamer.current_video_codec().await; + let is_hardware = self.webrtc_streamer.is_hardware_encoding().await; + self.publish_event(SystemEvent::WebRTCReady { + transition_id: None, + codec: codec_to_string(codec), + hardware: is_hardware, + }) + .await; } Ok(()) @@ -631,22 +574,23 @@ impl VideoStreamManager { self.streamer.start().await?; } StreamMode::WebRTC => { - // Ensure video capture is running + // Ensure device is initialized for config discovery if self.streamer.state().await == StreamerState::Uninitialized { self.streamer.init_auto().await?; } - if self.streamer.state().await != StreamerState::Streaming { - self.streamer.start().await?; - } - // Connect frame source with correct format - if let Some(frame_tx) = self.streamer.frame_sender().await { - // Synchronize WebRTC config with actual capture format - let (format, resolution, fps) = self.streamer.current_video_config().await; + // Synchronize WebRTC config with current capture config + let (device_path, resolution, format, fps, jpeg_quality) = + self.streamer.current_capture_config().await; + self.webrtc_streamer + .update_video_config(resolution, format, fps) + .await; + if let Some(device_path) = device_path { self.webrtc_streamer - .update_video_config(resolution, format, fps) + .set_capture_device(device_path, jpeg_quality) .await; - self.webrtc_streamer.set_video_source(frame_tx).await; + } else { + warn!("No capture device configured for WebRTC"); } } } @@ -764,13 +708,6 @@ impl VideoStreamManager { self.streamer.is_streaming().await } - /// Get frame sender for video frames - pub async fn frame_sender( - &self, - ) -> Option> { - self.streamer.frame_sender().await - } - /// Subscribe to encoded video frames from the shared video pipeline /// /// This allows RustDesk (and other consumers) to receive H264/H265/VP8/VP9 @@ -781,10 +718,10 @@ impl VideoStreamManager { /// Returns None if video capture cannot be started or pipeline creation fails. pub async fn subscribe_encoded_frames( &self, - ) -> Option< - tokio::sync::broadcast::Receiver, - > { - // 1. Ensure video capture is initialized + ) -> Option>> { + // 1. Ensure video capture is initialized (for config discovery) if self.streamer.state().await == StreamerState::Uninitialized { tracing::info!("Initializing video capture for encoded frame subscription"); if let Err(e) = self.streamer.init_auto().await { @@ -796,28 +733,9 @@ impl VideoStreamManager { } } - // 2. Ensure video capture is running (streaming) - if self.streamer.state().await != StreamerState::Streaming { - tracing::info!("Starting video capture for encoded frame subscription"); - if let Err(e) = self.streamer.start().await { - tracing::error!("Failed to start video capture for encoded frames: {}", e); - return None; - } - // Wait for capture to stabilize - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - } - - // 3. Get frame sender from running capture - let frame_tx = match self.streamer.frame_sender().await { - Some(tx) => tx, - None => { - tracing::warn!("Cannot subscribe to encoded frames: no frame sender available"); - return None; - } - }; - - // 4. Synchronize WebRTC config with actual capture format - let (format, resolution, fps) = self.streamer.current_video_config().await; + // 2. Synchronize WebRTC config with capture config + let (device_path, resolution, format, fps, jpeg_quality) = + self.streamer.current_capture_config().await; tracing::info!( "Connecting encoded frame subscription: {}x{} {:?} @ {}fps", resolution.width, @@ -828,14 +746,17 @@ impl VideoStreamManager { self.webrtc_streamer .update_video_config(resolution, format, fps) .await; + if let Some(device_path) = device_path { + self.webrtc_streamer + .set_capture_device(device_path, jpeg_quality) + .await; + } else { + tracing::warn!("No capture device configured for encoded frames"); + return None; + } - // 5. Use WebRtcStreamer to ensure the shared video pipeline is running - // This will create the pipeline if needed - match self - .webrtc_streamer - .ensure_video_pipeline_for_external(frame_tx) - .await - { + // 3. Use WebRtcStreamer to ensure the shared video pipeline is running + match self.webrtc_streamer.ensure_video_pipeline_for_external().await { Ok(pipeline) => Some(pipeline.subscribe()), Err(e) => { tracing::error!("Failed to start shared video pipeline: {}", e); @@ -873,6 +794,11 @@ impl VideoStreamManager { self.webrtc_streamer.set_bitrate_preset(preset).await } + /// Request a keyframe from the shared video pipeline + pub async fn request_keyframe(&self) -> crate::error::Result<()> { + self.webrtc_streamer.request_keyframe().await + } + /// Publish event to event bus async fn publish_event(&self, event: SystemEvent) { if let Some(ref events) = *self.events.read().await { diff --git a/src/video/streamer.rs b/src/video/streamer.rs index 7e0ab5ae..2b11e744 100644 --- a/src/video/streamer.rs +++ b/src/video/streamer.rs @@ -4,17 +4,28 @@ //! managing the lifecycle of the capture thread and MJPEG/WebRTC distribution. use std::path::PathBuf; +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::Arc; -use tokio::sync::{broadcast, RwLock}; +use tokio::sync::RwLock; use tracing::{debug, error, info, trace, warn}; -use super::capture::{CaptureConfig, CaptureState, VideoCapturer}; use super::device::{enumerate_devices, find_best_device, VideoDeviceInfo}; use super::format::{PixelFormat, Resolution}; -use super::frame::VideoFrame; +use super::frame::{FrameBuffer, FrameBufferPool, VideoFrame}; use crate::error::{AppError, Result}; use crate::events::{EventBus, SystemEvent}; use crate::stream::MjpegStreamHandler; +use v4l::buffer::Type as BufferType; +use v4l::io::traits::CaptureStream; +use v4l::prelude::*; +use v4l::video::capture::Parameters; +use v4l::video::Capture; +use v4l::Format; + +/// Minimum valid frame size for capture +const MIN_CAPTURE_FRAME_SIZE: usize = 128; +/// Validate JPEG header every N frames to reduce overhead +const JPEG_VALIDATE_INTERVAL: u64 = 30; /// Streamer configuration #[derive(Debug, Clone)] @@ -65,11 +76,14 @@ pub enum StreamerState { /// Video streamer service pub struct Streamer { config: RwLock, - capturer: RwLock>>, mjpeg_handler: Arc, current_device: RwLock>, state: RwLock, start_lock: tokio::sync::Mutex<()>, + direct_stop: AtomicBool, + direct_active: AtomicBool, + direct_handle: tokio::sync::Mutex>>, + current_fps: AtomicU32, /// Event bus for broadcasting state changes (optional) events: RwLock>>, /// Last published state (for change detection) @@ -94,11 +108,14 @@ impl Streamer { pub fn new() -> Arc { Arc::new(Self { config: RwLock::new(StreamerConfig::default()), - capturer: RwLock::new(None), mjpeg_handler: Arc::new(MjpegStreamHandler::new()), current_device: RwLock::new(None), state: RwLock::new(StreamerState::Uninitialized), start_lock: tokio::sync::Mutex::new(()), + direct_stop: AtomicBool::new(false), + direct_active: AtomicBool::new(false), + direct_handle: tokio::sync::Mutex::new(None), + current_fps: AtomicU32::new(0), events: RwLock::new(None), last_published_state: RwLock::new(None), config_changing: std::sync::atomic::AtomicBool::new(false), @@ -114,11 +131,14 @@ impl Streamer { pub fn with_config(config: StreamerConfig) -> Arc { Arc::new(Self { config: RwLock::new(config), - capturer: RwLock::new(None), mjpeg_handler: Arc::new(MjpegStreamHandler::new()), current_device: RwLock::new(None), state: RwLock::new(StreamerState::Uninitialized), start_lock: tokio::sync::Mutex::new(()), + direct_stop: AtomicBool::new(false), + direct_active: AtomicBool::new(false), + direct_handle: tokio::sync::Mutex::new(None), + current_fps: AtomicU32::new(0), events: RwLock::new(None), last_published_state: RwLock::new(None), config_changing: std::sync::atomic::AtomicBool::new(false), @@ -176,20 +196,6 @@ impl Streamer { self.mjpeg_handler.clone() } - /// Get frame sender for WebRTC integration - /// Returns None if no capturer is initialized - pub async fn frame_sender(&self) -> Option> { - let capturer = self.capturer.read().await; - capturer.as_ref().map(|c| c.frame_sender()) - } - - /// Subscribe to video frames - /// Returns None if no capturer is initialized - pub async fn subscribe_frames(&self) -> Option> { - let capturer = self.capturer.read().await; - capturer.as_ref().map(|c| c.subscribe()) - } - /// Get current device info pub async fn current_device(&self) -> Option { self.current_device.read().await.clone() @@ -201,6 +207,20 @@ impl Streamer { (config.format, config.resolution, config.fps) } + /// Get current capture configuration for direct pipelines + pub async fn current_capture_config( + &self, + ) -> (Option, Resolution, PixelFormat, u32, u8) { + let config = self.config.read().await; + ( + config.device_path.clone(), + config.resolution, + config.format, + config.fps, + config.jpeg_quality, + ) + } + /// List available video devices pub async fn list_devices(&self) -> Result> { enumerate_devices() @@ -278,18 +298,11 @@ impl Streamer { // Give clients time to receive the disconnect signal and close their connections tokio::time::sleep(std::time::Duration::from_millis(100)).await; - // Stop existing capturer and wait for device release - { - // Take ownership of the old capturer to ensure it's dropped - let old_capturer = self.capturer.write().await.take(); - if let Some(capturer) = old_capturer { - info!("Stopping existing capture before applying new config..."); - if let Err(e) = capturer.stop().await { - warn!("Error stopping old capturer: {}", e); - } - // Explicitly drop the capturer to release V4L2 resources - drop(capturer); - } + // Stop active capture and wait for device release + if self.direct_active.load(Ordering::SeqCst) { + info!("Stopping existing capture before applying new config..."); + self.stop().await?; + tokio::time::sleep(std::time::Duration::from_millis(100)).await; } // Update config @@ -301,18 +314,6 @@ impl Streamer { cfg.fps = fps; } - // Recreate capturer - let capture_config = CaptureConfig { - device_path: device.path.clone(), - resolution, - format, - fps, - jpeg_quality: self.config.read().await.jpeg_quality, - ..Default::default() - }; - - let capturer = Arc::new(VideoCapturer::new(capture_config)); - *self.capturer.write().await = Some(capturer.clone()); *self.current_device.write().await = Some(device.clone()); *self.state.write().await = StreamerState::Ready; @@ -374,21 +375,6 @@ impl Streamer { // Store device info *self.current_device.write().await = Some(device.clone()); - // Create capturer - let config = self.config.read().await; - let capture_config = CaptureConfig { - device_path: device.path.clone(), - resolution: config.resolution, - format: config.format, - fps: config.fps, - jpeg_quality: config.jpeg_quality, - ..Default::default() - }; - drop(config); - - let capturer = Arc::new(VideoCapturer::new(capture_config)); - *self.capturer.write().await = Some(capturer); - *self.state.write().await = StreamerState::Ready; info!("Streamer initialized: {} @ {}", format, resolution); @@ -445,43 +431,30 @@ impl Streamer { .ok_or_else(|| AppError::VideoError("No resolutions available".to_string())) } - /// Restart the capturer only (for recovery - doesn't spawn new monitor) - /// - /// This is a simpler version of start() used during device recovery. - /// It doesn't spawn a new state monitor since the existing one is still active. - async fn restart_capturer(&self) -> Result<()> { - let capturer = self.capturer.read().await; - let capturer = capturer - .as_ref() - .ok_or_else(|| AppError::VideoError("Capturer not initialized".to_string()))?; + /// Restart capture for recovery (direct capture path) + async fn restart_capture(self: &Arc) -> Result<()> { + self.direct_stop.store(false, Ordering::SeqCst); + self.start().await?; - // Start capture - capturer.start().await?; - - // Set MJPEG handler online - self.mjpeg_handler.set_online(); - - // Start frame distribution task - let mjpeg_handler = self.mjpeg_handler.clone(); - let mut frame_rx = capturer.subscribe(); - - tokio::spawn(async move { - debug!("Recovery frame distribution task started"); - loop { - match frame_rx.recv().await { - Ok(frame) => { - mjpeg_handler.update_frame(frame); - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {} - Err(tokio::sync::broadcast::error::RecvError::Closed) => { - debug!("Frame channel closed"); - break; - } + // Wait briefly for the capture thread to initialize the device. + // If it fails immediately, the state will flip to Error/DeviceLost. + for _ in 0..5 { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + let state = *self.state.read().await; + match state { + StreamerState::Streaming | StreamerState::NoSignal => return Ok(()), + StreamerState::Error | StreamerState::DeviceLost => { + return Err(AppError::VideoError( + "Failed to restart capture".to_string(), + )) } + _ => {} } - }); + } - Ok(()) + Err(AppError::VideoError( + "Capture restart timed out".to_string(), + )) } /// Start streaming @@ -498,138 +471,26 @@ impl Streamer { self.init_auto().await?; } - let capturer = self.capturer.read().await; - let capturer = capturer - .as_ref() - .ok_or_else(|| AppError::VideoError("Capturer not initialized".to_string()))?; + let device = self + .current_device + .read() + .await + .clone() + .ok_or_else(|| AppError::VideoError("No video device configured".to_string()))?; - // Start capture - capturer.start().await?; + let config = self.config.read().await.clone(); + self.direct_stop.store(false, Ordering::SeqCst); + self.direct_active.store(true, Ordering::SeqCst); - // Set MJPEG handler online before starting frame distribution - // This is important after config changes where disconnect_all_clients() set it offline + let streamer = self.clone(); + let handle = tokio::task::spawn_blocking(move || { + streamer.run_direct_capture(device.path, config); + }); + *self.direct_handle.lock().await = Some(handle); + + // Set MJPEG handler online before starting capture self.mjpeg_handler.set_online(); - // Start frame distribution task - let mjpeg_handler = self.mjpeg_handler.clone(); - let mut frame_rx = capturer.subscribe(); - let state_ref = Arc::downgrade(self); - let frame_tx = capturer.frame_sender(); - - tokio::spawn(async move { - info!("Frame distribution task started"); - - // Track when we started having no active consumers - let mut idle_since: Option = None; - const IDLE_STOP_DELAY_SECS: u64 = 5; - - loop { - match frame_rx.recv().await { - Ok(frame) => { - mjpeg_handler.update_frame(frame); - - // Check if there are any active consumers: - // - MJPEG clients via mjpeg_handler - // - Other subscribers (WebRTC/RustDesk) via frame_tx receiver_count - // Note: receiver_count includes this task, so > 1 means other subscribers - let mjpeg_clients = mjpeg_handler.client_count(); - let other_subscribers = frame_tx.receiver_count().saturating_sub(1); - - if mjpeg_clients == 0 && other_subscribers == 0 { - if idle_since.is_none() { - idle_since = Some(std::time::Instant::now()); - trace!("No active video consumers, starting idle timer"); - } else if let Some(since) = idle_since { - if since.elapsed().as_secs() >= IDLE_STOP_DELAY_SECS { - info!( - "No active video consumers for {}s, stopping frame distribution", - IDLE_STOP_DELAY_SECS - ); - // Stop the streamer - if let Some(streamer) = state_ref.upgrade() { - if let Err(e) = streamer.stop().await { - warn!( - "Failed to stop streamer during idle cleanup: {}", - e - ); - } - } - break; - } - } - } else { - // Reset idle timer when we have consumers - if idle_since.is_some() { - trace!("Video consumers active, resetting idle timer"); - idle_since = None; - } - } - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {} - Err(tokio::sync::broadcast::error::RecvError::Closed) => { - debug!("Frame channel closed"); - break; - } - } - - // Check if streamer still exists - if state_ref.upgrade().is_none() { - break; - } - } - info!("Frame distribution task ended"); - }); - - // Monitor capture state - let mut state_rx = capturer.state_watch(); - let state_ref = Arc::downgrade(self); - let mjpeg_handler = self.mjpeg_handler.clone(); - - tokio::spawn(async move { - while state_rx.changed().await.is_ok() { - let capture_state = *state_rx.borrow(); - match capture_state { - CaptureState::Running => { - if let Some(streamer) = state_ref.upgrade() { - *streamer.state.write().await = StreamerState::Streaming; - } - } - CaptureState::NoSignal => { - mjpeg_handler.set_offline(); - if let Some(streamer) = state_ref.upgrade() { - *streamer.state.write().await = StreamerState::NoSignal; - } - } - CaptureState::Stopped => { - mjpeg_handler.set_offline(); - if let Some(streamer) = state_ref.upgrade() { - *streamer.state.write().await = StreamerState::Ready; - } - } - CaptureState::Error => { - mjpeg_handler.set_offline(); - if let Some(streamer) = state_ref.upgrade() { - *streamer.state.write().await = StreamerState::Error; - } - } - CaptureState::DeviceLost => { - mjpeg_handler.set_offline(); - if let Some(streamer) = state_ref.upgrade() { - *streamer.state.write().await = StreamerState::DeviceLost; - // Start device recovery task (fire and forget) - let streamer_clone = Arc::clone(&streamer); - tokio::spawn(async move { - streamer_clone.start_device_recovery_internal().await; - }); - } - } - CaptureState::Starting => { - // Starting state - device is initializing, no action needed - } - } - } - }); - // Start background tasks only once per Streamer instance // Use compare_exchange to atomically check and set the flag if self @@ -735,9 +596,11 @@ impl Streamer { /// Stop streaming pub async fn stop(&self) -> Result<()> { - if let Some(capturer) = self.capturer.read().await.as_ref() { - capturer.stop().await?; + self.direct_stop.store(true, Ordering::SeqCst); + if let Some(handle) = self.direct_handle.lock().await.take() { + let _ = handle.await; } + self.direct_active.store(false, Ordering::SeqCst); self.mjpeg_handler.set_offline(); *self.state.write().await = StreamerState::Ready; @@ -749,6 +612,258 @@ impl Streamer { Ok(()) } + /// Direct capture loop for MJPEG mode (single loop, no broadcast) + fn run_direct_capture(self: Arc, device_path: PathBuf, config: StreamerConfig) { + const MAX_RETRIES: u32 = 5; + const RETRY_DELAY_MS: u64 = 200; + const IDLE_STOP_DELAY_SECS: u64 = 5; + const BUFFER_COUNT: u32 = 2; + + let handle = tokio::runtime::Handle::current(); + let mut last_state = StreamerState::Streaming; + + let mut set_state = |new_state: StreamerState| { + if new_state != last_state { + handle.block_on(async { + *self.state.write().await = new_state; + self.publish_event(self.current_state_event().await).await; + }); + last_state = new_state; + } + }; + + let mut device_opt: Option = None; + let mut format_opt: Option = None; + let mut last_error: Option = None; + + for attempt in 0..MAX_RETRIES { + if self.direct_stop.load(Ordering::Relaxed) { + self.direct_active.store(false, Ordering::SeqCst); + return; + } + + let device = match Device::with_path(&device_path) { + Ok(d) => d, + Err(e) => { + let err_str = e.to_string(); + if err_str.contains("busy") || err_str.contains("resource") { + warn!( + "Device busy on attempt {}/{}, retrying in {}ms...", + attempt + 1, + MAX_RETRIES, + RETRY_DELAY_MS + ); + std::thread::sleep(std::time::Duration::from_millis(RETRY_DELAY_MS)); + last_error = Some(err_str); + continue; + } + last_error = Some(err_str); + break; + } + }; + + let requested = Format::new( + config.resolution.width, + config.resolution.height, + config.format.to_fourcc(), + ); + + match device.set_format(&requested) { + Ok(actual) => { + device_opt = Some(device); + format_opt = Some(actual); + break; + } + Err(e) => { + let err_str = e.to_string(); + if err_str.contains("busy") || err_str.contains("resource") { + warn!( + "Device busy on set_format attempt {}/{}, retrying in {}ms...", + attempt + 1, + MAX_RETRIES, + RETRY_DELAY_MS + ); + std::thread::sleep(std::time::Duration::from_millis(RETRY_DELAY_MS)); + last_error = Some(err_str); + continue; + } + last_error = Some(err_str); + break; + } + } + } + + let (device, actual_format) = match (device_opt, format_opt) { + (Some(d), Some(f)) => (d, f), + _ => { + error!( + "Failed to open device {:?}: {}", + device_path, + last_error.unwrap_or_else(|| "unknown error".to_string()) + ); + self.mjpeg_handler.set_offline(); + set_state(StreamerState::Error); + self.direct_active.store(false, Ordering::SeqCst); + self.current_fps.store(0, Ordering::Relaxed); + return; + } + }; + + info!( + "Capture format: {}x{} {:?} stride={}", + actual_format.width, actual_format.height, actual_format.fourcc, actual_format.stride + ); + + let resolution = Resolution::new(actual_format.width, actual_format.height); + let pixel_format = + PixelFormat::from_fourcc(actual_format.fourcc).unwrap_or(config.format); + + if config.fps > 0 { + if let Err(e) = device.set_params(&Parameters::with_fps(config.fps)) { + warn!("Failed to set hardware FPS: {}", e); + } + } + + let mut stream = + match MmapStream::with_buffers(&device, BufferType::VideoCapture, BUFFER_COUNT) { + Ok(s) => s, + Err(e) => { + error!("Failed to create capture stream: {}", e); + self.mjpeg_handler.set_offline(); + set_state(StreamerState::Error); + self.direct_active.store(false, Ordering::SeqCst); + self.current_fps.store(0, Ordering::Relaxed); + return; + } + }; + + let buffer_pool = Arc::new(FrameBufferPool::new(BUFFER_COUNT.max(4) as usize)); + let mut signal_present = true; + let mut sequence: u64 = 0; + let mut validate_counter: u64 = 0; + let mut idle_since: Option = None; + + let mut fps_frame_count: u64 = 0; + let mut last_fps_time = std::time::Instant::now(); + + while !self.direct_stop.load(Ordering::Relaxed) { + let mjpeg_clients = self.mjpeg_handler.client_count(); + if mjpeg_clients == 0 { + if idle_since.is_none() { + idle_since = Some(std::time::Instant::now()); + trace!("No active video consumers, starting idle timer"); + } else if let Some(since) = idle_since { + if since.elapsed().as_secs() >= IDLE_STOP_DELAY_SECS { + info!( + "No active video consumers for {}s, stopping capture", + IDLE_STOP_DELAY_SECS + ); + self.mjpeg_handler.set_offline(); + set_state(StreamerState::Ready); + break; + } + } + } else if idle_since.is_some() { + trace!("Video consumers active, resetting idle timer"); + idle_since = None; + } + + let (buf, meta) = match stream.next() { + Ok(frame_data) => frame_data, + Err(e) => { + if e.kind() == std::io::ErrorKind::TimedOut { + if signal_present { + signal_present = false; + self.mjpeg_handler.set_offline(); + set_state(StreamerState::NoSignal); + self.current_fps.store(0, Ordering::Relaxed); + fps_frame_count = 0; + last_fps_time = std::time::Instant::now(); + } + std::thread::sleep(std::time::Duration::from_millis(100)); + continue; + } + + let is_device_lost = match e.raw_os_error() { + Some(6) => true, // ENXIO + Some(19) => true, // ENODEV + Some(5) => true, // EIO + Some(32) => true, // EPIPE + Some(108) => true, // ESHUTDOWN + _ => false, + }; + + if is_device_lost { + error!("Video device lost: {} - {}", device_path.display(), e); + self.mjpeg_handler.set_offline(); + handle.block_on(async { + *self.last_lost_device.write().await = + Some(device_path.display().to_string()); + *self.last_lost_reason.write().await = Some(e.to_string()); + }); + set_state(StreamerState::DeviceLost); + handle.block_on(async { + let streamer = Arc::clone(&self); + tokio::spawn(async move { + streamer.start_device_recovery_internal().await; + }); + }); + break; + } + + error!("Capture error: {}", e); + continue; + } + }; + + let frame_size = meta.bytesused as usize; + if frame_size < MIN_CAPTURE_FRAME_SIZE { + continue; + } + + validate_counter = validate_counter.wrapping_add(1); + if pixel_format.is_compressed() + && validate_counter % JPEG_VALIDATE_INTERVAL == 0 + && !VideoFrame::is_valid_jpeg_bytes(&buf[..frame_size]) + { + continue; + } + + let mut owned = buffer_pool.take(frame_size); + owned.resize(frame_size, 0); + owned[..frame_size].copy_from_slice(&buf[..frame_size]); + let frame = VideoFrame::from_pooled( + Arc::new(FrameBuffer::new(owned, Some(buffer_pool.clone()))), + resolution, + pixel_format, + actual_format.stride, + sequence, + ); + sequence = sequence.wrapping_add(1); + + if !signal_present { + signal_present = true; + self.mjpeg_handler.set_online(); + set_state(StreamerState::Streaming); + } + + self.mjpeg_handler.update_frame(frame); + + fps_frame_count += 1; + let fps_elapsed = last_fps_time.elapsed(); + if fps_elapsed >= std::time::Duration::from_secs(1) { + let current_fps = fps_frame_count as f32 / fps_elapsed.as_secs_f32(); + fps_frame_count = 0; + last_fps_time = std::time::Instant::now(); + self.current_fps + .store((current_fps * 100.0) as u32, Ordering::Relaxed); + } + } + + self.direct_active.store(false, Ordering::SeqCst); + self.current_fps.store(0, Ordering::Relaxed); + } + /// Check if streaming pub async fn is_streaming(&self) -> bool { self.state().await == StreamerState::Streaming @@ -756,14 +871,8 @@ impl Streamer { /// Get stream statistics pub async fn stats(&self) -> StreamerStats { - let capturer = self.capturer.read().await; - let capture_stats = if let Some(c) = capturer.as_ref() { - Some(c.stats().await) - } else { - None - }; - let config = self.config.read().await; + let fps = self.current_fps.load(Ordering::Relaxed) as f32 / 100.0; StreamerStats { state: self.state().await, @@ -772,15 +881,7 @@ impl Streamer { resolution: Some((config.resolution.width, config.resolution.height)), clients: self.mjpeg_handler.client_count(), target_fps: config.fps, - fps: capture_stats.as_ref().map(|s| s.current_fps).unwrap_or(0.0), - frames_captured: capture_stats - .as_ref() - .map(|s| s.frames_captured) - .unwrap_or(0), - frames_dropped: capture_stats - .as_ref() - .map(|s| s.frames_dropped) - .unwrap_or(0), + fps, } } @@ -829,23 +930,23 @@ impl Streamer { return; } - // Get last lost device info from capturer - let (device, reason) = { - let capturer = self.capturer.read().await; - if let Some(cap) = capturer.as_ref() { - cap.last_error().unwrap_or_else(|| { - let device_path = self - .current_device - .blocking_read() - .as_ref() - .map(|d| d.path.display().to_string()) - .unwrap_or_else(|| "unknown".to_string()); - (device_path, "Device lost".to_string()) - }) - } else { - ("unknown".to_string(), "Device lost".to_string()) - } + // Get last lost device info (from direct capture) + let device = if let Some(device) = self.last_lost_device.read().await.clone() { + device + } else { + self.current_device + .read() + .await + .as_ref() + .map(|d| d.path.display().to_string()) + .unwrap_or_else(|| "unknown".to_string()) }; + let reason = self + .last_lost_reason + .read() + .await + .clone() + .unwrap_or_else(|| "Device lost".to_string()); // Store error info *self.last_lost_device.write().await = Some(device.clone()); @@ -908,7 +1009,7 @@ impl Streamer { } // Try to restart capture - match streamer.restart_capturer().await { + match streamer.restart_capture().await { Ok(_) => { info!( "Video device {} recovered after {} attempts", @@ -947,11 +1048,14 @@ impl Default for Streamer { fn default() -> Self { Self { config: RwLock::new(StreamerConfig::default()), - capturer: RwLock::new(None), mjpeg_handler: Arc::new(MjpegStreamHandler::new()), current_device: RwLock::new(None), state: RwLock::new(StreamerState::Uninitialized), start_lock: tokio::sync::Mutex::new(()), + direct_stop: AtomicBool::new(false), + direct_active: AtomicBool::new(false), + direct_handle: tokio::sync::Mutex::new(None), + current_fps: AtomicU32::new(0), events: RwLock::new(None), last_published_state: RwLock::new(None), config_changing: std::sync::atomic::AtomicBool::new(false), @@ -976,8 +1080,6 @@ pub struct StreamerStats { pub target_fps: u32, /// Current actual FPS pub fps: f32, - pub frames_captured: u64, - pub frames_dropped: u64, } impl serde::Serialize for StreamerState { diff --git a/src/video/video_session.rs b/src/video/video_session.rs index 81b5ee41..7b51f725 100644 --- a/src/video/video_session.rs +++ b/src/video/video_session.rs @@ -83,7 +83,7 @@ struct VideoSession { /// Last activity time last_activity: Instant, /// Frame receiver - frame_rx: Option>, + frame_rx: Option>>, /// Stats frames_received: u64, bytes_received: u64, @@ -243,7 +243,7 @@ impl VideoSessionManager { pub async fn start_session( &self, session_id: &str, - ) -> Result> { + ) -> Result>> { // Ensure pipeline is running with correct codec self.ensure_pipeline_for_session(session_id).await?; diff --git a/src/web/audio_ws.rs b/src/web/audio_ws.rs index db9f60f0..564c663e 100644 --- a/src/web/audio_ws.rs +++ b/src/web/audio_ws.rs @@ -26,7 +26,6 @@ use axum::{ use futures::{SinkExt, StreamExt}; use std::sync::Arc; use std::time::Instant; -use tokio::sync::broadcast; use tracing::{debug, info, warn}; use crate::audio::OpusFrame; @@ -79,23 +78,21 @@ async fn handle_audio_socket(socket: WebSocket, state: Arc) { loop { tokio::select! { // Receive Opus frames and send to client - opus_result = opus_rx.recv() => { - match opus_result { - Ok(frame) => { - let binary = encode_audio_packet(&frame, stream_start); - if sender.send(Message::Binary(binary.into())).await.is_err() { - debug!("Failed to send audio frame, client disconnected"); - break; - } - } - Err(broadcast::error::RecvError::Lagged(n)) => { - warn!("Audio WebSocket client lagged by {} frames", n); - // Continue - just skip the missed frames - } - Err(broadcast::error::RecvError::Closed) => { - info!("Audio stream closed"); - break; - } + opus_result = opus_rx.changed() => { + if opus_result.is_err() { + info!("Audio stream closed"); + break; + } + + let frame = match opus_rx.borrow().clone() { + Some(frame) => frame, + None => continue, + }; + + let binary = encode_audio_packet(&frame, stream_start); + if sender.send(Message::Binary(binary.into())).await.is_err() { + debug!("Failed to send audio frame, client disconnected"); + break; } } diff --git a/src/web/handlers/config/apply.rs b/src/web/handlers/config/apply.rs index e80c8683..d7dd88e0 100644 --- a/src/web/handlers/config/apply.rs +++ b/src/web/handlers/config/apply.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use crate::config::*; use crate::error::{AppError, Result}; +use crate::events::SystemEvent; use crate::state::AppState; /// 应用 Video 配置变更 @@ -57,27 +58,55 @@ pub async fn apply_video_config( .map_err(|e| AppError::VideoError(format!("Failed to apply video config: {}", e)))?; tracing::info!("Video config applied to streamer"); - // Step 3: 重启 streamer - if let Err(e) = state.stream_manager.start().await { - tracing::error!("Failed to start streamer after config change: {}", e); - } else { - tracing::info!("Streamer started after config change"); + // Step 3: 重启 streamer(仅 MJPEG 模式) + if !state.stream_manager.is_webrtc_enabled().await { + if let Err(e) = state.stream_manager.start().await { + tracing::error!("Failed to start streamer after config change: {}", e); + } else { + tracing::info!("Streamer started after config change"); + } } - // Step 4: 更新 WebRTC frame source - if let Some(frame_tx) = state.stream_manager.frame_sender().await { - let receiver_count = frame_tx.receiver_count(); + // 配置 WebRTC direct capture(所有模式统一配置) + let (device_path, _resolution, _format, _fps, jpeg_quality) = state + .stream_manager + .streamer() + .current_capture_config() + .await; + if let Some(device_path) = device_path { state .stream_manager .webrtc_streamer() - .set_video_source(frame_tx) + .set_capture_device(device_path, jpeg_quality) .await; - tracing::info!( - "WebRTC streamer frame source updated (receiver_count={})", - receiver_count - ); } else { - tracing::warn!("No frame source available after config change"); + tracing::warn!("No capture device configured for WebRTC"); + } + + if state.stream_manager.is_webrtc_enabled().await { + use crate::video::encoder::VideoCodecType; + let codec = state + .stream_manager + .webrtc_streamer() + .current_video_codec() + .await; + let codec_str = match codec { + VideoCodecType::H264 => "h264", + VideoCodecType::H265 => "h265", + VideoCodecType::VP8 => "vp8", + VideoCodecType::VP9 => "vp9", + } + .to_string(); + let is_hardware = state + .stream_manager + .webrtc_streamer() + .is_hardware_encoding() + .await; + state.events.publish(SystemEvent::WebRTCReady { + transition_id: None, + codec: codec_str, + hardware: is_hardware, + }); } tracing::info!("Video config applied successfully"); @@ -157,6 +186,31 @@ pub async fn apply_hid_config( ) -> Result<()> { // 检查 OTG 描述符是否变更 let descriptor_changed = old_config.otg_descriptor != new_config.otg_descriptor; + let old_hid_functions = old_config.effective_otg_functions(); + let mut new_hid_functions = new_config.effective_otg_functions(); + + // Low-endpoint UDCs (e.g., musb) cannot handle consumer control endpoints reliably + if new_config.backend == HidBackend::Otg { + if let Some(udc) = + crate::otg::configfs::resolve_udc_name(new_config.otg_udc.as_deref()) + { + if crate::otg::configfs::is_low_endpoint_udc(&udc) && new_hid_functions.consumer { + tracing::warn!( + "UDC {} has low endpoint resources, disabling consumer control", + udc + ); + new_hid_functions.consumer = false; + } + } + } + + let hid_functions_changed = old_hid_functions != new_hid_functions; + + if new_config.backend == HidBackend::Otg && new_hid_functions.is_empty() { + return Err(AppError::BadRequest( + "OTG HID functions cannot be empty".to_string(), + )); + } // 如果描述符变更且当前使用 OTG 后端,需要重建 Gadget if descriptor_changed && new_config.backend == HidBackend::Otg { @@ -181,6 +235,7 @@ pub async fn apply_hid_config( && old_config.ch9329_baudrate == new_config.ch9329_baudrate && old_config.otg_udc == new_config.otg_udc && !descriptor_changed + && !hid_functions_changed { tracing::info!("HID config unchanged, skipping reload"); return Ok(()); @@ -188,6 +243,16 @@ pub async fn apply_hid_config( tracing::info!("Applying HID config changes..."); + if new_config.backend == HidBackend::Otg + && (hid_functions_changed || old_config.backend != HidBackend::Otg) + { + state + .otg_service + .update_hid_functions(new_hid_functions.clone()) + .await + .map_err(|e| AppError::Config(format!("OTG HID function update failed: {}", e)))?; + } + let new_hid_backend = match new_config.backend { HidBackend::Otg => crate::hid::HidBackendType::Otg, HidBackend::Ch9329 => crate::hid::HidBackendType::Ch9329 { @@ -208,32 +273,6 @@ pub async fn apply_hid_config( new_config.backend ); - // When switching to OTG backend, automatically enable MSD if not already enabled - // OTG HID and MSD share the same USB gadget, so it makes sense to enable both - if new_config.backend == HidBackend::Otg && old_config.backend != HidBackend::Otg { - let msd_guard = state.msd.read().await; - if msd_guard.is_none() { - drop(msd_guard); // Release read lock before acquiring write lock - - tracing::info!("OTG HID enabled, automatically initializing MSD..."); - - // Get MSD config from store - let config = state.config.get(); - - let msd = - crate::msd::MsdController::new(state.otg_service.clone(), config.msd.msd_dir_path()); - - if let Err(e) = msd.init().await { - tracing::warn!("Failed to auto-initialize MSD for OTG: {}", e); - } else { - let events = state.events.clone(); - msd.set_event_bus(events).await; - *state.msd.write().await = Some(msd); - tracing::info!("MSD automatically initialized for OTG mode"); - } - } - } - Ok(()) } diff --git a/src/web/handlers/config/auth.rs b/src/web/handlers/config/auth.rs new file mode 100644 index 00000000..f08879ae --- /dev/null +++ b/src/web/handlers/config/auth.rs @@ -0,0 +1,33 @@ +use axum::{extract::State, Json}; +use std::sync::Arc; + +use crate::config::AuthConfig; +use crate::error::Result; +use crate::state::AppState; + +use super::types::AuthConfigUpdate; + +/// Get auth configuration (sensitive fields are cleared) +pub async fn get_auth_config(State(state): State>) -> Json { + let mut auth = state.config.get().auth.clone(); + auth.totp_secret = None; + Json(auth) +} + +/// Update auth configuration +pub async fn update_auth_config( + State(state): State>, + Json(update): Json, +) -> Result> { + update.validate()?; + state + .config + .update(|config| { + update.apply_to(&mut config.auth); + }) + .await?; + + let mut auth = state.config.get().auth.clone(); + auth.totp_secret = None; + Ok(Json(auth)) +} diff --git a/src/web/handlers/config/mod.rs b/src/web/handlers/config/mod.rs index 5c9ee5e9..6748ac20 100644 --- a/src/web/handlers/config/mod.rs +++ b/src/web/handlers/config/mod.rs @@ -21,6 +21,7 @@ mod types; mod atx; mod audio; +mod auth; mod hid; mod msd; mod rustdesk; @@ -31,6 +32,7 @@ mod web; // 导出 handler 函数 pub use atx::{get_atx_config, update_atx_config}; pub use audio::{get_audio_config, update_audio_config}; +pub use auth::{get_auth_config, update_auth_config}; pub use hid::{get_hid_config, update_hid_config}; pub use msd::{get_msd_config, update_msd_config}; pub use rustdesk::{ diff --git a/src/web/handlers/config/types.rs b/src/web/handlers/config/types.rs index a3b0e049..3e500085 100644 --- a/src/web/handlers/config/types.rs +++ b/src/web/handlers/config/types.rs @@ -6,6 +6,25 @@ use serde::Deserialize; use std::path::Path; use typeshare::typeshare; +// ===== Auth Config ===== +#[typeshare] +#[derive(Debug, Deserialize)] +pub struct AuthConfigUpdate { + pub single_user_allow_multiple_sessions: Option, +} + +impl AuthConfigUpdate { + pub fn validate(&self) -> crate::error::Result<()> { + Ok(()) + } + + pub fn apply_to(&self, config: &mut AuthConfig) { + if let Some(allow_multiple) = self.single_user_allow_multiple_sessions { + config.single_user_allow_multiple_sessions = allow_multiple; + } + } +} + // ===== Video Config ===== #[typeshare] #[derive(Debug, Deserialize)] @@ -252,6 +271,32 @@ impl OtgDescriptorConfigUpdate { } } +#[typeshare] +#[derive(Debug, Deserialize)] +pub struct OtgHidFunctionsUpdate { + pub keyboard: Option, + pub mouse_relative: Option, + pub mouse_absolute: Option, + pub consumer: Option, +} + +impl OtgHidFunctionsUpdate { + pub fn apply_to(&self, config: &mut OtgHidFunctions) { + if let Some(enabled) = self.keyboard { + config.keyboard = enabled; + } + if let Some(enabled) = self.mouse_relative { + config.mouse_relative = enabled; + } + if let Some(enabled) = self.mouse_absolute { + config.mouse_absolute = enabled; + } + if let Some(enabled) = self.consumer { + config.consumer = enabled; + } + } +} + #[typeshare] #[derive(Debug, Deserialize)] pub struct HidConfigUpdate { @@ -260,6 +305,8 @@ pub struct HidConfigUpdate { pub ch9329_baudrate: Option, pub otg_udc: Option, pub otg_descriptor: Option, + pub otg_profile: Option, + pub otg_functions: Option, pub mouse_absolute: Option, } @@ -295,6 +342,12 @@ impl HidConfigUpdate { if let Some(ref desc) = self.otg_descriptor { desc.apply_to(&mut config.otg_descriptor); } + if let Some(profile) = self.otg_profile.clone() { + config.otg_profile = profile; + } + if let Some(ref functions) = self.otg_functions { + functions.apply_to(&mut config.otg_functions); + } if let Some(absolute) = self.mouse_absolute { config.mouse_absolute = absolute; } @@ -557,6 +610,7 @@ impl RustDeskConfigUpdate { pub struct WebConfigUpdate { pub http_port: Option, pub https_port: Option, + pub bind_addresses: Option>, pub bind_address: Option, pub https_enabled: Option, } @@ -573,6 +627,13 @@ impl WebConfigUpdate { return Err(AppError::BadRequest("HTTPS port cannot be 0".into())); } } + if let Some(ref addrs) = self.bind_addresses { + for addr in addrs { + if addr.parse::().is_err() { + return Err(AppError::BadRequest("Invalid bind address".into())); + } + } + } if let Some(ref addr) = self.bind_address { if addr.parse::().is_err() { return Err(AppError::BadRequest("Invalid bind address".into())); @@ -588,8 +649,16 @@ impl WebConfigUpdate { if let Some(port) = self.https_port { config.https_port = port; } - if let Some(ref addr) = self.bind_address { + if let Some(ref addrs) = self.bind_addresses { + config.bind_addresses = addrs.clone(); + if let Some(first) = addrs.first() { + config.bind_address = first.clone(); + } + } else if let Some(ref addr) = self.bind_address { config.bind_address = addr.clone(); + if config.bind_addresses.is_empty() { + config.bind_addresses = vec![addr.clone()]; + } } if let Some(enabled) = self.https_enabled { config.https_enabled = enabled; diff --git a/src/web/handlers/mod.rs b/src/web/handlers/mod.rs index 787697bf..ea0d077f 100644 --- a/src/web/handlers/mod.rs +++ b/src/web/handlers/mod.rs @@ -12,6 +12,7 @@ use tracing::{info, warn}; use crate::auth::{Session, SESSION_COOKIE}; use crate::config::{AppConfig, StreamMode}; use crate::error::{AppError, Result}; +use crate::events::SystemEvent; use crate::state::AppState; use crate::video::encoder::BitratePreset; @@ -315,28 +316,11 @@ fn get_network_addresses() -> Vec { Err(_) => return Vec::new(), }; - // Build a map of interface name -> IPv4 address - let mut ipv4_map: std::collections::HashMap = std::collections::HashMap::new(); - for ifaddr in all_addrs { - // Skip loopback - if ifaddr.interface_name == "lo" { - continue; - } - // Only collect IPv4 addresses (skip if already have one for this interface) - if !ipv4_map.contains_key(&ifaddr.interface_name) { - if let Some(addr) = ifaddr.address { - if let Some(sockaddr_in) = addr.as_sockaddr_in() { - ipv4_map.insert(ifaddr.interface_name.clone(), sockaddr_in.ip().to_string()); - } - } - } - } - - // Now check which interfaces are up - let mut addresses = Vec::new(); + // Check which interfaces are up + let mut up_ifaces = std::collections::HashSet::new(); let net_dir = match std::fs::read_dir("/sys/class/net") { Ok(dir) => dir, - Err(_) => return addresses, + Err(_) => return Vec::new(), }; for entry in net_dir.flatten() { @@ -360,12 +344,43 @@ fn get_network_addresses() -> Vec { continue; } - // Get IP from pre-fetched map - if let Some(ip) = ipv4_map.remove(&iface_name) { - addresses.push(NetworkAddress { - interface: iface_name, - ip, - }); + up_ifaces.insert(iface_name); + } + + let mut addresses = Vec::new(); + let mut seen = std::collections::HashSet::new(); + for ifaddr in all_addrs { + let iface_name = &ifaddr.interface_name; + if iface_name == "lo" || !up_ifaces.contains(iface_name) { + continue; + } + + if let Some(addr) = ifaddr.address { + if let Some(sockaddr_in) = addr.as_sockaddr_in() { + let ip = sockaddr_in.ip(); + if ip.is_loopback() { + continue; + } + let ip_str = ip.to_string(); + if seen.insert((iface_name.clone(), ip_str.clone())) { + addresses.push(NetworkAddress { + interface: iface_name.clone(), + ip: ip_str, + }); + } + } else if let Some(sockaddr_in6) = addr.as_sockaddr_in6() { + let ip = sockaddr_in6.ip(); + if ip.is_loopback() || ip.is_unspecified() || ip.is_unicast_link_local() { + continue; + } + let ip_str = ip.to_string(); + if seen.insert((iface_name.clone(), ip_str.clone())) { + addresses.push(NetworkAddress { + interface: iface_name.clone(), + ip: ip_str, + }); + } + } } } @@ -407,6 +422,13 @@ pub async fn login( .await? .ok_or_else(|| AppError::AuthError("Invalid username or password".to_string()))?; + if !config.auth.single_user_allow_multiple_sessions { + // Kick existing sessions before creating a new one. + let revoked_ids = state.sessions.list_ids().await?; + state.sessions.delete_all().await?; + state.remember_revoked_sessions(revoked_ids).await; + } + // Create session let session = state.sessions.create(&user.id).await?; @@ -457,7 +479,6 @@ pub async fn logout( pub struct AuthCheckResponse { pub authenticated: bool, pub user: Option, - pub is_admin: bool, } pub async fn auth_check( @@ -465,15 +486,14 @@ pub async fn auth_check( axum::Extension(session): axum::Extension, ) -> Json { // Get user info from user_id - let (username, is_admin) = match state.users.get(&session.user_id).await { - Ok(Some(user)) => (Some(user.username), user.is_admin), - _ => (Some(session.user_id.clone()), false), // Fallback to user_id if user not found + let username = match state.users.get(&session.user_id).await { + Ok(Some(user)) => Some(user.username), + _ => Some(session.user_id.clone()), // Fallback to user_id if user not found }; Json(AuthCheckResponse { authenticated: true, user: username, - is_admin, }) } @@ -513,11 +533,39 @@ pub struct SetupRequest { pub hid_ch9329_port: Option, pub hid_ch9329_baudrate: Option, pub hid_otg_udc: Option, + pub hid_otg_profile: Option, // Extension settings pub ttyd_enabled: Option, pub rustdesk_enabled: Option, } +fn normalize_otg_profile_for_low_endpoint(config: &mut AppConfig) { + if !matches!(config.hid.backend, crate::config::HidBackend::Otg) { + return; + } + let udc = crate::otg::configfs::resolve_udc_name(config.hid.otg_udc.as_deref()); + let Some(udc) = udc else { + return; + }; + if !crate::otg::configfs::is_low_endpoint_udc(&udc) { + return; + } + match config.hid.otg_profile { + crate::config::OtgHidProfile::Full => { + config.hid.otg_profile = crate::config::OtgHidProfile::FullNoConsumer; + } + crate::config::OtgHidProfile::FullNoMsd => { + config.hid.otg_profile = crate::config::OtgHidProfile::FullNoConsumerNoMsd; + } + crate::config::OtgHidProfile::Custom => { + if config.hid.otg_functions.consumer { + config.hid.otg_functions.consumer = false; + } + } + _ => {} + } +} + pub async fn setup_init( State(state): State>, Json(req): Json, @@ -593,6 +641,33 @@ pub async fn setup_init( if let Some(udc) = req.hid_otg_udc.clone() { config.hid.otg_udc = Some(udc); } + if let Some(profile) = req.hid_otg_profile.clone() { + config.hid.otg_profile = match profile.as_str() { + "full" => crate::config::OtgHidProfile::Full, + "full_no_msd" => crate::config::OtgHidProfile::FullNoMsd, + "full_no_consumer" => crate::config::OtgHidProfile::FullNoConsumer, + "full_no_consumer_no_msd" => crate::config::OtgHidProfile::FullNoConsumerNoMsd, + "legacy_keyboard" => crate::config::OtgHidProfile::LegacyKeyboard, + "legacy_mouse_relative" => crate::config::OtgHidProfile::LegacyMouseRelative, + "custom" => crate::config::OtgHidProfile::Custom, + _ => config.hid.otg_profile.clone(), + }; + if matches!(config.hid.backend, crate::config::HidBackend::Otg) { + match config.hid.otg_profile { + crate::config::OtgHidProfile::Full + | crate::config::OtgHidProfile::FullNoConsumer => { + config.msd.enabled = true; + } + crate::config::OtgHidProfile::FullNoMsd + | crate::config::OtgHidProfile::FullNoConsumerNoMsd + | crate::config::OtgHidProfile::LegacyKeyboard + | crate::config::OtgHidProfile::LegacyMouseRelative => { + config.msd.enabled = false; + } + crate::config::OtgHidProfile::Custom => {} + } + } + } // Extension settings if let Some(enabled) = req.ttyd_enabled { @@ -601,12 +676,32 @@ pub async fn setup_init( if let Some(enabled) = req.rustdesk_enabled { config.rustdesk.enabled = enabled; } + + normalize_otg_profile_for_low_endpoint(config); }) .await?; // Get updated config for HID reload let new_config = state.config.get(); + if matches!(new_config.hid.backend, crate::config::HidBackend::Otg) { + let mut hid_functions = new_config.hid.effective_otg_functions(); + if let Some(udc) = + crate::otg::configfs::resolve_udc_name(new_config.hid.otg_udc.as_deref()) + { + if crate::otg::configfs::is_low_endpoint_udc(&udc) && hid_functions.consumer { + tracing::warn!( + "UDC {} has low endpoint resources, disabling consumer control", + udc + ); + hid_functions.consumer = false; + } + } + if let Err(e) = state.otg_service.update_hid_functions(hid_functions).await { + tracing::warn!("Failed to apply HID functions during setup: {}", e); + } + } + tracing::info!( "Extension config after save: ttyd.enabled={}, rustdesk.enabled={}", new_config.extensions.ttyd.enabled, @@ -719,6 +814,9 @@ pub async fn update_config( let new_config: AppConfig = serde_json::from_value(merged) .map_err(|e| AppError::BadRequest(format!("Invalid config format: {}", e)))?; + let mut new_config = new_config; + normalize_otg_profile_for_low_endpoint(&mut new_config); + // Apply the validated config state.config.set(new_config.clone()).await?; @@ -797,34 +895,57 @@ pub async fn update_config( } tracing::info!("Video config applied successfully"); - // Step 3: Start the streamer to begin capturing frames - // This is necessary because apply_video_config only creates the capturer but doesn't start it - if let Err(e) = state.stream_manager.start().await { - tracing::error!("Failed to start streamer after config change: {}", e); - // Don't fail the request - the stream might start later when client connects - } else { - tracing::info!("Streamer started after config change"); + // Step 3: Start the streamer to begin capturing frames (MJPEG mode only) + if !state.stream_manager.is_webrtc_enabled().await { + // This is necessary because apply_video_config only creates the capturer but doesn't start it + if let Err(e) = state.stream_manager.start().await { + tracing::error!("Failed to start streamer after config change: {}", e); + // Don't fail the request - the stream might start later when client connects + } else { + tracing::info!("Streamer started after config change"); + } } - // Update frame source from the NEW capturer - // This is critical - the old frame_tx is invalid after config change - // New sessions will use this frame_tx when they connect - if let Some(frame_tx) = state.stream_manager.frame_sender().await { - let receiver_count = frame_tx.receiver_count(); - // Use WebRtcStreamer (new unified interface) + // Configure WebRTC direct capture (all modes) + let (device_path, _resolution, _format, _fps, jpeg_quality) = state + .stream_manager + .streamer() + .current_capture_config() + .await; + if let Some(device_path) = device_path { state .stream_manager .webrtc_streamer() - .set_video_source(frame_tx) + .set_capture_device(device_path, jpeg_quality) .await; - tracing::info!( - "WebRTC streamer frame source updated with new capturer (receiver_count={})", - receiver_count - ); } else { - tracing::warn!( - "No frame source available after config change - streamer may not be running" - ); + tracing::warn!("No capture device configured for WebRTC"); + } + + if state.stream_manager.is_webrtc_enabled().await { + use crate::video::encoder::VideoCodecType; + let codec = state + .stream_manager + .webrtc_streamer() + .current_video_codec() + .await; + let codec_str = match codec { + VideoCodecType::H264 => "h264", + VideoCodecType::H265 => "h265", + VideoCodecType::VP8 => "vp8", + VideoCodecType::VP9 => "vp9", + } + .to_string(); + let is_hardware = state + .stream_manager + .webrtc_streamer() + .is_hardware_encoding() + .await; + state.events.publish(SystemEvent::WebRTCReady { + transition_id: None, + codec: codec_str, + hardware: is_hardware, + }); } } @@ -1388,8 +1509,9 @@ pub async fn stream_mode_set( } }; + let no_switch_needed = !tx.accepted && !tx.switching && tx.transition_id.is_none(); Ok(Json(StreamModeResponse { - success: tx.accepted, + success: tx.accepted || no_switch_needed, mode: if tx.accepted { requested_mode_str.to_string() } else { @@ -1935,6 +2057,7 @@ pub async fn webrtc_close_session( #[derive(Serialize)] pub struct IceServersResponse { pub ice_servers: Vec, + pub mdns_mode: String, } #[derive(Serialize)] @@ -1950,6 +2073,7 @@ pub struct IceServerInfo { /// Returns user-configured servers, or Google STUN as fallback if none configured pub async fn webrtc_ice_servers(State(state): State>) -> Json { use crate::webrtc::config::public_ice; + use crate::webrtc::mdns::{mdns_mode, mdns_mode_label}; let config = state.config.get(); let mut ice_servers = Vec::new(); @@ -2005,7 +2129,13 @@ pub async fn webrtc_ice_servers(State(state): State>) -> Json for UserResponse { - fn from(user: crate::auth::User) -> Self { - Self { - id: user.id, - username: user.username, - is_admin: user.is_admin, - created_at: user.created_at.to_rfc3339(), - updated_at: user.updated_at.to_rfc3339(), - } - } -} - -/// List all users (admin only) -pub async fn list_users( - State(state): State>, - Extension(session): Extension, -) -> Result>> { - // Check if current user is admin - let current_user = state - .users - .get(&session.user_id) - .await? - .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; - - if !current_user.is_admin { - return Err(AppError::Forbidden("Admin access required".to_string())); - } - - let users = state.users.list().await?; - let response: Vec = users.into_iter().map(UserResponse::from).collect(); - Ok(Json(response)) -} - -/// Create user request -#[derive(Deserialize)] -pub struct CreateUserRequest { - pub username: String, - pub password: String, - pub is_admin: bool, -} - -/// Create new user (admin only) -pub async fn create_user( - State(state): State>, - Extension(session): Extension, - Json(req): Json, -) -> Result> { - // Check if current user is admin - let current_user = state - .users - .get(&session.user_id) - .await? - .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; - - if !current_user.is_admin { - return Err(AppError::Forbidden("Admin access required".to_string())); - } - - // Validate input - if req.username.len() < 2 { - return Err(AppError::BadRequest( - "Username must be at least 2 characters".to_string(), - )); - } - if req.password.len() < 4 { - return Err(AppError::BadRequest( - "Password must be at least 4 characters".to_string(), - )); - } - - let user = state - .users - .create(&req.username, &req.password, req.is_admin) - .await?; - info!("User created: {} (admin: {})", user.username, user.is_admin); - Ok(Json(UserResponse::from(user))) -} - -/// Update user request -#[derive(Deserialize)] -pub struct UpdateUserRequest { - pub username: Option, - pub is_admin: Option, -} - -/// Update user (admin only) -pub async fn update_user( - State(state): State>, - Extension(session): Extension, - Path(user_id): Path, - Json(req): Json, -) -> Result> { - // Check if current user is admin - let current_user = state - .users - .get(&session.user_id) - .await? - .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; - - if !current_user.is_admin { - return Err(AppError::Forbidden("Admin access required".to_string())); - } - - // Get target user - let mut user = state - .users - .get(&user_id) - .await? - .ok_or_else(|| AppError::NotFound("User not found".to_string()))?; - - // Update fields if provided - if let Some(username) = req.username { - if username.len() < 2 { - return Err(AppError::BadRequest( - "Username must be at least 2 characters".to_string(), - )); - } - user.username = username; - } - if let Some(is_admin) = req.is_admin { - user.is_admin = is_admin; - } - - // Note: We need to add an update method to UserStore - // For now, return error - Err(AppError::Internal( - "User update not yet implemented".to_string(), - )) -} - -/// Delete user (admin only) -pub async fn delete_user( - State(state): State>, - Extension(session): Extension, - Path(user_id): Path, -) -> Result> { - // Check if current user is admin - let current_user = state - .users - .get(&session.user_id) - .await? - .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; - - if !current_user.is_admin { - return Err(AppError::Forbidden("Admin access required".to_string())); - } - - // Prevent deleting self - if user_id == session.user_id { - return Err(AppError::BadRequest( - "Cannot delete your own account".to_string(), - )); - } - - // Check if this is the last admin - let users = state.users.list().await?; - let admin_count = users.iter().filter(|u| u.is_admin).count(); - let target_user = state - .users - .get(&user_id) - .await? - .ok_or_else(|| AppError::NotFound("User not found".to_string()))?; - - if target_user.is_admin && admin_count <= 1 { - return Err(AppError::BadRequest( - "Cannot delete the last admin user".to_string(), - )); - } - - state.users.delete(&user_id).await?; - info!("User deleted: {}", target_user.username); - - Ok(Json(LoginResponse { - success: true, - message: Some("User deleted successfully".to_string()), - })) -} - /// Change password request #[derive(Deserialize)] pub struct ChangePasswordRequest { @@ -2862,54 +2801,39 @@ pub struct ChangePasswordRequest { pub new_password: String, } -/// Change user password -pub async fn change_user_password( +/// Change current user's password +pub async fn change_password( State(state): State>, - Extension(session): Extension, - Path(user_id): Path, + axum::Extension(session): axum::Extension, Json(req): Json, ) -> Result> { - // Check if current user is admin or changing own password let current_user = state .users .get(&session.user_id) .await? .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; - let is_self = user_id == session.user_id; - let is_admin = current_user.is_admin; - - if !is_self && !is_admin { - return Err(AppError::Forbidden( - "Cannot change other user's password".to_string(), - )); - } - - // Validate new password if req.new_password.len() < 4 { return Err(AppError::BadRequest( "Password must be at least 4 characters".to_string(), )); } - // If changing own password, verify current password - if is_self { - let verified = state - .users - .verify(¤t_user.username, &req.current_password) - .await?; - if verified.is_none() { - return Err(AppError::AuthError( - "Current password is incorrect".to_string(), - )); - } + let verified = state + .users + .verify(¤t_user.username, &req.current_password) + .await?; + if verified.is_none() { + return Err(AppError::AuthError( + "Current password is incorrect".to_string(), + )); } state .users - .update_password(&user_id, &req.new_password) + .update_password(&session.user_id, &req.new_password) .await?; - info!("Password changed for user ID: {}", user_id); + info!("Password changed for user ID: {}", session.user_id); Ok(Json(LoginResponse { success: true, @@ -2917,6 +2841,55 @@ pub async fn change_user_password( })) } +/// Change username request +#[derive(Deserialize)] +pub struct ChangeUsernameRequest { + pub username: String, + pub current_password: String, +} + +/// Change current user's username +pub async fn change_username( + State(state): State>, + axum::Extension(session): axum::Extension, + Json(req): Json, +) -> Result> { + let current_user = state + .users + .get(&session.user_id) + .await? + .ok_or_else(|| AppError::AuthError("User not found".to_string()))?; + + if req.username.len() < 2 { + return Err(AppError::BadRequest( + "Username must be at least 2 characters".to_string(), + )); + } + + let verified = state + .users + .verify(¤t_user.username, &req.current_password) + .await?; + if verified.is_none() { + return Err(AppError::AuthError( + "Current password is incorrect".to_string(), + )); + } + + if current_user.username != req.username { + state + .users + .update_username(&session.user_id, &req.username) + .await?; + } + info!("Username changed for user ID: {}", session.user_id); + + Ok(Json(LoginResponse { + success: true, + message: Some("Username changed successfully".to_string()), + })) +} + // ============================================================================ // System Control // ============================================================================ diff --git a/src/web/routes.rs b/src/web/routes.rs index a913c796..02e74bf5 100644 --- a/src/web/routes.rs +++ b/src/web/routes.rs @@ -1,7 +1,7 @@ use axum::{ extract::DefaultBodyLimit, middleware, - routing::{any, delete, get, patch, post, put}, + routing::{any, delete, get, patch, post}, Router, }; use std::sync::Arc; @@ -13,7 +13,7 @@ use tower_http::{ use super::audio_ws::audio_ws_handler; use super::handlers; use super::ws::ws_handler; -use crate::auth::{auth_middleware, require_admin}; +use crate::auth::auth_middleware; use crate::hid::websocket::ws_hid_handler; use crate::state::AppState; @@ -32,11 +32,13 @@ pub fn create_router(state: Arc) -> Router { .route("/setup", get(handlers::setup_status)) .route("/setup/init", post(handlers::setup_init)); - // User routes (authenticated users - both regular and admin) + // Authenticated routes (all logged-in users) let user_routes = Router::new() .route("/info", get(handlers::system_info)) .route("/auth/logout", post(handlers::logout)) .route("/auth/check", get(handlers::auth_check)) + .route("/auth/password", post(handlers::change_password)) + .route("/auth/username", post(handlers::change_username)) .route("/devices", get(handlers::list_devices)) // WebSocket endpoint for real-time events .route("/ws", any(ws_handler)) @@ -69,11 +71,6 @@ pub fn create_router(state: Arc) -> Router { .route("/audio/devices", get(handlers::list_audio_devices)) // Audio WebSocket endpoint .route("/ws/audio", any(audio_ws_handler)) - // User can change their own password (handler will check ownership) - .route("/users/{id}/password", post(handlers::change_user_password)); - - // Admin-only routes (require admin privileges) - let admin_routes = Router::new() // Configuration management (domain-separated endpoints) .route("/config", get(handlers::config::get_all_config)) .route("/config", post(handlers::update_config)) @@ -126,6 +123,9 @@ pub fn create_router(state: Arc) -> Router { // Web server configuration .route("/config/web", get(handlers::config::get_web_config)) .route("/config/web", patch(handlers::config::update_web_config)) + // Auth configuration + .route("/config/auth", get(handlers::config::get_auth_config)) + .route("/config/auth", patch(handlers::config::update_auth_config)) // System control .route("/system/restart", post(handlers::system_restart)) // MSD (Mass Storage Device) endpoints @@ -160,11 +160,6 @@ pub fn create_router(state: Arc) -> Router { .route("/atx/wol", post(handlers::atx_wol)) // Device discovery endpoints .route("/devices/atx", get(handlers::devices::list_atx_devices)) - // User management endpoints - .route("/users", get(handlers::list_users)) - .route("/users", post(handlers::create_user)) - .route("/users/{id}", put(handlers::update_user)) - .route("/users/{id}", delete(handlers::delete_user)) // Extension management endpoints .route("/extensions", get(handlers::extensions::list_extensions)) .route("/extensions/{id}", get(handlers::extensions::get_extension)) @@ -200,12 +195,10 @@ pub fn create_router(state: Arc) -> Router { .route("/terminal", get(handlers::terminal::terminal_index)) .route("/terminal/", get(handlers::terminal::terminal_index)) .route("/terminal/ws", get(handlers::terminal::terminal_ws)) - .route("/terminal/{*path}", get(handlers::terminal::terminal_proxy)) - // Apply admin middleware to all admin routes - .layer(middleware::from_fn_with_state(state.clone(), require_admin)); + .route("/terminal/{*path}", get(handlers::terminal::terminal_proxy)); - // Combine protected routes (user + admin) - let protected_routes = Router::new().merge(user_routes).merge(admin_routes); + // Protected routes (all authenticated users) + let protected_routes = user_routes; // Stream endpoints (accessible with auth, but typically embedded in pages) let stream_routes = Router::new() diff --git a/src/webrtc/mdns.rs b/src/webrtc/mdns.rs new file mode 100644 index 00000000..61e65c2e --- /dev/null +++ b/src/webrtc/mdns.rs @@ -0,0 +1,34 @@ +use webrtc::ice::mdns::MulticastDnsMode; + +pub fn mdns_mode_from_env() -> Option { + let raw = std::env::var("ONE_KVM_WEBRTC_MDNS_MODE").ok()?; + let value = raw.trim().to_ascii_lowercase(); + if value.is_empty() { + return None; + } + + match value.as_str() { + "disabled" | "off" | "false" | "0" => Some(MulticastDnsMode::Disabled), + "query" | "query_only" | "query-only" => Some(MulticastDnsMode::QueryOnly), + "gather" | "query_and_gather" | "query-and-gather" | "on" | "true" | "1" => { + Some(MulticastDnsMode::QueryAndGather) + } + _ => None, + } +} + +pub fn mdns_mode() -> MulticastDnsMode { + mdns_mode_from_env().unwrap_or(MulticastDnsMode::QueryAndGather) +} + +pub fn mdns_mode_label(mode: MulticastDnsMode) -> &'static str { + match mode { + MulticastDnsMode::Disabled => "disabled", + MulticastDnsMode::QueryOnly => "query_only", + MulticastDnsMode::QueryAndGather => "query_and_gather", + } +} + +pub fn default_mdns_host_name(session_id: &str) -> String { + format!("{session_id}.local") +} diff --git a/src/webrtc/mod.rs b/src/webrtc/mod.rs index ac264a7c..7fdf3c72 100644 --- a/src/webrtc/mod.rs +++ b/src/webrtc/mod.rs @@ -27,6 +27,7 @@ pub mod config; pub mod h265_payloader; +pub(crate) mod mdns; pub mod peer; pub mod rtp; pub mod session; @@ -42,7 +43,5 @@ pub use rtp::{H264VideoTrack, H264VideoTrackConfig, OpusAudioTrack}; pub use session::WebRtcSessionManager; pub use signaling::{ConnectionState, IceCandidate, SdpAnswer, SdpOffer, SignalingMessage}; pub use universal_session::{UniversalSession, UniversalSessionConfig, UniversalSessionInfo}; -pub use video_track::{ - UniversalVideoTrack, UniversalVideoTrackConfig, VideoCodec, VideoTrackStats, -}; +pub use video_track::{UniversalVideoTrack, UniversalVideoTrackConfig, VideoCodec}; pub use webrtc_streamer::{SessionInfo, WebRtcStreamer, WebRtcStreamerConfig, WebRtcStreamerStats}; diff --git a/src/webrtc/peer.rs b/src/webrtc/peer.rs index 6705aae1..767a0ba3 100644 --- a/src/webrtc/peer.rs +++ b/src/webrtc/peer.rs @@ -5,9 +5,11 @@ use tokio::sync::{broadcast, watch, Mutex, RwLock}; use tracing::{debug, info}; use webrtc::api::interceptor_registry::register_default_interceptors; use webrtc::api::media_engine::MediaEngine; +use webrtc::api::setting_engine::SettingEngine; use webrtc::api::APIBuilder; use webrtc::data_channel::data_channel_message::DataChannelMessage; use webrtc::data_channel::RTCDataChannel; +use webrtc::ice::mdns::MulticastDnsMode; use webrtc::ice_transport::ice_candidate::RTCIceCandidate; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::interceptor::registry::Registry; @@ -17,6 +19,7 @@ use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; use webrtc::peer_connection::RTCPeerConnection; use super::config::WebRtcConfig; +use super::mdns::{default_mdns_host_name, mdns_mode}; use super::signaling::{ConnectionState, IceCandidate, SdpAnswer, SdpOffer}; use super::track::{VideoTrack, VideoTrackConfig}; use crate::error::{AppError, Result}; @@ -60,8 +63,17 @@ impl PeerConnection { registry = register_default_interceptors(registry, &mut media_engine) .map_err(|e| AppError::VideoError(format!("Failed to register interceptors: {}", e)))?; - // Create API + // Create API (with optional mDNS settings) + let mut setting_engine = SettingEngine::default(); + let mode = mdns_mode(); + setting_engine.set_ice_multicast_dns_mode(mode); + if mode == MulticastDnsMode::QueryAndGather { + setting_engine.set_multicast_dns_host_name(default_mdns_host_name(&session_id)); + } + info!("WebRTC mDNS mode: {:?} (session {})", mode, session_id); + let api = APIBuilder::new() + .with_setting_engine(setting_engine) .with_media_engine(media_engine) .with_interceptor_registry(registry) .build(); @@ -418,7 +430,7 @@ pub struct PeerConnectionManager { impl PeerConnectionManager { /// Create a new peer connection manager pub fn new(config: WebRtcConfig) -> Self { - let (frame_tx, _) = broadcast::channel(16); // Buffer size 16 for low latency + let (frame_tx, _) = broadcast::channel(16); Self { config, @@ -430,7 +442,7 @@ impl PeerConnectionManager { /// Create a new peer connection manager with HID controller pub fn with_hid(config: WebRtcConfig, hid: Arc) -> Self { - let (frame_tx, _) = broadcast::channel(16); // Buffer size 16 for low latency + let (frame_tx, _) = broadcast::channel(16); Self { config, diff --git a/src/webrtc/rtp.rs b/src/webrtc/rtp.rs index 4dbdf26b..e8dac17b 100644 --- a/src/webrtc/rtp.rs +++ b/src/webrtc/rtp.rs @@ -42,8 +42,6 @@ pub struct H264VideoTrack { config: H264VideoTrackConfig, /// H264 payloader for manual packetization (if needed) payloader: Mutex, - /// Statistics - stats: Mutex, /// Cached SPS NAL unit for injection before IDR frames /// Some hardware encoders don't repeat SPS/PPS with every keyframe cached_sps: Mutex>, @@ -83,21 +81,6 @@ impl Default for H264VideoTrackConfig { } } -/// H264 track statistics -#[derive(Debug, Clone, Default)] -pub struct H264TrackStats { - /// Frames sent - pub frames_sent: u64, - /// Bytes sent - pub bytes_sent: u64, - /// Packets sent (RTP packets) - pub packets_sent: u64, - /// Key frames sent - pub keyframes_sent: u64, - /// Errors encountered - pub errors: u64, -} - impl H264VideoTrack { /// Create a new H264 video track /// @@ -134,7 +117,6 @@ impl H264VideoTrack { track, config, payloader: Mutex::new(H264Payloader::default()), - stats: Mutex::new(H264TrackStats::default()), cached_sps: Mutex::new(None), cached_pps: Mutex::new(None), } @@ -150,11 +132,6 @@ impl H264VideoTrack { self.track.clone() } - /// Get current statistics - pub async fn stats(&self) -> H264TrackStats { - self.stats.lock().await.clone() - } - /// Write an H264 encoded frame to the track /// /// The frame data should be H264 Annex B format (with start codes 0x00000001 or 0x000001). @@ -288,16 +265,6 @@ impl H264VideoTrack { nal_count += 1; } - // Update statistics - if nal_count > 0 { - let mut stats = self.stats.lock().await; - stats.frames_sent += 1; - stats.bytes_sent += total_bytes; - if is_keyframe { - stats.keyframes_sent += 1; - } - } - trace!( "Sent frame: {} NAL units, {} bytes, keyframe={}", nal_count, @@ -344,19 +311,6 @@ impl H264VideoTrack { pub struct OpusAudioTrack { /// The underlying WebRTC track track: Arc, - /// Statistics - stats: Mutex, -} - -/// Opus track statistics -#[derive(Debug, Clone, Default)] -pub struct OpusTrackStats { - /// Packets sent - pub packets_sent: u64, - /// Bytes sent - pub bytes_sent: u64, - /// Errors - pub errors: u64, } impl OpusAudioTrack { @@ -378,7 +332,6 @@ impl OpusAudioTrack { Self { track, - stats: Mutex::new(OpusTrackStats::default()), } } @@ -392,11 +345,6 @@ impl OpusAudioTrack { self.track.clone() } - /// Get statistics - pub async fn stats(&self) -> OpusTrackStats { - self.stats.lock().await.clone() - } - /// Write Opus encoded audio data /// /// # Arguments @@ -417,23 +365,13 @@ impl OpusAudioTrack { ..Default::default() }; - match self.track.write_sample(&sample).await { - Ok(_) => { - let mut stats = self.stats.lock().await; - stats.packets_sent += 1; - stats.bytes_sent += data.len() as u64; - Ok(()) - } - Err(e) => { - let mut stats = self.stats.lock().await; - stats.errors += 1; + self.track + .write_sample(&sample) + .await + .map_err(|e| { error!("Failed to write Opus sample: {}", e); - Err(AppError::WebRtcError(format!( - "Failed to write audio sample: {}", - e - ))) - } - } + AppError::WebRtcError(format!("Failed to write audio sample: {}", e)) + }) } } diff --git a/src/webrtc/track.rs b/src/webrtc/track.rs index e144a8be..f9617df2 100644 --- a/src/webrtc/track.rs +++ b/src/webrtc/track.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use std::time::Instant; -use tokio::sync::{broadcast, watch, Mutex}; +use tokio::sync::{broadcast, watch}; use tracing::{debug, error, info}; use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; @@ -87,38 +87,11 @@ pub fn audio_codec_capability() -> RTCRtpCodecCapability { } } -/// Video track statistics -#[derive(Debug, Clone, Default)] -pub struct VideoTrackStats { - /// Frames sent - pub frames_sent: u64, - /// Bytes sent - pub bytes_sent: u64, - /// Packets sent - pub packets_sent: u64, - /// Packets lost (RTCP feedback) - pub packets_lost: u64, - /// Current bitrate (bps) - pub current_bitrate: u64, - /// Round trip time (ms) - pub rtt_ms: f64, - /// Jitter (ms) - pub jitter_ms: f64, -} - /// Video track for WebRTC streaming pub struct VideoTrack { config: VideoTrackConfig, /// RTP track track: Arc, - /// Statistics - stats: Arc>, - /// Sequence number for RTP - sequence_number: Arc>, - /// Timestamp for RTP - timestamp: Arc>, - /// Last frame time - last_frame_time: Arc>>, /// Running flag running: Arc>, } @@ -139,10 +112,6 @@ impl VideoTrack { Self { config, track, - stats: Arc::new(Mutex::new(VideoTrackStats::default())), - sequence_number: Arc::new(Mutex::new(0)), - timestamp: Arc::new(Mutex::new(0)), - last_frame_time: Arc::new(Mutex::new(None)), running: Arc::new(running_tx), } } @@ -152,25 +121,17 @@ impl VideoTrack { self.track.clone() } - /// Get current statistics - pub async fn stats(&self) -> VideoTrackStats { - self.stats.lock().await.clone() - } - /// Start sending frames from a broadcast receiver pub async fn start_sending(&self, mut frame_rx: broadcast::Receiver) { let _ = self.running.send(true); let track = self.track.clone(); - let stats = self.stats.clone(); - let sequence_number = self.sequence_number.clone(); - let timestamp = self.timestamp.clone(); - let last_frame_time = self.last_frame_time.clone(); let clock_rate = self.config.clock_rate; let mut running_rx = self.running.subscribe(); info!("Starting video track sender"); tokio::spawn(async move { + let mut state = SendState::default(); loop { tokio::select! { result = frame_rx.recv() => { @@ -179,10 +140,7 @@ impl VideoTrack { if let Err(e) = Self::send_frame( &track, &frame, - &stats, - &sequence_number, - ×tamp, - &last_frame_time, + &mut state, clock_rate, ).await { debug!("Failed to send frame: {}", e); @@ -219,29 +177,22 @@ impl VideoTrack { async fn send_frame( track: &TrackLocalStaticRTP, frame: &VideoFrame, - stats: &Mutex, - sequence_number: &Mutex, - timestamp: &Mutex, - last_frame_time: &Mutex>, + state: &mut SendState, clock_rate: u32, ) -> Result<(), Box> { // Calculate timestamp increment based on frame timing let now = Instant::now(); - let mut last_time = last_frame_time.lock().await; - let timestamp_increment = if let Some(last) = *last_time { + let timestamp_increment = if let Some(last) = state.last_frame_time { let elapsed = now.duration_since(last); ((elapsed.as_secs_f64() * clock_rate as f64) as u32).min(clock_rate / 10) } else { clock_rate / 30 // Default to 30 fps }; - *last_time = Some(now); - drop(last_time); + state.last_frame_time = Some(now); // Update timestamp - let mut ts = timestamp.lock().await; - *ts = ts.wrapping_add(timestamp_increment); - let _current_ts = *ts; - drop(ts); + state.timestamp = state.timestamp.wrapping_add(timestamp_increment); + let _current_ts = state.timestamp; // For H.264, we need to packetize into RTP // This is a simplified implementation - real implementation needs proper NAL unit handling @@ -257,33 +208,34 @@ impl VideoTrack { let _is_last = i == packet_count - 1; // Get sequence number - let mut seq = sequence_number.lock().await; - let _seq_num = *seq; - *seq = seq.wrapping_add(1); - drop(seq); + let _seq_num = state.sequence_number; + state.sequence_number = state.sequence_number.wrapping_add(1); // Build RTP packet payload // For simplicity, just send raw data - real implementation needs proper RTP packetization - let payload = data[start..end].to_vec(); + let payload = &data[start..end]; bytes_sent += payload.len() as u64; // Write sample (the track handles RTP header construction) - if let Err(e) = track.write(&payload).await { + if let Err(e) = track.write(payload).await { error!("Failed to write RTP packet: {}", e); return Err(e.into()); } } - // Update stats - let mut s = stats.lock().await; - s.frames_sent += 1; - s.bytes_sent += bytes_sent; - s.packets_sent += packet_count as u64; + let _ = bytes_sent; Ok(()) } } +#[derive(Debug, Default)] +struct SendState { + sequence_number: u16, + timestamp: u32, + last_frame_time: Option, +} + /// Audio track configuration #[derive(Debug, Clone)] pub struct AudioTrackConfig { diff --git a/src/webrtc/unified_video_track.rs b/src/webrtc/unified_video_track.rs index 3cf6e6df..d288f0e0 100644 --- a/src/webrtc/unified_video_track.rs +++ b/src/webrtc/unified_video_track.rs @@ -123,15 +123,6 @@ impl Default for UnifiedVideoTrackConfig { } } -/// Unified video track statistics -#[derive(Debug, Clone, Default)] -pub struct UnifiedVideoTrackStats { - pub frames_sent: u64, - pub bytes_sent: u64, - pub keyframes_sent: u64, - pub errors: u64, -} - /// Cached NAL parameter sets for H264 struct H264ParameterSets { sps: Option, @@ -179,8 +170,6 @@ pub struct UnifiedVideoTrack { track: Arc, /// Track configuration config: UnifiedVideoTrackConfig, - /// Statistics - stats: Mutex, /// H264 parameter set cache h264_params: Mutex, /// H265 parameter set cache @@ -207,7 +196,6 @@ impl UnifiedVideoTrack { Self { track, config, - stats: Mutex::new(UnifiedVideoTrackStats::default()), h264_params: Mutex::new(H264ParameterSets { sps: None, pps: None }), h265_params: Mutex::new(H265ParameterSets { vps: None, sps: None, pps: None }), } @@ -277,9 +265,6 @@ impl UnifiedVideoTrack { } /// Get statistics - pub async fn stats(&self) -> UnifiedVideoTrackStats { - self.stats.lock().await.clone() - } /// Write an encoded frame to the track /// @@ -504,13 +489,6 @@ impl UnifiedVideoTrack { debug!("VP8 write_sample failed: {}", e); } - let mut stats = self.stats.lock().await; - stats.frames_sent += 1; - stats.bytes_sent += data.len() as u64; - if is_keyframe { - stats.keyframes_sent += 1; - } - trace!("VP8 frame: {} bytes, keyframe={}", data.len(), is_keyframe); Ok(()) } @@ -531,13 +509,6 @@ impl UnifiedVideoTrack { debug!("VP9 write_sample failed: {}", e); } - let mut stats = self.stats.lock().await; - stats.frames_sent += 1; - stats.bytes_sent += data.len() as u64; - if is_keyframe { - stats.keyframes_sent += 1; - } - trace!("VP9 frame: {} bytes, keyframe={}", data.len(), is_keyframe); Ok(()) } @@ -572,15 +543,6 @@ impl UnifiedVideoTrack { total_bytes += nal_data.len() as u64; } - if nal_count > 0 { - let mut stats = self.stats.lock().await; - stats.frames_sent += 1; - stats.bytes_sent += total_bytes; - if is_keyframe { - stats.keyframes_sent += 1; - } - } - trace!("Sent {} NAL units, {} bytes, keyframe={}", nal_count, total_bytes, is_keyframe); Ok(()) } diff --git a/src/webrtc/universal_session.rs b/src/webrtc/universal_session.rs index 81f7e34c..b62bc89f 100644 --- a/src/webrtc/universal_session.rs +++ b/src/webrtc/universal_session.rs @@ -4,13 +4,16 @@ //! Replaces the H264-only H264Session with a more flexible implementation. use std::sync::Arc; -use tokio::sync::{broadcast, watch, Mutex, RwLock}; -use tracing::{debug, info, trace, warn}; +use std::time::{Duration, Instant}; +use tokio::sync::{watch, Mutex, RwLock}; +use tracing::{debug, info, warn}; use webrtc::api::interceptor_registry::register_default_interceptors; use webrtc::api::media_engine::MediaEngine; +use webrtc::api::setting_engine::SettingEngine; use webrtc::api::APIBuilder; use webrtc::data_channel::data_channel_message::DataChannelMessage; use webrtc::data_channel::RTCDataChannel; +use webrtc::ice::mdns::MulticastDnsMode; use webrtc::ice_transport::ice_candidate::RTCIceCandidate; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::interceptor::registry::Registry; @@ -24,17 +27,21 @@ use webrtc::rtp_transceiver::rtp_codec::{ use webrtc::rtp_transceiver::RTCPFeedback; use super::config::WebRtcConfig; +use super::mdns::{default_mdns_host_name, mdns_mode}; use super::rtp::OpusAudioTrack; use super::signaling::{ConnectionState, IceCandidate, SdpAnswer, SdpOffer}; use super::video_track::{UniversalVideoTrack, UniversalVideoTrackConfig, VideoCodec}; use crate::audio::OpusFrame; use crate::error::{AppError, Result}; +use crate::events::{EventBus, SystemEvent}; use crate::hid::datachannel::{parse_hid_message, HidChannelEvent}; use crate::hid::HidController; use crate::video::encoder::registry::VideoEncoderType; use crate::video::encoder::BitratePreset; use crate::video::format::{PixelFormat, Resolution}; use crate::video::shared_video_pipeline::EncodedVideoFrame; +use std::sync::atomic::AtomicBool; +use webrtc::ice_transport::ice_gatherer_state::RTCIceGathererState; /// H.265/HEVC MIME type (RFC 7798) const MIME_TYPE_H265: &str = "video/H265"; @@ -117,6 +124,8 @@ pub struct UniversalSession { ice_candidates: Arc>>, /// HID controller reference hid_controller: Option>, + /// Event bus for WebRTC signaling events (optional) + event_bus: Option>, /// Video frame receiver handle video_receiver_handle: Mutex>>, /// Audio frame receiver handle @@ -127,7 +136,11 @@ pub struct UniversalSession { impl UniversalSession { /// Create a new universal WebRTC session - pub async fn new(config: UniversalSessionConfig, session_id: String) -> Result { + pub async fn new( + config: UniversalSessionConfig, + session_id: String, + event_bus: Option>, + ) -> Result { info!( "Creating {} session: {} @ {}x{} (audio={})", config.codec, @@ -243,8 +256,17 @@ impl UniversalSession { registry = register_default_interceptors(registry, &mut media_engine) .map_err(|e| AppError::VideoError(format!("Failed to register interceptors: {}", e)))?; - // Create API + // Create API (with optional mDNS settings) + let mut setting_engine = SettingEngine::default(); + let mode = mdns_mode(); + setting_engine.set_ice_multicast_dns_mode(mode); + if mode == MulticastDnsMode::QueryAndGather { + setting_engine.set_multicast_dns_host_name(default_mdns_host_name(&session_id)); + } + info!("WebRTC mDNS mode: {:?} (session {})", mode, session_id); + let api = APIBuilder::new() + .with_setting_engine(setting_engine) .with_media_engine(media_engine) .with_interceptor_registry(registry) .build(); @@ -321,6 +343,7 @@ impl UniversalSession { state_rx, ice_candidates: Arc::new(Mutex::new(vec![])), hid_controller: None, + event_bus, video_receiver_handle: Mutex::new(None), audio_receiver_handle: Mutex::new(None), fps: config.fps, @@ -337,6 +360,7 @@ impl UniversalSession { let state = self.state.clone(); let session_id = self.session_id.clone(); let codec = self.codec; + let event_bus = self.event_bus.clone(); // Connection state change handler self.pc @@ -372,32 +396,56 @@ impl UniversalSession { // ICE gathering state handler let session_id_gather = self.session_id.clone(); + let event_bus_gather = event_bus.clone(); self.pc .on_ice_gathering_state_change(Box::new(move |state| { let session_id = session_id_gather.clone(); + let event_bus = event_bus_gather.clone(); Box::pin(async move { - debug!("[ICE] Session {} gathering state: {:?}", session_id, state); + if matches!(state, RTCIceGathererState::Complete) { + if let Some(bus) = event_bus.as_ref() { + bus.publish(SystemEvent::WebRTCIceComplete { session_id }); + } + } }) })); // ICE candidate handler let ice_candidates = self.ice_candidates.clone(); + let session_id_candidate = self.session_id.clone(); + let event_bus_candidate = event_bus.clone(); self.pc .on_ice_candidate(Box::new(move |candidate: Option| { let ice_candidates = ice_candidates.clone(); + let session_id = session_id_candidate.clone(); + let event_bus = event_bus_candidate.clone(); Box::pin(async move { if let Some(c) = candidate { - let candidate_str = c.to_json().map(|j| j.candidate).unwrap_or_default(); - debug!("ICE candidate: {}", candidate_str); + let candidate_json = c.to_json().ok(); + let candidate_str = candidate_json + .as_ref() + .map(|j| j.candidate.clone()) + .unwrap_or_default(); + let candidate = IceCandidate { + candidate: candidate_str, + sdp_mid: candidate_json.as_ref().and_then(|j| j.sdp_mid.clone()), + sdp_mline_index: candidate_json.as_ref().and_then(|j| j.sdp_mline_index), + username_fragment: candidate_json + .as_ref() + .and_then(|j| j.username_fragment.clone()), + }; let mut candidates = ice_candidates.lock().await; - candidates.push(IceCandidate { - candidate: candidate_str, - sdp_mid: c.to_json().ok().and_then(|j| j.sdp_mid), - sdp_mline_index: c.to_json().ok().and_then(|j| j.sdp_mline_index), - username_fragment: None, - }); + candidates.push(candidate.clone()); + drop(candidates); + + if let Some(bus) = event_bus.as_ref() { + bus.publish(SystemEvent::WebRTCIceCandidate { + session_id, + candidate, + }); + } } }) })); @@ -488,13 +536,11 @@ impl UniversalSession { /// /// The `on_connected` callback is called when ICE connection is established, /// allowing the caller to request a keyframe at the right time. - pub async fn start_from_video_pipeline( + pub async fn start_from_video_pipeline( &self, - mut frame_rx: broadcast::Receiver, - on_connected: F, - ) where - F: FnOnce() + Send + 'static, - { + mut frame_rx: tokio::sync::mpsc::Receiver>, + request_keyframe: Arc, + ) { info!( "Starting {} session {} with shared encoder", self.codec, self.session_id @@ -505,6 +551,7 @@ impl UniversalSession { let session_id = self.session_id.clone(); let _fps = self.fps; let expected_codec = self.codec; + let send_in_flight = Arc::new(AtomicBool::new(false)); let handle = tokio::spawn(async move { info!( @@ -536,7 +583,10 @@ impl UniversalSession { ); // Request keyframe now that connection is established - on_connected(); + request_keyframe(); + let mut waiting_for_keyframe = true; + let mut last_sequence: Option = None; + let mut last_keyframe_request = Instant::now() - Duration::from_secs(1); let mut frames_sent: u64 = 0; @@ -556,64 +606,81 @@ impl UniversalSession { } result = frame_rx.recv() => { - match result { - Ok(encoded_frame) => { - // Verify codec matches - let frame_codec = match encoded_frame.codec { - VideoEncoderType::H264 => VideoEncoderType::H264, - VideoEncoderType::H265 => VideoEncoderType::H265, - VideoEncoderType::VP8 => VideoEncoderType::VP8, - VideoEncoderType::VP9 => VideoEncoderType::VP9, - }; - - if frame_codec != expected_codec { - trace!("Skipping frame with codec {:?}, expected {:?}", frame_codec, expected_codec); - continue; - } - - // Debug log for H265 frames - if expected_codec == VideoEncoderType::H265 { - if encoded_frame.is_keyframe || frames_sent % 30 == 0 { - debug!( - "[Session-H265] Received frame #{}: size={}, keyframe={}, seq={}", - frames_sent, - encoded_frame.data.len(), - encoded_frame.is_keyframe, - encoded_frame.sequence - ); - } - } - - // Send encoded frame via RTP - if let Err(e) = video_track - .write_frame_bytes( - encoded_frame.data.clone(), - encoded_frame.is_keyframe, - ) - .await - { - if frames_sent % 100 == 0 { - debug!("Failed to write frame to track: {}", e); - } - } else { - frames_sent += 1; - - // Log successful H265 frame send - if expected_codec == VideoEncoderType::H265 && (encoded_frame.is_keyframe || frames_sent % 30 == 0) { - debug!( - "[Session-H265] Frame #{} sent successfully", - frames_sent - ); - } - } - } - Err(broadcast::error::RecvError::Lagged(n)) => { - debug!("Session {} lagged by {} frames", session_id, n); - } - Err(broadcast::error::RecvError::Closed) => { + let encoded_frame = match result { + Some(frame) => frame, + None => { info!("Video frame channel closed for session {}", session_id); break; } + }; + + // Verify codec matches + let frame_codec = match encoded_frame.codec { + VideoEncoderType::H264 => VideoEncoderType::H264, + VideoEncoderType::H265 => VideoEncoderType::H265, + VideoEncoderType::VP8 => VideoEncoderType::VP8, + VideoEncoderType::VP9 => VideoEncoderType::VP9, + }; + + if frame_codec != expected_codec { + continue; + } + + // Debug log for H265 frames + if expected_codec == VideoEncoderType::H265 { + if encoded_frame.is_keyframe || frames_sent % 30 == 0 { + debug!( + "[Session-H265] Received frame #{}: size={}, keyframe={}, seq={}", + frames_sent, + encoded_frame.data.len(), + encoded_frame.is_keyframe, + encoded_frame.sequence + ); + } + } + + // Ensure decoder starts from a keyframe and recover on gaps. + let mut gap_detected = false; + if let Some(prev) = last_sequence { + if encoded_frame.sequence > prev.saturating_add(1) { + gap_detected = true; + } + } + + if waiting_for_keyframe || gap_detected { + if encoded_frame.is_keyframe { + waiting_for_keyframe = false; + } else { + if gap_detected { + waiting_for_keyframe = true; + } + let now = Instant::now(); + if now.duration_since(last_keyframe_request) + >= Duration::from_millis(200) + { + request_keyframe(); + last_keyframe_request = now; + } + continue; + } + } + + let _ = send_in_flight; + + // Send encoded frame via RTP (drop if previous send is still in flight) + let send_result = video_track + .write_frame_bytes( + encoded_frame.data.clone(), + encoded_frame.is_keyframe, + ) + .await; + let _ = send_in_flight; + + if send_result.is_err() { + // Keep quiet unless debugging send failures elsewhere + } else { + frames_sent += 1; + last_sequence = Some(encoded_frame.sequence); } } } @@ -629,7 +696,10 @@ impl UniversalSession { } /// Start receiving Opus audio frames - pub async fn start_audio_from_opus(&self, mut opus_rx: broadcast::Receiver) { + pub async fn start_audio_from_opus( + &self, + mut opus_rx: tokio::sync::watch::Receiver>>, + ) { let audio_track = match &self.audio_track { Some(track) => track.clone(), None => { @@ -684,26 +754,25 @@ impl UniversalSession { } } - result = opus_rx.recv() => { - match result { - Ok(opus_frame) => { - // 20ms at 48kHz = 960 samples - let samples = 960u32; - if let Err(e) = audio_track.write_packet(&opus_frame.data, samples).await { - if packets_sent % 100 == 0 { - debug!("Failed to write audio packet: {}", e); - } - } else { - packets_sent += 1; - } - } - Err(broadcast::error::RecvError::Lagged(n)) => { - warn!("Session {} audio lagged by {} packets", session_id, n); - } - Err(broadcast::error::RecvError::Closed) => { - info!("Opus channel closed for session {}", session_id); - break; + result = opus_rx.changed() => { + if result.is_err() { + info!("Opus channel closed for session {}", session_id); + break; + } + + let opus_frame = match opus_rx.borrow().clone() { + Some(frame) => frame, + None => continue, + }; + + // 20ms at 48kHz = 960 samples + let samples = 960u32; + if let Err(e) = audio_track.write_packet(&opus_frame.data, samples).await { + if packets_sent % 100 == 0 { + debug!("Failed to write audio packet: {}", e); } + } else { + packets_sent += 1; } } } diff --git a/src/webrtc/video_track.rs b/src/webrtc/video_track.rs index ef4614cd..7ad4c99c 100644 --- a/src/webrtc/video_track.rs +++ b/src/webrtc/video_track.rs @@ -186,19 +186,6 @@ impl UniversalVideoTrackConfig { } } -/// Track statistics -#[derive(Debug, Clone, Default)] -pub struct VideoTrackStats { - /// Frames sent - pub frames_sent: u64, - /// Bytes sent - pub bytes_sent: u64, - /// Keyframes sent - pub keyframes_sent: u64, - /// Errors - pub errors: u64, -} - /// Track type wrapper to support different underlying track implementations enum TrackType { /// Sample-based track with built-in payloader (H264, VP8, VP9) @@ -227,8 +214,6 @@ pub struct UniversalVideoTrack { codec: VideoCodec, /// Configuration config: UniversalVideoTrackConfig, - /// Statistics - stats: Mutex, /// H265 RTP state (only used for H265) h265_state: Option>, } @@ -277,7 +262,6 @@ impl UniversalVideoTrack { track, codec: config.codec, config, - stats: Mutex::new(VideoTrackStats::default()), h265_state, } } @@ -301,9 +285,6 @@ impl UniversalVideoTrack { } /// Get current statistics - pub async fn stats(&self) -> VideoTrackStats { - self.stats.lock().await.clone() - } /// Write an encoded frame to the track /// @@ -332,7 +313,7 @@ impl UniversalVideoTrack { /// /// Sends the entire Annex B frame as a single Sample to allow the /// H264Payloader to aggregate SPS+PPS into STAP-A packets. - async fn write_h264_frame(&self, data: Bytes, is_keyframe: bool) -> Result<()> { + async fn write_h264_frame(&self, data: Bytes, _is_keyframe: bool) -> Result<()> { // Send entire Annex B frame as one Sample // The H264Payloader in rtp crate will: // 1. Parse NAL units from Annex B format @@ -340,7 +321,6 @@ impl UniversalVideoTrack { // 3. Aggregate SPS+PPS+IDR into STAP-A when possible // 4. Fragment large NALs using FU-A let frame_duration = Duration::from_micros(1_000_000 / self.config.fps.max(1) as u64); - let data_len = data.len(); let sample = Sample { data, duration: frame_duration, @@ -358,14 +338,6 @@ impl UniversalVideoTrack { } } - // Update stats - let mut stats = self.stats.lock().await; - stats.frames_sent += 1; - stats.bytes_sent += data_len as u64; - if is_keyframe { - stats.keyframes_sent += 1; - } - Ok(()) } @@ -379,11 +351,10 @@ impl UniversalVideoTrack { } /// Write VP8 frame - async fn write_vp8_frame(&self, data: Bytes, is_keyframe: bool) -> Result<()> { + async fn write_vp8_frame(&self, data: Bytes, _is_keyframe: bool) -> Result<()> { // VP8 frames are sent directly without NAL parsing // Calculate frame duration based on configured FPS let frame_duration = Duration::from_micros(1_000_000 / self.config.fps.max(1) as u64); - let data_len = data.len(); let sample = Sample { data, duration: frame_duration, @@ -401,23 +372,14 @@ impl UniversalVideoTrack { } } - // Update stats - let mut stats = self.stats.lock().await; - stats.frames_sent += 1; - stats.bytes_sent += data_len as u64; - if is_keyframe { - stats.keyframes_sent += 1; - } - Ok(()) } /// Write VP9 frame - async fn write_vp9_frame(&self, data: Bytes, is_keyframe: bool) -> Result<()> { + async fn write_vp9_frame(&self, data: Bytes, _is_keyframe: bool) -> Result<()> { // VP9 frames are sent directly without NAL parsing // Calculate frame duration based on configured FPS let frame_duration = Duration::from_micros(1_000_000 / self.config.fps.max(1) as u64); - let data_len = data.len(); let sample = Sample { data, duration: frame_duration, @@ -435,19 +397,11 @@ impl UniversalVideoTrack { } } - // Update stats - let mut stats = self.stats.lock().await; - stats.frames_sent += 1; - stats.bytes_sent += data_len as u64; - if is_keyframe { - stats.keyframes_sent += 1; - } - Ok(()) } /// Send H265 NAL units via custom H265Payloader - async fn send_h265_rtp(&self, payload: Bytes, is_keyframe: bool) -> Result<()> { + async fn send_h265_rtp(&self, payload: Bytes, _is_keyframe: bool) -> Result<()> { let rtp_track = match &self.track { TrackType::Rtp(t) => t, TrackType::Sample(_) => { @@ -486,8 +440,6 @@ impl UniversalVideoTrack { (payloads, timestamp, seq_start, num_payloads) }; // Lock released here, before network I/O - let mut total_bytes = 0u64; - // Send RTP packets without holding the lock for (i, payload_data) in payloads.into_iter().enumerate() { let seq = seq_start.wrapping_add(i as u16); @@ -513,15 +465,6 @@ impl UniversalVideoTrack { trace!("H265 write_rtp failed: {}", e); } - total_bytes += payload_data.len() as u64; - } - - // Update stats - let mut stats = self.stats.lock().await; - stats.frames_sent += 1; - stats.bytes_sent += total_bytes; - if is_keyframe { - stats.keyframes_sent += 1; } Ok(()) diff --git a/src/webrtc/webrtc_streamer.rs b/src/webrtc/webrtc_streamer.rs index 510540fe..44ed2b13 100644 --- a/src/webrtc/webrtc_streamer.rs +++ b/src/webrtc/webrtc_streamer.rs @@ -15,10 +15,6 @@ //! | +-- VP8 Encoder (hardware only - VAAPI) //! | +-- VP9 Encoder (hardware only - VAAPI) //! | -//! +-- Audio Pipeline -//! | +-- SharedAudioPipeline -//! | +-- OpusEncoder -//! | //! +-- UniversalSession[] (video + audio tracks + DataChannel) //! +-- UniversalVideoTrack (H264/H265/VP8/VP9) //! +-- Audio Track (RTP/Opus) @@ -29,23 +25,23 @@ //! //! - **Single encoder**: All sessions share one video encoder //! - **Multi-codec support**: H264, H265, VP8, VP9 -//! - **Audio support**: Opus audio streaming via SharedAudioPipeline +//! - **Audio support**: Opus audio streaming via AudioController //! - **HID via DataChannel**: Keyboard/mouse events through WebRTC DataChannel use std::collections::HashMap; +use std::path::PathBuf; use std::sync::Arc; -use tokio::sync::{broadcast, RwLock}; -use tracing::{debug, error, info, trace, warn}; +use tokio::sync::RwLock; +use tracing::{debug, info, trace, warn}; -use crate::audio::shared_pipeline::{SharedAudioPipeline, SharedAudioPipelineConfig}; use crate::audio::{AudioController, OpusFrame}; +use crate::events::EventBus; use crate::error::{AppError, Result}; use crate::hid::HidController; use crate::video::encoder::registry::EncoderBackend; use crate::video::encoder::registry::VideoEncoderType; use crate::video::encoder::VideoCodecType; use crate::video::format::{PixelFormat, Resolution}; -use crate::video::frame::VideoFrame; use crate::video::shared_video_pipeline::{ SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats, }; @@ -91,6 +87,14 @@ impl Default for WebRtcStreamerConfig { } } +/// Capture device configuration for direct capture pipeline +#[derive(Debug, Clone)] +pub struct CaptureDeviceConfig { + pub device_path: PathBuf, + pub buffer_count: u32, + pub jpeg_quality: u8, +} + /// WebRTC streamer statistics #[derive(Debug, Clone, Default)] pub struct WebRtcStreamerStats { @@ -102,30 +106,12 @@ pub struct WebRtcStreamerStats { pub video_pipeline: Option, /// Audio enabled pub audio_enabled: bool, - /// Audio pipeline stats (if available) - pub audio_pipeline: Option, } /// Video pipeline statistics #[derive(Debug, Clone, Default)] pub struct VideoPipelineStats { - pub frames_encoded: u64, - pub frames_dropped: u64, - pub bytes_encoded: u64, - pub keyframes_encoded: u64, - pub avg_encode_time_ms: f32, pub current_fps: f32, - pub subscribers: u64, -} - -/// Audio pipeline statistics -#[derive(Debug, Clone, Default)] -pub struct AudioPipelineStats { - pub frames_encoded: u64, - pub frames_dropped: u64, - pub bytes_encoded: u64, - pub avg_encode_time_ms: f32, - pub subscribers: u64, } /// Session info for listing @@ -151,20 +137,21 @@ pub struct WebRtcStreamer { video_pipeline: RwLock>>, /// All sessions (unified management) sessions: Arc>>>, - /// Video frame source - video_frame_tx: RwLock>>, + /// Capture device configuration for direct capture mode + capture_device: RwLock>, // === Audio === /// Audio enabled flag audio_enabled: RwLock, - /// Shared audio pipeline for Opus encoding - audio_pipeline: RwLock>>, /// Audio controller reference audio_controller: RwLock>>, // === Controllers === /// HID controller for DataChannel hid_controller: RwLock>>, + + /// Event bus for WebRTC signaling (optional) + events: RwLock>>, } impl WebRtcStreamer { @@ -180,11 +167,11 @@ impl WebRtcStreamer { video_codec: RwLock::new(config.video_codec), video_pipeline: RwLock::new(None), sessions: Arc::new(RwLock::new(HashMap::new())), - video_frame_tx: RwLock::new(None), + capture_device: RwLock::new(None), audio_enabled: RwLock::new(config.audio_enabled), - audio_pipeline: RwLock::new(None), audio_controller: RwLock::new(None), hid_controller: RwLock::new(None), + events: RwLock::new(None), }) } @@ -219,9 +206,10 @@ impl WebRtcStreamer { // Update codec *self.video_codec.write().await = codec; - // Create new pipeline with new codec - if let Some(ref tx) = *self.video_frame_tx.read().await { - self.ensure_video_pipeline(tx.clone()).await?; + // Create new pipeline with new codec if capture source is configured + let has_capture = self.capture_device.read().await.is_some(); + if has_capture { + self.ensure_video_pipeline().await?; } info!("Video codec switched to {:?}", codec); @@ -263,10 +251,7 @@ impl WebRtcStreamer { } /// Ensure video pipeline is initialized and running - async fn ensure_video_pipeline( - self: &Arc, - tx: broadcast::Sender, - ) -> Result> { + async fn ensure_video_pipeline(self: &Arc) -> Result> { let mut pipeline_guard = self.video_pipeline.write().await; if let Some(ref pipeline) = *pipeline_guard { @@ -290,7 +275,16 @@ impl WebRtcStreamer { info!("Creating shared video pipeline for {:?}", codec); let pipeline = SharedVideoPipeline::new(pipeline_config)?; - pipeline.start(tx.subscribe()).await?; + let capture_device = self.capture_device.read().await.clone(); + if let Some(device) = capture_device { + pipeline + .start_with_device(device.device_path, device.buffer_count, device.jpeg_quality) + .await?; + } else { + return Err(AppError::VideoError( + "No capture device configured".to_string(), + )); + } // Start a monitor task to detect when pipeline auto-stops let pipeline_weak = Arc::downgrade(&pipeline); @@ -317,11 +311,7 @@ impl WebRtcStreamer { } drop(pipeline_guard); - // NOTE: Don't clear video_frame_tx here! - // The frame source is managed by stream_manager and should - // remain available for new sessions. Only stream_manager - // should clear it during mode switches. - info!("Video pipeline stopped, but keeping frame source for new sessions"); + info!("Video pipeline stopped, but keeping capture config for new sessions"); } break; } @@ -339,9 +329,8 @@ impl WebRtcStreamer { /// components (like RustDesk) that need to share the encoded video stream. pub async fn ensure_video_pipeline_for_external( self: &Arc, - tx: broadcast::Sender, ) -> Result> { - self.ensure_video_pipeline(tx).await + self.ensure_video_pipeline().await } /// Get the current pipeline configuration (if pipeline is running) @@ -353,6 +342,18 @@ impl WebRtcStreamer { } } + /// Request the encoder to generate a keyframe on next encode + pub async fn request_keyframe(&self) -> Result<()> { + if let Some(ref pipeline) = *self.video_pipeline.read().await { + pipeline.request_keyframe().await; + Ok(()) + } else { + Err(AppError::VideoError( + "Video pipeline not running".to_string(), + )) + } + } + // === Audio Management === /// Check if audio is enabled @@ -367,13 +368,10 @@ impl WebRtcStreamer { self.config.write().await.audio_enabled = enabled; if enabled && !was_enabled { - // Start audio pipeline if we have an audio controller - if let Some(ref controller) = *self.audio_controller.read().await { - self.start_audio_pipeline(controller.clone()).await?; + // Reconnect audio for existing sessions if we have a controller + if let Some(ref _controller) = *self.audio_controller.read().await { + self.reconnect_audio_sources().await; } - } else if !enabled && was_enabled { - // Stop audio pipeline - self.stop_audio_pipeline().await; } info!("WebRTC audio enabled: {}", enabled); @@ -385,61 +383,16 @@ impl WebRtcStreamer { info!("Setting audio controller for WebRTC streamer"); *self.audio_controller.write().await = Some(controller.clone()); - // Start audio pipeline if audio is enabled + // Reconnect audio for existing sessions if audio is enabled if *self.audio_enabled.read().await { - if let Err(e) = self.start_audio_pipeline(controller).await { - error!("Failed to start audio pipeline: {}", e); - } + self.reconnect_audio_sources().await; } } - /// Start the shared audio pipeline - async fn start_audio_pipeline(&self, controller: Arc) -> Result<()> { - // Check if already running - if let Some(ref pipeline) = *self.audio_pipeline.read().await { - if pipeline.is_running() { - debug!("Audio pipeline already running"); - return Ok(()); - } - } - - // Get Opus frame receiver from audio controller - let _opus_rx = match controller.subscribe_opus_async().await { - Some(rx) => rx, - None => { - warn!("Audio controller not streaming, cannot start audio pipeline"); - return Ok(()); - } - }; - - // Create shared audio pipeline config - let config = SharedAudioPipelineConfig::default(); - let pipeline = SharedAudioPipeline::new(config)?; - - // Note: SharedAudioPipeline expects raw AudioFrame, but AudioController - // already provides encoded OpusFrame. We'll pass the OpusFrame directly - // to sessions instead of re-encoding. - // For now, store the pipeline reference for future use - *self.audio_pipeline.write().await = Some(pipeline); - - // Reconnect audio for all existing sessions - self.reconnect_audio_sources().await; - - info!("WebRTC audio pipeline started"); - Ok(()) - } - - /// Stop the shared audio pipeline - async fn stop_audio_pipeline(&self) { - if let Some(ref pipeline) = *self.audio_pipeline.read().await { - pipeline.stop(); - } - *self.audio_pipeline.write().await = None; - info!("WebRTC audio pipeline stopped"); - } - /// Subscribe to encoded Opus frames (for sessions) - pub async fn subscribe_opus(&self) -> Option> { + pub async fn subscribe_opus( + &self, + ) -> Option>>> { if let Some(ref controller) = *self.audio_controller.read().await { controller.subscribe_opus_async().await } else { @@ -463,38 +416,22 @@ impl WebRtcStreamer { } } - // === Video Frame Source === - - /// Set video frame source - pub async fn set_video_source(&self, tx: broadcast::Sender) { + /// Set capture device for direct capture pipeline + pub async fn set_capture_device(&self, device_path: PathBuf, jpeg_quality: u8) { info!( - "Setting video source for WebRTC streamer (receiver_count={})", - tx.receiver_count() + "Setting direct capture device for WebRTC: {:?}", + device_path ); - *self.video_frame_tx.write().await = Some(tx.clone()); + *self.capture_device.write().await = Some(CaptureDeviceConfig { + device_path, + buffer_count: 2, + jpeg_quality, + }); + } - // Start or restart pipeline if it exists - if let Some(ref pipeline) = *self.video_pipeline.read().await { - if !pipeline.is_running() { - info!("Starting video pipeline with new frame source"); - if let Err(e) = pipeline.start(tx.subscribe()).await { - error!("Failed to start video pipeline: {}", e); - } - } else { - // Pipeline is already running but may have old frame source - // We need to restart it with the new frame source - info!("Video pipeline already running, restarting with new frame source"); - pipeline.stop(); - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; - if let Err(e) = pipeline.start(tx.subscribe()).await { - error!("Failed to restart video pipeline: {}", e); - } - } - } else { - info!( - "No video pipeline exists yet, frame source will be used when pipeline is created" - ); - } + /// Clear direct capture device configuration + pub async fn clear_capture_device(&self) { + *self.capture_device.write().await = None; } /// Prepare for configuration change @@ -509,11 +446,6 @@ impl WebRtcStreamer { self.close_all_sessions().await; } - /// Reconnect video source after configuration change - pub async fn reconnect_video_source(&self, tx: broadcast::Sender) { - self.set_video_source(tx).await; - } - // === Configuration === /// Update video configuration @@ -690,6 +622,11 @@ impl WebRtcStreamer { *self.hid_controller.write().await = Some(hid); } + /// Set event bus for WebRTC signaling events + pub async fn set_event_bus(&self, events: Arc) { + *self.events.write().await = Some(events); + } + // === Session Management === /// Create a new WebRTC session @@ -698,13 +635,7 @@ impl WebRtcStreamer { let codec = *self.video_codec.read().await; // Ensure video pipeline is running - let frame_tx = self - .video_frame_tx - .read() - .await - .clone() - .ok_or_else(|| AppError::VideoError("No video frame source".to_string()))?; - let pipeline = self.ensure_video_pipeline(frame_tx).await?; + let pipeline = self.ensure_video_pipeline().await?; // Create session config let config = self.config.read().await; @@ -720,7 +651,9 @@ impl WebRtcStreamer { drop(config); // Create universal session - let mut session = UniversalSession::new(session_config.clone(), session_id.clone()).await?; + let event_bus = self.events.read().await.clone(); + let mut session = + UniversalSession::new(session_config.clone(), session_id.clone(), event_bus).await?; // Set HID controller if available // Note: We DON'T create a data channel here - the frontend creates it. @@ -734,22 +667,22 @@ impl WebRtcStreamer { let session = Arc::new(session); // Subscribe to video pipeline frames - // Request keyframe after ICE connection is established (via callback) + // Request keyframe after ICE connection is established and on gaps let pipeline_for_callback = pipeline.clone(); let session_id_for_callback = session_id.clone(); + let request_keyframe = Arc::new(move || { + let pipeline = pipeline_for_callback.clone(); + let sid = session_id_for_callback.clone(); + tokio::spawn(async move { + info!( + "Requesting keyframe for session {} after ICE connected", + sid + ); + pipeline.request_keyframe().await; + }); + }); session - .start_from_video_pipeline(pipeline.subscribe(), move || { - // Spawn async task to request keyframe - let pipeline = pipeline_for_callback; - let sid = session_id_for_callback; - tokio::spawn(async move { - info!( - "Requesting keyframe for session {} after ICE connected", - sid - ); - pipeline.request_keyframe().await; - }); - }) + .start_from_video_pipeline(pipeline.subscribe(), request_keyframe) .await; // Start audio if enabled @@ -913,27 +846,7 @@ impl WebRtcStreamer { let video_pipeline = if let Some(ref pipeline) = *self.video_pipeline.read().await { let s = pipeline.stats().await; Some(VideoPipelineStats { - frames_encoded: s.frames_encoded, - frames_dropped: s.frames_dropped, - bytes_encoded: s.bytes_encoded, - keyframes_encoded: s.keyframes_encoded, - avg_encode_time_ms: s.avg_encode_time_ms, current_fps: s.current_fps, - subscribers: s.subscribers, - }) - } else { - None - }; - - // Get audio pipeline stats - let audio_pipeline = if let Some(ref pipeline) = *self.audio_pipeline.read().await { - let stats = pipeline.stats().await; - Some(AudioPipelineStats { - frames_encoded: stats.frames_encoded, - frames_dropped: stats.frames_dropped, - bytes_encoded: stats.bytes_encoded, - avg_encode_time_ms: stats.avg_encode_time_ms, - subscribers: stats.subscribers, }) } else { None @@ -944,7 +857,6 @@ impl WebRtcStreamer { video_codec: format!("{:?}", codec), video_pipeline, audio_enabled: *self.audio_enabled.read().await, - audio_pipeline, } } @@ -984,9 +896,6 @@ impl WebRtcStreamer { if pipeline_running { info!("Restarting video pipeline to apply new bitrate: {}", preset); - // Save video_frame_tx BEFORE stopping pipeline (monitor task will clear it) - let saved_frame_tx = self.video_frame_tx.read().await.clone(); - // Stop existing pipeline if let Some(ref pipeline) = *self.video_pipeline.read().await { pipeline.stop(); @@ -998,46 +907,43 @@ impl WebRtcStreamer { // Clear pipeline reference - will be recreated *self.video_pipeline.write().await = None; - // Recreate pipeline with new config if we have a frame source - if let Some(tx) = saved_frame_tx { - // Get existing sessions that need to be reconnected - let session_ids: Vec = self.sessions.read().await.keys().cloned().collect(); + let has_source = self.capture_device.read().await.is_some(); + if !has_source { + return Ok(()); + } - if !session_ids.is_empty() { - // Restore video_frame_tx before recreating pipeline - *self.video_frame_tx.write().await = Some(tx.clone()); + let session_ids: Vec = self.sessions.read().await.keys().cloned().collect(); + if !session_ids.is_empty() { + let pipeline = self.ensure_video_pipeline().await?; - // Recreate pipeline - let pipeline = self.ensure_video_pipeline(tx).await?; - - // Reconnect all sessions to new pipeline - let sessions = self.sessions.read().await; - for session_id in &session_ids { - if let Some(session) = sessions.get(session_id) { - info!("Reconnecting session {} to new pipeline", session_id); - let pipeline_for_callback = pipeline.clone(); - let sid = session_id.clone(); - session - .start_from_video_pipeline(pipeline.subscribe(), move || { - let pipeline = pipeline_for_callback; - tokio::spawn(async move { - info!( - "Requesting keyframe for session {} after reconnect", - sid - ); - pipeline.request_keyframe().await; - }); - }) - .await; - } + let sessions = self.sessions.read().await; + for session_id in &session_ids { + if let Some(session) = sessions.get(session_id) { + info!("Reconnecting session {} to new pipeline", session_id); + let pipeline_for_callback = pipeline.clone(); + let sid = session_id.clone(); + let request_keyframe = Arc::new(move || { + let pipeline = pipeline_for_callback.clone(); + let sid = sid.clone(); + tokio::spawn(async move { + info!( + "Requesting keyframe for session {} after reconnect", + sid + ); + pipeline.request_keyframe().await; + }); + }); + session + .start_from_video_pipeline(pipeline.subscribe(), request_keyframe) + .await; } - - info!( - "Video pipeline restarted with {}, reconnected {} sessions", - preset, - session_ids.len() - ); } + + info!( + "Video pipeline restarted with {}, reconnected {} sessions", + preset, + session_ids.len() + ); } } else { debug!( @@ -1057,11 +963,11 @@ impl Default for WebRtcStreamer { video_codec: RwLock::new(VideoCodecType::H264), video_pipeline: RwLock::new(None), sessions: Arc::new(RwLock::new(HashMap::new())), - video_frame_tx: RwLock::new(None), + capture_device: RwLock::new(None), audio_enabled: RwLock::new(false), - audio_pipeline: RwLock::new(None), audio_controller: RwLock::new(None), hid_controller: RwLock::new(None), + events: RwLock::new(None), } } } diff --git a/test/bench_kvm.py b/test/bench_kvm.py new file mode 100644 index 00000000..f425052b --- /dev/null +++ b/test/bench_kvm.py @@ -0,0 +1,566 @@ +#!/usr/bin/env python3 +""" +One-KVM benchmark script (Windows-friendly). + +Measures FPS + CPU usage across: +- input pixel formats (capture card formats) +- output codecs (mjpeg/h264/h265/vp8/vp9) +- resolution/FPS matrix +- encoder backends (software/hardware) + +Requirements: + pip install requests websockets playwright + playwright install +""" + +from __future__ import annotations + +import argparse +import asyncio +import csv +import json +import sys +import threading +import time +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple + +import requests +import websockets +from playwright.async_api import async_playwright + + +SESSION_COOKIE = "one_kvm_session" +DEFAULT_MATRIX = [ + (1920, 1080, 30), + (1920, 1080, 60), + (1280, 720, 30), + (1280, 720, 60), +] + + +@dataclass +class Case: + input_format: str + output_codec: str + encoder: Optional[str] + width: int + height: int + fps: int + + +@dataclass +class Result: + input_format: str + output_codec: str + encoder: str + width: int + height: int + fps: int + avg_fps: float + avg_cpu: float + note: str = "" + + +class KvmClient: + def __init__(self, base_url: str, username: str, password: str) -> None: + self.base = base_url.rstrip("/") + self.s = requests.Session() + self.login(username, password) + + def login(self, username: str, password: str) -> None: + r = self.s.post(f"{self.base}/api/auth/login", json={"username": username, "password": password}) + r.raise_for_status() + + def get_cookie(self) -> str: + return self.s.cookies.get(SESSION_COOKIE, "") + + def get_video_config(self) -> Dict: + r = self.s.get(f"{self.base}/api/config/video") + r.raise_for_status() + return r.json() + + def get_stream_config(self) -> Dict: + r = self.s.get(f"{self.base}/api/config/stream") + r.raise_for_status() + return r.json() + + def get_devices(self) -> Dict: + r = self.s.get(f"{self.base}/api/devices") + r.raise_for_status() + return r.json() + + def get_codecs(self) -> Dict: + r = self.s.get(f"{self.base}/api/stream/codecs") + r.raise_for_status() + return r.json() + + def patch_video(self, device: Optional[str], fmt: str, w: int, h: int, fps: int) -> None: + payload: Dict[str, object] = {"format": fmt, "width": w, "height": h, "fps": fps} + if device: + payload["device"] = device + r = self.s.patch(f"{self.base}/api/config/video", json=payload) + r.raise_for_status() + + def patch_stream(self, encoder: Optional[str]) -> None: + if encoder is None: + return + r = self.s.patch(f"{self.base}/api/config/stream", json={"encoder": encoder}) + r.raise_for_status() + + def set_mode(self, mode: str) -> None: + r = self.s.post(f"{self.base}/api/stream/mode", json={"mode": mode}) + r.raise_for_status() + + def get_mode(self) -> Dict: + r = self.s.get(f"{self.base}/api/stream/mode") + r.raise_for_status() + return r.json() + + def wait_mode_ready(self, mode: str, timeout_sec: int = 20) -> None: + deadline = time.time() + timeout_sec + while time.time() < deadline: + data = self.get_mode() + if not data.get("switching") and data.get("mode") == mode: + return + time.sleep(0.5) + raise RuntimeError(f"mode switch timeout: {mode}") + + def start_stream(self) -> None: + r = self.s.post(f"{self.base}/api/stream/start") + r.raise_for_status() + + def stop_stream(self) -> None: + r = self.s.post(f"{self.base}/api/stream/stop") + r.raise_for_status() + + def cpu_sample(self) -> float: + r = self.s.get(f"{self.base}/api/info") + r.raise_for_status() + return float(r.json()["device_info"]["cpu_usage"]) + + def close_webrtc_session(self, session_id: str) -> None: + if not session_id: + return + self.s.post(f"{self.base}/api/webrtc/close", json={"session_id": session_id}) + + +class MjpegStream: + def __init__(self, url: str, cookie: str) -> None: + self._stop = threading.Event() + self._resp = requests.get(url, stream=True, headers={"Cookie": f"{SESSION_COOKIE}={cookie}"}) + self._thread = threading.Thread(target=self._reader, daemon=True) + self._thread.start() + + def _reader(self) -> None: + try: + for chunk in self._resp.iter_content(chunk_size=4096): + if self._stop.is_set(): + break + if not chunk: + time.sleep(0.01) + except Exception: + pass + + def close(self) -> None: + self._stop.set() + try: + self._resp.close() + except Exception: + pass + + +def parse_matrix(values: Optional[List[str]]) -> List[Tuple[int, int, int]]: + if not values: + return DEFAULT_MATRIX + result: List[Tuple[int, int, int]] = [] + for item in values: + # WIDTHxHEIGHT@FPS + part = item.strip().lower() + if "@" not in part or "x" not in part: + raise ValueError(f"invalid matrix item: {item}") + res_part, fps_part = part.split("@", 1) + w_str, h_str = res_part.split("x", 1) + result.append((int(w_str), int(h_str), int(fps_part))) + return result + + +def avg(values: Iterable[float]) -> float: + vals = list(values) + return sum(vals) / len(vals) if vals else 0.0 + + +def normalize_format(fmt: str) -> str: + return fmt.strip().upper() + + +def select_device(devices: Dict, preferred: Optional[str]) -> Optional[Dict]: + video_devices = devices.get("video", []) + if preferred: + for d in video_devices: + if d.get("path") == preferred: + return d + return video_devices[0] if video_devices else None + + +def build_supported_map(device: Dict) -> Dict[str, Dict[Tuple[int, int], List[int]]]: + supported: Dict[str, Dict[Tuple[int, int], List[int]]] = {} + for fmt in device.get("formats", []): + fmt_name = normalize_format(fmt.get("format", "")) + res_map: Dict[Tuple[int, int], List[int]] = {} + for res in fmt.get("resolutions", []): + key = (int(res.get("width", 0)), int(res.get("height", 0))) + fps_list = [int(f) for f in res.get("fps", [])] + res_map[key] = fps_list + supported[fmt_name] = res_map + return supported + + +def is_combo_supported( + supported: Dict[str, Dict[Tuple[int, int], List[int]]], + fmt: str, + width: int, + height: int, + fps: int, +) -> bool: + res_map = supported.get(fmt) + if not res_map: + return False + fps_list = res_map.get((width, height), []) + return fps in fps_list + + +async def mjpeg_sample( + base_url: str, + cookie: str, + client_id: str, + duration_sec: float, + cpu_sample_fn, +) -> Tuple[float, float]: + mjpeg_url = f"{base_url}/api/stream/mjpeg?client_id={client_id}" + stream = MjpegStream(mjpeg_url, cookie) + ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") + "/api/ws" + + fps_samples: List[float] = [] + cpu_samples: List[float] = [] + + # discard first cpu sample (needs delta) + cpu_sample_fn() + + try: + async with websockets.connect(ws_url, extra_headers={"Cookie": f"{SESSION_COOKIE}={cookie}"}) as ws: + start = time.time() + while time.time() - start < duration_sec: + try: + msg = await asyncio.wait_for(ws.recv(), timeout=1.0) + except asyncio.TimeoutError: + msg = None + + if msg: + data = json.loads(msg) + if data.get("type") == "stream.stats_update": + clients = data.get("clients_stat", {}) + if client_id in clients: + fps = float(clients[client_id].get("fps", 0)) + fps_samples.append(fps) + + cpu_samples.append(float(cpu_sample_fn())) + finally: + stream.close() + + return avg(fps_samples), avg(cpu_samples) + + +async def webrtc_sample( + base_url: str, + cookie: str, + duration_sec: float, + cpu_sample_fn, + headless: bool, +) -> Tuple[float, float, str]: + fps_samples: List[float] = [] + cpu_samples: List[float] = [] + session_id = "" + + # discard first cpu sample (needs delta) + cpu_sample_fn() + + async with async_playwright() as p: + browser = await p.chromium.launch(headless=headless) + context = await browser.new_context() + await context.add_cookies([{ + "name": SESSION_COOKIE, + "value": cookie, + "url": base_url, + "path": "/", + }]) + page = await context.new_page() + await page.goto(base_url + "/", wait_until="domcontentloaded") + + await page.evaluate( + """ + async (base) => { + const pc = new RTCPeerConnection(); + pc.addTransceiver('video', { direction: 'recvonly' }); + pc.addTransceiver('audio', { direction: 'recvonly' }); + pc.onicecandidate = async (e) => { + if (e.candidate && window.__sid) { + await fetch(base + "/api/webrtc/ice", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ session_id: window.__sid, candidate: e.candidate }) + }); + } + }; + const offer = await pc.createOffer(); + await pc.setLocalDescription(offer); + const resp = await fetch(base + "/api/webrtc/offer", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ sdp: offer.sdp }) + }); + const ans = await resp.json(); + window.__sid = ans.session_id; + await pc.setRemoteDescription({ type: "answer", sdp: ans.sdp }); + (ans.ice_candidates || []).forEach(c => pc.addIceCandidate(c)); + window.__kvmStats = { pc, lastTs: 0, lastFrames: 0 }; + } + """, + base_url, + ) + + try: + await page.wait_for_function( + "window.__kvmStats && window.__kvmStats.pc && window.__kvmStats.pc.connectionState === 'connected'", + timeout=15000, + ) + except Exception: + pass + + start = time.time() + while time.time() - start < duration_sec: + fps = await page.evaluate( + """ + async () => { + const s = window.__kvmStats; + const report = await s.pc.getStats(); + let fps = 0; + for (const r of report.values()) { + if (r.type === "inbound-rtp" && r.kind === "video") { + if (r.framesPerSecond) { + fps = r.framesPerSecond; + } else if (r.framesDecoded && s.lastTs) { + const dt = (r.timestamp - s.lastTs) / 1000.0; + const df = r.framesDecoded - s.lastFrames; + fps = dt > 0 ? df / dt : 0; + } + s.lastTs = r.timestamp; + s.lastFrames = r.framesDecoded || s.lastFrames; + break; + } + } + return fps; + } + """ + ) + fps_samples.append(float(fps)) + cpu_samples.append(float(cpu_sample_fn())) + await asyncio.sleep(1) + + session_id = await page.evaluate("window.__sid || ''") + await browser.close() + + return avg(fps_samples), avg(cpu_samples), session_id + + +async def run_case( + client: KvmClient, + device: Optional[str], + case: Case, + duration_sec: float, + warmup_sec: float, + headless: bool, +) -> Result: + client.patch_video(device, case.input_format, case.width, case.height, case.fps) + + if case.output_codec != "mjpeg": + client.patch_stream(case.encoder) + + client.set_mode(case.output_codec) + client.wait_mode_ready(case.output_codec) + + client.start_stream() + time.sleep(warmup_sec) + + note = "" + if case.output_codec == "mjpeg": + avg_fps, avg_cpu = await mjpeg_sample( + client.base, + client.get_cookie(), + client_id=f"bench-{int(time.time() * 1000)}", + duration_sec=duration_sec, + cpu_sample_fn=client.cpu_sample, + ) + else: + avg_fps, avg_cpu, session_id = await webrtc_sample( + client.base, + client.get_cookie(), + duration_sec=duration_sec, + cpu_sample_fn=client.cpu_sample, + headless=headless, + ) + if session_id: + client.close_webrtc_session(session_id) + else: + note = "no-session-id" + + client.stop_stream() + + return Result( + input_format=case.input_format, + output_codec=case.output_codec, + encoder=case.encoder or "n/a", + width=case.width, + height=case.height, + fps=case.fps, + avg_fps=avg_fps, + avg_cpu=avg_cpu, + note=note, + ) + + +def write_csv(results: List[Result], path: str) -> None: + with open(path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["input_format", "output_codec", "encoder", "width", "height", "fps", "avg_fps", "avg_cpu", "note"]) + for r in results: + w.writerow([r.input_format, r.output_codec, r.encoder, r.width, r.height, r.fps, f"{r.avg_fps:.2f}", f"{r.avg_cpu:.2f}", r.note]) + + +def write_md(results: List[Result], path: str) -> None: + lines = [ + "| input_format | output_codec | encoder | width | height | fps | avg_fps | avg_cpu | note |", + "|---|---|---|---:|---:|---:|---:|---:|---|", + ] + for r in results: + lines.append( + f"| {r.input_format} | {r.output_codec} | {r.encoder} | {r.width} | {r.height} | {r.fps} | {r.avg_fps:.2f} | {r.avg_cpu:.2f} | {r.note} |" + ) + with open(path, "w", encoding="utf-8") as f: + f.write("\n".join(lines)) + + +def main() -> int: + parser = argparse.ArgumentParser(description="One-KVM benchmark (FPS + CPU)") + parser.add_argument("--base-url", required=True, help="e.g. http://192.168.1.50") + parser.add_argument("--username", required=True) + parser.add_argument("--password", required=True) + parser.add_argument("--device", help="video device path, e.g. /dev/video0") + parser.add_argument("--input-formats", help="comma list, e.g. MJPEG,YUYV,NV12") + parser.add_argument("--output-codecs", help="comma list, e.g. mjpeg,h264,h265,vp8,vp9") + parser.add_argument("--encoder-backends", help="comma list, e.g. software,auto,vaapi,nvenc,qsv,amf,rkmpp,v4l2m2m") + parser.add_argument("--matrix", action="append", help="repeatable WIDTHxHEIGHT@FPS, e.g. 1920x1080@30") + parser.add_argument("--duration", type=float, default=30.0, help="sample duration seconds (default 30)") + parser.add_argument("--warmup", type=float, default=3.0, help="warmup seconds before sampling") + parser.add_argument("--csv", default="bench_results.csv") + parser.add_argument("--md", default="bench_results.md") + parser.add_argument("--headless", action="store_true", help="run browser headless (default: headful)") + + args = parser.parse_args() + + if sys.platform.startswith("win"): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + base_url = args.base_url.strip() + if not base_url.startswith(("http://", "https://")): + base_url = "http://" + base_url + client = KvmClient(base_url, args.username, args.password) + + devices = client.get_devices() + video_cfg = client.get_video_config() + device_path = args.device or video_cfg.get("device") + device_info = select_device(devices, device_path) + if not device_info: + print("No video device found.", file=sys.stderr) + return 2 + device_path = device_info.get("path") + + supported_map = build_supported_map(device_info) + + if args.input_formats: + input_formats = [normalize_format(f) for f in args.input_formats.split(",") if f.strip()] + else: + input_formats = list(supported_map.keys()) + + matrix = parse_matrix(args.matrix) + + codecs_info = client.get_codecs() + available_codecs = {c["id"] for c in codecs_info.get("codecs", []) if c.get("available")} + available_codecs.add("mjpeg") + + if args.output_codecs: + output_codecs = [c.strip().lower() for c in args.output_codecs.split(",") if c.strip()] + else: + output_codecs = sorted(list(available_codecs)) + + if args.encoder_backends: + encoder_backends = [e.strip().lower() for e in args.encoder_backends.split(",") if e.strip()] + else: + encoder_backends = ["software", "auto"] + + cases: List[Case] = [] + for fmt in input_formats: + for (w, h, fps) in matrix: + if not is_combo_supported(supported_map, fmt, w, h, fps): + continue + for codec in output_codecs: + if codec not in available_codecs: + continue + if codec == "mjpeg": + cases.append(Case(fmt, codec, None, w, h, fps)) + else: + for enc in encoder_backends: + cases.append(Case(fmt, codec, enc, w, h, fps)) + + print(f"Total cases: {len(cases)}") + results: List[Result] = [] + + for idx, case in enumerate(cases, 1): + print(f"[{idx}/{len(cases)}] {case.input_format} {case.output_codec} {case.encoder or 'n/a'} {case.width}x{case.height}@{case.fps}") + try: + result = asyncio.run( + run_case( + client, + device=device_path, + case=case, + duration_sec=args.duration, + warmup_sec=args.warmup, + headless=args.headless, + ) + ) + results.append(result) + print(f" -> avg_fps={result.avg_fps:.2f}, avg_cpu={result.avg_cpu:.2f}") + except Exception as exc: + results.append( + Result( + input_format=case.input_format, + output_codec=case.output_codec, + encoder=case.encoder or "n/a", + width=case.width, + height=case.height, + fps=case.fps, + avg_fps=0.0, + avg_cpu=0.0, + note=f"error: {exc}", + ) + ) + print(f" -> error: {exc}") + + write_csv(results, args.csv) + write_md(results, args.md) + print(f"Saved: {args.csv}, {args.md}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/web/index.html b/web/index.html index f4713daf..82971921 100644 --- a/web/index.html +++ b/web/index.html @@ -2,7 +2,7 @@ - + One-KVM diff --git a/web/package-lock.json b/web/package-lock.json index 0287dad9..fbfd1886 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "web", - "version": "0.0.0", + "version": "0.1.4", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "web", - "version": "0.0.0", + "version": "0.1.4", "dependencies": { "@vueuse/core": "^14.1.0", "class-variance-authority": "^0.7.1", @@ -1368,6 +1368,7 @@ "integrity": "sha512-GNWcUTRBgIRJD5zj+Tq0fKOJ5XZajIiBroOF0yvj2bSU1WvNdYS/dn9UxwsujGW4JX06dnHyjV2y9rRaybH0iQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "undici-types": "~7.16.0" } @@ -1782,6 +1783,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -2448,6 +2450,7 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -2495,6 +2498,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -2787,7 +2791,8 @@ "resolved": "https://registry.npmmirror.com/tailwindcss/-/tailwindcss-4.1.17.tgz", "integrity": "sha512-j9Ee2YjuQqYT9bbRTfTZht9W/ytp5H+jJpZKiYdP/bpnXARAuELt9ofP0lPnmHjbga7SNQIxdTAXCmtKVYjN+Q==", "dev": true, - "license": "MIT" + "license": "MIT", + "peer": true }, "node_modules/tapable": { "version": "2.3.0", @@ -2841,6 +2846,7 @@ "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "devOptional": true, "license": "Apache-2.0", + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -2906,6 +2912,7 @@ "integrity": "sha512-tI2l/nFHC5rLh7+5+o7QjKjSR04ivXDF4jcgV0f/bTQ+OJiITy5S6gaynVsEM+7RqzufMnVbIon6Sr5x1SDYaQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.5.0", @@ -2987,6 +2994,7 @@ "resolved": "https://registry.npmmirror.com/vue/-/vue-3.5.25.tgz", "integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==", "license": "MIT", + "peer": true, "dependencies": { "@vue/compiler-dom": "3.5.25", "@vue/compiler-sfc": "3.5.25", diff --git a/web/package.json b/web/package.json index 8f474a59..31096b3b 100644 --- a/web/package.json +++ b/web/package.json @@ -1,7 +1,7 @@ { "name": "web", "private": true, - "version": "0.1.1", + "version": "0.1.4", "type": "module", "scripts": { "dev": "vite", diff --git a/web/public/favicon.png b/web/public/favicon.png new file mode 100644 index 00000000..50433288 Binary files /dev/null and b/web/public/favicon.png differ diff --git a/web/src/api/config.ts b/web/src/api/config.ts index e7ac249f..8edef184 100644 --- a/web/src/api/config.ts +++ b/web/src/api/config.ts @@ -7,6 +7,8 @@ import type { AppConfig, + AuthConfig, + AuthConfigUpdate, VideoConfig, VideoConfigUpdate, StreamConfigResponse, @@ -41,6 +43,24 @@ export const configApi = { getAll: () => request('/config'), } +// ===== Auth 配置 API ===== +export const authConfigApi = { + /** + * 获取认证配置 + */ + get: () => request('/config/auth'), + + /** + * 更新认证配置 + * @param config 要更新的字段 + */ + update: (config: AuthConfigUpdate) => + request('/config/auth', { + method: 'PATCH', + body: JSON.stringify(config), + }), +} + // ===== Video 配置 API ===== export const videoConfigApi = { /** @@ -316,6 +336,7 @@ export const rustdeskConfigApi = { export interface WebConfig { http_port: number https_port: number + bind_addresses: string[] bind_address: string https_enabled: boolean } @@ -324,6 +345,7 @@ export interface WebConfig { export interface WebConfigUpdate { http_port?: number https_port?: number + bind_addresses?: string[] bind_address?: string https_enabled?: boolean } diff --git a/web/src/api/index.ts b/web/src/api/index.ts index f962bdb8..5cc503ae 100644 --- a/web/src/api/index.ts +++ b/web/src/api/index.ts @@ -16,7 +16,19 @@ export const authApi = { request<{ success: boolean }>('/auth/logout', { method: 'POST' }), check: () => - request<{ authenticated: boolean; user?: string; is_admin?: boolean }>('/auth/check'), + request<{ authenticated: boolean; user?: string }>('/auth/check'), + + changePassword: (currentPassword: string, newPassword: string) => + request<{ success: boolean }>('/auth/password', { + method: 'POST', + body: JSON.stringify({ current_password: currentPassword, new_password: newPassword }), + }), + + changeUsername: (username: string, currentPassword: string) => + request<{ success: boolean }>('/auth/username', { + method: 'POST', + body: JSON.stringify({ username, current_password: currentPassword }), + }), } // System API @@ -72,6 +84,7 @@ export const systemApi = { hid_ch9329_port?: string hid_ch9329_baudrate?: number hid_otg_udc?: string + hid_otg_profile?: string encoder_backend?: string audio_device?: string ttyd_enabled?: boolean @@ -121,8 +134,6 @@ export const streamApi = { clients: number target_fps: number fps: number - frames_captured: number - frames_dropped: number }>('/stream/status'), start: () => @@ -200,7 +211,7 @@ export const webrtcApi = { }), getIceServers: () => - request<{ ice_servers: IceServerConfig[] }>('/webrtc/ice-servers'), + request<{ ice_servers: IceServerConfig[]; mdns_mode: string }>('/webrtc/ice-servers'), } // HID API @@ -516,6 +527,7 @@ export const configApi = { // 导出新的域分离配置 API export { + authConfigApi, videoConfigApi, streamConfigApi, hidConfigApi, @@ -535,6 +547,8 @@ export { // 导出生成的类型 export type { AppConfig, + AuthConfig, + AuthConfigUpdate, VideoConfig, VideoConfigUpdate, StreamConfig, @@ -588,53 +602,4 @@ export const audioApi = { }), } -// User Management API -export interface User { - id: string - username: string - role: 'admin' | 'user' - created_at: string -} - -interface UserApiResponse { - id: string - username: string - is_admin: boolean - created_at: string -} - -export const userApi = { - list: async () => { - const rawUsers = await request('/users') - const users: User[] = rawUsers.map(u => ({ - id: u.id, - username: u.username, - role: u.is_admin ? 'admin' : 'user', - created_at: u.created_at, - })) - return { success: true, users } - }, - - create: (username: string, password: string, role: 'admin' | 'user' = 'user') => - request('/users', { - method: 'POST', - body: JSON.stringify({ username, password, is_admin: role === 'admin' }), - }), - - update: (id: string, data: { username?: string; role?: 'admin' | 'user' }) => - request<{ success: boolean }>(`/users/${id}`, { - method: 'PUT', - body: JSON.stringify({ username: data.username, is_admin: data.role === 'admin' }), - }), - - delete: (id: string) => - request<{ success: boolean }>(`/users/${id}`, { method: 'DELETE' }), - - changePassword: (id: string, newPassword: string, currentPassword?: string) => - request<{ success: boolean }>(`/users/${id}/password`, { - method: 'POST', - body: JSON.stringify({ new_password: newPassword, current_password: currentPassword }), - }), -} - export { ApiError } diff --git a/web/src/api/request.ts b/web/src/api/request.ts index d0ebcce0..7341b4a8 100644 --- a/web/src/api/request.ts +++ b/web/src/api/request.ts @@ -81,7 +81,12 @@ export async function request( // Handle HTTP errors (in case backend returns non-2xx) if (!response.ok) { const message = getErrorMessage(data, `HTTP ${response.status}`) - if (toastOnError && shouldShowToast(toastKey)) { + const normalized = message.toLowerCase() + const isNotAuthenticated = normalized.includes('not authenticated') + const isSessionExpired = normalized.includes('session expired') + const isLoggedInElsewhere = normalized.includes('logged in elsewhere') + const isAuthIssue = response.status === 401 && (isNotAuthenticated || isSessionExpired || isLoggedInElsewhere) + if (toastOnError && shouldShowToast(toastKey) && !isAuthIssue) { toast.error(t('api.operationFailed'), { description: message, duration: 4000, @@ -130,4 +135,3 @@ export async function request( throw new ApiError(0, t('api.networkError')) } } - diff --git a/web/src/components/ActionBar.vue b/web/src/components/ActionBar.vue index d600f52b..2d2422e5 100644 --- a/web/src/components/ActionBar.vue +++ b/web/src/components/ActionBar.vue @@ -52,14 +52,13 @@ const overflowMenuOpen = ref(false) const hidBackend = computed(() => (systemStore.hid?.backend ?? '').toLowerCase()) const isCh9329Backend = computed(() => hidBackend.value.includes('ch9329')) const showMsd = computed(() => { - return props.isAdmin && !isCh9329Backend.value + return !!systemStore.msd?.available && !isCh9329Backend.value }) const props = defineProps<{ mouseMode?: 'absolute' | 'relative' videoMode?: VideoMode ttydRunning?: boolean - isAdmin?: boolean }>() const emit = defineEmits<{ @@ -86,25 +85,23 @@ const extensionOpen = ref(false)